Compare commits

...

9 Commits

Author SHA1 Message Date
Folke Behrens
c90b082222 Split handle_client and pass async callback for connect_once 2025-06-10 19:41:07 +02:00
Folke Behrens
0957c8ea69 Make a copy of connect_to_compute for pglb 2025-06-10 15:01:12 +02:00
Folke Behrens
6cea6560e9 Move task_main and handle_client to pglb 2025-06-10 14:59:12 +02:00
Folke Behrens
e38193c530 proxy: Move connect_to_compute back to proxy (#12181)
It's mostly responsible for waking, retrying, and caching. A new, thin
wrapper around compute_once will be PGLB's entry point
2025-06-10 11:23:03 +00:00
Konstantin Knizhnik
21949137ed Return last ring index instead of min_ring_index in prefetch_register_bufferv (#12039)
## Problem

See https://github.com/neondatabase/neon/issues/12018

Now `prefetch_register_bufferv` calculates min_ring_index of all vector
requests.
But because of pump prefetch state or connection failure, previous slots
can be already proceeded and reused.

## Summary of changes

Instead of returning minimal index, this function should return last
slot index.
Actually result of this function is used only in two places. A first
place just fort checking (and this check is redundant because the same
check is done in `prefetch_register_bufferv` itself.
And in the second place where index of filled slot is actually used,
there is just one request.
Sp fortunately this bug can cause only assert failure in debug build.

---------

Co-authored-by: Konstantin Knizhnik <knizhnik@neon.tech>
2025-06-10 10:09:46 +00:00
Trung Dinh
02f94edb60 Remove global static TENANTS (#12169)
## Problem
There is this TODO in code:
https://github.com/neondatabase/neon/blob/main/pageserver/src/tenant/mgr.rs#L300-L302
This is an old TODO by @jcsp.

## Summary of changes
This PR addresses the TODO. Specifically, it removes a global static
`TENANTS`. Instead the `TenantManager` now directly manages the tenant
map. Enhancing abstraction.

Essentially, this PR moves all module-level methods to inside the
implementation of `TenantManager`.
2025-06-10 09:26:40 +00:00
Conrad Ludgate
58327ef74d [proxy] fix sql-over-http password setting (#12177)
## Problem

Looks like our sql-over-http tests get to rely on "trust"
authentication, so the path that made sure the authkeys data was set was
never being hit.

## Summary of changes

Slight refactor to WakeComputeBackends, as well as making sure auth keys
are propagated. Fix tests to ensure passwords are tested.
2025-06-10 08:46:29 +00:00
Dmitrii Kovalkov
73be6bb736 fix(compute): use proper safekeeper in VotesCollectedMset (#12175)
## Problem
`VotesCollectedMset` uses the wrong safekeeper to update truncateLsn.
This led to some failed assert later in the code during running
safekeeper migration tests.
- Relates to https://github.com/neondatabase/neon/issues/11823

## Summary of changes
Use proper safekeeper to update truncateLsn in VotesCollectedMset
2025-06-10 07:16:42 +00:00
Alex Chi Z.
40d7583906 feat(pageserver): use hostname as feature flag resolver property (#12141)
## Problem

part of https://github.com/neondatabase/neon/issues/11813

## Summary of changes

Collect pageserver hostname property so that we can use it in the
PostHog UI. Not sure if this is the best way to do that -- open to
suggestions.

---------

Signed-off-by: Alex Chi Z <chi@neon.tech>
2025-06-10 07:10:41 +00:00
26 changed files with 1047 additions and 773 deletions

View File

@@ -1022,6 +1022,7 @@ impl RemoteStorage for S3Bucket {
let Version { key, .. } = &vd;
let version_id = vd.version_id().map(|v| v.0.as_str());
if version_id == Some("null") {
// TODO: check the behavior of using the SDK on a non-versioned container
return Err(TimeTravelError::Other(anyhow!(
"Received ListVersions response for key={key} with version_id='null', \
indicating either disabled versioning, or legacy objects with null version id values"

View File

@@ -573,7 +573,8 @@ fn start_pageserver(
tokio::sync::mpsc::unbounded_channel();
let deletion_queue_client = deletion_queue.new_client();
let background_purges = mgr::BackgroundPurges::default();
let tenant_manager = BACKGROUND_RUNTIME.block_on(mgr::init_tenant_mgr(
let tenant_manager = mgr::init(
conf,
background_purges.clone(),
TenantSharedResources {
@@ -584,10 +585,10 @@ fn start_pageserver(
basebackup_prepare_sender,
feature_resolver,
},
order,
shutdown_pageserver.clone(),
))?;
);
let tenant_manager = Arc::new(tenant_manager);
BACKGROUND_RUNTIME.block_on(mgr::init_tenant_mgr(tenant_manager.clone(), order))?;
let basebackup_cache = BasebackupCache::spawn(
BACKGROUND_RUNTIME.handle(),

View File

@@ -1,5 +1,6 @@
use std::{collections::HashMap, sync::Arc, time::Duration};
use pageserver_api::config::NodeMetadata;
use posthog_client_lite::{
CaptureEvent, FeatureResolverBackgroundLoop, PostHogClientConfig, PostHogEvaluationError,
PostHogFlagFilterPropertyValue,
@@ -86,7 +87,35 @@ impl FeatureResolver {
}
}
}
// TODO: add pageserver URL.
// TODO: move this to a background task so that we don't block startup in case of slow disk
let metadata_path = conf.metadata_path();
match std::fs::read_to_string(&metadata_path) {
Ok(metadata_str) => match serde_json::from_str::<NodeMetadata>(&metadata_str) {
Ok(metadata) => {
properties.insert(
"hostname".to_string(),
PostHogFlagFilterPropertyValue::String(metadata.http_host),
);
if let Some(cplane_region) = metadata.other.get("region_id") {
if let Some(cplane_region) = cplane_region.as_str() {
// This region contains the cell number
properties.insert(
"neon_region".to_string(),
PostHogFlagFilterPropertyValue::String(
cplane_region.to_string(),
),
);
}
}
}
Err(e) => {
tracing::warn!("Failed to parse metadata.json: {}", e);
}
},
Err(e) => {
tracing::warn!("Failed to read metadata.json: {}", e);
}
}
Arc::new(properties)
};
let fake_tenants = {

View File

@@ -12,7 +12,6 @@ use anyhow::Context;
use camino::{Utf8DirEntry, Utf8Path, Utf8PathBuf};
use futures::StreamExt;
use itertools::Itertools;
use once_cell::sync::Lazy;
use pageserver_api::key::Key;
use pageserver_api::models::{DetachBehavior, LocationConfigMode};
use pageserver_api::shard::{
@@ -103,7 +102,7 @@ pub(crate) enum TenantsMap {
/// [`init_tenant_mgr`] is not done yet.
Initializing,
/// [`init_tenant_mgr`] is done, all on-disk tenants have been loaded.
/// New tenants can be added using [`tenant_map_acquire_slot`].
/// New tenants can be added using [`TenantManager::tenant_map_acquire_slot`].
Open(BTreeMap<TenantShardId, TenantSlot>),
/// The pageserver has entered shutdown mode via [`TenantManager::shutdown`].
/// Existing tenants are still accessible, but no new tenants can be created.
@@ -284,9 +283,6 @@ impl BackgroundPurges {
}
}
static TENANTS: Lazy<std::sync::RwLock<TenantsMap>> =
Lazy::new(|| std::sync::RwLock::new(TenantsMap::Initializing));
/// Responsible for storing and mutating the collection of all tenants
/// that this pageserver has state for.
///
@@ -297,10 +293,7 @@ static TENANTS: Lazy<std::sync::RwLock<TenantsMap>> =
/// and attached modes concurrently.
pub struct TenantManager {
conf: &'static PageServerConf,
// TODO: currently this is a &'static pointing to TENANTs. When we finish refactoring
// out of that static variable, the TenantManager can own this.
// See https://github.com/neondatabase/neon/issues/5796
tenants: &'static std::sync::RwLock<TenantsMap>,
tenants: std::sync::RwLock<TenantsMap>,
resources: TenantSharedResources,
// Long-running operations that happen outside of a [`Tenant`] lifetime should respect this token.
@@ -479,21 +472,43 @@ pub(crate) enum DeleteTenantError {
Other(#[from] anyhow::Error),
}
/// Initialize repositories with locally available timelines.
/// Initialize repositories at `Initializing` state.
pub fn init(
conf: &'static PageServerConf,
background_purges: BackgroundPurges,
resources: TenantSharedResources,
cancel: CancellationToken,
) -> TenantManager {
TenantManager {
conf,
tenants: std::sync::RwLock::new(TenantsMap::Initializing),
resources,
cancel,
background_purges,
}
}
/// Transition repositories from `Initializing` state to `Open` state with locally available timelines.
/// Timelines that are only partially available locally (remote storage has more data than this pageserver)
/// are scheduled for download and added to the tenant once download is completed.
#[instrument(skip_all)]
pub async fn init_tenant_mgr(
conf: &'static PageServerConf,
background_purges: BackgroundPurges,
resources: TenantSharedResources,
tenant_manager: Arc<TenantManager>,
init_order: InitializationOrder,
cancel: CancellationToken,
) -> anyhow::Result<TenantManager> {
) -> anyhow::Result<()> {
debug_assert!(matches!(
*tenant_manager.tenants.read().unwrap(),
TenantsMap::Initializing
));
let mut tenants = BTreeMap::new();
let ctx = RequestContext::todo_child(TaskKind::Startup, DownloadBehavior::Warn);
let conf = tenant_manager.conf;
let resources = &tenant_manager.resources;
let cancel = &tenant_manager.cancel;
let background_purges = &tenant_manager.background_purges;
// Initialize dynamic limits that depend on system resources
let system_memory =
sysinfo::System::new_with_specifics(sysinfo::RefreshKind::new().with_memory())
@@ -512,7 +527,7 @@ pub async fn init_tenant_mgr(
let tenant_configs = init_load_tenant_configs(conf).await;
// Determine which tenants are to be secondary or attached, and in which generation
let tenant_modes = init_load_generations(conf, &tenant_configs, &resources, &cancel).await?;
let tenant_modes = init_load_generations(conf, &tenant_configs, resources, cancel).await?;
tracing::info!(
"Attaching {} tenants at startup, warming up {} at a time",
@@ -669,18 +684,10 @@ pub async fn init_tenant_mgr(
info!("Processed {} local tenants at startup", tenants.len());
let mut tenants_map = TENANTS.write().unwrap();
assert!(matches!(&*tenants_map, &TenantsMap::Initializing));
let mut tenant_map = tenant_manager.tenants.write().unwrap();
*tenant_map = TenantsMap::Open(tenants);
*tenants_map = TenantsMap::Open(tenants);
Ok(TenantManager {
conf,
tenants: &TENANTS,
resources,
cancel: CancellationToken::new(),
background_purges,
})
Ok(())
}
/// Wrapper for Tenant::spawn that checks invariants before running
@@ -719,142 +726,6 @@ fn tenant_spawn(
)
}
async fn shutdown_all_tenants0(tenants: &std::sync::RwLock<TenantsMap>) {
let mut join_set = JoinSet::new();
#[cfg(all(debug_assertions, not(test)))]
{
// Check that our metrics properly tracked the size of the tenants map. This is a convenient location to check,
// as it happens implicitly at the end of tests etc.
let m = tenants.read().unwrap();
debug_assert_eq!(METRICS.slots_total(), m.len() as u64);
}
// Atomically, 1. create the shutdown tasks and 2. prevent creation of new tenants.
let (total_in_progress, total_attached) = {
let mut m = tenants.write().unwrap();
match &mut *m {
TenantsMap::Initializing => {
*m = TenantsMap::ShuttingDown(BTreeMap::default());
info!("tenants map is empty");
return;
}
TenantsMap::Open(tenants) => {
let mut shutdown_state = BTreeMap::new();
let mut total_in_progress = 0;
let mut total_attached = 0;
for (tenant_shard_id, v) in std::mem::take(tenants).into_iter() {
match v {
TenantSlot::Attached(t) => {
shutdown_state.insert(tenant_shard_id, TenantSlot::Attached(t.clone()));
join_set.spawn(
async move {
let res = {
let (_guard, shutdown_progress) = completion::channel();
t.shutdown(shutdown_progress, ShutdownMode::FreezeAndFlush).await
};
if let Err(other_progress) = res {
// join the another shutdown in progress
other_progress.wait().await;
}
// we cannot afford per tenant logging here, because if s3 is degraded, we are
// going to log too many lines
debug!("tenant successfully stopped");
}
.instrument(info_span!("shutdown", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug())),
);
total_attached += 1;
}
TenantSlot::Secondary(state) => {
// We don't need to wait for this individually per-tenant: the
// downloader task will be waited on eventually, this cancel
// is just to encourage it to drop out if it is doing work
// for this tenant right now.
state.cancel.cancel();
shutdown_state.insert(tenant_shard_id, TenantSlot::Secondary(state));
}
TenantSlot::InProgress(notify) => {
// InProgress tenants are not visible in TenantsMap::ShuttingDown: we will
// wait for their notifications to fire in this function.
join_set.spawn(async move {
notify.wait().await;
});
total_in_progress += 1;
}
}
}
*m = TenantsMap::ShuttingDown(shutdown_state);
(total_in_progress, total_attached)
}
TenantsMap::ShuttingDown(_) => {
error!(
"already shutting down, this function isn't supposed to be called more than once"
);
return;
}
}
};
let started_at = std::time::Instant::now();
info!(
"Waiting for {} InProgress tenants and {} Attached tenants to shut down",
total_in_progress, total_attached
);
let total = join_set.len();
let mut panicked = 0;
let mut buffering = true;
const BUFFER_FOR: std::time::Duration = std::time::Duration::from_millis(500);
let mut buffered = std::pin::pin!(tokio::time::sleep(BUFFER_FOR));
while !join_set.is_empty() {
tokio::select! {
Some(joined) = join_set.join_next() => {
match joined {
Ok(()) => {},
Err(join_error) if join_error.is_cancelled() => {
unreachable!("we are not cancelling any of the tasks");
}
Err(join_error) if join_error.is_panic() => {
// cannot really do anything, as this panic is likely a bug
panicked += 1;
}
Err(join_error) => {
warn!("unknown kind of JoinError: {join_error}");
}
}
if !buffering {
// buffer so that every 500ms since the first update (or starting) we'll log
// how far away we are; this is because we will get SIGKILL'd at 10s, and we
// are not able to log *then*.
buffering = true;
buffered.as_mut().reset(tokio::time::Instant::now() + BUFFER_FOR);
}
},
_ = &mut buffered, if buffering => {
buffering = false;
info!(remaining = join_set.len(), total, elapsed_ms = started_at.elapsed().as_millis(), "waiting for tenants to shutdown");
}
}
}
if panicked > 0 {
warn!(
panicked,
total, "observed panicks while shutting down tenants"
);
}
// caller will log how long we took
}
#[derive(thiserror::Error, Debug)]
pub(crate) enum UpsertLocationError {
#[error("Bad config request: {0}")]
@@ -1056,7 +927,8 @@ impl TenantManager {
// the tenant is inaccessible to the outside world while we are doing this, but that is sensible:
// the state is ill-defined while we're in transition. Transitions are async, but fast: we do
// not do significant I/O, and shutdowns should be prompt via cancellation tokens.
let mut slot_guard = tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)
let mut slot_guard = self
.tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)
.map_err(|e| match e {
TenantSlotError::NotFound(_) => {
unreachable!("Called with mode Any")
@@ -1223,6 +1095,75 @@ impl TenantManager {
}
}
fn tenant_map_acquire_slot(
&self,
tenant_shard_id: &TenantShardId,
mode: TenantSlotAcquireMode,
) -> Result<SlotGuard, TenantSlotError> {
use TenantSlotAcquireMode::*;
METRICS.tenant_slot_writes.inc();
let mut locked = self.tenants.write().unwrap();
let span = tracing::info_span!("acquire_slot", tenant_id=%tenant_shard_id.tenant_id, shard_id = %tenant_shard_id.shard_slug());
let _guard = span.enter();
let m = match &mut *locked {
TenantsMap::Initializing => return Err(TenantMapError::StillInitializing.into()),
TenantsMap::ShuttingDown(_) => return Err(TenantMapError::ShuttingDown.into()),
TenantsMap::Open(m) => m,
};
use std::collections::btree_map::Entry;
let entry = m.entry(*tenant_shard_id);
match entry {
Entry::Vacant(v) => match mode {
MustExist => {
tracing::debug!("Vacant && MustExist: return NotFound");
Err(TenantSlotError::NotFound(*tenant_shard_id))
}
_ => {
let (completion, barrier) = utils::completion::channel();
let inserting = TenantSlot::InProgress(barrier);
METRICS.slot_inserted(&inserting);
v.insert(inserting);
tracing::debug!("Vacant, inserted InProgress");
Ok(SlotGuard::new(
*tenant_shard_id,
None,
completion,
&self.tenants,
))
}
},
Entry::Occupied(mut o) => {
// Apply mode-driven checks
match (o.get(), mode) {
(TenantSlot::InProgress(_), _) => {
tracing::debug!("Occupied, failing for InProgress");
Err(TenantSlotError::InProgress)
}
_ => {
// Happy case: the slot was not in any state that violated our mode
let (completion, barrier) = utils::completion::channel();
let in_progress = TenantSlot::InProgress(barrier);
METRICS.slot_inserted(&in_progress);
let old_value = o.insert(in_progress);
METRICS.slot_removed(&old_value);
tracing::debug!("Occupied, replaced with InProgress");
Ok(SlotGuard::new(
*tenant_shard_id,
Some(old_value),
completion,
&self.tenants,
))
}
}
}
}
}
/// Resetting a tenant is equivalent to detaching it, then attaching it again with the same
/// LocationConf that was last used to attach it. Optionally, the local file cache may be
/// dropped before re-attaching.
@@ -1239,7 +1180,8 @@ impl TenantManager {
drop_cache: bool,
ctx: &RequestContext,
) -> anyhow::Result<()> {
let mut slot_guard = tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)?;
let mut slot_guard =
self.tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)?;
let Some(old_slot) = slot_guard.get_old_value() else {
anyhow::bail!("Tenant not found when trying to reset");
};
@@ -1388,7 +1330,8 @@ impl TenantManager {
Ok(())
}
let slot_guard = tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)?;
let slot_guard =
self.tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)?;
match &slot_guard.old_value {
Some(TenantSlot::Attached(tenant)) => {
// Legacy deletion flow: the tenant remains attached, goes to Stopping state, and
@@ -1539,7 +1482,7 @@ impl TenantManager {
// Phase 2: Put the parent shard to InProgress and grab a reference to the parent Tenant
drop(tenant);
let mut parent_slot_guard =
tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)?;
self.tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)?;
let parent = match parent_slot_guard.get_old_value() {
Some(TenantSlot::Attached(t)) => t,
Some(TenantSlot::Secondary(_)) => anyhow::bail!("Tenant location in secondary mode"),
@@ -1843,7 +1786,145 @@ impl TenantManager {
pub(crate) async fn shutdown(&self) {
self.cancel.cancel();
shutdown_all_tenants0(self.tenants).await
self.shutdown_all_tenants0().await
}
async fn shutdown_all_tenants0(&self) {
let mut join_set = JoinSet::new();
#[cfg(all(debug_assertions, not(test)))]
{
// Check that our metrics properly tracked the size of the tenants map. This is a convenient location to check,
// as it happens implicitly at the end of tests etc.
let m = self.tenants.read().unwrap();
debug_assert_eq!(METRICS.slots_total(), m.len() as u64);
}
// Atomically, 1. create the shutdown tasks and 2. prevent creation of new tenants.
let (total_in_progress, total_attached) = {
let mut m = self.tenants.write().unwrap();
match &mut *m {
TenantsMap::Initializing => {
*m = TenantsMap::ShuttingDown(BTreeMap::default());
info!("tenants map is empty");
return;
}
TenantsMap::Open(tenants) => {
let mut shutdown_state = BTreeMap::new();
let mut total_in_progress = 0;
let mut total_attached = 0;
for (tenant_shard_id, v) in std::mem::take(tenants).into_iter() {
match v {
TenantSlot::Attached(t) => {
shutdown_state
.insert(tenant_shard_id, TenantSlot::Attached(t.clone()));
join_set.spawn(
async move {
let res = {
let (_guard, shutdown_progress) = completion::channel();
t.shutdown(shutdown_progress, ShutdownMode::FreezeAndFlush).await
};
if let Err(other_progress) = res {
// join the another shutdown in progress
other_progress.wait().await;
}
// we cannot afford per tenant logging here, because if s3 is degraded, we are
// going to log too many lines
debug!("tenant successfully stopped");
}
.instrument(info_span!("shutdown", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug())),
);
total_attached += 1;
}
TenantSlot::Secondary(state) => {
// We don't need to wait for this individually per-tenant: the
// downloader task will be waited on eventually, this cancel
// is just to encourage it to drop out if it is doing work
// for this tenant right now.
state.cancel.cancel();
shutdown_state
.insert(tenant_shard_id, TenantSlot::Secondary(state));
}
TenantSlot::InProgress(notify) => {
// InProgress tenants are not visible in TenantsMap::ShuttingDown: we will
// wait for their notifications to fire in this function.
join_set.spawn(async move {
notify.wait().await;
});
total_in_progress += 1;
}
}
}
*m = TenantsMap::ShuttingDown(shutdown_state);
(total_in_progress, total_attached)
}
TenantsMap::ShuttingDown(_) => {
error!(
"already shutting down, this function isn't supposed to be called more than once"
);
return;
}
}
};
let started_at = std::time::Instant::now();
info!(
"Waiting for {} InProgress tenants and {} Attached tenants to shut down",
total_in_progress, total_attached
);
let total = join_set.len();
let mut panicked = 0;
let mut buffering = true;
const BUFFER_FOR: std::time::Duration = std::time::Duration::from_millis(500);
let mut buffered = std::pin::pin!(tokio::time::sleep(BUFFER_FOR));
while !join_set.is_empty() {
tokio::select! {
Some(joined) = join_set.join_next() => {
match joined {
Ok(()) => {},
Err(join_error) if join_error.is_cancelled() => {
unreachable!("we are not cancelling any of the tasks");
}
Err(join_error) if join_error.is_panic() => {
// cannot really do anything, as this panic is likely a bug
panicked += 1;
}
Err(join_error) => {
warn!("unknown kind of JoinError: {join_error}");
}
}
if !buffering {
// buffer so that every 500ms since the first update (or starting) we'll log
// how far away we are; this is because we will get SIGKILL'd at 10s, and we
// are not able to log *then*.
buffering = true;
buffered.as_mut().reset(tokio::time::Instant::now() + BUFFER_FOR);
}
},
_ = &mut buffered, if buffering => {
buffering = false;
info!(remaining = join_set.len(), total, elapsed_ms = started_at.elapsed().as_millis(), "waiting for tenants to shutdown");
}
}
}
if panicked > 0 {
warn!(
panicked,
total, "observed panicks while shutting down tenants"
);
}
// caller will log how long we took
}
/// Detaches a tenant, and removes its local files asynchronously.
@@ -1889,12 +1970,12 @@ impl TenantManager {
.map(Some)
};
let mut removal_result = remove_tenant_from_memory(
self.tenants,
tenant_shard_id,
tenant_dir_rename_operation(tenant_shard_id),
)
.await;
let mut removal_result = self
.remove_tenant_from_memory(
tenant_shard_id,
tenant_dir_rename_operation(tenant_shard_id),
)
.await;
// If the tenant was not found, it was likely already removed. Attempt to remove the tenant
// directory on disk anyway. For example, during shard splits, we shut down and remove the
@@ -1948,17 +2029,16 @@ impl TenantManager {
) -> Result<HashSet<TimelineId>, detach_ancestor::Error> {
use detach_ancestor::Error;
let slot_guard =
tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::MustExist).map_err(
|e| {
use TenantSlotError::*;
let slot_guard = self
.tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::MustExist)
.map_err(|e| {
use TenantSlotError::*;
match e {
MapState(TenantMapError::ShuttingDown) => Error::ShuttingDown,
NotFound(_) | InProgress | MapState(_) => Error::DetachReparent(e.into()),
}
},
)?;
match e {
MapState(TenantMapError::ShuttingDown) => Error::ShuttingDown,
NotFound(_) | InProgress | MapState(_) => Error::DetachReparent(e.into()),
}
})?;
let tenant = {
let old_slot = slot_guard
@@ -2291,6 +2371,80 @@ impl TenantManager {
other => ApiError::InternalServerError(anyhow::anyhow!(other)),
})
}
/// Stops and removes the tenant from memory, if it's not [`TenantState::Stopping`] already, bails otherwise.
/// Allows to remove other tenant resources manually, via `tenant_cleanup`.
/// If the cleanup fails, tenant will stay in memory in [`TenantState::Broken`] state, and another removal
async fn remove_tenant_from_memory<V, F>(
&self,
tenant_shard_id: TenantShardId,
tenant_cleanup: F,
) -> Result<V, TenantStateError>
where
F: std::future::Future<Output = anyhow::Result<V>>,
{
let mut slot_guard =
self.tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::MustExist)?;
// allow pageserver shutdown to await for our completion
let (_guard, progress) = completion::channel();
// The SlotGuard allows us to manipulate the Tenant object without fear of some
// concurrent API request doing something else for the same tenant ID.
let attached_tenant = match slot_guard.get_old_value() {
Some(TenantSlot::Attached(tenant)) => {
// whenever we remove a tenant from memory, we don't want to flush and wait for upload
let shutdown_mode = ShutdownMode::Hard;
// shutdown is sure to transition tenant to stopping, and wait for all tasks to complete, so
// that we can continue safely to cleanup.
match tenant.shutdown(progress, shutdown_mode).await {
Ok(()) => {}
Err(_other) => {
// if pageserver shutdown or other detach/ignore is already ongoing, we don't want to
// wait for it but return an error right away because these are distinct requests.
slot_guard.revert();
return Err(TenantStateError::IsStopping(tenant_shard_id));
}
}
Some(tenant)
}
Some(TenantSlot::Secondary(secondary_state)) => {
tracing::info!("Shutting down in secondary mode");
secondary_state.shutdown().await;
None
}
Some(TenantSlot::InProgress(_)) => {
// Acquiring a slot guarantees its old value was not InProgress
unreachable!();
}
None => None,
};
match tenant_cleanup
.await
.with_context(|| format!("Failed to run cleanup for tenant {tenant_shard_id}"))
{
Ok(hook_value) => {
// Success: drop the old TenantSlot::Attached.
slot_guard
.drop_old_value()
.expect("We just called shutdown");
Ok(hook_value)
}
Err(e) => {
// If we had a Tenant, set it to Broken and put it back in the TenantsMap
if let Some(attached_tenant) = attached_tenant {
attached_tenant.set_broken(e.to_string()).await;
}
// Leave the broken tenant in the map
slot_guard.revert();
Err(TenantStateError::Other(e))
}
}
}
}
#[derive(Debug, thiserror::Error)]
@@ -2455,7 +2609,7 @@ pub(crate) enum TenantMapError {
/// this tenant to retry later, or wait for the InProgress state to end.
///
/// This structure enforces the important invariant that we do not have overlapping
/// tasks that will try use local storage for a the same tenant ID: we enforce that
/// tasks that will try to use local storage for a the same tenant ID: we enforce that
/// the previous contents of a slot have been shut down before the slot can be
/// left empty or used for something else
///
@@ -2468,7 +2622,7 @@ pub(crate) enum TenantMapError {
/// The `old_value` may be dropped before the SlotGuard is dropped, by calling
/// `drop_old_value`. It is an error to call this without shutting down
/// the conents of `old_value`.
pub(crate) struct SlotGuard {
pub(crate) struct SlotGuard<'a> {
tenant_shard_id: TenantShardId,
old_value: Option<TenantSlot>,
upserted: bool,
@@ -2476,19 +2630,23 @@ pub(crate) struct SlotGuard {
/// [`TenantSlot::InProgress`] carries the corresponding Barrier: it will
/// release any waiters as soon as this SlotGuard is dropped.
completion: utils::completion::Completion,
tenants: &'a std::sync::RwLock<TenantsMap>,
}
impl SlotGuard {
impl<'a> SlotGuard<'a> {
fn new(
tenant_shard_id: TenantShardId,
old_value: Option<TenantSlot>,
completion: utils::completion::Completion,
tenants: &'a std::sync::RwLock<TenantsMap>,
) -> Self {
Self {
tenant_shard_id,
old_value,
upserted: false,
completion,
tenants,
}
}
@@ -2512,8 +2670,8 @@ impl SlotGuard {
));
}
let replaced = {
let mut locked = TENANTS.write().unwrap();
let replaced: Option<TenantSlot> = {
let mut locked = self.tenants.write().unwrap();
if let TenantSlot::InProgress(_) = new_value {
// It is never expected to try and upsert InProgress via this path: it should
@@ -2621,7 +2779,7 @@ impl SlotGuard {
}
}
impl Drop for SlotGuard {
impl<'a> Drop for SlotGuard<'a> {
fn drop(&mut self) {
if self.upserted {
return;
@@ -2629,7 +2787,7 @@ impl Drop for SlotGuard {
// Our old value is already shutdown, or it never existed: it is safe
// for us to fully release the TenantSlot back into an empty state
let mut locked = TENANTS.write().unwrap();
let mut locked = self.tenants.write().unwrap();
let m = match &mut *locked {
TenantsMap::Initializing => {
@@ -2711,151 +2869,6 @@ enum TenantSlotAcquireMode {
MustExist,
}
fn tenant_map_acquire_slot(
tenant_shard_id: &TenantShardId,
mode: TenantSlotAcquireMode,
) -> Result<SlotGuard, TenantSlotError> {
tenant_map_acquire_slot_impl(tenant_shard_id, &TENANTS, mode)
}
fn tenant_map_acquire_slot_impl(
tenant_shard_id: &TenantShardId,
tenants: &std::sync::RwLock<TenantsMap>,
mode: TenantSlotAcquireMode,
) -> Result<SlotGuard, TenantSlotError> {
use TenantSlotAcquireMode::*;
METRICS.tenant_slot_writes.inc();
let mut locked = tenants.write().unwrap();
let span = tracing::info_span!("acquire_slot", tenant_id=%tenant_shard_id.tenant_id, shard_id = %tenant_shard_id.shard_slug());
let _guard = span.enter();
let m = match &mut *locked {
TenantsMap::Initializing => return Err(TenantMapError::StillInitializing.into()),
TenantsMap::ShuttingDown(_) => return Err(TenantMapError::ShuttingDown.into()),
TenantsMap::Open(m) => m,
};
use std::collections::btree_map::Entry;
let entry = m.entry(*tenant_shard_id);
match entry {
Entry::Vacant(v) => match mode {
MustExist => {
tracing::debug!("Vacant && MustExist: return NotFound");
Err(TenantSlotError::NotFound(*tenant_shard_id))
}
_ => {
let (completion, barrier) = utils::completion::channel();
let inserting = TenantSlot::InProgress(barrier);
METRICS.slot_inserted(&inserting);
v.insert(inserting);
tracing::debug!("Vacant, inserted InProgress");
Ok(SlotGuard::new(*tenant_shard_id, None, completion))
}
},
Entry::Occupied(mut o) => {
// Apply mode-driven checks
match (o.get(), mode) {
(TenantSlot::InProgress(_), _) => {
tracing::debug!("Occupied, failing for InProgress");
Err(TenantSlotError::InProgress)
}
_ => {
// Happy case: the slot was not in any state that violated our mode
let (completion, barrier) = utils::completion::channel();
let in_progress = TenantSlot::InProgress(barrier);
METRICS.slot_inserted(&in_progress);
let old_value = o.insert(in_progress);
METRICS.slot_removed(&old_value);
tracing::debug!("Occupied, replaced with InProgress");
Ok(SlotGuard::new(
*tenant_shard_id,
Some(old_value),
completion,
))
}
}
}
}
}
/// Stops and removes the tenant from memory, if it's not [`TenantState::Stopping`] already, bails otherwise.
/// Allows to remove other tenant resources manually, via `tenant_cleanup`.
/// If the cleanup fails, tenant will stay in memory in [`TenantState::Broken`] state, and another removal
/// operation would be needed to remove it.
async fn remove_tenant_from_memory<V, F>(
tenants: &std::sync::RwLock<TenantsMap>,
tenant_shard_id: TenantShardId,
tenant_cleanup: F,
) -> Result<V, TenantStateError>
where
F: std::future::Future<Output = anyhow::Result<V>>,
{
let mut slot_guard =
tenant_map_acquire_slot_impl(&tenant_shard_id, tenants, TenantSlotAcquireMode::MustExist)?;
// allow pageserver shutdown to await for our completion
let (_guard, progress) = completion::channel();
// The SlotGuard allows us to manipulate the Tenant object without fear of some
// concurrent API request doing something else for the same tenant ID.
let attached_tenant = match slot_guard.get_old_value() {
Some(TenantSlot::Attached(tenant)) => {
// whenever we remove a tenant from memory, we don't want to flush and wait for upload
let shutdown_mode = ShutdownMode::Hard;
// shutdown is sure to transition tenant to stopping, and wait for all tasks to complete, so
// that we can continue safely to cleanup.
match tenant.shutdown(progress, shutdown_mode).await {
Ok(()) => {}
Err(_other) => {
// if pageserver shutdown or other detach/ignore is already ongoing, we don't want to
// wait for it but return an error right away because these are distinct requests.
slot_guard.revert();
return Err(TenantStateError::IsStopping(tenant_shard_id));
}
}
Some(tenant)
}
Some(TenantSlot::Secondary(secondary_state)) => {
tracing::info!("Shutting down in secondary mode");
secondary_state.shutdown().await;
None
}
Some(TenantSlot::InProgress(_)) => {
// Acquiring a slot guarantees its old value was not InProgress
unreachable!();
}
None => None,
};
match tenant_cleanup
.await
.with_context(|| format!("Failed to run cleanup for tenant {tenant_shard_id}"))
{
Ok(hook_value) => {
// Success: drop the old TenantSlot::Attached.
slot_guard
.drop_old_value()
.expect("We just called shutdown");
Ok(hook_value)
}
Err(e) => {
// If we had a Tenant, set it to Broken and put it back in the TenantsMap
if let Some(attached_tenant) = attached_tenant {
attached_tenant.set_broken(e.to_string()).await;
}
// Leave the broken tenant in the map
slot_guard.revert();
Err(TenantStateError::Other(e))
}
}
}
use http_utils::error::ApiError;
use pageserver_api::models::TimelineGcRequest;
@@ -2866,11 +2879,15 @@ mod tests {
use std::collections::BTreeMap;
use std::sync::Arc;
use storage_broker::BrokerClientChannel;
use tracing::Instrument;
use super::super::harness::TenantHarness;
use super::TenantsMap;
use crate::tenant::mgr::TenantSlot;
use crate::tenant::{
TenantSharedResources,
mgr::{BackgroundPurges, TenantManager, TenantSlot},
};
#[tokio::test(start_paused = true)]
async fn shutdown_awaits_in_progress_tenant() {
@@ -2891,23 +2908,47 @@ mod tests {
let _e = span.enter();
let tenants = BTreeMap::from([(id, TenantSlot::Attached(t.clone()))]);
let tenants = Arc::new(std::sync::RwLock::new(TenantsMap::Open(tenants)));
// Invoke remove_tenant_from_memory with a cleanup hook that blocks until we manually
// permit it to proceed: that will stick the tenant in InProgress
let (basebackup_prepare_sender, _) = tokio::sync::mpsc::unbounded_channel::<
crate::basebackup_cache::BasebackupPrepareRequest,
>();
let tenant_manager = TenantManager {
tenants: std::sync::RwLock::new(TenantsMap::Open(tenants)),
conf: h.conf,
resources: TenantSharedResources {
broker_client: BrokerClientChannel::connect_lazy("foobar.com")
.await
.unwrap(),
remote_storage: h.remote_storage.clone(),
deletion_queue_client: h.deletion_queue.new_client(),
l0_flush_global_state: crate::l0_flush::L0FlushGlobalState::new(
h.conf.l0_flush.clone(),
),
basebackup_prepare_sender,
feature_resolver: crate::feature_resolver::FeatureResolver::new_disabled(),
},
cancel: tokio_util::sync::CancellationToken::new(),
background_purges: BackgroundPurges::default(),
};
let tenant_manager = Arc::new(tenant_manager);
let (until_cleanup_completed, can_complete_cleanup) = utils::completion::channel();
let (until_cleanup_started, cleanup_started) = utils::completion::channel();
let mut remove_tenant_from_memory_task = {
let tenant_manager = tenant_manager.clone();
let jh = tokio::spawn({
let tenants = tenants.clone();
async move {
let cleanup = async move {
drop(until_cleanup_started);
can_complete_cleanup.wait().await;
anyhow::Ok(())
};
super::remove_tenant_from_memory(&tenants, id, cleanup).await
tenant_manager.remove_tenant_from_memory(id, cleanup).await
}
.instrument(h.span())
});
@@ -2920,9 +2961,11 @@ mod tests {
let mut shutdown_task = {
let (until_shutdown_started, shutdown_started) = utils::completion::channel();
let tenant_manager = tenant_manager.clone();
let shutdown_task = tokio::spawn(async move {
drop(until_shutdown_started);
super::shutdown_all_tenants0(&tenants).await;
tenant_manager.shutdown_all_tenants0().await;
});
shutdown_started.wait().await;

View File

@@ -1092,13 +1092,15 @@ communicator_prefetch_register_bufferv(BufferTag tag, neon_request_lsns *frlsns,
MyPState->ring_last <= ring_index);
}
/* internal version. Returns the ring index */
/* Internal version. Returns the ring index of the last block (result of this function is used only
* when nblocks==1)
*/
static uint64
prefetch_register_bufferv(BufferTag tag, neon_request_lsns *frlsns,
BlockNumber nblocks, const bits8 *mask,
bool is_prefetch)
{
uint64 min_ring_index;
uint64 last_ring_index;
PrefetchRequest hashkey;
#ifdef USE_ASSERT_CHECKING
bool any_hits = false;
@@ -1122,13 +1124,12 @@ Retry:
MyPState->ring_unused - MyPState->ring_receive;
MyNeonCounters->getpage_prefetches_buffered =
MyPState->n_responses_buffered;
last_ring_index = UINT64_MAX;
min_ring_index = UINT64_MAX;
for (int i = 0; i < nblocks; i++)
{
PrefetchRequest *slot = NULL;
PrfHashEntry *entry = NULL;
uint64 ring_index;
neon_request_lsns *lsns;
if (PointerIsValid(mask) && BITMAP_ISSET(mask, i))
@@ -1152,12 +1153,12 @@ Retry:
if (entry != NULL)
{
slot = entry->slot;
ring_index = slot->my_ring_index;
Assert(slot == GetPrfSlot(ring_index));
last_ring_index = slot->my_ring_index;
Assert(slot == GetPrfSlot(last_ring_index));
Assert(slot->status != PRFS_UNUSED);
Assert(MyPState->ring_last <= ring_index &&
ring_index < MyPState->ring_unused);
Assert(MyPState->ring_last <= last_ring_index &&
last_ring_index < MyPState->ring_unused);
Assert(BufferTagsEqual(&slot->buftag, &hashkey.buftag));
/*
@@ -1169,9 +1170,9 @@ Retry:
if (!neon_prefetch_response_usable(lsns, slot))
{
/* Wait for the old request to finish and discard it */
if (!prefetch_wait_for(ring_index))
if (!prefetch_wait_for(last_ring_index))
goto Retry;
prefetch_set_unused(ring_index);
prefetch_set_unused(last_ring_index);
entry = NULL;
slot = NULL;
pgBufferUsage.prefetch.expired += 1;
@@ -1188,13 +1189,12 @@ Retry:
*/
if (slot->status == PRFS_TAG_REMAINS)
{
prefetch_set_unused(ring_index);
prefetch_set_unused(last_ring_index);
entry = NULL;
slot = NULL;
}
else
{
min_ring_index = Min(min_ring_index, ring_index);
/* The buffered request is good enough, return that index */
if (is_prefetch)
pgBufferUsage.prefetch.duplicates++;
@@ -1283,12 +1283,12 @@ Retry:
* The next buffer pointed to by `ring_unused` is now definitely empty, so
* we can insert the new request to it.
*/
ring_index = MyPState->ring_unused;
last_ring_index = MyPState->ring_unused;
Assert(MyPState->ring_last <= ring_index &&
ring_index <= MyPState->ring_unused);
Assert(MyPState->ring_last <= last_ring_index &&
last_ring_index <= MyPState->ring_unused);
slot = GetPrfSlotNoCheck(ring_index);
slot = GetPrfSlotNoCheck(last_ring_index);
Assert(slot->status == PRFS_UNUSED);
@@ -1298,11 +1298,9 @@ Retry:
*/
slot->buftag = hashkey.buftag;
slot->shard_no = get_shard_number(&tag);
slot->my_ring_index = ring_index;
slot->my_ring_index = last_ring_index;
slot->flags = 0;
min_ring_index = Min(min_ring_index, ring_index);
if (is_prefetch)
MyNeonCounters->getpage_prefetch_requests_total++;
else
@@ -1315,11 +1313,12 @@ Retry:
MyPState->ring_unused - MyPState->ring_receive;
Assert(any_hits);
Assert(last_ring_index != UINT64_MAX);
Assert(GetPrfSlot(min_ring_index)->status == PRFS_REQUESTED ||
GetPrfSlot(min_ring_index)->status == PRFS_RECEIVED);
Assert(MyPState->ring_last <= min_ring_index &&
min_ring_index < MyPState->ring_unused);
Assert(GetPrfSlot(last_ring_index)->status == PRFS_REQUESTED ||
GetPrfSlot(last_ring_index)->status == PRFS_RECEIVED);
Assert(MyPState->ring_last <= last_ring_index &&
last_ring_index < MyPState->ring_unused);
if (flush_every_n_requests > 0 &&
MyPState->ring_unused - MyPState->ring_flush >= flush_every_n_requests)
@@ -1335,7 +1334,7 @@ Retry:
MyPState->ring_flush = MyPState->ring_unused;
}
return min_ring_index;
return last_ring_index;
}
static bool

View File

@@ -1135,7 +1135,7 @@ VotesCollectedMset(WalProposer *wp, MemberSet *mset, Safekeeper **msk, StringInf
wp->propTermStartLsn = sk->voteResponse.flushLsn;
wp->donor = sk;
}
wp->truncateLsn = Max(wp->safekeeper[i].voteResponse.truncateLsn, wp->truncateLsn);
wp->truncateLsn = Max(sk->voteResponse.truncateLsn, wp->truncateLsn);
if (n_votes > 0)
appendStringInfoString(s, ", ");

View File

@@ -14,9 +14,9 @@ use crate::context::RequestContext;
use crate::control_plane::client::cplane_proxy_v1;
use crate::control_plane::{self, CachedNodeInfo, NodeInfo};
use crate::error::{ReportableError, UserFacingError};
use crate::pglb::connect_compute::ComputeConnectBackend;
use crate::pqproto::BeMessage;
use crate::proxy::NeonOptions;
use crate::proxy::wake_compute::WakeComputeBackend;
use crate::stream::PqStream;
use crate::types::RoleName;
use crate::{auth, compute, waiters};
@@ -109,7 +109,7 @@ impl ConsoleRedirectBackend {
pub struct ConsoleRedirectNodeInfo(pub(super) NodeInfo);
#[async_trait]
impl ComputeConnectBackend for ConsoleRedirectNodeInfo {
impl WakeComputeBackend for ConsoleRedirectNodeInfo {
async fn wake_compute(
&self,
_ctx: &RequestContext,

View File

@@ -25,9 +25,9 @@ use crate::control_plane::{
RoleAccessControl,
};
use crate::intern::EndpointIdInt;
use crate::pglb::connect_compute::ComputeConnectBackend;
use crate::pqproto::BeMessage;
use crate::proxy::NeonOptions;
use crate::proxy::wake_compute::WakeComputeBackend;
use crate::rate_limiter::EndpointRateLimiter;
use crate::stream::Stream;
use crate::types::{EndpointCacheKey, EndpointId, RoleName};
@@ -407,13 +407,13 @@ impl Backend<'_, ComputeUserInfo> {
}
#[async_trait::async_trait]
impl ComputeConnectBackend for Backend<'_, ComputeCredentials> {
impl WakeComputeBackend for Backend<'_, ComputeUserInfo> {
async fn wake_compute(
&self,
ctx: &RequestContext,
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
match self {
Self::ControlPlane(api, creds) => api.wake_compute(ctx, &creds.info).await,
Self::ControlPlane(api, info) => api.wake_compute(ctx, info).await,
Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())),
}
}

View File

@@ -26,12 +26,12 @@ use utils::sentry_init::init_sentry;
use crate::context::RequestContext;
use crate::metrics::{Metrics, ThreadPoolMetrics};
use crate::pglb::TlsRequired;
use crate::pqproto::FeStartupPacket;
use crate::protocol2::ConnectionInfo;
use crate::proxy::{
ErrorSource, TlsRequired, copy_bidirectional_client_compute, run_until_cancelled,
};
use crate::proxy::{ErrorSource, copy_bidirectional_client_compute};
use crate::stream::{PqStream, Stream};
use crate::util::run_until_cancelled;
project_git_version!(GIT_VERSION);

View File

@@ -410,7 +410,7 @@ pub async fn run() -> anyhow::Result<()> {
match auth_backend {
Either::Left(auth_backend) => {
if let Some(proxy_listener) = proxy_listener {
client_tasks.spawn(crate::proxy::task_main(
client_tasks.spawn(crate::pglb::task_main(
config,
auth_backend,
proxy_listener,

View File

@@ -103,6 +103,8 @@ pub enum Auth {
}
/// A config for authenticating to the compute node.
// TODO: avoid Clone
#[derive(Clone)]
pub(crate) struct AuthInfo {
/// None for local-proxy, as we use trust-based localhost auth.
/// Some for sql-over-http, ws, tcp, and in most cases for console-redirect.
@@ -136,11 +138,11 @@ impl AuthInfo {
}
}
pub(crate) fn with_auth_keys(keys: &ComputeCredentialKeys) -> Self {
pub(crate) fn with_auth_keys(keys: ComputeCredentialKeys) -> Self {
Self {
auth: match keys {
ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(auth_keys)) => {
Some(Auth::Scram(Box::new(*auth_keys)))
Some(Auth::Scram(Box::new(auth_keys)))
}
ComputeCredentialKeys::JwtPayload(_) | ComputeCredentialKeys::None => None,
},

View File

@@ -11,13 +11,13 @@ use crate::config::{ProxyConfig, ProxyProtocolV2};
use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::metrics::{Metrics, NumClientConnectionsGuard};
use crate::pglb::connect_compute::{TcpMechanism, connect_to_compute};
use crate::pglb::handshake::{HandshakeData, handshake};
use crate::pglb::passthrough::ProxyPassthrough;
use crate::pglb::{ClientRequestError, ErrorSource};
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
use crate::proxy::{
ClientRequestError, ErrorSource, prepare_client_connection, run_until_cancelled,
};
use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute};
use crate::proxy::prepare_client_connection;
use crate::util::run_until_cancelled;
pub async fn task_main(
config: &'static ProxyConfig,

View File

@@ -106,4 +106,5 @@ mod tls;
mod types;
mod url;
mod usage_metrics;
mod util;
mod waiters;

View File

@@ -8,10 +8,10 @@ use crate::config::TlsConfig;
use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::metrics::Metrics;
use crate::pglb::TlsRequired;
use crate::pqproto::{
BeMessage, CancelKeyData, FeStartupPacket, ProtocolVersion, StartupMessageParams,
};
use crate::proxy::TlsRequired;
use crate::stream::{PqStream, Stream, StreamUpgradeError};
use crate::tls::PG_ALPN_PROTOCOL;

View File

@@ -1,5 +1,343 @@
pub mod connect_compute;
pub mod copy_bidirectional;
pub mod handshake;
pub mod inprocess;
pub mod passthrough;
use std::sync::Arc;
use futures::FutureExt;
use smol_str::ToSmolStr;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, debug, error, info, warn};
use crate::auth;
use crate::cancellation::{self, CancellationHandler};
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
use crate::context::RequestContext;
use crate::error::{ReportableError, UserFacingError};
use crate::metrics::{Metrics, NumClientConnectionsGuard};
pub use crate::pglb::copy_bidirectional::ErrorSource;
use crate::pglb::handshake::{HandshakeData, HandshakeError, handshake};
use crate::pglb::passthrough::ProxyPassthrough;
use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol};
use crate::proxy::connect_compute::{ConnectMechanism, TcpMechanism};
use crate::proxy::handle_connect_request;
use crate::rate_limiter::EndpointRateLimiter;
use crate::stream::Stream;
use crate::util::run_until_cancelled;
pub const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
#[derive(Error, Debug)]
#[error("{ERR_INSECURE_CONNECTION}")]
pub struct TlsRequired;
impl ReportableError for TlsRequired {
fn get_error_kind(&self) -> crate::error::ErrorKind {
crate::error::ErrorKind::User
}
}
impl UserFacingError for TlsRequired {}
pub async fn task_main(
config: &'static ProxyConfig,
auth_backend: &'static auth::Backend<'static, ()>,
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
cancellation_handler: Arc<CancellationHandler>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> anyhow::Result<()> {
scopeguard::defer! {
info!("proxy has shut down");
}
// When set for the server socket, the keepalive setting
// will be inherited by all accepted client sockets.
socket2::SockRef::from(&listener).set_keepalive(true)?;
let connections = tokio_util::task::task_tracker::TaskTracker::new();
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
while let Some(accept_result) =
run_until_cancelled(listener.accept(), &cancellation_token).await
{
let (socket, peer_addr) = accept_result?;
let conn_gauge = Metrics::get()
.proxy
.client_connections
.guard(crate::metrics::Protocol::Tcp);
let session_id = uuid::Uuid::new_v4();
let cancellation_handler = Arc::clone(&cancellation_handler);
let cancellations = cancellations.clone();
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
connections.spawn(async move {
let (socket, conn_info) = match config.proxy_protocol_v2 {
ProxyProtocolV2::Required => {
match read_proxy_protocol(socket).await {
Err(e) => {
warn!("per-client task finished with an error: {e:#}");
return;
}
// our load balancers will not send any more data. let's just exit immediately
Ok((_socket, ConnectHeader::Local)) => {
debug!("healthcheck received");
return;
}
Ok((socket, ConnectHeader::Proxy(info))) => (socket, info),
}
}
// ignore the header - it cannot be confused for a postgres or http connection so will
// error later.
ProxyProtocolV2::Rejected => (
socket,
ConnectionInfo {
addr: peer_addr,
extra: None,
},
),
};
match socket.set_nodelay(true) {
Ok(()) => {}
Err(e) => {
error!(
"per-client task finished with an error: failed to set socket option: {e:#}"
);
return;
}
}
let ctx = RequestContext::new(
session_id,
conn_info,
crate::metrics::Protocol::Tcp,
&config.region,
);
let res = handle_client(
config,
auth_backend,
&ctx,
cancellation_handler,
socket,
ClientMode::Tcp,
endpoint_rate_limiter2,
conn_gauge,
cancellations,
)
.instrument(ctx.span())
.boxed()
.await;
match res {
Err(e) => {
ctx.set_error_kind(e.get_error_kind());
warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
}
Ok(None) => {
ctx.set_success();
}
Ok(Some(p)) => {
ctx.set_success();
let _disconnect = ctx.log_connect();
match p.proxy_pass(&config.connect_to_compute).await {
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
warn!(
?session_id,
"per-client task finished with an IO error from the client: {e:#}"
);
}
Err(ErrorSource::Compute(e)) => {
error!(
?session_id,
"per-client task finished with an IO error from the compute: {e:#}"
);
}
}
}
}
});
}
connections.close();
cancellations.close();
drop(listener);
// Drain connections
connections.wait().await;
cancellations.wait().await;
Ok(())
}
pub(crate) enum ClientMode {
Tcp,
Websockets { hostname: Option<String> },
}
/// Abstracts the logic of handling TCP vs WS clients
impl ClientMode {
pub(crate) fn allow_cleartext(&self) -> bool {
match self {
ClientMode::Tcp => false,
ClientMode::Websockets { .. } => true,
}
}
pub(crate) fn hostname<'a, S>(&'a self, s: &'a Stream<S>) -> Option<&'a str> {
match self {
ClientMode::Tcp => s.sni_hostname(),
ClientMode::Websockets { hostname } => hostname.as_deref(),
}
}
fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> {
match self {
ClientMode::Tcp => tls,
// TLS is None here if using websockets, because the connection is already encrypted.
ClientMode::Websockets { .. } => None,
}
}
}
#[derive(Debug, Error)]
// almost all errors should be reported to the user, but there's a few cases where we cannot
// 1. Cancellation: we are not allowed to tell the client any cancellation statuses for security reasons
// 2. Handshake: handshake reports errors if it can, otherwise if the handshake fails due to protocol violation,
// we cannot be sure the client even understands our error message
// 3. PrepareClient: The client disconnected, so we can't tell them anyway...
pub(crate) enum ClientRequestError {
#[error("{0}")]
Cancellation(#[from] cancellation::CancelError),
#[error("{0}")]
Handshake(#[from] HandshakeError),
#[error("{0}")]
HandshakeTimeout(#[from] tokio::time::error::Elapsed),
#[error("{0}")]
PrepareClient(#[from] std::io::Error),
#[error("{0}")]
ReportedError(#[from] crate::stream::ReportedError),
}
impl ReportableError for ClientRequestError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
ClientRequestError::Cancellation(e) => e.get_error_kind(),
ClientRequestError::Handshake(e) => e.get_error_kind(),
ClientRequestError::HandshakeTimeout(_) => crate::error::ErrorKind::RateLimit,
ClientRequestError::ReportedError(e) => e.get_error_kind(),
ClientRequestError::PrepareClient(_) => crate::error::ErrorKind::ClientDisconnect,
}
}
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
config: &'static ProxyConfig,
auth_backend: &'static auth::Backend<'static, ()>,
ctx: &RequestContext,
cancellation_handler: Arc<CancellationHandler>,
client: S,
mode: ClientMode,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
conn_gauge: NumClientConnectionsGuard<'static>,
cancellations: tokio_util::task::task_tracker::TaskTracker,
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
debug!(
protocol = %ctx.protocol(),
"handling interactive connection from client"
);
let metrics = &Metrics::get().proxy;
let proto = ctx.protocol();
let request_gauge = metrics.connection_requests.guard(proto);
let tls = config.tls_config.load();
let tls = tls.as_deref();
let record_handshake_error = !ctx.has_private_peer_addr();
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let do_handshake = handshake(ctx, client, mode.handshake_tls(tls), record_handshake_error);
let (client, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
.await??
{
HandshakeData::Startup(client, params) => (client, params),
HandshakeData::Cancel(cancel_key_data) => {
// spawn a task to cancel the session, but don't wait for it
cancellations.spawn({
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
let ctx = ctx.clone();
let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id());
cancel_span.follows_from(tracing::Span::current());
async move {
cancellation_handler_clone
.cancel_session(
cancel_key_data,
ctx,
config.authentication_config.ip_allowlist_check_enabled,
config.authentication_config.is_vpc_acccess_proxy,
auth_backend.get_api(),
)
.await
.inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok();
}.instrument(cancel_span)
});
return Ok(None);
}
};
drop(pause);
ctx.set_db_options(params.clone());
let common_names = tls.map(|tls| &tls.common_names);
let private_link_id = match ctx.extra() {
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
None => None,
};
let (node, client, session) = handle_connect_request(
config,
auth_backend,
ctx,
cancellation_handler,
client,
&mode,
endpoint_rate_limiter,
&params,
common_names,
async |config, ctx, node_info, auth_info, creds, compute_config| {
TcpMechanism {
auth: auth_info.clone(),
locks: &config.connect_compute_locks,
user_info: creds.info.clone(),
}
.connect_once(ctx, node_info, compute_config)
.await
},
)
.await?;
Ok(Some(ProxyPassthrough {
client,
aux: node.aux.clone(),
private_link_id,
compute: node,
session_id: ctx.session_id(),
cancel: session,
_req: request_gauge,
_conn: conn_gauge,
}))
}

View File

@@ -2,9 +2,9 @@ use async_trait::async_trait;
use tokio::time;
use tracing::{debug, info, warn};
use crate::auth::backend::ComputeUserInfo;
use crate::auth::backend::{ComputeCredentials, ComputeUserInfo};
use crate::compute::{self, AuthInfo, COULD_NOT_CONNECT, PostgresConnection};
use crate::config::{ComputeConfig, RetryConfig};
use crate::config::{ComputeConfig, ProxyConfig, RetryConfig};
use crate::context::RequestContext;
use crate::control_plane::errors::WakeComputeError;
use crate::control_plane::locks::ApiLocks;
@@ -14,13 +14,13 @@ use crate::metrics::{
ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType,
};
use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute, retry_after, should_retry};
use crate::proxy::wake_compute::wake_compute;
use crate::proxy::wake_compute::{WakeComputeBackend, wake_compute};
use crate::types::Host;
/// If we couldn't connect, a cached connection info might be to blame
/// (e.g. the compute node's address might've changed at the wrong time).
/// Invalidate the cache entry (if any) to prevent subsequent errors.
#[tracing::instrument(name = "invalidate_cache", skip_all)]
#[tracing::instrument(skip_all)]
pub(crate) fn invalidate_cache(node_info: control_plane::CachedNodeInfo) -> NodeInfo {
let is_cached = node_info.cached();
if is_cached {
@@ -49,14 +49,6 @@ pub(crate) trait ConnectMechanism {
) -> Result<Self::Connection, Self::ConnectError>;
}
#[async_trait]
pub(crate) trait ComputeConnectBackend {
async fn wake_compute(
&self,
ctx: &RequestContext,
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError>;
}
pub(crate) struct TcpMechanism {
pub(crate) auth: AuthInfo,
/// connect_to_compute concurrency lock
@@ -91,7 +83,7 @@ impl ConnectMechanism for TcpMechanism {
/// Try to connect to the compute node, retrying if necessary.
#[tracing::instrument(skip_all)]
pub(crate) async fn connect_to_compute<M: ConnectMechanism, B: ComputeConnectBackend>(
pub(crate) async fn connect_to_compute<M: ConnectMechanism, B: WakeComputeBackend>(
ctx: &RequestContext,
mechanism: &M,
user_info: &B,
@@ -191,3 +183,114 @@ where
drop(pause);
}
}
#[tracing::instrument(skip_all)]
pub(crate) async fn connect_to_compute_pglb<
F: AsyncFn(
&'static ProxyConfig,
&RequestContext,
&CachedNodeInfo,
&AuthInfo,
&ComputeCredentials,
&ComputeConfig,
) -> Result<PostgresConnection, compute::ConnectionError>,
B: WakeComputeBackend,
>(
config: &'static ProxyConfig,
ctx: &RequestContext,
connect_compute_fn: F,
user_info: &B,
auth_info: &AuthInfo,
creds: &ComputeCredentials,
wake_compute_retry_config: RetryConfig,
compute: &ComputeConfig,
) -> Result<PostgresConnection, compute::ConnectionError> {
let mut num_retries = 0;
let node_info =
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?;
// try once
let err = match connect_compute_fn(config, ctx, &node_info, &auth_info, &creds, compute).await {
Ok(res) => {
ctx.success();
Metrics::get().proxy.retries_metric.observe(
RetriesMetricGroup {
outcome: ConnectOutcome::Success,
retry_type: RetryType::ConnectToCompute,
},
num_retries.into(),
);
return Ok(res);
}
Err(e) => e,
};
debug!(error = ?err, COULD_NOT_CONNECT);
let node_info = if !node_info.cached() || !err.should_retry_wake_compute() {
// If we just recieved this from cplane and didn't get it from cache, we shouldn't retry.
// Do not need to retrieve a new node_info, just return the old one.
if should_retry(&err, num_retries, compute.retry) {
Metrics::get().proxy.retries_metric.observe(
RetriesMetricGroup {
outcome: ConnectOutcome::Failed,
retry_type: RetryType::ConnectToCompute,
},
num_retries.into(),
);
return Err(err.into());
}
node_info
} else {
// if we failed to connect, it's likely that the compute node was suspended, wake a new compute node
debug!("compute node's state has likely changed; requesting a wake-up");
invalidate_cache(node_info);
// TODO: increment num_retries?
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?
};
// now that we have a new node, try connect to it repeatedly.
// this can error for a few reasons, for instance:
// * DNS connection settings haven't quite propagated yet
debug!("wake_compute success. attempting to connect");
num_retries = 1;
loop {
match connect_compute_fn(config, ctx, &node_info, &auth_info, &creds, compute).await {
Ok(res) => {
ctx.success();
Metrics::get().proxy.retries_metric.observe(
RetriesMetricGroup {
outcome: ConnectOutcome::Success,
retry_type: RetryType::ConnectToCompute,
},
num_retries.into(),
);
// TODO: is this necessary? We have a metric.
info!(?num_retries, "connected to compute node after");
return Ok(res);
}
Err(e) => {
if !should_retry(&e, num_retries, compute.retry) {
// Don't log an error here, caller will print the error
Metrics::get().proxy.retries_metric.observe(
RetriesMetricGroup {
outcome: ConnectOutcome::Failed,
retry_type: RetryType::ConnectToCompute,
},
num_retries.into(),
);
return Err(e.into());
}
warn!(error = ?e, num_retries, retriable = true, COULD_NOT_CONNECT);
}
}
let wait_duration = retry_after(num_retries, compute.retry);
num_retries += 1;
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::RetryTimeout);
time::sleep(wait_duration).await;
drop(pause);
}
}

View File

@@ -1,327 +1,61 @@
#[cfg(test)]
mod tests;
pub(crate) mod connect_compute;
pub(crate) mod retry;
pub(crate) mod wake_compute;
use std::collections::HashSet;
use std::sync::Arc;
use futures::FutureExt;
use itertools::Itertools;
use once_cell::sync::OnceCell;
use regex::Regex;
use serde::{Deserialize, Serialize};
use smol_str::{SmolStr, ToSmolStr, format_smolstr};
use thiserror::Error;
use smol_str::{SmolStr, format_smolstr};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, debug, error, info, warn};
use tracing::Instrument;
use crate::cancellation::{self, CancellationHandler};
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
use crate::auth::backend::ComputeCredentials;
use crate::cancellation::{CancellationHandler, Session};
use crate::compute::{AuthInfo, PostgresConnection};
use crate::config::{ComputeConfig, ProxyConfig};
use crate::context::RequestContext;
use crate::error::{ReportableError, UserFacingError};
use crate::metrics::{Metrics, NumClientConnectionsGuard};
use crate::pglb::connect_compute::{TcpMechanism, connect_to_compute};
use crate::control_plane::CachedNodeInfo;
pub use crate::pglb::copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute};
use crate::pglb::handshake::{HandshakeData, HandshakeError, handshake};
use crate::pglb::passthrough::ProxyPassthrough;
use crate::pglb::{ClientMode, ClientRequestError};
use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams};
use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol};
use crate::proxy::connect_compute::connect_to_compute_pglb;
use crate::rate_limiter::EndpointRateLimiter;
use crate::stream::{PqStream, Stream};
use crate::types::EndpointCacheKey;
use crate::{auth, compute};
const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)";
#[derive(Error, Debug)]
#[error("{ERR_INSECURE_CONNECTION}")]
pub struct TlsRequired;
impl ReportableError for TlsRequired {
fn get_error_kind(&self) -> crate::error::ErrorKind {
crate::error::ErrorKind::User
}
}
impl UserFacingError for TlsRequired {}
pub async fn run_until_cancelled<F: std::future::Future>(
f: F,
cancellation_token: &CancellationToken,
) -> Option<F::Output> {
match futures::future::select(
std::pin::pin!(f),
std::pin::pin!(cancellation_token.cancelled()),
)
.await
{
futures::future::Either::Left((f, _)) => Some(f),
futures::future::Either::Right(((), _)) => None,
}
}
pub async fn task_main(
config: &'static ProxyConfig,
auth_backend: &'static auth::Backend<'static, ()>,
listener: tokio::net::TcpListener,
cancellation_token: CancellationToken,
cancellation_handler: Arc<CancellationHandler>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> anyhow::Result<()> {
scopeguard::defer! {
info!("proxy has shut down");
}
// When set for the server socket, the keepalive setting
// will be inherited by all accepted client sockets.
socket2::SockRef::from(&listener).set_keepalive(true)?;
let connections = tokio_util::task::task_tracker::TaskTracker::new();
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
while let Some(accept_result) =
run_until_cancelled(listener.accept(), &cancellation_token).await
{
let (socket, peer_addr) = accept_result?;
let conn_gauge = Metrics::get()
.proxy
.client_connections
.guard(crate::metrics::Protocol::Tcp);
let session_id = uuid::Uuid::new_v4();
let cancellation_handler = Arc::clone(&cancellation_handler);
let cancellations = cancellations.clone();
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
connections.spawn(async move {
let (socket, conn_info) = match config.proxy_protocol_v2 {
ProxyProtocolV2::Required => {
match read_proxy_protocol(socket).await {
Err(e) => {
warn!("per-client task finished with an error: {e:#}");
return;
}
// our load balancers will not send any more data. let's just exit immediately
Ok((_socket, ConnectHeader::Local)) => {
debug!("healthcheck received");
return;
}
Ok((socket, ConnectHeader::Proxy(info))) => (socket, info),
}
}
// ignore the header - it cannot be confused for a postgres or http connection so will
// error later.
ProxyProtocolV2::Rejected => (
socket,
ConnectionInfo {
addr: peer_addr,
extra: None,
},
),
};
match socket.set_nodelay(true) {
Ok(()) => {}
Err(e) => {
error!(
"per-client task finished with an error: failed to set socket option: {e:#}"
);
return;
}
}
let ctx = RequestContext::new(
session_id,
conn_info,
crate::metrics::Protocol::Tcp,
&config.region,
);
let res = handle_client(
config,
auth_backend,
&ctx,
cancellation_handler,
socket,
ClientMode::Tcp,
endpoint_rate_limiter2,
conn_gauge,
cancellations,
)
.instrument(ctx.span())
.boxed()
.await;
match res {
Err(e) => {
ctx.set_error_kind(e.get_error_kind());
warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
}
Ok(None) => {
ctx.set_success();
}
Ok(Some(p)) => {
ctx.set_success();
let _disconnect = ctx.log_connect();
match p.proxy_pass(&config.connect_to_compute).await {
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
warn!(
?session_id,
"per-client task finished with an IO error from the client: {e:#}"
);
}
Err(ErrorSource::Compute(e)) => {
error!(
?session_id,
"per-client task finished with an IO error from the compute: {e:#}"
);
}
}
}
}
});
}
connections.close();
cancellations.close();
drop(listener);
// Drain connections
connections.wait().await;
cancellations.wait().await;
Ok(())
}
pub(crate) enum ClientMode {
Tcp,
Websockets { hostname: Option<String> },
}
/// Abstracts the logic of handling TCP vs WS clients
impl ClientMode {
pub(crate) fn allow_cleartext(&self) -> bool {
match self {
ClientMode::Tcp => false,
ClientMode::Websockets { .. } => true,
}
}
fn hostname<'a, S>(&'a self, s: &'a Stream<S>) -> Option<&'a str> {
match self {
ClientMode::Tcp => s.sni_hostname(),
ClientMode::Websockets { hostname } => hostname.as_deref(),
}
}
fn handshake_tls<'a>(&self, tls: Option<&'a TlsConfig>) -> Option<&'a TlsConfig> {
match self {
ClientMode::Tcp => tls,
// TLS is None here if using websockets, because the connection is already encrypted.
ClientMode::Websockets { .. } => None,
}
}
}
#[derive(Debug, Error)]
// almost all errors should be reported to the user, but there's a few cases where we cannot
// 1. Cancellation: we are not allowed to tell the client any cancellation statuses for security reasons
// 2. Handshake: handshake reports errors if it can, otherwise if the handshake fails due to protocol violation,
// we cannot be sure the client even understands our error message
// 3. PrepareClient: The client disconnected, so we can't tell them anyway...
pub(crate) enum ClientRequestError {
#[error("{0}")]
Cancellation(#[from] cancellation::CancelError),
#[error("{0}")]
Handshake(#[from] HandshakeError),
#[error("{0}")]
HandshakeTimeout(#[from] tokio::time::error::Elapsed),
#[error("{0}")]
PrepareClient(#[from] std::io::Error),
#[error("{0}")]
ReportedError(#[from] crate::stream::ReportedError),
}
impl ReportableError for ClientRequestError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
ClientRequestError::Cancellation(e) => e.get_error_kind(),
ClientRequestError::Handshake(e) => e.get_error_kind(),
ClientRequestError::HandshakeTimeout(_) => crate::error::ErrorKind::RateLimit,
ClientRequestError::ReportedError(e) => e.get_error_kind(),
ClientRequestError::PrepareClient(_) => crate::error::ErrorKind::ClientDisconnect,
}
}
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
pub(crate) async fn handle_connect_request<
S: AsyncRead + AsyncWrite + Unpin + Send,
F: AsyncFn(
&'static ProxyConfig,
&RequestContext,
&CachedNodeInfo,
&AuthInfo,
&ComputeCredentials,
&ComputeConfig,
) -> Result<PostgresConnection, compute::ConnectionError>,
>(
config: &'static ProxyConfig,
auth_backend: &'static auth::Backend<'static, ()>,
ctx: &RequestContext,
cancellation_handler: Arc<CancellationHandler>,
stream: S,
mode: ClientMode,
mut client: PqStream<Stream<S>>,
mode: &ClientMode,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
conn_gauge: NumClientConnectionsGuard<'static>,
cancellations: tokio_util::task::task_tracker::TaskTracker,
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
debug!(
protocol = %ctx.protocol(),
"handling interactive connection from client"
);
let metrics = &Metrics::get().proxy;
let proto = ctx.protocol();
let request_gauge = metrics.connection_requests.guard(proto);
let tls = config.tls_config.load();
let tls = tls.as_deref();
let record_handshake_error = !ctx.has_private_peer_addr();
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let do_handshake = handshake(ctx, stream, mode.handshake_tls(tls), record_handshake_error);
let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
.await??
{
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Cancel(cancel_key_data) => {
// spawn a task to cancel the session, but don't wait for it
cancellations.spawn({
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
let ctx = ctx.clone();
let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id());
cancel_span.follows_from(tracing::Span::current());
async move {
cancellation_handler_clone
.cancel_session(
cancel_key_data,
ctx,
config.authentication_config.ip_allowlist_check_enabled,
config.authentication_config.is_vpc_acccess_proxy,
auth_backend.get_api(),
)
.await
.inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok();
}.instrument(cancel_span)
});
return Ok(None);
}
};
drop(pause);
ctx.set_db_options(params.clone());
let hostname = mode.hostname(stream.get_ref());
let common_names = tls.map(|tls| &tls.common_names);
params: &StartupMessageParams,
common_names: Option<&HashSet<String>>,
connect_compute_fn: F,
) -> Result<(PostgresConnection, Stream<S>, Session), ClientRequestError> {
// TODO: to pglb
let hostname = mode.hostname(client.get_ref());
// Extract credentials which we're going to use for auth.
let result = auth_backend
@@ -331,14 +65,14 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
let user_info = match result {
Ok(user_info) => user_info,
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
Err(e) => Err(client.throw_error(e, Some(ctx)).await)?,
};
let user = user_info.get_user().to_owned();
let user_info = match user_info
.authenticate(
ctx,
&mut stream,
&mut client,
mode.allow_cleartext(),
&config.authentication_config,
endpoint_rate_limiter,
@@ -351,29 +85,28 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
let app = params.get("application_name");
let params_span = tracing::info_span!("", ?user, ?db, ?app);
return Err(stream
return Err(client
.throw_error(e, Some(ctx))
.instrument(params_span)
.await)?;
}
};
let creds = match &user_info {
auth::Backend::ControlPlane(_, creds) => creds,
let (cplane, creds) = match user_info {
auth::Backend::ControlPlane(cplane, creds) => (cplane, creds),
auth::Backend::Local(_) => unreachable!("local proxy does not run tcp proxy service"),
};
let params_compat = creds.info.options.get(NeonOptions::PARAMS_COMPAT).is_some();
let mut auth_info = compute::AuthInfo::with_auth_keys(&creds.keys);
let mut auth_info = compute::AuthInfo::with_auth_keys(creds.keys);
auth_info.set_startup_params(&params, params_compat);
let res = connect_to_compute(
let res = connect_to_compute_pglb(
config,
ctx,
&TcpMechanism {
user_info: creds.info.clone(),
auth: auth_info,
locks: &config.connect_compute_locks,
},
&user_info,
connect_compute_fn,
&auth::Backend::ControlPlane(cplane, creds.info),
&auth_info,
&creds,
config.wake_compute_retry_config,
&config.connect_to_compute,
)
@@ -381,32 +114,17 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
let node = match res {
Ok(node) => node,
Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?,
Err(e) => Err(client.throw_error(e, Some(ctx)).await)?,
};
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
let session = cancellation_handler_clone.get_key();
session.write_cancel_key(node.cancel_closure.clone())?;
prepare_client_connection(&node, *session.key(), &mut stream);
let stream = stream.flush_and_into_inner().await?;
prepare_client_connection(&node, *session.key(), &mut client);
let client = client.flush_and_into_inner().await?;
let private_link_id = match ctx.extra() {
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
None => None,
};
Ok(Some(ProxyPassthrough {
client: stream,
aux: node.aux.clone(),
private_link_id,
compute: node,
session_id: ctx.session_id(),
cancel: session,
_req: request_gauge,
_conn: conn_gauge,
}))
Ok((node, client, session))
}
/// Finish client connection initialization: confirm auth success, send params, etc.
@@ -441,7 +159,7 @@ impl NeonOptions {
// proxy options:
/// `PARAMS_COMPAT` allows opting in to forwarding all startup parameters from client to compute.
const PARAMS_COMPAT: &str = "proxy_params_compat";
pub const PARAMS_COMPAT: &str = "proxy_params_compat";
// cplane options:

View File

@@ -3,12 +3,13 @@
mod mitm;
use std::sync::Arc;
use std::time::Duration;
use anyhow::{Context, bail};
use async_trait::async_trait;
use http::StatusCode;
use postgres_client::config::{AuthKeys, ScramKeys, SslMode};
use postgres_client::config::SslMode;
use postgres_client::tls::{MakeTlsConnect, NoTls};
use retry::{ShouldRetryWakeCompute, retry_after};
use rstest::rstest;
@@ -19,19 +20,21 @@ use tracing_test::traced_test;
use super::retry::CouldRetry;
use super::*;
use crate::auth::backend::{
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned,
};
use crate::config::{ComputeConfig, RetryConfig};
use crate::auth::backend::{ComputeUserInfo, MaybeOwned};
use crate::config::{ComputeConfig, RetryConfig, TlsConfig};
use crate::context::RequestContext;
use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient};
use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status};
use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache};
use crate::error::ErrorKind;
use crate::pglb::connect_compute::ConnectMechanism;
use crate::error::{ErrorKind, ReportableError};
use crate::pglb::ERR_INSECURE_CONNECTION;
use crate::pglb::handshake::{HandshakeData, handshake};
use crate::proxy::connect_compute::{ConnectMechanism, connect_to_compute};
use crate::stream::Stream;
use crate::tls::client_config::compute_client_config_with_certs;
use crate::tls::server_config::CertResolver;
use crate::types::{BranchId, EndpointId, ProjectId};
use crate::{sasl, scram};
use crate::{auth, sasl, scram};
/// Generate a set of TLS certificates: CA + server.
fn generate_certs(
@@ -575,19 +578,13 @@ fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeIn
fn helper_create_connect_info(
mechanism: &TestConnectMechanism,
) -> auth::Backend<'static, ComputeCredentials> {
) -> auth::Backend<'static, ComputeUserInfo> {
auth::Backend::ControlPlane(
MaybeOwned::Owned(ControlPlaneClient::Test(Box::new(mechanism.clone()))),
ComputeCredentials {
info: ComputeUserInfo {
endpoint: "endpoint".into(),
user: "user".into(),
options: NeonOptions::parse_options_raw(""),
},
keys: ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(ScramKeys {
client_key: [0; 32],
server_key: [0; 32],
})),
ComputeUserInfo {
endpoint: "endpoint".into(),
user: "user".into(),
options: NeonOptions::parse_options_raw(""),
},
)
}

View File

@@ -1,3 +1,4 @@
use async_trait::async_trait;
use tracing::{error, info};
use crate::config::RetryConfig;
@@ -8,7 +9,6 @@ use crate::error::ReportableError;
use crate::metrics::{
ConnectOutcome, ConnectionFailuresBreakdownGroup, Metrics, RetriesMetricGroup, RetryType,
};
use crate::pglb::connect_compute::ComputeConnectBackend;
use crate::proxy::retry::{retry_after, should_retry};
// Use macro to retain original callsite.
@@ -23,7 +23,12 @@ macro_rules! log_wake_compute_error {
};
}
pub(crate) async fn wake_compute<B: ComputeConnectBackend>(
#[async_trait]
pub(crate) trait WakeComputeBackend {
async fn wake_compute(&self, ctx: &RequestContext) -> Result<CachedNodeInfo, WakeComputeError>;
}
pub(crate) async fn wake_compute<B: WakeComputeBackend>(
num_retries: &mut u32,
ctx: &RequestContext,
api: &B,

View File

@@ -21,7 +21,7 @@ use super::conn_pool_lib::{Client, ConnInfo, EndpointConnPool, GlobalConnPool};
use super::http_conn_pool::{self, HttpConnPool, Send, poll_http2_client};
use super::local_conn_pool::{self, EXT_NAME, EXT_SCHEMA, EXT_VERSION, LocalConnPool};
use crate::auth::backend::local::StaticAuthRules;
use crate::auth::backend::{ComputeCredentials, ComputeUserInfo};
use crate::auth::backend::{ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo};
use crate::auth::{self, AuthError};
use crate::compute_ctl::{
ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest,
@@ -34,7 +34,7 @@ use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
use crate::control_plane::locks::ApiLocks;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::intern::EndpointIdInt;
use crate::pglb::connect_compute::ConnectMechanism;
use crate::proxy::connect_compute::ConnectMechanism;
use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute};
use crate::rate_limiter::EndpointRateLimiter;
use crate::types::{EndpointId, Host, LOCAL_PROXY_SUFFIX};
@@ -180,14 +180,15 @@ impl PoolingBackend {
let conn_id = uuid::Uuid::new_v4();
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
let backend = self.auth_backend.as_ref().map(|()| keys);
crate::pglb::connect_compute::connect_to_compute(
let backend = self.auth_backend.as_ref().map(|()| keys.info);
crate::proxy::connect_compute::connect_to_compute(
ctx,
&TokioMechanism {
conn_id,
conn_info,
pool: self.pool.clone(),
locks: &self.config.connect_compute_locks,
keys: keys.keys,
},
&backend,
self.config.wake_compute_retry_config,
@@ -214,18 +215,15 @@ impl PoolingBackend {
let conn_id = uuid::Uuid::new_v4();
tracing::Span::current().record("conn_id", display(conn_id));
debug!(%conn_id, "pool: opening a new connection '{conn_info}'");
let backend = self.auth_backend.as_ref().map(|()| ComputeCredentials {
info: ComputeUserInfo {
user: conn_info.user_info.user.clone(),
endpoint: EndpointId::from(format!(
"{}{LOCAL_PROXY_SUFFIX}",
conn_info.user_info.endpoint.normalize()
)),
options: conn_info.user_info.options.clone(),
},
keys: crate::auth::backend::ComputeCredentialKeys::None,
let backend = self.auth_backend.as_ref().map(|()| ComputeUserInfo {
user: conn_info.user_info.user.clone(),
endpoint: EndpointId::from(format!(
"{}{LOCAL_PROXY_SUFFIX}",
conn_info.user_info.endpoint.normalize()
)),
options: conn_info.user_info.options.clone(),
});
crate::pglb::connect_compute::connect_to_compute(
crate::proxy::connect_compute::connect_to_compute(
ctx,
&HyperMechanism {
conn_id,
@@ -495,6 +493,7 @@ struct TokioMechanism {
pool: Arc<GlobalConnPool<postgres_client::Client, EndpointConnPool<postgres_client::Client>>>,
conn_info: ConnInfo,
conn_id: uuid::Uuid,
keys: ComputeCredentialKeys,
/// connect_to_compute concurrency lock
locks: &'static ApiLocks<Host>,
@@ -520,6 +519,10 @@ impl ConnectMechanism for TokioMechanism {
.dbname(&self.conn_info.dbname)
.connect_timeout(compute_config.timeout);
if let ComputeCredentialKeys::AuthKeys(auth_keys) = self.keys {
config.auth_keys(auth_keys);
}
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let res = config.connect(compute_config).await;
drop(pause);

View File

@@ -50,10 +50,10 @@ use crate::context::RequestContext;
use crate::ext::TaskExt;
use crate::metrics::Metrics;
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
use crate::proxy::run_until_cancelled;
use crate::rate_limiter::EndpointRateLimiter;
use crate::serverless::backend::PoolingBackend;
use crate::serverless::http_util::{api_error_into_response, json_response};
use crate::util::run_until_cancelled;
pub(crate) const SERVERLESS_DRIVER_SNI: &str = "api";
pub(crate) const AUTH_BROKER_SNI: &str = "apiauth";

View File

@@ -41,10 +41,11 @@ use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::http::{ReadBodyError, read_body_with_limit};
use crate::metrics::{HttpDirection, Metrics, SniGroup, SniKind};
use crate::pqproto::StartupMessageParams;
use crate::proxy::{NeonOptions, run_until_cancelled};
use crate::proxy::NeonOptions;
use crate::serverless::backend::HttpConnError;
use crate::types::{DbName, RoleName};
use crate::usage_metrics::{MetricCounter, MetricCounterRecorder};
use crate::util::run_until_cancelled;
#[derive(serde::Deserialize)]
#[serde(rename_all = "camelCase")]

View File

@@ -17,7 +17,8 @@ use crate::config::ProxyConfig;
use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::metrics::Metrics;
use crate::proxy::{ClientMode, ErrorSource, handle_client};
use crate::pglb::{ClientMode, handle_client};
use crate::proxy::ErrorSource;
use crate::rate_limiter::EndpointRateLimiter;
pin_project! {

14
proxy/src/util.rs Normal file
View File

@@ -0,0 +1,14 @@
use std::pin::pin;
use futures::future::{Either, select};
use tokio_util::sync::CancellationToken;
pub async fn run_until_cancelled<F: Future>(
f: F,
cancellation_token: &CancellationToken,
) -> Option<F::Output> {
match select(pin!(f), pin!(cancellation_token.cancelled())).await {
Either::Left((f, _)) => Some(f),
Either::Right(((), _)) => None,
}
}

View File

@@ -4046,6 +4046,16 @@ def static_proxy(
"CREATE TABLE neon_control_plane.endpoints (endpoint_id VARCHAR(255) PRIMARY KEY, allowed_ips VARCHAR(255))"
)
vanilla_pg.stop()
vanilla_pg.edit_hba(
[
"local all all trust",
"host all all 127.0.0.1/32 scram-sha-256",
"host all all ::1/128 scram-sha-256",
]
)
vanilla_pg.start()
proxy_port = port_distributor.get_port()
mgmt_port = port_distributor.get_port()
http_port = port_distributor.get_port()

View File

@@ -19,11 +19,15 @@ TABLE_NAME = "neon_control_plane.endpoints"
async def test_proxy_psql_allowed_ips(static_proxy: NeonProxy, vanilla_pg: VanillaPostgres):
# Shouldn't be able to connect to this project
vanilla_pg.safe_psql(
f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('private-project', '8.8.8.8')"
f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('private-project', '8.8.8.8')",
user="proxy",
password="password",
)
# Should be able to connect to this project
vanilla_pg.safe_psql(
f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('generic-project', '::1,127.0.0.1')"
f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('generic-project', '::1,127.0.0.1')",
user="proxy",
password="password",
)
def check_cannot_connect(**kwargs):
@@ -60,7 +64,9 @@ async def test_proxy_http_allowed_ips(static_proxy: NeonProxy, vanilla_pg: Vanil
# Shouldn't be able to connect to this project
vanilla_pg.safe_psql(
f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('proxy', '8.8.8.8')"
f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('proxy', '8.8.8.8')",
user="proxy",
password="password",
)
def query(status: int, query: str, *args):
@@ -75,6 +81,8 @@ async def test_proxy_http_allowed_ips(static_proxy: NeonProxy, vanilla_pg: Vanil
query(400, "select 1;") # ip address is not allowed
# Should be able to connect to this project
vanilla_pg.safe_psql(
f"UPDATE {TABLE_NAME} SET allowed_ips = '8.8.8.8,127.0.0.1' WHERE endpoint_id = 'proxy'"
f"UPDATE {TABLE_NAME} SET allowed_ips = '8.8.8.8,127.0.0.1' WHERE endpoint_id = 'proxy'",
user="proxy",
password="password",
)
query(200, "select 1;") # should work now