Compare commits

..

16 Commits

Author SHA1 Message Date
Conrad Ludgate
631139ceeb turns out the boxing isn't necessary, we just needed to massage the stack usage properly 2025-05-30 08:47:44 +01:00
Conrad Ludgate
fd43058bd7 optimise passthrough calling convention to further reduce memory 2025-05-29 18:35:24 +01:00
Conrad Ludgate
cf07c5b5f9 dont box handle_client anymore and move spawning passthrough into handle_client so we don't need to move a heavy object in return position anymore 2025-05-29 18:20:29 +01:00
Conrad Ludgate
11bb84c38d save 1000 bytes by removing instrument 2025-05-29 17:56:25 +01:00
Conrad Ludgate
219c72c24c optimise proxy_pass memory size a little, also boxing requestcontext since it is large 2025-05-29 17:52:26 +01:00
Conrad Ludgate
0633cd6385 small changes to connect compute mechanism/backend handling 2025-05-29 16:21:55 +01:00
Conrad Ludgate
0cdb0c5704 reuse the same tracker token for websockets and http 2025-05-29 16:04:14 +01:00
Conrad Ludgate
eefac5d78b box the connect to compute task 2025-05-29 15:58:28 +01:00
Conrad Ludgate
7d1c908b1b box authenticate task 2025-05-29 15:55:17 +01:00
Conrad Ludgate
cfa2813446 remove unnecessary aux field from passthrough 2025-05-29 15:51:57 +01:00
Conrad Ludgate
034bdb1552 move more work inside handshake 2025-05-29 15:50:10 +01:00
Conrad Ludgate
8b1ffa1718 simplify cplane authentication 2025-05-29 15:46:40 +01:00
Conrad Ludgate
2d3ea77953 box the handshake task 2025-05-29 15:39:33 +01:00
Conrad Ludgate
3124729f53 spawn passthrough as a separate task to reduce influence from the handshake task 2025-05-29 15:21:54 +01:00
Conrad Ludgate
6463eb38be manually handle task tracker tokens 2025-05-29 15:19:03 +01:00
Conrad Ludgate
ae506fd791 proxy: remove unused ip return value 2025-05-29 15:04:40 +01:00
40 changed files with 590 additions and 944 deletions

View File

@@ -57,6 +57,21 @@ use tracing::{error, info};
use url::Url;
use utils::failpoint_support;
// Compatibility hack: if the control plane specified any remote-ext-config
// use the default value for extension storage proxy gateway.
// Remove this once the control plane is updated to pass the gateway URL
fn parse_remote_ext_base_url(arg: &str) -> Result<String> {
const FALLBACK_PG_EXT_GATEWAY_BASE_URL: &str =
"http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local";
Ok(if arg.starts_with("http") {
arg
} else {
FALLBACK_PG_EXT_GATEWAY_BASE_URL
}
.to_owned())
}
#[derive(Parser)]
#[command(rename_all = "kebab-case")]
struct Cli {
@@ -65,7 +80,7 @@ struct Cli {
/// The base URL for the remote extension storage proxy gateway.
/// Should be in the form of `http(s)://<gateway-hostname>[:<port>]`.
#[arg(short = 'r', long, alias = "remote-ext-config")]
#[arg(short = 'r', long, value_parser = parse_remote_ext_base_url, alias = "remote-ext-config")]
pub remote_ext_base_url: Option<String>,
/// The port to bind the external listening HTTP server to. Clients running
@@ -261,4 +276,18 @@ mod test {
fn verify_cli() {
Cli::command().debug_assert()
}
#[test]
fn parse_pg_ext_gateway_base_url() {
let arg = "http://pg-ext-s3-gateway2";
let result = super::parse_remote_ext_base_url(arg).unwrap();
assert_eq!(result, arg);
let arg = "pg-ext-s3-gateway";
let result = super::parse_remote_ext_base_url(arg).unwrap();
assert_eq!(
result,
"http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local"
);
}
}

View File

@@ -339,8 +339,6 @@ async fn run_dump_restore(
destination_connstring: String,
) -> Result<(), anyhow::Error> {
let dumpdir = workdir.join("dumpdir");
let num_jobs = num_cpus::get().to_string();
info!("using {num_jobs} jobs for dump/restore");
let common_args = [
// schema mapping (prob suffices to specify them on one side)
@@ -356,7 +354,7 @@ async fn run_dump_restore(
"directory".to_string(),
// concurrency
"--jobs".to_string(),
num_jobs,
num_cpus::get().to_string(),
// progress updates
"--verbose".to_string(),
];

View File

@@ -71,7 +71,7 @@ impl AsyncAuthorizeRequest<Body> for Authorize {
// Unauthorized because when we eventually do use
// [`Validation`], we will hit the above `Err` match arm which
// returns 401 Unauthorized.
Some(ref scope) if scope.contains(&ComputeClaimsScope::Admin) => {
Some(ComputeClaimsScope::Admin) => {
let Some(ref audience) = data.claims.audience else {
return Err(JsonResponse::error(
StatusCode::UNAUTHORIZED,

View File

@@ -709,7 +709,7 @@ struct EndpointGenerateJwtCmdArgs {
endpoint_id: String,
#[clap(short = 's', long, help = "Scope to generate the JWT with", value_parser = ComputeClaimsScope::from_str)]
scope: Vec<ComputeClaimsScope>,
scope: Option<ComputeClaimsScope>,
}
#[derive(clap::Subcommand)]
@@ -1580,7 +1580,7 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
.with_context(|| format!("postgres endpoint {endpoint_id} is not found"))?
};
let jwt = endpoint.generate_jwt(Some(args.scope.clone()))?;
let jwt = endpoint.generate_jwt(args.scope)?;
print!("{jwt}");
}

View File

@@ -632,16 +632,14 @@ impl Endpoint {
}
/// Generate a JWT with the correct claims.
pub fn generate_jwt(&self, scope: Option<Vec<ComputeClaimsScope>>) -> Result<String> {
pub fn generate_jwt(&self, scope: Option<ComputeClaimsScope>) -> Result<String> {
self.env.generate_auth_token(&ComputeClaims {
audience: match scope {
Some(ref scope) if scope.contains(&ComputeClaimsScope::Admin) => {
Some(vec![COMPUTE_AUDIENCE.to_owned()])
}
Some(ComputeClaimsScope::Admin) => Some(vec![COMPUTE_AUDIENCE.to_owned()]),
_ => None,
},
compute_id: match scope {
Some(ref scope) if scope.contains(&ComputeClaimsScope::Admin) => None,
Some(ComputeClaimsScope::Admin) => None,
_ => Some(self.endpoint_id.clone()),
},
scope,
@@ -920,7 +918,7 @@ impl Endpoint {
self.external_http_address.port()
),
)
.bearer_auth(self.generate_jwt(None::<Vec<ComputeClaimsScope>>)?)
.bearer_auth(self.generate_jwt(None::<ComputeClaimsScope>)?)
.send()
.await?;
@@ -997,7 +995,7 @@ impl Endpoint {
self.external_http_address.port()
))
.header(CONTENT_TYPE.as_str(), "application/json")
.bearer_auth(self.generate_jwt(None::<Vec<ComputeClaimsScope>>)?)
.bearer_auth(self.generate_jwt(None::<ComputeClaimsScope>)?)
.body(
serde_json::to_string(&ConfigurationRequest {
spec,

View File

@@ -41,8 +41,8 @@ pub struct ComputeClaims {
/// [`ComputeClaimsScope::Admin`].
pub compute_id: Option<String>,
/// The scopes of what the token is authorized for.
pub scope: Option<Vec<ComputeClaimsScope>>,
/// The scope of what the token authorizes.
pub scope: Option<ComputeClaimsScope>,
/// The recipient the token is intended for.
///

View File

@@ -27,7 +27,6 @@ pub use prometheus::{
pub mod launch_timestamp;
mod wrappers;
pub use prometheus;
pub use wrappers::{CountedReader, CountedWriter};
mod hll;
pub use hll::{HyperLogLog, HyperLogLogState, HyperLogLogVec};

View File

@@ -4,7 +4,6 @@ use std::{sync::Arc, time::Duration};
use arc_swap::ArcSwap;
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, info_span};
use crate::{FeatureStore, PostHogClient, PostHogClientConfig};
@@ -27,35 +26,31 @@ impl FeatureResolverBackgroundLoop {
pub fn spawn(self: Arc<Self>, handle: &tokio::runtime::Handle, refresh_period: Duration) {
let this = self.clone();
let cancel = self.cancel.clone();
handle.spawn(
async move {
tracing::info!("Starting PostHog feature resolver");
let mut ticker = tokio::time::interval(refresh_period);
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
tokio::select! {
_ = ticker.tick() => {}
_ = cancel.cancelled() => break
}
let resp = match this
.posthog_client
.get_feature_flags_local_evaluation()
.await
{
Ok(resp) => resp,
Err(e) => {
tracing::warn!("Cannot get feature flags: {}", e);
continue;
}
};
let feature_store = FeatureStore::new_with_flags(resp.flags);
this.feature_store.store(Arc::new(feature_store));
tracing::info!("Feature flag updated");
handle.spawn(async move {
tracing::info!("Starting PostHog feature resolver");
let mut ticker = tokio::time::interval(refresh_period);
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
tokio::select! {
_ = ticker.tick() => {}
_ = cancel.cancelled() => break
}
tracing::info!("PostHog feature resolver stopped");
let resp = match this
.posthog_client
.get_feature_flags_local_evaluation()
.await
{
Ok(resp) => resp,
Err(e) => {
tracing::warn!("Cannot get feature flags: {}", e);
continue;
}
};
let feature_store = FeatureStore::new_with_flags(resp.flags);
this.feature_store.store(Arc::new(feature_store));
}
.instrument(info_span!("posthog_feature_resolver")),
);
tracing::info!("PostHog feature resolver stopped");
});
}
pub fn feature_store(&self) -> Arc<FeatureStore> {

View File

@@ -448,18 +448,6 @@ impl FeatureStore {
)))
}
}
/// Infer whether a feature flag is a boolean flag by checking if it has a multivariate filter.
pub fn is_feature_flag_boolean(&self, flag_key: &str) -> Result<bool, PostHogEvaluationError> {
if let Some(flag_config) = self.flags.get(flag_key) {
Ok(flag_config.filters.multivariate.is_none())
} else {
Err(PostHogEvaluationError::NotAvailable(format!(
"Not found in the local evaluation spec: {}",
flag_key
)))
}
}
}
pub struct PostHogClientConfig {
@@ -540,15 +528,7 @@ impl PostHogClient {
.bearer_auth(&self.config.server_api_key)
.send()
.await?;
let status = response.status();
let body = response.text().await?;
if !status.is_success() {
return Err(anyhow::anyhow!(
"Failed to get feature flags: {}, {}",
status,
body
));
}
Ok(serde_json::from_str(&body)?)
}

View File

@@ -264,56 +264,10 @@ mod propagation_of_cached_label_value {
}
}
criterion_group!(histograms, histograms::bench_bucket_scalability);
mod histograms {
use std::time::Instant;
use criterion::{BenchmarkId, Criterion};
use metrics::core::Collector;
pub fn bench_bucket_scalability(c: &mut Criterion) {
let mut g = c.benchmark_group("bucket_scalability");
for n in [1, 4, 8, 16, 32, 64, 128, 256] {
g.bench_with_input(BenchmarkId::new("nbuckets", n), &n, |b, n| {
b.iter_custom(|iters| {
let buckets: Vec<f64> = (0..*n).map(|i| i as f64 * 100.0).collect();
let histo = metrics::Histogram::with_opts(
metrics::prometheus::HistogramOpts::new("name", "help")
.buckets(buckets.clone()),
)
.unwrap();
let start = Instant::now();
for i in 0..usize::try_from(iters).unwrap() {
histo.observe(buckets[i % buckets.len()]);
}
let elapsed = start.elapsed();
// self-test
let mfs = histo.collect();
assert_eq!(mfs.len(), 1);
let metrics = mfs[0].get_metric();
assert_eq!(metrics.len(), 1);
let histo = metrics[0].get_histogram();
let buckets = histo.get_bucket();
assert!(
buckets
.iter()
.enumerate()
.all(|(i, b)| b.get_cumulative_count()
>= i as u64 * (iters / buckets.len() as u64))
);
elapsed
})
});
}
}
}
criterion_main!(
label_values,
single_metric_multicore_scalability,
propagation_of_cached_label_value,
histograms,
propagation_of_cached_label_value
);
/*
@@ -336,14 +290,6 @@ propagation_of_cached_label_value__naive/nthreads/8 time: [211.50 ns 214.44 ns
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/1 time: [14.135 ns 14.147 ns 14.160 ns]
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/4 time: [14.243 ns 14.255 ns 14.268 ns]
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/8 time: [14.470 ns 14.682 ns 14.895 ns]
bucket_scalability/nbuckets/1 time: [30.352 ns 30.353 ns 30.354 ns]
bucket_scalability/nbuckets/4 time: [30.464 ns 30.465 ns 30.467 ns]
bucket_scalability/nbuckets/8 time: [30.569 ns 30.575 ns 30.584 ns]
bucket_scalability/nbuckets/16 time: [30.961 ns 30.965 ns 30.969 ns]
bucket_scalability/nbuckets/32 time: [35.691 ns 35.707 ns 35.722 ns]
bucket_scalability/nbuckets/64 time: [47.829 ns 47.898 ns 47.974 ns]
bucket_scalability/nbuckets/128 time: [73.479 ns 73.512 ns 73.545 ns]
bucket_scalability/nbuckets/256 time: [127.92 ns 127.94 ns 127.96 ns]
Results on an i3en.3xlarge instance
@@ -398,14 +344,6 @@ propagation_of_cached_label_value__naive/nthreads/8 time: [434.87 ns 456.4
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/1 time: [3.3767 ns 3.3974 ns 3.4220 ns]
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/4 time: [3.6105 ns 4.2355 ns 5.1463 ns]
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/8 time: [4.0889 ns 4.9714 ns 6.0779 ns]
bucket_scalability/nbuckets/1 time: [4.8455 ns 4.8542 ns 4.8646 ns]
bucket_scalability/nbuckets/4 time: [4.5663 ns 4.5722 ns 4.5787 ns]
bucket_scalability/nbuckets/8 time: [4.5531 ns 4.5670 ns 4.5842 ns]
bucket_scalability/nbuckets/16 time: [4.6392 ns 4.6524 ns 4.6685 ns]
bucket_scalability/nbuckets/32 time: [6.0302 ns 6.0439 ns 6.0589 ns]
bucket_scalability/nbuckets/64 time: [10.608 ns 10.644 ns 10.691 ns]
bucket_scalability/nbuckets/128 time: [22.178 ns 22.316 ns 22.483 ns]
bucket_scalability/nbuckets/256 time: [42.190 ns 42.328 ns 42.492 ns]
Results on a Hetzner AX102 AMD Ryzen 9 7950X3D 16-Core Processor
@@ -424,13 +362,5 @@ propagation_of_cached_label_value__naive/nthreads/8 time: [164.24 ns 170.1
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/1 time: [2.2915 ns 2.2960 ns 2.3012 ns]
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/4 time: [2.5726 ns 2.6158 ns 2.6624 ns]
propagation_of_cached_label_value__long_lived_reference_per_thread/nthreads/8 time: [2.7068 ns 2.8243 ns 2.9824 ns]
bucket_scalability/nbuckets/1 time: [6.3998 ns 6.4288 ns 6.4684 ns]
bucket_scalability/nbuckets/4 time: [6.3603 ns 6.3620 ns 6.3637 ns]
bucket_scalability/nbuckets/8 time: [6.1646 ns 6.1654 ns 6.1667 ns]
bucket_scalability/nbuckets/16 time: [6.1341 ns 6.1391 ns 6.1454 ns]
bucket_scalability/nbuckets/32 time: [8.2206 ns 8.2254 ns 8.2301 ns]
bucket_scalability/nbuckets/64 time: [13.988 ns 13.994 ns 14.000 ns]
bucket_scalability/nbuckets/128 time: [28.180 ns 28.216 ns 28.251 ns]
bucket_scalability/nbuckets/256 time: [54.914 ns 54.931 ns 54.951 ns]
*/

View File

@@ -91,14 +91,4 @@ impl FeatureResolver {
))
}
}
pub fn is_feature_flag_boolean(&self, flag_key: &str) -> Result<bool, PostHogEvaluationError> {
if let Some(inner) = &self.inner {
inner.feature_store().is_feature_flag_boolean(flag_key)
} else {
Err(PostHogEvaluationError::NotAvailable(
"PostHog integration is not enabled".to_string(),
))
}
}
}

View File

@@ -3663,46 +3663,6 @@ async fn read_tar_eof(mut reader: (impl tokio::io::AsyncRead + Unpin)) -> anyhow
Ok(())
}
async fn tenant_evaluate_feature_flag(
request: Request<Body>,
_cancel: CancellationToken,
) -> Result<Response<Body>, ApiError> {
let tenant_shard_id: TenantShardId = parse_request_param(&request, "tenant_shard_id")?;
check_permission(&request, Some(tenant_shard_id.tenant_id))?;
let flag: String = must_parse_query_param(&request, "flag")?;
let as_type: String = must_parse_query_param(&request, "as")?;
let state = get_state(&request);
async {
let tenant = state
.tenant_manager
.get_attached_tenant_shard(tenant_shard_id)?;
if as_type == "boolean" {
let result = tenant.feature_resolver.evaluate_boolean(&flag, tenant_shard_id.tenant_id);
let result = result.map(|_| true).map_err(|e| e.to_string());
json_response(StatusCode::OK, result)
} else if as_type == "multivariate" {
let result = tenant.feature_resolver.evaluate_multivariate(&flag, tenant_shard_id.tenant_id).map_err(|e| e.to_string());
json_response(StatusCode::OK, result)
} else {
// Auto infer the type of the feature flag.
let is_boolean = tenant.feature_resolver.is_feature_flag_boolean(&flag).map_err(|e| ApiError::InternalServerError(anyhow::anyhow!("{e}")))?;
if is_boolean {
let result = tenant.feature_resolver.evaluate_boolean(&flag, tenant_shard_id.tenant_id);
let result = result.map(|_| true).map_err(|e| e.to_string());
json_response(StatusCode::OK, result)
} else {
let result = tenant.feature_resolver.evaluate_multivariate(&flag, tenant_shard_id.tenant_id).map_err(|e| e.to_string());
json_response(StatusCode::OK, result)
}
}
}
.instrument(info_span!("tenant_evaluate_feature_flag", tenant_id = %tenant_shard_id.tenant_id, shard_id = %tenant_shard_id.shard_slug()))
.await
}
/// Common functionality of all the HTTP API handlers.
///
/// - Adds a tracing span to each request (by `request_span`)
@@ -4079,8 +4039,5 @@ pub fn make_router(
"/v1/tenant/:tenant_shard_id/timeline/:timeline_id/activate_post_import",
|r| api_handler(r, activate_post_import_handler),
)
.get("/v1/tenant/:tenant_shard_id/feature_flag", |r| {
api_handler(r, tenant_evaluate_feature_flag)
})
.any(handler_404))
}

View File

@@ -1312,44 +1312,11 @@ impl EvictionsWithLowResidenceDuration {
//
// Roughly logarithmic scale.
const STORAGE_IO_TIME_BUCKETS: &[f64] = &[
0.00005, // 50us
0.00006, // 60us
0.00007, // 70us
0.00008, // 80us
0.00009, // 90us
0.0001, // 100us
0.000110, // 110us
0.000120, // 120us
0.000130, // 130us
0.000140, // 140us
0.000150, // 150us
0.000160, // 160us
0.000170, // 170us
0.000180, // 180us
0.000190, // 190us
0.000200, // 200us
0.000210, // 210us
0.000220, // 220us
0.000230, // 230us
0.000240, // 240us
0.000250, // 250us
0.000300, // 300us
0.000350, // 350us
0.000400, // 400us
0.000450, // 450us
0.000500, // 500us
0.000600, // 600us
0.000700, // 700us
0.000800, // 800us
0.000900, // 900us
0.001000, // 1ms
0.002000, // 2ms
0.003000, // 3ms
0.004000, // 4ms
0.005000, // 5ms
0.01000, // 10ms
0.02000, // 20ms
0.05000, // 50ms
0.000030, // 30 usec
0.001000, // 1000 usec
0.030, // 30 ms
1.000, // 1000 ms
30.000, // 30000 ms
];
/// VirtualFile fs operation variants.

View File

@@ -383,7 +383,7 @@ pub struct TenantShard {
l0_flush_global_state: L0FlushGlobalState,
pub(crate) feature_resolver: FeatureResolver,
feature_resolver: FeatureResolver,
}
impl std::fmt::Debug for TenantShard {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
@@ -5832,7 +5832,6 @@ pub(crate) mod harness {
pub conf: &'static PageServerConf,
pub tenant_conf: pageserver_api::models::TenantConfig,
pub tenant_shard_id: TenantShardId,
pub shard_identity: ShardIdentity,
pub generation: Generation,
pub shard: ShardIndex,
pub remote_storage: GenericRemoteStorage,
@@ -5900,7 +5899,6 @@ pub(crate) mod harness {
conf,
tenant_conf,
tenant_shard_id,
shard_identity,
generation,
shard,
remote_storage,
@@ -5962,7 +5960,8 @@ pub(crate) mod harness {
&ShardParameters::default(),
))
.unwrap(),
self.shard_identity,
// This is a legacy/test code path: sharding isn't supported here.
ShardIdentity::unsharded(),
Some(walredo_mgr),
self.tenant_shard_id,
self.remote_storage.clone(),
@@ -6084,7 +6083,6 @@ mod tests {
use timeline::compaction::{KeyHistoryRetention, KeyLogAtLsn};
use timeline::{CompactOptions, DeltaLayerTestDesc, VersionedKeySpaceQuery};
use utils::id::TenantId;
use utils::shard::{ShardCount, ShardNumber};
use super::*;
use crate::DEFAULT_PG_VERSION;
@@ -9420,77 +9418,6 @@ mod tests {
Ok(())
}
#[tokio::test]
async fn test_failed_flush_should_not_update_disk_consistent_lsn() -> anyhow::Result<()> {
//
// Setup
//
let harness = TenantHarness::create_custom(
"test_failed_flush_should_not_upload_disk_consistent_lsn",
pageserver_api::models::TenantConfig::default(),
TenantId::generate(),
ShardIdentity::new(ShardNumber(0), ShardCount(4), ShardStripeSize(128)).unwrap(),
Generation::new(1),
)
.await?;
let (tenant, ctx) = harness.load().await;
let timeline = tenant
.create_test_timeline(TIMELINE_ID, Lsn(0x10), DEFAULT_PG_VERSION, &ctx)
.await?;
assert_eq!(timeline.get_shard_identity().count, ShardCount(4));
let mut writer = timeline.writer().await;
writer
.put(
*TEST_KEY,
Lsn(0x20),
&Value::Image(test_img("foo at 0x20")),
&ctx,
)
.await?;
writer.finish_write(Lsn(0x20));
drop(writer);
timeline.freeze_and_flush().await.unwrap();
timeline.remote_client.wait_completion().await.unwrap();
let disk_consistent_lsn = timeline.get_disk_consistent_lsn();
let remote_consistent_lsn = timeline.get_remote_consistent_lsn_projected();
assert_eq!(Some(disk_consistent_lsn), remote_consistent_lsn);
//
// Test
//
let mut writer = timeline.writer().await;
writer
.put(
*TEST_KEY,
Lsn(0x30),
&Value::Image(test_img("foo at 0x30")),
&ctx,
)
.await?;
writer.finish_write(Lsn(0x30));
drop(writer);
fail::cfg(
"flush-layer-before-update-remote-consistent-lsn",
"return()",
)
.unwrap();
let flush_res = timeline.freeze_and_flush().await;
// if flush failed, the disk/remote consistent LSN should not be updated
assert!(flush_res.is_err());
assert_eq!(disk_consistent_lsn, timeline.get_disk_consistent_lsn());
assert_eq!(
remote_consistent_lsn,
timeline.get_remote_consistent_lsn_projected()
);
Ok(())
}
#[cfg(feature = "testing")]
#[tokio::test]
async fn test_simple_bottom_most_compaction_deltas_1() -> anyhow::Result<()> {

View File

@@ -4767,10 +4767,7 @@ impl Timeline {
|| !flushed_to_lsn.is_valid()
);
if flushed_to_lsn < frozen_to_lsn
&& self.shard_identity.count.count() > 1
&& result.is_ok()
{
if flushed_to_lsn < frozen_to_lsn && self.shard_identity.count.count() > 1 {
// If our layer flushes didn't carry disk_consistent_lsn up to the `to_lsn` advertised
// to us via layer_flush_start_rx, then advance it here.
//
@@ -4949,10 +4946,6 @@ impl Timeline {
return Err(FlushLayerError::Cancelled);
}
fail_point!("flush-layer-before-update-remote-consistent-lsn", |_| {
Err(FlushLayerError::Other(anyhow!("failpoint").into()))
});
let disk_consistent_lsn = Lsn(lsn_range.end.0 - 1);
// The new on-disk layers are now in the layer map. We can remove the

View File

@@ -11,7 +11,19 @@
//! - => S3 as the source for the PGDATA instead of local filesystem
//!
//! TODOs before productionization:
//! - ChunkProcessingJob size / ImportJob::total_size does not account for sharding.
//! => produced image layers likely too small.
//! - ChunkProcessingJob should cut up an ImportJob to hit exactly target image layer size.
//! - asserts / unwraps need to be replaced with errors
//! - don't trust remote objects will be small (=prevent OOMs in those cases)
//! - limit all in-memory buffers in size, or download to disk and read from there
//! - limit task concurrency
//! - generally play nice with other tenants in the system
//! - importbucket is different bucket than main pageserver storage, so, should be fine wrt S3 rate limits
//! - but concerns like network bandwidth, local disk write bandwidth, local disk capacity, etc
//! - integrate with layer eviction system
//! - audit for Tenant::cancel nor Timeline::cancel responsivity
//! - audit for Tenant/Timeline gate holding (we spawn tokio tasks during this flow!)
//!
//! An incomplete set of TODOs from the Hackathon:
//! - version-specific CheckPointData (=> pgv abstraction, already exists for regular walingest)
@@ -32,7 +44,7 @@ use pageserver_api::key::{
rel_dir_to_key, rel_size_to_key, relmap_file_key, slru_block_to_key, slru_dir_to_key,
slru_segment_size_to_key,
};
use pageserver_api::keyspace::{ShardedRange, singleton_range};
use pageserver_api::keyspace::{contiguous_range_len, is_contiguous_range, singleton_range};
use pageserver_api::models::{ShardImportProgress, ShardImportProgressV1, ShardImportStatus};
use pageserver_api::reltag::{RelTag, SlruKind};
use pageserver_api::shard::ShardIdentity;
@@ -155,7 +167,6 @@ impl Planner {
/// This function is and must remain pure: given the same input, it will generate the same import plan.
async fn plan(mut self, import_config: &TimelineImportConfig) -> anyhow::Result<Plan> {
let pgdata_lsn = Lsn(self.control_file.control_file_data().checkPoint).align();
anyhow::ensure!(pgdata_lsn.is_valid());
let datadir = PgDataDir::new(&self.storage).await?;
@@ -238,22 +249,14 @@ impl Planner {
});
// Assigns parts of key space to later parallel jobs
// Note: The image layers produced here may have gaps, meaning,
// there is not an image for each key in the layer's key range.
// The read path stops traversal at the first image layer, regardless
// of whether a base image has been found for a key or not.
// (Concept of sparse image layers doesn't exist.)
// This behavior is exactly right for the base image layers we're producing here.
// But, since no other place in the code currently produces image layers with gaps,
// it seems noteworthy.
let mut last_end_key = Key::MIN;
let mut current_chunk = Vec::new();
let mut current_chunk_size: usize = 0;
let mut jobs = Vec::new();
for task in std::mem::take(&mut self.tasks).into_iter() {
let task_size = task.total_size(&self.shard);
let projected_chunk_size = current_chunk_size.saturating_add(task_size);
if projected_chunk_size > import_config.import_job_soft_size_limit.into() {
if current_chunk_size + task.total_size()
> import_config.import_job_soft_size_limit.into()
{
let key_range = last_end_key..task.key_range().start;
jobs.push(ChunkProcessingJob::new(
key_range.clone(),
@@ -263,7 +266,7 @@ impl Planner {
last_end_key = key_range.end;
current_chunk_size = 0;
}
current_chunk_size = current_chunk_size.saturating_add(task_size);
current_chunk_size += task.total_size();
current_chunk.push(task);
}
jobs.push(ChunkProcessingJob::new(
@@ -601,18 +604,18 @@ impl PgDataDirDb {
};
let path = datadir_path.join(rel_tag.to_segfile_name(segno));
anyhow::ensure!(filesize % BLCKSZ as usize == 0);
assert!(filesize % BLCKSZ as usize == 0); // TODO: this should result in an error
let nblocks = filesize / BLCKSZ as usize;
Ok(PgDataDirDbFile {
PgDataDirDbFile {
path,
filesize,
rel_tag,
segno,
nblocks: Some(nblocks), // first non-cummulative sizes
})
}
})
.collect::<anyhow::Result<_, _>>()?;
.collect();
// Set cummulative sizes. Do all of that math here, so that later we could easier
// parallelize over segments and know with which segments we need to write relsize
@@ -647,22 +650,12 @@ impl PgDataDirDb {
trait ImportTask {
fn key_range(&self) -> Range<Key>;
fn total_size(&self, shard_identity: &ShardIdentity) -> usize {
let range = ShardedRange::new(self.key_range(), shard_identity);
let page_count = range.page_count();
if page_count == u32::MAX {
tracing::warn!(
"Import task has non contiguous key range: {}..{}",
self.key_range().start,
self.key_range().end
);
// Tasks should operate on contiguous ranges. It is unexpected for
// ranges to violate this assumption. Calling code handles this by mapping
// any task on a non contiguous range to its own image layer.
usize::MAX
fn total_size(&self) -> usize {
// TODO: revisit this
if is_contiguous_range(&self.key_range()) {
contiguous_range_len(&self.key_range()) as usize * 8192
} else {
page_count as usize * 8192
u32::MAX as usize
}
}
@@ -760,8 +753,6 @@ impl ImportTask for ImportRelBlocksTask {
layer_writer: &mut ImageLayerWriter,
ctx: &RequestContext,
) -> anyhow::Result<usize> {
const MAX_BYTE_RANGE_SIZE: usize = 128 * 1024 * 1024;
debug!("Importing relation file");
let (rel_tag, start_blk) = self.key_range.start.to_rel_block()?;
@@ -786,7 +777,7 @@ impl ImportTask for ImportRelBlocksTask {
assert_eq!(key.len(), 1);
assert!(!acc.is_empty());
assert!(acc_end > acc_start);
if acc_end == start && end - acc_start <= MAX_BYTE_RANGE_SIZE {
if acc_end == start /* TODO additional max range check here, to limit memory consumption per task to X */ {
acc.push(key.pop().unwrap());
Ok((acc, acc_start, end))
} else {
@@ -801,8 +792,8 @@ impl ImportTask for ImportRelBlocksTask {
.get_range(&self.path, range_start.into_u64(), range_end.into_u64())
.await?;
let mut buf = Bytes::from(range_buf);
// TODO: batched writes
for key in keys {
// The writer buffers writes internally
let image = buf.split_to(8192);
layer_writer.put_image(key, image, ctx).await?;
nimages += 1;
@@ -855,9 +846,6 @@ impl ImportTask for ImportSlruBlocksTask {
debug!("Importing SLRU segment file {}", self.path);
let buf = self.storage.get(&self.path).await?;
// TODO(vlad): Does timestamp to LSN work for imported timelines?
// Probably not since we don't append the `xact_time` to it as in
// [`WalIngest::ingest_xact_record`].
let (kind, segno, start_blk) = self.key_range.start.to_slru_block()?;
let (_kind, _segno, end_blk) = self.key_range.end.to_slru_block()?;
let mut blknum = start_blk;

View File

@@ -6,7 +6,7 @@ use bytes::Bytes;
use postgres_ffi::ControlFileData;
use remote_storage::{
Download, DownloadError, DownloadKind, DownloadOpts, GenericRemoteStorage, Listing,
ListingObject, RemotePath, RemoteStorageConfig,
ListingObject, RemotePath,
};
use serde::de::DeserializeOwned;
use tokio_util::sync::CancellationToken;
@@ -22,9 +22,11 @@ pub async fn new(
location: &index_part_format::Location,
cancel: CancellationToken,
) -> Result<RemoteStorageWrapper, anyhow::Error> {
// Downloads should be reasonably sized. We do ranged reads for relblock raw data
// and full reads for SLRU segments which are bounded by Postgres.
let timeout = RemoteStorageConfig::DEFAULT_TIMEOUT;
// FIXME: we probably want some timeout, and we might be able to assume the max file
// size on S3 is 1GiB (postgres segment size). But the problem is that the individual
// downloaders don't know enough about concurrent downloads to make a guess on the
// expected bandwidth and resulting best timeout.
let timeout = std::time::Duration::from_secs(24 * 60 * 60);
let location_storage = match location {
#[cfg(feature = "testing")]
index_part_format::Location::LocalFs { path } => {
@@ -48,12 +50,9 @@ pub async fn new(
.import_pgdata_aws_endpoint_url
.clone()
.map(|url| url.to_string()), // by specifying None here, remote_storage/aws-sdk-rust will infer from env
// This matches the default import job concurrency. This is managed
// separately from the usual S3 client, but the concern here is bandwidth
// usage.
concurrency_limit: 128.try_into().unwrap(),
max_keys_per_list_response: Some(1000),
upload_storage_class: None, // irrelevant
concurrency_limit: 100.try_into().unwrap(), // TODO: think about this
max_keys_per_list_response: Some(1000), // TODO: think about this
upload_storage_class: None, // irrelevant
},
timeout,
)

View File

@@ -17,9 +17,7 @@ use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{debug, info, warn};
use crate::auth::credentials::check_peer_addr_is_in_list;
use crate::auth::{
self, AuthError, ComputeUserInfoMaybeEndpoint, IpPattern, validate_password_and_exchange,
};
use crate::auth::{self, AuthError, ComputeUserInfoMaybeEndpoint, validate_password_and_exchange};
use crate::cache::Cached;
use crate::config::AuthenticationConfig;
use crate::context::RequestContext;
@@ -137,16 +135,6 @@ impl<'a, T> Backend<'a, T> {
}
}
}
impl<'a, T, E> Backend<'a, Result<T, E>> {
/// Very similar to [`std::option::Option::transpose`].
/// This is most useful for error handling.
pub(crate) fn transpose(self) -> Result<Backend<'a, T>, E> {
match self {
Self::ControlPlane(c, x) => x.map(|x| Backend::ControlPlane(c, x)),
Self::Local(l) => Ok(Backend::Local(l)),
}
}
}
pub(crate) struct ComputeCredentials {
pub(crate) info: ComputeUserInfo,
@@ -284,7 +272,7 @@ async fn auth_quirks(
allow_cleartext: bool,
config: &'static AuthenticationConfig,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> auth::Result<(ComputeCredentials, Option<Vec<IpPattern>>)> {
) -> auth::Result<ComputeCredentials> {
// If there's no project so far, that entails that client doesn't
// support SNI or other means of passing the endpoint (project) name.
// We now expect to see a very specific payload in the place of password.
@@ -301,15 +289,12 @@ async fn auth_quirks(
debug!("fetching authentication info and allowlists");
// check allowed list
let allowed_ips = if config.ip_allowlist_check_enabled {
if config.ip_allowlist_check_enabled {
let allowed_ips = api.get_allowed_ips(ctx, &info).await?;
if !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) {
return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
}
allowed_ips
} else {
Cached::new_uncached(Arc::new(vec![]))
};
}
// check if a VPC endpoint ID is coming in and if yes, if it's allowed
let access_blocks = api.get_block_public_or_vpc_access(ctx, &info).await?;
@@ -368,7 +353,7 @@ async fn auth_quirks(
)
.await
{
Ok(keys) => Ok((keys, Some(allowed_ips.as_ref().clone()))),
Ok(keys) => Ok(keys),
Err(e) => {
if e.is_password_failed() {
// The password could have been changed, so we invalidate the cache.
@@ -420,53 +405,39 @@ async fn authenticate_with_secret(
classic::authenticate(ctx, info, client, config, secret).await
}
impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
/// Get username from the credentials.
pub(crate) fn get_user(&self) -> &str {
match self {
Self::ControlPlane(_, user_info) => &user_info.user,
Self::Local(_) => "local",
}
}
impl ControlPlaneClient {
/// Authenticate the client via the requested backend, possibly using credentials.
#[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)]
pub(crate) async fn authenticate(
self,
&self,
ctx: &RequestContext,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
user_info: ComputeUserInfoMaybeEndpoint,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> auth::Result<(Backend<'a, ComputeCredentials>, Option<Vec<IpPattern>>)> {
let res = match self {
Self::ControlPlane(api, user_info) => {
debug!(
user = &*user_info.user,
project = user_info.endpoint(),
"performing authentication using the console"
);
) -> auth::Result<ComputeCredentials> {
debug!(
user = &*user_info.user,
project = user_info.endpoint(),
"performing authentication using the console"
);
let (credentials, ip_allowlist) = auth_quirks(
ctx,
&*api,
user_info,
client,
allow_cleartext,
config,
endpoint_rate_limiter,
)
.await?;
Ok((Backend::ControlPlane(api, credentials), ip_allowlist))
}
Self::Local(_) => {
return Err(auth::AuthError::bad_auth_method("invalid for local proxy"));
}
};
let credentials = auth_quirks(
ctx,
self,
user_info,
client,
allow_cleartext,
config,
endpoint_rate_limiter,
)
.await?;
// TODO: replace with some metric
info!("user successfully authenticated");
res
Ok(credentials)
}
}
@@ -536,6 +507,25 @@ impl ComputeConnectBackend for Backend<'_, ComputeCredentials> {
}
}
pub struct ControlPlaneWakeCompute<'a> {
pub cplane: &'a ControlPlaneClient,
pub creds: ComputeCredentials,
}
#[async_trait::async_trait]
impl ComputeConnectBackend for ControlPlaneWakeCompute<'_> {
async fn wake_compute(
&self,
ctx: &RequestContext,
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
self.cplane.wake_compute(ctx, &self.creds.info).await
}
fn get_keys(&self) -> &ComputeCredentialKeys {
&self.creds.keys
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unimplemented, clippy::unwrap_used)]
@@ -552,6 +542,7 @@ mod tests {
use postgres_protocol::message::backend::Message as PgMessage;
use postgres_protocol::message::frontend;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
use tokio_util::task::TaskTracker;
use super::jwt::JwkCache;
use super::{AuthRateLimiter, auth_quirks};
@@ -702,7 +693,7 @@ mod tests {
#[tokio::test]
async fn auth_quirks_scram() {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new(Stream::from_raw(server));
let mut stream = PqStream::new(Stream::from_raw(server), TaskTracker::new().token());
let ctx = RequestContext::test();
let api = Auth {
@@ -784,7 +775,7 @@ mod tests {
#[tokio::test]
async fn auth_quirks_cleartext() {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new(Stream::from_raw(server));
let mut stream = PqStream::new(Stream::from_raw(server), TaskTracker::new().token());
let ctx = RequestContext::test();
let api = Auth {
@@ -838,7 +829,7 @@ mod tests {
#[tokio::test]
async fn auth_quirks_password_hack() {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new(Stream::from_raw(server));
let mut stream = PqStream::new(Stream::from_raw(server), TaskTracker::new().token());
let ctx = RequestContext::test();
let api = Auth {
@@ -887,7 +878,7 @@ mod tests {
.await
.unwrap();
assert_eq!(creds.0.info.endpoint, "my-endpoint");
assert_eq!(creds.info.endpoint, "my-endpoint");
handle.await.unwrap();
}

View File

@@ -1,7 +1,7 @@
//! Client authentication mechanisms.
pub mod backend;
pub use backend::Backend;
pub use backend::{Backend, ControlPlaneWakeCompute};
mod credentials;
pub(crate) use credentials::{

View File

@@ -18,6 +18,7 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpListener;
use tokio_rustls::TlsConnector;
use tokio_util::sync::CancellationToken;
use tokio_util::task::task_tracker::TaskTrackerToken;
use tracing::{Instrument, error, info};
use utils::project_git_version;
use utils::sentry_init::init_sentry;
@@ -226,7 +227,8 @@ pub(super) async fn task_main(
let dest_suffix = Arc::clone(&dest_suffix);
let compute_tls_config = compute_tls_config.clone();
connections.spawn(
let tracker = connections.token();
tokio::spawn(
async move {
socket
.set_nodelay(true)
@@ -249,6 +251,7 @@ pub(super) async fn task_main(
compute_tls_config,
tls_server_end_point,
socket,
tracker,
)
.await
}
@@ -274,10 +277,11 @@ const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmod
async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
ctx: &RequestContext,
raw_stream: S,
tracker: TaskTrackerToken,
tls_config: Arc<rustls::ServerConfig>,
tls_server_end_point: TlsServerEndPoint,
) -> anyhow::Result<Stream<S>> {
let mut stream = PqStream::new(Stream::from_raw(raw_stream));
) -> anyhow::Result<(Stream<S>, TaskTrackerToken)> {
let mut stream = PqStream::new(Stream::from_raw(raw_stream), tracker);
let msg = stream.read_startup_packet().await?;
use pq_proto::FeStartupPacket::SslRequest;
@@ -291,7 +295,7 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
// Upgrade raw stream into a secure TLS-backed stream.
// NOTE: We've consumed `tls`; this fact will be used later.
let (raw, read_buf) = stream.into_inner();
let (raw, read_buf, tracker) = stream.into_inner();
// TODO: Normally, client doesn't send any data before
// server says TLS handshake is ok and read_buf is empty.
// However, you could imagine pipelining of postgres
@@ -302,13 +306,16 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
bail!("data is sent before server replied with EncryptionResponse");
}
Ok(Stream::Tls {
tls: Box::new(
raw.upgrade(tls_config, !ctx.has_private_peer_addr())
.await?,
),
tls_server_end_point,
})
Ok((
Stream::Tls {
tls: Box::new(
raw.upgrade(tls_config, !ctx.has_private_peer_addr())
.await?,
),
tls_server_end_point,
},
tracker,
))
}
unexpected => {
info!(
@@ -329,8 +336,10 @@ async fn handle_client(
compute_tls_config: Option<Arc<rustls::ClientConfig>>,
tls_server_end_point: TlsServerEndPoint,
stream: impl AsyncRead + AsyncWrite + Unpin,
tracker: TaskTrackerToken,
) -> anyhow::Result<()> {
let mut tls_stream = ssl_handshake(&ctx, stream, tls_config, tls_server_end_point).await?;
let (mut tls_stream, _tracker) =
ssl_handshake(&ctx, stream, tracker, tls_config, tls_server_end_point).await?;
// Cut off first part of the SNI domain
// We receive required destination details in the format of

View File

@@ -323,7 +323,7 @@ impl CancellationHandler {
}
}
pub(crate) fn get_key(self: &Arc<Self>) -> Session {
pub(crate) fn get_key(self: Arc<Self>) -> Session {
// we intentionally generate a random "backend pid" and "secret key" here.
// we use the corresponding u64 as an identifier for the
// actual endpoint+pid+secret for postgres/pgbouncer.
@@ -340,7 +340,7 @@ impl CancellationHandler {
Session {
key,
redis_key,
cancellation_handler: Arc::clone(self),
cancellation_handler: self,
}
}

View File

@@ -1,8 +1,9 @@
use std::sync::Arc;
use futures::{FutureExt, TryFutureExt};
use futures::TryFutureExt;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio_util::sync::CancellationToken;
use tokio_util::task::task_tracker::TaskTrackerToken;
use tracing::{Instrument, debug, error, info};
use crate::auth::backend::ConsoleRedirectBackend;
@@ -14,10 +15,8 @@ use crate::metrics::{Metrics, NumClientConnectionsGuard};
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute};
use crate::proxy::handshake::{HandshakeData, handshake};
use crate::proxy::passthrough::ProxyPassthrough;
use crate::proxy::{
ClientRequestError, ErrorSource, prepare_client_connection, run_until_cancelled,
};
use crate::proxy::passthrough::passthrough;
use crate::proxy::{ClientRequestError, prepare_client_connection, run_until_cancelled};
pub async fn task_main(
config: &'static ProxyConfig,
@@ -35,7 +34,6 @@ pub async fn task_main(
socket2::SockRef::from(&listener).set_keepalive(true)?;
let connections = tokio_util::task::task_tracker::TaskTracker::new();
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
while let Some(accept_result) =
run_until_cancelled(listener.accept(), &cancellation_token).await
@@ -49,11 +47,11 @@ pub async fn task_main(
let session_id = uuid::Uuid::new_v4();
let cancellation_handler = Arc::clone(&cancellation_handler);
let cancellations = cancellations.clone();
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
connections.spawn(async move {
let tracker = connections.token();
tokio::spawn(async move {
let (socket, peer_addr) = match read_proxy_protocol(socket).await {
Err(e) => {
error!("per-client task finished with an error: {e:#}");
@@ -103,99 +101,80 @@ pub async fn task_main(
&config.region,
);
let span = ctx.span();
let mut slot = Some(ctx);
let res = handle_client(
config,
backend,
&ctx,
&mut slot,
cancellation_handler,
socket,
conn_gauge,
cancellations,
tracker,
)
.instrument(ctx.span())
.boxed()
.instrument(span)
.await;
match res {
Err(e) => {
match (slot, res) {
(None, _) => {}
(Some(ctx), Ok(())) => {
ctx.success();
}
(Some(ctx), Err(e)) => {
ctx.set_error_kind(e.get_error_kind());
error!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
}
Ok(None) => {
ctx.set_success();
}
Ok(Some(p)) => {
ctx.set_success();
let _disconnect = ctx.log_connect();
match p.proxy_pass(&config.connect_to_compute).await {
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
error!(
?session_id,
"per-client task finished with an IO error from the client: {e:#}"
);
}
Err(ErrorSource::Compute(e)) => {
error!(
?session_id,
"per-client task finished with an IO error from the compute: {e:#}"
);
}
}
tracing::warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
}
}
});
}
connections.close();
cancellations.close();
drop(listener);
// Drain connections
connections.wait().await;
cancellations.wait().await;
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
config: &'static ProxyConfig,
backend: &'static ConsoleRedirectBackend,
ctx: &RequestContext,
ctx_slot: &mut Option<RequestContext>,
cancellation_handler: Arc<CancellationHandler>,
stream: S,
conn_gauge: NumClientConnectionsGuard<'static>,
cancellations: tokio_util::task::task_tracker::TaskTracker,
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
debug!(
protocol = %ctx.protocol(),
"handling interactive connection from client"
);
tracker: TaskTrackerToken,
) -> Result<(), ClientRequestError> {
let protocol = ctx_slot.as_ref().expect("context must be set").protocol();
debug!(%protocol, "handling interactive connection from client");
let metrics = &Metrics::get().proxy;
let proto = ctx.protocol();
let request_gauge = metrics.connection_requests.guard(proto);
let request_gauge = metrics.connection_requests.guard(protocol);
let tls = config.tls_config.load();
let tls = tls.as_deref();
let record_handshake_error = !ctx.has_private_peer_addr();
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let do_handshake = handshake(ctx, stream, tls, record_handshake_error);
let data = {
let ctx = ctx_slot.as_ref().expect("context must be set");
let record_handshake_error = !ctx.has_private_peer_addr();
let _pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let do_handshake = handshake(ctx, stream, tracker, tls, record_handshake_error);
tokio::time::timeout(config.handshake_timeout, do_handshake).await??
};
let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
.await??
{
let (mut stream, params) = match data {
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Cancel(cancel_key_data) => {
HandshakeData::Cancel(cancel_key_data, tracker) => {
// spawn a task to cancel the session, but don't wait for it
cancellations.spawn({
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
let ctx = ctx.clone();
tokio::spawn({
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
let ctx = ctx_slot.take().expect("context must be set");
let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id());
cancel_span.follows_from(tracing::Span::current());
async move {
let _tracker = tracker;
cancellation_handler_clone
.cancel_session(
cancel_key_data,
@@ -205,15 +184,17 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
backend.get_api(),
)
.await
.inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok();
}.instrument(cancel_span)
.inspect_err(|e| debug!(error = ?e, "cancel_session failed"))
.ok();
}
.instrument(cancel_span)
});
return Ok(None);
return Ok(());
}
};
drop(pause);
let ctx = ctx_slot.as_ref().expect("context must be set");
ctx.set_db_options(params.clone());
let (node_info, user_info, _ip_allowlist) = match backend
@@ -228,13 +209,13 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
let mut node = connect_to_compute(
ctx,
&TcpMechanism {
TcpMechanism {
user_info,
params_compat: true,
params: &params,
locks: &config.connect_compute_locks,
},
&node_info,
node_info,
config.wake_compute_retry_config,
&config.connect_to_compute,
)
@@ -252,17 +233,22 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
// PqStream input buffer. Normally there is none, but our serverless npm
// driver in pipeline mode sends startup, password and first query
// immediately after opening the connection.
let (stream, read_buf) = stream.into_inner();
let (stream, read_buf, tracker) = stream.into_inner();
node.stream.write_all(&read_buf).await?;
Ok(Some(ProxyPassthrough {
client: stream,
aux: node.aux.clone(),
private_link_id: None,
compute: node,
session_id: ctx.session_id(),
cancel: session,
_req: request_gauge,
_conn: conn_gauge,
}))
let ctx = ctx_slot.take().expect("context must be set");
ctx.set_success();
tokio::spawn(passthrough(
ctx,
&config.connect_to_compute,
stream,
node,
session,
request_gauge,
conn_gauge,
tracker,
));
Ok(())
}

View File

@@ -38,7 +38,7 @@ pub struct RequestContext(
/// I would typically use a RefCell but that would break the `Send` requirements
/// so we need something with thread-safety. `TryLock` is a cheap alternative
/// that offers similar semantics to a `RefCell` but with synchronisation.
TryLock<RequestContextInner>,
TryLock<Box<RequestContextInner>>,
);
struct RequestContextInner {
@@ -89,7 +89,7 @@ pub(crate) enum AuthMethod {
impl Clone for RequestContext {
fn clone(&self) -> Self {
let inner = self.0.try_lock().expect("should not deadlock");
let new = RequestContextInner {
let new = Box::new(RequestContextInner {
conn_info: inner.conn_info.clone(),
session_id: inner.session_id,
protocol: inner.protocol,
@@ -117,7 +117,7 @@ impl Clone for RequestContext {
disconnect_sender: None,
latency_timer: LatencyTimer::noop(inner.protocol),
disconnect_timestamp: inner.disconnect_timestamp,
};
});
Self(TryLock::new(new))
}
@@ -140,7 +140,7 @@ impl RequestContext {
role = tracing::field::Empty,
);
let inner = RequestContextInner {
let inner = Box::new(RequestContextInner {
conn_info,
session_id,
protocol,
@@ -168,7 +168,7 @@ impl RequestContext {
disconnect_sender: LOG_CHAN_DISCONNECT.get().and_then(|tx| tx.upgrade()),
latency_timer: LatencyTimer::new(protocol),
disconnect_timestamp: None,
};
});
Self(TryLock::new(inner))
}
@@ -522,7 +522,7 @@ impl Drop for RequestContextInner {
}
}
pub struct DisconnectLogger(RequestContextInner);
pub struct DisconnectLogger(Box<RequestContextInner>);
impl Drop for DisconnectLogger {
fn drop(&mut self) {

View File

@@ -53,6 +53,25 @@ pub(crate) trait ConnectMechanism {
fn update_connect_config(&self, conf: &mut compute::ConnCfg);
}
#[async_trait]
impl<T: ConnectMechanism + Sync> ConnectMechanism for &T {
type Connection = T::Connection;
type ConnectError = T::ConnectError;
type Error = T::Error;
async fn connect_once(
&self,
ctx: &RequestContext,
node_info: &control_plane::CachedNodeInfo,
config: &ComputeConfig,
) -> Result<Self::Connection, Self::ConnectError> {
T::connect_once(self, ctx, node_info, config).await
}
fn update_connect_config(&self, conf: &mut compute::ConnCfg) {
T::update_connect_config(self, conf);
}
}
#[async_trait]
pub(crate) trait ComputeConnectBackend {
async fn wake_compute(
@@ -105,8 +124,8 @@ impl ConnectMechanism for TcpMechanism<'_> {
#[tracing::instrument(skip_all)]
pub(crate) async fn connect_to_compute<M: ConnectMechanism, B: ComputeConnectBackend>(
ctx: &RequestContext,
mechanism: &M,
user_info: &B,
mechanism: M,
backend: B,
wake_compute_retry_config: RetryConfig,
compute: &ComputeConfig,
) -> Result<M::Connection, M::Error>
@@ -116,9 +135,9 @@ where
{
let mut num_retries = 0;
let mut node_info =
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?;
wake_compute(&mut num_retries, ctx, &backend, wake_compute_retry_config).await?;
node_info.set_keys(user_info.get_keys());
node_info.set_keys(backend.get_keys());
mechanism.update_connect_config(&mut node_info.config);
// try once
@@ -159,7 +178,7 @@ where
let old_node_info = invalidate_cache(node_info);
// TODO: increment num_retries?
let mut node_info =
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?;
wake_compute(&mut num_retries, ctx, &backend, wake_compute_retry_config).await?;
node_info.reuse_settings(old_node_info);
mechanism.update_connect_config(&mut node_info.config);

View File

@@ -67,7 +67,6 @@ where
}
}
#[tracing::instrument(skip_all)]
pub async fn copy_bidirectional_client_compute<Client, Compute>(
client: &mut Client,
compute: &mut Compute,

View File

@@ -5,6 +5,7 @@ use pq_proto::{
};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::task::task_tracker::TaskTrackerToken;
use tracing::{debug, info, warn};
use crate::auth::endpoint_sni;
@@ -51,7 +52,7 @@ impl ReportableError for HandshakeError {
pub(crate) enum HandshakeData<S> {
Startup(PqStream<Stream<S>>, StartupMessageParams),
Cancel(CancelKeyData),
Cancel(CancelKeyData, TaskTrackerToken),
}
/// Establish a (most probably, secure) connection with the client.
@@ -62,6 +63,7 @@ pub(crate) enum HandshakeData<S> {
pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
ctx: &RequestContext,
stream: S,
tracker: TaskTrackerToken,
mut tls: Option<&TlsConfig>,
record_handshake_error: bool,
) -> Result<HandshakeData<S>, HandshakeError> {
@@ -71,7 +73,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
const PG_PROTOCOL_EARLIEST: ProtocolVersion = ProtocolVersion::new(3, 0);
const PG_PROTOCOL_LATEST: ProtocolVersion = ProtocolVersion::new(3, 0);
let mut stream = PqStream::new(Stream::from_raw(stream));
let mut stream = PqStream::new(Stream::from_raw(stream), tracker);
loop {
let msg = stream.read_startup_packet().await?;
match msg {
@@ -157,15 +159,13 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
let (_, tls_server_end_point) =
tls.cert_resolver.resolve(conn_info.server_name());
stream = PqStream {
framed: Framed {
stream: Stream::Tls {
tls: Box::new(tls_stream),
tls_server_end_point,
},
read_buf,
write_buf,
stream.framed = Framed {
stream: Stream::Tls {
tls: Box::new(tls_stream),
tls_server_end_point,
},
read_buf,
write_buf,
};
}
}
@@ -248,7 +248,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
}
FeStartupPacket::CancelRequest(cancel_key_data) => {
info!(session_type = "cancellation", "successful handshake");
break Ok(HandshakeData::Cancel(cancel_key_data));
break Ok(HandshakeData::Cancel(cancel_key_data, stream.tracker));
}
}
}

View File

@@ -10,26 +10,27 @@ pub(crate) mod wake_compute;
use std::sync::Arc;
pub use copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute};
use futures::{FutureExt, TryFutureExt};
use futures::TryFutureExt;
use itertools::Itertools;
use once_cell::sync::OnceCell;
use passthrough::passthrough;
use pq_proto::{BeMessage as Be, CancelKeyData, StartupMessageParams};
use regex::Regex;
use serde::{Deserialize, Serialize};
use smol_str::{SmolStr, ToSmolStr, format_smolstr};
use smol_str::{SmolStr, format_smolstr};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio_util::sync::CancellationToken;
use tokio_util::task::task_tracker::TaskTrackerToken;
use tracing::{Instrument, debug, error, info, warn};
use self::connect_compute::{TcpMechanism, connect_to_compute};
use self::passthrough::ProxyPassthrough;
use crate::cancellation::{self, CancellationHandler};
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::metrics::{Metrics, NumClientConnectionsGuard};
use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol};
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
use crate::proxy::handshake::{HandshakeData, handshake};
use crate::rate_limiter::EndpointRateLimiter;
use crate::stream::{PqStream, Stream};
@@ -70,7 +71,6 @@ pub async fn task_main(
socket2::SockRef::from(&listener).set_keepalive(true)?;
let connections = tokio_util::task::task_tracker::TaskTracker::new();
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
while let Some(accept_result) =
run_until_cancelled(listener.accept(), &cancellation_token).await
@@ -84,12 +84,12 @@ pub async fn task_main(
let session_id = uuid::Uuid::new_v4();
let cancellation_handler = Arc::clone(&cancellation_handler);
let cancellations = cancellations.clone();
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
connections.spawn(async move {
let tracker = connections.token();
tokio::spawn(async move {
let (socket, conn_info) = match read_proxy_protocol(socket).await {
Err(e) => {
warn!("per-client task finished with an error: {e:#}");
@@ -138,60 +138,41 @@ pub async fn task_main(
crate::metrics::Protocol::Tcp,
&config.region,
);
let span = ctx.span();
let mut ctx = Some(ctx);
let res = handle_client(
config,
auth_backend,
&ctx,
&mut ctx,
cancellation_handler,
socket,
ClientMode::Tcp,
endpoint_rate_limiter2,
conn_gauge,
cancellations,
tracker,
)
.instrument(ctx.span())
.boxed()
.instrument(span)
.await;
match res {
Err(e) => {
match (ctx, res) {
(None, _) => {}
(Some(ctx), Ok(())) => {
ctx.success();
}
(Some(ctx), Err(e)) => {
ctx.set_error_kind(e.get_error_kind());
warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
}
Ok(None) => {
ctx.set_success();
}
Ok(Some(p)) => {
ctx.set_success();
let _disconnect = ctx.log_connect();
match p.proxy_pass(&config.connect_to_compute).await {
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
warn!(
?session_id,
"per-client task finished with an IO error from the client: {e:#}"
);
}
Err(ErrorSource::Compute(e)) => {
error!(
?session_id,
"per-client task finished with an IO error from the compute: {e:#}"
);
}
}
}
}
});
}
connections.close();
cancellations.close();
drop(listener);
// Drain connections
connections.wait().await;
cancellations.wait().await;
Ok(())
}
@@ -258,46 +239,79 @@ impl ReportableError for ClientRequestError {
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
config: &'static ProxyConfig,
auth_backend: &'static auth::Backend<'static, ()>,
ctx: &RequestContext,
ctx_slot: &mut Option<RequestContext>,
cancellation_handler: Arc<CancellationHandler>,
stream: S,
mode: ClientMode,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
conn_gauge: NumClientConnectionsGuard<'static>,
cancellations: tokio_util::task::task_tracker::TaskTracker,
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
debug!(
protocol = %ctx.protocol(),
"handling interactive connection from client"
);
tracker: TaskTrackerToken,
) -> Result<(), ClientRequestError> {
let cplane = match auth_backend {
auth::Backend::ControlPlane(cplane, ()) => &**cplane,
auth::Backend::Local(_) => unreachable!("local proxy does not run tcp proxy service"),
};
let protocol = ctx_slot.as_ref().expect("context must be set").protocol();
debug!(%protocol, "handling interactive connection from client");
let metrics = &Metrics::get().proxy;
let proto = ctx.protocol();
let request_gauge = metrics.connection_requests.guard(proto);
let request_gauge = metrics.connection_requests.guard(protocol);
let tls = config.tls_config.load();
let tls = tls.as_deref();
let handshake_result: Result<_, ClientRequestError> = async {
let tls = config.tls_config.load();
let tls = tls.as_deref();
let record_handshake_error = !ctx.has_private_peer_addr();
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let do_handshake = handshake(ctx, stream, mode.handshake_tls(tls), record_handshake_error);
let ctx = ctx_slot.as_ref().expect("context must be set");
let record_handshake_error = !ctx.has_private_peer_addr();
let data = {
let _pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
tokio::time::timeout(
config.handshake_timeout,
handshake(
ctx,
stream,
tracker,
mode.handshake_tls(tls),
record_handshake_error,
),
)
.await??
};
let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
.await??
{
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Cancel(cancel_key_data) => {
// spawn a task to cancel the session, but don't wait for it
cancellations.spawn({
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
let ctx = ctx.clone();
let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id());
match data {
HandshakeData::Startup(mut stream, params) => {
ctx.set_db_options(params.clone());
let host = mode.hostname(stream.get_ref());
let cn = tls.map(|tls| &tls.common_names);
// Extract credentials which we're going to use for auth.
let result = auth::ComputeUserInfoMaybeEndpoint::parse(ctx, &params, host, cn);
let user_info = match result {
Ok(user_info) => user_info,
Err(e) => stream.throw_error(e, Some(ctx)).await?,
};
let session = cancellation_handler.get_key();
Ok(Some((stream, params, session, user_info)))
}
HandshakeData::Cancel(cancel_key_data, tracker) => {
let ctx = ctx_slot.take().expect("context must be set");
ctx.set_success();
let cancel_span = tracing::info_span!(parent: None, "cancel_session", session_id = ?ctx.session_id());
cancel_span.follows_from(tracing::Span::current());
async move {
cancellation_handler_clone
// spawn a task to cancel the session, but don't wait for it
tokio::spawn(async move {
// ensure the proxy doesn't shutdown until we complete this task.
let _tracker = tracker;
cancellation_handler
.cancel_session(
cancel_key_data,
ctx,
@@ -305,111 +319,108 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
config.authentication_config.is_vpc_acccess_proxy,
auth_backend.get_api(),
)
.instrument(cancel_span)
.await
.inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok();
}.instrument(cancel_span)
});
.unwrap_or_else(|e| debug!(error = ?e, "cancel_session failed"));
});
return Ok(None);
Ok(None)
}
}
}
.await;
let Some((mut stream, params, session, user_info)) = handshake_result? else {
return Ok(());
};
drop(pause);
let ctx = ctx_slot.as_ref().expect("context must be set");
ctx.set_db_options(params.clone());
let auth_result: Result<_, ClientRequestError> = async {
let user = user_info.user.clone();
let hostname = mode.hostname(stream.get_ref());
match cplane
.authenticate(
ctx,
&mut stream,
user_info,
mode.allow_cleartext(),
&config.authentication_config,
endpoint_rate_limiter,
)
.await
{
Ok(auth_result) => Ok(auth_result),
Err(e) => {
let db = params.get("database");
let app = params.get("application_name");
let params_span = tracing::info_span!("", ?user, ?db, ?app);
stream
.throw_error(e, Some(ctx))
.instrument(params_span)
.await?
}
}
}
.await;
let common_names = tls.map(|tls| &tls.common_names);
let compute_creds = auth_result?;
// Extract credentials which we're going to use for auth.
let result = auth_backend
.as_ref()
.map(|()| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, &params, hostname, common_names))
.transpose();
let connect_result: Result<_, ClientRequestError> = async {
let compute_user_info = compute_creds.info.clone();
let params_compat = compute_user_info
.options
.get(NeonOptions::PARAMS_COMPAT)
.is_some();
let user_info = match result {
Ok(user_info) => user_info,
Err(e) => stream.throw_error(e, Some(ctx)).await?,
};
let user = user_info.get_user().to_owned();
let (user_info, _ip_allowlist) = match user_info
.authenticate(
let mut node = connect_to_compute(
ctx,
&mut stream,
mode.allow_cleartext(),
&config.authentication_config,
endpoint_rate_limiter,
TcpMechanism {
user_info: compute_user_info,
params_compat,
params: &params,
locks: &config.connect_compute_locks,
},
auth::ControlPlaneWakeCompute {
cplane,
creds: compute_creds,
},
config.wake_compute_retry_config,
&config.connect_to_compute,
)
.await
{
Ok(auth_result) => auth_result,
Err(e) => {
let db = params.get("database");
let app = params.get("application_name");
let params_span = tracing::info_span!("", ?user, ?db, ?app);
.or_else(|e| stream.throw_error(e, Some(ctx)))
.await?;
return stream
.throw_error(e, Some(ctx))
.instrument(params_span)
.await?;
}
};
session.write_cancel_key(node.cancel_closure.clone())?;
prepare_client_connection(&node, *session.key(), &mut stream).await?;
let compute_user_info = match &user_info {
auth::Backend::ControlPlane(_, info) => &info.info,
auth::Backend::Local(_) => unreachable!("local proxy does not run tcp proxy service"),
};
let params_compat = compute_user_info
.options
.get(NeonOptions::PARAMS_COMPAT)
.is_some();
// Before proxy passing, forward to compute whatever data is left in the
// PqStream input buffer. Normally there is none, but our serverless npm
// driver in pipeline mode sends startup, password and first query
// immediately after opening the connection.
let (stream, read_buf, tracker) = stream.into_inner();
node.stream.write_all(&read_buf).await?;
let mut node = connect_to_compute(
Ok((node, stream, tracker))
}
.await;
let (node, stream, tracker) = connect_result?;
let ctx = ctx_slot.take().expect("context must be set");
ctx.set_success();
tokio::spawn(passthrough(
ctx,
&TcpMechanism {
user_info: compute_user_info.clone(),
params_compat,
params: &params,
locks: &config.connect_compute_locks,
},
&user_info,
config.wake_compute_retry_config,
&config.connect_to_compute,
)
.or_else(|e| stream.throw_error(e, Some(ctx)))
.await?;
stream,
node,
session,
request_gauge,
conn_gauge,
tracker,
));
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
let session = cancellation_handler_clone.get_key();
session.write_cancel_key(node.cancel_closure.clone())?;
prepare_client_connection(&node, *session.key(), &mut stream).await?;
// Before proxy passing, forward to compute whatever data is left in the
// PqStream input buffer. Normally there is none, but our serverless npm
// driver in pipeline mode sends startup, password and first query
// immediately after opening the connection.
let (stream, read_buf) = stream.into_inner();
node.stream.write_all(&read_buf).await?;
let private_link_id = match ctx.extra() {
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
None => None,
};
Ok(Some(ProxyPassthrough {
client: stream,
aux: node.aux.clone(),
private_link_id,
compute: node,
session_id: ctx.session_id(),
cancel: session,
_req: request_gauge,
_conn: conn_gauge,
}))
Ok(())
}
/// Finish client connection initialization: confirm auth success, send params, etc.

View File

@@ -1,5 +1,6 @@
use smol_str::SmolStr;
use smol_str::{SmolStr, ToSmolStr};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::task::task_tracker::TaskTrackerToken;
use tracing::debug;
use utils::measured_stream::MeasuredStream;
@@ -7,13 +8,14 @@ use super::copy_bidirectional::ErrorSource;
use crate::cancellation;
use crate::compute::PostgresConnection;
use crate::config::ComputeConfig;
use crate::context::RequestContext;
use crate::control_plane::messages::MetricsAuxInfo;
use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard};
use crate::protocol2::ConnectionInfoExtra;
use crate::stream::Stream;
use crate::usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS};
/// Forward bytes in both directions (client <-> compute).
#[tracing::instrument(skip_all)]
pub(crate) async fn proxy_pass(
client: impl AsyncRead + AsyncWrite + Unpin,
compute: impl AsyncRead + AsyncWrite + Unpin,
@@ -61,41 +63,53 @@ pub(crate) async fn proxy_pass(
Ok(())
}
pub(crate) struct ProxyPassthrough<S> {
pub(crate) client: Stream<S>,
pub(crate) compute: PostgresConnection,
pub(crate) aux: MetricsAuxInfo,
pub(crate) session_id: uuid::Uuid,
pub(crate) private_link_id: Option<SmolStr>,
pub(crate) cancel: cancellation::Session,
#[allow(clippy::too_many_arguments)]
pub(crate) async fn passthrough<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
ctx: RequestContext,
compute_config: &'static ComputeConfig,
pub(crate) _req: NumConnectionRequestsGuard<'static>,
pub(crate) _conn: NumClientConnectionsGuard<'static>,
}
client: Stream<S>,
compute: PostgresConnection,
cancel: cancellation::Session,
impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {
pub(crate) async fn proxy_pass(
self,
compute_config: &ComputeConfig,
) -> Result<(), ErrorSource> {
let res = proxy_pass(
self.client,
self.compute.stream,
self.aux,
self.private_link_id,
)
.await;
if let Err(err) = self
.compute
.cancel_closure
.try_cancel_query(compute_config)
.await
{
tracing::warn!(session_id = ?self.session_id, ?err, "could not cancel the query in the database");
_req: NumConnectionRequestsGuard<'static>,
_conn: NumClientConnectionsGuard<'static>,
_tracker: TaskTrackerToken,
) {
let session_id = ctx.session_id();
let private_link_id = match ctx.extra() {
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
None => None,
};
let _disconnect = ctx.log_connect();
let res = proxy_pass(client, compute.stream, compute.aux, private_link_id).await;
match res {
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
tracing::warn!(
session_id = ?session_id,
"per-client task finished with an IO error from the client: {e:#}"
);
}
Err(ErrorSource::Compute(e)) => {
tracing::error!(
session_id = ?session_id,
"per-client task finished with an IO error from the compute: {e:#}"
);
}
drop(self.cancel.remove_cancel_key()); // we don't need a result. If the queue is full, we just log the error
res
}
if let Err(err) = compute
.cancel_closure
.try_cancel_query(compute_config)
.await
{
tracing::warn!(session_id = ?session_id, ?err, "could not cancel the query in the database");
}
// we don't need a result. If the queue is full, we just log the error
drop(cancel.remove_cancel_key());
}

View File

@@ -38,6 +38,7 @@ async fn proxy_mitm(
let (end_client, startup) = match handshake(
&RequestContext::test(),
client1,
TaskTracker::new().token(),
Some(&server_config1),
false,
)
@@ -45,7 +46,7 @@ async fn proxy_mitm(
.unwrap()
{
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Cancel(_) => panic!("cancellation not supported"),
HandshakeData::Cancel(_, _) => panic!("cancellation not supported"),
};
let mut end_server = tokio_util::codec::Framed::new(end_server, PgFrame);

View File

@@ -15,6 +15,7 @@ use rstest::rstest;
use rustls::crypto::ring;
use rustls::pki_types;
use tokio::io::DuplexStream;
use tokio_util::task::TaskTracker;
use tracing_test::traced_test;
use super::connect_compute::ConnectMechanism;
@@ -178,10 +179,12 @@ async fn dummy_proxy(
auth: impl TestAuth + Send,
) -> anyhow::Result<()> {
let (client, _) = read_proxy_protocol(client).await?;
let mut stream = match handshake(&RequestContext::test(), client, tls.as_ref(), false).await? {
HandshakeData::Startup(stream, _) => stream,
HandshakeData::Cancel(_) => bail!("cancellation not supported"),
};
let t = TaskTracker::new().token();
let mut stream =
match handshake(&RequestContext::test(), client, t, tls.as_ref(), false).await? {
HandshakeData::Startup(stream, _) => stream,
HandshakeData::Cancel(_, _) => bail!("cancellation not supported"),
};
auth.authenticate(&mut stream).await?;
@@ -622,7 +625,7 @@ async fn connect_to_compute_success() {
let mechanism = TestConnectMechanism::new(vec![Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
let config = config();
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config)
.await
.unwrap();
mechanism.verify();
@@ -636,7 +639,7 @@ async fn connect_to_compute_retry() {
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
let config = config();
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config)
.await
.unwrap();
mechanism.verify();
@@ -651,7 +654,7 @@ async fn connect_to_compute_non_retry_1() {
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Fail]);
let user_info = helper_create_connect_info(&mechanism);
let config = config();
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config)
.await
.unwrap_err();
mechanism.verify();
@@ -666,7 +669,7 @@ async fn connect_to_compute_non_retry_2() {
let mechanism = TestConnectMechanism::new(vec![Wake, Fail, Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
let config = config();
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config)
.await
.unwrap();
mechanism.verify();
@@ -691,7 +694,7 @@ async fn connect_to_compute_non_retry_3() {
connect_to_compute(
&ctx,
&mechanism,
&user_info,
user_info,
wake_compute_retry_config,
&config,
)
@@ -709,7 +712,7 @@ async fn wake_retry() {
let mechanism = TestConnectMechanism::new(vec![WakeRetry, Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
let config = config();
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config)
.await
.unwrap();
mechanism.verify();
@@ -724,7 +727,7 @@ async fn wake_non_retry() {
let mechanism = TestConnectMechanism::new(vec![WakeRetry, WakeFail]);
let user_info = helper_create_connect_info(&mechanism);
let config = config();
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config)
.await
.unwrap_err();
mechanism.verify();
@@ -743,7 +746,7 @@ async fn fail_but_wake_invalidates_cache() {
let user = helper_create_connect_info(&mech);
let cfg = config();
connect_to_compute(&ctx, &mech, &user, cfg.retry, &cfg)
connect_to_compute(&ctx, &mech, user, cfg.retry, &cfg)
.await
.unwrap();
@@ -764,7 +767,7 @@ async fn fail_no_wake_skips_cache_invalidation() {
let user = helper_create_connect_info(&mech);
let cfg = config();
connect_to_compute(&ctx, &mech, &user, cfg.retry, &cfg)
connect_to_compute(&ctx, &mech, user, cfg.retry, &cfg)
.await
.unwrap();
@@ -785,7 +788,7 @@ async fn retry_but_wake_invalidates_cache() {
let user_info = helper_create_connect_info(&mechanism);
let cfg = config();
connect_to_compute(&ctx, &mechanism, &user_info, cfg.retry, &cfg)
connect_to_compute(&ctx, &mechanism, user_info, cfg.retry, &cfg)
.await
.unwrap();
mechanism.verify();
@@ -808,7 +811,7 @@ async fn retry_no_wake_skips_invalidation() {
let user_info = helper_create_connect_info(&mechanism);
let cfg = config();
connect_to_compute(&ctx, &mechanism, &user_info, cfg.retry, &cfg)
connect_to_compute(&ctx, &mechanism, user_info, cfg.retry, &cfg)
.await
.unwrap_err();
mechanism.verify();

View File

@@ -224,13 +224,13 @@ impl PoolingBackend {
let backend = self.auth_backend.as_ref().map(|()| keys);
crate::proxy::connect_compute::connect_to_compute(
ctx,
&TokioMechanism {
TokioMechanism {
conn_id,
conn_info,
pool: self.pool.clone(),
locks: &self.config.connect_compute_locks,
},
&backend,
backend,
self.config.wake_compute_retry_config,
&self.config.connect_to_compute,
)
@@ -268,13 +268,13 @@ impl PoolingBackend {
});
crate::proxy::connect_compute::connect_to_compute(
ctx,
&HyperMechanism {
HyperMechanism {
conn_id,
conn_info,
pool: self.http_conn_pool.clone(),
locks: &self.config.connect_compute_locks,
},
&backend,
backend,
self.config.wake_compute_retry_config,
&self.config.connect_to_compute,
)

View File

@@ -41,7 +41,7 @@ use tokio::net::{TcpListener, TcpStream};
use tokio::time::timeout;
use tokio_rustls::TlsAcceptor;
use tokio_util::sync::CancellationToken;
use tokio_util::task::TaskTracker;
use tokio_util::task::task_tracker::TaskTrackerToken;
use tracing::{Instrument, info, warn};
use crate::cancellation::CancellationHandler;
@@ -124,7 +124,6 @@ pub async fn task_main(
let connections = tokio_util::task::task_tracker::TaskTracker::new();
connections.close(); // allows `connections.wait to complete`
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
while let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await {
let (conn, peer_addr) = res.context("could not accept TCP stream")?;
if let Err(e) = conn.set_nodelay(true) {
@@ -150,11 +149,11 @@ pub async fn task_main(
let conn_token = cancellation_token.child_token();
let tls_acceptor = tls_acceptor.clone();
let backend = backend.clone();
let connections2 = connections.clone();
let cancellation_handler = cancellation_handler.clone();
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
let cancellations = cancellations.clone();
connections.spawn(
let tracker = connections.token();
tokio::spawn(
async move {
let conn_token2 = conn_token.clone();
let _cancel_guard = config.http_config.cancel_set.insert(conn_id, conn_token2);
@@ -181,8 +180,7 @@ pub async fn task_main(
Box::pin(connection_handler(
config,
backend,
connections2,
cancellations,
tracker,
cancellation_handler,
endpoint_rate_limiter,
conn_token,
@@ -305,8 +303,7 @@ async fn connection_startup(
async fn connection_handler(
config: &'static ProxyConfig,
backend: Arc<PoolingBackend>,
connections: TaskTracker,
cancellations: TaskTracker,
tracker: TaskTrackerToken,
cancellation_handler: Arc<CancellationHandler>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellation_token: CancellationToken,
@@ -347,19 +344,17 @@ async fn connection_handler(
// `request_handler` is not cancel safe. It expects to be cancelled only at specific times.
// By spawning the future, we ensure it never gets cancelled until it decides to.
let cancellations = cancellations.clone();
let handler = connections.spawn(
let handler = tokio::spawn(
request_handler(
req,
config,
backend.clone(),
connections.clone(),
tracker.clone(),
cancellation_handler.clone(),
session_id,
conn_info2.clone(),
http_request_token,
endpoint_rate_limiter.clone(),
cancellations,
)
.in_current_span()
.map_ok_or_else(api_error_into_response, |r| r),
@@ -400,14 +395,13 @@ async fn request_handler(
mut request: hyper::Request<Incoming>,
config: &'static ProxyConfig,
backend: Arc<PoolingBackend>,
ws_connections: TaskTracker,
tracker: TaskTrackerToken,
cancellation_handler: Arc<CancellationHandler>,
session_id: uuid::Uuid,
conn_info: ConnectionInfo,
// used to cancel in-flight HTTP requests. not used to cancel websockets
http_cancellation_token: CancellationToken,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellations: TaskTracker,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
let host = request
.headers()
@@ -441,10 +435,17 @@ async fn request_handler(
let (response, websocket) = framed_websockets::upgrade::upgrade(&mut request)
.map_err(|e| ApiError::BadRequest(e.into()))?;
let cancellations = cancellations.clone();
ws_connections.spawn(
tokio::spawn(
async move {
if let Err(e) = websocket::serve_websocket(
let websocket = match websocket.await {
Err(e) => {
warn!("could not upgrade websocket connection: {e:#}");
return;
}
Ok(websocket) => websocket,
};
websocket::serve_websocket(
config,
backend.auth_backend,
ctx,
@@ -452,12 +453,9 @@ async fn request_handler(
cancellation_handler,
endpoint_rate_limiter,
host,
cancellations,
tracker,
)
.await
{
warn!("error in websocket connection: {e:#}");
}
.await;
}
.instrument(span),
);

View File

@@ -2,14 +2,14 @@ use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll, ready};
use anyhow::Context as _;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use framed_websockets::{Frame, OpCode, WebSocketServer};
use futures::{Sink, Stream};
use hyper::upgrade::OnUpgrade;
use hyper::upgrade::Upgraded;
use hyper_util::rt::TokioIo;
use pin_project_lite::pin_project;
use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
use tokio_util::task::task_tracker::TaskTrackerToken;
use tracing::warn;
use crate::cancellation::CancellationHandler;
@@ -17,7 +17,7 @@ use crate::config::ProxyConfig;
use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::metrics::Metrics;
use crate::proxy::{ClientMode, ErrorSource, handle_client};
use crate::proxy::{ClientMode, handle_client};
use crate::rate_limiter::EndpointRateLimiter;
pin_project! {
@@ -128,13 +128,12 @@ pub(crate) async fn serve_websocket(
config: &'static ProxyConfig,
auth_backend: &'static crate::auth::Backend<'static, ()>,
ctx: RequestContext,
websocket: OnUpgrade,
websocket: Upgraded,
cancellation_handler: Arc<CancellationHandler>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
hostname: Option<String>,
cancellations: tokio_util::task::task_tracker::TaskTracker,
) -> anyhow::Result<()> {
let websocket = websocket.await?;
tracker: TaskTrackerToken,
) {
let websocket = WebSocketServer::after_handshake(TokioIo::new(websocket));
let conn_gauge = Metrics::get()
@@ -142,36 +141,28 @@ pub(crate) async fn serve_websocket(
.client_connections
.guard(crate::metrics::Protocol::Ws);
let res = Box::pin(handle_client(
let mut ctx_slot = Some(ctx);
let res = handle_client(
config,
auth_backend,
&ctx,
&mut ctx_slot,
cancellation_handler,
WebSocketRw::new(websocket),
ClientMode::Websockets { hostname },
endpoint_rate_limiter,
conn_gauge,
cancellations,
))
tracker,
)
.await;
match res {
Err(e) => {
match (ctx_slot, res) {
(None, _) => {}
(Some(ctx), Err(e)) => {
ctx.set_error_kind(e.get_error_kind());
Err(e.into())
tracing::warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
}
Ok(None) => {
(Some(ctx), Ok(())) => {
ctx.set_success();
Ok(())
}
Ok(Some(p)) => {
ctx.set_success();
ctx.log_connect();
match p.proxy_pass(&config.connect_to_compute).await {
Ok(()) => Ok(()),
Err(ErrorSource::Client(err)) => Err(err).context("client"),
Err(ErrorSource::Compute(err)) => Err(err).context("compute"),
}
}
}
}

View File

@@ -10,6 +10,7 @@ use serde::{Deserialize, Serialize};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_rustls::server::TlsStream;
use tokio_util::task::task_tracker::TaskTrackerToken;
use tracing::debug;
use crate::control_plane::messages::ColdStartInfo;
@@ -24,19 +25,22 @@ use crate::tls::TlsServerEndPoint;
/// to pass random malformed bytes through the connection).
pub struct PqStream<S> {
pub(crate) framed: Framed<S>,
pub(crate) tracker: TaskTrackerToken,
}
impl<S> PqStream<S> {
/// Construct a new libpq protocol wrapper.
pub fn new(stream: S) -> Self {
pub fn new(stream: S, tracker: TaskTrackerToken) -> Self {
Self {
framed: Framed::new(stream),
tracker,
}
}
/// Extract the underlying stream and read buffer.
pub fn into_inner(self) -> (S, BytesMut) {
self.framed.into_inner()
pub fn into_inner(self) -> (S, BytesMut, TaskTrackerToken) {
let (stream, read) = self.framed.into_inner();
(stream, read, self.tracker)
}
/// Get a shared reference to the underlying stream.

View File

@@ -139,14 +139,6 @@ pub(crate) struct StorageControllerMetricGroup {
/// HTTP request status counters for handled requests
pub(crate) storage_controller_reconcile_long_running:
measured::CounterVec<ReconcileLongRunningLabelGroupSet>,
/// Indicator of safekeeper reconciler queue depth, broken down by safekeeper, excluding ongoing reconciles.
pub(crate) storage_controller_safkeeper_reconciles_queued:
measured::GaugeVec<SafekeeperReconcilerLabelGroupSet>,
/// Indicator of completed safekeeper reconciles, broken down by safekeeper.
pub(crate) storage_controller_safkeeper_reconciles_complete:
measured::CounterVec<SafekeeperReconcilerLabelGroupSet>,
}
impl StorageControllerMetrics {
@@ -265,17 +257,6 @@ pub(crate) enum Method {
Other,
}
#[derive(measured::LabelGroup, Clone)]
#[label(set = SafekeeperReconcilerLabelGroupSet)]
pub(crate) struct SafekeeperReconcilerLabelGroup<'a> {
#[label(dynamic_with = lasso::ThreadedRodeo, default)]
pub(crate) sk_az: &'a str,
#[label(dynamic_with = lasso::ThreadedRodeo, default)]
pub(crate) sk_node_id: &'a str,
#[label(dynamic_with = lasso::ThreadedRodeo, default)]
pub(crate) sk_hostname: &'a str,
}
impl From<hyper::Method> for Method {
fn from(value: hyper::Method) -> Self {
if value == hyper::Method::GET {

View File

@@ -20,9 +20,7 @@ use utils::{
};
use crate::{
metrics::{METRICS_REGISTRY, SafekeeperReconcilerLabelGroup},
persistence::SafekeeperTimelineOpKind,
safekeeper::Safekeeper,
persistence::SafekeeperTimelineOpKind, safekeeper::Safekeeper,
safekeeper_client::SafekeeperClient,
};
@@ -220,26 +218,7 @@ impl ReconcilerHandle {
fn schedule_reconcile(&self, req: ScheduleRequest) {
let (cancel, token_id) = self.new_token_slot(req.tenant_id, req.timeline_id);
let hostname = req.safekeeper.skp.host.clone();
let sk_az = req.safekeeper.skp.availability_zone_id.clone();
let sk_node_id = req.safekeeper.get_id().to_string();
// We don't have direct access to the queue depth here, so increase it blindly by 1.
// We know that putting into the queue increases the queue depth. The receiver will
// update with the correct value once it processes the next item. To avoid races where we
// reduce before we increase, leaving the gauge with a 1 value for a long time, we
// increase it before putting into the queue.
let queued_gauge = &METRICS_REGISTRY
.metrics_group
.storage_controller_safkeeper_reconciles_queued;
let label_group = SafekeeperReconcilerLabelGroup {
sk_az: &sk_az,
sk_node_id: &sk_node_id,
sk_hostname: &hostname,
};
queued_gauge.inc(label_group.clone());
if let Err(err) = self.tx.send((req, cancel, token_id)) {
queued_gauge.set(label_group, 0);
tracing::info!("scheduling request onto {hostname} returned error: {err}");
}
}
@@ -304,18 +283,6 @@ impl SafekeeperReconciler {
continue;
}
let queued_gauge = &METRICS_REGISTRY
.metrics_group
.storage_controller_safkeeper_reconciles_queued;
queued_gauge.set(
SafekeeperReconcilerLabelGroup {
sk_az: &req.safekeeper.skp.availability_zone_id,
sk_node_id: &req.safekeeper.get_id().to_string(),
sk_hostname: &req.safekeeper.skp.host,
},
self.rx.len() as i64,
);
tokio::task::spawn(async move {
let kind = req.kind;
let tenant_id = req.tenant_id;
@@ -544,16 +511,6 @@ impl SafekeeperReconcilerInner {
req.generation,
)
.await;
let complete_counter = &METRICS_REGISTRY
.metrics_group
.storage_controller_safkeeper_reconciles_complete;
complete_counter.inc(SafekeeperReconcilerLabelGroup {
sk_az: &req.safekeeper.skp.availability_zone_id,
sk_node_id: &req.safekeeper.get_id().to_string(),
sk_hostname: &req.safekeeper.skp.host,
});
if let Err(err) = res {
tracing::info!(
"couldn't remove reconciliation request onto {} from persistence: {err:?}",

View File

@@ -536,14 +536,16 @@ class NeonLocalCli(AbstractNeonCli):
res.check_returncode()
return res
def endpoint_generate_jwt(self, endpoint_id: str, scope: list[ComputeClaimsScope]) -> str:
def endpoint_generate_jwt(
self, endpoint_id: str, scope: ComputeClaimsScope | None = None
) -> str:
"""
Generate a JWT for making requests to the endpoint's external HTTP
server.
"""
args = ["endpoint", "generate-jwt", endpoint_id]
for s in scope:
args += ["--scope", str(s)]
if scope:
args += ["--scope", str(scope)]
cmd = self.raw_cli(args)
cmd.check_returncode()

View File

@@ -4282,7 +4282,7 @@ class Endpoint(PgProtocol, LogUtils):
self.config(config_lines)
self.__jwt = self.generate_jwt([])
self.__jwt = self.generate_jwt()
return self
@@ -4329,7 +4329,7 @@ class Endpoint(PgProtocol, LogUtils):
return self
def generate_jwt(self, scope: list[ComputeClaimsScope]) -> str:
def generate_jwt(self, scope: ComputeClaimsScope | None = None) -> str:
"""
Generate a JWT for making requests to the endpoint's external HTTP
server.

View File

@@ -287,17 +287,6 @@ def test_pgdata_import_smoke(
with pytest.raises(psycopg2.errors.UndefinedTable):
br_initdb_endpoint.safe_psql(f"select * from {workload.table}")
# The storage controller might be overly eager and attempt to finalize
# the import before the task got a chance to exit.
env.storage_controller.allowed_errors.extend(
[
".*Call to node.*management API.*failed.*Import task still running.*",
]
)
for ps in env.pageservers:
ps.allowed_errors.extend([".*Error processing HTTP request.*Import task not done yet.*"])
@run_only_on_default_postgres(reason="PG version is irrelevant here")
def test_import_completion_on_restart(
@@ -482,17 +471,6 @@ def test_import_respects_timeline_lifecycle(
else:
raise RuntimeError(f"{action} param not recognized")
# The storage controller might be overly eager and attempt to finalize
# the import before the task got a chance to exit.
env.storage_controller.allowed_errors.extend(
[
".*Call to node.*management API.*failed.*Import task still running.*",
]
)
for ps in env.pageservers:
ps.allowed_errors.extend([".*Error processing HTTP request.*Import task not done yet.*"])
@skip_in_debug_build("Validation query takes too long in debug builds")
def test_import_chaos(

View File

@@ -124,9 +124,6 @@ def test_location_conf_churn(neon_env_builder: NeonEnvBuilder, make_httpserver,
".*downloading failed, possibly for shutdown",
# {tenant_id=... timeline_id=...}:handle_pagerequests:handle_get_page_at_lsn_request{rel=1664/0/1260 blkno=0 req_lsn=0/149F0D8}: error reading relation or page version: Not found: will not become active. Current state: Stopping\n'
".*page_service.*will not become active.*",
# the following errors are possible when pageserver tries to ingest wal records despite being in unreadable state
".*wal_connection_manager.*layer file download failed: No file found.*",
".*wal_connection_manager.*could not ingest record.*",
]
)
@@ -159,45 +156,6 @@ def test_location_conf_churn(neon_env_builder: NeonEnvBuilder, make_httpserver,
env.pageservers[2].id: ("Detached", None),
}
# Track all the attached locations with mode and generation
history: list[tuple[int, str, int | None]] = []
def may_read(pageserver: NeonPageserver, mode: str, generation: int | None) -> bool:
# Rules for when a pageserver may read:
# - our generation is higher than any previous
# - our generation is equal to previous, but no other pageserver
# in that generation has been AttachedSingle (i.e. allowed to compact/GC)
# - our generation is equal to previous, and the previous holder of this
# generation was the same node as we're attaching now.
#
# If these conditions are not met, then a read _might_ work, but the pageserver might
# also hit errors trying to download layers.
highest_historic_generation = max([i[2] for i in history if i[2] is not None], default=None)
if generation is None:
# We're not in an attached state, we may not read
return False
elif highest_historic_generation is not None and generation < highest_historic_generation:
# We are in an outdated generation, we may not read
return False
elif highest_historic_generation is not None and generation == highest_historic_generation:
# We are re-using a generation: if any pageserver other than this one
# has held AttachedSingle mode, this node may not read (because some other
# node may be doing GC/compaction).
if any(
i[1] == "AttachedSingle"
and i[2] == highest_historic_generation
and i[0] != pageserver.id
for i in history
):
log.info(
f"Skipping read on {pageserver.id} because other pageserver has been in AttachedSingle mode in generation {highest_historic_generation}"
)
return False
# Fall through: we have passed conditions for readability
return True
latest_attached = env.pageservers[0].id
for _i in range(0, 64):
@@ -241,10 +199,9 @@ def test_location_conf_churn(neon_env_builder: NeonEnvBuilder, make_httpserver,
assert len(tenants) == 1
assert tenants[0]["generation"] == new_generation
if may_read(pageserver, last_state_ps[0], last_state_ps[1]):
log.info("Entering postgres...")
workload.churn_rows(rng.randint(128, 256), pageserver.id)
workload.validate(pageserver.id)
log.info("Entering postgres...")
workload.churn_rows(rng.randint(128, 256), pageserver.id)
workload.validate(pageserver.id)
elif last_state_ps[0].startswith("Attached"):
# The `storage_controller` will only re-attach on startup when a pageserver was the
# holder of the latest generation: otherwise the pageserver will revert to detached
@@ -284,16 +241,18 @@ def test_location_conf_churn(neon_env_builder: NeonEnvBuilder, make_httpserver,
location_conf["generation"] = generation
pageserver.tenant_location_configure(tenant_id, location_conf)
last_state[pageserver.id] = (mode, generation)
may_read_this_generation = may_read(pageserver, mode, generation)
history.append((pageserver.id, mode, generation))
# It's only valid to connect to the last generation. Newer generations may yank layer
# files used in older generations.
last_generation = max(
[s[1] for s in last_state.values() if s[1] is not None], default=None
)
# This is a basic test: we are validating that he endpoint works properly _between_
# configuration changes. A stronger test would be to validate that clients see
# no errors while we are making the changes.
if may_read_this_generation:
if mode.startswith("Attached") and generation == last_generation:
# This is a basic test: we are validating that he endpoint works properly _between_
# configuration changes. A stronger test would be to validate that clients see
# no errors while we are making the changes.
workload.churn_rows(
rng.randint(128, 256), pageserver.id, upload=mode != "AttachedStale"
)
@@ -306,16 +265,9 @@ def test_location_conf_churn(neon_env_builder: NeonEnvBuilder, make_httpserver,
assert gc_summary["remote_storage_errors"] == 0
assert gc_summary["indices_deleted"] > 0
# Attach all pageservers, in a higher generation than any previous. We will use the same
# gen for all, and AttachedMulti mode so that they do not interfere with one another.
generation = env.storage_controller.attach_hook_issue(tenant_id, env.pageservers[0].id)
# Attach all pageservers
for ps in env.pageservers:
location_conf = {
"mode": "AttachedMulti",
"secondary_conf": None,
"tenant_conf": {},
"generation": generation,
}
location_conf = {"mode": "AttachedMulti", "secondary_conf": None, "tenant_conf": {}}
ps.tenant_location_configure(tenant_id, location_conf)
# Confirm that all are readable