Compare commits

..

1 Commits

Author SHA1 Message Date
Arseny Sher
ee46acee26 Add safekeeper option to patch control file.
https://github.com/neondatabase/neon/issues/6397
2024-01-21 00:22:30 +03:00
66 changed files with 1478 additions and 2527 deletions

View File

@@ -1,2 +1,2 @@
[profile.default]
slow-timeout = { period = "20s", terminate-after = 3 }
slow-timeout = "1m"

191
Cargo.lock generated
View File

@@ -10,9 +10,9 @@ checksum = "8b5ace29ee3216de37c0546865ad08edef58b0f9e76838ed8959a84a990e58c5"
[[package]]
name = "addr2line"
version = "0.21.0"
version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb"
checksum = "a76fd60b23679b7d19bd066031410fb7e458ccc5e958eb5c325888ce4baedc97"
dependencies = [
"gimli",
]
@@ -840,15 +840,15 @@ dependencies = [
[[package]]
name = "backtrace"
version = "0.3.69"
version = "0.3.67"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2089b7e3f35b9dd2d0ed921ead4f6d318c27680d4a5bd167b3ee120edb105837"
checksum = "233d376d6d185f2a3093e58f283f60f880315b6c60075b01f36b3b85154564ca"
dependencies = [
"addr2line",
"cc",
"cfg-if",
"libc",
"miniz_oxide",
"miniz_oxide 0.6.2",
"object",
"rustc-demangle",
]
@@ -1215,7 +1215,7 @@ dependencies = [
"flate2",
"futures",
"hyper",
"nix 0.27.1",
"nix 0.26.2",
"notify",
"num_cpus",
"opentelemetry",
@@ -1331,7 +1331,7 @@ dependencies = [
"git-version",
"hex",
"hyper",
"nix 0.27.1",
"nix 0.26.2",
"once_cell",
"pageserver_api",
"pageserver_client",
@@ -1872,13 +1872,13 @@ dependencies = [
[[package]]
name = "filetime"
version = "0.2.22"
version = "0.2.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d4029edd3e734da6fe05b6cd7bd2960760a616bd2ddd0d59a0124746d6272af0"
checksum = "5cbc844cecaee9d4443931972e1289c8ff485cb4cc2767cb03ca139ed6885153"
dependencies = [
"cfg-if",
"libc",
"redox_syscall 0.3.5",
"redox_syscall 0.2.16",
"windows-sys 0.48.0",
]
@@ -1895,7 +1895,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b9429470923de8e8cbd4d2dc513535400b4b3fef0319fb5c4e1f520a7bef743"
dependencies = [
"crc32fast",
"miniz_oxide",
"miniz_oxide 0.7.1",
]
[[package]]
@@ -2093,9 +2093,9 @@ dependencies = [
[[package]]
name = "gimli"
version = "0.28.1"
version = "0.27.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253"
checksum = "ad0a93d233ebf96623465aad4046a8d3aa4da22d4f4beba5388838c8a434bbb4"
[[package]]
name = "git-version"
@@ -2748,18 +2748,18 @@ checksum = "f665ee40bc4a3c5590afb1e9677db74a508659dfd71e126420da8274909a0167"
[[package]]
name = "memoffset"
version = "0.8.0"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d61c719bcfbcf5d62b3a09efa6088de8c54bc0bfcd3ea7ae39fcc186108b8de1"
checksum = "5de893c32cde5f383baa4c04c5d6dbdd735cfd4a794b0debdb2bb1b421da5ff4"
dependencies = [
"autocfg",
]
[[package]]
name = "memoffset"
version = "0.9.0"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c"
checksum = "d61c719bcfbcf5d62b3a09efa6088de8c54bc0bfcd3ea7ae39fcc186108b8de1"
dependencies = [
"autocfg",
]
@@ -2797,6 +2797,15 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a"
[[package]]
name = "miniz_oxide"
version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b275950c28b37e794e8c55d88aeb5e139d0ce23fdbbeda68f8d7174abdf9e8fa"
dependencies = [
"adler",
]
[[package]]
name = "miniz_oxide"
version = "0.7.1"
@@ -2856,14 +2865,16 @@ dependencies = [
[[package]]
name = "nix"
version = "0.27.1"
version = "0.26.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2eb04e9c688eff1c89d72b407f168cf79bb9e867a9d3323ed6c01519eb9cc053"
checksum = "bfdda3d196821d6af13126e40375cdf7da646a96114af134d5f417a9a1dc8e1a"
dependencies = [
"bitflags 2.4.1",
"bitflags 1.3.2",
"cfg-if",
"libc",
"memoffset 0.9.0",
"memoffset 0.7.1",
"pin-utils",
"static_assertions",
]
[[package]]
@@ -2878,21 +2889,20 @@ dependencies = [
[[package]]
name = "notify"
version = "6.1.1"
version = "5.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6205bd8bb1e454ad2e27422015fb5e4f2bcc7e08fa8f27058670d208324a4d2d"
checksum = "729f63e1ca555a43fe3efa4f3efdf4801c479da85b432242a7b726f353c88486"
dependencies = [
"bitflags 2.4.1",
"bitflags 1.3.2",
"crossbeam-channel",
"filetime",
"fsevent-sys",
"inotify 0.9.6",
"kqueue",
"libc",
"log",
"mio",
"walkdir",
"windows-sys 0.48.0",
"windows-sys 0.45.0",
]
[[package]]
@@ -3018,9 +3028,9 @@ dependencies = [
[[package]]
name = "object"
version = "0.32.2"
version = "0.30.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a6a622008b6e321afc04970976f62ee297fdbaa6f95318ca343e3eebb9648441"
checksum = "ea86265d3d3dcb6a27fc51bd29a4bf387fae9d2986b823079d4986af253eb439"
dependencies = [
"memchr",
]
@@ -3092,9 +3102,9 @@ dependencies = [
[[package]]
name = "opentelemetry"
version = "0.20.0"
version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9591d937bc0e6d2feb6f71a559540ab300ea49955229c347a517a28d27784c54"
checksum = "5f4b8347cc26099d3aeee044065ecc3ae11469796b4d65d065a23a584ed92a6f"
dependencies = [
"opentelemetry_api",
"opentelemetry_sdk",
@@ -3102,9 +3112,9 @@ dependencies = [
[[package]]
name = "opentelemetry-http"
version = "0.9.0"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7594ec0e11d8e33faf03530a4c49af7064ebba81c1480e01be67d90b356508b"
checksum = "a819b71d6530c4297b49b3cae2939ab3a8cc1b9f382826a1bc29dd0ca3864906"
dependencies = [
"async-trait",
"bytes",
@@ -3115,56 +3125,54 @@ dependencies = [
[[package]]
name = "opentelemetry-otlp"
version = "0.13.0"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7e5e5a5c4135864099f3faafbe939eb4d7f9b80ebf68a8448da961b32a7c1275"
checksum = "8af72d59a4484654ea8eb183fea5ae4eb6a41d7ac3e3bae5f4d2a282a3a7d3ca"
dependencies = [
"async-trait",
"futures-core",
"futures",
"futures-util",
"http",
"opentelemetry",
"opentelemetry-http",
"opentelemetry-proto",
"opentelemetry-semantic-conventions",
"opentelemetry_api",
"opentelemetry_sdk",
"prost",
"reqwest",
"thiserror",
"tokio",
"tonic",
]
[[package]]
name = "opentelemetry-proto"
version = "0.3.0"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1e3f814aa9f8c905d0ee4bde026afd3b2577a97c10e1699912e3e44f0c4cbeb"
checksum = "045f8eea8c0fa19f7d48e7bc3128a39c2e5c533d5c61298c548dfefc1064474c"
dependencies = [
"opentelemetry_api",
"opentelemetry_sdk",
"futures",
"futures-util",
"opentelemetry",
"prost",
"tonic",
"tonic 0.8.3",
]
[[package]]
name = "opentelemetry-semantic-conventions"
version = "0.12.0"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73c9f9340ad135068800e7f1b24e9e09ed9e7143f5bf8518ded3d3ec69789269"
checksum = "24e33428e6bf08c6f7fcea4ddb8e358fab0fe48ab877a87c70c6ebe20f673ce5"
dependencies = [
"opentelemetry",
]
[[package]]
name = "opentelemetry_api"
version = "0.20.0"
version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a81f725323db1b1206ca3da8bb19874bbd3f57c3bcd59471bfb04525b265b9b"
checksum = "ed41783a5bf567688eb38372f2b7a8530f5a607a4b49d38dd7573236c23ca7e2"
dependencies = [
"fnv",
"futures-channel",
"futures-util",
"indexmap 1.9.3",
"js-sys",
"once_cell",
"pin-project-lite",
"thiserror",
@@ -3173,22 +3181,21 @@ dependencies = [
[[package]]
name = "opentelemetry_sdk"
version = "0.20.0"
version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fa8e705a0612d48139799fcbaba0d4a90f06277153e43dd2bdc16c6f0edd8026"
checksum = "8b3a2a91fdbfdd4d212c0dcc2ab540de2c2bcbbd90be17de7a7daf8822d010c1"
dependencies = [
"async-trait",
"crossbeam-channel",
"dashmap",
"fnv",
"futures-channel",
"futures-executor",
"futures-util",
"once_cell",
"opentelemetry_api",
"ordered-float 3.9.2",
"percent-encoding",
"rand 0.8.5",
"regex",
"serde_json",
"thiserror",
"tokio",
"tokio-stream",
@@ -3203,15 +3210,6 @@ dependencies = [
"num-traits",
]
[[package]]
name = "ordered-float"
version = "3.9.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1e1c390732d15f1d48471625cd92d154e66db2c56645e29a9cd26f4699f72dc"
dependencies = [
"num-traits",
]
[[package]]
name = "ordered-multimap"
version = "0.7.1"
@@ -3327,7 +3325,7 @@ dependencies = [
"itertools",
"md5",
"metrics",
"nix 0.27.1",
"nix 0.26.2",
"num-traits",
"num_cpus",
"once_cell",
@@ -4341,9 +4339,9 @@ dependencies = [
[[package]]
name = "reqwest-tracing"
version = "0.4.7"
version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a0152176687dd5cfe7f507ac1cb1a491c679cfe483afd133a7db7aaea818bb3"
checksum = "1b97ad83c2fc18113346b7158d79732242002427c30f620fa817c1f32901e0a8"
dependencies = [
"anyhow",
"async-trait",
@@ -5033,9 +5031,9 @@ dependencies = [
[[package]]
name = "shlex"
version = "1.3.0"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
checksum = "43b2853a4d09f215c24cc5489c992ce46052d359b5109343cbafbf26bc62f8a3"
[[package]]
name = "signal-hook"
@@ -5197,7 +5195,7 @@ dependencies = [
"prost",
"tokio",
"tokio-stream",
"tonic",
"tonic 0.9.2",
"tonic-build",
"tracing",
"utils",
@@ -5417,7 +5415,7 @@ checksum = "7e54bc85fc7faa8bc175c4bab5b92ba8d9a3ce893d0e9f42cc455c8ab16a9e09"
dependencies = [
"byteorder",
"integer-encoding",
"ordered-float 2.10.1",
"ordered-float",
]
[[package]]
@@ -5683,6 +5681,38 @@ dependencies = [
"winnow",
]
[[package]]
name = "tonic"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f219fad3b929bef19b1f86fbc0358d35daed8f2cac972037ac0dc10bbb8d5fb"
dependencies = [
"async-stream",
"async-trait",
"axum",
"base64 0.13.1",
"bytes",
"futures-core",
"futures-util",
"h2",
"http",
"http-body",
"hyper",
"hyper-timeout",
"percent-encoding",
"pin-project",
"prost",
"prost-derive",
"tokio",
"tokio-stream",
"tokio-util",
"tower",
"tower-layer",
"tower-service",
"tracing",
"tracing-futures",
]
[[package]]
name = "tonic"
version = "0.9.2"
@@ -5826,6 +5856,16 @@ dependencies = [
"tracing-subscriber",
]
[[package]]
name = "tracing-futures"
version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97d095ae15e245a057c8e8451bab9b3ee1e1f68e9ba2b4fbc18d0ac5237835f2"
dependencies = [
"pin-project",
"tracing",
]
[[package]]
name = "tracing-log"
version = "0.1.3"
@@ -5839,9 +5879,9 @@ dependencies = [
[[package]]
name = "tracing-opentelemetry"
version = "0.20.0"
version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc09e402904a5261e42cf27aea09ccb7d5318c6717a9eec3d8e2e65c56b18f19"
checksum = "00a39dcf9bfc1742fa4d6215253b33a6e474be78275884c216fc2a06267b3600"
dependencies = [
"once_cell",
"opentelemetry",
@@ -6078,7 +6118,7 @@ dependencies = [
"hyper",
"jsonwebtoken",
"metrics",
"nix 0.27.1",
"nix 0.26.2",
"once_cell",
"pin-project-lite",
"postgres_connection",
@@ -6586,8 +6626,10 @@ dependencies = [
"clap",
"clap_builder",
"crossbeam-utils",
"dashmap",
"either",
"fail",
"futures",
"futures-channel",
"futures-core",
"futures-executor",
@@ -6632,7 +6674,6 @@ dependencies = [
"tokio-util",
"toml_datetime",
"toml_edit",
"tonic",
"tower",
"tracing",
"tracing-core",

View File

@@ -99,14 +99,14 @@ libc = "0.2"
md5 = "0.7.0"
memoffset = "0.8"
native-tls = "0.2"
nix = { version = "0.27", features = ["fs", "process", "socket", "signal", "poll"] }
notify = "6.0.0"
nix = "0.26"
notify = "5.0.0"
num_cpus = "1.15"
num-traits = "0.2.15"
once_cell = "1.13"
opentelemetry = "0.20.0"
opentelemetry-otlp = { version = "0.13.0", default_features=false, features = ["http-proto", "trace", "http", "reqwest-client"] }
opentelemetry-semantic-conventions = "0.12.0"
opentelemetry = "0.19.0"
opentelemetry-otlp = { version = "0.12.0", default_features=false, features = ["http-proto", "trace", "http", "reqwest-client"] }
opentelemetry-semantic-conventions = "0.11.0"
parking_lot = "0.12"
parquet = { version = "49.0.0", default-features = false, features = ["zstd"] }
parquet_derive = "49.0.0"
@@ -118,7 +118,7 @@ rand = "0.8"
redis = { version = "0.24.0", features = ["tokio-rustls-comp", "keep-alive"] }
regex = "1.10.2"
reqwest = { version = "0.11", default-features = false, features = ["rustls-tls"] }
reqwest-tracing = { version = "0.4.7", features = ["opentelemetry_0_20"] }
reqwest-tracing = { version = "0.4.0", features = ["opentelemetry_0_19"] }
reqwest-middleware = "0.2.0"
reqwest-retry = "0.2.2"
routerify = "3"
@@ -162,7 +162,7 @@ toml_edit = "0.19"
tonic = {version = "0.9", features = ["tls", "tls-roots"]}
tracing = "0.1"
tracing-error = "0.2.0"
tracing-opentelemetry = "0.20.0"
tracing-opentelemetry = "0.19.0"
tracing-subscriber = { version = "0.3", default_features = false, features = ["smallvec", "fmt", "tracing-log", "std", "env-filter", "json"] }
url = "2.2"
uuid = { version = "1.6.1", features = ["v4", "v7", "serde"] }

View File

@@ -143,8 +143,6 @@ RUN wget https://github.com/pgRouting/pgrouting/archive/v3.4.2.tar.gz -O pgrouti
#########################################################################################
FROM build-deps AS plv8-build
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
ARG PG_VERSION
RUN apt update && \
apt install -y ninja-build python3-dev libncurses5 binutils clang
@@ -619,7 +617,6 @@ RUN wget https://github.com/theory/pg-semver/archive/refs/tags/v0.32.1.tar.gz -O
FROM build-deps AS pg-embedding-pg-build
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
ARG PG_VERSION
ENV PATH "/usr/local/pgsql/bin/:$PATH"
RUN case "${PG_VERSION}" in \
"v14" | "v15") \
@@ -782,8 +779,6 @@ RUN wget https://github.com/eulerto/wal2json/archive/refs/tags/wal2json_2_5.tar.
#
#########################################################################################
FROM build-deps AS neon-pg-ext-build
ARG PG_VERSION
# Public extensions
COPY --from=postgis-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=postgis-build /sfcgal/* /

View File

@@ -700,14 +700,13 @@ impl ComputeNode {
// In this case we need to connect with old `zenith_admin` name
// and create new user. We cannot simply rename connected user,
// but we can create a new one and grant it all privileges.
let connstr = self.connstr.clone();
let mut client = match Client::connect(connstr.as_str(), NoTls) {
let mut client = match Client::connect(self.connstr.as_str(), NoTls) {
Err(e) => {
info!(
"cannot connect to postgres: {}, retrying with `zenith_admin` username",
e
);
let mut zenith_admin_connstr = connstr.clone();
let mut zenith_admin_connstr = self.connstr.clone();
zenith_admin_connstr
.set_username("zenith_admin")
@@ -720,8 +719,8 @@ impl ComputeNode {
client.simple_query("GRANT zenith_admin TO cloud_admin")?;
drop(client);
// reconnect with connstring with expected name
Client::connect(connstr.as_str(), NoTls)?
// reconnect with connsting with expected name
Client::connect(self.connstr.as_str(), NoTls)?
}
Ok(client) => client,
};
@@ -735,8 +734,8 @@ impl ComputeNode {
cleanup_instance(&mut client)?;
handle_roles(spec, &mut client)?;
handle_databases(spec, &mut client)?;
handle_role_deletions(spec, connstr.as_str(), &mut client)?;
handle_grants(spec, &mut client, connstr.as_str())?;
handle_role_deletions(spec, self.connstr.as_str(), &mut client)?;
handle_grants(spec, &mut client, self.connstr.as_str())?;
handle_extensions(spec, &mut client)?;
handle_extension_neon(&mut client)?;
create_availability_check_data(&mut client)?;
@@ -744,12 +743,6 @@ impl ComputeNode {
// 'Close' connection
drop(client);
if self.has_feature(ComputeFeature::Migrations) {
thread::spawn(move || {
let mut client = Client::connect(connstr.as_str(), NoTls)?;
handle_migrations(&mut client)
});
}
Ok(())
}
@@ -814,10 +807,6 @@ impl ComputeNode {
handle_grants(&spec, &mut client, self.connstr.as_str())?;
handle_extensions(&spec, &mut client)?;
handle_extension_neon(&mut client)?;
// We can skip handle_migrations here because a new migration can only appear
// if we have a new version of the compute_ctl binary, which can only happen
// if compute got restarted, in which case we'll end up inside of apply_config
// instead of reconfigure.
}
// 'Close' connection

View File

@@ -727,79 +727,3 @@ pub fn handle_extension_neon(client: &mut Client) -> Result<()> {
Ok(())
}
#[instrument(skip_all)]
pub fn handle_migrations(client: &mut Client) -> Result<()> {
info!("handle migrations");
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
// !BE SURE TO ONLY ADD MIGRATIONS TO THE END OF THIS ARRAY. IF YOU DO NOT, VERY VERY BAD THINGS MAY HAPPEN!
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
let migrations = [
"ALTER ROLE neon_superuser BYPASSRLS",
r#"
DO $$
DECLARE
role_name text;
BEGIN
FOR role_name IN SELECT rolname FROM pg_roles WHERE pg_has_role(rolname, 'neon_superuser', 'member')
LOOP
RAISE NOTICE 'EXECUTING ALTER ROLE % INHERIT', quote_ident(role_name);
EXECUTE 'ALTER ROLE ' || quote_ident(role_name) || ' INHERIT';
END LOOP;
FOR role_name IN SELECT rolname FROM pg_roles
WHERE
NOT pg_has_role(rolname, 'neon_superuser', 'member') AND NOT starts_with(rolname, 'pg_')
LOOP
RAISE NOTICE 'EXECUTING ALTER ROLE % NOBYPASSRLS', quote_ident(role_name);
EXECUTE 'ALTER ROLE ' || quote_ident(role_name) || ' NOBYPASSRLS';
END LOOP;
END $$;
"#,
];
let mut query = "CREATE SCHEMA IF NOT EXISTS neon_migration";
client.simple_query(query)?;
query = "CREATE TABLE IF NOT EXISTS neon_migration.migration_id (key INT NOT NULL PRIMARY KEY, id bigint NOT NULL DEFAULT 0)";
client.simple_query(query)?;
query = "INSERT INTO neon_migration.migration_id VALUES (0, 0) ON CONFLICT DO NOTHING";
client.simple_query(query)?;
query = "ALTER SCHEMA neon_migration OWNER TO cloud_admin";
client.simple_query(query)?;
query = "REVOKE ALL ON SCHEMA neon_migration FROM PUBLIC";
client.simple_query(query)?;
query = "SELECT id FROM neon_migration.migration_id";
let row = client.query_one(query, &[])?;
let mut current_migration: usize = row.get::<&str, i64>("id") as usize;
let starting_migration_id = current_migration;
query = "BEGIN";
client.simple_query(query)?;
while current_migration < migrations.len() {
info!("Running migration:\n{}\n", migrations[current_migration]);
client.simple_query(migrations[current_migration])?;
current_migration += 1;
}
let setval = format!(
"UPDATE neon_migration.migration_id SET id={}",
migrations.len()
);
client.simple_query(&setval)?;
query = "COMMIT";
client.simple_query(query)?;
info!(
"Ran {} migrations",
(migrations.len() - starting_migration_id)
);
Ok(())
}

View File

@@ -184,7 +184,7 @@ impl Persistence {
pub(crate) async fn increment_generation(
&self,
tenant_shard_id: TenantShardId,
node_id: NodeId,
node_id: Option<NodeId>,
) -> anyhow::Result<Generation> {
let (write, gen) = {
let mut locked = self.state.lock().unwrap();
@@ -192,9 +192,14 @@ impl Persistence {
anyhow::bail!("Tried to increment generation of unknown shard");
};
shard.generation += 1;
shard.generation_pageserver = Some(node_id);
// If we're called with a None pageserver, we need only update the generation
// record to disassociate it with this pageserver, not actually increment the number, as
// the increment is guaranteed to happen the next time this tenant is attached.
if node_id.is_some() {
shard.generation += 1;
}
shard.generation_pageserver = node_id;
let gen = Generation::new(shard.generation);
(locked.save(), gen)
};
@@ -203,19 +208,6 @@ impl Persistence {
Ok(gen)
}
pub(crate) async fn detach(&self, tenant_shard_id: TenantShardId) -> anyhow::Result<()> {
let write = {
let mut locked = self.state.lock().unwrap();
let Some(shard) = locked.tenants.get_mut(&tenant_shard_id) else {
anyhow::bail!("Tried to increment generation of unknown shard");
};
shard.generation_pageserver = None;
locked.save()
};
write.commit().await?;
Ok(())
}
pub(crate) async fn re_attach(
&self,
node_id: NodeId,

View File

@@ -296,7 +296,7 @@ impl Reconciler {
// Increment generation before attaching to new pageserver
self.generation = self
.persistence
.increment_generation(self.tenant_shard_id, dest_ps_id)
.increment_generation(self.tenant_shard_id, Some(dest_ps_id))
.await?;
let dest_conf = build_location_config(
@@ -395,7 +395,7 @@ impl Reconciler {
// as locations with unknown (None) observed state.
self.generation = self
.persistence
.increment_generation(self.tenant_shard_id, node_id)
.increment_generation(self.tenant_shard_id, Some(node_id))
.await?;
wanted_conf.generation = self.generation.into();
tracing::info!("Observed configuration requires update.");

View File

@@ -362,14 +362,13 @@ impl Service {
);
}
let new_generation = if let Some(req_node_id) = attach_req.node_id {
let new_generation = if attach_req.node_id.is_some() {
Some(
self.persistence
.increment_generation(attach_req.tenant_shard_id, req_node_id)
.increment_generation(attach_req.tenant_shard_id, attach_req.node_id)
.await?,
)
} else {
self.persistence.detach(attach_req.tenant_shard_id).await?;
None
};
@@ -408,7 +407,6 @@ impl Service {
"attach_hook: tenant {} set generation {:?}, pageserver {}",
attach_req.tenant_shard_id,
tenant_state.generation,
// TODO: this is an odd number of 0xf's
attach_req.node_id.unwrap_or(utils::id::NodeId(0xfffffff))
);

View File

@@ -57,7 +57,7 @@ use crate::local_env::LocalEnv;
use crate::postgresql_conf::PostgresConf;
use compute_api::responses::{ComputeState, ComputeStatus};
use compute_api::spec::{Cluster, ComputeFeature, ComputeMode, ComputeSpec};
use compute_api::spec::{Cluster, ComputeMode, ComputeSpec};
// contents of a endpoint.json file
#[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Debug)]
@@ -70,7 +70,6 @@ pub struct EndpointConf {
http_port: u16,
pg_version: u32,
skip_pg_catalog_updates: bool,
features: Vec<ComputeFeature>,
}
//
@@ -141,7 +140,6 @@ impl ComputeControlPlane {
// with this we basically test a case of waking up an idle compute, where
// we also skip catalog updates in the cloud.
skip_pg_catalog_updates: true,
features: vec![],
});
ep.create_endpoint_dir()?;
@@ -156,7 +154,6 @@ impl ComputeControlPlane {
pg_port,
pg_version,
skip_pg_catalog_updates: true,
features: vec![],
})?,
)?;
std::fs::write(
@@ -218,9 +215,6 @@ pub struct Endpoint {
// Optimizations
skip_pg_catalog_updates: bool,
// Feature flags
features: Vec<ComputeFeature>,
}
impl Endpoint {
@@ -250,7 +244,6 @@ impl Endpoint {
tenant_id: conf.tenant_id,
pg_version: conf.pg_version,
skip_pg_catalog_updates: conf.skip_pg_catalog_updates,
features: conf.features,
})
}
@@ -526,7 +519,7 @@ impl Endpoint {
skip_pg_catalog_updates: self.skip_pg_catalog_updates,
format_version: 1.0,
operation_uuid: None,
features: self.features.clone(),
features: vec![],
cluster: Cluster {
cluster_id: None, // project ID: not used
name: None, // project name: not used

View File

@@ -90,9 +90,6 @@ pub enum ComputeFeature {
/// track short-lived connections as user activity.
ActivityMonitorExperimental,
/// Enable running migrations
Migrations,
/// This is a special feature flag that is used to represent unknown feature flags.
/// Basically all unknown to enum flags are represented as this one. See unit test
/// `parse_unknown_features()` for more details.

View File

@@ -1,11 +1,9 @@
use anyhow::{bail, Result};
use byteorder::{ByteOrder, BE};
use postgres_ffi::relfile_utils::{FSM_FORKNUM, VISIBILITYMAP_FORKNUM};
use postgres_ffi::{Oid, TransactionId};
use serde::{Deserialize, Serialize};
use std::{fmt, ops::Range};
use std::fmt;
use crate::reltag::{BlockNumber, RelTag, SlruKind};
use crate::reltag::{BlockNumber, RelTag};
/// Key used in the Repository kv-store.
///
@@ -145,390 +143,12 @@ impl Key {
}
}
// Layout of the Key address space
//
// The Key struct, used to address the underlying key-value store, consists of
// 18 bytes, split into six fields. See 'Key' in repository.rs. We need to map
// all the data and metadata keys into those 18 bytes.
//
// Principles for the mapping:
//
// - Things that are often accessed or modified together, should be close to
// each other in the key space. For example, if a relation is extended by one
// block, we create a new key-value pair for the block data, and update the
// relation size entry. Because of that, the RelSize key comes after all the
// RelBlocks of a relation: the RelSize and the last RelBlock are always next
// to each other.
//
// The key space is divided into four major sections, identified by the first
// byte, and the form a hierarchy:
//
// 00 Relation data and metadata
//
// DbDir () -> (dbnode, spcnode)
// Filenodemap
// RelDir -> relnode forknum
// RelBlocks
// RelSize
//
// 01 SLRUs
//
// SlruDir kind
// SlruSegBlocks segno
// SlruSegSize
//
// 02 pg_twophase
//
// 03 misc
// Controlfile
// checkpoint
// pg_version
//
// 04 aux files
//
// Below is a full list of the keyspace allocation:
//
// DbDir:
// 00 00000000 00000000 00000000 00 00000000
//
// Filenodemap:
// 00 SPCNODE DBNODE 00000000 00 00000000
//
// RelDir:
// 00 SPCNODE DBNODE 00000000 00 00000001 (Postgres never uses relfilenode 0)
//
// RelBlock:
// 00 SPCNODE DBNODE RELNODE FORK BLKNUM
//
// RelSize:
// 00 SPCNODE DBNODE RELNODE FORK FFFFFFFF
//
// SlruDir:
// 01 kind 00000000 00000000 00 00000000
//
// SlruSegBlock:
// 01 kind 00000001 SEGNO 00 BLKNUM
//
// SlruSegSize:
// 01 kind 00000001 SEGNO 00 FFFFFFFF
//
// TwoPhaseDir:
// 02 00000000 00000000 00000000 00 00000000
//
// TwoPhaseFile:
// 02 00000000 00000000 00000000 00 XID
//
// ControlFile:
// 03 00000000 00000000 00000000 00 00000000
//
// Checkpoint:
// 03 00000000 00000000 00000000 00 00000001
//
// AuxFiles:
// 03 00000000 00000000 00000000 00 00000002
//
//-- Section 01: relation data and metadata
pub const DBDIR_KEY: Key = Key {
field1: 0x00,
field2: 0,
field3: 0,
field4: 0,
field5: 0,
field6: 0,
};
#[inline(always)]
pub fn dbdir_key_range(spcnode: Oid, dbnode: Oid) -> Range<Key> {
Key {
field1: 0x00,
field2: spcnode,
field3: dbnode,
field4: 0,
field5: 0,
field6: 0,
}..Key {
field1: 0x00,
field2: spcnode,
field3: dbnode,
field4: 0xffffffff,
field5: 0xff,
field6: 0xffffffff,
}
}
#[inline(always)]
pub fn relmap_file_key(spcnode: Oid, dbnode: Oid) -> Key {
Key {
field1: 0x00,
field2: spcnode,
field3: dbnode,
field4: 0,
field5: 0,
field6: 0,
}
}
#[inline(always)]
pub fn rel_dir_to_key(spcnode: Oid, dbnode: Oid) -> Key {
Key {
field1: 0x00,
field2: spcnode,
field3: dbnode,
field4: 0,
field5: 0,
field6: 1,
}
}
#[inline(always)]
pub fn rel_block_to_key(rel: RelTag, blknum: BlockNumber) -> Key {
Key {
field1: 0x00,
field2: rel.spcnode,
field3: rel.dbnode,
field4: rel.relnode,
field5: rel.forknum,
field6: blknum,
}
}
#[inline(always)]
pub fn rel_size_to_key(rel: RelTag) -> Key {
Key {
field1: 0x00,
field2: rel.spcnode,
field3: rel.dbnode,
field4: rel.relnode,
field5: rel.forknum,
field6: 0xffffffff,
}
}
#[inline(always)]
pub fn rel_key_range(rel: RelTag) -> Range<Key> {
Key {
field1: 0x00,
field2: rel.spcnode,
field3: rel.dbnode,
field4: rel.relnode,
field5: rel.forknum,
field6: 0,
}..Key {
field1: 0x00,
field2: rel.spcnode,
field3: rel.dbnode,
field4: rel.relnode,
field5: rel.forknum + 1,
field6: 0,
}
}
//-- Section 02: SLRUs
#[inline(always)]
pub fn slru_dir_to_key(kind: SlruKind) -> Key {
Key {
field1: 0x01,
field2: match kind {
SlruKind::Clog => 0x00,
SlruKind::MultiXactMembers => 0x01,
SlruKind::MultiXactOffsets => 0x02,
},
field3: 0,
field4: 0,
field5: 0,
field6: 0,
}
}
#[inline(always)]
pub fn slru_block_to_key(kind: SlruKind, segno: u32, blknum: BlockNumber) -> Key {
Key {
field1: 0x01,
field2: match kind {
SlruKind::Clog => 0x00,
SlruKind::MultiXactMembers => 0x01,
SlruKind::MultiXactOffsets => 0x02,
},
field3: 1,
field4: segno,
field5: 0,
field6: blknum,
}
}
#[inline(always)]
pub fn slru_segment_size_to_key(kind: SlruKind, segno: u32) -> Key {
Key {
field1: 0x01,
field2: match kind {
SlruKind::Clog => 0x00,
SlruKind::MultiXactMembers => 0x01,
SlruKind::MultiXactOffsets => 0x02,
},
field3: 1,
field4: segno,
field5: 0,
field6: 0xffffffff,
}
}
#[inline(always)]
pub fn slru_segment_key_range(kind: SlruKind, segno: u32) -> Range<Key> {
let field2 = match kind {
SlruKind::Clog => 0x00,
SlruKind::MultiXactMembers => 0x01,
SlruKind::MultiXactOffsets => 0x02,
};
Key {
field1: 0x01,
field2,
field3: 1,
field4: segno,
field5: 0,
field6: 0,
}..Key {
field1: 0x01,
field2,
field3: 1,
field4: segno,
field5: 1,
field6: 0,
}
}
//-- Section 03: pg_twophase
pub const TWOPHASEDIR_KEY: Key = Key {
field1: 0x02,
field2: 0,
field3: 0,
field4: 0,
field5: 0,
field6: 0,
};
#[inline(always)]
pub fn twophase_file_key(xid: TransactionId) -> Key {
Key {
field1: 0x02,
field2: 0,
field3: 0,
field4: 0,
field5: 0,
field6: xid,
}
}
#[inline(always)]
pub fn twophase_key_range(xid: TransactionId) -> Range<Key> {
let (next_xid, overflowed) = xid.overflowing_add(1);
Key {
field1: 0x02,
field2: 0,
field3: 0,
field4: 0,
field5: 0,
field6: xid,
}..Key {
field1: 0x02,
field2: 0,
field3: 0,
field4: 0,
field5: u8::from(overflowed),
field6: next_xid,
}
}
//-- Section 03: Control file
pub const CONTROLFILE_KEY: Key = Key {
field1: 0x03,
field2: 0,
field3: 0,
field4: 0,
field5: 0,
field6: 0,
};
pub const CHECKPOINT_KEY: Key = Key {
field1: 0x03,
field2: 0,
field3: 0,
field4: 0,
field5: 0,
field6: 1,
};
pub const AUX_FILES_KEY: Key = Key {
field1: 0x03,
field2: 0,
field3: 0,
field4: 0,
field5: 0,
field6: 2,
};
// Reverse mappings for a few Keys.
// These are needed by WAL redo manager.
// AUX_FILES currently stores only data for logical replication (slots etc), and
// we don't preserve these on a branch because safekeepers can't follow timeline
// switch (and generally it likely should be optional), so ignore these.
#[inline(always)]
pub fn is_inherited_key(key: Key) -> bool {
key != AUX_FILES_KEY
}
#[inline(always)]
pub fn is_rel_fsm_block_key(key: Key) -> bool {
key.field1 == 0x00 && key.field4 != 0 && key.field5 == FSM_FORKNUM && key.field6 != 0xffffffff
}
#[inline(always)]
pub fn is_rel_vm_block_key(key: Key) -> bool {
key.field1 == 0x00
&& key.field4 != 0
&& key.field5 == VISIBILITYMAP_FORKNUM
&& key.field6 != 0xffffffff
}
#[inline(always)]
pub fn key_to_slru_block(key: Key) -> anyhow::Result<(SlruKind, u32, BlockNumber)> {
Ok(match key.field1 {
0x01 => {
let kind = match key.field2 {
0x00 => SlruKind::Clog,
0x01 => SlruKind::MultiXactMembers,
0x02 => SlruKind::MultiXactOffsets,
_ => anyhow::bail!("unrecognized slru kind 0x{:02x}", key.field2),
};
let segno = key.field4;
let blknum = key.field6;
(kind, segno, blknum)
}
_ => anyhow::bail!("unexpected value kind 0x{:02x}", key.field1),
})
}
#[inline(always)]
pub fn is_slru_block_key(key: Key) -> bool {
key.field1 == 0x01 // SLRU-related
&& key.field3 == 0x00000001 // but not SlruDir
&& key.field6 != 0xffffffff // and not SlruSegSize
}
#[inline(always)]
pub fn is_rel_block_key(key: &Key) -> bool {
key.field1 == 0x00 && key.field4 != 0 && key.field6 != 0xffffffff
}
/// Guaranteed to return `Ok()` if [[is_rel_block_key]] returns `true` for `key`.
#[inline(always)]
pub fn key_to_rel_block(key: Key) -> anyhow::Result<(RelTag, BlockNumber)> {
Ok(match key.field1 {
0x00 => (

View File

@@ -329,8 +329,8 @@ impl CheckPoint {
///
/// Returns 'true' if the XID was updated.
pub fn update_next_xid(&mut self, xid: u32) -> bool {
// nextXid should be greater than any XID in WAL, so increment provided XID and check for wraparround.
let mut new_xid = std::cmp::max(xid.wrapping_add(1), pg_constants::FIRST_NORMAL_TRANSACTION_ID);
// nextXid should nw greater than any XID in WAL, so increment provided XID and check for wraparround.
let mut new_xid = std::cmp::max(xid + 1, pg_constants::FIRST_NORMAL_TRANSACTION_ID);
// To reduce number of metadata checkpoints, we forward align XID on XID_CHECKPOINT_INTERVAL.
// XID_CHECKPOINT_INTERVAL should not be larger than BLCKSZ*CLOG_XACTS_PER_BYTE
new_xid =

View File

@@ -82,19 +82,6 @@ impl<S> Framed<S> {
write_buf: self.write_buf,
})
}
/// Return new Framed with stream type transformed by f. For dynamic dispatch.
pub fn map_stream_sync<S2, F>(self, f: F) -> Framed<S2>
where
F: FnOnce(S) -> S2,
{
let stream = f(self.stream);
Framed {
stream,
read_buf: self.read_buf,
write_buf: self.write_buf,
}
}
}
impl<S: AsyncRead + Unpin> Framed<S> {

View File

@@ -5,10 +5,10 @@ use std::os::unix::io::RawFd;
pub fn set_nonblock(fd: RawFd) -> Result<(), std::io::Error> {
let bits = fcntl(fd, F_GETFL)?;
// If F_GETFL returns some unknown bits, they should be valid
// Safety: If F_GETFL returns some unknown bits, they should be valid
// for passing back to F_SETFL, too. If we left them out, the F_SETFL
// would effectively clear them, which is not what we want.
let mut flags = OFlag::from_bits_retain(bits);
let mut flags = unsafe { OFlag::from_bits_unchecked(bits) };
flags |= OFlag::O_NONBLOCK;
fcntl(fd, F_SETFL(flags))?;

View File

@@ -1,6 +1,7 @@
use std::{
io,
net::{TcpListener, ToSocketAddrs},
os::unix::prelude::AsRawFd,
};
use nix::sys::socket::{setsockopt, sockopt::ReuseAddr};
@@ -9,7 +10,7 @@ use nix::sys::socket::{setsockopt, sockopt::ReuseAddr};
pub fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<TcpListener> {
let listener = TcpListener::bind(addr)?;
setsockopt(&listener, ReuseAddr, &true)?;
setsockopt(listener.as_raw_fd(), ReuseAddr, &true)?;
Ok(listener)
}

View File

@@ -61,7 +61,7 @@ use crate::context::{DownloadBehavior, RequestContext};
use crate::import_datadir::import_wal_from_tar;
use crate::metrics;
use crate::metrics::LIVE_CONNECTIONS_COUNT;
use crate::pgdatadir_mapping::Version;
use crate::pgdatadir_mapping::{rel_block_to_key, Version};
use crate::task_mgr;
use crate::task_mgr::TaskKind;
use crate::tenant::debug_assert_current_span_has_tenant_and_timeline_id;
@@ -75,7 +75,6 @@ use crate::tenant::PageReconstructError;
use crate::tenant::Timeline;
use crate::trace::Tracer;
use pageserver_api::key::rel_block_to_key;
use postgres_ffi::pg_constants::DEFAULTTABLESPACE_OID;
use postgres_ffi::BLCKSZ;
@@ -387,18 +386,12 @@ impl PageServerHandler {
/// Future that completes when we need to shut down the connection.
///
/// We currently need to shut down when any of the following happens:
/// 1. any of the timelines we hold GateGuards for in `shard_timelines` is cancelled
/// 2. task_mgr requests shutdown of the connection
/// Reasons for need to shut down are:
/// - any of the timelines we hold GateGuards for in `shard_timelines` is cancelled
/// - task_mgr requests shutdown of the connection
///
/// NB on (1): the connection's lifecycle is not actually tied to any of the
/// `shard_timelines`s' lifecycles. But it's _necessary_ in the current
/// implementation to be responsive to timeline cancellation because
/// the connection holds their `GateGuards` open (sored in `shard_timelines`).
/// We currently do the easy thing and terminate the connection if any of the
/// shard_timelines gets cancelled. But really, we cuold spend more effort
/// and simply remove the cancelled timeline from the `shard_timelines`, thereby
/// dropping the guard.
/// The need to check for `task_mgr` cancellation arises mainly from `handle_pagerequests`
/// where, at first, `shard_timelines` is empty, see <https://github.com/neondatabase/neon/pull/6388>
///
/// NB: keep in sync with [`Self::is_connection_cancelled`]
async fn await_connection_cancelled(&self) {
@@ -411,17 +404,16 @@ impl PageServerHandler {
// immutable &self). So it's fine to evaluate shard_timelines after the sleep, we don't risk
// missing any inserts to the map.
let mut cancellation_sources = Vec::with_capacity(1 + self.shard_timelines.len());
use futures::future::Either;
cancellation_sources.push(Either::Left(task_mgr::shutdown_watcher()));
cancellation_sources.extend(
self.shard_timelines
.values()
.map(|ht| Either::Right(ht.timeline.cancel.cancelled())),
);
FuturesUnordered::from_iter(cancellation_sources)
.next()
.await;
let mut futs = self
.shard_timelines
.values()
.map(|ht| ht.timeline.cancel.cancelled())
.collect::<FuturesUnordered<_>>();
tokio::select! {
_ = task_mgr::shutdown_watcher() => { }
_ = futs.next() => {}
}
}
/// Checking variant of [`Self::await_connection_cancelled`].

View File

@@ -13,12 +13,7 @@ use crate::repository::*;
use crate::walrecord::NeonWalRecord;
use anyhow::{ensure, Context};
use bytes::{Buf, Bytes};
use pageserver_api::key::{
dbdir_key_range, is_rel_block_key, is_slru_block_key, rel_block_to_key, rel_dir_to_key,
rel_key_range, rel_size_to_key, relmap_file_key, slru_block_to_key, slru_dir_to_key,
slru_segment_key_range, slru_segment_size_to_key, twophase_file_key, twophase_key_range,
AUX_FILES_KEY, CHECKPOINT_KEY, CONTROLFILE_KEY, DBDIR_KEY, TWOPHASEDIR_KEY,
};
use pageserver_api::key::is_rel_block_key;
use pageserver_api::reltag::{BlockNumber, RelTag, SlruKind};
use postgres_ffi::relfile_utils::{FSM_FORKNUM, VISIBILITYMAP_FORKNUM};
use postgres_ffi::BLCKSZ;
@@ -1540,6 +1535,366 @@ struct SlruSegmentDirectory {
static ZERO_PAGE: Bytes = Bytes::from_static(&[0u8; BLCKSZ as usize]);
// Layout of the Key address space
//
// The Key struct, used to address the underlying key-value store, consists of
// 18 bytes, split into six fields. See 'Key' in repository.rs. We need to map
// all the data and metadata keys into those 18 bytes.
//
// Principles for the mapping:
//
// - Things that are often accessed or modified together, should be close to
// each other in the key space. For example, if a relation is extended by one
// block, we create a new key-value pair for the block data, and update the
// relation size entry. Because of that, the RelSize key comes after all the
// RelBlocks of a relation: the RelSize and the last RelBlock are always next
// to each other.
//
// The key space is divided into four major sections, identified by the first
// byte, and the form a hierarchy:
//
// 00 Relation data and metadata
//
// DbDir () -> (dbnode, spcnode)
// Filenodemap
// RelDir -> relnode forknum
// RelBlocks
// RelSize
//
// 01 SLRUs
//
// SlruDir kind
// SlruSegBlocks segno
// SlruSegSize
//
// 02 pg_twophase
//
// 03 misc
// Controlfile
// checkpoint
// pg_version
//
// 04 aux files
//
// Below is a full list of the keyspace allocation:
//
// DbDir:
// 00 00000000 00000000 00000000 00 00000000
//
// Filenodemap:
// 00 SPCNODE DBNODE 00000000 00 00000000
//
// RelDir:
// 00 SPCNODE DBNODE 00000000 00 00000001 (Postgres never uses relfilenode 0)
//
// RelBlock:
// 00 SPCNODE DBNODE RELNODE FORK BLKNUM
//
// RelSize:
// 00 SPCNODE DBNODE RELNODE FORK FFFFFFFF
//
// SlruDir:
// 01 kind 00000000 00000000 00 00000000
//
// SlruSegBlock:
// 01 kind 00000001 SEGNO 00 BLKNUM
//
// SlruSegSize:
// 01 kind 00000001 SEGNO 00 FFFFFFFF
//
// TwoPhaseDir:
// 02 00000000 00000000 00000000 00 00000000
//
// TwoPhaseFile:
// 02 00000000 00000000 00000000 00 XID
//
// ControlFile:
// 03 00000000 00000000 00000000 00 00000000
//
// Checkpoint:
// 03 00000000 00000000 00000000 00 00000001
//
// AuxFiles:
// 03 00000000 00000000 00000000 00 00000002
//
//-- Section 01: relation data and metadata
const DBDIR_KEY: Key = Key {
field1: 0x00,
field2: 0,
field3: 0,
field4: 0,
field5: 0,
field6: 0,
};
fn dbdir_key_range(spcnode: Oid, dbnode: Oid) -> Range<Key> {
Key {
field1: 0x00,
field2: spcnode,
field3: dbnode,
field4: 0,
field5: 0,
field6: 0,
}..Key {
field1: 0x00,
field2: spcnode,
field3: dbnode,
field4: 0xffffffff,
field5: 0xff,
field6: 0xffffffff,
}
}
fn relmap_file_key(spcnode: Oid, dbnode: Oid) -> Key {
Key {
field1: 0x00,
field2: spcnode,
field3: dbnode,
field4: 0,
field5: 0,
field6: 0,
}
}
fn rel_dir_to_key(spcnode: Oid, dbnode: Oid) -> Key {
Key {
field1: 0x00,
field2: spcnode,
field3: dbnode,
field4: 0,
field5: 0,
field6: 1,
}
}
pub(crate) fn rel_block_to_key(rel: RelTag, blknum: BlockNumber) -> Key {
Key {
field1: 0x00,
field2: rel.spcnode,
field3: rel.dbnode,
field4: rel.relnode,
field5: rel.forknum,
field6: blknum,
}
}
fn rel_size_to_key(rel: RelTag) -> Key {
Key {
field1: 0x00,
field2: rel.spcnode,
field3: rel.dbnode,
field4: rel.relnode,
field5: rel.forknum,
field6: 0xffffffff,
}
}
fn rel_key_range(rel: RelTag) -> Range<Key> {
Key {
field1: 0x00,
field2: rel.spcnode,
field3: rel.dbnode,
field4: rel.relnode,
field5: rel.forknum,
field6: 0,
}..Key {
field1: 0x00,
field2: rel.spcnode,
field3: rel.dbnode,
field4: rel.relnode,
field5: rel.forknum + 1,
field6: 0,
}
}
//-- Section 02: SLRUs
fn slru_dir_to_key(kind: SlruKind) -> Key {
Key {
field1: 0x01,
field2: match kind {
SlruKind::Clog => 0x00,
SlruKind::MultiXactMembers => 0x01,
SlruKind::MultiXactOffsets => 0x02,
},
field3: 0,
field4: 0,
field5: 0,
field6: 0,
}
}
fn slru_block_to_key(kind: SlruKind, segno: u32, blknum: BlockNumber) -> Key {
Key {
field1: 0x01,
field2: match kind {
SlruKind::Clog => 0x00,
SlruKind::MultiXactMembers => 0x01,
SlruKind::MultiXactOffsets => 0x02,
},
field3: 1,
field4: segno,
field5: 0,
field6: blknum,
}
}
fn slru_segment_size_to_key(kind: SlruKind, segno: u32) -> Key {
Key {
field1: 0x01,
field2: match kind {
SlruKind::Clog => 0x00,
SlruKind::MultiXactMembers => 0x01,
SlruKind::MultiXactOffsets => 0x02,
},
field3: 1,
field4: segno,
field5: 0,
field6: 0xffffffff,
}
}
fn slru_segment_key_range(kind: SlruKind, segno: u32) -> Range<Key> {
let field2 = match kind {
SlruKind::Clog => 0x00,
SlruKind::MultiXactMembers => 0x01,
SlruKind::MultiXactOffsets => 0x02,
};
Key {
field1: 0x01,
field2,
field3: 1,
field4: segno,
field5: 0,
field6: 0,
}..Key {
field1: 0x01,
field2,
field3: 1,
field4: segno,
field5: 1,
field6: 0,
}
}
//-- Section 03: pg_twophase
const TWOPHASEDIR_KEY: Key = Key {
field1: 0x02,
field2: 0,
field3: 0,
field4: 0,
field5: 0,
field6: 0,
};
fn twophase_file_key(xid: TransactionId) -> Key {
Key {
field1: 0x02,
field2: 0,
field3: 0,
field4: 0,
field5: 0,
field6: xid,
}
}
fn twophase_key_range(xid: TransactionId) -> Range<Key> {
let (next_xid, overflowed) = xid.overflowing_add(1);
Key {
field1: 0x02,
field2: 0,
field3: 0,
field4: 0,
field5: 0,
field6: xid,
}..Key {
field1: 0x02,
field2: 0,
field3: 0,
field4: 0,
field5: u8::from(overflowed),
field6: next_xid,
}
}
//-- Section 03: Control file
const CONTROLFILE_KEY: Key = Key {
field1: 0x03,
field2: 0,
field3: 0,
field4: 0,
field5: 0,
field6: 0,
};
const CHECKPOINT_KEY: Key = Key {
field1: 0x03,
field2: 0,
field3: 0,
field4: 0,
field5: 0,
field6: 1,
};
const AUX_FILES_KEY: Key = Key {
field1: 0x03,
field2: 0,
field3: 0,
field4: 0,
field5: 0,
field6: 2,
};
// Reverse mappings for a few Keys.
// These are needed by WAL redo manager.
// AUX_FILES currently stores only data for logical replication (slots etc), and
// we don't preserve these on a branch because safekeepers can't follow timeline
// switch (and generally it likely should be optional), so ignore these.
pub fn is_inherited_key(key: Key) -> bool {
key != AUX_FILES_KEY
}
pub fn is_rel_fsm_block_key(key: Key) -> bool {
key.field1 == 0x00 && key.field4 != 0 && key.field5 == FSM_FORKNUM && key.field6 != 0xffffffff
}
pub fn is_rel_vm_block_key(key: Key) -> bool {
key.field1 == 0x00
&& key.field4 != 0
&& key.field5 == VISIBILITYMAP_FORKNUM
&& key.field6 != 0xffffffff
}
pub fn key_to_slru_block(key: Key) -> anyhow::Result<(SlruKind, u32, BlockNumber)> {
Ok(match key.field1 {
0x01 => {
let kind = match key.field2 {
0x00 => SlruKind::Clog,
0x01 => SlruKind::MultiXactMembers,
0x02 => SlruKind::MultiXactOffsets,
_ => anyhow::bail!("unrecognized slru kind 0x{:02x}", key.field2),
};
let segno = key.field4;
let blknum = key.field6;
(kind, segno, blknum)
}
_ => anyhow::bail!("unexpected value kind 0x{:02x}", key.field1),
})
}
fn is_slru_block_key(key: Key) -> bool {
key.field1 == 0x01 // SLRU-related
&& key.field3 == 0x00000001 // but not SlruDir
&& key.field6 != 0xffffffff // and not SlruSegSize
}
#[allow(clippy::bool_assert_comparison)]
#[cfg(test)]
mod tests {

View File

@@ -716,10 +716,6 @@ impl Tenant {
// stayed in Activating for such a long time that shutdown found it in
// that state.
tracing::info!(state=%tenant_clone.current_state(), "Tenant shut down before activation");
// Make the tenant broken so that set_stopping will not hang waiting for it to leave
// the Attaching state. This is an over-reaction (nothing really broke, the tenant is
// just shutting down), but ensures progress.
make_broken(&tenant_clone, anyhow::anyhow!("Shut down while Attaching"));
return Ok(());
},
)

View File

@@ -73,8 +73,8 @@ use crate::metrics::{
TimelineMetrics, MATERIALIZED_PAGE_CACHE_HIT, MATERIALIZED_PAGE_CACHE_HIT_DIRECT,
};
use crate::pgdatadir_mapping::CalculateLogicalSizeError;
use crate::pgdatadir_mapping::{is_inherited_key, is_rel_fsm_block_key, is_rel_vm_block_key};
use crate::tenant::config::TenantConfOpt;
use pageserver_api::key::{is_inherited_key, is_rel_fsm_block_key, is_rel_vm_block_key};
use pageserver_api::reltag::RelTag;
use pageserver_api::shard::ShardIndex;

View File

@@ -33,12 +33,11 @@ use utils::failpoint_support;
use crate::context::RequestContext;
use crate::metrics::WAL_INGEST;
use crate::pgdatadir_mapping::{DatadirModification, Version};
use crate::pgdatadir_mapping::*;
use crate::tenant::PageReconstructError;
use crate::tenant::Timeline;
use crate::walrecord::*;
use crate::ZERO_PAGE;
use pageserver_api::key::rel_block_to_key;
use pageserver_api::reltag::{BlockNumber, RelTag, SlruKind};
use postgres_ffi::pg_constants;
use postgres_ffi::relfile_utils::{FSM_FORKNUM, INIT_FORKNUM, MAIN_FORKNUM, VISIBILITYMAP_FORKNUM};
@@ -103,9 +102,7 @@ impl WalIngest {
buf.advance(decoded.main_data_offset);
assert!(!self.checkpoint_modified);
if decoded.xl_xid != pg_constants::INVALID_TRANSACTION_ID
&& self.checkpoint.update_next_xid(decoded.xl_xid)
{
if self.checkpoint.update_next_xid(decoded.xl_xid) {
self.checkpoint_modified = true;
}
@@ -333,13 +330,8 @@ impl WalIngest {
< 0
{
self.checkpoint.oldestXid = xlog_checkpoint.oldestXid;
self.checkpoint_modified = true;
}
// Write a new checkpoint key-value pair on every checkpoint record, even
// if nothing really changed. Not strictly required, but it seems nice to
// have some trace of the checkpoint records in the layer files at the same
// LSNs.
self.checkpoint_modified = true;
}
}
pg_constants::RM_LOGICALMSG_ID => {

View File

@@ -47,10 +47,11 @@ use crate::metrics::{
WAL_REDO_PROCESS_LAUNCH_DURATION_HISTOGRAM, WAL_REDO_RECORDS_HISTOGRAM,
WAL_REDO_RECORD_COUNTER, WAL_REDO_TIME,
};
use crate::pgdatadir_mapping::key_to_slru_block;
use crate::repository::Key;
use crate::walrecord::NeonWalRecord;
use pageserver_api::key::{key_to_rel_block, key_to_slru_block};
use pageserver_api::key::key_to_rel_block;
use pageserver_api::reltag::{RelTag, SlruKind};
use postgres_ffi::pg_constants;
use postgres_ffi::relfile_utils::VISIBILITYMAP_FORKNUM;
@@ -836,8 +837,9 @@ impl WalRedoProcess {
let mut proc = { input }; // TODO: remove this legacy rename, but this keep the patch small.
let mut nwrite = 0usize;
let mut stdin_pollfds = [PollFd::new(proc.stdin.as_raw_fd(), PollFlags::POLLOUT)];
while nwrite < writebuf.len() {
let mut stdin_pollfds = [PollFd::new(&proc.stdin, PollFlags::POLLOUT)];
let n = loop {
match nix::poll::poll(&mut stdin_pollfds[..], wal_redo_timeout.as_millis() as i32) {
Err(nix::errno::Errno::EINTR) => continue,
@@ -876,6 +878,7 @@ impl WalRedoProcess {
// advancing processed responses number.
let mut output = self.stdout.lock().unwrap();
let mut stdout_pollfds = [PollFd::new(output.stdout.as_raw_fd(), PollFlags::POLLIN)];
let n_processed_responses = output.n_processed_responses;
while n_processed_responses + output.pending_responses.len() <= request_no {
// We expect the WAL redo process to respond with an 8k page image. We read it
@@ -883,7 +886,6 @@ impl WalRedoProcess {
let mut resultbuf = vec![0; BLCKSZ.into()];
let mut nresult: usize = 0; // # of bytes read into 'resultbuf' so far
while nresult < BLCKSZ.into() {
let mut stdout_pollfds = [PollFd::new(&output.stdout, PollFlags::POLLIN)];
// We do two things simultaneously: reading response from stdout
// and forward any logging information that the child writes to its stderr to the page server's log.
let n = loop {

View File

@@ -637,7 +637,7 @@ HandleAlterRole(AlterRoleStmt *stmt)
ListCell *option;
const char *role_name = stmt->role->rolename;
if (RoleIsNeonSuperuser(role_name) && !superuser())
if (RoleIsNeonSuperuser(role_name))
elog(ERROR, "can't ALTER neon_superuser");
foreach(option, stmt->options)

View File

@@ -64,26 +64,10 @@ static int max_reconnect_attempts = 60;
#define MAX_PAGESERVER_CONNSTRING_SIZE 256
/*
* The "neon.pageserver_connstring" GUC is marked with the PGC_SIGHUP option,
* allowing it to be changed using pg_reload_conf(). The control plane can
* update the connection string if the pageserver crashes, is relocated, or
* new shards are added. A copy of the current value of the GUC is kept in
* shared memory, updated by the postmaster, because regular backends don't
* reload the config during query execution, but we might need to re-establish
* the pageserver connection with the new connection string even in the middle
* of a query.
*
* The shared memory copy is protected by a lockless algorithm using two
* atomic counters. The counters allow a backend to quickly check if the value
* has changed since last access, and to detect and retry copying the value if
* the postmaster changes the value concurrently. (Postmaster doesn't have a
* PGPROC entry and therefore cannot use LWLocks.)
*/
typedef struct
{
pg_atomic_uint64 begin_update_counter;
pg_atomic_uint64 end_update_counter;
LWLockId lock;
pg_atomic_uint64 update_counter;
char pageserver_connstring[MAX_PAGESERVER_CONNSTRING_SIZE];
} PagestoreShmemState;
@@ -100,7 +84,7 @@ static bool pageserver_flush(void);
static void pageserver_disconnect(void);
static bool
PagestoreShmemIsValid(void)
PagestoreShmemIsValid()
{
return pagestore_shared && UsedShmemSegAddr;
}
@@ -114,58 +98,31 @@ CheckPageserverConnstring(char **newval, void **extra, GucSource source)
static void
AssignPageserverConnstring(const char *newval, void *extra)
{
/*
* Only postmaster updates the copy in shared memory.
*/
if (!PagestoreShmemIsValid() || IsUnderPostmaster)
if (!PagestoreShmemIsValid())
return;
pg_atomic_add_fetch_u64(&pagestore_shared->begin_update_counter, 1);
pg_write_barrier();
LWLockAcquire(pagestore_shared->lock, LW_EXCLUSIVE);
strlcpy(pagestore_shared->pageserver_connstring, newval, MAX_PAGESERVER_CONNSTRING_SIZE);
pg_write_barrier();
pg_atomic_add_fetch_u64(&pagestore_shared->end_update_counter, 1);
pg_atomic_fetch_add_u64(&pagestore_shared->update_counter, 1);
LWLockRelease(pagestore_shared->lock);
}
static bool
CheckConnstringUpdated(void)
CheckConnstringUpdated()
{
if (!PagestoreShmemIsValid())
return false;
return pagestore_local_counter < pg_atomic_read_u64(&pagestore_shared->begin_update_counter);
return pagestore_local_counter < pg_atomic_read_u64(&pagestore_shared->update_counter);
}
static void
ReloadConnstring(void)
ReloadConnstring()
{
uint64 begin_update_counter;
uint64 end_update_counter;
if (!PagestoreShmemIsValid())
return;
/*
* Copy the current settnig from shared to local memory. Postmaster can
* update the value concurrently, in which case we would copy a garbled
* mix of the old and new values. We will detect it because the counter's
* won't match, and retry. But it's important that we don't do anything
* within the retry-loop that would depend on the string having valid
* contents.
*/
do
{
begin_update_counter = pg_atomic_read_u64(&pagestore_shared->begin_update_counter);
end_update_counter = pg_atomic_read_u64(&pagestore_shared->end_update_counter);
pg_read_barrier();
strlcpy(local_pageserver_connstring, pagestore_shared->pageserver_connstring, sizeof(local_pageserver_connstring));
pg_read_barrier();
}
while (begin_update_counter != end_update_counter
|| begin_update_counter != pg_atomic_read_u64(&pagestore_shared->begin_update_counter)
|| end_update_counter != pg_atomic_read_u64(&pagestore_shared->end_update_counter));
pagestore_local_counter = end_update_counter;
LWLockAcquire(pagestore_shared->lock, LW_SHARED);
strlcpy(local_pageserver_connstring, pagestore_shared->pageserver_connstring, sizeof(local_pageserver_connstring));
pagestore_local_counter = pg_atomic_read_u64(&pagestore_shared->update_counter);
LWLockRelease(pagestore_shared->lock);
}
static bool
@@ -180,7 +137,7 @@ pageserver_connect(int elevel)
static TimestampTz last_connect_time = 0;
static uint64_t delay_us = MIN_RECONNECT_INTERVAL_USEC;
TimestampTz now;
uint64_t us_since_last_connect;
uint64_t us_since_last_connect;
Assert(!connected);
@@ -190,7 +147,7 @@ pageserver_connect(int elevel)
}
now = GetCurrentTimestamp();
us_since_last_connect = now - last_connect_time;
us_since_last_connect = now - last_connect_time;
if (us_since_last_connect < delay_us)
{
pg_usleep(delay_us - us_since_last_connect);
@@ -548,8 +505,8 @@ PagestoreShmemInit(void)
&found);
if (!found)
{
pg_atomic_init_u64(&pagestore_shared->begin_update_counter, 0);
pg_atomic_init_u64(&pagestore_shared->end_update_counter, 0);
pagestore_shared->lock = &(GetNamedLWLockTranche("neon_libpagestore")->lock);
pg_atomic_init_u64(&pagestore_shared->update_counter, 0);
AssignPageserverConnstring(page_server_connstring, NULL);
}
LWLockRelease(AddinShmemInitLock);
@@ -574,6 +531,7 @@ pagestore_shmem_request(void)
#endif
RequestAddinShmemSpace(PagestoreShmemSize());
RequestNamedLWLockTranche("neon_libpagestore", 1);
}
static void

View File

@@ -5,7 +5,7 @@ edition.workspace = true
license.workspace = true
[features]
default = []
default = ["testing"]
testing = []
[dependencies]

View File

@@ -13,7 +13,7 @@ use password_hack::PasswordHackPayload;
mod flow;
pub use flow::*;
use crate::error::{ReportableError, UserFacingError};
use crate::{console, error::UserFacingError};
use std::io;
use thiserror::Error;
@@ -23,6 +23,15 @@ pub type Result<T> = std::result::Result<T, AuthError>;
/// Common authentication error.
#[derive(Debug, Error)]
pub enum AuthErrorImpl {
#[error(transparent)]
Link(#[from] backend::LinkAuthError),
#[error(transparent)]
GetAuthInfo(#[from] console::errors::GetAuthInfoError),
#[error(transparent)]
WakeCompute(#[from] console::errors::WakeComputeError),
/// SASL protocol errors (includes [SCRAM](crate::scram)).
#[error(transparent)]
Sasl(#[from] crate::sasl::Error),
@@ -90,25 +99,13 @@ impl<E: Into<AuthErrorImpl>> From<E> for AuthError {
}
}
impl ReportableError for AuthError {
fn get_error_type(&self) -> crate::error::ErrorKind {
match self.0.as_ref() {
AuthErrorImpl::Sasl(s) => s.get_error_type(),
AuthErrorImpl::BadAuthMethod(_) => crate::error::ErrorKind::User,
AuthErrorImpl::MalformedPassword(_) => crate::error::ErrorKind::User,
AuthErrorImpl::MissingEndpointName => crate::error::ErrorKind::User,
AuthErrorImpl::AuthFailed(_) => crate::error::ErrorKind::User,
AuthErrorImpl::Io(_) => crate::error::ErrorKind::Disconnect,
AuthErrorImpl::IpAddressNotAllowed => crate::error::ErrorKind::User,
AuthErrorImpl::TooManyConnections => crate::error::ErrorKind::RateLimit,
}
}
}
impl UserFacingError for AuthError {
fn to_string_client(&self) -> String {
use AuthErrorImpl::*;
match self.0.as_ref() {
Link(e) => e.to_string_client(),
GetAuthInfo(e) => e.to_string_client(),
WakeCompute(e) => e.to_string_client(),
Sasl(e) => e.to_string_client(),
AuthFailed(_) => self.to_string(),
BadAuthMethod(_) => self.to_string(),

View File

@@ -2,27 +2,21 @@ mod classic;
mod hacks;
mod link;
use pq_proto::StartupMessageParams;
pub use link::LinkAuthError;
use smol_str::SmolStr;
use tokio_postgres::config::AuthKeys;
use crate::auth::backend::link::NeedsLinkAuthentication;
use crate::auth::credentials::check_peer_addr_is_in_list;
use crate::auth::validate_password_and_exchange;
use crate::cache::Cached;
use crate::cancellation::Session;
use crate::config::ProxyConfig;
use crate::console::errors::GetAuthInfoError;
use crate::console::provider::ConsoleBackend;
use crate::console::AuthSecret;
use crate::context::RequestMonitoring;
use crate::proxy::wake_compute::NeedsWakeCompute;
use crate::proxy::ClientMode;
use crate::proxy::connect_compute::handle_try_wake;
use crate::proxy::retry::retry_after;
use crate::proxy::NeonOptions;
use crate::rate_limiter::EndpointRateLimiter;
use crate::scram;
use crate::state_machine::{user_facing_error, DynStage, ResultExt, Stage, StageError};
use crate::stream::{PqStream, Stream};
use crate::stream::Stream;
use crate::{
auth::{self, ComputeUserInfoMaybeEndpoint},
config::AuthenticationConfig,
@@ -35,11 +29,10 @@ use crate::{
};
use futures::TryFutureExt;
use std::borrow::Cow;
use std::ops::ControlFlow;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
use self::hacks::NeedsPasswordHack;
use tracing::{error, info, warn};
/// This type serves two purposes:
///
@@ -50,8 +43,11 @@ use self::hacks::NeedsPasswordHack;
/// this helps us provide the credentials only to those auth
/// backends which require them for the authentication process.
pub enum BackendType<'a, T> {
/// Cloud API (V2).
Console(Cow<'a, ConsoleBackend>, T),
/// Current Cloud API (V2).
Console(Cow<'a, console::provider::neon::Api>, T),
/// Local mock of Cloud API (V2).
#[cfg(feature = "testing")]
Postgres(Cow<'a, console::provider::mock::Api>, T),
/// Authentication via a web browser.
Link(Cow<'a, url::ApiUrl>),
#[cfg(test)]
@@ -68,15 +64,9 @@ impl std::fmt::Display for BackendType<'_, ()> {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
use BackendType::*;
match self {
Console(api, _) => match &**api {
ConsoleBackend::Console(endpoint) => {
fmt.debug_tuple("Console").field(&endpoint.url()).finish()
}
#[cfg(feature = "testing")]
ConsoleBackend::Postgres(endpoint) => {
fmt.debug_tuple("Postgres").field(&endpoint.url()).finish()
}
},
Console(endpoint, _) => fmt.debug_tuple("Console").field(&endpoint.url()).finish(),
#[cfg(feature = "testing")]
Postgres(endpoint, _) => fmt.debug_tuple("Postgres").field(&endpoint.url()).finish(),
Link(url) => fmt.debug_tuple("Link").field(&url.as_str()).finish(),
#[cfg(test)]
Test(_) => fmt.debug_tuple("Test").finish(),
@@ -91,6 +81,8 @@ impl<T> BackendType<'_, T> {
use BackendType::*;
match self {
Console(c, x) => Console(Cow::Borrowed(c), x),
#[cfg(feature = "testing")]
Postgres(c, x) => Postgres(Cow::Borrowed(c), x),
Link(c) => Link(Cow::Borrowed(c)),
#[cfg(test)]
Test(x) => Test(*x),
@@ -106,6 +98,8 @@ impl<'a, T> BackendType<'a, T> {
use BackendType::*;
match self {
Console(c, x) => Console(c, f(x)),
#[cfg(feature = "testing")]
Postgres(c, x) => Postgres(c, f(x)),
Link(c) => Link(c),
#[cfg(test)]
Test(x) => Test(x),
@@ -120,6 +114,8 @@ impl<'a, T, E> BackendType<'a, Result<T, E>> {
use BackendType::*;
match self {
Console(c, x) => x.map(|x| Console(c, x)),
#[cfg(feature = "testing")]
Postgres(c, x) => x.map(|x| Postgres(c, x)),
Link(c) => Ok(Link(c)),
#[cfg(test)]
Test(x) => Ok(Test(x)),
@@ -176,94 +172,69 @@ impl TryFrom<ComputeUserInfoMaybeEndpoint> for ComputeUserInfo {
}
}
struct NeedsAuthSecret<S> {
stream: PqStream<Stream<S>>,
api: Cow<'static, ConsoleBackend>,
params: StartupMessageParams,
allow_self_signed_compute: bool,
/// True to its name, this function encapsulates our current auth trade-offs.
/// Here, we choose the appropriate auth flow based on circumstances.
///
/// All authentication flows will emit an AuthenticationOk message if successful.
async fn auth_quirks(
ctx: &mut RequestMonitoring,
api: &impl console::Api,
user_info: ComputeUserInfoMaybeEndpoint,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
allow_cleartext: bool,
info: ComputeUserInfo,
unauthenticated_password: Option<Vec<u8>>,
config: &'static AuthenticationConfig,
// monitoring
ctx: RequestMonitoring,
cancel_session: Session,
}
impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> Stage for NeedsAuthSecret<S> {
fn span(&self) -> tracing::Span {
tracing::info_span!("get_auth_secret")
}
async fn run(self) -> Result<DynStage, StageError> {
let Self {
stream,
api,
params,
allow_cleartext,
allow_self_signed_compute,
info,
unauthenticated_password,
config,
mut ctx,
cancel_session,
} = self;
info!("fetching user's authentication info");
let (allowed_ips, stream) = api
.get_allowed_ips(&mut ctx, &info)
.await
.send_error_to_user(&mut ctx, stream)?;
// check allowed list
if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
return Err(user_facing_error(
auth::AuthError::ip_address_not_allowed(),
&mut ctx,
stream,
));
) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
// If there's no project so far, that entails that client doesn't
// support SNI or other means of passing the endpoint (project) name.
// We now expect to see a very specific payload in the place of password.
let (info, unauthenticated_password) = match user_info.try_into() {
Err(info) => {
let res = hacks::password_hack_no_authentication(info, client, &mut ctx.latency_timer)
.await?;
ctx.set_endpoint_id(Some(res.info.endpoint.clone()));
(res.info, Some(res.keys))
}
let (cached_secret, mut stream) = api
.get_role_secret(&mut ctx, &info)
.await
.send_error_to_user(&mut ctx, stream)?;
Ok(info) => (info, None),
};
let secret = cached_secret.value.clone().unwrap_or_else(|| {
// If we don't have an authentication secret, we mock one to
// prevent malicious probing (possible due to missing protocol steps).
// This mocked secret will never lead to successful authentication.
info!("authentication info not found, mocking it");
AuthSecret::Scram(scram::ServerSecret::mock(&info.user, rand::random()))
});
info!("fetching user's authentication info");
let allowed_ips = api.get_allowed_ips(ctx, &info).await?;
let (keys, stream) = authenticate_with_secret(
&mut ctx,
secret,
info,
&mut stream,
unauthenticated_password,
allow_cleartext,
config,
)
.await
.map_err(|e| {
// check allowed list
if !check_peer_addr_is_in_list(&ctx.peer_addr, &allowed_ips) {
return Err(auth::AuthError::ip_address_not_allowed());
}
let maybe_secret = api.get_role_secret(ctx, &info).await?;
let cached_secret = maybe_secret.unwrap_or_else(|| {
// If we don't have an authentication secret, we mock one to
// prevent malicious probing (possible due to missing protocol steps).
// This mocked secret will never lead to successful authentication.
info!("authentication info not found, mocking it");
Cached::new_uncached(AuthSecret::Scram(scram::ServerSecret::mock(
&info.user,
rand::random(),
)))
});
match authenticate_with_secret(
ctx,
cached_secret.value.clone(),
info,
client,
unauthenticated_password,
allow_cleartext,
config,
)
.await
{
Ok(keys) => Ok(keys),
Err(e) => {
if e.is_auth_failed() {
// The password could have been changed, so we invalidate the cache.
cached_secret.invalidate();
}
e
})
.send_error_to_user(&mut ctx, stream)?;
Ok(Box::new(NeedsWakeCompute {
stream,
api,
params,
allow_self_signed_compute,
creds: keys,
ctx,
cancel_session,
}))
Err(e)
}
}
}
@@ -304,6 +275,49 @@ async fn authenticate_with_secret(
classic::authenticate(info, client, config, &mut ctx.latency_timer, secret).await
}
/// Authenticate the user and then wake a compute (or retrieve an existing compute session from cache)
/// only if authentication was successfuly.
async fn auth_and_wake_compute(
ctx: &mut RequestMonitoring,
api: &impl console::Api,
user_info: ComputeUserInfoMaybeEndpoint,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
) -> auth::Result<(CachedNodeInfo, ComputeUserInfo)> {
let compute_credentials =
auth_quirks(ctx, api, user_info, client, allow_cleartext, config).await?;
let mut num_retries = 0;
let mut node = loop {
let wake_res = api.wake_compute(ctx, &compute_credentials.info).await;
match handle_try_wake(wake_res, num_retries) {
Err(e) => {
error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node");
return Err(e.into());
}
Ok(ControlFlow::Continue(e)) => {
warn!(error = ?e, num_retries, retriable = true, "couldn't wake compute node");
}
Ok(ControlFlow::Break(n)) => break n,
}
let wait_duration = retry_after(num_retries);
num_retries += 1;
tokio::time::sleep(wait_duration).await;
};
ctx.set_project(node.aux.clone());
match compute_credentials.keys {
#[cfg(feature = "testing")]
ComputeCredentialKeys::Password(password) => node.config.password(password),
ComputeCredentialKeys::AuthKeys(auth_keys) => node.config.auth_keys(auth_keys),
};
Ok((node, compute_credentials.info))
}
impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
/// Get compute endpoint name from the credentials.
pub fn get_endpoint(&self) -> Option<SmolStr> {
@@ -311,6 +325,8 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
match self {
Console(_, user_info) => user_info.project.clone(),
#[cfg(feature = "testing")]
Postgres(_, user_info) => user_info.project.clone(),
Link(_) => Some("link".into()),
#[cfg(test)]
Test(_) => Some("test".into()),
@@ -323,101 +339,70 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
match self {
Console(_, user_info) => &user_info.user,
#[cfg(feature = "testing")]
Postgres(_, user_info) => &user_info.user,
Link(_) => "link",
#[cfg(test)]
Test(_) => "test",
}
}
}
pub struct NeedsAuthentication<S> {
pub stream: PqStream<Stream<S>>,
pub creds: BackendType<'static, auth::ComputeUserInfoMaybeEndpoint>,
pub params: StartupMessageParams,
pub endpoint_rate_limiter: Arc<EndpointRateLimiter>,
pub mode: ClientMode,
pub config: &'static ProxyConfig,
/// Authenticate the client via the requested backend, possibly using credentials.
#[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)]
pub async fn authenticate(
self,
ctx: &mut RequestMonitoring,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
) -> auth::Result<(CachedNodeInfo, BackendType<'a, ComputeUserInfo>)> {
use BackendType::*;
// monitoring
pub ctx: RequestMonitoring,
pub cancel_session: Session,
}
let res = match self {
Console(api, user_info) => {
info!(
user = &*user_info.user,
project = user_info.project(),
"performing authentication using the console"
);
impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> Stage for NeedsAuthentication<S> {
fn span(&self) -> tracing::Span {
tracing::info_span!("authenticate")
}
async fn run(self) -> Result<DynStage, StageError> {
let Self {
stream,
creds,
params,
endpoint_rate_limiter,
mode,
config,
mut ctx,
cancel_session,
} = self;
// check rate limit
if let Some(ep) = creds.get_endpoint() {
if !endpoint_rate_limiter.check(ep) {
return Err(user_facing_error(
auth::AuthError::too_many_connections(),
&mut ctx,
stream,
));
let (cache_info, user_info) =
auth_and_wake_compute(ctx, &*api, user_info, client, allow_cleartext, config)
.await?;
(cache_info, BackendType::Console(api, user_info))
}
}
#[cfg(feature = "testing")]
Postgres(api, user_info) => {
info!(
user = &*user_info.user,
project = user_info.project(),
"performing authentication using a local postgres instance"
);
let allow_self_signed_compute = mode.allow_self_signed_compute(config);
let allow_cleartext = mode.allow_cleartext();
match creds {
BackendType::Console(api, creds) => {
// If there's no project so far, that entails that client doesn't
// support SNI or other means of passing the endpoint (project) name.
// We now expect to see a very specific payload in the place of password.
match creds.try_into() {
Err(info) => Ok(Box::new(NeedsPasswordHack {
stream,
api,
params,
allow_self_signed_compute,
info,
allow_cleartext,
config: &config.authentication_config,
ctx,
cancel_session,
})),
Ok(info) => Ok(Box::new(NeedsAuthSecret {
stream,
api,
params,
allow_self_signed_compute,
info,
unauthenticated_password: None,
allow_cleartext,
config: &config.authentication_config,
ctx,
cancel_session,
})),
}
let (cache_info, user_info) =
auth_and_wake_compute(ctx, &*api, user_info, client, allow_cleartext, config)
.await?;
(cache_info, BackendType::Postgres(api, user_info))
}
// NOTE: this auth backend doesn't use client credentials.
BackendType::Link(link) => Ok(Box::new(NeedsLinkAuthentication {
stream,
link,
params,
allow_self_signed_compute,
ctx,
cancel_session,
})),
Link(url) => {
info!("performing link authentication");
let node_info = link::authenticate(&url, client).await?;
(
CachedNodeInfo::new_uncached(node_info),
BackendType::Link(url),
)
}
#[cfg(test)]
BackendType::Test(_) => {
Test(_) => {
unreachable!("this function should never be called in the test backend")
}
}
};
info!("user successfully authenticated");
Ok(res)
}
}
@@ -429,6 +414,8 @@ impl BackendType<'_, ComputeUserInfo> {
use BackendType::*;
match self {
Console(api, user_info) => api.get_allowed_ips(ctx, user_info).await,
#[cfg(feature = "testing")]
Postgres(api, user_info) => api.get_allowed_ips(ctx, user_info).await,
Link(_) => Ok(Cached::new_uncached(Arc::new(vec![]))),
#[cfg(test)]
Test(x) => Ok(Cached::new_uncached(Arc::new(x.get_allowed_ips()?))),
@@ -445,6 +432,8 @@ impl BackendType<'_, ComputeUserInfo> {
match self {
Console(api, user_info) => api.wake_compute(ctx, user_info).map_ok(Some).await,
#[cfg(feature = "testing")]
Postgres(api, user_info) => api.wake_compute(ctx, user_info).map_ok(Some).await,
Link(_) => Ok(None),
#[cfg(test)]
Test(x) => x.wake_compute().map(Some),

View File

@@ -1,21 +1,13 @@
use std::borrow::Cow;
use super::{
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, ComputeUserInfoNoEndpoint,
NeedsAuthSecret,
};
use crate::{
auth::{self, AuthFlow},
cancellation::Session,
config::AuthenticationConfig,
console::{provider::ConsoleBackend, AuthSecret},
context::RequestMonitoring,
console::AuthSecret,
metrics::LatencyTimer,
sasl,
state_machine::{DynStage, ResultExt, Stage, StageError},
stream::{self, PqStream, Stream},
stream::{self, Stream},
};
use pq_proto::StartupMessageParams;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, warn};
@@ -54,7 +46,7 @@ pub async fn authenticate_cleartext(
/// Workaround for clients which don't provide an endpoint (project) name.
/// Similar to [`authenticate_cleartext`], but there's a specific password format,
/// and passwords are not yet validated (we don't know how to validate them!)
async fn password_hack_no_authentication(
pub async fn password_hack_no_authentication(
info: ComputeUserInfoNoEndpoint,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
latency_timer: &mut LatencyTimer,
@@ -82,47 +74,3 @@ async fn password_hack_no_authentication(
keys: payload.password,
})
}
pub struct NeedsPasswordHack<S> {
pub stream: PqStream<Stream<S>>,
pub api: Cow<'static, ConsoleBackend>,
pub params: StartupMessageParams,
pub allow_self_signed_compute: bool,
pub allow_cleartext: bool,
pub info: ComputeUserInfoNoEndpoint,
pub config: &'static AuthenticationConfig,
// monitoring
pub ctx: RequestMonitoring,
pub cancel_session: Session,
}
impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> Stage for NeedsPasswordHack<S> {
fn span(&self) -> tracing::Span {
tracing::info_span!("password_hack")
}
async fn run(mut self) -> Result<DynStage, StageError> {
let (res, stream) = password_hack_no_authentication(
self.info,
&mut self.stream,
&mut self.ctx.latency_timer,
)
.await
.send_error_to_user(&mut self.ctx, self.stream)?;
self.ctx.set_endpoint_id(Some(res.info.endpoint.clone()));
Ok(Box::new(NeedsAuthSecret {
stream,
info: res.info,
unauthenticated_password: Some(res.keys),
api: self.api,
params: self.params,
allow_self_signed_compute: self.allow_self_signed_compute,
allow_cleartext: self.allow_cleartext,
ctx: self.ctx,
cancel_session: self.cancel_session,
config: self.config,
}))
}
}

View File

@@ -1,20 +1,41 @@
use std::borrow::Cow;
use crate::{
auth::BackendType,
cancellation::Session,
compute,
console::{self, mgmt::ComputeReady, provider::NodeInfo, CachedNodeInfo},
context::RequestMonitoring,
proxy::connect_compute::{NeedsComputeConnection, TcpMechanism},
state_machine::{DynStage, ResultExt, Stage, StageError},
stream::{PqStream, Stream},
waiters::Waiter,
auth, compute,
console::{self, provider::NodeInfo},
error::UserFacingError,
stream::PqStream,
waiters,
};
use pq_proto::{BeMessage as Be, StartupMessageParams};
use pq_proto::BeMessage as Be;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_postgres::config::SslMode;
use tracing::info;
use tracing::{info, info_span};
#[derive(Debug, Error)]
pub enum LinkAuthError {
/// Authentication error reported by the console.
#[error("Authentication failed: {0}")]
AuthFailed(String),
#[error(transparent)]
WaiterRegister(#[from] waiters::RegisterError),
#[error(transparent)]
WaiterWait(#[from] waiters::WaitError),
#[error(transparent)]
Io(#[from] std::io::Error),
}
impl UserFacingError for LinkAuthError {
fn to_string_client(&self) -> String {
use LinkAuthError::*;
match self {
AuthFailed(_) => self.to_string(),
_ => "Internal error".to_string(),
}
}
}
fn hello_message(redirect_uri: &reqwest::Url, session_id: &str) -> String {
format!(
@@ -32,146 +53,57 @@ pub fn new_psql_session_id() -> String {
hex::encode(rand::random::<[u8; 8]>())
}
pub struct NeedsLinkAuthentication<S> {
pub stream: PqStream<Stream<S>>,
pub link: Cow<'static, crate::url::ApiUrl>,
pub params: StartupMessageParams,
pub allow_self_signed_compute: bool,
pub(super) async fn authenticate(
link_uri: &reqwest::Url,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<NodeInfo> {
let psql_session_id = new_psql_session_id();
let span = info_span!("link", psql_session_id = &psql_session_id);
let greeting = hello_message(link_uri, &psql_session_id);
// monitoring
pub ctx: RequestMonitoring,
pub cancel_session: Session,
}
impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> Stage for NeedsLinkAuthentication<S> {
fn span(&self) -> tracing::Span {
tracing::info_span!("link", psql_session_id = tracing::field::Empty)
}
async fn run(self) -> Result<DynStage, StageError> {
let Self {
mut stream,
link,
params,
allow_self_signed_compute,
mut ctx,
cancel_session,
} = self;
// registering waiter can fail if we get unlucky with rng.
// just try again.
let (psql_session_id, waiter) = loop {
let psql_session_id = new_psql_session_id();
match console::mgmt::get_waiter(&psql_session_id) {
Ok(waiter) => break (psql_session_id, waiter),
Err(_e) => continue,
}
};
tracing::Span::current().record("psql_session_id", &psql_session_id);
let greeting = hello_message(&link, &psql_session_id);
info!("sending the auth URL to the user");
stream
.write_message_noflush(&Be::AuthenticationOk)
.and_then(|s| s.write_message_noflush(&Be::CLIENT_ENCODING))
.and_then(|s| s.write_message_noflush(&Be::NoticeResponse(&greeting)))
.no_user_error(&mut ctx, crate::error::ErrorKind::Service)?
.flush()
.await
.no_user_error(&mut ctx, crate::error::ErrorKind::Disconnect)?;
Ok(Box::new(NeedsLinkAuthenticationResponse {
stream,
link,
params,
allow_self_signed_compute,
waiter,
psql_session_id,
ctx,
cancel_session,
}))
}
}
struct NeedsLinkAuthenticationResponse<S> {
stream: PqStream<Stream<S>>,
link: Cow<'static, crate::url::ApiUrl>,
params: StartupMessageParams,
allow_self_signed_compute: bool,
waiter: Waiter<'static, ComputeReady>,
psql_session_id: String,
// monitoring
ctx: RequestMonitoring,
cancel_session: Session,
}
impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> Stage
for NeedsLinkAuthenticationResponse<S>
{
fn span(&self) -> tracing::Span {
tracing::info_span!("link_wait", psql_session_id = self.psql_session_id)
}
async fn run(self) -> Result<DynStage, StageError> {
let Self {
mut stream,
link,
params,
allow_self_signed_compute,
waiter,
psql_session_id: _,
mut ctx,
cancel_session,
} = self;
let db_info = console::mgmt::with_waiter(psql_session_id, |waiter| async {
// Give user a URL to spawn a new database.
info!(parent: &span, "sending the auth URL to the user");
client
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&Be::CLIENT_ENCODING)?
.write_message(&Be::NoticeResponse(&greeting))
.await?;
// Wait for web console response (see `mgmt`).
info!("waiting for console's reply...");
let db_info = waiter
.await
.no_user_error(&mut ctx, crate::error::ErrorKind::Service)?;
info!(parent: &span, "waiting for console's reply...");
waiter.await?.map_err(LinkAuthError::AuthFailed)
})
.await?;
stream
.write_message_noflush(&Be::NoticeResponse("Connecting to database."))
.no_user_error(&mut ctx, crate::error::ErrorKind::Service)?;
client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?;
// This config should be self-contained, because we won't
// take username or dbname from client's startup message.
let mut config = compute::ConnCfg::new();
config
.host(&db_info.host)
.port(db_info.port)
.dbname(&db_info.dbname)
.user(&db_info.user);
// This config should be self-contained, because we won't
// take username or dbname from client's startup message.
let mut config = compute::ConnCfg::new();
config
.host(&db_info.host)
.port(db_info.port)
.dbname(&db_info.dbname)
.user(&db_info.user);
// Backwards compatibility. pg_sni_proxy uses "--" in domain names
// while direct connections do not. Once we migrate to pg_sni_proxy
// everywhere, we can remove this.
if db_info.host.contains("--") {
// we need TLS connection with SNI info to properly route it
config.ssl_mode(SslMode::Require);
} else {
config.ssl_mode(SslMode::Disable);
}
if let Some(password) = db_info.password {
config.password(password.as_ref());
}
let node_info = CachedNodeInfo::new_uncached(NodeInfo {
config,
aux: db_info.aux,
allow_self_signed_compute,
});
let user_info = BackendType::Link(link);
Ok(Box::new(NeedsComputeConnection {
stream,
user_info,
mechanism: TcpMechanism { params },
node_info,
ctx,
cancel_session,
}))
// Backwards compatibility. pg_sni_proxy uses "--" in domain names
// while direct connections do not. Once we migrate to pg_sni_proxy
// everywhere, we can remove this.
if db_info.host.contains("--") {
// we need TLS connection with SNI info to properly route it
config.ssl_mode(SslMode::Require);
} else {
config.ssl_mode(SslMode::Disable);
}
if let Some(password) = db_info.password {
config.password(password.as_ref());
}
Ok(NodeInfo {
config,
aux: db_info.aux,
allow_self_signed_compute: false, // caller may override
})
}

View File

@@ -1,11 +1,8 @@
//! User credentials used in authentication.
use crate::{
auth::password_hack::parse_endpoint_param,
context::RequestMonitoring,
error::{ReportableError, UserFacingError},
metrics::NUM_CONNECTION_ACCEPTED_BY_SNI,
proxy::NeonOptions,
auth::password_hack::parse_endpoint_param, context::RequestMonitoring, error::UserFacingError,
metrics::NUM_CONNECTION_ACCEPTED_BY_SNI, proxy::NeonOptions,
};
use itertools::Itertools;
use pq_proto::StartupMessageParams;
@@ -36,24 +33,7 @@ pub enum ComputeUserInfoParseError {
MalformedProjectName(SmolStr),
}
impl ReportableError for ComputeUserInfoParseError {
fn get_error_type(&self) -> crate::error::ErrorKind {
match self {
ComputeUserInfoParseError::MissingKey(_) => crate::error::ErrorKind::User,
ComputeUserInfoParseError::InconsistentProjectNames { .. } => {
crate::error::ErrorKind::User
}
ComputeUserInfoParseError::UnknownCommonName { .. } => crate::error::ErrorKind::User,
ComputeUserInfoParseError::MalformedProjectName(_) => crate::error::ErrorKind::User,
}
}
}
impl UserFacingError for ComputeUserInfoParseError {
fn to_string_client(&self) -> String {
self.to_string()
}
}
impl UserFacingError for ComputeUserInfoParseError {}
/// Various client credentials which we use for authentication.
/// Note that we don't store any kind of client key or password here.

View File

@@ -164,13 +164,6 @@ async fn task_main(
let tls_config = Arc::clone(&tls_config);
let dest_suffix = Arc::clone(&dest_suffix);
let root_span = tracing::info_span!(
"handle_client",
?session_id,
endpoint = tracing::field::Empty
);
let root_span2 = root_span.clone();
connections.spawn(
async move {
socket
@@ -178,13 +171,8 @@ async fn task_main(
.context("failed to set socket option")?;
info!(%peer_addr, "serving");
let mut ctx = RequestMonitoring::new(
session_id,
peer_addr.ip(),
"sni_router",
"sni",
root_span2,
);
let mut ctx =
RequestMonitoring::new(session_id, peer_addr.ip(), "sni_router", "sni");
handle_client(
&mut ctx,
dest_suffix,
@@ -198,7 +186,7 @@ async fn task_main(
// Acknowledge that the task has finished with an error.
error!("per-client task finished with an error: {e:#}");
})
.instrument(root_span),
.instrument(tracing::info_span!("handle_client", ?session_id)),
);
}
@@ -283,7 +271,6 @@ async fn handle_client(
let client = tokio::net::TcpStream::connect(destination).await?;
ctx.log();
let metrics_aux: MetricsAuxInfo = Default::default();
proxy::proxy::pass::proxy_pass(tls_stream, client, metrics_aux).await
proxy::proxy::proxy_pass(ctx, tls_stream, client, metrics_aux).await
}

View File

@@ -249,19 +249,12 @@ async fn main() -> anyhow::Result<()> {
}
if let auth::BackendType::Console(api, _) = &config.auth_backend {
match &**api {
proxy::console::provider::ConsoleBackend::Console(api) => {
let cache = api.caches.project_info.clone();
if let Some(url) = args.redis_notifications {
info!("Starting redis notifications listener ({url})");
maintenance_tasks
.spawn(notifications::task_main(url.to_owned(), cache.clone()));
}
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
}
#[cfg(feature = "testing")]
proxy::console::provider::ConsoleBackend::Postgres(_) => {}
let cache = api.caches.project_info.clone();
if let Some(url) = args.redis_notifications {
info!("Starting redis notifications listener ({url})");
maintenance_tasks.spawn(notifications::task_main(url.to_owned(), cache.clone()));
}
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
}
let maintenance = loop {
@@ -358,15 +351,13 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
let endpoint = http::Endpoint::new(url, http::new_client(rate_limiter_config));
let api = console::provider::neon::Api::new(endpoint, caches, locks);
let api = console::provider::ConsoleBackend::Console(api);
auth::BackendType::Console(Cow::Owned(api), ())
}
#[cfg(feature = "testing")]
AuthBackend::Postgres => {
let url = args.auth_endpoint.parse()?;
let api = console::provider::mock::Api::new(url);
let api = console::provider::ConsoleBackend::Postgres(api);
auth::BackendType::Console(Cow::Owned(api), ())
auth::BackendType::Postgres(Cow::Owned(api), ())
}
AuthBackend::Link => {
let url = args.uri.parse()?;

View File

@@ -44,7 +44,7 @@ impl<T> From<T> for Entry<T> {
#[derive(Default)]
struct EndpointInfo {
secret: std::collections::HashMap<SmolStr, Entry<Option<AuthSecret>>>,
secret: std::collections::HashMap<SmolStr, Entry<AuthSecret>>,
allowed_ips: Option<Entry<Arc<Vec<SmolStr>>>>,
}
@@ -60,7 +60,7 @@ impl EndpointInfo {
role_name: &SmolStr,
valid_since: Instant,
ignore_cache_since: Option<Instant>,
) -> Option<(Option<AuthSecret>, bool)> {
) -> Option<(AuthSecret, bool)> {
if let Some(secret) = self.secret.get(role_name) {
if valid_since < secret.created_at {
return Some((
@@ -169,7 +169,7 @@ impl ProjectInfoCacheImpl {
&self,
endpoint_id: &SmolStr,
role_name: &SmolStr,
) -> Option<Cached<&Self, Option<AuthSecret>>> {
) -> Option<Cached<&Self, AuthSecret>> {
let (valid_since, ignore_cache_since) = self.get_cache_times();
let endpoint_info = self.cache.get(endpoint_id)?;
let (value, ignore_cache) =
@@ -208,7 +208,7 @@ impl ProjectInfoCacheImpl {
project_id: &SmolStr,
endpoint_id: &SmolStr,
role_name: &SmolStr,
secret: Option<AuthSecret>,
secret: AuthSecret,
) {
if self.cache.len() >= self.config.size {
// If there are too many entries, wait until the next gc cycle.
@@ -266,7 +266,7 @@ impl ProjectInfoCacheImpl {
tokio::time::interval(self.config.gc_interval / (self.cache.shards().len()) as u32);
loop {
interval.tick().await;
if self.cache.len() < self.config.size {
if self.cache.len() <= self.config.size {
// If there are not too many entries, wait until the next gc cycle.
continue;
}
@@ -364,11 +364,8 @@ mod tests {
let endpoint_id = "endpoint".into();
let user1: SmolStr = "user1".into();
let user2: SmolStr = "user2".into();
let secret1 = Some(AuthSecret::Scram(ServerSecret::mock(
user1.as_str(),
[1; 32],
)));
let secret2 = None;
let secret1 = AuthSecret::Scram(ServerSecret::mock(user1.as_str(), [1; 32]));
let secret2 = AuthSecret::Scram(ServerSecret::mock(user2.as_str(), [2; 32]));
let allowed_ips = Arc::new(vec!["allowed_ip1".into(), "allowed_ip2".into()]);
cache.insert_role_secret(&project_id, &endpoint_id, &user1, secret1.clone());
cache.insert_role_secret(&project_id, &endpoint_id, &user2, secret2.clone());
@@ -383,10 +380,7 @@ mod tests {
// Shouldn't add more than 2 roles.
let user3: SmolStr = "user3".into();
let secret3 = Some(AuthSecret::Scram(ServerSecret::mock(
user3.as_str(),
[3; 32],
)));
let secret3 = AuthSecret::Scram(ServerSecret::mock(user3.as_str(), [3; 32]));
cache.insert_role_secret(&project_id, &endpoint_id, &user3, secret3.clone());
assert!(cache.get_role_secret(&endpoint_id, &user3).is_none());
@@ -419,14 +413,8 @@ mod tests {
let endpoint_id = "endpoint".into();
let user1: SmolStr = "user1".into();
let user2: SmolStr = "user2".into();
let secret1 = Some(AuthSecret::Scram(ServerSecret::mock(
user1.as_str(),
[1; 32],
)));
let secret2 = Some(AuthSecret::Scram(ServerSecret::mock(
user2.as_str(),
[2; 32],
)));
let secret1 = AuthSecret::Scram(ServerSecret::mock(user1.as_str(), [1; 32]));
let secret2 = AuthSecret::Scram(ServerSecret::mock(user2.as_str(), [2; 32]));
let allowed_ips = Arc::new(vec!["allowed_ip1".into(), "allowed_ip2".into()]);
cache.insert_role_secret(&project_id, &endpoint_id, &user1, secret1.clone());
cache.insert_role_secret(&project_id, &endpoint_id, &user2, secret2.clone());
@@ -471,14 +459,8 @@ mod tests {
let endpoint_id = "endpoint".into();
let user1: SmolStr = "user1".into();
let user2: SmolStr = "user2".into();
let secret1 = Some(AuthSecret::Scram(ServerSecret::mock(
user1.as_str(),
[1; 32],
)));
let secret2 = Some(AuthSecret::Scram(ServerSecret::mock(
user2.as_str(),
[2; 32],
)));
let secret1 = AuthSecret::Scram(ServerSecret::mock(user1.as_str(), [1; 32]));
let secret2 = AuthSecret::Scram(ServerSecret::mock(user2.as_str(), [2; 32]));
let allowed_ips = Arc::new(vec!["allowed_ip1".into(), "allowed_ip2".into()]);
cache.insert_role_secret(&project_id, &endpoint_id, &user1, secret1.clone());
cache.clone().disable_ttl();

View File

@@ -1,7 +1,7 @@
use anyhow::Context;
use anyhow::{bail, Context};
use dashmap::DashMap;
use pq_proto::CancelKeyData;
use std::{net::SocketAddr, sync::Arc};
use std::net::SocketAddr;
use tokio::net::TcpStream;
use tokio_postgres::{CancelToken, NoTls};
use tracing::info;
@@ -25,33 +25,39 @@ impl CancelMap {
}
/// Run async action within an ephemeral session identified by [`CancelKeyData`].
pub fn get_session(self: Arc<Self>) -> Session {
pub async fn with_session<'a, F, R, V>(&'a self, f: F) -> anyhow::Result<V>
where
F: FnOnce(Session<'a>) -> R,
R: std::future::Future<Output = anyhow::Result<V>>,
{
// HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
// expose it and we don't want to do another roundtrip to query
// for it. The client will be able to notice that this is not the
// actual backend_pid, but backend_pid is not used for anything
// so it doesn't matter.
let key = loop {
let key = rand::random();
let key = rand::random();
// Random key collisions are unlikely to happen here, but they're still possible,
// which is why we have to take care not to rewrite an existing key.
match self.0.entry(key) {
dashmap::mapref::entry::Entry::Occupied(_) => {
continue;
}
dashmap::mapref::entry::Entry::Vacant(e) => {
e.insert(None);
}
// Random key collisions are unlikely to happen here, but they're still possible,
// which is why we have to take care not to rewrite an existing key.
match self.0.entry(key) {
dashmap::mapref::entry::Entry::Occupied(_) => {
bail!("query cancellation key already exists: {key}")
}
break key;
};
dashmap::mapref::entry::Entry::Vacant(e) => {
e.insert(None);
}
}
// This will guarantee that the session gets dropped
// as soon as the future is finished.
scopeguard::defer! {
self.0.remove(&key);
info!("dropped query cancellation key {key}");
}
info!("registered new query cancellation key {key}");
Session {
key,
cancel_map: self,
}
let session = Session::new(key, self);
f(session).await
}
#[cfg(test)]
@@ -92,17 +98,23 @@ impl CancelClosure {
}
/// Helper for registering query cancellation tokens.
pub struct Session {
pub struct Session<'a> {
/// The user-facing key identifying this session.
key: CancelKeyData,
/// The [`CancelMap`] this session belongs to.
cancel_map: Arc<CancelMap>,
cancel_map: &'a CancelMap,
}
impl Session {
impl<'a> Session<'a> {
fn new(key: CancelKeyData, cancel_map: &'a CancelMap) -> Self {
Self { key, cancel_map }
}
}
impl Session<'_> {
/// Store the cancel token for the given session.
/// This enables query cancellation in `crate::proxy::prepare_client_connection`.
pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
pub fn enable_query_cancellation(self, cancel_closure: CancelClosure) -> CancelKeyData {
info!("enabling query cancellation for this session");
self.cancel_map.0.insert(self.key, Some(cancel_closure));
@@ -110,26 +122,37 @@ impl Session {
}
}
impl Drop for Session {
fn drop(&mut self) {
self.cancel_map.0.remove(&self.key);
info!("dropped query cancellation key {}", &self.key);
}
}
#[cfg(test)]
mod tests {
use super::*;
use once_cell::sync::Lazy;
#[tokio::test]
async fn check_session_drop() -> anyhow::Result<()> {
let cancel_map: Arc<CancelMap> = Default::default();
static CANCEL_MAP: Lazy<CancelMap> = Lazy::new(Default::default);
let (tx, rx) = tokio::sync::oneshot::channel();
let task = tokio::spawn(CANCEL_MAP.with_session(|session| async move {
assert!(CANCEL_MAP.contains(&session));
tx.send(()).expect("failed to send");
futures::future::pending::<()>().await; // sleep forever
Ok(())
}));
// Wait until the task has been spawned.
rx.await.context("failed to hear from the task")?;
// Drop the session's entry by cancelling the task.
task.abort();
let error = task.await.expect_err("task should have failed");
if !error.is_cancelled() {
anyhow::bail!(error);
}
let session = cancel_map.clone().get_session();
assert!(cancel_map.contains(&session));
drop(session);
// Check that the session has been dropped.
assert!(cancel_map.is_empty());
assert!(CANCEL_MAP.is_empty());
Ok(())
}

View File

@@ -1,10 +1,6 @@
use crate::{
auth::parse_endpoint_param,
cancellation::CancelClosure,
console::errors::WakeComputeError,
context::RequestMonitoring,
error::{ReportableError, UserFacingError},
metrics::NUM_DB_CONNECTIONS_GAUGE,
auth::parse_endpoint_param, cancellation::CancelClosure, console::errors::WakeComputeError,
context::RequestMonitoring, error::UserFacingError, metrics::NUM_DB_CONNECTIONS_GAUGE,
proxy::neon_option,
};
use futures::{FutureExt, TryFutureExt};
@@ -36,17 +32,6 @@ pub enum ConnectionError {
WakeComputeError(#[from] WakeComputeError),
}
impl ReportableError for ConnectionError {
fn get_error_type(&self) -> crate::error::ErrorKind {
match self {
ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute,
ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
ConnectionError::WakeComputeError(_) => crate::error::ErrorKind::ControlPlane,
}
}
}
impl UserFacingError for ConnectionError {
fn to_string_client(&self) -> String {
use ConnectionError::*;

View File

@@ -13,10 +13,16 @@ use tracing::{error, info, info_span, Instrument};
static CPLANE_WAITERS: Lazy<Waiters<ComputeReady>> = Lazy::new(Default::default);
/// Give caller an opportunity to wait for the cloud's reply.
pub fn get_waiter(
pub async fn with_waiter<R, T, E>(
psql_session_id: impl Into<String>,
) -> Result<Waiter<'static, ComputeReady>, waiters::RegisterError> {
CPLANE_WAITERS.register(psql_session_id.into())
action: impl FnOnce(Waiter<'static, ComputeReady>) -> R,
) -> Result<T, E>
where
R: std::future::Future<Output = Result<T, E>>,
E: From<waiters::RegisterError>,
{
let waiter = CPLANE_WAITERS.register(psql_session_id.into())?;
action(waiter).await
}
pub fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), waiters::NotifyError> {
@@ -71,7 +77,7 @@ async fn handle_connection(socket: TcpStream) -> Result<(), QueryError> {
}
/// A message received by `mgmt` when a compute node is ready.
pub type ComputeReady = DatabaseInfo;
pub type ComputeReady = Result<DatabaseInfo, String>;
// TODO: replace with an http-based protocol.
struct MgmtHandler;
@@ -96,7 +102,7 @@ fn try_process_query(pgb: &mut PostgresBackendTCP, query: &str) -> Result<(), Qu
let _enter = span.enter();
info!("got response: {:?}", resp.result);
match notify(resp.session_id, resp.result) {
match notify(resp.session_id, Ok(resp.result)) {
Ok(()) => {
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
.write_message_noflush(&BeMessage::DataRow(&[Some(b"ok")]))?

View File

@@ -21,7 +21,7 @@ use tracing::info;
pub mod errors {
use crate::{
error::{io_error, ReportableError, UserFacingError},
error::{io_error, UserFacingError},
http,
proxy::retry::ShouldRetry,
};
@@ -56,15 +56,6 @@ pub mod errors {
}
}
impl ReportableError for ApiError {
fn get_error_type(&self) -> crate::error::ErrorKind {
match self {
ApiError::Console { .. } => crate::error::ErrorKind::ControlPlane,
ApiError::Transport(_) => crate::error::ErrorKind::ControlPlane,
}
}
}
impl UserFacingError for ApiError {
fn to_string_client(&self) -> String {
use ApiError::*;
@@ -149,15 +140,6 @@ pub mod errors {
}
}
impl ReportableError for GetAuthInfoError {
fn get_error_type(&self) -> crate::error::ErrorKind {
match self {
GetAuthInfoError::BadSecret => crate::error::ErrorKind::ControlPlane,
GetAuthInfoError::ApiError(_) => crate::error::ErrorKind::ControlPlane,
}
}
}
impl UserFacingError for GetAuthInfoError {
fn to_string_client(&self) -> String {
use GetAuthInfoError::*;
@@ -199,16 +181,6 @@ pub mod errors {
}
}
impl ReportableError for WakeComputeError {
fn get_error_type(&self) -> crate::error::ErrorKind {
match self {
WakeComputeError::BadComputeAddress(_) => crate::error::ErrorKind::ControlPlane,
WakeComputeError::ApiError(e) => e.get_error_type(),
WakeComputeError::TimeoutError => crate::error::ErrorKind::RateLimit,
}
}
}
impl UserFacingError for WakeComputeError {
fn to_string_client(&self) -> String {
use WakeComputeError::*;
@@ -263,7 +235,7 @@ pub struct NodeInfo {
pub type NodeInfoCache = TimedLru<SmolStr, NodeInfo>;
pub type CachedNodeInfo = Cached<&'static NodeInfoCache>;
pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option<AuthSecret>>;
pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, AuthSecret>;
pub type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc<Vec<SmolStr>>>;
/// This will allocate per each call, but the http requests alone
@@ -276,75 +248,23 @@ pub trait Api {
async fn get_role_secret(
&self,
ctx: &mut RequestMonitoring,
user_info: &ComputeUserInfo,
) -> Result<CachedRoleSecret, errors::GetAuthInfoError>;
creds: &ComputeUserInfo,
) -> Result<Option<CachedRoleSecret>, errors::GetAuthInfoError>;
async fn get_allowed_ips(
&self,
ctx: &mut RequestMonitoring,
user_info: &ComputeUserInfo,
creds: &ComputeUserInfo,
) -> Result<CachedAllowedIps, errors::GetAuthInfoError>;
/// Wake up the compute node and return the corresponding connection info.
async fn wake_compute(
&self,
ctx: &mut RequestMonitoring,
user_info: &ComputeUserInfo,
creds: &ComputeUserInfo,
) -> Result<CachedNodeInfo, errors::WakeComputeError>;
}
#[derive(Clone)]
pub enum ConsoleBackend {
/// Current Cloud API (V2).
Console(neon::Api),
/// Local mock of Cloud API (V2).
#[cfg(feature = "testing")]
Postgres(mock::Api),
}
#[async_trait]
impl Api for ConsoleBackend {
async fn get_role_secret(
&self,
ctx: &mut RequestMonitoring,
user_info: &ComputeUserInfo,
) -> Result<CachedRoleSecret, errors::GetAuthInfoError> {
use ConsoleBackend::*;
match self {
Console(api) => api.get_role_secret(ctx, user_info).await,
#[cfg(feature = "testing")]
Postgres(api) => api.get_role_secret(ctx, user_info).await,
}
}
async fn get_allowed_ips(
&self,
ctx: &mut RequestMonitoring,
user_info: &ComputeUserInfo,
) -> Result<CachedAllowedIps, errors::GetAuthInfoError> {
use ConsoleBackend::*;
match self {
Console(api) => api.get_allowed_ips(ctx, user_info).await,
#[cfg(feature = "testing")]
Postgres(api) => api.get_allowed_ips(ctx, user_info).await,
}
}
async fn wake_compute(
&self,
ctx: &mut RequestMonitoring,
user_info: &ComputeUserInfo,
) -> Result<CachedNodeInfo, errors::WakeComputeError> {
use ConsoleBackend::*;
match self {
Console(api) => api.wake_compute(ctx, user_info).await,
#[cfg(feature = "testing")]
Postgres(api) => api.wake_compute(ctx, user_info).await,
}
}
}
/// Various caches for [`console`](super).
pub struct ApiCaches {
/// Cache for the `wake_compute` API method.

View File

@@ -150,10 +150,12 @@ impl super::Api for Api {
&self,
_ctx: &mut RequestMonitoring,
user_info: &ComputeUserInfo,
) -> Result<CachedRoleSecret, GetAuthInfoError> {
Ok(CachedRoleSecret::new_uncached(
self.do_get_auth_info(user_info).await?.secret,
))
) -> Result<Option<CachedRoleSecret>, GetAuthInfoError> {
Ok(self
.do_get_auth_info(user_info)
.await?
.secret
.map(CachedRoleSecret::new_uncached))
}
async fn get_allowed_ips(

View File

@@ -86,14 +86,9 @@ impl Api {
},
};
let secret = if body.role_secret.is_empty() {
None
} else {
let secret = scram::ServerSecret::parse(&body.role_secret)
.map(AuthSecret::Scram)
.ok_or(GetAuthInfoError::BadSecret)?;
Some(secret)
};
let secret = scram::ServerSecret::parse(&body.role_secret)
.map(AuthSecret::Scram)
.ok_or(GetAuthInfoError::BadSecret)?;
let allowed_ips = body
.allowed_ips
.into_iter()
@@ -102,7 +97,7 @@ impl Api {
.collect_vec();
ALLOWED_IPS_NUMBER.observe(allowed_ips.len() as f64);
Ok(AuthInfo {
secret,
secret: Some(secret),
allowed_ips,
project_id: body.project_id.map(SmolStr::from),
})
@@ -177,28 +172,26 @@ impl super::Api for Api {
&self,
ctx: &mut RequestMonitoring,
user_info: &ComputeUserInfo,
) -> Result<CachedRoleSecret, GetAuthInfoError> {
) -> Result<Option<CachedRoleSecret>, GetAuthInfoError> {
let ep = &user_info.endpoint;
let user = &user_info.user;
if let Some(role_secret) = self.caches.project_info.get_role_secret(ep, user) {
return Ok(role_secret);
return Ok(Some(role_secret));
}
let auth_info = self.do_get_auth_info(ctx, user_info).await?;
if let Some(project_id) = auth_info.project_id {
self.caches.project_info.insert_role_secret(
&project_id,
ep,
user,
auth_info.secret.clone(),
);
self.caches.project_info.insert_allowed_ips(
&project_id,
ep,
Arc::new(auth_info.allowed_ips),
);
let project_id = auth_info.project_id.unwrap_or(ep.clone());
if let Some(secret) = &auth_info.secret {
self.caches
.project_info
.insert_role_secret(&project_id, ep, user, secret.clone())
}
self.caches.project_info.insert_allowed_ips(
&project_id,
ep,
Arc::new(auth_info.allowed_ips),
);
// When we just got a secret, we don't need to invalidate it.
Ok(Cached::new_uncached(auth_info.secret))
Ok(auth_info.secret.map(Cached::new_uncached))
}
async fn get_allowed_ips(
@@ -219,17 +212,15 @@ impl super::Api for Api {
let auth_info = self.do_get_auth_info(ctx, user_info).await?;
let allowed_ips = Arc::new(auth_info.allowed_ips);
let user = &user_info.user;
if let Some(project_id) = auth_info.project_id {
self.caches.project_info.insert_role_secret(
&project_id,
ep,
user,
auth_info.secret.clone(),
);
let project_id = auth_info.project_id.unwrap_or(ep.clone());
if let Some(secret) = &auth_info.secret {
self.caches
.project_info
.insert_allowed_ips(&project_id, ep, allowed_ips.clone());
.insert_role_secret(&project_id, ep, user, secret.clone())
}
self.caches
.project_info
.insert_allowed_ips(&project_id, ep, allowed_ips.clone());
Ok(Cached::new_uncached(allowed_ips))
}

View File

@@ -38,7 +38,6 @@ pub struct RequestMonitoring {
// This sender is here to keep the request monitoring channel open while requests are taking place.
sender: Option<mpsc::UnboundedSender<RequestMonitoring>>,
pub latency_timer: LatencyTimer,
root_span: tracing::Span,
}
impl RequestMonitoring {
@@ -47,7 +46,6 @@ impl RequestMonitoring {
peer_addr: IpAddr,
protocol: &'static str,
region: &'static str,
root_span: tracing::Span,
) -> Self {
Self {
peer_addr,
@@ -66,19 +64,12 @@ impl RequestMonitoring {
sender: LOG_CHAN.get().and_then(|tx| tx.upgrade()),
latency_timer: LatencyTimer::new(protocol),
root_span,
}
}
#[cfg(test)]
pub fn test() -> Self {
RequestMonitoring::new(
Uuid::now_v7(),
[127, 0, 0, 1].into(),
"test",
"test",
tracing::Span::none(),
)
RequestMonitoring::new(Uuid::now_v7(), [127, 0, 0, 1].into(), "test", "test")
}
pub fn console_application_name(&self) -> String {
@@ -96,10 +87,7 @@ impl RequestMonitoring {
}
pub fn set_endpoint_id(&mut self, endpoint_id: Option<SmolStr>) {
if let (None, Some(ep)) = (self.endpoint_id.as_ref(), endpoint_id) {
self.root_span.record("ep", &*ep);
self.endpoint_id = Some(ep)
}
self.endpoint_id = endpoint_id.or_else(|| self.endpoint_id.clone());
}
pub fn set_application(&mut self, app: Option<SmolStr>) {
@@ -114,10 +102,6 @@ impl RequestMonitoring {
self.success = true;
}
pub fn error(&mut self, err: ErrorKind) {
self.error_kind = Some(err);
}
pub fn log(&mut self) {
if let Some(tx) = self.sender.take() {
let _: Result<(), _> = tx.send(self.clone());

View File

@@ -17,16 +17,19 @@ pub fn log_error<E: fmt::Display>(e: E) -> E {
/// NOTE: This trait should not be implemented for [`anyhow::Error`], since it
/// is way too convenient and tends to proliferate all across the codebase,
/// ultimately leading to accidental leaks of sensitive data.
pub trait UserFacingError: ReportableError {
pub trait UserFacingError: fmt::Display {
/// Format the error for client, stripping all sensitive info.
///
/// Although this might be a no-op for many types, it's highly
/// recommended to override the default impl in case error type
/// contains anything sensitive: various IDs, IP addresses etc.
fn to_string_client(&self) -> String;
#[inline(always)]
fn to_string_client(&self) -> String {
self.to_string()
}
}
#[derive(Clone, Copy)]
#[derive(Clone)]
pub enum ErrorKind {
/// Wrong password, unknown endpoint, protocol violation, etc...
User,
@@ -59,7 +62,3 @@ impl ErrorKind {
}
}
}
pub trait ReportableError: fmt::Display + Send + 'static {
fn get_error_type(&self) -> ErrorKind;
}

View File

@@ -26,7 +26,6 @@ pub mod redis;
pub mod sasl;
pub mod scram;
pub mod serverless;
pub mod state_machine;
pub mod stream;
pub mod url;
pub mod usage_metrics;

View File

@@ -2,32 +2,38 @@
mod tests;
pub mod connect_compute;
pub mod handshake;
pub mod pass;
pub mod retry;
pub mod wake_compute;
use crate::{
cancellation::CancelMap,
config::{ProxyConfig, TlsConfig},
auth,
cancellation::{self, CancelMap},
compute,
config::{AuthenticationConfig, ProxyConfig, TlsConfig},
console::messages::MetricsAuxInfo,
context::RequestMonitoring,
metrics::{NUM_CLIENT_CONNECTION_GAUGE, NUM_CONNECTION_REQUESTS_GAUGE},
metrics::{
NUM_BYTES_PROXIED_COUNTER, NUM_BYTES_PROXIED_PER_CLIENT_COUNTER,
NUM_CLIENT_CONNECTION_GAUGE, NUM_CONNECTION_REQUESTS_GAUGE,
},
protocol2::WithClientIp,
proxy::handshake::NeedsHandshake,
rate_limiter::EndpointRateLimiter,
state_machine::{DynStage, StageResult},
stream::Stream,
stream::{PqStream, Stream},
usage_metrics::{Ids, USAGE_METRICS},
};
use anyhow::Context;
use anyhow::{bail, Context};
use futures::TryFutureExt;
use itertools::Itertools;
use once_cell::sync::OnceCell;
use pq_proto::StartupMessageParams;
use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams};
use regex::Regex;
use smol_str::SmolStr;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio_util::sync::CancellationToken;
use tracing::{error, info, info_span, Instrument};
use utils::measured_stream::MeasuredStream;
use self::connect_compute::{connect_to_compute, TcpMechanism};
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
const ERR_PROTO_VIOLATION: &str = "protocol violation";
@@ -73,64 +79,45 @@ pub async fn task_main(
let cancel_map = Arc::clone(&cancel_map);
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
let root_span = info_span!(
"handle_client",
?session_id,
peer_addr = tracing::field::Empty,
ep = tracing::field::Empty,
);
let root_span2 = root_span.clone();
connections.spawn(
async move {
info!("accepted postgres client connection");
let mut socket = WithClientIp::new(socket);
let mut peer_addr = peer_addr.ip();
match socket.wait_for_addr().await {
Err(e) => {
error!("IO error: {e:#}");
return;
}
Ok(Some(addr)) => {
peer_addr = addr.ip();
root_span2.record("peer_addr", &tracing::field::display(addr));
}
Ok(None) if config.require_client_ip => {
error!("missing required client IP");
return;
}
Ok(None) => {}
};
if let Some(addr) = socket.wait_for_addr().await? {
peer_addr = addr.ip();
tracing::Span::current().record("peer_addr", &tracing::field::display(addr));
} else if config.require_client_ip {
bail!("missing required client IP");
}
let ctx = RequestMonitoring::new(
session_id,
peer_addr,
"tcp",
&config.region,
root_span2,
);
let mut ctx = RequestMonitoring::new(session_id, peer_addr, "tcp", &config.region);
if let Err(e) = socket
socket
.inner
.set_nodelay(true)
.context("failed to set socket option")
{
error!("could not set nodelay: {e:#}");
return;
}
.context("failed to set socket option")?;
handle_client(
config,
ctx,
cancel_map,
&mut ctx,
&cancel_map,
socket,
ClientMode::Tcp,
endpoint_rate_limiter,
)
.await;
.await
}
.instrument(root_span),
.instrument(info_span!(
"handle_client",
?session_id,
peer_addr = tracing::field::Empty
))
.unwrap_or_else(move |e| {
// Acknowledge that the task has finished with an error.
error!(?session_id, "per-client task finished with an error: {e:#}");
}),
);
}
@@ -150,14 +137,14 @@ pub enum ClientMode {
/// Abstracts the logic of handling TCP vs WS clients
impl ClientMode {
pub fn allow_cleartext(&self) -> bool {
fn allow_cleartext(&self) -> bool {
match self {
ClientMode::Tcp => false,
ClientMode::Websockets { .. } => true,
}
}
pub fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool {
fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool {
match self {
ClientMode::Tcp => config.allow_self_signed_compute,
ClientMode::Websockets { .. } => false,
@@ -180,14 +167,14 @@ impl ClientMode {
}
}
pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + 'static + Send>(
pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
config: &'static ProxyConfig,
ctx: RequestMonitoring,
cancel_map: Arc<CancelMap>,
ctx: &mut RequestMonitoring,
cancel_map: &CancelMap,
stream: S,
mode: ClientMode,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) {
) -> anyhow::Result<()> {
info!(
protocol = ctx.protocol,
"handling interactive connection from client"
@@ -201,26 +188,311 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + 'static + Send>(
.with_label_values(&[proto])
.guard();
let mut stage = Box::new(NeedsHandshake {
stream,
config,
cancel_map,
mode,
endpoint_rate_limiter,
ctx,
}) as DynStage;
let tls = config.tls_config.as_ref();
while let StageResult::Run(handle) = stage.run() {
stage = match handle.await.expect("tasks should not panic") {
Ok(s) => s,
Err(e) => {
e.finish().await;
break;
let pause = ctx.latency_timer.pause();
let do_handshake = handshake(stream, mode.handshake_tls(tls), cancel_map);
let (mut stream, params) = match do_handshake.await? {
Some(x) => x,
None => return Ok(()), // it's a cancellation request
};
drop(pause);
// Extract credentials which we're going to use for auth.
let user_info = {
let hostname = mode.hostname(stream.get_ref());
let common_names = tls.map(|tls| &tls.common_names);
let result = config
.auth_backend
.as_ref()
.map(|_| {
auth::ComputeUserInfoMaybeEndpoint::parse(ctx, &params, hostname, common_names)
})
.transpose();
match result {
Ok(user_info) => user_info,
Err(e) => stream.throw_error(e).await?,
}
};
ctx.set_endpoint_id(user_info.get_endpoint());
let client = Client::new(
stream,
user_info,
&params,
mode.allow_self_signed_compute(config),
endpoint_rate_limiter,
);
cancel_map
.with_session(|session| {
client.connect_to_db(ctx, session, mode, &config.authentication_config)
})
.await
}
/// Establish a (most probably, secure) connection with the client.
/// For better testing experience, `stream` can be any object satisfying the traits.
/// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
/// we also take an extra care of propagating only the select handshake errors to client.
#[tracing::instrument(skip_all)]
async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
stream: S,
mut tls: Option<&TlsConfig>,
cancel_map: &CancelMap,
) -> anyhow::Result<Option<(PqStream<Stream<S>>, StartupMessageParams)>> {
// Client may try upgrading to each protocol only once
let (mut tried_ssl, mut tried_gss) = (false, false);
let mut stream = PqStream::new(Stream::from_raw(stream));
loop {
let msg = stream.read_startup_packet().await?;
info!("received {msg:?}");
use FeStartupPacket::*;
match msg {
SslRequest => match stream.get_ref() {
Stream::Raw { .. } if !tried_ssl => {
tried_ssl = true;
// We can't perform TLS handshake without a config
let enc = tls.is_some();
stream.write_message(&Be::EncryptionResponse(enc)).await?;
if let Some(tls) = tls.take() {
// Upgrade raw stream into a secure TLS-backed stream.
// NOTE: We've consumed `tls`; this fact will be used later.
let (raw, read_buf) = stream.into_inner();
// TODO: Normally, client doesn't send any data before
// server says TLS handshake is ok and read_buf is empy.
// However, you could imagine pipelining of postgres
// SSLRequest + TLS ClientHello in one hunk similar to
// pipelining in our node js driver. We should probably
// support that by chaining read_buf with the stream.
if !read_buf.is_empty() {
bail!("data is sent before server replied with EncryptionResponse");
}
let tls_stream = raw.upgrade(tls.to_server_config()).await?;
let (_, tls_server_end_point) = tls
.cert_resolver
.resolve(tls_stream.get_ref().1.server_name())
.context("missing certificate")?;
stream = PqStream::new(Stream::Tls {
tls: Box::new(tls_stream),
tls_server_end_point,
});
}
}
_ => bail!(ERR_PROTO_VIOLATION),
},
GssEncRequest => match stream.get_ref() {
Stream::Raw { .. } if !tried_gss => {
tried_gss = true;
// Currently, we don't support GSSAPI
stream.write_message(&Be::EncryptionResponse(false)).await?;
}
_ => bail!(ERR_PROTO_VIOLATION),
},
StartupMessage { params, .. } => {
// Check that the config has been consumed during upgrade
// OR we didn't provide it at all (for dev purposes).
if tls.is_some() {
stream.throw_error_str(ERR_INSECURE_CONNECTION).await?;
}
info!(session_type = "normal", "successful handshake");
break Ok(Some((stream, params)));
}
CancelRequest(cancel_key_data) => {
cancel_map.cancel_session(cancel_key_data).await?;
info!(session_type = "cancellation", "successful handshake");
break Ok(None);
}
}
}
}
/// Finish client connection initialization: confirm auth success, send params, etc.
#[tracing::instrument(skip_all)]
async fn prepare_client_connection(
node: &compute::PostgresConnection,
session: cancellation::Session<'_>,
stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> anyhow::Result<()> {
// Register compute's query cancellation token and produce a new, unique one.
// The new token (cancel_key_data) will be sent to the client.
let cancel_key_data = session.enable_query_cancellation(node.cancel_closure.clone());
// Forward all postgres connection params to the client.
// Right now the implementation is very hacky and inefficent (ideally,
// we don't need an intermediate hashmap), but at least it should be correct.
for (name, value) in &node.params {
// TODO: Theoretically, this could result in a big pile of params...
stream.write_message_noflush(&Be::ParameterStatus {
name: name.as_bytes(),
value: value.as_bytes(),
})?;
}
stream
.write_message_noflush(&Be::BackendKeyData(cancel_key_data))?
.write_message(&Be::ReadyForQuery)
.await?;
Ok(())
}
/// Forward bytes in both directions (client <-> compute).
#[tracing::instrument(skip_all)]
pub async fn proxy_pass(
ctx: &mut RequestMonitoring,
client: impl AsyncRead + AsyncWrite + Unpin,
compute: impl AsyncRead + AsyncWrite + Unpin,
aux: MetricsAuxInfo,
) -> anyhow::Result<()> {
ctx.set_success();
ctx.log();
let usage = USAGE_METRICS.register(Ids {
endpoint_id: aux.endpoint_id.clone(),
branch_id: aux.branch_id.clone(),
});
let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["tx"]);
let m_sent2 = NUM_BYTES_PROXIED_PER_CLIENT_COUNTER.with_label_values(&aux.traffic_labels("tx"));
let mut client = MeasuredStream::new(
client,
|_| {},
|cnt| {
// Number of bytes we sent to the client (outbound).
m_sent.inc_by(cnt as u64);
m_sent2.inc_by(cnt as u64);
usage.record_egress(cnt as u64);
},
);
let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["rx"]);
let m_recv2 = NUM_BYTES_PROXIED_PER_CLIENT_COUNTER.with_label_values(&aux.traffic_labels("rx"));
let mut compute = MeasuredStream::new(
compute,
|_| {},
|cnt| {
// Number of bytes the client sent to the compute node (inbound).
m_recv.inc_by(cnt as u64);
m_recv2.inc_by(cnt as u64);
},
);
// Starting from here we only proxy the client's traffic.
info!("performing the proxy pass...");
let _ = tokio::io::copy_bidirectional(&mut client, &mut compute).await?;
Ok(())
}
/// Thin connection context.
struct Client<'a, S> {
/// The underlying libpq protocol stream.
stream: PqStream<Stream<S>>,
/// Client credentials that we care about.
user_info: auth::BackendType<'a, auth::ComputeUserInfoMaybeEndpoint>,
/// KV-dictionary with PostgreSQL connection params.
params: &'a StartupMessageParams,
/// Allow self-signed certificates (for testing).
allow_self_signed_compute: bool,
/// Rate limiter for endpoints
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
}
impl<'a, S> Client<'a, S> {
/// Construct a new connection context.
fn new(
stream: PqStream<Stream<S>>,
user_info: auth::BackendType<'a, auth::ComputeUserInfoMaybeEndpoint>,
params: &'a StartupMessageParams,
allow_self_signed_compute: bool,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> Self {
Self {
stream,
user_info,
params,
allow_self_signed_compute,
endpoint_rate_limiter,
}
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
/// Let the client authenticate and connect to the designated compute node.
// Instrumentation logs endpoint name everywhere. Doesn't work for link
// auth; strictly speaking we don't know endpoint name in its case.
#[tracing::instrument(name = "", fields(ep = %self.user_info.get_endpoint().unwrap_or_default()), skip_all)]
async fn connect_to_db(
self,
ctx: &mut RequestMonitoring,
session: cancellation::Session<'_>,
mode: ClientMode,
config: &'static AuthenticationConfig,
) -> anyhow::Result<()> {
let Self {
mut stream,
user_info,
params,
allow_self_signed_compute,
endpoint_rate_limiter,
} = self;
// check rate limit
if let Some(ep) = user_info.get_endpoint() {
if !endpoint_rate_limiter.check(ep) {
return stream
.throw_error(auth::AuthError::too_many_connections())
.await;
}
}
let user = user_info.get_user().to_owned();
let auth_result = match user_info
.authenticate(ctx, &mut stream, mode.allow_cleartext(), config)
.await
{
Ok(auth_result) => auth_result,
Err(e) => {
let db = params.get("database");
let app = params.get("application_name");
let params_span = tracing::info_span!("", ?user, ?db, ?app);
return stream.throw_error(e).instrument(params_span).await;
}
};
let (mut node_info, user_info) = auth_result;
node_info.allow_self_signed_compute = allow_self_signed_compute;
let aux = node_info.aux.clone();
let mut node = connect_to_compute(ctx, &TcpMechanism { params }, node_info, &user_info)
.or_else(|e| stream.throw_error(e))
.await?;
prepare_client_connection(&node, session, &mut stream).await?;
// Before proxy passing, forward to compute whatever data is left in the
// PqStream input buffer. Normally there is none, but our serverless npm
// driver in pipeline mode sends startup, password and first query
// immediately after opening the connection.
let (stream, read_buf) = stream.into_inner();
node.stream.write_all(&read_buf).await?;
proxy_pass(ctx, stream, node.stream, aux).await
}
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct NeonOptions(Vec<(SmolStr, SmolStr)>);

View File

@@ -1,120 +1,20 @@
use crate::{
auth,
cancellation::{self, Session},
compute::{self, PostgresConnection},
console::{self, errors::WakeComputeError, Api},
context::RequestMonitoring,
metrics::{bool_to_str, NUM_CONNECTION_FAILURES, NUM_WAKEUP_FAILURES},
state_machine::{DynStage, ResultExt, Stage, StageError},
stream::{PqStream, Stream},
proxy::retry::{retry_after, ShouldRetry},
};
use async_trait::async_trait;
use hyper::StatusCode;
use pq_proto::StartupMessageParams;
use std::ops::ControlFlow;
use tokio::{
io::{AsyncRead, AsyncWrite, AsyncWriteExt},
time,
};
use tokio::time;
use tracing::{error, info, warn};
use pq_proto::BeMessage as Be;
use super::{
pass::ProxyPass,
retry::{retry_after, ShouldRetry},
};
const CONNECT_TIMEOUT: time::Duration = time::Duration::from_secs(2);
pub struct NeedsComputeConnection<S> {
pub stream: PqStream<Stream<S>>,
pub user_info: auth::BackendType<'static, auth::backend::ComputeUserInfo>,
pub mechanism: TcpMechanism,
pub node_info: console::CachedNodeInfo,
// monitoring
pub ctx: RequestMonitoring,
pub cancel_session: Session,
}
impl<S> Stage for NeedsComputeConnection<S>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
fn span(&self) -> tracing::Span {
tracing::info_span!("connect_to_compute")
}
async fn run(self) -> Result<DynStage, StageError> {
let Self {
stream,
user_info,
mechanism,
node_info,
mut ctx,
cancel_session,
} = self;
let aux = node_info.aux.clone();
let (mut node, mut stream) =
connect_to_compute(&mut ctx, &mechanism, node_info, &user_info)
.await
.send_error_to_user(&mut ctx, stream)?;
prepare_client_connection(&node, &cancel_session, &mut stream)
.await
.no_user_error(&mut ctx, crate::error::ErrorKind::Disconnect)?;
// Before proxy passing, forward to compute whatever data is left in the
// PqStream input buffer. Normally there is none, but our serverless npm
// driver in pipeline mode sends startup, password and first query
// immediately after opening the connection.
let (stream, read_buf) = stream.into_inner();
node.stream
.write_all(&read_buf)
.await
.no_user_error(&mut ctx, crate::error::ErrorKind::Disconnect)?;
Ok(Box::new(ProxyPass {
client: stream,
compute: node.stream,
aux,
cancel_session,
}))
}
}
/// Finish client connection initialization: confirm auth success, send params, etc.
#[tracing::instrument(skip_all)]
async fn prepare_client_connection(
node: &compute::PostgresConnection,
session: &cancellation::Session,
stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> std::io::Result<()> {
// Register compute's query cancellation token and produce a new, unique one.
// The new token (cancel_key_data) will be sent to the client.
let cancel_key_data = session.enable_query_cancellation(node.cancel_closure.clone());
// Forward all postgres connection params to the client.
// Right now the implementation is very hacky and inefficent (ideally,
// we don't need an intermediate hashmap), but at least it should be correct.
for (name, value) in &node.params {
// TODO: Theoretically, this could result in a big pile of params...
stream.write_message_noflush(&Be::ParameterStatus {
name: name.as_bytes(),
value: value.as_bytes(),
})?;
}
stream
.write_message_noflush(&Be::BackendKeyData(cancel_key_data))?
.write_message(&Be::ReadyForQuery)
.await?;
Ok(())
}
/// If we couldn't connect, a cached connection info might be to blame
/// (e.g. the compute node's address might've changed at the wrong time).
/// Invalidate the cache entry (if any) to prevent subsequent errors.
@@ -163,13 +63,13 @@ pub trait ConnectMechanism {
fn update_connect_config(&self, conf: &mut compute::ConnCfg);
}
pub struct TcpMechanism {
pub struct TcpMechanism<'a> {
/// KV-dictionary with PostgreSQL connection params.
pub params: StartupMessageParams,
pub params: &'a StartupMessageParams,
}
#[async_trait]
impl ConnectMechanism for TcpMechanism {
impl ConnectMechanism for TcpMechanism<'_> {
type Connection = PostgresConnection;
type ConnectError = compute::ConnectionError;
type Error = compute::ConnectionError;
@@ -184,7 +84,7 @@ impl ConnectMechanism for TcpMechanism {
}
fn update_connect_config(&self, config: &mut compute::ConnCfg) {
config.set_startup_params(&self.params);
config.set_startup_params(self.params);
}
}
@@ -260,6 +160,8 @@ where
let node_info = loop {
let wake_res = match user_info {
auth::BackendType::Console(api, user_info) => api.wake_compute(ctx, user_info).await,
#[cfg(feature = "testing")]
auth::BackendType::Postgres(api, user_info) => api.wake_compute(ctx, user_info).await,
// nothing to do?
auth::BackendType::Link(_) => return Err(err.into()),
// test backend

View File

@@ -1,203 +0,0 @@
use crate::{
auth::{self, backend::NeedsAuthentication},
cancellation::CancelMap,
config::{ProxyConfig, TlsConfig},
context::RequestMonitoring,
error::ReportableError,
proxy::{ERR_INSECURE_CONNECTION, ERR_PROTO_VIOLATION},
rate_limiter::EndpointRateLimiter,
state_machine::{DynStage, Finished, ResultExt, Stage, StageError},
stream::{PqStream, Stream, StreamUpgradeError},
};
use anyhow::{anyhow, Context};
use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams};
use std::{io, sync::Arc};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{error, info};
use super::ClientMode;
pub struct NeedsHandshake<S> {
pub stream: S,
pub config: &'static ProxyConfig,
pub cancel_map: Arc<CancelMap>,
pub mode: ClientMode,
pub endpoint_rate_limiter: Arc<EndpointRateLimiter>,
// monitoring
pub ctx: RequestMonitoring,
}
impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> Stage for NeedsHandshake<S> {
fn span(&self) -> tracing::Span {
tracing::info_span!("handshake")
}
async fn run(self) -> Result<DynStage, StageError> {
let Self {
stream,
config,
cancel_map,
mode,
endpoint_rate_limiter,
mut ctx,
} = self;
let tls = config.tls_config.as_ref();
let pause_timer = ctx.latency_timer.pause();
let handshake = handshake(stream, mode.handshake_tls(tls), &cancel_map).await;
drop(pause_timer);
let (stream, params) = match handshake {
Err(err) => {
// TODO: proper handling
error!("could not complete handshake: {err:#}");
return Err(StageError::Done);
}
// cancellation
Ok(None) => return Ok(Box::new(Finished)),
Ok(Some(s)) => s,
};
let hostname = mode.hostname(stream.get_ref());
let common_names = tls.map(|tls| &tls.common_names);
let (creds, stream) = config
.auth_backend
.as_ref()
.map(|_| {
auth::ComputeUserInfoMaybeEndpoint::parse(&mut ctx, &params, hostname, common_names)
})
.transpose()
.send_error_to_user(&mut ctx, stream)?;
ctx.set_endpoint_id(creds.get_endpoint());
Ok(Box::new(NeedsAuthentication {
stream,
creds,
params,
endpoint_rate_limiter,
mode,
config,
ctx,
cancel_session: cancel_map.get_session(),
}))
}
}
#[derive(Error, Debug)]
pub enum HandshakeError {
#[error("client disconnected: {0}")]
ClientIO(#[from] io::Error),
#[error("protocol violation: {0}")]
ProtocolError(#[from] anyhow::Error),
#[error("could not initiate tls connection: {0}")]
TLSError(#[from] StreamUpgradeError),
#[error("could not cancel connection: {0}")]
Cancel(anyhow::Error),
}
impl ReportableError for HandshakeError {
fn get_error_type(&self) -> crate::error::ErrorKind {
match self {
HandshakeError::ClientIO(_) => crate::error::ErrorKind::Disconnect,
HandshakeError::ProtocolError(_) => crate::error::ErrorKind::User,
HandshakeError::TLSError(_) => crate::error::ErrorKind::User,
HandshakeError::Cancel(_) => crate::error::ErrorKind::Compute,
}
}
}
type SuccessfulHandshake<S> = (PqStream<Stream<S>>, StartupMessageParams);
/// Establish a (most probably, secure) connection with the client.
/// For better testing experience, `stream` can be any object satisfying the traits.
/// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
/// we also take an extra care of propagating only the select handshake errors to client.
pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
stream: S,
mut tls: Option<&TlsConfig>,
cancel_map: &CancelMap,
) -> Result<Option<SuccessfulHandshake<S>>, HandshakeError> {
// Client may try upgrading to each protocol only once
let (mut tried_ssl, mut tried_gss) = (false, false);
let mut stream = PqStream::new(Stream::from_raw(stream));
loop {
let msg = stream.read_startup_packet().await?;
info!("received {msg:?}");
use FeStartupPacket::*;
match msg {
SslRequest => match stream.get_ref() {
Stream::Raw { .. } if !tried_ssl => {
tried_ssl = true;
// We can't perform TLS handshake without a config
let enc = tls.is_some();
stream.write_message(&Be::EncryptionResponse(enc)).await?;
if let Some(tls) = tls.take() {
// Upgrade raw stream into a secure TLS-backed stream.
// NOTE: We've consumed `tls`; this fact will be used later.
let (raw, read_buf) = stream.into_inner();
// TODO: Normally, client doesn't send any data before
// server says TLS handshake is ok and read_buf is empy.
// However, you could imagine pipelining of postgres
// SSLRequest + TLS ClientHello in one hunk similar to
// pipelining in our node js driver. We should probably
// support that by chaining read_buf with the stream.
if !read_buf.is_empty() {
return Err(HandshakeError::ProtocolError(anyhow!(
"data is sent before server replied with EncryptionResponse"
)));
}
let tls_stream = raw.upgrade(tls.to_server_config()).await?;
let (_, tls_server_end_point) = tls
.cert_resolver
.resolve(tls_stream.get_ref().1.server_name())
.context("missing certificate")?;
stream = PqStream::new(Stream::Tls {
tls: Box::new(tls_stream),
tls_server_end_point,
});
}
}
_ => return Err(HandshakeError::ProtocolError(anyhow!(ERR_PROTO_VIOLATION))),
},
GssEncRequest => match stream.get_ref() {
Stream::Raw { .. } if !tried_gss => {
tried_gss = true;
// Currently, we don't support GSSAPI
stream.write_message(&Be::EncryptionResponse(false)).await?;
}
_ => return Err(HandshakeError::ProtocolError(anyhow!(ERR_PROTO_VIOLATION))),
},
StartupMessage { params, .. } => {
// Check that the config has been consumed during upgrade
// OR we didn't provide it at all (for dev purposes).
if tls.is_some() {
stream.throw_error_str(ERR_INSECURE_CONNECTION).await?;
}
info!(session_type = "normal", "successful handshake");
break Ok(Some((stream, params)));
}
CancelRequest(cancel_key_data) => {
cancel_map
.cancel_session(cancel_key_data)
.await
.map_err(HandshakeError::Cancel)?;
info!(session_type = "cancellation", "successful handshake");
break Ok(None);
}
}
}
}

View File

@@ -1,82 +0,0 @@
use crate::{
cancellation::Session,
console::messages::MetricsAuxInfo,
metrics::{NUM_BYTES_PROXIED_COUNTER, NUM_BYTES_PROXIED_PER_CLIENT_COUNTER},
state_machine::{DynStage, Finished, Stage, StageError},
stream::Stream,
usage_metrics::{Ids, USAGE_METRICS},
};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{error, info};
use utils::measured_stream::MeasuredStream;
pub struct ProxyPass<Client, Compute> {
pub client: Stream<Client>,
pub compute: Compute,
// monitoring
pub aux: MetricsAuxInfo,
pub cancel_session: Session,
}
impl<Client, Compute> Stage for ProxyPass<Client, Compute>
where
Client: AsyncRead + AsyncWrite + Unpin + Send + 'static,
Compute: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
fn span(&self) -> tracing::Span {
tracing::info_span!("proxy_pass")
}
async fn run(self) -> Result<DynStage, StageError> {
if let Err(e) = proxy_pass(self.client, self.compute, self.aux).await {
error!("{e:#}")
}
drop(self.cancel_session);
Ok(Box::new(Finished))
}
}
/// Forward bytes in both directions (client <-> compute).
pub async fn proxy_pass(
client: impl AsyncRead + AsyncWrite + Unpin,
compute: impl AsyncRead + AsyncWrite + Unpin,
aux: MetricsAuxInfo,
) -> anyhow::Result<()> {
let usage = USAGE_METRICS.register(Ids {
endpoint_id: aux.endpoint_id.clone(),
branch_id: aux.branch_id.clone(),
});
let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["tx"]);
let m_sent2 = NUM_BYTES_PROXIED_PER_CLIENT_COUNTER.with_label_values(&aux.traffic_labels("tx"));
let mut client = MeasuredStream::new(
client,
|_| {},
|cnt| {
// Number of bytes we sent to the client (outbound).
m_sent.inc_by(cnt as u64);
m_sent2.inc_by(cnt as u64);
usage.record_egress(cnt as u64);
},
);
let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["rx"]);
let m_recv2 = NUM_BYTES_PROXIED_PER_CLIENT_COUNTER.with_label_values(&aux.traffic_labels("rx"));
let mut compute = MeasuredStream::new(
compute,
|_| {},
|cnt| {
// Number of bytes the client sent to the compute node (inbound).
m_recv.inc_by(cnt as u64);
m_recv2.inc_by(cnt as u64);
},
);
// Starting from here we only proxy the client's traffic.
info!("performing the proxy pass...");
let _ = tokio::io::copy_bidirectional(&mut client, &mut compute).await?;
Ok(())
}

View File

@@ -3,19 +3,14 @@
mod mitm;
use super::connect_compute::ConnectMechanism;
use super::handshake::handshake;
use super::retry::ShouldRetry;
use super::*;
use crate::auth::backend::{ComputeUserInfo, TestBackend};
use crate::config::CertResolver;
use crate::console::{self, CachedNodeInfo, NodeInfo};
use crate::proxy::connect_compute::connect_to_compute;
use crate::proxy::retry::{retry_after, NUM_RETRIES_CONNECT};
use crate::stream::PqStream;
use crate::{auth, compute, http, sasl, scram};
use anyhow::bail;
use crate::{auth, http, sasl, scram};
use async_trait::async_trait;
use pq_proto::BeMessage as Be;
use rstest::rstest;
use smol_str::SmolStr;
use tokio_postgres::config::SslMode;
@@ -207,7 +202,7 @@ async fn handshake_tls_is_enforced_by_proxy() -> anyhow::Result<()> {
.err() // -> Option<E>
.context("server shouldn't accept client")?;
assert!(server_err.to_string().contains(ERR_INSECURE_CONNECTION));
assert!(client_err.to_string().contains(&server_err.to_string()));
Ok(())
}

View File

@@ -10,7 +10,7 @@ use super::*;
use bytes::{Bytes, BytesMut};
use futures::{SinkExt, StreamExt};
use postgres_protocol::message::frontend;
use tokio::io::{AsyncReadExt, AsyncWriteExt, DuplexStream};
use tokio::io::{AsyncReadExt, DuplexStream};
use tokio_postgres::config::SslMode;
use tokio_postgres::tls::TlsConnect;
use tokio_util::codec::{Decoder, Encoder};

View File

@@ -1,89 +0,0 @@
use std::{borrow::Cow, ops::ControlFlow};
use pq_proto::StartupMessageParams;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{error, warn};
use crate::{
auth::{
backend::{ComputeCredentialKeys, ComputeCredentials},
BackendType,
},
cancellation::Session,
console::{provider::ConsoleBackend, Api},
context::RequestMonitoring,
state_machine::{user_facing_error, DynStage, Stage, StageError},
stream::{PqStream, Stream},
};
use super::{
connect_compute::{handle_try_wake, NeedsComputeConnection, TcpMechanism},
retry::retry_after,
};
pub struct NeedsWakeCompute<S> {
pub stream: PqStream<Stream<S>>,
pub api: Cow<'static, ConsoleBackend>,
pub params: StartupMessageParams,
pub allow_self_signed_compute: bool,
pub creds: ComputeCredentials<ComputeCredentialKeys>,
// monitoring
pub ctx: RequestMonitoring,
pub cancel_session: Session,
}
impl<S: AsyncRead + AsyncWrite + Unpin + Send + 'static> Stage for NeedsWakeCompute<S> {
fn span(&self) -> tracing::Span {
tracing::info_span!("wake_compute")
}
async fn run(self) -> Result<DynStage, StageError> {
let Self {
stream,
api,
params,
allow_self_signed_compute,
creds,
mut ctx,
cancel_session,
} = self;
let mut num_retries = 0;
let mut node_info = loop {
let wake_res = api.wake_compute(&mut ctx, &creds.info).await;
match handle_try_wake(wake_res, num_retries) {
Err(e) => {
error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node");
return Err(user_facing_error(e, &mut ctx, stream));
}
Ok(ControlFlow::Continue(e)) => {
warn!(error = ?e, num_retries, retriable = true, "couldn't wake compute node");
}
Ok(ControlFlow::Break(n)) => break n,
}
let wait_duration = retry_after(num_retries);
num_retries += 1;
tokio::time::sleep(wait_duration).await;
};
ctx.set_project(node_info.aux.clone());
node_info.allow_self_signed_compute = allow_self_signed_compute;
match creds.keys {
#[cfg(feature = "testing")]
ComputeCredentialKeys::Password(password) => node_info.config.password(password),
ComputeCredentialKeys::AuthKeys(auth_keys) => node_info.config.auth_keys(auth_keys),
};
Ok(Box::new(NeedsComputeConnection {
stream,
user_info: BackendType::Console(api, creds.info),
mechanism: TcpMechanism { params },
node_info,
ctx,
cancel_session,
}))
}
}

View File

@@ -10,7 +10,7 @@ mod channel_binding;
mod messages;
mod stream;
use crate::error::{ReportableError, UserFacingError};
use crate::error::UserFacingError;
use std::io;
use thiserror::Error;
@@ -37,25 +37,6 @@ pub enum Error {
Io(#[from] io::Error),
}
impl ReportableError for Error {
fn get_error_type(&self) -> crate::error::ErrorKind {
match self {
Error::ChannelBindingFailed(_) => crate::error::ErrorKind::User,
Error::ChannelBindingBadMethod(_) => crate::error::ErrorKind::User,
Error::BadClientMessage(_) => crate::error::ErrorKind::User,
Error::MissingBinding => crate::error::ErrorKind::Service,
Error::Io(io) => match io.kind() {
// tokio postgres uses these for various scram failures
io::ErrorKind::InvalidInput
| io::ErrorKind::UnexpectedEof
| io::ErrorKind::Other => crate::error::ErrorKind::User,
// all other IO errors are likely disconnects.
_ => crate::error::ErrorKind::Disconnect,
},
}
}
}
impl UserFacingError for Error {
fn to_string_client(&self) -> String {
use Error::*;

View File

@@ -124,12 +124,6 @@ pub async fn task_main(
let cancel_map = Arc::new(CancelMap::default());
let session_id = uuid::Uuid::new_v4();
let root_span = info_span!(
"serverless",
session = %session_id,
%peer_addr,
);
request_handler(
req,
config,
@@ -141,9 +135,12 @@ pub async fn task_main(
sni_name,
peer_addr.ip(),
endpoint_rate_limiter,
root_span.clone(),
)
.instrument(root_span)
.instrument(info_span!(
"serverless",
session = %session_id,
%peer_addr,
))
.await
}
},
@@ -208,7 +205,6 @@ async fn request_handler(
sni_hostname: Option<String>,
peer_addr: IpAddr,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
root_span: tracing::Span,
) -> Result<Response<Body>, ApiError> {
let host = request
.headers()
@@ -219,33 +215,27 @@ async fn request_handler(
// Check if the request is a websocket upgrade request.
if hyper_tungstenite::is_upgrade_request(&request) {
info!("performing websocket upgrade");
info!(session_id = ?session_id, "performing websocket upgrade");
let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None)
.map_err(|e| ApiError::BadRequest(e.into()))?;
ws_connections.spawn(
async move {
let ctx =
RequestMonitoring::new(session_id, peer_addr, "ws", &config.region, root_span);
let mut ctx = RequestMonitoring::new(session_id, peer_addr, "ws", &config.region);
let websocket = match websocket.await {
Err(e) => {
error!("error in websocket connection: {e:#}");
return;
}
Ok(ws) => ws,
};
websocket::serve_websocket(
if let Err(e) = websocket::serve_websocket(
config,
ctx,
&mut ctx,
websocket,
cancel_map,
&cancel_map,
host,
endpoint_rate_limiter,
)
.await
{
error!(session_id = ?session_id, "error in websocket connection: {e:#}");
}
}
.in_current_span(),
);
@@ -253,8 +243,7 @@ async fn request_handler(
// Return the response so the spawned future can continue.
Ok(response)
} else if request.uri().path() == "/sql" && request.method() == Method::POST {
let mut ctx =
RequestMonitoring::new(session_id, peer_addr, "http", &config.region, root_span);
let mut ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region);
sql_over_http::handle(
tls,

View File

@@ -9,7 +9,7 @@ use crate::{
use bytes::{Buf, Bytes};
use futures::{Sink, Stream};
use hyper::upgrade::Upgraded;
use hyper_tungstenite::{tungstenite::Message, WebSocketStream};
use hyper_tungstenite::{tungstenite::Message, HyperWebsocket, WebSocketStream};
use pin_project_lite::pin_project;
use std::{
@@ -131,12 +131,13 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncBufRead for WebSocketRw<S> {
pub async fn serve_websocket(
config: &'static ProxyConfig,
ctx: RequestMonitoring,
websocket: WebSocketStream<Upgraded>,
cancel_map: Arc<CancelMap>,
ctx: &mut RequestMonitoring,
websocket: HyperWebsocket,
cancel_map: &CancelMap,
hostname: Option<String>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) {
) -> anyhow::Result<()> {
let websocket = websocket.await?;
handle_client(
config,
ctx,
@@ -145,7 +146,8 @@ pub async fn serve_websocket(
ClientMode::Websockets { hostname },
endpoint_rate_limiter,
)
.await
.await?;
Ok(())
}
#[cfg(test)]

View File

@@ -1,149 +0,0 @@
use futures::Future;
use pq_proto::{framed::Framed, BeMessage};
use tokio::{io::AsyncWrite, task::JoinHandle};
use tracing::{info, warn, Instrument};
pub trait Captures<T> {}
impl<T, U> Captures<T> for U {}
#[must_use]
pub enum StageError {
Flush(Framed<Box<dyn AsyncWrite + Unpin + Send + 'static>>),
Done,
}
impl StageError {
pub async fn finish(self) {
match self {
StageError::Flush(mut f) => {
// ignore result. we can't do anything about it.
// this is already the error case anyway...
if let Err(e) = f.flush().await {
warn!("could not send message to user: {e:?}")
}
}
StageError::Done => {}
}
info!("task finished");
}
}
pub type DynStage = Box<dyn StageSpawn>;
/// Stage represents a single stage in a state machine.
pub trait Stage: 'static + Send {
/// The span this stage should be run inside.
fn span(&self) -> tracing::Span;
/// Run the current stage, returning a new [`DynStage`], or an error
///
/// Can be implemented as `async fn run(self) -> Result<DynStage, StageError>`
fn run(self) -> impl 'static + Send + Future<Output = Result<DynStage, StageError>>;
}
pub enum StageResult {
Finished,
Run(JoinHandle<Result<DynStage, StageError>>),
}
pub trait StageSpawn: 'static + Send {
fn run(self: Box<Self>) -> StageResult;
}
/// Stage spawn is a helper trait for the state machine. It spawns the stages as a tokio task
impl<S: Stage> StageSpawn for S {
fn run(self: Box<Self>) -> StageResult {
let span = self.span();
StageResult::Run(tokio::spawn(S::run(*self).instrument(span)))
}
}
pub struct Finished;
impl StageSpawn for Finished {
fn run(self: Box<Self>) -> StageResult {
StageResult::Finished
}
}
use crate::{
context::RequestMonitoring,
error::{ErrorKind, UserFacingError},
stream::PqStream,
};
pub trait ResultExt<T, E> {
fn send_error_to_user<S>(
self,
ctx: &mut RequestMonitoring,
stream: PqStream<S>,
) -> Result<(T, PqStream<S>), StageError>
where
S: AsyncWrite + Unpin + Send + 'static,
E: UserFacingError;
fn no_user_error(self, ctx: &mut RequestMonitoring, kind: ErrorKind) -> Result<T, StageError>
where
E: std::fmt::Display;
}
impl<T, E> ResultExt<T, E> for Result<T, E> {
fn send_error_to_user<S>(
self,
ctx: &mut RequestMonitoring,
stream: PqStream<S>,
) -> Result<(T, PqStream<S>), StageError>
where
S: AsyncWrite + Unpin + Send + 'static,
E: UserFacingError,
{
match self {
Ok(t) => Ok((t, stream)),
Err(e) => Err(user_facing_error(e, ctx, stream)),
}
}
fn no_user_error(self, ctx: &mut RequestMonitoring, kind: ErrorKind) -> Result<T, StageError>
where
E: std::fmt::Display,
{
match self {
Ok(t) => Ok(t),
Err(e) => {
tracing::error!(
kind = kind.to_str(),
user_msg = "",
"task finished with error: {e}"
);
ctx.error(kind);
ctx.log();
Err(StageError::Done)
}
}
}
}
pub fn user_facing_error<S, E>(
err: E,
ctx: &mut RequestMonitoring,
mut stream: PqStream<S>,
) -> StageError
where
S: AsyncWrite + Unpin + Send + 'static,
E: UserFacingError,
{
let kind = err.get_error_type();
ctx.error(kind);
ctx.log();
let msg = err.to_string_client();
tracing::error!(
kind = kind.to_str(),
user_msg = msg,
"task finished with error: {err}"
);
if let Err(err) = stream.write_message_noflush(&BeMessage::ErrorResponse(&msg, None)) {
warn!("could not process error message: {err:?}")
}
StageError::Flush(stream.framed.map_stream_sync(|f| Box::new(f) as Box<_>))
}

View File

@@ -1,5 +1,5 @@
use crate::config::TlsServerEndPoint;
use crate::error::ErrorKind;
use crate::error::UserFacingError;
use anyhow::bail;
use bytes::BytesMut;
@@ -99,17 +99,24 @@ impl<S: AsyncWrite + Unpin> PqStream<S> {
/// Allowing string literals is safe under the assumption they might not contain any runtime info.
/// This method exists due to `&str` not implementing `Into<anyhow::Error>`.
pub async fn throw_error_str<T>(&mut self, error: &'static str) -> anyhow::Result<T> {
let kind = ErrorKind::User;
tracing::error!(
kind = kind.to_str(),
full_msg = error,
user_msg = error,
"task finished with error"
);
tracing::info!("forwarding error to user: {error}");
self.write_message(&BeMessage::ErrorResponse(error, None))
.await?;
bail!(error)
}
/// Write the error message using [`Self::write_message`], then re-throw it.
/// Trait [`UserFacingError`] acts as an allowlist for error types.
pub async fn throw_error<T, E>(&mut self, error: E) -> anyhow::Result<T>
where
E: UserFacingError + Into<anyhow::Error>,
{
let msg = error.to_string_client();
tracing::info!("forwarding error to user: {msg}");
self.write_message(&BeMessage::ErrorResponse(&msg, None))
.await?;
bail!(error)
}
}
/// Wrapper for upgrading raw streams into secure streams.

View File

@@ -8,6 +8,8 @@ use futures::future::BoxFuture;
use futures::stream::FuturesUnordered;
use futures::{FutureExt, StreamExt};
use remote_storage::RemoteStorageConfig;
use safekeeper::control_file::FileStorage;
use safekeeper::state::TimelinePersistentState;
use sd_notify::NotifyState;
use tokio::runtime::Handle;
use tokio::signal::unix::{signal, SignalKind};
@@ -30,12 +32,12 @@ use safekeeper::defaults::{
DEFAULT_HEARTBEAT_TIMEOUT, DEFAULT_HTTP_LISTEN_ADDR, DEFAULT_MAX_OFFLOADER_LAG_BYTES,
DEFAULT_PG_LISTEN_ADDR,
};
use safekeeper::wal_service;
use safekeeper::GlobalTimelines;
use safekeeper::SafeKeeperConf;
use safekeeper::{broker, WAL_SERVICE_RUNTIME};
use safekeeper::{control_file, BROKER_RUNTIME};
use safekeeper::{http, WAL_REMOVER_RUNTIME};
use safekeeper::{json_merge, wal_service};
use safekeeper::{remove_wal, WAL_BACKUP_RUNTIME};
use safekeeper::{wal_backup, HTTP_RUNTIME};
use storage_broker::DEFAULT_ENDPOINT;
@@ -105,9 +107,6 @@ struct Args {
/// Do not wait for changes to be written safely to disk. Unsafe.
#[arg(short, long)]
no_sync: bool,
/// Dump control file at path specified by this argument and exit.
#[arg(long)]
dump_control_file: Option<Utf8PathBuf>,
/// Broker endpoint for storage nodes coordination in the form
/// http[s]://host:port. In case of https schema TLS is connection is
/// established; plaintext otherwise.
@@ -166,6 +165,21 @@ struct Args {
/// useful for debugging.
#[arg(long)]
current_thread_runtime: bool,
/// Dump control file at path specified by this argument and exit.
#[arg(long)]
dump_control_file: Option<Utf8PathBuf>,
/// Patch control file at path specified by this argument and exit.
/// Patch is specified in --patch option and imposed over
/// control file as per rfc7386.
/// Without --write-patched the result is only printed.
#[arg(long, verbatim_doc_comment)]
patch_control_file: Option<Utf8PathBuf>,
/// The patch to apply to control file at --patch-control-file, in JSON.
#[arg(long, default_value = None)]
patch: Option<String>,
/// Write --patch-control-file result back in place.
#[arg(long, default_value = "false")]
write_patched: bool,
}
// Like PathBufValueParser, but allows empty string.
@@ -207,7 +221,13 @@ async fn main() -> anyhow::Result<()> {
if let Some(addr) = args.dump_control_file {
let state = control_file::FileStorage::load_control_file(addr)?;
let json = serde_json::to_string(&state)?;
print!("{json}");
println!("{json}");
return Ok(());
}
if let Some(cfile_path) = args.patch_control_file {
let patch = args.patch.ok_or(anyhow::anyhow!("patch is missing"))?;
patch_control_file(cfile_path, patch, args.write_patched).await?;
return Ok(());
}
@@ -529,6 +549,26 @@ fn parse_remote_storage(storage_conf: &str) -> anyhow::Result<RemoteStorageConfi
})
}
async fn patch_control_file(
cfile_path: Utf8PathBuf,
patch: String,
write: bool,
) -> anyhow::Result<()> {
let state = control_file::FileStorage::load_control_file(&cfile_path)?;
// serialize to json, impose patch and deserialize back
let mut state_json =
serde_json::to_value(state).context("failed to serialize state to json")?;
let patch_json = serde_json::from_str(&patch).context("failed to parse patch")?;
json_merge(&mut state_json, patch_json);
let patched_state: TimelinePersistentState =
serde_json::from_value(state_json.clone()).context("failed to deserialize patched json")?;
println!("{state_json}");
if write {
FileStorage::do_persist(&patched_state, &cfile_path, true).await?;
}
return Ok(());
}
#[test]
fn verify_cli() {
use clap::CommandFactory;

View File

@@ -2,7 +2,7 @@
use anyhow::{bail, ensure, Context, Result};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use camino::Utf8PathBuf;
use camino::{Utf8Path, Utf8PathBuf};
use tokio::fs::{self, File};
use tokio::io::AsyncWriteExt;
@@ -155,6 +155,46 @@ impl FileStorage {
})?;
Ok(state)
}
/// Persist state s to dst_path, optionally fsyncing file.
pub async fn do_persist(
s: &TimelinePersistentState,
dst_path: &Utf8Path,
sync: bool,
) -> Result<()> {
let mut f = File::create(&dst_path)
.await
.with_context(|| format!("failed to create partial control file at: {}", &dst_path))?;
let mut buf: Vec<u8> = Vec::new();
WriteBytesExt::write_u32::<LittleEndian>(&mut buf, SK_MAGIC)?;
WriteBytesExt::write_u32::<LittleEndian>(&mut buf, SK_FORMAT_VERSION)?;
s.ser_into(&mut buf)?;
// calculate checksum before resize
let checksum = crc32c::crc32c(&buf);
buf.extend_from_slice(&checksum.to_le_bytes());
f.write_all(&buf).await.with_context(|| {
format!(
"failed to write safekeeper state into control file at: {}",
dst_path
)
})?;
f.flush().await.with_context(|| {
format!(
"failed to flush safekeeper state into control file at: {}",
dst_path
)
})?;
// fsync the file
if sync {
f.sync_all()
.await
.with_context(|| format!("failed to sync partial control file at {}", dst_path))?;
}
Ok(())
}
}
impl Deref for FileStorage {
@@ -167,7 +207,7 @@ impl Deref for FileStorage {
#[async_trait::async_trait]
impl Storage for FileStorage {
/// Persists state durably to the underlying storage.
/// Atomically persists state durably to the underlying storage.
///
/// For a description, see <https://lwn.net/Articles/457667/>.
async fn persist(&mut self, s: &TimelinePersistentState) -> Result<()> {
@@ -175,46 +215,9 @@ impl Storage for FileStorage {
// write data to safekeeper.control.partial
let control_partial_path = self.timeline_dir.join(CONTROL_FILE_NAME_PARTIAL);
let mut control_partial = File::create(&control_partial_path).await.with_context(|| {
format!(
"failed to create partial control file at: {}",
&control_partial_path
)
})?;
let mut buf: Vec<u8> = Vec::new();
WriteBytesExt::write_u32::<LittleEndian>(&mut buf, SK_MAGIC)?;
WriteBytesExt::write_u32::<LittleEndian>(&mut buf, SK_FORMAT_VERSION)?;
s.ser_into(&mut buf)?;
// calculate checksum before resize
let checksum = crc32c::crc32c(&buf);
buf.extend_from_slice(&checksum.to_le_bytes());
control_partial.write_all(&buf).await.with_context(|| {
format!(
"failed to write safekeeper state into control file at: {}",
control_partial_path
)
})?;
control_partial.flush().await.with_context(|| {
format!(
"failed to flush safekeeper state into control file at: {}",
control_partial_path
)
})?;
// fsync the file
if !self.conf.no_sync {
control_partial.sync_all().await.with_context(|| {
format!(
"failed to sync partial control file at {}",
control_partial_path
)
})?;
}
FileStorage::do_persist(s, &control_partial_path, !self.conf.no_sync).await?;
let control_path = self.timeline_dir.join(CONTROL_FILE_NAME);
// rename should be atomic
fs::rename(&control_partial_path, &control_path).await?;
// this sync is not required by any standard but postgres does this (see durable_rename)

View File

@@ -2,6 +2,7 @@
use camino::Utf8PathBuf;
use once_cell::sync::Lazy;
use remote_storage::RemoteStorageConfig;
use serde_json::Value;
use tokio::runtime::Runtime;
use std::time::Duration;
@@ -175,3 +176,24 @@ pub static METRICS_SHIFTER_RUNTIME: Lazy<Runtime> = Lazy::new(|| {
.build()
.expect("Failed to create broker runtime")
});
/// Merge json b into json a according to
/// https://www.rfc-editor.org/rfc/rfc7396
/// https://stackoverflow.com/a/54118457/4014587
pub fn json_merge(a: &mut Value, b: Value) {
if let Value::Object(a) = a {
if let Value::Object(b) = b {
for (k, v) in b {
if v.is_null() {
a.remove(&k);
} else {
json_merge(a.entry(k).or_insert(Value::Null), v);
}
}
return;
}
}
*a = b;
}

View File

@@ -110,7 +110,7 @@ pub static REMOVED_WAL_SEGMENTS: Lazy<IntCounter> = Lazy::new(|| {
pub static BACKED_UP_SEGMENTS: Lazy<IntCounter> = Lazy::new(|| {
register_int_counter!(
"safekeeper_backed_up_segments_total",
"Number of WAL segments backed up to the S3"
"Number of WAL segments backed up to the broker"
)
.expect("Failed to register safekeeper_backed_up_segments_total counter")
});
@@ -337,7 +337,6 @@ pub struct TimelineCollector {
flushed_wal_seconds: GaugeVec,
collect_timeline_metrics: Gauge,
timelines_count: IntGauge,
active_timelines_count: IntGauge,
}
impl Default for TimelineCollector {
@@ -521,13 +520,6 @@ impl TimelineCollector {
.unwrap();
descs.extend(timelines_count.desc().into_iter().cloned());
let active_timelines_count = IntGauge::new(
"safekeeper_active_timelines",
"Total number of active timelines",
)
.unwrap();
descs.extend(active_timelines_count.desc().into_iter().cloned());
TimelineCollector {
descs,
commit_lsn,
@@ -548,7 +540,6 @@ impl TimelineCollector {
flushed_wal_seconds,
collect_timeline_metrics,
timelines_count,
active_timelines_count,
}
}
}
@@ -581,7 +572,6 @@ impl Collector for TimelineCollector {
let timelines = GlobalTimelines::get_all();
let timelines_count = timelines.len();
let mut active_timelines_count = 0;
// Prometheus Collector is sync, and data is stored under async lock. To
// bridge the gap with a crutch, collect data in spawned thread with
@@ -600,10 +590,6 @@ impl Collector for TimelineCollector {
let timeline_id = tli.ttid.timeline_id.to_string();
let labels = &[tenant_id.as_str(), timeline_id.as_str()];
if tli.timeline_is_active {
active_timelines_count += 1;
}
self.commit_lsn
.with_label_values(labels)
.set(tli.mem_state.commit_lsn.into());
@@ -695,8 +681,6 @@ impl Collector for TimelineCollector {
// report total number of timelines
self.timelines_count.set(timelines_count as i64);
self.active_timelines_count
.set(active_timelines_count as i64);
mfs.extend(self.timelines_count.collect());
mfs

View File

@@ -10,15 +10,11 @@ use crate::{GlobalTimelines, SafeKeeperConf};
pub async fn task_main(conf: SafeKeeperConf) -> anyhow::Result<()> {
let wal_removal_interval = Duration::from_millis(5000);
loop {
let now = tokio::time::Instant::now();
let mut active_timelines = 0;
let tlis = GlobalTimelines::get_all();
for tli in &tlis {
if !tli.is_active().await {
continue;
}
active_timelines += 1;
let ttid = tli.ttid;
async {
if let Err(e) = tli.maybe_persist_control_file().await {
@@ -31,17 +27,6 @@ pub async fn task_main(conf: SafeKeeperConf) -> anyhow::Result<()> {
.instrument(info_span!("WAL removal", ttid = %ttid))
.await;
}
let elapsed = now.elapsed();
let total_timelines = tlis.len();
if elapsed > wal_removal_interval {
info!(
"WAL removal is too long, processed {} active timelines ({} total) in {:?}",
active_timelines, total_timelines, elapsed
);
}
sleep(wal_removal_interval).await;
}
}

View File

@@ -2914,7 +2914,6 @@ class Endpoint(PgProtocol):
# Write it back updated
with open(config_path, "w") as file:
log.info(json.dumps(dict(data_dict, **kwargs)))
json.dump(dict(data_dict, **kwargs), file, indent=4)
# Mock the extension part of spec passed from control plane for local testing

View File

@@ -248,15 +248,8 @@ def test_ddl_forwarding(ddl: DdlForwardingContext):
# We don't have compute_ctl, so here, so create neon_superuser here manually
cur.execute("CREATE ROLE neon_superuser NOLOGIN CREATEDB CREATEROLE")
# Contrary to popular belief, being superman does not make you superuser
cur.execute("CREATE ROLE superman LOGIN NOSUPERUSER PASSWORD 'jungle_man'")
with ddl.pg.cursor(user="superman", password="jungle_man") as superman_cur:
# We allow real SUPERUSERs to ALTER neon_superuser
with pytest.raises(psycopg2.InternalError):
superman_cur.execute("ALTER ROLE neon_superuser LOGIN")
cur.execute("ALTER ROLE neon_superuser LOGIN")
with pytest.raises(psycopg2.InternalError):
cur.execute("ALTER ROLE neon_superuser LOGIN")
with pytest.raises(psycopg2.InternalError):
cur.execute("CREATE DATABASE trololobus WITH OWNER neon_superuser")

View File

@@ -1,37 +0,0 @@
import time
from fixtures.neon_fixtures import NeonEnv
def test_migrations(neon_simple_env: NeonEnv):
env = neon_simple_env
env.neon_cli.create_branch("test_migrations", "empty")
endpoint = env.endpoints.create("test_migrations")
log_path = endpoint.endpoint_path() / "compute.log"
endpoint.respec(skip_pg_catalog_updates=False, features=["migrations"])
endpoint.start()
time.sleep(1) # Sleep to let migrations run
with endpoint.cursor() as cur:
cur.execute("SELECT id FROM neon_migration.migration_id")
migration_id = cur.fetchall()
assert migration_id[0][0] == 2
with open(log_path, "r") as log_file:
logs = log_file.read()
assert "INFO handle_migrations: Ran 2 migrations" in logs
endpoint.stop()
endpoint.start()
time.sleep(1) # Sleep to let migrations run
with endpoint.cursor() as cur:
cur.execute("SELECT id FROM neon_migration.migration_id")
migration_id = cur.fetchall()
assert migration_id[0][0] == 2
with open(log_path, "r") as log_file:
logs = log_file.read()
assert "INFO handle_migrations: Ran 0 migrations" in logs

View File

@@ -1,18 +1,10 @@
import json
import os
import time
from pathlib import Path
from fixtures.log_helper import log
from fixtures.neon_fixtures import NeonEnvBuilder, wait_for_wal_insert_lsn
from fixtures.pageserver.utils import (
wait_for_last_record_lsn,
)
from fixtures.remote_storage import RemoteStorageKind
from fixtures.types import Lsn, TenantId, TimelineId
from fixtures.utils import query_scalar
from fixtures.neon_fixtures import NeonEnvBuilder
# Test restarting page server, while safekeeper and compute node keep
# running.
def test_next_xid(neon_env_builder: NeonEnvBuilder):
env = neon_env_builder.init_start()
@@ -60,161 +52,3 @@ def test_next_xid(neon_env_builder: NeonEnvBuilder):
cur = conn.cursor()
cur.execute("SELECT count(*) FROM t")
assert cur.fetchone() == (iterations,)
# Test for a bug we had, where nextXid was incorrectly updated when the
# XID counter reached 2 billion. The nextXid tracking logic incorrectly
# treated 0 (InvalidTransactionId) as a regular XID, and after reaching
# 2 billion, it started to look like a very new XID, which caused nextXid
# to be immediately advanced to the next epoch.
#
def test_import_at_2bil(
neon_env_builder: NeonEnvBuilder,
test_output_dir: Path,
pg_distrib_dir: Path,
pg_bin,
vanilla_pg,
):
neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.LOCAL_FS)
env = neon_env_builder.init_start()
ps_http = env.pageserver.http_client()
# Set LD_LIBRARY_PATH in the env properly, otherwise we may use the wrong libpq.
# PgBin sets it automatically, but here we need to pipe psql output to the tar command.
psql_env = {"LD_LIBRARY_PATH": str(pg_distrib_dir / "lib")}
# Reset the vanilla Postgres instance to somewhat before 2 billion transactions.
pg_resetwal_path = os.path.join(pg_bin.pg_bin_path, "pg_resetwal")
cmd = [pg_resetwal_path, "--next-transaction-id=2129920000", "-D", str(vanilla_pg.pgdatadir)]
pg_bin.run_capture(cmd, env=psql_env)
vanilla_pg.start()
vanilla_pg.safe_psql("create user cloud_admin with password 'postgres' superuser")
vanilla_pg.safe_psql(
"""create table tt as select 'long string to consume some space' || g
from generate_series(1,300000) g"""
)
assert vanilla_pg.safe_psql("select count(*) from tt") == [(300000,)]
vanilla_pg.safe_psql("CREATE TABLE t (t text);")
vanilla_pg.safe_psql("INSERT INTO t VALUES ('inserted in vanilla')")
endpoint_id = "ep-import_from_vanilla"
tenant = TenantId.generate()
timeline = TimelineId.generate()
env.pageserver.tenant_create(tenant)
# Take basebackup
basebackup_dir = os.path.join(test_output_dir, "basebackup")
base_tar = os.path.join(basebackup_dir, "base.tar")
wal_tar = os.path.join(basebackup_dir, "pg_wal.tar")
os.mkdir(basebackup_dir)
vanilla_pg.safe_psql("CHECKPOINT")
pg_bin.run(
[
"pg_basebackup",
"-F",
"tar",
"-d",
vanilla_pg.connstr(),
"-D",
basebackup_dir,
]
)
# Get start_lsn and end_lsn
with open(os.path.join(basebackup_dir, "backup_manifest")) as f:
manifest = json.load(f)
start_lsn = manifest["WAL-Ranges"][0]["Start-LSN"]
end_lsn = manifest["WAL-Ranges"][0]["End-LSN"]
def import_tar(base, wal):
env.neon_cli.raw_cli(
[
"timeline",
"import",
"--tenant-id",
str(tenant),
"--timeline-id",
str(timeline),
"--node-name",
endpoint_id,
"--base-lsn",
start_lsn,
"--base-tarfile",
base,
"--end-lsn",
end_lsn,
"--wal-tarfile",
wal,
"--pg-version",
env.pg_version,
]
)
# Importing correct backup works
import_tar(base_tar, wal_tar)
wait_for_last_record_lsn(ps_http, tenant, timeline, Lsn(end_lsn))
endpoint = env.endpoints.create_start(
endpoint_id,
tenant_id=tenant,
config_lines=[
"log_autovacuum_min_duration = 0",
"autovacuum_naptime='5 s'",
],
)
assert endpoint.safe_psql("select count(*) from t") == [(1,)]
# Ok, consume
conn = endpoint.connect()
cur = conn.cursor()
# Install extension containing function needed for test
cur.execute("CREATE EXTENSION neon_test_utils")
# Advance nextXid close to 2 billion XIDs
while True:
xid = int(query_scalar(cur, "SELECT txid_current()"))
log.info(f"xid now {xid}")
# Consume 10k transactons at a time until we get to 2^31 - 200k
if xid < 2 * 1024 * 1024 * 1024 - 100000:
cur.execute("select test_consume_xids(50000);")
elif xid < 2 * 1024 * 1024 * 1024 - 10000:
cur.execute("select test_consume_xids(5000);")
else:
break
# Run a bunch of real INSERTs to cross over the 2 billion mark
# Use a begin-exception block to have a separate sub-XID for each insert.
cur.execute(
"""
do $$
begin
for i in 1..10000 loop
-- Use a begin-exception block to generate a new subtransaction on each iteration
begin
insert into t values (i);
exception when others then
raise 'not expected %', sqlerrm;
end;
end loop;
end;
$$;
"""
)
# A checkpoint writes a WAL record with xl_xid=0. Many other WAL
# records would have the same effect.
cur.execute("checkpoint")
# wait until pageserver receives that data
wait_for_wal_insert_lsn(env, endpoint, tenant, timeline)
# Restart endpoint
endpoint.stop()
endpoint.start()
conn = endpoint.connect()
cur = conn.cursor()
cur.execute("SELECT count(*) from t")
assert cur.fetchone() == (10000 + 1,)

View File

@@ -1,42 +0,0 @@
import threading
import time
from contextlib import closing
from fixtures.log_helper import log
from fixtures.neon_fixtures import NeonEnv, PgBin
# Test updating neon.pageserver_connstring setting on the fly.
#
# This merely changes some whitespace in the connection string, so
# this doesn't prove that the new string actually takes effect. But at
# least the code gets exercised.
def test_pageserver_reconnect(neon_simple_env: NeonEnv, pg_bin: PgBin):
env = neon_simple_env
env.neon_cli.create_branch("test_pageserver_restarts")
endpoint = env.endpoints.create_start("test_pageserver_restarts")
n_reconnects = 1000
timeout = 0.01
scale = 10
def run_pgbench(connstr: str):
log.info(f"Start a pgbench workload on pg {connstr}")
pg_bin.run_capture(["pgbench", "-i", f"-s{scale}", connstr])
pg_bin.run_capture(["pgbench", f"-T{int(n_reconnects*timeout)}", connstr])
thread = threading.Thread(target=run_pgbench, args=(endpoint.connstr(),), daemon=True)
thread.start()
with closing(endpoint.connect()) as con:
with con.cursor() as c:
c.execute("SELECT setting FROM pg_settings WHERE name='neon.pageserver_connstring'")
connstring = c.fetchall()[0][0]
for i in range(n_reconnects):
time.sleep(timeout)
c.execute(
"alter system set neon.pageserver_connstring=%s",
(connstring + (" " * (i % 2)),),
)
c.execute("select pg_reload_conf()")
thread.join()

View File

@@ -1,4 +1,3 @@
import pytest
from fixtures.log_helper import log
from fixtures.neon_fixtures import NeonEnv, fork_at_current_lsn
@@ -118,8 +117,6 @@ def test_vm_bit_clear(neon_simple_env: NeonEnv):
# Test that the ALL_FROZEN VM bit is cleared correctly at a HEAP_LOCK
# record.
#
# FIXME: This test is broken
@pytest.mark.skip("See https://github.com/neondatabase/neon/pull/6412#issuecomment-1902072541")
def test_vm_bit_clear_on_heap_lock(neon_simple_env: NeonEnv):
env = neon_simple_env

View File

@@ -29,8 +29,10 @@ chrono = { version = "0.4", default-features = false, features = ["clock", "serd
clap = { version = "4", features = ["derive", "string"] }
clap_builder = { version = "4", default-features = false, features = ["color", "help", "std", "string", "suggestions", "usage"] }
crossbeam-utils = { version = "0.8" }
dashmap = { version = "5", default-features = false, features = ["raw-api"] }
either = { version = "1" }
fail = { version = "0.5", default-features = false, features = ["failpoints"] }
futures = { version = "0.3" }
futures-channel = { version = "0.3", features = ["sink"] }
futures-core = { version = "0.3" }
futures-executor = { version = "0.3" }
@@ -72,7 +74,6 @@ tokio-rustls = { version = "0.24" }
tokio-util = { version = "0.7", features = ["codec", "compat", "io", "rt"] }
toml_datetime = { version = "0.6", default-features = false, features = ["serde"] }
toml_edit = { version = "0.19", features = ["serde"] }
tonic = { version = "0.9", features = ["tls-roots"] }
tower = { version = "0.4", default-features = false, features = ["balance", "buffer", "limit", "log", "timeout", "util"] }
tracing = { version = "0.1", features = ["log"] }
tracing-core = { version = "0.1" }