mirror of
https://github.com/neondatabase/neon.git
synced 2026-02-11 06:30:37 +00:00
Compare commits
9 Commits
jcsp/ha-te
...
proxy-fix-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3c0eb1bf71 | ||
|
|
ec7c878364 | ||
|
|
5d799f0a25 | ||
|
|
d1bd8d377c | ||
|
|
71fda96c21 | ||
|
|
7afa5b3f35 | ||
|
|
2fc4e3df84 | ||
|
|
d91ff747bb | ||
|
|
375dfd661c |
208
Cargo.lock
generated
208
Cargo.lock
generated
@@ -282,12 +282,10 @@ dependencies = [
|
||||
"control_plane",
|
||||
"diesel",
|
||||
"diesel_migrations",
|
||||
"fail",
|
||||
"futures",
|
||||
"git-version",
|
||||
"hex",
|
||||
"humantime",
|
||||
"hyper",
|
||||
"hyper 0.14.26",
|
||||
"metrics",
|
||||
"once_cell",
|
||||
"pageserver_api",
|
||||
@@ -333,7 +331,7 @@ dependencies = [
|
||||
"fastrand 2.0.0",
|
||||
"hex",
|
||||
"http 0.2.9",
|
||||
"hyper",
|
||||
"hyper 0.14.26",
|
||||
"ring 0.17.6",
|
||||
"time",
|
||||
"tokio",
|
||||
@@ -370,7 +368,7 @@ dependencies = [
|
||||
"bytes",
|
||||
"fastrand 2.0.0",
|
||||
"http 0.2.9",
|
||||
"http-body",
|
||||
"http-body 0.4.5",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"tracing",
|
||||
@@ -398,7 +396,7 @@ dependencies = [
|
||||
"aws-types",
|
||||
"bytes",
|
||||
"http 0.2.9",
|
||||
"http-body",
|
||||
"http-body 0.4.5",
|
||||
"once_cell",
|
||||
"percent-encoding",
|
||||
"regex-lite",
|
||||
@@ -549,7 +547,7 @@ dependencies = [
|
||||
"crc32fast",
|
||||
"hex",
|
||||
"http 0.2.9",
|
||||
"http-body",
|
||||
"http-body 0.4.5",
|
||||
"md-5",
|
||||
"pin-project-lite",
|
||||
"sha1",
|
||||
@@ -581,7 +579,7 @@ dependencies = [
|
||||
"bytes-utils",
|
||||
"futures-core",
|
||||
"http 0.2.9",
|
||||
"http-body",
|
||||
"http-body 0.4.5",
|
||||
"once_cell",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
@@ -620,10 +618,10 @@ dependencies = [
|
||||
"aws-smithy-types",
|
||||
"bytes",
|
||||
"fastrand 2.0.0",
|
||||
"h2",
|
||||
"h2 0.3.24",
|
||||
"http 0.2.9",
|
||||
"http-body",
|
||||
"hyper",
|
||||
"http-body 0.4.5",
|
||||
"hyper 0.14.26",
|
||||
"hyper-rustls",
|
||||
"once_cell",
|
||||
"pin-project-lite",
|
||||
@@ -660,7 +658,7 @@ dependencies = [
|
||||
"bytes-utils",
|
||||
"futures-core",
|
||||
"http 0.2.9",
|
||||
"http-body",
|
||||
"http-body 0.4.5",
|
||||
"itoa",
|
||||
"num-integer",
|
||||
"pin-project-lite",
|
||||
@@ -709,8 +707,8 @@ dependencies = [
|
||||
"bytes",
|
||||
"futures-util",
|
||||
"http 0.2.9",
|
||||
"http-body",
|
||||
"hyper",
|
||||
"http-body 0.4.5",
|
||||
"hyper 0.14.26",
|
||||
"itoa",
|
||||
"matchit",
|
||||
"memchr",
|
||||
@@ -725,7 +723,7 @@ dependencies = [
|
||||
"sha1",
|
||||
"sync_wrapper",
|
||||
"tokio",
|
||||
"tokio-tungstenite",
|
||||
"tokio-tungstenite 0.20.0",
|
||||
"tower",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
@@ -741,7 +739,7 @@ dependencies = [
|
||||
"bytes",
|
||||
"futures-util",
|
||||
"http 0.2.9",
|
||||
"http-body",
|
||||
"http-body 0.4.5",
|
||||
"mime",
|
||||
"rustversion",
|
||||
"tower-layer",
|
||||
@@ -1230,7 +1228,7 @@ dependencies = [
|
||||
"compute_api",
|
||||
"flate2",
|
||||
"futures",
|
||||
"hyper",
|
||||
"hyper 0.14.26",
|
||||
"nix 0.27.1",
|
||||
"notify",
|
||||
"num_cpus",
|
||||
@@ -1346,7 +1344,7 @@ dependencies = [
|
||||
"futures",
|
||||
"git-version",
|
||||
"hex",
|
||||
"hyper",
|
||||
"hyper 0.14.26",
|
||||
"nix 0.27.1",
|
||||
"once_cell",
|
||||
"pageserver_api",
|
||||
@@ -2246,6 +2244,25 @@ dependencies = [
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "h2"
|
||||
version = "0.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "31d030e59af851932b72ceebadf4a2b5986dba4c3b99dd2493f8273a0f151943"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"fnv",
|
||||
"futures-core",
|
||||
"futures-sink",
|
||||
"futures-util",
|
||||
"http 1.0.0",
|
||||
"indexmap 2.0.1",
|
||||
"slab",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "half"
|
||||
version = "1.8.2"
|
||||
@@ -2411,6 +2428,29 @@ dependencies = [
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http-body"
|
||||
version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1cac85db508abc24a2e48553ba12a996e87244a0395ce011e62b37158745d643"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"http 1.0.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http-body-util"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "41cb79eb393015dadd30fc252023adb0b2400a0caee0fa2a077e6e21a551e840"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"futures-util",
|
||||
"http 1.0.0",
|
||||
"http-body 1.0.0",
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http-types"
|
||||
version = "2.12.0"
|
||||
@@ -2469,9 +2509,9 @@ dependencies = [
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"h2",
|
||||
"h2 0.3.24",
|
||||
"http 0.2.9",
|
||||
"http-body",
|
||||
"http-body 0.4.5",
|
||||
"httparse",
|
||||
"httpdate",
|
||||
"itoa",
|
||||
@@ -2483,6 +2523,26 @@ dependencies = [
|
||||
"want",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyper"
|
||||
version = "1.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "186548d73ac615b32a73aafe38fb4f56c0d340e110e5a200bcadbaf2e199263a"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"futures-channel",
|
||||
"futures-util",
|
||||
"h2 0.4.2",
|
||||
"http 1.0.0",
|
||||
"http-body 1.0.0",
|
||||
"httparse",
|
||||
"httpdate",
|
||||
"itoa",
|
||||
"pin-project-lite",
|
||||
"smallvec",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyper-rustls"
|
||||
version = "0.24.0"
|
||||
@@ -2490,7 +2550,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0646026eb1b3eea4cd9ba47912ea5ce9cc07713d105b1a14698f4e6433d348b7"
|
||||
dependencies = [
|
||||
"http 0.2.9",
|
||||
"hyper",
|
||||
"hyper 0.14.26",
|
||||
"log",
|
||||
"rustls 0.21.9",
|
||||
"rustls-native-certs",
|
||||
@@ -2504,7 +2564,7 @@ version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1"
|
||||
dependencies = [
|
||||
"hyper",
|
||||
"hyper 0.14.26",
|
||||
"pin-project-lite",
|
||||
"tokio",
|
||||
"tokio-io-timeout",
|
||||
@@ -2517,7 +2577,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"hyper",
|
||||
"hyper 0.14.26",
|
||||
"native-tls",
|
||||
"tokio",
|
||||
"tokio-native-tls",
|
||||
@@ -2525,15 +2585,33 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "hyper-tungstenite"
|
||||
version = "0.11.1"
|
||||
version = "0.13.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7cc7dcb1ab67cd336f468a12491765672e61a3b6b148634dbfe2fe8acd3fe7d9"
|
||||
checksum = "7a343d17fe7885302ed7252767dc7bb83609a874b6ff581142241ec4b73957ad"
|
||||
dependencies = [
|
||||
"hyper",
|
||||
"http-body-util",
|
||||
"hyper 1.2.0",
|
||||
"hyper-util",
|
||||
"pin-project-lite",
|
||||
"tokio",
|
||||
"tokio-tungstenite",
|
||||
"tungstenite",
|
||||
"tokio-tungstenite 0.21.0",
|
||||
"tungstenite 0.21.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyper-util"
|
||||
version = "0.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ca38ef113da30126bbff9cd1705f9273e15d45498615d138b0c20279ac7a76aa"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"futures-util",
|
||||
"http 1.0.0",
|
||||
"http-body 1.0.0",
|
||||
"hyper 1.2.0",
|
||||
"pin-project-lite",
|
||||
"socket2 0.5.5",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3510,7 +3588,7 @@ dependencies = [
|
||||
"hex-literal",
|
||||
"humantime",
|
||||
"humantime-serde",
|
||||
"hyper",
|
||||
"hyper 0.14.26",
|
||||
"itertools",
|
||||
"leaky-bucket",
|
||||
"md5",
|
||||
@@ -4180,9 +4258,13 @@ dependencies = [
|
||||
"hex",
|
||||
"hmac",
|
||||
"hostname",
|
||||
"http 1.0.0",
|
||||
"http-body-util",
|
||||
"humantime",
|
||||
"hyper",
|
||||
"hyper 0.14.26",
|
||||
"hyper 1.2.0",
|
||||
"hyper-tungstenite",
|
||||
"hyper-util",
|
||||
"ipnet",
|
||||
"itertools",
|
||||
"lasso",
|
||||
@@ -4514,7 +4596,7 @@ dependencies = [
|
||||
"futures-util",
|
||||
"http-types",
|
||||
"humantime",
|
||||
"hyper",
|
||||
"hyper 0.14.26",
|
||||
"itertools",
|
||||
"metrics",
|
||||
"once_cell",
|
||||
@@ -4544,10 +4626,10 @@ dependencies = [
|
||||
"encoding_rs",
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"h2",
|
||||
"h2 0.3.24",
|
||||
"http 0.2.9",
|
||||
"http-body",
|
||||
"hyper",
|
||||
"http-body 0.4.5",
|
||||
"hyper 0.14.26",
|
||||
"hyper-rustls",
|
||||
"hyper-tls",
|
||||
"ipnet",
|
||||
@@ -4605,7 +4687,7 @@ dependencies = [
|
||||
"futures",
|
||||
"getrandom 0.2.11",
|
||||
"http 0.2.9",
|
||||
"hyper",
|
||||
"hyper 0.14.26",
|
||||
"parking_lot 0.11.2",
|
||||
"reqwest",
|
||||
"reqwest-middleware",
|
||||
@@ -4692,7 +4774,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "496c1d3718081c45ba9c31fbfc07417900aa96f4070ff90dc29961836b7a9945"
|
||||
dependencies = [
|
||||
"http 0.2.9",
|
||||
"hyper",
|
||||
"hyper 0.14.26",
|
||||
"lazy_static",
|
||||
"percent-encoding",
|
||||
"regex",
|
||||
@@ -4971,7 +5053,7 @@ dependencies = [
|
||||
"git-version",
|
||||
"hex",
|
||||
"humantime",
|
||||
"hyper",
|
||||
"hyper 0.14.26",
|
||||
"metrics",
|
||||
"once_cell",
|
||||
"parking_lot 0.12.1",
|
||||
@@ -5446,9 +5528,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "smallvec"
|
||||
version = "1.11.0"
|
||||
version = "1.13.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "62bb4feee49fdd9f707ef802e22365a35de4b7b299de4763d44bfea899442ff9"
|
||||
checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7"
|
||||
|
||||
[[package]]
|
||||
name = "smol_str"
|
||||
@@ -5540,7 +5622,7 @@ dependencies = [
|
||||
"futures-util",
|
||||
"git-version",
|
||||
"humantime",
|
||||
"hyper",
|
||||
"hyper 0.14.26",
|
||||
"metrics",
|
||||
"once_cell",
|
||||
"parking_lot 0.12.1",
|
||||
@@ -6024,7 +6106,19 @@ dependencies = [
|
||||
"futures-util",
|
||||
"log",
|
||||
"tokio",
|
||||
"tungstenite",
|
||||
"tungstenite 0.20.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-tungstenite"
|
||||
version = "0.21.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38"
|
||||
dependencies = [
|
||||
"futures-util",
|
||||
"log",
|
||||
"tokio",
|
||||
"tungstenite 0.21.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -6091,10 +6185,10 @@ dependencies = [
|
||||
"bytes",
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"h2",
|
||||
"h2 0.3.24",
|
||||
"http 0.2.9",
|
||||
"http-body",
|
||||
"hyper",
|
||||
"http-body 0.4.5",
|
||||
"hyper 0.14.26",
|
||||
"hyper-timeout",
|
||||
"percent-encoding",
|
||||
"pin-project",
|
||||
@@ -6280,7 +6374,7 @@ dependencies = [
|
||||
name = "tracing-utils"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"hyper",
|
||||
"hyper 0.14.26",
|
||||
"opentelemetry",
|
||||
"opentelemetry-otlp",
|
||||
"opentelemetry-semantic-conventions",
|
||||
@@ -6317,6 +6411,25 @@ dependencies = [
|
||||
"utf-8",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tungstenite"
|
||||
version = "0.21.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"data-encoding",
|
||||
"http 1.0.0",
|
||||
"httparse",
|
||||
"log",
|
||||
"rand 0.8.5",
|
||||
"sha1",
|
||||
"thiserror",
|
||||
"url",
|
||||
"utf-8",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "twox-hash"
|
||||
version = "1.6.3"
|
||||
@@ -6480,7 +6593,7 @@ dependencies = [
|
||||
"heapless",
|
||||
"hex",
|
||||
"hex-literal",
|
||||
"hyper",
|
||||
"hyper 0.14.26",
|
||||
"jsonwebtoken",
|
||||
"leaky-bucket",
|
||||
"metrics",
|
||||
@@ -7005,7 +7118,7 @@ dependencies = [
|
||||
"hashbrown 0.14.0",
|
||||
"hex",
|
||||
"hmac",
|
||||
"hyper",
|
||||
"hyper 0.14.26",
|
||||
"indexmap 1.9.3",
|
||||
"itertools",
|
||||
"libc",
|
||||
@@ -7042,7 +7155,6 @@ dependencies = [
|
||||
"tower",
|
||||
"tracing",
|
||||
"tracing-core",
|
||||
"tungstenite",
|
||||
"url",
|
||||
"uuid",
|
||||
"zeroize",
|
||||
|
||||
@@ -92,7 +92,7 @@ http-types = { version = "2", default-features = false }
|
||||
humantime = "2.1"
|
||||
humantime-serde = "1.1.1"
|
||||
hyper = "0.14"
|
||||
hyper-tungstenite = "0.11"
|
||||
hyper-tungstenite = "0.13.0"
|
||||
inotify = "0.10.2"
|
||||
ipnet = "2.9.0"
|
||||
itertools = "0.10"
|
||||
|
||||
@@ -19,10 +19,8 @@ aws-config.workspace = true
|
||||
aws-sdk-secretsmanager.workspace = true
|
||||
camino.workspace = true
|
||||
clap.workspace = true
|
||||
fail.workspace = true
|
||||
futures.workspace = true
|
||||
git-version.workspace = true
|
||||
hex.workspace = true
|
||||
hyper.workspace = true
|
||||
humantime.workspace = true
|
||||
once_cell.workspace = true
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use std::sync::Arc;
|
||||
use std::{collections::HashMap, time::Duration};
|
||||
|
||||
use control_plane::endpoint::{ComputeControlPlane, EndpointStatus};
|
||||
@@ -24,13 +23,10 @@ struct ShardedComputeHookTenant {
|
||||
stripe_size: ShardStripeSize,
|
||||
shard_count: ShardCount,
|
||||
shards: Vec<(ShardNumber, NodeId)>,
|
||||
|
||||
// Async lock used for ensuring that remote compute hook calls are ordered identically to updates to this structure
|
||||
lock: Arc<tokio::sync::Mutex<()>>,
|
||||
}
|
||||
|
||||
enum ComputeHookTenant {
|
||||
Unsharded((NodeId, Arc<tokio::sync::Mutex<()>>)),
|
||||
Unsharded(NodeId),
|
||||
Sharded(ShardedComputeHookTenant),
|
||||
}
|
||||
|
||||
@@ -42,17 +38,9 @@ impl ComputeHookTenant {
|
||||
shards: vec![(tenant_shard_id.shard_number, node_id)],
|
||||
stripe_size,
|
||||
shard_count: tenant_shard_id.shard_count,
|
||||
lock: Arc::default(),
|
||||
})
|
||||
} else {
|
||||
Self::Unsharded((node_id, Arc::default()))
|
||||
}
|
||||
}
|
||||
|
||||
fn get_lock(&self) -> &Arc<tokio::sync::Mutex<()>> {
|
||||
match self {
|
||||
Self::Unsharded((_node_id, lock)) => lock,
|
||||
Self::Sharded(sharded_tenant) => &sharded_tenant.lock,
|
||||
Self::Unsharded(node_id)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,9 +53,7 @@ impl ComputeHookTenant {
|
||||
node_id: NodeId,
|
||||
) {
|
||||
match self {
|
||||
Self::Unsharded((existing_node_id, _lock))
|
||||
if tenant_shard_id.shard_count.count() == 1 =>
|
||||
{
|
||||
Self::Unsharded(existing_node_id) if tenant_shard_id.shard_count.count() == 1 => {
|
||||
*existing_node_id = node_id
|
||||
}
|
||||
Self::Sharded(sharded_tenant)
|
||||
@@ -136,15 +122,9 @@ pub(crate) enum NotifyError {
|
||||
}
|
||||
|
||||
impl ComputeHookTenant {
|
||||
fn maybe_reconfigure(
|
||||
&self,
|
||||
tenant_id: TenantId,
|
||||
) -> Option<(
|
||||
ComputeHookNotifyRequest,
|
||||
impl std::future::Future<Output = tokio::sync::OwnedMutexGuard<()>>,
|
||||
)> {
|
||||
let request = match self {
|
||||
Self::Unsharded((node_id, _lock)) => Some(ComputeHookNotifyRequest {
|
||||
fn maybe_reconfigure(&self, tenant_id: TenantId) -> Option<ComputeHookNotifyRequest> {
|
||||
match self {
|
||||
Self::Unsharded(node_id) => Some(ComputeHookNotifyRequest {
|
||||
tenant_id,
|
||||
shards: vec![ComputeHookNotifyRequestShard {
|
||||
shard_number: ShardNumber(0),
|
||||
@@ -178,9 +158,7 @@ impl ComputeHookTenant {
|
||||
);
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
request.map(|r| (r, self.get_lock().clone().lock_owned()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -189,11 +167,8 @@ impl ComputeHookTenant {
|
||||
/// the compute connection string.
|
||||
pub(super) struct ComputeHook {
|
||||
config: Config,
|
||||
state: std::sync::Mutex<HashMap<TenantId, ComputeHookTenant>>,
|
||||
state: tokio::sync::Mutex<HashMap<TenantId, ComputeHookTenant>>,
|
||||
authorization_header: Option<String>,
|
||||
|
||||
// This lock is only used in testing enviroments, to serialize calls into neon_lock
|
||||
neon_local_lock: tokio::sync::Mutex<()>,
|
||||
}
|
||||
|
||||
impl ComputeHook {
|
||||
@@ -207,7 +182,6 @@ impl ComputeHook {
|
||||
state: Default::default(),
|
||||
config,
|
||||
authorization_header,
|
||||
neon_local_lock: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -216,10 +190,6 @@ impl ComputeHook {
|
||||
&self,
|
||||
reconfigure_request: ComputeHookNotifyRequest,
|
||||
) -> anyhow::Result<()> {
|
||||
// neon_local updates are not safe to call concurrently, use a lock to serialize
|
||||
// all calls to this function
|
||||
let _locked = self.neon_local_lock.lock().await;
|
||||
|
||||
let env = match LocalEnv::load_config() {
|
||||
Ok(e) => e,
|
||||
Err(e) => {
|
||||
@@ -370,38 +340,30 @@ impl ComputeHook {
|
||||
stripe_size: ShardStripeSize,
|
||||
cancel: &CancellationToken,
|
||||
) -> Result<(), NotifyError> {
|
||||
let reconfigure_request = {
|
||||
let mut locked = self.state.lock().unwrap();
|
||||
let mut locked = self.state.lock().await;
|
||||
|
||||
use std::collections::hash_map::Entry;
|
||||
let tenant = match locked.entry(tenant_shard_id.tenant_id) {
|
||||
Entry::Vacant(e) => e.insert(ComputeHookTenant::new(
|
||||
tenant_shard_id,
|
||||
stripe_size,
|
||||
node_id,
|
||||
)),
|
||||
Entry::Occupied(e) => {
|
||||
let tenant = e.into_mut();
|
||||
tenant.update(tenant_shard_id, stripe_size, node_id);
|
||||
tenant
|
||||
}
|
||||
};
|
||||
|
||||
tenant.maybe_reconfigure(tenant_shard_id.tenant_id)
|
||||
use std::collections::hash_map::Entry;
|
||||
let tenant = match locked.entry(tenant_shard_id.tenant_id) {
|
||||
Entry::Vacant(e) => e.insert(ComputeHookTenant::new(
|
||||
tenant_shard_id,
|
||||
stripe_size,
|
||||
node_id,
|
||||
)),
|
||||
Entry::Occupied(e) => {
|
||||
let tenant = e.into_mut();
|
||||
tenant.update(tenant_shard_id, stripe_size, node_id);
|
||||
tenant
|
||||
}
|
||||
};
|
||||
let Some((reconfigure_request, lock_fut)) = reconfigure_request else {
|
||||
|
||||
let reconfigure_request = tenant.maybe_reconfigure(tenant_shard_id.tenant_id);
|
||||
let Some(reconfigure_request) = reconfigure_request else {
|
||||
// The tenant doesn't yet have pageservers for all its shards: we won't notify anything
|
||||
// until it does.
|
||||
tracing::info!("Tenant isn't yet ready to emit a notification");
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
// Finish acquiring the tenant's async lock: this future was created inside the self.state
|
||||
// lock above, so we are guaranteed to get this lock in the same order as callers took
|
||||
// that lock. This ordering is essential: the cloud control plane must end up with the
|
||||
// same end state for the tenant that we see.
|
||||
let _guard = lock_fut.await;
|
||||
|
||||
if let Some(notify_url) = &self.config.compute_hook_url {
|
||||
self.do_notify(notify_url, reconfigure_request, cancel)
|
||||
.await
|
||||
@@ -443,7 +405,6 @@ pub(crate) mod tests {
|
||||
tenant_state
|
||||
.maybe_reconfigure(tenant_id)
|
||||
.unwrap()
|
||||
.0
|
||||
.shards
|
||||
.len(),
|
||||
1
|
||||
@@ -451,7 +412,6 @@ pub(crate) mod tests {
|
||||
assert!(tenant_state
|
||||
.maybe_reconfigure(tenant_id)
|
||||
.unwrap()
|
||||
.0
|
||||
.stripe_size
|
||||
.is_none());
|
||||
|
||||
@@ -485,7 +445,6 @@ pub(crate) mod tests {
|
||||
tenant_state
|
||||
.maybe_reconfigure(tenant_id)
|
||||
.unwrap()
|
||||
.0
|
||||
.shards
|
||||
.len(),
|
||||
2
|
||||
@@ -494,7 +453,6 @@ pub(crate) mod tests {
|
||||
tenant_state
|
||||
.maybe_reconfigure(tenant_id)
|
||||
.unwrap()
|
||||
.0
|
||||
.stripe_size,
|
||||
Some(ShardStripeSize(32768))
|
||||
);
|
||||
|
||||
@@ -10,9 +10,7 @@ use pageserver_api::shard::TenantShardId;
|
||||
use pageserver_client::mgmt_api;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use utils::auth::{Scope, SwappableJwtAuth};
|
||||
use utils::failpoint_support::failpoints_handler;
|
||||
use utils::http::endpoint::{auth_middleware, check_permission_with, request_span};
|
||||
use utils::http::request::{must_get_query_param, parse_request_param};
|
||||
use utils::id::{TenantId, TimelineId};
|
||||
@@ -440,24 +438,6 @@ async fn handle_tenants_dump(req: Request<Body>) -> Result<Response<Body>, ApiEr
|
||||
state.service.tenants_dump()
|
||||
}
|
||||
|
||||
async fn handle_balance_all(
|
||||
service: Arc<Service>,
|
||||
req: Request<Body>,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
check_permissions(&req, Scope::Admin)?;
|
||||
service.balance_all()?;
|
||||
json_response(StatusCode::OK, ())
|
||||
}
|
||||
|
||||
async fn handle_balance_attached(
|
||||
service: Arc<Service>,
|
||||
req: Request<Body>,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
check_permissions(&req, Scope::Admin)?;
|
||||
service.balance_attached()?;
|
||||
json_response(StatusCode::OK, ())
|
||||
}
|
||||
|
||||
async fn handle_scheduler_dump(req: Request<Body>) -> Result<Response<Body>, ApiError> {
|
||||
check_permissions(&req, Scope::Admin)?;
|
||||
|
||||
@@ -574,9 +554,6 @@ pub fn make_router(
|
||||
.post("/debug/v1/consistency_check", |r| {
|
||||
request_span(r, handle_consistency_check)
|
||||
})
|
||||
.put("/debug/v1/failpoints", |r| {
|
||||
request_span(r, |r| failpoints_handler(r, CancellationToken::new()))
|
||||
})
|
||||
.get("/control/v1/tenant/:tenant_id/locate", |r| {
|
||||
tenant_service_handler(r, handle_tenant_locate)
|
||||
})
|
||||
@@ -595,12 +572,6 @@ pub fn make_router(
|
||||
.put("/control/v1/tenant/:tenant_id/shard_split", |r| {
|
||||
tenant_service_handler(r, handle_tenant_shard_split)
|
||||
})
|
||||
.post("/control/v1/balance/all", |r| {
|
||||
tenant_service_handler(r, handle_balance_all)
|
||||
})
|
||||
.post("/control/v1/balance/attached", |r| {
|
||||
tenant_service_handler(r, handle_balance_attached)
|
||||
})
|
||||
// Tenant operations
|
||||
// The ^/v1/ endpoints act as a "Virtual Pageserver", enabling shard-naive clients to call into
|
||||
// this service to manage tenants that actually consist of many tenant shards, as if they are a single entity.
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
/// A map of locks covering some arbitrary identifiers. Useful if you have a collection of objects but don't
|
||||
/// want to embed a lock in each one, or if your locking granularity is different to your object granularity.
|
||||
/// For example, used in the storage controller where the objects are tenant shards, but sometimes locking
|
||||
/// is needed at a tenant-wide granularity.
|
||||
pub(crate) struct IdLockMap<T>
|
||||
where
|
||||
T: Eq + PartialEq + std::hash::Hash,
|
||||
{
|
||||
/// A synchronous lock for getting/setting the async locks that our callers will wait on.
|
||||
entities: std::sync::Mutex<std::collections::HashMap<T, Arc<tokio::sync::RwLock<()>>>>,
|
||||
}
|
||||
|
||||
impl<T> IdLockMap<T>
|
||||
where
|
||||
T: Eq + PartialEq + std::hash::Hash,
|
||||
{
|
||||
pub(crate) fn shared(
|
||||
&self,
|
||||
key: T,
|
||||
) -> impl std::future::Future<Output = tokio::sync::OwnedRwLockReadGuard<()>> {
|
||||
let mut locked = self.entities.lock().unwrap();
|
||||
let entry = locked.entry(key).or_default();
|
||||
entry.clone().read_owned()
|
||||
}
|
||||
|
||||
pub(crate) fn exclusive(
|
||||
&self,
|
||||
key: T,
|
||||
) -> impl std::future::Future<Output = tokio::sync::OwnedRwLockWriteGuard<()>> {
|
||||
let mut locked = self.entities.lock().unwrap();
|
||||
let entry = locked.entry(key).or_default();
|
||||
entry.clone().write_owned()
|
||||
}
|
||||
|
||||
/// Rather than building a lock guard that re-takes the [`Self::entities`] lock, we just do
|
||||
/// periodic housekeeping to avoid the map growing indefinitely
|
||||
pub(crate) fn housekeeping(&self) {
|
||||
let mut locked = self.entities.lock().unwrap();
|
||||
locked.retain(|_k, lock| lock.try_write().is_err())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Default for IdLockMap<T>
|
||||
where
|
||||
T: Eq + PartialEq + std::hash::Hash,
|
||||
{
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
entities: std::sync::Mutex::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4,7 +4,6 @@ use utils::seqwait::MonotonicCounter;
|
||||
mod auth;
|
||||
mod compute_hook;
|
||||
pub mod http;
|
||||
mod id_lock_map;
|
||||
pub mod metrics;
|
||||
mod node;
|
||||
pub mod persistence;
|
||||
|
||||
@@ -11,9 +11,6 @@ use diesel::prelude::*;
|
||||
use diesel::Connection;
|
||||
use pageserver_api::controller_api::{NodeSchedulingPolicy, PlacementPolicy};
|
||||
use pageserver_api::models::TenantConfig;
|
||||
use pageserver_api::shard::ShardConfigError;
|
||||
use pageserver_api::shard::ShardIdentity;
|
||||
use pageserver_api::shard::ShardStripeSize;
|
||||
use pageserver_api::shard::{ShardCount, ShardNumber, TenantShardId};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use utils::generation::Generation;
|
||||
@@ -75,14 +72,6 @@ pub(crate) enum DatabaseError {
|
||||
Logical(String),
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub(crate) enum AbortShardSplitStatus {
|
||||
/// We aborted the split in the database by reverting to the parent shards
|
||||
Aborted,
|
||||
/// The split had already been persisted.
|
||||
Complete,
|
||||
}
|
||||
|
||||
pub(crate) type DatabaseResult<T> = Result<T, DatabaseError>;
|
||||
|
||||
impl Persistence {
|
||||
@@ -581,42 +570,6 @@ impl Persistence {
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Used when the remote part of a shard split failed: we will revert the database state to have only
|
||||
/// the parent shards, with SplitState::Idle.
|
||||
pub(crate) async fn abort_shard_split(
|
||||
&self,
|
||||
split_tenant_id: TenantId,
|
||||
new_shard_count: ShardCount,
|
||||
) -> DatabaseResult<AbortShardSplitStatus> {
|
||||
use crate::schema::tenant_shards::dsl::*;
|
||||
self.with_conn(move |conn| -> DatabaseResult<AbortShardSplitStatus> {
|
||||
let aborted = conn.transaction(|conn| -> QueryResult<AbortShardSplitStatus> {
|
||||
// Clear the splitting state on parent shards
|
||||
let updated = diesel::update(tenant_shards)
|
||||
.filter(tenant_id.eq(split_tenant_id.to_string()))
|
||||
.filter(shard_count.ne(new_shard_count.literal() as i32))
|
||||
.set((splitting.eq(0),))
|
||||
.execute(conn)?;
|
||||
|
||||
// Parent shards are already gone: we cannot abort.
|
||||
if updated == 0 {
|
||||
return Ok(AbortShardSplitStatus::Complete);
|
||||
}
|
||||
|
||||
// Erase child shards
|
||||
diesel::delete(tenant_shards)
|
||||
.filter(tenant_id.eq(split_tenant_id.to_string()))
|
||||
.filter(shard_count.eq(new_shard_count.literal() as i32))
|
||||
.execute(conn)?;
|
||||
|
||||
Ok(AbortShardSplitStatus::Aborted)
|
||||
})?;
|
||||
|
||||
Ok(aborted)
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
/// Parts of [`crate::tenant_state::TenantState`] that are stored durably
|
||||
@@ -651,28 +604,6 @@ pub(crate) struct TenantShardPersistence {
|
||||
pub(crate) config: String,
|
||||
}
|
||||
|
||||
impl TenantShardPersistence {
|
||||
pub(crate) fn get_shard_identity(&self) -> Result<ShardIdentity, ShardConfigError> {
|
||||
if self.shard_count == 0 {
|
||||
Ok(ShardIdentity::unsharded())
|
||||
} else {
|
||||
Ok(ShardIdentity::new(
|
||||
ShardNumber(self.shard_number as u8),
|
||||
ShardCount::new(self.shard_count as u8),
|
||||
ShardStripeSize(self.shard_stripe_size as u32),
|
||||
)?)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_tenant_shard_id(&self) -> Result<TenantShardId, hex::FromHexError> {
|
||||
Ok(TenantShardId {
|
||||
tenant_id: TenantId::from_str(self.tenant_id.as_str())?,
|
||||
shard_number: ShardNumber(self.shard_number as u8),
|
||||
shard_count: ShardCount::new(self.shard_count as u8),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Parts of [`crate::node::Node`] that are stored durably
|
||||
#[derive(Serialize, Deserialize, Queryable, Selectable, Insertable, Eq, PartialEq)]
|
||||
#[diesel(table_name = crate::schema::nodes)]
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use crate::persistence::Persistence;
|
||||
use crate::service;
|
||||
use hyper::StatusCode;
|
||||
use pageserver_api::models::{
|
||||
LocationConfig, LocationConfigMode, LocationConfigSecondary, TenantConfig,
|
||||
};
|
||||
@@ -19,8 +18,6 @@ use crate::compute_hook::{ComputeHook, NotifyError};
|
||||
use crate::node::Node;
|
||||
use crate::tenant_state::{IntentState, ObservedState, ObservedStateLocation};
|
||||
|
||||
const DEFAULT_HEATMAP_PERIOD: &str = "60s";
|
||||
|
||||
/// Object with the lifetime of the background reconcile task that is created
|
||||
/// for tenants which have a difference between their intent and observed states.
|
||||
pub(super) struct Reconciler {
|
||||
@@ -488,29 +485,17 @@ impl Reconciler {
|
||||
)
|
||||
.await
|
||||
{
|
||||
Some(Ok(observed)) => Some(observed),
|
||||
Some(Err(mgmt_api::Error::ApiError(status, _msg)))
|
||||
if status == StatusCode::NOT_FOUND =>
|
||||
{
|
||||
None
|
||||
}
|
||||
Some(Ok(observed)) => observed,
|
||||
Some(Err(e)) => return Err(e.into()),
|
||||
None => return Err(ReconcileError::Cancel),
|
||||
};
|
||||
tracing::info!("Scanned location configuration on {attached_node}: {observed_conf:?}");
|
||||
match observed_conf {
|
||||
Some(conf) => {
|
||||
// Pageserver returned a state: update it in observed. This may still be an indeterminate (None) state,
|
||||
// if internally the pageserver's TenantSlot was being mutated (e.g. some long running API call is still running)
|
||||
self.observed
|
||||
.locations
|
||||
.insert(attached_node.get_id(), ObservedStateLocation { conf });
|
||||
}
|
||||
None => {
|
||||
// Pageserver returned 404: we have confirmation that there is no state for this shard on that pageserver.
|
||||
self.observed.locations.remove(&attached_node.get_id());
|
||||
}
|
||||
}
|
||||
self.observed.locations.insert(
|
||||
attached_node.get_id(),
|
||||
ObservedStateLocation {
|
||||
conf: observed_conf,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -540,12 +525,7 @@ impl Reconciler {
|
||||
)));
|
||||
};
|
||||
|
||||
let mut wanted_conf = attached_location_conf(
|
||||
generation,
|
||||
&self.shard,
|
||||
&self.config,
|
||||
!self.intent.secondary.is_empty(),
|
||||
);
|
||||
let mut wanted_conf = attached_location_conf(generation, &self.shard, &self.config);
|
||||
match self.observed.locations.get(&node.get_id()) {
|
||||
Some(conf) if conf.conf.as_ref() == Some(&wanted_conf) => {
|
||||
// Nothing to do
|
||||
@@ -682,26 +662,10 @@ impl Reconciler {
|
||||
}
|
||||
}
|
||||
|
||||
/// We tweak the externally-set TenantConfig while configuring
|
||||
/// locations, using our awareness of whether secondary locations
|
||||
/// are in use to automatically enable/disable heatmap uploads.
|
||||
fn ha_aware_config(config: &TenantConfig, has_secondaries: bool) -> TenantConfig {
|
||||
let mut config = config.clone();
|
||||
if has_secondaries {
|
||||
if config.heatmap_period.is_none() {
|
||||
config.heatmap_period = Some(DEFAULT_HEATMAP_PERIOD.to_string());
|
||||
}
|
||||
} else {
|
||||
config.heatmap_period = None;
|
||||
}
|
||||
config
|
||||
}
|
||||
|
||||
pub(crate) fn attached_location_conf(
|
||||
generation: Generation,
|
||||
shard: &ShardIdentity,
|
||||
config: &TenantConfig,
|
||||
has_secondaries: bool,
|
||||
) -> LocationConfig {
|
||||
LocationConfig {
|
||||
mode: LocationConfigMode::AttachedSingle,
|
||||
@@ -710,7 +674,7 @@ pub(crate) fn attached_location_conf(
|
||||
shard_number: shard.number.0,
|
||||
shard_count: shard.count.literal(),
|
||||
shard_stripe_size: shard.stripe_size.0,
|
||||
tenant_conf: ha_aware_config(config, has_secondaries),
|
||||
tenant_conf: config.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -725,6 +689,6 @@ pub(crate) fn secondary_location_conf(
|
||||
shard_number: shard.number.0,
|
||||
shard_count: shard.count.literal(),
|
||||
shard_stripe_size: shard.stripe_size.0,
|
||||
tenant_conf: ha_aware_config(config, true),
|
||||
tenant_conf: config.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -577,12 +577,7 @@ impl TenantState {
|
||||
.generation
|
||||
.expect("Attempted to enter attached state without a generation");
|
||||
|
||||
let wanted_conf = attached_location_conf(
|
||||
generation,
|
||||
&self.shard,
|
||||
&self.config,
|
||||
!self.intent.secondary.is_empty(),
|
||||
);
|
||||
let wanted_conf = attached_location_conf(generation, &self.shard, &self.config);
|
||||
match self.observed.locations.get(&node_id) {
|
||||
Some(conf) if conf.conf.as_ref() == Some(&wanted_conf) => {}
|
||||
Some(_) | None => {
|
||||
|
||||
@@ -774,10 +774,7 @@ impl Endpoint {
|
||||
spec.shard_stripe_size = stripe_size.map(|s| s.0 as usize);
|
||||
}
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(30))
|
||||
.build()
|
||||
.unwrap();
|
||||
let client = reqwest::Client::new();
|
||||
let response = client
|
||||
.post(format!(
|
||||
"http://{}:{}/configure",
|
||||
|
||||
@@ -17,7 +17,6 @@ use std::time::Duration;
|
||||
use anyhow::{bail, Context};
|
||||
use camino::Utf8PathBuf;
|
||||
use futures::SinkExt;
|
||||
use hyper::StatusCode;
|
||||
use pageserver_api::controller_api::NodeRegisterRequest;
|
||||
use pageserver_api::models::{
|
||||
self, LocationConfig, ShardParameters, TenantHistorySize, TenantInfo, TimelineInfo,
|
||||
@@ -263,11 +262,6 @@ impl PageServerNode {
|
||||
match st {
|
||||
Ok(()) => Ok(true),
|
||||
Err(mgmt_api::Error::ReceiveBody(_)) => Ok(false),
|
||||
Err(mgmt_api::Error::ApiError(status, _msg))
|
||||
if status == StatusCode::SERVICE_UNAVAILABLE =>
|
||||
{
|
||||
Ok(false)
|
||||
}
|
||||
Err(e) => Err(anyhow::anyhow!("Failed to check node status: {e}")),
|
||||
}
|
||||
},
|
||||
|
||||
@@ -2103,16 +2103,6 @@ where
|
||||
R: std::future::Future<Output = Result<Response<Body>, ApiError>> + Send + 'static,
|
||||
H: FnOnce(Request<Body>, CancellationToken) -> R + Send + Sync + 'static,
|
||||
{
|
||||
if request.uri() != &"/v1/failpoints".parse::<Uri>().unwrap() {
|
||||
fail::fail_point!("api-503", |_| Err(ApiError::ResourceUnavailable(
|
||||
"failpoint".into()
|
||||
)));
|
||||
|
||||
fail::fail_point!("api-500", |_| Err(ApiError::InternalServerError(
|
||||
anyhow::anyhow!("failpoint")
|
||||
)));
|
||||
}
|
||||
|
||||
// Spawn a new task to handle the request, to protect the handler from unexpected
|
||||
// async cancellations. Most pageserver functions are not async cancellation safe.
|
||||
// We arm a drop-guard, so that if Hyper drops the Future, we signal the task
|
||||
@@ -2257,7 +2247,7 @@ pub fn make_router(
|
||||
.get("/v1/location_config", |r| {
|
||||
api_handler(r, list_location_config_handler)
|
||||
})
|
||||
.get("/v1/location_config/:tenant_shard_id", |r| {
|
||||
.get("/v1/location_config/:tenant_id", |r| {
|
||||
api_handler(r, get_location_config_handler)
|
||||
})
|
||||
.put(
|
||||
|
||||
@@ -1440,31 +1440,6 @@ impl TenantManager {
|
||||
tenant_shard_id: TenantShardId,
|
||||
new_shard_count: ShardCount,
|
||||
ctx: &RequestContext,
|
||||
) -> anyhow::Result<Vec<TenantShardId>> {
|
||||
let r = self
|
||||
.do_shard_split(tenant_shard_id, new_shard_count, ctx)
|
||||
.await;
|
||||
if r.is_err() {
|
||||
// Shard splitting might have left the original shard in a partially shut down state (it
|
||||
// stops the shard's remote timeline client). Reset it to ensure we leave things in
|
||||
// a working state.
|
||||
if self.get(tenant_shard_id).is_some() {
|
||||
tracing::warn!("Resetting {tenant_shard_id} after shard split failure");
|
||||
if let Err(e) = self.reset_tenant(tenant_shard_id, false, ctx).await {
|
||||
// Log this error because our return value will still be the original error, not this one.
|
||||
tracing::warn!("Failed to reset {tenant_shard_id}: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
r
|
||||
}
|
||||
|
||||
pub(crate) async fn do_shard_split(
|
||||
&self,
|
||||
tenant_shard_id: TenantShardId,
|
||||
new_shard_count: ShardCount,
|
||||
ctx: &RequestContext,
|
||||
) -> anyhow::Result<Vec<TenantShardId>> {
|
||||
let tenant = get_tenant(tenant_shard_id, true)?;
|
||||
|
||||
@@ -1491,10 +1466,6 @@ impl TenantManager {
|
||||
.join(",")
|
||||
);
|
||||
|
||||
fail::fail_point!("shard-split-pre-prepare", |_| Err(anyhow::anyhow!(
|
||||
"failpoint"
|
||||
)));
|
||||
|
||||
// Phase 1: Write out child shards' remote index files, in the parent tenant's current generation
|
||||
if let Err(e) = tenant.split_prepare(&child_shards).await {
|
||||
// If [`Tenant::split_prepare`] fails, we must reload the tenant, because it might
|
||||
@@ -1504,10 +1475,6 @@ impl TenantManager {
|
||||
return Err(e);
|
||||
}
|
||||
|
||||
fail::fail_point!("shard-split-post-prepare", |_| Err(anyhow::anyhow!(
|
||||
"failpoint"
|
||||
)));
|
||||
|
||||
self.resources.deletion_queue_client.flush_advisory();
|
||||
|
||||
// Phase 2: Put the parent shard to InProgress and grab a reference to the parent Tenant
|
||||
@@ -1529,16 +1496,11 @@ impl TenantManager {
|
||||
anyhow::bail!("Detached parent shard in the middle of split!")
|
||||
}
|
||||
};
|
||||
fail::fail_point!("shard-split-pre-hardlink", |_| Err(anyhow::anyhow!(
|
||||
"failpoint"
|
||||
)));
|
||||
|
||||
// Optimization: hardlink layers from the parent into the children, so that they don't have to
|
||||
// re-download & duplicate the data referenced in their initial IndexPart
|
||||
self.shard_split_hardlink(parent, child_shards.clone())
|
||||
.await?;
|
||||
fail::fail_point!("shard-split-post-hardlink", |_| Err(anyhow::anyhow!(
|
||||
"failpoint"
|
||||
)));
|
||||
|
||||
// Take a snapshot of where the parent's WAL ingest had got to: we will wait for
|
||||
// child shards to reach this point.
|
||||
@@ -1575,10 +1537,6 @@ impl TenantManager {
|
||||
.await?;
|
||||
}
|
||||
|
||||
fail::fail_point!("shard-split-post-child-conf", |_| Err(anyhow::anyhow!(
|
||||
"failpoint"
|
||||
)));
|
||||
|
||||
// Phase 4: wait for child chards WAL ingest to catch up to target LSN
|
||||
for child_shard_id in &child_shards {
|
||||
let child_shard_id = *child_shard_id;
|
||||
@@ -1611,10 +1569,6 @@ impl TenantManager {
|
||||
timeline.timeline_id,
|
||||
target_lsn
|
||||
);
|
||||
|
||||
fail::fail_point!("shard-split-lsn-wait", |_| Err(anyhow::anyhow!(
|
||||
"failpoint"
|
||||
)));
|
||||
if let Err(e) = timeline.wait_lsn(*target_lsn, ctx).await {
|
||||
// Failure here might mean shutdown, in any case this part is an optimization
|
||||
// and we shouldn't hold up the split operation.
|
||||
@@ -1660,10 +1614,6 @@ impl TenantManager {
|
||||
},
|
||||
);
|
||||
|
||||
fail::fail_point!("shard-split-pre-finish", |_| Err(anyhow::anyhow!(
|
||||
"failpoint"
|
||||
)));
|
||||
|
||||
parent_slot_guard.drop_old_value()?;
|
||||
|
||||
// Phase 6: Release the InProgress on the parent shard
|
||||
|
||||
@@ -30,6 +30,10 @@ hostname.workspace = true
|
||||
humantime.workspace = true
|
||||
hyper-tungstenite.workspace = true
|
||||
hyper.workspace = true
|
||||
hyper1 = { package = "hyper", version = "1.2", features = ["server", "http1", "http2"] }
|
||||
hyper-util = { version = "0.1", features = ["tokio"] }
|
||||
http1 = { package = "http", version = "1" }
|
||||
http-body-util = { version = "0.1" }
|
||||
ipnet.workspace = true
|
||||
itertools.workspace = true
|
||||
lasso = { workspace = true, features = ["multi-threaded"] }
|
||||
|
||||
@@ -175,7 +175,7 @@ async fn task_main(
|
||||
.context("failed to set socket option")?;
|
||||
|
||||
info!(%peer_addr, "serving");
|
||||
let ctx = RequestMonitoring::new(session_id, peer_addr.ip(), "sni_router", "sni");
|
||||
let ctx = RequestMonitoring::new(session_id, peer_addr, "sni_router", "sni");
|
||||
handle_client(ctx, dest_suffix, tls_config, tls_server_end_point, socket).await
|
||||
}
|
||||
.unwrap_or_else(|e| {
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
use chrono::Utc;
|
||||
use once_cell::sync::OnceCell;
|
||||
use smol_str::SmolStr;
|
||||
use std::net::IpAddr;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::{field::display, info_span, Span};
|
||||
use uuid::Uuid;
|
||||
@@ -62,7 +62,7 @@ pub enum AuthMethod {
|
||||
impl RequestMonitoring {
|
||||
pub fn new(
|
||||
session_id: Uuid,
|
||||
peer_addr: IpAddr,
|
||||
peer_addr: SocketAddr,
|
||||
protocol: &'static str,
|
||||
region: &'static str,
|
||||
) -> Self {
|
||||
@@ -75,7 +75,7 @@ impl RequestMonitoring {
|
||||
);
|
||||
|
||||
Self {
|
||||
peer_addr,
|
||||
peer_addr: peer_addr.ip(),
|
||||
session_id,
|
||||
protocol,
|
||||
first_packet: Utc::now(),
|
||||
@@ -100,7 +100,12 @@ impl RequestMonitoring {
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn test() -> Self {
|
||||
RequestMonitoring::new(Uuid::now_v7(), [127, 0, 0, 1].into(), "test", "test")
|
||||
RequestMonitoring::new(
|
||||
Uuid::now_v7(),
|
||||
([127, 0, 0, 1], 5432).into(),
|
||||
"test",
|
||||
"test",
|
||||
)
|
||||
}
|
||||
|
||||
pub fn console_application_name(&self) -> String {
|
||||
|
||||
@@ -5,19 +5,13 @@ use std::{
|
||||
io,
|
||||
net::SocketAddr,
|
||||
pin::{pin, Pin},
|
||||
sync::Mutex,
|
||||
task::{ready, Context, Poll},
|
||||
};
|
||||
|
||||
use bytes::{Buf, BytesMut};
|
||||
use hyper::server::accept::Accept;
|
||||
use hyper::server::conn::{AddrIncoming, AddrStream};
|
||||
use metrics::IntCounterPairGuard;
|
||||
use hyper::server::conn::AddrIncoming;
|
||||
use pin_project_lite::pin_project;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{metrics::NUM_CLIENT_CONNECTION_GAUGE, serverless::tls_listener::AsyncAccept};
|
||||
|
||||
pub struct ProxyProtocolAccept {
|
||||
pub incoming: AddrIncoming,
|
||||
@@ -331,87 +325,6 @@ impl<T: AsyncRead> AsyncRead for WithClientIp<T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncAccept for ProxyProtocolAccept {
|
||||
type Connection = WithConnectionGuard<WithClientIp<AddrStream>>;
|
||||
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll_accept(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<Self::Connection, Self::Error>>> {
|
||||
let conn = ready!(Pin::new(&mut self.incoming).poll_accept(cx)?);
|
||||
tracing::info!(protocol = self.protocol, "accepted new TCP connection");
|
||||
let Some(conn) = conn else {
|
||||
return Poll::Ready(None);
|
||||
};
|
||||
|
||||
Poll::Ready(Some(Ok(WithConnectionGuard {
|
||||
inner: WithClientIp::new(conn),
|
||||
connection_id: Uuid::new_v4(),
|
||||
gauge: Mutex::new(Some(
|
||||
NUM_CLIENT_CONNECTION_GAUGE
|
||||
.with_label_values(&[self.protocol])
|
||||
.guard(),
|
||||
)),
|
||||
})))
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
pub struct WithConnectionGuard<T> {
|
||||
#[pin]
|
||||
pub inner: T,
|
||||
pub connection_id: Uuid,
|
||||
pub gauge: Mutex<Option<IntCounterPairGuard>>,
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsyncWrite> AsyncWrite for WithConnectionGuard<T> {
|
||||
#[inline]
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, io::Error>> {
|
||||
self.project().inner.poll_write(cx, buf)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
self.project().inner.poll_flush(cx)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
self.project().inner.poll_shutdown(cx)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_write_vectored(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
bufs: &[io::IoSlice<'_>],
|
||||
) -> Poll<Result<usize, io::Error>> {
|
||||
self.project().inner.poll_write_vectored(cx, bufs)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn is_write_vectored(&self) -> bool {
|
||||
self.inner.is_write_vectored()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsyncRead> AsyncRead for WithConnectionGuard<T> {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
self.project().inner.poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::pin::pin;
|
||||
|
||||
@@ -91,9 +91,8 @@ pub async fn task_main(
|
||||
|
||||
connections.spawn(async move {
|
||||
let mut socket = WithClientIp::new(socket);
|
||||
let mut peer_addr = peer_addr.ip();
|
||||
match socket.wait_for_addr().await {
|
||||
Ok(Some(addr)) => peer_addr = addr.ip(),
|
||||
let peer_addr = match socket.wait_for_addr().await {
|
||||
Ok(Some(addr)) => addr,
|
||||
Err(e) => {
|
||||
error!("per-client task finished with an error: {e:#}");
|
||||
return;
|
||||
@@ -102,8 +101,8 @@ pub async fn task_main(
|
||||
error!("missing required client IP");
|
||||
return;
|
||||
}
|
||||
Ok(None) => {}
|
||||
}
|
||||
Ok(None) => peer_addr
|
||||
};
|
||||
|
||||
match socket.inner.set_nodelay(true) {
|
||||
Ok(()) => {},
|
||||
|
||||
@@ -4,46 +4,45 @@
|
||||
|
||||
mod backend;
|
||||
mod conn_pool;
|
||||
mod http_auto;
|
||||
mod json;
|
||||
mod sql_over_http;
|
||||
pub mod tls_listener;
|
||||
mod websocket;
|
||||
|
||||
use bytes::Bytes;
|
||||
pub use conn_pool::GlobalConnPoolOptions;
|
||||
|
||||
use anyhow::bail;
|
||||
use hyper::StatusCode;
|
||||
use metrics::IntCounterPairGuard;
|
||||
use anyhow::Context;
|
||||
use futures::future::{select, Either};
|
||||
use http1::{Method, Response, StatusCode};
|
||||
use http_body_util::Full;
|
||||
use hyper1::body::Incoming;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::SeedableRng;
|
||||
pub use reqwest_middleware::{ClientWithMiddleware, Error};
|
||||
pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
|
||||
use serde::Serialize;
|
||||
use tokio::time::timeout;
|
||||
use tokio_util::task::TaskTracker;
|
||||
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::metrics::TLS_HANDSHAKE_FAILURES;
|
||||
use crate::protocol2::{ProxyProtocolAccept, WithClientIp, WithConnectionGuard};
|
||||
use crate::metrics::{NUM_CLIENT_CONNECTION_GAUGE, TLS_HANDSHAKE_FAILURES};
|
||||
use crate::protocol2::WithClientIp;
|
||||
use crate::proxy::run_until_cancelled;
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::serverless::backend::PoolingBackend;
|
||||
use crate::serverless::http_auto::Rewind;
|
||||
use crate::{cancellation::CancellationHandler, config::ProxyConfig};
|
||||
use futures::StreamExt;
|
||||
use hyper::{
|
||||
server::{
|
||||
accept,
|
||||
conn::{AddrIncoming, AddrStream},
|
||||
},
|
||||
Body, Method, Request, Response,
|
||||
};
|
||||
|
||||
use std::convert::Infallible;
|
||||
use std::net::IpAddr;
|
||||
use std::task::Poll;
|
||||
use std::{future::ready, sync::Arc};
|
||||
use tls_listener::TlsListener;
|
||||
use std::net::SocketAddr;
|
||||
use std::pin::pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{error, info, warn, Instrument};
|
||||
use utils::http::{error::ApiError, json::json_response};
|
||||
use utils::http::error::ApiError;
|
||||
|
||||
pub const SERVERLESS_DRIVER_SNI: &str = "api";
|
||||
|
||||
@@ -95,134 +94,221 @@ pub async fn task_main(
|
||||
tls_server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
|
||||
let tls_acceptor: tokio_rustls::TlsAcceptor = Arc::new(tls_server_config).into();
|
||||
|
||||
let mut addr_incoming = AddrIncoming::from_listener(ws_listener)?;
|
||||
let _ = addr_incoming.set_nodelay(true);
|
||||
let addr_incoming = ProxyProtocolAccept {
|
||||
incoming: addr_incoming,
|
||||
protocol: "http",
|
||||
};
|
||||
|
||||
let ws_connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
ws_connections.close(); // allows `ws_connections.wait to complete`
|
||||
|
||||
let tls_listener = TlsListener::new(tls_acceptor, addr_incoming).filter(|conn| {
|
||||
if let Err(err) = conn {
|
||||
error!(
|
||||
protocol = "http",
|
||||
"failed to accept TLS connection: {err:?}"
|
||||
);
|
||||
TLS_HANDSHAKE_FAILURES.inc();
|
||||
ready(false)
|
||||
} else {
|
||||
info!(protocol = "http", "accepted new TLS connection");
|
||||
ready(true)
|
||||
let http_connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
http_connections.close();
|
||||
|
||||
let server = http_auto::Builder::new();
|
||||
|
||||
loop {
|
||||
let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await else {
|
||||
break;
|
||||
};
|
||||
|
||||
let (conn, mut peer_addr) = res.context("could not accept TCP stream")?;
|
||||
if let Err(e) = conn.set_nodelay(true) {
|
||||
tracing::error!("could not set nodolay: {e}");
|
||||
continue;
|
||||
}
|
||||
});
|
||||
let cancellation_token = cancellation_token.child_token();
|
||||
|
||||
let make_svc = hyper::service::make_service_fn(
|
||||
|stream: &tokio_rustls::server::TlsStream<
|
||||
WithConnectionGuard<WithClientIp<AddrStream>>,
|
||||
>| {
|
||||
let (conn, _) = stream.get_ref();
|
||||
let tls = tls_acceptor.clone();
|
||||
|
||||
// this is jank. should dissapear with hyper 1.0 migration.
|
||||
let gauge = conn
|
||||
.gauge
|
||||
.lock()
|
||||
.expect("lock should not be poisoned")
|
||||
.take()
|
||||
.expect("gauge should be set on connection start");
|
||||
let backend = backend.clone();
|
||||
let ws_connections = ws_connections.clone();
|
||||
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
||||
let cancellation_handler = cancellation_handler.clone();
|
||||
let server = server.clone();
|
||||
|
||||
let client_addr = conn.inner.client_addr();
|
||||
let remote_addr = conn.inner.inner.remote_addr();
|
||||
let backend = backend.clone();
|
||||
let ws_connections = ws_connections.clone();
|
||||
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
||||
let cancellation_handler = cancellation_handler.clone();
|
||||
async move {
|
||||
let peer_addr = match client_addr {
|
||||
Some(addr) => addr,
|
||||
None if config.require_client_ip => bail!("missing required client ip"),
|
||||
None => remote_addr,
|
||||
};
|
||||
Ok(MetricService::new(
|
||||
hyper::service::service_fn(move |req: Request<Body>| {
|
||||
let backend = backend.clone();
|
||||
let ws_connections = ws_connections.clone();
|
||||
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
||||
let cancellation_handler = cancellation_handler.clone();
|
||||
http_connections.spawn(async move {
|
||||
let _gauge = NUM_CLIENT_CONNECTION_GAUGE
|
||||
.with_label_values(&["http"])
|
||||
.guard();
|
||||
|
||||
async move {
|
||||
Ok::<_, Infallible>(
|
||||
request_handler(
|
||||
req,
|
||||
config,
|
||||
backend,
|
||||
ws_connections,
|
||||
cancellation_handler,
|
||||
peer_addr.ip(),
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.await
|
||||
.map_or_else(|e| e.into_response(), |r| r),
|
||||
)
|
||||
}
|
||||
}),
|
||||
gauge,
|
||||
))
|
||||
// handle PROXY protocol
|
||||
let mut conn = WithClientIp::new(conn);
|
||||
let peer = match conn.wait_for_addr().await {
|
||||
Ok(peer) => peer,
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
"failed to accept TCP connection: invalid PROXY protocol V2 header: {e:#}"
|
||||
);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(peer) = peer {
|
||||
peer_addr = peer;
|
||||
}
|
||||
},
|
||||
);
|
||||
info!(%peer_addr, protocol = "http", "accepted new TCP connection");
|
||||
|
||||
hyper::Server::builder(accept::from_stream(tls_listener))
|
||||
.serve(make_svc)
|
||||
.with_graceful_shutdown(cancellation_token.cancelled())
|
||||
.await?;
|
||||
let accept = tls.accept(conn);
|
||||
let conn = match timeout(Duration::from_secs(10), accept).await {
|
||||
Ok(Ok(conn)) => {
|
||||
info!(%peer_addr, protocol = "http", "accepted new TLS connection");
|
||||
conn
|
||||
}
|
||||
// The handshake failed, try getting another connection from the queue
|
||||
Ok(Err(e)) => {
|
||||
TLS_HANDSHAKE_FAILURES.inc();
|
||||
warn!(%peer_addr, protocol = "http", "failed to accept TLS connection: {e:?}");
|
||||
return;
|
||||
}
|
||||
// The handshake timed out, try getting another connection from the queue
|
||||
Err(_) => {
|
||||
TLS_HANDSHAKE_FAILURES.inc();
|
||||
warn!(%peer_addr, protocol = "http", "failed to accept TLS connection: timeout");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let (version, conn) = match conn.get_ref().1.alpn_protocol() {
|
||||
Some(b"http/1.1") => (http_auto::Version::H1, Rewind::new(conn)),
|
||||
Some(b"h2") => (http_auto::Version::H2, Rewind::new(conn)),
|
||||
_ => {
|
||||
tracing::debug!("HTTP: no ALPN negotiated");
|
||||
let conn = timeout(Duration::from_secs(10), http_auto::read_version(conn)).await;
|
||||
match conn {
|
||||
Ok(Ok(v)) => v,
|
||||
Ok(Err(e)) => {
|
||||
tracing::warn!("HTTP connection error: {e}");
|
||||
return;
|
||||
},
|
||||
Err(_) => {
|
||||
tracing::warn!("HTTP connection error: timeout determining http version");
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let conn = server.serve_connection_with_upgrades(
|
||||
conn,
|
||||
version,
|
||||
hyper1::service::service_fn(move |req: hyper1::Request<Incoming>| {
|
||||
let backend = backend.clone();
|
||||
let ws_connections = ws_connections.clone();
|
||||
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
||||
let cancellation_handler = cancellation_handler.clone();
|
||||
|
||||
async move {
|
||||
Ok::<_, Infallible>(
|
||||
request_handler(
|
||||
req,
|
||||
config,
|
||||
backend,
|
||||
ws_connections,
|
||||
cancellation_handler,
|
||||
peer_addr,
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.await
|
||||
.map_or_else(api_error_into_response, |r| r),
|
||||
)
|
||||
}
|
||||
})
|
||||
);
|
||||
|
||||
|
||||
let cancel = pin!(cancellation_token.cancelled());
|
||||
let conn = pin!(conn);
|
||||
let res = match select(cancel, conn).await {
|
||||
Either::Left((_cancelled, mut conn)) => {
|
||||
conn.as_mut().graceful_shutdown();
|
||||
conn.await
|
||||
}
|
||||
Either::Right((res, _)) => res,
|
||||
};
|
||||
|
||||
match res {
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
tracing::warn!("HTTP connection error {e}")
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// await websocket connections
|
||||
http_connections.wait().await;
|
||||
ws_connections.wait().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct MetricService<S> {
|
||||
inner: S,
|
||||
_gauge: IntCounterPairGuard,
|
||||
}
|
||||
|
||||
impl<S> MetricService<S> {
|
||||
fn new(inner: S, _gauge: IntCounterPairGuard) -> MetricService<S> {
|
||||
MetricService { inner, _gauge }
|
||||
fn api_error_into_response(this: ApiError) -> Response<Full<Bytes>> {
|
||||
match this {
|
||||
ApiError::BadRequest(err) => HttpErrorBody::response_from_msg_and_status(
|
||||
format!("{err:#?}"), // use debug printing so that we give the cause
|
||||
StatusCode::BAD_REQUEST,
|
||||
),
|
||||
ApiError::Forbidden(_) => {
|
||||
HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::FORBIDDEN)
|
||||
}
|
||||
ApiError::Unauthorized(_) => {
|
||||
HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::UNAUTHORIZED)
|
||||
}
|
||||
ApiError::NotFound(_) => {
|
||||
HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::NOT_FOUND)
|
||||
}
|
||||
ApiError::Conflict(_) => {
|
||||
HttpErrorBody::response_from_msg_and_status(this.to_string(), StatusCode::CONFLICT)
|
||||
}
|
||||
ApiError::PreconditionFailed(_) => HttpErrorBody::response_from_msg_and_status(
|
||||
this.to_string(),
|
||||
StatusCode::PRECONDITION_FAILED,
|
||||
),
|
||||
ApiError::ShuttingDown => HttpErrorBody::response_from_msg_and_status(
|
||||
"Shutting down".to_string(),
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
),
|
||||
ApiError::ResourceUnavailable(err) => HttpErrorBody::response_from_msg_and_status(
|
||||
err.to_string(),
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
),
|
||||
ApiError::Timeout(err) => HttpErrorBody::response_from_msg_and_status(
|
||||
err.to_string(),
|
||||
StatusCode::REQUEST_TIMEOUT,
|
||||
),
|
||||
ApiError::InternalServerError(err) => HttpErrorBody::response_from_msg_and_status(
|
||||
err.to_string(),
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, ReqBody> hyper::service::Service<Request<ReqBody>> for MetricService<S>
|
||||
where
|
||||
S: hyper::service::Service<Request<ReqBody>>,
|
||||
{
|
||||
type Response = S::Response;
|
||||
type Error = S::Error;
|
||||
type Future = S::Future;
|
||||
#[derive(Serialize)]
|
||||
struct HttpErrorBody {
|
||||
pub msg: String,
|
||||
}
|
||||
|
||||
fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.inner.poll_ready(cx)
|
||||
impl HttpErrorBody {
|
||||
pub fn response_from_msg_and_status(msg: String, status: StatusCode) -> Response<Full<Bytes>> {
|
||||
HttpErrorBody { msg }.to_response(status)
|
||||
}
|
||||
|
||||
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
|
||||
self.inner.call(req)
|
||||
pub fn to_response(&self, status: StatusCode) -> Response<Full<Bytes>> {
|
||||
Response::builder()
|
||||
.status(status)
|
||||
.header(http1::header::CONTENT_TYPE, "application/json")
|
||||
// we do not have nested maps with non string keys so serialization shouldn't fail
|
||||
.body(Full::new(Bytes::from(serde_json::to_string(self).unwrap())))
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn request_handler(
|
||||
mut request: Request<Body>,
|
||||
mut request: hyper1::Request<Incoming>,
|
||||
config: &'static ProxyConfig,
|
||||
backend: Arc<PoolingBackend>,
|
||||
ws_connections: TaskTracker,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
peer_addr: IpAddr,
|
||||
peer_addr: SocketAddr,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
) -> Result<Response<Full<Bytes>>, ApiError> {
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
|
||||
let host = request
|
||||
@@ -261,14 +347,14 @@ async fn request_handler(
|
||||
|
||||
// Return the response so the spawned future can continue.
|
||||
Ok(response)
|
||||
} else if request.uri().path() == "/sql" && request.method() == Method::POST {
|
||||
} else if request.uri().path() == "/sql" && *request.method() == Method::POST {
|
||||
let ctx = RequestMonitoring::new(session_id, peer_addr, "http", &config.region);
|
||||
let span = ctx.span.clone();
|
||||
|
||||
sql_over_http::handle(config, ctx, request, backend)
|
||||
.instrument(span)
|
||||
.await
|
||||
} else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS {
|
||||
} else if request.uri().path() == "/sql" && *request.method() == Method::OPTIONS {
|
||||
Response::builder()
|
||||
.header("Allow", "OPTIONS, POST")
|
||||
.header("Access-Control-Allow-Origin", "*")
|
||||
@@ -278,9 +364,24 @@ async fn request_handler(
|
||||
)
|
||||
.header("Access-Control-Max-Age", "86400" /* 24 hours */)
|
||||
.status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code
|
||||
.body(Body::empty())
|
||||
.body(Full::new(Bytes::new()))
|
||||
.map_err(|e| ApiError::InternalServerError(e.into()))
|
||||
} else {
|
||||
json_response(StatusCode::BAD_REQUEST, "query is not supported")
|
||||
}
|
||||
}
|
||||
|
||||
fn json_response<T: Serialize>(
|
||||
status: StatusCode,
|
||||
data: T,
|
||||
) -> Result<Response<Full<Bytes>>, ApiError> {
|
||||
let json = serde_json::to_string(&data)
|
||||
.context("Failed to serialize JSON response")
|
||||
.map_err(ApiError::InternalServerError)?;
|
||||
let response = Response::builder()
|
||||
.status(status)
|
||||
.header(http1::header::CONTENT_TYPE, "application/json")
|
||||
.body(Full::new(Bytes::from(json)))
|
||||
.map_err(|e| ApiError::InternalServerError(e.into()))?;
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
316
proxy/src/serverless/http_auto.rs
Normal file
316
proxy/src/serverless/http_auto.rs
Normal file
@@ -0,0 +1,316 @@
|
||||
//! [`hyper-util`] offers an 'auto' connection to detect whether the connection should be HTTP1 or HTTP2.
|
||||
//! There's a bug in this implementation where graceful shutdowns are not properly respected.
|
||||
|
||||
use futures::ready;
|
||||
use hyper1::body::Body;
|
||||
use hyper1::service::HttpService;
|
||||
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
|
||||
use std::future::Future;
|
||||
use std::marker::PhantomPinned;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use std::{error::Error as StdError, io, marker::Unpin};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
|
||||
use ::http1::{Request, Response};
|
||||
use bytes::Bytes;
|
||||
use hyper1::{body::Incoming, service::Service};
|
||||
|
||||
use hyper1::server::conn::http1;
|
||||
use hyper1::{rt::bounds::Http2ServerConnExec, server::conn::http2};
|
||||
|
||||
use pin_project_lite::pin_project;
|
||||
|
||||
type Error = Box<dyn std::error::Error + Send + Sync>;
|
||||
|
||||
type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
const H2_PREFACE: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";
|
||||
|
||||
/// Http1 or Http2 connection builder.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Builder {
|
||||
http1: http1::Builder,
|
||||
http2: http2::Builder<TokioExecutor>,
|
||||
}
|
||||
|
||||
impl Builder {
|
||||
/// Create a new auto connection builder.
|
||||
pub fn new() -> Self {
|
||||
let mut builder = Self {
|
||||
http1: http1::Builder::new(),
|
||||
http2: http2::Builder::new(TokioExecutor::new()),
|
||||
};
|
||||
|
||||
builder.http1.timer(TokioTimer::new());
|
||||
builder.http2.timer(TokioTimer::new());
|
||||
|
||||
builder
|
||||
}
|
||||
|
||||
/// Bind a connection together with a [`Service`], with the ability to
|
||||
/// handle HTTP upgrades. This requires that the IO object implements
|
||||
/// `Send`.
|
||||
pub fn serve_connection_with_upgrades<I, S, B>(
|
||||
&self,
|
||||
io: Rewind<I>,
|
||||
version: Version,
|
||||
service: S,
|
||||
) -> UpgradeableConnection<I, S>
|
||||
where
|
||||
S: Service<Request<Incoming>, Response = Response<B>>,
|
||||
S::Future: 'static,
|
||||
S::Error: Into<Box<dyn StdError + Send + Sync>>,
|
||||
B: Body + 'static,
|
||||
B::Error: Into<Box<dyn StdError + Send + Sync>>,
|
||||
I: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
TokioExecutor: Http2ServerConnExec<S::Future, B>,
|
||||
{
|
||||
match version {
|
||||
Version::H1 => {
|
||||
let conn = self
|
||||
.http1
|
||||
.serve_connection(TokioIo::new(io), service)
|
||||
.with_upgrades();
|
||||
UpgradeableConnection {
|
||||
state: UpgradeableConnState::H1 { conn },
|
||||
}
|
||||
}
|
||||
Version::H2 => {
|
||||
let conn = self.http2.serve_connection(TokioIo::new(io), service);
|
||||
UpgradeableConnection {
|
||||
state: UpgradeableConnState::H2 { conn },
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub(crate) enum Version {
|
||||
H1,
|
||||
H2,
|
||||
}
|
||||
|
||||
pub(crate) fn read_version<I>(io: I) -> ReadVersion<I>
|
||||
where
|
||||
I: AsyncRead + Unpin,
|
||||
{
|
||||
ReadVersion {
|
||||
io: Some(io),
|
||||
buf: [0; 24],
|
||||
filled: 0,
|
||||
version: Version::H2,
|
||||
_pin: PhantomPinned,
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
pub(crate) struct ReadVersion<I> {
|
||||
io: Option<I>,
|
||||
buf: [u8; 24],
|
||||
// the amount of `buf` thats been filled
|
||||
filled: usize,
|
||||
version: Version,
|
||||
// Make this future `!Unpin` for compatibility with async trait methods.
|
||||
#[pin]
|
||||
_pin: PhantomPinned,
|
||||
}
|
||||
}
|
||||
|
||||
impl<I> Future for ReadVersion<I>
|
||||
where
|
||||
I: AsyncRead + Unpin,
|
||||
{
|
||||
type Output = io::Result<(Version, Rewind<I>)>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let this = self.project();
|
||||
|
||||
let mut buf = ReadBuf::new(&mut *this.buf);
|
||||
buf.set_filled(*this.filled);
|
||||
|
||||
// We start as H2 and switch to H1 as soon as we don't have the preface.
|
||||
while buf.filled().len() < H2_PREFACE.len() {
|
||||
let len = buf.filled().len();
|
||||
ready!(Pin::new(this.io.as_mut().unwrap()).poll_read(cx, &mut buf))?;
|
||||
*this.filled = buf.filled().len();
|
||||
|
||||
// We starts as H2 and switch to H1 when we don't get the preface.
|
||||
if buf.filled().len() == len
|
||||
|| buf.filled()[len..] != H2_PREFACE[len..buf.filled().len()]
|
||||
{
|
||||
*this.version = Version::H1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let io = this.io.take().unwrap();
|
||||
let buf = buf.filled().to_vec();
|
||||
Poll::Ready(Ok((
|
||||
*this.version,
|
||||
Rewind::new_buffered(io, Bytes::from(buf)),
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
/// Connection future.
|
||||
pub struct UpgradeableConnection<I, S>
|
||||
where
|
||||
S: HttpService<Incoming>,
|
||||
{
|
||||
#[pin]
|
||||
state: UpgradeableConnState<I, S>,
|
||||
}
|
||||
}
|
||||
|
||||
type Http1UpgradeableConnection<I, S> =
|
||||
hyper1::server::conn::http1::UpgradeableConnection<TokioIo<Rewind<I>>, S>;
|
||||
type Http2Connection<I, S> =
|
||||
hyper1::server::conn::http2::Connection<TokioIo<Rewind<I>>, S, TokioExecutor>;
|
||||
|
||||
pin_project! {
|
||||
#[project = UpgradeableConnStateProj]
|
||||
enum UpgradeableConnState<I, S>
|
||||
where
|
||||
S: HttpService<Incoming>,
|
||||
{
|
||||
H1 {
|
||||
#[pin]
|
||||
conn: Http1UpgradeableConnection<I, S>,
|
||||
},
|
||||
H2 {
|
||||
#[pin]
|
||||
conn: Http2Connection<I, S>,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
impl<I, S, B> UpgradeableConnection<I, S>
|
||||
where
|
||||
S: HttpService<Incoming, ResBody = B>,
|
||||
S::Error: Into<Box<dyn StdError + Send + Sync>>,
|
||||
I: AsyncRead + AsyncWrite + Unpin,
|
||||
B: Body + 'static,
|
||||
B::Error: Into<Box<dyn StdError + Send + Sync>>,
|
||||
TokioExecutor: Http2ServerConnExec<S::Future, B>,
|
||||
{
|
||||
/// Start a graceful shutdown process for this connection.
|
||||
///
|
||||
/// This `UpgradeableConnection` should continue to be polled until shutdown can finish.
|
||||
///
|
||||
/// # Note
|
||||
///
|
||||
/// This should only be called while the `Connection` future is still nothing. pending. If
|
||||
/// called after `UpgradeableConnection::poll` has resolved, this does nothing.
|
||||
pub fn graceful_shutdown(self: Pin<&mut Self>) {
|
||||
match self.project().state.project() {
|
||||
UpgradeableConnStateProj::H1 { conn } => conn.graceful_shutdown(),
|
||||
UpgradeableConnStateProj::H2 { conn } => conn.graceful_shutdown(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<I, S, B> Future for UpgradeableConnection<I, S>
|
||||
where
|
||||
S: Service<Request<Incoming>, Response = Response<B>>,
|
||||
S::Future: 'static,
|
||||
S::Error: Into<Box<dyn StdError + Send + Sync>>,
|
||||
B: Body + 'static,
|
||||
B::Error: Into<Box<dyn StdError + Send + Sync>>,
|
||||
I: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
TokioExecutor: Http2ServerConnExec<S::Future, B>,
|
||||
{
|
||||
type Output = Result<()>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let mut this = self.as_mut().project();
|
||||
match this.state.as_mut().project() {
|
||||
UpgradeableConnStateProj::H1 { conn } => conn.poll(cx).map_err(Into::into),
|
||||
UpgradeableConnStateProj::H2 { conn } => conn.poll(cx).map_err(Into::into),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Combine a buffer with an IO, rewinding reads to use the buffer.
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct Rewind<T> {
|
||||
pre: Option<Bytes>,
|
||||
inner: T,
|
||||
}
|
||||
|
||||
impl<T> Rewind<T> {
|
||||
pub(crate) fn new(io: T) -> Self {
|
||||
Rewind {
|
||||
pre: None,
|
||||
inner: io,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn new_buffered(io: T, buf: Bytes) -> Self {
|
||||
Rewind {
|
||||
pre: Some(buf),
|
||||
inner: io,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> AsyncRead for Rewind<T>
|
||||
where
|
||||
T: AsyncRead + Unpin,
|
||||
{
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
if let Some(prefix) = self.pre.take() {
|
||||
// If there are no remaining bytes, let the bytes get dropped.
|
||||
if !prefix.is_empty() {
|
||||
let copy_len = std::cmp::min(prefix.len(), buf.remaining());
|
||||
buf.put_slice(&prefix[..copy_len]);
|
||||
// Put back what's left
|
||||
if !prefix.is_empty() {
|
||||
self.pre = Some(prefix);
|
||||
}
|
||||
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
}
|
||||
Pin::new(&mut self.inner).poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> AsyncWrite for Rewind<T>
|
||||
where
|
||||
T: AsyncWrite + Unpin,
|
||||
{
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
Pin::new(&mut self.inner).poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_write_vectored(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
bufs: &[io::IoSlice<'_>],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_shutdown(cx)
|
||||
}
|
||||
|
||||
fn is_write_vectored(&self) -> bool {
|
||||
self.inner.is_write_vectored()
|
||||
}
|
||||
}
|
||||
@@ -1,14 +1,19 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::json_response;
|
||||
use anyhow::bail;
|
||||
use bytes::Bytes;
|
||||
use futures::StreamExt;
|
||||
use hyper::body::HttpBody;
|
||||
use hyper::header;
|
||||
use hyper::http::HeaderName;
|
||||
use hyper::http::HeaderValue;
|
||||
use hyper::Response;
|
||||
use hyper::StatusCode;
|
||||
use hyper::{Body, HeaderMap, Request};
|
||||
use http_body_util::BodyExt;
|
||||
use http_body_util::Full;
|
||||
use hyper1::body::Body;
|
||||
use hyper1::body::Incoming;
|
||||
use hyper1::header;
|
||||
use hyper1::http::HeaderName;
|
||||
use hyper1::http::HeaderValue;
|
||||
use hyper1::Response;
|
||||
use hyper1::StatusCode;
|
||||
use hyper1::{HeaderMap, Request};
|
||||
use serde_json::json;
|
||||
use serde_json::Value;
|
||||
use tokio::try_join;
|
||||
@@ -22,7 +27,6 @@ use tracing::error;
|
||||
use tracing::info;
|
||||
use url::Url;
|
||||
use utils::http::error::ApiError;
|
||||
use utils::http::json::json_response;
|
||||
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::auth::endpoint_sni;
|
||||
@@ -191,9 +195,9 @@ fn get_conn_info(
|
||||
pub async fn handle(
|
||||
config: &'static ProxyConfig,
|
||||
mut ctx: RequestMonitoring,
|
||||
request: Request<Body>,
|
||||
request: Request<Incoming>,
|
||||
backend: Arc<PoolingBackend>,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
) -> Result<Response<Full<Bytes>>, ApiError> {
|
||||
let result = tokio::time::timeout(
|
||||
config.http_config.request_timeout,
|
||||
handle_inner(config, &mut ctx, request, backend),
|
||||
@@ -300,19 +304,18 @@ pub async fn handle(
|
||||
}
|
||||
};
|
||||
|
||||
response.headers_mut().insert(
|
||||
"Access-Control-Allow-Origin",
|
||||
hyper::http::HeaderValue::from_static("*"),
|
||||
);
|
||||
response
|
||||
.headers_mut()
|
||||
.insert("Access-Control-Allow-Origin", HeaderValue::from_static("*"));
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
async fn handle_inner(
|
||||
config: &'static ProxyConfig,
|
||||
ctx: &mut RequestMonitoring,
|
||||
request: Request<Body>,
|
||||
request: Request<Incoming>,
|
||||
backend: Arc<PoolingBackend>,
|
||||
) -> anyhow::Result<Response<Body>> {
|
||||
) -> anyhow::Result<Response<Full<Bytes>>> {
|
||||
let _request_gauge = NUM_CONNECTION_REQUESTS_GAUGE
|
||||
.with_label_values(&[ctx.protocol])
|
||||
.guard();
|
||||
@@ -369,9 +372,12 @@ async fn handle_inner(
|
||||
}
|
||||
|
||||
let fetch_and_process_request = async {
|
||||
let body = hyper::body::to_bytes(request.into_body())
|
||||
let body = request
|
||||
.into_body()
|
||||
.collect()
|
||||
.await
|
||||
.map_err(anyhow::Error::from)?;
|
||||
.map_err(anyhow::Error::from)?
|
||||
.to_bytes();
|
||||
info!(length = body.len(), "request payload read");
|
||||
let payload: Payload = serde_json::from_slice(&body)?;
|
||||
Ok::<Payload, anyhow::Error>(payload) // Adjust error type accordingly
|
||||
@@ -490,7 +496,7 @@ async fn handle_inner(
|
||||
let body = serde_json::to_string(&result).expect("json serialization should not fail");
|
||||
let len = body.len();
|
||||
let response = response
|
||||
.body(Body::from(body))
|
||||
.body(Full::new(Bytes::from(body)))
|
||||
// only fails if invalid status code or invalid header/values are given.
|
||||
// these are not user configurable so it cannot fail dynamically
|
||||
.expect("building response payload should not fail");
|
||||
|
||||
@@ -1,283 +0,0 @@
|
||||
use std::{
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use futures::{Future, Stream, StreamExt};
|
||||
use pin_project_lite::pin_project;
|
||||
use thiserror::Error;
|
||||
use tokio::{
|
||||
io::{AsyncRead, AsyncWrite},
|
||||
task::JoinSet,
|
||||
time::timeout,
|
||||
};
|
||||
|
||||
/// Default timeout for the TLS handshake.
|
||||
pub const DEFAULT_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
|
||||
/// Trait for TLS implementation.
|
||||
///
|
||||
/// Implementations are provided by the rustls and native-tls features.
|
||||
pub trait AsyncTls<C: AsyncRead + AsyncWrite>: Clone {
|
||||
/// The type of the TLS stream created from the underlying stream.
|
||||
type Stream: Send + 'static;
|
||||
/// Error type for completing the TLS handshake
|
||||
type Error: std::error::Error + Send + 'static;
|
||||
/// Type of the Future for the TLS stream that is accepted.
|
||||
type AcceptFuture: Future<Output = Result<Self::Stream, Self::Error>> + Send + 'static;
|
||||
|
||||
/// Accept a TLS connection on an underlying stream
|
||||
fn accept(&self, stream: C) -> Self::AcceptFuture;
|
||||
}
|
||||
|
||||
/// Asynchronously accept connections.
|
||||
pub trait AsyncAccept {
|
||||
/// The type of the connection that is accepted.
|
||||
type Connection: AsyncRead + AsyncWrite;
|
||||
/// The type of error that may be returned.
|
||||
type Error;
|
||||
|
||||
/// Poll to accept the next connection.
|
||||
fn poll_accept(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<Self::Connection, Self::Error>>>;
|
||||
|
||||
/// Return a new `AsyncAccept` that stops accepting connections after
|
||||
/// `ender` completes.
|
||||
///
|
||||
/// Useful for graceful shutdown.
|
||||
///
|
||||
/// See [examples/echo.rs](https://github.com/tmccombs/tls-listener/blob/main/examples/echo.rs)
|
||||
/// for example of how to use.
|
||||
fn until<F: Future>(self, ender: F) -> Until<Self, F>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
Until {
|
||||
acceptor: self,
|
||||
ender,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
///
|
||||
/// Wraps a `Stream` of connections (such as a TCP listener) so that each connection is itself
|
||||
/// encrypted using TLS.
|
||||
///
|
||||
/// It is similar to:
|
||||
///
|
||||
/// ```ignore
|
||||
/// tcpListener.and_then(|s| tlsAcceptor.accept(s))
|
||||
/// ```
|
||||
///
|
||||
/// except that it has the ability to accept multiple transport-level connections
|
||||
/// simultaneously while the TLS handshake is pending for other connections.
|
||||
///
|
||||
/// By default, if a client fails the TLS handshake, that is treated as an error, and the
|
||||
/// `TlsListener` will return an `Err`. If the `TlsListener` is passed directly to a hyper
|
||||
/// [`Server`][1], then an invalid handshake can cause the server to stop accepting connections.
|
||||
/// See [`http-stream.rs`][2] or [`http-low-level`][3] examples, for examples of how to avoid this.
|
||||
///
|
||||
/// Note that if the maximum number of pending connections is greater than 1, the resulting
|
||||
/// [`T::Stream`][4] connections may come in a different order than the connections produced by the
|
||||
/// underlying listener.
|
||||
///
|
||||
/// [1]: https://docs.rs/hyper/latest/hyper/server/struct.Server.html
|
||||
/// [2]: https://github.com/tmccombs/tls-listener/blob/main/examples/http-stream.rs
|
||||
/// [3]: https://github.com/tmccombs/tls-listener/blob/main/examples/http-low-level.rs
|
||||
/// [4]: AsyncTls::Stream
|
||||
///
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub struct TlsListener<A: AsyncAccept, T: AsyncTls<A::Connection>> {
|
||||
#[pin]
|
||||
listener: A,
|
||||
tls: T,
|
||||
waiting: JoinSet<Result<Result<T::Stream, T::Error>, tokio::time::error::Elapsed>>,
|
||||
timeout: Duration,
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for `TlsListener`.
|
||||
#[derive(Clone)]
|
||||
pub struct Builder<T> {
|
||||
tls: T,
|
||||
handshake_timeout: Duration,
|
||||
}
|
||||
|
||||
/// Wraps errors from either the listener or the TLS Acceptor
|
||||
#[derive(Debug, Error)]
|
||||
pub enum Error<LE: std::error::Error, TE: std::error::Error> {
|
||||
/// An error that arose from the listener ([AsyncAccept::Error])
|
||||
#[error("{0}")]
|
||||
ListenerError(#[source] LE),
|
||||
/// An error that occurred during the TLS accept handshake
|
||||
#[error("{0}")]
|
||||
TlsAcceptError(#[source] TE),
|
||||
}
|
||||
|
||||
impl<A: AsyncAccept, T> TlsListener<A, T>
|
||||
where
|
||||
T: AsyncTls<A::Connection>,
|
||||
{
|
||||
/// Create a `TlsListener` with default options.
|
||||
pub fn new(tls: T, listener: A) -> Self {
|
||||
builder(tls).listen(listener)
|
||||
}
|
||||
}
|
||||
|
||||
impl<A, T> TlsListener<A, T>
|
||||
where
|
||||
A: AsyncAccept,
|
||||
A::Error: std::error::Error,
|
||||
T: AsyncTls<A::Connection>,
|
||||
{
|
||||
/// Accept the next connection
|
||||
///
|
||||
/// This is essentially an alias to `self.next()` with a more domain-appropriate name.
|
||||
pub async fn accept(&mut self) -> Option<<Self as Stream>::Item>
|
||||
where
|
||||
Self: Unpin,
|
||||
{
|
||||
self.next().await
|
||||
}
|
||||
|
||||
/// Replaces the Tls Acceptor configuration, which will be used for new connections.
|
||||
///
|
||||
/// This can be used to change the certificate used at runtime.
|
||||
pub fn replace_acceptor(&mut self, acceptor: T) {
|
||||
self.tls = acceptor;
|
||||
}
|
||||
|
||||
/// Replaces the Tls Acceptor configuration from a pinned reference to `Self`.
|
||||
///
|
||||
/// This is useful if your listener is `!Unpin`.
|
||||
///
|
||||
/// This can be used to change the certificate used at runtime.
|
||||
pub fn replace_acceptor_pin(self: Pin<&mut Self>, acceptor: T) {
|
||||
*self.project().tls = acceptor;
|
||||
}
|
||||
}
|
||||
|
||||
impl<A, T> Stream for TlsListener<A, T>
|
||||
where
|
||||
A: AsyncAccept,
|
||||
A::Error: std::error::Error,
|
||||
T: AsyncTls<A::Connection>,
|
||||
{
|
||||
type Item = Result<T::Stream, Error<A::Error, T::Error>>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let mut this = self.project();
|
||||
|
||||
loop {
|
||||
match this.listener.as_mut().poll_accept(cx) {
|
||||
Poll::Pending => break,
|
||||
Poll::Ready(Some(Ok(conn))) => {
|
||||
this.waiting
|
||||
.spawn(timeout(*this.timeout, this.tls.accept(conn)));
|
||||
}
|
||||
Poll::Ready(Some(Err(e))) => {
|
||||
return Poll::Ready(Some(Err(Error::ListenerError(e))));
|
||||
}
|
||||
Poll::Ready(None) => return Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
|
||||
loop {
|
||||
return match this.waiting.poll_join_next(cx) {
|
||||
Poll::Ready(Some(Ok(Ok(conn)))) => {
|
||||
Poll::Ready(Some(conn.map_err(Error::TlsAcceptError)))
|
||||
}
|
||||
// The handshake timed out, try getting another connection from the queue
|
||||
Poll::Ready(Some(Ok(Err(_)))) => continue,
|
||||
// The handshake panicked
|
||||
Poll::Ready(Some(Err(e))) if e.is_panic() => {
|
||||
std::panic::resume_unwind(e.into_panic())
|
||||
}
|
||||
// The handshake was externally aborted
|
||||
Poll::Ready(Some(Err(_))) => unreachable!("handshake tasks are never aborted"),
|
||||
_ => Poll::Pending,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: AsyncRead + AsyncWrite + Unpin + Send + 'static> AsyncTls<C> for tokio_rustls::TlsAcceptor {
|
||||
type Stream = tokio_rustls::server::TlsStream<C>;
|
||||
type Error = std::io::Error;
|
||||
type AcceptFuture = tokio_rustls::Accept<C>;
|
||||
|
||||
fn accept(&self, conn: C) -> Self::AcceptFuture {
|
||||
tokio_rustls::TlsAcceptor::accept(self, conn)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Builder<T> {
|
||||
/// Set the timeout for handshakes.
|
||||
///
|
||||
/// If a timeout takes longer than `timeout`, then the handshake will be
|
||||
/// aborted and the underlying connection will be dropped.
|
||||
///
|
||||
/// Defaults to `DEFAULT_HANDSHAKE_TIMEOUT`.
|
||||
pub fn handshake_timeout(&mut self, timeout: Duration) -> &mut Self {
|
||||
self.handshake_timeout = timeout;
|
||||
self
|
||||
}
|
||||
|
||||
/// Create a `TlsListener` from the builder
|
||||
///
|
||||
/// Actually build the `TlsListener`. The `listener` argument should be
|
||||
/// an implementation of the `AsyncAccept` trait that accepts new connections
|
||||
/// that the `TlsListener` will encrypt using TLS.
|
||||
pub fn listen<A: AsyncAccept>(&self, listener: A) -> TlsListener<A, T>
|
||||
where
|
||||
T: AsyncTls<A::Connection>,
|
||||
{
|
||||
TlsListener {
|
||||
listener,
|
||||
tls: self.tls.clone(),
|
||||
waiting: JoinSet::new(),
|
||||
timeout: self.handshake_timeout,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new Builder for a TlsListener
|
||||
///
|
||||
/// `server_config` will be used to configure the TLS sessions.
|
||||
pub fn builder<T>(tls: T) -> Builder<T> {
|
||||
Builder {
|
||||
tls,
|
||||
handshake_timeout: DEFAULT_HANDSHAKE_TIMEOUT,
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
/// See [`AsyncAccept::until`]
|
||||
pub struct Until<A, E> {
|
||||
#[pin]
|
||||
acceptor: A,
|
||||
#[pin]
|
||||
ender: E,
|
||||
}
|
||||
}
|
||||
|
||||
impl<A: AsyncAccept, E: Future> AsyncAccept for Until<A, E> {
|
||||
type Connection = A::Connection;
|
||||
type Error = A::Error;
|
||||
|
||||
fn poll_accept(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<Self::Connection, Self::Error>>> {
|
||||
let this = self.project();
|
||||
|
||||
match this.ender.poll(cx) {
|
||||
Poll::Pending => this.acceptor.poll_accept(cx),
|
||||
Poll::Ready(_) => Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1518,7 +1518,6 @@ class NeonCli(AbstractNeonCli):
|
||||
conf: Optional[Dict[str, Any]] = None,
|
||||
shard_count: Optional[int] = None,
|
||||
shard_stripe_size: Optional[int] = None,
|
||||
placement_policy: Optional[str] = None,
|
||||
set_default: bool = False,
|
||||
) -> Tuple[TenantId, TimelineId]:
|
||||
"""
|
||||
@@ -1552,9 +1551,6 @@ class NeonCli(AbstractNeonCli):
|
||||
if shard_stripe_size is not None:
|
||||
args.extend(["--shard-stripe-size", str(shard_stripe_size)])
|
||||
|
||||
if placement_policy is not None:
|
||||
args.extend(["--placement-policy", str(placement_policy)])
|
||||
|
||||
res = self.raw_cli(args)
|
||||
res.check_returncode()
|
||||
return tenant_id, timeline_id
|
||||
@@ -2172,37 +2168,6 @@ class NeonAttachmentService(MetricsGetter):
|
||||
)
|
||||
log.info("Attachment service passed consistency check")
|
||||
|
||||
def configure_failpoints(self, config_strings: Tuple[str, str] | List[Tuple[str, str]]):
|
||||
if isinstance(config_strings, tuple):
|
||||
pairs = [config_strings]
|
||||
else:
|
||||
pairs = config_strings
|
||||
|
||||
log.info(f"Requesting config failpoints: {repr(pairs)}")
|
||||
|
||||
res = self.request(
|
||||
"PUT",
|
||||
f"{self.env.attachment_service_api}/debug/v1/failpoints",
|
||||
json=[{"name": name, "actions": actions} for name, actions in pairs],
|
||||
headers=self.headers(TokenScope.ADMIN),
|
||||
)
|
||||
log.info(f"Got failpoints request response code {res.status_code}")
|
||||
res.raise_for_status()
|
||||
|
||||
def balance_all(self):
|
||||
self.request(
|
||||
"POST",
|
||||
f"{self.env.attachment_service_api}/control/v1/balance/all",
|
||||
headers=self.headers(TokenScope.ADMIN),
|
||||
)
|
||||
|
||||
def balance_attached(self):
|
||||
self.request(
|
||||
"POST",
|
||||
f"{self.env.attachment_service_api}/control/v1/balance/attached",
|
||||
headers=self.headers(TokenScope.ADMIN),
|
||||
)
|
||||
|
||||
def __enter__(self) -> "NeonAttachmentService":
|
||||
return self
|
||||
|
||||
@@ -2357,16 +2322,16 @@ class NeonPageserver(PgProtocol):
|
||||
def assert_no_errors(self):
|
||||
logfile = self.workdir / "pageserver.log"
|
||||
if not logfile.exists():
|
||||
log.warning(f"Skipping log check on pageserver {self.id}: {logfile} does not exist")
|
||||
log.warning(f"Skipping log check: {logfile} does not exist")
|
||||
return
|
||||
|
||||
with logfile.open("r") as f:
|
||||
errors = scan_pageserver_log_for_errors(f, self.allowed_errors)
|
||||
|
||||
for _lineno, error in errors:
|
||||
log.info(f"not allowed error (pageserver {self.id}): {error.strip()}")
|
||||
log.info(f"not allowed error: {error.strip()}")
|
||||
|
||||
assert not errors, f"Pageserver {self.id}: {errors}"
|
||||
assert not errors
|
||||
|
||||
def assert_no_metric_errors(self):
|
||||
"""
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
from fixtures.log_helper import log
|
||||
@@ -12,10 +11,6 @@ from fixtures.neon_fixtures import (
|
||||
from fixtures.pageserver.utils import wait_for_last_record_lsn, wait_for_upload
|
||||
from fixtures.types import TenantId, TimelineId
|
||||
|
||||
# neon_local doesn't handle creating/modifying endpoints concurrently, so we use a mutex
|
||||
# to ensure we don't do that: this enables running lots of Workloads in parallel safely.
|
||||
ENDPOINT_LOCK = threading.Lock()
|
||||
|
||||
|
||||
class Workload:
|
||||
"""
|
||||
@@ -46,30 +41,17 @@ class Workload:
|
||||
|
||||
self._endpoint: Optional[Endpoint] = None
|
||||
|
||||
def reconfigure(self):
|
||||
"""
|
||||
Request the endpoint to reconfigure based on location reported by storage controller
|
||||
"""
|
||||
if self._endpoint is not None:
|
||||
with ENDPOINT_LOCK:
|
||||
self._endpoint.reconfigure()
|
||||
|
||||
def endpoint(self, pageserver_id: Optional[int] = None) -> Endpoint:
|
||||
# We may be running alongside other Workloads for different tenants. Full TTID is
|
||||
# obnoxiously long for use here, but a cut-down version is still unique enough for tests.
|
||||
endpoint_id = f"ep-workload-{str(self.tenant_id)[0:4]}-{str(self.timeline_id)[0:4]}"
|
||||
|
||||
with ENDPOINT_LOCK:
|
||||
if self._endpoint is None:
|
||||
self._endpoint = self.env.endpoints.create(
|
||||
self.branch_name,
|
||||
tenant_id=self.tenant_id,
|
||||
pageserver_id=pageserver_id,
|
||||
endpoint_id=endpoint_id,
|
||||
)
|
||||
self._endpoint.start(pageserver_id=pageserver_id)
|
||||
else:
|
||||
self._endpoint.reconfigure(pageserver_id=pageserver_id)
|
||||
if self._endpoint is None:
|
||||
self._endpoint = self.env.endpoints.create(
|
||||
self.branch_name,
|
||||
tenant_id=self.tenant_id,
|
||||
pageserver_id=pageserver_id,
|
||||
endpoint_id="ep-workload",
|
||||
)
|
||||
self._endpoint.start(pageserver_id=pageserver_id)
|
||||
else:
|
||||
self._endpoint.reconfigure(pageserver_id=pageserver_id)
|
||||
|
||||
connstring = self._endpoint.safe_psql(
|
||||
"SELECT setting FROM pg_settings WHERE name='neon.pageserver_connstring'"
|
||||
@@ -112,7 +94,7 @@ class Workload:
|
||||
else:
|
||||
return False
|
||||
|
||||
def churn_rows(self, n, pageserver_id: Optional[int] = None, upload=True, ingest=True):
|
||||
def churn_rows(self, n, pageserver_id: Optional[int] = None, upload=True):
|
||||
assert self.expect_rows >= n
|
||||
|
||||
max_iters = 10
|
||||
@@ -150,28 +132,22 @@ class Workload:
|
||||
]
|
||||
)
|
||||
|
||||
if ingest:
|
||||
# Wait for written data to be ingested by the pageserver
|
||||
for tenant_shard_id, pageserver in tenant_get_shards(
|
||||
self.env, self.tenant_id, pageserver_id
|
||||
):
|
||||
last_flush_lsn = wait_for_last_flush_lsn(
|
||||
self.env,
|
||||
endpoint,
|
||||
self.tenant_id,
|
||||
self.timeline_id,
|
||||
pageserver_id=pageserver_id,
|
||||
)
|
||||
ps_http = pageserver.http_client()
|
||||
wait_for_last_record_lsn(ps_http, tenant_shard_id, self.timeline_id, last_flush_lsn)
|
||||
for tenant_shard_id, pageserver in tenant_get_shards(
|
||||
self.env, self.tenant_id, pageserver_id
|
||||
):
|
||||
last_flush_lsn = wait_for_last_flush_lsn(
|
||||
self.env, endpoint, self.tenant_id, self.timeline_id, pageserver_id=pageserver_id
|
||||
)
|
||||
ps_http = pageserver.http_client()
|
||||
wait_for_last_record_lsn(ps_http, tenant_shard_id, self.timeline_id, last_flush_lsn)
|
||||
|
||||
if upload:
|
||||
# Wait for written data to be uploaded to S3 (force a checkpoint to trigger upload)
|
||||
ps_http.timeline_checkpoint(tenant_shard_id, self.timeline_id)
|
||||
wait_for_upload(ps_http, tenant_shard_id, self.timeline_id, last_flush_lsn)
|
||||
log.info(f"Churn: waiting for remote LSN {last_flush_lsn}")
|
||||
else:
|
||||
log.info(f"Churn: not waiting for upload, disk LSN {last_flush_lsn}")
|
||||
if upload:
|
||||
# force a checkpoint to trigger upload
|
||||
ps_http.timeline_checkpoint(tenant_shard_id, self.timeline_id)
|
||||
wait_for_upload(ps_http, tenant_shard_id, self.timeline_id, last_flush_lsn)
|
||||
log.info(f"Churn: waiting for remote LSN {last_flush_lsn}")
|
||||
else:
|
||||
log.info(f"Churn: not waiting for upload, disk LSN {last_flush_lsn}")
|
||||
|
||||
def validate(self, pageserver_id: Optional[int] = None):
|
||||
endpoint = self.endpoint(pageserver_id)
|
||||
|
||||
@@ -1,17 +1,13 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from fixtures.log_helper import log
|
||||
from fixtures.neon_fixtures import (
|
||||
AttachmentServiceApiException,
|
||||
NeonEnv,
|
||||
NeonEnvBuilder,
|
||||
tenant_get_shards,
|
||||
)
|
||||
from fixtures.remote_storage import s3_storage
|
||||
from fixtures.types import Lsn, TenantShardId, TimelineId
|
||||
from fixtures.utils import wait_until
|
||||
from fixtures.workload import Workload
|
||||
|
||||
|
||||
@@ -404,245 +400,3 @@ def test_sharding_ingest(
|
||||
|
||||
# Each shard may emit up to one huge layer, because initdb ingest doesn't respect checkpoint_distance.
|
||||
assert huge_layer_count <= shard_count
|
||||
|
||||
|
||||
class Failure:
|
||||
pageserver_id: Optional[int]
|
||||
|
||||
def apply(self, env: NeonEnv):
|
||||
raise NotImplementedError()
|
||||
|
||||
def clear(self, env: NeonEnv):
|
||||
"""
|
||||
Clear the failure, in a way that should enable the system to proceed
|
||||
to a totally clean state (all nodes online and reconciled)
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def expect_available(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def can_mitigate(self):
|
||||
"""Whether Self.mitigate is available for use"""
|
||||
return False
|
||||
|
||||
def mitigate(self, env: NeonEnv):
|
||||
"""
|
||||
Mitigate the failure in a way that should allow shard split to
|
||||
complete and service to resume, but does not guarantee to leave
|
||||
the whole world in a clean state (e.g. an Offline node might have
|
||||
junk LocationConfigs on it)
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def fails_forward(self):
|
||||
"""
|
||||
If true, this failure results in a state that eventualy completes the split.
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
class PageserverFailpoint(Failure):
|
||||
def __init__(self, failpoint, pageserver_id, mitigate):
|
||||
self.failpoint = failpoint
|
||||
self.pageserver_id = pageserver_id
|
||||
self._mitigate = mitigate
|
||||
|
||||
def apply(self, env: NeonEnv):
|
||||
pageserver = env.get_pageserver(self.pageserver_id)
|
||||
pageserver.allowed_errors.extend(
|
||||
[".*failpoint.*", ".*Resetting.*after shard split failure.*"]
|
||||
)
|
||||
pageserver.http_client().configure_failpoints((self.failpoint, "return(1)"))
|
||||
|
||||
def clear(self, env: NeonEnv):
|
||||
pageserver = env.get_pageserver(self.pageserver_id)
|
||||
pageserver.http_client().configure_failpoints((self.failpoint, "off"))
|
||||
if self._mitigate:
|
||||
env.attachment_service.node_configure(self.pageserver_id, {"availability": "Active"})
|
||||
|
||||
def expect_available(self):
|
||||
return True
|
||||
|
||||
def can_mitigate(self):
|
||||
return self._mitigate
|
||||
|
||||
def mitigate(self, env):
|
||||
env.attachment_service.node_configure(self.pageserver_id, {"availability": "Offline"})
|
||||
|
||||
|
||||
class StorageControllerFailpoint(Failure):
|
||||
def __init__(self, failpoint):
|
||||
self.failpoint = failpoint
|
||||
self.pageserver_id = None
|
||||
|
||||
def apply(self, env: NeonEnv):
|
||||
env.attachment_service.configure_failpoints((self.failpoint, "return(1)"))
|
||||
|
||||
def clear(self, env: NeonEnv):
|
||||
env.attachment_service.configure_failpoints((self.failpoint, "off"))
|
||||
|
||||
def expect_available(self):
|
||||
return True
|
||||
|
||||
def can_mitigate(self):
|
||||
return False
|
||||
|
||||
def fails_forward(self):
|
||||
# Edge case: the very last failpoint that simulates a DB connection error, where
|
||||
# the abort path will fail-forward and result in a complete split.
|
||||
return self.failpoint == "shard-split-post-complete"
|
||||
|
||||
|
||||
class NodeKill(Failure):
|
||||
def __init__(self, pageserver_id, mitigate):
|
||||
self.pageserver_id = pageserver_id
|
||||
self._mitigate = mitigate
|
||||
|
||||
def apply(self, env: NeonEnv):
|
||||
pageserver = env.get_pageserver(self.pageserver_id)
|
||||
pageserver.stop(immediate=True)
|
||||
|
||||
def clear(self, env: NeonEnv):
|
||||
pageserver = env.get_pageserver(self.pageserver_id)
|
||||
pageserver.start()
|
||||
|
||||
def expect_available(self):
|
||||
return False
|
||||
|
||||
def mitigate(self, env):
|
||||
env.attachment_service.node_configure(self.pageserver_id, {"availability": "Offline"})
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"failure",
|
||||
[
|
||||
PageserverFailpoint("api-500", 1, False),
|
||||
NodeKill(1, False),
|
||||
PageserverFailpoint("api-500", 1, True),
|
||||
NodeKill(1, True),
|
||||
PageserverFailpoint("shard-split-pre-prepare", 1, False),
|
||||
PageserverFailpoint("shard-split-post-prepare", 1, False),
|
||||
PageserverFailpoint("shard-split-pre-hardlink", 1, False),
|
||||
PageserverFailpoint("shard-split-post-hardlink", 1, False),
|
||||
PageserverFailpoint("shard-split-post-child-conf", 1, False),
|
||||
PageserverFailpoint("shard-split-lsn-wait", 1, False),
|
||||
PageserverFailpoint("shard-split-pre-finish", 1, False),
|
||||
StorageControllerFailpoint("shard-split-validation"),
|
||||
StorageControllerFailpoint("shard-split-post-begin"),
|
||||
StorageControllerFailpoint("shard-split-post-remote"),
|
||||
StorageControllerFailpoint("shard-split-post-complete"),
|
||||
],
|
||||
)
|
||||
def test_sharding_split_failures(neon_env_builder: NeonEnvBuilder, failure: Failure):
|
||||
neon_env_builder.num_pageservers = 4
|
||||
initial_shard_count = 2
|
||||
split_shard_count = 4
|
||||
|
||||
env = neon_env_builder.init_start(initial_tenant_shard_count=initial_shard_count)
|
||||
tenant_id = env.initial_tenant
|
||||
timeline_id = env.initial_timeline
|
||||
|
||||
# Make sure the node we're failing has a shard on it, otherwise the test isn't testing anything
|
||||
assert (
|
||||
failure.pageserver_id is None
|
||||
or len(
|
||||
env.get_pageserver(failure.pageserver_id)
|
||||
.http_client()
|
||||
.tenant_list_locations()["tenant_shards"]
|
||||
)
|
||||
> 0
|
||||
)
|
||||
|
||||
workload = Workload(env, tenant_id, timeline_id)
|
||||
workload.init()
|
||||
workload.write_rows(100)
|
||||
|
||||
# Set one pageserver to 500 all requests, then do a split
|
||||
# TODO: also test with a long-blocking failure: controller should time out its request and then
|
||||
# clean up in a well defined way.
|
||||
failure.apply(env)
|
||||
|
||||
with pytest.raises(AttachmentServiceApiException):
|
||||
env.attachment_service.tenant_shard_split(tenant_id, shard_count=4)
|
||||
|
||||
# We expect that the overall operation will fail, but some split requests
|
||||
# will have succeeded: the net result should be to return to a clean state, including
|
||||
# detaching any child shards.
|
||||
def assert_rolled_back(exclude_ps_id=None) -> None:
|
||||
count = 0
|
||||
for ps in env.pageservers:
|
||||
if exclude_ps_id is not None and ps.id == exclude_ps_id:
|
||||
continue
|
||||
|
||||
locations = ps.http_client().tenant_list_locations()["tenant_shards"]
|
||||
for loc in locations:
|
||||
tenant_shard_id = TenantShardId.parse(loc[0])
|
||||
log.info(f"Shard {tenant_shard_id} seen on node {ps.id}")
|
||||
assert tenant_shard_id.shard_count == initial_shard_count
|
||||
count += 1
|
||||
assert count == initial_shard_count
|
||||
|
||||
def assert_split_done(exclude_ps_id=None) -> None:
|
||||
count = 0
|
||||
for ps in env.pageservers:
|
||||
if exclude_ps_id is not None and ps.id == exclude_ps_id:
|
||||
continue
|
||||
|
||||
locations = ps.http_client().tenant_list_locations()["tenant_shards"]
|
||||
for loc in locations:
|
||||
tenant_shard_id = TenantShardId.parse(loc[0])
|
||||
log.info(f"Shard {tenant_shard_id} seen on node {ps.id}")
|
||||
assert tenant_shard_id.shard_count == split_shard_count
|
||||
count += 1
|
||||
assert count == split_shard_count
|
||||
|
||||
def finish_split():
|
||||
# Having failed+rolled back, we should be able to split again
|
||||
# No failures this time; it will succeed
|
||||
env.attachment_service.tenant_shard_split(tenant_id, shard_count=split_shard_count)
|
||||
|
||||
workload.churn_rows(10)
|
||||
workload.validate()
|
||||
|
||||
if failure.expect_available():
|
||||
# Even though the split failed partway through, this should not have interrupted
|
||||
# clients. Disable waiting for pageservers in the workload helper, because our
|
||||
# failpoints may prevent API access.
|
||||
# This only applies for failure modes that leave pageserver page_service API available.
|
||||
workload.churn_rows(10, upload=False, ingest=False)
|
||||
workload.validate()
|
||||
|
||||
if failure.fails_forward():
|
||||
# A failure type which results in eventual completion of the split
|
||||
wait_until(30, 1, assert_split_done)
|
||||
elif failure.can_mitigate():
|
||||
# Mitigation phase: we expect to be able to proceed with a successful shard split
|
||||
failure.mitigate(env)
|
||||
|
||||
# The split should appear to be rolled back from the point of view of all pageservers
|
||||
# apart from the one that is offline
|
||||
wait_until(30, 1, lambda: assert_rolled_back(exclude_ps_id=failure.pageserver_id))
|
||||
|
||||
finish_split()
|
||||
wait_until(30, 1, lambda: assert_split_done(exclude_ps_id=failure.pageserver_id))
|
||||
|
||||
# Having cleared the failure, everything should converge to a pristine state
|
||||
failure.clear(env)
|
||||
wait_until(30, 1, assert_split_done)
|
||||
else:
|
||||
# Once we restore the faulty pageserver's API to good health, rollback should
|
||||
# eventually complete.
|
||||
failure.clear(env)
|
||||
|
||||
wait_until(30, 1, assert_rolled_back)
|
||||
|
||||
# Having rolled back, the tenant should be working
|
||||
workload.churn_rows(10)
|
||||
workload.validate()
|
||||
|
||||
# Splitting again should work, since we cleared the failure
|
||||
finish_split()
|
||||
assert_split_done()
|
||||
|
||||
env.attachment_service.consistency_check()
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import concurrent.futures
|
||||
import random
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
@@ -25,9 +23,8 @@ from fixtures.pageserver.utils import (
|
||||
)
|
||||
from fixtures.pg_version import PgVersion
|
||||
from fixtures.remote_storage import RemoteStorageKind, s3_storage
|
||||
from fixtures.types import TenantId, TenantShardId, TimelineId
|
||||
from fixtures.types import TenantId, TimelineId
|
||||
from fixtures.utils import run_pg_bench_small, wait_until
|
||||
from fixtures.workload import Workload
|
||||
from mypy_boto3_s3.type_defs import (
|
||||
ObjectTypeDef,
|
||||
)
|
||||
@@ -773,186 +770,3 @@ def test_sharding_service_tenant_conf(neon_env_builder: NeonEnvBuilder):
|
||||
assert "pitr_interval" not in readback_ps.tenant_specific_overrides
|
||||
|
||||
env.attachment_service.consistency_check()
|
||||
|
||||
|
||||
def test_storcon_rolling_failures(
|
||||
neon_env_builder: NeonEnvBuilder, httpserver: HTTPServer, httpserver_listen_address
|
||||
):
|
||||
neon_env_builder.num_pageservers = 8
|
||||
|
||||
(host, port) = httpserver_listen_address
|
||||
neon_env_builder.control_plane_compute_hook_api = f"http://{host}:{port}/notify-attach"
|
||||
|
||||
workloads: dict[TenantId, Workload] = {}
|
||||
|
||||
# Do neon_local endpoint reconfiguration in the background so that we can
|
||||
# accept a healthy rate of calls into notify-attach.
|
||||
reconfigure_threads = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
def handler(request: Request):
|
||||
"""
|
||||
Although storage controller can use neon_local directly, this causes problems when
|
||||
the test is also concurrently modifying endpoints. Instead, configure storage controller
|
||||
to send notifications up to this test code, which will route all endpoint updates
|
||||
through Workload, which has a mutex to make it all safe.
|
||||
"""
|
||||
assert request.json is not None
|
||||
body: dict[str, Any] = request.json
|
||||
log.info(f"notify-attach request: {body}")
|
||||
|
||||
try:
|
||||
workload = workloads[TenantId(body["tenant_id"])]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
# This causes the endpoint to query storage controller for its location, which
|
||||
# is redundant since we already have it here, but this avoids extending the
|
||||
# neon_local CLI to take full lists of locations
|
||||
reconfigure_threads.submit(lambda workload=workload: workload.reconfigure()) # type: ignore[no-any-return]
|
||||
|
||||
return Response(status=200)
|
||||
|
||||
httpserver.expect_request("/notify-attach", method="PUT").respond_with_handler(handler)
|
||||
|
||||
env = neon_env_builder.init_start()
|
||||
|
||||
for ps in env.pageservers:
|
||||
# We will do unclean detaches
|
||||
ps.allowed_errors.append(".*Dropped remote consistent LSN updates.*")
|
||||
|
||||
n_tenants = 32
|
||||
tenants = [(env.initial_tenant, env.initial_timeline)]
|
||||
for i in range(0, n_tenants - 1):
|
||||
tenant_id = TenantId.generate()
|
||||
timeline_id = TimelineId.generate()
|
||||
shard_count = [1, 2, 4][i % 3]
|
||||
env.neon_cli.create_tenant(
|
||||
tenant_id, timeline_id, shard_count=shard_count, placement_policy='{"Double":1}'
|
||||
)
|
||||
tenants.append((tenant_id, timeline_id))
|
||||
|
||||
# Background pain:
|
||||
# - TODO: some fraction of pageserver API requests hang
|
||||
# (this requires implementing wrap of location_conf calls with proper timeline/cancel)
|
||||
# - TODO: continuous tenant/timeline creation/destruction over a different ID range than
|
||||
# the ones we're using for availability checks.
|
||||
|
||||
rng = random.Random(0xDEADBEEF)
|
||||
|
||||
for tenant_id, timeline_id in tenants:
|
||||
workload = Workload(env, tenant_id, timeline_id)
|
||||
workloads[tenant_id] = workload
|
||||
|
||||
def node_evacuated(node_id: int):
|
||||
counts = get_node_shard_counts(env, [t[0] for t in tenants])
|
||||
assert counts[node_id] == 0
|
||||
|
||||
def attachments_active():
|
||||
for tid, _tlid in tenants:
|
||||
for shard in env.attachment_service.locate(tid):
|
||||
psid = shard["node_id"]
|
||||
tsid = TenantShardId.parse(shard["shard_id"])
|
||||
status = env.get_pageserver(psid).http_client().tenant_status(tenant_id=tsid)
|
||||
assert status["state"]["slug"] == "Active"
|
||||
log.info(f"Shard {tsid} active on node {psid}")
|
||||
|
||||
failpoints = ("api-503", "5%1000*return(1)")
|
||||
failpoints_str = f"{failpoints[0]}={failpoints[1]}"
|
||||
for ps in env.pageservers:
|
||||
ps.http_client().configure_failpoints(failpoints)
|
||||
|
||||
def for_all_workloads(callback, timeout=60):
|
||||
futs = []
|
||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||
for _tenant_id, workload in workloads.items():
|
||||
futs.append(pool.submit(callback, workload))
|
||||
|
||||
for f in futs:
|
||||
f.result(timeout=timeout)
|
||||
|
||||
def clean_fail_restore():
|
||||
"""
|
||||
Clean shutdown of a node: mark it offline in storage controller, wait for new attachment
|
||||
locations to activate, then SIGTERM it.
|
||||
- Endpoints should not fail any queries
|
||||
- New attach locations should activate within bounded time.
|
||||
"""
|
||||
victim = rng.choice(env.pageservers)
|
||||
env.attachment_service.node_configure(victim.id, {"availability": "Offline"})
|
||||
|
||||
wait_until(10, 1, lambda node_id=victim.id: node_evacuated(node_id)) # type: ignore[misc]
|
||||
wait_until(10, 1, attachments_active)
|
||||
|
||||
victim.stop(immediate=False)
|
||||
|
||||
traffic()
|
||||
|
||||
victim.start(extra_env_vars={"FAILPOINTS": failpoints_str})
|
||||
|
||||
# Revert shards to attach at their original locations
|
||||
env.attachment_service.balance_attached()
|
||||
wait_until(10, 1, attachments_active)
|
||||
|
||||
def hard_fail_restore():
|
||||
"""
|
||||
Simulate an unexpected death of a pageserver node
|
||||
"""
|
||||
victim = rng.choice(env.pageservers)
|
||||
victim.stop(immediate=True)
|
||||
# TODO: once we implement heartbeats detecting node failures, remove this
|
||||
# explicit marking offline and rely on storage controller to detect it itself.
|
||||
env.attachment_service.node_configure(victim.id, {"availability": "Offline"})
|
||||
wait_until(10, 1, lambda node_id=victim.id: node_evacuated(node_id)) # type: ignore[misc]
|
||||
wait_until(10, 1, attachments_active)
|
||||
traffic()
|
||||
victim.start(extra_env_vars={"FAILPOINTS": failpoints_str})
|
||||
env.attachment_service.balance_attached()
|
||||
wait_until(10, 1, attachments_active)
|
||||
|
||||
def traffic():
|
||||
"""
|
||||
Check that all tenants are working for postgres clients
|
||||
"""
|
||||
|
||||
def exercise_one(workload):
|
||||
workload.churn_rows(100)
|
||||
workload.validate()
|
||||
|
||||
for_all_workloads(exercise_one)
|
||||
|
||||
def init_one(workload):
|
||||
workload.init()
|
||||
workload.write_rows(100)
|
||||
|
||||
for_all_workloads(init_one, timeout=60)
|
||||
|
||||
for i in range(0, 20):
|
||||
mode = rng.choice([0, 1, 2])
|
||||
log.info(f"Iteration {i}, mode {mode}")
|
||||
if mode == 0:
|
||||
# Traffic interval: sometimes, instead of a failure, just let the clients
|
||||
# write a load of data. This avoids chaos tests ending up with unrealistically
|
||||
# small quantities of data in flight.
|
||||
traffic()
|
||||
elif mode == 1:
|
||||
clean_fail_restore()
|
||||
elif mode == 2:
|
||||
hard_fail_restore()
|
||||
|
||||
# Fail and restart: hard-kill one node. Notify the storage controller that it is offline.
|
||||
# Success criteria:
|
||||
# - New attach locations should activate within bounded time
|
||||
# - TODO: once we do heartbeating, we should not have to explicitly mark the node offline
|
||||
|
||||
# TODO: fail and remove: fail a node, and remove it from the cluster.
|
||||
# Success criteria:
|
||||
# - Endpoints should not fail any queries
|
||||
# - New attach locations should activate within bounded time
|
||||
# - New secondary locations should fill up with data within bounded time
|
||||
|
||||
# TODO: somehow need to wait for reconciles to complete before doing consistency check
|
||||
# (or make the check wait).
|
||||
|
||||
# Do consistency check on every iteration, not just at the end: this makes it more obvious
|
||||
# which change caused an issue.
|
||||
env.attachment_service.consistency_check()
|
||||
|
||||
@@ -64,7 +64,7 @@ rustls = { version = "0.21", features = ["dangerous_configuration"] }
|
||||
scopeguard = { version = "1" }
|
||||
serde = { version = "1", features = ["alloc", "derive"] }
|
||||
serde_json = { version = "1", features = ["raw_value"] }
|
||||
smallvec = { version = "1", default-features = false, features = ["write"] }
|
||||
smallvec = { version = "1", default-features = false, features = ["const_new", "write"] }
|
||||
subtle = { version = "2" }
|
||||
time = { version = "0.3", features = ["local-offset", "macros", "serde-well-known"] }
|
||||
tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "net", "process", "rt-multi-thread", "signal", "test-util"] }
|
||||
@@ -76,7 +76,6 @@ tonic = { version = "0.9", features = ["tls-roots"] }
|
||||
tower = { version = "0.4", default-features = false, features = ["balance", "buffer", "limit", "log", "timeout", "util"] }
|
||||
tracing = { version = "0.1", features = ["log"] }
|
||||
tracing-core = { version = "0.1" }
|
||||
tungstenite = { version = "0.20" }
|
||||
url = { version = "2", features = ["serde"] }
|
||||
uuid = { version = "1", features = ["serde", "v4", "v7"] }
|
||||
zeroize = { version = "1", features = ["derive"] }
|
||||
|
||||
Reference in New Issue
Block a user