mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-21 07:00:38 +00:00
Compare commits
443 Commits
arpad/walp
...
release-pr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b147439d6b | ||
|
|
87179e26b3 | ||
|
|
f05df409bd | ||
|
|
f6c0f6c4ec | ||
|
|
62cd3b8d3d | ||
|
|
8d26978ed9 | ||
|
|
35372a8f12 | ||
|
|
6d95a3fe2d | ||
|
|
99726495c7 | ||
|
|
4a4a457312 | ||
|
|
e78d1e2ec6 | ||
|
|
af429b4a62 | ||
|
|
3b4d4eb535 | ||
|
|
f060537a31 | ||
|
|
8a6fc6fd8c | ||
|
|
51639cd6af | ||
|
|
529d661532 | ||
|
|
9e4cf52949 | ||
|
|
831f2a4ba7 | ||
|
|
eadabeddb8 | ||
|
|
67ddf1de28 | ||
|
|
541fcd8d2f | ||
|
|
e77961c1c6 | ||
|
|
cdfa06caad | ||
|
|
f0bb93a9c9 | ||
|
|
30adf8e2bd | ||
|
|
5d538a9503 | ||
|
|
54433c0839 | ||
|
|
40bb9ff62a | ||
|
|
4688b815b1 | ||
|
|
0982ca4636 | ||
|
|
7272d9f7b3 | ||
|
|
37d555aa59 | ||
|
|
cae3e2976b | ||
|
|
51ecd1bb37 | ||
|
|
1e6bb48076 | ||
|
|
1470af0b42 | ||
|
|
f92f92b91b | ||
|
|
dbb205ae92 | ||
|
|
85072b715f | ||
|
|
6c86fe7143 | ||
|
|
66d5fe7f5b | ||
|
|
a1b9528757 | ||
|
|
1423bb8aa2 | ||
|
|
332f064a42 | ||
|
|
c962f2b447 | ||
|
|
446b3f9d28 | ||
|
|
23352dc2e9 | ||
|
|
c65fc5a955 | ||
|
|
3e624581cd | ||
|
|
fedf4f169c | ||
|
|
86d5798108 | ||
|
|
8b4088dd8a | ||
|
|
c91905e643 | ||
|
|
44b4e355a2 | ||
|
|
03666a1f37 | ||
|
|
9c92242ca0 | ||
|
|
a354071dd0 | ||
|
|
758680d4f8 | ||
|
|
1738fd0a96 | ||
|
|
87b7edfc72 | ||
|
|
def05700d5 | ||
|
|
b547681e08 | ||
|
|
0fd211537b | ||
|
|
a83bd4e81c | ||
|
|
ecdad5e6d5 | ||
|
|
d028929945 | ||
|
|
7b0e3db868 | ||
|
|
088eb72dd7 | ||
|
|
d550e3f626 | ||
|
|
8c6b41daf5 | ||
|
|
bbb050459b | ||
|
|
cab498c787 | ||
|
|
6359342ffb | ||
|
|
13285c2a5e | ||
|
|
33790d14a3 | ||
|
|
709b8cd371 | ||
|
|
1c9bbf1a92 | ||
|
|
16163fb850 | ||
|
|
73ccc2b08c | ||
|
|
c719be6474 | ||
|
|
718645e56c | ||
|
|
fbc8c36983 | ||
|
|
5519e42612 | ||
|
|
4157eaf4c5 | ||
|
|
60241127e2 | ||
|
|
f7d5322e8b | ||
|
|
41bb9c5280 | ||
|
|
69c0d61c5c | ||
|
|
63cb8ce975 | ||
|
|
907e4aa3c4 | ||
|
|
0a2a84b766 | ||
|
|
85b12ddd52 | ||
|
|
dd76f1eeee | ||
|
|
8963ac85f9 | ||
|
|
4a488b3e24 | ||
|
|
c4987b0b13 | ||
|
|
84b4821118 | ||
|
|
32ba9811f9 | ||
|
|
a0cd64c4d3 | ||
|
|
84687b743d | ||
|
|
b6f93dcec9 | ||
|
|
4f6c594973 | ||
|
|
a750c14735 | ||
|
|
9ce0dd4e55 | ||
|
|
0e1a336607 | ||
|
|
7fc2912d06 | ||
|
|
fdf231c237 | ||
|
|
1e08b5dccc | ||
|
|
030810ed3e | ||
|
|
62b74bdc2c | ||
|
|
8b7e9ed820 | ||
|
|
5dad89acd4 | ||
|
|
547b2d2827 | ||
|
|
93f29a0065 | ||
|
|
4f36494615 | ||
|
|
0a550f3e7d | ||
|
|
4bb9554e4a | ||
|
|
008616cfe6 | ||
|
|
e61ec94fbc | ||
|
|
e5152551ad | ||
|
|
b0822a5499 | ||
|
|
1fb6ab59e8 | ||
|
|
e16439400d | ||
|
|
e401f66698 | ||
|
|
2fa461b668 | ||
|
|
03d90bc0b3 | ||
|
|
268bc890ea | ||
|
|
8a6ee79f6f | ||
|
|
9052c32b46 | ||
|
|
995e729ebe | ||
|
|
76077e1ddf | ||
|
|
0467d88f06 | ||
|
|
f5eec194e7 | ||
|
|
7e00be391d | ||
|
|
d56599df2a | ||
|
|
9d9aab3680 | ||
|
|
a202b1b5cc | ||
|
|
90f731f3b1 | ||
|
|
7736b748d3 | ||
|
|
9c23333cb3 | ||
|
|
66a99009ba | ||
|
|
5d4c57491f | ||
|
|
73935ea3a2 | ||
|
|
32e595d4dd | ||
|
|
b0d69acb07 | ||
|
|
98355a419a | ||
|
|
cfb03d6cf0 | ||
|
|
d81ef3f962 | ||
|
|
5d62c67e75 | ||
|
|
53d53d5b1e | ||
|
|
29fe6ea47a | ||
|
|
640327ccb3 | ||
|
|
7cf0f6b37e | ||
|
|
03c2c569be | ||
|
|
eff6d4538a | ||
|
|
5ef7782e9c | ||
|
|
73101db8c4 | ||
|
|
bccdfc6d39 | ||
|
|
99595813bb | ||
|
|
fe07b54758 | ||
|
|
a42d173e7b | ||
|
|
e07f689238 | ||
|
|
7831eddc88 | ||
|
|
943b1bc80c | ||
|
|
95a184e9b7 | ||
|
|
3fa17e9d17 | ||
|
|
55e0fd9789 | ||
|
|
2a88889f44 | ||
|
|
5bad8126dc | ||
|
|
27bc242085 | ||
|
|
192b49cc6d | ||
|
|
e1b60f3693 | ||
|
|
2804f5323b | ||
|
|
676adc6b32 | ||
|
|
96a4e8de66 | ||
|
|
01180666b0 | ||
|
|
6c94269c32 | ||
|
|
edc691647d | ||
|
|
855d7b4781 | ||
|
|
c49c9707ce | ||
|
|
2227540a0d | ||
|
|
f1347f2417 | ||
|
|
30b295b017 | ||
|
|
1cef395266 | ||
|
|
78d160f76d | ||
|
|
b9238059d6 | ||
|
|
d0cb4b88c8 | ||
|
|
1ec3e39d4e | ||
|
|
a1a74eef2c | ||
|
|
90e689adda | ||
|
|
f0b2d4b053 | ||
|
|
299d9474c9 | ||
|
|
7234208b36 | ||
|
|
93450f11f5 | ||
|
|
2f0f9edf33 | ||
|
|
d424f2b7c8 | ||
|
|
21315e80bc | ||
|
|
483b66d383 | ||
|
|
aa72a22661 | ||
|
|
5c0264b591 | ||
|
|
9f13277729 | ||
|
|
54aa319805 | ||
|
|
4a227484bf | ||
|
|
2f83f85291 | ||
|
|
d6cfcb0d93 | ||
|
|
392843ad2a | ||
|
|
bd4dae8f4a | ||
|
|
b05fe53cfd | ||
|
|
c13a2f0df1 | ||
|
|
39be366fc5 | ||
|
|
6eda0a3158 | ||
|
|
306c7a1813 | ||
|
|
80be423a58 | ||
|
|
5dcfef82f2 | ||
|
|
e67b8f69c0 | ||
|
|
e546872ab4 | ||
|
|
322ea1cf7c | ||
|
|
3633742de9 | ||
|
|
079d3a37ba | ||
|
|
a46e77b476 | ||
|
|
a92702b01e | ||
|
|
8ff3253f20 | ||
|
|
04b82c92a7 | ||
|
|
e5bf423e68 | ||
|
|
60af392e45 | ||
|
|
661fc41e71 | ||
|
|
702c488f32 | ||
|
|
45c5122754 | ||
|
|
558394f710 | ||
|
|
73b0898608 | ||
|
|
e65be4c2dc | ||
|
|
40087b8164 | ||
|
|
c762b59483 | ||
|
|
5d71601ca9 | ||
|
|
a113c3e433 | ||
|
|
e81fc598f4 | ||
|
|
48b845fa76 | ||
|
|
27096858dc | ||
|
|
4430d0ae7d | ||
|
|
6e183aa0de | ||
|
|
fd6d0b7635 | ||
|
|
3710c32aae | ||
|
|
be83bee49d | ||
|
|
cf28e5922a | ||
|
|
7d384d6953 | ||
|
|
4b3b37b912 | ||
|
|
1d8d200f4d | ||
|
|
0d80d6ce18 | ||
|
|
f653ee039f | ||
|
|
e614a95853 | ||
|
|
850db4cc13 | ||
|
|
8a316b1277 | ||
|
|
4d13bae449 | ||
|
|
49377abd98 | ||
|
|
a6b2f4e54e | ||
|
|
face60d50b | ||
|
|
9768aa27f2 | ||
|
|
96b2e575e1 | ||
|
|
7222777784 | ||
|
|
5469fdede0 | ||
|
|
72aa6b9fdd | ||
|
|
ae0634b7be | ||
|
|
70711f32fa | ||
|
|
52a88af0aa | ||
|
|
b7a43bf817 | ||
|
|
dce91b33a4 | ||
|
|
23ee4f3050 | ||
|
|
46857e8282 | ||
|
|
368ab0ce54 | ||
|
|
a5987eebfd | ||
|
|
6686ede30f | ||
|
|
373c7057cc | ||
|
|
7d6ec16166 | ||
|
|
0e6fdc8a58 | ||
|
|
521438a5c6 | ||
|
|
07d7874bc8 | ||
|
|
1804111a02 | ||
|
|
cd0178efed | ||
|
|
333574be57 | ||
|
|
79a799a143 | ||
|
|
9da06af6c9 | ||
|
|
ce1753d036 | ||
|
|
67db8432b4 | ||
|
|
4e2e44e524 | ||
|
|
ed786104f3 | ||
|
|
84b74f2bd1 | ||
|
|
fec2ad6283 | ||
|
|
98eebd4682 | ||
|
|
2f74287c9b | ||
|
|
aee1bf95e3 | ||
|
|
b9de9d75ff | ||
|
|
7943b709e6 | ||
|
|
d7d066d493 | ||
|
|
e78ac22107 | ||
|
|
76a8f2bb44 | ||
|
|
8d59a8581f | ||
|
|
b1ddd01289 | ||
|
|
6eae4fc9aa | ||
|
|
765455bca2 | ||
|
|
4204960942 | ||
|
|
67345d66ea | ||
|
|
2266ee5971 | ||
|
|
b58445d855 | ||
|
|
36050e7f3d | ||
|
|
33360ed96d | ||
|
|
39a28d1108 | ||
|
|
efa6aa134f | ||
|
|
2c724e56e2 | ||
|
|
feff887c6f | ||
|
|
353d915fcf | ||
|
|
2e38098cbc | ||
|
|
a6fe5ea1ac | ||
|
|
05b0aed0c1 | ||
|
|
cd1705357d | ||
|
|
6bc7561290 | ||
|
|
fbd3ac14b5 | ||
|
|
e437787c8f | ||
|
|
3460dbf90b | ||
|
|
6b89d99677 | ||
|
|
6cc8ea86e4 | ||
|
|
e62a492d6f | ||
|
|
a475cdf642 | ||
|
|
7002c79a47 | ||
|
|
ee6cf357b4 | ||
|
|
e5c2086b5f | ||
|
|
5f1208296a | ||
|
|
88e8e473cd | ||
|
|
b0a77844f6 | ||
|
|
1baf464307 | ||
|
|
e9b8e81cea | ||
|
|
85d6194aa4 | ||
|
|
333a7a68ef | ||
|
|
6aa4e41bee | ||
|
|
840183e51f | ||
|
|
cbccc94b03 | ||
|
|
fce227df22 | ||
|
|
bd787e800f | ||
|
|
4a7704b4a3 | ||
|
|
ff1119da66 | ||
|
|
4c3ba1627b | ||
|
|
1407174fb2 | ||
|
|
ec9dcb1889 | ||
|
|
d11d781afc | ||
|
|
4e44565b71 | ||
|
|
4ed51ad33b | ||
|
|
1c1ebe5537 | ||
|
|
c19cb7f386 | ||
|
|
4b97d31b16 | ||
|
|
923ade3dd7 | ||
|
|
b04e711975 | ||
|
|
afd0a6b39a | ||
|
|
99752286d8 | ||
|
|
15df93363c | ||
|
|
bc0ab741af | ||
|
|
51d9dfeaa3 | ||
|
|
f63cb18155 | ||
|
|
0de603d88e | ||
|
|
240913912a | ||
|
|
91a4ea0de2 | ||
|
|
8608704f49 | ||
|
|
efef68ce99 | ||
|
|
8daefd24da | ||
|
|
46cc8b7982 | ||
|
|
38cd90dd0c | ||
|
|
a51b269f15 | ||
|
|
43bf6d0a0f | ||
|
|
15273a9b66 | ||
|
|
78aca668d0 | ||
|
|
acbf4148ea | ||
|
|
6508540561 | ||
|
|
a41b5244a8 | ||
|
|
2b3189be95 | ||
|
|
248563c595 | ||
|
|
14cd6ca933 | ||
|
|
eb36403e71 | ||
|
|
3c6f779698 | ||
|
|
f67f0c1c11 | ||
|
|
edb02d3299 | ||
|
|
664a69e65b | ||
|
|
478322ebf9 | ||
|
|
802f174072 | ||
|
|
47f9890bae | ||
|
|
262265daad | ||
|
|
300da5b872 | ||
|
|
7b22b5c433 | ||
|
|
ffca97bc1e | ||
|
|
cb356f3259 | ||
|
|
c85374295f | ||
|
|
4992160677 | ||
|
|
bd535b3371 | ||
|
|
d90c5a03af | ||
|
|
2d02cc9079 | ||
|
|
49ad94b99f | ||
|
|
948a217398 | ||
|
|
125381eae7 | ||
|
|
cd01bbc715 | ||
|
|
d8b5e3b88d | ||
|
|
06d25f2186 | ||
|
|
f759b561f3 | ||
|
|
ece0555600 | ||
|
|
73ea0a0b01 | ||
|
|
d8f6d6fd6f | ||
|
|
d24de169a7 | ||
|
|
0816168296 | ||
|
|
277b44d57a | ||
|
|
68c2c3880e | ||
|
|
49da498f65 | ||
|
|
2c76ba3dd7 | ||
|
|
dbe3dc69ad | ||
|
|
8e5bb3ed49 | ||
|
|
ab0be7b8da | ||
|
|
b4c55f5d24 | ||
|
|
ede70d833c | ||
|
|
70c3d18bb0 | ||
|
|
7a491f52c4 | ||
|
|
323c4ecb4f | ||
|
|
3d2466607e | ||
|
|
ed478b39f4 | ||
|
|
91585a558d | ||
|
|
93467eae1f | ||
|
|
f3aac81d19 | ||
|
|
979ad60c19 | ||
|
|
9316cb1b1f | ||
|
|
e7939a527a | ||
|
|
36d26665e1 | ||
|
|
873347f977 | ||
|
|
e814ac16f9 | ||
|
|
ad3055d386 | ||
|
|
94e03eb452 | ||
|
|
380f26ef79 | ||
|
|
3c5b7f59d7 | ||
|
|
fee89f80b5 | ||
|
|
41cce8eaf1 | ||
|
|
f88fe0218d | ||
|
|
cc856eca85 | ||
|
|
cf350c6002 | ||
|
|
0ce6b6a0a3 | ||
|
|
73f247d537 | ||
|
|
960be82183 | ||
|
|
806e5a6c19 | ||
|
|
8d5df07cce | ||
|
|
df7a9d1407 |
@@ -310,13 +310,13 @@ RUN curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux
|
||||
. "$HOME/.cargo/env" && \
|
||||
cargo --version && rustup --version && \
|
||||
rustup component add llvm-tools rustfmt clippy && \
|
||||
cargo install rustfilt --version ${RUSTFILT_VERSION} && \
|
||||
cargo install cargo-hakari --version ${CARGO_HAKARI_VERSION} && \
|
||||
cargo install cargo-deny --locked --version ${CARGO_DENY_VERSION} && \
|
||||
cargo install cargo-hack --version ${CARGO_HACK_VERSION} && \
|
||||
cargo install cargo-nextest --version ${CARGO_NEXTEST_VERSION} && \
|
||||
cargo install cargo-chef --locked --version ${CARGO_CHEF_VERSION} && \
|
||||
cargo install diesel_cli --version ${CARGO_DIESEL_CLI_VERSION} \
|
||||
cargo install rustfilt --version ${RUSTFILT_VERSION} --locked && \
|
||||
cargo install cargo-hakari --version ${CARGO_HAKARI_VERSION} --locked && \
|
||||
cargo install cargo-deny --version ${CARGO_DENY_VERSION} --locked && \
|
||||
cargo install cargo-hack --version ${CARGO_HACK_VERSION} --locked && \
|
||||
cargo install cargo-nextest --version ${CARGO_NEXTEST_VERSION} --locked && \
|
||||
cargo install cargo-chef --version ${CARGO_CHEF_VERSION} --locked && \
|
||||
cargo install diesel_cli --version ${CARGO_DIESEL_CLI_VERSION} --locked \
|
||||
--features postgres-bundled --no-default-features && \
|
||||
rm -rf /home/nonroot/.cargo/registry && \
|
||||
rm -rf /home/nonroot/.cargo/git
|
||||
|
||||
@@ -57,21 +57,6 @@ use tracing::{error, info};
|
||||
use url::Url;
|
||||
use utils::failpoint_support;
|
||||
|
||||
// Compatibility hack: if the control plane specified any remote-ext-config
|
||||
// use the default value for extension storage proxy gateway.
|
||||
// Remove this once the control plane is updated to pass the gateway URL
|
||||
fn parse_remote_ext_base_url(arg: &str) -> Result<String> {
|
||||
const FALLBACK_PG_EXT_GATEWAY_BASE_URL: &str =
|
||||
"http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local";
|
||||
|
||||
Ok(if arg.starts_with("http") {
|
||||
arg
|
||||
} else {
|
||||
FALLBACK_PG_EXT_GATEWAY_BASE_URL
|
||||
}
|
||||
.to_owned())
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(rename_all = "kebab-case")]
|
||||
struct Cli {
|
||||
@@ -79,9 +64,8 @@ struct Cli {
|
||||
pub pgbin: String,
|
||||
|
||||
/// The base URL for the remote extension storage proxy gateway.
|
||||
/// Should be in the form of `http(s)://<gateway-hostname>[:<port>]`.
|
||||
#[arg(short = 'r', long, value_parser = parse_remote_ext_base_url, alias = "remote-ext-config")]
|
||||
pub remote_ext_base_url: Option<String>,
|
||||
#[arg(short = 'r', long)]
|
||||
pub remote_ext_base_url: Option<Url>,
|
||||
|
||||
/// The port to bind the external listening HTTP server to. Clients running
|
||||
/// outside the compute will talk to the compute through this port. Keep
|
||||
@@ -136,6 +120,10 @@ struct Cli {
|
||||
requires = "compute-id"
|
||||
)]
|
||||
pub control_plane_uri: Option<String>,
|
||||
|
||||
/// Interval in seconds for collecting installed extensions statistics
|
||||
#[arg(long, default_value = "3600")]
|
||||
pub installed_extensions_collection_interval: u64,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
@@ -179,6 +167,7 @@ fn main() -> Result<()> {
|
||||
cgroup: cli.cgroup,
|
||||
#[cfg(target_os = "linux")]
|
||||
vm_monitor_addr: cli.vm_monitor_addr,
|
||||
installed_extensions_collection_interval: cli.installed_extensions_collection_interval,
|
||||
},
|
||||
config,
|
||||
)?;
|
||||
@@ -271,18 +260,4 @@ mod test {
|
||||
fn verify_cli() {
|
||||
Cli::command().debug_assert()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_pg_ext_gateway_base_url() {
|
||||
let arg = "http://pg-ext-s3-gateway2";
|
||||
let result = super::parse_remote_ext_base_url(arg).unwrap();
|
||||
assert_eq!(result, arg);
|
||||
|
||||
let arg = "pg-ext-s3-gateway";
|
||||
let result = super::parse_remote_ext_base_url(arg).unwrap();
|
||||
assert_eq!(
|
||||
result,
|
||||
"http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -339,6 +339,8 @@ async fn run_dump_restore(
|
||||
destination_connstring: String,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let dumpdir = workdir.join("dumpdir");
|
||||
let num_jobs = num_cpus::get().to_string();
|
||||
info!("using {num_jobs} jobs for dump/restore");
|
||||
|
||||
let common_args = [
|
||||
// schema mapping (prob suffices to specify them on one side)
|
||||
@@ -354,7 +356,7 @@ async fn run_dump_restore(
|
||||
"directory".to_string(),
|
||||
// concurrency
|
||||
"--jobs".to_string(),
|
||||
num_cpus::get().to_string(),
|
||||
num_jobs,
|
||||
// progress updates
|
||||
"--verbose".to_string(),
|
||||
];
|
||||
|
||||
@@ -31,6 +31,7 @@ use std::time::{Duration, Instant};
|
||||
use std::{env, fs};
|
||||
use tokio::spawn;
|
||||
use tracing::{Instrument, debug, error, info, instrument, warn};
|
||||
use url::Url;
|
||||
use utils::id::{TenantId, TimelineId};
|
||||
use utils::lsn::Lsn;
|
||||
use utils::measured_stream::MeasuredReader;
|
||||
@@ -96,7 +97,10 @@ pub struct ComputeNodeParams {
|
||||
pub internal_http_port: u16,
|
||||
|
||||
/// the address of extension storage proxy gateway
|
||||
pub remote_ext_base_url: Option<String>,
|
||||
pub remote_ext_base_url: Option<Url>,
|
||||
|
||||
/// Interval for installed extensions collection
|
||||
pub installed_extensions_collection_interval: u64,
|
||||
}
|
||||
|
||||
/// Compute node info shared across several `compute_ctl` threads.
|
||||
@@ -742,17 +746,7 @@ impl ComputeNode {
|
||||
|
||||
let conf = self.get_tokio_conn_conf(None);
|
||||
tokio::task::spawn(async {
|
||||
let res = get_installed_extensions(conf).await;
|
||||
match res {
|
||||
Ok(extensions) => {
|
||||
info!(
|
||||
"[NEON_EXT_STAT] {}",
|
||||
serde_json::to_string(&extensions)
|
||||
.expect("failed to serialize extensions list")
|
||||
);
|
||||
}
|
||||
Err(err) => error!("could not get installed extensions: {err:?}"),
|
||||
}
|
||||
let _ = installed_extensions(conf).await;
|
||||
});
|
||||
}
|
||||
|
||||
@@ -782,6 +776,9 @@ impl ComputeNode {
|
||||
// Log metrics so that we can search for slow operations in logs
|
||||
info!(?metrics, postmaster_pid = %postmaster_pid, "compute start finished");
|
||||
|
||||
// Spawn the extension stats background task
|
||||
self.spawn_extension_stats_task();
|
||||
|
||||
if pspec.spec.prewarm_lfc_on_startup {
|
||||
self.prewarm_lfc();
|
||||
}
|
||||
@@ -2192,6 +2189,41 @@ LIMIT 100",
|
||||
info!("Pageserver config changed");
|
||||
}
|
||||
}
|
||||
|
||||
pub fn spawn_extension_stats_task(&self) {
|
||||
let conf = self.tokio_conn_conf.clone();
|
||||
let installed_extensions_collection_interval =
|
||||
self.params.installed_extensions_collection_interval;
|
||||
tokio::spawn(async move {
|
||||
// An initial sleep is added to ensure that two collections don't happen at the same time.
|
||||
// The first collection happens during compute startup.
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(
|
||||
installed_extensions_collection_interval,
|
||||
))
|
||||
.await;
|
||||
let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(
|
||||
installed_extensions_collection_interval,
|
||||
));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
let _ = installed_extensions(conf.clone()).await;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn installed_extensions(conf: tokio_postgres::Config) -> Result<()> {
|
||||
let res = get_installed_extensions(conf).await;
|
||||
match res {
|
||||
Ok(extensions) => {
|
||||
info!(
|
||||
"[NEON_EXT_STAT] {}",
|
||||
serde_json::to_string(&extensions).expect("failed to serialize extensions list")
|
||||
);
|
||||
}
|
||||
Err(err) => error!("could not get installed extensions: {err:?}"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn forward_termination_signal() {
|
||||
|
||||
@@ -83,6 +83,7 @@ use reqwest::StatusCode;
|
||||
use tar::Archive;
|
||||
use tracing::info;
|
||||
use tracing::log::warn;
|
||||
use url::Url;
|
||||
use zstd::stream::read::Decoder;
|
||||
|
||||
use crate::metrics::{REMOTE_EXT_REQUESTS_TOTAL, UNKNOWN_HTTP_STATUS};
|
||||
@@ -158,14 +159,14 @@ fn parse_pg_version(human_version: &str) -> PostgresMajorVersion {
|
||||
pub async fn download_extension(
|
||||
ext_name: &str,
|
||||
ext_path: &RemotePath,
|
||||
remote_ext_base_url: &str,
|
||||
remote_ext_base_url: &Url,
|
||||
pgbin: &str,
|
||||
) -> Result<u64> {
|
||||
info!("Download extension {:?} from {:?}", ext_name, ext_path);
|
||||
|
||||
// TODO add retry logic
|
||||
let download_buffer =
|
||||
match download_extension_tar(remote_ext_base_url, &ext_path.to_string()).await {
|
||||
match download_extension_tar(remote_ext_base_url.as_str(), &ext_path.to_string()).await {
|
||||
Ok(buffer) => buffer,
|
||||
Err(error_message) => {
|
||||
return Err(anyhow::anyhow!(
|
||||
|
||||
@@ -20,7 +20,7 @@ first_path="$(ldconfig --verbose 2>/dev/null \
|
||||
| grep --invert-match ^$'\t' \
|
||||
| cut --delimiter=: --fields=1 \
|
||||
| head --lines=1)"
|
||||
test "$first_path" == '/usr/local/lib' || true # Remove the || true in a follow-up PR. Needed for backwards compat.
|
||||
test "$first_path" == '/usr/local/lib'
|
||||
|
||||
echo "Waiting pageserver become ready."
|
||||
while ! nc -z pageserver 6400; do
|
||||
|
||||
@@ -27,6 +27,7 @@ pub use prometheus::{
|
||||
|
||||
pub mod launch_timestamp;
|
||||
mod wrappers;
|
||||
pub use prometheus;
|
||||
pub use wrappers::{CountedReader, CountedWriter};
|
||||
mod hll;
|
||||
pub use hll::{HyperLogLog, HyperLogLogState, HyperLogLogVec};
|
||||
|
||||
@@ -354,6 +354,9 @@ pub struct ShardImportProgressV1 {
|
||||
pub completed: usize,
|
||||
/// Hash of the plan
|
||||
pub import_plan_hash: u64,
|
||||
/// Soft limit for the job size
|
||||
/// This needs to remain constant throughout the import
|
||||
pub job_soft_size_limit: usize,
|
||||
}
|
||||
|
||||
impl ShardImportStatus {
|
||||
|
||||
@@ -4,6 +4,7 @@ use std::{sync::Arc, time::Duration};
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{Instrument, info_span};
|
||||
|
||||
use crate::{FeatureStore, PostHogClient, PostHogClientConfig};
|
||||
|
||||
@@ -26,31 +27,35 @@ impl FeatureResolverBackgroundLoop {
|
||||
pub fn spawn(self: Arc<Self>, handle: &tokio::runtime::Handle, refresh_period: Duration) {
|
||||
let this = self.clone();
|
||||
let cancel = self.cancel.clone();
|
||||
handle.spawn(async move {
|
||||
tracing::info!("Starting PostHog feature resolver");
|
||||
let mut ticker = tokio::time::interval(refresh_period);
|
||||
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = ticker.tick() => {}
|
||||
_ = cancel.cancelled() => break
|
||||
}
|
||||
let resp = match this
|
||||
.posthog_client
|
||||
.get_feature_flags_local_evaluation()
|
||||
.await
|
||||
{
|
||||
Ok(resp) => resp,
|
||||
Err(e) => {
|
||||
tracing::warn!("Cannot get feature flags: {}", e);
|
||||
continue;
|
||||
handle.spawn(
|
||||
async move {
|
||||
tracing::info!("Starting PostHog feature resolver");
|
||||
let mut ticker = tokio::time::interval(refresh_period);
|
||||
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = ticker.tick() => {}
|
||||
_ = cancel.cancelled() => break
|
||||
}
|
||||
};
|
||||
let feature_store = FeatureStore::new_with_flags(resp.flags);
|
||||
this.feature_store.store(Arc::new(feature_store));
|
||||
let resp = match this
|
||||
.posthog_client
|
||||
.get_feature_flags_local_evaluation()
|
||||
.await
|
||||
{
|
||||
Ok(resp) => resp,
|
||||
Err(e) => {
|
||||
tracing::warn!("Cannot get feature flags: {}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let feature_store = FeatureStore::new_with_flags(resp.flags);
|
||||
this.feature_store.store(Arc::new(feature_store));
|
||||
tracing::info!("Feature flag updated");
|
||||
}
|
||||
tracing::info!("PostHog feature resolver stopped");
|
||||
}
|
||||
tracing::info!("PostHog feature resolver stopped");
|
||||
});
|
||||
.instrument(info_span!("posthog_feature_resolver")),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn feature_store(&self) -> Arc<FeatureStore> {
|
||||
|
||||
@@ -37,7 +37,7 @@ pub struct LocalEvaluationFlag {
|
||||
#[derive(Deserialize)]
|
||||
pub struct LocalEvaluationFlagFilters {
|
||||
groups: Vec<LocalEvaluationFlagFilterGroup>,
|
||||
multivariate: LocalEvaluationFlagMultivariate,
|
||||
multivariate: Option<LocalEvaluationFlagMultivariate>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@@ -254,7 +254,7 @@ impl FeatureStore {
|
||||
}
|
||||
}
|
||||
|
||||
/// Evaluate a multivariate feature flag. Returns `None` if the flag is not available or if there are errors
|
||||
/// Evaluate a multivariate feature flag. Returns an error if the flag is not available or if there are errors
|
||||
/// during the evaluation.
|
||||
///
|
||||
/// The parsing logic is as follows:
|
||||
@@ -272,6 +272,10 @@ impl FeatureStore {
|
||||
/// Example: we have a multivariate flag with 3 groups of the configured global rollout percentage: A (10%), B (20%), C (70%).
|
||||
/// There is a single group with a condition that has a rollout percentage of 10% and it does not have a variant override.
|
||||
/// Then, we will have 1% of the users evaluated to A, 2% to B, and 7% to C.
|
||||
///
|
||||
/// Error handling: the caller should inspect the error and decide the behavior when a feature flag
|
||||
/// cannot be evaluated (i.e., default to false if it cannot be resolved). The error should *not* be
|
||||
/// propagated beyond where the feature flag gets resolved.
|
||||
pub fn evaluate_multivariate(
|
||||
&self,
|
||||
flag_key: &str,
|
||||
@@ -290,6 +294,35 @@ impl FeatureStore {
|
||||
)
|
||||
}
|
||||
|
||||
/// Evaluate a boolean feature flag. Returns an error if the flag is not available or if there are errors
|
||||
/// during the evaluation.
|
||||
///
|
||||
/// The parsing logic is as follows:
|
||||
///
|
||||
/// * Generate a consistent hash for the tenant-feature.
|
||||
/// * Match each filter group.
|
||||
/// - If a group is matched, it will first determine whether the user is in the range of the rollout
|
||||
/// percentage.
|
||||
/// - If the hash falls within the group's rollout percentage, return true.
|
||||
/// * Otherwise, continue with the next group until all groups are evaluated and no group is within the
|
||||
/// rollout percentage.
|
||||
/// * If there are no matching groups, return an error.
|
||||
///
|
||||
/// Returns `Ok(())` if the feature flag evaluates to true. In the future, it will return a payload.
|
||||
///
|
||||
/// Error handling: the caller should inspect the error and decide the behavior when a feature flag
|
||||
/// cannot be evaluated (i.e., default to false if it cannot be resolved). The error should *not* be
|
||||
/// propagated beyond where the feature flag gets resolved.
|
||||
pub fn evaluate_boolean(
|
||||
&self,
|
||||
flag_key: &str,
|
||||
user_id: &str,
|
||||
properties: &HashMap<String, PostHogFlagFilterPropertyValue>,
|
||||
) -> Result<(), PostHogEvaluationError> {
|
||||
let hash_on_global_rollout_percentage = Self::consistent_hash(user_id, flag_key, "boolean");
|
||||
self.evaluate_boolean_inner(flag_key, hash_on_global_rollout_percentage, properties)
|
||||
}
|
||||
|
||||
/// Evaluate a multivariate feature flag. Note that we directly take the mapped user ID
|
||||
/// (a consistent hash ranging from 0 to 1) so that it is easier to use it in the tests
|
||||
/// and avoid duplicate computations.
|
||||
@@ -316,6 +349,11 @@ impl FeatureStore {
|
||||
flag_key
|
||||
)));
|
||||
}
|
||||
let Some(ref multivariate) = flag_config.filters.multivariate else {
|
||||
return Err(PostHogEvaluationError::Internal(format!(
|
||||
"No multivariate available, should use evaluate_boolean?: {flag_key}"
|
||||
)));
|
||||
};
|
||||
// TODO: sort the groups so that variant overrides always get evaluated first and it follows the PostHog
|
||||
// Python SDK behavior; for now we do not configure conditions without variant overrides in Neon so it
|
||||
// does not matter.
|
||||
@@ -324,7 +362,7 @@ impl FeatureStore {
|
||||
GroupEvaluationResult::MatchedAndOverride(variant) => return Ok(variant),
|
||||
GroupEvaluationResult::MatchedAndEvaluate => {
|
||||
let mut percentage = 0;
|
||||
for variant in &flag_config.filters.multivariate.variants {
|
||||
for variant in &multivariate.variants {
|
||||
percentage += variant.rollout_percentage;
|
||||
if self
|
||||
.evaluate_percentage(hash_on_global_rollout_percentage, percentage)
|
||||
@@ -352,6 +390,76 @@ impl FeatureStore {
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
/// Evaluate a multivariate feature flag. Note that we directly take the mapped user ID
|
||||
/// (a consistent hash ranging from 0 to 1) so that it is easier to use it in the tests
|
||||
/// and avoid duplicate computations.
|
||||
///
|
||||
/// Use a different consistent hash for evaluating the group rollout percentage.
|
||||
/// The behavior: if the condition is set to rolling out to 10% of the users, and
|
||||
/// we set the variant A to 20% in the global config, then 2% of the total users will
|
||||
/// be evaluated to variant A.
|
||||
///
|
||||
/// Note that the hash to determine group rollout percentage is shared across all groups. So if we have two
|
||||
/// exactly-the-same conditions with 10% and 20% rollout percentage respectively, a total of 20% of the users
|
||||
/// will be evaluated (versus 30% if group evaluation is done independently).
|
||||
pub(crate) fn evaluate_boolean_inner(
|
||||
&self,
|
||||
flag_key: &str,
|
||||
hash_on_global_rollout_percentage: f64,
|
||||
properties: &HashMap<String, PostHogFlagFilterPropertyValue>,
|
||||
) -> Result<(), PostHogEvaluationError> {
|
||||
if let Some(flag_config) = self.flags.get(flag_key) {
|
||||
if !flag_config.active {
|
||||
return Err(PostHogEvaluationError::NotAvailable(format!(
|
||||
"The feature flag is not active: {}",
|
||||
flag_key
|
||||
)));
|
||||
}
|
||||
if flag_config.filters.multivariate.is_some() {
|
||||
return Err(PostHogEvaluationError::Internal(format!(
|
||||
"This looks like a multivariate flag, should use evaluate_multivariate?: {flag_key}"
|
||||
)));
|
||||
};
|
||||
// TODO: sort the groups so that variant overrides always get evaluated first and it follows the PostHog
|
||||
// Python SDK behavior; for now we do not configure conditions without variant overrides in Neon so it
|
||||
// does not matter.
|
||||
for group in &flag_config.filters.groups {
|
||||
match self.evaluate_group(group, hash_on_global_rollout_percentage, properties)? {
|
||||
GroupEvaluationResult::MatchedAndOverride(_) => {
|
||||
return Err(PostHogEvaluationError::Internal(format!(
|
||||
"Boolean flag cannot have overrides: {}",
|
||||
flag_key
|
||||
)));
|
||||
}
|
||||
GroupEvaluationResult::MatchedAndEvaluate => {
|
||||
return Ok(());
|
||||
}
|
||||
GroupEvaluationResult::Unmatched => continue,
|
||||
}
|
||||
}
|
||||
// If no group is matched, the feature is not available, and up to the caller to decide what to do.
|
||||
Err(PostHogEvaluationError::NoConditionGroupMatched)
|
||||
} else {
|
||||
// The feature flag is not available yet
|
||||
Err(PostHogEvaluationError::NotAvailable(format!(
|
||||
"Not found in the local evaluation spec: {}",
|
||||
flag_key
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
/// Infer whether a feature flag is a boolean flag by checking if it has a multivariate filter.
|
||||
pub fn is_feature_flag_boolean(&self, flag_key: &str) -> Result<bool, PostHogEvaluationError> {
|
||||
if let Some(flag_config) = self.flags.get(flag_key) {
|
||||
Ok(flag_config.filters.multivariate.is_none())
|
||||
} else {
|
||||
Err(PostHogEvaluationError::NotAvailable(format!(
|
||||
"Not found in the local evaluation spec: {}",
|
||||
flag_key
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PostHogClientConfig {
|
||||
@@ -432,7 +540,15 @@ impl PostHogClient {
|
||||
.bearer_auth(&self.config.server_api_key)
|
||||
.send()
|
||||
.await?;
|
||||
let status = response.status();
|
||||
let body = response.text().await?;
|
||||
if !status.is_success() {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Failed to get feature flags: {}, {}",
|
||||
status,
|
||||
body
|
||||
));
|
||||
}
|
||||
Ok(serde_json::from_str(&body)?)
|
||||
}
|
||||
|
||||
@@ -469,95 +585,162 @@ mod tests {
|
||||
|
||||
fn data() -> &'static str {
|
||||
r#"{
|
||||
"flags": [
|
||||
{
|
||||
"id": 132794,
|
||||
"team_id": 152860,
|
||||
"name": "",
|
||||
"key": "gc-compaction",
|
||||
"filters": {
|
||||
"groups": [
|
||||
{
|
||||
"variant": "enabled-stage-2",
|
||||
"properties": [
|
||||
{
|
||||
"key": "plan_type",
|
||||
"type": "person",
|
||||
"value": [
|
||||
"free"
|
||||
],
|
||||
"operator": "exact"
|
||||
},
|
||||
{
|
||||
"key": "pageserver_remote_size",
|
||||
"type": "person",
|
||||
"value": "10000000",
|
||||
"operator": "lt"
|
||||
}
|
||||
],
|
||||
"rollout_percentage": 50
|
||||
},
|
||||
{
|
||||
"properties": [
|
||||
{
|
||||
"key": "plan_type",
|
||||
"type": "person",
|
||||
"value": [
|
||||
"free"
|
||||
],
|
||||
"operator": "exact"
|
||||
},
|
||||
{
|
||||
"key": "pageserver_remote_size",
|
||||
"type": "person",
|
||||
"value": "10000000",
|
||||
"operator": "lt"
|
||||
}
|
||||
],
|
||||
"rollout_percentage": 80
|
||||
}
|
||||
],
|
||||
"payloads": {},
|
||||
"multivariate": {
|
||||
"variants": [
|
||||
{
|
||||
"key": "disabled",
|
||||
"name": "",
|
||||
"rollout_percentage": 90
|
||||
},
|
||||
{
|
||||
"key": "enabled-stage-1",
|
||||
"name": "",
|
||||
"rollout_percentage": 10
|
||||
},
|
||||
{
|
||||
"key": "enabled-stage-2",
|
||||
"name": "",
|
||||
"rollout_percentage": 0
|
||||
},
|
||||
{
|
||||
"key": "enabled-stage-3",
|
||||
"name": "",
|
||||
"rollout_percentage": 0
|
||||
},
|
||||
{
|
||||
"key": "enabled",
|
||||
"name": "",
|
||||
"rollout_percentage": 0
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"deleted": false,
|
||||
"active": true,
|
||||
"ensure_experience_continuity": false,
|
||||
"has_encrypted_payloads": false,
|
||||
"version": 6
|
||||
}
|
||||
"flags": [
|
||||
{
|
||||
"id": 141807,
|
||||
"team_id": 152860,
|
||||
"name": "",
|
||||
"key": "image-compaction-boundary",
|
||||
"filters": {
|
||||
"groups": [
|
||||
{
|
||||
"variant": null,
|
||||
"properties": [
|
||||
{
|
||||
"key": "plan_type",
|
||||
"type": "person",
|
||||
"value": [
|
||||
"free"
|
||||
],
|
||||
"operator": "exact"
|
||||
}
|
||||
],
|
||||
"group_type_mapping": {},
|
||||
"cohorts": {}
|
||||
}"#
|
||||
"rollout_percentage": 40
|
||||
},
|
||||
{
|
||||
"variant": null,
|
||||
"properties": [],
|
||||
"rollout_percentage": 10
|
||||
}
|
||||
],
|
||||
"payloads": {},
|
||||
"multivariate": null
|
||||
},
|
||||
"deleted": false,
|
||||
"active": true,
|
||||
"ensure_experience_continuity": false,
|
||||
"has_encrypted_payloads": false,
|
||||
"version": 1
|
||||
},
|
||||
{
|
||||
"id": 135586,
|
||||
"team_id": 152860,
|
||||
"name": "",
|
||||
"key": "boolean-flag",
|
||||
"filters": {
|
||||
"groups": [
|
||||
{
|
||||
"variant": null,
|
||||
"properties": [
|
||||
{
|
||||
"key": "plan_type",
|
||||
"type": "person",
|
||||
"value": [
|
||||
"free"
|
||||
],
|
||||
"operator": "exact"
|
||||
}
|
||||
],
|
||||
"rollout_percentage": 47
|
||||
}
|
||||
],
|
||||
"payloads": {},
|
||||
"multivariate": null
|
||||
},
|
||||
"deleted": false,
|
||||
"active": true,
|
||||
"ensure_experience_continuity": false,
|
||||
"has_encrypted_payloads": false,
|
||||
"version": 1
|
||||
},
|
||||
{
|
||||
"id": 132794,
|
||||
"team_id": 152860,
|
||||
"name": "",
|
||||
"key": "gc-compaction",
|
||||
"filters": {
|
||||
"groups": [
|
||||
{
|
||||
"variant": "enabled-stage-2",
|
||||
"properties": [
|
||||
{
|
||||
"key": "plan_type",
|
||||
"type": "person",
|
||||
"value": [
|
||||
"free"
|
||||
],
|
||||
"operator": "exact"
|
||||
},
|
||||
{
|
||||
"key": "pageserver_remote_size",
|
||||
"type": "person",
|
||||
"value": "10000000",
|
||||
"operator": "lt"
|
||||
}
|
||||
],
|
||||
"rollout_percentage": 50
|
||||
},
|
||||
{
|
||||
"properties": [
|
||||
{
|
||||
"key": "plan_type",
|
||||
"type": "person",
|
||||
"value": [
|
||||
"free"
|
||||
],
|
||||
"operator": "exact"
|
||||
},
|
||||
{
|
||||
"key": "pageserver_remote_size",
|
||||
"type": "person",
|
||||
"value": "10000000",
|
||||
"operator": "lt"
|
||||
}
|
||||
],
|
||||
"rollout_percentage": 80
|
||||
}
|
||||
],
|
||||
"payloads": {},
|
||||
"multivariate": {
|
||||
"variants": [
|
||||
{
|
||||
"key": "disabled",
|
||||
"name": "",
|
||||
"rollout_percentage": 90
|
||||
},
|
||||
{
|
||||
"key": "enabled-stage-1",
|
||||
"name": "",
|
||||
"rollout_percentage": 10
|
||||
},
|
||||
{
|
||||
"key": "enabled-stage-2",
|
||||
"name": "",
|
||||
"rollout_percentage": 0
|
||||
},
|
||||
{
|
||||
"key": "enabled-stage-3",
|
||||
"name": "",
|
||||
"rollout_percentage": 0
|
||||
},
|
||||
{
|
||||
"key": "enabled",
|
||||
"name": "",
|
||||
"rollout_percentage": 0
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"deleted": false,
|
||||
"active": true,
|
||||
"ensure_experience_continuity": false,
|
||||
"has_encrypted_payloads": false,
|
||||
"version": 7
|
||||
}
|
||||
],
|
||||
"group_type_mapping": {},
|
||||
"cohorts": {}
|
||||
}"#
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -633,4 +816,125 @@ mod tests {
|
||||
Err(PostHogEvaluationError::NoConditionGroupMatched)
|
||||
),);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn evaluate_boolean_1() {
|
||||
// The `boolean-flag` feature flag only has one group that matches on the free user.
|
||||
|
||||
let mut store = FeatureStore::new();
|
||||
let response: LocalEvaluationResponse = serde_json::from_str(data()).unwrap();
|
||||
store.set_flags(response.flags);
|
||||
|
||||
// This lacks the required properties and cannot be evaluated.
|
||||
let variant = store.evaluate_boolean_inner("boolean-flag", 1.00, &HashMap::new());
|
||||
assert!(matches!(
|
||||
variant,
|
||||
Err(PostHogEvaluationError::NotAvailable(_))
|
||||
),);
|
||||
|
||||
let properties_unmatched = HashMap::from([
|
||||
(
|
||||
"plan_type".to_string(),
|
||||
PostHogFlagFilterPropertyValue::String("paid".to_string()),
|
||||
),
|
||||
(
|
||||
"pageserver_remote_size".to_string(),
|
||||
PostHogFlagFilterPropertyValue::Number(1000.0),
|
||||
),
|
||||
]);
|
||||
|
||||
// This does not match any group so there will be an error.
|
||||
let variant = store.evaluate_boolean_inner("boolean-flag", 1.00, &properties_unmatched);
|
||||
assert!(matches!(
|
||||
variant,
|
||||
Err(PostHogEvaluationError::NoConditionGroupMatched)
|
||||
),);
|
||||
|
||||
let properties = HashMap::from([
|
||||
(
|
||||
"plan_type".to_string(),
|
||||
PostHogFlagFilterPropertyValue::String("free".to_string()),
|
||||
),
|
||||
(
|
||||
"pageserver_remote_size".to_string(),
|
||||
PostHogFlagFilterPropertyValue::Number(1000.0),
|
||||
),
|
||||
]);
|
||||
|
||||
// It matches the first group as 0.10 <= 0.50 and the properties are matched. Then it gets evaluated to the variant override.
|
||||
let variant = store.evaluate_boolean_inner("boolean-flag", 0.10, &properties);
|
||||
assert!(variant.is_ok());
|
||||
|
||||
// It matches the group conditions but not the group rollout percentage.
|
||||
let variant = store.evaluate_boolean_inner("boolean-flag", 1.00, &properties);
|
||||
assert!(matches!(
|
||||
variant,
|
||||
Err(PostHogEvaluationError::NoConditionGroupMatched)
|
||||
),);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn evaluate_boolean_2() {
|
||||
// The `image-compaction-boundary` feature flag has one group that matches on the free user and a group that matches on all users.
|
||||
|
||||
let mut store = FeatureStore::new();
|
||||
let response: LocalEvaluationResponse = serde_json::from_str(data()).unwrap();
|
||||
store.set_flags(response.flags);
|
||||
|
||||
// This lacks the required properties and cannot be evaluated.
|
||||
let variant =
|
||||
store.evaluate_boolean_inner("image-compaction-boundary", 1.00, &HashMap::new());
|
||||
assert!(matches!(
|
||||
variant,
|
||||
Err(PostHogEvaluationError::NotAvailable(_))
|
||||
),);
|
||||
|
||||
let properties_unmatched = HashMap::from([
|
||||
(
|
||||
"plan_type".to_string(),
|
||||
PostHogFlagFilterPropertyValue::String("paid".to_string()),
|
||||
),
|
||||
(
|
||||
"pageserver_remote_size".to_string(),
|
||||
PostHogFlagFilterPropertyValue::Number(1000.0),
|
||||
),
|
||||
]);
|
||||
|
||||
// This does not match the filtered group but the all user group.
|
||||
let variant =
|
||||
store.evaluate_boolean_inner("image-compaction-boundary", 1.00, &properties_unmatched);
|
||||
assert!(matches!(
|
||||
variant,
|
||||
Err(PostHogEvaluationError::NoConditionGroupMatched)
|
||||
),);
|
||||
let variant =
|
||||
store.evaluate_boolean_inner("image-compaction-boundary", 0.05, &properties_unmatched);
|
||||
assert!(variant.is_ok());
|
||||
|
||||
let properties = HashMap::from([
|
||||
(
|
||||
"plan_type".to_string(),
|
||||
PostHogFlagFilterPropertyValue::String("free".to_string()),
|
||||
),
|
||||
(
|
||||
"pageserver_remote_size".to_string(),
|
||||
PostHogFlagFilterPropertyValue::Number(1000.0),
|
||||
),
|
||||
]);
|
||||
|
||||
// It matches the first group as 0.30 <= 0.40 and the properties are matched. Then it gets evaluated to the variant override.
|
||||
let variant = store.evaluate_boolean_inner("image-compaction-boundary", 0.30, &properties);
|
||||
assert!(variant.is_ok());
|
||||
|
||||
// It matches the group conditions but not the group rollout percentage.
|
||||
let variant = store.evaluate_boolean_inner("image-compaction-boundary", 1.00, &properties);
|
||||
assert!(matches!(
|
||||
variant,
|
||||
Err(PostHogEvaluationError::NoConditionGroupMatched)
|
||||
),);
|
||||
|
||||
// It matches the second "all" group conditions.
|
||||
let variant = store.evaluate_boolean_inner("image-compaction-boundary", 0.09, &properties);
|
||||
assert!(variant.is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -264,10 +264,56 @@ mod propagation_of_cached_label_value {
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(histograms, histograms::bench_bucket_scalability);
|
||||
mod histograms {
|
||||
use std::time::Instant;
|
||||
|
||||
use criterion::{BenchmarkId, Criterion};
|
||||
use metrics::core::Collector;
|
||||
|
||||
pub fn bench_bucket_scalability(c: &mut Criterion) {
|
||||
let mut g = c.benchmark_group("bucket_scalability");
|
||||
|
||||
for n in [1, 4, 8, 16, 32, 64, 128, 256] {
|
||||
g.bench_with_input(BenchmarkId::new("nbuckets", n), &n, |b, n| {
|
||||
b.iter_custom(|iters| {
|
||||
let buckets: Vec<f64> = (0..*n).map(|i| i as f64 * 100.0).collect();
|
||||
let histo = metrics::Histogram::with_opts(
|
||||
metrics::prometheus::HistogramOpts::new("name", "help")
|
||||
.buckets(buckets.clone()),
|
||||
)
|
||||
.unwrap();
|
||||
let start = Instant::now();
|
||||
for i in 0..usize::try_from(iters).unwrap() {
|
||||
histo.observe(buckets[i % buckets.len()]);
|
||||
}
|
||||
let elapsed = start.elapsed();
|
||||
// self-test
|
||||
let mfs = histo.collect();
|
||||
assert_eq!(mfs.len(), 1);
|
||||
let metrics = mfs[0].get_metric();
|
||||
assert_eq!(metrics.len(), 1);
|
||||
let histo = metrics[0].get_histogram();
|
||||
let buckets = histo.get_bucket();
|
||||
assert!(
|
||||
buckets
|
||||
.iter()
|
||||
.enumerate()
|
||||
.all(|(i, b)| b.get_cumulative_count()
|
||||
>= i as u64 * (iters / buckets.len() as u64))
|
||||
);
|
||||
elapsed
|
||||
})
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
criterion_main!(
|
||||
label_values,
|
||||
single_metric_multicore_scalability,
|
||||
propagation_of_cached_label_value
|
||||
propagation_of_cached_label_value,
|
||||
histograms,
|
||||
);
|
||||
|
||||
/*
|
||||
@@ -290,6 +336,14 @@ propagation_of_cached_label_value__naive/nthreads/8 time: [211.50 ns 214.44 ns
|
||||
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/1 time: [14.135 ns 14.147 ns 14.160 ns]
|
||||
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/4 time: [14.243 ns 14.255 ns 14.268 ns]
|
||||
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/8 time: [14.470 ns 14.682 ns 14.895 ns]
|
||||
bucket_scalability/nbuckets/1 time: [30.352 ns 30.353 ns 30.354 ns]
|
||||
bucket_scalability/nbuckets/4 time: [30.464 ns 30.465 ns 30.467 ns]
|
||||
bucket_scalability/nbuckets/8 time: [30.569 ns 30.575 ns 30.584 ns]
|
||||
bucket_scalability/nbuckets/16 time: [30.961 ns 30.965 ns 30.969 ns]
|
||||
bucket_scalability/nbuckets/32 time: [35.691 ns 35.707 ns 35.722 ns]
|
||||
bucket_scalability/nbuckets/64 time: [47.829 ns 47.898 ns 47.974 ns]
|
||||
bucket_scalability/nbuckets/128 time: [73.479 ns 73.512 ns 73.545 ns]
|
||||
bucket_scalability/nbuckets/256 time: [127.92 ns 127.94 ns 127.96 ns]
|
||||
|
||||
Results on an i3en.3xlarge instance
|
||||
|
||||
@@ -344,6 +398,14 @@ propagation_of_cached_label_value__naive/nthreads/8 time: [434.87 ns 456.4
|
||||
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/1 time: [3.3767 ns 3.3974 ns 3.4220 ns]
|
||||
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/4 time: [3.6105 ns 4.2355 ns 5.1463 ns]
|
||||
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/8 time: [4.0889 ns 4.9714 ns 6.0779 ns]
|
||||
bucket_scalability/nbuckets/1 time: [4.8455 ns 4.8542 ns 4.8646 ns]
|
||||
bucket_scalability/nbuckets/4 time: [4.5663 ns 4.5722 ns 4.5787 ns]
|
||||
bucket_scalability/nbuckets/8 time: [4.5531 ns 4.5670 ns 4.5842 ns]
|
||||
bucket_scalability/nbuckets/16 time: [4.6392 ns 4.6524 ns 4.6685 ns]
|
||||
bucket_scalability/nbuckets/32 time: [6.0302 ns 6.0439 ns 6.0589 ns]
|
||||
bucket_scalability/nbuckets/64 time: [10.608 ns 10.644 ns 10.691 ns]
|
||||
bucket_scalability/nbuckets/128 time: [22.178 ns 22.316 ns 22.483 ns]
|
||||
bucket_scalability/nbuckets/256 time: [42.190 ns 42.328 ns 42.492 ns]
|
||||
|
||||
Results on a Hetzner AX102 AMD Ryzen 9 7950X3D 16-Core Processor
|
||||
|
||||
@@ -362,5 +424,13 @@ propagation_of_cached_label_value__naive/nthreads/8 time: [164.24 ns 170.1
|
||||
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/1 time: [2.2915 ns 2.2960 ns 2.3012 ns]
|
||||
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/4 time: [2.5726 ns 2.6158 ns 2.6624 ns]
|
||||
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/8 time: [2.7068 ns 2.8243 ns 2.9824 ns]
|
||||
bucket_scalability/nbuckets/1 time: [6.3998 ns 6.4288 ns 6.4684 ns]
|
||||
bucket_scalability/nbuckets/4 time: [6.3603 ns 6.3620 ns 6.3637 ns]
|
||||
bucket_scalability/nbuckets/8 time: [6.1646 ns 6.1654 ns 6.1667 ns]
|
||||
bucket_scalability/nbuckets/16 time: [6.1341 ns 6.1391 ns 6.1454 ns]
|
||||
bucket_scalability/nbuckets/32 time: [8.2206 ns 8.2254 ns 8.2301 ns]
|
||||
bucket_scalability/nbuckets/64 time: [13.988 ns 13.994 ns 14.000 ns]
|
||||
bucket_scalability/nbuckets/128 time: [28.180 ns 28.216 ns 28.251 ns]
|
||||
bucket_scalability/nbuckets/256 time: [54.914 ns 54.931 ns 54.951 ns]
|
||||
|
||||
*/
|
||||
|
||||
@@ -54,9 +54,9 @@ service PageService {
|
||||
// RPCs use regular unary requests, since they are not as frequent and
|
||||
// performance-critical, and this simplifies implementation.
|
||||
//
|
||||
// NB: a status response (e.g. errors) will terminate the stream. The stream
|
||||
// may be shared by e.g. multiple Postgres backends, so we should avoid this.
|
||||
// Most errors are therefore sent as GetPageResponse.status instead.
|
||||
// NB: a gRPC status response (e.g. errors) will terminate the stream. The
|
||||
// stream may be shared by multiple Postgres backends, so we avoid this by
|
||||
// sending them as GetPageResponse.status_code instead.
|
||||
rpc GetPages (stream GetPageRequest) returns (stream GetPageResponse);
|
||||
|
||||
// Returns the size of a relation, as # of blocks.
|
||||
@@ -159,8 +159,8 @@ message GetPageRequest {
|
||||
// A GetPageRequest class. Primarily intended for observability, but may also be
|
||||
// used for prioritization in the future.
|
||||
enum GetPageClass {
|
||||
// Unknown class. For forwards compatibility: used when the client sends a
|
||||
// class that the server doesn't know about.
|
||||
// Unknown class. For backwards compatibility: used when an older client version sends a class
|
||||
// that a newer server version has removed.
|
||||
GET_PAGE_CLASS_UNKNOWN = 0;
|
||||
// A normal request. This is the default.
|
||||
GET_PAGE_CLASS_NORMAL = 1;
|
||||
@@ -180,31 +180,37 @@ message GetPageResponse {
|
||||
// The original request's ID.
|
||||
uint64 request_id = 1;
|
||||
// The response status code.
|
||||
GetPageStatus status = 2;
|
||||
GetPageStatusCode status_code = 2;
|
||||
// A string describing the status, if any.
|
||||
string reason = 3;
|
||||
// The 8KB page images, in the same order as the request. Empty if status != OK.
|
||||
// The 8KB page images, in the same order as the request. Empty if status_code != OK.
|
||||
repeated bytes page_image = 4;
|
||||
}
|
||||
|
||||
// A GetPageResponse status code. Since we use a bidirectional stream, we don't
|
||||
// want to send errors as gRPC statuses, since this would terminate the stream.
|
||||
enum GetPageStatus {
|
||||
// Unknown status. For forwards compatibility: used when the server sends a
|
||||
// status code that the client doesn't know about.
|
||||
GET_PAGE_STATUS_UNKNOWN = 0;
|
||||
// A GetPageResponse status code.
|
||||
//
|
||||
// These are effectively equivalent to gRPC statuses. However, we use a bidirectional stream
|
||||
// (potentially shared by many backends), and a gRPC status response would terminate the stream so
|
||||
// we send GetPageResponse messages with these codes instead.
|
||||
enum GetPageStatusCode {
|
||||
// Unknown status. For forwards compatibility: used when an older client version receives a new
|
||||
// status code from a newer server version.
|
||||
GET_PAGE_STATUS_CODE_UNKNOWN = 0;
|
||||
// The request was successful.
|
||||
GET_PAGE_STATUS_OK = 1;
|
||||
GET_PAGE_STATUS_CODE_OK = 1;
|
||||
// The page did not exist. The tenant/timeline/shard has already been
|
||||
// validated during stream setup.
|
||||
GET_PAGE_STATUS_NOT_FOUND = 2;
|
||||
GET_PAGE_STATUS_CODE_NOT_FOUND = 2;
|
||||
// The request was invalid.
|
||||
GET_PAGE_STATUS_INVALID = 3;
|
||||
GET_PAGE_STATUS_CODE_INVALID_REQUEST = 3;
|
||||
// The request failed due to an internal server error.
|
||||
GET_PAGE_STATUS_CODE_INTERNAL_ERROR = 4;
|
||||
// The tenant is rate limited. Slow down and retry later.
|
||||
GET_PAGE_STATUS_SLOW_DOWN = 4;
|
||||
// TODO: consider adding a GET_PAGE_STATUS_LAYER_DOWNLOAD in the case of a
|
||||
// layer download. This could free up the server task to process other
|
||||
// requests while the layer download is in progress.
|
||||
GET_PAGE_STATUS_CODE_SLOW_DOWN = 5;
|
||||
// NB: shutdown errors are emitted as a gRPC Unavailable status.
|
||||
//
|
||||
// TODO: consider adding a GET_PAGE_STATUS_CODE_LAYER_DOWNLOAD in the case of a layer download.
|
||||
// This could free up the server task to process other requests while the download is in progress.
|
||||
}
|
||||
|
||||
// Fetches the size of a relation at a given LSN, as # of blocks. Only valid on
|
||||
|
||||
@@ -35,6 +35,12 @@ impl ProtocolError {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ProtocolError> for tonic::Status {
|
||||
fn from(err: ProtocolError) -> Self {
|
||||
tonic::Status::invalid_argument(format!("{err}"))
|
||||
}
|
||||
}
|
||||
|
||||
/// The LSN a request should read at.
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct ReadLsn {
|
||||
@@ -328,7 +334,7 @@ pub type RequestID = u64;
|
||||
/// A GetPage request class.
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub enum GetPageClass {
|
||||
/// Unknown status. For backwards compatibility: used when an older client version sends a class
|
||||
/// Unknown class. For backwards compatibility: used when an older client version sends a class
|
||||
/// that a newer server version has removed.
|
||||
Unknown,
|
||||
/// A normal request. This is the default.
|
||||
@@ -386,7 +392,7 @@ pub struct GetPageResponse {
|
||||
/// The original request's ID.
|
||||
pub request_id: RequestID,
|
||||
/// The response status code.
|
||||
pub status: GetPageStatus,
|
||||
pub status_code: GetPageStatusCode,
|
||||
/// A string describing the status, if any.
|
||||
pub reason: Option<String>,
|
||||
/// The 8KB page images, in the same order as the request. Empty if status != OK.
|
||||
@@ -397,7 +403,7 @@ impl From<proto::GetPageResponse> for GetPageResponse {
|
||||
fn from(pb: proto::GetPageResponse) -> Self {
|
||||
Self {
|
||||
request_id: pb.request_id,
|
||||
status: pb.status.into(),
|
||||
status_code: pb.status_code.into(),
|
||||
reason: Some(pb.reason).filter(|r| !r.is_empty()),
|
||||
page_images: pb.page_image.into(),
|
||||
}
|
||||
@@ -408,16 +414,20 @@ impl From<GetPageResponse> for proto::GetPageResponse {
|
||||
fn from(response: GetPageResponse) -> Self {
|
||||
Self {
|
||||
request_id: response.request_id,
|
||||
status: response.status.into(),
|
||||
status_code: response.status_code.into(),
|
||||
reason: response.reason.unwrap_or_default(),
|
||||
page_image: response.page_images.into_vec(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A GetPage response status.
|
||||
/// A GetPage response status code.
|
||||
///
|
||||
/// These are effectively equivalent to gRPC statuses. However, we use a bidirectional stream
|
||||
/// (potentially shared by many backends), and a gRPC status response would terminate the stream so
|
||||
/// we send GetPageResponse messages with these codes instead.
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub enum GetPageStatus {
|
||||
pub enum GetPageStatusCode {
|
||||
/// Unknown status. For forwards compatibility: used when an older client version receives a new
|
||||
/// status code from a newer server version.
|
||||
Unknown,
|
||||
@@ -427,46 +437,50 @@ pub enum GetPageStatus {
|
||||
/// setup.
|
||||
NotFound,
|
||||
/// The request was invalid.
|
||||
Invalid,
|
||||
InvalidRequest,
|
||||
/// The request failed due to an internal server error.
|
||||
InternalError,
|
||||
/// The tenant is rate limited. Slow down and retry later.
|
||||
SlowDown,
|
||||
}
|
||||
|
||||
impl From<proto::GetPageStatus> for GetPageStatus {
|
||||
fn from(pb: proto::GetPageStatus) -> Self {
|
||||
impl From<proto::GetPageStatusCode> for GetPageStatusCode {
|
||||
fn from(pb: proto::GetPageStatusCode) -> Self {
|
||||
match pb {
|
||||
proto::GetPageStatus::Unknown => Self::Unknown,
|
||||
proto::GetPageStatus::Ok => Self::Ok,
|
||||
proto::GetPageStatus::NotFound => Self::NotFound,
|
||||
proto::GetPageStatus::Invalid => Self::Invalid,
|
||||
proto::GetPageStatus::SlowDown => Self::SlowDown,
|
||||
proto::GetPageStatusCode::Unknown => Self::Unknown,
|
||||
proto::GetPageStatusCode::Ok => Self::Ok,
|
||||
proto::GetPageStatusCode::NotFound => Self::NotFound,
|
||||
proto::GetPageStatusCode::InvalidRequest => Self::InvalidRequest,
|
||||
proto::GetPageStatusCode::InternalError => Self::InternalError,
|
||||
proto::GetPageStatusCode::SlowDown => Self::SlowDown,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<i32> for GetPageStatus {
|
||||
fn from(status: i32) -> Self {
|
||||
proto::GetPageStatus::try_from(status)
|
||||
.unwrap_or(proto::GetPageStatus::Unknown)
|
||||
impl From<i32> for GetPageStatusCode {
|
||||
fn from(status_code: i32) -> Self {
|
||||
proto::GetPageStatusCode::try_from(status_code)
|
||||
.unwrap_or(proto::GetPageStatusCode::Unknown)
|
||||
.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<GetPageStatus> for proto::GetPageStatus {
|
||||
fn from(status: GetPageStatus) -> Self {
|
||||
match status {
|
||||
GetPageStatus::Unknown => Self::Unknown,
|
||||
GetPageStatus::Ok => Self::Ok,
|
||||
GetPageStatus::NotFound => Self::NotFound,
|
||||
GetPageStatus::Invalid => Self::Invalid,
|
||||
GetPageStatus::SlowDown => Self::SlowDown,
|
||||
impl From<GetPageStatusCode> for proto::GetPageStatusCode {
|
||||
fn from(status_code: GetPageStatusCode) -> Self {
|
||||
match status_code {
|
||||
GetPageStatusCode::Unknown => Self::Unknown,
|
||||
GetPageStatusCode::Ok => Self::Ok,
|
||||
GetPageStatusCode::NotFound => Self::NotFound,
|
||||
GetPageStatusCode::InvalidRequest => Self::InvalidRequest,
|
||||
GetPageStatusCode::InternalError => Self::InternalError,
|
||||
GetPageStatusCode::SlowDown => Self::SlowDown,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<GetPageStatus> for i32 {
|
||||
fn from(status: GetPageStatus) -> Self {
|
||||
proto::GetPageStatus::from(status).into()
|
||||
impl From<GetPageStatusCode> for i32 {
|
||||
fn from(status_code: GetPageStatusCode) -> Self {
|
||||
proto::GetPageStatusCode::from(status_code).into()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -837,7 +837,30 @@ async fn collect_eviction_candidates(
|
||||
continue;
|
||||
}
|
||||
let info = tl.get_local_layers_for_disk_usage_eviction().await;
|
||||
debug!(tenant_id=%tl.tenant_shard_id.tenant_id, shard_id=%tl.tenant_shard_id.shard_slug(), timeline_id=%tl.timeline_id, "timeline resident layers count: {}", info.resident_layers.len());
|
||||
debug!(
|
||||
tenant_id=%tl.tenant_shard_id.tenant_id,
|
||||
shard_id=%tl.tenant_shard_id.shard_slug(),
|
||||
timeline_id=%tl.timeline_id,
|
||||
"timeline resident layers count: {}", info.resident_layers.len()
|
||||
);
|
||||
|
||||
tenant_candidates.extend(info.resident_layers.into_iter());
|
||||
max_layer_size = max_layer_size.max(info.max_layer_size.unwrap_or(0));
|
||||
|
||||
if cancel.is_cancelled() {
|
||||
return Ok(EvictionCandidates::Cancelled);
|
||||
}
|
||||
}
|
||||
|
||||
// Also consider layers of timelines being imported for eviction
|
||||
for tl in tenant.list_importing_timelines() {
|
||||
let info = tl.timeline.get_local_layers_for_disk_usage_eviction().await;
|
||||
debug!(
|
||||
tenant_id=%tl.timeline.tenant_shard_id.tenant_id,
|
||||
shard_id=%tl.timeline.tenant_shard_id.shard_slug(),
|
||||
timeline_id=%tl.timeline.timeline_id,
|
||||
"timeline resident layers count: {}", info.resident_layers.len()
|
||||
);
|
||||
|
||||
tenant_candidates.extend(info.resident_layers.into_iter());
|
||||
max_layer_size = max_layer_size.max(info.max_layer_size.unwrap_or(0));
|
||||
|
||||
@@ -45,6 +45,10 @@ impl FeatureResolver {
|
||||
}
|
||||
|
||||
/// Evaluate a multivariate feature flag. Currently, we do not support any properties.
|
||||
///
|
||||
/// Error handling: the caller should inspect the error and decide the behavior when a feature flag
|
||||
/// cannot be evaluated (i.e., default to false if it cannot be resolved). The error should *not* be
|
||||
/// propagated beyond where the feature flag gets resolved.
|
||||
pub fn evaluate_multivariate(
|
||||
&self,
|
||||
flag_key: &str,
|
||||
@@ -62,4 +66,39 @@ impl FeatureResolver {
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// Evaluate a boolean feature flag. Currently, we do not support any properties.
|
||||
///
|
||||
/// Returns `Ok(())` if the flag is evaluated to true, otherwise returns an error.
|
||||
///
|
||||
/// Error handling: the caller should inspect the error and decide the behavior when a feature flag
|
||||
/// cannot be evaluated (i.e., default to false if it cannot be resolved). The error should *not* be
|
||||
/// propagated beyond where the feature flag gets resolved.
|
||||
pub fn evaluate_boolean(
|
||||
&self,
|
||||
flag_key: &str,
|
||||
tenant_id: TenantId,
|
||||
) -> Result<(), PostHogEvaluationError> {
|
||||
if let Some(inner) = &self.inner {
|
||||
inner.feature_store().evaluate_boolean(
|
||||
flag_key,
|
||||
&tenant_id.to_string(),
|
||||
&HashMap::new(),
|
||||
)
|
||||
} else {
|
||||
Err(PostHogEvaluationError::NotAvailable(
|
||||
"PostHog integration is not enabled".to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_feature_flag_boolean(&self, flag_key: &str) -> Result<bool, PostHogEvaluationError> {
|
||||
if let Some(inner) = &self.inner {
|
||||
inner.feature_store().is_feature_flag_boolean(flag_key)
|
||||
} else {
|
||||
Err(PostHogEvaluationError::NotAvailable(
|
||||
"PostHog integration is not enabled".to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -353,6 +353,33 @@ paths:
|
||||
"200":
|
||||
description: OK
|
||||
|
||||
/v1/tenant/{tenant_shard_id}/timeline/{timeline_id}/mark_invisible:
|
||||
parameters:
|
||||
- name: tenant_shard_id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: timeline_id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
format: hex
|
||||
put:
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
is_visible:
|
||||
type: boolean
|
||||
default: false
|
||||
responses:
|
||||
"200":
|
||||
description: OK
|
||||
|
||||
/v1/tenant/{tenant_shard_id}/location_config:
|
||||
parameters:
|
||||
- name: tenant_shard_id
|
||||
|
||||
@@ -3663,6 +3663,46 @@ async fn read_tar_eof(mut reader: (impl tokio::io::AsyncRead + Unpin)) -> anyhow
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn tenant_evaluate_feature_flag(
|
||||
request: Request<Body>,
|
||||
_cancel: CancellationToken,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
let tenant_shard_id: TenantShardId = parse_request_param(&request, "tenant_shard_id")?;
|
||||
check_permission(&request, Some(tenant_shard_id.tenant_id))?;
|
||||
|
||||
let flag: String = must_parse_query_param(&request, "flag")?;
|
||||
let as_type: String = must_parse_query_param(&request, "as")?;
|
||||
|
||||
let state = get_state(&request);
|
||||
|
||||
async {
|
||||
let tenant = state
|
||||
.tenant_manager
|
||||
.get_attached_tenant_shard(tenant_shard_id)?;
|
||||
if as_type == "boolean" {
|
||||
let result = tenant.feature_resolver.evaluate_boolean(&flag, tenant_shard_id.tenant_id);
|
||||
let result = result.map(|_| true).map_err(|e| e.to_string());
|
||||
json_response(StatusCode::OK, result)
|
||||
} else if as_type == "multivariate" {
|
||||
let result = tenant.feature_resolver.evaluate_multivariate(&flag, tenant_shard_id.tenant_id).map_err(|e| e.to_string());
|
||||
json_response(StatusCode::OK, result)
|
||||
} else {
|
||||
// Auto infer the type of the feature flag.
|
||||
let is_boolean = tenant.feature_resolver.is_feature_flag_boolean(&flag).map_err(|e| ApiError::InternalServerError(anyhow::anyhow!("{e}")))?;
|
||||
if is_boolean {
|
||||
let result = tenant.feature_resolver.evaluate_boolean(&flag, tenant_shard_id.tenant_id);
|
||||
let result = result.map(|_| true).map_err(|e| e.to_string());
|
||||
json_response(StatusCode::OK, result)
|
||||
} else {
|
||||
let result = tenant.feature_resolver.evaluate_multivariate(&flag, tenant_shard_id.tenant_id).map_err(|e| e.to_string());
|
||||
json_response(StatusCode::OK, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
.instrument(info_span!("tenant_evaluate_feature_flag", tenant_id = %tenant_shard_id.tenant_id, shard_id = %tenant_shard_id.shard_slug()))
|
||||
.await
|
||||
}
|
||||
|
||||
/// Common functionality of all the HTTP API handlers.
|
||||
///
|
||||
/// - Adds a tracing span to each request (by `request_span`)
|
||||
@@ -4039,5 +4079,8 @@ pub fn make_router(
|
||||
"/v1/tenant/:tenant_shard_id/timeline/:timeline_id/activate_post_import",
|
||||
|r| api_handler(r, activate_post_import_handler),
|
||||
)
|
||||
.get("/v1/tenant/:tenant_shard_id/feature_flag", |r| {
|
||||
api_handler(r, tenant_evaluate_feature_flag)
|
||||
})
|
||||
.any(handler_404))
|
||||
}
|
||||
|
||||
@@ -1312,11 +1312,44 @@ impl EvictionsWithLowResidenceDuration {
|
||||
//
|
||||
// Roughly logarithmic scale.
|
||||
const STORAGE_IO_TIME_BUCKETS: &[f64] = &[
|
||||
0.000030, // 30 usec
|
||||
0.001000, // 1000 usec
|
||||
0.030, // 30 ms
|
||||
1.000, // 1000 ms
|
||||
30.000, // 30000 ms
|
||||
0.00005, // 50us
|
||||
0.00006, // 60us
|
||||
0.00007, // 70us
|
||||
0.00008, // 80us
|
||||
0.00009, // 90us
|
||||
0.0001, // 100us
|
||||
0.000110, // 110us
|
||||
0.000120, // 120us
|
||||
0.000130, // 130us
|
||||
0.000140, // 140us
|
||||
0.000150, // 150us
|
||||
0.000160, // 160us
|
||||
0.000170, // 170us
|
||||
0.000180, // 180us
|
||||
0.000190, // 190us
|
||||
0.000200, // 200us
|
||||
0.000210, // 210us
|
||||
0.000220, // 220us
|
||||
0.000230, // 230us
|
||||
0.000240, // 240us
|
||||
0.000250, // 250us
|
||||
0.000300, // 300us
|
||||
0.000350, // 350us
|
||||
0.000400, // 400us
|
||||
0.000450, // 450us
|
||||
0.000500, // 500us
|
||||
0.000600, // 600us
|
||||
0.000700, // 700us
|
||||
0.000800, // 800us
|
||||
0.000900, // 900us
|
||||
0.001000, // 1ms
|
||||
0.002000, // 2ms
|
||||
0.003000, // 3ms
|
||||
0.004000, // 4ms
|
||||
0.005000, // 5ms
|
||||
0.01000, // 10ms
|
||||
0.02000, // 20ms
|
||||
0.05000, // 50ms
|
||||
];
|
||||
|
||||
/// VirtualFile fs operation variants.
|
||||
|
||||
@@ -769,6 +769,9 @@ struct BatchedGetPageRequest {
|
||||
timer: SmgrOpTimer,
|
||||
lsn_range: LsnRange,
|
||||
ctx: RequestContext,
|
||||
// If the request is perf enabled, this contains a context
|
||||
// with a perf span tracking the time spent waiting for the executor.
|
||||
batch_wait_ctx: Option<RequestContext>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "testing")]
|
||||
@@ -781,6 +784,7 @@ struct BatchedTestRequest {
|
||||
/// so that we don't keep the [`Timeline::gate`] open while the batch
|
||||
/// is being built up inside the [`spsc_fold`] (pagestream pipelining).
|
||||
#[derive(IntoStaticStr)]
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
enum BatchedFeMessage {
|
||||
Exists {
|
||||
span: Span,
|
||||
@@ -1298,6 +1302,22 @@ impl PageServerHandler {
|
||||
}
|
||||
};
|
||||
|
||||
let batch_wait_ctx = if ctx.has_perf_span() {
|
||||
Some(
|
||||
RequestContextBuilder::from(&ctx)
|
||||
.perf_span(|crnt_perf_span| {
|
||||
info_span!(
|
||||
target: PERF_TRACE_TARGET,
|
||||
parent: crnt_perf_span,
|
||||
"WAIT_EXECUTOR",
|
||||
)
|
||||
})
|
||||
.attached_child(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
BatchedFeMessage::GetPage {
|
||||
span,
|
||||
shard: shard.downgrade(),
|
||||
@@ -1309,6 +1329,7 @@ impl PageServerHandler {
|
||||
request_lsn: req.hdr.request_lsn
|
||||
},
|
||||
ctx,
|
||||
batch_wait_ctx,
|
||||
}],
|
||||
// The executor grabs the batch when it becomes idle.
|
||||
// Hence, [`GetPageBatchBreakReason::ExecutorSteal`] is the
|
||||
@@ -1464,7 +1485,7 @@ impl PageServerHandler {
|
||||
let mut flush_timers = Vec::with_capacity(handler_results.len());
|
||||
for handler_result in &mut handler_results {
|
||||
let flush_timer = match handler_result {
|
||||
Ok((_, timer)) => Some(
|
||||
Ok((_response, timer, _ctx)) => Some(
|
||||
timer
|
||||
.observe_execution_end(flushing_start_time)
|
||||
.expect("we are the first caller"),
|
||||
@@ -1484,7 +1505,7 @@ impl PageServerHandler {
|
||||
// Some handler errors cause exit from pagestream protocol.
|
||||
// Other handler errors are sent back as an error message and we stay in pagestream protocol.
|
||||
for (handler_result, flushing_timer) in handler_results.into_iter().zip(flush_timers) {
|
||||
let response_msg = match handler_result {
|
||||
let (response_msg, ctx) = match handler_result {
|
||||
Err(e) => match &e.err {
|
||||
PageStreamError::Shutdown => {
|
||||
// If we fail to fulfil a request during shutdown, which may be _because_ of
|
||||
@@ -1509,15 +1530,30 @@ impl PageServerHandler {
|
||||
error!("error reading relation or page version: {full:#}")
|
||||
});
|
||||
|
||||
PagestreamBeMessage::Error(PagestreamErrorResponse {
|
||||
req: e.req,
|
||||
message: e.err.to_string(),
|
||||
})
|
||||
(
|
||||
PagestreamBeMessage::Error(PagestreamErrorResponse {
|
||||
req: e.req,
|
||||
message: e.err.to_string(),
|
||||
}),
|
||||
None,
|
||||
)
|
||||
}
|
||||
},
|
||||
Ok((response_msg, _op_timer_already_observed)) => response_msg,
|
||||
Ok((response_msg, _op_timer_already_observed, ctx)) => (response_msg, Some(ctx)),
|
||||
};
|
||||
|
||||
let ctx = ctx.map(|req_ctx| {
|
||||
RequestContextBuilder::from(&req_ctx)
|
||||
.perf_span(|crnt_perf_span| {
|
||||
info_span!(
|
||||
target: PERF_TRACE_TARGET,
|
||||
parent: crnt_perf_span,
|
||||
"FLUSH_RESPONSE",
|
||||
)
|
||||
})
|
||||
.attached_child()
|
||||
});
|
||||
|
||||
//
|
||||
// marshal & transmit response message
|
||||
//
|
||||
@@ -1540,6 +1576,17 @@ impl PageServerHandler {
|
||||
)),
|
||||
None => futures::future::Either::Right(flush_fut),
|
||||
};
|
||||
|
||||
let flush_fut = if let Some(req_ctx) = ctx.as_ref() {
|
||||
futures::future::Either::Left(
|
||||
flush_fut.maybe_perf_instrument(req_ctx, |current_perf_span| {
|
||||
current_perf_span.clone()
|
||||
}),
|
||||
)
|
||||
} else {
|
||||
futures::future::Either::Right(flush_fut)
|
||||
};
|
||||
|
||||
// do it while respecting cancellation
|
||||
let _: () = async move {
|
||||
tokio::select! {
|
||||
@@ -1569,7 +1616,7 @@ impl PageServerHandler {
|
||||
ctx: &RequestContext,
|
||||
) -> Result<
|
||||
(
|
||||
Vec<Result<(PagestreamBeMessage, SmgrOpTimer), BatchedPageStreamError>>,
|
||||
Vec<Result<(PagestreamBeMessage, SmgrOpTimer, RequestContext), BatchedPageStreamError>>,
|
||||
Span,
|
||||
),
|
||||
QueryError,
|
||||
@@ -1596,7 +1643,7 @@ impl PageServerHandler {
|
||||
self.handle_get_rel_exists_request(&shard, &req, &ctx)
|
||||
.instrument(span.clone())
|
||||
.await
|
||||
.map(|msg| (msg, timer))
|
||||
.map(|msg| (msg, timer, ctx))
|
||||
.map_err(|err| BatchedPageStreamError { err, req: req.hdr }),
|
||||
],
|
||||
span,
|
||||
@@ -1615,7 +1662,7 @@ impl PageServerHandler {
|
||||
self.handle_get_nblocks_request(&shard, &req, &ctx)
|
||||
.instrument(span.clone())
|
||||
.await
|
||||
.map(|msg| (msg, timer))
|
||||
.map(|msg| (msg, timer, ctx))
|
||||
.map_err(|err| BatchedPageStreamError { err, req: req.hdr }),
|
||||
],
|
||||
span,
|
||||
@@ -1662,7 +1709,7 @@ impl PageServerHandler {
|
||||
self.handle_db_size_request(&shard, &req, &ctx)
|
||||
.instrument(span.clone())
|
||||
.await
|
||||
.map(|msg| (msg, timer))
|
||||
.map(|msg| (msg, timer, ctx))
|
||||
.map_err(|err| BatchedPageStreamError { err, req: req.hdr }),
|
||||
],
|
||||
span,
|
||||
@@ -1681,7 +1728,7 @@ impl PageServerHandler {
|
||||
self.handle_get_slru_segment_request(&shard, &req, &ctx)
|
||||
.instrument(span.clone())
|
||||
.await
|
||||
.map(|msg| (msg, timer))
|
||||
.map(|msg| (msg, timer, ctx))
|
||||
.map_err(|err| BatchedPageStreamError { err, req: req.hdr }),
|
||||
],
|
||||
span,
|
||||
@@ -2033,12 +2080,25 @@ impl PageServerHandler {
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
let batch = match batch {
|
||||
let mut batch = match batch {
|
||||
Ok(batch) => batch,
|
||||
Err(e) => {
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
|
||||
if let BatchedFeMessage::GetPage {
|
||||
pages,
|
||||
span: _,
|
||||
shard: _,
|
||||
batch_break_reason: _,
|
||||
} = &mut batch
|
||||
{
|
||||
for req in pages {
|
||||
req.batch_wait_ctx.take();
|
||||
}
|
||||
}
|
||||
|
||||
self.pagestream_handle_batched_message(
|
||||
pgb_writer,
|
||||
batch,
|
||||
@@ -2351,7 +2411,8 @@ impl PageServerHandler {
|
||||
io_concurrency: IoConcurrency,
|
||||
batch_break_reason: GetPageBatchBreakReason,
|
||||
ctx: &RequestContext,
|
||||
) -> Vec<Result<(PagestreamBeMessage, SmgrOpTimer), BatchedPageStreamError>> {
|
||||
) -> Vec<Result<(PagestreamBeMessage, SmgrOpTimer, RequestContext), BatchedPageStreamError>>
|
||||
{
|
||||
debug_assert_current_span_has_tenant_and_timeline_id();
|
||||
|
||||
timeline
|
||||
@@ -2458,6 +2519,7 @@ impl PageServerHandler {
|
||||
page,
|
||||
}),
|
||||
req.timer,
|
||||
req.ctx,
|
||||
)
|
||||
})
|
||||
.map_err(|e| BatchedPageStreamError {
|
||||
@@ -2502,7 +2564,8 @@ impl PageServerHandler {
|
||||
timeline: &Timeline,
|
||||
requests: Vec<BatchedTestRequest>,
|
||||
_ctx: &RequestContext,
|
||||
) -> Vec<Result<(PagestreamBeMessage, SmgrOpTimer), BatchedPageStreamError>> {
|
||||
) -> Vec<Result<(PagestreamBeMessage, SmgrOpTimer, RequestContext), BatchedPageStreamError>>
|
||||
{
|
||||
// real requests would do something with the timeline
|
||||
let mut results = Vec::with_capacity(requests.len());
|
||||
for _req in requests.iter() {
|
||||
@@ -2529,6 +2592,10 @@ impl PageServerHandler {
|
||||
req: req.req.clone(),
|
||||
}),
|
||||
req.timer,
|
||||
RequestContext::new(
|
||||
TaskKind::PageRequestHandler,
|
||||
DownloadBehavior::Warn,
|
||||
),
|
||||
)
|
||||
})
|
||||
.map_err(|e| BatchedPageStreamError {
|
||||
|
||||
@@ -300,7 +300,7 @@ pub struct TenantShard {
|
||||
/// as in progress.
|
||||
/// * Imported timelines are removed when the storage controller calls the post timeline
|
||||
/// import activation endpoint.
|
||||
timelines_importing: std::sync::Mutex<HashMap<TimelineId, ImportingTimeline>>,
|
||||
timelines_importing: std::sync::Mutex<HashMap<TimelineId, Arc<ImportingTimeline>>>,
|
||||
|
||||
/// The last tenant manifest known to be in remote storage. None if the manifest has not yet
|
||||
/// been either downloaded or uploaded. Always Some after tenant attach.
|
||||
@@ -383,7 +383,7 @@ pub struct TenantShard {
|
||||
|
||||
l0_flush_global_state: L0FlushGlobalState,
|
||||
|
||||
feature_resolver: FeatureResolver,
|
||||
pub(crate) feature_resolver: FeatureResolver,
|
||||
}
|
||||
impl std::fmt::Debug for TenantShard {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
@@ -672,6 +672,7 @@ pub enum MaybeOffloaded {
|
||||
pub enum TimelineOrOffloaded {
|
||||
Timeline(Arc<Timeline>),
|
||||
Offloaded(Arc<OffloadedTimeline>),
|
||||
Importing(Arc<ImportingTimeline>),
|
||||
}
|
||||
|
||||
impl TimelineOrOffloaded {
|
||||
@@ -683,6 +684,9 @@ impl TimelineOrOffloaded {
|
||||
TimelineOrOffloaded::Offloaded(offloaded) => {
|
||||
TimelineOrOffloadedArcRef::Offloaded(offloaded)
|
||||
}
|
||||
TimelineOrOffloaded::Importing(importing) => {
|
||||
TimelineOrOffloadedArcRef::Importing(importing)
|
||||
}
|
||||
}
|
||||
}
|
||||
pub fn tenant_shard_id(&self) -> TenantShardId {
|
||||
@@ -695,12 +699,16 @@ impl TimelineOrOffloaded {
|
||||
match self {
|
||||
TimelineOrOffloaded::Timeline(timeline) => &timeline.delete_progress,
|
||||
TimelineOrOffloaded::Offloaded(offloaded) => &offloaded.delete_progress,
|
||||
TimelineOrOffloaded::Importing(importing) => &importing.delete_progress,
|
||||
}
|
||||
}
|
||||
fn maybe_remote_client(&self) -> Option<Arc<RemoteTimelineClient>> {
|
||||
match self {
|
||||
TimelineOrOffloaded::Timeline(timeline) => Some(timeline.remote_client.clone()),
|
||||
TimelineOrOffloaded::Offloaded(_offloaded) => None,
|
||||
TimelineOrOffloaded::Importing(importing) => {
|
||||
Some(importing.timeline.remote_client.clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -708,6 +716,7 @@ impl TimelineOrOffloaded {
|
||||
pub enum TimelineOrOffloadedArcRef<'a> {
|
||||
Timeline(&'a Arc<Timeline>),
|
||||
Offloaded(&'a Arc<OffloadedTimeline>),
|
||||
Importing(&'a Arc<ImportingTimeline>),
|
||||
}
|
||||
|
||||
impl TimelineOrOffloadedArcRef<'_> {
|
||||
@@ -715,12 +724,14 @@ impl TimelineOrOffloadedArcRef<'_> {
|
||||
match self {
|
||||
TimelineOrOffloadedArcRef::Timeline(timeline) => timeline.tenant_shard_id,
|
||||
TimelineOrOffloadedArcRef::Offloaded(offloaded) => offloaded.tenant_shard_id,
|
||||
TimelineOrOffloadedArcRef::Importing(importing) => importing.timeline.tenant_shard_id,
|
||||
}
|
||||
}
|
||||
pub fn timeline_id(&self) -> TimelineId {
|
||||
match self {
|
||||
TimelineOrOffloadedArcRef::Timeline(timeline) => timeline.timeline_id,
|
||||
TimelineOrOffloadedArcRef::Offloaded(offloaded) => offloaded.timeline_id,
|
||||
TimelineOrOffloadedArcRef::Importing(importing) => importing.timeline.timeline_id,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -737,6 +748,12 @@ impl<'a> From<&'a Arc<OffloadedTimeline>> for TimelineOrOffloadedArcRef<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a Arc<ImportingTimeline>> for TimelineOrOffloadedArcRef<'a> {
|
||||
fn from(timeline: &'a Arc<ImportingTimeline>) -> Self {
|
||||
Self::Importing(timeline)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error, PartialEq, Eq)]
|
||||
pub enum GetTimelineError {
|
||||
#[error("Timeline is shutting down")]
|
||||
@@ -1789,20 +1806,25 @@ impl TenantShard {
|
||||
},
|
||||
) => {
|
||||
let timeline_id = timeline.timeline_id;
|
||||
let import_task_gate = Gate::default();
|
||||
let import_task_guard = import_task_gate.enter().unwrap();
|
||||
let import_task_handle =
|
||||
tokio::task::spawn(self.clone().create_timeline_import_pgdata_task(
|
||||
timeline.clone(),
|
||||
import_pgdata,
|
||||
guard,
|
||||
import_task_guard,
|
||||
ctx.detached_child(TaskKind::ImportPgdata, DownloadBehavior::Warn),
|
||||
));
|
||||
|
||||
let prev = self.timelines_importing.lock().unwrap().insert(
|
||||
timeline_id,
|
||||
ImportingTimeline {
|
||||
Arc::new(ImportingTimeline {
|
||||
timeline: timeline.clone(),
|
||||
import_task_handle,
|
||||
},
|
||||
import_task_gate,
|
||||
delete_progress: TimelineDeleteProgress::default(),
|
||||
}),
|
||||
);
|
||||
|
||||
assert!(prev.is_none());
|
||||
@@ -2420,6 +2442,17 @@ impl TenantShard {
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Lists timelines the tenant contains.
|
||||
/// It's up to callers to omit certain timelines that are not considered ready for use.
|
||||
pub fn list_importing_timelines(&self) -> Vec<Arc<ImportingTimeline>> {
|
||||
self.timelines_importing
|
||||
.lock()
|
||||
.unwrap()
|
||||
.values()
|
||||
.map(Arc::clone)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Lists timelines the tenant manages, including offloaded ones.
|
||||
///
|
||||
/// It's up to callers to omit certain timelines that are not considered ready for use.
|
||||
@@ -2853,19 +2886,25 @@ impl TenantShard {
|
||||
|
||||
let (timeline, timeline_create_guard) = uninit_timeline.finish_creation_myself();
|
||||
|
||||
let import_task_gate = Gate::default();
|
||||
let import_task_guard = import_task_gate.enter().unwrap();
|
||||
|
||||
let import_task_handle = tokio::spawn(self.clone().create_timeline_import_pgdata_task(
|
||||
timeline.clone(),
|
||||
index_part,
|
||||
timeline_create_guard,
|
||||
import_task_guard,
|
||||
timeline_ctx.detached_child(TaskKind::ImportPgdata, DownloadBehavior::Warn),
|
||||
));
|
||||
|
||||
let prev = self.timelines_importing.lock().unwrap().insert(
|
||||
timeline.timeline_id,
|
||||
ImportingTimeline {
|
||||
Arc::new(ImportingTimeline {
|
||||
timeline: timeline.clone(),
|
||||
import_task_handle,
|
||||
},
|
||||
import_task_gate,
|
||||
delete_progress: TimelineDeleteProgress::default(),
|
||||
}),
|
||||
);
|
||||
|
||||
// Idempotency is enforced higher up the stack
|
||||
@@ -2924,6 +2963,7 @@ impl TenantShard {
|
||||
timeline: Arc<Timeline>,
|
||||
index_part: import_pgdata::index_part_format::Root,
|
||||
timeline_create_guard: TimelineCreateGuard,
|
||||
_import_task_guard: GateGuard,
|
||||
ctx: RequestContext,
|
||||
) {
|
||||
debug_assert_current_span_has_tenant_and_timeline_id();
|
||||
@@ -3835,6 +3875,9 @@ impl TenantShard {
|
||||
.build_timeline_client(offloaded.timeline_id, self.remote_storage.clone());
|
||||
Arc::new(remote_client)
|
||||
}
|
||||
TimelineOrOffloadedArcRef::Importing(_) => {
|
||||
unreachable!("Importing timelines are not included in the iterator")
|
||||
}
|
||||
};
|
||||
|
||||
// Shut down the timeline's remote client: this means that the indices we write
|
||||
@@ -5044,6 +5087,14 @@ impl TenantShard {
|
||||
info!("timeline already exists but is offloaded");
|
||||
Err(CreateTimelineError::Conflict)
|
||||
}
|
||||
Err(TimelineExclusionError::AlreadyExists {
|
||||
existing: TimelineOrOffloaded::Importing(_existing),
|
||||
..
|
||||
}) => {
|
||||
// If there's a timeline already importing, then we would hit
|
||||
// the [`TimelineExclusionError::AlreadyCreating`] branch above.
|
||||
unreachable!("Importing timelines hold the creation guard")
|
||||
}
|
||||
Err(TimelineExclusionError::AlreadyExists {
|
||||
existing: TimelineOrOffloaded::Timeline(existing),
|
||||
arg,
|
||||
@@ -5315,6 +5366,7 @@ impl TenantShard {
|
||||
l0_compaction_trigger: self.l0_compaction_trigger.clone(),
|
||||
l0_flush_global_state: self.l0_flush_global_state.clone(),
|
||||
basebackup_prepare_sender: self.basebackup_prepare_sender.clone(),
|
||||
feature_resolver: self.feature_resolver.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5780,6 +5832,7 @@ pub(crate) mod harness {
|
||||
pub conf: &'static PageServerConf,
|
||||
pub tenant_conf: pageserver_api::models::TenantConfig,
|
||||
pub tenant_shard_id: TenantShardId,
|
||||
pub shard_identity: ShardIdentity,
|
||||
pub generation: Generation,
|
||||
pub shard: ShardIndex,
|
||||
pub remote_storage: GenericRemoteStorage,
|
||||
@@ -5847,6 +5900,7 @@ pub(crate) mod harness {
|
||||
conf,
|
||||
tenant_conf,
|
||||
tenant_shard_id,
|
||||
shard_identity,
|
||||
generation,
|
||||
shard,
|
||||
remote_storage,
|
||||
@@ -5908,8 +5962,7 @@ pub(crate) mod harness {
|
||||
&ShardParameters::default(),
|
||||
))
|
||||
.unwrap(),
|
||||
// This is a legacy/test code path: sharding isn't supported here.
|
||||
ShardIdentity::unsharded(),
|
||||
self.shard_identity,
|
||||
Some(walredo_mgr),
|
||||
self.tenant_shard_id,
|
||||
self.remote_storage.clone(),
|
||||
@@ -6031,6 +6084,7 @@ mod tests {
|
||||
use timeline::compaction::{KeyHistoryRetention, KeyLogAtLsn};
|
||||
use timeline::{CompactOptions, DeltaLayerTestDesc, VersionedKeySpaceQuery};
|
||||
use utils::id::TenantId;
|
||||
use utils::shard::{ShardCount, ShardNumber};
|
||||
|
||||
use super::*;
|
||||
use crate::DEFAULT_PG_VERSION;
|
||||
@@ -8359,10 +8413,24 @@ mod tests {
|
||||
}
|
||||
|
||||
tline.freeze_and_flush().await?;
|
||||
// Force layers to L1
|
||||
tline
|
||||
.compact(
|
||||
&cancel,
|
||||
{
|
||||
let mut flags = EnumSet::new();
|
||||
flags.insert(CompactFlags::ForceL0Compaction);
|
||||
flags
|
||||
},
|
||||
&ctx,
|
||||
)
|
||||
.await?;
|
||||
|
||||
if iter % 5 == 0 {
|
||||
let scan_lsn = Lsn(lsn.0 + 1);
|
||||
info!("scanning at {}", scan_lsn);
|
||||
let (_, before_delta_file_accessed) =
|
||||
scan_with_statistics(&tline, &keyspace, lsn, &ctx, io_concurrency.clone())
|
||||
scan_with_statistics(&tline, &keyspace, scan_lsn, &ctx, io_concurrency.clone())
|
||||
.await?;
|
||||
tline
|
||||
.compact(
|
||||
@@ -8371,13 +8439,14 @@ mod tests {
|
||||
let mut flags = EnumSet::new();
|
||||
flags.insert(CompactFlags::ForceImageLayerCreation);
|
||||
flags.insert(CompactFlags::ForceRepartition);
|
||||
flags.insert(CompactFlags::ForceL0Compaction);
|
||||
flags
|
||||
},
|
||||
&ctx,
|
||||
)
|
||||
.await?;
|
||||
let (_, after_delta_file_accessed) =
|
||||
scan_with_statistics(&tline, &keyspace, lsn, &ctx, io_concurrency.clone())
|
||||
scan_with_statistics(&tline, &keyspace, scan_lsn, &ctx, io_concurrency.clone())
|
||||
.await?;
|
||||
assert!(
|
||||
after_delta_file_accessed < before_delta_file_accessed,
|
||||
@@ -8818,6 +8887,8 @@ mod tests {
|
||||
|
||||
let cancel = CancellationToken::new();
|
||||
|
||||
// Image layer creation happens on the disk_consistent_lsn so we need to force set it now.
|
||||
tline.force_set_disk_consistent_lsn(Lsn(0x40));
|
||||
tline
|
||||
.compact(
|
||||
&cancel,
|
||||
@@ -8831,8 +8902,7 @@ mod tests {
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Image layers are created at last_record_lsn
|
||||
// Image layers are created at repartition LSN
|
||||
let images = tline
|
||||
.inspect_image_layers(Lsn(0x40), &ctx, io_concurrency.clone())
|
||||
.await
|
||||
@@ -9350,6 +9420,77 @@ mod tests {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_failed_flush_should_not_update_disk_consistent_lsn() -> anyhow::Result<()> {
|
||||
//
|
||||
// Setup
|
||||
//
|
||||
let harness = TenantHarness::create_custom(
|
||||
"test_failed_flush_should_not_upload_disk_consistent_lsn",
|
||||
pageserver_api::models::TenantConfig::default(),
|
||||
TenantId::generate(),
|
||||
ShardIdentity::new(ShardNumber(0), ShardCount(4), ShardStripeSize(128)).unwrap(),
|
||||
Generation::new(1),
|
||||
)
|
||||
.await?;
|
||||
let (tenant, ctx) = harness.load().await;
|
||||
|
||||
let timeline = tenant
|
||||
.create_test_timeline(TIMELINE_ID, Lsn(0x10), DEFAULT_PG_VERSION, &ctx)
|
||||
.await?;
|
||||
assert_eq!(timeline.get_shard_identity().count, ShardCount(4));
|
||||
let mut writer = timeline.writer().await;
|
||||
writer
|
||||
.put(
|
||||
*TEST_KEY,
|
||||
Lsn(0x20),
|
||||
&Value::Image(test_img("foo at 0x20")),
|
||||
&ctx,
|
||||
)
|
||||
.await?;
|
||||
writer.finish_write(Lsn(0x20));
|
||||
drop(writer);
|
||||
timeline.freeze_and_flush().await.unwrap();
|
||||
|
||||
timeline.remote_client.wait_completion().await.unwrap();
|
||||
let disk_consistent_lsn = timeline.get_disk_consistent_lsn();
|
||||
let remote_consistent_lsn = timeline.get_remote_consistent_lsn_projected();
|
||||
assert_eq!(Some(disk_consistent_lsn), remote_consistent_lsn);
|
||||
|
||||
//
|
||||
// Test
|
||||
//
|
||||
|
||||
let mut writer = timeline.writer().await;
|
||||
writer
|
||||
.put(
|
||||
*TEST_KEY,
|
||||
Lsn(0x30),
|
||||
&Value::Image(test_img("foo at 0x30")),
|
||||
&ctx,
|
||||
)
|
||||
.await?;
|
||||
writer.finish_write(Lsn(0x30));
|
||||
drop(writer);
|
||||
|
||||
fail::cfg(
|
||||
"flush-layer-before-update-remote-consistent-lsn",
|
||||
"return()",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let flush_res = timeline.freeze_and_flush().await;
|
||||
// if flush failed, the disk/remote consistent LSN should not be updated
|
||||
assert!(flush_res.is_err());
|
||||
assert_eq!(disk_consistent_lsn, timeline.get_disk_consistent_lsn());
|
||||
assert_eq!(
|
||||
remote_consistent_lsn,
|
||||
timeline.get_remote_consistent_lsn_projected()
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "testing")]
|
||||
#[tokio::test]
|
||||
async fn test_simple_bottom_most_compaction_deltas_1() -> anyhow::Result<()> {
|
||||
|
||||
@@ -1348,6 +1348,21 @@ impl RemoteTimelineClient {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn schedule_unlinking_of_layers_from_index_part<I>(
|
||||
self: &Arc<Self>,
|
||||
names: I,
|
||||
) -> Result<(), NotInitialized>
|
||||
where
|
||||
I: IntoIterator<Item = LayerName>,
|
||||
{
|
||||
let mut guard = self.upload_queue.lock().unwrap();
|
||||
let upload_queue = guard.initialized_mut()?;
|
||||
|
||||
self.schedule_unlinking_of_layers_from_index_part0(upload_queue, names);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update the remote index file, removing the to-be-deleted files from the index,
|
||||
/// allowing scheduling of actual deletions later.
|
||||
fn schedule_unlinking_of_layers_from_index_part0<I>(
|
||||
|
||||
@@ -103,6 +103,7 @@ use crate::context::{
|
||||
DownloadBehavior, PerfInstrumentFutureExt, RequestContext, RequestContextBuilder,
|
||||
};
|
||||
use crate::disk_usage_eviction_task::{DiskUsageEvictionInfo, EvictionCandidate, finite_f32};
|
||||
use crate::feature_resolver::FeatureResolver;
|
||||
use crate::keyspace::{KeyPartitioning, KeySpace};
|
||||
use crate::l0_flush::{self, L0FlushGlobalState};
|
||||
use crate::metrics::{
|
||||
@@ -198,6 +199,7 @@ pub struct TimelineResources {
|
||||
pub l0_compaction_trigger: Arc<Notify>,
|
||||
pub l0_flush_global_state: l0_flush::L0FlushGlobalState,
|
||||
pub basebackup_prepare_sender: BasebackupPrepareSender,
|
||||
pub feature_resolver: FeatureResolver,
|
||||
}
|
||||
|
||||
pub struct Timeline {
|
||||
@@ -444,6 +446,8 @@ pub struct Timeline {
|
||||
|
||||
/// A channel to send async requests to prepare a basebackup for the basebackup cache.
|
||||
basebackup_prepare_sender: BasebackupPrepareSender,
|
||||
|
||||
feature_resolver: FeatureResolver,
|
||||
}
|
||||
|
||||
pub(crate) enum PreviousHeatmap {
|
||||
@@ -3072,6 +3076,8 @@ impl Timeline {
|
||||
wait_lsn_log_slow: tokio::sync::Semaphore::new(1),
|
||||
|
||||
basebackup_prepare_sender: resources.basebackup_prepare_sender,
|
||||
|
||||
feature_resolver: resources.feature_resolver,
|
||||
};
|
||||
|
||||
result.repartition_threshold =
|
||||
@@ -4761,7 +4767,10 @@ impl Timeline {
|
||||
|| !flushed_to_lsn.is_valid()
|
||||
);
|
||||
|
||||
if flushed_to_lsn < frozen_to_lsn && self.shard_identity.count.count() > 1 {
|
||||
if flushed_to_lsn < frozen_to_lsn
|
||||
&& self.shard_identity.count.count() > 1
|
||||
&& result.is_ok()
|
||||
{
|
||||
// If our layer flushes didn't carry disk_consistent_lsn up to the `to_lsn` advertised
|
||||
// to us via layer_flush_start_rx, then advance it here.
|
||||
//
|
||||
@@ -4906,6 +4915,7 @@ impl Timeline {
|
||||
LastImageLayerCreationStatus::Initial,
|
||||
false, // don't yield for L0, we're flushing L0
|
||||
)
|
||||
.instrument(info_span!("create_image_layers", mode = %ImageLayerCreationMode::Initial, partition_mode = "initial", lsn = %self.initdb_lsn))
|
||||
.await?;
|
||||
debug_assert!(
|
||||
matches!(is_complete, LastImageLayerCreationStatus::Complete),
|
||||
@@ -4939,6 +4949,10 @@ impl Timeline {
|
||||
return Err(FlushLayerError::Cancelled);
|
||||
}
|
||||
|
||||
fail_point!("flush-layer-before-update-remote-consistent-lsn", |_| {
|
||||
Err(FlushLayerError::Other(anyhow!("failpoint").into()))
|
||||
});
|
||||
|
||||
let disk_consistent_lsn = Lsn(lsn_range.end.0 - 1);
|
||||
|
||||
// The new on-disk layers are now in the layer map. We can remove the
|
||||
@@ -5462,7 +5476,8 @@ impl Timeline {
|
||||
|
||||
/// Returns the image layers generated and an enum indicating whether the process is fully completed.
|
||||
/// true = we have generate all image layers, false = we preempt the process for L0 compaction.
|
||||
#[tracing::instrument(skip_all, fields(%lsn, %mode))]
|
||||
///
|
||||
/// `partition_mode` is only for logging purpose and is not used anywhere in this function.
|
||||
async fn create_image_layers(
|
||||
self: &Arc<Timeline>,
|
||||
partitioning: &KeyPartitioning,
|
||||
|
||||
@@ -206,8 +206,8 @@ pub struct GcCompactionQueue {
|
||||
}
|
||||
|
||||
static CONCURRENT_GC_COMPACTION_TASKS: Lazy<Arc<Semaphore>> = Lazy::new(|| {
|
||||
// Only allow two timelines on one pageserver to run gc compaction at a time.
|
||||
Arc::new(Semaphore::new(2))
|
||||
// Only allow one timeline on one pageserver to run gc compaction at a time.
|
||||
Arc::new(Semaphore::new(1))
|
||||
});
|
||||
|
||||
impl GcCompactionQueue {
|
||||
@@ -1278,11 +1278,55 @@ impl Timeline {
|
||||
}
|
||||
|
||||
let gc_cutoff = *self.applied_gc_cutoff_lsn.read();
|
||||
let l0_l1_boundary_lsn = {
|
||||
// We do the repartition on the L0-L1 boundary. All data below the boundary
|
||||
// are compacted by L0 with low read amplification, thus making the `repartition`
|
||||
// function run fast.
|
||||
let guard = self.layers.read().await;
|
||||
guard
|
||||
.all_persistent_layers()
|
||||
.iter()
|
||||
.map(|x| {
|
||||
// Use the end LSN of delta layers OR the start LSN of image layers.
|
||||
if x.is_delta {
|
||||
x.lsn_range.end
|
||||
} else {
|
||||
x.lsn_range.start
|
||||
}
|
||||
})
|
||||
.max()
|
||||
};
|
||||
|
||||
let (partition_mode, partition_lsn) = if cfg!(test)
|
||||
|| cfg!(feature = "testing")
|
||||
|| self
|
||||
.feature_resolver
|
||||
.evaluate_boolean("image-compaction-boundary", self.tenant_shard_id.tenant_id)
|
||||
.is_ok()
|
||||
{
|
||||
let last_repartition_lsn = self.partitioning.read().1;
|
||||
let lsn = match l0_l1_boundary_lsn {
|
||||
Some(boundary) => gc_cutoff
|
||||
.max(boundary)
|
||||
.max(last_repartition_lsn)
|
||||
.max(self.initdb_lsn)
|
||||
.max(self.ancestor_lsn),
|
||||
None => self.get_last_record_lsn(),
|
||||
};
|
||||
if lsn <= self.initdb_lsn || lsn <= self.ancestor_lsn {
|
||||
// Do not attempt to create image layers below the initdb or ancestor LSN -- no data below it
|
||||
("l0_l1_boundary", self.get_last_record_lsn())
|
||||
} else {
|
||||
("l0_l1_boundary", lsn)
|
||||
}
|
||||
} else {
|
||||
("latest_record", self.get_last_record_lsn())
|
||||
};
|
||||
|
||||
// 2. Repartition and create image layers if necessary
|
||||
match self
|
||||
.repartition(
|
||||
self.get_last_record_lsn(),
|
||||
partition_lsn,
|
||||
self.get_compaction_target_size(),
|
||||
options.flags,
|
||||
ctx,
|
||||
@@ -1301,18 +1345,19 @@ impl Timeline {
|
||||
.extend(sparse_partitioning.into_dense().parts);
|
||||
|
||||
// 3. Create new image layers for partitions that have been modified "enough".
|
||||
let mode = if options
|
||||
.flags
|
||||
.contains(CompactFlags::ForceImageLayerCreation)
|
||||
{
|
||||
ImageLayerCreationMode::Force
|
||||
} else {
|
||||
ImageLayerCreationMode::Try
|
||||
};
|
||||
let (image_layers, outcome) = self
|
||||
.create_image_layers(
|
||||
&partitioning,
|
||||
lsn,
|
||||
if options
|
||||
.flags
|
||||
.contains(CompactFlags::ForceImageLayerCreation)
|
||||
{
|
||||
ImageLayerCreationMode::Force
|
||||
} else {
|
||||
ImageLayerCreationMode::Try
|
||||
},
|
||||
mode,
|
||||
&image_ctx,
|
||||
self.last_image_layer_creation_status
|
||||
.load()
|
||||
@@ -1320,6 +1365,7 @@ impl Timeline {
|
||||
.clone(),
|
||||
options.flags.contains(CompactFlags::YieldForL0),
|
||||
)
|
||||
.instrument(info_span!("create_image_layers", mode = %mode, partition_mode = %partition_mode, lsn = %lsn))
|
||||
.await
|
||||
.inspect_err(|err| {
|
||||
if let CreateImageLayersError::GetVectoredError(
|
||||
@@ -1344,7 +1390,8 @@ impl Timeline {
|
||||
}
|
||||
|
||||
Ok(_) => {
|
||||
info!("skipping repartitioning due to image compaction LSN being below GC cutoff");
|
||||
// This happens very frequently so we don't want to log it.
|
||||
debug!("skipping repartitioning due to image compaction LSN being below GC cutoff");
|
||||
}
|
||||
|
||||
// Suppress errors when cancelled.
|
||||
|
||||
@@ -121,6 +121,7 @@ async fn remove_maybe_offloaded_timeline_from_tenant(
|
||||
// This observes the locking order between timelines and timelines_offloaded
|
||||
let mut timelines = tenant.timelines.lock().unwrap();
|
||||
let mut timelines_offloaded = tenant.timelines_offloaded.lock().unwrap();
|
||||
let mut timelines_importing = tenant.timelines_importing.lock().unwrap();
|
||||
let offloaded_children_exist = timelines_offloaded
|
||||
.iter()
|
||||
.any(|(_, entry)| entry.ancestor_timeline_id == Some(timeline.timeline_id()));
|
||||
@@ -150,8 +151,12 @@ async fn remove_maybe_offloaded_timeline_from_tenant(
|
||||
.expect("timeline that we were deleting was concurrently removed from 'timelines_offloaded' map");
|
||||
offloaded_timeline.delete_from_ancestor_with_timelines(&timelines);
|
||||
}
|
||||
TimelineOrOffloaded::Importing(importing) => {
|
||||
timelines_importing.remove(&importing.timeline.timeline_id);
|
||||
}
|
||||
}
|
||||
|
||||
drop(timelines_importing);
|
||||
drop(timelines_offloaded);
|
||||
drop(timelines);
|
||||
|
||||
@@ -203,8 +208,17 @@ impl DeleteTimelineFlow {
|
||||
guard.mark_in_progress()?;
|
||||
|
||||
// Now that the Timeline is in Stopping state, request all the related tasks to shut down.
|
||||
if let TimelineOrOffloaded::Timeline(timeline) = &timeline {
|
||||
timeline.shutdown(super::ShutdownMode::Hard).await;
|
||||
// TODO(vlad): shut down imported timeline here
|
||||
match &timeline {
|
||||
TimelineOrOffloaded::Timeline(timeline) => {
|
||||
timeline.shutdown(super::ShutdownMode::Hard).await;
|
||||
}
|
||||
TimelineOrOffloaded::Importing(importing) => {
|
||||
importing.shutdown().await;
|
||||
}
|
||||
TimelineOrOffloaded::Offloaded(_offloaded) => {
|
||||
// Nothing to shut down in this case
|
||||
}
|
||||
}
|
||||
|
||||
tenant.gc_block.before_delete(&timeline.timeline_id());
|
||||
@@ -389,10 +403,18 @@ impl DeleteTimelineFlow {
|
||||
Err(anyhow::anyhow!("failpoint: timeline-delete-before-rm"))?
|
||||
});
|
||||
|
||||
// Offloaded timelines have no local state
|
||||
// TODO: once we persist offloaded information, delete the timeline from there, too
|
||||
if let TimelineOrOffloaded::Timeline(timeline) = timeline {
|
||||
delete_local_timeline_directory(conf, tenant.tenant_shard_id, timeline).await;
|
||||
match timeline {
|
||||
TimelineOrOffloaded::Timeline(timeline) => {
|
||||
delete_local_timeline_directory(conf, tenant.tenant_shard_id, timeline).await;
|
||||
}
|
||||
TimelineOrOffloaded::Importing(importing) => {
|
||||
delete_local_timeline_directory(conf, tenant.tenant_shard_id, &importing.timeline)
|
||||
.await;
|
||||
}
|
||||
TimelineOrOffloaded::Offloaded(_offloaded) => {
|
||||
// Offloaded timelines have no local state
|
||||
// TODO: once we persist offloaded information, delete the timeline from there, too
|
||||
}
|
||||
}
|
||||
|
||||
fail::fail_point!("timeline-delete-after-rm", |_| {
|
||||
@@ -451,12 +473,16 @@ pub(super) fn make_timeline_delete_guard(
|
||||
// For more context see this discussion: `https://github.com/neondatabase/neon/pull/4552#discussion_r1253437346`
|
||||
let timelines = tenant.timelines.lock().unwrap();
|
||||
let timelines_offloaded = tenant.timelines_offloaded.lock().unwrap();
|
||||
let timelines_importing = tenant.timelines_importing.lock().unwrap();
|
||||
|
||||
let timeline = match timelines.get(&timeline_id) {
|
||||
Some(t) => TimelineOrOffloaded::Timeline(Arc::clone(t)),
|
||||
None => match timelines_offloaded.get(&timeline_id) {
|
||||
Some(t) => TimelineOrOffloaded::Offloaded(Arc::clone(t)),
|
||||
None => return Err(DeleteTimelineError::NotFound),
|
||||
None => match timelines_importing.get(&timeline_id) {
|
||||
Some(t) => TimelineOrOffloaded::Importing(Arc::clone(t)),
|
||||
None => return Err(DeleteTimelineError::NotFound),
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
@@ -8,8 +8,10 @@ use tokio::task::JoinHandle;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::info;
|
||||
use utils::lsn::Lsn;
|
||||
use utils::pausable_failpoint;
|
||||
use utils::sync::gate::Gate;
|
||||
|
||||
use super::Timeline;
|
||||
use super::{Timeline, TimelineDeleteProgress};
|
||||
use crate::context::RequestContext;
|
||||
use crate::controller_upcall_client::{StorageControllerUpcallApi, StorageControllerUpcallClient};
|
||||
use crate::tenant::metadata::TimelineMetadata;
|
||||
@@ -19,15 +21,23 @@ mod importbucket_client;
|
||||
mod importbucket_format;
|
||||
pub(crate) mod index_part_format;
|
||||
|
||||
pub(crate) struct ImportingTimeline {
|
||||
pub struct ImportingTimeline {
|
||||
pub import_task_handle: JoinHandle<()>,
|
||||
pub import_task_gate: Gate,
|
||||
pub timeline: Arc<Timeline>,
|
||||
pub delete_progress: TimelineDeleteProgress,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for ImportingTimeline {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "ImportingTimeline<{}>", self.timeline.timeline_id)
|
||||
}
|
||||
}
|
||||
|
||||
impl ImportingTimeline {
|
||||
pub(crate) async fn shutdown(self) {
|
||||
pub async fn shutdown(&self) {
|
||||
self.import_task_handle.abort();
|
||||
let _ = self.import_task_handle.await;
|
||||
self.import_task_gate.close().await;
|
||||
|
||||
self.timeline.remote_client.shutdown().await;
|
||||
}
|
||||
@@ -101,6 +111,8 @@ pub async fn doit(
|
||||
.schedule_index_upload_for_file_changes()?;
|
||||
timeline.remote_client.wait_completion().await?;
|
||||
|
||||
pausable_failpoint!("import-timeline-pre-success-notify-pausable");
|
||||
|
||||
// Communicate that shard is done.
|
||||
// Ensure at-least-once delivery of the upcall to storage controller
|
||||
// before we mark the task as done and never come here again.
|
||||
|
||||
@@ -11,25 +11,14 @@
|
||||
//! - => S3 as the source for the PGDATA instead of local filesystem
|
||||
//!
|
||||
//! TODOs before productionization:
|
||||
//! - ChunkProcessingJob size / ImportJob::total_size does not account for sharding.
|
||||
//! => produced image layers likely too small.
|
||||
//! - ChunkProcessingJob should cut up an ImportJob to hit exactly target image layer size.
|
||||
//! - asserts / unwraps need to be replaced with errors
|
||||
//! - don't trust remote objects will be small (=prevent OOMs in those cases)
|
||||
//! - limit all in-memory buffers in size, or download to disk and read from there
|
||||
//! - limit task concurrency
|
||||
//! - generally play nice with other tenants in the system
|
||||
//! - importbucket is different bucket than main pageserver storage, so, should be fine wrt S3 rate limits
|
||||
//! - but concerns like network bandwidth, local disk write bandwidth, local disk capacity, etc
|
||||
//! - integrate with layer eviction system
|
||||
//! - audit for Tenant::cancel nor Timeline::cancel responsivity
|
||||
//! - audit for Tenant/Timeline gate holding (we spawn tokio tasks during this flow!)
|
||||
//!
|
||||
//! An incomplete set of TODOs from the Hackathon:
|
||||
//! - version-specific CheckPointData (=> pgv abstraction, already exists for regular walingest)
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::num::NonZeroUsize;
|
||||
use std::ops::Range;
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -43,7 +32,7 @@ use pageserver_api::key::{
|
||||
rel_dir_to_key, rel_size_to_key, relmap_file_key, slru_block_to_key, slru_dir_to_key,
|
||||
slru_segment_size_to_key,
|
||||
};
|
||||
use pageserver_api::keyspace::{contiguous_range_len, is_contiguous_range, singleton_range};
|
||||
use pageserver_api::keyspace::{ShardedRange, singleton_range};
|
||||
use pageserver_api::models::{ShardImportProgress, ShardImportProgressV1, ShardImportStatus};
|
||||
use pageserver_api::reltag::{RelTag, SlruKind};
|
||||
use pageserver_api::shard::ShardIdentity;
|
||||
@@ -100,8 +89,24 @@ async fn run_v1(
|
||||
tasks: Vec::default(),
|
||||
};
|
||||
|
||||
let import_config = &timeline.conf.timeline_import_config;
|
||||
let plan = planner.plan(import_config).await?;
|
||||
// Use the job size limit encoded in the progress if we are resuming an import.
|
||||
// This ensures that imports have stable plans even if the pageserver config changes.
|
||||
let import_config = {
|
||||
match &import_progress {
|
||||
Some(progress) => {
|
||||
let base = &timeline.conf.timeline_import_config;
|
||||
TimelineImportConfig {
|
||||
import_job_soft_size_limit: NonZeroUsize::new(progress.job_soft_size_limit)
|
||||
.unwrap(),
|
||||
import_job_concurrency: base.import_job_concurrency,
|
||||
import_job_checkpoint_threshold: base.import_job_checkpoint_threshold,
|
||||
}
|
||||
}
|
||||
None => timeline.conf.timeline_import_config.clone(),
|
||||
}
|
||||
};
|
||||
|
||||
let plan = planner.plan(&import_config).await?;
|
||||
|
||||
// Hash the plan and compare with the hash of the plan we got back from the storage controller.
|
||||
// If the two match, it means that the planning stage had the same output.
|
||||
@@ -126,7 +131,7 @@ async fn run_v1(
|
||||
pausable_failpoint!("import-timeline-pre-execute-pausable");
|
||||
|
||||
let start_from_job_idx = import_progress.map(|progress| progress.completed);
|
||||
plan.execute(timeline, start_from_job_idx, plan_hash, import_config, ctx)
|
||||
plan.execute(timeline, start_from_job_idx, plan_hash, &import_config, ctx)
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -150,6 +155,7 @@ impl Planner {
|
||||
/// This function is and must remain pure: given the same input, it will generate the same import plan.
|
||||
async fn plan(mut self, import_config: &TimelineImportConfig) -> anyhow::Result<Plan> {
|
||||
let pgdata_lsn = Lsn(self.control_file.control_file_data().checkPoint).align();
|
||||
anyhow::ensure!(pgdata_lsn.is_valid());
|
||||
|
||||
let datadir = PgDataDir::new(&self.storage).await?;
|
||||
|
||||
@@ -232,14 +238,22 @@ impl Planner {
|
||||
});
|
||||
|
||||
// Assigns parts of key space to later parallel jobs
|
||||
// Note: The image layers produced here may have gaps, meaning,
|
||||
// there is not an image for each key in the layer's key range.
|
||||
// The read path stops traversal at the first image layer, regardless
|
||||
// of whether a base image has been found for a key or not.
|
||||
// (Concept of sparse image layers doesn't exist.)
|
||||
// This behavior is exactly right for the base image layers we're producing here.
|
||||
// But, since no other place in the code currently produces image layers with gaps,
|
||||
// it seems noteworthy.
|
||||
let mut last_end_key = Key::MIN;
|
||||
let mut current_chunk = Vec::new();
|
||||
let mut current_chunk_size: usize = 0;
|
||||
let mut jobs = Vec::new();
|
||||
for task in std::mem::take(&mut self.tasks).into_iter() {
|
||||
if current_chunk_size + task.total_size()
|
||||
> import_config.import_job_soft_size_limit.into()
|
||||
{
|
||||
let task_size = task.total_size(&self.shard);
|
||||
let projected_chunk_size = current_chunk_size.saturating_add(task_size);
|
||||
if projected_chunk_size > import_config.import_job_soft_size_limit.into() {
|
||||
let key_range = last_end_key..task.key_range().start;
|
||||
jobs.push(ChunkProcessingJob::new(
|
||||
key_range.clone(),
|
||||
@@ -249,7 +263,7 @@ impl Planner {
|
||||
last_end_key = key_range.end;
|
||||
current_chunk_size = 0;
|
||||
}
|
||||
current_chunk_size += task.total_size();
|
||||
current_chunk_size = current_chunk_size.saturating_add(task_size);
|
||||
current_chunk.push(task);
|
||||
}
|
||||
jobs.push(ChunkProcessingJob::new(
|
||||
@@ -453,6 +467,7 @@ impl Plan {
|
||||
jobs: jobs_in_plan,
|
||||
completed: last_completed_job_idx,
|
||||
import_plan_hash,
|
||||
job_soft_size_limit: import_config.import_job_soft_size_limit.into(),
|
||||
};
|
||||
|
||||
timeline.remote_client.schedule_index_upload_for_file_changes()?;
|
||||
@@ -586,18 +601,18 @@ impl PgDataDirDb {
|
||||
};
|
||||
|
||||
let path = datadir_path.join(rel_tag.to_segfile_name(segno));
|
||||
assert!(filesize % BLCKSZ as usize == 0); // TODO: this should result in an error
|
||||
anyhow::ensure!(filesize % BLCKSZ as usize == 0);
|
||||
let nblocks = filesize / BLCKSZ as usize;
|
||||
|
||||
PgDataDirDbFile {
|
||||
Ok(PgDataDirDbFile {
|
||||
path,
|
||||
filesize,
|
||||
rel_tag,
|
||||
segno,
|
||||
nblocks: Some(nblocks), // first non-cummulative sizes
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
.collect::<anyhow::Result<_, _>>()?;
|
||||
|
||||
// Set cummulative sizes. Do all of that math here, so that later we could easier
|
||||
// parallelize over segments and know with which segments we need to write relsize
|
||||
@@ -632,12 +647,22 @@ impl PgDataDirDb {
|
||||
trait ImportTask {
|
||||
fn key_range(&self) -> Range<Key>;
|
||||
|
||||
fn total_size(&self) -> usize {
|
||||
// TODO: revisit this
|
||||
if is_contiguous_range(&self.key_range()) {
|
||||
contiguous_range_len(&self.key_range()) as usize * 8192
|
||||
fn total_size(&self, shard_identity: &ShardIdentity) -> usize {
|
||||
let range = ShardedRange::new(self.key_range(), shard_identity);
|
||||
let page_count = range.page_count();
|
||||
if page_count == u32::MAX {
|
||||
tracing::warn!(
|
||||
"Import task has non contiguous key range: {}..{}",
|
||||
self.key_range().start,
|
||||
self.key_range().end
|
||||
);
|
||||
|
||||
// Tasks should operate on contiguous ranges. It is unexpected for
|
||||
// ranges to violate this assumption. Calling code handles this by mapping
|
||||
// any task on a non contiguous range to its own image layer.
|
||||
usize::MAX
|
||||
} else {
|
||||
u32::MAX as usize
|
||||
page_count as usize * 8192
|
||||
}
|
||||
}
|
||||
|
||||
@@ -735,6 +760,8 @@ impl ImportTask for ImportRelBlocksTask {
|
||||
layer_writer: &mut ImageLayerWriter,
|
||||
ctx: &RequestContext,
|
||||
) -> anyhow::Result<usize> {
|
||||
const MAX_BYTE_RANGE_SIZE: usize = 128 * 1024 * 1024;
|
||||
|
||||
debug!("Importing relation file");
|
||||
|
||||
let (rel_tag, start_blk) = self.key_range.start.to_rel_block()?;
|
||||
@@ -759,7 +786,7 @@ impl ImportTask for ImportRelBlocksTask {
|
||||
assert_eq!(key.len(), 1);
|
||||
assert!(!acc.is_empty());
|
||||
assert!(acc_end > acc_start);
|
||||
if acc_end == start /* TODO additional max range check here, to limit memory consumption per task to X */ {
|
||||
if acc_end == start && end - acc_start <= MAX_BYTE_RANGE_SIZE {
|
||||
acc.push(key.pop().unwrap());
|
||||
Ok((acc, acc_start, end))
|
||||
} else {
|
||||
@@ -774,8 +801,8 @@ impl ImportTask for ImportRelBlocksTask {
|
||||
.get_range(&self.path, range_start.into_u64(), range_end.into_u64())
|
||||
.await?;
|
||||
let mut buf = Bytes::from(range_buf);
|
||||
// TODO: batched writes
|
||||
for key in keys {
|
||||
// The writer buffers writes internally
|
||||
let image = buf.split_to(8192);
|
||||
layer_writer.put_image(key, image, ctx).await?;
|
||||
nimages += 1;
|
||||
@@ -828,6 +855,9 @@ impl ImportTask for ImportSlruBlocksTask {
|
||||
debug!("Importing SLRU segment file {}", self.path);
|
||||
let buf = self.storage.get(&self.path).await?;
|
||||
|
||||
// TODO(vlad): Does timestamp to LSN work for imported timelines?
|
||||
// Probably not since we don't append the `xact_time` to it as in
|
||||
// [`WalIngest::ingest_xact_record`].
|
||||
let (kind, segno, start_blk) = self.key_range.start.to_slru_block()?;
|
||||
let (_kind, _segno, end_blk) = self.key_range.end.to_slru_block()?;
|
||||
let mut blknum = start_blk;
|
||||
@@ -964,6 +994,15 @@ impl ChunkProcessingJob {
|
||||
.cloned();
|
||||
match existing_layer {
|
||||
Some(existing) => {
|
||||
// Unlink the remote layer from the index without scheduling its deletion.
|
||||
// When `existing_layer` drops [`LayerInner::drop`] will schedule its deletion from
|
||||
// remote storage, but that assumes that the layer was unlinked from the index first.
|
||||
timeline
|
||||
.remote_client
|
||||
.schedule_unlinking_of_layers_from_index_part(std::iter::once(
|
||||
existing.layer_desc().layer_name(),
|
||||
))?;
|
||||
|
||||
guard.open_mut()?.rewrite_layers(
|
||||
&[(existing.clone(), resident_layer.clone())],
|
||||
&[],
|
||||
|
||||
@@ -6,7 +6,7 @@ use bytes::Bytes;
|
||||
use postgres_ffi::ControlFileData;
|
||||
use remote_storage::{
|
||||
Download, DownloadError, DownloadKind, DownloadOpts, GenericRemoteStorage, Listing,
|
||||
ListingObject, RemotePath,
|
||||
ListingObject, RemotePath, RemoteStorageConfig,
|
||||
};
|
||||
use serde::de::DeserializeOwned;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
@@ -22,11 +22,9 @@ pub async fn new(
|
||||
location: &index_part_format::Location,
|
||||
cancel: CancellationToken,
|
||||
) -> Result<RemoteStorageWrapper, anyhow::Error> {
|
||||
// FIXME: we probably want some timeout, and we might be able to assume the max file
|
||||
// size on S3 is 1GiB (postgres segment size). But the problem is that the individual
|
||||
// downloaders don't know enough about concurrent downloads to make a guess on the
|
||||
// expected bandwidth and resulting best timeout.
|
||||
let timeout = std::time::Duration::from_secs(24 * 60 * 60);
|
||||
// Downloads should be reasonably sized. We do ranged reads for relblock raw data
|
||||
// and full reads for SLRU segments which are bounded by Postgres.
|
||||
let timeout = RemoteStorageConfig::DEFAULT_TIMEOUT;
|
||||
let location_storage = match location {
|
||||
#[cfg(feature = "testing")]
|
||||
index_part_format::Location::LocalFs { path } => {
|
||||
@@ -50,9 +48,12 @@ pub async fn new(
|
||||
.import_pgdata_aws_endpoint_url
|
||||
.clone()
|
||||
.map(|url| url.to_string()), // by specifying None here, remote_storage/aws-sdk-rust will infer from env
|
||||
concurrency_limit: 100.try_into().unwrap(), // TODO: think about this
|
||||
max_keys_per_list_response: Some(1000), // TODO: think about this
|
||||
upload_storage_class: None, // irrelevant
|
||||
// This matches the default import job concurrency. This is managed
|
||||
// separately from the usual S3 client, but the concern here is bandwidth
|
||||
// usage.
|
||||
concurrency_limit: 128.try_into().unwrap(),
|
||||
max_keys_per_list_response: Some(1000),
|
||||
upload_storage_class: None, // irrelevant
|
||||
},
|
||||
timeout,
|
||||
)
|
||||
|
||||
@@ -173,8 +173,10 @@ WalProposerCreate(WalProposerConfig *config, walproposer_api api)
|
||||
}
|
||||
wp->quorum = wp->n_safekeepers / 2 + 1;
|
||||
|
||||
if (wp->config->proto_version != 3)
|
||||
if (wp->config->proto_version != 2 && wp->config->proto_version != 3)
|
||||
wp_log(FATAL, "unsupported safekeeper protocol version %d", wp->config->proto_version);
|
||||
if (wp->safekeepers_generation > INVALID_GENERATION && wp->config->proto_version < 3)
|
||||
wp_log(FATAL, "enabling generations requires protocol version 3");
|
||||
wp_log(LOG, "using safekeeper protocol version %d", wp->config->proto_version);
|
||||
|
||||
/* Fill the greeting package */
|
||||
@@ -2177,79 +2179,183 @@ MembershipConfigurationSerialize(MembershipConfiguration *mconf, StringInfo buf)
|
||||
}
|
||||
}
|
||||
|
||||
/* Serialize proposer -> acceptor message into buf */
|
||||
/* Serialize proposer -> acceptor message into buf using specified version */
|
||||
static void
|
||||
PAMessageSerialize(WalProposer *wp, ProposerAcceptorMessage *msg, StringInfo buf, int proto_version)
|
||||
{
|
||||
/* only version 3 is supported */
|
||||
Assert(proto_version == 3);
|
||||
/* both version are supported currently until we fully migrate to 3 */
|
||||
Assert(proto_version == 3 || proto_version == 2);
|
||||
|
||||
resetStringInfo(buf);
|
||||
|
||||
/*
|
||||
* v2 sends structs for some messages as is, so commonly send tag only
|
||||
* for v3
|
||||
*/
|
||||
pq_sendint8(buf, msg->tag);
|
||||
|
||||
switch (msg->tag)
|
||||
if (proto_version == 3)
|
||||
{
|
||||
case 'g':
|
||||
{
|
||||
ProposerGreeting *m = (ProposerGreeting *) msg;
|
||||
/*
|
||||
* v2 sends structs for some messages as is, so commonly send tag only
|
||||
* for v3
|
||||
*/
|
||||
pq_sendint8(buf, msg->tag);
|
||||
|
||||
pq_send_ascii_string(buf, m->tenant_id);
|
||||
pq_send_ascii_string(buf, m->timeline_id);
|
||||
MembershipConfigurationSerialize(&m->mconf, buf);
|
||||
pq_sendint32(buf, m->pg_version);
|
||||
pq_sendint64(buf, m->system_id);
|
||||
pq_sendint32(buf, m->wal_seg_size);
|
||||
break;
|
||||
}
|
||||
case 'v':
|
||||
{
|
||||
VoteRequest *m = (VoteRequest *) msg;
|
||||
|
||||
pq_sendint32(buf, m->generation);
|
||||
pq_sendint64(buf, m->term);
|
||||
break;
|
||||
|
||||
}
|
||||
case 'e':
|
||||
{
|
||||
ProposerElected *m = (ProposerElected *) msg;
|
||||
|
||||
pq_sendint32(buf, m->generation);
|
||||
pq_sendint64(buf, m->term);
|
||||
pq_sendint64(buf, m->startStreamingAt);
|
||||
pq_sendint32(buf, m->termHistory->n_entries);
|
||||
for (uint32 i = 0; i < m->termHistory->n_entries; i++)
|
||||
switch (msg->tag)
|
||||
{
|
||||
case 'g':
|
||||
{
|
||||
pq_sendint64(buf, m->termHistory->entries[i].term);
|
||||
pq_sendint64(buf, m->termHistory->entries[i].lsn);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'a':
|
||||
{
|
||||
/*
|
||||
* Note: this serializes only AppendRequestHeader, caller
|
||||
* is expected to append WAL data later.
|
||||
*/
|
||||
AppendRequestHeader *m = (AppendRequestHeader *) msg;
|
||||
ProposerGreeting *m = (ProposerGreeting *) msg;
|
||||
|
||||
pq_sendint32(buf, m->generation);
|
||||
pq_sendint64(buf, m->term);
|
||||
pq_sendint64(buf, m->beginLsn);
|
||||
pq_sendint64(buf, m->endLsn);
|
||||
pq_sendint64(buf, m->commitLsn);
|
||||
pq_sendint64(buf, m->truncateLsn);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
wp_log(FATAL, "unexpected message type %c to serialize", msg->tag);
|
||||
pq_send_ascii_string(buf, m->tenant_id);
|
||||
pq_send_ascii_string(buf, m->timeline_id);
|
||||
MembershipConfigurationSerialize(&m->mconf, buf);
|
||||
pq_sendint32(buf, m->pg_version);
|
||||
pq_sendint64(buf, m->system_id);
|
||||
pq_sendint32(buf, m->wal_seg_size);
|
||||
break;
|
||||
}
|
||||
case 'v':
|
||||
{
|
||||
VoteRequest *m = (VoteRequest *) msg;
|
||||
|
||||
pq_sendint32(buf, m->generation);
|
||||
pq_sendint64(buf, m->term);
|
||||
break;
|
||||
|
||||
}
|
||||
case 'e':
|
||||
{
|
||||
ProposerElected *m = (ProposerElected *) msg;
|
||||
|
||||
pq_sendint32(buf, m->generation);
|
||||
pq_sendint64(buf, m->term);
|
||||
pq_sendint64(buf, m->startStreamingAt);
|
||||
pq_sendint32(buf, m->termHistory->n_entries);
|
||||
for (uint32 i = 0; i < m->termHistory->n_entries; i++)
|
||||
{
|
||||
pq_sendint64(buf, m->termHistory->entries[i].term);
|
||||
pq_sendint64(buf, m->termHistory->entries[i].lsn);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'a':
|
||||
{
|
||||
/*
|
||||
* Note: this serializes only AppendRequestHeader, caller
|
||||
* is expected to append WAL data later.
|
||||
*/
|
||||
AppendRequestHeader *m = (AppendRequestHeader *) msg;
|
||||
|
||||
pq_sendint32(buf, m->generation);
|
||||
pq_sendint64(buf, m->term);
|
||||
pq_sendint64(buf, m->beginLsn);
|
||||
pq_sendint64(buf, m->endLsn);
|
||||
pq_sendint64(buf, m->commitLsn);
|
||||
pq_sendint64(buf, m->truncateLsn);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
wp_log(FATAL, "unexpected message type %c to serialize", msg->tag);
|
||||
}
|
||||
return;
|
||||
}
|
||||
return;
|
||||
|
||||
if (proto_version == 2)
|
||||
{
|
||||
switch (msg->tag)
|
||||
{
|
||||
case 'g':
|
||||
{
|
||||
/* v2 sent struct as is */
|
||||
ProposerGreeting *m = (ProposerGreeting *) msg;
|
||||
ProposerGreetingV2 greetRequestV2;
|
||||
|
||||
/* Fill also v2 struct. */
|
||||
greetRequestV2.tag = 'g';
|
||||
greetRequestV2.protocolVersion = proto_version;
|
||||
greetRequestV2.pgVersion = m->pg_version;
|
||||
|
||||
/*
|
||||
* v3 removed this field because it's easier to pass as
|
||||
* libq or START_WAL_PUSH options
|
||||
*/
|
||||
memset(&greetRequestV2.proposerId, 0, sizeof(greetRequestV2.proposerId));
|
||||
greetRequestV2.systemId = wp->config->systemId;
|
||||
if (*m->timeline_id != '\0' &&
|
||||
!HexDecodeString(greetRequestV2.timeline_id, m->timeline_id, 16))
|
||||
wp_log(FATAL, "could not parse neon.timeline_id, %s", m->timeline_id);
|
||||
if (*m->tenant_id != '\0' &&
|
||||
!HexDecodeString(greetRequestV2.tenant_id, m->tenant_id, 16))
|
||||
wp_log(FATAL, "could not parse neon.tenant_id, %s", m->tenant_id);
|
||||
|
||||
greetRequestV2.timeline = wp->config->pgTimeline;
|
||||
greetRequestV2.walSegSize = wp->config->wal_segment_size;
|
||||
|
||||
pq_sendbytes(buf, (char *) &greetRequestV2, sizeof(greetRequestV2));
|
||||
break;
|
||||
}
|
||||
case 'v':
|
||||
{
|
||||
/* v2 sent struct as is */
|
||||
VoteRequest *m = (VoteRequest *) msg;
|
||||
VoteRequestV2 voteRequestV2;
|
||||
|
||||
voteRequestV2.tag = m->pam.tag;
|
||||
voteRequestV2.term = m->term;
|
||||
/* removed field */
|
||||
memset(&voteRequestV2.proposerId, 0, sizeof(voteRequestV2.proposerId));
|
||||
pq_sendbytes(buf, (char *) &voteRequestV2, sizeof(voteRequestV2));
|
||||
break;
|
||||
}
|
||||
case 'e':
|
||||
{
|
||||
ProposerElected *m = (ProposerElected *) msg;
|
||||
|
||||
pq_sendint64_le(buf, m->apm.tag);
|
||||
pq_sendint64_le(buf, m->term);
|
||||
pq_sendint64_le(buf, m->startStreamingAt);
|
||||
pq_sendint32_le(buf, m->termHistory->n_entries);
|
||||
for (int i = 0; i < m->termHistory->n_entries; i++)
|
||||
{
|
||||
pq_sendint64_le(buf, m->termHistory->entries[i].term);
|
||||
pq_sendint64_le(buf, m->termHistory->entries[i].lsn);
|
||||
}
|
||||
|
||||
/*
|
||||
* Removed timeline_start_lsn. Still send it as a valid
|
||||
* value until safekeepers taking it from term history are
|
||||
* deployed.
|
||||
*/
|
||||
pq_sendint64_le(buf, m->termHistory->entries[0].lsn);
|
||||
break;
|
||||
}
|
||||
case 'a':
|
||||
|
||||
/*
|
||||
* Note: this serializes only AppendRequestHeader, caller is
|
||||
* expected to append WAL data later.
|
||||
*/
|
||||
{
|
||||
/* v2 sent struct as is */
|
||||
AppendRequestHeader *m = (AppendRequestHeader *) msg;
|
||||
AppendRequestHeaderV2 appendRequestHeaderV2;
|
||||
|
||||
appendRequestHeaderV2.tag = m->apm.tag;
|
||||
appendRequestHeaderV2.term = m->term;
|
||||
appendRequestHeaderV2.epochStartLsn = 0; /* removed field */
|
||||
appendRequestHeaderV2.beginLsn = m->beginLsn;
|
||||
appendRequestHeaderV2.endLsn = m->endLsn;
|
||||
appendRequestHeaderV2.commitLsn = m->commitLsn;
|
||||
appendRequestHeaderV2.truncateLsn = m->truncateLsn;
|
||||
/* removed field */
|
||||
memset(&appendRequestHeaderV2.proposerId, 0, sizeof(appendRequestHeaderV2.proposerId));
|
||||
|
||||
pq_sendbytes(buf, (char *) &appendRequestHeaderV2, sizeof(appendRequestHeaderV2));
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
wp_log(FATAL, "unexpected message type %c to serialize", msg->tag);
|
||||
}
|
||||
return;
|
||||
}
|
||||
wp_log(FATAL, "unexpected proto_version %d", proto_version);
|
||||
}
|
||||
|
||||
/*
|
||||
@@ -2343,72 +2449,141 @@ AsyncReadMessage(Safekeeper *sk, AcceptorProposerMessage *anymsg)
|
||||
s.maxlen = buf_size;
|
||||
s.cursor = 0;
|
||||
|
||||
/* only version 3 is supported */
|
||||
Assert(wp->config->proto_version == 3);
|
||||
|
||||
tag = pq_getmsgbyte(&s);
|
||||
if (tag != anymsg->tag)
|
||||
if (wp->config->proto_version == 3)
|
||||
{
|
||||
wp_log(WARNING, "unexpected message tag %c from node %s:%s in state %s", (char) tag, sk->host,
|
||||
sk->port, FormatSafekeeperState(sk));
|
||||
ResetConnection(sk);
|
||||
return false;
|
||||
}
|
||||
switch (tag)
|
||||
{
|
||||
case 'g':
|
||||
{
|
||||
AcceptorGreeting *msg = (AcceptorGreeting *) anymsg;
|
||||
|
||||
msg->nodeId = pq_getmsgint64(&s);
|
||||
MembershipConfigurationDeserialize(&msg->mconf, &s);
|
||||
msg->term = pq_getmsgint64(&s);
|
||||
pq_getmsgend(&s);
|
||||
return true;
|
||||
}
|
||||
case 'v':
|
||||
{
|
||||
VoteResponse *msg = (VoteResponse *) anymsg;
|
||||
|
||||
msg->generation = pq_getmsgint32(&s);
|
||||
msg->term = pq_getmsgint64(&s);
|
||||
msg->voteGiven = pq_getmsgbyte(&s);
|
||||
msg->flushLsn = pq_getmsgint64(&s);
|
||||
msg->truncateLsn = pq_getmsgint64(&s);
|
||||
msg->termHistory.n_entries = pq_getmsgint32(&s);
|
||||
msg->termHistory.entries = palloc(sizeof(TermSwitchEntry) * msg->termHistory.n_entries);
|
||||
for (uint32 i = 0; i < msg->termHistory.n_entries; i++)
|
||||
tag = pq_getmsgbyte(&s);
|
||||
if (tag != anymsg->tag)
|
||||
{
|
||||
wp_log(WARNING, "unexpected message tag %c from node %s:%s in state %s", (char) tag, sk->host,
|
||||
sk->port, FormatSafekeeperState(sk));
|
||||
ResetConnection(sk);
|
||||
return false;
|
||||
}
|
||||
switch (tag)
|
||||
{
|
||||
case 'g':
|
||||
{
|
||||
msg->termHistory.entries[i].term = pq_getmsgint64(&s);
|
||||
msg->termHistory.entries[i].lsn = pq_getmsgint64(&s);
|
||||
}
|
||||
pq_getmsgend(&s);
|
||||
return true;
|
||||
}
|
||||
case 'a':
|
||||
{
|
||||
AppendResponse *msg = (AppendResponse *) anymsg;
|
||||
AcceptorGreeting *msg = (AcceptorGreeting *) anymsg;
|
||||
|
||||
msg->generation = pq_getmsgint32(&s);
|
||||
msg->term = pq_getmsgint64(&s);
|
||||
msg->flushLsn = pq_getmsgint64(&s);
|
||||
msg->commitLsn = pq_getmsgint64(&s);
|
||||
msg->hs.ts = pq_getmsgint64(&s);
|
||||
msg->hs.xmin.value = pq_getmsgint64(&s);
|
||||
msg->hs.catalog_xmin.value = pq_getmsgint64(&s);
|
||||
if (s.len > s.cursor)
|
||||
ParsePageserverFeedbackMessage(wp, &s, &msg->ps_feedback);
|
||||
else
|
||||
msg->ps_feedback.present = false;
|
||||
pq_getmsgend(&s);
|
||||
return true;
|
||||
}
|
||||
default:
|
||||
{
|
||||
wp_log(FATAL, "unexpected message tag %c to read", (char) tag);
|
||||
return false;
|
||||
}
|
||||
msg->nodeId = pq_getmsgint64(&s);
|
||||
MembershipConfigurationDeserialize(&msg->mconf, &s);
|
||||
msg->term = pq_getmsgint64(&s);
|
||||
pq_getmsgend(&s);
|
||||
return true;
|
||||
}
|
||||
case 'v':
|
||||
{
|
||||
VoteResponse *msg = (VoteResponse *) anymsg;
|
||||
|
||||
msg->generation = pq_getmsgint32(&s);
|
||||
msg->term = pq_getmsgint64(&s);
|
||||
msg->voteGiven = pq_getmsgbyte(&s);
|
||||
msg->flushLsn = pq_getmsgint64(&s);
|
||||
msg->truncateLsn = pq_getmsgint64(&s);
|
||||
msg->termHistory.n_entries = pq_getmsgint32(&s);
|
||||
msg->termHistory.entries = palloc(sizeof(TermSwitchEntry) * msg->termHistory.n_entries);
|
||||
for (uint32 i = 0; i < msg->termHistory.n_entries; i++)
|
||||
{
|
||||
msg->termHistory.entries[i].term = pq_getmsgint64(&s);
|
||||
msg->termHistory.entries[i].lsn = pq_getmsgint64(&s);
|
||||
}
|
||||
pq_getmsgend(&s);
|
||||
return true;
|
||||
}
|
||||
case 'a':
|
||||
{
|
||||
AppendResponse *msg = (AppendResponse *) anymsg;
|
||||
|
||||
msg->generation = pq_getmsgint32(&s);
|
||||
msg->term = pq_getmsgint64(&s);
|
||||
msg->flushLsn = pq_getmsgint64(&s);
|
||||
msg->commitLsn = pq_getmsgint64(&s);
|
||||
msg->hs.ts = pq_getmsgint64(&s);
|
||||
msg->hs.xmin.value = pq_getmsgint64(&s);
|
||||
msg->hs.catalog_xmin.value = pq_getmsgint64(&s);
|
||||
if (s.len > s.cursor)
|
||||
ParsePageserverFeedbackMessage(wp, &s, &msg->ps_feedback);
|
||||
else
|
||||
msg->ps_feedback.present = false;
|
||||
pq_getmsgend(&s);
|
||||
return true;
|
||||
}
|
||||
default:
|
||||
{
|
||||
wp_log(FATAL, "unexpected message tag %c to read", (char) tag);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (wp->config->proto_version == 2)
|
||||
{
|
||||
tag = pq_getmsgint64_le(&s);
|
||||
if (tag != anymsg->tag)
|
||||
{
|
||||
wp_log(WARNING, "unexpected message tag %c from node %s:%s in state %s", (char) tag, sk->host,
|
||||
sk->port, FormatSafekeeperState(sk));
|
||||
ResetConnection(sk);
|
||||
return false;
|
||||
}
|
||||
switch (tag)
|
||||
{
|
||||
case 'g':
|
||||
{
|
||||
AcceptorGreeting *msg = (AcceptorGreeting *) anymsg;
|
||||
|
||||
msg->term = pq_getmsgint64_le(&s);
|
||||
msg->nodeId = pq_getmsgint64_le(&s);
|
||||
pq_getmsgend(&s);
|
||||
return true;
|
||||
}
|
||||
|
||||
case 'v':
|
||||
{
|
||||
VoteResponse *msg = (VoteResponse *) anymsg;
|
||||
|
||||
msg->term = pq_getmsgint64_le(&s);
|
||||
msg->voteGiven = pq_getmsgint64_le(&s);
|
||||
msg->flushLsn = pq_getmsgint64_le(&s);
|
||||
msg->truncateLsn = pq_getmsgint64_le(&s);
|
||||
msg->termHistory.n_entries = pq_getmsgint32_le(&s);
|
||||
msg->termHistory.entries = palloc(sizeof(TermSwitchEntry) * msg->termHistory.n_entries);
|
||||
for (int i = 0; i < msg->termHistory.n_entries; i++)
|
||||
{
|
||||
msg->termHistory.entries[i].term = pq_getmsgint64_le(&s);
|
||||
msg->termHistory.entries[i].lsn = pq_getmsgint64_le(&s);
|
||||
}
|
||||
pq_getmsgint64_le(&s); /* timelineStartLsn */
|
||||
pq_getmsgend(&s);
|
||||
return true;
|
||||
}
|
||||
|
||||
case 'a':
|
||||
{
|
||||
AppendResponse *msg = (AppendResponse *) anymsg;
|
||||
|
||||
msg->term = pq_getmsgint64_le(&s);
|
||||
msg->flushLsn = pq_getmsgint64_le(&s);
|
||||
msg->commitLsn = pq_getmsgint64_le(&s);
|
||||
msg->hs.ts = pq_getmsgint64_le(&s);
|
||||
msg->hs.xmin.value = pq_getmsgint64_le(&s);
|
||||
msg->hs.catalog_xmin.value = pq_getmsgint64_le(&s);
|
||||
if (s.len > s.cursor)
|
||||
ParsePageserverFeedbackMessage(wp, &s, &msg->ps_feedback);
|
||||
else
|
||||
msg->ps_feedback.present = false;
|
||||
pq_getmsgend(&s);
|
||||
return true;
|
||||
}
|
||||
|
||||
default:
|
||||
{
|
||||
wp_log(FATAL, "unexpected message tag %c to read", (char) tag);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
wp_log(FATAL, "unsupported proto_version %d", wp->config->proto_version);
|
||||
return false; /* keep the compiler quiet */
|
||||
}
|
||||
|
||||
/*
|
||||
|
||||
@@ -17,35 +17,27 @@ pub(super) async fn authenticate(
|
||||
config: &'static AuthenticationConfig,
|
||||
secret: AuthSecret,
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
let flow = AuthFlow::new(client);
|
||||
let scram_keys = match secret {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
AuthSecret::Md5(_) => {
|
||||
debug!("auth endpoint chooses MD5");
|
||||
return Err(auth::AuthError::bad_auth_method("MD5"));
|
||||
return Err(auth::AuthError::MalformedPassword("MD5 not supported"));
|
||||
}
|
||||
AuthSecret::Scram(secret) => {
|
||||
debug!("auth endpoint chooses SCRAM");
|
||||
let scram = auth::Scram(&secret, ctx);
|
||||
|
||||
let auth_outcome = tokio::time::timeout(
|
||||
config.scram_protocol_timeout,
|
||||
async {
|
||||
|
||||
flow.begin(scram).await.map_err(|error| {
|
||||
warn!(?error, "error sending scram acknowledgement");
|
||||
error
|
||||
})?.authenticate().await.map_err(|error| {
|
||||
let auth_outcome = tokio::time::timeout(config.scram_protocol_timeout, async {
|
||||
AuthFlow::new(client, scram)
|
||||
.authenticate()
|
||||
.await
|
||||
.inspect_err(|error| {
|
||||
warn!(?error, "error processing scram messages");
|
||||
error
|
||||
})
|
||||
}
|
||||
)
|
||||
})
|
||||
.await
|
||||
.map_err(|e| {
|
||||
warn!("error processing scram messages error = authentication timed out, execution time exceeded {} seconds", config.scram_protocol_timeout.as_secs());
|
||||
auth::AuthError::user_timeout(e)
|
||||
})??;
|
||||
.inspect_err(|_| warn!("error processing scram messages error = authentication timed out, execution time exceeded {} seconds", config.scram_protocol_timeout.as_secs()))
|
||||
.map_err(auth::AuthError::user_timeout)??;
|
||||
|
||||
let client_key = match auth_outcome {
|
||||
sasl::Outcome::Success(key) => key,
|
||||
|
||||
@@ -2,7 +2,6 @@ use std::fmt;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use postgres_client::config::SslMode;
|
||||
use pq_proto::BeMessage as Be;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{info, info_span};
|
||||
@@ -16,6 +15,7 @@ use crate::context::RequestContext;
|
||||
use crate::control_plane::client::cplane_proxy_v1;
|
||||
use crate::control_plane::{self, CachedNodeInfo, NodeInfo};
|
||||
use crate::error::{ReportableError, UserFacingError};
|
||||
use crate::pqproto::BeMessage;
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::proxy::connect_compute::ComputeConnectBackend;
|
||||
use crate::stream::PqStream;
|
||||
@@ -154,11 +154,13 @@ async fn authenticate(
|
||||
|
||||
// Give user a URL to spawn a new database.
|
||||
info!(parent: &span, "sending the auth URL to the user");
|
||||
client
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
.write_message_noflush(&Be::CLIENT_ENCODING)?
|
||||
.write_message(&Be::NoticeResponse(&greeting))
|
||||
.await?;
|
||||
client.write_message(BeMessage::AuthenticationOk);
|
||||
client.write_message(BeMessage::ParameterStatus {
|
||||
name: b"client_encoding",
|
||||
value: b"UTF8",
|
||||
});
|
||||
client.write_message(BeMessage::NoticeResponse(&greeting));
|
||||
client.flush().await?;
|
||||
|
||||
// Wait for console response via control plane (see `mgmt`).
|
||||
info!(parent: &span, "waiting for console's reply...");
|
||||
@@ -188,7 +190,7 @@ async fn authenticate(
|
||||
}
|
||||
}
|
||||
|
||||
client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?;
|
||||
client.write_message(BeMessage::NoticeResponse("Connecting to database."));
|
||||
|
||||
// This config should be self-contained, because we won't
|
||||
// take username or dbname from client's startup message.
|
||||
|
||||
@@ -24,23 +24,25 @@ pub(crate) async fn authenticate_cleartext(
|
||||
debug!("cleartext auth flow override is enabled, proceeding");
|
||||
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
|
||||
|
||||
// pause the timer while we communicate with the client
|
||||
let paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
|
||||
let ep = EndpointIdInt::from(&info.endpoint);
|
||||
|
||||
let auth_flow = AuthFlow::new(client)
|
||||
.begin(auth::CleartextPassword {
|
||||
let auth_flow = AuthFlow::new(
|
||||
client,
|
||||
auth::CleartextPassword {
|
||||
secret,
|
||||
endpoint: ep,
|
||||
pool: config.thread_pool.clone(),
|
||||
})
|
||||
.await?;
|
||||
drop(paused);
|
||||
// cleartext auth is only allowed to the ws/http protocol.
|
||||
// If we're here, we already received the password in the first message.
|
||||
// Scram protocol will be executed on the proxy side.
|
||||
let auth_outcome = auth_flow.authenticate().await?;
|
||||
},
|
||||
);
|
||||
let auth_outcome = {
|
||||
// pause the timer while we communicate with the client
|
||||
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
|
||||
// cleartext auth is only allowed to the ws/http protocol.
|
||||
// If we're here, we already received the password in the first message.
|
||||
// Scram protocol will be executed on the proxy side.
|
||||
auth_flow.authenticate().await?
|
||||
};
|
||||
|
||||
let keys = match auth_outcome {
|
||||
sasl::Outcome::Success(key) => key,
|
||||
@@ -67,9 +69,7 @@ pub(crate) async fn password_hack_no_authentication(
|
||||
// pause the timer while we communicate with the client
|
||||
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
|
||||
let payload = AuthFlow::new(client)
|
||||
.begin(auth::PasswordHack)
|
||||
.await?
|
||||
let payload = AuthFlow::new(client, auth::PasswordHack)
|
||||
.get_password()
|
||||
.await?;
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ use crate::control_plane::{
|
||||
};
|
||||
use crate::intern::EndpointIdInt;
|
||||
use crate::metrics::Metrics;
|
||||
use crate::pqproto::BeMessage;
|
||||
use crate::protocol2::ConnectionInfoExtra;
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::proxy::connect_compute::ComputeConnectBackend;
|
||||
@@ -402,7 +403,7 @@ async fn authenticate_with_secret(
|
||||
};
|
||||
|
||||
// we have authenticated the password
|
||||
client.write_message_noflush(&pq_proto::BeMessage::AuthenticationOk)?;
|
||||
client.write_message(BeMessage::AuthenticationOk);
|
||||
|
||||
return Ok(ComputeCredentials { info, keys });
|
||||
}
|
||||
@@ -702,7 +703,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn auth_quirks_scram() {
|
||||
let (mut client, server) = tokio::io::duplex(1024);
|
||||
let mut stream = PqStream::new(Stream::from_raw(server));
|
||||
let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server));
|
||||
|
||||
let ctx = RequestContext::test();
|
||||
let api = Auth {
|
||||
@@ -784,7 +785,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn auth_quirks_cleartext() {
|
||||
let (mut client, server) = tokio::io::duplex(1024);
|
||||
let mut stream = PqStream::new(Stream::from_raw(server));
|
||||
let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server));
|
||||
|
||||
let ctx = RequestContext::test();
|
||||
let api = Auth {
|
||||
@@ -838,7 +839,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn auth_quirks_password_hack() {
|
||||
let (mut client, server) = tokio::io::duplex(1024);
|
||||
let mut stream = PqStream::new(Stream::from_raw(server));
|
||||
let mut stream = PqStream::new_skip_handshake(Stream::from_raw(server));
|
||||
|
||||
let ctx = RequestContext::test();
|
||||
let api = Auth {
|
||||
|
||||
@@ -5,7 +5,6 @@ use std::net::IpAddr;
|
||||
use std::str::FromStr;
|
||||
|
||||
use itertools::Itertools;
|
||||
use pq_proto::StartupMessageParams;
|
||||
use thiserror::Error;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
@@ -13,6 +12,7 @@ use crate::auth::password_hack::parse_endpoint_param;
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::{ReportableError, UserFacingError};
|
||||
use crate::metrics::{Metrics, SniGroup, SniKind};
|
||||
use crate::pqproto::StartupMessageParams;
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::serverless::{AUTH_BROKER_SNI, SERVERLESS_DRIVER_SNI};
|
||||
use crate::types::{EndpointId, RoleName};
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
//! Main authentication flow.
|
||||
|
||||
use std::io;
|
||||
use std::sync::Arc;
|
||||
|
||||
use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
|
||||
use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::info;
|
||||
|
||||
@@ -13,35 +11,26 @@ use super::{AuthError, PasswordHackPayload};
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::AuthSecret;
|
||||
use crate::intern::EndpointIdInt;
|
||||
use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage};
|
||||
use crate::sasl;
|
||||
use crate::scram::threadpool::ThreadPool;
|
||||
use crate::scram::{self};
|
||||
use crate::stream::{PqStream, Stream};
|
||||
use crate::tls::TlsServerEndPoint;
|
||||
|
||||
/// Every authentication selector is supposed to implement this trait.
|
||||
pub(crate) trait AuthMethod {
|
||||
/// Any authentication selector should provide initial backend message
|
||||
/// containing auth method name and parameters, e.g. md5 salt.
|
||||
fn first_message(&self, channel_binding: bool) -> BeMessage<'_>;
|
||||
}
|
||||
|
||||
/// Initial state of [`AuthFlow`].
|
||||
pub(crate) struct Begin;
|
||||
|
||||
/// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
|
||||
pub(crate) struct Scram<'a>(
|
||||
pub(crate) &'a scram::ServerSecret,
|
||||
pub(crate) &'a RequestContext,
|
||||
);
|
||||
|
||||
impl AuthMethod for Scram<'_> {
|
||||
impl Scram<'_> {
|
||||
#[inline(always)]
|
||||
fn first_message(&self, channel_binding: bool) -> BeMessage<'_> {
|
||||
if channel_binding {
|
||||
Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
|
||||
BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(scram::METHODS))
|
||||
} else {
|
||||
Be::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(
|
||||
BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(
|
||||
scram::METHODS_WITHOUT_PLUS,
|
||||
))
|
||||
}
|
||||
@@ -52,13 +41,6 @@ impl AuthMethod for Scram<'_> {
|
||||
/// <https://github.com/neondatabase/cloud/issues/1620#issuecomment-1165332290>.
|
||||
pub(crate) struct PasswordHack;
|
||||
|
||||
impl AuthMethod for PasswordHack {
|
||||
#[inline(always)]
|
||||
fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
|
||||
Be::AuthenticationCleartextPassword
|
||||
}
|
||||
}
|
||||
|
||||
/// Use clear-text password auth called `password` in docs
|
||||
/// <https://www.postgresql.org/docs/current/auth-password.html>
|
||||
pub(crate) struct CleartextPassword {
|
||||
@@ -67,53 +49,37 @@ pub(crate) struct CleartextPassword {
|
||||
pub(crate) secret: AuthSecret,
|
||||
}
|
||||
|
||||
impl AuthMethod for CleartextPassword {
|
||||
#[inline(always)]
|
||||
fn first_message(&self, _channel_binding: bool) -> BeMessage<'_> {
|
||||
Be::AuthenticationCleartextPassword
|
||||
}
|
||||
}
|
||||
|
||||
/// This wrapper for [`PqStream`] performs client authentication.
|
||||
#[must_use]
|
||||
pub(crate) struct AuthFlow<'a, S, State> {
|
||||
/// The underlying stream which implements libpq's protocol.
|
||||
stream: &'a mut PqStream<Stream<S>>,
|
||||
/// State might contain ancillary data (see [`Self::begin`]).
|
||||
/// State might contain ancillary data.
|
||||
state: State,
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
}
|
||||
|
||||
/// Initial state of the stream wrapper.
|
||||
impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
|
||||
impl<'a, S: AsyncRead + AsyncWrite + Unpin, M> AuthFlow<'a, S, M> {
|
||||
/// Create a new wrapper for client authentication.
|
||||
pub(crate) fn new(stream: &'a mut PqStream<Stream<S>>) -> Self {
|
||||
pub(crate) fn new(stream: &'a mut PqStream<Stream<S>>, method: M) -> Self {
|
||||
let tls_server_end_point = stream.get_ref().tls_server_end_point();
|
||||
|
||||
Self {
|
||||
stream,
|
||||
state: Begin,
|
||||
state: method,
|
||||
tls_server_end_point,
|
||||
}
|
||||
}
|
||||
|
||||
/// Move to the next step by sending auth method's name & params to client.
|
||||
pub(crate) async fn begin<M: AuthMethod>(self, method: M) -> io::Result<AuthFlow<'a, S, M>> {
|
||||
self.stream
|
||||
.write_message(&method.first_message(self.tls_server_end_point.supported()))
|
||||
.await?;
|
||||
|
||||
Ok(AuthFlow {
|
||||
stream: self.stream,
|
||||
state: method,
|
||||
tls_server_end_point: self.tls_server_end_point,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
|
||||
/// Perform user authentication. Raise an error in case authentication failed.
|
||||
pub(crate) async fn get_password(self) -> super::Result<PasswordHackPayload> {
|
||||
self.stream
|
||||
.write_message(BeMessage::AuthenticationCleartextPassword);
|
||||
self.stream.flush().await?;
|
||||
|
||||
let msg = self.stream.read_password_message().await?;
|
||||
let password = msg
|
||||
.strip_suffix(&[0])
|
||||
@@ -133,6 +99,10 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
|
||||
/// Perform user authentication. Raise an error in case authentication failed.
|
||||
pub(crate) async fn authenticate(self) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
|
||||
self.stream
|
||||
.write_message(BeMessage::AuthenticationCleartextPassword);
|
||||
self.stream.flush().await?;
|
||||
|
||||
let msg = self.stream.read_password_message().await?;
|
||||
let password = msg
|
||||
.strip_suffix(&[0])
|
||||
@@ -147,7 +117,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
|
||||
.await?;
|
||||
|
||||
if let sasl::Outcome::Success(_) = &outcome {
|
||||
self.stream.write_message_noflush(&Be::AuthenticationOk)?;
|
||||
self.stream.write_message(BeMessage::AuthenticationOk);
|
||||
}
|
||||
|
||||
Ok(outcome)
|
||||
@@ -159,42 +129,36 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
|
||||
/// Perform user authentication. Raise an error in case authentication failed.
|
||||
pub(crate) async fn authenticate(self) -> super::Result<sasl::Outcome<scram::ScramKey>> {
|
||||
let Scram(secret, ctx) = self.state;
|
||||
let channel_binding = self.tls_server_end_point;
|
||||
|
||||
// pause the timer while we communicate with the client
|
||||
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
// send sasl message.
|
||||
{
|
||||
// pause the timer while we communicate with the client
|
||||
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
|
||||
// Initial client message contains the chosen auth method's name.
|
||||
let msg = self.stream.read_password_message().await?;
|
||||
let sasl = sasl::FirstMessage::parse(&msg)
|
||||
.ok_or(AuthError::MalformedPassword("bad sasl message"))?;
|
||||
|
||||
// Currently, the only supported SASL method is SCRAM.
|
||||
if !scram::METHODS.contains(&sasl.method) {
|
||||
return Err(super::AuthError::bad_auth_method(sasl.method));
|
||||
let sasl = self.state.first_message(channel_binding.supported());
|
||||
self.stream.write_message(sasl);
|
||||
self.stream.flush().await?;
|
||||
}
|
||||
|
||||
match sasl.method {
|
||||
SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256),
|
||||
SCRAM_SHA_256_PLUS => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus),
|
||||
_ => {}
|
||||
}
|
||||
// complete sasl handshake.
|
||||
sasl::authenticate(ctx, self.stream, |method| {
|
||||
// Currently, the only supported SASL method is SCRAM.
|
||||
match method {
|
||||
SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256),
|
||||
SCRAM_SHA_256_PLUS => {
|
||||
ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus);
|
||||
}
|
||||
method => return Err(sasl::Error::BadAuthMethod(method.into())),
|
||||
}
|
||||
|
||||
// TODO: make this a metric instead
|
||||
info!("client chooses {}", sasl.method);
|
||||
// TODO: make this a metric instead
|
||||
info!("client chooses {}", method);
|
||||
|
||||
let outcome = sasl::SaslStream::new(self.stream, sasl.message)
|
||||
.authenticate(scram::Exchange::new(
|
||||
secret,
|
||||
rand::random,
|
||||
self.tls_server_end_point,
|
||||
))
|
||||
.await?;
|
||||
|
||||
if let sasl::Outcome::Success(_) = &outcome {
|
||||
self.stream.write_message_noflush(&Be::AuthenticationOk)?;
|
||||
}
|
||||
|
||||
Ok(outcome)
|
||||
Ok(scram::Exchange::new(secret, rand::random, channel_binding))
|
||||
})
|
||||
.await
|
||||
.map_err(AuthError::Sasl)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,8 +4,9 @@
|
||||
//! This allows connecting to pods/services running in the same Kubernetes cluster from
|
||||
//! the outside. Similar to an ingress controller for HTTPS.
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::path::Path;
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Context, anyhow, bail, ensure};
|
||||
use clap::Arg;
|
||||
@@ -17,6 +18,7 @@ use rustls::pki_types::{DnsName, PrivateKeyDer};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_rustls::TlsConnector;
|
||||
use tokio_rustls::server::TlsStream;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{Instrument, error, info};
|
||||
use utils::project_git_version;
|
||||
@@ -24,10 +26,12 @@ use utils::sentry_init::init_sentry;
|
||||
|
||||
use crate::context::RequestContext;
|
||||
use crate::metrics::{Metrics, ThreadPoolMetrics};
|
||||
use crate::pqproto::FeStartupPacket;
|
||||
use crate::protocol2::ConnectionInfo;
|
||||
use crate::proxy::{ErrorSource, copy_bidirectional_client_compute, run_until_cancelled};
|
||||
use crate::proxy::{
|
||||
ErrorSource, TlsRequired, copy_bidirectional_client_compute, run_until_cancelled,
|
||||
};
|
||||
use crate::stream::{PqStream, Stream};
|
||||
use crate::tls::TlsServerEndPoint;
|
||||
|
||||
project_git_version!(GIT_VERSION);
|
||||
|
||||
@@ -84,7 +88,7 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
.parse()?;
|
||||
|
||||
// Configure TLS
|
||||
let (tls_config, tls_server_end_point): (Arc<rustls::ServerConfig>, TlsServerEndPoint) = match (
|
||||
let tls_config = match (
|
||||
args.get_one::<String>("tls-key"),
|
||||
args.get_one::<String>("tls-cert"),
|
||||
) {
|
||||
@@ -117,7 +121,6 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
dest.clone(),
|
||||
tls_config.clone(),
|
||||
None,
|
||||
tls_server_end_point,
|
||||
proxy_listener,
|
||||
cancellation_token.clone(),
|
||||
))
|
||||
@@ -127,7 +130,6 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
dest,
|
||||
tls_config,
|
||||
Some(compute_tls_config),
|
||||
tls_server_end_point,
|
||||
proxy_listener_compute_tls,
|
||||
cancellation_token.clone(),
|
||||
))
|
||||
@@ -154,7 +156,7 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
pub(super) fn parse_tls(
|
||||
key_path: &Path,
|
||||
cert_path: &Path,
|
||||
) -> anyhow::Result<(Arc<rustls::ServerConfig>, TlsServerEndPoint)> {
|
||||
) -> anyhow::Result<Arc<rustls::ServerConfig>> {
|
||||
let key = {
|
||||
let key_bytes = std::fs::read(key_path).context("TLS key file")?;
|
||||
|
||||
@@ -187,10 +189,6 @@ pub(super) fn parse_tls(
|
||||
})?
|
||||
};
|
||||
|
||||
// needed for channel bindings
|
||||
let first_cert = cert_chain.first().context("missing certificate")?;
|
||||
let tls_server_end_point = TlsServerEndPoint::new(first_cert)?;
|
||||
|
||||
let tls_config =
|
||||
rustls::ServerConfig::builder_with_provider(Arc::new(ring::default_provider()))
|
||||
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])
|
||||
@@ -199,14 +197,13 @@ pub(super) fn parse_tls(
|
||||
.with_single_cert(cert_chain, key)?
|
||||
.into();
|
||||
|
||||
Ok((tls_config, tls_server_end_point))
|
||||
Ok(tls_config)
|
||||
}
|
||||
|
||||
pub(super) async fn task_main(
|
||||
dest_suffix: Arc<String>,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
compute_tls_config: Option<Arc<rustls::ClientConfig>>,
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
listener: tokio::net::TcpListener,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> anyhow::Result<()> {
|
||||
@@ -242,15 +239,7 @@ pub(super) async fn task_main(
|
||||
crate::metrics::Protocol::SniRouter,
|
||||
"sni",
|
||||
);
|
||||
handle_client(
|
||||
ctx,
|
||||
dest_suffix,
|
||||
tls_config,
|
||||
compute_tls_config,
|
||||
tls_server_end_point,
|
||||
socket,
|
||||
)
|
||||
.await
|
||||
handle_client(ctx, dest_suffix, tls_config, compute_tls_config, socket).await
|
||||
}
|
||||
.unwrap_or_else(|e| {
|
||||
// Acknowledge that the task has finished with an error.
|
||||
@@ -269,55 +258,26 @@ pub(super) async fn task_main(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
|
||||
|
||||
async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
ctx: &RequestContext,
|
||||
raw_stream: S,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
) -> anyhow::Result<Stream<S>> {
|
||||
let mut stream = PqStream::new(Stream::from_raw(raw_stream));
|
||||
|
||||
let msg = stream.read_startup_packet().await?;
|
||||
use pq_proto::FeStartupPacket::SslRequest;
|
||||
|
||||
) -> anyhow::Result<TlsStream<S>> {
|
||||
let (mut stream, msg) = PqStream::parse_startup(Stream::from_raw(raw_stream)).await?;
|
||||
match msg {
|
||||
SslRequest { direct: false } => {
|
||||
stream
|
||||
.write_message(&pq_proto::BeMessage::EncryptionResponse(true))
|
||||
.await?;
|
||||
FeStartupPacket::SslRequest { direct: None } => {
|
||||
let raw = stream.accept_tls().await?;
|
||||
|
||||
// Upgrade raw stream into a secure TLS-backed stream.
|
||||
// NOTE: We've consumed `tls`; this fact will be used later.
|
||||
|
||||
let (raw, read_buf) = stream.into_inner();
|
||||
// TODO: Normally, client doesn't send any data before
|
||||
// server says TLS handshake is ok and read_buf is empty.
|
||||
// However, you could imagine pipelining of postgres
|
||||
// SSLRequest + TLS ClientHello in one hunk similar to
|
||||
// pipelining in our node js driver. We should probably
|
||||
// support that by chaining read_buf with the stream.
|
||||
if !read_buf.is_empty() {
|
||||
bail!("data is sent before server replied with EncryptionResponse");
|
||||
}
|
||||
|
||||
Ok(Stream::Tls {
|
||||
tls: Box::new(
|
||||
raw.upgrade(tls_config, !ctx.has_private_peer_addr())
|
||||
.await?,
|
||||
),
|
||||
tls_server_end_point,
|
||||
})
|
||||
Ok(raw
|
||||
.upgrade(tls_config, !ctx.has_private_peer_addr())
|
||||
.await?)
|
||||
}
|
||||
unexpected => {
|
||||
info!(
|
||||
?unexpected,
|
||||
"unexpected startup packet, rejecting connection"
|
||||
);
|
||||
stream
|
||||
.throw_error_str(ERR_INSECURE_CONNECTION, crate::error::ErrorKind::User, None)
|
||||
.await?
|
||||
Err(stream.throw_error(TlsRequired, None).await)?
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -327,15 +287,18 @@ async fn handle_client(
|
||||
dest_suffix: Arc<String>,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
compute_tls_config: Option<Arc<rustls::ClientConfig>>,
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
stream: impl AsyncRead + AsyncWrite + Unpin,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut tls_stream = ssl_handshake(&ctx, stream, tls_config, tls_server_end_point).await?;
|
||||
let mut tls_stream = ssl_handshake(&ctx, stream, tls_config).await?;
|
||||
|
||||
// Cut off first part of the SNI domain
|
||||
// We receive required destination details in the format of
|
||||
// `{k8s_service_name}--{k8s_namespace}--{port}.non-sni-domain`
|
||||
let sni = tls_stream.sni_hostname().ok_or(anyhow!("SNI missing"))?;
|
||||
let sni = tls_stream
|
||||
.get_ref()
|
||||
.1
|
||||
.server_name()
|
||||
.ok_or(anyhow!("SNI missing"))?;
|
||||
let dest: Vec<&str> = sni
|
||||
.split_once('.')
|
||||
.context("invalid SNI")?
|
||||
|
||||
@@ -476,8 +476,7 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
let key_path = args.tls_key.expect("already asserted it is set");
|
||||
let cert_path = args.tls_cert.expect("already asserted it is set");
|
||||
|
||||
let (tls_config, tls_server_end_point) =
|
||||
super::pg_sni_router::parse_tls(&key_path, &cert_path)?;
|
||||
let tls_config = super::pg_sni_router::parse_tls(&key_path, &cert_path)?;
|
||||
|
||||
let dest = Arc::new(dest);
|
||||
|
||||
@@ -485,7 +484,6 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
dest.clone(),
|
||||
tls_config.clone(),
|
||||
None,
|
||||
tls_server_end_point,
|
||||
listen,
|
||||
cancellation_token.clone(),
|
||||
));
|
||||
@@ -494,7 +492,6 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
dest,
|
||||
tls_config,
|
||||
Some(config.connect_to_compute.tls.clone()),
|
||||
tls_server_end_point,
|
||||
listen_tls,
|
||||
cancellation_token.clone(),
|
||||
));
|
||||
|
||||
@@ -5,7 +5,6 @@ use anyhow::{Context, anyhow};
|
||||
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
|
||||
use postgres_client::CancelToken;
|
||||
use postgres_client::tls::MakeTlsConnect;
|
||||
use pq_proto::CancelKeyData;
|
||||
use redis::{Cmd, FromRedisValue, Value};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
@@ -21,6 +20,7 @@ use crate::control_plane::ControlPlaneApi;
|
||||
use crate::error::ReportableError;
|
||||
use crate::ext::LockExt;
|
||||
use crate::metrics::{CancelChannelSizeGuard, CancellationRequest, Metrics, RedisMsgKind};
|
||||
use crate::pqproto::CancelKeyData;
|
||||
use crate::protocol2::ConnectionInfoExtra;
|
||||
use crate::rate_limiter::LeakyBucketRateLimiter;
|
||||
use crate::redis::keys::KeyPrefix;
|
||||
|
||||
@@ -8,7 +8,6 @@ use itertools::Itertools;
|
||||
use postgres_client::tls::MakeTlsConnect;
|
||||
use postgres_client::{CancelToken, RawConnection};
|
||||
use postgres_protocol::message::backend::NoticeResponseBody;
|
||||
use pq_proto::StartupMessageParams;
|
||||
use rustls::pki_types::InvalidDnsNameError;
|
||||
use thiserror::Error;
|
||||
use tokio::net::{TcpStream, lookup_host};
|
||||
@@ -24,6 +23,7 @@ use crate::control_plane::errors::WakeComputeError;
|
||||
use crate::control_plane::messages::MetricsAuxInfo;
|
||||
use crate::error::{ReportableError, UserFacingError};
|
||||
use crate::metrics::{Metrics, NumDbConnectionsGuard};
|
||||
use crate::pqproto::StartupMessageParams;
|
||||
use crate::proxy::neon_option;
|
||||
use crate::tls::postgres_rustls::MakeRustlsConnect;
|
||||
use crate::types::Host;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{Instrument, debug, error, info};
|
||||
|
||||
@@ -221,12 +221,10 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
.await
|
||||
{
|
||||
Ok(auth_result) => auth_result,
|
||||
Err(e) => {
|
||||
return stream.throw_error(e, Some(ctx)).await?;
|
||||
}
|
||||
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
|
||||
};
|
||||
|
||||
let mut node = connect_to_compute(
|
||||
let node = connect_to_compute(
|
||||
ctx,
|
||||
&TcpMechanism {
|
||||
user_info,
|
||||
@@ -238,7 +236,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
config.wake_compute_retry_config,
|
||||
&config.connect_to_compute,
|
||||
)
|
||||
.or_else(|e| stream.throw_error(e, Some(ctx)))
|
||||
.or_else(|e| async { Err(stream.throw_error(e, Some(ctx)).await) })
|
||||
.await?;
|
||||
|
||||
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
||||
@@ -246,14 +244,8 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
|
||||
session.write_cancel_key(node.cancel_closure.clone())?;
|
||||
|
||||
prepare_client_connection(&node, *session.key(), &mut stream).await?;
|
||||
|
||||
// Before proxy passing, forward to compute whatever data is left in the
|
||||
// PqStream input buffer. Normally there is none, but our serverless npm
|
||||
// driver in pipeline mode sends startup, password and first query
|
||||
// immediately after opening the connection.
|
||||
let (stream, read_buf) = stream.into_inner();
|
||||
node.stream.write_all(&read_buf).await?;
|
||||
prepare_client_connection(&node, *session.key(), &mut stream);
|
||||
let stream = stream.flush_and_into_inner().await?;
|
||||
|
||||
Ok(Some(ProxyPassthrough {
|
||||
client: stream,
|
||||
|
||||
@@ -4,7 +4,6 @@ use std::net::IpAddr;
|
||||
|
||||
use chrono::Utc;
|
||||
use once_cell::sync::OnceCell;
|
||||
use pq_proto::StartupMessageParams;
|
||||
use smol_str::SmolStr;
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::field::display;
|
||||
@@ -20,6 +19,7 @@ use crate::metrics::{
|
||||
ConnectOutcome, InvalidEndpointsGroup, LatencyAccumulated, LatencyTimer, Metrics, Protocol,
|
||||
Waiting,
|
||||
};
|
||||
use crate::pqproto::StartupMessageParams;
|
||||
use crate::protocol2::{ConnectionInfo, ConnectionInfoExtra};
|
||||
use crate::types::{DbName, EndpointId, RoleName};
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ use parquet::file::metadata::RowGroupMetaDataPtr;
|
||||
use parquet::file::properties::{DEFAULT_PAGE_SIZE, WriterProperties, WriterPropertiesPtr};
|
||||
use parquet::file::writer::SerializedFileWriter;
|
||||
use parquet::record::RecordWriter;
|
||||
use pq_proto::StartupMessageParams;
|
||||
use remote_storage::{GenericRemoteStorage, RemotePath, RemoteStorageConfig, TimeoutOrCancel};
|
||||
use serde::ser::SerializeMap;
|
||||
use tokio::sync::mpsc;
|
||||
@@ -24,6 +23,7 @@ use super::{LOG_CHAN, RequestContextInner};
|
||||
use crate::config::remote_storage_from_toml;
|
||||
use crate::context::LOG_CHAN_DISCONNECT;
|
||||
use crate::ext::TaskExt;
|
||||
use crate::pqproto::StartupMessageParams;
|
||||
|
||||
#[derive(clap::Args, Clone, Debug)]
|
||||
pub struct ParquetUploadArgs {
|
||||
|
||||
@@ -92,6 +92,7 @@ mod logging;
|
||||
mod metrics;
|
||||
mod parse;
|
||||
mod pglb;
|
||||
mod pqproto;
|
||||
mod protocol2;
|
||||
mod proxy;
|
||||
mod rate_limiter;
|
||||
|
||||
693
proxy/src/pqproto.rs
Normal file
693
proxy/src/pqproto.rs
Normal file
@@ -0,0 +1,693 @@
|
||||
//! Postgres protocol codec
|
||||
//!
|
||||
//! <https://www.postgresql.org/docs/current/protocol-message-formats.html>
|
||||
|
||||
use std::fmt;
|
||||
use std::io::{self, Cursor};
|
||||
|
||||
use bytes::{Buf, BufMut};
|
||||
use itertools::Itertools;
|
||||
use rand::distributions::{Distribution, Standard};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt};
|
||||
use zerocopy::{FromBytes, Immutable, IntoBytes, big_endian};
|
||||
|
||||
pub type ErrorCode = [u8; 5];
|
||||
|
||||
pub const FE_PASSWORD_MESSAGE: u8 = b'p';
|
||||
|
||||
pub const SQLSTATE_INTERNAL_ERROR: [u8; 5] = *b"XX000";
|
||||
|
||||
/// The protocol version number.
|
||||
///
|
||||
/// The most significant 16 bits are the major version number (3 for the protocol described here).
|
||||
/// The least significant 16 bits are the minor version number (0 for the protocol described here).
|
||||
/// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-STARTUPMESSAGE>
|
||||
#[derive(Clone, Copy, PartialEq, PartialOrd, FromBytes, IntoBytes, Immutable)]
|
||||
#[repr(C)]
|
||||
pub struct ProtocolVersion {
|
||||
major: big_endian::U16,
|
||||
minor: big_endian::U16,
|
||||
}
|
||||
|
||||
impl ProtocolVersion {
|
||||
pub const fn new(major: u16, minor: u16) -> Self {
|
||||
Self {
|
||||
major: big_endian::U16::new(major),
|
||||
minor: big_endian::U16::new(minor),
|
||||
}
|
||||
}
|
||||
pub const fn minor(self) -> u16 {
|
||||
self.minor.get()
|
||||
}
|
||||
pub const fn major(self) -> u16 {
|
||||
self.major.get()
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for ProtocolVersion {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_list()
|
||||
.entry(&self.major())
|
||||
.entry(&self.minor())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
/// read the type from the stream using zerocopy.
|
||||
///
|
||||
/// not cancel safe.
|
||||
macro_rules! read {
|
||||
($s:expr => $t:ty) => {{
|
||||
// cannot be implemented as a function due to lack of const-generic-expr
|
||||
let mut buf = [0; size_of::<$t>()];
|
||||
$s.read_exact(&mut buf).await?;
|
||||
let res: $t = zerocopy::transmute!(buf);
|
||||
res
|
||||
}};
|
||||
}
|
||||
|
||||
pub async fn read_startup<S>(stream: &mut S) -> io::Result<FeStartupPacket>
|
||||
where
|
||||
S: AsyncRead + Unpin,
|
||||
{
|
||||
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L118>
|
||||
const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
|
||||
const RESERVED_INVALID_MAJOR_VERSION: u16 = 1234;
|
||||
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L132>
|
||||
const CANCEL_REQUEST_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5678);
|
||||
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L166>
|
||||
const NEGOTIATE_SSL_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5679);
|
||||
/// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/include/libpq/pqcomm.h#L167>
|
||||
const NEGOTIATE_GSS_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5680);
|
||||
|
||||
/// This first reads the startup message header, is 8 bytes.
|
||||
/// The first 4 bytes is a big-endian message length, and the next 4 bytes is a version number.
|
||||
///
|
||||
/// The length value is inclusive of the header. For example,
|
||||
/// an empty message will always have length 8.
|
||||
#[derive(Clone, Copy, FromBytes, IntoBytes, Immutable)]
|
||||
#[repr(C)]
|
||||
struct StartupHeader {
|
||||
len: big_endian::U32,
|
||||
version: ProtocolVersion,
|
||||
}
|
||||
|
||||
let header = read!(stream => StartupHeader);
|
||||
|
||||
// <https://github.com/postgres/postgres/blob/04bcf9e19a4261fe9c7df37c777592c2e10c32a7/src/backend/tcop/backend_startup.c#L378-L382>
|
||||
// First byte indicates standard SSL handshake message
|
||||
// (It can't be a Postgres startup length because in network byte order
|
||||
// that would be a startup packet hundreds of megabytes long)
|
||||
if header.as_bytes()[0] == 0x16 {
|
||||
return Ok(FeStartupPacket::SslRequest {
|
||||
// The bytes we read for the header are actually part of a TLS ClientHello.
|
||||
// In theory, if the ClientHello was < 8 bytes we would fail with EOF before we get here.
|
||||
// In practice though, I see no world where a ClientHello is less than 8 bytes
|
||||
// since it includes ephemeral keys etc.
|
||||
direct: Some(zerocopy::transmute!(header)),
|
||||
});
|
||||
}
|
||||
|
||||
let Some(len) = (header.len.get() as usize).checked_sub(8) else {
|
||||
return Err(io::Error::other(format!(
|
||||
"invalid startup message length {}, must be at least 8.",
|
||||
header.len,
|
||||
)));
|
||||
};
|
||||
|
||||
// TODO: add a histogram for startup packet lengths
|
||||
if len > MAX_STARTUP_PACKET_LENGTH {
|
||||
tracing::warn!("large startup message detected: {len} bytes");
|
||||
return Err(io::Error::other(format!(
|
||||
"invalid startup message length {len}"
|
||||
)));
|
||||
}
|
||||
|
||||
match header.version {
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-CANCELREQUEST>
|
||||
CANCEL_REQUEST_CODE => {
|
||||
if len != 8 {
|
||||
return Err(io::Error::other(
|
||||
"CancelRequest message is malformed, backend PID / secret key missing",
|
||||
));
|
||||
}
|
||||
|
||||
Ok(FeStartupPacket::CancelRequest(
|
||||
read!(stream => CancelKeyData),
|
||||
))
|
||||
}
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-SSLREQUEST>
|
||||
NEGOTIATE_SSL_CODE => {
|
||||
// Requested upgrade to SSL (aka TLS)
|
||||
Ok(FeStartupPacket::SslRequest { direct: None })
|
||||
}
|
||||
NEGOTIATE_GSS_CODE => {
|
||||
// Requested upgrade to GSSAPI
|
||||
Ok(FeStartupPacket::GssEncRequest)
|
||||
}
|
||||
version if version.major() == RESERVED_INVALID_MAJOR_VERSION => Err(io::Error::other(
|
||||
format!("Unrecognized request code {version:?}"),
|
||||
)),
|
||||
// StartupMessage
|
||||
version => {
|
||||
// The protocol version number is followed by one or more pairs of parameter name and value strings.
|
||||
// A zero byte is required as a terminator after the last name/value pair.
|
||||
// Parameters can appear in any order. user is required, others are optional.
|
||||
|
||||
let mut buf = vec![0; len];
|
||||
stream.read_exact(&mut buf).await?;
|
||||
|
||||
if buf.pop() != Some(b'\0') {
|
||||
return Err(io::Error::other(
|
||||
"StartupMessage params: missing null terminator",
|
||||
));
|
||||
}
|
||||
|
||||
// TODO: Don't do this.
|
||||
// There's no guarantee that these messages are utf8,
|
||||
// but they usually happen to be simple ascii.
|
||||
let params = String::from_utf8(buf)
|
||||
.map_err(|_| io::Error::other("StartupMessage params: invalid utf-8"))?;
|
||||
|
||||
Ok(FeStartupPacket::StartupMessage {
|
||||
version,
|
||||
params: StartupMessageParams { params },
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Read a raw postgres packet, which will respect the max length requested.
|
||||
///
|
||||
/// This returns the message tag, as well as the message body. The message
|
||||
/// body is written into `buf`, and it is otherwise completely overwritten.
|
||||
///
|
||||
/// This is not cancel safe.
|
||||
pub async fn read_message<'a, S>(
|
||||
stream: &mut S,
|
||||
buf: &'a mut Vec<u8>,
|
||||
max: usize,
|
||||
) -> io::Result<(u8, &'a mut [u8])>
|
||||
where
|
||||
S: AsyncRead + Unpin,
|
||||
{
|
||||
/// This first reads the header, which for regular messages in the 3.0 protocol is 5 bytes.
|
||||
/// The first byte is a message tag, and the next 4 bytes is a big-endian length.
|
||||
///
|
||||
/// Awkwardly, the length value is inclusive of itself, but not of the tag. For example,
|
||||
/// an empty message will always have length 4.
|
||||
#[derive(Clone, Copy, FromBytes)]
|
||||
#[repr(C)]
|
||||
struct Header {
|
||||
tag: u8,
|
||||
len: big_endian::U32,
|
||||
}
|
||||
|
||||
let header = read!(stream => Header);
|
||||
|
||||
// as described above, the length must be at least 4.
|
||||
let Some(len) = (header.len.get() as usize).checked_sub(4) else {
|
||||
return Err(io::Error::other(format!(
|
||||
"invalid startup message length {}, must be at least 4.",
|
||||
header.len,
|
||||
)));
|
||||
};
|
||||
|
||||
// TODO: add a histogram for message lengths
|
||||
|
||||
// check if the message exceeds our desired max.
|
||||
if len > max {
|
||||
tracing::warn!("large postgres message detected: {len} bytes");
|
||||
return Err(io::Error::other(format!("invalid message length {len}")));
|
||||
}
|
||||
|
||||
// read in our entire message.
|
||||
buf.resize(len, 0);
|
||||
stream.read_exact(buf).await?;
|
||||
|
||||
Ok((header.tag, buf))
|
||||
}
|
||||
|
||||
pub struct WriteBuf(Cursor<Vec<u8>>);
|
||||
|
||||
impl Buf for WriteBuf {
|
||||
#[inline]
|
||||
fn remaining(&self) -> usize {
|
||||
self.0.remaining()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn chunk(&self) -> &[u8] {
|
||||
self.0.chunk()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn advance(&mut self, cnt: usize) {
|
||||
self.0.advance(cnt);
|
||||
}
|
||||
}
|
||||
|
||||
impl WriteBuf {
|
||||
pub const fn new() -> Self {
|
||||
Self(Cursor::new(Vec::new()))
|
||||
}
|
||||
|
||||
/// Use a heuristic to determine if we should shrink the write buffer.
|
||||
#[inline]
|
||||
fn should_shrink(&self) -> bool {
|
||||
let n = self.0.position() as usize;
|
||||
let len = self.0.get_ref().len();
|
||||
|
||||
// the unused space at the front of our buffer is 2x the size of our filled portion.
|
||||
n + n > len
|
||||
}
|
||||
|
||||
/// Shrink the write buffer so that subsequent writes have more spare capacity.
|
||||
#[cold]
|
||||
fn shrink(&mut self) {
|
||||
let n = self.0.position() as usize;
|
||||
let buf = self.0.get_mut();
|
||||
|
||||
// buf repr:
|
||||
// [----unused------|-----filled-----|-----uninit-----]
|
||||
// ^ n ^ buf.len() ^ buf.capacity()
|
||||
let filled = n..buf.len();
|
||||
let filled_len = filled.len();
|
||||
buf.copy_within(filled, 0);
|
||||
buf.truncate(filled_len);
|
||||
self.0.set_position(0);
|
||||
}
|
||||
|
||||
/// clear the write buffer.
|
||||
pub fn reset(&mut self) {
|
||||
let buf = self.0.get_mut();
|
||||
buf.clear();
|
||||
self.0.set_position(0);
|
||||
}
|
||||
|
||||
/// Write a raw message to the internal buffer.
|
||||
///
|
||||
/// The size_hint value is only a hint for reserving space. It's ok if it's incorrect, since
|
||||
/// we calculate the length after the fact.
|
||||
pub fn write_raw(&mut self, size_hint: usize, tag: u8, f: impl FnOnce(&mut Vec<u8>)) {
|
||||
if self.should_shrink() {
|
||||
self.shrink();
|
||||
}
|
||||
|
||||
let buf = self.0.get_mut();
|
||||
buf.reserve(5 + size_hint);
|
||||
|
||||
buf.push(tag);
|
||||
let start = buf.len();
|
||||
buf.extend_from_slice(&[0, 0, 0, 0]);
|
||||
|
||||
f(buf);
|
||||
|
||||
let end = buf.len();
|
||||
let len = (end - start) as u32;
|
||||
buf[start..start + 4].copy_from_slice(&len.to_be_bytes());
|
||||
}
|
||||
|
||||
/// Write an encryption response message.
|
||||
pub fn encryption(&mut self, m: u8) {
|
||||
self.0.get_mut().push(m);
|
||||
}
|
||||
|
||||
pub fn write_error(&mut self, msg: &str, error_code: ErrorCode) {
|
||||
self.shrink();
|
||||
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-ERRORRESPONSE>
|
||||
// <https://www.postgresql.org/docs/current/protocol-error-fields.html>
|
||||
// "SERROR\0CXXXXX\0M\0\0".len() == 17
|
||||
self.write_raw(17 + msg.len(), b'E', |buf| {
|
||||
// Severity: ERROR
|
||||
buf.put_slice(b"SERROR\0");
|
||||
|
||||
// Code: error_code
|
||||
buf.put_u8(b'C');
|
||||
buf.put_slice(&error_code);
|
||||
buf.put_u8(0);
|
||||
|
||||
// Message: msg
|
||||
buf.put_u8(b'M');
|
||||
buf.put_slice(msg.as_bytes());
|
||||
buf.put_u8(0);
|
||||
|
||||
// End.
|
||||
buf.put_u8(0);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum FeStartupPacket {
|
||||
CancelRequest(CancelKeyData),
|
||||
SslRequest {
|
||||
direct: Option<[u8; 8]>,
|
||||
},
|
||||
GssEncRequest,
|
||||
StartupMessage {
|
||||
version: ProtocolVersion,
|
||||
params: StartupMessageParams,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct StartupMessageParams {
|
||||
pub params: String,
|
||||
}
|
||||
|
||||
impl StartupMessageParams {
|
||||
/// Get parameter's value by its name.
|
||||
pub fn get(&self, name: &str) -> Option<&str> {
|
||||
self.iter().find_map(|(k, v)| (k == name).then_some(v))
|
||||
}
|
||||
|
||||
/// Split command-line options according to PostgreSQL's logic,
|
||||
/// taking into account all escape sequences but leaving them as-is.
|
||||
/// [`None`] means that there's no `options` in [`Self`].
|
||||
pub fn options_raw(&self) -> Option<impl Iterator<Item = &str>> {
|
||||
self.get("options").map(Self::parse_options_raw)
|
||||
}
|
||||
|
||||
/// Split command-line options according to PostgreSQL's logic,
|
||||
/// taking into account all escape sequences but leaving them as-is.
|
||||
pub fn parse_options_raw(input: &str) -> impl Iterator<Item = &str> {
|
||||
// See `postgres: pg_split_opts`.
|
||||
let mut last_was_escape = false;
|
||||
input
|
||||
.split(move |c: char| {
|
||||
// We split by non-escaped whitespace symbols.
|
||||
let should_split = c.is_ascii_whitespace() && !last_was_escape;
|
||||
last_was_escape = c == '\\' && !last_was_escape;
|
||||
should_split
|
||||
})
|
||||
.filter(|s| !s.is_empty())
|
||||
}
|
||||
|
||||
/// Iterate through key-value pairs in an arbitrary order.
|
||||
pub fn iter(&self) -> impl Iterator<Item = (&str, &str)> {
|
||||
self.params.split_terminator('\0').tuples()
|
||||
}
|
||||
|
||||
// This function is mostly useful in tests.
|
||||
#[cfg(test)]
|
||||
pub fn new<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> Self {
|
||||
let mut b = Self {
|
||||
params: String::new(),
|
||||
};
|
||||
for (k, v) in pairs {
|
||||
b.insert(k, v);
|
||||
}
|
||||
b
|
||||
}
|
||||
|
||||
/// Set parameter's value by its name.
|
||||
/// name and value must not contain a \0 byte
|
||||
pub fn insert(&mut self, name: &str, value: &str) {
|
||||
self.params.reserve(name.len() + value.len() + 2);
|
||||
self.params.push_str(name);
|
||||
self.params.push('\0');
|
||||
self.params.push_str(value);
|
||||
self.params.push('\0');
|
||||
}
|
||||
}
|
||||
|
||||
/// Cancel keys usually are represented as PID+SecretKey, but to proxy they're just
|
||||
/// opaque bytes.
|
||||
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, FromBytes, IntoBytes, Immutable)]
|
||||
pub struct CancelKeyData(pub big_endian::U64);
|
||||
|
||||
pub fn id_to_cancel_key(id: u64) -> CancelKeyData {
|
||||
CancelKeyData(big_endian::U64::new(id))
|
||||
}
|
||||
|
||||
impl fmt::Display for CancelKeyData {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let id = self.0;
|
||||
f.debug_tuple("CancelKeyData")
|
||||
.field(&format_args!("{id:x}"))
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
impl Distribution<CancelKeyData> for Standard {
|
||||
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> CancelKeyData {
|
||||
id_to_cancel_key(rng.r#gen())
|
||||
}
|
||||
}
|
||||
|
||||
pub enum BeMessage<'a> {
|
||||
AuthenticationOk,
|
||||
AuthenticationSasl(BeAuthenticationSaslMessage<'a>),
|
||||
AuthenticationCleartextPassword,
|
||||
BackendKeyData(CancelKeyData),
|
||||
ParameterStatus {
|
||||
name: &'a [u8],
|
||||
value: &'a [u8],
|
||||
},
|
||||
ReadyForQuery,
|
||||
NoticeResponse(&'a str),
|
||||
NegotiateProtocolVersion {
|
||||
version: ProtocolVersion,
|
||||
options: &'a [&'a str],
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum BeAuthenticationSaslMessage<'a> {
|
||||
Methods(&'a [&'a str]),
|
||||
Continue(&'a [u8]),
|
||||
Final(&'a [u8]),
|
||||
}
|
||||
|
||||
impl BeMessage<'_> {
|
||||
/// Write the message into an internal buffer
|
||||
pub fn write_message(self, buf: &mut WriteBuf) {
|
||||
match self {
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONCLEARTEXTPASSWORD>
|
||||
BeMessage::AuthenticationOk => {
|
||||
buf.write_raw(1, b'R', |buf| buf.put_i32(0));
|
||||
}
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONCLEARTEXTPASSWORD>
|
||||
BeMessage::AuthenticationCleartextPassword => {
|
||||
buf.write_raw(1, b'R', |buf| buf.put_i32(3));
|
||||
}
|
||||
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONSASL>
|
||||
BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Methods(methods)) => {
|
||||
let len: usize = methods.iter().map(|m| m.len() + 1).sum();
|
||||
buf.write_raw(len + 2, b'R', |buf| {
|
||||
buf.put_i32(10); // Specifies that SASL auth method is used.
|
||||
for method in methods {
|
||||
buf.put_slice(method.as_bytes());
|
||||
buf.put_u8(0);
|
||||
}
|
||||
buf.put_u8(0); // zero terminator for the list
|
||||
});
|
||||
}
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONSASL>
|
||||
BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Continue(extra)) => {
|
||||
buf.write_raw(extra.len() + 1, b'R', |buf| {
|
||||
buf.put_i32(11); // Continue SASL auth.
|
||||
buf.put_slice(extra);
|
||||
});
|
||||
}
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-AUTHENTICATIONSASL>
|
||||
BeMessage::AuthenticationSasl(BeAuthenticationSaslMessage::Final(extra)) => {
|
||||
buf.write_raw(extra.len() + 1, b'R', |buf| {
|
||||
buf.put_i32(12); // Send final SASL message.
|
||||
buf.put_slice(extra);
|
||||
});
|
||||
}
|
||||
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-BACKENDKEYDATA>
|
||||
BeMessage::BackendKeyData(key_data) => {
|
||||
buf.write_raw(8, b'K', |buf| buf.put_slice(key_data.as_bytes()));
|
||||
}
|
||||
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-NOTICERESPONSE>
|
||||
// <https://www.postgresql.org/docs/current/protocol-error-fields.html>
|
||||
BeMessage::NoticeResponse(msg) => {
|
||||
// 'N' signalizes NoticeResponse messages
|
||||
buf.write_raw(18 + msg.len(), b'N', |buf| {
|
||||
// Severity: NOTICE
|
||||
buf.put_slice(b"SNOTICE\0");
|
||||
|
||||
// Code: XX000 (ignored for notice, but still required)
|
||||
buf.put_slice(b"CXX000\0");
|
||||
|
||||
// Message: msg
|
||||
buf.put_u8(b'M');
|
||||
buf.put_slice(msg.as_bytes());
|
||||
buf.put_u8(0);
|
||||
|
||||
// End notice.
|
||||
buf.put_u8(0);
|
||||
});
|
||||
}
|
||||
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-PARAMETERSTATUS>
|
||||
BeMessage::ParameterStatus { name, value } => {
|
||||
buf.write_raw(name.len() + value.len() + 2, b'S', |buf| {
|
||||
buf.put_slice(name.as_bytes());
|
||||
buf.put_u8(0);
|
||||
buf.put_slice(value.as_bytes());
|
||||
buf.put_u8(0);
|
||||
});
|
||||
}
|
||||
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-NEGOTIATEPROTOCOLVERSION>
|
||||
BeMessage::ReadyForQuery => {
|
||||
buf.write_raw(1, b'Z', |buf| buf.put_u8(b'I'));
|
||||
}
|
||||
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-NEGOTIATEPROTOCOLVERSION>
|
||||
BeMessage::NegotiateProtocolVersion { version, options } => {
|
||||
let len: usize = options.iter().map(|o| o.len() + 1).sum();
|
||||
buf.write_raw(8 + len, b'v', |buf| {
|
||||
buf.put_slice(version.as_bytes());
|
||||
buf.put_u32(options.len() as u32);
|
||||
for option in options {
|
||||
buf.put_slice(option.as_bytes());
|
||||
buf.put_u8(0);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::io::Cursor;
|
||||
|
||||
use tokio::io::{AsyncWriteExt, duplex};
|
||||
use zerocopy::IntoBytes;
|
||||
|
||||
use crate::pqproto::{FeStartupPacket, read_message, read_startup};
|
||||
|
||||
use super::ProtocolVersion;
|
||||
|
||||
#[tokio::test]
|
||||
async fn reject_large_startup() {
|
||||
// we're going to define a v3.0 startup message with far too many parameters.
|
||||
let mut payload = vec![];
|
||||
// 10001 + 8 bytes.
|
||||
payload.extend_from_slice(&10009_u32.to_be_bytes());
|
||||
payload.extend_from_slice(ProtocolVersion::new(3, 0).as_bytes());
|
||||
payload.resize(10009, b'a');
|
||||
|
||||
let (mut server, mut client) = duplex(128);
|
||||
#[rustfmt::skip]
|
||||
let (server, client) = tokio::join!(
|
||||
async move { read_startup(&mut server).await.unwrap_err() },
|
||||
async move { client.write_all(&payload).await.unwrap_err() },
|
||||
);
|
||||
|
||||
assert_eq!(server.to_string(), "invalid startup message length 10001");
|
||||
assert_eq!(client.to_string(), "broken pipe");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reject_large_password() {
|
||||
// we're going to define a password message that is far too long.
|
||||
let mut payload = vec![];
|
||||
payload.push(b'p');
|
||||
payload.extend_from_slice(&517_u32.to_be_bytes());
|
||||
payload.resize(518, b'a');
|
||||
|
||||
let (mut server, mut client) = duplex(128);
|
||||
#[rustfmt::skip]
|
||||
let (server, client) = tokio::join!(
|
||||
async move { read_message(&mut server, &mut vec![], 512).await.unwrap_err() },
|
||||
async move { client.write_all(&payload).await.unwrap_err() },
|
||||
);
|
||||
|
||||
assert_eq!(server.to_string(), "invalid message length 513");
|
||||
assert_eq!(client.to_string(), "broken pipe");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn read_startup_message() {
|
||||
let mut payload = vec![];
|
||||
payload.extend_from_slice(&17_u32.to_be_bytes());
|
||||
payload.extend_from_slice(ProtocolVersion::new(3, 0).as_bytes());
|
||||
payload.extend_from_slice(b"abc\0def\0\0");
|
||||
|
||||
let startup = read_startup(&mut Cursor::new(&payload)).await.unwrap();
|
||||
let FeStartupPacket::StartupMessage { version, params } = startup else {
|
||||
panic!("unexpected startup message: {startup:?}");
|
||||
};
|
||||
|
||||
assert_eq!(version.major(), 3);
|
||||
assert_eq!(version.minor(), 0);
|
||||
assert_eq!(params.params, "abc\0def\0");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn read_ssl_message() {
|
||||
let mut payload = vec![];
|
||||
payload.extend_from_slice(&8_u32.to_be_bytes());
|
||||
payload.extend_from_slice(ProtocolVersion::new(1234, 5679).as_bytes());
|
||||
|
||||
let startup = read_startup(&mut Cursor::new(&payload)).await.unwrap();
|
||||
let FeStartupPacket::SslRequest { direct: None } = startup else {
|
||||
panic!("unexpected startup message: {startup:?}");
|
||||
};
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn read_tls_message() {
|
||||
// sample client hello taken from <https://tls13.xargs.org/#client-hello>
|
||||
let client_hello = [
|
||||
0x16, 0x03, 0x01, 0x00, 0xf8, 0x01, 0x00, 0x00, 0xf4, 0x03, 0x03, 0x00, 0x01, 0x02,
|
||||
0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10,
|
||||
0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e,
|
||||
0x1f, 0x20, 0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, 0xea, 0xeb,
|
||||
0xec, 0xed, 0xee, 0xef, 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9,
|
||||
0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff, 0x00, 0x08, 0x13, 0x02, 0x13, 0x03, 0x13, 0x01,
|
||||
0x00, 0xff, 0x01, 0x00, 0x00, 0xa3, 0x00, 0x00, 0x00, 0x18, 0x00, 0x16, 0x00, 0x00,
|
||||
0x13, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x75, 0x6c, 0x66, 0x68, 0x65,
|
||||
0x69, 0x6d, 0x2e, 0x6e, 0x65, 0x74, 0x00, 0x0b, 0x00, 0x04, 0x03, 0x00, 0x01, 0x02,
|
||||
0x00, 0x0a, 0x00, 0x16, 0x00, 0x14, 0x00, 0x1d, 0x00, 0x17, 0x00, 0x1e, 0x00, 0x19,
|
||||
0x00, 0x18, 0x01, 0x00, 0x01, 0x01, 0x01, 0x02, 0x01, 0x03, 0x01, 0x04, 0x00, 0x23,
|
||||
0x00, 0x00, 0x00, 0x16, 0x00, 0x00, 0x00, 0x17, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x1e,
|
||||
0x00, 0x1c, 0x04, 0x03, 0x05, 0x03, 0x06, 0x03, 0x08, 0x07, 0x08, 0x08, 0x08, 0x09,
|
||||
0x08, 0x0a, 0x08, 0x0b, 0x08, 0x04, 0x08, 0x05, 0x08, 0x06, 0x04, 0x01, 0x05, 0x01,
|
||||
0x06, 0x01, 0x00, 0x2b, 0x00, 0x03, 0x02, 0x03, 0x04, 0x00, 0x2d, 0x00, 0x02, 0x01,
|
||||
0x01, 0x00, 0x33, 0x00, 0x26, 0x00, 0x24, 0x00, 0x1d, 0x00, 0x20, 0x35, 0x80, 0x72,
|
||||
0xd6, 0x36, 0x58, 0x80, 0xd1, 0xae, 0xea, 0x32, 0x9a, 0xdf, 0x91, 0x21, 0x38, 0x38,
|
||||
0x51, 0xed, 0x21, 0xa2, 0x8e, 0x3b, 0x75, 0xe9, 0x65, 0xd0, 0xd2, 0xcd, 0x16, 0x62,
|
||||
0x54,
|
||||
];
|
||||
|
||||
let mut cursor = Cursor::new(&client_hello);
|
||||
|
||||
let startup = read_startup(&mut cursor).await.unwrap();
|
||||
let FeStartupPacket::SslRequest {
|
||||
direct: Some(prefix),
|
||||
} = startup
|
||||
else {
|
||||
panic!("unexpected startup message: {startup:?}");
|
||||
};
|
||||
|
||||
// check that no data is lost.
|
||||
assert_eq!(prefix, [0x16, 0x03, 0x01, 0x00, 0xf8, 0x01, 0x00, 0x00]);
|
||||
assert_eq!(cursor.position(), 8);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn read_message_success() {
|
||||
let query = b"Q\0\0\0\x0cSELECT 1Q\0\0\0\x0cSELECT 2";
|
||||
let mut cursor = Cursor::new(&query);
|
||||
|
||||
let mut buf = vec![];
|
||||
let (tag, message) = read_message(&mut cursor, &mut buf, 100).await.unwrap();
|
||||
assert_eq!(tag, b'Q');
|
||||
assert_eq!(message, b"SELECT 1");
|
||||
|
||||
let (tag, message) = read_message(&mut cursor, &mut buf, 100).await.unwrap();
|
||||
assert_eq!(tag, b'Q');
|
||||
assert_eq!(message, b"SELECT 2");
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,4 @@
|
||||
use async_trait::async_trait;
|
||||
use pq_proto::StartupMessageParams;
|
||||
use tokio::time;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
@@ -15,6 +14,7 @@ use crate::error::ReportableError;
|
||||
use crate::metrics::{
|
||||
ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType,
|
||||
};
|
||||
use crate::pqproto::StartupMessageParams;
|
||||
use crate::proxy::retry::{CouldRetry, retry_after, should_retry};
|
||||
use crate::proxy::wake_compute::wake_compute;
|
||||
use crate::types::Host;
|
||||
|
||||
@@ -1,8 +1,3 @@
|
||||
use bytes::Buf;
|
||||
use pq_proto::framed::Framed;
|
||||
use pq_proto::{
|
||||
BeMessage as Be, CancelKeyData, FeStartupPacket, ProtocolVersion, StartupMessageParams,
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{debug, info, warn};
|
||||
@@ -12,7 +7,10 @@ use crate::config::TlsConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::ReportableError;
|
||||
use crate::metrics::Metrics;
|
||||
use crate::proxy::ERR_INSECURE_CONNECTION;
|
||||
use crate::pqproto::{
|
||||
BeMessage, CancelKeyData, FeStartupPacket, ProtocolVersion, StartupMessageParams,
|
||||
};
|
||||
use crate::proxy::TlsRequired;
|
||||
use crate::stream::{PqStream, Stream, StreamUpgradeError};
|
||||
use crate::tls::PG_ALPN_PROTOCOL;
|
||||
|
||||
@@ -71,33 +69,25 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
const PG_PROTOCOL_EARLIEST: ProtocolVersion = ProtocolVersion::new(3, 0);
|
||||
const PG_PROTOCOL_LATEST: ProtocolVersion = ProtocolVersion::new(3, 0);
|
||||
|
||||
let mut stream = PqStream::new(Stream::from_raw(stream));
|
||||
let (mut stream, mut msg) = PqStream::parse_startup(Stream::from_raw(stream)).await?;
|
||||
loop {
|
||||
let msg = stream.read_startup_packet().await?;
|
||||
match msg {
|
||||
FeStartupPacket::SslRequest { direct } => match stream.get_ref() {
|
||||
Stream::Raw { .. } if !tried_ssl => {
|
||||
tried_ssl = true;
|
||||
|
||||
// We can't perform TLS handshake without a config
|
||||
let have_tls = tls.is_some();
|
||||
if !direct {
|
||||
stream
|
||||
.write_message(&Be::EncryptionResponse(have_tls))
|
||||
.await?;
|
||||
} else if !have_tls {
|
||||
return Err(HandshakeError::ProtocolViolation);
|
||||
}
|
||||
|
||||
if let Some(tls) = tls.take() {
|
||||
// Upgrade raw stream into a secure TLS-backed stream.
|
||||
// NOTE: We've consumed `tls`; this fact will be used later.
|
||||
|
||||
let Framed {
|
||||
stream: raw,
|
||||
read_buf,
|
||||
write_buf,
|
||||
} = stream.framed;
|
||||
let mut read_buf;
|
||||
let raw = if let Some(direct) = &direct {
|
||||
read_buf = &direct[..];
|
||||
stream.accept_direct_tls()
|
||||
} else {
|
||||
read_buf = &[];
|
||||
stream.accept_tls().await?
|
||||
};
|
||||
|
||||
let Stream::Raw { raw } = raw else {
|
||||
return Err(HandshakeError::StreamUpgradeError(
|
||||
@@ -105,12 +95,11 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
));
|
||||
};
|
||||
|
||||
let mut read_buf = read_buf.reader();
|
||||
let mut res = Ok(());
|
||||
let accept = tokio_rustls::TlsAcceptor::from(tls.pg_config.clone())
|
||||
.accept_with(raw, |session| {
|
||||
// push the early data to the tls session
|
||||
while !read_buf.get_ref().is_empty() {
|
||||
while !read_buf.is_empty() {
|
||||
match session.read_tls(&mut read_buf) {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
@@ -123,7 +112,6 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
|
||||
res?;
|
||||
|
||||
let read_buf = read_buf.into_inner();
|
||||
if !read_buf.is_empty() {
|
||||
return Err(HandshakeError::EarlyData);
|
||||
}
|
||||
@@ -157,16 +145,17 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
let (_, tls_server_end_point) =
|
||||
tls.cert_resolver.resolve(conn_info.server_name());
|
||||
|
||||
stream = PqStream {
|
||||
framed: Framed {
|
||||
stream: Stream::Tls {
|
||||
tls: Box::new(tls_stream),
|
||||
tls_server_end_point,
|
||||
},
|
||||
read_buf,
|
||||
write_buf,
|
||||
},
|
||||
let tls = Stream::Tls {
|
||||
tls: Box::new(tls_stream),
|
||||
tls_server_end_point,
|
||||
};
|
||||
(stream, msg) = PqStream::parse_startup(tls).await?;
|
||||
} else {
|
||||
if direct.is_some() {
|
||||
// client sent us a ClientHello already, we can't do anything with it.
|
||||
return Err(HandshakeError::ProtocolViolation);
|
||||
}
|
||||
msg = stream.reject_encryption().await?;
|
||||
}
|
||||
}
|
||||
_ => return Err(HandshakeError::ProtocolViolation),
|
||||
@@ -176,7 +165,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
tried_gss = true;
|
||||
|
||||
// Currently, we don't support GSSAPI
|
||||
stream.write_message(&Be::EncryptionResponse(false)).await?;
|
||||
msg = stream.reject_encryption().await?;
|
||||
}
|
||||
_ => return Err(HandshakeError::ProtocolViolation),
|
||||
},
|
||||
@@ -186,13 +175,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
// Check that the config has been consumed during upgrade
|
||||
// OR we didn't provide it at all (for dev purposes).
|
||||
if tls.is_some() {
|
||||
return stream
|
||||
.throw_error_str(
|
||||
ERR_INSECURE_CONNECTION,
|
||||
crate::error::ErrorKind::User,
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
Err(stream.throw_error(TlsRequired, None).await)?;
|
||||
}
|
||||
|
||||
// This log highlights the start of the connection.
|
||||
@@ -214,20 +197,21 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
// no protocol extensions are supported.
|
||||
// <https://github.com/postgres/postgres/blob/ca481d3c9ab7bf69ff0c8d71ad3951d407f6a33c/src/backend/tcop/backend_startup.c#L744-L753>
|
||||
let mut unsupported = vec![];
|
||||
for (k, _) in params.iter() {
|
||||
let mut supported = StartupMessageParams::default();
|
||||
|
||||
for (k, v) in params.iter() {
|
||||
if k.starts_with("_pq_.") {
|
||||
unsupported.push(k);
|
||||
} else {
|
||||
supported.insert(k, v);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: remove unsupported options so we don't send them to compute.
|
||||
|
||||
stream
|
||||
.write_message(&Be::NegotiateProtocolVersion {
|
||||
version: PG_PROTOCOL_LATEST,
|
||||
options: &unsupported,
|
||||
})
|
||||
.await?;
|
||||
stream.write_message(BeMessage::NegotiateProtocolVersion {
|
||||
version: PG_PROTOCOL_LATEST,
|
||||
options: &unsupported,
|
||||
});
|
||||
stream.flush().await?;
|
||||
|
||||
info!(
|
||||
?version,
|
||||
@@ -235,7 +219,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
session_type = "normal",
|
||||
"successful handshake; unsupported minor version requested"
|
||||
);
|
||||
break Ok(HandshakeData::Startup(stream, params));
|
||||
break Ok(HandshakeData::Startup(stream, supported));
|
||||
}
|
||||
FeStartupPacket::StartupMessage { version, params } => {
|
||||
warn!(
|
||||
|
||||
@@ -10,15 +10,14 @@ pub(crate) mod wake_compute;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub use copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute};
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
use futures::FutureExt;
|
||||
use itertools::Itertools;
|
||||
use once_cell::sync::OnceCell;
|
||||
use pq_proto::{BeMessage as Be, CancelKeyData, StartupMessageParams};
|
||||
use regex::Regex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use smol_str::{SmolStr, ToSmolStr, format_smolstr};
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{Instrument, debug, error, info, warn};
|
||||
|
||||
@@ -27,8 +26,9 @@ use self::passthrough::ProxyPassthrough;
|
||||
use crate::cancellation::{self, CancellationHandler};
|
||||
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::ReportableError;
|
||||
use crate::error::{ReportableError, UserFacingError};
|
||||
use crate::metrics::{Metrics, NumClientConnectionsGuard};
|
||||
use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams};
|
||||
use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol};
|
||||
use crate::proxy::handshake::{HandshakeData, handshake};
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
@@ -38,6 +38,18 @@ use crate::{auth, compute};
|
||||
|
||||
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
#[error("{ERR_INSECURE_CONNECTION}")]
|
||||
pub struct TlsRequired;
|
||||
|
||||
impl ReportableError for TlsRequired {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
crate::error::ErrorKind::User
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for TlsRequired {}
|
||||
|
||||
pub async fn run_until_cancelled<F: std::future::Future>(
|
||||
f: F,
|
||||
cancellation_token: &CancellationToken,
|
||||
@@ -329,7 +341,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
|
||||
let user_info = match result {
|
||||
Ok(user_info) => user_info,
|
||||
Err(e) => stream.throw_error(e, Some(ctx)).await?,
|
||||
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
|
||||
};
|
||||
|
||||
let user = user_info.get_user().to_owned();
|
||||
@@ -349,10 +361,10 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
let app = params.get("application_name");
|
||||
let params_span = tracing::info_span!("", ?user, ?db, ?app);
|
||||
|
||||
return stream
|
||||
return Err(stream
|
||||
.throw_error(e, Some(ctx))
|
||||
.instrument(params_span)
|
||||
.await?;
|
||||
.await)?;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -365,7 +377,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
.get(NeonOptions::PARAMS_COMPAT)
|
||||
.is_some();
|
||||
|
||||
let mut node = connect_to_compute(
|
||||
let res = connect_to_compute(
|
||||
ctx,
|
||||
&TcpMechanism {
|
||||
user_info: compute_user_info.clone(),
|
||||
@@ -377,22 +389,19 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
config.wake_compute_retry_config,
|
||||
&config.connect_to_compute,
|
||||
)
|
||||
.or_else(|e| stream.throw_error(e, Some(ctx)))
|
||||
.await?;
|
||||
.await;
|
||||
|
||||
let node = match res {
|
||||
Ok(node) => node,
|
||||
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
|
||||
};
|
||||
|
||||
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
||||
let session = cancellation_handler_clone.get_key();
|
||||
|
||||
session.write_cancel_key(node.cancel_closure.clone())?;
|
||||
|
||||
prepare_client_connection(&node, *session.key(), &mut stream).await?;
|
||||
|
||||
// Before proxy passing, forward to compute whatever data is left in the
|
||||
// PqStream input buffer. Normally there is none, but our serverless npm
|
||||
// driver in pipeline mode sends startup, password and first query
|
||||
// immediately after opening the connection.
|
||||
let (stream, read_buf) = stream.into_inner();
|
||||
node.stream.write_all(&read_buf).await?;
|
||||
prepare_client_connection(&node, *session.key(), &mut stream);
|
||||
let stream = stream.flush_and_into_inner().await?;
|
||||
|
||||
let private_link_id = match ctx.extra() {
|
||||
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
|
||||
@@ -413,31 +422,28 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
}
|
||||
|
||||
/// Finish client connection initialization: confirm auth success, send params, etc.
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub(crate) async fn prepare_client_connection(
|
||||
pub(crate) fn prepare_client_connection(
|
||||
node: &compute::PostgresConnection,
|
||||
cancel_key_data: CancelKeyData,
|
||||
stream: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> Result<(), std::io::Error> {
|
||||
) {
|
||||
// Forward all deferred notices to the client.
|
||||
for notice in &node.delayed_notice {
|
||||
stream.write_message_noflush(&Be::Raw(b'N', notice.as_bytes()))?;
|
||||
stream.write_raw(notice.as_bytes().len(), b'N', |buf| {
|
||||
buf.extend_from_slice(notice.as_bytes());
|
||||
});
|
||||
}
|
||||
|
||||
// Forward all postgres connection params to the client.
|
||||
for (name, value) in &node.params {
|
||||
stream.write_message_noflush(&Be::ParameterStatus {
|
||||
stream.write_message(BeMessage::ParameterStatus {
|
||||
name: name.as_bytes(),
|
||||
value: value.as_bytes(),
|
||||
})?;
|
||||
});
|
||||
}
|
||||
|
||||
stream
|
||||
.write_message_noflush(&Be::BackendKeyData(cancel_key_data))?
|
||||
.write_message(&Be::ReadyForQuery)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
stream.write_message(BeMessage::BackendKeyData(cancel_key_data));
|
||||
stream.write_message(BeMessage::ReadyForQuery);
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
|
||||
|
||||
@@ -125,9 +125,10 @@ pub(crate) fn retry_after(num_retries: u32, config: RetryConfig) -> time::Durati
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::ShouldRetryWakeCompute;
|
||||
use postgres_client::error::{DbError, SqlState};
|
||||
|
||||
use super::ShouldRetryWakeCompute;
|
||||
|
||||
#[test]
|
||||
fn should_retry_wake_compute_for_db_error() {
|
||||
// These SQLStates should NOT trigger a wake_compute retry.
|
||||
|
||||
@@ -10,7 +10,7 @@ use bytes::{Bytes, BytesMut};
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use postgres_client::tls::TlsConnect;
|
||||
use postgres_protocol::message::frontend;
|
||||
use tokio::io::{AsyncReadExt, DuplexStream};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, DuplexStream};
|
||||
use tokio_util::codec::{Decoder, Encoder};
|
||||
|
||||
use super::*;
|
||||
@@ -49,15 +49,14 @@ async fn proxy_mitm(
|
||||
};
|
||||
|
||||
let mut end_server = tokio_util::codec::Framed::new(end_server, PgFrame);
|
||||
let (end_client, buf) = end_client.framed.into_inner();
|
||||
assert!(buf.is_empty());
|
||||
let end_client = end_client.flush_and_into_inner().await.unwrap();
|
||||
let mut end_client = tokio_util::codec::Framed::new(end_client, PgFrame);
|
||||
|
||||
// give the end_server the startup parameters
|
||||
let mut buf = BytesMut::new();
|
||||
frontend::startup_message(
|
||||
&postgres_protocol::message::frontend::StartupMessageParams {
|
||||
params: startup.params.into(),
|
||||
params: startup.params.as_bytes().into(),
|
||||
},
|
||||
&mut buf,
|
||||
)
|
||||
|
||||
@@ -128,7 +128,7 @@ trait TestAuth: Sized {
|
||||
self,
|
||||
stream: &mut PqStream<Stream<S>>,
|
||||
) -> anyhow::Result<()> {
|
||||
stream.write_message_noflush(&Be::AuthenticationOk)?;
|
||||
stream.write_message(BeMessage::AuthenticationOk);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -157,9 +157,7 @@ impl TestAuth for Scram {
|
||||
self,
|
||||
stream: &mut PqStream<Stream<S>>,
|
||||
) -> anyhow::Result<()> {
|
||||
let outcome = auth::AuthFlow::new(stream)
|
||||
.begin(auth::Scram(&self.0, &RequestContext::test()))
|
||||
.await?
|
||||
let outcome = auth::AuthFlow::new(stream, auth::Scram(&self.0, &RequestContext::test()))
|
||||
.authenticate()
|
||||
.await?;
|
||||
|
||||
@@ -185,10 +183,12 @@ async fn dummy_proxy(
|
||||
|
||||
auth.authenticate(&mut stream).await?;
|
||||
|
||||
stream
|
||||
.write_message_noflush(&Be::CLIENT_ENCODING)?
|
||||
.write_message(&Be::ReadyForQuery)
|
||||
.await?;
|
||||
stream.write_message(BeMessage::ParameterStatus {
|
||||
name: b"client_encoding",
|
||||
value: b"UTF8",
|
||||
});
|
||||
stream.write_message(BeMessage::ReadyForQuery);
|
||||
stream.flush().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
use core::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use pq_proto::CancelKeyData;
|
||||
use tokio::sync::Mutex;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::pqproto::CancelKeyData;
|
||||
|
||||
pub trait CancellationPublisherMut: Send + Sync + 'static {
|
||||
#[allow(async_fn_in_trait)]
|
||||
async fn try_publish(
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
use std::io::ErrorKind;
|
||||
|
||||
use anyhow::Ok;
|
||||
use pq_proto::{CancelKeyData, id_to_cancel_key};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::pqproto::{CancelKeyData, id_to_cancel_key};
|
||||
|
||||
pub mod keyspace {
|
||||
pub const CANCEL_PREFIX: &str = "cancel";
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
#[derive(Clone, Debug, Eq, PartialEq)]
|
||||
pub(crate) enum KeyPrefix {
|
||||
#[serde(untagged)]
|
||||
Cancel(CancelKeyData),
|
||||
}
|
||||
|
||||
@@ -18,9 +17,7 @@ impl KeyPrefix {
|
||||
pub(crate) fn build_redis_key(&self) -> String {
|
||||
match self {
|
||||
KeyPrefix::Cancel(key) => {
|
||||
let hi = (key.backend_pid as u64) << 32;
|
||||
let lo = (key.cancel_key as u64) & 0xffff_ffff;
|
||||
let id = hi | lo;
|
||||
let id = key.0.get();
|
||||
let keyspace = keyspace::CANCEL_PREFIX;
|
||||
format!("{keyspace}:{id:x}")
|
||||
}
|
||||
@@ -63,10 +60,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_build_redis_key() {
|
||||
let cancel_key: KeyPrefix = KeyPrefix::Cancel(CancelKeyData {
|
||||
backend_pid: 12345,
|
||||
cancel_key: 54321,
|
||||
});
|
||||
let cancel_key: KeyPrefix = KeyPrefix::Cancel(id_to_cancel_key(12345 << 32 | 54321));
|
||||
|
||||
let redis_key = cancel_key.build_redis_key();
|
||||
assert_eq!(redis_key, "cancel:30390000d431");
|
||||
@@ -77,10 +71,7 @@ mod tests {
|
||||
let redis_key = "cancel:30390000d431";
|
||||
let key: KeyPrefix = parse_redis_key(redis_key).expect("Failed to parse key");
|
||||
|
||||
let ref_key = CancelKeyData {
|
||||
backend_pid: 12345,
|
||||
cancel_key: 54321,
|
||||
};
|
||||
let ref_key = id_to_cancel_key(12345 << 32 | 54321);
|
||||
|
||||
assert_eq!(key.as_str(), KeyPrefix::Cancel(ref_key).as_str());
|
||||
let KeyPrefix::Cancel(cancel_key) = key;
|
||||
|
||||
@@ -2,11 +2,9 @@ use std::convert::Infallible;
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::StreamExt;
|
||||
use pq_proto::CancelKeyData;
|
||||
use redis::aio::PubSub;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
|
||||
use crate::cache::project_info::ProjectInfoCache;
|
||||
@@ -100,14 +98,6 @@ pub(crate) struct PasswordUpdate {
|
||||
role_name: RoleNameInt,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
pub(crate) struct CancelSession {
|
||||
pub(crate) region_id: Option<String>,
|
||||
pub(crate) cancel_key_data: CancelKeyData,
|
||||
pub(crate) session_id: Uuid,
|
||||
pub(crate) peer_addr: Option<std::net::IpAddr>,
|
||||
}
|
||||
|
||||
fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result<T, D::Error>
|
||||
where
|
||||
T: for<'de2> serde::Deserialize<'de2>,
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
//! Definitions for SASL messages.
|
||||
|
||||
use pq_proto::{BeAuthenticationSaslMessage, BeMessage};
|
||||
|
||||
use crate::parse::split_cstr;
|
||||
|
||||
/// SASL-specific payload of [`PasswordMessage`](pq_proto::FeMessage::PasswordMessage).
|
||||
@@ -30,26 +28,6 @@ impl<'a> FirstMessage<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
/// A single SASL message.
|
||||
/// This struct is deliberately decoupled from lower-level
|
||||
/// [`BeAuthenticationSaslMessage`].
|
||||
#[derive(Debug)]
|
||||
pub(super) enum ServerMessage<T> {
|
||||
/// We expect to see more steps.
|
||||
Continue(T),
|
||||
/// This is the final step.
|
||||
Final(T),
|
||||
}
|
||||
|
||||
impl<'a> ServerMessage<&'a str> {
|
||||
pub(super) fn to_reply(&self) -> BeMessage<'a> {
|
||||
BeMessage::AuthenticationSasl(match self {
|
||||
ServerMessage::Continue(s) => BeAuthenticationSaslMessage::Continue(s.as_bytes()),
|
||||
ServerMessage::Final(s) => BeAuthenticationSaslMessage::Final(s.as_bytes()),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
@@ -14,7 +14,7 @@ use std::io;
|
||||
|
||||
pub(crate) use channel_binding::ChannelBinding;
|
||||
pub(crate) use messages::FirstMessage;
|
||||
pub(crate) use stream::{Outcome, SaslStream};
|
||||
pub(crate) use stream::{Outcome, authenticate};
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::error::{ReportableError, UserFacingError};
|
||||
@@ -22,6 +22,9 @@ use crate::error::{ReportableError, UserFacingError};
|
||||
/// Fine-grained auth errors help in writing tests.
|
||||
#[derive(Error, Debug)]
|
||||
pub(crate) enum Error {
|
||||
#[error("Unsupported authentication method: {0}")]
|
||||
BadAuthMethod(Box<str>),
|
||||
|
||||
#[error("Channel binding failed: {0}")]
|
||||
ChannelBindingFailed(&'static str),
|
||||
|
||||
@@ -54,6 +57,7 @@ impl UserFacingError for Error {
|
||||
impl ReportableError for Error {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
Error::BadAuthMethod(_) => crate::error::ErrorKind::User,
|
||||
Error::ChannelBindingFailed(_) => crate::error::ErrorKind::User,
|
||||
Error::ChannelBindingBadMethod(_) => crate::error::ErrorKind::User,
|
||||
Error::BadClientMessage(_) => crate::error::ErrorKind::User,
|
||||
|
||||
@@ -3,61 +3,12 @@
|
||||
use std::io;
|
||||
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::info;
|
||||
|
||||
use super::Mechanism;
|
||||
use super::messages::ServerMessage;
|
||||
use super::{Mechanism, Step};
|
||||
use crate::context::RequestContext;
|
||||
use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage};
|
||||
use crate::stream::PqStream;
|
||||
|
||||
/// Abstracts away all peculiarities of the libpq's protocol.
|
||||
pub(crate) struct SaslStream<'a, S> {
|
||||
/// The underlying stream.
|
||||
stream: &'a mut PqStream<S>,
|
||||
/// Current password message we received from client.
|
||||
current: bytes::Bytes,
|
||||
/// First SASL message produced by client.
|
||||
first: Option<&'a str>,
|
||||
}
|
||||
|
||||
impl<'a, S> SaslStream<'a, S> {
|
||||
pub(crate) fn new(stream: &'a mut PqStream<S>, first: &'a str) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
current: bytes::Bytes::new(),
|
||||
first: Some(first),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin> SaslStream<'_, S> {
|
||||
// Receive a new SASL message from the client.
|
||||
async fn recv(&mut self) -> io::Result<&str> {
|
||||
if let Some(first) = self.first.take() {
|
||||
return Ok(first);
|
||||
}
|
||||
|
||||
self.current = self.stream.read_password_message().await?;
|
||||
let s = std::str::from_utf8(&self.current)
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?;
|
||||
|
||||
Ok(s)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncWrite + Unpin> SaslStream<'_, S> {
|
||||
// Send a SASL message to the client.
|
||||
async fn send(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> {
|
||||
self.stream.write_message(&msg.to_reply()).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Queue a SASL message for the client.
|
||||
fn send_noflush(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> {
|
||||
self.stream.write_message_noflush(&msg.to_reply())?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// SASL authentication outcome.
|
||||
/// It's much easier to match on those two variants
|
||||
/// than to peek into a noisy protocol error type.
|
||||
@@ -69,33 +20,62 @@ pub(crate) enum Outcome<R> {
|
||||
Failure(&'static str),
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> SaslStream<'_, S> {
|
||||
/// Perform SASL message exchange according to the underlying algorithm
|
||||
/// until user is either authenticated or denied access.
|
||||
pub(crate) async fn authenticate<M: Mechanism>(
|
||||
mut self,
|
||||
mut mechanism: M,
|
||||
) -> super::Result<Outcome<M::Output>> {
|
||||
loop {
|
||||
let input = self.recv().await?;
|
||||
let step = mechanism.exchange(input).map_err(|error| {
|
||||
info!(?error, "error during SASL exchange");
|
||||
error
|
||||
})?;
|
||||
pub async fn authenticate<S, F, M>(
|
||||
ctx: &RequestContext,
|
||||
stream: &mut PqStream<S>,
|
||||
mechanism: F,
|
||||
) -> super::Result<Outcome<M::Output>>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
F: FnOnce(&str) -> super::Result<M>,
|
||||
M: Mechanism,
|
||||
{
|
||||
let sasl = {
|
||||
// pause the timer while we communicate with the client
|
||||
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
|
||||
use super::Step;
|
||||
return Ok(match step {
|
||||
Step::Continue(moved_mechanism, reply) => {
|
||||
self.send(&ServerMessage::Continue(&reply)).await?;
|
||||
mechanism = moved_mechanism;
|
||||
continue;
|
||||
}
|
||||
Step::Success(result, reply) => {
|
||||
self.send_noflush(&ServerMessage::Final(&reply))?;
|
||||
Outcome::Success(result)
|
||||
}
|
||||
Step::Failure(reason) => Outcome::Failure(reason),
|
||||
});
|
||||
// Initial client message contains the chosen auth method's name.
|
||||
let msg = stream.read_password_message().await?;
|
||||
super::FirstMessage::parse(msg).ok_or(super::Error::BadClientMessage("bad sasl message"))?
|
||||
};
|
||||
|
||||
let mut mechanism = mechanism(sasl.method)?;
|
||||
let mut input = sasl.message;
|
||||
loop {
|
||||
let step = mechanism
|
||||
.exchange(input)
|
||||
.inspect_err(|error| tracing::info!(?error, "error during SASL exchange"))?;
|
||||
|
||||
match step {
|
||||
Step::Continue(moved_mechanism, reply) => {
|
||||
mechanism = moved_mechanism;
|
||||
|
||||
// pause the timer while we communicate with the client
|
||||
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
|
||||
// write reply
|
||||
let sasl_msg = BeAuthenticationSaslMessage::Continue(reply.as_bytes());
|
||||
stream.write_message(BeMessage::AuthenticationSasl(sasl_msg));
|
||||
|
||||
// get next input
|
||||
stream.flush().await?;
|
||||
let msg = stream.read_password_message().await?;
|
||||
input = std::str::from_utf8(msg)
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?;
|
||||
}
|
||||
Step::Success(result, reply) => {
|
||||
// pause the timer while we communicate with the client
|
||||
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
|
||||
// write reply
|
||||
let sasl_msg = BeAuthenticationSaslMessage::Final(reply.as_bytes());
|
||||
stream.write_message(BeMessage::AuthenticationSasl(sasl_msg));
|
||||
stream.write_message(BeMessage::AuthenticationOk);
|
||||
// exit with success
|
||||
break Ok(Outcome::Success(result));
|
||||
}
|
||||
// exit with failure
|
||||
Step::Failure(reason) => break Ok(Outcome::Failure(reason)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,7 +17,6 @@ use postgres_client::error::{DbError, ErrorPosition, SqlState};
|
||||
use postgres_client::{
|
||||
GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, RowStream, Transaction,
|
||||
};
|
||||
use pq_proto::StartupMessageParamsBuilder;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use serde_json::value::RawValue;
|
||||
@@ -41,6 +40,7 @@ use crate::context::RequestContext;
|
||||
use crate::error::{ErrorKind, ReportableError, UserFacingError};
|
||||
use crate::http::{ReadBodyError, read_body_with_limit};
|
||||
use crate::metrics::{HttpDirection, Metrics, SniGroup, SniKind};
|
||||
use crate::pqproto::StartupMessageParams;
|
||||
use crate::proxy::{NeonOptions, run_until_cancelled};
|
||||
use crate::serverless::backend::HttpConnError;
|
||||
use crate::types::{DbName, RoleName};
|
||||
@@ -219,7 +219,7 @@ fn get_conn_info(
|
||||
|
||||
let mut options = Option::None;
|
||||
|
||||
let mut params = StartupMessageParamsBuilder::default();
|
||||
let mut params = StartupMessageParams::default();
|
||||
params.insert("user", &username);
|
||||
params.insert("database", &dbname);
|
||||
for (key, value) in pairs {
|
||||
|
||||
@@ -2,19 +2,17 @@ use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::{io, task};
|
||||
|
||||
use bytes::BytesMut;
|
||||
use pq_proto::framed::{ConnectionError, Framed};
|
||||
use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
|
||||
use rustls::ServerConfig;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
|
||||
use tokio_rustls::server::TlsStream;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::control_plane::messages::ColdStartInfo;
|
||||
use crate::error::{ErrorKind, ReportableError, UserFacingError};
|
||||
use crate::metrics::Metrics;
|
||||
use crate::pqproto::{
|
||||
BeMessage, FE_PASSWORD_MESSAGE, FeStartupPacket, SQLSTATE_INTERNAL_ERROR, WriteBuf,
|
||||
read_message, read_startup,
|
||||
};
|
||||
use crate::tls::TlsServerEndPoint;
|
||||
|
||||
/// Stream wrapper which implements libpq's protocol.
|
||||
@@ -23,58 +21,77 @@ use crate::tls::TlsServerEndPoint;
|
||||
/// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
|
||||
/// to pass random malformed bytes through the connection).
|
||||
pub struct PqStream<S> {
|
||||
pub(crate) framed: Framed<S>,
|
||||
stream: S,
|
||||
read: Vec<u8>,
|
||||
write: WriteBuf,
|
||||
}
|
||||
|
||||
impl<S> PqStream<S> {
|
||||
/// Construct a new libpq protocol wrapper.
|
||||
pub fn new(stream: S) -> Self {
|
||||
pub fn get_ref(&self) -> &S {
|
||||
&self.stream
|
||||
}
|
||||
|
||||
/// Construct a new libpq protocol wrapper over a stream without the first startup message.
|
||||
#[cfg(test)]
|
||||
pub fn new_skip_handshake(stream: S) -> Self {
|
||||
Self {
|
||||
framed: Framed::new(stream),
|
||||
stream,
|
||||
read: Vec::new(),
|
||||
write: WriteBuf::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract the underlying stream and read buffer.
|
||||
pub fn into_inner(self) -> (S, BytesMut) {
|
||||
self.framed.into_inner()
|
||||
}
|
||||
|
||||
/// Get a shared reference to the underlying stream.
|
||||
pub(crate) fn get_ref(&self) -> &S {
|
||||
self.framed.get_ref()
|
||||
}
|
||||
}
|
||||
|
||||
fn err_connection() -> io::Error {
|
||||
io::Error::new(io::ErrorKind::ConnectionAborted, "connection is lost")
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> PqStream<S> {
|
||||
/// Construct a new libpq protocol wrapper and read the first startup message.
|
||||
///
|
||||
/// This is not cancel safe.
|
||||
pub async fn parse_startup(mut stream: S) -> io::Result<(Self, FeStartupPacket)> {
|
||||
let startup = read_startup(&mut stream).await?;
|
||||
Ok((
|
||||
Self {
|
||||
stream,
|
||||
read: Vec::new(),
|
||||
write: WriteBuf::new(),
|
||||
},
|
||||
startup,
|
||||
))
|
||||
}
|
||||
|
||||
/// Tell the client that encryption is not supported.
|
||||
///
|
||||
/// This is not cancel safe
|
||||
pub async fn reject_encryption(&mut self) -> io::Result<FeStartupPacket> {
|
||||
// N for No.
|
||||
self.write.encryption(b'N');
|
||||
self.flush().await?;
|
||||
read_startup(&mut self.stream).await
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin> PqStream<S> {
|
||||
/// Receive [`FeStartupPacket`], which is a first packet sent by a client.
|
||||
pub async fn read_startup_packet(&mut self) -> io::Result<FeStartupPacket> {
|
||||
self.framed
|
||||
.read_startup_message()
|
||||
.await
|
||||
.map_err(ConnectionError::into_io_error)?
|
||||
.ok_or_else(err_connection)
|
||||
}
|
||||
|
||||
async fn read_message(&mut self) -> io::Result<FeMessage> {
|
||||
self.framed
|
||||
.read_message()
|
||||
.await
|
||||
.map_err(ConnectionError::into_io_error)?
|
||||
.ok_or_else(err_connection)
|
||||
}
|
||||
|
||||
pub(crate) async fn read_password_message(&mut self) -> io::Result<bytes::Bytes> {
|
||||
match self.read_message().await? {
|
||||
FeMessage::PasswordMessage(msg) => Ok(msg),
|
||||
bad => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("unexpected message type: {bad:?}"),
|
||||
)),
|
||||
/// Read a raw postgres packet, which will respect the max length requested.
|
||||
/// This is not cancel safe.
|
||||
async fn read_raw_expect(&mut self, tag: u8, max: usize) -> io::Result<&mut [u8]> {
|
||||
let (actual_tag, msg) = read_message(&mut self.stream, &mut self.read, max).await?;
|
||||
if actual_tag != tag {
|
||||
return Err(io::Error::other(format!(
|
||||
"incorrect message tag, expected {:?}, got {:?}",
|
||||
tag as char, actual_tag as char,
|
||||
)));
|
||||
}
|
||||
Ok(msg)
|
||||
}
|
||||
|
||||
/// Read a postgres password message, which will respect the max length requested.
|
||||
/// This is not cancel safe.
|
||||
pub async fn read_password_message(&mut self) -> io::Result<&mut [u8]> {
|
||||
// passwords are usually pretty short
|
||||
// and SASL SCRAM messages are no longer than 256 bytes in my testing
|
||||
// (a few hashes and random bytes, encoded into base64).
|
||||
const MAX_PASSWORD_LENGTH: usize = 512;
|
||||
self.read_raw_expect(FE_PASSWORD_MESSAGE, MAX_PASSWORD_LENGTH)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -84,6 +101,16 @@ pub struct ReportedError {
|
||||
error_kind: ErrorKind,
|
||||
}
|
||||
|
||||
impl ReportedError {
|
||||
pub fn new(e: (impl UserFacingError + Into<anyhow::Error>)) -> Self {
|
||||
let error_kind = e.get_error_kind();
|
||||
Self {
|
||||
source: e.into(),
|
||||
error_kind,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ReportedError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
self.source.fmt(f)
|
||||
@@ -102,109 +129,65 @@ impl ReportableError for ReportedError {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
enum ErrorTag {
|
||||
#[serde(rename = "proxy")]
|
||||
Proxy,
|
||||
#[serde(rename = "compute")]
|
||||
Compute,
|
||||
#[serde(rename = "client")]
|
||||
Client,
|
||||
#[serde(rename = "controlplane")]
|
||||
ControlPlane,
|
||||
#[serde(rename = "other")]
|
||||
Other,
|
||||
}
|
||||
|
||||
impl From<ErrorKind> for ErrorTag {
|
||||
fn from(error_kind: ErrorKind) -> Self {
|
||||
match error_kind {
|
||||
ErrorKind::User => Self::Client,
|
||||
ErrorKind::ClientDisconnect => Self::Client,
|
||||
ErrorKind::RateLimit => Self::Proxy,
|
||||
ErrorKind::ServiceRateLimit => Self::Proxy, // considering rate limit as proxy error for SLI
|
||||
ErrorKind::Quota => Self::Proxy,
|
||||
ErrorKind::Service => Self::Proxy,
|
||||
ErrorKind::ControlPlane => Self::ControlPlane,
|
||||
ErrorKind::Postgres => Self::Other,
|
||||
ErrorKind::Compute => Self::Compute,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
struct ProbeErrorData {
|
||||
tag: ErrorTag,
|
||||
msg: String,
|
||||
cold_start_info: Option<ColdStartInfo>,
|
||||
}
|
||||
|
||||
impl<S: AsyncWrite + Unpin> PqStream<S> {
|
||||
/// Write the message into an internal buffer, but don't flush the underlying stream.
|
||||
pub(crate) fn write_message_noflush(
|
||||
&mut self,
|
||||
message: &BeMessage<'_>,
|
||||
) -> io::Result<&mut Self> {
|
||||
self.framed
|
||||
.write_message(message)
|
||||
.map_err(ProtocolError::into_io_error)?;
|
||||
Ok(self)
|
||||
/// Tell the client that we are willing to accept SSL.
|
||||
/// This is not cancel safe
|
||||
pub async fn accept_tls(mut self) -> io::Result<S> {
|
||||
// S for SSL.
|
||||
self.write.encryption(b'S');
|
||||
self.flush().await?;
|
||||
Ok(self.stream)
|
||||
}
|
||||
|
||||
/// Write the message into an internal buffer and flush it.
|
||||
pub async fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
|
||||
self.write_message_noflush(message)?;
|
||||
self.flush().await?;
|
||||
Ok(self)
|
||||
/// Assert that we are using direct TLS.
|
||||
pub fn accept_direct_tls(self) -> S {
|
||||
self.stream
|
||||
}
|
||||
|
||||
/// Write a raw message to the internal buffer.
|
||||
pub fn write_raw(&mut self, size_hint: usize, tag: u8, f: impl FnOnce(&mut Vec<u8>)) {
|
||||
self.write.write_raw(size_hint, tag, f);
|
||||
}
|
||||
|
||||
/// Write the message into an internal buffer
|
||||
pub fn write_message(&mut self, message: BeMessage<'_>) {
|
||||
message.write_message(&mut self.write);
|
||||
}
|
||||
|
||||
/// Flush the output buffer into the underlying stream.
|
||||
pub(crate) async fn flush(&mut self) -> io::Result<&mut Self> {
|
||||
self.framed.flush().await?;
|
||||
Ok(self)
|
||||
///
|
||||
/// This is cancel safe.
|
||||
pub async fn flush(&mut self) -> io::Result<()> {
|
||||
self.stream.write_all_buf(&mut self.write).await?;
|
||||
self.write.reset();
|
||||
|
||||
self.stream.flush().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Writes message with the given error kind to the stream.
|
||||
/// Used only for probe queries
|
||||
async fn write_format_message(
|
||||
&mut self,
|
||||
msg: &str,
|
||||
error_kind: ErrorKind,
|
||||
ctx: Option<&crate::context::RequestContext>,
|
||||
) -> String {
|
||||
let formatted_msg = match ctx {
|
||||
Some(ctx) if ctx.get_testodrome_id().is_some() => {
|
||||
serde_json::to_string(&ProbeErrorData {
|
||||
tag: ErrorTag::from(error_kind),
|
||||
msg: msg.to_string(),
|
||||
cold_start_info: Some(ctx.cold_start_info()),
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
_ => msg.to_string(),
|
||||
};
|
||||
|
||||
// already error case, ignore client IO error
|
||||
self.write_message(&BeMessage::ErrorResponse(&formatted_msg, None))
|
||||
.await
|
||||
.inspect_err(|e| debug!("write_message failed: {e}"))
|
||||
.ok();
|
||||
|
||||
formatted_msg
|
||||
/// Flush the output buffer into the underlying stream.
|
||||
///
|
||||
/// This is cancel safe.
|
||||
pub async fn flush_and_into_inner(mut self) -> io::Result<S> {
|
||||
self.flush().await?;
|
||||
Ok(self.stream)
|
||||
}
|
||||
|
||||
/// Write the error message using [`Self::write_format_message`], then re-throw it.
|
||||
/// Allowing string literals is safe under the assumption they might not contain any runtime info.
|
||||
/// This method exists due to `&str` not implementing `Into<anyhow::Error>`.
|
||||
/// Write the error message to the client, then re-throw it.
|
||||
///
|
||||
/// Trait [`UserFacingError`] acts as an allowlist for error types.
|
||||
/// If `ctx` is provided and has testodrome_id set, error messages will be prefixed according to error kind.
|
||||
pub async fn throw_error_str<T>(
|
||||
pub(crate) async fn throw_error<E>(
|
||||
&mut self,
|
||||
msg: &'static str,
|
||||
error_kind: ErrorKind,
|
||||
error: E,
|
||||
ctx: Option<&crate::context::RequestContext>,
|
||||
) -> Result<T, ReportedError> {
|
||||
self.write_format_message(msg, error_kind, ctx).await;
|
||||
) -> ReportedError
|
||||
where
|
||||
E: UserFacingError + Into<anyhow::Error>,
|
||||
{
|
||||
let error_kind = error.get_error_kind();
|
||||
let msg = error.to_string_client();
|
||||
|
||||
if error_kind != ErrorKind::RateLimit && error_kind != ErrorKind::User {
|
||||
tracing::info!(
|
||||
@@ -214,39 +197,39 @@ impl<S: AsyncWrite + Unpin> PqStream<S> {
|
||||
);
|
||||
}
|
||||
|
||||
Err(ReportedError {
|
||||
source: anyhow::anyhow!(msg),
|
||||
error_kind,
|
||||
})
|
||||
}
|
||||
|
||||
/// Write the error message using [`Self::write_format_message`], then re-throw it.
|
||||
/// Trait [`UserFacingError`] acts as an allowlist for error types.
|
||||
/// If `ctx` is provided and has testodrome_id set, error messages will be prefixed according to error kind.
|
||||
pub(crate) async fn throw_error<T, E>(
|
||||
&mut self,
|
||||
error: E,
|
||||
ctx: Option<&crate::context::RequestContext>,
|
||||
) -> Result<T, ReportedError>
|
||||
where
|
||||
E: UserFacingError + Into<anyhow::Error>,
|
||||
{
|
||||
let error_kind = error.get_error_kind();
|
||||
let msg = error.to_string_client();
|
||||
self.write_format_message(&msg, error_kind, ctx).await;
|
||||
if error_kind != ErrorKind::RateLimit && error_kind != ErrorKind::User {
|
||||
tracing::info!(
|
||||
kind=error_kind.to_metric_label(),
|
||||
error=%error,
|
||||
msg,
|
||||
"forwarding error to user",
|
||||
);
|
||||
let probe_msg;
|
||||
let mut msg = &*msg;
|
||||
if let Some(ctx) = ctx {
|
||||
if ctx.get_testodrome_id().is_some() {
|
||||
let tag = match error_kind {
|
||||
ErrorKind::User => "client",
|
||||
ErrorKind::ClientDisconnect => "client",
|
||||
ErrorKind::RateLimit => "proxy",
|
||||
ErrorKind::ServiceRateLimit => "proxy",
|
||||
ErrorKind::Quota => "proxy",
|
||||
ErrorKind::Service => "proxy",
|
||||
ErrorKind::ControlPlane => "controlplane",
|
||||
ErrorKind::Postgres => "other",
|
||||
ErrorKind::Compute => "compute",
|
||||
};
|
||||
probe_msg = typed_json::json!({
|
||||
"tag": tag,
|
||||
"msg": msg,
|
||||
"cold_start_info": ctx.cold_start_info(),
|
||||
})
|
||||
.to_string();
|
||||
msg = &probe_msg;
|
||||
}
|
||||
}
|
||||
|
||||
Err(ReportedError {
|
||||
source: anyhow::anyhow!(error),
|
||||
error_kind,
|
||||
})
|
||||
// TODO: either preserve the error code from postgres, or assign error codes to proxy errors.
|
||||
self.write.write_error(msg, SQLSTATE_INTERNAL_ERROR);
|
||||
|
||||
self.flush()
|
||||
.await
|
||||
.unwrap_or_else(|e| tracing::debug!("write_message failed: {e}"));
|
||||
|
||||
ReportedError::new(error)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -44,6 +44,7 @@ struct GlobalTimelinesState {
|
||||
// on-demand timeline creation from recreating deleted timelines. This is only soft-enforced, as
|
||||
// this map is dropped on restart.
|
||||
tombstones: HashMap<TenantTimelineId, Instant>,
|
||||
tenant_tombstones: HashMap<TenantId, Instant>,
|
||||
|
||||
conf: Arc<SafeKeeperConf>,
|
||||
broker_active_set: Arc<TimelinesSet>,
|
||||
@@ -81,10 +82,25 @@ impl GlobalTimelinesState {
|
||||
}
|
||||
}
|
||||
|
||||
fn has_tombstone(&self, ttid: &TenantTimelineId) -> bool {
|
||||
self.tombstones.contains_key(ttid) || self.tenant_tombstones.contains_key(&ttid.tenant_id)
|
||||
}
|
||||
|
||||
/// Removes all blocking tombstones for the given timeline ID.
|
||||
/// Returns `true` if there have been actual changes.
|
||||
fn remove_tombstone(&mut self, ttid: &TenantTimelineId) -> bool {
|
||||
self.tombstones.remove(ttid).is_some()
|
||||
|| self.tenant_tombstones.remove(&ttid.tenant_id).is_some()
|
||||
}
|
||||
|
||||
fn delete(&mut self, ttid: TenantTimelineId) {
|
||||
self.timelines.remove(&ttid);
|
||||
self.tombstones.insert(ttid, Instant::now());
|
||||
}
|
||||
|
||||
fn add_tenant_tombstone(&mut self, tenant_id: TenantId) {
|
||||
self.tenant_tombstones.insert(tenant_id, Instant::now());
|
||||
}
|
||||
}
|
||||
|
||||
/// A struct used to manage access to the global timelines map.
|
||||
@@ -99,6 +115,7 @@ impl GlobalTimelines {
|
||||
state: Mutex::new(GlobalTimelinesState {
|
||||
timelines: HashMap::new(),
|
||||
tombstones: HashMap::new(),
|
||||
tenant_tombstones: HashMap::new(),
|
||||
conf,
|
||||
broker_active_set: Arc::new(TimelinesSet::default()),
|
||||
global_rate_limiter: RateLimiter::new(1, 1),
|
||||
@@ -245,7 +262,7 @@ impl GlobalTimelines {
|
||||
return Ok(timeline);
|
||||
}
|
||||
|
||||
if state.tombstones.contains_key(&ttid) {
|
||||
if state.has_tombstone(&ttid) {
|
||||
anyhow::bail!("Timeline {ttid} is deleted, refusing to recreate");
|
||||
}
|
||||
|
||||
@@ -295,13 +312,14 @@ impl GlobalTimelines {
|
||||
_ => {}
|
||||
}
|
||||
if check_tombstone {
|
||||
if state.tombstones.contains_key(&ttid) {
|
||||
if state.has_tombstone(&ttid) {
|
||||
anyhow::bail!("timeline {ttid} is deleted, refusing to recreate");
|
||||
}
|
||||
} else {
|
||||
// We may be have been asked to load a timeline that was previously deleted (e.g. from `pull_timeline.rs`). We trust
|
||||
// that the human doing this manual intervention knows what they are doing, and remove its tombstone.
|
||||
if state.tombstones.remove(&ttid).is_some() {
|
||||
// It's also possible that we enter this when the tenant has been deleted, even if the timeline itself has never existed.
|
||||
if state.remove_tombstone(&ttid) {
|
||||
warn!("un-deleted timeline {ttid}");
|
||||
}
|
||||
}
|
||||
@@ -482,6 +500,7 @@ impl GlobalTimelines {
|
||||
let tli_res = {
|
||||
let state = self.state.lock().unwrap();
|
||||
|
||||
// Do NOT check tenant tombstones here: those were set earlier
|
||||
if state.tombstones.contains_key(ttid) {
|
||||
// Presence of a tombstone guarantees that a previous deletion has completed and there is no work to do.
|
||||
info!("Timeline {ttid} was already deleted");
|
||||
@@ -557,6 +576,10 @@ impl GlobalTimelines {
|
||||
action: DeleteOrExclude,
|
||||
) -> Result<HashMap<TenantTimelineId, TimelineDeleteResult>> {
|
||||
info!("deleting all timelines for tenant {}", tenant_id);
|
||||
|
||||
// Adding a tombstone before getting the timelines to prevent new timeline additions
|
||||
self.state.lock().unwrap().add_tenant_tombstone(*tenant_id);
|
||||
|
||||
let to_delete = self.get_all_for_tenant(*tenant_id);
|
||||
|
||||
let mut err = None;
|
||||
@@ -600,6 +623,9 @@ impl GlobalTimelines {
|
||||
state
|
||||
.tombstones
|
||||
.retain(|_, v| now.duration_since(*v) < *tombstone_ttl);
|
||||
state
|
||||
.tenant_tombstones
|
||||
.retain(|_, v| now.duration_since(*v) < *tombstone_ttl);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -482,6 +482,10 @@ async fn handle_tenant_timeline_delete(
|
||||
ForwardOutcome::NotForwarded(_req) => {}
|
||||
};
|
||||
|
||||
service
|
||||
.maybe_delete_timeline_import(tenant_id, timeline_id)
|
||||
.await?;
|
||||
|
||||
// For timeline deletions, which both implement an "initially return 202, then 404 once
|
||||
// we're done" semantic, we wrap with a retry loop to expose a simpler API upstream.
|
||||
async fn deletion_wrapper<R, F>(service: Arc<Service>, f: F) -> Result<Response<Body>, ApiError>
|
||||
|
||||
@@ -139,6 +139,14 @@ pub(crate) struct StorageControllerMetricGroup {
|
||||
/// HTTP request status counters for handled requests
|
||||
pub(crate) storage_controller_reconcile_long_running:
|
||||
measured::CounterVec<ReconcileLongRunningLabelGroupSet>,
|
||||
|
||||
/// Indicator of safekeeper reconciler queue depth, broken down by safekeeper, excluding ongoing reconciles.
|
||||
pub(crate) storage_controller_safkeeper_reconciles_queued:
|
||||
measured::GaugeVec<SafekeeperReconcilerLabelGroupSet>,
|
||||
|
||||
/// Indicator of completed safekeeper reconciles, broken down by safekeeper.
|
||||
pub(crate) storage_controller_safkeeper_reconciles_complete:
|
||||
measured::CounterVec<SafekeeperReconcilerLabelGroupSet>,
|
||||
}
|
||||
|
||||
impl StorageControllerMetrics {
|
||||
@@ -257,6 +265,17 @@ pub(crate) enum Method {
|
||||
Other,
|
||||
}
|
||||
|
||||
#[derive(measured::LabelGroup, Clone)]
|
||||
#[label(set = SafekeeperReconcilerLabelGroupSet)]
|
||||
pub(crate) struct SafekeeperReconcilerLabelGroup<'a> {
|
||||
#[label(dynamic_with = lasso::ThreadedRodeo, default)]
|
||||
pub(crate) sk_az: &'a str,
|
||||
#[label(dynamic_with = lasso::ThreadedRodeo, default)]
|
||||
pub(crate) sk_node_id: &'a str,
|
||||
#[label(dynamic_with = lasso::ThreadedRodeo, default)]
|
||||
pub(crate) sk_hostname: &'a str,
|
||||
}
|
||||
|
||||
impl From<hyper::Method> for Method {
|
||||
fn from(value: hyper::Method) -> Self {
|
||||
if value == hyper::Method::GET {
|
||||
|
||||
@@ -99,8 +99,8 @@ use crate::tenant_shard::{
|
||||
ScheduleOptimization, ScheduleOptimizationAction, TenantShard,
|
||||
};
|
||||
use crate::timeline_import::{
|
||||
ImportResult, ShardImportStatuses, TimelineImport, TimelineImportFinalizeError,
|
||||
TimelineImportState, UpcallClient,
|
||||
FinalizingImport, ImportResult, ShardImportStatuses, TimelineImport,
|
||||
TimelineImportFinalizeError, TimelineImportState, UpcallClient,
|
||||
};
|
||||
|
||||
const WAITER_FILL_DRAIN_POLL_TIMEOUT: Duration = Duration::from_millis(500);
|
||||
@@ -232,6 +232,9 @@ struct ServiceState {
|
||||
|
||||
/// Queue of tenants who are waiting for concurrency limits to permit them to reconcile
|
||||
delayed_reconcile_rx: tokio::sync::mpsc::Receiver<TenantShardId>,
|
||||
|
||||
/// Tracks ongoing timeline import finalization tasks
|
||||
imports_finalizing: BTreeMap<(TenantId, TimelineId), FinalizingImport>,
|
||||
}
|
||||
|
||||
/// Transform an error from a pageserver into an error to return to callers of a storage
|
||||
@@ -308,6 +311,7 @@ impl ServiceState {
|
||||
scheduler,
|
||||
ongoing_operation: None,
|
||||
delayed_reconcile_rx,
|
||||
imports_finalizing: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4097,13 +4101,58 @@ impl Service {
|
||||
///
|
||||
/// If this method gets pre-empted by shut down, it will be called again at start-up (on-going
|
||||
/// imports are stored in the database).
|
||||
///
|
||||
/// # Cancel-Safety
|
||||
/// Not cancel safe.
|
||||
/// If the caller stops polling, the import will not be removed from
|
||||
/// [`ServiceState::imports_finalizing`].
|
||||
#[instrument(skip_all, fields(
|
||||
tenant_id=%import.tenant_id,
|
||||
timeline_id=%import.timeline_id,
|
||||
))]
|
||||
|
||||
async fn finalize_timeline_import(
|
||||
self: &Arc<Self>,
|
||||
import: TimelineImport,
|
||||
) -> Result<(), TimelineImportFinalizeError> {
|
||||
let tenant_timeline = (import.tenant_id, import.timeline_id);
|
||||
|
||||
let (_finalize_import_guard, cancel) = {
|
||||
let mut locked = self.inner.write().unwrap();
|
||||
let gate = Gate::default();
|
||||
let cancel = CancellationToken::default();
|
||||
|
||||
let guard = gate.enter().unwrap();
|
||||
|
||||
locked.imports_finalizing.insert(
|
||||
tenant_timeline,
|
||||
FinalizingImport {
|
||||
gate,
|
||||
cancel: cancel.clone(),
|
||||
},
|
||||
);
|
||||
|
||||
(guard, cancel)
|
||||
};
|
||||
|
||||
let res = tokio::select! {
|
||||
res = self.finalize_timeline_import_impl(import) => {
|
||||
res
|
||||
},
|
||||
_ = cancel.cancelled() => {
|
||||
Err(TimelineImportFinalizeError::Cancelled)
|
||||
}
|
||||
};
|
||||
|
||||
let mut locked = self.inner.write().unwrap();
|
||||
locked.imports_finalizing.remove(&tenant_timeline);
|
||||
|
||||
res
|
||||
}
|
||||
|
||||
async fn finalize_timeline_import_impl(
|
||||
self: &Arc<Self>,
|
||||
import: TimelineImport,
|
||||
) -> Result<(), TimelineImportFinalizeError> {
|
||||
tracing::info!("Finalizing timeline import");
|
||||
|
||||
@@ -4303,6 +4352,46 @@ impl Service {
|
||||
.await;
|
||||
}
|
||||
|
||||
/// Delete a timeline import if it exists
|
||||
///
|
||||
/// Firstly, delete the entry from the database. Any updates
|
||||
/// from pageservers after the update will fail with a 404, so the
|
||||
/// import cannot progress into finalizing state if it's not there already.
|
||||
/// Secondly, cancel the finalization if one is in progress.
|
||||
pub(crate) async fn maybe_delete_timeline_import(
|
||||
self: &Arc<Self>,
|
||||
tenant_id: TenantId,
|
||||
timeline_id: TimelineId,
|
||||
) -> Result<(), DatabaseError> {
|
||||
let tenant_has_ongoing_import = {
|
||||
let locked = self.inner.read().unwrap();
|
||||
locked
|
||||
.tenants
|
||||
.range(TenantShardId::tenant_range(tenant_id))
|
||||
.any(|(_tid, shard)| shard.importing == TimelineImportState::Importing)
|
||||
};
|
||||
|
||||
if !tenant_has_ongoing_import {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
self.persistence
|
||||
.delete_timeline_import(tenant_id, timeline_id)
|
||||
.await?;
|
||||
|
||||
let maybe_finalizing = {
|
||||
let mut locked = self.inner.write().unwrap();
|
||||
locked.imports_finalizing.remove(&(tenant_id, timeline_id))
|
||||
};
|
||||
|
||||
if let Some(finalizing) = maybe_finalizing {
|
||||
finalizing.cancel.cancel();
|
||||
finalizing.gate.close().await;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn tenant_timeline_archival_config(
|
||||
&self,
|
||||
tenant_id: TenantId,
|
||||
@@ -8538,8 +8627,9 @@ impl Service {
|
||||
Some(ShardCount(new_shard_count))
|
||||
}
|
||||
|
||||
/// Fetches the top tenant shards from every node, in descending order of
|
||||
/// max logical size. Any node errors will be logged and ignored.
|
||||
/// Fetches the top tenant shards from every available node, in descending order of
|
||||
/// max logical size. Offline nodes are skipped, and any errors from available nodes
|
||||
/// will be logged and ignored.
|
||||
async fn get_top_tenant_shards(
|
||||
&self,
|
||||
request: &TopTenantShardsRequest,
|
||||
@@ -8550,6 +8640,7 @@ impl Service {
|
||||
.unwrap()
|
||||
.nodes
|
||||
.values()
|
||||
.filter(|node| node.is_available())
|
||||
.cloned()
|
||||
.collect_vec();
|
||||
|
||||
|
||||
@@ -20,7 +20,9 @@ use utils::{
|
||||
};
|
||||
|
||||
use crate::{
|
||||
persistence::SafekeeperTimelineOpKind, safekeeper::Safekeeper,
|
||||
metrics::{METRICS_REGISTRY, SafekeeperReconcilerLabelGroup},
|
||||
persistence::SafekeeperTimelineOpKind,
|
||||
safekeeper::Safekeeper,
|
||||
safekeeper_client::SafekeeperClient,
|
||||
};
|
||||
|
||||
@@ -218,7 +220,26 @@ impl ReconcilerHandle {
|
||||
fn schedule_reconcile(&self, req: ScheduleRequest) {
|
||||
let (cancel, token_id) = self.new_token_slot(req.tenant_id, req.timeline_id);
|
||||
let hostname = req.safekeeper.skp.host.clone();
|
||||
let sk_az = req.safekeeper.skp.availability_zone_id.clone();
|
||||
let sk_node_id = req.safekeeper.get_id().to_string();
|
||||
|
||||
// We don't have direct access to the queue depth here, so increase it blindly by 1.
|
||||
// We know that putting into the queue increases the queue depth. The receiver will
|
||||
// update with the correct value once it processes the next item. To avoid races where we
|
||||
// reduce before we increase, leaving the gauge with a 1 value for a long time, we
|
||||
// increase it before putting into the queue.
|
||||
let queued_gauge = &METRICS_REGISTRY
|
||||
.metrics_group
|
||||
.storage_controller_safkeeper_reconciles_queued;
|
||||
let label_group = SafekeeperReconcilerLabelGroup {
|
||||
sk_az: &sk_az,
|
||||
sk_node_id: &sk_node_id,
|
||||
sk_hostname: &hostname,
|
||||
};
|
||||
queued_gauge.inc(label_group.clone());
|
||||
|
||||
if let Err(err) = self.tx.send((req, cancel, token_id)) {
|
||||
queued_gauge.set(label_group, 0);
|
||||
tracing::info!("scheduling request onto {hostname} returned error: {err}");
|
||||
}
|
||||
}
|
||||
@@ -283,6 +304,18 @@ impl SafekeeperReconciler {
|
||||
continue;
|
||||
}
|
||||
|
||||
let queued_gauge = &METRICS_REGISTRY
|
||||
.metrics_group
|
||||
.storage_controller_safkeeper_reconciles_queued;
|
||||
queued_gauge.set(
|
||||
SafekeeperReconcilerLabelGroup {
|
||||
sk_az: &req.safekeeper.skp.availability_zone_id,
|
||||
sk_node_id: &req.safekeeper.get_id().to_string(),
|
||||
sk_hostname: &req.safekeeper.skp.host,
|
||||
},
|
||||
self.rx.len() as i64,
|
||||
);
|
||||
|
||||
tokio::task::spawn(async move {
|
||||
let kind = req.kind;
|
||||
let tenant_id = req.tenant_id;
|
||||
@@ -511,6 +544,16 @@ impl SafekeeperReconcilerInner {
|
||||
req.generation,
|
||||
)
|
||||
.await;
|
||||
|
||||
let complete_counter = &METRICS_REGISTRY
|
||||
.metrics_group
|
||||
.storage_controller_safkeeper_reconciles_complete;
|
||||
complete_counter.inc(SafekeeperReconcilerLabelGroup {
|
||||
sk_az: &req.safekeeper.skp.availability_zone_id,
|
||||
sk_node_id: &req.safekeeper.get_id().to_string(),
|
||||
sk_hostname: &req.safekeeper.skp.host,
|
||||
});
|
||||
|
||||
if let Err(err) = res {
|
||||
tracing::info!(
|
||||
"couldn't remove reconciliation request onto {} from persistence: {err:?}",
|
||||
|
||||
@@ -7,6 +7,7 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
use pageserver_api::models::{ShardImportProgress, ShardImportStatus};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use utils::sync::gate::Gate;
|
||||
use utils::{
|
||||
id::{TenantId, TimelineId},
|
||||
shard::ShardIndex,
|
||||
@@ -55,6 +56,8 @@ pub(crate) enum TimelineImportUpdateFollowUp {
|
||||
pub(crate) enum TimelineImportFinalizeError {
|
||||
#[error("Shut down interrupted import finalize")]
|
||||
ShuttingDown,
|
||||
#[error("Import finalization was cancelled")]
|
||||
Cancelled,
|
||||
#[error("Mismatched shard detected during import finalize: {0}")]
|
||||
MismatchedShards(ShardIndex),
|
||||
}
|
||||
@@ -164,6 +167,11 @@ impl TimelineImport {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct FinalizingImport {
|
||||
pub(crate) gate: Gate,
|
||||
pub(crate) cancel: CancellationToken,
|
||||
}
|
||||
|
||||
pub(crate) type ImportResult = Result<(), String>;
|
||||
|
||||
pub(crate) struct UpcallClient {
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
@@ -11,6 +12,7 @@ from _pytest.config import Config
|
||||
|
||||
from fixtures.log_helper import log
|
||||
from fixtures.neon_cli import AbstractNeonCli
|
||||
from fixtures.neon_fixtures import Endpoint, VanillaPostgres
|
||||
from fixtures.pg_version import PgVersion
|
||||
from fixtures.remote_storage import MockS3Server
|
||||
|
||||
@@ -161,3 +163,57 @@ def fast_import(
|
||||
f.write(fi.cmd.stderr)
|
||||
|
||||
log.info("Written logs to %s", test_output_dir)
|
||||
|
||||
|
||||
def mock_import_bucket(vanilla_pg: VanillaPostgres, path: Path):
|
||||
"""
|
||||
Mock the import S3 bucket into a local directory for a provided vanilla PG instance.
|
||||
"""
|
||||
assert not vanilla_pg.is_running()
|
||||
|
||||
path.mkdir()
|
||||
# what cplane writes before scheduling fast_import
|
||||
specpath = path / "spec.json"
|
||||
specpath.write_text(json.dumps({"branch_id": "somebranch", "project_id": "someproject"}))
|
||||
# what fast_import writes
|
||||
vanilla_pg.pgdatadir.rename(path / "pgdata")
|
||||
statusdir = path / "status"
|
||||
statusdir.mkdir()
|
||||
(statusdir / "pgdata").write_text(json.dumps({"done": True}))
|
||||
(statusdir / "fast_import").write_text(json.dumps({"command": "pgdata", "done": True}))
|
||||
|
||||
|
||||
def populate_vanilla_pg(vanilla_pg: VanillaPostgres, target_relblock_size: int) -> int:
|
||||
assert vanilla_pg.is_running()
|
||||
|
||||
vanilla_pg.safe_psql("create user cloud_admin with password 'postgres' superuser")
|
||||
# fillfactor so we don't need to produce that much data
|
||||
# 900 byte per row is > 10% => 1 row per page
|
||||
vanilla_pg.safe_psql("""create table t (data char(900)) with (fillfactor = 10)""")
|
||||
|
||||
nrows = 0
|
||||
while True:
|
||||
relblock_size = vanilla_pg.safe_psql_scalar("select pg_relation_size('t')")
|
||||
log.info(
|
||||
f"relblock size: {relblock_size / 8192} pages (target: {target_relblock_size // 8192}) pages"
|
||||
)
|
||||
if relblock_size >= target_relblock_size:
|
||||
break
|
||||
addrows = int((target_relblock_size - relblock_size) // 8192)
|
||||
assert addrows >= 1, "forward progress"
|
||||
vanilla_pg.safe_psql(
|
||||
f"insert into t select generate_series({nrows + 1}, {nrows + addrows})"
|
||||
)
|
||||
nrows += addrows
|
||||
|
||||
return nrows
|
||||
|
||||
|
||||
def validate_import_from_vanilla_pg(endpoint: Endpoint, nrows: int):
|
||||
assert endpoint.safe_psql_many(
|
||||
[
|
||||
"set effective_io_concurrency=32;",
|
||||
"SET statement_timeout='300s';",
|
||||
"select count(*), sum(data::bigint)::bigint from t",
|
||||
]
|
||||
) == [[], [], [(nrows, nrows * (nrows + 1) // 2)]]
|
||||
|
||||
@@ -2337,6 +2337,22 @@ class NeonStorageController(MetricsGetter, LogUtils):
|
||||
headers=self.headers(TokenScope.ADMIN),
|
||||
)
|
||||
|
||||
def import_status(
|
||||
self, tenant_shard_id: TenantShardId, timeline_id: TimelineId, generation: int
|
||||
):
|
||||
payload = {
|
||||
"tenant_shard_id": str(tenant_shard_id),
|
||||
"timeline_id": str(timeline_id),
|
||||
"generation": generation,
|
||||
}
|
||||
|
||||
self.request(
|
||||
"GET",
|
||||
f"{self.api}/upcall/v1/timeline_import_status",
|
||||
headers=self.headers(TokenScope.GENERATIONS_API),
|
||||
json=payload,
|
||||
)
|
||||
|
||||
def reconcile_all(self):
|
||||
r = self.request(
|
||||
"POST",
|
||||
@@ -2813,6 +2829,11 @@ class NeonPageserver(PgProtocol, LogUtils):
|
||||
if self.running:
|
||||
self.http_client().configure_failpoints([(name, action)])
|
||||
|
||||
def clear_persistent_failpoint(self, name: str):
|
||||
del self._persistent_failpoints[name]
|
||||
if self.running:
|
||||
self.http_client().configure_failpoints([(name, "off")])
|
||||
|
||||
def timeline_dir(
|
||||
self,
|
||||
tenant_shard_id: TenantId | TenantShardId,
|
||||
|
||||
@@ -675,7 +675,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter):
|
||||
|
||||
def timeline_delete(
|
||||
self, tenant_id: TenantId | TenantShardId, timeline_id: TimelineId, **kwargs
|
||||
):
|
||||
) -> int:
|
||||
"""
|
||||
Note that deletion is not instant, it is scheduled and performed mostly in the background.
|
||||
So if you need to wait for it to complete use `timeline_delete_wait_completed`.
|
||||
@@ -688,6 +688,8 @@ class PageserverHttpClient(requests.Session, MetricsGetter):
|
||||
res_json = res.json()
|
||||
assert res_json is None
|
||||
|
||||
return res.status_code
|
||||
|
||||
def timeline_gc(
|
||||
self,
|
||||
tenant_id: TenantId | TenantShardId,
|
||||
|
||||
@@ -1,31 +1,41 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import json
|
||||
import time
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from threading import Event
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
from fixtures.common_types import Lsn, TenantId, TimelineId
|
||||
from fixtures.fast_import import mock_import_bucket, populate_vanilla_pg
|
||||
from fixtures.log_helper import log
|
||||
from fixtures.neon_fixtures import (
|
||||
NeonEnv,
|
||||
NeonEnvBuilder,
|
||||
NeonPageserver,
|
||||
PgBin,
|
||||
VanillaPostgres,
|
||||
wait_for_last_flush_lsn,
|
||||
)
|
||||
from fixtures.pageserver.http import (
|
||||
ImportPgdataIdemptencyKey,
|
||||
)
|
||||
from fixtures.pageserver.utils import wait_for_upload_queue_empty
|
||||
from fixtures.remote_storage import RemoteStorageKind
|
||||
from fixtures.utils import human_bytes, wait_until
|
||||
from fixtures.utils import human_bytes, run_only_on_default_postgres, wait_until
|
||||
from werkzeug.wrappers.response import Response
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
from fixtures.pageserver.http import PageserverHttpClient
|
||||
from pytest_httpserver import HTTPServer
|
||||
from werkzeug.wrappers.request import Request
|
||||
|
||||
|
||||
GLOBAL_LRU_LOG_LINE = "tenant_min_resident_size-respecting LRU would not relieve pressure, evicting more following global LRU policy"
|
||||
@@ -164,6 +174,7 @@ class EvictionEnv:
|
||||
min_avail_bytes,
|
||||
mock_behavior,
|
||||
eviction_order: EvictionOrder,
|
||||
wait_logical_size: bool = True,
|
||||
):
|
||||
"""
|
||||
Starts pageserver up with mocked statvfs setup. The startup is
|
||||
@@ -201,11 +212,12 @@ class EvictionEnv:
|
||||
pageserver.start()
|
||||
|
||||
# we now do initial logical size calculation on startup, which on debug builds can fight with disk usage based eviction
|
||||
for tenant_id, timeline_id in self.timelines:
|
||||
tenant_ps = self.neon_env.get_tenant_pageserver(tenant_id)
|
||||
# Pageserver may be none if we are currently not attached anywhere, e.g. during secondary eviction test
|
||||
if tenant_ps is not None:
|
||||
tenant_ps.http_client().timeline_wait_logical_size(tenant_id, timeline_id)
|
||||
if wait_logical_size:
|
||||
for tenant_id, timeline_id in self.timelines:
|
||||
tenant_ps = self.neon_env.get_tenant_pageserver(tenant_id)
|
||||
# Pageserver may be none if we are currently not attached anywhere, e.g. during secondary eviction test
|
||||
if tenant_ps is not None:
|
||||
tenant_ps.http_client().timeline_wait_logical_size(tenant_id, timeline_id)
|
||||
|
||||
def statvfs_called():
|
||||
pageserver.assert_log_contains(".*running mocked statvfs.*")
|
||||
@@ -882,3 +894,121 @@ def test_secondary_mode_eviction(eviction_env_ha: EvictionEnv):
|
||||
assert total_size - post_eviction_total_size >= evict_bytes, (
|
||||
"we requested at least evict_bytes worth of free space"
|
||||
)
|
||||
|
||||
|
||||
@run_only_on_default_postgres(reason="PG version is irrelevant here")
|
||||
def test_import_timeline_disk_pressure_eviction(
|
||||
neon_env_builder: NeonEnvBuilder,
|
||||
vanilla_pg: VanillaPostgres,
|
||||
make_httpserver: HTTPServer,
|
||||
pg_bin: PgBin,
|
||||
):
|
||||
"""
|
||||
TODO
|
||||
"""
|
||||
# Set up mock control plane HTTP server to listen for import completions
|
||||
import_completion_signaled = Event()
|
||||
|
||||
def handler(request: Request) -> Response:
|
||||
log.info(f"control plane /import_complete request: {request.json}")
|
||||
import_completion_signaled.set()
|
||||
return Response(json.dumps({}), status=200)
|
||||
|
||||
cplane_mgmt_api_server = make_httpserver
|
||||
cplane_mgmt_api_server.expect_request(
|
||||
"/storage/api/v1/import_complete", method="PUT"
|
||||
).respond_with_handler(handler)
|
||||
|
||||
# Plug the cplane mock in
|
||||
neon_env_builder.control_plane_hooks_api = (
|
||||
f"http://{cplane_mgmt_api_server.host}:{cplane_mgmt_api_server.port}/storage/api/v1/"
|
||||
)
|
||||
|
||||
# The import will specifiy a local filesystem path mocking remote storage
|
||||
neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.LOCAL_FS)
|
||||
|
||||
vanilla_pg.start()
|
||||
target_relblock_size = 1024 * 1024 * 128
|
||||
populate_vanilla_pg(vanilla_pg, target_relblock_size)
|
||||
vanilla_pg.stop()
|
||||
|
||||
env = neon_env_builder.init_configs()
|
||||
env.start()
|
||||
|
||||
importbucket_path = neon_env_builder.repo_dir / "test_import_completion_bucket"
|
||||
mock_import_bucket(vanilla_pg, importbucket_path)
|
||||
|
||||
tenant_id = TenantId.generate()
|
||||
timeline_id = TimelineId.generate()
|
||||
idempotency = ImportPgdataIdemptencyKey.random()
|
||||
|
||||
eviction_env = EvictionEnv(
|
||||
timelines=[(tenant_id, timeline_id)],
|
||||
neon_env=env,
|
||||
pageserver_http=env.pageserver.http_client(),
|
||||
layer_size=5 * 1024 * 1024, # Doesn't apply here
|
||||
pg_bin=pg_bin, # Not used here
|
||||
pgbench_init_lsns={}, # Not used here
|
||||
)
|
||||
|
||||
# Pause before delivering the final notification to storcon.
|
||||
# This keeps the import in progress.
|
||||
failpoint_name = "import-timeline-pre-success-notify-pausable"
|
||||
env.pageserver.add_persistent_failpoint(failpoint_name, "pause")
|
||||
|
||||
env.storage_controller.tenant_create(tenant_id)
|
||||
env.storage_controller.timeline_create(
|
||||
tenant_id,
|
||||
{
|
||||
"new_timeline_id": str(timeline_id),
|
||||
"import_pgdata": {
|
||||
"idempotency_key": str(idempotency),
|
||||
"location": {"LocalFs": {"path": str(importbucket_path.absolute())}},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
def hit_failpoint():
|
||||
log.info("Checking log for pattern...")
|
||||
try:
|
||||
assert env.pageserver.log_contains(f".*at failpoint {failpoint_name}.*")
|
||||
except Exception:
|
||||
log.exception("Failed to find pattern in log")
|
||||
raise
|
||||
|
||||
wait_until(hit_failpoint)
|
||||
assert not import_completion_signaled.is_set()
|
||||
|
||||
env.pageserver.stop()
|
||||
|
||||
total_size, _, _ = eviction_env.timelines_du(env.pageserver)
|
||||
blocksize = 512
|
||||
total_blocks = (total_size + (blocksize - 1)) // blocksize
|
||||
|
||||
eviction_env.pageserver_start_with_disk_usage_eviction(
|
||||
env.pageserver,
|
||||
period="1s",
|
||||
max_usage_pct=33,
|
||||
min_avail_bytes=0,
|
||||
mock_behavior={
|
||||
"type": "Success",
|
||||
"blocksize": blocksize,
|
||||
"total_blocks": total_blocks,
|
||||
# Only count layer files towards used bytes in the mock_statvfs.
|
||||
# This avoids accounting for metadata files & tenant conf in the tests.
|
||||
"name_filter": ".*__.*",
|
||||
},
|
||||
eviction_order=EvictionOrder.RELATIVE_ORDER_SPARE,
|
||||
wait_logical_size=False,
|
||||
)
|
||||
|
||||
wait_until(lambda: env.pageserver.assert_log_contains(".*disk usage pressure relieved"))
|
||||
|
||||
env.pageserver.clear_persistent_failpoint(failpoint_name)
|
||||
|
||||
def cplane_notified():
|
||||
assert import_completion_signaled.is_set()
|
||||
|
||||
wait_until(cplane_notified)
|
||||
|
||||
env.pageserver.allowed_errors.append(r".* running disk usage based eviction due to pressure.*")
|
||||
|
||||
@@ -12,13 +12,19 @@ import psycopg2
|
||||
import psycopg2.errors
|
||||
import pytest
|
||||
from fixtures.common_types import Lsn, TenantId, TenantShardId, TimelineId
|
||||
from fixtures.fast_import import FastImport
|
||||
from fixtures.fast_import import (
|
||||
FastImport,
|
||||
mock_import_bucket,
|
||||
populate_vanilla_pg,
|
||||
validate_import_from_vanilla_pg,
|
||||
)
|
||||
from fixtures.log_helper import log
|
||||
from fixtures.neon_fixtures import (
|
||||
NeonEnvBuilder,
|
||||
PageserverImportConfig,
|
||||
PgBin,
|
||||
PgProtocol,
|
||||
StorageControllerApiException,
|
||||
StorageControllerMigrationConfig,
|
||||
VanillaPostgres,
|
||||
)
|
||||
@@ -59,24 +65,6 @@ smoke_params = [
|
||||
]
|
||||
|
||||
|
||||
def mock_import_bucket(vanilla_pg: VanillaPostgres, path: Path):
|
||||
"""
|
||||
Mock the import S3 bucket into a local directory for a provided vanilla PG instance.
|
||||
"""
|
||||
assert not vanilla_pg.is_running()
|
||||
|
||||
path.mkdir()
|
||||
# what cplane writes before scheduling fast_import
|
||||
specpath = path / "spec.json"
|
||||
specpath.write_text(json.dumps({"branch_id": "somebranch", "project_id": "someproject"}))
|
||||
# what fast_import writes
|
||||
vanilla_pg.pgdatadir.rename(path / "pgdata")
|
||||
statusdir = path / "status"
|
||||
statusdir.mkdir()
|
||||
(statusdir / "pgdata").write_text(json.dumps({"done": True}))
|
||||
(statusdir / "fast_import").write_text(json.dumps({"command": "pgdata", "done": True}))
|
||||
|
||||
|
||||
@skip_in_debug_build("MULTIPLE_RELATION_SEGMENTS has non trivial amount of data")
|
||||
@pytest.mark.parametrize("shard_count,stripe_size,rel_block_size", smoke_params)
|
||||
def test_pgdata_import_smoke(
|
||||
@@ -131,10 +119,6 @@ def test_pgdata_import_smoke(
|
||||
# Put data in vanilla pg
|
||||
#
|
||||
|
||||
vanilla_pg.start()
|
||||
vanilla_pg.safe_psql("create user cloud_admin with password 'postgres' superuser")
|
||||
|
||||
log.info("create relblock data")
|
||||
if rel_block_size == RelBlockSize.ONE_STRIPE_SIZE:
|
||||
target_relblock_size = stripe_size * 8192
|
||||
elif rel_block_size == RelBlockSize.TWO_STRPES_PER_SHARD:
|
||||
@@ -145,45 +129,8 @@ def test_pgdata_import_smoke(
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
# fillfactor so we don't need to produce that much data
|
||||
# 900 byte per row is > 10% => 1 row per page
|
||||
vanilla_pg.safe_psql("""create table t (data char(900)) with (fillfactor = 10)""")
|
||||
|
||||
nrows = 0
|
||||
while True:
|
||||
relblock_size = vanilla_pg.safe_psql_scalar("select pg_relation_size('t')")
|
||||
log.info(
|
||||
f"relblock size: {relblock_size / 8192} pages (target: {target_relblock_size // 8192}) pages"
|
||||
)
|
||||
if relblock_size >= target_relblock_size:
|
||||
break
|
||||
addrows = int((target_relblock_size - relblock_size) // 8192)
|
||||
assert addrows >= 1, "forward progress"
|
||||
vanilla_pg.safe_psql(
|
||||
f"insert into t select generate_series({nrows + 1}, {nrows + addrows})"
|
||||
)
|
||||
nrows += addrows
|
||||
expect_nrows = nrows
|
||||
expect_sum = (
|
||||
(nrows) * (nrows + 1) // 2
|
||||
) # https://stackoverflow.com/questions/43901484/sum-of-the-integers-from-1-to-n
|
||||
|
||||
def validate_vanilla_equivalence(ep):
|
||||
# TODO: would be nicer to just compare pgdump
|
||||
|
||||
# Enable IO concurrency for batching on large sequential scan, to avoid making
|
||||
# this test unnecessarily onerous on CPU. Especially on debug mode, it's still
|
||||
# pretty onerous though, so increase statement_timeout to avoid timeouts.
|
||||
assert ep.safe_psql_many(
|
||||
[
|
||||
"set effective_io_concurrency=32;",
|
||||
"SET statement_timeout='300s';",
|
||||
"select count(*), sum(data::bigint)::bigint from t",
|
||||
]
|
||||
) == [[], [], [(expect_nrows, expect_sum)]]
|
||||
|
||||
validate_vanilla_equivalence(vanilla_pg)
|
||||
|
||||
vanilla_pg.start()
|
||||
rows_inserted = populate_vanilla_pg(vanilla_pg, target_relblock_size)
|
||||
vanilla_pg.stop()
|
||||
|
||||
#
|
||||
@@ -274,14 +221,14 @@ def test_pgdata_import_smoke(
|
||||
config_lines=ep_config,
|
||||
)
|
||||
|
||||
validate_vanilla_equivalence(ro_endpoint)
|
||||
validate_import_from_vanilla_pg(ro_endpoint, rows_inserted)
|
||||
|
||||
# ensure the import survives restarts
|
||||
ro_endpoint.stop()
|
||||
env.pageserver.stop(immediate=True)
|
||||
env.pageserver.start()
|
||||
ro_endpoint.start()
|
||||
validate_vanilla_equivalence(ro_endpoint)
|
||||
validate_import_from_vanilla_pg(ro_endpoint, rows_inserted)
|
||||
|
||||
#
|
||||
# validate the layer files in each shard only have the shard-specific data
|
||||
@@ -321,7 +268,7 @@ def test_pgdata_import_smoke(
|
||||
child_workload = workload.branch(timeline_id=child_timeline_id, branch_name="br-tip")
|
||||
child_workload.validate()
|
||||
|
||||
validate_vanilla_equivalence(child_workload.endpoint())
|
||||
validate_import_from_vanilla_pg(child_workload.endpoint(), rows_inserted)
|
||||
|
||||
# ... at the initdb lsn
|
||||
_ = env.create_branch(
|
||||
@@ -336,10 +283,21 @@ def test_pgdata_import_smoke(
|
||||
tenant_id=tenant_id,
|
||||
config_lines=ep_config,
|
||||
)
|
||||
validate_vanilla_equivalence(br_initdb_endpoint)
|
||||
validate_import_from_vanilla_pg(br_initdb_endpoint, rows_inserted)
|
||||
with pytest.raises(psycopg2.errors.UndefinedTable):
|
||||
br_initdb_endpoint.safe_psql(f"select * from {workload.table}")
|
||||
|
||||
# The storage controller might be overly eager and attempt to finalize
|
||||
# the import before the task got a chance to exit.
|
||||
env.storage_controller.allowed_errors.extend(
|
||||
[
|
||||
".*Call to node.*management API.*failed.*Import task still running.*",
|
||||
]
|
||||
)
|
||||
|
||||
for ps in env.pageservers:
|
||||
ps.allowed_errors.extend([".*Error processing HTTP request.*Import task not done yet.*"])
|
||||
|
||||
|
||||
@run_only_on_default_postgres(reason="PG version is irrelevant here")
|
||||
def test_import_completion_on_restart(
|
||||
@@ -423,8 +381,12 @@ def test_import_completion_on_restart(
|
||||
|
||||
|
||||
@run_only_on_default_postgres(reason="PG version is irrelevant here")
|
||||
def test_import_respects_tenant_shutdown(
|
||||
neon_env_builder: NeonEnvBuilder, vanilla_pg: VanillaPostgres, make_httpserver: HTTPServer
|
||||
@pytest.mark.parametrize("action", ["restart", "delete"])
|
||||
def test_import_respects_timeline_lifecycle(
|
||||
neon_env_builder: NeonEnvBuilder,
|
||||
vanilla_pg: VanillaPostgres,
|
||||
make_httpserver: HTTPServer,
|
||||
action: str,
|
||||
):
|
||||
"""
|
||||
Validate that importing timelines respect the usual timeline life cycle:
|
||||
@@ -492,16 +454,44 @@ def test_import_respects_tenant_shutdown(
|
||||
wait_until(hit_failpoint)
|
||||
assert not import_completion_signaled.is_set()
|
||||
|
||||
# Restart the pageserver while an import job is in progress.
|
||||
# This clears the failpoint and we expect that the import starts up afresh
|
||||
# after the restart and eventually completes.
|
||||
env.pageserver.stop()
|
||||
env.pageserver.start()
|
||||
if action == "restart":
|
||||
# Restart the pageserver while an import job is in progress.
|
||||
# This clears the failpoint and we expect that the import starts up afresh
|
||||
# after the restart and eventually completes.
|
||||
env.pageserver.stop()
|
||||
env.pageserver.start()
|
||||
|
||||
def cplane_notified():
|
||||
assert import_completion_signaled.is_set()
|
||||
def cplane_notified():
|
||||
assert import_completion_signaled.is_set()
|
||||
|
||||
wait_until(cplane_notified)
|
||||
wait_until(cplane_notified)
|
||||
elif action == "delete":
|
||||
status = env.storage_controller.pageserver_api().timeline_delete(tenant_id, timeline_id)
|
||||
assert status == 200
|
||||
|
||||
timeline_path = env.pageserver.timeline_dir(tenant_id, timeline_id)
|
||||
assert not timeline_path.exists(), "Timeline dir exists after deletion"
|
||||
|
||||
shard_zero = TenantShardId(tenant_id, 0, 0)
|
||||
location = env.storage_controller.inspect(shard_zero)
|
||||
assert location is not None
|
||||
generation = location[0]
|
||||
|
||||
with pytest.raises(StorageControllerApiException, match="not found"):
|
||||
env.storage_controller.import_status(shard_zero, timeline_id, generation)
|
||||
else:
|
||||
raise RuntimeError(f"{action} param not recognized")
|
||||
|
||||
# The storage controller might be overly eager and attempt to finalize
|
||||
# the import before the task got a chance to exit.
|
||||
env.storage_controller.allowed_errors.extend(
|
||||
[
|
||||
".*Call to node.*management API.*failed.*Import task still running.*",
|
||||
]
|
||||
)
|
||||
|
||||
for ps in env.pageservers:
|
||||
ps.allowed_errors.extend([".*Error processing HTTP request.*Import task not done yet.*"])
|
||||
|
||||
|
||||
@skip_in_debug_build("Validation query takes too long in debug builds")
|
||||
@@ -556,23 +546,8 @@ def test_import_chaos(
|
||||
neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.LOCAL_FS)
|
||||
|
||||
vanilla_pg.start()
|
||||
vanilla_pg.safe_psql("create user cloud_admin with password 'postgres' superuser")
|
||||
vanilla_pg.safe_psql("""create table t (data char(900)) with (fillfactor = 10)""")
|
||||
|
||||
nrows = 0
|
||||
while True:
|
||||
relblock_size = vanilla_pg.safe_psql_scalar("select pg_relation_size('t')")
|
||||
log.info(
|
||||
f"relblock size: {relblock_size / 8192} pages (target: {TARGET_RELBOCK_SIZE // 8192}) pages"
|
||||
)
|
||||
if relblock_size >= TARGET_RELBOCK_SIZE:
|
||||
break
|
||||
addrows = int((TARGET_RELBOCK_SIZE - relblock_size) // 8192)
|
||||
assert addrows >= 1, "forward progress"
|
||||
vanilla_pg.safe_psql(
|
||||
f"insert into t select generate_series({nrows + 1}, {nrows + addrows})"
|
||||
)
|
||||
nrows += addrows
|
||||
inserted_rows = populate_vanilla_pg(vanilla_pg, TARGET_RELBOCK_SIZE)
|
||||
|
||||
vanilla_pg.stop()
|
||||
|
||||
@@ -740,13 +715,7 @@ def test_import_chaos(
|
||||
endpoint = env.endpoints.create_start(branch_name=import_branch_name, tenant_id=tenant_id)
|
||||
|
||||
# Validate the imported data is legit
|
||||
assert endpoint.safe_psql_many(
|
||||
[
|
||||
"set effective_io_concurrency=32;",
|
||||
"SET statement_timeout='300s';",
|
||||
"select count(*), sum(data::bigint)::bigint from t",
|
||||
]
|
||||
) == [[], [], [(nrows, nrows * (nrows + 1) // 2)]]
|
||||
validate_import_from_vanilla_pg(endpoint, inserted_rows)
|
||||
|
||||
endpoint.stop()
|
||||
|
||||
|
||||
@@ -20,6 +20,9 @@ from fixtures.remote_storage import LocalFsStorage, RemoteStorageKind
|
||||
from fixtures.utils import query_scalar, wait_until
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="We won't create future layers any more after https://github.com/neondatabase/neon/pull/10548"
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"attach_mode",
|
||||
["default_generation", "same_generation"],
|
||||
|
||||
@@ -10,9 +10,14 @@ if TYPE_CHECKING:
|
||||
from fixtures.pageserver.http import PageserverHttpClient
|
||||
|
||||
|
||||
def check_tenant(env: NeonEnv, pageserver_http: PageserverHttpClient):
|
||||
def check_tenant(
|
||||
env: NeonEnv, pageserver_http: PageserverHttpClient, safekeeper_proto_version: int
|
||||
):
|
||||
tenant_id, timeline_id = env.create_tenant()
|
||||
endpoint = env.endpoints.create_start("main", tenant_id=tenant_id)
|
||||
config_lines = [
|
||||
f"neon.safekeeper_proto_version = {safekeeper_proto_version}",
|
||||
]
|
||||
endpoint = env.endpoints.create_start("main", tenant_id=tenant_id, config_lines=config_lines)
|
||||
# we rely upon autocommit after each statement
|
||||
res_1 = endpoint.safe_psql_many(
|
||||
queries=[
|
||||
@@ -37,10 +42,13 @@ def check_tenant(env: NeonEnv, pageserver_http: PageserverHttpClient):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_timelines,num_safekeepers", [(3, 1)])
|
||||
# Test both proto versions until we fully migrate.
|
||||
@pytest.mark.parametrize("safekeeper_proto_version", [2, 3])
|
||||
def test_normal_work(
|
||||
neon_env_builder: NeonEnvBuilder,
|
||||
num_timelines: int,
|
||||
num_safekeepers: int,
|
||||
safekeeper_proto_version: int,
|
||||
):
|
||||
"""
|
||||
Basic test:
|
||||
@@ -60,4 +68,4 @@ def test_normal_work(
|
||||
pageserver_http = env.pageserver.http_client()
|
||||
|
||||
for _ in range(num_timelines):
|
||||
check_tenant(env, pageserver_http)
|
||||
check_tenant(env, pageserver_http, safekeeper_proto_version)
|
||||
|
||||
@@ -124,6 +124,9 @@ def test_location_conf_churn(neon_env_builder: NeonEnvBuilder, make_httpserver,
|
||||
".*downloading failed, possibly for shutdown",
|
||||
# {tenant_id=... timeline_id=...}:handle_pagerequests:handle_get_page_at_lsn_request{rel=1664/0/1260 blkno=0 req_lsn=0/149F0D8}: error reading relation or page version: Not found: will not become active. Current state: Stopping\n'
|
||||
".*page_service.*will not become active.*",
|
||||
# the following errors are possible when pageserver tries to ingest wal records despite being in unreadable state
|
||||
".*wal_connection_manager.*layer file download failed: No file found.*",
|
||||
".*wal_connection_manager.*could not ingest record.*",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -156,6 +159,45 @@ def test_location_conf_churn(neon_env_builder: NeonEnvBuilder, make_httpserver,
|
||||
env.pageservers[2].id: ("Detached", None),
|
||||
}
|
||||
|
||||
# Track all the attached locations with mode and generation
|
||||
history: list[tuple[int, str, int | None]] = []
|
||||
|
||||
def may_read(pageserver: NeonPageserver, mode: str, generation: int | None) -> bool:
|
||||
# Rules for when a pageserver may read:
|
||||
# - our generation is higher than any previous
|
||||
# - our generation is equal to previous, but no other pageserver
|
||||
# in that generation has been AttachedSingle (i.e. allowed to compact/GC)
|
||||
# - our generation is equal to previous, and the previous holder of this
|
||||
# generation was the same node as we're attaching now.
|
||||
#
|
||||
# If these conditions are not met, then a read _might_ work, but the pageserver might
|
||||
# also hit errors trying to download layers.
|
||||
highest_historic_generation = max([i[2] for i in history if i[2] is not None], default=None)
|
||||
|
||||
if generation is None:
|
||||
# We're not in an attached state, we may not read
|
||||
return False
|
||||
elif highest_historic_generation is not None and generation < highest_historic_generation:
|
||||
# We are in an outdated generation, we may not read
|
||||
return False
|
||||
elif highest_historic_generation is not None and generation == highest_historic_generation:
|
||||
# We are re-using a generation: if any pageserver other than this one
|
||||
# has held AttachedSingle mode, this node may not read (because some other
|
||||
# node may be doing GC/compaction).
|
||||
if any(
|
||||
i[1] == "AttachedSingle"
|
||||
and i[2] == highest_historic_generation
|
||||
and i[0] != pageserver.id
|
||||
for i in history
|
||||
):
|
||||
log.info(
|
||||
f"Skipping read on {pageserver.id} because other pageserver has been in AttachedSingle mode in generation {highest_historic_generation}"
|
||||
)
|
||||
return False
|
||||
|
||||
# Fall through: we have passed conditions for readability
|
||||
return True
|
||||
|
||||
latest_attached = env.pageservers[0].id
|
||||
|
||||
for _i in range(0, 64):
|
||||
@@ -199,9 +241,10 @@ def test_location_conf_churn(neon_env_builder: NeonEnvBuilder, make_httpserver,
|
||||
assert len(tenants) == 1
|
||||
assert tenants[0]["generation"] == new_generation
|
||||
|
||||
log.info("Entering postgres...")
|
||||
workload.churn_rows(rng.randint(128, 256), pageserver.id)
|
||||
workload.validate(pageserver.id)
|
||||
if may_read(pageserver, last_state_ps[0], last_state_ps[1]):
|
||||
log.info("Entering postgres...")
|
||||
workload.churn_rows(rng.randint(128, 256), pageserver.id)
|
||||
workload.validate(pageserver.id)
|
||||
elif last_state_ps[0].startswith("Attached"):
|
||||
# The `storage_controller` will only re-attach on startup when a pageserver was the
|
||||
# holder of the latest generation: otherwise the pageserver will revert to detached
|
||||
@@ -241,18 +284,16 @@ def test_location_conf_churn(neon_env_builder: NeonEnvBuilder, make_httpserver,
|
||||
location_conf["generation"] = generation
|
||||
|
||||
pageserver.tenant_location_configure(tenant_id, location_conf)
|
||||
|
||||
last_state[pageserver.id] = (mode, generation)
|
||||
|
||||
# It's only valid to connect to the last generation. Newer generations may yank layer
|
||||
# files used in older generations.
|
||||
last_generation = max(
|
||||
[s[1] for s in last_state.values() if s[1] is not None], default=None
|
||||
)
|
||||
may_read_this_generation = may_read(pageserver, mode, generation)
|
||||
history.append((pageserver.id, mode, generation))
|
||||
|
||||
if mode.startswith("Attached") and generation == last_generation:
|
||||
# This is a basic test: we are validating that he endpoint works properly _between_
|
||||
# configuration changes. A stronger test would be to validate that clients see
|
||||
# no errors while we are making the changes.
|
||||
# This is a basic test: we are validating that he endpoint works properly _between_
|
||||
# configuration changes. A stronger test would be to validate that clients see
|
||||
# no errors while we are making the changes.
|
||||
if may_read_this_generation:
|
||||
workload.churn_rows(
|
||||
rng.randint(128, 256), pageserver.id, upload=mode != "AttachedStale"
|
||||
)
|
||||
@@ -265,9 +306,16 @@ def test_location_conf_churn(neon_env_builder: NeonEnvBuilder, make_httpserver,
|
||||
assert gc_summary["remote_storage_errors"] == 0
|
||||
assert gc_summary["indices_deleted"] > 0
|
||||
|
||||
# Attach all pageservers
|
||||
# Attach all pageservers, in a higher generation than any previous. We will use the same
|
||||
# gen for all, and AttachedMulti mode so that they do not interfere with one another.
|
||||
generation = env.storage_controller.attach_hook_issue(tenant_id, env.pageservers[0].id)
|
||||
for ps in env.pageservers:
|
||||
location_conf = {"mode": "AttachedMulti", "secondary_conf": None, "tenant_conf": {}}
|
||||
location_conf = {
|
||||
"mode": "AttachedMulti",
|
||||
"secondary_conf": None,
|
||||
"tenant_conf": {},
|
||||
"generation": generation,
|
||||
}
|
||||
ps.tenant_location_configure(tenant_id, location_conf)
|
||||
|
||||
# Confirm that all are readable
|
||||
|
||||
@@ -4192,10 +4192,10 @@ def test_storcon_create_delete_sk_down(
|
||||
# ensure the safekeeper deleted the timeline
|
||||
def timeline_deleted_on_active_sks():
|
||||
env.safekeepers[0].assert_log_contains(
|
||||
f"deleting timeline {tenant_id}/{child_timeline_id} from disk"
|
||||
f"((deleting timeline|Timeline) {tenant_id}/{child_timeline_id} (from disk|was already deleted)|DELETE.*tenant/{tenant_id} .*status: 200 OK)"
|
||||
)
|
||||
env.safekeepers[2].assert_log_contains(
|
||||
f"deleting timeline {tenant_id}/{child_timeline_id} from disk"
|
||||
f"((deleting timeline|Timeline) {tenant_id}/{child_timeline_id} (from disk|was already deleted)|DELETE.*tenant/{tenant_id} .*status: 200 OK)"
|
||||
)
|
||||
|
||||
wait_until(timeline_deleted_on_active_sks)
|
||||
@@ -4210,7 +4210,7 @@ def test_storcon_create_delete_sk_down(
|
||||
# ensure that there is log msgs for the third safekeeper too
|
||||
def timeline_deleted_on_sk():
|
||||
env.safekeepers[1].assert_log_contains(
|
||||
f"deleting timeline {tenant_id}/{child_timeline_id} from disk"
|
||||
f"((deleting timeline|Timeline) {tenant_id}/{child_timeline_id} (from disk|was already deleted)|DELETE.*tenant/{tenant_id} .*status: 200 OK)"
|
||||
)
|
||||
|
||||
wait_until(timeline_deleted_on_sk)
|
||||
|
||||
@@ -540,13 +540,16 @@ def test_recovery_uncommitted(neon_env_builder: NeonEnvBuilder):
|
||||
asyncio.run(run_recovery_uncommitted(env))
|
||||
|
||||
|
||||
async def run_wal_truncation(env: NeonEnv):
|
||||
async def run_wal_truncation(env: NeonEnv, safekeeper_proto_version: int):
|
||||
tenant_id = env.initial_tenant
|
||||
timeline_id = env.initial_timeline
|
||||
|
||||
(sk1, sk2, sk3) = env.safekeepers
|
||||
|
||||
ep = env.endpoints.create_start("main")
|
||||
config_lines = [
|
||||
f"neon.safekeeper_proto_version = {safekeeper_proto_version}",
|
||||
]
|
||||
ep = env.endpoints.create_start("main", config_lines=config_lines)
|
||||
ep.safe_psql("create table t (key int, value text)")
|
||||
ep.safe_psql("insert into t select generate_series(1, 100), 'payload'")
|
||||
|
||||
@@ -565,6 +568,7 @@ async def run_wal_truncation(env: NeonEnv):
|
||||
sk2.start()
|
||||
ep = env.endpoints.create_start(
|
||||
"main",
|
||||
config_lines=config_lines,
|
||||
)
|
||||
ep.safe_psql("insert into t select generate_series(1, 200), 'payload'")
|
||||
|
||||
@@ -583,11 +587,13 @@ async def run_wal_truncation(env: NeonEnv):
|
||||
|
||||
# Simple deterministic test creating tail of WAL on safekeeper which is
|
||||
# truncated when majority without this sk elects walproposer starting earlier.
|
||||
def test_wal_truncation(neon_env_builder: NeonEnvBuilder):
|
||||
# Test both proto versions until we fully migrate.
|
||||
@pytest.mark.parametrize("safekeeper_proto_version", [2, 3])
|
||||
def test_wal_truncation(neon_env_builder: NeonEnvBuilder, safekeeper_proto_version: int):
|
||||
neon_env_builder.num_safekeepers = 3
|
||||
env = neon_env_builder.init_start()
|
||||
|
||||
asyncio.run(run_wal_truncation(env))
|
||||
asyncio.run(run_wal_truncation(env, safekeeper_proto_version))
|
||||
|
||||
|
||||
async def quorum_sanity_single(
|
||||
|
||||
Reference in New Issue
Block a user