From 3ffe6de0b9a4f49cf18f6a2ebf0fc2c6274dfccd Mon Sep 17 00:00:00 2001 From: Erik Grinaker Date: Fri, 29 Nov 2024 10:40:08 +0100 Subject: [PATCH 01/15] test_runner/performance: add logical message ingest benchmark (#9749) Adds a benchmark for logical message WAL ingestion throughput end-to-end. Logical messages are essentially noops, and thus ignored by the Pageserver. Example results from my MacBook, with fsync enabled: ``` postgres_ingest: 14.445 s safekeeper_ingest: 29.948 s pageserver_ingest: 30.013 s pageserver_recover_ingest: 8.633 s wal_written: 10,340 MB message_count: 1310720 messages postgres_throughput: 715 MB/s safekeeper_throughput: 345 MB/s pageserver_throughput: 344 MB/s pageserver_recover_throughput: 1197 MB/s ``` See https://github.com/neondatabase/neon/issues/9642#issuecomment-2475995205 for running analysis. Touches #9642. --- test_runner/fixtures/neon_fixtures.py | 31 ++++++ .../test_ingest_logical_message.py | 101 ++++++++++++++++++ 2 files changed, 132 insertions(+) create mode 100644 test_runner/performance/test_ingest_logical_message.py diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 1f4d2aa5ec..e3c88e9965 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -4404,6 +4404,10 @@ class Safekeeper(LogUtils): log.info(f"sk {self.id} flush LSN: {flush_lsn}") return flush_lsn + def get_commit_lsn(self, tenant_id: TenantId, timeline_id: TimelineId) -> Lsn: + timeline_status = self.http_client().timeline_status(tenant_id, timeline_id) + return timeline_status.commit_lsn + def pull_timeline( self, srcs: list[Safekeeper], tenant_id: TenantId, timeline_id: TimelineId ) -> dict[str, Any]: @@ -4949,6 +4953,33 @@ def wait_for_last_flush_lsn( return min(results) +def wait_for_commit_lsn( + env: NeonEnv, + tenant: TenantId, + timeline: TimelineId, + lsn: Lsn, +) -> Lsn: + # TODO: it would be better to poll this in the compute, but there's no API for it. See: + # https://github.com/neondatabase/neon/issues/9758 + "Wait for the given LSN to be committed on any Safekeeper" + + max_commit_lsn = Lsn(0) + for i in range(1000): + for sk in env.safekeepers: + commit_lsn = sk.get_commit_lsn(tenant, timeline) + if commit_lsn >= lsn: + log.info(f"{tenant}/{timeline} at commit_lsn {commit_lsn}") + return commit_lsn + max_commit_lsn = max(max_commit_lsn, commit_lsn) + + if i % 10 == 0: + log.info( + f"{tenant}/{timeline} waiting for commit_lsn to reach {lsn}, now {max_commit_lsn}" + ) + time.sleep(0.1) + raise Exception(f"timed out while waiting for commit_lsn to reach {lsn}, was {max_commit_lsn}") + + def flush_ep_to_pageserver( env: NeonEnv, ep: Endpoint, diff --git a/test_runner/performance/test_ingest_logical_message.py b/test_runner/performance/test_ingest_logical_message.py new file mode 100644 index 0000000000..d3118eb15a --- /dev/null +++ b/test_runner/performance/test_ingest_logical_message.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +import pytest +from fixtures.benchmark_fixture import MetricReport, NeonBenchmarker +from fixtures.common_types import Lsn +from fixtures.log_helper import log +from fixtures.neon_fixtures import ( + NeonEnvBuilder, + wait_for_commit_lsn, + wait_for_last_flush_lsn, +) +from fixtures.pageserver.utils import wait_for_last_record_lsn + + +@pytest.mark.timeout(600) +@pytest.mark.parametrize("size", [1024, 8192, 131072]) +@pytest.mark.parametrize("fsync", [True, False], ids=["fsync", "nofsync"]) +def test_ingest_logical_message( + request: pytest.FixtureRequest, + neon_env_builder: NeonEnvBuilder, + zenbenchmark: NeonBenchmarker, + fsync: bool, + size: int, +): + """ + Benchmarks ingestion of 10 GB of logical message WAL. These are essentially noops, and don't + incur any pageserver writes. + """ + + VOLUME = 10 * 1024**3 + count = VOLUME // size + + neon_env_builder.safekeepers_enable_fsync = fsync + + env = neon_env_builder.init_start() + endpoint = env.endpoints.create_start( + "main", + config_lines=[ + f"fsync = {fsync}", + # Disable backpressure. We don't want to block on pageserver. + "max_replication_apply_lag = 0", + "max_replication_flush_lag = 0", + "max_replication_write_lag = 0", + ], + ) + client = env.pageserver.http_client() + + # Wait for the timeline to be propagated to the pageserver. + wait_for_last_flush_lsn(env, endpoint, env.initial_tenant, env.initial_timeline) + + # Ingest data and measure durations. + start_lsn = Lsn(endpoint.safe_psql("select pg_current_wal_lsn()")[0][0]) + + with endpoint.cursor() as cur: + cur.execute("set statement_timeout = 0") + + # Postgres will return once the logical messages have been written to its local WAL, without + # waiting for Safekeeper commit. We measure ingestion time both for Postgres, Safekeeper, + # and Pageserver to detect bottlenecks. + log.info("Ingesting data") + with zenbenchmark.record_duration("pageserver_ingest"): + with zenbenchmark.record_duration("safekeeper_ingest"): + with zenbenchmark.record_duration("postgres_ingest"): + cur.execute(f""" + select pg_logical_emit_message(false, '', repeat('x', {size})) + from generate_series(1, {count}) + """) + + end_lsn = Lsn(endpoint.safe_psql("select pg_current_wal_lsn()")[0][0]) + + # Wait for Safekeeper. + log.info("Waiting for Safekeeper to catch up") + wait_for_commit_lsn(env, env.initial_tenant, env.initial_timeline, end_lsn) + + # Wait for Pageserver. + log.info("Waiting for Pageserver to catch up") + wait_for_last_record_lsn(client, env.initial_tenant, env.initial_timeline, end_lsn) + + # Now that all data is ingested, delete and recreate the tenant in the pageserver. This will + # reingest all the WAL from the safekeeper without any other constraints. This gives us a + # baseline of how fast the pageserver can ingest this WAL in isolation. + status = env.storage_controller.inspect(tenant_shard_id=env.initial_tenant) + assert status is not None + + client.tenant_delete(env.initial_tenant) + env.pageserver.tenant_create(tenant_id=env.initial_tenant, generation=status[0]) + + with zenbenchmark.record_duration("pageserver_recover_ingest"): + log.info("Recovering WAL into pageserver") + client.timeline_create(env.pg_version, env.initial_tenant, env.initial_timeline) + wait_for_last_flush_lsn(env, endpoint, env.initial_tenant, env.initial_timeline) + + # Emit metrics. + wal_written_mb = round((end_lsn - start_lsn) / (1024 * 1024)) + zenbenchmark.record("wal_written", wal_written_mb, "MB", MetricReport.TEST_PARAM) + zenbenchmark.record("message_count", count, "messages", MetricReport.TEST_PARAM) + + props = {p["name"]: p["value"] for _, p in request.node.user_properties} + for name in ("postgres", "safekeeper", "pageserver", "pageserver_recover"): + throughput = int(wal_written_mb / props[f"{name}_ingest"]) + zenbenchmark.record(f"{name}_throughput", throughput, "MB/s", MetricReport.HIGHER_IS_BETTER) From 1d642d6a57dd1cd1645a34aba5a2dd6e06a6c651 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Fri, 29 Nov 2024 11:08:01 +0000 Subject: [PATCH 02/15] chore(proxy): vendor a subset of rust-postgres (#9930) Our rust-postgres fork is getting messy. Mostly because proxy wants more control over the raw protocol than tokio-postgres provides. As such, it's diverging more and more. Storage and compute also make use of rust-postgres, but in more normal usage, thus they don't need our crazy changes. Idea: * proxy maintains their subset * other teams use a minimal patch set against upstream rust-postgres Reviewing this code will be difficult. To implement it, I 1. Copied tokio-postgres, postgres-protocol and postgres-types from https://github.com/neondatabase/rust-postgres/tree/00940fcdb57a8e99e805297b75839e7c4c7b1796 2. Updated their package names with the `2` suffix to make them compile in the workspace. 3. Updated proxy to use those packages 4. Copied in the code from tokio-postgres-rustls 0.13 (with some patches applied https://github.com/jbg/tokio-postgres-rustls/pull/32 https://github.com/jbg/tokio-postgres-rustls/pull/33) 5. Removed as much dead code as I could find in the vendored libraries 6. Updated the tokio-postgres-rustls code to use our existing channel binding implementation --- .config/hakari.toml | 3 + Cargo.lock | 56 +- Cargo.toml | 3 + libs/proxy/README.md | 6 + libs/proxy/postgres-protocol2/Cargo.toml | 21 + .../src/authentication/mod.rs | 37 + .../src/authentication/sasl.rs | 516 +++++ .../postgres-protocol2/src/escape/mod.rs | 93 + .../postgres-protocol2/src/escape/test.rs | 17 + libs/proxy/postgres-protocol2/src/lib.rs | 78 + .../postgres-protocol2/src/message/backend.rs | 766 ++++++++ .../src/message/frontend.rs | 297 +++ .../postgres-protocol2/src/message/mod.rs | 8 + .../postgres-protocol2/src/password/mod.rs | 107 ++ .../postgres-protocol2/src/password/test.rs | 19 + .../proxy/postgres-protocol2/src/types/mod.rs | 294 +++ .../postgres-protocol2/src/types/test.rs | 87 + libs/proxy/postgres-types2/Cargo.toml | 10 + libs/proxy/postgres-types2/src/lib.rs | 477 +++++ libs/proxy/postgres-types2/src/private.rs | 34 + libs/proxy/postgres-types2/src/type_gen.rs | 1524 +++++++++++++++ libs/proxy/tokio-postgres2/Cargo.toml | 21 + .../proxy/tokio-postgres2/src/cancel_query.rs | 40 + .../tokio-postgres2/src/cancel_query_raw.rs | 29 + .../proxy/tokio-postgres2/src/cancel_token.rs | 62 + libs/proxy/tokio-postgres2/src/client.rs | 439 +++++ libs/proxy/tokio-postgres2/src/codec.rs | 109 ++ libs/proxy/tokio-postgres2/src/config.rs | 897 +++++++++ libs/proxy/tokio-postgres2/src/connect.rs | 112 ++ libs/proxy/tokio-postgres2/src/connect_raw.rs | 359 ++++ .../tokio-postgres2/src/connect_socket.rs | 65 + libs/proxy/tokio-postgres2/src/connect_tls.rs | 48 + libs/proxy/tokio-postgres2/src/connection.rs | 323 ++++ libs/proxy/tokio-postgres2/src/error/mod.rs | 501 +++++ .../tokio-postgres2/src/error/sqlstate.rs | 1670 +++++++++++++++++ .../tokio-postgres2/src/generic_client.rs | 64 + libs/proxy/tokio-postgres2/src/lib.rs | 148 ++ .../tokio-postgres2/src/maybe_tls_stream.rs | 77 + libs/proxy/tokio-postgres2/src/prepare.rs | 262 +++ libs/proxy/tokio-postgres2/src/query.rs | 340 ++++ libs/proxy/tokio-postgres2/src/row.rs | 300 +++ .../proxy/tokio-postgres2/src/simple_query.rs | 142 ++ libs/proxy/tokio-postgres2/src/statement.rs | 157 ++ libs/proxy/tokio-postgres2/src/tls.rs | 162 ++ .../proxy/tokio-postgres2/src/to_statement.rs | 57 + libs/proxy/tokio-postgres2/src/transaction.rs | 74 + .../src/transaction_builder.rs | 113 ++ libs/proxy/tokio-postgres2/src/types.rs | 6 + proxy/Cargo.toml | 6 +- proxy/src/compute.rs | 5 +- proxy/src/context/mod.rs | 1 + proxy/src/lib.rs | 1 + proxy/src/postgres_rustls/mod.rs | 158 ++ proxy/src/proxy/tests/mod.rs | 2 +- proxy/src/serverless/backend.rs | 2 +- proxy/src/serverless/conn_pool.rs | 5 +- proxy/src/serverless/local_conn_pool.rs | 11 +- workspace_hack/Cargo.toml | 4 +- 58 files changed, 11199 insertions(+), 26 deletions(-) create mode 100644 libs/proxy/README.md create mode 100644 libs/proxy/postgres-protocol2/Cargo.toml create mode 100644 libs/proxy/postgres-protocol2/src/authentication/mod.rs create mode 100644 libs/proxy/postgres-protocol2/src/authentication/sasl.rs create mode 100644 libs/proxy/postgres-protocol2/src/escape/mod.rs create mode 100644 libs/proxy/postgres-protocol2/src/escape/test.rs create mode 100644 libs/proxy/postgres-protocol2/src/lib.rs create mode 100644 libs/proxy/postgres-protocol2/src/message/backend.rs create mode 100644 libs/proxy/postgres-protocol2/src/message/frontend.rs create mode 100644 libs/proxy/postgres-protocol2/src/message/mod.rs create mode 100644 libs/proxy/postgres-protocol2/src/password/mod.rs create mode 100644 libs/proxy/postgres-protocol2/src/password/test.rs create mode 100644 libs/proxy/postgres-protocol2/src/types/mod.rs create mode 100644 libs/proxy/postgres-protocol2/src/types/test.rs create mode 100644 libs/proxy/postgres-types2/Cargo.toml create mode 100644 libs/proxy/postgres-types2/src/lib.rs create mode 100644 libs/proxy/postgres-types2/src/private.rs create mode 100644 libs/proxy/postgres-types2/src/type_gen.rs create mode 100644 libs/proxy/tokio-postgres2/Cargo.toml create mode 100644 libs/proxy/tokio-postgres2/src/cancel_query.rs create mode 100644 libs/proxy/tokio-postgres2/src/cancel_query_raw.rs create mode 100644 libs/proxy/tokio-postgres2/src/cancel_token.rs create mode 100644 libs/proxy/tokio-postgres2/src/client.rs create mode 100644 libs/proxy/tokio-postgres2/src/codec.rs create mode 100644 libs/proxy/tokio-postgres2/src/config.rs create mode 100644 libs/proxy/tokio-postgres2/src/connect.rs create mode 100644 libs/proxy/tokio-postgres2/src/connect_raw.rs create mode 100644 libs/proxy/tokio-postgres2/src/connect_socket.rs create mode 100644 libs/proxy/tokio-postgres2/src/connect_tls.rs create mode 100644 libs/proxy/tokio-postgres2/src/connection.rs create mode 100644 libs/proxy/tokio-postgres2/src/error/mod.rs create mode 100644 libs/proxy/tokio-postgres2/src/error/sqlstate.rs create mode 100644 libs/proxy/tokio-postgres2/src/generic_client.rs create mode 100644 libs/proxy/tokio-postgres2/src/lib.rs create mode 100644 libs/proxy/tokio-postgres2/src/maybe_tls_stream.rs create mode 100644 libs/proxy/tokio-postgres2/src/prepare.rs create mode 100644 libs/proxy/tokio-postgres2/src/query.rs create mode 100644 libs/proxy/tokio-postgres2/src/row.rs create mode 100644 libs/proxy/tokio-postgres2/src/simple_query.rs create mode 100644 libs/proxy/tokio-postgres2/src/statement.rs create mode 100644 libs/proxy/tokio-postgres2/src/tls.rs create mode 100644 libs/proxy/tokio-postgres2/src/to_statement.rs create mode 100644 libs/proxy/tokio-postgres2/src/transaction.rs create mode 100644 libs/proxy/tokio-postgres2/src/transaction_builder.rs create mode 100644 libs/proxy/tokio-postgres2/src/types.rs create mode 100644 proxy/src/postgres_rustls/mod.rs diff --git a/.config/hakari.toml b/.config/hakari.toml index b5990d090e..3b6d9d8822 100644 --- a/.config/hakari.toml +++ b/.config/hakari.toml @@ -46,6 +46,9 @@ workspace-members = [ "utils", "wal_craft", "walproposer", + "postgres-protocol2", + "postgres-types2", + "tokio-postgres2", ] # Write out exact versions rather than a semver range. (Defaults to false.) diff --git a/Cargo.lock b/Cargo.lock index 43a46fb1eb..f05c6311dd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4162,6 +4162,23 @@ dependencies = [ "tokio", ] +[[package]] +name = "postgres-protocol2" +version = "0.1.0" +dependencies = [ + "base64 0.20.0", + "byteorder", + "bytes", + "fallible-iterator", + "hmac", + "md-5", + "memchr", + "rand 0.8.5", + "sha2", + "stringprep", + "tokio", +] + [[package]] name = "postgres-types" version = "0.2.4" @@ -4170,8 +4187,15 @@ dependencies = [ "bytes", "fallible-iterator", "postgres-protocol", - "serde", - "serde_json", +] + +[[package]] +name = "postgres-types2" +version = "0.1.0" +dependencies = [ + "bytes", + "fallible-iterator", + "postgres-protocol2", ] [[package]] @@ -4501,7 +4525,7 @@ dependencies = [ "parquet_derive", "pbkdf2", "pin-project-lite", - "postgres-protocol", + "postgres-protocol2", "postgres_backend", "pq_proto", "prometheus", @@ -4536,8 +4560,7 @@ dependencies = [ "tikv-jemalloc-ctl", "tikv-jemallocator", "tokio", - "tokio-postgres", - "tokio-postgres-rustls", + "tokio-postgres2", "tokio-rustls 0.26.0", "tokio-tungstenite", "tokio-util", @@ -6421,6 +6444,7 @@ dependencies = [ "libc", "mio", "num_cpus", + "parking_lot 0.12.1", "pin-project-lite", "signal-hook-registry", "socket2", @@ -6502,6 +6526,26 @@ dependencies = [ "x509-certificate", ] +[[package]] +name = "tokio-postgres2" +version = "0.1.0" +dependencies = [ + "async-trait", + "byteorder", + "bytes", + "fallible-iterator", + "futures-util", + "log", + "parking_lot 0.12.1", + "percent-encoding", + "phf", + "pin-project-lite", + "postgres-protocol2", + "postgres-types2", + "tokio", + "tokio-util", +] + [[package]] name = "tokio-rustls" version = "0.24.0" @@ -7597,7 +7641,6 @@ dependencies = [ "num-traits", "once_cell", "parquet", - "postgres-types", "prettyplease", "proc-macro2", "prost", @@ -7622,7 +7665,6 @@ dependencies = [ "time", "time-macros", "tokio", - "tokio-postgres", "tokio-rustls 0.26.0", "tokio-stream", "tokio-util", diff --git a/Cargo.toml b/Cargo.toml index e3dc5b97f8..742201d0f5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,6 +35,9 @@ members = [ "libs/walproposer", "libs/wal_decoder", "libs/postgres_initdb", + "libs/proxy/postgres-protocol2", + "libs/proxy/postgres-types2", + "libs/proxy/tokio-postgres2", ] [workspace.package] diff --git a/libs/proxy/README.md b/libs/proxy/README.md new file mode 100644 index 0000000000..2ae6210e46 --- /dev/null +++ b/libs/proxy/README.md @@ -0,0 +1,6 @@ +This directory contains libraries that are specific for proxy. + +Currently, it contains a signficant fork/refactoring of rust-postgres that no longer reflects the API +of the original library. Since it was so significant, it made sense to upgrade it to it's own set of libraries. + +Proxy needs unique access to the protocol, which explains why such heavy modifications were necessary. diff --git a/libs/proxy/postgres-protocol2/Cargo.toml b/libs/proxy/postgres-protocol2/Cargo.toml new file mode 100644 index 0000000000..284a632954 --- /dev/null +++ b/libs/proxy/postgres-protocol2/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "postgres-protocol2" +version = "0.1.0" +edition = "2018" +license = "MIT/Apache-2.0" + +[dependencies] +base64 = "0.20" +byteorder.workspace = true +bytes.workspace = true +fallible-iterator.workspace = true +hmac.workspace = true +md-5 = "0.10" +memchr = "2.0" +rand.workspace = true +sha2.workspace = true +stringprep = "0.1" +tokio = { workspace = true, features = ["rt"] } + +[dev-dependencies] +tokio = { workspace = true, features = ["full"] } diff --git a/libs/proxy/postgres-protocol2/src/authentication/mod.rs b/libs/proxy/postgres-protocol2/src/authentication/mod.rs new file mode 100644 index 0000000000..71afa4b9b6 --- /dev/null +++ b/libs/proxy/postgres-protocol2/src/authentication/mod.rs @@ -0,0 +1,37 @@ +//! Authentication protocol support. +use md5::{Digest, Md5}; + +pub mod sasl; + +/// Hashes authentication information in a way suitable for use in response +/// to an `AuthenticationMd5Password` message. +/// +/// The resulting string should be sent back to the database in a +/// `PasswordMessage` message. +#[inline] +pub fn md5_hash(username: &[u8], password: &[u8], salt: [u8; 4]) -> String { + let mut md5 = Md5::new(); + md5.update(password); + md5.update(username); + let output = md5.finalize_reset(); + md5.update(format!("{:x}", output)); + md5.update(salt); + format!("md5{:x}", md5.finalize()) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn md5() { + let username = b"md5_user"; + let password = b"password"; + let salt = [0x2a, 0x3d, 0x8f, 0xe0]; + + assert_eq!( + md5_hash(username, password, salt), + "md562af4dd09bbb41884907a838a3233294" + ); + } +} diff --git a/libs/proxy/postgres-protocol2/src/authentication/sasl.rs b/libs/proxy/postgres-protocol2/src/authentication/sasl.rs new file mode 100644 index 0000000000..19aa3c1e9a --- /dev/null +++ b/libs/proxy/postgres-protocol2/src/authentication/sasl.rs @@ -0,0 +1,516 @@ +//! SASL-based authentication support. + +use hmac::{Hmac, Mac}; +use rand::{self, Rng}; +use sha2::digest::FixedOutput; +use sha2::{Digest, Sha256}; +use std::fmt::Write; +use std::io; +use std::iter; +use std::mem; +use std::str; +use tokio::task::yield_now; + +const NONCE_LENGTH: usize = 24; + +/// The identifier of the SCRAM-SHA-256 SASL authentication mechanism. +pub const SCRAM_SHA_256: &str = "SCRAM-SHA-256"; +/// The identifier of the SCRAM-SHA-256-PLUS SASL authentication mechanism. +pub const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS"; + +// since postgres passwords are not required to exclude saslprep-prohibited +// characters or even be valid UTF8, we run saslprep if possible and otherwise +// return the raw password. +fn normalize(pass: &[u8]) -> Vec { + let pass = match str::from_utf8(pass) { + Ok(pass) => pass, + Err(_) => return pass.to_vec(), + }; + + match stringprep::saslprep(pass) { + Ok(pass) => pass.into_owned().into_bytes(), + Err(_) => pass.as_bytes().to_vec(), + } +} + +pub(crate) async fn hi(str: &[u8], salt: &[u8], iterations: u32) -> [u8; 32] { + let mut hmac = + Hmac::::new_from_slice(str).expect("HMAC is able to accept all key sizes"); + hmac.update(salt); + hmac.update(&[0, 0, 0, 1]); + let mut prev = hmac.finalize().into_bytes(); + + let mut hi = prev; + + for i in 1..iterations { + let mut hmac = Hmac::::new_from_slice(str).expect("already checked above"); + hmac.update(&prev); + prev = hmac.finalize().into_bytes(); + + for (hi, prev) in hi.iter_mut().zip(prev) { + *hi ^= prev; + } + // yield every ~250us + // hopefully reduces tail latencies + if i % 1024 == 0 { + yield_now().await + } + } + + hi.into() +} + +enum ChannelBindingInner { + Unrequested, + Unsupported, + TlsServerEndPoint(Vec), +} + +/// The channel binding configuration for a SCRAM authentication exchange. +pub struct ChannelBinding(ChannelBindingInner); + +impl ChannelBinding { + /// The server did not request channel binding. + pub fn unrequested() -> ChannelBinding { + ChannelBinding(ChannelBindingInner::Unrequested) + } + + /// The server requested channel binding but the client is unable to provide it. + pub fn unsupported() -> ChannelBinding { + ChannelBinding(ChannelBindingInner::Unsupported) + } + + /// The server requested channel binding and the client will use the `tls-server-end-point` + /// method. + pub fn tls_server_end_point(signature: Vec) -> ChannelBinding { + ChannelBinding(ChannelBindingInner::TlsServerEndPoint(signature)) + } + + fn gs2_header(&self) -> &'static str { + match self.0 { + ChannelBindingInner::Unrequested => "y,,", + ChannelBindingInner::Unsupported => "n,,", + ChannelBindingInner::TlsServerEndPoint(_) => "p=tls-server-end-point,,", + } + } + + fn cbind_data(&self) -> &[u8] { + match self.0 { + ChannelBindingInner::Unrequested | ChannelBindingInner::Unsupported => &[], + ChannelBindingInner::TlsServerEndPoint(ref buf) => buf, + } + } +} + +/// A pair of keys for the SCRAM-SHA-256 mechanism. +/// See for details. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ScramKeys { + /// Used by server to authenticate client. + pub client_key: [u8; N], + /// Used by client to verify server's signature. + pub server_key: [u8; N], +} + +/// Password or keys which were derived from it. +enum Credentials { + /// A regular password as a vector of bytes. + Password(Vec), + /// A precomputed pair of keys. + Keys(Box>), +} + +enum State { + Update { + nonce: String, + password: Credentials<32>, + channel_binding: ChannelBinding, + }, + Finish { + server_key: [u8; 32], + auth_message: String, + }, + Done, +} + +/// A type which handles the client side of the SCRAM-SHA-256/SCRAM-SHA-256-PLUS authentication +/// process. +/// +/// During the authentication process, if the backend sends an `AuthenticationSASL` message which +/// includes `SCRAM-SHA-256` as an authentication mechanism, this type can be used. +/// +/// After a `ScramSha256` is constructed, the buffer returned by the `message()` method should be +/// sent to the backend in a `SASLInitialResponse` message along with the mechanism name. +/// +/// The server will reply with an `AuthenticationSASLContinue` message. Its contents should be +/// passed to the `update()` method, after which the buffer returned by the `message()` method +/// should be sent to the backend in a `SASLResponse` message. +/// +/// The server will reply with an `AuthenticationSASLFinal` message. Its contents should be passed +/// to the `finish()` method, after which the authentication process is complete. +pub struct ScramSha256 { + message: String, + state: State, +} + +fn nonce() -> String { + // rand 0.5's ThreadRng is cryptographically secure + let mut rng = rand::thread_rng(); + (0..NONCE_LENGTH) + .map(|_| { + let mut v = rng.gen_range(0x21u8..0x7e); + if v == 0x2c { + v = 0x7e + } + v as char + }) + .collect() +} + +impl ScramSha256 { + /// Constructs a new instance which will use the provided password for authentication. + pub fn new(password: &[u8], channel_binding: ChannelBinding) -> ScramSha256 { + let password = Credentials::Password(normalize(password)); + ScramSha256::new_inner(password, channel_binding, nonce()) + } + + /// Constructs a new instance which will use the provided key pair for authentication. + pub fn new_with_keys(keys: ScramKeys<32>, channel_binding: ChannelBinding) -> ScramSha256 { + let password = Credentials::Keys(keys.into()); + ScramSha256::new_inner(password, channel_binding, nonce()) + } + + fn new_inner( + password: Credentials<32>, + channel_binding: ChannelBinding, + nonce: String, + ) -> ScramSha256 { + ScramSha256 { + message: format!("{}n=,r={}", channel_binding.gs2_header(), nonce), + state: State::Update { + nonce, + password, + channel_binding, + }, + } + } + + /// Returns the message which should be sent to the backend in an `SASLResponse` message. + pub fn message(&self) -> &[u8] { + if let State::Done = self.state { + panic!("invalid SCRAM state"); + } + self.message.as_bytes() + } + + /// Updates the state machine with the response from the backend. + /// + /// This should be called when an `AuthenticationSASLContinue` message is received. + pub async fn update(&mut self, message: &[u8]) -> io::Result<()> { + let (client_nonce, password, channel_binding) = + match mem::replace(&mut self.state, State::Done) { + State::Update { + nonce, + password, + channel_binding, + } => (nonce, password, channel_binding), + _ => return Err(io::Error::new(io::ErrorKind::Other, "invalid SCRAM state")), + }; + + let message = + str::from_utf8(message).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?; + + let parsed = Parser::new(message).server_first_message()?; + + if !parsed.nonce.starts_with(&client_nonce) { + return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid nonce")); + } + + let (client_key, server_key) = match password { + Credentials::Password(password) => { + let salt = match base64::decode(parsed.salt) { + Ok(salt) => salt, + Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)), + }; + + let salted_password = hi(&password, &salt, parsed.iteration_count).await; + + let make_key = |name| { + let mut hmac = Hmac::::new_from_slice(&salted_password) + .expect("HMAC is able to accept all key sizes"); + hmac.update(name); + + let mut key = [0u8; 32]; + key.copy_from_slice(hmac.finalize().into_bytes().as_slice()); + key + }; + + (make_key(b"Client Key"), make_key(b"Server Key")) + } + Credentials::Keys(keys) => (keys.client_key, keys.server_key), + }; + + let mut hash = Sha256::default(); + hash.update(client_key); + let stored_key = hash.finalize_fixed(); + + let mut cbind_input = vec![]; + cbind_input.extend(channel_binding.gs2_header().as_bytes()); + cbind_input.extend(channel_binding.cbind_data()); + let cbind_input = base64::encode(&cbind_input); + + self.message.clear(); + write!(&mut self.message, "c={},r={}", cbind_input, parsed.nonce).unwrap(); + + let auth_message = format!("n=,r={},{},{}", client_nonce, message, self.message); + + let mut hmac = Hmac::::new_from_slice(&stored_key) + .expect("HMAC is able to accept all key sizes"); + hmac.update(auth_message.as_bytes()); + let client_signature = hmac.finalize().into_bytes(); + + let mut client_proof = client_key; + for (proof, signature) in client_proof.iter_mut().zip(client_signature) { + *proof ^= signature; + } + + write!(&mut self.message, ",p={}", base64::encode(client_proof)).unwrap(); + + self.state = State::Finish { + server_key, + auth_message, + }; + Ok(()) + } + + /// Finalizes the authentication process. + /// + /// This should be called when the backend sends an `AuthenticationSASLFinal` message. + /// Authentication has only succeeded if this method returns `Ok(())`. + pub fn finish(&mut self, message: &[u8]) -> io::Result<()> { + let (server_key, auth_message) = match mem::replace(&mut self.state, State::Done) { + State::Finish { + server_key, + auth_message, + } => (server_key, auth_message), + _ => return Err(io::Error::new(io::ErrorKind::Other, "invalid SCRAM state")), + }; + + let message = + str::from_utf8(message).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?; + + let parsed = Parser::new(message).server_final_message()?; + + let verifier = match parsed { + ServerFinalMessage::Error(e) => { + return Err(io::Error::new( + io::ErrorKind::Other, + format!("SCRAM error: {}", e), + )); + } + ServerFinalMessage::Verifier(verifier) => verifier, + }; + + let verifier = match base64::decode(verifier) { + Ok(verifier) => verifier, + Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)), + }; + + let mut hmac = Hmac::::new_from_slice(&server_key) + .expect("HMAC is able to accept all key sizes"); + hmac.update(auth_message.as_bytes()); + hmac.verify_slice(&verifier) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "SCRAM verification error")) + } +} + +struct Parser<'a> { + s: &'a str, + it: iter::Peekable>, +} + +impl<'a> Parser<'a> { + fn new(s: &'a str) -> Parser<'a> { + Parser { + s, + it: s.char_indices().peekable(), + } + } + + fn eat(&mut self, target: char) -> io::Result<()> { + match self.it.next() { + Some((_, c)) if c == target => Ok(()), + Some((i, c)) => { + let m = format!( + "unexpected character at byte {}: expected `{}` but got `{}", + i, target, c + ); + Err(io::Error::new(io::ErrorKind::InvalidInput, m)) + } + None => Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "unexpected EOF", + )), + } + } + + fn take_while(&mut self, f: F) -> io::Result<&'a str> + where + F: Fn(char) -> bool, + { + let start = match self.it.peek() { + Some(&(i, _)) => i, + None => return Ok(""), + }; + + loop { + match self.it.peek() { + Some(&(_, c)) if f(c) => { + self.it.next(); + } + Some(&(i, _)) => return Ok(&self.s[start..i]), + None => return Ok(&self.s[start..]), + } + } + } + + fn printable(&mut self) -> io::Result<&'a str> { + self.take_while(|c| matches!(c, '\x21'..='\x2b' | '\x2d'..='\x7e')) + } + + fn nonce(&mut self) -> io::Result<&'a str> { + self.eat('r')?; + self.eat('=')?; + self.printable() + } + + fn base64(&mut self) -> io::Result<&'a str> { + self.take_while(|c| matches!(c, 'a'..='z' | 'A'..='Z' | '0'..='9' | '/' | '+' | '=')) + } + + fn salt(&mut self) -> io::Result<&'a str> { + self.eat('s')?; + self.eat('=')?; + self.base64() + } + + fn posit_number(&mut self) -> io::Result { + let n = self.take_while(|c| c.is_ascii_digit())?; + n.parse() + .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e)) + } + + fn iteration_count(&mut self) -> io::Result { + self.eat('i')?; + self.eat('=')?; + self.posit_number() + } + + fn eof(&mut self) -> io::Result<()> { + match self.it.peek() { + Some(&(i, _)) => Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unexpected trailing data at byte {}", i), + )), + None => Ok(()), + } + } + + fn server_first_message(&mut self) -> io::Result> { + let nonce = self.nonce()?; + self.eat(',')?; + let salt = self.salt()?; + self.eat(',')?; + let iteration_count = self.iteration_count()?; + self.eof()?; + + Ok(ServerFirstMessage { + nonce, + salt, + iteration_count, + }) + } + + fn value(&mut self) -> io::Result<&'a str> { + self.take_while(|c| matches!(c, '\0' | '=' | ',')) + } + + fn server_error(&mut self) -> io::Result> { + match self.it.peek() { + Some(&(_, 'e')) => {} + _ => return Ok(None), + } + + self.eat('e')?; + self.eat('=')?; + self.value().map(Some) + } + + fn verifier(&mut self) -> io::Result<&'a str> { + self.eat('v')?; + self.eat('=')?; + self.base64() + } + + fn server_final_message(&mut self) -> io::Result> { + let message = match self.server_error()? { + Some(error) => ServerFinalMessage::Error(error), + None => ServerFinalMessage::Verifier(self.verifier()?), + }; + self.eof()?; + Ok(message) + } +} + +struct ServerFirstMessage<'a> { + nonce: &'a str, + salt: &'a str, + iteration_count: u32, +} + +enum ServerFinalMessage<'a> { + Error(&'a str), + Verifier(&'a str), +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn parse_server_first_message() { + let message = "r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096"; + let message = Parser::new(message).server_first_message().unwrap(); + assert_eq!(message.nonce, "fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j"); + assert_eq!(message.salt, "QSXCR+Q6sek8bf92"); + assert_eq!(message.iteration_count, 4096); + } + + // recorded auth exchange from psql + #[tokio::test] + async fn exchange() { + let password = "foobar"; + let nonce = "9IZ2O01zb9IgiIZ1WJ/zgpJB"; + + let client_first = "n,,n=,r=9IZ2O01zb9IgiIZ1WJ/zgpJB"; + let server_first = + "r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,s=fs3IXBy7U7+IvVjZ,i\ + =4096"; + let client_final = + "c=biws,r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,p=AmNKosjJzS3\ + 1NTlQYNs5BTeQjdHdk7lOflDo5re2an8="; + let server_final = "v=U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw="; + + let mut scram = ScramSha256::new_inner( + Credentials::Password(normalize(password.as_bytes())), + ChannelBinding::unsupported(), + nonce.to_string(), + ); + assert_eq!(str::from_utf8(scram.message()).unwrap(), client_first); + + scram.update(server_first.as_bytes()).await.unwrap(); + assert_eq!(str::from_utf8(scram.message()).unwrap(), client_final); + + scram.finish(server_final.as_bytes()).unwrap(); + } +} diff --git a/libs/proxy/postgres-protocol2/src/escape/mod.rs b/libs/proxy/postgres-protocol2/src/escape/mod.rs new file mode 100644 index 0000000000..0ba7efdcac --- /dev/null +++ b/libs/proxy/postgres-protocol2/src/escape/mod.rs @@ -0,0 +1,93 @@ +//! Provides functions for escaping literals and identifiers for use +//! in SQL queries. +//! +//! Prefer parameterized queries where possible. Do not escape +//! parameters in a parameterized query. + +#[cfg(test)] +mod test; + +/// Escape a literal and surround result with single quotes. Not +/// recommended in most cases. +/// +/// If input contains backslashes, result will be of the form ` +/// E'...'` so it is safe to use regardless of the setting of +/// standard_conforming_strings. +pub fn escape_literal(input: &str) -> String { + escape_internal(input, false) +} + +/// Escape an identifier and surround result with double quotes. +pub fn escape_identifier(input: &str) -> String { + escape_internal(input, true) +} + +// Translation of PostgreSQL libpq's PQescapeInternal(). Does not +// require a connection because input string is known to be valid +// UTF-8. +// +// Escape arbitrary strings. If as_ident is true, we escape the +// result as an identifier; if false, as a literal. The result is +// returned in a newly allocated buffer. If we fail due to an +// encoding violation or out of memory condition, we return NULL, +// storing an error message into conn. +fn escape_internal(input: &str, as_ident: bool) -> String { + let mut num_backslashes = 0; + let mut num_quotes = 0; + let quote_char = if as_ident { '"' } else { '\'' }; + + // Scan the string for characters that must be escaped. + for ch in input.chars() { + if ch == quote_char { + num_quotes += 1; + } else if ch == '\\' { + num_backslashes += 1; + } + } + + // Allocate output String. + let mut result_size = input.len() + num_quotes + 3; // two quotes, plus a NUL + if !as_ident && num_backslashes > 0 { + result_size += num_backslashes + 2; + } + + let mut output = String::with_capacity(result_size); + + // If we are escaping a literal that contains backslashes, we use + // the escape string syntax so that the result is correct under + // either value of standard_conforming_strings. We also emit a + // leading space in this case, to guard against the possibility + // that the result might be interpolated immediately following an + // identifier. + if !as_ident && num_backslashes > 0 { + output.push(' '); + output.push('E'); + } + + // Opening quote. + output.push(quote_char); + + // Use fast path if possible. + // + // We've already verified that the input string is well-formed in + // the current encoding. If it contains no quotes and, in the + // case of literal-escaping, no backslashes, then we can just copy + // it directly to the output buffer, adding the necessary quotes. + // + // If not, we must rescan the input and process each character + // individually. + if num_quotes == 0 && (num_backslashes == 0 || as_ident) { + output.push_str(input); + } else { + for ch in input.chars() { + if ch == quote_char || (!as_ident && ch == '\\') { + output.push(ch); + } + output.push(ch); + } + } + + output.push(quote_char); + + output +} diff --git a/libs/proxy/postgres-protocol2/src/escape/test.rs b/libs/proxy/postgres-protocol2/src/escape/test.rs new file mode 100644 index 0000000000..4816a103b7 --- /dev/null +++ b/libs/proxy/postgres-protocol2/src/escape/test.rs @@ -0,0 +1,17 @@ +use crate::escape::{escape_identifier, escape_literal}; + +#[test] +fn test_escape_idenifier() { + assert_eq!(escape_identifier("foo"), String::from("\"foo\"")); + assert_eq!(escape_identifier("f\\oo"), String::from("\"f\\oo\"")); + assert_eq!(escape_identifier("f'oo"), String::from("\"f'oo\"")); + assert_eq!(escape_identifier("f\"oo"), String::from("\"f\"\"oo\"")); +} + +#[test] +fn test_escape_literal() { + assert_eq!(escape_literal("foo"), String::from("'foo'")); + assert_eq!(escape_literal("f\\oo"), String::from(" E'f\\\\oo'")); + assert_eq!(escape_literal("f'oo"), String::from("'f''oo'")); + assert_eq!(escape_literal("f\"oo"), String::from("'f\"oo'")); +} diff --git a/libs/proxy/postgres-protocol2/src/lib.rs b/libs/proxy/postgres-protocol2/src/lib.rs new file mode 100644 index 0000000000..947f2f835d --- /dev/null +++ b/libs/proxy/postgres-protocol2/src/lib.rs @@ -0,0 +1,78 @@ +//! Low level Postgres protocol APIs. +//! +//! This crate implements the low level components of Postgres's communication +//! protocol, including message and value serialization and deserialization. +//! It is designed to be used as a building block by higher level APIs such as +//! `rust-postgres`, and should not typically be used directly. +//! +//! # Note +//! +//! This library assumes that the `client_encoding` backend parameter has been +//! set to `UTF8`. It will most likely not behave properly if that is not the case. +#![doc(html_root_url = "https://docs.rs/postgres-protocol/0.6")] +#![warn(missing_docs, rust_2018_idioms, clippy::all)] + +use byteorder::{BigEndian, ByteOrder}; +use bytes::{BufMut, BytesMut}; +use std::io; + +pub mod authentication; +pub mod escape; +pub mod message; +pub mod password; +pub mod types; + +/// A Postgres OID. +pub type Oid = u32; + +/// A Postgres Log Sequence Number (LSN). +pub type Lsn = u64; + +/// An enum indicating if a value is `NULL` or not. +pub enum IsNull { + /// The value is `NULL`. + Yes, + /// The value is not `NULL`. + No, +} + +fn write_nullable(serializer: F, buf: &mut BytesMut) -> Result<(), E> +where + F: FnOnce(&mut BytesMut) -> Result, + E: From, +{ + let base = buf.len(); + buf.put_i32(0); + let size = match serializer(buf)? { + IsNull::No => i32::from_usize(buf.len() - base - 4)?, + IsNull::Yes => -1, + }; + BigEndian::write_i32(&mut buf[base..], size); + + Ok(()) +} + +trait FromUsize: Sized { + fn from_usize(x: usize) -> Result; +} + +macro_rules! from_usize { + ($t:ty) => { + impl FromUsize for $t { + #[inline] + fn from_usize(x: usize) -> io::Result<$t> { + if x > <$t>::MAX as usize { + Err(io::Error::new( + io::ErrorKind::InvalidInput, + "value too large to transmit", + )) + } else { + Ok(x as $t) + } + } + } + }; +} + +from_usize!(i16); +from_usize!(i32); diff --git a/libs/proxy/postgres-protocol2/src/message/backend.rs b/libs/proxy/postgres-protocol2/src/message/backend.rs new file mode 100644 index 0000000000..356d142f3f --- /dev/null +++ b/libs/proxy/postgres-protocol2/src/message/backend.rs @@ -0,0 +1,766 @@ +#![allow(missing_docs)] + +use byteorder::{BigEndian, ByteOrder, ReadBytesExt}; +use bytes::{Bytes, BytesMut}; +use fallible_iterator::FallibleIterator; +use memchr::memchr; +use std::cmp; +use std::io::{self, Read}; +use std::ops::Range; +use std::str; + +use crate::Oid; + +// top-level message tags +const PARSE_COMPLETE_TAG: u8 = b'1'; +const BIND_COMPLETE_TAG: u8 = b'2'; +const CLOSE_COMPLETE_TAG: u8 = b'3'; +pub const NOTIFICATION_RESPONSE_TAG: u8 = b'A'; +const COPY_DONE_TAG: u8 = b'c'; +const COMMAND_COMPLETE_TAG: u8 = b'C'; +const COPY_DATA_TAG: u8 = b'd'; +const DATA_ROW_TAG: u8 = b'D'; +const ERROR_RESPONSE_TAG: u8 = b'E'; +const COPY_IN_RESPONSE_TAG: u8 = b'G'; +const COPY_OUT_RESPONSE_TAG: u8 = b'H'; +const COPY_BOTH_RESPONSE_TAG: u8 = b'W'; +const EMPTY_QUERY_RESPONSE_TAG: u8 = b'I'; +const BACKEND_KEY_DATA_TAG: u8 = b'K'; +pub const NO_DATA_TAG: u8 = b'n'; +pub const NOTICE_RESPONSE_TAG: u8 = b'N'; +const AUTHENTICATION_TAG: u8 = b'R'; +const PORTAL_SUSPENDED_TAG: u8 = b's'; +pub const PARAMETER_STATUS_TAG: u8 = b'S'; +const PARAMETER_DESCRIPTION_TAG: u8 = b't'; +const ROW_DESCRIPTION_TAG: u8 = b'T'; +pub const READY_FOR_QUERY_TAG: u8 = b'Z'; + +#[derive(Debug, Copy, Clone)] +pub struct Header { + tag: u8, + len: i32, +} + +#[allow(clippy::len_without_is_empty)] +impl Header { + #[inline] + pub fn parse(buf: &[u8]) -> io::Result> { + if buf.len() < 5 { + return Ok(None); + } + + let tag = buf[0]; + let len = BigEndian::read_i32(&buf[1..]); + + if len < 4 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "invalid message length: header length < 4", + )); + } + + Ok(Some(Header { tag, len })) + } + + #[inline] + pub fn tag(self) -> u8 { + self.tag + } + + #[inline] + pub fn len(self) -> i32 { + self.len + } +} + +/// An enum representing Postgres backend messages. +#[non_exhaustive] +pub enum Message { + AuthenticationCleartextPassword, + AuthenticationGss, + AuthenticationKerberosV5, + AuthenticationMd5Password(AuthenticationMd5PasswordBody), + AuthenticationOk, + AuthenticationScmCredential, + AuthenticationSspi, + AuthenticationGssContinue, + AuthenticationSasl(AuthenticationSaslBody), + AuthenticationSaslContinue(AuthenticationSaslContinueBody), + AuthenticationSaslFinal(AuthenticationSaslFinalBody), + BackendKeyData(BackendKeyDataBody), + BindComplete, + CloseComplete, + CommandComplete(CommandCompleteBody), + CopyData, + CopyDone, + CopyInResponse, + CopyOutResponse, + CopyBothResponse, + DataRow(DataRowBody), + EmptyQueryResponse, + ErrorResponse(ErrorResponseBody), + NoData, + NoticeResponse(NoticeResponseBody), + NotificationResponse(NotificationResponseBody), + ParameterDescription(ParameterDescriptionBody), + ParameterStatus(ParameterStatusBody), + ParseComplete, + PortalSuspended, + ReadyForQuery(ReadyForQueryBody), + RowDescription(RowDescriptionBody), +} + +impl Message { + #[inline] + pub fn parse(buf: &mut BytesMut) -> io::Result> { + if buf.len() < 5 { + let to_read = 5 - buf.len(); + buf.reserve(to_read); + return Ok(None); + } + + let tag = buf[0]; + let len = (&buf[1..5]).read_u32::().unwrap(); + + if len < 4 { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid message length: parsing u32", + )); + } + + let total_len = len as usize + 1; + if buf.len() < total_len { + let to_read = total_len - buf.len(); + buf.reserve(to_read); + return Ok(None); + } + + let mut buf = Buffer { + bytes: buf.split_to(total_len).freeze(), + idx: 5, + }; + + let message = match tag { + PARSE_COMPLETE_TAG => Message::ParseComplete, + BIND_COMPLETE_TAG => Message::BindComplete, + CLOSE_COMPLETE_TAG => Message::CloseComplete, + NOTIFICATION_RESPONSE_TAG => { + let process_id = buf.read_i32::()?; + let channel = buf.read_cstr()?; + let message = buf.read_cstr()?; + Message::NotificationResponse(NotificationResponseBody { + process_id, + channel, + message, + }) + } + COPY_DONE_TAG => Message::CopyDone, + COMMAND_COMPLETE_TAG => { + let tag = buf.read_cstr()?; + Message::CommandComplete(CommandCompleteBody { tag }) + } + COPY_DATA_TAG => Message::CopyData, + DATA_ROW_TAG => { + let len = buf.read_u16::()?; + let storage = buf.read_all(); + Message::DataRow(DataRowBody { storage, len }) + } + ERROR_RESPONSE_TAG => { + let storage = buf.read_all(); + Message::ErrorResponse(ErrorResponseBody { storage }) + } + COPY_IN_RESPONSE_TAG => Message::CopyInResponse, + COPY_OUT_RESPONSE_TAG => Message::CopyOutResponse, + COPY_BOTH_RESPONSE_TAG => Message::CopyBothResponse, + EMPTY_QUERY_RESPONSE_TAG => Message::EmptyQueryResponse, + BACKEND_KEY_DATA_TAG => { + let process_id = buf.read_i32::()?; + let secret_key = buf.read_i32::()?; + Message::BackendKeyData(BackendKeyDataBody { + process_id, + secret_key, + }) + } + NO_DATA_TAG => Message::NoData, + NOTICE_RESPONSE_TAG => { + let storage = buf.read_all(); + Message::NoticeResponse(NoticeResponseBody { storage }) + } + AUTHENTICATION_TAG => match buf.read_i32::()? { + 0 => Message::AuthenticationOk, + 2 => Message::AuthenticationKerberosV5, + 3 => Message::AuthenticationCleartextPassword, + 5 => { + let mut salt = [0; 4]; + buf.read_exact(&mut salt)?; + Message::AuthenticationMd5Password(AuthenticationMd5PasswordBody { salt }) + } + 6 => Message::AuthenticationScmCredential, + 7 => Message::AuthenticationGss, + 8 => Message::AuthenticationGssContinue, + 9 => Message::AuthenticationSspi, + 10 => { + let storage = buf.read_all(); + Message::AuthenticationSasl(AuthenticationSaslBody(storage)) + } + 11 => { + let storage = buf.read_all(); + Message::AuthenticationSaslContinue(AuthenticationSaslContinueBody(storage)) + } + 12 => { + let storage = buf.read_all(); + Message::AuthenticationSaslFinal(AuthenticationSaslFinalBody(storage)) + } + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown authentication tag `{}`", tag), + )); + } + }, + PORTAL_SUSPENDED_TAG => Message::PortalSuspended, + PARAMETER_STATUS_TAG => { + let name = buf.read_cstr()?; + let value = buf.read_cstr()?; + Message::ParameterStatus(ParameterStatusBody { name, value }) + } + PARAMETER_DESCRIPTION_TAG => { + let len = buf.read_u16::()?; + let storage = buf.read_all(); + Message::ParameterDescription(ParameterDescriptionBody { storage, len }) + } + ROW_DESCRIPTION_TAG => { + let len = buf.read_u16::()?; + let storage = buf.read_all(); + Message::RowDescription(RowDescriptionBody { storage, len }) + } + READY_FOR_QUERY_TAG => { + let status = buf.read_u8()?; + Message::ReadyForQuery(ReadyForQueryBody { status }) + } + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown message tag `{}`", tag), + )); + } + }; + + if !buf.is_empty() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid message length: expected buffer to be empty", + )); + } + + Ok(Some(message)) + } +} + +struct Buffer { + bytes: Bytes, + idx: usize, +} + +impl Buffer { + #[inline] + fn slice(&self) -> &[u8] { + &self.bytes[self.idx..] + } + + #[inline] + fn is_empty(&self) -> bool { + self.slice().is_empty() + } + + #[inline] + fn read_cstr(&mut self) -> io::Result { + match memchr(0, self.slice()) { + Some(pos) => { + let start = self.idx; + let end = start + pos; + let cstr = self.bytes.slice(start..end); + self.idx = end + 1; + Ok(cstr) + } + None => Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "unexpected EOF", + )), + } + } + + #[inline] + fn read_all(&mut self) -> Bytes { + let buf = self.bytes.slice(self.idx..); + self.idx = self.bytes.len(); + buf + } +} + +impl Read for Buffer { + #[inline] + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let len = { + let slice = self.slice(); + let len = cmp::min(slice.len(), buf.len()); + buf[..len].copy_from_slice(&slice[..len]); + len + }; + self.idx += len; + Ok(len) + } +} + +pub struct AuthenticationMd5PasswordBody { + salt: [u8; 4], +} + +impl AuthenticationMd5PasswordBody { + #[inline] + pub fn salt(&self) -> [u8; 4] { + self.salt + } +} + +pub struct AuthenticationSaslBody(Bytes); + +impl AuthenticationSaslBody { + #[inline] + pub fn mechanisms(&self) -> SaslMechanisms<'_> { + SaslMechanisms(&self.0) + } +} + +pub struct SaslMechanisms<'a>(&'a [u8]); + +impl<'a> FallibleIterator for SaslMechanisms<'a> { + type Item = &'a str; + type Error = io::Error; + + #[inline] + fn next(&mut self) -> io::Result> { + let value_end = find_null(self.0, 0)?; + if value_end == 0 { + if self.0.len() != 1 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "invalid message length: expected to be at end of iterator for sasl", + )); + } + Ok(None) + } else { + let value = get_str(&self.0[..value_end])?; + self.0 = &self.0[value_end + 1..]; + Ok(Some(value)) + } + } +} + +pub struct AuthenticationSaslContinueBody(Bytes); + +impl AuthenticationSaslContinueBody { + #[inline] + pub fn data(&self) -> &[u8] { + &self.0 + } +} + +pub struct AuthenticationSaslFinalBody(Bytes); + +impl AuthenticationSaslFinalBody { + #[inline] + pub fn data(&self) -> &[u8] { + &self.0 + } +} + +pub struct BackendKeyDataBody { + process_id: i32, + secret_key: i32, +} + +impl BackendKeyDataBody { + #[inline] + pub fn process_id(&self) -> i32 { + self.process_id + } + + #[inline] + pub fn secret_key(&self) -> i32 { + self.secret_key + } +} + +pub struct CommandCompleteBody { + tag: Bytes, +} + +impl CommandCompleteBody { + #[inline] + pub fn tag(&self) -> io::Result<&str> { + get_str(&self.tag) + } +} + +#[derive(Debug)] +pub struct DataRowBody { + storage: Bytes, + len: u16, +} + +impl DataRowBody { + #[inline] + pub fn ranges(&self) -> DataRowRanges<'_> { + DataRowRanges { + buf: &self.storage, + len: self.storage.len(), + remaining: self.len, + } + } + + #[inline] + pub fn buffer(&self) -> &[u8] { + &self.storage + } +} + +pub struct DataRowRanges<'a> { + buf: &'a [u8], + len: usize, + remaining: u16, +} + +impl FallibleIterator for DataRowRanges<'_> { + type Item = Option>; + type Error = io::Error; + + #[inline] + fn next(&mut self) -> io::Result>>> { + if self.remaining == 0 { + if self.buf.is_empty() { + return Ok(None); + } else { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid message length: datarowrange is not empty", + )); + } + } + + self.remaining -= 1; + let len = self.buf.read_i32::()?; + if len < 0 { + Ok(Some(None)) + } else { + let len = len as usize; + if self.buf.len() < len { + return Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "unexpected EOF", + )); + } + let base = self.len - self.buf.len(); + self.buf = &self.buf[len..]; + Ok(Some(Some(base..base + len))) + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let len = self.remaining as usize; + (len, Some(len)) + } +} + +pub struct ErrorResponseBody { + storage: Bytes, +} + +impl ErrorResponseBody { + #[inline] + pub fn fields(&self) -> ErrorFields<'_> { + ErrorFields { buf: &self.storage } + } +} + +pub struct ErrorFields<'a> { + buf: &'a [u8], +} + +impl<'a> FallibleIterator for ErrorFields<'a> { + type Item = ErrorField<'a>; + type Error = io::Error; + + #[inline] + fn next(&mut self) -> io::Result>> { + let type_ = self.buf.read_u8()?; + if type_ == 0 { + if self.buf.is_empty() { + return Ok(None); + } else { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid message length: error fields is not drained", + )); + } + } + + let value_end = find_null(self.buf, 0)?; + let value = get_str(&self.buf[..value_end])?; + self.buf = &self.buf[value_end + 1..]; + + Ok(Some(ErrorField { type_, value })) + } +} + +pub struct ErrorField<'a> { + type_: u8, + value: &'a str, +} + +impl ErrorField<'_> { + #[inline] + pub fn type_(&self) -> u8 { + self.type_ + } + + #[inline] + pub fn value(&self) -> &str { + self.value + } +} + +pub struct NoticeResponseBody { + storage: Bytes, +} + +impl NoticeResponseBody { + #[inline] + pub fn fields(&self) -> ErrorFields<'_> { + ErrorFields { buf: &self.storage } + } +} + +pub struct NotificationResponseBody { + process_id: i32, + channel: Bytes, + message: Bytes, +} + +impl NotificationResponseBody { + #[inline] + pub fn process_id(&self) -> i32 { + self.process_id + } + + #[inline] + pub fn channel(&self) -> io::Result<&str> { + get_str(&self.channel) + } + + #[inline] + pub fn message(&self) -> io::Result<&str> { + get_str(&self.message) + } +} + +pub struct ParameterDescriptionBody { + storage: Bytes, + len: u16, +} + +impl ParameterDescriptionBody { + #[inline] + pub fn parameters(&self) -> Parameters<'_> { + Parameters { + buf: &self.storage, + remaining: self.len, + } + } +} + +pub struct Parameters<'a> { + buf: &'a [u8], + remaining: u16, +} + +impl FallibleIterator for Parameters<'_> { + type Item = Oid; + type Error = io::Error; + + #[inline] + fn next(&mut self) -> io::Result> { + if self.remaining == 0 { + if self.buf.is_empty() { + return Ok(None); + } else { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid message length: parameters is not drained", + )); + } + } + + self.remaining -= 1; + self.buf.read_u32::().map(Some) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let len = self.remaining as usize; + (len, Some(len)) + } +} + +pub struct ParameterStatusBody { + name: Bytes, + value: Bytes, +} + +impl ParameterStatusBody { + #[inline] + pub fn name(&self) -> io::Result<&str> { + get_str(&self.name) + } + + #[inline] + pub fn value(&self) -> io::Result<&str> { + get_str(&self.value) + } +} + +pub struct ReadyForQueryBody { + status: u8, +} + +impl ReadyForQueryBody { + #[inline] + pub fn status(&self) -> u8 { + self.status + } +} + +pub struct RowDescriptionBody { + storage: Bytes, + len: u16, +} + +impl RowDescriptionBody { + #[inline] + pub fn fields(&self) -> Fields<'_> { + Fields { + buf: &self.storage, + remaining: self.len, + } + } +} + +pub struct Fields<'a> { + buf: &'a [u8], + remaining: u16, +} + +impl<'a> FallibleIterator for Fields<'a> { + type Item = Field<'a>; + type Error = io::Error; + + #[inline] + fn next(&mut self) -> io::Result>> { + if self.remaining == 0 { + if self.buf.is_empty() { + return Ok(None); + } else { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid message length: field is not drained", + )); + } + } + + self.remaining -= 1; + let name_end = find_null(self.buf, 0)?; + let name = get_str(&self.buf[..name_end])?; + self.buf = &self.buf[name_end + 1..]; + let table_oid = self.buf.read_u32::()?; + let column_id = self.buf.read_i16::()?; + let type_oid = self.buf.read_u32::()?; + let type_size = self.buf.read_i16::()?; + let type_modifier = self.buf.read_i32::()?; + let format = self.buf.read_i16::()?; + + Ok(Some(Field { + name, + table_oid, + column_id, + type_oid, + type_size, + type_modifier, + format, + })) + } +} + +pub struct Field<'a> { + name: &'a str, + table_oid: Oid, + column_id: i16, + type_oid: Oid, + type_size: i16, + type_modifier: i32, + format: i16, +} + +impl<'a> Field<'a> { + #[inline] + pub fn name(&self) -> &'a str { + self.name + } + + #[inline] + pub fn table_oid(&self) -> Oid { + self.table_oid + } + + #[inline] + pub fn column_id(&self) -> i16 { + self.column_id + } + + #[inline] + pub fn type_oid(&self) -> Oid { + self.type_oid + } + + #[inline] + pub fn type_size(&self) -> i16 { + self.type_size + } + + #[inline] + pub fn type_modifier(&self) -> i32 { + self.type_modifier + } + + #[inline] + pub fn format(&self) -> i16 { + self.format + } +} + +#[inline] +fn find_null(buf: &[u8], start: usize) -> io::Result { + match memchr(0, &buf[start..]) { + Some(pos) => Ok(pos + start), + None => Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "unexpected EOF", + )), + } +} + +#[inline] +fn get_str(buf: &[u8]) -> io::Result<&str> { + str::from_utf8(buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e)) +} diff --git a/libs/proxy/postgres-protocol2/src/message/frontend.rs b/libs/proxy/postgres-protocol2/src/message/frontend.rs new file mode 100644 index 0000000000..5d0a8ff8c8 --- /dev/null +++ b/libs/proxy/postgres-protocol2/src/message/frontend.rs @@ -0,0 +1,297 @@ +//! Frontend message serialization. +#![allow(missing_docs)] + +use byteorder::{BigEndian, ByteOrder}; +use bytes::{Buf, BufMut, BytesMut}; +use std::convert::TryFrom; +use std::error::Error; +use std::io; +use std::marker; + +use crate::{write_nullable, FromUsize, IsNull, Oid}; + +#[inline] +fn write_body(buf: &mut BytesMut, f: F) -> Result<(), E> +where + F: FnOnce(&mut BytesMut) -> Result<(), E>, + E: From, +{ + let base = buf.len(); + buf.extend_from_slice(&[0; 4]); + + f(buf)?; + + let size = i32::from_usize(buf.len() - base)?; + BigEndian::write_i32(&mut buf[base..], size); + Ok(()) +} + +pub enum BindError { + Conversion(Box), + Serialization(io::Error), +} + +impl From> for BindError { + #[inline] + fn from(e: Box) -> BindError { + BindError::Conversion(e) + } +} + +impl From for BindError { + #[inline] + fn from(e: io::Error) -> BindError { + BindError::Serialization(e) + } +} + +#[inline] +pub fn bind( + portal: &str, + statement: &str, + formats: I, + values: J, + mut serializer: F, + result_formats: K, + buf: &mut BytesMut, +) -> Result<(), BindError> +where + I: IntoIterator, + J: IntoIterator, + F: FnMut(T, &mut BytesMut) -> Result>, + K: IntoIterator, +{ + buf.put_u8(b'B'); + + write_body(buf, |buf| { + write_cstr(portal.as_bytes(), buf)?; + write_cstr(statement.as_bytes(), buf)?; + write_counted( + formats, + |f, buf| { + buf.put_i16(f); + Ok::<_, io::Error>(()) + }, + buf, + )?; + write_counted( + values, + |v, buf| write_nullable(|buf| serializer(v, buf), buf), + buf, + )?; + write_counted( + result_formats, + |f, buf| { + buf.put_i16(f); + Ok::<_, io::Error>(()) + }, + buf, + )?; + + Ok(()) + }) +} + +#[inline] +fn write_counted(items: I, mut serializer: F, buf: &mut BytesMut) -> Result<(), E> +where + I: IntoIterator, + F: FnMut(T, &mut BytesMut) -> Result<(), E>, + E: From, +{ + let base = buf.len(); + buf.extend_from_slice(&[0; 2]); + let mut count = 0; + for item in items { + serializer(item, buf)?; + count += 1; + } + let count = i16::from_usize(count)?; + BigEndian::write_i16(&mut buf[base..], count); + + Ok(()) +} + +#[inline] +pub fn cancel_request(process_id: i32, secret_key: i32, buf: &mut BytesMut) { + write_body(buf, |buf| { + buf.put_i32(80_877_102); + buf.put_i32(process_id); + buf.put_i32(secret_key); + Ok::<_, io::Error>(()) + }) + .unwrap(); +} + +#[inline] +pub fn close(variant: u8, name: &str, buf: &mut BytesMut) -> io::Result<()> { + buf.put_u8(b'C'); + write_body(buf, |buf| { + buf.put_u8(variant); + write_cstr(name.as_bytes(), buf) + }) +} + +pub struct CopyData { + buf: T, + len: i32, +} + +impl CopyData +where + T: Buf, +{ + pub fn new(buf: T) -> io::Result> { + let len = buf + .remaining() + .checked_add(4) + .and_then(|l| i32::try_from(l).ok()) + .ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidInput, "message length overflow") + })?; + + Ok(CopyData { buf, len }) + } + + pub fn write(self, out: &mut BytesMut) { + out.put_u8(b'd'); + out.put_i32(self.len); + out.put(self.buf); + } +} + +#[inline] +pub fn copy_done(buf: &mut BytesMut) { + buf.put_u8(b'c'); + write_body(buf, |_| Ok::<(), io::Error>(())).unwrap(); +} + +#[inline] +pub fn copy_fail(message: &str, buf: &mut BytesMut) -> io::Result<()> { + buf.put_u8(b'f'); + write_body(buf, |buf| write_cstr(message.as_bytes(), buf)) +} + +#[inline] +pub fn describe(variant: u8, name: &str, buf: &mut BytesMut) -> io::Result<()> { + buf.put_u8(b'D'); + write_body(buf, |buf| { + buf.put_u8(variant); + write_cstr(name.as_bytes(), buf) + }) +} + +#[inline] +pub fn execute(portal: &str, max_rows: i32, buf: &mut BytesMut) -> io::Result<()> { + buf.put_u8(b'E'); + write_body(buf, |buf| { + write_cstr(portal.as_bytes(), buf)?; + buf.put_i32(max_rows); + Ok(()) + }) +} + +#[inline] +pub fn parse(name: &str, query: &str, param_types: I, buf: &mut BytesMut) -> io::Result<()> +where + I: IntoIterator, +{ + buf.put_u8(b'P'); + write_body(buf, |buf| { + write_cstr(name.as_bytes(), buf)?; + write_cstr(query.as_bytes(), buf)?; + write_counted( + param_types, + |t, buf| { + buf.put_u32(t); + Ok::<_, io::Error>(()) + }, + buf, + )?; + Ok(()) + }) +} + +#[inline] +pub fn password_message(password: &[u8], buf: &mut BytesMut) -> io::Result<()> { + buf.put_u8(b'p'); + write_body(buf, |buf| write_cstr(password, buf)) +} + +#[inline] +pub fn query(query: &str, buf: &mut BytesMut) -> io::Result<()> { + buf.put_u8(b'Q'); + write_body(buf, |buf| write_cstr(query.as_bytes(), buf)) +} + +#[inline] +pub fn sasl_initial_response(mechanism: &str, data: &[u8], buf: &mut BytesMut) -> io::Result<()> { + buf.put_u8(b'p'); + write_body(buf, |buf| { + write_cstr(mechanism.as_bytes(), buf)?; + let len = i32::from_usize(data.len())?; + buf.put_i32(len); + buf.put_slice(data); + Ok(()) + }) +} + +#[inline] +pub fn sasl_response(data: &[u8], buf: &mut BytesMut) -> io::Result<()> { + buf.put_u8(b'p'); + write_body(buf, |buf| { + buf.put_slice(data); + Ok(()) + }) +} + +#[inline] +pub fn ssl_request(buf: &mut BytesMut) { + write_body(buf, |buf| { + buf.put_i32(80_877_103); + Ok::<_, io::Error>(()) + }) + .unwrap(); +} + +#[inline] +pub fn startup_message<'a, I>(parameters: I, buf: &mut BytesMut) -> io::Result<()> +where + I: IntoIterator, +{ + write_body(buf, |buf| { + // postgres protocol version 3.0(196608) in bigger-endian + buf.put_i32(0x00_03_00_00); + for (key, value) in parameters { + write_cstr(key.as_bytes(), buf)?; + write_cstr(value.as_bytes(), buf)?; + } + buf.put_u8(0); + Ok(()) + }) +} + +#[inline] +pub fn sync(buf: &mut BytesMut) { + buf.put_u8(b'S'); + write_body(buf, |_| Ok::<(), io::Error>(())).unwrap(); +} + +#[inline] +pub fn terminate(buf: &mut BytesMut) { + buf.put_u8(b'X'); + write_body(buf, |_| Ok::<(), io::Error>(())).unwrap(); +} + +#[inline] +fn write_cstr(s: &[u8], buf: &mut BytesMut) -> Result<(), io::Error> { + if s.contains(&0) { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "string contains embedded null", + )); + } + buf.put_slice(s); + buf.put_u8(0); + Ok(()) +} diff --git a/libs/proxy/postgres-protocol2/src/message/mod.rs b/libs/proxy/postgres-protocol2/src/message/mod.rs new file mode 100644 index 0000000000..9e5d997548 --- /dev/null +++ b/libs/proxy/postgres-protocol2/src/message/mod.rs @@ -0,0 +1,8 @@ +//! Postgres message protocol support. +//! +//! See [Postgres's documentation][docs] for more information on message flow. +//! +//! [docs]: https://www.postgresql.org/docs/9.5/static/protocol-flow.html + +pub mod backend; +pub mod frontend; diff --git a/libs/proxy/postgres-protocol2/src/password/mod.rs b/libs/proxy/postgres-protocol2/src/password/mod.rs new file mode 100644 index 0000000000..e669e80f3f --- /dev/null +++ b/libs/proxy/postgres-protocol2/src/password/mod.rs @@ -0,0 +1,107 @@ +//! Functions to encrypt a password in the client. +//! +//! This is intended to be used by client applications that wish to +//! send commands like `ALTER USER joe PASSWORD 'pwd'`. The password +//! need not be sent in cleartext if it is encrypted on the client +//! side. This is good because it ensures the cleartext password won't +//! end up in logs pg_stat displays, etc. + +use crate::authentication::sasl; +use hmac::{Hmac, Mac}; +use md5::Md5; +use rand::RngCore; +use sha2::digest::FixedOutput; +use sha2::{Digest, Sha256}; + +#[cfg(test)] +mod test; + +const SCRAM_DEFAULT_ITERATIONS: u32 = 4096; +const SCRAM_DEFAULT_SALT_LEN: usize = 16; + +/// Hash password using SCRAM-SHA-256 with a randomly-generated +/// salt. +/// +/// The client may assume the returned string doesn't contain any +/// special characters that would require escaping in an SQL command. +pub async fn scram_sha_256(password: &[u8]) -> String { + let mut salt: [u8; SCRAM_DEFAULT_SALT_LEN] = [0; SCRAM_DEFAULT_SALT_LEN]; + let mut rng = rand::thread_rng(); + rng.fill_bytes(&mut salt); + scram_sha_256_salt(password, salt).await +} + +// Internal implementation of scram_sha_256 with a caller-provided +// salt. This is useful for testing. +pub(crate) async fn scram_sha_256_salt( + password: &[u8], + salt: [u8; SCRAM_DEFAULT_SALT_LEN], +) -> String { + // Prepare the password, per [RFC + // 4013](https://tools.ietf.org/html/rfc4013), if possible. + // + // Postgres treats passwords as byte strings (without embedded NUL + // bytes), but SASL expects passwords to be valid UTF-8. + // + // Follow the behavior of libpq's PQencryptPasswordConn(), and + // also the backend. If the password is not valid UTF-8, or if it + // contains prohibited characters (such as non-ASCII whitespace), + // just skip the SASLprep step and use the original byte + // sequence. + let prepared: Vec = match std::str::from_utf8(password) { + Ok(password_str) => { + match stringprep::saslprep(password_str) { + Ok(p) => p.into_owned().into_bytes(), + // contains invalid characters; skip saslprep + Err(_) => Vec::from(password), + } + } + // not valid UTF-8; skip saslprep + Err(_) => Vec::from(password), + }; + + // salt password + let salted_password = sasl::hi(&prepared, &salt, SCRAM_DEFAULT_ITERATIONS).await; + + // client key + let mut hmac = Hmac::::new_from_slice(&salted_password) + .expect("HMAC is able to accept all key sizes"); + hmac.update(b"Client Key"); + let client_key = hmac.finalize().into_bytes(); + + // stored key + let mut hash = Sha256::default(); + hash.update(client_key.as_slice()); + let stored_key = hash.finalize_fixed(); + + // server key + let mut hmac = Hmac::::new_from_slice(&salted_password) + .expect("HMAC is able to accept all key sizes"); + hmac.update(b"Server Key"); + let server_key = hmac.finalize().into_bytes(); + + format!( + "SCRAM-SHA-256${}:{}${}:{}", + SCRAM_DEFAULT_ITERATIONS, + base64::encode(salt), + base64::encode(stored_key), + base64::encode(server_key) + ) +} + +/// **Not recommended, as MD5 is not considered to be secure.** +/// +/// Hash password using MD5 with the username as the salt. +/// +/// The client may assume the returned string doesn't contain any +/// special characters that would require escaping. +pub fn md5(password: &[u8], username: &str) -> String { + // salt password with username + let mut salted_password = Vec::from(password); + salted_password.extend_from_slice(username.as_bytes()); + + let mut hash = Md5::new(); + hash.update(&salted_password); + let digest = hash.finalize(); + format!("md5{:x}", digest) +} diff --git a/libs/proxy/postgres-protocol2/src/password/test.rs b/libs/proxy/postgres-protocol2/src/password/test.rs new file mode 100644 index 0000000000..c9d340f09d --- /dev/null +++ b/libs/proxy/postgres-protocol2/src/password/test.rs @@ -0,0 +1,19 @@ +use crate::password; + +#[tokio::test] +async fn test_encrypt_scram_sha_256() { + // Specify the salt to make the test deterministic. Any bytes will do. + let salt: [u8; 16] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; + assert_eq!( + password::scram_sha_256_salt(b"secret", salt).await, + "SCRAM-SHA-256$4096:AQIDBAUGBwgJCgsMDQ4PEA==$8rrDg00OqaiWXJ7p+sCgHEIaBSHY89ZJl3mfIsf32oY=:05L1f+yZbiN8O0AnO40Og85NNRhvzTS57naKRWCcsIA=" + ); +} + +#[test] +fn test_encrypt_md5() { + assert_eq!( + password::md5(b"secret", "foo"), + "md54ab2c5d00339c4b2a4e921d2dc4edec7" + ); +} diff --git a/libs/proxy/postgres-protocol2/src/types/mod.rs b/libs/proxy/postgres-protocol2/src/types/mod.rs new file mode 100644 index 0000000000..78131c05bf --- /dev/null +++ b/libs/proxy/postgres-protocol2/src/types/mod.rs @@ -0,0 +1,294 @@ +//! Conversions to and from Postgres's binary format for various types. +use byteorder::{BigEndian, ReadBytesExt}; +use bytes::{BufMut, BytesMut}; +use fallible_iterator::FallibleIterator; +use std::boxed::Box as StdBox; +use std::error::Error; +use std::str; + +use crate::Oid; + +#[cfg(test)] +mod test; + +/// Serializes a `TEXT`, `VARCHAR`, `CHAR(n)`, `NAME`, or `CITEXT` value. +#[inline] +pub fn text_to_sql(v: &str, buf: &mut BytesMut) { + buf.put_slice(v.as_bytes()); +} + +/// Deserializes a `TEXT`, `VARCHAR`, `CHAR(n)`, `NAME`, or `CITEXT` value. +#[inline] +pub fn text_from_sql(buf: &[u8]) -> Result<&str, StdBox> { + Ok(str::from_utf8(buf)?) +} + +/// Deserializes a `"char"` value. +#[inline] +pub fn char_from_sql(mut buf: &[u8]) -> Result> { + let v = buf.read_i8()?; + if !buf.is_empty() { + return Err("invalid buffer size".into()); + } + Ok(v) +} + +/// Serializes an `OID` value. +#[inline] +pub fn oid_to_sql(v: Oid, buf: &mut BytesMut) { + buf.put_u32(v); +} + +/// Deserializes an `OID` value. +#[inline] +pub fn oid_from_sql(mut buf: &[u8]) -> Result> { + let v = buf.read_u32::()?; + if !buf.is_empty() { + return Err("invalid buffer size".into()); + } + Ok(v) +} + +/// A fallible iterator over `HSTORE` entries. +pub struct HstoreEntries<'a> { + remaining: i32, + buf: &'a [u8], +} + +impl<'a> FallibleIterator for HstoreEntries<'a> { + type Item = (&'a str, Option<&'a str>); + type Error = StdBox; + + #[inline] + #[allow(clippy::type_complexity)] + fn next( + &mut self, + ) -> Result)>, StdBox> { + if self.remaining == 0 { + if !self.buf.is_empty() { + return Err("invalid buffer size".into()); + } + return Ok(None); + } + + self.remaining -= 1; + + let key_len = self.buf.read_i32::()?; + if key_len < 0 { + return Err("invalid key length".into()); + } + let (key, buf) = self.buf.split_at(key_len as usize); + let key = str::from_utf8(key)?; + self.buf = buf; + + let value_len = self.buf.read_i32::()?; + let value = if value_len < 0 { + None + } else { + let (value, buf) = self.buf.split_at(value_len as usize); + let value = str::from_utf8(value)?; + self.buf = buf; + Some(value) + }; + + Ok(Some((key, value))) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let len = self.remaining as usize; + (len, Some(len)) + } +} + +/// Deserializes an array value. +#[inline] +pub fn array_from_sql(mut buf: &[u8]) -> Result, StdBox> { + let dimensions = buf.read_i32::()?; + if dimensions < 0 { + return Err("invalid dimension count".into()); + } + + let mut r = buf; + let mut elements = 1i32; + for _ in 0..dimensions { + let len = r.read_i32::()?; + if len < 0 { + return Err("invalid dimension size".into()); + } + let _lower_bound = r.read_i32::()?; + elements = match elements.checked_mul(len) { + Some(elements) => elements, + None => return Err("too many array elements".into()), + }; + } + + if dimensions == 0 { + elements = 0; + } + + Ok(Array { + dimensions, + elements, + buf, + }) +} + +/// A Postgres array. +pub struct Array<'a> { + dimensions: i32, + elements: i32, + buf: &'a [u8], +} + +impl<'a> Array<'a> { + /// Returns an iterator over the dimensions of the array. + #[inline] + pub fn dimensions(&self) -> ArrayDimensions<'a> { + ArrayDimensions(&self.buf[..self.dimensions as usize * 8]) + } + + /// Returns an iterator over the values of the array. + #[inline] + pub fn values(&self) -> ArrayValues<'a> { + ArrayValues { + remaining: self.elements, + buf: &self.buf[self.dimensions as usize * 8..], + } + } +} + +/// An iterator over the dimensions of an array. +pub struct ArrayDimensions<'a>(&'a [u8]); + +impl FallibleIterator for ArrayDimensions<'_> { + type Item = ArrayDimension; + type Error = StdBox; + + #[inline] + fn next(&mut self) -> Result, StdBox> { + if self.0.is_empty() { + return Ok(None); + } + + let len = self.0.read_i32::()?; + let lower_bound = self.0.read_i32::()?; + + Ok(Some(ArrayDimension { len, lower_bound })) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let len = self.0.len() / 8; + (len, Some(len)) + } +} + +/// Information about a dimension of an array. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub struct ArrayDimension { + /// The length of this dimension. + pub len: i32, + + /// The base value used to index into this dimension. + pub lower_bound: i32, +} + +/// An iterator over the values of an array, in row-major order. +pub struct ArrayValues<'a> { + remaining: i32, + buf: &'a [u8], +} + +impl<'a> FallibleIterator for ArrayValues<'a> { + type Item = Option<&'a [u8]>; + type Error = StdBox; + + #[inline] + fn next(&mut self) -> Result>, StdBox> { + if self.remaining == 0 { + if !self.buf.is_empty() { + return Err("invalid message length: arrayvalue not drained".into()); + } + return Ok(None); + } + self.remaining -= 1; + + let len = self.buf.read_i32::()?; + let val = if len < 0 { + None + } else { + if self.buf.len() < len as usize { + return Err("invalid value length".into()); + } + + let (val, buf) = self.buf.split_at(len as usize); + self.buf = buf; + Some(val) + }; + + Ok(Some(val)) + } + + fn size_hint(&self) -> (usize, Option) { + let len = self.remaining as usize; + (len, Some(len)) + } +} + +/// Serializes a Postgres ltree string +#[inline] +pub fn ltree_to_sql(v: &str, buf: &mut BytesMut) { + // A version number is prepended to an ltree string per spec + buf.put_u8(1); + // Append the rest of the query + buf.put_slice(v.as_bytes()); +} + +/// Deserialize a Postgres ltree string +#[inline] +pub fn ltree_from_sql(buf: &[u8]) -> Result<&str, StdBox> { + match buf { + // Remove the version number from the front of the ltree per spec + [1u8, rest @ ..] => Ok(str::from_utf8(rest)?), + _ => Err("ltree version 1 only supported".into()), + } +} + +/// Serializes a Postgres lquery string +#[inline] +pub fn lquery_to_sql(v: &str, buf: &mut BytesMut) { + // A version number is prepended to an lquery string per spec + buf.put_u8(1); + // Append the rest of the query + buf.put_slice(v.as_bytes()); +} + +/// Deserialize a Postgres lquery string +#[inline] +pub fn lquery_from_sql(buf: &[u8]) -> Result<&str, StdBox> { + match buf { + // Remove the version number from the front of the lquery per spec + [1u8, rest @ ..] => Ok(str::from_utf8(rest)?), + _ => Err("lquery version 1 only supported".into()), + } +} + +/// Serializes a Postgres ltxtquery string +#[inline] +pub fn ltxtquery_to_sql(v: &str, buf: &mut BytesMut) { + // A version number is prepended to an ltxtquery string per spec + buf.put_u8(1); + // Append the rest of the query + buf.put_slice(v.as_bytes()); +} + +/// Deserialize a Postgres ltxtquery string +#[inline] +pub fn ltxtquery_from_sql(buf: &[u8]) -> Result<&str, StdBox> { + match buf { + // Remove the version number from the front of the ltxtquery per spec + [1u8, rest @ ..] => Ok(str::from_utf8(rest)?), + _ => Err("ltxtquery version 1 only supported".into()), + } +} diff --git a/libs/proxy/postgres-protocol2/src/types/test.rs b/libs/proxy/postgres-protocol2/src/types/test.rs new file mode 100644 index 0000000000..96cc055bc3 --- /dev/null +++ b/libs/proxy/postgres-protocol2/src/types/test.rs @@ -0,0 +1,87 @@ +use bytes::{Buf, BytesMut}; + +use super::*; + +#[test] +fn ltree_sql() { + let mut query = vec![1u8]; + query.extend_from_slice("A.B.C".as_bytes()); + + let mut buf = BytesMut::new(); + + ltree_to_sql("A.B.C", &mut buf); + + assert_eq!(query.as_slice(), buf.chunk()); +} + +#[test] +fn ltree_str() { + let mut query = vec![1u8]; + query.extend_from_slice("A.B.C".as_bytes()); + + assert!(ltree_from_sql(query.as_slice()).is_ok()) +} + +#[test] +fn ltree_wrong_version() { + let mut query = vec![2u8]; + query.extend_from_slice("A.B.C".as_bytes()); + + assert!(ltree_from_sql(query.as_slice()).is_err()) +} + +#[test] +fn lquery_sql() { + let mut query = vec![1u8]; + query.extend_from_slice("A.B.C".as_bytes()); + + let mut buf = BytesMut::new(); + + lquery_to_sql("A.B.C", &mut buf); + + assert_eq!(query.as_slice(), buf.chunk()); +} + +#[test] +fn lquery_str() { + let mut query = vec![1u8]; + query.extend_from_slice("A.B.C".as_bytes()); + + assert!(lquery_from_sql(query.as_slice()).is_ok()) +} + +#[test] +fn lquery_wrong_version() { + let mut query = vec![2u8]; + query.extend_from_slice("A.B.C".as_bytes()); + + assert!(lquery_from_sql(query.as_slice()).is_err()) +} + +#[test] +fn ltxtquery_sql() { + let mut query = vec![1u8]; + query.extend_from_slice("a & b*".as_bytes()); + + let mut buf = BytesMut::new(); + + ltree_to_sql("a & b*", &mut buf); + + assert_eq!(query.as_slice(), buf.chunk()); +} + +#[test] +fn ltxtquery_str() { + let mut query = vec![1u8]; + query.extend_from_slice("a & b*".as_bytes()); + + assert!(ltree_from_sql(query.as_slice()).is_ok()) +} + +#[test] +fn ltxtquery_wrong_version() { + let mut query = vec![2u8]; + query.extend_from_slice("a & b*".as_bytes()); + + assert!(ltree_from_sql(query.as_slice()).is_err()) +} diff --git a/libs/proxy/postgres-types2/Cargo.toml b/libs/proxy/postgres-types2/Cargo.toml new file mode 100644 index 0000000000..58cfb5571f --- /dev/null +++ b/libs/proxy/postgres-types2/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "postgres-types2" +version = "0.1.0" +edition = "2018" +license = "MIT/Apache-2.0" + +[dependencies] +bytes.workspace = true +fallible-iterator.workspace = true +postgres-protocol2 = { path = "../postgres-protocol2" } diff --git a/libs/proxy/postgres-types2/src/lib.rs b/libs/proxy/postgres-types2/src/lib.rs new file mode 100644 index 0000000000..18ba032151 --- /dev/null +++ b/libs/proxy/postgres-types2/src/lib.rs @@ -0,0 +1,477 @@ +//! Conversions to and from Postgres types. +//! +//! This crate is used by the `tokio-postgres` and `postgres` crates. You normally don't need to depend directly on it +//! unless you want to define your own `ToSql` or `FromSql` definitions. +#![doc(html_root_url = "https://docs.rs/postgres-types/0.2")] +#![warn(clippy::all, rust_2018_idioms, missing_docs)] + +use fallible_iterator::FallibleIterator; +use postgres_protocol2::types; +use std::any::type_name; +use std::error::Error; +use std::fmt; +use std::sync::Arc; + +use crate::type_gen::{Inner, Other}; + +#[doc(inline)] +pub use postgres_protocol2::Oid; + +use bytes::BytesMut; + +/// Generates a simple implementation of `ToSql::accepts` which accepts the +/// types passed to it. +macro_rules! accepts { + ($($expected:ident),+) => ( + fn accepts(ty: &$crate::Type) -> bool { + matches!(*ty, $($crate::Type::$expected)|+) + } + ) +} + +/// Generates an implementation of `ToSql::to_sql_checked`. +/// +/// All `ToSql` implementations should use this macro. +macro_rules! to_sql_checked { + () => { + fn to_sql_checked( + &self, + ty: &$crate::Type, + out: &mut $crate::private::BytesMut, + ) -> ::std::result::Result< + $crate::IsNull, + Box, + > { + $crate::__to_sql_checked(self, ty, out) + } + }; +} + +// WARNING: this function is not considered part of this crate's public API. +// It is subject to change at any time. +#[doc(hidden)] +pub fn __to_sql_checked( + v: &T, + ty: &Type, + out: &mut BytesMut, +) -> Result> +where + T: ToSql, +{ + if !T::accepts(ty) { + return Err(Box::new(WrongType::new::(ty.clone()))); + } + v.to_sql(ty, out) +} + +// mod pg_lsn; +#[doc(hidden)] +pub mod private; +// mod special; +mod type_gen; + +/// A Postgres type. +#[derive(PartialEq, Eq, Clone, Hash)] +pub struct Type(Inner); + +impl fmt::Debug for Type { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&self.0, fmt) + } +} + +impl fmt::Display for Type { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.schema() { + "public" | "pg_catalog" => {} + schema => write!(fmt, "{}.", schema)?, + } + fmt.write_str(self.name()) + } +} + +impl Type { + /// Creates a new `Type`. + pub fn new(name: String, oid: Oid, kind: Kind, schema: String) -> Type { + Type(Inner::Other(Arc::new(Other { + name, + oid, + kind, + schema, + }))) + } + + /// Returns the `Type` corresponding to the provided `Oid` if it + /// corresponds to a built-in type. + pub fn from_oid(oid: Oid) -> Option { + Inner::from_oid(oid).map(Type) + } + + /// Returns the OID of the `Type`. + pub fn oid(&self) -> Oid { + self.0.oid() + } + + /// Returns the kind of this type. + pub fn kind(&self) -> &Kind { + self.0.kind() + } + + /// Returns the schema of this type. + pub fn schema(&self) -> &str { + match self.0 { + Inner::Other(ref u) => &u.schema, + _ => "pg_catalog", + } + } + + /// Returns the name of this type. + pub fn name(&self) -> &str { + self.0.name() + } +} + +/// Represents the kind of a Postgres type. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[non_exhaustive] +pub enum Kind { + /// A simple type like `VARCHAR` or `INTEGER`. + Simple, + /// An enumerated type along with its variants. + Enum(Vec), + /// A pseudo-type. + Pseudo, + /// An array type along with the type of its elements. + Array(Type), + /// A range type along with the type of its elements. + Range(Type), + /// A multirange type along with the type of its elements. + Multirange(Type), + /// A domain type along with its underlying type. + Domain(Type), + /// A composite type along with information about its fields. + Composite(Vec), +} + +/// Information about a field of a composite type. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct Field { + name: String, + type_: Type, +} + +impl Field { + /// Creates a new `Field`. + pub fn new(name: String, type_: Type) -> Field { + Field { name, type_ } + } + + /// Returns the name of the field. + pub fn name(&self) -> &str { + &self.name + } + + /// Returns the type of the field. + pub fn type_(&self) -> &Type { + &self.type_ + } +} + +/// An error indicating that a `NULL` Postgres value was passed to a `FromSql` +/// implementation that does not support `NULL` values. +#[derive(Debug, Clone, Copy)] +pub struct WasNull; + +impl fmt::Display for WasNull { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.write_str("a Postgres value was `NULL`") + } +} + +impl Error for WasNull {} + +/// An error indicating that a conversion was attempted between incompatible +/// Rust and Postgres types. +#[derive(Debug)] +pub struct WrongType { + postgres: Type, + rust: &'static str, +} + +impl fmt::Display for WrongType { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + fmt, + "cannot convert between the Rust type `{}` and the Postgres type `{}`", + self.rust, self.postgres, + ) + } +} + +impl Error for WrongType {} + +impl WrongType { + /// Creates a new `WrongType` error. + pub fn new(ty: Type) -> WrongType { + WrongType { + postgres: ty, + rust: type_name::(), + } + } +} + +/// An error indicating that a as_text conversion was attempted on a binary +/// result. +#[derive(Debug)] +pub struct WrongFormat {} + +impl Error for WrongFormat {} + +impl fmt::Display for WrongFormat { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + fmt, + "cannot read column as text while it is in binary format" + ) + } +} + +/// A trait for types that can be created from a Postgres value. +pub trait FromSql<'a>: Sized { + /// Creates a new value of this type from a buffer of data of the specified + /// Postgres `Type` in its binary format. + /// + /// The caller of this method is responsible for ensuring that this type + /// is compatible with the Postgres `Type`. + fn from_sql(ty: &Type, raw: &'a [u8]) -> Result>; + + /// Creates a new value of this type from a `NULL` SQL value. + /// + /// The caller of this method is responsible for ensuring that this type + /// is compatible with the Postgres `Type`. + /// + /// The default implementation returns `Err(Box::new(WasNull))`. + #[allow(unused_variables)] + fn from_sql_null(ty: &Type) -> Result> { + Err(Box::new(WasNull)) + } + + /// A convenience function that delegates to `from_sql` and `from_sql_null` depending on the + /// value of `raw`. + fn from_sql_nullable( + ty: &Type, + raw: Option<&'a [u8]>, + ) -> Result> { + match raw { + Some(raw) => Self::from_sql(ty, raw), + None => Self::from_sql_null(ty), + } + } + + /// Determines if a value of this type can be created from the specified + /// Postgres `Type`. + fn accepts(ty: &Type) -> bool; +} + +/// A trait for types which can be created from a Postgres value without borrowing any data. +/// +/// This is primarily useful for trait bounds on functions. +pub trait FromSqlOwned: for<'a> FromSql<'a> {} + +impl FromSqlOwned for T where T: for<'a> FromSql<'a> {} + +impl<'a, T: FromSql<'a>> FromSql<'a> for Option { + fn from_sql(ty: &Type, raw: &'a [u8]) -> Result, Box> { + ::from_sql(ty, raw).map(Some) + } + + fn from_sql_null(_: &Type) -> Result, Box> { + Ok(None) + } + + fn accepts(ty: &Type) -> bool { + ::accepts(ty) + } +} + +impl<'a, T: FromSql<'a>> FromSql<'a> for Vec { + fn from_sql(ty: &Type, raw: &'a [u8]) -> Result, Box> { + let member_type = match *ty.kind() { + Kind::Array(ref member) => member, + _ => panic!("expected array type"), + }; + + let array = types::array_from_sql(raw)?; + if array.dimensions().count()? > 1 { + return Err("array contains too many dimensions".into()); + } + + array + .values() + .map(|v| T::from_sql_nullable(member_type, v)) + .collect() + } + + fn accepts(ty: &Type) -> bool { + match *ty.kind() { + Kind::Array(ref inner) => T::accepts(inner), + _ => false, + } + } +} + +impl<'a> FromSql<'a> for String { + fn from_sql(ty: &Type, raw: &'a [u8]) -> Result> { + <&str as FromSql>::from_sql(ty, raw).map(ToString::to_string) + } + + fn accepts(ty: &Type) -> bool { + <&str as FromSql>::accepts(ty) + } +} + +impl<'a> FromSql<'a> for &'a str { + fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<&'a str, Box> { + match *ty { + ref ty if ty.name() == "ltree" => types::ltree_from_sql(raw), + ref ty if ty.name() == "lquery" => types::lquery_from_sql(raw), + ref ty if ty.name() == "ltxtquery" => types::ltxtquery_from_sql(raw), + _ => types::text_from_sql(raw), + } + } + + fn accepts(ty: &Type) -> bool { + match *ty { + Type::VARCHAR | Type::TEXT | Type::BPCHAR | Type::NAME | Type::UNKNOWN => true, + ref ty + if (ty.name() == "citext" + || ty.name() == "ltree" + || ty.name() == "lquery" + || ty.name() == "ltxtquery") => + { + true + } + _ => false, + } + } +} + +macro_rules! simple_from { + ($t:ty, $f:ident, $($expected:ident),+) => { + impl<'a> FromSql<'a> for $t { + fn from_sql(_: &Type, raw: &'a [u8]) -> Result<$t, Box> { + types::$f(raw) + } + + accepts!($($expected),+); + } + } +} + +simple_from!(i8, char_from_sql, CHAR); +simple_from!(u32, oid_from_sql, OID); + +/// An enum representing the nullability of a Postgres value. +pub enum IsNull { + /// The value is NULL. + Yes, + /// The value is not NULL. + No, +} + +/// A trait for types that can be converted into Postgres values. +pub trait ToSql: fmt::Debug { + /// Converts the value of `self` into the binary format of the specified + /// Postgres `Type`, appending it to `out`. + /// + /// The caller of this method is responsible for ensuring that this type + /// is compatible with the Postgres `Type`. + /// + /// The return value indicates if this value should be represented as + /// `NULL`. If this is the case, implementations **must not** write + /// anything to `out`. + fn to_sql(&self, ty: &Type, out: &mut BytesMut) -> Result> + where + Self: Sized; + + /// Determines if a value of this type can be converted to the specified + /// Postgres `Type`. + fn accepts(ty: &Type) -> bool + where + Self: Sized; + + /// An adaptor method used internally by Rust-Postgres. + /// + /// *All* implementations of this method should be generated by the + /// `to_sql_checked!()` macro. + fn to_sql_checked( + &self, + ty: &Type, + out: &mut BytesMut, + ) -> Result>; + + /// Specify the encode format + fn encode_format(&self, _ty: &Type) -> Format { + Format::Binary + } +} + +/// Supported Postgres message format types +/// +/// Using Text format in a message assumes a Postgres `SERVER_ENCODING` of `UTF8` +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum Format { + /// Text format (UTF-8) + Text, + /// Compact, typed binary format + Binary, +} + +impl ToSql for &str { + fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result> { + match *ty { + ref ty if ty.name() == "ltree" => types::ltree_to_sql(self, w), + ref ty if ty.name() == "lquery" => types::lquery_to_sql(self, w), + ref ty if ty.name() == "ltxtquery" => types::ltxtquery_to_sql(self, w), + _ => types::text_to_sql(self, w), + } + Ok(IsNull::No) + } + + fn accepts(ty: &Type) -> bool { + match *ty { + Type::VARCHAR | Type::TEXT | Type::BPCHAR | Type::NAME | Type::UNKNOWN => true, + ref ty + if (ty.name() == "citext" + || ty.name() == "ltree" + || ty.name() == "lquery" + || ty.name() == "ltxtquery") => + { + true + } + _ => false, + } + } + + to_sql_checked!(); +} + +macro_rules! simple_to { + ($t:ty, $f:ident, $($expected:ident),+) => { + impl ToSql for $t { + fn to_sql(&self, + _: &Type, + w: &mut BytesMut) + -> Result> { + types::$f(*self, w); + Ok(IsNull::No) + } + + accepts!($($expected),+); + + to_sql_checked!(); + } + } +} + +simple_to!(u32, oid_to_sql, OID); diff --git a/libs/proxy/postgres-types2/src/private.rs b/libs/proxy/postgres-types2/src/private.rs new file mode 100644 index 0000000000..774f9a301c --- /dev/null +++ b/libs/proxy/postgres-types2/src/private.rs @@ -0,0 +1,34 @@ +use crate::{FromSql, Type}; +pub use bytes::BytesMut; +use std::error::Error; + +pub fn read_be_i32(buf: &mut &[u8]) -> Result> { + if buf.len() < 4 { + return Err("invalid buffer size".into()); + } + let mut bytes = [0; 4]; + bytes.copy_from_slice(&buf[..4]); + *buf = &buf[4..]; + Ok(i32::from_be_bytes(bytes)) +} + +pub fn read_value<'a, T>( + type_: &Type, + buf: &mut &'a [u8], +) -> Result> +where + T: FromSql<'a>, +{ + let len = read_be_i32(buf)?; + let value = if len < 0 { + None + } else { + if len as usize > buf.len() { + return Err("invalid buffer size".into()); + } + let (head, tail) = buf.split_at(len as usize); + *buf = tail; + Some(head) + }; + T::from_sql_nullable(type_, value) +} diff --git a/libs/proxy/postgres-types2/src/type_gen.rs b/libs/proxy/postgres-types2/src/type_gen.rs new file mode 100644 index 0000000000..a1bc3f85c0 --- /dev/null +++ b/libs/proxy/postgres-types2/src/type_gen.rs @@ -0,0 +1,1524 @@ +// Autogenerated file - DO NOT EDIT +use std::sync::Arc; + +use crate::{Kind, Oid, Type}; + +#[derive(PartialEq, Eq, Debug, Hash)] +pub struct Other { + pub name: String, + pub oid: Oid, + pub kind: Kind, + pub schema: String, +} + +#[derive(PartialEq, Eq, Clone, Debug, Hash)] +pub enum Inner { + Bool, + Bytea, + Char, + Name, + Int8, + Int2, + Int2Vector, + Int4, + Regproc, + Text, + Oid, + Tid, + Xid, + Cid, + OidVector, + PgDdlCommand, + Json, + Xml, + XmlArray, + PgNodeTree, + JsonArray, + TableAmHandler, + Xid8Array, + IndexAmHandler, + Point, + Lseg, + Path, + Box, + Polygon, + Line, + LineArray, + Cidr, + CidrArray, + Float4, + Float8, + Unknown, + Circle, + CircleArray, + Macaddr8, + Macaddr8Array, + Money, + MoneyArray, + Macaddr, + Inet, + BoolArray, + ByteaArray, + CharArray, + NameArray, + Int2Array, + Int2VectorArray, + Int4Array, + RegprocArray, + TextArray, + TidArray, + XidArray, + CidArray, + OidVectorArray, + BpcharArray, + VarcharArray, + Int8Array, + PointArray, + LsegArray, + PathArray, + BoxArray, + Float4Array, + Float8Array, + PolygonArray, + OidArray, + Aclitem, + AclitemArray, + MacaddrArray, + InetArray, + Bpchar, + Varchar, + Date, + Time, + Timestamp, + TimestampArray, + DateArray, + TimeArray, + Timestamptz, + TimestamptzArray, + Interval, + IntervalArray, + NumericArray, + CstringArray, + Timetz, + TimetzArray, + Bit, + BitArray, + Varbit, + VarbitArray, + Numeric, + Refcursor, + RefcursorArray, + Regprocedure, + Regoper, + Regoperator, + Regclass, + Regtype, + RegprocedureArray, + RegoperArray, + RegoperatorArray, + RegclassArray, + RegtypeArray, + Record, + Cstring, + Any, + Anyarray, + Void, + Trigger, + LanguageHandler, + Internal, + Anyelement, + RecordArray, + Anynonarray, + TxidSnapshotArray, + Uuid, + UuidArray, + TxidSnapshot, + FdwHandler, + PgLsn, + PgLsnArray, + TsmHandler, + PgNdistinct, + PgDependencies, + Anyenum, + TsVector, + Tsquery, + GtsVector, + TsVectorArray, + GtsVectorArray, + TsqueryArray, + Regconfig, + RegconfigArray, + Regdictionary, + RegdictionaryArray, + Jsonb, + JsonbArray, + AnyRange, + EventTrigger, + Int4Range, + Int4RangeArray, + NumRange, + NumRangeArray, + TsRange, + TsRangeArray, + TstzRange, + TstzRangeArray, + DateRange, + DateRangeArray, + Int8Range, + Int8RangeArray, + Jsonpath, + JsonpathArray, + Regnamespace, + RegnamespaceArray, + Regrole, + RegroleArray, + Regcollation, + RegcollationArray, + Int4multiRange, + NummultiRange, + TsmultiRange, + TstzmultiRange, + DatemultiRange, + Int8multiRange, + AnymultiRange, + AnycompatiblemultiRange, + PgBrinBloomSummary, + PgBrinMinmaxMultiSummary, + PgMcvList, + PgSnapshot, + PgSnapshotArray, + Xid8, + Anycompatible, + Anycompatiblearray, + Anycompatiblenonarray, + AnycompatibleRange, + Int4multiRangeArray, + NummultiRangeArray, + TsmultiRangeArray, + TstzmultiRangeArray, + DatemultiRangeArray, + Int8multiRangeArray, + Other(Arc), +} + +impl Inner { + pub fn from_oid(oid: Oid) -> Option { + match oid { + 16 => Some(Inner::Bool), + 17 => Some(Inner::Bytea), + 18 => Some(Inner::Char), + 19 => Some(Inner::Name), + 20 => Some(Inner::Int8), + 21 => Some(Inner::Int2), + 22 => Some(Inner::Int2Vector), + 23 => Some(Inner::Int4), + 24 => Some(Inner::Regproc), + 25 => Some(Inner::Text), + 26 => Some(Inner::Oid), + 27 => Some(Inner::Tid), + 28 => Some(Inner::Xid), + 29 => Some(Inner::Cid), + 30 => Some(Inner::OidVector), + 32 => Some(Inner::PgDdlCommand), + 114 => Some(Inner::Json), + 142 => Some(Inner::Xml), + 143 => Some(Inner::XmlArray), + 194 => Some(Inner::PgNodeTree), + 199 => Some(Inner::JsonArray), + 269 => Some(Inner::TableAmHandler), + 271 => Some(Inner::Xid8Array), + 325 => Some(Inner::IndexAmHandler), + 600 => Some(Inner::Point), + 601 => Some(Inner::Lseg), + 602 => Some(Inner::Path), + 603 => Some(Inner::Box), + 604 => Some(Inner::Polygon), + 628 => Some(Inner::Line), + 629 => Some(Inner::LineArray), + 650 => Some(Inner::Cidr), + 651 => Some(Inner::CidrArray), + 700 => Some(Inner::Float4), + 701 => Some(Inner::Float8), + 705 => Some(Inner::Unknown), + 718 => Some(Inner::Circle), + 719 => Some(Inner::CircleArray), + 774 => Some(Inner::Macaddr8), + 775 => Some(Inner::Macaddr8Array), + 790 => Some(Inner::Money), + 791 => Some(Inner::MoneyArray), + 829 => Some(Inner::Macaddr), + 869 => Some(Inner::Inet), + 1000 => Some(Inner::BoolArray), + 1001 => Some(Inner::ByteaArray), + 1002 => Some(Inner::CharArray), + 1003 => Some(Inner::NameArray), + 1005 => Some(Inner::Int2Array), + 1006 => Some(Inner::Int2VectorArray), + 1007 => Some(Inner::Int4Array), + 1008 => Some(Inner::RegprocArray), + 1009 => Some(Inner::TextArray), + 1010 => Some(Inner::TidArray), + 1011 => Some(Inner::XidArray), + 1012 => Some(Inner::CidArray), + 1013 => Some(Inner::OidVectorArray), + 1014 => Some(Inner::BpcharArray), + 1015 => Some(Inner::VarcharArray), + 1016 => Some(Inner::Int8Array), + 1017 => Some(Inner::PointArray), + 1018 => Some(Inner::LsegArray), + 1019 => Some(Inner::PathArray), + 1020 => Some(Inner::BoxArray), + 1021 => Some(Inner::Float4Array), + 1022 => Some(Inner::Float8Array), + 1027 => Some(Inner::PolygonArray), + 1028 => Some(Inner::OidArray), + 1033 => Some(Inner::Aclitem), + 1034 => Some(Inner::AclitemArray), + 1040 => Some(Inner::MacaddrArray), + 1041 => Some(Inner::InetArray), + 1042 => Some(Inner::Bpchar), + 1043 => Some(Inner::Varchar), + 1082 => Some(Inner::Date), + 1083 => Some(Inner::Time), + 1114 => Some(Inner::Timestamp), + 1115 => Some(Inner::TimestampArray), + 1182 => Some(Inner::DateArray), + 1183 => Some(Inner::TimeArray), + 1184 => Some(Inner::Timestamptz), + 1185 => Some(Inner::TimestamptzArray), + 1186 => Some(Inner::Interval), + 1187 => Some(Inner::IntervalArray), + 1231 => Some(Inner::NumericArray), + 1263 => Some(Inner::CstringArray), + 1266 => Some(Inner::Timetz), + 1270 => Some(Inner::TimetzArray), + 1560 => Some(Inner::Bit), + 1561 => Some(Inner::BitArray), + 1562 => Some(Inner::Varbit), + 1563 => Some(Inner::VarbitArray), + 1700 => Some(Inner::Numeric), + 1790 => Some(Inner::Refcursor), + 2201 => Some(Inner::RefcursorArray), + 2202 => Some(Inner::Regprocedure), + 2203 => Some(Inner::Regoper), + 2204 => Some(Inner::Regoperator), + 2205 => Some(Inner::Regclass), + 2206 => Some(Inner::Regtype), + 2207 => Some(Inner::RegprocedureArray), + 2208 => Some(Inner::RegoperArray), + 2209 => Some(Inner::RegoperatorArray), + 2210 => Some(Inner::RegclassArray), + 2211 => Some(Inner::RegtypeArray), + 2249 => Some(Inner::Record), + 2275 => Some(Inner::Cstring), + 2276 => Some(Inner::Any), + 2277 => Some(Inner::Anyarray), + 2278 => Some(Inner::Void), + 2279 => Some(Inner::Trigger), + 2280 => Some(Inner::LanguageHandler), + 2281 => Some(Inner::Internal), + 2283 => Some(Inner::Anyelement), + 2287 => Some(Inner::RecordArray), + 2776 => Some(Inner::Anynonarray), + 2949 => Some(Inner::TxidSnapshotArray), + 2950 => Some(Inner::Uuid), + 2951 => Some(Inner::UuidArray), + 2970 => Some(Inner::TxidSnapshot), + 3115 => Some(Inner::FdwHandler), + 3220 => Some(Inner::PgLsn), + 3221 => Some(Inner::PgLsnArray), + 3310 => Some(Inner::TsmHandler), + 3361 => Some(Inner::PgNdistinct), + 3402 => Some(Inner::PgDependencies), + 3500 => Some(Inner::Anyenum), + 3614 => Some(Inner::TsVector), + 3615 => Some(Inner::Tsquery), + 3642 => Some(Inner::GtsVector), + 3643 => Some(Inner::TsVectorArray), + 3644 => Some(Inner::GtsVectorArray), + 3645 => Some(Inner::TsqueryArray), + 3734 => Some(Inner::Regconfig), + 3735 => Some(Inner::RegconfigArray), + 3769 => Some(Inner::Regdictionary), + 3770 => Some(Inner::RegdictionaryArray), + 3802 => Some(Inner::Jsonb), + 3807 => Some(Inner::JsonbArray), + 3831 => Some(Inner::AnyRange), + 3838 => Some(Inner::EventTrigger), + 3904 => Some(Inner::Int4Range), + 3905 => Some(Inner::Int4RangeArray), + 3906 => Some(Inner::NumRange), + 3907 => Some(Inner::NumRangeArray), + 3908 => Some(Inner::TsRange), + 3909 => Some(Inner::TsRangeArray), + 3910 => Some(Inner::TstzRange), + 3911 => Some(Inner::TstzRangeArray), + 3912 => Some(Inner::DateRange), + 3913 => Some(Inner::DateRangeArray), + 3926 => Some(Inner::Int8Range), + 3927 => Some(Inner::Int8RangeArray), + 4072 => Some(Inner::Jsonpath), + 4073 => Some(Inner::JsonpathArray), + 4089 => Some(Inner::Regnamespace), + 4090 => Some(Inner::RegnamespaceArray), + 4096 => Some(Inner::Regrole), + 4097 => Some(Inner::RegroleArray), + 4191 => Some(Inner::Regcollation), + 4192 => Some(Inner::RegcollationArray), + 4451 => Some(Inner::Int4multiRange), + 4532 => Some(Inner::NummultiRange), + 4533 => Some(Inner::TsmultiRange), + 4534 => Some(Inner::TstzmultiRange), + 4535 => Some(Inner::DatemultiRange), + 4536 => Some(Inner::Int8multiRange), + 4537 => Some(Inner::AnymultiRange), + 4538 => Some(Inner::AnycompatiblemultiRange), + 4600 => Some(Inner::PgBrinBloomSummary), + 4601 => Some(Inner::PgBrinMinmaxMultiSummary), + 5017 => Some(Inner::PgMcvList), + 5038 => Some(Inner::PgSnapshot), + 5039 => Some(Inner::PgSnapshotArray), + 5069 => Some(Inner::Xid8), + 5077 => Some(Inner::Anycompatible), + 5078 => Some(Inner::Anycompatiblearray), + 5079 => Some(Inner::Anycompatiblenonarray), + 5080 => Some(Inner::AnycompatibleRange), + 6150 => Some(Inner::Int4multiRangeArray), + 6151 => Some(Inner::NummultiRangeArray), + 6152 => Some(Inner::TsmultiRangeArray), + 6153 => Some(Inner::TstzmultiRangeArray), + 6155 => Some(Inner::DatemultiRangeArray), + 6157 => Some(Inner::Int8multiRangeArray), + _ => None, + } + } + + pub fn oid(&self) -> Oid { + match *self { + Inner::Bool => 16, + Inner::Bytea => 17, + Inner::Char => 18, + Inner::Name => 19, + Inner::Int8 => 20, + Inner::Int2 => 21, + Inner::Int2Vector => 22, + Inner::Int4 => 23, + Inner::Regproc => 24, + Inner::Text => 25, + Inner::Oid => 26, + Inner::Tid => 27, + Inner::Xid => 28, + Inner::Cid => 29, + Inner::OidVector => 30, + Inner::PgDdlCommand => 32, + Inner::Json => 114, + Inner::Xml => 142, + Inner::XmlArray => 143, + Inner::PgNodeTree => 194, + Inner::JsonArray => 199, + Inner::TableAmHandler => 269, + Inner::Xid8Array => 271, + Inner::IndexAmHandler => 325, + Inner::Point => 600, + Inner::Lseg => 601, + Inner::Path => 602, + Inner::Box => 603, + Inner::Polygon => 604, + Inner::Line => 628, + Inner::LineArray => 629, + Inner::Cidr => 650, + Inner::CidrArray => 651, + Inner::Float4 => 700, + Inner::Float8 => 701, + Inner::Unknown => 705, + Inner::Circle => 718, + Inner::CircleArray => 719, + Inner::Macaddr8 => 774, + Inner::Macaddr8Array => 775, + Inner::Money => 790, + Inner::MoneyArray => 791, + Inner::Macaddr => 829, + Inner::Inet => 869, + Inner::BoolArray => 1000, + Inner::ByteaArray => 1001, + Inner::CharArray => 1002, + Inner::NameArray => 1003, + Inner::Int2Array => 1005, + Inner::Int2VectorArray => 1006, + Inner::Int4Array => 1007, + Inner::RegprocArray => 1008, + Inner::TextArray => 1009, + Inner::TidArray => 1010, + Inner::XidArray => 1011, + Inner::CidArray => 1012, + Inner::OidVectorArray => 1013, + Inner::BpcharArray => 1014, + Inner::VarcharArray => 1015, + Inner::Int8Array => 1016, + Inner::PointArray => 1017, + Inner::LsegArray => 1018, + Inner::PathArray => 1019, + Inner::BoxArray => 1020, + Inner::Float4Array => 1021, + Inner::Float8Array => 1022, + Inner::PolygonArray => 1027, + Inner::OidArray => 1028, + Inner::Aclitem => 1033, + Inner::AclitemArray => 1034, + Inner::MacaddrArray => 1040, + Inner::InetArray => 1041, + Inner::Bpchar => 1042, + Inner::Varchar => 1043, + Inner::Date => 1082, + Inner::Time => 1083, + Inner::Timestamp => 1114, + Inner::TimestampArray => 1115, + Inner::DateArray => 1182, + Inner::TimeArray => 1183, + Inner::Timestamptz => 1184, + Inner::TimestamptzArray => 1185, + Inner::Interval => 1186, + Inner::IntervalArray => 1187, + Inner::NumericArray => 1231, + Inner::CstringArray => 1263, + Inner::Timetz => 1266, + Inner::TimetzArray => 1270, + Inner::Bit => 1560, + Inner::BitArray => 1561, + Inner::Varbit => 1562, + Inner::VarbitArray => 1563, + Inner::Numeric => 1700, + Inner::Refcursor => 1790, + Inner::RefcursorArray => 2201, + Inner::Regprocedure => 2202, + Inner::Regoper => 2203, + Inner::Regoperator => 2204, + Inner::Regclass => 2205, + Inner::Regtype => 2206, + Inner::RegprocedureArray => 2207, + Inner::RegoperArray => 2208, + Inner::RegoperatorArray => 2209, + Inner::RegclassArray => 2210, + Inner::RegtypeArray => 2211, + Inner::Record => 2249, + Inner::Cstring => 2275, + Inner::Any => 2276, + Inner::Anyarray => 2277, + Inner::Void => 2278, + Inner::Trigger => 2279, + Inner::LanguageHandler => 2280, + Inner::Internal => 2281, + Inner::Anyelement => 2283, + Inner::RecordArray => 2287, + Inner::Anynonarray => 2776, + Inner::TxidSnapshotArray => 2949, + Inner::Uuid => 2950, + Inner::UuidArray => 2951, + Inner::TxidSnapshot => 2970, + Inner::FdwHandler => 3115, + Inner::PgLsn => 3220, + Inner::PgLsnArray => 3221, + Inner::TsmHandler => 3310, + Inner::PgNdistinct => 3361, + Inner::PgDependencies => 3402, + Inner::Anyenum => 3500, + Inner::TsVector => 3614, + Inner::Tsquery => 3615, + Inner::GtsVector => 3642, + Inner::TsVectorArray => 3643, + Inner::GtsVectorArray => 3644, + Inner::TsqueryArray => 3645, + Inner::Regconfig => 3734, + Inner::RegconfigArray => 3735, + Inner::Regdictionary => 3769, + Inner::RegdictionaryArray => 3770, + Inner::Jsonb => 3802, + Inner::JsonbArray => 3807, + Inner::AnyRange => 3831, + Inner::EventTrigger => 3838, + Inner::Int4Range => 3904, + Inner::Int4RangeArray => 3905, + Inner::NumRange => 3906, + Inner::NumRangeArray => 3907, + Inner::TsRange => 3908, + Inner::TsRangeArray => 3909, + Inner::TstzRange => 3910, + Inner::TstzRangeArray => 3911, + Inner::DateRange => 3912, + Inner::DateRangeArray => 3913, + Inner::Int8Range => 3926, + Inner::Int8RangeArray => 3927, + Inner::Jsonpath => 4072, + Inner::JsonpathArray => 4073, + Inner::Regnamespace => 4089, + Inner::RegnamespaceArray => 4090, + Inner::Regrole => 4096, + Inner::RegroleArray => 4097, + Inner::Regcollation => 4191, + Inner::RegcollationArray => 4192, + Inner::Int4multiRange => 4451, + Inner::NummultiRange => 4532, + Inner::TsmultiRange => 4533, + Inner::TstzmultiRange => 4534, + Inner::DatemultiRange => 4535, + Inner::Int8multiRange => 4536, + Inner::AnymultiRange => 4537, + Inner::AnycompatiblemultiRange => 4538, + Inner::PgBrinBloomSummary => 4600, + Inner::PgBrinMinmaxMultiSummary => 4601, + Inner::PgMcvList => 5017, + Inner::PgSnapshot => 5038, + Inner::PgSnapshotArray => 5039, + Inner::Xid8 => 5069, + Inner::Anycompatible => 5077, + Inner::Anycompatiblearray => 5078, + Inner::Anycompatiblenonarray => 5079, + Inner::AnycompatibleRange => 5080, + Inner::Int4multiRangeArray => 6150, + Inner::NummultiRangeArray => 6151, + Inner::TsmultiRangeArray => 6152, + Inner::TstzmultiRangeArray => 6153, + Inner::DatemultiRangeArray => 6155, + Inner::Int8multiRangeArray => 6157, + Inner::Other(ref u) => u.oid, + } + } + + pub fn kind(&self) -> &Kind { + match *self { + Inner::Bool => &Kind::Simple, + Inner::Bytea => &Kind::Simple, + Inner::Char => &Kind::Simple, + Inner::Name => &Kind::Simple, + Inner::Int8 => &Kind::Simple, + Inner::Int2 => &Kind::Simple, + Inner::Int2Vector => &Kind::Array(Type(Inner::Int2)), + Inner::Int4 => &Kind::Simple, + Inner::Regproc => &Kind::Simple, + Inner::Text => &Kind::Simple, + Inner::Oid => &Kind::Simple, + Inner::Tid => &Kind::Simple, + Inner::Xid => &Kind::Simple, + Inner::Cid => &Kind::Simple, + Inner::OidVector => &Kind::Array(Type(Inner::Oid)), + Inner::PgDdlCommand => &Kind::Pseudo, + Inner::Json => &Kind::Simple, + Inner::Xml => &Kind::Simple, + Inner::XmlArray => &Kind::Array(Type(Inner::Xml)), + Inner::PgNodeTree => &Kind::Simple, + Inner::JsonArray => &Kind::Array(Type(Inner::Json)), + Inner::TableAmHandler => &Kind::Pseudo, + Inner::Xid8Array => &Kind::Array(Type(Inner::Xid8)), + Inner::IndexAmHandler => &Kind::Pseudo, + Inner::Point => &Kind::Simple, + Inner::Lseg => &Kind::Simple, + Inner::Path => &Kind::Simple, + Inner::Box => &Kind::Simple, + Inner::Polygon => &Kind::Simple, + Inner::Line => &Kind::Simple, + Inner::LineArray => &Kind::Array(Type(Inner::Line)), + Inner::Cidr => &Kind::Simple, + Inner::CidrArray => &Kind::Array(Type(Inner::Cidr)), + Inner::Float4 => &Kind::Simple, + Inner::Float8 => &Kind::Simple, + Inner::Unknown => &Kind::Simple, + Inner::Circle => &Kind::Simple, + Inner::CircleArray => &Kind::Array(Type(Inner::Circle)), + Inner::Macaddr8 => &Kind::Simple, + Inner::Macaddr8Array => &Kind::Array(Type(Inner::Macaddr8)), + Inner::Money => &Kind::Simple, + Inner::MoneyArray => &Kind::Array(Type(Inner::Money)), + Inner::Macaddr => &Kind::Simple, + Inner::Inet => &Kind::Simple, + Inner::BoolArray => &Kind::Array(Type(Inner::Bool)), + Inner::ByteaArray => &Kind::Array(Type(Inner::Bytea)), + Inner::CharArray => &Kind::Array(Type(Inner::Char)), + Inner::NameArray => &Kind::Array(Type(Inner::Name)), + Inner::Int2Array => &Kind::Array(Type(Inner::Int2)), + Inner::Int2VectorArray => &Kind::Array(Type(Inner::Int2Vector)), + Inner::Int4Array => &Kind::Array(Type(Inner::Int4)), + Inner::RegprocArray => &Kind::Array(Type(Inner::Regproc)), + Inner::TextArray => &Kind::Array(Type(Inner::Text)), + Inner::TidArray => &Kind::Array(Type(Inner::Tid)), + Inner::XidArray => &Kind::Array(Type(Inner::Xid)), + Inner::CidArray => &Kind::Array(Type(Inner::Cid)), + Inner::OidVectorArray => &Kind::Array(Type(Inner::OidVector)), + Inner::BpcharArray => &Kind::Array(Type(Inner::Bpchar)), + Inner::VarcharArray => &Kind::Array(Type(Inner::Varchar)), + Inner::Int8Array => &Kind::Array(Type(Inner::Int8)), + Inner::PointArray => &Kind::Array(Type(Inner::Point)), + Inner::LsegArray => &Kind::Array(Type(Inner::Lseg)), + Inner::PathArray => &Kind::Array(Type(Inner::Path)), + Inner::BoxArray => &Kind::Array(Type(Inner::Box)), + Inner::Float4Array => &Kind::Array(Type(Inner::Float4)), + Inner::Float8Array => &Kind::Array(Type(Inner::Float8)), + Inner::PolygonArray => &Kind::Array(Type(Inner::Polygon)), + Inner::OidArray => &Kind::Array(Type(Inner::Oid)), + Inner::Aclitem => &Kind::Simple, + Inner::AclitemArray => &Kind::Array(Type(Inner::Aclitem)), + Inner::MacaddrArray => &Kind::Array(Type(Inner::Macaddr)), + Inner::InetArray => &Kind::Array(Type(Inner::Inet)), + Inner::Bpchar => &Kind::Simple, + Inner::Varchar => &Kind::Simple, + Inner::Date => &Kind::Simple, + Inner::Time => &Kind::Simple, + Inner::Timestamp => &Kind::Simple, + Inner::TimestampArray => &Kind::Array(Type(Inner::Timestamp)), + Inner::DateArray => &Kind::Array(Type(Inner::Date)), + Inner::TimeArray => &Kind::Array(Type(Inner::Time)), + Inner::Timestamptz => &Kind::Simple, + Inner::TimestamptzArray => &Kind::Array(Type(Inner::Timestamptz)), + Inner::Interval => &Kind::Simple, + Inner::IntervalArray => &Kind::Array(Type(Inner::Interval)), + Inner::NumericArray => &Kind::Array(Type(Inner::Numeric)), + Inner::CstringArray => &Kind::Array(Type(Inner::Cstring)), + Inner::Timetz => &Kind::Simple, + Inner::TimetzArray => &Kind::Array(Type(Inner::Timetz)), + Inner::Bit => &Kind::Simple, + Inner::BitArray => &Kind::Array(Type(Inner::Bit)), + Inner::Varbit => &Kind::Simple, + Inner::VarbitArray => &Kind::Array(Type(Inner::Varbit)), + Inner::Numeric => &Kind::Simple, + Inner::Refcursor => &Kind::Simple, + Inner::RefcursorArray => &Kind::Array(Type(Inner::Refcursor)), + Inner::Regprocedure => &Kind::Simple, + Inner::Regoper => &Kind::Simple, + Inner::Regoperator => &Kind::Simple, + Inner::Regclass => &Kind::Simple, + Inner::Regtype => &Kind::Simple, + Inner::RegprocedureArray => &Kind::Array(Type(Inner::Regprocedure)), + Inner::RegoperArray => &Kind::Array(Type(Inner::Regoper)), + Inner::RegoperatorArray => &Kind::Array(Type(Inner::Regoperator)), + Inner::RegclassArray => &Kind::Array(Type(Inner::Regclass)), + Inner::RegtypeArray => &Kind::Array(Type(Inner::Regtype)), + Inner::Record => &Kind::Pseudo, + Inner::Cstring => &Kind::Pseudo, + Inner::Any => &Kind::Pseudo, + Inner::Anyarray => &Kind::Pseudo, + Inner::Void => &Kind::Pseudo, + Inner::Trigger => &Kind::Pseudo, + Inner::LanguageHandler => &Kind::Pseudo, + Inner::Internal => &Kind::Pseudo, + Inner::Anyelement => &Kind::Pseudo, + Inner::RecordArray => &Kind::Pseudo, + Inner::Anynonarray => &Kind::Pseudo, + Inner::TxidSnapshotArray => &Kind::Array(Type(Inner::TxidSnapshot)), + Inner::Uuid => &Kind::Simple, + Inner::UuidArray => &Kind::Array(Type(Inner::Uuid)), + Inner::TxidSnapshot => &Kind::Simple, + Inner::FdwHandler => &Kind::Pseudo, + Inner::PgLsn => &Kind::Simple, + Inner::PgLsnArray => &Kind::Array(Type(Inner::PgLsn)), + Inner::TsmHandler => &Kind::Pseudo, + Inner::PgNdistinct => &Kind::Simple, + Inner::PgDependencies => &Kind::Simple, + Inner::Anyenum => &Kind::Pseudo, + Inner::TsVector => &Kind::Simple, + Inner::Tsquery => &Kind::Simple, + Inner::GtsVector => &Kind::Simple, + Inner::TsVectorArray => &Kind::Array(Type(Inner::TsVector)), + Inner::GtsVectorArray => &Kind::Array(Type(Inner::GtsVector)), + Inner::TsqueryArray => &Kind::Array(Type(Inner::Tsquery)), + Inner::Regconfig => &Kind::Simple, + Inner::RegconfigArray => &Kind::Array(Type(Inner::Regconfig)), + Inner::Regdictionary => &Kind::Simple, + Inner::RegdictionaryArray => &Kind::Array(Type(Inner::Regdictionary)), + Inner::Jsonb => &Kind::Simple, + Inner::JsonbArray => &Kind::Array(Type(Inner::Jsonb)), + Inner::AnyRange => &Kind::Pseudo, + Inner::EventTrigger => &Kind::Pseudo, + Inner::Int4Range => &Kind::Range(Type(Inner::Int4)), + Inner::Int4RangeArray => &Kind::Array(Type(Inner::Int4Range)), + Inner::NumRange => &Kind::Range(Type(Inner::Numeric)), + Inner::NumRangeArray => &Kind::Array(Type(Inner::NumRange)), + Inner::TsRange => &Kind::Range(Type(Inner::Timestamp)), + Inner::TsRangeArray => &Kind::Array(Type(Inner::TsRange)), + Inner::TstzRange => &Kind::Range(Type(Inner::Timestamptz)), + Inner::TstzRangeArray => &Kind::Array(Type(Inner::TstzRange)), + Inner::DateRange => &Kind::Range(Type(Inner::Date)), + Inner::DateRangeArray => &Kind::Array(Type(Inner::DateRange)), + Inner::Int8Range => &Kind::Range(Type(Inner::Int8)), + Inner::Int8RangeArray => &Kind::Array(Type(Inner::Int8Range)), + Inner::Jsonpath => &Kind::Simple, + Inner::JsonpathArray => &Kind::Array(Type(Inner::Jsonpath)), + Inner::Regnamespace => &Kind::Simple, + Inner::RegnamespaceArray => &Kind::Array(Type(Inner::Regnamespace)), + Inner::Regrole => &Kind::Simple, + Inner::RegroleArray => &Kind::Array(Type(Inner::Regrole)), + Inner::Regcollation => &Kind::Simple, + Inner::RegcollationArray => &Kind::Array(Type(Inner::Regcollation)), + Inner::Int4multiRange => &Kind::Multirange(Type(Inner::Int4)), + Inner::NummultiRange => &Kind::Multirange(Type(Inner::Numeric)), + Inner::TsmultiRange => &Kind::Multirange(Type(Inner::Timestamp)), + Inner::TstzmultiRange => &Kind::Multirange(Type(Inner::Timestamptz)), + Inner::DatemultiRange => &Kind::Multirange(Type(Inner::Date)), + Inner::Int8multiRange => &Kind::Multirange(Type(Inner::Int8)), + Inner::AnymultiRange => &Kind::Pseudo, + Inner::AnycompatiblemultiRange => &Kind::Pseudo, + Inner::PgBrinBloomSummary => &Kind::Simple, + Inner::PgBrinMinmaxMultiSummary => &Kind::Simple, + Inner::PgMcvList => &Kind::Simple, + Inner::PgSnapshot => &Kind::Simple, + Inner::PgSnapshotArray => &Kind::Array(Type(Inner::PgSnapshot)), + Inner::Xid8 => &Kind::Simple, + Inner::Anycompatible => &Kind::Pseudo, + Inner::Anycompatiblearray => &Kind::Pseudo, + Inner::Anycompatiblenonarray => &Kind::Pseudo, + Inner::AnycompatibleRange => &Kind::Pseudo, + Inner::Int4multiRangeArray => &Kind::Array(Type(Inner::Int4multiRange)), + Inner::NummultiRangeArray => &Kind::Array(Type(Inner::NummultiRange)), + Inner::TsmultiRangeArray => &Kind::Array(Type(Inner::TsmultiRange)), + Inner::TstzmultiRangeArray => &Kind::Array(Type(Inner::TstzmultiRange)), + Inner::DatemultiRangeArray => &Kind::Array(Type(Inner::DatemultiRange)), + Inner::Int8multiRangeArray => &Kind::Array(Type(Inner::Int8multiRange)), + Inner::Other(ref u) => &u.kind, + } + } + + pub fn name(&self) -> &str { + match *self { + Inner::Bool => "bool", + Inner::Bytea => "bytea", + Inner::Char => "char", + Inner::Name => "name", + Inner::Int8 => "int8", + Inner::Int2 => "int2", + Inner::Int2Vector => "int2vector", + Inner::Int4 => "int4", + Inner::Regproc => "regproc", + Inner::Text => "text", + Inner::Oid => "oid", + Inner::Tid => "tid", + Inner::Xid => "xid", + Inner::Cid => "cid", + Inner::OidVector => "oidvector", + Inner::PgDdlCommand => "pg_ddl_command", + Inner::Json => "json", + Inner::Xml => "xml", + Inner::XmlArray => "_xml", + Inner::PgNodeTree => "pg_node_tree", + Inner::JsonArray => "_json", + Inner::TableAmHandler => "table_am_handler", + Inner::Xid8Array => "_xid8", + Inner::IndexAmHandler => "index_am_handler", + Inner::Point => "point", + Inner::Lseg => "lseg", + Inner::Path => "path", + Inner::Box => "box", + Inner::Polygon => "polygon", + Inner::Line => "line", + Inner::LineArray => "_line", + Inner::Cidr => "cidr", + Inner::CidrArray => "_cidr", + Inner::Float4 => "float4", + Inner::Float8 => "float8", + Inner::Unknown => "unknown", + Inner::Circle => "circle", + Inner::CircleArray => "_circle", + Inner::Macaddr8 => "macaddr8", + Inner::Macaddr8Array => "_macaddr8", + Inner::Money => "money", + Inner::MoneyArray => "_money", + Inner::Macaddr => "macaddr", + Inner::Inet => "inet", + Inner::BoolArray => "_bool", + Inner::ByteaArray => "_bytea", + Inner::CharArray => "_char", + Inner::NameArray => "_name", + Inner::Int2Array => "_int2", + Inner::Int2VectorArray => "_int2vector", + Inner::Int4Array => "_int4", + Inner::RegprocArray => "_regproc", + Inner::TextArray => "_text", + Inner::TidArray => "_tid", + Inner::XidArray => "_xid", + Inner::CidArray => "_cid", + Inner::OidVectorArray => "_oidvector", + Inner::BpcharArray => "_bpchar", + Inner::VarcharArray => "_varchar", + Inner::Int8Array => "_int8", + Inner::PointArray => "_point", + Inner::LsegArray => "_lseg", + Inner::PathArray => "_path", + Inner::BoxArray => "_box", + Inner::Float4Array => "_float4", + Inner::Float8Array => "_float8", + Inner::PolygonArray => "_polygon", + Inner::OidArray => "_oid", + Inner::Aclitem => "aclitem", + Inner::AclitemArray => "_aclitem", + Inner::MacaddrArray => "_macaddr", + Inner::InetArray => "_inet", + Inner::Bpchar => "bpchar", + Inner::Varchar => "varchar", + Inner::Date => "date", + Inner::Time => "time", + Inner::Timestamp => "timestamp", + Inner::TimestampArray => "_timestamp", + Inner::DateArray => "_date", + Inner::TimeArray => "_time", + Inner::Timestamptz => "timestamptz", + Inner::TimestamptzArray => "_timestamptz", + Inner::Interval => "interval", + Inner::IntervalArray => "_interval", + Inner::NumericArray => "_numeric", + Inner::CstringArray => "_cstring", + Inner::Timetz => "timetz", + Inner::TimetzArray => "_timetz", + Inner::Bit => "bit", + Inner::BitArray => "_bit", + Inner::Varbit => "varbit", + Inner::VarbitArray => "_varbit", + Inner::Numeric => "numeric", + Inner::Refcursor => "refcursor", + Inner::RefcursorArray => "_refcursor", + Inner::Regprocedure => "regprocedure", + Inner::Regoper => "regoper", + Inner::Regoperator => "regoperator", + Inner::Regclass => "regclass", + Inner::Regtype => "regtype", + Inner::RegprocedureArray => "_regprocedure", + Inner::RegoperArray => "_regoper", + Inner::RegoperatorArray => "_regoperator", + Inner::RegclassArray => "_regclass", + Inner::RegtypeArray => "_regtype", + Inner::Record => "record", + Inner::Cstring => "cstring", + Inner::Any => "any", + Inner::Anyarray => "anyarray", + Inner::Void => "void", + Inner::Trigger => "trigger", + Inner::LanguageHandler => "language_handler", + Inner::Internal => "internal", + Inner::Anyelement => "anyelement", + Inner::RecordArray => "_record", + Inner::Anynonarray => "anynonarray", + Inner::TxidSnapshotArray => "_txid_snapshot", + Inner::Uuid => "uuid", + Inner::UuidArray => "_uuid", + Inner::TxidSnapshot => "txid_snapshot", + Inner::FdwHandler => "fdw_handler", + Inner::PgLsn => "pg_lsn", + Inner::PgLsnArray => "_pg_lsn", + Inner::TsmHandler => "tsm_handler", + Inner::PgNdistinct => "pg_ndistinct", + Inner::PgDependencies => "pg_dependencies", + Inner::Anyenum => "anyenum", + Inner::TsVector => "tsvector", + Inner::Tsquery => "tsquery", + Inner::GtsVector => "gtsvector", + Inner::TsVectorArray => "_tsvector", + Inner::GtsVectorArray => "_gtsvector", + Inner::TsqueryArray => "_tsquery", + Inner::Regconfig => "regconfig", + Inner::RegconfigArray => "_regconfig", + Inner::Regdictionary => "regdictionary", + Inner::RegdictionaryArray => "_regdictionary", + Inner::Jsonb => "jsonb", + Inner::JsonbArray => "_jsonb", + Inner::AnyRange => "anyrange", + Inner::EventTrigger => "event_trigger", + Inner::Int4Range => "int4range", + Inner::Int4RangeArray => "_int4range", + Inner::NumRange => "numrange", + Inner::NumRangeArray => "_numrange", + Inner::TsRange => "tsrange", + Inner::TsRangeArray => "_tsrange", + Inner::TstzRange => "tstzrange", + Inner::TstzRangeArray => "_tstzrange", + Inner::DateRange => "daterange", + Inner::DateRangeArray => "_daterange", + Inner::Int8Range => "int8range", + Inner::Int8RangeArray => "_int8range", + Inner::Jsonpath => "jsonpath", + Inner::JsonpathArray => "_jsonpath", + Inner::Regnamespace => "regnamespace", + Inner::RegnamespaceArray => "_regnamespace", + Inner::Regrole => "regrole", + Inner::RegroleArray => "_regrole", + Inner::Regcollation => "regcollation", + Inner::RegcollationArray => "_regcollation", + Inner::Int4multiRange => "int4multirange", + Inner::NummultiRange => "nummultirange", + Inner::TsmultiRange => "tsmultirange", + Inner::TstzmultiRange => "tstzmultirange", + Inner::DatemultiRange => "datemultirange", + Inner::Int8multiRange => "int8multirange", + Inner::AnymultiRange => "anymultirange", + Inner::AnycompatiblemultiRange => "anycompatiblemultirange", + Inner::PgBrinBloomSummary => "pg_brin_bloom_summary", + Inner::PgBrinMinmaxMultiSummary => "pg_brin_minmax_multi_summary", + Inner::PgMcvList => "pg_mcv_list", + Inner::PgSnapshot => "pg_snapshot", + Inner::PgSnapshotArray => "_pg_snapshot", + Inner::Xid8 => "xid8", + Inner::Anycompatible => "anycompatible", + Inner::Anycompatiblearray => "anycompatiblearray", + Inner::Anycompatiblenonarray => "anycompatiblenonarray", + Inner::AnycompatibleRange => "anycompatiblerange", + Inner::Int4multiRangeArray => "_int4multirange", + Inner::NummultiRangeArray => "_nummultirange", + Inner::TsmultiRangeArray => "_tsmultirange", + Inner::TstzmultiRangeArray => "_tstzmultirange", + Inner::DatemultiRangeArray => "_datemultirange", + Inner::Int8multiRangeArray => "_int8multirange", + Inner::Other(ref u) => &u.name, + } + } +} +impl Type { + /// BOOL - boolean, 'true'/'false' + pub const BOOL: Type = Type(Inner::Bool); + + /// BYTEA - variable-length string, binary values escaped + pub const BYTEA: Type = Type(Inner::Bytea); + + /// CHAR - single character + pub const CHAR: Type = Type(Inner::Char); + + /// NAME - 63-byte type for storing system identifiers + pub const NAME: Type = Type(Inner::Name); + + /// INT8 - ~18 digit integer, 8-byte storage + pub const INT8: Type = Type(Inner::Int8); + + /// INT2 - -32 thousand to 32 thousand, 2-byte storage + pub const INT2: Type = Type(Inner::Int2); + + /// INT2VECTOR - array of int2, used in system tables + pub const INT2_VECTOR: Type = Type(Inner::Int2Vector); + + /// INT4 - -2 billion to 2 billion integer, 4-byte storage + pub const INT4: Type = Type(Inner::Int4); + + /// REGPROC - registered procedure + pub const REGPROC: Type = Type(Inner::Regproc); + + /// TEXT - variable-length string, no limit specified + pub const TEXT: Type = Type(Inner::Text); + + /// OID - object identifier(oid), maximum 4 billion + pub const OID: Type = Type(Inner::Oid); + + /// TID - (block, offset), physical location of tuple + pub const TID: Type = Type(Inner::Tid); + + /// XID - transaction id + pub const XID: Type = Type(Inner::Xid); + + /// CID - command identifier type, sequence in transaction id + pub const CID: Type = Type(Inner::Cid); + + /// OIDVECTOR - array of oids, used in system tables + pub const OID_VECTOR: Type = Type(Inner::OidVector); + + /// PG_DDL_COMMAND - internal type for passing CollectedCommand + pub const PG_DDL_COMMAND: Type = Type(Inner::PgDdlCommand); + + /// JSON - JSON stored as text + pub const JSON: Type = Type(Inner::Json); + + /// XML - XML content + pub const XML: Type = Type(Inner::Xml); + + /// XML[] + pub const XML_ARRAY: Type = Type(Inner::XmlArray); + + /// PG_NODE_TREE - string representing an internal node tree + pub const PG_NODE_TREE: Type = Type(Inner::PgNodeTree); + + /// JSON[] + pub const JSON_ARRAY: Type = Type(Inner::JsonArray); + + /// TABLE_AM_HANDLER + pub const TABLE_AM_HANDLER: Type = Type(Inner::TableAmHandler); + + /// XID8[] + pub const XID8_ARRAY: Type = Type(Inner::Xid8Array); + + /// INDEX_AM_HANDLER - pseudo-type for the result of an index AM handler function + pub const INDEX_AM_HANDLER: Type = Type(Inner::IndexAmHandler); + + /// POINT - geometric point '(x, y)' + pub const POINT: Type = Type(Inner::Point); + + /// LSEG - geometric line segment '(pt1,pt2)' + pub const LSEG: Type = Type(Inner::Lseg); + + /// PATH - geometric path '(pt1,...)' + pub const PATH: Type = Type(Inner::Path); + + /// BOX - geometric box '(lower left,upper right)' + pub const BOX: Type = Type(Inner::Box); + + /// POLYGON - geometric polygon '(pt1,...)' + pub const POLYGON: Type = Type(Inner::Polygon); + + /// LINE - geometric line + pub const LINE: Type = Type(Inner::Line); + + /// LINE[] + pub const LINE_ARRAY: Type = Type(Inner::LineArray); + + /// CIDR - network IP address/netmask, network address + pub const CIDR: Type = Type(Inner::Cidr); + + /// CIDR[] + pub const CIDR_ARRAY: Type = Type(Inner::CidrArray); + + /// FLOAT4 - single-precision floating point number, 4-byte storage + pub const FLOAT4: Type = Type(Inner::Float4); + + /// FLOAT8 - double-precision floating point number, 8-byte storage + pub const FLOAT8: Type = Type(Inner::Float8); + + /// UNKNOWN - pseudo-type representing an undetermined type + pub const UNKNOWN: Type = Type(Inner::Unknown); + + /// CIRCLE - geometric circle '(center,radius)' + pub const CIRCLE: Type = Type(Inner::Circle); + + /// CIRCLE[] + pub const CIRCLE_ARRAY: Type = Type(Inner::CircleArray); + + /// MACADDR8 - XX:XX:XX:XX:XX:XX:XX:XX, MAC address + pub const MACADDR8: Type = Type(Inner::Macaddr8); + + /// MACADDR8[] + pub const MACADDR8_ARRAY: Type = Type(Inner::Macaddr8Array); + + /// MONEY - monetary amounts, $d,ddd.cc + pub const MONEY: Type = Type(Inner::Money); + + /// MONEY[] + pub const MONEY_ARRAY: Type = Type(Inner::MoneyArray); + + /// MACADDR - XX:XX:XX:XX:XX:XX, MAC address + pub const MACADDR: Type = Type(Inner::Macaddr); + + /// INET - IP address/netmask, host address, netmask optional + pub const INET: Type = Type(Inner::Inet); + + /// BOOL[] + pub const BOOL_ARRAY: Type = Type(Inner::BoolArray); + + /// BYTEA[] + pub const BYTEA_ARRAY: Type = Type(Inner::ByteaArray); + + /// CHAR[] + pub const CHAR_ARRAY: Type = Type(Inner::CharArray); + + /// NAME[] + pub const NAME_ARRAY: Type = Type(Inner::NameArray); + + /// INT2[] + pub const INT2_ARRAY: Type = Type(Inner::Int2Array); + + /// INT2VECTOR[] + pub const INT2_VECTOR_ARRAY: Type = Type(Inner::Int2VectorArray); + + /// INT4[] + pub const INT4_ARRAY: Type = Type(Inner::Int4Array); + + /// REGPROC[] + pub const REGPROC_ARRAY: Type = Type(Inner::RegprocArray); + + /// TEXT[] + pub const TEXT_ARRAY: Type = Type(Inner::TextArray); + + /// TID[] + pub const TID_ARRAY: Type = Type(Inner::TidArray); + + /// XID[] + pub const XID_ARRAY: Type = Type(Inner::XidArray); + + /// CID[] + pub const CID_ARRAY: Type = Type(Inner::CidArray); + + /// OIDVECTOR[] + pub const OID_VECTOR_ARRAY: Type = Type(Inner::OidVectorArray); + + /// BPCHAR[] + pub const BPCHAR_ARRAY: Type = Type(Inner::BpcharArray); + + /// VARCHAR[] + pub const VARCHAR_ARRAY: Type = Type(Inner::VarcharArray); + + /// INT8[] + pub const INT8_ARRAY: Type = Type(Inner::Int8Array); + + /// POINT[] + pub const POINT_ARRAY: Type = Type(Inner::PointArray); + + /// LSEG[] + pub const LSEG_ARRAY: Type = Type(Inner::LsegArray); + + /// PATH[] + pub const PATH_ARRAY: Type = Type(Inner::PathArray); + + /// BOX[] + pub const BOX_ARRAY: Type = Type(Inner::BoxArray); + + /// FLOAT4[] + pub const FLOAT4_ARRAY: Type = Type(Inner::Float4Array); + + /// FLOAT8[] + pub const FLOAT8_ARRAY: Type = Type(Inner::Float8Array); + + /// POLYGON[] + pub const POLYGON_ARRAY: Type = Type(Inner::PolygonArray); + + /// OID[] + pub const OID_ARRAY: Type = Type(Inner::OidArray); + + /// ACLITEM - access control list + pub const ACLITEM: Type = Type(Inner::Aclitem); + + /// ACLITEM[] + pub const ACLITEM_ARRAY: Type = Type(Inner::AclitemArray); + + /// MACADDR[] + pub const MACADDR_ARRAY: Type = Type(Inner::MacaddrArray); + + /// INET[] + pub const INET_ARRAY: Type = Type(Inner::InetArray); + + /// BPCHAR - char(length), blank-padded string, fixed storage length + pub const BPCHAR: Type = Type(Inner::Bpchar); + + /// VARCHAR - varchar(length), non-blank-padded string, variable storage length + pub const VARCHAR: Type = Type(Inner::Varchar); + + /// DATE - date + pub const DATE: Type = Type(Inner::Date); + + /// TIME - time of day + pub const TIME: Type = Type(Inner::Time); + + /// TIMESTAMP - date and time + pub const TIMESTAMP: Type = Type(Inner::Timestamp); + + /// TIMESTAMP[] + pub const TIMESTAMP_ARRAY: Type = Type(Inner::TimestampArray); + + /// DATE[] + pub const DATE_ARRAY: Type = Type(Inner::DateArray); + + /// TIME[] + pub const TIME_ARRAY: Type = Type(Inner::TimeArray); + + /// TIMESTAMPTZ - date and time with time zone + pub const TIMESTAMPTZ: Type = Type(Inner::Timestamptz); + + /// TIMESTAMPTZ[] + pub const TIMESTAMPTZ_ARRAY: Type = Type(Inner::TimestamptzArray); + + /// INTERVAL - @ <number> <units>, time interval + pub const INTERVAL: Type = Type(Inner::Interval); + + /// INTERVAL[] + pub const INTERVAL_ARRAY: Type = Type(Inner::IntervalArray); + + /// NUMERIC[] + pub const NUMERIC_ARRAY: Type = Type(Inner::NumericArray); + + /// CSTRING[] + pub const CSTRING_ARRAY: Type = Type(Inner::CstringArray); + + /// TIMETZ - time of day with time zone + pub const TIMETZ: Type = Type(Inner::Timetz); + + /// TIMETZ[] + pub const TIMETZ_ARRAY: Type = Type(Inner::TimetzArray); + + /// BIT - fixed-length bit string + pub const BIT: Type = Type(Inner::Bit); + + /// BIT[] + pub const BIT_ARRAY: Type = Type(Inner::BitArray); + + /// VARBIT - variable-length bit string + pub const VARBIT: Type = Type(Inner::Varbit); + + /// VARBIT[] + pub const VARBIT_ARRAY: Type = Type(Inner::VarbitArray); + + /// NUMERIC - numeric(precision, decimal), arbitrary precision number + pub const NUMERIC: Type = Type(Inner::Numeric); + + /// REFCURSOR - reference to cursor (portal name) + pub const REFCURSOR: Type = Type(Inner::Refcursor); + + /// REFCURSOR[] + pub const REFCURSOR_ARRAY: Type = Type(Inner::RefcursorArray); + + /// REGPROCEDURE - registered procedure (with args) + pub const REGPROCEDURE: Type = Type(Inner::Regprocedure); + + /// REGOPER - registered operator + pub const REGOPER: Type = Type(Inner::Regoper); + + /// REGOPERATOR - registered operator (with args) + pub const REGOPERATOR: Type = Type(Inner::Regoperator); + + /// REGCLASS - registered class + pub const REGCLASS: Type = Type(Inner::Regclass); + + /// REGTYPE - registered type + pub const REGTYPE: Type = Type(Inner::Regtype); + + /// REGPROCEDURE[] + pub const REGPROCEDURE_ARRAY: Type = Type(Inner::RegprocedureArray); + + /// REGOPER[] + pub const REGOPER_ARRAY: Type = Type(Inner::RegoperArray); + + /// REGOPERATOR[] + pub const REGOPERATOR_ARRAY: Type = Type(Inner::RegoperatorArray); + + /// REGCLASS[] + pub const REGCLASS_ARRAY: Type = Type(Inner::RegclassArray); + + /// REGTYPE[] + pub const REGTYPE_ARRAY: Type = Type(Inner::RegtypeArray); + + /// RECORD - pseudo-type representing any composite type + pub const RECORD: Type = Type(Inner::Record); + + /// CSTRING - C-style string + pub const CSTRING: Type = Type(Inner::Cstring); + + /// ANY - pseudo-type representing any type + pub const ANY: Type = Type(Inner::Any); + + /// ANYARRAY - pseudo-type representing a polymorphic array type + pub const ANYARRAY: Type = Type(Inner::Anyarray); + + /// VOID - pseudo-type for the result of a function with no real result + pub const VOID: Type = Type(Inner::Void); + + /// TRIGGER - pseudo-type for the result of a trigger function + pub const TRIGGER: Type = Type(Inner::Trigger); + + /// LANGUAGE_HANDLER - pseudo-type for the result of a language handler function + pub const LANGUAGE_HANDLER: Type = Type(Inner::LanguageHandler); + + /// INTERNAL - pseudo-type representing an internal data structure + pub const INTERNAL: Type = Type(Inner::Internal); + + /// ANYELEMENT - pseudo-type representing a polymorphic base type + pub const ANYELEMENT: Type = Type(Inner::Anyelement); + + /// RECORD[] + pub const RECORD_ARRAY: Type = Type(Inner::RecordArray); + + /// ANYNONARRAY - pseudo-type representing a polymorphic base type that is not an array + pub const ANYNONARRAY: Type = Type(Inner::Anynonarray); + + /// TXID_SNAPSHOT[] + pub const TXID_SNAPSHOT_ARRAY: Type = Type(Inner::TxidSnapshotArray); + + /// UUID - UUID datatype + pub const UUID: Type = Type(Inner::Uuid); + + /// UUID[] + pub const UUID_ARRAY: Type = Type(Inner::UuidArray); + + /// TXID_SNAPSHOT - txid snapshot + pub const TXID_SNAPSHOT: Type = Type(Inner::TxidSnapshot); + + /// FDW_HANDLER - pseudo-type for the result of an FDW handler function + pub const FDW_HANDLER: Type = Type(Inner::FdwHandler); + + /// PG_LSN - PostgreSQL LSN datatype + pub const PG_LSN: Type = Type(Inner::PgLsn); + + /// PG_LSN[] + pub const PG_LSN_ARRAY: Type = Type(Inner::PgLsnArray); + + /// TSM_HANDLER - pseudo-type for the result of a tablesample method function + pub const TSM_HANDLER: Type = Type(Inner::TsmHandler); + + /// PG_NDISTINCT - multivariate ndistinct coefficients + pub const PG_NDISTINCT: Type = Type(Inner::PgNdistinct); + + /// PG_DEPENDENCIES - multivariate dependencies + pub const PG_DEPENDENCIES: Type = Type(Inner::PgDependencies); + + /// ANYENUM - pseudo-type representing a polymorphic base type that is an enum + pub const ANYENUM: Type = Type(Inner::Anyenum); + + /// TSVECTOR - text representation for text search + pub const TS_VECTOR: Type = Type(Inner::TsVector); + + /// TSQUERY - query representation for text search + pub const TSQUERY: Type = Type(Inner::Tsquery); + + /// GTSVECTOR - GiST index internal text representation for text search + pub const GTS_VECTOR: Type = Type(Inner::GtsVector); + + /// TSVECTOR[] + pub const TS_VECTOR_ARRAY: Type = Type(Inner::TsVectorArray); + + /// GTSVECTOR[] + pub const GTS_VECTOR_ARRAY: Type = Type(Inner::GtsVectorArray); + + /// TSQUERY[] + pub const TSQUERY_ARRAY: Type = Type(Inner::TsqueryArray); + + /// REGCONFIG - registered text search configuration + pub const REGCONFIG: Type = Type(Inner::Regconfig); + + /// REGCONFIG[] + pub const REGCONFIG_ARRAY: Type = Type(Inner::RegconfigArray); + + /// REGDICTIONARY - registered text search dictionary + pub const REGDICTIONARY: Type = Type(Inner::Regdictionary); + + /// REGDICTIONARY[] + pub const REGDICTIONARY_ARRAY: Type = Type(Inner::RegdictionaryArray); + + /// JSONB - Binary JSON + pub const JSONB: Type = Type(Inner::Jsonb); + + /// JSONB[] + pub const JSONB_ARRAY: Type = Type(Inner::JsonbArray); + + /// ANYRANGE - pseudo-type representing a range over a polymorphic base type + pub const ANY_RANGE: Type = Type(Inner::AnyRange); + + /// EVENT_TRIGGER - pseudo-type for the result of an event trigger function + pub const EVENT_TRIGGER: Type = Type(Inner::EventTrigger); + + /// INT4RANGE - range of integers + pub const INT4_RANGE: Type = Type(Inner::Int4Range); + + /// INT4RANGE[] + pub const INT4_RANGE_ARRAY: Type = Type(Inner::Int4RangeArray); + + /// NUMRANGE - range of numerics + pub const NUM_RANGE: Type = Type(Inner::NumRange); + + /// NUMRANGE[] + pub const NUM_RANGE_ARRAY: Type = Type(Inner::NumRangeArray); + + /// TSRANGE - range of timestamps without time zone + pub const TS_RANGE: Type = Type(Inner::TsRange); + + /// TSRANGE[] + pub const TS_RANGE_ARRAY: Type = Type(Inner::TsRangeArray); + + /// TSTZRANGE - range of timestamps with time zone + pub const TSTZ_RANGE: Type = Type(Inner::TstzRange); + + /// TSTZRANGE[] + pub const TSTZ_RANGE_ARRAY: Type = Type(Inner::TstzRangeArray); + + /// DATERANGE - range of dates + pub const DATE_RANGE: Type = Type(Inner::DateRange); + + /// DATERANGE[] + pub const DATE_RANGE_ARRAY: Type = Type(Inner::DateRangeArray); + + /// INT8RANGE - range of bigints + pub const INT8_RANGE: Type = Type(Inner::Int8Range); + + /// INT8RANGE[] + pub const INT8_RANGE_ARRAY: Type = Type(Inner::Int8RangeArray); + + /// JSONPATH - JSON path + pub const JSONPATH: Type = Type(Inner::Jsonpath); + + /// JSONPATH[] + pub const JSONPATH_ARRAY: Type = Type(Inner::JsonpathArray); + + /// REGNAMESPACE - registered namespace + pub const REGNAMESPACE: Type = Type(Inner::Regnamespace); + + /// REGNAMESPACE[] + pub const REGNAMESPACE_ARRAY: Type = Type(Inner::RegnamespaceArray); + + /// REGROLE - registered role + pub const REGROLE: Type = Type(Inner::Regrole); + + /// REGROLE[] + pub const REGROLE_ARRAY: Type = Type(Inner::RegroleArray); + + /// REGCOLLATION - registered collation + pub const REGCOLLATION: Type = Type(Inner::Regcollation); + + /// REGCOLLATION[] + pub const REGCOLLATION_ARRAY: Type = Type(Inner::RegcollationArray); + + /// INT4MULTIRANGE - multirange of integers + pub const INT4MULTI_RANGE: Type = Type(Inner::Int4multiRange); + + /// NUMMULTIRANGE - multirange of numerics + pub const NUMMULTI_RANGE: Type = Type(Inner::NummultiRange); + + /// TSMULTIRANGE - multirange of timestamps without time zone + pub const TSMULTI_RANGE: Type = Type(Inner::TsmultiRange); + + /// TSTZMULTIRANGE - multirange of timestamps with time zone + pub const TSTZMULTI_RANGE: Type = Type(Inner::TstzmultiRange); + + /// DATEMULTIRANGE - multirange of dates + pub const DATEMULTI_RANGE: Type = Type(Inner::DatemultiRange); + + /// INT8MULTIRANGE - multirange of bigints + pub const INT8MULTI_RANGE: Type = Type(Inner::Int8multiRange); + + /// ANYMULTIRANGE - pseudo-type representing a polymorphic base type that is a multirange + pub const ANYMULTI_RANGE: Type = Type(Inner::AnymultiRange); + + /// ANYCOMPATIBLEMULTIRANGE - pseudo-type representing a multirange over a polymorphic common type + pub const ANYCOMPATIBLEMULTI_RANGE: Type = Type(Inner::AnycompatiblemultiRange); + + /// PG_BRIN_BLOOM_SUMMARY - BRIN bloom summary + pub const PG_BRIN_BLOOM_SUMMARY: Type = Type(Inner::PgBrinBloomSummary); + + /// PG_BRIN_MINMAX_MULTI_SUMMARY - BRIN minmax-multi summary + pub const PG_BRIN_MINMAX_MULTI_SUMMARY: Type = Type(Inner::PgBrinMinmaxMultiSummary); + + /// PG_MCV_LIST - multivariate MCV list + pub const PG_MCV_LIST: Type = Type(Inner::PgMcvList); + + /// PG_SNAPSHOT - snapshot + pub const PG_SNAPSHOT: Type = Type(Inner::PgSnapshot); + + /// PG_SNAPSHOT[] + pub const PG_SNAPSHOT_ARRAY: Type = Type(Inner::PgSnapshotArray); + + /// XID8 - full transaction id + pub const XID8: Type = Type(Inner::Xid8); + + /// ANYCOMPATIBLE - pseudo-type representing a polymorphic common type + pub const ANYCOMPATIBLE: Type = Type(Inner::Anycompatible); + + /// ANYCOMPATIBLEARRAY - pseudo-type representing an array of polymorphic common type elements + pub const ANYCOMPATIBLEARRAY: Type = Type(Inner::Anycompatiblearray); + + /// ANYCOMPATIBLENONARRAY - pseudo-type representing a polymorphic common type that is not an array + pub const ANYCOMPATIBLENONARRAY: Type = Type(Inner::Anycompatiblenonarray); + + /// ANYCOMPATIBLERANGE - pseudo-type representing a range over a polymorphic common type + pub const ANYCOMPATIBLE_RANGE: Type = Type(Inner::AnycompatibleRange); + + /// INT4MULTIRANGE[] + pub const INT4MULTI_RANGE_ARRAY: Type = Type(Inner::Int4multiRangeArray); + + /// NUMMULTIRANGE[] + pub const NUMMULTI_RANGE_ARRAY: Type = Type(Inner::NummultiRangeArray); + + /// TSMULTIRANGE[] + pub const TSMULTI_RANGE_ARRAY: Type = Type(Inner::TsmultiRangeArray); + + /// TSTZMULTIRANGE[] + pub const TSTZMULTI_RANGE_ARRAY: Type = Type(Inner::TstzmultiRangeArray); + + /// DATEMULTIRANGE[] + pub const DATEMULTI_RANGE_ARRAY: Type = Type(Inner::DatemultiRangeArray); + + /// INT8MULTIRANGE[] + pub const INT8MULTI_RANGE_ARRAY: Type = Type(Inner::Int8multiRangeArray); +} diff --git a/libs/proxy/tokio-postgres2/Cargo.toml b/libs/proxy/tokio-postgres2/Cargo.toml new file mode 100644 index 0000000000..7130c1b726 --- /dev/null +++ b/libs/proxy/tokio-postgres2/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "tokio-postgres2" +version = "0.1.0" +edition = "2018" +license = "MIT/Apache-2.0" + +[dependencies] +async-trait.workspace = true +bytes.workspace = true +byteorder.workspace = true +fallible-iterator.workspace = true +futures-util = { workspace = true, features = ["sink"] } +log = "0.4" +parking_lot.workspace = true +percent-encoding = "2.0" +pin-project-lite.workspace = true +phf = "0.11" +postgres-protocol2 = { path = "../postgres-protocol2" } +postgres-types2 = { path = "../postgres-types2" } +tokio = { workspace = true, features = ["io-util", "time", "net"] } +tokio-util = { workspace = true, features = ["codec"] } diff --git a/libs/proxy/tokio-postgres2/src/cancel_query.rs b/libs/proxy/tokio-postgres2/src/cancel_query.rs new file mode 100644 index 0000000000..cddbf16336 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/cancel_query.rs @@ -0,0 +1,40 @@ +use tokio::net::TcpStream; + +use crate::client::SocketConfig; +use crate::config::{Host, SslMode}; +use crate::tls::MakeTlsConnect; +use crate::{cancel_query_raw, connect_socket, Error}; +use std::io; + +pub(crate) async fn cancel_query( + config: Option, + ssl_mode: SslMode, + mut tls: T, + process_id: i32, + secret_key: i32, +) -> Result<(), Error> +where + T: MakeTlsConnect, +{ + let config = match config { + Some(config) => config, + None => { + return Err(Error::connect(io::Error::new( + io::ErrorKind::InvalidInput, + "unknown host", + ))) + } + }; + + let hostname = match &config.host { + Host::Tcp(host) => &**host, + }; + let tls = tls + .make_tls_connect(hostname) + .map_err(|e| Error::tls(e.into()))?; + + let socket = + connect_socket::connect_socket(&config.host, config.port, config.connect_timeout).await?; + + cancel_query_raw::cancel_query_raw(socket, ssl_mode, tls, process_id, secret_key).await +} diff --git a/libs/proxy/tokio-postgres2/src/cancel_query_raw.rs b/libs/proxy/tokio-postgres2/src/cancel_query_raw.rs new file mode 100644 index 0000000000..8c08296435 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/cancel_query_raw.rs @@ -0,0 +1,29 @@ +use crate::config::SslMode; +use crate::tls::TlsConnect; +use crate::{connect_tls, Error}; +use bytes::BytesMut; +use postgres_protocol2::message::frontend; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; + +pub async fn cancel_query_raw( + stream: S, + mode: SslMode, + tls: T, + process_id: i32, + secret_key: i32, +) -> Result<(), Error> +where + S: AsyncRead + AsyncWrite + Unpin, + T: TlsConnect, +{ + let mut stream = connect_tls::connect_tls(stream, mode, tls).await?; + + let mut buf = BytesMut::new(); + frontend::cancel_request(process_id, secret_key, &mut buf); + + stream.write_all(&buf).await.map_err(Error::io)?; + stream.flush().await.map_err(Error::io)?; + stream.shutdown().await.map_err(Error::io)?; + + Ok(()) +} diff --git a/libs/proxy/tokio-postgres2/src/cancel_token.rs b/libs/proxy/tokio-postgres2/src/cancel_token.rs new file mode 100644 index 0000000000..b949bf358f --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/cancel_token.rs @@ -0,0 +1,62 @@ +use crate::config::SslMode; +use crate::tls::TlsConnect; + +use crate::{cancel_query, client::SocketConfig, tls::MakeTlsConnect}; +use crate::{cancel_query_raw, Error}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::net::TcpStream; + +/// The capability to request cancellation of in-progress queries on a +/// connection. +#[derive(Clone)] +pub struct CancelToken { + pub(crate) socket_config: Option, + pub(crate) ssl_mode: SslMode, + pub(crate) process_id: i32, + pub(crate) secret_key: i32, +} + +impl CancelToken { + /// Attempts to cancel the in-progress query on the connection associated + /// with this `CancelToken`. + /// + /// The server provides no information about whether a cancellation attempt was successful or not. An error will + /// only be returned if the client was unable to connect to the database. + /// + /// Cancellation is inherently racy. There is no guarantee that the + /// cancellation request will reach the server before the query terminates + /// normally, or that the connection associated with this token is still + /// active. + /// + /// Requires the `runtime` Cargo feature (enabled by default). + pub async fn cancel_query(&self, tls: T) -> Result<(), Error> + where + T: MakeTlsConnect, + { + cancel_query::cancel_query( + self.socket_config.clone(), + self.ssl_mode, + tls, + self.process_id, + self.secret_key, + ) + .await + } + + /// Like `cancel_query`, but uses a stream which is already connected to the server rather than opening a new + /// connection itself. + pub async fn cancel_query_raw(&self, stream: S, tls: T) -> Result<(), Error> + where + S: AsyncRead + AsyncWrite + Unpin, + T: TlsConnect, + { + cancel_query_raw::cancel_query_raw( + stream, + self.ssl_mode, + tls, + self.process_id, + self.secret_key, + ) + .await + } +} diff --git a/libs/proxy/tokio-postgres2/src/client.rs b/libs/proxy/tokio-postgres2/src/client.rs new file mode 100644 index 0000000000..96200b71e7 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/client.rs @@ -0,0 +1,439 @@ +use crate::codec::{BackendMessages, FrontendMessage}; + +use crate::config::Host; +use crate::config::SslMode; +use crate::connection::{Request, RequestMessages}; + +use crate::query::RowStream; +use crate::simple_query::SimpleQueryStream; + +use crate::types::{Oid, ToSql, Type}; + +use crate::{ + prepare, query, simple_query, slice_iter, CancelToken, Error, ReadyForQueryStatus, Row, + SimpleQueryMessage, Statement, ToStatement, Transaction, TransactionBuilder, +}; +use bytes::BytesMut; +use fallible_iterator::FallibleIterator; +use futures_util::{future, ready, TryStreamExt}; +use parking_lot::Mutex; +use postgres_protocol2::message::{backend::Message, frontend}; +use std::collections::HashMap; +use std::fmt; +use std::sync::Arc; +use std::task::{Context, Poll}; +use tokio::sync::mpsc; + +use std::time::Duration; + +pub struct Responses { + receiver: mpsc::Receiver, + cur: BackendMessages, +} + +impl Responses { + pub fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll> { + loop { + match self.cur.next().map_err(Error::parse)? { + Some(Message::ErrorResponse(body)) => return Poll::Ready(Err(Error::db(body))), + Some(message) => return Poll::Ready(Ok(message)), + None => {} + } + + match ready!(self.receiver.poll_recv(cx)) { + Some(messages) => self.cur = messages, + None => return Poll::Ready(Err(Error::closed())), + } + } + } + + pub async fn next(&mut self) -> Result { + future::poll_fn(|cx| self.poll_next(cx)).await + } +} + +/// A cache of type info and prepared statements for fetching type info +/// (corresponding to the queries in the [prepare] module). +#[derive(Default)] +struct CachedTypeInfo { + /// A statement for basic information for a type from its + /// OID. Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_QUERY) (or its + /// fallback). + typeinfo: Option, + /// A statement for getting information for a composite type from its OID. + /// Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_COMPOSITE_QUERY). + typeinfo_composite: Option, + /// A statement for getting information for a composite type from its OID. + /// Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_COMPOSITE_QUERY) (or + /// its fallback). + typeinfo_enum: Option, + + /// Cache of types already looked up. + types: HashMap, +} + +pub struct InnerClient { + sender: mpsc::UnboundedSender, + cached_typeinfo: Mutex, + + /// A buffer to use when writing out postgres commands. + buffer: Mutex, +} + +impl InnerClient { + pub fn send(&self, messages: RequestMessages) -> Result { + let (sender, receiver) = mpsc::channel(1); + let request = Request { messages, sender }; + self.sender.send(request).map_err(|_| Error::closed())?; + + Ok(Responses { + receiver, + cur: BackendMessages::empty(), + }) + } + + pub fn typeinfo(&self) -> Option { + self.cached_typeinfo.lock().typeinfo.clone() + } + + pub fn set_typeinfo(&self, statement: &Statement) { + self.cached_typeinfo.lock().typeinfo = Some(statement.clone()); + } + + pub fn typeinfo_composite(&self) -> Option { + self.cached_typeinfo.lock().typeinfo_composite.clone() + } + + pub fn set_typeinfo_composite(&self, statement: &Statement) { + self.cached_typeinfo.lock().typeinfo_composite = Some(statement.clone()); + } + + pub fn typeinfo_enum(&self) -> Option { + self.cached_typeinfo.lock().typeinfo_enum.clone() + } + + pub fn set_typeinfo_enum(&self, statement: &Statement) { + self.cached_typeinfo.lock().typeinfo_enum = Some(statement.clone()); + } + + pub fn type_(&self, oid: Oid) -> Option { + self.cached_typeinfo.lock().types.get(&oid).cloned() + } + + pub fn set_type(&self, oid: Oid, type_: &Type) { + self.cached_typeinfo.lock().types.insert(oid, type_.clone()); + } + + /// Call the given function with a buffer to be used when writing out + /// postgres commands. + pub fn with_buf(&self, f: F) -> R + where + F: FnOnce(&mut BytesMut) -> R, + { + let mut buffer = self.buffer.lock(); + let r = f(&mut buffer); + buffer.clear(); + r + } +} + +#[derive(Clone)] +pub(crate) struct SocketConfig { + pub host: Host, + pub port: u16, + pub connect_timeout: Option, + // pub keepalive: Option, +} + +/// An asynchronous PostgreSQL client. +/// +/// The client is one half of what is returned when a connection is established. Users interact with the database +/// through this client object. +pub struct Client { + inner: Arc, + + socket_config: Option, + ssl_mode: SslMode, + process_id: i32, + secret_key: i32, +} + +impl Client { + pub(crate) fn new( + sender: mpsc::UnboundedSender, + ssl_mode: SslMode, + process_id: i32, + secret_key: i32, + ) -> Client { + Client { + inner: Arc::new(InnerClient { + sender, + cached_typeinfo: Default::default(), + buffer: Default::default(), + }), + + socket_config: None, + ssl_mode, + process_id, + secret_key, + } + } + + /// Returns process_id. + pub fn get_process_id(&self) -> i32 { + self.process_id + } + + pub(crate) fn inner(&self) -> &Arc { + &self.inner + } + + pub(crate) fn set_socket_config(&mut self, socket_config: SocketConfig) { + self.socket_config = Some(socket_config); + } + + /// Creates a new prepared statement. + /// + /// Prepared statements can be executed repeatedly, and may contain query parameters (indicated by `$1`, `$2`, etc), + /// which are set when executed. Prepared statements can only be used with the connection that created them. + pub async fn prepare(&self, query: &str) -> Result { + self.prepare_typed(query, &[]).await + } + + /// Like `prepare`, but allows the types of query parameters to be explicitly specified. + /// + /// The list of types may be smaller than the number of parameters - the types of the remaining parameters will be + /// inferred. For example, `client.prepare_typed(query, &[])` is equivalent to `client.prepare(query)`. + pub async fn prepare_typed( + &self, + query: &str, + parameter_types: &[Type], + ) -> Result { + prepare::prepare(&self.inner, query, parameter_types).await + } + + /// Executes a statement, returning a vector of the resulting rows. + /// + /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list + /// provided, 1-indexed. + /// + /// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be + /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front + /// with the `prepare` method. + /// + /// # Panics + /// + /// Panics if the number of parameters provided does not match the number expected. + pub async fn query( + &self, + statement: &T, + params: &[&(dyn ToSql + Sync)], + ) -> Result, Error> + where + T: ?Sized + ToStatement, + { + self.query_raw(statement, slice_iter(params)) + .await? + .try_collect() + .await + } + + /// The maximally flexible version of [`query`]. + /// + /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list + /// provided, 1-indexed. + /// + /// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be + /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front + /// with the `prepare` method. + /// + /// # Panics + /// + /// Panics if the number of parameters provided does not match the number expected. + /// + /// [`query`]: #method.query + pub async fn query_raw<'a, T, I>(&self, statement: &T, params: I) -> Result + where + T: ?Sized + ToStatement, + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + { + let statement = statement.__convert().into_statement(self).await?; + query::query(&self.inner, statement, params).await + } + + /// Pass text directly to the Postgres backend to allow it to sort out typing itself and + /// to save a roundtrip + pub async fn query_raw_txt(&self, statement: &str, params: I) -> Result + where + S: AsRef, + I: IntoIterator>, + I::IntoIter: ExactSizeIterator, + { + query::query_txt(&self.inner, statement, params).await + } + + /// Executes a statement, returning the number of rows modified. + /// + /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list + /// provided, 1-indexed. + /// + /// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be + /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front + /// with the `prepare` method. + /// + /// If the statement does not modify any rows (e.g. `SELECT`), 0 is returned. + /// + /// # Panics + /// + /// Panics if the number of parameters provided does not match the number expected. + pub async fn execute( + &self, + statement: &T, + params: &[&(dyn ToSql + Sync)], + ) -> Result + where + T: ?Sized + ToStatement, + { + self.execute_raw(statement, slice_iter(params)).await + } + + /// The maximally flexible version of [`execute`]. + /// + /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list + /// provided, 1-indexed. + /// + /// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be + /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front + /// with the `prepare` method. + /// + /// # Panics + /// + /// Panics if the number of parameters provided does not match the number expected. + /// + /// [`execute`]: #method.execute + pub async fn execute_raw<'a, T, I>(&self, statement: &T, params: I) -> Result + where + T: ?Sized + ToStatement, + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + { + let statement = statement.__convert().into_statement(self).await?; + query::execute(self.inner(), statement, params).await + } + + /// Executes a sequence of SQL statements using the simple query protocol, returning the resulting rows. + /// + /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that + /// point. The simple query protocol returns the values in rows as strings rather than in their binary encodings, + /// so the associated row type doesn't work with the `FromSql` trait. Rather than simply returning a list of the + /// rows, this method returns a list of an enum which indicates either the completion of one of the commands, + /// or a row of data. This preserves the framing between the separate statements in the request. + /// + /// # Warning + /// + /// Prepared statements should be use for any query which contains user-specified data, as they provided the + /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass + /// them to this method! + pub async fn simple_query(&self, query: &str) -> Result, Error> { + self.simple_query_raw(query).await?.try_collect().await + } + + pub(crate) async fn simple_query_raw(&self, query: &str) -> Result { + simple_query::simple_query(self.inner(), query).await + } + + /// Executes a sequence of SQL statements using the simple query protocol. + /// + /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that + /// point. This is intended for use when, for example, initializing a database schema. + /// + /// # Warning + /// + /// Prepared statements should be use for any query which contains user-specified data, as they provided the + /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass + /// them to this method! + pub async fn batch_execute(&self, query: &str) -> Result { + simple_query::batch_execute(self.inner(), query).await + } + + /// Begins a new database transaction. + /// + /// The transaction will roll back by default - use the `commit` method to commit it. + pub async fn transaction(&mut self) -> Result, Error> { + struct RollbackIfNotDone<'me> { + client: &'me Client, + done: bool, + } + + impl Drop for RollbackIfNotDone<'_> { + fn drop(&mut self) { + if self.done { + return; + } + + let buf = self.client.inner().with_buf(|buf| { + frontend::query("ROLLBACK", buf).unwrap(); + buf.split().freeze() + }); + let _ = self + .client + .inner() + .send(RequestMessages::Single(FrontendMessage::Raw(buf))); + } + } + + // This is done, as `Future` created by this method can be dropped after + // `RequestMessages` is synchronously send to the `Connection` by + // `batch_execute()`, but before `Responses` is asynchronously polled to + // completion. In that case `Transaction` won't be created and thus + // won't be rolled back. + { + let mut cleaner = RollbackIfNotDone { + client: self, + done: false, + }; + self.batch_execute("BEGIN").await?; + cleaner.done = true; + } + + Ok(Transaction::new(self)) + } + + /// Returns a builder for a transaction with custom settings. + /// + /// Unlike the `transaction` method, the builder can be used to control the transaction's isolation level and other + /// attributes. + pub fn build_transaction(&mut self) -> TransactionBuilder<'_> { + TransactionBuilder::new(self) + } + + /// Constructs a cancellation token that can later be used to request cancellation of a query running on the + /// connection associated with this client. + pub fn cancel_token(&self) -> CancelToken { + CancelToken { + socket_config: self.socket_config.clone(), + ssl_mode: self.ssl_mode, + process_id: self.process_id, + secret_key: self.secret_key, + } + } + + /// Query for type information + pub async fn get_type(&self, oid: Oid) -> Result { + crate::prepare::get_type(&self.inner, oid).await + } + + /// Determines if the connection to the server has already closed. + /// + /// In that case, all future queries will fail. + pub fn is_closed(&self) -> bool { + self.inner.sender.is_closed() + } +} + +impl fmt::Debug for Client { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Client").finish() + } +} diff --git a/libs/proxy/tokio-postgres2/src/codec.rs b/libs/proxy/tokio-postgres2/src/codec.rs new file mode 100644 index 0000000000..7412db785b --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/codec.rs @@ -0,0 +1,109 @@ +use bytes::{Buf, Bytes, BytesMut}; +use fallible_iterator::FallibleIterator; +use postgres_protocol2::message::backend; +use postgres_protocol2::message::frontend::CopyData; +use std::io; +use tokio_util::codec::{Decoder, Encoder}; + +pub enum FrontendMessage { + Raw(Bytes), + CopyData(CopyData>), +} + +pub enum BackendMessage { + Normal { + messages: BackendMessages, + request_complete: bool, + }, + Async(backend::Message), +} + +pub struct BackendMessages(BytesMut); + +impl BackendMessages { + pub fn empty() -> BackendMessages { + BackendMessages(BytesMut::new()) + } +} + +impl FallibleIterator for BackendMessages { + type Item = backend::Message; + type Error = io::Error; + + fn next(&mut self) -> io::Result> { + backend::Message::parse(&mut self.0) + } +} + +pub struct PostgresCodec { + pub max_message_size: Option, +} + +impl Encoder for PostgresCodec { + type Error = io::Error; + + fn encode(&mut self, item: FrontendMessage, dst: &mut BytesMut) -> io::Result<()> { + match item { + FrontendMessage::Raw(buf) => dst.extend_from_slice(&buf), + FrontendMessage::CopyData(data) => data.write(dst), + } + + Ok(()) + } +} + +impl Decoder for PostgresCodec { + type Item = BackendMessage; + type Error = io::Error; + + fn decode(&mut self, src: &mut BytesMut) -> Result, io::Error> { + let mut idx = 0; + let mut request_complete = false; + + while let Some(header) = backend::Header::parse(&src[idx..])? { + let len = header.len() as usize + 1; + if src[idx..].len() < len { + break; + } + + if let Some(max) = self.max_message_size { + if len > max { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "message too large", + )); + } + } + + match header.tag() { + backend::NOTICE_RESPONSE_TAG + | backend::NOTIFICATION_RESPONSE_TAG + | backend::PARAMETER_STATUS_TAG => { + if idx == 0 { + let message = backend::Message::parse(src)?.unwrap(); + return Ok(Some(BackendMessage::Async(message))); + } else { + break; + } + } + _ => {} + } + + idx += len; + + if header.tag() == backend::READY_FOR_QUERY_TAG { + request_complete = true; + break; + } + } + + if idx == 0 { + Ok(None) + } else { + Ok(Some(BackendMessage::Normal { + messages: BackendMessages(src.split_to(idx)), + request_complete, + })) + } + } +} diff --git a/libs/proxy/tokio-postgres2/src/config.rs b/libs/proxy/tokio-postgres2/src/config.rs new file mode 100644 index 0000000000..969c20ba47 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/config.rs @@ -0,0 +1,897 @@ +//! Connection configuration. + +use crate::connect::connect; +use crate::connect_raw::connect_raw; +use crate::tls::MakeTlsConnect; +use crate::tls::TlsConnect; +use crate::{Client, Connection, Error}; +use std::borrow::Cow; +use std::str; +use std::str::FromStr; +use std::time::Duration; +use std::{error, fmt, iter, mem}; +use tokio::io::{AsyncRead, AsyncWrite}; + +pub use postgres_protocol2::authentication::sasl::ScramKeys; +use tokio::net::TcpStream; + +/// Properties required of a session. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub enum TargetSessionAttrs { + /// No special properties are required. + Any, + /// The session must allow writes. + ReadWrite, +} + +/// TLS configuration. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub enum SslMode { + /// Do not use TLS. + Disable, + /// Attempt to connect with TLS but allow sessions without. + Prefer, + /// Require the use of TLS. + Require, +} + +/// Channel binding configuration. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub enum ChannelBinding { + /// Do not use channel binding. + Disable, + /// Attempt to use channel binding but allow sessions without. + Prefer, + /// Require the use of channel binding. + Require, +} + +/// Replication mode configuration. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub enum ReplicationMode { + /// Physical replication. + Physical, + /// Logical replication. + Logical, +} + +/// A host specification. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Host { + /// A TCP hostname. + Tcp(String), +} + +/// Precomputed keys which may override password during auth. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum AuthKeys { + /// A `ClientKey` & `ServerKey` pair for `SCRAM-SHA-256`. + ScramSha256(ScramKeys<32>), +} + +/// Connection configuration. +/// +/// Configuration can be parsed from libpq-style connection strings. These strings come in two formats: +/// +/// # Key-Value +/// +/// This format consists of space-separated key-value pairs. Values which are either the empty string or contain +/// whitespace should be wrapped in `'`. `'` and `\` characters should be backslash-escaped. +/// +/// ## Keys +/// +/// * `user` - The username to authenticate with. Required. +/// * `password` - The password to authenticate with. +/// * `dbname` - The name of the database to connect to. Defaults to the username. +/// * `options` - Command line options used to configure the server. +/// * `application_name` - Sets the `application_name` parameter on the server. +/// * `sslmode` - Controls usage of TLS. If set to `disable`, TLS will not be used. If set to `prefer`, TLS will be used +/// if available, but not used otherwise. If set to `require`, TLS will be forced to be used. Defaults to `prefer`. +/// * `host` - The host to connect to. On Unix platforms, if the host starts with a `/` character it is treated as the +/// path to the directory containing Unix domain sockets. Otherwise, it is treated as a hostname. Multiple hosts +/// can be specified, separated by commas. Each host will be tried in turn when connecting. Required if connecting +/// with the `connect` method. +/// * `port` - The port to connect to. Multiple ports can be specified, separated by commas. The number of ports must be +/// either 1, in which case it will be used for all hosts, or the same as the number of hosts. Defaults to 5432 if +/// omitted or the empty string. +/// * `connect_timeout` - The time limit in seconds applied to each socket-level connection attempt. Note that hostnames +/// can resolve to multiple IP addresses, and this limit is applied to each address. Defaults to no timeout. +/// * `target_session_attrs` - Specifies requirements of the session. If set to `read-write`, the client will check that +/// the `transaction_read_write` session parameter is set to `on`. This can be used to connect to the primary server +/// in a database cluster as opposed to the secondary read-only mirrors. Defaults to `all`. +/// * `channel_binding` - Controls usage of channel binding in the authentication process. If set to `disable`, channel +/// binding will not be used. If set to `prefer`, channel binding will be used if available, but not used otherwise. +/// If set to `require`, the authentication process will fail if channel binding is not used. Defaults to `prefer`. +/// +/// ## Examples +/// +/// ```not_rust +/// host=localhost user=postgres connect_timeout=10 keepalives=0 +/// ``` +/// +/// ```not_rust +/// host=/var/lib/postgresql,localhost port=1234 user=postgres password='password with spaces' +/// ``` +/// +/// ```not_rust +/// host=host1,host2,host3 port=1234,,5678 user=postgres target_session_attrs=read-write +/// ``` +/// +/// # Url +/// +/// This format resembles a URL with a scheme of either `postgres://` or `postgresql://`. All components are optional, +/// and the format accepts query parameters for all of the key-value pairs described in the section above. Multiple +/// host/port pairs can be comma-separated. Unix socket paths in the host section of the URL should be percent-encoded, +/// as the path component of the URL specifies the database name. +/// +/// ## Examples +/// +/// ```not_rust +/// postgresql://user@localhost +/// ``` +/// +/// ```not_rust +/// postgresql://user:password@%2Fvar%2Flib%2Fpostgresql/mydb?connect_timeout=10 +/// ``` +/// +/// ```not_rust +/// postgresql://user@host1:1234,host2,host3:5678?target_session_attrs=read-write +/// ``` +/// +/// ```not_rust +/// postgresql:///mydb?user=user&host=/var/lib/postgresql +/// ``` +#[derive(Clone, PartialEq, Eq)] +pub struct Config { + pub(crate) user: Option, + pub(crate) password: Option>, + pub(crate) auth_keys: Option>, + pub(crate) dbname: Option, + pub(crate) options: Option, + pub(crate) application_name: Option, + pub(crate) ssl_mode: SslMode, + pub(crate) host: Vec, + pub(crate) port: Vec, + pub(crate) connect_timeout: Option, + pub(crate) target_session_attrs: TargetSessionAttrs, + pub(crate) channel_binding: ChannelBinding, + pub(crate) replication_mode: Option, + pub(crate) max_backend_message_size: Option, +} + +impl Default for Config { + fn default() -> Config { + Config::new() + } +} + +impl Config { + /// Creates a new configuration. + pub fn new() -> Config { + Config { + user: None, + password: None, + auth_keys: None, + dbname: None, + options: None, + application_name: None, + ssl_mode: SslMode::Prefer, + host: vec![], + port: vec![], + connect_timeout: None, + target_session_attrs: TargetSessionAttrs::Any, + channel_binding: ChannelBinding::Prefer, + replication_mode: None, + max_backend_message_size: None, + } + } + + /// Sets the user to authenticate with. + /// + /// Required. + pub fn user(&mut self, user: &str) -> &mut Config { + self.user = Some(user.to_string()); + self + } + + /// Gets the user to authenticate with, if one has been configured with + /// the `user` method. + pub fn get_user(&self) -> Option<&str> { + self.user.as_deref() + } + + /// Sets the password to authenticate with. + pub fn password(&mut self, password: T) -> &mut Config + where + T: AsRef<[u8]>, + { + self.password = Some(password.as_ref().to_vec()); + self + } + + /// Gets the password to authenticate with, if one has been configured with + /// the `password` method. + pub fn get_password(&self) -> Option<&[u8]> { + self.password.as_deref() + } + + /// Sets precomputed protocol-specific keys to authenticate with. + /// When set, this option will override `password`. + /// See [`AuthKeys`] for more information. + pub fn auth_keys(&mut self, keys: AuthKeys) -> &mut Config { + self.auth_keys = Some(Box::new(keys)); + self + } + + /// Gets precomputed protocol-specific keys to authenticate with. + /// if one has been configured with the `auth_keys` method. + pub fn get_auth_keys(&self) -> Option { + self.auth_keys.as_deref().copied() + } + + /// Sets the name of the database to connect to. + /// + /// Defaults to the user. + pub fn dbname(&mut self, dbname: &str) -> &mut Config { + self.dbname = Some(dbname.to_string()); + self + } + + /// Gets the name of the database to connect to, if one has been configured + /// with the `dbname` method. + pub fn get_dbname(&self) -> Option<&str> { + self.dbname.as_deref() + } + + /// Sets command line options used to configure the server. + pub fn options(&mut self, options: &str) -> &mut Config { + self.options = Some(options.to_string()); + self + } + + /// Gets the command line options used to configure the server, if the + /// options have been set with the `options` method. + pub fn get_options(&self) -> Option<&str> { + self.options.as_deref() + } + + /// Sets the value of the `application_name` runtime parameter. + pub fn application_name(&mut self, application_name: &str) -> &mut Config { + self.application_name = Some(application_name.to_string()); + self + } + + /// Gets the value of the `application_name` runtime parameter, if it has + /// been set with the `application_name` method. + pub fn get_application_name(&self) -> Option<&str> { + self.application_name.as_deref() + } + + /// Sets the SSL configuration. + /// + /// Defaults to `prefer`. + pub fn ssl_mode(&mut self, ssl_mode: SslMode) -> &mut Config { + self.ssl_mode = ssl_mode; + self + } + + /// Gets the SSL configuration. + pub fn get_ssl_mode(&self) -> SslMode { + self.ssl_mode + } + + /// Adds a host to the configuration. + /// + /// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. + pub fn host(&mut self, host: &str) -> &mut Config { + self.host.push(Host::Tcp(host.to_string())); + self + } + + /// Gets the hosts that have been added to the configuration with `host`. + pub fn get_hosts(&self) -> &[Host] { + &self.host + } + + /// Adds a port to the configuration. + /// + /// Multiple ports can be specified by calling this method multiple times. There must either be no ports, in which + /// case the default of 5432 is used, a single port, in which it is used for all hosts, or the same number of ports + /// as hosts. + pub fn port(&mut self, port: u16) -> &mut Config { + self.port.push(port); + self + } + + /// Gets the ports that have been added to the configuration with `port`. + pub fn get_ports(&self) -> &[u16] { + &self.port + } + + /// Sets the timeout applied to socket-level connection attempts. + /// + /// Note that hostnames can resolve to multiple IP addresses, and this timeout will apply to each address of each + /// host separately. Defaults to no limit. + pub fn connect_timeout(&mut self, connect_timeout: Duration) -> &mut Config { + self.connect_timeout = Some(connect_timeout); + self + } + + /// Gets the connection timeout, if one has been set with the + /// `connect_timeout` method. + pub fn get_connect_timeout(&self) -> Option<&Duration> { + self.connect_timeout.as_ref() + } + + /// Sets the requirements of the session. + /// + /// This can be used to connect to the primary server in a clustered database rather than one of the read-only + /// secondary servers. Defaults to `Any`. + pub fn target_session_attrs( + &mut self, + target_session_attrs: TargetSessionAttrs, + ) -> &mut Config { + self.target_session_attrs = target_session_attrs; + self + } + + /// Gets the requirements of the session. + pub fn get_target_session_attrs(&self) -> TargetSessionAttrs { + self.target_session_attrs + } + + /// Sets the channel binding behavior. + /// + /// Defaults to `prefer`. + pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config { + self.channel_binding = channel_binding; + self + } + + /// Gets the channel binding behavior. + pub fn get_channel_binding(&self) -> ChannelBinding { + self.channel_binding + } + + /// Set replication mode. + pub fn replication_mode(&mut self, replication_mode: ReplicationMode) -> &mut Config { + self.replication_mode = Some(replication_mode); + self + } + + /// Get replication mode. + pub fn get_replication_mode(&self) -> Option { + self.replication_mode + } + + /// Set limit for backend messages size. + pub fn max_backend_message_size(&mut self, max_backend_message_size: usize) -> &mut Config { + self.max_backend_message_size = Some(max_backend_message_size); + self + } + + /// Get limit for backend messages size. + pub fn get_max_backend_message_size(&self) -> Option { + self.max_backend_message_size + } + + fn param(&mut self, key: &str, value: &str) -> Result<(), Error> { + match key { + "user" => { + self.user(value); + } + "password" => { + self.password(value); + } + "dbname" => { + self.dbname(value); + } + "options" => { + self.options(value); + } + "application_name" => { + self.application_name(value); + } + "sslmode" => { + let mode = match value { + "disable" => SslMode::Disable, + "prefer" => SslMode::Prefer, + "require" => SslMode::Require, + _ => return Err(Error::config_parse(Box::new(InvalidValue("sslmode")))), + }; + self.ssl_mode(mode); + } + "host" => { + for host in value.split(',') { + self.host(host); + } + } + "port" => { + for port in value.split(',') { + let port = if port.is_empty() { + 5432 + } else { + port.parse() + .map_err(|_| Error::config_parse(Box::new(InvalidValue("port"))))? + }; + self.port(port); + } + } + "connect_timeout" => { + let timeout = value + .parse::() + .map_err(|_| Error::config_parse(Box::new(InvalidValue("connect_timeout"))))?; + if timeout > 0 { + self.connect_timeout(Duration::from_secs(timeout as u64)); + } + } + "target_session_attrs" => { + let target_session_attrs = match value { + "any" => TargetSessionAttrs::Any, + "read-write" => TargetSessionAttrs::ReadWrite, + _ => { + return Err(Error::config_parse(Box::new(InvalidValue( + "target_session_attrs", + )))); + } + }; + self.target_session_attrs(target_session_attrs); + } + "channel_binding" => { + let channel_binding = match value { + "disable" => ChannelBinding::Disable, + "prefer" => ChannelBinding::Prefer, + "require" => ChannelBinding::Require, + _ => { + return Err(Error::config_parse(Box::new(InvalidValue( + "channel_binding", + )))) + } + }; + self.channel_binding(channel_binding); + } + "max_backend_message_size" => { + let limit = value.parse::().map_err(|_| { + Error::config_parse(Box::new(InvalidValue("max_backend_message_size"))) + })?; + if limit > 0 { + self.max_backend_message_size(limit); + } + } + key => { + return Err(Error::config_parse(Box::new(UnknownOption( + key.to_string(), + )))); + } + } + + Ok(()) + } + + /// Opens a connection to a PostgreSQL database. + /// + /// Requires the `runtime` Cargo feature (enabled by default). + pub async fn connect( + &self, + tls: T, + ) -> Result<(Client, Connection), Error> + where + T: MakeTlsConnect, + { + connect(tls, self).await + } + + /// Connects to a PostgreSQL database over an arbitrary stream. + /// + /// All of the settings other than `user`, `password`, `dbname`, `options`, and `application_name` name are ignored. + pub async fn connect_raw( + &self, + stream: S, + tls: T, + ) -> Result<(Client, Connection), Error> + where + S: AsyncRead + AsyncWrite + Unpin, + T: TlsConnect, + { + connect_raw(stream, tls, self).await + } +} + +impl FromStr for Config { + type Err = Error; + + fn from_str(s: &str) -> Result { + match UrlParser::parse(s)? { + Some(config) => Ok(config), + None => Parser::parse(s), + } + } +} + +// Omit password from debug output +impl fmt::Debug for Config { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + struct Redaction {} + impl fmt::Debug for Redaction { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "_") + } + } + + f.debug_struct("Config") + .field("user", &self.user) + .field("password", &self.password.as_ref().map(|_| Redaction {})) + .field("dbname", &self.dbname) + .field("options", &self.options) + .field("application_name", &self.application_name) + .field("ssl_mode", &self.ssl_mode) + .field("host", &self.host) + .field("port", &self.port) + .field("connect_timeout", &self.connect_timeout) + .field("target_session_attrs", &self.target_session_attrs) + .field("channel_binding", &self.channel_binding) + .field("replication", &self.replication_mode) + .finish() + } +} + +#[derive(Debug)] +struct UnknownOption(String); + +impl fmt::Display for UnknownOption { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "unknown option `{}`", self.0) + } +} + +impl error::Error for UnknownOption {} + +#[derive(Debug)] +struct InvalidValue(&'static str); + +impl fmt::Display for InvalidValue { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "invalid value for option `{}`", self.0) + } +} + +impl error::Error for InvalidValue {} + +struct Parser<'a> { + s: &'a str, + it: iter::Peekable>, +} + +impl<'a> Parser<'a> { + fn parse(s: &'a str) -> Result { + let mut parser = Parser { + s, + it: s.char_indices().peekable(), + }; + + let mut config = Config::new(); + + while let Some((key, value)) = parser.parameter()? { + config.param(key, &value)?; + } + + Ok(config) + } + + fn skip_ws(&mut self) { + self.take_while(char::is_whitespace); + } + + fn take_while(&mut self, f: F) -> &'a str + where + F: Fn(char) -> bool, + { + let start = match self.it.peek() { + Some(&(i, _)) => i, + None => return "", + }; + + loop { + match self.it.peek() { + Some(&(_, c)) if f(c) => { + self.it.next(); + } + Some(&(i, _)) => return &self.s[start..i], + None => return &self.s[start..], + } + } + } + + fn eat(&mut self, target: char) -> Result<(), Error> { + match self.it.next() { + Some((_, c)) if c == target => Ok(()), + Some((i, c)) => { + let m = format!( + "unexpected character at byte {}: expected `{}` but got `{}`", + i, target, c + ); + Err(Error::config_parse(m.into())) + } + None => Err(Error::config_parse("unexpected EOF".into())), + } + } + + fn eat_if(&mut self, target: char) -> bool { + match self.it.peek() { + Some(&(_, c)) if c == target => { + self.it.next(); + true + } + _ => false, + } + } + + fn keyword(&mut self) -> Option<&'a str> { + let s = self.take_while(|c| match c { + c if c.is_whitespace() => false, + '=' => false, + _ => true, + }); + + if s.is_empty() { + None + } else { + Some(s) + } + } + + fn value(&mut self) -> Result { + let value = if self.eat_if('\'') { + let value = self.quoted_value()?; + self.eat('\'')?; + value + } else { + self.simple_value()? + }; + + Ok(value) + } + + fn simple_value(&mut self) -> Result { + let mut value = String::new(); + + while let Some(&(_, c)) = self.it.peek() { + if c.is_whitespace() { + break; + } + + self.it.next(); + if c == '\\' { + if let Some((_, c2)) = self.it.next() { + value.push(c2); + } + } else { + value.push(c); + } + } + + if value.is_empty() { + return Err(Error::config_parse("unexpected EOF".into())); + } + + Ok(value) + } + + fn quoted_value(&mut self) -> Result { + let mut value = String::new(); + + while let Some(&(_, c)) = self.it.peek() { + if c == '\'' { + return Ok(value); + } + + self.it.next(); + if c == '\\' { + if let Some((_, c2)) = self.it.next() { + value.push(c2); + } + } else { + value.push(c); + } + } + + Err(Error::config_parse( + "unterminated quoted connection parameter value".into(), + )) + } + + fn parameter(&mut self) -> Result, Error> { + self.skip_ws(); + let keyword = match self.keyword() { + Some(keyword) => keyword, + None => return Ok(None), + }; + self.skip_ws(); + self.eat('=')?; + self.skip_ws(); + let value = self.value()?; + + Ok(Some((keyword, value))) + } +} + +// This is a pretty sloppy "URL" parser, but it matches the behavior of libpq, where things really aren't very strict +struct UrlParser<'a> { + s: &'a str, + config: Config, +} + +impl<'a> UrlParser<'a> { + fn parse(s: &'a str) -> Result, Error> { + let s = match Self::remove_url_prefix(s) { + Some(s) => s, + None => return Ok(None), + }; + + let mut parser = UrlParser { + s, + config: Config::new(), + }; + + parser.parse_credentials()?; + parser.parse_host()?; + parser.parse_path()?; + parser.parse_params()?; + + Ok(Some(parser.config)) + } + + fn remove_url_prefix(s: &str) -> Option<&str> { + for prefix in &["postgres://", "postgresql://"] { + if let Some(stripped) = s.strip_prefix(prefix) { + return Some(stripped); + } + } + + None + } + + fn take_until(&mut self, end: &[char]) -> Option<&'a str> { + match self.s.find(end) { + Some(pos) => { + let (head, tail) = self.s.split_at(pos); + self.s = tail; + Some(head) + } + None => None, + } + } + + fn take_all(&mut self) -> &'a str { + mem::take(&mut self.s) + } + + fn eat_byte(&mut self) { + self.s = &self.s[1..]; + } + + fn parse_credentials(&mut self) -> Result<(), Error> { + let creds = match self.take_until(&['@']) { + Some(creds) => creds, + None => return Ok(()), + }; + self.eat_byte(); + + let mut it = creds.splitn(2, ':'); + let user = self.decode(it.next().unwrap())?; + self.config.user(&user); + + if let Some(password) = it.next() { + let password = Cow::from(percent_encoding::percent_decode(password.as_bytes())); + self.config.password(password); + } + + Ok(()) + } + + fn parse_host(&mut self) -> Result<(), Error> { + let host = match self.take_until(&['/', '?']) { + Some(host) => host, + None => self.take_all(), + }; + + if host.is_empty() { + return Ok(()); + } + + for chunk in host.split(',') { + let (host, port) = if chunk.starts_with('[') { + let idx = match chunk.find(']') { + Some(idx) => idx, + None => return Err(Error::config_parse(InvalidValue("host").into())), + }; + + let host = &chunk[1..idx]; + let remaining = &chunk[idx + 1..]; + let port = if let Some(port) = remaining.strip_prefix(':') { + Some(port) + } else if remaining.is_empty() { + None + } else { + return Err(Error::config_parse(InvalidValue("host").into())); + }; + + (host, port) + } else { + let mut it = chunk.splitn(2, ':'); + (it.next().unwrap(), it.next()) + }; + + self.host_param(host)?; + let port = self.decode(port.unwrap_or("5432"))?; + self.config.param("port", &port)?; + } + + Ok(()) + } + + fn parse_path(&mut self) -> Result<(), Error> { + if !self.s.starts_with('/') { + return Ok(()); + } + self.eat_byte(); + + let dbname = match self.take_until(&['?']) { + Some(dbname) => dbname, + None => self.take_all(), + }; + + if !dbname.is_empty() { + self.config.dbname(&self.decode(dbname)?); + } + + Ok(()) + } + + fn parse_params(&mut self) -> Result<(), Error> { + if !self.s.starts_with('?') { + return Ok(()); + } + self.eat_byte(); + + while !self.s.is_empty() { + let key = match self.take_until(&['=']) { + Some(key) => self.decode(key)?, + None => return Err(Error::config_parse("unterminated parameter".into())), + }; + self.eat_byte(); + + let value = match self.take_until(&['&']) { + Some(value) => { + self.eat_byte(); + value + } + None => self.take_all(), + }; + + if key == "host" { + self.host_param(value)?; + } else { + let value = self.decode(value)?; + self.config.param(&key, &value)?; + } + } + + Ok(()) + } + + fn host_param(&mut self, s: &str) -> Result<(), Error> { + let s = self.decode(s)?; + self.config.param("host", &s) + } + + fn decode(&self, s: &'a str) -> Result, Error> { + percent_encoding::percent_decode(s.as_bytes()) + .decode_utf8() + .map_err(|e| Error::config_parse(e.into())) + } +} diff --git a/libs/proxy/tokio-postgres2/src/connect.rs b/libs/proxy/tokio-postgres2/src/connect.rs new file mode 100644 index 0000000000..7517fe0cde --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/connect.rs @@ -0,0 +1,112 @@ +use crate::client::SocketConfig; +use crate::config::{Host, TargetSessionAttrs}; +use crate::connect_raw::connect_raw; +use crate::connect_socket::connect_socket; +use crate::tls::{MakeTlsConnect, TlsConnect}; +use crate::{Client, Config, Connection, Error, SimpleQueryMessage}; +use futures_util::{future, pin_mut, Future, FutureExt, Stream}; +use std::io; +use std::task::Poll; +use tokio::net::TcpStream; + +pub async fn connect( + mut tls: T, + config: &Config, +) -> Result<(Client, Connection), Error> +where + T: MakeTlsConnect, +{ + if config.host.is_empty() { + return Err(Error::config("host missing".into())); + } + + if config.port.len() > 1 && config.port.len() != config.host.len() { + return Err(Error::config("invalid number of ports".into())); + } + + let mut error = None; + for (i, host) in config.host.iter().enumerate() { + let port = config + .port + .get(i) + .or_else(|| config.port.first()) + .copied() + .unwrap_or(5432); + + let hostname = match host { + Host::Tcp(host) => host.as_str(), + }; + + let tls = tls + .make_tls_connect(hostname) + .map_err(|e| Error::tls(e.into()))?; + + match connect_once(host, port, tls, config).await { + Ok((client, connection)) => return Ok((client, connection)), + Err(e) => error = Some(e), + } + } + + Err(error.unwrap()) +} + +async fn connect_once( + host: &Host, + port: u16, + tls: T, + config: &Config, +) -> Result<(Client, Connection), Error> +where + T: TlsConnect, +{ + let socket = connect_socket(host, port, config.connect_timeout).await?; + let (mut client, mut connection) = connect_raw(socket, tls, config).await?; + + if let TargetSessionAttrs::ReadWrite = config.target_session_attrs { + let rows = client.simple_query_raw("SHOW transaction_read_only"); + pin_mut!(rows); + + let rows = future::poll_fn(|cx| { + if connection.poll_unpin(cx)?.is_ready() { + return Poll::Ready(Err(Error::closed())); + } + + rows.as_mut().poll(cx) + }) + .await?; + pin_mut!(rows); + + loop { + let next = future::poll_fn(|cx| { + if connection.poll_unpin(cx)?.is_ready() { + return Poll::Ready(Some(Err(Error::closed()))); + } + + rows.as_mut().poll_next(cx) + }); + + match next.await.transpose()? { + Some(SimpleQueryMessage::Row(row)) => { + if row.try_get(0)? == Some("on") { + return Err(Error::connect(io::Error::new( + io::ErrorKind::PermissionDenied, + "database does not allow writes", + ))); + } else { + break; + } + } + Some(_) => {} + None => return Err(Error::unexpected_message()), + } + } + } + + client.set_socket_config(SocketConfig { + host: host.clone(), + port, + connect_timeout: config.connect_timeout, + }); + + Ok((client, connection)) +} diff --git a/libs/proxy/tokio-postgres2/src/connect_raw.rs b/libs/proxy/tokio-postgres2/src/connect_raw.rs new file mode 100644 index 0000000000..80677af969 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/connect_raw.rs @@ -0,0 +1,359 @@ +use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec}; +use crate::config::{self, AuthKeys, Config, ReplicationMode}; +use crate::connect_tls::connect_tls; +use crate::maybe_tls_stream::MaybeTlsStream; +use crate::tls::{TlsConnect, TlsStream}; +use crate::{Client, Connection, Error}; +use bytes::BytesMut; +use fallible_iterator::FallibleIterator; +use futures_util::{ready, Sink, SinkExt, Stream, TryStreamExt}; +use postgres_protocol2::authentication; +use postgres_protocol2::authentication::sasl; +use postgres_protocol2::authentication::sasl::ScramSha256; +use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message}; +use postgres_protocol2::message::frontend; +use std::collections::{HashMap, VecDeque}; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::sync::mpsc; +use tokio_util::codec::Framed; + +pub struct StartupStream { + inner: Framed, PostgresCodec>, + buf: BackendMessages, + delayed: VecDeque, +} + +impl Sink for StartupStream +where + S: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin, +{ + type Error = io::Error; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_ready(cx) + } + + fn start_send(mut self: Pin<&mut Self>, item: FrontendMessage) -> io::Result<()> { + Pin::new(&mut self.inner).start_send(item) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_close(cx) + } +} + +impl Stream for StartupStream +where + S: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin, +{ + type Item = io::Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + loop { + match self.buf.next() { + Ok(Some(message)) => return Poll::Ready(Some(Ok(message))), + Ok(None) => {} + Err(e) => return Poll::Ready(Some(Err(e))), + } + + match ready!(Pin::new(&mut self.inner).poll_next(cx)) { + Some(Ok(BackendMessage::Normal { messages, .. })) => self.buf = messages, + Some(Ok(BackendMessage::Async(message))) => return Poll::Ready(Some(Ok(message))), + Some(Err(e)) => return Poll::Ready(Some(Err(e))), + None => return Poll::Ready(None), + } + } + } +} + +pub async fn connect_raw( + stream: S, + tls: T, + config: &Config, +) -> Result<(Client, Connection), Error> +where + S: AsyncRead + AsyncWrite + Unpin, + T: TlsConnect, +{ + let stream = connect_tls(stream, config.ssl_mode, tls).await?; + + let mut stream = StartupStream { + inner: Framed::new( + stream, + PostgresCodec { + max_message_size: config.max_backend_message_size, + }, + ), + buf: BackendMessages::empty(), + delayed: VecDeque::new(), + }; + + startup(&mut stream, config).await?; + authenticate(&mut stream, config).await?; + let (process_id, secret_key, parameters) = read_info(&mut stream).await?; + + let (sender, receiver) = mpsc::unbounded_channel(); + let client = Client::new(sender, config.ssl_mode, process_id, secret_key); + let connection = Connection::new(stream.inner, stream.delayed, parameters, receiver); + + Ok((client, connection)) +} + +async fn startup(stream: &mut StartupStream, config: &Config) -> Result<(), Error> +where + S: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin, +{ + let mut params = vec![("client_encoding", "UTF8")]; + if let Some(user) = &config.user { + params.push(("user", &**user)); + } + if let Some(dbname) = &config.dbname { + params.push(("database", &**dbname)); + } + if let Some(options) = &config.options { + params.push(("options", &**options)); + } + if let Some(application_name) = &config.application_name { + params.push(("application_name", &**application_name)); + } + if let Some(replication_mode) = &config.replication_mode { + match replication_mode { + ReplicationMode::Physical => params.push(("replication", "true")), + ReplicationMode::Logical => params.push(("replication", "database")), + } + } + + let mut buf = BytesMut::new(); + frontend::startup_message(params, &mut buf).map_err(Error::encode)?; + + stream + .send(FrontendMessage::Raw(buf.freeze())) + .await + .map_err(Error::io) +} + +async fn authenticate(stream: &mut StartupStream, config: &Config) -> Result<(), Error> +where + S: AsyncRead + AsyncWrite + Unpin, + T: TlsStream + Unpin, +{ + match stream.try_next().await.map_err(Error::io)? { + Some(Message::AuthenticationOk) => { + can_skip_channel_binding(config)?; + return Ok(()); + } + Some(Message::AuthenticationCleartextPassword) => { + can_skip_channel_binding(config)?; + + let pass = config + .password + .as_ref() + .ok_or_else(|| Error::config("password missing".into()))?; + + authenticate_password(stream, pass).await?; + } + Some(Message::AuthenticationMd5Password(body)) => { + can_skip_channel_binding(config)?; + + let user = config + .user + .as_ref() + .ok_or_else(|| Error::config("user missing".into()))?; + let pass = config + .password + .as_ref() + .ok_or_else(|| Error::config("password missing".into()))?; + + let output = authentication::md5_hash(user.as_bytes(), pass, body.salt()); + authenticate_password(stream, output.as_bytes()).await?; + } + Some(Message::AuthenticationSasl(body)) => { + authenticate_sasl(stream, body, config).await?; + } + Some(Message::AuthenticationKerberosV5) + | Some(Message::AuthenticationScmCredential) + | Some(Message::AuthenticationGss) + | Some(Message::AuthenticationSspi) => { + return Err(Error::authentication( + "unsupported authentication method".into(), + )) + } + Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), + Some(_) => return Err(Error::unexpected_message()), + None => return Err(Error::closed()), + } + + match stream.try_next().await.map_err(Error::io)? { + Some(Message::AuthenticationOk) => Ok(()), + Some(Message::ErrorResponse(body)) => Err(Error::db(body)), + Some(_) => Err(Error::unexpected_message()), + None => Err(Error::closed()), + } +} + +fn can_skip_channel_binding(config: &Config) -> Result<(), Error> { + match config.channel_binding { + config::ChannelBinding::Disable | config::ChannelBinding::Prefer => Ok(()), + config::ChannelBinding::Require => Err(Error::authentication( + "server did not use channel binding".into(), + )), + } +} + +async fn authenticate_password( + stream: &mut StartupStream, + password: &[u8], +) -> Result<(), Error> +where + S: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin, +{ + let mut buf = BytesMut::new(); + frontend::password_message(password, &mut buf).map_err(Error::encode)?; + + stream + .send(FrontendMessage::Raw(buf.freeze())) + .await + .map_err(Error::io) +} + +async fn authenticate_sasl( + stream: &mut StartupStream, + body: AuthenticationSaslBody, + config: &Config, +) -> Result<(), Error> +where + S: AsyncRead + AsyncWrite + Unpin, + T: TlsStream + Unpin, +{ + let mut has_scram = false; + let mut has_scram_plus = false; + let mut mechanisms = body.mechanisms(); + while let Some(mechanism) = mechanisms.next().map_err(Error::parse)? { + match mechanism { + sasl::SCRAM_SHA_256 => has_scram = true, + sasl::SCRAM_SHA_256_PLUS => has_scram_plus = true, + _ => {} + } + } + + let channel_binding = stream + .inner + .get_ref() + .channel_binding() + .tls_server_end_point + .filter(|_| config.channel_binding != config::ChannelBinding::Disable) + .map(sasl::ChannelBinding::tls_server_end_point); + + let (channel_binding, mechanism) = if has_scram_plus { + match channel_binding { + Some(channel_binding) => (channel_binding, sasl::SCRAM_SHA_256_PLUS), + None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256), + } + } else if has_scram { + match channel_binding { + Some(_) => (sasl::ChannelBinding::unrequested(), sasl::SCRAM_SHA_256), + None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256), + } + } else { + return Err(Error::authentication("unsupported SASL mechanism".into())); + }; + + if mechanism != sasl::SCRAM_SHA_256_PLUS { + can_skip_channel_binding(config)?; + } + + let mut scram = if let Some(AuthKeys::ScramSha256(keys)) = config.get_auth_keys() { + ScramSha256::new_with_keys(keys, channel_binding) + } else if let Some(password) = config.get_password() { + ScramSha256::new(password, channel_binding) + } else { + return Err(Error::config("password or auth keys missing".into())); + }; + + let mut buf = BytesMut::new(); + frontend::sasl_initial_response(mechanism, scram.message(), &mut buf).map_err(Error::encode)?; + stream + .send(FrontendMessage::Raw(buf.freeze())) + .await + .map_err(Error::io)?; + + let body = match stream.try_next().await.map_err(Error::io)? { + Some(Message::AuthenticationSaslContinue(body)) => body, + Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), + Some(_) => return Err(Error::unexpected_message()), + None => return Err(Error::closed()), + }; + + scram + .update(body.data()) + .await + .map_err(|e| Error::authentication(e.into()))?; + + let mut buf = BytesMut::new(); + frontend::sasl_response(scram.message(), &mut buf).map_err(Error::encode)?; + stream + .send(FrontendMessage::Raw(buf.freeze())) + .await + .map_err(Error::io)?; + + let body = match stream.try_next().await.map_err(Error::io)? { + Some(Message::AuthenticationSaslFinal(body)) => body, + Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), + Some(_) => return Err(Error::unexpected_message()), + None => return Err(Error::closed()), + }; + + scram + .finish(body.data()) + .map_err(|e| Error::authentication(e.into()))?; + + Ok(()) +} + +async fn read_info( + stream: &mut StartupStream, +) -> Result<(i32, i32, HashMap), Error> +where + S: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin, +{ + let mut process_id = 0; + let mut secret_key = 0; + let mut parameters = HashMap::new(); + + loop { + match stream.try_next().await.map_err(Error::io)? { + Some(Message::BackendKeyData(body)) => { + process_id = body.process_id(); + secret_key = body.secret_key(); + } + Some(Message::ParameterStatus(body)) => { + parameters.insert( + body.name().map_err(Error::parse)?.to_string(), + body.value().map_err(Error::parse)?.to_string(), + ); + } + Some(msg @ Message::NoticeResponse(_)) => { + stream.delayed.push_back(BackendMessage::Async(msg)) + } + Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key, parameters)), + Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), + Some(_) => return Err(Error::unexpected_message()), + None => return Err(Error::closed()), + } + } +} diff --git a/libs/proxy/tokio-postgres2/src/connect_socket.rs b/libs/proxy/tokio-postgres2/src/connect_socket.rs new file mode 100644 index 0000000000..336a13317f --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/connect_socket.rs @@ -0,0 +1,65 @@ +use crate::config::Host; +use crate::Error; +use std::future::Future; +use std::io; +use std::time::Duration; +use tokio::net::{self, TcpStream}; +use tokio::time; + +pub(crate) async fn connect_socket( + host: &Host, + port: u16, + connect_timeout: Option, +) -> Result { + match host { + Host::Tcp(host) => { + let addrs = net::lookup_host((&**host, port)) + .await + .map_err(Error::connect)?; + + let mut last_err = None; + + for addr in addrs { + let stream = + match connect_with_timeout(TcpStream::connect(addr), connect_timeout).await { + Ok(stream) => stream, + Err(e) => { + last_err = Some(e); + continue; + } + }; + + stream.set_nodelay(true).map_err(Error::connect)?; + + return Ok(stream); + } + + Err(last_err.unwrap_or_else(|| { + Error::connect(io::Error::new( + io::ErrorKind::InvalidInput, + "could not resolve any addresses", + )) + })) + } + } +} + +async fn connect_with_timeout(connect: F, timeout: Option) -> Result +where + F: Future>, +{ + match timeout { + Some(timeout) => match time::timeout(timeout, connect).await { + Ok(Ok(socket)) => Ok(socket), + Ok(Err(e)) => Err(Error::connect(e)), + Err(_) => Err(Error::connect(io::Error::new( + io::ErrorKind::TimedOut, + "connection timed out", + ))), + }, + None => match connect.await { + Ok(socket) => Ok(socket), + Err(e) => Err(Error::connect(e)), + }, + } +} diff --git a/libs/proxy/tokio-postgres2/src/connect_tls.rs b/libs/proxy/tokio-postgres2/src/connect_tls.rs new file mode 100644 index 0000000000..64b0b68abc --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/connect_tls.rs @@ -0,0 +1,48 @@ +use crate::config::SslMode; +use crate::maybe_tls_stream::MaybeTlsStream; +use crate::tls::private::ForcePrivateApi; +use crate::tls::TlsConnect; +use crate::Error; +use bytes::BytesMut; +use postgres_protocol2::message::frontend; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +pub async fn connect_tls( + mut stream: S, + mode: SslMode, + tls: T, +) -> Result, Error> +where + S: AsyncRead + AsyncWrite + Unpin, + T: TlsConnect, +{ + match mode { + SslMode::Disable => return Ok(MaybeTlsStream::Raw(stream)), + SslMode::Prefer if !tls.can_connect(ForcePrivateApi) => { + return Ok(MaybeTlsStream::Raw(stream)) + } + SslMode::Prefer | SslMode::Require => {} + } + + let mut buf = BytesMut::new(); + frontend::ssl_request(&mut buf); + stream.write_all(&buf).await.map_err(Error::io)?; + + let mut buf = [0]; + stream.read_exact(&mut buf).await.map_err(Error::io)?; + + if buf[0] != b'S' { + if SslMode::Require == mode { + return Err(Error::tls("server does not support TLS".into())); + } else { + return Ok(MaybeTlsStream::Raw(stream)); + } + } + + let stream = tls + .connect(stream) + .await + .map_err(|e| Error::tls(e.into()))?; + + Ok(MaybeTlsStream::Tls(stream)) +} diff --git a/libs/proxy/tokio-postgres2/src/connection.rs b/libs/proxy/tokio-postgres2/src/connection.rs new file mode 100644 index 0000000000..0aa5c77e22 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/connection.rs @@ -0,0 +1,323 @@ +use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec}; +use crate::error::DbError; +use crate::maybe_tls_stream::MaybeTlsStream; +use crate::{AsyncMessage, Error, Notification}; +use bytes::BytesMut; +use fallible_iterator::FallibleIterator; +use futures_util::{ready, Sink, Stream}; +use log::{info, trace}; +use postgres_protocol2::message::backend::Message; +use postgres_protocol2::message::frontend; +use std::collections::{HashMap, VecDeque}; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::sync::mpsc; +use tokio_util::codec::Framed; +use tokio_util::sync::PollSender; + +pub enum RequestMessages { + Single(FrontendMessage), +} + +pub struct Request { + pub messages: RequestMessages, + pub sender: mpsc::Sender, +} + +pub struct Response { + sender: PollSender, +} + +#[derive(PartialEq, Debug)] +enum State { + Active, + Terminating, + Closing, +} + +/// A connection to a PostgreSQL database. +/// +/// This is one half of what is returned when a new connection is established. It performs the actual IO with the +/// server, and should generally be spawned off onto an executor to run in the background. +/// +/// `Connection` implements `Future`, and only resolves when the connection is closed, either because a fatal error has +/// occurred, or because its associated `Client` has dropped and all outstanding work has completed. +#[must_use = "futures do nothing unless polled"] +pub struct Connection { + /// HACK: we need this in the Neon Proxy. + pub stream: Framed, PostgresCodec>, + /// HACK: we need this in the Neon Proxy to forward params. + pub parameters: HashMap, + receiver: mpsc::UnboundedReceiver, + pending_request: Option, + pending_responses: VecDeque, + responses: VecDeque, + state: State, +} + +impl Connection +where + S: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin, +{ + pub(crate) fn new( + stream: Framed, PostgresCodec>, + pending_responses: VecDeque, + parameters: HashMap, + receiver: mpsc::UnboundedReceiver, + ) -> Connection { + Connection { + stream, + parameters, + receiver, + pending_request: None, + pending_responses, + responses: VecDeque::new(), + state: State::Active, + } + } + + fn poll_response( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>> { + if let Some(message) = self.pending_responses.pop_front() { + trace!("retrying pending response"); + return Poll::Ready(Some(Ok(message))); + } + + Pin::new(&mut self.stream) + .poll_next(cx) + .map(|o| o.map(|r| r.map_err(Error::io))) + } + + fn poll_read(&mut self, cx: &mut Context<'_>) -> Result, Error> { + if self.state != State::Active { + trace!("poll_read: done"); + return Ok(None); + } + + loop { + let message = match self.poll_response(cx)? { + Poll::Ready(Some(message)) => message, + Poll::Ready(None) => return Err(Error::closed()), + Poll::Pending => { + trace!("poll_read: waiting on response"); + return Ok(None); + } + }; + + let (mut messages, request_complete) = match message { + BackendMessage::Async(Message::NoticeResponse(body)) => { + let error = DbError::parse(&mut body.fields()).map_err(Error::parse)?; + return Ok(Some(AsyncMessage::Notice(error))); + } + BackendMessage::Async(Message::NotificationResponse(body)) => { + let notification = Notification { + process_id: body.process_id(), + channel: body.channel().map_err(Error::parse)?.to_string(), + payload: body.message().map_err(Error::parse)?.to_string(), + }; + return Ok(Some(AsyncMessage::Notification(notification))); + } + BackendMessage::Async(Message::ParameterStatus(body)) => { + self.parameters.insert( + body.name().map_err(Error::parse)?.to_string(), + body.value().map_err(Error::parse)?.to_string(), + ); + continue; + } + BackendMessage::Async(_) => unreachable!(), + BackendMessage::Normal { + messages, + request_complete, + } => (messages, request_complete), + }; + + let mut response = match self.responses.pop_front() { + Some(response) => response, + None => match messages.next().map_err(Error::parse)? { + Some(Message::ErrorResponse(error)) => return Err(Error::db(error)), + _ => return Err(Error::unexpected_message()), + }, + }; + + match response.sender.poll_reserve(cx) { + Poll::Ready(Ok(())) => { + let _ = response.sender.send_item(messages); + if !request_complete { + self.responses.push_front(response); + } + } + Poll::Ready(Err(_)) => { + // we need to keep paging through the rest of the messages even if the receiver's hung up + if !request_complete { + self.responses.push_front(response); + } + } + Poll::Pending => { + self.responses.push_front(response); + self.pending_responses.push_back(BackendMessage::Normal { + messages, + request_complete, + }); + trace!("poll_read: waiting on sender"); + return Ok(None); + } + } + } + } + + fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll> { + if let Some(messages) = self.pending_request.take() { + trace!("retrying pending request"); + return Poll::Ready(Some(messages)); + } + + if self.receiver.is_closed() { + return Poll::Ready(None); + } + + match self.receiver.poll_recv(cx) { + Poll::Ready(Some(request)) => { + trace!("polled new request"); + self.responses.push_back(Response { + sender: PollSender::new(request.sender), + }); + Poll::Ready(Some(request.messages)) + } + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } + + fn poll_write(&mut self, cx: &mut Context<'_>) -> Result { + loop { + if self.state == State::Closing { + trace!("poll_write: done"); + return Ok(false); + } + + if Pin::new(&mut self.stream) + .poll_ready(cx) + .map_err(Error::io)? + .is_pending() + { + trace!("poll_write: waiting on socket"); + return Ok(false); + } + + let request = match self.poll_request(cx) { + Poll::Ready(Some(request)) => request, + Poll::Ready(None) if self.responses.is_empty() && self.state == State::Active => { + trace!("poll_write: at eof, terminating"); + self.state = State::Terminating; + let mut request = BytesMut::new(); + frontend::terminate(&mut request); + RequestMessages::Single(FrontendMessage::Raw(request.freeze())) + } + Poll::Ready(None) => { + trace!( + "poll_write: at eof, pending responses {}", + self.responses.len() + ); + return Ok(true); + } + Poll::Pending => { + trace!("poll_write: waiting on request"); + return Ok(true); + } + }; + + match request { + RequestMessages::Single(request) => { + Pin::new(&mut self.stream) + .start_send(request) + .map_err(Error::io)?; + if self.state == State::Terminating { + trace!("poll_write: sent eof, closing"); + self.state = State::Closing; + } + } + } + } + } + + fn poll_flush(&mut self, cx: &mut Context<'_>) -> Result<(), Error> { + match Pin::new(&mut self.stream) + .poll_flush(cx) + .map_err(Error::io)? + { + Poll::Ready(()) => trace!("poll_flush: flushed"), + Poll::Pending => trace!("poll_flush: waiting on socket"), + } + Ok(()) + } + + fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll> { + if self.state != State::Closing { + return Poll::Pending; + } + + match Pin::new(&mut self.stream) + .poll_close(cx) + .map_err(Error::io)? + { + Poll::Ready(()) => { + trace!("poll_shutdown: complete"); + Poll::Ready(Ok(())) + } + Poll::Pending => { + trace!("poll_shutdown: waiting on socket"); + Poll::Pending + } + } + } + + /// Returns the value of a runtime parameter for this connection. + pub fn parameter(&self, name: &str) -> Option<&str> { + self.parameters.get(name).map(|s| &**s) + } + + /// Polls for asynchronous messages from the server. + /// + /// The server can send notices as well as notifications asynchronously to the client. Applications that wish to + /// examine those messages should use this method to drive the connection rather than its `Future` implementation. + pub fn poll_message( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>> { + let message = self.poll_read(cx)?; + let want_flush = self.poll_write(cx)?; + if want_flush { + self.poll_flush(cx)?; + } + match message { + Some(message) => Poll::Ready(Some(Ok(message))), + None => match self.poll_shutdown(cx) { + Poll::Ready(Ok(())) => Poll::Ready(None), + Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))), + Poll::Pending => Poll::Pending, + }, + } + } +} + +impl Future for Connection +where + S: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin, +{ + type Output = Result<(), Error>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + while let Some(message) = ready!(self.poll_message(cx)?) { + if let AsyncMessage::Notice(notice) = message { + info!("{}: {}", notice.severity(), notice.message()); + } + } + Poll::Ready(Ok(())) + } +} diff --git a/libs/proxy/tokio-postgres2/src/error/mod.rs b/libs/proxy/tokio-postgres2/src/error/mod.rs new file mode 100644 index 0000000000..6514322250 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/error/mod.rs @@ -0,0 +1,501 @@ +//! Errors. + +use fallible_iterator::FallibleIterator; +use postgres_protocol2::message::backend::{ErrorFields, ErrorResponseBody}; +use std::error::{self, Error as _Error}; +use std::fmt; +use std::io; + +pub use self::sqlstate::*; + +#[allow(clippy::unreadable_literal)] +mod sqlstate; + +/// The severity of a Postgres error or notice. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum Severity { + /// PANIC + Panic, + /// FATAL + Fatal, + /// ERROR + Error, + /// WARNING + Warning, + /// NOTICE + Notice, + /// DEBUG + Debug, + /// INFO + Info, + /// LOG + Log, +} + +impl fmt::Display for Severity { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + let s = match *self { + Severity::Panic => "PANIC", + Severity::Fatal => "FATAL", + Severity::Error => "ERROR", + Severity::Warning => "WARNING", + Severity::Notice => "NOTICE", + Severity::Debug => "DEBUG", + Severity::Info => "INFO", + Severity::Log => "LOG", + }; + fmt.write_str(s) + } +} + +impl Severity { + fn from_str(s: &str) -> Option { + match s { + "PANIC" => Some(Severity::Panic), + "FATAL" => Some(Severity::Fatal), + "ERROR" => Some(Severity::Error), + "WARNING" => Some(Severity::Warning), + "NOTICE" => Some(Severity::Notice), + "DEBUG" => Some(Severity::Debug), + "INFO" => Some(Severity::Info), + "LOG" => Some(Severity::Log), + _ => None, + } + } +} + +/// A Postgres error or notice. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct DbError { + severity: String, + parsed_severity: Option, + code: SqlState, + message: String, + detail: Option, + hint: Option, + position: Option, + where_: Option, + schema: Option, + table: Option, + column: Option, + datatype: Option, + constraint: Option, + file: Option, + line: Option, + routine: Option, +} + +impl DbError { + pub(crate) fn parse(fields: &mut ErrorFields<'_>) -> io::Result { + let mut severity = None; + let mut parsed_severity = None; + let mut code = None; + let mut message = None; + let mut detail = None; + let mut hint = None; + let mut normal_position = None; + let mut internal_position = None; + let mut internal_query = None; + let mut where_ = None; + let mut schema = None; + let mut table = None; + let mut column = None; + let mut datatype = None; + let mut constraint = None; + let mut file = None; + let mut line = None; + let mut routine = None; + + while let Some(field) = fields.next()? { + match field.type_() { + b'S' => severity = Some(field.value().to_owned()), + b'C' => code = Some(SqlState::from_code(field.value())), + b'M' => message = Some(field.value().to_owned()), + b'D' => detail = Some(field.value().to_owned()), + b'H' => hint = Some(field.value().to_owned()), + b'P' => { + normal_position = Some(field.value().parse::().map_err(|_| { + io::Error::new( + io::ErrorKind::InvalidInput, + "`P` field did not contain an integer", + ) + })?); + } + b'p' => { + internal_position = Some(field.value().parse::().map_err(|_| { + io::Error::new( + io::ErrorKind::InvalidInput, + "`p` field did not contain an integer", + ) + })?); + } + b'q' => internal_query = Some(field.value().to_owned()), + b'W' => where_ = Some(field.value().to_owned()), + b's' => schema = Some(field.value().to_owned()), + b't' => table = Some(field.value().to_owned()), + b'c' => column = Some(field.value().to_owned()), + b'd' => datatype = Some(field.value().to_owned()), + b'n' => constraint = Some(field.value().to_owned()), + b'F' => file = Some(field.value().to_owned()), + b'L' => { + line = Some(field.value().parse::().map_err(|_| { + io::Error::new( + io::ErrorKind::InvalidInput, + "`L` field did not contain an integer", + ) + })?); + } + b'R' => routine = Some(field.value().to_owned()), + b'V' => { + parsed_severity = Some(Severity::from_str(field.value()).ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "`V` field contained an invalid value", + ) + })?); + } + _ => {} + } + } + + Ok(DbError { + severity: severity + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "`S` field missing"))?, + parsed_severity, + code: code + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "`C` field missing"))?, + message: message + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "`M` field missing"))?, + detail, + hint, + position: match normal_position { + Some(position) => Some(ErrorPosition::Original(position)), + None => match internal_position { + Some(position) => Some(ErrorPosition::Internal { + position, + query: internal_query.ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "`q` field missing but `p` field present", + ) + })?, + }), + None => None, + }, + }, + where_, + schema, + table, + column, + datatype, + constraint, + file, + line, + routine, + }) + } + + /// The field contents are ERROR, FATAL, or PANIC (in an error message), + /// or WARNING, NOTICE, DEBUG, INFO, or LOG (in a notice message), or a + /// localized translation of one of these. + pub fn severity(&self) -> &str { + &self.severity + } + + /// A parsed, nonlocalized version of `severity`. (PostgreSQL 9.6+) + pub fn parsed_severity(&self) -> Option { + self.parsed_severity + } + + /// The SQLSTATE code for the error. + pub fn code(&self) -> &SqlState { + &self.code + } + + /// The primary human-readable error message. + /// + /// This should be accurate but terse (typically one line). + pub fn message(&self) -> &str { + &self.message + } + + /// An optional secondary error message carrying more detail about the + /// problem. + /// + /// Might run to multiple lines. + pub fn detail(&self) -> Option<&str> { + self.detail.as_deref() + } + + /// An optional suggestion what to do about the problem. + /// + /// This is intended to differ from `detail` in that it offers advice + /// (potentially inappropriate) rather than hard facts. Might run to + /// multiple lines. + pub fn hint(&self) -> Option<&str> { + self.hint.as_deref() + } + + /// An optional error cursor position into either the original query string + /// or an internally generated query. + pub fn position(&self) -> Option<&ErrorPosition> { + self.position.as_ref() + } + + /// An indication of the context in which the error occurred. + /// + /// Presently this includes a call stack traceback of active procedural + /// language functions and internally-generated queries. The trace is one + /// entry per line, most recent first. + pub fn where_(&self) -> Option<&str> { + self.where_.as_deref() + } + + /// If the error was associated with a specific database object, the name + /// of the schema containing that object, if any. (PostgreSQL 9.3+) + pub fn schema(&self) -> Option<&str> { + self.schema.as_deref() + } + + /// If the error was associated with a specific table, the name of the + /// table. (Refer to the schema name field for the name of the table's + /// schema.) (PostgreSQL 9.3+) + pub fn table(&self) -> Option<&str> { + self.table.as_deref() + } + + /// If the error was associated with a specific table column, the name of + /// the column. + /// + /// (Refer to the schema and table name fields to identify the table.) + /// (PostgreSQL 9.3+) + pub fn column(&self) -> Option<&str> { + self.column.as_deref() + } + + /// If the error was associated with a specific data type, the name of the + /// data type. (Refer to the schema name field for the name of the data + /// type's schema.) (PostgreSQL 9.3+) + pub fn datatype(&self) -> Option<&str> { + self.datatype.as_deref() + } + + /// If the error was associated with a specific constraint, the name of the + /// constraint. + /// + /// Refer to fields listed above for the associated table or domain. + /// (For this purpose, indexes are treated as constraints, even if they + /// weren't created with constraint syntax.) (PostgreSQL 9.3+) + pub fn constraint(&self) -> Option<&str> { + self.constraint.as_deref() + } + + /// The file name of the source-code location where the error was reported. + pub fn file(&self) -> Option<&str> { + self.file.as_deref() + } + + /// The line number of the source-code location where the error was + /// reported. + pub fn line(&self) -> Option { + self.line + } + + /// The name of the source-code routine reporting the error. + pub fn routine(&self) -> Option<&str> { + self.routine.as_deref() + } +} + +impl fmt::Display for DbError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "{}: {}", self.severity, self.message)?; + if let Some(detail) = &self.detail { + write!(fmt, "\nDETAIL: {}", detail)?; + } + if let Some(hint) = &self.hint { + write!(fmt, "\nHINT: {}", hint)?; + } + Ok(()) + } +} + +impl error::Error for DbError {} + +/// Represents the position of an error in a query. +#[derive(Clone, PartialEq, Eq, Debug)] +pub enum ErrorPosition { + /// A position in the original query. + Original(u32), + /// A position in an internally generated query. + Internal { + /// The byte position. + position: u32, + /// A query generated by the Postgres server. + query: String, + }, +} + +#[derive(Debug, PartialEq)] +enum Kind { + Io, + UnexpectedMessage, + Tls, + ToSql(usize), + FromSql(usize), + Column(String), + Closed, + Db, + Parse, + Encode, + Authentication, + ConfigParse, + Config, + Connect, + Timeout, +} + +struct ErrorInner { + kind: Kind, + cause: Option>, +} + +/// An error communicating with the Postgres server. +pub struct Error(Box); + +impl fmt::Debug for Error { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Error") + .field("kind", &self.0.kind) + .field("cause", &self.0.cause) + .finish() + } +} + +impl fmt::Display for Error { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.0.kind { + Kind::Io => fmt.write_str("error communicating with the server")?, + Kind::UnexpectedMessage => fmt.write_str("unexpected message from server")?, + Kind::Tls => fmt.write_str("error performing TLS handshake")?, + Kind::ToSql(idx) => write!(fmt, "error serializing parameter {}", idx)?, + Kind::FromSql(idx) => write!(fmt, "error deserializing column {}", idx)?, + Kind::Column(column) => write!(fmt, "invalid column `{}`", column)?, + Kind::Closed => fmt.write_str("connection closed")?, + Kind::Db => fmt.write_str("db error")?, + Kind::Parse => fmt.write_str("error parsing response from server")?, + Kind::Encode => fmt.write_str("error encoding message to server")?, + Kind::Authentication => fmt.write_str("authentication error")?, + Kind::ConfigParse => fmt.write_str("invalid connection string")?, + Kind::Config => fmt.write_str("invalid configuration")?, + Kind::Connect => fmt.write_str("error connecting to server")?, + Kind::Timeout => fmt.write_str("timeout waiting for server")?, + }; + if let Some(ref cause) = self.0.cause { + write!(fmt, ": {}", cause)?; + } + Ok(()) + } +} + +impl error::Error for Error { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { + self.0.cause.as_ref().map(|e| &**e as _) + } +} + +impl Error { + /// Consumes the error, returning its cause. + pub fn into_source(self) -> Option> { + self.0.cause + } + + /// Returns the source of this error if it was a `DbError`. + /// + /// This is a simple convenience method. + pub fn as_db_error(&self) -> Option<&DbError> { + self.source().and_then(|e| e.downcast_ref::()) + } + + /// Determines if the error was associated with closed connection. + pub fn is_closed(&self) -> bool { + self.0.kind == Kind::Closed + } + + /// Returns the SQLSTATE error code associated with the error. + /// + /// This is a convenience method that downcasts the cause to a `DbError` and returns its code. + pub fn code(&self) -> Option<&SqlState> { + self.as_db_error().map(DbError::code) + } + + fn new(kind: Kind, cause: Option>) -> Error { + Error(Box::new(ErrorInner { kind, cause })) + } + + pub(crate) fn closed() -> Error { + Error::new(Kind::Closed, None) + } + + pub(crate) fn unexpected_message() -> Error { + Error::new(Kind::UnexpectedMessage, None) + } + + #[allow(clippy::needless_pass_by_value)] + pub(crate) fn db(error: ErrorResponseBody) -> Error { + match DbError::parse(&mut error.fields()) { + Ok(e) => Error::new(Kind::Db, Some(Box::new(e))), + Err(e) => Error::new(Kind::Parse, Some(Box::new(e))), + } + } + + pub(crate) fn parse(e: io::Error) -> Error { + Error::new(Kind::Parse, Some(Box::new(e))) + } + + pub(crate) fn encode(e: io::Error) -> Error { + Error::new(Kind::Encode, Some(Box::new(e))) + } + + #[allow(clippy::wrong_self_convention)] + pub(crate) fn to_sql(e: Box, idx: usize) -> Error { + Error::new(Kind::ToSql(idx), Some(e)) + } + + pub(crate) fn from_sql(e: Box, idx: usize) -> Error { + Error::new(Kind::FromSql(idx), Some(e)) + } + + pub(crate) fn column(column: String) -> Error { + Error::new(Kind::Column(column), None) + } + + pub(crate) fn tls(e: Box) -> Error { + Error::new(Kind::Tls, Some(e)) + } + + pub(crate) fn io(e: io::Error) -> Error { + Error::new(Kind::Io, Some(Box::new(e))) + } + + pub(crate) fn authentication(e: Box) -> Error { + Error::new(Kind::Authentication, Some(e)) + } + + pub(crate) fn config_parse(e: Box) -> Error { + Error::new(Kind::ConfigParse, Some(e)) + } + + pub(crate) fn config(e: Box) -> Error { + Error::new(Kind::Config, Some(e)) + } + + pub(crate) fn connect(e: io::Error) -> Error { + Error::new(Kind::Connect, Some(Box::new(e))) + } + + #[doc(hidden)] + pub fn __private_api_timeout() -> Error { + Error::new(Kind::Timeout, None) + } +} diff --git a/libs/proxy/tokio-postgres2/src/error/sqlstate.rs b/libs/proxy/tokio-postgres2/src/error/sqlstate.rs new file mode 100644 index 0000000000..13a1d75f95 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/error/sqlstate.rs @@ -0,0 +1,1670 @@ +// Autogenerated file - DO NOT EDIT + +/// A SQLSTATE error code +#[derive(PartialEq, Eq, Clone, Debug)] +pub struct SqlState(Inner); + +impl SqlState { + /// Creates a `SqlState` from its error code. + pub fn from_code(s: &str) -> SqlState { + match SQLSTATE_MAP.get(s) { + Some(state) => state.clone(), + None => SqlState(Inner::Other(s.into())), + } + } + + /// Returns the error code corresponding to the `SqlState`. + pub fn code(&self) -> &str { + match &self.0 { + Inner::E00000 => "00000", + Inner::E01000 => "01000", + Inner::E0100C => "0100C", + Inner::E01008 => "01008", + Inner::E01003 => "01003", + Inner::E01007 => "01007", + Inner::E01006 => "01006", + Inner::E01004 => "01004", + Inner::E01P01 => "01P01", + Inner::E02000 => "02000", + Inner::E02001 => "02001", + Inner::E03000 => "03000", + Inner::E08000 => "08000", + Inner::E08003 => "08003", + Inner::E08006 => "08006", + Inner::E08001 => "08001", + Inner::E08004 => "08004", + Inner::E08007 => "08007", + Inner::E08P01 => "08P01", + Inner::E09000 => "09000", + Inner::E0A000 => "0A000", + Inner::E0B000 => "0B000", + Inner::E0F000 => "0F000", + Inner::E0F001 => "0F001", + Inner::E0L000 => "0L000", + Inner::E0LP01 => "0LP01", + Inner::E0P000 => "0P000", + Inner::E0Z000 => "0Z000", + Inner::E0Z002 => "0Z002", + Inner::E20000 => "20000", + Inner::E21000 => "21000", + Inner::E22000 => "22000", + Inner::E2202E => "2202E", + Inner::E22021 => "22021", + Inner::E22008 => "22008", + Inner::E22012 => "22012", + Inner::E22005 => "22005", + Inner::E2200B => "2200B", + Inner::E22022 => "22022", + Inner::E22015 => "22015", + Inner::E2201E => "2201E", + Inner::E22014 => "22014", + Inner::E22016 => "22016", + Inner::E2201F => "2201F", + Inner::E2201G => "2201G", + Inner::E22018 => "22018", + Inner::E22007 => "22007", + Inner::E22019 => "22019", + Inner::E2200D => "2200D", + Inner::E22025 => "22025", + Inner::E22P06 => "22P06", + Inner::E22010 => "22010", + Inner::E22023 => "22023", + Inner::E22013 => "22013", + Inner::E2201B => "2201B", + Inner::E2201W => "2201W", + Inner::E2201X => "2201X", + Inner::E2202H => "2202H", + Inner::E2202G => "2202G", + Inner::E22009 => "22009", + Inner::E2200C => "2200C", + Inner::E2200G => "2200G", + Inner::E22004 => "22004", + Inner::E22002 => "22002", + Inner::E22003 => "22003", + Inner::E2200H => "2200H", + Inner::E22026 => "22026", + Inner::E22001 => "22001", + Inner::E22011 => "22011", + Inner::E22027 => "22027", + Inner::E22024 => "22024", + Inner::E2200F => "2200F", + Inner::E22P01 => "22P01", + Inner::E22P02 => "22P02", + Inner::E22P03 => "22P03", + Inner::E22P04 => "22P04", + Inner::E22P05 => "22P05", + Inner::E2200L => "2200L", + Inner::E2200M => "2200M", + Inner::E2200N => "2200N", + Inner::E2200S => "2200S", + Inner::E2200T => "2200T", + Inner::E22030 => "22030", + Inner::E22031 => "22031", + Inner::E22032 => "22032", + Inner::E22033 => "22033", + Inner::E22034 => "22034", + Inner::E22035 => "22035", + Inner::E22036 => "22036", + Inner::E22037 => "22037", + Inner::E22038 => "22038", + Inner::E22039 => "22039", + Inner::E2203A => "2203A", + Inner::E2203B => "2203B", + Inner::E2203C => "2203C", + Inner::E2203D => "2203D", + Inner::E2203E => "2203E", + Inner::E2203F => "2203F", + Inner::E2203G => "2203G", + Inner::E23000 => "23000", + Inner::E23001 => "23001", + Inner::E23502 => "23502", + Inner::E23503 => "23503", + Inner::E23505 => "23505", + Inner::E23514 => "23514", + Inner::E23P01 => "23P01", + Inner::E24000 => "24000", + Inner::E25000 => "25000", + Inner::E25001 => "25001", + Inner::E25002 => "25002", + Inner::E25008 => "25008", + Inner::E25003 => "25003", + Inner::E25004 => "25004", + Inner::E25005 => "25005", + Inner::E25006 => "25006", + Inner::E25007 => "25007", + Inner::E25P01 => "25P01", + Inner::E25P02 => "25P02", + Inner::E25P03 => "25P03", + Inner::E26000 => "26000", + Inner::E27000 => "27000", + Inner::E28000 => "28000", + Inner::E28P01 => "28P01", + Inner::E2B000 => "2B000", + Inner::E2BP01 => "2BP01", + Inner::E2D000 => "2D000", + Inner::E2F000 => "2F000", + Inner::E2F005 => "2F005", + Inner::E2F002 => "2F002", + Inner::E2F003 => "2F003", + Inner::E2F004 => "2F004", + Inner::E34000 => "34000", + Inner::E38000 => "38000", + Inner::E38001 => "38001", + Inner::E38002 => "38002", + Inner::E38003 => "38003", + Inner::E38004 => "38004", + Inner::E39000 => "39000", + Inner::E39001 => "39001", + Inner::E39004 => "39004", + Inner::E39P01 => "39P01", + Inner::E39P02 => "39P02", + Inner::E39P03 => "39P03", + Inner::E3B000 => "3B000", + Inner::E3B001 => "3B001", + Inner::E3D000 => "3D000", + Inner::E3F000 => "3F000", + Inner::E40000 => "40000", + Inner::E40002 => "40002", + Inner::E40001 => "40001", + Inner::E40003 => "40003", + Inner::E40P01 => "40P01", + Inner::E42000 => "42000", + Inner::E42601 => "42601", + Inner::E42501 => "42501", + Inner::E42846 => "42846", + Inner::E42803 => "42803", + Inner::E42P20 => "42P20", + Inner::E42P19 => "42P19", + Inner::E42830 => "42830", + Inner::E42602 => "42602", + Inner::E42622 => "42622", + Inner::E42939 => "42939", + Inner::E42804 => "42804", + Inner::E42P18 => "42P18", + Inner::E42P21 => "42P21", + Inner::E42P22 => "42P22", + Inner::E42809 => "42809", + Inner::E428C9 => "428C9", + Inner::E42703 => "42703", + Inner::E42883 => "42883", + Inner::E42P01 => "42P01", + Inner::E42P02 => "42P02", + Inner::E42704 => "42704", + Inner::E42701 => "42701", + Inner::E42P03 => "42P03", + Inner::E42P04 => "42P04", + Inner::E42723 => "42723", + Inner::E42P05 => "42P05", + Inner::E42P06 => "42P06", + Inner::E42P07 => "42P07", + Inner::E42712 => "42712", + Inner::E42710 => "42710", + Inner::E42702 => "42702", + Inner::E42725 => "42725", + Inner::E42P08 => "42P08", + Inner::E42P09 => "42P09", + Inner::E42P10 => "42P10", + Inner::E42611 => "42611", + Inner::E42P11 => "42P11", + Inner::E42P12 => "42P12", + Inner::E42P13 => "42P13", + Inner::E42P14 => "42P14", + Inner::E42P15 => "42P15", + Inner::E42P16 => "42P16", + Inner::E42P17 => "42P17", + Inner::E44000 => "44000", + Inner::E53000 => "53000", + Inner::E53100 => "53100", + Inner::E53200 => "53200", + Inner::E53300 => "53300", + Inner::E53400 => "53400", + Inner::E54000 => "54000", + Inner::E54001 => "54001", + Inner::E54011 => "54011", + Inner::E54023 => "54023", + Inner::E55000 => "55000", + Inner::E55006 => "55006", + Inner::E55P02 => "55P02", + Inner::E55P03 => "55P03", + Inner::E55P04 => "55P04", + Inner::E57000 => "57000", + Inner::E57014 => "57014", + Inner::E57P01 => "57P01", + Inner::E57P02 => "57P02", + Inner::E57P03 => "57P03", + Inner::E57P04 => "57P04", + Inner::E57P05 => "57P05", + Inner::E58000 => "58000", + Inner::E58030 => "58030", + Inner::E58P01 => "58P01", + Inner::E58P02 => "58P02", + Inner::E72000 => "72000", + Inner::EF0000 => "F0000", + Inner::EF0001 => "F0001", + Inner::EHV000 => "HV000", + Inner::EHV005 => "HV005", + Inner::EHV002 => "HV002", + Inner::EHV010 => "HV010", + Inner::EHV021 => "HV021", + Inner::EHV024 => "HV024", + Inner::EHV007 => "HV007", + Inner::EHV008 => "HV008", + Inner::EHV004 => "HV004", + Inner::EHV006 => "HV006", + Inner::EHV091 => "HV091", + Inner::EHV00B => "HV00B", + Inner::EHV00C => "HV00C", + Inner::EHV00D => "HV00D", + Inner::EHV090 => "HV090", + Inner::EHV00A => "HV00A", + Inner::EHV009 => "HV009", + Inner::EHV014 => "HV014", + Inner::EHV001 => "HV001", + Inner::EHV00P => "HV00P", + Inner::EHV00J => "HV00J", + Inner::EHV00K => "HV00K", + Inner::EHV00Q => "HV00Q", + Inner::EHV00R => "HV00R", + Inner::EHV00L => "HV00L", + Inner::EHV00M => "HV00M", + Inner::EHV00N => "HV00N", + Inner::EP0000 => "P0000", + Inner::EP0001 => "P0001", + Inner::EP0002 => "P0002", + Inner::EP0003 => "P0003", + Inner::EP0004 => "P0004", + Inner::EXX000 => "XX000", + Inner::EXX001 => "XX001", + Inner::EXX002 => "XX002", + Inner::Other(code) => code, + } + } + + /// 00000 + pub const SUCCESSFUL_COMPLETION: SqlState = SqlState(Inner::E00000); + + /// 01000 + pub const WARNING: SqlState = SqlState(Inner::E01000); + + /// 0100C + pub const WARNING_DYNAMIC_RESULT_SETS_RETURNED: SqlState = SqlState(Inner::E0100C); + + /// 01008 + pub const WARNING_IMPLICIT_ZERO_BIT_PADDING: SqlState = SqlState(Inner::E01008); + + /// 01003 + pub const WARNING_NULL_VALUE_ELIMINATED_IN_SET_FUNCTION: SqlState = SqlState(Inner::E01003); + + /// 01007 + pub const WARNING_PRIVILEGE_NOT_GRANTED: SqlState = SqlState(Inner::E01007); + + /// 01006 + pub const WARNING_PRIVILEGE_NOT_REVOKED: SqlState = SqlState(Inner::E01006); + + /// 01004 + pub const WARNING_STRING_DATA_RIGHT_TRUNCATION: SqlState = SqlState(Inner::E01004); + + /// 01P01 + pub const WARNING_DEPRECATED_FEATURE: SqlState = SqlState(Inner::E01P01); + + /// 02000 + pub const NO_DATA: SqlState = SqlState(Inner::E02000); + + /// 02001 + pub const NO_ADDITIONAL_DYNAMIC_RESULT_SETS_RETURNED: SqlState = SqlState(Inner::E02001); + + /// 03000 + pub const SQL_STATEMENT_NOT_YET_COMPLETE: SqlState = SqlState(Inner::E03000); + + /// 08000 + pub const CONNECTION_EXCEPTION: SqlState = SqlState(Inner::E08000); + + /// 08003 + pub const CONNECTION_DOES_NOT_EXIST: SqlState = SqlState(Inner::E08003); + + /// 08006 + pub const CONNECTION_FAILURE: SqlState = SqlState(Inner::E08006); + + /// 08001 + pub const SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION: SqlState = SqlState(Inner::E08001); + + /// 08004 + pub const SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION: SqlState = SqlState(Inner::E08004); + + /// 08007 + pub const TRANSACTION_RESOLUTION_UNKNOWN: SqlState = SqlState(Inner::E08007); + + /// 08P01 + pub const PROTOCOL_VIOLATION: SqlState = SqlState(Inner::E08P01); + + /// 09000 + pub const TRIGGERED_ACTION_EXCEPTION: SqlState = SqlState(Inner::E09000); + + /// 0A000 + pub const FEATURE_NOT_SUPPORTED: SqlState = SqlState(Inner::E0A000); + + /// 0B000 + pub const INVALID_TRANSACTION_INITIATION: SqlState = SqlState(Inner::E0B000); + + /// 0F000 + pub const LOCATOR_EXCEPTION: SqlState = SqlState(Inner::E0F000); + + /// 0F001 + pub const L_E_INVALID_SPECIFICATION: SqlState = SqlState(Inner::E0F001); + + /// 0L000 + pub const INVALID_GRANTOR: SqlState = SqlState(Inner::E0L000); + + /// 0LP01 + pub const INVALID_GRANT_OPERATION: SqlState = SqlState(Inner::E0LP01); + + /// 0P000 + pub const INVALID_ROLE_SPECIFICATION: SqlState = SqlState(Inner::E0P000); + + /// 0Z000 + pub const DIAGNOSTICS_EXCEPTION: SqlState = SqlState(Inner::E0Z000); + + /// 0Z002 + pub const STACKED_DIAGNOSTICS_ACCESSED_WITHOUT_ACTIVE_HANDLER: SqlState = + SqlState(Inner::E0Z002); + + /// 20000 + pub const CASE_NOT_FOUND: SqlState = SqlState(Inner::E20000); + + /// 21000 + pub const CARDINALITY_VIOLATION: SqlState = SqlState(Inner::E21000); + + /// 22000 + pub const DATA_EXCEPTION: SqlState = SqlState(Inner::E22000); + + /// 2202E + pub const ARRAY_ELEMENT_ERROR: SqlState = SqlState(Inner::E2202E); + + /// 2202E + pub const ARRAY_SUBSCRIPT_ERROR: SqlState = SqlState(Inner::E2202E); + + /// 22021 + pub const CHARACTER_NOT_IN_REPERTOIRE: SqlState = SqlState(Inner::E22021); + + /// 22008 + pub const DATETIME_FIELD_OVERFLOW: SqlState = SqlState(Inner::E22008); + + /// 22008 + pub const DATETIME_VALUE_OUT_OF_RANGE: SqlState = SqlState(Inner::E22008); + + /// 22012 + pub const DIVISION_BY_ZERO: SqlState = SqlState(Inner::E22012); + + /// 22005 + pub const ERROR_IN_ASSIGNMENT: SqlState = SqlState(Inner::E22005); + + /// 2200B + pub const ESCAPE_CHARACTER_CONFLICT: SqlState = SqlState(Inner::E2200B); + + /// 22022 + pub const INDICATOR_OVERFLOW: SqlState = SqlState(Inner::E22022); + + /// 22015 + pub const INTERVAL_FIELD_OVERFLOW: SqlState = SqlState(Inner::E22015); + + /// 2201E + pub const INVALID_ARGUMENT_FOR_LOG: SqlState = SqlState(Inner::E2201E); + + /// 22014 + pub const INVALID_ARGUMENT_FOR_NTILE: SqlState = SqlState(Inner::E22014); + + /// 22016 + pub const INVALID_ARGUMENT_FOR_NTH_VALUE: SqlState = SqlState(Inner::E22016); + + /// 2201F + pub const INVALID_ARGUMENT_FOR_POWER_FUNCTION: SqlState = SqlState(Inner::E2201F); + + /// 2201G + pub const INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION: SqlState = SqlState(Inner::E2201G); + + /// 22018 + pub const INVALID_CHARACTER_VALUE_FOR_CAST: SqlState = SqlState(Inner::E22018); + + /// 22007 + pub const INVALID_DATETIME_FORMAT: SqlState = SqlState(Inner::E22007); + + /// 22019 + pub const INVALID_ESCAPE_CHARACTER: SqlState = SqlState(Inner::E22019); + + /// 2200D + pub const INVALID_ESCAPE_OCTET: SqlState = SqlState(Inner::E2200D); + + /// 22025 + pub const INVALID_ESCAPE_SEQUENCE: SqlState = SqlState(Inner::E22025); + + /// 22P06 + pub const NONSTANDARD_USE_OF_ESCAPE_CHARACTER: SqlState = SqlState(Inner::E22P06); + + /// 22010 + pub const INVALID_INDICATOR_PARAMETER_VALUE: SqlState = SqlState(Inner::E22010); + + /// 22023 + pub const INVALID_PARAMETER_VALUE: SqlState = SqlState(Inner::E22023); + + /// 22013 + pub const INVALID_PRECEDING_OR_FOLLOWING_SIZE: SqlState = SqlState(Inner::E22013); + + /// 2201B + pub const INVALID_REGULAR_EXPRESSION: SqlState = SqlState(Inner::E2201B); + + /// 2201W + pub const INVALID_ROW_COUNT_IN_LIMIT_CLAUSE: SqlState = SqlState(Inner::E2201W); + + /// 2201X + pub const INVALID_ROW_COUNT_IN_RESULT_OFFSET_CLAUSE: SqlState = SqlState(Inner::E2201X); + + /// 2202H + pub const INVALID_TABLESAMPLE_ARGUMENT: SqlState = SqlState(Inner::E2202H); + + /// 2202G + pub const INVALID_TABLESAMPLE_REPEAT: SqlState = SqlState(Inner::E2202G); + + /// 22009 + pub const INVALID_TIME_ZONE_DISPLACEMENT_VALUE: SqlState = SqlState(Inner::E22009); + + /// 2200C + pub const INVALID_USE_OF_ESCAPE_CHARACTER: SqlState = SqlState(Inner::E2200C); + + /// 2200G + pub const MOST_SPECIFIC_TYPE_MISMATCH: SqlState = SqlState(Inner::E2200G); + + /// 22004 + pub const NULL_VALUE_NOT_ALLOWED: SqlState = SqlState(Inner::E22004); + + /// 22002 + pub const NULL_VALUE_NO_INDICATOR_PARAMETER: SqlState = SqlState(Inner::E22002); + + /// 22003 + pub const NUMERIC_VALUE_OUT_OF_RANGE: SqlState = SqlState(Inner::E22003); + + /// 2200H + pub const SEQUENCE_GENERATOR_LIMIT_EXCEEDED: SqlState = SqlState(Inner::E2200H); + + /// 22026 + pub const STRING_DATA_LENGTH_MISMATCH: SqlState = SqlState(Inner::E22026); + + /// 22001 + pub const STRING_DATA_RIGHT_TRUNCATION: SqlState = SqlState(Inner::E22001); + + /// 22011 + pub const SUBSTRING_ERROR: SqlState = SqlState(Inner::E22011); + + /// 22027 + pub const TRIM_ERROR: SqlState = SqlState(Inner::E22027); + + /// 22024 + pub const UNTERMINATED_C_STRING: SqlState = SqlState(Inner::E22024); + + /// 2200F + pub const ZERO_LENGTH_CHARACTER_STRING: SqlState = SqlState(Inner::E2200F); + + /// 22P01 + pub const FLOATING_POINT_EXCEPTION: SqlState = SqlState(Inner::E22P01); + + /// 22P02 + pub const INVALID_TEXT_REPRESENTATION: SqlState = SqlState(Inner::E22P02); + + /// 22P03 + pub const INVALID_BINARY_REPRESENTATION: SqlState = SqlState(Inner::E22P03); + + /// 22P04 + pub const BAD_COPY_FILE_FORMAT: SqlState = SqlState(Inner::E22P04); + + /// 22P05 + pub const UNTRANSLATABLE_CHARACTER: SqlState = SqlState(Inner::E22P05); + + /// 2200L + pub const NOT_AN_XML_DOCUMENT: SqlState = SqlState(Inner::E2200L); + + /// 2200M + pub const INVALID_XML_DOCUMENT: SqlState = SqlState(Inner::E2200M); + + /// 2200N + pub const INVALID_XML_CONTENT: SqlState = SqlState(Inner::E2200N); + + /// 2200S + pub const INVALID_XML_COMMENT: SqlState = SqlState(Inner::E2200S); + + /// 2200T + pub const INVALID_XML_PROCESSING_INSTRUCTION: SqlState = SqlState(Inner::E2200T); + + /// 22030 + pub const DUPLICATE_JSON_OBJECT_KEY_VALUE: SqlState = SqlState(Inner::E22030); + + /// 22031 + pub const INVALID_ARGUMENT_FOR_SQL_JSON_DATETIME_FUNCTION: SqlState = SqlState(Inner::E22031); + + /// 22032 + pub const INVALID_JSON_TEXT: SqlState = SqlState(Inner::E22032); + + /// 22033 + pub const INVALID_SQL_JSON_SUBSCRIPT: SqlState = SqlState(Inner::E22033); + + /// 22034 + pub const MORE_THAN_ONE_SQL_JSON_ITEM: SqlState = SqlState(Inner::E22034); + + /// 22035 + pub const NO_SQL_JSON_ITEM: SqlState = SqlState(Inner::E22035); + + /// 22036 + pub const NON_NUMERIC_SQL_JSON_ITEM: SqlState = SqlState(Inner::E22036); + + /// 22037 + pub const NON_UNIQUE_KEYS_IN_A_JSON_OBJECT: SqlState = SqlState(Inner::E22037); + + /// 22038 + pub const SINGLETON_SQL_JSON_ITEM_REQUIRED: SqlState = SqlState(Inner::E22038); + + /// 22039 + pub const SQL_JSON_ARRAY_NOT_FOUND: SqlState = SqlState(Inner::E22039); + + /// 2203A + pub const SQL_JSON_MEMBER_NOT_FOUND: SqlState = SqlState(Inner::E2203A); + + /// 2203B + pub const SQL_JSON_NUMBER_NOT_FOUND: SqlState = SqlState(Inner::E2203B); + + /// 2203C + pub const SQL_JSON_OBJECT_NOT_FOUND: SqlState = SqlState(Inner::E2203C); + + /// 2203D + pub const TOO_MANY_JSON_ARRAY_ELEMENTS: SqlState = SqlState(Inner::E2203D); + + /// 2203E + pub const TOO_MANY_JSON_OBJECT_MEMBERS: SqlState = SqlState(Inner::E2203E); + + /// 2203F + pub const SQL_JSON_SCALAR_REQUIRED: SqlState = SqlState(Inner::E2203F); + + /// 2203G + pub const SQL_JSON_ITEM_CANNOT_BE_CAST_TO_TARGET_TYPE: SqlState = SqlState(Inner::E2203G); + + /// 23000 + pub const INTEGRITY_CONSTRAINT_VIOLATION: SqlState = SqlState(Inner::E23000); + + /// 23001 + pub const RESTRICT_VIOLATION: SqlState = SqlState(Inner::E23001); + + /// 23502 + pub const NOT_NULL_VIOLATION: SqlState = SqlState(Inner::E23502); + + /// 23503 + pub const FOREIGN_KEY_VIOLATION: SqlState = SqlState(Inner::E23503); + + /// 23505 + pub const UNIQUE_VIOLATION: SqlState = SqlState(Inner::E23505); + + /// 23514 + pub const CHECK_VIOLATION: SqlState = SqlState(Inner::E23514); + + /// 23P01 + pub const EXCLUSION_VIOLATION: SqlState = SqlState(Inner::E23P01); + + /// 24000 + pub const INVALID_CURSOR_STATE: SqlState = SqlState(Inner::E24000); + + /// 25000 + pub const INVALID_TRANSACTION_STATE: SqlState = SqlState(Inner::E25000); + + /// 25001 + pub const ACTIVE_SQL_TRANSACTION: SqlState = SqlState(Inner::E25001); + + /// 25002 + pub const BRANCH_TRANSACTION_ALREADY_ACTIVE: SqlState = SqlState(Inner::E25002); + + /// 25008 + pub const HELD_CURSOR_REQUIRES_SAME_ISOLATION_LEVEL: SqlState = SqlState(Inner::E25008); + + /// 25003 + pub const INAPPROPRIATE_ACCESS_MODE_FOR_BRANCH_TRANSACTION: SqlState = SqlState(Inner::E25003); + + /// 25004 + pub const INAPPROPRIATE_ISOLATION_LEVEL_FOR_BRANCH_TRANSACTION: SqlState = + SqlState(Inner::E25004); + + /// 25005 + pub const NO_ACTIVE_SQL_TRANSACTION_FOR_BRANCH_TRANSACTION: SqlState = SqlState(Inner::E25005); + + /// 25006 + pub const READ_ONLY_SQL_TRANSACTION: SqlState = SqlState(Inner::E25006); + + /// 25007 + pub const SCHEMA_AND_DATA_STATEMENT_MIXING_NOT_SUPPORTED: SqlState = SqlState(Inner::E25007); + + /// 25P01 + pub const NO_ACTIVE_SQL_TRANSACTION: SqlState = SqlState(Inner::E25P01); + + /// 25P02 + pub const IN_FAILED_SQL_TRANSACTION: SqlState = SqlState(Inner::E25P02); + + /// 25P03 + pub const IDLE_IN_TRANSACTION_SESSION_TIMEOUT: SqlState = SqlState(Inner::E25P03); + + /// 26000 + pub const INVALID_SQL_STATEMENT_NAME: SqlState = SqlState(Inner::E26000); + + /// 26000 + pub const UNDEFINED_PSTATEMENT: SqlState = SqlState(Inner::E26000); + + /// 27000 + pub const TRIGGERED_DATA_CHANGE_VIOLATION: SqlState = SqlState(Inner::E27000); + + /// 28000 + pub const INVALID_AUTHORIZATION_SPECIFICATION: SqlState = SqlState(Inner::E28000); + + /// 28P01 + pub const INVALID_PASSWORD: SqlState = SqlState(Inner::E28P01); + + /// 2B000 + pub const DEPENDENT_PRIVILEGE_DESCRIPTORS_STILL_EXIST: SqlState = SqlState(Inner::E2B000); + + /// 2BP01 + pub const DEPENDENT_OBJECTS_STILL_EXIST: SqlState = SqlState(Inner::E2BP01); + + /// 2D000 + pub const INVALID_TRANSACTION_TERMINATION: SqlState = SqlState(Inner::E2D000); + + /// 2F000 + pub const SQL_ROUTINE_EXCEPTION: SqlState = SqlState(Inner::E2F000); + + /// 2F005 + pub const S_R_E_FUNCTION_EXECUTED_NO_RETURN_STATEMENT: SqlState = SqlState(Inner::E2F005); + + /// 2F002 + pub const S_R_E_MODIFYING_SQL_DATA_NOT_PERMITTED: SqlState = SqlState(Inner::E2F002); + + /// 2F003 + pub const S_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED: SqlState = SqlState(Inner::E2F003); + + /// 2F004 + pub const S_R_E_READING_SQL_DATA_NOT_PERMITTED: SqlState = SqlState(Inner::E2F004); + + /// 34000 + pub const INVALID_CURSOR_NAME: SqlState = SqlState(Inner::E34000); + + /// 34000 + pub const UNDEFINED_CURSOR: SqlState = SqlState(Inner::E34000); + + /// 38000 + pub const EXTERNAL_ROUTINE_EXCEPTION: SqlState = SqlState(Inner::E38000); + + /// 38001 + pub const E_R_E_CONTAINING_SQL_NOT_PERMITTED: SqlState = SqlState(Inner::E38001); + + /// 38002 + pub const E_R_E_MODIFYING_SQL_DATA_NOT_PERMITTED: SqlState = SqlState(Inner::E38002); + + /// 38003 + pub const E_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED: SqlState = SqlState(Inner::E38003); + + /// 38004 + pub const E_R_E_READING_SQL_DATA_NOT_PERMITTED: SqlState = SqlState(Inner::E38004); + + /// 39000 + pub const EXTERNAL_ROUTINE_INVOCATION_EXCEPTION: SqlState = SqlState(Inner::E39000); + + /// 39001 + pub const E_R_I_E_INVALID_SQLSTATE_RETURNED: SqlState = SqlState(Inner::E39001); + + /// 39004 + pub const E_R_I_E_NULL_VALUE_NOT_ALLOWED: SqlState = SqlState(Inner::E39004); + + /// 39P01 + pub const E_R_I_E_TRIGGER_PROTOCOL_VIOLATED: SqlState = SqlState(Inner::E39P01); + + /// 39P02 + pub const E_R_I_E_SRF_PROTOCOL_VIOLATED: SqlState = SqlState(Inner::E39P02); + + /// 39P03 + pub const E_R_I_E_EVENT_TRIGGER_PROTOCOL_VIOLATED: SqlState = SqlState(Inner::E39P03); + + /// 3B000 + pub const SAVEPOINT_EXCEPTION: SqlState = SqlState(Inner::E3B000); + + /// 3B001 + pub const S_E_INVALID_SPECIFICATION: SqlState = SqlState(Inner::E3B001); + + /// 3D000 + pub const INVALID_CATALOG_NAME: SqlState = SqlState(Inner::E3D000); + + /// 3D000 + pub const UNDEFINED_DATABASE: SqlState = SqlState(Inner::E3D000); + + /// 3F000 + pub const INVALID_SCHEMA_NAME: SqlState = SqlState(Inner::E3F000); + + /// 3F000 + pub const UNDEFINED_SCHEMA: SqlState = SqlState(Inner::E3F000); + + /// 40000 + pub const TRANSACTION_ROLLBACK: SqlState = SqlState(Inner::E40000); + + /// 40002 + pub const T_R_INTEGRITY_CONSTRAINT_VIOLATION: SqlState = SqlState(Inner::E40002); + + /// 40001 + pub const T_R_SERIALIZATION_FAILURE: SqlState = SqlState(Inner::E40001); + + /// 40003 + pub const T_R_STATEMENT_COMPLETION_UNKNOWN: SqlState = SqlState(Inner::E40003); + + /// 40P01 + pub const T_R_DEADLOCK_DETECTED: SqlState = SqlState(Inner::E40P01); + + /// 42000 + pub const SYNTAX_ERROR_OR_ACCESS_RULE_VIOLATION: SqlState = SqlState(Inner::E42000); + + /// 42601 + pub const SYNTAX_ERROR: SqlState = SqlState(Inner::E42601); + + /// 42501 + pub const INSUFFICIENT_PRIVILEGE: SqlState = SqlState(Inner::E42501); + + /// 42846 + pub const CANNOT_COERCE: SqlState = SqlState(Inner::E42846); + + /// 42803 + pub const GROUPING_ERROR: SqlState = SqlState(Inner::E42803); + + /// 42P20 + pub const WINDOWING_ERROR: SqlState = SqlState(Inner::E42P20); + + /// 42P19 + pub const INVALID_RECURSION: SqlState = SqlState(Inner::E42P19); + + /// 42830 + pub const INVALID_FOREIGN_KEY: SqlState = SqlState(Inner::E42830); + + /// 42602 + pub const INVALID_NAME: SqlState = SqlState(Inner::E42602); + + /// 42622 + pub const NAME_TOO_LONG: SqlState = SqlState(Inner::E42622); + + /// 42939 + pub const RESERVED_NAME: SqlState = SqlState(Inner::E42939); + + /// 42804 + pub const DATATYPE_MISMATCH: SqlState = SqlState(Inner::E42804); + + /// 42P18 + pub const INDETERMINATE_DATATYPE: SqlState = SqlState(Inner::E42P18); + + /// 42P21 + pub const COLLATION_MISMATCH: SqlState = SqlState(Inner::E42P21); + + /// 42P22 + pub const INDETERMINATE_COLLATION: SqlState = SqlState(Inner::E42P22); + + /// 42809 + pub const WRONG_OBJECT_TYPE: SqlState = SqlState(Inner::E42809); + + /// 428C9 + pub const GENERATED_ALWAYS: SqlState = SqlState(Inner::E428C9); + + /// 42703 + pub const UNDEFINED_COLUMN: SqlState = SqlState(Inner::E42703); + + /// 42883 + pub const UNDEFINED_FUNCTION: SqlState = SqlState(Inner::E42883); + + /// 42P01 + pub const UNDEFINED_TABLE: SqlState = SqlState(Inner::E42P01); + + /// 42P02 + pub const UNDEFINED_PARAMETER: SqlState = SqlState(Inner::E42P02); + + /// 42704 + pub const UNDEFINED_OBJECT: SqlState = SqlState(Inner::E42704); + + /// 42701 + pub const DUPLICATE_COLUMN: SqlState = SqlState(Inner::E42701); + + /// 42P03 + pub const DUPLICATE_CURSOR: SqlState = SqlState(Inner::E42P03); + + /// 42P04 + pub const DUPLICATE_DATABASE: SqlState = SqlState(Inner::E42P04); + + /// 42723 + pub const DUPLICATE_FUNCTION: SqlState = SqlState(Inner::E42723); + + /// 42P05 + pub const DUPLICATE_PSTATEMENT: SqlState = SqlState(Inner::E42P05); + + /// 42P06 + pub const DUPLICATE_SCHEMA: SqlState = SqlState(Inner::E42P06); + + /// 42P07 + pub const DUPLICATE_TABLE: SqlState = SqlState(Inner::E42P07); + + /// 42712 + pub const DUPLICATE_ALIAS: SqlState = SqlState(Inner::E42712); + + /// 42710 + pub const DUPLICATE_OBJECT: SqlState = SqlState(Inner::E42710); + + /// 42702 + pub const AMBIGUOUS_COLUMN: SqlState = SqlState(Inner::E42702); + + /// 42725 + pub const AMBIGUOUS_FUNCTION: SqlState = SqlState(Inner::E42725); + + /// 42P08 + pub const AMBIGUOUS_PARAMETER: SqlState = SqlState(Inner::E42P08); + + /// 42P09 + pub const AMBIGUOUS_ALIAS: SqlState = SqlState(Inner::E42P09); + + /// 42P10 + pub const INVALID_COLUMN_REFERENCE: SqlState = SqlState(Inner::E42P10); + + /// 42611 + pub const INVALID_COLUMN_DEFINITION: SqlState = SqlState(Inner::E42611); + + /// 42P11 + pub const INVALID_CURSOR_DEFINITION: SqlState = SqlState(Inner::E42P11); + + /// 42P12 + pub const INVALID_DATABASE_DEFINITION: SqlState = SqlState(Inner::E42P12); + + /// 42P13 + pub const INVALID_FUNCTION_DEFINITION: SqlState = SqlState(Inner::E42P13); + + /// 42P14 + pub const INVALID_PSTATEMENT_DEFINITION: SqlState = SqlState(Inner::E42P14); + + /// 42P15 + pub const INVALID_SCHEMA_DEFINITION: SqlState = SqlState(Inner::E42P15); + + /// 42P16 + pub const INVALID_TABLE_DEFINITION: SqlState = SqlState(Inner::E42P16); + + /// 42P17 + pub const INVALID_OBJECT_DEFINITION: SqlState = SqlState(Inner::E42P17); + + /// 44000 + pub const WITH_CHECK_OPTION_VIOLATION: SqlState = SqlState(Inner::E44000); + + /// 53000 + pub const INSUFFICIENT_RESOURCES: SqlState = SqlState(Inner::E53000); + + /// 53100 + pub const DISK_FULL: SqlState = SqlState(Inner::E53100); + + /// 53200 + pub const OUT_OF_MEMORY: SqlState = SqlState(Inner::E53200); + + /// 53300 + pub const TOO_MANY_CONNECTIONS: SqlState = SqlState(Inner::E53300); + + /// 53400 + pub const CONFIGURATION_LIMIT_EXCEEDED: SqlState = SqlState(Inner::E53400); + + /// 54000 + pub const PROGRAM_LIMIT_EXCEEDED: SqlState = SqlState(Inner::E54000); + + /// 54001 + pub const STATEMENT_TOO_COMPLEX: SqlState = SqlState(Inner::E54001); + + /// 54011 + pub const TOO_MANY_COLUMNS: SqlState = SqlState(Inner::E54011); + + /// 54023 + pub const TOO_MANY_ARGUMENTS: SqlState = SqlState(Inner::E54023); + + /// 55000 + pub const OBJECT_NOT_IN_PREREQUISITE_STATE: SqlState = SqlState(Inner::E55000); + + /// 55006 + pub const OBJECT_IN_USE: SqlState = SqlState(Inner::E55006); + + /// 55P02 + pub const CANT_CHANGE_RUNTIME_PARAM: SqlState = SqlState(Inner::E55P02); + + /// 55P03 + pub const LOCK_NOT_AVAILABLE: SqlState = SqlState(Inner::E55P03); + + /// 55P04 + pub const UNSAFE_NEW_ENUM_VALUE_USAGE: SqlState = SqlState(Inner::E55P04); + + /// 57000 + pub const OPERATOR_INTERVENTION: SqlState = SqlState(Inner::E57000); + + /// 57014 + pub const QUERY_CANCELED: SqlState = SqlState(Inner::E57014); + + /// 57P01 + pub const ADMIN_SHUTDOWN: SqlState = SqlState(Inner::E57P01); + + /// 57P02 + pub const CRASH_SHUTDOWN: SqlState = SqlState(Inner::E57P02); + + /// 57P03 + pub const CANNOT_CONNECT_NOW: SqlState = SqlState(Inner::E57P03); + + /// 57P04 + pub const DATABASE_DROPPED: SqlState = SqlState(Inner::E57P04); + + /// 57P05 + pub const IDLE_SESSION_TIMEOUT: SqlState = SqlState(Inner::E57P05); + + /// 58000 + pub const SYSTEM_ERROR: SqlState = SqlState(Inner::E58000); + + /// 58030 + pub const IO_ERROR: SqlState = SqlState(Inner::E58030); + + /// 58P01 + pub const UNDEFINED_FILE: SqlState = SqlState(Inner::E58P01); + + /// 58P02 + pub const DUPLICATE_FILE: SqlState = SqlState(Inner::E58P02); + + /// 72000 + pub const SNAPSHOT_TOO_OLD: SqlState = SqlState(Inner::E72000); + + /// F0000 + pub const CONFIG_FILE_ERROR: SqlState = SqlState(Inner::EF0000); + + /// F0001 + pub const LOCK_FILE_EXISTS: SqlState = SqlState(Inner::EF0001); + + /// HV000 + pub const FDW_ERROR: SqlState = SqlState(Inner::EHV000); + + /// HV005 + pub const FDW_COLUMN_NAME_NOT_FOUND: SqlState = SqlState(Inner::EHV005); + + /// HV002 + pub const FDW_DYNAMIC_PARAMETER_VALUE_NEEDED: SqlState = SqlState(Inner::EHV002); + + /// HV010 + pub const FDW_FUNCTION_SEQUENCE_ERROR: SqlState = SqlState(Inner::EHV010); + + /// HV021 + pub const FDW_INCONSISTENT_DESCRIPTOR_INFORMATION: SqlState = SqlState(Inner::EHV021); + + /// HV024 + pub const FDW_INVALID_ATTRIBUTE_VALUE: SqlState = SqlState(Inner::EHV024); + + /// HV007 + pub const FDW_INVALID_COLUMN_NAME: SqlState = SqlState(Inner::EHV007); + + /// HV008 + pub const FDW_INVALID_COLUMN_NUMBER: SqlState = SqlState(Inner::EHV008); + + /// HV004 + pub const FDW_INVALID_DATA_TYPE: SqlState = SqlState(Inner::EHV004); + + /// HV006 + pub const FDW_INVALID_DATA_TYPE_DESCRIPTORS: SqlState = SqlState(Inner::EHV006); + + /// HV091 + pub const FDW_INVALID_DESCRIPTOR_FIELD_IDENTIFIER: SqlState = SqlState(Inner::EHV091); + + /// HV00B + pub const FDW_INVALID_HANDLE: SqlState = SqlState(Inner::EHV00B); + + /// HV00C + pub const FDW_INVALID_OPTION_INDEX: SqlState = SqlState(Inner::EHV00C); + + /// HV00D + pub const FDW_INVALID_OPTION_NAME: SqlState = SqlState(Inner::EHV00D); + + /// HV090 + pub const FDW_INVALID_STRING_LENGTH_OR_BUFFER_LENGTH: SqlState = SqlState(Inner::EHV090); + + /// HV00A + pub const FDW_INVALID_STRING_FORMAT: SqlState = SqlState(Inner::EHV00A); + + /// HV009 + pub const FDW_INVALID_USE_OF_NULL_POINTER: SqlState = SqlState(Inner::EHV009); + + /// HV014 + pub const FDW_TOO_MANY_HANDLES: SqlState = SqlState(Inner::EHV014); + + /// HV001 + pub const FDW_OUT_OF_MEMORY: SqlState = SqlState(Inner::EHV001); + + /// HV00P + pub const FDW_NO_SCHEMAS: SqlState = SqlState(Inner::EHV00P); + + /// HV00J + pub const FDW_OPTION_NAME_NOT_FOUND: SqlState = SqlState(Inner::EHV00J); + + /// HV00K + pub const FDW_REPLY_HANDLE: SqlState = SqlState(Inner::EHV00K); + + /// HV00Q + pub const FDW_SCHEMA_NOT_FOUND: SqlState = SqlState(Inner::EHV00Q); + + /// HV00R + pub const FDW_TABLE_NOT_FOUND: SqlState = SqlState(Inner::EHV00R); + + /// HV00L + pub const FDW_UNABLE_TO_CREATE_EXECUTION: SqlState = SqlState(Inner::EHV00L); + + /// HV00M + pub const FDW_UNABLE_TO_CREATE_REPLY: SqlState = SqlState(Inner::EHV00M); + + /// HV00N + pub const FDW_UNABLE_TO_ESTABLISH_CONNECTION: SqlState = SqlState(Inner::EHV00N); + + /// P0000 + pub const PLPGSQL_ERROR: SqlState = SqlState(Inner::EP0000); + + /// P0001 + pub const RAISE_EXCEPTION: SqlState = SqlState(Inner::EP0001); + + /// P0002 + pub const NO_DATA_FOUND: SqlState = SqlState(Inner::EP0002); + + /// P0003 + pub const TOO_MANY_ROWS: SqlState = SqlState(Inner::EP0003); + + /// P0004 + pub const ASSERT_FAILURE: SqlState = SqlState(Inner::EP0004); + + /// XX000 + pub const INTERNAL_ERROR: SqlState = SqlState(Inner::EXX000); + + /// XX001 + pub const DATA_CORRUPTED: SqlState = SqlState(Inner::EXX001); + + /// XX002 + pub const INDEX_CORRUPTED: SqlState = SqlState(Inner::EXX002); +} + +#[derive(PartialEq, Eq, Clone, Debug)] +#[allow(clippy::upper_case_acronyms)] +enum Inner { + E00000, + E01000, + E0100C, + E01008, + E01003, + E01007, + E01006, + E01004, + E01P01, + E02000, + E02001, + E03000, + E08000, + E08003, + E08006, + E08001, + E08004, + E08007, + E08P01, + E09000, + E0A000, + E0B000, + E0F000, + E0F001, + E0L000, + E0LP01, + E0P000, + E0Z000, + E0Z002, + E20000, + E21000, + E22000, + E2202E, + E22021, + E22008, + E22012, + E22005, + E2200B, + E22022, + E22015, + E2201E, + E22014, + E22016, + E2201F, + E2201G, + E22018, + E22007, + E22019, + E2200D, + E22025, + E22P06, + E22010, + E22023, + E22013, + E2201B, + E2201W, + E2201X, + E2202H, + E2202G, + E22009, + E2200C, + E2200G, + E22004, + E22002, + E22003, + E2200H, + E22026, + E22001, + E22011, + E22027, + E22024, + E2200F, + E22P01, + E22P02, + E22P03, + E22P04, + E22P05, + E2200L, + E2200M, + E2200N, + E2200S, + E2200T, + E22030, + E22031, + E22032, + E22033, + E22034, + E22035, + E22036, + E22037, + E22038, + E22039, + E2203A, + E2203B, + E2203C, + E2203D, + E2203E, + E2203F, + E2203G, + E23000, + E23001, + E23502, + E23503, + E23505, + E23514, + E23P01, + E24000, + E25000, + E25001, + E25002, + E25008, + E25003, + E25004, + E25005, + E25006, + E25007, + E25P01, + E25P02, + E25P03, + E26000, + E27000, + E28000, + E28P01, + E2B000, + E2BP01, + E2D000, + E2F000, + E2F005, + E2F002, + E2F003, + E2F004, + E34000, + E38000, + E38001, + E38002, + E38003, + E38004, + E39000, + E39001, + E39004, + E39P01, + E39P02, + E39P03, + E3B000, + E3B001, + E3D000, + E3F000, + E40000, + E40002, + E40001, + E40003, + E40P01, + E42000, + E42601, + E42501, + E42846, + E42803, + E42P20, + E42P19, + E42830, + E42602, + E42622, + E42939, + E42804, + E42P18, + E42P21, + E42P22, + E42809, + E428C9, + E42703, + E42883, + E42P01, + E42P02, + E42704, + E42701, + E42P03, + E42P04, + E42723, + E42P05, + E42P06, + E42P07, + E42712, + E42710, + E42702, + E42725, + E42P08, + E42P09, + E42P10, + E42611, + E42P11, + E42P12, + E42P13, + E42P14, + E42P15, + E42P16, + E42P17, + E44000, + E53000, + E53100, + E53200, + E53300, + E53400, + E54000, + E54001, + E54011, + E54023, + E55000, + E55006, + E55P02, + E55P03, + E55P04, + E57000, + E57014, + E57P01, + E57P02, + E57P03, + E57P04, + E57P05, + E58000, + E58030, + E58P01, + E58P02, + E72000, + EF0000, + EF0001, + EHV000, + EHV005, + EHV002, + EHV010, + EHV021, + EHV024, + EHV007, + EHV008, + EHV004, + EHV006, + EHV091, + EHV00B, + EHV00C, + EHV00D, + EHV090, + EHV00A, + EHV009, + EHV014, + EHV001, + EHV00P, + EHV00J, + EHV00K, + EHV00Q, + EHV00R, + EHV00L, + EHV00M, + EHV00N, + EP0000, + EP0001, + EP0002, + EP0003, + EP0004, + EXX000, + EXX001, + EXX002, + Other(Box), +} + +#[rustfmt::skip] +static SQLSTATE_MAP: phf::Map<&'static str, SqlState> = +::phf::Map { + key: 12913932095322966823, + disps: &[ + (0, 24), + (0, 12), + (0, 74), + (0, 109), + (0, 11), + (0, 9), + (0, 0), + (4, 38), + (3, 155), + (0, 6), + (1, 242), + (0, 66), + (0, 53), + (5, 180), + (3, 221), + (7, 230), + (0, 125), + (1, 46), + (0, 11), + (1, 2), + (0, 5), + (0, 13), + (0, 171), + (0, 15), + (0, 4), + (0, 22), + (1, 85), + (0, 75), + (2, 0), + (1, 25), + (7, 47), + (0, 45), + (0, 35), + (0, 7), + (7, 124), + (0, 0), + (14, 104), + (1, 183), + (61, 50), + (3, 76), + (0, 12), + (0, 7), + (4, 189), + (0, 1), + (64, 102), + (0, 0), + (16, 192), + (24, 19), + (0, 5), + (0, 87), + (0, 89), + (0, 14), + ], + entries: &[ + ("2F000", SqlState::SQL_ROUTINE_EXCEPTION), + ("01008", SqlState::WARNING_IMPLICIT_ZERO_BIT_PADDING), + ("42501", SqlState::INSUFFICIENT_PRIVILEGE), + ("22000", SqlState::DATA_EXCEPTION), + ("0100C", SqlState::WARNING_DYNAMIC_RESULT_SETS_RETURNED), + ("2200N", SqlState::INVALID_XML_CONTENT), + ("40001", SqlState::T_R_SERIALIZATION_FAILURE), + ("28P01", SqlState::INVALID_PASSWORD), + ("38000", SqlState::EXTERNAL_ROUTINE_EXCEPTION), + ("25006", SqlState::READ_ONLY_SQL_TRANSACTION), + ("2203D", SqlState::TOO_MANY_JSON_ARRAY_ELEMENTS), + ("42P09", SqlState::AMBIGUOUS_ALIAS), + ("F0000", SqlState::CONFIG_FILE_ERROR), + ("42P18", SqlState::INDETERMINATE_DATATYPE), + ("40002", SqlState::T_R_INTEGRITY_CONSTRAINT_VIOLATION), + ("22009", SqlState::INVALID_TIME_ZONE_DISPLACEMENT_VALUE), + ("42P08", SqlState::AMBIGUOUS_PARAMETER), + ("08000", SqlState::CONNECTION_EXCEPTION), + ("25P01", SqlState::NO_ACTIVE_SQL_TRANSACTION), + ("22024", SqlState::UNTERMINATED_C_STRING), + ("55000", SqlState::OBJECT_NOT_IN_PREREQUISITE_STATE), + ("25001", SqlState::ACTIVE_SQL_TRANSACTION), + ("03000", SqlState::SQL_STATEMENT_NOT_YET_COMPLETE), + ("42710", SqlState::DUPLICATE_OBJECT), + ("2D000", SqlState::INVALID_TRANSACTION_TERMINATION), + ("2200G", SqlState::MOST_SPECIFIC_TYPE_MISMATCH), + ("22022", SqlState::INDICATOR_OVERFLOW), + ("55006", SqlState::OBJECT_IN_USE), + ("53200", SqlState::OUT_OF_MEMORY), + ("22012", SqlState::DIVISION_BY_ZERO), + ("P0002", SqlState::NO_DATA_FOUND), + ("XX001", SqlState::DATA_CORRUPTED), + ("22P05", SqlState::UNTRANSLATABLE_CHARACTER), + ("40003", SqlState::T_R_STATEMENT_COMPLETION_UNKNOWN), + ("22021", SqlState::CHARACTER_NOT_IN_REPERTOIRE), + ("25000", SqlState::INVALID_TRANSACTION_STATE), + ("42P15", SqlState::INVALID_SCHEMA_DEFINITION), + ("0B000", SqlState::INVALID_TRANSACTION_INITIATION), + ("22004", SqlState::NULL_VALUE_NOT_ALLOWED), + ("42804", SqlState::DATATYPE_MISMATCH), + ("42803", SqlState::GROUPING_ERROR), + ("02001", SqlState::NO_ADDITIONAL_DYNAMIC_RESULT_SETS_RETURNED), + ("25002", SqlState::BRANCH_TRANSACTION_ALREADY_ACTIVE), + ("28000", SqlState::INVALID_AUTHORIZATION_SPECIFICATION), + ("HV009", SqlState::FDW_INVALID_USE_OF_NULL_POINTER), + ("22P01", SqlState::FLOATING_POINT_EXCEPTION), + ("2B000", SqlState::DEPENDENT_PRIVILEGE_DESCRIPTORS_STILL_EXIST), + ("42723", SqlState::DUPLICATE_FUNCTION), + ("21000", SqlState::CARDINALITY_VIOLATION), + ("0Z002", SqlState::STACKED_DIAGNOSTICS_ACCESSED_WITHOUT_ACTIVE_HANDLER), + ("23505", SqlState::UNIQUE_VIOLATION), + ("HV00J", SqlState::FDW_OPTION_NAME_NOT_FOUND), + ("23P01", SqlState::EXCLUSION_VIOLATION), + ("39P03", SqlState::E_R_I_E_EVENT_TRIGGER_PROTOCOL_VIOLATED), + ("42P10", SqlState::INVALID_COLUMN_REFERENCE), + ("2202H", SqlState::INVALID_TABLESAMPLE_ARGUMENT), + ("55P04", SqlState::UNSAFE_NEW_ENUM_VALUE_USAGE), + ("P0000", SqlState::PLPGSQL_ERROR), + ("2F005", SqlState::S_R_E_FUNCTION_EXECUTED_NO_RETURN_STATEMENT), + ("HV00M", SqlState::FDW_UNABLE_TO_CREATE_REPLY), + ("0A000", SqlState::FEATURE_NOT_SUPPORTED), + ("24000", SqlState::INVALID_CURSOR_STATE), + ("25008", SqlState::HELD_CURSOR_REQUIRES_SAME_ISOLATION_LEVEL), + ("01003", SqlState::WARNING_NULL_VALUE_ELIMINATED_IN_SET_FUNCTION), + ("42712", SqlState::DUPLICATE_ALIAS), + ("HV014", SqlState::FDW_TOO_MANY_HANDLES), + ("58030", SqlState::IO_ERROR), + ("2201W", SqlState::INVALID_ROW_COUNT_IN_LIMIT_CLAUSE), + ("22033", SqlState::INVALID_SQL_JSON_SUBSCRIPT), + ("2BP01", SqlState::DEPENDENT_OBJECTS_STILL_EXIST), + ("HV005", SqlState::FDW_COLUMN_NAME_NOT_FOUND), + ("25004", SqlState::INAPPROPRIATE_ISOLATION_LEVEL_FOR_BRANCH_TRANSACTION), + ("54000", SqlState::PROGRAM_LIMIT_EXCEEDED), + ("20000", SqlState::CASE_NOT_FOUND), + ("2203G", SqlState::SQL_JSON_ITEM_CANNOT_BE_CAST_TO_TARGET_TYPE), + ("22038", SqlState::SINGLETON_SQL_JSON_ITEM_REQUIRED), + ("22007", SqlState::INVALID_DATETIME_FORMAT), + ("08004", SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION), + ("2200H", SqlState::SEQUENCE_GENERATOR_LIMIT_EXCEEDED), + ("HV00D", SqlState::FDW_INVALID_OPTION_NAME), + ("P0004", SqlState::ASSERT_FAILURE), + ("22018", SqlState::INVALID_CHARACTER_VALUE_FOR_CAST), + ("0L000", SqlState::INVALID_GRANTOR), + ("22P04", SqlState::BAD_COPY_FILE_FORMAT), + ("22031", SqlState::INVALID_ARGUMENT_FOR_SQL_JSON_DATETIME_FUNCTION), + ("01P01", SqlState::WARNING_DEPRECATED_FEATURE), + ("0LP01", SqlState::INVALID_GRANT_OPERATION), + ("58P02", SqlState::DUPLICATE_FILE), + ("26000", SqlState::INVALID_SQL_STATEMENT_NAME), + ("54001", SqlState::STATEMENT_TOO_COMPLEX), + ("22010", SqlState::INVALID_INDICATOR_PARAMETER_VALUE), + ("HV00C", SqlState::FDW_INVALID_OPTION_INDEX), + ("22008", SqlState::DATETIME_FIELD_OVERFLOW), + ("42P06", SqlState::DUPLICATE_SCHEMA), + ("25007", SqlState::SCHEMA_AND_DATA_STATEMENT_MIXING_NOT_SUPPORTED), + ("42P20", SqlState::WINDOWING_ERROR), + ("HV091", SqlState::FDW_INVALID_DESCRIPTOR_FIELD_IDENTIFIER), + ("HV021", SqlState::FDW_INCONSISTENT_DESCRIPTOR_INFORMATION), + ("42702", SqlState::AMBIGUOUS_COLUMN), + ("02000", SqlState::NO_DATA), + ("54011", SqlState::TOO_MANY_COLUMNS), + ("HV004", SqlState::FDW_INVALID_DATA_TYPE), + ("01006", SqlState::WARNING_PRIVILEGE_NOT_REVOKED), + ("42701", SqlState::DUPLICATE_COLUMN), + ("08P01", SqlState::PROTOCOL_VIOLATION), + ("42622", SqlState::NAME_TOO_LONG), + ("P0003", SqlState::TOO_MANY_ROWS), + ("22003", SqlState::NUMERIC_VALUE_OUT_OF_RANGE), + ("42P03", SqlState::DUPLICATE_CURSOR), + ("23001", SqlState::RESTRICT_VIOLATION), + ("57000", SqlState::OPERATOR_INTERVENTION), + ("22027", SqlState::TRIM_ERROR), + ("42P12", SqlState::INVALID_DATABASE_DEFINITION), + ("3B000", SqlState::SAVEPOINT_EXCEPTION), + ("2201B", SqlState::INVALID_REGULAR_EXPRESSION), + ("22030", SqlState::DUPLICATE_JSON_OBJECT_KEY_VALUE), + ("2F004", SqlState::S_R_E_READING_SQL_DATA_NOT_PERMITTED), + ("428C9", SqlState::GENERATED_ALWAYS), + ("2200S", SqlState::INVALID_XML_COMMENT), + ("22039", SqlState::SQL_JSON_ARRAY_NOT_FOUND), + ("42809", SqlState::WRONG_OBJECT_TYPE), + ("2201X", SqlState::INVALID_ROW_COUNT_IN_RESULT_OFFSET_CLAUSE), + ("39001", SqlState::E_R_I_E_INVALID_SQLSTATE_RETURNED), + ("25P02", SqlState::IN_FAILED_SQL_TRANSACTION), + ("0P000", SqlState::INVALID_ROLE_SPECIFICATION), + ("HV00N", SqlState::FDW_UNABLE_TO_ESTABLISH_CONNECTION), + ("53100", SqlState::DISK_FULL), + ("42601", SqlState::SYNTAX_ERROR), + ("23000", SqlState::INTEGRITY_CONSTRAINT_VIOLATION), + ("HV006", SqlState::FDW_INVALID_DATA_TYPE_DESCRIPTORS), + ("HV00B", SqlState::FDW_INVALID_HANDLE), + ("HV00Q", SqlState::FDW_SCHEMA_NOT_FOUND), + ("01000", SqlState::WARNING), + ("42883", SqlState::UNDEFINED_FUNCTION), + ("57P01", SqlState::ADMIN_SHUTDOWN), + ("22037", SqlState::NON_UNIQUE_KEYS_IN_A_JSON_OBJECT), + ("00000", SqlState::SUCCESSFUL_COMPLETION), + ("55P03", SqlState::LOCK_NOT_AVAILABLE), + ("42P01", SqlState::UNDEFINED_TABLE), + ("42830", SqlState::INVALID_FOREIGN_KEY), + ("22005", SqlState::ERROR_IN_ASSIGNMENT), + ("22025", SqlState::INVALID_ESCAPE_SEQUENCE), + ("XX002", SqlState::INDEX_CORRUPTED), + ("42P16", SqlState::INVALID_TABLE_DEFINITION), + ("55P02", SqlState::CANT_CHANGE_RUNTIME_PARAM), + ("22019", SqlState::INVALID_ESCAPE_CHARACTER), + ("P0001", SqlState::RAISE_EXCEPTION), + ("72000", SqlState::SNAPSHOT_TOO_OLD), + ("42P11", SqlState::INVALID_CURSOR_DEFINITION), + ("40P01", SqlState::T_R_DEADLOCK_DETECTED), + ("57P02", SqlState::CRASH_SHUTDOWN), + ("HV00A", SqlState::FDW_INVALID_STRING_FORMAT), + ("2F002", SqlState::S_R_E_MODIFYING_SQL_DATA_NOT_PERMITTED), + ("23503", SqlState::FOREIGN_KEY_VIOLATION), + ("40000", SqlState::TRANSACTION_ROLLBACK), + ("22032", SqlState::INVALID_JSON_TEXT), + ("2202E", SqlState::ARRAY_ELEMENT_ERROR), + ("42P19", SqlState::INVALID_RECURSION), + ("42611", SqlState::INVALID_COLUMN_DEFINITION), + ("42P13", SqlState::INVALID_FUNCTION_DEFINITION), + ("25003", SqlState::INAPPROPRIATE_ACCESS_MODE_FOR_BRANCH_TRANSACTION), + ("39P02", SqlState::E_R_I_E_SRF_PROTOCOL_VIOLATED), + ("XX000", SqlState::INTERNAL_ERROR), + ("08006", SqlState::CONNECTION_FAILURE), + ("57P04", SqlState::DATABASE_DROPPED), + ("42P07", SqlState::DUPLICATE_TABLE), + ("22P03", SqlState::INVALID_BINARY_REPRESENTATION), + ("22035", SqlState::NO_SQL_JSON_ITEM), + ("42P14", SqlState::INVALID_PSTATEMENT_DEFINITION), + ("01007", SqlState::WARNING_PRIVILEGE_NOT_GRANTED), + ("38004", SqlState::E_R_E_READING_SQL_DATA_NOT_PERMITTED), + ("42P21", SqlState::COLLATION_MISMATCH), + ("0Z000", SqlState::DIAGNOSTICS_EXCEPTION), + ("HV001", SqlState::FDW_OUT_OF_MEMORY), + ("0F000", SqlState::LOCATOR_EXCEPTION), + ("22013", SqlState::INVALID_PRECEDING_OR_FOLLOWING_SIZE), + ("2201E", SqlState::INVALID_ARGUMENT_FOR_LOG), + ("22011", SqlState::SUBSTRING_ERROR), + ("42602", SqlState::INVALID_NAME), + ("01004", SqlState::WARNING_STRING_DATA_RIGHT_TRUNCATION), + ("42P02", SqlState::UNDEFINED_PARAMETER), + ("2203C", SqlState::SQL_JSON_OBJECT_NOT_FOUND), + ("HV002", SqlState::FDW_DYNAMIC_PARAMETER_VALUE_NEEDED), + ("0F001", SqlState::L_E_INVALID_SPECIFICATION), + ("58P01", SqlState::UNDEFINED_FILE), + ("38001", SqlState::E_R_E_CONTAINING_SQL_NOT_PERMITTED), + ("42703", SqlState::UNDEFINED_COLUMN), + ("57P05", SqlState::IDLE_SESSION_TIMEOUT), + ("57P03", SqlState::CANNOT_CONNECT_NOW), + ("HV007", SqlState::FDW_INVALID_COLUMN_NAME), + ("22014", SqlState::INVALID_ARGUMENT_FOR_NTILE), + ("22P06", SqlState::NONSTANDARD_USE_OF_ESCAPE_CHARACTER), + ("2203F", SqlState::SQL_JSON_SCALAR_REQUIRED), + ("2200F", SqlState::ZERO_LENGTH_CHARACTER_STRING), + ("09000", SqlState::TRIGGERED_ACTION_EXCEPTION), + ("2201F", SqlState::INVALID_ARGUMENT_FOR_POWER_FUNCTION), + ("08003", SqlState::CONNECTION_DOES_NOT_EXIST), + ("38002", SqlState::E_R_E_MODIFYING_SQL_DATA_NOT_PERMITTED), + ("F0001", SqlState::LOCK_FILE_EXISTS), + ("42P22", SqlState::INDETERMINATE_COLLATION), + ("2200C", SqlState::INVALID_USE_OF_ESCAPE_CHARACTER), + ("2203E", SqlState::TOO_MANY_JSON_OBJECT_MEMBERS), + ("23514", SqlState::CHECK_VIOLATION), + ("22P02", SqlState::INVALID_TEXT_REPRESENTATION), + ("54023", SqlState::TOO_MANY_ARGUMENTS), + ("2200T", SqlState::INVALID_XML_PROCESSING_INSTRUCTION), + ("22016", SqlState::INVALID_ARGUMENT_FOR_NTH_VALUE), + ("25P03", SqlState::IDLE_IN_TRANSACTION_SESSION_TIMEOUT), + ("3B001", SqlState::S_E_INVALID_SPECIFICATION), + ("08001", SqlState::SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION), + ("22036", SqlState::NON_NUMERIC_SQL_JSON_ITEM), + ("3F000", SqlState::INVALID_SCHEMA_NAME), + ("39P01", SqlState::E_R_I_E_TRIGGER_PROTOCOL_VIOLATED), + ("22026", SqlState::STRING_DATA_LENGTH_MISMATCH), + ("42P17", SqlState::INVALID_OBJECT_DEFINITION), + ("22034", SqlState::MORE_THAN_ONE_SQL_JSON_ITEM), + ("HV000", SqlState::FDW_ERROR), + ("2200B", SqlState::ESCAPE_CHARACTER_CONFLICT), + ("HV008", SqlState::FDW_INVALID_COLUMN_NUMBER), + ("34000", SqlState::INVALID_CURSOR_NAME), + ("2201G", SqlState::INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION), + ("44000", SqlState::WITH_CHECK_OPTION_VIOLATION), + ("HV010", SqlState::FDW_FUNCTION_SEQUENCE_ERROR), + ("39004", SqlState::E_R_I_E_NULL_VALUE_NOT_ALLOWED), + ("22001", SqlState::STRING_DATA_RIGHT_TRUNCATION), + ("3D000", SqlState::INVALID_CATALOG_NAME), + ("25005", SqlState::NO_ACTIVE_SQL_TRANSACTION_FOR_BRANCH_TRANSACTION), + ("2200L", SqlState::NOT_AN_XML_DOCUMENT), + ("27000", SqlState::TRIGGERED_DATA_CHANGE_VIOLATION), + ("HV090", SqlState::FDW_INVALID_STRING_LENGTH_OR_BUFFER_LENGTH), + ("42939", SqlState::RESERVED_NAME), + ("58000", SqlState::SYSTEM_ERROR), + ("2200M", SqlState::INVALID_XML_DOCUMENT), + ("HV00L", SqlState::FDW_UNABLE_TO_CREATE_EXECUTION), + ("57014", SqlState::QUERY_CANCELED), + ("23502", SqlState::NOT_NULL_VIOLATION), + ("22002", SqlState::NULL_VALUE_NO_INDICATOR_PARAMETER), + ("HV00R", SqlState::FDW_TABLE_NOT_FOUND), + ("HV00P", SqlState::FDW_NO_SCHEMAS), + ("38003", SqlState::E_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED), + ("39000", SqlState::EXTERNAL_ROUTINE_INVOCATION_EXCEPTION), + ("22015", SqlState::INTERVAL_FIELD_OVERFLOW), + ("HV00K", SqlState::FDW_REPLY_HANDLE), + ("HV024", SqlState::FDW_INVALID_ATTRIBUTE_VALUE), + ("2200D", SqlState::INVALID_ESCAPE_OCTET), + ("08007", SqlState::TRANSACTION_RESOLUTION_UNKNOWN), + ("2F003", SqlState::S_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED), + ("42725", SqlState::AMBIGUOUS_FUNCTION), + ("2203A", SqlState::SQL_JSON_MEMBER_NOT_FOUND), + ("42846", SqlState::CANNOT_COERCE), + ("42P04", SqlState::DUPLICATE_DATABASE), + ("42000", SqlState::SYNTAX_ERROR_OR_ACCESS_RULE_VIOLATION), + ("2203B", SqlState::SQL_JSON_NUMBER_NOT_FOUND), + ("42P05", SqlState::DUPLICATE_PSTATEMENT), + ("53300", SqlState::TOO_MANY_CONNECTIONS), + ("53400", SqlState::CONFIGURATION_LIMIT_EXCEEDED), + ("42704", SqlState::UNDEFINED_OBJECT), + ("2202G", SqlState::INVALID_TABLESAMPLE_REPEAT), + ("22023", SqlState::INVALID_PARAMETER_VALUE), + ("53000", SqlState::INSUFFICIENT_RESOURCES), + ], +}; diff --git a/libs/proxy/tokio-postgres2/src/generic_client.rs b/libs/proxy/tokio-postgres2/src/generic_client.rs new file mode 100644 index 0000000000..768213f8ed --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/generic_client.rs @@ -0,0 +1,64 @@ +use crate::query::RowStream; +use crate::types::Type; +use crate::{Client, Error, Transaction}; +use async_trait::async_trait; +use postgres_protocol2::Oid; + +mod private { + pub trait Sealed {} +} + +/// A trait allowing abstraction over connections and transactions. +/// +/// This trait is "sealed", and cannot be implemented outside of this crate. +#[async_trait] +pub trait GenericClient: private::Sealed { + /// Like `Client::query_raw_txt`. + async fn query_raw_txt(&self, statement: &str, params: I) -> Result + where + S: AsRef + Sync + Send, + I: IntoIterator> + Sync + Send, + I::IntoIter: ExactSizeIterator + Sync + Send; + + /// Query for type information + async fn get_type(&self, oid: Oid) -> Result; +} + +impl private::Sealed for Client {} + +#[async_trait] +impl GenericClient for Client { + async fn query_raw_txt(&self, statement: &str, params: I) -> Result + where + S: AsRef + Sync + Send, + I: IntoIterator> + Sync + Send, + I::IntoIter: ExactSizeIterator + Sync + Send, + { + self.query_raw_txt(statement, params).await + } + + /// Query for type information + async fn get_type(&self, oid: Oid) -> Result { + self.get_type(oid).await + } +} + +impl private::Sealed for Transaction<'_> {} + +#[async_trait] +#[allow(clippy::needless_lifetimes)] +impl GenericClient for Transaction<'_> { + async fn query_raw_txt(&self, statement: &str, params: I) -> Result + where + S: AsRef + Sync + Send, + I: IntoIterator> + Sync + Send, + I::IntoIter: ExactSizeIterator + Sync + Send, + { + self.query_raw_txt(statement, params).await + } + + /// Query for type information + async fn get_type(&self, oid: Oid) -> Result { + self.client().get_type(oid).await + } +} diff --git a/libs/proxy/tokio-postgres2/src/lib.rs b/libs/proxy/tokio-postgres2/src/lib.rs new file mode 100644 index 0000000000..72ba8172b2 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/lib.rs @@ -0,0 +1,148 @@ +//! An asynchronous, pipelined, PostgreSQL client. +#![warn(rust_2018_idioms, clippy::all, missing_docs)] + +pub use crate::cancel_token::CancelToken; +pub use crate::client::Client; +pub use crate::config::Config; +pub use crate::connection::Connection; +use crate::error::DbError; +pub use crate::error::Error; +pub use crate::generic_client::GenericClient; +pub use crate::query::RowStream; +pub use crate::row::{Row, SimpleQueryRow}; +pub use crate::simple_query::SimpleQueryStream; +pub use crate::statement::{Column, Statement}; +use crate::tls::MakeTlsConnect; +pub use crate::tls::NoTls; +pub use crate::to_statement::ToStatement; +pub use crate::transaction::Transaction; +pub use crate::transaction_builder::{IsolationLevel, TransactionBuilder}; +use crate::types::ToSql; +use postgres_protocol2::message::backend::ReadyForQueryBody; +use tokio::net::TcpStream; + +/// After executing a query, the connection will be in one of these states +#[derive(Clone, Copy, Debug, PartialEq)] +#[repr(u8)] +pub enum ReadyForQueryStatus { + /// Connection state is unknown + Unknown, + /// Connection is idle (no transactions) + Idle = b'I', + /// Connection is in a transaction block + Transaction = b'T', + /// Connection is in a failed transaction block + FailedTransaction = b'E', +} + +impl From for ReadyForQueryStatus { + fn from(value: ReadyForQueryBody) -> Self { + match value.status() { + b'I' => Self::Idle, + b'T' => Self::Transaction, + b'E' => Self::FailedTransaction, + _ => Self::Unknown, + } + } +} + +mod cancel_query; +mod cancel_query_raw; +mod cancel_token; +mod client; +mod codec; +pub mod config; +mod connect; +mod connect_raw; +mod connect_socket; +mod connect_tls; +mod connection; +pub mod error; +mod generic_client; +pub mod maybe_tls_stream; +mod prepare; +mod query; +pub mod row; +mod simple_query; +mod statement; +pub mod tls; +mod to_statement; +mod transaction; +mod transaction_builder; +pub mod types; + +/// A convenience function which parses a connection string and connects to the database. +/// +/// See the documentation for [`Config`] for details on the connection string format. +/// +/// Requires the `runtime` Cargo feature (enabled by default). +/// +/// [`Config`]: config/struct.Config.html +pub async fn connect( + config: &str, + tls: T, +) -> Result<(Client, Connection), Error> +where + T: MakeTlsConnect, +{ + let config = config.parse::()?; + config.connect(tls).await +} + +/// An asynchronous notification. +#[derive(Clone, Debug)] +pub struct Notification { + process_id: i32, + channel: String, + payload: String, +} + +impl Notification { + /// The process ID of the notifying backend process. + pub fn process_id(&self) -> i32 { + self.process_id + } + + /// The name of the channel that the notify has been raised on. + pub fn channel(&self) -> &str { + &self.channel + } + + /// The "payload" string passed from the notifying process. + pub fn payload(&self) -> &str { + &self.payload + } +} + +/// An asynchronous message from the server. +#[allow(clippy::large_enum_variant)] +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum AsyncMessage { + /// A notice. + /// + /// Notices use the same format as errors, but aren't "errors" per-se. + Notice(DbError), + /// A notification. + /// + /// Connections can subscribe to notifications with the `LISTEN` command. + Notification(Notification), +} + +/// Message returned by the `SimpleQuery` stream. +#[derive(Debug)] +#[non_exhaustive] +pub enum SimpleQueryMessage { + /// A row of data. + Row(SimpleQueryRow), + /// A statement in the query has completed. + /// + /// The number of rows modified or selected is returned. + CommandComplete(u64), +} + +fn slice_iter<'a>( + s: &'a [&'a (dyn ToSql + Sync)], +) -> impl ExactSizeIterator + 'a { + s.iter().map(|s| *s as _) +} diff --git a/libs/proxy/tokio-postgres2/src/maybe_tls_stream.rs b/libs/proxy/tokio-postgres2/src/maybe_tls_stream.rs new file mode 100644 index 0000000000..9a7e248997 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/maybe_tls_stream.rs @@ -0,0 +1,77 @@ +//! MaybeTlsStream. +//! +//! Represents a stream that may or may not be encrypted with TLS. +use crate::tls::{ChannelBinding, TlsStream}; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +/// A stream that may or may not be encrypted with TLS. +pub enum MaybeTlsStream { + /// An unencrypted stream. + Raw(S), + /// An encrypted stream. + Tls(T), +} + +impl AsyncRead for MaybeTlsStream +where + S: AsyncRead + Unpin, + T: AsyncRead + Unpin, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match &mut *self { + MaybeTlsStream::Raw(s) => Pin::new(s).poll_read(cx, buf), + MaybeTlsStream::Tls(s) => Pin::new(s).poll_read(cx, buf), + } + } +} + +impl AsyncWrite for MaybeTlsStream +where + S: AsyncWrite + Unpin, + T: AsyncWrite + Unpin, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match &mut *self { + MaybeTlsStream::Raw(s) => Pin::new(s).poll_write(cx, buf), + MaybeTlsStream::Tls(s) => Pin::new(s).poll_write(cx, buf), + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match &mut *self { + MaybeTlsStream::Raw(s) => Pin::new(s).poll_flush(cx), + MaybeTlsStream::Tls(s) => Pin::new(s).poll_flush(cx), + } + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match &mut *self { + MaybeTlsStream::Raw(s) => Pin::new(s).poll_shutdown(cx), + MaybeTlsStream::Tls(s) => Pin::new(s).poll_shutdown(cx), + } + } +} + +impl TlsStream for MaybeTlsStream +where + S: AsyncRead + AsyncWrite + Unpin, + T: TlsStream + Unpin, +{ + fn channel_binding(&self) -> ChannelBinding { + match self { + MaybeTlsStream::Raw(_) => ChannelBinding::none(), + MaybeTlsStream::Tls(s) => s.channel_binding(), + } + } +} diff --git a/libs/proxy/tokio-postgres2/src/prepare.rs b/libs/proxy/tokio-postgres2/src/prepare.rs new file mode 100644 index 0000000000..da0c755c5b --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/prepare.rs @@ -0,0 +1,262 @@ +use crate::client::InnerClient; +use crate::codec::FrontendMessage; +use crate::connection::RequestMessages; +use crate::error::SqlState; +use crate::types::{Field, Kind, Oid, Type}; +use crate::{query, slice_iter}; +use crate::{Column, Error, Statement}; +use bytes::Bytes; +use fallible_iterator::FallibleIterator; +use futures_util::{pin_mut, TryStreamExt}; +use log::debug; +use postgres_protocol2::message::backend::Message; +use postgres_protocol2::message::frontend; +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; + +pub(crate) const TYPEINFO_QUERY: &str = "\ +SELECT t.typname, t.typtype, t.typelem, r.rngsubtype, t.typbasetype, n.nspname, t.typrelid +FROM pg_catalog.pg_type t +LEFT OUTER JOIN pg_catalog.pg_range r ON r.rngtypid = t.oid +INNER JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid +WHERE t.oid = $1 +"; + +// Range types weren't added until Postgres 9.2, so pg_range may not exist +const TYPEINFO_FALLBACK_QUERY: &str = "\ +SELECT t.typname, t.typtype, t.typelem, NULL::OID, t.typbasetype, n.nspname, t.typrelid +FROM pg_catalog.pg_type t +INNER JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid +WHERE t.oid = $1 +"; + +const TYPEINFO_ENUM_QUERY: &str = "\ +SELECT enumlabel +FROM pg_catalog.pg_enum +WHERE enumtypid = $1 +ORDER BY enumsortorder +"; + +// Postgres 9.0 didn't have enumsortorder +const TYPEINFO_ENUM_FALLBACK_QUERY: &str = "\ +SELECT enumlabel +FROM pg_catalog.pg_enum +WHERE enumtypid = $1 +ORDER BY oid +"; + +pub(crate) const TYPEINFO_COMPOSITE_QUERY: &str = "\ +SELECT attname, atttypid +FROM pg_catalog.pg_attribute +WHERE attrelid = $1 +AND NOT attisdropped +AND attnum > 0 +ORDER BY attnum +"; + +static NEXT_ID: AtomicUsize = AtomicUsize::new(0); + +pub async fn prepare( + client: &Arc, + query: &str, + types: &[Type], +) -> Result { + let name = format!("s{}", NEXT_ID.fetch_add(1, Ordering::SeqCst)); + let buf = encode(client, &name, query, types)?; + let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + + match responses.next().await? { + Message::ParseComplete => {} + _ => return Err(Error::unexpected_message()), + } + + let parameter_description = match responses.next().await? { + Message::ParameterDescription(body) => body, + _ => return Err(Error::unexpected_message()), + }; + + let row_description = match responses.next().await? { + Message::RowDescription(body) => Some(body), + Message::NoData => None, + _ => return Err(Error::unexpected_message()), + }; + + let mut parameters = vec![]; + let mut it = parameter_description.parameters(); + while let Some(oid) = it.next().map_err(Error::parse)? { + let type_ = get_type(client, oid).await?; + parameters.push(type_); + } + + let mut columns = vec![]; + if let Some(row_description) = row_description { + let mut it = row_description.fields(); + while let Some(field) = it.next().map_err(Error::parse)? { + let type_ = get_type(client, field.type_oid()).await?; + let column = Column::new(field.name().to_string(), type_, field); + columns.push(column); + } + } + + Ok(Statement::new(client, name, parameters, columns)) +} + +fn prepare_rec<'a>( + client: &'a Arc, + query: &'a str, + types: &'a [Type], +) -> Pin> + 'a + Send>> { + Box::pin(prepare(client, query, types)) +} + +fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Result { + if types.is_empty() { + debug!("preparing query {}: {}", name, query); + } else { + debug!("preparing query {} with types {:?}: {}", name, types, query); + } + + client.with_buf(|buf| { + frontend::parse(name, query, types.iter().map(Type::oid), buf).map_err(Error::encode)?; + frontend::describe(b'S', name, buf).map_err(Error::encode)?; + frontend::sync(buf); + Ok(buf.split().freeze()) + }) +} + +pub async fn get_type(client: &Arc, oid: Oid) -> Result { + if let Some(type_) = Type::from_oid(oid) { + return Ok(type_); + } + + if let Some(type_) = client.type_(oid) { + return Ok(type_); + } + + let stmt = typeinfo_statement(client).await?; + + let rows = query::query(client, stmt, slice_iter(&[&oid])).await?; + pin_mut!(rows); + + let row = match rows.try_next().await? { + Some(row) => row, + None => return Err(Error::unexpected_message()), + }; + + let name: String = row.try_get(0)?; + let type_: i8 = row.try_get(1)?; + let elem_oid: Oid = row.try_get(2)?; + let rngsubtype: Option = row.try_get(3)?; + let basetype: Oid = row.try_get(4)?; + let schema: String = row.try_get(5)?; + let relid: Oid = row.try_get(6)?; + + let kind = if type_ == b'e' as i8 { + let variants = get_enum_variants(client, oid).await?; + Kind::Enum(variants) + } else if type_ == b'p' as i8 { + Kind::Pseudo + } else if basetype != 0 { + let type_ = get_type_rec(client, basetype).await?; + Kind::Domain(type_) + } else if elem_oid != 0 { + let type_ = get_type_rec(client, elem_oid).await?; + Kind::Array(type_) + } else if relid != 0 { + let fields = get_composite_fields(client, relid).await?; + Kind::Composite(fields) + } else if let Some(rngsubtype) = rngsubtype { + let type_ = get_type_rec(client, rngsubtype).await?; + Kind::Range(type_) + } else { + Kind::Simple + }; + + let type_ = Type::new(name, oid, kind, schema); + client.set_type(oid, &type_); + + Ok(type_) +} + +fn get_type_rec<'a>( + client: &'a Arc, + oid: Oid, +) -> Pin> + Send + 'a>> { + Box::pin(get_type(client, oid)) +} + +async fn typeinfo_statement(client: &Arc) -> Result { + if let Some(stmt) = client.typeinfo() { + return Ok(stmt); + } + + let stmt = match prepare_rec(client, TYPEINFO_QUERY, &[]).await { + Ok(stmt) => stmt, + Err(ref e) if e.code() == Some(&SqlState::UNDEFINED_TABLE) => { + prepare_rec(client, TYPEINFO_FALLBACK_QUERY, &[]).await? + } + Err(e) => return Err(e), + }; + + client.set_typeinfo(&stmt); + Ok(stmt) +} + +async fn get_enum_variants(client: &Arc, oid: Oid) -> Result, Error> { + let stmt = typeinfo_enum_statement(client).await?; + + query::query(client, stmt, slice_iter(&[&oid])) + .await? + .and_then(|row| async move { row.try_get(0) }) + .try_collect() + .await +} + +async fn typeinfo_enum_statement(client: &Arc) -> Result { + if let Some(stmt) = client.typeinfo_enum() { + return Ok(stmt); + } + + let stmt = match prepare_rec(client, TYPEINFO_ENUM_QUERY, &[]).await { + Ok(stmt) => stmt, + Err(ref e) if e.code() == Some(&SqlState::UNDEFINED_COLUMN) => { + prepare_rec(client, TYPEINFO_ENUM_FALLBACK_QUERY, &[]).await? + } + Err(e) => return Err(e), + }; + + client.set_typeinfo_enum(&stmt); + Ok(stmt) +} + +async fn get_composite_fields(client: &Arc, oid: Oid) -> Result, Error> { + let stmt = typeinfo_composite_statement(client).await?; + + let rows = query::query(client, stmt, slice_iter(&[&oid])) + .await? + .try_collect::>() + .await?; + + let mut fields = vec![]; + for row in rows { + let name = row.try_get(0)?; + let oid = row.try_get(1)?; + let type_ = get_type_rec(client, oid).await?; + fields.push(Field::new(name, type_)); + } + + Ok(fields) +} + +async fn typeinfo_composite_statement(client: &Arc) -> Result { + if let Some(stmt) = client.typeinfo_composite() { + return Ok(stmt); + } + + let stmt = prepare_rec(client, TYPEINFO_COMPOSITE_QUERY, &[]).await?; + + client.set_typeinfo_composite(&stmt); + Ok(stmt) +} diff --git a/libs/proxy/tokio-postgres2/src/query.rs b/libs/proxy/tokio-postgres2/src/query.rs new file mode 100644 index 0000000000..534195a707 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/query.rs @@ -0,0 +1,340 @@ +use crate::client::{InnerClient, Responses}; +use crate::codec::FrontendMessage; +use crate::connection::RequestMessages; +use crate::types::IsNull; +use crate::{Column, Error, ReadyForQueryStatus, Row, Statement}; +use bytes::{BufMut, Bytes, BytesMut}; +use fallible_iterator::FallibleIterator; +use futures_util::{ready, Stream}; +use log::{debug, log_enabled, Level}; +use pin_project_lite::pin_project; +use postgres_protocol2::message::backend::Message; +use postgres_protocol2::message::frontend; +use postgres_types2::{Format, ToSql, Type}; +use std::fmt; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +struct BorrowToSqlParamsDebug<'a>(&'a [&'a (dyn ToSql + Sync)]); + +impl fmt::Debug for BorrowToSqlParamsDebug<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_list().entries(self.0.iter()).finish() + } +} + +pub async fn query<'a, I>( + client: &InnerClient, + statement: Statement, + params: I, +) -> Result +where + I: IntoIterator, + I::IntoIter: ExactSizeIterator, +{ + let buf = if log_enabled!(Level::Debug) { + let params = params.into_iter().collect::>(); + debug!( + "executing statement {} with parameters: {:?}", + statement.name(), + BorrowToSqlParamsDebug(params.as_slice()), + ); + encode(client, &statement, params)? + } else { + encode(client, &statement, params)? + }; + let responses = start(client, buf).await?; + Ok(RowStream { + statement, + responses, + command_tag: None, + status: ReadyForQueryStatus::Unknown, + output_format: Format::Binary, + _p: PhantomPinned, + }) +} + +pub async fn query_txt( + client: &Arc, + query: &str, + params: I, +) -> Result +where + S: AsRef, + I: IntoIterator>, + I::IntoIter: ExactSizeIterator, +{ + let params = params.into_iter(); + + let buf = client.with_buf(|buf| { + frontend::parse( + "", // unnamed prepared statement + query, // query to parse + std::iter::empty(), // give no type info + buf, + ) + .map_err(Error::encode)?; + frontend::describe(b'S', "", buf).map_err(Error::encode)?; + // Bind, pass params as text, retrieve as binary + match frontend::bind( + "", // empty string selects the unnamed portal + "", // unnamed prepared statement + std::iter::empty(), // all parameters use the default format (text) + params, + |param, buf| match param { + Some(param) => { + buf.put_slice(param.as_ref().as_bytes()); + Ok(postgres_protocol2::IsNull::No) + } + None => Ok(postgres_protocol2::IsNull::Yes), + }, + Some(0), // all text + buf, + ) { + Ok(()) => Ok(()), + Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, 0)), + Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)), + }?; + + // Execute + frontend::execute("", 0, buf).map_err(Error::encode)?; + // Sync + frontend::sync(buf); + + Ok(buf.split().freeze()) + })?; + + // now read the responses + let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + + match responses.next().await? { + Message::ParseComplete => {} + _ => return Err(Error::unexpected_message()), + } + + let parameter_description = match responses.next().await? { + Message::ParameterDescription(body) => body, + _ => return Err(Error::unexpected_message()), + }; + + let row_description = match responses.next().await? { + Message::RowDescription(body) => Some(body), + Message::NoData => None, + _ => return Err(Error::unexpected_message()), + }; + + match responses.next().await? { + Message::BindComplete => {} + _ => return Err(Error::unexpected_message()), + } + + let mut parameters = vec![]; + let mut it = parameter_description.parameters(); + while let Some(oid) = it.next().map_err(Error::parse)? { + let type_ = Type::from_oid(oid).unwrap_or(Type::UNKNOWN); + parameters.push(type_); + } + + let mut columns = vec![]; + if let Some(row_description) = row_description { + let mut it = row_description.fields(); + while let Some(field) = it.next().map_err(Error::parse)? { + let type_ = Type::from_oid(field.type_oid()).unwrap_or(Type::UNKNOWN); + let column = Column::new(field.name().to_string(), type_, field); + columns.push(column); + } + } + + Ok(RowStream { + statement: Statement::new_anonymous(parameters, columns), + responses, + command_tag: None, + status: ReadyForQueryStatus::Unknown, + output_format: Format::Text, + _p: PhantomPinned, + }) +} + +pub async fn execute<'a, I>( + client: &InnerClient, + statement: Statement, + params: I, +) -> Result +where + I: IntoIterator, + I::IntoIter: ExactSizeIterator, +{ + let buf = if log_enabled!(Level::Debug) { + let params = params.into_iter().collect::>(); + debug!( + "executing statement {} with parameters: {:?}", + statement.name(), + BorrowToSqlParamsDebug(params.as_slice()), + ); + encode(client, &statement, params)? + } else { + encode(client, &statement, params)? + }; + let mut responses = start(client, buf).await?; + + let mut rows = 0; + loop { + match responses.next().await? { + Message::DataRow(_) => {} + Message::CommandComplete(body) => { + rows = body + .tag() + .map_err(Error::parse)? + .rsplit(' ') + .next() + .unwrap() + .parse() + .unwrap_or(0); + } + Message::EmptyQueryResponse => rows = 0, + Message::ReadyForQuery(_) => return Ok(rows), + _ => return Err(Error::unexpected_message()), + } + } +} + +async fn start(client: &InnerClient, buf: Bytes) -> Result { + let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + + match responses.next().await? { + Message::BindComplete => {} + _ => return Err(Error::unexpected_message()), + } + + Ok(responses) +} + +pub fn encode<'a, I>(client: &InnerClient, statement: &Statement, params: I) -> Result +where + I: IntoIterator, + I::IntoIter: ExactSizeIterator, +{ + client.with_buf(|buf| { + encode_bind(statement, params, "", buf)?; + frontend::execute("", 0, buf).map_err(Error::encode)?; + frontend::sync(buf); + Ok(buf.split().freeze()) + }) +} + +pub fn encode_bind<'a, I>( + statement: &Statement, + params: I, + portal: &str, + buf: &mut BytesMut, +) -> Result<(), Error> +where + I: IntoIterator, + I::IntoIter: ExactSizeIterator, +{ + let param_types = statement.params(); + let params = params.into_iter(); + + assert!( + param_types.len() == params.len(), + "expected {} parameters but got {}", + param_types.len(), + params.len() + ); + + let (param_formats, params): (Vec<_>, Vec<_>) = params + .zip(param_types.iter()) + .map(|(p, ty)| (p.encode_format(ty) as i16, p)) + .unzip(); + + let params = params.into_iter(); + + let mut error_idx = 0; + let r = frontend::bind( + portal, + statement.name(), + param_formats, + params.zip(param_types).enumerate(), + |(idx, (param, ty)), buf| match param.to_sql_checked(ty, buf) { + Ok(IsNull::No) => Ok(postgres_protocol2::IsNull::No), + Ok(IsNull::Yes) => Ok(postgres_protocol2::IsNull::Yes), + Err(e) => { + error_idx = idx; + Err(e) + } + }, + Some(1), + buf, + ); + match r { + Ok(()) => Ok(()), + Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, error_idx)), + Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)), + } +} + +pin_project! { + /// A stream of table rows. + pub struct RowStream { + statement: Statement, + responses: Responses, + command_tag: Option, + output_format: Format, + status: ReadyForQueryStatus, + #[pin] + _p: PhantomPinned, + } +} + +impl Stream for RowStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + loop { + match ready!(this.responses.poll_next(cx)?) { + Message::DataRow(body) => { + return Poll::Ready(Some(Ok(Row::new( + this.statement.clone(), + body, + *this.output_format, + )?))) + } + Message::EmptyQueryResponse | Message::PortalSuspended => {} + Message::CommandComplete(body) => { + if let Ok(tag) = body.tag() { + *this.command_tag = Some(tag.to_string()); + } + } + Message::ReadyForQuery(status) => { + *this.status = status.into(); + return Poll::Ready(None); + } + _ => return Poll::Ready(Some(Err(Error::unexpected_message()))), + } + } + } +} + +impl RowStream { + /// Returns information about the columns of data in the row. + pub fn columns(&self) -> &[Column] { + self.statement.columns() + } + + /// Returns the command tag of this query. + /// + /// This is only available after the stream has been exhausted. + pub fn command_tag(&self) -> Option { + self.command_tag.clone() + } + + /// Returns if the connection is ready for querying, with the status of the connection. + /// + /// This might be available only after the stream has been exhausted. + pub fn ready_status(&self) -> ReadyForQueryStatus { + self.status + } +} diff --git a/libs/proxy/tokio-postgres2/src/row.rs b/libs/proxy/tokio-postgres2/src/row.rs new file mode 100644 index 0000000000..10e130707d --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/row.rs @@ -0,0 +1,300 @@ +//! Rows. + +use crate::row::sealed::{AsName, Sealed}; +use crate::simple_query::SimpleColumn; +use crate::statement::Column; +use crate::types::{FromSql, Type, WrongType}; +use crate::{Error, Statement}; +use fallible_iterator::FallibleIterator; +use postgres_protocol2::message::backend::DataRowBody; +use postgres_types2::{Format, WrongFormat}; +use std::fmt; +use std::ops::Range; +use std::str; +use std::sync::Arc; + +mod sealed { + pub trait Sealed {} + + pub trait AsName { + fn as_name(&self) -> &str; + } +} + +impl AsName for Column { + fn as_name(&self) -> &str { + self.name() + } +} + +impl AsName for String { + fn as_name(&self) -> &str { + self + } +} + +/// A trait implemented by types that can index into columns of a row. +/// +/// This cannot be implemented outside of this crate. +pub trait RowIndex: Sealed { + #[doc(hidden)] + fn __idx(&self, columns: &[T]) -> Option + where + T: AsName; +} + +impl Sealed for usize {} + +impl RowIndex for usize { + #[inline] + fn __idx(&self, columns: &[T]) -> Option + where + T: AsName, + { + if *self >= columns.len() { + None + } else { + Some(*self) + } + } +} + +impl Sealed for str {} + +impl RowIndex for str { + #[inline] + fn __idx(&self, columns: &[T]) -> Option + where + T: AsName, + { + if let Some(idx) = columns.iter().position(|d| d.as_name() == self) { + return Some(idx); + }; + + // FIXME ASCII-only case insensitivity isn't really the right thing to + // do. Postgres itself uses a dubious wrapper around tolower and JDBC + // uses the US locale. + columns + .iter() + .position(|d| d.as_name().eq_ignore_ascii_case(self)) + } +} + +impl Sealed for &T where T: ?Sized + Sealed {} + +impl RowIndex for &T +where + T: ?Sized + RowIndex, +{ + #[inline] + fn __idx(&self, columns: &[U]) -> Option + where + U: AsName, + { + T::__idx(*self, columns) + } +} + +/// A row of data returned from the database by a query. +pub struct Row { + statement: Statement, + output_format: Format, + body: DataRowBody, + ranges: Vec>>, +} + +impl fmt::Debug for Row { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Row") + .field("columns", &self.columns()) + .finish() + } +} + +impl Row { + pub(crate) fn new( + statement: Statement, + body: DataRowBody, + output_format: Format, + ) -> Result { + let ranges = body.ranges().collect().map_err(Error::parse)?; + Ok(Row { + statement, + body, + ranges, + output_format, + }) + } + + /// Returns information about the columns of data in the row. + pub fn columns(&self) -> &[Column] { + self.statement.columns() + } + + /// Determines if the row contains no values. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns the number of values in the row. + pub fn len(&self) -> usize { + self.columns().len() + } + + /// Deserializes a value from the row. + /// + /// The value can be specified either by its numeric index in the row, or by its column name. + /// + /// # Panics + /// + /// Panics if the index is out of bounds or if the value cannot be converted to the specified type. + pub fn get<'a, I, T>(&'a self, idx: I) -> T + where + I: RowIndex + fmt::Display, + T: FromSql<'a>, + { + match self.get_inner(&idx) { + Ok(ok) => ok, + Err(err) => panic!("error retrieving column {}: {}", idx, err), + } + } + + /// Like `Row::get`, but returns a `Result` rather than panicking. + pub fn try_get<'a, I, T>(&'a self, idx: I) -> Result + where + I: RowIndex + fmt::Display, + T: FromSql<'a>, + { + self.get_inner(&idx) + } + + fn get_inner<'a, I, T>(&'a self, idx: &I) -> Result + where + I: RowIndex + fmt::Display, + T: FromSql<'a>, + { + let idx = match idx.__idx(self.columns()) { + Some(idx) => idx, + None => return Err(Error::column(idx.to_string())), + }; + + let ty = self.columns()[idx].type_(); + if !T::accepts(ty) { + return Err(Error::from_sql( + Box::new(WrongType::new::(ty.clone())), + idx, + )); + } + + FromSql::from_sql_nullable(ty, self.col_buffer(idx)).map_err(|e| Error::from_sql(e, idx)) + } + + /// Get the raw bytes for the column at the given index. + fn col_buffer(&self, idx: usize) -> Option<&[u8]> { + let range = self.ranges.get(idx)?.to_owned()?; + Some(&self.body.buffer()[range]) + } + + /// Interpret the column at the given index as text + /// + /// Useful when using query_raw_txt() which sets text transfer mode + pub fn as_text(&self, idx: usize) -> Result, Error> { + if self.output_format == Format::Text { + match self.col_buffer(idx) { + Some(raw) => { + FromSql::from_sql(&Type::TEXT, raw).map_err(|e| Error::from_sql(e, idx)) + } + None => Ok(None), + } + } else { + Err(Error::from_sql(Box::new(WrongFormat {}), idx)) + } + } + + /// Row byte size + pub fn body_len(&self) -> usize { + self.body.buffer().len() + } +} + +impl AsName for SimpleColumn { + fn as_name(&self) -> &str { + self.name() + } +} + +/// A row of data returned from the database by a simple query. +#[derive(Debug)] +pub struct SimpleQueryRow { + columns: Arc<[SimpleColumn]>, + body: DataRowBody, + ranges: Vec>>, +} + +impl SimpleQueryRow { + #[allow(clippy::new_ret_no_self)] + pub(crate) fn new( + columns: Arc<[SimpleColumn]>, + body: DataRowBody, + ) -> Result { + let ranges = body.ranges().collect().map_err(Error::parse)?; + Ok(SimpleQueryRow { + columns, + body, + ranges, + }) + } + + /// Returns information about the columns of data in the row. + pub fn columns(&self) -> &[SimpleColumn] { + &self.columns + } + + /// Determines if the row contains no values. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns the number of values in the row. + pub fn len(&self) -> usize { + self.columns.len() + } + + /// Returns a value from the row. + /// + /// The value can be specified either by its numeric index in the row, or by its column name. + /// + /// # Panics + /// + /// Panics if the index is out of bounds or if the value cannot be converted to the specified type. + pub fn get(&self, idx: I) -> Option<&str> + where + I: RowIndex + fmt::Display, + { + match self.get_inner(&idx) { + Ok(ok) => ok, + Err(err) => panic!("error retrieving column {}: {}", idx, err), + } + } + + /// Like `SimpleQueryRow::get`, but returns a `Result` rather than panicking. + pub fn try_get(&self, idx: I) -> Result, Error> + where + I: RowIndex + fmt::Display, + { + self.get_inner(&idx) + } + + fn get_inner(&self, idx: &I) -> Result, Error> + where + I: RowIndex + fmt::Display, + { + let idx = match idx.__idx(&self.columns) { + Some(idx) => idx, + None => return Err(Error::column(idx.to_string())), + }; + + let buf = self.ranges[idx].clone().map(|r| &self.body.buffer()[r]); + FromSql::from_sql_nullable(&Type::TEXT, buf).map_err(|e| Error::from_sql(e, idx)) + } +} diff --git a/libs/proxy/tokio-postgres2/src/simple_query.rs b/libs/proxy/tokio-postgres2/src/simple_query.rs new file mode 100644 index 0000000000..fb2550377b --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/simple_query.rs @@ -0,0 +1,142 @@ +use crate::client::{InnerClient, Responses}; +use crate::codec::FrontendMessage; +use crate::connection::RequestMessages; +use crate::{Error, ReadyForQueryStatus, SimpleQueryMessage, SimpleQueryRow}; +use bytes::Bytes; +use fallible_iterator::FallibleIterator; +use futures_util::{ready, Stream}; +use log::debug; +use pin_project_lite::pin_project; +use postgres_protocol2::message::backend::Message; +use postgres_protocol2::message::frontend; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +/// Information about a column of a single query row. +#[derive(Debug)] +pub struct SimpleColumn { + name: String, +} + +impl SimpleColumn { + pub(crate) fn new(name: String) -> SimpleColumn { + SimpleColumn { name } + } + + /// Returns the name of the column. + pub fn name(&self) -> &str { + &self.name + } +} + +pub async fn simple_query(client: &InnerClient, query: &str) -> Result { + debug!("executing simple query: {}", query); + + let buf = encode(client, query)?; + let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + + Ok(SimpleQueryStream { + responses, + columns: None, + status: ReadyForQueryStatus::Unknown, + _p: PhantomPinned, + }) +} + +pub async fn batch_execute( + client: &InnerClient, + query: &str, +) -> Result { + debug!("executing statement batch: {}", query); + + let buf = encode(client, query)?; + let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + + loop { + match responses.next().await? { + Message::ReadyForQuery(status) => return Ok(status.into()), + Message::CommandComplete(_) + | Message::EmptyQueryResponse + | Message::RowDescription(_) + | Message::DataRow(_) => {} + _ => return Err(Error::unexpected_message()), + } + } +} + +pub(crate) fn encode(client: &InnerClient, query: &str) -> Result { + client.with_buf(|buf| { + frontend::query(query, buf).map_err(Error::encode)?; + Ok(buf.split().freeze()) + }) +} + +pin_project! { + /// A stream of simple query results. + pub struct SimpleQueryStream { + responses: Responses, + columns: Option>, + status: ReadyForQueryStatus, + #[pin] + _p: PhantomPinned, + } +} + +impl SimpleQueryStream { + /// Returns if the connection is ready for querying, with the status of the connection. + /// + /// This might be available only after the stream has been exhausted. + pub fn ready_status(&self) -> ReadyForQueryStatus { + self.status + } +} + +impl Stream for SimpleQueryStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + loop { + match ready!(this.responses.poll_next(cx)?) { + Message::CommandComplete(body) => { + let rows = body + .tag() + .map_err(Error::parse)? + .rsplit(' ') + .next() + .unwrap() + .parse() + .unwrap_or(0); + return Poll::Ready(Some(Ok(SimpleQueryMessage::CommandComplete(rows)))); + } + Message::EmptyQueryResponse => { + return Poll::Ready(Some(Ok(SimpleQueryMessage::CommandComplete(0)))); + } + Message::RowDescription(body) => { + let columns = body + .fields() + .map(|f| Ok(SimpleColumn::new(f.name().to_string()))) + .collect::>() + .map_err(Error::parse)? + .into(); + + *this.columns = Some(columns); + } + Message::DataRow(body) => { + let row = match &this.columns { + Some(columns) => SimpleQueryRow::new(columns.clone(), body)?, + None => return Poll::Ready(Some(Err(Error::unexpected_message()))), + }; + return Poll::Ready(Some(Ok(SimpleQueryMessage::Row(row)))); + } + Message::ReadyForQuery(s) => { + *this.status = s.into(); + return Poll::Ready(None); + } + _ => return Poll::Ready(Some(Err(Error::unexpected_message()))), + } + } + } +} diff --git a/libs/proxy/tokio-postgres2/src/statement.rs b/libs/proxy/tokio-postgres2/src/statement.rs new file mode 100644 index 0000000000..22e160fc05 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/statement.rs @@ -0,0 +1,157 @@ +use crate::client::InnerClient; +use crate::codec::FrontendMessage; +use crate::connection::RequestMessages; +use crate::types::Type; +use postgres_protocol2::{ + message::{backend::Field, frontend}, + Oid, +}; +use std::{ + fmt, + sync::{Arc, Weak}, +}; + +struct StatementInner { + client: Weak, + name: String, + params: Vec, + columns: Vec, +} + +impl Drop for StatementInner { + fn drop(&mut self) { + if let Some(client) = self.client.upgrade() { + let buf = client.with_buf(|buf| { + frontend::close(b'S', &self.name, buf).unwrap(); + frontend::sync(buf); + buf.split().freeze() + }); + let _ = client.send(RequestMessages::Single(FrontendMessage::Raw(buf))); + } + } +} + +/// A prepared statement. +/// +/// Prepared statements can only be used with the connection that created them. +#[derive(Clone)] +pub struct Statement(Arc); + +impl Statement { + pub(crate) fn new( + inner: &Arc, + name: String, + params: Vec, + columns: Vec, + ) -> Statement { + Statement(Arc::new(StatementInner { + client: Arc::downgrade(inner), + name, + params, + columns, + })) + } + + pub(crate) fn new_anonymous(params: Vec, columns: Vec) -> Statement { + Statement(Arc::new(StatementInner { + client: Weak::new(), + name: String::new(), + params, + columns, + })) + } + + pub(crate) fn name(&self) -> &str { + &self.0.name + } + + /// Returns the expected types of the statement's parameters. + pub fn params(&self) -> &[Type] { + &self.0.params + } + + /// Returns information about the columns returned when the statement is queried. + pub fn columns(&self) -> &[Column] { + &self.0.columns + } +} + +/// Information about a column of a query. +pub struct Column { + name: String, + type_: Type, + + // raw fields from RowDescription + table_oid: Oid, + column_id: i16, + format: i16, + + // that better be stored in self.type_, but that is more radical refactoring + type_oid: Oid, + type_size: i16, + type_modifier: i32, +} + +impl Column { + pub(crate) fn new(name: String, type_: Type, raw_field: Field<'_>) -> Column { + Column { + name, + type_, + table_oid: raw_field.table_oid(), + column_id: raw_field.column_id(), + format: raw_field.format(), + type_oid: raw_field.type_oid(), + type_size: raw_field.type_size(), + type_modifier: raw_field.type_modifier(), + } + } + + /// Returns the name of the column. + pub fn name(&self) -> &str { + &self.name + } + + /// Returns the type of the column. + pub fn type_(&self) -> &Type { + &self.type_ + } + + /// Returns the table OID of the column. + pub fn table_oid(&self) -> Oid { + self.table_oid + } + + /// Returns the column ID of the column. + pub fn column_id(&self) -> i16 { + self.column_id + } + + /// Returns the format of the column. + pub fn format(&self) -> i16 { + self.format + } + + /// Returns the type OID of the column. + pub fn type_oid(&self) -> Oid { + self.type_oid + } + + /// Returns the type size of the column. + pub fn type_size(&self) -> i16 { + self.type_size + } + + /// Returns the type modifier of the column. + pub fn type_modifier(&self) -> i32 { + self.type_modifier + } +} + +impl fmt::Debug for Column { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Column") + .field("name", &self.name) + .field("type", &self.type_) + .finish() + } +} diff --git a/libs/proxy/tokio-postgres2/src/tls.rs b/libs/proxy/tokio-postgres2/src/tls.rs new file mode 100644 index 0000000000..dc8140719f --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/tls.rs @@ -0,0 +1,162 @@ +//! TLS support. + +use std::error::Error; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::{fmt, io}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +pub(crate) mod private { + pub struct ForcePrivateApi; +} + +/// Channel binding information returned from a TLS handshake. +pub struct ChannelBinding { + pub(crate) tls_server_end_point: Option>, +} + +impl ChannelBinding { + /// Creates a `ChannelBinding` containing no information. + pub fn none() -> ChannelBinding { + ChannelBinding { + tls_server_end_point: None, + } + } + + /// Creates a `ChannelBinding` containing `tls-server-end-point` channel binding information. + pub fn tls_server_end_point(tls_server_end_point: Vec) -> ChannelBinding { + ChannelBinding { + tls_server_end_point: Some(tls_server_end_point), + } + } +} + +/// A constructor of `TlsConnect`ors. +/// +/// Requires the `runtime` Cargo feature (enabled by default). +pub trait MakeTlsConnect { + /// The stream type created by the `TlsConnect` implementation. + type Stream: TlsStream + Unpin; + /// The `TlsConnect` implementation created by this type. + type TlsConnect: TlsConnect; + /// The error type returned by the `TlsConnect` implementation. + type Error: Into>; + + /// Creates a new `TlsConnect`or. + /// + /// The domain name is provided for certificate verification and SNI. + fn make_tls_connect(&mut self, domain: &str) -> Result; +} + +/// An asynchronous function wrapping a stream in a TLS session. +pub trait TlsConnect { + /// The stream returned by the future. + type Stream: TlsStream + Unpin; + /// The error returned by the future. + type Error: Into>; + /// The future returned by the connector. + type Future: Future>; + + /// Returns a future performing a TLS handshake over the stream. + fn connect(self, stream: S) -> Self::Future; + + #[doc(hidden)] + fn can_connect(&self, _: private::ForcePrivateApi) -> bool { + true + } +} + +/// A TLS-wrapped connection to a PostgreSQL database. +pub trait TlsStream: AsyncRead + AsyncWrite { + /// Returns channel binding information for the session. + fn channel_binding(&self) -> ChannelBinding; +} + +/// A `MakeTlsConnect` and `TlsConnect` implementation which simply returns an error. +/// +/// This can be used when `sslmode` is `none` or `prefer`. +#[derive(Debug, Copy, Clone)] +pub struct NoTls; + +impl MakeTlsConnect for NoTls { + type Stream = NoTlsStream; + type TlsConnect = NoTls; + type Error = NoTlsError; + + fn make_tls_connect(&mut self, _: &str) -> Result { + Ok(NoTls) + } +} + +impl TlsConnect for NoTls { + type Stream = NoTlsStream; + type Error = NoTlsError; + type Future = NoTlsFuture; + + fn connect(self, _: S) -> NoTlsFuture { + NoTlsFuture(()) + } + + fn can_connect(&self, _: private::ForcePrivateApi) -> bool { + false + } +} + +/// The future returned by `NoTls`. +pub struct NoTlsFuture(()); + +impl Future for NoTlsFuture { + type Output = Result; + + fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll { + Poll::Ready(Err(NoTlsError(()))) + } +} + +/// The TLS "stream" type produced by the `NoTls` connector. +/// +/// Since `NoTls` doesn't support TLS, this type is uninhabited. +pub enum NoTlsStream {} + +impl AsyncRead for NoTlsStream { + fn poll_read( + self: Pin<&mut Self>, + _: &mut Context<'_>, + _: &mut ReadBuf<'_>, + ) -> Poll> { + match *self {} + } +} + +impl AsyncWrite for NoTlsStream { + fn poll_write(self: Pin<&mut Self>, _: &mut Context<'_>, _: &[u8]) -> Poll> { + match *self {} + } + + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + match *self {} + } + + fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + match *self {} + } +} + +impl TlsStream for NoTlsStream { + fn channel_binding(&self) -> ChannelBinding { + match *self {} + } +} + +/// The error returned by `NoTls`. +#[derive(Debug)] +pub struct NoTlsError(()); + +impl fmt::Display for NoTlsError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.write_str("no TLS implementation configured") + } +} + +impl Error for NoTlsError {} diff --git a/libs/proxy/tokio-postgres2/src/to_statement.rs b/libs/proxy/tokio-postgres2/src/to_statement.rs new file mode 100644 index 0000000000..427f77dd79 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/to_statement.rs @@ -0,0 +1,57 @@ +use crate::to_statement::private::{Sealed, ToStatementType}; +use crate::Statement; + +mod private { + use crate::{Client, Error, Statement}; + + pub trait Sealed {} + + pub enum ToStatementType<'a> { + Statement(&'a Statement), + Query(&'a str), + } + + impl<'a> ToStatementType<'a> { + pub async fn into_statement(self, client: &Client) -> Result { + match self { + ToStatementType::Statement(s) => Ok(s.clone()), + ToStatementType::Query(s) => client.prepare(s).await, + } + } + } +} + +/// A trait abstracting over prepared and unprepared statements. +/// +/// Many methods are generic over this bound, so that they support both a raw query string as well as a statement which +/// was prepared previously. +/// +/// This trait is "sealed" and cannot be implemented by anything outside this crate. +pub trait ToStatement: Sealed { + #[doc(hidden)] + fn __convert(&self) -> ToStatementType<'_>; +} + +impl ToStatement for Statement { + fn __convert(&self) -> ToStatementType<'_> { + ToStatementType::Statement(self) + } +} + +impl Sealed for Statement {} + +impl ToStatement for str { + fn __convert(&self) -> ToStatementType<'_> { + ToStatementType::Query(self) + } +} + +impl Sealed for str {} + +impl ToStatement for String { + fn __convert(&self) -> ToStatementType<'_> { + ToStatementType::Query(self) + } +} + +impl Sealed for String {} diff --git a/libs/proxy/tokio-postgres2/src/transaction.rs b/libs/proxy/tokio-postgres2/src/transaction.rs new file mode 100644 index 0000000000..03a57e4947 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/transaction.rs @@ -0,0 +1,74 @@ +use crate::codec::FrontendMessage; +use crate::connection::RequestMessages; +use crate::query::RowStream; +use crate::{CancelToken, Client, Error, ReadyForQueryStatus}; +use postgres_protocol2::message::frontend; + +/// A representation of a PostgreSQL database transaction. +/// +/// Transactions will implicitly roll back when dropped. Use the `commit` method to commit the changes made in the +/// transaction. Transactions can be nested, with inner transactions implemented via safepoints. +pub struct Transaction<'a> { + client: &'a mut Client, + done: bool, +} + +impl Drop for Transaction<'_> { + fn drop(&mut self) { + if self.done { + return; + } + + let buf = self.client.inner().with_buf(|buf| { + frontend::query("ROLLBACK", buf).unwrap(); + buf.split().freeze() + }); + let _ = self + .client + .inner() + .send(RequestMessages::Single(FrontendMessage::Raw(buf))); + } +} + +impl<'a> Transaction<'a> { + pub(crate) fn new(client: &'a mut Client) -> Transaction<'a> { + Transaction { + client, + done: false, + } + } + + /// Consumes the transaction, committing all changes made within it. + pub async fn commit(mut self) -> Result { + self.done = true; + self.client.batch_execute("COMMIT").await + } + + /// Rolls the transaction back, discarding all changes made within it. + /// + /// This is equivalent to `Transaction`'s `Drop` implementation, but provides any error encountered to the caller. + pub async fn rollback(mut self) -> Result { + self.done = true; + self.client.batch_execute("ROLLBACK").await + } + + /// Like `Client::query_raw_txt`. + pub async fn query_raw_txt(&self, statement: &str, params: I) -> Result + where + S: AsRef, + I: IntoIterator>, + I::IntoIter: ExactSizeIterator, + { + self.client.query_raw_txt(statement, params).await + } + + /// Like `Client::cancel_token`. + pub fn cancel_token(&self) -> CancelToken { + self.client.cancel_token() + } + + /// Returns a reference to the underlying `Client`. + pub fn client(&self) -> &Client { + self.client + } +} diff --git a/libs/proxy/tokio-postgres2/src/transaction_builder.rs b/libs/proxy/tokio-postgres2/src/transaction_builder.rs new file mode 100644 index 0000000000..9718ac588c --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/transaction_builder.rs @@ -0,0 +1,113 @@ +use crate::{Client, Error, Transaction}; + +/// The isolation level of a database transaction. +#[derive(Debug, Copy, Clone)] +#[non_exhaustive] +pub enum IsolationLevel { + /// Equivalent to `ReadCommitted`. + ReadUncommitted, + + /// An individual statement in the transaction will see rows committed before it began. + ReadCommitted, + + /// All statements in the transaction will see the same view of rows committed before the first query in the + /// transaction. + RepeatableRead, + + /// The reads and writes in this transaction must be able to be committed as an atomic "unit" with respect to reads + /// and writes of all other concurrent serializable transactions without interleaving. + Serializable, +} + +/// A builder for database transactions. +pub struct TransactionBuilder<'a> { + client: &'a mut Client, + isolation_level: Option, + read_only: Option, + deferrable: Option, +} + +impl<'a> TransactionBuilder<'a> { + pub(crate) fn new(client: &'a mut Client) -> TransactionBuilder<'a> { + TransactionBuilder { + client, + isolation_level: None, + read_only: None, + deferrable: None, + } + } + + /// Sets the isolation level of the transaction. + pub fn isolation_level(mut self, isolation_level: IsolationLevel) -> Self { + self.isolation_level = Some(isolation_level); + self + } + + /// Sets the access mode of the transaction. + pub fn read_only(mut self, read_only: bool) -> Self { + self.read_only = Some(read_only); + self + } + + /// Sets the deferrability of the transaction. + /// + /// If the transaction is also serializable and read only, creation of the transaction may block, but when it + /// completes the transaction is able to run with less overhead and a guarantee that it will not be aborted due to + /// serialization failure. + pub fn deferrable(mut self, deferrable: bool) -> Self { + self.deferrable = Some(deferrable); + self + } + + /// Begins the transaction. + /// + /// The transaction will roll back by default - use the `commit` method to commit it. + pub async fn start(self) -> Result, Error> { + let mut query = "START TRANSACTION".to_string(); + let mut first = true; + + if let Some(level) = self.isolation_level { + first = false; + + query.push_str(" ISOLATION LEVEL "); + let level = match level { + IsolationLevel::ReadUncommitted => "READ UNCOMMITTED", + IsolationLevel::ReadCommitted => "READ COMMITTED", + IsolationLevel::RepeatableRead => "REPEATABLE READ", + IsolationLevel::Serializable => "SERIALIZABLE", + }; + query.push_str(level); + } + + if let Some(read_only) = self.read_only { + if !first { + query.push(','); + } + first = false; + + let s = if read_only { + " READ ONLY" + } else { + " READ WRITE" + }; + query.push_str(s); + } + + if let Some(deferrable) = self.deferrable { + if !first { + query.push(','); + } + + let s = if deferrable { + " DEFERRABLE" + } else { + " NOT DEFERRABLE" + }; + query.push_str(s); + } + + self.client.batch_execute(&query).await?; + + Ok(Transaction::new(self.client)) + } +} diff --git a/libs/proxy/tokio-postgres2/src/types.rs b/libs/proxy/tokio-postgres2/src/types.rs new file mode 100644 index 0000000000..e571d7ee00 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/types.rs @@ -0,0 +1,6 @@ +//! Types. +//! +//! This module is a reexport of the `postgres_types` crate. + +#[doc(inline)] +pub use postgres_types2::*; diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 1665d6361a..0d774d529d 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -55,6 +55,7 @@ parquet.workspace = true parquet_derive.workspace = true pin-project-lite.workspace = true postgres_backend.workspace = true +postgres-protocol = { package = "postgres-protocol2", path = "../libs/proxy/postgres-protocol2" } pq_proto.workspace = true prometheus.workspace = true rand.workspace = true @@ -80,8 +81,7 @@ subtle.workspace = true thiserror.workspace = true tikv-jemallocator.workspace = true tikv-jemalloc-ctl = { workspace = true, features = ["use_std"] } -tokio-postgres = { workspace = true, features = ["with-serde_json-1"] } -tokio-postgres-rustls.workspace = true +tokio-postgres = { package = "tokio-postgres2", path = "../libs/proxy/tokio-postgres2" } tokio-rustls.workspace = true tokio-util.workspace = true tokio = { workspace = true, features = ["signal"] } @@ -96,7 +96,6 @@ utils.workspace = true uuid.workspace = true rustls-native-certs.workspace = true x509-parser.workspace = true -postgres-protocol.workspace = true redis.workspace = true zerocopy.workspace = true @@ -117,6 +116,5 @@ tokio-tungstenite.workspace = true pbkdf2 = { workspace = true, features = ["simple", "std"] } rcgen.workspace = true rstest.workspace = true -tokio-postgres-rustls.workspace = true walkdir.workspace = true rand_distr = "0.4" diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index 8408d4720b..2abe88ac88 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -13,7 +13,6 @@ use rustls::pki_types::InvalidDnsNameError; use thiserror::Error; use tokio::net::TcpStream; use tokio_postgres::tls::MakeTlsConnect; -use tokio_postgres_rustls::MakeRustlsConnect; use tracing::{debug, error, info, warn}; use crate::auth::parse_endpoint_param; @@ -24,6 +23,7 @@ use crate::control_plane::errors::WakeComputeError; use crate::control_plane::messages::MetricsAuxInfo; use crate::error::{ReportableError, UserFacingError}; use crate::metrics::{Metrics, NumDbConnectionsGuard}; +use crate::postgres_rustls::MakeRustlsConnect; use crate::proxy::neon_option; use crate::types::Host; @@ -244,7 +244,6 @@ impl ConnCfg { let port = ports.get(i).or_else(|| ports.first()).unwrap_or(&5432); let host = match host { Host::Tcp(host) => host.as_str(), - Host::Unix(_) => continue, // unix sockets are not welcome here }; match connect_once(host, *port).await { @@ -315,7 +314,7 @@ impl ConnCfg { }; let client_config = client_config.with_no_client_auth(); - let mut mk_tls = tokio_postgres_rustls::MakeRustlsConnect::new(client_config); + let mut mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(client_config); let tls = >::make_tls_connect( &mut mk_tls, host, diff --git a/proxy/src/context/mod.rs b/proxy/src/context/mod.rs index 5c19a23e36..4a063a5faa 100644 --- a/proxy/src/context/mod.rs +++ b/proxy/src/context/mod.rs @@ -414,6 +414,7 @@ impl RequestContextInner { outcome, }); } + if let Some(tx) = self.sender.take() { // If type changes, this error handling needs to be updated. let tx: mpsc::UnboundedSender = tx; diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index ad7e1d2771..ba69f9cf2d 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -88,6 +88,7 @@ pub mod jemalloc; pub mod logging; pub mod metrics; pub mod parse; +pub mod postgres_rustls; pub mod protocol2; pub mod proxy; pub mod rate_limiter; diff --git a/proxy/src/postgres_rustls/mod.rs b/proxy/src/postgres_rustls/mod.rs new file mode 100644 index 0000000000..31e7915e89 --- /dev/null +++ b/proxy/src/postgres_rustls/mod.rs @@ -0,0 +1,158 @@ +use std::convert::TryFrom; +use std::sync::Arc; + +use rustls::pki_types::ServerName; +use rustls::ClientConfig; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_postgres::tls::MakeTlsConnect; + +mod private { + use std::future::Future; + use std::io; + use std::pin::Pin; + use std::task::{Context, Poll}; + + use rustls::pki_types::ServerName; + use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + use tokio_postgres::tls::{ChannelBinding, TlsConnect}; + use tokio_rustls::client::TlsStream; + use tokio_rustls::TlsConnector; + + use crate::config::TlsServerEndPoint; + + pub struct TlsConnectFuture { + inner: tokio_rustls::Connect, + } + + impl Future for TlsConnectFuture + where + S: AsyncRead + AsyncWrite + Unpin, + { + type Output = io::Result>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.inner).poll(cx).map_ok(RustlsStream) + } + } + + pub struct RustlsConnect(pub RustlsConnectData); + + pub struct RustlsConnectData { + pub hostname: ServerName<'static>, + pub connector: TlsConnector, + } + + impl TlsConnect for RustlsConnect + where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, + { + type Stream = RustlsStream; + type Error = io::Error; + type Future = TlsConnectFuture; + + fn connect(self, stream: S) -> Self::Future { + TlsConnectFuture { + inner: self.0.connector.connect(self.0.hostname, stream), + } + } + } + + pub struct RustlsStream(TlsStream); + + impl tokio_postgres::tls::TlsStream for RustlsStream + where + S: AsyncRead + AsyncWrite + Unpin, + { + fn channel_binding(&self) -> ChannelBinding { + let (_, session) = self.0.get_ref(); + match session.peer_certificates() { + Some([cert, ..]) => TlsServerEndPoint::new(cert) + .ok() + .and_then(|cb| match cb { + TlsServerEndPoint::Sha256(hash) => Some(hash), + TlsServerEndPoint::Undefined => None, + }) + .map_or_else(ChannelBinding::none, |hash| { + ChannelBinding::tls_server_end_point(hash.to_vec()) + }), + _ => ChannelBinding::none(), + } + } + } + + impl AsyncRead for RustlsStream + where + S: AsyncRead + AsyncWrite + Unpin, + { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.0).poll_read(cx, buf) + } + } + + impl AsyncWrite for RustlsStream + where + S: AsyncRead + AsyncWrite + Unpin, + { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.0).poll_write(cx, buf) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.0).poll_flush(cx) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.0).poll_shutdown(cx) + } + } +} + +/// A `MakeTlsConnect` implementation using `rustls`. +/// +/// That way you can connect to PostgreSQL using `rustls` as the TLS stack. +#[derive(Clone)] +pub struct MakeRustlsConnect { + config: Arc, +} + +impl MakeRustlsConnect { + /// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`. + #[must_use] + pub fn new(config: ClientConfig) -> Self { + Self { + config: Arc::new(config), + } + } +} + +impl MakeTlsConnect for MakeRustlsConnect +where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + type Stream = private::RustlsStream; + type TlsConnect = private::RustlsConnect; + type Error = rustls::pki_types::InvalidDnsNameError; + + fn make_tls_connect(&mut self, hostname: &str) -> Result { + ServerName::try_from(hostname).map(|dns_name| { + private::RustlsConnect(private::RustlsConnectData { + hostname: dns_name.to_owned(), + connector: Arc::clone(&self.config).into(), + }) + }) + } +} diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index 3de8ca8736..2c2c2964b6 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -14,7 +14,6 @@ use rustls::pki_types; use tokio::io::DuplexStream; use tokio_postgres::config::SslMode; use tokio_postgres::tls::{MakeTlsConnect, NoTls}; -use tokio_postgres_rustls::MakeRustlsConnect; use super::connect_compute::ConnectMechanism; use super::retry::CouldRetry; @@ -29,6 +28,7 @@ use crate::control_plane::{ self, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, NodeInfo, NodeInfoCache, }; use crate::error::ErrorKind; +use crate::postgres_rustls::MakeRustlsConnect; use crate::types::{BranchId, EndpointId, ProjectId}; use crate::{sasl, scram}; diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 3037e20888..75909f3358 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -333,7 +333,7 @@ impl PoolingBackend { debug!("setting up backend session state"); // initiates the auth session - if let Err(e) = client.query("select auth.init()", &[]).await { + if let Err(e) = client.execute("select auth.init()", &[]).await { discard.discard(); return Err(e.into()); } diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index bd262f45ed..c302eac568 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -6,9 +6,10 @@ use std::task::{ready, Poll}; use futures::future::poll_fn; use futures::Future; use smallvec::SmallVec; +use tokio::net::TcpStream; use tokio::time::Instant; use tokio_postgres::tls::NoTlsStream; -use tokio_postgres::{AsyncMessage, Socket}; +use tokio_postgres::AsyncMessage; use tokio_util::sync::CancellationToken; use tracing::{error, info, info_span, warn, Instrument}; #[cfg(test)] @@ -57,7 +58,7 @@ pub(crate) fn poll_client( ctx: &RequestContext, conn_info: ConnInfo, client: C, - mut connection: tokio_postgres::Connection, + mut connection: tokio_postgres::Connection, conn_id: uuid::Uuid, aux: MetricsAuxInfo, ) -> Client { diff --git a/proxy/src/serverless/local_conn_pool.rs b/proxy/src/serverless/local_conn_pool.rs index 9abe35db08..db9ac49dae 100644 --- a/proxy/src/serverless/local_conn_pool.rs +++ b/proxy/src/serverless/local_conn_pool.rs @@ -24,10 +24,11 @@ use p256::ecdsa::{Signature, SigningKey}; use parking_lot::RwLock; use serde_json::value::RawValue; use signature::Signer; +use tokio::net::TcpStream; use tokio::time::Instant; use tokio_postgres::tls::NoTlsStream; use tokio_postgres::types::ToSql; -use tokio_postgres::{AsyncMessage, Socket}; +use tokio_postgres::AsyncMessage; use tokio_util::sync::CancellationToken; use tracing::{debug, error, info, info_span, warn, Instrument}; @@ -163,7 +164,7 @@ pub(crate) fn poll_client( ctx: &RequestContext, conn_info: ConnInfo, client: C, - mut connection: tokio_postgres::Connection, + mut connection: tokio_postgres::Connection, key: SigningKey, conn_id: uuid::Uuid, aux: MetricsAuxInfo, @@ -286,11 +287,11 @@ impl ClientInnerCommon { let token = resign_jwt(&local_data.key, payload, local_data.jti)?; // initiates the auth session - self.inner.simple_query("discard all").await?; + self.inner.batch_execute("discard all").await?; self.inner - .query( + .execute( "select auth.jwt_session_init($1)", - &[&token as &(dyn ToSql + Sync)], + &[&&*token as &(dyn ToSql + Sync)], ) .await?; diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index a73d9d6352..c0a3abc377 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -60,7 +60,6 @@ num-integer = { version = "0.1", features = ["i128"] } num-traits = { version = "0.2", features = ["i128", "libm"] } once_cell = { version = "1" } parquet = { version = "53", default-features = false, features = ["zstd"] } -postgres-types = { git = "https://github.com/neondatabase/rust-postgres.git", branch = "neon", default-features = false, features = ["with-serde_json-1"] } prost = { version = "0.13", features = ["prost-derive"] } rand = { version = "0.8", features = ["small_rng"] } regex = { version = "1" } @@ -79,8 +78,7 @@ subtle = { version = "2" } sync_wrapper = { version = "0.1", default-features = false, features = ["futures"] } tikv-jemalloc-sys = { version = "0.6", features = ["stats"] } time = { version = "0.3", features = ["macros", "serde-well-known"] } -tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "net", "process", "rt-multi-thread", "signal", "test-util"] } -tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", branch = "neon", features = ["with-serde_json-1"] } +tokio = { version = "1", features = ["full", "test-util"] } tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "ring", "tls12"] } tokio-stream = { version = "0.1", features = ["net"] } tokio-util = { version = "0.7", features = ["codec", "compat", "io", "rt"] } From ea3798e3b30f808f2851f58ff2390150b89959c6 Mon Sep 17 00:00:00 2001 From: John Spray Date: Fri, 29 Nov 2024 13:27:49 +0000 Subject: [PATCH 03/15] storage controller: use proper ScheduleContext when evacuating a node (#9908) ## Problem When picking locations for a shard, we should use a ScheduleContext that includes all the other shards in the tenant, so that we apply proper anti-affinity between shards. If we don't do this, then it can lead to unstable scheduling, where we place a shard somewhere that the optimizer will then immediately move it away from. We didn't always do this, because it was a bit awkward to accumulate the context for a tenant rather than just walking tenants. This was a TODO in `handle_node_availability_transition`: ``` // TODO: populate a ScheduleContext including all shards in the same tenant_id (only matters // for tenants without secondary locations: if they have a secondary location, then this // schedule() call is just promoting an existing secondary) ``` This is a precursor to https://github.com/neondatabase/neon/issues/8264, where the current imperfect scheduling during node evacuation hampers testing. ## Summary of changes - Add an iterator type that yields each shard along with a schedulecontext that includes all the other shards from the same tenant - Use the iterator to replace hand-crafted logic in optimize_all_plan (functionally identical) - Use the iterator in `handle_node_availability_transition` to apply proper anti-affinity during node evacuation. --- storage_controller/src/scheduler.rs | 17 +- storage_controller/src/service.rs | 200 +++++++----------- .../src/service/context_iterator.rs | 139 ++++++++++++ storage_controller/src/tenant_shard.rs | 11 +- 4 files changed, 245 insertions(+), 122 deletions(-) create mode 100644 storage_controller/src/service/context_iterator.rs diff --git a/storage_controller/src/scheduler.rs b/storage_controller/src/scheduler.rs index 2414d95eb8..ecc6b11e47 100644 --- a/storage_controller/src/scheduler.rs +++ b/storage_controller/src/scheduler.rs @@ -305,7 +305,7 @@ impl std::ops::Add for AffinityScore { /// Hint for whether this is a sincere attempt to schedule, or a speculative /// check for where we _would_ schedule (done during optimization) -#[derive(Debug)] +#[derive(Debug, Clone)] pub(crate) enum ScheduleMode { Normal, Speculative, @@ -319,7 +319,7 @@ impl Default for ScheduleMode { // For carrying state between multiple calls to [`TenantShard::schedule`], e.g. when calling // it for many shards in the same tenant. -#[derive(Debug, Default)] +#[derive(Debug, Default, Clone)] pub(crate) struct ScheduleContext { /// Sparse map of nodes: omitting a node implicitly makes its affinity [`AffinityScore::FREE`] pub(crate) nodes: HashMap, @@ -331,6 +331,14 @@ pub(crate) struct ScheduleContext { } impl ScheduleContext { + pub(crate) fn new(mode: ScheduleMode) -> Self { + Self { + nodes: HashMap::new(), + attached_nodes: HashMap::new(), + mode, + } + } + /// Input is a list of nodes we would like to avoid using again within this context. The more /// times a node is passed into this call, the less inclined we are to use it. pub(crate) fn avoid(&mut self, nodes: &[NodeId]) { @@ -355,6 +363,11 @@ impl ScheduleContext { pub(crate) fn get_node_attachments(&self, node_id: NodeId) -> usize { self.attached_nodes.get(&node_id).copied().unwrap_or(0) } + + #[cfg(test)] + pub(crate) fn attach_count(&self) -> usize { + self.attached_nodes.values().sum() + } } pub(crate) enum RefCountUpdate { diff --git a/storage_controller/src/service.rs b/storage_controller/src/service.rs index 446c476b99..636ccf11a1 100644 --- a/storage_controller/src/service.rs +++ b/storage_controller/src/service.rs @@ -1,3 +1,6 @@ +pub mod chaos_injector; +mod context_iterator; + use hyper::Uri; use std::{ borrow::Cow, @@ -95,7 +98,7 @@ use crate::{ }, }; -pub mod chaos_injector; +use context_iterator::TenantShardContextIterator; // For operations that should be quick, like attaching a new tenant const SHORT_RECONCILE_TIMEOUT: Duration = Duration::from_secs(5); @@ -5498,49 +5501,51 @@ impl Service { let mut tenants_affected: usize = 0; - for (tenant_shard_id, tenant_shard) in tenants { - if let Some(observed_loc) = tenant_shard.observed.locations.get_mut(&node_id) { - // When a node goes offline, we set its observed configuration to None, indicating unknown: we will - // not assume our knowledge of the node's configuration is accurate until it comes back online - observed_loc.conf = None; - } + for (_tenant_id, mut schedule_context, shards) in + TenantShardContextIterator::new(tenants, ScheduleMode::Normal) + { + for tenant_shard in shards { + let tenant_shard_id = tenant_shard.tenant_shard_id; + if let Some(observed_loc) = + tenant_shard.observed.locations.get_mut(&node_id) + { + // When a node goes offline, we set its observed configuration to None, indicating unknown: we will + // not assume our knowledge of the node's configuration is accurate until it comes back online + observed_loc.conf = None; + } - if nodes.len() == 1 { - // Special case for single-node cluster: there is no point trying to reschedule - // any tenant shards: avoid doing so, in order to avoid spewing warnings about - // failures to schedule them. - continue; - } + if nodes.len() == 1 { + // Special case for single-node cluster: there is no point trying to reschedule + // any tenant shards: avoid doing so, in order to avoid spewing warnings about + // failures to schedule them. + continue; + } - if !nodes - .values() - .any(|n| matches!(n.may_schedule(), MaySchedule::Yes(_))) - { - // Special case for when all nodes are unavailable and/or unschedulable: there is no point - // trying to reschedule since there's nowhere else to go. Without this - // branch we incorrectly detach tenants in response to node unavailability. - continue; - } + if !nodes + .values() + .any(|n| matches!(n.may_schedule(), MaySchedule::Yes(_))) + { + // Special case for when all nodes are unavailable and/or unschedulable: there is no point + // trying to reschedule since there's nowhere else to go. Without this + // branch we incorrectly detach tenants in response to node unavailability. + continue; + } - if tenant_shard.intent.demote_attached(scheduler, node_id) { - tenant_shard.sequence = tenant_shard.sequence.next(); + if tenant_shard.intent.demote_attached(scheduler, node_id) { + tenant_shard.sequence = tenant_shard.sequence.next(); - // TODO: populate a ScheduleContext including all shards in the same tenant_id (only matters - // for tenants without secondary locations: if they have a secondary location, then this - // schedule() call is just promoting an existing secondary) - let mut schedule_context = ScheduleContext::default(); - - match tenant_shard.schedule(scheduler, &mut schedule_context) { - Err(e) => { - // It is possible that some tenants will become unschedulable when too many pageservers - // go offline: in this case there isn't much we can do other than make the issue observable. - // TODO: give TenantShard a scheduling error attribute to be queried later. - tracing::warn!(%tenant_shard_id, "Scheduling error when marking pageserver {} offline: {e}", node_id); - } - Ok(()) => { - if self.maybe_reconcile_shard(tenant_shard, nodes).is_some() { - tenants_affected += 1; - }; + match tenant_shard.schedule(scheduler, &mut schedule_context) { + Err(e) => { + // It is possible that some tenants will become unschedulable when too many pageservers + // go offline: in this case there isn't much we can do other than make the issue observable. + // TODO: give TenantShard a scheduling error attribute to be queried later. + tracing::warn!(%tenant_shard_id, "Scheduling error when marking pageserver {} offline: {e}", node_id); + } + Ok(()) => { + if self.maybe_reconcile_shard(tenant_shard, nodes).is_some() { + tenants_affected += 1; + }; + } } } } @@ -6011,14 +6016,8 @@ impl Service { let (nodes, tenants, _scheduler) = locked.parts_mut(); let pageservers = nodes.clone(); - let mut schedule_context = ScheduleContext::default(); - let mut reconciles_spawned = 0; - for (tenant_shard_id, shard) in tenants.iter_mut() { - if tenant_shard_id.is_shard_zero() { - schedule_context = ScheduleContext::default(); - } - + for shard in tenants.values_mut() { // Skip checking if this shard is already enqueued for reconciliation if shard.delayed_reconcile && self.reconciler_concurrency.available_permits() == 0 { // If there is something delayed, then return a nonzero count so that @@ -6033,8 +6032,6 @@ impl Service { if self.maybe_reconcile_shard(shard, &pageservers).is_some() { reconciles_spawned += 1; } - - schedule_context.avoid(&shard.intent.all_pageservers()); } reconciles_spawned @@ -6103,95 +6100,62 @@ impl Service { } fn optimize_all_plan(&self) -> Vec<(TenantShardId, ScheduleOptimization)> { - let mut schedule_context = ScheduleContext::default(); - - let mut tenant_shards: Vec<&TenantShard> = Vec::new(); - // How many candidate optimizations we will generate, before evaluating them for readniess: setting // this higher than the execution limit gives us a chance to execute some work even if the first // few optimizations we find are not ready. const MAX_OPTIMIZATIONS_PLAN_PER_PASS: usize = 8; let mut work = Vec::new(); - let mut locked = self.inner.write().unwrap(); let (nodes, tenants, scheduler) = locked.parts_mut(); - for (tenant_shard_id, shard) in tenants.iter() { - if tenant_shard_id.is_shard_zero() { - // Reset accumulators on the first shard in a tenant - schedule_context = ScheduleContext::default(); - schedule_context.mode = ScheduleMode::Speculative; - tenant_shards.clear(); - } - if work.len() >= MAX_OPTIMIZATIONS_PLAN_PER_PASS { - break; - } - - match shard.get_scheduling_policy() { - ShardSchedulingPolicy::Active => { - // Ok to do optimization + for (_tenant_id, schedule_context, shards) in + TenantShardContextIterator::new(tenants, ScheduleMode::Speculative) + { + for shard in shards { + if work.len() >= MAX_OPTIMIZATIONS_PLAN_PER_PASS { + break; } - ShardSchedulingPolicy::Essential - | ShardSchedulingPolicy::Pause - | ShardSchedulingPolicy::Stop => { - // Policy prevents optimizing this shard. - continue; + match shard.get_scheduling_policy() { + ShardSchedulingPolicy::Active => { + // Ok to do optimization + } + ShardSchedulingPolicy::Essential + | ShardSchedulingPolicy::Pause + | ShardSchedulingPolicy::Stop => { + // Policy prevents optimizing this shard. + continue; + } } - } - // Accumulate the schedule context for all the shards in a tenant: we must have - // the total view of all shards before we can try to optimize any of them. - schedule_context.avoid(&shard.intent.all_pageservers()); - if let Some(attached) = shard.intent.get_attached() { - schedule_context.push_attached(*attached); - } - tenant_shards.push(shard); - - // Once we have seen the last shard in the tenant, proceed to search across all shards - // in the tenant for optimizations - if shard.shard.number.0 == shard.shard.count.count() - 1 { - if tenant_shards.iter().any(|s| s.reconciler.is_some()) { + if !matches!(shard.splitting, SplitState::Idle) + || matches!(shard.policy, PlacementPolicy::Detached) + || shard.reconciler.is_some() + { // Do not start any optimizations while another change to the tenant is ongoing: this // is not necessary for correctness, but simplifies operations and implicitly throttles // optimization changes to happen in a "trickle" over time. continue; } - if tenant_shards.iter().any(|s| { - !matches!(s.splitting, SplitState::Idle) - || matches!(s.policy, PlacementPolicy::Detached) - }) { - // Never attempt to optimize a tenant that is currently being split, or - // a tenant that is meant to be detached - continue; - } - // TODO: optimization calculations are relatively expensive: create some fast-path for // the common idle case (avoiding the search on tenants that we have recently checked) - - for shard in &tenant_shards { - if let Some(optimization) = - // If idle, maybe ptimize attachments: if a shard has a secondary location that is preferable to - // its primary location based on soft constraints, cut it over. - shard.optimize_attachment(nodes, &schedule_context) - { - work.push((shard.tenant_shard_id, optimization)); - break; - } else if let Some(optimization) = - // If idle, maybe optimize secondary locations: if a shard has a secondary location that would be - // better placed on another node, based on ScheduleContext, then adjust it. This - // covers cases like after a shard split, where we might have too many shards - // in the same tenant with secondary locations on the node where they originally split. - shard.optimize_secondary(scheduler, &schedule_context) - { - work.push((shard.tenant_shard_id, optimization)); - break; - } - - // TODO: extend this mechanism to prefer attaching on nodes with fewer attached - // tenants (i.e. extend schedule state to distinguish attached from secondary counts), - // for the total number of attachments on a node (not just within a tenant.) + if let Some(optimization) = + // If idle, maybe ptimize attachments: if a shard has a secondary location that is preferable to + // its primary location based on soft constraints, cut it over. + shard.optimize_attachment(nodes, &schedule_context) + { + work.push((shard.tenant_shard_id, optimization)); + break; + } else if let Some(optimization) = + // If idle, maybe optimize secondary locations: if a shard has a secondary location that would be + // better placed on another node, based on ScheduleContext, then adjust it. This + // covers cases like after a shard split, where we might have too many shards + // in the same tenant with secondary locations on the node where they originally split. + shard.optimize_secondary(scheduler, &schedule_context) + { + work.push((shard.tenant_shard_id, optimization)); + break; } } } diff --git a/storage_controller/src/service/context_iterator.rs b/storage_controller/src/service/context_iterator.rs new file mode 100644 index 0000000000..d38010a27e --- /dev/null +++ b/storage_controller/src/service/context_iterator.rs @@ -0,0 +1,139 @@ +use std::collections::BTreeMap; + +use utils::id::TenantId; +use utils::shard::TenantShardId; + +use crate::scheduler::{ScheduleContext, ScheduleMode}; +use crate::tenant_shard::TenantShard; + +/// When making scheduling decisions, it is useful to have the ScheduleContext for a whole +/// tenant while considering the individual shards within it. This iterator is a helper +/// that gathers all the shards in a tenant and then yields them together with a ScheduleContext +/// for the tenant. +pub(super) struct TenantShardContextIterator<'a> { + schedule_mode: ScheduleMode, + inner: std::collections::btree_map::IterMut<'a, TenantShardId, TenantShard>, +} + +impl<'a> TenantShardContextIterator<'a> { + pub(super) fn new( + tenants: &'a mut BTreeMap, + schedule_mode: ScheduleMode, + ) -> Self { + Self { + schedule_mode, + inner: tenants.iter_mut(), + } + } +} + +impl<'a> Iterator for TenantShardContextIterator<'a> { + type Item = (TenantId, ScheduleContext, Vec<&'a mut TenantShard>); + + fn next(&mut self) -> Option { + let mut tenant_shards = Vec::new(); + let mut schedule_context = ScheduleContext::new(self.schedule_mode.clone()); + loop { + let (tenant_shard_id, shard) = self.inner.next()?; + + if tenant_shard_id.is_shard_zero() { + // Cleared on last shard of previous tenant + assert!(tenant_shards.is_empty()); + } + + // Accumulate the schedule context for all the shards in a tenant + schedule_context.avoid(&shard.intent.all_pageservers()); + if let Some(attached) = shard.intent.get_attached() { + schedule_context.push_attached(*attached); + } + tenant_shards.push(shard); + + if tenant_shard_id.shard_number.0 == tenant_shard_id.shard_count.count() - 1 { + return Some((tenant_shard_id.tenant_id, schedule_context, tenant_shards)); + } + } + } +} + +#[cfg(test)] +mod tests { + use std::{collections::BTreeMap, str::FromStr}; + + use pageserver_api::controller_api::PlacementPolicy; + use utils::shard::{ShardCount, ShardNumber}; + + use crate::{ + scheduler::test_utils::make_test_nodes, service::Scheduler, + tenant_shard::tests::make_test_tenant_with_id, + }; + + use super::*; + + #[test] + fn test_context_iterator() { + // Hand-crafted tenant IDs to ensure they appear in the expected order when put into + // a btreemap & iterated + let mut t_1_shards = make_test_tenant_with_id( + TenantId::from_str("af0480929707ee75372337efaa5ecf96").unwrap(), + PlacementPolicy::Attached(1), + ShardCount(1), + None, + ); + let t_2_shards = make_test_tenant_with_id( + TenantId::from_str("bf0480929707ee75372337efaa5ecf96").unwrap(), + PlacementPolicy::Attached(1), + ShardCount(4), + None, + ); + let mut t_3_shards = make_test_tenant_with_id( + TenantId::from_str("cf0480929707ee75372337efaa5ecf96").unwrap(), + PlacementPolicy::Attached(1), + ShardCount(1), + None, + ); + + let t1_id = t_1_shards[0].tenant_shard_id.tenant_id; + let t2_id = t_2_shards[0].tenant_shard_id.tenant_id; + let t3_id = t_3_shards[0].tenant_shard_id.tenant_id; + + let mut tenants = BTreeMap::new(); + tenants.insert(t_1_shards[0].tenant_shard_id, t_1_shards.pop().unwrap()); + for shard in t_2_shards { + tenants.insert(shard.tenant_shard_id, shard); + } + tenants.insert(t_3_shards[0].tenant_shard_id, t_3_shards.pop().unwrap()); + + let nodes = make_test_nodes(3, &[]); + let mut scheduler = Scheduler::new(nodes.values()); + let mut context = ScheduleContext::default(); + for shard in tenants.values_mut() { + shard.schedule(&mut scheduler, &mut context).unwrap(); + } + + let mut iter = TenantShardContextIterator::new(&mut tenants, ScheduleMode::Speculative); + let (tenant_id, context, shards) = iter.next().unwrap(); + assert_eq!(tenant_id, t1_id); + assert_eq!(shards[0].tenant_shard_id.shard_number, ShardNumber(0)); + assert_eq!(shards.len(), 1); + assert_eq!(context.attach_count(), 1); + + let (tenant_id, context, shards) = iter.next().unwrap(); + assert_eq!(tenant_id, t2_id); + assert_eq!(shards[0].tenant_shard_id.shard_number, ShardNumber(0)); + assert_eq!(shards[1].tenant_shard_id.shard_number, ShardNumber(1)); + assert_eq!(shards[2].tenant_shard_id.shard_number, ShardNumber(2)); + assert_eq!(shards[3].tenant_shard_id.shard_number, ShardNumber(3)); + assert_eq!(shards.len(), 4); + assert_eq!(context.attach_count(), 4); + + let (tenant_id, context, shards) = iter.next().unwrap(); + assert_eq!(tenant_id, t3_id); + assert_eq!(shards[0].tenant_shard_id.shard_number, ShardNumber(0)); + assert_eq!(shards.len(), 1); + assert_eq!(context.attach_count(), 1); + + for shard in tenants.values_mut() { + shard.intent.clear(&mut scheduler); + } + } +} diff --git a/storage_controller/src/tenant_shard.rs b/storage_controller/src/tenant_shard.rs index 27c97d3b86..2eb98ee825 100644 --- a/storage_controller/src/tenant_shard.rs +++ b/storage_controller/src/tenant_shard.rs @@ -1574,13 +1574,20 @@ pub(crate) mod tests { ) } - fn make_test_tenant( + pub(crate) fn make_test_tenant( policy: PlacementPolicy, shard_count: ShardCount, preferred_az: Option, ) -> Vec { - let tenant_id = TenantId::generate(); + make_test_tenant_with_id(TenantId::generate(), policy, shard_count, preferred_az) + } + pub(crate) fn make_test_tenant_with_id( + tenant_id: TenantId, + policy: PlacementPolicy, + shard_count: ShardCount, + preferred_az: Option, + ) -> Vec { (0..shard_count.count()) .map(|i| { let shard_number = ShardNumber(i); From a6073b5013fb1513e1f9937642fb3610f62854dc Mon Sep 17 00:00:00 2001 From: Erik Grinaker Date: Fri, 29 Nov 2024 15:38:04 +0200 Subject: [PATCH 04/15] safekeeper: use jemalloc (#9780) ## Problem To add Safekeeper heap profiling in #9778, we need to switch to an allocator that supports it. Pageserver and proxy already use jemalloc. Touches #9534. ## Summary of changes Use jemalloc in Safekeeper. --- Cargo.lock | 1 + safekeeper/Cargo.toml | 1 + safekeeper/benches/receive_wal.rs | 30 +++++++++++++++++++++++++++++- safekeeper/src/bin/safekeeper.rs | 3 +++ 4 files changed, 34 insertions(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index f05c6311dd..abe69525c9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5409,6 +5409,7 @@ dependencies = [ "strum", "strum_macros", "thiserror", + "tikv-jemallocator", "tokio", "tokio-io-timeout", "tokio-postgres", diff --git a/safekeeper/Cargo.toml b/safekeeper/Cargo.toml index 635a9222e1..0422c46ab1 100644 --- a/safekeeper/Cargo.toml +++ b/safekeeper/Cargo.toml @@ -41,6 +41,7 @@ serde_json.workspace = true strum.workspace = true strum_macros.workspace = true thiserror.workspace = true +tikv-jemallocator.workspace = true tokio = { workspace = true, features = ["fs"] } tokio-util = { workspace = true } tokio-io-timeout.workspace = true diff --git a/safekeeper/benches/receive_wal.rs b/safekeeper/benches/receive_wal.rs index c637b4fb24..8c4281cf52 100644 --- a/safekeeper/benches/receive_wal.rs +++ b/safekeeper/benches/receive_wal.rs @@ -6,6 +6,7 @@ mod benchutils; use std::io::Write as _; use benchutils::Env; +use bytes::BytesMut; use camino_tempfile::tempfile; use criterion::{criterion_group, criterion_main, BatchSize, Bencher, Criterion}; use itertools::Itertools as _; @@ -23,6 +24,9 @@ const KB: usize = 1024; const MB: usize = 1024 * KB; const GB: usize = 1024 * MB; +#[global_allocator] +static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; + // Register benchmarks with Criterion. criterion_group!( name = benches; @@ -30,7 +34,8 @@ criterion_group!( targets = bench_process_msg, bench_wal_acceptor, bench_wal_acceptor_throughput, - bench_file_write + bench_file_write, + bench_bytes_reserve, ); criterion_main!(benches); @@ -341,3 +346,26 @@ fn bench_file_write(c: &mut Criterion) { Ok(()) } } + +/// Benchmarks the cost of memory allocations when receiving WAL messages. This emulates the logic +/// in FeMessage::parse, which extends the read buffer. It is primarily intended to test jemalloc. +fn bench_bytes_reserve(c: &mut Criterion) { + let mut g = c.benchmark_group("bytes_reserve"); + for size in [1, 64, KB, 8 * KB, 128 * KB] { + g.throughput(criterion::Throughput::Bytes(size as u64)); + g.bench_function(format!("size={size}"), |b| run_bench(b, size).unwrap()); + } + + fn run_bench(b: &mut Bencher, size: usize) -> anyhow::Result<()> { + let mut bytes = BytesMut::new(); + let data = vec![0; size]; + + b.iter(|| { + bytes.reserve(size); + bytes.extend_from_slice(&data); + bytes.split_to(size).freeze(); + }); + + Ok(()) + } +} diff --git a/safekeeper/src/bin/safekeeper.rs b/safekeeper/src/bin/safekeeper.rs index 1248428d33..3659bcd7e0 100644 --- a/safekeeper/src/bin/safekeeper.rs +++ b/safekeeper/src/bin/safekeeper.rs @@ -48,6 +48,9 @@ use utils::{ tcp_listener, }; +#[global_allocator] +static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; + const PID_FILE_NAME: &str = "safekeeper.pid"; const ID_FILE_NAME: &str = "safekeeper.id"; From 538e2312a617c65d489d391892c70b2e4d7407b5 Mon Sep 17 00:00:00 2001 From: Alexey Kondratov Date: Fri, 29 Nov 2024 14:55:56 +0100 Subject: [PATCH 05/15] feat(compute_ctl): Always set application_name (#9934) ## Problem It was not always possible to judge what exactly some `cloud_admin` connections were doing because we didn't consistently set `application_name` everywhere. ## Summary of changes Unify the way we connect to Postgres: 1. Switch to building configs everywhere 2. Always set `application_name` and make naming consistent Follow-up for #9919 Part of neondatabase/cloud#20948 --- compute_tools/src/bin/compute_ctl.rs | 10 ++++- compute_tools/src/catalog.rs | 7 +--- compute_tools/src/checker.rs | 3 +- compute_tools/src/compute.rs | 49 ++++++++++++++++------- compute_tools/src/http/api.rs | 11 +++-- compute_tools/src/installed_extensions.rs | 14 +++---- compute_tools/src/monitor.rs | 13 +++--- 7 files changed, 64 insertions(+), 43 deletions(-) diff --git a/compute_tools/src/bin/compute_ctl.rs b/compute_tools/src/bin/compute_ctl.rs index 6b670de2ea..b178d7abd6 100644 --- a/compute_tools/src/bin/compute_ctl.rs +++ b/compute_tools/src/bin/compute_ctl.rs @@ -37,6 +37,7 @@ use std::collections::HashMap; use std::fs::File; use std::path::Path; use std::process::exit; +use std::str::FromStr; use std::sync::atomic::Ordering; use std::sync::{mpsc, Arc, Condvar, Mutex, RwLock}; use std::{thread, time::Duration}; @@ -322,8 +323,15 @@ fn wait_spec( } else { spec_set = false; } + let connstr = Url::parse(connstr).context("cannot parse connstr as a URL")?; + let conn_conf = postgres::config::Config::from_str(connstr.as_str()) + .context("cannot build postgres config from connstr")?; + let tokio_conn_conf = tokio_postgres::config::Config::from_str(connstr.as_str()) + .context("cannot build tokio postgres config from connstr")?; let compute_node = ComputeNode { - connstr: Url::parse(connstr).context("cannot parse connstr as a URL")?, + connstr, + conn_conf, + tokio_conn_conf, pgdata: pgdata.to_string(), pgbin: pgbin.to_string(), pgversion: get_pg_version_string(pgbin), diff --git a/compute_tools/src/catalog.rs b/compute_tools/src/catalog.rs index 08ae8bf44d..72198a9479 100644 --- a/compute_tools/src/catalog.rs +++ b/compute_tools/src/catalog.rs @@ -6,7 +6,6 @@ use tokio::{ process::Command, spawn, }; -use tokio_postgres::connect; use tokio_stream::{self as stream, StreamExt}; use tokio_util::codec::{BytesCodec, FramedRead}; use tracing::warn; @@ -16,10 +15,8 @@ use crate::pg_helpers::{get_existing_dbs_async, get_existing_roles_async, postgr use compute_api::responses::CatalogObjects; pub async fn get_dbs_and_roles(compute: &Arc) -> anyhow::Result { - let connstr = compute.connstr.clone(); - - let (client, connection): (tokio_postgres::Client, _) = - connect(connstr.as_str(), NoTls).await?; + let conf = compute.get_tokio_conn_conf(Some("compute_ctl:get_dbs_and_roles")); + let (client, connection): (tokio_postgres::Client, _) = conf.connect(NoTls).await?; spawn(async move { if let Err(e) = connection.await { diff --git a/compute_tools/src/checker.rs b/compute_tools/src/checker.rs index cec2b1bed8..62d61a8bc9 100644 --- a/compute_tools/src/checker.rs +++ b/compute_tools/src/checker.rs @@ -9,7 +9,8 @@ use crate::compute::ComputeNode; #[instrument(skip_all)] pub async fn check_writability(compute: &ComputeNode) -> Result<()> { // Connect to the database. - let (client, connection) = tokio_postgres::connect(compute.connstr.as_str(), NoTls).await?; + let conf = compute.get_tokio_conn_conf(Some("compute_ctl:availability_checker")); + let (client, connection) = conf.connect(NoTls).await?; if client.is_closed() { return Err(anyhow!("connection to postgres closed")); } diff --git a/compute_tools/src/compute.rs b/compute_tools/src/compute.rs index 1a026a4014..da1caf1a9b 100644 --- a/compute_tools/src/compute.rs +++ b/compute_tools/src/compute.rs @@ -20,8 +20,9 @@ use futures::future::join_all; use futures::stream::FuturesUnordered; use futures::StreamExt; use nix::unistd::Pid; +use postgres; use postgres::error::SqlState; -use postgres::{Client, NoTls}; +use postgres::NoTls; use tracing::{debug, error, info, instrument, warn}; use utils::id::{TenantId, TimelineId}; use utils::lsn::Lsn; @@ -58,6 +59,10 @@ pub static PG_PID: AtomicU32 = AtomicU32::new(0); pub struct ComputeNode { // Url type maintains proper escaping pub connstr: url::Url, + // We connect to Postgres from many different places, so build configs once + // and reuse them where needed. + pub conn_conf: postgres::config::Config, + pub tokio_conn_conf: tokio_postgres::config::Config, pub pgdata: String, pub pgbin: String, pub pgversion: String, @@ -800,10 +805,10 @@ impl ComputeNode { /// version. In the future, it may upgrade all 3rd-party extensions. #[instrument(skip_all)] pub fn post_apply_config(&self) -> Result<()> { - let connstr = self.connstr.clone(); + let conf = self.get_conn_conf(Some("compute_ctl:post_apply_config")); thread::spawn(move || { let func = || { - let mut client = Client::connect(connstr.as_str(), NoTls)?; + let mut client = conf.connect(NoTls)?; handle_neon_extension_upgrade(&mut client) .context("handle_neon_extension_upgrade")?; Ok::<_, anyhow::Error>(()) @@ -815,12 +820,27 @@ impl ComputeNode { Ok(()) } + pub fn get_conn_conf(&self, application_name: Option<&str>) -> postgres::Config { + let mut conf = self.conn_conf.clone(); + if let Some(application_name) = application_name { + conf.application_name(application_name); + } + conf + } + + pub fn get_tokio_conn_conf(&self, application_name: Option<&str>) -> tokio_postgres::Config { + let mut conf = self.tokio_conn_conf.clone(); + if let Some(application_name) = application_name { + conf.application_name(application_name); + } + conf + } + async fn get_maintenance_client( conf: &tokio_postgres::Config, ) -> Result { let mut conf = conf.clone(); - - conf.application_name("apply_config"); + conf.application_name("compute_ctl:apply_config"); let (client, conn) = match conf.connect(NoTls).await { // If connection fails, it may be the old node with `zenith_admin` superuser. @@ -837,6 +857,7 @@ impl ComputeNode { e ); let mut zenith_admin_conf = postgres::config::Config::from(conf.clone()); + zenith_admin_conf.application_name("compute_ctl:apply_config"); zenith_admin_conf.user("zenith_admin"); let mut client = @@ -1134,8 +1155,7 @@ impl ComputeNode { /// Do initial configuration of the already started Postgres. #[instrument(skip_all)] pub fn apply_config(&self, compute_state: &ComputeState) -> Result<()> { - let mut conf = tokio_postgres::Config::from_str(self.connstr.as_str()).unwrap(); - conf.application_name("apply_config"); + let conf = self.get_tokio_conn_conf(Some("compute_ctl:apply_config")); let conf = Arc::new(conf); let spec = Arc::new( @@ -1161,7 +1181,7 @@ impl ComputeNode { thread::spawn(move || { let conf = conf.as_ref().clone(); let mut conf = postgres::config::Config::from(conf); - conf.application_name("migrations"); + conf.application_name("compute_ctl:migrations"); let mut client = conf.connect(NoTls)?; handle_migrations(&mut client).context("apply_config handle_migrations") @@ -1369,9 +1389,9 @@ impl ComputeNode { } self.post_apply_config()?; - let connstr = self.connstr.clone(); + let conf = self.get_conn_conf(None); thread::spawn(move || { - let res = get_installed_extensions(&connstr); + let res = get_installed_extensions(conf); match res { Ok(extensions) => { info!( @@ -1510,7 +1530,8 @@ impl ComputeNode { /// Select `pg_stat_statements` data and return it as a stringified JSON pub async fn collect_insights(&self) -> String { let mut result_rows: Vec = Vec::new(); - let connect_result = tokio_postgres::connect(self.connstr.as_str(), NoTls).await; + let conf = self.get_tokio_conn_conf(Some("compute_ctl:collect_insights")); + let connect_result = conf.connect(NoTls).await; let (client, connection) = connect_result.unwrap(); tokio::spawn(async move { if let Err(e) = connection.await { @@ -1636,10 +1657,9 @@ LIMIT 100", privileges: &[Privilege], role_name: &PgIdent, ) -> Result<()> { - use tokio_postgres::config::Config; use tokio_postgres::NoTls; - let mut conf = Config::from_str(self.connstr.as_str()).unwrap(); + let mut conf = self.get_tokio_conn_conf(Some("compute_ctl:set_role_grants")); conf.dbname(db_name); let (db_client, conn) = conf @@ -1676,10 +1696,9 @@ LIMIT 100", db_name: &PgIdent, ext_version: ExtVersion, ) -> Result { - use tokio_postgres::config::Config; use tokio_postgres::NoTls; - let mut conf = Config::from_str(self.connstr.as_str()).unwrap(); + let mut conf = self.get_tokio_conn_conf(Some("compute_ctl:install_extension")); conf.dbname(db_name); let (db_client, conn) = conf diff --git a/compute_tools/src/http/api.rs b/compute_tools/src/http/api.rs index a6c6cff20a..7fa6426d8f 100644 --- a/compute_tools/src/http/api.rs +++ b/compute_tools/src/http/api.rs @@ -295,12 +295,11 @@ async fn routes(req: Request, compute: &Arc) -> Response render_json(Body::from(serde_json::to_string(&res).unwrap())), diff --git a/compute_tools/src/installed_extensions.rs b/compute_tools/src/installed_extensions.rs index f473c29a55..5f62f08858 100644 --- a/compute_tools/src/installed_extensions.rs +++ b/compute_tools/src/installed_extensions.rs @@ -10,8 +10,6 @@ use metrics::core::Collector; use metrics::{register_uint_gauge_vec, UIntGaugeVec}; use once_cell::sync::Lazy; -use crate::pg_helpers::postgres_conf_for_db; - /// We don't reuse get_existing_dbs() just for code clarity /// and to make database listing query here more explicit. /// @@ -41,14 +39,16 @@ fn list_dbs(client: &mut Client) -> Result> { /// /// Same extension can be installed in multiple databases with different versions, /// we only keep the highest and lowest version across all databases. -pub fn get_installed_extensions(connstr: &url::Url) -> Result { - let mut client = Client::connect(connstr.as_str(), NoTls)?; +pub fn get_installed_extensions(mut conf: postgres::config::Config) -> Result { + conf.application_name("compute_ctl:get_installed_extensions"); + let mut client = conf.connect(NoTls)?; + let databases: Vec = list_dbs(&mut client)?; let mut extensions_map: HashMap = HashMap::new(); for db in databases.iter() { - let config = postgres_conf_for_db(connstr, db)?; - let mut db_client = config.connect(NoTls)?; + conf.dbname(db); + let mut db_client = conf.connect(NoTls)?; let extensions: Vec<(String, String)> = db_client .query( "SELECT extname, extversion FROM pg_catalog.pg_extension;", @@ -82,7 +82,7 @@ pub fn get_installed_extensions(connstr: &url::Url) -> Result = None; @@ -57,7 +54,7 @@ fn watch_compute_activity(compute: &ComputeNode) { info!("connection to Postgres is closed, trying to reconnect"); // Connection is closed, reconnect and try again. - client = Client::connect(connstr, NoTls); + client = conf.connect(NoTls); continue; } @@ -196,7 +193,7 @@ fn watch_compute_activity(compute: &ComputeNode) { debug!("could not connect to Postgres: {}, retrying", e); // Establish a new connection and try again. - client = Client::connect(connstr, NoTls); + client = conf.connect(NoTls); } } } From d5624cc50521098d16a49ad92a735184a48981ae Mon Sep 17 00:00:00 2001 From: John Spray Date: Fri, 29 Nov 2024 15:11:44 +0000 Subject: [PATCH 06/15] pageserver: download small objects using a smaller timeout (#9938) ## Problem It appears that the Azure storage API tends to hang TCP connections more than S3 does. Currently we use a 2 minute timeout for all downloads. This is large because sometimes the objects we download are large. However, waiting 2 minutes when doing something like downloading a manifest on tenant attach is problematic, because when someone is doing a "create tenant, create timeline" workflow, that 2 minutes is long enough for them reasonably to give up creating that timeline. Rather than propagate oversized timeouts further up the stack, we should use a different timeout for objects that we expect to be small. Closes: https://github.com/neondatabase/neon/issues/9836 ## Summary of changes - Add a `small_timeout` configuration attribute to remote storage, defaulting to 30 seconds (still a very generous period to do something like download an index) - Add a DownloadKind parameter to DownloadOpts, so that callers can indicate whether they expect the object to be small or large. - In the azure client, use small timeout for HEAD requests, and for GET requests if DownloadKind::Small is used. - Use DownloadKind::Small for manifests, indices, and heatmap downloads. This PR intentionally does not make the equivalent change to the S3 client, to reduce blast radius in case this has unexpected consequences (we could accomplish the same thing by editing lots of configs, but just skipping the code is simpler for right now) --- libs/remote_storage/src/azure_blob.rs | 23 ++++++++++++++--- libs/remote_storage/src/config.rs | 25 ++++++++++++++++--- libs/remote_storage/src/lib.rs | 20 ++++++++++++++- libs/remote_storage/tests/test_real_azure.rs | 3 ++- libs/remote_storage/tests/test_real_s3.rs | 1 + pageserver/src/deletion_queue.rs | 1 + pageserver/src/tenant.rs | 1 + .../tenant/remote_timeline_client/download.rs | 21 +++++++++++++--- pageserver/src/tenant/secondary/downloader.rs | 3 ++- .../import_pgdata/importbucket_client.rs | 4 ++- proxy/src/context/parquet.rs | 2 ++ 11 files changed, 89 insertions(+), 15 deletions(-) diff --git a/libs/remote_storage/src/azure_blob.rs b/libs/remote_storage/src/azure_blob.rs index 840917ef68..8d1962fa29 100644 --- a/libs/remote_storage/src/azure_blob.rs +++ b/libs/remote_storage/src/azure_blob.rs @@ -35,6 +35,7 @@ use utils::backoff; use utils::backoff::exponential_backoff_duration_seconds; use crate::metrics::{start_measuring_requests, AttemptOutcome, RequestKind}; +use crate::DownloadKind; use crate::{ config::AzureConfig, error::Cancelled, ConcurrencyLimiter, Download, DownloadError, DownloadOpts, Listing, ListingMode, ListingObject, RemotePath, RemoteStorage, StorageMetadata, @@ -49,10 +50,17 @@ pub struct AzureBlobStorage { concurrency_limiter: ConcurrencyLimiter, // Per-request timeout. Accessible for tests. pub timeout: Duration, + + // Alternative timeout used for metadata objects which are expected to be small + pub small_timeout: Duration, } impl AzureBlobStorage { - pub fn new(azure_config: &AzureConfig, timeout: Duration) -> Result { + pub fn new( + azure_config: &AzureConfig, + timeout: Duration, + small_timeout: Duration, + ) -> Result { debug!( "Creating azure remote storage for azure container {}", azure_config.container_name @@ -94,6 +102,7 @@ impl AzureBlobStorage { max_keys_per_list_response, concurrency_limiter: ConcurrencyLimiter::new(azure_config.concurrency_limit.get()), timeout, + small_timeout, }) } @@ -133,6 +142,7 @@ impl AzureBlobStorage { async fn download_for_builder( &self, builder: GetBlobBuilder, + timeout: Duration, cancel: &CancellationToken, ) -> Result { let kind = RequestKind::Get; @@ -156,7 +166,7 @@ impl AzureBlobStorage { .map_err(to_download_error); // apply per request timeout - let response = tokio_stream::StreamExt::timeout(response, self.timeout); + let response = tokio_stream::StreamExt::timeout(response, timeout); // flatten let response = response.map(|res| match res { @@ -415,7 +425,7 @@ impl RemoteStorage for AzureBlobStorage { let blob_client = self.client.blob_client(self.relative_path_to_name(key)); let properties_future = blob_client.get_properties().into_future(); - let properties_future = tokio::time::timeout(self.timeout, properties_future); + let properties_future = tokio::time::timeout(self.small_timeout, properties_future); let res = tokio::select! { res = properties_future => res, @@ -521,7 +531,12 @@ impl RemoteStorage for AzureBlobStorage { }); } - self.download_for_builder(builder, cancel).await + let timeout = match opts.kind { + DownloadKind::Small => self.small_timeout, + DownloadKind::Large => self.timeout, + }; + + self.download_for_builder(builder, timeout, cancel).await } async fn delete(&self, path: &RemotePath, cancel: &CancellationToken) -> anyhow::Result<()> { diff --git a/libs/remote_storage/src/config.rs b/libs/remote_storage/src/config.rs index e99ae4f747..f6ef31077c 100644 --- a/libs/remote_storage/src/config.rs +++ b/libs/remote_storage/src/config.rs @@ -24,6 +24,13 @@ pub struct RemoteStorageConfig { skip_serializing_if = "is_default_timeout" )] pub timeout: Duration, + /// Alternative timeout used for metadata objects which are expected to be small + #[serde( + with = "humantime_serde", + default = "default_small_timeout", + skip_serializing_if = "is_default_small_timeout" + )] + pub small_timeout: Duration, } impl RemoteStorageKind { @@ -40,10 +47,18 @@ fn default_timeout() -> Duration { RemoteStorageConfig::DEFAULT_TIMEOUT } +fn default_small_timeout() -> Duration { + RemoteStorageConfig::DEFAULT_SMALL_TIMEOUT +} + fn is_default_timeout(d: &Duration) -> bool { *d == RemoteStorageConfig::DEFAULT_TIMEOUT } +fn is_default_small_timeout(d: &Duration) -> bool { + *d == RemoteStorageConfig::DEFAULT_SMALL_TIMEOUT +} + /// A kind of a remote storage to connect to, with its connection configuration. #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] #[serde(untagged)] @@ -184,6 +199,7 @@ fn serialize_storage_class( impl RemoteStorageConfig { pub const DEFAULT_TIMEOUT: Duration = std::time::Duration::from_secs(120); + pub const DEFAULT_SMALL_TIMEOUT: Duration = std::time::Duration::from_secs(30); pub fn from_toml(toml: &toml_edit::Item) -> anyhow::Result { Ok(utils::toml_edit_ext::deserialize_item(toml)?) @@ -219,7 +235,8 @@ timeout = '5s'"; storage: RemoteStorageKind::LocalFs { local_path: Utf8PathBuf::from(".") }, - timeout: Duration::from_secs(5) + timeout: Duration::from_secs(5), + small_timeout: RemoteStorageConfig::DEFAULT_SMALL_TIMEOUT } ); } @@ -247,7 +264,8 @@ timeout = '5s'"; max_keys_per_list_response: DEFAULT_MAX_KEYS_PER_LIST_RESPONSE, upload_storage_class: Some(StorageClass::IntelligentTiering), }), - timeout: Duration::from_secs(7) + timeout: Duration::from_secs(7), + small_timeout: RemoteStorageConfig::DEFAULT_SMALL_TIMEOUT } ); } @@ -299,7 +317,8 @@ timeout = '5s'"; concurrency_limit: default_remote_storage_azure_concurrency_limit(), max_keys_per_list_response: DEFAULT_MAX_KEYS_PER_LIST_RESPONSE, }), - timeout: Duration::from_secs(7) + timeout: Duration::from_secs(7), + small_timeout: RemoteStorageConfig::DEFAULT_SMALL_TIMEOUT } ); } diff --git a/libs/remote_storage/src/lib.rs b/libs/remote_storage/src/lib.rs index 719608dd5f..0ece29d99e 100644 --- a/libs/remote_storage/src/lib.rs +++ b/libs/remote_storage/src/lib.rs @@ -178,6 +178,15 @@ pub struct DownloadOpts { /// The end of the byte range to download, or unbounded. Must be after the /// start bound. pub byte_end: Bound, + /// Indicate whether we're downloading something small or large: this indirectly controls + /// timeouts: for something like an index/manifest/heatmap, we should time out faster than + /// for layer files + pub kind: DownloadKind, +} + +pub enum DownloadKind { + Large, + Small, } impl Default for DownloadOpts { @@ -186,6 +195,7 @@ impl Default for DownloadOpts { etag: Default::default(), byte_start: Bound::Unbounded, byte_end: Bound::Unbounded, + kind: DownloadKind::Large, } } } @@ -584,6 +594,10 @@ impl GenericRemoteStorage> { impl GenericRemoteStorage { pub async fn from_config(storage_config: &RemoteStorageConfig) -> anyhow::Result { let timeout = storage_config.timeout; + + // If somkeone overrides timeout to be small without adjusting small_timeout, then adjust it automatically + let small_timeout = std::cmp::min(storage_config.small_timeout, timeout); + Ok(match &storage_config.storage { RemoteStorageKind::LocalFs { local_path: path } => { info!("Using fs root '{path}' as a remote storage"); @@ -606,7 +620,11 @@ impl GenericRemoteStorage { .unwrap_or(""); info!("Using azure container '{}' in account '{storage_account}' in region '{}' as a remote storage, prefix in container: '{:?}'", azure_config.container_name, azure_config.container_region, azure_config.prefix_in_container); - Self::AzureBlob(Arc::new(AzureBlobStorage::new(azure_config, timeout)?)) + Self::AzureBlob(Arc::new(AzureBlobStorage::new( + azure_config, + timeout, + small_timeout, + )?)) } }) } diff --git a/libs/remote_storage/tests/test_real_azure.rs b/libs/remote_storage/tests/test_real_azure.rs index 3a20649490..92d579fec8 100644 --- a/libs/remote_storage/tests/test_real_azure.rs +++ b/libs/remote_storage/tests/test_real_azure.rs @@ -219,7 +219,8 @@ async fn create_azure_client( concurrency_limit: NonZeroUsize::new(100).unwrap(), max_keys_per_list_response, }), - timeout: Duration::from_secs(120), + timeout: RemoteStorageConfig::DEFAULT_TIMEOUT, + small_timeout: RemoteStorageConfig::DEFAULT_SMALL_TIMEOUT, }; Ok(Arc::new( GenericRemoteStorage::from_config(&remote_storage_config) diff --git a/libs/remote_storage/tests/test_real_s3.rs b/libs/remote_storage/tests/test_real_s3.rs index 3e99a65fac..e60ec18c93 100644 --- a/libs/remote_storage/tests/test_real_s3.rs +++ b/libs/remote_storage/tests/test_real_s3.rs @@ -396,6 +396,7 @@ async fn create_s3_client( upload_storage_class: None, }), timeout: RemoteStorageConfig::DEFAULT_TIMEOUT, + small_timeout: RemoteStorageConfig::DEFAULT_SMALL_TIMEOUT, }; Ok(Arc::new( GenericRemoteStorage::from_config(&remote_storage_config) diff --git a/pageserver/src/deletion_queue.rs b/pageserver/src/deletion_queue.rs index e74c8ecf5a..1d508f5fe9 100644 --- a/pageserver/src/deletion_queue.rs +++ b/pageserver/src/deletion_queue.rs @@ -838,6 +838,7 @@ mod test { local_path: remote_fs_dir.clone(), }, timeout: RemoteStorageConfig::DEFAULT_TIMEOUT, + small_timeout: RemoteStorageConfig::DEFAULT_SMALL_TIMEOUT, }; let storage = GenericRemoteStorage::from_config(&storage_config) .await diff --git a/pageserver/src/tenant.rs b/pageserver/src/tenant.rs index 339a3ca1bb..cd0690bb1a 100644 --- a/pageserver/src/tenant.rs +++ b/pageserver/src/tenant.rs @@ -5423,6 +5423,7 @@ pub(crate) mod harness { local_path: remote_fs_dir.clone(), }, timeout: RemoteStorageConfig::DEFAULT_TIMEOUT, + small_timeout: RemoteStorageConfig::DEFAULT_SMALL_TIMEOUT, }; let remote_storage = GenericRemoteStorage::from_config(&config).await.unwrap(); let deletion_queue = MockDeletionQueue::new(Some(remote_storage.clone())); diff --git a/pageserver/src/tenant/remote_timeline_client/download.rs b/pageserver/src/tenant/remote_timeline_client/download.rs index d632e595ad..739615be9c 100644 --- a/pageserver/src/tenant/remote_timeline_client/download.rs +++ b/pageserver/src/tenant/remote_timeline_client/download.rs @@ -30,7 +30,9 @@ use crate::tenant::Generation; use crate::virtual_file::owned_buffers_io::io_buf_ext::IoBufExt; use crate::virtual_file::{on_fatal_io_error, MaybeFatalIo, VirtualFile}; use crate::TEMP_FILE_SUFFIX; -use remote_storage::{DownloadError, DownloadOpts, GenericRemoteStorage, ListingMode, RemotePath}; +use remote_storage::{ + DownloadError, DownloadKind, DownloadOpts, GenericRemoteStorage, ListingMode, RemotePath, +}; use utils::crashsafe::path_with_suffix_extension; use utils::id::{TenantId, TimelineId}; use utils::pausable_failpoint; @@ -345,12 +347,13 @@ pub async fn list_remote_timelines( async fn do_download_remote_path_retry_forever( storage: &GenericRemoteStorage, remote_path: &RemotePath, + download_opts: DownloadOpts, cancel: &CancellationToken, ) -> Result<(Vec, SystemTime), DownloadError> { download_retry_forever( || async { let download = storage - .download(remote_path, &DownloadOpts::default(), cancel) + .download(remote_path, &download_opts, cancel) .await?; let mut bytes = Vec::new(); @@ -377,8 +380,13 @@ async fn do_download_tenant_manifest( ) -> Result<(TenantManifest, Generation, SystemTime), DownloadError> { let remote_path = remote_tenant_manifest_path(tenant_shard_id, generation); + let download_opts = DownloadOpts { + kind: DownloadKind::Small, + ..Default::default() + }; + let (manifest_bytes, manifest_bytes_mtime) = - do_download_remote_path_retry_forever(storage, &remote_path, cancel).await?; + do_download_remote_path_retry_forever(storage, &remote_path, download_opts, cancel).await?; let tenant_manifest = TenantManifest::from_json_bytes(&manifest_bytes) .with_context(|| format!("deserialize tenant manifest file at {remote_path:?}")) @@ -398,8 +406,13 @@ async fn do_download_index_part( timeline_id.expect("A timeline ID is always provided when downloading an index"); let remote_path = remote_index_path(tenant_shard_id, timeline_id, index_generation); + let download_opts = DownloadOpts { + kind: DownloadKind::Small, + ..Default::default() + }; + let (index_part_bytes, index_part_mtime) = - do_download_remote_path_retry_forever(storage, &remote_path, cancel).await?; + do_download_remote_path_retry_forever(storage, &remote_path, download_opts, cancel).await?; let index_part: IndexPart = serde_json::from_slice(&index_part_bytes) .with_context(|| format!("deserialize index part file at {remote_path:?}")) diff --git a/pageserver/src/tenant/secondary/downloader.rs b/pageserver/src/tenant/secondary/downloader.rs index 7443261a9c..8d771dc405 100644 --- a/pageserver/src/tenant/secondary/downloader.rs +++ b/pageserver/src/tenant/secondary/downloader.rs @@ -49,7 +49,7 @@ use futures::Future; use metrics::UIntGauge; use pageserver_api::models::SecondaryProgress; use pageserver_api::shard::TenantShardId; -use remote_storage::{DownloadError, DownloadOpts, Etag, GenericRemoteStorage}; +use remote_storage::{DownloadError, DownloadKind, DownloadOpts, Etag, GenericRemoteStorage}; use tokio_util::sync::CancellationToken; use tracing::{info_span, instrument, warn, Instrument}; @@ -946,6 +946,7 @@ impl<'a> TenantDownloader<'a> { let cancel = &self.secondary_state.cancel; let opts = DownloadOpts { etag: prev_etag.cloned(), + kind: DownloadKind::Small, ..Default::default() }; diff --git a/pageserver/src/tenant/timeline/import_pgdata/importbucket_client.rs b/pageserver/src/tenant/timeline/import_pgdata/importbucket_client.rs index 8d5ab1780f..bc4d148a29 100644 --- a/pageserver/src/tenant/timeline/import_pgdata/importbucket_client.rs +++ b/pageserver/src/tenant/timeline/import_pgdata/importbucket_client.rs @@ -4,7 +4,8 @@ use anyhow::Context; use bytes::Bytes; use postgres_ffi::ControlFileData; use remote_storage::{ - Download, DownloadError, DownloadOpts, GenericRemoteStorage, Listing, ListingObject, RemotePath, + Download, DownloadError, DownloadKind, DownloadOpts, GenericRemoteStorage, Listing, + ListingObject, RemotePath, }; use serde::de::DeserializeOwned; use tokio_util::sync::CancellationToken; @@ -239,6 +240,7 @@ impl RemoteStorageWrapper { .download( path, &DownloadOpts { + kind: DownloadKind::Large, etag: None, byte_start: Bound::Included(start_inclusive), byte_end: Bound::Excluded(end_exclusive) diff --git a/proxy/src/context/parquet.rs b/proxy/src/context/parquet.rs index e328c6de79..b375eb886e 100644 --- a/proxy/src/context/parquet.rs +++ b/proxy/src/context/parquet.rs @@ -486,6 +486,7 @@ mod tests { upload_storage_class: None, }), timeout: RemoteStorageConfig::DEFAULT_TIMEOUT, + small_timeout: RemoteStorageConfig::DEFAULT_SMALL_TIMEOUT, }) ); assert_eq!(parquet_upload.parquet_upload_row_group_size, 100); @@ -545,6 +546,7 @@ mod tests { local_path: tmpdir.to_path_buf(), }, timeout: std::time::Duration::from_secs(120), + small_timeout: std::time::Duration::from_secs(30), }; let storage = GenericRemoteStorage::from_config(&remote_storage_config) .await From c848f25ec25e04afba9f2b0509372504b35cafe9 Mon Sep 17 00:00:00 2001 From: Gleb Novikov Date: Fri, 29 Nov 2024 17:58:36 +0000 Subject: [PATCH 07/15] Fixed fast_import pgbin in calling get_pg_version (#9933) Was working on https://github.com/neondatabase/cloud/pull/20795 and discovered that fast_import is not working normally. --- compute_tools/src/bin/fast_import.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/compute_tools/src/bin/fast_import.rs b/compute_tools/src/bin/fast_import.rs index 6716cc6234..b6db3eb11a 100644 --- a/compute_tools/src/bin/fast_import.rs +++ b/compute_tools/src/bin/fast_import.rs @@ -21,7 +21,7 @@ //! - Build the image with the following command: //! //! ```bash -//! docker buildx build --build-arg DEBIAN_FLAVOR=bullseye-slim --build-arg GIT_VERSION=local --build-arg PG_VERSION=v14 --build-arg BUILD_TAG="$(date --iso-8601=s -u)" -t localhost:3030/localregistry/compute-node-v14:latest -f compute/Dockerfile.com +//! docker buildx build --platform linux/amd64 --build-arg DEBIAN_VERSION=bullseye --build-arg GIT_VERSION=local --build-arg PG_VERSION=v14 --build-arg BUILD_TAG="$(date --iso-8601=s -u)" -t localhost:3030/localregistry/compute-node-v14:latest -f compute/compute-node.Dockerfile . //! docker push localhost:3030/localregistry/compute-node-v14:latest //! ``` @@ -132,7 +132,8 @@ pub(crate) async fn main() -> anyhow::Result<()> { // // Initialize pgdata // - let pg_version = match get_pg_version(pg_bin_dir.as_str()) { + let pgbin = pg_bin_dir.join("postgres"); + let pg_version = match get_pg_version(pgbin.as_ref()) { PostgresMajorVersion::V14 => 14, PostgresMajorVersion::V15 => 15, PostgresMajorVersion::V16 => 16, @@ -155,7 +156,7 @@ pub(crate) async fn main() -> anyhow::Result<()> { // // Launch postgres process // - let mut postgres_proc = tokio::process::Command::new(pg_bin_dir.join("postgres")) + let mut postgres_proc = tokio::process::Command::new(pgbin) .arg("-D") .arg(&pgdata_dir) .args(["-c", "wal_level=minimal"]) From 973a8d2680f968e83e5668e69c87636189146e54 Mon Sep 17 00:00:00 2001 From: Matthias van de Meent Date: Fri, 29 Nov 2024 20:10:26 +0100 Subject: [PATCH 08/15] Fix timeout value used in XLogWaitForReplayOf (#9937) The previous value assumed usec precision, while the timeout used is in milliseconds, causing replica backends to wait for (potentially) many hours for WAL replay without the expected progress reports in logs. This fixes the issue. Reported-By: Alexander Lakhin ## Problem https://github.com/neondatabase/postgres/pull/279#issuecomment-2507671817 The timeout value was configured with the assumption the indicated value would be microseconds, where it's actually milliseconds. That causes the backend to wait for much longer (2h46m40s) before it emits the "I'm waiting for recovery" message. While we do have wait events configured on this, it's not great to have stuck backends without clear logs, so this fixes the timeout value in all our PostgreSQL branches. ## PG PRs * PG14: https://github.com/neondatabase/postgres/pull/542 * PG15: https://github.com/neondatabase/postgres/pull/543 * PG16: https://github.com/neondatabase/postgres/pull/544 * PG17: https://github.com/neondatabase/postgres/pull/545 --- vendor/postgres-v14 | 2 +- vendor/postgres-v15 | 2 +- vendor/postgres-v16 | 2 +- vendor/postgres-v17 | 2 +- vendor/revisions.json | 8 ++++---- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/vendor/postgres-v14 b/vendor/postgres-v14 index 284ae56be2..c1989c934d 160000 --- a/vendor/postgres-v14 +++ b/vendor/postgres-v14 @@ -1 +1 @@ -Subproject commit 284ae56be2397fd3eaf20777fa220b2d0ad968f5 +Subproject commit c1989c934d46e04e78b3c496c8a34bcd40ddceeb diff --git a/vendor/postgres-v15 b/vendor/postgres-v15 index aed79ee87b..d929b9a8b9 160000 --- a/vendor/postgres-v15 +++ b/vendor/postgres-v15 @@ -1 +1 @@ -Subproject commit aed79ee87b94779cc52ec13e3b74eba6ada93f05 +Subproject commit d929b9a8b9f32f6fe5a0eac3e6e963f0e44e27e6 diff --git a/vendor/postgres-v16 b/vendor/postgres-v16 index f5cfc6fa89..13e9e35394 160000 --- a/vendor/postgres-v16 +++ b/vendor/postgres-v16 @@ -1 +1 @@ -Subproject commit f5cfc6fa898544050e821ac688adafece1ac3cff +Subproject commit 13e9e3539419003e79bd9aa29e1bc44f3fd555dd diff --git a/vendor/postgres-v17 b/vendor/postgres-v17 index 3c15b6565f..faebe5e5af 160000 --- a/vendor/postgres-v17 +++ b/vendor/postgres-v17 @@ -1 +1 @@ -Subproject commit 3c15b6565f6c8d36d169ed9ea7412cf90cfb2a8f +Subproject commit faebe5e5aff5687908504453623778f8515529db diff --git a/vendor/revisions.json b/vendor/revisions.json index 4dae88e73d..abeddcadf7 100644 --- a/vendor/revisions.json +++ b/vendor/revisions.json @@ -1,18 +1,18 @@ { "v17": [ "17.2", - "3c15b6565f6c8d36d169ed9ea7412cf90cfb2a8f" + "faebe5e5aff5687908504453623778f8515529db" ], "v16": [ "16.6", - "f5cfc6fa898544050e821ac688adafece1ac3cff" + "13e9e3539419003e79bd9aa29e1bc44f3fd555dd" ], "v15": [ "15.10", - "aed79ee87b94779cc52ec13e3b74eba6ada93f05" + "d929b9a8b9f32f6fe5a0eac3e6e963f0e44e27e6" ], "v14": [ "14.15", - "284ae56be2397fd3eaf20777fa220b2d0ad968f5" + "c1989c934d46e04e78b3c496c8a34bcd40ddceeb" ] } From aa4ec11af9c982a4022074f18a05745d91633bca Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Sat, 30 Nov 2024 01:16:24 +0100 Subject: [PATCH 09/15] page_service: rewrite batching to work without a timeout (#9851) # Problem The timeout-based batching adds latency to unbatchable workloads. We can choose a short batching timeout (e.g. 10us) but that requires high-resolution timers, which tokio doesn't have. I thoroughly explored options to use OS timers (see [this](https://github.com/neondatabase/neon/pull/9822) abandoned PR). In short, it's not an attractive option because any timer implementation adds non-trivial overheads. # Solution The insight is that, in the steady state of a batchable workload, the time we spend in `get_vectored` will be hundreds of microseconds anyway. If we prepare the next batch concurrently to `get_vectored`, we will have a sizeable batch ready once `get_vectored` of the current batch is done and do not need an explicit timeout. This can be reasonably described as **pipelining of the protocol handler**. # Implementation We model the sub-protocol handler for pagestream requests (`handle_pagrequests`) as two futures that form a pipeline: 2. Batching: read requests from the connection and fill the current batch 3. Execution: `take` the current batch, execute it using `get_vectored`, and send the response. The Reading and Batching stage are connected through a new type of channel called `spsc_fold`. See the long comment in the `handle_pagerequests_pipelined` for details. # Changes - Refactor `handle_pagerequests` - separate functions for - reading one protocol message; produces a `BatchedFeMessage` with just one page request in it - batching; tried to merge an incoming `BatchedFeMessage` into an existing `BatchedFeMessage`; returns `None` on success and returns back the incoming message in case merging isn't possible - execution of a batched message - unify the timeline handle acquisition & request span construction; it now happen in the function that reads the protocol message - Implement serial and pipelined model - serial: what we had before any of the batching changes - read one protocol message - execute protocol messages - pipelined: the design described above - optionality for execution of the pipeline: either via concurrent futures vs tokio tasks - Pageserver config - remove batching timeout field - add ability to configure pipelining mode - add ability to limit max batch size for pipelined configurations (required for the rollout, cf https://github.com/neondatabase/cloud/issues/20620 ) - ability to configure execution mode - Tests - remove `batch_timeout` parametrization - rename `test_getpage_merge_smoke` to `test_throughput` - add parametrization to test different max batch sizes and execution moes - rename `test_timer_precision` to `test_latency` - rename the test case file to `test_page_service_batching.py` - better descriptions of what the tests actually do ## On the holding The `TimelineHandle` in the pending batch While batching, we hold the `TimelineHandle` in the pending batch. Therefore, the timeline will not finish shutting down while we're batching. This is not a problem in practice because the concurrently ongoing `get_vectored` call will fail quickly with an error indicating that the timeline is shutting down. This results in the Execution stage returning a `QueryError::Shutdown`, which causes the pipeline / entire page service connection to shut down. This drops all references to the `Arc>>>` object, thereby dropping the contained `TimelineHandle`s. - => fixes https://github.com/neondatabase/neon/issues/9850 # Performance Local run of the benchmarks, results in [this empty commit](https://github.com/neondatabase/neon/pull/9851/commits/1cf5b1463f69ba5066cbb0713912aec7bb5579ad) in the PR branch. Key take-aways: * `concurrent-futures` and `tasks` deliver identical `batching_factor` * tail latency impact unknown, cf https://github.com/neondatabase/neon/issues/9837 * `concurrent-futures` has higher throughput than `tasks` in all workloads (=lower `time` metric) * In unbatchable workloads, `concurrent-futures` has 5% higher `CPU-per-throughput` than that of `tasks`, and 15% higher than that of `serial`. * In batchable-32 workload, `concurrent-futures` has 8% lower `CPU-per-throughput` than that of `tasks` (comparison to tput of `serial` is irrelevant) * in unbatchable workloads, mean and tail latencies of `concurrent-futures` is practically identical to `serial`, whereas `tasks` adds 20-30us of overhead Overall, `concurrent-futures` seems like a slightly more attractive choice. # Rollout This change is disabled-by-default. Rollout plan: - https://github.com/neondatabase/cloud/issues/20620 # Refs - epic: https://github.com/neondatabase/neon/issues/9376 - this sub-task: https://github.com/neondatabase/neon/issues/9377 - the abandoned attempt to improve batching timeout resolution: https://github.com/neondatabase/neon/pull/9820 - closes https://github.com/neondatabase/neon/issues/9850 - fixes https://github.com/neondatabase/neon/issues/9835 --- Cargo.lock | 10 +- Cargo.toml | 1 + libs/pageserver_api/src/config.rs | 30 +- libs/utils/Cargo.toml | 2 + libs/utils/src/sync.rs | 2 + libs/utils/src/sync/spsc_fold.rs | 452 +++++++ pageserver/src/config.rs | 10 +- pageserver/src/lib.rs | 19 + pageserver/src/page_service.rs | 1059 ++++++++++------- test_runner/fixtures/neon_fixtures.py | 3 +- ...merge.py => test_page_service_batching.py} | 131 +- 11 files changed, 1262 insertions(+), 457 deletions(-) create mode 100644 libs/utils/src/sync/spsc_fold.rs rename test_runner/performance/pageserver/{test_pageserver_getpage_merge.py => test_page_service_batching.py} (69%) diff --git a/Cargo.lock b/Cargo.lock index abe69525c9..313222cf3c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "RustyXML" @@ -1717,6 +1717,12 @@ dependencies = [ "utils", ] +[[package]] +name = "diatomic-waker" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab03c107fafeb3ee9f5925686dbb7a73bc76e3932abb0d2b365cb64b169cf04c" + [[package]] name = "diesel" version = "2.2.3" @@ -7045,6 +7051,7 @@ dependencies = [ "chrono", "const_format", "criterion", + "diatomic-waker", "fail", "futures", "git-version", @@ -7063,6 +7070,7 @@ dependencies = [ "rand 0.8.5", "regex", "routerify", + "scopeguard", "sentry", "serde", "serde_assert", diff --git a/Cargo.toml b/Cargo.toml index 742201d0f5..64c384f17a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -83,6 +83,7 @@ comfy-table = "7.1" const_format = "0.2" crc32c = "0.6" dashmap = { version = "5.5.0", features = ["raw-api"] } +diatomic-waker = { version = "0.2.3" } either = "1.8" enum-map = "2.4.2" enumset = "1.0.12" diff --git a/libs/pageserver_api/src/config.rs b/libs/pageserver_api/src/config.rs index 721d97404b..e49d15ba87 100644 --- a/libs/pageserver_api/src/config.rs +++ b/libs/pageserver_api/src/config.rs @@ -118,9 +118,8 @@ pub struct ConfigToml { pub virtual_file_io_mode: Option, #[serde(skip_serializing_if = "Option::is_none")] pub no_sync: Option, - #[serde(with = "humantime_serde")] - pub server_side_batch_timeout: Option, pub wal_receiver_protocol: PostgresClientProtocol, + pub page_service_pipelining: PageServicePipeliningConfig, } #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] @@ -137,6 +136,28 @@ pub struct DiskUsageEvictionTaskConfig { pub eviction_order: EvictionOrder, } +#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[serde(tag = "mode", rename_all = "kebab-case")] +#[serde(deny_unknown_fields)] +pub enum PageServicePipeliningConfig { + Serial, + Pipelined(PageServicePipeliningConfigPipelined), +} +#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[serde(deny_unknown_fields)] +pub struct PageServicePipeliningConfigPipelined { + /// Causes runtime errors if larger than max get_vectored batch size. + pub max_batch_size: NonZeroUsize, + pub execution: PageServiceProtocolPipelinedExecutionStrategy, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[serde(rename_all = "kebab-case")] +pub enum PageServiceProtocolPipelinedExecutionStrategy { + ConcurrentFutures, + Tasks, +} + pub mod statvfs { pub mod mock { #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] @@ -332,8 +353,6 @@ pub mod defaults { pub const DEFAULT_IO_BUFFER_ALIGNMENT: usize = 512; - pub const DEFAULT_SERVER_SIDE_BATCH_TIMEOUT: Option<&str> = None; - pub const DEFAULT_WAL_RECEIVER_PROTOCOL: utils::postgres_client::PostgresClientProtocol = utils::postgres_client::PostgresClientProtocol::Vanilla; } @@ -420,11 +439,10 @@ impl Default for ConfigToml { ephemeral_bytes_per_memory_kb: (DEFAULT_EPHEMERAL_BYTES_PER_MEMORY_KB), l0_flush: None, virtual_file_io_mode: None, - server_side_batch_timeout: DEFAULT_SERVER_SIDE_BATCH_TIMEOUT - .map(|duration| humantime::parse_duration(duration).unwrap()), tenant_config: TenantConfigToml::default(), no_sync: None, wal_receiver_protocol: DEFAULT_WAL_RECEIVER_PROTOCOL, + page_service_pipelining: PageServicePipeliningConfig::Serial, } } } diff --git a/libs/utils/Cargo.toml b/libs/utils/Cargo.toml index f440b81d8f..5648072a83 100644 --- a/libs/utils/Cargo.toml +++ b/libs/utils/Cargo.toml @@ -19,6 +19,7 @@ bincode.workspace = true bytes.workspace = true camino.workspace = true chrono.workspace = true +diatomic-waker.workspace = true git-version.workspace = true hex = { workspace = true, features = ["serde"] } humantime.workspace = true @@ -45,6 +46,7 @@ tracing.workspace = true tracing-error.workspace = true tracing-subscriber = { workspace = true, features = ["json", "registry"] } rand.workspace = true +scopeguard.workspace = true strum.workspace = true strum_macros.workspace = true url.workspace = true diff --git a/libs/utils/src/sync.rs b/libs/utils/src/sync.rs index 2ee8f35449..7aa26e24bc 100644 --- a/libs/utils/src/sync.rs +++ b/libs/utils/src/sync.rs @@ -1,3 +1,5 @@ pub mod heavier_once_cell; pub mod gate; + +pub mod spsc_fold; diff --git a/libs/utils/src/sync/spsc_fold.rs b/libs/utils/src/sync/spsc_fold.rs new file mode 100644 index 0000000000..b44f766ef0 --- /dev/null +++ b/libs/utils/src/sync/spsc_fold.rs @@ -0,0 +1,452 @@ +use core::{future::poll_fn, task::Poll}; +use std::sync::{Arc, Mutex}; + +use diatomic_waker::DiatomicWaker; + +pub struct Sender { + state: Arc>, +} + +pub struct Receiver { + state: Arc>, +} + +struct Inner { + wake_receiver: DiatomicWaker, + wake_sender: DiatomicWaker, + value: Mutex>, +} + +enum State { + NoData, + HasData(T), + TryFoldFailed, // transient state + SenderWaitsForReceiverToConsume(T), + SenderGone(Option), + ReceiverGone, + AllGone, + SenderDropping, // transient state + ReceiverDropping, // transient state +} + +pub fn channel() -> (Sender, Receiver) { + let inner = Inner { + wake_receiver: DiatomicWaker::new(), + wake_sender: DiatomicWaker::new(), + value: Mutex::new(State::NoData), + }; + + let state = Arc::new(inner); + ( + Sender { + state: state.clone(), + }, + Receiver { state }, + ) +} + +#[derive(Debug, thiserror::Error)] +pub enum SendError { + #[error("receiver is gone")] + ReceiverGone, +} + +impl Sender { + /// # Panics + /// + /// If `try_fold` panics, any subsequent call to `send` panic. + pub async fn send(&mut self, value: T, try_fold: F) -> Result<(), SendError> + where + F: Fn(&mut T, T) -> Result<(), T>, + { + let mut value = Some(value); + poll_fn(|cx| { + let mut guard = self.state.value.lock().unwrap(); + match &mut *guard { + State::NoData => { + *guard = State::HasData(value.take().unwrap()); + self.state.wake_receiver.notify(); + Poll::Ready(Ok(())) + } + State::HasData(_) => { + let State::HasData(acc_mut) = &mut *guard else { + unreachable!("this match arm guarantees that the guard is HasData"); + }; + match try_fold(acc_mut, value.take().unwrap()) { + Ok(()) => { + // no need to wake receiver, if it was waiting it already + // got a wake-up when we transitioned from NoData to HasData + Poll::Ready(Ok(())) + } + Err(unfoldable_value) => { + value = Some(unfoldable_value); + let State::HasData(acc) = + std::mem::replace(&mut *guard, State::TryFoldFailed) + else { + unreachable!("this match arm guarantees that the guard is HasData"); + }; + *guard = State::SenderWaitsForReceiverToConsume(acc); + // SAFETY: send is single threaded due to `&mut self` requirement, + // therefore register is not concurrent. + unsafe { + self.state.wake_sender.register(cx.waker()); + } + Poll::Pending + } + } + } + State::SenderWaitsForReceiverToConsume(_data) => { + // Really, we shouldn't be polled until receiver has consumed and wakes us. + Poll::Pending + } + State::ReceiverGone => Poll::Ready(Err(SendError::ReceiverGone)), + State::SenderGone(_) + | State::AllGone + | State::SenderDropping + | State::ReceiverDropping + | State::TryFoldFailed => { + unreachable!(); + } + } + }) + .await + } +} + +impl Drop for Sender { + fn drop(&mut self) { + scopeguard::defer! { + self.state.wake_receiver.notify() + }; + let Ok(mut guard) = self.state.value.lock() else { + return; + }; + *guard = match std::mem::replace(&mut *guard, State::SenderDropping) { + State::NoData => State::SenderGone(None), + State::HasData(data) | State::SenderWaitsForReceiverToConsume(data) => { + State::SenderGone(Some(data)) + } + State::ReceiverGone => State::AllGone, + State::TryFoldFailed + | State::SenderGone(_) + | State::AllGone + | State::SenderDropping + | State::ReceiverDropping => { + unreachable!("unreachable state {:?}", guard.discriminant_str()) + } + } + } +} + +#[derive(Debug, thiserror::Error)] +pub enum RecvError { + #[error("sender is gone")] + SenderGone, +} + +impl Receiver { + pub async fn recv(&mut self) -> Result { + poll_fn(|cx| { + let mut guard = self.state.value.lock().unwrap(); + match &mut *guard { + State::NoData => { + // SAFETY: recv is single threaded due to `&mut self` requirement, + // therefore register is not concurrent. + unsafe { + self.state.wake_receiver.register(cx.waker()); + } + Poll::Pending + } + guard @ State::HasData(_) + | guard @ State::SenderWaitsForReceiverToConsume(_) + | guard @ State::SenderGone(Some(_)) => { + let data = guard + .take_data() + .expect("in these states, data is guaranteed to be present"); + self.state.wake_sender.notify(); + Poll::Ready(Ok(data)) + } + State::SenderGone(None) => Poll::Ready(Err(RecvError::SenderGone)), + State::ReceiverGone + | State::AllGone + | State::SenderDropping + | State::ReceiverDropping + | State::TryFoldFailed => { + unreachable!("unreachable state {:?}", guard.discriminant_str()); + } + } + }) + .await + } +} + +impl Drop for Receiver { + fn drop(&mut self) { + scopeguard::defer! { + self.state.wake_sender.notify() + }; + let Ok(mut guard) = self.state.value.lock() else { + return; + }; + *guard = match std::mem::replace(&mut *guard, State::ReceiverDropping) { + State::NoData => State::ReceiverGone, + State::HasData(_) | State::SenderWaitsForReceiverToConsume(_) => State::ReceiverGone, + State::SenderGone(_) => State::AllGone, + State::TryFoldFailed + | State::ReceiverGone + | State::AllGone + | State::SenderDropping + | State::ReceiverDropping => { + unreachable!("unreachable state {:?}", guard.discriminant_str()) + } + } + } +} + +impl State { + fn take_data(&mut self) -> Option { + match self { + State::HasData(_) => { + let State::HasData(data) = std::mem::replace(self, State::NoData) else { + unreachable!("this match arm guarantees that the state is HasData"); + }; + Some(data) + } + State::SenderWaitsForReceiverToConsume(_) => { + let State::SenderWaitsForReceiverToConsume(data) = + std::mem::replace(self, State::NoData) + else { + unreachable!( + "this match arm guarantees that the state is SenderWaitsForReceiverToConsume" + ); + }; + Some(data) + } + State::SenderGone(data) => Some(data.take().unwrap()), + State::NoData + | State::TryFoldFailed + | State::ReceiverGone + | State::AllGone + | State::SenderDropping + | State::ReceiverDropping => None, + } + } + fn discriminant_str(&self) -> &'static str { + match self { + State::NoData => "NoData", + State::HasData(_) => "HasData", + State::TryFoldFailed => "TryFoldFailed", + State::SenderWaitsForReceiverToConsume(_) => "SenderWaitsForReceiverToConsume", + State::SenderGone(_) => "SenderGone", + State::ReceiverGone => "ReceiverGone", + State::AllGone => "AllGone", + State::SenderDropping => "SenderDropping", + State::ReceiverDropping => "ReceiverDropping", + } + } +} + +#[cfg(test)] +mod tests { + + use super::*; + + const FOREVER: std::time::Duration = std::time::Duration::from_secs(u64::MAX); + + #[tokio::test] + async fn test_send_recv() { + let (mut sender, mut receiver) = channel(); + + sender + .send(42, |acc, val| { + *acc += val; + Ok(()) + }) + .await + .unwrap(); + + let received = receiver.recv().await.unwrap(); + assert_eq!(received, 42); + } + + #[tokio::test] + async fn test_send_recv_with_fold() { + let (mut sender, mut receiver) = channel(); + + sender + .send(1, |acc, val| { + *acc += val; + Ok(()) + }) + .await + .unwrap(); + sender + .send(2, |acc, val| { + *acc += val; + Ok(()) + }) + .await + .unwrap(); + + let received = receiver.recv().await.unwrap(); + assert_eq!(received, 3); + } + + #[tokio::test(start_paused = true)] + async fn test_sender_waits_for_receiver_if_try_fold_fails() { + let (mut sender, mut receiver) = channel(); + + sender.send(23, |_, _| panic!("first send")).await.unwrap(); + + let send_fut = sender.send(42, |_, val| Err(val)); + let mut send_fut = std::pin::pin!(send_fut); + + tokio::select! { + _ = tokio::time::sleep(FOREVER) => {}, + _ = &mut send_fut => { + panic!("send should not complete"); + }, + } + + let val = receiver.recv().await.unwrap(); + assert_eq!(val, 23); + + tokio::select! { + _ = tokio::time::sleep(FOREVER) => { + panic!("receiver should have consumed the value"); + }, + _ = &mut send_fut => { }, + } + + let val = receiver.recv().await.unwrap(); + assert_eq!(val, 42); + } + + #[tokio::test(start_paused = true)] + async fn test_sender_errors_if_waits_for_receiver_and_receiver_drops() { + let (mut sender, receiver) = channel(); + + sender.send(23, |_, _| unreachable!()).await.unwrap(); + + let send_fut = sender.send(42, |_, val| Err(val)); + let send_fut = std::pin::pin!(send_fut); + + drop(receiver); + + let result = send_fut.await; + assert!(matches!(result, Err(SendError::ReceiverGone))); + } + + #[tokio::test(start_paused = true)] + async fn test_receiver_errors_if_waits_for_sender_and_sender_drops() { + let (sender, mut receiver) = channel::<()>(); + + let recv_fut = receiver.recv(); + let recv_fut = std::pin::pin!(recv_fut); + + drop(sender); + + let result = recv_fut.await; + assert!(matches!(result, Err(RecvError::SenderGone))); + } + + #[tokio::test(start_paused = true)] + async fn test_receiver_errors_if_waits_for_sender_and_sender_drops_with_data() { + let (mut sender, mut receiver) = channel(); + + sender.send(42, |_, _| unreachable!()).await.unwrap(); + + { + let recv_fut = receiver.recv(); + let recv_fut = std::pin::pin!(recv_fut); + + drop(sender); + + let val = recv_fut.await.unwrap(); + assert_eq!(val, 42); + } + + let result = receiver.recv().await; + assert!(matches!(result, Err(RecvError::SenderGone))); + } + + #[tokio::test(start_paused = true)] + async fn test_receiver_waits_for_sender_if_no_data() { + let (mut sender, mut receiver) = channel(); + + let recv_fut = receiver.recv(); + let mut recv_fut = std::pin::pin!(recv_fut); + + tokio::select! { + _ = tokio::time::sleep(FOREVER) => {}, + _ = &mut recv_fut => { + panic!("recv should not complete"); + }, + } + + sender.send(42, |_, _| Ok(())).await.unwrap(); + + let val = recv_fut.await.unwrap(); + assert_eq!(val, 42); + } + + #[tokio::test] + async fn test_receiver_gone_while_nodata() { + let (mut sender, receiver) = channel(); + drop(receiver); + + let result = sender.send(42, |_, _| Ok(())).await; + assert!(matches!(result, Err(SendError::ReceiverGone))); + } + + #[tokio::test] + async fn test_sender_gone_while_nodata() { + let (sender, mut receiver) = super::channel::(); + drop(sender); + + let result = receiver.recv().await; + assert!(matches!(result, Err(RecvError::SenderGone))); + } + + #[tokio::test(start_paused = true)] + async fn test_receiver_drops_after_sender_went_to_sleep() { + let (mut sender, receiver) = channel(); + let state = receiver.state.clone(); + + sender.send(23, |_, _| unreachable!()).await.unwrap(); + + let send_task = tokio::spawn(async move { sender.send(42, |_, v| Err(v)).await }); + + tokio::time::sleep(FOREVER).await; + + assert!(matches!( + &*state.value.lock().unwrap(), + &State::SenderWaitsForReceiverToConsume(_) + )); + + drop(receiver); + + let err = send_task + .await + .unwrap() + .expect_err("should unblock immediately"); + assert!(matches!(err, SendError::ReceiverGone)); + } + + #[tokio::test(start_paused = true)] + async fn test_sender_drops_after_receiver_went_to_sleep() { + let (sender, mut receiver) = channel::(); + let state = sender.state.clone(); + + let recv_task = tokio::spawn(async move { receiver.recv().await }); + + tokio::time::sleep(FOREVER).await; + + assert!(matches!(&*state.value.lock().unwrap(), &State::NoData)); + + drop(sender); + + let err = recv_task.await.unwrap().expect_err("should error"); + assert!(matches!(err, RecvError::SenderGone)); + } +} diff --git a/pageserver/src/config.rs b/pageserver/src/config.rs index 2cf237e72b..1651db8500 100644 --- a/pageserver/src/config.rs +++ b/pageserver/src/config.rs @@ -188,11 +188,9 @@ pub struct PageServerConf { /// Optionally disable disk syncs (unsafe!) pub no_sync: bool, - /// Maximum amount of time for which a get page request request - /// might be held up for request merging. - pub server_side_batch_timeout: Option, - pub wal_receiver_protocol: PostgresClientProtocol, + + pub page_service_pipelining: pageserver_api::config::PageServicePipeliningConfig, } /// Token for authentication to safekeepers @@ -350,10 +348,10 @@ impl PageServerConf { concurrent_tenant_warmup, concurrent_tenant_size_logical_size_queries, virtual_file_io_engine, - server_side_batch_timeout, tenant_config, no_sync, wal_receiver_protocol, + page_service_pipelining, } = config_toml; let mut conf = PageServerConf { @@ -393,11 +391,11 @@ impl PageServerConf { image_compression, timeline_offloading, ephemeral_bytes_per_memory_kb, - server_side_batch_timeout, import_pgdata_upcall_api, import_pgdata_upcall_api_token: import_pgdata_upcall_api_token.map(SecretString::from), import_pgdata_aws_endpoint_url, wal_receiver_protocol, + page_service_pipelining, // ------------------------------------------------------------ // fields that require additional validation or custom handling diff --git a/pageserver/src/lib.rs b/pageserver/src/lib.rs index ef6711397a..ff6af3566c 100644 --- a/pageserver/src/lib.rs +++ b/pageserver/src/lib.rs @@ -356,6 +356,25 @@ async fn timed( } } +/// Like [`timed`], but the warning timeout only starts after `cancel` has been cancelled. +async fn timed_after_cancellation( + fut: Fut, + name: &str, + warn_at: std::time::Duration, + cancel: &CancellationToken, +) -> ::Output { + let mut fut = std::pin::pin!(fut); + + tokio::select! { + _ = cancel.cancelled() => { + timed(fut, name, warn_at).await + } + ret = &mut fut => { + ret + } + } +} + #[cfg(test)] mod timed_tests { use super::timed; diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index 5fd02d8749..1917e7f5b7 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -7,6 +7,10 @@ use bytes::Buf; use futures::FutureExt; use itertools::Itertools; use once_cell::sync::OnceCell; +use pageserver_api::config::{ + PageServicePipeliningConfig, PageServicePipeliningConfigPipelined, + PageServiceProtocolPipelinedExecutionStrategy, +}; use pageserver_api::models::{self, TenantState}; use pageserver_api::models::{ PagestreamBeMessage, PagestreamDbSizeRequest, PagestreamDbSizeResponse, @@ -16,12 +20,15 @@ use pageserver_api::models::{ PagestreamProtocolVersion, }; use pageserver_api::shard::TenantShardId; -use postgres_backend::{is_expected_io_error, AuthType, PostgresBackend, QueryError}; +use postgres_backend::{ + is_expected_io_error, AuthType, PostgresBackend, PostgresBackendReader, QueryError, +}; use pq_proto::framed::ConnectionError; use pq_proto::FeStartupPacket; use pq_proto::{BeMessage, FeMessage, RowDescriptor}; use std::borrow::Cow; use std::io; +use std::num::NonZeroUsize; use std::str; use std::str::FromStr; use std::sync::Arc; @@ -32,6 +39,7 @@ use tokio::io::{AsyncWriteExt, BufWriter}; use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; use tracing::*; +use utils::sync::spsc_fold; use utils::{ auth::{Claims, Scope, SwappableJwtAuth}, id::{TenantId, TimelineId}, @@ -40,7 +48,6 @@ use utils::{ }; use crate::auth::check_permission; -use crate::basebackup; use crate::basebackup::BasebackupError; use crate::config::PageServerConf; use crate::context::{DownloadBehavior, RequestContext}; @@ -58,6 +65,7 @@ use crate::tenant::timeline::{self, WaitLsnError}; use crate::tenant::GetTimelineError; use crate::tenant::PageReconstructError; use crate::tenant::Timeline; +use crate::{basebackup, timed_after_cancellation}; use pageserver_api::key::rel_block_to_key; use pageserver_api::reltag::{BlockNumber, RelTag, SlruKind}; use postgres_ffi::pg_constants::DEFAULTTABLESPACE_OID; @@ -105,7 +113,7 @@ pub fn spawn( pg_auth, tcp_listener, conf.pg_auth_type, - conf.server_side_batch_timeout, + conf.page_service_pipelining.clone(), libpq_ctx, cancel.clone(), ) @@ -154,7 +162,7 @@ pub async fn libpq_listener_main( auth: Option>, listener: tokio::net::TcpListener, auth_type: AuthType, - server_side_batch_timeout: Option, + pipelining_config: PageServicePipeliningConfig, listener_ctx: RequestContext, listener_cancel: CancellationToken, ) -> Connections { @@ -185,7 +193,7 @@ pub async fn libpq_listener_main( local_auth, socket, auth_type, - server_side_batch_timeout, + pipelining_config.clone(), connection_ctx, connections_cancel.child_token(), )); @@ -213,7 +221,7 @@ async fn page_service_conn_main( auth: Option>, socket: tokio::net::TcpStream, auth_type: AuthType, - server_side_batch_timeout: Option, + pipelining_config: PageServicePipeliningConfig, connection_ctx: RequestContext, cancel: CancellationToken, ) -> ConnectionHandlerResult { @@ -256,7 +264,7 @@ async fn page_service_conn_main( // a while: we will tear down this PageServerHandler and instantiate a new one if/when // they reconnect. socket.set_timeout(Some(std::time::Duration::from_millis(socket_timeout_ms))); - let socket = std::pin::pin!(socket); + let socket = Box::pin(socket); fail::fail_point!("ps::connection-start::pre-login"); @@ -267,7 +275,7 @@ async fn page_service_conn_main( let mut conn_handler = PageServerHandler::new( tenant_manager, auth, - server_side_batch_timeout, + pipelining_config, connection_ctx, cancel.clone(), ); @@ -283,7 +291,7 @@ async fn page_service_conn_main( info!("Postgres client disconnected ({io_error})"); Ok(()) } else { - let tenant_id = conn_handler.timeline_handles.tenant_id(); + let tenant_id = conn_handler.timeline_handles.as_ref().unwrap().tenant_id(); Err(io_error).context(format!( "Postgres connection error for tenant_id={:?} client at peer_addr={}", tenant_id, peer_addr @@ -291,7 +299,7 @@ async fn page_service_conn_main( } } other => { - let tenant_id = conn_handler.timeline_handles.tenant_id(); + let tenant_id = conn_handler.timeline_handles.as_ref().unwrap().tenant_id(); other.context(format!( "Postgres query error for tenant_id={:?} client peer_addr={}", tenant_id, peer_addr @@ -312,13 +320,10 @@ struct PageServerHandler { cancel: CancellationToken, - timeline_handles: TimelineHandles, + /// None only while pagestream protocol is being processed. + timeline_handles: Option, - /// Messages queued up for the next processing batch - next_batch: Option, - - /// See [`PageServerConf::server_side_batch_timeout`] - server_side_batch_timeout: Option, + pipelining_config: PageServicePipeliningConfig, } struct TimelineHandles { @@ -535,10 +540,12 @@ impl From for QueryError { enum BatchedFeMessage { Exists { span: Span, + shard: timeline::handle::Handle, req: models::PagestreamExistsRequest, }, Nblocks { span: Span, + shard: timeline::handle::Handle, req: models::PagestreamNblocksRequest, }, GetPage { @@ -549,10 +556,12 @@ enum BatchedFeMessage { }, DbSize { span: Span, + shard: timeline::handle::Handle, req: models::PagestreamDbSizeRequest, }, GetSlruSegment { span: Span, + shard: timeline::handle::Handle, req: models::PagestreamGetSlruSegmentRequest, }, RespondError { @@ -561,18 +570,11 @@ enum BatchedFeMessage { }, } -enum BatchOrEof { - /// In the common case, this has one entry. - /// At most, it has two entries: the first is the leftover batch, the second is an error. - Batch(smallvec::SmallVec<[BatchedFeMessage; 1]>), - Eof, -} - impl PageServerHandler { pub fn new( tenant_manager: Arc, auth: Option>, - server_side_batch_timeout: Option, + pipelining_config: PageServicePipeliningConfig, connection_ctx: RequestContext, cancel: CancellationToken, ) -> Self { @@ -580,10 +582,9 @@ impl PageServerHandler { auth, claims: None, connection_ctx, - timeline_handles: TimelineHandles::new(tenant_manager), + timeline_handles: Some(TimelineHandles::new(tenant_manager)), cancel, - next_batch: None, - server_side_batch_timeout, + pipelining_config, } } @@ -611,219 +612,356 @@ impl PageServerHandler { ) } - async fn read_batch_from_connection( - &mut self, - pgb: &mut PostgresBackend, - tenant_id: &TenantId, - timeline_id: &TimelineId, + async fn pagestream_read_message( + pgb: &mut PostgresBackendReader, + tenant_id: TenantId, + timeline_id: TimelineId, + timeline_handles: &mut TimelineHandles, + cancel: &CancellationToken, ctx: &RequestContext, - ) -> Result, QueryError> + parent_span: Span, + ) -> Result, QueryError> + where + IO: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static, + { + let msg = tokio::select! { + biased; + _ = cancel.cancelled() => { + return Err(QueryError::Shutdown) + } + msg = pgb.read_message() => { msg } + }; + + let copy_data_bytes = match msg? { + Some(FeMessage::CopyData(bytes)) => bytes, + Some(FeMessage::Terminate) => { + return Ok(None); + } + Some(m) => { + return Err(QueryError::Other(anyhow::anyhow!( + "unexpected message: {m:?} during COPY" + ))); + } + None => { + return Ok(None); + } // client disconnected + }; + trace!("query: {copy_data_bytes:?}"); + + fail::fail_point!("ps::handle-pagerequest-message"); + + // parse request + let neon_fe_msg = PagestreamFeMessage::parse(&mut copy_data_bytes.reader())?; + + let batched_msg = match neon_fe_msg { + PagestreamFeMessage::Exists(req) => { + let span = tracing::info_span!(parent: parent_span, "handle_get_rel_exists_request", rel = %req.rel, req_lsn = %req.request_lsn); + let shard = timeline_handles + .get(tenant_id, timeline_id, ShardSelector::Zero) + .instrument(span.clone()) // sets `shard_id` field + .await?; + BatchedFeMessage::Exists { span, shard, req } + } + PagestreamFeMessage::Nblocks(req) => { + let span = tracing::info_span!(parent: parent_span, "handle_get_nblocks_request", rel = %req.rel, req_lsn = %req.request_lsn); + let shard = timeline_handles + .get(tenant_id, timeline_id, ShardSelector::Zero) + .instrument(span.clone()) // sets `shard_id` field + .await?; + BatchedFeMessage::Nblocks { span, shard, req } + } + PagestreamFeMessage::DbSize(req) => { + let span = tracing::info_span!(parent: parent_span, "handle_db_size_request", dbnode = %req.dbnode, req_lsn = %req.request_lsn); + let shard = timeline_handles + .get(tenant_id, timeline_id, ShardSelector::Zero) + .instrument(span.clone()) // sets `shard_id` field + .await?; + BatchedFeMessage::DbSize { span, shard, req } + } + PagestreamFeMessage::GetSlruSegment(req) => { + let span = tracing::info_span!(parent: parent_span, "handle_get_slru_segment_request", kind = %req.kind, segno = %req.segno, req_lsn = %req.request_lsn); + let shard = timeline_handles + .get(tenant_id, timeline_id, ShardSelector::Zero) + .instrument(span.clone()) // sets `shard_id` field + .await?; + BatchedFeMessage::GetSlruSegment { span, shard, req } + } + PagestreamFeMessage::GetPage(PagestreamGetPageRequest { + request_lsn, + not_modified_since, + rel, + blkno, + }) => { + let span = tracing::info_span!(parent: parent_span, "handle_get_page_at_lsn_request_batched", req_lsn = %request_lsn); + + macro_rules! respond_error { + ($error:expr) => {{ + let error = BatchedFeMessage::RespondError { + span, + error: $error, + }; + Ok(Some(error)) + }}; + } + + let key = rel_block_to_key(rel, blkno); + let shard = match timeline_handles + .get(tenant_id, timeline_id, ShardSelector::Page(key)) + .instrument(span.clone()) // sets `shard_id` field + .await + { + Ok(tl) => tl, + Err(GetActiveTimelineError::Tenant(GetActiveTenantError::NotFound(_))) => { + // We already know this tenant exists in general, because we resolved it at + // start of connection. Getting a NotFound here indicates that the shard containing + // the requested page is not present on this node: the client's knowledge of shard->pageserver + // mapping is out of date. + // + // Closing the connection by returning ``::Reconnect` has the side effect of rate-limiting above message, via + // client's reconnect backoff, as well as hopefully prompting the client to load its updated configuration + // and talk to a different pageserver. + return respond_error!(PageStreamError::Reconnect( + "getpage@lsn request routed to wrong shard".into() + )); + } + Err(e) => { + return respond_error!(e.into()); + } + }; + let effective_request_lsn = match Self::wait_or_get_last_lsn( + &shard, + request_lsn, + not_modified_since, + &shard.get_latest_gc_cutoff_lsn(), + ctx, + ) + // TODO: if we actually need to wait for lsn here, it delays the entire batch which doesn't need to wait + .await + { + Ok(lsn) => lsn, + Err(e) => { + return respond_error!(e); + } + }; + BatchedFeMessage::GetPage { + span, + shard, + effective_request_lsn, + pages: smallvec::smallvec![(rel, blkno)], + } + } + }; + Ok(Some(batched_msg)) + } + + /// Post-condition: `batch` is Some() + #[instrument(skip_all, level = tracing::Level::TRACE)] + #[allow(clippy::boxed_local)] + fn pagestream_do_batch( + max_batch_size: NonZeroUsize, + batch: &mut Result, + this_msg: Result, + ) -> Result<(), Result> { + debug_assert_current_span_has_tenant_and_timeline_id_no_shard_id(); + + let this_msg = match this_msg { + Ok(this_msg) => this_msg, + Err(e) => return Err(Err(e)), + }; + + match (&mut *batch, this_msg) { + // something batched already, let's see if we can add this message to the batch + ( + Ok(BatchedFeMessage::GetPage { + span: _, + shard: accum_shard, + pages: ref mut accum_pages, + effective_request_lsn: accum_lsn, + }), + BatchedFeMessage::GetPage { + span: _, + shard: this_shard, + pages: this_pages, + effective_request_lsn: this_lsn, + }, + ) if (|| { + assert_eq!(this_pages.len(), 1); + if accum_pages.len() >= max_batch_size.get() { + trace!(%accum_lsn, %this_lsn, %max_batch_size, "stopping batching because of batch size"); + assert_eq!(accum_pages.len(), max_batch_size.get()); + return false; + } + if (accum_shard.tenant_shard_id, accum_shard.timeline_id) + != (this_shard.tenant_shard_id, this_shard.timeline_id) + { + trace!(%accum_lsn, %this_lsn, "stopping batching because timeline object mismatch"); + // TODO: we _could_ batch & execute each shard seperately (and in parallel). + // But the current logic for keeping responses in order does not support that. + return false; + } + // the vectored get currently only supports a single LSN, so, bounce as soon + // as the effective request_lsn changes + if *accum_lsn != this_lsn { + trace!(%accum_lsn, %this_lsn, "stopping batching because LSN changed"); + return false; + } + true + })() => + { + // ok to batch + accum_pages.extend(this_pages); + Ok(()) + } + // something batched already but this message is unbatchable + (_, this_msg) => { + // by default, don't continue batching + Err(Ok(this_msg)) + } + } + } + + #[instrument(level = tracing::Level::DEBUG, skip_all)] + async fn pagesteam_handle_batched_message( + &mut self, + pgb_writer: &mut PostgresBackend, + batch: BatchedFeMessage, + cancel: &CancellationToken, + ctx: &RequestContext, + ) -> Result<(), QueryError> where IO: AsyncRead + AsyncWrite + Send + Sync + Unpin, { - let mut batch = self.next_batch.take(); - let mut batch_started_at: Option = None; - - let next_batch: Option = loop { - let sleep_fut = match (self.server_side_batch_timeout, batch_started_at) { - (Some(batch_timeout), Some(started_at)) => futures::future::Either::Left( - tokio::time::sleep_until((started_at + batch_timeout).into()), - ), - _ => futures::future::Either::Right(futures::future::pending()), - }; - - let msg = tokio::select! { - biased; - _ = self.cancel.cancelled() => { - return Err(QueryError::Shutdown) - } - msg = pgb.read_message() => { - msg - } - _ = sleep_fut => { - assert!(batch.is_some()); - break None; - } - }; - let copy_data_bytes = match msg? { - Some(FeMessage::CopyData(bytes)) => bytes, - Some(FeMessage::Terminate) => { - return Ok(Some(BatchOrEof::Eof)); - } - Some(m) => { - return Err(QueryError::Other(anyhow::anyhow!( - "unexpected message: {m:?} during COPY" - ))); - } - None => { - return Ok(Some(BatchOrEof::Eof)); - } // client disconnected - }; - trace!("query: {copy_data_bytes:?}"); - fail::fail_point!("ps::handle-pagerequest-message"); - - // parse request - let neon_fe_msg = PagestreamFeMessage::parse(&mut copy_data_bytes.reader())?; - - let this_msg = match neon_fe_msg { - PagestreamFeMessage::Exists(req) => BatchedFeMessage::Exists { - span: tracing::info_span!("handle_get_rel_exists_request", rel = %req.rel, req_lsn = %req.request_lsn), - req, - }, - PagestreamFeMessage::Nblocks(req) => BatchedFeMessage::Nblocks { - span: tracing::info_span!("handle_get_nblocks_request", rel = %req.rel, req_lsn = %req.request_lsn), - req, - }, - PagestreamFeMessage::DbSize(req) => BatchedFeMessage::DbSize { - span: tracing::info_span!("handle_db_size_request", dbnode = %req.dbnode, req_lsn = %req.request_lsn), - req, - }, - PagestreamFeMessage::GetSlruSegment(req) => BatchedFeMessage::GetSlruSegment { - span: tracing::info_span!("handle_get_slru_segment_request", kind = %req.kind, segno = %req.segno, req_lsn = %req.request_lsn), - req, - }, - PagestreamFeMessage::GetPage(PagestreamGetPageRequest { - request_lsn, - not_modified_since, - rel, - blkno, - }) => { - // shard_id is filled in by the handler - let span = tracing::info_span!( - "handle_get_page_at_lsn_request_batched", - %tenant_id, %timeline_id, shard_id = tracing::field::Empty, req_lsn = %request_lsn, - batch_size = tracing::field::Empty, batch_id = tracing::field::Empty - ); - - macro_rules! current_batch_and_error { - ($error:expr) => {{ - let error = BatchedFeMessage::RespondError { - span, - error: $error, - }; - let batch_and_error = match batch { - Some(b) => smallvec::smallvec![b, error], - None => smallvec::smallvec![error], - }; - Ok(Some(BatchOrEof::Batch(batch_and_error))) - }}; - } - - let key = rel_block_to_key(rel, blkno); - let shard = match self - .timeline_handles - .get(*tenant_id, *timeline_id, ShardSelector::Page(key)) - .instrument(span.clone()) - .await - { - Ok(tl) => tl, - Err(GetActiveTimelineError::Tenant(GetActiveTenantError::NotFound(_))) => { - // We already know this tenant exists in general, because we resolved it at - // start of connection. Getting a NotFound here indicates that the shard containing - // the requested page is not present on this node: the client's knowledge of shard->pageserver - // mapping is out of date. - // - // Closing the connection by returning ``::Reconnect` has the side effect of rate-limiting above message, via - // client's reconnect backoff, as well as hopefully prompting the client to load its updated configuration - // and talk to a different pageserver. - return current_batch_and_error!(PageStreamError::Reconnect( - "getpage@lsn request routed to wrong shard".into() - )); - } - Err(e) => { - return current_batch_and_error!(e.into()); - } - }; - let effective_request_lsn = match Self::wait_or_get_last_lsn( - &shard, - request_lsn, - not_modified_since, - &shard.get_latest_gc_cutoff_lsn(), - ctx, - ) - // TODO: if we actually need to wait for lsn here, it delays the entire batch which doesn't need to wait - .await - { - Ok(lsn) => lsn, - Err(e) => { - return current_batch_and_error!(e); - } - }; - BatchedFeMessage::GetPage { + // invoke handler function + let (handler_results, span): (Vec>, _) = + match batch { + BatchedFeMessage::Exists { span, shard, req } => { + fail::fail_point!("ps::handle-pagerequest-message::exists"); + ( + vec![ + self.handle_get_rel_exists_request(&shard, &req, ctx) + .instrument(span.clone()) + .await, + ], span, - shard, - effective_request_lsn, - pages: smallvec::smallvec![(rel, blkno)], - } + ) + } + BatchedFeMessage::Nblocks { span, shard, req } => { + fail::fail_point!("ps::handle-pagerequest-message::nblocks"); + ( + vec![ + self.handle_get_nblocks_request(&shard, &req, ctx) + .instrument(span.clone()) + .await, + ], + span, + ) + } + BatchedFeMessage::GetPage { + span, + shard, + effective_request_lsn, + pages, + } => { + fail::fail_point!("ps::handle-pagerequest-message::getpage"); + ( + { + let npages = pages.len(); + trace!(npages, "handling getpage request"); + let res = self + .handle_get_page_at_lsn_request_batched( + &shard, + effective_request_lsn, + pages, + ctx, + ) + .instrument(span.clone()) + .await; + assert_eq!(res.len(), npages); + res + }, + span, + ) + } + BatchedFeMessage::DbSize { span, shard, req } => { + fail::fail_point!("ps::handle-pagerequest-message::dbsize"); + ( + vec![ + self.handle_db_size_request(&shard, &req, ctx) + .instrument(span.clone()) + .await, + ], + span, + ) + } + BatchedFeMessage::GetSlruSegment { span, shard, req } => { + fail::fail_point!("ps::handle-pagerequest-message::slrusegment"); + ( + vec![ + self.handle_get_slru_segment_request(&shard, &req, ctx) + .instrument(span.clone()) + .await, + ], + span, + ) + } + BatchedFeMessage::RespondError { span, error } => { + // We've already decided to respond with an error, so we don't need to + // call the handler. + (vec![Err(error)], span) } }; - let batch_timeout = match self.server_side_batch_timeout { - Some(value) => value, - None => { - // Batching is not enabled - stop on the first message. - return Ok(Some(BatchOrEof::Batch(smallvec::smallvec![this_msg]))); - } + // Map handler result to protocol behavior. + // Some handler errors cause exit from pagestream protocol. + // Other handler errors are sent back as an error message and we stay in pagestream protocol. + for handler_result in handler_results { + let response_msg = match handler_result { + Err(e) => match &e { + PageStreamError::Shutdown => { + // If we fail to fulfil a request during shutdown, which may be _because_ of + // shutdown, then do not send the error to the client. Instead just drop the + // connection. + span.in_scope(|| info!("dropping connection due to shutdown")); + return Err(QueryError::Shutdown); + } + PageStreamError::Reconnect(reason) => { + span.in_scope(|| info!("handler requested reconnect: {reason}")); + return Err(QueryError::Reconnect); + } + PageStreamError::Read(_) + | PageStreamError::LsnTimeout(_) + | PageStreamError::NotFound(_) + | PageStreamError::BadRequest(_) => { + // print the all details to the log with {:#}, but for the client the + // error message is enough. Do not log if shutting down, as the anyhow::Error + // here includes cancellation which is not an error. + let full = utils::error::report_compact_sources(&e); + span.in_scope(|| { + error!("error reading relation or page version: {full:#}") + }); + PagestreamBeMessage::Error(PagestreamErrorResponse { + message: e.to_string(), + }) + } + }, + Ok(response_msg) => response_msg, }; - // check if we can batch - match (&mut batch, this_msg) { - (None, this_msg) => { - batch = Some(this_msg); - } - ( - Some(BatchedFeMessage::GetPage { - span: _, - shard: accum_shard, - pages: accum_pages, - effective_request_lsn: accum_lsn, - }), - BatchedFeMessage::GetPage { - span: _, - shard: this_shard, - pages: this_pages, - effective_request_lsn: this_lsn, - }, - ) if async { - assert_eq!(this_pages.len(), 1); - if accum_pages.len() >= Timeline::MAX_GET_VECTORED_KEYS as usize { - assert_eq!(accum_pages.len(), Timeline::MAX_GET_VECTORED_KEYS as usize); - return false; - } - if (accum_shard.tenant_shard_id, accum_shard.timeline_id) - != (this_shard.tenant_shard_id, this_shard.timeline_id) - { - // TODO: we _could_ batch & execute each shard seperately (and in parallel). - // But the current logic for keeping responses in order does not support that. - return false; - } - // the vectored get currently only supports a single LSN, so, bounce as soon - // as the effective request_lsn changes - if *accum_lsn != this_lsn { - return false; - } - true - } - .await => - { - // ok to batch - accum_pages.extend(this_pages); - } - (Some(_), this_msg) => { - // by default, don't continue batching - break Some(this_msg); - } + // marshal & transmit response message + pgb_writer.write_message_noflush(&BeMessage::CopyData(&response_msg.serialize()))?; + } + tokio::select! { + biased; + _ = cancel.cancelled() => { + // We were requested to shut down. + info!("shutdown request received in page handler"); + return Err(QueryError::Shutdown) } - - // batching impl piece - let started_at = batch_started_at.get_or_insert_with(Instant::now); - if started_at.elapsed() > batch_timeout { - break None; + res = pgb_writer.flush() => { + res?; } - }; - - self.next_batch = next_batch; - Ok(batch.map(|b| BatchOrEof::Batch(smallvec::smallvec![b]))) + } + Ok(()) } /// Pagestream sub-protocol handler. @@ -845,7 +983,7 @@ impl PageServerHandler { ctx: RequestContext, ) -> Result<(), QueryError> where - IO: AsyncRead + AsyncWrite + Send + Sync + Unpin, + IO: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static, { debug_assert_current_span_has_tenant_and_timeline_id_no_shard_id(); @@ -861,169 +999,283 @@ impl PageServerHandler { } } - // If [`PageServerHandler`] is reused for multiple pagestreams, - // then make sure to not process requests from the previous ones. - self.next_batch = None; + let pgb_reader = pgb + .split() + .context("implementation error: split pgb into reader and writer")?; - loop { - let maybe_batched = self - .read_batch_from_connection(pgb, &tenant_id, &timeline_id, &ctx) - .await?; - let batched = match maybe_batched { - Some(BatchOrEof::Batch(b)) => b, - Some(BatchOrEof::Eof) => { - break; - } + let timeline_handles = self + .timeline_handles + .take() + .expect("implementation error: timeline_handles should not be locked"); + + let request_span = info_span!("request", shard_id = tracing::field::Empty); + let ((pgb_reader, timeline_handles), result) = match self.pipelining_config.clone() { + PageServicePipeliningConfig::Pipelined(pipelining_config) => { + self.handle_pagerequests_pipelined( + pgb, + pgb_reader, + tenant_id, + timeline_id, + timeline_handles, + request_span, + pipelining_config, + &ctx, + ) + .await + } + PageServicePipeliningConfig::Serial => { + self.handle_pagerequests_serial( + pgb, + pgb_reader, + tenant_id, + timeline_id, + timeline_handles, + request_span, + &ctx, + ) + .await + } + }; + + debug!("pagestream subprotocol shut down cleanly"); + + pgb.unsplit(pgb_reader) + .context("implementation error: unsplit pgb")?; + + let replaced = self.timeline_handles.replace(timeline_handles); + assert!(replaced.is_none()); + + result + } + + #[allow(clippy::too_many_arguments)] + async fn handle_pagerequests_serial( + &mut self, + pgb_writer: &mut PostgresBackend, + mut pgb_reader: PostgresBackendReader, + tenant_id: TenantId, + timeline_id: TimelineId, + mut timeline_handles: TimelineHandles, + request_span: Span, + ctx: &RequestContext, + ) -> ( + (PostgresBackendReader, TimelineHandles), + Result<(), QueryError>, + ) + where + IO: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static, + { + let cancel = self.cancel.clone(); + let err = loop { + let msg = Self::pagestream_read_message( + &mut pgb_reader, + tenant_id, + timeline_id, + &mut timeline_handles, + &cancel, + ctx, + request_span.clone(), + ) + .await; + let msg = match msg { + Ok(msg) => msg, + Err(e) => break e, + }; + let msg = match msg { + Some(msg) => msg, None => { - continue; + debug!("pagestream subprotocol end observed"); + return ((pgb_reader, timeline_handles), Ok(())); } }; + let err = self + .pagesteam_handle_batched_message(pgb_writer, msg, &cancel, ctx) + .await; + match err { + Ok(()) => {} + Err(e) => break e, + } + }; + ((pgb_reader, timeline_handles), Err(err)) + } - for batch in batched { - // invoke handler function - let (handler_results, span): ( - Vec>, - _, - ) = match batch { - BatchedFeMessage::Exists { span, req } => { - fail::fail_point!("ps::handle-pagerequest-message::exists"); - ( - vec![ - self.handle_get_rel_exists_request( - tenant_id, - timeline_id, - &req, - &ctx, - ) - .instrument(span.clone()) - .await, - ], - span, - ) - } - BatchedFeMessage::Nblocks { span, req } => { - fail::fail_point!("ps::handle-pagerequest-message::nblocks"); - ( - vec![ - self.handle_get_nblocks_request(tenant_id, timeline_id, &req, &ctx) - .instrument(span.clone()) - .await, - ], - span, - ) - } - BatchedFeMessage::GetPage { - span, - shard, - effective_request_lsn, - pages, - } => { - fail::fail_point!("ps::handle-pagerequest-message::getpage"); - ( - { - let npages = pages.len(); - let res = self - .handle_get_page_at_lsn_request_batched( - &shard, - effective_request_lsn, - pages, - &ctx, - ) - .instrument(span.clone()) - .await; - assert_eq!(res.len(), npages); - res - }, - span, - ) - } - BatchedFeMessage::DbSize { span, req } => { - fail::fail_point!("ps::handle-pagerequest-message::dbsize"); - ( - vec![ - self.handle_db_size_request(tenant_id, timeline_id, &req, &ctx) - .instrument(span.clone()) - .await, - ], - span, - ) - } - BatchedFeMessage::GetSlruSegment { span, req } => { - fail::fail_point!("ps::handle-pagerequest-message::slrusegment"); - ( - vec![ - self.handle_get_slru_segment_request( - tenant_id, - timeline_id, - &req, - &ctx, - ) - .instrument(span.clone()) - .await, - ], - span, - ) - } - BatchedFeMessage::RespondError { span, error } => { - // We've already decided to respond with an error, so we don't need to - // call the handler. - (vec![Err(error)], span) - } - }; + /// # Cancel-Safety + /// + /// May leak tokio tasks if not polled to completion. + #[allow(clippy::too_many_arguments)] + async fn handle_pagerequests_pipelined( + &mut self, + pgb_writer: &mut PostgresBackend, + pgb_reader: PostgresBackendReader, + tenant_id: TenantId, + timeline_id: TimelineId, + mut timeline_handles: TimelineHandles, + request_span: Span, + pipelining_config: PageServicePipeliningConfigPipelined, + ctx: &RequestContext, + ) -> ( + (PostgresBackendReader, TimelineHandles), + Result<(), QueryError>, + ) + where + IO: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static, + { + // + // Pipelined pagestream handling consists of + // - a Batcher that reads requests off the wire and + // and batches them if possible, + // - an Executor that processes the batched requests. + // + // The batch is built up inside an `spsc_fold` channel, + // shared betwen Batcher (Sender) and Executor (Receiver). + // + // The Batcher continously folds client requests into the batch, + // while the Executor can at any time take out what's in the batch + // in order to process it. + // This means the next batch builds up while the Executor + // executes the last batch. + // + // CANCELLATION + // + // We run both Batcher and Executor futures to completion before + // returning from this function. + // + // If Executor exits first, it signals cancellation to the Batcher + // via a CancellationToken that is child of `self.cancel`. + // If Batcher exits first, it signals cancellation to the Executor + // by dropping the spsc_fold channel Sender. + // + // CLEAN SHUTDOWN + // + // Clean shutdown means that the client ends the COPYBOTH session. + // In response to such a client message, the Batcher exits. + // The Executor continues to run, draining the spsc_fold channel. + // Once drained, the spsc_fold recv will fail with a distinct error + // indicating that the sender disconnected. + // The Executor exits with Ok(()) in response to that error. + // + // Server initiated shutdown is not clean shutdown, but instead + // is an error Err(QueryError::Shutdown) that is propagated through + // error propagation. + // + // ERROR PROPAGATION + // + // When the Batcher encounter an error, it sends it as a value + // through the spsc_fold channel and exits afterwards. + // When the Executor observes such an error in the channel, + // it exits returning that error value. + // + // This design ensures that the Executor stage will still process + // the batch that was in flight when the Batcher encountered an error, + // thereby beahving identical to a serial implementation. - // Map handler result to protocol behavior. - // Some handler errors cause exit from pagestream protocol. - // Other handler errors are sent back as an error message and we stay in pagestream protocol. - for handler_result in handler_results { - let response_msg = match handler_result { - Err(e) => match &e { - PageStreamError::Shutdown => { - // If we fail to fulfil a request during shutdown, which may be _because_ of - // shutdown, then do not send the error to the client. Instead just drop the - // connection. - span.in_scope(|| info!("dropping connection due to shutdown")); - return Err(QueryError::Shutdown); - } - PageStreamError::Reconnect(reason) => { - span.in_scope(|| info!("handler requested reconnect: {reason}")); - return Err(QueryError::Reconnect); - } - PageStreamError::Read(_) - | PageStreamError::LsnTimeout(_) - | PageStreamError::NotFound(_) - | PageStreamError::BadRequest(_) => { - // print the all details to the log with {:#}, but for the client the - // error message is enough. Do not log if shutting down, as the anyhow::Error - // here includes cancellation which is not an error. - let full = utils::error::report_compact_sources(&e); - span.in_scope(|| { - error!("error reading relation or page version: {full:#}") - }); - PagestreamBeMessage::Error(PagestreamErrorResponse { - message: e.to_string(), - }) - } - }, - Ok(response_msg) => response_msg, - }; + let PageServicePipeliningConfigPipelined { + max_batch_size, + execution, + } = pipelining_config; - // marshal & transmit response message - pgb.write_message_noflush(&BeMessage::CopyData(&response_msg.serialize()))?; + // Macro to _define_ a pipeline stage. + macro_rules! pipeline_stage { + ($name:literal, $cancel:expr, $make_fut:expr) => {{ + let cancel: CancellationToken = $cancel; + let stage_fut = $make_fut(cancel.clone()); + async move { + scopeguard::defer! { + debug!("exiting"); + } + timed_after_cancellation(stage_fut, $name, Duration::from_millis(100), &cancel) + .await } - tokio::select! { - biased; - _ = self.cancel.cancelled() => { - // We were requested to shut down. - info!("shutdown request received in page handler"); - return Err(QueryError::Shutdown) - } - res = pgb.flush() => { - res?; - } + .instrument(tracing::info_span!($name)) + }}; + } + + // + // Batcher + // + + let cancel_batcher = self.cancel.child_token(); + let (mut batch_tx, mut batch_rx) = spsc_fold::channel(); + let batcher = pipeline_stage!("batcher", cancel_batcher.clone(), move |cancel_batcher| { + let ctx = ctx.attached_child(); + async move { + let mut pgb_reader = pgb_reader; + let mut exit = false; + while !exit { + let read_res = Self::pagestream_read_message( + &mut pgb_reader, + tenant_id, + timeline_id, + &mut timeline_handles, + &cancel_batcher, + &ctx, + request_span.clone(), + ) + .await; + let Some(read_res) = read_res.transpose() else { + debug!("client-initiated shutdown"); + break; + }; + exit |= read_res.is_err(); + let could_send = batch_tx + .send(read_res, |batch, res| { + Self::pagestream_do_batch(max_batch_size, batch, res) + }) + .await; + exit |= could_send.is_err(); + } + (pgb_reader, timeline_handles) + } + }); + + // + // Executor + // + + let executor = pipeline_stage!("executor", self.cancel.clone(), move |cancel| { + let ctx = ctx.attached_child(); + async move { + let _cancel_batcher = cancel_batcher.drop_guard(); + loop { + let maybe_batch = batch_rx.recv().await; + let batch = match maybe_batch { + Ok(batch) => batch, + Err(spsc_fold::RecvError::SenderGone) => { + debug!("upstream gone"); + return Ok(()); + } + }; + let batch = match batch { + Ok(batch) => batch, + Err(e) => { + return Err(e); + } + }; + self.pagesteam_handle_batched_message(pgb_writer, batch, &cancel, &ctx) + .await?; } } + }); + + // + // Execute the stages. + // + + match execution { + PageServiceProtocolPipelinedExecutionStrategy::ConcurrentFutures => { + tokio::join!(batcher, executor) + } + PageServiceProtocolPipelinedExecutionStrategy::Tasks => { + // These tasks are not tracked anywhere. + let read_messages_task = tokio::spawn(batcher); + let (read_messages_task_res, executor_res_) = + tokio::join!(read_messages_task, executor,); + ( + read_messages_task_res.expect("propagated panic from read_messages"), + executor_res_, + ) + } } - Ok(()) } /// Helper function to handle the LSN from client request. @@ -1131,6 +1383,8 @@ impl PageServerHandler { { let timeline = self .timeline_handles + .as_mut() + .unwrap() .get( tenant_shard_id.tenant_id, timeline_id, @@ -1165,22 +1419,17 @@ impl PageServerHandler { #[instrument(skip_all, fields(shard_id))] async fn handle_get_rel_exists_request( &mut self, - tenant_id: TenantId, - timeline_id: TimelineId, + timeline: &Timeline, req: &PagestreamExistsRequest, ctx: &RequestContext, ) -> Result { - let timeline = self - .timeline_handles - .get(tenant_id, timeline_id, ShardSelector::Zero) - .await?; let _timer = timeline .query_metrics .start_timer(metrics::SmgrQueryType::GetRelExists, ctx); let latest_gc_cutoff_lsn = timeline.get_latest_gc_cutoff_lsn(); let lsn = Self::wait_or_get_last_lsn( - &timeline, + timeline, req.request_lsn, req.not_modified_since, &latest_gc_cutoff_lsn, @@ -1200,23 +1449,17 @@ impl PageServerHandler { #[instrument(skip_all, fields(shard_id))] async fn handle_get_nblocks_request( &mut self, - tenant_id: TenantId, - timeline_id: TimelineId, + timeline: &Timeline, req: &PagestreamNblocksRequest, ctx: &RequestContext, ) -> Result { - let timeline = self - .timeline_handles - .get(tenant_id, timeline_id, ShardSelector::Zero) - .await?; - let _timer = timeline .query_metrics .start_timer(metrics::SmgrQueryType::GetRelSize, ctx); let latest_gc_cutoff_lsn = timeline.get_latest_gc_cutoff_lsn(); let lsn = Self::wait_or_get_last_lsn( - &timeline, + timeline, req.request_lsn, req.not_modified_since, &latest_gc_cutoff_lsn, @@ -1236,23 +1479,17 @@ impl PageServerHandler { #[instrument(skip_all, fields(shard_id))] async fn handle_db_size_request( &mut self, - tenant_id: TenantId, - timeline_id: TimelineId, + timeline: &Timeline, req: &PagestreamDbSizeRequest, ctx: &RequestContext, ) -> Result { - let timeline = self - .timeline_handles - .get(tenant_id, timeline_id, ShardSelector::Zero) - .await?; - let _timer = timeline .query_metrics .start_timer(metrics::SmgrQueryType::GetDbSize, ctx); let latest_gc_cutoff_lsn = timeline.get_latest_gc_cutoff_lsn(); let lsn = Self::wait_or_get_last_lsn( - &timeline, + timeline, req.request_lsn, req.not_modified_since, &latest_gc_cutoff_lsn, @@ -1300,23 +1537,17 @@ impl PageServerHandler { #[instrument(skip_all, fields(shard_id))] async fn handle_get_slru_segment_request( &mut self, - tenant_id: TenantId, - timeline_id: TimelineId, + timeline: &Timeline, req: &PagestreamGetSlruSegmentRequest, ctx: &RequestContext, ) -> Result { - let timeline = self - .timeline_handles - .get(tenant_id, timeline_id, ShardSelector::Zero) - .await?; - let _timer = timeline .query_metrics .start_timer(metrics::SmgrQueryType::GetSlruSegment, ctx); let latest_gc_cutoff_lsn = timeline.get_latest_gc_cutoff_lsn(); let lsn = Self::wait_or_get_last_lsn( - &timeline, + timeline, req.request_lsn, req.not_modified_since, &latest_gc_cutoff_lsn, @@ -1374,6 +1605,8 @@ impl PageServerHandler { let timeline = self .timeline_handles + .as_mut() + .unwrap() .get(tenant_id, timeline_id, ShardSelector::Zero) .await?; @@ -1716,7 +1949,7 @@ impl PageServiceCmd { impl postgres_backend::Handler for PageServerHandler where - IO: AsyncRead + AsyncWrite + Send + Sync + Unpin, + IO: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static, { fn check_auth_jwt( &mut self, diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index e3c88e9965..9bcfffeb9c 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -3804,9 +3804,10 @@ class Endpoint(PgProtocol, LogUtils): # shared_buffers = 512kB to make postgres use LFC intensively # neon.max_file_cache_size and neon.file_cache size limit are # set to 1MB because small LFC is better for testing (helps to find more problems) + lfc_path_escaped = str(lfc_path).replace("'", "''") config_lines = [ "shared_buffers = 512kB", - f"neon.file_cache_path = '{self.lfc_path()}'", + f"neon.file_cache_path = '{lfc_path_escaped}'", "neon.max_file_cache_size = 1MB", "neon.file_cache_size_limit = 1MB", ] + config_lines diff --git a/test_runner/performance/pageserver/test_pageserver_getpage_merge.py b/test_runner/performance/pageserver/test_page_service_batching.py similarity index 69% rename from test_runner/performance/pageserver/test_pageserver_getpage_merge.py rename to test_runner/performance/pageserver/test_page_service_batching.py index 34cce9900b..c47a849fec 100644 --- a/test_runner/performance/pageserver/test_pageserver_getpage_merge.py +++ b/test_runner/performance/pageserver/test_page_service_batching.py @@ -11,36 +11,95 @@ from fixtures.log_helper import log from fixtures.neon_fixtures import NeonEnvBuilder, PgBin, wait_for_last_flush_lsn from fixtures.utils import humantime_to_ms -TARGET_RUNTIME = 60 +TARGET_RUNTIME = 30 + + +@dataclass +class PageServicePipeliningConfig: + pass + + +@dataclass +class PageServicePipeliningConfigSerial(PageServicePipeliningConfig): + mode: str = "serial" + + +@dataclass +class PageServicePipeliningConfigPipelined(PageServicePipeliningConfig): + max_batch_size: int + execution: str + mode: str = "pipelined" + + +EXECUTION = ["concurrent-futures", "tasks"] + +NON_BATCHABLE: list[PageServicePipeliningConfig] = [PageServicePipeliningConfigSerial()] +for max_batch_size in [1, 32]: + for execution in EXECUTION: + NON_BATCHABLE.append(PageServicePipeliningConfigPipelined(max_batch_size, execution)) + +BATCHABLE: list[PageServicePipeliningConfig] = [PageServicePipeliningConfigSerial()] +for max_batch_size in [1, 2, 4, 8, 16, 32]: + for execution in EXECUTION: + BATCHABLE.append(PageServicePipeliningConfigPipelined(max_batch_size, execution)) -@pytest.mark.skip("See https://github.com/neondatabase/neon/pull/9820#issue-2675856095") @pytest.mark.parametrize( - "tablesize_mib, batch_timeout, target_runtime, effective_io_concurrency, readhead_buffer_size, name", + "tablesize_mib, pipelining_config, target_runtime, effective_io_concurrency, readhead_buffer_size, name", [ - # the next 4 cases demonstrate how not-batchable workloads suffer from batching timeout - (50, None, TARGET_RUNTIME, 1, 128, "not batchable no batching"), - (50, "10us", TARGET_RUNTIME, 1, 128, "not batchable 10us timeout"), - (50, "1ms", TARGET_RUNTIME, 1, 128, "not batchable 1ms timeout"), - # the next 4 cases demonstrate how batchable workloads benefit from batching - (50, None, TARGET_RUNTIME, 100, 128, "batchable no batching"), - (50, "10us", TARGET_RUNTIME, 100, 128, "batchable 10us timeout"), - (50, "100us", TARGET_RUNTIME, 100, 128, "batchable 100us timeout"), - (50, "1ms", TARGET_RUNTIME, 100, 128, "batchable 1ms timeout"), + # non-batchable workloads + # (A separate benchmark will consider latency). + *[ + ( + 50, + config, + TARGET_RUNTIME, + 1, + 128, + f"not batchable {dataclasses.asdict(config)}", + ) + for config in NON_BATCHABLE + ], + # batchable workloads should show throughput and CPU efficiency improvements + *[ + ( + 50, + config, + TARGET_RUNTIME, + 100, + 128, + f"batchable {dataclasses.asdict(config)}", + ) + for config in BATCHABLE + ], ], ) -def test_getpage_merge_smoke( +def test_throughput( neon_env_builder: NeonEnvBuilder, zenbenchmark: NeonBenchmarker, tablesize_mib: int, - batch_timeout: str | None, + pipelining_config: PageServicePipeliningConfig, target_runtime: int, effective_io_concurrency: int, readhead_buffer_size: int, name: str, ): """ - Do a bunch of sequential scans and ensure that the pageserver does some merging. + Do a bunch of sequential scans with varying compute and pipelining configurations. + Primary performance metrics are the achieved batching factor and throughput (wall clock time). + Resource utilization is also interesting - we currently measure CPU time. + + The test is a fixed-runtime based type of test (target_runtime). + Hence, the results are normalized to the number of iterations completed within target runtime. + + If the compute doesn't provide pipeline depth (effective_io_concurrency=1), + performance should be about identical in all configurations. + Pipelining can still yield improvements in these scenarios because it parses the + next request while the current one is still being executed. + + If the compute provides pipeline depth (effective_io_concurrency=100), then + pipelining configs, especially with max_batch_size>1 should yield dramatic improvements + in all performance metrics. """ # @@ -51,14 +110,16 @@ def test_getpage_merge_smoke( params.update( { "tablesize_mib": (tablesize_mib, {"unit": "MiB"}), - "batch_timeout": ( - -1 if batch_timeout is None else 1e3 * humantime_to_ms(batch_timeout), - {"unit": "us"}, - ), # target_runtime is just a polite ask to the workload to run for this long "effective_io_concurrency": (effective_io_concurrency, {}), "readhead_buffer_size": (readhead_buffer_size, {}), - # name is not a metric + # name is not a metric, we just use it to identify the test easily in the `test_...[...]`` notation + } + ) + params.update( + { + f"pipelining_config.{k}": (v, {}) + for k, v in dataclasses.asdict(pipelining_config).items() } ) @@ -170,7 +231,9 @@ def test_getpage_merge_smoke( after = get_metrics() return (after - before).normalize(iters - 1) - env.pageserver.patch_config_toml_nonrecursive({"server_side_batch_timeout": batch_timeout}) + env.pageserver.patch_config_toml_nonrecursive( + {"page_service_pipelining": dataclasses.asdict(pipelining_config)} + ) env.pageserver.restart() metrics = workload() @@ -199,23 +262,30 @@ def test_getpage_merge_smoke( ) -@pytest.mark.skip("See https://github.com/neondatabase/neon/pull/9820#issue-2675856095") +PRECISION_CONFIGS: list[PageServicePipeliningConfig] = [PageServicePipeliningConfigSerial()] +for max_batch_size in [1, 32]: + for execution in EXECUTION: + PRECISION_CONFIGS.append(PageServicePipeliningConfigPipelined(max_batch_size, execution)) + + @pytest.mark.parametrize( - "batch_timeout", [None, "10us", "20us", "50us", "100us", "200us", "500us", "1ms"] + "pipelining_config,name", + [(config, f"{dataclasses.asdict(config)}") for config in PRECISION_CONFIGS], ) -def test_timer_precision( +def test_latency( neon_env_builder: NeonEnvBuilder, zenbenchmark: NeonBenchmarker, pg_bin: PgBin, - batch_timeout: str | None, + pipelining_config: PageServicePipeliningConfig, + name: str, ): """ - Determine the batching timeout precision (mean latency) and tail latency impact. + Measure the latency impact of pipelining in an un-batchable workloads. - The baseline is `None`; an ideal batching timeout implementation would increase - the mean latency by exactly `batch_timeout`. + An ideal implementation should not increase average or tail latencies for such workloads. - That is not the case with the current implementation, will be addressed in future changes. + We don't have support in pagebench to create queue depth yet. + => https://github.com/neondatabase/neon/issues/9837 """ # @@ -223,7 +293,8 @@ def test_timer_precision( # def patch_ps_config(ps_config): - ps_config["server_side_batch_timeout"] = batch_timeout + if pipelining_config is not None: + ps_config["page_service_pipelining"] = dataclasses.asdict(pipelining_config) neon_env_builder.pageserver_config_override = patch_ps_config From 4abc8e5282037c85a922ae113e0677c50841b309 Mon Sep 17 00:00:00 2001 From: Folke Behrens Date: Sat, 30 Nov 2024 11:11:37 +0100 Subject: [PATCH 10/15] Merge the consumption metric pushes (#9939) #8564 ## Problem The main and backup consumption metric pushes are completely independent, resulting in different event time windows and different idempotency keys. ## Summary of changes * Merge the push tasks, but keep chunks the same size. --- Cargo.lock | 1 + libs/consumption_metrics/src/lib.rs | 5 +- proxy/Cargo.toml | 1 + proxy/src/bin/proxy.rs | 4 - proxy/src/usage_metrics.rs | 351 ++++++++++++++-------------- 5 files changed, 181 insertions(+), 181 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 313222cf3c..5ce27a7d45 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4501,6 +4501,7 @@ dependencies = [ "ecdsa 0.16.9", "env_logger", "fallible-iterator", + "flate2", "framed-websockets", "futures", "hashbrown 0.14.5", diff --git a/libs/consumption_metrics/src/lib.rs b/libs/consumption_metrics/src/lib.rs index fbe2e6830f..448134f31a 100644 --- a/libs/consumption_metrics/src/lib.rs +++ b/libs/consumption_metrics/src/lib.rs @@ -103,11 +103,12 @@ impl<'a> IdempotencyKey<'a> { } } +/// Split into chunks of 1000 metrics to avoid exceeding the max request size. pub const CHUNK_SIZE: usize = 1000; // Just a wrapper around a slice of events // to serialize it as `{"events" : [ ] } -#[derive(serde::Serialize, Deserialize)] -pub struct EventChunk<'a, T: Clone> { +#[derive(Debug, serde::Serialize, serde::Deserialize, PartialEq)] +pub struct EventChunk<'a, T: Clone + PartialEq> { pub events: std::borrow::Cow<'a, [T]>, } diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 0d774d529d..f5934c8a89 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -112,6 +112,7 @@ workspace_hack.workspace = true [dev-dependencies] camino-tempfile.workspace = true fallible-iterator.workspace = true +flate2.workspace = true tokio-tungstenite.workspace = true pbkdf2 = { workspace = true, features = ["simple", "std"] } rcgen.workspace = true diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index a935378162..b772a987ee 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -517,10 +517,6 @@ async fn main() -> anyhow::Result<()> { if let Some(metrics_config) = &config.metric_collection { // TODO: Add gc regardles of the metric collection being enabled. maintenance_tasks.spawn(usage_metrics::task_main(metrics_config)); - client_tasks.spawn(usage_metrics::task_backup( - &metrics_config.backup_metric_collection_config, - cancellation_token.clone(), - )); } if let Either::Left(auth::Backend::ControlPlane(api, _)) = &auth_backend { diff --git a/proxy/src/usage_metrics.rs b/proxy/src/usage_metrics.rs index c5e8588623..65e74466f2 100644 --- a/proxy/src/usage_metrics.rs +++ b/proxy/src/usage_metrics.rs @@ -1,19 +1,18 @@ //! Periodically collect proxy consumption metrics //! and push them to a HTTP endpoint. +use std::borrow::Cow; use std::convert::Infallible; -use std::pin::pin; use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; use std::sync::Arc; use std::time::Duration; -use anyhow::Context; +use anyhow::{bail, Context}; use async_compression::tokio::write::GzipEncoder; use bytes::Bytes; use chrono::{DateTime, Datelike, Timelike, Utc}; use consumption_metrics::{idempotency_key, Event, EventChunk, EventType, CHUNK_SIZE}; use dashmap::mapref::entry::Entry; use dashmap::DashMap; -use futures::future::select; use once_cell::sync::Lazy; use remote_storage::{GenericRemoteStorage, RemotePath, TimeoutOrCancel}; use serde::{Deserialize, Serialize}; @@ -23,7 +22,7 @@ use tracing::{error, info, instrument, trace, warn}; use utils::backoff; use uuid::{NoContext, Timestamp}; -use crate::config::{MetricBackupCollectionConfig, MetricCollectionConfig}; +use crate::config::MetricCollectionConfig; use crate::context::parquet::{FAILED_UPLOAD_MAX_RETRIES, FAILED_UPLOAD_WARN_THRESHOLD}; use crate::http; use crate::intern::{BranchIdInt, EndpointIdInt}; @@ -58,55 +57,21 @@ trait MetricCounterReporter { fn move_metrics(&self) -> (u64, usize); } -#[derive(Debug)] -struct MetricBackupCounter { - transmitted: AtomicU64, - opened_connections: AtomicUsize, -} - -impl MetricCounterRecorder for MetricBackupCounter { - fn record_egress(&self, bytes: u64) { - self.transmitted.fetch_add(bytes, Ordering::AcqRel); - } - - fn record_connection(&self, count: usize) { - self.opened_connections.fetch_add(count, Ordering::AcqRel); - } -} - -impl MetricCounterReporter for MetricBackupCounter { - fn get_metrics(&mut self) -> (u64, usize) { - ( - *self.transmitted.get_mut(), - *self.opened_connections.get_mut(), - ) - } - fn move_metrics(&self) -> (u64, usize) { - ( - self.transmitted.swap(0, Ordering::AcqRel), - self.opened_connections.swap(0, Ordering::AcqRel), - ) - } -} - #[derive(Debug)] pub(crate) struct MetricCounter { transmitted: AtomicU64, opened_connections: AtomicUsize, - backup: Arc, } impl MetricCounterRecorder for MetricCounter { /// Record that some bytes were sent from the proxy to the client fn record_egress(&self, bytes: u64) { - self.transmitted.fetch_add(bytes, Ordering::AcqRel); - self.backup.record_egress(bytes); + self.transmitted.fetch_add(bytes, Ordering::Relaxed); } /// Record that some connections were opened fn record_connection(&self, count: usize) { - self.opened_connections.fetch_add(count, Ordering::AcqRel); - self.backup.record_connection(count); + self.opened_connections.fetch_add(count, Ordering::Relaxed); } } @@ -119,8 +84,8 @@ impl MetricCounterReporter for MetricCounter { } fn move_metrics(&self) -> (u64, usize) { ( - self.transmitted.swap(0, Ordering::AcqRel), - self.opened_connections.swap(0, Ordering::AcqRel), + self.transmitted.swap(0, Ordering::Relaxed), + self.opened_connections.swap(0, Ordering::Relaxed), ) } } @@ -173,26 +138,11 @@ type FastHasher = std::hash::BuildHasherDefault; #[derive(Default)] pub(crate) struct Metrics { endpoints: DashMap, FastHasher>, - backup_endpoints: DashMap, FastHasher>, } impl Metrics { /// Register a new byte metrics counter for this endpoint pub(crate) fn register(&self, ids: Ids) -> Arc { - let backup = if let Some(entry) = self.backup_endpoints.get(&ids) { - entry.clone() - } else { - self.backup_endpoints - .entry(ids.clone()) - .or_insert_with(|| { - Arc::new(MetricBackupCounter { - transmitted: AtomicU64::new(0), - opened_connections: AtomicUsize::new(0), - }) - }) - .clone() - }; - let entry = if let Some(entry) = self.endpoints.get(&ids) { entry.clone() } else { @@ -202,7 +152,6 @@ impl Metrics { Arc::new(MetricCounter { transmitted: AtomicU64::new(0), opened_connections: AtomicUsize::new(0), - backup: backup.clone(), }) }) .clone() @@ -227,6 +176,21 @@ pub async fn task_main(config: &MetricCollectionConfig) -> anyhow::Result anyhow::Result( now: DateTime, chunk_size: usize, ) -> impl Iterator>> + 'a { - // Split into chunks of 1000 metrics to avoid exceeding the max request size metrics_to_send .chunks(chunk_size) .map(move |chunk| EventChunk { @@ -303,11 +268,14 @@ fn create_event_chunks<'a>( }) } +#[expect(clippy::too_many_arguments)] #[instrument(skip_all)] async fn collect_metrics_iteration( endpoints: &DashMap, FastHasher>, client: &http::ClientWithMiddleware, metric_collection_endpoint: &reqwest::Url, + storage: Option<&GenericRemoteStorage>, + outer_chunk_size: usize, hostname: &str, prev: DateTime, now: DateTime, @@ -323,17 +291,54 @@ async fn collect_metrics_iteration( trace!("no new metrics to send"); } + let cancel = CancellationToken::new(); + let path_prefix = create_remote_path_prefix(now); + // Send metrics. - for chunk in create_event_chunks(&metrics_to_send, hostname, prev, now, CHUNK_SIZE) { + for chunk in create_event_chunks(&metrics_to_send, hostname, prev, now, outer_chunk_size) { + tokio::join!( + upload_main_events_chunked(client, metric_collection_endpoint, &chunk, CHUNK_SIZE), + async { + if let Err(e) = upload_backup_events(storage, &chunk, &path_prefix, &cancel).await { + error!("failed to upload consumption events to remote storage: {e:?}"); + } + } + ); + } +} + +fn create_remote_path_prefix(now: DateTime) -> String { + format!( + "year={year:04}/month={month:02}/day={day:02}/{hour:02}:{minute:02}:{second:02}Z", + year = now.year(), + month = now.month(), + day = now.day(), + hour = now.hour(), + minute = now.minute(), + second = now.second(), + ) +} + +async fn upload_main_events_chunked( + client: &http::ClientWithMiddleware, + metric_collection_endpoint: &reqwest::Url, + chunk: &EventChunk<'_, Event>, + subchunk_size: usize, +) { + // Split into smaller chunks to avoid exceeding the max request size + for subchunk in chunk.events.chunks(subchunk_size).map(|c| EventChunk { + events: Cow::Borrowed(c), + }) { let res = client .post(metric_collection_endpoint.clone()) - .json(&chunk) + .json(&subchunk) .send() .await; let res = match res { Ok(x) => x, Err(err) => { + // TODO: retry? error!("failed to send metrics: {:?}", err); continue; } @@ -341,7 +346,7 @@ async fn collect_metrics_iteration( if !res.status().is_success() { error!("metrics endpoint refused the sent metrics: {:?}", res); - for metric in chunk.events.iter().filter(|e| e.value > (1u64 << 40)) { + for metric in subchunk.events.iter().filter(|e| e.value > (1u64 << 40)) { // Report if the metric value is suspiciously large warn!("potentially abnormal metric value: {:?}", metric); } @@ -349,113 +354,34 @@ async fn collect_metrics_iteration( } } -pub async fn task_backup( - backup_config: &MetricBackupCollectionConfig, - cancellation_token: CancellationToken, -) -> anyhow::Result<()> { - info!("metrics backup config: {backup_config:?}"); - scopeguard::defer! { - info!("metrics backup has shut down"); - } - // Even if the remote storage is not configured, we still want to clear the metrics. - let storage = if let Some(config) = backup_config.remote_storage_config.as_ref() { - Some( - GenericRemoteStorage::from_config(config) - .await - .context("remote storage init")?, - ) - } else { - None - }; - let mut ticker = tokio::time::interval(backup_config.interval); - let mut prev = Utc::now(); - let hostname = hostname::get()?.as_os_str().to_string_lossy().into_owned(); - loop { - select(pin!(ticker.tick()), pin!(cancellation_token.cancelled())).await; - let now = Utc::now(); - collect_metrics_backup_iteration( - &USAGE_METRICS.backup_endpoints, - storage.as_ref(), - &hostname, - prev, - now, - backup_config.chunk_size, - ) - .await; - - prev = now; - if cancellation_token.is_cancelled() { - info!("metrics backup has been cancelled"); - break; - } - } - Ok(()) -} - -#[instrument(skip_all)] -async fn collect_metrics_backup_iteration( - endpoints: &DashMap, FastHasher>, +async fn upload_backup_events( storage: Option<&GenericRemoteStorage>, - hostname: &str, - prev: DateTime, - now: DateTime, - chunk_size: usize, -) { - let year = now.year(); - let month = now.month(); - let day = now.day(); - let hour = now.hour(); - let minute = now.minute(); - let second = now.second(); - let cancel = CancellationToken::new(); - - info!("starting collect_metrics_backup_iteration"); - - let metrics_to_send = collect_and_clear_metrics(endpoints); - - if metrics_to_send.is_empty() { - trace!("no new metrics to send"); - } - - // Send metrics. - for chunk in create_event_chunks(&metrics_to_send, hostname, prev, now, chunk_size) { - let real_now = Utc::now(); - let id = uuid::Uuid::new_v7(Timestamp::from_unix( - NoContext, - real_now.second().into(), - real_now.nanosecond(), - )); - let path = format!("year={year:04}/month={month:02}/day={day:02}/{hour:02}:{minute:02}:{second:02}Z_{id}.json.gz"); - let remote_path = match RemotePath::from_string(&path) { - Ok(remote_path) => remote_path, - Err(e) => { - error!("failed to create remote path from str {path}: {:?}", e); - continue; - } - }; - - let res = upload_events_chunk(storage, chunk, &remote_path, &cancel).await; - - if let Err(e) = res { - error!( - "failed to upload consumption events to remote storage: {:?}", - e - ); - } - } -} - -async fn upload_events_chunk( - storage: Option<&GenericRemoteStorage>, - chunk: EventChunk<'_, Event>, - remote_path: &RemotePath, + chunk: &EventChunk<'_, Event>, + path_prefix: &str, cancel: &CancellationToken, ) -> anyhow::Result<()> { let Some(storage) = storage else { - error!("no remote storage configured"); + warn!("no remote storage configured"); return Ok(()); }; - let data = serde_json::to_vec(&chunk).context("serialize metrics")?; + + let real_now = Utc::now(); + let id = uuid::Uuid::new_v7(Timestamp::from_unix( + NoContext, + real_now.second().into(), + real_now.nanosecond(), + )); + let path = format!("{path_prefix}_{id}.json.gz"); + let remote_path = match RemotePath::from_string(&path) { + Ok(remote_path) => remote_path, + Err(e) => { + bail!("failed to create remote path from str {path}: {:?}", e); + } + }; + + // TODO: This is async compression from Vec to Vec. Rewrite as byte stream. + // Use sync compression in blocking threadpool. + let data = serde_json::to_vec(chunk).context("serialize metrics")?; let mut encoder = GzipEncoder::new(Vec::new()); encoder.write_all(&data).await.context("compress metrics")?; encoder.shutdown().await.context("compress metrics")?; @@ -464,7 +390,7 @@ async fn upload_events_chunk( || async { let stream = futures::stream::once(futures::future::ready(Ok(compressed_data.clone()))); storage - .upload(stream, compressed_data.len(), remote_path, None, cancel) + .upload(stream, compressed_data.len(), &remote_path, None, cancel) .await }, TimeoutOrCancel::caused_by_cancel, @@ -482,9 +408,12 @@ async fn upload_events_chunk( #[cfg(test)] mod tests { + use std::fs; + use std::io::BufReader; use std::sync::{Arc, Mutex}; use anyhow::Error; + use camino_tempfile::tempdir; use chrono::Utc; use consumption_metrics::{Event, EventChunk}; use http_body_util::BodyExt; @@ -493,6 +422,7 @@ mod tests { use hyper::service::service_fn; use hyper::{Request, Response}; use hyper_util::rt::TokioIo; + use remote_storage::{RemoteStorageConfig, RemoteStorageKind}; use tokio::net::TcpListener; use url::Url; @@ -538,8 +468,34 @@ mod tests { let endpoint = Url::parse(&format!("http://{addr}")).unwrap(); let now = Utc::now(); + let storage_test_dir = tempdir().unwrap(); + let local_fs_path = storage_test_dir.path().join("usage_metrics"); + fs::create_dir_all(&local_fs_path).unwrap(); + let storage = GenericRemoteStorage::from_config(&RemoteStorageConfig { + storage: RemoteStorageKind::LocalFs { + local_path: local_fs_path.clone(), + }, + timeout: Duration::from_secs(10), + small_timeout: Duration::from_secs(1), + }) + .await + .unwrap(); + + let mut pushed_chunks: Vec = Vec::new(); + let mut stored_chunks: Vec = Vec::new(); + // no counters have been registered - collect_metrics_iteration(&metrics.endpoints, &client, &endpoint, "foo", now, now).await; + collect_metrics_iteration( + &metrics.endpoints, + &client, + &endpoint, + Some(&storage), + 1000, + "foo", + now, + now, + ) + .await; let r = std::mem::take(&mut *reports.lock().unwrap()); assert!(r.is_empty()); @@ -551,39 +507,84 @@ mod tests { }); // the counter should be observed despite 0 egress - collect_metrics_iteration(&metrics.endpoints, &client, &endpoint, "foo", now, now).await; + collect_metrics_iteration( + &metrics.endpoints, + &client, + &endpoint, + Some(&storage), + 1000, + "foo", + now, + now, + ) + .await; let r = std::mem::take(&mut *reports.lock().unwrap()); assert_eq!(r.len(), 1); assert_eq!(r[0].events.len(), 1); assert_eq!(r[0].events[0].value, 0); + pushed_chunks.extend(r); // record egress counter.record_egress(1); // egress should be observered - collect_metrics_iteration(&metrics.endpoints, &client, &endpoint, "foo", now, now).await; + collect_metrics_iteration( + &metrics.endpoints, + &client, + &endpoint, + Some(&storage), + 1000, + "foo", + now, + now, + ) + .await; let r = std::mem::take(&mut *reports.lock().unwrap()); assert_eq!(r.len(), 1); assert_eq!(r[0].events.len(), 1); assert_eq!(r[0].events[0].value, 1); + pushed_chunks.extend(r); // release counter drop(counter); // we do not observe the counter - collect_metrics_iteration(&metrics.endpoints, &client, &endpoint, "foo", now, now).await; + collect_metrics_iteration( + &metrics.endpoints, + &client, + &endpoint, + Some(&storage), + 1000, + "foo", + now, + now, + ) + .await; let r = std::mem::take(&mut *reports.lock().unwrap()); assert!(r.is_empty()); // counter is unregistered assert!(metrics.endpoints.is_empty()); - collect_metrics_backup_iteration(&metrics.backup_endpoints, None, "foo", now, now, 1000) - .await; - assert!(!metrics.backup_endpoints.is_empty()); - collect_metrics_backup_iteration(&metrics.backup_endpoints, None, "foo", now, now, 1000) - .await; - // backup counter is unregistered after the second iteration - assert!(metrics.backup_endpoints.is_empty()); + let path_prefix = create_remote_path_prefix(now); + for entry in walkdir::WalkDir::new(&local_fs_path) + .into_iter() + .filter_map(|e| e.ok()) + { + let path = local_fs_path.join(&path_prefix).to_string(); + if entry.path().to_str().unwrap().starts_with(&path) { + let chunk = serde_json::from_reader(flate2::bufread::GzDecoder::new( + BufReader::new(fs::File::open(entry.into_path()).unwrap()), + )) + .unwrap(); + stored_chunks.push(chunk); + } + } + storage_test_dir.close().ok(); + + // sort by first event's idempotency key because the order of files is nondeterministic + pushed_chunks.sort_by_cached_key(|c| c.events[0].idempotency_key.clone()); + stored_chunks.sort_by_cached_key(|c| c.events[0].idempotency_key.clone()); + assert_eq!(pushed_chunks, stored_chunks); } } From 97a9abd18131708d0daadfd9c43b95048b538910 Mon Sep 17 00:00:00 2001 From: Konstantin Knizhnik Date: Sun, 1 Dec 2024 14:23:10 +0200 Subject: [PATCH 11/15] Add GUC controlling whether to pause recovery if some critical GUCs at replica have smaller value than on primary (#9057) ## Problem See https://github.com/neondatabase/neon/issues/9023 ## Summary of changes Ass GUC `recovery_pause_on_misconfig` allowing not to pause in case of replica and primary configuration mismatch See https://github.com/neondatabase/postgres/pull/501 See https://github.com/neondatabase/postgres/pull/502 See https://github.com/neondatabase/postgres/pull/503 See https://github.com/neondatabase/postgres/pull/504 ## Checklist before requesting a review - [ ] I have performed a self-review of my code. - [ ] If it is a core feature, I have added thorough tests. - [ ] Do we need to implement analytics? if so did you add the relevant metrics to the dashboard? - [ ] If this PR requires public announcement, mark it with /release-notes label and add several sentences in this section. ## Checklist before merging - [ ] Do not forget to reformat commit message to not include the above checklist --------- Co-authored-by: Konstantin Knizhnik Co-authored-by: Heikki Linnakangas --- pgxn/neon/neon.c | 13 ++ .../regress/test_physical_replication.py | 221 +++++++++++++++++- vendor/postgres-v14 | 2 +- vendor/postgres-v15 | 2 +- vendor/postgres-v16 | 2 +- vendor/postgres-v17 | 2 +- vendor/revisions.json | 8 +- 7 files changed, 241 insertions(+), 9 deletions(-) diff --git a/pgxn/neon/neon.c b/pgxn/neon/neon.c index 51b9f58bbc..ff08f9164d 100644 --- a/pgxn/neon/neon.c +++ b/pgxn/neon/neon.c @@ -15,6 +15,9 @@ #include "access/subtrans.h" #include "access/twophase.h" #include "access/xlog.h" +#if PG_MAJORVERSION_NUM >= 15 +#include "access/xlogrecovery.h" +#endif #include "replication/logical.h" #include "replication/slot.h" #include "replication/walsender.h" @@ -432,6 +435,16 @@ _PG_init(void) restore_running_xacts_callback = RestoreRunningXactsFromClog; + DefineCustomBoolVariable( + "neon.allow_replica_misconfig", + "Allow replica startup when some critical GUCs have smaller value than on primary node", + NULL, + &allowReplicaMisconfig, + true, + PGC_POSTMASTER, + 0, + NULL, NULL, NULL); + DefineCustomEnumVariable( "neon.running_xacts_overflow_policy", "Action performed on snapshot overflow when restoring runnings xacts from CLOG", diff --git a/test_runner/regress/test_physical_replication.py b/test_runner/regress/test_physical_replication.py index 043aff686b..6cb11b825d 100644 --- a/test_runner/regress/test_physical_replication.py +++ b/test_runner/regress/test_physical_replication.py @@ -4,6 +4,10 @@ import random import time from typing import TYPE_CHECKING +import pytest +from fixtures.log_helper import log +from fixtures.neon_fixtures import wait_replica_caughtup + if TYPE_CHECKING: from fixtures.neon_fixtures import NeonEnv @@ -19,8 +23,8 @@ def test_physical_replication(neon_simple_env: NeonEnv): p_cur.execute( "CREATE TABLE t(pk bigint primary key, payload text default repeat('?',200))" ) - time.sleep(1) with env.endpoints.new_replica_start(origin=primary, endpoint_id="secondary") as secondary: + wait_replica_caughtup(primary, secondary) with primary.connect() as p_con: with p_con.cursor() as p_cur: with secondary.connect() as s_con: @@ -42,3 +46,218 @@ def test_physical_replication(neon_simple_env: NeonEnv): s_cur.execute( "select * from t where pk=%s", (random.randrange(1, 2 * pk),) ) + + +def test_physical_replication_config_mismatch_max_connections(neon_simple_env: NeonEnv): + """ + Test for primary and replica with different configuration settings (max_connections). + PostgreSQL enforces that settings that affect how many transactions can be open at the same time + have values equal to or higher in a hot standby replica than in the primary. If they don't, the replica refuses + to start up. If the settings are changed in the primary, it emits a WAL record with the new settings, and + when the replica sees that record it pauses the replay. + + PostgreSQL enforces this to ensure that the replica can hold all the XIDs in the so-called + "known-assigned XIDs" array, which is a fixed size array that needs to be allocated + upfront and server startup. That's pretty pessimistic, though; usually you can get + away with smaller settings, because we allocate space for 64 subtransactions per + transaction too. If you get unlucky and you run out of space, WAL redo dies with + "ERROR: too many KnownAssignedXids". It's better to take the chances than refuse + to start up, especially in Neon: if the WAL redo dies, the server is restarted, which is + no worse than refusing to start up in the first place. Furthermore, the control plane + tries to ensure that on restart, the settings are set high enough, so most likely it will + work after restart. Because of that, we have patched Postgres to disable to checks when + the `recovery_pause_on_misconfig` setting is set to `false` (which is the default on neon). + + This test tests all those cases of running out of space in known-assigned XIDs array that + we can hit with `recovery_pause_on_misconfig=false`, which are unreachable in unpatched + Postgres. + There's a similar check for `max_locks_per_transactions` too, which is related to running out + of space in the lock manager rather than known-assigned XIDs. Similar story with that, although + running out of space in the lock manager is possible in unmodified Postgres too. Enforcing the + check for `max_locks_per_transactions` ensures that you don't run out of space in the lock manager + when there are no read-only queries holding locks in the replica, but you can still run out if you have + those. + """ + env = neon_simple_env + with env.endpoints.create_start( + branch_name="main", + endpoint_id="primary", + ) as primary: + with primary.connect() as p_con: + with p_con.cursor() as p_cur: + p_cur.execute( + "CREATE TABLE t(pk bigint primary key, payload text default repeat('?',200))" + ) + with env.endpoints.new_replica_start( + origin=primary, + endpoint_id="secondary", + config_lines=["max_connections=5"], + ) as secondary: + wait_replica_caughtup(primary, secondary) + with secondary.connect() as s_con: + with s_con.cursor() as s_cur: + cursors = [] + for i in range(10): + p_con = primary.connect() + p_cur = p_con.cursor() + p_cur.execute("begin") + p_cur.execute("insert into t (pk) values (%s)", (i,)) + cursors.append(p_cur) + + for p_cur in cursors: + p_cur.execute("commit") + + wait_replica_caughtup(primary, secondary) + s_cur.execute("select count(*) from t") + assert s_cur.fetchall()[0][0] == 10 + + +def test_physical_replication_config_mismatch_max_prepared(neon_simple_env: NeonEnv): + """ + Test for primary and replica with different configuration settings (max_prepared_transactions). + If number of transactions at primary exceeds its limit at replica then WAL replay is terminated. + """ + env = neon_simple_env + primary = env.endpoints.create_start( + branch_name="main", + endpoint_id="primary", + config_lines=["max_prepared_transactions=10"], + ) + p_con = primary.connect() + p_cur = p_con.cursor() + p_cur.execute("CREATE TABLE t(pk bigint primary key, payload text default repeat('?',200))") + + secondary = env.endpoints.new_replica_start( + origin=primary, + endpoint_id="secondary", + config_lines=["max_prepared_transactions=5"], + ) + wait_replica_caughtup(primary, secondary) + + s_con = secondary.connect() + s_cur = s_con.cursor() + cursors = [] + for i in range(10): + p_con = primary.connect() + p_cur = p_con.cursor() + p_cur.execute("begin") + p_cur.execute("insert into t (pk) values (%s)", (i,)) + p_cur.execute(f"prepare transaction 't{i}'") + cursors.append(p_cur) + + for i in range(10): + cursors[i].execute(f"commit prepared 't{i}'") + + time.sleep(5) + with pytest.raises(Exception) as e: + s_cur.execute("select count(*) from t") + assert s_cur.fetchall()[0][0] == 10 + secondary.stop() + + log.info(f"Replica crashed with {e}") + assert secondary.log_contains("maximum number of prepared transactions reached") + + +def connect(ep): + max_reconnect_attempts = 10 + for _ in range(max_reconnect_attempts): + try: + return ep.connect() + except Exception as e: + log.info(f"Failed to connect with primary: {e}") + time.sleep(1) + + +def test_physical_replication_config_mismatch_too_many_known_xids(neon_simple_env: NeonEnv): + """ + Test for primary and replica with different configuration settings (max_connections). + In this case large difference in this setting and larger number of concurrent transactions at primary + # cause too many known xids error at replica. + """ + env = neon_simple_env + primary = env.endpoints.create_start( + branch_name="main", + endpoint_id="primary", + config_lines=[ + "max_connections=1000", + "shared_buffers=128MB", # prevent "no unpinned buffers available" error + ], + ) + secondary = env.endpoints.new_replica_start( + origin=primary, + endpoint_id="secondary", + config_lines=[ + "max_connections=2", + "autovacuum_max_workers=1", + "max_worker_processes=5", + "max_wal_senders=1", + "superuser_reserved_connections=0", + ], + ) + + p_con = primary.connect() + p_cur = p_con.cursor() + p_cur.execute("CREATE TABLE t(x integer)") + + n_connections = 990 + cursors = [] + for i in range(n_connections): + p_con = connect(primary) + p_cur = p_con.cursor() + p_cur.execute("begin") + p_cur.execute(f"insert into t values({i})") + cursors.append(p_cur) + + for cur in cursors: + cur.execute("commit") + + time.sleep(5) + with pytest.raises(Exception) as e: + s_con = secondary.connect() + s_cur = s_con.cursor() + s_cur.execute("select count(*) from t") + assert s_cur.fetchall()[0][0] == n_connections + secondary.stop() + + log.info(f"Replica crashed with {e}") + assert secondary.log_contains("too many KnownAssignedXids") + + +def test_physical_replication_config_mismatch_max_locks_per_transaction(neon_simple_env: NeonEnv): + """ + Test for primary and replica with different configuration settings (max_locks_per_transaction). + In conjunction with different number of max_connections at primary and standby it can cause "out of shared memory" + error if the primary obtains more AccessExclusiveLocks than the standby can hold. + """ + env = neon_simple_env + primary = env.endpoints.create_start( + branch_name="main", + endpoint_id="primary", + config_lines=[ + "max_locks_per_transaction = 100", + ], + ) + secondary = env.endpoints.new_replica_start( + origin=primary, + endpoint_id="secondary", + config_lines=[ + "max_connections=10", + "max_locks_per_transaction = 10", + ], + ) + + n_tables = 1000 + + p_con = primary.connect() + p_cur = p_con.cursor() + p_cur.execute("begin") + for i in range(n_tables): + p_cur.execute(f"CREATE TABLE t_{i}(x integer)") + p_cur.execute("commit") + + with pytest.raises(Exception) as e: + wait_replica_caughtup(primary, secondary) + secondary.stop() + + log.info(f"Replica crashed with {e}") + assert secondary.log_contains("You might need to increase") diff --git a/vendor/postgres-v14 b/vendor/postgres-v14 index c1989c934d..373f9decad 160000 --- a/vendor/postgres-v14 +++ b/vendor/postgres-v14 @@ -1 +1 @@ -Subproject commit c1989c934d46e04e78b3c496c8a34bcd40ddceeb +Subproject commit 373f9decad933d2d46f321231032ae8b0da81acd diff --git a/vendor/postgres-v15 b/vendor/postgres-v15 index d929b9a8b9..972e325e62 160000 --- a/vendor/postgres-v15 +++ b/vendor/postgres-v15 @@ -1 +1 @@ -Subproject commit d929b9a8b9f32f6fe5a0eac3e6e963f0e44e27e6 +Subproject commit 972e325e62b455957adbbdd8580e31275bb5b8c9 diff --git a/vendor/postgres-v16 b/vendor/postgres-v16 index 13e9e35394..dff6615a8e 160000 --- a/vendor/postgres-v16 +++ b/vendor/postgres-v16 @@ -1 +1 @@ -Subproject commit 13e9e3539419003e79bd9aa29e1bc44f3fd555dd +Subproject commit dff6615a8e48a10bb17a03fa3c00635f1ace7a92 diff --git a/vendor/postgres-v17 b/vendor/postgres-v17 index faebe5e5af..a10d95be67 160000 --- a/vendor/postgres-v17 +++ b/vendor/postgres-v17 @@ -1 +1 @@ -Subproject commit faebe5e5aff5687908504453623778f8515529db +Subproject commit a10d95be67265e0f10a422ba0457f5a7af01de71 diff --git a/vendor/revisions.json b/vendor/revisions.json index abeddcadf7..8a73e14dcf 100644 --- a/vendor/revisions.json +++ b/vendor/revisions.json @@ -1,18 +1,18 @@ { "v17": [ "17.2", - "faebe5e5aff5687908504453623778f8515529db" + "a10d95be67265e0f10a422ba0457f5a7af01de71" ], "v16": [ "16.6", - "13e9e3539419003e79bd9aa29e1bc44f3fd555dd" + "dff6615a8e48a10bb17a03fa3c00635f1ace7a92" ], "v15": [ "15.10", - "d929b9a8b9f32f6fe5a0eac3e6e963f0e44e27e6" + "972e325e62b455957adbbdd8580e31275bb5b8c9" ], "v14": [ "14.15", - "c1989c934d46e04e78b3c496c8a34bcd40ddceeb" + "373f9decad933d2d46f321231032ae8b0da81acd" ] } From fae8e7ba76b134a06f7175eadb71c87038a1b399 Mon Sep 17 00:00:00 2001 From: Alexander Bayandin Date: Sun, 1 Dec 2024 13:04:37 +0000 Subject: [PATCH 12/15] Compute image: prepare Postgres v14-v16 for Debian 12 (#9954) ## Problem Current compute images for Postgres 14-16 don't build on Debian 12 because of issues with extensions. This PR fixes that, but for the current setup, it is mostly a no-op change. ## Summary of changes - Use `/bin/bash -euo pipefail` as SHELL to fail earlier - Fix `plv8` build: backport a trivial patch for v8 - Fix `postgis` build: depend `sfgal` version on Debian version instead of Postgres version Tested in: https://github.com/neondatabase/neon/pull/9849 --- compute/compute-node.Dockerfile | 17 ++++++++----- compute/patches/plv8-3.1.10.patch | 42 +++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 6 deletions(-) create mode 100644 compute/patches/plv8-3.1.10.patch diff --git a/compute/compute-node.Dockerfile b/compute/compute-node.Dockerfile index 2fcd9985bc..9567018053 100644 --- a/compute/compute-node.Dockerfile +++ b/compute/compute-node.Dockerfile @@ -14,6 +14,9 @@ ARG DEBIAN_FLAVOR=${DEBIAN_VERSION}-slim FROM debian:$DEBIAN_FLAVOR AS build-deps ARG DEBIAN_VERSION +# Use strict mode for bash to catch errors early +SHELL ["/bin/bash", "-euo", "pipefail", "-c"] + RUN case $DEBIAN_VERSION in \ # Version-specific installs for Bullseye (PG14-PG16): # The h3_pg extension needs a cmake 3.20+, but Debian bullseye has 3.18. @@ -106,6 +109,7 @@ RUN cd postgres && \ # ######################################################################################### FROM build-deps AS postgis-build +ARG DEBIAN_VERSION ARG PG_VERSION COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/ RUN apt update && \ @@ -122,12 +126,12 @@ RUN apt update && \ # and also we must check backward compatibility with older versions of PostGIS. # # Use new version only for v17 -RUN case "${PG_VERSION}" in \ - "v17") \ +RUN case "${DEBIAN_VERSION}" in \ + "bookworm") \ export SFCGAL_VERSION=1.4.1 \ export SFCGAL_CHECKSUM=1800c8a26241588f11cddcf433049e9b9aea902e923414d2ecef33a3295626c3 \ ;; \ - "v14" | "v15" | "v16") \ + "bullseye") \ export SFCGAL_VERSION=1.3.10 \ export SFCGAL_CHECKSUM=4e39b3b2adada6254a7bdba6d297bb28e1a9835a9f879b74f37e2dab70203232 \ ;; \ @@ -228,6 +232,8 @@ FROM build-deps AS plv8-build ARG PG_VERSION COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/ +COPY compute/patches/plv8-3.1.10.patch /plv8-3.1.10.patch + RUN apt update && \ apt install --no-install-recommends -y ninja-build python3-dev libncurses5 binutils clang @@ -239,8 +245,6 @@ RUN apt update && \ # # Use new version only for v17 # because since v3.2, plv8 doesn't include plcoffee and plls extensions -ENV PLV8_TAG=v3.2.3 - RUN case "${PG_VERSION}" in \ "v17") \ export PLV8_TAG=v3.2.3 \ @@ -255,8 +259,9 @@ RUN case "${PG_VERSION}" in \ git clone --recurse-submodules --depth 1 --branch ${PLV8_TAG} https://github.com/plv8/plv8.git plv8-src && \ tar -czf plv8.tar.gz --exclude .git plv8-src && \ cd plv8-src && \ + if [[ "${PG_VERSION}" < "v17" ]]; then patch -p1 < /plv8-3.1.10.patch; fi && \ # generate and copy upgrade scripts - mkdir -p upgrade && ./generate_upgrade.sh 3.1.10 && \ + mkdir -p upgrade && ./generate_upgrade.sh ${PLV8_TAG#v} && \ cp upgrade/* /usr/local/pgsql/share/extension/ && \ export PATH="/usr/local/pgsql/bin:$PATH" && \ make DOCKER=1 -j $(getconf _NPROCESSORS_ONLN) install && \ diff --git a/compute/patches/plv8-3.1.10.patch b/compute/patches/plv8-3.1.10.patch new file mode 100644 index 0000000000..43cdb479f7 --- /dev/null +++ b/compute/patches/plv8-3.1.10.patch @@ -0,0 +1,42 @@ +commit 46b38d3e46f9cd6c70d9b189dd6ff4abaa17cf5e +Author: Alexander Bayandin +Date: Sat Nov 30 18:29:32 2024 +0000 + + Fix v8 9.7.37 compilation on Debian 12 + +diff --git a/patches/code/84cf3230a9680aac3b73c410c2b758760b6d3066.patch b/patches/code/84cf3230a9680aac3b73c410c2b758760b6d3066.patch +new file mode 100644 +index 0000000..f0a5dc7 +--- /dev/null ++++ b/patches/code/84cf3230a9680aac3b73c410c2b758760b6d3066.patch +@@ -0,0 +1,30 @@ ++From 84cf3230a9680aac3b73c410c2b758760b6d3066 Mon Sep 17 00:00:00 2001 ++From: Michael Lippautz ++Date: Thu, 27 Jan 2022 14:14:11 +0100 ++Subject: [PATCH] cppgc: Fix include ++ ++Add to cover for std::exchange. ++ ++Bug: v8:12585 ++Change-Id: Ida65144e93e466be8914527d0e646f348c136bcb ++Reviewed-on: https://chromium-review.googlesource.com/c/v8/v8/+/3420309 ++Auto-Submit: Michael Lippautz ++Reviewed-by: Omer Katz ++Commit-Queue: Michael Lippautz ++Cr-Commit-Position: refs/heads/main@{#78820} ++--- ++ src/heap/cppgc/prefinalizer-handler.h | 1 + ++ 1 file changed, 1 insertion(+) ++ ++diff --git a/src/heap/cppgc/prefinalizer-handler.h b/src/heap/cppgc/prefinalizer-handler.h ++index bc17c99b1838..c82c91ff5a45 100644 ++--- a/src/heap/cppgc/prefinalizer-handler.h +++++ b/src/heap/cppgc/prefinalizer-handler.h ++@@ -5,6 +5,7 @@ ++ #ifndef V8_HEAP_CPPGC_PREFINALIZER_HANDLER_H_ ++ #define V8_HEAP_CPPGC_PREFINALIZER_HANDLER_H_ ++ +++#include ++ #include ++ ++ #include "include/cppgc/prefinalizer.h" From aad809b048afbc86c2ffe48461e75f9e6d6fe3fb Mon Sep 17 00:00:00 2001 From: Konstantin Knizhnik Date: Sun, 1 Dec 2024 17:47:28 +0200 Subject: [PATCH 13/15] Fix issues with prefetch ring buffer resize (#9847) ## Problem See https://neondb.slack.com/archives/C04DGM6SMTM/p1732110190129479 We observe the following error in the logs ``` [XX000] ERROR: [NEON_SMGR] [shard 3] Incorrect prefetch read: status=1 response=0x7fafef335138 my=128 receive=128 ``` most likely caused by changing `neon.readahead_buffer_size` ## Summary of changes 1. Copy shard state 2. Do not use prefetch_set_unused in readahead_buffer_resize 3. Change prefetch buffer overflow criteria --------- Co-authored-by: Konstantin Knizhnik --- pgxn/neon/pagestore_smgr.c | 13 +++++- .../regress/test_prefetch_buffer_resize.py | 40 +++++++++++++++++++ 2 files changed, 51 insertions(+), 2 deletions(-) create mode 100644 test_runner/regress/test_prefetch_buffer_resize.py diff --git a/pgxn/neon/pagestore_smgr.c b/pgxn/neon/pagestore_smgr.c index cbb0e2ae6d..a5e0c402fb 100644 --- a/pgxn/neon/pagestore_smgr.c +++ b/pgxn/neon/pagestore_smgr.c @@ -439,6 +439,8 @@ readahead_buffer_resize(int newsize, void *extra) newPState->ring_unused = newsize; newPState->ring_receive = newsize; newPState->ring_flush = newsize; + newPState->max_shard_no = MyPState->max_shard_no; + memcpy(newPState->shard_bitmap, MyPState->shard_bitmap, sizeof(MyPState->shard_bitmap)); /* * Copy over the prefetches. @@ -495,7 +497,11 @@ readahead_buffer_resize(int newsize, void *extra) for (; end >= MyPState->ring_last && end != UINT64_MAX; end -= 1) { - prefetch_set_unused(end); + PrefetchRequest *slot = GetPrfSlot(end); + if (slot->status == PRFS_RECEIVED) + { + pfree(slot->response); + } } prfh_destroy(MyPState->prf_hash); @@ -944,6 +950,9 @@ Retry: Assert(entry == NULL); Assert(slot == NULL); + /* There should be no buffer overflow */ + Assert(MyPState->ring_last + readahead_buffer_size >= MyPState->ring_unused); + /* * If the prefetch queue is full, we need to make room by clearing the * oldest slot. If the oldest slot holds a buffer that was already @@ -958,7 +967,7 @@ Retry: * a prefetch request kind of goes against the principles of * prefetching) */ - if (MyPState->ring_last + readahead_buffer_size - 1 == MyPState->ring_unused) + if (MyPState->ring_last + readahead_buffer_size == MyPState->ring_unused) { uint64 cleanup_index = MyPState->ring_last; diff --git a/test_runner/regress/test_prefetch_buffer_resize.py b/test_runner/regress/test_prefetch_buffer_resize.py new file mode 100644 index 0000000000..7676b78b0e --- /dev/null +++ b/test_runner/regress/test_prefetch_buffer_resize.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +import random + +import pytest +from fixtures.neon_fixtures import NeonEnvBuilder + + +@pytest.mark.parametrize("shard_count", [None, 4]) +@pytest.mark.timeout(600) +def test_prefetch(neon_env_builder: NeonEnvBuilder, shard_count: int | None): + if shard_count is not None: + neon_env_builder.num_pageservers = shard_count + env = neon_env_builder.init_start( + initial_tenant_shard_count=shard_count, + ) + n_iter = 10 + n_rec = 100000 + + endpoint = env.endpoints.create_start( + "main", + config_lines=[ + "shared_buffers=10MB", + ], + ) + + cur = endpoint.connect().cursor() + + cur.execute("CREATE TABLE t(pk integer, filler text default repeat('?', 200))") + cur.execute(f"insert into t (pk) values (generate_series(1,{n_rec}))") + + cur.execute("set statement_timeout=0") + cur.execute("set effective_io_concurrency=20") + cur.execute("set max_parallel_workers_per_gather=0") + + for _ in range(n_iter): + buf_size = random.randrange(16, 32) + cur.execute(f"set neon.readahead_buffer_size={buf_size}") + limit = random.randrange(1, n_rec) + cur.execute(f"select sum(pk) from (select pk from t limit {limit}) s") From 14853a32846fdf0c571f0f13c9d83315b2974dd6 Mon Sep 17 00:00:00 2001 From: John Spray Date: Sun, 1 Dec 2024 18:09:58 +0000 Subject: [PATCH 14/15] storcon: don't take any Service locks in /status and /ready (#9944) ## Problem We saw unexpected container terminations when running in k8s with with small CPU resource requests. The /status and /ready handlers called `maybe_forward`, which always takes the lock on Service::inner. If there is a lot of writer lock contention, and the container is starved of CPU, this increases the likelihood that we will get killed by the kubelet. It isn't certain that this was a cause of issues, but it is a potential source that we can eliminate. ## Summary of changes - Revise logic to return immediately if the URL is in the non-forwarded list, rather than calling maybe_forward --- storage_controller/src/http.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/storage_controller/src/http.rs b/storage_controller/src/http.rs index 9b5d4caf31..39e078ba7c 100644 --- a/storage_controller/src/http.rs +++ b/storage_controller/src/http.rs @@ -1452,10 +1452,15 @@ async fn maybe_forward(req: Request) -> ForwardOutcome { let uri = req.uri().to_string(); let uri_for_forward = !NOT_FOR_FORWARD.contains(&uri.as_str()); + // Fast return before trying to take any Service locks, if we will never forward anyway + if !uri_for_forward { + return ForwardOutcome::NotForwarded(req); + } + let state = get_state(&req); let leadership_status = state.service.get_leadership_status(); - if leadership_status != LeadershipStatus::SteppedDown || !uri_for_forward { + if leadership_status != LeadershipStatus::SteppedDown { return ForwardOutcome::NotForwarded(req); } From 304af5c9e3e490ca8b96bdbff26e63c3266b3a42 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Mon, 2 Dec 2024 06:05:37 +0000 Subject: [PATCH 15/15] Storage & Compute release 2024-12-02