Compare commits

..

3 Commits

Author SHA1 Message Date
Erik Grinaker
90e3313dbb wip 2025-06-01 13:33:58 +02:00
Erik Grinaker
232591e457 Fix test build 2025-05-29 12:00:40 +02:00
Erik Grinaker
8daf272561 pageserver: initial gRPC page service implementation 2025-05-28 18:10:29 +02:00
40 changed files with 1220 additions and 1190 deletions

2
Cargo.lock generated
View File

@@ -4305,6 +4305,7 @@ dependencies = [
"hashlink",
"hex",
"hex-literal",
"http 1.1.0",
"http-utils",
"humantime",
"humantime-serde",
@@ -4367,6 +4368,7 @@ dependencies = [
"toml_edit",
"tonic 0.13.1",
"tonic-reflection",
"tower 0.5.2",
"tracing",
"tracing-utils",
"twox-hash",

View File

@@ -354,9 +354,6 @@ pub struct ShardImportProgressV1 {
pub completed: usize,
/// Hash of the plan
pub import_plan_hash: u64,
/// Soft limit for the job size
/// This needs to remain constant throughout the import
pub job_soft_size_limit: usize,
}
impl ShardImportStatus {
@@ -1934,7 +1931,7 @@ pub enum PagestreamFeMessage {
}
// Wrapped in libpq CopyData
#[derive(strum_macros::EnumProperty)]
#[derive(Debug, strum_macros::EnumProperty)]
pub enum PagestreamBeMessage {
Exists(PagestreamExistsResponse),
Nblocks(PagestreamNblocksResponse),

View File

@@ -34,6 +34,7 @@ fail.workspace = true
futures.workspace = true
hashlink.workspace = true
hex.workspace = true
http.workspace = true
http-utils.workspace = true
humantime-serde.workspace = true
humantime.workspace = true
@@ -93,6 +94,7 @@ tokio-util.workspace = true
toml_edit = { workspace = true, features = [ "serde" ] }
tonic.workspace = true
tonic-reflection.workspace = true
tower.workspace = true
tracing.workspace = true
tracing-utils.workspace = true
url.workspace = true

View File

@@ -584,6 +584,7 @@ impl TryFrom<GetSlruSegmentResponse> for proto::GetSlruSegmentResponse {
type Error = ProtocolError;
fn try_from(segment: GetSlruSegmentResponse) -> Result<Self, Self::Error> {
// TODO: can a segment legitimately be empty?
if segment.is_empty() {
return Err(ProtocolError::Missing("segment"));
}

View File

@@ -804,7 +804,7 @@ fn start_pageserver(
} else {
None
},
basebackup_cache.clone(),
basebackup_cache,
);
// Spawn a Pageserver gRPC server task. It will spawn separate tasks for
@@ -816,12 +816,10 @@ fn start_pageserver(
let mut page_service_grpc = None;
if let Some(grpc_listener) = grpc_listener {
page_service_grpc = Some(page_service::spawn_grpc(
conf,
tenant_manager.clone(),
grpc_auth,
otel_guard.as_ref().map(|g| g.dispatch.clone()),
grpc_listener,
basebackup_cache,
)?);
}

View File

@@ -837,30 +837,7 @@ async fn collect_eviction_candidates(
continue;
}
let info = tl.get_local_layers_for_disk_usage_eviction().await;
debug!(
tenant_id=%tl.tenant_shard_id.tenant_id,
shard_id=%tl.tenant_shard_id.shard_slug(),
timeline_id=%tl.timeline_id,
"timeline resident layers count: {}", info.resident_layers.len()
);
tenant_candidates.extend(info.resident_layers.into_iter());
max_layer_size = max_layer_size.max(info.max_layer_size.unwrap_or(0));
if cancel.is_cancelled() {
return Ok(EvictionCandidates::Cancelled);
}
}
// Also consider layers of timelines being imported for eviction
for tl in tenant.list_importing_timelines() {
let info = tl.timeline.get_local_layers_for_disk_usage_eviction().await;
debug!(
tenant_id=%tl.timeline.tenant_shard_id.tenant_id,
shard_id=%tl.timeline.tenant_shard_id.shard_slug(),
timeline_id=%tl.timeline.timeline_id,
"timeline resident layers count: {}", info.resident_layers.len()
);
debug!(tenant_id=%tl.tenant_shard_id.tenant_id, shard_id=%tl.tenant_shard_id.shard_slug(), timeline_id=%tl.timeline_id, "timeline resident layers count: {}", info.resident_layers.len());
tenant_candidates.extend(info.resident_layers.into_iter());
max_layer_size = max_layer_size.max(info.max_layer_size.unwrap_or(0));

File diff suppressed because it is too large Load Diff

View File

@@ -300,7 +300,7 @@ pub struct TenantShard {
/// as in progress.
/// * Imported timelines are removed when the storage controller calls the post timeline
/// import activation endpoint.
timelines_importing: std::sync::Mutex<HashMap<TimelineId, Arc<ImportingTimeline>>>,
timelines_importing: std::sync::Mutex<HashMap<TimelineId, ImportingTimeline>>,
/// The last tenant manifest known to be in remote storage. None if the manifest has not yet
/// been either downloaded or uploaded. Always Some after tenant attach.
@@ -672,7 +672,6 @@ pub enum MaybeOffloaded {
pub enum TimelineOrOffloaded {
Timeline(Arc<Timeline>),
Offloaded(Arc<OffloadedTimeline>),
Importing(Arc<ImportingTimeline>),
}
impl TimelineOrOffloaded {
@@ -684,9 +683,6 @@ impl TimelineOrOffloaded {
TimelineOrOffloaded::Offloaded(offloaded) => {
TimelineOrOffloadedArcRef::Offloaded(offloaded)
}
TimelineOrOffloaded::Importing(importing) => {
TimelineOrOffloadedArcRef::Importing(importing)
}
}
}
pub fn tenant_shard_id(&self) -> TenantShardId {
@@ -699,16 +695,12 @@ impl TimelineOrOffloaded {
match self {
TimelineOrOffloaded::Timeline(timeline) => &timeline.delete_progress,
TimelineOrOffloaded::Offloaded(offloaded) => &offloaded.delete_progress,
TimelineOrOffloaded::Importing(importing) => &importing.delete_progress,
}
}
fn maybe_remote_client(&self) -> Option<Arc<RemoteTimelineClient>> {
match self {
TimelineOrOffloaded::Timeline(timeline) => Some(timeline.remote_client.clone()),
TimelineOrOffloaded::Offloaded(_offloaded) => None,
TimelineOrOffloaded::Importing(importing) => {
Some(importing.timeline.remote_client.clone())
}
}
}
}
@@ -716,7 +708,6 @@ impl TimelineOrOffloaded {
pub enum TimelineOrOffloadedArcRef<'a> {
Timeline(&'a Arc<Timeline>),
Offloaded(&'a Arc<OffloadedTimeline>),
Importing(&'a Arc<ImportingTimeline>),
}
impl TimelineOrOffloadedArcRef<'_> {
@@ -724,14 +715,12 @@ impl TimelineOrOffloadedArcRef<'_> {
match self {
TimelineOrOffloadedArcRef::Timeline(timeline) => timeline.tenant_shard_id,
TimelineOrOffloadedArcRef::Offloaded(offloaded) => offloaded.tenant_shard_id,
TimelineOrOffloadedArcRef::Importing(importing) => importing.timeline.tenant_shard_id,
}
}
pub fn timeline_id(&self) -> TimelineId {
match self {
TimelineOrOffloadedArcRef::Timeline(timeline) => timeline.timeline_id,
TimelineOrOffloadedArcRef::Offloaded(offloaded) => offloaded.timeline_id,
TimelineOrOffloadedArcRef::Importing(importing) => importing.timeline.timeline_id,
}
}
}
@@ -748,12 +737,6 @@ impl<'a> From<&'a Arc<OffloadedTimeline>> for TimelineOrOffloadedArcRef<'a> {
}
}
impl<'a> From<&'a Arc<ImportingTimeline>> for TimelineOrOffloadedArcRef<'a> {
fn from(timeline: &'a Arc<ImportingTimeline>) -> Self {
Self::Importing(timeline)
}
}
#[derive(Debug, thiserror::Error, PartialEq, Eq)]
pub enum GetTimelineError {
#[error("Timeline is shutting down")]
@@ -1806,25 +1789,20 @@ impl TenantShard {
},
) => {
let timeline_id = timeline.timeline_id;
let import_task_gate = Gate::default();
let import_task_guard = import_task_gate.enter().unwrap();
let import_task_handle =
tokio::task::spawn(self.clone().create_timeline_import_pgdata_task(
timeline.clone(),
import_pgdata,
guard,
import_task_guard,
ctx.detached_child(TaskKind::ImportPgdata, DownloadBehavior::Warn),
));
let prev = self.timelines_importing.lock().unwrap().insert(
timeline_id,
Arc::new(ImportingTimeline {
ImportingTimeline {
timeline: timeline.clone(),
import_task_handle,
import_task_gate,
delete_progress: TimelineDeleteProgress::default(),
}),
},
);
assert!(prev.is_none());
@@ -2442,17 +2420,6 @@ impl TenantShard {
.collect()
}
/// Lists timelines the tenant contains.
/// It's up to callers to omit certain timelines that are not considered ready for use.
pub fn list_importing_timelines(&self) -> Vec<Arc<ImportingTimeline>> {
self.timelines_importing
.lock()
.unwrap()
.values()
.map(Arc::clone)
.collect()
}
/// Lists timelines the tenant manages, including offloaded ones.
///
/// It's up to callers to omit certain timelines that are not considered ready for use.
@@ -2886,25 +2853,19 @@ impl TenantShard {
let (timeline, timeline_create_guard) = uninit_timeline.finish_creation_myself();
let import_task_gate = Gate::default();
let import_task_guard = import_task_gate.enter().unwrap();
let import_task_handle = tokio::spawn(self.clone().create_timeline_import_pgdata_task(
timeline.clone(),
index_part,
timeline_create_guard,
import_task_guard,
timeline_ctx.detached_child(TaskKind::ImportPgdata, DownloadBehavior::Warn),
));
let prev = self.timelines_importing.lock().unwrap().insert(
timeline.timeline_id,
Arc::new(ImportingTimeline {
ImportingTimeline {
timeline: timeline.clone(),
import_task_handle,
import_task_gate,
delete_progress: TimelineDeleteProgress::default(),
}),
},
);
// Idempotency is enforced higher up the stack
@@ -2963,7 +2924,6 @@ impl TenantShard {
timeline: Arc<Timeline>,
index_part: import_pgdata::index_part_format::Root,
timeline_create_guard: TimelineCreateGuard,
_import_task_guard: GateGuard,
ctx: RequestContext,
) {
debug_assert_current_span_has_tenant_and_timeline_id();
@@ -3875,9 +3835,6 @@ impl TenantShard {
.build_timeline_client(offloaded.timeline_id, self.remote_storage.clone());
Arc::new(remote_client)
}
TimelineOrOffloadedArcRef::Importing(_) => {
unreachable!("Importing timelines are not included in the iterator")
}
};
// Shut down the timeline's remote client: this means that the indices we write
@@ -5087,14 +5044,6 @@ impl TenantShard {
info!("timeline already exists but is offloaded");
Err(CreateTimelineError::Conflict)
}
Err(TimelineExclusionError::AlreadyExists {
existing: TimelineOrOffloaded::Importing(_existing),
..
}) => {
// If there's a timeline already importing, then we would hit
// the [`TimelineExclusionError::AlreadyCreating`] branch above.
unreachable!("Importing timelines hold the creation guard")
}
Err(TimelineExclusionError::AlreadyExists {
existing: TimelineOrOffloaded::Timeline(existing),
arg,

View File

@@ -1348,21 +1348,6 @@ impl RemoteTimelineClient {
Ok(())
}
pub(crate) fn schedule_unlinking_of_layers_from_index_part<I>(
self: &Arc<Self>,
names: I,
) -> Result<(), NotInitialized>
where
I: IntoIterator<Item = LayerName>,
{
let mut guard = self.upload_queue.lock().unwrap();
let upload_queue = guard.initialized_mut()?;
self.schedule_unlinking_of_layers_from_index_part0(upload_queue, names);
Ok(())
}
/// Update the remote index file, removing the to-be-deleted files from the index,
/// allowing scheduling of actual deletions later.
fn schedule_unlinking_of_layers_from_index_part0<I>(

View File

@@ -206,8 +206,8 @@ pub struct GcCompactionQueue {
}
static CONCURRENT_GC_COMPACTION_TASKS: Lazy<Arc<Semaphore>> = Lazy::new(|| {
// Only allow one timeline on one pageserver to run gc compaction at a time.
Arc::new(Semaphore::new(1))
// Only allow two timelines on one pageserver to run gc compaction at a time.
Arc::new(Semaphore::new(2))
});
impl GcCompactionQueue {

View File

@@ -121,7 +121,6 @@ async fn remove_maybe_offloaded_timeline_from_tenant(
// This observes the locking order between timelines and timelines_offloaded
let mut timelines = tenant.timelines.lock().unwrap();
let mut timelines_offloaded = tenant.timelines_offloaded.lock().unwrap();
let mut timelines_importing = tenant.timelines_importing.lock().unwrap();
let offloaded_children_exist = timelines_offloaded
.iter()
.any(|(_, entry)| entry.ancestor_timeline_id == Some(timeline.timeline_id()));
@@ -151,12 +150,8 @@ async fn remove_maybe_offloaded_timeline_from_tenant(
.expect("timeline that we were deleting was concurrently removed from 'timelines_offloaded' map");
offloaded_timeline.delete_from_ancestor_with_timelines(&timelines);
}
TimelineOrOffloaded::Importing(importing) => {
timelines_importing.remove(&importing.timeline.timeline_id);
}
}
drop(timelines_importing);
drop(timelines_offloaded);
drop(timelines);
@@ -208,17 +203,8 @@ impl DeleteTimelineFlow {
guard.mark_in_progress()?;
// Now that the Timeline is in Stopping state, request all the related tasks to shut down.
// TODO(vlad): shut down imported timeline here
match &timeline {
TimelineOrOffloaded::Timeline(timeline) => {
timeline.shutdown(super::ShutdownMode::Hard).await;
}
TimelineOrOffloaded::Importing(importing) => {
importing.shutdown().await;
}
TimelineOrOffloaded::Offloaded(_offloaded) => {
// Nothing to shut down in this case
}
if let TimelineOrOffloaded::Timeline(timeline) = &timeline {
timeline.shutdown(super::ShutdownMode::Hard).await;
}
tenant.gc_block.before_delete(&timeline.timeline_id());
@@ -403,18 +389,10 @@ impl DeleteTimelineFlow {
Err(anyhow::anyhow!("failpoint: timeline-delete-before-rm"))?
});
match timeline {
TimelineOrOffloaded::Timeline(timeline) => {
delete_local_timeline_directory(conf, tenant.tenant_shard_id, timeline).await;
}
TimelineOrOffloaded::Importing(importing) => {
delete_local_timeline_directory(conf, tenant.tenant_shard_id, &importing.timeline)
.await;
}
TimelineOrOffloaded::Offloaded(_offloaded) => {
// Offloaded timelines have no local state
// TODO: once we persist offloaded information, delete the timeline from there, too
}
// Offloaded timelines have no local state
// TODO: once we persist offloaded information, delete the timeline from there, too
if let TimelineOrOffloaded::Timeline(timeline) = timeline {
delete_local_timeline_directory(conf, tenant.tenant_shard_id, timeline).await;
}
fail::fail_point!("timeline-delete-after-rm", |_| {
@@ -473,16 +451,12 @@ pub(super) fn make_timeline_delete_guard(
// For more context see this discussion: `https://github.com/neondatabase/neon/pull/4552#discussion_r1253437346`
let timelines = tenant.timelines.lock().unwrap();
let timelines_offloaded = tenant.timelines_offloaded.lock().unwrap();
let timelines_importing = tenant.timelines_importing.lock().unwrap();
let timeline = match timelines.get(&timeline_id) {
Some(t) => TimelineOrOffloaded::Timeline(Arc::clone(t)),
None => match timelines_offloaded.get(&timeline_id) {
Some(t) => TimelineOrOffloaded::Offloaded(Arc::clone(t)),
None => match timelines_importing.get(&timeline_id) {
Some(t) => TimelineOrOffloaded::Importing(Arc::clone(t)),
None => return Err(DeleteTimelineError::NotFound),
},
None => return Err(DeleteTimelineError::NotFound),
},
};

View File

@@ -8,10 +8,8 @@ use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use tracing::info;
use utils::lsn::Lsn;
use utils::pausable_failpoint;
use utils::sync::gate::Gate;
use super::{Timeline, TimelineDeleteProgress};
use super::Timeline;
use crate::context::RequestContext;
use crate::controller_upcall_client::{StorageControllerUpcallApi, StorageControllerUpcallClient};
use crate::tenant::metadata::TimelineMetadata;
@@ -21,23 +19,15 @@ mod importbucket_client;
mod importbucket_format;
pub(crate) mod index_part_format;
pub struct ImportingTimeline {
pub(crate) struct ImportingTimeline {
pub import_task_handle: JoinHandle<()>,
pub import_task_gate: Gate,
pub timeline: Arc<Timeline>,
pub delete_progress: TimelineDeleteProgress,
}
impl std::fmt::Debug for ImportingTimeline {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "ImportingTimeline<{}>", self.timeline.timeline_id)
}
}
impl ImportingTimeline {
pub async fn shutdown(&self) {
pub(crate) async fn shutdown(self) {
self.import_task_handle.abort();
self.import_task_gate.close().await;
let _ = self.import_task_handle.await;
self.timeline.remote_client.shutdown().await;
}
@@ -111,8 +101,6 @@ pub async fn doit(
.schedule_index_upload_for_file_changes()?;
timeline.remote_client.wait_completion().await?;
pausable_failpoint!("import-timeline-pre-success-notify-pausable");
// Communicate that shard is done.
// Ensure at-least-once delivery of the upcall to storage controller
// before we mark the task as done and never come here again.

View File

@@ -30,7 +30,6 @@
use std::collections::HashSet;
use std::hash::{Hash, Hasher};
use std::num::NonZeroUsize;
use std::ops::Range;
use std::sync::Arc;
@@ -101,24 +100,8 @@ async fn run_v1(
tasks: Vec::default(),
};
// Use the job size limit encoded in the progress if we are resuming an import.
// This ensures that imports have stable plans even if the pageserver config changes.
let import_config = {
match &import_progress {
Some(progress) => {
let base = &timeline.conf.timeline_import_config;
TimelineImportConfig {
import_job_soft_size_limit: NonZeroUsize::new(progress.job_soft_size_limit)
.unwrap(),
import_job_concurrency: base.import_job_concurrency,
import_job_checkpoint_threshold: base.import_job_checkpoint_threshold,
}
}
None => timeline.conf.timeline_import_config.clone(),
}
};
let plan = planner.plan(&import_config).await?;
let import_config = &timeline.conf.timeline_import_config;
let plan = planner.plan(import_config).await?;
// Hash the plan and compare with the hash of the plan we got back from the storage controller.
// If the two match, it means that the planning stage had the same output.
@@ -143,7 +126,7 @@ async fn run_v1(
pausable_failpoint!("import-timeline-pre-execute-pausable");
let start_from_job_idx = import_progress.map(|progress| progress.completed);
plan.execute(timeline, start_from_job_idx, plan_hash, &import_config, ctx)
plan.execute(timeline, start_from_job_idx, plan_hash, import_config, ctx)
.await
}
@@ -470,7 +453,6 @@ impl Plan {
jobs: jobs_in_plan,
completed: last_completed_job_idx,
import_plan_hash,
job_soft_size_limit: import_config.import_job_soft_size_limit.into(),
};
timeline.remote_client.schedule_index_upload_for_file_changes()?;
@@ -982,15 +964,6 @@ impl ChunkProcessingJob {
.cloned();
match existing_layer {
Some(existing) => {
// Unlink the remote layer from the index without scheduling its deletion.
// When `existing_layer` drops [`LayerInner::drop`] will schedule its deletion from
// remote storage, but that assumes that the layer was unlinked from the index first.
timeline
.remote_client
.schedule_unlinking_of_layers_from_index_part(std::iter::once(
existing.layer_desc().layer_name(),
))?;
guard.open_mut()?.rewrite_layers(
&[(existing.clone(), resident_layer.clone())],
&[],

View File

@@ -17,7 +17,9 @@ use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{debug, info, warn};
use crate::auth::credentials::check_peer_addr_is_in_list;
use crate::auth::{self, AuthError, ComputeUserInfoMaybeEndpoint, validate_password_and_exchange};
use crate::auth::{
self, AuthError, ComputeUserInfoMaybeEndpoint, IpPattern, validate_password_and_exchange,
};
use crate::cache::Cached;
use crate::config::AuthenticationConfig;
use crate::context::RequestContext;
@@ -135,6 +137,16 @@ impl<'a, T> Backend<'a, T> {
}
}
}
impl<'a, T, E> Backend<'a, Result<T, E>> {
/// Very similar to [`std::option::Option::transpose`].
/// This is most useful for error handling.
pub(crate) fn transpose(self) -> Result<Backend<'a, T>, E> {
match self {
Self::ControlPlane(c, x) => x.map(|x| Backend::ControlPlane(c, x)),
Self::Local(l) => Ok(Backend::Local(l)),
}
}
}
pub(crate) struct ComputeCredentials {
pub(crate) info: ComputeUserInfo,
@@ -272,7 +284,7 @@ async fn auth_quirks(
allow_cleartext: bool,
config: &'static AuthenticationConfig,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> auth::Result<ComputeCredentials> {
) -> auth::Result<(ComputeCredentials, Option<Vec<IpPattern>>)> {
// If there's no project so far, that entails that client doesn't
// support SNI or other means of passing the endpoint (project) name.
// We now expect to see a very specific payload in the place of password.
@@ -289,12 +301,15 @@ async fn auth_quirks(
debug!("fetching authentication info and allowlists");
// check allowed list
if config.ip_allowlist_check_enabled {
let allowed_ips = if config.ip_allowlist_check_enabled {
let allowed_ips = api.get_allowed_ips(ctx, &info).await?;
if !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) {
return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
}
}
allowed_ips
} else {
Cached::new_uncached(Arc::new(vec![]))
};
// check if a VPC endpoint ID is coming in and if yes, if it's allowed
let access_blocks = api.get_block_public_or_vpc_access(ctx, &info).await?;
@@ -353,7 +368,7 @@ async fn auth_quirks(
)
.await
{
Ok(keys) => Ok(keys),
Ok(keys) => Ok((keys, Some(allowed_ips.as_ref().clone()))),
Err(e) => {
if e.is_password_failed() {
// The password could have been changed, so we invalidate the cache.
@@ -405,39 +420,53 @@ async fn authenticate_with_secret(
classic::authenticate(ctx, info, client, config, secret).await
}
impl ControlPlaneClient {
impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
/// Get username from the credentials.
pub(crate) fn get_user(&self) -> &str {
match self {
Self::ControlPlane(_, user_info) => &user_info.user,
Self::Local(_) => "local",
}
}
/// Authenticate the client via the requested backend, possibly using credentials.
#[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)]
pub(crate) async fn authenticate(
&self,
self,
ctx: &RequestContext,
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
user_info: ComputeUserInfoMaybeEndpoint,
allow_cleartext: bool,
config: &'static AuthenticationConfig,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> auth::Result<ComputeCredentials> {
debug!(
user = &*user_info.user,
project = user_info.endpoint(),
"performing authentication using the console"
);
) -> auth::Result<(Backend<'a, ComputeCredentials>, Option<Vec<IpPattern>>)> {
let res = match self {
Self::ControlPlane(api, user_info) => {
debug!(
user = &*user_info.user,
project = user_info.endpoint(),
"performing authentication using the console"
);
let credentials = auth_quirks(
ctx,
self,
user_info,
client,
allow_cleartext,
config,
endpoint_rate_limiter,
)
.await?;
let (credentials, ip_allowlist) = auth_quirks(
ctx,
&*api,
user_info,
client,
allow_cleartext,
config,
endpoint_rate_limiter,
)
.await?;
Ok((Backend::ControlPlane(api, credentials), ip_allowlist))
}
Self::Local(_) => {
return Err(auth::AuthError::bad_auth_method("invalid for local proxy"));
}
};
// TODO: replace with some metric
info!("user successfully authenticated");
Ok(credentials)
res
}
}
@@ -507,25 +536,6 @@ impl ComputeConnectBackend for Backend<'_, ComputeCredentials> {
}
}
pub struct ControlPlaneWakeCompute<'a> {
pub cplane: &'a ControlPlaneClient,
pub creds: ComputeCredentials,
}
#[async_trait::async_trait]
impl ComputeConnectBackend for ControlPlaneWakeCompute<'_> {
async fn wake_compute(
&self,
ctx: &RequestContext,
) -> Result<CachedNodeInfo, control_plane::errors::WakeComputeError> {
self.cplane.wake_compute(ctx, &self.creds.info).await
}
fn get_keys(&self) -> &ComputeCredentialKeys {
&self.creds.keys
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unimplemented, clippy::unwrap_used)]
@@ -542,7 +552,6 @@ mod tests {
use postgres_protocol::message::backend::Message as PgMessage;
use postgres_protocol::message::frontend;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
use tokio_util::task::TaskTracker;
use super::jwt::JwkCache;
use super::{AuthRateLimiter, auth_quirks};
@@ -693,7 +702,7 @@ mod tests {
#[tokio::test]
async fn auth_quirks_scram() {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new(Stream::from_raw(server), TaskTracker::new().token());
let mut stream = PqStream::new(Stream::from_raw(server));
let ctx = RequestContext::test();
let api = Auth {
@@ -775,7 +784,7 @@ mod tests {
#[tokio::test]
async fn auth_quirks_cleartext() {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new(Stream::from_raw(server), TaskTracker::new().token());
let mut stream = PqStream::new(Stream::from_raw(server));
let ctx = RequestContext::test();
let api = Auth {
@@ -829,7 +838,7 @@ mod tests {
#[tokio::test]
async fn auth_quirks_password_hack() {
let (mut client, server) = tokio::io::duplex(1024);
let mut stream = PqStream::new(Stream::from_raw(server), TaskTracker::new().token());
let mut stream = PqStream::new(Stream::from_raw(server));
let ctx = RequestContext::test();
let api = Auth {
@@ -878,7 +887,7 @@ mod tests {
.await
.unwrap();
assert_eq!(creds.info.endpoint, "my-endpoint");
assert_eq!(creds.0.info.endpoint, "my-endpoint");
handle.await.unwrap();
}

View File

@@ -1,7 +1,7 @@
//! Client authentication mechanisms.
pub mod backend;
pub use backend::{Backend, ControlPlaneWakeCompute};
pub use backend::Backend;
mod credentials;
pub(crate) use credentials::{

View File

@@ -18,7 +18,6 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpListener;
use tokio_rustls::TlsConnector;
use tokio_util::sync::CancellationToken;
use tokio_util::task::task_tracker::TaskTrackerToken;
use tracing::{Instrument, error, info};
use utils::project_git_version;
use utils::sentry_init::init_sentry;
@@ -227,8 +226,7 @@ pub(super) async fn task_main(
let dest_suffix = Arc::clone(&dest_suffix);
let compute_tls_config = compute_tls_config.clone();
let tracker = connections.token();
tokio::spawn(
connections.spawn(
async move {
socket
.set_nodelay(true)
@@ -251,7 +249,6 @@ pub(super) async fn task_main(
compute_tls_config,
tls_server_end_point,
socket,
tracker,
)
.await
}
@@ -277,11 +274,10 @@ const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmod
async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
ctx: &RequestContext,
raw_stream: S,
tracker: TaskTrackerToken,
tls_config: Arc<rustls::ServerConfig>,
tls_server_end_point: TlsServerEndPoint,
) -> anyhow::Result<(Stream<S>, TaskTrackerToken)> {
let mut stream = PqStream::new(Stream::from_raw(raw_stream), tracker);
) -> anyhow::Result<Stream<S>> {
let mut stream = PqStream::new(Stream::from_raw(raw_stream));
let msg = stream.read_startup_packet().await?;
use pq_proto::FeStartupPacket::SslRequest;
@@ -295,7 +291,7 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
// Upgrade raw stream into a secure TLS-backed stream.
// NOTE: We've consumed `tls`; this fact will be used later.
let (raw, read_buf, tracker) = stream.into_inner();
let (raw, read_buf) = stream.into_inner();
// TODO: Normally, client doesn't send any data before
// server says TLS handshake is ok and read_buf is empty.
// However, you could imagine pipelining of postgres
@@ -306,16 +302,13 @@ async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
bail!("data is sent before server replied with EncryptionResponse");
}
Ok((
Stream::Tls {
tls: Box::new(
raw.upgrade(tls_config, !ctx.has_private_peer_addr())
.await?,
),
tls_server_end_point,
},
tracker,
))
Ok(Stream::Tls {
tls: Box::new(
raw.upgrade(tls_config, !ctx.has_private_peer_addr())
.await?,
),
tls_server_end_point,
})
}
unexpected => {
info!(
@@ -336,10 +329,8 @@ async fn handle_client(
compute_tls_config: Option<Arc<rustls::ClientConfig>>,
tls_server_end_point: TlsServerEndPoint,
stream: impl AsyncRead + AsyncWrite + Unpin,
tracker: TaskTrackerToken,
) -> anyhow::Result<()> {
let (mut tls_stream, _tracker) =
ssl_handshake(&ctx, stream, tracker, tls_config, tls_server_end_point).await?;
let mut tls_stream = ssl_handshake(&ctx, stream, tls_config, tls_server_end_point).await?;
// Cut off first part of the SNI domain
// We receive required destination details in the format of

View File

@@ -323,7 +323,7 @@ impl CancellationHandler {
}
}
pub(crate) fn get_key(self: Arc<Self>) -> Session {
pub(crate) fn get_key(self: &Arc<Self>) -> Session {
// we intentionally generate a random "backend pid" and "secret key" here.
// we use the corresponding u64 as an identifier for the
// actual endpoint+pid+secret for postgres/pgbouncer.
@@ -340,7 +340,7 @@ impl CancellationHandler {
Session {
key,
redis_key,
cancellation_handler: self,
cancellation_handler: Arc::clone(self),
}
}

View File

@@ -1,9 +1,8 @@
use std::sync::Arc;
use futures::TryFutureExt;
use futures::{FutureExt, TryFutureExt};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio_util::sync::CancellationToken;
use tokio_util::task::task_tracker::TaskTrackerToken;
use tracing::{Instrument, debug, error, info};
use crate::auth::backend::ConsoleRedirectBackend;
@@ -15,8 +14,10 @@ use crate::metrics::{Metrics, NumClientConnectionsGuard};
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute};
use crate::proxy::handshake::{HandshakeData, handshake};
use crate::proxy::passthrough::passthrough;
use crate::proxy::{ClientRequestError, prepare_client_connection, run_until_cancelled};
use crate::proxy::passthrough::ProxyPassthrough;
use crate::proxy::{
ClientRequestError, ErrorSource, prepare_client_connection, run_until_cancelled,
};
pub async fn task_main(
config: &'static ProxyConfig,
@@ -34,6 +35,7 @@ pub async fn task_main(
socket2::SockRef::from(&listener).set_keepalive(true)?;
let connections = tokio_util::task::task_tracker::TaskTracker::new();
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
while let Some(accept_result) =
run_until_cancelled(listener.accept(), &cancellation_token).await
@@ -47,11 +49,11 @@ pub async fn task_main(
let session_id = uuid::Uuid::new_v4();
let cancellation_handler = Arc::clone(&cancellation_handler);
let cancellations = cancellations.clone();
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
let tracker = connections.token();
tokio::spawn(async move {
connections.spawn(async move {
let (socket, peer_addr) = match read_proxy_protocol(socket).await {
Err(e) => {
error!("per-client task finished with an error: {e:#}");
@@ -101,80 +103,99 @@ pub async fn task_main(
&config.region,
);
let span = ctx.span();
let mut slot = Some(ctx);
let res = handle_client(
config,
backend,
&mut slot,
&ctx,
cancellation_handler,
socket,
conn_gauge,
tracker,
cancellations,
)
.instrument(span)
.instrument(ctx.span())
.boxed()
.await;
match (slot, res) {
(None, _) => {}
(Some(ctx), Ok(())) => {
ctx.success();
}
(Some(ctx), Err(e)) => {
match res {
Err(e) => {
ctx.set_error_kind(e.get_error_kind());
tracing::warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
error!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
}
Ok(None) => {
ctx.set_success();
}
Ok(Some(p)) => {
ctx.set_success();
let _disconnect = ctx.log_connect();
match p.proxy_pass(&config.connect_to_compute).await {
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
error!(
?session_id,
"per-client task finished with an IO error from the client: {e:#}"
);
}
Err(ErrorSource::Compute(e)) => {
error!(
?session_id,
"per-client task finished with an IO error from the compute: {e:#}"
);
}
}
}
}
});
}
connections.close();
cancellations.close();
drop(listener);
// Drain connections
connections.wait().await;
cancellations.wait().await;
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
config: &'static ProxyConfig,
backend: &'static ConsoleRedirectBackend,
ctx_slot: &mut Option<RequestContext>,
ctx: &RequestContext,
cancellation_handler: Arc<CancellationHandler>,
stream: S,
conn_gauge: NumClientConnectionsGuard<'static>,
tracker: TaskTrackerToken,
) -> Result<(), ClientRequestError> {
let protocol = ctx_slot.as_ref().expect("context must be set").protocol();
debug!(%protocol, "handling interactive connection from client");
cancellations: tokio_util::task::task_tracker::TaskTracker,
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
debug!(
protocol = %ctx.protocol(),
"handling interactive connection from client"
);
let metrics = &Metrics::get().proxy;
let request_gauge = metrics.connection_requests.guard(protocol);
let proto = ctx.protocol();
let request_gauge = metrics.connection_requests.guard(proto);
let tls = config.tls_config.load();
let tls = tls.as_deref();
let data = {
let ctx = ctx_slot.as_ref().expect("context must be set");
let record_handshake_error = !ctx.has_private_peer_addr();
let _pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let do_handshake = handshake(ctx, stream, tracker, tls, record_handshake_error);
tokio::time::timeout(config.handshake_timeout, do_handshake).await??
};
let record_handshake_error = !ctx.has_private_peer_addr();
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let do_handshake = handshake(ctx, stream, tls, record_handshake_error);
let (mut stream, params) = match data {
let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
.await??
{
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Cancel(cancel_key_data, tracker) => {
HandshakeData::Cancel(cancel_key_data) => {
// spawn a task to cancel the session, but don't wait for it
tokio::spawn({
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
let ctx = ctx_slot.take().expect("context must be set");
cancellations.spawn({
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
let ctx = ctx.clone();
let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id());
cancel_span.follows_from(tracing::Span::current());
async move {
let _tracker = tracker;
cancellation_handler_clone
.cancel_session(
cancel_key_data,
@@ -184,17 +205,15 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send + 'st
backend.get_api(),
)
.await
.inspect_err(|e| debug!(error = ?e, "cancel_session failed"))
.ok();
}
.instrument(cancel_span)
.inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok();
}.instrument(cancel_span)
});
return Ok(());
return Ok(None);
}
};
drop(pause);
let ctx = ctx_slot.as_ref().expect("context must be set");
ctx.set_db_options(params.clone());
let (node_info, user_info, _ip_allowlist) = match backend
@@ -209,13 +228,13 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send + 'st
let mut node = connect_to_compute(
ctx,
TcpMechanism {
&TcpMechanism {
user_info,
params_compat: true,
params: &params,
locks: &config.connect_compute_locks,
},
node_info,
&node_info,
config.wake_compute_retry_config,
&config.connect_to_compute,
)
@@ -233,22 +252,17 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send + 'st
// PqStream input buffer. Normally there is none, but our serverless npm
// driver in pipeline mode sends startup, password and first query
// immediately after opening the connection.
let (stream, read_buf, tracker) = stream.into_inner();
let (stream, read_buf) = stream.into_inner();
node.stream.write_all(&read_buf).await?;
let ctx = ctx_slot.take().expect("context must be set");
ctx.set_success();
tokio::spawn(passthrough(
ctx,
&config.connect_to_compute,
stream,
node,
session,
request_gauge,
conn_gauge,
tracker,
));
Ok(())
Ok(Some(ProxyPassthrough {
client: stream,
aux: node.aux.clone(),
private_link_id: None,
compute: node,
session_id: ctx.session_id(),
cancel: session,
_req: request_gauge,
_conn: conn_gauge,
}))
}

View File

@@ -38,7 +38,7 @@ pub struct RequestContext(
/// I would typically use a RefCell but that would break the `Send` requirements
/// so we need something with thread-safety. `TryLock` is a cheap alternative
/// that offers similar semantics to a `RefCell` but with synchronisation.
TryLock<Box<RequestContextInner>>,
TryLock<RequestContextInner>,
);
struct RequestContextInner {
@@ -89,7 +89,7 @@ pub(crate) enum AuthMethod {
impl Clone for RequestContext {
fn clone(&self) -> Self {
let inner = self.0.try_lock().expect("should not deadlock");
let new = Box::new(RequestContextInner {
let new = RequestContextInner {
conn_info: inner.conn_info.clone(),
session_id: inner.session_id,
protocol: inner.protocol,
@@ -117,7 +117,7 @@ impl Clone for RequestContext {
disconnect_sender: None,
latency_timer: LatencyTimer::noop(inner.protocol),
disconnect_timestamp: inner.disconnect_timestamp,
});
};
Self(TryLock::new(new))
}
@@ -140,7 +140,7 @@ impl RequestContext {
role = tracing::field::Empty,
);
let inner = Box::new(RequestContextInner {
let inner = RequestContextInner {
conn_info,
session_id,
protocol,
@@ -168,7 +168,7 @@ impl RequestContext {
disconnect_sender: LOG_CHAN_DISCONNECT.get().and_then(|tx| tx.upgrade()),
latency_timer: LatencyTimer::new(protocol),
disconnect_timestamp: None,
});
};
Self(TryLock::new(inner))
}
@@ -522,7 +522,7 @@ impl Drop for RequestContextInner {
}
}
pub struct DisconnectLogger(Box<RequestContextInner>);
pub struct DisconnectLogger(RequestContextInner);
impl Drop for DisconnectLogger {
fn drop(&mut self) {

View File

@@ -53,25 +53,6 @@ pub(crate) trait ConnectMechanism {
fn update_connect_config(&self, conf: &mut compute::ConnCfg);
}
#[async_trait]
impl<T: ConnectMechanism + Sync> ConnectMechanism for &T {
type Connection = T::Connection;
type ConnectError = T::ConnectError;
type Error = T::Error;
async fn connect_once(
&self,
ctx: &RequestContext,
node_info: &control_plane::CachedNodeInfo,
config: &ComputeConfig,
) -> Result<Self::Connection, Self::ConnectError> {
T::connect_once(self, ctx, node_info, config).await
}
fn update_connect_config(&self, conf: &mut compute::ConnCfg) {
T::update_connect_config(self, conf);
}
}
#[async_trait]
pub(crate) trait ComputeConnectBackend {
async fn wake_compute(
@@ -124,8 +105,8 @@ impl ConnectMechanism for TcpMechanism<'_> {
#[tracing::instrument(skip_all)]
pub(crate) async fn connect_to_compute<M: ConnectMechanism, B: ComputeConnectBackend>(
ctx: &RequestContext,
mechanism: M,
backend: B,
mechanism: &M,
user_info: &B,
wake_compute_retry_config: RetryConfig,
compute: &ComputeConfig,
) -> Result<M::Connection, M::Error>
@@ -135,9 +116,9 @@ where
{
let mut num_retries = 0;
let mut node_info =
wake_compute(&mut num_retries, ctx, &backend, wake_compute_retry_config).await?;
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?;
node_info.set_keys(backend.get_keys());
node_info.set_keys(user_info.get_keys());
mechanism.update_connect_config(&mut node_info.config);
// try once
@@ -178,7 +159,7 @@ where
let old_node_info = invalidate_cache(node_info);
// TODO: increment num_retries?
let mut node_info =
wake_compute(&mut num_retries, ctx, &backend, wake_compute_retry_config).await?;
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?;
node_info.reuse_settings(old_node_info);
mechanism.update_connect_config(&mut node_info.config);

View File

@@ -67,6 +67,7 @@ where
}
}
#[tracing::instrument(skip_all)]
pub async fn copy_bidirectional_client_compute<Client, Compute>(
client: &mut Client,
compute: &mut Compute,

View File

@@ -5,7 +5,6 @@ use pq_proto::{
};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::task::task_tracker::TaskTrackerToken;
use tracing::{debug, info, warn};
use crate::auth::endpoint_sni;
@@ -52,7 +51,7 @@ impl ReportableError for HandshakeError {
pub(crate) enum HandshakeData<S> {
Startup(PqStream<Stream<S>>, StartupMessageParams),
Cancel(CancelKeyData, TaskTrackerToken),
Cancel(CancelKeyData),
}
/// Establish a (most probably, secure) connection with the client.
@@ -63,7 +62,6 @@ pub(crate) enum HandshakeData<S> {
pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
ctx: &RequestContext,
stream: S,
tracker: TaskTrackerToken,
mut tls: Option<&TlsConfig>,
record_handshake_error: bool,
) -> Result<HandshakeData<S>, HandshakeError> {
@@ -73,7 +71,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
const PG_PROTOCOL_EARLIEST: ProtocolVersion = ProtocolVersion::new(3, 0);
const PG_PROTOCOL_LATEST: ProtocolVersion = ProtocolVersion::new(3, 0);
let mut stream = PqStream::new(Stream::from_raw(stream), tracker);
let mut stream = PqStream::new(Stream::from_raw(stream));
loop {
let msg = stream.read_startup_packet().await?;
match msg {
@@ -159,13 +157,15 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
let (_, tls_server_end_point) =
tls.cert_resolver.resolve(conn_info.server_name());
stream.framed = Framed {
stream: Stream::Tls {
tls: Box::new(tls_stream),
tls_server_end_point,
stream = PqStream {
framed: Framed {
stream: Stream::Tls {
tls: Box::new(tls_stream),
tls_server_end_point,
},
read_buf,
write_buf,
},
read_buf,
write_buf,
};
}
}
@@ -248,7 +248,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
}
FeStartupPacket::CancelRequest(cancel_key_data) => {
info!(session_type = "cancellation", "successful handshake");
break Ok(HandshakeData::Cancel(cancel_key_data, stream.tracker));
break Ok(HandshakeData::Cancel(cancel_key_data));
}
}
}

View File

@@ -10,27 +10,26 @@ pub(crate) mod wake_compute;
use std::sync::Arc;
pub use copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute};
use futures::TryFutureExt;
use futures::{FutureExt, TryFutureExt};
use itertools::Itertools;
use once_cell::sync::OnceCell;
use passthrough::passthrough;
use pq_proto::{BeMessage as Be, CancelKeyData, StartupMessageParams};
use regex::Regex;
use serde::{Deserialize, Serialize};
use smol_str::{SmolStr, format_smolstr};
use smol_str::{SmolStr, ToSmolStr, format_smolstr};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio_util::sync::CancellationToken;
use tokio_util::task::task_tracker::TaskTrackerToken;
use tracing::{Instrument, debug, error, info, warn};
use self::connect_compute::{TcpMechanism, connect_to_compute};
use self::passthrough::ProxyPassthrough;
use crate::cancellation::{self, CancellationHandler};
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::metrics::{Metrics, NumClientConnectionsGuard};
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol};
use crate::proxy::handshake::{HandshakeData, handshake};
use crate::rate_limiter::EndpointRateLimiter;
use crate::stream::{PqStream, Stream};
@@ -71,6 +70,7 @@ pub async fn task_main(
socket2::SockRef::from(&listener).set_keepalive(true)?;
let connections = tokio_util::task::task_tracker::TaskTracker::new();
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
while let Some(accept_result) =
run_until_cancelled(listener.accept(), &cancellation_token).await
@@ -84,12 +84,12 @@ pub async fn task_main(
let session_id = uuid::Uuid::new_v4();
let cancellation_handler = Arc::clone(&cancellation_handler);
let cancellations = cancellations.clone();
debug!(protocol = "tcp", %session_id, "accepted new TCP connection");
let endpoint_rate_limiter2 = endpoint_rate_limiter.clone();
let tracker = connections.token();
tokio::spawn(async move {
connections.spawn(async move {
let (socket, conn_info) = match read_proxy_protocol(socket).await {
Err(e) => {
warn!("per-client task finished with an error: {e:#}");
@@ -138,41 +138,60 @@ pub async fn task_main(
crate::metrics::Protocol::Tcp,
&config.region,
);
let span = ctx.span();
let mut ctx = Some(ctx);
let res = handle_client(
config,
auth_backend,
&mut ctx,
&ctx,
cancellation_handler,
socket,
ClientMode::Tcp,
endpoint_rate_limiter2,
conn_gauge,
tracker,
cancellations,
)
.instrument(span)
.instrument(ctx.span())
.boxed()
.await;
match (ctx, res) {
(None, _) => {}
(Some(ctx), Ok(())) => {
ctx.success();
}
(Some(ctx), Err(e)) => {
match res {
Err(e) => {
ctx.set_error_kind(e.get_error_kind());
warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
}
Ok(None) => {
ctx.set_success();
}
Ok(Some(p)) => {
ctx.set_success();
let _disconnect = ctx.log_connect();
match p.proxy_pass(&config.connect_to_compute).await {
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
warn!(
?session_id,
"per-client task finished with an IO error from the client: {e:#}"
);
}
Err(ErrorSource::Compute(e)) => {
error!(
?session_id,
"per-client task finished with an IO error from the compute: {e:#}"
);
}
}
}
}
});
}
connections.close();
cancellations.close();
drop(listener);
// Drain connections
connections.wait().await;
cancellations.wait().await;
Ok(())
}
@@ -239,79 +258,46 @@ impl ReportableError for ClientRequestError {
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
config: &'static ProxyConfig,
auth_backend: &'static auth::Backend<'static, ()>,
ctx_slot: &mut Option<RequestContext>,
ctx: &RequestContext,
cancellation_handler: Arc<CancellationHandler>,
stream: S,
mode: ClientMode,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
conn_gauge: NumClientConnectionsGuard<'static>,
tracker: TaskTrackerToken,
) -> Result<(), ClientRequestError> {
let cplane = match auth_backend {
auth::Backend::ControlPlane(cplane, ()) => &**cplane,
auth::Backend::Local(_) => unreachable!("local proxy does not run tcp proxy service"),
};
let protocol = ctx_slot.as_ref().expect("context must be set").protocol();
debug!(%protocol, "handling interactive connection from client");
cancellations: tokio_util::task::task_tracker::TaskTracker,
) -> Result<Option<ProxyPassthrough<S>>, ClientRequestError> {
debug!(
protocol = %ctx.protocol(),
"handling interactive connection from client"
);
let metrics = &Metrics::get().proxy;
let request_gauge = metrics.connection_requests.guard(protocol);
let proto = ctx.protocol();
let request_gauge = metrics.connection_requests.guard(proto);
let handshake_result: Result<_, ClientRequestError> = async {
let tls = config.tls_config.load();
let tls = tls.as_deref();
let tls = config.tls_config.load();
let tls = tls.as_deref();
let ctx = ctx_slot.as_ref().expect("context must be set");
let record_handshake_error = !ctx.has_private_peer_addr();
let data = {
let _pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
tokio::time::timeout(
config.handshake_timeout,
handshake(
ctx,
stream,
tracker,
mode.handshake_tls(tls),
record_handshake_error,
),
)
.await??
};
let record_handshake_error = !ctx.has_private_peer_addr();
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let do_handshake = handshake(ctx, stream, mode.handshake_tls(tls), record_handshake_error);
match data {
HandshakeData::Startup(mut stream, params) => {
ctx.set_db_options(params.clone());
let host = mode.hostname(stream.get_ref());
let cn = tls.map(|tls| &tls.common_names);
// Extract credentials which we're going to use for auth.
let result = auth::ComputeUserInfoMaybeEndpoint::parse(ctx, &params, host, cn);
let user_info = match result {
Ok(user_info) => user_info,
Err(e) => stream.throw_error(e, Some(ctx)).await?,
};
let session = cancellation_handler.get_key();
Ok(Some((stream, params, session, user_info)))
}
HandshakeData::Cancel(cancel_key_data, tracker) => {
let ctx = ctx_slot.take().expect("context must be set");
ctx.set_success();
let cancel_span = tracing::info_span!(parent: None, "cancel_session", session_id = ?ctx.session_id());
let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
.await??
{
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Cancel(cancel_key_data) => {
// spawn a task to cancel the session, but don't wait for it
cancellations.spawn({
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
let ctx = ctx.clone();
let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?ctx.session_id());
cancel_span.follows_from(tracing::Span::current());
// spawn a task to cancel the session, but don't wait for it
tokio::spawn(async move {
// ensure the proxy doesn't shutdown until we complete this task.
let _tracker = tracker;
cancellation_handler
async move {
cancellation_handler_clone
.cancel_session(
cancel_key_data,
ctx,
@@ -319,108 +305,111 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send + 'st
config.authentication_config.is_vpc_acccess_proxy,
auth_backend.get_api(),
)
.instrument(cancel_span)
.await
.unwrap_or_else(|e| debug!(error = ?e, "cancel_session failed"));
});
.inspect_err(|e | debug!(error = ?e, "cancel_session failed")).ok();
}.instrument(cancel_span)
});
Ok(None)
}
return Ok(None);
}
}
.await;
let Some((mut stream, params, session, user_info)) = handshake_result? else {
return Ok(());
};
let ctx = ctx_slot.as_ref().expect("context must be set");
drop(pause);
let auth_result: Result<_, ClientRequestError> = async {
let user = user_info.user.clone();
ctx.set_db_options(params.clone());
match cplane
.authenticate(
ctx,
&mut stream,
user_info,
mode.allow_cleartext(),
&config.authentication_config,
endpoint_rate_limiter,
)
.await
{
Ok(auth_result) => Ok(auth_result),
Err(e) => {
let db = params.get("database");
let app = params.get("application_name");
let params_span = tracing::info_span!("", ?user, ?db, ?app);
stream
.throw_error(e, Some(ctx))
.instrument(params_span)
.await?
}
}
}
.await;
let hostname = mode.hostname(stream.get_ref());
let compute_creds = auth_result?;
let common_names = tls.map(|tls| &tls.common_names);
let connect_result: Result<_, ClientRequestError> = async {
let compute_user_info = compute_creds.info.clone();
let params_compat = compute_user_info
.options
.get(NeonOptions::PARAMS_COMPAT)
.is_some();
// Extract credentials which we're going to use for auth.
let result = auth_backend
.as_ref()
.map(|()| auth::ComputeUserInfoMaybeEndpoint::parse(ctx, &params, hostname, common_names))
.transpose();
let mut node = connect_to_compute(
let user_info = match result {
Ok(user_info) => user_info,
Err(e) => stream.throw_error(e, Some(ctx)).await?,
};
let user = user_info.get_user().to_owned();
let (user_info, _ip_allowlist) = match user_info
.authenticate(
ctx,
TcpMechanism {
user_info: compute_user_info,
params_compat,
params: &params,
locks: &config.connect_compute_locks,
},
auth::ControlPlaneWakeCompute {
cplane,
creds: compute_creds,
},
config.wake_compute_retry_config,
&config.connect_to_compute,
&mut stream,
mode.allow_cleartext(),
&config.authentication_config,
endpoint_rate_limiter,
)
.or_else(|e| stream.throw_error(e, Some(ctx)))
.await?;
.await
{
Ok(auth_result) => auth_result,
Err(e) => {
let db = params.get("database");
let app = params.get("application_name");
let params_span = tracing::info_span!("", ?user, ?db, ?app);
session.write_cancel_key(node.cancel_closure.clone())?;
prepare_client_connection(&node, *session.key(), &mut stream).await?;
return stream
.throw_error(e, Some(ctx))
.instrument(params_span)
.await?;
}
};
// Before proxy passing, forward to compute whatever data is left in the
// PqStream input buffer. Normally there is none, but our serverless npm
// driver in pipeline mode sends startup, password and first query
// immediately after opening the connection.
let (stream, read_buf, tracker) = stream.into_inner();
node.stream.write_all(&read_buf).await?;
let compute_user_info = match &user_info {
auth::Backend::ControlPlane(_, info) => &info.info,
auth::Backend::Local(_) => unreachable!("local proxy does not run tcp proxy service"),
};
let params_compat = compute_user_info
.options
.get(NeonOptions::PARAMS_COMPAT)
.is_some();
Ok((node, stream, tracker))
}
.await;
let (node, stream, tracker) = connect_result?;
let ctx = ctx_slot.take().expect("context must be set");
ctx.set_success();
tokio::spawn(passthrough(
let mut node = connect_to_compute(
ctx,
&TcpMechanism {
user_info: compute_user_info.clone(),
params_compat,
params: &params,
locks: &config.connect_compute_locks,
},
&user_info,
config.wake_compute_retry_config,
&config.connect_to_compute,
stream,
node,
session,
request_gauge,
conn_gauge,
tracker,
));
)
.or_else(|e| stream.throw_error(e, Some(ctx)))
.await?;
Ok(())
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
let session = cancellation_handler_clone.get_key();
session.write_cancel_key(node.cancel_closure.clone())?;
prepare_client_connection(&node, *session.key(), &mut stream).await?;
// Before proxy passing, forward to compute whatever data is left in the
// PqStream input buffer. Normally there is none, but our serverless npm
// driver in pipeline mode sends startup, password and first query
// immediately after opening the connection.
let (stream, read_buf) = stream.into_inner();
node.stream.write_all(&read_buf).await?;
let private_link_id = match ctx.extra() {
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
None => None,
};
Ok(Some(ProxyPassthrough {
client: stream,
aux: node.aux.clone(),
private_link_id,
compute: node,
session_id: ctx.session_id(),
cancel: session,
_req: request_gauge,
_conn: conn_gauge,
}))
}
/// Finish client connection initialization: confirm auth success, send params, etc.

View File

@@ -1,6 +1,5 @@
use smol_str::{SmolStr, ToSmolStr};
use smol_str::SmolStr;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::task::task_tracker::TaskTrackerToken;
use tracing::debug;
use utils::measured_stream::MeasuredStream;
@@ -8,14 +7,13 @@ use super::copy_bidirectional::ErrorSource;
use crate::cancellation;
use crate::compute::PostgresConnection;
use crate::config::ComputeConfig;
use crate::context::RequestContext;
use crate::control_plane::messages::MetricsAuxInfo;
use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard};
use crate::protocol2::ConnectionInfoExtra;
use crate::stream::Stream;
use crate::usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS};
/// Forward bytes in both directions (client <-> compute).
#[tracing::instrument(skip_all)]
pub(crate) async fn proxy_pass(
client: impl AsyncRead + AsyncWrite + Unpin,
compute: impl AsyncRead + AsyncWrite + Unpin,
@@ -63,53 +61,41 @@ pub(crate) async fn proxy_pass(
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn passthrough<S: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
ctx: RequestContext,
compute_config: &'static ComputeConfig,
pub(crate) struct ProxyPassthrough<S> {
pub(crate) client: Stream<S>,
pub(crate) compute: PostgresConnection,
pub(crate) aux: MetricsAuxInfo,
pub(crate) session_id: uuid::Uuid,
pub(crate) private_link_id: Option<SmolStr>,
pub(crate) cancel: cancellation::Session,
client: Stream<S>,
compute: PostgresConnection,
cancel: cancellation::Session,
_req: NumConnectionRequestsGuard<'static>,
_conn: NumClientConnectionsGuard<'static>,
_tracker: TaskTrackerToken,
) {
let session_id = ctx.session_id();
let private_link_id = match ctx.extra() {
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
None => None,
};
let _disconnect = ctx.log_connect();
let res = proxy_pass(client, compute.stream, compute.aux, private_link_id).await;
match res {
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
tracing::warn!(
session_id = ?session_id,
"per-client task finished with an IO error from the client: {e:#}"
);
}
Err(ErrorSource::Compute(e)) => {
tracing::error!(
session_id = ?session_id,
"per-client task finished with an IO error from the compute: {e:#}"
);
}
}
if let Err(err) = compute
.cancel_closure
.try_cancel_query(compute_config)
.await
{
tracing::warn!(session_id = ?session_id, ?err, "could not cancel the query in the database");
}
// we don't need a result. If the queue is full, we just log the error
drop(cancel.remove_cancel_key());
pub(crate) _req: NumConnectionRequestsGuard<'static>,
pub(crate) _conn: NumClientConnectionsGuard<'static>,
}
impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {
pub(crate) async fn proxy_pass(
self,
compute_config: &ComputeConfig,
) -> Result<(), ErrorSource> {
let res = proxy_pass(
self.client,
self.compute.stream,
self.aux,
self.private_link_id,
)
.await;
if let Err(err) = self
.compute
.cancel_closure
.try_cancel_query(compute_config)
.await
{
tracing::warn!(session_id = ?self.session_id, ?err, "could not cancel the query in the database");
}
drop(self.cancel.remove_cancel_key()); // we don't need a result. If the queue is full, we just log the error
res
}
}

View File

@@ -38,7 +38,6 @@ async fn proxy_mitm(
let (end_client, startup) = match handshake(
&RequestContext::test(),
client1,
TaskTracker::new().token(),
Some(&server_config1),
false,
)
@@ -46,7 +45,7 @@ async fn proxy_mitm(
.unwrap()
{
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Cancel(_, _) => panic!("cancellation not supported"),
HandshakeData::Cancel(_) => panic!("cancellation not supported"),
};
let mut end_server = tokio_util::codec::Framed::new(end_server, PgFrame);

View File

@@ -15,7 +15,6 @@ use rstest::rstest;
use rustls::crypto::ring;
use rustls::pki_types;
use tokio::io::DuplexStream;
use tokio_util::task::TaskTracker;
use tracing_test::traced_test;
use super::connect_compute::ConnectMechanism;
@@ -179,12 +178,10 @@ async fn dummy_proxy(
auth: impl TestAuth + Send,
) -> anyhow::Result<()> {
let (client, _) = read_proxy_protocol(client).await?;
let t = TaskTracker::new().token();
let mut stream =
match handshake(&RequestContext::test(), client, t, tls.as_ref(), false).await? {
HandshakeData::Startup(stream, _) => stream,
HandshakeData::Cancel(_, _) => bail!("cancellation not supported"),
};
let mut stream = match handshake(&RequestContext::test(), client, tls.as_ref(), false).await? {
HandshakeData::Startup(stream, _) => stream,
HandshakeData::Cancel(_) => bail!("cancellation not supported"),
};
auth.authenticate(&mut stream).await?;
@@ -625,7 +622,7 @@ async fn connect_to_compute_success() {
let mechanism = TestConnectMechanism::new(vec![Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
let config = config();
connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config)
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
.await
.unwrap();
mechanism.verify();
@@ -639,7 +636,7 @@ async fn connect_to_compute_retry() {
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
let config = config();
connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config)
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
.await
.unwrap();
mechanism.verify();
@@ -654,7 +651,7 @@ async fn connect_to_compute_non_retry_1() {
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Fail]);
let user_info = helper_create_connect_info(&mechanism);
let config = config();
connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config)
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
.await
.unwrap_err();
mechanism.verify();
@@ -669,7 +666,7 @@ async fn connect_to_compute_non_retry_2() {
let mechanism = TestConnectMechanism::new(vec![Wake, Fail, Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
let config = config();
connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config)
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
.await
.unwrap();
mechanism.verify();
@@ -694,7 +691,7 @@ async fn connect_to_compute_non_retry_3() {
connect_to_compute(
&ctx,
&mechanism,
user_info,
&user_info,
wake_compute_retry_config,
&config,
)
@@ -712,7 +709,7 @@ async fn wake_retry() {
let mechanism = TestConnectMechanism::new(vec![WakeRetry, Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
let config = config();
connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config)
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
.await
.unwrap();
mechanism.verify();
@@ -727,7 +724,7 @@ async fn wake_non_retry() {
let mechanism = TestConnectMechanism::new(vec![WakeRetry, WakeFail]);
let user_info = helper_create_connect_info(&mechanism);
let config = config();
connect_to_compute(&ctx, &mechanism, user_info, config.retry, &config)
connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config)
.await
.unwrap_err();
mechanism.verify();
@@ -746,7 +743,7 @@ async fn fail_but_wake_invalidates_cache() {
let user = helper_create_connect_info(&mech);
let cfg = config();
connect_to_compute(&ctx, &mech, user, cfg.retry, &cfg)
connect_to_compute(&ctx, &mech, &user, cfg.retry, &cfg)
.await
.unwrap();
@@ -767,7 +764,7 @@ async fn fail_no_wake_skips_cache_invalidation() {
let user = helper_create_connect_info(&mech);
let cfg = config();
connect_to_compute(&ctx, &mech, user, cfg.retry, &cfg)
connect_to_compute(&ctx, &mech, &user, cfg.retry, &cfg)
.await
.unwrap();
@@ -788,7 +785,7 @@ async fn retry_but_wake_invalidates_cache() {
let user_info = helper_create_connect_info(&mechanism);
let cfg = config();
connect_to_compute(&ctx, &mechanism, user_info, cfg.retry, &cfg)
connect_to_compute(&ctx, &mechanism, &user_info, cfg.retry, &cfg)
.await
.unwrap();
mechanism.verify();
@@ -811,7 +808,7 @@ async fn retry_no_wake_skips_invalidation() {
let user_info = helper_create_connect_info(&mechanism);
let cfg = config();
connect_to_compute(&ctx, &mechanism, user_info, cfg.retry, &cfg)
connect_to_compute(&ctx, &mechanism, &user_info, cfg.retry, &cfg)
.await
.unwrap_err();
mechanism.verify();

View File

@@ -224,13 +224,13 @@ impl PoolingBackend {
let backend = self.auth_backend.as_ref().map(|()| keys);
crate::proxy::connect_compute::connect_to_compute(
ctx,
TokioMechanism {
&TokioMechanism {
conn_id,
conn_info,
pool: self.pool.clone(),
locks: &self.config.connect_compute_locks,
},
backend,
&backend,
self.config.wake_compute_retry_config,
&self.config.connect_to_compute,
)
@@ -268,13 +268,13 @@ impl PoolingBackend {
});
crate::proxy::connect_compute::connect_to_compute(
ctx,
HyperMechanism {
&HyperMechanism {
conn_id,
conn_info,
pool: self.http_conn_pool.clone(),
locks: &self.config.connect_compute_locks,
},
backend,
&backend,
self.config.wake_compute_retry_config,
&self.config.connect_to_compute,
)

View File

@@ -41,7 +41,7 @@ use tokio::net::{TcpListener, TcpStream};
use tokio::time::timeout;
use tokio_rustls::TlsAcceptor;
use tokio_util::sync::CancellationToken;
use tokio_util::task::task_tracker::TaskTrackerToken;
use tokio_util::task::TaskTracker;
use tracing::{Instrument, info, warn};
use crate::cancellation::CancellationHandler;
@@ -124,6 +124,7 @@ pub async fn task_main(
let connections = tokio_util::task::task_tracker::TaskTracker::new();
connections.close(); // allows `connections.wait to complete`
let cancellations = tokio_util::task::task_tracker::TaskTracker::new();
while let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await {
let (conn, peer_addr) = res.context("could not accept TCP stream")?;
if let Err(e) = conn.set_nodelay(true) {
@@ -149,11 +150,11 @@ pub async fn task_main(
let conn_token = cancellation_token.child_token();
let tls_acceptor = tls_acceptor.clone();
let backend = backend.clone();
let connections2 = connections.clone();
let cancellation_handler = cancellation_handler.clone();
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
let tracker = connections.token();
tokio::spawn(
let cancellations = cancellations.clone();
connections.spawn(
async move {
let conn_token2 = conn_token.clone();
let _cancel_guard = config.http_config.cancel_set.insert(conn_id, conn_token2);
@@ -180,7 +181,8 @@ pub async fn task_main(
Box::pin(connection_handler(
config,
backend,
tracker,
connections2,
cancellations,
cancellation_handler,
endpoint_rate_limiter,
conn_token,
@@ -303,7 +305,8 @@ async fn connection_startup(
async fn connection_handler(
config: &'static ProxyConfig,
backend: Arc<PoolingBackend>,
tracker: TaskTrackerToken,
connections: TaskTracker,
cancellations: TaskTracker,
cancellation_handler: Arc<CancellationHandler>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellation_token: CancellationToken,
@@ -344,17 +347,19 @@ async fn connection_handler(
// `request_handler` is not cancel safe. It expects to be cancelled only at specific times.
// By spawning the future, we ensure it never gets cancelled until it decides to.
let handler = tokio::spawn(
let cancellations = cancellations.clone();
let handler = connections.spawn(
request_handler(
req,
config,
backend.clone(),
tracker.clone(),
connections.clone(),
cancellation_handler.clone(),
session_id,
conn_info2.clone(),
http_request_token,
endpoint_rate_limiter.clone(),
cancellations,
)
.in_current_span()
.map_ok_or_else(api_error_into_response, |r| r),
@@ -395,13 +400,14 @@ async fn request_handler(
mut request: hyper::Request<Incoming>,
config: &'static ProxyConfig,
backend: Arc<PoolingBackend>,
tracker: TaskTrackerToken,
ws_connections: TaskTracker,
cancellation_handler: Arc<CancellationHandler>,
session_id: uuid::Uuid,
conn_info: ConnectionInfo,
// used to cancel in-flight HTTP requests. not used to cancel websockets
http_cancellation_token: CancellationToken,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellations: TaskTracker,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ApiError> {
let host = request
.headers()
@@ -435,17 +441,10 @@ async fn request_handler(
let (response, websocket) = framed_websockets::upgrade::upgrade(&mut request)
.map_err(|e| ApiError::BadRequest(e.into()))?;
tokio::spawn(
let cancellations = cancellations.clone();
ws_connections.spawn(
async move {
let websocket = match websocket.await {
Err(e) => {
warn!("could not upgrade websocket connection: {e:#}");
return;
}
Ok(websocket) => websocket,
};
websocket::serve_websocket(
if let Err(e) = websocket::serve_websocket(
config,
backend.auth_backend,
ctx,
@@ -453,9 +452,12 @@ async fn request_handler(
cancellation_handler,
endpoint_rate_limiter,
host,
tracker,
cancellations,
)
.await;
.await
{
warn!("error in websocket connection: {e:#}");
}
}
.instrument(span),
);

View File

@@ -2,14 +2,14 @@ use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll, ready};
use anyhow::Context as _;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use framed_websockets::{Frame, OpCode, WebSocketServer};
use futures::{Sink, Stream};
use hyper::upgrade::Upgraded;
use hyper::upgrade::OnUpgrade;
use hyper_util::rt::TokioIo;
use pin_project_lite::pin_project;
use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
use tokio_util::task::task_tracker::TaskTrackerToken;
use tracing::warn;
use crate::cancellation::CancellationHandler;
@@ -17,7 +17,7 @@ use crate::config::ProxyConfig;
use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::metrics::Metrics;
use crate::proxy::{ClientMode, handle_client};
use crate::proxy::{ClientMode, ErrorSource, handle_client};
use crate::rate_limiter::EndpointRateLimiter;
pin_project! {
@@ -128,12 +128,13 @@ pub(crate) async fn serve_websocket(
config: &'static ProxyConfig,
auth_backend: &'static crate::auth::Backend<'static, ()>,
ctx: RequestContext,
websocket: Upgraded,
websocket: OnUpgrade,
cancellation_handler: Arc<CancellationHandler>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
hostname: Option<String>,
tracker: TaskTrackerToken,
) {
cancellations: tokio_util::task::task_tracker::TaskTracker,
) -> anyhow::Result<()> {
let websocket = websocket.await?;
let websocket = WebSocketServer::after_handshake(TokioIo::new(websocket));
let conn_gauge = Metrics::get()
@@ -141,28 +142,36 @@ pub(crate) async fn serve_websocket(
.client_connections
.guard(crate::metrics::Protocol::Ws);
let mut ctx_slot = Some(ctx);
let res = handle_client(
let res = Box::pin(handle_client(
config,
auth_backend,
&mut ctx_slot,
&ctx,
cancellation_handler,
WebSocketRw::new(websocket),
ClientMode::Websockets { hostname },
endpoint_rate_limiter,
conn_gauge,
tracker,
)
cancellations,
))
.await;
match (ctx_slot, res) {
(None, _) => {}
(Some(ctx), Err(e)) => {
match res {
Err(e) => {
ctx.set_error_kind(e.get_error_kind());
tracing::warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}");
Err(e.into())
}
(Some(ctx), Ok(())) => {
Ok(None) => {
ctx.set_success();
Ok(())
}
Ok(Some(p)) => {
ctx.set_success();
ctx.log_connect();
match p.proxy_pass(&config.connect_to_compute).await {
Ok(()) => Ok(()),
Err(ErrorSource::Client(err)) => Err(err).context("client"),
Err(ErrorSource::Compute(err)) => Err(err).context("compute"),
}
}
}
}

View File

@@ -10,7 +10,6 @@ use serde::{Deserialize, Serialize};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_rustls::server::TlsStream;
use tokio_util::task::task_tracker::TaskTrackerToken;
use tracing::debug;
use crate::control_plane::messages::ColdStartInfo;
@@ -25,22 +24,19 @@ use crate::tls::TlsServerEndPoint;
/// to pass random malformed bytes through the connection).
pub struct PqStream<S> {
pub(crate) framed: Framed<S>,
pub(crate) tracker: TaskTrackerToken,
}
impl<S> PqStream<S> {
/// Construct a new libpq protocol wrapper.
pub fn new(stream: S, tracker: TaskTrackerToken) -> Self {
pub fn new(stream: S) -> Self {
Self {
framed: Framed::new(stream),
tracker,
}
}
/// Extract the underlying stream and read buffer.
pub fn into_inner(self) -> (S, BytesMut, TaskTrackerToken) {
let (stream, read) = self.framed.into_inner();
(stream, read, self.tracker)
pub fn into_inner(self) -> (S, BytesMut) {
self.framed.into_inner()
}
/// Get a shared reference to the underlying stream.

View File

@@ -44,7 +44,6 @@ struct GlobalTimelinesState {
// on-demand timeline creation from recreating deleted timelines. This is only soft-enforced, as
// this map is dropped on restart.
tombstones: HashMap<TenantTimelineId, Instant>,
tenant_tombstones: HashMap<TenantId, Instant>,
conf: Arc<SafeKeeperConf>,
broker_active_set: Arc<TimelinesSet>,
@@ -82,25 +81,10 @@ impl GlobalTimelinesState {
}
}
fn has_tombstone(&self, ttid: &TenantTimelineId) -> bool {
self.tombstones.contains_key(ttid) || self.tenant_tombstones.contains_key(&ttid.tenant_id)
}
/// Removes all blocking tombstones for the given timeline ID.
/// Returns `true` if there have been actual changes.
fn remove_tombstone(&mut self, ttid: &TenantTimelineId) -> bool {
self.tombstones.remove(ttid).is_some()
|| self.tenant_tombstones.remove(&ttid.tenant_id).is_some()
}
fn delete(&mut self, ttid: TenantTimelineId) {
self.timelines.remove(&ttid);
self.tombstones.insert(ttid, Instant::now());
}
fn add_tenant_tombstone(&mut self, tenant_id: TenantId) {
self.tenant_tombstones.insert(tenant_id, Instant::now());
}
}
/// A struct used to manage access to the global timelines map.
@@ -115,7 +99,6 @@ impl GlobalTimelines {
state: Mutex::new(GlobalTimelinesState {
timelines: HashMap::new(),
tombstones: HashMap::new(),
tenant_tombstones: HashMap::new(),
conf,
broker_active_set: Arc::new(TimelinesSet::default()),
global_rate_limiter: RateLimiter::new(1, 1),
@@ -262,7 +245,7 @@ impl GlobalTimelines {
return Ok(timeline);
}
if state.has_tombstone(&ttid) {
if state.tombstones.contains_key(&ttid) {
anyhow::bail!("Timeline {ttid} is deleted, refusing to recreate");
}
@@ -312,14 +295,13 @@ impl GlobalTimelines {
_ => {}
}
if check_tombstone {
if state.has_tombstone(&ttid) {
if state.tombstones.contains_key(&ttid) {
anyhow::bail!("timeline {ttid} is deleted, refusing to recreate");
}
} else {
// We may be have been asked to load a timeline that was previously deleted (e.g. from `pull_timeline.rs`). We trust
// that the human doing this manual intervention knows what they are doing, and remove its tombstone.
// It's also possible that we enter this when the tenant has been deleted, even if the timeline itself has never existed.
if state.remove_tombstone(&ttid) {
if state.tombstones.remove(&ttid).is_some() {
warn!("un-deleted timeline {ttid}");
}
}
@@ -500,7 +482,6 @@ impl GlobalTimelines {
let tli_res = {
let state = self.state.lock().unwrap();
// Do NOT check tenant tombstones here: those were set earlier
if state.tombstones.contains_key(ttid) {
// Presence of a tombstone guarantees that a previous deletion has completed and there is no work to do.
info!("Timeline {ttid} was already deleted");
@@ -576,10 +557,6 @@ impl GlobalTimelines {
action: DeleteOrExclude,
) -> Result<HashMap<TenantTimelineId, TimelineDeleteResult>> {
info!("deleting all timelines for tenant {}", tenant_id);
// Adding a tombstone before getting the timelines to prevent new timeline additions
self.state.lock().unwrap().add_tenant_tombstone(*tenant_id);
let to_delete = self.get_all_for_tenant(*tenant_id);
let mut err = None;
@@ -623,9 +600,6 @@ impl GlobalTimelines {
state
.tombstones
.retain(|_, v| now.duration_since(*v) < *tombstone_ttl);
state
.tenant_tombstones
.retain(|_, v| now.duration_since(*v) < *tombstone_ttl);
}
}

View File

@@ -482,10 +482,6 @@ async fn handle_tenant_timeline_delete(
ForwardOutcome::NotForwarded(_req) => {}
};
service
.maybe_delete_timeline_import(tenant_id, timeline_id)
.await?;
// For timeline deletions, which both implement an "initially return 202, then 404 once
// we're done" semantic, we wrap with a retry loop to expose a simpler API upstream.
async fn deletion_wrapper<R, F>(service: Arc<Service>, f: F) -> Result<Response<Body>, ApiError>

View File

@@ -99,8 +99,8 @@ use crate::tenant_shard::{
ScheduleOptimization, ScheduleOptimizationAction, TenantShard,
};
use crate::timeline_import::{
FinalizingImport, ImportResult, ShardImportStatuses, TimelineImport,
TimelineImportFinalizeError, TimelineImportState, UpcallClient,
ImportResult, ShardImportStatuses, TimelineImport, TimelineImportFinalizeError,
TimelineImportState, UpcallClient,
};
const WAITER_FILL_DRAIN_POLL_TIMEOUT: Duration = Duration::from_millis(500);
@@ -232,9 +232,6 @@ struct ServiceState {
/// Queue of tenants who are waiting for concurrency limits to permit them to reconcile
delayed_reconcile_rx: tokio::sync::mpsc::Receiver<TenantShardId>,
/// Tracks ongoing timeline import finalization tasks
imports_finalizing: BTreeMap<(TenantId, TimelineId), FinalizingImport>,
}
/// Transform an error from a pageserver into an error to return to callers of a storage
@@ -311,7 +308,6 @@ impl ServiceState {
scheduler,
ongoing_operation: None,
delayed_reconcile_rx,
imports_finalizing: Default::default(),
}
}
@@ -4101,58 +4097,13 @@ impl Service {
///
/// If this method gets pre-empted by shut down, it will be called again at start-up (on-going
/// imports are stored in the database).
///
/// # Cancel-Safety
/// Not cancel safe.
/// If the caller stops polling, the import will not be removed from
/// [`ServiceState::imports_finalizing`].
#[instrument(skip_all, fields(
tenant_id=%import.tenant_id,
timeline_id=%import.timeline_id,
))]
async fn finalize_timeline_import(
self: &Arc<Self>,
import: TimelineImport,
) -> Result<(), TimelineImportFinalizeError> {
let tenant_timeline = (import.tenant_id, import.timeline_id);
let (_finalize_import_guard, cancel) = {
let mut locked = self.inner.write().unwrap();
let gate = Gate::default();
let cancel = CancellationToken::default();
let guard = gate.enter().unwrap();
locked.imports_finalizing.insert(
tenant_timeline,
FinalizingImport {
gate,
cancel: cancel.clone(),
},
);
(guard, cancel)
};
let res = tokio::select! {
res = self.finalize_timeline_import_impl(import) => {
res
},
_ = cancel.cancelled() => {
Err(TimelineImportFinalizeError::Cancelled)
}
};
let mut locked = self.inner.write().unwrap();
locked.imports_finalizing.remove(&tenant_timeline);
res
}
async fn finalize_timeline_import_impl(
self: &Arc<Self>,
import: TimelineImport,
) -> Result<(), TimelineImportFinalizeError> {
tracing::info!("Finalizing timeline import");
@@ -4352,46 +4303,6 @@ impl Service {
.await;
}
/// Delete a timeline import if it exists
///
/// Firstly, delete the entry from the database. Any updates
/// from pageservers after the update will fail with a 404, so the
/// import cannot progress into finalizing state if it's not there already.
/// Secondly, cancel the finalization if one is in progress.
pub(crate) async fn maybe_delete_timeline_import(
self: &Arc<Self>,
tenant_id: TenantId,
timeline_id: TimelineId,
) -> Result<(), DatabaseError> {
let tenant_has_ongoing_import = {
let locked = self.inner.read().unwrap();
locked
.tenants
.range(TenantShardId::tenant_range(tenant_id))
.any(|(_tid, shard)| shard.importing == TimelineImportState::Importing)
};
if !tenant_has_ongoing_import {
return Ok(());
}
self.persistence
.delete_timeline_import(tenant_id, timeline_id)
.await?;
let maybe_finalizing = {
let mut locked = self.inner.write().unwrap();
locked.imports_finalizing.remove(&(tenant_id, timeline_id))
};
if let Some(finalizing) = maybe_finalizing {
finalizing.cancel.cancel();
finalizing.gate.close().await;
}
Ok(())
}
pub(crate) async fn tenant_timeline_archival_config(
&self,
tenant_id: TenantId,
@@ -8627,9 +8538,8 @@ impl Service {
Some(ShardCount(new_shard_count))
}
/// Fetches the top tenant shards from every available node, in descending order of
/// max logical size. Offline nodes are skipped, and any errors from available nodes
/// will be logged and ignored.
/// Fetches the top tenant shards from every node, in descending order of
/// max logical size. Any node errors will be logged and ignored.
async fn get_top_tenant_shards(
&self,
request: &TopTenantShardsRequest,
@@ -8640,7 +8550,6 @@ impl Service {
.unwrap()
.nodes
.values()
.filter(|node| node.is_available())
.cloned()
.collect_vec();

View File

@@ -7,7 +7,6 @@ use serde::{Deserialize, Serialize};
use pageserver_api::models::{ShardImportProgress, ShardImportStatus};
use tokio_util::sync::CancellationToken;
use utils::sync::gate::Gate;
use utils::{
id::{TenantId, TimelineId},
shard::ShardIndex,
@@ -56,8 +55,6 @@ pub(crate) enum TimelineImportUpdateFollowUp {
pub(crate) enum TimelineImportFinalizeError {
#[error("Shut down interrupted import finalize")]
ShuttingDown,
#[error("Import finalization was cancelled")]
Cancelled,
#[error("Mismatched shard detected during import finalize: {0}")]
MismatchedShards(ShardIndex),
}
@@ -167,11 +164,6 @@ impl TimelineImport {
}
}
pub(crate) struct FinalizingImport {
pub(crate) gate: Gate,
pub(crate) cancel: CancellationToken,
}
pub(crate) type ImportResult = Result<(), String>;
pub(crate) struct UpcallClient {

View File

@@ -1,4 +1,3 @@
import json
import os
import shutil
import subprocess
@@ -12,7 +11,6 @@ from _pytest.config import Config
from fixtures.log_helper import log
from fixtures.neon_cli import AbstractNeonCli
from fixtures.neon_fixtures import Endpoint, VanillaPostgres
from fixtures.pg_version import PgVersion
from fixtures.remote_storage import MockS3Server
@@ -163,57 +161,3 @@ def fast_import(
f.write(fi.cmd.stderr)
log.info("Written logs to %s", test_output_dir)
def mock_import_bucket(vanilla_pg: VanillaPostgres, path: Path):
"""
Mock the import S3 bucket into a local directory for a provided vanilla PG instance.
"""
assert not vanilla_pg.is_running()
path.mkdir()
# what cplane writes before scheduling fast_import
specpath = path / "spec.json"
specpath.write_text(json.dumps({"branch_id": "somebranch", "project_id": "someproject"}))
# what fast_import writes
vanilla_pg.pgdatadir.rename(path / "pgdata")
statusdir = path / "status"
statusdir.mkdir()
(statusdir / "pgdata").write_text(json.dumps({"done": True}))
(statusdir / "fast_import").write_text(json.dumps({"command": "pgdata", "done": True}))
def populate_vanilla_pg(vanilla_pg: VanillaPostgres, target_relblock_size: int) -> int:
assert vanilla_pg.is_running()
vanilla_pg.safe_psql("create user cloud_admin with password 'postgres' superuser")
# fillfactor so we don't need to produce that much data
# 900 byte per row is > 10% => 1 row per page
vanilla_pg.safe_psql("""create table t (data char(900)) with (fillfactor = 10)""")
nrows = 0
while True:
relblock_size = vanilla_pg.safe_psql_scalar("select pg_relation_size('t')")
log.info(
f"relblock size: {relblock_size / 8192} pages (target: {target_relblock_size // 8192}) pages"
)
if relblock_size >= target_relblock_size:
break
addrows = int((target_relblock_size - relblock_size) // 8192)
assert addrows >= 1, "forward progress"
vanilla_pg.safe_psql(
f"insert into t select generate_series({nrows + 1}, {nrows + addrows})"
)
nrows += addrows
return nrows
def validate_import_from_vanilla_pg(endpoint: Endpoint, nrows: int):
assert endpoint.safe_psql_many(
[
"set effective_io_concurrency=32;",
"SET statement_timeout='300s';",
"select count(*), sum(data::bigint)::bigint from t",
]
) == [[], [], [(nrows, nrows * (nrows + 1) // 2)]]

View File

@@ -2337,22 +2337,6 @@ class NeonStorageController(MetricsGetter, LogUtils):
headers=self.headers(TokenScope.ADMIN),
)
def import_status(
self, tenant_shard_id: TenantShardId, timeline_id: TimelineId, generation: int
):
payload = {
"tenant_shard_id": str(tenant_shard_id),
"timeline_id": str(timeline_id),
"generation": generation,
}
self.request(
"GET",
f"{self.api}/upcall/v1/timeline_import_status",
headers=self.headers(TokenScope.GENERATIONS_API),
json=payload,
)
def reconcile_all(self):
r = self.request(
"POST",
@@ -2829,11 +2813,6 @@ class NeonPageserver(PgProtocol, LogUtils):
if self.running:
self.http_client().configure_failpoints([(name, action)])
def clear_persistent_failpoint(self, name: str):
del self._persistent_failpoints[name]
if self.running:
self.http_client().configure_failpoints([(name, "off")])
def timeline_dir(
self,
tenant_shard_id: TenantId | TenantShardId,

View File

@@ -675,7 +675,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter):
def timeline_delete(
self, tenant_id: TenantId | TenantShardId, timeline_id: TimelineId, **kwargs
) -> int:
):
"""
Note that deletion is not instant, it is scheduled and performed mostly in the background.
So if you need to wait for it to complete use `timeline_delete_wait_completed`.
@@ -688,8 +688,6 @@ class PageserverHttpClient(requests.Session, MetricsGetter):
res_json = res.json()
assert res_json is None
return res.status_code
def timeline_gc(
self,
tenant_id: TenantId | TenantShardId,

View File

@@ -1,41 +1,31 @@
from __future__ import annotations
import enum
import json
import time
from collections import Counter
from dataclasses import dataclass
from enum import StrEnum
from threading import Event
from typing import TYPE_CHECKING
import pytest
from fixtures.common_types import Lsn, TenantId, TimelineId
from fixtures.fast_import import mock_import_bucket, populate_vanilla_pg
from fixtures.log_helper import log
from fixtures.neon_fixtures import (
NeonEnv,
NeonEnvBuilder,
NeonPageserver,
PgBin,
VanillaPostgres,
wait_for_last_flush_lsn,
)
from fixtures.pageserver.http import (
ImportPgdataIdemptencyKey,
)
from fixtures.pageserver.utils import wait_for_upload_queue_empty
from fixtures.remote_storage import RemoteStorageKind
from fixtures.utils import human_bytes, run_only_on_default_postgres, wait_until
from werkzeug.wrappers.response import Response
from fixtures.utils import human_bytes, wait_until
if TYPE_CHECKING:
from collections.abc import Iterable
from typing import Any
from fixtures.pageserver.http import PageserverHttpClient
from pytest_httpserver import HTTPServer
from werkzeug.wrappers.request import Request
GLOBAL_LRU_LOG_LINE = "tenant_min_resident_size-respecting LRU would not relieve pressure, evicting more following global LRU policy"
@@ -174,7 +164,6 @@ class EvictionEnv:
min_avail_bytes,
mock_behavior,
eviction_order: EvictionOrder,
wait_logical_size: bool = True,
):
"""
Starts pageserver up with mocked statvfs setup. The startup is
@@ -212,12 +201,11 @@ class EvictionEnv:
pageserver.start()
# we now do initial logical size calculation on startup, which on debug builds can fight with disk usage based eviction
if wait_logical_size:
for tenant_id, timeline_id in self.timelines:
tenant_ps = self.neon_env.get_tenant_pageserver(tenant_id)
# Pageserver may be none if we are currently not attached anywhere, e.g. during secondary eviction test
if tenant_ps is not None:
tenant_ps.http_client().timeline_wait_logical_size(tenant_id, timeline_id)
for tenant_id, timeline_id in self.timelines:
tenant_ps = self.neon_env.get_tenant_pageserver(tenant_id)
# Pageserver may be none if we are currently not attached anywhere, e.g. during secondary eviction test
if tenant_ps is not None:
tenant_ps.http_client().timeline_wait_logical_size(tenant_id, timeline_id)
def statvfs_called():
pageserver.assert_log_contains(".*running mocked statvfs.*")
@@ -894,121 +882,3 @@ def test_secondary_mode_eviction(eviction_env_ha: EvictionEnv):
assert total_size - post_eviction_total_size >= evict_bytes, (
"we requested at least evict_bytes worth of free space"
)
@run_only_on_default_postgres(reason="PG version is irrelevant here")
def test_import_timeline_disk_pressure_eviction(
neon_env_builder: NeonEnvBuilder,
vanilla_pg: VanillaPostgres,
make_httpserver: HTTPServer,
pg_bin: PgBin,
):
"""
TODO
"""
# Set up mock control plane HTTP server to listen for import completions
import_completion_signaled = Event()
def handler(request: Request) -> Response:
log.info(f"control plane /import_complete request: {request.json}")
import_completion_signaled.set()
return Response(json.dumps({}), status=200)
cplane_mgmt_api_server = make_httpserver
cplane_mgmt_api_server.expect_request(
"/storage/api/v1/import_complete", method="PUT"
).respond_with_handler(handler)
# Plug the cplane mock in
neon_env_builder.control_plane_hooks_api = (
f"http://{cplane_mgmt_api_server.host}:{cplane_mgmt_api_server.port}/storage/api/v1/"
)
# The import will specifiy a local filesystem path mocking remote storage
neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.LOCAL_FS)
vanilla_pg.start()
target_relblock_size = 1024 * 1024 * 128
populate_vanilla_pg(vanilla_pg, target_relblock_size)
vanilla_pg.stop()
env = neon_env_builder.init_configs()
env.start()
importbucket_path = neon_env_builder.repo_dir / "test_import_completion_bucket"
mock_import_bucket(vanilla_pg, importbucket_path)
tenant_id = TenantId.generate()
timeline_id = TimelineId.generate()
idempotency = ImportPgdataIdemptencyKey.random()
eviction_env = EvictionEnv(
timelines=[(tenant_id, timeline_id)],
neon_env=env,
pageserver_http=env.pageserver.http_client(),
layer_size=5 * 1024 * 1024, # Doesn't apply here
pg_bin=pg_bin, # Not used here
pgbench_init_lsns={}, # Not used here
)
# Pause before delivering the final notification to storcon.
# This keeps the import in progress.
failpoint_name = "import-timeline-pre-success-notify-pausable"
env.pageserver.add_persistent_failpoint(failpoint_name, "pause")
env.storage_controller.tenant_create(tenant_id)
env.storage_controller.timeline_create(
tenant_id,
{
"new_timeline_id": str(timeline_id),
"import_pgdata": {
"idempotency_key": str(idempotency),
"location": {"LocalFs": {"path": str(importbucket_path.absolute())}},
},
},
)
def hit_failpoint():
log.info("Checking log for pattern...")
try:
assert env.pageserver.log_contains(f".*at failpoint {failpoint_name}.*")
except Exception:
log.exception("Failed to find pattern in log")
raise
wait_until(hit_failpoint)
assert not import_completion_signaled.is_set()
env.pageserver.stop()
total_size, _, _ = eviction_env.timelines_du(env.pageserver)
blocksize = 512
total_blocks = (total_size + (blocksize - 1)) // blocksize
eviction_env.pageserver_start_with_disk_usage_eviction(
env.pageserver,
period="1s",
max_usage_pct=33,
min_avail_bytes=0,
mock_behavior={
"type": "Success",
"blocksize": blocksize,
"total_blocks": total_blocks,
# Only count layer files towards used bytes in the mock_statvfs.
# This avoids accounting for metadata files & tenant conf in the tests.
"name_filter": ".*__.*",
},
eviction_order=EvictionOrder.RELATIVE_ORDER_SPARE,
wait_logical_size=False,
)
wait_until(lambda: env.pageserver.assert_log_contains(".*disk usage pressure relieved"))
env.pageserver.clear_persistent_failpoint(failpoint_name)
def cplane_notified():
assert import_completion_signaled.is_set()
wait_until(cplane_notified)
env.pageserver.allowed_errors.append(r".* running disk usage based eviction due to pressure.*")

View File

@@ -12,19 +12,13 @@ import psycopg2
import psycopg2.errors
import pytest
from fixtures.common_types import Lsn, TenantId, TenantShardId, TimelineId
from fixtures.fast_import import (
FastImport,
mock_import_bucket,
populate_vanilla_pg,
validate_import_from_vanilla_pg,
)
from fixtures.fast_import import FastImport
from fixtures.log_helper import log
from fixtures.neon_fixtures import (
NeonEnvBuilder,
PageserverImportConfig,
PgBin,
PgProtocol,
StorageControllerApiException,
StorageControllerMigrationConfig,
VanillaPostgres,
)
@@ -65,6 +59,24 @@ smoke_params = [
]
def mock_import_bucket(vanilla_pg: VanillaPostgres, path: Path):
"""
Mock the import S3 bucket into a local directory for a provided vanilla PG instance.
"""
assert not vanilla_pg.is_running()
path.mkdir()
# what cplane writes before scheduling fast_import
specpath = path / "spec.json"
specpath.write_text(json.dumps({"branch_id": "somebranch", "project_id": "someproject"}))
# what fast_import writes
vanilla_pg.pgdatadir.rename(path / "pgdata")
statusdir = path / "status"
statusdir.mkdir()
(statusdir / "pgdata").write_text(json.dumps({"done": True}))
(statusdir / "fast_import").write_text(json.dumps({"command": "pgdata", "done": True}))
@skip_in_debug_build("MULTIPLE_RELATION_SEGMENTS has non trivial amount of data")
@pytest.mark.parametrize("shard_count,stripe_size,rel_block_size", smoke_params)
def test_pgdata_import_smoke(
@@ -119,6 +131,10 @@ def test_pgdata_import_smoke(
# Put data in vanilla pg
#
vanilla_pg.start()
vanilla_pg.safe_psql("create user cloud_admin with password 'postgres' superuser")
log.info("create relblock data")
if rel_block_size == RelBlockSize.ONE_STRIPE_SIZE:
target_relblock_size = stripe_size * 8192
elif rel_block_size == RelBlockSize.TWO_STRPES_PER_SHARD:
@@ -129,8 +145,45 @@ def test_pgdata_import_smoke(
else:
raise ValueError
vanilla_pg.start()
rows_inserted = populate_vanilla_pg(vanilla_pg, target_relblock_size)
# fillfactor so we don't need to produce that much data
# 900 byte per row is > 10% => 1 row per page
vanilla_pg.safe_psql("""create table t (data char(900)) with (fillfactor = 10)""")
nrows = 0
while True:
relblock_size = vanilla_pg.safe_psql_scalar("select pg_relation_size('t')")
log.info(
f"relblock size: {relblock_size / 8192} pages (target: {target_relblock_size // 8192}) pages"
)
if relblock_size >= target_relblock_size:
break
addrows = int((target_relblock_size - relblock_size) // 8192)
assert addrows >= 1, "forward progress"
vanilla_pg.safe_psql(
f"insert into t select generate_series({nrows + 1}, {nrows + addrows})"
)
nrows += addrows
expect_nrows = nrows
expect_sum = (
(nrows) * (nrows + 1) // 2
) # https://stackoverflow.com/questions/43901484/sum-of-the-integers-from-1-to-n
def validate_vanilla_equivalence(ep):
# TODO: would be nicer to just compare pgdump
# Enable IO concurrency for batching on large sequential scan, to avoid making
# this test unnecessarily onerous on CPU. Especially on debug mode, it's still
# pretty onerous though, so increase statement_timeout to avoid timeouts.
assert ep.safe_psql_many(
[
"set effective_io_concurrency=32;",
"SET statement_timeout='300s';",
"select count(*), sum(data::bigint)::bigint from t",
]
) == [[], [], [(expect_nrows, expect_sum)]]
validate_vanilla_equivalence(vanilla_pg)
vanilla_pg.stop()
#
@@ -221,14 +274,14 @@ def test_pgdata_import_smoke(
config_lines=ep_config,
)
validate_import_from_vanilla_pg(ro_endpoint, rows_inserted)
validate_vanilla_equivalence(ro_endpoint)
# ensure the import survives restarts
ro_endpoint.stop()
env.pageserver.stop(immediate=True)
env.pageserver.start()
ro_endpoint.start()
validate_import_from_vanilla_pg(ro_endpoint, rows_inserted)
validate_vanilla_equivalence(ro_endpoint)
#
# validate the layer files in each shard only have the shard-specific data
@@ -268,7 +321,7 @@ def test_pgdata_import_smoke(
child_workload = workload.branch(timeline_id=child_timeline_id, branch_name="br-tip")
child_workload.validate()
validate_import_from_vanilla_pg(child_workload.endpoint(), rows_inserted)
validate_vanilla_equivalence(child_workload.endpoint())
# ... at the initdb lsn
_ = env.create_branch(
@@ -283,7 +336,7 @@ def test_pgdata_import_smoke(
tenant_id=tenant_id,
config_lines=ep_config,
)
validate_import_from_vanilla_pg(br_initdb_endpoint, rows_inserted)
validate_vanilla_equivalence(br_initdb_endpoint)
with pytest.raises(psycopg2.errors.UndefinedTable):
br_initdb_endpoint.safe_psql(f"select * from {workload.table}")
@@ -370,12 +423,8 @@ def test_import_completion_on_restart(
@run_only_on_default_postgres(reason="PG version is irrelevant here")
@pytest.mark.parametrize("action", ["restart", "delete"])
def test_import_respects_timeline_lifecycle(
neon_env_builder: NeonEnvBuilder,
vanilla_pg: VanillaPostgres,
make_httpserver: HTTPServer,
action: str,
def test_import_respects_tenant_shutdown(
neon_env_builder: NeonEnvBuilder, vanilla_pg: VanillaPostgres, make_httpserver: HTTPServer
):
"""
Validate that importing timelines respect the usual timeline life cycle:
@@ -443,33 +492,16 @@ def test_import_respects_timeline_lifecycle(
wait_until(hit_failpoint)
assert not import_completion_signaled.is_set()
if action == "restart":
# Restart the pageserver while an import job is in progress.
# This clears the failpoint and we expect that the import starts up afresh
# after the restart and eventually completes.
env.pageserver.stop()
env.pageserver.start()
# Restart the pageserver while an import job is in progress.
# This clears the failpoint and we expect that the import starts up afresh
# after the restart and eventually completes.
env.pageserver.stop()
env.pageserver.start()
def cplane_notified():
assert import_completion_signaled.is_set()
def cplane_notified():
assert import_completion_signaled.is_set()
wait_until(cplane_notified)
elif action == "delete":
status = env.storage_controller.pageserver_api().timeline_delete(tenant_id, timeline_id)
assert status == 200
timeline_path = env.pageserver.timeline_dir(tenant_id, timeline_id)
assert not timeline_path.exists(), "Timeline dir exists after deletion"
shard_zero = TenantShardId(tenant_id, 0, 0)
location = env.storage_controller.inspect(shard_zero)
assert location is not None
generation = location[0]
with pytest.raises(StorageControllerApiException, match="not found"):
env.storage_controller.import_status(shard_zero, timeline_id, generation)
else:
raise RuntimeError(f"{action} param not recognized")
wait_until(cplane_notified)
@skip_in_debug_build("Validation query takes too long in debug builds")
@@ -524,8 +556,23 @@ def test_import_chaos(
neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.LOCAL_FS)
vanilla_pg.start()
vanilla_pg.safe_psql("create user cloud_admin with password 'postgres' superuser")
vanilla_pg.safe_psql("""create table t (data char(900)) with (fillfactor = 10)""")
inserted_rows = populate_vanilla_pg(vanilla_pg, TARGET_RELBOCK_SIZE)
nrows = 0
while True:
relblock_size = vanilla_pg.safe_psql_scalar("select pg_relation_size('t')")
log.info(
f"relblock size: {relblock_size / 8192} pages (target: {TARGET_RELBOCK_SIZE // 8192}) pages"
)
if relblock_size >= TARGET_RELBOCK_SIZE:
break
addrows = int((TARGET_RELBOCK_SIZE - relblock_size) // 8192)
assert addrows >= 1, "forward progress"
vanilla_pg.safe_psql(
f"insert into t select generate_series({nrows + 1}, {nrows + addrows})"
)
nrows += addrows
vanilla_pg.stop()
@@ -693,7 +740,13 @@ def test_import_chaos(
endpoint = env.endpoints.create_start(branch_name=import_branch_name, tenant_id=tenant_id)
# Validate the imported data is legit
validate_import_from_vanilla_pg(endpoint, inserted_rows)
assert endpoint.safe_psql_many(
[
"set effective_io_concurrency=32;",
"SET statement_timeout='300s';",
"select count(*), sum(data::bigint)::bigint from t",
]
) == [[], [], [(nrows, nrows * (nrows + 1) // 2)]]
endpoint.stop()

View File

@@ -4192,10 +4192,10 @@ def test_storcon_create_delete_sk_down(
# ensure the safekeeper deleted the timeline
def timeline_deleted_on_active_sks():
env.safekeepers[0].assert_log_contains(
f"((deleting timeline|Timeline) {tenant_id}/{child_timeline_id} (from disk|was already deleted)|DELETE.*tenant/{tenant_id} .*status: 200 OK)"
f"deleting timeline {tenant_id}/{child_timeline_id} from disk"
)
env.safekeepers[2].assert_log_contains(
f"((deleting timeline|Timeline) {tenant_id}/{child_timeline_id} (from disk|was already deleted)|DELETE.*tenant/{tenant_id} .*status: 200 OK)"
f"deleting timeline {tenant_id}/{child_timeline_id} from disk"
)
wait_until(timeline_deleted_on_active_sks)
@@ -4210,7 +4210,7 @@ def test_storcon_create_delete_sk_down(
# ensure that there is log msgs for the third safekeeper too
def timeline_deleted_on_sk():
env.safekeepers[1].assert_log_contains(
f"((deleting timeline|Timeline) {tenant_id}/{child_timeline_id} (from disk|was already deleted)|DELETE.*tenant/{tenant_id} .*status: 200 OK)"
f"deleting timeline {tenant_id}/{child_timeline_id} from disk"
)
wait_until(timeline_deleted_on_sk)