mirror of
https://github.com/neondatabase/neon.git
synced 2026-02-03 10:40:37 +00:00
Compare commits
11 Commits
conrad/pro
...
tristan957
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8b75265231 | ||
|
|
62cd3b8d3d | ||
|
|
8d26978ed9 | ||
|
|
35372a8f12 | ||
|
|
6d95a3fe2d | ||
|
|
99726495c7 | ||
|
|
4a4a457312 | ||
|
|
e78d1e2ec6 | ||
|
|
af429b4a62 | ||
|
|
3b4d4eb535 | ||
|
|
f060537a31 |
@@ -57,21 +57,6 @@ use tracing::{error, info};
|
||||
use url::Url;
|
||||
use utils::failpoint_support;
|
||||
|
||||
// Compatibility hack: if the control plane specified any remote-ext-config
|
||||
// use the default value for extension storage proxy gateway.
|
||||
// Remove this once the control plane is updated to pass the gateway URL
|
||||
fn parse_remote_ext_base_url(arg: &str) -> Result<String> {
|
||||
const FALLBACK_PG_EXT_GATEWAY_BASE_URL: &str =
|
||||
"http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local";
|
||||
|
||||
Ok(if arg.starts_with("http") {
|
||||
arg
|
||||
} else {
|
||||
FALLBACK_PG_EXT_GATEWAY_BASE_URL
|
||||
}
|
||||
.to_owned())
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(rename_all = "kebab-case")]
|
||||
struct Cli {
|
||||
@@ -80,7 +65,7 @@ struct Cli {
|
||||
|
||||
/// The base URL for the remote extension storage proxy gateway.
|
||||
/// Should be in the form of `http(s)://<gateway-hostname>[:<port>]`.
|
||||
#[arg(short = 'r', long, value_parser = parse_remote_ext_base_url, alias = "remote-ext-config")]
|
||||
#[arg(short = 'r', long, alias = "remote-ext-config")]
|
||||
pub remote_ext_base_url: Option<String>,
|
||||
|
||||
/// The port to bind the external listening HTTP server to. Clients running
|
||||
@@ -276,18 +261,4 @@ mod test {
|
||||
fn verify_cli() {
|
||||
Cli::command().debug_assert()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_pg_ext_gateway_base_url() {
|
||||
let arg = "http://pg-ext-s3-gateway2";
|
||||
let result = super::parse_remote_ext_base_url(arg).unwrap();
|
||||
assert_eq!(result, arg);
|
||||
|
||||
let arg = "pg-ext-s3-gateway";
|
||||
let result = super::parse_remote_ext_base_url(arg).unwrap();
|
||||
assert_eq!(
|
||||
result,
|
||||
"http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -339,6 +339,8 @@ async fn run_dump_restore(
|
||||
destination_connstring: String,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let dumpdir = workdir.join("dumpdir");
|
||||
let num_jobs = num_cpus::get().to_string();
|
||||
info!("using {num_jobs} jobs for dump/restore");
|
||||
|
||||
let common_args = [
|
||||
// schema mapping (prob suffices to specify them on one side)
|
||||
@@ -354,7 +356,7 @@ async fn run_dump_restore(
|
||||
"directory".to_string(),
|
||||
// concurrency
|
||||
"--jobs".to_string(),
|
||||
num_cpus::get().to_string(),
|
||||
num_jobs,
|
||||
// progress updates
|
||||
"--verbose".to_string(),
|
||||
];
|
||||
|
||||
@@ -71,7 +71,7 @@ impl AsyncAuthorizeRequest<Body> for Authorize {
|
||||
// Unauthorized because when we eventually do use
|
||||
// [`Validation`], we will hit the above `Err` match arm which
|
||||
// returns 401 Unauthorized.
|
||||
Some(ComputeClaimsScope::Admin) => {
|
||||
Some(ref scope) if scope.contains(&ComputeClaimsScope::Admin) => {
|
||||
let Some(ref audience) = data.claims.audience else {
|
||||
return Err(JsonResponse::error(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
|
||||
@@ -709,7 +709,7 @@ struct EndpointGenerateJwtCmdArgs {
|
||||
endpoint_id: String,
|
||||
|
||||
#[clap(short = 's', long, help = "Scope to generate the JWT with", value_parser = ComputeClaimsScope::from_str)]
|
||||
scope: Option<ComputeClaimsScope>,
|
||||
scope: Vec<ComputeClaimsScope>,
|
||||
}
|
||||
|
||||
#[derive(clap::Subcommand)]
|
||||
@@ -1580,7 +1580,7 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
|
||||
.with_context(|| format!("postgres endpoint {endpoint_id} is not found"))?
|
||||
};
|
||||
|
||||
let jwt = endpoint.generate_jwt(args.scope)?;
|
||||
let jwt = endpoint.generate_jwt(Some(args.scope.clone()))?;
|
||||
|
||||
print!("{jwt}");
|
||||
}
|
||||
|
||||
@@ -632,14 +632,16 @@ impl Endpoint {
|
||||
}
|
||||
|
||||
/// Generate a JWT with the correct claims.
|
||||
pub fn generate_jwt(&self, scope: Option<ComputeClaimsScope>) -> Result<String> {
|
||||
pub fn generate_jwt(&self, scope: Option<Vec<ComputeClaimsScope>>) -> Result<String> {
|
||||
self.env.generate_auth_token(&ComputeClaims {
|
||||
audience: match scope {
|
||||
Some(ComputeClaimsScope::Admin) => Some(vec![COMPUTE_AUDIENCE.to_owned()]),
|
||||
Some(ref scope) if scope.contains(&ComputeClaimsScope::Admin) => {
|
||||
Some(vec![COMPUTE_AUDIENCE.to_owned()])
|
||||
}
|
||||
_ => None,
|
||||
},
|
||||
compute_id: match scope {
|
||||
Some(ComputeClaimsScope::Admin) => None,
|
||||
Some(ref scope) if scope.contains(&ComputeClaimsScope::Admin) => None,
|
||||
_ => Some(self.endpoint_id.clone()),
|
||||
},
|
||||
scope,
|
||||
@@ -918,7 +920,7 @@ impl Endpoint {
|
||||
self.external_http_address.port()
|
||||
),
|
||||
)
|
||||
.bearer_auth(self.generate_jwt(None::<ComputeClaimsScope>)?)
|
||||
.bearer_auth(self.generate_jwt(None::<Vec<ComputeClaimsScope>>)?)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
@@ -995,7 +997,7 @@ impl Endpoint {
|
||||
self.external_http_address.port()
|
||||
))
|
||||
.header(CONTENT_TYPE.as_str(), "application/json")
|
||||
.bearer_auth(self.generate_jwt(None::<ComputeClaimsScope>)?)
|
||||
.bearer_auth(self.generate_jwt(None::<Vec<ComputeClaimsScope>>)?)
|
||||
.body(
|
||||
serde_json::to_string(&ConfigurationRequest {
|
||||
spec,
|
||||
|
||||
@@ -41,8 +41,8 @@ pub struct ComputeClaims {
|
||||
/// [`ComputeClaimsScope::Admin`].
|
||||
pub compute_id: Option<String>,
|
||||
|
||||
/// The scope of what the token authorizes.
|
||||
pub scope: Option<ComputeClaimsScope>,
|
||||
/// The scopes of what the token is authorized for.
|
||||
pub scope: Option<Vec<ComputeClaimsScope>>,
|
||||
|
||||
/// The recipient the token is intended for.
|
||||
///
|
||||
|
||||
@@ -27,6 +27,7 @@ pub use prometheus::{
|
||||
|
||||
pub mod launch_timestamp;
|
||||
mod wrappers;
|
||||
pub use prometheus;
|
||||
pub use wrappers::{CountedReader, CountedWriter};
|
||||
mod hll;
|
||||
pub use hll::{HyperLogLog, HyperLogLogState, HyperLogLogVec};
|
||||
|
||||
@@ -4,6 +4,7 @@ use std::{sync::Arc, time::Duration};
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{Instrument, info_span};
|
||||
|
||||
use crate::{FeatureStore, PostHogClient, PostHogClientConfig};
|
||||
|
||||
@@ -26,31 +27,35 @@ impl FeatureResolverBackgroundLoop {
|
||||
pub fn spawn(self: Arc<Self>, handle: &tokio::runtime::Handle, refresh_period: Duration) {
|
||||
let this = self.clone();
|
||||
let cancel = self.cancel.clone();
|
||||
handle.spawn(async move {
|
||||
tracing::info!("Starting PostHog feature resolver");
|
||||
let mut ticker = tokio::time::interval(refresh_period);
|
||||
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = ticker.tick() => {}
|
||||
_ = cancel.cancelled() => break
|
||||
}
|
||||
let resp = match this
|
||||
.posthog_client
|
||||
.get_feature_flags_local_evaluation()
|
||||
.await
|
||||
{
|
||||
Ok(resp) => resp,
|
||||
Err(e) => {
|
||||
tracing::warn!("Cannot get feature flags: {}", e);
|
||||
continue;
|
||||
handle.spawn(
|
||||
async move {
|
||||
tracing::info!("Starting PostHog feature resolver");
|
||||
let mut ticker = tokio::time::interval(refresh_period);
|
||||
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = ticker.tick() => {}
|
||||
_ = cancel.cancelled() => break
|
||||
}
|
||||
};
|
||||
let feature_store = FeatureStore::new_with_flags(resp.flags);
|
||||
this.feature_store.store(Arc::new(feature_store));
|
||||
let resp = match this
|
||||
.posthog_client
|
||||
.get_feature_flags_local_evaluation()
|
||||
.await
|
||||
{
|
||||
Ok(resp) => resp,
|
||||
Err(e) => {
|
||||
tracing::warn!("Cannot get feature flags: {}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let feature_store = FeatureStore::new_with_flags(resp.flags);
|
||||
this.feature_store.store(Arc::new(feature_store));
|
||||
tracing::info!("Feature flag updated");
|
||||
}
|
||||
tracing::info!("PostHog feature resolver stopped");
|
||||
}
|
||||
tracing::info!("PostHog feature resolver stopped");
|
||||
});
|
||||
.instrument(info_span!("posthog_feature_resolver")),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn feature_store(&self) -> Arc<FeatureStore> {
|
||||
|
||||
@@ -448,6 +448,18 @@ impl FeatureStore {
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
/// Infer whether a feature flag is a boolean flag by checking if it has a multivariate filter.
|
||||
pub fn is_feature_flag_boolean(&self, flag_key: &str) -> Result<bool, PostHogEvaluationError> {
|
||||
if let Some(flag_config) = self.flags.get(flag_key) {
|
||||
Ok(flag_config.filters.multivariate.is_none())
|
||||
} else {
|
||||
Err(PostHogEvaluationError::NotAvailable(format!(
|
||||
"Not found in the local evaluation spec: {}",
|
||||
flag_key
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PostHogClientConfig {
|
||||
@@ -528,7 +540,15 @@ impl PostHogClient {
|
||||
.bearer_auth(&self.config.server_api_key)
|
||||
.send()
|
||||
.await?;
|
||||
let status = response.status();
|
||||
let body = response.text().await?;
|
||||
if !status.is_success() {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Failed to get feature flags: {}, {}",
|
||||
status,
|
||||
body
|
||||
));
|
||||
}
|
||||
Ok(serde_json::from_str(&body)?)
|
||||
}
|
||||
|
||||
|
||||
@@ -264,10 +264,56 @@ mod propagation_of_cached_label_value {
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(histograms, histograms::bench_bucket_scalability);
|
||||
mod histograms {
|
||||
use std::time::Instant;
|
||||
|
||||
use criterion::{BenchmarkId, Criterion};
|
||||
use metrics::core::Collector;
|
||||
|
||||
pub fn bench_bucket_scalability(c: &mut Criterion) {
|
||||
let mut g = c.benchmark_group("bucket_scalability");
|
||||
|
||||
for n in [1, 4, 8, 16, 32, 64, 128, 256] {
|
||||
g.bench_with_input(BenchmarkId::new("nbuckets", n), &n, |b, n| {
|
||||
b.iter_custom(|iters| {
|
||||
let buckets: Vec<f64> = (0..*n).map(|i| i as f64 * 100.0).collect();
|
||||
let histo = metrics::Histogram::with_opts(
|
||||
metrics::prometheus::HistogramOpts::new("name", "help")
|
||||
.buckets(buckets.clone()),
|
||||
)
|
||||
.unwrap();
|
||||
let start = Instant::now();
|
||||
for i in 0..usize::try_from(iters).unwrap() {
|
||||
histo.observe(buckets[i % buckets.len()]);
|
||||
}
|
||||
let elapsed = start.elapsed();
|
||||
// self-test
|
||||
let mfs = histo.collect();
|
||||
assert_eq!(mfs.len(), 1);
|
||||
let metrics = mfs[0].get_metric();
|
||||
assert_eq!(metrics.len(), 1);
|
||||
let histo = metrics[0].get_histogram();
|
||||
let buckets = histo.get_bucket();
|
||||
assert!(
|
||||
buckets
|
||||
.iter()
|
||||
.enumerate()
|
||||
.all(|(i, b)| b.get_cumulative_count()
|
||||
>= i as u64 * (iters / buckets.len() as u64))
|
||||
);
|
||||
elapsed
|
||||
})
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
criterion_main!(
|
||||
label_values,
|
||||
single_metric_multicore_scalability,
|
||||
propagation_of_cached_label_value
|
||||
propagation_of_cached_label_value,
|
||||
histograms,
|
||||
);
|
||||
|
||||
/*
|
||||
@@ -290,6 +336,14 @@ propagation_of_cached_label_value__naive/nthreads/8 time: [211.50 ns 214.44 ns
|
||||
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/1 time: [14.135 ns 14.147 ns 14.160 ns]
|
||||
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/4 time: [14.243 ns 14.255 ns 14.268 ns]
|
||||
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/8 time: [14.470 ns 14.682 ns 14.895 ns]
|
||||
bucket_scalability/nbuckets/1 time: [30.352 ns 30.353 ns 30.354 ns]
|
||||
bucket_scalability/nbuckets/4 time: [30.464 ns 30.465 ns 30.467 ns]
|
||||
bucket_scalability/nbuckets/8 time: [30.569 ns 30.575 ns 30.584 ns]
|
||||
bucket_scalability/nbuckets/16 time: [30.961 ns 30.965 ns 30.969 ns]
|
||||
bucket_scalability/nbuckets/32 time: [35.691 ns 35.707 ns 35.722 ns]
|
||||
bucket_scalability/nbuckets/64 time: [47.829 ns 47.898 ns 47.974 ns]
|
||||
bucket_scalability/nbuckets/128 time: [73.479 ns 73.512 ns 73.545 ns]
|
||||
bucket_scalability/nbuckets/256 time: [127.92 ns 127.94 ns 127.96 ns]
|
||||
|
||||
Results on an i3en.3xlarge instance
|
||||
|
||||
@@ -344,6 +398,14 @@ propagation_of_cached_label_value__naive/nthreads/8 time: [434.87 ns 456.4
|
||||
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/1 time: [3.3767 ns 3.3974 ns 3.4220 ns]
|
||||
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/4 time: [3.6105 ns 4.2355 ns 5.1463 ns]
|
||||
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/8 time: [4.0889 ns 4.9714 ns 6.0779 ns]
|
||||
bucket_scalability/nbuckets/1 time: [4.8455 ns 4.8542 ns 4.8646 ns]
|
||||
bucket_scalability/nbuckets/4 time: [4.5663 ns 4.5722 ns 4.5787 ns]
|
||||
bucket_scalability/nbuckets/8 time: [4.5531 ns 4.5670 ns 4.5842 ns]
|
||||
bucket_scalability/nbuckets/16 time: [4.6392 ns 4.6524 ns 4.6685 ns]
|
||||
bucket_scalability/nbuckets/32 time: [6.0302 ns 6.0439 ns 6.0589 ns]
|
||||
bucket_scalability/nbuckets/64 time: [10.608 ns 10.644 ns 10.691 ns]
|
||||
bucket_scalability/nbuckets/128 time: [22.178 ns 22.316 ns 22.483 ns]
|
||||
bucket_scalability/nbuckets/256 time: [42.190 ns 42.328 ns 42.492 ns]
|
||||
|
||||
Results on a Hetzner AX102 AMD Ryzen 9 7950X3D 16-Core Processor
|
||||
|
||||
@@ -362,5 +424,13 @@ propagation_of_cached_label_value__naive/nthreads/8 time: [164.24 ns 170.1
|
||||
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/1 time: [2.2915 ns 2.2960 ns 2.3012 ns]
|
||||
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/4 time: [2.5726 ns 2.6158 ns 2.6624 ns]
|
||||
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/8 time: [2.7068 ns 2.8243 ns 2.9824 ns]
|
||||
bucket_scalability/nbuckets/1 time: [6.3998 ns 6.4288 ns 6.4684 ns]
|
||||
bucket_scalability/nbuckets/4 time: [6.3603 ns 6.3620 ns 6.3637 ns]
|
||||
bucket_scalability/nbuckets/8 time: [6.1646 ns 6.1654 ns 6.1667 ns]
|
||||
bucket_scalability/nbuckets/16 time: [6.1341 ns 6.1391 ns 6.1454 ns]
|
||||
bucket_scalability/nbuckets/32 time: [8.2206 ns 8.2254 ns 8.2301 ns]
|
||||
bucket_scalability/nbuckets/64 time: [13.988 ns 13.994 ns 14.000 ns]
|
||||
bucket_scalability/nbuckets/128 time: [28.180 ns 28.216 ns 28.251 ns]
|
||||
bucket_scalability/nbuckets/256 time: [54.914 ns 54.931 ns 54.951 ns]
|
||||
|
||||
*/
|
||||
|
||||
@@ -91,4 +91,14 @@ impl FeatureResolver {
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_feature_flag_boolean(&self, flag_key: &str) -> Result<bool, PostHogEvaluationError> {
|
||||
if let Some(inner) = &self.inner {
|
||||
inner.feature_store().is_feature_flag_boolean(flag_key)
|
||||
} else {
|
||||
Err(PostHogEvaluationError::NotAvailable(
|
||||
"PostHog integration is not enabled".to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3663,6 +3663,46 @@ async fn read_tar_eof(mut reader: (impl tokio::io::AsyncRead + Unpin)) -> anyhow
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn tenant_evaluate_feature_flag(
|
||||
request: Request<Body>,
|
||||
_cancel: CancellationToken,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
let tenant_shard_id: TenantShardId = parse_request_param(&request, "tenant_shard_id")?;
|
||||
check_permission(&request, Some(tenant_shard_id.tenant_id))?;
|
||||
|
||||
let flag: String = must_parse_query_param(&request, "flag")?;
|
||||
let as_type: String = must_parse_query_param(&request, "as")?;
|
||||
|
||||
let state = get_state(&request);
|
||||
|
||||
async {
|
||||
let tenant = state
|
||||
.tenant_manager
|
||||
.get_attached_tenant_shard(tenant_shard_id)?;
|
||||
if as_type == "boolean" {
|
||||
let result = tenant.feature_resolver.evaluate_boolean(&flag, tenant_shard_id.tenant_id);
|
||||
let result = result.map(|_| true).map_err(|e| e.to_string());
|
||||
json_response(StatusCode::OK, result)
|
||||
} else if as_type == "multivariate" {
|
||||
let result = tenant.feature_resolver.evaluate_multivariate(&flag, tenant_shard_id.tenant_id).map_err(|e| e.to_string());
|
||||
json_response(StatusCode::OK, result)
|
||||
} else {
|
||||
// Auto infer the type of the feature flag.
|
||||
let is_boolean = tenant.feature_resolver.is_feature_flag_boolean(&flag).map_err(|e| ApiError::InternalServerError(anyhow::anyhow!("{e}")))?;
|
||||
if is_boolean {
|
||||
let result = tenant.feature_resolver.evaluate_boolean(&flag, tenant_shard_id.tenant_id);
|
||||
let result = result.map(|_| true).map_err(|e| e.to_string());
|
||||
json_response(StatusCode::OK, result)
|
||||
} else {
|
||||
let result = tenant.feature_resolver.evaluate_multivariate(&flag, tenant_shard_id.tenant_id).map_err(|e| e.to_string());
|
||||
json_response(StatusCode::OK, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
.instrument(info_span!("tenant_evaluate_feature_flag", tenant_id = %tenant_shard_id.tenant_id, shard_id = %tenant_shard_id.shard_slug()))
|
||||
.await
|
||||
}
|
||||
|
||||
/// Common functionality of all the HTTP API handlers.
|
||||
///
|
||||
/// - Adds a tracing span to each request (by `request_span`)
|
||||
@@ -4039,5 +4079,8 @@ pub fn make_router(
|
||||
"/v1/tenant/:tenant_shard_id/timeline/:timeline_id/activate_post_import",
|
||||
|r| api_handler(r, activate_post_import_handler),
|
||||
)
|
||||
.get("/v1/tenant/:tenant_shard_id/feature_flag", |r| {
|
||||
api_handler(r, tenant_evaluate_feature_flag)
|
||||
})
|
||||
.any(handler_404))
|
||||
}
|
||||
|
||||
@@ -1312,11 +1312,44 @@ impl EvictionsWithLowResidenceDuration {
|
||||
//
|
||||
// Roughly logarithmic scale.
|
||||
const STORAGE_IO_TIME_BUCKETS: &[f64] = &[
|
||||
0.000030, // 30 usec
|
||||
0.001000, // 1000 usec
|
||||
0.030, // 30 ms
|
||||
1.000, // 1000 ms
|
||||
30.000, // 30000 ms
|
||||
0.00005, // 50us
|
||||
0.00006, // 60us
|
||||
0.00007, // 70us
|
||||
0.00008, // 80us
|
||||
0.00009, // 90us
|
||||
0.0001, // 100us
|
||||
0.000110, // 110us
|
||||
0.000120, // 120us
|
||||
0.000130, // 130us
|
||||
0.000140, // 140us
|
||||
0.000150, // 150us
|
||||
0.000160, // 160us
|
||||
0.000170, // 170us
|
||||
0.000180, // 180us
|
||||
0.000190, // 190us
|
||||
0.000200, // 200us
|
||||
0.000210, // 210us
|
||||
0.000220, // 220us
|
||||
0.000230, // 230us
|
||||
0.000240, // 240us
|
||||
0.000250, // 250us
|
||||
0.000300, // 300us
|
||||
0.000350, // 350us
|
||||
0.000400, // 400us
|
||||
0.000450, // 450us
|
||||
0.000500, // 500us
|
||||
0.000600, // 600us
|
||||
0.000700, // 700us
|
||||
0.000800, // 800us
|
||||
0.000900, // 900us
|
||||
0.001000, // 1ms
|
||||
0.002000, // 2ms
|
||||
0.003000, // 3ms
|
||||
0.004000, // 4ms
|
||||
0.005000, // 5ms
|
||||
0.01000, // 10ms
|
||||
0.02000, // 20ms
|
||||
0.05000, // 50ms
|
||||
];
|
||||
|
||||
/// VirtualFile fs operation variants.
|
||||
|
||||
@@ -383,7 +383,7 @@ pub struct TenantShard {
|
||||
|
||||
l0_flush_global_state: L0FlushGlobalState,
|
||||
|
||||
feature_resolver: FeatureResolver,
|
||||
pub(crate) feature_resolver: FeatureResolver,
|
||||
}
|
||||
impl std::fmt::Debug for TenantShard {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
@@ -5832,6 +5832,7 @@ pub(crate) mod harness {
|
||||
pub conf: &'static PageServerConf,
|
||||
pub tenant_conf: pageserver_api::models::TenantConfig,
|
||||
pub tenant_shard_id: TenantShardId,
|
||||
pub shard_identity: ShardIdentity,
|
||||
pub generation: Generation,
|
||||
pub shard: ShardIndex,
|
||||
pub remote_storage: GenericRemoteStorage,
|
||||
@@ -5899,6 +5900,7 @@ pub(crate) mod harness {
|
||||
conf,
|
||||
tenant_conf,
|
||||
tenant_shard_id,
|
||||
shard_identity,
|
||||
generation,
|
||||
shard,
|
||||
remote_storage,
|
||||
@@ -5960,8 +5962,7 @@ pub(crate) mod harness {
|
||||
&ShardParameters::default(),
|
||||
))
|
||||
.unwrap(),
|
||||
// This is a legacy/test code path: sharding isn't supported here.
|
||||
ShardIdentity::unsharded(),
|
||||
self.shard_identity,
|
||||
Some(walredo_mgr),
|
||||
self.tenant_shard_id,
|
||||
self.remote_storage.clone(),
|
||||
@@ -6083,6 +6084,7 @@ mod tests {
|
||||
use timeline::compaction::{KeyHistoryRetention, KeyLogAtLsn};
|
||||
use timeline::{CompactOptions, DeltaLayerTestDesc, VersionedKeySpaceQuery};
|
||||
use utils::id::TenantId;
|
||||
use utils::shard::{ShardCount, ShardNumber};
|
||||
|
||||
use super::*;
|
||||
use crate::DEFAULT_PG_VERSION;
|
||||
@@ -9418,6 +9420,77 @@ mod tests {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_failed_flush_should_not_update_disk_consistent_lsn() -> anyhow::Result<()> {
|
||||
//
|
||||
// Setup
|
||||
//
|
||||
let harness = TenantHarness::create_custom(
|
||||
"test_failed_flush_should_not_upload_disk_consistent_lsn",
|
||||
pageserver_api::models::TenantConfig::default(),
|
||||
TenantId::generate(),
|
||||
ShardIdentity::new(ShardNumber(0), ShardCount(4), ShardStripeSize(128)).unwrap(),
|
||||
Generation::new(1),
|
||||
)
|
||||
.await?;
|
||||
let (tenant, ctx) = harness.load().await;
|
||||
|
||||
let timeline = tenant
|
||||
.create_test_timeline(TIMELINE_ID, Lsn(0x10), DEFAULT_PG_VERSION, &ctx)
|
||||
.await?;
|
||||
assert_eq!(timeline.get_shard_identity().count, ShardCount(4));
|
||||
let mut writer = timeline.writer().await;
|
||||
writer
|
||||
.put(
|
||||
*TEST_KEY,
|
||||
Lsn(0x20),
|
||||
&Value::Image(test_img("foo at 0x20")),
|
||||
&ctx,
|
||||
)
|
||||
.await?;
|
||||
writer.finish_write(Lsn(0x20));
|
||||
drop(writer);
|
||||
timeline.freeze_and_flush().await.unwrap();
|
||||
|
||||
timeline.remote_client.wait_completion().await.unwrap();
|
||||
let disk_consistent_lsn = timeline.get_disk_consistent_lsn();
|
||||
let remote_consistent_lsn = timeline.get_remote_consistent_lsn_projected();
|
||||
assert_eq!(Some(disk_consistent_lsn), remote_consistent_lsn);
|
||||
|
||||
//
|
||||
// Test
|
||||
//
|
||||
|
||||
let mut writer = timeline.writer().await;
|
||||
writer
|
||||
.put(
|
||||
*TEST_KEY,
|
||||
Lsn(0x30),
|
||||
&Value::Image(test_img("foo at 0x30")),
|
||||
&ctx,
|
||||
)
|
||||
.await?;
|
||||
writer.finish_write(Lsn(0x30));
|
||||
drop(writer);
|
||||
|
||||
fail::cfg(
|
||||
"flush-layer-before-update-remote-consistent-lsn",
|
||||
"return()",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let flush_res = timeline.freeze_and_flush().await;
|
||||
// if flush failed, the disk/remote consistent LSN should not be updated
|
||||
assert!(flush_res.is_err());
|
||||
assert_eq!(disk_consistent_lsn, timeline.get_disk_consistent_lsn());
|
||||
assert_eq!(
|
||||
remote_consistent_lsn,
|
||||
timeline.get_remote_consistent_lsn_projected()
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "testing")]
|
||||
#[tokio::test]
|
||||
async fn test_simple_bottom_most_compaction_deltas_1() -> anyhow::Result<()> {
|
||||
|
||||
@@ -4767,7 +4767,10 @@ impl Timeline {
|
||||
|| !flushed_to_lsn.is_valid()
|
||||
);
|
||||
|
||||
if flushed_to_lsn < frozen_to_lsn && self.shard_identity.count.count() > 1 {
|
||||
if flushed_to_lsn < frozen_to_lsn
|
||||
&& self.shard_identity.count.count() > 1
|
||||
&& result.is_ok()
|
||||
{
|
||||
// If our layer flushes didn't carry disk_consistent_lsn up to the `to_lsn` advertised
|
||||
// to us via layer_flush_start_rx, then advance it here.
|
||||
//
|
||||
@@ -4946,6 +4949,10 @@ impl Timeline {
|
||||
return Err(FlushLayerError::Cancelled);
|
||||
}
|
||||
|
||||
fail_point!("flush-layer-before-update-remote-consistent-lsn", |_| {
|
||||
Err(FlushLayerError::Other(anyhow!("failpoint").into()))
|
||||
});
|
||||
|
||||
let disk_consistent_lsn = Lsn(lsn_range.end.0 - 1);
|
||||
|
||||
// The new on-disk layers are now in the layer map. We can remove the
|
||||
|
||||
@@ -11,19 +11,7 @@
|
||||
//! - => S3 as the source for the PGDATA instead of local filesystem
|
||||
//!
|
||||
//! TODOs before productionization:
|
||||
//! - ChunkProcessingJob size / ImportJob::total_size does not account for sharding.
|
||||
//! => produced image layers likely too small.
|
||||
//! - ChunkProcessingJob should cut up an ImportJob to hit exactly target image layer size.
|
||||
//! - asserts / unwraps need to be replaced with errors
|
||||
//! - don't trust remote objects will be small (=prevent OOMs in those cases)
|
||||
//! - limit all in-memory buffers in size, or download to disk and read from there
|
||||
//! - limit task concurrency
|
||||
//! - generally play nice with other tenants in the system
|
||||
//! - importbucket is different bucket than main pageserver storage, so, should be fine wrt S3 rate limits
|
||||
//! - but concerns like network bandwidth, local disk write bandwidth, local disk capacity, etc
|
||||
//! - integrate with layer eviction system
|
||||
//! - audit for Tenant::cancel nor Timeline::cancel responsivity
|
||||
//! - audit for Tenant/Timeline gate holding (we spawn tokio tasks during this flow!)
|
||||
//!
|
||||
//! An incomplete set of TODOs from the Hackathon:
|
||||
//! - version-specific CheckPointData (=> pgv abstraction, already exists for regular walingest)
|
||||
@@ -44,7 +32,7 @@ use pageserver_api::key::{
|
||||
rel_dir_to_key, rel_size_to_key, relmap_file_key, slru_block_to_key, slru_dir_to_key,
|
||||
slru_segment_size_to_key,
|
||||
};
|
||||
use pageserver_api::keyspace::{contiguous_range_len, is_contiguous_range, singleton_range};
|
||||
use pageserver_api::keyspace::{ShardedRange, singleton_range};
|
||||
use pageserver_api::models::{ShardImportProgress, ShardImportProgressV1, ShardImportStatus};
|
||||
use pageserver_api::reltag::{RelTag, SlruKind};
|
||||
use pageserver_api::shard::ShardIdentity;
|
||||
@@ -167,6 +155,7 @@ impl Planner {
|
||||
/// This function is and must remain pure: given the same input, it will generate the same import plan.
|
||||
async fn plan(mut self, import_config: &TimelineImportConfig) -> anyhow::Result<Plan> {
|
||||
let pgdata_lsn = Lsn(self.control_file.control_file_data().checkPoint).align();
|
||||
anyhow::ensure!(pgdata_lsn.is_valid());
|
||||
|
||||
let datadir = PgDataDir::new(&self.storage).await?;
|
||||
|
||||
@@ -249,14 +238,22 @@ impl Planner {
|
||||
});
|
||||
|
||||
// Assigns parts of key space to later parallel jobs
|
||||
// Note: The image layers produced here may have gaps, meaning,
|
||||
// there is not an image for each key in the layer's key range.
|
||||
// The read path stops traversal at the first image layer, regardless
|
||||
// of whether a base image has been found for a key or not.
|
||||
// (Concept of sparse image layers doesn't exist.)
|
||||
// This behavior is exactly right for the base image layers we're producing here.
|
||||
// But, since no other place in the code currently produces image layers with gaps,
|
||||
// it seems noteworthy.
|
||||
let mut last_end_key = Key::MIN;
|
||||
let mut current_chunk = Vec::new();
|
||||
let mut current_chunk_size: usize = 0;
|
||||
let mut jobs = Vec::new();
|
||||
for task in std::mem::take(&mut self.tasks).into_iter() {
|
||||
if current_chunk_size + task.total_size()
|
||||
> import_config.import_job_soft_size_limit.into()
|
||||
{
|
||||
let task_size = task.total_size(&self.shard);
|
||||
let projected_chunk_size = current_chunk_size.saturating_add(task_size);
|
||||
if projected_chunk_size > import_config.import_job_soft_size_limit.into() {
|
||||
let key_range = last_end_key..task.key_range().start;
|
||||
jobs.push(ChunkProcessingJob::new(
|
||||
key_range.clone(),
|
||||
@@ -266,7 +263,7 @@ impl Planner {
|
||||
last_end_key = key_range.end;
|
||||
current_chunk_size = 0;
|
||||
}
|
||||
current_chunk_size += task.total_size();
|
||||
current_chunk_size = current_chunk_size.saturating_add(task_size);
|
||||
current_chunk.push(task);
|
||||
}
|
||||
jobs.push(ChunkProcessingJob::new(
|
||||
@@ -604,18 +601,18 @@ impl PgDataDirDb {
|
||||
};
|
||||
|
||||
let path = datadir_path.join(rel_tag.to_segfile_name(segno));
|
||||
assert!(filesize % BLCKSZ as usize == 0); // TODO: this should result in an error
|
||||
anyhow::ensure!(filesize % BLCKSZ as usize == 0);
|
||||
let nblocks = filesize / BLCKSZ as usize;
|
||||
|
||||
PgDataDirDbFile {
|
||||
Ok(PgDataDirDbFile {
|
||||
path,
|
||||
filesize,
|
||||
rel_tag,
|
||||
segno,
|
||||
nblocks: Some(nblocks), // first non-cummulative sizes
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
.collect::<anyhow::Result<_, _>>()?;
|
||||
|
||||
// Set cummulative sizes. Do all of that math here, so that later we could easier
|
||||
// parallelize over segments and know with which segments we need to write relsize
|
||||
@@ -650,12 +647,22 @@ impl PgDataDirDb {
|
||||
trait ImportTask {
|
||||
fn key_range(&self) -> Range<Key>;
|
||||
|
||||
fn total_size(&self) -> usize {
|
||||
// TODO: revisit this
|
||||
if is_contiguous_range(&self.key_range()) {
|
||||
contiguous_range_len(&self.key_range()) as usize * 8192
|
||||
fn total_size(&self, shard_identity: &ShardIdentity) -> usize {
|
||||
let range = ShardedRange::new(self.key_range(), shard_identity);
|
||||
let page_count = range.page_count();
|
||||
if page_count == u32::MAX {
|
||||
tracing::warn!(
|
||||
"Import task has non contiguous key range: {}..{}",
|
||||
self.key_range().start,
|
||||
self.key_range().end
|
||||
);
|
||||
|
||||
// Tasks should operate on contiguous ranges. It is unexpected for
|
||||
// ranges to violate this assumption. Calling code handles this by mapping
|
||||
// any task on a non contiguous range to its own image layer.
|
||||
usize::MAX
|
||||
} else {
|
||||
u32::MAX as usize
|
||||
page_count as usize * 8192
|
||||
}
|
||||
}
|
||||
|
||||
@@ -753,6 +760,8 @@ impl ImportTask for ImportRelBlocksTask {
|
||||
layer_writer: &mut ImageLayerWriter,
|
||||
ctx: &RequestContext,
|
||||
) -> anyhow::Result<usize> {
|
||||
const MAX_BYTE_RANGE_SIZE: usize = 128 * 1024 * 1024;
|
||||
|
||||
debug!("Importing relation file");
|
||||
|
||||
let (rel_tag, start_blk) = self.key_range.start.to_rel_block()?;
|
||||
@@ -777,7 +786,7 @@ impl ImportTask for ImportRelBlocksTask {
|
||||
assert_eq!(key.len(), 1);
|
||||
assert!(!acc.is_empty());
|
||||
assert!(acc_end > acc_start);
|
||||
if acc_end == start /* TODO additional max range check here, to limit memory consumption per task to X */ {
|
||||
if acc_end == start && end - acc_start <= MAX_BYTE_RANGE_SIZE {
|
||||
acc.push(key.pop().unwrap());
|
||||
Ok((acc, acc_start, end))
|
||||
} else {
|
||||
@@ -792,8 +801,8 @@ impl ImportTask for ImportRelBlocksTask {
|
||||
.get_range(&self.path, range_start.into_u64(), range_end.into_u64())
|
||||
.await?;
|
||||
let mut buf = Bytes::from(range_buf);
|
||||
// TODO: batched writes
|
||||
for key in keys {
|
||||
// The writer buffers writes internally
|
||||
let image = buf.split_to(8192);
|
||||
layer_writer.put_image(key, image, ctx).await?;
|
||||
nimages += 1;
|
||||
@@ -846,6 +855,9 @@ impl ImportTask for ImportSlruBlocksTask {
|
||||
debug!("Importing SLRU segment file {}", self.path);
|
||||
let buf = self.storage.get(&self.path).await?;
|
||||
|
||||
// TODO(vlad): Does timestamp to LSN work for imported timelines?
|
||||
// Probably not since we don't append the `xact_time` to it as in
|
||||
// [`WalIngest::ingest_xact_record`].
|
||||
let (kind, segno, start_blk) = self.key_range.start.to_slru_block()?;
|
||||
let (_kind, _segno, end_blk) = self.key_range.end.to_slru_block()?;
|
||||
let mut blknum = start_blk;
|
||||
|
||||
@@ -6,7 +6,7 @@ use bytes::Bytes;
|
||||
use postgres_ffi::ControlFileData;
|
||||
use remote_storage::{
|
||||
Download, DownloadError, DownloadKind, DownloadOpts, GenericRemoteStorage, Listing,
|
||||
ListingObject, RemotePath,
|
||||
ListingObject, RemotePath, RemoteStorageConfig,
|
||||
};
|
||||
use serde::de::DeserializeOwned;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
@@ -22,11 +22,9 @@ pub async fn new(
|
||||
location: &index_part_format::Location,
|
||||
cancel: CancellationToken,
|
||||
) -> Result<RemoteStorageWrapper, anyhow::Error> {
|
||||
// FIXME: we probably want some timeout, and we might be able to assume the max file
|
||||
// size on S3 is 1GiB (postgres segment size). But the problem is that the individual
|
||||
// downloaders don't know enough about concurrent downloads to make a guess on the
|
||||
// expected bandwidth and resulting best timeout.
|
||||
let timeout = std::time::Duration::from_secs(24 * 60 * 60);
|
||||
// Downloads should be reasonably sized. We do ranged reads for relblock raw data
|
||||
// and full reads for SLRU segments which are bounded by Postgres.
|
||||
let timeout = RemoteStorageConfig::DEFAULT_TIMEOUT;
|
||||
let location_storage = match location {
|
||||
#[cfg(feature = "testing")]
|
||||
index_part_format::Location::LocalFs { path } => {
|
||||
@@ -50,9 +48,12 @@ pub async fn new(
|
||||
.import_pgdata_aws_endpoint_url
|
||||
.clone()
|
||||
.map(|url| url.to_string()), // by specifying None here, remote_storage/aws-sdk-rust will infer from env
|
||||
concurrency_limit: 100.try_into().unwrap(), // TODO: think about this
|
||||
max_keys_per_list_response: Some(1000), // TODO: think about this
|
||||
upload_storage_class: None, // irrelevant
|
||||
// This matches the default import job concurrency. This is managed
|
||||
// separately from the usual S3 client, but the concern here is bandwidth
|
||||
// usage.
|
||||
concurrency_limit: 128.try_into().unwrap(),
|
||||
max_keys_per_list_response: Some(1000),
|
||||
upload_storage_class: None, // irrelevant
|
||||
},
|
||||
timeout,
|
||||
)
|
||||
|
||||
@@ -17,7 +17,9 @@ use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::auth::credentials::check_peer_addr_is_in_list;
|
||||
use crate::auth::{self, AuthError, ComputeUserInfoMaybeEndpoint, validate_password_and_exchange};
|
||||
use crate::auth::{
|
||||
self, AuthError, ComputeUserInfoMaybeEndpoint, IpPattern, validate_password_and_exchange,
|
||||
};
|
||||
use crate::cache::Cached;
|
||||
use crate::config::AuthenticationConfig;
|
||||
use crate::context::RequestContext;
|
||||
@@ -135,6 +137,16 @@ impl<'a, T> Backend<'a, T> {
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<'a, T, E> Backend<'a, Result<T, E>> {
|
||||
/// Very similar to [`std::option::Option::transpose`].
|
||||
/// This is most useful for error handling.
|
||||
pub(crate) fn transpose(self) -> Result<Backend<'a, T>, E> {
|
||||
match self {
|
||||
Self::ControlPlane(c, x) => x.map(|x| Backend::ControlPlane(c, x)),
|
||||
Self::Local(l) => Ok(Backend::Local(l)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct ComputeCredentials {
|
||||
pub(crate) info: ComputeUserInfo,
|
||||
@@ -272,7 +284,7 @@ async fn auth_quirks(
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
) -> auth::Result<(ComputeCredentials, Option<Vec<IpPattern>>)> {
|
||||
// If there's no project so far, that entails that client doesn't
|
||||
// support SNI or other means of passing the endpoint (project) name.
|
||||
// We now expect to see a very specific payload in the place of password.
|
||||
@@ -289,12 +301,15 @@ async fn auth_quirks(
|
||||
debug!("fetching authentication info and allowlists");
|
||||
|
||||
// check allowed list
|
||||
if config.ip_allowlist_check_enabled {
|
||||
let allowed_ips = if config.ip_allowlist_check_enabled {
|
||||
let allowed_ips = api.get_allowed_ips(ctx, &info).await?;
|
||||
if !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) {
|
||||
return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
|
||||
}
|
||||
}
|
||||
allowed_ips
|
||||
} else {
|
||||
Cached::new_uncached(Arc::new(vec![]))
|
||||
};
|
||||
|
||||
// check if a VPC endpoint ID is coming in and if yes, if it's allowed
|
||||
let access_blocks = api.get_block_public_or_vpc_access(ctx, &info).await?;
|
||||
@@ -353,7 +368,7 @@ async fn auth_quirks(
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(keys) => Ok(keys),
|
||||
Ok(keys) => Ok((keys, Some(allowed_ips.as_ref().clone()))),
|
||||
Err(e) => {
|
||||
if e.is_password_failed() {
|
||||
// The password could have been changed, so we invalidate the cache.
|
||||
@@ -405,39 +420,53 @@ async fn authenticate_with_secret(
|
||||
classic::authenticate(ctx, info, client, config, secret).await
|
||||
}
|
||||
|
||||
impl ControlPlaneClient {
|
||||
impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
|
||||
/// Get username from the credentials.
|
||||
pub(crate) fn get_user(&self) -> &str {
|
||||
match self {
|
||||
Self::ControlPlane(_, user_info) => &user_info.user,
|
||||
Self::Local(_) => "local",
|
||||
}
|
||||
}
|
||||
|
||||
/// Authenticate the client via the requested backend, possibly using credentials.
|
||||
#[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)]
|
||||
pub(crate) async fn authenticate(
|
||||
&self,
|
||||
self,
|
||||
ctx: &RequestContext,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
user_info: ComputeUserInfoMaybeEndpoint,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
debug!(
|
||||
user = &*user_info.user,
|
||||
project = user_info.endpoint(),
|
||||
"performing authentication using the console"
|
||||
);
|
||||
) -> auth::Result<(Backend<'a, ComputeCredentials>, Option<Vec<IpPattern>>)> {
|
||||
let res = match self {
|
||||
Self::ControlPlane(api, user_info) => {
|
||||
debug!(
|
||||
user = &*user_info.user,
|
||||
project = user_info.endpoint(),
|
||||
"performing authentication using the console"
|
||||
);
|
||||
|
||||
let credentials = auth_quirks(
|
||||
ctx,
|
||||
self,
|
||||
user_info,
|
||||
client,
|
||||
allow_cleartext,
|
||||
config,
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.await?;
|
||||
let (credentials, ip_allowlist) = auth_quirks(
|
||||
ctx,
|
||||
&*api,
|
||||
user_info,
|
||||
client,
|
||||
allow_cleartext,
|
||||
config,
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.await?;
|
||||
Ok((Backend::ControlPlane(api, credentials), ip_allowlist))
|
||||
}
|
||||
Self::Local(_) => {
|
||||
return Err(auth::AuthError::bad_auth_method("invalid for local proxy"));
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: replace with some metric
|
||||
info!("user successfully authenticated");
|
||||
|
||||
Ok(credentials)
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
@@ -507,25 +536,6 @@ impl ComputeConnectBackend for Backend<'_, ComputeCredentials> {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ControlPlaneWakeCompute<'a> {
|
||||
pub cplane: &'a ControlPlaneClient,
|
||||
pub creds: ComputeCredentials,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl ComputeConnectBackend for ControlPlaneWakeCompute<'_> {
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
|
||||
self.cplane.wake_compute(ctx, &self.creds.info).await
|
||||
}
|
||||
|
||||
fn get_keys(&self) -> &ComputeCredentialKeys {
|
||||
&self.creds.keys
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#![allow(clippy::unimplemented, clippy::unwrap_used)]
|
||||
@@ -542,7 +552,6 @@ mod tests {
|
||||
use postgres_protocol::message::backend::Message as PgMessage;
|
||||
use postgres_protocol::message::frontend;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
|
||||
use tokio_util::task::TaskTracker;
|
||||
|
||||
use super::jwt::JwkCache;
|
||||
use super::{AuthRateLimiter, auth_quirks};
|
||||
@@ -693,7 +702,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn auth_quirks_scram() {
|
||||
let (mut client, server) = tokio::io::duplex(1024);
|
||||
let mut stream = PqStream::new(Stream::from_raw(server), TaskTracker::new().token());
|
||||
let mut stream = PqStream::new(Stream::from_raw(server));
|
||||
|
||||
let ctx = RequestContext::test();
|
||||
let api = Auth {
|
||||
@@ -775,7 +784,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn auth_quirks_cleartext() {
|
||||
let (mut client, server) = tokio::io::duplex(1024);
|
||||
let mut stream = PqStream::new(Stream::from_raw(server), TaskTracker::new().token());
|
||||
let mut stream = PqStream::new(Stream::from_raw(server));
|
||||
|
||||
let ctx = RequestContext::test();
|
||||
let api = Auth {
|
||||
@@ -829,7 +838,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn auth_quirks_password_hack() {
|
||||
let (mut client, server) = tokio::io::duplex(1024);
|
||||
let mut stream = PqStream::new(Stream::from_raw(server), TaskTracker::new().token());
|
||||
let mut stream = PqStream::new(Stream::from_raw(server));
|
||||
|
||||
let ctx = RequestContext::test();
|
||||
let api = Auth {
|
||||
@@ -878,7 +887,7 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(creds.info.endpoint, "my-endpoint");
|
||||
assert_eq!(creds.0.info.endpoint, "my-endpoint");
|
||||
|
||||
handle.await.unwrap();
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
//! Client authentication mechanisms.
|
||||
|
||||
pub mod backend;
|
||||
pub use backend::{Backend, ControlPlaneWakeCompute};
|
||||
pub use backend::Backend;
|
||||
|
||||
mod credentials;
|
||||
pub(crate) use credentials::{
|
||||
|
||||
@@ -18,7 +18,6 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_rustls::TlsConnector;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tokio_util::task::task_tracker::TaskTrackerToken;
|
||||
use tracing::{Instrument, error, info};
|
||||
use utils::project_git_version;
|
||||
use utils::sentry_init::init_sentry;
|
||||
@@ -227,8 +226,7 @@ pub(super) async fn task_main(
|
||||
let dest_suffix = Arc::clone(&dest_suffix);
|
||||
let compute_tls_config = compute_tls_config.clone();
|
||||
|
||||
let tracker = connections.token();
|
||||
tokio::spawn(
|
||||
connections.spawn(
|
||||
async move {
|
||||
socket
|
||||
.set_nodelay(true)
|
||||
@@ -251,7 +249,6 @@ pub(super) async fn task_main(
|
||||
compute_tls_config,
|
||||
tls_server_end_point,
|
||||
socket,
|
||||
tracker,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -277,11 +274,10 @@ const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmod
|
||||
async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
ctx: &RequestContext,
|
||||
raw_stream: S,
|
||||
tracker: TaskTrackerToken,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
) -> anyhow::Result<(Stream<S>, TaskTrackerToken)> {
|
||||
let mut stream = PqStream::new(Stream::from_raw(raw_stream), tracker);
|
||||
) -> anyhow::Result<Stream<S>> {
|
||||
let mut stream = PqStream::new(Stream::from_raw(raw_stream));
|
||||
|
||||
let msg = stream.read_startup_packet().await?;
|
||||
use pq_proto::FeStartupPacket::SslRequest;
|
||||
@@ -295,7 +291,7 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
// Upgrade raw stream into a secure TLS-backed stream.
|
||||
// NOTE: We've consumed `tls`; this fact will be used later.
|
||||
|
||||
let (raw, read_buf, tracker) = stream.into_inner();
|
||||
let (raw, read_buf) = stream.into_inner();
|
||||
// TODO: Normally, client doesn't send any data before
|
||||
// server says TLS handshake is ok and read_buf is empty.
|
||||
// However, you could imagine pipelining of postgres
|
||||
@@ -306,16 +302,13 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
bail!("data is sent before server replied with EncryptionResponse");
|
||||
}
|
||||
|
||||
Ok((
|
||||
Stream::Tls {
|
||||
tls: Box::new(
|
||||
raw.upgrade(tls_config, !ctx.has_private_peer_addr())
|
||||
.await?,
|
||||
),
|
||||
tls_server_end_point,
|
||||
},
|
||||
tracker,
|
||||
))
|
||||
Ok(Stream::Tls {
|
||||
tls: Box::new(
|
||||
raw.upgrade(tls_config, !ctx.has_private_peer_addr())
|
||||
.await?,
|
||||
),
|
||||
tls_server_end_point,
|
||||
})
|
||||
}
|
||||
unexpected => {
|
||||
info!(
|
||||
@@ -336,10 +329,8 @@ async fn handle_client(
|
||||
compute_tls_config: Option<Arc<rustls::ClientConfig>>,
|
||||
tls_server_end_point: TlsServerEndPoint,
|
||||
stream: impl AsyncRead + AsyncWrite + Unpin,
|
||||
tracker: TaskTrackerToken,
|
||||
) -> anyhow::Result<()> {
|
||||
let (mut tls_stream, _tracker) =
|
||||
ssl_handshake(&ctx, stream, tracker, tls_config, tls_server_end_point).await?;
|
||||
let mut tls_stream = ssl_handshake(&ctx, stream, tls_config, tls_server_end_point).await?;
|
||||
|
||||
// Cut off first part of the SNI domain
|
||||
// We receive required destination details in the format of
|
||||
|
||||
@@ -323,7 +323,7 @@ impl CancellationHandler {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_key(self: Arc<Self>) -> Session {
|
||||
pub(crate) fn get_key(self: &Arc<Self>) -> Session {
|
||||
// we intentionally generate a random "backend pid" and "secret key" here.
|
||||
// we use the corresponding u64 as an identifier for the
|
||||
// actual endpoint+pid+secret for postgres/pgbouncer.
|
||||
@@ -340,7 +340,7 @@ impl CancellationHandler {
|
||||
Session {
|
||||
key,
|
||||
redis_key,
|
||||
cancellation_handler: self,
|
||||
cancellation_handler: Arc::clone(self),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::TryFutureExt;
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tokio_util::task::task_tracker::TaskTrackerToken;
|
||||
use tracing::{Instrument, debug, error, info};
|
||||
|
||||
use crate::auth::backend::ConsoleRedirectBackend;
|
||||
@@ -15,8 +14,10 @@ use crate::metrics::{Metrics, NumClientConnectionsGuard};
|
||||
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
|
||||
use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute};
|
||||
use crate::proxy::handshake::{HandshakeData, handshake};
|
||||
use crate::proxy::passthrough::passthrough;
|
||||
use crate::proxy::{ClientRequestError, prepare_client_connection, run_until_cancelled};
|
||||
use crate::proxy::passthrough::ProxyPassthrough;
|
||||
use crate::proxy::{
|
||||
ClientRequestError, ErrorSource, prepare_client_connection, run_until_cancelled,
|
||||
};
|
||||
|
||||
pub async fn task_main(
|
||||
config: &'static ProxyConfig,
|
||||
@@ -34,6 +35,7 @@ pub async fn task_main(
|
||||
socket2::SockRef::from(&listener).set_keepalive(true)?;
|
||||
|
||||
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
|
||||
while let Some(accept_result) =
|
||||
run_until_cancelled(listener.accept(), &cancellation_token).await
|
||||
@@ -47,11 +49,11 @@ pub async fn task_main(
|
||||
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
let cancellation_handler = Arc::clone(&cancellation_handler);
|
||||
let cancellations = cancellations.clone();
|
||||
|
||||
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
|
||||
|
||||
let tracker = connections.token();
|
||||
tokio::spawn(async move {
|
||||
connections.spawn(async move {
|
||||
let (socket, peer_addr) = match read_proxy_protocol(socket).await {
|
||||
Err(e) => {
|
||||
error!("per-client task finished with an error: {e:#}");
|
||||
@@ -101,80 +103,99 @@ pub async fn task_main(
|
||||
&config.region,
|
||||
);
|
||||
|
||||
let span = ctx.span();
|
||||
let mut slot = Some(ctx);
|
||||
let res = handle_client(
|
||||
config,
|
||||
backend,
|
||||
&mut slot,
|
||||
&ctx,
|
||||
cancellation_handler,
|
||||
socket,
|
||||
conn_gauge,
|
||||
tracker,
|
||||
cancellations,
|
||||
)
|
||||
.instrument(span)
|
||||
.instrument(ctx.span())
|
||||
.boxed()
|
||||
.await;
|
||||
|
||||
match (slot, res) {
|
||||
(None, _) => {}
|
||||
(Some(ctx), Ok(())) => {
|
||||
ctx.success();
|
||||
}
|
||||
(Some(ctx), Err(e)) => {
|
||||
match res {
|
||||
Err(e) => {
|
||||
ctx.set_error_kind(e.get_error_kind());
|
||||
tracing::warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
|
||||
error!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
|
||||
}
|
||||
Ok(None) => {
|
||||
ctx.set_success();
|
||||
}
|
||||
Ok(Some(p)) => {
|
||||
ctx.set_success();
|
||||
let _disconnect = ctx.log_connect();
|
||||
match p.proxy_pass(&config.connect_to_compute).await {
|
||||
Ok(()) => {}
|
||||
Err(ErrorSource::Client(e)) => {
|
||||
error!(
|
||||
?session_id,
|
||||
"per-client task finished with an IO error from the client: {e:#}"
|
||||
);
|
||||
}
|
||||
Err(ErrorSource::Compute(e)) => {
|
||||
error!(
|
||||
?session_id,
|
||||
"per-client task finished with an IO error from the compute: {e:#}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
connections.close();
|
||||
cancellations.close();
|
||||
drop(listener);
|
||||
|
||||
// Drain connections
|
||||
connections.wait().await;
|
||||
cancellations.wait().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
|
||||
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
config: &'static ProxyConfig,
|
||||
backend: &'static ConsoleRedirectBackend,
|
||||
ctx_slot: &mut Option<RequestContext>,
|
||||
ctx: &RequestContext,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
stream: S,
|
||||
conn_gauge: NumClientConnectionsGuard<'static>,
|
||||
tracker: TaskTrackerToken,
|
||||
) -> Result<(), ClientRequestError> {
|
||||
let protocol = ctx_slot.as_ref().expect("context must be set").protocol();
|
||||
debug!(%protocol, "handling interactive connection from client");
|
||||
cancellations: tokio_util::task::task_tracker::TaskTracker,
|
||||
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
|
||||
debug!(
|
||||
protocol = %ctx.protocol(),
|
||||
"handling interactive connection from client"
|
||||
);
|
||||
|
||||
let metrics = &Metrics::get().proxy;
|
||||
let request_gauge = metrics.connection_requests.guard(protocol);
|
||||
let proto = ctx.protocol();
|
||||
let request_gauge = metrics.connection_requests.guard(proto);
|
||||
|
||||
let tls = config.tls_config.load();
|
||||
let tls = tls.as_deref();
|
||||
|
||||
let data = {
|
||||
let ctx = ctx_slot.as_ref().expect("context must be set");
|
||||
let record_handshake_error = !ctx.has_private_peer_addr();
|
||||
let _pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
let do_handshake = handshake(ctx, stream, tracker, tls, record_handshake_error);
|
||||
tokio::time::timeout(config.handshake_timeout, do_handshake).await??
|
||||
};
|
||||
let record_handshake_error = !ctx.has_private_peer_addr();
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
let do_handshake = handshake(ctx, stream, tls, record_handshake_error);
|
||||
|
||||
let (mut stream, params) = match data {
|
||||
let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
|
||||
.await??
|
||||
{
|
||||
HandshakeData::Startup(stream, params) => (stream, params),
|
||||
HandshakeData::Cancel(cancel_key_data, tracker) => {
|
||||
HandshakeData::Cancel(cancel_key_data) => {
|
||||
// spawn a task to cancel the session, but don't wait for it
|
||||
tokio::spawn({
|
||||
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
||||
let ctx = ctx_slot.take().expect("context must be set");
|
||||
cancellations.spawn({
|
||||
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
||||
let ctx = ctx.clone();
|
||||
let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id());
|
||||
cancel_span.follows_from(tracing::Span::current());
|
||||
async move {
|
||||
let _tracker = tracker;
|
||||
cancellation_handler_clone
|
||||
.cancel_session(
|
||||
cancel_key_data,
|
||||
@@ -184,17 +205,15 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send + 'st
|
||||
backend.get_api(),
|
||||
)
|
||||
.await
|
||||
.inspect_err(|e| debug!(error = ?e, "cancel_session failed"))
|
||||
.ok();
|
||||
}
|
||||
.instrument(cancel_span)
|
||||
.inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok();
|
||||
}.instrument(cancel_span)
|
||||
});
|
||||
|
||||
return Ok(());
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
drop(pause);
|
||||
|
||||
let ctx = ctx_slot.as_ref().expect("context must be set");
|
||||
ctx.set_db_options(params.clone());
|
||||
|
||||
let (node_info, user_info, _ip_allowlist) = match backend
|
||||
@@ -209,13 +228,13 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send + 'st
|
||||
|
||||
let mut node = connect_to_compute(
|
||||
ctx,
|
||||
TcpMechanism {
|
||||
&TcpMechanism {
|
||||
user_info,
|
||||
params_compat: true,
|
||||
params: ¶ms,
|
||||
locks: &config.connect_compute_locks,
|
||||
},
|
||||
node_info,
|
||||
&node_info,
|
||||
config.wake_compute_retry_config,
|
||||
&config.connect_to_compute,
|
||||
)
|
||||
@@ -233,22 +252,17 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send + 'st
|
||||
// PqStream input buffer. Normally there is none, but our serverless npm
|
||||
// driver in pipeline mode sends startup, password and first query
|
||||
// immediately after opening the connection.
|
||||
let (stream, read_buf, tracker) = stream.into_inner();
|
||||
let (stream, read_buf) = stream.into_inner();
|
||||
node.stream.write_all(&read_buf).await?;
|
||||
|
||||
let ctx = ctx_slot.take().expect("context must be set");
|
||||
ctx.set_success();
|
||||
|
||||
tokio::spawn(passthrough(
|
||||
ctx,
|
||||
&config.connect_to_compute,
|
||||
stream,
|
||||
node,
|
||||
session,
|
||||
request_gauge,
|
||||
conn_gauge,
|
||||
tracker,
|
||||
));
|
||||
|
||||
Ok(())
|
||||
Ok(Some(ProxyPassthrough {
|
||||
client: stream,
|
||||
aux: node.aux.clone(),
|
||||
private_link_id: None,
|
||||
compute: node,
|
||||
session_id: ctx.session_id(),
|
||||
cancel: session,
|
||||
_req: request_gauge,
|
||||
_conn: conn_gauge,
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -38,7 +38,7 @@ pub struct RequestContext(
|
||||
/// I would typically use a RefCell but that would break the `Send` requirements
|
||||
/// so we need something with thread-safety. `TryLock` is a cheap alternative
|
||||
/// that offers similar semantics to a `RefCell` but with synchronisation.
|
||||
TryLock<Box<RequestContextInner>>,
|
||||
TryLock<RequestContextInner>,
|
||||
);
|
||||
|
||||
struct RequestContextInner {
|
||||
@@ -89,7 +89,7 @@ pub(crate) enum AuthMethod {
|
||||
impl Clone for RequestContext {
|
||||
fn clone(&self) -> Self {
|
||||
let inner = self.0.try_lock().expect("should not deadlock");
|
||||
let new = Box::new(RequestContextInner {
|
||||
let new = RequestContextInner {
|
||||
conn_info: inner.conn_info.clone(),
|
||||
session_id: inner.session_id,
|
||||
protocol: inner.protocol,
|
||||
@@ -117,7 +117,7 @@ impl Clone for RequestContext {
|
||||
disconnect_sender: None,
|
||||
latency_timer: LatencyTimer::noop(inner.protocol),
|
||||
disconnect_timestamp: inner.disconnect_timestamp,
|
||||
});
|
||||
};
|
||||
|
||||
Self(TryLock::new(new))
|
||||
}
|
||||
@@ -140,7 +140,7 @@ impl RequestContext {
|
||||
role = tracing::field::Empty,
|
||||
);
|
||||
|
||||
let inner = Box::new(RequestContextInner {
|
||||
let inner = RequestContextInner {
|
||||
conn_info,
|
||||
session_id,
|
||||
protocol,
|
||||
@@ -168,7 +168,7 @@ impl RequestContext {
|
||||
disconnect_sender: LOG_CHAN_DISCONNECT.get().and_then(|tx| tx.upgrade()),
|
||||
latency_timer: LatencyTimer::new(protocol),
|
||||
disconnect_timestamp: None,
|
||||
});
|
||||
};
|
||||
|
||||
Self(TryLock::new(inner))
|
||||
}
|
||||
@@ -522,7 +522,7 @@ impl Drop for RequestContextInner {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DisconnectLogger(Box<RequestContextInner>);
|
||||
pub struct DisconnectLogger(RequestContextInner);
|
||||
|
||||
impl Drop for DisconnectLogger {
|
||||
fn drop(&mut self) {
|
||||
|
||||
@@ -53,25 +53,6 @@ pub(crate) trait ConnectMechanism {
|
||||
fn update_connect_config(&self, conf: &mut compute::ConnCfg);
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T: ConnectMechanism + Sync> ConnectMechanism for &T {
|
||||
type Connection = T::Connection;
|
||||
type ConnectError = T::ConnectError;
|
||||
type Error = T::Error;
|
||||
async fn connect_once(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
node_info: &control_plane::CachedNodeInfo,
|
||||
config: &ComputeConfig,
|
||||
) -> Result<Self::Connection, Self::ConnectError> {
|
||||
T::connect_once(self, ctx, node_info, config).await
|
||||
}
|
||||
|
||||
fn update_connect_config(&self, conf: &mut compute::ConnCfg) {
|
||||
T::update_connect_config(self, conf);
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub(crate) trait ComputeConnectBackend {
|
||||
async fn wake_compute(
|
||||
@@ -124,8 +105,8 @@ impl ConnectMechanism for TcpMechanism<'_> {
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub(crate) async fn connect_to_compute<M: ConnectMechanism, B: ComputeConnectBackend>(
|
||||
ctx: &RequestContext,
|
||||
mechanism: M,
|
||||
backend: B,
|
||||
mechanism: &M,
|
||||
user_info: &B,
|
||||
wake_compute_retry_config: RetryConfig,
|
||||
compute: &ComputeConfig,
|
||||
) -> Result<M::Connection, M::Error>
|
||||
@@ -135,9 +116,9 @@ where
|
||||
{
|
||||
let mut num_retries = 0;
|
||||
let mut node_info =
|
||||
wake_compute(&mut num_retries, ctx, &backend, wake_compute_retry_config).await?;
|
||||
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?;
|
||||
|
||||
node_info.set_keys(backend.get_keys());
|
||||
node_info.set_keys(user_info.get_keys());
|
||||
mechanism.update_connect_config(&mut node_info.config);
|
||||
|
||||
// try once
|
||||
@@ -178,7 +159,7 @@ where
|
||||
let old_node_info = invalidate_cache(node_info);
|
||||
// TODO: increment num_retries?
|
||||
let mut node_info =
|
||||
wake_compute(&mut num_retries, ctx, &backend, wake_compute_retry_config).await?;
|
||||
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?;
|
||||
node_info.reuse_settings(old_node_info);
|
||||
|
||||
mechanism.update_connect_config(&mut node_info.config);
|
||||
|
||||
@@ -67,6 +67,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn copy_bidirectional_client_compute<Client, Compute>(
|
||||
client: &mut Client,
|
||||
compute: &mut Compute,
|
||||
|
||||
@@ -5,7 +5,6 @@ use pq_proto::{
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_util::task::task_tracker::TaskTrackerToken;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::auth::endpoint_sni;
|
||||
@@ -52,7 +51,7 @@ impl ReportableError for HandshakeError {
|
||||
|
||||
pub(crate) enum HandshakeData<S> {
|
||||
Startup(PqStream<Stream<S>>, StartupMessageParams),
|
||||
Cancel(CancelKeyData, TaskTrackerToken),
|
||||
Cancel(CancelKeyData),
|
||||
}
|
||||
|
||||
/// Establish a (most probably, secure) connection with the client.
|
||||
@@ -63,7 +62,6 @@ pub(crate) enum HandshakeData<S> {
|
||||
pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
ctx: &RequestContext,
|
||||
stream: S,
|
||||
tracker: TaskTrackerToken,
|
||||
mut tls: Option<&TlsConfig>,
|
||||
record_handshake_error: bool,
|
||||
) -> Result<HandshakeData<S>, HandshakeError> {
|
||||
@@ -73,7 +71,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
const PG_PROTOCOL_EARLIEST: ProtocolVersion = ProtocolVersion::new(3, 0);
|
||||
const PG_PROTOCOL_LATEST: ProtocolVersion = ProtocolVersion::new(3, 0);
|
||||
|
||||
let mut stream = PqStream::new(Stream::from_raw(stream), tracker);
|
||||
let mut stream = PqStream::new(Stream::from_raw(stream));
|
||||
loop {
|
||||
let msg = stream.read_startup_packet().await?;
|
||||
match msg {
|
||||
@@ -159,13 +157,15 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
let (_, tls_server_end_point) =
|
||||
tls.cert_resolver.resolve(conn_info.server_name());
|
||||
|
||||
stream.framed = Framed {
|
||||
stream: Stream::Tls {
|
||||
tls: Box::new(tls_stream),
|
||||
tls_server_end_point,
|
||||
stream = PqStream {
|
||||
framed: Framed {
|
||||
stream: Stream::Tls {
|
||||
tls: Box::new(tls_stream),
|
||||
tls_server_end_point,
|
||||
},
|
||||
read_buf,
|
||||
write_buf,
|
||||
},
|
||||
read_buf,
|
||||
write_buf,
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -248,7 +248,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
}
|
||||
FeStartupPacket::CancelRequest(cancel_key_data) => {
|
||||
info!(session_type = "cancellation", "successful handshake");
|
||||
break Ok(HandshakeData::Cancel(cancel_key_data, stream.tracker));
|
||||
break Ok(HandshakeData::Cancel(cancel_key_data));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,27 +10,26 @@ pub(crate) mod wake_compute;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub use copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute};
|
||||
use futures::TryFutureExt;
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
use itertools::Itertools;
|
||||
use once_cell::sync::OnceCell;
|
||||
use passthrough::passthrough;
|
||||
use pq_proto::{BeMessage as Be, CancelKeyData, StartupMessageParams};
|
||||
use regex::Regex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use smol_str::{SmolStr, format_smolstr};
|
||||
use smol_str::{SmolStr, ToSmolStr, format_smolstr};
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tokio_util::task::task_tracker::TaskTrackerToken;
|
||||
use tracing::{Instrument, debug, error, info, warn};
|
||||
|
||||
use self::connect_compute::{TcpMechanism, connect_to_compute};
|
||||
use self::passthrough::ProxyPassthrough;
|
||||
use crate::cancellation::{self, CancellationHandler};
|
||||
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::ReportableError;
|
||||
use crate::metrics::{Metrics, NumClientConnectionsGuard};
|
||||
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
|
||||
use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol};
|
||||
use crate::proxy::handshake::{HandshakeData, handshake};
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::stream::{PqStream, Stream};
|
||||
@@ -71,6 +70,7 @@ pub async fn task_main(
|
||||
socket2::SockRef::from(&listener).set_keepalive(true)?;
|
||||
|
||||
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
|
||||
while let Some(accept_result) =
|
||||
run_until_cancelled(listener.accept(), &cancellation_token).await
|
||||
@@ -84,12 +84,12 @@ pub async fn task_main(
|
||||
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
let cancellation_handler = Arc::clone(&cancellation_handler);
|
||||
let cancellations = cancellations.clone();
|
||||
|
||||
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
|
||||
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
|
||||
|
||||
let tracker = connections.token();
|
||||
tokio::spawn(async move {
|
||||
connections.spawn(async move {
|
||||
let (socket, conn_info) = match read_proxy_protocol(socket).await {
|
||||
Err(e) => {
|
||||
warn!("per-client task finished with an error: {e:#}");
|
||||
@@ -138,41 +138,60 @@ pub async fn task_main(
|
||||
crate::metrics::Protocol::Tcp,
|
||||
&config.region,
|
||||
);
|
||||
let span = ctx.span();
|
||||
let mut ctx = Some(ctx);
|
||||
|
||||
let res = handle_client(
|
||||
config,
|
||||
auth_backend,
|
||||
&mut ctx,
|
||||
&ctx,
|
||||
cancellation_handler,
|
||||
socket,
|
||||
ClientMode::Tcp,
|
||||
endpoint_rate_limiter2,
|
||||
conn_gauge,
|
||||
tracker,
|
||||
cancellations,
|
||||
)
|
||||
.instrument(span)
|
||||
.instrument(ctx.span())
|
||||
.boxed()
|
||||
.await;
|
||||
|
||||
match (ctx, res) {
|
||||
(None, _) => {}
|
||||
(Some(ctx), Ok(())) => {
|
||||
ctx.success();
|
||||
}
|
||||
(Some(ctx), Err(e)) => {
|
||||
match res {
|
||||
Err(e) => {
|
||||
ctx.set_error_kind(e.get_error_kind());
|
||||
warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
|
||||
}
|
||||
Ok(None) => {
|
||||
ctx.set_success();
|
||||
}
|
||||
Ok(Some(p)) => {
|
||||
ctx.set_success();
|
||||
let _disconnect = ctx.log_connect();
|
||||
match p.proxy_pass(&config.connect_to_compute).await {
|
||||
Ok(()) => {}
|
||||
Err(ErrorSource::Client(e)) => {
|
||||
warn!(
|
||||
?session_id,
|
||||
"per-client task finished with an IO error from the client: {e:#}"
|
||||
);
|
||||
}
|
||||
Err(ErrorSource::Compute(e)) => {
|
||||
error!(
|
||||
?session_id,
|
||||
"per-client task finished with an IO error from the compute: {e:#}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
connections.close();
|
||||
cancellations.close();
|
||||
drop(listener);
|
||||
|
||||
// Drain connections
|
||||
connections.wait().await;
|
||||
cancellations.wait().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -239,79 +258,46 @@ impl ReportableError for ClientRequestError {
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
|
||||
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
config: &'static ProxyConfig,
|
||||
auth_backend: &'static auth::Backend<'static, ()>,
|
||||
ctx_slot: &mut Option<RequestContext>,
|
||||
ctx: &RequestContext,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
stream: S,
|
||||
mode: ClientMode,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
conn_gauge: NumClientConnectionsGuard<'static>,
|
||||
tracker: TaskTrackerToken,
|
||||
) -> Result<(), ClientRequestError> {
|
||||
let cplane = match auth_backend {
|
||||
auth::Backend::ControlPlane(cplane, ()) => &**cplane,
|
||||
auth::Backend::Local(_) => unreachable!("local proxy does not run tcp proxy service"),
|
||||
};
|
||||
|
||||
let protocol = ctx_slot.as_ref().expect("context must be set").protocol();
|
||||
debug!(%protocol, "handling interactive connection from client");
|
||||
cancellations: tokio_util::task::task_tracker::TaskTracker,
|
||||
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
|
||||
debug!(
|
||||
protocol = %ctx.protocol(),
|
||||
"handling interactive connection from client"
|
||||
);
|
||||
|
||||
let metrics = &Metrics::get().proxy;
|
||||
let request_gauge = metrics.connection_requests.guard(protocol);
|
||||
let proto = ctx.protocol();
|
||||
let request_gauge = metrics.connection_requests.guard(proto);
|
||||
|
||||
let handshake_result: Result<_, ClientRequestError> = async {
|
||||
let tls = config.tls_config.load();
|
||||
let tls = tls.as_deref();
|
||||
let tls = config.tls_config.load();
|
||||
let tls = tls.as_deref();
|
||||
|
||||
let ctx = ctx_slot.as_ref().expect("context must be set");
|
||||
let record_handshake_error = !ctx.has_private_peer_addr();
|
||||
let data = {
|
||||
let _pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
tokio::time::timeout(
|
||||
config.handshake_timeout,
|
||||
handshake(
|
||||
ctx,
|
||||
stream,
|
||||
tracker,
|
||||
mode.handshake_tls(tls),
|
||||
record_handshake_error,
|
||||
),
|
||||
)
|
||||
.await??
|
||||
};
|
||||
let record_handshake_error = !ctx.has_private_peer_addr();
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
let do_handshake = handshake(ctx, stream, mode.handshake_tls(tls), record_handshake_error);
|
||||
|
||||
match data {
|
||||
HandshakeData::Startup(mut stream, params) => {
|
||||
ctx.set_db_options(params.clone());
|
||||
|
||||
let host = mode.hostname(stream.get_ref());
|
||||
let cn = tls.map(|tls| &tls.common_names);
|
||||
|
||||
// Extract credentials which we're going to use for auth.
|
||||
let result = auth::ComputeUserInfoMaybeEndpoint::parse(ctx, ¶ms, host, cn);
|
||||
let user_info = match result {
|
||||
Ok(user_info) => user_info,
|
||||
Err(e) => stream.throw_error(e, Some(ctx)).await?,
|
||||
};
|
||||
|
||||
let session = cancellation_handler.get_key();
|
||||
Ok(Some((stream, params, session, user_info)))
|
||||
}
|
||||
HandshakeData::Cancel(cancel_key_data, tracker) => {
|
||||
let ctx = ctx_slot.take().expect("context must be set");
|
||||
ctx.set_success();
|
||||
|
||||
let cancel_span = tracing::info_span!(parent: None, "cancel_session", session_id = ?ctx.session_id());
|
||||
let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
|
||||
.await??
|
||||
{
|
||||
HandshakeData::Startup(stream, params) => (stream, params),
|
||||
HandshakeData::Cancel(cancel_key_data) => {
|
||||
// spawn a task to cancel the session, but don't wait for it
|
||||
cancellations.spawn({
|
||||
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
||||
let ctx = ctx.clone();
|
||||
let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id());
|
||||
cancel_span.follows_from(tracing::Span::current());
|
||||
|
||||
// spawn a task to cancel the session, but don't wait for it
|
||||
tokio::spawn(async move {
|
||||
// ensure the proxy doesn't shutdown until we complete this task.
|
||||
let _tracker = tracker;
|
||||
|
||||
cancellation_handler
|
||||
async move {
|
||||
cancellation_handler_clone
|
||||
.cancel_session(
|
||||
cancel_key_data,
|
||||
ctx,
|
||||
@@ -319,108 +305,111 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send + 'st
|
||||
config.authentication_config.is_vpc_acccess_proxy,
|
||||
auth_backend.get_api(),
|
||||
)
|
||||
.instrument(cancel_span)
|
||||
.await
|
||||
.unwrap_or_else(|e| debug!(error = ?e, "cancel_session failed"));
|
||||
});
|
||||
.inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok();
|
||||
}.instrument(cancel_span)
|
||||
});
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
.await;
|
||||
|
||||
let Some((mut stream, params, session, user_info)) = handshake_result? else {
|
||||
return Ok(());
|
||||
};
|
||||
let ctx = ctx_slot.as_ref().expect("context must be set");
|
||||
drop(pause);
|
||||
|
||||
let auth_result: Result<_, ClientRequestError> = async {
|
||||
let user = user_info.user.clone();
|
||||
ctx.set_db_options(params.clone());
|
||||
|
||||
match cplane
|
||||
.authenticate(
|
||||
ctx,
|
||||
&mut stream,
|
||||
user_info,
|
||||
mode.allow_cleartext(),
|
||||
&config.authentication_config,
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(auth_result) => Ok(auth_result),
|
||||
Err(e) => {
|
||||
let db = params.get("database");
|
||||
let app = params.get("application_name");
|
||||
let params_span = tracing::info_span!("", ?user, ?db, ?app);
|
||||
stream
|
||||
.throw_error(e, Some(ctx))
|
||||
.instrument(params_span)
|
||||
.await?
|
||||
}
|
||||
}
|
||||
}
|
||||
.await;
|
||||
let hostname = mode.hostname(stream.get_ref());
|
||||
|
||||
let compute_creds = auth_result?;
|
||||
let common_names = tls.map(|tls| &tls.common_names);
|
||||
|
||||
let connect_result: Result<_, ClientRequestError> = async {
|
||||
let compute_user_info = compute_creds.info.clone();
|
||||
let params_compat = compute_user_info
|
||||
.options
|
||||
.get(NeonOptions::PARAMS_COMPAT)
|
||||
.is_some();
|
||||
// Extract credentials which we're going to use for auth.
|
||||
let result = auth_backend
|
||||
.as_ref()
|
||||
.map(|()| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, ¶ms, hostname, common_names))
|
||||
.transpose();
|
||||
|
||||
let mut node = connect_to_compute(
|
||||
let user_info = match result {
|
||||
Ok(user_info) => user_info,
|
||||
Err(e) => stream.throw_error(e, Some(ctx)).await?,
|
||||
};
|
||||
|
||||
let user = user_info.get_user().to_owned();
|
||||
let (user_info, _ip_allowlist) = match user_info
|
||||
.authenticate(
|
||||
ctx,
|
||||
TcpMechanism {
|
||||
user_info: compute_user_info,
|
||||
params_compat,
|
||||
params: ¶ms,
|
||||
locks: &config.connect_compute_locks,
|
||||
},
|
||||
auth::ControlPlaneWakeCompute {
|
||||
cplane,
|
||||
creds: compute_creds,
|
||||
},
|
||||
config.wake_compute_retry_config,
|
||||
&config.connect_to_compute,
|
||||
&mut stream,
|
||||
mode.allow_cleartext(),
|
||||
&config.authentication_config,
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
.or_else(|e| stream.throw_error(e, Some(ctx)))
|
||||
.await?;
|
||||
.await
|
||||
{
|
||||
Ok(auth_result) => auth_result,
|
||||
Err(e) => {
|
||||
let db = params.get("database");
|
||||
let app = params.get("application_name");
|
||||
let params_span = tracing::info_span!("", ?user, ?db, ?app);
|
||||
|
||||
session.write_cancel_key(node.cancel_closure.clone())?;
|
||||
prepare_client_connection(&node, *session.key(), &mut stream).await?;
|
||||
return stream
|
||||
.throw_error(e, Some(ctx))
|
||||
.instrument(params_span)
|
||||
.await?;
|
||||
}
|
||||
};
|
||||
|
||||
// Before proxy passing, forward to compute whatever data is left in the
|
||||
// PqStream input buffer. Normally there is none, but our serverless npm
|
||||
// driver in pipeline mode sends startup, password and first query
|
||||
// immediately after opening the connection.
|
||||
let (stream, read_buf, tracker) = stream.into_inner();
|
||||
node.stream.write_all(&read_buf).await?;
|
||||
let compute_user_info = match &user_info {
|
||||
auth::Backend::ControlPlane(_, info) => &info.info,
|
||||
auth::Backend::Local(_) => unreachable!("local proxy does not run tcp proxy service"),
|
||||
};
|
||||
let params_compat = compute_user_info
|
||||
.options
|
||||
.get(NeonOptions::PARAMS_COMPAT)
|
||||
.is_some();
|
||||
|
||||
Ok((node, stream, tracker))
|
||||
}
|
||||
.await;
|
||||
|
||||
let (node, stream, tracker) = connect_result?;
|
||||
|
||||
let ctx = ctx_slot.take().expect("context must be set");
|
||||
ctx.set_success();
|
||||
|
||||
tokio::spawn(passthrough(
|
||||
let mut node = connect_to_compute(
|
||||
ctx,
|
||||
&TcpMechanism {
|
||||
user_info: compute_user_info.clone(),
|
||||
params_compat,
|
||||
params: ¶ms,
|
||||
locks: &config.connect_compute_locks,
|
||||
},
|
||||
&user_info,
|
||||
config.wake_compute_retry_config,
|
||||
&config.connect_to_compute,
|
||||
stream,
|
||||
node,
|
||||
session,
|
||||
request_gauge,
|
||||
conn_gauge,
|
||||
tracker,
|
||||
));
|
||||
)
|
||||
.or_else(|e| stream.throw_error(e, Some(ctx)))
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
||||
let session = cancellation_handler_clone.get_key();
|
||||
|
||||
session.write_cancel_key(node.cancel_closure.clone())?;
|
||||
|
||||
prepare_client_connection(&node, *session.key(), &mut stream).await?;
|
||||
|
||||
// Before proxy passing, forward to compute whatever data is left in the
|
||||
// PqStream input buffer. Normally there is none, but our serverless npm
|
||||
// driver in pipeline mode sends startup, password and first query
|
||||
// immediately after opening the connection.
|
||||
let (stream, read_buf) = stream.into_inner();
|
||||
node.stream.write_all(&read_buf).await?;
|
||||
|
||||
let private_link_id = match ctx.extra() {
|
||||
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
|
||||
Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
|
||||
None => None,
|
||||
};
|
||||
|
||||
Ok(Some(ProxyPassthrough {
|
||||
client: stream,
|
||||
aux: node.aux.clone(),
|
||||
private_link_id,
|
||||
compute: node,
|
||||
session_id: ctx.session_id(),
|
||||
cancel: session,
|
||||
_req: request_gauge,
|
||||
_conn: conn_gauge,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Finish client connection initialization: confirm auth success, send params, etc.
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use smol_str::{SmolStr, ToSmolStr};
|
||||
use smol_str::SmolStr;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_util::task::task_tracker::TaskTrackerToken;
|
||||
use tracing::debug;
|
||||
use utils::measured_stream::MeasuredStream;
|
||||
|
||||
@@ -8,14 +7,13 @@ use super::copy_bidirectional::ErrorSource;
|
||||
use crate::cancellation;
|
||||
use crate::compute::PostgresConnection;
|
||||
use crate::config::ComputeConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::messages::MetricsAuxInfo;
|
||||
use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard};
|
||||
use crate::protocol2::ConnectionInfoExtra;
|
||||
use crate::stream::Stream;
|
||||
use crate::usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS};
|
||||
|
||||
/// Forward bytes in both directions (client <-> compute).
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub(crate) async fn proxy_pass(
|
||||
client: impl AsyncRead + AsyncWrite + Unpin,
|
||||
compute: impl AsyncRead + AsyncWrite + Unpin,
|
||||
@@ -63,53 +61,41 @@ pub(crate) async fn proxy_pass(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn passthrough<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
|
||||
ctx: RequestContext,
|
||||
compute_config: &'static ComputeConfig,
|
||||
pub(crate) struct ProxyPassthrough<S> {
|
||||
pub(crate) client: Stream<S>,
|
||||
pub(crate) compute: PostgresConnection,
|
||||
pub(crate) aux: MetricsAuxInfo,
|
||||
pub(crate) session_id: uuid::Uuid,
|
||||
pub(crate) private_link_id: Option<SmolStr>,
|
||||
pub(crate) cancel: cancellation::Session,
|
||||
|
||||
client: Stream<S>,
|
||||
compute: PostgresConnection,
|
||||
cancel: cancellation::Session,
|
||||
|
||||
_req: NumConnectionRequestsGuard<'static>,
|
||||
_conn: NumClientConnectionsGuard<'static>,
|
||||
_tracker: TaskTrackerToken,
|
||||
) {
|
||||
let session_id = ctx.session_id();
|
||||
let private_link_id = match ctx.extra() {
|
||||
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
|
||||
Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
|
||||
None => None,
|
||||
};
|
||||
|
||||
let _disconnect = ctx.log_connect();
|
||||
let res = proxy_pass(client, compute.stream, compute.aux, private_link_id).await;
|
||||
|
||||
match res {
|
||||
Ok(()) => {}
|
||||
Err(ErrorSource::Client(e)) => {
|
||||
tracing::warn!(
|
||||
session_id = ?session_id,
|
||||
"per-client task finished with an IO error from the client: {e:#}"
|
||||
);
|
||||
}
|
||||
Err(ErrorSource::Compute(e)) => {
|
||||
tracing::error!(
|
||||
session_id = ?session_id,
|
||||
"per-client task finished with an IO error from the compute: {e:#}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if let Err(err) = compute
|
||||
.cancel_closure
|
||||
.try_cancel_query(compute_config)
|
||||
.await
|
||||
{
|
||||
tracing::warn!(session_id = ?session_id, ?err, "could not cancel the query in the database");
|
||||
}
|
||||
|
||||
// we don't need a result. If the queue is full, we just log the error
|
||||
drop(cancel.remove_cancel_key());
|
||||
pub(crate) _req: NumConnectionRequestsGuard<'static>,
|
||||
pub(crate) _conn: NumClientConnectionsGuard<'static>,
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {
|
||||
pub(crate) async fn proxy_pass(
|
||||
self,
|
||||
compute_config: &ComputeConfig,
|
||||
) -> Result<(), ErrorSource> {
|
||||
let res = proxy_pass(
|
||||
self.client,
|
||||
self.compute.stream,
|
||||
self.aux,
|
||||
self.private_link_id,
|
||||
)
|
||||
.await;
|
||||
if let Err(err) = self
|
||||
.compute
|
||||
.cancel_closure
|
||||
.try_cancel_query(compute_config)
|
||||
.await
|
||||
{
|
||||
tracing::warn!(session_id = ?self.session_id, ?err, "could not cancel the query in the database");
|
||||
}
|
||||
|
||||
drop(self.cancel.remove_cancel_key()); // we don't need a result. If the queue is full, we just log the error
|
||||
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,7 +38,6 @@ async fn proxy_mitm(
|
||||
let (end_client, startup) = match handshake(
|
||||
&RequestContext::test(),
|
||||
client1,
|
||||
TaskTracker::new().token(),
|
||||
Some(&server_config1),
|
||||
false,
|
||||
)
|
||||
@@ -46,7 +45,7 @@ async fn proxy_mitm(
|
||||
.unwrap()
|
||||
{
|
||||
HandshakeData::Startup(stream, params) => (stream, params),
|
||||
HandshakeData::Cancel(_, _) => panic!("cancellation not supported"),
|
||||
HandshakeData::Cancel(_) => panic!("cancellation not supported"),
|
||||
};
|
||||
|
||||
let mut end_server = tokio_util::codec::Framed::new(end_server, PgFrame);
|
||||
|
||||
@@ -15,7 +15,6 @@ use rstest::rstest;
|
||||
use rustls::crypto::ring;
|
||||
use rustls::pki_types;
|
||||
use tokio::io::DuplexStream;
|
||||
use tokio_util::task::TaskTracker;
|
||||
use tracing_test::traced_test;
|
||||
|
||||
use super::connect_compute::ConnectMechanism;
|
||||
@@ -179,12 +178,10 @@ async fn dummy_proxy(
|
||||
auth: impl TestAuth + Send,
|
||||
) -> anyhow::Result<()> {
|
||||
let (client, _) = read_proxy_protocol(client).await?;
|
||||
let t = TaskTracker::new().token();
|
||||
let mut stream =
|
||||
match handshake(&RequestContext::test(), client, t, tls.as_ref(), false).await? {
|
||||
HandshakeData::Startup(stream, _) => stream,
|
||||
HandshakeData::Cancel(_, _) => bail!("cancellation not supported"),
|
||||
};
|
||||
let mut stream = match handshake(&RequestContext::test(), client, tls.as_ref(), false).await? {
|
||||
HandshakeData::Startup(stream, _) => stream,
|
||||
HandshakeData::Cancel(_) => bail!("cancellation not supported"),
|
||||
};
|
||||
|
||||
auth.authenticate(&mut stream).await?;
|
||||
|
||||
@@ -625,7 +622,7 @@ async fn connect_to_compute_success() {
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let config = config();
|
||||
connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config)
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -639,7 +636,7 @@ async fn connect_to_compute_retry() {
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let config = config();
|
||||
connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config)
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -654,7 +651,7 @@ async fn connect_to_compute_non_retry_1() {
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Fail]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let config = config();
|
||||
connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config)
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
|
||||
.await
|
||||
.unwrap_err();
|
||||
mechanism.verify();
|
||||
@@ -669,7 +666,7 @@ async fn connect_to_compute_non_retry_2() {
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, Fail, Wake, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let config = config();
|
||||
connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config)
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -694,7 +691,7 @@ async fn connect_to_compute_non_retry_3() {
|
||||
connect_to_compute(
|
||||
&ctx,
|
||||
&mechanism,
|
||||
user_info,
|
||||
&user_info,
|
||||
wake_compute_retry_config,
|
||||
&config,
|
||||
)
|
||||
@@ -712,7 +709,7 @@ async fn wake_retry() {
|
||||
let mechanism = TestConnectMechanism::new(vec![WakeRetry, Wake, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let config = config();
|
||||
connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config)
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -727,7 +724,7 @@ async fn wake_non_retry() {
|
||||
let mechanism = TestConnectMechanism::new(vec![WakeRetry, WakeFail]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let config = config();
|
||||
connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config)
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
|
||||
.await
|
||||
.unwrap_err();
|
||||
mechanism.verify();
|
||||
@@ -746,7 +743,7 @@ async fn fail_but_wake_invalidates_cache() {
|
||||
let user = helper_create_connect_info(&mech);
|
||||
let cfg = config();
|
||||
|
||||
connect_to_compute(&ctx, &mech, user, cfg.retry, &cfg)
|
||||
connect_to_compute(&ctx, &mech, &user, cfg.retry, &cfg)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -767,7 +764,7 @@ async fn fail_no_wake_skips_cache_invalidation() {
|
||||
let user = helper_create_connect_info(&mech);
|
||||
let cfg = config();
|
||||
|
||||
connect_to_compute(&ctx, &mech, user, cfg.retry, &cfg)
|
||||
connect_to_compute(&ctx, &mech, &user, cfg.retry, &cfg)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -788,7 +785,7 @@ async fn retry_but_wake_invalidates_cache() {
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let cfg = config();
|
||||
|
||||
connect_to_compute(&ctx, &mechanism, user_info, cfg.retry, &cfg)
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, cfg.retry, &cfg)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -811,7 +808,7 @@ async fn retry_no_wake_skips_invalidation() {
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
let cfg = config();
|
||||
|
||||
connect_to_compute(&ctx, &mechanism, user_info, cfg.retry, &cfg)
|
||||
connect_to_compute(&ctx, &mechanism, &user_info, cfg.retry, &cfg)
|
||||
.await
|
||||
.unwrap_err();
|
||||
mechanism.verify();
|
||||
|
||||
@@ -224,13 +224,13 @@ impl PoolingBackend {
|
||||
let backend = self.auth_backend.as_ref().map(|()| keys);
|
||||
crate::proxy::connect_compute::connect_to_compute(
|
||||
ctx,
|
||||
TokioMechanism {
|
||||
&TokioMechanism {
|
||||
conn_id,
|
||||
conn_info,
|
||||
pool: self.pool.clone(),
|
||||
locks: &self.config.connect_compute_locks,
|
||||
},
|
||||
backend,
|
||||
&backend,
|
||||
self.config.wake_compute_retry_config,
|
||||
&self.config.connect_to_compute,
|
||||
)
|
||||
@@ -268,13 +268,13 @@ impl PoolingBackend {
|
||||
});
|
||||
crate::proxy::connect_compute::connect_to_compute(
|
||||
ctx,
|
||||
HyperMechanism {
|
||||
&HyperMechanism {
|
||||
conn_id,
|
||||
conn_info,
|
||||
pool: self.http_conn_pool.clone(),
|
||||
locks: &self.config.connect_compute_locks,
|
||||
},
|
||||
backend,
|
||||
&backend,
|
||||
self.config.wake_compute_retry_config,
|
||||
&self.config.connect_to_compute,
|
||||
)
|
||||
|
||||
@@ -41,7 +41,7 @@ use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio::time::timeout;
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tokio_util::task::task_tracker::TaskTrackerToken;
|
||||
use tokio_util::task::TaskTracker;
|
||||
use tracing::{Instrument, info, warn};
|
||||
|
||||
use crate::cancellation::CancellationHandler;
|
||||
@@ -124,6 +124,7 @@ pub async fn task_main(
|
||||
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
connections.close(); // allows `connections.wait to complete`
|
||||
|
||||
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
while let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await {
|
||||
let (conn, peer_addr) = res.context("could not accept TCP stream")?;
|
||||
if let Err(e) = conn.set_nodelay(true) {
|
||||
@@ -149,11 +150,11 @@ pub async fn task_main(
|
||||
let conn_token = cancellation_token.child_token();
|
||||
let tls_acceptor = tls_acceptor.clone();
|
||||
let backend = backend.clone();
|
||||
let connections2 = connections.clone();
|
||||
let cancellation_handler = cancellation_handler.clone();
|
||||
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
||||
|
||||
let tracker = connections.token();
|
||||
tokio::spawn(
|
||||
let cancellations = cancellations.clone();
|
||||
connections.spawn(
|
||||
async move {
|
||||
let conn_token2 = conn_token.clone();
|
||||
let _cancel_guard = config.http_config.cancel_set.insert(conn_id, conn_token2);
|
||||
@@ -180,7 +181,8 @@ pub async fn task_main(
|
||||
Box::pin(connection_handler(
|
||||
config,
|
||||
backend,
|
||||
tracker,
|
||||
connections2,
|
||||
cancellations,
|
||||
cancellation_handler,
|
||||
endpoint_rate_limiter,
|
||||
conn_token,
|
||||
@@ -303,7 +305,8 @@ async fn connection_startup(
|
||||
async fn connection_handler(
|
||||
config: &'static ProxyConfig,
|
||||
backend: Arc<PoolingBackend>,
|
||||
tracker: TaskTrackerToken,
|
||||
connections: TaskTracker,
|
||||
cancellations: TaskTracker,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
cancellation_token: CancellationToken,
|
||||
@@ -344,17 +347,19 @@ async fn connection_handler(
|
||||
|
||||
// `request_handler` is not cancel safe. It expects to be cancelled only at specific times.
|
||||
// By spawning the future, we ensure it never gets cancelled until it decides to.
|
||||
let handler = tokio::spawn(
|
||||
let cancellations = cancellations.clone();
|
||||
let handler = connections.spawn(
|
||||
request_handler(
|
||||
req,
|
||||
config,
|
||||
backend.clone(),
|
||||
tracker.clone(),
|
||||
connections.clone(),
|
||||
cancellation_handler.clone(),
|
||||
session_id,
|
||||
conn_info2.clone(),
|
||||
http_request_token,
|
||||
endpoint_rate_limiter.clone(),
|
||||
cancellations,
|
||||
)
|
||||
.in_current_span()
|
||||
.map_ok_or_else(api_error_into_response, |r| r),
|
||||
@@ -395,13 +400,14 @@ async fn request_handler(
|
||||
mut request: hyper::Request<Incoming>,
|
||||
config: &'static ProxyConfig,
|
||||
backend: Arc<PoolingBackend>,
|
||||
tracker: TaskTrackerToken,
|
||||
ws_connections: TaskTracker,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
session_id: uuid::Uuid,
|
||||
conn_info: ConnectionInfo,
|
||||
// used to cancel in-flight HTTP requests. not used to cancel websockets
|
||||
http_cancellation_token: CancellationToken,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
cancellations: TaskTracker,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
|
||||
let host = request
|
||||
.headers()
|
||||
@@ -435,17 +441,10 @@ async fn request_handler(
|
||||
let (response, websocket) = framed_websockets::upgrade::upgrade(&mut request)
|
||||
.map_err(|e| ApiError::BadRequest(e.into()))?;
|
||||
|
||||
tokio::spawn(
|
||||
let cancellations = cancellations.clone();
|
||||
ws_connections.spawn(
|
||||
async move {
|
||||
let websocket = match websocket.await {
|
||||
Err(e) => {
|
||||
warn!("could not upgrade websocket connection: {e:#}");
|
||||
return;
|
||||
}
|
||||
Ok(websocket) => websocket,
|
||||
};
|
||||
|
||||
websocket::serve_websocket(
|
||||
if let Err(e) = websocket::serve_websocket(
|
||||
config,
|
||||
backend.auth_backend,
|
||||
ctx,
|
||||
@@ -453,9 +452,12 @@ async fn request_handler(
|
||||
cancellation_handler,
|
||||
endpoint_rate_limiter,
|
||||
host,
|
||||
tracker,
|
||||
cancellations,
|
||||
)
|
||||
.await;
|
||||
.await
|
||||
{
|
||||
warn!("error in websocket connection: {e:#}");
|
||||
}
|
||||
}
|
||||
.instrument(span),
|
||||
);
|
||||
|
||||
@@ -2,14 +2,14 @@ use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll, ready};
|
||||
|
||||
use anyhow::Context as _;
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
use framed_websockets::{Frame, OpCode, WebSocketServer};
|
||||
use futures::{Sink, Stream};
|
||||
use hyper::upgrade::Upgraded;
|
||||
use hyper::upgrade::OnUpgrade;
|
||||
use hyper_util::rt::TokioIo;
|
||||
use pin_project_lite::pin_project;
|
||||
use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio_util::task::task_tracker::TaskTrackerToken;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::cancellation::CancellationHandler;
|
||||
@@ -17,7 +17,7 @@ use crate::config::ProxyConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::ReportableError;
|
||||
use crate::metrics::Metrics;
|
||||
use crate::proxy::{ClientMode, handle_client};
|
||||
use crate::proxy::{ClientMode, ErrorSource, handle_client};
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
|
||||
pin_project! {
|
||||
@@ -128,12 +128,13 @@ pub(crate) async fn serve_websocket(
|
||||
config: &'static ProxyConfig,
|
||||
auth_backend: &'static crate::auth::Backend<'static, ()>,
|
||||
ctx: RequestContext,
|
||||
websocket: Upgraded,
|
||||
websocket: OnUpgrade,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
hostname: Option<String>,
|
||||
tracker: TaskTrackerToken,
|
||||
) {
|
||||
cancellations: tokio_util::task::task_tracker::TaskTracker,
|
||||
) -> anyhow::Result<()> {
|
||||
let websocket = websocket.await?;
|
||||
let websocket = WebSocketServer::after_handshake(TokioIo::new(websocket));
|
||||
|
||||
let conn_gauge = Metrics::get()
|
||||
@@ -141,28 +142,36 @@ pub(crate) async fn serve_websocket(
|
||||
.client_connections
|
||||
.guard(crate::metrics::Protocol::Ws);
|
||||
|
||||
let mut ctx_slot = Some(ctx);
|
||||
let res = handle_client(
|
||||
let res = Box::pin(handle_client(
|
||||
config,
|
||||
auth_backend,
|
||||
&mut ctx_slot,
|
||||
&ctx,
|
||||
cancellation_handler,
|
||||
WebSocketRw::new(websocket),
|
||||
ClientMode::Websockets { hostname },
|
||||
endpoint_rate_limiter,
|
||||
conn_gauge,
|
||||
tracker,
|
||||
)
|
||||
cancellations,
|
||||
))
|
||||
.await;
|
||||
|
||||
match (ctx_slot, res) {
|
||||
(None, _) => {}
|
||||
(Some(ctx), Err(e)) => {
|
||||
match res {
|
||||
Err(e) => {
|
||||
ctx.set_error_kind(e.get_error_kind());
|
||||
tracing::warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
|
||||
Err(e.into())
|
||||
}
|
||||
(Some(ctx), Ok(())) => {
|
||||
Ok(None) => {
|
||||
ctx.set_success();
|
||||
Ok(())
|
||||
}
|
||||
Ok(Some(p)) => {
|
||||
ctx.set_success();
|
||||
ctx.log_connect();
|
||||
match p.proxy_pass(&config.connect_to_compute).await {
|
||||
Ok(()) => Ok(()),
|
||||
Err(ErrorSource::Client(err)) => Err(err).context("client"),
|
||||
Err(ErrorSource::Compute(err)) => Err(err).context("compute"),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,7 +10,6 @@ use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio_rustls::server::TlsStream;
|
||||
use tokio_util::task::task_tracker::TaskTrackerToken;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::control_plane::messages::ColdStartInfo;
|
||||
@@ -25,22 +24,19 @@ use crate::tls::TlsServerEndPoint;
|
||||
/// to pass random malformed bytes through the connection).
|
||||
pub struct PqStream<S> {
|
||||
pub(crate) framed: Framed<S>,
|
||||
pub(crate) tracker: TaskTrackerToken,
|
||||
}
|
||||
|
||||
impl<S> PqStream<S> {
|
||||
/// Construct a new libpq protocol wrapper.
|
||||
pub fn new(stream: S, tracker: TaskTrackerToken) -> Self {
|
||||
pub fn new(stream: S) -> Self {
|
||||
Self {
|
||||
framed: Framed::new(stream),
|
||||
tracker,
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract the underlying stream and read buffer.
|
||||
pub fn into_inner(self) -> (S, BytesMut, TaskTrackerToken) {
|
||||
let (stream, read) = self.framed.into_inner();
|
||||
(stream, read, self.tracker)
|
||||
pub fn into_inner(self) -> (S, BytesMut) {
|
||||
self.framed.into_inner()
|
||||
}
|
||||
|
||||
/// Get a shared reference to the underlying stream.
|
||||
|
||||
@@ -139,6 +139,14 @@ pub(crate) struct StorageControllerMetricGroup {
|
||||
/// HTTP request status counters for handled requests
|
||||
pub(crate) storage_controller_reconcile_long_running:
|
||||
measured::CounterVec<ReconcileLongRunningLabelGroupSet>,
|
||||
|
||||
/// Indicator of safekeeper reconciler queue depth, broken down by safekeeper, excluding ongoing reconciles.
|
||||
pub(crate) storage_controller_safkeeper_reconciles_queued:
|
||||
measured::GaugeVec<SafekeeperReconcilerLabelGroupSet>,
|
||||
|
||||
/// Indicator of completed safekeeper reconciles, broken down by safekeeper.
|
||||
pub(crate) storage_controller_safkeeper_reconciles_complete:
|
||||
measured::CounterVec<SafekeeperReconcilerLabelGroupSet>,
|
||||
}
|
||||
|
||||
impl StorageControllerMetrics {
|
||||
@@ -257,6 +265,17 @@ pub(crate) enum Method {
|
||||
Other,
|
||||
}
|
||||
|
||||
#[derive(measured::LabelGroup, Clone)]
|
||||
#[label(set = SafekeeperReconcilerLabelGroupSet)]
|
||||
pub(crate) struct SafekeeperReconcilerLabelGroup<'a> {
|
||||
#[label(dynamic_with = lasso::ThreadedRodeo, default)]
|
||||
pub(crate) sk_az: &'a str,
|
||||
#[label(dynamic_with = lasso::ThreadedRodeo, default)]
|
||||
pub(crate) sk_node_id: &'a str,
|
||||
#[label(dynamic_with = lasso::ThreadedRodeo, default)]
|
||||
pub(crate) sk_hostname: &'a str,
|
||||
}
|
||||
|
||||
impl From<hyper::Method> for Method {
|
||||
fn from(value: hyper::Method) -> Self {
|
||||
if value == hyper::Method::GET {
|
||||
|
||||
@@ -20,7 +20,9 @@ use utils::{
|
||||
};
|
||||
|
||||
use crate::{
|
||||
persistence::SafekeeperTimelineOpKind, safekeeper::Safekeeper,
|
||||
metrics::{METRICS_REGISTRY, SafekeeperReconcilerLabelGroup},
|
||||
persistence::SafekeeperTimelineOpKind,
|
||||
safekeeper::Safekeeper,
|
||||
safekeeper_client::SafekeeperClient,
|
||||
};
|
||||
|
||||
@@ -218,7 +220,26 @@ impl ReconcilerHandle {
|
||||
fn schedule_reconcile(&self, req: ScheduleRequest) {
|
||||
let (cancel, token_id) = self.new_token_slot(req.tenant_id, req.timeline_id);
|
||||
let hostname = req.safekeeper.skp.host.clone();
|
||||
let sk_az = req.safekeeper.skp.availability_zone_id.clone();
|
||||
let sk_node_id = req.safekeeper.get_id().to_string();
|
||||
|
||||
// We don't have direct access to the queue depth here, so increase it blindly by 1.
|
||||
// We know that putting into the queue increases the queue depth. The receiver will
|
||||
// update with the correct value once it processes the next item. To avoid races where we
|
||||
// reduce before we increase, leaving the gauge with a 1 value for a long time, we
|
||||
// increase it before putting into the queue.
|
||||
let queued_gauge = &METRICS_REGISTRY
|
||||
.metrics_group
|
||||
.storage_controller_safkeeper_reconciles_queued;
|
||||
let label_group = SafekeeperReconcilerLabelGroup {
|
||||
sk_az: &sk_az,
|
||||
sk_node_id: &sk_node_id,
|
||||
sk_hostname: &hostname,
|
||||
};
|
||||
queued_gauge.inc(label_group.clone());
|
||||
|
||||
if let Err(err) = self.tx.send((req, cancel, token_id)) {
|
||||
queued_gauge.set(label_group, 0);
|
||||
tracing::info!("scheduling request onto {hostname} returned error: {err}");
|
||||
}
|
||||
}
|
||||
@@ -283,6 +304,18 @@ impl SafekeeperReconciler {
|
||||
continue;
|
||||
}
|
||||
|
||||
let queued_gauge = &METRICS_REGISTRY
|
||||
.metrics_group
|
||||
.storage_controller_safkeeper_reconciles_queued;
|
||||
queued_gauge.set(
|
||||
SafekeeperReconcilerLabelGroup {
|
||||
sk_az: &req.safekeeper.skp.availability_zone_id,
|
||||
sk_node_id: &req.safekeeper.get_id().to_string(),
|
||||
sk_hostname: &req.safekeeper.skp.host,
|
||||
},
|
||||
self.rx.len() as i64,
|
||||
);
|
||||
|
||||
tokio::task::spawn(async move {
|
||||
let kind = req.kind;
|
||||
let tenant_id = req.tenant_id;
|
||||
@@ -511,6 +544,16 @@ impl SafekeeperReconcilerInner {
|
||||
req.generation,
|
||||
)
|
||||
.await;
|
||||
|
||||
let complete_counter = &METRICS_REGISTRY
|
||||
.metrics_group
|
||||
.storage_controller_safkeeper_reconciles_complete;
|
||||
complete_counter.inc(SafekeeperReconcilerLabelGroup {
|
||||
sk_az: &req.safekeeper.skp.availability_zone_id,
|
||||
sk_node_id: &req.safekeeper.get_id().to_string(),
|
||||
sk_hostname: &req.safekeeper.skp.host,
|
||||
});
|
||||
|
||||
if let Err(err) = res {
|
||||
tracing::info!(
|
||||
"couldn't remove reconciliation request onto {} from persistence: {err:?}",
|
||||
|
||||
@@ -536,16 +536,14 @@ class NeonLocalCli(AbstractNeonCli):
|
||||
res.check_returncode()
|
||||
return res
|
||||
|
||||
def endpoint_generate_jwt(
|
||||
self, endpoint_id: str, scope: ComputeClaimsScope | None = None
|
||||
) -> str:
|
||||
def endpoint_generate_jwt(self, endpoint_id: str, scope: list[ComputeClaimsScope]) -> str:
|
||||
"""
|
||||
Generate a JWT for making requests to the endpoint's external HTTP
|
||||
server.
|
||||
"""
|
||||
args = ["endpoint", "generate-jwt", endpoint_id]
|
||||
if scope:
|
||||
args += ["--scope", str(scope)]
|
||||
for s in scope:
|
||||
args += ["--scope", str(s)]
|
||||
|
||||
cmd = self.raw_cli(args)
|
||||
cmd.check_returncode()
|
||||
|
||||
@@ -4282,7 +4282,7 @@ class Endpoint(PgProtocol, LogUtils):
|
||||
|
||||
self.config(config_lines)
|
||||
|
||||
self.__jwt = self.generate_jwt()
|
||||
self.__jwt = self.generate_jwt([])
|
||||
|
||||
return self
|
||||
|
||||
@@ -4329,7 +4329,7 @@ class Endpoint(PgProtocol, LogUtils):
|
||||
|
||||
return self
|
||||
|
||||
def generate_jwt(self, scope: ComputeClaimsScope | None = None) -> str:
|
||||
def generate_jwt(self, scope: list[ComputeClaimsScope]) -> str:
|
||||
"""
|
||||
Generate a JWT for making requests to the endpoint's external HTTP
|
||||
server.
|
||||
|
||||
@@ -287,6 +287,17 @@ def test_pgdata_import_smoke(
|
||||
with pytest.raises(psycopg2.errors.UndefinedTable):
|
||||
br_initdb_endpoint.safe_psql(f"select * from {workload.table}")
|
||||
|
||||
# The storage controller might be overly eager and attempt to finalize
|
||||
# the import before the task got a chance to exit.
|
||||
env.storage_controller.allowed_errors.extend(
|
||||
[
|
||||
".*Call to node.*management API.*failed.*Import task still running.*",
|
||||
]
|
||||
)
|
||||
|
||||
for ps in env.pageservers:
|
||||
ps.allowed_errors.extend([".*Error processing HTTP request.*Import task not done yet.*"])
|
||||
|
||||
|
||||
@run_only_on_default_postgres(reason="PG version is irrelevant here")
|
||||
def test_import_completion_on_restart(
|
||||
@@ -471,6 +482,17 @@ def test_import_respects_timeline_lifecycle(
|
||||
else:
|
||||
raise RuntimeError(f"{action} param not recognized")
|
||||
|
||||
# The storage controller might be overly eager and attempt to finalize
|
||||
# the import before the task got a chance to exit.
|
||||
env.storage_controller.allowed_errors.extend(
|
||||
[
|
||||
".*Call to node.*management API.*failed.*Import task still running.*",
|
||||
]
|
||||
)
|
||||
|
||||
for ps in env.pageservers:
|
||||
ps.allowed_errors.extend([".*Error processing HTTP request.*Import task not done yet.*"])
|
||||
|
||||
|
||||
@skip_in_debug_build("Validation query takes too long in debug builds")
|
||||
def test_import_chaos(
|
||||
|
||||
@@ -124,6 +124,9 @@ def test_location_conf_churn(neon_env_builder: NeonEnvBuilder, make_httpserver,
|
||||
".*downloading failed, possibly for shutdown",
|
||||
# {tenant_id=... timeline_id=...}:handle_pagerequests:handle_get_page_at_lsn_request{rel=1664/0/1260 blkno=0 req_lsn=0/149F0D8}: error reading relation or page version: Not found: will not become active. Current state: Stopping\n'
|
||||
".*page_service.*will not become active.*",
|
||||
# the following errors are possible when pageserver tries to ingest wal records despite being in unreadable state
|
||||
".*wal_connection_manager.*layer file download failed: No file found.*",
|
||||
".*wal_connection_manager.*could not ingest record.*",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -156,6 +159,45 @@ def test_location_conf_churn(neon_env_builder: NeonEnvBuilder, make_httpserver,
|
||||
env.pageservers[2].id: ("Detached", None),
|
||||
}
|
||||
|
||||
# Track all the attached locations with mode and generation
|
||||
history: list[tuple[int, str, int | None]] = []
|
||||
|
||||
def may_read(pageserver: NeonPageserver, mode: str, generation: int | None) -> bool:
|
||||
# Rules for when a pageserver may read:
|
||||
# - our generation is higher than any previous
|
||||
# - our generation is equal to previous, but no other pageserver
|
||||
# in that generation has been AttachedSingle (i.e. allowed to compact/GC)
|
||||
# - our generation is equal to previous, and the previous holder of this
|
||||
# generation was the same node as we're attaching now.
|
||||
#
|
||||
# If these conditions are not met, then a read _might_ work, but the pageserver might
|
||||
# also hit errors trying to download layers.
|
||||
highest_historic_generation = max([i[2] for i in history if i[2] is not None], default=None)
|
||||
|
||||
if generation is None:
|
||||
# We're not in an attached state, we may not read
|
||||
return False
|
||||
elif highest_historic_generation is not None and generation < highest_historic_generation:
|
||||
# We are in an outdated generation, we may not read
|
||||
return False
|
||||
elif highest_historic_generation is not None and generation == highest_historic_generation:
|
||||
# We are re-using a generation: if any pageserver other than this one
|
||||
# has held AttachedSingle mode, this node may not read (because some other
|
||||
# node may be doing GC/compaction).
|
||||
if any(
|
||||
i[1] == "AttachedSingle"
|
||||
and i[2] == highest_historic_generation
|
||||
and i[0] != pageserver.id
|
||||
for i in history
|
||||
):
|
||||
log.info(
|
||||
f"Skipping read on {pageserver.id} because other pageserver has been in AttachedSingle mode in generation {highest_historic_generation}"
|
||||
)
|
||||
return False
|
||||
|
||||
# Fall through: we have passed conditions for readability
|
||||
return True
|
||||
|
||||
latest_attached = env.pageservers[0].id
|
||||
|
||||
for _i in range(0, 64):
|
||||
@@ -199,9 +241,10 @@ def test_location_conf_churn(neon_env_builder: NeonEnvBuilder, make_httpserver,
|
||||
assert len(tenants) == 1
|
||||
assert tenants[0]["generation"] == new_generation
|
||||
|
||||
log.info("Entering postgres...")
|
||||
workload.churn_rows(rng.randint(128, 256), pageserver.id)
|
||||
workload.validate(pageserver.id)
|
||||
if may_read(pageserver, last_state_ps[0], last_state_ps[1]):
|
||||
log.info("Entering postgres...")
|
||||
workload.churn_rows(rng.randint(128, 256), pageserver.id)
|
||||
workload.validate(pageserver.id)
|
||||
elif last_state_ps[0].startswith("Attached"):
|
||||
# The `storage_controller` will only re-attach on startup when a pageserver was the
|
||||
# holder of the latest generation: otherwise the pageserver will revert to detached
|
||||
@@ -241,18 +284,16 @@ def test_location_conf_churn(neon_env_builder: NeonEnvBuilder, make_httpserver,
|
||||
location_conf["generation"] = generation
|
||||
|
||||
pageserver.tenant_location_configure(tenant_id, location_conf)
|
||||
|
||||
last_state[pageserver.id] = (mode, generation)
|
||||
|
||||
# It's only valid to connect to the last generation. Newer generations may yank layer
|
||||
# files used in older generations.
|
||||
last_generation = max(
|
||||
[s[1] for s in last_state.values() if s[1] is not None], default=None
|
||||
)
|
||||
may_read_this_generation = may_read(pageserver, mode, generation)
|
||||
history.append((pageserver.id, mode, generation))
|
||||
|
||||
if mode.startswith("Attached") and generation == last_generation:
|
||||
# This is a basic test: we are validating that he endpoint works properly _between_
|
||||
# configuration changes. A stronger test would be to validate that clients see
|
||||
# no errors while we are making the changes.
|
||||
# This is a basic test: we are validating that he endpoint works properly _between_
|
||||
# configuration changes. A stronger test would be to validate that clients see
|
||||
# no errors while we are making the changes.
|
||||
if may_read_this_generation:
|
||||
workload.churn_rows(
|
||||
rng.randint(128, 256), pageserver.id, upload=mode != "AttachedStale"
|
||||
)
|
||||
@@ -265,9 +306,16 @@ def test_location_conf_churn(neon_env_builder: NeonEnvBuilder, make_httpserver,
|
||||
assert gc_summary["remote_storage_errors"] == 0
|
||||
assert gc_summary["indices_deleted"] > 0
|
||||
|
||||
# Attach all pageservers
|
||||
# Attach all pageservers, in a higher generation than any previous. We will use the same
|
||||
# gen for all, and AttachedMulti mode so that they do not interfere with one another.
|
||||
generation = env.storage_controller.attach_hook_issue(tenant_id, env.pageservers[0].id)
|
||||
for ps in env.pageservers:
|
||||
location_conf = {"mode": "AttachedMulti", "secondary_conf": None, "tenant_conf": {}}
|
||||
location_conf = {
|
||||
"mode": "AttachedMulti",
|
||||
"secondary_conf": None,
|
||||
"tenant_conf": {},
|
||||
"generation": generation,
|
||||
}
|
||||
ps.tenant_location_configure(tenant_id, location_conf)
|
||||
|
||||
# Confirm that all are readable
|
||||
|
||||
Reference in New Issue
Block a user