mirror of
https://github.com/neondatabase/neon.git
synced 2026-02-01 17:50:38 +00:00
Compare commits
8 Commits
cloneable/
...
conrad/pro
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4ada80d915 | ||
|
|
fd263a0c23 | ||
|
|
26dc39053e | ||
|
|
1f62ee5f5c | ||
|
|
e78254657a | ||
|
|
640500aa6d | ||
|
|
b0c712f63f | ||
|
|
f84e73c323 |
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -2055,7 +2055,6 @@ dependencies = [
|
||||
"axum-extra",
|
||||
"camino",
|
||||
"camino-tempfile",
|
||||
"clap",
|
||||
"futures",
|
||||
"http-body-util",
|
||||
"itertools 0.10.5",
|
||||
@@ -5274,6 +5273,7 @@ dependencies = [
|
||||
"tokio-rustls 0.26.2",
|
||||
"tokio-tungstenite 0.21.0",
|
||||
"tokio-util",
|
||||
"toml",
|
||||
"tracing",
|
||||
"tracing-log",
|
||||
"tracing-opentelemetry",
|
||||
|
||||
@@ -8,7 +8,6 @@ anyhow.workspace = true
|
||||
axum-extra.workspace = true
|
||||
axum.workspace = true
|
||||
camino.workspace = true
|
||||
clap.workspace = true
|
||||
futures.workspace = true
|
||||
jsonwebtoken.workspace = true
|
||||
prometheus.workspace = true
|
||||
|
||||
@@ -3,8 +3,7 @@
|
||||
//! This service is deployed either as a separate component or as part of compute image
|
||||
//! for large computes.
|
||||
mod app;
|
||||
use anyhow::Context;
|
||||
use clap::Parser;
|
||||
use anyhow::{Context, bail};
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use tracing::info;
|
||||
use utils::logging;
|
||||
@@ -18,18 +17,6 @@ const fn listen() -> SocketAddr {
|
||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 51243)
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(exclusive = true)]
|
||||
config_file: Option<String>,
|
||||
#[arg(long, default_value = "false", requires = "config")]
|
||||
/// to allow testing k8s helm chart where we don't have s3 credentials
|
||||
no_s3_check_on_startup: bool,
|
||||
#[arg(long, value_name = "FILE")]
|
||||
/// inline config mode for k8s helm chart
|
||||
config: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
struct Config {
|
||||
@@ -50,16 +37,19 @@ async fn main() -> anyhow::Result<()> {
|
||||
logging::Output::Stdout,
|
||||
)?;
|
||||
|
||||
let args = Args::parse();
|
||||
let config: Config = if let Some(config_path) = args.config_file {
|
||||
info!("Reading config from {config_path}");
|
||||
let config = std::fs::read_to_string(config_path)?;
|
||||
// Allow either passing filename or inline config (for k8s helm chart)
|
||||
let args: Vec<String> = std::env::args().skip(1).collect();
|
||||
let config: Config = if args.len() == 1 && args[0].ends_with(".json") {
|
||||
info!("Reading config from {}", args[0]);
|
||||
let config = std::fs::read_to_string(args[0].clone())?;
|
||||
serde_json::from_str(&config).context("parsing config")?
|
||||
} else if let Some(config) = args.config {
|
||||
} else if !args.is_empty() && args[0].starts_with("--config=") {
|
||||
info!("Reading inline config");
|
||||
serde_json::from_str(&config).context("parsing config")?
|
||||
let config = args.join(" ");
|
||||
let config = config.strip_prefix("--config=").unwrap();
|
||||
serde_json::from_str(config).context("parsing config")?
|
||||
} else {
|
||||
anyhow::bail!("Supply either config file path or --config=inline-config");
|
||||
bail!("Usage: endpoint_storage config.json or endpoint_storage --config=JSON");
|
||||
};
|
||||
|
||||
info!("Reading pemfile from {}", config.pemfile.clone());
|
||||
@@ -72,9 +62,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
let storage = remote_storage::GenericRemoteStorage::from_config(&config.storage_config).await?;
|
||||
let cancel = tokio_util::sync::CancellationToken::new();
|
||||
if !args.no_s3_check_on_startup {
|
||||
app::check_storage_permissions(&storage, cancel.clone()).await?;
|
||||
}
|
||||
app::check_storage_permissions(&storage, cancel.clone()).await?;
|
||||
|
||||
let proxy = std::sync::Arc::new(endpoint_storage::Storage {
|
||||
auth,
|
||||
|
||||
@@ -10,7 +10,7 @@ use std::sync::Arc;
|
||||
use std::time::{Duration, SystemTime};
|
||||
use std::{env, io};
|
||||
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
use anyhow::{Context, Result};
|
||||
use azure_core::request_options::{IfMatchCondition, MaxResults, Metadata, Range};
|
||||
use azure_core::{Continuable, HttpClient, RetryOptions, TransportOptions};
|
||||
use azure_storage::StorageCredentials;
|
||||
@@ -37,7 +37,6 @@ use crate::metrics::{AttemptOutcome, RequestKind, start_measuring_requests};
|
||||
use crate::{
|
||||
ConcurrencyLimiter, Download, DownloadError, DownloadKind, DownloadOpts, Listing, ListingMode,
|
||||
ListingObject, RemotePath, RemoteStorage, StorageMetadata, TimeTravelError, TimeoutOrCancel,
|
||||
Version, VersionKind,
|
||||
};
|
||||
|
||||
pub struct AzureBlobStorage {
|
||||
@@ -406,39 +405,6 @@ impl AzureBlobStorage {
|
||||
pub fn container_name(&self) -> &str {
|
||||
&self.container_name
|
||||
}
|
||||
|
||||
async fn list_versions_with_permit(
|
||||
&self,
|
||||
_permit: &tokio::sync::SemaphorePermit<'_>,
|
||||
prefix: Option<&RemotePath>,
|
||||
mode: ListingMode,
|
||||
max_keys: Option<NonZeroU32>,
|
||||
cancel: &CancellationToken,
|
||||
) -> Result<crate::VersionListing, DownloadError> {
|
||||
let customize_builder = |mut builder: ListBlobsBuilder| {
|
||||
builder = builder.include_versions(true);
|
||||
// We do not return this info back to `VersionListing` yet.
|
||||
builder = builder.include_deleted(true);
|
||||
builder
|
||||
};
|
||||
let kind = RequestKind::ListVersions;
|
||||
|
||||
let mut stream = std::pin::pin!(self.list_streaming_for_fn(
|
||||
prefix,
|
||||
mode,
|
||||
max_keys,
|
||||
cancel,
|
||||
kind,
|
||||
customize_builder
|
||||
));
|
||||
let mut combined: crate::VersionListing =
|
||||
stream.next().await.expect("At least one item required")?;
|
||||
while let Some(list) = stream.next().await {
|
||||
let list = list?;
|
||||
combined.versions.extend(list.versions.into_iter());
|
||||
}
|
||||
Ok(combined)
|
||||
}
|
||||
}
|
||||
|
||||
trait ListingCollector {
|
||||
@@ -522,10 +488,27 @@ impl RemoteStorage for AzureBlobStorage {
|
||||
max_keys: Option<NonZeroU32>,
|
||||
cancel: &CancellationToken,
|
||||
) -> std::result::Result<crate::VersionListing, DownloadError> {
|
||||
let customize_builder = |mut builder: ListBlobsBuilder| {
|
||||
builder = builder.include_versions(true);
|
||||
builder
|
||||
};
|
||||
let kind = RequestKind::ListVersions;
|
||||
let permit = self.permit(kind, cancel).await?;
|
||||
self.list_versions_with_permit(&permit, prefix, mode, max_keys, cancel)
|
||||
.await
|
||||
|
||||
let mut stream = std::pin::pin!(self.list_streaming_for_fn(
|
||||
prefix,
|
||||
mode,
|
||||
max_keys,
|
||||
cancel,
|
||||
kind,
|
||||
customize_builder
|
||||
));
|
||||
let mut combined: crate::VersionListing =
|
||||
stream.next().await.expect("At least one item required")?;
|
||||
while let Some(list) = stream.next().await {
|
||||
let list = list?;
|
||||
combined.versions.extend(list.versions.into_iter());
|
||||
}
|
||||
Ok(combined)
|
||||
}
|
||||
|
||||
async fn head_object(
|
||||
@@ -820,158 +803,14 @@ impl RemoteStorage for AzureBlobStorage {
|
||||
|
||||
async fn time_travel_recover(
|
||||
&self,
|
||||
prefix: Option<&RemotePath>,
|
||||
timestamp: SystemTime,
|
||||
done_if_after: SystemTime,
|
||||
cancel: &CancellationToken,
|
||||
_prefix: Option<&RemotePath>,
|
||||
_timestamp: SystemTime,
|
||||
_done_if_after: SystemTime,
|
||||
_cancel: &CancellationToken,
|
||||
) -> Result<(), TimeTravelError> {
|
||||
let msg = "PLEASE NOTE: Azure Blob storage time-travel recovery may not work as expected "
|
||||
.to_string()
|
||||
+ "for some specific files. If a file gets deleted but then overwritten and we want to recover "
|
||||
+ "to the time during the file was not present, this functionality will recover the file. Only "
|
||||
+ "use the functionality for services that can tolerate this. For example, recovering a state of the "
|
||||
+ "pageserver tenants.";
|
||||
tracing::error!("{}", msg);
|
||||
|
||||
let kind = RequestKind::TimeTravel;
|
||||
let permit = self.permit(kind, cancel).await?;
|
||||
|
||||
let mode = ListingMode::NoDelimiter;
|
||||
let version_listing = self
|
||||
.list_versions_with_permit(&permit, prefix, mode, None, cancel)
|
||||
.await
|
||||
.map_err(|err| match err {
|
||||
DownloadError::Other(e) => TimeTravelError::Other(e),
|
||||
DownloadError::Cancelled => TimeTravelError::Cancelled,
|
||||
other => TimeTravelError::Other(other.into()),
|
||||
})?;
|
||||
let versions_and_deletes = version_listing.versions;
|
||||
|
||||
tracing::info!(
|
||||
"Built list for time travel with {} versions and deletions",
|
||||
versions_and_deletes.len()
|
||||
);
|
||||
|
||||
// Work on the list of references instead of the objects directly,
|
||||
// otherwise we get lifetime errors in the sort_by_key call below.
|
||||
let mut versions_and_deletes = versions_and_deletes.iter().collect::<Vec<_>>();
|
||||
|
||||
versions_and_deletes.sort_by_key(|vd| (&vd.key, &vd.last_modified));
|
||||
|
||||
let mut vds_for_key = HashMap::<_, Vec<_>>::new();
|
||||
|
||||
for vd in &versions_and_deletes {
|
||||
let Version { key, .. } = &vd;
|
||||
let version_id = vd.version_id().map(|v| v.0.as_str());
|
||||
if version_id == Some("null") {
|
||||
return Err(TimeTravelError::Other(anyhow!(
|
||||
"Received ListVersions response for key={key} with version_id='null', \
|
||||
indicating either disabled versioning, or legacy objects with null version id values"
|
||||
)));
|
||||
}
|
||||
tracing::trace!("Parsing version key={key} kind={:?}", vd.kind);
|
||||
|
||||
vds_for_key.entry(key).or_default().push(vd);
|
||||
}
|
||||
|
||||
let warn_threshold = 3;
|
||||
let max_retries = 10;
|
||||
let is_permanent = |e: &_| matches!(e, TimeTravelError::Cancelled);
|
||||
|
||||
for (key, versions) in vds_for_key {
|
||||
let last_vd = versions.last().unwrap();
|
||||
let key = self.relative_path_to_name(key);
|
||||
if last_vd.last_modified > done_if_after {
|
||||
tracing::debug!("Key {key} has version later than done_if_after, skipping");
|
||||
continue;
|
||||
}
|
||||
// the version we want to restore to.
|
||||
let version_to_restore_to =
|
||||
match versions.binary_search_by_key(×tamp, |tpl| tpl.last_modified) {
|
||||
Ok(v) => v,
|
||||
Err(e) => e,
|
||||
};
|
||||
if version_to_restore_to == versions.len() {
|
||||
tracing::debug!("Key {key} has no changes since timestamp, skipping");
|
||||
continue;
|
||||
}
|
||||
let mut do_delete = false;
|
||||
if version_to_restore_to == 0 {
|
||||
// All versions more recent, so the key didn't exist at the specified time point.
|
||||
tracing::debug!(
|
||||
"All {} versions more recent for {key}, deleting",
|
||||
versions.len()
|
||||
);
|
||||
do_delete = true;
|
||||
} else {
|
||||
match &versions[version_to_restore_to - 1] {
|
||||
Version {
|
||||
kind: VersionKind::Version(version_id),
|
||||
..
|
||||
} => {
|
||||
let source_url = format!(
|
||||
"{}/{}?versionid={}",
|
||||
self.client
|
||||
.url()
|
||||
.map_err(|e| TimeTravelError::Other(anyhow!("{e}")))?,
|
||||
key,
|
||||
version_id.0
|
||||
);
|
||||
tracing::debug!(
|
||||
"Promoting old version {} for {key} at {}...",
|
||||
version_id.0,
|
||||
source_url
|
||||
);
|
||||
backoff::retry(
|
||||
|| async {
|
||||
let blob_client = self.client.blob_client(key.clone());
|
||||
let op = blob_client.copy(Url::from_str(&source_url).unwrap());
|
||||
tokio::select! {
|
||||
res = op => res.map_err(|e| TimeTravelError::Other(e.into())),
|
||||
_ = cancel.cancelled() => Err(TimeTravelError::Cancelled),
|
||||
}
|
||||
},
|
||||
is_permanent,
|
||||
warn_threshold,
|
||||
max_retries,
|
||||
"copying object version for time_travel_recover",
|
||||
cancel,
|
||||
)
|
||||
.await
|
||||
.ok_or_else(|| TimeTravelError::Cancelled)
|
||||
.and_then(|x| x)?;
|
||||
tracing::info!(?version_id, %key, "Copied old version in Azure blob storage");
|
||||
}
|
||||
Version {
|
||||
kind: VersionKind::DeletionMarker,
|
||||
..
|
||||
} => {
|
||||
do_delete = true;
|
||||
}
|
||||
}
|
||||
};
|
||||
if do_delete {
|
||||
if matches!(last_vd.kind, VersionKind::DeletionMarker) {
|
||||
// Key has since been deleted (but there was some history), no need to do anything
|
||||
tracing::debug!("Key {key} already deleted, skipping.");
|
||||
} else {
|
||||
tracing::debug!("Deleting {key}...");
|
||||
|
||||
self.delete(&RemotePath::from_string(&key).unwrap(), cancel)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
// delete_oid0 will use TimeoutOrCancel
|
||||
if TimeoutOrCancel::caused_by_cancel(&e) {
|
||||
TimeTravelError::Cancelled
|
||||
} else {
|
||||
TimeTravelError::Other(e)
|
||||
}
|
||||
})?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
// TODO use Azure point in time recovery feature for this
|
||||
// https://learn.microsoft.com/en-us/azure/storage/blobs/point-in-time-restore-overview
|
||||
Err(TimeTravelError::Unimplemented)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1022,7 +1022,6 @@ impl RemoteStorage for S3Bucket {
|
||||
let Version { key, .. } = &vd;
|
||||
let version_id = vd.version_id().map(|v| v.0.as_str());
|
||||
if version_id == Some("null") {
|
||||
// TODO: check the behavior of using the SDK on a non-versioned container
|
||||
return Err(TimeTravelError::Other(anyhow!(
|
||||
"Received ListVersions response for key={key} with version_id='null', \
|
||||
indicating either disabled versioning, or legacy objects with null version id values"
|
||||
|
||||
@@ -573,8 +573,7 @@ fn start_pageserver(
|
||||
tokio::sync::mpsc::unbounded_channel();
|
||||
let deletion_queue_client = deletion_queue.new_client();
|
||||
let background_purges = mgr::BackgroundPurges::default();
|
||||
|
||||
let tenant_manager = mgr::init(
|
||||
let tenant_manager = BACKGROUND_RUNTIME.block_on(mgr::init_tenant_mgr(
|
||||
conf,
|
||||
background_purges.clone(),
|
||||
TenantSharedResources {
|
||||
@@ -585,10 +584,10 @@ fn start_pageserver(
|
||||
basebackup_prepare_sender,
|
||||
feature_resolver,
|
||||
},
|
||||
order,
|
||||
shutdown_pageserver.clone(),
|
||||
);
|
||||
))?;
|
||||
let tenant_manager = Arc::new(tenant_manager);
|
||||
BACKGROUND_RUNTIME.block_on(mgr::init_tenant_mgr(tenant_manager.clone(), order))?;
|
||||
|
||||
let basebackup_cache = BasebackupCache::spawn(
|
||||
BACKGROUND_RUNTIME.handle(),
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use std::{collections::HashMap, sync::Arc, time::Duration};
|
||||
|
||||
use pageserver_api::config::NodeMetadata;
|
||||
use posthog_client_lite::{
|
||||
CaptureEvent, FeatureResolverBackgroundLoop, PostHogClientConfig, PostHogEvaluationError,
|
||||
PostHogFlagFilterPropertyValue,
|
||||
@@ -87,35 +86,7 @@ impl FeatureResolver {
|
||||
}
|
||||
}
|
||||
}
|
||||
// TODO: move this to a background task so that we don't block startup in case of slow disk
|
||||
let metadata_path = conf.metadata_path();
|
||||
match std::fs::read_to_string(&metadata_path) {
|
||||
Ok(metadata_str) => match serde_json::from_str::<NodeMetadata>(&metadata_str) {
|
||||
Ok(metadata) => {
|
||||
properties.insert(
|
||||
"hostname".to_string(),
|
||||
PostHogFlagFilterPropertyValue::String(metadata.http_host),
|
||||
);
|
||||
if let Some(cplane_region) = metadata.other.get("region_id") {
|
||||
if let Some(cplane_region) = cplane_region.as_str() {
|
||||
// This region contains the cell number
|
||||
properties.insert(
|
||||
"neon_region".to_string(),
|
||||
PostHogFlagFilterPropertyValue::String(
|
||||
cplane_region.to_string(),
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to parse metadata.json: {}", e);
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to read metadata.json: {}", e);
|
||||
}
|
||||
}
|
||||
// TODO: add pageserver URL.
|
||||
Arc::new(properties)
|
||||
};
|
||||
let fake_tenants = {
|
||||
|
||||
@@ -1053,15 +1053,6 @@ pub(crate) static TENANT_STATE_METRIC: Lazy<UIntGaugeVec> = Lazy::new(|| {
|
||||
.expect("Failed to register pageserver_tenant_states_count metric")
|
||||
});
|
||||
|
||||
pub(crate) static TIMELINE_STATE_METRIC: Lazy<UIntGaugeVec> = Lazy::new(|| {
|
||||
register_uint_gauge_vec!(
|
||||
"pageserver_timeline_states_count",
|
||||
"Count of timelines per state",
|
||||
&["state"]
|
||||
)
|
||||
.expect("Failed to register pageserver_timeline_states_count metric")
|
||||
});
|
||||
|
||||
/// A set of broken tenants.
|
||||
///
|
||||
/// These are expected to be so rare that a set is fine. Set as in a new timeseries per each broken
|
||||
@@ -3334,8 +3325,6 @@ impl TimelineMetrics {
|
||||
&timeline_id,
|
||||
);
|
||||
|
||||
TIMELINE_STATE_METRIC.with_label_values(&["active"]).inc();
|
||||
|
||||
TimelineMetrics {
|
||||
tenant_id,
|
||||
shard_id,
|
||||
@@ -3490,8 +3479,6 @@ impl TimelineMetrics {
|
||||
return;
|
||||
}
|
||||
|
||||
TIMELINE_STATE_METRIC.with_label_values(&["active"]).dec();
|
||||
|
||||
let tenant_id = &self.tenant_id;
|
||||
let timeline_id = &self.timeline_id;
|
||||
let shard_id = &self.shard_id;
|
||||
|
||||
@@ -89,8 +89,7 @@ use crate::l0_flush::L0FlushGlobalState;
|
||||
use crate::metrics::{
|
||||
BROKEN_TENANTS_SET, CIRCUIT_BREAKERS_BROKEN, CIRCUIT_BREAKERS_UNBROKEN, CONCURRENT_INITDBS,
|
||||
INITDB_RUN_TIME, INITDB_SEMAPHORE_ACQUISITION_TIME, TENANT, TENANT_OFFLOADED_TIMELINES,
|
||||
TENANT_STATE_METRIC, TENANT_SYNTHETIC_SIZE_METRIC, TIMELINE_STATE_METRIC,
|
||||
remove_tenant_metrics,
|
||||
TENANT_STATE_METRIC, TENANT_SYNTHETIC_SIZE_METRIC, remove_tenant_metrics,
|
||||
};
|
||||
use crate::task_mgr::TaskKind;
|
||||
use crate::tenant::config::LocationMode;
|
||||
@@ -545,28 +544,6 @@ pub struct OffloadedTimeline {
|
||||
|
||||
/// Part of the `OffloadedTimeline` object's lifecycle: this needs to be set before we drop it
|
||||
pub deleted_from_ancestor: AtomicBool,
|
||||
|
||||
_metrics_guard: OffloadedTimelineMetricsGuard,
|
||||
}
|
||||
|
||||
/// Increases the offloaded timeline count metric when created, and decreases when dropped.
|
||||
struct OffloadedTimelineMetricsGuard;
|
||||
|
||||
impl OffloadedTimelineMetricsGuard {
|
||||
fn new() -> Self {
|
||||
TIMELINE_STATE_METRIC
|
||||
.with_label_values(&["offloaded"])
|
||||
.inc();
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for OffloadedTimelineMetricsGuard {
|
||||
fn drop(&mut self) {
|
||||
TIMELINE_STATE_METRIC
|
||||
.with_label_values(&["offloaded"])
|
||||
.dec();
|
||||
}
|
||||
}
|
||||
|
||||
impl OffloadedTimeline {
|
||||
@@ -599,8 +576,6 @@ impl OffloadedTimeline {
|
||||
|
||||
delete_progress: timeline.delete_progress.clone(),
|
||||
deleted_from_ancestor: AtomicBool::new(false),
|
||||
|
||||
_metrics_guard: OffloadedTimelineMetricsGuard::new(),
|
||||
})
|
||||
}
|
||||
fn from_manifest(tenant_shard_id: TenantShardId, manifest: &OffloadedTimelineManifest) -> Self {
|
||||
@@ -620,7 +595,6 @@ impl OffloadedTimeline {
|
||||
archived_at,
|
||||
delete_progress: TimelineDeleteProgress::default(),
|
||||
deleted_from_ancestor: AtomicBool::new(false),
|
||||
_metrics_guard: OffloadedTimelineMetricsGuard::new(),
|
||||
}
|
||||
}
|
||||
fn manifest(&self) -> OffloadedTimelineManifest {
|
||||
|
||||
@@ -12,6 +12,7 @@ use anyhow::Context;
|
||||
use camino::{Utf8DirEntry, Utf8Path, Utf8PathBuf};
|
||||
use futures::StreamExt;
|
||||
use itertools::Itertools;
|
||||
use once_cell::sync::Lazy;
|
||||
use pageserver_api::key::Key;
|
||||
use pageserver_api::models::{DetachBehavior, LocationConfigMode};
|
||||
use pageserver_api::shard::{
|
||||
@@ -102,7 +103,7 @@ pub(crate) enum TenantsMap {
|
||||
/// [`init_tenant_mgr`] is not done yet.
|
||||
Initializing,
|
||||
/// [`init_tenant_mgr`] is done, all on-disk tenants have been loaded.
|
||||
/// New tenants can be added using [`TenantManager::tenant_map_acquire_slot`].
|
||||
/// New tenants can be added using [`tenant_map_acquire_slot`].
|
||||
Open(BTreeMap<TenantShardId, TenantSlot>),
|
||||
/// The pageserver has entered shutdown mode via [`TenantManager::shutdown`].
|
||||
/// Existing tenants are still accessible, but no new tenants can be created.
|
||||
@@ -283,6 +284,9 @@ impl BackgroundPurges {
|
||||
}
|
||||
}
|
||||
|
||||
static TENANTS: Lazy<std::sync::RwLock<TenantsMap>> =
|
||||
Lazy::new(|| std::sync::RwLock::new(TenantsMap::Initializing));
|
||||
|
||||
/// Responsible for storing and mutating the collection of all tenants
|
||||
/// that this pageserver has state for.
|
||||
///
|
||||
@@ -293,7 +297,10 @@ impl BackgroundPurges {
|
||||
/// and attached modes concurrently.
|
||||
pub struct TenantManager {
|
||||
conf: &'static PageServerConf,
|
||||
tenants: std::sync::RwLock<TenantsMap>,
|
||||
// TODO: currently this is a &'static pointing to TENANTs. When we finish refactoring
|
||||
// out of that static variable, the TenantManager can own this.
|
||||
// See https://github.com/neondatabase/neon/issues/5796
|
||||
tenants: &'static std::sync::RwLock<TenantsMap>,
|
||||
resources: TenantSharedResources,
|
||||
|
||||
// Long-running operations that happen outside of a [`Tenant`] lifetime should respect this token.
|
||||
@@ -472,43 +479,21 @@ pub(crate) enum DeleteTenantError {
|
||||
Other(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
/// Initialize repositories at `Initializing` state.
|
||||
pub fn init(
|
||||
conf: &'static PageServerConf,
|
||||
background_purges: BackgroundPurges,
|
||||
resources: TenantSharedResources,
|
||||
cancel: CancellationToken,
|
||||
) -> TenantManager {
|
||||
TenantManager {
|
||||
conf,
|
||||
tenants: std::sync::RwLock::new(TenantsMap::Initializing),
|
||||
resources,
|
||||
cancel,
|
||||
background_purges,
|
||||
}
|
||||
}
|
||||
|
||||
/// Transition repositories from `Initializing` state to `Open` state with locally available timelines.
|
||||
/// Initialize repositories with locally available timelines.
|
||||
/// Timelines that are only partially available locally (remote storage has more data than this pageserver)
|
||||
/// are scheduled for download and added to the tenant once download is completed.
|
||||
#[instrument(skip_all)]
|
||||
pub async fn init_tenant_mgr(
|
||||
tenant_manager: Arc<TenantManager>,
|
||||
conf: &'static PageServerConf,
|
||||
background_purges: BackgroundPurges,
|
||||
resources: TenantSharedResources,
|
||||
init_order: InitializationOrder,
|
||||
) -> anyhow::Result<()> {
|
||||
debug_assert!(matches!(
|
||||
*tenant_manager.tenants.read().unwrap(),
|
||||
TenantsMap::Initializing
|
||||
));
|
||||
cancel: CancellationToken,
|
||||
) -> anyhow::Result<TenantManager> {
|
||||
let mut tenants = BTreeMap::new();
|
||||
|
||||
let ctx = RequestContext::todo_child(TaskKind::Startup, DownloadBehavior::Warn);
|
||||
|
||||
let conf = tenant_manager.conf;
|
||||
let resources = &tenant_manager.resources;
|
||||
let cancel = &tenant_manager.cancel;
|
||||
let background_purges = &tenant_manager.background_purges;
|
||||
|
||||
// Initialize dynamic limits that depend on system resources
|
||||
let system_memory =
|
||||
sysinfo::System::new_with_specifics(sysinfo::RefreshKind::new().with_memory())
|
||||
@@ -527,7 +512,7 @@ pub async fn init_tenant_mgr(
|
||||
let tenant_configs = init_load_tenant_configs(conf).await;
|
||||
|
||||
// Determine which tenants are to be secondary or attached, and in which generation
|
||||
let tenant_modes = init_load_generations(conf, &tenant_configs, resources, cancel).await?;
|
||||
let tenant_modes = init_load_generations(conf, &tenant_configs, &resources, &cancel).await?;
|
||||
|
||||
tracing::info!(
|
||||
"Attaching {} tenants at startup, warming up {} at a time",
|
||||
@@ -684,10 +669,18 @@ pub async fn init_tenant_mgr(
|
||||
|
||||
info!("Processed {} local tenants at startup", tenants.len());
|
||||
|
||||
let mut tenant_map = tenant_manager.tenants.write().unwrap();
|
||||
*tenant_map = TenantsMap::Open(tenants);
|
||||
let mut tenants_map = TENANTS.write().unwrap();
|
||||
assert!(matches!(&*tenants_map, &TenantsMap::Initializing));
|
||||
|
||||
Ok(())
|
||||
*tenants_map = TenantsMap::Open(tenants);
|
||||
|
||||
Ok(TenantManager {
|
||||
conf,
|
||||
tenants: &TENANTS,
|
||||
resources,
|
||||
cancel: CancellationToken::new(),
|
||||
background_purges,
|
||||
})
|
||||
}
|
||||
|
||||
/// Wrapper for Tenant::spawn that checks invariants before running
|
||||
@@ -726,6 +719,142 @@ fn tenant_spawn(
|
||||
)
|
||||
}
|
||||
|
||||
async fn shutdown_all_tenants0(tenants: &std::sync::RwLock<TenantsMap>) {
|
||||
let mut join_set = JoinSet::new();
|
||||
|
||||
#[cfg(all(debug_assertions, not(test)))]
|
||||
{
|
||||
// Check that our metrics properly tracked the size of the tenants map. This is a convenient location to check,
|
||||
// as it happens implicitly at the end of tests etc.
|
||||
let m = tenants.read().unwrap();
|
||||
debug_assert_eq!(METRICS.slots_total(), m.len() as u64);
|
||||
}
|
||||
|
||||
// Atomically, 1. create the shutdown tasks and 2. prevent creation of new tenants.
|
||||
let (total_in_progress, total_attached) = {
|
||||
let mut m = tenants.write().unwrap();
|
||||
match &mut *m {
|
||||
TenantsMap::Initializing => {
|
||||
*m = TenantsMap::ShuttingDown(BTreeMap::default());
|
||||
info!("tenants map is empty");
|
||||
return;
|
||||
}
|
||||
TenantsMap::Open(tenants) => {
|
||||
let mut shutdown_state = BTreeMap::new();
|
||||
let mut total_in_progress = 0;
|
||||
let mut total_attached = 0;
|
||||
|
||||
for (tenant_shard_id, v) in std::mem::take(tenants).into_iter() {
|
||||
match v {
|
||||
TenantSlot::Attached(t) => {
|
||||
shutdown_state.insert(tenant_shard_id, TenantSlot::Attached(t.clone()));
|
||||
join_set.spawn(
|
||||
async move {
|
||||
let res = {
|
||||
let (_guard, shutdown_progress) = completion::channel();
|
||||
t.shutdown(shutdown_progress, ShutdownMode::FreezeAndFlush).await
|
||||
};
|
||||
|
||||
if let Err(other_progress) = res {
|
||||
// join the another shutdown in progress
|
||||
other_progress.wait().await;
|
||||
}
|
||||
|
||||
// we cannot afford per tenant logging here, because if s3 is degraded, we are
|
||||
// going to log too many lines
|
||||
debug!("tenant successfully stopped");
|
||||
}
|
||||
.instrument(info_span!("shutdown", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug())),
|
||||
);
|
||||
|
||||
total_attached += 1;
|
||||
}
|
||||
TenantSlot::Secondary(state) => {
|
||||
// We don't need to wait for this individually per-tenant: the
|
||||
// downloader task will be waited on eventually, this cancel
|
||||
// is just to encourage it to drop out if it is doing work
|
||||
// for this tenant right now.
|
||||
state.cancel.cancel();
|
||||
|
||||
shutdown_state.insert(tenant_shard_id, TenantSlot::Secondary(state));
|
||||
}
|
||||
TenantSlot::InProgress(notify) => {
|
||||
// InProgress tenants are not visible in TenantsMap::ShuttingDown: we will
|
||||
// wait for their notifications to fire in this function.
|
||||
join_set.spawn(async move {
|
||||
notify.wait().await;
|
||||
});
|
||||
|
||||
total_in_progress += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
*m = TenantsMap::ShuttingDown(shutdown_state);
|
||||
(total_in_progress, total_attached)
|
||||
}
|
||||
TenantsMap::ShuttingDown(_) => {
|
||||
error!(
|
||||
"already shutting down, this function isn't supposed to be called more than once"
|
||||
);
|
||||
return;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let started_at = std::time::Instant::now();
|
||||
|
||||
info!(
|
||||
"Waiting for {} InProgress tenants and {} Attached tenants to shut down",
|
||||
total_in_progress, total_attached
|
||||
);
|
||||
|
||||
let total = join_set.len();
|
||||
let mut panicked = 0;
|
||||
let mut buffering = true;
|
||||
const BUFFER_FOR: std::time::Duration = std::time::Duration::from_millis(500);
|
||||
let mut buffered = std::pin::pin!(tokio::time::sleep(BUFFER_FOR));
|
||||
|
||||
while !join_set.is_empty() {
|
||||
tokio::select! {
|
||||
Some(joined) = join_set.join_next() => {
|
||||
match joined {
|
||||
Ok(()) => {},
|
||||
Err(join_error) if join_error.is_cancelled() => {
|
||||
unreachable!("we are not cancelling any of the tasks");
|
||||
}
|
||||
Err(join_error) if join_error.is_panic() => {
|
||||
// cannot really do anything, as this panic is likely a bug
|
||||
panicked += 1;
|
||||
}
|
||||
Err(join_error) => {
|
||||
warn!("unknown kind of JoinError: {join_error}");
|
||||
}
|
||||
}
|
||||
if !buffering {
|
||||
// buffer so that every 500ms since the first update (or starting) we'll log
|
||||
// how far away we are; this is because we will get SIGKILL'd at 10s, and we
|
||||
// are not able to log *then*.
|
||||
buffering = true;
|
||||
buffered.as_mut().reset(tokio::time::Instant::now() + BUFFER_FOR);
|
||||
}
|
||||
},
|
||||
_ = &mut buffered, if buffering => {
|
||||
buffering = false;
|
||||
info!(remaining = join_set.len(), total, elapsed_ms = started_at.elapsed().as_millis(), "waiting for tenants to shutdown");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if panicked > 0 {
|
||||
warn!(
|
||||
panicked,
|
||||
total, "observed panicks while shutting down tenants"
|
||||
);
|
||||
}
|
||||
|
||||
// caller will log how long we took
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub(crate) enum UpsertLocationError {
|
||||
#[error("Bad config request: {0}")]
|
||||
@@ -927,8 +1056,7 @@ impl TenantManager {
|
||||
// the tenant is inaccessible to the outside world while we are doing this, but that is sensible:
|
||||
// the state is ill-defined while we're in transition. Transitions are async, but fast: we do
|
||||
// not do significant I/O, and shutdowns should be prompt via cancellation tokens.
|
||||
let mut slot_guard = self
|
||||
.tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)
|
||||
let mut slot_guard = tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)
|
||||
.map_err(|e| match e {
|
||||
TenantSlotError::NotFound(_) => {
|
||||
unreachable!("Called with mode Any")
|
||||
@@ -1095,75 +1223,6 @@ impl TenantManager {
|
||||
}
|
||||
}
|
||||
|
||||
fn tenant_map_acquire_slot(
|
||||
&self,
|
||||
tenant_shard_id: &TenantShardId,
|
||||
mode: TenantSlotAcquireMode,
|
||||
) -> Result<SlotGuard, TenantSlotError> {
|
||||
use TenantSlotAcquireMode::*;
|
||||
METRICS.tenant_slot_writes.inc();
|
||||
|
||||
let mut locked = self.tenants.write().unwrap();
|
||||
let span = tracing::info_span!("acquire_slot", tenant_id=%tenant_shard_id.tenant_id, shard_id = %tenant_shard_id.shard_slug());
|
||||
let _guard = span.enter();
|
||||
|
||||
let m = match &mut *locked {
|
||||
TenantsMap::Initializing => return Err(TenantMapError::StillInitializing.into()),
|
||||
TenantsMap::ShuttingDown(_) => return Err(TenantMapError::ShuttingDown.into()),
|
||||
TenantsMap::Open(m) => m,
|
||||
};
|
||||
|
||||
use std::collections::btree_map::Entry;
|
||||
|
||||
let entry = m.entry(*tenant_shard_id);
|
||||
|
||||
match entry {
|
||||
Entry::Vacant(v) => match mode {
|
||||
MustExist => {
|
||||
tracing::debug!("Vacant && MustExist: return NotFound");
|
||||
Err(TenantSlotError::NotFound(*tenant_shard_id))
|
||||
}
|
||||
_ => {
|
||||
let (completion, barrier) = utils::completion::channel();
|
||||
let inserting = TenantSlot::InProgress(barrier);
|
||||
METRICS.slot_inserted(&inserting);
|
||||
v.insert(inserting);
|
||||
tracing::debug!("Vacant, inserted InProgress");
|
||||
Ok(SlotGuard::new(
|
||||
*tenant_shard_id,
|
||||
None,
|
||||
completion,
|
||||
&self.tenants,
|
||||
))
|
||||
}
|
||||
},
|
||||
Entry::Occupied(mut o) => {
|
||||
// Apply mode-driven checks
|
||||
match (o.get(), mode) {
|
||||
(TenantSlot::InProgress(_), _) => {
|
||||
tracing::debug!("Occupied, failing for InProgress");
|
||||
Err(TenantSlotError::InProgress)
|
||||
}
|
||||
_ => {
|
||||
// Happy case: the slot was not in any state that violated our mode
|
||||
let (completion, barrier) = utils::completion::channel();
|
||||
let in_progress = TenantSlot::InProgress(barrier);
|
||||
METRICS.slot_inserted(&in_progress);
|
||||
let old_value = o.insert(in_progress);
|
||||
METRICS.slot_removed(&old_value);
|
||||
tracing::debug!("Occupied, replaced with InProgress");
|
||||
Ok(SlotGuard::new(
|
||||
*tenant_shard_id,
|
||||
Some(old_value),
|
||||
completion,
|
||||
&self.tenants,
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Resetting a tenant is equivalent to detaching it, then attaching it again with the same
|
||||
/// LocationConf that was last used to attach it. Optionally, the local file cache may be
|
||||
/// dropped before re-attaching.
|
||||
@@ -1180,8 +1239,7 @@ impl TenantManager {
|
||||
drop_cache: bool,
|
||||
ctx: &RequestContext,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut slot_guard =
|
||||
self.tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)?;
|
||||
let mut slot_guard = tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)?;
|
||||
let Some(old_slot) = slot_guard.get_old_value() else {
|
||||
anyhow::bail!("Tenant not found when trying to reset");
|
||||
};
|
||||
@@ -1330,8 +1388,7 @@ impl TenantManager {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
let slot_guard =
|
||||
self.tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)?;
|
||||
let slot_guard = tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)?;
|
||||
match &slot_guard.old_value {
|
||||
Some(TenantSlot::Attached(tenant)) => {
|
||||
// Legacy deletion flow: the tenant remains attached, goes to Stopping state, and
|
||||
@@ -1482,7 +1539,7 @@ impl TenantManager {
|
||||
// Phase 2: Put the parent shard to InProgress and grab a reference to the parent Tenant
|
||||
drop(tenant);
|
||||
let mut parent_slot_guard =
|
||||
self.tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)?;
|
||||
tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)?;
|
||||
let parent = match parent_slot_guard.get_old_value() {
|
||||
Some(TenantSlot::Attached(t)) => t,
|
||||
Some(TenantSlot::Secondary(_)) => anyhow::bail!("Tenant location in secondary mode"),
|
||||
@@ -1786,145 +1843,7 @@ impl TenantManager {
|
||||
pub(crate) async fn shutdown(&self) {
|
||||
self.cancel.cancel();
|
||||
|
||||
self.shutdown_all_tenants0().await
|
||||
}
|
||||
|
||||
async fn shutdown_all_tenants0(&self) {
|
||||
let mut join_set = JoinSet::new();
|
||||
|
||||
#[cfg(all(debug_assertions, not(test)))]
|
||||
{
|
||||
// Check that our metrics properly tracked the size of the tenants map. This is a convenient location to check,
|
||||
// as it happens implicitly at the end of tests etc.
|
||||
let m = self.tenants.read().unwrap();
|
||||
debug_assert_eq!(METRICS.slots_total(), m.len() as u64);
|
||||
}
|
||||
|
||||
// Atomically, 1. create the shutdown tasks and 2. prevent creation of new tenants.
|
||||
let (total_in_progress, total_attached) = {
|
||||
let mut m = self.tenants.write().unwrap();
|
||||
match &mut *m {
|
||||
TenantsMap::Initializing => {
|
||||
*m = TenantsMap::ShuttingDown(BTreeMap::default());
|
||||
info!("tenants map is empty");
|
||||
return;
|
||||
}
|
||||
TenantsMap::Open(tenants) => {
|
||||
let mut shutdown_state = BTreeMap::new();
|
||||
let mut total_in_progress = 0;
|
||||
let mut total_attached = 0;
|
||||
|
||||
for (tenant_shard_id, v) in std::mem::take(tenants).into_iter() {
|
||||
match v {
|
||||
TenantSlot::Attached(t) => {
|
||||
shutdown_state
|
||||
.insert(tenant_shard_id, TenantSlot::Attached(t.clone()));
|
||||
join_set.spawn(
|
||||
async move {
|
||||
let res = {
|
||||
let (_guard, shutdown_progress) = completion::channel();
|
||||
t.shutdown(shutdown_progress, ShutdownMode::FreezeAndFlush).await
|
||||
};
|
||||
|
||||
if let Err(other_progress) = res {
|
||||
// join the another shutdown in progress
|
||||
other_progress.wait().await;
|
||||
}
|
||||
|
||||
// we cannot afford per tenant logging here, because if s3 is degraded, we are
|
||||
// going to log too many lines
|
||||
debug!("tenant successfully stopped");
|
||||
}
|
||||
.instrument(info_span!("shutdown", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug())),
|
||||
);
|
||||
|
||||
total_attached += 1;
|
||||
}
|
||||
TenantSlot::Secondary(state) => {
|
||||
// We don't need to wait for this individually per-tenant: the
|
||||
// downloader task will be waited on eventually, this cancel
|
||||
// is just to encourage it to drop out if it is doing work
|
||||
// for this tenant right now.
|
||||
state.cancel.cancel();
|
||||
|
||||
shutdown_state
|
||||
.insert(tenant_shard_id, TenantSlot::Secondary(state));
|
||||
}
|
||||
TenantSlot::InProgress(notify) => {
|
||||
// InProgress tenants are not visible in TenantsMap::ShuttingDown: we will
|
||||
// wait for their notifications to fire in this function.
|
||||
join_set.spawn(async move {
|
||||
notify.wait().await;
|
||||
});
|
||||
|
||||
total_in_progress += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
*m = TenantsMap::ShuttingDown(shutdown_state);
|
||||
(total_in_progress, total_attached)
|
||||
}
|
||||
TenantsMap::ShuttingDown(_) => {
|
||||
error!(
|
||||
"already shutting down, this function isn't supposed to be called more than once"
|
||||
);
|
||||
return;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let started_at = std::time::Instant::now();
|
||||
|
||||
info!(
|
||||
"Waiting for {} InProgress tenants and {} Attached tenants to shut down",
|
||||
total_in_progress, total_attached
|
||||
);
|
||||
|
||||
let total = join_set.len();
|
||||
let mut panicked = 0;
|
||||
let mut buffering = true;
|
||||
const BUFFER_FOR: std::time::Duration = std::time::Duration::from_millis(500);
|
||||
let mut buffered = std::pin::pin!(tokio::time::sleep(BUFFER_FOR));
|
||||
|
||||
while !join_set.is_empty() {
|
||||
tokio::select! {
|
||||
Some(joined) = join_set.join_next() => {
|
||||
match joined {
|
||||
Ok(()) => {},
|
||||
Err(join_error) if join_error.is_cancelled() => {
|
||||
unreachable!("we are not cancelling any of the tasks");
|
||||
}
|
||||
Err(join_error) if join_error.is_panic() => {
|
||||
// cannot really do anything, as this panic is likely a bug
|
||||
panicked += 1;
|
||||
}
|
||||
Err(join_error) => {
|
||||
warn!("unknown kind of JoinError: {join_error}");
|
||||
}
|
||||
}
|
||||
if !buffering {
|
||||
// buffer so that every 500ms since the first update (or starting) we'll log
|
||||
// how far away we are; this is because we will get SIGKILL'd at 10s, and we
|
||||
// are not able to log *then*.
|
||||
buffering = true;
|
||||
buffered.as_mut().reset(tokio::time::Instant::now() + BUFFER_FOR);
|
||||
}
|
||||
},
|
||||
_ = &mut buffered, if buffering => {
|
||||
buffering = false;
|
||||
info!(remaining = join_set.len(), total, elapsed_ms = started_at.elapsed().as_millis(), "waiting for tenants to shutdown");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if panicked > 0 {
|
||||
warn!(
|
||||
panicked,
|
||||
total, "observed panicks while shutting down tenants"
|
||||
);
|
||||
}
|
||||
|
||||
// caller will log how long we took
|
||||
shutdown_all_tenants0(self.tenants).await
|
||||
}
|
||||
|
||||
/// Detaches a tenant, and removes its local files asynchronously.
|
||||
@@ -1970,12 +1889,12 @@ impl TenantManager {
|
||||
.map(Some)
|
||||
};
|
||||
|
||||
let mut removal_result = self
|
||||
.remove_tenant_from_memory(
|
||||
tenant_shard_id,
|
||||
tenant_dir_rename_operation(tenant_shard_id),
|
||||
)
|
||||
.await;
|
||||
let mut removal_result = remove_tenant_from_memory(
|
||||
self.tenants,
|
||||
tenant_shard_id,
|
||||
tenant_dir_rename_operation(tenant_shard_id),
|
||||
)
|
||||
.await;
|
||||
|
||||
// If the tenant was not found, it was likely already removed. Attempt to remove the tenant
|
||||
// directory on disk anyway. For example, during shard splits, we shut down and remove the
|
||||
@@ -2029,16 +1948,17 @@ impl TenantManager {
|
||||
) -> Result<HashSet<TimelineId>, detach_ancestor::Error> {
|
||||
use detach_ancestor::Error;
|
||||
|
||||
let slot_guard = self
|
||||
.tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::MustExist)
|
||||
.map_err(|e| {
|
||||
use TenantSlotError::*;
|
||||
let slot_guard =
|
||||
tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::MustExist).map_err(
|
||||
|e| {
|
||||
use TenantSlotError::*;
|
||||
|
||||
match e {
|
||||
MapState(TenantMapError::ShuttingDown) => Error::ShuttingDown,
|
||||
NotFound(_) | InProgress | MapState(_) => Error::DetachReparent(e.into()),
|
||||
}
|
||||
})?;
|
||||
match e {
|
||||
MapState(TenantMapError::ShuttingDown) => Error::ShuttingDown,
|
||||
NotFound(_) | InProgress | MapState(_) => Error::DetachReparent(e.into()),
|
||||
}
|
||||
},
|
||||
)?;
|
||||
|
||||
let tenant = {
|
||||
let old_slot = slot_guard
|
||||
@@ -2371,80 +2291,6 @@ impl TenantManager {
|
||||
other => ApiError::InternalServerError(anyhow::anyhow!(other)),
|
||||
})
|
||||
}
|
||||
|
||||
/// Stops and removes the tenant from memory, if it's not [`TenantState::Stopping`] already, bails otherwise.
|
||||
/// Allows to remove other tenant resources manually, via `tenant_cleanup`.
|
||||
/// If the cleanup fails, tenant will stay in memory in [`TenantState::Broken`] state, and another removal
|
||||
async fn remove_tenant_from_memory<V, F>(
|
||||
&self,
|
||||
tenant_shard_id: TenantShardId,
|
||||
tenant_cleanup: F,
|
||||
) -> Result<V, TenantStateError>
|
||||
where
|
||||
F: std::future::Future<Output = anyhow::Result<V>>,
|
||||
{
|
||||
let mut slot_guard =
|
||||
self.tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::MustExist)?;
|
||||
|
||||
// allow pageserver shutdown to await for our completion
|
||||
let (_guard, progress) = completion::channel();
|
||||
|
||||
// The SlotGuard allows us to manipulate the Tenant object without fear of some
|
||||
// concurrent API request doing something else for the same tenant ID.
|
||||
let attached_tenant = match slot_guard.get_old_value() {
|
||||
Some(TenantSlot::Attached(tenant)) => {
|
||||
// whenever we remove a tenant from memory, we don't want to flush and wait for upload
|
||||
let shutdown_mode = ShutdownMode::Hard;
|
||||
|
||||
// shutdown is sure to transition tenant to stopping, and wait for all tasks to complete, so
|
||||
// that we can continue safely to cleanup.
|
||||
match tenant.shutdown(progress, shutdown_mode).await {
|
||||
Ok(()) => {}
|
||||
Err(_other) => {
|
||||
// if pageserver shutdown or other detach/ignore is already ongoing, we don't want to
|
||||
// wait for it but return an error right away because these are distinct requests.
|
||||
slot_guard.revert();
|
||||
return Err(TenantStateError::IsStopping(tenant_shard_id));
|
||||
}
|
||||
}
|
||||
Some(tenant)
|
||||
}
|
||||
Some(TenantSlot::Secondary(secondary_state)) => {
|
||||
tracing::info!("Shutting down in secondary mode");
|
||||
secondary_state.shutdown().await;
|
||||
None
|
||||
}
|
||||
Some(TenantSlot::InProgress(_)) => {
|
||||
// Acquiring a slot guarantees its old value was not InProgress
|
||||
unreachable!();
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
|
||||
match tenant_cleanup
|
||||
.await
|
||||
.with_context(|| format!("Failed to run cleanup for tenant {tenant_shard_id}"))
|
||||
{
|
||||
Ok(hook_value) => {
|
||||
// Success: drop the old TenantSlot::Attached.
|
||||
slot_guard
|
||||
.drop_old_value()
|
||||
.expect("We just called shutdown");
|
||||
|
||||
Ok(hook_value)
|
||||
}
|
||||
Err(e) => {
|
||||
// If we had a Tenant, set it to Broken and put it back in the TenantsMap
|
||||
if let Some(attached_tenant) = attached_tenant {
|
||||
attached_tenant.set_broken(e.to_string()).await;
|
||||
}
|
||||
// Leave the broken tenant in the map
|
||||
slot_guard.revert();
|
||||
|
||||
Err(TenantStateError::Other(e))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
@@ -2609,7 +2455,7 @@ pub(crate) enum TenantMapError {
|
||||
/// this tenant to retry later, or wait for the InProgress state to end.
|
||||
///
|
||||
/// This structure enforces the important invariant that we do not have overlapping
|
||||
/// tasks that will try to use local storage for a the same tenant ID: we enforce that
|
||||
/// tasks that will try use local storage for a the same tenant ID: we enforce that
|
||||
/// the previous contents of a slot have been shut down before the slot can be
|
||||
/// left empty or used for something else
|
||||
///
|
||||
@@ -2622,7 +2468,7 @@ pub(crate) enum TenantMapError {
|
||||
/// The `old_value` may be dropped before the SlotGuard is dropped, by calling
|
||||
/// `drop_old_value`. It is an error to call this without shutting down
|
||||
/// the conents of `old_value`.
|
||||
pub(crate) struct SlotGuard<'a> {
|
||||
pub(crate) struct SlotGuard {
|
||||
tenant_shard_id: TenantShardId,
|
||||
old_value: Option<TenantSlot>,
|
||||
upserted: bool,
|
||||
@@ -2630,23 +2476,19 @@ pub(crate) struct SlotGuard<'a> {
|
||||
/// [`TenantSlot::InProgress`] carries the corresponding Barrier: it will
|
||||
/// release any waiters as soon as this SlotGuard is dropped.
|
||||
completion: utils::completion::Completion,
|
||||
|
||||
tenants: &'a std::sync::RwLock<TenantsMap>,
|
||||
}
|
||||
|
||||
impl<'a> SlotGuard<'a> {
|
||||
impl SlotGuard {
|
||||
fn new(
|
||||
tenant_shard_id: TenantShardId,
|
||||
old_value: Option<TenantSlot>,
|
||||
completion: utils::completion::Completion,
|
||||
tenants: &'a std::sync::RwLock<TenantsMap>,
|
||||
) -> Self {
|
||||
Self {
|
||||
tenant_shard_id,
|
||||
old_value,
|
||||
upserted: false,
|
||||
completion,
|
||||
tenants,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2670,8 +2512,8 @@ impl<'a> SlotGuard<'a> {
|
||||
));
|
||||
}
|
||||
|
||||
let replaced: Option<TenantSlot> = {
|
||||
let mut locked = self.tenants.write().unwrap();
|
||||
let replaced = {
|
||||
let mut locked = TENANTS.write().unwrap();
|
||||
|
||||
if let TenantSlot::InProgress(_) = new_value {
|
||||
// It is never expected to try and upsert InProgress via this path: it should
|
||||
@@ -2779,7 +2621,7 @@ impl<'a> SlotGuard<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Drop for SlotGuard<'a> {
|
||||
impl Drop for SlotGuard {
|
||||
fn drop(&mut self) {
|
||||
if self.upserted {
|
||||
return;
|
||||
@@ -2787,7 +2629,7 @@ impl<'a> Drop for SlotGuard<'a> {
|
||||
// Our old value is already shutdown, or it never existed: it is safe
|
||||
// for us to fully release the TenantSlot back into an empty state
|
||||
|
||||
let mut locked = self.tenants.write().unwrap();
|
||||
let mut locked = TENANTS.write().unwrap();
|
||||
|
||||
let m = match &mut *locked {
|
||||
TenantsMap::Initializing => {
|
||||
@@ -2869,6 +2711,151 @@ enum TenantSlotAcquireMode {
|
||||
MustExist,
|
||||
}
|
||||
|
||||
fn tenant_map_acquire_slot(
|
||||
tenant_shard_id: &TenantShardId,
|
||||
mode: TenantSlotAcquireMode,
|
||||
) -> Result<SlotGuard, TenantSlotError> {
|
||||
tenant_map_acquire_slot_impl(tenant_shard_id, &TENANTS, mode)
|
||||
}
|
||||
|
||||
fn tenant_map_acquire_slot_impl(
|
||||
tenant_shard_id: &TenantShardId,
|
||||
tenants: &std::sync::RwLock<TenantsMap>,
|
||||
mode: TenantSlotAcquireMode,
|
||||
) -> Result<SlotGuard, TenantSlotError> {
|
||||
use TenantSlotAcquireMode::*;
|
||||
METRICS.tenant_slot_writes.inc();
|
||||
|
||||
let mut locked = tenants.write().unwrap();
|
||||
let span = tracing::info_span!("acquire_slot", tenant_id=%tenant_shard_id.tenant_id, shard_id = %tenant_shard_id.shard_slug());
|
||||
let _guard = span.enter();
|
||||
|
||||
let m = match &mut *locked {
|
||||
TenantsMap::Initializing => return Err(TenantMapError::StillInitializing.into()),
|
||||
TenantsMap::ShuttingDown(_) => return Err(TenantMapError::ShuttingDown.into()),
|
||||
TenantsMap::Open(m) => m,
|
||||
};
|
||||
|
||||
use std::collections::btree_map::Entry;
|
||||
|
||||
let entry = m.entry(*tenant_shard_id);
|
||||
|
||||
match entry {
|
||||
Entry::Vacant(v) => match mode {
|
||||
MustExist => {
|
||||
tracing::debug!("Vacant && MustExist: return NotFound");
|
||||
Err(TenantSlotError::NotFound(*tenant_shard_id))
|
||||
}
|
||||
_ => {
|
||||
let (completion, barrier) = utils::completion::channel();
|
||||
let inserting = TenantSlot::InProgress(barrier);
|
||||
METRICS.slot_inserted(&inserting);
|
||||
v.insert(inserting);
|
||||
tracing::debug!("Vacant, inserted InProgress");
|
||||
Ok(SlotGuard::new(*tenant_shard_id, None, completion))
|
||||
}
|
||||
},
|
||||
Entry::Occupied(mut o) => {
|
||||
// Apply mode-driven checks
|
||||
match (o.get(), mode) {
|
||||
(TenantSlot::InProgress(_), _) => {
|
||||
tracing::debug!("Occupied, failing for InProgress");
|
||||
Err(TenantSlotError::InProgress)
|
||||
}
|
||||
_ => {
|
||||
// Happy case: the slot was not in any state that violated our mode
|
||||
let (completion, barrier) = utils::completion::channel();
|
||||
let in_progress = TenantSlot::InProgress(barrier);
|
||||
METRICS.slot_inserted(&in_progress);
|
||||
let old_value = o.insert(in_progress);
|
||||
METRICS.slot_removed(&old_value);
|
||||
tracing::debug!("Occupied, replaced with InProgress");
|
||||
Ok(SlotGuard::new(
|
||||
*tenant_shard_id,
|
||||
Some(old_value),
|
||||
completion,
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Stops and removes the tenant from memory, if it's not [`TenantState::Stopping`] already, bails otherwise.
|
||||
/// Allows to remove other tenant resources manually, via `tenant_cleanup`.
|
||||
/// If the cleanup fails, tenant will stay in memory in [`TenantState::Broken`] state, and another removal
|
||||
/// operation would be needed to remove it.
|
||||
async fn remove_tenant_from_memory<V, F>(
|
||||
tenants: &std::sync::RwLock<TenantsMap>,
|
||||
tenant_shard_id: TenantShardId,
|
||||
tenant_cleanup: F,
|
||||
) -> Result<V, TenantStateError>
|
||||
where
|
||||
F: std::future::Future<Output = anyhow::Result<V>>,
|
||||
{
|
||||
let mut slot_guard =
|
||||
tenant_map_acquire_slot_impl(&tenant_shard_id, tenants, TenantSlotAcquireMode::MustExist)?;
|
||||
|
||||
// allow pageserver shutdown to await for our completion
|
||||
let (_guard, progress) = completion::channel();
|
||||
|
||||
// The SlotGuard allows us to manipulate the Tenant object without fear of some
|
||||
// concurrent API request doing something else for the same tenant ID.
|
||||
let attached_tenant = match slot_guard.get_old_value() {
|
||||
Some(TenantSlot::Attached(tenant)) => {
|
||||
// whenever we remove a tenant from memory, we don't want to flush and wait for upload
|
||||
let shutdown_mode = ShutdownMode::Hard;
|
||||
|
||||
// shutdown is sure to transition tenant to stopping, and wait for all tasks to complete, so
|
||||
// that we can continue safely to cleanup.
|
||||
match tenant.shutdown(progress, shutdown_mode).await {
|
||||
Ok(()) => {}
|
||||
Err(_other) => {
|
||||
// if pageserver shutdown or other detach/ignore is already ongoing, we don't want to
|
||||
// wait for it but return an error right away because these are distinct requests.
|
||||
slot_guard.revert();
|
||||
return Err(TenantStateError::IsStopping(tenant_shard_id));
|
||||
}
|
||||
}
|
||||
Some(tenant)
|
||||
}
|
||||
Some(TenantSlot::Secondary(secondary_state)) => {
|
||||
tracing::info!("Shutting down in secondary mode");
|
||||
secondary_state.shutdown().await;
|
||||
None
|
||||
}
|
||||
Some(TenantSlot::InProgress(_)) => {
|
||||
// Acquiring a slot guarantees its old value was not InProgress
|
||||
unreachable!();
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
|
||||
match tenant_cleanup
|
||||
.await
|
||||
.with_context(|| format!("Failed to run cleanup for tenant {tenant_shard_id}"))
|
||||
{
|
||||
Ok(hook_value) => {
|
||||
// Success: drop the old TenantSlot::Attached.
|
||||
slot_guard
|
||||
.drop_old_value()
|
||||
.expect("We just called shutdown");
|
||||
|
||||
Ok(hook_value)
|
||||
}
|
||||
Err(e) => {
|
||||
// If we had a Tenant, set it to Broken and put it back in the TenantsMap
|
||||
if let Some(attached_tenant) = attached_tenant {
|
||||
attached_tenant.set_broken(e.to_string()).await;
|
||||
}
|
||||
// Leave the broken tenant in the map
|
||||
slot_guard.revert();
|
||||
|
||||
Err(TenantStateError::Other(e))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
use http_utils::error::ApiError;
|
||||
use pageserver_api::models::TimelineGcRequest;
|
||||
|
||||
@@ -2879,15 +2866,11 @@ mod tests {
|
||||
use std::collections::BTreeMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use storage_broker::BrokerClientChannel;
|
||||
use tracing::Instrument;
|
||||
|
||||
use super::super::harness::TenantHarness;
|
||||
use super::TenantsMap;
|
||||
use crate::tenant::{
|
||||
TenantSharedResources,
|
||||
mgr::{BackgroundPurges, TenantManager, TenantSlot},
|
||||
};
|
||||
use crate::tenant::mgr::TenantSlot;
|
||||
|
||||
#[tokio::test(start_paused = true)]
|
||||
async fn shutdown_awaits_in_progress_tenant() {
|
||||
@@ -2908,47 +2891,23 @@ mod tests {
|
||||
let _e = span.enter();
|
||||
|
||||
let tenants = BTreeMap::from([(id, TenantSlot::Attached(t.clone()))]);
|
||||
let tenants = Arc::new(std::sync::RwLock::new(TenantsMap::Open(tenants)));
|
||||
|
||||
// Invoke remove_tenant_from_memory with a cleanup hook that blocks until we manually
|
||||
// permit it to proceed: that will stick the tenant in InProgress
|
||||
|
||||
let (basebackup_prepare_sender, _) = tokio::sync::mpsc::unbounded_channel::<
|
||||
crate::basebackup_cache::BasebackupPrepareRequest,
|
||||
>();
|
||||
|
||||
let tenant_manager = TenantManager {
|
||||
tenants: std::sync::RwLock::new(TenantsMap::Open(tenants)),
|
||||
conf: h.conf,
|
||||
resources: TenantSharedResources {
|
||||
broker_client: BrokerClientChannel::connect_lazy("foobar.com")
|
||||
.await
|
||||
.unwrap(),
|
||||
remote_storage: h.remote_storage.clone(),
|
||||
deletion_queue_client: h.deletion_queue.new_client(),
|
||||
l0_flush_global_state: crate::l0_flush::L0FlushGlobalState::new(
|
||||
h.conf.l0_flush.clone(),
|
||||
),
|
||||
basebackup_prepare_sender,
|
||||
feature_resolver: crate::feature_resolver::FeatureResolver::new_disabled(),
|
||||
},
|
||||
cancel: tokio_util::sync::CancellationToken::new(),
|
||||
background_purges: BackgroundPurges::default(),
|
||||
};
|
||||
|
||||
let tenant_manager = Arc::new(tenant_manager);
|
||||
|
||||
let (until_cleanup_completed, can_complete_cleanup) = utils::completion::channel();
|
||||
let (until_cleanup_started, cleanup_started) = utils::completion::channel();
|
||||
let mut remove_tenant_from_memory_task = {
|
||||
let tenant_manager = tenant_manager.clone();
|
||||
let jh = tokio::spawn({
|
||||
let tenants = tenants.clone();
|
||||
async move {
|
||||
let cleanup = async move {
|
||||
drop(until_cleanup_started);
|
||||
can_complete_cleanup.wait().await;
|
||||
anyhow::Ok(())
|
||||
};
|
||||
tenant_manager.remove_tenant_from_memory(id, cleanup).await
|
||||
super::remove_tenant_from_memory(&tenants, id, cleanup).await
|
||||
}
|
||||
.instrument(h.span())
|
||||
});
|
||||
@@ -2961,11 +2920,9 @@ mod tests {
|
||||
let mut shutdown_task = {
|
||||
let (until_shutdown_started, shutdown_started) = utils::completion::channel();
|
||||
|
||||
let tenant_manager = tenant_manager.clone();
|
||||
|
||||
let shutdown_task = tokio::spawn(async move {
|
||||
drop(until_shutdown_started);
|
||||
tenant_manager.shutdown_all_tenants0().await;
|
||||
super::shutdown_all_tenants0(&tenants).await;
|
||||
});
|
||||
|
||||
shutdown_started.wait().await;
|
||||
|
||||
@@ -1092,15 +1092,13 @@ communicator_prefetch_register_bufferv(BufferTag tag, neon_request_lsns *frlsns,
|
||||
MyPState->ring_last <= ring_index);
|
||||
}
|
||||
|
||||
/* Internal version. Returns the ring index of the last block (result of this function is used only
|
||||
* when nblocks==1)
|
||||
*/
|
||||
/* internal version. Returns the ring index */
|
||||
static uint64
|
||||
prefetch_register_bufferv(BufferTag tag, neon_request_lsns *frlsns,
|
||||
BlockNumber nblocks, const bits8 *mask,
|
||||
bool is_prefetch)
|
||||
{
|
||||
uint64 last_ring_index;
|
||||
uint64 min_ring_index;
|
||||
PrefetchRequest hashkey;
|
||||
#ifdef USE_ASSERT_CHECKING
|
||||
bool any_hits = false;
|
||||
@@ -1124,12 +1122,13 @@ Retry:
|
||||
MyPState->ring_unused - MyPState->ring_receive;
|
||||
MyNeonCounters->getpage_prefetches_buffered =
|
||||
MyPState->n_responses_buffered;
|
||||
last_ring_index = UINT64_MAX;
|
||||
|
||||
min_ring_index = UINT64_MAX;
|
||||
for (int i = 0; i < nblocks; i++)
|
||||
{
|
||||
PrefetchRequest *slot = NULL;
|
||||
PrfHashEntry *entry = NULL;
|
||||
uint64 ring_index;
|
||||
neon_request_lsns *lsns;
|
||||
|
||||
if (PointerIsValid(mask) && BITMAP_ISSET(mask, i))
|
||||
@@ -1153,12 +1152,12 @@ Retry:
|
||||
if (entry != NULL)
|
||||
{
|
||||
slot = entry->slot;
|
||||
last_ring_index = slot->my_ring_index;
|
||||
Assert(slot == GetPrfSlot(last_ring_index));
|
||||
ring_index = slot->my_ring_index;
|
||||
Assert(slot == GetPrfSlot(ring_index));
|
||||
|
||||
Assert(slot->status != PRFS_UNUSED);
|
||||
Assert(MyPState->ring_last <= last_ring_index &&
|
||||
last_ring_index < MyPState->ring_unused);
|
||||
Assert(MyPState->ring_last <= ring_index &&
|
||||
ring_index < MyPState->ring_unused);
|
||||
Assert(BufferTagsEqual(&slot->buftag, &hashkey.buftag));
|
||||
|
||||
/*
|
||||
@@ -1170,9 +1169,9 @@ Retry:
|
||||
if (!neon_prefetch_response_usable(lsns, slot))
|
||||
{
|
||||
/* Wait for the old request to finish and discard it */
|
||||
if (!prefetch_wait_for(last_ring_index))
|
||||
if (!prefetch_wait_for(ring_index))
|
||||
goto Retry;
|
||||
prefetch_set_unused(last_ring_index);
|
||||
prefetch_set_unused(ring_index);
|
||||
entry = NULL;
|
||||
slot = NULL;
|
||||
pgBufferUsage.prefetch.expired += 1;
|
||||
@@ -1189,12 +1188,13 @@ Retry:
|
||||
*/
|
||||
if (slot->status == PRFS_TAG_REMAINS)
|
||||
{
|
||||
prefetch_set_unused(last_ring_index);
|
||||
prefetch_set_unused(ring_index);
|
||||
entry = NULL;
|
||||
slot = NULL;
|
||||
}
|
||||
else
|
||||
{
|
||||
min_ring_index = Min(min_ring_index, ring_index);
|
||||
/* The buffered request is good enough, return that index */
|
||||
if (is_prefetch)
|
||||
pgBufferUsage.prefetch.duplicates++;
|
||||
@@ -1283,12 +1283,12 @@ Retry:
|
||||
* The next buffer pointed to by `ring_unused` is now definitely empty, so
|
||||
* we can insert the new request to it.
|
||||
*/
|
||||
last_ring_index = MyPState->ring_unused;
|
||||
ring_index = MyPState->ring_unused;
|
||||
|
||||
Assert(MyPState->ring_last <= last_ring_index &&
|
||||
last_ring_index <= MyPState->ring_unused);
|
||||
Assert(MyPState->ring_last <= ring_index &&
|
||||
ring_index <= MyPState->ring_unused);
|
||||
|
||||
slot = GetPrfSlotNoCheck(last_ring_index);
|
||||
slot = GetPrfSlotNoCheck(ring_index);
|
||||
|
||||
Assert(slot->status == PRFS_UNUSED);
|
||||
|
||||
@@ -1298,9 +1298,11 @@ Retry:
|
||||
*/
|
||||
slot->buftag = hashkey.buftag;
|
||||
slot->shard_no = get_shard_number(&tag);
|
||||
slot->my_ring_index = last_ring_index;
|
||||
slot->my_ring_index = ring_index;
|
||||
slot->flags = 0;
|
||||
|
||||
min_ring_index = Min(min_ring_index, ring_index);
|
||||
|
||||
if (is_prefetch)
|
||||
MyNeonCounters->getpage_prefetch_requests_total++;
|
||||
else
|
||||
@@ -1313,12 +1315,11 @@ Retry:
|
||||
MyPState->ring_unused - MyPState->ring_receive;
|
||||
|
||||
Assert(any_hits);
|
||||
Assert(last_ring_index != UINT64_MAX);
|
||||
|
||||
Assert(GetPrfSlot(last_ring_index)->status == PRFS_REQUESTED ||
|
||||
GetPrfSlot(last_ring_index)->status == PRFS_RECEIVED);
|
||||
Assert(MyPState->ring_last <= last_ring_index &&
|
||||
last_ring_index < MyPState->ring_unused);
|
||||
Assert(GetPrfSlot(min_ring_index)->status == PRFS_REQUESTED ||
|
||||
GetPrfSlot(min_ring_index)->status == PRFS_RECEIVED);
|
||||
Assert(MyPState->ring_last <= min_ring_index &&
|
||||
min_ring_index < MyPState->ring_unused);
|
||||
|
||||
if (flush_every_n_requests > 0 &&
|
||||
MyPState->ring_unused - MyPState->ring_flush >= flush_every_n_requests)
|
||||
@@ -1334,7 +1335,7 @@ Retry:
|
||||
MyPState->ring_flush = MyPState->ring_unused;
|
||||
}
|
||||
|
||||
return last_ring_index;
|
||||
return min_ring_index;
|
||||
}
|
||||
|
||||
static bool
|
||||
|
||||
@@ -1135,7 +1135,7 @@ VotesCollectedMset(WalProposer *wp, MemberSet *mset, Safekeeper **msk, StringInf
|
||||
wp->propTermStartLsn = sk->voteResponse.flushLsn;
|
||||
wp->donor = sk;
|
||||
}
|
||||
wp->truncateLsn = Max(sk->voteResponse.truncateLsn, wp->truncateLsn);
|
||||
wp->truncateLsn = Max(wp->safekeeper[i].voteResponse.truncateLsn, wp->truncateLsn);
|
||||
|
||||
if (n_votes > 0)
|
||||
appendStringInfoString(s, ", ");
|
||||
|
||||
@@ -89,6 +89,7 @@ tokio-postgres = { workspace = true, optional = true }
|
||||
tokio-rustls.workspace = true
|
||||
tokio-util.workspace = true
|
||||
tokio = { workspace = true, features = ["signal"] }
|
||||
toml.workspace = true
|
||||
tracing-subscriber.workspace = true
|
||||
tracing-utils.workspace = true
|
||||
tracing.workspace = true
|
||||
|
||||
@@ -14,9 +14,9 @@ use crate::context::RequestContext;
|
||||
use crate::control_plane::client::cplane_proxy_v1;
|
||||
use crate::control_plane::{self, CachedNodeInfo, NodeInfo};
|
||||
use crate::error::{ReportableError, UserFacingError};
|
||||
use crate::pglb::connect_compute::ComputeConnectBackend;
|
||||
use crate::pqproto::BeMessage;
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::proxy::wake_compute::WakeComputeBackend;
|
||||
use crate::stream::PqStream;
|
||||
use crate::types::RoleName;
|
||||
use crate::{auth, compute, waiters};
|
||||
@@ -109,7 +109,7 @@ impl ConsoleRedirectBackend {
|
||||
pub struct ConsoleRedirectNodeInfo(pub(super) NodeInfo);
|
||||
|
||||
#[async_trait]
|
||||
impl WakeComputeBackend for ConsoleRedirectNodeInfo {
|
||||
impl ComputeConnectBackend for ConsoleRedirectNodeInfo {
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
_ctx: &RequestContext,
|
||||
|
||||
@@ -25,9 +25,9 @@ use crate::control_plane::{
|
||||
RoleAccessControl,
|
||||
};
|
||||
use crate::intern::EndpointIdInt;
|
||||
use crate::pglb::connect_compute::ComputeConnectBackend;
|
||||
use crate::pqproto::BeMessage;
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::proxy::wake_compute::WakeComputeBackend;
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::stream::Stream;
|
||||
use crate::types::{EndpointCacheKey, EndpointId, RoleName};
|
||||
@@ -407,13 +407,13 @@ impl Backend<'_, ComputeUserInfo> {
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl WakeComputeBackend for Backend<'_, ComputeUserInfo> {
|
||||
impl ComputeConnectBackend for Backend<'_, ComputeCredentials> {
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
|
||||
match self {
|
||||
Self::ControlPlane(api, info) => api.wake_compute(ctx, info).await,
|
||||
Self::ControlPlane(api, creds) => api.wake_compute(ctx, &creds.info).await,
|
||||
Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -279,7 +279,6 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
|
||||
},
|
||||
proxy_protocol_v2: config::ProxyProtocolV2::Rejected,
|
||||
handshake_timeout: Duration::from_secs(10),
|
||||
region: "local".into(),
|
||||
wake_compute_retry_config: RetryConfig::parse(RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)?,
|
||||
connect_compute_locks,
|
||||
connect_to_compute: compute_config,
|
||||
|
||||
@@ -26,12 +26,12 @@ use utils::sentry_init::init_sentry;
|
||||
|
||||
use crate::context::RequestContext;
|
||||
use crate::metrics::{Metrics, ThreadPoolMetrics};
|
||||
use crate::pglb::TlsRequired;
|
||||
use crate::pqproto::FeStartupPacket;
|
||||
use crate::protocol2::ConnectionInfo;
|
||||
use crate::proxy::{ErrorSource, copy_bidirectional_client_compute};
|
||||
use crate::proxy::{
|
||||
ErrorSource, TlsRequired, copy_bidirectional_client_compute, run_until_cancelled,
|
||||
};
|
||||
use crate::stream::{PqStream, Stream};
|
||||
use crate::util::run_until_cancelled;
|
||||
|
||||
project_git_version!(GIT_VERSION);
|
||||
|
||||
@@ -237,7 +237,6 @@ pub(super) async fn task_main(
|
||||
extra: None,
|
||||
},
|
||||
crate::metrics::Protocol::SniRouter,
|
||||
"sni",
|
||||
);
|
||||
handle_client(ctx, dest_suffix, tls_config, compute_tls_config, socket).await
|
||||
}
|
||||
|
||||
@@ -8,14 +8,15 @@ use std::time::Duration;
|
||||
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
use anyhow::Context;
|
||||
use anyhow::{bail, ensure};
|
||||
use anyhow::{bail, anyhow};
|
||||
use arc_swap::ArcSwapOption;
|
||||
use futures::future::Either;
|
||||
use remote_storage::RemoteStorageConfig;
|
||||
use serde::Deserialize;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::task::JoinSet;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{Instrument, info, warn};
|
||||
use tracing::{Instrument, info};
|
||||
use utils::sentry_init::init_sentry;
|
||||
use utils::{project_build_tag, project_git_version};
|
||||
|
||||
@@ -39,7 +40,7 @@ use crate::serverless::cancel_set::CancelSet;
|
||||
use crate::tls::client_config::compute_client_config_with_root_certs;
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
use crate::url::ApiUrl;
|
||||
use crate::{auth, control_plane, http, serverless, usage_metrics};
|
||||
use crate::{auth, control_plane, http, pglb, serverless, usage_metrics};
|
||||
|
||||
project_git_version!(GIT_VERSION);
|
||||
project_build_tag!(BUILD_TAG);
|
||||
@@ -59,6 +60,262 @@ enum AuthBackendType {
|
||||
Postgres,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Root {
|
||||
#[serde(flatten)]
|
||||
legacy: LegacyModes,
|
||||
introspection: Introspection,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum LegacyModes {
|
||||
Proxy {
|
||||
pglb: Pglb,
|
||||
neonkeeper: NeonKeeper,
|
||||
http: Option<Http>,
|
||||
pg_sni_router: Option<PgSniRouter>,
|
||||
},
|
||||
AuthBroker {
|
||||
neonkeeper: NeonKeeper,
|
||||
http: Http,
|
||||
},
|
||||
ConsoleRedirect {
|
||||
console_redirect: ConsoleRedirect,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Pglb {
|
||||
listener: Listener,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Listener {
|
||||
/// address to bind to
|
||||
addr: SocketAddr,
|
||||
/// which header should we expect to see on this socket
|
||||
/// from the load balancer
|
||||
header: Option<ProxyHeader>,
|
||||
|
||||
/// certificates used for TLS.
|
||||
/// first cert is the default.
|
||||
/// TLS not used if no certs provided.
|
||||
certs: Vec<KeyPair>,
|
||||
|
||||
/// Timeout to use for TLS handshake
|
||||
timeout: Option<Duration>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
enum ProxyHeader {
|
||||
/// Accept the PROXY! protocol V2.
|
||||
ProxyProtocolV2(ProxyProtocolV2Kind),
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
enum ProxyProtocolV2Kind {
|
||||
/// Expect AWS TLVs in the header.
|
||||
Aws,
|
||||
/// Expect Azure TLVs in the header.
|
||||
Azure,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct KeyPair {
|
||||
key: PathBuf,
|
||||
cert: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
/// The service that authenticates all incoming connection attempts,
|
||||
/// provides monitoring and also wakes computes.
|
||||
struct NeonKeeper {
|
||||
cplane: ControlPlaneBackend,
|
||||
redis: Option<Redis>,
|
||||
auth: Vec<AuthMechanism>,
|
||||
|
||||
/// map of endpoint->computeinfo
|
||||
compute: Cache,
|
||||
/// cache for GetEndpointAccessControls.
|
||||
project_info_cache: config::ProjectInfoCacheOptions,
|
||||
/// cache for all valid endpoints
|
||||
endpoint_cache_config: config::EndpointCacheConfig,
|
||||
|
||||
request_log_export: Option<RequestLogExport>,
|
||||
data_transfer_export: Option<DataTransferExport>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Redis {
|
||||
/// Cancellation channel size (max queue size for redis kv client)
|
||||
cancellation_ch_size: usize,
|
||||
/// Cancellation ops batch size for redis
|
||||
cancellation_batch_size: usize,
|
||||
|
||||
auth: RedisAuthentication,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
enum RedisAuthentication {
|
||||
/// i don't remember what this stands for.
|
||||
/// IAM roles for service accounts?
|
||||
Irsa {
|
||||
host: String,
|
||||
port: u16,
|
||||
cluster_name: Option<String>,
|
||||
user_id: Option<String>,
|
||||
aws_region: String,
|
||||
},
|
||||
Basic {
|
||||
url: url::Url,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct PgSniRouter {
|
||||
/// The listener to use to proxy connections to compute,
|
||||
/// assuming the compute does not support TLS.
|
||||
listener: Listener,
|
||||
|
||||
/// The listener to use to proxy connections to compute,
|
||||
/// assuming the compute requires TLS.
|
||||
listener_tls: Listener,
|
||||
|
||||
/// append this domain zone to the SNI hostname to get the destination address
|
||||
dest: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
/// `psql -h pg.neon.tech`.
|
||||
struct ConsoleRedirect {
|
||||
/// Connection requests from clients.
|
||||
listener: Listener,
|
||||
/// Messages from control plane to accept the connection.
|
||||
cplane: Listener,
|
||||
|
||||
/// The base url to use for redirects.
|
||||
console: url::Url,
|
||||
|
||||
timeout: Duration,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
enum ControlPlaneBackend {
|
||||
/// Use the HTTP API to access the control plane.
|
||||
Http(url::Url),
|
||||
/// Stub the control plane with a postgres instance.
|
||||
#[cfg(feature = "testing")]
|
||||
PostgresMock(url::Url),
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Http {
|
||||
listener: Listener,
|
||||
sql_over_http: SqlOverHttp,
|
||||
|
||||
// todo: move into Pglb.
|
||||
websockets: Option<Websockets>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct SqlOverHttp {
|
||||
pool_max_conns_per_endpoint: usize,
|
||||
pool_max_total_conns: usize,
|
||||
pool_idle_timeout: Duration,
|
||||
pool_gc_epoch: Duration,
|
||||
pool_shards: usize,
|
||||
|
||||
client_conn_threshold: u64,
|
||||
cancel_set_shards: usize,
|
||||
|
||||
timeout: Duration,
|
||||
max_request_size_bytes: usize,
|
||||
max_response_size_bytes: usize,
|
||||
|
||||
auth: Vec<AuthMechanism>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
enum AuthMechanism {
|
||||
Sasl {
|
||||
/// timeout for SASL handshake
|
||||
timeout: Duration,
|
||||
},
|
||||
CleartextPassword {
|
||||
/// number of threads for the thread pool
|
||||
threads: usize,
|
||||
},
|
||||
// add something about the jwks cache i guess.
|
||||
Jwt {},
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Websockets {
|
||||
auth: Vec<AuthMechanism>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
/// The HTTP API used for internal monitoring.
|
||||
struct Introspection {
|
||||
listener: Listener,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
enum RequestLogExport {
|
||||
Parquet {
|
||||
location: RemoteStorageConfig,
|
||||
disconnect: RemoteStorageConfig,
|
||||
|
||||
/// The region identifier to tag the entries with.
|
||||
region: String,
|
||||
|
||||
/// How many rows to include in a row group
|
||||
row_group_size: usize,
|
||||
|
||||
/// How large each column page should be in bytes
|
||||
page_size: usize,
|
||||
|
||||
/// How large the total parquet file should be in bytes
|
||||
size: i64,
|
||||
|
||||
/// How long to wait before forcing a file upload
|
||||
maximum_duration: tokio::time::Duration,
|
||||
// /// What level of compression to use
|
||||
// compression: Compression,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
enum Cache {
|
||||
/// Expire by LRU or by idle.
|
||||
/// Note: "live" in "time-to-live" actually means idle here.
|
||||
LruTtl {
|
||||
/// Max number of entries.
|
||||
size: usize,
|
||||
/// Entry's time-to-live.
|
||||
ttl: Duration,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct DataTransferExport {
|
||||
/// http endpoint to receive periodic metric updates
|
||||
endpoint: Option<String>,
|
||||
/// how often metrics should be sent to a collection endpoint
|
||||
interval: Option<String>,
|
||||
|
||||
/// interval for backup metric collection
|
||||
backup_interval: std::time::Duration,
|
||||
/// remote storage configuration for backup metric collection
|
||||
/// Encoded as toml (same format as pageservers), eg
|
||||
/// `{bucket_name='the-bucket',bucket_region='us-east-1',prefix_in_bucket='proxy',endpoint='http://minio:9000'}`
|
||||
backup_remote_storage: Option<RemoteStorageConfig>,
|
||||
/// chunk size for backup metric collection
|
||||
/// Size of each event is no more than 400 bytes, so 2**22 is about 200MB before the compression.
|
||||
backup_chunk_size: usize,
|
||||
}
|
||||
|
||||
/// Neon proxy/router
|
||||
#[derive(Parser)]
|
||||
#[command(version = GIT_VERSION, about)]
|
||||
@@ -120,12 +377,6 @@ struct ProxyCliArgs {
|
||||
/// timeout for the TLS handshake
|
||||
#[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
|
||||
handshake_timeout: tokio::time::Duration,
|
||||
/// http endpoint to receive periodic metric updates
|
||||
#[clap(long)]
|
||||
metric_collection_endpoint: Option<String>,
|
||||
/// how often metrics should be sent to a collection endpoint
|
||||
#[clap(long)]
|
||||
metric_collection_interval: Option<String>,
|
||||
/// cache for `wake_compute` api method (use `size=0` to disable)
|
||||
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
|
||||
wake_compute_cache: String,
|
||||
@@ -152,40 +403,31 @@ struct ProxyCliArgs {
|
||||
/// Wake compute rate limiter max number of requests per second.
|
||||
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
|
||||
wake_compute_limit: Vec<RateBucketInfo>,
|
||||
/// Redis rate limiter max number of requests per second.
|
||||
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_REDIS_SET)]
|
||||
redis_rps_limit: Vec<RateBucketInfo>,
|
||||
/// Cancellation channel size (max queue size for redis kv client)
|
||||
#[clap(long, default_value_t = 1024)]
|
||||
cancellation_ch_size: usize,
|
||||
/// Cancellation ops batch size for redis
|
||||
#[clap(long, default_value_t = 8)]
|
||||
cancellation_batch_size: usize,
|
||||
/// cache for `allowed_ips` (use `size=0` to disable)
|
||||
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
|
||||
allowed_ips_cache: String,
|
||||
/// cache for `role_secret` (use `size=0` to disable)
|
||||
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
|
||||
role_secret_cache: String,
|
||||
/// redis url for notifications (if empty, redis_host:port will be used for both notifications and streaming connections)
|
||||
/// redis url for plain authentication
|
||||
#[clap(long, alias("redis-notifications"))]
|
||||
redis_plain: Option<String>,
|
||||
/// what from the available authentications type to use for redis. Supported are "irsa" and "plain".
|
||||
#[clap(long)]
|
||||
redis_notifications: Option<String>,
|
||||
/// what from the available authentications type to use for the regional redis we have. Supported are "irsa" and "plain".
|
||||
#[clap(long, default_value = "irsa")]
|
||||
redis_auth_type: String,
|
||||
/// redis host for streaming connections (might be different from the notifications host)
|
||||
redis_auth_type: Option<String>,
|
||||
/// redis host for irsa authentication
|
||||
#[clap(long)]
|
||||
redis_host: Option<String>,
|
||||
/// redis port for streaming connections (might be different from the notifications host)
|
||||
/// redis port for irsa authentication
|
||||
#[clap(long)]
|
||||
redis_port: Option<u16>,
|
||||
/// redis cluster name, used in aws elasticache
|
||||
/// redis cluster name for irsa authentication
|
||||
#[clap(long)]
|
||||
redis_cluster_name: Option<String>,
|
||||
/// redis user_id, used in aws elasticache
|
||||
/// redis user_id for irsa authentication
|
||||
#[clap(long)]
|
||||
redis_user_id: Option<String>,
|
||||
/// aws region to retrieve credentials
|
||||
/// aws region for irsa authentication
|
||||
#[clap(long, default_value_t = String::new())]
|
||||
aws_region: String,
|
||||
/// cache for `project_info` (use `size=0` to disable)
|
||||
@@ -197,6 +439,12 @@ struct ProxyCliArgs {
|
||||
#[clap(flatten)]
|
||||
parquet_upload: ParquetUploadArgs,
|
||||
|
||||
/// http endpoint to receive periodic metric updates
|
||||
#[clap(long)]
|
||||
metric_collection_endpoint: Option<String>,
|
||||
/// how often metrics should be sent to a collection endpoint
|
||||
#[clap(long)]
|
||||
metric_collection_interval: Option<String>,
|
||||
/// interval for backup metric collection
|
||||
#[clap(long, default_value = "10m", value_parser = humantime::parse_duration)]
|
||||
metric_backup_collection_interval: std::time::Duration,
|
||||
@@ -209,6 +457,7 @@ struct ProxyCliArgs {
|
||||
/// Size of each event is no more than 400 bytes, so 2**22 is about 200MB before the compression.
|
||||
#[clap(long, default_value = "4194304")]
|
||||
metric_backup_collection_chunk_size: usize,
|
||||
|
||||
/// Whether to retry the connection to the compute node
|
||||
#[clap(long, default_value = config::RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)]
|
||||
connect_to_compute_retry: String,
|
||||
@@ -319,208 +568,120 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
}
|
||||
};
|
||||
|
||||
let args = ProxyCliArgs::parse();
|
||||
let config = build_config(&args)?;
|
||||
let auth_backend = build_auth_backend(&args)?;
|
||||
|
||||
match auth_backend {
|
||||
Either::Left(auth_backend) => info!("Authentication backend: {auth_backend}"),
|
||||
Either::Right(auth_backend) => info!("Authentication backend: {auth_backend:?}"),
|
||||
}
|
||||
info!("Using region: {}", args.aws_region);
|
||||
let (regional_redis_client, redis_notifications_client) = configure_redis(&args).await?;
|
||||
|
||||
// Check that we can bind to address before further initialization
|
||||
info!("Starting http on {}", args.http);
|
||||
let http_listener = TcpListener::bind(args.http).await?.into_std()?;
|
||||
|
||||
info!("Starting mgmt on {}", args.mgmt);
|
||||
let mgmt_listener = TcpListener::bind(args.mgmt).await?;
|
||||
|
||||
let proxy_listener = if args.is_auth_broker {
|
||||
None
|
||||
} else {
|
||||
info!("Starting proxy on {}", args.proxy);
|
||||
Some(TcpListener::bind(args.proxy).await?)
|
||||
};
|
||||
|
||||
let sni_router_listeners = {
|
||||
let args = &args.pg_sni_router;
|
||||
if args.dest.is_some() {
|
||||
ensure!(
|
||||
args.tls_key.is_some(),
|
||||
"sni-router-tls-key must be provided"
|
||||
);
|
||||
ensure!(
|
||||
args.tls_cert.is_some(),
|
||||
"sni-router-tls-cert must be provided"
|
||||
);
|
||||
|
||||
info!(
|
||||
"Starting pg-sni-router on {} and {}",
|
||||
args.listen, args.listen_tls
|
||||
);
|
||||
|
||||
Some((
|
||||
TcpListener::bind(args.listen).await?,
|
||||
TcpListener::bind(args.listen_tls).await?,
|
||||
))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: rename the argument to something like serverless.
|
||||
// It now covers more than just websockets, it also covers SQL over HTTP.
|
||||
let serverless_listener = if let Some(serverless_address) = args.wss {
|
||||
info!("Starting wss on {serverless_address}");
|
||||
Some(TcpListener::bind(serverless_address).await?)
|
||||
} else if args.is_auth_broker {
|
||||
bail!("wss arg must be present for auth-broker")
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let cancellation_token = CancellationToken::new();
|
||||
|
||||
let redis_rps_limit = Vec::leak(args.redis_rps_limit.clone());
|
||||
RateBucketInfo::validate(redis_rps_limit)?;
|
||||
|
||||
let redis_kv_client = regional_redis_client
|
||||
.as_ref()
|
||||
.map(|redis_publisher| RedisKVClient::new(redis_publisher.clone(), redis_rps_limit));
|
||||
|
||||
// channel size should be higher than redis client limit to avoid blocking
|
||||
let cancel_ch_size = args.cancellation_ch_size;
|
||||
let (tx_cancel, rx_cancel) = tokio::sync::mpsc::channel(cancel_ch_size);
|
||||
let cancellation_handler = Arc::new(CancellationHandler::new(
|
||||
&config.connect_to_compute,
|
||||
Some(tx_cancel),
|
||||
));
|
||||
|
||||
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
|
||||
RateBucketInfo::to_leaky_bucket(&args.endpoint_rps_limit)
|
||||
.unwrap_or(EndpointRateLimiter::DEFAULT),
|
||||
64,
|
||||
));
|
||||
let config: Root = toml::from_str(&tokio::fs::read_to_string("proxy.toml").await?)?;
|
||||
|
||||
// client facing tasks. these will exit on error or on cancellation
|
||||
// cancellation returns Ok(())
|
||||
let mut client_tasks = JoinSet::new();
|
||||
match auth_backend {
|
||||
Either::Left(auth_backend) => {
|
||||
if let Some(proxy_listener) = proxy_listener {
|
||||
client_tasks.spawn(crate::pglb::task_main(
|
||||
config,
|
||||
auth_backend,
|
||||
proxy_listener,
|
||||
cancellation_token.clone(),
|
||||
cancellation_handler.clone(),
|
||||
endpoint_rate_limiter.clone(),
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(serverless_listener) = serverless_listener {
|
||||
client_tasks.spawn(serverless::task_main(
|
||||
config,
|
||||
auth_backend,
|
||||
serverless_listener,
|
||||
cancellation_token.clone(),
|
||||
cancellation_handler.clone(),
|
||||
endpoint_rate_limiter.clone(),
|
||||
));
|
||||
}
|
||||
}
|
||||
Either::Right(auth_backend) => {
|
||||
if let Some(proxy_listener) = proxy_listener {
|
||||
client_tasks.spawn(crate::console_redirect_proxy::task_main(
|
||||
config,
|
||||
auth_backend,
|
||||
proxy_listener,
|
||||
cancellation_token.clone(),
|
||||
cancellation_handler.clone(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// spawn pg-sni-router mode.
|
||||
if let Some((listen, listen_tls)) = sni_router_listeners {
|
||||
let args = args.pg_sni_router;
|
||||
let dest = args.dest.expect("already asserted it is set");
|
||||
let key_path = args.tls_key.expect("already asserted it is set");
|
||||
let cert_path = args.tls_cert.expect("already asserted it is set");
|
||||
|
||||
let tls_config = super::pg_sni_router::parse_tls(&key_path, &cert_path)?;
|
||||
|
||||
let dest = Arc::new(dest);
|
||||
|
||||
client_tasks.spawn(super::pg_sni_router::task_main(
|
||||
dest.clone(),
|
||||
tls_config.clone(),
|
||||
None,
|
||||
listen,
|
||||
cancellation_token.clone(),
|
||||
));
|
||||
|
||||
client_tasks.spawn(super::pg_sni_router::task_main(
|
||||
dest,
|
||||
tls_config,
|
||||
Some(config.connect_to_compute.tls.clone()),
|
||||
listen_tls,
|
||||
cancellation_token.clone(),
|
||||
));
|
||||
}
|
||||
|
||||
client_tasks.spawn(crate::context::parquet::worker(
|
||||
cancellation_token.clone(),
|
||||
args.parquet_upload,
|
||||
));
|
||||
|
||||
// maintenance tasks. these never return unless there's an error
|
||||
let mut maintenance_tasks = JoinSet::new();
|
||||
maintenance_tasks.spawn(crate::signals::handle(cancellation_token.clone(), || {}));
|
||||
maintenance_tasks.spawn(http::health_server::task_main(
|
||||
http_listener,
|
||||
AppMetrics {
|
||||
jemalloc,
|
||||
neon_metrics,
|
||||
proxy: crate::metrics::Metrics::get(),
|
||||
},
|
||||
));
|
||||
maintenance_tasks.spawn(control_plane::mgmt::task_main(mgmt_listener));
|
||||
|
||||
if let Some(metrics_config) = &config.metric_collection {
|
||||
// TODO: Add gc regardles of the metric collection being enabled.
|
||||
maintenance_tasks.spawn(usage_metrics::task_main(metrics_config));
|
||||
}
|
||||
let cancellation_token = CancellationToken::new();
|
||||
|
||||
#[cfg_attr(not(any(test, feature = "testing")), expect(irrefutable_let_patterns))]
|
||||
if let Either::Left(auth::Backend::ControlPlane(api, ())) = &auth_backend {
|
||||
if let crate::control_plane::client::ControlPlaneClient::ProxyV1(api) = &**api {
|
||||
match (redis_notifications_client, regional_redis_client.clone()) {
|
||||
(None, None) => {}
|
||||
(client1, client2) => {
|
||||
let cache = api.caches.project_info.clone();
|
||||
if let Some(client) = client1 {
|
||||
maintenance_tasks.spawn(notifications::task_main(
|
||||
client,
|
||||
cache.clone(),
|
||||
args.region.clone(),
|
||||
));
|
||||
}
|
||||
if let Some(client) = client2 {
|
||||
maintenance_tasks.spawn(notifications::task_main(
|
||||
client,
|
||||
cache.clone(),
|
||||
args.region.clone(),
|
||||
));
|
||||
}
|
||||
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
|
||||
}
|
||||
match config.legacy {
|
||||
LegacyModes::Proxy {
|
||||
pglb,
|
||||
neonkeeper,
|
||||
http,
|
||||
pg_sni_router,
|
||||
} => {
|
||||
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
|
||||
// todo: use neonkeeper config.
|
||||
EndpointRateLimiter::DEFAULT,
|
||||
64,
|
||||
));
|
||||
|
||||
info!("Starting proxy on {}", pglb.listener.addr);
|
||||
let proxy_listener = TcpListener::bind(pglb.listener.addr).await?;
|
||||
|
||||
client_tasks.spawn(crate::proxy::task_main(
|
||||
config,
|
||||
auth_backend,
|
||||
proxy_listener,
|
||||
cancellation_token.clone(),
|
||||
cancellation_handler.clone(),
|
||||
endpoint_rate_limiter.clone(),
|
||||
));
|
||||
|
||||
if let Some(http) = http {
|
||||
info!("Starting wss on {}", http.listener.addr);
|
||||
let http_listener = TcpListener::bind(http.listener.addr).await?;
|
||||
|
||||
client_tasks.spawn(serverless::task_main(
|
||||
config,
|
||||
auth_backend,
|
||||
http_listener,
|
||||
cancellation_token.clone(),
|
||||
cancellation_handler.clone(),
|
||||
endpoint_rate_limiter.clone(),
|
||||
));
|
||||
};
|
||||
|
||||
if let Some(redis) = neonkeeper.redis {
|
||||
let client = configure_redis(redis.auth);
|
||||
}
|
||||
|
||||
if let Some(mut redis_kv_client) = redis_kv_client {
|
||||
if let Some(sni_router) = pg_sni_router {
|
||||
let listen = TcpListener::bind(sni_router.listener.addr).await?;
|
||||
let listen_tls = TcpListener::bind(sni_router.listener_tls.addr).await?;
|
||||
|
||||
let [KeyPair { key, cert }] = sni_router
|
||||
.listener
|
||||
.certs
|
||||
.try_into()
|
||||
.map_err(|_| anyhow!("only 1 keypair is supported for pg-sni-router"))?;
|
||||
|
||||
let tls_config = super::pg_sni_router::parse_tls(&key, &cert)?;
|
||||
|
||||
let dest = Arc::new(sni_router.dest);
|
||||
|
||||
client_tasks.spawn(super::pg_sni_router::task_main(
|
||||
dest.clone(),
|
||||
tls_config.clone(),
|
||||
None,
|
||||
listen,
|
||||
cancellation_token.clone(),
|
||||
));
|
||||
|
||||
client_tasks.spawn(super::pg_sni_router::task_main(
|
||||
dest,
|
||||
tls_config,
|
||||
Some(config.connect_to_compute.tls.clone()),
|
||||
listen_tls,
|
||||
cancellation_token.clone(),
|
||||
));
|
||||
}
|
||||
|
||||
match neonkeeper.request_log_export {
|
||||
Some(RequestLogExport::Parquet {
|
||||
location,
|
||||
disconnect,
|
||||
region,
|
||||
row_group_size,
|
||||
page_size,
|
||||
size,
|
||||
maximum_duration,
|
||||
}) => {
|
||||
client_tasks.spawn(crate::context::parquet::worker(
|
||||
cancellation_token.clone(),
|
||||
args.parquet_upload,
|
||||
args.region,
|
||||
));
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
|
||||
if let (ControlPlaneBackend::Http(api), Some(redis)) =
|
||||
(neonkeeper.cplane, neonkeeper.redis)
|
||||
{
|
||||
// project info cache and invalidation of that cache.
|
||||
let cache = api.caches.project_info.clone();
|
||||
maintenance_tasks.spawn(notifications::task_main(client.clone(), cache.clone()));
|
||||
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
|
||||
|
||||
// cancellation key management
|
||||
let mut redis_kv_client = RedisKVClient::new(client.clone());
|
||||
maintenance_tasks.spawn(async move {
|
||||
redis_kv_client.try_connect().await?;
|
||||
handle_cancel_messages(
|
||||
@@ -537,18 +698,139 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
// so let's wait forever instead.
|
||||
std::future::pending().await
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(regional_redis_client) = regional_redis_client {
|
||||
// listen for notifications of new projects/endpoints/branches
|
||||
let cache = api.caches.endpoints_cache.clone();
|
||||
let con = regional_redis_client;
|
||||
let span = tracing::info_span!("endpoints_cache");
|
||||
maintenance_tasks.spawn(
|
||||
async move { cache.do_read(con, cancellation_token.clone()).await }
|
||||
async move { cache.do_read(client, cancellation_token.clone()).await }
|
||||
.instrument(span),
|
||||
);
|
||||
}
|
||||
}
|
||||
LegacyModes::AuthBroker { neonkeeper, http } => {
|
||||
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
|
||||
// todo: use neonkeeper config.
|
||||
EndpointRateLimiter::DEFAULT,
|
||||
64,
|
||||
));
|
||||
|
||||
info!("Starting wss on {}", http.listener.addr);
|
||||
let http_listener = TcpListener::bind(http.listener.addr).await?;
|
||||
|
||||
if let Some(redis) = neonkeeper.redis {
|
||||
let client = configure_redis(redis.auth);
|
||||
}
|
||||
|
||||
client_tasks.spawn(serverless::task_main(
|
||||
config,
|
||||
auth_backend,
|
||||
serverless_listener,
|
||||
cancellation_token.clone(),
|
||||
cancellation_handler.clone(),
|
||||
endpoint_rate_limiter.clone(),
|
||||
));
|
||||
|
||||
match neonkeeper.request_log_export {
|
||||
Some(RequestLogExport::Parquet {
|
||||
location,
|
||||
disconnect,
|
||||
region,
|
||||
row_group_size,
|
||||
page_size,
|
||||
size,
|
||||
maximum_duration,
|
||||
}) => {
|
||||
client_tasks.spawn(crate::context::parquet::worker(
|
||||
cancellation_token.clone(),
|
||||
args.parquet_upload,
|
||||
args.region,
|
||||
));
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
|
||||
if let (ControlPlaneBackend::Http(api), Some(redis)) =
|
||||
(neonkeeper.cplane, neonkeeper.redis)
|
||||
{
|
||||
// project info cache and invalidation of that cache.
|
||||
let cache = api.caches.project_info.clone();
|
||||
maintenance_tasks.spawn(notifications::task_main(client.clone(), cache.clone()));
|
||||
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
|
||||
|
||||
// cancellation key management
|
||||
let mut redis_kv_client = RedisKVClient::new(client.clone());
|
||||
maintenance_tasks.spawn(async move {
|
||||
redis_kv_client.try_connect().await?;
|
||||
handle_cancel_messages(
|
||||
&mut redis_kv_client,
|
||||
rx_cancel,
|
||||
args.cancellation_batch_size,
|
||||
)
|
||||
.await?;
|
||||
|
||||
drop(redis_kv_client);
|
||||
|
||||
// `handle_cancel_messages` was terminated due to the tx_cancel
|
||||
// being dropped. this is not worthy of an error, and this task can only return `Err`,
|
||||
// so let's wait forever instead.
|
||||
std::future::pending().await
|
||||
});
|
||||
|
||||
// listen for notifications of new projects/endpoints/branches
|
||||
let cache = api.caches.endpoints_cache.clone();
|
||||
let span = tracing::info_span!("endpoints_cache");
|
||||
maintenance_tasks.spawn(
|
||||
async move { cache.do_read(client, cancellation_token.clone()).await }
|
||||
.instrument(span),
|
||||
);
|
||||
}
|
||||
}
|
||||
LegacyModes::ConsoleRedirect { console_redirect } => {
|
||||
info!("Starting proxy on {}", console_redirect.listener.addr);
|
||||
let proxy_listener = TcpListener::bind(console_redirect.listener.addr).await?;
|
||||
|
||||
info!("Starting mgmt on {}", console_redirect.listener.addr);
|
||||
let mgmt_listener = TcpListener::bind(console_redirect.listener.addr).await?;
|
||||
|
||||
client_tasks.spawn(crate::console_redirect_proxy::task_main(
|
||||
config,
|
||||
auth_backend,
|
||||
proxy_listener,
|
||||
cancellation_token.clone(),
|
||||
cancellation_handler.clone(),
|
||||
));
|
||||
maintenance_tasks.spawn(control_plane::mgmt::task_main(mgmt_listener));
|
||||
}
|
||||
}
|
||||
|
||||
// Check that we can bind to address before further initialization
|
||||
info!("Starting http on {}", config.introspection.listener.addr);
|
||||
let http_listener = TcpListener::bind(config.introspection.listener.addr)
|
||||
.await?
|
||||
.into_std()?;
|
||||
|
||||
// channel size should be higher than redis client limit to avoid blocking
|
||||
let cancel_ch_size = args.cancellation_ch_size;
|
||||
let (tx_cancel, rx_cancel) = tokio::sync::mpsc::channel(cancel_ch_size);
|
||||
let cancellation_handler = Arc::new(CancellationHandler::new(
|
||||
&config.connect_to_compute,
|
||||
Some(tx_cancel),
|
||||
));
|
||||
|
||||
maintenance_tasks.spawn(crate::signals::handle(cancellation_token.clone(), || {}));
|
||||
maintenance_tasks.spawn(http::health_server::task_main(
|
||||
http_listener,
|
||||
AppMetrics {
|
||||
jemalloc,
|
||||
neon_metrics,
|
||||
proxy: crate::metrics::Metrics::get(),
|
||||
},
|
||||
));
|
||||
|
||||
if let Some(metrics_config) = &config.metric_collection {
|
||||
// TODO: Add gc regardles of the metric collection being enabled.
|
||||
maintenance_tasks.spawn(usage_metrics::task_main(metrics_config));
|
||||
}
|
||||
|
||||
let maintenance = loop {
|
||||
@@ -673,7 +955,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
authentication_config,
|
||||
proxy_protocol_v2: args.proxy_protocol_v2,
|
||||
handshake_timeout: args.handshake_timeout,
|
||||
region: args.region.clone(),
|
||||
wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?,
|
||||
connect_compute_locks,
|
||||
connect_to_compute: compute_config,
|
||||
@@ -833,58 +1114,45 @@ fn build_auth_backend(
|
||||
}
|
||||
}
|
||||
|
||||
async fn configure_redis(
|
||||
args: &ProxyCliArgs,
|
||||
) -> anyhow::Result<(
|
||||
Option<ConnectionWithCredentialsProvider>,
|
||||
Option<ConnectionWithCredentialsProvider>,
|
||||
)> {
|
||||
// TODO: untangle the config args
|
||||
let regional_redis_client = match (args.redis_auth_type.as_str(), &args.redis_notifications) {
|
||||
("plain", redis_url) => match redis_url {
|
||||
None => {
|
||||
bail!("plain auth requires redis_notifications to be set");
|
||||
}
|
||||
Some(url) => {
|
||||
Some(ConnectionWithCredentialsProvider::new_with_static_credentials(url.clone()))
|
||||
}
|
||||
},
|
||||
("irsa", _) => match (&args.redis_host, args.redis_port) {
|
||||
(Some(host), Some(port)) => Some(
|
||||
ConnectionWithCredentialsProvider::new_with_credentials_provider(
|
||||
host.clone(),
|
||||
port,
|
||||
elasticache::CredentialsProvider::new(
|
||||
args.aws_region.clone(),
|
||||
args.redis_cluster_name.clone(),
|
||||
args.redis_user_id.clone(),
|
||||
)
|
||||
.await,
|
||||
),
|
||||
),
|
||||
(None, None) => {
|
||||
// todo: upgrade to error?
|
||||
warn!(
|
||||
"irsa auth requires redis-host and redis-port to be set, continuing without regional_redis_client"
|
||||
);
|
||||
None
|
||||
}
|
||||
_ => {
|
||||
bail!("redis-host and redis-port must be specified together");
|
||||
}
|
||||
},
|
||||
_ => {
|
||||
bail!("unknown auth type given");
|
||||
async fn configure_redis(auth: RedisAuthentication) -> ConnectionWithCredentialsProvider {
|
||||
match auth {
|
||||
RedisAuthentication::Irsa {
|
||||
host,
|
||||
port,
|
||||
cluster_name,
|
||||
user_id,
|
||||
aws_region,
|
||||
} => ConnectionWithCredentialsProvider::new_with_credentials_provider(
|
||||
host,
|
||||
port,
|
||||
elasticache::CredentialsProvider::new(aws_region, cluster_name, user_id).await,
|
||||
),
|
||||
RedisAuthentication::Basic { url } => {
|
||||
ConnectionWithCredentialsProvider::new_with_static_credentials(url.clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
|
||||
let redis_notifications_client = if let Some(url) = &args.redis_notifications {
|
||||
Some(ConnectionWithCredentialsProvider::new_with_static_credentials(&**url))
|
||||
} else {
|
||||
regional_redis_client.clone()
|
||||
// let redis_notifications_client = if let Some(url) = &args.redis_notifications {
|
||||
// Some(ConnectionWithCredentialsProvider::new_with_static_credentials(&**url))
|
||||
// } else {
|
||||
// regional_redis_client.clone()
|
||||
// };
|
||||
|
||||
Ok(redis_client)
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
|
||||
Ok((regional_redis_client, redis_notifications_client))
|
||||
// let redis_notifications_client = if let Some(url) = &args.redis_notifications {
|
||||
// Some(ConnectionWithCredentialsProvider::new_with_static_credentials(&**url))
|
||||
// } else {
|
||||
// regional_redis_client.clone()
|
||||
// };
|
||||
|
||||
Ok(redis_client)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -103,8 +103,6 @@ pub enum Auth {
|
||||
}
|
||||
|
||||
/// A config for authenticating to the compute node.
|
||||
// TODO: avoid Clone
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct AuthInfo {
|
||||
/// None for local-proxy, as we use trust-based localhost auth.
|
||||
/// Some for sql-over-http, ws, tcp, and in most cases for console-redirect.
|
||||
@@ -138,11 +136,11 @@ impl AuthInfo {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn with_auth_keys(keys: ComputeCredentialKeys) -> Self {
|
||||
pub(crate) fn with_auth_keys(keys: &ComputeCredentialKeys) -> Self {
|
||||
Self {
|
||||
auth: match keys {
|
||||
ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(auth_keys)) => {
|
||||
Some(Auth::Scram(Box::new(auth_keys)))
|
||||
Some(Auth::Scram(Box::new(*auth_keys)))
|
||||
}
|
||||
ComputeCredentialKeys::JwtPayload(_) | ComputeCredentialKeys::None => None,
|
||||
},
|
||||
|
||||
@@ -22,7 +22,6 @@ pub struct ProxyConfig {
|
||||
pub http_config: HttpConfig,
|
||||
pub authentication_config: AuthenticationConfig,
|
||||
pub proxy_protocol_v2: ProxyProtocolV2,
|
||||
pub region: String,
|
||||
pub handshake_timeout: Duration,
|
||||
pub wake_compute_retry_config: RetryConfig,
|
||||
pub connect_compute_locks: ApiLocks<Host>,
|
||||
@@ -70,7 +69,7 @@ pub struct AuthenticationConfig {
|
||||
pub console_redirect_confirmation_timeout: tokio::time::Duration,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
pub struct EndpointCacheConfig {
|
||||
/// Batch size to receive all endpoints on the startup.
|
||||
pub initial_batch_size: usize,
|
||||
@@ -206,7 +205,7 @@ impl FromStr for CacheOptions {
|
||||
}
|
||||
|
||||
/// Helper for cmdline cache options parsing.
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
pub struct ProjectInfoCacheOptions {
|
||||
/// Max number of entries.
|
||||
pub size: usize,
|
||||
|
||||
@@ -11,13 +11,13 @@ use crate::config::{ProxyConfig, ProxyProtocolV2};
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::ReportableError;
|
||||
use crate::metrics::{Metrics, NumClientConnectionsGuard};
|
||||
use crate::pglb::connect_compute::{TcpMechanism, connect_to_compute};
|
||||
use crate::pglb::handshake::{HandshakeData, handshake};
|
||||
use crate::pglb::passthrough::ProxyPassthrough;
|
||||
use crate::pglb::{ClientRequestError, ErrorSource};
|
||||
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
|
||||
use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute};
|
||||
use crate::proxy::prepare_client_connection;
|
||||
use crate::util::run_until_cancelled;
|
||||
use crate::proxy::{
|
||||
ClientRequestError, ErrorSource, prepare_client_connection, run_until_cancelled,
|
||||
};
|
||||
|
||||
pub async fn task_main(
|
||||
config: &'static ProxyConfig,
|
||||
@@ -90,12 +90,7 @@ pub async fn task_main(
|
||||
}
|
||||
}
|
||||
|
||||
let ctx = RequestContext::new(
|
||||
session_id,
|
||||
conn_info,
|
||||
crate::metrics::Protocol::Tcp,
|
||||
&config.region,
|
||||
);
|
||||
let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Tcp);
|
||||
|
||||
let res = handle_client(
|
||||
config,
|
||||
|
||||
@@ -46,7 +46,6 @@ struct RequestContextInner {
|
||||
pub(crate) session_id: Uuid,
|
||||
pub(crate) protocol: Protocol,
|
||||
first_packet: chrono::DateTime<Utc>,
|
||||
region: &'static str,
|
||||
pub(crate) span: Span,
|
||||
|
||||
// filled in as they are discovered
|
||||
@@ -94,7 +93,6 @@ impl Clone for RequestContext {
|
||||
session_id: inner.session_id,
|
||||
protocol: inner.protocol,
|
||||
first_packet: inner.first_packet,
|
||||
region: inner.region,
|
||||
span: info_span!("background_task"),
|
||||
|
||||
project: inner.project,
|
||||
@@ -124,12 +122,7 @@ impl Clone for RequestContext {
|
||||
}
|
||||
|
||||
impl RequestContext {
|
||||
pub fn new(
|
||||
session_id: Uuid,
|
||||
conn_info: ConnectionInfo,
|
||||
protocol: Protocol,
|
||||
region: &'static str,
|
||||
) -> Self {
|
||||
pub fn new(session_id: Uuid, conn_info: ConnectionInfo, protocol: Protocol) -> Self {
|
||||
// TODO: be careful with long lived spans
|
||||
let span = info_span!(
|
||||
"connect_request",
|
||||
@@ -145,7 +138,6 @@ impl RequestContext {
|
||||
session_id,
|
||||
protocol,
|
||||
first_packet: Utc::now(),
|
||||
region,
|
||||
span,
|
||||
|
||||
project: None,
|
||||
@@ -179,7 +171,7 @@ impl RequestContext {
|
||||
let ip = IpAddr::from([127, 0, 0, 1]);
|
||||
let addr = SocketAddr::new(ip, 5432);
|
||||
let conn_info = ConnectionInfo { addr, extra: None };
|
||||
RequestContext::new(Uuid::now_v7(), conn_info, Protocol::Tcp, "test")
|
||||
RequestContext::new(Uuid::now_v7(), conn_info, Protocol::Tcp)
|
||||
}
|
||||
|
||||
pub(crate) fn console_application_name(&self) -> String {
|
||||
|
||||
@@ -74,7 +74,7 @@ pub(crate) const FAILED_UPLOAD_MAX_RETRIES: u32 = 10;
|
||||
|
||||
#[derive(parquet_derive::ParquetRecordWriter)]
|
||||
pub(crate) struct RequestData {
|
||||
region: &'static str,
|
||||
region: String,
|
||||
protocol: &'static str,
|
||||
/// Must be UTC. The derive macro doesn't like the timezones
|
||||
timestamp: chrono::NaiveDateTime,
|
||||
@@ -147,7 +147,7 @@ impl From<&RequestContextInner> for RequestData {
|
||||
}),
|
||||
jwt_issuer: value.jwt_issuer.clone(),
|
||||
protocol: value.protocol.as_str(),
|
||||
region: value.region,
|
||||
region: String::new(),
|
||||
error: value.error_kind.as_ref().map(|e| e.to_metric_label()),
|
||||
success: value.success,
|
||||
cold_start_info: value.cold_start_info.as_str(),
|
||||
@@ -167,6 +167,7 @@ impl From<&RequestContextInner> for RequestData {
|
||||
pub async fn worker(
|
||||
cancellation_token: CancellationToken,
|
||||
config: ParquetUploadArgs,
|
||||
region: String,
|
||||
) -> anyhow::Result<()> {
|
||||
let Some(remote_storage_config) = config.parquet_upload_remote_storage else {
|
||||
tracing::warn!("parquet request upload: no s3 bucket configured");
|
||||
@@ -232,12 +233,17 @@ pub async fn worker(
|
||||
.context("remote storage for disconnect events init")?;
|
||||
let parquet_config_disconnect = parquet_config.clone();
|
||||
tokio::try_join!(
|
||||
worker_inner(storage, rx, parquet_config),
|
||||
worker_inner(storage_disconnect, rx_disconnect, parquet_config_disconnect)
|
||||
worker_inner(storage, rx, parquet_config, ®ion),
|
||||
worker_inner(
|
||||
storage_disconnect,
|
||||
rx_disconnect,
|
||||
parquet_config_disconnect,
|
||||
®ion
|
||||
)
|
||||
)
|
||||
.map(|_| ())
|
||||
} else {
|
||||
worker_inner(storage, rx, parquet_config).await
|
||||
worker_inner(storage, rx, parquet_config, ®ion).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -257,6 +263,7 @@ async fn worker_inner(
|
||||
storage: GenericRemoteStorage,
|
||||
rx: impl Stream<Item = RequestData>,
|
||||
config: ParquetConfig,
|
||||
region: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
let storage = if config.test_remote_failures > 0 {
|
||||
@@ -277,7 +284,8 @@ async fn worker_inner(
|
||||
let mut last_upload = time::Instant::now();
|
||||
|
||||
let mut len = 0;
|
||||
while let Some(row) = rx.next().await {
|
||||
while let Some(mut row) = rx.next().await {
|
||||
region.clone_into(&mut row.region);
|
||||
rows.push(row);
|
||||
let force = last_upload.elapsed() > config.max_duration;
|
||||
if rows.len() == config.rows_per_group || force {
|
||||
@@ -533,7 +541,7 @@ mod tests {
|
||||
auth_method: None,
|
||||
jwt_issuer: None,
|
||||
protocol: ["tcp", "ws", "http"][rng.gen_range(0..3)],
|
||||
region: "us-east-1",
|
||||
region: String::new(),
|
||||
error: None,
|
||||
success: rng.r#gen(),
|
||||
cold_start_info: "no",
|
||||
@@ -565,7 +573,9 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
worker_inner(storage, rx, config).await.unwrap();
|
||||
worker_inner(storage, rx, config, "us-east-1")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut files = WalkDir::new(tmpdir.as_std_path())
|
||||
.into_iter()
|
||||
|
||||
@@ -106,5 +106,4 @@ mod tls;
|
||||
mod types;
|
||||
mod url;
|
||||
mod usage_metrics;
|
||||
mod util;
|
||||
mod waiters;
|
||||
|
||||
@@ -2,9 +2,9 @@ use async_trait::async_trait;
|
||||
use tokio::time;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::auth::backend::{ComputeCredentials, ComputeUserInfo};
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::compute::{self, AuthInfo, COULD_NOT_CONNECT, PostgresConnection};
|
||||
use crate::config::{ComputeConfig, ProxyConfig, RetryConfig};
|
||||
use crate::config::{ComputeConfig, RetryConfig};
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::errors::WakeComputeError;
|
||||
use crate::control_plane::locks::ApiLocks;
|
||||
@@ -14,13 +14,13 @@ use crate::metrics::{
|
||||
ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType,
|
||||
};
|
||||
use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute, retry_after, should_retry};
|
||||
use crate::proxy::wake_compute::{WakeComputeBackend, wake_compute};
|
||||
use crate::proxy::wake_compute::wake_compute;
|
||||
use crate::types::Host;
|
||||
|
||||
/// If we couldn't connect, a cached connection info might be to blame
|
||||
/// (e.g. the compute node's address might've changed at the wrong time).
|
||||
/// Invalidate the cache entry (if any) to prevent subsequent errors.
|
||||
#[tracing::instrument(skip_all)]
|
||||
#[tracing::instrument(name = "invalidate_cache", skip_all)]
|
||||
pub(crate) fn invalidate_cache(node_info: control_plane::CachedNodeInfo) -> NodeInfo {
|
||||
let is_cached = node_info.cached();
|
||||
if is_cached {
|
||||
@@ -49,6 +49,14 @@ pub(crate) trait ConnectMechanism {
|
||||
) -> Result<Self::Connection, Self::ConnectError>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub(crate) trait ComputeConnectBackend {
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError>;
|
||||
}
|
||||
|
||||
pub(crate) struct TcpMechanism {
|
||||
pub(crate) auth: AuthInfo,
|
||||
/// connect_to_compute concurrency lock
|
||||
@@ -83,7 +91,7 @@ impl ConnectMechanism for TcpMechanism {
|
||||
|
||||
/// Try to connect to the compute node, retrying if necessary.
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub(crate) async fn connect_to_compute<M: ConnectMechanism, B: WakeComputeBackend>(
|
||||
pub(crate) async fn connect_to_compute<M: ConnectMechanism, B: ComputeConnectBackend>(
|
||||
ctx: &RequestContext,
|
||||
mechanism: &M,
|
||||
user_info: &B,
|
||||
@@ -183,114 +191,3 @@ where
|
||||
drop(pause);
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub(crate) async fn connect_to_compute_pglb<
|
||||
F: AsyncFn(
|
||||
&'static ProxyConfig,
|
||||
&RequestContext,
|
||||
&CachedNodeInfo,
|
||||
&AuthInfo,
|
||||
&ComputeCredentials,
|
||||
&ComputeConfig,
|
||||
) -> Result<PostgresConnection, compute::ConnectionError>,
|
||||
B: WakeComputeBackend,
|
||||
>(
|
||||
config: &'static ProxyConfig,
|
||||
ctx: &RequestContext,
|
||||
connect_compute_fn: F,
|
||||
user_info: &B,
|
||||
auth_info: &AuthInfo,
|
||||
creds: &ComputeCredentials,
|
||||
wake_compute_retry_config: RetryConfig,
|
||||
compute: &ComputeConfig,
|
||||
) -> Result<PostgresConnection, compute::ConnectionError> {
|
||||
let mut num_retries = 0;
|
||||
let node_info =
|
||||
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?;
|
||||
|
||||
// try once
|
||||
let err = match connect_compute_fn(config, ctx, &node_info, &auth_info, &creds, compute).await {
|
||||
Ok(res) => {
|
||||
ctx.success();
|
||||
Metrics::get().proxy.retries_metric.observe(
|
||||
RetriesMetricGroup {
|
||||
outcome: ConnectOutcome::Success,
|
||||
retry_type: RetryType::ConnectToCompute,
|
||||
},
|
||||
num_retries.into(),
|
||||
);
|
||||
return Ok(res);
|
||||
}
|
||||
Err(e) => e,
|
||||
};
|
||||
|
||||
debug!(error = ?err, COULD_NOT_CONNECT);
|
||||
|
||||
let node_info = if !node_info.cached() || !err.should_retry_wake_compute() {
|
||||
// If we just recieved this from cplane and didn't get it from cache, we shouldn't retry.
|
||||
// Do not need to retrieve a new node_info, just return the old one.
|
||||
if should_retry(&err, num_retries, compute.retry) {
|
||||
Metrics::get().proxy.retries_metric.observe(
|
||||
RetriesMetricGroup {
|
||||
outcome: ConnectOutcome::Failed,
|
||||
retry_type: RetryType::ConnectToCompute,
|
||||
},
|
||||
num_retries.into(),
|
||||
);
|
||||
return Err(err.into());
|
||||
}
|
||||
node_info
|
||||
} else {
|
||||
// if we failed to connect, it's likely that the compute node was suspended, wake a new compute node
|
||||
debug!("compute node's state has likely changed; requesting a wake-up");
|
||||
invalidate_cache(node_info);
|
||||
// TODO: increment num_retries?
|
||||
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?
|
||||
};
|
||||
|
||||
// now that we have a new node, try connect to it repeatedly.
|
||||
// this can error for a few reasons, for instance:
|
||||
// * DNS connection settings haven't quite propagated yet
|
||||
debug!("wake_compute success. attempting to connect");
|
||||
num_retries = 1;
|
||||
loop {
|
||||
match connect_compute_fn(config, ctx, &node_info, &auth_info, &creds, compute).await {
|
||||
Ok(res) => {
|
||||
ctx.success();
|
||||
Metrics::get().proxy.retries_metric.observe(
|
||||
RetriesMetricGroup {
|
||||
outcome: ConnectOutcome::Success,
|
||||
retry_type: RetryType::ConnectToCompute,
|
||||
},
|
||||
num_retries.into(),
|
||||
);
|
||||
// TODO: is this necessary? We have a metric.
|
||||
info!(?num_retries, "connected to compute node after");
|
||||
return Ok(res);
|
||||
}
|
||||
Err(e) => {
|
||||
if !should_retry(&e, num_retries, compute.retry) {
|
||||
// Don't log an error here, caller will print the error
|
||||
Metrics::get().proxy.retries_metric.observe(
|
||||
RetriesMetricGroup {
|
||||
outcome: ConnectOutcome::Failed,
|
||||
retry_type: RetryType::ConnectToCompute,
|
||||
},
|
||||
num_retries.into(),
|
||||
);
|
||||
return Err(e.into());
|
||||
}
|
||||
|
||||
warn!(error = ?e, num_retries, retriable = true, COULD_NOT_CONNECT);
|
||||
}
|
||||
}
|
||||
|
||||
let wait_duration = retry_after(num_retries, compute.retry);
|
||||
num_retries += 1;
|
||||
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::RetryTimeout);
|
||||
time::sleep(wait_duration).await;
|
||||
drop(pause);
|
||||
}
|
||||
}
|
||||
@@ -8,10 +8,10 @@ use crate::config::TlsConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::ReportableError;
|
||||
use crate::metrics::Metrics;
|
||||
use crate::pglb::TlsRequired;
|
||||
use crate::pqproto::{
|
||||
BeMessage, CancelKeyData, FeStartupPacket, ProtocolVersion, StartupMessageParams,
|
||||
};
|
||||
use crate::proxy::TlsRequired;
|
||||
use crate::stream::{PqStream, Stream, StreamUpgradeError};
|
||||
use crate::tls::PG_ALPN_PROTOCOL;
|
||||
|
||||
|
||||
@@ -1,343 +1,5 @@
|
||||
pub mod connect_compute;
|
||||
pub mod copy_bidirectional;
|
||||
pub mod handshake;
|
||||
pub mod inprocess;
|
||||
pub mod passthrough;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::FutureExt;
|
||||
use smol_str::ToSmolStr;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{Instrument, debug, error, info, warn};
|
||||
|
||||
use crate::auth;
|
||||
use crate::cancellation::{self, CancellationHandler};
|
||||
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::{ReportableError, UserFacingError};
|
||||
use crate::metrics::{Metrics, NumClientConnectionsGuard};
|
||||
pub use crate::pglb::copy_bidirectional::ErrorSource;
|
||||
use crate::pglb::handshake::{HandshakeData, HandshakeError, handshake};
|
||||
use crate::pglb::passthrough::ProxyPassthrough;
|
||||
use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol};
|
||||
use crate::proxy::connect_compute::{ConnectMechanism, TcpMechanism};
|
||||
use crate::proxy::handle_connect_request;
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::stream::Stream;
|
||||
use crate::util::run_until_cancelled;
|
||||
|
||||
pub const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
#[error("{ERR_INSECURE_CONNECTION}")]
|
||||
pub struct TlsRequired;
|
||||
|
||||
impl ReportableError for TlsRequired {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
crate::error::ErrorKind::User
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for TlsRequired {}
|
||||
|
||||
pub async fn task_main(
|
||||
config: &'static ProxyConfig,
|
||||
auth_backend: &'static auth::Backend<'static, ()>,
|
||||
listener: tokio::net::TcpListener,
|
||||
cancellation_token: CancellationToken,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> anyhow::Result<()> {
|
||||
scopeguard::defer! {
|
||||
info!("proxy has shut down");
|
||||
}
|
||||
|
||||
// When set for the server socket, the keepalive setting
|
||||
// will be inherited by all accepted client sockets.
|
||||
socket2::SockRef::from(&listener).set_keepalive(true)?;
|
||||
|
||||
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
|
||||
while let Some(accept_result) =
|
||||
run_until_cancelled(listener.accept(), &cancellation_token).await
|
||||
{
|
||||
let (socket, peer_addr) = accept_result?;
|
||||
|
||||
let conn_gauge = Metrics::get()
|
||||
.proxy
|
||||
.client_connections
|
||||
.guard(crate::metrics::Protocol::Tcp);
|
||||
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
let cancellation_handler = Arc::clone(&cancellation_handler);
|
||||
let cancellations = cancellations.clone();
|
||||
|
||||
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
|
||||
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
|
||||
|
||||
connections.spawn(async move {
|
||||
let (socket, conn_info) = match config.proxy_protocol_v2 {
|
||||
ProxyProtocolV2::Required => {
|
||||
match read_proxy_protocol(socket).await {
|
||||
Err(e) => {
|
||||
warn!("per-client task finished with an error: {e:#}");
|
||||
return;
|
||||
}
|
||||
// our load balancers will not send any more data. let's just exit immediately
|
||||
Ok((_socket, ConnectHeader::Local)) => {
|
||||
debug!("healthcheck received");
|
||||
return;
|
||||
}
|
||||
Ok((socket, ConnectHeader::Proxy(info))) => (socket, info),
|
||||
}
|
||||
}
|
||||
// ignore the header - it cannot be confused for a postgres or http connection so will
|
||||
// error later.
|
||||
ProxyProtocolV2::Rejected => (
|
||||
socket,
|
||||
ConnectionInfo {
|
||||
addr: peer_addr,
|
||||
extra: None,
|
||||
},
|
||||
),
|
||||
};
|
||||
|
||||
match socket.set_nodelay(true) {
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
error!(
|
||||
"per-client task finished with an error: failed to set socket option: {e:#}"
|
||||
);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
let ctx = RequestContext::new(
|
||||
session_id,
|
||||
conn_info,
|
||||
crate::metrics::Protocol::Tcp,
|
||||
&config.region,
|
||||
);
|
||||
|
||||
let res = handle_client(
|
||||
config,
|
||||
auth_backend,
|
||||
&ctx,
|
||||
cancellation_handler,
|
||||
socket,
|
||||
ClientMode::Tcp,
|
||||
endpoint_rate_limiter2,
|
||||
conn_gauge,
|
||||
cancellations,
|
||||
)
|
||||
.instrument(ctx.span())
|
||||
.boxed()
|
||||
.await;
|
||||
|
||||
match res {
|
||||
Err(e) => {
|
||||
ctx.set_error_kind(e.get_error_kind());
|
||||
warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
|
||||
}
|
||||
Ok(None) => {
|
||||
ctx.set_success();
|
||||
}
|
||||
Ok(Some(p)) => {
|
||||
ctx.set_success();
|
||||
let _disconnect = ctx.log_connect();
|
||||
match p.proxy_pass(&config.connect_to_compute).await {
|
||||
Ok(()) => {}
|
||||
Err(ErrorSource::Client(e)) => {
|
||||
warn!(
|
||||
?session_id,
|
||||
"per-client task finished with an IO error from the client: {e:#}"
|
||||
);
|
||||
}
|
||||
Err(ErrorSource::Compute(e)) => {
|
||||
error!(
|
||||
?session_id,
|
||||
"per-client task finished with an IO error from the compute: {e:#}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
connections.close();
|
||||
cancellations.close();
|
||||
drop(listener);
|
||||
|
||||
// Drain connections
|
||||
connections.wait().await;
|
||||
cancellations.wait().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) enum ClientMode {
|
||||
Tcp,
|
||||
Websockets { hostname: Option<String> },
|
||||
}
|
||||
|
||||
/// Abstracts the logic of handling TCP vs WS clients
|
||||
impl ClientMode {
|
||||
pub(crate) fn allow_cleartext(&self) -> bool {
|
||||
match self {
|
||||
ClientMode::Tcp => false,
|
||||
ClientMode::Websockets { .. } => true,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn hostname<'a, S>(&'a self, s: &'a Stream<S>) -> Option<&'a str> {
|
||||
match self {
|
||||
ClientMode::Tcp => s.sni_hostname(),
|
||||
ClientMode::Websockets { hostname } => hostname.as_deref(),
|
||||
}
|
||||
}
|
||||
|
||||
fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> {
|
||||
match self {
|
||||
ClientMode::Tcp => tls,
|
||||
// TLS is None here if using websockets, because the connection is already encrypted.
|
||||
ClientMode::Websockets { .. } => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
// almost all errors should be reported to the user, but there's a few cases where we cannot
|
||||
// 1. Cancellation: we are not allowed to tell the client any cancellation statuses for security reasons
|
||||
// 2. Handshake: handshake reports errors if it can, otherwise if the handshake fails due to protocol violation,
|
||||
// we cannot be sure the client even understands our error message
|
||||
// 3. PrepareClient: The client disconnected, so we can't tell them anyway...
|
||||
pub(crate) enum ClientRequestError {
|
||||
#[error("{0}")]
|
||||
Cancellation(#[from] cancellation::CancelError),
|
||||
#[error("{0}")]
|
||||
Handshake(#[from] HandshakeError),
|
||||
#[error("{0}")]
|
||||
HandshakeTimeout(#[from] tokio::time::error::Elapsed),
|
||||
#[error("{0}")]
|
||||
PrepareClient(#[from] std::io::Error),
|
||||
#[error("{0}")]
|
||||
ReportedError(#[from] crate::stream::ReportedError),
|
||||
}
|
||||
|
||||
impl ReportableError for ClientRequestError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
ClientRequestError::Cancellation(e) => e.get_error_kind(),
|
||||
ClientRequestError::Handshake(e) => e.get_error_kind(),
|
||||
ClientRequestError::HandshakeTimeout(_) => crate::error::ErrorKind::RateLimit,
|
||||
ClientRequestError::ReportedError(e) => e.get_error_kind(),
|
||||
ClientRequestError::PrepareClient(_) => crate::error::ErrorKind::ClientDisconnect,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
config: &'static ProxyConfig,
|
||||
auth_backend: &'static auth::Backend<'static, ()>,
|
||||
ctx: &RequestContext,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
client: S,
|
||||
mode: ClientMode,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
conn_gauge: NumClientConnectionsGuard<'static>,
|
||||
cancellations: tokio_util::task::task_tracker::TaskTracker,
|
||||
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
|
||||
debug!(
|
||||
protocol = %ctx.protocol(),
|
||||
"handling interactive connection from client"
|
||||
);
|
||||
|
||||
let metrics = &Metrics::get().proxy;
|
||||
let proto = ctx.protocol();
|
||||
let request_gauge = metrics.connection_requests.guard(proto);
|
||||
|
||||
let tls = config.tls_config.load();
|
||||
let tls = tls.as_deref();
|
||||
|
||||
let record_handshake_error = !ctx.has_private_peer_addr();
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
let do_handshake = handshake(ctx, client, mode.handshake_tls(tls), record_handshake_error);
|
||||
|
||||
let (client, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
|
||||
.await??
|
||||
{
|
||||
HandshakeData::Startup(client, params) => (client, params),
|
||||
HandshakeData::Cancel(cancel_key_data) => {
|
||||
// spawn a task to cancel the session, but don't wait for it
|
||||
cancellations.spawn({
|
||||
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
||||
let ctx = ctx.clone();
|
||||
let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id());
|
||||
cancel_span.follows_from(tracing::Span::current());
|
||||
async move {
|
||||
cancellation_handler_clone
|
||||
.cancel_session(
|
||||
cancel_key_data,
|
||||
ctx,
|
||||
config.authentication_config.ip_allowlist_check_enabled,
|
||||
config.authentication_config.is_vpc_acccess_proxy,
|
||||
auth_backend.get_api(),
|
||||
)
|
||||
.await
|
||||
.inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok();
|
||||
}.instrument(cancel_span)
|
||||
});
|
||||
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
drop(pause);
|
||||
|
||||
ctx.set_db_options(params.clone());
|
||||
|
||||
let common_names = tls.map(|tls| &tls.common_names);
|
||||
|
||||
let private_link_id = match ctx.extra() {
|
||||
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
|
||||
Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
|
||||
None => None,
|
||||
};
|
||||
|
||||
let (node, client, session) = handle_connect_request(
|
||||
config,
|
||||
auth_backend,
|
||||
ctx,
|
||||
cancellation_handler,
|
||||
client,
|
||||
&mode,
|
||||
endpoint_rate_limiter,
|
||||
¶ms,
|
||||
common_names,
|
||||
async |config, ctx, node_info, auth_info, creds, compute_config| {
|
||||
TcpMechanism {
|
||||
auth: auth_info.clone(),
|
||||
locks: &config.connect_compute_locks,
|
||||
user_info: creds.info.clone(),
|
||||
}
|
||||
.connect_once(ctx, node_info, compute_config)
|
||||
.await
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(Some(ProxyPassthrough {
|
||||
client,
|
||||
aux: node.aux.clone(),
|
||||
private_link_id,
|
||||
compute: node,
|
||||
session_id: ctx.session_id(),
|
||||
cancel: session,
|
||||
_req: request_gauge,
|
||||
_conn: conn_gauge,
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -1,61 +1,322 @@
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
pub(crate) mod connect_compute;
|
||||
pub(crate) mod retry;
|
||||
pub(crate) mod wake_compute;
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::FutureExt;
|
||||
use itertools::Itertools;
|
||||
use once_cell::sync::OnceCell;
|
||||
use regex::Regex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use smol_str::{SmolStr, format_smolstr};
|
||||
use smol_str::{SmolStr, ToSmolStr, format_smolstr};
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::Instrument;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{Instrument, debug, error, info, warn};
|
||||
|
||||
use crate::auth::backend::ComputeCredentials;
|
||||
use crate::cancellation::{CancellationHandler, Session};
|
||||
use crate::compute::{AuthInfo, PostgresConnection};
|
||||
use crate::config::{ComputeConfig, ProxyConfig};
|
||||
use crate::cancellation::{self, CancellationHandler};
|
||||
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::CachedNodeInfo;
|
||||
use crate::error::{ReportableError, UserFacingError};
|
||||
use crate::metrics::{Metrics, NumClientConnectionsGuard};
|
||||
use crate::pglb::connect_compute::{TcpMechanism, connect_to_compute};
|
||||
pub use crate::pglb::copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute};
|
||||
use crate::pglb::{ClientMode, ClientRequestError};
|
||||
use crate::pglb::handshake::{HandshakeData, HandshakeError, handshake};
|
||||
use crate::pglb::passthrough::ProxyPassthrough;
|
||||
use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams};
|
||||
use crate::proxy::connect_compute::connect_to_compute_pglb;
|
||||
use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol};
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::stream::{PqStream, Stream};
|
||||
use crate::types::EndpointCacheKey;
|
||||
use crate::{auth, compute};
|
||||
|
||||
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
#[error("{ERR_INSECURE_CONNECTION}")]
|
||||
pub struct TlsRequired;
|
||||
|
||||
impl ReportableError for TlsRequired {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
crate::error::ErrorKind::User
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for TlsRequired {}
|
||||
|
||||
pub async fn run_until_cancelled<F: std::future::Future>(
|
||||
f: F,
|
||||
cancellation_token: &CancellationToken,
|
||||
) -> Option<F::Output> {
|
||||
match futures::future::select(
|
||||
std::pin::pin!(f),
|
||||
std::pin::pin!(cancellation_token.cancelled()),
|
||||
)
|
||||
.await
|
||||
{
|
||||
futures::future::Either::Left((f, _)) => Some(f),
|
||||
futures::future::Either::Right(((), _)) => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn task_main(
|
||||
config: &'static ProxyConfig,
|
||||
auth_backend: &'static auth::Backend<'static, ()>,
|
||||
listener: tokio::net::TcpListener,
|
||||
cancellation_token: CancellationToken,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> anyhow::Result<()> {
|
||||
scopeguard::defer! {
|
||||
info!("proxy has shut down");
|
||||
}
|
||||
|
||||
// When set for the server socket, the keepalive setting
|
||||
// will be inherited by all accepted client sockets.
|
||||
socket2::SockRef::from(&listener).set_keepalive(true)?;
|
||||
|
||||
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
|
||||
while let Some(accept_result) =
|
||||
run_until_cancelled(listener.accept(), &cancellation_token).await
|
||||
{
|
||||
let (socket, peer_addr) = accept_result?;
|
||||
|
||||
let conn_gauge = Metrics::get()
|
||||
.proxy
|
||||
.client_connections
|
||||
.guard(crate::metrics::Protocol::Tcp);
|
||||
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
let cancellation_handler = Arc::clone(&cancellation_handler);
|
||||
let cancellations = cancellations.clone();
|
||||
|
||||
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
|
||||
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
|
||||
|
||||
connections.spawn(async move {
|
||||
let (socket, conn_info) = match config.proxy_protocol_v2 {
|
||||
ProxyProtocolV2::Required => {
|
||||
match read_proxy_protocol(socket).await {
|
||||
Err(e) => {
|
||||
warn!("per-client task finished with an error: {e:#}");
|
||||
return;
|
||||
}
|
||||
// our load balancers will not send any more data. let's just exit immediately
|
||||
Ok((_socket, ConnectHeader::Local)) => {
|
||||
debug!("healthcheck received");
|
||||
return;
|
||||
}
|
||||
Ok((socket, ConnectHeader::Proxy(info))) => (socket, info),
|
||||
}
|
||||
}
|
||||
// ignore the header - it cannot be confused for a postgres or http connection so will
|
||||
// error later.
|
||||
ProxyProtocolV2::Rejected => (
|
||||
socket,
|
||||
ConnectionInfo {
|
||||
addr: peer_addr,
|
||||
extra: None,
|
||||
},
|
||||
),
|
||||
};
|
||||
|
||||
match socket.set_nodelay(true) {
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
error!(
|
||||
"per-client task finished with an error: failed to set socket option: {e:#}"
|
||||
);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Tcp);
|
||||
|
||||
let res = handle_client(
|
||||
config,
|
||||
auth_backend,
|
||||
&ctx,
|
||||
cancellation_handler,
|
||||
socket,
|
||||
ClientMode::Tcp,
|
||||
endpoint_rate_limiter2,
|
||||
conn_gauge,
|
||||
cancellations,
|
||||
)
|
||||
.instrument(ctx.span())
|
||||
.boxed()
|
||||
.await;
|
||||
|
||||
match res {
|
||||
Err(e) => {
|
||||
ctx.set_error_kind(e.get_error_kind());
|
||||
warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
|
||||
}
|
||||
Ok(None) => {
|
||||
ctx.set_success();
|
||||
}
|
||||
Ok(Some(p)) => {
|
||||
ctx.set_success();
|
||||
let _disconnect = ctx.log_connect();
|
||||
match p.proxy_pass(&config.connect_to_compute).await {
|
||||
Ok(()) => {}
|
||||
Err(ErrorSource::Client(e)) => {
|
||||
warn!(
|
||||
?session_id,
|
||||
"per-client task finished with an IO error from the client: {e:#}"
|
||||
);
|
||||
}
|
||||
Err(ErrorSource::Compute(e)) => {
|
||||
error!(
|
||||
?session_id,
|
||||
"per-client task finished with an IO error from the compute: {e:#}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
connections.close();
|
||||
cancellations.close();
|
||||
drop(listener);
|
||||
|
||||
// Drain connections
|
||||
connections.wait().await;
|
||||
cancellations.wait().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) enum ClientMode {
|
||||
Tcp,
|
||||
Websockets { hostname: Option<String> },
|
||||
}
|
||||
|
||||
/// Abstracts the logic of handling TCP vs WS clients
|
||||
impl ClientMode {
|
||||
pub(crate) fn allow_cleartext(&self) -> bool {
|
||||
match self {
|
||||
ClientMode::Tcp => false,
|
||||
ClientMode::Websockets { .. } => true,
|
||||
}
|
||||
}
|
||||
|
||||
fn hostname<'a, S>(&'a self, s: &'a Stream<S>) -> Option<&'a str> {
|
||||
match self {
|
||||
ClientMode::Tcp => s.sni_hostname(),
|
||||
ClientMode::Websockets { hostname } => hostname.as_deref(),
|
||||
}
|
||||
}
|
||||
|
||||
fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> {
|
||||
match self {
|
||||
ClientMode::Tcp => tls,
|
||||
// TLS is None here if using websockets, because the connection is already encrypted.
|
||||
ClientMode::Websockets { .. } => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
// almost all errors should be reported to the user, but there's a few cases where we cannot
|
||||
// 1. Cancellation: we are not allowed to tell the client any cancellation statuses for security reasons
|
||||
// 2. Handshake: handshake reports errors if it can, otherwise if the handshake fails due to protocol violation,
|
||||
// we cannot be sure the client even understands our error message
|
||||
// 3. PrepareClient: The client disconnected, so we can't tell them anyway...
|
||||
pub(crate) enum ClientRequestError {
|
||||
#[error("{0}")]
|
||||
Cancellation(#[from] cancellation::CancelError),
|
||||
#[error("{0}")]
|
||||
Handshake(#[from] HandshakeError),
|
||||
#[error("{0}")]
|
||||
HandshakeTimeout(#[from] tokio::time::error::Elapsed),
|
||||
#[error("{0}")]
|
||||
PrepareClient(#[from] std::io::Error),
|
||||
#[error("{0}")]
|
||||
ReportedError(#[from] crate::stream::ReportedError),
|
||||
}
|
||||
|
||||
impl ReportableError for ClientRequestError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
ClientRequestError::Cancellation(e) => e.get_error_kind(),
|
||||
ClientRequestError::Handshake(e) => e.get_error_kind(),
|
||||
ClientRequestError::HandshakeTimeout(_) => crate::error::ErrorKind::RateLimit,
|
||||
ClientRequestError::ReportedError(e) => e.get_error_kind(),
|
||||
ClientRequestError::PrepareClient(_) => crate::error::ErrorKind::ClientDisconnect,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn handle_connect_request<
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send,
|
||||
F: AsyncFn(
|
||||
&'static ProxyConfig,
|
||||
&RequestContext,
|
||||
&CachedNodeInfo,
|
||||
&AuthInfo,
|
||||
&ComputeCredentials,
|
||||
&ComputeConfig,
|
||||
) -> Result<PostgresConnection, compute::ConnectionError>,
|
||||
>(
|
||||
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
config: &'static ProxyConfig,
|
||||
auth_backend: &'static auth::Backend<'static, ()>,
|
||||
ctx: &RequestContext,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
mut client: PqStream<Stream<S>>,
|
||||
mode: &ClientMode,
|
||||
stream: S,
|
||||
mode: ClientMode,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
params: &StartupMessageParams,
|
||||
common_names: Option<&HashSet<String>>,
|
||||
connect_compute_fn: F,
|
||||
) -> Result<(PostgresConnection, Stream<S>, Session), ClientRequestError> {
|
||||
// TODO: to pglb
|
||||
let hostname = mode.hostname(client.get_ref());
|
||||
conn_gauge: NumClientConnectionsGuard<'static>,
|
||||
cancellations: tokio_util::task::task_tracker::TaskTracker,
|
||||
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
|
||||
debug!(
|
||||
protocol = %ctx.protocol(),
|
||||
"handling interactive connection from client"
|
||||
);
|
||||
|
||||
let metrics = &Metrics::get().proxy;
|
||||
let proto = ctx.protocol();
|
||||
let request_gauge = metrics.connection_requests.guard(proto);
|
||||
|
||||
let tls = config.tls_config.load();
|
||||
let tls = tls.as_deref();
|
||||
|
||||
let record_handshake_error = !ctx.has_private_peer_addr();
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
let do_handshake = handshake(ctx, stream, mode.handshake_tls(tls), record_handshake_error);
|
||||
|
||||
let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
|
||||
.await??
|
||||
{
|
||||
HandshakeData::Startup(stream, params) => (stream, params),
|
||||
HandshakeData::Cancel(cancel_key_data) => {
|
||||
// spawn a task to cancel the session, but don't wait for it
|
||||
cancellations.spawn({
|
||||
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
||||
let ctx = ctx.clone();
|
||||
let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id());
|
||||
cancel_span.follows_from(tracing::Span::current());
|
||||
async move {
|
||||
cancellation_handler_clone
|
||||
.cancel_session(
|
||||
cancel_key_data,
|
||||
ctx,
|
||||
config.authentication_config.ip_allowlist_check_enabled,
|
||||
config.authentication_config.is_vpc_acccess_proxy,
|
||||
auth_backend.get_api(),
|
||||
)
|
||||
.await
|
||||
.inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok();
|
||||
}.instrument(cancel_span)
|
||||
});
|
||||
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
drop(pause);
|
||||
|
||||
ctx.set_db_options(params.clone());
|
||||
|
||||
let hostname = mode.hostname(stream.get_ref());
|
||||
|
||||
let common_names = tls.map(|tls| &tls.common_names);
|
||||
|
||||
// Extract credentials which we're going to use for auth.
|
||||
let result = auth_backend
|
||||
@@ -65,14 +326,14 @@ pub(crate) async fn handle_connect_request<
|
||||
|
||||
let user_info = match result {
|
||||
Ok(user_info) => user_info,
|
||||
Err(e) => Err(client.throw_error(e, Some(ctx)).await)?,
|
||||
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
|
||||
};
|
||||
|
||||
let user = user_info.get_user().to_owned();
|
||||
let user_info = match user_info
|
||||
.authenticate(
|
||||
ctx,
|
||||
&mut client,
|
||||
&mut stream,
|
||||
mode.allow_cleartext(),
|
||||
&config.authentication_config,
|
||||
endpoint_rate_limiter,
|
||||
@@ -85,28 +346,29 @@ pub(crate) async fn handle_connect_request<
|
||||
let app = params.get("application_name");
|
||||
let params_span = tracing::info_span!("", ?user, ?db, ?app);
|
||||
|
||||
return Err(client
|
||||
return Err(stream
|
||||
.throw_error(e, Some(ctx))
|
||||
.instrument(params_span)
|
||||
.await)?;
|
||||
}
|
||||
};
|
||||
|
||||
let (cplane, creds) = match user_info {
|
||||
auth::Backend::ControlPlane(cplane, creds) => (cplane, creds),
|
||||
let creds = match &user_info {
|
||||
auth::Backend::ControlPlane(_, creds) => creds,
|
||||
auth::Backend::Local(_) => unreachable!("local proxy does not run tcp proxy service"),
|
||||
};
|
||||
let params_compat = creds.info.options.get(NeonOptions::PARAMS_COMPAT).is_some();
|
||||
let mut auth_info = compute::AuthInfo::with_auth_keys(creds.keys);
|
||||
let mut auth_info = compute::AuthInfo::with_auth_keys(&creds.keys);
|
||||
auth_info.set_startup_params(¶ms, params_compat);
|
||||
|
||||
let res = connect_to_compute_pglb(
|
||||
config,
|
||||
let res = connect_to_compute(
|
||||
ctx,
|
||||
connect_compute_fn,
|
||||
&auth::Backend::ControlPlane(cplane, creds.info),
|
||||
&auth_info,
|
||||
&creds,
|
||||
&TcpMechanism {
|
||||
user_info: creds.info.clone(),
|
||||
auth: auth_info,
|
||||
locks: &config.connect_compute_locks,
|
||||
},
|
||||
&user_info,
|
||||
config.wake_compute_retry_config,
|
||||
&config.connect_to_compute,
|
||||
)
|
||||
@@ -114,17 +376,32 @@ pub(crate) async fn handle_connect_request<
|
||||
|
||||
let node = match res {
|
||||
Ok(node) => node,
|
||||
Err(e) => Err(client.throw_error(e, Some(ctx)).await)?,
|
||||
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
|
||||
};
|
||||
|
||||
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
|
||||
let session = cancellation_handler_clone.get_key();
|
||||
|
||||
session.write_cancel_key(node.cancel_closure.clone())?;
|
||||
prepare_client_connection(&node, *session.key(), &mut client);
|
||||
let client = client.flush_and_into_inner().await?;
|
||||
prepare_client_connection(&node, *session.key(), &mut stream);
|
||||
let stream = stream.flush_and_into_inner().await?;
|
||||
|
||||
Ok((node, client, session))
|
||||
let private_link_id = match ctx.extra() {
|
||||
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
|
||||
Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
|
||||
None => None,
|
||||
};
|
||||
|
||||
Ok(Some(ProxyPassthrough {
|
||||
client: stream,
|
||||
aux: node.aux.clone(),
|
||||
private_link_id,
|
||||
compute: node,
|
||||
session_id: ctx.session_id(),
|
||||
cancel: session,
|
||||
_req: request_gauge,
|
||||
_conn: conn_gauge,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Finish client connection initialization: confirm auth success, send params, etc.
|
||||
@@ -159,7 +436,7 @@ impl NeonOptions {
|
||||
// proxy options:
|
||||
|
||||
/// `PARAMS_COMPAT` allows opting in to forwarding all startup parameters from client to compute.
|
||||
pub const PARAMS_COMPAT: &str = "proxy_params_compat";
|
||||
const PARAMS_COMPAT: &str = "proxy_params_compat";
|
||||
|
||||
// cplane options:
|
||||
|
||||
|
||||
@@ -3,13 +3,12 @@
|
||||
|
||||
mod mitm;
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::{Context, bail};
|
||||
use async_trait::async_trait;
|
||||
use http::StatusCode;
|
||||
use postgres_client::config::SslMode;
|
||||
use postgres_client::config::{AuthKeys, ScramKeys, SslMode};
|
||||
use postgres_client::tls::{MakeTlsConnect, NoTls};
|
||||
use retry::{ShouldRetryWakeCompute, retry_after};
|
||||
use rstest::rstest;
|
||||
@@ -20,21 +19,19 @@ use tracing_test::traced_test;
|
||||
|
||||
use super::retry::CouldRetry;
|
||||
use super::*;
|
||||
use crate::auth::backend::{ComputeUserInfo, MaybeOwned};
|
||||
use crate::config::{ComputeConfig, RetryConfig, TlsConfig};
|
||||
use crate::context::RequestContext;
|
||||
use crate::auth::backend::{
|
||||
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned,
|
||||
};
|
||||
use crate::config::{ComputeConfig, RetryConfig};
|
||||
use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient};
|
||||
use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status};
|
||||
use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache};
|
||||
use crate::error::{ErrorKind, ReportableError};
|
||||
use crate::pglb::ERR_INSECURE_CONNECTION;
|
||||
use crate::pglb::handshake::{HandshakeData, handshake};
|
||||
use crate::proxy::connect_compute::{ConnectMechanism, connect_to_compute};
|
||||
use crate::stream::Stream;
|
||||
use crate::error::ErrorKind;
|
||||
use crate::pglb::connect_compute::ConnectMechanism;
|
||||
use crate::tls::client_config::compute_client_config_with_certs;
|
||||
use crate::tls::server_config::CertResolver;
|
||||
use crate::types::{BranchId, EndpointId, ProjectId};
|
||||
use crate::{auth, sasl, scram};
|
||||
use crate::{sasl, scram};
|
||||
|
||||
/// Generate a set of TLS certificates: CA + server.
|
||||
fn generate_certs(
|
||||
@@ -578,13 +575,19 @@ fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeIn
|
||||
|
||||
fn helper_create_connect_info(
|
||||
mechanism: &TestConnectMechanism,
|
||||
) -> auth::Backend<'static, ComputeUserInfo> {
|
||||
) -> auth::Backend<'static, ComputeCredentials> {
|
||||
auth::Backend::ControlPlane(
|
||||
MaybeOwned::Owned(ControlPlaneClient::Test(Box::new(mechanism.clone()))),
|
||||
ComputeUserInfo {
|
||||
endpoint: "endpoint".into(),
|
||||
user: "user".into(),
|
||||
options: NeonOptions::parse_options_raw(""),
|
||||
ComputeCredentials {
|
||||
info: ComputeUserInfo {
|
||||
endpoint: "endpoint".into(),
|
||||
user: "user".into(),
|
||||
options: NeonOptions::parse_options_raw(""),
|
||||
},
|
||||
keys: ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(ScramKeys {
|
||||
client_key: [0; 32],
|
||||
server_key: [0; 32],
|
||||
})),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use async_trait::async_trait;
|
||||
use tracing::{error, info};
|
||||
|
||||
use crate::config::RetryConfig;
|
||||
@@ -9,6 +8,7 @@ use crate::error::ReportableError;
|
||||
use crate::metrics::{
|
||||
ConnectOutcome, ConnectionFailuresBreakdownGroup, Metrics, RetriesMetricGroup, RetryType,
|
||||
};
|
||||
use crate::pglb::connect_compute::ComputeConnectBackend;
|
||||
use crate::proxy::retry::{retry_after, should_retry};
|
||||
|
||||
// Use macro to retain original callsite.
|
||||
@@ -23,12 +23,7 @@ macro_rules! log_wake_compute_error {
|
||||
};
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub(crate) trait WakeComputeBackend {
|
||||
async fn wake_compute(&self, ctx: &RequestContext) -> Result<CachedNodeInfo, WakeComputeError>;
|
||||
}
|
||||
|
||||
pub(crate) async fn wake_compute<B: WakeComputeBackend>(
|
||||
pub(crate) async fn wake_compute<B: ComputeConnectBackend>(
|
||||
num_retries: &mut u32,
|
||||
ctx: &RequestContext,
|
||||
api: &B,
|
||||
|
||||
@@ -140,12 +140,6 @@ impl RateBucketInfo {
|
||||
Self::new(200, Duration::from_secs(600)),
|
||||
];
|
||||
|
||||
// For all the sessions will be cancel key. So this limit is essentially global proxy limit.
|
||||
pub const DEFAULT_REDIS_SET: [Self; 2] = [
|
||||
Self::new(100_000, Duration::from_secs(1)),
|
||||
Self::new(50_000, Duration::from_secs(10)),
|
||||
];
|
||||
|
||||
pub fn rps(&self) -> f64 {
|
||||
(self.max_rpi as f64) / self.interval.as_secs_f64()
|
||||
}
|
||||
|
||||
@@ -2,11 +2,9 @@ use redis::aio::ConnectionLike;
|
||||
use redis::{Cmd, FromRedisValue, Pipeline, RedisResult};
|
||||
|
||||
use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
|
||||
use crate::rate_limiter::{GlobalRateLimiter, RateBucketInfo};
|
||||
|
||||
pub struct RedisKVClient {
|
||||
client: ConnectionWithCredentialsProvider,
|
||||
limiter: GlobalRateLimiter,
|
||||
}
|
||||
|
||||
#[allow(async_fn_in_trait)]
|
||||
@@ -27,11 +25,8 @@ impl Queryable for Cmd {
|
||||
}
|
||||
|
||||
impl RedisKVClient {
|
||||
pub fn new(client: ConnectionWithCredentialsProvider, info: &'static [RateBucketInfo]) -> Self {
|
||||
Self {
|
||||
client,
|
||||
limiter: GlobalRateLimiter::new(info.into()),
|
||||
}
|
||||
pub fn new(client: ConnectionWithCredentialsProvider) -> Self {
|
||||
Self { client }
|
||||
}
|
||||
|
||||
pub async fn try_connect(&mut self) -> anyhow::Result<()> {
|
||||
@@ -49,11 +44,6 @@ impl RedisKVClient {
|
||||
&mut self,
|
||||
q: &impl Queryable,
|
||||
) -> anyhow::Result<T> {
|
||||
if !self.limiter.check() {
|
||||
tracing::info!("Rate limit exceeded. Skipping query");
|
||||
return Err(anyhow::anyhow!("Rate limit exceeded"));
|
||||
}
|
||||
|
||||
match q.query(&mut self.client).await {
|
||||
Ok(t) => return Ok(t),
|
||||
Err(e) => {
|
||||
|
||||
@@ -141,29 +141,19 @@ where
|
||||
|
||||
struct MessageHandler<C: ProjectInfoCache + Send + Sync + 'static> {
|
||||
cache: Arc<C>,
|
||||
region_id: String,
|
||||
}
|
||||
|
||||
impl<C: ProjectInfoCache + Send + Sync + 'static> Clone for MessageHandler<C> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
cache: self.cache.clone(),
|
||||
region_id: self.region_id.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
|
||||
pub(crate) fn new(cache: Arc<C>, region_id: String) -> Self {
|
||||
Self { cache, region_id }
|
||||
}
|
||||
|
||||
pub(crate) async fn increment_active_listeners(&self) {
|
||||
self.cache.increment_active_listeners().await;
|
||||
}
|
||||
|
||||
pub(crate) async fn decrement_active_listeners(&self) {
|
||||
self.cache.decrement_active_listeners().await;
|
||||
pub(crate) fn new(cache: Arc<C>) -> Self {
|
||||
Self { cache }
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))]
|
||||
@@ -276,7 +266,7 @@ async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
|
||||
}
|
||||
let mut conn = match try_connect(&redis).await {
|
||||
Ok(conn) => {
|
||||
handler.increment_active_listeners().await;
|
||||
handler.cache.increment_active_listeners().await;
|
||||
conn
|
||||
}
|
||||
Err(e) => {
|
||||
@@ -297,11 +287,11 @@ async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
|
||||
}
|
||||
}
|
||||
if cancellation_token.is_cancelled() {
|
||||
handler.decrement_active_listeners().await;
|
||||
handler.cache.decrement_active_listeners().await;
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
handler.decrement_active_listeners().await;
|
||||
handler.cache.decrement_active_listeners().await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -310,12 +300,11 @@ async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
|
||||
pub async fn task_main<C>(
|
||||
redis: ConnectionWithCredentialsProvider,
|
||||
cache: Arc<C>,
|
||||
region_id: String,
|
||||
) -> anyhow::Result<Infallible>
|
||||
where
|
||||
C: ProjectInfoCache + Send + Sync + 'static,
|
||||
{
|
||||
let handler = MessageHandler::new(cache, region_id);
|
||||
let handler = MessageHandler::new(cache);
|
||||
// 6h - 1m.
|
||||
// There will be 1 minute overlap between two tasks. But at least we can be sure that no message is lost.
|
||||
let mut interval = tokio::time::interval(std::time::Duration::from_secs(6 * 60 * 60 - 60));
|
||||
|
||||
@@ -21,7 +21,7 @@ use super::conn_pool_lib::{Client, ConnInfo, EndpointConnPool, GlobalConnPool};
|
||||
use super::http_conn_pool::{self, HttpConnPool, Send, poll_http2_client};
|
||||
use super::local_conn_pool::{self, EXT_NAME, EXT_SCHEMA, EXT_VERSION, LocalConnPool};
|
||||
use crate::auth::backend::local::StaticAuthRules;
|
||||
use crate::auth::backend::{ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo};
|
||||
use crate::auth::backend::{ComputeCredentials, ComputeUserInfo};
|
||||
use crate::auth::{self, AuthError};
|
||||
use crate::compute_ctl::{
|
||||
ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest,
|
||||
@@ -34,7 +34,7 @@ use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
|
||||
use crate::control_plane::locks::ApiLocks;
|
||||
use crate::error::{ErrorKind, ReportableError, UserFacingError};
|
||||
use crate::intern::EndpointIdInt;
|
||||
use crate::proxy::connect_compute::ConnectMechanism;
|
||||
use crate::pglb::connect_compute::ConnectMechanism;
|
||||
use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute};
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::types::{EndpointId, Host, LOCAL_PROXY_SUFFIX};
|
||||
@@ -180,15 +180,14 @@ impl PoolingBackend {
|
||||
let conn_id = uuid::Uuid::new_v4();
|
||||
tracing::Span::current().record("conn_id", display(conn_id));
|
||||
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
|
||||
let backend = self.auth_backend.as_ref().map(|()| keys.info);
|
||||
crate::proxy::connect_compute::connect_to_compute(
|
||||
let backend = self.auth_backend.as_ref().map(|()| keys);
|
||||
crate::pglb::connect_compute::connect_to_compute(
|
||||
ctx,
|
||||
&TokioMechanism {
|
||||
conn_id,
|
||||
conn_info,
|
||||
pool: self.pool.clone(),
|
||||
locks: &self.config.connect_compute_locks,
|
||||
keys: keys.keys,
|
||||
},
|
||||
&backend,
|
||||
self.config.wake_compute_retry_config,
|
||||
@@ -215,15 +214,18 @@ impl PoolingBackend {
|
||||
let conn_id = uuid::Uuid::new_v4();
|
||||
tracing::Span::current().record("conn_id", display(conn_id));
|
||||
debug!(%conn_id, "pool: opening a new connection '{conn_info}'");
|
||||
let backend = self.auth_backend.as_ref().map(|()| ComputeUserInfo {
|
||||
user: conn_info.user_info.user.clone(),
|
||||
endpoint: EndpointId::from(format!(
|
||||
"{}{LOCAL_PROXY_SUFFIX}",
|
||||
conn_info.user_info.endpoint.normalize()
|
||||
)),
|
||||
options: conn_info.user_info.options.clone(),
|
||||
let backend = self.auth_backend.as_ref().map(|()| ComputeCredentials {
|
||||
info: ComputeUserInfo {
|
||||
user: conn_info.user_info.user.clone(),
|
||||
endpoint: EndpointId::from(format!(
|
||||
"{}{LOCAL_PROXY_SUFFIX}",
|
||||
conn_info.user_info.endpoint.normalize()
|
||||
)),
|
||||
options: conn_info.user_info.options.clone(),
|
||||
},
|
||||
keys: crate::auth::backend::ComputeCredentialKeys::None,
|
||||
});
|
||||
crate::proxy::connect_compute::connect_to_compute(
|
||||
crate::pglb::connect_compute::connect_to_compute(
|
||||
ctx,
|
||||
&HyperMechanism {
|
||||
conn_id,
|
||||
@@ -493,7 +495,6 @@ struct TokioMechanism {
|
||||
pool: Arc<GlobalConnPool<postgres_client::Client, EndpointConnPool<postgres_client::Client>>>,
|
||||
conn_info: ConnInfo,
|
||||
conn_id: uuid::Uuid,
|
||||
keys: ComputeCredentialKeys,
|
||||
|
||||
/// connect_to_compute concurrency lock
|
||||
locks: &'static ApiLocks<Host>,
|
||||
@@ -519,10 +520,6 @@ impl ConnectMechanism for TokioMechanism {
|
||||
.dbname(&self.conn_info.dbname)
|
||||
.connect_timeout(compute_config.timeout);
|
||||
|
||||
if let ComputeCredentialKeys::AuthKeys(auth_keys) = self.keys {
|
||||
config.auth_keys(auth_keys);
|
||||
}
|
||||
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
let res = config.connect(compute_config).await;
|
||||
drop(pause);
|
||||
|
||||
@@ -50,10 +50,10 @@ use crate::context::RequestContext;
|
||||
use crate::ext::TaskExt;
|
||||
use crate::metrics::Metrics;
|
||||
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
|
||||
use crate::proxy::run_until_cancelled;
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::serverless::backend::PoolingBackend;
|
||||
use crate::serverless::http_util::{api_error_into_response, json_response};
|
||||
use crate::util::run_until_cancelled;
|
||||
|
||||
pub(crate) const SERVERLESS_DRIVER_SNI: &str = "api";
|
||||
pub(crate) const AUTH_BROKER_SNI: &str = "apiauth";
|
||||
@@ -417,12 +417,7 @@ async fn request_handler(
|
||||
if config.http_config.accept_websockets
|
||||
&& framed_websockets::upgrade::is_upgrade_request(&request)
|
||||
{
|
||||
let ctx = RequestContext::new(
|
||||
session_id,
|
||||
conn_info,
|
||||
crate::metrics::Protocol::Ws,
|
||||
&config.region,
|
||||
);
|
||||
let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Ws);
|
||||
|
||||
ctx.set_user_agent(
|
||||
request
|
||||
@@ -462,12 +457,7 @@ async fn request_handler(
|
||||
// Return the response so the spawned future can continue.
|
||||
Ok(response.map(|b| b.map_err(|x| match x {}).boxed()))
|
||||
} else if request.uri().path() == "/sql" && *request.method() == Method::POST {
|
||||
let ctx = RequestContext::new(
|
||||
session_id,
|
||||
conn_info,
|
||||
crate::metrics::Protocol::Http,
|
||||
&config.region,
|
||||
);
|
||||
let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Http);
|
||||
let span = ctx.span();
|
||||
|
||||
let testodrome_id = request
|
||||
|
||||
@@ -41,11 +41,10 @@ use crate::error::{ErrorKind, ReportableError, UserFacingError};
|
||||
use crate::http::{ReadBodyError, read_body_with_limit};
|
||||
use crate::metrics::{HttpDirection, Metrics, SniGroup, SniKind};
|
||||
use crate::pqproto::StartupMessageParams;
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::proxy::{NeonOptions, run_until_cancelled};
|
||||
use crate::serverless::backend::HttpConnError;
|
||||
use crate::types::{DbName, RoleName};
|
||||
use crate::usage_metrics::{MetricCounter, MetricCounterRecorder};
|
||||
use crate::util::run_until_cancelled;
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
|
||||
@@ -17,8 +17,7 @@ use crate::config::ProxyConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::ReportableError;
|
||||
use crate::metrics::Metrics;
|
||||
use crate::pglb::{ClientMode, handle_client};
|
||||
use crate::proxy::ErrorSource;
|
||||
use crate::proxy::{ClientMode, ErrorSource, handle_client};
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
|
||||
pin_project! {
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
use std::pin::pin;
|
||||
|
||||
use futures::future::{Either, select};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
pub async fn run_until_cancelled<F: Future>(
|
||||
f: F,
|
||||
cancellation_token: &CancellationToken,
|
||||
) -> Option<F::Output> {
|
||||
match select(pin!(f), pin!(cancellation_token.cancelled())).await {
|
||||
Either::Left((f, _)) => Some(f),
|
||||
Either::Right(((), _)) => None,
|
||||
}
|
||||
}
|
||||
@@ -77,7 +77,7 @@ class EndpointHttpClient(requests.Session):
|
||||
status, err = json["status"], json.get("error")
|
||||
assert status == "completed", f"{status}, error {err}"
|
||||
|
||||
wait_until(prewarmed, timeout=60)
|
||||
wait_until(prewarmed)
|
||||
|
||||
def offload_lfc(self):
|
||||
url = f"http://localhost:{self.external_port}/lfc/offload"
|
||||
|
||||
@@ -4046,16 +4046,6 @@ def static_proxy(
|
||||
"CREATE TABLE neon_control_plane.endpoints (endpoint_id VARCHAR(255) PRIMARY KEY, allowed_ips VARCHAR(255))"
|
||||
)
|
||||
|
||||
vanilla_pg.stop()
|
||||
vanilla_pg.edit_hba(
|
||||
[
|
||||
"local all all trust",
|
||||
"host all all 127.0.0.1/32 scram-sha-256",
|
||||
"host all all ::1/128 scram-sha-256",
|
||||
]
|
||||
)
|
||||
vanilla_pg.start()
|
||||
|
||||
proxy_port = port_distributor.get_port()
|
||||
mgmt_port = port_distributor.get_port()
|
||||
http_port = port_distributor.get_port()
|
||||
|
||||
@@ -16,7 +16,7 @@ if TYPE_CHECKING:
|
||||
|
||||
# Test restarting page server, while safekeeper and compute node keep
|
||||
# running.
|
||||
def test_pageserver_restarts_under_workload(neon_simple_env: NeonEnv, pg_bin: PgBin):
|
||||
def test_pageserver_restarts_under_worload(neon_simple_env: NeonEnv, pg_bin: PgBin):
|
||||
env = neon_simple_env
|
||||
env.create_branch("test_pageserver_restarts")
|
||||
endpoint = env.endpoints.create_start("test_pageserver_restarts")
|
||||
@@ -28,11 +28,7 @@ def test_pageserver_restarts_under_workload(neon_simple_env: NeonEnv, pg_bin: Pg
|
||||
pg_bin.run_capture(["pgbench", "-i", "-I", "dtGvp", f"-s{scale}", connstr])
|
||||
pg_bin.run_capture(["pgbench", f"-T{n_restarts}", connstr])
|
||||
|
||||
thread = threading.Thread(
|
||||
target=run_pgbench,
|
||||
args=(endpoint.connstr(options="-cstatement_timeout=360s"),),
|
||||
daemon=True,
|
||||
)
|
||||
thread = threading.Thread(target=run_pgbench, args=(endpoint.connstr(),), daemon=True)
|
||||
thread.start()
|
||||
|
||||
for _ in range(n_restarts):
|
||||
|
||||
@@ -19,15 +19,11 @@ TABLE_NAME = "neon_control_plane.endpoints"
|
||||
async def test_proxy_psql_allowed_ips(static_proxy: NeonProxy, vanilla_pg: VanillaPostgres):
|
||||
# Shouldn't be able to connect to this project
|
||||
vanilla_pg.safe_psql(
|
||||
f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('private-project', '8.8.8.8')",
|
||||
user="proxy",
|
||||
password="password",
|
||||
f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('private-project', '8.8.8.8')"
|
||||
)
|
||||
# Should be able to connect to this project
|
||||
vanilla_pg.safe_psql(
|
||||
f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('generic-project', '::1,127.0.0.1')",
|
||||
user="proxy",
|
||||
password="password",
|
||||
f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('generic-project', '::1,127.0.0.1')"
|
||||
)
|
||||
|
||||
def check_cannot_connect(**kwargs):
|
||||
@@ -64,9 +60,7 @@ async def test_proxy_http_allowed_ips(static_proxy: NeonProxy, vanilla_pg: Vanil
|
||||
|
||||
# Shouldn't be able to connect to this project
|
||||
vanilla_pg.safe_psql(
|
||||
f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('proxy', '8.8.8.8')",
|
||||
user="proxy",
|
||||
password="password",
|
||||
f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('proxy', '8.8.8.8')"
|
||||
)
|
||||
|
||||
def query(status: int, query: str, *args):
|
||||
@@ -81,8 +75,6 @@ async def test_proxy_http_allowed_ips(static_proxy: NeonProxy, vanilla_pg: Vanil
|
||||
query(400, "select 1;") # ip address is not allowed
|
||||
# Should be able to connect to this project
|
||||
vanilla_pg.safe_psql(
|
||||
f"UPDATE {TABLE_NAME} SET allowed_ips = '8.8.8.8,127.0.0.1' WHERE endpoint_id = 'proxy'",
|
||||
user="proxy",
|
||||
password="password",
|
||||
f"UPDATE {TABLE_NAME} SET allowed_ips = '8.8.8.8,127.0.0.1' WHERE endpoint_id = 'proxy'"
|
||||
)
|
||||
query(200, "select 1;") # should work now
|
||||
|
||||
Reference in New Issue
Block a user