Compare commits

..

13 Commits

Author SHA1 Message Date
Konstantin Knizhnik
42d0b040f8 Fix merge conflicts 2024-03-11 17:34:05 +02:00
Konstantin Knizhnik
9832638c09 Add compression tag to BLOBs stored in image layer 2024-03-10 22:10:12 +02:00
Konstantin Knizhnik
62e7638c69 Store compression algorithm in image layer metadata 2024-03-10 21:54:47 +02:00
Konstantin Knizhnik
0dad8e427d Update pageserver/src/walingest.rs
Co-authored-by: Joonas Koivunen <joonas@neon.tech>
2024-03-10 21:52:03 +02:00
Konstantin Knizhnik
4cfa2fdca5 Support compression of get_page responses 2024-03-10 21:52:01 +02:00
Konstantin Knizhnik
56ddf8e37f Build Postgres with lz4 support 2024-03-10 21:49:34 +02:00
Konstantin Knizhnik
ed4bb3073f Resolve merge conflict 2024-03-10 21:48:47 +02:00
Konstantin Knizhnik
b7e7aeed4d Peform compression of page images in storage 2024-03-10 21:48:00 +02:00
Konstantin Knizhnik
a880178cca Support lx4 WAL compression 2024-03-10 21:38:00 +02:00
Joonas Koivunen
b09d686335 fix: on-demand downloads can outlive timeline shutdown (#7051)
## Problem

Before this PR, it was possible that on-demand downloads were started
after `Timeline::shutdown()`.

For example, we have observed a walreceiver-connection-handler-initiated
on-demand download that was started after `Timeline::shutdown()`s final
`task_mgr::shutdown_tasks()` call.

The underlying issue is that `task_mgr::shutdown_tasks()` isn't sticky,
i.e., new tasks can be spawned during or after
`task_mgr::shutdown_tasks()`.

Cc: https://github.com/neondatabase/neon/issues/4175 in lieu of a more
specific issue for task_mgr. We already decided we want to get rid of it
anyways.

Original investigation:
https://neondb.slack.com/archives/C033RQ5SPDH/p1709824952465949

## Changes

- enter gate while downloading
- use timeline cancellation token for cancelling download

thereby, fixes #7054

Entering the gate might also remove recent "kept the gate from closing"
in staging.
2024-03-09 13:09:08 +00:00
Christian Schwarz
74d24582cf throttling: exclude throttled time from basebackup (fixup of #6953) (#7072)
PR #6953 only excluded throttled time from the handle_pagerequests
(aka smgr metrics).

This PR implements the deduction for `basebackup ` queries.

The other page_service methods either don't use Timeline::get
or they aren't used in production.

Found by manually inspecting in [staging
logs](https://neonprod.grafana.net/explore?schemaVersion=1&panes=%7B%22wx8%22:%7B%22datasource%22:%22xHHYY0dVz%22,%22queries%22:%5B%7B%22refId%22:%22A%22,%22expr%22:%22%7Bhostname%3D%5C%22pageserver-0.eu-west-1.aws.neon.build%5C%22%7D%20%7C~%20%60git-env%7CERR%7CWARN%60%22,%22queryType%22:%22range%22,%22datasource%22:%7B%22type%22:%22loki%22,%22uid%22:%22xHHYY0dVz%22%7D,%22editorMode%22:%22code%22%7D%5D,%22range%22:%7B%22to%22:%221709919114642%22,%22from%22:%221709904430898%22%7D%7D%7D).
2024-03-09 13:37:02 +01:00
Sasha Krassovsky
4834d22d2d Revoke REPLICATION (#7052)
## Problem
Currently users can cause problems with replication
## Summary of changes
Don't let them replicate
2024-03-08 22:24:30 +00:00
Anastasia Lubennikova
86e8c43ddf Add downgrade scripts for neon extension. (#7065)
## Problem

When we start compute with newer version of extension (i.e. 1.2) and
then rollback the release, downgrading the compute version, next compute
start will try to update extension to the latest version available in
neon.control (i.e. 1.1).

Thus we need to provide downgrade scripts like neon--1.2--1.1.sql

These scripts must revert the changes made by the upgrade scripts in the
reverse order. This is necessary to ensure that the next upgrade will
work correctly.

In general, we need to write upgrade and downgrade scripts to be more
robust and add IF EXISTS / CREATE OR REPLACE clauses to all statements
(where applicable).

## Summary of changes
Adds downgrade scripts.
Adds test cases for extension downgrade/upgrade. 

fixes #7066

This is a follow-up for
https://app.incident.io/neondb/incidents/167?tab=follow-ups

Signed-off-by: Alex Chi Z <chi@neon.tech>
Co-authored-by: Alex Chi Z <iskyzh@gmail.com>
Co-authored-by: Anastasia Lubennikova <anastasia@neon.tech>
2024-03-08 20:42:35 +00:00
47 changed files with 1051 additions and 865 deletions

216
Cargo.lock generated
View File

@@ -285,7 +285,7 @@ dependencies = [
"futures",
"git-version",
"humantime",
"hyper 0.14.26",
"hyper",
"metrics",
"once_cell",
"pageserver_api",
@@ -331,7 +331,7 @@ dependencies = [
"fastrand 2.0.0",
"hex",
"http 0.2.9",
"hyper 0.14.26",
"hyper",
"ring 0.17.6",
"time",
"tokio",
@@ -368,7 +368,7 @@ dependencies = [
"bytes",
"fastrand 2.0.0",
"http 0.2.9",
"http-body 0.4.5",
"http-body",
"percent-encoding",
"pin-project-lite",
"tracing",
@@ -396,7 +396,7 @@ dependencies = [
"aws-types",
"bytes",
"http 0.2.9",
"http-body 0.4.5",
"http-body",
"once_cell",
"percent-encoding",
"regex-lite",
@@ -547,7 +547,7 @@ dependencies = [
"crc32fast",
"hex",
"http 0.2.9",
"http-body 0.4.5",
"http-body",
"md-5",
"pin-project-lite",
"sha1",
@@ -579,7 +579,7 @@ dependencies = [
"bytes-utils",
"futures-core",
"http 0.2.9",
"http-body 0.4.5",
"http-body",
"once_cell",
"percent-encoding",
"pin-project-lite",
@@ -618,10 +618,10 @@ dependencies = [
"aws-smithy-types",
"bytes",
"fastrand 2.0.0",
"h2 0.3.24",
"h2",
"http 0.2.9",
"http-body 0.4.5",
"hyper 0.14.26",
"http-body",
"hyper",
"hyper-rustls",
"once_cell",
"pin-project-lite",
@@ -658,7 +658,7 @@ dependencies = [
"bytes-utils",
"futures-core",
"http 0.2.9",
"http-body 0.4.5",
"http-body",
"itoa",
"num-integer",
"pin-project-lite",
@@ -707,8 +707,8 @@ dependencies = [
"bytes",
"futures-util",
"http 0.2.9",
"http-body 0.4.5",
"hyper 0.14.26",
"http-body",
"hyper",
"itoa",
"matchit",
"memchr",
@@ -723,7 +723,7 @@ dependencies = [
"sha1",
"sync_wrapper",
"tokio",
"tokio-tungstenite 0.20.0",
"tokio-tungstenite",
"tower",
"tower-layer",
"tower-service",
@@ -739,7 +739,7 @@ dependencies = [
"bytes",
"futures-util",
"http 0.2.9",
"http-body 0.4.5",
"http-body",
"mime",
"rustversion",
"tower-layer",
@@ -1228,7 +1228,7 @@ dependencies = [
"compute_api",
"flate2",
"futures",
"hyper 0.14.26",
"hyper",
"nix 0.27.1",
"notify",
"num_cpus",
@@ -1344,7 +1344,7 @@ dependencies = [
"futures",
"git-version",
"hex",
"hyper 0.14.26",
"hyper",
"nix 0.27.1",
"once_cell",
"pageserver_api",
@@ -2244,25 +2244,6 @@ dependencies = [
"tracing",
]
[[package]]
name = "h2"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "31d030e59af851932b72ceebadf4a2b5986dba4c3b99dd2493f8273a0f151943"
dependencies = [
"bytes",
"fnv",
"futures-core",
"futures-sink",
"futures-util",
"http 1.0.0",
"indexmap 2.0.1",
"slab",
"tokio",
"tokio-util",
"tracing",
]
[[package]]
name = "half"
version = "1.8.2"
@@ -2428,29 +2409,6 @@ dependencies = [
"pin-project-lite",
]
[[package]]
name = "http-body"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1cac85db508abc24a2e48553ba12a996e87244a0395ce011e62b37158745d643"
dependencies = [
"bytes",
"http 1.0.0",
]
[[package]]
name = "http-body-util"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "41cb79eb393015dadd30fc252023adb0b2400a0caee0fa2a077e6e21a551e840"
dependencies = [
"bytes",
"futures-util",
"http 1.0.0",
"http-body 1.0.0",
"pin-project-lite",
]
[[package]]
name = "http-types"
version = "2.12.0"
@@ -2509,9 +2467,9 @@ dependencies = [
"futures-channel",
"futures-core",
"futures-util",
"h2 0.3.24",
"h2",
"http 0.2.9",
"http-body 0.4.5",
"http-body",
"httparse",
"httpdate",
"itoa",
@@ -2523,26 +2481,6 @@ dependencies = [
"want",
]
[[package]]
name = "hyper"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "186548d73ac615b32a73aafe38fb4f56c0d340e110e5a200bcadbaf2e199263a"
dependencies = [
"bytes",
"futures-channel",
"futures-util",
"h2 0.4.2",
"http 1.0.0",
"http-body 1.0.0",
"httparse",
"httpdate",
"itoa",
"pin-project-lite",
"smallvec",
"tokio",
]
[[package]]
name = "hyper-rustls"
version = "0.24.0"
@@ -2550,7 +2488,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0646026eb1b3eea4cd9ba47912ea5ce9cc07713d105b1a14698f4e6433d348b7"
dependencies = [
"http 0.2.9",
"hyper 0.14.26",
"hyper",
"log",
"rustls 0.21.9",
"rustls-native-certs",
@@ -2564,7 +2502,7 @@ version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1"
dependencies = [
"hyper 0.14.26",
"hyper",
"pin-project-lite",
"tokio",
"tokio-io-timeout",
@@ -2577,7 +2515,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905"
dependencies = [
"bytes",
"hyper 0.14.26",
"hyper",
"native-tls",
"tokio",
"tokio-native-tls",
@@ -2585,33 +2523,15 @@ dependencies = [
[[package]]
name = "hyper-tungstenite"
version = "0.13.0"
version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a343d17fe7885302ed7252767dc7bb83609a874b6ff581142241ec4b73957ad"
checksum = "7cc7dcb1ab67cd336f468a12491765672e61a3b6b148634dbfe2fe8acd3fe7d9"
dependencies = [
"http-body-util",
"hyper 1.2.0",
"hyper-util",
"hyper",
"pin-project-lite",
"tokio",
"tokio-tungstenite 0.21.0",
"tungstenite 0.21.0",
]
[[package]]
name = "hyper-util"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ca38ef113da30126bbff9cd1705f9273e15d45498615d138b0c20279ac7a76aa"
dependencies = [
"bytes",
"futures-util",
"http 1.0.0",
"http-body 1.0.0",
"hyper 1.2.0",
"pin-project-lite",
"socket2 0.5.5",
"tokio",
"tokio-tungstenite",
"tungstenite",
]
[[package]]
@@ -2921,6 +2841,15 @@ version = "0.4.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f"
[[package]]
name = "lz4_flex"
version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3ea9b256699eda7b0387ffbc776dd625e28bde3918446381781245b7a50349d8"
dependencies = [
"twox-hash",
]
[[package]]
name = "match_cfg"
version = "0.1.0"
@@ -3588,9 +3517,10 @@ dependencies = [
"hex-literal",
"humantime",
"humantime-serde",
"hyper 0.14.26",
"hyper",
"itertools",
"leaky-bucket",
"lz4_flex",
"md5",
"metrics",
"nix 0.27.1",
@@ -4258,13 +4188,9 @@ dependencies = [
"hex",
"hmac",
"hostname",
"http 1.0.0",
"http-body-util",
"humantime",
"hyper 0.14.26",
"hyper 1.2.0",
"hyper",
"hyper-tungstenite",
"hyper-util",
"ipnet",
"itertools",
"lasso",
@@ -4596,7 +4522,7 @@ dependencies = [
"futures-util",
"http-types",
"humantime",
"hyper 0.14.26",
"hyper",
"itertools",
"metrics",
"once_cell",
@@ -4626,10 +4552,10 @@ dependencies = [
"encoding_rs",
"futures-core",
"futures-util",
"h2 0.3.24",
"h2",
"http 0.2.9",
"http-body 0.4.5",
"hyper 0.14.26",
"http-body",
"hyper",
"hyper-rustls",
"hyper-tls",
"ipnet",
@@ -4687,7 +4613,7 @@ dependencies = [
"futures",
"getrandom 0.2.11",
"http 0.2.9",
"hyper 0.14.26",
"hyper",
"parking_lot 0.11.2",
"reqwest",
"reqwest-middleware",
@@ -4774,7 +4700,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "496c1d3718081c45ba9c31fbfc07417900aa96f4070ff90dc29961836b7a9945"
dependencies = [
"http 0.2.9",
"hyper 0.14.26",
"hyper",
"lazy_static",
"percent-encoding",
"regex",
@@ -5053,7 +4979,7 @@ dependencies = [
"git-version",
"hex",
"humantime",
"hyper 0.14.26",
"hyper",
"metrics",
"once_cell",
"parking_lot 0.12.1",
@@ -5528,9 +5454,9 @@ dependencies = [
[[package]]
name = "smallvec"
version = "1.13.1"
version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7"
checksum = "62bb4feee49fdd9f707ef802e22365a35de4b7b299de4763d44bfea899442ff9"
[[package]]
name = "smol_str"
@@ -5622,7 +5548,7 @@ dependencies = [
"futures-util",
"git-version",
"humantime",
"hyper 0.14.26",
"hyper",
"metrics",
"once_cell",
"parking_lot 0.12.1",
@@ -6106,19 +6032,7 @@ dependencies = [
"futures-util",
"log",
"tokio",
"tungstenite 0.20.1",
]
[[package]]
name = "tokio-tungstenite"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38"
dependencies = [
"futures-util",
"log",
"tokio",
"tungstenite 0.21.0",
"tungstenite",
]
[[package]]
@@ -6185,10 +6099,10 @@ dependencies = [
"bytes",
"futures-core",
"futures-util",
"h2 0.3.24",
"h2",
"http 0.2.9",
"http-body 0.4.5",
"hyper 0.14.26",
"http-body",
"hyper",
"hyper-timeout",
"percent-encoding",
"pin-project",
@@ -6374,7 +6288,7 @@ dependencies = [
name = "tracing-utils"
version = "0.1.0"
dependencies = [
"hyper 0.14.26",
"hyper",
"opentelemetry",
"opentelemetry-otlp",
"opentelemetry-semantic-conventions",
@@ -6411,25 +6325,6 @@ dependencies = [
"utf-8",
]
[[package]]
name = "tungstenite"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1"
dependencies = [
"byteorder",
"bytes",
"data-encoding",
"http 1.0.0",
"httparse",
"log",
"rand 0.8.5",
"sha1",
"thiserror",
"url",
"utf-8",
]
[[package]]
name = "twox-hash"
version = "1.6.3"
@@ -6593,7 +6488,7 @@ dependencies = [
"heapless",
"hex",
"hex-literal",
"hyper 0.14.26",
"hyper",
"jsonwebtoken",
"leaky-bucket",
"metrics",
@@ -7118,7 +7013,7 @@ dependencies = [
"hashbrown 0.14.0",
"hex",
"hmac",
"hyper 0.14.26",
"hyper",
"indexmap 1.9.3",
"itertools",
"libc",
@@ -7155,6 +7050,7 @@ dependencies = [
"tower",
"tracing",
"tracing-core",
"tungstenite",
"url",
"uuid",
"zeroize",

View File

@@ -92,7 +92,7 @@ http-types = { version = "2", default-features = false }
humantime = "2.1"
humantime-serde = "1.1.1"
hyper = "0.14"
hyper-tungstenite = "0.13.0"
hyper-tungstenite = "0.11"
inotify = "0.10.2"
ipnet = "2.9.0"
itertools = "0.10"
@@ -100,6 +100,7 @@ jsonwebtoken = "9"
lasso = "0.7"
leaky-bucket = "1.0.1"
libc = "0.2"
lz4_flex = "0.11.1"
md5 = "0.7.0"
memoffset = "0.8"
native-tls = "0.2"

View File

@@ -23,12 +23,12 @@ endif
UNAME_S := $(shell uname -s)
ifeq ($(UNAME_S),Linux)
# Seccomp BPF is only available for Linux
PG_CONFIGURE_OPTS += --with-libseccomp
PG_CONFIGURE_OPTS += --with-lz4 --with-libseccomp
else ifeq ($(UNAME_S),Darwin)
# macOS with brew-installed openssl requires explicit paths
# It can be configured with OPENSSL_PREFIX variable
OPENSSL_PREFIX ?= $(shell brew --prefix openssl@3)
PG_CONFIGURE_OPTS += --with-includes=$(OPENSSL_PREFIX)/include --with-libraries=$(OPENSSL_PREFIX)/lib
PG_CONFIGURE_OPTS += --with-lz4 --with-includes=$(OPENSSL_PREFIX)/include --with-libraries=$(OPENSSL_PREFIX)/lib
PG_CONFIGURE_OPTS += PKG_CONFIG_PATH=$(shell brew --prefix icu4c)/lib/pkgconfig
# macOS already has bison and flex in the system, but they are old and result in postgres-v14 target failure
# brew formulae are keg-only and not symlinked into HOMEBREW_PREFIX, force their usage

View File

@@ -302,9 +302,9 @@ pub fn handle_roles(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
RoleAction::Create => {
// This branch only runs when roles are created through the console, so it is
// safe to add more permissions here. BYPASSRLS and REPLICATION are inherited
// from neon_superuser.
// from neon_superuser. (NOTE: REPLICATION has been removed from here for now).
let mut query: String = format!(
"CREATE ROLE {} INHERIT CREATEROLE CREATEDB BYPASSRLS REPLICATION IN ROLE neon_superuser",
"CREATE ROLE {} INHERIT CREATEROLE CREATEDB BYPASSRLS IN ROLE neon_superuser",
name.pg_quote()
);
info!("running role create query: '{}'", &query);
@@ -805,6 +805,18 @@ $$;"#,
"",
"",
// Add new migrations below.
r#"
DO $$
DECLARE
role_name TEXT;
BEGIN
FOR role_name IN SELECT rolname FROM pg_roles WHERE rolreplication IS TRUE
LOOP
RAISE NOTICE 'EXECUTING ALTER ROLE % NOREPLICATION', quote_ident(role_name);
EXECUTE 'ALTER ROLE ' || quote_ident(role_name) || ' NOREPLICATION';
END LOOP;
END
$$;"#,
];
let mut query = "CREATE SCHEMA IF NOT EXISTS neon_migration";

View File

@@ -29,7 +29,6 @@ pub mod launch_timestamp;
mod wrappers;
pub use wrappers::{CountedReader, CountedWriter};
mod hll;
pub mod metric_vec_duration;
pub use hll::{HyperLogLog, HyperLogLogVec};
#[cfg(target_os = "linux")]
pub mod more_process_metrics;

View File

@@ -1,23 +0,0 @@
//! Helpers for observing duration on `HistogramVec` / `CounterVec` / `GaugeVec` / `MetricVec<T>`.
use std::{future::Future, time::Instant};
pub trait DurationResultObserver {
fn observe_result<T, E>(&self, res: &Result<T, E>, duration: std::time::Duration);
}
pub async fn observe_async_block_duration_by_result<
T,
E,
F: Future<Output = Result<T, E>>,
O: DurationResultObserver,
>(
observer: &O,
block: F,
) -> Result<T, E> {
let start = Instant::now();
let result = block.await;
let duration = start.elapsed();
observer.observe_result(&result, duration);
result
}

View File

@@ -757,6 +757,7 @@ pub enum PagestreamBeMessage {
Error(PagestreamErrorResponse),
DbSize(PagestreamDbSizeResponse),
GetSlruSegment(PagestreamGetSlruSegmentResponse),
GetCompressedPage(PagestreamGetPageResponse),
}
// Keep in sync with `pagestore_client.h`
@@ -996,6 +997,12 @@ impl PagestreamBeMessage {
bytes.put(&resp.page[..]);
}
Self::GetCompressedPage(resp) => {
bytes.put_u8(105); /* tag from pagestore_client.h */
bytes.put_u16(resp.page.len() as u16);
bytes.put(&resp.page[..]);
}
Self::Error(resp) => {
bytes.put_u8(Tag::Error as u8);
bytes.put(resp.message.as_bytes());
@@ -1078,6 +1085,7 @@ impl PagestreamBeMessage {
Self::Error(_) => "Error",
Self::DbSize(_) => "DbSize",
Self::GetSlruSegment(_) => "GetSlruSegment",
Self::GetCompressedPage(_) => "GetCompressedPage",
}
}
}

View File

@@ -144,6 +144,13 @@ pub fn bkpimage_is_compressed(bimg_info: u8, version: u32) -> anyhow::Result<boo
dispatch_pgversion!(version, Ok(pgv::bindings::bkpimg_is_compressed(bimg_info)))
}
pub fn bkpimage_is_compressed_lz4(bimg_info: u8, version: u32) -> anyhow::Result<bool> {
dispatch_pgversion!(
version,
Ok(pgv::bindings::bkpimg_is_compressed_lz4(bimg_info))
)
}
pub fn generate_wal_segment(
segno: u64,
system_id: u64,

View File

@@ -8,3 +8,7 @@ pub const SIZEOF_RELMAPFILE: usize = 512; /* sizeof(RelMapFile) in relmapper.c *
pub fn bkpimg_is_compressed(bimg_info: u8) -> bool {
(bimg_info & BKPIMAGE_IS_COMPRESSED) != 0
}
pub fn bkpimg_is_compressed_lz4(_bimg_info: u8) -> bool {
false
}

View File

@@ -16,3 +16,7 @@ pub fn bkpimg_is_compressed(bimg_info: u8) -> bool {
(bimg_info & ANY_COMPRESS_FLAG) != 0
}
pub fn bkpimg_is_compressed_lz4(bimg_info: u8) -> bool {
(bimg_info & BKPIMAGE_COMPRESS_LZ4) != 0
}

View File

@@ -16,3 +16,7 @@ pub fn bkpimg_is_compressed(bimg_info: u8) -> bool {
(bimg_info & ANY_COMPRESS_FLAG) != 0
}
pub fn bkpimg_is_compressed_lz4(bimg_info: u8) -> bool {
(bimg_info & BKPIMAGE_COMPRESS_LZ4) != 0
}

View File

@@ -17,6 +17,7 @@ use remote_storage::{
};
use test_context::test_context;
use test_context::AsyncTestContext;
use tokio::io::AsyncBufReadExt;
use tokio_util::sync::CancellationToken;
use tracing::info;
@@ -484,32 +485,33 @@ async fn download_is_cancelled(ctx: &mut MaybeEnabledStorage) {
))
.unwrap();
let len = upload_large_enough_file(&ctx.client, &path, &cancel).await;
let file_len = upload_large_enough_file(&ctx.client, &path, &cancel).await;
{
let mut stream = ctx
let stream = ctx
.client
.download(&path, &cancel)
.await
.expect("download succeeds")
.download_stream;
let first = stream
.next()
.await
.expect("should have the first blob")
.expect("should have succeeded");
let mut reader = std::pin::pin!(tokio_util::io::StreamReader::new(stream));
tracing::info!(len = first.len(), "downloaded first chunk");
let first = reader.fill_buf().await.expect("should have the first blob");
let len = first.len();
tracing::info!(len, "downloaded first chunk");
assert!(
first.len() < len,
first.len() < file_len,
"uploaded file is too small, we downloaded all on first chunk"
);
reader.consume(len);
cancel.cancel();
let next = stream.next().await.expect("stream should have more");
let next = reader.fill_buf().await;
let e = next.expect_err("expected an error, but got a chunk?");
@@ -520,6 +522,10 @@ async fn download_is_cancelled(ctx: &mut MaybeEnabledStorage) {
.is_some_and(|e| matches!(e, DownloadError::Cancelled)),
"{inner:?}"
);
let e = DownloadError::from(e);
assert!(matches!(e, DownloadError::Cancelled), "{e:?}");
}
let cancel = CancellationToken::new();

View File

@@ -37,6 +37,7 @@ humantime-serde.workspace = true
hyper.workspace = true
itertools.workspace = true
leaky-bucket.workspace = true
lz4_flex.workspace = true
md5.workspace = true
nix.workspace = true
# hack to get the number of worker threads tokio uses

View File

@@ -157,6 +157,7 @@ impl PagestreamClient {
PagestreamBeMessage::Exists(_)
| PagestreamBeMessage::Nblocks(_)
| PagestreamBeMessage::DbSize(_)
| PagestreamBeMessage::GetCompressedPage(_)
| PagestreamBeMessage::GetSlruSegment(_) => {
anyhow::bail!(
"unexpected be message kind in response to getpage request: {}",

View File

@@ -40,7 +40,14 @@ use tracing::info;
/// format, bump this!
/// Note that TimelineMetadata uses its own version number to track
/// backwards-compatible changes to the metadata format.
pub const STORAGE_FORMAT_VERSION: u16 = 3;
pub const STORAGE_FORMAT_VERSION: u16 = 4;
/// Minimal sorage format version with compression support
pub const COMPRESSED_STORAGE_FORMAT_VERSION: u16 = 4;
/// Page image compression algorithm
pub const NO_COMPRESSION: u8 = 0;
pub const LZ4_COMPRESSION: u8 = 0;
pub const DEFAULT_PG_VERSION: u32 = 15;

View File

@@ -1,5 +1,4 @@
use enum_map::EnumMap;
use metrics::metric_vec_duration::DurationResultObserver;
use metrics::{
register_counter_vec, register_gauge_vec, register_histogram, register_histogram_vec,
register_int_counter, register_int_counter_pair_vec, register_int_counter_vec,
@@ -1283,11 +1282,65 @@ pub(crate) static BASEBACKUP_QUERY_TIME: Lazy<BasebackupQueryTime> = Lazy::new(|
})
});
impl DurationResultObserver for BasebackupQueryTime {
fn observe_result<T, E>(&self, res: &Result<T, E>, duration: std::time::Duration) {
pub(crate) struct BasebackupQueryTimeOngoingRecording<'a, 'c> {
parent: &'a BasebackupQueryTime,
ctx: &'c RequestContext,
start: std::time::Instant,
}
impl BasebackupQueryTime {
pub(crate) fn start_recording<'c: 'a, 'a>(
&'a self,
ctx: &'c RequestContext,
) -> BasebackupQueryTimeOngoingRecording<'_, '_> {
let start = Instant::now();
match ctx.micros_spent_throttled.open() {
Ok(()) => (),
Err(error) => {
use utils::rate_limit::RateLimit;
static LOGGED: Lazy<Mutex<RateLimit>> =
Lazy::new(|| Mutex::new(RateLimit::new(Duration::from_secs(10))));
let mut rate_limit = LOGGED.lock().unwrap();
rate_limit.call(|| {
warn!(error, "error opening micros_spent_throttled; this message is logged at a global rate limit");
});
}
}
BasebackupQueryTimeOngoingRecording {
parent: self,
ctx,
start,
}
}
}
impl<'a, 'c> BasebackupQueryTimeOngoingRecording<'a, 'c> {
pub(crate) fn observe<T, E>(self, res: &Result<T, E>) {
let elapsed = self.start.elapsed();
let ex_throttled = self
.ctx
.micros_spent_throttled
.close_and_checked_sub_from(elapsed);
let ex_throttled = match ex_throttled {
Ok(ex_throttled) => ex_throttled,
Err(error) => {
use utils::rate_limit::RateLimit;
static LOGGED: Lazy<Mutex<RateLimit>> =
Lazy::new(|| Mutex::new(RateLimit::new(Duration::from_secs(10))));
let mut rate_limit = LOGGED.lock().unwrap();
rate_limit.call(|| {
warn!(error, "error deducting time spent throttled; this message is logged at a global rate limit");
});
elapsed
}
};
let label_value = if res.is_ok() { "ok" } else { "error" };
let metric = self.0.get_metric_with_label_values(&[label_value]).unwrap();
metric.observe(duration.as_secs_f64());
let metric = self
.parent
.0
.get_metric_with_label_values(&[label_value])
.unwrap();
metric.observe(ex_throttled.as_secs_f64());
}
}

View File

@@ -1155,9 +1155,18 @@ impl PageServerHandler {
.get_rel_page_at_lsn(req.rel, req.blkno, Version::Lsn(lsn), req.latest, ctx)
.await?;
Ok(PagestreamBeMessage::GetPage(PagestreamGetPageResponse {
page,
}))
let compressed = lz4_flex::block::compress(&page);
if compressed.len() < page.len() {
Ok(PagestreamBeMessage::GetCompressedPage(
PagestreamGetPageResponse {
page: Bytes::from(compressed),
},
))
} else {
Ok(PagestreamBeMessage::GetPage(PagestreamGetPageResponse {
page,
}))
}
}
#[instrument(skip_all, fields(shard_id))]
@@ -1199,7 +1208,7 @@ impl PageServerHandler {
prev_lsn: Option<Lsn>,
full_backup: bool,
gzip: bool,
ctx: RequestContext,
ctx: &RequestContext,
) -> Result<(), QueryError>
where
IO: AsyncRead + AsyncWrite + Send + Sync + Unpin,
@@ -1214,7 +1223,7 @@ impl PageServerHandler {
if let Some(lsn) = lsn {
// Backup was requested at a particular LSN. Wait for it to arrive.
info!("waiting for {}", lsn);
timeline.wait_lsn(lsn, &ctx).await?;
timeline.wait_lsn(lsn, ctx).await?;
timeline
.check_lsn_is_in_scope(lsn, &latest_gc_cutoff_lsn)
.context("invalid basebackup lsn")?;
@@ -1236,7 +1245,7 @@ impl PageServerHandler {
lsn,
prev_lsn,
full_backup,
&ctx,
ctx,
)
.await?;
} else {
@@ -1257,7 +1266,7 @@ impl PageServerHandler {
lsn,
prev_lsn,
full_backup,
&ctx,
ctx,
)
.await?;
// shutdown the encoder to ensure the gzip footer is written
@@ -1269,7 +1278,7 @@ impl PageServerHandler {
lsn,
prev_lsn,
full_backup,
&ctx,
ctx,
)
.await?;
}
@@ -1449,25 +1458,25 @@ where
false
};
::metrics::metric_vec_duration::observe_async_block_duration_by_result(
&*metrics::BASEBACKUP_QUERY_TIME,
async move {
self.handle_basebackup_request(
pgb,
tenant_id,
timeline_id,
lsn,
None,
false,
gzip,
ctx,
)
.await?;
pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
Result::<(), QueryError>::Ok(())
},
)
.await?;
let metric_recording = metrics::BASEBACKUP_QUERY_TIME.start_recording(&ctx);
let res = async {
self.handle_basebackup_request(
pgb,
tenant_id,
timeline_id,
lsn,
None,
false,
gzip,
&ctx,
)
.await?;
pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
Result::<(), QueryError>::Ok(())
}
.await;
metric_recording.observe(&res);
res?;
}
// return pair of prev_lsn and last_lsn
else if query_string.starts_with("get_last_record_rlsn ") {
@@ -1563,7 +1572,7 @@ where
prev_lsn,
true,
false,
ctx,
&ctx,
)
.await?;
pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;

View File

@@ -16,6 +16,7 @@ use anyhow::{ensure, Context};
use bytes::{Buf, Bytes, BytesMut};
use enum_map::Enum;
use itertools::Itertools;
use lz4_flex;
use pageserver_api::key::{
dbdir_key_range, is_rel_block_key, is_slru_block_key, rel_block_to_key, rel_dir_to_key,
rel_key_range, rel_size_to_key, relmap_file_key, slru_block_to_key, slru_dir_to_key,
@@ -992,7 +993,15 @@ impl<'a> DatadirModification<'a> {
img: Bytes,
) -> anyhow::Result<()> {
anyhow::ensure!(rel.relnode != 0, RelationError::InvalidRelnode);
self.put(rel_block_to_key(rel, blknum), Value::Image(img));
let compressed = lz4_flex::block::compress(&img);
if compressed.len() < img.len() {
self.put(
rel_block_to_key(rel, blknum),
Value::CompressedImage(Bytes::from(compressed)),
);
} else {
self.put(rel_block_to_key(rel, blknum), Value::Image(img));
}
Ok(())
}
@@ -1597,6 +1606,10 @@ impl<'a> DatadirModification<'a> {
if let Some((_, value)) = values.last() {
return if let Value::Image(img) = value {
Ok(img.clone())
} else if let Value::CompressedImage(img) = value {
let decompressed = lz4_flex::block::decompress(&img, BLCKSZ as usize)
.map_err(|msg| PageReconstructError::Other(anyhow::anyhow!(msg)))?;
Ok(Bytes::from(decompressed))
} else {
// Currently, we never need to read back a WAL record that we
// inserted in the same "transaction". All the metadata updates

View File

@@ -13,6 +13,8 @@ pub use pageserver_api::key::{Key, KEY_SIZE};
pub enum Value {
/// An Image value contains a full copy of the value
Image(Bytes),
/// An compressed page image contains a full copy of the page
CompressedImage(Bytes),
/// A WalRecord value contains a WAL record that needs to be
/// replayed get the full value. Replaying the WAL record
/// might need a previous version of the value (if will_init()
@@ -22,12 +24,17 @@ pub enum Value {
impl Value {
pub fn is_image(&self) -> bool {
matches!(self, Value::Image(_))
match self {
Value::Image(_) => true,
Value::CompressedImage(_) => true,
Value::WalRecord(_) => false,
}
}
pub fn will_init(&self) -> bool {
match self {
Value::Image(_) => true,
Value::CompressedImage(_) => true,
Value::WalRecord(rec) => rec.will_init(),
}
}

View File

@@ -272,9 +272,6 @@ pub enum TaskKind {
// Task that uploads a file to remote storage
RemoteUploadTask,
// Task that downloads a file from remote storage
RemoteDownloadTask,
// task that handles the initial downloading of all tenants
InitialLoad,

View File

@@ -11,13 +11,16 @@
//! len < 128: 0XXXXXXX
//! len >= 128: 1XXXXXXX XXXXXXXX XXXXXXXX XXXXXXXX
//!
use bytes::{BufMut, BytesMut};
use bytes::{BufMut, Bytes, BytesMut};
use tokio_epoll_uring::{BoundedBuf, IoBuf, Slice};
use crate::context::RequestContext;
use crate::page_cache::PAGE_SZ;
use crate::tenant::block_io::BlockCursor;
use crate::virtual_file::VirtualFile;
use crate::{LZ4_COMPRESSION, NO_COMPRESSION};
use lz4_flex;
use postgres_ffi::BLCKSZ;
use std::cmp::min;
use std::io::{Error, ErrorKind};
@@ -32,6 +35,29 @@ impl<'a> BlockCursor<'a> {
self.read_blob_into_buf(offset, &mut buf, ctx).await?;
Ok(buf)
}
/// Read blob into the given buffer. Any previous contents in the buffer
/// are overwritten.
pub async fn read_compressed_blob(
&self,
offset: u64,
ctx: &RequestContext,
) -> Result<Vec<u8>, std::io::Error> {
let blknum = (offset / PAGE_SZ as u64) as u32;
let off = (offset % PAGE_SZ as u64) as usize;
let buf = self.read_blk(blknum, ctx).await?;
let compression_alg = buf[off];
let res = self.read_blob(offset + 1, ctx).await?;
if compression_alg == LZ4_COMPRESSION {
lz4_flex::block::decompress(&res, BLCKSZ as usize).map_err(|_| {
std::io::Error::new(std::io::ErrorKind::InvalidData, "decompress error")
})
} else {
assert_eq!(compression_alg, NO_COMPRESSION);
Ok(res)
}
}
/// Read blob into the given buffer. Any previous contents in the buffer
/// are overwritten.
pub async fn read_blob_into_buf(
@@ -211,6 +237,58 @@ impl<const BUFFERED: bool> BlobWriter<BUFFERED> {
(src_buf, Ok(()))
}
pub async fn write_compressed_blob(&mut self, srcbuf: Bytes) -> Result<u64, Error> {
let offset = self.offset;
let len = srcbuf.len();
let mut io_buf = self.io_buf.take().expect("we always put it back below");
io_buf.clear();
let mut is_compressed = false;
if len < 128 {
// Short blob. Write a 1-byte length header
io_buf.put_u8(NO_COMPRESSION);
io_buf.put_u8(len as u8);
} else {
// Write a 4-byte length header
if len > 0x7fff_ffff {
return Err(Error::new(
ErrorKind::Other,
format!("blob too large ({} bytes)", len),
));
}
if len == BLCKSZ as usize {
let compressed = lz4_flex::block::compress(&srcbuf);
if compressed.len() < len {
io_buf.put_u8(LZ4_COMPRESSION);
let mut len_buf = (compressed.len() as u32).to_be_bytes();
len_buf[0] |= 0x80;
io_buf.extend_from_slice(&len_buf[..]);
io_buf.extend_from_slice(&compressed[..]);
is_compressed = true;
}
if is_compressed {
io_buf.put_u8(NO_COMPRESSION);
let mut len_buf = (len as u32).to_be_bytes();
len_buf[0] |= 0x80;
io_buf.extend_from_slice(&len_buf[..]);
}
}
}
let (io_buf, hdr_res) = self.write_all(io_buf).await;
match hdr_res {
Ok(_) => (),
Err(e) => return Err(e),
}
self.io_buf = Some(io_buf);
if is_compressed {
hdr_res.map(|_| offset)
} else {
let (_buf, res) = self.write_all(srcbuf).await;
res.map(|_| offset)
}
}
/// Write a blob of data. Returns the offset that it was written to,
/// which can be used to retrieve the data later.
pub async fn write_blob<B: BoundedBuf<Buf = Buf>, Buf: IoBuf + Send>(
@@ -227,7 +305,6 @@ impl<const BUFFERED: bool> BlobWriter<BUFFERED> {
if len < 128 {
// Short blob. Write a 1-byte length header
io_buf.put_u8(len as u8);
self.write_all(io_buf).await
} else {
// Write a 4-byte length header
if len > 0x7fff_ffff {
@@ -242,8 +319,8 @@ impl<const BUFFERED: bool> BlobWriter<BUFFERED> {
let mut len_buf = (len as u32).to_be_bytes();
len_buf[0] |= 0x80;
io_buf.extend_from_slice(&len_buf[..]);
self.write_all(io_buf).await
}
self.write_all(io_buf).await
}
.await;
self.io_buf = Some(io_buf);

View File

@@ -20,6 +20,7 @@ use pageserver_api::keyspace::{KeySpace, KeySpaceRandomAccum};
use pageserver_api::models::{
LayerAccessKind, LayerResidenceEvent, LayerResidenceEventReason, LayerResidenceStatus,
};
use postgres_ffi::BLCKSZ;
use std::cmp::{Ordering, Reverse};
use std::collections::hash_map::Entry;
use std::collections::{BinaryHeap, HashMap};
@@ -147,12 +148,13 @@ impl ValuesReconstructState {
lsn: Lsn,
value: Value,
) -> ValueReconstructSituation {
let state = self
let mut error: Option<PageReconstructError> = None;
let key_state = self
.keys
.entry(*key)
.or_insert(Ok(VectoredValueReconstructState::default()));
if let Ok(state) = state {
let situation = if let Ok(state) = key_state {
let key_done = match state.situation {
ValueReconstructSituation::Complete => unreachable!(),
ValueReconstructSituation::Continue => match value {
@@ -160,6 +162,21 @@ impl ValuesReconstructState {
state.img = Some((lsn, img));
true
}
Value::CompressedImage(img) => {
match lz4_flex::block::decompress(&img, BLCKSZ as usize) {
Ok(decompressed) => {
state.img = Some((lsn, Bytes::from(decompressed)));
true
}
Err(e) => {
error = Some(PageReconstructError::from(anyhow::anyhow!(
"Failed to decompress blobrom virtual file: {}",
e
)));
true
}
}
}
Value::WalRecord(rec) => {
let reached_cache =
state.get_cached_lsn().map(|clsn| clsn + 1) == Some(lsn);
@@ -178,7 +195,11 @@ impl ValuesReconstructState {
state.situation
} else {
ValueReconstructSituation::Complete
};
if let Some(err) = error {
*key_state = Err(err);
}
situation
}
/// Returns the Lsn at which this key is cached if one exists.

View File

@@ -44,12 +44,13 @@ use crate::virtual_file::{self, VirtualFile};
use crate::{walrecord, TEMP_FILE_SUFFIX};
use crate::{DELTA_FILE_MAGIC, STORAGE_FORMAT_VERSION};
use anyhow::{anyhow, bail, ensure, Context, Result};
use bytes::BytesMut;
use bytes::{Bytes, BytesMut};
use camino::{Utf8Path, Utf8PathBuf};
use futures::StreamExt;
use pageserver_api::keyspace::KeySpace;
use pageserver_api::models::LayerAccessKind;
use pageserver_api::shard::TenantShardId;
use postgres_ffi::BLCKSZ;
use rand::{distributions::Alphanumeric, Rng};
use serde::{Deserialize, Serialize};
use std::fs::File;
@@ -813,6 +814,12 @@ impl DeltaLayerInner {
need_image = false;
break;
}
Value::CompressedImage(img) => {
let decompressed = lz4_flex::block::decompress(&img, BLCKSZ as usize)?;
reconstruct_state.img = Some((entry_lsn, Bytes::from(decompressed)));
need_image = false;
break;
}
Value::WalRecord(rec) => {
let will_init = rec.will_init();
reconstruct_state.records.push((entry_lsn, rec));
@@ -1102,6 +1109,9 @@ impl DeltaLayerInner {
Value::Image(img) => {
format!(" img {} bytes", img.len())
}
Value::CompressedImage(img) => {
format!(" compressed img {} bytes", img.len())
}
Value::WalRecord(rec) => {
let wal_desc = walrecord::describe_wal_record(&rec)?;
format!(
@@ -1138,6 +1148,11 @@ impl DeltaLayerInner {
let checkpoint = CheckPoint::decode(&img)?;
println!(" CHECKPOINT: {:?}", checkpoint);
}
Value::CompressedImage(img) => {
let decompressed = lz4_flex::block::decompress(&img, BLCKSZ as usize)?;
let checkpoint = CheckPoint::decode(&decompressed)?;
println!(" CHECKPOINT: {:?}", checkpoint);
}
Value::WalRecord(_rec) => {
println!(" unexpected walrecord value for checkpoint key");
}

View File

@@ -39,7 +39,9 @@ use crate::tenant::vectored_blob_io::{
};
use crate::tenant::{PageReconstructError, Timeline};
use crate::virtual_file::{self, VirtualFile};
use crate::{IMAGE_FILE_MAGIC, STORAGE_FORMAT_VERSION, TEMP_FILE_SUFFIX};
use crate::{
COMPRESSED_STORAGE_FORMAT_VERSION, IMAGE_FILE_MAGIC, STORAGE_FORMAT_VERSION, TEMP_FILE_SUFFIX,
};
use anyhow::{anyhow, bail, ensure, Context, Result};
use bytes::{Bytes, BytesMut};
use camino::{Utf8Path, Utf8PathBuf};
@@ -153,6 +155,7 @@ pub struct ImageLayerInner {
// values copied from summary
index_start_blk: u32,
index_root_blk: u32,
format_version: u16,
lsn: Lsn,
@@ -167,6 +170,7 @@ impl std::fmt::Debug for ImageLayerInner {
f.debug_struct("ImageLayerInner")
.field("index_start_blk", &self.index_start_blk)
.field("index_root_blk", &self.index_root_blk)
.field("format_version", &self.format_version)
.finish()
}
}
@@ -408,6 +412,7 @@ impl ImageLayerInner {
Ok(Ok(ImageLayerInner {
index_start_blk: actual_summary.index_start_blk,
index_root_blk: actual_summary.index_root_blk,
format_version: actual_summary.format_version,
lsn,
file,
file_id,
@@ -436,18 +441,20 @@ impl ImageLayerInner {
)
.await?
{
let blob = block_reader
.block_cursor()
.read_blob(
offset,
&RequestContextBuilder::extend(ctx)
.page_content_kind(PageContentKind::ImageLayerValue)
.build(),
)
.await
.with_context(|| format!("failed to read value from offset {}", offset))?;
let value = Bytes::from(blob);
let ctx = RequestContextBuilder::extend(ctx)
.page_content_kind(PageContentKind::ImageLayerValue)
.build();
let blob = (if self.format_version >= COMPRESSED_STORAGE_FORMAT_VERSION {
block_reader
.block_cursor()
.read_compressed_blob(offset, &ctx)
.await
} else {
block_reader.block_cursor().read_blob(offset, &ctx).await
})
.with_context(|| format!("failed to read value from offset {}", offset))?;
let value = Bytes::from(blob);
reconstruct_state.img = Some((self.lsn, value));
Ok(ValueReconstructResult::Complete)
} else {
@@ -658,10 +665,7 @@ impl ImageLayerWriterInner {
///
async fn put_image(&mut self, key: Key, img: Bytes) -> anyhow::Result<()> {
ensure!(self.key_range.contains(&key));
let (_img, res) = self.blob_writer.write_blob(img).await;
// TODO: re-use the buffer for `img` further upstack
let off = res?;
let off = self.blob_writer.write_compressed_blob(img).await?;
let mut keybuf: [u8; KEY_SIZE] = [0u8; KEY_SIZE];
key.write_to_byte_slice(&mut keybuf);
self.tree.append(&keybuf, off)?;

View File

@@ -14,9 +14,12 @@ use crate::tenant::timeline::GetVectoredError;
use crate::tenant::{PageReconstructError, Timeline};
use crate::walrecord;
use anyhow::{anyhow, ensure, Result};
use bytes::Bytes;
use lz4_flex;
use pageserver_api::keyspace::KeySpace;
use pageserver_api::models::InMemoryLayerInfo;
use pageserver_api::shard::TenantShardId;
use postgres_ffi::BLCKSZ;
use std::collections::{BinaryHeap, HashMap, HashSet};
use std::sync::{Arc, OnceLock};
use tracing::*;
@@ -133,6 +136,9 @@ impl InMemoryLayer {
Ok(Value::Image(img)) => {
write!(&mut desc, " img {} bytes", img.len())?;
}
Ok(Value::CompressedImage(img)) => {
write!(&mut desc, " compressed img {} bytes", img.len())?;
}
Ok(Value::WalRecord(rec)) => {
let wal_desc = walrecord::describe_wal_record(&rec).unwrap();
write!(
@@ -184,6 +190,11 @@ impl InMemoryLayer {
reconstruct_state.img = Some((*entry_lsn, img));
return Ok(ValueReconstructResult::Complete);
}
Value::CompressedImage(img) => {
let decompressed = lz4_flex::block::decompress(&img, BLCKSZ as usize)?;
reconstruct_state.img = Some((*entry_lsn, Bytes::from(decompressed)));
return Ok(ValueReconstructResult::Complete);
}
Value::WalRecord(rec) => {
let will_init = rec.will_init();
reconstruct_state.records.push((*entry_lsn, rec));

View File

@@ -880,23 +880,18 @@ impl LayerInner {
) -> Result<heavier_once_cell::InitPermit, DownloadError> {
debug_assert_current_span_has_tenant_and_timeline_id();
let task_name = format!("download layer {}", self);
let (tx, rx) = tokio::sync::oneshot::channel();
// this is sadly needed because of task_mgr::shutdown_tasks, otherwise we cannot
// block tenant::mgr::remove_tenant_from_memory.
let this: Arc<Self> = self.clone();
crate::task_mgr::spawn(
&tokio::runtime::Handle::current(),
crate::task_mgr::TaskKind::RemoteDownloadTask,
Some(self.desc.tenant_shard_id),
Some(self.desc.timeline_id),
&task_name,
false,
async move {
let guard = timeline
.gate
.enter()
.map_err(|_| DownloadError::DownloadCancelled)?;
tokio::task::spawn(async move {
let _guard = guard;
let client = timeline
.remote_client
@@ -906,7 +901,7 @@ impl LayerInner {
let result = client.download_layer_file(
&this.desc.filename(),
&this.metadata(),
&crate::task_mgr::shutdown_token()
&timeline.cancel
)
.await;
@@ -929,7 +924,6 @@ impl LayerInner {
tokio::select! {
_ = tokio::time::sleep(backoff) => {},
_ = crate::task_mgr::shutdown_token().cancelled_owned() => {},
_ = timeline.cancel.cancelled() => {},
};
@@ -959,11 +953,10 @@ impl LayerInner {
}
}
}
Ok(())
}
.in_current_span(),
);
match rx.await {
Ok((Ok(()), permit)) => {
if let Some(reason) = self

View File

@@ -471,8 +471,9 @@ impl WalIngest {
&& decoded.xl_rmid == pg_constants::RM_XLOG_ID
&& (decoded.xl_info == pg_constants::XLOG_FPI
|| decoded.xl_info == pg_constants::XLOG_FPI_FOR_HINT)
// compression of WAL is not yet supported: fall back to storing the original WAL record
&& !postgres_ffi::bkpimage_is_compressed(blk.bimg_info, modification.tline.pg_version)?
// only lz4 compression of WAL is now supported, for other compression algorithms fall back to storing the original WAL record
&& (!postgres_ffi::bkpimage_is_compressed(blk.bimg_info, modification.tline.pg_version)? ||
postgres_ffi::bkpimage_is_compressed_lz4(blk.bimg_info, modification.tline.pg_version)?)
// do not materialize null pages because them most likely be soon replaced with real data
&& blk.bimg_len != 0
{
@@ -480,7 +481,21 @@ impl WalIngest {
let img_len = blk.bimg_len as usize;
let img_offs = blk.bimg_offset as usize;
let mut image = BytesMut::with_capacity(BLCKSZ as usize);
image.extend_from_slice(&decoded.record[img_offs..img_offs + img_len]);
if postgres_ffi::bkpimage_is_compressed_lz4(
blk.bimg_info,
modification.tline.pg_version,
)? {
let decompressed_img_len = (BLCKSZ - blk.hole_length) as usize;
let decompressed = lz4_flex::block::decompress(
&decoded.record[img_offs..img_offs + img_len],
decompressed_img_len,
)
.map_err(|msg| PageReconstructError::Other(anyhow::anyhow!(msg)))?;
assert_eq!(decompressed.len(), decompressed_img_len);
image.extend_from_slice(&decompressed);
} else {
image.extend_from_slice(&decoded.record[img_offs..img_offs + img_len]);
}
if blk.hole_length != 0 {
let tail = image.split_off(blk.hole_offset as usize);

View File

@@ -18,10 +18,10 @@ OBJS = \
PG_CPPFLAGS = -I$(libpq_srcdir)
SHLIB_LINK_INTERNAL = $(libpq)
SHLIB_LINK = -lcurl
SHLIB_LINK = -lcurl -llz4
EXTENSION = neon
DATA = neon--1.0.sql neon--1.0--1.1.sql neon--1.1--1.2.sql neon--1.2--1.3.sql
DATA = neon--1.0.sql neon--1.0--1.1.sql neon--1.1--1.2.sql neon--1.2--1.3.sql neon--1.3--1.2.sql neon--1.2--1.1.sql neon--1.1--1.0.sql
PGFILEDESC = "neon - cloud storage for PostgreSQL"
EXTRA_CLEAN = \

View File

@@ -0,0 +1,6 @@
-- the order of operations is important here
-- because the view depends on the function
DROP VIEW IF EXISTS neon_lfc_stats CASCADE;
DROP FUNCTION IF EXISTS neon_get_lfc_stats CASCADE;

View File

@@ -0,0 +1 @@
DROP VIEW IF EXISTS NEON_STAT_FILE_CACHE CASCADE;

View File

@@ -0,0 +1 @@
DROP FUNCTION IF EXISTS approximate_working_set_size(bool) CASCADE;

View File

@@ -44,6 +44,7 @@ typedef enum
T_NeonErrorResponse,
T_NeonDbSizeResponse,
T_NeonGetSlruSegmentResponse,
T_NeonGetCompressedPageResponse
} NeonMessageTag;
/* base struct for c-style inheritance */
@@ -144,6 +145,15 @@ typedef struct
#define PS_GETPAGERESPONSE_SIZE (MAXALIGN(offsetof(NeonGetPageResponse, page) + BLCKSZ))
typedef struct
{
NeonMessageTag tag;
uint16 compressed_size;
char page[FLEXIBLE_ARRAY_MEMBER];
} NeonGetCompressedPageResponse;
#define PS_GETCOMPRESSEDPAGERESPONSE_SIZE(compressded_size) (MAXALIGN(offsetof(NeonGetCompressedPageResponse, page) + compressed_size))
typedef struct
{
NeonMessageTag tag;

View File

@@ -45,6 +45,10 @@
*/
#include "postgres.h"
#ifdef USE_LZ4
#include <lz4.h>
#endif
#include "access/xact.h"
#include "access/xlog.h"
#include "access/xlogdefs.h"
@@ -1059,6 +1063,7 @@ nm_pack_request(NeonRequest *msg)
case T_NeonExistsResponse:
case T_NeonNblocksResponse:
case T_NeonGetPageResponse:
case T_NeonGetCompressedPageResponse:
case T_NeonErrorResponse:
case T_NeonDbSizeResponse:
case T_NeonGetSlruSegmentResponse:
@@ -1114,6 +1119,21 @@ nm_unpack_response(StringInfo s)
Assert(msg_resp->tag == T_NeonGetPageResponse);
resp = (NeonResponse *) msg_resp;
break;
}
case T_NeonGetCompressedPageResponse:
{
NeonGetCompressedPageResponse *msg_resp;
uint16 compressed_size = pq_getmsgint(s, 2);
msg_resp = palloc0(PS_GETCOMPRESSEDPAGERESPONSE_SIZE(compressed_size));
msg_resp->tag = tag;
msg_resp->compressed_size = compressed_size;
memcpy(msg_resp->page, pq_getmsgbytes(s, compressed_size), compressed_size);
pq_getmsgend(s);
Assert(msg_resp->tag == T_NeonGetCompressedPageResponse);
resp = (NeonResponse *) msg_resp;
break;
}
@@ -1287,6 +1307,14 @@ nm_to_string(NeonMessage *msg)
appendStringInfoChar(&s, '}');
break;
}
case T_NeonGetCompressedPageResponse:
{
NeonGetCompressedPageResponse *msg_resp = (NeonGetCompressedPageResponse *) msg;
appendStringInfoString(&s, "{\"type\": \"NeonGetCompressedPageResponse\"");
appendStringInfo(&s, ", \"compressed_page_size\": \"%d\"}", msg_resp->compressed_size);
appendStringInfoChar(&s, '}');
break;
}
case T_NeonErrorResponse:
{
NeonErrorResponse *msg_resp = (NeonErrorResponse *) msg;
@@ -2205,6 +2233,29 @@ neon_read_at_lsn(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
lfc_write(rinfo, forkNum, blkno, buffer);
break;
case T_NeonGetCompressedPageResponse:
{
#ifndef USE_LZ4
ereport(ERROR, \
(errcode(ERRCODE_FEATURE_NOT_SUPPORTED), \
errmsg("compression method lz4 not supported"), \
errdetail("This functionality requires the server to be built with lz4 support."), \
errhint("You need to rebuild PostgreSQL using %s.", "--with-lz4")))
#else
NeonGetCompressedPageResponse* cp = (NeonGetCompressedPageResponse *) resp;
int rc = LZ4_decompress_safe(cp->page,
buffer,
cp->compressed_size,
BLCKSZ);
if (rc != BLCKSZ) {
ereport(ERROR,
(errcode(ERRCODE_DATA_CORRUPTED),
errmsg_internal("compressed lz4 data is corrupt")));
}
lfc_write(rinfo, forkNum, blkno, buffer);
#endif
break;
}
case T_NeonErrorResponse:
ereport(ERROR,
(errcode(ERRCODE_IO_ERROR),

View File

@@ -30,10 +30,6 @@ hostname.workspace = true
humantime.workspace = true
hyper-tungstenite.workspace = true
hyper.workspace = true
hyper1 = { package = "hyper", version = "1.2", features = ["server", "http1", "http2"] }
hyper-util = { version = "0.1", features = ["tokio"] }
http1 = { package = "http", version = "1" }
http-body-util = { version = "0.1" }
ipnet.workspace = true
itertools.workspace = true
lasso = { workspace = true, features = ["multi-threaded"] }

View File

@@ -175,7 +175,7 @@ async fn task_main(
.context("failed to set socket option")?;
info!(%peer_addr, "serving");
let ctx = RequestMonitoring::new(session_id, peer_addr, "sni_router", "sni");
let ctx = RequestMonitoring::new(session_id, peer_addr.ip(), "sni_router", "sni");
handle_client(ctx, dest_suffix, tls_config, tls_server_end_point, socket).await
}
.unwrap_or_else(|e| {

View File

@@ -3,7 +3,7 @@
use chrono::Utc;
use once_cell::sync::OnceCell;
use smol_str::SmolStr;
use std::net::{IpAddr, SocketAddr};
use std::net::IpAddr;
use tokio::sync::mpsc;
use tracing::{field::display, info_span, Span};
use uuid::Uuid;
@@ -62,7 +62,7 @@ pub enum AuthMethod {
impl RequestMonitoring {
pub fn new(
session_id: Uuid,
peer_addr: SocketAddr,
peer_addr: IpAddr,
protocol: &'static str,
region: &'static str,
) -> Self {
@@ -75,7 +75,7 @@ impl RequestMonitoring {
);
Self {
peer_addr: peer_addr.ip(),
peer_addr,
session_id,
protocol,
first_packet: Utc::now(),
@@ -100,12 +100,7 @@ impl RequestMonitoring {
#[cfg(test)]
pub fn test() -> Self {
RequestMonitoring::new(
Uuid::now_v7(),
([127, 0, 0, 1], 5432).into(),
"test",
"test",
)
RequestMonitoring::new(Uuid::now_v7(), [127, 0, 0, 1].into(), "test", "test")
}
pub fn console_application_name(&self) -> String {

View File

@@ -5,13 +5,19 @@ use std::{
io,
net::SocketAddr,
pin::{pin, Pin},
sync::Mutex,
task::{ready, Context, Poll},
};
use bytes::{Buf, BytesMut};
use hyper::server::conn::AddrIncoming;
use hyper::server::accept::Accept;
use hyper::server::conn::{AddrIncoming, AddrStream};
use metrics::IntCounterPairGuard;
use pin_project_lite::pin_project;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
use uuid::Uuid;
use crate::{metrics::NUM_CLIENT_CONNECTION_GAUGE, serverless::tls_listener::AsyncAccept};
pub struct ProxyProtocolAccept {
pub incoming: AddrIncoming,
@@ -325,6 +331,87 @@ impl<T: AsyncRead> AsyncRead for WithClientIp<T> {
}
}
impl AsyncAccept for ProxyProtocolAccept {
type Connection = WithConnectionGuard<WithClientIp<AddrStream>>;
type Error = io::Error;
fn poll_accept(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Connection, Self::Error>>> {
let conn = ready!(Pin::new(&mut self.incoming).poll_accept(cx)?);
tracing::info!(protocol = self.protocol, "accepted new TCP connection");
let Some(conn) = conn else {
return Poll::Ready(None);
};
Poll::Ready(Some(Ok(WithConnectionGuard {
inner: WithClientIp::new(conn),
connection_id: Uuid::new_v4(),
gauge: Mutex::new(Some(
NUM_CLIENT_CONNECTION_GAUGE
.with_label_values(&[self.protocol])
.guard(),
)),
})))
}
}
pin_project! {
pub struct WithConnectionGuard<T> {
#[pin]
pub inner: T,
pub connection_id: Uuid,
pub gauge: Mutex<Option<IntCounterPairGuard>>,
}
}
impl<T: AsyncWrite> AsyncWrite for WithConnectionGuard<T> {
#[inline]
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
self.project().inner.poll_write(cx, buf)
}
#[inline]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.project().inner.poll_flush(cx)
}
#[inline]
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.project().inner.poll_shutdown(cx)
}
#[inline]
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<Result<usize, io::Error>> {
self.project().inner.poll_write_vectored(cx, bufs)
}
#[inline]
fn is_write_vectored(&self) -> bool {
self.inner.is_write_vectored()
}
}
impl<T: AsyncRead> AsyncRead for WithConnectionGuard<T> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.project().inner.poll_read(cx, buf)
}
}
#[cfg(test)]
mod tests {
use std::pin::pin;

View File

@@ -91,8 +91,9 @@ pub async fn task_main(
connections.spawn(async move {
let mut socket = WithClientIp::new(socket);
let peer_addr = match socket.wait_for_addr().await {
Ok(Some(addr)) => addr,
let mut peer_addr = peer_addr.ip();
match socket.wait_for_addr().await {
Ok(Some(addr)) => peer_addr = addr.ip(),
Err(e) => {
error!("per-client task finished with an error: {e:#}");
return;
@@ -101,8 +102,8 @@ pub async fn task_main(
error!("missing required client IP");
return;
}
Ok(None) => peer_addr
};
Ok(None) => {}
}
match socket.inner.set_nodelay(true) {
Ok(()) => {},

View File

@@ -4,45 +4,46 @@
mod backend;
mod conn_pool;
mod http_auto;
mod json;
mod sql_over_http;
pub mod tls_listener;
mod websocket;
use bytes::Bytes;
pub use conn_pool::GlobalConnPoolOptions;
use anyhow::Context;
use futures::future::{select, Either};
use http1::{Method, Response, StatusCode};
use http_body_util::Full;
use hyper1::body::Incoming;
use anyhow::bail;
use hyper::StatusCode;
use metrics::IntCounterPairGuard;
use rand::rngs::StdRng;
use rand::SeedableRng;
pub use reqwest_middleware::{ClientWithMiddleware, Error};
pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use serde::Serialize;
use tokio::time::timeout;
use tokio_util::task::TaskTracker;
use crate::context::RequestMonitoring;
use crate::metrics::{NUM_CLIENT_CONNECTION_GAUGE, TLS_HANDSHAKE_FAILURES};
use crate::protocol2::WithClientIp;
use crate::proxy::run_until_cancelled;
use crate::metrics::TLS_HANDSHAKE_FAILURES;
use crate::protocol2::{ProxyProtocolAccept, WithClientIp, WithConnectionGuard};
use crate::rate_limiter::EndpointRateLimiter;
use crate::serverless::backend::PoolingBackend;
use crate::serverless::http_auto::Rewind;
use crate::{cancellation::CancellationHandler, config::ProxyConfig};
use futures::StreamExt;
use hyper::{
server::{
accept,
conn::{AddrIncoming, AddrStream},
},
Body, Method, Request, Response,
};
use std::convert::Infallible;
use std::net::SocketAddr;
use std::pin::pin;
use std::sync::Arc;
use std::time::Duration;
use std::net::IpAddr;
use std::task::Poll;
use std::{future::ready, sync::Arc};
use tls_listener::TlsListener;
use tokio::net::TcpListener;
use tokio_util::sync::CancellationToken;
use tracing::{error, info, warn, Instrument};
use utils::http::error::ApiError;
use utils::http::{error::ApiError, json::json_response};
pub const SERVERLESS_DRIVER_SNI: &str = "api";
@@ -94,221 +95,134 @@ pub async fn task_main(
tls_server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
let tls_acceptor: tokio_rustls::TlsAcceptor = Arc::new(tls_server_config).into();
let mut addr_incoming = AddrIncoming::from_listener(ws_listener)?;
let _ = addr_incoming.set_nodelay(true);
let addr_incoming = ProxyProtocolAccept {
incoming: addr_incoming,
protocol: "http",
};
let ws_connections = tokio_util::task::task_tracker::TaskTracker::new();
ws_connections.close(); // allows `ws_connections.wait to complete`
let http_connections = tokio_util::task::task_tracker::TaskTracker::new();
http_connections.close();
let server = http_auto::Builder::new();
loop {
let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await else {
break;
};
let (conn, mut peer_addr) = res.context("could not accept TCP stream")?;
if let Err(e) = conn.set_nodelay(true) {
tracing::error!("could not set nodolay: {e}");
continue;
}
let cancellation_token = cancellation_token.child_token();
let tls = tls_acceptor.clone();
let backend = backend.clone();
let ws_connections = ws_connections.clone();
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
let cancellation_handler = cancellation_handler.clone();
let server = server.clone();
http_connections.spawn(async move {
let _gauge = NUM_CLIENT_CONNECTION_GAUGE
.with_label_values(&["http"])
.guard();
// handle PROXY protocol
let mut conn = WithClientIp::new(conn);
let peer = match conn.wait_for_addr().await {
Ok(peer) => peer,
Err(e) => {
tracing::error!(
"failed to accept TCP connection: invalid PROXY protocol V2 header: {e:#}"
);
return;
}
};
if let Some(peer) = peer {
peer_addr = peer;
}
info!(%peer_addr, protocol = "http", "accepted new TCP connection");
let accept = tls.accept(conn);
let conn = match timeout(Duration::from_secs(10), accept).await {
Ok(Ok(conn)) => {
info!(%peer_addr, protocol = "http", "accepted new TLS connection");
conn
}
// The handshake failed, try getting another connection from the queue
Ok(Err(e)) => {
TLS_HANDSHAKE_FAILURES.inc();
warn!(%peer_addr, protocol = "http", "failed to accept TLS connection: {e:?}");
return;
}
// The handshake timed out, try getting another connection from the queue
Err(_) => {
TLS_HANDSHAKE_FAILURES.inc();
warn!(%peer_addr, protocol = "http", "failed to accept TLS connection: timeout");
return;
}
};
let (version, conn) = match conn.get_ref().1.alpn_protocol() {
Some(b"http/1.1") => (http_auto::Version::H1, Rewind::new(conn)),
Some(b"h2") => (http_auto::Version::H2, Rewind::new(conn)),
_ => {
tracing::debug!("HTTP: no ALPN negotiated");
let conn = timeout(Duration::from_secs(10), http_auto::read_version(conn)).await;
match conn {
Ok(Ok(v)) => v,
Ok(Err(e)) => {
tracing::warn!("HTTP connection error: {e}");
return;
},
Err(_) => {
tracing::warn!("HTTP connection error: timeout determining http version");
return;
}
}
}
};
let conn = server.serve_connection_with_upgrades(
conn,
version,
hyper1::service::service_fn(move |req: hyper1::Request<Incoming>| {
let backend = backend.clone();
let ws_connections = ws_connections.clone();
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
let cancellation_handler = cancellation_handler.clone();
async move {
Ok::<_, Infallible>(
request_handler(
req,
config,
backend,
ws_connections,
cancellation_handler,
peer_addr,
endpoint_rate_limiter,
)
.await
.map_or_else(api_error_into_response, |r| r),
)
}
})
let tls_listener = TlsListener::new(tls_acceptor, addr_incoming).filter(|conn| {
if let Err(err) = conn {
error!(
protocol = "http",
"failed to accept TLS connection: {err:?}"
);
TLS_HANDSHAKE_FAILURES.inc();
ready(false)
} else {
info!(protocol = "http", "accepted new TLS connection");
ready(true)
}
});
let make_svc = hyper::service::make_service_fn(
|stream: &tokio_rustls::server::TlsStream<
WithConnectionGuard<WithClientIp<AddrStream>>,
>| {
let (conn, _) = stream.get_ref();
let cancel = pin!(cancellation_token.cancelled());
let conn = pin!(conn);
let res = match select(cancel, conn).await {
Either::Left((_cancelled, mut conn)) => {
conn.as_mut().graceful_shutdown();
conn.await
}
Either::Right((res, _)) => res,
};
// this is jank. should dissapear with hyper 1.0 migration.
let gauge = conn
.gauge
.lock()
.expect("lock should not be poisoned")
.take()
.expect("gauge should be set on connection start");
match res {
Ok(()) => {}
Err(e) => {
tracing::warn!("HTTP connection error {e}")
}
let client_addr = conn.inner.client_addr();
let remote_addr = conn.inner.inner.remote_addr();
let backend = backend.clone();
let ws_connections = ws_connections.clone();
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
let cancellation_handler = cancellation_handler.clone();
async move {
let peer_addr = match client_addr {
Some(addr) => addr,
None if config.require_client_ip => bail!("missing required client ip"),
None => remote_addr,
};
Ok(MetricService::new(
hyper::service::service_fn(move |req: Request<Body>| {
let backend = backend.clone();
let ws_connections = ws_connections.clone();
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
let cancellation_handler = cancellation_handler.clone();
async move {
Ok::<_, Infallible>(
request_handler(
req,
config,
backend,
ws_connections,
cancellation_handler,
peer_addr.ip(),
endpoint_rate_limiter,
)
.await
.map_or_else(|e| e.into_response(), |r| r),
)
}
}),
gauge,
))
}
});
}
},
);
hyper::Server::builder(accept::from_stream(tls_listener))
.serve(make_svc)
.with_graceful_shutdown(cancellation_token.cancelled())
.await?;
// await websocket connections
http_connections.wait().await;
ws_connections.wait().await;
Ok(())
}
fn api_error_into_response(this: ApiError) -> Response<Full<Bytes>> {
match this {
ApiError::BadRequest(err) => HttpErrorBody::response_from_msg_and_status(
format!("{err:#?}"), // use debug printing so that we give the cause
StatusCode::BAD_REQUEST,
),
ApiError::Forbidden(_) => {
HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::FORBIDDEN)
}
ApiError::Unauthorized(_) => {
HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::UNAUTHORIZED)
}
ApiError::NotFound(_) => {
HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::NOT_FOUND)
}
ApiError::Conflict(_) => {
HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::CONFLICT)
}
ApiError::PreconditionFailed(_) => HttpErrorBody::response_from_msg_and_status(
this.to_string(),
StatusCode::PRECONDITION_FAILED,
),
ApiError::ShuttingDown => HttpErrorBody::response_from_msg_and_status(
"Shutting down".to_string(),
StatusCode::SERVICE_UNAVAILABLE,
),
ApiError::ResourceUnavailable(err) => HttpErrorBody::response_from_msg_and_status(
err.to_string(),
StatusCode::SERVICE_UNAVAILABLE,
),
ApiError::Timeout(err) => HttpErrorBody::response_from_msg_and_status(
err.to_string(),
StatusCode::REQUEST_TIMEOUT,
),
ApiError::InternalServerError(err) => HttpErrorBody::response_from_msg_and_status(
err.to_string(),
StatusCode::INTERNAL_SERVER_ERROR,
),
struct MetricService<S> {
inner: S,
_gauge: IntCounterPairGuard,
}
impl<S> MetricService<S> {
fn new(inner: S, _gauge: IntCounterPairGuard) -> MetricService<S> {
MetricService { inner, _gauge }
}
}
#[derive(Serialize)]
struct HttpErrorBody {
pub msg: String,
}
impl<S, ReqBody> hyper::service::Service<Request<ReqBody>> for MetricService<S>
where
S: hyper::service::Service<Request<ReqBody>>,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;
impl HttpErrorBody {
pub fn response_from_msg_and_status(msg: String, status: StatusCode) -> Response<Full<Bytes>> {
HttpErrorBody { msg }.to_response(status)
fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
pub fn to_response(&self, status: StatusCode) -> Response<Full<Bytes>> {
Response::builder()
.status(status)
.header(http1::header::CONTENT_TYPE, "application/json")
// we do not have nested maps with non string keys so serialization shouldn't fail
.body(Full::new(Bytes::from(serde_json::to_string(self).unwrap())))
.unwrap()
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
self.inner.call(req)
}
}
#[allow(clippy::too_many_arguments)]
async fn request_handler(
mut request: hyper1::Request<Incoming>,
mut request: Request<Body>,
config: &'static ProxyConfig,
backend: Arc<PoolingBackend>,
ws_connections: TaskTracker,
cancellation_handler: Arc<CancellationHandler>,
peer_addr: SocketAddr,
peer_addr: IpAddr,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> Result<Response<Full<Bytes>>, ApiError> {
) -> Result<Response<Body>, ApiError> {
let session_id = uuid::Uuid::new_v4();
let host = request
@@ -347,14 +261,14 @@ async fn request_handler(
// Return the response so the spawned future can continue.
Ok(response)
} else if request.uri().path() == "/sql" && *request.method() == Method::POST {
} else if request.uri().path() == "/sql" && request.method() == Method::POST {
let ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region);
let span = ctx.span.clone();
sql_over_http::handle(config, ctx, request, backend)
.instrument(span)
.await
} else if request.uri().path() == "/sql" && *request.method() == Method::OPTIONS {
} else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS {
Response::builder()
.header("Allow", "OPTIONS, POST")
.header("Access-Control-Allow-Origin", "*")
@@ -364,24 +278,9 @@ async fn request_handler(
)
.header("Access-Control-Max-Age", "86400" /* 24 hours */)
.status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
.body(Full::new(Bytes::new()))
.body(Body::empty())
.map_err(|e| ApiError::InternalServerError(e.into()))
} else {
json_response(StatusCode::BAD_REQUEST, "query is not supported")
}
}
fn json_response<T: Serialize>(
status: StatusCode,
data: T,
) -> Result<Response<Full<Bytes>>, ApiError> {
let json = serde_json::to_string(&data)
.context("Failed to serialize JSON response")
.map_err(ApiError::InternalServerError)?;
let response = Response::builder()
.status(status)
.header(http1::header::CONTENT_TYPE, "application/json")
.body(Full::new(Bytes::from(json)))
.map_err(|e| ApiError::InternalServerError(e.into()))?;
Ok(response)
}

View File

@@ -1,316 +0,0 @@
//! [`hyper-util`] offers an 'auto' connection to detect whether the connection should be HTTP1 or HTTP2.
//! There's a bug in this implementation where graceful shutdowns are not properly respected.
use futures::ready;
use hyper1::body::Body;
use hyper1::service::HttpService;
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
use std::future::Future;
use std::marker::PhantomPinned;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{error::Error as StdError, io, marker::Unpin};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use ::http1::{Request, Response};
use bytes::Bytes;
use hyper1::{body::Incoming, service::Service};
use hyper1::server::conn::http1;
use hyper1::{rt::bounds::Http2ServerConnExec, server::conn::http2};
use pin_project_lite::pin_project;
type Error = Box<dyn std::error::Error + Send + Sync>;
type Result<T> = std::result::Result<T, Error>;
const H2_PREFACE: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";
/// Http1 or Http2 connection builder.
#[derive(Clone, Debug)]
pub struct Builder {
http1: http1::Builder,
http2: http2::Builder<TokioExecutor>,
}
impl Builder {
/// Create a new auto connection builder.
pub fn new() -> Self {
let mut builder = Self {
http1: http1::Builder::new(),
http2: http2::Builder::new(TokioExecutor::new()),
};
builder.http1.timer(TokioTimer::new());
builder.http2.timer(TokioTimer::new());
builder
}
/// Bind a connection together with a [`Service`], with the ability to
/// handle HTTP upgrades. This requires that the IO object implements
/// `Send`.
pub fn serve_connection_with_upgrades<I, S, B>(
&self,
io: Rewind<I>,
version: Version,
service: S,
) -> UpgradeableConnection<I, S>
where
S: Service<Request<Incoming>, Response = Response<B>>,
S::Future: 'static,
S::Error: Into<Box<dyn StdError + Send + Sync>>,
B: Body + 'static,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
I: AsyncRead + AsyncWrite + Unpin + Send + 'static,
TokioExecutor: Http2ServerConnExec<S::Future, B>,
{
match version {
Version::H1 => {
let conn = self
.http1
.serve_connection(TokioIo::new(io), service)
.with_upgrades();
UpgradeableConnection {
state: UpgradeableConnState::H1 { conn },
}
}
Version::H2 => {
let conn = self.http2.serve_connection(TokioIo::new(io), service);
UpgradeableConnection {
state: UpgradeableConnState::H2 { conn },
}
}
}
}
}
#[derive(Copy, Clone)]
pub(crate) enum Version {
H1,
H2,
}
pub(crate) fn read_version<I>(io: I) -> ReadVersion<I>
where
I: AsyncRead + Unpin,
{
ReadVersion {
io: Some(io),
buf: [0; 24],
filled: 0,
version: Version::H2,
_pin: PhantomPinned,
}
}
pin_project! {
pub(crate) struct ReadVersion<I> {
io: Option<I>,
buf: [u8; 24],
// the amount of `buf` thats been filled
filled: usize,
version: Version,
// Make this future `!Unpin` for compatibility with async trait methods.
#[pin]
_pin: PhantomPinned,
}
}
impl<I> Future for ReadVersion<I>
where
I: AsyncRead + Unpin,
{
type Output = io::Result<(Version, Rewind<I>)>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let mut buf = ReadBuf::new(&mut *this.buf);
buf.set_filled(*this.filled);
// We start as H2 and switch to H1 as soon as we don't have the preface.
while buf.filled().len() < H2_PREFACE.len() {
let len = buf.filled().len();
ready!(Pin::new(this.io.as_mut().unwrap()).poll_read(cx, &mut buf))?;
*this.filled = buf.filled().len();
// We starts as H2 and switch to H1 when we don't get the preface.
if buf.filled().len() == len
|| buf.filled()[len..] != H2_PREFACE[len..buf.filled().len()]
{
*this.version = Version::H1;
break;
}
}
let io = this.io.take().unwrap();
let buf = buf.filled().to_vec();
Poll::Ready(Ok((
*this.version,
Rewind::new_buffered(io, Bytes::from(buf)),
)))
}
}
pin_project! {
/// Connection future.
pub struct UpgradeableConnection<I, S>
where
S: HttpService<Incoming>,
{
#[pin]
state: UpgradeableConnState<I, S>,
}
}
type Http1UpgradeableConnection<I, S> =
hyper1::server::conn::http1::UpgradeableConnection<TokioIo<Rewind<I>>, S>;
type Http2Connection<I, S> =
hyper1::server::conn::http2::Connection<TokioIo<Rewind<I>>, S, TokioExecutor>;
pin_project! {
#[project = UpgradeableConnStateProj]
enum UpgradeableConnState<I, S>
where
S: HttpService<Incoming>,
{
H1 {
#[pin]
conn: Http1UpgradeableConnection<I, S>,
},
H2 {
#[pin]
conn: Http2Connection<I, S>,
},
}
}
impl<I, S, B> UpgradeableConnection<I, S>
where
S: HttpService<Incoming, ResBody = B>,
S::Error: Into<Box<dyn StdError + Send + Sync>>,
I: AsyncRead + AsyncWrite + Unpin,
B: Body + 'static,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
TokioExecutor: Http2ServerConnExec<S::Future, B>,
{
/// Start a graceful shutdown process for this connection.
///
/// This `UpgradeableConnection` should continue to be polled until shutdown can finish.
///
/// # Note
///
/// This should only be called while the `Connection` future is still nothing. pending. If
/// called after `UpgradeableConnection::poll` has resolved, this does nothing.
pub fn graceful_shutdown(self: Pin<&mut Self>) {
match self.project().state.project() {
UpgradeableConnStateProj::H1 { conn } => conn.graceful_shutdown(),
UpgradeableConnStateProj::H2 { conn } => conn.graceful_shutdown(),
}
}
}
impl<I, S, B> Future for UpgradeableConnection<I, S>
where
S: Service<Request<Incoming>, Response = Response<B>>,
S::Future: 'static,
S::Error: Into<Box<dyn StdError + Send + Sync>>,
B: Body + 'static,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
I: AsyncRead + AsyncWrite + Unpin + Send + 'static,
TokioExecutor: Http2ServerConnExec<S::Future, B>,
{
type Output = Result<()>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.as_mut().project();
match this.state.as_mut().project() {
UpgradeableConnStateProj::H1 { conn } => conn.poll(cx).map_err(Into::into),
UpgradeableConnStateProj::H2 { conn } => conn.poll(cx).map_err(Into::into),
}
}
}
/// Combine a buffer with an IO, rewinding reads to use the buffer.
#[derive(Debug)]
pub(crate) struct Rewind<T> {
pre: Option<Bytes>,
inner: T,
}
impl<T> Rewind<T> {
pub(crate) fn new(io: T) -> Self {
Rewind {
pre: None,
inner: io,
}
}
pub(crate) fn new_buffered(io: T, buf: Bytes) -> Self {
Rewind {
pre: Some(buf),
inner: io,
}
}
}
impl<T> AsyncRead for Rewind<T>
where
T: AsyncRead + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if let Some(prefix) = self.pre.take() {
// If there are no remaining bytes, let the bytes get dropped.
if !prefix.is_empty() {
let copy_len = std::cmp::min(prefix.len(), buf.remaining());
buf.put_slice(&prefix[..copy_len]);
// Put back what's left
if !prefix.is_empty() {
self.pre = Some(prefix);
}
return Poll::Ready(Ok(()));
}
}
Pin::new(&mut self.inner).poll_read(cx, buf)
}
}
impl<T> AsyncWrite for Rewind<T>
where
T: AsyncWrite + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_write(cx, buf)
}
fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_shutdown(cx)
}
fn is_write_vectored(&self) -> bool {
self.inner.is_write_vectored()
}
}

View File

@@ -1,19 +1,14 @@
use std::sync::Arc;
use super::json_response;
use anyhow::bail;
use bytes::Bytes;
use futures::StreamExt;
use http_body_util::BodyExt;
use http_body_util::Full;
use hyper1::body::Body;
use hyper1::body::Incoming;
use hyper1::header;
use hyper1::http::HeaderName;
use hyper1::http::HeaderValue;
use hyper1::Response;
use hyper1::StatusCode;
use hyper1::{HeaderMap, Request};
use hyper::body::HttpBody;
use hyper::header;
use hyper::http::HeaderName;
use hyper::http::HeaderValue;
use hyper::Response;
use hyper::StatusCode;
use hyper::{Body, HeaderMap, Request};
use serde_json::json;
use serde_json::Value;
use tokio::try_join;
@@ -27,6 +22,7 @@ use tracing::error;
use tracing::info;
use url::Url;
use utils::http::error::ApiError;
use utils::http::json::json_response;
use crate::auth::backend::ComputeUserInfo;
use crate::auth::endpoint_sni;
@@ -195,9 +191,9 @@ fn get_conn_info(
pub async fn handle(
config: &'static ProxyConfig,
mut ctx: RequestMonitoring,
request: Request<Incoming>,
request: Request<Body>,
backend: Arc<PoolingBackend>,
) -> Result<Response<Full<Bytes>>, ApiError> {
) -> Result<Response<Body>, ApiError> {
let result = tokio::time::timeout(
config.http_config.request_timeout,
handle_inner(config, &mut ctx, request, backend),
@@ -304,18 +300,19 @@ pub async fn handle(
}
};
response
.headers_mut()
.insert("Access-Control-Allow-Origin", HeaderValue::from_static("*"));
response.headers_mut().insert(
"Access-Control-Allow-Origin",
hyper::http::HeaderValue::from_static("*"),
);
Ok(response)
}
async fn handle_inner(
config: &'static ProxyConfig,
ctx: &mut RequestMonitoring,
request: Request<Incoming>,
request: Request<Body>,
backend: Arc<PoolingBackend>,
) -> anyhow::Result<Response<Full<Bytes>>> {
) -> anyhow::Result<Response<Body>> {
let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE
.with_label_values(&[ctx.protocol])
.guard();
@@ -372,12 +369,9 @@ async fn handle_inner(
}
let fetch_and_process_request = async {
let body = request
.into_body()
.collect()
let body = hyper::body::to_bytes(request.into_body())
.await
.map_err(anyhow::Error::from)?
.to_bytes();
.map_err(anyhow::Error::from)?;
info!(length = body.len(), "request payload read");
let payload: Payload = serde_json::from_slice(&body)?;
Ok::<Payload, anyhow::Error>(payload) // Adjust error type accordingly
@@ -496,7 +490,7 @@ async fn handle_inner(
let body = serde_json::to_string(&result).expect("json serialization should not fail");
let len = body.len();
let response = response
.body(Full::new(Bytes::from(body)))
.body(Body::from(body))
// only fails if invalid status code or invalid header/values are given.
// these are not user configurable so it cannot fail dynamically
.expect("building response payload should not fail");

View File

@@ -0,0 +1,283 @@
use std::{
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use futures::{Future, Stream, StreamExt};
use pin_project_lite::pin_project;
use thiserror::Error;
use tokio::{
io::{AsyncRead, AsyncWrite},
task::JoinSet,
time::timeout,
};
/// Default timeout for the TLS handshake.
pub const DEFAULT_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
/// Trait for TLS implementation.
///
/// Implementations are provided by the rustls and native-tls features.
pub trait AsyncTls<C: AsyncRead + AsyncWrite>: Clone {
/// The type of the TLS stream created from the underlying stream.
type Stream: Send + 'static;
/// Error type for completing the TLS handshake
type Error: std::error::Error + Send + 'static;
/// Type of the Future for the TLS stream that is accepted.
type AcceptFuture: Future<Output = Result<Self::Stream, Self::Error>> + Send + 'static;
/// Accept a TLS connection on an underlying stream
fn accept(&self, stream: C) -> Self::AcceptFuture;
}
/// Asynchronously accept connections.
pub trait AsyncAccept {
/// The type of the connection that is accepted.
type Connection: AsyncRead + AsyncWrite;
/// The type of error that may be returned.
type Error;
/// Poll to accept the next connection.
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Connection, Self::Error>>>;
/// Return a new `AsyncAccept` that stops accepting connections after
/// `ender` completes.
///
/// Useful for graceful shutdown.
///
/// See [examples/echo.rs](https://github.com/tmccombs/tls-listener/blob/main/examples/echo.rs)
/// for example of how to use.
fn until<F: Future>(self, ender: F) -> Until<Self, F>
where
Self: Sized,
{
Until {
acceptor: self,
ender,
}
}
}
pin_project! {
///
/// Wraps a `Stream` of connections (such as a TCP listener) so that each connection is itself
/// encrypted using TLS.
///
/// It is similar to:
///
/// ```ignore
/// tcpListener.and_then(|s| tlsAcceptor.accept(s))
/// ```
///
/// except that it has the ability to accept multiple transport-level connections
/// simultaneously while the TLS handshake is pending for other connections.
///
/// By default, if a client fails the TLS handshake, that is treated as an error, and the
/// `TlsListener` will return an `Err`. If the `TlsListener` is passed directly to a hyper
/// [`Server`][1], then an invalid handshake can cause the server to stop accepting connections.
/// See [`http-stream.rs`][2] or [`http-low-level`][3] examples, for examples of how to avoid this.
///
/// Note that if the maximum number of pending connections is greater than 1, the resulting
/// [`T::Stream`][4] connections may come in a different order than the connections produced by the
/// underlying listener.
///
/// [1]: https://docs.rs/hyper/latest/hyper/server/struct.Server.html
/// [2]: https://github.com/tmccombs/tls-listener/blob/main/examples/http-stream.rs
/// [3]: https://github.com/tmccombs/tls-listener/blob/main/examples/http-low-level.rs
/// [4]: AsyncTls::Stream
///
#[allow(clippy::type_complexity)]
pub struct TlsListener<A: AsyncAccept, T: AsyncTls<A::Connection>> {
#[pin]
listener: A,
tls: T,
waiting: JoinSet<Result<Result<T::Stream, T::Error>, tokio::time::error::Elapsed>>,
timeout: Duration,
}
}
/// Builder for `TlsListener`.
#[derive(Clone)]
pub struct Builder<T> {
tls: T,
handshake_timeout: Duration,
}
/// Wraps errors from either the listener or the TLS Acceptor
#[derive(Debug, Error)]
pub enum Error<LE: std::error::Error, TE: std::error::Error> {
/// An error that arose from the listener ([AsyncAccept::Error])
#[error("{0}")]
ListenerError(#[source] LE),
/// An error that occurred during the TLS accept handshake
#[error("{0}")]
TlsAcceptError(#[source] TE),
}
impl<A: AsyncAccept, T> TlsListener<A, T>
where
T: AsyncTls<A::Connection>,
{
/// Create a `TlsListener` with default options.
pub fn new(tls: T, listener: A) -> Self {
builder(tls).listen(listener)
}
}
impl<A, T> TlsListener<A, T>
where
A: AsyncAccept,
A::Error: std::error::Error,
T: AsyncTls<A::Connection>,
{
/// Accept the next connection
///
/// This is essentially an alias to `self.next()` with a more domain-appropriate name.
pub async fn accept(&mut self) -> Option<<Self as Stream>::Item>
where
Self: Unpin,
{
self.next().await
}
/// Replaces the Tls Acceptor configuration, which will be used for new connections.
///
/// This can be used to change the certificate used at runtime.
pub fn replace_acceptor(&mut self, acceptor: T) {
self.tls = acceptor;
}
/// Replaces the Tls Acceptor configuration from a pinned reference to `Self`.
///
/// This is useful if your listener is `!Unpin`.
///
/// This can be used to change the certificate used at runtime.
pub fn replace_acceptor_pin(self: Pin<&mut Self>, acceptor: T) {
*self.project().tls = acceptor;
}
}
impl<A, T> Stream for TlsListener<A, T>
where
A: AsyncAccept,
A::Error: std::error::Error,
T: AsyncTls<A::Connection>,
{
type Item = Result<T::Stream, Error<A::Error, T::Error>>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
loop {
match this.listener.as_mut().poll_accept(cx) {
Poll::Pending => break,
Poll::Ready(Some(Ok(conn))) => {
this.waiting
.spawn(timeout(*this.timeout, this.tls.accept(conn)));
}
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Some(Err(Error::ListenerError(e))));
}
Poll::Ready(None) => return Poll::Ready(None),
}
}
loop {
return match this.waiting.poll_join_next(cx) {
Poll::Ready(Some(Ok(Ok(conn)))) => {
Poll::Ready(Some(conn.map_err(Error::TlsAcceptError)))
}
// The handshake timed out, try getting another connection from the queue
Poll::Ready(Some(Ok(Err(_)))) => continue,
// The handshake panicked
Poll::Ready(Some(Err(e))) if e.is_panic() => {
std::panic::resume_unwind(e.into_panic())
}
// The handshake was externally aborted
Poll::Ready(Some(Err(_))) => unreachable!("handshake tasks are never aborted"),
_ => Poll::Pending,
};
}
}
}
impl<C: AsyncRead + AsyncWrite + Unpin + Send + 'static> AsyncTls<C> for tokio_rustls::TlsAcceptor {
type Stream = tokio_rustls::server::TlsStream<C>;
type Error = std::io::Error;
type AcceptFuture = tokio_rustls::Accept<C>;
fn accept(&self, conn: C) -> Self::AcceptFuture {
tokio_rustls::TlsAcceptor::accept(self, conn)
}
}
impl<T> Builder<T> {
/// Set the timeout for handshakes.
///
/// If a timeout takes longer than `timeout`, then the handshake will be
/// aborted and the underlying connection will be dropped.
///
/// Defaults to `DEFAULT_HANDSHAKE_TIMEOUT`.
pub fn handshake_timeout(&mut self, timeout: Duration) -> &mut Self {
self.handshake_timeout = timeout;
self
}
/// Create a `TlsListener` from the builder
///
/// Actually build the `TlsListener`. The `listener` argument should be
/// an implementation of the `AsyncAccept` trait that accepts new connections
/// that the `TlsListener` will encrypt using TLS.
pub fn listen<A: AsyncAccept>(&self, listener: A) -> TlsListener<A, T>
where
T: AsyncTls<A::Connection>,
{
TlsListener {
listener,
tls: self.tls.clone(),
waiting: JoinSet::new(),
timeout: self.handshake_timeout,
}
}
}
/// Create a new Builder for a TlsListener
///
/// `server_config` will be used to configure the TLS sessions.
pub fn builder<T>(tls: T) -> Builder<T> {
Builder {
tls,
handshake_timeout: DEFAULT_HANDSHAKE_TIMEOUT,
}
}
pin_project! {
/// See [`AsyncAccept::until`]
pub struct Until<A, E> {
#[pin]
acceptor: A,
#[pin]
ender: E,
}
}
impl<A: AsyncAccept, E: Future> AsyncAccept for Until<A, E> {
type Connection = A::Connection;
type Error = A::Error;
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Connection, Self::Error>>> {
let this = self.project();
match this.ender.poll(cx) {
Poll::Pending => this.acceptor.poll_accept(cx),
Poll::Ready(_) => Poll::Ready(None),
}
}
}

View File

@@ -15,7 +15,7 @@ def test_migrations(neon_simple_env: NeonEnv):
endpoint.wait_for_migrations()
num_migrations = 8
num_migrations = 9
with endpoint.cursor() as cur:
cur.execute("SELECT id FROM neon_migration.migration_id")

View File

@@ -29,3 +29,34 @@ def test_neon_extension(neon_env_builder: NeonEnvBuilder):
log.info(res)
assert len(res) == 1
assert len(res[0]) == 5
# Verify that the neon extension can be upgraded/downgraded.
def test_neon_extension_compatibility(neon_env_builder: NeonEnvBuilder):
env = neon_env_builder.init_start()
env.neon_cli.create_branch("test_neon_extension_compatibility")
endpoint_main = env.endpoints.create("test_neon_extension_compatibility")
# don't skip pg_catalog updates - it runs CREATE EXTENSION neon
endpoint_main.respec(skip_pg_catalog_updates=False)
endpoint_main.start()
with closing(endpoint_main.connect()) as conn:
with conn.cursor() as cur:
all_versions = ["1.3", "1.2", "1.1", "1.0"]
current_version = "1.3"
for idx, begin_version in enumerate(all_versions):
for target_version in all_versions[idx + 1 :]:
if current_version != begin_version:
cur.execute(
f"ALTER EXTENSION neon UPDATE TO '{begin_version}'; -- {current_version}->{begin_version}"
)
current_version = begin_version
# downgrade
cur.execute(
f"ALTER EXTENSION neon UPDATE TO '{target_version}'; -- {begin_version}->{target_version}"
)
# upgrade
cur.execute(
f"ALTER EXTENSION neon UPDATE TO '{begin_version}'; -- {target_version}->{begin_version}"
)

View File

@@ -190,6 +190,8 @@ def test_delete_tenant_exercise_crash_safety_failpoints(
# So by ignoring these instead of waiting for empty upload queue
# we execute more distinct code paths.
'.*stopping left-over name="remote upload".*',
# an on-demand is cancelled by shutdown
".*initial size calculation failed: downloading failed, possibly for shutdown",
]
)

View File

@@ -213,7 +213,9 @@ def test_delete_timeline_exercise_crash_safety_failpoints(
# This happens when timeline remains are cleaned up during loading
".*Timeline dir entry become invalid.*",
# In one of the branches we poll for tenant to become active. Polls can generate this log message:
f".*Tenant {env.initial_tenant} is not active*",
f".*Tenant {env.initial_tenant} is not active.*",
# an on-demand is cancelled by shutdown
".*initial size calculation failed: downloading failed, possibly for shutdown",
]
)

View File

@@ -64,7 +64,7 @@ rustls = { version = "0.21", features = ["dangerous_configuration"] }
scopeguard = { version = "1" }
serde = { version = "1", features = ["alloc", "derive"] }
serde_json = { version = "1", features = ["raw_value"] }
smallvec = { version = "1", default-features = false, features = ["const_new", "write"] }
smallvec = { version = "1", default-features = false, features = ["write"] }
subtle = { version = "2" }
time = { version = "0.3", features = ["local-offset", "macros", "serde-well-known"] }
tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "net", "process", "rt-multi-thread", "signal", "test-util"] }
@@ -76,6 +76,7 @@ tonic = { version = "0.9", features = ["tls-roots"] }
tower = { version = "0.4", default-features = false, features = ["balance", "buffer", "limit", "log", "timeout", "util"] }
tracing = { version = "0.1", features = ["log"] }
tracing-core = { version = "0.1" }
tungstenite = { version = "0.20" }
url = { version = "2", features = ["serde"] }
uuid = { version = "1", features = ["serde", "v4", "v7"] }
zeroize = { version = "1", features = ["derive"] }