Compare commits

..

1 Commits

Author SHA1 Message Date
Christian Schwarz
5953ea12f6 DO NOT MERGE: diable materialized page cache for benchmarking 2024-03-23 17:40:25 +00:00
51 changed files with 204 additions and 606 deletions

View File

@@ -1121,16 +1121,10 @@ jobs:
run: |
if [[ "$GITHUB_REF_NAME" == "main" ]]; then
gh workflow --repo neondatabase/aws run deploy-dev.yml --ref main -f branch=main -f dockerTag=${{needs.tag.outputs.build-tag}} -f deployPreprodRegion=false
elif [[ "$GITHUB_REF_NAME" == "release" ]]; then
gh workflow --repo neondatabase/aws run deploy-dev.yml --ref main \
-f deployPgSniRouter=false \
-f deployProxy=false \
-f deployStorage=true \
-f deployStorageBroker=true \
-f branch=main \
-f dockerTag=${{needs.tag.outputs.build-tag}} \
-f deployPreprodRegion=true
# TODO: move deployPreprodRegion to release (`"$GITHUB_REF_NAME" == "release"` block), once Staging support different compute tag prefixes for different regions
gh workflow --repo neondatabase/aws run deploy-dev.yml --ref main -f branch=main -f dockerTag=${{needs.tag.outputs.build-tag}} -f deployPreprodRegion=true
elif [[ "$GITHUB_REF_NAME" == "release" ]]; then
gh workflow --repo neondatabase/aws run deploy-prod.yml --ref main \
-f deployPgSniRouter=false \
-f deployProxy=false \
@@ -1139,15 +1133,6 @@ jobs:
-f branch=main \
-f dockerTag=${{needs.tag.outputs.build-tag}}
elif [[ "$GITHUB_REF_NAME" == "release-proxy" ]]; then
gh workflow --repo neondatabase/aws run deploy-dev.yml --ref main \
-f deployPgSniRouter=true \
-f deployProxy=true \
-f deployStorage=false \
-f deployStorageBroker=false \
-f branch=main \
-f dockerTag=${{needs.tag.outputs.build-tag}} \
-f deployPreprodRegion=true
gh workflow --repo neondatabase/aws run deploy-proxy-prod.yml --ref main \
-f deployPgSniRouter=true \
-f deployProxy=true \

1
Cargo.lock generated
View File

@@ -4237,7 +4237,6 @@ dependencies = [
"consumption_metrics",
"dashmap",
"env_logger",
"fallible-iterator",
"futures",
"git-version",
"hashbrown 0.13.2",

View File

@@ -79,7 +79,6 @@ either = "1.8"
enum-map = "2.4.2"
enumset = "1.0.12"
fail = "0.5.0"
fallible-iterator = "0.2"
fs2 = "0.4.3"
futures = "0.3"
futures-core = "0.3"

View File

@@ -15,9 +15,9 @@ use metrics::launch_timestamp::{set_launch_timestamp_metric, LaunchTimestamp};
use pageserver::control_plane_client::ControlPlaneClient;
use pageserver::disk_usage_eviction_task::{self, launch_disk_usage_global_eviction_task};
use pageserver::metrics::{STARTUP_DURATION, STARTUP_IS_LOADING};
use pageserver::task_mgr::WALRECEIVER_RUNTIME;
use pageserver::tenant::{secondary, TenantSharedResources};
use remote_storage::GenericRemoteStorage;
use tokio::signal::unix::SignalKind;
use tokio::time::Instant;
use tracing::*;
@@ -28,7 +28,7 @@ use pageserver::{
deletion_queue::DeletionQueue,
http, page_cache, page_service, task_mgr,
task_mgr::TaskKind,
task_mgr::THE_RUNTIME,
task_mgr::{BACKGROUND_RUNTIME, COMPUTE_REQUEST_RUNTIME, MGMT_REQUEST_RUNTIME},
tenant::mgr,
virtual_file,
};
@@ -323,7 +323,7 @@ fn start_pageserver(
// Launch broker client
// The storage_broker::connect call needs to happen inside a tokio runtime thread.
let broker_client = THE_RUNTIME
let broker_client = WALRECEIVER_RUNTIME
.block_on(async {
// Note: we do not attempt connecting here (but validate endpoints sanity).
storage_broker::connect(conf.broker_endpoint.clone(), conf.broker_keepalive_interval)
@@ -391,7 +391,7 @@ fn start_pageserver(
conf,
);
if let Some(deletion_workers) = deletion_workers {
deletion_workers.spawn_with(THE_RUNTIME.handle());
deletion_workers.spawn_with(BACKGROUND_RUNTIME.handle());
}
// Up to this point no significant I/O has been done: this should have been fast. Record
@@ -423,7 +423,7 @@ fn start_pageserver(
// Scan the local 'tenants/' directory and start loading the tenants
let deletion_queue_client = deletion_queue.new_client();
let tenant_manager = THE_RUNTIME.block_on(mgr::init_tenant_mgr(
let tenant_manager = BACKGROUND_RUNTIME.block_on(mgr::init_tenant_mgr(
conf,
TenantSharedResources {
broker_client: broker_client.clone(),
@@ -435,7 +435,7 @@ fn start_pageserver(
))?;
let tenant_manager = Arc::new(tenant_manager);
THE_RUNTIME.spawn({
BACKGROUND_RUNTIME.spawn({
let shutdown_pageserver = shutdown_pageserver.clone();
let drive_init = async move {
// NOTE: unlike many futures in pageserver, this one is cancellation-safe
@@ -545,7 +545,7 @@ fn start_pageserver(
// Start up the service to handle HTTP mgmt API request. We created the
// listener earlier already.
{
let _rt_guard = THE_RUNTIME.enter();
let _rt_guard = MGMT_REQUEST_RUNTIME.enter();
let router_state = Arc::new(
http::routes::State::new(
@@ -569,6 +569,7 @@ fn start_pageserver(
.with_graceful_shutdown(task_mgr::shutdown_watcher());
task_mgr::spawn(
MGMT_REQUEST_RUNTIME.handle(),
TaskKind::HttpEndpointListener,
None,
None,
@@ -593,6 +594,7 @@ fn start_pageserver(
let local_disk_storage = conf.workdir.join("last_consumption_metrics.json");
task_mgr::spawn(
crate::BACKGROUND_RUNTIME.handle(),
TaskKind::MetricsCollection,
None,
None,
@@ -613,7 +615,6 @@ fn start_pageserver(
pageserver::consumption_metrics::collect_metrics(
metric_collection_endpoint,
&conf.metric_collection_bucket,
conf.metric_collection_interval,
conf.cached_metric_collection_interval,
conf.synthetic_size_calculation_interval,
@@ -641,6 +642,7 @@ fn start_pageserver(
DownloadBehavior::Error,
);
task_mgr::spawn(
COMPUTE_REQUEST_RUNTIME.handle(),
TaskKind::LibpqEndpointListener,
None,
None,
@@ -664,37 +666,42 @@ fn start_pageserver(
let mut shutdown_pageserver = Some(shutdown_pageserver.drop_guard());
// All started up! Now just sit and wait for shutdown signal.
{
THE_RUNTIME.block_on(async move {
let mut sigint = tokio::signal::unix::signal(SignalKind::interrupt()).unwrap();
let mut sigterm = tokio::signal::unix::signal(SignalKind::terminate()).unwrap();
let mut sigquit = tokio::signal::unix::signal(SignalKind::quit()).unwrap();
let signal = tokio::select! {
_ = sigquit.recv() => {
info!("Got signal SIGQUIT. Terminating in immediate shutdown mode",);
std::process::exit(111);
}
_ = sigint.recv() => { "SIGINT" },
_ = sigterm.recv() => { "SIGTERM" },
};
use signal_hook::consts::*;
let signal_handler = BACKGROUND_RUNTIME.spawn_blocking(move || {
let mut signals =
signal_hook::iterator::Signals::new([SIGINT, SIGTERM, SIGQUIT]).unwrap();
return signals
.forever()
.next()
.expect("forever() never returns None unless explicitly closed");
});
let signal = BACKGROUND_RUNTIME
.block_on(signal_handler)
.expect("join error");
match signal {
SIGQUIT => {
info!("Got signal {signal}. Terminating in immediate shutdown mode",);
std::process::exit(111);
}
SIGINT | SIGTERM => {
info!("Got signal {signal}. Terminating gracefully in fast shutdown mode",);
info!("Got signal {signal}. Terminating gracefully in fast shutdown mode",);
// This cancels the `shutdown_pageserver` cancellation tree.
// Right now that tree doesn't reach very far, and `task_mgr` is used instead.
// The plan is to change that over time.
shutdown_pageserver.take();
let bg_remote_storage = remote_storage.clone();
let bg_deletion_queue = deletion_queue.clone();
pageserver::shutdown_pageserver(
&tenant_manager,
bg_remote_storage.map(|_| bg_deletion_queue),
0,
)
.await;
unreachable!()
})
// This cancels the `shutdown_pageserver` cancellation tree.
// Right now that tree doesn't reach very far, and `task_mgr` is used instead.
// The plan is to change that over time.
shutdown_pageserver.take();
let bg_remote_storage = remote_storage.clone();
let bg_deletion_queue = deletion_queue.clone();
BACKGROUND_RUNTIME.block_on(pageserver::shutdown_pageserver(
&tenant_manager,
bg_remote_storage.map(|_| bg_deletion_queue),
0,
));
unreachable!()
}
_ => unreachable!(),
}
}
}

View File

@@ -234,7 +234,6 @@ pub struct PageServerConf {
// How often to send unchanged cached metrics to the metrics endpoint.
pub cached_metric_collection_interval: Duration,
pub metric_collection_endpoint: Option<Url>,
pub metric_collection_bucket: Option<RemoteStorageConfig>,
pub synthetic_size_calculation_interval: Duration,
pub disk_usage_based_eviction: Option<DiskUsageEvictionTaskConfig>,
@@ -374,7 +373,6 @@ struct PageServerConfigBuilder {
cached_metric_collection_interval: BuilderValue<Duration>,
metric_collection_endpoint: BuilderValue<Option<Url>>,
synthetic_size_calculation_interval: BuilderValue<Duration>,
metric_collection_bucket: BuilderValue<Option<RemoteStorageConfig>>,
disk_usage_based_eviction: BuilderValue<Option<DiskUsageEvictionTaskConfig>>,
@@ -457,8 +455,6 @@ impl PageServerConfigBuilder {
.expect("cannot parse default synthetic size calculation interval")),
metric_collection_endpoint: Set(DEFAULT_METRIC_COLLECTION_ENDPOINT),
metric_collection_bucket: Set(None),
disk_usage_based_eviction: Set(None),
test_remote_failures: Set(0),
@@ -590,13 +586,6 @@ impl PageServerConfigBuilder {
self.metric_collection_endpoint = BuilderValue::Set(metric_collection_endpoint)
}
pub fn metric_collection_bucket(
&mut self,
metric_collection_bucket: Option<RemoteStorageConfig>,
) {
self.metric_collection_bucket = BuilderValue::Set(metric_collection_bucket)
}
pub fn synthetic_size_calculation_interval(
&mut self,
synthetic_size_calculation_interval: Duration,
@@ -705,7 +694,6 @@ impl PageServerConfigBuilder {
metric_collection_interval,
cached_metric_collection_interval,
metric_collection_endpoint,
metric_collection_bucket,
synthetic_size_calculation_interval,
disk_usage_based_eviction,
test_remote_failures,
@@ -954,9 +942,6 @@ impl PageServerConf {
let endpoint = parse_toml_string(key, item)?.parse().context("failed to parse metric_collection_endpoint")?;
builder.metric_collection_endpoint(Some(endpoint));
},
"metric_collection_bucket" => {
builder.metric_collection_bucket(RemoteStorageConfig::from_toml(item)?)
}
"synthetic_size_calculation_interval" =>
builder.synthetic_size_calculation_interval(parse_toml_duration(key, item)?),
"test_remote_failures" => builder.test_remote_failures(parse_toml_u64(key, item)?),
@@ -1072,7 +1057,6 @@ impl PageServerConf {
metric_collection_interval: Duration::from_secs(60),
cached_metric_collection_interval: Duration::from_secs(60 * 60),
metric_collection_endpoint: defaults::DEFAULT_METRIC_COLLECTION_ENDPOINT,
metric_collection_bucket: None,
synthetic_size_calculation_interval: Duration::from_secs(60),
disk_usage_based_eviction: None,
test_remote_failures: 0,
@@ -1305,7 +1289,6 @@ background_task_maximum_delay = '334 s'
defaults::DEFAULT_CACHED_METRIC_COLLECTION_INTERVAL
)?,
metric_collection_endpoint: defaults::DEFAULT_METRIC_COLLECTION_ENDPOINT,
metric_collection_bucket: None,
synthetic_size_calculation_interval: humantime::parse_duration(
defaults::DEFAULT_SYNTHETIC_SIZE_CALCULATION_INTERVAL
)?,
@@ -1380,7 +1363,6 @@ background_task_maximum_delay = '334 s'
metric_collection_interval: Duration::from_secs(222),
cached_metric_collection_interval: Duration::from_secs(22200),
metric_collection_endpoint: Some(Url::parse("http://localhost:80/metrics")?),
metric_collection_bucket: None,
synthetic_size_calculation_interval: Duration::from_secs(333),
disk_usage_based_eviction: None,
test_remote_failures: 0,

View File

@@ -1,13 +1,12 @@
//! Periodically collect consumption metrics for all active tenants
//! and push them to a HTTP endpoint.
use crate::context::{DownloadBehavior, RequestContext};
use crate::task_mgr::{self, TaskKind};
use crate::task_mgr::{self, TaskKind, BACKGROUND_RUNTIME};
use crate::tenant::tasks::BackgroundLoopKind;
use crate::tenant::{mgr, LogicalSizeCalculationCause, PageReconstructError, Tenant};
use camino::Utf8PathBuf;
use consumption_metrics::EventType;
use pageserver_api::models::TenantState;
use remote_storage::{GenericRemoteStorage, RemoteStorageConfig};
use reqwest::Url;
use std::collections::HashMap;
use std::sync::Arc;
@@ -42,7 +41,6 @@ type Cache = HashMap<MetricsKey, (EventType, u64)>;
#[allow(clippy::too_many_arguments)]
pub async fn collect_metrics(
metric_collection_endpoint: &Url,
metric_collection_bucket: &Option<RemoteStorageConfig>,
metric_collection_interval: Duration,
_cached_metric_collection_interval: Duration,
synthetic_size_calculation_interval: Duration,
@@ -61,6 +59,7 @@ pub async fn collect_metrics(
let worker_ctx =
ctx.detached_child(TaskKind::CalculateSyntheticSize, DownloadBehavior::Download);
task_mgr::spawn(
BACKGROUND_RUNTIME.handle(),
TaskKind::CalculateSyntheticSize,
None,
None,
@@ -95,20 +94,6 @@ pub async fn collect_metrics(
.build()
.expect("Failed to create http client with timeout");
let bucket_client = if let Some(bucket_config) = metric_collection_bucket {
match GenericRemoteStorage::from_config(bucket_config) {
Ok(client) => Some(client),
Err(e) => {
// Non-fatal error: if we were given an invalid config, we will proceed
// with sending metrics over the network, but not to S3.
tracing::warn!("Invalid configuration for metric_collection_bucket: {e}");
None
}
}
} else {
None
};
let node_id = node_id.to_string();
loop {
@@ -133,18 +118,10 @@ pub async fn collect_metrics(
tracing::error!("failed to persist metrics to {path:?}: {e:#}");
}
}
if let Some(bucket_client) = &bucket_client {
let res =
upload::upload_metrics_bucket(bucket_client, &cancel, &node_id, &metrics).await;
if let Err(e) = res {
tracing::error!("failed to upload to S3: {e:#}");
}
}
};
let upload = async {
let res = upload::upload_metrics_http(
let res = upload::upload_metrics(
&client,
metric_collection_endpoint,
&cancel,
@@ -155,7 +132,7 @@ pub async fn collect_metrics(
.await;
if let Err(e) = res {
// serialization error which should never happen
tracing::error!("failed to upload via HTTP due to {e:#}");
tracing::error!("failed to upload due to {e:#}");
}
};

View File

@@ -1,9 +1,4 @@
use std::time::SystemTime;
use chrono::{DateTime, Utc};
use consumption_metrics::{Event, EventChunk, IdempotencyKey, CHUNK_SIZE};
use remote_storage::{GenericRemoteStorage, RemotePath};
use tokio::io::AsyncWriteExt;
use tokio_util::sync::CancellationToken;
use tracing::Instrument;
@@ -18,9 +13,8 @@ struct Ids {
pub(super) timeline_id: Option<TimelineId>,
}
/// Serialize and write metrics to an HTTP endpoint
#[tracing::instrument(skip_all, fields(metrics_total = %metrics.len()))]
pub(super) async fn upload_metrics_http(
pub(super) async fn upload_metrics(
client: &reqwest::Client,
metric_collection_endpoint: &reqwest::Url,
cancel: &CancellationToken,
@@ -80,60 +74,6 @@ pub(super) async fn upload_metrics_http(
Ok(())
}
/// Serialize and write metrics to a remote storage object
#[tracing::instrument(skip_all, fields(metrics_total = %metrics.len()))]
pub(super) async fn upload_metrics_bucket(
client: &GenericRemoteStorage,
cancel: &CancellationToken,
node_id: &str,
metrics: &[RawMetric],
) -> anyhow::Result<()> {
if metrics.is_empty() {
// Skip uploads if we have no metrics, so that readers don't have to handle the edge case
// of an empty object.
return Ok(());
}
// Compose object path
let datetime: DateTime<Utc> = SystemTime::now().into();
let ts_prefix = datetime.format("year=%Y/month=%m/day=%d/%H:%M:%SZ");
let path = RemotePath::from_string(&format!("{ts_prefix}_{node_id}.ndjson.gz"))?;
// Set up a gzip writer into a buffer
let mut compressed_bytes: Vec<u8> = Vec::new();
let compressed_writer = std::io::Cursor::new(&mut compressed_bytes);
let mut gzip_writer = async_compression::tokio::write::GzipEncoder::new(compressed_writer);
// Serialize and write into compressed buffer
let started_at = std::time::Instant::now();
for res in serialize_in_chunks(CHUNK_SIZE, metrics, node_id) {
let (_chunk, body) = res?;
gzip_writer.write_all(&body).await?;
}
gzip_writer.flush().await?;
gzip_writer.shutdown().await?;
let compressed_length = compressed_bytes.len();
// Write to remote storage
client
.upload_storage_object(
futures::stream::once(futures::future::ready(Ok(compressed_bytes.into()))),
compressed_length,
&path,
cancel,
)
.await?;
let elapsed = started_at.elapsed();
tracing::info!(
compressed_length,
elapsed_ms = elapsed.as_millis(),
"write metrics bucket at {path}",
);
Ok(())
}
// The return type is quite ugly, but we gain testability in isolation
fn serialize_in_chunks<'a, F>(
chunk_size: usize,

View File

@@ -173,6 +173,8 @@ impl ControlPlaneGenerationsApi for ControlPlaneClient {
register,
};
fail::fail_point!("control-plane-client-re-attach");
let response: ReAttachResponse = self.retry_http_forever(&re_attach_path, request).await?;
tracing::info!(
"Received re-attach response with {} tenants",
@@ -208,7 +210,7 @@ impl ControlPlaneGenerationsApi for ControlPlaneClient {
.collect(),
};
crate::tenant::pausable_failpoint!("control-plane-client-validate");
fail::fail_point!("control-plane-client-validate");
let response: ValidateResponse = self.retry_http_forever(&re_attach_path, request).await?;

View File

@@ -59,7 +59,7 @@ use utils::{completion, id::TimelineId};
use crate::{
config::PageServerConf,
metrics::disk_usage_based_eviction::METRICS,
task_mgr::{self, TaskKind},
task_mgr::{self, TaskKind, BACKGROUND_RUNTIME},
tenant::{
self,
mgr::TenantManager,
@@ -202,6 +202,7 @@ pub fn launch_disk_usage_global_eviction_task(
info!("launching disk usage based eviction task");
task_mgr::spawn(
BACKGROUND_RUNTIME.handle(),
TaskKind::DiskUsageEviction,
None,
None,

View File

@@ -180,6 +180,7 @@ pub async fn libpq_listener_main(
// only deal with a particular timeline, but we don't know which one
// yet.
task_mgr::spawn(
&tokio::runtime::Handle::current(),
TaskKind::PageRequestHandler,
None,
None,

View File

@@ -98,22 +98,42 @@ use utils::id::TimelineId;
// other operations, if the upload tasks e.g. get blocked on locks. It shouldn't
// happen, but still.
//
/// The single tokio runtime used by all pageserver code.
/// In the past, we had multiple runtimes, and in the future we should weed out
/// remaining references to this global field and rely on ambient runtime instead,
/// i.e., use `tokio::spawn` instead of `THE_RUNTIME.spawn()`, etc.
pub static THE_RUNTIME: Lazy<Runtime> = Lazy::new(|| {
pub static COMPUTE_REQUEST_RUNTIME: Lazy<Runtime> = Lazy::new(|| {
tokio::runtime::Builder::new_multi_thread()
.thread_name("compute request worker")
.enable_all()
.build()
.expect("Failed to create compute request runtime")
});
pub static MGMT_REQUEST_RUNTIME: Lazy<Runtime> = Lazy::new(|| {
tokio::runtime::Builder::new_multi_thread()
.thread_name("mgmt request worker")
.enable_all()
.build()
.expect("Failed to create mgmt request runtime")
});
pub static WALRECEIVER_RUNTIME: Lazy<Runtime> = Lazy::new(|| {
tokio::runtime::Builder::new_multi_thread()
.thread_name("walreceiver worker")
.enable_all()
.build()
.expect("Failed to create walreceiver runtime")
});
pub static BACKGROUND_RUNTIME: Lazy<Runtime> = Lazy::new(|| {
tokio::runtime::Builder::new_multi_thread()
.thread_name("background op worker")
// if you change the number of worker threads please change the constant below
.enable_all()
.build()
.expect("Failed to create background op runtime")
});
pub(crate) static THE_RUNTIME_WORKER_THREADS: Lazy<usize> = Lazy::new(|| {
pub(crate) static BACKGROUND_RUNTIME_WORKER_THREADS: Lazy<usize> = Lazy::new(|| {
// force init and thus panics
let _ = THE_RUNTIME.handle();
let _ = BACKGROUND_RUNTIME.handle();
// replicates tokio-1.28.1::loom::sys::num_cpus which is not available publicly
// tokio would had already panicked for parsing errors or NotUnicode
//
@@ -305,6 +325,7 @@ struct PageServerTask {
/// Note: if shutdown_process_on_error is set to true failure
/// of the task will lead to shutdown of entire process
pub fn spawn<F>(
runtime: &tokio::runtime::Handle,
kind: TaskKind,
tenant_shard_id: Option<TenantShardId>,
timeline_id: Option<TimelineId>,
@@ -333,7 +354,7 @@ where
let task_name = name.to_string();
let task_cloned = Arc::clone(&task);
let join_handle = THE_RUNTIME.spawn(task_wrapper(
let join_handle = runtime.spawn(task_wrapper(
task_name,
task_id,
task_cloned,

View File

@@ -144,7 +144,6 @@ macro_rules! pausable_failpoint {
}
};
}
pub(crate) use pausable_failpoint;
pub mod blob_io;
pub mod block_io;
@@ -662,6 +661,7 @@ impl Tenant {
let tenant_clone = Arc::clone(&tenant);
let ctx = ctx.detached_child(TaskKind::Attach, DownloadBehavior::Warn);
task_mgr::spawn(
&tokio::runtime::Handle::current(),
TaskKind::Attach,
Some(tenant_shard_id),
None,

View File

@@ -482,6 +482,7 @@ impl DeleteTenantFlow {
let tenant_shard_id = tenant.tenant_shard_id;
task_mgr::spawn(
task_mgr::BACKGROUND_RUNTIME.handle(),
TaskKind::TimelineDeletionWorker,
Some(tenant_shard_id),
None,

View File

@@ -1850,6 +1850,7 @@ impl TenantManager {
let task_tenant_id = None;
task_mgr::spawn(
task_mgr::BACKGROUND_RUNTIME.handle(),
TaskKind::MgmtRequest,
task_tenant_id,
None,
@@ -2815,12 +2816,15 @@ pub(crate) fn immediate_gc(
// TODO: spawning is redundant now, need to hold the gate
task_mgr::spawn(
&tokio::runtime::Handle::current(),
TaskKind::GarbageCollector,
Some(tenant_shard_id),
Some(timeline_id),
&format!("timeline_gc_handler garbage collection run for tenant {tenant_shard_id} timeline {timeline_id}"),
false,
async move {
fail::fail_point!("immediate_gc_task_pre");
#[allow(unused_mut)]
let mut result = tenant
.gc_iteration(Some(timeline_id), gc_horizon, pitr, &cancel, &ctx)

View File

@@ -223,6 +223,7 @@ use crate::{
config::PageServerConf,
task_mgr,
task_mgr::TaskKind,
task_mgr::BACKGROUND_RUNTIME,
tenant::metadata::TimelineMetadata,
tenant::upload_queue::{
UploadOp, UploadQueue, UploadQueueInitialized, UploadQueueStopped, UploadTask,
@@ -306,6 +307,8 @@ pub enum PersistIndexPartWithDeletedFlagError {
pub struct RemoteTimelineClient {
conf: &'static PageServerConf,
runtime: tokio::runtime::Handle,
tenant_shard_id: TenantShardId,
timeline_id: TimelineId,
generation: Generation,
@@ -338,6 +341,12 @@ impl RemoteTimelineClient {
) -> RemoteTimelineClient {
RemoteTimelineClient {
conf,
runtime: if cfg!(test) {
// remote_timeline_client.rs tests rely on current-thread runtime
tokio::runtime::Handle::current()
} else {
BACKGROUND_RUNTIME.handle().clone()
},
tenant_shard_id,
timeline_id,
generation,
@@ -1272,6 +1281,7 @@ impl RemoteTimelineClient {
let tenant_shard_id = self.tenant_shard_id;
let timeline_id = self.timeline_id;
task_mgr::spawn(
&self.runtime,
TaskKind::RemoteUploadTask,
Some(self.tenant_shard_id),
Some(self.timeline_id),
@@ -1866,6 +1876,7 @@ mod tests {
fn build_client(&self, generation: Generation) -> Arc<RemoteTimelineClient> {
Arc::new(RemoteTimelineClient {
conf: self.harness.conf,
runtime: tokio::runtime::Handle::current(),
tenant_shard_id: self.harness.tenant_shard_id,
timeline_id: TIMELINE_ID,
generation,

View File

@@ -8,7 +8,7 @@ use std::{sync::Arc, time::SystemTime};
use crate::{
config::PageServerConf,
disk_usage_eviction_task::DiskUsageEvictionInfo,
task_mgr::{self, TaskKind},
task_mgr::{self, TaskKind, BACKGROUND_RUNTIME},
virtual_file::MaybeFatalIo,
};
@@ -317,6 +317,7 @@ pub fn spawn_tasks(
tokio::sync::mpsc::channel::<CommandRequest<UploadCommand>>(16);
task_mgr::spawn(
BACKGROUND_RUNTIME.handle(),
TaskKind::SecondaryDownloads,
None,
None,
@@ -337,6 +338,7 @@ pub fn spawn_tasks(
);
task_mgr::spawn(
BACKGROUND_RUNTIME.handle(),
TaskKind::SecondaryUploads,
None,
None,

View File

@@ -1447,7 +1447,7 @@ impl LayerInner {
#[cfg(test)]
tokio::task::spawn(fut);
#[cfg(not(test))]
crate::task_mgr::THE_RUNTIME.spawn(fut);
crate::task_mgr::BACKGROUND_RUNTIME.spawn(fut);
}
/// Needed to use entered runtime in tests, but otherwise use BACKGROUND_RUNTIME.
@@ -1458,7 +1458,7 @@ impl LayerInner {
#[cfg(test)]
tokio::task::spawn_blocking(f);
#[cfg(not(test))]
crate::task_mgr::THE_RUNTIME.spawn_blocking(f);
crate::task_mgr::BACKGROUND_RUNTIME.spawn_blocking(f);
}
}

View File

@@ -8,7 +8,7 @@ use std::time::{Duration, Instant};
use crate::context::{DownloadBehavior, RequestContext};
use crate::metrics::TENANT_TASK_EVENTS;
use crate::task_mgr;
use crate::task_mgr::TaskKind;
use crate::task_mgr::{TaskKind, BACKGROUND_RUNTIME};
use crate::tenant::throttle::Stats;
use crate::tenant::timeline::CompactionError;
use crate::tenant::{Tenant, TenantState};
@@ -18,7 +18,7 @@ use utils::{backoff, completion};
static CONCURRENT_BACKGROUND_TASKS: once_cell::sync::Lazy<tokio::sync::Semaphore> =
once_cell::sync::Lazy::new(|| {
let total_threads = *crate::task_mgr::THE_RUNTIME_WORKER_THREADS;
let total_threads = *task_mgr::BACKGROUND_RUNTIME_WORKER_THREADS;
let permits = usize::max(
1,
// while a lot of the work is done on spawn_blocking, we still do
@@ -85,6 +85,7 @@ pub fn start_background_loops(
) {
let tenant_shard_id = tenant.tenant_shard_id;
task_mgr::spawn(
BACKGROUND_RUNTIME.handle(),
TaskKind::Compaction,
Some(tenant_shard_id),
None,
@@ -108,6 +109,7 @@ pub fn start_background_loops(
},
);
task_mgr::spawn(
BACKGROUND_RUNTIME.handle(),
TaskKind::GarbageCollector,
Some(tenant_shard_id),
None,

View File

@@ -1723,6 +1723,7 @@ impl Timeline {
initdb_optimization_count: 0,
};
task_mgr::spawn(
task_mgr::BACKGROUND_RUNTIME.handle(),
task_mgr::TaskKind::LayerFlushTask,
Some(self.tenant_shard_id),
Some(self.timeline_id),
@@ -2085,6 +2086,7 @@ impl Timeline {
DownloadBehavior::Download,
);
task_mgr::spawn(
task_mgr::BACKGROUND_RUNTIME.handle(),
task_mgr::TaskKind::InitialLogicalSizeCalculation,
Some(self.tenant_shard_id),
Some(self.timeline_id),
@@ -2262,6 +2264,7 @@ impl Timeline {
DownloadBehavior::Download,
);
task_mgr::spawn(
task_mgr::BACKGROUND_RUNTIME.handle(),
task_mgr::TaskKind::OndemandLogicalSizeCalculation,
Some(self.tenant_shard_id),
Some(self.timeline_id),
@@ -2852,15 +2855,7 @@ impl Timeline {
lsn: Lsn,
ctx: &RequestContext,
) -> Option<(Lsn, Bytes)> {
let cache = page_cache::get();
// FIXME: It's pointless to check the cache for things that are not 8kB pages.
// We should look at the key to determine if it's a cacheable object
let (lsn, read_guard) = cache
.lookup_materialized_page(self.tenant_shard_id, self.timeline_id, key, lsn, ctx)
.await?;
let img = Bytes::from(read_guard.to_vec());
Some((lsn, img))
return None;
}
async fn get_ready_ancestor_timeline(
@@ -3837,7 +3832,7 @@ impl Timeline {
};
let timer = self.metrics.garbage_collect_histo.start_timer();
pausable_failpoint!("before-timeline-gc");
fail_point!("before-timeline-gc");
// Is the timeline being deleted?
if self.is_stopping() {
@@ -4148,6 +4143,7 @@ impl Timeline {
let self_clone = Arc::clone(&self);
let task_id = task_mgr::spawn(
task_mgr::BACKGROUND_RUNTIME.handle(),
task_mgr::TaskKind::DownloadAllRemoteLayers,
Some(self.tenant_shard_id),
Some(self.timeline_id),

View File

@@ -443,6 +443,7 @@ impl DeleteTimelineFlow {
let timeline_id = timeline.timeline_id;
task_mgr::spawn(
task_mgr::BACKGROUND_RUNTIME.handle(),
TaskKind::TimelineDeletionWorker,
Some(tenant_shard_id),
Some(timeline_id),

View File

@@ -28,7 +28,7 @@ use tracing::{debug, error, info, info_span, instrument, warn, Instrument};
use crate::{
context::{DownloadBehavior, RequestContext},
pgdatadir_mapping::CollectKeySpaceError,
task_mgr::{self, TaskKind},
task_mgr::{self, TaskKind, BACKGROUND_RUNTIME},
tenant::{
tasks::BackgroundLoopKind, timeline::EvictionError, LogicalSizeCalculationCause, Tenant,
},
@@ -56,6 +56,7 @@ impl Timeline {
let self_clone = Arc::clone(self);
let background_tasks_can_start = background_tasks_can_start.cloned();
task_mgr::spawn(
BACKGROUND_RUNTIME.handle(),
TaskKind::Eviction,
Some(self.tenant_shard_id),
Some(self.timeline_id),

View File

@@ -24,7 +24,7 @@ mod connection_manager;
mod walreceiver_connection;
use crate::context::{DownloadBehavior, RequestContext};
use crate::task_mgr::{self, TaskKind};
use crate::task_mgr::{self, TaskKind, WALRECEIVER_RUNTIME};
use crate::tenant::debug_assert_current_span_has_tenant_and_timeline_id;
use crate::tenant::timeline::walreceiver::connection_manager::{
connection_manager_loop_step, ConnectionManagerState,
@@ -82,6 +82,7 @@ impl WalReceiver {
let loop_status = Arc::new(std::sync::RwLock::new(None));
let manager_status = Arc::clone(&loop_status);
task_mgr::spawn(
WALRECEIVER_RUNTIME.handle(),
TaskKind::WalReceiverManager,
Some(timeline.tenant_shard_id),
Some(timeline_id),
@@ -180,7 +181,7 @@ impl<E: Clone> TaskHandle<E> {
let (events_sender, events_receiver) = watch::channel(TaskStateUpdate::Started);
let cancellation_clone = cancellation.clone();
let join_handle = tokio::spawn(async move {
let join_handle = WALRECEIVER_RUNTIME.spawn(async move {
events_sender.send(TaskStateUpdate::Started).ok();
task(events_sender, cancellation_clone).await
// events_sender is dropped at some point during the .await above.

View File

@@ -11,6 +11,7 @@ use std::{
use anyhow::{anyhow, Context};
use bytes::BytesMut;
use chrono::{NaiveDateTime, Utc};
use fail::fail_point;
use futures::StreamExt;
use postgres::{error::SqlState, SimpleQueryMessage, SimpleQueryRow};
use postgres_ffi::WAL_SEGMENT_SIZE;
@@ -26,7 +27,9 @@ use super::TaskStateUpdate;
use crate::{
context::RequestContext,
metrics::{LIVE_CONNECTIONS_COUNT, WALRECEIVER_STARTED_CONNECTIONS, WAL_INGEST},
task_mgr::{self, TaskKind},
task_mgr,
task_mgr::TaskKind,
task_mgr::WALRECEIVER_RUNTIME,
tenant::{debug_assert_current_span_has_tenant_and_timeline_id, Timeline, WalReceiverInfo},
walingest::WalIngest,
walrecord::DecodedWALRecord,
@@ -160,6 +163,7 @@ pub(super) async fn handle_walreceiver_connection(
);
let connection_cancellation = cancellation.clone();
task_mgr::spawn(
WALRECEIVER_RUNTIME.handle(),
TaskKind::WalReceiverConnectionPoller,
Some(timeline.tenant_shard_id),
Some(timeline.timeline_id),
@@ -325,17 +329,7 @@ pub(super) async fn handle_walreceiver_connection(
filtered_records += 1;
}
// don't simply use pausable_failpoint here because its spawn_blocking slows
// slows down the tests too much.
fail::fail_point!("walreceiver-after-ingest-blocking");
if let Err(()) = (|| {
fail::fail_point!("walreceiver-after-ingest-pause-activate", |_| {
Err(())
});
Ok(())
})() {
pausable_failpoint!("walreceiver-after-ingest-pause");
}
fail_point!("walreceiver-after-ingest");
last_rec_lsn = lsn;

View File

@@ -312,7 +312,7 @@ pg_cluster_size(PG_FUNCTION_ARGS)
{
int64 size;
size = GetNeonCurrentClusterSize();
size = GetZenithCurrentClusterSize();
if (size == 0)
PG_RETURN_NULL();

View File

@@ -26,8 +26,6 @@ extern void pg_init_libpagestore(void);
extern void pg_init_walproposer(void);
extern uint64 BackpressureThrottlingTime(void);
extern void SetNeonCurrentClusterSize(uint64 size);
extern uint64 GetNeonCurrentClusterSize(void);
extern void replication_feedback_get_lsns(XLogRecPtr *writeLsn, XLogRecPtr *flushLsn, XLogRecPtr *applyLsn);
extern void PGDLLEXPORT WalProposerSync(int argc, char *argv[]);

View File

@@ -1831,7 +1831,7 @@ neon_extend(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno,
reln->smgr_relpersistence == RELPERSISTENCE_PERMANENT &&
!IsAutoVacuumWorkerProcess())
{
uint64 current_size = GetNeonCurrentClusterSize();
uint64 current_size = GetZenithCurrentClusterSize();
if (current_size >= ((uint64) max_cluster_size) * 1024 * 1024)
ereport(ERROR,
@@ -1912,7 +1912,7 @@ neon_zeroextend(SMgrRelation reln, ForkNumber forkNum, BlockNumber blocknum,
reln->smgr_relpersistence == RELPERSISTENCE_PERMANENT &&
!IsAutoVacuumWorkerProcess())
{
uint64 current_size = GetNeonCurrentClusterSize();
uint64 current_size = GetZenithCurrentClusterSize();
if (current_size >= ((uint64) max_cluster_size) * 1024 * 1024)
ereport(ERROR,

View File

@@ -287,7 +287,6 @@ typedef struct WalproposerShmemState
slock_t mutex;
term_t mineLastElectedTerm;
pg_atomic_uint64 backpressureThrottlingTime;
pg_atomic_uint64 currentClusterSize;
/* last feedback from each shard */
PageserverFeedback shard_ps_feedback[MAX_SHARDS];

View File

@@ -282,7 +282,6 @@ WalproposerShmemInit(void)
memset(walprop_shared, 0, WalproposerShmemSize());
SpinLockInit(&walprop_shared->mutex);
pg_atomic_init_u64(&walprop_shared->backpressureThrottlingTime, 0);
pg_atomic_init_u64(&walprop_shared->currentClusterSize, 0);
}
LWLockRelease(AddinShmemInitLock);
@@ -1973,7 +1972,7 @@ walprop_pg_process_safekeeper_feedback(WalProposer *wp, Safekeeper *sk)
/* Only one main shard sends non-zero currentClusterSize */
if (sk->appendResponse.ps_feedback.currentClusterSize > 0)
SetNeonCurrentClusterSize(sk->appendResponse.ps_feedback.currentClusterSize);
SetZenithCurrentClusterSize(sk->appendResponse.ps_feedback.currentClusterSize);
if (min_feedback.disk_consistent_lsn != standby_apply_lsn)
{
@@ -2095,18 +2094,6 @@ GetLogRepRestartLSN(WalProposer *wp)
return lrRestartLsn;
}
void SetNeonCurrentClusterSize(uint64 size)
{
pg_atomic_write_u64(&walprop_shared->currentClusterSize, size);
}
uint64 GetNeonCurrentClusterSize(void)
{
return pg_atomic_read_u64(&walprop_shared->currentClusterSize);
}
uint64 GetNeonCurrentClusterSize(void);
static const walproposer_api walprop_pg = {
.get_shmem_state = walprop_pg_get_shmem_state,
.start_streaming = walprop_pg_start_streaming,

View File

@@ -97,7 +97,6 @@ workspace_hack.workspace = true
[dev-dependencies]
camino-tempfile.workspace = true
fallible-iterator.workspace = true
rcgen.workspace = true
rstest.workspace = true
tokio-postgres-rustls.workspace = true

View File

@@ -408,228 +408,3 @@ impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, &()> {
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use bytes::BytesMut;
use fallible_iterator::FallibleIterator;
use postgres_protocol::{
authentication::sasl::{ChannelBinding, ScramSha256},
message::{backend::Message as PgMessage, frontend},
};
use provider::AuthSecret;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
use crate::{
auth::{ComputeUserInfoMaybeEndpoint, IpPattern},
config::AuthenticationConfig,
console::{
self,
provider::{self, CachedAllowedIps, CachedRoleSecret},
CachedNodeInfo,
},
context::RequestMonitoring,
proxy::NeonOptions,
scram::ServerSecret,
stream::{PqStream, Stream},
};
use super::auth_quirks;
struct Auth {
ips: Vec<IpPattern>,
secret: AuthSecret,
}
impl console::Api for Auth {
async fn get_role_secret(
&self,
_ctx: &mut RequestMonitoring,
_user_info: &super::ComputeUserInfo,
) -> Result<CachedRoleSecret, console::errors::GetAuthInfoError> {
Ok(CachedRoleSecret::new_uncached(Some(self.secret.clone())))
}
async fn get_allowed_ips_and_secret(
&self,
_ctx: &mut RequestMonitoring,
_user_info: &super::ComputeUserInfo,
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), console::errors::GetAuthInfoError>
{
Ok((
CachedAllowedIps::new_uncached(Arc::new(self.ips.clone())),
Some(CachedRoleSecret::new_uncached(Some(self.secret.clone()))),
))
}
async fn wake_compute(
&self,
_ctx: &mut RequestMonitoring,
_user_info: &super::ComputeUserInfo,
) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
unimplemented!()
}
}
static CONFIG: &AuthenticationConfig = &AuthenticationConfig {
scram_protocol_timeout: std::time::Duration::from_secs(5),
};
async fn read_message(r: &mut (impl AsyncRead + Unpin), b: &mut BytesMut) -> PgMessage {
loop {
r.read_buf(&mut *b).await.unwrap();
if let Some(m) = PgMessage::parse(&mut *b).unwrap() {
break m;
}
}
}
#[tokio::test]
async fn auth_quirks_scram() {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new(Stream::from_raw(server));
let mut ctx = RequestMonitoring::test();
let api = Auth {
ips: vec![],
secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
};
let user_info = ComputeUserInfoMaybeEndpoint {
user: "conrad".into(),
endpoint_id: Some("endpoint".into()),
options: NeonOptions::default(),
};
let handle = tokio::spawn(async move {
let mut scram = ScramSha256::new(b"my-secret-password", ChannelBinding::unsupported());
let mut read = BytesMut::new();
// server should offer scram
match read_message(&mut client, &mut read).await {
PgMessage::AuthenticationSasl(a) => {
let options: Vec<&str> = a.mechanisms().collect().unwrap();
assert_eq!(options, ["SCRAM-SHA-256"]);
}
_ => panic!("wrong message"),
}
// client sends client-first-message
let mut write = BytesMut::new();
frontend::sasl_initial_response("SCRAM-SHA-256", scram.message(), &mut write).unwrap();
client.write_all(&write).await.unwrap();
// server response with server-first-message
match read_message(&mut client, &mut read).await {
PgMessage::AuthenticationSaslContinue(a) => {
scram.update(a.data()).await.unwrap();
}
_ => panic!("wrong message"),
}
// client response with client-final-message
write.clear();
frontend::sasl_response(scram.message(), &mut write).unwrap();
client.write_all(&write).await.unwrap();
// server response with server-final-message
match read_message(&mut client, &mut read).await {
PgMessage::AuthenticationSaslFinal(a) => {
scram.finish(a.data()).unwrap();
}
_ => panic!("wrong message"),
}
});
let _creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, false, CONFIG)
.await
.unwrap();
handle.await.unwrap();
}
#[tokio::test]
async fn auth_quirks_cleartext() {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new(Stream::from_raw(server));
let mut ctx = RequestMonitoring::test();
let api = Auth {
ips: vec![],
secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
};
let user_info = ComputeUserInfoMaybeEndpoint {
user: "conrad".into(),
endpoint_id: Some("endpoint".into()),
options: NeonOptions::default(),
};
let handle = tokio::spawn(async move {
let mut read = BytesMut::new();
let mut write = BytesMut::new();
// server should offer cleartext
match read_message(&mut client, &mut read).await {
PgMessage::AuthenticationCleartextPassword => {}
_ => panic!("wrong message"),
}
// client responds with password
write.clear();
frontend::password_message(b"my-secret-password", &mut write).unwrap();
client.write_all(&write).await.unwrap();
});
let _creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, true, CONFIG)
.await
.unwrap();
handle.await.unwrap();
}
#[tokio::test]
async fn auth_quirks_password_hack() {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new(Stream::from_raw(server));
let mut ctx = RequestMonitoring::test();
let api = Auth {
ips: vec![],
secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
};
let user_info = ComputeUserInfoMaybeEndpoint {
user: "conrad".into(),
endpoint_id: None,
options: NeonOptions::default(),
};
let handle = tokio::spawn(async move {
let mut read = BytesMut::new();
// server should offer cleartext
match read_message(&mut client, &mut read).await {
PgMessage::AuthenticationCleartextPassword => {}
_ => panic!("wrong message"),
}
// client responds with password
let mut write = BytesMut::new();
frontend::password_message(b"endpoint=my-endpoint;my-secret-password", &mut write)
.unwrap();
client.write_all(&write).await.unwrap();
});
let creds = auth_quirks(&mut ctx, &api, user_info, &mut stream, true, CONFIG)
.await
.unwrap();
assert_eq!(creds.info.endpoint, "my-endpoint");
handle.await.unwrap();
}
}

View File

@@ -211,19 +211,4 @@ mod tests {
Ok(())
}
#[tokio::test]
async fn cancel_session_noop_regression() {
let handler = CancellationHandler::<()>::new(Default::default(), "local");
handler
.cancel_session(
CancelKeyData {
backend_pid: 0,
cancel_key: 0,
},
Uuid::new_v4(),
)
.await
.unwrap();
}
}

View File

@@ -82,13 +82,14 @@ pub type ScramKeys = tokio_postgres::config::ScramKeys<32>;
/// A config for establishing a connection to compute node.
/// Eventually, `tokio_postgres` will be replaced with something better.
/// Newtype allows us to implement methods on top of it.
#[derive(Clone, Default)]
#[derive(Clone)]
#[repr(transparent)]
pub struct ConnCfg(Box<tokio_postgres::Config>);
/// Creation and initialization routines.
impl ConnCfg {
pub fn new() -> Self {
Self::default()
Self(Default::default())
}
/// Reuse password or auth keys from the other config.
@@ -164,6 +165,12 @@ impl std::ops::DerefMut for ConnCfg {
}
}
impl Default for ConnCfg {
fn default() -> Self {
Self::new()
}
}
impl ConnCfg {
/// Establish a raw TCP connection to the compute node.
async fn connect_raw(&self, timeout: Duration) -> io::Result<(SocketAddr, TcpStream, &str)> {

View File

@@ -6,7 +6,7 @@ pub mod messages;
/// Wrappers for console APIs and their mocks.
pub mod provider;
pub(crate) use provider::{errors, Api, AuthSecret, CachedNodeInfo, NodeInfo};
pub use provider::{errors, Api, AuthSecret, CachedNodeInfo, NodeInfo};
/// Various cache-related types.
pub mod caches {

View File

@@ -14,6 +14,7 @@ use crate::{
context::RequestMonitoring,
scram, EndpointCacheKey, ProjectId,
};
use async_trait::async_trait;
use dashmap::DashMap;
use std::{sync::Arc, time::Duration};
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
@@ -325,7 +326,8 @@ pub type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc<Vec<IpPatt
/// This will allocate per each call, but the http requests alone
/// already require a few allocations, so it should be fine.
pub(crate) trait Api {
#[async_trait]
pub trait Api {
/// Get the client's auth secret for authentication.
/// Returns option because user not found situation is special.
/// We still have to mock the scram to avoid leaking information that user doesn't exist.
@@ -361,6 +363,7 @@ pub enum ConsoleBackend {
Test(Box<dyn crate::auth::backend::TestBackend>),
}
#[async_trait]
impl Api for ConsoleBackend {
async fn get_role_secret(
&self,

View File

@@ -8,6 +8,7 @@ use crate::console::provider::{CachedAllowedIps, CachedRoleSecret};
use crate::context::RequestMonitoring;
use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl};
use crate::{auth::IpPattern, cache::Cached};
use async_trait::async_trait;
use futures::TryFutureExt;
use std::{str::FromStr, sync::Arc};
use thiserror::Error;
@@ -143,6 +144,7 @@ async fn get_execute_postgres_query(
Ok(Some(entry))
}
#[async_trait]
impl super::Api for Api {
#[tracing::instrument(skip_all)]
async fn get_role_secret(

View File

@@ -14,6 +14,7 @@ use crate::{
context::RequestMonitoring,
metrics::{ALLOWED_IPS_BY_CACHE_OUTCOME, ALLOWED_IPS_NUMBER},
};
use async_trait::async_trait;
use futures::TryFutureExt;
use std::sync::Arc;
use tokio::time::Instant;
@@ -167,6 +168,7 @@ impl Api {
}
}
#[async_trait]
impl super::Api for Api {
#[tracing::instrument(skip_all)]
async fn get_role_secret(

View File

@@ -1,5 +1,6 @@
use std::sync::Arc;
use async_trait::async_trait;
use pq_proto::CancelKeyData;
use redis::AsyncCommands;
use tokio::sync::Mutex;
@@ -12,8 +13,8 @@ use super::{
notifications::{CancelSession, Notification, PROXY_CHANNEL_NAME},
};
#[async_trait]
pub trait CancellationPublisherMut: Send + Sync + 'static {
#[allow(async_fn_in_trait)]
async fn try_publish(
&mut self,
cancel_key_data: CancelKeyData,
@@ -21,8 +22,8 @@ pub trait CancellationPublisherMut: Send + Sync + 'static {
) -> anyhow::Result<()>;
}
#[async_trait]
pub trait CancellationPublisher: Send + Sync + 'static {
#[allow(async_fn_in_trait)]
async fn try_publish(
&self,
cancel_key_data: CancelKeyData,
@@ -30,9 +31,10 @@ pub trait CancellationPublisher: Send + Sync + 'static {
) -> anyhow::Result<()>;
}
impl CancellationPublisher for () {
#[async_trait]
impl CancellationPublisherMut for () {
async fn try_publish(
&self,
&mut self,
_cancel_key_data: CancelKeyData,
_session_id: Uuid,
) -> anyhow::Result<()> {
@@ -40,16 +42,18 @@ impl CancellationPublisher for () {
}
}
impl<P: CancellationPublisher> CancellationPublisherMut for P {
#[async_trait]
impl<P: CancellationPublisherMut> CancellationPublisher for P {
async fn try_publish(
&mut self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
&self,
_cancel_key_data: CancelKeyData,
_session_id: Uuid,
) -> anyhow::Result<()> {
<P as CancellationPublisher>::try_publish(self, cancel_key_data, session_id).await
self.try_publish(_cancel_key_data, _session_id).await
}
}
#[async_trait]
impl<P: CancellationPublisher> CancellationPublisher for Option<P> {
async fn try_publish(
&self,
@@ -64,6 +68,7 @@ impl<P: CancellationPublisher> CancellationPublisher for Option<P> {
}
}
#[async_trait]
impl<P: CancellationPublisherMut> CancellationPublisher for Arc<Mutex<P>> {
async fn try_publish(
&self,
@@ -140,6 +145,7 @@ impl RedisPublisherClient {
}
}
#[async_trait]
impl CancellationPublisherMut for RedisPublisherClient {
async fn try_publish(
&mut self,

View File

@@ -3,7 +3,9 @@
use std::convert::Infallible;
use hmac::{Hmac, Mac};
use sha2::Sha256;
use sha2::digest::FixedOutput;
use sha2::{Digest, Sha256};
use subtle::{Choice, ConstantTimeEq};
use tokio::task::yield_now;
use super::messages::{
@@ -11,7 +13,6 @@ use super::messages::{
};
use super::secret::ServerSecret;
use super::signature::SignatureBuilder;
use super::ScramKey;
use crate::config;
use crate::sasl::{self, ChannelBinding, Error as SaslError};
@@ -103,7 +104,7 @@ async fn pbkdf2(str: &[u8], salt: &[u8], iterations: u32) -> [u8; 32] {
}
// copied from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L236-L248>
async fn derive_client_key(password: &[u8], salt: &[u8], iterations: u32) -> ScramKey {
async fn derive_keys(password: &[u8], salt: &[u8], iterations: u32) -> ([u8; 32], [u8; 32]) {
let salted_password = pbkdf2(password, salt, iterations).await;
let make_key = |name| {
@@ -115,7 +116,7 @@ async fn derive_client_key(password: &[u8], salt: &[u8], iterations: u32) -> Scr
<[u8; 32]>::from(key.into_bytes())
};
make_key(b"Client Key").into()
(make_key(b"Client Key"), make_key(b"Server Key"))
}
pub async fn exchange(
@@ -123,12 +124,21 @@ pub async fn exchange(
password: &[u8],
) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
let salt = base64::decode(&secret.salt_base64)?;
let client_key = derive_client_key(password, &salt, secret.iterations).await;
let (client_key, server_key) = derive_keys(password, &salt, secret.iterations).await;
let stored_key: [u8; 32] = Sha256::default()
.chain_update(client_key)
.finalize_fixed()
.into();
if secret.is_password_invalid(&client_key).into() {
Ok(sasl::Outcome::Failure("password doesn't match"))
// constant time to not leak partial key match
let valid = stored_key.ct_eq(&secret.stored_key.as_bytes())
| server_key.ct_eq(&secret.server_key.as_bytes())
| Choice::from(secret.doomed as u8);
if valid.into() {
Ok(sasl::Outcome::Success(super::ScramKey::from(client_key)))
} else {
Ok(sasl::Outcome::Success(client_key))
Ok(sasl::Outcome::Failure("password doesn't match"))
}
}
@@ -210,7 +220,7 @@ impl SaslSentInner {
.derive_client_key(&client_final_message.proof);
// Auth fails either if keys don't match or it's pre-determined to fail.
if secret.is_password_invalid(&client_key).into() {
if client_key.sha256() != secret.stored_key || secret.doomed {
return Ok(sasl::Step::Failure("password doesn't match"));
}

View File

@@ -1,31 +1,17 @@
//! Tools for client/server/stored key management.
use subtle::ConstantTimeEq;
/// Faithfully taken from PostgreSQL.
pub const SCRAM_KEY_LEN: usize = 32;
/// One of the keys derived from the user's password.
/// We use the same structure for all keys, i.e.
/// `ClientKey`, `StoredKey`, and `ServerKey`.
#[derive(Clone, Default, Eq, Debug)]
#[derive(Clone, Default, PartialEq, Eq, Debug)]
#[repr(transparent)]
pub struct ScramKey {
bytes: [u8; SCRAM_KEY_LEN],
}
impl PartialEq for ScramKey {
fn eq(&self, other: &Self) -> bool {
self.ct_eq(other).into()
}
}
impl ConstantTimeEq for ScramKey {
fn ct_eq(&self, other: &Self) -> subtle::Choice {
self.bytes.ct_eq(&other.bytes)
}
}
impl ScramKey {
pub fn sha256(&self) -> Self {
super::sha256([self.as_ref()]).into()

View File

@@ -206,28 +206,6 @@ mod tests {
}
}
#[test]
fn parse_client_first_message_with_invalid_gs2_authz() {
assert!(ClientFirstMessage::parse("n,authzid,n=user,r=nonce").is_none())
}
#[test]
fn parse_client_first_message_with_extra_params() {
let msg = ClientFirstMessage::parse("n,,n=user,r=nonce,a=foo,b=bar,c=baz").unwrap();
assert_eq!(msg.bare, "n=user,r=nonce,a=foo,b=bar,c=baz");
assert_eq!(msg.username, "user");
assert_eq!(msg.nonce, "nonce");
assert_eq!(msg.cbind_flag, ChannelBinding::NotSupportedClient);
}
#[test]
fn parse_client_first_message_with_extra_params_invalid() {
// must be of the form `<ascii letter>=<...>`
assert!(ClientFirstMessage::parse("n,,n=user,r=nonce,abc=foo").is_none());
assert!(ClientFirstMessage::parse("n,,n=user,r=nonce,1=foo").is_none());
assert!(ClientFirstMessage::parse("n,,n=user,r=nonce,a").is_none());
}
#[test]
fn parse_client_final_message() {
let input = [

View File

@@ -1,7 +1,5 @@
//! Tools for SCRAM server secret management.
use subtle::{Choice, ConstantTimeEq};
use super::base64_decode_array;
use super::key::ScramKey;
@@ -42,11 +40,6 @@ impl ServerSecret {
Some(secret)
}
pub fn is_password_invalid(&self, client_key: &ScramKey) -> Choice {
// constant time to not leak partial key match
client_key.sha256().ct_ne(&self.stored_key) | Choice::from(self.doomed as u8)
}
/// To avoid revealing information to an attacker, we use a
/// mocked server secret even if the user doesn't exist.
/// See `auth-scram.c : mock_scram_secret` for details.

View File

@@ -244,7 +244,6 @@ impl SimulationApi {
mutex: 0,
mineLastElectedTerm: 0,
backpressureThrottlingTime: pg_atomic_uint64 { value: 0 },
currentClusterSize: pg_atomic_uint64 { value: 0 },
shard_ps_feedback: [empty_feedback; 128],
num_shards: 0,
min_ps_feedback: empty_feedback,

View File

@@ -116,7 +116,7 @@ def test_backpressure_received_lsn_lag(neon_env_builder: NeonEnvBuilder):
# Configure failpoint to slow down walreceiver ingest
with closing(env.pageserver.connect()) as psconn:
with psconn.cursor(cursor_factory=psycopg2.extras.DictCursor) as pscur:
pscur.execute("failpoints walreceiver-after-ingest-blocking=sleep(20)")
pscur.execute("failpoints walreceiver-after-ingest=sleep(20)")
# FIXME
# Wait for the check thread to start

View File

@@ -84,21 +84,3 @@ def test_hot_standby(neon_simple_env: NeonEnv):
# clean up
if slow_down_send:
sk_http.configure_failpoints(("sk-send-wal-replica-sleep", "off"))
def test_2_replicas_start(neon_simple_env: NeonEnv):
env = neon_simple_env
with env.endpoints.create_start(
branch_name="main",
endpoint_id="primary",
) as primary:
time.sleep(1)
with env.endpoints.new_replica_start(
origin=primary, endpoint_id="secondary1"
) as secondary1:
with env.endpoints.new_replica_start(
origin=primary, endpoint_id="secondary2"
) as secondary2:
wait_replica_caughtup(primary, secondary1)
wait_replica_caughtup(primary, secondary2)

View File

@@ -1,6 +1,4 @@
import gzip
import json
import os
import time
from dataclasses import dataclass
from pathlib import Path
@@ -12,11 +10,7 @@ from fixtures.neon_fixtures import (
NeonEnvBuilder,
wait_for_last_flush_lsn,
)
from fixtures.remote_storage import (
LocalFsStorage,
RemoteStorageKind,
remote_storage_to_toml_inline_table,
)
from fixtures.remote_storage import RemoteStorageKind
from fixtures.types import TenantId, TimelineId
from pytest_httpserver import HTTPServer
from werkzeug.wrappers.request import Request
@@ -46,9 +40,6 @@ def test_metric_collection(
uploads.put((events, is_last == "true"))
return Response(status=200)
neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.LOCAL_FS)
assert neon_env_builder.pageserver_remote_storage is not None
# Require collecting metrics frequently, since we change
# the timeline and want something to be logged about it.
#
@@ -57,11 +48,12 @@ def test_metric_collection(
neon_env_builder.pageserver_config_override = f"""
metric_collection_interval="1s"
metric_collection_endpoint="{metric_collection_endpoint}"
metric_collection_bucket={remote_storage_to_toml_inline_table(neon_env_builder.pageserver_remote_storage)}
cached_metric_collection_interval="0s"
synthetic_size_calculation_interval="3s"
"""
neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.LOCAL_FS)
log.info(f"test_metric_collection endpoint is {metric_collection_endpoint}")
# mock http server that returns OK for the metrics
@@ -175,20 +167,6 @@ def test_metric_collection(
httpserver.check()
# Check that at least one bucket output object is present, and that all
# can be decompressed and decoded.
bucket_dumps = {}
assert isinstance(env.pageserver_remote_storage, LocalFsStorage)
for dirpath, _dirs, files in os.walk(env.pageserver_remote_storage.root):
for file in files:
file_path = os.path.join(dirpath, file)
log.info(file_path)
if file.endswith(".gz"):
bucket_dumps[file_path] = json.load(gzip.open(file_path))
assert len(bucket_dumps) >= 1
assert all("events" in data for data in bucket_dumps.values())
def test_metric_collection_cleans_up_tempfile(
httpserver: HTTPServer,

View File

@@ -1,9 +1,7 @@
import pytest
from fixtures.log_helper import log
from fixtures.neon_fixtures import NeonEnv, wait_replica_caughtup
@pytest.mark.xfail
def test_replication_start(neon_simple_env: NeonEnv):
env = neon_simple_env

View File

@@ -931,7 +931,7 @@ def test_timeline_logical_size_task_priority(neon_env_builder: NeonEnvBuilder):
env.pageserver.stop()
env.pageserver.start(
extra_env_vars={
"FAILPOINTS": "initial-size-calculation-permit-pause=pause;walreceiver-after-ingest-pause-activate=return(1);walreceiver-after-ingest-pause=pause"
"FAILPOINTS": "initial-size-calculation-permit-pause=pause;walreceiver-after-ingest=pause"
}
)
@@ -953,11 +953,7 @@ def test_timeline_logical_size_task_priority(neon_env_builder: NeonEnvBuilder):
assert details["current_logical_size_is_accurate"] is True
client.configure_failpoints(
[
("initial-size-calculation-permit-pause", "off"),
("walreceiver-after-ingest-pause-activate", "off"),
("walreceiver-after-ingest-pause", "off"),
]
[("initial-size-calculation-permit-pause", "off"), ("walreceiver-after-ingest", "off")]
)
@@ -987,7 +983,7 @@ def test_eager_attach_does_not_queue_up(neon_env_builder: NeonEnvBuilder):
# pause at logical size calculation, also pause before walreceiver can give feedback so it will give priority to logical size calculation
env.pageserver.start(
extra_env_vars={
"FAILPOINTS": "timeline-calculate-logical-size-pause=pause;walreceiver-after-ingest-pause-activate=return(1);walreceiver-after-ingest-pause=pause"
"FAILPOINTS": "timeline-calculate-logical-size-pause=pause;walreceiver-after-ingest=pause"
}
)
@@ -1033,11 +1029,7 @@ def test_eager_attach_does_not_queue_up(neon_env_builder: NeonEnvBuilder):
other_is_attaching()
client.configure_failpoints(
[
("timeline-calculate-logical-size-pause", "off"),
("walreceiver-after-ingest-pause-activate", "off"),
("walreceiver-after-ingest-pause", "off"),
]
[("timeline-calculate-logical-size-pause", "off"), ("walreceiver-after-ingest", "off")]
)
@@ -1067,7 +1059,7 @@ def test_lazy_attach_activation(neon_env_builder: NeonEnvBuilder, activation_met
# pause at logical size calculation, also pause before walreceiver can give feedback so it will give priority to logical size calculation
env.pageserver.start(
extra_env_vars={
"FAILPOINTS": "timeline-calculate-logical-size-pause=pause;walreceiver-after-ingest-pause-activate=return(1);walreceiver-after-ingest-pause=pause"
"FAILPOINTS": "timeline-calculate-logical-size-pause=pause;walreceiver-after-ingest=pause"
}
)
@@ -1119,11 +1111,3 @@ def test_lazy_attach_activation(neon_env_builder: NeonEnvBuilder, activation_met
delete_lazy_activating(lazy_tenant, env.pageserver, expect_attaching=True)
else:
raise RuntimeError(activation_method)
client.configure_failpoints(
[
("timeline-calculate-logical-size-pause", "off"),
("walreceiver-after-ingest-pause-activate", "off"),
("walreceiver-after-ingest-pause", "off"),
]
)

View File

@@ -1,5 +1,5 @@
{
"postgres-v16": "3946b2e2ea71d07af092099cb5bcae76a69b90d6",
"postgres-v15": "e7651e79c0c27fbddc3c724f5b9553222c28e395",
"postgres-v14": "748643b4683e9fe3b105011a6ba8a687d032cd65"
"postgres-v16": "90078947229aa7f9ac5f7ed4527b2c7386d5332b",
"postgres-v15": "80cef885add1af6741aa31944c7d2c84d8f9098f",
"postgres-v14": "3b09894ddb8825b50c963942059eab1a2a0b0a89"
}