mirror of
https://github.com/neondatabase/neon.git
synced 2026-02-04 19:20:36 +00:00
Compare commits
3 Commits
remove-sel
...
proxy-conf
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
462713802d | ||
|
|
0b2e0d8af5 | ||
|
|
7ac2179aeb |
@@ -2,7 +2,6 @@
|
||||
# This is only present for local builds, as it will be overridden
|
||||
# by the RUSTDOCFLAGS env var in CI.
|
||||
rustdocflags = ["-Arustdoc::private_intra_doc_links"]
|
||||
rustflags = ["--cfg", "tokio_unstable"]
|
||||
|
||||
[alias]
|
||||
build_testing = ["build", "--features", "testing"]
|
||||
|
||||
1
.github/workflows/build_and_test.yml
vendored
1
.github/workflows/build_and_test.yml
vendored
@@ -214,7 +214,6 @@ jobs:
|
||||
BUILD_TYPE: ${{ matrix.build_type }}
|
||||
GIT_VERSION: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
BUILD_TAG: ${{ needs.tag.outputs.build-tag }}
|
||||
RUSTFLAGS: "--cfg=tokio_unstable"
|
||||
|
||||
steps:
|
||||
- name: Fix git ownership
|
||||
|
||||
142
Cargo.lock
generated
142
Cargo.lock
generated
@@ -1240,43 +1240,6 @@ dependencies = [
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "console-api"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fd326812b3fd01da5bb1af7d340d0d555fd3d4b641e7f1dfcf5962a902952787"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"prost 0.12.4",
|
||||
"prost-types 0.12.4",
|
||||
"tonic 0.10.2",
|
||||
"tracing-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "console-subscriber"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7481d4c57092cd1c19dd541b92bdce883de840df30aa5d03fd48a3935c01842e"
|
||||
dependencies = [
|
||||
"console-api",
|
||||
"crossbeam-channel",
|
||||
"crossbeam-utils",
|
||||
"futures-task",
|
||||
"hdrhistogram",
|
||||
"humantime",
|
||||
"prost-types 0.12.4",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thread_local",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tonic 0.10.2",
|
||||
"tracing",
|
||||
"tracing-core",
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "const-oid"
|
||||
version = "0.9.5"
|
||||
@@ -3445,11 +3408,11 @@ dependencies = [
|
||||
"opentelemetry-semantic-conventions",
|
||||
"opentelemetry_api",
|
||||
"opentelemetry_sdk",
|
||||
"prost 0.11.9",
|
||||
"prost",
|
||||
"reqwest",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tonic 0.9.2",
|
||||
"tonic",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3460,8 +3423,8 @@ checksum = "b1e3f814aa9f8c905d0ee4bde026afd3b2577a97c10e1699912e3e44f0c4cbeb"
|
||||
dependencies = [
|
||||
"opentelemetry_api",
|
||||
"opentelemetry_sdk",
|
||||
"prost 0.11.9",
|
||||
"tonic 0.9.2",
|
||||
"prost",
|
||||
"tonic",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4046,6 +4009,17 @@ dependencies = [
|
||||
"tokio-postgres",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "postgres-native-tls"
|
||||
version = "0.5.0"
|
||||
source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#20031d7a9ee1addeae6e0968e3899ae6bf01cee2"
|
||||
dependencies = [
|
||||
"native-tls",
|
||||
"tokio",
|
||||
"tokio-native-tls",
|
||||
"tokio-postgres",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "postgres-protocol"
|
||||
version = "0.6.4"
|
||||
@@ -4260,17 +4234,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0b82eaa1d779e9a4bc1c3217db8ffbeabaae1dca241bf70183242128d48681cd"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"prost-derive 0.11.9",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "prost"
|
||||
version = "0.12.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d0f5d036824e4761737860779c906171497f6d55681139d8312388f8fe398922"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"prost-derive 0.12.4",
|
||||
"prost-derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4287,8 +4251,8 @@ dependencies = [
|
||||
"multimap",
|
||||
"petgraph",
|
||||
"prettyplease 0.1.25",
|
||||
"prost 0.11.9",
|
||||
"prost-types 0.11.9",
|
||||
"prost",
|
||||
"prost-types",
|
||||
"regex",
|
||||
"syn 1.0.109",
|
||||
"tempfile",
|
||||
@@ -4308,35 +4272,13 @@ dependencies = [
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "prost-derive"
|
||||
version = "0.12.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "19de2de2a00075bf566bee3bd4db014b11587e84184d3f7a791bc17f1a8e9e48"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"itertools",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.52",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "prost-types"
|
||||
version = "0.11.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "213622a1460818959ac1181aaeb2dc9c7f63df720db7d788b3e24eacd1983e13"
|
||||
dependencies = [
|
||||
"prost 0.11.9",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "prost-types"
|
||||
version = "0.12.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3235c33eb02c1f1e212abdbe34c78b264b038fb58ca612664343271e36e55ffe"
|
||||
dependencies = [
|
||||
"prost 0.12.4",
|
||||
"prost",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4358,7 +4300,6 @@ dependencies = [
|
||||
"camino-tempfile",
|
||||
"chrono",
|
||||
"clap",
|
||||
"console-subscriber",
|
||||
"consumption_metrics",
|
||||
"dashmap",
|
||||
"env_logger",
|
||||
@@ -4373,6 +4314,7 @@ dependencies = [
|
||||
"http 1.1.0",
|
||||
"http-body-util",
|
||||
"humantime",
|
||||
"humantime-serde",
|
||||
"hyper 0.14.26",
|
||||
"hyper 1.2.0",
|
||||
"hyper-tungstenite",
|
||||
@@ -4383,6 +4325,7 @@ dependencies = [
|
||||
"md5",
|
||||
"measured",
|
||||
"metrics",
|
||||
"native-tls",
|
||||
"once_cell",
|
||||
"opentelemetry",
|
||||
"parking_lot 0.12.1",
|
||||
@@ -4390,6 +4333,7 @@ dependencies = [
|
||||
"parquet_derive",
|
||||
"pbkdf2",
|
||||
"pin-project-lite",
|
||||
"postgres-native-tls",
|
||||
"postgres-protocol",
|
||||
"postgres_backend",
|
||||
"pq_proto",
|
||||
@@ -4408,11 +4352,11 @@ dependencies = [
|
||||
"rstest",
|
||||
"rustc-hash",
|
||||
"rustls 0.22.2",
|
||||
"rustls-native-certs 0.7.0",
|
||||
"rustls-pemfile 2.1.1",
|
||||
"scopeguard",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_with",
|
||||
"sha2",
|
||||
"smallvec",
|
||||
"smol_str",
|
||||
@@ -5775,10 +5719,10 @@ dependencies = [
|
||||
"metrics",
|
||||
"once_cell",
|
||||
"parking_lot 0.12.1",
|
||||
"prost 0.11.9",
|
||||
"prost",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tonic 0.9.2",
|
||||
"tonic",
|
||||
"tonic-build",
|
||||
"tracing",
|
||||
"utils",
|
||||
@@ -6170,7 +6114,6 @@ dependencies = [
|
||||
"signal-hook-registry",
|
||||
"socket2 0.5.5",
|
||||
"tokio-macros",
|
||||
"tracing",
|
||||
"windows-sys 0.48.0",
|
||||
]
|
||||
|
||||
@@ -6401,7 +6344,7 @@ dependencies = [
|
||||
"hyper-timeout",
|
||||
"percent-encoding",
|
||||
"pin-project",
|
||||
"prost 0.11.9",
|
||||
"prost",
|
||||
"rustls-native-certs 0.6.2",
|
||||
"rustls-pemfile 1.0.2",
|
||||
"tokio",
|
||||
@@ -6413,33 +6356,6 @@ dependencies = [
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tonic"
|
||||
version = "0.10.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d560933a0de61cf715926b9cac824d4c883c2c43142f787595e48280c40a1d0e"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"async-trait",
|
||||
"axum",
|
||||
"base64 0.21.1",
|
||||
"bytes",
|
||||
"h2 0.3.26",
|
||||
"http 0.2.9",
|
||||
"http-body 0.4.5",
|
||||
"hyper 0.14.26",
|
||||
"hyper-timeout",
|
||||
"percent-encoding",
|
||||
"pin-project",
|
||||
"prost 0.12.4",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tower",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tonic-build"
|
||||
version = "0.9.2"
|
||||
@@ -7420,7 +7336,6 @@ dependencies = [
|
||||
"futures-util",
|
||||
"getrandom 0.2.11",
|
||||
"hashbrown 0.14.0",
|
||||
"hdrhistogram",
|
||||
"hex",
|
||||
"hmac",
|
||||
"hyper 0.14.26",
|
||||
@@ -7435,7 +7350,7 @@ dependencies = [
|
||||
"num-traits",
|
||||
"once_cell",
|
||||
"parquet",
|
||||
"prost 0.11.9",
|
||||
"prost",
|
||||
"rand 0.8.5",
|
||||
"regex",
|
||||
"regex-automata 0.4.3",
|
||||
@@ -7454,11 +7369,10 @@ dependencies = [
|
||||
"time-macros",
|
||||
"tokio",
|
||||
"tokio-rustls 0.24.0",
|
||||
"tokio-stream",
|
||||
"tokio-util",
|
||||
"toml_datetime",
|
||||
"toml_edit",
|
||||
"tonic 0.9.2",
|
||||
"tonic",
|
||||
"tower",
|
||||
"tracing",
|
||||
"tracing-core",
|
||||
|
||||
@@ -47,7 +47,7 @@ COPY --chown=nonroot . .
|
||||
# Show build caching stats to check if it was used in the end.
|
||||
# Has to be the part of the same RUN since cachepot daemon is killed in the end of this RUN, losing the compilation stats.
|
||||
RUN set -e \
|
||||
&& RUSTFLAGS="-Clinker=clang -Clink-arg=-fuse-ld=mold -Clink-arg=-Wl,--no-rosegment --cfg=tokio_unstable" cargo build \
|
||||
&& RUSTFLAGS="-Clinker=clang -Clink-arg=-fuse-ld=mold -Clink-arg=-Wl,--no-rosegment" cargo build \
|
||||
--bin pg_sni_router \
|
||||
--bin pageserver \
|
||||
--bin pagectl \
|
||||
|
||||
@@ -35,6 +35,7 @@ hmac.workspace = true
|
||||
hostname.workspace = true
|
||||
http.workspace = true
|
||||
humantime.workspace = true
|
||||
humantime-serde.workspace = true
|
||||
hyper-tungstenite.workspace = true
|
||||
hyper.workspace = true
|
||||
hyper1 = { package = "hyper", version = "1.2", features = ["server"] }
|
||||
@@ -67,10 +68,10 @@ routerify.workspace = true
|
||||
rustc-hash.workspace = true
|
||||
rustls-pemfile.workspace = true
|
||||
rustls.workspace = true
|
||||
rustls-native-certs = "0.7.0"
|
||||
scopeguard.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
serde_with.workspace = true
|
||||
sha2 = { workspace = true, features = ["asm"] }
|
||||
smol_str.workspace = true
|
||||
smallvec.workspace = true
|
||||
@@ -82,10 +83,9 @@ thiserror.workspace = true
|
||||
tikv-jemallocator.workspace = true
|
||||
tikv-jemalloc-ctl = { workspace = true, features = ["use_std"] }
|
||||
tokio-postgres.workspace = true
|
||||
tokio-postgres-rustls.workspace = true
|
||||
tokio-rustls.workspace = true
|
||||
tokio-util.workspace = true
|
||||
tokio = { workspace = true, features = ["signal", "tracing"] }
|
||||
tokio = { workspace = true, features = ["signal"] }
|
||||
tracing-opentelemetry.workspace = true
|
||||
tracing-subscriber.workspace = true
|
||||
tracing-utils.workspace = true
|
||||
@@ -96,11 +96,11 @@ utils.workspace = true
|
||||
uuid.workspace = true
|
||||
webpki-roots.workspace = true
|
||||
x509-parser.workspace = true
|
||||
native-tls.workspace = true
|
||||
postgres-native-tls.workspace = true
|
||||
postgres-protocol.workspace = true
|
||||
redis.workspace = true
|
||||
|
||||
console-subscriber = "0.2.0"
|
||||
|
||||
workspace_hack.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
@@ -108,5 +108,6 @@ camino-tempfile.workspace = true
|
||||
fallible-iterator.workspace = true
|
||||
rcgen.workspace = true
|
||||
rstest.workspace = true
|
||||
tokio-postgres-rustls.workspace = true
|
||||
walkdir.workspace = true
|
||||
rand_distr = "0.4"
|
||||
|
||||
@@ -121,5 +121,6 @@ pub(super) async fn authenticate(
|
||||
Ok(NodeInfo {
|
||||
config,
|
||||
aux: db_info.aux,
|
||||
allow_self_signed_compute: false, // caller may override
|
||||
})
|
||||
}
|
||||
|
||||
@@ -35,8 +35,6 @@ use proxy::config::{self, ProxyConfig};
|
||||
use proxy::serverless;
|
||||
use std::net::SocketAddr;
|
||||
use std::pin::pin;
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::sync::Arc;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::Mutex;
|
||||
@@ -122,6 +120,9 @@ struct ProxyCliArgs {
|
||||
/// lock for `wake_compute` api method. example: "shards=32,permits=4,epoch=10m,timeout=1s". (use `permits=0` to disable).
|
||||
#[clap(long, default_value = config::WakeComputeLockOptions::DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK)]
|
||||
wake_compute_lock: String,
|
||||
/// Allow self-signed certificates for compute nodes (for testing)
|
||||
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
|
||||
allow_self_signed_compute: bool,
|
||||
#[clap(flatten)]
|
||||
sql_over_http: SqlOverHttpArgs,
|
||||
/// timeout for scram authentication protocol
|
||||
@@ -235,21 +236,8 @@ struct SqlOverHttpArgs {
|
||||
sql_over_http_pool_shards: usize,
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let rt = tokio::runtime::Builder::new_multi_thread()
|
||||
.enable_all()
|
||||
.thread_name_fn(|| {
|
||||
static ATOMIC_ID: AtomicUsize = AtomicUsize::new(0);
|
||||
let id = ATOMIC_ID.fetch_add(1, Ordering::SeqCst);
|
||||
format!("worker-{}", id)
|
||||
})
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
rt.block_on(main2())
|
||||
}
|
||||
|
||||
async fn main2() -> anyhow::Result<()> {
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let _logging_guard = proxy::logging::init().await?;
|
||||
let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
|
||||
let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
|
||||
@@ -364,16 +352,12 @@ async fn main2() -> anyhow::Result<()> {
|
||||
// client facing tasks. these will exit on error or on cancellation
|
||||
// cancellation returns Ok(())
|
||||
let mut client_tasks = JoinSet::new();
|
||||
client_tasks
|
||||
.build_task()
|
||||
.name("tcp main")
|
||||
.spawn(proxy::proxy::task_main(
|
||||
config,
|
||||
proxy_listener,
|
||||
cancellation_token.clone(),
|
||||
cancellation_handler.clone(),
|
||||
))
|
||||
.unwrap();
|
||||
client_tasks.spawn(proxy::proxy::task_main(
|
||||
config,
|
||||
proxy_listener,
|
||||
cancellation_token.clone(),
|
||||
cancellation_handler.clone(),
|
||||
));
|
||||
|
||||
// TODO: rename the argument to something like serverless.
|
||||
// It now covers more than just websockets, it also covers SQL over HTTP.
|
||||
@@ -382,98 +366,58 @@ async fn main2() -> anyhow::Result<()> {
|
||||
info!("Starting wss on {serverless_address}");
|
||||
let serverless_listener = TcpListener::bind(serverless_address).await?;
|
||||
|
||||
client_tasks
|
||||
.build_task()
|
||||
.name("serverless main")
|
||||
.spawn(serverless::task_main(
|
||||
config,
|
||||
serverless_listener,
|
||||
cancellation_token.clone(),
|
||||
cancellation_handler.clone(),
|
||||
))
|
||||
.unwrap();
|
||||
client_tasks.spawn(serverless::task_main(
|
||||
config,
|
||||
serverless_listener,
|
||||
cancellation_token.clone(),
|
||||
cancellation_handler.clone(),
|
||||
));
|
||||
}
|
||||
|
||||
client_tasks
|
||||
.build_task()
|
||||
.name("parquet worker")
|
||||
.spawn(proxy::context::parquet::worker(
|
||||
cancellation_token.clone(),
|
||||
args.parquet_upload,
|
||||
))
|
||||
.unwrap();
|
||||
client_tasks.spawn(proxy::context::parquet::worker(
|
||||
cancellation_token.clone(),
|
||||
args.parquet_upload,
|
||||
));
|
||||
|
||||
// maintenance tasks. these never return unless there's an error
|
||||
let mut maintenance_tasks = JoinSet::new();
|
||||
maintenance_tasks
|
||||
.build_task()
|
||||
.name("signal handler")
|
||||
.spawn(proxy::handle_signals(cancellation_token.clone()))
|
||||
.unwrap();
|
||||
maintenance_tasks
|
||||
.build_task()
|
||||
.name("health server")
|
||||
.spawn(http::health_server::task_main(
|
||||
http_listener,
|
||||
AppMetrics {
|
||||
jemalloc,
|
||||
neon_metrics,
|
||||
proxy: proxy::metrics::Metrics::get(),
|
||||
},
|
||||
))
|
||||
.unwrap();
|
||||
maintenance_tasks
|
||||
.build_task()
|
||||
.name("mangement main")
|
||||
.spawn(console::mgmt::task_main(mgmt_listener))
|
||||
.unwrap();
|
||||
maintenance_tasks.spawn(proxy::handle_signals(cancellation_token.clone()));
|
||||
maintenance_tasks.spawn(http::health_server::task_main(
|
||||
http_listener,
|
||||
AppMetrics {
|
||||
jemalloc,
|
||||
neon_metrics,
|
||||
proxy: proxy::metrics::Metrics::get(),
|
||||
},
|
||||
));
|
||||
maintenance_tasks.spawn(console::mgmt::task_main(mgmt_listener));
|
||||
|
||||
if let Some(metrics_config) = &config.metric_collection {
|
||||
// TODO: Add gc regardles of the metric collection being enabled.
|
||||
maintenance_tasks
|
||||
.build_task()
|
||||
.name("")
|
||||
.spawn(usage_metrics::task_main(metrics_config))
|
||||
.unwrap();
|
||||
client_tasks
|
||||
.build_task()
|
||||
.name("")
|
||||
.spawn(usage_metrics::task_backup(
|
||||
&metrics_config.backup_metric_collection_config,
|
||||
cancellation_token,
|
||||
))
|
||||
.unwrap();
|
||||
maintenance_tasks.spawn(usage_metrics::task_main(metrics_config));
|
||||
client_tasks.spawn(usage_metrics::task_backup(
|
||||
&metrics_config.backup_metric_collection_config,
|
||||
cancellation_token,
|
||||
));
|
||||
}
|
||||
|
||||
if let auth::BackendType::Console(api, _) = &config.auth_backend {
|
||||
if let proxy::console::provider::ConsoleBackend::Console(api) = &**api {
|
||||
if let Some(redis_notifications_client) = redis_notifications_client {
|
||||
let cache = api.caches.project_info.clone();
|
||||
maintenance_tasks
|
||||
.build_task()
|
||||
.name("redis notifications")
|
||||
.spawn(notifications::task_main(
|
||||
redis_notifications_client,
|
||||
cache.clone(),
|
||||
cancel_map.clone(),
|
||||
args.region.clone(),
|
||||
))
|
||||
.unwrap();
|
||||
maintenance_tasks
|
||||
.build_task()
|
||||
.name("proj info cache gc")
|
||||
.spawn(async move { cache.clone().gc_worker().await })
|
||||
.unwrap();
|
||||
maintenance_tasks.spawn(notifications::task_main(
|
||||
redis_notifications_client,
|
||||
cache.clone(),
|
||||
cancel_map.clone(),
|
||||
args.region.clone(),
|
||||
));
|
||||
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
|
||||
}
|
||||
if let Some(regional_redis_client) = regional_redis_client {
|
||||
let cache = api.caches.endpoints_cache.clone();
|
||||
let con = regional_redis_client;
|
||||
let span = tracing::info_span!("endpoints_cache");
|
||||
maintenance_tasks
|
||||
.build_task()
|
||||
.name("redis endpoints cache read")
|
||||
.spawn(async move { cache.do_read(con).await }.instrument(span))
|
||||
.unwrap();
|
||||
maintenance_tasks.spawn(async move { cache.do_read(con).await }.instrument(span));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -514,6 +458,9 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
_ => bail!("either both or neither tls-key and tls-cert must be specified"),
|
||||
};
|
||||
|
||||
if args.allow_self_signed_compute {
|
||||
warn!("allowing self-signed compute certificates");
|
||||
}
|
||||
let backup_metric_collection_config = config::MetricBackupCollectionConfig {
|
||||
interval: args.metric_backup_collection_interval,
|
||||
remote_storage_config: remote_storage_from_toml(
|
||||
@@ -578,10 +525,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
)
|
||||
.unwrap(),
|
||||
));
|
||||
tokio::task::Builder::new()
|
||||
.name("wake compute lock gc")
|
||||
.spawn(locks.garbage_collect_worker())
|
||||
.unwrap();
|
||||
tokio::spawn(locks.garbage_collect_worker());
|
||||
|
||||
let url = args.auth_endpoint.parse()?;
|
||||
let endpoint = http::Endpoint::new(url, http::new_client());
|
||||
@@ -631,6 +575,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
tls_config,
|
||||
auth_backend,
|
||||
metric_collection,
|
||||
allow_self_signed_compute: args.allow_self_signed_compute,
|
||||
http_config,
|
||||
authentication_config,
|
||||
require_client_ip: args.require_client_ip,
|
||||
|
||||
@@ -10,12 +10,7 @@ use crate::{
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
use itertools::Itertools;
|
||||
use pq_proto::StartupMessageParams;
|
||||
use std::{
|
||||
io,
|
||||
net::SocketAddr,
|
||||
sync::{Arc, OnceLock},
|
||||
time::Duration,
|
||||
};
|
||||
use std::{io, net::SocketAddr, time::Duration};
|
||||
use thiserror::Error;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_postgres::tls::MakeTlsConnect;
|
||||
@@ -33,6 +28,9 @@ pub enum ConnectionError {
|
||||
#[error("{COULD_NOT_CONNECT}: {0}")]
|
||||
CouldNotConnect(#[from] io::Error),
|
||||
|
||||
#[error("{COULD_NOT_CONNECT}: {0}")]
|
||||
TlsError(#[from] native_tls::Error),
|
||||
|
||||
#[error("{COULD_NOT_CONNECT}: {0}")]
|
||||
WakeComputeError(#[from] WakeComputeError),
|
||||
}
|
||||
@@ -71,6 +69,7 @@ impl ReportableError for ConnectionError {
|
||||
}
|
||||
ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute,
|
||||
ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute,
|
||||
ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
|
||||
ConnectionError::WakeComputeError(e) => e.get_error_kind(),
|
||||
}
|
||||
}
|
||||
@@ -240,7 +239,7 @@ pub struct PostgresConnection {
|
||||
/// Socket connected to a compute node.
|
||||
pub stream: tokio_postgres::maybe_tls_stream::MaybeTlsStream<
|
||||
tokio::net::TcpStream,
|
||||
tokio_postgres_rustls::RustlsStream<tokio::net::TcpStream>,
|
||||
postgres_native_tls::TlsStream<tokio::net::TcpStream>,
|
||||
>,
|
||||
/// PostgreSQL connection parameters.
|
||||
pub params: std::collections::HashMap<String, String>,
|
||||
@@ -252,39 +251,22 @@ pub struct PostgresConnection {
|
||||
_guage: NumDbConnectionsGuard<'static>,
|
||||
}
|
||||
|
||||
static ROOT_STORE: OnceLock<Arc<rustls::RootCertStore>> = OnceLock::new();
|
||||
|
||||
impl ConnCfg {
|
||||
/// Connect to a corresponding compute node.
|
||||
pub async fn connect(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
allow_self_signed_compute: bool,
|
||||
aux: MetricsAuxInfo,
|
||||
timeout: Duration,
|
||||
) -> Result<PostgresConnection, ConnectionError> {
|
||||
let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
|
||||
|
||||
let root_store = ROOT_STORE.get_or_init(|| {
|
||||
let mut roots = rustls::RootCertStore::empty();
|
||||
|
||||
let certs = match rustls_native_certs::load_native_certs() {
|
||||
Ok(certs) => certs,
|
||||
Err(e) => {
|
||||
error!("could not load native ssl certs: {e:?}");
|
||||
return Arc::new(roots);
|
||||
}
|
||||
};
|
||||
|
||||
let (added, ignored) = roots.add_parsable_certificates(certs);
|
||||
info!(added, ignored, "loaded native ssl certifications");
|
||||
|
||||
Arc::new(roots)
|
||||
});
|
||||
|
||||
let client_config = rustls::ClientConfig::builder()
|
||||
.with_root_certificates(root_store.clone())
|
||||
.with_no_client_auth();
|
||||
let mut mk_tls = tokio_postgres_rustls::MakeRustlsConnect::new(client_config);
|
||||
let tls_connector = native_tls::TlsConnector::builder()
|
||||
.danger_accept_invalid_certs(allow_self_signed_compute)
|
||||
.build()
|
||||
.unwrap();
|
||||
let mut mk_tls = postgres_native_tls::MakeTlsConnector::new(tls_connector);
|
||||
let tls = MakeTlsConnect::<tokio::net::TcpStream>::make_tls_connect(&mut mk_tls, host)?;
|
||||
|
||||
// connect_raw() will not use TLS if sslmode is "disable"
|
||||
|
||||
@@ -3,13 +3,19 @@ use crate::{
|
||||
rate_limiter::RateBucketInfo,
|
||||
serverless::GlobalConnPoolOptions,
|
||||
};
|
||||
use anyhow::{bail, ensure, Context, Ok};
|
||||
use anyhow::{ensure, Context};
|
||||
use humantime::parse_duration;
|
||||
use itertools::Itertools;
|
||||
use remote_storage::RemoteStorageConfig;
|
||||
use rustls::{
|
||||
crypto::ring::sign,
|
||||
pki_types::{CertificateDer, PrivateKeyDer},
|
||||
};
|
||||
use serde::{
|
||||
de::{value::BorrowedStrDeserializer, MapAccess},
|
||||
forward_to_deserialize_any, Deserialize, Deserializer,
|
||||
};
|
||||
use serde_with::serde_as;
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
@@ -24,6 +30,7 @@ pub struct ProxyConfig {
|
||||
pub tls_config: Option<TlsConfig>,
|
||||
pub auth_backend: auth::BackendType<'static, (), ()>,
|
||||
pub metric_collection: Option<MetricCollectionConfig>,
|
||||
pub allow_self_signed_compute: bool,
|
||||
pub http_config: HttpConfig,
|
||||
pub authentication_config: AuthenticationConfig,
|
||||
pub require_client_ip: bool,
|
||||
@@ -336,45 +343,88 @@ impl EndpointCacheConfig {
|
||||
/// Notice that by default the limiter is empty, which means that cache is disabled.
|
||||
pub const CACHE_DEFAULT_OPTIONS: &'static str =
|
||||
"initial_batch_size=1000,default_batch_size=10,xread_timeout=5m,stream_name=controlPlane,disable_cache=true,limiter_info=1000@1s,retry_interval=1s";
|
||||
}
|
||||
|
||||
/// Parse cache options passed via cmdline.
|
||||
/// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
|
||||
fn parse(options: &str) -> anyhow::Result<Self> {
|
||||
let mut initial_batch_size = None;
|
||||
let mut default_batch_size = None;
|
||||
let mut xread_timeout = None;
|
||||
let mut stream_name = None;
|
||||
let mut limiter_info = vec![];
|
||||
let mut disable_cache = false;
|
||||
let mut retry_interval = None;
|
||||
impl<'de> serde::Deserialize<'de> for EndpointCacheConfig {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
struct Visitor;
|
||||
impl<'de> serde::de::Visitor<'de> for Visitor {
|
||||
type Value = EndpointCacheConfig;
|
||||
fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
f.write_str("struct EndpointCacheConfig")
|
||||
}
|
||||
|
||||
for option in options.split(',') {
|
||||
let (key, value) = option
|
||||
.split_once('=')
|
||||
.with_context(|| format!("bad key-value pair: {option}"))?;
|
||||
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
|
||||
where
|
||||
A: serde::de::MapAccess<'de>,
|
||||
{
|
||||
fn e<E: serde::de::Error, T: std::fmt::Display>(t: T) -> E {
|
||||
E::custom(t)
|
||||
}
|
||||
|
||||
match key {
|
||||
"initial_batch_size" => initial_batch_size = Some(value.parse()?),
|
||||
"default_batch_size" => default_batch_size = Some(value.parse()?),
|
||||
"xread_timeout" => xread_timeout = Some(humantime::parse_duration(value)?),
|
||||
"stream_name" => stream_name = Some(value.to_string()),
|
||||
"limiter_info" => limiter_info.push(RateBucketInfo::from_str(value)?),
|
||||
"disable_cache" => disable_cache = value.parse()?,
|
||||
"retry_interval" => retry_interval = Some(humantime::parse_duration(value)?),
|
||||
unknown => bail!("unknown key: {unknown}"),
|
||||
let mut initial_batch_size: Option<usize> = None;
|
||||
let mut default_batch_size: Option<usize> = None;
|
||||
let mut xread_timeout: Option<Duration> = None;
|
||||
let mut stream_name: Option<String> = None;
|
||||
let mut limiter_info: Vec<RateBucketInfo> = vec![];
|
||||
let mut disable_cache: bool = false;
|
||||
let mut retry_interval: Option<Duration> = None;
|
||||
while let Some((k, v)) = map.next_entry::<&str, &str>()? {
|
||||
match k {
|
||||
"initial_batch_size" => initial_batch_size = Some(v.parse().map_err(e)?),
|
||||
"default_batch_size" => default_batch_size = Some(v.parse().map_err(e)?),
|
||||
"xread_timeout" => {
|
||||
xread_timeout = Some(parse_duration(v).map_err(e)?);
|
||||
}
|
||||
"stream_name" => stream_name = Some(v.to_owned()),
|
||||
"limiter_info" => limiter_info.push(v.parse().map_err(e)?),
|
||||
"disable_cache" => disable_cache = v.parse().map_err(e)?,
|
||||
"retry_interval" => retry_interval = Some(parse_duration(v).map_err(e)?),
|
||||
x => {
|
||||
return Err(serde::de::Error::unknown_field(
|
||||
x,
|
||||
&[
|
||||
"initial_batch_size",
|
||||
"default_batch_size",
|
||||
"xread_timeout",
|
||||
"stream_name",
|
||||
"limiter_info",
|
||||
"disable_cache",
|
||||
"retry_interval",
|
||||
],
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let initial_batch_size = initial_batch_size
|
||||
.ok_or_else(|| serde::de::Error::missing_field("initial_batch_size"))?;
|
||||
let default_batch_size = default_batch_size
|
||||
.ok_or_else(|| serde::de::Error::missing_field("default_batch_size"))?;
|
||||
let xread_timeout = xread_timeout
|
||||
.ok_or_else(|| serde::de::Error::missing_field("xread_timeout"))?;
|
||||
let stream_name =
|
||||
stream_name.ok_or_else(|| serde::de::Error::missing_field("stream_name"))?;
|
||||
let retry_interval = retry_interval
|
||||
.ok_or_else(|| serde::de::Error::missing_field("retry_interval"))?;
|
||||
|
||||
RateBucketInfo::validate(&mut limiter_info).map_err(e)?;
|
||||
|
||||
Ok(EndpointCacheConfig {
|
||||
initial_batch_size,
|
||||
default_batch_size,
|
||||
xread_timeout,
|
||||
stream_name,
|
||||
limiter_info,
|
||||
disable_cache,
|
||||
retry_interval,
|
||||
})
|
||||
}
|
||||
}
|
||||
RateBucketInfo::validate(&mut limiter_info)?;
|
||||
|
||||
Ok(Self {
|
||||
initial_batch_size: initial_batch_size.context("missing `initial_batch_size`")?,
|
||||
default_batch_size: default_batch_size.context("missing `default_batch_size`")?,
|
||||
xread_timeout: xread_timeout.context("missing `xread_timeout`")?,
|
||||
stream_name: stream_name.context("missing `stream_name`")?,
|
||||
disable_cache,
|
||||
limiter_info,
|
||||
retry_interval: retry_interval.context("missing `retry_interval`")?,
|
||||
})
|
||||
serde::Deserializer::deserialize_map(deserializer, Visitor)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -383,7 +433,7 @@ impl FromStr for EndpointCacheConfig {
|
||||
|
||||
fn from_str(options: &str) -> Result<Self, Self::Err> {
|
||||
let error = || format!("failed to parse endpoint cache options '{options}'");
|
||||
Self::parse(options).with_context(error)
|
||||
Self::deserialize(SimpleKVConfig(options)).with_context(error)
|
||||
}
|
||||
}
|
||||
#[derive(Debug)]
|
||||
@@ -402,11 +452,15 @@ pub fn remote_storage_from_toml(s: &str) -> anyhow::Result<OptRemoteStorageConfi
|
||||
}
|
||||
|
||||
/// Helper for cmdline cache options parsing.
|
||||
#[derive(Debug)]
|
||||
#[serde_as]
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CacheOptions {
|
||||
/// Max number of entries.
|
||||
#[serde_as(as = "serde_with::DisplayFromStr")]
|
||||
pub size: usize,
|
||||
/// Entry's time-to-live.
|
||||
#[serde(with = "humantime_serde")]
|
||||
#[serde(default)]
|
||||
pub ttl: Duration,
|
||||
}
|
||||
|
||||
@@ -417,30 +471,7 @@ impl CacheOptions {
|
||||
/// Parse cache options passed via cmdline.
|
||||
/// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
|
||||
fn parse(options: &str) -> anyhow::Result<Self> {
|
||||
let mut size = None;
|
||||
let mut ttl = None;
|
||||
|
||||
for option in options.split(',') {
|
||||
let (key, value) = option
|
||||
.split_once('=')
|
||||
.with_context(|| format!("bad key-value pair: {option}"))?;
|
||||
|
||||
match key {
|
||||
"size" => size = Some(value.parse()?),
|
||||
"ttl" => ttl = Some(humantime::parse_duration(value)?),
|
||||
unknown => bail!("unknown key: {unknown}"),
|
||||
}
|
||||
}
|
||||
|
||||
// TTL doesn't matter if cache is always empty.
|
||||
if let Some(0) = size {
|
||||
ttl.get_or_insert(Duration::default());
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
size: size.context("missing `size`")?,
|
||||
ttl: ttl.context("missing `ttl`")?,
|
||||
})
|
||||
Ok(Self::deserialize(SimpleKVConfig(options))?)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -454,15 +485,21 @@ impl FromStr for CacheOptions {
|
||||
}
|
||||
|
||||
/// Helper for cmdline cache options parsing.
|
||||
#[derive(Debug)]
|
||||
#[serde_as]
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ProjectInfoCacheOptions {
|
||||
/// Max number of entries.
|
||||
#[serde_as(as = "serde_with::DisplayFromStr")]
|
||||
pub size: usize,
|
||||
/// Entry's time-to-live.
|
||||
#[serde(with = "humantime_serde")]
|
||||
#[serde(default)]
|
||||
pub ttl: Duration,
|
||||
/// Max number of roles per endpoint.
|
||||
#[serde_as(as = "serde_with::DisplayFromStr")]
|
||||
pub max_roles: usize,
|
||||
/// Gc interval.
|
||||
#[serde(with = "humantime_serde")]
|
||||
pub gc_interval: Duration,
|
||||
}
|
||||
|
||||
@@ -474,36 +511,7 @@ impl ProjectInfoCacheOptions {
|
||||
/// Parse cache options passed via cmdline.
|
||||
/// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
|
||||
fn parse(options: &str) -> anyhow::Result<Self> {
|
||||
let mut size = None;
|
||||
let mut ttl = None;
|
||||
let mut max_roles = None;
|
||||
let mut gc_interval = None;
|
||||
|
||||
for option in options.split(',') {
|
||||
let (key, value) = option
|
||||
.split_once('=')
|
||||
.with_context(|| format!("bad key-value pair: {option}"))?;
|
||||
|
||||
match key {
|
||||
"size" => size = Some(value.parse()?),
|
||||
"ttl" => ttl = Some(humantime::parse_duration(value)?),
|
||||
"max_roles" => max_roles = Some(value.parse()?),
|
||||
"gc_interval" => gc_interval = Some(humantime::parse_duration(value)?),
|
||||
unknown => bail!("unknown key: {unknown}"),
|
||||
}
|
||||
}
|
||||
|
||||
// TTL doesn't matter if cache is always empty.
|
||||
if let Some(0) = size {
|
||||
ttl.get_or_insert(Duration::default());
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
size: size.context("missing `size`")?,
|
||||
ttl: ttl.context("missing `ttl`")?,
|
||||
max_roles: max_roles.context("missing `max_roles`")?,
|
||||
gc_interval: gc_interval.context("missing `gc_interval`")?,
|
||||
})
|
||||
Ok(Self::deserialize(SimpleKVConfig(options))?)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -517,14 +525,23 @@ impl FromStr for ProjectInfoCacheOptions {
|
||||
}
|
||||
|
||||
/// Helper for cmdline cache options parsing.
|
||||
#[serde_as]
|
||||
#[derive(Deserialize)]
|
||||
pub struct WakeComputeLockOptions {
|
||||
/// The number of shards the lock map should have
|
||||
#[serde_as(as = "serde_with::DisplayFromStr")]
|
||||
#[serde(default)]
|
||||
pub shards: usize,
|
||||
/// The number of allowed concurrent requests for each endpoitn
|
||||
#[serde_as(as = "serde_with::DisplayFromStr")]
|
||||
pub permits: usize,
|
||||
/// Garbage collection epoch
|
||||
#[serde(with = "humantime_serde")]
|
||||
#[serde(default)]
|
||||
pub epoch: Duration,
|
||||
/// Lock timeout
|
||||
#[serde(with = "humantime_serde")]
|
||||
#[serde(default)]
|
||||
pub timeout: Duration,
|
||||
}
|
||||
|
||||
@@ -537,45 +554,23 @@ impl WakeComputeLockOptions {
|
||||
/// Parse lock options passed via cmdline.
|
||||
/// Example: [`Self::DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK`].
|
||||
fn parse(options: &str) -> anyhow::Result<Self> {
|
||||
let mut shards = None;
|
||||
let mut permits = None;
|
||||
let mut epoch = None;
|
||||
let mut timeout = None;
|
||||
|
||||
for option in options.split(',') {
|
||||
let (key, value) = option
|
||||
.split_once('=')
|
||||
.with_context(|| format!("bad key-value pair: {option}"))?;
|
||||
|
||||
match key {
|
||||
"shards" => shards = Some(value.parse()?),
|
||||
"permits" => permits = Some(value.parse()?),
|
||||
"epoch" => epoch = Some(humantime::parse_duration(value)?),
|
||||
"timeout" => timeout = Some(humantime::parse_duration(value)?),
|
||||
unknown => bail!("unknown key: {unknown}"),
|
||||
}
|
||||
let out = Self::deserialize(SimpleKVConfig(options))?;
|
||||
if out.permits != 0 {
|
||||
ensure!(
|
||||
out.timeout > Duration::ZERO,
|
||||
"wake compute lock timeout should be non-zero"
|
||||
);
|
||||
ensure!(
|
||||
out.epoch > Duration::ZERO,
|
||||
"wake compute lock gc epoch should be non-zero"
|
||||
);
|
||||
ensure!(out.shards > 1, "shard count must be > 1");
|
||||
ensure!(
|
||||
out.shards.is_power_of_two(),
|
||||
"shard count must be a power of two"
|
||||
);
|
||||
}
|
||||
|
||||
// these dont matter if lock is disabled
|
||||
if let Some(0) = permits {
|
||||
timeout = Some(Duration::default());
|
||||
epoch = Some(Duration::default());
|
||||
shards = Some(2);
|
||||
}
|
||||
|
||||
let out = Self {
|
||||
shards: shards.context("missing `shards`")?,
|
||||
permits: permits.context("missing `permits`")?,
|
||||
epoch: epoch.context("missing `epoch`")?,
|
||||
timeout: timeout.context("missing `timeout`")?,
|
||||
};
|
||||
|
||||
ensure!(out.shards > 1, "shard count must be > 1");
|
||||
ensure!(
|
||||
out.shards.is_power_of_two(),
|
||||
"shard count must be a power of two"
|
||||
);
|
||||
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
@@ -589,6 +584,100 @@ impl FromStr for WakeComputeLockOptions {
|
||||
}
|
||||
}
|
||||
|
||||
struct SimpleKVConfig<'a>(&'a str);
|
||||
struct SimpleKVConfigMapAccess<'a> {
|
||||
kv: std::str::Split<'a, char>,
|
||||
val: Option<&'a str>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct SimpleKVConfigErr(String);
|
||||
|
||||
impl std::fmt::Display for SimpleKVConfigErr {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str(&self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for SimpleKVConfigErr {}
|
||||
|
||||
impl serde::de::Error for SimpleKVConfigErr {
|
||||
fn custom<T>(msg: T) -> Self
|
||||
where
|
||||
T: std::fmt::Display,
|
||||
{
|
||||
Self(msg.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> MapAccess<'de> for SimpleKVConfigMapAccess<'de> {
|
||||
type Error = SimpleKVConfigErr;
|
||||
|
||||
fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
|
||||
where
|
||||
K: serde::de::DeserializeSeed<'de>,
|
||||
{
|
||||
let Some(kv) = self.kv.next() else {
|
||||
return Ok(None);
|
||||
};
|
||||
let (key, value) = kv
|
||||
.split_once('=')
|
||||
.ok_or_else(|| SimpleKVConfigErr("invalid kv pair".to_string()))?;
|
||||
self.val = Some(value);
|
||||
|
||||
seed.deserialize(BorrowedStrDeserializer::new(key))
|
||||
.map(Some)
|
||||
}
|
||||
|
||||
fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: serde::de::DeserializeSeed<'de>,
|
||||
{
|
||||
seed.deserialize(BorrowedStrDeserializer::new(self.val.take().unwrap()))
|
||||
}
|
||||
|
||||
fn next_entry_seed<K, V>(
|
||||
&mut self,
|
||||
kseed: K,
|
||||
vseed: V,
|
||||
) -> Result<Option<(K::Value, V::Value)>, Self::Error>
|
||||
where
|
||||
K: serde::de::DeserializeSeed<'de>,
|
||||
V: serde::de::DeserializeSeed<'de>,
|
||||
{
|
||||
let Some(kv) = self.kv.next() else {
|
||||
return Ok(None);
|
||||
};
|
||||
let (key, value) = kv
|
||||
.split_once('=')
|
||||
.ok_or_else(|| SimpleKVConfigErr("invalid kv pair".to_string()))?;
|
||||
|
||||
let key = kseed.deserialize(BorrowedStrDeserializer::new(key))?;
|
||||
let value = vseed.deserialize(BorrowedStrDeserializer::new(value))?;
|
||||
Ok(Some((key, value)))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserializer<'de> for SimpleKVConfig<'de> {
|
||||
type Error = SimpleKVConfigErr;
|
||||
|
||||
fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
|
||||
where
|
||||
V: serde::de::Visitor<'de>,
|
||||
{
|
||||
visitor.visit_map(SimpleKVConfigMapAccess {
|
||||
kv: self.0.split(','),
|
||||
val: None,
|
||||
})
|
||||
}
|
||||
|
||||
forward_to_deserialize_any! {
|
||||
bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
|
||||
bytes byte_buf option unit struct unit_struct newtype_struct seq tuple
|
||||
tuple_struct map enum identifier ignored_any
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -646,7 +735,7 @@ mod tests {
|
||||
} = "permits=0".parse()?;
|
||||
assert_eq!(epoch, Duration::ZERO);
|
||||
assert_eq!(timeout, Duration::ZERO);
|
||||
assert_eq!(shards, 2);
|
||||
assert_eq!(shards, 0);
|
||||
assert_eq!(permits, 0);
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -40,31 +40,28 @@ pub async fn task_main(listener: TcpListener) -> anyhow::Result<Infallible> {
|
||||
|
||||
let span = info_span!("mgmt", peer = %peer_addr);
|
||||
|
||||
tokio::task::Builder::new()
|
||||
.name("mgmt handler")
|
||||
.spawn(
|
||||
async move {
|
||||
info!("serving a new console management API connection");
|
||||
tokio::task::spawn(
|
||||
async move {
|
||||
info!("serving a new console management API connection");
|
||||
|
||||
// these might be long running connections, have a separate logging for cancelling
|
||||
// on shutdown and other ways of stopping.
|
||||
let cancelled = scopeguard::guard(tracing::Span::current(), |span| {
|
||||
let _e = span.entered();
|
||||
info!("console management API task cancelled");
|
||||
});
|
||||
// these might be long running connections, have a separate logging for cancelling
|
||||
// on shutdown and other ways of stopping.
|
||||
let cancelled = scopeguard::guard(tracing::Span::current(), |span| {
|
||||
let _e = span.entered();
|
||||
info!("console management API task cancelled");
|
||||
});
|
||||
|
||||
if let Err(e) = handle_connection(socket).await {
|
||||
error!("serving failed with an error: {e}");
|
||||
} else {
|
||||
info!("serving completed");
|
||||
}
|
||||
|
||||
// we can no longer get dropped
|
||||
scopeguard::ScopeGuard::into_inner(cancelled);
|
||||
if let Err(e) = handle_connection(socket).await {
|
||||
error!("serving failed with an error: {e}");
|
||||
} else {
|
||||
info!("serving completed");
|
||||
}
|
||||
.instrument(span),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// we can no longer get dropped
|
||||
scopeguard::ScopeGuard::into_inner(cancelled);
|
||||
}
|
||||
.instrument(span),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ use crate::{
|
||||
scram, EndpointCacheKey,
|
||||
};
|
||||
use dashmap::DashMap;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use std::{num::NonZeroUsize, sync::Arc, time::Duration};
|
||||
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
|
||||
use tokio::time::Instant;
|
||||
use tracing::info;
|
||||
@@ -293,6 +293,9 @@ pub struct NodeInfo {
|
||||
|
||||
/// Labels for proxy's metrics.
|
||||
pub aux: MetricsAuxInfo,
|
||||
|
||||
/// Whether we should accept self-signed certificates (for testing)
|
||||
pub allow_self_signed_compute: bool,
|
||||
}
|
||||
|
||||
impl NodeInfo {
|
||||
@@ -301,9 +304,17 @@ impl NodeInfo {
|
||||
ctx: &mut RequestMonitoring,
|
||||
timeout: Duration,
|
||||
) -> Result<compute::PostgresConnection, compute::ConnectionError> {
|
||||
self.config.connect(ctx, self.aux.clone(), timeout).await
|
||||
self.config
|
||||
.connect(
|
||||
ctx,
|
||||
self.allow_self_signed_compute,
|
||||
self.aux.clone(),
|
||||
timeout,
|
||||
)
|
||||
.await
|
||||
}
|
||||
pub fn reuse_settings(&mut self, other: Self) {
|
||||
self.allow_self_signed_compute = other.allow_self_signed_compute;
|
||||
self.config.reuse_password(other.config);
|
||||
}
|
||||
|
||||
@@ -438,13 +449,17 @@ impl ApiCaches {
|
||||
/// Various caches for [`console`](super).
|
||||
pub struct ApiLocks {
|
||||
name: &'static str,
|
||||
node_locks: DashMap<EndpointCacheKey, Arc<Semaphore>>,
|
||||
permits: usize,
|
||||
inner: Option<ApiLocksInner>,
|
||||
timeout: Duration,
|
||||
epoch: std::time::Duration,
|
||||
metrics: &'static ApiLockMetrics,
|
||||
}
|
||||
|
||||
struct ApiLocksInner {
|
||||
permits: NonZeroUsize,
|
||||
node_locks: DashMap<EndpointCacheKey, Arc<Semaphore>>,
|
||||
}
|
||||
|
||||
impl ApiLocks {
|
||||
pub fn new(
|
||||
name: &'static str,
|
||||
@@ -454,10 +469,14 @@ impl ApiLocks {
|
||||
epoch: std::time::Duration,
|
||||
metrics: &'static ApiLockMetrics,
|
||||
) -> prometheus::Result<Self> {
|
||||
let inner = NonZeroUsize::new(permits).map(|permits| ApiLocksInner {
|
||||
permits,
|
||||
node_locks: DashMap::with_shard_amount(shards),
|
||||
});
|
||||
|
||||
Ok(Self {
|
||||
name,
|
||||
node_locks: DashMap::with_shard_amount(shards),
|
||||
permits,
|
||||
inner,
|
||||
timeout,
|
||||
epoch,
|
||||
metrics,
|
||||
@@ -468,20 +487,21 @@ impl ApiLocks {
|
||||
&self,
|
||||
key: &EndpointCacheKey,
|
||||
) -> Result<WakeComputePermit, errors::WakeComputeError> {
|
||||
if self.permits == 0 {
|
||||
let Some(inner) = &self.inner else {
|
||||
return Ok(WakeComputePermit { permit: None });
|
||||
}
|
||||
};
|
||||
let now = Instant::now();
|
||||
let semaphore = {
|
||||
// get fast path
|
||||
if let Some(semaphore) = self.node_locks.get(key) {
|
||||
if let Some(semaphore) = inner.node_locks.get(key) {
|
||||
semaphore.clone()
|
||||
} else {
|
||||
self.node_locks
|
||||
inner
|
||||
.node_locks
|
||||
.entry(key.clone())
|
||||
.or_insert_with(|| {
|
||||
self.metrics.semaphores_registered.inc();
|
||||
Arc::new(Semaphore::new(self.permits))
|
||||
Arc::new(Semaphore::new(inner.permits.get()))
|
||||
})
|
||||
.clone()
|
||||
}
|
||||
@@ -498,13 +518,13 @@ impl ApiLocks {
|
||||
}
|
||||
|
||||
pub async fn garbage_collect_worker(&self) {
|
||||
if self.permits == 0 {
|
||||
let Some(inner) = &self.inner else {
|
||||
return;
|
||||
}
|
||||
};
|
||||
let mut interval =
|
||||
tokio::time::interval(self.epoch / (self.node_locks.shards().len()) as u32);
|
||||
tokio::time::interval(self.epoch / (inner.node_locks.shards().len()) as u32);
|
||||
loop {
|
||||
for (i, shard) in self.node_locks.shards().iter().enumerate() {
|
||||
for (i, shard) in inner.node_locks.shards().iter().enumerate() {
|
||||
interval.tick().await;
|
||||
// temporary lock a single shard and then clear any semaphores that aren't currently checked out
|
||||
// race conditions: if strong_count == 1, there's no way that it can increase while the shard is locked
|
||||
|
||||
@@ -63,10 +63,7 @@ impl Api {
|
||||
let (client, connection) =
|
||||
tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
|
||||
|
||||
tokio::task::Builder::new()
|
||||
.name("mock conn")
|
||||
.spawn(connection)
|
||||
.unwrap();
|
||||
tokio::spawn(connection);
|
||||
let secret = match get_execute_postgres_query(
|
||||
&client,
|
||||
"select rolpassword from pg_catalog.pg_authid where rolname = $1",
|
||||
@@ -129,6 +126,7 @@ impl Api {
|
||||
branch_id: (&BranchId::from("branch")).into(),
|
||||
cold_start_info: crate::console::messages::ColdStartInfo::Warm,
|
||||
},
|
||||
allow_self_signed_compute: false,
|
||||
};
|
||||
|
||||
Ok(node)
|
||||
|
||||
@@ -175,6 +175,7 @@ impl Api {
|
||||
let node = NodeInfo {
|
||||
config,
|
||||
aux: body.aux,
|
||||
allow_self_signed_compute: false,
|
||||
};
|
||||
|
||||
Ok(node)
|
||||
|
||||
@@ -141,15 +141,12 @@ pub async fn worker(
|
||||
LOG_CHAN.set(tx.downgrade()).unwrap();
|
||||
|
||||
// setup row stream that will close on cancellation
|
||||
tokio::task::Builder::new()
|
||||
.name("drop parquet conn")
|
||||
.spawn(async move {
|
||||
cancellation_token.cancelled().await;
|
||||
// dropping this sender will cause the channel to close only once
|
||||
// all the remaining inflight requests have been completed.
|
||||
drop(tx);
|
||||
})
|
||||
.unwrap();
|
||||
tokio::spawn(async move {
|
||||
cancellation_token.cancelled().await;
|
||||
// dropping this sender will cause the channel to close only once
|
||||
// all the remaining inflight requests have been completed.
|
||||
drop(tx);
|
||||
});
|
||||
let rx = futures::stream::poll_fn(move |cx| rx.poll_recv(cx));
|
||||
let rx = rx.map(RequestData::from);
|
||||
|
||||
|
||||
@@ -75,10 +75,6 @@ async fn prometheus_metrics_handler(
|
||||
|
||||
let span = info_span!("blocking");
|
||||
let body = tokio::task::spawn_blocking(move || {
|
||||
// there are situations where we lose scraped metrics under load, try to gather some clues
|
||||
// since all nodes are queried this, keep the message count low.
|
||||
let spawned_at = std::time::Instant::now();
|
||||
|
||||
let _span = span.entered();
|
||||
|
||||
let mut state = state.lock().unwrap();
|
||||
@@ -88,19 +84,11 @@ async fn prometheus_metrics_handler(
|
||||
.collect_group_into(&mut *encoder)
|
||||
.unwrap_or_else(|infallible| match infallible {});
|
||||
|
||||
let encoded_at = std::time::Instant::now();
|
||||
|
||||
let body = encoder.finish();
|
||||
|
||||
let spawned_in = spawned_at - started_at;
|
||||
let encoded_in = encoded_at - spawned_at;
|
||||
let total = encoded_at - started_at;
|
||||
|
||||
tracing::info!(
|
||||
bytes = body.len(),
|
||||
total_ms = total.as_millis(),
|
||||
spawning_ms = spawned_in.as_millis(),
|
||||
encoding_ms = encoded_in.as_millis(),
|
||||
elapsed_ms = started_at.elapsed().as_millis(),
|
||||
"responded /metrics"
|
||||
);
|
||||
|
||||
|
||||
@@ -26,12 +26,7 @@ pub async fn init() -> anyhow::Result<LoggingGuard> {
|
||||
.await
|
||||
.map(OpenTelemetryLayer::new);
|
||||
|
||||
// spawn the console server in the background,
|
||||
// returning a `Layer`:
|
||||
let console_layer = console_subscriber::spawn();
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(console_layer)
|
||||
.with(env_filter)
|
||||
.with(otlp_layer)
|
||||
.with(fmt_layer)
|
||||
|
||||
@@ -87,7 +87,7 @@ pub async fn task_main(
|
||||
|
||||
tracing::info!(protocol = "tcp", %session_id, "accepted new TCP connection");
|
||||
|
||||
tokio::task::Builder::new().name("tcp client connection").spawn(connections.track_future(async move {
|
||||
connections.spawn(async move {
|
||||
let mut socket = WithClientIp::new(socket);
|
||||
let mut peer_addr = peer_addr.ip();
|
||||
match socket.wait_for_addr().await {
|
||||
@@ -152,7 +152,7 @@ pub async fn task_main(
|
||||
}
|
||||
}
|
||||
}
|
||||
})).unwrap();
|
||||
});
|
||||
}
|
||||
|
||||
connections.close();
|
||||
@@ -178,6 +178,13 @@ impl ClientMode {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool {
|
||||
match self {
|
||||
ClientMode::Tcp => config.allow_self_signed_compute,
|
||||
ClientMode::Websockets { .. } => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn hostname<'a, S>(&'a self, s: &'a Stream<S>) -> Option<&'a str> {
|
||||
match self {
|
||||
ClientMode::Tcp => s.sni_hostname(),
|
||||
@@ -296,9 +303,14 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
}
|
||||
};
|
||||
|
||||
let mut node = connect_to_compute(ctx, &TcpMechanism { params: ¶ms }, &user_info)
|
||||
.or_else(|e| stream.throw_error(e))
|
||||
.await?;
|
||||
let mut node = connect_to_compute(
|
||||
ctx,
|
||||
&TcpMechanism { params: ¶ms },
|
||||
&user_info,
|
||||
mode.allow_self_signed_compute(config),
|
||||
)
|
||||
.or_else(|e| stream.throw_error(e))
|
||||
.await?;
|
||||
|
||||
let session = cancellation_handler.get_session();
|
||||
prepare_client_connection(&node, &session, &mut stream).await?;
|
||||
|
||||
@@ -92,6 +92,7 @@ pub async fn connect_to_compute<M: ConnectMechanism, B: ComputeConnectBackend>(
|
||||
ctx: &mut RequestMonitoring,
|
||||
mechanism: &M,
|
||||
user_info: &B,
|
||||
allow_self_signed_compute: bool,
|
||||
) -> Result<M::Connection, M::Error>
|
||||
where
|
||||
M::ConnectError: ShouldRetry + std::fmt::Debug,
|
||||
@@ -102,6 +103,8 @@ where
|
||||
if let Some(keys) = user_info.get_keys() {
|
||||
node_info.set_keys(keys);
|
||||
}
|
||||
node_info.allow_self_signed_compute = allow_self_signed_compute;
|
||||
// let mut node_info = credentials.get_node_info(ctx, user_info).await?;
|
||||
mechanism.update_connect_config(&mut node_info.config);
|
||||
|
||||
// try once
|
||||
|
||||
@@ -519,6 +519,7 @@ fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeIn
|
||||
branch_id: (&BranchId::from("branch")).into(),
|
||||
cold_start_info: crate::console::messages::ColdStartInfo::Warm,
|
||||
},
|
||||
allow_self_signed_compute: false,
|
||||
};
|
||||
let (_, node) = cache.insert("key".into(), node);
|
||||
node
|
||||
@@ -548,7 +549,7 @@ async fn connect_to_compute_success() {
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info)
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -561,7 +562,7 @@ async fn connect_to_compute_retry() {
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info)
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -575,7 +576,7 @@ async fn connect_to_compute_non_retry_1() {
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Fail]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info)
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
|
||||
.await
|
||||
.unwrap_err();
|
||||
mechanism.verify();
|
||||
@@ -589,7 +590,7 @@ async fn connect_to_compute_non_retry_2() {
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, Fail, Wake, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info)
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -607,7 +608,7 @@ async fn connect_to_compute_non_retry_3() {
|
||||
Retry, Retry, Retry, Retry, Retry, /* the 17th time */ Retry,
|
||||
]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info)
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
|
||||
.await
|
||||
.unwrap_err();
|
||||
mechanism.verify();
|
||||
@@ -621,7 +622,7 @@ async fn wake_retry() {
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![WakeRetry, Wake, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info)
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -635,7 +636,7 @@ async fn wake_non_retry() {
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![WakeRetry, WakeFail]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info)
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
|
||||
.await
|
||||
.unwrap_err();
|
||||
mechanism.verify();
|
||||
|
||||
@@ -108,12 +108,10 @@ impl ConnectionWithCredentialsProvider {
|
||||
if let Credentials::Dynamic(credentials_provider, _) = &self.credentials {
|
||||
let credentials_provider = credentials_provider.clone();
|
||||
let con2 = con.clone();
|
||||
let f = tokio::task::Builder::new()
|
||||
.name("redis keep connection")
|
||||
.spawn(async move {
|
||||
let _ = Self::keep_connection(con2, credentials_provider).await;
|
||||
});
|
||||
self.refresh_token_task = Some(f.unwrap());
|
||||
let f = tokio::spawn(async move {
|
||||
let _ = Self::keep_connection(con2, credentials_provider).await;
|
||||
});
|
||||
self.refresh_token_task = Some(f);
|
||||
}
|
||||
match Self::ping(&mut con).await {
|
||||
Ok(()) => {
|
||||
|
||||
@@ -142,13 +142,10 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
|
||||
// To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
|
||||
// TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
|
||||
let cache = self.cache.clone();
|
||||
tokio::task::Builder::new()
|
||||
.name("invalidate cache lazy")
|
||||
.spawn(async move {
|
||||
tokio::time::sleep(INVALIDATION_LAG).await;
|
||||
invalidate_cache(cache, msg);
|
||||
})
|
||||
.unwrap();
|
||||
tokio::spawn(async move {
|
||||
tokio::time::sleep(INVALIDATION_LAG).await;
|
||||
invalidate_cache(cache, msg);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -61,28 +61,22 @@ pub async fn task_main(
|
||||
let conn_pool = conn_pool::GlobalConnPool::new(&config.http_config);
|
||||
{
|
||||
let conn_pool = Arc::clone(&conn_pool);
|
||||
tokio::task::Builder::new()
|
||||
.name("serverless pool gc")
|
||||
.spawn(async move {
|
||||
conn_pool.gc_worker(StdRng::from_entropy()).await;
|
||||
})
|
||||
.unwrap();
|
||||
tokio::spawn(async move {
|
||||
conn_pool.gc_worker(StdRng::from_entropy()).await;
|
||||
});
|
||||
}
|
||||
|
||||
// shutdown the connection pool
|
||||
tokio::task::Builder::new()
|
||||
.name("serverless pool shutdown")
|
||||
.spawn({
|
||||
let cancellation_token = cancellation_token.clone();
|
||||
let conn_pool = conn_pool.clone();
|
||||
async move {
|
||||
cancellation_token.cancelled().await;
|
||||
tokio::task::spawn_blocking(move || conn_pool.shutdown())
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
tokio::spawn({
|
||||
let cancellation_token = cancellation_token.clone();
|
||||
let conn_pool = conn_pool.clone();
|
||||
async move {
|
||||
cancellation_token.cancelled().await;
|
||||
tokio::task::spawn_blocking(move || conn_pool.shutdown())
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
let backend = Arc::new(PoolingBackend {
|
||||
pool: Arc::clone(&conn_pool),
|
||||
@@ -115,25 +109,20 @@ pub async fn task_main(
|
||||
let conn_id = uuid::Uuid::new_v4();
|
||||
let http_conn_span = tracing::info_span!("http_conn", ?conn_id);
|
||||
|
||||
tokio::task::Builder::new()
|
||||
.name("serverless conn handler")
|
||||
.spawn(
|
||||
connections.track_future(
|
||||
connection_handler(
|
||||
config,
|
||||
backend.clone(),
|
||||
connections.clone(),
|
||||
cancellation_handler.clone(),
|
||||
cancellation_token.clone(),
|
||||
server.clone(),
|
||||
tls_acceptor.clone(),
|
||||
conn,
|
||||
peer_addr,
|
||||
)
|
||||
.instrument(http_conn_span),
|
||||
),
|
||||
connections.spawn(
|
||||
connection_handler(
|
||||
config,
|
||||
backend.clone(),
|
||||
connections.clone(),
|
||||
cancellation_handler.clone(),
|
||||
cancellation_token.clone(),
|
||||
server.clone(),
|
||||
tls_acceptor.clone(),
|
||||
conn,
|
||||
peer_addr,
|
||||
)
|
||||
.unwrap();
|
||||
.instrument(http_conn_span),
|
||||
);
|
||||
}
|
||||
|
||||
connections.wait().await;
|
||||
@@ -229,25 +218,20 @@ async fn connection_handler(
|
||||
|
||||
// `request_handler` is not cancel safe. It expects to be cancelled only at specific times.
|
||||
// By spawning the future, we ensure it never gets cancelled until it decides to.
|
||||
let handler = tokio::task::Builder::new()
|
||||
.name("serverless request handler")
|
||||
.spawn(
|
||||
connections.track_future(
|
||||
request_handler(
|
||||
req,
|
||||
config,
|
||||
backend.clone(),
|
||||
connections.clone(),
|
||||
cancellation_handler.clone(),
|
||||
session_id,
|
||||
peer_addr,
|
||||
http_request_token,
|
||||
)
|
||||
.in_current_span()
|
||||
.map_ok_or_else(api_error_into_response, |r| r),
|
||||
),
|
||||
let handler = connections.spawn(
|
||||
request_handler(
|
||||
req,
|
||||
config,
|
||||
backend.clone(),
|
||||
connections.clone(),
|
||||
cancellation_handler.clone(),
|
||||
session_id,
|
||||
peer_addr,
|
||||
http_request_token,
|
||||
)
|
||||
.unwrap();
|
||||
.in_current_span()
|
||||
.map_ok_or_else(api_error_into_response, |r| r),
|
||||
);
|
||||
|
||||
async move {
|
||||
let res = handler.await;
|
||||
@@ -306,27 +290,17 @@ async fn request_handler(
|
||||
let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None)
|
||||
.map_err(|e| ApiError::BadRequest(e.into()))?;
|
||||
|
||||
tokio::task::Builder::new()
|
||||
.name("websocket client conn")
|
||||
.spawn(
|
||||
ws_connections.track_future(
|
||||
async move {
|
||||
if let Err(e) = websocket::serve_websocket(
|
||||
config,
|
||||
ctx,
|
||||
websocket,
|
||||
cancellation_handler,
|
||||
host,
|
||||
)
|
||||
ws_connections.spawn(
|
||||
async move {
|
||||
if let Err(e) =
|
||||
websocket::serve_websocket(config, ctx, websocket, cancellation_handler, host)
|
||||
.await
|
||||
{
|
||||
error!("error in websocket connection: {e:#}");
|
||||
}
|
||||
}
|
||||
.instrument(span),
|
||||
),
|
||||
)
|
||||
.unwrap();
|
||||
{
|
||||
error!("error in websocket connection: {e:#}");
|
||||
}
|
||||
}
|
||||
.instrument(span),
|
||||
);
|
||||
|
||||
// Return the response so the spawned future can continue.
|
||||
Ok(response)
|
||||
|
||||
@@ -107,6 +107,7 @@ impl PoolingBackend {
|
||||
pool: self.pool.clone(),
|
||||
},
|
||||
&backend,
|
||||
false, // do not allow self signed compute for http flow
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -492,7 +492,7 @@ pub fn poll_client<C: ClientInnerExt>(
|
||||
let cancel = CancellationToken::new();
|
||||
let cancelled = cancel.clone().cancelled_owned();
|
||||
|
||||
tokio::task::Builder::new().name("pooled conn").spawn(
|
||||
tokio::spawn(
|
||||
async move {
|
||||
let _conn_gauge = conn_gauge;
|
||||
let mut idle_timeout = pin!(tokio::time::sleep(idle));
|
||||
@@ -565,7 +565,7 @@ pub fn poll_client<C: ClientInnerExt>(
|
||||
}).await;
|
||||
|
||||
}
|
||||
.instrument(span)).unwrap();
|
||||
.instrument(span));
|
||||
let inner = ClientInner {
|
||||
inner: client,
|
||||
session: tx,
|
||||
|
||||
@@ -38,7 +38,6 @@ futures-sink = { version = "0.3" }
|
||||
futures-util = { version = "0.3", features = ["channel", "io", "sink"] }
|
||||
getrandom = { version = "0.2", default-features = false, features = ["std"] }
|
||||
hashbrown = { version = "0.14", features = ["raw"] }
|
||||
hdrhistogram = { version = "7" }
|
||||
hex = { version = "0.4", features = ["serde"] }
|
||||
hmac = { version = "0.12", default-features = false, features = ["reset"] }
|
||||
hyper = { version = "0.14", features = ["full"] }
|
||||
@@ -67,9 +66,8 @@ sha2 = { version = "0.10", features = ["asm"] }
|
||||
smallvec = { version = "1", default-features = false, features = ["const_new", "write"] }
|
||||
subtle = { version = "2" }
|
||||
time = { version = "0.3", features = ["local-offset", "macros", "serde-well-known"] }
|
||||
tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "net", "process", "rt-multi-thread", "signal", "test-util", "tracing"] }
|
||||
tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "net", "process", "rt-multi-thread", "signal", "test-util"] }
|
||||
tokio-rustls = { version = "0.24" }
|
||||
tokio-stream = { version = "0.1", features = ["net"] }
|
||||
tokio-util = { version = "0.7", features = ["codec", "compat", "io", "rt"] }
|
||||
toml_datetime = { version = "0.6", default-features = false, features = ["serde"] }
|
||||
toml_edit = { version = "0.19", features = ["serde"] }
|
||||
|
||||
Reference in New Issue
Block a user