mirror of
https://github.com/neondatabase/neon.git
synced 2026-02-10 14:10:37 +00:00
Compare commits
13 Commits
proxy-fix-
...
wal_lz4_co
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
42d0b040f8 | ||
|
|
9832638c09 | ||
|
|
62e7638c69 | ||
|
|
0dad8e427d | ||
|
|
4cfa2fdca5 | ||
|
|
56ddf8e37f | ||
|
|
ed4bb3073f | ||
|
|
b7e7aeed4d | ||
|
|
a880178cca | ||
|
|
b09d686335 | ||
|
|
74d24582cf | ||
|
|
4834d22d2d | ||
|
|
86e8c43ddf |
216
Cargo.lock
generated
216
Cargo.lock
generated
@@ -285,7 +285,7 @@ dependencies = [
|
||||
"futures",
|
||||
"git-version",
|
||||
"humantime",
|
||||
"hyper 0.14.26",
|
||||
"hyper",
|
||||
"metrics",
|
||||
"once_cell",
|
||||
"pageserver_api",
|
||||
@@ -331,7 +331,7 @@ dependencies = [
|
||||
"fastrand 2.0.0",
|
||||
"hex",
|
||||
"http 0.2.9",
|
||||
"hyper 0.14.26",
|
||||
"hyper",
|
||||
"ring 0.17.6",
|
||||
"time",
|
||||
"tokio",
|
||||
@@ -368,7 +368,7 @@ dependencies = [
|
||||
"bytes",
|
||||
"fastrand 2.0.0",
|
||||
"http 0.2.9",
|
||||
"http-body 0.4.5",
|
||||
"http-body",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"tracing",
|
||||
@@ -396,7 +396,7 @@ dependencies = [
|
||||
"aws-types",
|
||||
"bytes",
|
||||
"http 0.2.9",
|
||||
"http-body 0.4.5",
|
||||
"http-body",
|
||||
"once_cell",
|
||||
"percent-encoding",
|
||||
"regex-lite",
|
||||
@@ -547,7 +547,7 @@ dependencies = [
|
||||
"crc32fast",
|
||||
"hex",
|
||||
"http 0.2.9",
|
||||
"http-body 0.4.5",
|
||||
"http-body",
|
||||
"md-5",
|
||||
"pin-project-lite",
|
||||
"sha1",
|
||||
@@ -579,7 +579,7 @@ dependencies = [
|
||||
"bytes-utils",
|
||||
"futures-core",
|
||||
"http 0.2.9",
|
||||
"http-body 0.4.5",
|
||||
"http-body",
|
||||
"once_cell",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
@@ -618,10 +618,10 @@ dependencies = [
|
||||
"aws-smithy-types",
|
||||
"bytes",
|
||||
"fastrand 2.0.0",
|
||||
"h2 0.3.24",
|
||||
"h2",
|
||||
"http 0.2.9",
|
||||
"http-body 0.4.5",
|
||||
"hyper 0.14.26",
|
||||
"http-body",
|
||||
"hyper",
|
||||
"hyper-rustls",
|
||||
"once_cell",
|
||||
"pin-project-lite",
|
||||
@@ -658,7 +658,7 @@ dependencies = [
|
||||
"bytes-utils",
|
||||
"futures-core",
|
||||
"http 0.2.9",
|
||||
"http-body 0.4.5",
|
||||
"http-body",
|
||||
"itoa",
|
||||
"num-integer",
|
||||
"pin-project-lite",
|
||||
@@ -707,8 +707,8 @@ dependencies = [
|
||||
"bytes",
|
||||
"futures-util",
|
||||
"http 0.2.9",
|
||||
"http-body 0.4.5",
|
||||
"hyper 0.14.26",
|
||||
"http-body",
|
||||
"hyper",
|
||||
"itoa",
|
||||
"matchit",
|
||||
"memchr",
|
||||
@@ -723,7 +723,7 @@ dependencies = [
|
||||
"sha1",
|
||||
"sync_wrapper",
|
||||
"tokio",
|
||||
"tokio-tungstenite 0.20.0",
|
||||
"tokio-tungstenite",
|
||||
"tower",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
@@ -739,7 +739,7 @@ dependencies = [
|
||||
"bytes",
|
||||
"futures-util",
|
||||
"http 0.2.9",
|
||||
"http-body 0.4.5",
|
||||
"http-body",
|
||||
"mime",
|
||||
"rustversion",
|
||||
"tower-layer",
|
||||
@@ -1228,7 +1228,7 @@ dependencies = [
|
||||
"compute_api",
|
||||
"flate2",
|
||||
"futures",
|
||||
"hyper 0.14.26",
|
||||
"hyper",
|
||||
"nix 0.27.1",
|
||||
"notify",
|
||||
"num_cpus",
|
||||
@@ -1344,7 +1344,7 @@ dependencies = [
|
||||
"futures",
|
||||
"git-version",
|
||||
"hex",
|
||||
"hyper 0.14.26",
|
||||
"hyper",
|
||||
"nix 0.27.1",
|
||||
"once_cell",
|
||||
"pageserver_api",
|
||||
@@ -2244,25 +2244,6 @@ dependencies = [
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "h2"
|
||||
version = "0.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "31d030e59af851932b72ceebadf4a2b5986dba4c3b99dd2493f8273a0f151943"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"fnv",
|
||||
"futures-core",
|
||||
"futures-sink",
|
||||
"futures-util",
|
||||
"http 1.0.0",
|
||||
"indexmap 2.0.1",
|
||||
"slab",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "half"
|
||||
version = "1.8.2"
|
||||
@@ -2428,29 +2409,6 @@ dependencies = [
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http-body"
|
||||
version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1cac85db508abc24a2e48553ba12a996e87244a0395ce011e62b37158745d643"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"http 1.0.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http-body-util"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "41cb79eb393015dadd30fc252023adb0b2400a0caee0fa2a077e6e21a551e840"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"futures-util",
|
||||
"http 1.0.0",
|
||||
"http-body 1.0.0",
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http-types"
|
||||
version = "2.12.0"
|
||||
@@ -2509,9 +2467,9 @@ dependencies = [
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"h2 0.3.24",
|
||||
"h2",
|
||||
"http 0.2.9",
|
||||
"http-body 0.4.5",
|
||||
"http-body",
|
||||
"httparse",
|
||||
"httpdate",
|
||||
"itoa",
|
||||
@@ -2523,26 +2481,6 @@ dependencies = [
|
||||
"want",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyper"
|
||||
version = "1.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "186548d73ac615b32a73aafe38fb4f56c0d340e110e5a200bcadbaf2e199263a"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"futures-channel",
|
||||
"futures-util",
|
||||
"h2 0.4.2",
|
||||
"http 1.0.0",
|
||||
"http-body 1.0.0",
|
||||
"httparse",
|
||||
"httpdate",
|
||||
"itoa",
|
||||
"pin-project-lite",
|
||||
"smallvec",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyper-rustls"
|
||||
version = "0.24.0"
|
||||
@@ -2550,7 +2488,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0646026eb1b3eea4cd9ba47912ea5ce9cc07713d105b1a14698f4e6433d348b7"
|
||||
dependencies = [
|
||||
"http 0.2.9",
|
||||
"hyper 0.14.26",
|
||||
"hyper",
|
||||
"log",
|
||||
"rustls 0.21.9",
|
||||
"rustls-native-certs",
|
||||
@@ -2564,7 +2502,7 @@ version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1"
|
||||
dependencies = [
|
||||
"hyper 0.14.26",
|
||||
"hyper",
|
||||
"pin-project-lite",
|
||||
"tokio",
|
||||
"tokio-io-timeout",
|
||||
@@ -2577,7 +2515,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"hyper 0.14.26",
|
||||
"hyper",
|
||||
"native-tls",
|
||||
"tokio",
|
||||
"tokio-native-tls",
|
||||
@@ -2585,33 +2523,15 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "hyper-tungstenite"
|
||||
version = "0.13.0"
|
||||
version = "0.11.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7a343d17fe7885302ed7252767dc7bb83609a874b6ff581142241ec4b73957ad"
|
||||
checksum = "7cc7dcb1ab67cd336f468a12491765672e61a3b6b148634dbfe2fe8acd3fe7d9"
|
||||
dependencies = [
|
||||
"http-body-util",
|
||||
"hyper 1.2.0",
|
||||
"hyper-util",
|
||||
"hyper",
|
||||
"pin-project-lite",
|
||||
"tokio",
|
||||
"tokio-tungstenite 0.21.0",
|
||||
"tungstenite 0.21.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyper-util"
|
||||
version = "0.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ca38ef113da30126bbff9cd1705f9273e15d45498615d138b0c20279ac7a76aa"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"futures-util",
|
||||
"http 1.0.0",
|
||||
"http-body 1.0.0",
|
||||
"hyper 1.2.0",
|
||||
"pin-project-lite",
|
||||
"socket2 0.5.5",
|
||||
"tokio",
|
||||
"tokio-tungstenite",
|
||||
"tungstenite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2921,6 +2841,15 @@ version = "0.4.20"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f"
|
||||
|
||||
[[package]]
|
||||
name = "lz4_flex"
|
||||
version = "0.11.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3ea9b256699eda7b0387ffbc776dd625e28bde3918446381781245b7a50349d8"
|
||||
dependencies = [
|
||||
"twox-hash",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "match_cfg"
|
||||
version = "0.1.0"
|
||||
@@ -3588,9 +3517,10 @@ dependencies = [
|
||||
"hex-literal",
|
||||
"humantime",
|
||||
"humantime-serde",
|
||||
"hyper 0.14.26",
|
||||
"hyper",
|
||||
"itertools",
|
||||
"leaky-bucket",
|
||||
"lz4_flex",
|
||||
"md5",
|
||||
"metrics",
|
||||
"nix 0.27.1",
|
||||
@@ -4258,13 +4188,9 @@ dependencies = [
|
||||
"hex",
|
||||
"hmac",
|
||||
"hostname",
|
||||
"http 1.0.0",
|
||||
"http-body-util",
|
||||
"humantime",
|
||||
"hyper 0.14.26",
|
||||
"hyper 1.2.0",
|
||||
"hyper",
|
||||
"hyper-tungstenite",
|
||||
"hyper-util",
|
||||
"ipnet",
|
||||
"itertools",
|
||||
"lasso",
|
||||
@@ -4596,7 +4522,7 @@ dependencies = [
|
||||
"futures-util",
|
||||
"http-types",
|
||||
"humantime",
|
||||
"hyper 0.14.26",
|
||||
"hyper",
|
||||
"itertools",
|
||||
"metrics",
|
||||
"once_cell",
|
||||
@@ -4626,10 +4552,10 @@ dependencies = [
|
||||
"encoding_rs",
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"h2 0.3.24",
|
||||
"h2",
|
||||
"http 0.2.9",
|
||||
"http-body 0.4.5",
|
||||
"hyper 0.14.26",
|
||||
"http-body",
|
||||
"hyper",
|
||||
"hyper-rustls",
|
||||
"hyper-tls",
|
||||
"ipnet",
|
||||
@@ -4687,7 +4613,7 @@ dependencies = [
|
||||
"futures",
|
||||
"getrandom 0.2.11",
|
||||
"http 0.2.9",
|
||||
"hyper 0.14.26",
|
||||
"hyper",
|
||||
"parking_lot 0.11.2",
|
||||
"reqwest",
|
||||
"reqwest-middleware",
|
||||
@@ -4774,7 +4700,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "496c1d3718081c45ba9c31fbfc07417900aa96f4070ff90dc29961836b7a9945"
|
||||
dependencies = [
|
||||
"http 0.2.9",
|
||||
"hyper 0.14.26",
|
||||
"hyper",
|
||||
"lazy_static",
|
||||
"percent-encoding",
|
||||
"regex",
|
||||
@@ -5053,7 +4979,7 @@ dependencies = [
|
||||
"git-version",
|
||||
"hex",
|
||||
"humantime",
|
||||
"hyper 0.14.26",
|
||||
"hyper",
|
||||
"metrics",
|
||||
"once_cell",
|
||||
"parking_lot 0.12.1",
|
||||
@@ -5528,9 +5454,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "smallvec"
|
||||
version = "1.13.1"
|
||||
version = "1.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7"
|
||||
checksum = "62bb4feee49fdd9f707ef802e22365a35de4b7b299de4763d44bfea899442ff9"
|
||||
|
||||
[[package]]
|
||||
name = "smol_str"
|
||||
@@ -5622,7 +5548,7 @@ dependencies = [
|
||||
"futures-util",
|
||||
"git-version",
|
||||
"humantime",
|
||||
"hyper 0.14.26",
|
||||
"hyper",
|
||||
"metrics",
|
||||
"once_cell",
|
||||
"parking_lot 0.12.1",
|
||||
@@ -6106,19 +6032,7 @@ dependencies = [
|
||||
"futures-util",
|
||||
"log",
|
||||
"tokio",
|
||||
"tungstenite 0.20.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-tungstenite"
|
||||
version = "0.21.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38"
|
||||
dependencies = [
|
||||
"futures-util",
|
||||
"log",
|
||||
"tokio",
|
||||
"tungstenite 0.21.0",
|
||||
"tungstenite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -6185,10 +6099,10 @@ dependencies = [
|
||||
"bytes",
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"h2 0.3.24",
|
||||
"h2",
|
||||
"http 0.2.9",
|
||||
"http-body 0.4.5",
|
||||
"hyper 0.14.26",
|
||||
"http-body",
|
||||
"hyper",
|
||||
"hyper-timeout",
|
||||
"percent-encoding",
|
||||
"pin-project",
|
||||
@@ -6374,7 +6288,7 @@ dependencies = [
|
||||
name = "tracing-utils"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"hyper 0.14.26",
|
||||
"hyper",
|
||||
"opentelemetry",
|
||||
"opentelemetry-otlp",
|
||||
"opentelemetry-semantic-conventions",
|
||||
@@ -6411,25 +6325,6 @@ dependencies = [
|
||||
"utf-8",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tungstenite"
|
||||
version = "0.21.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"data-encoding",
|
||||
"http 1.0.0",
|
||||
"httparse",
|
||||
"log",
|
||||
"rand 0.8.5",
|
||||
"sha1",
|
||||
"thiserror",
|
||||
"url",
|
||||
"utf-8",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "twox-hash"
|
||||
version = "1.6.3"
|
||||
@@ -6593,7 +6488,7 @@ dependencies = [
|
||||
"heapless",
|
||||
"hex",
|
||||
"hex-literal",
|
||||
"hyper 0.14.26",
|
||||
"hyper",
|
||||
"jsonwebtoken",
|
||||
"leaky-bucket",
|
||||
"metrics",
|
||||
@@ -7118,7 +7013,7 @@ dependencies = [
|
||||
"hashbrown 0.14.0",
|
||||
"hex",
|
||||
"hmac",
|
||||
"hyper 0.14.26",
|
||||
"hyper",
|
||||
"indexmap 1.9.3",
|
||||
"itertools",
|
||||
"libc",
|
||||
@@ -7155,6 +7050,7 @@ dependencies = [
|
||||
"tower",
|
||||
"tracing",
|
||||
"tracing-core",
|
||||
"tungstenite",
|
||||
"url",
|
||||
"uuid",
|
||||
"zeroize",
|
||||
|
||||
@@ -92,7 +92,7 @@ http-types = { version = "2", default-features = false }
|
||||
humantime = "2.1"
|
||||
humantime-serde = "1.1.1"
|
||||
hyper = "0.14"
|
||||
hyper-tungstenite = "0.13.0"
|
||||
hyper-tungstenite = "0.11"
|
||||
inotify = "0.10.2"
|
||||
ipnet = "2.9.0"
|
||||
itertools = "0.10"
|
||||
@@ -100,6 +100,7 @@ jsonwebtoken = "9"
|
||||
lasso = "0.7"
|
||||
leaky-bucket = "1.0.1"
|
||||
libc = "0.2"
|
||||
lz4_flex = "0.11.1"
|
||||
md5 = "0.7.0"
|
||||
memoffset = "0.8"
|
||||
native-tls = "0.2"
|
||||
|
||||
4
Makefile
4
Makefile
@@ -23,12 +23,12 @@ endif
|
||||
UNAME_S := $(shell uname -s)
|
||||
ifeq ($(UNAME_S),Linux)
|
||||
# Seccomp BPF is only available for Linux
|
||||
PG_CONFIGURE_OPTS += --with-libseccomp
|
||||
PG_CONFIGURE_OPTS += --with-lz4 --with-libseccomp
|
||||
else ifeq ($(UNAME_S),Darwin)
|
||||
# macOS with brew-installed openssl requires explicit paths
|
||||
# It can be configured with OPENSSL_PREFIX variable
|
||||
OPENSSL_PREFIX ?= $(shell brew --prefix openssl@3)
|
||||
PG_CONFIGURE_OPTS += --with-includes=$(OPENSSL_PREFIX)/include --with-libraries=$(OPENSSL_PREFIX)/lib
|
||||
PG_CONFIGURE_OPTS += --with-lz4 --with-includes=$(OPENSSL_PREFIX)/include --with-libraries=$(OPENSSL_PREFIX)/lib
|
||||
PG_CONFIGURE_OPTS += PKG_CONFIG_PATH=$(shell brew --prefix icu4c)/lib/pkgconfig
|
||||
# macOS already has bison and flex in the system, but they are old and result in postgres-v14 target failure
|
||||
# brew formulae are keg-only and not symlinked into HOMEBREW_PREFIX, force their usage
|
||||
|
||||
@@ -302,9 +302,9 @@ pub fn handle_roles(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
|
||||
RoleAction::Create => {
|
||||
// This branch only runs when roles are created through the console, so it is
|
||||
// safe to add more permissions here. BYPASSRLS and REPLICATION are inherited
|
||||
// from neon_superuser.
|
||||
// from neon_superuser. (NOTE: REPLICATION has been removed from here for now).
|
||||
let mut query: String = format!(
|
||||
"CREATE ROLE {} INHERIT CREATEROLE CREATEDB BYPASSRLS REPLICATION IN ROLE neon_superuser",
|
||||
"CREATE ROLE {} INHERIT CREATEROLE CREATEDB BYPASSRLS IN ROLE neon_superuser",
|
||||
name.pg_quote()
|
||||
);
|
||||
info!("running role create query: '{}'", &query);
|
||||
@@ -805,6 +805,18 @@ $$;"#,
|
||||
"",
|
||||
"",
|
||||
// Add new migrations below.
|
||||
r#"
|
||||
DO $$
|
||||
DECLARE
|
||||
role_name TEXT;
|
||||
BEGIN
|
||||
FOR role_name IN SELECT rolname FROM pg_roles WHERE rolreplication IS TRUE
|
||||
LOOP
|
||||
RAISE NOTICE 'EXECUTING ALTER ROLE % NOREPLICATION', quote_ident(role_name);
|
||||
EXECUTE 'ALTER ROLE ' || quote_ident(role_name) || ' NOREPLICATION';
|
||||
END LOOP;
|
||||
END
|
||||
$$;"#,
|
||||
];
|
||||
|
||||
let mut query = "CREATE SCHEMA IF NOT EXISTS neon_migration";
|
||||
|
||||
@@ -29,7 +29,6 @@ pub mod launch_timestamp;
|
||||
mod wrappers;
|
||||
pub use wrappers::{CountedReader, CountedWriter};
|
||||
mod hll;
|
||||
pub mod metric_vec_duration;
|
||||
pub use hll::{HyperLogLog, HyperLogLogVec};
|
||||
#[cfg(target_os = "linux")]
|
||||
pub mod more_process_metrics;
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
//! Helpers for observing duration on `HistogramVec` / `CounterVec` / `GaugeVec` / `MetricVec<T>`.
|
||||
|
||||
use std::{future::Future, time::Instant};
|
||||
|
||||
pub trait DurationResultObserver {
|
||||
fn observe_result<T, E>(&self, res: &Result<T, E>, duration: std::time::Duration);
|
||||
}
|
||||
|
||||
pub async fn observe_async_block_duration_by_result<
|
||||
T,
|
||||
E,
|
||||
F: Future<Output = Result<T, E>>,
|
||||
O: DurationResultObserver,
|
||||
>(
|
||||
observer: &O,
|
||||
block: F,
|
||||
) -> Result<T, E> {
|
||||
let start = Instant::now();
|
||||
let result = block.await;
|
||||
let duration = start.elapsed();
|
||||
observer.observe_result(&result, duration);
|
||||
result
|
||||
}
|
||||
@@ -757,6 +757,7 @@ pub enum PagestreamBeMessage {
|
||||
Error(PagestreamErrorResponse),
|
||||
DbSize(PagestreamDbSizeResponse),
|
||||
GetSlruSegment(PagestreamGetSlruSegmentResponse),
|
||||
GetCompressedPage(PagestreamGetPageResponse),
|
||||
}
|
||||
|
||||
// Keep in sync with `pagestore_client.h`
|
||||
@@ -996,6 +997,12 @@ impl PagestreamBeMessage {
|
||||
bytes.put(&resp.page[..]);
|
||||
}
|
||||
|
||||
Self::GetCompressedPage(resp) => {
|
||||
bytes.put_u8(105); /* tag from pagestore_client.h */
|
||||
bytes.put_u16(resp.page.len() as u16);
|
||||
bytes.put(&resp.page[..]);
|
||||
}
|
||||
|
||||
Self::Error(resp) => {
|
||||
bytes.put_u8(Tag::Error as u8);
|
||||
bytes.put(resp.message.as_bytes());
|
||||
@@ -1078,6 +1085,7 @@ impl PagestreamBeMessage {
|
||||
Self::Error(_) => "Error",
|
||||
Self::DbSize(_) => "DbSize",
|
||||
Self::GetSlruSegment(_) => "GetSlruSegment",
|
||||
Self::GetCompressedPage(_) => "GetCompressedPage",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -144,6 +144,13 @@ pub fn bkpimage_is_compressed(bimg_info: u8, version: u32) -> anyhow::Result<boo
|
||||
dispatch_pgversion!(version, Ok(pgv::bindings::bkpimg_is_compressed(bimg_info)))
|
||||
}
|
||||
|
||||
pub fn bkpimage_is_compressed_lz4(bimg_info: u8, version: u32) -> anyhow::Result<bool> {
|
||||
dispatch_pgversion!(
|
||||
version,
|
||||
Ok(pgv::bindings::bkpimg_is_compressed_lz4(bimg_info))
|
||||
)
|
||||
}
|
||||
|
||||
pub fn generate_wal_segment(
|
||||
segno: u64,
|
||||
system_id: u64,
|
||||
|
||||
@@ -8,3 +8,7 @@ pub const SIZEOF_RELMAPFILE: usize = 512; /* sizeof(RelMapFile) in relmapper.c *
|
||||
pub fn bkpimg_is_compressed(bimg_info: u8) -> bool {
|
||||
(bimg_info & BKPIMAGE_IS_COMPRESSED) != 0
|
||||
}
|
||||
|
||||
pub fn bkpimg_is_compressed_lz4(_bimg_info: u8) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
@@ -16,3 +16,7 @@ pub fn bkpimg_is_compressed(bimg_info: u8) -> bool {
|
||||
|
||||
(bimg_info & ANY_COMPRESS_FLAG) != 0
|
||||
}
|
||||
|
||||
pub fn bkpimg_is_compressed_lz4(bimg_info: u8) -> bool {
|
||||
(bimg_info & BKPIMAGE_COMPRESS_LZ4) != 0
|
||||
}
|
||||
|
||||
@@ -16,3 +16,7 @@ pub fn bkpimg_is_compressed(bimg_info: u8) -> bool {
|
||||
|
||||
(bimg_info & ANY_COMPRESS_FLAG) != 0
|
||||
}
|
||||
|
||||
pub fn bkpimg_is_compressed_lz4(bimg_info: u8) -> bool {
|
||||
(bimg_info & BKPIMAGE_COMPRESS_LZ4) != 0
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ use remote_storage::{
|
||||
};
|
||||
use test_context::test_context;
|
||||
use test_context::AsyncTestContext;
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::info;
|
||||
|
||||
@@ -484,32 +485,33 @@ async fn download_is_cancelled(ctx: &mut MaybeEnabledStorage) {
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
let len = upload_large_enough_file(&ctx.client, &path, &cancel).await;
|
||||
let file_len = upload_large_enough_file(&ctx.client, &path, &cancel).await;
|
||||
|
||||
{
|
||||
let mut stream = ctx
|
||||
let stream = ctx
|
||||
.client
|
||||
.download(&path, &cancel)
|
||||
.await
|
||||
.expect("download succeeds")
|
||||
.download_stream;
|
||||
|
||||
let first = stream
|
||||
.next()
|
||||
.await
|
||||
.expect("should have the first blob")
|
||||
.expect("should have succeeded");
|
||||
let mut reader = std::pin::pin!(tokio_util::io::StreamReader::new(stream));
|
||||
|
||||
tracing::info!(len = first.len(), "downloaded first chunk");
|
||||
let first = reader.fill_buf().await.expect("should have the first blob");
|
||||
|
||||
let len = first.len();
|
||||
tracing::info!(len, "downloaded first chunk");
|
||||
|
||||
assert!(
|
||||
first.len() < len,
|
||||
first.len() < file_len,
|
||||
"uploaded file is too small, we downloaded all on first chunk"
|
||||
);
|
||||
|
||||
reader.consume(len);
|
||||
|
||||
cancel.cancel();
|
||||
|
||||
let next = stream.next().await.expect("stream should have more");
|
||||
let next = reader.fill_buf().await;
|
||||
|
||||
let e = next.expect_err("expected an error, but got a chunk?");
|
||||
|
||||
@@ -520,6 +522,10 @@ async fn download_is_cancelled(ctx: &mut MaybeEnabledStorage) {
|
||||
.is_some_and(|e| matches!(e, DownloadError::Cancelled)),
|
||||
"{inner:?}"
|
||||
);
|
||||
|
||||
let e = DownloadError::from(e);
|
||||
|
||||
assert!(matches!(e, DownloadError::Cancelled), "{e:?}");
|
||||
}
|
||||
|
||||
let cancel = CancellationToken::new();
|
||||
|
||||
@@ -37,6 +37,7 @@ humantime-serde.workspace = true
|
||||
hyper.workspace = true
|
||||
itertools.workspace = true
|
||||
leaky-bucket.workspace = true
|
||||
lz4_flex.workspace = true
|
||||
md5.workspace = true
|
||||
nix.workspace = true
|
||||
# hack to get the number of worker threads tokio uses
|
||||
|
||||
@@ -157,6 +157,7 @@ impl PagestreamClient {
|
||||
PagestreamBeMessage::Exists(_)
|
||||
| PagestreamBeMessage::Nblocks(_)
|
||||
| PagestreamBeMessage::DbSize(_)
|
||||
| PagestreamBeMessage::GetCompressedPage(_)
|
||||
| PagestreamBeMessage::GetSlruSegment(_) => {
|
||||
anyhow::bail!(
|
||||
"unexpected be message kind in response to getpage request: {}",
|
||||
|
||||
@@ -40,7 +40,14 @@ use tracing::info;
|
||||
/// format, bump this!
|
||||
/// Note that TimelineMetadata uses its own version number to track
|
||||
/// backwards-compatible changes to the metadata format.
|
||||
pub const STORAGE_FORMAT_VERSION: u16 = 3;
|
||||
pub const STORAGE_FORMAT_VERSION: u16 = 4;
|
||||
|
||||
/// Minimal sorage format version with compression support
|
||||
pub const COMPRESSED_STORAGE_FORMAT_VERSION: u16 = 4;
|
||||
|
||||
/// Page image compression algorithm
|
||||
pub const NO_COMPRESSION: u8 = 0;
|
||||
pub const LZ4_COMPRESSION: u8 = 0;
|
||||
|
||||
pub const DEFAULT_PG_VERSION: u32 = 15;
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
use enum_map::EnumMap;
|
||||
use metrics::metric_vec_duration::DurationResultObserver;
|
||||
use metrics::{
|
||||
register_counter_vec, register_gauge_vec, register_histogram, register_histogram_vec,
|
||||
register_int_counter, register_int_counter_pair_vec, register_int_counter_vec,
|
||||
@@ -1283,11 +1282,65 @@ pub(crate) static BASEBACKUP_QUERY_TIME: Lazy<BasebackupQueryTime> = Lazy::new(|
|
||||
})
|
||||
});
|
||||
|
||||
impl DurationResultObserver for BasebackupQueryTime {
|
||||
fn observe_result<T, E>(&self, res: &Result<T, E>, duration: std::time::Duration) {
|
||||
pub(crate) struct BasebackupQueryTimeOngoingRecording<'a, 'c> {
|
||||
parent: &'a BasebackupQueryTime,
|
||||
ctx: &'c RequestContext,
|
||||
start: std::time::Instant,
|
||||
}
|
||||
|
||||
impl BasebackupQueryTime {
|
||||
pub(crate) fn start_recording<'c: 'a, 'a>(
|
||||
&'a self,
|
||||
ctx: &'c RequestContext,
|
||||
) -> BasebackupQueryTimeOngoingRecording<'_, '_> {
|
||||
let start = Instant::now();
|
||||
match ctx.micros_spent_throttled.open() {
|
||||
Ok(()) => (),
|
||||
Err(error) => {
|
||||
use utils::rate_limit::RateLimit;
|
||||
static LOGGED: Lazy<Mutex<RateLimit>> =
|
||||
Lazy::new(|| Mutex::new(RateLimit::new(Duration::from_secs(10))));
|
||||
let mut rate_limit = LOGGED.lock().unwrap();
|
||||
rate_limit.call(|| {
|
||||
warn!(error, "error opening micros_spent_throttled; this message is logged at a global rate limit");
|
||||
});
|
||||
}
|
||||
}
|
||||
BasebackupQueryTimeOngoingRecording {
|
||||
parent: self,
|
||||
ctx,
|
||||
start,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'c> BasebackupQueryTimeOngoingRecording<'a, 'c> {
|
||||
pub(crate) fn observe<T, E>(self, res: &Result<T, E>) {
|
||||
let elapsed = self.start.elapsed();
|
||||
let ex_throttled = self
|
||||
.ctx
|
||||
.micros_spent_throttled
|
||||
.close_and_checked_sub_from(elapsed);
|
||||
let ex_throttled = match ex_throttled {
|
||||
Ok(ex_throttled) => ex_throttled,
|
||||
Err(error) => {
|
||||
use utils::rate_limit::RateLimit;
|
||||
static LOGGED: Lazy<Mutex<RateLimit>> =
|
||||
Lazy::new(|| Mutex::new(RateLimit::new(Duration::from_secs(10))));
|
||||
let mut rate_limit = LOGGED.lock().unwrap();
|
||||
rate_limit.call(|| {
|
||||
warn!(error, "error deducting time spent throttled; this message is logged at a global rate limit");
|
||||
});
|
||||
elapsed
|
||||
}
|
||||
};
|
||||
let label_value = if res.is_ok() { "ok" } else { "error" };
|
||||
let metric = self.0.get_metric_with_label_values(&[label_value]).unwrap();
|
||||
metric.observe(duration.as_secs_f64());
|
||||
let metric = self
|
||||
.parent
|
||||
.0
|
||||
.get_metric_with_label_values(&[label_value])
|
||||
.unwrap();
|
||||
metric.observe(ex_throttled.as_secs_f64());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1155,9 +1155,18 @@ impl PageServerHandler {
|
||||
.get_rel_page_at_lsn(req.rel, req.blkno, Version::Lsn(lsn), req.latest, ctx)
|
||||
.await?;
|
||||
|
||||
Ok(PagestreamBeMessage::GetPage(PagestreamGetPageResponse {
|
||||
page,
|
||||
}))
|
||||
let compressed = lz4_flex::block::compress(&page);
|
||||
if compressed.len() < page.len() {
|
||||
Ok(PagestreamBeMessage::GetCompressedPage(
|
||||
PagestreamGetPageResponse {
|
||||
page: Bytes::from(compressed),
|
||||
},
|
||||
))
|
||||
} else {
|
||||
Ok(PagestreamBeMessage::GetPage(PagestreamGetPageResponse {
|
||||
page,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip_all, fields(shard_id))]
|
||||
@@ -1199,7 +1208,7 @@ impl PageServerHandler {
|
||||
prev_lsn: Option<Lsn>,
|
||||
full_backup: bool,
|
||||
gzip: bool,
|
||||
ctx: RequestContext,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<(), QueryError>
|
||||
where
|
||||
IO: AsyncRead + AsyncWrite + Send + Sync + Unpin,
|
||||
@@ -1214,7 +1223,7 @@ impl PageServerHandler {
|
||||
if let Some(lsn) = lsn {
|
||||
// Backup was requested at a particular LSN. Wait for it to arrive.
|
||||
info!("waiting for {}", lsn);
|
||||
timeline.wait_lsn(lsn, &ctx).await?;
|
||||
timeline.wait_lsn(lsn, ctx).await?;
|
||||
timeline
|
||||
.check_lsn_is_in_scope(lsn, &latest_gc_cutoff_lsn)
|
||||
.context("invalid basebackup lsn")?;
|
||||
@@ -1236,7 +1245,7 @@ impl PageServerHandler {
|
||||
lsn,
|
||||
prev_lsn,
|
||||
full_backup,
|
||||
&ctx,
|
||||
ctx,
|
||||
)
|
||||
.await?;
|
||||
} else {
|
||||
@@ -1257,7 +1266,7 @@ impl PageServerHandler {
|
||||
lsn,
|
||||
prev_lsn,
|
||||
full_backup,
|
||||
&ctx,
|
||||
ctx,
|
||||
)
|
||||
.await?;
|
||||
// shutdown the encoder to ensure the gzip footer is written
|
||||
@@ -1269,7 +1278,7 @@ impl PageServerHandler {
|
||||
lsn,
|
||||
prev_lsn,
|
||||
full_backup,
|
||||
&ctx,
|
||||
ctx,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
@@ -1449,25 +1458,25 @@ where
|
||||
false
|
||||
};
|
||||
|
||||
::metrics::metric_vec_duration::observe_async_block_duration_by_result(
|
||||
&*metrics::BASEBACKUP_QUERY_TIME,
|
||||
async move {
|
||||
self.handle_basebackup_request(
|
||||
pgb,
|
||||
tenant_id,
|
||||
timeline_id,
|
||||
lsn,
|
||||
None,
|
||||
false,
|
||||
gzip,
|
||||
ctx,
|
||||
)
|
||||
.await?;
|
||||
pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
Result::<(), QueryError>::Ok(())
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
let metric_recording = metrics::BASEBACKUP_QUERY_TIME.start_recording(&ctx);
|
||||
let res = async {
|
||||
self.handle_basebackup_request(
|
||||
pgb,
|
||||
tenant_id,
|
||||
timeline_id,
|
||||
lsn,
|
||||
None,
|
||||
false,
|
||||
gzip,
|
||||
&ctx,
|
||||
)
|
||||
.await?;
|
||||
pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
Result::<(), QueryError>::Ok(())
|
||||
}
|
||||
.await;
|
||||
metric_recording.observe(&res);
|
||||
res?;
|
||||
}
|
||||
// return pair of prev_lsn and last_lsn
|
||||
else if query_string.starts_with("get_last_record_rlsn ") {
|
||||
@@ -1563,7 +1572,7 @@ where
|
||||
prev_lsn,
|
||||
true,
|
||||
false,
|
||||
ctx,
|
||||
&ctx,
|
||||
)
|
||||
.await?;
|
||||
pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
|
||||
@@ -16,6 +16,7 @@ use anyhow::{ensure, Context};
|
||||
use bytes::{Buf, Bytes, BytesMut};
|
||||
use enum_map::Enum;
|
||||
use itertools::Itertools;
|
||||
use lz4_flex;
|
||||
use pageserver_api::key::{
|
||||
dbdir_key_range, is_rel_block_key, is_slru_block_key, rel_block_to_key, rel_dir_to_key,
|
||||
rel_key_range, rel_size_to_key, relmap_file_key, slru_block_to_key, slru_dir_to_key,
|
||||
@@ -992,7 +993,15 @@ impl<'a> DatadirModification<'a> {
|
||||
img: Bytes,
|
||||
) -> anyhow::Result<()> {
|
||||
anyhow::ensure!(rel.relnode != 0, RelationError::InvalidRelnode);
|
||||
self.put(rel_block_to_key(rel, blknum), Value::Image(img));
|
||||
let compressed = lz4_flex::block::compress(&img);
|
||||
if compressed.len() < img.len() {
|
||||
self.put(
|
||||
rel_block_to_key(rel, blknum),
|
||||
Value::CompressedImage(Bytes::from(compressed)),
|
||||
);
|
||||
} else {
|
||||
self.put(rel_block_to_key(rel, blknum), Value::Image(img));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1597,6 +1606,10 @@ impl<'a> DatadirModification<'a> {
|
||||
if let Some((_, value)) = values.last() {
|
||||
return if let Value::Image(img) = value {
|
||||
Ok(img.clone())
|
||||
} else if let Value::CompressedImage(img) = value {
|
||||
let decompressed = lz4_flex::block::decompress(&img, BLCKSZ as usize)
|
||||
.map_err(|msg| PageReconstructError::Other(anyhow::anyhow!(msg)))?;
|
||||
Ok(Bytes::from(decompressed))
|
||||
} else {
|
||||
// Currently, we never need to read back a WAL record that we
|
||||
// inserted in the same "transaction". All the metadata updates
|
||||
|
||||
@@ -13,6 +13,8 @@ pub use pageserver_api::key::{Key, KEY_SIZE};
|
||||
pub enum Value {
|
||||
/// An Image value contains a full copy of the value
|
||||
Image(Bytes),
|
||||
/// An compressed page image contains a full copy of the page
|
||||
CompressedImage(Bytes),
|
||||
/// A WalRecord value contains a WAL record that needs to be
|
||||
/// replayed get the full value. Replaying the WAL record
|
||||
/// might need a previous version of the value (if will_init()
|
||||
@@ -22,12 +24,17 @@ pub enum Value {
|
||||
|
||||
impl Value {
|
||||
pub fn is_image(&self) -> bool {
|
||||
matches!(self, Value::Image(_))
|
||||
match self {
|
||||
Value::Image(_) => true,
|
||||
Value::CompressedImage(_) => true,
|
||||
Value::WalRecord(_) => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn will_init(&self) -> bool {
|
||||
match self {
|
||||
Value::Image(_) => true,
|
||||
Value::CompressedImage(_) => true,
|
||||
Value::WalRecord(rec) => rec.will_init(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -272,9 +272,6 @@ pub enum TaskKind {
|
||||
// Task that uploads a file to remote storage
|
||||
RemoteUploadTask,
|
||||
|
||||
// Task that downloads a file from remote storage
|
||||
RemoteDownloadTask,
|
||||
|
||||
// task that handles the initial downloading of all tenants
|
||||
InitialLoad,
|
||||
|
||||
|
||||
@@ -11,13 +11,16 @@
|
||||
//! len < 128: 0XXXXXXX
|
||||
//! len >= 128: 1XXXXXXX XXXXXXXX XXXXXXXX XXXXXXXX
|
||||
//!
|
||||
use bytes::{BufMut, BytesMut};
|
||||
use bytes::{BufMut, Bytes, BytesMut};
|
||||
use tokio_epoll_uring::{BoundedBuf, IoBuf, Slice};
|
||||
|
||||
use crate::context::RequestContext;
|
||||
use crate::page_cache::PAGE_SZ;
|
||||
use crate::tenant::block_io::BlockCursor;
|
||||
use crate::virtual_file::VirtualFile;
|
||||
use crate::{LZ4_COMPRESSION, NO_COMPRESSION};
|
||||
use lz4_flex;
|
||||
use postgres_ffi::BLCKSZ;
|
||||
use std::cmp::min;
|
||||
use std::io::{Error, ErrorKind};
|
||||
|
||||
@@ -32,6 +35,29 @@ impl<'a> BlockCursor<'a> {
|
||||
self.read_blob_into_buf(offset, &mut buf, ctx).await?;
|
||||
Ok(buf)
|
||||
}
|
||||
/// Read blob into the given buffer. Any previous contents in the buffer
|
||||
/// are overwritten.
|
||||
pub async fn read_compressed_blob(
|
||||
&self,
|
||||
offset: u64,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<Vec<u8>, std::io::Error> {
|
||||
let blknum = (offset / PAGE_SZ as u64) as u32;
|
||||
let off = (offset % PAGE_SZ as u64) as usize;
|
||||
|
||||
let buf = self.read_blk(blknum, ctx).await?;
|
||||
let compression_alg = buf[off];
|
||||
let res = self.read_blob(offset + 1, ctx).await?;
|
||||
if compression_alg == LZ4_COMPRESSION {
|
||||
lz4_flex::block::decompress(&res, BLCKSZ as usize).map_err(|_| {
|
||||
std::io::Error::new(std::io::ErrorKind::InvalidData, "decompress error")
|
||||
})
|
||||
} else {
|
||||
assert_eq!(compression_alg, NO_COMPRESSION);
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
/// Read blob into the given buffer. Any previous contents in the buffer
|
||||
/// are overwritten.
|
||||
pub async fn read_blob_into_buf(
|
||||
@@ -211,6 +237,58 @@ impl<const BUFFERED: bool> BlobWriter<BUFFERED> {
|
||||
(src_buf, Ok(()))
|
||||
}
|
||||
|
||||
pub async fn write_compressed_blob(&mut self, srcbuf: Bytes) -> Result<u64, Error> {
|
||||
let offset = self.offset;
|
||||
|
||||
let len = srcbuf.len();
|
||||
|
||||
let mut io_buf = self.io_buf.take().expect("we always put it back below");
|
||||
io_buf.clear();
|
||||
let mut is_compressed = false;
|
||||
if len < 128 {
|
||||
// Short blob. Write a 1-byte length header
|
||||
io_buf.put_u8(NO_COMPRESSION);
|
||||
io_buf.put_u8(len as u8);
|
||||
} else {
|
||||
// Write a 4-byte length header
|
||||
if len > 0x7fff_ffff {
|
||||
return Err(Error::new(
|
||||
ErrorKind::Other,
|
||||
format!("blob too large ({} bytes)", len),
|
||||
));
|
||||
}
|
||||
if len == BLCKSZ as usize {
|
||||
let compressed = lz4_flex::block::compress(&srcbuf);
|
||||
if compressed.len() < len {
|
||||
io_buf.put_u8(LZ4_COMPRESSION);
|
||||
let mut len_buf = (compressed.len() as u32).to_be_bytes();
|
||||
len_buf[0] |= 0x80;
|
||||
io_buf.extend_from_slice(&len_buf[..]);
|
||||
io_buf.extend_from_slice(&compressed[..]);
|
||||
is_compressed = true;
|
||||
}
|
||||
if is_compressed {
|
||||
io_buf.put_u8(NO_COMPRESSION);
|
||||
let mut len_buf = (len as u32).to_be_bytes();
|
||||
len_buf[0] |= 0x80;
|
||||
io_buf.extend_from_slice(&len_buf[..]);
|
||||
}
|
||||
}
|
||||
}
|
||||
let (io_buf, hdr_res) = self.write_all(io_buf).await;
|
||||
match hdr_res {
|
||||
Ok(_) => (),
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
self.io_buf = Some(io_buf);
|
||||
if is_compressed {
|
||||
hdr_res.map(|_| offset)
|
||||
} else {
|
||||
let (_buf, res) = self.write_all(srcbuf).await;
|
||||
res.map(|_| offset)
|
||||
}
|
||||
}
|
||||
|
||||
/// Write a blob of data. Returns the offset that it was written to,
|
||||
/// which can be used to retrieve the data later.
|
||||
pub async fn write_blob<B: BoundedBuf<Buf = Buf>, Buf: IoBuf + Send>(
|
||||
@@ -227,7 +305,6 @@ impl<const BUFFERED: bool> BlobWriter<BUFFERED> {
|
||||
if len < 128 {
|
||||
// Short blob. Write a 1-byte length header
|
||||
io_buf.put_u8(len as u8);
|
||||
self.write_all(io_buf).await
|
||||
} else {
|
||||
// Write a 4-byte length header
|
||||
if len > 0x7fff_ffff {
|
||||
@@ -242,8 +319,8 @@ impl<const BUFFERED: bool> BlobWriter<BUFFERED> {
|
||||
let mut len_buf = (len as u32).to_be_bytes();
|
||||
len_buf[0] |= 0x80;
|
||||
io_buf.extend_from_slice(&len_buf[..]);
|
||||
self.write_all(io_buf).await
|
||||
}
|
||||
self.write_all(io_buf).await
|
||||
}
|
||||
.await;
|
||||
self.io_buf = Some(io_buf);
|
||||
|
||||
@@ -20,6 +20,7 @@ use pageserver_api::keyspace::{KeySpace, KeySpaceRandomAccum};
|
||||
use pageserver_api::models::{
|
||||
LayerAccessKind, LayerResidenceEvent, LayerResidenceEventReason, LayerResidenceStatus,
|
||||
};
|
||||
use postgres_ffi::BLCKSZ;
|
||||
use std::cmp::{Ordering, Reverse};
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::collections::{BinaryHeap, HashMap};
|
||||
@@ -147,12 +148,13 @@ impl ValuesReconstructState {
|
||||
lsn: Lsn,
|
||||
value: Value,
|
||||
) -> ValueReconstructSituation {
|
||||
let state = self
|
||||
let mut error: Option<PageReconstructError> = None;
|
||||
let key_state = self
|
||||
.keys
|
||||
.entry(*key)
|
||||
.or_insert(Ok(VectoredValueReconstructState::default()));
|
||||
|
||||
if let Ok(state) = state {
|
||||
let situation = if let Ok(state) = key_state {
|
||||
let key_done = match state.situation {
|
||||
ValueReconstructSituation::Complete => unreachable!(),
|
||||
ValueReconstructSituation::Continue => match value {
|
||||
@@ -160,6 +162,21 @@ impl ValuesReconstructState {
|
||||
state.img = Some((lsn, img));
|
||||
true
|
||||
}
|
||||
Value::CompressedImage(img) => {
|
||||
match lz4_flex::block::decompress(&img, BLCKSZ as usize) {
|
||||
Ok(decompressed) => {
|
||||
state.img = Some((lsn, Bytes::from(decompressed)));
|
||||
true
|
||||
}
|
||||
Err(e) => {
|
||||
error = Some(PageReconstructError::from(anyhow::anyhow!(
|
||||
"Failed to decompress blobrom virtual file: {}",
|
||||
e
|
||||
)));
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
Value::WalRecord(rec) => {
|
||||
let reached_cache =
|
||||
state.get_cached_lsn().map(|clsn| clsn + 1) == Some(lsn);
|
||||
@@ -178,7 +195,11 @@ impl ValuesReconstructState {
|
||||
state.situation
|
||||
} else {
|
||||
ValueReconstructSituation::Complete
|
||||
};
|
||||
if let Some(err) = error {
|
||||
*key_state = Err(err);
|
||||
}
|
||||
situation
|
||||
}
|
||||
|
||||
/// Returns the Lsn at which this key is cached if one exists.
|
||||
|
||||
@@ -44,12 +44,13 @@ use crate::virtual_file::{self, VirtualFile};
|
||||
use crate::{walrecord, TEMP_FILE_SUFFIX};
|
||||
use crate::{DELTA_FILE_MAGIC, STORAGE_FORMAT_VERSION};
|
||||
use anyhow::{anyhow, bail, ensure, Context, Result};
|
||||
use bytes::BytesMut;
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use camino::{Utf8Path, Utf8PathBuf};
|
||||
use futures::StreamExt;
|
||||
use pageserver_api::keyspace::KeySpace;
|
||||
use pageserver_api::models::LayerAccessKind;
|
||||
use pageserver_api::shard::TenantShardId;
|
||||
use postgres_ffi::BLCKSZ;
|
||||
use rand::{distributions::Alphanumeric, Rng};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fs::File;
|
||||
@@ -813,6 +814,12 @@ impl DeltaLayerInner {
|
||||
need_image = false;
|
||||
break;
|
||||
}
|
||||
Value::CompressedImage(img) => {
|
||||
let decompressed = lz4_flex::block::decompress(&img, BLCKSZ as usize)?;
|
||||
reconstruct_state.img = Some((entry_lsn, Bytes::from(decompressed)));
|
||||
need_image = false;
|
||||
break;
|
||||
}
|
||||
Value::WalRecord(rec) => {
|
||||
let will_init = rec.will_init();
|
||||
reconstruct_state.records.push((entry_lsn, rec));
|
||||
@@ -1102,6 +1109,9 @@ impl DeltaLayerInner {
|
||||
Value::Image(img) => {
|
||||
format!(" img {} bytes", img.len())
|
||||
}
|
||||
Value::CompressedImage(img) => {
|
||||
format!(" compressed img {} bytes", img.len())
|
||||
}
|
||||
Value::WalRecord(rec) => {
|
||||
let wal_desc = walrecord::describe_wal_record(&rec)?;
|
||||
format!(
|
||||
@@ -1138,6 +1148,11 @@ impl DeltaLayerInner {
|
||||
let checkpoint = CheckPoint::decode(&img)?;
|
||||
println!(" CHECKPOINT: {:?}", checkpoint);
|
||||
}
|
||||
Value::CompressedImage(img) => {
|
||||
let decompressed = lz4_flex::block::decompress(&img, BLCKSZ as usize)?;
|
||||
let checkpoint = CheckPoint::decode(&decompressed)?;
|
||||
println!(" CHECKPOINT: {:?}", checkpoint);
|
||||
}
|
||||
Value::WalRecord(_rec) => {
|
||||
println!(" unexpected walrecord value for checkpoint key");
|
||||
}
|
||||
|
||||
@@ -39,7 +39,9 @@ use crate::tenant::vectored_blob_io::{
|
||||
};
|
||||
use crate::tenant::{PageReconstructError, Timeline};
|
||||
use crate::virtual_file::{self, VirtualFile};
|
||||
use crate::{IMAGE_FILE_MAGIC, STORAGE_FORMAT_VERSION, TEMP_FILE_SUFFIX};
|
||||
use crate::{
|
||||
COMPRESSED_STORAGE_FORMAT_VERSION, IMAGE_FILE_MAGIC, STORAGE_FORMAT_VERSION, TEMP_FILE_SUFFIX,
|
||||
};
|
||||
use anyhow::{anyhow, bail, ensure, Context, Result};
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use camino::{Utf8Path, Utf8PathBuf};
|
||||
@@ -153,6 +155,7 @@ pub struct ImageLayerInner {
|
||||
// values copied from summary
|
||||
index_start_blk: u32,
|
||||
index_root_blk: u32,
|
||||
format_version: u16,
|
||||
|
||||
lsn: Lsn,
|
||||
|
||||
@@ -167,6 +170,7 @@ impl std::fmt::Debug for ImageLayerInner {
|
||||
f.debug_struct("ImageLayerInner")
|
||||
.field("index_start_blk", &self.index_start_blk)
|
||||
.field("index_root_blk", &self.index_root_blk)
|
||||
.field("format_version", &self.format_version)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
@@ -408,6 +412,7 @@ impl ImageLayerInner {
|
||||
Ok(Ok(ImageLayerInner {
|
||||
index_start_blk: actual_summary.index_start_blk,
|
||||
index_root_blk: actual_summary.index_root_blk,
|
||||
format_version: actual_summary.format_version,
|
||||
lsn,
|
||||
file,
|
||||
file_id,
|
||||
@@ -436,18 +441,20 @@ impl ImageLayerInner {
|
||||
)
|
||||
.await?
|
||||
{
|
||||
let blob = block_reader
|
||||
.block_cursor()
|
||||
.read_blob(
|
||||
offset,
|
||||
&RequestContextBuilder::extend(ctx)
|
||||
.page_content_kind(PageContentKind::ImageLayerValue)
|
||||
.build(),
|
||||
)
|
||||
.await
|
||||
.with_context(|| format!("failed to read value from offset {}", offset))?;
|
||||
let value = Bytes::from(blob);
|
||||
let ctx = RequestContextBuilder::extend(ctx)
|
||||
.page_content_kind(PageContentKind::ImageLayerValue)
|
||||
.build();
|
||||
let blob = (if self.format_version >= COMPRESSED_STORAGE_FORMAT_VERSION {
|
||||
block_reader
|
||||
.block_cursor()
|
||||
.read_compressed_blob(offset, &ctx)
|
||||
.await
|
||||
} else {
|
||||
block_reader.block_cursor().read_blob(offset, &ctx).await
|
||||
})
|
||||
.with_context(|| format!("failed to read value from offset {}", offset))?;
|
||||
|
||||
let value = Bytes::from(blob);
|
||||
reconstruct_state.img = Some((self.lsn, value));
|
||||
Ok(ValueReconstructResult::Complete)
|
||||
} else {
|
||||
@@ -658,10 +665,7 @@ impl ImageLayerWriterInner {
|
||||
///
|
||||
async fn put_image(&mut self, key: Key, img: Bytes) -> anyhow::Result<()> {
|
||||
ensure!(self.key_range.contains(&key));
|
||||
let (_img, res) = self.blob_writer.write_blob(img).await;
|
||||
// TODO: re-use the buffer for `img` further upstack
|
||||
let off = res?;
|
||||
|
||||
let off = self.blob_writer.write_compressed_blob(img).await?;
|
||||
let mut keybuf: [u8; KEY_SIZE] = [0u8; KEY_SIZE];
|
||||
key.write_to_byte_slice(&mut keybuf);
|
||||
self.tree.append(&keybuf, off)?;
|
||||
|
||||
@@ -14,9 +14,12 @@ use crate::tenant::timeline::GetVectoredError;
|
||||
use crate::tenant::{PageReconstructError, Timeline};
|
||||
use crate::walrecord;
|
||||
use anyhow::{anyhow, ensure, Result};
|
||||
use bytes::Bytes;
|
||||
use lz4_flex;
|
||||
use pageserver_api::keyspace::KeySpace;
|
||||
use pageserver_api::models::InMemoryLayerInfo;
|
||||
use pageserver_api::shard::TenantShardId;
|
||||
use postgres_ffi::BLCKSZ;
|
||||
use std::collections::{BinaryHeap, HashMap, HashSet};
|
||||
use std::sync::{Arc, OnceLock};
|
||||
use tracing::*;
|
||||
@@ -133,6 +136,9 @@ impl InMemoryLayer {
|
||||
Ok(Value::Image(img)) => {
|
||||
write!(&mut desc, " img {} bytes", img.len())?;
|
||||
}
|
||||
Ok(Value::CompressedImage(img)) => {
|
||||
write!(&mut desc, " compressed img {} bytes", img.len())?;
|
||||
}
|
||||
Ok(Value::WalRecord(rec)) => {
|
||||
let wal_desc = walrecord::describe_wal_record(&rec).unwrap();
|
||||
write!(
|
||||
@@ -184,6 +190,11 @@ impl InMemoryLayer {
|
||||
reconstruct_state.img = Some((*entry_lsn, img));
|
||||
return Ok(ValueReconstructResult::Complete);
|
||||
}
|
||||
Value::CompressedImage(img) => {
|
||||
let decompressed = lz4_flex::block::decompress(&img, BLCKSZ as usize)?;
|
||||
reconstruct_state.img = Some((*entry_lsn, Bytes::from(decompressed)));
|
||||
return Ok(ValueReconstructResult::Complete);
|
||||
}
|
||||
Value::WalRecord(rec) => {
|
||||
let will_init = rec.will_init();
|
||||
reconstruct_state.records.push((*entry_lsn, rec));
|
||||
|
||||
@@ -880,23 +880,18 @@ impl LayerInner {
|
||||
) -> Result<heavier_once_cell::InitPermit, DownloadError> {
|
||||
debug_assert_current_span_has_tenant_and_timeline_id();
|
||||
|
||||
let task_name = format!("download layer {}", self);
|
||||
|
||||
let (tx, rx) = tokio::sync::oneshot::channel();
|
||||
|
||||
// this is sadly needed because of task_mgr::shutdown_tasks, otherwise we cannot
|
||||
// block tenant::mgr::remove_tenant_from_memory.
|
||||
|
||||
let this: Arc<Self> = self.clone();
|
||||
|
||||
crate::task_mgr::spawn(
|
||||
&tokio::runtime::Handle::current(),
|
||||
crate::task_mgr::TaskKind::RemoteDownloadTask,
|
||||
Some(self.desc.tenant_shard_id),
|
||||
Some(self.desc.timeline_id),
|
||||
&task_name,
|
||||
false,
|
||||
async move {
|
||||
let guard = timeline
|
||||
.gate
|
||||
.enter()
|
||||
.map_err(|_| DownloadError::DownloadCancelled)?;
|
||||
|
||||
tokio::task::spawn(async move {
|
||||
|
||||
let _guard = guard;
|
||||
|
||||
let client = timeline
|
||||
.remote_client
|
||||
@@ -906,7 +901,7 @@ impl LayerInner {
|
||||
let result = client.download_layer_file(
|
||||
&this.desc.filename(),
|
||||
&this.metadata(),
|
||||
&crate::task_mgr::shutdown_token()
|
||||
&timeline.cancel
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -929,7 +924,6 @@ impl LayerInner {
|
||||
|
||||
tokio::select! {
|
||||
_ = tokio::time::sleep(backoff) => {},
|
||||
_ = crate::task_mgr::shutdown_token().cancelled_owned() => {},
|
||||
_ = timeline.cancel.cancelled() => {},
|
||||
};
|
||||
|
||||
@@ -959,11 +953,10 @@ impl LayerInner {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
.in_current_span(),
|
||||
);
|
||||
|
||||
match rx.await {
|
||||
Ok((Ok(()), permit)) => {
|
||||
if let Some(reason) = self
|
||||
|
||||
@@ -471,8 +471,9 @@ impl WalIngest {
|
||||
&& decoded.xl_rmid == pg_constants::RM_XLOG_ID
|
||||
&& (decoded.xl_info == pg_constants::XLOG_FPI
|
||||
|| decoded.xl_info == pg_constants::XLOG_FPI_FOR_HINT)
|
||||
// compression of WAL is not yet supported: fall back to storing the original WAL record
|
||||
&& !postgres_ffi::bkpimage_is_compressed(blk.bimg_info, modification.tline.pg_version)?
|
||||
// only lz4 compression of WAL is now supported, for other compression algorithms fall back to storing the original WAL record
|
||||
&& (!postgres_ffi::bkpimage_is_compressed(blk.bimg_info, modification.tline.pg_version)? ||
|
||||
postgres_ffi::bkpimage_is_compressed_lz4(blk.bimg_info, modification.tline.pg_version)?)
|
||||
// do not materialize null pages because them most likely be soon replaced with real data
|
||||
&& blk.bimg_len != 0
|
||||
{
|
||||
@@ -480,7 +481,21 @@ impl WalIngest {
|
||||
let img_len = blk.bimg_len as usize;
|
||||
let img_offs = blk.bimg_offset as usize;
|
||||
let mut image = BytesMut::with_capacity(BLCKSZ as usize);
|
||||
image.extend_from_slice(&decoded.record[img_offs..img_offs + img_len]);
|
||||
if postgres_ffi::bkpimage_is_compressed_lz4(
|
||||
blk.bimg_info,
|
||||
modification.tline.pg_version,
|
||||
)? {
|
||||
let decompressed_img_len = (BLCKSZ - blk.hole_length) as usize;
|
||||
let decompressed = lz4_flex::block::decompress(
|
||||
&decoded.record[img_offs..img_offs + img_len],
|
||||
decompressed_img_len,
|
||||
)
|
||||
.map_err(|msg| PageReconstructError::Other(anyhow::anyhow!(msg)))?;
|
||||
assert_eq!(decompressed.len(), decompressed_img_len);
|
||||
image.extend_from_slice(&decompressed);
|
||||
} else {
|
||||
image.extend_from_slice(&decoded.record[img_offs..img_offs + img_len]);
|
||||
}
|
||||
|
||||
if blk.hole_length != 0 {
|
||||
let tail = image.split_off(blk.hole_offset as usize);
|
||||
|
||||
@@ -18,10 +18,10 @@ OBJS = \
|
||||
|
||||
PG_CPPFLAGS = -I$(libpq_srcdir)
|
||||
SHLIB_LINK_INTERNAL = $(libpq)
|
||||
SHLIB_LINK = -lcurl
|
||||
SHLIB_LINK = -lcurl -llz4
|
||||
|
||||
EXTENSION = neon
|
||||
DATA = neon--1.0.sql neon--1.0--1.1.sql neon--1.1--1.2.sql neon--1.2--1.3.sql
|
||||
DATA = neon--1.0.sql neon--1.0--1.1.sql neon--1.1--1.2.sql neon--1.2--1.3.sql neon--1.3--1.2.sql neon--1.2--1.1.sql neon--1.1--1.0.sql
|
||||
PGFILEDESC = "neon - cloud storage for PostgreSQL"
|
||||
|
||||
EXTRA_CLEAN = \
|
||||
|
||||
6
pgxn/neon/neon--1.1--1.0.sql
Normal file
6
pgxn/neon/neon--1.1--1.0.sql
Normal file
@@ -0,0 +1,6 @@
|
||||
-- the order of operations is important here
|
||||
-- because the view depends on the function
|
||||
|
||||
DROP VIEW IF EXISTS neon_lfc_stats CASCADE;
|
||||
|
||||
DROP FUNCTION IF EXISTS neon_get_lfc_stats CASCADE;
|
||||
1
pgxn/neon/neon--1.2--1.1.sql
Normal file
1
pgxn/neon/neon--1.2--1.1.sql
Normal file
@@ -0,0 +1 @@
|
||||
DROP VIEW IF EXISTS NEON_STAT_FILE_CACHE CASCADE;
|
||||
1
pgxn/neon/neon--1.3--1.2.sql
Normal file
1
pgxn/neon/neon--1.3--1.2.sql
Normal file
@@ -0,0 +1 @@
|
||||
DROP FUNCTION IF EXISTS approximate_working_set_size(bool) CASCADE;
|
||||
@@ -44,6 +44,7 @@ typedef enum
|
||||
T_NeonErrorResponse,
|
||||
T_NeonDbSizeResponse,
|
||||
T_NeonGetSlruSegmentResponse,
|
||||
T_NeonGetCompressedPageResponse
|
||||
} NeonMessageTag;
|
||||
|
||||
/* base struct for c-style inheritance */
|
||||
@@ -144,6 +145,15 @@ typedef struct
|
||||
|
||||
#define PS_GETPAGERESPONSE_SIZE (MAXALIGN(offsetof(NeonGetPageResponse, page) + BLCKSZ))
|
||||
|
||||
typedef struct
|
||||
{
|
||||
NeonMessageTag tag;
|
||||
uint16 compressed_size;
|
||||
char page[FLEXIBLE_ARRAY_MEMBER];
|
||||
} NeonGetCompressedPageResponse;
|
||||
|
||||
#define PS_GETCOMPRESSEDPAGERESPONSE_SIZE(compressded_size) (MAXALIGN(offsetof(NeonGetCompressedPageResponse, page) + compressed_size))
|
||||
|
||||
typedef struct
|
||||
{
|
||||
NeonMessageTag tag;
|
||||
|
||||
@@ -45,6 +45,10 @@
|
||||
*/
|
||||
#include "postgres.h"
|
||||
|
||||
#ifdef USE_LZ4
|
||||
#include <lz4.h>
|
||||
#endif
|
||||
|
||||
#include "access/xact.h"
|
||||
#include "access/xlog.h"
|
||||
#include "access/xlogdefs.h"
|
||||
@@ -1059,6 +1063,7 @@ nm_pack_request(NeonRequest *msg)
|
||||
case T_NeonExistsResponse:
|
||||
case T_NeonNblocksResponse:
|
||||
case T_NeonGetPageResponse:
|
||||
case T_NeonGetCompressedPageResponse:
|
||||
case T_NeonErrorResponse:
|
||||
case T_NeonDbSizeResponse:
|
||||
case T_NeonGetSlruSegmentResponse:
|
||||
@@ -1114,6 +1119,21 @@ nm_unpack_response(StringInfo s)
|
||||
|
||||
Assert(msg_resp->tag == T_NeonGetPageResponse);
|
||||
|
||||
resp = (NeonResponse *) msg_resp;
|
||||
break;
|
||||
}
|
||||
case T_NeonGetCompressedPageResponse:
|
||||
{
|
||||
NeonGetCompressedPageResponse *msg_resp;
|
||||
uint16 compressed_size = pq_getmsgint(s, 2);
|
||||
msg_resp = palloc0(PS_GETCOMPRESSEDPAGERESPONSE_SIZE(compressed_size));
|
||||
msg_resp->tag = tag;
|
||||
msg_resp->compressed_size = compressed_size;
|
||||
memcpy(msg_resp->page, pq_getmsgbytes(s, compressed_size), compressed_size);
|
||||
pq_getmsgend(s);
|
||||
|
||||
Assert(msg_resp->tag == T_NeonGetCompressedPageResponse);
|
||||
|
||||
resp = (NeonResponse *) msg_resp;
|
||||
break;
|
||||
}
|
||||
@@ -1287,6 +1307,14 @@ nm_to_string(NeonMessage *msg)
|
||||
appendStringInfoChar(&s, '}');
|
||||
break;
|
||||
}
|
||||
case T_NeonGetCompressedPageResponse:
|
||||
{
|
||||
NeonGetCompressedPageResponse *msg_resp = (NeonGetCompressedPageResponse *) msg;
|
||||
appendStringInfoString(&s, "{\"type\": \"NeonGetCompressedPageResponse\"");
|
||||
appendStringInfo(&s, ", \"compressed_page_size\": \"%d\"}", msg_resp->compressed_size);
|
||||
appendStringInfoChar(&s, '}');
|
||||
break;
|
||||
}
|
||||
case T_NeonErrorResponse:
|
||||
{
|
||||
NeonErrorResponse *msg_resp = (NeonErrorResponse *) msg;
|
||||
@@ -2205,6 +2233,29 @@ neon_read_at_lsn(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
|
||||
lfc_write(rinfo, forkNum, blkno, buffer);
|
||||
break;
|
||||
|
||||
case T_NeonGetCompressedPageResponse:
|
||||
{
|
||||
#ifndef USE_LZ4
|
||||
ereport(ERROR, \
|
||||
(errcode(ERRCODE_FEATURE_NOT_SUPPORTED), \
|
||||
errmsg("compression method lz4 not supported"), \
|
||||
errdetail("This functionality requires the server to be built with lz4 support."), \
|
||||
errhint("You need to rebuild PostgreSQL using %s.", "--with-lz4")))
|
||||
#else
|
||||
NeonGetCompressedPageResponse* cp = (NeonGetCompressedPageResponse *) resp;
|
||||
int rc = LZ4_decompress_safe(cp->page,
|
||||
buffer,
|
||||
cp->compressed_size,
|
||||
BLCKSZ);
|
||||
if (rc != BLCKSZ) {
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_DATA_CORRUPTED),
|
||||
errmsg_internal("compressed lz4 data is corrupt")));
|
||||
}
|
||||
lfc_write(rinfo, forkNum, blkno, buffer);
|
||||
#endif
|
||||
break;
|
||||
}
|
||||
case T_NeonErrorResponse:
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_IO_ERROR),
|
||||
|
||||
@@ -30,10 +30,6 @@ hostname.workspace = true
|
||||
humantime.workspace = true
|
||||
hyper-tungstenite.workspace = true
|
||||
hyper.workspace = true
|
||||
hyper1 = { package = "hyper", version = "1.2", features = ["server", "http1", "http2"] }
|
||||
hyper-util = { version = "0.1", features = ["tokio"] }
|
||||
http1 = { package = "http", version = "1" }
|
||||
http-body-util = { version = "0.1" }
|
||||
ipnet.workspace = true
|
||||
itertools.workspace = true
|
||||
lasso = { workspace = true, features = ["multi-threaded"] }
|
||||
|
||||
@@ -175,7 +175,7 @@ async fn task_main(
|
||||
.context("failed to set socket option")?;
|
||||
|
||||
info!(%peer_addr, "serving");
|
||||
let ctx = RequestMonitoring::new(session_id, peer_addr, "sni_router", "sni");
|
||||
let ctx = RequestMonitoring::new(session_id, peer_addr.ip(), "sni_router", "sni");
|
||||
handle_client(ctx, dest_suffix, tls_config, tls_server_end_point, socket).await
|
||||
}
|
||||
.unwrap_or_else(|e| {
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
use chrono::Utc;
|
||||
use once_cell::sync::OnceCell;
|
||||
use smol_str::SmolStr;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::net::IpAddr;
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::{field::display, info_span, Span};
|
||||
use uuid::Uuid;
|
||||
@@ -62,7 +62,7 @@ pub enum AuthMethod {
|
||||
impl RequestMonitoring {
|
||||
pub fn new(
|
||||
session_id: Uuid,
|
||||
peer_addr: SocketAddr,
|
||||
peer_addr: IpAddr,
|
||||
protocol: &'static str,
|
||||
region: &'static str,
|
||||
) -> Self {
|
||||
@@ -75,7 +75,7 @@ impl RequestMonitoring {
|
||||
);
|
||||
|
||||
Self {
|
||||
peer_addr: peer_addr.ip(),
|
||||
peer_addr,
|
||||
session_id,
|
||||
protocol,
|
||||
first_packet: Utc::now(),
|
||||
@@ -100,12 +100,7 @@ impl RequestMonitoring {
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn test() -> Self {
|
||||
RequestMonitoring::new(
|
||||
Uuid::now_v7(),
|
||||
([127, 0, 0, 1], 5432).into(),
|
||||
"test",
|
||||
"test",
|
||||
)
|
||||
RequestMonitoring::new(Uuid::now_v7(), [127, 0, 0, 1].into(), "test", "test")
|
||||
}
|
||||
|
||||
pub fn console_application_name(&self) -> String {
|
||||
|
||||
@@ -5,13 +5,19 @@ use std::{
|
||||
io,
|
||||
net::SocketAddr,
|
||||
pin::{pin, Pin},
|
||||
sync::Mutex,
|
||||
task::{ready, Context, Poll},
|
||||
};
|
||||
|
||||
use bytes::{Buf, BytesMut};
|
||||
use hyper::server::conn::AddrIncoming;
|
||||
use hyper::server::accept::Accept;
|
||||
use hyper::server::conn::{AddrIncoming, AddrStream};
|
||||
use metrics::IntCounterPairGuard;
|
||||
use pin_project_lite::pin_project;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{metrics::NUM_CLIENT_CONNECTION_GAUGE, serverless::tls_listener::AsyncAccept};
|
||||
|
||||
pub struct ProxyProtocolAccept {
|
||||
pub incoming: AddrIncoming,
|
||||
@@ -325,6 +331,87 @@ impl<T: AsyncRead> AsyncRead for WithClientIp<T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncAccept for ProxyProtocolAccept {
|
||||
type Connection = WithConnectionGuard<WithClientIp<AddrStream>>;
|
||||
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll_accept(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<Self::Connection, Self::Error>>> {
|
||||
let conn = ready!(Pin::new(&mut self.incoming).poll_accept(cx)?);
|
||||
tracing::info!(protocol = self.protocol, "accepted new TCP connection");
|
||||
let Some(conn) = conn else {
|
||||
return Poll::Ready(None);
|
||||
};
|
||||
|
||||
Poll::Ready(Some(Ok(WithConnectionGuard {
|
||||
inner: WithClientIp::new(conn),
|
||||
connection_id: Uuid::new_v4(),
|
||||
gauge: Mutex::new(Some(
|
||||
NUM_CLIENT_CONNECTION_GAUGE
|
||||
.with_label_values(&[self.protocol])
|
||||
.guard(),
|
||||
)),
|
||||
})))
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
pub struct WithConnectionGuard<T> {
|
||||
#[pin]
|
||||
pub inner: T,
|
||||
pub connection_id: Uuid,
|
||||
pub gauge: Mutex<Option<IntCounterPairGuard>>,
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsyncWrite> AsyncWrite for WithConnectionGuard<T> {
|
||||
#[inline]
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, io::Error>> {
|
||||
self.project().inner.poll_write(cx, buf)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
self.project().inner.poll_flush(cx)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
self.project().inner.poll_shutdown(cx)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_write_vectored(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
bufs: &[io::IoSlice<'_>],
|
||||
) -> Poll<Result<usize, io::Error>> {
|
||||
self.project().inner.poll_write_vectored(cx, bufs)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn is_write_vectored(&self) -> bool {
|
||||
self.inner.is_write_vectored()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsyncRead> AsyncRead for WithConnectionGuard<T> {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
self.project().inner.poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::pin::pin;
|
||||
|
||||
@@ -91,8 +91,9 @@ pub async fn task_main(
|
||||
|
||||
connections.spawn(async move {
|
||||
let mut socket = WithClientIp::new(socket);
|
||||
let peer_addr = match socket.wait_for_addr().await {
|
||||
Ok(Some(addr)) => addr,
|
||||
let mut peer_addr = peer_addr.ip();
|
||||
match socket.wait_for_addr().await {
|
||||
Ok(Some(addr)) => peer_addr = addr.ip(),
|
||||
Err(e) => {
|
||||
error!("per-client task finished with an error: {e:#}");
|
||||
return;
|
||||
@@ -101,8 +102,8 @@ pub async fn task_main(
|
||||
error!("missing required client IP");
|
||||
return;
|
||||
}
|
||||
Ok(None) => peer_addr
|
||||
};
|
||||
Ok(None) => {}
|
||||
}
|
||||
|
||||
match socket.inner.set_nodelay(true) {
|
||||
Ok(()) => {},
|
||||
|
||||
@@ -4,45 +4,46 @@
|
||||
|
||||
mod backend;
|
||||
mod conn_pool;
|
||||
mod http_auto;
|
||||
mod json;
|
||||
mod sql_over_http;
|
||||
pub mod tls_listener;
|
||||
mod websocket;
|
||||
|
||||
use bytes::Bytes;
|
||||
pub use conn_pool::GlobalConnPoolOptions;
|
||||
|
||||
use anyhow::Context;
|
||||
use futures::future::{select, Either};
|
||||
use http1::{Method, Response, StatusCode};
|
||||
use http_body_util::Full;
|
||||
use hyper1::body::Incoming;
|
||||
use anyhow::bail;
|
||||
use hyper::StatusCode;
|
||||
use metrics::IntCounterPairGuard;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::SeedableRng;
|
||||
pub use reqwest_middleware::{ClientWithMiddleware, Error};
|
||||
pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
|
||||
use serde::Serialize;
|
||||
use tokio::time::timeout;
|
||||
use tokio_util::task::TaskTracker;
|
||||
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::metrics::{NUM_CLIENT_CONNECTION_GAUGE, TLS_HANDSHAKE_FAILURES};
|
||||
use crate::protocol2::WithClientIp;
|
||||
use crate::proxy::run_until_cancelled;
|
||||
use crate::metrics::TLS_HANDSHAKE_FAILURES;
|
||||
use crate::protocol2::{ProxyProtocolAccept, WithClientIp, WithConnectionGuard};
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::serverless::backend::PoolingBackend;
|
||||
use crate::serverless::http_auto::Rewind;
|
||||
use crate::{cancellation::CancellationHandler, config::ProxyConfig};
|
||||
use futures::StreamExt;
|
||||
use hyper::{
|
||||
server::{
|
||||
accept,
|
||||
conn::{AddrIncoming, AddrStream},
|
||||
},
|
||||
Body, Method, Request, Response,
|
||||
};
|
||||
|
||||
use std::convert::Infallible;
|
||||
use std::net::SocketAddr;
|
||||
use std::pin::pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::net::IpAddr;
|
||||
use std::task::Poll;
|
||||
use std::{future::ready, sync::Arc};
|
||||
use tls_listener::TlsListener;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{error, info, warn, Instrument};
|
||||
use utils::http::error::ApiError;
|
||||
use utils::http::{error::ApiError, json::json_response};
|
||||
|
||||
pub const SERVERLESS_DRIVER_SNI: &str = "api";
|
||||
|
||||
@@ -94,221 +95,134 @@ pub async fn task_main(
|
||||
tls_server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
|
||||
let tls_acceptor: tokio_rustls::TlsAcceptor = Arc::new(tls_server_config).into();
|
||||
|
||||
let mut addr_incoming = AddrIncoming::from_listener(ws_listener)?;
|
||||
let _ = addr_incoming.set_nodelay(true);
|
||||
let addr_incoming = ProxyProtocolAccept {
|
||||
incoming: addr_incoming,
|
||||
protocol: "http",
|
||||
};
|
||||
|
||||
let ws_connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
ws_connections.close(); // allows `ws_connections.wait to complete`
|
||||
|
||||
let http_connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
http_connections.close();
|
||||
|
||||
let server = http_auto::Builder::new();
|
||||
|
||||
loop {
|
||||
let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await else {
|
||||
break;
|
||||
};
|
||||
|
||||
let (conn, mut peer_addr) = res.context("could not accept TCP stream")?;
|
||||
if let Err(e) = conn.set_nodelay(true) {
|
||||
tracing::error!("could not set nodolay: {e}");
|
||||
continue;
|
||||
}
|
||||
let cancellation_token = cancellation_token.child_token();
|
||||
|
||||
let tls = tls_acceptor.clone();
|
||||
|
||||
let backend = backend.clone();
|
||||
let ws_connections = ws_connections.clone();
|
||||
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
||||
let cancellation_handler = cancellation_handler.clone();
|
||||
let server = server.clone();
|
||||
|
||||
http_connections.spawn(async move {
|
||||
let _gauge = NUM_CLIENT_CONNECTION_GAUGE
|
||||
.with_label_values(&["http"])
|
||||
.guard();
|
||||
|
||||
// handle PROXY protocol
|
||||
let mut conn = WithClientIp::new(conn);
|
||||
let peer = match conn.wait_for_addr().await {
|
||||
Ok(peer) => peer,
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
"failed to accept TCP connection: invalid PROXY protocol V2 header: {e:#}"
|
||||
);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(peer) = peer {
|
||||
peer_addr = peer;
|
||||
}
|
||||
info!(%peer_addr, protocol = "http", "accepted new TCP connection");
|
||||
|
||||
let accept = tls.accept(conn);
|
||||
let conn = match timeout(Duration::from_secs(10), accept).await {
|
||||
Ok(Ok(conn)) => {
|
||||
info!(%peer_addr, protocol = "http", "accepted new TLS connection");
|
||||
conn
|
||||
}
|
||||
// The handshake failed, try getting another connection from the queue
|
||||
Ok(Err(e)) => {
|
||||
TLS_HANDSHAKE_FAILURES.inc();
|
||||
warn!(%peer_addr, protocol = "http", "failed to accept TLS connection: {e:?}");
|
||||
return;
|
||||
}
|
||||
// The handshake timed out, try getting another connection from the queue
|
||||
Err(_) => {
|
||||
TLS_HANDSHAKE_FAILURES.inc();
|
||||
warn!(%peer_addr, protocol = "http", "failed to accept TLS connection: timeout");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let (version, conn) = match conn.get_ref().1.alpn_protocol() {
|
||||
Some(b"http/1.1") => (http_auto::Version::H1, Rewind::new(conn)),
|
||||
Some(b"h2") => (http_auto::Version::H2, Rewind::new(conn)),
|
||||
_ => {
|
||||
tracing::debug!("HTTP: no ALPN negotiated");
|
||||
let conn = timeout(Duration::from_secs(10), http_auto::read_version(conn)).await;
|
||||
match conn {
|
||||
Ok(Ok(v)) => v,
|
||||
Ok(Err(e)) => {
|
||||
tracing::warn!("HTTP connection error: {e}");
|
||||
return;
|
||||
},
|
||||
Err(_) => {
|
||||
tracing::warn!("HTTP connection error: timeout determining http version");
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let conn = server.serve_connection_with_upgrades(
|
||||
conn,
|
||||
version,
|
||||
hyper1::service::service_fn(move |req: hyper1::Request<Incoming>| {
|
||||
let backend = backend.clone();
|
||||
let ws_connections = ws_connections.clone();
|
||||
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
||||
let cancellation_handler = cancellation_handler.clone();
|
||||
|
||||
async move {
|
||||
Ok::<_, Infallible>(
|
||||
request_handler(
|
||||
req,
|
||||
config,
|
||||
backend,
|
||||
ws_connections,
|
||||
cancellation_handler,
|
||||
peer_addr,
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.await
|
||||
.map_or_else(api_error_into_response, |r| r),
|
||||
)
|
||||
}
|
||||
})
|
||||
let tls_listener = TlsListener::new(tls_acceptor, addr_incoming).filter(|conn| {
|
||||
if let Err(err) = conn {
|
||||
error!(
|
||||
protocol = "http",
|
||||
"failed to accept TLS connection: {err:?}"
|
||||
);
|
||||
TLS_HANDSHAKE_FAILURES.inc();
|
||||
ready(false)
|
||||
} else {
|
||||
info!(protocol = "http", "accepted new TLS connection");
|
||||
ready(true)
|
||||
}
|
||||
});
|
||||
|
||||
let make_svc = hyper::service::make_service_fn(
|
||||
|stream: &tokio_rustls::server::TlsStream<
|
||||
WithConnectionGuard<WithClientIp<AddrStream>>,
|
||||
>| {
|
||||
let (conn, _) = stream.get_ref();
|
||||
|
||||
let cancel = pin!(cancellation_token.cancelled());
|
||||
let conn = pin!(conn);
|
||||
let res = match select(cancel, conn).await {
|
||||
Either::Left((_cancelled, mut conn)) => {
|
||||
conn.as_mut().graceful_shutdown();
|
||||
conn.await
|
||||
}
|
||||
Either::Right((res, _)) => res,
|
||||
};
|
||||
// this is jank. should dissapear with hyper 1.0 migration.
|
||||
let gauge = conn
|
||||
.gauge
|
||||
.lock()
|
||||
.expect("lock should not be poisoned")
|
||||
.take()
|
||||
.expect("gauge should be set on connection start");
|
||||
|
||||
match res {
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
tracing::warn!("HTTP connection error {e}")
|
||||
}
|
||||
let client_addr = conn.inner.client_addr();
|
||||
let remote_addr = conn.inner.inner.remote_addr();
|
||||
let backend = backend.clone();
|
||||
let ws_connections = ws_connections.clone();
|
||||
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
||||
let cancellation_handler = cancellation_handler.clone();
|
||||
async move {
|
||||
let peer_addr = match client_addr {
|
||||
Some(addr) => addr,
|
||||
None if config.require_client_ip => bail!("missing required client ip"),
|
||||
None => remote_addr,
|
||||
};
|
||||
Ok(MetricService::new(
|
||||
hyper::service::service_fn(move |req: Request<Body>| {
|
||||
let backend = backend.clone();
|
||||
let ws_connections = ws_connections.clone();
|
||||
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
||||
let cancellation_handler = cancellation_handler.clone();
|
||||
|
||||
async move {
|
||||
Ok::<_, Infallible>(
|
||||
request_handler(
|
||||
req,
|
||||
config,
|
||||
backend,
|
||||
ws_connections,
|
||||
cancellation_handler,
|
||||
peer_addr.ip(),
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.await
|
||||
.map_or_else(|e| e.into_response(), |r| r),
|
||||
)
|
||||
}
|
||||
}),
|
||||
gauge,
|
||||
))
|
||||
}
|
||||
});
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
hyper::Server::builder(accept::from_stream(tls_listener))
|
||||
.serve(make_svc)
|
||||
.with_graceful_shutdown(cancellation_token.cancelled())
|
||||
.await?;
|
||||
|
||||
// await websocket connections
|
||||
http_connections.wait().await;
|
||||
ws_connections.wait().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn api_error_into_response(this: ApiError) -> Response<Full<Bytes>> {
|
||||
match this {
|
||||
ApiError::BadRequest(err) => HttpErrorBody::response_from_msg_and_status(
|
||||
format!("{err:#?}"), // use debug printing so that we give the cause
|
||||
StatusCode::BAD_REQUEST,
|
||||
),
|
||||
ApiError::Forbidden(_) => {
|
||||
HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::FORBIDDEN)
|
||||
}
|
||||
ApiError::Unauthorized(_) => {
|
||||
HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::UNAUTHORIZED)
|
||||
}
|
||||
ApiError::NotFound(_) => {
|
||||
HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::NOT_FOUND)
|
||||
}
|
||||
ApiError::Conflict(_) => {
|
||||
HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::CONFLICT)
|
||||
}
|
||||
ApiError::PreconditionFailed(_) => HttpErrorBody::response_from_msg_and_status(
|
||||
this.to_string(),
|
||||
StatusCode::PRECONDITION_FAILED,
|
||||
),
|
||||
ApiError::ShuttingDown => HttpErrorBody::response_from_msg_and_status(
|
||||
"Shutting down".to_string(),
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
),
|
||||
ApiError::ResourceUnavailable(err) => HttpErrorBody::response_from_msg_and_status(
|
||||
err.to_string(),
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
),
|
||||
ApiError::Timeout(err) => HttpErrorBody::response_from_msg_and_status(
|
||||
err.to_string(),
|
||||
StatusCode::REQUEST_TIMEOUT,
|
||||
),
|
||||
ApiError::InternalServerError(err) => HttpErrorBody::response_from_msg_and_status(
|
||||
err.to_string(),
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
),
|
||||
struct MetricService<S> {
|
||||
inner: S,
|
||||
_gauge: IntCounterPairGuard,
|
||||
}
|
||||
|
||||
impl<S> MetricService<S> {
|
||||
fn new(inner: S, _gauge: IntCounterPairGuard) -> MetricService<S> {
|
||||
MetricService { inner, _gauge }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct HttpErrorBody {
|
||||
pub msg: String,
|
||||
}
|
||||
impl<S, ReqBody> hyper::service::Service<Request<ReqBody>> for MetricService<S>
|
||||
where
|
||||
S: hyper::service::Service<Request<ReqBody>>,
|
||||
{
|
||||
type Response = S::Response;
|
||||
type Error = S::Error;
|
||||
type Future = S::Future;
|
||||
|
||||
impl HttpErrorBody {
|
||||
pub fn response_from_msg_and_status(msg: String, status: StatusCode) -> Response<Full<Bytes>> {
|
||||
HttpErrorBody { msg }.to_response(status)
|
||||
fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.inner.poll_ready(cx)
|
||||
}
|
||||
|
||||
pub fn to_response(&self, status: StatusCode) -> Response<Full<Bytes>> {
|
||||
Response::builder()
|
||||
.status(status)
|
||||
.header(http1::header::CONTENT_TYPE, "application/json")
|
||||
// we do not have nested maps with non string keys so serialization shouldn't fail
|
||||
.body(Full::new(Bytes::from(serde_json::to_string(self).unwrap())))
|
||||
.unwrap()
|
||||
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
|
||||
self.inner.call(req)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn request_handler(
|
||||
mut request: hyper1::Request<Incoming>,
|
||||
mut request: Request<Body>,
|
||||
config: &'static ProxyConfig,
|
||||
backend: Arc<PoolingBackend>,
|
||||
ws_connections: TaskTracker,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
peer_addr: SocketAddr,
|
||||
peer_addr: IpAddr,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> Result<Response<Full<Bytes>>, ApiError> {
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
|
||||
let host = request
|
||||
@@ -347,14 +261,14 @@ async fn request_handler(
|
||||
|
||||
// Return the response so the spawned future can continue.
|
||||
Ok(response)
|
||||
} else if request.uri().path() == "/sql" && *request.method() == Method::POST {
|
||||
} else if request.uri().path() == "/sql" && request.method() == Method::POST {
|
||||
let ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region);
|
||||
let span = ctx.span.clone();
|
||||
|
||||
sql_over_http::handle(config, ctx, request, backend)
|
||||
.instrument(span)
|
||||
.await
|
||||
} else if request.uri().path() == "/sql" && *request.method() == Method::OPTIONS {
|
||||
} else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS {
|
||||
Response::builder()
|
||||
.header("Allow", "OPTIONS, POST")
|
||||
.header("Access-Control-Allow-Origin", "*")
|
||||
@@ -364,24 +278,9 @@ async fn request_handler(
|
||||
)
|
||||
.header("Access-Control-Max-Age", "86400" /* 24 hours */)
|
||||
.status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
|
||||
.body(Full::new(Bytes::new()))
|
||||
.body(Body::empty())
|
||||
.map_err(|e| ApiError::InternalServerError(e.into()))
|
||||
} else {
|
||||
json_response(StatusCode::BAD_REQUEST, "query is not supported")
|
||||
}
|
||||
}
|
||||
|
||||
fn json_response<T: Serialize>(
|
||||
status: StatusCode,
|
||||
data: T,
|
||||
) -> Result<Response<Full<Bytes>>, ApiError> {
|
||||
let json = serde_json::to_string(&data)
|
||||
.context("Failed to serialize JSON response")
|
||||
.map_err(ApiError::InternalServerError)?;
|
||||
let response = Response::builder()
|
||||
.status(status)
|
||||
.header(http1::header::CONTENT_TYPE, "application/json")
|
||||
.body(Full::new(Bytes::from(json)))
|
||||
.map_err(|e| ApiError::InternalServerError(e.into()))?;
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
@@ -1,316 +0,0 @@
|
||||
//! [`hyper-util`] offers an 'auto' connection to detect whether the connection should be HTTP1 or HTTP2.
|
||||
//! There's a bug in this implementation where graceful shutdowns are not properly respected.
|
||||
|
||||
use futures::ready;
|
||||
use hyper1::body::Body;
|
||||
use hyper1::service::HttpService;
|
||||
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
|
||||
use std::future::Future;
|
||||
use std::marker::PhantomPinned;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use std::{error::Error as StdError, io, marker::Unpin};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
|
||||
use ::http1::{Request, Response};
|
||||
use bytes::Bytes;
|
||||
use hyper1::{body::Incoming, service::Service};
|
||||
|
||||
use hyper1::server::conn::http1;
|
||||
use hyper1::{rt::bounds::Http2ServerConnExec, server::conn::http2};
|
||||
|
||||
use pin_project_lite::pin_project;
|
||||
|
||||
type Error = Box<dyn std::error::Error + Send + Sync>;
|
||||
|
||||
type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
const H2_PREFACE: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";
|
||||
|
||||
/// Http1 or Http2 connection builder.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Builder {
|
||||
http1: http1::Builder,
|
||||
http2: http2::Builder<TokioExecutor>,
|
||||
}
|
||||
|
||||
impl Builder {
|
||||
/// Create a new auto connection builder.
|
||||
pub fn new() -> Self {
|
||||
let mut builder = Self {
|
||||
http1: http1::Builder::new(),
|
||||
http2: http2::Builder::new(TokioExecutor::new()),
|
||||
};
|
||||
|
||||
builder.http1.timer(TokioTimer::new());
|
||||
builder.http2.timer(TokioTimer::new());
|
||||
|
||||
builder
|
||||
}
|
||||
|
||||
/// Bind a connection together with a [`Service`], with the ability to
|
||||
/// handle HTTP upgrades. This requires that the IO object implements
|
||||
/// `Send`.
|
||||
pub fn serve_connection_with_upgrades<I, S, B>(
|
||||
&self,
|
||||
io: Rewind<I>,
|
||||
version: Version,
|
||||
service: S,
|
||||
) -> UpgradeableConnection<I, S>
|
||||
where
|
||||
S: Service<Request<Incoming>, Response = Response<B>>,
|
||||
S::Future: 'static,
|
||||
S::Error: Into<Box<dyn StdError + Send + Sync>>,
|
||||
B: Body + 'static,
|
||||
B::Error: Into<Box<dyn StdError + Send + Sync>>,
|
||||
I: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
TokioExecutor: Http2ServerConnExec<S::Future, B>,
|
||||
{
|
||||
match version {
|
||||
Version::H1 => {
|
||||
let conn = self
|
||||
.http1
|
||||
.serve_connection(TokioIo::new(io), service)
|
||||
.with_upgrades();
|
||||
UpgradeableConnection {
|
||||
state: UpgradeableConnState::H1 { conn },
|
||||
}
|
||||
}
|
||||
Version::H2 => {
|
||||
let conn = self.http2.serve_connection(TokioIo::new(io), service);
|
||||
UpgradeableConnection {
|
||||
state: UpgradeableConnState::H2 { conn },
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub(crate) enum Version {
|
||||
H1,
|
||||
H2,
|
||||
}
|
||||
|
||||
pub(crate) fn read_version<I>(io: I) -> ReadVersion<I>
|
||||
where
|
||||
I: AsyncRead + Unpin,
|
||||
{
|
||||
ReadVersion {
|
||||
io: Some(io),
|
||||
buf: [0; 24],
|
||||
filled: 0,
|
||||
version: Version::H2,
|
||||
_pin: PhantomPinned,
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
pub(crate) struct ReadVersion<I> {
|
||||
io: Option<I>,
|
||||
buf: [u8; 24],
|
||||
// the amount of `buf` thats been filled
|
||||
filled: usize,
|
||||
version: Version,
|
||||
// Make this future `!Unpin` for compatibility with async trait methods.
|
||||
#[pin]
|
||||
_pin: PhantomPinned,
|
||||
}
|
||||
}
|
||||
|
||||
impl<I> Future for ReadVersion<I>
|
||||
where
|
||||
I: AsyncRead + Unpin,
|
||||
{
|
||||
type Output = io::Result<(Version, Rewind<I>)>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let this = self.project();
|
||||
|
||||
let mut buf = ReadBuf::new(&mut *this.buf);
|
||||
buf.set_filled(*this.filled);
|
||||
|
||||
// We start as H2 and switch to H1 as soon as we don't have the preface.
|
||||
while buf.filled().len() < H2_PREFACE.len() {
|
||||
let len = buf.filled().len();
|
||||
ready!(Pin::new(this.io.as_mut().unwrap()).poll_read(cx, &mut buf))?;
|
||||
*this.filled = buf.filled().len();
|
||||
|
||||
// We starts as H2 and switch to H1 when we don't get the preface.
|
||||
if buf.filled().len() == len
|
||||
|| buf.filled()[len..] != H2_PREFACE[len..buf.filled().len()]
|
||||
{
|
||||
*this.version = Version::H1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let io = this.io.take().unwrap();
|
||||
let buf = buf.filled().to_vec();
|
||||
Poll::Ready(Ok((
|
||||
*this.version,
|
||||
Rewind::new_buffered(io, Bytes::from(buf)),
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
/// Connection future.
|
||||
pub struct UpgradeableConnection<I, S>
|
||||
where
|
||||
S: HttpService<Incoming>,
|
||||
{
|
||||
#[pin]
|
||||
state: UpgradeableConnState<I, S>,
|
||||
}
|
||||
}
|
||||
|
||||
type Http1UpgradeableConnection<I, S> =
|
||||
hyper1::server::conn::http1::UpgradeableConnection<TokioIo<Rewind<I>>, S>;
|
||||
type Http2Connection<I, S> =
|
||||
hyper1::server::conn::http2::Connection<TokioIo<Rewind<I>>, S, TokioExecutor>;
|
||||
|
||||
pin_project! {
|
||||
#[project = UpgradeableConnStateProj]
|
||||
enum UpgradeableConnState<I, S>
|
||||
where
|
||||
S: HttpService<Incoming>,
|
||||
{
|
||||
H1 {
|
||||
#[pin]
|
||||
conn: Http1UpgradeableConnection<I, S>,
|
||||
},
|
||||
H2 {
|
||||
#[pin]
|
||||
conn: Http2Connection<I, S>,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
impl<I, S, B> UpgradeableConnection<I, S>
|
||||
where
|
||||
S: HttpService<Incoming, ResBody = B>,
|
||||
S::Error: Into<Box<dyn StdError + Send + Sync>>,
|
||||
I: AsyncRead + AsyncWrite + Unpin,
|
||||
B: Body + 'static,
|
||||
B::Error: Into<Box<dyn StdError + Send + Sync>>,
|
||||
TokioExecutor: Http2ServerConnExec<S::Future, B>,
|
||||
{
|
||||
/// Start a graceful shutdown process for this connection.
|
||||
///
|
||||
/// This `UpgradeableConnection` should continue to be polled until shutdown can finish.
|
||||
///
|
||||
/// # Note
|
||||
///
|
||||
/// This should only be called while the `Connection` future is still nothing. pending. If
|
||||
/// called after `UpgradeableConnection::poll` has resolved, this does nothing.
|
||||
pub fn graceful_shutdown(self: Pin<&mut Self>) {
|
||||
match self.project().state.project() {
|
||||
UpgradeableConnStateProj::H1 { conn } => conn.graceful_shutdown(),
|
||||
UpgradeableConnStateProj::H2 { conn } => conn.graceful_shutdown(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<I, S, B> Future for UpgradeableConnection<I, S>
|
||||
where
|
||||
S: Service<Request<Incoming>, Response = Response<B>>,
|
||||
S::Future: 'static,
|
||||
S::Error: Into<Box<dyn StdError + Send + Sync>>,
|
||||
B: Body + 'static,
|
||||
B::Error: Into<Box<dyn StdError + Send + Sync>>,
|
||||
I: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
TokioExecutor: Http2ServerConnExec<S::Future, B>,
|
||||
{
|
||||
type Output = Result<()>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let mut this = self.as_mut().project();
|
||||
match this.state.as_mut().project() {
|
||||
UpgradeableConnStateProj::H1 { conn } => conn.poll(cx).map_err(Into::into),
|
||||
UpgradeableConnStateProj::H2 { conn } => conn.poll(cx).map_err(Into::into),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Combine a buffer with an IO, rewinding reads to use the buffer.
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct Rewind<T> {
|
||||
pre: Option<Bytes>,
|
||||
inner: T,
|
||||
}
|
||||
|
||||
impl<T> Rewind<T> {
|
||||
pub(crate) fn new(io: T) -> Self {
|
||||
Rewind {
|
||||
pre: None,
|
||||
inner: io,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn new_buffered(io: T, buf: Bytes) -> Self {
|
||||
Rewind {
|
||||
pre: Some(buf),
|
||||
inner: io,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> AsyncRead for Rewind<T>
|
||||
where
|
||||
T: AsyncRead + Unpin,
|
||||
{
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
if let Some(prefix) = self.pre.take() {
|
||||
// If there are no remaining bytes, let the bytes get dropped.
|
||||
if !prefix.is_empty() {
|
||||
let copy_len = std::cmp::min(prefix.len(), buf.remaining());
|
||||
buf.put_slice(&prefix[..copy_len]);
|
||||
// Put back what's left
|
||||
if !prefix.is_empty() {
|
||||
self.pre = Some(prefix);
|
||||
}
|
||||
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
}
|
||||
Pin::new(&mut self.inner).poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> AsyncWrite for Rewind<T>
|
||||
where
|
||||
T: AsyncWrite + Unpin,
|
||||
{
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
Pin::new(&mut self.inner).poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_write_vectored(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
bufs: &[io::IoSlice<'_>],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_shutdown(cx)
|
||||
}
|
||||
|
||||
fn is_write_vectored(&self) -> bool {
|
||||
self.inner.is_write_vectored()
|
||||
}
|
||||
}
|
||||
@@ -1,19 +1,14 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::json_response;
|
||||
use anyhow::bail;
|
||||
use bytes::Bytes;
|
||||
use futures::StreamExt;
|
||||
use http_body_util::BodyExt;
|
||||
use http_body_util::Full;
|
||||
use hyper1::body::Body;
|
||||
use hyper1::body::Incoming;
|
||||
use hyper1::header;
|
||||
use hyper1::http::HeaderName;
|
||||
use hyper1::http::HeaderValue;
|
||||
use hyper1::Response;
|
||||
use hyper1::StatusCode;
|
||||
use hyper1::{HeaderMap, Request};
|
||||
use hyper::body::HttpBody;
|
||||
use hyper::header;
|
||||
use hyper::http::HeaderName;
|
||||
use hyper::http::HeaderValue;
|
||||
use hyper::Response;
|
||||
use hyper::StatusCode;
|
||||
use hyper::{Body, HeaderMap, Request};
|
||||
use serde_json::json;
|
||||
use serde_json::Value;
|
||||
use tokio::try_join;
|
||||
@@ -27,6 +22,7 @@ use tracing::error;
|
||||
use tracing::info;
|
||||
use url::Url;
|
||||
use utils::http::error::ApiError;
|
||||
use utils::http::json::json_response;
|
||||
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::auth::endpoint_sni;
|
||||
@@ -195,9 +191,9 @@ fn get_conn_info(
|
||||
pub async fn handle(
|
||||
config: &'static ProxyConfig,
|
||||
mut ctx: RequestMonitoring,
|
||||
request: Request<Incoming>,
|
||||
request: Request<Body>,
|
||||
backend: Arc<PoolingBackend>,
|
||||
) -> Result<Response<Full<Bytes>>, ApiError> {
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
let result = tokio::time::timeout(
|
||||
config.http_config.request_timeout,
|
||||
handle_inner(config, &mut ctx, request, backend),
|
||||
@@ -304,18 +300,19 @@ pub async fn handle(
|
||||
}
|
||||
};
|
||||
|
||||
response
|
||||
.headers_mut()
|
||||
.insert("Access-Control-Allow-Origin", HeaderValue::from_static("*"));
|
||||
response.headers_mut().insert(
|
||||
"Access-Control-Allow-Origin",
|
||||
hyper::http::HeaderValue::from_static("*"),
|
||||
);
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
async fn handle_inner(
|
||||
config: &'static ProxyConfig,
|
||||
ctx: &mut RequestMonitoring,
|
||||
request: Request<Incoming>,
|
||||
request: Request<Body>,
|
||||
backend: Arc<PoolingBackend>,
|
||||
) -> anyhow::Result<Response<Full<Bytes>>> {
|
||||
) -> anyhow::Result<Response<Body>> {
|
||||
let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE
|
||||
.with_label_values(&[ctx.protocol])
|
||||
.guard();
|
||||
@@ -372,12 +369,9 @@ async fn handle_inner(
|
||||
}
|
||||
|
||||
let fetch_and_process_request = async {
|
||||
let body = request
|
||||
.into_body()
|
||||
.collect()
|
||||
let body = hyper::body::to_bytes(request.into_body())
|
||||
.await
|
||||
.map_err(anyhow::Error::from)?
|
||||
.to_bytes();
|
||||
.map_err(anyhow::Error::from)?;
|
||||
info!(length = body.len(), "request payload read");
|
||||
let payload: Payload = serde_json::from_slice(&body)?;
|
||||
Ok::<Payload, anyhow::Error>(payload) // Adjust error type accordingly
|
||||
@@ -496,7 +490,7 @@ async fn handle_inner(
|
||||
let body = serde_json::to_string(&result).expect("json serialization should not fail");
|
||||
let len = body.len();
|
||||
let response = response
|
||||
.body(Full::new(Bytes::from(body)))
|
||||
.body(Body::from(body))
|
||||
// only fails if invalid status code or invalid header/values are given.
|
||||
// these are not user configurable so it cannot fail dynamically
|
||||
.expect("building response payload should not fail");
|
||||
|
||||
283
proxy/src/serverless/tls_listener.rs
Normal file
283
proxy/src/serverless/tls_listener.rs
Normal file
@@ -0,0 +1,283 @@
|
||||
use std::{
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use futures::{Future, Stream, StreamExt};
|
||||
use pin_project_lite::pin_project;
|
||||
use thiserror::Error;
|
||||
use tokio::{
|
||||
io::{AsyncRead, AsyncWrite},
|
||||
task::JoinSet,
|
||||
time::timeout,
|
||||
};
|
||||
|
||||
/// Default timeout for the TLS handshake.
|
||||
pub const DEFAULT_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
|
||||
/// Trait for TLS implementation.
|
||||
///
|
||||
/// Implementations are provided by the rustls and native-tls features.
|
||||
pub trait AsyncTls<C: AsyncRead + AsyncWrite>: Clone {
|
||||
/// The type of the TLS stream created from the underlying stream.
|
||||
type Stream: Send + 'static;
|
||||
/// Error type for completing the TLS handshake
|
||||
type Error: std::error::Error + Send + 'static;
|
||||
/// Type of the Future for the TLS stream that is accepted.
|
||||
type AcceptFuture: Future<Output = Result<Self::Stream, Self::Error>> + Send + 'static;
|
||||
|
||||
/// Accept a TLS connection on an underlying stream
|
||||
fn accept(&self, stream: C) -> Self::AcceptFuture;
|
||||
}
|
||||
|
||||
/// Asynchronously accept connections.
|
||||
pub trait AsyncAccept {
|
||||
/// The type of the connection that is accepted.
|
||||
type Connection: AsyncRead + AsyncWrite;
|
||||
/// The type of error that may be returned.
|
||||
type Error;
|
||||
|
||||
/// Poll to accept the next connection.
|
||||
fn poll_accept(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<Self::Connection, Self::Error>>>;
|
||||
|
||||
/// Return a new `AsyncAccept` that stops accepting connections after
|
||||
/// `ender` completes.
|
||||
///
|
||||
/// Useful for graceful shutdown.
|
||||
///
|
||||
/// See [examples/echo.rs](https://github.com/tmccombs/tls-listener/blob/main/examples/echo.rs)
|
||||
/// for example of how to use.
|
||||
fn until<F: Future>(self, ender: F) -> Until<Self, F>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
Until {
|
||||
acceptor: self,
|
||||
ender,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
///
|
||||
/// Wraps a `Stream` of connections (such as a TCP listener) so that each connection is itself
|
||||
/// encrypted using TLS.
|
||||
///
|
||||
/// It is similar to:
|
||||
///
|
||||
/// ```ignore
|
||||
/// tcpListener.and_then(|s| tlsAcceptor.accept(s))
|
||||
/// ```
|
||||
///
|
||||
/// except that it has the ability to accept multiple transport-level connections
|
||||
/// simultaneously while the TLS handshake is pending for other connections.
|
||||
///
|
||||
/// By default, if a client fails the TLS handshake, that is treated as an error, and the
|
||||
/// `TlsListener` will return an `Err`. If the `TlsListener` is passed directly to a hyper
|
||||
/// [`Server`][1], then an invalid handshake can cause the server to stop accepting connections.
|
||||
/// See [`http-stream.rs`][2] or [`http-low-level`][3] examples, for examples of how to avoid this.
|
||||
///
|
||||
/// Note that if the maximum number of pending connections is greater than 1, the resulting
|
||||
/// [`T::Stream`][4] connections may come in a different order than the connections produced by the
|
||||
/// underlying listener.
|
||||
///
|
||||
/// [1]: https://docs.rs/hyper/latest/hyper/server/struct.Server.html
|
||||
/// [2]: https://github.com/tmccombs/tls-listener/blob/main/examples/http-stream.rs
|
||||
/// [3]: https://github.com/tmccombs/tls-listener/blob/main/examples/http-low-level.rs
|
||||
/// [4]: AsyncTls::Stream
|
||||
///
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub struct TlsListener<A: AsyncAccept, T: AsyncTls<A::Connection>> {
|
||||
#[pin]
|
||||
listener: A,
|
||||
tls: T,
|
||||
waiting: JoinSet<Result<Result<T::Stream, T::Error>, tokio::time::error::Elapsed>>,
|
||||
timeout: Duration,
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for `TlsListener`.
|
||||
#[derive(Clone)]
|
||||
pub struct Builder<T> {
|
||||
tls: T,
|
||||
handshake_timeout: Duration,
|
||||
}
|
||||
|
||||
/// Wraps errors from either the listener or the TLS Acceptor
|
||||
#[derive(Debug, Error)]
|
||||
pub enum Error<LE: std::error::Error, TE: std::error::Error> {
|
||||
/// An error that arose from the listener ([AsyncAccept::Error])
|
||||
#[error("{0}")]
|
||||
ListenerError(#[source] LE),
|
||||
/// An error that occurred during the TLS accept handshake
|
||||
#[error("{0}")]
|
||||
TlsAcceptError(#[source] TE),
|
||||
}
|
||||
|
||||
impl<A: AsyncAccept, T> TlsListener<A, T>
|
||||
where
|
||||
T: AsyncTls<A::Connection>,
|
||||
{
|
||||
/// Create a `TlsListener` with default options.
|
||||
pub fn new(tls: T, listener: A) -> Self {
|
||||
builder(tls).listen(listener)
|
||||
}
|
||||
}
|
||||
|
||||
impl<A, T> TlsListener<A, T>
|
||||
where
|
||||
A: AsyncAccept,
|
||||
A::Error: std::error::Error,
|
||||
T: AsyncTls<A::Connection>,
|
||||
{
|
||||
/// Accept the next connection
|
||||
///
|
||||
/// This is essentially an alias to `self.next()` with a more domain-appropriate name.
|
||||
pub async fn accept(&mut self) -> Option<<Self as Stream>::Item>
|
||||
where
|
||||
Self: Unpin,
|
||||
{
|
||||
self.next().await
|
||||
}
|
||||
|
||||
/// Replaces the Tls Acceptor configuration, which will be used for new connections.
|
||||
///
|
||||
/// This can be used to change the certificate used at runtime.
|
||||
pub fn replace_acceptor(&mut self, acceptor: T) {
|
||||
self.tls = acceptor;
|
||||
}
|
||||
|
||||
/// Replaces the Tls Acceptor configuration from a pinned reference to `Self`.
|
||||
///
|
||||
/// This is useful if your listener is `!Unpin`.
|
||||
///
|
||||
/// This can be used to change the certificate used at runtime.
|
||||
pub fn replace_acceptor_pin(self: Pin<&mut Self>, acceptor: T) {
|
||||
*self.project().tls = acceptor;
|
||||
}
|
||||
}
|
||||
|
||||
impl<A, T> Stream for TlsListener<A, T>
|
||||
where
|
||||
A: AsyncAccept,
|
||||
A::Error: std::error::Error,
|
||||
T: AsyncTls<A::Connection>,
|
||||
{
|
||||
type Item = Result<T::Stream, Error<A::Error, T::Error>>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let mut this = self.project();
|
||||
|
||||
loop {
|
||||
match this.listener.as_mut().poll_accept(cx) {
|
||||
Poll::Pending => break,
|
||||
Poll::Ready(Some(Ok(conn))) => {
|
||||
this.waiting
|
||||
.spawn(timeout(*this.timeout, this.tls.accept(conn)));
|
||||
}
|
||||
Poll::Ready(Some(Err(e))) => {
|
||||
return Poll::Ready(Some(Err(Error::ListenerError(e))));
|
||||
}
|
||||
Poll::Ready(None) => return Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
|
||||
loop {
|
||||
return match this.waiting.poll_join_next(cx) {
|
||||
Poll::Ready(Some(Ok(Ok(conn)))) => {
|
||||
Poll::Ready(Some(conn.map_err(Error::TlsAcceptError)))
|
||||
}
|
||||
// The handshake timed out, try getting another connection from the queue
|
||||
Poll::Ready(Some(Ok(Err(_)))) => continue,
|
||||
// The handshake panicked
|
||||
Poll::Ready(Some(Err(e))) if e.is_panic() => {
|
||||
std::panic::resume_unwind(e.into_panic())
|
||||
}
|
||||
// The handshake was externally aborted
|
||||
Poll::Ready(Some(Err(_))) => unreachable!("handshake tasks are never aborted"),
|
||||
_ => Poll::Pending,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: AsyncRead + AsyncWrite + Unpin + Send + 'static> AsyncTls<C> for tokio_rustls::TlsAcceptor {
|
||||
type Stream = tokio_rustls::server::TlsStream<C>;
|
||||
type Error = std::io::Error;
|
||||
type AcceptFuture = tokio_rustls::Accept<C>;
|
||||
|
||||
fn accept(&self, conn: C) -> Self::AcceptFuture {
|
||||
tokio_rustls::TlsAcceptor::accept(self, conn)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Builder<T> {
|
||||
/// Set the timeout for handshakes.
|
||||
///
|
||||
/// If a timeout takes longer than `timeout`, then the handshake will be
|
||||
/// aborted and the underlying connection will be dropped.
|
||||
///
|
||||
/// Defaults to `DEFAULT_HANDSHAKE_TIMEOUT`.
|
||||
pub fn handshake_timeout(&mut self, timeout: Duration) -> &mut Self {
|
||||
self.handshake_timeout = timeout;
|
||||
self
|
||||
}
|
||||
|
||||
/// Create a `TlsListener` from the builder
|
||||
///
|
||||
/// Actually build the `TlsListener`. The `listener` argument should be
|
||||
/// an implementation of the `AsyncAccept` trait that accepts new connections
|
||||
/// that the `TlsListener` will encrypt using TLS.
|
||||
pub fn listen<A: AsyncAccept>(&self, listener: A) -> TlsListener<A, T>
|
||||
where
|
||||
T: AsyncTls<A::Connection>,
|
||||
{
|
||||
TlsListener {
|
||||
listener,
|
||||
tls: self.tls.clone(),
|
||||
waiting: JoinSet::new(),
|
||||
timeout: self.handshake_timeout,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new Builder for a TlsListener
|
||||
///
|
||||
/// `server_config` will be used to configure the TLS sessions.
|
||||
pub fn builder<T>(tls: T) -> Builder<T> {
|
||||
Builder {
|
||||
tls,
|
||||
handshake_timeout: DEFAULT_HANDSHAKE_TIMEOUT,
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
/// See [`AsyncAccept::until`]
|
||||
pub struct Until<A, E> {
|
||||
#[pin]
|
||||
acceptor: A,
|
||||
#[pin]
|
||||
ender: E,
|
||||
}
|
||||
}
|
||||
|
||||
impl<A: AsyncAccept, E: Future> AsyncAccept for Until<A, E> {
|
||||
type Connection = A::Connection;
|
||||
type Error = A::Error;
|
||||
|
||||
fn poll_accept(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<Self::Connection, Self::Error>>> {
|
||||
let this = self.project();
|
||||
|
||||
match this.ender.poll(cx) {
|
||||
Poll::Pending => this.acceptor.poll_accept(cx),
|
||||
Poll::Ready(_) => Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -15,7 +15,7 @@ def test_migrations(neon_simple_env: NeonEnv):
|
||||
|
||||
endpoint.wait_for_migrations()
|
||||
|
||||
num_migrations = 8
|
||||
num_migrations = 9
|
||||
|
||||
with endpoint.cursor() as cur:
|
||||
cur.execute("SELECT id FROM neon_migration.migration_id")
|
||||
|
||||
@@ -29,3 +29,34 @@ def test_neon_extension(neon_env_builder: NeonEnvBuilder):
|
||||
log.info(res)
|
||||
assert len(res) == 1
|
||||
assert len(res[0]) == 5
|
||||
|
||||
|
||||
# Verify that the neon extension can be upgraded/downgraded.
|
||||
def test_neon_extension_compatibility(neon_env_builder: NeonEnvBuilder):
|
||||
env = neon_env_builder.init_start()
|
||||
env.neon_cli.create_branch("test_neon_extension_compatibility")
|
||||
|
||||
endpoint_main = env.endpoints.create("test_neon_extension_compatibility")
|
||||
# don't skip pg_catalog updates - it runs CREATE EXTENSION neon
|
||||
endpoint_main.respec(skip_pg_catalog_updates=False)
|
||||
endpoint_main.start()
|
||||
|
||||
with closing(endpoint_main.connect()) as conn:
|
||||
with conn.cursor() as cur:
|
||||
all_versions = ["1.3", "1.2", "1.1", "1.0"]
|
||||
current_version = "1.3"
|
||||
for idx, begin_version in enumerate(all_versions):
|
||||
for target_version in all_versions[idx + 1 :]:
|
||||
if current_version != begin_version:
|
||||
cur.execute(
|
||||
f"ALTER EXTENSION neon UPDATE TO '{begin_version}'; -- {current_version}->{begin_version}"
|
||||
)
|
||||
current_version = begin_version
|
||||
# downgrade
|
||||
cur.execute(
|
||||
f"ALTER EXTENSION neon UPDATE TO '{target_version}'; -- {begin_version}->{target_version}"
|
||||
)
|
||||
# upgrade
|
||||
cur.execute(
|
||||
f"ALTER EXTENSION neon UPDATE TO '{begin_version}'; -- {target_version}->{begin_version}"
|
||||
)
|
||||
|
||||
@@ -190,6 +190,8 @@ def test_delete_tenant_exercise_crash_safety_failpoints(
|
||||
# So by ignoring these instead of waiting for empty upload queue
|
||||
# we execute more distinct code paths.
|
||||
'.*stopping left-over name="remote upload".*',
|
||||
# an on-demand is cancelled by shutdown
|
||||
".*initial size calculation failed: downloading failed, possibly for shutdown",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -213,7 +213,9 @@ def test_delete_timeline_exercise_crash_safety_failpoints(
|
||||
# This happens when timeline remains are cleaned up during loading
|
||||
".*Timeline dir entry become invalid.*",
|
||||
# In one of the branches we poll for tenant to become active. Polls can generate this log message:
|
||||
f".*Tenant {env.initial_tenant} is not active*",
|
||||
f".*Tenant {env.initial_tenant} is not active.*",
|
||||
# an on-demand is cancelled by shutdown
|
||||
".*initial size calculation failed: downloading failed, possibly for shutdown",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -64,7 +64,7 @@ rustls = { version = "0.21", features = ["dangerous_configuration"] }
|
||||
scopeguard = { version = "1" }
|
||||
serde = { version = "1", features = ["alloc", "derive"] }
|
||||
serde_json = { version = "1", features = ["raw_value"] }
|
||||
smallvec = { version = "1", default-features = false, features = ["const_new", "write"] }
|
||||
smallvec = { version = "1", default-features = false, features = ["write"] }
|
||||
subtle = { version = "2" }
|
||||
time = { version = "0.3", features = ["local-offset", "macros", "serde-well-known"] }
|
||||
tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "net", "process", "rt-multi-thread", "signal", "test-util"] }
|
||||
@@ -76,6 +76,7 @@ tonic = { version = "0.9", features = ["tls-roots"] }
|
||||
tower = { version = "0.4", default-features = false, features = ["balance", "buffer", "limit", "log", "timeout", "util"] }
|
||||
tracing = { version = "0.1", features = ["log"] }
|
||||
tracing-core = { version = "0.1" }
|
||||
tungstenite = { version = "0.20" }
|
||||
url = { version = "2", features = ["serde"] }
|
||||
uuid = { version = "1", features = ["serde", "v4", "v7"] }
|
||||
zeroize = { version = "1", features = ["derive"] }
|
||||
|
||||
Reference in New Issue
Block a user