Compare commits

..

8 Commits

45 changed files with 1145 additions and 1220 deletions

2
Cargo.lock generated
View File

@@ -2055,7 +2055,6 @@ dependencies = [
"axum-extra",
"camino",
"camino-tempfile",
"clap",
"futures",
"http-body-util",
"itertools 0.10.5",
@@ -5274,6 +5273,7 @@ dependencies = [
"tokio-rustls 0.26.2",
"tokio-tungstenite 0.21.0",
"tokio-util",
"toml",
"tracing",
"tracing-log",
"tracing-opentelemetry",

View File

@@ -8,7 +8,6 @@ anyhow.workspace = true
axum-extra.workspace = true
axum.workspace = true
camino.workspace = true
clap.workspace = true
futures.workspace = true
jsonwebtoken.workspace = true
prometheus.workspace = true

View File

@@ -3,8 +3,7 @@
//! This service is deployed either as a separate component or as part of compute image
//! for large computes.
mod app;
use anyhow::Context;
use clap::Parser;
use anyhow::{Context, bail};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use tracing::info;
use utils::logging;
@@ -18,18 +17,6 @@ const fn listen() -> SocketAddr {
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 51243)
}
#[derive(Parser)]
struct Args {
#[arg(exclusive = true)]
config_file: Option<String>,
#[arg(long, default_value = "false", requires = "config")]
/// to allow testing k8s helm chart where we don't have s3 credentials
no_s3_check_on_startup: bool,
#[arg(long, value_name = "FILE")]
/// inline config mode for k8s helm chart
config: Option<String>,
}
#[derive(serde::Deserialize)]
#[serde(tag = "type")]
struct Config {
@@ -50,16 +37,19 @@ async fn main() -> anyhow::Result<()> {
logging::Output::Stdout,
)?;
let args = Args::parse();
let config: Config = if let Some(config_path) = args.config_file {
info!("Reading config from {config_path}");
let config = std::fs::read_to_string(config_path)?;
// Allow either passing filename or inline config (for k8s helm chart)
let args: Vec<String> = std::env::args().skip(1).collect();
let config: Config = if args.len() == 1 && args[0].ends_with(".json") {
info!("Reading config from {}", args[0]);
let config = std::fs::read_to_string(args[0].clone())?;
serde_json::from_str(&config).context("parsing config")?
} else if let Some(config) = args.config {
} else if !args.is_empty() && args[0].starts_with("--config=") {
info!("Reading inline config");
serde_json::from_str(&config).context("parsing config")?
let config = args.join(" ");
let config = config.strip_prefix("--config=").unwrap();
serde_json::from_str(config).context("parsing config")?
} else {
anyhow::bail!("Supply either config file path or --config=inline-config");
bail!("Usage: endpoint_storage config.json or endpoint_storage --config=JSON");
};
info!("Reading pemfile from {}", config.pemfile.clone());
@@ -72,9 +62,7 @@ async fn main() -> anyhow::Result<()> {
let storage = remote_storage::GenericRemoteStorage::from_config(&config.storage_config).await?;
let cancel = tokio_util::sync::CancellationToken::new();
if !args.no_s3_check_on_startup {
app::check_storage_permissions(&storage, cancel.clone()).await?;
}
app::check_storage_permissions(&storage, cancel.clone()).await?;
let proxy = std::sync::Arc::new(endpoint_storage::Storage {
auth,

View File

@@ -1,3 +1,5 @@
use std::io;
use tokio::net::TcpStream;
use crate::client::SocketConfig;
@@ -6,7 +8,7 @@ use crate::tls::MakeTlsConnect;
use crate::{Error, cancel_query_raw, connect_socket};
pub(crate) async fn cancel_query<T>(
config: SocketConfig,
config: Option<SocketConfig>,
ssl_mode: SslMode,
tls: T,
process_id: i32,
@@ -15,6 +17,16 @@ pub(crate) async fn cancel_query<T>(
where
T: MakeTlsConnect<TcpStream>,
{
let config = match config {
Some(config) => config,
None => {
return Err(Error::connect(io::Error::new(
io::ErrorKind::InvalidInput,
"unknown host",
)));
}
};
let hostname = match &config.host {
Host::Tcp(host) => &**host,
};

View File

@@ -7,18 +7,11 @@ use crate::config::SslMode;
use crate::tls::{MakeTlsConnect, TlsConnect};
use crate::{Error, cancel_query, cancel_query_raw};
/// The capability to request cancellation of in-progress queries on a
/// connection.
#[derive(Clone)]
pub struct CancelToken {
pub socket_config: SocketConfig,
pub raw: RawCancelToken,
}
/// The capability to request cancellation of in-progress queries on a
/// connection.
#[derive(Clone, Serialize, Deserialize)]
pub struct RawCancelToken {
pub struct CancelToken {
pub socket_config: Option<SocketConfig>,
pub ssl_mode: SslMode,
pub process_id: i32,
pub secret_key: i32,
@@ -43,16 +36,14 @@ impl CancelToken {
{
cancel_query::cancel_query(
self.socket_config.clone(),
self.raw.ssl_mode,
self.ssl_mode,
tls,
self.raw.process_id,
self.raw.secret_key,
self.process_id,
self.secret_key,
)
.await
}
}
impl RawCancelToken {
/// Like `cancel_query`, but uses a stream which is already connected to the server rather than opening a new
/// connection itself.
pub async fn cancel_query_raw<S, T>(&self, stream: S, tls: T) -> Result<(), Error>

View File

@@ -12,7 +12,6 @@ use postgres_protocol2::message::frontend;
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc;
use crate::cancel_token::RawCancelToken;
use crate::codec::{BackendMessages, FrontendMessage};
use crate::config::{Host, SslMode};
use crate::query::RowStream;
@@ -332,12 +331,10 @@ impl Client {
/// connection associated with this client.
pub fn cancel_token(&self) -> CancelToken {
CancelToken {
socket_config: self.socket_config.clone(),
raw: RawCancelToken {
ssl_mode: self.ssl_mode,
process_id: self.process_id,
secret_key: self.secret_key,
},
socket_config: Some(self.socket_config.clone()),
ssl_mode: self.ssl_mode,
process_id: self.process_id,
secret_key: self.secret_key,
}
}

View File

@@ -3,7 +3,7 @@
use postgres_protocol2::message::backend::ReadyForQueryBody;
pub use crate::cancel_token::{CancelToken, RawCancelToken};
pub use crate::cancel_token::CancelToken;
pub use crate::client::{Client, SocketConfig};
pub use crate::config::Config;
pub use crate::connect_raw::RawConnection;

View File

@@ -10,7 +10,7 @@ use std::sync::Arc;
use std::time::{Duration, SystemTime};
use std::{env, io};
use anyhow::{Context, Result, anyhow};
use anyhow::{Context, Result};
use azure_core::request_options::{IfMatchCondition, MaxResults, Metadata, Range};
use azure_core::{Continuable, HttpClient, RetryOptions, TransportOptions};
use azure_storage::StorageCredentials;
@@ -37,7 +37,6 @@ use crate::metrics::{AttemptOutcome, RequestKind, start_measuring_requests};
use crate::{
ConcurrencyLimiter, Download, DownloadError, DownloadKind, DownloadOpts, Listing, ListingMode,
ListingObject, RemotePath, RemoteStorage, StorageMetadata, TimeTravelError, TimeoutOrCancel,
Version, VersionKind,
};
pub struct AzureBlobStorage {
@@ -406,39 +405,6 @@ impl AzureBlobStorage {
pub fn container_name(&self) -> &str {
&self.container_name
}
async fn list_versions_with_permit(
&self,
_permit: &tokio::sync::SemaphorePermit<'_>,
prefix: Option<&RemotePath>,
mode: ListingMode,
max_keys: Option<NonZeroU32>,
cancel: &CancellationToken,
) -> Result<crate::VersionListing, DownloadError> {
let customize_builder = |mut builder: ListBlobsBuilder| {
builder = builder.include_versions(true);
// We do not return this info back to `VersionListing` yet.
builder = builder.include_deleted(true);
builder
};
let kind = RequestKind::ListVersions;
let mut stream = std::pin::pin!(self.list_streaming_for_fn(
prefix,
mode,
max_keys,
cancel,
kind,
customize_builder
));
let mut combined: crate::VersionListing =
stream.next().await.expect("At least one item required")?;
while let Some(list) = stream.next().await {
let list = list?;
combined.versions.extend(list.versions.into_iter());
}
Ok(combined)
}
}
trait ListingCollector {
@@ -522,10 +488,27 @@ impl RemoteStorage for AzureBlobStorage {
max_keys: Option<NonZeroU32>,
cancel: &CancellationToken,
) -> std::result::Result<crate::VersionListing, DownloadError> {
let customize_builder = |mut builder: ListBlobsBuilder| {
builder = builder.include_versions(true);
builder
};
let kind = RequestKind::ListVersions;
let permit = self.permit(kind, cancel).await?;
self.list_versions_with_permit(&permit, prefix, mode, max_keys, cancel)
.await
let mut stream = std::pin::pin!(self.list_streaming_for_fn(
prefix,
mode,
max_keys,
cancel,
kind,
customize_builder
));
let mut combined: crate::VersionListing =
stream.next().await.expect("At least one item required")?;
while let Some(list) = stream.next().await {
let list = list?;
combined.versions.extend(list.versions.into_iter());
}
Ok(combined)
}
async fn head_object(
@@ -820,158 +803,14 @@ impl RemoteStorage for AzureBlobStorage {
async fn time_travel_recover(
&self,
prefix: Option<&RemotePath>,
timestamp: SystemTime,
done_if_after: SystemTime,
cancel: &CancellationToken,
_prefix: Option<&RemotePath>,
_timestamp: SystemTime,
_done_if_after: SystemTime,
_cancel: &CancellationToken,
) -> Result<(), TimeTravelError> {
let msg = "PLEASE NOTE: Azure Blob storage time-travel recovery may not work as expected "
.to_string()
+ "for some specific files. If a file gets deleted but then overwritten and we want to recover "
+ "to the time during the file was not present, this functionality will recover the file. Only "
+ "use the functionality for services that can tolerate this. For example, recovering a state of the "
+ "pageserver tenants.";
tracing::error!("{}", msg);
let kind = RequestKind::TimeTravel;
let permit = self.permit(kind, cancel).await?;
let mode = ListingMode::NoDelimiter;
let version_listing = self
.list_versions_with_permit(&permit, prefix, mode, None, cancel)
.await
.map_err(|err| match err {
DownloadError::Other(e) => TimeTravelError::Other(e),
DownloadError::Cancelled => TimeTravelError::Cancelled,
other => TimeTravelError::Other(other.into()),
})?;
let versions_and_deletes = version_listing.versions;
tracing::info!(
"Built list for time travel with {} versions and deletions",
versions_and_deletes.len()
);
// Work on the list of references instead of the objects directly,
// otherwise we get lifetime errors in the sort_by_key call below.
let mut versions_and_deletes = versions_and_deletes.iter().collect::<Vec<_>>();
versions_and_deletes.sort_by_key(|vd| (&vd.key, &vd.last_modified));
let mut vds_for_key = HashMap::<_, Vec<_>>::new();
for vd in &versions_and_deletes {
let Version { key, .. } = &vd;
let version_id = vd.version_id().map(|v| v.0.as_str());
if version_id == Some("null") {
return Err(TimeTravelError::Other(anyhow!(
"Received ListVersions response for key={key} with version_id='null', \
indicating either disabled versioning, or legacy objects with null version id values"
)));
}
tracing::trace!("Parsing version key={key} kind={:?}", vd.kind);
vds_for_key.entry(key).or_default().push(vd);
}
let warn_threshold = 3;
let max_retries = 10;
let is_permanent = |e: &_| matches!(e, TimeTravelError::Cancelled);
for (key, versions) in vds_for_key {
let last_vd = versions.last().unwrap();
let key = self.relative_path_to_name(key);
if last_vd.last_modified > done_if_after {
tracing::debug!("Key {key} has version later than done_if_after, skipping");
continue;
}
// the version we want to restore to.
let version_to_restore_to =
match versions.binary_search_by_key(&timestamp, |tpl| tpl.last_modified) {
Ok(v) => v,
Err(e) => e,
};
if version_to_restore_to == versions.len() {
tracing::debug!("Key {key} has no changes since timestamp, skipping");
continue;
}
let mut do_delete = false;
if version_to_restore_to == 0 {
// All versions more recent, so the key didn't exist at the specified time point.
tracing::debug!(
"All {} versions more recent for {key}, deleting",
versions.len()
);
do_delete = true;
} else {
match &versions[version_to_restore_to - 1] {
Version {
kind: VersionKind::Version(version_id),
..
} => {
let source_url = format!(
"{}/{}?versionid={}",
self.client
.url()
.map_err(|e| TimeTravelError::Other(anyhow!("{e}")))?,
key,
version_id.0
);
tracing::debug!(
"Promoting old version {} for {key} at {}...",
version_id.0,
source_url
);
backoff::retry(
|| async {
let blob_client = self.client.blob_client(key.clone());
let op = blob_client.copy(Url::from_str(&source_url).unwrap());
tokio::select! {
res = op => res.map_err(|e| TimeTravelError::Other(e.into())),
_ = cancel.cancelled() => Err(TimeTravelError::Cancelled),
}
},
is_permanent,
warn_threshold,
max_retries,
"copying object version for time_travel_recover",
cancel,
)
.await
.ok_or_else(|| TimeTravelError::Cancelled)
.and_then(|x| x)?;
tracing::info!(?version_id, %key, "Copied old version in Azure blob storage");
}
Version {
kind: VersionKind::DeletionMarker,
..
} => {
do_delete = true;
}
}
};
if do_delete {
if matches!(last_vd.kind, VersionKind::DeletionMarker) {
// Key has since been deleted (but there was some history), no need to do anything
tracing::debug!("Key {key} already deleted, skipping.");
} else {
tracing::debug!("Deleting {key}...");
self.delete(&RemotePath::from_string(&key).unwrap(), cancel)
.await
.map_err(|e| {
// delete_oid0 will use TimeoutOrCancel
if TimeoutOrCancel::caused_by_cancel(&e) {
TimeTravelError::Cancelled
} else {
TimeTravelError::Other(e)
}
})?;
}
}
}
Ok(())
// TODO use Azure point in time recovery feature for this
// https://learn.microsoft.com/en-us/azure/storage/blobs/point-in-time-restore-overview
Err(TimeTravelError::Unimplemented)
}
}

View File

@@ -1022,7 +1022,6 @@ impl RemoteStorage for S3Bucket {
let Version { key, .. } = &vd;
let version_id = vd.version_id().map(|v| v.0.as_str());
if version_id == Some("null") {
// TODO: check the behavior of using the SDK on a non-versioned container
return Err(TimeTravelError::Other(anyhow!(
"Received ListVersions response for key={key} with version_id='null', \
indicating either disabled versioning, or legacy objects with null version id values"

View File

@@ -573,8 +573,7 @@ fn start_pageserver(
tokio::sync::mpsc::unbounded_channel();
let deletion_queue_client = deletion_queue.new_client();
let background_purges = mgr::BackgroundPurges::default();
let tenant_manager = mgr::init(
let tenant_manager = BACKGROUND_RUNTIME.block_on(mgr::init_tenant_mgr(
conf,
background_purges.clone(),
TenantSharedResources {
@@ -585,10 +584,10 @@ fn start_pageserver(
basebackup_prepare_sender,
feature_resolver,
},
order,
shutdown_pageserver.clone(),
);
))?;
let tenant_manager = Arc::new(tenant_manager);
BACKGROUND_RUNTIME.block_on(mgr::init_tenant_mgr(tenant_manager.clone(), order))?;
let basebackup_cache = BasebackupCache::spawn(
BACKGROUND_RUNTIME.handle(),

View File

@@ -1,6 +1,5 @@
use std::{collections::HashMap, sync::Arc, time::Duration};
use pageserver_api::config::NodeMetadata;
use posthog_client_lite::{
CaptureEvent, FeatureResolverBackgroundLoop, PostHogClientConfig, PostHogEvaluationError,
PostHogFlagFilterPropertyValue,
@@ -87,35 +86,7 @@ impl FeatureResolver {
}
}
}
// TODO: move this to a background task so that we don't block startup in case of slow disk
let metadata_path = conf.metadata_path();
match std::fs::read_to_string(&metadata_path) {
Ok(metadata_str) => match serde_json::from_str::<NodeMetadata>(&metadata_str) {
Ok(metadata) => {
properties.insert(
"hostname".to_string(),
PostHogFlagFilterPropertyValue::String(metadata.http_host),
);
if let Some(cplane_region) = metadata.other.get("region_id") {
if let Some(cplane_region) = cplane_region.as_str() {
// This region contains the cell number
properties.insert(
"neon_region".to_string(),
PostHogFlagFilterPropertyValue::String(
cplane_region.to_string(),
),
);
}
}
}
Err(e) => {
tracing::warn!("Failed to parse metadata.json: {}", e);
}
},
Err(e) => {
tracing::warn!("Failed to read metadata.json: {}", e);
}
}
// TODO: add pageserver URL.
Arc::new(properties)
};
let fake_tenants = {

View File

@@ -1053,15 +1053,6 @@ pub(crate) static TENANT_STATE_METRIC: Lazy<UIntGaugeVec> = Lazy::new(|| {
.expect("Failed to register pageserver_tenant_states_count metric")
});
pub(crate) static TIMELINE_STATE_METRIC: Lazy<UIntGaugeVec> = Lazy::new(|| {
register_uint_gauge_vec!(
"pageserver_timeline_states_count",
"Count of timelines per state",
&["state"]
)
.expect("Failed to register pageserver_timeline_states_count metric")
});
/// A set of broken tenants.
///
/// These are expected to be so rare that a set is fine. Set as in a new timeseries per each broken
@@ -3334,8 +3325,6 @@ impl TimelineMetrics {
&timeline_id,
);
TIMELINE_STATE_METRIC.with_label_values(&["active"]).inc();
TimelineMetrics {
tenant_id,
shard_id,
@@ -3490,8 +3479,6 @@ impl TimelineMetrics {
return;
}
TIMELINE_STATE_METRIC.with_label_values(&["active"]).dec();
let tenant_id = &self.tenant_id;
let timeline_id = &self.timeline_id;
let shard_id = &self.shard_id;

View File

@@ -89,8 +89,7 @@ use crate::l0_flush::L0FlushGlobalState;
use crate::metrics::{
BROKEN_TENANTS_SET, CIRCUIT_BREAKERS_BROKEN, CIRCUIT_BREAKERS_UNBROKEN, CONCURRENT_INITDBS,
INITDB_RUN_TIME, INITDB_SEMAPHORE_ACQUISITION_TIME, TENANT, TENANT_OFFLOADED_TIMELINES,
TENANT_STATE_METRIC, TENANT_SYNTHETIC_SIZE_METRIC, TIMELINE_STATE_METRIC,
remove_tenant_metrics,
TENANT_STATE_METRIC, TENANT_SYNTHETIC_SIZE_METRIC, remove_tenant_metrics,
};
use crate::task_mgr::TaskKind;
use crate::tenant::config::LocationMode;
@@ -545,28 +544,6 @@ pub struct OffloadedTimeline {
/// Part of the `OffloadedTimeline` object's lifecycle: this needs to be set before we drop it
pub deleted_from_ancestor: AtomicBool,
_metrics_guard: OffloadedTimelineMetricsGuard,
}
/// Increases the offloaded timeline count metric when created, and decreases when dropped.
struct OffloadedTimelineMetricsGuard;
impl OffloadedTimelineMetricsGuard {
fn new() -> Self {
TIMELINE_STATE_METRIC
.with_label_values(&["offloaded"])
.inc();
Self
}
}
impl Drop for OffloadedTimelineMetricsGuard {
fn drop(&mut self) {
TIMELINE_STATE_METRIC
.with_label_values(&["offloaded"])
.dec();
}
}
impl OffloadedTimeline {
@@ -599,8 +576,6 @@ impl OffloadedTimeline {
delete_progress: timeline.delete_progress.clone(),
deleted_from_ancestor: AtomicBool::new(false),
_metrics_guard: OffloadedTimelineMetricsGuard::new(),
})
}
fn from_manifest(tenant_shard_id: TenantShardId, manifest: &OffloadedTimelineManifest) -> Self {
@@ -620,7 +595,6 @@ impl OffloadedTimeline {
archived_at,
delete_progress: TimelineDeleteProgress::default(),
deleted_from_ancestor: AtomicBool::new(false),
_metrics_guard: OffloadedTimelineMetricsGuard::new(),
}
}
fn manifest(&self) -> OffloadedTimelineManifest {

View File

@@ -12,6 +12,7 @@ use anyhow::Context;
use camino::{Utf8DirEntry, Utf8Path, Utf8PathBuf};
use futures::StreamExt;
use itertools::Itertools;
use once_cell::sync::Lazy;
use pageserver_api::key::Key;
use pageserver_api::models::{DetachBehavior, LocationConfigMode};
use pageserver_api::shard::{
@@ -102,7 +103,7 @@ pub(crate) enum TenantsMap {
/// [`init_tenant_mgr`] is not done yet.
Initializing,
/// [`init_tenant_mgr`] is done, all on-disk tenants have been loaded.
/// New tenants can be added using [`TenantManager::tenant_map_acquire_slot`].
/// New tenants can be added using [`tenant_map_acquire_slot`].
Open(BTreeMap<TenantShardId, TenantSlot>),
/// The pageserver has entered shutdown mode via [`TenantManager::shutdown`].
/// Existing tenants are still accessible, but no new tenants can be created.
@@ -283,6 +284,9 @@ impl BackgroundPurges {
}
}
static TENANTS: Lazy<std::sync::RwLock<TenantsMap>> =
Lazy::new(|| std::sync::RwLock::new(TenantsMap::Initializing));
/// Responsible for storing and mutating the collection of all tenants
/// that this pageserver has state for.
///
@@ -293,7 +297,10 @@ impl BackgroundPurges {
/// and attached modes concurrently.
pub struct TenantManager {
conf: &'static PageServerConf,
tenants: std::sync::RwLock<TenantsMap>,
// TODO: currently this is a &'static pointing to TENANTs. When we finish refactoring
// out of that static variable, the TenantManager can own this.
// See https://github.com/neondatabase/neon/issues/5796
tenants: &'static std::sync::RwLock<TenantsMap>,
resources: TenantSharedResources,
// Long-running operations that happen outside of a [`Tenant`] lifetime should respect this token.
@@ -472,43 +479,21 @@ pub(crate) enum DeleteTenantError {
Other(#[from] anyhow::Error),
}
/// Initialize repositories at `Initializing` state.
pub fn init(
conf: &'static PageServerConf,
background_purges: BackgroundPurges,
resources: TenantSharedResources,
cancel: CancellationToken,
) -> TenantManager {
TenantManager {
conf,
tenants: std::sync::RwLock::new(TenantsMap::Initializing),
resources,
cancel,
background_purges,
}
}
/// Transition repositories from `Initializing` state to `Open` state with locally available timelines.
/// Initialize repositories with locally available timelines.
/// Timelines that are only partially available locally (remote storage has more data than this pageserver)
/// are scheduled for download and added to the tenant once download is completed.
#[instrument(skip_all)]
pub async fn init_tenant_mgr(
tenant_manager: Arc<TenantManager>,
conf: &'static PageServerConf,
background_purges: BackgroundPurges,
resources: TenantSharedResources,
init_order: InitializationOrder,
) -> anyhow::Result<()> {
debug_assert!(matches!(
*tenant_manager.tenants.read().unwrap(),
TenantsMap::Initializing
));
cancel: CancellationToken,
) -> anyhow::Result<TenantManager> {
let mut tenants = BTreeMap::new();
let ctx = RequestContext::todo_child(TaskKind::Startup, DownloadBehavior::Warn);
let conf = tenant_manager.conf;
let resources = &tenant_manager.resources;
let cancel = &tenant_manager.cancel;
let background_purges = &tenant_manager.background_purges;
// Initialize dynamic limits that depend on system resources
let system_memory =
sysinfo::System::new_with_specifics(sysinfo::RefreshKind::new().with_memory())
@@ -527,7 +512,7 @@ pub async fn init_tenant_mgr(
let tenant_configs = init_load_tenant_configs(conf).await;
// Determine which tenants are to be secondary or attached, and in which generation
let tenant_modes = init_load_generations(conf, &tenant_configs, resources, cancel).await?;
let tenant_modes = init_load_generations(conf, &tenant_configs, &resources, &cancel).await?;
tracing::info!(
"Attaching {} tenants at startup, warming up {} at a time",
@@ -684,10 +669,18 @@ pub async fn init_tenant_mgr(
info!("Processed {} local tenants at startup", tenants.len());
let mut tenant_map = tenant_manager.tenants.write().unwrap();
*tenant_map = TenantsMap::Open(tenants);
let mut tenants_map = TENANTS.write().unwrap();
assert!(matches!(&*tenants_map, &TenantsMap::Initializing));
Ok(())
*tenants_map = TenantsMap::Open(tenants);
Ok(TenantManager {
conf,
tenants: &TENANTS,
resources,
cancel: CancellationToken::new(),
background_purges,
})
}
/// Wrapper for Tenant::spawn that checks invariants before running
@@ -726,6 +719,142 @@ fn tenant_spawn(
)
}
async fn shutdown_all_tenants0(tenants: &std::sync::RwLock<TenantsMap>) {
let mut join_set = JoinSet::new();
#[cfg(all(debug_assertions, not(test)))]
{
// Check that our metrics properly tracked the size of the tenants map. This is a convenient location to check,
// as it happens implicitly at the end of tests etc.
let m = tenants.read().unwrap();
debug_assert_eq!(METRICS.slots_total(), m.len() as u64);
}
// Atomically, 1. create the shutdown tasks and 2. prevent creation of new tenants.
let (total_in_progress, total_attached) = {
let mut m = tenants.write().unwrap();
match &mut *m {
TenantsMap::Initializing => {
*m = TenantsMap::ShuttingDown(BTreeMap::default());
info!("tenants map is empty");
return;
}
TenantsMap::Open(tenants) => {
let mut shutdown_state = BTreeMap::new();
let mut total_in_progress = 0;
let mut total_attached = 0;
for (tenant_shard_id, v) in std::mem::take(tenants).into_iter() {
match v {
TenantSlot::Attached(t) => {
shutdown_state.insert(tenant_shard_id, TenantSlot::Attached(t.clone()));
join_set.spawn(
async move {
let res = {
let (_guard, shutdown_progress) = completion::channel();
t.shutdown(shutdown_progress, ShutdownMode::FreezeAndFlush).await
};
if let Err(other_progress) = res {
// join the another shutdown in progress
other_progress.wait().await;
}
// we cannot afford per tenant logging here, because if s3 is degraded, we are
// going to log too many lines
debug!("tenant successfully stopped");
}
.instrument(info_span!("shutdown", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug())),
);
total_attached += 1;
}
TenantSlot::Secondary(state) => {
// We don't need to wait for this individually per-tenant: the
// downloader task will be waited on eventually, this cancel
// is just to encourage it to drop out if it is doing work
// for this tenant right now.
state.cancel.cancel();
shutdown_state.insert(tenant_shard_id, TenantSlot::Secondary(state));
}
TenantSlot::InProgress(notify) => {
// InProgress tenants are not visible in TenantsMap::ShuttingDown: we will
// wait for their notifications to fire in this function.
join_set.spawn(async move {
notify.wait().await;
});
total_in_progress += 1;
}
}
}
*m = TenantsMap::ShuttingDown(shutdown_state);
(total_in_progress, total_attached)
}
TenantsMap::ShuttingDown(_) => {
error!(
"already shutting down, this function isn't supposed to be called more than once"
);
return;
}
}
};
let started_at = std::time::Instant::now();
info!(
"Waiting for {} InProgress tenants and {} Attached tenants to shut down",
total_in_progress, total_attached
);
let total = join_set.len();
let mut panicked = 0;
let mut buffering = true;
const BUFFER_FOR: std::time::Duration = std::time::Duration::from_millis(500);
let mut buffered = std::pin::pin!(tokio::time::sleep(BUFFER_FOR));
while !join_set.is_empty() {
tokio::select! {
Some(joined) = join_set.join_next() => {
match joined {
Ok(()) => {},
Err(join_error) if join_error.is_cancelled() => {
unreachable!("we are not cancelling any of the tasks");
}
Err(join_error) if join_error.is_panic() => {
// cannot really do anything, as this panic is likely a bug
panicked += 1;
}
Err(join_error) => {
warn!("unknown kind of JoinError: {join_error}");
}
}
if !buffering {
// buffer so that every 500ms since the first update (or starting) we'll log
// how far away we are; this is because we will get SIGKILL'd at 10s, and we
// are not able to log *then*.
buffering = true;
buffered.as_mut().reset(tokio::time::Instant::now() + BUFFER_FOR);
}
},
_ = &mut buffered, if buffering => {
buffering = false;
info!(remaining = join_set.len(), total, elapsed_ms = started_at.elapsed().as_millis(), "waiting for tenants to shutdown");
}
}
}
if panicked > 0 {
warn!(
panicked,
total, "observed panicks while shutting down tenants"
);
}
// caller will log how long we took
}
#[derive(thiserror::Error, Debug)]
pub(crate) enum UpsertLocationError {
#[error("Bad config request: {0}")]
@@ -927,8 +1056,7 @@ impl TenantManager {
// the tenant is inaccessible to the outside world while we are doing this, but that is sensible:
// the state is ill-defined while we're in transition. Transitions are async, but fast: we do
// not do significant I/O, and shutdowns should be prompt via cancellation tokens.
let mut slot_guard = self
.tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)
let mut slot_guard = tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)
.map_err(|e| match e {
TenantSlotError::NotFound(_) => {
unreachable!("Called with mode Any")
@@ -1095,75 +1223,6 @@ impl TenantManager {
}
}
fn tenant_map_acquire_slot(
&self,
tenant_shard_id: &TenantShardId,
mode: TenantSlotAcquireMode,
) -> Result<SlotGuard, TenantSlotError> {
use TenantSlotAcquireMode::*;
METRICS.tenant_slot_writes.inc();
let mut locked = self.tenants.write().unwrap();
let span = tracing::info_span!("acquire_slot", tenant_id=%tenant_shard_id.tenant_id, shard_id = %tenant_shard_id.shard_slug());
let _guard = span.enter();
let m = match &mut *locked {
TenantsMap::Initializing => return Err(TenantMapError::StillInitializing.into()),
TenantsMap::ShuttingDown(_) => return Err(TenantMapError::ShuttingDown.into()),
TenantsMap::Open(m) => m,
};
use std::collections::btree_map::Entry;
let entry = m.entry(*tenant_shard_id);
match entry {
Entry::Vacant(v) => match mode {
MustExist => {
tracing::debug!("Vacant && MustExist: return NotFound");
Err(TenantSlotError::NotFound(*tenant_shard_id))
}
_ => {
let (completion, barrier) = utils::completion::channel();
let inserting = TenantSlot::InProgress(barrier);
METRICS.slot_inserted(&inserting);
v.insert(inserting);
tracing::debug!("Vacant, inserted InProgress");
Ok(SlotGuard::new(
*tenant_shard_id,
None,
completion,
&self.tenants,
))
}
},
Entry::Occupied(mut o) => {
// Apply mode-driven checks
match (o.get(), mode) {
(TenantSlot::InProgress(_), _) => {
tracing::debug!("Occupied, failing for InProgress");
Err(TenantSlotError::InProgress)
}
_ => {
// Happy case: the slot was not in any state that violated our mode
let (completion, barrier) = utils::completion::channel();
let in_progress = TenantSlot::InProgress(barrier);
METRICS.slot_inserted(&in_progress);
let old_value = o.insert(in_progress);
METRICS.slot_removed(&old_value);
tracing::debug!("Occupied, replaced with InProgress");
Ok(SlotGuard::new(
*tenant_shard_id,
Some(old_value),
completion,
&self.tenants,
))
}
}
}
}
}
/// Resetting a tenant is equivalent to detaching it, then attaching it again with the same
/// LocationConf that was last used to attach it. Optionally, the local file cache may be
/// dropped before re-attaching.
@@ -1180,8 +1239,7 @@ impl TenantManager {
drop_cache: bool,
ctx: &RequestContext,
) -> anyhow::Result<()> {
let mut slot_guard =
self.tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)?;
let mut slot_guard = tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)?;
let Some(old_slot) = slot_guard.get_old_value() else {
anyhow::bail!("Tenant not found when trying to reset");
};
@@ -1330,8 +1388,7 @@ impl TenantManager {
Ok(())
}
let slot_guard =
self.tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)?;
let slot_guard = tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)?;
match &slot_guard.old_value {
Some(TenantSlot::Attached(tenant)) => {
// Legacy deletion flow: the tenant remains attached, goes to Stopping state, and
@@ -1482,7 +1539,7 @@ impl TenantManager {
// Phase 2: Put the parent shard to InProgress and grab a reference to the parent Tenant
drop(tenant);
let mut parent_slot_guard =
self.tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)?;
tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)?;
let parent = match parent_slot_guard.get_old_value() {
Some(TenantSlot::Attached(t)) => t,
Some(TenantSlot::Secondary(_)) => anyhow::bail!("Tenant location in secondary mode"),
@@ -1786,145 +1843,7 @@ impl TenantManager {
pub(crate) async fn shutdown(&self) {
self.cancel.cancel();
self.shutdown_all_tenants0().await
}
async fn shutdown_all_tenants0(&self) {
let mut join_set = JoinSet::new();
#[cfg(all(debug_assertions, not(test)))]
{
// Check that our metrics properly tracked the size of the tenants map. This is a convenient location to check,
// as it happens implicitly at the end of tests etc.
let m = self.tenants.read().unwrap();
debug_assert_eq!(METRICS.slots_total(), m.len() as u64);
}
// Atomically, 1. create the shutdown tasks and 2. prevent creation of new tenants.
let (total_in_progress, total_attached) = {
let mut m = self.tenants.write().unwrap();
match &mut *m {
TenantsMap::Initializing => {
*m = TenantsMap::ShuttingDown(BTreeMap::default());
info!("tenants map is empty");
return;
}
TenantsMap::Open(tenants) => {
let mut shutdown_state = BTreeMap::new();
let mut total_in_progress = 0;
let mut total_attached = 0;
for (tenant_shard_id, v) in std::mem::take(tenants).into_iter() {
match v {
TenantSlot::Attached(t) => {
shutdown_state
.insert(tenant_shard_id, TenantSlot::Attached(t.clone()));
join_set.spawn(
async move {
let res = {
let (_guard, shutdown_progress) = completion::channel();
t.shutdown(shutdown_progress, ShutdownMode::FreezeAndFlush).await
};
if let Err(other_progress) = res {
// join the another shutdown in progress
other_progress.wait().await;
}
// we cannot afford per tenant logging here, because if s3 is degraded, we are
// going to log too many lines
debug!("tenant successfully stopped");
}
.instrument(info_span!("shutdown", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug())),
);
total_attached += 1;
}
TenantSlot::Secondary(state) => {
// We don't need to wait for this individually per-tenant: the
// downloader task will be waited on eventually, this cancel
// is just to encourage it to drop out if it is doing work
// for this tenant right now.
state.cancel.cancel();
shutdown_state
.insert(tenant_shard_id, TenantSlot::Secondary(state));
}
TenantSlot::InProgress(notify) => {
// InProgress tenants are not visible in TenantsMap::ShuttingDown: we will
// wait for their notifications to fire in this function.
join_set.spawn(async move {
notify.wait().await;
});
total_in_progress += 1;
}
}
}
*m = TenantsMap::ShuttingDown(shutdown_state);
(total_in_progress, total_attached)
}
TenantsMap::ShuttingDown(_) => {
error!(
"already shutting down, this function isn't supposed to be called more than once"
);
return;
}
}
};
let started_at = std::time::Instant::now();
info!(
"Waiting for {} InProgress tenants and {} Attached tenants to shut down",
total_in_progress, total_attached
);
let total = join_set.len();
let mut panicked = 0;
let mut buffering = true;
const BUFFER_FOR: std::time::Duration = std::time::Duration::from_millis(500);
let mut buffered = std::pin::pin!(tokio::time::sleep(BUFFER_FOR));
while !join_set.is_empty() {
tokio::select! {
Some(joined) = join_set.join_next() => {
match joined {
Ok(()) => {},
Err(join_error) if join_error.is_cancelled() => {
unreachable!("we are not cancelling any of the tasks");
}
Err(join_error) if join_error.is_panic() => {
// cannot really do anything, as this panic is likely a bug
panicked += 1;
}
Err(join_error) => {
warn!("unknown kind of JoinError: {join_error}");
}
}
if !buffering {
// buffer so that every 500ms since the first update (or starting) we'll log
// how far away we are; this is because we will get SIGKILL'd at 10s, and we
// are not able to log *then*.
buffering = true;
buffered.as_mut().reset(tokio::time::Instant::now() + BUFFER_FOR);
}
},
_ = &mut buffered, if buffering => {
buffering = false;
info!(remaining = join_set.len(), total, elapsed_ms = started_at.elapsed().as_millis(), "waiting for tenants to shutdown");
}
}
}
if panicked > 0 {
warn!(
panicked,
total, "observed panicks while shutting down tenants"
);
}
// caller will log how long we took
shutdown_all_tenants0(self.tenants).await
}
/// Detaches a tenant, and removes its local files asynchronously.
@@ -1970,12 +1889,12 @@ impl TenantManager {
.map(Some)
};
let mut removal_result = self
.remove_tenant_from_memory(
tenant_shard_id,
tenant_dir_rename_operation(tenant_shard_id),
)
.await;
let mut removal_result = remove_tenant_from_memory(
self.tenants,
tenant_shard_id,
tenant_dir_rename_operation(tenant_shard_id),
)
.await;
// If the tenant was not found, it was likely already removed. Attempt to remove the tenant
// directory on disk anyway. For example, during shard splits, we shut down and remove the
@@ -2029,16 +1948,17 @@ impl TenantManager {
) -> Result<HashSet<TimelineId>, detach_ancestor::Error> {
use detach_ancestor::Error;
let slot_guard = self
.tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::MustExist)
.map_err(|e| {
use TenantSlotError::*;
let slot_guard =
tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::MustExist).map_err(
|e| {
use TenantSlotError::*;
match e {
MapState(TenantMapError::ShuttingDown) => Error::ShuttingDown,
NotFound(_) | InProgress | MapState(_) => Error::DetachReparent(e.into()),
}
})?;
match e {
MapState(TenantMapError::ShuttingDown) => Error::ShuttingDown,
NotFound(_) | InProgress | MapState(_) => Error::DetachReparent(e.into()),
}
},
)?;
let tenant = {
let old_slot = slot_guard
@@ -2371,80 +2291,6 @@ impl TenantManager {
other => ApiError::InternalServerError(anyhow::anyhow!(other)),
})
}
/// Stops and removes the tenant from memory, if it's not [`TenantState::Stopping`] already, bails otherwise.
/// Allows to remove other tenant resources manually, via `tenant_cleanup`.
/// If the cleanup fails, tenant will stay in memory in [`TenantState::Broken`] state, and another removal
async fn remove_tenant_from_memory<V, F>(
&self,
tenant_shard_id: TenantShardId,
tenant_cleanup: F,
) -> Result<V, TenantStateError>
where
F: std::future::Future<Output = anyhow::Result<V>>,
{
let mut slot_guard =
self.tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::MustExist)?;
// allow pageserver shutdown to await for our completion
let (_guard, progress) = completion::channel();
// The SlotGuard allows us to manipulate the Tenant object without fear of some
// concurrent API request doing something else for the same tenant ID.
let attached_tenant = match slot_guard.get_old_value() {
Some(TenantSlot::Attached(tenant)) => {
// whenever we remove a tenant from memory, we don't want to flush and wait for upload
let shutdown_mode = ShutdownMode::Hard;
// shutdown is sure to transition tenant to stopping, and wait for all tasks to complete, so
// that we can continue safely to cleanup.
match tenant.shutdown(progress, shutdown_mode).await {
Ok(()) => {}
Err(_other) => {
// if pageserver shutdown or other detach/ignore is already ongoing, we don't want to
// wait for it but return an error right away because these are distinct requests.
slot_guard.revert();
return Err(TenantStateError::IsStopping(tenant_shard_id));
}
}
Some(tenant)
}
Some(TenantSlot::Secondary(secondary_state)) => {
tracing::info!("Shutting down in secondary mode");
secondary_state.shutdown().await;
None
}
Some(TenantSlot::InProgress(_)) => {
// Acquiring a slot guarantees its old value was not InProgress
unreachable!();
}
None => None,
};
match tenant_cleanup
.await
.with_context(|| format!("Failed to run cleanup for tenant {tenant_shard_id}"))
{
Ok(hook_value) => {
// Success: drop the old TenantSlot::Attached.
slot_guard
.drop_old_value()
.expect("We just called shutdown");
Ok(hook_value)
}
Err(e) => {
// If we had a Tenant, set it to Broken and put it back in the TenantsMap
if let Some(attached_tenant) = attached_tenant {
attached_tenant.set_broken(e.to_string()).await;
}
// Leave the broken tenant in the map
slot_guard.revert();
Err(TenantStateError::Other(e))
}
}
}
}
#[derive(Debug, thiserror::Error)]
@@ -2609,7 +2455,7 @@ pub(crate) enum TenantMapError {
/// this tenant to retry later, or wait for the InProgress state to end.
///
/// This structure enforces the important invariant that we do not have overlapping
/// tasks that will try to use local storage for a the same tenant ID: we enforce that
/// tasks that will try use local storage for a the same tenant ID: we enforce that
/// the previous contents of a slot have been shut down before the slot can be
/// left empty or used for something else
///
@@ -2622,7 +2468,7 @@ pub(crate) enum TenantMapError {
/// The `old_value` may be dropped before the SlotGuard is dropped, by calling
/// `drop_old_value`. It is an error to call this without shutting down
/// the conents of `old_value`.
pub(crate) struct SlotGuard<'a> {
pub(crate) struct SlotGuard {
tenant_shard_id: TenantShardId,
old_value: Option<TenantSlot>,
upserted: bool,
@@ -2630,23 +2476,19 @@ pub(crate) struct SlotGuard<'a> {
/// [`TenantSlot::InProgress`] carries the corresponding Barrier: it will
/// release any waiters as soon as this SlotGuard is dropped.
completion: utils::completion::Completion,
tenants: &'a std::sync::RwLock<TenantsMap>,
}
impl<'a> SlotGuard<'a> {
impl SlotGuard {
fn new(
tenant_shard_id: TenantShardId,
old_value: Option<TenantSlot>,
completion: utils::completion::Completion,
tenants: &'a std::sync::RwLock<TenantsMap>,
) -> Self {
Self {
tenant_shard_id,
old_value,
upserted: false,
completion,
tenants,
}
}
@@ -2670,8 +2512,8 @@ impl<'a> SlotGuard<'a> {
));
}
let replaced: Option<TenantSlot> = {
let mut locked = self.tenants.write().unwrap();
let replaced = {
let mut locked = TENANTS.write().unwrap();
if let TenantSlot::InProgress(_) = new_value {
// It is never expected to try and upsert InProgress via this path: it should
@@ -2779,7 +2621,7 @@ impl<'a> SlotGuard<'a> {
}
}
impl<'a> Drop for SlotGuard<'a> {
impl Drop for SlotGuard {
fn drop(&mut self) {
if self.upserted {
return;
@@ -2787,7 +2629,7 @@ impl<'a> Drop for SlotGuard<'a> {
// Our old value is already shutdown, or it never existed: it is safe
// for us to fully release the TenantSlot back into an empty state
let mut locked = self.tenants.write().unwrap();
let mut locked = TENANTS.write().unwrap();
let m = match &mut *locked {
TenantsMap::Initializing => {
@@ -2869,6 +2711,151 @@ enum TenantSlotAcquireMode {
MustExist,
}
fn tenant_map_acquire_slot(
tenant_shard_id: &TenantShardId,
mode: TenantSlotAcquireMode,
) -> Result<SlotGuard, TenantSlotError> {
tenant_map_acquire_slot_impl(tenant_shard_id, &TENANTS, mode)
}
fn tenant_map_acquire_slot_impl(
tenant_shard_id: &TenantShardId,
tenants: &std::sync::RwLock<TenantsMap>,
mode: TenantSlotAcquireMode,
) -> Result<SlotGuard, TenantSlotError> {
use TenantSlotAcquireMode::*;
METRICS.tenant_slot_writes.inc();
let mut locked = tenants.write().unwrap();
let span = tracing::info_span!("acquire_slot", tenant_id=%tenant_shard_id.tenant_id, shard_id = %tenant_shard_id.shard_slug());
let _guard = span.enter();
let m = match &mut *locked {
TenantsMap::Initializing => return Err(TenantMapError::StillInitializing.into()),
TenantsMap::ShuttingDown(_) => return Err(TenantMapError::ShuttingDown.into()),
TenantsMap::Open(m) => m,
};
use std::collections::btree_map::Entry;
let entry = m.entry(*tenant_shard_id);
match entry {
Entry::Vacant(v) => match mode {
MustExist => {
tracing::debug!("Vacant && MustExist: return NotFound");
Err(TenantSlotError::NotFound(*tenant_shard_id))
}
_ => {
let (completion, barrier) = utils::completion::channel();
let inserting = TenantSlot::InProgress(barrier);
METRICS.slot_inserted(&inserting);
v.insert(inserting);
tracing::debug!("Vacant, inserted InProgress");
Ok(SlotGuard::new(*tenant_shard_id, None, completion))
}
},
Entry::Occupied(mut o) => {
// Apply mode-driven checks
match (o.get(), mode) {
(TenantSlot::InProgress(_), _) => {
tracing::debug!("Occupied, failing for InProgress");
Err(TenantSlotError::InProgress)
}
_ => {
// Happy case: the slot was not in any state that violated our mode
let (completion, barrier) = utils::completion::channel();
let in_progress = TenantSlot::InProgress(barrier);
METRICS.slot_inserted(&in_progress);
let old_value = o.insert(in_progress);
METRICS.slot_removed(&old_value);
tracing::debug!("Occupied, replaced with InProgress");
Ok(SlotGuard::new(
*tenant_shard_id,
Some(old_value),
completion,
))
}
}
}
}
}
/// Stops and removes the tenant from memory, if it's not [`TenantState::Stopping`] already, bails otherwise.
/// Allows to remove other tenant resources manually, via `tenant_cleanup`.
/// If the cleanup fails, tenant will stay in memory in [`TenantState::Broken`] state, and another removal
/// operation would be needed to remove it.
async fn remove_tenant_from_memory<V, F>(
tenants: &std::sync::RwLock<TenantsMap>,
tenant_shard_id: TenantShardId,
tenant_cleanup: F,
) -> Result<V, TenantStateError>
where
F: std::future::Future<Output = anyhow::Result<V>>,
{
let mut slot_guard =
tenant_map_acquire_slot_impl(&tenant_shard_id, tenants, TenantSlotAcquireMode::MustExist)?;
// allow pageserver shutdown to await for our completion
let (_guard, progress) = completion::channel();
// The SlotGuard allows us to manipulate the Tenant object without fear of some
// concurrent API request doing something else for the same tenant ID.
let attached_tenant = match slot_guard.get_old_value() {
Some(TenantSlot::Attached(tenant)) => {
// whenever we remove a tenant from memory, we don't want to flush and wait for upload
let shutdown_mode = ShutdownMode::Hard;
// shutdown is sure to transition tenant to stopping, and wait for all tasks to complete, so
// that we can continue safely to cleanup.
match tenant.shutdown(progress, shutdown_mode).await {
Ok(()) => {}
Err(_other) => {
// if pageserver shutdown or other detach/ignore is already ongoing, we don't want to
// wait for it but return an error right away because these are distinct requests.
slot_guard.revert();
return Err(TenantStateError::IsStopping(tenant_shard_id));
}
}
Some(tenant)
}
Some(TenantSlot::Secondary(secondary_state)) => {
tracing::info!("Shutting down in secondary mode");
secondary_state.shutdown().await;
None
}
Some(TenantSlot::InProgress(_)) => {
// Acquiring a slot guarantees its old value was not InProgress
unreachable!();
}
None => None,
};
match tenant_cleanup
.await
.with_context(|| format!("Failed to run cleanup for tenant {tenant_shard_id}"))
{
Ok(hook_value) => {
// Success: drop the old TenantSlot::Attached.
slot_guard
.drop_old_value()
.expect("We just called shutdown");
Ok(hook_value)
}
Err(e) => {
// If we had a Tenant, set it to Broken and put it back in the TenantsMap
if let Some(attached_tenant) = attached_tenant {
attached_tenant.set_broken(e.to_string()).await;
}
// Leave the broken tenant in the map
slot_guard.revert();
Err(TenantStateError::Other(e))
}
}
}
use http_utils::error::ApiError;
use pageserver_api::models::TimelineGcRequest;
@@ -2879,15 +2866,11 @@ mod tests {
use std::collections::BTreeMap;
use std::sync::Arc;
use storage_broker::BrokerClientChannel;
use tracing::Instrument;
use super::super::harness::TenantHarness;
use super::TenantsMap;
use crate::tenant::{
TenantSharedResources,
mgr::{BackgroundPurges, TenantManager, TenantSlot},
};
use crate::tenant::mgr::TenantSlot;
#[tokio::test(start_paused = true)]
async fn shutdown_awaits_in_progress_tenant() {
@@ -2908,47 +2891,23 @@ mod tests {
let _e = span.enter();
let tenants = BTreeMap::from([(id, TenantSlot::Attached(t.clone()))]);
let tenants = Arc::new(std::sync::RwLock::new(TenantsMap::Open(tenants)));
// Invoke remove_tenant_from_memory with a cleanup hook that blocks until we manually
// permit it to proceed: that will stick the tenant in InProgress
let (basebackup_prepare_sender, _) = tokio::sync::mpsc::unbounded_channel::<
crate::basebackup_cache::BasebackupPrepareRequest,
>();
let tenant_manager = TenantManager {
tenants: std::sync::RwLock::new(TenantsMap::Open(tenants)),
conf: h.conf,
resources: TenantSharedResources {
broker_client: BrokerClientChannel::connect_lazy("foobar.com")
.await
.unwrap(),
remote_storage: h.remote_storage.clone(),
deletion_queue_client: h.deletion_queue.new_client(),
l0_flush_global_state: crate::l0_flush::L0FlushGlobalState::new(
h.conf.l0_flush.clone(),
),
basebackup_prepare_sender,
feature_resolver: crate::feature_resolver::FeatureResolver::new_disabled(),
},
cancel: tokio_util::sync::CancellationToken::new(),
background_purges: BackgroundPurges::default(),
};
let tenant_manager = Arc::new(tenant_manager);
let (until_cleanup_completed, can_complete_cleanup) = utils::completion::channel();
let (until_cleanup_started, cleanup_started) = utils::completion::channel();
let mut remove_tenant_from_memory_task = {
let tenant_manager = tenant_manager.clone();
let jh = tokio::spawn({
let tenants = tenants.clone();
async move {
let cleanup = async move {
drop(until_cleanup_started);
can_complete_cleanup.wait().await;
anyhow::Ok(())
};
tenant_manager.remove_tenant_from_memory(id, cleanup).await
super::remove_tenant_from_memory(&tenants, id, cleanup).await
}
.instrument(h.span())
});
@@ -2961,11 +2920,9 @@ mod tests {
let mut shutdown_task = {
let (until_shutdown_started, shutdown_started) = utils::completion::channel();
let tenant_manager = tenant_manager.clone();
let shutdown_task = tokio::spawn(async move {
drop(until_shutdown_started);
tenant_manager.shutdown_all_tenants0().await;
super::shutdown_all_tenants0(&tenants).await;
});
shutdown_started.wait().await;

View File

@@ -1092,15 +1092,13 @@ communicator_prefetch_register_bufferv(BufferTag tag, neon_request_lsns *frlsns,
MyPState->ring_last <= ring_index);
}
/* Internal version. Returns the ring index of the last block (result of this function is used only
* when nblocks==1)
*/
/* internal version. Returns the ring index */
static uint64
prefetch_register_bufferv(BufferTag tag, neon_request_lsns *frlsns,
BlockNumber nblocks, const bits8 *mask,
bool is_prefetch)
{
uint64 last_ring_index;
uint64 min_ring_index;
PrefetchRequest hashkey;
#ifdef USE_ASSERT_CHECKING
bool any_hits = false;
@@ -1124,12 +1122,13 @@ Retry:
MyPState->ring_unused - MyPState->ring_receive;
MyNeonCounters->getpage_prefetches_buffered =
MyPState->n_responses_buffered;
last_ring_index = UINT64_MAX;
min_ring_index = UINT64_MAX;
for (int i = 0; i < nblocks; i++)
{
PrefetchRequest *slot = NULL;
PrfHashEntry *entry = NULL;
uint64 ring_index;
neon_request_lsns *lsns;
if (PointerIsValid(mask) && BITMAP_ISSET(mask, i))
@@ -1153,12 +1152,12 @@ Retry:
if (entry != NULL)
{
slot = entry->slot;
last_ring_index = slot->my_ring_index;
Assert(slot == GetPrfSlot(last_ring_index));
ring_index = slot->my_ring_index;
Assert(slot == GetPrfSlot(ring_index));
Assert(slot->status != PRFS_UNUSED);
Assert(MyPState->ring_last <= last_ring_index &&
last_ring_index < MyPState->ring_unused);
Assert(MyPState->ring_last <= ring_index &&
ring_index < MyPState->ring_unused);
Assert(BufferTagsEqual(&slot->buftag, &hashkey.buftag));
/*
@@ -1170,9 +1169,9 @@ Retry:
if (!neon_prefetch_response_usable(lsns, slot))
{
/* Wait for the old request to finish and discard it */
if (!prefetch_wait_for(last_ring_index))
if (!prefetch_wait_for(ring_index))
goto Retry;
prefetch_set_unused(last_ring_index);
prefetch_set_unused(ring_index);
entry = NULL;
slot = NULL;
pgBufferUsage.prefetch.expired += 1;
@@ -1189,12 +1188,13 @@ Retry:
*/
if (slot->status == PRFS_TAG_REMAINS)
{
prefetch_set_unused(last_ring_index);
prefetch_set_unused(ring_index);
entry = NULL;
slot = NULL;
}
else
{
min_ring_index = Min(min_ring_index, ring_index);
/* The buffered request is good enough, return that index */
if (is_prefetch)
pgBufferUsage.prefetch.duplicates++;
@@ -1283,12 +1283,12 @@ Retry:
* The next buffer pointed to by `ring_unused` is now definitely empty, so
* we can insert the new request to it.
*/
last_ring_index = MyPState->ring_unused;
ring_index = MyPState->ring_unused;
Assert(MyPState->ring_last <= last_ring_index &&
last_ring_index <= MyPState->ring_unused);
Assert(MyPState->ring_last <= ring_index &&
ring_index <= MyPState->ring_unused);
slot = GetPrfSlotNoCheck(last_ring_index);
slot = GetPrfSlotNoCheck(ring_index);
Assert(slot->status == PRFS_UNUSED);
@@ -1298,9 +1298,11 @@ Retry:
*/
slot->buftag = hashkey.buftag;
slot->shard_no = get_shard_number(&tag);
slot->my_ring_index = last_ring_index;
slot->my_ring_index = ring_index;
slot->flags = 0;
min_ring_index = Min(min_ring_index, ring_index);
if (is_prefetch)
MyNeonCounters->getpage_prefetch_requests_total++;
else
@@ -1313,12 +1315,11 @@ Retry:
MyPState->ring_unused - MyPState->ring_receive;
Assert(any_hits);
Assert(last_ring_index != UINT64_MAX);
Assert(GetPrfSlot(last_ring_index)->status == PRFS_REQUESTED ||
GetPrfSlot(last_ring_index)->status == PRFS_RECEIVED);
Assert(MyPState->ring_last <= last_ring_index &&
last_ring_index < MyPState->ring_unused);
Assert(GetPrfSlot(min_ring_index)->status == PRFS_REQUESTED ||
GetPrfSlot(min_ring_index)->status == PRFS_RECEIVED);
Assert(MyPState->ring_last <= min_ring_index &&
min_ring_index < MyPState->ring_unused);
if (flush_every_n_requests > 0 &&
MyPState->ring_unused - MyPState->ring_flush >= flush_every_n_requests)
@@ -1334,7 +1335,7 @@ Retry:
MyPState->ring_flush = MyPState->ring_unused;
}
return last_ring_index;
return min_ring_index;
}
static bool

View File

@@ -1135,7 +1135,7 @@ VotesCollectedMset(WalProposer *wp, MemberSet *mset, Safekeeper **msk, StringInf
wp->propTermStartLsn = sk->voteResponse.flushLsn;
wp->donor = sk;
}
wp->truncateLsn = Max(sk->voteResponse.truncateLsn, wp->truncateLsn);
wp->truncateLsn = Max(wp->safekeeper[i].voteResponse.truncateLsn, wp->truncateLsn);
if (n_votes > 0)
appendStringInfoString(s, ", ");

View File

@@ -89,6 +89,7 @@ tokio-postgres = { workspace = true, optional = true }
tokio-rustls.workspace = true
tokio-util.workspace = true
tokio = { workspace = true, features = ["signal"] }
toml.workspace = true
tracing-subscriber.workspace = true
tracing-utils.workspace = true
tracing.workspace = true

View File

@@ -14,9 +14,9 @@ use crate::context::RequestContext;
use crate::control_plane::client::cplane_proxy_v1;
use crate::control_plane::{self, CachedNodeInfo, NodeInfo};
use crate::error::{ReportableError, UserFacingError};
use crate::pglb::connect_compute::ComputeConnectBackend;
use crate::pqproto::BeMessage;
use crate::proxy::NeonOptions;
use crate::proxy::wake_compute::WakeComputeBackend;
use crate::stream::PqStream;
use crate::types::RoleName;
use crate::{auth, compute, waiters};
@@ -109,7 +109,7 @@ impl ConsoleRedirectBackend {
pub struct ConsoleRedirectNodeInfo(pub(super) NodeInfo);
#[async_trait]
impl WakeComputeBackend for ConsoleRedirectNodeInfo {
impl ComputeConnectBackend for ConsoleRedirectNodeInfo {
async fn wake_compute(
&self,
_ctx: &RequestContext,

View File

@@ -25,9 +25,9 @@ use crate::control_plane::{
RoleAccessControl,
};
use crate::intern::EndpointIdInt;
use crate::pglb::connect_compute::ComputeConnectBackend;
use crate::pqproto::BeMessage;
use crate::proxy::NeonOptions;
use crate::proxy::wake_compute::WakeComputeBackend;
use crate::rate_limiter::EndpointRateLimiter;
use crate::stream::Stream;
use crate::types::{EndpointCacheKey, EndpointId, RoleName};
@@ -407,13 +407,13 @@ impl Backend<'_, ComputeUserInfo> {
}
#[async_trait::async_trait]
impl WakeComputeBackend for Backend<'_, ComputeUserInfo> {
impl ComputeConnectBackend for Backend<'_, ComputeCredentials> {
async fn wake_compute(
&self,
ctx: &RequestContext,
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
match self {
Self::ControlPlane(api, info) => api.wake_compute(ctx, info).await,
Self::ControlPlane(api, creds) => api.wake_compute(ctx, &creds.info).await,
Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())),
}
}

View File

@@ -279,7 +279,6 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
},
proxy_protocol_v2: config::ProxyProtocolV2::Rejected,
handshake_timeout: Duration::from_secs(10),
region: "local".into(),
wake_compute_retry_config: RetryConfig::parse(RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)?,
connect_compute_locks,
connect_to_compute: compute_config,

View File

@@ -28,9 +28,10 @@ use crate::context::RequestContext;
use crate::metrics::{Metrics, ThreadPoolMetrics};
use crate::pqproto::FeStartupPacket;
use crate::protocol2::ConnectionInfo;
use crate::proxy::{ErrorSource, TlsRequired, copy_bidirectional_client_compute};
use crate::proxy::{
ErrorSource, TlsRequired, copy_bidirectional_client_compute, run_until_cancelled,
};
use crate::stream::{PqStream, Stream};
use crate::util::run_until_cancelled;
project_git_version!(GIT_VERSION);
@@ -236,7 +237,6 @@ pub(super) async fn task_main(
extra: None,
},
crate::metrics::Protocol::SniRouter,
"sni",
);
handle_client(ctx, dest_suffix, tls_config, compute_tls_config, socket).await
}

View File

@@ -8,14 +8,15 @@ use std::time::Duration;
#[cfg(any(test, feature = "testing"))]
use anyhow::Context;
use anyhow::{bail, ensure};
use anyhow::{bail, anyhow};
use arc_swap::ArcSwapOption;
use futures::future::Either;
use remote_storage::RemoteStorageConfig;
use serde::Deserialize;
use tokio::net::TcpListener;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, info, warn};
use tracing::{Instrument, info};
use utils::sentry_init::init_sentry;
use utils::{project_build_tag, project_git_version};
@@ -39,7 +40,7 @@ use crate::serverless::cancel_set::CancelSet;
use crate::tls::client_config::compute_client_config_with_root_certs;
#[cfg(any(test, feature = "testing"))]
use crate::url::ApiUrl;
use crate::{auth, control_plane, http, serverless, usage_metrics};
use crate::{auth, control_plane, http, pglb, serverless, usage_metrics};
project_git_version!(GIT_VERSION);
project_build_tag!(BUILD_TAG);
@@ -59,6 +60,262 @@ enum AuthBackendType {
Postgres,
}
#[derive(Deserialize)]
struct Root {
#[serde(flatten)]
legacy: LegacyModes,
introspection: Introspection,
}
#[derive(Deserialize)]
#[serde(untagged)]
enum LegacyModes {
Proxy {
pglb: Pglb,
neonkeeper: NeonKeeper,
http: Option<Http>,
pg_sni_router: Option<PgSniRouter>,
},
AuthBroker {
neonkeeper: NeonKeeper,
http: Http,
},
ConsoleRedirect {
console_redirect: ConsoleRedirect,
},
}
#[derive(Deserialize)]
struct Pglb {
listener: Listener,
}
#[derive(Deserialize)]
struct Listener {
/// address to bind to
addr: SocketAddr,
/// which header should we expect to see on this socket
/// from the load balancer
header: Option<ProxyHeader>,
/// certificates used for TLS.
/// first cert is the default.
/// TLS not used if no certs provided.
certs: Vec<KeyPair>,
/// Timeout to use for TLS handshake
timeout: Option<Duration>,
}
#[derive(Deserialize)]
enum ProxyHeader {
/// Accept the PROXY! protocol V2.
ProxyProtocolV2(ProxyProtocolV2Kind),
}
#[derive(Deserialize)]
enum ProxyProtocolV2Kind {
/// Expect AWS TLVs in the header.
Aws,
/// Expect Azure TLVs in the header.
Azure,
}
#[derive(Deserialize)]
struct KeyPair {
key: PathBuf,
cert: PathBuf,
}
#[derive(Deserialize)]
/// The service that authenticates all incoming connection attempts,
/// provides monitoring and also wakes computes.
struct NeonKeeper {
cplane: ControlPlaneBackend,
redis: Option<Redis>,
auth: Vec<AuthMechanism>,
/// map of endpoint->computeinfo
compute: Cache,
/// cache for GetEndpointAccessControls.
project_info_cache: config::ProjectInfoCacheOptions,
/// cache for all valid endpoints
endpoint_cache_config: config::EndpointCacheConfig,
request_log_export: Option<RequestLogExport>,
data_transfer_export: Option<DataTransferExport>,
}
#[derive(Deserialize)]
struct Redis {
/// Cancellation channel size (max queue size for redis kv client)
cancellation_ch_size: usize,
/// Cancellation ops batch size for redis
cancellation_batch_size: usize,
auth: RedisAuthentication,
}
#[derive(Deserialize)]
enum RedisAuthentication {
/// i don't remember what this stands for.
/// IAM roles for service accounts?
Irsa {
host: String,
port: u16,
cluster_name: Option<String>,
user_id: Option<String>,
aws_region: String,
},
Basic {
url: url::Url,
},
}
#[derive(Deserialize)]
struct PgSniRouter {
/// The listener to use to proxy connections to compute,
/// assuming the compute does not support TLS.
listener: Listener,
/// The listener to use to proxy connections to compute,
/// assuming the compute requires TLS.
listener_tls: Listener,
/// append this domain zone to the SNI hostname to get the destination address
dest: String,
}
#[derive(Deserialize)]
/// `psql -h pg.neon.tech`.
struct ConsoleRedirect {
/// Connection requests from clients.
listener: Listener,
/// Messages from control plane to accept the connection.
cplane: Listener,
/// The base url to use for redirects.
console: url::Url,
timeout: Duration,
}
#[derive(Deserialize)]
enum ControlPlaneBackend {
/// Use the HTTP API to access the control plane.
Http(url::Url),
/// Stub the control plane with a postgres instance.
#[cfg(feature = "testing")]
PostgresMock(url::Url),
}
#[derive(Deserialize)]
struct Http {
listener: Listener,
sql_over_http: SqlOverHttp,
// todo: move into Pglb.
websockets: Option<Websockets>,
}
#[derive(Deserialize)]
struct SqlOverHttp {
pool_max_conns_per_endpoint: usize,
pool_max_total_conns: usize,
pool_idle_timeout: Duration,
pool_gc_epoch: Duration,
pool_shards: usize,
client_conn_threshold: u64,
cancel_set_shards: usize,
timeout: Duration,
max_request_size_bytes: usize,
max_response_size_bytes: usize,
auth: Vec<AuthMechanism>,
}
#[derive(Deserialize)]
enum AuthMechanism {
Sasl {
/// timeout for SASL handshake
timeout: Duration,
},
CleartextPassword {
/// number of threads for the thread pool
threads: usize,
},
// add something about the jwks cache i guess.
Jwt {},
}
#[derive(Deserialize)]
struct Websockets {
auth: Vec<AuthMechanism>,
}
#[derive(Deserialize)]
/// The HTTP API used for internal monitoring.
struct Introspection {
listener: Listener,
}
#[derive(Deserialize)]
enum RequestLogExport {
Parquet {
location: RemoteStorageConfig,
disconnect: RemoteStorageConfig,
/// The region identifier to tag the entries with.
region: String,
/// How many rows to include in a row group
row_group_size: usize,
/// How large each column page should be in bytes
page_size: usize,
/// How large the total parquet file should be in bytes
size: i64,
/// How long to wait before forcing a file upload
maximum_duration: tokio::time::Duration,
// /// What level of compression to use
// compression: Compression,
},
}
#[derive(Deserialize)]
enum Cache {
/// Expire by LRU or by idle.
/// Note: "live" in "time-to-live" actually means idle here.
LruTtl {
/// Max number of entries.
size: usize,
/// Entry's time-to-live.
ttl: Duration,
},
}
#[derive(Deserialize)]
struct DataTransferExport {
/// http endpoint to receive periodic metric updates
endpoint: Option<String>,
/// how often metrics should be sent to a collection endpoint
interval: Option<String>,
/// interval for backup metric collection
backup_interval: std::time::Duration,
/// remote storage configuration for backup metric collection
/// Encoded as toml (same format as pageservers), eg
/// `{bucket_name='the-bucket',bucket_region='us-east-1',prefix_in_bucket='proxy',endpoint='http://minio:9000'}`
backup_remote_storage: Option<RemoteStorageConfig>,
/// chunk size for backup metric collection
/// Size of each event is no more than 400 bytes, so 2**22 is about 200MB before the compression.
backup_chunk_size: usize,
}
/// Neon proxy/router
#[derive(Parser)]
#[command(version = GIT_VERSION, about)]
@@ -120,12 +377,6 @@ struct ProxyCliArgs {
/// timeout for the TLS handshake
#[clap(long, default_value = "15s", value_parser = humantime::parse_duration)]
handshake_timeout: tokio::time::Duration,
/// http endpoint to receive periodic metric updates
#[clap(long)]
metric_collection_endpoint: Option<String>,
/// how often metrics should be sent to a collection endpoint
#[clap(long)]
metric_collection_interval: Option<String>,
/// cache for `wake_compute` api method (use `size=0` to disable)
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
wake_compute_cache: String,
@@ -152,40 +403,31 @@ struct ProxyCliArgs {
/// Wake compute rate limiter max number of requests per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
wake_compute_limit: Vec<RateBucketInfo>,
/// Redis rate limiter max number of requests per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_REDIS_SET)]
redis_rps_limit: Vec<RateBucketInfo>,
/// Cancellation channel size (max queue size for redis kv client)
#[clap(long, default_value_t = 1024)]
cancellation_ch_size: usize,
/// Cancellation ops batch size for redis
#[clap(long, default_value_t = 8)]
cancellation_batch_size: usize,
/// cache for `allowed_ips` (use `size=0` to disable)
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
allowed_ips_cache: String,
/// cache for `role_secret` (use `size=0` to disable)
#[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)]
role_secret_cache: String,
/// redis url for notifications (if empty, redis_host:port will be used for both notifications and streaming connections)
/// redis url for plain authentication
#[clap(long, alias("redis-notifications"))]
redis_plain: Option<String>,
/// what from the available authentications type to use for redis. Supported are "irsa" and "plain".
#[clap(long)]
redis_notifications: Option<String>,
/// what from the available authentications type to use for the regional redis we have. Supported are "irsa" and "plain".
#[clap(long, default_value = "irsa")]
redis_auth_type: String,
/// redis host for streaming connections (might be different from the notifications host)
redis_auth_type: Option<String>,
/// redis host for irsa authentication
#[clap(long)]
redis_host: Option<String>,
/// redis port for streaming connections (might be different from the notifications host)
/// redis port for irsa authentication
#[clap(long)]
redis_port: Option<u16>,
/// redis cluster name, used in aws elasticache
/// redis cluster name for irsa authentication
#[clap(long)]
redis_cluster_name: Option<String>,
/// redis user_id, used in aws elasticache
/// redis user_id for irsa authentication
#[clap(long)]
redis_user_id: Option<String>,
/// aws region to retrieve credentials
/// aws region for irsa authentication
#[clap(long, default_value_t = String::new())]
aws_region: String,
/// cache for `project_info` (use `size=0` to disable)
@@ -197,6 +439,12 @@ struct ProxyCliArgs {
#[clap(flatten)]
parquet_upload: ParquetUploadArgs,
/// http endpoint to receive periodic metric updates
#[clap(long)]
metric_collection_endpoint: Option<String>,
/// how often metrics should be sent to a collection endpoint
#[clap(long)]
metric_collection_interval: Option<String>,
/// interval for backup metric collection
#[clap(long, default_value = "10m", value_parser = humantime::parse_duration)]
metric_backup_collection_interval: std::time::Duration,
@@ -209,6 +457,7 @@ struct ProxyCliArgs {
/// Size of each event is no more than 400 bytes, so 2**22 is about 200MB before the compression.
#[clap(long, default_value = "4194304")]
metric_backup_collection_chunk_size: usize,
/// Whether to retry the connection to the compute node
#[clap(long, default_value = config::RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)]
connect_to_compute_retry: String,
@@ -319,208 +568,120 @@ pub async fn run() -> anyhow::Result<()> {
}
};
let args = ProxyCliArgs::parse();
let config = build_config(&args)?;
let auth_backend = build_auth_backend(&args)?;
match auth_backend {
Either::Left(auth_backend) => info!("Authentication backend: {auth_backend}"),
Either::Right(auth_backend) => info!("Authentication backend: {auth_backend:?}"),
}
info!("Using region: {}", args.aws_region);
let (regional_redis_client, redis_notifications_client) = configure_redis(&args).await?;
// Check that we can bind to address before further initialization
info!("Starting http on {}", args.http);
let http_listener = TcpListener::bind(args.http).await?.into_std()?;
info!("Starting mgmt on {}", args.mgmt);
let mgmt_listener = TcpListener::bind(args.mgmt).await?;
let proxy_listener = if args.is_auth_broker {
None
} else {
info!("Starting proxy on {}", args.proxy);
Some(TcpListener::bind(args.proxy).await?)
};
let sni_router_listeners = {
let args = &args.pg_sni_router;
if args.dest.is_some() {
ensure!(
args.tls_key.is_some(),
"sni-router-tls-key must be provided"
);
ensure!(
args.tls_cert.is_some(),
"sni-router-tls-cert must be provided"
);
info!(
"Starting pg-sni-router on {} and {}",
args.listen, args.listen_tls
);
Some((
TcpListener::bind(args.listen).await?,
TcpListener::bind(args.listen_tls).await?,
))
} else {
None
}
};
// TODO: rename the argument to something like serverless.
// It now covers more than just websockets, it also covers SQL over HTTP.
let serverless_listener = if let Some(serverless_address) = args.wss {
info!("Starting wss on {serverless_address}");
Some(TcpListener::bind(serverless_address).await?)
} else if args.is_auth_broker {
bail!("wss arg must be present for auth-broker")
} else {
None
};
let cancellation_token = CancellationToken::new();
let redis_rps_limit = Vec::leak(args.redis_rps_limit.clone());
RateBucketInfo::validate(redis_rps_limit)?;
let redis_kv_client = regional_redis_client
.as_ref()
.map(|redis_publisher| RedisKVClient::new(redis_publisher.clone(), redis_rps_limit));
// channel size should be higher than redis client limit to avoid blocking
let cancel_ch_size = args.cancellation_ch_size;
let (tx_cancel, rx_cancel) = tokio::sync::mpsc::channel(cancel_ch_size);
let cancellation_handler = Arc::new(CancellationHandler::new(
&config.connect_to_compute,
Some(tx_cancel),
));
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
RateBucketInfo::to_leaky_bucket(&args.endpoint_rps_limit)
.unwrap_or(EndpointRateLimiter::DEFAULT),
64,
));
let config: Root = toml::from_str(&tokio::fs::read_to_string("proxy.toml").await?)?;
// client facing tasks. these will exit on error or on cancellation
// cancellation returns Ok(())
let mut client_tasks = JoinSet::new();
match auth_backend {
Either::Left(auth_backend) => {
if let Some(proxy_listener) = proxy_listener {
client_tasks.spawn(crate::proxy::task_main(
config,
auth_backend,
proxy_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
));
}
if let Some(serverless_listener) = serverless_listener {
client_tasks.spawn(serverless::task_main(
config,
auth_backend,
serverless_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
));
}
}
Either::Right(auth_backend) => {
if let Some(proxy_listener) = proxy_listener {
client_tasks.spawn(crate::console_redirect_proxy::task_main(
config,
auth_backend,
proxy_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
));
}
}
}
// spawn pg-sni-router mode.
if let Some((listen, listen_tls)) = sni_router_listeners {
let args = args.pg_sni_router;
let dest = args.dest.expect("already asserted it is set");
let key_path = args.tls_key.expect("already asserted it is set");
let cert_path = args.tls_cert.expect("already asserted it is set");
let tls_config = super::pg_sni_router::parse_tls(&key_path, &cert_path)?;
let dest = Arc::new(dest);
client_tasks.spawn(super::pg_sni_router::task_main(
dest.clone(),
tls_config.clone(),
None,
listen,
cancellation_token.clone(),
));
client_tasks.spawn(super::pg_sni_router::task_main(
dest,
tls_config,
Some(config.connect_to_compute.tls.clone()),
listen_tls,
cancellation_token.clone(),
));
}
client_tasks.spawn(crate::context::parquet::worker(
cancellation_token.clone(),
args.parquet_upload,
));
// maintenance tasks. these never return unless there's an error
let mut maintenance_tasks = JoinSet::new();
maintenance_tasks.spawn(crate::signals::handle(cancellation_token.clone(), || {}));
maintenance_tasks.spawn(http::health_server::task_main(
http_listener,
AppMetrics {
jemalloc,
neon_metrics,
proxy: crate::metrics::Metrics::get(),
},
));
maintenance_tasks.spawn(control_plane::mgmt::task_main(mgmt_listener));
if let Some(metrics_config) = &config.metric_collection {
// TODO: Add gc regardles of the metric collection being enabled.
maintenance_tasks.spawn(usage_metrics::task_main(metrics_config));
}
let cancellation_token = CancellationToken::new();
#[cfg_attr(not(any(test, feature = "testing")), expect(irrefutable_let_patterns))]
if let Either::Left(auth::Backend::ControlPlane(api, ())) = &auth_backend {
if let crate::control_plane::client::ControlPlaneClient::ProxyV1(api) = &**api {
match (redis_notifications_client, regional_redis_client.clone()) {
(None, None) => {}
(client1, client2) => {
let cache = api.caches.project_info.clone();
if let Some(client) = client1 {
maintenance_tasks.spawn(notifications::task_main(
client,
cache.clone(),
args.region.clone(),
));
}
if let Some(client) = client2 {
maintenance_tasks.spawn(notifications::task_main(
client,
cache.clone(),
args.region.clone(),
));
}
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
}
match config.legacy {
LegacyModes::Proxy {
pglb,
neonkeeper,
http,
pg_sni_router,
} => {
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
// todo: use neonkeeper config.
EndpointRateLimiter::DEFAULT,
64,
));
info!("Starting proxy on {}", pglb.listener.addr);
let proxy_listener = TcpListener::bind(pglb.listener.addr).await?;
client_tasks.spawn(crate::proxy::task_main(
config,
auth_backend,
proxy_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
));
if let Some(http) = http {
info!("Starting wss on {}", http.listener.addr);
let http_listener = TcpListener::bind(http.listener.addr).await?;
client_tasks.spawn(serverless::task_main(
config,
auth_backend,
http_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
));
};
if let Some(redis) = neonkeeper.redis {
let client = configure_redis(redis.auth);
}
if let Some(mut redis_kv_client) = redis_kv_client {
if let Some(sni_router) = pg_sni_router {
let listen = TcpListener::bind(sni_router.listener.addr).await?;
let listen_tls = TcpListener::bind(sni_router.listener_tls.addr).await?;
let [KeyPair { key, cert }] = sni_router
.listener
.certs
.try_into()
.map_err(|_| anyhow!("only 1 keypair is supported for pg-sni-router"))?;
let tls_config = super::pg_sni_router::parse_tls(&key, &cert)?;
let dest = Arc::new(sni_router.dest);
client_tasks.spawn(super::pg_sni_router::task_main(
dest.clone(),
tls_config.clone(),
None,
listen,
cancellation_token.clone(),
));
client_tasks.spawn(super::pg_sni_router::task_main(
dest,
tls_config,
Some(config.connect_to_compute.tls.clone()),
listen_tls,
cancellation_token.clone(),
));
}
match neonkeeper.request_log_export {
Some(RequestLogExport::Parquet {
location,
disconnect,
region,
row_group_size,
page_size,
size,
maximum_duration,
}) => {
client_tasks.spawn(crate::context::parquet::worker(
cancellation_token.clone(),
args.parquet_upload,
args.region,
));
}
None => {}
}
if let (ControlPlaneBackend::Http(api), Some(redis)) =
(neonkeeper.cplane, neonkeeper.redis)
{
// project info cache and invalidation of that cache.
let cache = api.caches.project_info.clone();
maintenance_tasks.spawn(notifications::task_main(client.clone(), cache.clone()));
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
// cancellation key management
let mut redis_kv_client = RedisKVClient::new(client.clone());
maintenance_tasks.spawn(async move {
redis_kv_client.try_connect().await?;
handle_cancel_messages(
@@ -537,18 +698,139 @@ pub async fn run() -> anyhow::Result<()> {
// so let's wait forever instead.
std::future::pending().await
});
}
if let Some(regional_redis_client) = regional_redis_client {
// listen for notifications of new projects/endpoints/branches
let cache = api.caches.endpoints_cache.clone();
let con = regional_redis_client;
let span = tracing::info_span!("endpoints_cache");
maintenance_tasks.spawn(
async move { cache.do_read(con, cancellation_token.clone()).await }
async move { cache.do_read(client, cancellation_token.clone()).await }
.instrument(span),
);
}
}
LegacyModes::AuthBroker { neonkeeper, http } => {
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
// todo: use neonkeeper config.
EndpointRateLimiter::DEFAULT,
64,
));
info!("Starting wss on {}", http.listener.addr);
let http_listener = TcpListener::bind(http.listener.addr).await?;
if let Some(redis) = neonkeeper.redis {
let client = configure_redis(redis.auth);
}
client_tasks.spawn(serverless::task_main(
config,
auth_backend,
serverless_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
endpoint_rate_limiter.clone(),
));
match neonkeeper.request_log_export {
Some(RequestLogExport::Parquet {
location,
disconnect,
region,
row_group_size,
page_size,
size,
maximum_duration,
}) => {
client_tasks.spawn(crate::context::parquet::worker(
cancellation_token.clone(),
args.parquet_upload,
args.region,
));
}
None => {}
}
if let (ControlPlaneBackend::Http(api), Some(redis)) =
(neonkeeper.cplane, neonkeeper.redis)
{
// project info cache and invalidation of that cache.
let cache = api.caches.project_info.clone();
maintenance_tasks.spawn(notifications::task_main(client.clone(), cache.clone()));
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
// cancellation key management
let mut redis_kv_client = RedisKVClient::new(client.clone());
maintenance_tasks.spawn(async move {
redis_kv_client.try_connect().await?;
handle_cancel_messages(
&mut redis_kv_client,
rx_cancel,
args.cancellation_batch_size,
)
.await?;
drop(redis_kv_client);
// `handle_cancel_messages` was terminated due to the tx_cancel
// being dropped. this is not worthy of an error, and this task can only return `Err`,
// so let's wait forever instead.
std::future::pending().await
});
// listen for notifications of new projects/endpoints/branches
let cache = api.caches.endpoints_cache.clone();
let span = tracing::info_span!("endpoints_cache");
maintenance_tasks.spawn(
async move { cache.do_read(client, cancellation_token.clone()).await }
.instrument(span),
);
}
}
LegacyModes::ConsoleRedirect { console_redirect } => {
info!("Starting proxy on {}", console_redirect.listener.addr);
let proxy_listener = TcpListener::bind(console_redirect.listener.addr).await?;
info!("Starting mgmt on {}", console_redirect.listener.addr);
let mgmt_listener = TcpListener::bind(console_redirect.listener.addr).await?;
client_tasks.spawn(crate::console_redirect_proxy::task_main(
config,
auth_backend,
proxy_listener,
cancellation_token.clone(),
cancellation_handler.clone(),
));
maintenance_tasks.spawn(control_plane::mgmt::task_main(mgmt_listener));
}
}
// Check that we can bind to address before further initialization
info!("Starting http on {}", config.introspection.listener.addr);
let http_listener = TcpListener::bind(config.introspection.listener.addr)
.await?
.into_std()?;
// channel size should be higher than redis client limit to avoid blocking
let cancel_ch_size = args.cancellation_ch_size;
let (tx_cancel, rx_cancel) = tokio::sync::mpsc::channel(cancel_ch_size);
let cancellation_handler = Arc::new(CancellationHandler::new(
&config.connect_to_compute,
Some(tx_cancel),
));
maintenance_tasks.spawn(crate::signals::handle(cancellation_token.clone(), || {}));
maintenance_tasks.spawn(http::health_server::task_main(
http_listener,
AppMetrics {
jemalloc,
neon_metrics,
proxy: crate::metrics::Metrics::get(),
},
));
if let Some(metrics_config) = &config.metric_collection {
// TODO: Add gc regardles of the metric collection being enabled.
maintenance_tasks.spawn(usage_metrics::task_main(metrics_config));
}
let maintenance = loop {
@@ -673,7 +955,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
authentication_config,
proxy_protocol_v2: args.proxy_protocol_v2,
handshake_timeout: args.handshake_timeout,
region: args.region.clone(),
wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?,
connect_compute_locks,
connect_to_compute: compute_config,
@@ -833,58 +1114,45 @@ fn build_auth_backend(
}
}
async fn configure_redis(
args: &ProxyCliArgs,
) -> anyhow::Result<(
Option<ConnectionWithCredentialsProvider>,
Option<ConnectionWithCredentialsProvider>,
)> {
// TODO: untangle the config args
let regional_redis_client = match (args.redis_auth_type.as_str(), &args.redis_notifications) {
("plain", redis_url) => match redis_url {
None => {
bail!("plain auth requires redis_notifications to be set");
}
Some(url) => {
Some(ConnectionWithCredentialsProvider::new_with_static_credentials(url.clone()))
}
},
("irsa", _) => match (&args.redis_host, args.redis_port) {
(Some(host), Some(port)) => Some(
ConnectionWithCredentialsProvider::new_with_credentials_provider(
host.clone(),
port,
elasticache::CredentialsProvider::new(
args.aws_region.clone(),
args.redis_cluster_name.clone(),
args.redis_user_id.clone(),
)
.await,
),
),
(None, None) => {
// todo: upgrade to error?
warn!(
"irsa auth requires redis-host and redis-port to be set, continuing without regional_redis_client"
);
None
}
_ => {
bail!("redis-host and redis-port must be specified together");
}
},
_ => {
bail!("unknown auth type given");
async fn configure_redis(auth: RedisAuthentication) -> ConnectionWithCredentialsProvider {
match auth {
RedisAuthentication::Irsa {
host,
port,
cluster_name,
user_id,
aws_region,
} => ConnectionWithCredentialsProvider::new_with_credentials_provider(
host,
port,
elasticache::CredentialsProvider::new(aws_region, cluster_name, user_id).await,
),
RedisAuthentication::Basic { url } => {
ConnectionWithCredentialsProvider::new_with_static_credentials(url.clone())
}
}
}
None => None,
};
let redis_notifications_client = if let Some(url) = &args.redis_notifications {
Some(ConnectionWithCredentialsProvider::new_with_static_credentials(&**url))
} else {
regional_redis_client.clone()
// let redis_notifications_client = if let Some(url) = &args.redis_notifications {
// Some(ConnectionWithCredentialsProvider::new_with_static_credentials(&**url))
// } else {
// regional_redis_client.clone()
// };
Ok(redis_client)
}
None => None,
};
Ok((regional_redis_client, redis_notifications_client))
// let redis_notifications_client = if let Some(url) = &args.redis_notifications {
// Some(ConnectionWithCredentialsProvider::new_with_static_credentials(&**url))
// } else {
// regional_redis_client.clone()
// };
Ok(redis_client)
}
#[cfg(test)]

View File

@@ -3,7 +3,7 @@ use std::sync::Arc;
use anyhow::{Context, anyhow};
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
use postgres_client::RawCancelToken;
use postgres_client::CancelToken;
use postgres_client::tls::MakeTlsConnect;
use redis::{Cmd, FromRedisValue, Value};
use serde::{Deserialize, Serialize};
@@ -33,6 +33,7 @@ const CANCEL_KEY_TTL: i64 = 1_209_600; // 2 weeks cancellation key expire time
pub enum CancelKeyOp {
StoreCancelKey {
key: String,
field: String,
value: String,
resp_tx: Option<oneshot::Sender<anyhow::Result<()>>>,
_guard: CancelChannelSizeGuard<'static>,
@@ -40,7 +41,7 @@ pub enum CancelKeyOp {
},
GetCancelData {
key: String,
resp_tx: oneshot::Sender<anyhow::Result<String>>,
resp_tx: oneshot::Sender<anyhow::Result<Vec<(String, String)>>>,
_guard: CancelChannelSizeGuard<'static>,
},
RemoveCancelKey {
@@ -119,6 +120,7 @@ impl CancelKeyOp {
match self {
CancelKeyOp::StoreCancelKey {
key,
field,
value,
resp_tx,
_guard,
@@ -126,7 +128,7 @@ impl CancelKeyOp {
} => {
let reply =
resp_tx.map(|resp_tx| CancelReplyOp::StoreCancelKey { resp_tx, _guard });
pipe.add_command(Cmd::hset(&key, "data", value), reply);
pipe.add_command(Cmd::hset(&key, field, value), reply);
pipe.add_command_no_reply(Cmd::expire(key, expire));
}
CancelKeyOp::GetCancelData {
@@ -135,7 +137,7 @@ impl CancelKeyOp {
_guard,
} => {
let reply = CancelReplyOp::GetCancelData { resp_tx, _guard };
pipe.add_command_with_reply(Cmd::hget(key, "data"), reply);
pipe.add_command_with_reply(Cmd::hgetall(key), reply);
}
CancelKeyOp::RemoveCancelKey {
key,
@@ -158,7 +160,7 @@ pub enum CancelReplyOp {
_guard: CancelChannelSizeGuard<'static>,
},
GetCancelData {
resp_tx: oneshot::Sender<anyhow::Result<String>>,
resp_tx: oneshot::Sender<anyhow::Result<Vec<(String, String)>>>,
_guard: CancelChannelSizeGuard<'static>,
},
RemoveCancelKey {
@@ -345,7 +347,7 @@ impl CancellationHandler {
_guard: Metrics::get()
.proxy
.cancel_channel_size
.guard(RedisMsgKind::HGet),
.guard(RedisMsgKind::HGetAll),
};
let Some(tx) = &self.tx else {
@@ -364,21 +366,32 @@ impl CancellationHandler {
CancelError::InternalError
})?;
let cancel_state_str: String = match result {
Ok(s) => s,
let cancel_state_str: Option<String> = match result {
Ok(mut state) => {
if state.len() == 1 {
Some(state.remove(0).1)
} else {
tracing::warn!("unexpected number of entries in cancel state: {state:?}");
return Err(CancelError::InternalError);
}
}
Err(e) => {
tracing::warn!("failed to receive cancel state from redis: {e}");
return Err(CancelError::InternalError);
}
};
let cancel_closure: CancelClosure =
serde_json::from_str(&cancel_state_str).map_err(|e| {
tracing::warn!("failed to deserialize cancel state: {e}");
CancelError::InternalError
})?;
Ok(Some(cancel_closure))
let cancel_state: Option<CancelClosure> = match cancel_state_str {
Some(state) => {
let cancel_closure: CancelClosure = serde_json::from_str(&state).map_err(|e| {
tracing::warn!("failed to deserialize cancel state: {e}");
CancelError::InternalError
})?;
Some(cancel_closure)
}
None => None,
};
Ok(cancel_state)
}
/// Try to cancel a running query for the corresponding connection.
/// If the cancellation key is not found, it will be published to Redis.
@@ -457,7 +470,7 @@ impl CancellationHandler {
#[derive(Clone, Serialize, Deserialize)]
pub struct CancelClosure {
socket_addr: SocketAddr,
cancel_token: RawCancelToken,
cancel_token: CancelToken,
hostname: String, // for pg_sni router
user_info: ComputeUserInfo,
}
@@ -465,7 +478,7 @@ pub struct CancelClosure {
impl CancelClosure {
pub(crate) fn new(
socket_addr: SocketAddr,
cancel_token: RawCancelToken,
cancel_token: CancelToken,
hostname: String,
user_info: ComputeUserInfo,
) -> Self {
@@ -525,6 +538,7 @@ impl Session {
let op = CancelKeyOp::StoreCancelKey {
key: self.redis_key.clone(),
field: "data".to_string(),
value: closure_json,
resp_tx: None,
_guard: Metrics::get()

View File

@@ -9,7 +9,7 @@ use itertools::Itertools;
use postgres_client::config::{AuthKeys, SslMode};
use postgres_client::maybe_tls_stream::MaybeTlsStream;
use postgres_client::tls::MakeTlsConnect;
use postgres_client::{NoTls, RawCancelToken, RawConnection};
use postgres_client::{CancelToken, NoTls, RawConnection};
use postgres_protocol::message::backend::NoticeResponseBody;
use thiserror::Error;
use tokio::net::{TcpStream, lookup_host};
@@ -136,11 +136,11 @@ impl AuthInfo {
}
}
pub(crate) fn with_auth_keys(keys: ComputeCredentialKeys) -> Self {
pub(crate) fn with_auth_keys(keys: &ComputeCredentialKeys) -> Self {
Self {
auth: match keys {
ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(auth_keys)) => {
Some(Auth::Scram(Box::new(auth_keys)))
Some(Auth::Scram(Box::new(*auth_keys)))
}
ComputeCredentialKeys::JwtPayload(_) | ComputeCredentialKeys::None => None,
},
@@ -327,7 +327,8 @@ impl ConnectInfo {
// Yet another reason to rework the connection establishing code.
let cancel_closure = CancelClosure::new(
socket_addr,
RawCancelToken {
CancelToken {
socket_config: None,
ssl_mode: self.ssl_mode,
process_id,
secret_key,

View File

@@ -22,7 +22,6 @@ pub struct ProxyConfig {
pub http_config: HttpConfig,
pub authentication_config: AuthenticationConfig,
pub proxy_protocol_v2: ProxyProtocolV2,
pub region: String,
pub handshake_timeout: Duration,
pub wake_compute_retry_config: RetryConfig,
pub connect_compute_locks: ApiLocks<Host>,
@@ -70,7 +69,7 @@ pub struct AuthenticationConfig {
pub console_redirect_confirmation_timeout: tokio::time::Duration,
}
#[derive(Debug)]
#[derive(Debug, serde::Deserialize)]
pub struct EndpointCacheConfig {
/// Batch size to receive all endpoints on the startup.
pub initial_batch_size: usize,
@@ -206,7 +205,7 @@ impl FromStr for CacheOptions {
}
/// Helper for cmdline cache options parsing.
#[derive(Debug)]
#[derive(Debug, serde::Deserialize)]
pub struct ProjectInfoCacheOptions {
/// Max number of entries.
pub size: usize,

View File

@@ -11,12 +11,13 @@ use crate::config::{ProxyConfig, ProxyProtocolV2};
use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::metrics::{Metrics, NumClientConnectionsGuard};
use crate::pglb::connect_compute::{TcpMechanism, connect_to_compute};
use crate::pglb::handshake::{HandshakeData, handshake};
use crate::pglb::passthrough::ProxyPassthrough;
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute};
use crate::proxy::{ClientRequestError, ErrorSource, prepare_client_connection};
use crate::util::run_until_cancelled;
use crate::proxy::{
ClientRequestError, ErrorSource, prepare_client_connection, run_until_cancelled,
};
pub async fn task_main(
config: &'static ProxyConfig,
@@ -89,12 +90,7 @@ pub async fn task_main(
}
}
let ctx = RequestContext::new(
session_id,
conn_info,
crate::metrics::Protocol::Tcp,
&config.region,
);
let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Tcp);
let res = handle_client(
config,

View File

@@ -46,7 +46,6 @@ struct RequestContextInner {
pub(crate) session_id: Uuid,
pub(crate) protocol: Protocol,
first_packet: chrono::DateTime<Utc>,
region: &'static str,
pub(crate) span: Span,
// filled in as they are discovered
@@ -94,7 +93,6 @@ impl Clone for RequestContext {
session_id: inner.session_id,
protocol: inner.protocol,
first_packet: inner.first_packet,
region: inner.region,
span: info_span!("background_task"),
project: inner.project,
@@ -124,12 +122,7 @@ impl Clone for RequestContext {
}
impl RequestContext {
pub fn new(
session_id: Uuid,
conn_info: ConnectionInfo,
protocol: Protocol,
region: &'static str,
) -> Self {
pub fn new(session_id: Uuid, conn_info: ConnectionInfo, protocol: Protocol) -> Self {
// TODO: be careful with long lived spans
let span = info_span!(
"connect_request",
@@ -145,7 +138,6 @@ impl RequestContext {
session_id,
protocol,
first_packet: Utc::now(),
region,
span,
project: None,
@@ -179,7 +171,7 @@ impl RequestContext {
let ip = IpAddr::from([127, 0, 0, 1]);
let addr = SocketAddr::new(ip, 5432);
let conn_info = ConnectionInfo { addr, extra: None };
RequestContext::new(Uuid::now_v7(), conn_info, Protocol::Tcp, "test")
RequestContext::new(Uuid::now_v7(), conn_info, Protocol::Tcp)
}
pub(crate) fn console_application_name(&self) -> String {

View File

@@ -74,7 +74,7 @@ pub(crate) const FAILED_UPLOAD_MAX_RETRIES: u32 = 10;
#[derive(parquet_derive::ParquetRecordWriter)]
pub(crate) struct RequestData {
region: &'static str,
region: String,
protocol: &'static str,
/// Must be UTC. The derive macro doesn't like the timezones
timestamp: chrono::NaiveDateTime,
@@ -147,7 +147,7 @@ impl From<&RequestContextInner> for RequestData {
}),
jwt_issuer: value.jwt_issuer.clone(),
protocol: value.protocol.as_str(),
region: value.region,
region: String::new(),
error: value.error_kind.as_ref().map(|e| e.to_metric_label()),
success: value.success,
cold_start_info: value.cold_start_info.as_str(),
@@ -167,6 +167,7 @@ impl From<&RequestContextInner> for RequestData {
pub async fn worker(
cancellation_token: CancellationToken,
config: ParquetUploadArgs,
region: String,
) -> anyhow::Result<()> {
let Some(remote_storage_config) = config.parquet_upload_remote_storage else {
tracing::warn!("parquet request upload: no s3 bucket configured");
@@ -232,12 +233,17 @@ pub async fn worker(
.context("remote storage for disconnect events init")?;
let parquet_config_disconnect = parquet_config.clone();
tokio::try_join!(
worker_inner(storage, rx, parquet_config),
worker_inner(storage_disconnect, rx_disconnect, parquet_config_disconnect)
worker_inner(storage, rx, parquet_config, &region),
worker_inner(
storage_disconnect,
rx_disconnect,
parquet_config_disconnect,
&region
)
)
.map(|_| ())
} else {
worker_inner(storage, rx, parquet_config).await
worker_inner(storage, rx, parquet_config, &region).await
}
}
@@ -257,6 +263,7 @@ async fn worker_inner(
storage: GenericRemoteStorage,
rx: impl Stream<Item = RequestData>,
config: ParquetConfig,
region: &str,
) -> anyhow::Result<()> {
#[cfg(any(test, feature = "testing"))]
let storage = if config.test_remote_failures > 0 {
@@ -277,7 +284,8 @@ async fn worker_inner(
let mut last_upload = time::Instant::now();
let mut len = 0;
while let Some(row) = rx.next().await {
while let Some(mut row) = rx.next().await {
region.clone_into(&mut row.region);
rows.push(row);
let force = last_upload.elapsed() > config.max_duration;
if rows.len() == config.rows_per_group || force {
@@ -533,7 +541,7 @@ mod tests {
auth_method: None,
jwt_issuer: None,
protocol: ["tcp", "ws", "http"][rng.gen_range(0..3)],
region: "us-east-1",
region: String::new(),
error: None,
success: rng.r#gen(),
cold_start_info: "no",
@@ -565,7 +573,9 @@ mod tests {
.await
.unwrap();
worker_inner(storage, rx, config).await.unwrap();
worker_inner(storage, rx, config, "us-east-1")
.await
.unwrap();
let mut files = WalkDir::new(tmpdir.as_std_path())
.into_iter()

View File

@@ -106,5 +106,4 @@ mod tls;
mod types;
mod url;
mod usage_metrics;
mod util;
mod waiters;

View File

@@ -8,19 +8,19 @@ use crate::config::{ComputeConfig, RetryConfig};
use crate::context::RequestContext;
use crate::control_plane::errors::WakeComputeError;
use crate::control_plane::locks::ApiLocks;
use crate::control_plane::{self, NodeInfo};
use crate::control_plane::{self, CachedNodeInfo, NodeInfo};
use crate::error::ReportableError;
use crate::metrics::{
ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType,
};
use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute, retry_after, should_retry};
use crate::proxy::wake_compute::{WakeComputeBackend, wake_compute};
use crate::proxy::wake_compute::wake_compute;
use crate::types::Host;
/// If we couldn't connect, a cached connection info might be to blame
/// (e.g. the compute node's address might've changed at the wrong time).
/// Invalidate the cache entry (if any) to prevent subsequent errors.
#[tracing::instrument(skip_all)]
#[tracing::instrument(name = "invalidate_cache", skip_all)]
pub(crate) fn invalidate_cache(node_info: control_plane::CachedNodeInfo) -> NodeInfo {
let is_cached = node_info.cached();
if is_cached {
@@ -49,6 +49,14 @@ pub(crate) trait ConnectMechanism {
) -> Result<Self::Connection, Self::ConnectError>;
}
#[async_trait]
pub(crate) trait ComputeConnectBackend {
async fn wake_compute(
&self,
ctx: &RequestContext,
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError>;
}
pub(crate) struct TcpMechanism {
pub(crate) auth: AuthInfo,
/// connect_to_compute concurrency lock
@@ -83,7 +91,7 @@ impl ConnectMechanism for TcpMechanism {
/// Try to connect to the compute node, retrying if necessary.
#[tracing::instrument(skip_all)]
pub(crate) async fn connect_to_compute<M: ConnectMechanism, B: WakeComputeBackend>(
pub(crate) async fn connect_to_compute<M: ConnectMechanism, B: ComputeConnectBackend>(
ctx: &RequestContext,
mechanism: &M,
user_info: &B,

View File

@@ -1,3 +1,4 @@
pub mod connect_compute;
pub mod copy_bidirectional;
pub mod handshake;
pub mod inprocess;

View File

@@ -1,10 +1,8 @@
#[cfg(test)]
mod tests;
pub(crate) mod connect_compute;
pub(crate) mod retry;
pub(crate) mod wake_compute;
use std::sync::Arc;
use futures::FutureExt;
@@ -23,16 +21,15 @@ use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
use crate::context::RequestContext;
use crate::error::{ReportableError, UserFacingError};
use crate::metrics::{Metrics, NumClientConnectionsGuard};
use crate::pglb::connect_compute::{TcpMechanism, connect_to_compute};
pub use crate::pglb::copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute};
use crate::pglb::handshake::{HandshakeData, HandshakeError, handshake};
use crate::pglb::passthrough::ProxyPassthrough;
use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams};
use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol};
use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute};
use crate::rate_limiter::EndpointRateLimiter;
use crate::stream::{PqStream, Stream};
use crate::types::EndpointCacheKey;
use crate::util::run_until_cancelled;
use crate::{auth, compute};
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
@@ -49,6 +46,21 @@ impl ReportableError for TlsRequired {
impl UserFacingError for TlsRequired {}
pub async fn run_until_cancelled<F: std::future::Future>(
f: F,
cancellation_token: &CancellationToken,
) -> Option<F::Output> {
match futures::future::select(
std::pin::pin!(f),
std::pin::pin!(cancellation_token.cancelled()),
)
.await
{
futures::future::Either::Left((f, _)) => Some(f),
futures::future::Either::Right(((), _)) => None,
}
}
pub async fn task_main(
config: &'static ProxyConfig,
auth_backend: &'static auth::Backend<'static, ()>,
@@ -122,12 +134,7 @@ pub async fn task_main(
}
}
let ctx = RequestContext::new(
session_id,
conn_info,
crate::metrics::Protocol::Tcp,
&config.region,
);
let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Tcp);
let res = handle_client(
config,
@@ -346,12 +353,12 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
}
};
let (cplane, creds) = match user_info {
auth::Backend::ControlPlane(cplane, creds) => (cplane, creds),
let creds = match &user_info {
auth::Backend::ControlPlane(_, creds) => creds,
auth::Backend::Local(_) => unreachable!("local proxy does not run tcp proxy service"),
};
let params_compat = creds.info.options.get(NeonOptions::PARAMS_COMPAT).is_some();
let mut auth_info = compute::AuthInfo::with_auth_keys(creds.keys);
let mut auth_info = compute::AuthInfo::with_auth_keys(&creds.keys);
auth_info.set_startup_params(&params, params_compat);
let res = connect_to_compute(
@@ -361,7 +368,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
auth: auth_info,
locks: &config.connect_compute_locks,
},
&auth::Backend::ControlPlane(cplane, creds.info),
&user_info,
config.wake_compute_retry_config,
&config.connect_to_compute,
)

View File

@@ -8,7 +8,7 @@ use std::time::Duration;
use anyhow::{Context, bail};
use async_trait::async_trait;
use http::StatusCode;
use postgres_client::config::SslMode;
use postgres_client::config::{AuthKeys, ScramKeys, SslMode};
use postgres_client::tls::{MakeTlsConnect, NoTls};
use retry::{ShouldRetryWakeCompute, retry_after};
use rstest::rstest;
@@ -19,13 +19,15 @@ use tracing_test::traced_test;
use super::retry::CouldRetry;
use super::*;
use crate::auth::backend::{ComputeUserInfo, MaybeOwned};
use crate::auth::backend::{
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned,
};
use crate::config::{ComputeConfig, RetryConfig};
use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient};
use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status};
use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache};
use crate::error::ErrorKind;
use crate::proxy::connect_compute::ConnectMechanism;
use crate::pglb::connect_compute::ConnectMechanism;
use crate::tls::client_config::compute_client_config_with_certs;
use crate::tls::server_config::CertResolver;
use crate::types::{BranchId, EndpointId, ProjectId};
@@ -573,13 +575,19 @@ fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeIn
fn helper_create_connect_info(
mechanism: &TestConnectMechanism,
) -> auth::Backend<'static, ComputeUserInfo> {
) -> auth::Backend<'static, ComputeCredentials> {
auth::Backend::ControlPlane(
MaybeOwned::Owned(ControlPlaneClient::Test(Box::new(mechanism.clone()))),
ComputeUserInfo {
endpoint: "endpoint".into(),
user: "user".into(),
options: NeonOptions::parse_options_raw(""),
ComputeCredentials {
info: ComputeUserInfo {
endpoint: "endpoint".into(),
user: "user".into(),
options: NeonOptions::parse_options_raw(""),
},
keys: ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(ScramKeys {
client_key: [0; 32],
server_key: [0; 32],
})),
},
)
}

View File

@@ -1,4 +1,3 @@
use async_trait::async_trait;
use tracing::{error, info};
use crate::config::RetryConfig;
@@ -9,6 +8,7 @@ use crate::error::ReportableError;
use crate::metrics::{
ConnectOutcome, ConnectionFailuresBreakdownGroup, Metrics, RetriesMetricGroup, RetryType,
};
use crate::pglb::connect_compute::ComputeConnectBackend;
use crate::proxy::retry::{retry_after, should_retry};
// Use macro to retain original callsite.
@@ -23,12 +23,7 @@ macro_rules! log_wake_compute_error {
};
}
#[async_trait]
pub(crate) trait WakeComputeBackend {
async fn wake_compute(&self, ctx: &RequestContext) -> Result<CachedNodeInfo, WakeComputeError>;
}
pub(crate) async fn wake_compute<B: WakeComputeBackend>(
pub(crate) async fn wake_compute<B: ComputeConnectBackend>(
num_retries: &mut u32,
ctx: &RequestContext,
api: &B,

View File

@@ -140,12 +140,6 @@ impl RateBucketInfo {
Self::new(200, Duration::from_secs(600)),
];
// For all the sessions will be cancel key. So this limit is essentially global proxy limit.
pub const DEFAULT_REDIS_SET: [Self; 2] = [
Self::new(100_000, Duration::from_secs(1)),
Self::new(50_000, Duration::from_secs(10)),
];
pub fn rps(&self) -> f64 {
(self.max_rpi as f64) / self.interval.as_secs_f64()
}

View File

@@ -1,15 +1,10 @@
use std::time::Duration;
use futures::FutureExt;
use redis::aio::ConnectionLike;
use redis::{Cmd, FromRedisValue, Pipeline, RedisResult};
use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use crate::rate_limiter::{GlobalRateLimiter, RateBucketInfo};
pub struct RedisKVClient {
client: ConnectionWithCredentialsProvider,
limiter: GlobalRateLimiter,
}
#[allow(async_fn_in_trait)]
@@ -30,49 +25,34 @@ impl Queryable for Cmd {
}
impl RedisKVClient {
pub fn new(client: ConnectionWithCredentialsProvider, info: &'static [RateBucketInfo]) -> Self {
Self {
client,
limiter: GlobalRateLimiter::new(info.into()),
}
pub fn new(client: ConnectionWithCredentialsProvider) -> Self {
Self { client }
}
pub async fn try_connect(&mut self) -> anyhow::Result<()> {
self.client
.connect()
.boxed()
.await
.inspect_err(|e| tracing::error!("failed to connect to redis: {e}"))
match self.client.connect().await {
Ok(()) => {}
Err(e) => {
tracing::error!("failed to connect to redis: {e}");
return Err(e);
}
}
Ok(())
}
pub(crate) async fn query<T: FromRedisValue>(
&mut self,
q: &impl Queryable,
) -> anyhow::Result<T> {
if !self.limiter.check() {
tracing::info!("Rate limit exceeded. Skipping query");
return Err(anyhow::anyhow!("Rate limit exceeded"));
}
let e = match q.query(&mut self.client).await {
match q.query(&mut self.client).await {
Ok(t) => return Ok(t),
Err(e) => e,
};
tracing::error!("failed to run query: {e}");
match e.retry_method() {
redis::RetryMethod::Reconnect => {
tracing::info!("Redis client is disconnected. Reconnecting...");
self.try_connect().await?;
Err(e) => {
tracing::error!("failed to run query: {e}");
}
redis::RetryMethod::RetryImmediately => {}
redis::RetryMethod::WaitAndRetry => {
// somewhat arbitrary.
tokio::time::sleep(Duration::from_millis(100)).await;
}
_ => Err(e)?,
}
tracing::info!("Redis client is disconnected. Reconnecting...");
self.try_connect().await?;
Ok(q.query(&mut self.client).await?)
}
}

View File

@@ -141,29 +141,19 @@ where
struct MessageHandler<C: ProjectInfoCache + Send + Sync + 'static> {
cache: Arc<C>,
region_id: String,
}
impl<C: ProjectInfoCache + Send + Sync + 'static> Clone for MessageHandler<C> {
fn clone(&self) -> Self {
Self {
cache: self.cache.clone(),
region_id: self.region_id.clone(),
}
}
}
impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
pub(crate) fn new(cache: Arc<C>, region_id: String) -> Self {
Self { cache, region_id }
}
pub(crate) async fn increment_active_listeners(&self) {
self.cache.increment_active_listeners().await;
}
pub(crate) async fn decrement_active_listeners(&self) {
self.cache.decrement_active_listeners().await;
pub(crate) fn new(cache: Arc<C>) -> Self {
Self { cache }
}
#[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))]
@@ -276,7 +266,7 @@ async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
}
let mut conn = match try_connect(&redis).await {
Ok(conn) => {
handler.increment_active_listeners().await;
handler.cache.increment_active_listeners().await;
conn
}
Err(e) => {
@@ -297,11 +287,11 @@ async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
}
}
if cancellation_token.is_cancelled() {
handler.decrement_active_listeners().await;
handler.cache.decrement_active_listeners().await;
return Ok(());
}
}
handler.decrement_active_listeners().await;
handler.cache.decrement_active_listeners().await;
}
}
@@ -310,12 +300,11 @@ async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
pub async fn task_main<C>(
redis: ConnectionWithCredentialsProvider,
cache: Arc<C>,
region_id: String,
) -> anyhow::Result<Infallible>
where
C: ProjectInfoCache + Send + Sync + 'static,
{
let handler = MessageHandler::new(cache, region_id);
let handler = MessageHandler::new(cache);
// 6h - 1m.
// There will be 1 minute overlap between two tasks. But at least we can be sure that no message is lost.
let mut interval = tokio::time::interval(std::time::Duration::from_secs(6 * 60 * 60 - 60));

View File

@@ -21,7 +21,7 @@ use super::conn_pool_lib::{Client, ConnInfo, EndpointConnPool, GlobalConnPool};
use super::http_conn_pool::{self, HttpConnPool, Send, poll_http2_client};
use super::local_conn_pool::{self, EXT_NAME, EXT_SCHEMA, EXT_VERSION, LocalConnPool};
use crate::auth::backend::local::StaticAuthRules;
use crate::auth::backend::{ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo};
use crate::auth::backend::{ComputeCredentials, ComputeUserInfo};
use crate::auth::{self, AuthError};
use crate::compute_ctl::{
ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest,
@@ -34,7 +34,7 @@ use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
use crate::control_plane::locks::ApiLocks;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::intern::EndpointIdInt;
use crate::proxy::connect_compute::ConnectMechanism;
use crate::pglb::connect_compute::ConnectMechanism;
use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute};
use crate::rate_limiter::EndpointRateLimiter;
use crate::types::{EndpointId, Host, LOCAL_PROXY_SUFFIX};
@@ -180,15 +180,14 @@ impl PoolingBackend {
let conn_id = uuid::Uuid::new_v4();
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
let backend = self.auth_backend.as_ref().map(|()| keys.info);
crate::proxy::connect_compute::connect_to_compute(
let backend = self.auth_backend.as_ref().map(|()| keys);
crate::pglb::connect_compute::connect_to_compute(
ctx,
&TokioMechanism {
conn_id,
conn_info,
pool: self.pool.clone(),
locks: &self.config.connect_compute_locks,
keys: keys.keys,
},
&backend,
self.config.wake_compute_retry_config,
@@ -215,15 +214,18 @@ impl PoolingBackend {
let conn_id = uuid::Uuid::new_v4();
tracing::Span::current().record("conn_id", display(conn_id));
debug!(%conn_id, "pool: opening a new connection '{conn_info}'");
let backend = self.auth_backend.as_ref().map(|()| ComputeUserInfo {
user: conn_info.user_info.user.clone(),
endpoint: EndpointId::from(format!(
"{}{LOCAL_PROXY_SUFFIX}",
conn_info.user_info.endpoint.normalize()
)),
options: conn_info.user_info.options.clone(),
let backend = self.auth_backend.as_ref().map(|()| ComputeCredentials {
info: ComputeUserInfo {
user: conn_info.user_info.user.clone(),
endpoint: EndpointId::from(format!(
"{}{LOCAL_PROXY_SUFFIX}",
conn_info.user_info.endpoint.normalize()
)),
options: conn_info.user_info.options.clone(),
},
keys: crate::auth::backend::ComputeCredentialKeys::None,
});
crate::proxy::connect_compute::connect_to_compute(
crate::pglb::connect_compute::connect_to_compute(
ctx,
&HyperMechanism {
conn_id,
@@ -493,7 +495,6 @@ struct TokioMechanism {
pool: Arc<GlobalConnPool<postgres_client::Client, EndpointConnPool<postgres_client::Client>>>,
conn_info: ConnInfo,
conn_id: uuid::Uuid,
keys: ComputeCredentialKeys,
/// connect_to_compute concurrency lock
locks: &'static ApiLocks<Host>,
@@ -519,10 +520,6 @@ impl ConnectMechanism for TokioMechanism {
.dbname(&self.conn_info.dbname)
.connect_timeout(compute_config.timeout);
if let ComputeCredentialKeys::AuthKeys(auth_keys) = self.keys {
config.auth_keys(auth_keys);
}
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let res = config.connect(compute_config).await;
drop(pause);

View File

@@ -50,10 +50,10 @@ use crate::context::RequestContext;
use crate::ext::TaskExt;
use crate::metrics::Metrics;
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
use crate::proxy::run_until_cancelled;
use crate::rate_limiter::EndpointRateLimiter;
use crate::serverless::backend::PoolingBackend;
use crate::serverless::http_util::{api_error_into_response, json_response};
use crate::util::run_until_cancelled;
pub(crate) const SERVERLESS_DRIVER_SNI: &str = "api";
pub(crate) const AUTH_BROKER_SNI: &str = "apiauth";
@@ -417,12 +417,7 @@ async fn request_handler(
if config.http_config.accept_websockets
&& framed_websockets::upgrade::is_upgrade_request(&request)
{
let ctx = RequestContext::new(
session_id,
conn_info,
crate::metrics::Protocol::Ws,
&config.region,
);
let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Ws);
ctx.set_user_agent(
request
@@ -462,12 +457,7 @@ async fn request_handler(
// Return the response so the spawned future can continue.
Ok(response.map(|b| b.map_err(|x| match x {}).boxed()))
} else if request.uri().path() == "/sql" && *request.method() == Method::POST {
let ctx = RequestContext::new(
session_id,
conn_info,
crate::metrics::Protocol::Http,
&config.region,
);
let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Http);
let span = ctx.span();
let testodrome_id = request

View File

@@ -41,11 +41,10 @@ use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::http::{ReadBodyError, read_body_with_limit};
use crate::metrics::{HttpDirection, Metrics, SniGroup, SniKind};
use crate::pqproto::StartupMessageParams;
use crate::proxy::NeonOptions;
use crate::proxy::{NeonOptions, run_until_cancelled};
use crate::serverless::backend::HttpConnError;
use crate::types::{DbName, RoleName};
use crate::usage_metrics::{MetricCounter, MetricCounterRecorder};
use crate::util::run_until_cancelled;
#[derive(serde::Deserialize)]
#[serde(rename_all = "camelCase")]

View File

@@ -1,14 +0,0 @@
use std::pin::pin;
use futures::future::{Either, select};
use tokio_util::sync::CancellationToken;
pub async fn run_until_cancelled<F: Future>(
f: F,
cancellation_token: &CancellationToken,
) -> Option<F::Output> {
match select(pin!(f), pin!(cancellation_token.cancelled())).await {
Either::Left((f, _)) => Some(f),
Either::Right(((), _)) => None,
}
}

View File

@@ -77,7 +77,7 @@ class EndpointHttpClient(requests.Session):
status, err = json["status"], json.get("error")
assert status == "completed", f"{status}, error {err}"
wait_until(prewarmed, timeout=60)
wait_until(prewarmed)
def offload_lfc(self):
url = f"http://localhost:{self.external_port}/lfc/offload"

View File

@@ -4046,16 +4046,6 @@ def static_proxy(
"CREATE TABLE neon_control_plane.endpoints (endpoint_id VARCHAR(255) PRIMARY KEY, allowed_ips VARCHAR(255))"
)
vanilla_pg.stop()
vanilla_pg.edit_hba(
[
"local all all trust",
"host all all 127.0.0.1/32 scram-sha-256",
"host all all ::1/128 scram-sha-256",
]
)
vanilla_pg.start()
proxy_port = port_distributor.get_port()
mgmt_port = port_distributor.get_port()
http_port = port_distributor.get_port()

View File

@@ -16,7 +16,7 @@ if TYPE_CHECKING:
# Test restarting page server, while safekeeper and compute node keep
# running.
def test_pageserver_restarts_under_workload(neon_simple_env: NeonEnv, pg_bin: PgBin):
def test_pageserver_restarts_under_worload(neon_simple_env: NeonEnv, pg_bin: PgBin):
env = neon_simple_env
env.create_branch("test_pageserver_restarts")
endpoint = env.endpoints.create_start("test_pageserver_restarts")
@@ -28,11 +28,7 @@ def test_pageserver_restarts_under_workload(neon_simple_env: NeonEnv, pg_bin: Pg
pg_bin.run_capture(["pgbench", "-i", "-I", "dtGvp", f"-s{scale}", connstr])
pg_bin.run_capture(["pgbench", f"-T{n_restarts}", connstr])
thread = threading.Thread(
target=run_pgbench,
args=(endpoint.connstr(options="-cstatement_timeout=360s"),),
daemon=True,
)
thread = threading.Thread(target=run_pgbench, args=(endpoint.connstr(),), daemon=True)
thread.start()
for _ in range(n_restarts):

View File

@@ -19,15 +19,11 @@ TABLE_NAME = "neon_control_plane.endpoints"
async def test_proxy_psql_allowed_ips(static_proxy: NeonProxy, vanilla_pg: VanillaPostgres):
# Shouldn't be able to connect to this project
vanilla_pg.safe_psql(
f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('private-project', '8.8.8.8')",
user="proxy",
password="password",
f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('private-project', '8.8.8.8')"
)
# Should be able to connect to this project
vanilla_pg.safe_psql(
f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('generic-project', '::1,127.0.0.1')",
user="proxy",
password="password",
f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('generic-project', '::1,127.0.0.1')"
)
def check_cannot_connect(**kwargs):
@@ -64,9 +60,7 @@ async def test_proxy_http_allowed_ips(static_proxy: NeonProxy, vanilla_pg: Vanil
# Shouldn't be able to connect to this project
vanilla_pg.safe_psql(
f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('proxy', '8.8.8.8')",
user="proxy",
password="password",
f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('proxy', '8.8.8.8')"
)
def query(status: int, query: str, *args):
@@ -81,8 +75,6 @@ async def test_proxy_http_allowed_ips(static_proxy: NeonProxy, vanilla_pg: Vanil
query(400, "select 1;") # ip address is not allowed
# Should be able to connect to this project
vanilla_pg.safe_psql(
f"UPDATE {TABLE_NAME} SET allowed_ips = '8.8.8.8,127.0.0.1' WHERE endpoint_id = 'proxy'",
user="proxy",
password="password",
f"UPDATE {TABLE_NAME} SET allowed_ips = '8.8.8.8,127.0.0.1' WHERE endpoint_id = 'proxy'"
)
query(200, "select 1;") # should work now