From aeda82a0105f18393e8d56d7ff2f6202059edde6 Mon Sep 17 00:00:00 2001 From: Joonas Koivunen Date: Mon, 12 Feb 2024 11:57:29 +0200 Subject: [PATCH 01/20] fix(heavier_once_cell): assertion failure can be hit (#6722) @problame noticed that the `tokio::sync::AcquireError` branch assertion can be hit like in the added test. We haven't seen this yet in production, but I'd prefer not to see it there. There `take_and_deinit` is being used, but this race must be quite timing sensitive. Rework of earlier: #6652. --- libs/utils/src/sync/heavier_once_cell.rs | 174 ++++++++++++++++++----- 1 file changed, 138 insertions(+), 36 deletions(-) diff --git a/libs/utils/src/sync/heavier_once_cell.rs b/libs/utils/src/sync/heavier_once_cell.rs index 0ccaf4e716..0773abba2d 100644 --- a/libs/utils/src/sync/heavier_once_cell.rs +++ b/libs/utils/src/sync/heavier_once_cell.rs @@ -69,37 +69,44 @@ impl OnceCell { F: FnOnce(InitPermit) -> Fut, Fut: std::future::Future>, { - let sem = { + loop { + let sem = { + let guard = self.inner.lock().unwrap(); + if guard.value.is_some() { + return Ok(Guard(guard)); + } + guard.init_semaphore.clone() + }; + + { + let permit = { + // increment the count for the duration of queued + let _guard = CountWaitingInitializers::start(self); + sem.acquire().await + }; + + let Ok(permit) = permit else { + let guard = self.inner.lock().unwrap(); + if !Arc::ptr_eq(&sem, &guard.init_semaphore) { + // there was a take_and_deinit in between + continue; + } + assert!( + guard.value.is_some(), + "semaphore got closed, must be initialized" + ); + return Ok(Guard(guard)); + }; + + permit.forget(); + } + + let permit = InitPermit(sem); + let (value, _permit) = factory(permit).await?; + let guard = self.inner.lock().unwrap(); - if guard.value.is_some() { - return Ok(Guard(guard)); - } - guard.init_semaphore.clone() - }; - let permit = { - // increment the count for the duration of queued - let _guard = CountWaitingInitializers::start(self); - sem.acquire_owned().await - }; - - match permit { - Ok(permit) => { - let permit = InitPermit(permit); - let (value, _permit) = factory(permit).await?; - - let guard = self.inner.lock().unwrap(); - - Ok(Self::set0(value, guard)) - } - Err(_closed) => { - let guard = self.inner.lock().unwrap(); - assert!( - guard.value.is_some(), - "semaphore got closed, must be initialized" - ); - return Ok(Guard(guard)); - } + return Ok(Self::set0(value, guard)); } } @@ -197,27 +204,41 @@ impl<'a, T> Guard<'a, T> { /// [`OnceCell::get_or_init`] will wait on it to complete. pub fn take_and_deinit(&mut self) -> (T, InitPermit) { let mut swapped = Inner::default(); - let permit = swapped - .init_semaphore - .clone() - .try_acquire_owned() - .expect("we just created this"); + let sem = swapped.init_semaphore.clone(); + // acquire and forget right away, moving the control over to InitPermit + sem.try_acquire().expect("we just created this").forget(); std::mem::swap(&mut *self.0, &mut swapped); swapped .value - .map(|v| (v, InitPermit(permit))) + .map(|v| (v, InitPermit(sem))) .expect("guard is not created unless value has been initialized") } } /// Type held by OnceCell (de)initializing task. -pub struct InitPermit(tokio::sync::OwnedSemaphorePermit); +/// +/// On drop, this type will return the permit. +pub struct InitPermit(Arc); + +impl Drop for InitPermit { + fn drop(&mut self) { + assert_eq!( + self.0.available_permits(), + 0, + "InitPermit should only exist as the unique permit" + ); + self.0.add_permits(1); + } +} #[cfg(test)] mod tests { + use futures::Future; + use super::*; use std::{ convert::Infallible, + pin::{pin, Pin}, sync::atomic::{AtomicUsize, Ordering}, time::Duration, }; @@ -380,4 +401,85 @@ mod tests { .unwrap(); assert_eq!(*g, "now initialized"); } + + #[tokio::test(start_paused = true)] + async fn reproduce_init_take_deinit_race() { + init_take_deinit_scenario(|cell, factory| { + Box::pin(async { + cell.get_or_init(factory).await.unwrap(); + }) + }) + .await; + } + + type BoxedInitFuture = Pin>>>; + type BoxedInitFunction = Box BoxedInitFuture>; + + /// Reproduce an assertion failure. + /// + /// This has interesting generics to be generic between `get_or_init` and `get_mut_or_init`. + /// We currently only have one, but the structure is kept. + async fn init_take_deinit_scenario(init_way: F) + where + F: for<'a> Fn( + &'a OnceCell<&'static str>, + BoxedInitFunction<&'static str, Infallible>, + ) -> Pin + 'a>>, + { + let cell = OnceCell::default(); + + // acquire the init_semaphore only permit to drive initializing tasks in order to waiting + // on the same semaphore. + let permit = cell + .inner + .lock() + .unwrap() + .init_semaphore + .clone() + .try_acquire_owned() + .unwrap(); + + let mut t1 = pin!(init_way( + &cell, + Box::new(|permit| Box::pin(async move { Ok(("t1", permit)) })), + )); + + let mut t2 = pin!(init_way( + &cell, + Box::new(|permit| Box::pin(async move { Ok(("t2", permit)) })), + )); + + // drive t2 first to the init_semaphore -- the timeout will be hit once t2 future can + // no longer make progress + tokio::select! { + _ = &mut t2 => unreachable!("it cannot get permit"), + _ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {} + } + + // followed by t1 in the init_semaphore + tokio::select! { + _ = &mut t1 => unreachable!("it cannot get permit"), + _ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {} + } + + // now let t2 proceed and initialize + drop(permit); + t2.await; + + let (s, permit) = { cell.get().unwrap().take_and_deinit() }; + assert_eq!("t2", s); + + // now originally t1 would see the semaphore it has as closed. it cannot yet get a permit from + // the new one. + tokio::select! { + _ = &mut t1 => unreachable!("it cannot get permit"), + _ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {} + } + + // only now we get to initialize it + drop(permit); + t1.await; + + assert_eq!("t1", *cell.get().unwrap()); + } } From c77411e9035ac38925652bf1f772b333acb0b9ac Mon Sep 17 00:00:00 2001 From: Joonas Koivunen Date: Mon, 12 Feb 2024 14:52:20 +0200 Subject: [PATCH 02/20] cleanup around `attach` (#6621) The smaller changes I found while looking around #6584. - rustfmt was not able to format handle_timeline_create - fix Generation::get_suffix always allocating - Generation was missing a `#[track_caller]` for panicky method - attach has a lot of issues, but even with this PR it cannot be formatted by rustfmt - moved the `preload` span to be on top of `attach` -- it is awaited inline - make disconnected panic! or unreachable! into expect, expect_err --- libs/utils/src/generation.rs | 41 ++++- pageserver/src/http/routes.rs | 76 +++++---- pageserver/src/tenant.rs | 199 +++++++++++------------ pageserver/src/tenant/delete.rs | 8 +- pageserver/src/tenant/timeline/delete.rs | 9 +- 5 files changed, 177 insertions(+), 156 deletions(-) diff --git a/libs/utils/src/generation.rs b/libs/utils/src/generation.rs index 46eadee1da..6f6c46cfeb 100644 --- a/libs/utils/src/generation.rs +++ b/libs/utils/src/generation.rs @@ -54,12 +54,10 @@ impl Generation { } #[track_caller] - pub fn get_suffix(&self) -> String { + pub fn get_suffix(&self) -> impl std::fmt::Display { match self { - Self::Valid(v) => { - format!("-{:08x}", v) - } - Self::None => "".into(), + Self::Valid(v) => GenerationFileSuffix(Some(*v)), + Self::None => GenerationFileSuffix(None), Self::Broken => { panic!("Tried to use a broken generation"); } @@ -90,6 +88,7 @@ impl Generation { } } + #[track_caller] pub fn next(&self) -> Generation { match self { Self::Valid(n) => Self::Valid(*n + 1), @@ -107,6 +106,18 @@ impl Generation { } } +struct GenerationFileSuffix(Option); + +impl std::fmt::Display for GenerationFileSuffix { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if let Some(g) = self.0 { + write!(f, "-{g:08x}") + } else { + Ok(()) + } + } +} + impl Serialize for Generation { fn serialize(&self, serializer: S) -> Result where @@ -164,4 +175,24 @@ mod test { assert!(Generation::none() < Generation::new(0)); assert!(Generation::none() < Generation::new(1)); } + + #[test] + fn suffix_is_stable() { + use std::fmt::Write as _; + + // the suffix must remain stable through-out the pageserver remote storage evolution and + // not be changed accidentially without thinking about migration + let examples = [ + (line!(), Generation::None, ""), + (line!(), Generation::Valid(0), "-00000000"), + (line!(), Generation::Valid(u32::MAX), "-ffffffff"), + ]; + + let mut s = String::new(); + for (line, gen, expected) in examples { + s.clear(); + write!(s, "{}", &gen.get_suffix()).expect("string grows"); + assert_eq!(s, expected, "example on {line}"); + } + } } diff --git a/pageserver/src/http/routes.rs b/pageserver/src/http/routes.rs index af9a3c7301..4be8ee9892 100644 --- a/pageserver/src/http/routes.rs +++ b/pageserver/src/http/routes.rs @@ -488,7 +488,9 @@ async fn timeline_create_handler( let state = get_state(&request); async { - let tenant = state.tenant_manager.get_attached_tenant_shard(tenant_shard_id, false)?; + let tenant = state + .tenant_manager + .get_attached_tenant_shard(tenant_shard_id, false)?; tenant.wait_to_become_active(ACTIVE_TENANT_TIMEOUT).await?; @@ -498,48 +500,62 @@ async fn timeline_create_handler( tracing::info!("bootstrapping"); } - match tenant.create_timeline( - new_timeline_id, - request_data.ancestor_timeline_id.map(TimelineId::from), - request_data.ancestor_start_lsn, - request_data.pg_version.unwrap_or(crate::DEFAULT_PG_VERSION), - request_data.existing_initdb_timeline_id, - state.broker_client.clone(), - &ctx, - ) - .await { + match tenant + .create_timeline( + new_timeline_id, + request_data.ancestor_timeline_id, + request_data.ancestor_start_lsn, + request_data.pg_version.unwrap_or(crate::DEFAULT_PG_VERSION), + request_data.existing_initdb_timeline_id, + state.broker_client.clone(), + &ctx, + ) + .await + { Ok(new_timeline) => { // Created. Construct a TimelineInfo for it. - let timeline_info = build_timeline_info_common(&new_timeline, &ctx, tenant::timeline::GetLogicalSizePriority::User) - .await - .map_err(ApiError::InternalServerError)?; + let timeline_info = build_timeline_info_common( + &new_timeline, + &ctx, + tenant::timeline::GetLogicalSizePriority::User, + ) + .await + .map_err(ApiError::InternalServerError)?; json_response(StatusCode::CREATED, timeline_info) } Err(_) if tenant.cancel.is_cancelled() => { // In case we get some ugly error type during shutdown, cast it into a clean 503. - json_response(StatusCode::SERVICE_UNAVAILABLE, HttpErrorBody::from_msg("Tenant shutting down".to_string())) - } - Err(tenant::CreateTimelineError::Conflict | tenant::CreateTimelineError::AlreadyCreating) => { - json_response(StatusCode::CONFLICT, ()) - } - Err(tenant::CreateTimelineError::AncestorLsn(err)) => { - json_response(StatusCode::NOT_ACCEPTABLE, HttpErrorBody::from_msg( - format!("{err:#}") - )) - } - Err(e @ tenant::CreateTimelineError::AncestorNotActive) => { - json_response(StatusCode::SERVICE_UNAVAILABLE, HttpErrorBody::from_msg(e.to_string())) - } - Err(tenant::CreateTimelineError::ShuttingDown) => { - json_response(StatusCode::SERVICE_UNAVAILABLE,HttpErrorBody::from_msg("tenant shutting down".to_string())) + json_response( + StatusCode::SERVICE_UNAVAILABLE, + HttpErrorBody::from_msg("Tenant shutting down".to_string()), + ) } + Err( + tenant::CreateTimelineError::Conflict + | tenant::CreateTimelineError::AlreadyCreating, + ) => json_response(StatusCode::CONFLICT, ()), + Err(tenant::CreateTimelineError::AncestorLsn(err)) => json_response( + StatusCode::NOT_ACCEPTABLE, + HttpErrorBody::from_msg(format!("{err:#}")), + ), + Err(e @ tenant::CreateTimelineError::AncestorNotActive) => json_response( + StatusCode::SERVICE_UNAVAILABLE, + HttpErrorBody::from_msg(e.to_string()), + ), + Err(tenant::CreateTimelineError::ShuttingDown) => json_response( + StatusCode::SERVICE_UNAVAILABLE, + HttpErrorBody::from_msg("tenant shutting down".to_string()), + ), Err(tenant::CreateTimelineError::Other(err)) => Err(ApiError::InternalServerError(err)), } } .instrument(info_span!("timeline_create", tenant_id = %tenant_shard_id.tenant_id, shard_id = %tenant_shard_id.shard_slug(), - timeline_id = %new_timeline_id, lsn=?request_data.ancestor_start_lsn, pg_version=?request_data.pg_version)) + timeline_id = %new_timeline_id, + lsn=?request_data.ancestor_start_lsn, + pg_version=?request_data.pg_version + )) .await } diff --git a/pageserver/src/tenant.rs b/pageserver/src/tenant.rs index 4446c410b0..d946c57118 100644 --- a/pageserver/src/tenant.rs +++ b/pageserver/src/tenant.rs @@ -644,10 +644,10 @@ impl Tenant { // The attach task will carry a GateGuard, so that shutdown() reliably waits for it to drop out if // we shut down while attaching. - let Ok(attach_gate_guard) = tenant.gate.enter() else { - // We just created the Tenant: nothing else can have shut it down yet - unreachable!(); - }; + let attach_gate_guard = tenant + .gate + .enter() + .expect("We just created the Tenant: nothing else can have shut it down yet"); // Do all the hard work in the background let tenant_clone = Arc::clone(&tenant); @@ -755,36 +755,27 @@ impl Tenant { AttachType::Normal }; - let preload_timer = TENANT.preload.start_timer(); - let preload = match mode { - SpawnMode::Create => { - // Don't count the skipped preload into the histogram of preload durations - preload_timer.stop_and_discard(); + let preload = match (&mode, &remote_storage) { + (SpawnMode::Create, _) => { None }, - SpawnMode::Normal => { - match &remote_storage { - Some(remote_storage) => Some( - match tenant_clone - .preload(remote_storage, task_mgr::shutdown_token()) - .instrument( - tracing::info_span!(parent: None, "attach_preload", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug()), - ) - .await { - Ok(p) => { - preload_timer.observe_duration(); - p - } - , - Err(e) => { - make_broken(&tenant_clone, anyhow::anyhow!(e)); - return Ok(()); - } - }, - ), - None => None, + (SpawnMode::Normal, Some(remote_storage)) => { + let _preload_timer = TENANT.preload.start_timer(); + let res = tenant_clone + .preload(remote_storage, task_mgr::shutdown_token()) + .await; + match res { + Ok(p) => Some(p), + Err(e) => { + make_broken(&tenant_clone, anyhow::anyhow!(e)); + return Ok(()); + } } } + (SpawnMode::Normal, None) => { + let _preload_timer = TENANT.preload.start_timer(); + None + } }; // Remote preload is complete. @@ -820,36 +811,37 @@ impl Tenant { info!("ready for backgound jobs barrier"); } - match DeleteTenantFlow::resume_from_attach( + let deleted = DeleteTenantFlow::resume_from_attach( deletion, &tenant_clone, preload, tenants, &ctx, ) - .await - { - Err(err) => { - make_broken(&tenant_clone, anyhow::anyhow!(err)); - return Ok(()); - } - Ok(()) => return Ok(()), + .await; + + if let Err(e) = deleted { + make_broken(&tenant_clone, anyhow::anyhow!(e)); } + + return Ok(()); } // We will time the duration of the attach phase unless this is a creation (attach will do no work) - let attach_timer = match mode { - SpawnMode::Create => None, - SpawnMode::Normal => {Some(TENANT.attach.start_timer())} + let attached = { + let _attach_timer = match mode { + SpawnMode::Create => None, + SpawnMode::Normal => {Some(TENANT.attach.start_timer())} + }; + tenant_clone.attach(preload, mode, &ctx).await }; - match tenant_clone.attach(preload, mode, &ctx).await { + + match attached { Ok(()) => { info!("attach finished, activating"); - if let Some(t)= attach_timer {t.observe_duration();} tenant_clone.activate(broker_client, None, &ctx); } Err(e) => { - if let Some(t)= attach_timer {t.observe_duration();} make_broken(&tenant_clone, anyhow::anyhow!(e)); } } @@ -862,34 +854,26 @@ impl Tenant { // logical size calculations: if logical size calculation semaphore is saturated, // then warmup will wait for that before proceeding to the next tenant. if let AttachType::Warmup(_permit) = attach_type { - let mut futs = FuturesUnordered::new(); - let timelines: Vec<_> = tenant_clone.timelines.lock().unwrap().values().cloned().collect(); - for t in timelines { - futs.push(t.await_initial_logical_size()) - } + let mut futs: FuturesUnordered<_> = tenant_clone.timelines.lock().unwrap().values().cloned().map(|t| t.await_initial_logical_size()).collect(); tracing::info!("Waiting for initial logical sizes while warming up..."); - while futs.next().await.is_some() { - - } + while futs.next().await.is_some() {} tracing::info!("Warm-up complete"); } Ok(()) } - .instrument({ - let span = tracing::info_span!(parent: None, "attach", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug(), gen=?generation); - span.follows_from(Span::current()); - span - }), + .instrument(tracing::info_span!(parent: None, "attach", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug(), gen=?generation)), ); Ok(tenant) } + #[instrument(skip_all)] pub(crate) async fn preload( self: &Arc, remote_storage: &GenericRemoteStorage, cancel: CancellationToken, ) -> anyhow::Result { + span::debug_assert_current_span_has_tenant_id(); // Get list of remote timelines // download index files for every tenant timeline info!("listing remote timelines"); @@ -3982,6 +3966,8 @@ pub(crate) mod harness { } } + #[cfg(test)] + #[derive(Debug)] enum LoadMode { Local, Remote, @@ -4064,7 +4050,7 @@ pub(crate) mod harness { info_span!("TenantHarness", tenant_id=%self.tenant_shard_id.tenant_id, shard_id=%self.tenant_shard_id.shard_slug()) } - pub async fn load(&self) -> (Arc, RequestContext) { + pub(crate) async fn load(&self) -> (Arc, RequestContext) { let ctx = RequestContext::new(TaskKind::UnitTest, DownloadBehavior::Error); ( self.try_load(&ctx) @@ -4074,31 +4060,31 @@ pub(crate) mod harness { ) } - fn remote_empty(&self) -> bool { - let tenant_path = self.conf.tenant_path(&self.tenant_shard_id); - let remote_tenant_dir = self - .remote_fs_dir - .join(tenant_path.strip_prefix(&self.conf.workdir).unwrap()); - if std::fs::metadata(&remote_tenant_dir).is_err() { - return true; - } - - match std::fs::read_dir(remote_tenant_dir) - .unwrap() - .flatten() - .next() - { - Some(entry) => { - tracing::debug!( - "remote_empty: not empty, found file {}", - entry.file_name().to_string_lossy(), - ); - false - } - None => true, - } + /// For tests that specifically want to exercise the local load path, which does + /// not use remote storage. + pub(crate) async fn try_load_local( + &self, + ctx: &RequestContext, + ) -> anyhow::Result> { + self.do_try_load(ctx, LoadMode::Local).await } + /// The 'load' in this function is either a local load or a normal attachment, + pub(crate) async fn try_load(&self, ctx: &RequestContext) -> anyhow::Result> { + // If we have nothing in remote storage, must use load_local instead of attach: attach + // will error out if there are no timelines. + // + // See https://github.com/neondatabase/neon/issues/5456 for how we will eliminate + // this weird state of a Tenant which exists but doesn't have any timelines. + let mode = match self.remote_empty() { + true => LoadMode::Local, + false => LoadMode::Remote, + }; + + self.do_try_load(ctx, mode).await + } + + #[instrument(skip_all, fields(tenant_id=%self.tenant_shard_id.tenant_id, shard_id=%self.tenant_shard_id.shard_slug(), ?mode))] async fn do_try_load( &self, ctx: &RequestContext, @@ -4125,20 +4111,13 @@ pub(crate) mod harness { match mode { LoadMode::Local => { - tenant - .load_local(ctx) - .instrument(info_span!("try_load", tenant_id=%self.tenant_shard_id.tenant_id, shard_id=%self.tenant_shard_id.shard_slug())) - .await?; + tenant.load_local(ctx).await?; } LoadMode::Remote => { let preload = tenant .preload(&self.remote_storage, CancellationToken::new()) - .instrument(info_span!("try_load_preload", tenant_id=%self.tenant_shard_id.tenant_id, shard_id=%self.tenant_shard_id.shard_slug())) - .await?; - tenant - .attach(Some(preload), SpawnMode::Normal, ctx) - .instrument(info_span!("try_load", tenant_id=%self.tenant_shard_id.tenant_id, shard_id=%self.tenant_shard_id.shard_slug())) .await?; + tenant.attach(Some(preload), SpawnMode::Normal, ctx).await?; } } @@ -4149,25 +4128,29 @@ pub(crate) mod harness { Ok(tenant) } - /// For tests that specifically want to exercise the local load path, which does - /// not use remote storage. - pub async fn try_load_local(&self, ctx: &RequestContext) -> anyhow::Result> { - self.do_try_load(ctx, LoadMode::Local).await - } + fn remote_empty(&self) -> bool { + let tenant_path = self.conf.tenant_path(&self.tenant_shard_id); + let remote_tenant_dir = self + .remote_fs_dir + .join(tenant_path.strip_prefix(&self.conf.workdir).unwrap()); + if std::fs::metadata(&remote_tenant_dir).is_err() { + return true; + } - /// The 'load' in this function is either a local load or a normal attachment, - pub async fn try_load(&self, ctx: &RequestContext) -> anyhow::Result> { - // If we have nothing in remote storage, must use load_local instead of attach: attach - // will error out if there are no timelines. - // - // See https://github.com/neondatabase/neon/issues/5456 for how we will eliminate - // this weird state of a Tenant which exists but doesn't have any timelines. - let mode = match self.remote_empty() { - true => LoadMode::Local, - false => LoadMode::Remote, - }; - - self.do_try_load(ctx, mode).await + match std::fs::read_dir(remote_tenant_dir) + .unwrap() + .flatten() + .next() + { + Some(entry) => { + tracing::debug!( + "remote_empty: not empty, found file {}", + entry.file_name().to_string_lossy(), + ); + false + } + None => true, + } } pub fn timeline_path(&self, timeline_id: &TimelineId) -> Utf8PathBuf { diff --git a/pageserver/src/tenant/delete.rs b/pageserver/src/tenant/delete.rs index 7c35914b61..0e192b577c 100644 --- a/pageserver/src/tenant/delete.rs +++ b/pageserver/src/tenant/delete.rs @@ -6,7 +6,7 @@ use pageserver_api::{models::TenantState, shard::TenantShardId}; use remote_storage::{GenericRemoteStorage, RemotePath}; use tokio::sync::OwnedMutexGuard; use tokio_util::sync::CancellationToken; -use tracing::{error, instrument, Instrument, Span}; +use tracing::{error, instrument, Instrument}; use utils::{backoff, completion, crashsafe, fs_ext, id::TimelineId}; @@ -496,11 +496,7 @@ impl DeleteTenantFlow { }; Ok(()) } - .instrument({ - let span = tracing::info_span!(parent: None, "delete_tenant", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug()); - span.follows_from(Span::current()); - span - }), + .instrument(tracing::info_span!(parent: None, "delete_tenant", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug())), ); } diff --git a/pageserver/src/tenant/timeline/delete.rs b/pageserver/src/tenant/timeline/delete.rs index 88d7ce61dd..dc499197b0 100644 --- a/pageserver/src/tenant/timeline/delete.rs +++ b/pageserver/src/tenant/timeline/delete.rs @@ -6,7 +6,7 @@ use std::{ use anyhow::Context; use pageserver_api::{models::TimelineState, shard::TenantShardId}; use tokio::sync::OwnedMutexGuard; -use tracing::{debug, error, info, instrument, warn, Instrument, Span}; +use tracing::{debug, error, info, instrument, warn, Instrument}; use utils::{crashsafe, fs_ext, id::TimelineId}; use crate::{ @@ -541,12 +541,7 @@ impl DeleteTimelineFlow { }; Ok(()) } - .instrument({ - let span = - tracing::info_span!(parent: None, "delete_timeline", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug(),timeline_id=%timeline_id); - span.follows_from(Span::current()); - span - }), + .instrument(tracing::info_span!(parent: None, "delete_timeline", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug(),timeline_id=%timeline_id)), ); } From 020e607637fe00ec869fd6eb71dfa732ae501b37 Mon Sep 17 00:00:00 2001 From: Anna Khanova <32508607+khanova@users.noreply.github.com> Date: Mon, 12 Feb 2024 14:04:46 +0100 Subject: [PATCH 03/20] Proxy: copy bidirectional fork (#6720) ## Problem `tokio::io::copy_bidirectional` doesn't close the connection once one of the sides closes it. It's not really suitable for the postgres protocol. ## Summary of changes Fork `copy_bidirectional` and initiate a shutdown for both connections. --------- Co-authored-by: Conrad Ludgate --- proxy/src/proxy.rs | 1 + proxy/src/proxy/copy_bidirectional.rs | 256 ++++++++++++++++++++++++++ proxy/src/proxy/passthrough.rs | 2 +- 3 files changed, 258 insertions(+), 1 deletion(-) create mode 100644 proxy/src/proxy/copy_bidirectional.rs diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 50e22ec72a..77aadb6f28 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -2,6 +2,7 @@ mod tests; pub mod connect_compute; +mod copy_bidirectional; pub mod handshake; pub mod passthrough; pub mod retry; diff --git a/proxy/src/proxy/copy_bidirectional.rs b/proxy/src/proxy/copy_bidirectional.rs new file mode 100644 index 0000000000..2ecc1151da --- /dev/null +++ b/proxy/src/proxy/copy_bidirectional.rs @@ -0,0 +1,256 @@ +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +use std::future::poll_fn; +use std::io; +use std::pin::Pin; +use std::task::{ready, Context, Poll}; + +#[derive(Debug)] +enum TransferState { + Running(CopyBuffer), + ShuttingDown(u64), + Done(u64), +} + +fn transfer_one_direction( + cx: &mut Context<'_>, + state: &mut TransferState, + r: &mut A, + w: &mut B, +) -> Poll> +where + A: AsyncRead + AsyncWrite + Unpin + ?Sized, + B: AsyncRead + AsyncWrite + Unpin + ?Sized, +{ + let mut r = Pin::new(r); + let mut w = Pin::new(w); + loop { + match state { + TransferState::Running(buf) => { + let count = ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?; + *state = TransferState::ShuttingDown(count); + } + TransferState::ShuttingDown(count) => { + ready!(w.as_mut().poll_shutdown(cx))?; + *state = TransferState::Done(*count); + } + TransferState::Done(count) => return Poll::Ready(Ok(*count)), + } + } +} + +pub(super) async fn copy_bidirectional( + a: &mut A, + b: &mut B, +) -> Result<(u64, u64), std::io::Error> +where + A: AsyncRead + AsyncWrite + Unpin + ?Sized, + B: AsyncRead + AsyncWrite + Unpin + ?Sized, +{ + let mut a_to_b = TransferState::Running(CopyBuffer::new()); + let mut b_to_a = TransferState::Running(CopyBuffer::new()); + + poll_fn(|cx| { + let mut a_to_b_result = transfer_one_direction(cx, &mut a_to_b, a, b)?; + let mut b_to_a_result = transfer_one_direction(cx, &mut b_to_a, b, a)?; + + // Early termination checks + if let TransferState::Done(_) = a_to_b { + if let TransferState::Running(buf) = &b_to_a { + // Initiate shutdown + b_to_a = TransferState::ShuttingDown(buf.amt); + b_to_a_result = transfer_one_direction(cx, &mut b_to_a, b, a)?; + } + } + if let TransferState::Done(_) = b_to_a { + if let TransferState::Running(buf) = &a_to_b { + // Initiate shutdown + a_to_b = TransferState::ShuttingDown(buf.amt); + a_to_b_result = transfer_one_direction(cx, &mut a_to_b, a, b)?; + } + } + + // It is not a problem if ready! returns early ... (comment remains the same) + let a_to_b = ready!(a_to_b_result); + let b_to_a = ready!(b_to_a_result); + + Poll::Ready(Ok((a_to_b, b_to_a))) + }) + .await +} + +#[derive(Debug)] +pub(super) struct CopyBuffer { + read_done: bool, + need_flush: bool, + pos: usize, + cap: usize, + amt: u64, + buf: Box<[u8]>, +} +const DEFAULT_BUF_SIZE: usize = 8 * 1024; + +impl CopyBuffer { + pub(super) fn new() -> Self { + Self { + read_done: false, + need_flush: false, + pos: 0, + cap: 0, + amt: 0, + buf: vec![0; DEFAULT_BUF_SIZE].into_boxed_slice(), + } + } + + fn poll_fill_buf( + &mut self, + cx: &mut Context<'_>, + reader: Pin<&mut R>, + ) -> Poll> + where + R: AsyncRead + ?Sized, + { + let me = &mut *self; + let mut buf = ReadBuf::new(&mut me.buf); + buf.set_filled(me.cap); + + let res = reader.poll_read(cx, &mut buf); + if let Poll::Ready(Ok(())) = res { + let filled_len = buf.filled().len(); + me.read_done = me.cap == filled_len; + me.cap = filled_len; + } + res + } + + fn poll_write_buf( + &mut self, + cx: &mut Context<'_>, + mut reader: Pin<&mut R>, + mut writer: Pin<&mut W>, + ) -> Poll> + where + R: AsyncRead + ?Sized, + W: AsyncWrite + ?Sized, + { + let me = &mut *self; + match writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]) { + Poll::Pending => { + // Top up the buffer towards full if we can read a bit more + // data - this should improve the chances of a large write + if !me.read_done && me.cap < me.buf.len() { + ready!(me.poll_fill_buf(cx, reader.as_mut()))?; + } + Poll::Pending + } + res => res, + } + } + + pub(super) fn poll_copy( + &mut self, + cx: &mut Context<'_>, + mut reader: Pin<&mut R>, + mut writer: Pin<&mut W>, + ) -> Poll> + where + R: AsyncRead + ?Sized, + W: AsyncWrite + ?Sized, + { + loop { + // If our buffer is empty, then we need to read some data to + // continue. + if self.pos == self.cap && !self.read_done { + self.pos = 0; + self.cap = 0; + + match self.poll_fill_buf(cx, reader.as_mut()) { + Poll::Ready(Ok(())) => (), + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Pending => { + // Try flushing when the reader has no progress to avoid deadlock + // when the reader depends on buffered writer. + if self.need_flush { + ready!(writer.as_mut().poll_flush(cx))?; + self.need_flush = false; + } + + return Poll::Pending; + } + } + } + + // If our buffer has some data, let's write it out! + while self.pos < self.cap { + let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?; + if i == 0 { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::WriteZero, + "write zero byte into writer", + ))); + } else { + self.pos += i; + self.amt += i as u64; + self.need_flush = true; + } + } + + // If pos larger than cap, this loop will never stop. + // In particular, user's wrong poll_write implementation returning + // incorrect written length may lead to thread blocking. + debug_assert!( + self.pos <= self.cap, + "writer returned length larger than input slice" + ); + + // If we've written all the data and we've seen EOF, flush out the + // data and finish the transfer. + if self.pos == self.cap && self.read_done { + ready!(writer.as_mut().poll_flush(cx))?; + return Poll::Ready(Ok(self.amt)); + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::io::AsyncWriteExt; + + #[tokio::test] + async fn test_early_termination_a_to_d() { + let (mut a_mock, mut b_mock) = tokio::io::duplex(8); // Create a mock duplex stream + let (mut c_mock, mut d_mock) = tokio::io::duplex(32); // Create a mock duplex stream + + // Simulate 'a' finishing while there's still data for 'b' + a_mock.write_all(b"hello").await.unwrap(); + a_mock.shutdown().await.unwrap(); + d_mock.write_all(b"Neon Serverless Postgres").await.unwrap(); + + let result = copy_bidirectional(&mut b_mock, &mut c_mock).await.unwrap(); + + // Assert correct transferred amounts + let (a_to_d_count, d_to_a_count) = result; + assert_eq!(a_to_d_count, 5); // 'hello' was transferred + assert!(d_to_a_count <= 8); // response only partially transferred or not at all + } + + #[tokio::test] + async fn test_early_termination_d_to_a() { + let (mut a_mock, mut b_mock) = tokio::io::duplex(32); // Create a mock duplex stream + let (mut c_mock, mut d_mock) = tokio::io::duplex(8); // Create a mock duplex stream + + // Simulate 'a' finishing while there's still data for 'b' + d_mock.write_all(b"hello").await.unwrap(); + d_mock.shutdown().await.unwrap(); + a_mock.write_all(b"Neon Serverless Postgres").await.unwrap(); + + let result = copy_bidirectional(&mut b_mock, &mut c_mock).await.unwrap(); + + // Assert correct transferred amounts + let (a_to_d_count, d_to_a_count) = result; + assert_eq!(d_to_a_count, 5); // 'hello' was transferred + assert!(a_to_d_count <= 8); // response only partially transferred or not at all + } +} diff --git a/proxy/src/proxy/passthrough.rs b/proxy/src/proxy/passthrough.rs index b7018c6fb5..c98f68d8d1 100644 --- a/proxy/src/proxy/passthrough.rs +++ b/proxy/src/proxy/passthrough.rs @@ -45,7 +45,7 @@ pub async fn proxy_pass( // Starting from here we only proxy the client's traffic. info!("performing the proxy pass..."); - let _ = tokio::io::copy_bidirectional(&mut client, &mut compute).await?; + let _ = crate::proxy::copy_bidirectional::copy_bidirectional(&mut client, &mut compute).await?; Ok(()) } From 98ec5c5c466158fcb10394303077132efa680690 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Mon, 12 Feb 2024 13:14:06 +0000 Subject: [PATCH 04/20] proxy: some more parquet data (#6711) ## Summary of changes add auth_method and database to the parquet logs --- proxy/src/auth/backend.rs | 8 ++-- proxy/src/auth/backend/classic.rs | 8 ++-- proxy/src/auth/backend/hacks.rs | 12 +++-- proxy/src/auth/backend/link.rs | 2 + proxy/src/auth/credentials.rs | 3 ++ proxy/src/auth/flow.rs | 17 ++++++- proxy/src/context.rs | 23 ++++++++- proxy/src/context/parquet.rs | 69 ++++++++++++++++----------- proxy/src/proxy/tests.rs | 2 +- proxy/src/serverless/sql_over_http.rs | 9 +++- 10 files changed, 104 insertions(+), 49 deletions(-) diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index fa2782bee3..c9f21f1cf5 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -194,8 +194,7 @@ async fn auth_quirks( // We now expect to see a very specific payload in the place of password. let (info, unauthenticated_password) = match user_info.try_into() { Err(info) => { - let res = hacks::password_hack_no_authentication(info, client, &mut ctx.latency_timer) - .await?; + let res = hacks::password_hack_no_authentication(ctx, info, client).await?; ctx.set_endpoint_id(res.info.endpoint.clone()); tracing::Span::current().record("ep", &tracing::field::display(&res.info.endpoint)); @@ -276,11 +275,12 @@ async fn authenticate_with_secret( // Perform cleartext auth if we're allowed to do that. // Currently, we use it for websocket connections (latency). if allow_cleartext { - return hacks::authenticate_cleartext(info, client, &mut ctx.latency_timer, secret).await; + ctx.set_auth_method(crate::context::AuthMethod::Cleartext); + return hacks::authenticate_cleartext(ctx, info, client, secret).await; } // Finally, proceed with the main auth flow (SCRAM-based). - classic::authenticate(info, client, config, &mut ctx.latency_timer, secret).await + classic::authenticate(ctx, info, client, config, secret).await } impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> { diff --git a/proxy/src/auth/backend/classic.rs b/proxy/src/auth/backend/classic.rs index 745dd75107..e855843bc3 100644 --- a/proxy/src/auth/backend/classic.rs +++ b/proxy/src/auth/backend/classic.rs @@ -4,7 +4,7 @@ use crate::{ compute, config::AuthenticationConfig, console::AuthSecret, - metrics::LatencyTimer, + context::RequestMonitoring, sasl, stream::{PqStream, Stream}, }; @@ -12,10 +12,10 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{info, warn}; pub(super) async fn authenticate( + ctx: &mut RequestMonitoring, creds: ComputeUserInfo, client: &mut PqStream>, config: &'static AuthenticationConfig, - latency_timer: &mut LatencyTimer, secret: AuthSecret, ) -> auth::Result> { let flow = AuthFlow::new(client); @@ -27,13 +27,11 @@ pub(super) async fn authenticate( } AuthSecret::Scram(secret) => { info!("auth endpoint chooses SCRAM"); - let scram = auth::Scram(&secret); + let scram = auth::Scram(&secret, &mut *ctx); let auth_outcome = tokio::time::timeout( config.scram_protocol_timeout, async { - // pause the timer while we communicate with the client - let _paused = latency_timer.pause(); flow.begin(scram).await.map_err(|error| { warn!(?error, "error sending scram acknowledgement"); diff --git a/proxy/src/auth/backend/hacks.rs b/proxy/src/auth/backend/hacks.rs index b6c1a92d3c..9f60b709d4 100644 --- a/proxy/src/auth/backend/hacks.rs +++ b/proxy/src/auth/backend/hacks.rs @@ -4,7 +4,7 @@ use super::{ use crate::{ auth::{self, AuthFlow}, console::AuthSecret, - metrics::LatencyTimer, + context::RequestMonitoring, sasl, stream::{self, Stream}, }; @@ -16,15 +16,16 @@ use tracing::{info, warn}; /// These properties are benefical for serverless JS workers, so we /// use this mechanism for websocket connections. pub async fn authenticate_cleartext( + ctx: &mut RequestMonitoring, info: ComputeUserInfo, client: &mut stream::PqStream>, - latency_timer: &mut LatencyTimer, secret: AuthSecret, ) -> auth::Result> { warn!("cleartext auth flow override is enabled, proceeding"); + ctx.set_auth_method(crate::context::AuthMethod::Cleartext); // pause the timer while we communicate with the client - let _paused = latency_timer.pause(); + let _paused = ctx.latency_timer.pause(); let auth_outcome = AuthFlow::new(client) .begin(auth::CleartextPassword(secret)) @@ -47,14 +48,15 @@ pub async fn authenticate_cleartext( /// Similar to [`authenticate_cleartext`], but there's a specific password format, /// and passwords are not yet validated (we don't know how to validate them!) pub async fn password_hack_no_authentication( + ctx: &mut RequestMonitoring, info: ComputeUserInfoNoEndpoint, client: &mut stream::PqStream>, - latency_timer: &mut LatencyTimer, ) -> auth::Result>> { warn!("project not specified, resorting to the password hack auth flow"); + ctx.set_auth_method(crate::context::AuthMethod::Cleartext); // pause the timer while we communicate with the client - let _paused = latency_timer.pause(); + let _paused = ctx.latency_timer.pause(); let payload = AuthFlow::new(client) .begin(auth::PasswordHack) diff --git a/proxy/src/auth/backend/link.rs b/proxy/src/auth/backend/link.rs index c71637dd1a..bf9ebf4c18 100644 --- a/proxy/src/auth/backend/link.rs +++ b/proxy/src/auth/backend/link.rs @@ -61,6 +61,8 @@ pub(super) async fn authenticate( link_uri: &reqwest::Url, client: &mut PqStream, ) -> auth::Result { + ctx.set_auth_method(crate::context::AuthMethod::Web); + // registering waiter can fail if we get unlucky with rng. // just try again. let (psql_session_id, waiter) = loop { diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index d32609e44c..d318b3be54 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -99,6 +99,9 @@ impl ComputeUserInfoMaybeEndpoint { // record the values if we have them ctx.set_application(params.get("application_name").map(SmolStr::from)); ctx.set_user(user.clone()); + if let Some(dbname) = params.get("database") { + ctx.set_dbname(dbname.into()); + } // Project name might be passed via PG's command-line options. let endpoint_option = params diff --git a/proxy/src/auth/flow.rs b/proxy/src/auth/flow.rs index c2783e236c..dce73138c6 100644 --- a/proxy/src/auth/flow.rs +++ b/proxy/src/auth/flow.rs @@ -4,9 +4,11 @@ use super::{backend::ComputeCredentialKeys, AuthErrorImpl, PasswordHackPayload}; use crate::{ config::TlsServerEndPoint, console::AuthSecret, + context::RequestMonitoring, sasl, scram, stream::{PqStream, Stream}, }; +use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS}; use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be}; use std::io; use tokio::io::{AsyncRead, AsyncWrite}; @@ -23,7 +25,7 @@ pub trait AuthMethod { pub struct Begin; /// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`]. -pub struct Scram<'a>(pub &'a scram::ServerSecret); +pub struct Scram<'a>(pub &'a scram::ServerSecret, pub &'a mut RequestMonitoring); impl AuthMethod for Scram<'_> { #[inline(always)] @@ -138,6 +140,11 @@ impl AuthFlow<'_, S, CleartextPassword> { impl AuthFlow<'_, S, Scram<'_>> { /// Perform user authentication. Raise an error in case authentication failed. pub async fn authenticate(self) -> super::Result> { + let Scram(secret, ctx) = self.state; + + // pause the timer while we communicate with the client + let _paused = ctx.latency_timer.pause(); + // Initial client message contains the chosen auth method's name. let msg = self.stream.read_password_message().await?; let sasl = sasl::FirstMessage::parse(&msg) @@ -148,9 +155,15 @@ impl AuthFlow<'_, S, Scram<'_>> { return Err(super::AuthError::bad_auth_method(sasl.method)); } + match sasl.method { + SCRAM_SHA_256 => ctx.auth_method = Some(crate::context::AuthMethod::ScramSha256), + SCRAM_SHA_256_PLUS => { + ctx.auth_method = Some(crate::context::AuthMethod::ScramSha256Plus) + } + _ => {} + } info!("client chooses {}", sasl.method); - let secret = self.state.0; let outcome = sasl::SaslStream::new(self.stream, sasl.message) .authenticate(scram::Exchange::new( secret, diff --git a/proxy/src/context.rs b/proxy/src/context.rs index d2bf3f68d3..0cea53ae63 100644 --- a/proxy/src/context.rs +++ b/proxy/src/context.rs @@ -11,7 +11,7 @@ use crate::{ console::messages::MetricsAuxInfo, error::ErrorKind, metrics::{LatencyTimer, ENDPOINT_ERRORS_BY_KIND, ERROR_BY_KIND}, - BranchId, EndpointId, ProjectId, RoleName, + BranchId, DbName, EndpointId, ProjectId, RoleName, }; pub mod parquet; @@ -34,9 +34,11 @@ pub struct RequestMonitoring { project: Option, branch: Option, endpoint_id: Option, + dbname: Option, user: Option, application: Option, error_kind: Option, + pub(crate) auth_method: Option, success: bool, // extra @@ -45,6 +47,15 @@ pub struct RequestMonitoring { pub latency_timer: LatencyTimer, } +#[derive(Clone, Debug)] +pub enum AuthMethod { + // aka link aka passwordless + Web, + ScramSha256, + ScramSha256Plus, + Cleartext, +} + impl RequestMonitoring { pub fn new( session_id: Uuid, @@ -62,9 +73,11 @@ impl RequestMonitoring { project: None, branch: None, endpoint_id: None, + dbname: None, user: None, application: None, error_kind: None, + auth_method: None, success: false, sender: LOG_CHAN.get().and_then(|tx| tx.upgrade()), @@ -106,10 +119,18 @@ impl RequestMonitoring { self.application = app.or_else(|| self.application.clone()); } + pub fn set_dbname(&mut self, dbname: DbName) { + self.dbname = Some(dbname); + } + pub fn set_user(&mut self, user: RoleName) { self.user = Some(user); } + pub fn set_auth_method(&mut self, auth_method: AuthMethod) { + self.auth_method = Some(auth_method); + } + pub fn set_error_kind(&mut self, kind: ErrorKind) { ERROR_BY_KIND .with_label_values(&[kind.to_metric_label()]) diff --git a/proxy/src/context/parquet.rs b/proxy/src/context/parquet.rs index 0fe46915bc..ad22829183 100644 --- a/proxy/src/context/parquet.rs +++ b/proxy/src/context/parquet.rs @@ -84,8 +84,10 @@ struct RequestData { username: Option, application_name: Option, endpoint_id: Option, + database: Option, project: Option, branch: Option, + auth_method: Option<&'static str>, error: Option<&'static str>, /// Success is counted if we form a HTTP response with sql rows inside /// Or if we make it to proxy_pass @@ -104,8 +106,15 @@ impl From for RequestData { username: value.user.as_deref().map(String::from), application_name: value.application.as_deref().map(String::from), endpoint_id: value.endpoint_id.as_deref().map(String::from), + database: value.dbname.as_deref().map(String::from), project: value.project.as_deref().map(String::from), branch: value.branch.as_deref().map(String::from), + auth_method: value.auth_method.as_ref().map(|x| match x { + super::AuthMethod::Web => "web", + super::AuthMethod::ScramSha256 => "scram_sha_256", + super::AuthMethod::ScramSha256Plus => "scram_sha_256_plus", + super::AuthMethod::Cleartext => "cleartext", + }), protocol: value.protocol, region: value.region, error: value.error_kind.as_ref().map(|e| e.to_metric_label()), @@ -431,8 +440,10 @@ mod tests { application_name: Some("test".to_owned()), username: Some(hex::encode(rng.gen::<[u8; 4]>())), endpoint_id: Some(hex::encode(rng.gen::<[u8; 16]>())), + database: Some(hex::encode(rng.gen::<[u8; 16]>())), project: Some(hex::encode(rng.gen::<[u8; 16]>())), branch: Some(hex::encode(rng.gen::<[u8; 16]>())), + auth_method: None, protocol: ["tcp", "ws", "http"][rng.gen_range(0..3)], region: "us-east-1", error: None, @@ -505,15 +516,15 @@ mod tests { assert_eq!( file_stats, [ - (1087635, 3, 6000), - (1087288, 3, 6000), - (1087444, 3, 6000), - (1087572, 3, 6000), - (1087468, 3, 6000), - (1087500, 3, 6000), - (1087533, 3, 6000), - (1087566, 3, 6000), - (362671, 1, 2000) + (1313727, 3, 6000), + (1313720, 3, 6000), + (1313780, 3, 6000), + (1313737, 3, 6000), + (1313867, 3, 6000), + (1313709, 3, 6000), + (1313501, 3, 6000), + (1313737, 3, 6000), + (438118, 1, 2000) ], ); @@ -543,11 +554,11 @@ mod tests { assert_eq!( file_stats, [ - (1028637, 5, 10000), - (1031969, 5, 10000), - (1019900, 5, 10000), - (1020365, 5, 10000), - (1025010, 5, 10000) + (1219459, 5, 10000), + (1225609, 5, 10000), + (1227403, 5, 10000), + (1226765, 5, 10000), + (1218043, 5, 10000) ], ); @@ -579,11 +590,11 @@ mod tests { assert_eq!( file_stats, [ - (1210770, 6, 12000), - (1211036, 6, 12000), - (1210990, 6, 12000), - (1210861, 6, 12000), - (202073, 1, 2000) + (1205106, 5, 10000), + (1204837, 5, 10000), + (1205130, 5, 10000), + (1205118, 5, 10000), + (1205373, 5, 10000) ], ); @@ -608,15 +619,15 @@ mod tests { assert_eq!( file_stats, [ - (1087635, 3, 6000), - (1087288, 3, 6000), - (1087444, 3, 6000), - (1087572, 3, 6000), - (1087468, 3, 6000), - (1087500, 3, 6000), - (1087533, 3, 6000), - (1087566, 3, 6000), - (362671, 1, 2000) + (1313727, 3, 6000), + (1313720, 3, 6000), + (1313780, 3, 6000), + (1313737, 3, 6000), + (1313867, 3, 6000), + (1313709, 3, 6000), + (1313501, 3, 6000), + (1313737, 3, 6000), + (438118, 1, 2000) ], ); @@ -653,7 +664,7 @@ mod tests { // files are smaller than the size threshold, but they took too long to fill so were flushed early assert_eq!( file_stats, - [(545264, 2, 3001), (545025, 2, 3000), (544857, 2, 2999)], + [(658383, 2, 3001), (658097, 2, 3000), (657893, 2, 2999)], ); tmpdir.close().unwrap(); diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests.rs index 3e961afb41..5bb43c0375 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests.rs @@ -144,7 +144,7 @@ impl TestAuth for Scram { stream: &mut PqStream>, ) -> anyhow::Result<()> { let outcome = auth::AuthFlow::new(stream) - .begin(auth::Scram(&self.0)) + .begin(auth::Scram(&self.0, &mut RequestMonitoring::test())) .await? .authenticate() .await?; diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 54424360c4..e9f868d51e 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -36,6 +36,7 @@ use crate::error::ReportableError; use crate::metrics::HTTP_CONTENT_LENGTH; use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE; use crate::proxy::NeonOptions; +use crate::DbName; use crate::RoleName; use super::backend::PoolingBackend; @@ -117,6 +118,9 @@ fn get_conn_info( headers: &HeaderMap, tls: &TlsConfig, ) -> Result { + // HTTP only uses cleartext (for now and likely always) + ctx.set_auth_method(crate::context::AuthMethod::Cleartext); + let connection_string = headers .get("Neon-Connection-String") .ok_or(ConnInfoError::InvalidHeader("Neon-Connection-String"))? @@ -134,7 +138,8 @@ fn get_conn_info( .path_segments() .ok_or(ConnInfoError::MissingDbName)?; - let dbname = url_path.next().ok_or(ConnInfoError::InvalidDbName)?; + let dbname: DbName = url_path.next().ok_or(ConnInfoError::InvalidDbName)?.into(); + ctx.set_dbname(dbname.clone()); let username = RoleName::from(urlencoding::decode(connection_url.username())?); if username.is_empty() { @@ -174,7 +179,7 @@ fn get_conn_info( Ok(ConnInfo { user_info, - dbname: dbname.into(), + dbname, password: match password { std::borrow::Cow::Borrowed(b) => b.into(), std::borrow::Cow::Owned(b) => b.into(), From 242dd8398c8d6728270c8d8c2a0b45dae480cb97 Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Mon, 12 Feb 2024 15:58:55 +0100 Subject: [PATCH 05/20] refactor(blob_io): use owned buffers (#6660) This PR refactors the `blob_io` code away from using slices towards taking owned buffers and return them after use. Using owned buffers will eventually allow us to use io_uring for writes. part of https://github.com/neondatabase/neon/issues/6663 Depends on https://github.com/neondatabase/tokio-epoll-uring/pull/43 The high level scheme is as follows: - call writing functions with the `BoundedBuf` - return the underlying `BoundedBuf::Buf` for potential reuse in the caller NB: Invoking `BoundedBuf::slice(..)` will return a slice that _includes the uninitialized portion of `BoundedBuf`_. I.e., the portion between `bytes_init()` and `bytes_total()`. It's a safe API that actually permits access to uninitialized memory. Not great. Another wrinkle is that it panics if the range has length 0. However, I don't want to switch away from the `BoundedBuf` API, since it's what tokio-uring uses. We can always weed this out later by replacing `BoundedBuf` with our own type. Created an issue so we don't forget: https://github.com/neondatabase/tokio-epoll-uring/issues/46 --- Cargo.lock | 5 +- pageserver/src/tenant/blob_io.rs | 121 +++++++++++++----- .../src/tenant/storage_layer/delta_layer.rs | 26 ++-- .../src/tenant/storage_layer/image_layer.rs | 8 +- .../tenant/storage_layer/inmemory_layer.rs | 8 +- pageserver/src/tenant/timeline.rs | 2 +- 6 files changed, 115 insertions(+), 55 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 83afdaf66f..520163e41b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5740,7 +5740,7 @@ dependencies = [ [[package]] name = "tokio-epoll-uring" version = "0.1.0" -source = "git+https://github.com/neondatabase/tokio-epoll-uring.git?branch=main#d6a1c93442fb6b3a5bec490204961134e54925dc" +source = "git+https://github.com/neondatabase/tokio-epoll-uring.git?branch=main#868d2c42b5d54ca82fead6e8f2f233b69a540d3e" dependencies = [ "futures", "nix 0.26.4", @@ -6265,8 +6265,9 @@ dependencies = [ [[package]] name = "uring-common" version = "0.1.0" -source = "git+https://github.com/neondatabase/tokio-epoll-uring.git?branch=main#d6a1c93442fb6b3a5bec490204961134e54925dc" +source = "git+https://github.com/neondatabase/tokio-epoll-uring.git?branch=main#868d2c42b5d54ca82fead6e8f2f233b69a540d3e" dependencies = [ + "bytes", "io-uring", "libc", ] diff --git a/pageserver/src/tenant/blob_io.rs b/pageserver/src/tenant/blob_io.rs index 6de2e95055..e2ff12665a 100644 --- a/pageserver/src/tenant/blob_io.rs +++ b/pageserver/src/tenant/blob_io.rs @@ -11,6 +11,9 @@ //! len < 128: 0XXXXXXX //! len >= 128: 1XXXXXXX XXXXXXXX XXXXXXXX XXXXXXXX //! +use bytes::{BufMut, BytesMut}; +use tokio_epoll_uring::{BoundedBuf, Slice}; + use crate::context::RequestContext; use crate::page_cache::PAGE_SZ; use crate::tenant::block_io::BlockCursor; @@ -100,6 +103,8 @@ pub struct BlobWriter { offset: u64, /// A buffer to save on write calls, only used if BUFFERED=true buf: Vec, + /// We do tiny writes for the length headers; they need to be in an owned buffer; + io_buf: Option, } impl BlobWriter { @@ -108,6 +113,7 @@ impl BlobWriter { inner, offset: start_offset, buf: Vec::with_capacity(Self::CAPACITY), + io_buf: Some(BytesMut::new()), } } @@ -117,14 +123,28 @@ impl BlobWriter { const CAPACITY: usize = if BUFFERED { PAGE_SZ } else { 0 }; - #[inline(always)] /// Writes the given buffer directly to the underlying `VirtualFile`. /// You need to make sure that the internal buffer is empty, otherwise /// data will be written in wrong order. - async fn write_all_unbuffered(&mut self, src_buf: &[u8]) -> Result<(), Error> { - self.inner.write_all(src_buf).await?; - self.offset += src_buf.len() as u64; - Ok(()) + #[inline(always)] + async fn write_all_unbuffered( + &mut self, + src_buf: B, + ) -> (B::Buf, Result<(), Error>) { + let src_buf_len = src_buf.bytes_init(); + let (src_buf, res) = if src_buf_len > 0 { + let src_buf = src_buf.slice(0..src_buf_len); + let res = self.inner.write_all(&src_buf).await; + let src_buf = Slice::into_inner(src_buf); + (src_buf, res) + } else { + let res = self.inner.write_all(&[]).await; + (Slice::into_inner(src_buf.slice_full()), res) + }; + if let Ok(()) = &res { + self.offset += src_buf_len as u64; + } + (src_buf, res) } #[inline(always)] @@ -146,62 +166,91 @@ impl BlobWriter { } /// Internal, possibly buffered, write function - async fn write_all(&mut self, mut src_buf: &[u8]) -> Result<(), Error> { + async fn write_all(&mut self, src_buf: B) -> (B::Buf, Result<(), Error>) { if !BUFFERED { assert!(self.buf.is_empty()); - self.write_all_unbuffered(src_buf).await?; - return Ok(()); + return self.write_all_unbuffered(src_buf).await; } let remaining = Self::CAPACITY - self.buf.len(); + let src_buf_len = src_buf.bytes_init(); + if src_buf_len == 0 { + return (Slice::into_inner(src_buf.slice_full()), Ok(())); + } + let mut src_buf = src_buf.slice(0..src_buf_len); // First try to copy as much as we can into the buffer if remaining > 0 { - let copied = self.write_into_buffer(src_buf); - src_buf = &src_buf[copied..]; + let copied = self.write_into_buffer(&src_buf); + src_buf = src_buf.slice(copied..); } // Then, if the buffer is full, flush it out if self.buf.len() == Self::CAPACITY { - self.flush_buffer().await?; + if let Err(e) = self.flush_buffer().await { + return (Slice::into_inner(src_buf), Err(e)); + } } // Finally, write the tail of src_buf: // If it wholly fits into the buffer without // completely filling it, then put it there. // If not, write it out directly. - if !src_buf.is_empty() { + let src_buf = if !src_buf.is_empty() { assert_eq!(self.buf.len(), 0); if src_buf.len() < Self::CAPACITY { - let copied = self.write_into_buffer(src_buf); + let copied = self.write_into_buffer(&src_buf); // We just verified above that src_buf fits into our internal buffer. assert_eq!(copied, src_buf.len()); + Slice::into_inner(src_buf) } else { - self.write_all_unbuffered(src_buf).await?; + let (src_buf, res) = self.write_all_unbuffered(src_buf).await; + if let Err(e) = res { + return (src_buf, Err(e)); + } + src_buf } - } - Ok(()) + } else { + Slice::into_inner(src_buf) + }; + (src_buf, Ok(())) } /// Write a blob of data. Returns the offset that it was written to, /// which can be used to retrieve the data later. - pub async fn write_blob(&mut self, srcbuf: &[u8]) -> Result { + pub async fn write_blob(&mut self, srcbuf: B) -> (B::Buf, Result) { let offset = self.offset; - if srcbuf.len() < 128 { - // Short blob. Write a 1-byte length header - let len_buf = srcbuf.len() as u8; - self.write_all(&[len_buf]).await?; - } else { - // Write a 4-byte length header - if srcbuf.len() > 0x7fff_ffff { - return Err(Error::new( - ErrorKind::Other, - format!("blob too large ({} bytes)", srcbuf.len()), - )); + let len = srcbuf.bytes_init(); + + let mut io_buf = self.io_buf.take().expect("we always put it back below"); + io_buf.clear(); + let (io_buf, hdr_res) = async { + if len < 128 { + // Short blob. Write a 1-byte length header + io_buf.put_u8(len as u8); + self.write_all(io_buf).await + } else { + // Write a 4-byte length header + if len > 0x7fff_ffff { + return ( + io_buf, + Err(Error::new( + ErrorKind::Other, + format!("blob too large ({} bytes)", len), + )), + ); + } + let mut len_buf = (len as u32).to_be_bytes(); + len_buf[0] |= 0x80; + io_buf.extend_from_slice(&len_buf[..]); + self.write_all(io_buf).await } - let mut len_buf = ((srcbuf.len()) as u32).to_be_bytes(); - len_buf[0] |= 0x80; - self.write_all(&len_buf).await?; } - self.write_all(srcbuf).await?; - Ok(offset) + .await; + self.io_buf = Some(io_buf); + match hdr_res { + Ok(_) => (), + Err(e) => return (Slice::into_inner(srcbuf.slice(..)), Err(e)), + } + let (srcbuf, res) = self.write_all(srcbuf).await; + (srcbuf, res.map(|_| offset)) } } @@ -248,12 +297,14 @@ mod tests { let file = VirtualFile::create(pathbuf.as_path()).await?; let mut wtr = BlobWriter::::new(file, 0); for blob in blobs.iter() { - let offs = wtr.write_blob(blob).await?; + let (_, res) = wtr.write_blob(blob.clone()).await; + let offs = res?; offsets.push(offs); } // Write out one page worth of zeros so that we can // read again with read_blk - let offs = wtr.write_blob(&vec![0; PAGE_SZ]).await?; + let (_, res) = wtr.write_blob(vec![0; PAGE_SZ]).await; + let offs = res?; println!("Writing final blob at offs={offs}"); wtr.flush_buffer().await?; } diff --git a/pageserver/src/tenant/storage_layer/delta_layer.rs b/pageserver/src/tenant/storage_layer/delta_layer.rs index 2a51884c0b..7a5dc7a59f 100644 --- a/pageserver/src/tenant/storage_layer/delta_layer.rs +++ b/pageserver/src/tenant/storage_layer/delta_layer.rs @@ -416,27 +416,31 @@ impl DeltaLayerWriterInner { /// The values must be appended in key, lsn order. /// async fn put_value(&mut self, key: Key, lsn: Lsn, val: Value) -> anyhow::Result<()> { - self.put_value_bytes(key, lsn, &Value::ser(&val)?, val.will_init()) - .await + let (_, res) = self + .put_value_bytes(key, lsn, Value::ser(&val)?, val.will_init()) + .await; + res } async fn put_value_bytes( &mut self, key: Key, lsn: Lsn, - val: &[u8], + val: Vec, will_init: bool, - ) -> anyhow::Result<()> { + ) -> (Vec, anyhow::Result<()>) { assert!(self.lsn_range.start <= lsn); - - let off = self.blob_writer.write_blob(val).await?; + let (val, res) = self.blob_writer.write_blob(val).await; + let off = match res { + Ok(off) => off, + Err(e) => return (val, Err(anyhow::anyhow!(e))), + }; let blob_ref = BlobRef::new(off, will_init); let delta_key = DeltaKey::from_key_lsn(&key, lsn); - self.tree.append(&delta_key.0, blob_ref.0)?; - - Ok(()) + let res = self.tree.append(&delta_key.0, blob_ref.0); + (val, res.map_err(|e| anyhow::anyhow!(e))) } fn size(&self) -> u64 { @@ -587,9 +591,9 @@ impl DeltaLayerWriter { &mut self, key: Key, lsn: Lsn, - val: &[u8], + val: Vec, will_init: bool, - ) -> anyhow::Result<()> { + ) -> (Vec, anyhow::Result<()>) { self.inner .as_mut() .unwrap() diff --git a/pageserver/src/tenant/storage_layer/image_layer.rs b/pageserver/src/tenant/storage_layer/image_layer.rs index c62e6aed51..1ad195032d 100644 --- a/pageserver/src/tenant/storage_layer/image_layer.rs +++ b/pageserver/src/tenant/storage_layer/image_layer.rs @@ -528,9 +528,11 @@ impl ImageLayerWriterInner { /// /// The page versions must be appended in blknum order. /// - async fn put_image(&mut self, key: Key, img: &[u8]) -> anyhow::Result<()> { + async fn put_image(&mut self, key: Key, img: Bytes) -> anyhow::Result<()> { ensure!(self.key_range.contains(&key)); - let off = self.blob_writer.write_blob(img).await?; + let (_img, res) = self.blob_writer.write_blob(img).await; + // TODO: re-use the buffer for `img` further upstack + let off = res?; let mut keybuf: [u8; KEY_SIZE] = [0u8; KEY_SIZE]; key.write_to_byte_slice(&mut keybuf); @@ -659,7 +661,7 @@ impl ImageLayerWriter { /// /// The page versions must be appended in blknum order. /// - pub async fn put_image(&mut self, key: Key, img: &[u8]) -> anyhow::Result<()> { + pub async fn put_image(&mut self, key: Key, img: Bytes) -> anyhow::Result<()> { self.inner.as_mut().unwrap().put_image(key, img).await } diff --git a/pageserver/src/tenant/storage_layer/inmemory_layer.rs b/pageserver/src/tenant/storage_layer/inmemory_layer.rs index 7c9103eea8..c597b15533 100644 --- a/pageserver/src/tenant/storage_layer/inmemory_layer.rs +++ b/pageserver/src/tenant/storage_layer/inmemory_layer.rs @@ -383,9 +383,11 @@ impl InMemoryLayer { for (lsn, pos) in vec_map.as_slice() { cursor.read_blob_into_buf(*pos, &mut buf, &ctx).await?; let will_init = Value::des(&buf)?.will_init(); - delta_layer_writer - .put_value_bytes(key, *lsn, &buf, will_init) - .await?; + let res; + (buf, res) = delta_layer_writer + .put_value_bytes(key, *lsn, buf, will_init) + .await; + res?; } } diff --git a/pageserver/src/tenant/timeline.rs b/pageserver/src/tenant/timeline.rs index f96679ca69..74676277d5 100644 --- a/pageserver/src/tenant/timeline.rs +++ b/pageserver/src/tenant/timeline.rs @@ -3328,7 +3328,7 @@ impl Timeline { } }; - image_layer_writer.put_image(img_key, &img).await?; + image_layer_writer.put_image(img_key, img).await?; } } From 789a71c4ee6722f26ae4929a10e1316568e2006f Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Mon, 12 Feb 2024 15:03:45 +0000 Subject: [PATCH 06/20] proxy: add more http logging (#6726) ## Problem hard to see where time is taken during HTTP flow. ## Summary of changes add a lot more for query state. add a conn_id field to the sql-over-http span --- proxy/src/metrics.rs | 5 ++-- proxy/src/serverless/backend.rs | 8 +++---- proxy/src/serverless/conn_pool.rs | 22 +++++------------- proxy/src/serverless/sql_over_http.rs | 33 +++++++++++++++++++++++---- 4 files changed, 41 insertions(+), 27 deletions(-) diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index ccf89f9b05..f7f162a075 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -200,8 +200,9 @@ impl LatencyTimer { pub fn success(&mut self) { // stop the stopwatch and record the time that we have accumulated - let start = self.start.take().expect("latency timer should be started"); - self.accumulated += start.elapsed(); + if let Some(start) = self.start.take() { + self.accumulated += start.elapsed(); + } // success self.outcome = "success"; diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 8285da68d7..156002006d 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -1,7 +1,7 @@ use std::{sync::Arc, time::Duration}; use async_trait::async_trait; -use tracing::info; +use tracing::{field::display, info}; use crate::{ auth::{backend::ComputeCredentialKeys, check_peer_addr_is_in_list, AuthError}, @@ -15,7 +15,7 @@ use crate::{ proxy::connect_compute::ConnectMechanism, }; -use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool, APP_NAME}; +use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool}; pub struct PoolingBackend { pub pool: Arc>, @@ -81,8 +81,8 @@ impl PoolingBackend { return Ok(client); } let conn_id = uuid::Uuid::new_v4(); - info!(%conn_id, "pool: opening a new connection '{conn_info}'"); - ctx.set_application(Some(APP_NAME)); + tracing::Span::current().record("conn_id", display(conn_id)); + info!("pool: opening a new connection '{conn_info}'"); let backend = self .config .auth_backend diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index f4e5b145c5..53e7c1c2ee 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -4,7 +4,6 @@ use metrics::IntCounterPairGuard; use parking_lot::RwLock; use rand::Rng; use smallvec::SmallVec; -use smol_str::SmolStr; use std::{collections::HashMap, pin::pin, sync::Arc, sync::Weak, time::Duration}; use std::{ fmt, @@ -31,8 +30,6 @@ use tracing::{info, info_span, Instrument}; use super::backend::HttpConnError; -pub const APP_NAME: SmolStr = SmolStr::new_inline("/sql_over_http"); - #[derive(Debug, Clone)] pub struct ConnInfo { pub user_info: ComputeUserInfo, @@ -379,12 +376,13 @@ impl GlobalConnPool { info!("pool: cached connection '{conn_info}' is closed, opening a new one"); return Ok(None); } else { - info!("pool: reusing connection '{conn_info}'"); - client.session.send(ctx.session_id)?; + tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id)); tracing::Span::current().record( "pid", &tracing::field::display(client.inner.get_process_id()), ); + info!("pool: reusing connection '{conn_info}'"); + client.session.send(ctx.session_id)?; ctx.latency_timer.pool_hit(); ctx.latency_timer.success(); return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool))); @@ -577,7 +575,6 @@ pub struct Client { } pub struct Discard<'a, C: ClientInnerExt> { - conn_id: uuid::Uuid, conn_info: &'a ConnInfo, pool: &'a mut Weak>>, } @@ -603,14 +600,7 @@ impl Client { span: _, } = self; let inner = inner.as_mut().expect("client inner should not be removed"); - ( - &mut inner.inner, - Discard { - pool, - conn_info, - conn_id: inner.conn_id, - }, - ) + (&mut inner.inner, Discard { pool, conn_info }) } pub fn check_idle(&mut self, status: ReadyForQueryStatus) { @@ -625,13 +615,13 @@ impl Discard<'_, C> { pub fn check_idle(&mut self, status: ReadyForQueryStatus) { let conn_info = &self.conn_info; if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 { - info!(conn_id = %self.conn_id, "pool: throwing away connection '{conn_info}' because connection is not idle") + info!("pool: throwing away connection '{conn_info}' because connection is not idle") } } pub fn discard(&mut self) { let conn_info = &self.conn_info; if std::mem::take(self.pool).strong_count() > 0 { - info!(conn_id = %self.conn_id, "pool: throwing away connection '{conn_info}' because connection is potentially in a broken state") + info!("pool: throwing away connection '{conn_info}' because connection is potentially in a broken state") } } } diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index e9f868d51e..ecb72abe73 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -36,6 +36,7 @@ use crate::error::ReportableError; use crate::metrics::HTTP_CONTENT_LENGTH; use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE; use crate::proxy::NeonOptions; +use crate::serverless::backend::HttpConnError; use crate::DbName; use crate::RoleName; @@ -305,7 +306,14 @@ pub async fn handle( Ok(response) } -#[instrument(name = "sql-over-http", fields(pid = tracing::field::Empty), skip_all)] +#[instrument( + name = "sql-over-http", + skip_all, + fields( + pid = tracing::field::Empty, + conn_id = tracing::field::Empty + ) +)] async fn handle_inner( config: &'static ProxyConfig, ctx: &mut RequestMonitoring, @@ -359,12 +367,10 @@ async fn handle_inner( let txn_read_only = headers.get(&TXN_READ_ONLY) == Some(&HEADER_VALUE_TRUE); let txn_deferrable = headers.get(&TXN_DEFERRABLE) == Some(&HEADER_VALUE_TRUE); - let paused = ctx.latency_timer.pause(); let request_content_length = match request.body().size_hint().upper() { Some(v) => v, None => MAX_REQUEST_SIZE + 1, }; - drop(paused); info!(request_content_length, "request size in bytes"); HTTP_CONTENT_LENGTH.observe(request_content_length as f64); @@ -380,15 +386,20 @@ async fn handle_inner( let body = hyper::body::to_bytes(request.into_body()) .await .map_err(anyhow::Error::from)?; + info!(length = body.len(), "request payload read"); let payload: Payload = serde_json::from_slice(&body)?; Ok::(payload) // Adjust error type accordingly }; let authenticate_and_connect = async { let keys = backend.authenticate(ctx, &conn_info).await?; - backend + let client = backend .connect_to_compute(ctx, conn_info, keys, !allow_pool) - .await + .await?; + // not strictly necessary to mark success here, + // but it's just insurance for if we forget it somewhere else + ctx.latency_timer.success(); + Ok::<_, HttpConnError>(client) }; // Run both operations in parallel @@ -420,6 +431,7 @@ async fn handle_inner( results } Payload::Batch(statements) => { + info!("starting transaction"); let (inner, mut discard) = client.inner(); let mut builder = inner.build_transaction(); if let Some(isolation_level) = txn_isolation_level { @@ -449,6 +461,7 @@ async fn handle_inner( .await { Ok(results) => { + info!("commit"); let status = transaction.commit().await.map_err(|e| { // if we cannot commit - for now don't return connection to pool // TODO: get a query status from the error @@ -459,6 +472,7 @@ async fn handle_inner( results } Err(err) => { + info!("rollback"); let status = transaction.rollback().await.map_err(|e| { // if we cannot rollback - for now don't return connection to pool // TODO: get a query status from the error @@ -533,8 +547,10 @@ async fn query_to_json( raw_output: bool, default_array_mode: bool, ) -> anyhow::Result<(ReadyForQueryStatus, Value)> { + info!("executing query"); let query_params = data.params; let row_stream = client.query_raw_txt(&data.query, query_params).await?; + info!("finished executing query"); // Manually drain the stream into a vector to leave row_stream hanging // around to get a command tag. Also check that the response is not too @@ -569,6 +585,13 @@ async fn query_to_json( } .and_then(|s| s.parse::().ok()); + info!( + rows = rows.len(), + ?ready, + command_tag, + "finished reading rows" + ); + let mut fields = vec![]; let mut columns = vec![]; From 7ea593db2292324e136d3325cd96217c9d652395 Mon Sep 17 00:00:00 2001 From: Joonas Koivunen Date: Mon, 12 Feb 2024 17:13:35 +0200 Subject: [PATCH 07/20] refactor(LayerManager): resident layers query (#6634) Refactor out layer accesses so that we can have easy access to resident layers, which are needed for number of cases instead of layers for eviction. Simplifies the heatmap building by only using Layers, not RemoteTimelineClient. Cc: #5331 --- .../src/tenant/remote_timeline_client.rs | 17 ---- pageserver/src/tenant/storage_layer.rs | 8 +- pageserver/src/tenant/storage_layer/layer.rs | 4 - pageserver/src/tenant/timeline.rs | 97 ++++++------------- .../src/tenant/timeline/eviction_task.rs | 7 +- .../src/tenant/timeline/layer_manager.rs | 45 ++++++--- 6 files changed, 74 insertions(+), 104 deletions(-) diff --git a/pageserver/src/tenant/remote_timeline_client.rs b/pageserver/src/tenant/remote_timeline_client.rs index e17dea01a8..483f53d5c8 100644 --- a/pageserver/src/tenant/remote_timeline_client.rs +++ b/pageserver/src/tenant/remote_timeline_client.rs @@ -1700,23 +1700,6 @@ impl RemoteTimelineClient { } } } - - pub(crate) fn get_layers_metadata( - &self, - layers: Vec, - ) -> anyhow::Result>> { - let q = self.upload_queue.lock().unwrap(); - let q = match &*q { - UploadQueue::Stopped(_) | UploadQueue::Uninitialized => { - anyhow::bail!("queue is in state {}", q.as_str()) - } - UploadQueue::Initialized(inner) => inner, - }; - - let decorated = layers.into_iter().map(|l| q.latest_files.get(&l).cloned()); - - Ok(decorated.collect()) - } } pub fn remote_timelines_path(tenant_shard_id: &TenantShardId) -> RemotePath { diff --git a/pageserver/src/tenant/storage_layer.rs b/pageserver/src/tenant/storage_layer.rs index 6e9a4932d8..2d92baccbe 100644 --- a/pageserver/src/tenant/storage_layer.rs +++ b/pageserver/src/tenant/storage_layer.rs @@ -257,6 +257,12 @@ impl LayerAccessStats { ret } + /// Get the latest access timestamp, falling back to latest residence event, further falling + /// back to `SystemTime::now` for a usable timestamp for eviction. + pub(crate) fn latest_activity_or_now(&self) -> SystemTime { + self.latest_activity().unwrap_or_else(SystemTime::now) + } + /// Get the latest access timestamp, falling back to latest residence event. /// /// This function can only return `None` if there has not yet been a call to the @@ -271,7 +277,7 @@ impl LayerAccessStats { /// that that type can only be produced by inserting into the layer map. /// /// [`record_residence_event`]: Self::record_residence_event - pub(crate) fn latest_activity(&self) -> Option { + fn latest_activity(&self) -> Option { let locked = self.0.lock().unwrap(); let inner = &locked.for_eviction_policy; match inner.last_accesses.recent() { diff --git a/pageserver/src/tenant/storage_layer/layer.rs b/pageserver/src/tenant/storage_layer/layer.rs index dd9de99477..bfcc031863 100644 --- a/pageserver/src/tenant/storage_layer/layer.rs +++ b/pageserver/src/tenant/storage_layer/layer.rs @@ -1413,10 +1413,6 @@ impl ResidentLayer { &self.owner.0.path } - pub(crate) fn access_stats(&self) -> &LayerAccessStats { - self.owner.access_stats() - } - pub(crate) fn metadata(&self) -> LayerFileMetadata { self.owner.metadata() } diff --git a/pageserver/src/tenant/timeline.rs b/pageserver/src/tenant/timeline.rs index 74676277d5..625be7a644 100644 --- a/pageserver/src/tenant/timeline.rs +++ b/pageserver/src/tenant/timeline.rs @@ -12,6 +12,7 @@ use bytes::Bytes; use camino::{Utf8Path, Utf8PathBuf}; use enumset::EnumSet; use fail::fail_point; +use futures::stream::StreamExt; use itertools::Itertools; use pageserver_api::{ keyspace::{key_range_size, KeySpaceAccum}, @@ -105,7 +106,7 @@ use self::logical_size::LogicalSize; use self::walreceiver::{WalReceiver, WalReceiverConf}; use super::config::TenantConf; -use super::remote_timeline_client::index::{IndexLayerMetadata, IndexPart}; +use super::remote_timeline_client::index::IndexPart; use super::remote_timeline_client::RemoteTimelineClient; use super::secondary::heatmap::{HeatMapLayer, HeatMapTimeline}; use super::{debug_assert_current_span_has_tenant_and_timeline_id, AttachedTenantConf}; @@ -1458,7 +1459,7 @@ impl Timeline { generation, shard_identity, pg_version, - layers: Arc::new(tokio::sync::RwLock::new(LayerManager::create())), + layers: Default::default(), wanted_image_layers: Mutex::new(None), walredo_mgr, @@ -2283,45 +2284,28 @@ impl Timeline { /// should treat this as a cue to simply skip doing any heatmap uploading /// for this timeline. pub(crate) async fn generate_heatmap(&self) -> Option { - let eviction_info = self.get_local_layers_for_disk_usage_eviction().await; + // no point in heatmaps without remote client + let _remote_client = self.remote_client.as_ref()?; - let remote_client = match &self.remote_client { - Some(c) => c, - None => return None, - }; + if !self.is_active() { + return None; + } - let layer_file_names = eviction_info - .resident_layers - .iter() - .map(|l| l.layer.get_name()) - .collect::>(); + let guard = self.layers.read().await; - let decorated = match remote_client.get_layers_metadata(layer_file_names) { - Ok(d) => d, - Err(_) => { - // Getting metadata only fails on Timeline in bad state. - return None; - } - }; + let resident = guard.resident_layers().map(|layer| { + let last_activity_ts = layer.access_stats().latest_activity_or_now(); - let heatmap_layers = std::iter::zip( - eviction_info.resident_layers.into_iter(), - decorated.into_iter(), - ) - .filter_map(|(layer, remote_info)| { - remote_info.map(|remote_info| { - HeatMapLayer::new( - layer.layer.get_name(), - IndexLayerMetadata::from(remote_info), - layer.last_activity_ts, - ) - }) + HeatMapLayer::new( + layer.layer_desc().filename(), + layer.metadata().into(), + last_activity_ts, + ) }); - Some(HeatMapTimeline::new( - self.timeline_id, - heatmap_layers.collect(), - )) + let layers = resident.collect().await; + + Some(HeatMapTimeline::new(self.timeline_id, layers)) } } @@ -4662,41 +4646,24 @@ impl Timeline { /// Returns non-remote layers for eviction. pub(crate) async fn get_local_layers_for_disk_usage_eviction(&self) -> DiskUsageEvictionInfo { let guard = self.layers.read().await; - let layers = guard.layer_map(); - let mut max_layer_size: Option = None; - let mut resident_layers = Vec::new(); - for l in layers.iter_historic_layers() { - let file_size = l.file_size(); - max_layer_size = max_layer_size.map_or(Some(file_size), |m| Some(m.max(file_size))); + let resident_layers = guard + .resident_layers() + .map(|layer| { + let file_size = layer.layer_desc().file_size; + max_layer_size = max_layer_size.map_or(Some(file_size), |m| Some(m.max(file_size))); - let l = guard.get_from_desc(&l); + let last_activity_ts = layer.access_stats().latest_activity_or_now(); - let l = match l.keep_resident().await { - Ok(Some(l)) => l, - Ok(None) => continue, - Err(e) => { - // these should not happen, but we cannot make them statically impossible right - // now. - tracing::warn!(layer=%l, "failed to keep the layer resident: {e:#}"); - continue; + EvictionCandidate { + layer: layer.into(), + last_activity_ts, + relative_last_activity: finite_f32::FiniteF32::ZERO, } - }; - - let last_activity_ts = l.access_stats().latest_activity().unwrap_or_else(|| { - // We only use this fallback if there's an implementation error. - // `latest_activity` already does rate-limited warn!() log. - debug!(layer=%l, "last_activity returns None, using SystemTime::now"); - SystemTime::now() - }); - - resident_layers.push(EvictionCandidate { - layer: l.drop_eviction_guard().into(), - last_activity_ts, - relative_last_activity: finite_f32::FiniteF32::ZERO, - }); - } + }) + .collect() + .await; DiskUsageEvictionInfo { max_layer_size, diff --git a/pageserver/src/tenant/timeline/eviction_task.rs b/pageserver/src/tenant/timeline/eviction_task.rs index 9bdd52e809..d87f78e35f 100644 --- a/pageserver/src/tenant/timeline/eviction_task.rs +++ b/pageserver/src/tenant/timeline/eviction_task.rs @@ -239,12 +239,7 @@ impl Timeline { } }; - let last_activity_ts = hist_layer.access_stats().latest_activity().unwrap_or_else(|| { - // We only use this fallback if there's an implementation error. - // `latest_activity` already does rate-limited warn!() log. - debug!(layer=%hist_layer, "last_activity returns None, using SystemTime::now"); - SystemTime::now() - }); + let last_activity_ts = hist_layer.access_stats().latest_activity_or_now(); let no_activity_for = match now.duration_since(last_activity_ts) { Ok(d) => d, diff --git a/pageserver/src/tenant/timeline/layer_manager.rs b/pageserver/src/tenant/timeline/layer_manager.rs index e38f5be209..ebcdcfdb4d 100644 --- a/pageserver/src/tenant/timeline/layer_manager.rs +++ b/pageserver/src/tenant/timeline/layer_manager.rs @@ -1,4 +1,5 @@ use anyhow::{bail, ensure, Context, Result}; +use futures::StreamExt; use pageserver_api::shard::TenantShardId; use std::{collections::HashMap, sync::Arc}; use tracing::trace; @@ -20,19 +21,13 @@ use crate::{ }; /// Provides semantic APIs to manipulate the layer map. +#[derive(Default)] pub(crate) struct LayerManager { layer_map: LayerMap, layer_fmgr: LayerFileManager, } impl LayerManager { - pub(crate) fn create() -> Self { - Self { - layer_map: LayerMap::default(), - layer_fmgr: LayerFileManager::new(), - } - } - pub(crate) fn get_from_desc(&self, desc: &PersistentLayerDesc) -> Layer { self.layer_fmgr.get_from_desc(desc) } @@ -246,6 +241,32 @@ impl LayerManager { layer.delete_on_drop(); } + pub(crate) fn resident_layers(&self) -> impl futures::stream::Stream + '_ { + // for small layer maps, we most likely have all resident, but for larger more are likely + // to be evicted assuming lots of layers correlated with longer lifespan. + + let layers = self + .layer_map() + .iter_historic_layers() + .map(|desc| self.get_from_desc(&desc)); + + let layers = futures::stream::iter(layers); + + layers.filter_map(|layer| async move { + // TODO(#6028): this query does not really need to see the ResidentLayer + match layer.keep_resident().await { + Ok(Some(layer)) => Some(layer.drop_eviction_guard()), + Ok(None) => None, + Err(e) => { + // these should not happen, but we cannot make them statically impossible right + // now. + tracing::warn!(%layer, "failed to keep the layer resident: {e:#}"); + None + } + } + }) + } + pub(crate) fn contains(&self, layer: &Layer) -> bool { self.layer_fmgr.contains(layer) } @@ -253,6 +274,12 @@ impl LayerManager { pub(crate) struct LayerFileManager(HashMap); +impl Default for LayerFileManager { + fn default() -> Self { + Self(HashMap::default()) + } +} + impl LayerFileManager { fn get_from_desc(&self, desc: &PersistentLayerDesc) -> T { // The assumption for the `expect()` is that all code maintains the following invariant: @@ -275,10 +302,6 @@ impl LayerFileManager { self.0.contains_key(&layer.layer_desc().key()) } - pub(crate) fn new() -> Self { - Self(HashMap::new()) - } - pub(crate) fn remove(&mut self, layer: &T) { let present = self.0.remove(&layer.layer_desc().key()); if present.is_none() && cfg!(debug_assertions) { From 8b8ff88e4b0e1a1b1c14f0edbe50e0c6236afa93 Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Mon, 12 Feb 2024 16:25:33 +0100 Subject: [PATCH 08/20] GH actions: label to disable CI runs completely (#6677) I don't want my very-early-draft PRs to trigger any CI runs. So, add a label `run-no-ci`, and piggy-back on the `check-permissions` job. --- .github/workflows/actionlint.yml | 1 + .github/workflows/build_and_test.yml | 2 +- .github/workflows/neon_extra_builds.yml | 2 ++ 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/actionlint.yml b/.github/workflows/actionlint.yml index 584828c1d0..c290ff88e2 100644 --- a/.github/workflows/actionlint.yml +++ b/.github/workflows/actionlint.yml @@ -17,6 +17,7 @@ concurrency: jobs: actionlint: + if: ${{ !contains(github.event.pull_request.labels.*.name, 'run-no-ci') }} runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 078916e1ea..6e4020a1b8 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -26,8 +26,8 @@ env: jobs: check-permissions: + if: ${{ !contains(github.event.pull_request.labels.*.name, 'run-no-ci') }} runs-on: ubuntu-latest - steps: - name: Disallow PRs from forks if: | diff --git a/.github/workflows/neon_extra_builds.yml b/.github/workflows/neon_extra_builds.yml index c90ef60074..ff2a3a040a 100644 --- a/.github/workflows/neon_extra_builds.yml +++ b/.github/workflows/neon_extra_builds.yml @@ -117,6 +117,7 @@ jobs: check-linux-arm-build: timeout-minutes: 90 + if: ${{ !contains(github.event.pull_request.labels.*.name, 'run-no-ci') }} runs-on: [ self-hosted, dev, arm64 ] env: @@ -237,6 +238,7 @@ jobs: check-codestyle-rust-arm: timeout-minutes: 90 + if: ${{ !contains(github.event.pull_request.labels.*.name, 'run-no-ci') }} runs-on: [ self-hosted, dev, arm64 ] container: From a1f37cba1c790e5b89958fb7df13cde39429add8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Arpad=20M=C3=BCller?= Date: Mon, 12 Feb 2024 19:15:21 +0100 Subject: [PATCH 09/20] Add test that runs the S3 scrubber (#6641) In #6079 it was found that there is no test that executes the scrubber. We now add such a test, which does the following things: * create a tenant, write some data * run the scrubber * remove the tenant * run the scrubber again Each time, the scrubber runs the scan-metadata command. Before #6079 we would have errored, now we don't. Fixes #6080 --- test_runner/fixtures/neon_fixtures.py | 8 ++-- .../regress/test_pageserver_generations.py | 4 +- .../regress/test_pageserver_secondary.py | 2 +- test_runner/regress/test_tenant_delete.py | 40 ++++++++++++++++++- 4 files changed, 45 insertions(+), 9 deletions(-) diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index faa8effe10..26f2b999a6 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -899,7 +899,7 @@ class NeonEnvBuilder: if self.scrub_on_exit: try: - S3Scrubber(self.test_output_dir, self).scan_metadata() + S3Scrubber(self).scan_metadata() except Exception as e: log.error(f"Error during remote storage scrub: {e}") cleanup_error = e @@ -3659,9 +3659,9 @@ class SafekeeperHttpClient(requests.Session): class S3Scrubber: - def __init__(self, log_dir: Path, env: NeonEnvBuilder): + def __init__(self, env: NeonEnvBuilder, log_dir: Optional[Path] = None): self.env = env - self.log_dir = log_dir + self.log_dir = log_dir or env.test_output_dir def scrubber_cli(self, args: list[str], timeout) -> str: assert isinstance(self.env.pageserver_remote_storage, S3Storage) @@ -3682,7 +3682,7 @@ class S3Scrubber: args = base_args + args (output_path, stdout, status_code) = subprocess_capture( - self.log_dir, + self.env.test_output_dir, args, echo_stderr=True, echo_stdout=True, diff --git a/test_runner/regress/test_pageserver_generations.py b/test_runner/regress/test_pageserver_generations.py index 725ed63d1c..de9f3b6945 100644 --- a/test_runner/regress/test_pageserver_generations.py +++ b/test_runner/regress/test_pageserver_generations.py @@ -265,9 +265,7 @@ def test_generations_upgrade(neon_env_builder: NeonEnvBuilder): # Having written a mixture of generation-aware and legacy index_part.json, # ensure the scrubber handles the situation as expected. - metadata_summary = S3Scrubber( - neon_env_builder.test_output_dir, neon_env_builder - ).scan_metadata() + metadata_summary = S3Scrubber(neon_env_builder).scan_metadata() assert metadata_summary["tenant_count"] == 1 # Scrubber should have seen our timeline assert metadata_summary["timeline_count"] == 1 assert metadata_summary["timeline_shard_count"] == 1 diff --git a/test_runner/regress/test_pageserver_secondary.py b/test_runner/regress/test_pageserver_secondary.py index 293152dd62..aec989252c 100644 --- a/test_runner/regress/test_pageserver_secondary.py +++ b/test_runner/regress/test_pageserver_secondary.py @@ -498,7 +498,7 @@ def test_secondary_downloads(neon_env_builder: NeonEnvBuilder): # Scrub the remote storage # ======================== # This confirms that the scrubber isn't upset by the presence of the heatmap - S3Scrubber(neon_env_builder.test_output_dir, neon_env_builder).scan_metadata() + S3Scrubber(neon_env_builder).scan_metadata() # Detach secondary and delete tenant # =================================== diff --git a/test_runner/regress/test_tenant_delete.py b/test_runner/regress/test_tenant_delete.py index b4e5a550f3..e928ea8bb1 100644 --- a/test_runner/regress/test_tenant_delete.py +++ b/test_runner/regress/test_tenant_delete.py @@ -9,6 +9,7 @@ from fixtures.log_helper import log from fixtures.neon_fixtures import ( NeonEnvBuilder, PgBin, + S3Scrubber, last_flush_lsn_upload, wait_for_last_flush_lsn, ) @@ -19,12 +20,13 @@ from fixtures.pageserver.utils import ( assert_prefix_not_empty, poll_for_remote_storage_iterations, tenant_delete_wait_completed, + wait_for_upload, wait_tenant_status_404, wait_until_tenant_active, wait_until_tenant_state, ) from fixtures.remote_storage import RemoteStorageKind, available_s3_storages, s3_storage -from fixtures.types import TenantId, TimelineId +from fixtures.types import Lsn, TenantId, TimelineId from fixtures.utils import run_pg_bench_small, wait_until from requests.exceptions import ReadTimeout @@ -669,3 +671,39 @@ def test_tenant_delete_races_timeline_creation( # Zero tenants remain (we deleted the default tenant) assert ps_http.get_metric_value("pageserver_tenant_manager_slots") == 0 + + +def test_tenant_delete_scrubber(pg_bin: PgBin, neon_env_builder: NeonEnvBuilder): + """ + Validate that creating and then deleting the tenant both survives the scrubber, + and that one can run the scrubber without problems. + """ + + remote_storage_kind = RemoteStorageKind.MOCK_S3 + neon_env_builder.enable_pageserver_remote_storage(remote_storage_kind) + scrubber = S3Scrubber(neon_env_builder) + env = neon_env_builder.init_start(initial_tenant_conf=MANY_SMALL_LAYERS_TENANT_CONFIG) + + ps_http = env.pageserver.http_client() + # create a tenant separate from the main tenant so that we have one remaining + # after we deleted it, as the scrubber treats empty buckets as an error. + (tenant_id, timeline_id) = env.neon_cli.create_tenant() + + with env.endpoints.create_start("main", tenant_id=tenant_id) as endpoint: + run_pg_bench_small(pg_bin, endpoint.connstr()) + last_flush_lsn = Lsn(endpoint.safe_psql("SELECT pg_current_wal_flush_lsn()")[0][0]) + ps_http.timeline_checkpoint(tenant_id, timeline_id) + wait_for_upload(ps_http, tenant_id, timeline_id, last_flush_lsn) + env.stop() + + result = scrubber.scan_metadata() + assert result["with_warnings"] == [] + + env.start() + ps_http = env.pageserver.http_client() + iterations = poll_for_remote_storage_iterations(remote_storage_kind) + tenant_delete_wait_completed(ps_http, tenant_id, iterations) + env.stop() + + scrubber.scan_metadata() + assert result["with_warnings"] == [] From fac50a6264fb8ee59778d0720ba799a24c46695a Mon Sep 17 00:00:00 2001 From: Anna Khanova <32508607+khanova@users.noreply.github.com> Date: Mon, 12 Feb 2024 19:41:02 +0100 Subject: [PATCH 10/20] Proxy refactor auth+connect (#6708) ## Problem Not really a problem, just refactoring. ## Summary of changes Separate authenticate from wake compute. Do not call wake compute second time if we managed to connect to postgres or if we got it not from cache. --- proxy/src/auth.rs | 5 - proxy/src/auth/backend.rs | 146 ++++++++++++++++------------- proxy/src/auth/backend/classic.rs | 2 +- proxy/src/auth/backend/hacks.rs | 6 +- proxy/src/bin/proxy.rs | 2 +- proxy/src/compute.rs | 8 +- proxy/src/config.rs | 2 +- proxy/src/console/provider.rs | 33 ++++++- proxy/src/console/provider/mock.rs | 4 +- proxy/src/error.rs | 12 ++- proxy/src/proxy.rs | 13 +-- proxy/src/proxy/connect_compute.rs | 67 ++++++++----- proxy/src/proxy/tests.rs | 142 +++++++++++++++++++++------- proxy/src/proxy/wake_compute.rs | 16 +--- proxy/src/serverless/backend.rs | 40 +++----- 15 files changed, 307 insertions(+), 191 deletions(-) diff --git a/proxy/src/auth.rs b/proxy/src/auth.rs index 48de4e2353..c8028d1bf0 100644 --- a/proxy/src/auth.rs +++ b/proxy/src/auth.rs @@ -36,9 +36,6 @@ pub enum AuthErrorImpl { #[error(transparent)] GetAuthInfo(#[from] console::errors::GetAuthInfoError), - #[error(transparent)] - WakeCompute(#[from] console::errors::WakeComputeError), - /// SASL protocol errors (includes [SCRAM](crate::scram)). #[error(transparent)] Sasl(#[from] crate::sasl::Error), @@ -119,7 +116,6 @@ impl UserFacingError for AuthError { match self.0.as_ref() { Link(e) => e.to_string_client(), GetAuthInfo(e) => e.to_string_client(), - WakeCompute(e) => e.to_string_client(), Sasl(e) => e.to_string_client(), AuthFailed(_) => self.to_string(), BadAuthMethod(_) => self.to_string(), @@ -139,7 +135,6 @@ impl ReportableError for AuthError { match self.0.as_ref() { Link(e) => e.get_error_kind(), GetAuthInfo(e) => e.get_error_kind(), - WakeCompute(e) => e.get_error_kind(), Sasl(e) => e.get_error_kind(), AuthFailed(_) => crate::error::ErrorKind::User, BadAuthMethod(_) => crate::error::ErrorKind::User, diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index c9f21f1cf5..47c1dc4e92 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -10,9 +10,9 @@ use crate::auth::validate_password_and_exchange; use crate::cache::Cached; use crate::console::errors::GetAuthInfoError; use crate::console::provider::{CachedRoleSecret, ConsoleBackend}; -use crate::console::AuthSecret; +use crate::console::{AuthSecret, NodeInfo}; use crate::context::RequestMonitoring; -use crate::proxy::wake_compute::wake_compute; +use crate::proxy::connect_compute::ComputeConnectBackend; use crate::proxy::NeonOptions; use crate::stream::Stream; use crate::{ @@ -26,7 +26,6 @@ use crate::{ stream, url, }; use crate::{scram, EndpointCacheKey, EndpointId, RoleName}; -use futures::TryFutureExt; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::info; @@ -56,11 +55,11 @@ impl std::ops::Deref for MaybeOwned<'_, T> { /// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`], /// this helps us provide the credentials only to those auth /// backends which require them for the authentication process. -pub enum BackendType<'a, T> { +pub enum BackendType<'a, T, D> { /// Cloud API (V2). Console(MaybeOwned<'a, ConsoleBackend>, T), /// Authentication via a web browser. - Link(MaybeOwned<'a, url::ApiUrl>), + Link(MaybeOwned<'a, url::ApiUrl>, D), } pub trait TestBackend: Send + Sync + 'static { @@ -71,7 +70,7 @@ pub trait TestBackend: Send + Sync + 'static { fn get_role_secret(&self) -> Result; } -impl std::fmt::Display for BackendType<'_, ()> { +impl std::fmt::Display for BackendType<'_, (), ()> { fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { use BackendType::*; match self { @@ -86,51 +85,50 @@ impl std::fmt::Display for BackendType<'_, ()> { #[cfg(test)] ConsoleBackend::Test(_) => fmt.debug_tuple("Test").finish(), }, - Link(url) => fmt.debug_tuple("Link").field(&url.as_str()).finish(), + Link(url, _) => fmt.debug_tuple("Link").field(&url.as_str()).finish(), } } } -impl BackendType<'_, T> { +impl BackendType<'_, T, D> { /// Very similar to [`std::option::Option::as_ref`]. /// This helps us pass structured config to async tasks. - pub fn as_ref(&self) -> BackendType<'_, &T> { + pub fn as_ref(&self) -> BackendType<'_, &T, &D> { use BackendType::*; match self { Console(c, x) => Console(MaybeOwned::Borrowed(c), x), - Link(c) => Link(MaybeOwned::Borrowed(c)), + Link(c, x) => Link(MaybeOwned::Borrowed(c), x), } } } -impl<'a, T> BackendType<'a, T> { +impl<'a, T, D> BackendType<'a, T, D> { /// Very similar to [`std::option::Option::map`]. /// Maps [`BackendType`] to [`BackendType`] by applying /// a function to a contained value. - pub fn map(self, f: impl FnOnce(T) -> R) -> BackendType<'a, R> { + pub fn map(self, f: impl FnOnce(T) -> R) -> BackendType<'a, R, D> { use BackendType::*; match self { Console(c, x) => Console(c, f(x)), - Link(c) => Link(c), + Link(c, x) => Link(c, x), } } } - -impl<'a, T, E> BackendType<'a, Result> { +impl<'a, T, D, E> BackendType<'a, Result, D> { /// Very similar to [`std::option::Option::transpose`]. /// This is most useful for error handling. - pub fn transpose(self) -> Result, E> { + pub fn transpose(self) -> Result, E> { use BackendType::*; match self { Console(c, x) => x.map(|x| Console(c, x)), - Link(c) => Ok(Link(c)), + Link(c, x) => Ok(Link(c, x)), } } } -pub struct ComputeCredentials { +pub struct ComputeCredentials { pub info: ComputeUserInfo, - pub keys: T, + pub keys: ComputeCredentialKeys, } #[derive(Debug, Clone)] @@ -153,7 +151,6 @@ impl ComputeUserInfo { } pub enum ComputeCredentialKeys { - #[cfg(any(test, feature = "testing"))] Password(Vec), AuthKeys(AuthKeys), } @@ -188,7 +185,7 @@ async fn auth_quirks( client: &mut stream::PqStream>, allow_cleartext: bool, config: &'static AuthenticationConfig, -) -> auth::Result> { +) -> auth::Result { // If there's no project so far, that entails that client doesn't // support SNI or other means of passing the endpoint (project) name. // We now expect to see a very specific payload in the place of password. @@ -198,8 +195,11 @@ async fn auth_quirks( ctx.set_endpoint_id(res.info.endpoint.clone()); tracing::Span::current().record("ep", &tracing::field::display(&res.info.endpoint)); - - (res.info, Some(res.keys)) + let password = match res.keys { + ComputeCredentialKeys::Password(p) => p, + _ => unreachable!("password hack should return a password"), + }; + (res.info, Some(password)) } Ok(info) => (info, None), }; @@ -253,7 +253,7 @@ async fn authenticate_with_secret( unauthenticated_password: Option>, allow_cleartext: bool, config: &'static AuthenticationConfig, -) -> auth::Result> { +) -> auth::Result { if let Some(password) = unauthenticated_password { let auth_outcome = validate_password_and_exchange(&password, secret)?; let keys = match auth_outcome { @@ -283,14 +283,14 @@ async fn authenticate_with_secret( classic::authenticate(ctx, info, client, config, secret).await } -impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> { +impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> { /// Get compute endpoint name from the credentials. pub fn get_endpoint(&self) -> Option { use BackendType::*; match self { Console(_, user_info) => user_info.endpoint_id.clone(), - Link(_) => Some("link".into()), + Link(_, _) => Some("link".into()), } } @@ -300,7 +300,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> { match self { Console(_, user_info) => &user_info.user, - Link(_) => "link", + Link(_, _) => "link", } } @@ -312,7 +312,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> { client: &mut stream::PqStream>, allow_cleartext: bool, config: &'static AuthenticationConfig, - ) -> auth::Result<(CachedNodeInfo, BackendType<'a, ComputeUserInfo>)> { + ) -> auth::Result> { use BackendType::*; let res = match self { @@ -323,33 +323,17 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> { "performing authentication using the console" ); - let compute_credentials = + let credentials = auth_quirks(ctx, &*api, user_info, client, allow_cleartext, config).await?; - - let mut num_retries = 0; - let mut node = - wake_compute(&mut num_retries, ctx, &api, &compute_credentials.info).await?; - - ctx.set_project(node.aux.clone()); - - match compute_credentials.keys { - #[cfg(any(test, feature = "testing"))] - ComputeCredentialKeys::Password(password) => node.config.password(password), - ComputeCredentialKeys::AuthKeys(auth_keys) => node.config.auth_keys(auth_keys), - }; - - (node, BackendType::Console(api, compute_credentials.info)) + BackendType::Console(api, credentials) } // NOTE: this auth backend doesn't use client credentials. - Link(url) => { + Link(url, _) => { info!("performing link authentication"); - let node_info = link::authenticate(ctx, &url, client).await?; + let info = link::authenticate(ctx, &url, client).await?; - ( - CachedNodeInfo::new_uncached(node_info), - BackendType::Link(url), - ) + BackendType::Link(url, info) } }; @@ -358,7 +342,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> { } } -impl BackendType<'_, ComputeUserInfo> { +impl BackendType<'_, ComputeUserInfo, &()> { pub async fn get_role_secret( &self, ctx: &mut RequestMonitoring, @@ -366,7 +350,7 @@ impl BackendType<'_, ComputeUserInfo> { use BackendType::*; match self { Console(api, user_info) => api.get_role_secret(ctx, user_info).await, - Link(_) => Ok(Cached::new_uncached(None)), + Link(_, _) => Ok(Cached::new_uncached(None)), } } @@ -377,21 +361,51 @@ impl BackendType<'_, ComputeUserInfo> { use BackendType::*; match self { Console(api, user_info) => api.get_allowed_ips_and_secret(ctx, user_info).await, - Link(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)), - } - } - - /// When applicable, wake the compute node, gaining its connection info in the process. - /// The link auth flow doesn't support this, so we return [`None`] in that case. - pub async fn wake_compute( - &self, - ctx: &mut RequestMonitoring, - ) -> Result, console::errors::WakeComputeError> { - use BackendType::*; - - match self { - Console(api, user_info) => api.wake_compute(ctx, user_info).map_ok(Some).await, - Link(_) => Ok(None), + Link(_, _) => Ok((Cached::new_uncached(Arc::new(vec![])), None)), + } + } +} + +#[async_trait::async_trait] +impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, NodeInfo> { + async fn wake_compute( + &self, + ctx: &mut RequestMonitoring, + ) -> Result { + use BackendType::*; + + match self { + Console(api, creds) => api.wake_compute(ctx, &creds.info).await, + Link(_, info) => Ok(Cached::new_uncached(info.clone())), + } + } + + fn get_keys(&self) -> Option<&ComputeCredentialKeys> { + match self { + BackendType::Console(_, creds) => Some(&creds.keys), + BackendType::Link(_, _) => None, + } + } +} + +#[async_trait::async_trait] +impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, &()> { + async fn wake_compute( + &self, + ctx: &mut RequestMonitoring, + ) -> Result { + use BackendType::*; + + match self { + Console(api, creds) => api.wake_compute(ctx, &creds.info).await, + Link(_, _) => unreachable!("link auth flow doesn't support waking the compute"), + } + } + + fn get_keys(&self) -> Option<&ComputeCredentialKeys> { + match self { + BackendType::Console(_, creds) => Some(&creds.keys), + BackendType::Link(_, _) => None, } } } diff --git a/proxy/src/auth/backend/classic.rs b/proxy/src/auth/backend/classic.rs index e855843bc3..d075331846 100644 --- a/proxy/src/auth/backend/classic.rs +++ b/proxy/src/auth/backend/classic.rs @@ -17,7 +17,7 @@ pub(super) async fn authenticate( client: &mut PqStream>, config: &'static AuthenticationConfig, secret: AuthSecret, -) -> auth::Result> { +) -> auth::Result { let flow = AuthFlow::new(client); let scram_keys = match secret { #[cfg(any(test, feature = "testing"))] diff --git a/proxy/src/auth/backend/hacks.rs b/proxy/src/auth/backend/hacks.rs index 9f60b709d4..26cf7a01f2 100644 --- a/proxy/src/auth/backend/hacks.rs +++ b/proxy/src/auth/backend/hacks.rs @@ -20,7 +20,7 @@ pub async fn authenticate_cleartext( info: ComputeUserInfo, client: &mut stream::PqStream>, secret: AuthSecret, -) -> auth::Result> { +) -> auth::Result { warn!("cleartext auth flow override is enabled, proceeding"); ctx.set_auth_method(crate::context::AuthMethod::Cleartext); @@ -51,7 +51,7 @@ pub async fn password_hack_no_authentication( ctx: &mut RequestMonitoring, info: ComputeUserInfoNoEndpoint, client: &mut stream::PqStream>, -) -> auth::Result>> { +) -> auth::Result { warn!("project not specified, resorting to the password hack auth flow"); ctx.set_auth_method(crate::context::AuthMethod::Cleartext); @@ -73,6 +73,6 @@ pub async fn password_hack_no_authentication( options: info.options, endpoint: payload.endpoint, }, - keys: payload.password, + keys: ComputeCredentialKeys::Password(payload.password), }) } diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index 8fbcb56758..00a229c135 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -383,7 +383,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { } AuthBackend::Link => { let url = args.uri.parse()?; - auth::BackendType::Link(MaybeOwned::Owned(url)) + auth::BackendType::Link(MaybeOwned::Owned(url), ()) } }; let http_config = HttpConfig { diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index 83940d80ec..b61c1fb9ef 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -1,7 +1,7 @@ use crate::{ auth::parse_endpoint_param, cancellation::CancelClosure, - console::errors::WakeComputeError, + console::{errors::WakeComputeError, messages::MetricsAuxInfo}, context::RequestMonitoring, error::{ReportableError, UserFacingError}, metrics::NUM_DB_CONNECTIONS_GAUGE, @@ -93,7 +93,7 @@ impl ConnCfg { } /// Reuse password or auth keys from the other config. - pub fn reuse_password(&mut self, other: &Self) { + pub fn reuse_password(&mut self, other: Self) { if let Some(password) = other.get_password() { self.password(password); } @@ -253,6 +253,8 @@ pub struct PostgresConnection { pub params: std::collections::HashMap, /// Query cancellation token. pub cancel_closure: CancelClosure, + /// Labels for proxy's metrics. + pub aux: MetricsAuxInfo, _guage: IntCounterPairGuard, } @@ -263,6 +265,7 @@ impl ConnCfg { &self, ctx: &mut RequestMonitoring, allow_self_signed_compute: bool, + aux: MetricsAuxInfo, timeout: Duration, ) -> Result { let (socket_addr, stream, host) = self.connect_raw(timeout).await?; @@ -297,6 +300,7 @@ impl ConnCfg { stream, params, cancel_closure, + aux, _guage: NUM_DB_CONNECTIONS_GAUGE .with_label_values(&[ctx.protocol]) .guard(), diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 31c9228b35..5fcb537834 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -13,7 +13,7 @@ use x509_parser::oid_registry; pub struct ProxyConfig { pub tls_config: Option, - pub auth_backend: auth::BackendType<'static, ()>, + pub auth_backend: auth::BackendType<'static, (), ()>, pub metric_collection: Option, pub allow_self_signed_compute: bool, pub http_config: HttpConfig, diff --git a/proxy/src/console/provider.rs b/proxy/src/console/provider.rs index e5cad42753..640444d14e 100644 --- a/proxy/src/console/provider.rs +++ b/proxy/src/console/provider.rs @@ -4,7 +4,10 @@ pub mod neon; use super::messages::MetricsAuxInfo; use crate::{ - auth::{backend::ComputeUserInfo, IpPattern}, + auth::{ + backend::{ComputeCredentialKeys, ComputeUserInfo}, + IpPattern, + }, cache::{project_info::ProjectInfoCacheImpl, Cached, TimedLru}, compute, config::{CacheOptions, ProjectInfoCacheOptions}, @@ -261,6 +264,34 @@ pub struct NodeInfo { pub allow_self_signed_compute: bool, } +impl NodeInfo { + pub async fn connect( + &self, + ctx: &mut RequestMonitoring, + timeout: Duration, + ) -> Result { + self.config + .connect( + ctx, + self.allow_self_signed_compute, + self.aux.clone(), + timeout, + ) + .await + } + pub fn reuse_settings(&mut self, other: Self) { + self.allow_self_signed_compute = other.allow_self_signed_compute; + self.config.reuse_password(other.config); + } + + pub fn set_keys(&mut self, keys: &ComputeCredentialKeys) { + match keys { + ComputeCredentialKeys::Password(password) => self.config.password(password), + ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys), + }; + } +} + pub type NodeInfoCache = TimedLru; pub type CachedNodeInfo = Cached<&'static NodeInfoCache>; pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option>; diff --git a/proxy/src/console/provider/mock.rs b/proxy/src/console/provider/mock.rs index 79a04f255d..0579ef6fc4 100644 --- a/proxy/src/console/provider/mock.rs +++ b/proxy/src/console/provider/mock.rs @@ -176,9 +176,7 @@ impl super::Api for Api { _ctx: &mut RequestMonitoring, _user_info: &ComputeUserInfo, ) -> Result { - self.do_wake_compute() - .map_ok(CachedNodeInfo::new_uncached) - .await + self.do_wake_compute().map_ok(Cached::new_uncached).await } } diff --git a/proxy/src/error.rs b/proxy/src/error.rs index eafe92bf48..69fe1ebc12 100644 --- a/proxy/src/error.rs +++ b/proxy/src/error.rs @@ -29,7 +29,7 @@ pub trait UserFacingError: ReportableError { } } -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, Eq, PartialEq)] pub enum ErrorKind { /// Wrong password, unknown endpoint, protocol violation, etc... User, @@ -90,3 +90,13 @@ impl ReportableError for tokio::time::error::Elapsed { ErrorKind::RateLimit } } + +impl ReportableError for tokio_postgres::error::Error { + fn get_error_kind(&self) -> ErrorKind { + if self.as_db_error().is_some() { + ErrorKind::Postgres + } else { + ErrorKind::Compute + } + } +} diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 77aadb6f28..5f65de4c98 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -163,14 +163,14 @@ pub enum ClientMode { /// Abstracts the logic of handling TCP vs WS clients impl ClientMode { - fn allow_cleartext(&self) -> bool { + pub fn allow_cleartext(&self) -> bool { match self { ClientMode::Tcp => false, ClientMode::Websockets { .. } => true, } } - fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool { + pub fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool { match self { ClientMode::Tcp => config.allow_self_signed_compute, ClientMode::Websockets { .. } => false, @@ -287,7 +287,7 @@ pub async fn handle_client( } let user = user_info.get_user().to_owned(); - let (mut node_info, user_info) = match user_info + let user_info = match user_info .authenticate( ctx, &mut stream, @@ -306,14 +306,11 @@ pub async fn handle_client( } }; - node_info.allow_self_signed_compute = mode.allow_self_signed_compute(config); - - let aux = node_info.aux.clone(); let mut node = connect_to_compute( ctx, &TcpMechanism { params: ¶ms }, - node_info, &user_info, + mode.allow_self_signed_compute(config), ) .or_else(|e| stream.throw_error(e)) .await?; @@ -330,8 +327,8 @@ pub async fn handle_client( Ok(Some(ProxyPassthrough { client: stream, + aux: node.aux.clone(), compute: node, - aux, req: _request_gauge, conn: _client_gauge, })) diff --git a/proxy/src/proxy/connect_compute.rs b/proxy/src/proxy/connect_compute.rs index b9346aa743..6e57caf998 100644 --- a/proxy/src/proxy/connect_compute.rs +++ b/proxy/src/proxy/connect_compute.rs @@ -1,8 +1,9 @@ use crate::{ - auth, + auth::backend::ComputeCredentialKeys, compute::{self, PostgresConnection}, - console::{self, errors::WakeComputeError}, + console::{self, errors::WakeComputeError, CachedNodeInfo, NodeInfo}, context::RequestMonitoring, + error::ReportableError, metrics::NUM_CONNECTION_FAILURES, proxy::{ retry::{retry_after, ShouldRetry}, @@ -20,7 +21,7 @@ const CONNECT_TIMEOUT: time::Duration = time::Duration::from_secs(2); /// (e.g. the compute node's address might've changed at the wrong time). /// Invalidate the cache entry (if any) to prevent subsequent errors. #[tracing::instrument(name = "invalidate_cache", skip_all)] -pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> compute::ConnCfg { +pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> NodeInfo { let is_cached = node_info.cached(); if is_cached { warn!("invalidating stalled compute node info cache entry"); @@ -31,13 +32,13 @@ pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> compute::ConnCfg }; NUM_CONNECTION_FAILURES.with_label_values(&[label]).inc(); - node_info.invalidate().config + node_info.invalidate() } #[async_trait] pub trait ConnectMechanism { type Connection; - type ConnectError; + type ConnectError: ReportableError; type Error: From; async fn connect_once( &self, @@ -49,6 +50,16 @@ pub trait ConnectMechanism { fn update_connect_config(&self, conf: &mut compute::ConnCfg); } +#[async_trait] +pub trait ComputeConnectBackend { + async fn wake_compute( + &self, + ctx: &mut RequestMonitoring, + ) -> Result; + + fn get_keys(&self) -> Option<&ComputeCredentialKeys>; +} + pub struct TcpMechanism<'a> { /// KV-dictionary with PostgreSQL connection params. pub params: &'a StartupMessageParams, @@ -67,11 +78,7 @@ impl ConnectMechanism for TcpMechanism<'_> { node_info: &console::CachedNodeInfo, timeout: time::Duration, ) -> Result { - let allow_self_signed_compute = node_info.allow_self_signed_compute; - node_info - .config - .connect(ctx, allow_self_signed_compute, timeout) - .await + node_info.connect(ctx, timeout).await } fn update_connect_config(&self, config: &mut compute::ConnCfg) { @@ -82,16 +89,23 @@ impl ConnectMechanism for TcpMechanism<'_> { /// Try to connect to the compute node, retrying if necessary. /// This function might update `node_info`, so we take it by `&mut`. #[tracing::instrument(skip_all)] -pub async fn connect_to_compute( +pub async fn connect_to_compute( ctx: &mut RequestMonitoring, mechanism: &M, - mut node_info: console::CachedNodeInfo, - user_info: &auth::BackendType<'_, auth::backend::ComputeUserInfo>, + user_info: &B, + allow_self_signed_compute: bool, ) -> Result where M::ConnectError: ShouldRetry + std::fmt::Debug, M::Error: From, { + let mut num_retries = 0; + let mut node_info = wake_compute(&mut num_retries, ctx, user_info).await?; + if let Some(keys) = user_info.get_keys() { + node_info.set_keys(keys); + } + node_info.allow_self_signed_compute = allow_self_signed_compute; + // let mut node_info = credentials.get_node_info(ctx, user_info).await?; mechanism.update_connect_config(&mut node_info.config); // try once @@ -108,28 +122,31 @@ where error!(error = ?err, "could not connect to compute node"); - let mut num_retries = 1; - - match user_info { - auth::BackendType::Console(api, info) => { + let node_info = + if err.get_error_kind() == crate::error::ErrorKind::Postgres || !node_info.cached() { + // If the error is Postgres, that means that we managed to connect to the compute node, but there was an error. + // Do not need to retrieve a new node_info, just return the old one. + if !err.should_retry(num_retries) { + return Err(err.into()); + } + node_info + } else { // if we failed to connect, it's likely that the compute node was suspended, wake a new compute node info!("compute node's state has likely changed; requesting a wake-up"); - ctx.latency_timer.cache_miss(); - let config = invalidate_cache(node_info); - node_info = wake_compute(&mut num_retries, ctx, api, info).await?; + let old_node_info = invalidate_cache(node_info); + let mut node_info = wake_compute(&mut num_retries, ctx, user_info).await?; + node_info.reuse_settings(old_node_info); - node_info.config.reuse_password(&config); mechanism.update_connect_config(&mut node_info.config); - } - // nothing to do? - auth::BackendType::Link(_) => {} - }; + node_info + }; // now that we have a new node, try connect to it repeatedly. // this can error for a few reasons, for instance: // * DNS connection settings haven't quite propagated yet info!("wake_compute success. attempting to connect"); + num_retries = 1; loop { match mechanism .connect_once(ctx, &node_info, CONNECT_TIMEOUT) diff --git a/proxy/src/proxy/tests.rs b/proxy/src/proxy/tests.rs index 5bb43c0375..efbd661bbf 100644 --- a/proxy/src/proxy/tests.rs +++ b/proxy/src/proxy/tests.rs @@ -2,13 +2,19 @@ mod mitm; +use std::time::Duration; + use super::connect_compute::ConnectMechanism; use super::retry::ShouldRetry; use super::*; -use crate::auth::backend::{ComputeUserInfo, MaybeOwned, TestBackend}; +use crate::auth::backend::{ + ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned, TestBackend, +}; use crate::config::CertResolver; +use crate::console::caches::NodeInfoCache; use crate::console::provider::{CachedAllowedIps, CachedRoleSecret, ConsoleBackend}; use crate::console::{self, CachedNodeInfo, NodeInfo}; +use crate::error::ErrorKind; use crate::proxy::retry::{retry_after, NUM_RETRIES_CONNECT}; use crate::{auth, http, sasl, scram}; use async_trait::async_trait; @@ -369,12 +375,15 @@ enum ConnectAction { Connect, Retry, Fail, + RetryPg, + FailPg, } #[derive(Clone)] struct TestConnectMechanism { counter: Arc>, sequence: Vec, + cache: &'static NodeInfoCache, } impl TestConnectMechanism { @@ -393,6 +402,12 @@ impl TestConnectMechanism { Self { counter: Arc::new(std::sync::Mutex::new(0)), sequence, + cache: Box::leak(Box::new(NodeInfoCache::new( + "test", + 1, + Duration::from_secs(100), + false, + ))), } } } @@ -403,6 +418,13 @@ struct TestConnection; #[derive(Debug)] struct TestConnectError { retryable: bool, + kind: crate::error::ErrorKind, +} + +impl ReportableError for TestConnectError { + fn get_error_kind(&self) -> crate::error::ErrorKind { + self.kind + } } impl std::fmt::Display for TestConnectError { @@ -436,8 +458,22 @@ impl ConnectMechanism for TestConnectMechanism { *counter += 1; match action { ConnectAction::Connect => Ok(TestConnection), - ConnectAction::Retry => Err(TestConnectError { retryable: true }), - ConnectAction::Fail => Err(TestConnectError { retryable: false }), + ConnectAction::Retry => Err(TestConnectError { + retryable: true, + kind: ErrorKind::Compute, + }), + ConnectAction::Fail => Err(TestConnectError { + retryable: false, + kind: ErrorKind::Compute, + }), + ConnectAction::FailPg => Err(TestConnectError { + retryable: false, + kind: ErrorKind::Postgres, + }), + ConnectAction::RetryPg => Err(TestConnectError { + retryable: true, + kind: ErrorKind::Postgres, + }), x => panic!("expecting action {:?}, connect is called instead", x), } } @@ -451,7 +487,7 @@ impl TestBackend for TestConnectMechanism { let action = self.sequence[*counter]; *counter += 1; match action { - ConnectAction::Wake => Ok(helper_create_cached_node_info()), + ConnectAction::Wake => Ok(helper_create_cached_node_info(self.cache)), ConnectAction::WakeFail => { let err = console::errors::ApiError::Console { status: http::StatusCode::FORBIDDEN, @@ -483,37 +519,41 @@ impl TestBackend for TestConnectMechanism { } } -fn helper_create_cached_node_info() -> CachedNodeInfo { +fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo { let node = NodeInfo { config: compute::ConnCfg::new(), aux: Default::default(), allow_self_signed_compute: false, }; - CachedNodeInfo::new_uncached(node) + let (_, node) = cache.insert("key".into(), node); + node } fn helper_create_connect_info( mechanism: &TestConnectMechanism, -) -> (CachedNodeInfo, auth::BackendType<'static, ComputeUserInfo>) { - let cache = helper_create_cached_node_info(); +) -> auth::BackendType<'static, ComputeCredentials, &()> { let user_info = auth::BackendType::Console( MaybeOwned::Owned(ConsoleBackend::Test(Box::new(mechanism.clone()))), - ComputeUserInfo { - endpoint: "endpoint".into(), - user: "user".into(), - options: NeonOptions::parse_options_raw(""), + ComputeCredentials { + info: ComputeUserInfo { + endpoint: "endpoint".into(), + user: "user".into(), + options: NeonOptions::parse_options_raw(""), + }, + keys: ComputeCredentialKeys::Password("password".into()), }, ); - (cache, user_info) + user_info } #[tokio::test] async fn connect_to_compute_success() { + let _ = env_logger::try_init(); use ConnectAction::*; let mut ctx = RequestMonitoring::test(); - let mechanism = TestConnectMechanism::new(vec![Connect]); - let (cache, user_info) = helper_create_connect_info(&mechanism); - connect_to_compute(&mut ctx, &mechanism, cache, &user_info) + let mechanism = TestConnectMechanism::new(vec![Wake, Connect]); + let user_info = helper_create_connect_info(&mechanism); + connect_to_compute(&mut ctx, &mechanism, &user_info, false) .await .unwrap(); mechanism.verify(); @@ -521,24 +561,52 @@ async fn connect_to_compute_success() { #[tokio::test] async fn connect_to_compute_retry() { + let _ = env_logger::try_init(); use ConnectAction::*; let mut ctx = RequestMonitoring::test(); - let mechanism = TestConnectMechanism::new(vec![Retry, Wake, Retry, Connect]); - let (cache, user_info) = helper_create_connect_info(&mechanism); - connect_to_compute(&mut ctx, &mechanism, cache, &user_info) + let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Connect]); + let user_info = helper_create_connect_info(&mechanism); + connect_to_compute(&mut ctx, &mechanism, &user_info, false) .await .unwrap(); mechanism.verify(); } +#[tokio::test] +async fn connect_to_compute_retry_pg() { + let _ = env_logger::try_init(); + use ConnectAction::*; + let mut ctx = RequestMonitoring::test(); + let mechanism = TestConnectMechanism::new(vec![Wake, RetryPg, Connect]); + let user_info = helper_create_connect_info(&mechanism); + connect_to_compute(&mut ctx, &mechanism, &user_info, false) + .await + .unwrap(); + mechanism.verify(); +} + +#[tokio::test] +async fn connect_to_compute_fail_pg() { + let _ = env_logger::try_init(); + use ConnectAction::*; + let mut ctx = RequestMonitoring::test(); + let mechanism = TestConnectMechanism::new(vec![Wake, FailPg]); + let user_info = helper_create_connect_info(&mechanism); + connect_to_compute(&mut ctx, &mechanism, &user_info, false) + .await + .unwrap_err(); + mechanism.verify(); +} + /// Test that we don't retry if the error is not retryable. #[tokio::test] async fn connect_to_compute_non_retry_1() { + let _ = env_logger::try_init(); use ConnectAction::*; let mut ctx = RequestMonitoring::test(); - let mechanism = TestConnectMechanism::new(vec![Retry, Wake, Retry, Fail]); - let (cache, user_info) = helper_create_connect_info(&mechanism); - connect_to_compute(&mut ctx, &mechanism, cache, &user_info) + let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Fail]); + let user_info = helper_create_connect_info(&mechanism); + connect_to_compute(&mut ctx, &mechanism, &user_info, false) .await .unwrap_err(); mechanism.verify(); @@ -547,11 +615,12 @@ async fn connect_to_compute_non_retry_1() { /// Even for non-retryable errors, we should retry at least once. #[tokio::test] async fn connect_to_compute_non_retry_2() { + let _ = env_logger::try_init(); use ConnectAction::*; let mut ctx = RequestMonitoring::test(); - let mechanism = TestConnectMechanism::new(vec![Fail, Wake, Retry, Connect]); - let (cache, user_info) = helper_create_connect_info(&mechanism); - connect_to_compute(&mut ctx, &mechanism, cache, &user_info) + let mechanism = TestConnectMechanism::new(vec![Wake, Fail, Wake, Connect]); + let user_info = helper_create_connect_info(&mechanism); + connect_to_compute(&mut ctx, &mechanism, &user_info, false) .await .unwrap(); mechanism.verify(); @@ -560,15 +629,16 @@ async fn connect_to_compute_non_retry_2() { /// Retry for at most `NUM_RETRIES_CONNECT` times. #[tokio::test] async fn connect_to_compute_non_retry_3() { + let _ = env_logger::try_init(); assert_eq!(NUM_RETRIES_CONNECT, 16); use ConnectAction::*; let mut ctx = RequestMonitoring::test(); let mechanism = TestConnectMechanism::new(vec![ - Retry, Wake, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, - Retry, Retry, Retry, Retry, /* the 17th time */ Retry, + Wake, Retry, Wake, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, + Retry, Retry, Retry, Retry, Retry, /* the 17th time */ Retry, ]); - let (cache, user_info) = helper_create_connect_info(&mechanism); - connect_to_compute(&mut ctx, &mechanism, cache, &user_info) + let user_info = helper_create_connect_info(&mechanism); + connect_to_compute(&mut ctx, &mechanism, &user_info, false) .await .unwrap_err(); mechanism.verify(); @@ -577,11 +647,12 @@ async fn connect_to_compute_non_retry_3() { /// Should retry wake compute. #[tokio::test] async fn wake_retry() { + let _ = env_logger::try_init(); use ConnectAction::*; let mut ctx = RequestMonitoring::test(); - let mechanism = TestConnectMechanism::new(vec![Retry, WakeRetry, Wake, Connect]); - let (cache, user_info) = helper_create_connect_info(&mechanism); - connect_to_compute(&mut ctx, &mechanism, cache, &user_info) + let mechanism = TestConnectMechanism::new(vec![WakeRetry, Wake, Connect]); + let user_info = helper_create_connect_info(&mechanism); + connect_to_compute(&mut ctx, &mechanism, &user_info, false) .await .unwrap(); mechanism.verify(); @@ -590,11 +661,12 @@ async fn wake_retry() { /// Wake failed with a non-retryable error. #[tokio::test] async fn wake_non_retry() { + let _ = env_logger::try_init(); use ConnectAction::*; let mut ctx = RequestMonitoring::test(); - let mechanism = TestConnectMechanism::new(vec![Retry, WakeFail]); - let (cache, user_info) = helper_create_connect_info(&mechanism); - connect_to_compute(&mut ctx, &mechanism, cache, &user_info) + let mechanism = TestConnectMechanism::new(vec![WakeRetry, WakeFail]); + let user_info = helper_create_connect_info(&mechanism); + connect_to_compute(&mut ctx, &mechanism, &user_info, false) .await .unwrap_err(); mechanism.verify(); diff --git a/proxy/src/proxy/wake_compute.rs b/proxy/src/proxy/wake_compute.rs index 925727bdab..2c593451b4 100644 --- a/proxy/src/proxy/wake_compute.rs +++ b/proxy/src/proxy/wake_compute.rs @@ -1,9 +1,4 @@ -use crate::auth::backend::ComputeUserInfo; -use crate::console::{ - errors::WakeComputeError, - provider::{CachedNodeInfo, ConsoleBackend}, - Api, -}; +use crate::console::{errors::WakeComputeError, provider::CachedNodeInfo}; use crate::context::RequestMonitoring; use crate::metrics::{bool_to_str, NUM_WAKEUP_FAILURES}; use crate::proxy::retry::retry_after; @@ -11,17 +6,16 @@ use hyper::StatusCode; use std::ops::ControlFlow; use tracing::{error, warn}; +use super::connect_compute::ComputeConnectBackend; use super::retry::ShouldRetry; -/// wake a compute (or retrieve an existing compute session from cache) -pub async fn wake_compute( +pub async fn wake_compute( num_retries: &mut u32, ctx: &mut RequestMonitoring, - api: &ConsoleBackend, - info: &ComputeUserInfo, + api: &B, ) -> Result { loop { - let wake_res = api.wake_compute(ctx, info).await; + let wake_res = api.wake_compute(ctx).await; match handle_try_wake(wake_res, *num_retries) { Err(e) => { error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node"); diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 156002006d..6f93f86d5f 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -4,7 +4,7 @@ use async_trait::async_trait; use tracing::{field::display, info}; use crate::{ - auth::{backend::ComputeCredentialKeys, check_peer_addr_is_in_list, AuthError}, + auth::{backend::ComputeCredentials, check_peer_addr_is_in_list, AuthError}, compute, config::ProxyConfig, console::{ @@ -27,7 +27,7 @@ impl PoolingBackend { &self, ctx: &mut RequestMonitoring, conn_info: &ConnInfo, - ) -> Result { + ) -> Result { let user_info = conn_info.user_info.clone(); let backend = self.config.auth_backend.as_ref().map(|_| user_info.clone()); let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?; @@ -49,13 +49,17 @@ impl PoolingBackend { }; let auth_outcome = crate::auth::validate_password_and_exchange(&conn_info.password, secret)?; - match auth_outcome { + let res = match auth_outcome { crate::sasl::Outcome::Success(key) => Ok(key), crate::sasl::Outcome::Failure(reason) => { info!("auth backend failed with an error: {reason}"); Err(AuthError::auth_failed(&*conn_info.user_info.user)) } - } + }; + res.map(|key| ComputeCredentials { + info: user_info, + keys: key, + }) } // Wake up the destination if needed. Code here is a bit involved because @@ -66,7 +70,7 @@ impl PoolingBackend { &self, ctx: &mut RequestMonitoring, conn_info: ConnInfo, - keys: ComputeCredentialKeys, + keys: ComputeCredentials, force_new: bool, ) -> Result, HttpConnError> { let maybe_client = if !force_new { @@ -82,26 +86,8 @@ impl PoolingBackend { } let conn_id = uuid::Uuid::new_v4(); tracing::Span::current().record("conn_id", display(conn_id)); - info!("pool: opening a new connection '{conn_info}'"); - let backend = self - .config - .auth_backend - .as_ref() - .map(|_| conn_info.user_info.clone()); - - let mut node_info = backend - .wake_compute(ctx) - .await? - .ok_or(HttpConnError::NoComputeInfo)?; - - match keys { - #[cfg(any(test, feature = "testing"))] - ComputeCredentialKeys::Password(password) => node_info.config.password(password), - ComputeCredentialKeys::AuthKeys(auth_keys) => node_info.config.auth_keys(auth_keys), - }; - - ctx.set_project(node_info.aux.clone()); - + info!(%conn_id, "pool: opening a new connection '{conn_info}'"); + let backend = self.config.auth_backend.as_ref().map(|_| keys); crate::proxy::connect_compute::connect_to_compute( ctx, &TokioMechanism { @@ -109,8 +95,8 @@ impl PoolingBackend { conn_info, pool: self.pool.clone(), }, - node_info, &backend, + false, // do not allow self signed compute for http flow ) .await } @@ -129,8 +115,6 @@ pub enum HttpConnError { AuthError(#[from] AuthError), #[error("wake_compute returned error")] WakeCompute(#[from] WakeComputeError), - #[error("wake_compute returned nothing")] - NoComputeInfo, } struct TokioMechanism { From 4be2223a4cd80fdc40c37aab2206bb6f505dc008 Mon Sep 17 00:00:00 2001 From: Arthur Petukhovsky Date: Mon, 12 Feb 2024 20:29:57 +0000 Subject: [PATCH 11/20] Discrete event simulation for safekeepers (#5804) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR contains the first version of a [FoundationDB-like](https://www.youtube.com/watch?v=4fFDFbi3toc) simulation testing for safekeeper and walproposer. ### desim This is a core "framework" for running determenistic simulation. It operates on threads, allowing to test syncronous code (like walproposer). `libs/desim/src/executor.rs` contains implementation of a determenistic thread execution. This is achieved by blocking all threads, and each time allowing only a single thread to make an execution step. All executor's threads are blocked using `yield_me(after_ms)` function. This function is called when a thread wants to sleep or wait for an external notification (like blocking on a channel until it has a ready message). `libs/desim/src/chan.rs` contains implementation of a channel (basic sync primitive). It has unlimited capacity and any thread can push or read messages to/from it. `libs/desim/src/network.rs` has a very naive implementation of a network (only reliable TCP-like connections are supported for now), that can have arbitrary delays for each package and failure injections for breaking connections with some probability. `libs/desim/src/world.rs` ties everything together, to have a concept of virtual nodes that can have network connections between them. ### walproposer_sim Has everything to run walproposer and safekeepers in a simulation. `safekeeper.rs` reimplements all necesary stuff from `receive_wal.rs`, `send_wal.rs` and `timelines_global_map.rs`. `walproposer_api.rs` implements all walproposer callback to use simulation library. `simulation.rs` defines a schedule – a set of events like `restart ` or `write_wal` that should happen at time ``. It also has code to spawn walproposer/safekeeper threads and provide config to them. ### tests `simple_test.rs` has tests that just start walproposer and 3 safekeepers together in a simulation, and tests that they are not crashing right away. `misc_test.rs` has tests checking more advanced simulation cases, like crashing or restarting threads, testing memory deallocation, etc. `random_test.rs` is the main test, it checks thousands of random seeds (schedules) for correctness. It roughly corresponds to running a real python integration test in an environment with very unstable network and cpu, but in a determenistic way (each seed results in the same execution log) and much much faster. Closes #547 --------- Co-authored-by: Arseny Sher --- Cargo.lock | 20 + Cargo.toml | 2 + libs/desim/Cargo.toml | 18 + libs/desim/README.md | 7 + libs/desim/src/chan.rs | 108 +++ libs/desim/src/executor.rs | 483 +++++++++++++ libs/desim/src/lib.rs | 8 + libs/desim/src/network.rs | 451 ++++++++++++ libs/desim/src/node_os.rs | 54 ++ libs/desim/src/options.rs | 50 ++ libs/desim/src/proto.rs | 63 ++ libs/desim/src/time.rs | 129 ++++ libs/desim/src/world.rs | 180 +++++ libs/desim/tests/reliable_copy_test.rs | 244 +++++++ libs/postgres_ffi/src/xlog_utils.rs | 10 +- libs/walproposer/build.rs | 4 + libs/walproposer/src/api_bindings.rs | 20 +- libs/walproposer/src/walproposer.rs | 45 +- pageserver/src/walingest.rs | 2 +- pgxn/neon/walproposer.c | 15 +- pgxn/neon/walproposer.h | 9 + safekeeper/Cargo.toml | 7 + safekeeper/tests/misc_test.rs | 155 ++++ safekeeper/tests/random_test.rs | 56 ++ safekeeper/tests/simple_test.rs | 45 ++ .../tests/walproposer_sim/block_storage.rs | 57 ++ safekeeper/tests/walproposer_sim/log.rs | 77 ++ safekeeper/tests/walproposer_sim/mod.rs | 8 + .../tests/walproposer_sim/safekeeper.rs | 410 +++++++++++ .../tests/walproposer_sim/safekeeper_disk.rs | 278 +++++++ .../tests/walproposer_sim/simulation.rs | 436 +++++++++++ .../tests/walproposer_sim/simulation_logs.rs | 187 +++++ .../tests/walproposer_sim/walproposer_api.rs | 676 ++++++++++++++++++ .../tests/walproposer_sim/walproposer_disk.rs | 314 ++++++++ 34 files changed, 4603 insertions(+), 25 deletions(-) create mode 100644 libs/desim/Cargo.toml create mode 100644 libs/desim/README.md create mode 100644 libs/desim/src/chan.rs create mode 100644 libs/desim/src/executor.rs create mode 100644 libs/desim/src/lib.rs create mode 100644 libs/desim/src/network.rs create mode 100644 libs/desim/src/node_os.rs create mode 100644 libs/desim/src/options.rs create mode 100644 libs/desim/src/proto.rs create mode 100644 libs/desim/src/time.rs create mode 100644 libs/desim/src/world.rs create mode 100644 libs/desim/tests/reliable_copy_test.rs create mode 100644 safekeeper/tests/misc_test.rs create mode 100644 safekeeper/tests/random_test.rs create mode 100644 safekeeper/tests/simple_test.rs create mode 100644 safekeeper/tests/walproposer_sim/block_storage.rs create mode 100644 safekeeper/tests/walproposer_sim/log.rs create mode 100644 safekeeper/tests/walproposer_sim/mod.rs create mode 100644 safekeeper/tests/walproposer_sim/safekeeper.rs create mode 100644 safekeeper/tests/walproposer_sim/safekeeper_disk.rs create mode 100644 safekeeper/tests/walproposer_sim/simulation.rs create mode 100644 safekeeper/tests/walproposer_sim/simulation_logs.rs create mode 100644 safekeeper/tests/walproposer_sim/walproposer_api.rs create mode 100644 safekeeper/tests/walproposer_sim/walproposer_disk.rs diff --git a/Cargo.lock b/Cargo.lock index 520163e41b..f11c774016 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1639,6 +1639,22 @@ dependencies = [ "rusticata-macros", ] +[[package]] +name = "desim" +version = "0.1.0" +dependencies = [ + "anyhow", + "bytes", + "hex", + "parking_lot 0.12.1", + "rand 0.8.5", + "scopeguard", + "smallvec", + "tracing", + "utils", + "workspace_hack", +] + [[package]] name = "diesel" version = "2.1.4" @@ -4827,6 +4843,7 @@ dependencies = [ "clap", "const_format", "crc32c", + "desim", "fail", "fs2", "futures", @@ -4842,6 +4859,7 @@ dependencies = [ "postgres_backend", "postgres_ffi", "pq_proto", + "rand 0.8.5", "regex", "remote_storage", "reqwest", @@ -4862,8 +4880,10 @@ dependencies = [ "tokio-util", "toml_edit", "tracing", + "tracing-subscriber", "url", "utils", + "walproposer", "workspace_hack", ] diff --git a/Cargo.toml b/Cargo.toml index ebc3dfa7b1..8df9ca9988 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ members = [ "libs/pageserver_api", "libs/postgres_ffi", "libs/safekeeper_api", + "libs/desim", "libs/utils", "libs/consumption_metrics", "libs/postgres_backend", @@ -203,6 +204,7 @@ postgres_ffi = { version = "0.1", path = "./libs/postgres_ffi/" } pq_proto = { version = "0.1", path = "./libs/pq_proto/" } remote_storage = { version = "0.1", path = "./libs/remote_storage/" } safekeeper_api = { version = "0.1", path = "./libs/safekeeper_api" } +desim = { version = "0.1", path = "./libs/desim" } storage_broker = { version = "0.1", path = "./storage_broker/" } # Note: main broker code is inside the binary crate, so linking with the library shouldn't be heavy. tenant_size_model = { version = "0.1", path = "./libs/tenant_size_model/" } tracing-utils = { version = "0.1", path = "./libs/tracing-utils/" } diff --git a/libs/desim/Cargo.toml b/libs/desim/Cargo.toml new file mode 100644 index 0000000000..6f442d8243 --- /dev/null +++ b/libs/desim/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "desim" +version = "0.1.0" +edition.workspace = true +license.workspace = true + +[dependencies] +anyhow.workspace = true +rand.workspace = true +tracing.workspace = true +bytes.workspace = true +utils.workspace = true +parking_lot.workspace = true +hex.workspace = true +scopeguard.workspace = true +smallvec = { workspace = true, features = ["write"] } + +workspace_hack.workspace = true diff --git a/libs/desim/README.md b/libs/desim/README.md new file mode 100644 index 0000000000..80568ebb1b --- /dev/null +++ b/libs/desim/README.md @@ -0,0 +1,7 @@ +# Discrete Event SIMulator + +This is a library for running simulations of distributed systems. The main idea is borrowed from [FoundationDB](https://www.youtube.com/watch?v=4fFDFbi3toc). + +Each node runs as a separate thread. This library was not optimized for speed yet, but it's already much faster than running usual intergration tests in real time, because it uses virtual simulation time and can fast-forward time to skip intervals where all nodes are doing nothing but sleeping or waiting for something. + +The original purpose for this library is to test walproposer and safekeeper implementation working together, in a scenarios close to the real world environment. This simulator is determenistic and can inject failures in networking without waiting minutes of wall-time to trigger timeout, which makes it easier to find bugs in our consensus implementation compared to using integration tests. diff --git a/libs/desim/src/chan.rs b/libs/desim/src/chan.rs new file mode 100644 index 0000000000..6661d59871 --- /dev/null +++ b/libs/desim/src/chan.rs @@ -0,0 +1,108 @@ +use std::{collections::VecDeque, sync::Arc}; + +use parking_lot::{Mutex, MutexGuard}; + +use crate::executor::{self, PollSome, Waker}; + +/// FIFO channel with blocking send and receive. Can be cloned and shared between threads. +/// Blocking functions should be used only from threads that are managed by the executor. +pub struct Chan { + shared: Arc>, +} + +impl Clone for Chan { + fn clone(&self) -> Self { + Chan { + shared: self.shared.clone(), + } + } +} + +impl Default for Chan { + fn default() -> Self { + Self::new() + } +} + +impl Chan { + pub fn new() -> Chan { + Chan { + shared: Arc::new(State { + queue: Mutex::new(VecDeque::new()), + waker: Waker::new(), + }), + } + } + + /// Get a message from the front of the queue, block if the queue is empty. + /// If not called from the executor thread, it can block forever. + pub fn recv(&self) -> T { + self.shared.recv() + } + + /// Panic if the queue is empty. + pub fn must_recv(&self) -> T { + self.shared + .try_recv() + .expect("message should've been ready") + } + + /// Get a message from the front of the queue, return None if the queue is empty. + /// Never blocks. + pub fn try_recv(&self) -> Option { + self.shared.try_recv() + } + + /// Send a message to the back of the queue. + pub fn send(&self, t: T) { + self.shared.send(t); + } +} + +struct State { + queue: Mutex>, + waker: Waker, +} + +impl State { + fn send(&self, t: T) { + self.queue.lock().push_back(t); + self.waker.wake_all(); + } + + fn try_recv(&self) -> Option { + let mut q = self.queue.lock(); + q.pop_front() + } + + fn recv(&self) -> T { + // interrupt the receiver to prevent consuming everything at once + executor::yield_me(0); + + let mut queue = self.queue.lock(); + if let Some(t) = queue.pop_front() { + return t; + } + loop { + self.waker.wake_me_later(); + if let Some(t) = queue.pop_front() { + return t; + } + MutexGuard::unlocked(&mut queue, || { + executor::yield_me(-1); + }); + } + } +} + +impl PollSome for Chan { + /// Schedules a wakeup for the current thread. + fn wake_me(&self) { + self.shared.waker.wake_me_later(); + } + + /// Checks if chan has any pending messages. + fn has_some(&self) -> bool { + !self.shared.queue.lock().is_empty() + } +} diff --git a/libs/desim/src/executor.rs b/libs/desim/src/executor.rs new file mode 100644 index 0000000000..9d44bd7741 --- /dev/null +++ b/libs/desim/src/executor.rs @@ -0,0 +1,483 @@ +use std::{ + panic::AssertUnwindSafe, + sync::{ + atomic::{AtomicBool, AtomicU32, AtomicU8, Ordering}, + mpsc, Arc, OnceLock, + }, + thread::JoinHandle, +}; + +use tracing::{debug, error, trace}; + +use crate::time::Timing; + +/// Stores status of the running threads. Threads are registered in the runtime upon creation +/// and deregistered upon termination. +pub struct Runtime { + // stores handles to all threads that are currently running + threads: Vec, + // stores current time and pending wakeups + clock: Arc, + // thread counter + thread_counter: AtomicU32, + // Thread step counter -- how many times all threads has been actually + // stepped (note that all world/time/executor/thread have slightly different + // meaning of steps). For observability. + pub step_counter: u64, +} + +impl Runtime { + /// Init new runtime, no running threads. + pub fn new(clock: Arc) -> Self { + Self { + threads: Vec::new(), + clock, + thread_counter: AtomicU32::new(0), + step_counter: 0, + } + } + + /// Spawn a new thread and register it in the runtime. + pub fn spawn(&mut self, f: F) -> ExternalHandle + where + F: FnOnce() + Send + 'static, + { + let (tx, rx) = mpsc::channel(); + + let clock = self.clock.clone(); + let tid = self.thread_counter.fetch_add(1, Ordering::SeqCst); + debug!("spawning thread-{}", tid); + + let join = std::thread::spawn(move || { + let _guard = tracing::info_span!("", tid).entered(); + + let res = std::panic::catch_unwind(AssertUnwindSafe(|| { + with_thread_context(|ctx| { + assert!(ctx.clock.set(clock).is_ok()); + ctx.id.store(tid, Ordering::SeqCst); + tx.send(ctx.clone()).expect("failed to send thread context"); + // suspend thread to put it to `threads` in sleeping state + ctx.yield_me(0); + }); + + // start user-provided function + f(); + })); + debug!("thread finished"); + + if let Err(e) = res { + with_thread_context(|ctx| { + if !ctx.allow_panic.load(std::sync::atomic::Ordering::SeqCst) { + error!("thread panicked, terminating the process: {:?}", e); + std::process::exit(1); + } + + debug!("thread panicked: {:?}", e); + let mut result = ctx.result.lock(); + if result.0 == -1 { + *result = (256, format!("thread panicked: {:?}", e)); + } + }); + } + + with_thread_context(|ctx| { + ctx.finish_me(); + }); + }); + + let ctx = rx.recv().expect("failed to receive thread context"); + let handle = ThreadHandle::new(ctx.clone(), join); + + self.threads.push(handle); + + ExternalHandle { ctx } + } + + /// Returns true if there are any unfinished activity, such as running thread or pending events. + /// Otherwise returns false, which means all threads are blocked forever. + pub fn step(&mut self) -> bool { + trace!("runtime step"); + + // have we run any thread? + let mut ran = false; + + self.threads.retain(|thread: &ThreadHandle| { + let res = thread.ctx.wakeup.compare_exchange( + PENDING_WAKEUP, + NO_WAKEUP, + Ordering::SeqCst, + Ordering::SeqCst, + ); + if res.is_err() { + // thread has no pending wakeups, leaving as is + return true; + } + ran = true; + + trace!("entering thread-{}", thread.ctx.tid()); + let status = thread.step(); + self.step_counter += 1; + trace!( + "out of thread-{} with status {:?}", + thread.ctx.tid(), + status + ); + + if status == Status::Sleep { + true + } else { + trace!("thread has finished"); + // removing the thread from the list + false + } + }); + + if !ran { + trace!("no threads were run, stepping clock"); + if let Some(ctx_to_wake) = self.clock.step() { + trace!("waking up thread-{}", ctx_to_wake.tid()); + ctx_to_wake.inc_wake(); + } else { + return false; + } + } + + true + } + + /// Kill all threads. This is done by setting a flag in each thread context and waking it up. + pub fn crash_all_threads(&mut self) { + for thread in self.threads.iter() { + thread.ctx.crash_stop(); + } + + // all threads should be finished after a few steps + while !self.threads.is_empty() { + self.step(); + } + } +} + +impl Drop for Runtime { + fn drop(&mut self) { + debug!("dropping the runtime"); + self.crash_all_threads(); + } +} + +#[derive(Clone)] +pub struct ExternalHandle { + ctx: Arc, +} + +impl ExternalHandle { + /// Returns true if thread has finished execution. + pub fn is_finished(&self) -> bool { + let status = self.ctx.mutex.lock(); + *status == Status::Finished + } + + /// Returns exitcode and message, which is available after thread has finished execution. + pub fn result(&self) -> (i32, String) { + let result = self.ctx.result.lock(); + result.clone() + } + + /// Returns thread id. + pub fn id(&self) -> u32 { + self.ctx.id.load(Ordering::SeqCst) + } + + /// Sets a flag to crash thread on the next wakeup. + pub fn crash_stop(&self) { + self.ctx.crash_stop(); + } +} + +struct ThreadHandle { + ctx: Arc, + _join: JoinHandle<()>, +} + +impl ThreadHandle { + /// Create a new [`ThreadHandle`] and wait until thread will enter [`Status::Sleep`] state. + fn new(ctx: Arc, join: JoinHandle<()>) -> Self { + let mut status = ctx.mutex.lock(); + // wait until thread will go into the first yield + while *status != Status::Sleep { + ctx.condvar.wait(&mut status); + } + drop(status); + + Self { ctx, _join: join } + } + + /// Allows thread to execute one step of its execution. + /// Returns [`Status`] of the thread after the step. + fn step(&self) -> Status { + let mut status = self.ctx.mutex.lock(); + assert!(matches!(*status, Status::Sleep)); + + *status = Status::Running; + self.ctx.condvar.notify_all(); + + while *status == Status::Running { + self.ctx.condvar.wait(&mut status); + } + + *status + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum Status { + /// Thread is running. + Running, + /// Waiting for event to complete, will be resumed by the executor step, once wakeup flag is set. + Sleep, + /// Thread finished execution. + Finished, +} + +const NO_WAKEUP: u8 = 0; +const PENDING_WAKEUP: u8 = 1; + +pub struct ThreadContext { + id: AtomicU32, + // used to block thread until it is woken up + mutex: parking_lot::Mutex, + condvar: parking_lot::Condvar, + // used as a flag to indicate runtime that thread is ready to be woken up + wakeup: AtomicU8, + clock: OnceLock>, + // execution result, set by exit() call + result: parking_lot::Mutex<(i32, String)>, + // determines if process should be killed on receiving panic + allow_panic: AtomicBool, + // acts as a signal that thread should crash itself on the next wakeup + crash_request: AtomicBool, +} + +impl ThreadContext { + pub(crate) fn new() -> Self { + Self { + id: AtomicU32::new(0), + mutex: parking_lot::Mutex::new(Status::Running), + condvar: parking_lot::Condvar::new(), + wakeup: AtomicU8::new(NO_WAKEUP), + clock: OnceLock::new(), + result: parking_lot::Mutex::new((-1, String::new())), + allow_panic: AtomicBool::new(false), + crash_request: AtomicBool::new(false), + } + } +} + +// Functions for executor to control thread execution. +impl ThreadContext { + /// Set atomic flag to indicate that thread is ready to be woken up. + fn inc_wake(&self) { + self.wakeup.store(PENDING_WAKEUP, Ordering::SeqCst); + } + + /// Internal function used for event queues. + pub(crate) fn schedule_wakeup(self: &Arc, after_ms: u64) { + self.clock + .get() + .unwrap() + .schedule_wakeup(after_ms, self.clone()); + } + + fn tid(&self) -> u32 { + self.id.load(Ordering::SeqCst) + } + + fn crash_stop(&self) { + let status = self.mutex.lock(); + if *status == Status::Finished { + debug!( + "trying to crash thread-{}, which is already finished", + self.tid() + ); + return; + } + assert!(matches!(*status, Status::Sleep)); + drop(status); + + self.allow_panic.store(true, Ordering::SeqCst); + self.crash_request.store(true, Ordering::SeqCst); + // set a wakeup + self.inc_wake(); + // it will panic on the next wakeup + } +} + +// Internal functions. +impl ThreadContext { + /// Blocks thread until it's woken up by the executor. If `after_ms` is 0, is will be + /// woken on the next step. If `after_ms` > 0, wakeup is scheduled after that time. + /// Otherwise wakeup is not scheduled inside `yield_me`, and should be arranged before + /// calling this function. + fn yield_me(self: &Arc, after_ms: i64) { + let mut status = self.mutex.lock(); + assert!(matches!(*status, Status::Running)); + + match after_ms.cmp(&0) { + std::cmp::Ordering::Less => { + // block until something wakes us up + } + std::cmp::Ordering::Equal => { + // tell executor that we are ready to be woken up + self.inc_wake(); + } + std::cmp::Ordering::Greater => { + // schedule wakeup + self.clock + .get() + .unwrap() + .schedule_wakeup(after_ms as u64, self.clone()); + } + } + + *status = Status::Sleep; + self.condvar.notify_all(); + + // wait until executor wakes us up + while *status != Status::Running { + self.condvar.wait(&mut status); + } + + if self.crash_request.load(Ordering::SeqCst) { + panic!("crashed by request"); + } + } + + /// Called only once, exactly before thread finishes execution. + fn finish_me(&self) { + let mut status = self.mutex.lock(); + assert!(matches!(*status, Status::Running)); + + *status = Status::Finished; + { + let mut result = self.result.lock(); + if result.0 == -1 { + *result = (0, "finished normally".to_owned()); + } + } + self.condvar.notify_all(); + } +} + +/// Invokes the given closure with a reference to the current thread [`ThreadContext`]. +#[inline(always)] +fn with_thread_context(f: impl FnOnce(&Arc) -> T) -> T { + thread_local!(static THREAD_DATA: Arc = Arc::new(ThreadContext::new())); + THREAD_DATA.with(f) +} + +/// Waker is used to wake up threads that are blocked on condition. +/// It keeps track of contexts [`Arc`] and can increment the counter +/// of several contexts to send a notification. +pub struct Waker { + // contexts that are waiting for a notification + contexts: parking_lot::Mutex; 8]>>, +} + +impl Default for Waker { + fn default() -> Self { + Self::new() + } +} + +impl Waker { + pub fn new() -> Self { + Self { + contexts: parking_lot::Mutex::new(smallvec::SmallVec::new()), + } + } + + /// Subscribe current thread to receive a wake notification later. + pub fn wake_me_later(&self) { + with_thread_context(|ctx| { + self.contexts.lock().push(ctx.clone()); + }); + } + + /// Wake up all threads that are waiting for a notification and clear the list. + pub fn wake_all(&self) { + let mut v = self.contexts.lock(); + for ctx in v.iter() { + ctx.inc_wake(); + } + v.clear(); + } +} + +/// See [`ThreadContext::yield_me`]. +pub fn yield_me(after_ms: i64) { + with_thread_context(|ctx| ctx.yield_me(after_ms)) +} + +/// Get current time. +pub fn now() -> u64 { + with_thread_context(|ctx| ctx.clock.get().unwrap().now()) +} + +pub fn exit(code: i32, msg: String) { + with_thread_context(|ctx| { + ctx.allow_panic.store(true, Ordering::SeqCst); + let mut result = ctx.result.lock(); + *result = (code, msg); + panic!("exit"); + }); +} + +pub(crate) fn get_thread_ctx() -> Arc { + with_thread_context(|ctx| ctx.clone()) +} + +/// Trait for polling channels until they have something. +pub trait PollSome { + /// Schedule wakeup for message arrival. + fn wake_me(&self); + + /// Check if channel has a ready message. + fn has_some(&self) -> bool; +} + +/// Blocks current thread until one of the channels has a ready message. Returns +/// index of the channel that has a message. If timeout is reached, returns None. +/// +/// Negative timeout means block forever. Zero timeout means check channels and return +/// immediately. Positive timeout means block until timeout is reached. +pub fn epoll_chans(chans: &[Box], timeout: i64) -> Option { + let deadline = if timeout < 0 { + 0 + } else { + now() + timeout as u64 + }; + + loop { + for chan in chans { + chan.wake_me() + } + + for (i, chan) in chans.iter().enumerate() { + if chan.has_some() { + return Some(i); + } + } + + if timeout < 0 { + // block until wakeup + yield_me(-1); + } else { + let current_time = now(); + if current_time >= deadline { + return None; + } + + yield_me((deadline - current_time) as i64); + } + } +} diff --git a/libs/desim/src/lib.rs b/libs/desim/src/lib.rs new file mode 100644 index 0000000000..14f5a885c5 --- /dev/null +++ b/libs/desim/src/lib.rs @@ -0,0 +1,8 @@ +pub mod chan; +pub mod executor; +pub mod network; +pub mod node_os; +pub mod options; +pub mod proto; +pub mod time; +pub mod world; diff --git a/libs/desim/src/network.rs b/libs/desim/src/network.rs new file mode 100644 index 0000000000..e15a714daa --- /dev/null +++ b/libs/desim/src/network.rs @@ -0,0 +1,451 @@ +use std::{ + cmp::Ordering, + collections::{BinaryHeap, VecDeque}, + fmt::{self, Debug}, + ops::DerefMut, + sync::{mpsc, Arc}, +}; + +use parking_lot::{ + lock_api::{MappedMutexGuard, MutexGuard}, + Mutex, RawMutex, +}; +use rand::rngs::StdRng; +use tracing::debug; + +use crate::{ + executor::{self, ThreadContext}, + options::NetworkOptions, + proto::NetEvent, + proto::NodeEvent, +}; + +use super::{chan::Chan, proto::AnyMessage}; + +pub struct NetworkTask { + options: Arc, + connections: Mutex>, + /// min-heap of connections having something to deliver. + events: Mutex>, + task_context: Arc, +} + +impl NetworkTask { + pub fn start_new(options: Arc, tx: mpsc::Sender>) { + let ctx = executor::get_thread_ctx(); + let task = Arc::new(Self { + options, + connections: Mutex::new(Vec::new()), + events: Mutex::new(BinaryHeap::new()), + task_context: ctx, + }); + + // send the task upstream + tx.send(task.clone()).unwrap(); + + // start the task + task.start(); + } + + pub fn start_new_connection(self: &Arc, rng: StdRng, dst_accept: Chan) -> TCP { + let now = executor::now(); + let connection_id = self.connections.lock().len(); + + let vc = VirtualConnection { + connection_id, + dst_accept, + dst_sockets: [Chan::new(), Chan::new()], + state: Mutex::new(ConnectionState { + buffers: [NetworkBuffer::new(None), NetworkBuffer::new(Some(now))], + rng, + }), + }; + vc.schedule_timeout(self); + vc.send_connect(self); + + let recv_chan = vc.dst_sockets[0].clone(); + self.connections.lock().push(vc); + + TCP { + net: self.clone(), + conn_id: connection_id, + dir: 0, + recv_chan, + } + } +} + +// private functions +impl NetworkTask { + /// Schedule to wakeup network task (self) `after_ms` later to deliver + /// messages of connection `id`. + fn schedule(&self, id: usize, after_ms: u64) { + self.events.lock().push(Event { + time: executor::now() + after_ms, + conn_id: id, + }); + self.task_context.schedule_wakeup(after_ms); + } + + /// Get locked connection `id`. + fn get(&self, id: usize) -> MappedMutexGuard<'_, RawMutex, VirtualConnection> { + MutexGuard::map(self.connections.lock(), |connections| { + connections.get_mut(id).unwrap() + }) + } + + fn collect_pending_events(&self, now: u64, vec: &mut Vec) { + vec.clear(); + let mut events = self.events.lock(); + while let Some(event) = events.peek() { + if event.time > now { + break; + } + let event = events.pop().unwrap(); + vec.push(event); + } + } + + fn start(self: &Arc) { + debug!("started network task"); + + let mut events = Vec::new(); + loop { + let now = executor::now(); + self.collect_pending_events(now, &mut events); + + for event in events.drain(..) { + let conn = self.get(event.conn_id); + conn.process(self); + } + + // block until wakeup + executor::yield_me(-1); + } + } +} + +// 0 - from node(0) to node(1) +// 1 - from node(1) to node(0) +type MessageDirection = u8; + +fn sender_str(dir: MessageDirection) -> &'static str { + match dir { + 0 => "client", + 1 => "server", + _ => unreachable!(), + } +} + +fn receiver_str(dir: MessageDirection) -> &'static str { + match dir { + 0 => "server", + 1 => "client", + _ => unreachable!(), + } +} + +/// Virtual connection between two nodes. +/// Node 0 is the creator of the connection (client), +/// and node 1 is the acceptor (server). +struct VirtualConnection { + connection_id: usize, + /// one-off chan, used to deliver Accept message to dst + dst_accept: Chan, + /// message sinks + dst_sockets: [Chan; 2], + state: Mutex, +} + +struct ConnectionState { + buffers: [NetworkBuffer; 2], + rng: StdRng, +} + +impl VirtualConnection { + /// Notify the future about the possible timeout. + fn schedule_timeout(&self, net: &NetworkTask) { + if let Some(timeout) = net.options.keepalive_timeout { + net.schedule(self.connection_id, timeout); + } + } + + /// Send the handshake (Accept) to the server. + fn send_connect(&self, net: &NetworkTask) { + let now = executor::now(); + let mut state = self.state.lock(); + let delay = net.options.connect_delay.delay(&mut state.rng); + let buffer = &mut state.buffers[0]; + assert!(buffer.buf.is_empty()); + assert!(!buffer.recv_closed); + assert!(!buffer.send_closed); + assert!(buffer.last_recv.is_none()); + + let delay = if let Some(ms) = delay { + ms + } else { + debug!("NET: TCP #{} dropped connect", self.connection_id); + buffer.send_closed = true; + return; + }; + + // Send a message into the future. + buffer + .buf + .push_back((now + delay, AnyMessage::InternalConnect)); + net.schedule(self.connection_id, delay); + } + + /// Transmit some of the messages from the buffer to the nodes. + fn process(&self, net: &Arc) { + let now = executor::now(); + + let mut state = self.state.lock(); + + for direction in 0..2 { + self.process_direction( + net, + state.deref_mut(), + now, + direction as MessageDirection, + &self.dst_sockets[direction ^ 1], + ); + } + + // Close the one side of the connection by timeout if the node + // has not received any messages for a long time. + if let Some(timeout) = net.options.keepalive_timeout { + let mut to_close = [false, false]; + for direction in 0..2 { + let buffer = &mut state.buffers[direction]; + if buffer.recv_closed { + continue; + } + if let Some(last_recv) = buffer.last_recv { + if now - last_recv >= timeout { + debug!( + "NET: connection {} timed out at {}", + self.connection_id, + receiver_str(direction as MessageDirection) + ); + let node_idx = direction ^ 1; + to_close[node_idx] = true; + } + } + } + drop(state); + + for (node_idx, should_close) in to_close.iter().enumerate() { + if *should_close { + self.close(node_idx); + } + } + } + } + + /// Process messages in the buffer in the given direction. + fn process_direction( + &self, + net: &Arc, + state: &mut ConnectionState, + now: u64, + direction: MessageDirection, + to_socket: &Chan, + ) { + let buffer = &mut state.buffers[direction as usize]; + if buffer.recv_closed { + assert!(buffer.buf.is_empty()); + } + + while !buffer.buf.is_empty() && buffer.buf.front().unwrap().0 <= now { + let msg = buffer.buf.pop_front().unwrap().1; + + buffer.last_recv = Some(now); + self.schedule_timeout(net); + + if let AnyMessage::InternalConnect = msg { + // TODO: assert to_socket is the server + let server_to_client = TCP { + net: net.clone(), + conn_id: self.connection_id, + dir: direction ^ 1, + recv_chan: to_socket.clone(), + }; + // special case, we need to deliver new connection to a separate channel + self.dst_accept.send(NodeEvent::Accept(server_to_client)); + } else { + to_socket.send(NetEvent::Message(msg)); + } + } + } + + /// Try to send a message to the buffer, optionally dropping it and + /// determining delivery timestamp. + fn send(&self, net: &NetworkTask, direction: MessageDirection, msg: AnyMessage) { + let now = executor::now(); + let mut state = self.state.lock(); + + let (delay, close) = if let Some(ms) = net.options.send_delay.delay(&mut state.rng) { + (ms, false) + } else { + (0, true) + }; + + let buffer = &mut state.buffers[direction as usize]; + if buffer.send_closed { + debug!( + "NET: TCP #{} dropped message {:?} (broken pipe)", + self.connection_id, msg + ); + return; + } + + if close { + debug!( + "NET: TCP #{} dropped message {:?} (pipe just broke)", + self.connection_id, msg + ); + buffer.send_closed = true; + return; + } + + if buffer.recv_closed { + debug!( + "NET: TCP #{} dropped message {:?} (recv closed)", + self.connection_id, msg + ); + return; + } + + // Send a message into the future. + buffer.buf.push_back((now + delay, msg)); + net.schedule(self.connection_id, delay); + } + + /// Close the connection. Only one side of the connection will be closed, + /// and no further messages will be delivered. The other side will not be notified. + fn close(&self, node_idx: usize) { + let mut state = self.state.lock(); + let recv_buffer = &mut state.buffers[1 ^ node_idx]; + if recv_buffer.recv_closed { + debug!( + "NET: TCP #{} closed twice at {}", + self.connection_id, + sender_str(node_idx as MessageDirection), + ); + return; + } + + debug!( + "NET: TCP #{} closed at {}", + self.connection_id, + sender_str(node_idx as MessageDirection), + ); + recv_buffer.recv_closed = true; + for msg in recv_buffer.buf.drain(..) { + debug!( + "NET: TCP #{} dropped message {:?} (closed)", + self.connection_id, msg + ); + } + + let send_buffer = &mut state.buffers[node_idx]; + send_buffer.send_closed = true; + drop(state); + + // TODO: notify the other side? + + self.dst_sockets[node_idx].send(NetEvent::Closed); + } +} + +struct NetworkBuffer { + /// Messages paired with time of delivery + buf: VecDeque<(u64, AnyMessage)>, + /// True if the connection is closed on the receiving side, + /// i.e. no more messages from the buffer will be delivered. + recv_closed: bool, + /// True if the connection is closed on the sending side, + /// i.e. no more messages will be added to the buffer. + send_closed: bool, + /// Last time a message was delivered from the buffer. + /// If None, it means that the server is the receiver and + /// it has not yet aware of this connection (i.e. has not + /// received the Accept). + last_recv: Option, +} + +impl NetworkBuffer { + fn new(last_recv: Option) -> Self { + Self { + buf: VecDeque::new(), + recv_closed: false, + send_closed: false, + last_recv, + } + } +} + +/// Single end of a bidirectional network stream without reordering (TCP-like). +/// Reads are implemented using channels, writes go to the buffer inside VirtualConnection. +pub struct TCP { + net: Arc, + conn_id: usize, + dir: MessageDirection, + recv_chan: Chan, +} + +impl Debug for TCP { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "TCP #{} ({})", self.conn_id, sender_str(self.dir),) + } +} + +impl TCP { + /// Send a message to the other side. It's guaranteed that it will not arrive + /// before the arrival of all messages sent earlier. + pub fn send(&self, msg: AnyMessage) { + let conn = self.net.get(self.conn_id); + conn.send(&self.net, self.dir, msg); + } + + /// Get a channel to receive incoming messages. + pub fn recv_chan(&self) -> Chan { + self.recv_chan.clone() + } + + pub fn connection_id(&self) -> usize { + self.conn_id + } + + pub fn close(&self) { + let conn = self.net.get(self.conn_id); + conn.close(self.dir as usize); + } +} +struct Event { + time: u64, + conn_id: usize, +} + +// BinaryHeap is a max-heap, and we want a min-heap. Reverse the ordering here +// to get that. +impl PartialOrd for Event { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for Event { + fn cmp(&self, other: &Self) -> Ordering { + (other.time, other.conn_id).cmp(&(self.time, self.conn_id)) + } +} + +impl PartialEq for Event { + fn eq(&self, other: &Self) -> bool { + (other.time, other.conn_id) == (self.time, self.conn_id) + } +} + +impl Eq for Event {} diff --git a/libs/desim/src/node_os.rs b/libs/desim/src/node_os.rs new file mode 100644 index 0000000000..7744a9f5e1 --- /dev/null +++ b/libs/desim/src/node_os.rs @@ -0,0 +1,54 @@ +use std::sync::Arc; + +use rand::Rng; + +use crate::proto::NodeEvent; + +use super::{ + chan::Chan, + network::TCP, + world::{Node, NodeId, World}, +}; + +/// Abstraction with all functions (aka syscalls) available to the node. +#[derive(Clone)] +pub struct NodeOs { + world: Arc, + internal: Arc, +} + +impl NodeOs { + pub fn new(world: Arc, internal: Arc) -> NodeOs { + NodeOs { world, internal } + } + + /// Get the node id. + pub fn id(&self) -> NodeId { + self.internal.id + } + + /// Opens a bidirectional connection with the other node. Always successful. + pub fn open_tcp(&self, dst: NodeId) -> TCP { + self.world.open_tcp(dst) + } + + /// Returns a channel to receive node events (socket Accept and internal messages). + pub fn node_events(&self) -> Chan { + self.internal.node_events() + } + + /// Get current time. + pub fn now(&self) -> u64 { + self.world.now() + } + + /// Generate a random number in range [0, max). + pub fn random(&self, max: u64) -> u64 { + self.internal.rng.lock().gen_range(0..max) + } + + /// Append a new event to the world event log. + pub fn log_event(&self, data: String) { + self.internal.log_event(data) + } +} diff --git a/libs/desim/src/options.rs b/libs/desim/src/options.rs new file mode 100644 index 0000000000..5da7c2c482 --- /dev/null +++ b/libs/desim/src/options.rs @@ -0,0 +1,50 @@ +use rand::{rngs::StdRng, Rng}; + +/// Describes random delays and failures. Delay will be uniformly distributed in [min, max]. +/// Connection failure will occur with the probablity fail_prob. +#[derive(Clone, Debug)] +pub struct Delay { + pub min: u64, + pub max: u64, + pub fail_prob: f64, // [0; 1] +} + +impl Delay { + /// Create a struct with no delay, no failures. + pub fn empty() -> Delay { + Delay { + min: 0, + max: 0, + fail_prob: 0.0, + } + } + + /// Create a struct with a fixed delay. + pub fn fixed(ms: u64) -> Delay { + Delay { + min: ms, + max: ms, + fail_prob: 0.0, + } + } + + /// Generate a random delay in range [min, max]. Return None if the + /// message should be dropped. + pub fn delay(&self, rng: &mut StdRng) -> Option { + if rng.gen_bool(self.fail_prob) { + return None; + } + Some(rng.gen_range(self.min..=self.max)) + } +} + +/// Describes network settings. All network packets will be subjected to the same delays and failures. +#[derive(Clone, Debug)] +pub struct NetworkOptions { + /// Connection will be automatically closed after this timeout if no data is received. + pub keepalive_timeout: Option, + /// New connections will be delayed by this amount of time. + pub connect_delay: Delay, + /// Each message will be delayed by this amount of time. + pub send_delay: Delay, +} diff --git a/libs/desim/src/proto.rs b/libs/desim/src/proto.rs new file mode 100644 index 0000000000..92a7e8a27d --- /dev/null +++ b/libs/desim/src/proto.rs @@ -0,0 +1,63 @@ +use std::fmt::Debug; + +use bytes::Bytes; +use utils::lsn::Lsn; + +use crate::{network::TCP, world::NodeId}; + +/// Internal node events. +#[derive(Debug)] +pub enum NodeEvent { + Accept(TCP), + Internal(AnyMessage), +} + +/// Events that are coming from a network socket. +#[derive(Clone, Debug)] +pub enum NetEvent { + Message(AnyMessage), + Closed, +} + +/// Custom events generated throughout the simulation. Can be used by the test to verify the correctness. +#[derive(Debug)] +pub struct SimEvent { + pub time: u64, + pub node: NodeId, + pub data: String, +} + +/// Umbrella type for all possible flavours of messages. These events can be sent over network +/// or to an internal node events channel. +#[derive(Clone)] +pub enum AnyMessage { + /// Not used, empty placeholder. + None, + /// Used internally for notifying node about new incoming connection. + InternalConnect, + Just32(u32), + ReplCell(ReplCell), + Bytes(Bytes), + LSN(u64), +} + +impl Debug for AnyMessage { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + AnyMessage::None => write!(f, "None"), + AnyMessage::InternalConnect => write!(f, "InternalConnect"), + AnyMessage::Just32(v) => write!(f, "Just32({})", v), + AnyMessage::ReplCell(v) => write!(f, "ReplCell({:?})", v), + AnyMessage::Bytes(v) => write!(f, "Bytes({})", hex::encode(v)), + AnyMessage::LSN(v) => write!(f, "LSN({})", Lsn(*v)), + } + } +} + +/// Used in reliable_copy_test.rs +#[derive(Clone, Debug)] +pub struct ReplCell { + pub value: u32, + pub client_id: u32, + pub seqno: u32, +} diff --git a/libs/desim/src/time.rs b/libs/desim/src/time.rs new file mode 100644 index 0000000000..7bb71db95c --- /dev/null +++ b/libs/desim/src/time.rs @@ -0,0 +1,129 @@ +use std::{ + cmp::Ordering, + collections::BinaryHeap, + ops::DerefMut, + sync::{ + atomic::{AtomicU32, AtomicU64}, + Arc, + }, +}; + +use parking_lot::Mutex; +use tracing::trace; + +use crate::executor::ThreadContext; + +/// Holds current time and all pending wakeup events. +pub struct Timing { + /// Current world's time. + current_time: AtomicU64, + /// Pending timers. + queue: Mutex>, + /// Global nonce. Makes picking events from binary heap queue deterministic + /// by appending a number to events with the same timestamp. + nonce: AtomicU32, + /// Used to schedule fake events. + fake_context: Arc, +} + +impl Default for Timing { + fn default() -> Self { + Self::new() + } +} + +impl Timing { + /// Create a new empty clock with time set to 0. + pub fn new() -> Timing { + Timing { + current_time: AtomicU64::new(0), + queue: Mutex::new(BinaryHeap::new()), + nonce: AtomicU32::new(0), + fake_context: Arc::new(ThreadContext::new()), + } + } + + /// Return the current world's time. + pub fn now(&self) -> u64 { + self.current_time.load(std::sync::atomic::Ordering::SeqCst) + } + + /// Tick-tock the global clock. Return the event ready to be processed + /// or move the clock forward and then return the event. + pub(crate) fn step(&self) -> Option> { + let mut queue = self.queue.lock(); + + if queue.is_empty() { + // no future events + return None; + } + + if !self.is_event_ready(queue.deref_mut()) { + let next_time = queue.peek().unwrap().time; + self.current_time + .store(next_time, std::sync::atomic::Ordering::SeqCst); + trace!("rewind time to {}", next_time); + assert!(self.is_event_ready(queue.deref_mut())); + } + + Some(queue.pop().unwrap().wake_context) + } + + /// Append an event to the queue, to wakeup the thread in `ms` milliseconds. + pub(crate) fn schedule_wakeup(&self, ms: u64, wake_context: Arc) { + self.nonce.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + let nonce = self.nonce.load(std::sync::atomic::Ordering::SeqCst); + self.queue.lock().push(Pending { + time: self.now() + ms, + nonce, + wake_context, + }) + } + + /// Append a fake event to the queue, to prevent clocks from skipping this time. + pub fn schedule_fake(&self, ms: u64) { + self.queue.lock().push(Pending { + time: self.now() + ms, + nonce: 0, + wake_context: self.fake_context.clone(), + }); + } + + /// Return true if there is a ready event. + fn is_event_ready(&self, queue: &mut BinaryHeap) -> bool { + queue.peek().map_or(false, |x| x.time <= self.now()) + } + + /// Clear all pending events. + pub(crate) fn clear(&self) { + self.queue.lock().clear(); + } +} + +struct Pending { + time: u64, + nonce: u32, + wake_context: Arc, +} + +// BinaryHeap is a max-heap, and we want a min-heap. Reverse the ordering here +// to get that. +impl PartialOrd for Pending { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for Pending { + fn cmp(&self, other: &Self) -> Ordering { + (other.time, other.nonce).cmp(&(self.time, self.nonce)) + } +} + +impl PartialEq for Pending { + fn eq(&self, other: &Self) -> bool { + (other.time, other.nonce) == (self.time, self.nonce) + } +} + +impl Eq for Pending {} diff --git a/libs/desim/src/world.rs b/libs/desim/src/world.rs new file mode 100644 index 0000000000..7d60be04b5 --- /dev/null +++ b/libs/desim/src/world.rs @@ -0,0 +1,180 @@ +use parking_lot::Mutex; +use rand::{rngs::StdRng, SeedableRng}; +use std::{ + ops::DerefMut, + sync::{mpsc, Arc}, +}; + +use crate::{ + executor::{ExternalHandle, Runtime}, + network::NetworkTask, + options::NetworkOptions, + proto::{NodeEvent, SimEvent}, + time::Timing, +}; + +use super::{chan::Chan, network::TCP, node_os::NodeOs}; + +pub type NodeId = u32; + +/// World contains simulation state. +pub struct World { + nodes: Mutex>>, + /// Random number generator. + rng: Mutex, + /// Internal event log. + events: Mutex>, + /// Separate task that processes all network messages. + network_task: Arc, + /// Runtime for running threads and moving time. + runtime: Mutex, + /// To get current time. + timing: Arc, +} + +impl World { + pub fn new(seed: u64, options: Arc) -> World { + let timing = Arc::new(Timing::new()); + let mut runtime = Runtime::new(timing.clone()); + + let (tx, rx) = mpsc::channel(); + + runtime.spawn(move || { + // create and start network background thread, and send it back via the channel + NetworkTask::start_new(options, tx) + }); + + // wait for the network task to start + while runtime.step() {} + + let network_task = rx.recv().unwrap(); + + World { + nodes: Mutex::new(Vec::new()), + rng: Mutex::new(StdRng::seed_from_u64(seed)), + events: Mutex::new(Vec::new()), + network_task, + runtime: Mutex::new(runtime), + timing, + } + } + + pub fn step(&self) -> bool { + self.runtime.lock().step() + } + + pub fn get_thread_step_count(&self) -> u64 { + self.runtime.lock().step_counter + } + + /// Create a new random number generator. + pub fn new_rng(&self) -> StdRng { + let mut rng = self.rng.lock(); + StdRng::from_rng(rng.deref_mut()).unwrap() + } + + /// Create a new node. + pub fn new_node(self: &Arc) -> Arc { + let mut nodes = self.nodes.lock(); + let id = nodes.len() as NodeId; + let node = Arc::new(Node::new(id, self.clone(), self.new_rng())); + nodes.push(node.clone()); + node + } + + /// Get an internal node state by id. + fn get_node(&self, id: NodeId) -> Option> { + let nodes = self.nodes.lock(); + let num = id as usize; + if num < nodes.len() { + Some(nodes[num].clone()) + } else { + None + } + } + + pub fn stop_all(&self) { + self.runtime.lock().crash_all_threads(); + } + + /// Returns a writable end of a TCP connection, to send src->dst messages. + pub fn open_tcp(self: &Arc, dst: NodeId) -> TCP { + // TODO: replace unwrap() with /dev/null socket. + let dst = self.get_node(dst).unwrap(); + let dst_accept = dst.node_events.lock().clone(); + + let rng = self.new_rng(); + self.network_task.start_new_connection(rng, dst_accept) + } + + /// Get current time. + pub fn now(&self) -> u64 { + self.timing.now() + } + + /// Get a copy of the internal clock. + pub fn clock(&self) -> Arc { + self.timing.clone() + } + + pub fn add_event(&self, node: NodeId, data: String) { + let time = self.now(); + self.events.lock().push(SimEvent { time, node, data }); + } + + pub fn take_events(&self) -> Vec { + let mut events = self.events.lock(); + let mut res = Vec::new(); + std::mem::swap(&mut res, &mut events); + res + } + + pub fn deallocate(&self) { + self.stop_all(); + self.timing.clear(); + self.nodes.lock().clear(); + } +} + +/// Internal node state. +pub struct Node { + pub id: NodeId, + node_events: Mutex>, + world: Arc, + pub(crate) rng: Mutex, +} + +impl Node { + pub fn new(id: NodeId, world: Arc, rng: StdRng) -> Node { + Node { + id, + node_events: Mutex::new(Chan::new()), + world, + rng: Mutex::new(rng), + } + } + + /// Spawn a new thread with this node context. + pub fn launch(self: &Arc, f: impl FnOnce(NodeOs) + Send + 'static) -> ExternalHandle { + let node = self.clone(); + let world = self.world.clone(); + self.world.runtime.lock().spawn(move || { + f(NodeOs::new(world, node.clone())); + }) + } + + /// Returns a channel to receive Accepts and internal messages. + pub fn node_events(&self) -> Chan { + self.node_events.lock().clone() + } + + /// This will drop all in-flight Accept messages. + pub fn replug_node_events(&self, chan: Chan) { + *self.node_events.lock() = chan; + } + + /// Append event to the world's log. + pub fn log_event(&self, data: String) { + self.world.add_event(self.id, data) + } +} diff --git a/libs/desim/tests/reliable_copy_test.rs b/libs/desim/tests/reliable_copy_test.rs new file mode 100644 index 0000000000..cf7bff8f5a --- /dev/null +++ b/libs/desim/tests/reliable_copy_test.rs @@ -0,0 +1,244 @@ +//! Simple test to verify that simulator is working. +#[cfg(test)] +mod reliable_copy_test { + use anyhow::Result; + use desim::executor::{self, PollSome}; + use desim::options::{Delay, NetworkOptions}; + use desim::proto::{NetEvent, NodeEvent, ReplCell}; + use desim::world::{NodeId, World}; + use desim::{node_os::NodeOs, proto::AnyMessage}; + use parking_lot::Mutex; + use std::sync::Arc; + use tracing::info; + + /// Disk storage trait and implementation. + pub trait Storage { + fn flush_pos(&self) -> u32; + fn flush(&mut self) -> Result<()>; + fn write(&mut self, t: T); + } + + #[derive(Clone)] + pub struct SharedStorage { + pub state: Arc>>, + } + + impl SharedStorage { + pub fn new() -> Self { + Self { + state: Arc::new(Mutex::new(InMemoryStorage::new())), + } + } + } + + impl Storage for SharedStorage { + fn flush_pos(&self) -> u32 { + self.state.lock().flush_pos + } + + fn flush(&mut self) -> Result<()> { + executor::yield_me(0); + self.state.lock().flush() + } + + fn write(&mut self, t: T) { + executor::yield_me(0); + self.state.lock().write(t); + } + } + + pub struct InMemoryStorage { + pub data: Vec, + pub flush_pos: u32, + } + + impl InMemoryStorage { + pub fn new() -> Self { + Self { + data: Vec::new(), + flush_pos: 0, + } + } + + pub fn flush(&mut self) -> Result<()> { + self.flush_pos = self.data.len() as u32; + Ok(()) + } + + pub fn write(&mut self, t: T) { + self.data.push(t); + } + } + + /// Server implementation. + pub fn run_server(os: NodeOs, mut storage: Box>) { + info!("started server"); + + let node_events = os.node_events(); + let mut epoll_vec: Vec> = vec![Box::new(node_events.clone())]; + let mut sockets = vec![]; + + loop { + let index = executor::epoll_chans(&epoll_vec, -1).unwrap(); + + if index == 0 { + let node_event = node_events.must_recv(); + info!("got node event: {:?}", node_event); + if let NodeEvent::Accept(tcp) = node_event { + tcp.send(AnyMessage::Just32(storage.flush_pos())); + epoll_vec.push(Box::new(tcp.recv_chan())); + sockets.push(tcp); + } + continue; + } + + let recv_chan = sockets[index - 1].recv_chan(); + let socket = &sockets[index - 1]; + + let event = recv_chan.must_recv(); + info!("got event: {:?}", event); + if let NetEvent::Message(AnyMessage::ReplCell(cell)) = event { + if cell.seqno != storage.flush_pos() { + info!("got out of order data: {:?}", cell); + continue; + } + storage.write(cell.value); + storage.flush().unwrap(); + socket.send(AnyMessage::Just32(storage.flush_pos())); + } + } + } + + /// Client copies all data from array to the remote node. + pub fn run_client(os: NodeOs, data: &[ReplCell], dst: NodeId) { + info!("started client"); + + let mut delivered = 0; + + let mut sock = os.open_tcp(dst); + let mut recv_chan = sock.recv_chan(); + + while delivered < data.len() { + let num = &data[delivered]; + info!("sending data: {:?}", num.clone()); + sock.send(AnyMessage::ReplCell(num.clone())); + + // loop { + let event = recv_chan.recv(); + match event { + NetEvent::Message(AnyMessage::Just32(flush_pos)) => { + if flush_pos == 1 + delivered as u32 { + delivered += 1; + } + } + NetEvent::Closed => { + info!("connection closed, reestablishing"); + sock = os.open_tcp(dst); + recv_chan = sock.recv_chan(); + } + _ => {} + } + + // } + } + + let sock = os.open_tcp(dst); + for num in data { + info!("sending data: {:?}", num.clone()); + sock.send(AnyMessage::ReplCell(num.clone())); + } + + info!("sent all data and finished client"); + } + + /// Run test simulations. + #[test] + fn sim_example_reliable_copy() { + utils::logging::init( + utils::logging::LogFormat::Test, + utils::logging::TracingErrorLayerEnablement::Disabled, + utils::logging::Output::Stdout, + ) + .expect("logging init failed"); + + let delay = Delay { + min: 1, + max: 60, + fail_prob: 0.4, + }; + + let network = NetworkOptions { + keepalive_timeout: Some(50), + connect_delay: delay.clone(), + send_delay: delay.clone(), + }; + + for seed in 0..20 { + let u32_data: [u32; 5] = [1, 2, 3, 4, 5]; + let data = u32_to_cells(&u32_data, 1); + let world = Arc::new(World::new(seed, Arc::new(network.clone()))); + + start_simulation(Options { + world, + time_limit: 1_000_000, + client_fn: Box::new(move |os, server_id| run_client(os, &data, server_id)), + u32_data, + }); + } + } + + pub struct Options { + pub world: Arc, + pub time_limit: u64, + pub u32_data: [u32; 5], + pub client_fn: Box, + } + + pub fn start_simulation(options: Options) { + let world = options.world; + + let client_node = world.new_node(); + let server_node = world.new_node(); + let server_id = server_node.id; + + // start the client thread + client_node.launch(move |os| { + let client_fn = options.client_fn; + client_fn(os, server_id); + }); + + // start the server thread + let shared_storage = SharedStorage::new(); + let server_storage = shared_storage.clone(); + server_node.launch(move |os| run_server(os, Box::new(server_storage))); + + while world.step() && world.now() < options.time_limit {} + + let disk_data = shared_storage.state.lock().data.clone(); + assert!(verify_data(&disk_data, &options.u32_data[..])); + } + + pub fn u32_to_cells(data: &[u32], client_id: u32) -> Vec { + let mut res = Vec::new(); + for (i, _) in data.iter().enumerate() { + res.push(ReplCell { + client_id, + seqno: i as u32, + value: data[i], + }); + } + res + } + + fn verify_data(disk_data: &[u32], data: &[u32]) -> bool { + if disk_data.len() != data.len() { + return false; + } + for i in 0..data.len() { + if disk_data[i] != data[i] { + return false; + } + } + true + } +} diff --git a/libs/postgres_ffi/src/xlog_utils.rs b/libs/postgres_ffi/src/xlog_utils.rs index a863fad269..977653848d 100644 --- a/libs/postgres_ffi/src/xlog_utils.rs +++ b/libs/postgres_ffi/src/xlog_utils.rs @@ -431,11 +431,11 @@ pub fn generate_wal_segment(segno: u64, system_id: u64, lsn: Lsn) -> Result anyhow::Result<()> { println!("cargo:rustc-link-lib=static=walproposer"); println!("cargo:rustc-link-search={walproposer_lib_search_str}"); + // Rebuild crate when libwalproposer.a changes + println!("cargo:rerun-if-changed={walproposer_lib_search_str}/libwalproposer.a"); + let pg_config_bin = pg_install_abs.join("v16").join("bin").join("pg_config"); let inc_server_path: String = if pg_config_bin.exists() { let output = Command::new(pg_config_bin) @@ -79,6 +82,7 @@ fn main() -> anyhow::Result<()> { .allowlist_function("WalProposerBroadcast") .allowlist_function("WalProposerPoll") .allowlist_function("WalProposerFree") + .allowlist_function("SafekeeperStateDesiredEvents") .allowlist_var("DEBUG5") .allowlist_var("DEBUG4") .allowlist_var("DEBUG3") diff --git a/libs/walproposer/src/api_bindings.rs b/libs/walproposer/src/api_bindings.rs index 1f7bf952dc..8317e2fa03 100644 --- a/libs/walproposer/src/api_bindings.rs +++ b/libs/walproposer/src/api_bindings.rs @@ -22,6 +22,7 @@ use crate::bindings::WalProposerExecStatusType; use crate::bindings::WalproposerShmemState; use crate::bindings::XLogRecPtr; use crate::walproposer::ApiImpl; +use crate::walproposer::StreamingCallback; use crate::walproposer::WaitResult; extern "C" fn get_shmem_state(wp: *mut WalProposer) -> *mut WalproposerShmemState { @@ -36,7 +37,8 @@ extern "C" fn start_streaming(wp: *mut WalProposer, startpos: XLogRecPtr) { unsafe { let callback_data = (*(*wp).config).callback_data; let api = callback_data as *mut Box; - (*api).start_streaming(startpos) + let callback = StreamingCallback::new(wp); + (*api).start_streaming(startpos, &callback); } } @@ -134,19 +136,18 @@ extern "C" fn conn_async_read( unsafe { let callback_data = (*(*(*sk).wp).config).callback_data; let api = callback_data as *mut Box; - let (res, result) = (*api).conn_async_read(&mut (*sk)); // This function has guarantee that returned buf will be valid until // the next call. So we can store a Vec in each Safekeeper and reuse // it on the next call. let mut inbuf = take_vec_u8(&mut (*sk).inbuf).unwrap_or_default(); - inbuf.clear(); - inbuf.extend_from_slice(res); + + let result = (*api).conn_async_read(&mut (*sk), &mut inbuf); // Put a Vec back to sk->inbuf and return data ptr. + *amount = inbuf.len() as i32; *buf = store_vec_u8(&mut (*sk).inbuf, inbuf); - *amount = res.len() as i32; result } @@ -182,6 +183,10 @@ extern "C" fn recovery_download(wp: *mut WalProposer, sk: *mut Safekeeper) -> bo unsafe { let callback_data = (*(*(*sk).wp).config).callback_data; let api = callback_data as *mut Box; + + // currently `recovery_download` is always called right after election + (*api).after_election(&mut (*wp)); + (*api).recovery_download(&mut (*wp), &mut (*sk)) } } @@ -277,7 +282,8 @@ extern "C" fn wait_event_set( } WaitResult::Timeout => { *event_sk = std::ptr::null_mut(); - *events = crate::bindings::WL_TIMEOUT; + // WaitEventSetWait returns 0 for timeout. + *events = 0; 0 } WaitResult::Network(sk, event_mask) => { @@ -340,7 +346,7 @@ extern "C" fn log_internal( } } -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub enum Level { Debug5, Debug4, diff --git a/libs/walproposer/src/walproposer.rs b/libs/walproposer/src/walproposer.rs index 8ab8fb1a07..13fade220c 100644 --- a/libs/walproposer/src/walproposer.rs +++ b/libs/walproposer/src/walproposer.rs @@ -1,13 +1,13 @@ use std::ffi::CString; use postgres_ffi::WAL_SEGMENT_SIZE; -use utils::id::TenantTimelineId; +use utils::{id::TenantTimelineId, lsn::Lsn}; use crate::{ api_bindings::{create_api, take_vec_u8, Level}, bindings::{ - NeonWALReadResult, Safekeeper, WalProposer, WalProposerConfig, WalProposerCreate, - WalProposerFree, WalProposerStart, + NeonWALReadResult, Safekeeper, WalProposer, WalProposerBroadcast, WalProposerConfig, + WalProposerCreate, WalProposerFree, WalProposerPoll, WalProposerStart, }, }; @@ -16,11 +16,11 @@ use crate::{ /// /// Refer to `pgxn/neon/walproposer.h` for documentation. pub trait ApiImpl { - fn get_shmem_state(&self) -> &mut crate::bindings::WalproposerShmemState { + fn get_shmem_state(&self) -> *mut crate::bindings::WalproposerShmemState { todo!() } - fn start_streaming(&self, _startpos: u64) { + fn start_streaming(&self, _startpos: u64, _callback: &StreamingCallback) { todo!() } @@ -70,7 +70,11 @@ pub trait ApiImpl { todo!() } - fn conn_async_read(&self, _sk: &mut Safekeeper) -> (&[u8], crate::bindings::PGAsyncReadResult) { + fn conn_async_read( + &self, + _sk: &mut Safekeeper, + _vec: &mut Vec, + ) -> crate::bindings::PGAsyncReadResult { todo!() } @@ -151,12 +155,14 @@ pub trait ApiImpl { } } +#[derive(Debug)] pub enum WaitResult { Latch, Timeout, Network(*mut Safekeeper, u32), } +#[derive(Clone)] pub struct Config { /// Tenant and timeline id pub ttid: TenantTimelineId, @@ -242,6 +248,24 @@ impl Drop for Wrapper { } } +pub struct StreamingCallback { + wp: *mut WalProposer, +} + +impl StreamingCallback { + pub fn new(wp: *mut WalProposer) -> StreamingCallback { + StreamingCallback { wp } + } + + pub fn broadcast(&self, startpos: Lsn, endpos: Lsn) { + unsafe { WalProposerBroadcast(self.wp, startpos.0, endpos.0) } + } + + pub fn poll(&self) { + unsafe { WalProposerPoll(self.wp) } + } +} + #[cfg(test)] mod tests { use core::panic; @@ -344,14 +368,13 @@ mod tests { fn conn_async_read( &self, _: &mut crate::bindings::Safekeeper, - ) -> (&[u8], crate::bindings::PGAsyncReadResult) { + vec: &mut Vec, + ) -> crate::bindings::PGAsyncReadResult { println!("conn_async_read"); let reply = self.next_safekeeper_reply(); println!("conn_async_read result: {:?}", reply); - ( - reply, - crate::bindings::PGAsyncReadResult_PG_ASYNC_READ_SUCCESS, - ) + vec.extend_from_slice(reply); + crate::bindings::PGAsyncReadResult_PG_ASYNC_READ_SUCCESS } fn conn_blocking_write(&self, _: &mut crate::bindings::Safekeeper, buf: &[u8]) -> bool { diff --git a/pageserver/src/walingest.rs b/pageserver/src/walingest.rs index 93d1dcab35..12ceac0191 100644 --- a/pageserver/src/walingest.rs +++ b/pageserver/src/walingest.rs @@ -346,7 +346,7 @@ impl WalIngest { let info = decoded.xl_info & pg_constants::XLR_RMGR_INFO_MASK; if info == pg_constants::XLOG_LOGICAL_MESSAGE { - let xlrec = XlLogicalMessage::decode(&mut buf); + let xlrec = crate::walrecord::XlLogicalMessage::decode(&mut buf); let prefix = std::str::from_utf8(&buf[0..xlrec.prefix_size - 1])?; let message = &buf[xlrec.prefix_size..xlrec.prefix_size + xlrec.message_size]; if prefix == "neon-test" { diff --git a/pgxn/neon/walproposer.c b/pgxn/neon/walproposer.c index 171af7d2aa..0d5007ef73 100644 --- a/pgxn/neon/walproposer.c +++ b/pgxn/neon/walproposer.c @@ -688,7 +688,7 @@ RecvAcceptorGreeting(Safekeeper *sk) if (!AsyncReadMessage(sk, (AcceptorProposerMessage *) &sk->greetResponse)) return; - wp_log(LOG, "received AcceptorGreeting from safekeeper %s:%s", sk->host, sk->port); + wp_log(LOG, "received AcceptorGreeting from safekeeper %s:%s, term=" INT64_FORMAT, sk->host, sk->port, sk->greetResponse.term); /* Protocol is all good, move to voting. */ sk->state = SS_VOTING; @@ -922,6 +922,7 @@ static void DetermineEpochStartLsn(WalProposer *wp) { TermHistory *dth; + int n_ready = 0; wp->propEpochStartLsn = InvalidXLogRecPtr; wp->donorEpoch = 0; @@ -932,6 +933,8 @@ DetermineEpochStartLsn(WalProposer *wp) { if (wp->safekeeper[i].state == SS_IDLE) { + n_ready++; + if (GetEpoch(&wp->safekeeper[i]) > wp->donorEpoch || (GetEpoch(&wp->safekeeper[i]) == wp->donorEpoch && wp->safekeeper[i].voteResponse.flushLsn > wp->propEpochStartLsn)) @@ -958,6 +961,16 @@ DetermineEpochStartLsn(WalProposer *wp) } } + if (n_ready < wp->quorum) + { + /* + * This is a rare case that can be triggered if safekeeper has voted and disconnected. + * In this case, its state will not be SS_IDLE and its vote cannot be used, because + * we clean up `voteResponse` in `ShutdownConnection`. + */ + wp_log(FATAL, "missing majority of votes, collected %d, expected %d, got %d", wp->n_votes, wp->quorum, n_ready); + } + /* * If propEpochStartLsn is 0, it means flushLsn is 0 everywhere, we are bootstrapping * and nothing was committed yet. Start streaming then from the basebackup LSN. diff --git a/pgxn/neon/walproposer.h b/pgxn/neon/walproposer.h index 688d8e6e52..53820f6e1b 100644 --- a/pgxn/neon/walproposer.h +++ b/pgxn/neon/walproposer.h @@ -486,6 +486,8 @@ typedef struct walproposer_api * * On success, the data is placed in *buf. It is valid until the next call * to this function. + * + * Returns PG_ASYNC_READ_FAIL on closed connection. */ PGAsyncReadResult (*conn_async_read) (Safekeeper *sk, char **buf, int *amount); @@ -532,6 +534,13 @@ typedef struct walproposer_api * Returns 0 if timeout is reached, 1 if some event happened. Updates * events mask to indicate events and sets sk to the safekeeper which has * an event. + * + * On timeout, events is set to WL_NO_EVENTS. On socket event, events is + * set to WL_SOCKET_READABLE and/or WL_SOCKET_WRITEABLE. When socket is + * closed, events is set to WL_SOCKET_READABLE. + * + * WL_SOCKET_WRITEABLE is usually set only when we need to flush the buffer. + * It can be returned only if caller asked for this event in the last *_event_set call. */ int (*wait_event_set) (WalProposer *wp, long timeout, Safekeeper **sk, uint32 *events); diff --git a/safekeeper/Cargo.toml b/safekeeper/Cargo.toml index 364cad7892..cb4a1def1f 100644 --- a/safekeeper/Cargo.toml +++ b/safekeeper/Cargo.toml @@ -61,3 +61,10 @@ tokio-stream.workspace = true utils.workspace = true workspace_hack.workspace = true + +[dev-dependencies] +walproposer.workspace = true +rand.workspace = true +desim.workspace = true +tracing.workspace = true +tracing-subscriber = { workspace = true, features = ["json"] } diff --git a/safekeeper/tests/misc_test.rs b/safekeeper/tests/misc_test.rs new file mode 100644 index 0000000000..8e5b17a143 --- /dev/null +++ b/safekeeper/tests/misc_test.rs @@ -0,0 +1,155 @@ +use std::sync::Arc; + +use tracing::{info, warn}; +use utils::lsn::Lsn; + +use crate::walproposer_sim::{ + log::{init_logger, init_tracing_logger}, + simulation::{generate_network_opts, generate_schedule, Schedule, TestAction, TestConfig}, +}; + +pub mod walproposer_sim; + +// Test that simulation supports restarting (crashing) safekeepers. +#[test] +fn crash_safekeeper() { + let clock = init_logger(); + let config = TestConfig::new(Some(clock)); + let test = config.start(1337); + + let lsn = test.sync_safekeepers().unwrap(); + assert_eq!(lsn, Lsn(0)); + info!("Sucessfully synced empty safekeepers at 0/0"); + + let mut wp = test.launch_walproposer(lsn); + + // Write some WAL and crash safekeeper 0 without waiting for replication. + test.poll_for_duration(30); + wp.write_tx(3); + test.servers[0].restart(); + + // Wait some time, so that walproposer can reconnect. + test.poll_for_duration(2000); +} + +// Test that walproposer can be crashed (stopped). +#[test] +fn test_simple_restart() { + let clock = init_logger(); + let config = TestConfig::new(Some(clock)); + let test = config.start(1337); + + let lsn = test.sync_safekeepers().unwrap(); + assert_eq!(lsn, Lsn(0)); + info!("Sucessfully synced empty safekeepers at 0/0"); + + let mut wp = test.launch_walproposer(lsn); + + test.poll_for_duration(30); + wp.write_tx(3); + test.poll_for_duration(100); + + wp.stop(); + drop(wp); + + let lsn = test.sync_safekeepers().unwrap(); + info!("Sucessfully synced safekeepers at {}", lsn); +} + +// Test runnning a simple schedule, restarting everything a several times. +#[test] +fn test_simple_schedule() -> anyhow::Result<()> { + let clock = init_logger(); + let mut config = TestConfig::new(Some(clock)); + config.network.keepalive_timeout = Some(100); + let test = config.start(1337); + + let schedule: Schedule = vec![ + (0, TestAction::RestartWalProposer), + (50, TestAction::WriteTx(5)), + (100, TestAction::RestartSafekeeper(0)), + (100, TestAction::WriteTx(5)), + (110, TestAction::RestartSafekeeper(1)), + (110, TestAction::WriteTx(5)), + (120, TestAction::RestartSafekeeper(2)), + (120, TestAction::WriteTx(5)), + (201, TestAction::RestartWalProposer), + (251, TestAction::RestartSafekeeper(0)), + (251, TestAction::RestartSafekeeper(1)), + (251, TestAction::RestartSafekeeper(2)), + (251, TestAction::WriteTx(5)), + (255, TestAction::WriteTx(5)), + (1000, TestAction::WriteTx(5)), + ]; + + test.run_schedule(&schedule)?; + info!("Test finished, stopping all threads"); + test.world.deallocate(); + + Ok(()) +} + +// Test that simulation can process 10^4 transactions. +#[test] +fn test_many_tx() -> anyhow::Result<()> { + let clock = init_logger(); + let config = TestConfig::new(Some(clock)); + let test = config.start(1337); + + let mut schedule: Schedule = vec![]; + for i in 0..100 { + schedule.push((i * 10, TestAction::WriteTx(100))); + } + + test.run_schedule(&schedule)?; + info!("Test finished, stopping all threads"); + test.world.stop_all(); + + let events = test.world.take_events(); + info!("Events: {:?}", events); + let last_commit_lsn = events + .iter() + .filter_map(|event| { + if event.data.starts_with("commit_lsn;") { + let lsn: u64 = event.data.split(';').nth(1).unwrap().parse().unwrap(); + return Some(lsn); + } + None + }) + .last() + .unwrap(); + + let initdb_lsn = 21623024; + let diff = last_commit_lsn - initdb_lsn; + info!("Last commit lsn: {}, diff: {}", last_commit_lsn, diff); + // each tx is at least 8 bytes, it's written a 100 times for in a loop for 100 times + assert!(diff > 100 * 100 * 8); + Ok(()) +} + +// Checks that we don't have nasty circular dependencies, preventing Arc from deallocating. +// This test doesn't really assert anything, you need to run it manually to check if there +// is any issue. +#[test] +fn test_res_dealloc() -> anyhow::Result<()> { + let clock = init_tracing_logger(true); + let mut config = TestConfig::new(Some(clock)); + + let seed = 123456; + config.network = generate_network_opts(seed); + let test = config.start(seed); + warn!("Running test with seed {}", seed); + + let schedule = generate_schedule(seed); + info!("schedule: {:?}", schedule); + test.run_schedule(&schedule).unwrap(); + test.world.stop_all(); + + let world = test.world.clone(); + drop(test); + info!("world strong count: {}", Arc::strong_count(&world)); + world.deallocate(); + info!("world strong count: {}", Arc::strong_count(&world)); + + Ok(()) +} diff --git a/safekeeper/tests/random_test.rs b/safekeeper/tests/random_test.rs new file mode 100644 index 0000000000..6c6f6a8c96 --- /dev/null +++ b/safekeeper/tests/random_test.rs @@ -0,0 +1,56 @@ +use rand::Rng; +use tracing::{info, warn}; + +use crate::walproposer_sim::{ + log::{init_logger, init_tracing_logger}, + simulation::{generate_network_opts, generate_schedule, TestConfig}, + simulation_logs::validate_events, +}; + +pub mod walproposer_sim; + +// Generates 2000 random seeds and runs a schedule for each of them. +// If you seed this test fail, please report the last seed to the +// @safekeeper team. +#[test] +fn test_random_schedules() -> anyhow::Result<()> { + let clock = init_logger(); + let mut config = TestConfig::new(Some(clock)); + + for _ in 0..2000 { + let seed: u64 = rand::thread_rng().gen(); + config.network = generate_network_opts(seed); + + let test = config.start(seed); + warn!("Running test with seed {}", seed); + + let schedule = generate_schedule(seed); + test.run_schedule(&schedule).unwrap(); + validate_events(test.world.take_events()); + test.world.deallocate(); + } + + Ok(()) +} + +// After you found a seed that fails, you can insert this seed here +// and run the test to see the full debug output. +#[test] +fn test_one_schedule() -> anyhow::Result<()> { + let clock = init_tracing_logger(true); + let mut config = TestConfig::new(Some(clock)); + + let seed = 11047466935058776390; + config.network = generate_network_opts(seed); + info!("network: {:?}", config.network); + let test = config.start(seed); + warn!("Running test with seed {}", seed); + + let schedule = generate_schedule(seed); + info!("schedule: {:?}", schedule); + test.run_schedule(&schedule).unwrap(); + validate_events(test.world.take_events()); + test.world.deallocate(); + + Ok(()) +} diff --git a/safekeeper/tests/simple_test.rs b/safekeeper/tests/simple_test.rs new file mode 100644 index 0000000000..0be9d0deef --- /dev/null +++ b/safekeeper/tests/simple_test.rs @@ -0,0 +1,45 @@ +use tracing::info; +use utils::lsn::Lsn; + +use crate::walproposer_sim::{log::init_logger, simulation::TestConfig}; + +pub mod walproposer_sim; + +// Check that first start of sync_safekeepers() returns 0/0 on empty safekeepers. +#[test] +fn sync_empty_safekeepers() { + let clock = init_logger(); + let config = TestConfig::new(Some(clock)); + let test = config.start(1337); + + let lsn = test.sync_safekeepers().unwrap(); + assert_eq!(lsn, Lsn(0)); + info!("Sucessfully synced empty safekeepers at 0/0"); + + let lsn = test.sync_safekeepers().unwrap(); + assert_eq!(lsn, Lsn(0)); + info!("Sucessfully synced (again) empty safekeepers at 0/0"); +} + +// Check that there are no panics when we are writing and streaming WAL to safekeepers. +#[test] +fn run_walproposer_generate_wal() { + let clock = init_logger(); + let config = TestConfig::new(Some(clock)); + let test = config.start(1337); + + let lsn = test.sync_safekeepers().unwrap(); + assert_eq!(lsn, Lsn(0)); + info!("Sucessfully synced empty safekeepers at 0/0"); + + let mut wp = test.launch_walproposer(lsn); + + // wait for walproposer to start + test.poll_for_duration(30); + + // just write some WAL + for _ in 0..100 { + wp.write_tx(1); + test.poll_for_duration(5); + } +} diff --git a/safekeeper/tests/walproposer_sim/block_storage.rs b/safekeeper/tests/walproposer_sim/block_storage.rs new file mode 100644 index 0000000000..468c02ad2f --- /dev/null +++ b/safekeeper/tests/walproposer_sim/block_storage.rs @@ -0,0 +1,57 @@ +use std::collections::HashMap; + +const BLOCK_SIZE: usize = 8192; + +/// A simple in-memory implementation of a block storage. Can be used to implement external +/// storage in tests. +pub struct BlockStorage { + blocks: HashMap, +} + +impl Default for BlockStorage { + fn default() -> Self { + Self::new() + } +} + +impl BlockStorage { + pub fn new() -> Self { + BlockStorage { + blocks: HashMap::new(), + } + } + + pub fn read(&self, pos: u64, buf: &mut [u8]) { + let mut buf_offset = 0; + let mut storage_pos = pos; + while buf_offset < buf.len() { + let block_id = storage_pos / BLOCK_SIZE as u64; + let block = self.blocks.get(&block_id).unwrap_or(&[0; BLOCK_SIZE]); + let block_offset = storage_pos % BLOCK_SIZE as u64; + let block_len = BLOCK_SIZE as u64 - block_offset; + let buf_len = buf.len() - buf_offset; + let copy_len = std::cmp::min(block_len as usize, buf_len); + buf[buf_offset..buf_offset + copy_len] + .copy_from_slice(&block[block_offset as usize..block_offset as usize + copy_len]); + buf_offset += copy_len; + storage_pos += copy_len as u64; + } + } + + pub fn write(&mut self, pos: u64, buf: &[u8]) { + let mut buf_offset = 0; + let mut storage_pos = pos; + while buf_offset < buf.len() { + let block_id = storage_pos / BLOCK_SIZE as u64; + let block = self.blocks.entry(block_id).or_insert([0; BLOCK_SIZE]); + let block_offset = storage_pos % BLOCK_SIZE as u64; + let block_len = BLOCK_SIZE as u64 - block_offset; + let buf_len = buf.len() - buf_offset; + let copy_len = std::cmp::min(block_len as usize, buf_len); + block[block_offset as usize..block_offset as usize + copy_len] + .copy_from_slice(&buf[buf_offset..buf_offset + copy_len]); + buf_offset += copy_len; + storage_pos += copy_len as u64 + } + } +} diff --git a/safekeeper/tests/walproposer_sim/log.rs b/safekeeper/tests/walproposer_sim/log.rs new file mode 100644 index 0000000000..870f30de4f --- /dev/null +++ b/safekeeper/tests/walproposer_sim/log.rs @@ -0,0 +1,77 @@ +use std::{fmt, sync::Arc}; + +use desim::time::Timing; +use once_cell::sync::OnceCell; +use parking_lot::Mutex; +use tracing_subscriber::fmt::{format::Writer, time::FormatTime}; + +/// SimClock can be plugged into tracing logger to print simulation time. +#[derive(Clone)] +pub struct SimClock { + clock_ptr: Arc>>>, +} + +impl Default for SimClock { + fn default() -> Self { + SimClock { + clock_ptr: Arc::new(Mutex::new(None)), + } + } +} + +impl SimClock { + pub fn set_clock(&self, clock: Arc) { + *self.clock_ptr.lock() = Some(clock); + } +} + +impl FormatTime for SimClock { + fn format_time(&self, w: &mut Writer<'_>) -> fmt::Result { + let clock = self.clock_ptr.lock(); + + if let Some(clock) = clock.as_ref() { + let now = clock.now(); + write!(w, "[{}]", now) + } else { + write!(w, "[?]") + } + } +} + +static LOGGING_DONE: OnceCell = OnceCell::new(); + +/// Returns ptr to clocks attached to tracing logger to update them when the +/// world is (re)created. +pub fn init_tracing_logger(debug_enabled: bool) -> SimClock { + LOGGING_DONE + .get_or_init(|| { + let clock = SimClock::default(); + let base_logger = tracing_subscriber::fmt() + .with_target(false) + // prefix log lines with simulated time timestamp + .with_timer(clock.clone()) + // .with_ansi(true) TODO + .with_max_level(match debug_enabled { + true => tracing::Level::DEBUG, + false => tracing::Level::WARN, + }) + .with_writer(std::io::stdout); + base_logger.init(); + + // logging::replace_panic_hook_with_tracing_panic_hook().forget(); + + if !debug_enabled { + std::panic::set_hook(Box::new(|_| {})); + } + + clock + }) + .clone() +} + +pub fn init_logger() -> SimClock { + // RUST_TRACEBACK envvar controls whether we print all logs or only warnings. + let debug_enabled = std::env::var("RUST_TRACEBACK").is_ok(); + + init_tracing_logger(debug_enabled) +} diff --git a/safekeeper/tests/walproposer_sim/mod.rs b/safekeeper/tests/walproposer_sim/mod.rs new file mode 100644 index 0000000000..ec560dcb3b --- /dev/null +++ b/safekeeper/tests/walproposer_sim/mod.rs @@ -0,0 +1,8 @@ +pub mod block_storage; +pub mod log; +pub mod safekeeper; +pub mod safekeeper_disk; +pub mod simulation; +pub mod simulation_logs; +pub mod walproposer_api; +pub mod walproposer_disk; diff --git a/safekeeper/tests/walproposer_sim/safekeeper.rs b/safekeeper/tests/walproposer_sim/safekeeper.rs new file mode 100644 index 0000000000..1945b9d0cb --- /dev/null +++ b/safekeeper/tests/walproposer_sim/safekeeper.rs @@ -0,0 +1,410 @@ +//! Safekeeper communication endpoint to WAL proposer (compute node). +//! Gets messages from the network, passes them down to consensus module and +//! sends replies back. + +use std::{collections::HashMap, sync::Arc, time::Duration}; + +use anyhow::{bail, Result}; +use bytes::{Bytes, BytesMut}; +use camino::Utf8PathBuf; +use desim::{ + executor::{self, PollSome}, + network::TCP, + node_os::NodeOs, + proto::{AnyMessage, NetEvent, NodeEvent}, +}; +use hyper::Uri; +use safekeeper::{ + safekeeper::{ProposerAcceptorMessage, SafeKeeper, ServerInfo, UNKNOWN_SERVER_VERSION}, + state::TimelinePersistentState, + timeline::TimelineError, + wal_storage::Storage, + SafeKeeperConf, +}; +use tracing::{debug, info_span}; +use utils::{ + id::{NodeId, TenantId, TenantTimelineId, TimelineId}, + lsn::Lsn, +}; + +use super::safekeeper_disk::{DiskStateStorage, DiskWALStorage, SafekeeperDisk, TimelineDisk}; + +struct SharedState { + sk: SafeKeeper, + disk: Arc, +} + +struct GlobalMap { + timelines: HashMap, + conf: SafeKeeperConf, + disk: Arc, +} + +impl GlobalMap { + /// Restores global state from disk. + fn new(disk: Arc, conf: SafeKeeperConf) -> Result { + let mut timelines = HashMap::new(); + + for (&ttid, disk) in disk.timelines.lock().iter() { + debug!("loading timeline {}", ttid); + let state = disk.state.lock().clone(); + + if state.server.wal_seg_size == 0 { + bail!(TimelineError::UninitializedWalSegSize(ttid)); + } + + if state.server.pg_version == UNKNOWN_SERVER_VERSION { + bail!(TimelineError::UninitialinzedPgVersion(ttid)); + } + + if state.commit_lsn < state.local_start_lsn { + bail!( + "commit_lsn {} is higher than local_start_lsn {}", + state.commit_lsn, + state.local_start_lsn + ); + } + + let control_store = DiskStateStorage::new(disk.clone()); + let wal_store = DiskWALStorage::new(disk.clone(), &control_store)?; + + let sk = SafeKeeper::new(control_store, wal_store, conf.my_id)?; + timelines.insert( + ttid, + SharedState { + sk, + disk: disk.clone(), + }, + ); + } + + Ok(Self { + timelines, + conf, + disk, + }) + } + + fn create(&mut self, ttid: TenantTimelineId, server_info: ServerInfo) -> Result<()> { + if self.timelines.contains_key(&ttid) { + bail!("timeline {} already exists", ttid); + } + + debug!("creating new timeline {}", ttid); + + let commit_lsn = Lsn::INVALID; + let local_start_lsn = Lsn::INVALID; + + let state = + TimelinePersistentState::new(&ttid, server_info, vec![], commit_lsn, local_start_lsn); + + if state.server.wal_seg_size == 0 { + bail!(TimelineError::UninitializedWalSegSize(ttid)); + } + + if state.server.pg_version == UNKNOWN_SERVER_VERSION { + bail!(TimelineError::UninitialinzedPgVersion(ttid)); + } + + if state.commit_lsn < state.local_start_lsn { + bail!( + "commit_lsn {} is higher than local_start_lsn {}", + state.commit_lsn, + state.local_start_lsn + ); + } + + let disk_timeline = self.disk.put_state(&ttid, state); + let control_store = DiskStateStorage::new(disk_timeline.clone()); + let wal_store = DiskWALStorage::new(disk_timeline.clone(), &control_store)?; + + let sk = SafeKeeper::new(control_store, wal_store, self.conf.my_id)?; + + self.timelines.insert( + ttid, + SharedState { + sk, + disk: disk_timeline, + }, + ); + Ok(()) + } + + fn get(&mut self, ttid: &TenantTimelineId) -> &mut SharedState { + self.timelines.get_mut(ttid).expect("timeline must exist") + } + + fn has_tli(&self, ttid: &TenantTimelineId) -> bool { + self.timelines.contains_key(ttid) + } +} + +/// State of a single connection to walproposer. +struct ConnState { + tcp: TCP, + + greeting: bool, + ttid: TenantTimelineId, + flush_pending: bool, + + runtime: tokio::runtime::Runtime, +} + +pub fn run_server(os: NodeOs, disk: Arc) -> Result<()> { + let _enter = info_span!("safekeeper", id = os.id()).entered(); + debug!("started server"); + os.log_event("started;safekeeper".to_owned()); + let conf = SafeKeeperConf { + workdir: Utf8PathBuf::from("."), + my_id: NodeId(os.id() as u64), + listen_pg_addr: String::new(), + listen_http_addr: String::new(), + no_sync: false, + broker_endpoint: "/".parse::().unwrap(), + broker_keepalive_interval: Duration::from_secs(0), + heartbeat_timeout: Duration::from_secs(0), + remote_storage: None, + max_offloader_lag_bytes: 0, + wal_backup_enabled: false, + listen_pg_addr_tenant_only: None, + advertise_pg_addr: None, + availability_zone: None, + peer_recovery_enabled: false, + backup_parallel_jobs: 0, + pg_auth: None, + pg_tenant_only_auth: None, + http_auth: None, + current_thread_runtime: false, + }; + + let mut global = GlobalMap::new(disk, conf.clone())?; + let mut conns: HashMap = HashMap::new(); + + for (&_ttid, shared_state) in global.timelines.iter_mut() { + let flush_lsn = shared_state.sk.wal_store.flush_lsn(); + let commit_lsn = shared_state.sk.state.commit_lsn; + os.log_event(format!("tli_loaded;{};{}", flush_lsn.0, commit_lsn.0)); + } + + let node_events = os.node_events(); + let mut epoll_vec: Vec> = vec![]; + let mut epoll_idx: Vec = vec![]; + + // TODO: batch events processing (multiple events per tick) + loop { + epoll_vec.clear(); + epoll_idx.clear(); + + // node events channel + epoll_vec.push(Box::new(node_events.clone())); + epoll_idx.push(0); + + // tcp connections + for conn in conns.values() { + epoll_vec.push(Box::new(conn.tcp.recv_chan())); + epoll_idx.push(conn.tcp.connection_id()); + } + + // waiting for the next message + let index = executor::epoll_chans(&epoll_vec, -1).unwrap(); + + if index == 0 { + // got a new connection + match node_events.must_recv() { + NodeEvent::Accept(tcp) => { + conns.insert( + tcp.connection_id(), + ConnState { + tcp, + greeting: false, + ttid: TenantTimelineId::empty(), + flush_pending: false, + runtime: tokio::runtime::Builder::new_current_thread().build()?, + }, + ); + } + NodeEvent::Internal(_) => unreachable!(), + } + continue; + } + + let connection_id = epoll_idx[index]; + let conn = conns.get_mut(&connection_id).unwrap(); + let mut next_event = Some(conn.tcp.recv_chan().must_recv()); + + loop { + let event = match next_event { + Some(event) => event, + None => break, + }; + + match event { + NetEvent::Message(msg) => { + let res = conn.process_any(msg, &mut global); + if res.is_err() { + debug!("conn {:?} error: {:#}", connection_id, res.unwrap_err()); + conns.remove(&connection_id); + break; + } + } + NetEvent::Closed => { + // TODO: remove from conns? + } + } + + next_event = conn.tcp.recv_chan().try_recv(); + } + + conns.retain(|_, conn| { + let res = conn.flush(&mut global); + if res.is_err() { + debug!("conn {:?} error: {:?}", conn.tcp, res); + } + res.is_ok() + }); + } +} + +impl ConnState { + /// Process a message from the network. It can be START_REPLICATION request or a valid ProposerAcceptorMessage message. + fn process_any(&mut self, any: AnyMessage, global: &mut GlobalMap) -> Result<()> { + if let AnyMessage::Bytes(copy_data) = any { + let repl_prefix = b"START_REPLICATION "; + if !self.greeting && copy_data.starts_with(repl_prefix) { + self.process_start_replication(copy_data.slice(repl_prefix.len()..), global)?; + bail!("finished processing START_REPLICATION") + } + + let msg = ProposerAcceptorMessage::parse(copy_data)?; + debug!("got msg: {:?}", msg); + self.process(msg, global) + } else { + bail!("unexpected message, expected AnyMessage::Bytes"); + } + } + + /// Process START_REPLICATION request. + fn process_start_replication( + &mut self, + copy_data: Bytes, + global: &mut GlobalMap, + ) -> Result<()> { + // format is " " + let str = String::from_utf8(copy_data.to_vec())?; + + let mut parts = str.split(' '); + let tenant_id = parts.next().unwrap().parse::()?; + let timeline_id = parts.next().unwrap().parse::()?; + let start_lsn = parts.next().unwrap().parse::()?; + let end_lsn = parts.next().unwrap().parse::()?; + + let ttid = TenantTimelineId::new(tenant_id, timeline_id); + let shared_state = global.get(&ttid); + + // read bytes from start_lsn to end_lsn + let mut buf = vec![0; (end_lsn - start_lsn) as usize]; + shared_state.disk.wal.lock().read(start_lsn, &mut buf); + + // send bytes to the client + self.tcp.send(AnyMessage::Bytes(Bytes::from(buf))); + Ok(()) + } + + /// Get or create a timeline. + fn init_timeline( + &mut self, + ttid: TenantTimelineId, + server_info: ServerInfo, + global: &mut GlobalMap, + ) -> Result<()> { + self.ttid = ttid; + if global.has_tli(&ttid) { + return Ok(()); + } + + global.create(ttid, server_info) + } + + /// Process a ProposerAcceptorMessage. + fn process(&mut self, msg: ProposerAcceptorMessage, global: &mut GlobalMap) -> Result<()> { + if !self.greeting { + self.greeting = true; + + match msg { + ProposerAcceptorMessage::Greeting(ref greeting) => { + tracing::info!( + "start handshake with walproposer {:?} {:?}", + self.tcp, + greeting + ); + let server_info = ServerInfo { + pg_version: greeting.pg_version, + system_id: greeting.system_id, + wal_seg_size: greeting.wal_seg_size, + }; + let ttid = TenantTimelineId::new(greeting.tenant_id, greeting.timeline_id); + self.init_timeline(ttid, server_info, global)? + } + _ => { + bail!("unexpected message {msg:?} instead of greeting"); + } + } + } + + let tli = global.get(&self.ttid); + + match msg { + ProposerAcceptorMessage::AppendRequest(append_request) => { + self.flush_pending = true; + self.process_sk_msg( + tli, + &ProposerAcceptorMessage::NoFlushAppendRequest(append_request), + )?; + } + other => { + self.process_sk_msg(tli, &other)?; + } + } + + Ok(()) + } + + /// Process FlushWAL if needed. + fn flush(&mut self, global: &mut GlobalMap) -> Result<()> { + // TODO: try to add extra flushes in simulation, to verify that extra flushes don't break anything + if !self.flush_pending { + return Ok(()); + } + self.flush_pending = false; + let shared_state = global.get(&self.ttid); + self.process_sk_msg(shared_state, &ProposerAcceptorMessage::FlushWAL) + } + + /// Make safekeeper process a message and send a reply to the TCP + fn process_sk_msg( + &mut self, + shared_state: &mut SharedState, + msg: &ProposerAcceptorMessage, + ) -> Result<()> { + let mut reply = self.runtime.block_on(shared_state.sk.process_msg(msg))?; + if let Some(reply) = &mut reply { + // TODO: if this is AppendResponse, fill in proper hot standby feedback and disk consistent lsn + + let mut buf = BytesMut::with_capacity(128); + reply.serialize(&mut buf)?; + + self.tcp.send(AnyMessage::Bytes(buf.into())); + } + Ok(()) + } +} + +impl Drop for ConnState { + fn drop(&mut self) { + debug!("dropping conn: {:?}", self.tcp); + if !std::thread::panicking() { + self.tcp.close(); + } + // TODO: clean up non-fsynced WAL + } +} diff --git a/safekeeper/tests/walproposer_sim/safekeeper_disk.rs b/safekeeper/tests/walproposer_sim/safekeeper_disk.rs new file mode 100644 index 0000000000..35bca325aa --- /dev/null +++ b/safekeeper/tests/walproposer_sim/safekeeper_disk.rs @@ -0,0 +1,278 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use parking_lot::Mutex; +use safekeeper::state::TimelinePersistentState; +use utils::id::TenantTimelineId; + +use super::block_storage::BlockStorage; + +use std::{ops::Deref, time::Instant}; + +use anyhow::Result; +use bytes::{Buf, BytesMut}; +use futures::future::BoxFuture; +use postgres_ffi::{waldecoder::WalStreamDecoder, XLogSegNo}; +use safekeeper::{control_file, metrics::WalStorageMetrics, wal_storage}; +use tracing::{debug, info}; +use utils::lsn::Lsn; + +/// All safekeeper state that is usually saved to disk. +pub struct SafekeeperDisk { + pub timelines: Mutex>>, +} + +impl Default for SafekeeperDisk { + fn default() -> Self { + Self::new() + } +} + +impl SafekeeperDisk { + pub fn new() -> Self { + SafekeeperDisk { + timelines: Mutex::new(HashMap::new()), + } + } + + pub fn put_state( + &self, + ttid: &TenantTimelineId, + state: TimelinePersistentState, + ) -> Arc { + self.timelines + .lock() + .entry(*ttid) + .and_modify(|e| { + let mut mu = e.state.lock(); + *mu = state.clone(); + }) + .or_insert_with(|| { + Arc::new(TimelineDisk { + state: Mutex::new(state), + wal: Mutex::new(BlockStorage::new()), + }) + }) + .clone() + } +} + +/// Control file state and WAL storage. +pub struct TimelineDisk { + pub state: Mutex, + pub wal: Mutex, +} + +/// Implementation of `control_file::Storage` trait. +pub struct DiskStateStorage { + persisted_state: TimelinePersistentState, + disk: Arc, + last_persist_at: Instant, +} + +impl DiskStateStorage { + pub fn new(disk: Arc) -> Self { + let guard = disk.state.lock(); + let state = guard.clone(); + drop(guard); + DiskStateStorage { + persisted_state: state, + disk, + last_persist_at: Instant::now(), + } + } +} + +#[async_trait::async_trait] +impl control_file::Storage for DiskStateStorage { + /// Persist safekeeper state on disk and update internal state. + async fn persist(&mut self, s: &TimelinePersistentState) -> Result<()> { + self.persisted_state = s.clone(); + *self.disk.state.lock() = s.clone(); + Ok(()) + } + + /// Timestamp of last persist. + fn last_persist_at(&self) -> Instant { + // TODO: don't rely on it in tests + self.last_persist_at + } +} + +impl Deref for DiskStateStorage { + type Target = TimelinePersistentState; + + fn deref(&self) -> &Self::Target { + &self.persisted_state + } +} + +/// Implementation of `wal_storage::Storage` trait. +pub struct DiskWALStorage { + /// Written to disk, but possibly still in the cache and not fully persisted. + /// Also can be ahead of record_lsn, if happen to be in the middle of a WAL record. + write_lsn: Lsn, + + /// The LSN of the last WAL record written to disk. Still can be not fully flushed. + write_record_lsn: Lsn, + + /// The LSN of the last WAL record flushed to disk. + flush_record_lsn: Lsn, + + /// Decoder is required for detecting boundaries of WAL records. + decoder: WalStreamDecoder, + + /// Bytes of WAL records that are not yet written to disk. + unflushed_bytes: BytesMut, + + /// Contains BlockStorage for WAL. + disk: Arc, +} + +impl DiskWALStorage { + pub fn new(disk: Arc, state: &TimelinePersistentState) -> Result { + let write_lsn = if state.commit_lsn == Lsn(0) { + Lsn(0) + } else { + Self::find_end_of_wal(disk.clone(), state.commit_lsn)? + }; + + let flush_lsn = write_lsn; + Ok(DiskWALStorage { + write_lsn, + write_record_lsn: flush_lsn, + flush_record_lsn: flush_lsn, + decoder: WalStreamDecoder::new(flush_lsn, 16), + unflushed_bytes: BytesMut::new(), + disk, + }) + } + + fn find_end_of_wal(disk: Arc, start_lsn: Lsn) -> Result { + let mut buf = [0; 8192]; + let mut pos = start_lsn.0; + let mut decoder = WalStreamDecoder::new(start_lsn, 16); + let mut result = start_lsn; + loop { + disk.wal.lock().read(pos, &mut buf); + pos += buf.len() as u64; + decoder.feed_bytes(&buf); + + loop { + match decoder.poll_decode() { + Ok(Some(record)) => result = record.0, + Err(e) => { + debug!( + "find_end_of_wal reached end at {:?}, decode error: {:?}", + result, e + ); + return Ok(result); + } + Ok(None) => break, // need more data + } + } + } + } +} + +#[async_trait::async_trait] +impl wal_storage::Storage for DiskWALStorage { + /// LSN of last durably stored WAL record. + fn flush_lsn(&self) -> Lsn { + self.flush_record_lsn + } + + /// Write piece of WAL from buf to disk, but not necessarily sync it. + async fn write_wal(&mut self, startpos: Lsn, buf: &[u8]) -> Result<()> { + if self.write_lsn != startpos { + panic!("write_wal called with wrong startpos"); + } + + self.unflushed_bytes.extend_from_slice(buf); + self.write_lsn += buf.len() as u64; + + if self.decoder.available() != startpos { + info!( + "restart decoder from {} to {}", + self.decoder.available(), + startpos, + ); + self.decoder = WalStreamDecoder::new(startpos, 16); + } + self.decoder.feed_bytes(buf); + loop { + match self.decoder.poll_decode()? { + None => break, // no full record yet + Some((lsn, _rec)) => { + self.write_record_lsn = lsn; + } + } + } + + Ok(()) + } + + /// Truncate WAL at specified LSN, which must be the end of WAL record. + async fn truncate_wal(&mut self, end_pos: Lsn) -> Result<()> { + if self.write_lsn != Lsn(0) && end_pos > self.write_lsn { + panic!( + "truncate_wal called on non-written WAL, write_lsn={}, end_pos={}", + self.write_lsn, end_pos + ); + } + + self.flush_wal().await?; + + // write zeroes to disk from end_pos until self.write_lsn + let buf = [0; 8192]; + let mut pos = end_pos.0; + while pos < self.write_lsn.0 { + self.disk.wal.lock().write(pos, &buf); + pos += buf.len() as u64; + } + + self.write_lsn = end_pos; + self.write_record_lsn = end_pos; + self.flush_record_lsn = end_pos; + self.unflushed_bytes.clear(); + self.decoder = WalStreamDecoder::new(end_pos, 16); + + Ok(()) + } + + /// Durably store WAL on disk, up to the last written WAL record. + async fn flush_wal(&mut self) -> Result<()> { + if self.flush_record_lsn == self.write_record_lsn { + // no need to do extra flush + return Ok(()); + } + + let num_bytes = self.write_record_lsn.0 - self.flush_record_lsn.0; + + self.disk.wal.lock().write( + self.flush_record_lsn.0, + &self.unflushed_bytes[..num_bytes as usize], + ); + self.unflushed_bytes.advance(num_bytes as usize); + self.flush_record_lsn = self.write_record_lsn; + + Ok(()) + } + + /// Remove all segments <= given segno. Returns function doing that as we + /// want to perform it without timeline lock. + fn remove_up_to(&self, _segno_up_to: XLogSegNo) -> BoxFuture<'static, anyhow::Result<()>> { + Box::pin(async move { Ok(()) }) + } + + /// Release resources associated with the storage -- technically, close FDs. + /// Currently we don't remove timelines until restart (#3146), so need to + /// spare descriptors. This would be useful for temporary tli detach as + /// well. + fn close(&mut self) {} + + /// Get metrics for this timeline. + fn get_metrics(&self) -> WalStorageMetrics { + WalStorageMetrics::default() + } +} diff --git a/safekeeper/tests/walproposer_sim/simulation.rs b/safekeeper/tests/walproposer_sim/simulation.rs new file mode 100644 index 0000000000..0d7aaf517b --- /dev/null +++ b/safekeeper/tests/walproposer_sim/simulation.rs @@ -0,0 +1,436 @@ +use std::{cell::Cell, str::FromStr, sync::Arc}; + +use crate::walproposer_sim::{safekeeper::run_server, walproposer_api::SimulationApi}; +use desim::{ + executor::{self, ExternalHandle}, + node_os::NodeOs, + options::{Delay, NetworkOptions}, + proto::{AnyMessage, NodeEvent}, + world::Node, + world::World, +}; +use rand::{Rng, SeedableRng}; +use tracing::{debug, info_span, warn}; +use utils::{id::TenantTimelineId, lsn::Lsn}; +use walproposer::walproposer::{Config, Wrapper}; + +use super::{ + log::SimClock, safekeeper_disk::SafekeeperDisk, walproposer_api, + walproposer_disk::DiskWalProposer, +}; + +/// Simulated safekeeper node. +pub struct SafekeeperNode { + pub node: Arc, + pub id: u32, + pub disk: Arc, + pub thread: Cell, +} + +impl SafekeeperNode { + /// Create and start a safekeeper at the specified Node. + pub fn new(node: Arc) -> Self { + let disk = Arc::new(SafekeeperDisk::new()); + let thread = Cell::new(SafekeeperNode::launch(disk.clone(), node.clone())); + + Self { + id: node.id, + node, + disk, + thread, + } + } + + fn launch(disk: Arc, node: Arc) -> ExternalHandle { + // start the server thread + node.launch(move |os| { + run_server(os, disk).expect("server should finish without errors"); + }) + } + + /// Restart the safekeeper. + pub fn restart(&self) { + let new_thread = SafekeeperNode::launch(self.disk.clone(), self.node.clone()); + let old_thread = self.thread.replace(new_thread); + old_thread.crash_stop(); + } +} + +/// Simulated walproposer node. +pub struct WalProposer { + thread: ExternalHandle, + node: Arc, + disk: Arc, + sync_safekeepers: bool, +} + +impl WalProposer { + /// Generic start function for both modes. + fn start( + os: NodeOs, + disk: Arc, + ttid: TenantTimelineId, + addrs: Vec, + lsn: Option, + ) { + let sync_safekeepers = lsn.is_none(); + + let _enter = if sync_safekeepers { + info_span!("sync", started = executor::now()).entered() + } else { + info_span!("walproposer", started = executor::now()).entered() + }; + + os.log_event(format!("started;walproposer;{}", sync_safekeepers as i32)); + + let config = Config { + ttid, + safekeepers_list: addrs, + safekeeper_reconnect_timeout: 1000, + safekeeper_connection_timeout: 5000, + sync_safekeepers, + }; + let args = walproposer_api::Args { + os, + config: config.clone(), + disk, + redo_start_lsn: lsn, + }; + let api = SimulationApi::new(args); + let wp = Wrapper::new(Box::new(api), config); + wp.start(); + } + + /// Start walproposer in a sync_safekeepers mode. + pub fn launch_sync(ttid: TenantTimelineId, addrs: Vec, node: Arc) -> Self { + debug!("sync_safekeepers started at node {}", node.id); + let disk = DiskWalProposer::new(); + let disk_wp = disk.clone(); + + // start the client thread + let handle = node.launch(move |os| { + WalProposer::start(os, disk_wp, ttid, addrs, None); + }); + + Self { + thread: handle, + node, + disk, + sync_safekeepers: true, + } + } + + /// Start walproposer in a normal mode. + pub fn launch_walproposer( + ttid: TenantTimelineId, + addrs: Vec, + node: Arc, + lsn: Lsn, + ) -> Self { + debug!("walproposer started at node {}", node.id); + let disk = DiskWalProposer::new(); + disk.lock().reset_to(lsn); + let disk_wp = disk.clone(); + + // start the client thread + let handle = node.launch(move |os| { + WalProposer::start(os, disk_wp, ttid, addrs, Some(lsn)); + }); + + Self { + thread: handle, + node, + disk, + sync_safekeepers: false, + } + } + + pub fn write_tx(&mut self, cnt: usize) { + let start_lsn = self.disk.lock().flush_rec_ptr(); + + for _ in 0..cnt { + self.disk + .lock() + .insert_logical_message("prefix", b"message") + .expect("failed to generate logical message"); + } + + let end_lsn = self.disk.lock().flush_rec_ptr(); + + // log event + self.node + .log_event(format!("write_wal;{};{};{}", start_lsn.0, end_lsn.0, cnt)); + + // now we need to set "Latch" in walproposer + self.node + .node_events() + .send(NodeEvent::Internal(AnyMessage::Just32(0))); + } + + pub fn stop(&self) { + self.thread.crash_stop(); + } +} + +/// Holds basic simulation settings, such as network options. +pub struct TestConfig { + pub network: NetworkOptions, + pub timeout: u64, + pub clock: Option, +} + +impl TestConfig { + /// Create a new TestConfig with default settings. + pub fn new(clock: Option) -> Self { + Self { + network: NetworkOptions { + keepalive_timeout: Some(2000), + connect_delay: Delay { + min: 1, + max: 5, + fail_prob: 0.0, + }, + send_delay: Delay { + min: 1, + max: 5, + fail_prob: 0.0, + }, + }, + timeout: 1_000 * 10, + clock, + } + } + + /// Start a new simulation with the specified seed. + pub fn start(&self, seed: u64) -> Test { + let world = Arc::new(World::new(seed, Arc::new(self.network.clone()))); + + if let Some(clock) = &self.clock { + clock.set_clock(world.clock()); + } + + let servers = [ + SafekeeperNode::new(world.new_node()), + SafekeeperNode::new(world.new_node()), + SafekeeperNode::new(world.new_node()), + ]; + + let server_ids = [servers[0].id, servers[1].id, servers[2].id]; + let safekeepers_addrs = server_ids.map(|id| format!("node:{}", id)).to_vec(); + + let ttid = TenantTimelineId::generate(); + + Test { + world, + servers, + sk_list: safekeepers_addrs, + ttid, + timeout: self.timeout, + } + } +} + +/// Holds simulation state. +pub struct Test { + pub world: Arc, + pub servers: [SafekeeperNode; 3], + pub sk_list: Vec, + pub ttid: TenantTimelineId, + pub timeout: u64, +} + +impl Test { + /// Start a sync_safekeepers thread and wait for it to finish. + pub fn sync_safekeepers(&self) -> anyhow::Result { + let wp = self.launch_sync_safekeepers(); + + // poll until exit or timeout + let time_limit = self.timeout; + while self.world.step() && self.world.now() < time_limit && !wp.thread.is_finished() {} + + if !wp.thread.is_finished() { + anyhow::bail!("timeout or idle stuck"); + } + + let res = wp.thread.result(); + if res.0 != 0 { + anyhow::bail!("non-zero exitcode: {:?}", res); + } + let lsn = Lsn::from_str(&res.1)?; + Ok(lsn) + } + + /// Spawn a new sync_safekeepers thread. + pub fn launch_sync_safekeepers(&self) -> WalProposer { + WalProposer::launch_sync(self.ttid, self.sk_list.clone(), self.world.new_node()) + } + + /// Spawn a new walproposer thread. + pub fn launch_walproposer(&self, lsn: Lsn) -> WalProposer { + let lsn = if lsn.0 == 0 { + // usual LSN after basebackup + Lsn(21623024) + } else { + lsn + }; + + WalProposer::launch_walproposer(self.ttid, self.sk_list.clone(), self.world.new_node(), lsn) + } + + /// Execute the simulation for the specified duration. + pub fn poll_for_duration(&self, duration: u64) { + let time_limit = std::cmp::min(self.world.now() + duration, self.timeout); + while self.world.step() && self.world.now() < time_limit {} + } + + /// Execute the simulation together with events defined in some schedule. + pub fn run_schedule(&self, schedule: &Schedule) -> anyhow::Result<()> { + // scheduling empty events so that world will stop in those points + { + let clock = self.world.clock(); + + let now = self.world.now(); + for (time, _) in schedule { + if *time < now { + continue; + } + clock.schedule_fake(*time - now); + } + } + + let mut wp = self.launch_sync_safekeepers(); + + let mut skipped_tx = 0; + let mut started_tx = 0; + + let mut schedule_ptr = 0; + + loop { + if wp.sync_safekeepers && wp.thread.is_finished() { + let res = wp.thread.result(); + if res.0 != 0 { + warn!("sync non-zero exitcode: {:?}", res); + debug!("restarting sync_safekeepers"); + // restart the sync_safekeepers + wp = self.launch_sync_safekeepers(); + continue; + } + let lsn = Lsn::from_str(&res.1)?; + debug!("sync_safekeepers finished at LSN {}", lsn); + wp = self.launch_walproposer(lsn); + debug!("walproposer started at thread {}", wp.thread.id()); + } + + let now = self.world.now(); + while schedule_ptr < schedule.len() && schedule[schedule_ptr].0 <= now { + if now != schedule[schedule_ptr].0 { + warn!("skipped event {:?} at {}", schedule[schedule_ptr], now); + } + + let action = &schedule[schedule_ptr].1; + match action { + TestAction::WriteTx(size) => { + if !wp.sync_safekeepers && !wp.thread.is_finished() { + started_tx += *size; + wp.write_tx(*size); + debug!("written {} transactions", size); + } else { + skipped_tx += size; + debug!("skipped {} transactions", size); + } + } + TestAction::RestartSafekeeper(id) => { + debug!("restarting safekeeper {}", id); + self.servers[*id].restart(); + } + TestAction::RestartWalProposer => { + debug!("restarting sync_safekeepers"); + wp.stop(); + wp = self.launch_sync_safekeepers(); + } + } + schedule_ptr += 1; + } + + if schedule_ptr == schedule.len() { + break; + } + let next_event_time = schedule[schedule_ptr].0; + + // poll until the next event + if wp.thread.is_finished() { + while self.world.step() && self.world.now() < next_event_time {} + } else { + while self.world.step() + && self.world.now() < next_event_time + && !wp.thread.is_finished() + {} + } + } + + debug!( + "finished schedule, total steps: {}", + self.world.get_thread_step_count() + ); + debug!("skipped_tx: {}", skipped_tx); + debug!("started_tx: {}", started_tx); + + Ok(()) + } +} + +#[derive(Debug, Clone)] +pub enum TestAction { + WriteTx(usize), + RestartSafekeeper(usize), + RestartWalProposer, +} + +pub type Schedule = Vec<(u64, TestAction)>; + +pub fn generate_schedule(seed: u64) -> Schedule { + let mut rng = rand::rngs::StdRng::seed_from_u64(seed); + let mut schedule = Vec::new(); + let mut time = 0; + + let cnt = rng.gen_range(1..100); + + for _ in 0..cnt { + time += rng.gen_range(0..500); + let action = match rng.gen_range(0..3) { + 0 => TestAction::WriteTx(rng.gen_range(1..10)), + 1 => TestAction::RestartSafekeeper(rng.gen_range(0..3)), + 2 => TestAction::RestartWalProposer, + _ => unreachable!(), + }; + schedule.push((time, action)); + } + + schedule +} + +pub fn generate_network_opts(seed: u64) -> NetworkOptions { + let mut rng = rand::rngs::StdRng::seed_from_u64(seed); + + let timeout = rng.gen_range(100..2000); + let max_delay = rng.gen_range(1..2 * timeout); + let min_delay = rng.gen_range(1..=max_delay); + + let max_fail_prob = rng.gen_range(0.0..0.9); + let connect_fail_prob = rng.gen_range(0.0..max_fail_prob); + let send_fail_prob = rng.gen_range(0.0..connect_fail_prob); + + NetworkOptions { + keepalive_timeout: Some(timeout), + connect_delay: Delay { + min: min_delay, + max: max_delay, + fail_prob: connect_fail_prob, + }, + send_delay: Delay { + min: min_delay, + max: max_delay, + fail_prob: send_fail_prob, + }, + } +} diff --git a/safekeeper/tests/walproposer_sim/simulation_logs.rs b/safekeeper/tests/walproposer_sim/simulation_logs.rs new file mode 100644 index 0000000000..38885e5dd0 --- /dev/null +++ b/safekeeper/tests/walproposer_sim/simulation_logs.rs @@ -0,0 +1,187 @@ +use desim::proto::SimEvent; +use tracing::debug; + +#[derive(Debug, Clone, PartialEq, Eq)] +enum NodeKind { + Unknown, + Safekeeper, + WalProposer, +} + +impl Default for NodeKind { + fn default() -> Self { + Self::Unknown + } +} + +/// Simulation state of walproposer/safekeeper, derived from the simulation logs. +#[derive(Clone, Debug, Default)] +struct NodeInfo { + kind: NodeKind, + + // walproposer + is_sync: bool, + term: u64, + epoch_lsn: u64, + + // safekeeper + commit_lsn: u64, + flush_lsn: u64, +} + +impl NodeInfo { + fn init_kind(&mut self, kind: NodeKind) { + if self.kind == NodeKind::Unknown { + self.kind = kind; + } else { + assert!(self.kind == kind); + } + } + + fn started(&mut self, data: &str) { + let mut parts = data.split(';'); + assert!(parts.next().unwrap() == "started"); + match parts.next().unwrap() { + "safekeeper" => { + self.init_kind(NodeKind::Safekeeper); + } + "walproposer" => { + self.init_kind(NodeKind::WalProposer); + let is_sync: u8 = parts.next().unwrap().parse().unwrap(); + self.is_sync = is_sync != 0; + } + _ => unreachable!(), + } + } +} + +/// Global state of the simulation, derived from the simulation logs. +#[derive(Debug, Default)] +struct GlobalState { + nodes: Vec, + commit_lsn: u64, + write_lsn: u64, + max_write_lsn: u64, + + written_wal: u64, + written_records: u64, +} + +impl GlobalState { + fn new() -> Self { + Default::default() + } + + fn get(&mut self, id: u32) -> &mut NodeInfo { + let id = id as usize; + if id >= self.nodes.len() { + self.nodes.resize(id + 1, NodeInfo::default()); + } + &mut self.nodes[id] + } +} + +/// Try to find inconsistencies in the simulation log. +pub fn validate_events(events: Vec) { + const INITDB_LSN: u64 = 21623024; + + let hook = std::panic::take_hook(); + scopeguard::defer_on_success! { + std::panic::set_hook(hook); + }; + + let mut state = GlobalState::new(); + state.max_write_lsn = INITDB_LSN; + + for event in events { + debug!("{:?}", event); + + let node = state.get(event.node); + if event.data.starts_with("started;") { + node.started(&event.data); + continue; + } + assert!(node.kind != NodeKind::Unknown); + + // drop reference to unlock state + let mut node = node.clone(); + + let mut parts = event.data.split(';'); + match node.kind { + NodeKind::Safekeeper => match parts.next().unwrap() { + "tli_loaded" => { + let flush_lsn: u64 = parts.next().unwrap().parse().unwrap(); + let commit_lsn: u64 = parts.next().unwrap().parse().unwrap(); + node.flush_lsn = flush_lsn; + node.commit_lsn = commit_lsn; + } + _ => unreachable!(), + }, + NodeKind::WalProposer => { + match parts.next().unwrap() { + "prop_elected" => { + let prop_lsn: u64 = parts.next().unwrap().parse().unwrap(); + let prop_term: u64 = parts.next().unwrap().parse().unwrap(); + let prev_lsn: u64 = parts.next().unwrap().parse().unwrap(); + let prev_term: u64 = parts.next().unwrap().parse().unwrap(); + + assert!(prop_lsn >= prev_lsn); + assert!(prop_term >= prev_term); + + assert!(prop_lsn >= state.commit_lsn); + + if prop_lsn > state.write_lsn { + assert!(prop_lsn <= state.max_write_lsn); + debug!( + "moving write_lsn up from {} to {}", + state.write_lsn, prop_lsn + ); + state.write_lsn = prop_lsn; + } + if prop_lsn < state.write_lsn { + debug!( + "moving write_lsn down from {} to {}", + state.write_lsn, prop_lsn + ); + state.write_lsn = prop_lsn; + } + + node.epoch_lsn = prop_lsn; + node.term = prop_term; + } + "write_wal" => { + assert!(!node.is_sync); + let start_lsn: u64 = parts.next().unwrap().parse().unwrap(); + let end_lsn: u64 = parts.next().unwrap().parse().unwrap(); + let cnt: u64 = parts.next().unwrap().parse().unwrap(); + + let size = end_lsn - start_lsn; + state.written_wal += size; + state.written_records += cnt; + + // TODO: If we allow writing WAL before winning the election + + assert!(start_lsn >= state.commit_lsn); + assert!(end_lsn >= start_lsn); + // assert!(start_lsn == state.write_lsn); + state.write_lsn = end_lsn; + + if end_lsn > state.max_write_lsn { + state.max_write_lsn = end_lsn; + } + } + "commit_lsn" => { + let lsn: u64 = parts.next().unwrap().parse().unwrap(); + assert!(lsn >= state.commit_lsn); + state.commit_lsn = lsn; + } + _ => unreachable!(), + } + } + _ => unreachable!(), + } + + // update the node in the state struct + *state.get(event.node) = node; + } +} diff --git a/safekeeper/tests/walproposer_sim/walproposer_api.rs b/safekeeper/tests/walproposer_sim/walproposer_api.rs new file mode 100644 index 0000000000..746cac019e --- /dev/null +++ b/safekeeper/tests/walproposer_sim/walproposer_api.rs @@ -0,0 +1,676 @@ +use std::{ + cell::{RefCell, RefMut, UnsafeCell}, + ffi::CStr, + sync::Arc, +}; + +use bytes::Bytes; +use desim::{ + executor::{self, PollSome}, + network::TCP, + node_os::NodeOs, + proto::{AnyMessage, NetEvent, NodeEvent}, + world::NodeId, +}; +use tracing::debug; +use utils::lsn::Lsn; +use walproposer::{ + api_bindings::Level, + bindings::{ + pg_atomic_uint64, NeonWALReadResult, PageserverFeedback, SafekeeperStateDesiredEvents, + WL_SOCKET_READABLE, WL_SOCKET_WRITEABLE, + }, + walproposer::{ApiImpl, Config}, +}; + +use super::walproposer_disk::DiskWalProposer; + +/// Special state for each wp->sk connection. +struct SafekeeperConn { + host: String, + port: String, + node_id: NodeId, + // socket is Some(..) equals to connection is established + socket: Option, + // connection is in progress + is_connecting: bool, + // START_WAL_PUSH is in progress + is_start_wal_push: bool, + // pointer to Safekeeper in walproposer for callbacks + raw_ptr: *mut walproposer::bindings::Safekeeper, +} + +impl SafekeeperConn { + pub fn new(host: String, port: String) -> Self { + // port number is the same as NodeId + let port_num = port.parse::().unwrap(); + Self { + host, + port, + node_id: port_num, + socket: None, + is_connecting: false, + is_start_wal_push: false, + raw_ptr: std::ptr::null_mut(), + } + } +} + +/// Simulation version of a postgres WaitEventSet. At pos 0 there is always +/// a special NodeEvents channel, which is used as a latch. +struct EventSet { + os: NodeOs, + // all pollable channels, 0 is always NodeEvent channel + chans: Vec>, + // 0 is always nullptr + sk_ptrs: Vec<*mut walproposer::bindings::Safekeeper>, + // event mask for each channel + masks: Vec, +} + +impl EventSet { + pub fn new(os: NodeOs) -> Self { + let node_events = os.node_events(); + Self { + os, + chans: vec![Box::new(node_events)], + sk_ptrs: vec![std::ptr::null_mut()], + masks: vec![WL_SOCKET_READABLE], + } + } + + /// Leaves all readable channels at the beginning of the array. + fn sort_readable(&mut self) -> usize { + let mut cnt = 1; + for i in 1..self.chans.len() { + if self.masks[i] & WL_SOCKET_READABLE != 0 { + self.chans.swap(i, cnt); + self.sk_ptrs.swap(i, cnt); + self.masks.swap(i, cnt); + cnt += 1; + } + } + cnt + } + + fn update_event_set(&mut self, conn: &SafekeeperConn, event_mask: u32) { + let index = self + .sk_ptrs + .iter() + .position(|&ptr| ptr == conn.raw_ptr) + .expect("safekeeper should exist in event set"); + self.masks[index] = event_mask; + } + + fn add_safekeeper(&mut self, sk: &SafekeeperConn, event_mask: u32) { + for ptr in self.sk_ptrs.iter() { + assert!(*ptr != sk.raw_ptr); + } + + self.chans.push(Box::new( + sk.socket + .as_ref() + .expect("socket should not be closed") + .recv_chan(), + )); + self.sk_ptrs.push(sk.raw_ptr); + self.masks.push(event_mask); + } + + fn remove_safekeeper(&mut self, sk: &SafekeeperConn) { + let index = self.sk_ptrs.iter().position(|&ptr| ptr == sk.raw_ptr); + if index.is_none() { + debug!("remove_safekeeper: sk={:?} not found", sk.raw_ptr); + return; + } + let index = index.unwrap(); + + self.chans.remove(index); + self.sk_ptrs.remove(index); + self.masks.remove(index); + + // to simulate the actual behaviour + self.refresh_event_set(); + } + + /// Updates all masks to match the result of a SafekeeperStateDesiredEvents. + fn refresh_event_set(&mut self) { + for (i, mask) in self.masks.iter_mut().enumerate() { + if i == 0 { + continue; + } + + let mut mask_sk: u32 = 0; + let mut mask_nwr: u32 = 0; + unsafe { SafekeeperStateDesiredEvents(self.sk_ptrs[i], &mut mask_sk, &mut mask_nwr) }; + + if mask_sk != *mask { + debug!( + "refresh_event_set: sk={:?}, old_mask={:#b}, new_mask={:#b}", + self.sk_ptrs[i], *mask, mask_sk + ); + *mask = mask_sk; + } + } + } + + /// Wait for events on all channels. + fn wait(&mut self, timeout_millis: i64) -> walproposer::walproposer::WaitResult { + // all channels are always writeable + for (i, mask) in self.masks.iter().enumerate() { + if *mask & WL_SOCKET_WRITEABLE != 0 { + return walproposer::walproposer::WaitResult::Network( + self.sk_ptrs[i], + WL_SOCKET_WRITEABLE, + ); + } + } + + let cnt = self.sort_readable(); + + let slice = &self.chans[0..cnt]; + match executor::epoll_chans(slice, timeout_millis) { + None => walproposer::walproposer::WaitResult::Timeout, + Some(0) => { + let msg = self.os.node_events().must_recv(); + match msg { + NodeEvent::Internal(AnyMessage::Just32(0)) => { + // got a notification about new WAL available + } + NodeEvent::Internal(_) => unreachable!(), + NodeEvent::Accept(_) => unreachable!(), + } + walproposer::walproposer::WaitResult::Latch + } + Some(index) => walproposer::walproposer::WaitResult::Network( + self.sk_ptrs[index], + WL_SOCKET_READABLE, + ), + } + } +} + +/// This struct handles all calls from walproposer into walproposer_api. +pub struct SimulationApi { + os: NodeOs, + safekeepers: RefCell>, + disk: Arc, + redo_start_lsn: Option, + shmem: UnsafeCell, + config: Config, + event_set: RefCell>, +} + +pub struct Args { + pub os: NodeOs, + pub config: Config, + pub disk: Arc, + pub redo_start_lsn: Option, +} + +impl SimulationApi { + pub fn new(args: Args) -> Self { + // initialize connection state for each safekeeper + let sk_conns = args + .config + .safekeepers_list + .iter() + .map(|s| { + SafekeeperConn::new( + s.split(':').next().unwrap().to_string(), + s.split(':').nth(1).unwrap().to_string(), + ) + }) + .collect::>(); + + Self { + os: args.os, + safekeepers: RefCell::new(sk_conns), + disk: args.disk, + redo_start_lsn: args.redo_start_lsn, + shmem: UnsafeCell::new(walproposer::bindings::WalproposerShmemState { + mutex: 0, + feedback: PageserverFeedback { + currentClusterSize: 0, + last_received_lsn: 0, + disk_consistent_lsn: 0, + remote_consistent_lsn: 0, + replytime: 0, + }, + mineLastElectedTerm: 0, + backpressureThrottlingTime: pg_atomic_uint64 { value: 0 }, + }), + config: args.config, + event_set: RefCell::new(None), + } + } + + /// Get SafekeeperConn for the given Safekeeper. + fn get_conn(&self, sk: &mut walproposer::bindings::Safekeeper) -> RefMut<'_, SafekeeperConn> { + let sk_port = unsafe { CStr::from_ptr(sk.port).to_str().unwrap() }; + let state = self.safekeepers.borrow_mut(); + RefMut::map(state, |v| { + v.iter_mut() + .find(|conn| conn.port == sk_port) + .expect("safekeeper conn not found by port") + }) + } +} + +impl ApiImpl for SimulationApi { + fn get_current_timestamp(&self) -> i64 { + debug!("get_current_timestamp"); + // PG TimestampTZ is microseconds, but simulation unit is assumed to be + // milliseconds, so add 10^3 + self.os.now() as i64 * 1000 + } + + fn conn_status( + &self, + _: &mut walproposer::bindings::Safekeeper, + ) -> walproposer::bindings::WalProposerConnStatusType { + debug!("conn_status"); + // break the connection with a 10% chance + if self.os.random(100) < 10 { + walproposer::bindings::WalProposerConnStatusType_WP_CONNECTION_BAD + } else { + walproposer::bindings::WalProposerConnStatusType_WP_CONNECTION_OK + } + } + + fn conn_connect_start(&self, sk: &mut walproposer::bindings::Safekeeper) { + debug!("conn_connect_start"); + let mut conn = self.get_conn(sk); + + assert!(conn.socket.is_none()); + let socket = self.os.open_tcp(conn.node_id); + conn.socket = Some(socket); + conn.raw_ptr = sk; + conn.is_connecting = true; + } + + fn conn_connect_poll( + &self, + _: &mut walproposer::bindings::Safekeeper, + ) -> walproposer::bindings::WalProposerConnectPollStatusType { + debug!("conn_connect_poll"); + // TODO: break the connection here + walproposer::bindings::WalProposerConnectPollStatusType_WP_CONN_POLLING_OK + } + + fn conn_send_query(&self, sk: &mut walproposer::bindings::Safekeeper, query: &str) -> bool { + debug!("conn_send_query: {}", query); + self.get_conn(sk).is_start_wal_push = true; + true + } + + fn conn_get_query_result( + &self, + _: &mut walproposer::bindings::Safekeeper, + ) -> walproposer::bindings::WalProposerExecStatusType { + debug!("conn_get_query_result"); + // TODO: break the connection here + walproposer::bindings::WalProposerExecStatusType_WP_EXEC_SUCCESS_COPYBOTH + } + + fn conn_async_read( + &self, + sk: &mut walproposer::bindings::Safekeeper, + vec: &mut Vec, + ) -> walproposer::bindings::PGAsyncReadResult { + debug!("conn_async_read"); + let mut conn = self.get_conn(sk); + + let socket = if let Some(socket) = conn.socket.as_mut() { + socket + } else { + // socket is already closed + return walproposer::bindings::PGAsyncReadResult_PG_ASYNC_READ_FAIL; + }; + + let msg = socket.recv_chan().try_recv(); + + match msg { + None => { + // no message is ready + walproposer::bindings::PGAsyncReadResult_PG_ASYNC_READ_TRY_AGAIN + } + Some(NetEvent::Closed) => { + // connection is closed + debug!("conn_async_read: connection is closed"); + conn.socket = None; + walproposer::bindings::PGAsyncReadResult_PG_ASYNC_READ_FAIL + } + Some(NetEvent::Message(msg)) => { + // got a message + let b = match msg { + desim::proto::AnyMessage::Bytes(b) => b, + _ => unreachable!(), + }; + vec.extend_from_slice(&b); + walproposer::bindings::PGAsyncReadResult_PG_ASYNC_READ_SUCCESS + } + } + } + + fn conn_blocking_write(&self, sk: &mut walproposer::bindings::Safekeeper, buf: &[u8]) -> bool { + let mut conn = self.get_conn(sk); + debug!("conn_blocking_write to {}: {:?}", conn.node_id, buf); + let socket = conn.socket.as_mut().unwrap(); + socket.send(desim::proto::AnyMessage::Bytes(Bytes::copy_from_slice(buf))); + true + } + + fn conn_async_write( + &self, + sk: &mut walproposer::bindings::Safekeeper, + buf: &[u8], + ) -> walproposer::bindings::PGAsyncWriteResult { + let mut conn = self.get_conn(sk); + debug!("conn_async_write to {}: {:?}", conn.node_id, buf); + if let Some(socket) = conn.socket.as_mut() { + socket.send(desim::proto::AnyMessage::Bytes(Bytes::copy_from_slice(buf))); + } else { + // connection is already closed + debug!("conn_async_write: writing to a closed socket!"); + // TODO: maybe we should return error here? + } + walproposer::bindings::PGAsyncWriteResult_PG_ASYNC_WRITE_SUCCESS + } + + fn wal_reader_allocate(&self, _: &mut walproposer::bindings::Safekeeper) -> NeonWALReadResult { + debug!("wal_reader_allocate"); + walproposer::bindings::NeonWALReadResult_NEON_WALREAD_SUCCESS + } + + fn wal_read( + &self, + _sk: &mut walproposer::bindings::Safekeeper, + buf: &mut [u8], + startpos: u64, + ) -> NeonWALReadResult { + self.disk.lock().read(startpos, buf); + walproposer::bindings::NeonWALReadResult_NEON_WALREAD_SUCCESS + } + + fn init_event_set(&self, _: &mut walproposer::bindings::WalProposer) { + debug!("init_event_set"); + let new_event_set = EventSet::new(self.os.clone()); + let old_event_set = self.event_set.replace(Some(new_event_set)); + assert!(old_event_set.is_none()); + } + + fn update_event_set(&self, sk: &mut walproposer::bindings::Safekeeper, event_mask: u32) { + debug!( + "update_event_set, sk={:?}, events_mask={:#b}", + sk as *mut walproposer::bindings::Safekeeper, event_mask + ); + let conn = self.get_conn(sk); + + self.event_set + .borrow_mut() + .as_mut() + .unwrap() + .update_event_set(&conn, event_mask); + } + + fn add_safekeeper_event_set( + &self, + sk: &mut walproposer::bindings::Safekeeper, + event_mask: u32, + ) { + debug!( + "add_safekeeper_event_set, sk={:?}, events_mask={:#b}", + sk as *mut walproposer::bindings::Safekeeper, event_mask + ); + + self.event_set + .borrow_mut() + .as_mut() + .unwrap() + .add_safekeeper(&self.get_conn(sk), event_mask); + } + + fn rm_safekeeper_event_set(&self, sk: &mut walproposer::bindings::Safekeeper) { + debug!( + "rm_safekeeper_event_set, sk={:?}", + sk as *mut walproposer::bindings::Safekeeper, + ); + + self.event_set + .borrow_mut() + .as_mut() + .unwrap() + .remove_safekeeper(&self.get_conn(sk)); + } + + fn active_state_update_event_set(&self, sk: &mut walproposer::bindings::Safekeeper) { + debug!("active_state_update_event_set"); + + assert!(sk.state == walproposer::bindings::SafekeeperState_SS_ACTIVE); + self.event_set + .borrow_mut() + .as_mut() + .unwrap() + .refresh_event_set(); + } + + fn wal_reader_events(&self, _sk: &mut walproposer::bindings::Safekeeper) -> u32 { + 0 + } + + fn wait_event_set( + &self, + _: &mut walproposer::bindings::WalProposer, + timeout_millis: i64, + ) -> walproposer::walproposer::WaitResult { + // TODO: handle multiple stages as part of the simulation (e.g. connect, start_wal_push, etc) + let mut conns = self.safekeepers.borrow_mut(); + for conn in conns.iter_mut() { + if conn.socket.is_some() && conn.is_connecting { + conn.is_connecting = false; + debug!("wait_event_set, connecting to {}:{}", conn.host, conn.port); + return walproposer::walproposer::WaitResult::Network( + conn.raw_ptr, + WL_SOCKET_READABLE | WL_SOCKET_WRITEABLE, + ); + } + if conn.socket.is_some() && conn.is_start_wal_push { + conn.is_start_wal_push = false; + debug!( + "wait_event_set, start wal push to {}:{}", + conn.host, conn.port + ); + return walproposer::walproposer::WaitResult::Network( + conn.raw_ptr, + WL_SOCKET_READABLE, + ); + } + } + drop(conns); + + let res = self + .event_set + .borrow_mut() + .as_mut() + .unwrap() + .wait(timeout_millis); + + debug!( + "wait_event_set, timeout_millis={}, res={:?}", + timeout_millis, res, + ); + res + } + + fn strong_random(&self, buf: &mut [u8]) -> bool { + debug!("strong_random"); + buf.fill(0); + true + } + + fn finish_sync_safekeepers(&self, lsn: u64) { + debug!("finish_sync_safekeepers, lsn={}", lsn); + executor::exit(0, Lsn(lsn).to_string()); + } + + fn log_internal(&self, _wp: &mut walproposer::bindings::WalProposer, level: Level, msg: &str) { + debug!("wp_log[{}] {}", level, msg); + if level == Level::Fatal || level == Level::Panic { + if msg.contains("rejects our connection request with term") { + // collected quorum with lower term, then got rejected by next connected safekeeper + executor::exit(1, msg.to_owned()); + } + if msg.contains("collected propEpochStartLsn") && msg.contains(", but basebackup LSN ") + { + // sync-safekeepers collected wrong quorum, walproposer collected another quorum + executor::exit(1, msg.to_owned()); + } + if msg.contains("failed to download WAL for logical replicaiton") { + // Recovery connection broken and recovery was failed + executor::exit(1, msg.to_owned()); + } + if msg.contains("missing majority of votes, collected") { + // Voting bug when safekeeper disconnects after voting + executor::exit(1, msg.to_owned()); + } + panic!("unknown FATAL error from walproposer: {}", msg); + } + } + + fn after_election(&self, wp: &mut walproposer::bindings::WalProposer) { + let prop_lsn = wp.propEpochStartLsn; + let prop_term = wp.propTerm; + + let mut prev_lsn: u64 = 0; + let mut prev_term: u64 = 0; + + unsafe { + let history = wp.propTermHistory.entries; + let len = wp.propTermHistory.n_entries as usize; + if len > 1 { + let entry = *history.wrapping_add(len - 2); + prev_lsn = entry.lsn; + prev_term = entry.term; + } + } + + let msg = format!( + "prop_elected;{};{};{};{}", + prop_lsn, prop_term, prev_lsn, prev_term + ); + + debug!(msg); + self.os.log_event(msg); + } + + fn get_redo_start_lsn(&self) -> u64 { + debug!("get_redo_start_lsn -> {:?}", self.redo_start_lsn); + self.redo_start_lsn.expect("redo_start_lsn is not set").0 + } + + fn get_shmem_state(&self) -> *mut walproposer::bindings::WalproposerShmemState { + self.shmem.get() + } + + fn start_streaming( + &self, + startpos: u64, + callback: &walproposer::walproposer::StreamingCallback, + ) { + let disk = &self.disk; + let disk_lsn = disk.lock().flush_rec_ptr().0; + debug!("start_streaming at {} (disk_lsn={})", startpos, disk_lsn); + if startpos < disk_lsn { + debug!("startpos < disk_lsn, it means we wrote some transaction even before streaming started"); + } + assert!(startpos <= disk_lsn); + let mut broadcasted = Lsn(startpos); + + loop { + let available = disk.lock().flush_rec_ptr(); + assert!(available >= broadcasted); + callback.broadcast(broadcasted, available); + broadcasted = available; + callback.poll(); + } + } + + fn process_safekeeper_feedback( + &self, + wp: &mut walproposer::bindings::WalProposer, + commit_lsn: u64, + ) { + debug!("process_safekeeper_feedback, commit_lsn={}", commit_lsn); + if commit_lsn > wp.lastSentCommitLsn { + self.os.log_event(format!("commit_lsn;{}", commit_lsn)); + } + } + + fn get_flush_rec_ptr(&self) -> u64 { + let lsn = self.disk.lock().flush_rec_ptr(); + debug!("get_flush_rec_ptr: {}", lsn); + lsn.0 + } + + fn recovery_download( + &self, + wp: &mut walproposer::bindings::WalProposer, + sk: &mut walproposer::bindings::Safekeeper, + ) -> bool { + let mut startpos = wp.truncateLsn; + let endpos = wp.propEpochStartLsn; + + if startpos == endpos { + debug!("recovery_download: nothing to download"); + return true; + } + + debug!("recovery_download from {} to {}", startpos, endpos,); + + let replication_prompt = format!( + "START_REPLICATION {} {} {} {}", + self.config.ttid.tenant_id, self.config.ttid.timeline_id, startpos, endpos, + ); + let async_conn = self.get_conn(sk); + + let conn = self.os.open_tcp(async_conn.node_id); + conn.send(desim::proto::AnyMessage::Bytes(replication_prompt.into())); + + let chan = conn.recv_chan(); + while startpos < endpos { + let event = chan.recv(); + match event { + NetEvent::Closed => { + debug!("connection closed in recovery"); + break; + } + NetEvent::Message(AnyMessage::Bytes(b)) => { + debug!("got recovery bytes from safekeeper"); + self.disk.lock().write(startpos, &b); + startpos += b.len() as u64; + } + NetEvent::Message(_) => unreachable!(), + } + } + + debug!("recovery finished at {}", startpos); + + startpos == endpos + } + + fn conn_finish(&self, sk: &mut walproposer::bindings::Safekeeper) { + let mut conn = self.get_conn(sk); + debug!("conn_finish to {}", conn.node_id); + if let Some(socket) = conn.socket.as_mut() { + socket.close(); + } else { + // connection is already closed + } + conn.socket = None; + } + + fn conn_error_message(&self, _sk: &mut walproposer::bindings::Safekeeper) -> String { + "connection is closed, probably".into() + } +} diff --git a/safekeeper/tests/walproposer_sim/walproposer_disk.rs b/safekeeper/tests/walproposer_sim/walproposer_disk.rs new file mode 100644 index 0000000000..aa329bd2f0 --- /dev/null +++ b/safekeeper/tests/walproposer_sim/walproposer_disk.rs @@ -0,0 +1,314 @@ +use std::{ffi::CString, sync::Arc}; + +use byteorder::{LittleEndian, WriteBytesExt}; +use crc32c::crc32c_append; +use parking_lot::{Mutex, MutexGuard}; +use postgres_ffi::{ + pg_constants::{ + RM_LOGICALMSG_ID, XLOG_LOGICAL_MESSAGE, XLP_LONG_HEADER, XLR_BLOCK_ID_DATA_LONG, + XLR_BLOCK_ID_DATA_SHORT, + }, + v16::{ + wal_craft_test_export::{XLogLongPageHeaderData, XLogPageHeaderData, XLOG_PAGE_MAGIC}, + xlog_utils::{ + XLogSegNoOffsetToRecPtr, XlLogicalMessage, XLOG_RECORD_CRC_OFFS, + XLOG_SIZE_OF_XLOG_LONG_PHD, XLOG_SIZE_OF_XLOG_RECORD, XLOG_SIZE_OF_XLOG_SHORT_PHD, + XLP_FIRST_IS_CONTRECORD, + }, + XLogRecord, + }, + WAL_SEGMENT_SIZE, XLOG_BLCKSZ, +}; +use utils::lsn::Lsn; + +use super::block_storage::BlockStorage; + +/// Simulation implementation of walproposer WAL storage. +pub struct DiskWalProposer { + state: Mutex, +} + +impl DiskWalProposer { + pub fn new() -> Arc { + Arc::new(DiskWalProposer { + state: Mutex::new(State { + internal_available_lsn: Lsn(0), + prev_lsn: Lsn(0), + disk: BlockStorage::new(), + }), + }) + } + + pub fn lock(&self) -> MutexGuard { + self.state.lock() + } +} + +pub struct State { + // flush_lsn + internal_available_lsn: Lsn, + // needed for WAL generation + prev_lsn: Lsn, + // actual WAL storage + disk: BlockStorage, +} + +impl State { + pub fn read(&self, pos: u64, buf: &mut [u8]) { + self.disk.read(pos, buf); + // TODO: fail on reading uninitialized data + } + + pub fn write(&mut self, pos: u64, buf: &[u8]) { + self.disk.write(pos, buf); + } + + /// Update the internal available LSN to the given value. + pub fn reset_to(&mut self, lsn: Lsn) { + self.internal_available_lsn = lsn; + } + + /// Get current LSN. + pub fn flush_rec_ptr(&self) -> Lsn { + self.internal_available_lsn + } + + /// Generate a new WAL record at the current LSN. + pub fn insert_logical_message(&mut self, prefix: &str, msg: &[u8]) -> anyhow::Result<()> { + let prefix_cstr = CString::new(prefix)?; + let prefix_bytes = prefix_cstr.as_bytes_with_nul(); + + let lm = XlLogicalMessage { + db_id: 0, + transactional: 0, + prefix_size: prefix_bytes.len() as ::std::os::raw::c_ulong, + message_size: msg.len() as ::std::os::raw::c_ulong, + }; + + let record_bytes = lm.encode(); + let rdatas: Vec<&[u8]> = vec![&record_bytes, prefix_bytes, msg]; + insert_wal_record(self, rdatas, RM_LOGICALMSG_ID, XLOG_LOGICAL_MESSAGE) + } +} + +fn insert_wal_record( + state: &mut State, + rdatas: Vec<&[u8]>, + rmid: u8, + info: u8, +) -> anyhow::Result<()> { + // bytes right after the header, in the same rdata block + let mut scratch = Vec::new(); + let mainrdata_len: usize = rdatas.iter().map(|rdata| rdata.len()).sum(); + + if mainrdata_len > 0 { + if mainrdata_len > 255 { + scratch.push(XLR_BLOCK_ID_DATA_LONG); + // TODO: verify endiness + let _ = scratch.write_u32::(mainrdata_len as u32); + } else { + scratch.push(XLR_BLOCK_ID_DATA_SHORT); + scratch.push(mainrdata_len as u8); + } + } + + let total_len: u32 = (XLOG_SIZE_OF_XLOG_RECORD + scratch.len() + mainrdata_len) as u32; + let size = maxalign(total_len); + assert!(size as usize > XLOG_SIZE_OF_XLOG_RECORD); + + let start_bytepos = recptr_to_bytepos(state.internal_available_lsn); + let end_bytepos = start_bytepos + size as u64; + + let start_recptr = bytepos_to_recptr(start_bytepos); + let end_recptr = bytepos_to_recptr(end_bytepos); + + assert!(recptr_to_bytepos(start_recptr) == start_bytepos); + assert!(recptr_to_bytepos(end_recptr) == end_bytepos); + + let mut crc = crc32c_append(0, &scratch); + for rdata in &rdatas { + crc = crc32c_append(crc, rdata); + } + + let mut header = XLogRecord { + xl_tot_len: total_len, + xl_xid: 0, + xl_prev: state.prev_lsn.0, + xl_info: info, + xl_rmid: rmid, + __bindgen_padding_0: [0u8; 2usize], + xl_crc: crc, + }; + + // now we have the header and can finish the crc + let header_bytes = header.encode()?; + let crc = crc32c_append(crc, &header_bytes[0..XLOG_RECORD_CRC_OFFS]); + header.xl_crc = crc; + + let mut header_bytes = header.encode()?.to_vec(); + assert!(header_bytes.len() == XLOG_SIZE_OF_XLOG_RECORD); + + header_bytes.extend_from_slice(&scratch); + + // finish rdatas + let mut rdatas = rdatas; + rdatas.insert(0, &header_bytes); + + write_walrecord_to_disk(state, total_len as u64, rdatas, start_recptr, end_recptr)?; + + state.internal_available_lsn = end_recptr; + state.prev_lsn = start_recptr; + Ok(()) +} + +fn write_walrecord_to_disk( + state: &mut State, + total_len: u64, + rdatas: Vec<&[u8]>, + start: Lsn, + end: Lsn, +) -> anyhow::Result<()> { + let mut curr_ptr = start; + let mut freespace = insert_freespace(curr_ptr); + let mut written: usize = 0; + + assert!(freespace >= std::mem::size_of::()); + + for mut rdata in rdatas { + while rdata.len() >= freespace { + assert!( + curr_ptr.segment_offset(WAL_SEGMENT_SIZE) >= XLOG_SIZE_OF_XLOG_SHORT_PHD + || freespace == 0 + ); + + state.write(curr_ptr.0, &rdata[..freespace]); + rdata = &rdata[freespace..]; + written += freespace; + curr_ptr = Lsn(curr_ptr.0 + freespace as u64); + + let mut new_page = XLogPageHeaderData { + xlp_magic: XLOG_PAGE_MAGIC as u16, + xlp_info: XLP_BKP_REMOVABLE, + xlp_tli: 1, + xlp_pageaddr: curr_ptr.0, + xlp_rem_len: (total_len - written as u64) as u32, + ..Default::default() // Put 0 in padding fields. + }; + if new_page.xlp_rem_len > 0 { + new_page.xlp_info |= XLP_FIRST_IS_CONTRECORD; + } + + if curr_ptr.segment_offset(WAL_SEGMENT_SIZE) == 0 { + new_page.xlp_info |= XLP_LONG_HEADER; + let long_page = XLogLongPageHeaderData { + std: new_page, + xlp_sysid: 0, + xlp_seg_size: WAL_SEGMENT_SIZE as u32, + xlp_xlog_blcksz: XLOG_BLCKSZ as u32, + }; + let header_bytes = long_page.encode()?; + assert!(header_bytes.len() == XLOG_SIZE_OF_XLOG_LONG_PHD); + state.write(curr_ptr.0, &header_bytes); + curr_ptr = Lsn(curr_ptr.0 + header_bytes.len() as u64); + } else { + let header_bytes = new_page.encode()?; + assert!(header_bytes.len() == XLOG_SIZE_OF_XLOG_SHORT_PHD); + state.write(curr_ptr.0, &header_bytes); + curr_ptr = Lsn(curr_ptr.0 + header_bytes.len() as u64); + } + freespace = insert_freespace(curr_ptr); + } + + assert!( + curr_ptr.segment_offset(WAL_SEGMENT_SIZE) >= XLOG_SIZE_OF_XLOG_SHORT_PHD + || rdata.is_empty() + ); + state.write(curr_ptr.0, rdata); + curr_ptr = Lsn(curr_ptr.0 + rdata.len() as u64); + written += rdata.len(); + freespace -= rdata.len(); + } + + assert!(written == total_len as usize); + curr_ptr.0 = maxalign(curr_ptr.0); + assert!(curr_ptr == end); + Ok(()) +} + +fn maxalign(size: T) -> T +where + T: std::ops::BitAnd + + std::ops::Add + + std::ops::Not + + From, +{ + (size + T::from(7)) & !T::from(7) +} + +fn insert_freespace(ptr: Lsn) -> usize { + if ptr.block_offset() == 0 { + 0 + } else { + (XLOG_BLCKSZ as u64 - ptr.block_offset()) as usize + } +} + +const XLP_BKP_REMOVABLE: u16 = 0x0004; +const USABLE_BYTES_IN_PAGE: u64 = (XLOG_BLCKSZ - XLOG_SIZE_OF_XLOG_SHORT_PHD) as u64; +const USABLE_BYTES_IN_SEGMENT: u64 = ((WAL_SEGMENT_SIZE / XLOG_BLCKSZ) as u64 + * USABLE_BYTES_IN_PAGE) + - (XLOG_SIZE_OF_XLOG_RECORD - XLOG_SIZE_OF_XLOG_SHORT_PHD) as u64; + +fn bytepos_to_recptr(bytepos: u64) -> Lsn { + let fullsegs = bytepos / USABLE_BYTES_IN_SEGMENT; + let mut bytesleft = bytepos % USABLE_BYTES_IN_SEGMENT; + + let seg_offset = if bytesleft < (XLOG_BLCKSZ - XLOG_SIZE_OF_XLOG_SHORT_PHD) as u64 { + // fits on first page of segment + bytesleft + XLOG_SIZE_OF_XLOG_SHORT_PHD as u64 + } else { + // account for the first page on segment with long header + bytesleft -= (XLOG_BLCKSZ - XLOG_SIZE_OF_XLOG_SHORT_PHD) as u64; + let fullpages = bytesleft / USABLE_BYTES_IN_PAGE; + bytesleft %= USABLE_BYTES_IN_PAGE; + + XLOG_BLCKSZ as u64 + + fullpages * XLOG_BLCKSZ as u64 + + bytesleft + + XLOG_SIZE_OF_XLOG_SHORT_PHD as u64 + }; + + Lsn(XLogSegNoOffsetToRecPtr( + fullsegs, + seg_offset as u32, + WAL_SEGMENT_SIZE, + )) +} + +fn recptr_to_bytepos(ptr: Lsn) -> u64 { + let fullsegs = ptr.segment_number(WAL_SEGMENT_SIZE); + let offset = ptr.segment_offset(WAL_SEGMENT_SIZE) as u64; + + let fullpages = offset / XLOG_BLCKSZ as u64; + let offset = offset % XLOG_BLCKSZ as u64; + + if fullpages == 0 { + fullsegs * USABLE_BYTES_IN_SEGMENT + + if offset > 0 { + assert!(offset >= XLOG_SIZE_OF_XLOG_SHORT_PHD as u64); + offset - XLOG_SIZE_OF_XLOG_SHORT_PHD as u64 + } else { + 0 + } + } else { + fullsegs * USABLE_BYTES_IN_SEGMENT + + (XLOG_BLCKSZ - XLOG_SIZE_OF_XLOG_SHORT_PHD) as u64 + + (fullpages - 1) * USABLE_BYTES_IN_PAGE + + if offset > 0 { + assert!(offset >= XLOG_SIZE_OF_XLOG_SHORT_PHD as u64); + offset - XLOG_SIZE_OF_XLOG_SHORT_PHD as u64 + } else { + 0 + } + } +} From a8eb4042baa6ca1ae4268a1f1b22a89941b0d942 Mon Sep 17 00:00:00 2001 From: John Spray Date: Tue, 13 Feb 2024 07:00:50 +0000 Subject: [PATCH 12/20] tests: test_secondary_mode_eviction: avoid use of mocked statvfs (#6698) ## Problem Test sometimes fails with `used_blocks > total_blocks`, because when using mocked statvfs with the total blocks set to the size of data on disk before starting, we are implicitly asserting that nothing at all can be written to disk between startup and calling statvfs. Related: https://github.com/neondatabase/neon/issues/6511 ## Summary of changes - Use HTTP API to invoke disk usage eviction instead of mocked statvfs --- .../regress/test_disk_usage_eviction.py | 33 +++---------------- 1 file changed, 5 insertions(+), 28 deletions(-) diff --git a/test_runner/regress/test_disk_usage_eviction.py b/test_runner/regress/test_disk_usage_eviction.py index 061c57c88b..eb4e370ea7 100644 --- a/test_runner/regress/test_disk_usage_eviction.py +++ b/test_runner/regress/test_disk_usage_eviction.py @@ -893,37 +893,14 @@ def test_secondary_mode_eviction(eviction_env_ha: EvictionEnv): # in its heatmap ps_secondary.http_client().tenant_secondary_download(tenant_id) - # Configure the secondary pageserver to have a phony small disk size - ps_secondary.stop() total_size, _, _ = env.timelines_du(ps_secondary) - blocksize = 512 - total_blocks = (total_size + (blocksize - 1)) // blocksize + evict_bytes = total_size // 3 - min_avail_bytes = total_size // 3 - - env.pageserver_start_with_disk_usage_eviction( - ps_secondary, - period="1s", - max_usage_pct=100, - min_avail_bytes=min_avail_bytes, - mock_behavior={ - "type": "Success", - "blocksize": blocksize, - "total_blocks": total_blocks, - # Only count layer files towards used bytes in the mock_statvfs. - # This avoids accounting for metadata files & tenant conf in the tests. - "name_filter": ".*__.*", - }, - eviction_order=EvictionOrder.ABSOLUTE_ORDER, - ) - - def relieved_log_message(): - assert ps_secondary.log_contains(".*disk usage pressure relieved") - - wait_until(10, 1, relieved_log_message) + response = ps_secondary.http_client().disk_usage_eviction_run({"evict_bytes": evict_bytes}) + log.info(f"{response}") post_eviction_total_size, _, _ = env.timelines_du(ps_secondary) assert ( - total_size - post_eviction_total_size >= min_avail_bytes - ), "we requested at least min_avail_bytes worth of free space" + total_size - post_eviction_total_size >= evict_bytes + ), "we requested at least evict_bytes worth of free space" From 331935df91abe03a1e8a081bc96b6ef871f71bb1 Mon Sep 17 00:00:00 2001 From: Anna Khanova <32508607+khanova@users.noreply.github.com> Date: Tue, 13 Feb 2024 17:58:58 +0100 Subject: [PATCH 13/20] Proxy: send cancel notifications to all instances (#6719) ## Problem If cancel request ends up on the wrong proxy instance, it doesn't take an effect. ## Summary of changes Send redis notifications to all proxy pods about the cancel request. Related issue: https://github.com/neondatabase/neon/issues/5839, https://github.com/neondatabase/cloud/issues/10262 --- Cargo.lock | 7 +- Cargo.toml | 2 +- libs/pq_proto/Cargo.toml | 1 + libs/pq_proto/src/lib.rs | 3 +- proxy/src/bin/proxy.rs | 32 ++++- proxy/src/cancellation.rs | 109 ++++++++++++++--- proxy/src/config.rs | 1 + proxy/src/metrics.rs | 9 ++ proxy/src/proxy.rs | 16 +-- proxy/src/rate_limiter.rs | 2 +- proxy/src/rate_limiter/limiter.rs | 38 ++++++ proxy/src/redis.rs | 1 + proxy/src/redis/notifications.rs | 197 +++++++++++++++++++++++------- proxy/src/redis/publisher.rs | 80 ++++++++++++ proxy/src/serverless.rs | 13 +- proxy/src/serverless/websocket.rs | 6 +- workspace_hack/Cargo.toml | 4 +- 17 files changed, 432 insertions(+), 89 deletions(-) create mode 100644 proxy/src/redis/publisher.rs diff --git a/Cargo.lock b/Cargo.lock index f11c774016..45a313a72b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2263,11 +2263,11 @@ dependencies = [ [[package]] name = "hashlink" -version = "0.8.2" +version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0761a1b9491c4f2e3d66aa0f62d0fba0af9a0e2852e4d48ea506632a4b56e6aa" +checksum = "e8094feaf31ff591f651a2664fb9cfd92bba7a60ce3197265e9482ebe753c8f7" dependencies = [ - "hashbrown 0.13.2", + "hashbrown 0.14.0", ] [[package]] @@ -3952,6 +3952,7 @@ dependencies = [ "pin-project-lite", "postgres-protocol", "rand 0.8.5", + "serde", "thiserror", "tokio", "tracing", diff --git a/Cargo.toml b/Cargo.toml index 8df9ca9988..8952f7627f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -81,7 +81,7 @@ futures-core = "0.3" futures-util = "0.3" git-version = "0.3" hashbrown = "0.13" -hashlink = "0.8.1" +hashlink = "0.8.4" hdrhistogram = "7.5.2" hex = "0.4" hex-literal = "0.4" diff --git a/libs/pq_proto/Cargo.toml b/libs/pq_proto/Cargo.toml index b286eb0358..6eeb3bafef 100644 --- a/libs/pq_proto/Cargo.toml +++ b/libs/pq_proto/Cargo.toml @@ -13,5 +13,6 @@ rand.workspace = true tokio.workspace = true tracing.workspace = true thiserror.workspace = true +serde.workspace = true workspace_hack.workspace = true diff --git a/libs/pq_proto/src/lib.rs b/libs/pq_proto/src/lib.rs index c52a21bcd3..522b65f5d1 100644 --- a/libs/pq_proto/src/lib.rs +++ b/libs/pq_proto/src/lib.rs @@ -7,6 +7,7 @@ pub mod framed; use byteorder::{BigEndian, ReadBytesExt}; use bytes::{Buf, BufMut, Bytes, BytesMut}; +use serde::{Deserialize, Serialize}; use std::{borrow::Cow, collections::HashMap, fmt, io, str}; // re-export for use in utils pageserver_feedback.rs @@ -123,7 +124,7 @@ impl StartupMessageParams { } } -#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)] +#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] pub struct CancelKeyData { pub backend_pid: i32, pub cancel_key: i32, diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index 00a229c135..b3d4fc0411 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -1,6 +1,8 @@ use futures::future::Either; use proxy::auth; use proxy::auth::backend::MaybeOwned; +use proxy::cancellation::CancelMap; +use proxy::cancellation::CancellationHandler; use proxy::config::AuthenticationConfig; use proxy::config::CacheOptions; use proxy::config::HttpConfig; @@ -12,6 +14,7 @@ use proxy::rate_limiter::EndpointRateLimiter; use proxy::rate_limiter::RateBucketInfo; use proxy::rate_limiter::RateLimiterConfig; use proxy::redis::notifications; +use proxy::redis::publisher::RedisPublisherClient; use proxy::serverless::GlobalConnPoolOptions; use proxy::usage_metrics; @@ -22,6 +25,7 @@ use std::net::SocketAddr; use std::pin::pin; use std::sync::Arc; use tokio::net::TcpListener; +use tokio::sync::Mutex; use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; use tracing::info; @@ -129,6 +133,9 @@ struct ProxyCliArgs { /// Can be given multiple times for different bucket sizes. #[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)] endpoint_rps_limit: Vec, + /// Redis rate limiter max number of requests per second. + #[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)] + redis_rps_limit: Vec, /// Initial limit for dynamic rate limiter. Makes sense only if `rate_limit_algorithm` is *not* `None`. #[clap(long, default_value_t = 100)] initial_limit: usize, @@ -225,6 +232,19 @@ async fn main() -> anyhow::Result<()> { let cancellation_token = CancellationToken::new(); let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(&config.endpoint_rps_limit)); + let cancel_map = CancelMap::default(); + let redis_publisher = match &args.redis_notifications { + Some(url) => Some(Arc::new(Mutex::new(RedisPublisherClient::new( + url, + args.region.clone(), + &config.redis_rps_limit, + )?))), + None => None, + }; + let cancellation_handler = Arc::new(CancellationHandler::new( + cancel_map.clone(), + redis_publisher, + )); // client facing tasks. these will exit on error or on cancellation // cancellation returns Ok(()) @@ -234,6 +254,7 @@ async fn main() -> anyhow::Result<()> { proxy_listener, cancellation_token.clone(), endpoint_rate_limiter.clone(), + cancellation_handler.clone(), )); // TODO: rename the argument to something like serverless. @@ -248,6 +269,7 @@ async fn main() -> anyhow::Result<()> { serverless_listener, cancellation_token.clone(), endpoint_rate_limiter.clone(), + cancellation_handler.clone(), )); } @@ -271,7 +293,12 @@ async fn main() -> anyhow::Result<()> { let cache = api.caches.project_info.clone(); if let Some(url) = args.redis_notifications { info!("Starting redis notifications listener ({url})"); - maintenance_tasks.spawn(notifications::task_main(url.to_owned(), cache.clone())); + maintenance_tasks.spawn(notifications::task_main( + url.to_owned(), + cache.clone(), + cancel_map.clone(), + args.region.clone(), + )); } maintenance_tasks.spawn(async move { cache.clone().gc_worker().await }); } @@ -403,6 +430,8 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { let mut endpoint_rps_limit = args.endpoint_rps_limit.clone(); RateBucketInfo::validate(&mut endpoint_rps_limit)?; + let mut redis_rps_limit = args.redis_rps_limit.clone(); + RateBucketInfo::validate(&mut redis_rps_limit)?; let config = Box::leak(Box::new(ProxyConfig { tls_config, @@ -414,6 +443,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { require_client_ip: args.require_client_ip, disable_ip_check_for_http: args.disable_ip_check_for_http, endpoint_rps_limit, + redis_rps_limit, handshake_timeout: args.handshake_timeout, // TODO: add this argument region: args.region.clone(), diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index fe614628d8..93a77bc4ae 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -1,16 +1,28 @@ +use async_trait::async_trait; use dashmap::DashMap; use pq_proto::CancelKeyData; use std::{net::SocketAddr, sync::Arc}; use thiserror::Error; use tokio::net::TcpStream; +use tokio::sync::Mutex; use tokio_postgres::{CancelToken, NoTls}; use tracing::info; +use uuid::Uuid; -use crate::error::ReportableError; +use crate::{ + error::ReportableError, metrics::NUM_CANCELLATION_REQUESTS, + redis::publisher::RedisPublisherClient, +}; + +pub type CancelMap = Arc>>; /// Enables serving `CancelRequest`s. -#[derive(Default)] -pub struct CancelMap(DashMap>); +/// +/// If there is a `RedisPublisherClient` available, it will be used to publish the cancellation key to other proxy instances. +pub struct CancellationHandler { + map: CancelMap, + redis_client: Option>>, +} #[derive(Debug, Error)] pub enum CancelError { @@ -32,15 +44,43 @@ impl ReportableError for CancelError { } } -impl CancelMap { +impl CancellationHandler { + pub fn new(map: CancelMap, redis_client: Option>>) -> Self { + Self { map, redis_client } + } /// Cancel a running query for the corresponding connection. - pub async fn cancel_session(&self, key: CancelKeyData) -> Result<(), CancelError> { + pub async fn cancel_session( + &self, + key: CancelKeyData, + session_id: Uuid, + ) -> Result<(), CancelError> { + let from = "from_client"; // NB: we should immediately release the lock after cloning the token. - let Some(cancel_closure) = self.0.get(&key).and_then(|x| x.clone()) else { + let Some(cancel_closure) = self.map.get(&key).and_then(|x| x.clone()) else { tracing::warn!("query cancellation key not found: {key}"); + if let Some(redis_client) = &self.redis_client { + NUM_CANCELLATION_REQUESTS + .with_label_values(&[from, "not_found"]) + .inc(); + info!("publishing cancellation key to Redis"); + match redis_client.lock().await.try_publish(key, session_id).await { + Ok(()) => { + info!("cancellation key successfuly published to Redis"); + } + Err(e) => { + tracing::error!("failed to publish a message: {e}"); + return Err(CancelError::IO(std::io::Error::new( + std::io::ErrorKind::Other, + e.to_string(), + ))); + } + } + } return Ok(()); }; - + NUM_CANCELLATION_REQUESTS + .with_label_values(&[from, "found"]) + .inc(); info!("cancelling query per user's request using key {key}"); cancel_closure.try_cancel_query().await } @@ -57,7 +97,7 @@ impl CancelMap { // Random key collisions are unlikely to happen here, but they're still possible, // which is why we have to take care not to rewrite an existing key. - match self.0.entry(key) { + match self.map.entry(key) { dashmap::mapref::entry::Entry::Occupied(_) => continue, dashmap::mapref::entry::Entry::Vacant(e) => { e.insert(None); @@ -69,18 +109,46 @@ impl CancelMap { info!("registered new query cancellation key {key}"); Session { key, - cancel_map: self, + cancellation_handler: self, } } #[cfg(test)] fn contains(&self, session: &Session) -> bool { - self.0.contains_key(&session.key) + self.map.contains_key(&session.key) } #[cfg(test)] fn is_empty(&self) -> bool { - self.0.is_empty() + self.map.is_empty() + } +} + +#[async_trait] +pub trait NotificationsCancellationHandler { + async fn cancel_session_no_publish(&self, key: CancelKeyData) -> Result<(), CancelError>; +} + +#[async_trait] +impl NotificationsCancellationHandler for CancellationHandler { + async fn cancel_session_no_publish(&self, key: CancelKeyData) -> Result<(), CancelError> { + let from = "from_redis"; + let cancel_closure = self.map.get(&key).and_then(|x| x.clone()); + match cancel_closure { + Some(cancel_closure) => { + NUM_CANCELLATION_REQUESTS + .with_label_values(&[from, "found"]) + .inc(); + cancel_closure.try_cancel_query().await + } + None => { + NUM_CANCELLATION_REQUESTS + .with_label_values(&[from, "not_found"]) + .inc(); + tracing::warn!("query cancellation key not found: {key}"); + Ok(()) + } + } } } @@ -115,7 +183,7 @@ pub struct Session { /// The user-facing key identifying this session. key: CancelKeyData, /// The [`CancelMap`] this session belongs to. - cancel_map: Arc, + cancellation_handler: Arc, } impl Session { @@ -123,7 +191,9 @@ impl Session { /// This enables query cancellation in `crate::proxy::prepare_client_connection`. pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData { info!("enabling query cancellation for this session"); - self.cancel_map.0.insert(self.key, Some(cancel_closure)); + self.cancellation_handler + .map + .insert(self.key, Some(cancel_closure)); self.key } @@ -131,7 +201,7 @@ impl Session { impl Drop for Session { fn drop(&mut self) { - self.cancel_map.0.remove(&self.key); + self.cancellation_handler.map.remove(&self.key); info!("dropped query cancellation key {}", &self.key); } } @@ -142,13 +212,16 @@ mod tests { #[tokio::test] async fn check_session_drop() -> anyhow::Result<()> { - let cancel_map: Arc = Default::default(); + let cancellation_handler = Arc::new(CancellationHandler { + map: CancelMap::default(), + redis_client: None, + }); - let session = cancel_map.clone().get_session(); - assert!(cancel_map.contains(&session)); + let session = cancellation_handler.clone().get_session(); + assert!(cancellation_handler.contains(&session)); drop(session); // Check that the session has been dropped. - assert!(cancel_map.is_empty()); + assert!(cancellation_handler.is_empty()); Ok(()) } diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 5fcb537834..9f276c3c24 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -21,6 +21,7 @@ pub struct ProxyConfig { pub require_client_ip: bool, pub disable_ip_check_for_http: bool, pub endpoint_rps_limit: Vec, + pub redis_rps_limit: Vec, pub region: String, pub handshake_timeout: Duration, } diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index f7f162a075..66031f5eb2 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -152,6 +152,15 @@ pub static NUM_OPEN_CLIENTS_IN_HTTP_POOL: Lazy = Lazy::new(|| { .unwrap() }); +pub static NUM_CANCELLATION_REQUESTS: Lazy = Lazy::new(|| { + register_int_counter_vec!( + "proxy_cancellation_requests_total", + "Number of cancellation requests (per found/not_found).", + &["source", "kind"], + ) + .unwrap() +}); + #[derive(Clone)] pub struct LatencyTimer { // time since the stopwatch was started diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 5f65de4c98..ce77098a5f 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -10,7 +10,7 @@ pub mod wake_compute; use crate::{ auth, - cancellation::{self, CancelMap}, + cancellation::{self, CancellationHandler}, compute, config::{ProxyConfig, TlsConfig}, context::RequestMonitoring, @@ -62,6 +62,7 @@ pub async fn task_main( listener: tokio::net::TcpListener, cancellation_token: CancellationToken, endpoint_rate_limiter: Arc, + cancellation_handler: Arc, ) -> anyhow::Result<()> { scopeguard::defer! { info!("proxy has shut down"); @@ -72,7 +73,6 @@ pub async fn task_main( socket2::SockRef::from(&listener).set_keepalive(true)?; let connections = tokio_util::task::task_tracker::TaskTracker::new(); - let cancel_map = Arc::new(CancelMap::default()); while let Some(accept_result) = run_until_cancelled(listener.accept(), &cancellation_token).await @@ -80,7 +80,7 @@ pub async fn task_main( let (socket, peer_addr) = accept_result?; let session_id = uuid::Uuid::new_v4(); - let cancel_map = Arc::clone(&cancel_map); + let cancellation_handler = Arc::clone(&cancellation_handler); let endpoint_rate_limiter = endpoint_rate_limiter.clone(); let session_span = info_span!( @@ -113,7 +113,7 @@ pub async fn task_main( let res = handle_client( config, &mut ctx, - cancel_map, + cancellation_handler, socket, ClientMode::Tcp, endpoint_rate_limiter, @@ -227,7 +227,7 @@ impl ReportableError for ClientRequestError { pub async fn handle_client( config: &'static ProxyConfig, ctx: &mut RequestMonitoring, - cancel_map: Arc, + cancellation_handler: Arc, stream: S, mode: ClientMode, endpoint_rate_limiter: Arc, @@ -253,8 +253,8 @@ pub async fn handle_client( match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? { HandshakeData::Startup(stream, params) => (stream, params), HandshakeData::Cancel(cancel_key_data) => { - return Ok(cancel_map - .cancel_session(cancel_key_data) + return Ok(cancellation_handler + .cancel_session(cancel_key_data, ctx.session_id) .await .map(|()| None)?) } @@ -315,7 +315,7 @@ pub async fn handle_client( .or_else(|e| stream.throw_error(e)) .await?; - let session = cancel_map.get_session(); + let session = cancellation_handler.get_session(); prepare_client_connection(&node, &session, &mut stream).await?; // Before proxy passing, forward to compute whatever data is left in the diff --git a/proxy/src/rate_limiter.rs b/proxy/src/rate_limiter.rs index b26386d159..f0da4ead23 100644 --- a/proxy/src/rate_limiter.rs +++ b/proxy/src/rate_limiter.rs @@ -4,4 +4,4 @@ mod limiter; pub use aimd::Aimd; pub use limit_algorithm::{AimdConfig, Fixed, RateLimitAlgorithm, RateLimiterConfig}; pub use limiter::Limiter; -pub use limiter::{EndpointRateLimiter, RateBucketInfo}; +pub use limiter::{EndpointRateLimiter, RateBucketInfo, RedisRateLimiter}; diff --git a/proxy/src/rate_limiter/limiter.rs b/proxy/src/rate_limiter/limiter.rs index cbae72711c..3181060e2f 100644 --- a/proxy/src/rate_limiter/limiter.rs +++ b/proxy/src/rate_limiter/limiter.rs @@ -22,6 +22,44 @@ use super::{ RateLimiterConfig, }; +pub struct RedisRateLimiter { + data: Vec, + info: &'static [RateBucketInfo], +} + +impl RedisRateLimiter { + pub fn new(info: &'static [RateBucketInfo]) -> Self { + Self { + data: vec![ + RateBucket { + start: Instant::now(), + count: 0, + }; + info.len() + ], + info, + } + } + + /// Check that number of connections is below `max_rps` rps. + pub fn check(&mut self) -> bool { + let now = Instant::now(); + + let should_allow_request = self + .data + .iter_mut() + .zip(self.info) + .all(|(bucket, info)| bucket.should_allow_request(info, now)); + + if should_allow_request { + // only increment the bucket counts if the request will actually be accepted + self.data.iter_mut().for_each(RateBucket::inc); + } + + should_allow_request + } +} + // Simple per-endpoint rate limiter. // // Check that number of connections to the endpoint is below `max_rps` rps. diff --git a/proxy/src/redis.rs b/proxy/src/redis.rs index c2a91bed97..35d6db074e 100644 --- a/proxy/src/redis.rs +++ b/proxy/src/redis.rs @@ -1 +1,2 @@ pub mod notifications; +pub mod publisher; diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index 158884aa17..b8297a206c 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -1,38 +1,44 @@ use std::{convert::Infallible, sync::Arc}; use futures::StreamExt; +use pq_proto::CancelKeyData; use redis::aio::PubSub; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; use crate::{ cache::project_info::ProjectInfoCache, + cancellation::{CancelMap, CancellationHandler, NotificationsCancellationHandler}, intern::{ProjectIdInt, RoleNameInt}, }; -const CHANNEL_NAME: &str = "neondb-proxy-ws-updates"; +const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates"; +pub(crate) const PROXY_CHANNEL_NAME: &str = "neondb-proxy-to-proxy-updates"; const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20); const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20); -struct ConsoleRedisClient { +struct RedisConsumerClient { client: redis::Client, } -impl ConsoleRedisClient { +impl RedisConsumerClient { pub fn new(url: &str) -> anyhow::Result { let client = redis::Client::open(url)?; Ok(Self { client }) } async fn try_connect(&self) -> anyhow::Result { let mut conn = self.client.get_async_connection().await?.into_pubsub(); - tracing::info!("subscribing to a channel `{CHANNEL_NAME}`"); - conn.subscribe(CHANNEL_NAME).await?; + tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`"); + conn.subscribe(CPLANE_CHANNEL_NAME).await?; + tracing::info!("subscribing to a channel `{PROXY_CHANNEL_NAME}`"); + conn.subscribe(PROXY_CHANNEL_NAME).await?; Ok(conn) } } -#[derive(Clone, Debug, Deserialize, Eq, PartialEq)] +#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] #[serde(tag = "topic", content = "data")] -enum Notification { +pub(crate) enum Notification { #[serde( rename = "/allowed_ips_updated", deserialize_with = "deserialize_json_string" @@ -45,16 +51,25 @@ enum Notification { deserialize_with = "deserialize_json_string" )] PasswordUpdate { password_update: PasswordUpdate }, + #[serde(rename = "/cancel_session")] + Cancel(CancelSession), } -#[derive(Clone, Debug, Deserialize, Eq, PartialEq)] -struct AllowedIpsUpdate { +#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] +pub(crate) struct AllowedIpsUpdate { project_id: ProjectIdInt, } -#[derive(Clone, Debug, Deserialize, Eq, PartialEq)] -struct PasswordUpdate { +#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] +pub(crate) struct PasswordUpdate { project_id: ProjectIdInt, role_name: RoleNameInt, } +#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] +pub(crate) struct CancelSession { + pub region_id: Option, + pub cancel_key_data: CancelKeyData, + pub session_id: Uuid, +} + fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result where T: for<'de2> serde::Deserialize<'de2>, @@ -64,6 +79,88 @@ where serde_json::from_str(&s).map_err(::custom) } +struct MessageHandler< + C: ProjectInfoCache + Send + Sync + 'static, + H: NotificationsCancellationHandler + Send + Sync + 'static, +> { + cache: Arc, + cancellation_handler: Arc, + region_id: String, +} + +impl< + C: ProjectInfoCache + Send + Sync + 'static, + H: NotificationsCancellationHandler + Send + Sync + 'static, + > MessageHandler +{ + pub fn new(cache: Arc, cancellation_handler: Arc, region_id: String) -> Self { + Self { + cache, + cancellation_handler, + region_id, + } + } + pub fn disable_ttl(&self) { + self.cache.disable_ttl(); + } + pub fn enable_ttl(&self) { + self.cache.enable_ttl(); + } + #[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))] + async fn handle_message(&self, msg: redis::Msg) -> anyhow::Result<()> { + use Notification::*; + let payload: String = msg.get_payload()?; + tracing::debug!(?payload, "received a message payload"); + + let msg: Notification = match serde_json::from_str(&payload) { + Ok(msg) => msg, + Err(e) => { + tracing::error!("broken message: {e}"); + return Ok(()); + } + }; + tracing::debug!(?msg, "received a message"); + match msg { + Cancel(cancel_session) => { + tracing::Span::current().record( + "session_id", + &tracing::field::display(cancel_session.session_id), + ); + if let Some(cancel_region) = cancel_session.region_id { + // If the message is not for this region, ignore it. + if cancel_region != self.region_id { + return Ok(()); + } + } + // This instance of cancellation_handler doesn't have a RedisPublisherClient so it can't publish the message. + match self + .cancellation_handler + .cancel_session_no_publish(cancel_session.cancel_key_data) + .await + { + Ok(()) => {} + Err(e) => { + tracing::error!("failed to cancel session: {e}"); + } + } + } + _ => { + invalidate_cache(self.cache.clone(), msg.clone()); + // It might happen that the invalid entry is on the way to be cached. + // To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds. + // TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message. + let cache = self.cache.clone(); + tokio::spawn(async move { + tokio::time::sleep(INVALIDATION_LAG).await; + invalidate_cache(cache, msg); + }); + } + } + + Ok(()) + } +} + fn invalidate_cache(cache: Arc, msg: Notification) { use Notification::*; match msg { @@ -74,50 +171,33 @@ fn invalidate_cache(cache: Arc, msg: Notification) { password_update.project_id, password_update.role_name, ), + Cancel(_) => unreachable!("cancel message should be handled separately"), } } -#[tracing::instrument(skip(cache))] -fn handle_message(msg: redis::Msg, cache: Arc) -> anyhow::Result<()> -where - C: ProjectInfoCache + Send + Sync + 'static, -{ - let payload: String = msg.get_payload()?; - tracing::debug!(?payload, "received a message payload"); - - let msg: Notification = match serde_json::from_str(&payload) { - Ok(msg) => msg, - Err(e) => { - tracing::error!("broken message: {e}"); - return Ok(()); - } - }; - tracing::debug!(?msg, "received a message"); - invalidate_cache(cache.clone(), msg.clone()); - // It might happen that the invalid entry is on the way to be cached. - // To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds. - // TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message. - tokio::spawn(async move { - tokio::time::sleep(INVALIDATION_LAG).await; - invalidate_cache(cache, msg.clone()); - }); - - Ok(()) -} - /// Handle console's invalidation messages. #[tracing::instrument(name = "console_notifications", skip_all)] -pub async fn task_main(url: String, cache: Arc) -> anyhow::Result +pub async fn task_main( + url: String, + cache: Arc, + cancel_map: CancelMap, + region_id: String, +) -> anyhow::Result where C: ProjectInfoCache + Send + Sync + 'static, { cache.enable_ttl(); + let handler = MessageHandler::new( + cache, + Arc::new(CancellationHandler::new(cancel_map, None)), + region_id, + ); loop { - let redis = ConsoleRedisClient::new(&url)?; + let redis = RedisConsumerClient::new(&url)?; let conn = match redis.try_connect().await { Ok(conn) => { - cache.disable_ttl(); + handler.disable_ttl(); conn } Err(e) => { @@ -130,7 +210,7 @@ where }; let mut stream = conn.into_on_message(); while let Some(msg) = stream.next().await { - match handle_message(msg, cache.clone()) { + match handler.handle_message(msg).await { Ok(()) => {} Err(e) => { tracing::error!("failed to handle message: {e}, will try to reconnect"); @@ -138,7 +218,7 @@ where } } } - cache.enable_ttl(); + handler.enable_ttl(); } } @@ -198,6 +278,33 @@ mod tests { } ); + Ok(()) + } + #[test] + fn parse_cancel_session() -> anyhow::Result<()> { + let cancel_key_data = CancelKeyData { + backend_pid: 42, + cancel_key: 41, + }; + let uuid = uuid::Uuid::new_v4(); + let msg = Notification::Cancel(CancelSession { + cancel_key_data, + region_id: None, + session_id: uuid, + }); + let text = serde_json::to_string(&msg)?; + let result: Notification = serde_json::from_str(&text)?; + assert_eq!(msg, result); + + let msg = Notification::Cancel(CancelSession { + cancel_key_data, + region_id: Some("region".to_string()), + session_id: uuid, + }); + let text = serde_json::to_string(&msg)?; + let result: Notification = serde_json::from_str(&text)?; + assert_eq!(msg, result,); + Ok(()) } } diff --git a/proxy/src/redis/publisher.rs b/proxy/src/redis/publisher.rs new file mode 100644 index 0000000000..f85593afdd --- /dev/null +++ b/proxy/src/redis/publisher.rs @@ -0,0 +1,80 @@ +use pq_proto::CancelKeyData; +use redis::AsyncCommands; +use uuid::Uuid; + +use crate::rate_limiter::{RateBucketInfo, RedisRateLimiter}; + +use super::notifications::{CancelSession, Notification, PROXY_CHANNEL_NAME}; + +pub struct RedisPublisherClient { + client: redis::Client, + publisher: Option, + region_id: String, + limiter: RedisRateLimiter, +} + +impl RedisPublisherClient { + pub fn new( + url: &str, + region_id: String, + info: &'static [RateBucketInfo], + ) -> anyhow::Result { + let client = redis::Client::open(url)?; + Ok(Self { + client, + publisher: None, + region_id, + limiter: RedisRateLimiter::new(info), + }) + } + pub async fn try_publish( + &mut self, + cancel_key_data: CancelKeyData, + session_id: Uuid, + ) -> anyhow::Result<()> { + if !self.limiter.check() { + tracing::info!("Rate limit exceeded. Skipping cancellation message"); + return Err(anyhow::anyhow!("Rate limit exceeded")); + } + match self.publish(cancel_key_data, session_id).await { + Ok(()) => return Ok(()), + Err(e) => { + tracing::error!("failed to publish a message: {e}"); + self.publisher = None; + } + } + tracing::info!("Publisher is disconnected. Reconnectiong..."); + self.try_connect().await?; + self.publish(cancel_key_data, session_id).await + } + + async fn publish( + &mut self, + cancel_key_data: CancelKeyData, + session_id: Uuid, + ) -> anyhow::Result<()> { + let conn = self + .publisher + .as_mut() + .ok_or_else(|| anyhow::anyhow!("not connected"))?; + let payload = serde_json::to_string(&Notification::Cancel(CancelSession { + region_id: Some(self.region_id.clone()), + cancel_key_data, + session_id, + }))?; + conn.publish(PROXY_CHANNEL_NAME, payload).await?; + Ok(()) + } + pub async fn try_connect(&mut self) -> anyhow::Result<()> { + match self.client.get_async_connection().await { + Ok(conn) => { + self.publisher = Some(conn); + } + Err(e) => { + tracing::error!("failed to connect to redis: {e}"); + return Err(e.into()); + } + } + Ok(()) + } +} diff --git a/proxy/src/serverless.rs b/proxy/src/serverless.rs index a20600b94a..ee3e91495b 100644 --- a/proxy/src/serverless.rs +++ b/proxy/src/serverless.rs @@ -24,7 +24,7 @@ use crate::metrics::NUM_CLIENT_CONNECTION_GAUGE; use crate::protocol2::{ProxyProtocolAccept, WithClientIp}; use crate::rate_limiter::EndpointRateLimiter; use crate::serverless::backend::PoolingBackend; -use crate::{cancellation::CancelMap, config::ProxyConfig}; +use crate::{cancellation::CancellationHandler, config::ProxyConfig}; use futures::StreamExt; use hyper::{ server::{ @@ -50,6 +50,7 @@ pub async fn task_main( ws_listener: TcpListener, cancellation_token: CancellationToken, endpoint_rate_limiter: Arc, + cancellation_handler: Arc, ) -> anyhow::Result<()> { scopeguard::defer! { info!("websocket server has shut down"); @@ -115,7 +116,7 @@ pub async fn task_main( let backend = backend.clone(); let ws_connections = ws_connections.clone(); let endpoint_rate_limiter = endpoint_rate_limiter.clone(); - + let cancellation_handler = cancellation_handler.clone(); async move { let peer_addr = match client_addr { Some(addr) => addr, @@ -127,9 +128,9 @@ pub async fn task_main( let backend = backend.clone(); let ws_connections = ws_connections.clone(); let endpoint_rate_limiter = endpoint_rate_limiter.clone(); + let cancellation_handler = cancellation_handler.clone(); async move { - let cancel_map = Arc::new(CancelMap::default()); let session_id = uuid::Uuid::new_v4(); request_handler( @@ -137,7 +138,7 @@ pub async fn task_main( config, backend, ws_connections, - cancel_map, + cancellation_handler, session_id, peer_addr.ip(), endpoint_rate_limiter, @@ -205,7 +206,7 @@ async fn request_handler( config: &'static ProxyConfig, backend: Arc, ws_connections: TaskTracker, - cancel_map: Arc, + cancellation_handler: Arc, session_id: uuid::Uuid, peer_addr: IpAddr, endpoint_rate_limiter: Arc, @@ -232,7 +233,7 @@ async fn request_handler( config, ctx, websocket, - cancel_map, + cancellation_handler, host, endpoint_rate_limiter, ) diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs index 062dd440b2..24f2bb7e8c 100644 --- a/proxy/src/serverless/websocket.rs +++ b/proxy/src/serverless/websocket.rs @@ -1,5 +1,5 @@ use crate::{ - cancellation::CancelMap, + cancellation::CancellationHandler, config::ProxyConfig, context::RequestMonitoring, error::{io_error, ReportableError}, @@ -133,7 +133,7 @@ pub async fn serve_websocket( config: &'static ProxyConfig, mut ctx: RequestMonitoring, websocket: HyperWebsocket, - cancel_map: Arc, + cancellation_handler: Arc, hostname: Option, endpoint_rate_limiter: Arc, ) -> anyhow::Result<()> { @@ -141,7 +141,7 @@ pub async fn serve_websocket( let res = handle_client( config, &mut ctx, - cancel_map, + cancellation_handler, WebSocketRw::new(websocket), ClientMode::Websockets { hostname }, endpoint_rate_limiter, diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index 8e9cc43152..e808fabbe7 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -38,7 +38,7 @@ futures-io = { version = "0.3" } futures-sink = { version = "0.3" } futures-util = { version = "0.3", features = ["channel", "io", "sink"] } getrandom = { version = "0.2", default-features = false, features = ["std"] } -hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", default-features = false, features = ["raw"] } +hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", features = ["raw"] } hashbrown-594e8ee84c453af0 = { package = "hashbrown", version = "0.13", features = ["raw"] } hex = { version = "0.4", features = ["serde"] } hmac = { version = "0.12", default-features = false, features = ["reset"] } @@ -91,7 +91,7 @@ cc = { version = "1", default-features = false, features = ["parallel"] } chrono = { version = "0.4", default-features = false, features = ["clock", "serde", "wasmbind"] } either = { version = "1" } getrandom = { version = "0.2", default-features = false, features = ["std"] } -hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", default-features = false, features = ["raw"] } +hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", features = ["raw"] } indexmap = { version = "1", default-features = false, features = ["std"] } itertools = { version = "0.10" } libc = { version = "0.2", features = ["extra_traits", "use_std"] } From 7fa732c96c6382fd0468991b40f922348e653d3c Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Tue, 13 Feb 2024 18:46:25 +0100 Subject: [PATCH 14/20] refactor(virtual_file): take owned buffer in VirtualFile::write_all (#6664) Building atop #6660 , this PR converts VirtualFile::write_all to owned buffers. Part of https://github.com/neondatabase/neon/issues/6663 --- pageserver/src/deletion_queue.rs | 4 +- pageserver/src/tenant.rs | 4 +- pageserver/src/tenant/blob_io.rs | 26 ++++---- pageserver/src/tenant/metadata.rs | 2 +- pageserver/src/tenant/secondary/downloader.rs | 2 +- .../src/tenant/storage_layer/delta_layer.rs | 30 +++------ .../src/tenant/storage_layer/image_layer.rs | 30 +++------ pageserver/src/virtual_file.rs | 66 ++++++++++++------- 8 files changed, 81 insertions(+), 83 deletions(-) diff --git a/pageserver/src/deletion_queue.rs b/pageserver/src/deletion_queue.rs index da1da9331a..9046fe881b 100644 --- a/pageserver/src/deletion_queue.rs +++ b/pageserver/src/deletion_queue.rs @@ -234,7 +234,7 @@ impl DeletionHeader { let header_bytes = serde_json::to_vec(self).context("serialize deletion header")?; let header_path = conf.deletion_header_path(); let temp_path = path_with_suffix_extension(&header_path, TEMP_SUFFIX); - VirtualFile::crashsafe_overwrite(&header_path, &temp_path, &header_bytes) + VirtualFile::crashsafe_overwrite(&header_path, &temp_path, header_bytes) .await .maybe_fatal_err("save deletion header")?; @@ -325,7 +325,7 @@ impl DeletionList { let temp_path = path_with_suffix_extension(&path, TEMP_SUFFIX); let bytes = serde_json::to_vec(self).expect("Failed to serialize deletion list"); - VirtualFile::crashsafe_overwrite(&path, &temp_path, &bytes) + VirtualFile::crashsafe_overwrite(&path, &temp_path, bytes) .await .maybe_fatal_err("save deletion list") .map_err(Into::into) diff --git a/pageserver/src/tenant.rs b/pageserver/src/tenant.rs index d946c57118..9f1f188bf2 100644 --- a/pageserver/src/tenant.rs +++ b/pageserver/src/tenant.rs @@ -2880,7 +2880,7 @@ impl Tenant { let config_path = config_path.to_owned(); tokio::task::spawn_blocking(move || { Handle::current().block_on(async move { - let conf_content = conf_content.as_bytes(); + let conf_content = conf_content.into_bytes(); VirtualFile::crashsafe_overwrite(&config_path, &temp_path, conf_content) .await .with_context(|| { @@ -2917,7 +2917,7 @@ impl Tenant { let target_config_path = target_config_path.to_owned(); tokio::task::spawn_blocking(move || { Handle::current().block_on(async move { - let conf_content = conf_content.as_bytes(); + let conf_content = conf_content.into_bytes(); VirtualFile::crashsafe_overwrite(&target_config_path, &temp_path, conf_content) .await .with_context(|| { diff --git a/pageserver/src/tenant/blob_io.rs b/pageserver/src/tenant/blob_io.rs index e2ff12665a..ec70bdc679 100644 --- a/pageserver/src/tenant/blob_io.rs +++ b/pageserver/src/tenant/blob_io.rs @@ -131,27 +131,23 @@ impl BlobWriter { &mut self, src_buf: B, ) -> (B::Buf, Result<(), Error>) { - let src_buf_len = src_buf.bytes_init(); - let (src_buf, res) = if src_buf_len > 0 { - let src_buf = src_buf.slice(0..src_buf_len); - let res = self.inner.write_all(&src_buf).await; - let src_buf = Slice::into_inner(src_buf); - (src_buf, res) - } else { - let res = self.inner.write_all(&[]).await; - (Slice::into_inner(src_buf.slice_full()), res) + let (src_buf, res) = self.inner.write_all(src_buf).await; + let nbytes = match res { + Ok(nbytes) => nbytes, + Err(e) => return (src_buf, Err(e)), }; - if let Ok(()) = &res { - self.offset += src_buf_len as u64; - } - (src_buf, res) + self.offset += nbytes as u64; + (src_buf, Ok(())) } #[inline(always)] /// Flushes the internal buffer to the underlying `VirtualFile`. pub async fn flush_buffer(&mut self) -> Result<(), Error> { - self.inner.write_all(&self.buf).await?; - self.buf.clear(); + let buf = std::mem::take(&mut self.buf); + let (mut buf, res) = self.inner.write_all(buf).await; + res?; + buf.clear(); + self.buf = buf; Ok(()) } diff --git a/pageserver/src/tenant/metadata.rs b/pageserver/src/tenant/metadata.rs index 6fb86c65e2..dcbe781f90 100644 --- a/pageserver/src/tenant/metadata.rs +++ b/pageserver/src/tenant/metadata.rs @@ -279,7 +279,7 @@ pub async fn save_metadata( let path = conf.metadata_path(tenant_shard_id, timeline_id); let temp_path = path_with_suffix_extension(&path, TEMP_FILE_SUFFIX); let metadata_bytes = data.to_bytes().context("serialize metadata")?; - VirtualFile::crashsafe_overwrite(&path, &temp_path, &metadata_bytes) + VirtualFile::crashsafe_overwrite(&path, &temp_path, metadata_bytes) .await .context("write metadata")?; Ok(()) diff --git a/pageserver/src/tenant/secondary/downloader.rs b/pageserver/src/tenant/secondary/downloader.rs index 0666e104f8..c23416a7f0 100644 --- a/pageserver/src/tenant/secondary/downloader.rs +++ b/pageserver/src/tenant/secondary/downloader.rs @@ -486,7 +486,7 @@ impl<'a> TenantDownloader<'a> { let heatmap_path_bg = heatmap_path.clone(); tokio::task::spawn_blocking(move || { tokio::runtime::Handle::current().block_on(async move { - VirtualFile::crashsafe_overwrite(&heatmap_path_bg, &temp_path, &heatmap_bytes).await + VirtualFile::crashsafe_overwrite(&heatmap_path_bg, &temp_path, heatmap_bytes).await }) }) .await diff --git a/pageserver/src/tenant/storage_layer/delta_layer.rs b/pageserver/src/tenant/storage_layer/delta_layer.rs index 7a5dc7a59f..9a7bcbcebe 100644 --- a/pageserver/src/tenant/storage_layer/delta_layer.rs +++ b/pageserver/src/tenant/storage_layer/delta_layer.rs @@ -461,7 +461,8 @@ impl DeltaLayerWriterInner { file.seek(SeekFrom::Start(index_start_blk as u64 * PAGE_SZ as u64)) .await?; for buf in block_buf.blocks { - file.write_all(buf.as_ref()).await?; + let (_buf, res) = file.write_all(buf).await; + res?; } assert!(self.lsn_range.start < self.lsn_range.end); // Fill in the summary on blk 0 @@ -476,17 +477,12 @@ impl DeltaLayerWriterInner { index_root_blk, }; - let mut buf = smallvec::SmallVec::<[u8; PAGE_SZ]>::new(); + let mut buf = Vec::with_capacity(PAGE_SZ); + // TODO: could use smallvec here but it's a pain with Slice Summary::ser_into(&summary, &mut buf)?; - if buf.spilled() { - // This is bad as we only have one free block for the summary - warn!( - "Used more than one page size for summary buffer: {}", - buf.len() - ); - } file.seek(SeekFrom::Start(0)).await?; - file.write_all(&buf).await?; + let (_buf, res) = file.write_all(buf).await; + res?; let metadata = file .metadata() @@ -679,18 +675,12 @@ impl DeltaLayer { let new_summary = rewrite(actual_summary); - let mut buf = smallvec::SmallVec::<[u8; PAGE_SZ]>::new(); + let mut buf = Vec::with_capacity(PAGE_SZ); + // TODO: could use smallvec here, but it's a pain with Slice Summary::ser_into(&new_summary, &mut buf).context("serialize")?; - if buf.spilled() { - // The code in DeltaLayerWriterInner just warn!()s for this. - // It should probably error out as well. - return Err(RewriteSummaryError::Other(anyhow::anyhow!( - "Used more than one page size for summary buffer: {}", - buf.len() - ))); - } file.seek(SeekFrom::Start(0)).await?; - file.write_all(&buf).await?; + let (_buf, res) = file.write_all(buf).await; + res?; Ok(()) } } diff --git a/pageserver/src/tenant/storage_layer/image_layer.rs b/pageserver/src/tenant/storage_layer/image_layer.rs index 1ad195032d..458131b572 100644 --- a/pageserver/src/tenant/storage_layer/image_layer.rs +++ b/pageserver/src/tenant/storage_layer/image_layer.rs @@ -341,18 +341,12 @@ impl ImageLayer { let new_summary = rewrite(actual_summary); - let mut buf = smallvec::SmallVec::<[u8; PAGE_SZ]>::new(); + let mut buf = Vec::with_capacity(PAGE_SZ); + // TODO: could use smallvec here but it's a pain with Slice Summary::ser_into(&new_summary, &mut buf).context("serialize")?; - if buf.spilled() { - // The code in ImageLayerWriterInner just warn!()s for this. - // It should probably error out as well. - return Err(RewriteSummaryError::Other(anyhow::anyhow!( - "Used more than one page size for summary buffer: {}", - buf.len() - ))); - } file.seek(SeekFrom::Start(0)).await?; - file.write_all(&buf).await?; + let (_buf, res) = file.write_all(buf).await; + res?; Ok(()) } } @@ -555,7 +549,8 @@ impl ImageLayerWriterInner { .await?; let (index_root_blk, block_buf) = self.tree.finish()?; for buf in block_buf.blocks { - file.write_all(buf.as_ref()).await?; + let (_buf, res) = file.write_all(buf).await; + res?; } // Fill in the summary on blk 0 @@ -570,17 +565,12 @@ impl ImageLayerWriterInner { index_root_blk, }; - let mut buf = smallvec::SmallVec::<[u8; PAGE_SZ]>::new(); + let mut buf = Vec::with_capacity(PAGE_SZ); + // TODO: could use smallvec here but it's a pain with Slice Summary::ser_into(&summary, &mut buf)?; - if buf.spilled() { - // This is bad as we only have one free block for the summary - warn!( - "Used more than one page size for summary buffer: {}", - buf.len() - ); - } file.seek(SeekFrom::Start(0)).await?; - file.write_all(&buf).await?; + let (_buf, res) = file.write_all(buf).await; + res?; let metadata = file .metadata() diff --git a/pageserver/src/virtual_file.rs b/pageserver/src/virtual_file.rs index 059a6596d3..6cff748d42 100644 --- a/pageserver/src/virtual_file.rs +++ b/pageserver/src/virtual_file.rs @@ -19,7 +19,7 @@ use once_cell::sync::OnceCell; use pageserver_api::shard::TenantShardId; use std::fs::{self, File}; use std::io::{Error, ErrorKind, Seek, SeekFrom}; -use tokio_epoll_uring::IoBufMut; +use tokio_epoll_uring::{BoundedBuf, IoBufMut, Slice}; use std::os::fd::{AsRawFd, FromRawFd, IntoRawFd, OwnedFd, RawFd}; use std::os::unix::fs::FileExt; @@ -410,10 +410,10 @@ impl VirtualFile { /// step, the tmp path is renamed to the final path. As renames are /// atomic, a crash during the write operation will never leave behind a /// partially written file. - pub async fn crashsafe_overwrite( + pub async fn crashsafe_overwrite( final_path: &Utf8Path, tmp_path: &Utf8Path, - content: &[u8], + content: B, ) -> std::io::Result<()> { let Some(final_path_parent) = final_path.parent() else { return Err(std::io::Error::from_raw_os_error( @@ -430,7 +430,8 @@ impl VirtualFile { .create_new(true), ) .await?; - file.write_all(content).await?; + let (_content, res) = file.write_all(content).await; + res?; file.sync_all().await?; drop(file); // before the rename, that's important! // renames are atomic @@ -601,23 +602,36 @@ impl VirtualFile { Ok(()) } - pub async fn write_all(&mut self, mut buf: &[u8]) -> Result<(), Error> { + /// Writes `buf.slice(0..buf.bytes_init())`. + /// Returns the IoBuf that is underlying the BoundedBuf `buf`. + /// I.e., the returned value's `bytes_init()` method returns something different than the `bytes_init()` that was passed in. + /// It's quite brittle and easy to mis-use, so, we return the size in the Ok() variant. + pub async fn write_all(&mut self, buf: B) -> (B::Buf, Result) { + let nbytes = buf.bytes_init(); + if nbytes == 0 { + return (Slice::into_inner(buf.slice_full()), Ok(0)); + } + let mut buf = buf.slice(0..nbytes); while !buf.is_empty() { - match self.write(buf).await { + // TODO: push `Slice` further down + match self.write(&buf).await { Ok(0) => { - return Err(Error::new( - std::io::ErrorKind::WriteZero, - "failed to write whole buffer", - )); + return ( + Slice::into_inner(buf), + Err(Error::new( + std::io::ErrorKind::WriteZero, + "failed to write whole buffer", + )), + ); } Ok(n) => { - buf = &buf[n..]; + buf = buf.slice(n..); } Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => {} - Err(e) => return Err(e), + Err(e) => return (Slice::into_inner(buf), Err(e)), } } - Ok(()) + (Slice::into_inner(buf), Ok(nbytes)) } async fn write(&mut self, buf: &[u8]) -> Result { @@ -676,7 +690,6 @@ where F: FnMut(tokio_epoll_uring::Slice, u64) -> Fut, Fut: std::future::Future, std::io::Result)>, { - use tokio_epoll_uring::BoundedBuf; let mut buf: tokio_epoll_uring::Slice = buf.slice_full(); // includes all the uninitialized memory while buf.bytes_total() != 0 { let res; @@ -1063,10 +1076,19 @@ mod tests { MaybeVirtualFile::File(file) => file.seek(pos), } } - async fn write_all(&mut self, buf: &[u8]) -> Result<(), Error> { + async fn write_all(&mut self, buf: B) -> Result<(), Error> { match self { - MaybeVirtualFile::VirtualFile(file) => file.write_all(buf).await, - MaybeVirtualFile::File(file) => file.write_all(buf), + MaybeVirtualFile::VirtualFile(file) => { + let (_buf, res) = file.write_all(buf).await; + res.map(|_| ()) + } + MaybeVirtualFile::File(file) => { + let buf_len = buf.bytes_init(); + if buf_len == 0 { + return Ok(()); + } + file.write_all(&buf.slice(0..buf_len)) + } } } @@ -1141,7 +1163,7 @@ mod tests { .to_owned(), ) .await?; - file_a.write_all(b"foobar").await?; + file_a.write_all(b"foobar".to_vec()).await?; // cannot read from a file opened in write-only mode let _ = file_a.read_string().await.unwrap_err(); @@ -1150,7 +1172,7 @@ mod tests { let mut file_a = openfunc(path_a, OpenOptions::new().read(true).to_owned()).await?; // cannot write to a file opened in read-only mode - let _ = file_a.write_all(b"bar").await.unwrap_err(); + let _ = file_a.write_all(b"bar".to_vec()).await.unwrap_err(); // Try simple read assert_eq!("foobar", file_a.read_string().await?); @@ -1293,7 +1315,7 @@ mod tests { let path = testdir.join("myfile"); let tmp_path = testdir.join("myfile.tmp"); - VirtualFile::crashsafe_overwrite(&path, &tmp_path, b"foo") + VirtualFile::crashsafe_overwrite(&path, &tmp_path, b"foo".to_vec()) .await .unwrap(); let mut file = MaybeVirtualFile::from(VirtualFile::open(&path).await.unwrap()); @@ -1302,7 +1324,7 @@ mod tests { assert!(!tmp_path.exists()); drop(file); - VirtualFile::crashsafe_overwrite(&path, &tmp_path, b"bar") + VirtualFile::crashsafe_overwrite(&path, &tmp_path, b"bar".to_vec()) .await .unwrap(); let mut file = MaybeVirtualFile::from(VirtualFile::open(&path).await.unwrap()); @@ -1324,7 +1346,7 @@ mod tests { std::fs::write(&tmp_path, "some preexisting junk that should be removed").unwrap(); assert!(tmp_path.exists()); - VirtualFile::crashsafe_overwrite(&path, &tmp_path, b"foo") + VirtualFile::crashsafe_overwrite(&path, &tmp_path, b"foo".to_vec()) .await .unwrap(); From b6e070bf85c6f4fa204d36ae2d761db30b47d277 Mon Sep 17 00:00:00 2001 From: Konstantin Knizhnik Date: Tue, 13 Feb 2024 20:41:17 +0200 Subject: [PATCH 15/20] Do not perform fast exit for catalog pages in redo filter (#6730) ## Problem See https://github.com/neondatabase/neon/issues/6674 Current implementation of `neon_redo_read_buffer_filter` performs fast exist for catalog pages: ``` /* * Out of an abundance of caution, we always run redo on shared catalogs, * regardless of whether the block is stored in shared buffers. See also * this function's top comment. */ if (!OidIsValid(NInfoGetDbOid(rinfo))) return false; */ as a result last written lsn and relation size for FSM fork are not correctly updated for catalog relations. ## Summary of changes Do not perform fast path return for catalog relations. ## Checklist before requesting a review - [ ] I have performed a self-review of my code. - [ ] If it is a core feature, I have added thorough tests. - [ ] Do we need to implement analytics? if so did you add the relevant metrics to the dashboard? - [ ] If this PR requires public announcement, mark it with /release-notes label and add several sentences in this section. ## Checklist before merging - [ ] Do not forget to reformat commit message to not include the above checklist Co-authored-by: Konstantin Knizhnik --- pgxn/neon/pagestore_smgr.c | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/pgxn/neon/pagestore_smgr.c b/pgxn/neon/pagestore_smgr.c index 63e8b8dc1f..213e396328 100644 --- a/pgxn/neon/pagestore_smgr.c +++ b/pgxn/neon/pagestore_smgr.c @@ -3079,14 +3079,6 @@ neon_redo_read_buffer_filter(XLogReaderState *record, uint8 block_id) XLogRecGetBlockTag(record, block_id, &rinfo, &forknum, &blkno); #endif - /* - * Out of an abundance of caution, we always run redo on shared catalogs, - * regardless of whether the block is stored in shared buffers. See also - * this function's top comment. - */ - if (!OidIsValid(NInfoGetDbOid(rinfo))) - return false; - CopyNRelFileInfoToBufTag(tag, rinfo); tag.forkNum = forknum; tag.blockNum = blkno; @@ -3100,17 +3092,28 @@ neon_redo_read_buffer_filter(XLogReaderState *record, uint8 block_id) */ LWLockAcquire(partitionLock, LW_SHARED); - /* Try to find the relevant buffer */ - buffer = BufTableLookup(&tag, hash); - - no_redo_needed = buffer < 0; + /* + * Out of an abundance of caution, we always run redo on shared catalogs, + * regardless of whether the block is stored in shared buffers. See also + * this function's top comment. + */ + if (!OidIsValid(NInfoGetDbOid(rinfo))) + { + no_redo_needed = false; + } + else + { + /* Try to find the relevant buffer */ + buffer = BufTableLookup(&tag, hash); + no_redo_needed = buffer < 0; + } /* In both cases st lwlsn past this WAL record */ SetLastWrittenLSNForBlock(end_recptr, rinfo, forknum, blkno); /* * we don't have the buffer in memory, update lwLsn past this record, also - * evict page fro file cache + * evict page from file cache */ if (no_redo_needed) lfc_evict(rinfo, forknum, blkno); From ee7bbdda0e58af4350a6886544cd75f3cc1b2de9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Arpad=20M=C3=BCller?= Date: Wed, 14 Feb 2024 02:12:00 +0100 Subject: [PATCH 16/20] Create new metric for directory counts (#6736) There is O(n^2) issues due to how we store these directories (#6626), so it's good to keep an eye on them and ensure the numbers stay low. The new per-timeline metric `pageserver_directory_entries_count` isn't perfect, namely we don't calculate it every time we attach the timeline, but only if there is an actual change. Also, it is a collective metric over multiple scalars. Lastly, we only emit the metric if it is above a certain threshold. However, the metric still give a feel for the general size of the timeline. We care less for small values as the metric is mainly there to detect and track tenants with large directory counts. We also expose the directory counts in `TimelineInfo` so that one can get the detailed size distribution directly via the pageserver's API. Related: #6642 , https://github.com/neondatabase/cloud/issues/10273 --- libs/pageserver_api/src/models.rs | 2 + libs/pageserver_api/src/reltag.rs | 1 + pageserver/src/http/routes.rs | 1 + pageserver/src/metrics.rs | 34 +++++++++++++++- pageserver/src/pgdatadir_mapping.rs | 62 +++++++++++++++++++++++++++++ pageserver/src/tenant/timeline.rs | 39 +++++++++++++++++- test_runner/fixtures/metrics.py | 1 + 7 files changed, 137 insertions(+), 3 deletions(-) diff --git a/libs/pageserver_api/src/models.rs b/libs/pageserver_api/src/models.rs index 46324efd43..1226eaa312 100644 --- a/libs/pageserver_api/src/models.rs +++ b/libs/pageserver_api/src/models.rs @@ -494,6 +494,8 @@ pub struct TimelineInfo { pub current_logical_size: u64, pub current_logical_size_is_accurate: bool, + pub directory_entries_counts: Vec, + /// Sum of the size of all layer files. /// If a layer is present in both local FS and S3, it counts only once. pub current_physical_size: Option, // is None when timeline is Unloaded diff --git a/libs/pageserver_api/src/reltag.rs b/libs/pageserver_api/src/reltag.rs index 8eb848a514..38693ab847 100644 --- a/libs/pageserver_api/src/reltag.rs +++ b/libs/pageserver_api/src/reltag.rs @@ -124,6 +124,7 @@ impl RelTag { Ord, strum_macros::EnumIter, strum_macros::FromRepr, + enum_map::Enum, )] #[repr(u8)] pub enum SlruKind { diff --git a/pageserver/src/http/routes.rs b/pageserver/src/http/routes.rs index 4be8ee9892..c354cc9ab6 100644 --- a/pageserver/src/http/routes.rs +++ b/pageserver/src/http/routes.rs @@ -422,6 +422,7 @@ async fn build_timeline_info_common( tenant::timeline::logical_size::Accuracy::Approximate => false, tenant::timeline::logical_size::Accuracy::Exact => true, }, + directory_entries_counts: timeline.get_directory_metrics().to_vec(), current_physical_size, current_logical_size_non_incremental: None, timeline_dir_layer_file_size_sum: None, diff --git a/pageserver/src/metrics.rs b/pageserver/src/metrics.rs index 98c98ef6e7..c2b1eafc3a 100644 --- a/pageserver/src/metrics.rs +++ b/pageserver/src/metrics.rs @@ -602,6 +602,15 @@ pub(crate) mod initial_logical_size { }); } +static DIRECTORY_ENTRIES_COUNT: Lazy = Lazy::new(|| { + register_uint_gauge_vec!( + "pageserver_directory_entries_count", + "Sum of the entries in pageserver-stored directory listings", + &["tenant_id", "shard_id", "timeline_id"] + ) + .expect("failed to define a metric") +}); + pub(crate) static TENANT_STATE_METRIC: Lazy = Lazy::new(|| { register_uint_gauge_vec!( "pageserver_tenant_states_count", @@ -1809,6 +1818,7 @@ pub(crate) struct TimelineMetrics { resident_physical_size_gauge: UIntGauge, /// copy of LayeredTimeline.current_logical_size pub current_logical_size_gauge: UIntGauge, + pub directory_entries_count_gauge: Lazy UIntGauge>>, pub num_persistent_files_created: IntCounter, pub persistent_bytes_written: IntCounter, pub evictions: IntCounter, @@ -1818,12 +1828,12 @@ pub(crate) struct TimelineMetrics { impl TimelineMetrics { pub fn new( tenant_shard_id: &TenantShardId, - timeline_id: &TimelineId, + timeline_id_raw: &TimelineId, evictions_with_low_residence_duration_builder: EvictionsWithLowResidenceDurationBuilder, ) -> Self { let tenant_id = tenant_shard_id.tenant_id.to_string(); let shard_id = format!("{}", tenant_shard_id.shard_slug()); - let timeline_id = timeline_id.to_string(); + let timeline_id = timeline_id_raw.to_string(); let flush_time_histo = StorageTimeMetrics::new( StorageTimeOperation::LayerFlush, &tenant_id, @@ -1876,6 +1886,22 @@ impl TimelineMetrics { let current_logical_size_gauge = CURRENT_LOGICAL_SIZE .get_metric_with_label_values(&[&tenant_id, &shard_id, &timeline_id]) .unwrap(); + // TODO use impl Trait syntax here once we have ability to use it: https://github.com/rust-lang/rust/issues/63065 + let directory_entries_count_gauge_closure = { + let tenant_shard_id = *tenant_shard_id; + let timeline_id_raw = *timeline_id_raw; + move || { + let tenant_id = tenant_shard_id.tenant_id.to_string(); + let shard_id = format!("{}", tenant_shard_id.shard_slug()); + let timeline_id = timeline_id_raw.to_string(); + let gauge: UIntGauge = DIRECTORY_ENTRIES_COUNT + .get_metric_with_label_values(&[&tenant_id, &shard_id, &timeline_id]) + .unwrap(); + gauge + } + }; + let directory_entries_count_gauge: Lazy UIntGauge>> = + Lazy::new(Box::new(directory_entries_count_gauge_closure)); let num_persistent_files_created = NUM_PERSISTENT_FILES_CREATED .get_metric_with_label_values(&[&tenant_id, &shard_id, &timeline_id]) .unwrap(); @@ -1902,6 +1928,7 @@ impl TimelineMetrics { last_record_gauge, resident_physical_size_gauge, current_logical_size_gauge, + directory_entries_count_gauge, num_persistent_files_created, persistent_bytes_written, evictions, @@ -1944,6 +1971,9 @@ impl Drop for TimelineMetrics { RESIDENT_PHYSICAL_SIZE.remove_label_values(&[tenant_id, &shard_id, timeline_id]); } let _ = CURRENT_LOGICAL_SIZE.remove_label_values(&[tenant_id, &shard_id, timeline_id]); + if let Some(metric) = Lazy::get(&DIRECTORY_ENTRIES_COUNT) { + let _ = metric.remove_label_values(&[tenant_id, &shard_id, timeline_id]); + } let _ = NUM_PERSISTENT_FILES_CREATED.remove_label_values(&[tenant_id, &shard_id, timeline_id]); let _ = PERSISTENT_BYTES_WRITTEN.remove_label_values(&[tenant_id, &shard_id, timeline_id]); diff --git a/pageserver/src/pgdatadir_mapping.rs b/pageserver/src/pgdatadir_mapping.rs index f1d18c0146..5f80ea9b5e 100644 --- a/pageserver/src/pgdatadir_mapping.rs +++ b/pageserver/src/pgdatadir_mapping.rs @@ -14,6 +14,7 @@ use crate::span::debug_assert_current_span_has_tenant_and_timeline_id_no_shard_i use crate::walrecord::NeonWalRecord; use anyhow::{ensure, Context}; use bytes::{Buf, Bytes, BytesMut}; +use enum_map::Enum; use pageserver_api::key::{ dbdir_key_range, is_rel_block_key, is_slru_block_key, rel_block_to_key, rel_dir_to_key, rel_key_range, rel_size_to_key, relmap_file_key, slru_block_to_key, slru_dir_to_key, @@ -155,6 +156,7 @@ impl Timeline { pending_updates: HashMap::new(), pending_deletions: Vec::new(), pending_nblocks: 0, + pending_directory_entries: Vec::new(), lsn, } } @@ -868,6 +870,7 @@ pub struct DatadirModification<'a> { pending_updates: HashMap>, pending_deletions: Vec<(Range, Lsn)>, pending_nblocks: i64, + pending_directory_entries: Vec<(DirectoryKind, usize)>, } impl<'a> DatadirModification<'a> { @@ -899,6 +902,7 @@ impl<'a> DatadirModification<'a> { let buf = DbDirectory::ser(&DbDirectory { dbdirs: HashMap::new(), })?; + self.pending_directory_entries.push((DirectoryKind::Db, 0)); self.put(DBDIR_KEY, Value::Image(buf.into())); // Create AuxFilesDirectory @@ -907,16 +911,24 @@ impl<'a> DatadirModification<'a> { let buf = TwoPhaseDirectory::ser(&TwoPhaseDirectory { xids: HashSet::new(), })?; + self.pending_directory_entries + .push((DirectoryKind::TwoPhase, 0)); self.put(TWOPHASEDIR_KEY, Value::Image(buf.into())); let buf: Bytes = SlruSegmentDirectory::ser(&SlruSegmentDirectory::default())?.into(); let empty_dir = Value::Image(buf); self.put(slru_dir_to_key(SlruKind::Clog), empty_dir.clone()); + self.pending_directory_entries + .push((DirectoryKind::SlruSegment(SlruKind::Clog), 0)); self.put( slru_dir_to_key(SlruKind::MultiXactMembers), empty_dir.clone(), ); + self.pending_directory_entries + .push((DirectoryKind::SlruSegment(SlruKind::Clog), 0)); self.put(slru_dir_to_key(SlruKind::MultiXactOffsets), empty_dir); + self.pending_directory_entries + .push((DirectoryKind::SlruSegment(SlruKind::MultiXactOffsets), 0)); Ok(()) } @@ -1017,6 +1029,7 @@ impl<'a> DatadirModification<'a> { let buf = RelDirectory::ser(&RelDirectory { rels: HashSet::new(), })?; + self.pending_directory_entries.push((DirectoryKind::Rel, 0)); self.put( rel_dir_to_key(spcnode, dbnode), Value::Image(Bytes::from(buf)), @@ -1039,6 +1052,8 @@ impl<'a> DatadirModification<'a> { if !dir.xids.insert(xid) { anyhow::bail!("twophase file for xid {} already exists", xid); } + self.pending_directory_entries + .push((DirectoryKind::TwoPhase, dir.xids.len())); self.put( TWOPHASEDIR_KEY, Value::Image(Bytes::from(TwoPhaseDirectory::ser(&dir)?)), @@ -1074,6 +1089,8 @@ impl<'a> DatadirModification<'a> { let mut dir = DbDirectory::des(&buf)?; if dir.dbdirs.remove(&(spcnode, dbnode)).is_some() { let buf = DbDirectory::ser(&dir)?; + self.pending_directory_entries + .push((DirectoryKind::Db, dir.dbdirs.len())); self.put(DBDIR_KEY, Value::Image(buf.into())); } else { warn!( @@ -1111,6 +1128,8 @@ impl<'a> DatadirModification<'a> { // Didn't exist. Update dbdir dbdir.dbdirs.insert((rel.spcnode, rel.dbnode), false); let buf = DbDirectory::ser(&dbdir).context("serialize db")?; + self.pending_directory_entries + .push((DirectoryKind::Db, dbdir.dbdirs.len())); self.put(DBDIR_KEY, Value::Image(buf.into())); // and create the RelDirectory @@ -1125,6 +1144,10 @@ impl<'a> DatadirModification<'a> { if !rel_dir.rels.insert((rel.relnode, rel.forknum)) { return Err(RelationError::AlreadyExists); } + + self.pending_directory_entries + .push((DirectoryKind::Rel, rel_dir.rels.len())); + self.put( rel_dir_key, Value::Image(Bytes::from( @@ -1216,6 +1239,9 @@ impl<'a> DatadirModification<'a> { let buf = self.get(dir_key, ctx).await?; let mut dir = RelDirectory::des(&buf)?; + self.pending_directory_entries + .push((DirectoryKind::Rel, dir.rels.len())); + if dir.rels.remove(&(rel.relnode, rel.forknum)) { self.put(dir_key, Value::Image(Bytes::from(RelDirectory::ser(&dir)?))); } else { @@ -1251,6 +1277,8 @@ impl<'a> DatadirModification<'a> { if !dir.segments.insert(segno) { anyhow::bail!("slru segment {kind:?}/{segno} already exists"); } + self.pending_directory_entries + .push((DirectoryKind::SlruSegment(kind), dir.segments.len())); self.put( dir_key, Value::Image(Bytes::from(SlruSegmentDirectory::ser(&dir)?)), @@ -1295,6 +1323,8 @@ impl<'a> DatadirModification<'a> { if !dir.segments.remove(&segno) { warn!("slru segment {:?}/{} does not exist", kind, segno); } + self.pending_directory_entries + .push((DirectoryKind::SlruSegment(kind), dir.segments.len())); self.put( dir_key, Value::Image(Bytes::from(SlruSegmentDirectory::ser(&dir)?)), @@ -1325,6 +1355,8 @@ impl<'a> DatadirModification<'a> { if !dir.xids.remove(&xid) { warn!("twophase file for xid {} does not exist", xid); } + self.pending_directory_entries + .push((DirectoryKind::TwoPhase, dir.xids.len())); self.put( TWOPHASEDIR_KEY, Value::Image(Bytes::from(TwoPhaseDirectory::ser(&dir)?)), @@ -1340,6 +1372,8 @@ impl<'a> DatadirModification<'a> { let buf = AuxFilesDirectory::ser(&AuxFilesDirectory { files: HashMap::new(), })?; + self.pending_directory_entries + .push((DirectoryKind::AuxFiles, 0)); self.put(AUX_FILES_KEY, Value::Image(Bytes::from(buf))); Ok(()) } @@ -1366,6 +1400,9 @@ impl<'a> DatadirModification<'a> { } else { dir.files.insert(path, Bytes::copy_from_slice(content)); } + self.pending_directory_entries + .push((DirectoryKind::AuxFiles, dir.files.len())); + self.put( AUX_FILES_KEY, Value::Image(Bytes::from( @@ -1427,6 +1464,10 @@ impl<'a> DatadirModification<'a> { self.pending_nblocks = 0; } + for (kind, count) in std::mem::take(&mut self.pending_directory_entries) { + writer.update_directory_entries_count(kind, count as u64); + } + Ok(()) } @@ -1464,6 +1505,10 @@ impl<'a> DatadirModification<'a> { writer.update_current_logical_size(pending_nblocks * i64::from(BLCKSZ)); } + for (kind, count) in std::mem::take(&mut self.pending_directory_entries) { + writer.update_directory_entries_count(kind, count as u64); + } + Ok(()) } @@ -1588,6 +1633,23 @@ struct SlruSegmentDirectory { segments: HashSet, } +#[derive(Copy, Clone, PartialEq, Eq, Debug, enum_map::Enum)] +#[repr(u8)] +pub(crate) enum DirectoryKind { + Db, + TwoPhase, + Rel, + AuxFiles, + SlruSegment(SlruKind), +} + +impl DirectoryKind { + pub(crate) const KINDS_NUM: usize = ::LENGTH; + pub(crate) fn offset(&self) -> usize { + self.into_usize() + } +} + static ZERO_PAGE: Bytes = Bytes::from_static(&[0u8; BLCKSZ as usize]); #[allow(clippy::bool_assert_comparison)] diff --git a/pageserver/src/tenant/timeline.rs b/pageserver/src/tenant/timeline.rs index 625be7a644..87cf0ac6ea 100644 --- a/pageserver/src/tenant/timeline.rs +++ b/pageserver/src/tenant/timeline.rs @@ -14,6 +14,7 @@ use enumset::EnumSet; use fail::fail_point; use futures::stream::StreamExt; use itertools::Itertools; +use once_cell::sync::Lazy; use pageserver_api::{ keyspace::{key_range_size, KeySpaceAccum}, models::{ @@ -34,17 +35,22 @@ use tokio_util::sync::CancellationToken; use tracing::*; use utils::sync::gate::Gate; -use std::collections::{BTreeMap, BinaryHeap, HashMap, HashSet}; use std::ops::{Deref, Range}; use std::pin::pin; use std::sync::atomic::Ordering as AtomicOrdering; use std::sync::{Arc, Mutex, RwLock, Weak}; use std::time::{Duration, Instant, SystemTime}; +use std::{ + array, + collections::{BTreeMap, BinaryHeap, HashMap, HashSet}, + sync::atomic::AtomicU64, +}; use std::{ cmp::{max, min, Ordering}, ops::ControlFlow, }; +use crate::pgdatadir_mapping::DirectoryKind; use crate::tenant::timeline::logical_size::CurrentLogicalSize; use crate::tenant::{ layer_map::{LayerMap, SearchResult}, @@ -258,6 +264,8 @@ pub struct Timeline { // in `crate::page_service` writes these metrics. pub(crate) query_metrics: crate::metrics::SmgrQueryTimePerTimeline, + directory_metrics: [AtomicU64; DirectoryKind::KINDS_NUM], + /// Ensures layers aren't frozen by checkpointer between /// [`Timeline::get_layer_for_write`] and layer reads. /// Locked automatically by [`TimelineWriter`] and checkpointer. @@ -790,6 +798,10 @@ impl Timeline { self.metrics.resident_physical_size_get() } + pub(crate) fn get_directory_metrics(&self) -> [u64; DirectoryKind::KINDS_NUM] { + array::from_fn(|idx| self.directory_metrics[idx].load(AtomicOrdering::Relaxed)) + } + /// /// Wait until WAL has been received and processed up to this LSN. /// @@ -1496,6 +1508,8 @@ impl Timeline { &timeline_id, ), + directory_metrics: array::from_fn(|_| AtomicU64::new(0)), + flush_loop_state: Mutex::new(FlushLoopState::NotStarted), layer_flush_start_tx, @@ -2264,6 +2278,29 @@ impl Timeline { } } + pub(crate) fn update_directory_entries_count(&self, kind: DirectoryKind, count: u64) { + self.directory_metrics[kind.offset()].store(count, AtomicOrdering::Relaxed); + let aux_metric = + self.directory_metrics[DirectoryKind::AuxFiles.offset()].load(AtomicOrdering::Relaxed); + + let sum_of_entries = self + .directory_metrics + .iter() + .map(|v| v.load(AtomicOrdering::Relaxed)) + .sum(); + // Set a high general threshold and a lower threshold for the auxiliary files, + // as we can have large numbers of relations in the db directory. + const SUM_THRESHOLD: u64 = 5000; + const AUX_THRESHOLD: u64 = 1000; + if sum_of_entries >= SUM_THRESHOLD || aux_metric >= AUX_THRESHOLD { + self.metrics + .directory_entries_count_gauge + .set(sum_of_entries); + } else if let Some(metric) = Lazy::get(&self.metrics.directory_entries_count_gauge) { + metric.set(sum_of_entries); + } + } + async fn find_layer(&self, layer_file_name: &str) -> Option { let guard = self.layers.read().await; for historic_layer in guard.layer_map().iter_historic_layers() { diff --git a/test_runner/fixtures/metrics.py b/test_runner/fixtures/metrics.py index ef41774289..418370c3ab 100644 --- a/test_runner/fixtures/metrics.py +++ b/test_runner/fixtures/metrics.py @@ -96,5 +96,6 @@ PAGESERVER_PER_TENANT_METRICS: Tuple[str, ...] = ( "pageserver_evictions_total", "pageserver_evictions_with_low_residence_duration_total", *PAGESERVER_PER_TENANT_REMOTE_TIMELINE_CLIENT_METRICS, + # "pageserver_directory_entries_count", -- only used if above a certain threshold # "pageserver_broken_tenants_count" -- used only for broken ) From a5114a99b275b52fc7a512e62a7f80a5a103433d Mon Sep 17 00:00:00 2001 From: Heikki Linnakangas Date: Wed, 14 Feb 2024 10:34:58 +0200 Subject: [PATCH 17/20] Create a symlink from pg_dynshmem to /dev/shm See included comment and issue https://github.com/neondatabase/autoscaling/issues/800 for details. This has no effect, unless you set "dynamic_shared_memory_type = mmap" in postgresql.conf. --- compute_tools/src/compute.rs | 44 +++++++++++++++++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/compute_tools/src/compute.rs b/compute_tools/src/compute.rs index 993b5725a4..83db8e09ec 100644 --- a/compute_tools/src/compute.rs +++ b/compute_tools/src/compute.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use std::env; use std::fs; use std::io::BufRead; -use std::os::unix::fs::PermissionsExt; +use std::os::unix::fs::{symlink, PermissionsExt}; use std::path::Path; use std::process::{Command, Stdio}; use std::str::FromStr; @@ -634,6 +634,48 @@ impl ComputeNode { // Update pg_hba.conf received with basebackup. update_pg_hba(pgdata_path)?; + // Place pg_dynshmem under /dev/shm. This allows us to use + // 'dynamic_shared_memory_type = mmap' so that the files are placed in + // /dev/shm, similar to how 'dynamic_shared_memory_type = posix' works. + // + // Why on earth don't we just stick to the 'posix' default, you might + // ask. It turns out that making large allocations with 'posix' doesn't + // work very well with autoscaling. The behavior we want is that: + // + // 1. You can make large DSM allocations, larger than the current RAM + // size of the VM, without errors + // + // 2. If the allocated memory is really used, the VM is scaled up + // automatically to accommodate that + // + // We try to make that possible by having swap in the VM. But with the + // default 'posix' DSM implementation, we fail step 1, even when there's + // plenty of swap available. PostgreSQL uses posix_fallocate() to create + // the shmem segment, which is really just a file in /dev/shm in Linux, + // but posix_fallocate() on tmpfs returns ENOMEM if the size is larger + // than available RAM. + // + // Using 'dynamic_shared_memory_type = mmap' works around that, because + // the Postgres 'mmap' DSM implementation doesn't use + // posix_fallocate(). Instead, it uses repeated calls to write(2) to + // fill the file with zeros. It's weird that that differs between + // 'posix' and 'mmap', but we take advantage of it. When the file is + // filled slowly with write(2), the kernel allows it to grow larger, as + // long as there's swap available. + // + // In short, using 'dynamic_shared_memory_type = mmap' allows us one DSM + // segment to be larger than currently available RAM. But because we + // don't want to store it on a real file, which the kernel would try to + // flush to disk, so symlink pg_dynshm to /dev/shm. + // + // We don't set 'dynamic_shared_memory_type = mmap' here, we let the + // control plane control that option. If 'mmap' is not used, this + // symlink doesn't affect anything. + // + // See https://github.com/neondatabase/autoscaling/issues/800 + std::fs::remove_dir(pgdata_path.join("pg_dynshmem"))?; + symlink("/dev/shm/", pgdata_path.join("pg_dynshmem"))?; + match spec.mode { ComputeMode::Primary => {} ComputeMode::Replica | ComputeMode::Static(..) => { From a97b54e3b9e692532962d65b89b7e5f67a9c28a4 Mon Sep 17 00:00:00 2001 From: Heikki Linnakangas Date: Wed, 14 Feb 2024 10:35:59 +0200 Subject: [PATCH 18/20] Cherry-pick Postgres bugfix to 'mmap' DSM implementation Cherry-pick Upstream commit fbf9a7ac4d to neon stable branches. We'll get it in the next PostgreSQL minor release anyway, but we need it now, if we want to start using the 'mmap' implementation. See https://github.com/neondatabase/autoscaling/issues/800 for the plans on doing that. --- vendor/postgres-v14 | 2 +- vendor/postgres-v15 | 2 +- vendor/postgres-v16 | 2 +- vendor/revisions.json | 6 +++--- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/vendor/postgres-v14 b/vendor/postgres-v14 index 018fb05201..9dd9956c55 160000 --- a/vendor/postgres-v14 +++ b/vendor/postgres-v14 @@ -1 +1 @@ -Subproject commit 018fb052011081dc2733d3118d12e5c36df6eba1 +Subproject commit 9dd9956c55ffbbd9abe77d10382453757fedfcf5 diff --git a/vendor/postgres-v15 b/vendor/postgres-v15 index 6ee78a3c29..ca2def9993 160000 --- a/vendor/postgres-v15 +++ b/vendor/postgres-v15 @@ -1 +1 @@ -Subproject commit 6ee78a3c29e33cafd85ba09568b6b5eb031d29b9 +Subproject commit ca2def999368d9df098a637234ad5a9003189463 diff --git a/vendor/postgres-v16 b/vendor/postgres-v16 index 550cdd26d4..9c37a49884 160000 --- a/vendor/postgres-v16 +++ b/vendor/postgres-v16 @@ -1 +1 @@ -Subproject commit 550cdd26d445afdd26b15aa93c8c2f3dc52f8361 +Subproject commit 9c37a4988463a97d9cacb321acf3828b09823269 diff --git a/vendor/revisions.json b/vendor/revisions.json index 91ebb8cb34..72bc0d7e0d 100644 --- a/vendor/revisions.json +++ b/vendor/revisions.json @@ -1,5 +1,5 @@ { - "postgres-v16": "550cdd26d445afdd26b15aa93c8c2f3dc52f8361", - "postgres-v15": "6ee78a3c29e33cafd85ba09568b6b5eb031d29b9", - "postgres-v14": "018fb052011081dc2733d3118d12e5c36df6eba1" + "postgres-v16": "9c37a4988463a97d9cacb321acf3828b09823269", + "postgres-v15": "ca2def999368d9df098a637234ad5a9003189463", + "postgres-v14": "9dd9956c55ffbbd9abe77d10382453757fedfcf5" } From a9ec4eb4fc7777a529ff8c5ede814dd657390e58 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Wed, 14 Feb 2024 10:26:32 +0000 Subject: [PATCH 19/20] hold cancel session (#6750) ## Problem In a recent refactor, we accidentally dropped the cancel session early ## Summary of changes Hold the cancel session during proxy passthrough --- proxy/src/proxy.rs | 1 + proxy/src/proxy/passthrough.rs | 2 ++ 2 files changed, 3 insertions(+) diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index ce77098a5f..8a9445303a 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -331,6 +331,7 @@ pub async fn handle_client( compute: node, req: _request_gauge, conn: _client_gauge, + cancel: session, })) } diff --git a/proxy/src/proxy/passthrough.rs b/proxy/src/proxy/passthrough.rs index c98f68d8d1..73c170fc0b 100644 --- a/proxy/src/proxy/passthrough.rs +++ b/proxy/src/proxy/passthrough.rs @@ -1,4 +1,5 @@ use crate::{ + cancellation, compute::PostgresConnection, console::messages::MetricsAuxInfo, metrics::NUM_BYTES_PROXIED_COUNTER, @@ -57,6 +58,7 @@ pub struct ProxyPassthrough { pub req: IntCounterPairGuard, pub conn: IntCounterPairGuard, + pub cancel: cancellation::Session, } impl ProxyPassthrough { From f39b0fce9b24a049208e74cc7d2a6b006b487839 Mon Sep 17 00:00:00 2001 From: John Spray Date: Wed, 14 Feb 2024 10:57:01 +0000 Subject: [PATCH 20/20] Revert #6666 "tests: try to make restored-datadir comparison tests not flaky" (#6751) The #6666 change appears to have made the test fail more often. PR https://github.com/neondatabase/neon/pull/6712 should re-instate this change, along with its change to make the overall flow more reliable. This reverts commit 568f91420a9c677e77aeb736cb3f995a85f0b106. --- test_runner/fixtures/neon_fixtures.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 26f2b999a6..04af73c327 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -3967,27 +3967,24 @@ def list_files_to_compare(pgdata_dir: Path) -> List[str]: # pg is the existing and running compute node, that we want to compare with a basebackup def check_restored_datadir_content(test_output_dir: Path, env: NeonEnv, endpoint: Endpoint): - pg_bin = PgBin(test_output_dir, env.pg_distrib_dir, env.pg_version) - # Get the timeline ID. We need it for the 'basebackup' command timeline_id = TimelineId(endpoint.safe_psql("SHOW neon.timeline_id")[0][0]) + # many tests already checkpoint, but do it just in case + with closing(endpoint.connect()) as conn: + with conn.cursor() as cur: + cur.execute("CHECKPOINT") + + # wait for pageserver to catch up + wait_for_last_flush_lsn(env, endpoint, endpoint.tenant_id, timeline_id) # stop postgres to ensure that files won't change endpoint.stop() - # Read the shutdown checkpoint's LSN - pg_controldata_path = os.path.join(pg_bin.pg_bin_path, "pg_controldata") - cmd = f"{pg_controldata_path} -D {endpoint.pgdata_dir}" - result = subprocess.run(cmd, capture_output=True, text=True, shell=True) - checkpoint_lsn = re.findall( - "Latest checkpoint location:\\s+([0-9A-F]+/[0-9A-F]+)", result.stdout - )[0] - log.debug(f"last checkpoint at {checkpoint_lsn}") - # Take a basebackup from pageserver restored_dir_path = env.repo_dir / f"{endpoint.endpoint_id}_restored_datadir" restored_dir_path.mkdir(exist_ok=True) + pg_bin = PgBin(test_output_dir, env.pg_distrib_dir, env.pg_version) psql_path = os.path.join(pg_bin.pg_bin_path, "psql") pageserver_id = env.attachment_service.locate(endpoint.tenant_id)[0]["node_id"] @@ -3995,7 +3992,7 @@ def check_restored_datadir_content(test_output_dir: Path, env: NeonEnv, endpoint {psql_path} \ --no-psqlrc \ postgres://localhost:{env.get_pageserver(pageserver_id).service_port.pg} \ - -c 'basebackup {endpoint.tenant_id} {timeline_id} {checkpoint_lsn}' \ + -c 'basebackup {endpoint.tenant_id} {timeline_id}' \ | tar -x -C {restored_dir_path} """