diff --git a/pageserver/src/bin/pageserver.rs b/pageserver/src/bin/pageserver.rs index e9f8092a29..bd495d2316 100644 --- a/pageserver/src/bin/pageserver.rs +++ b/pageserver/src/bin/pageserver.rs @@ -1,6 +1,6 @@ //! Main entry point for the Page Server executable. -use std::{env, path::Path, str::FromStr, thread}; +use std::{env, path::Path, str::FromStr}; use tracing::*; use zenith_utils::{auth::JwtAuth, logging, postgres_backend::AuthType, tcp_listener, GIT_VERSION}; @@ -12,7 +12,9 @@ use daemonize::Daemonize; use pageserver::{ branches, config::{defaults::*, PageServerConf}, - http, page_cache, page_service, remote_storage, tenant_mgr, virtual_file, LOG_FILE_NAME, + http, page_cache, page_service, remote_storage, tenant_mgr, thread_mgr, + thread_mgr::ThreadKind, + virtual_file, LOG_FILE_NAME, }; use zenith_utils::http::endpoint; use zenith_utils::postgres_backend; @@ -169,7 +171,7 @@ fn start_pageserver(conf: &'static PageServerConf, daemonize: bool) -> Result<() ); let pageserver_listener = tcp_listener::bind(conf.listen_pg_addr.clone())?; - // XXX: Don't spawn any threads before daemonizing! + // NB: Don't spawn any threads before daemonizing! if daemonize { info!("daemonizing..."); @@ -195,16 +197,9 @@ fn start_pageserver(conf: &'static PageServerConf, daemonize: bool) -> Result<() } let signals = signals::install_shutdown_handlers()?; - let (async_shutdown_tx, async_shutdown_rx) = tokio::sync::watch::channel(()); - let mut threads = Vec::new(); - - let sync_startup = remote_storage::start_local_timeline_sync(conf, async_shutdown_rx) + let sync_startup = remote_storage::start_local_timeline_sync(conf) .context("Failed to set up local files sync with external storage")?; - if let Some(handle) = sync_startup.sync_loop_handle { - threads.push(handle); - } - // Initialize tenant manager. tenant_mgr::set_timeline_states(conf, sync_startup.initial_timeline_states); @@ -221,25 +216,27 @@ fn start_pageserver(conf: &'static PageServerConf, daemonize: bool) -> Result<() // Spawn a new thread for the http endpoint // bind before launching separate thread so the error reported before startup exits - let cloned = auth.clone(); - threads.push( - thread::Builder::new() - .name("http_endpoint_thread".into()) - .spawn(move || { - let router = http::make_router(conf, cloned); - endpoint::serve_thread_main(router, http_listener) - })?, - ); + let auth_cloned = auth.clone(); + thread_mgr::spawn( + ThreadKind::HttpEndpointListener, + None, + None, + "http_endpoint_thread", + move || { + let router = http::make_router(conf, auth_cloned); + endpoint::serve_thread_main(router, http_listener, thread_mgr::shutdown_watcher()) + }, + )?; - // Spawn a thread to listen for connections. It will spawn further threads + // Spawn a thread to listen for libpq connections. It will spawn further threads // for each connection. - threads.push( - thread::Builder::new() - .name("Page Service thread".into()) - .spawn(move || { - page_service::thread_main(conf, auth, pageserver_listener, conf.auth_type) - })?, - ); + thread_mgr::spawn( + ThreadKind::LibpqEndpointListener, + None, + None, + "libpq endpoint thread", + move || page_service::thread_main(conf, auth, pageserver_listener, conf.auth_type), + )?; signals.handle(|signal| match signal { Signal::Quit => { @@ -255,21 +252,38 @@ fn start_pageserver(conf: &'static PageServerConf, daemonize: bool) -> Result<() "Got {}. Terminating gracefully in fast shutdown mode", signal.name() ); - - async_shutdown_tx.send(())?; - postgres_backend::set_pgbackend_shutdown_requested(); - tenant_mgr::shutdown_all_tenants()?; - endpoint::shutdown(); - - for handle in std::mem::take(&mut threads) { - handle - .join() - .expect("thread panicked") - .expect("thread exited with an error"); - } - - info!("Shut down successfully completed"); - std::process::exit(0); + shutdown_pageserver(); + unreachable!() } }) } + +fn shutdown_pageserver() { + // Shut down the libpq endpoint thread. This prevents new connections from + // being accepted. + thread_mgr::shutdown_threads(Some(ThreadKind::LibpqEndpointListener), None, None); + + // Shut down any page service threads. + postgres_backend::set_pgbackend_shutdown_requested(); + thread_mgr::shutdown_threads(Some(ThreadKind::PageRequestHandler), None, None); + + // Shut down all the tenants. This flushes everything to disk and kills + // the checkpoint and GC threads. + tenant_mgr::shutdown_all_tenants(); + + // Stop syncing with remote storage. + // + // FIXME: Does this wait for the sync thread to finish syncing what's queued up? + // Should it? + thread_mgr::shutdown_threads(Some(ThreadKind::StorageSync), None, None); + + // Shut down the HTTP endpoint last, so that you can still check the server's + // status while it's shutting down. + thread_mgr::shutdown_threads(Some(ThreadKind::HttpEndpointListener), None, None); + + // There should be nothing left, but let's be sure + thread_mgr::shutdown_threads(None, None, None); + + info!("Shut down successfully completed"); + std::process::exit(0); +} diff --git a/pageserver/src/layered_repository.rs b/pageserver/src/layered_repository.rs index 71a34d0157..5b8ad61c5f 100644 --- a/pageserver/src/layered_repository.rs +++ b/pageserver/src/layered_repository.rs @@ -40,9 +40,8 @@ use crate::repository::{ BlockNumber, GcResult, Repository, RepositoryTimeline, Timeline, TimelineSyncState, TimelineWriter, ZenithWalRecord, }; -use crate::tenant_mgr; +use crate::thread_mgr; use crate::virtual_file::VirtualFile; -use crate::walreceiver; use crate::walreceiver::IS_WAL_RECEIVER; use crate::walredo::WalRedoManager; use crate::CheckpointConfig; @@ -286,19 +285,7 @@ impl Repository for LayeredRepository { Ok(()) } - // Wait for all threads to complete and persist repository data before pageserver shutdown. - fn shutdown(&self) -> Result<()> { - trace!("LayeredRepository shutdown for tenant {}", self.tenantid); - - let timelines = self.timelines.lock().unwrap(); - for (timelineid, timeline) in timelines.iter() { - shutdown_timeline(self.tenantid, *timelineid, timeline)?; - } - - Ok(()) - } - - // TODO this method currentlly does not do anything to prevent (or react to) state updates between a sync task schedule and a sync task end (that causes this update). + // TODO this method currently does not do anything to prevent (or react to) state updates between a sync task schedule and a sync task end (that causes this update). // Sync task is enqueued and can error and be rescheduled, so some significant time may pass between the events. // /// Reacts on the timeline sync state change, changing pageserver's memory state for this timeline (unload or load of the timeline files). @@ -309,7 +296,7 @@ impl Repository for LayeredRepository { ) -> Result<()> { let mut timelines_accessor = self.timelines.lock().unwrap(); - let timeline_to_shutdown = match new_state { + match new_state { TimelineSyncState::Ready(_) => { let reloaded_timeline = self.init_local_timeline(timeline_id, &mut timelines_accessor)?; @@ -329,10 +316,6 @@ impl Repository for LayeredRepository { }; drop(timelines_accessor); - if let Some(timeline) = timeline_to_shutdown { - shutdown_timeline(self.tenantid, timeline_id, &timeline)?; - } - Ok(()) } @@ -358,30 +341,6 @@ impl Repository for LayeredRepository { } } -fn shutdown_timeline( - tenant_id: ZTenantId, - timeline_id: ZTimelineId, - timeline: &LayeredTimelineEntry, -) -> Result<(), anyhow::Error> { - match timeline { - LayeredTimelineEntry::Local(timeline) => { - timeline - .upload_relishes - .store(false, atomic::Ordering::Relaxed); - walreceiver::stop_wal_receiver(tenant_id, timeline_id); - trace!("repo shutdown. checkpoint timeline {}", timeline_id); - // Do not reconstruct pages to reduce shutdown time - timeline.checkpoint(CheckpointConfig::Flush)?; - //TODO Wait for walredo process to shutdown too - } - LayeredTimelineEntry::Remote { .. } => warn!( - "Skipping shutdown of a remote timeline {} for tenant {}", - timeline_id, tenant_id - ), - } - Ok(()) -} - #[derive(Clone)] enum LayeredTimelineEntry { Local(Arc), @@ -652,8 +611,10 @@ impl LayeredRepository { // Ok, we now know all the branch points. // Perform GC for each timeline. for timelineid in timelineids { - if tenant_mgr::shutdown_requested() { - return Ok(totals); + if thread_mgr::is_shutdown_requested() { + // We were requested to shut down. Stop and return with the progress we + // made. + break; } // We have already loaded all timelines above diff --git a/pageserver/src/lib.rs b/pageserver/src/lib.rs index 23691ea130..3a68f56187 100644 --- a/pageserver/src/lib.rs +++ b/pageserver/src/lib.rs @@ -11,6 +11,7 @@ pub mod remote_storage; pub mod repository; pub mod tenant_mgr; pub mod tenant_threads; +pub mod thread_mgr; pub mod virtual_file; pub mod walingest; pub mod walreceiver; diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index c6de34b839..70ba7ec927 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -14,12 +14,11 @@ use anyhow::{anyhow, bail, ensure, Context, Result}; use bytes::{Buf, BufMut, Bytes, BytesMut}; use lazy_static::lazy_static; use regex::Regex; +use std::io; use std::net::TcpListener; use std::str; use std::str::FromStr; use std::sync::Arc; -use std::thread; -use std::{io, net::TcpStream}; use tracing::*; use zenith_metrics::{register_histogram_vec, HistogramVec}; use zenith_utils::auth::{self, JwtAuth}; @@ -39,6 +38,8 @@ use crate::config::PageServerConf; use crate::relish::*; use crate::repository::Timeline; use crate::tenant_mgr; +use crate::thread_mgr; +use crate::thread_mgr::ThreadKind; use crate::walreceiver; use crate::CheckpointConfig; @@ -189,36 +190,61 @@ pub fn thread_main( listener: TcpListener, auth_type: AuthType, ) -> anyhow::Result<()> { - let mut join_handles = Vec::new(); + listener.set_nonblocking(true)?; + let basic_rt = tokio::runtime::Builder::new_current_thread() + .enable_io() + .build()?; - while !tenant_mgr::shutdown_requested() { - let (socket, peer_addr) = listener.accept()?; - debug!("accepted connection from {}", peer_addr); - let local_auth = auth.clone(); + let tokio_listener = { + let _guard = basic_rt.enter(); + tokio::net::TcpListener::from_std(listener) + }?; - match thread::Builder::new() - .name("serving Page Service thread".into()) - .spawn(move || { - if let Err(err) = page_service_conn_main(conf, local_auth, socket, auth_type) { - error!("page server thread exited with error: {:?}", err); + // Wait for a new connection to arrive, or for server shutdown. + while let Some(res) = basic_rt.block_on(async { + let shutdown_watcher = thread_mgr::shutdown_watcher(); + tokio::select! { + biased; + + _ = shutdown_watcher => { + // We were requested to shut down. + None + } + + res = tokio_listener.accept() => { + Some(res) + } + } + }) { + match res { + Ok((socket, peer_addr)) => { + // Connection established. Spawn a new thread to handle it. + debug!("accepted connection from {}", peer_addr); + let local_auth = auth.clone(); + + // PageRequestHandler threads are not associated with any particular + // timeline in the thread manager. In practice most connections will + // only deal with a particular timeline, but we don't know which one + // yet. + if let Err(err) = thread_mgr::spawn( + ThreadKind::PageRequestHandler, + None, + None, + "serving Page Service thread", + move || page_service_conn_main(conf, local_auth, socket, auth_type), + ) { + // Thread creation failed. Log the error and continue. + error!("could not spawn page service thread: {:?}", err); } - }) { - Ok(handle) => { - // FIXME: There is no mechanism to remove the handle from this list - // when a thread exits - join_handles.push(handle); } Err(err) => { - // Thread creation failed. Log the error and continue. - error!(%err, "could not spawn page service thread"); + // accept() failed. Log the error, and loop back to retry on next connection. + error!("accept() failed: {:?}", err); } } } - debug!("page_service loop terminated. wait for connections to cancel"); - for handle in join_handles.into_iter() { - handle.join().unwrap(); - } + debug!("page_service loop terminated"); Ok(()) } @@ -226,10 +252,10 @@ pub fn thread_main( fn page_service_conn_main( conf: &'static PageServerConf, auth: Option>, - socket: TcpStream, + socket: tokio::net::TcpStream, auth_type: AuthType, ) -> anyhow::Result<()> { - // Immediatsely increment the gauge, then create a job to decrement it on thread exit. + // Immediately increment the gauge, then create a job to decrement it on thread exit. // One of the pros of `defer!` is that this will *most probably* // get called, even in presence of panics. let gauge = crate::LIVE_CONNECTIONS_COUNT.with_label_values(&["page_service"]); @@ -238,6 +264,15 @@ fn page_service_conn_main( gauge.dec(); } + // We use Tokio to accept the connection, but the rest of the code works with a + // regular socket. Convert. + let socket = socket + .into_std() + .context("could not convert tokio::net:TcpStream to std::net::TcpStream")?; + socket + .set_nonblocking(false) + .context("could not put socket to blocking mode")?; + socket .set_nodelay(true) .context("could not set TCP_NODELAY")?; @@ -296,7 +331,7 @@ impl PageServerHandler { /* switch client to COPYBOTH */ pgb.write_message(&BeMessage::CopyBothResponse)?; - while !tenant_mgr::shutdown_requested() { + while !thread_mgr::is_shutdown_requested() { match pgb.read_message() { Ok(message) => { if let Some(message) = message { diff --git a/pageserver/src/remote_storage.rs b/pageserver/src/remote_storage.rs index 5d53d703ec..6c948aaeb2 100644 --- a/pageserver/src/remote_storage.rs +++ b/pageserver/src/remote_storage.rs @@ -89,11 +89,10 @@ use std::{ collections::HashMap, ffi, fs, path::{Path, PathBuf}, - thread, }; use anyhow::{bail, Context}; -use tokio::{io, sync}; +use tokio::io; use tracing::{error, info}; use zenith_utils::zid::{ZTenantId, ZTimelineId}; @@ -125,8 +124,6 @@ pub struct SyncStartupData { /// To reuse the local file scan logic, the timeline states are returned even if no sync loop get started during init: /// in this case, no remote files exist and all local timelines with correct metadata files are considered ready. pub initial_timeline_states: HashMap>, - /// A handle to the sync loop, if it was started from the configuration provided. - pub sync_loop_handle: Option>>, } /// Based on the config, initiates the remote storage connection and starts a separate thread @@ -135,7 +132,6 @@ pub struct SyncStartupData { /// Along with that, scans tenant files local and remote (if the sync gets enabled) to check the initial timeline states. pub fn start_local_timeline_sync( config: &'static PageServerConf, - shutdown_hook: sync::watch::Receiver<()>, ) -> anyhow::Result { let local_timeline_files = local_tenant_timeline_files(config) .context("Failed to collect local tenant timeline files")?; @@ -143,7 +139,6 @@ pub fn start_local_timeline_sync( match &config.remote_storage_config { Some(storage_config) => match &storage_config.storage { RemoteStorageKind::LocalFs(root) => storage_sync::spawn_storage_sync_thread( - shutdown_hook, config, local_timeline_files, LocalFs::new(root.clone(), &config.workdir)?, @@ -151,7 +146,6 @@ pub fn start_local_timeline_sync( storage_config.max_sync_errors, ), RemoteStorageKind::AwsS3(s3_config) => storage_sync::spawn_storage_sync_thread( - shutdown_hook, config, local_timeline_files, S3::new(s3_config, &config.workdir)?, @@ -179,7 +173,6 @@ pub fn start_local_timeline_sync( } Ok(SyncStartupData { initial_timeline_states, - sync_loop_handle: None, }) } } diff --git a/pageserver/src/remote_storage/storage_sync.rs b/pageserver/src/remote_storage/storage_sync.rs index 5fc99f8228..9eab337d81 100644 --- a/pageserver/src/remote_storage/storage_sync.rs +++ b/pageserver/src/remote_storage/storage_sync.rs @@ -80,7 +80,6 @@ use std::{ num::{NonZeroU32, NonZeroUsize}, path::{Path, PathBuf}, sync::Arc, - thread, }; use anyhow::{bail, Context}; @@ -91,7 +90,6 @@ use tokio::{ runtime::Runtime, sync::{ mpsc::{self, UnboundedReceiver}, - watch::Receiver, RwLock, }, time::{Duration, Instant}, @@ -111,7 +109,7 @@ use super::{RemoteStorage, SyncStartupData, TimelineSyncId}; use crate::{ config::PageServerConf, layered_repository::metadata::TimelineMetadata, remote_storage::storage_sync::compression::read_archive_header, repository::TimelineSyncState, - tenant_mgr::set_timeline_states, + tenant_mgr::set_timeline_states, thread_mgr, thread_mgr::ThreadKind, }; use zenith_metrics::{register_histogram_vec, register_int_gauge, HistogramVec, IntGauge}; @@ -351,7 +349,6 @@ pub(super) fn spawn_storage_sync_thread< P: std::fmt::Debug + Send + Sync + 'static, S: RemoteStorage + Send + Sync + 'static, >( - shutdown_hook: Receiver<()>, conf: &'static PageServerConf, local_timeline_files: HashMap)>, storage: S, @@ -385,12 +382,14 @@ pub(super) fn spawn_storage_sync_thread< let initial_timeline_states = schedule_first_sync_tasks(&remote_index, local_timeline_files); - let handle = thread::Builder::new() - .name("Remote storage sync thread".to_string()) - .spawn(move || { + thread_mgr::spawn( + ThreadKind::StorageSync, + None, + None, + "Remote storage sync thread", + move || { storage_sync_loop( runtime, - shutdown_hook, conf, receiver, remote_index, @@ -398,11 +397,11 @@ pub(super) fn spawn_storage_sync_thread< max_concurrent_sync, max_sync_errors, ) - }) - .context("Failed to spawn remote storage sync thread")?; + }, + ) + .context("Failed to spawn remote storage sync thread")?; Ok(SyncStartupData { initial_timeline_states, - sync_loop_handle: Some(handle), }) } @@ -417,7 +416,6 @@ fn storage_sync_loop< S: RemoteStorage + Send + Sync + 'static, >( runtime: Runtime, - mut shutdown_hook: Receiver<()>, conf: &'static PageServerConf, mut receiver: UnboundedReceiver, index: RemoteTimelineIndex, @@ -437,7 +435,7 @@ fn storage_sync_loop< max_sync_errors, ) .instrument(debug_span!("storage_sync_loop_step")) => LoopStep::NewStates(new_timeline_states), - _ = shutdown_hook.changed() => LoopStep::Shutdown, + _ = thread_mgr::shutdown_watcher() => LoopStep::Shutdown, } }); diff --git a/pageserver/src/repository.rs b/pageserver/src/repository.rs index 8ddc2b4952..8ffee4d17c 100644 --- a/pageserver/src/repository.rs +++ b/pageserver/src/repository.rs @@ -19,8 +19,6 @@ pub type BlockNumber = u32; /// A repository corresponds to one .zenith directory. One repository holds multiple /// timelines, forked off from the same initial call to 'initdb'. pub trait Repository: Send + Sync { - fn shutdown(&self) -> Result<()>; - /// Updates timeline based on the new sync state, received from the remote storage synchronization. /// See [`crate::remote_storage`] for more details about the synchronization. fn set_timeline_state( diff --git a/pageserver/src/tenant_mgr.rs b/pageserver/src/tenant_mgr.rs index 368e2fdbad..7b8e8fe373 100644 --- a/pageserver/src/tenant_mgr.rs +++ b/pageserver/src/tenant_mgr.rs @@ -5,15 +5,16 @@ use crate::branches; use crate::config::PageServerConf; use crate::layered_repository::LayeredRepository; use crate::repository::{Repository, Timeline, TimelineSyncState}; -use crate::tenant_threads; +use crate::thread_mgr; +use crate::thread_mgr::ThreadKind; use crate::walredo::PostgresRedoManager; +use crate::CheckpointConfig; use anyhow::{anyhow, bail, Context, Result}; use lazy_static::lazy_static; use log::*; use serde::{Deserialize, Serialize}; use std::collections::{hash_map, HashMap}; use std::fmt; -use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, Mutex, MutexGuard}; use zenith_utils::zid::{ZTenantId, ZTimelineId}; @@ -23,7 +24,7 @@ lazy_static! { struct Tenant { state: TenantState, - repo: Option>, + repo: Arc, } #[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq)] @@ -56,8 +57,6 @@ fn access_tenants() -> MutexGuard<'static, HashMap> { TENANTS.lock().unwrap() } -static SHUTDOWN_REQUESTED: AtomicBool = AtomicBool::new(false); - /// Updates tenants' repositories, changing their timelines state in memory. pub fn set_timeline_states( conf: &'static PageServerConf, @@ -73,28 +72,7 @@ pub fn set_timeline_states( let mut m = access_tenants(); for (tenant_id, timeline_states) in timeline_states { - let tenant = m.entry(tenant_id).or_insert_with(|| Tenant { - state: TenantState::Idle, - repo: None, - }); - if let Err(e) = put_timelines_into_tenant(conf, tenant, tenant_id, timeline_states) { - error!( - "Failed to update timeline states for tenant {}: {:?}", - tenant_id, e - ); - } - } -} - -fn put_timelines_into_tenant( - conf: &'static PageServerConf, - tenant: &mut Tenant, - tenant_id: ZTenantId, - timeline_states: HashMap, -) -> anyhow::Result<()> { - let repo = match tenant.repo.as_ref() { - Some(repo) => Arc::clone(repo), - None => { + let tenant = m.entry(tenant_id).or_insert_with(|| { // Set up a WAL redo manager, for applying WAL records. let walredo_mgr = PostgresRedoManager::new(conf, tenant_id); @@ -105,13 +83,43 @@ fn put_timelines_into_tenant( tenant_id, conf.remote_storage_config.is_some(), )); - tenant.repo = Some(Arc::clone(&repo)); - repo + Tenant { + state: TenantState::Idle, + repo, + } + }); + if let Err(e) = put_timelines_into_tenant(tenant, tenant_id, timeline_states) { + error!( + "Failed to update timeline states for tenant {}: {:?}", + tenant_id, e + ); } - }; + } +} +fn put_timelines_into_tenant( + tenant: &mut Tenant, + tenant_id: ZTenantId, + timeline_states: HashMap, +) -> anyhow::Result<()> { for (timeline_id, timeline_state) in timeline_states { - repo.set_timeline_state(timeline_id, timeline_state) + // If the timeline is being put into any other state than Ready, + // stop any threads operating on it. + // + // FIXME: This is racy. A page service thread could just get + // handle on the Timeline, before we call set_timeline_state() + if !matches!(timeline_state, TimelineSyncState::Ready(_)) { + thread_mgr::shutdown_threads(None, Some(tenant_id), Some(timeline_id)); + + // Should we run a final checkpoint to flush all the data to + // disk? Doesn't seem necessary; all of the states other than + // Ready imply that the data on local disk is corrupt or incomplete, + // and we don't want to flush that to disk. + } + + tenant + .repo + .set_timeline_state(timeline_id, timeline_state) .with_context(|| { format!( "Failed to update timeline {} state to {:?}", @@ -123,29 +131,49 @@ fn put_timelines_into_tenant( Ok(()) } -// Check this flag in the thread loops to know when to exit -pub fn shutdown_requested() -> bool { - SHUTDOWN_REQUESTED.load(Ordering::Relaxed) -} - -pub fn shutdown_all_tenants() -> Result<()> { - SHUTDOWN_REQUESTED.swap(true, Ordering::Relaxed); - - let tenantids = list_tenantids()?; - - for tenantid in &tenantids { - set_tenant_state(*tenantid, TenantState::Stopping)?; +/// +/// Shut down all tenants. This runs as part of pageserver shutdown. +/// +pub fn shutdown_all_tenants() { + let mut m = access_tenants(); + let mut tenantids = Vec::new(); + for (tenantid, tenant) in m.iter_mut() { + tenant.state = TenantState::Stopping; + tenantids.push(*tenantid) } + drop(m); + thread_mgr::shutdown_threads(Some(ThreadKind::WalReceiver), None, None); + thread_mgr::shutdown_threads(Some(ThreadKind::GarbageCollector), None, None); + thread_mgr::shutdown_threads(Some(ThreadKind::Checkpointer), None, None); + + // Ok, no background threads running anymore. Flush any remaining data in + // memory to disk. + // + // We assume that any incoming connections that might request pages from + // the repository have already been terminated by the caller, so there + // should be no more activity in any of the repositories. + // + // On error, log it but continue with the shutdown for other tenants. for tenantid in tenantids { - // Wait for checkpointer and GC to finish their job - tenant_threads::wait_for_tenant_threads_to_stop(tenantid); - - let repo = get_repository_for_tenant(tenantid)?; debug!("shutdown tenant {}", tenantid); - repo.shutdown()?; + match get_repository_for_tenant(tenantid) { + Ok(repo) => { + if let Err(err) = repo.checkpoint_iteration(CheckpointConfig::Flush) { + error!( + "Could not checkpoint tenant {} during shutdown: {:?}", + tenantid, err + ); + } + } + Err(err) => { + error!( + "Could not get repository for tenant {} during shutdown: {:?}", + tenantid, err + ); + } + } } - Ok(()) } pub fn create_repository_for_tenant( @@ -153,7 +181,7 @@ pub fn create_repository_for_tenant( tenantid: ZTenantId, ) -> Result<()> { let wal_redo_manager = Arc::new(PostgresRedoManager::new(conf, tenantid)); - let repo = Some(branches::create_repo(conf, tenantid, wal_redo_manager)?); + let repo = branches::create_repo(conf, tenantid, wal_redo_manager)?; match access_tenants().entry(tenantid) { hash_map::Entry::Occupied(_) => bail!("tenant {} already exists", tenantid), @@ -172,22 +200,51 @@ pub fn get_tenant_state(tenantid: ZTenantId) -> Option { Some(access_tenants().get(&tenantid)?.state) } -pub fn set_tenant_state(tenantid: ZTenantId, newstate: TenantState) -> Result { +/// +/// Change the state of a tenant to Active and launch its checkpointer and GC +/// threads. If the tenant was already in Active state or Stopping, does nothing. +/// +pub fn activate_tenant(conf: &'static PageServerConf, tenantid: ZTenantId) -> Result<()> { let mut m = access_tenants(); - let tenant = m.get_mut(&tenantid); + let tenant = m + .get_mut(&tenantid) + .ok_or_else(|| anyhow!("Tenant not found for id {}", tenantid))?; - match tenant { - Some(tenant) => { - if newstate == TenantState::Idle && tenant.state != TenantState::Active { - // Only Active tenant can become Idle - return Ok(tenant.state); - } - info!("set_tenant_state: {} -> {}", tenant.state, newstate); - tenant.state = newstate; - Ok(tenant.state) + info!("activating tenant {}", tenantid); + + match tenant.state { + // If the tenant is already active, nothing to do. + TenantState::Active => {} + + // If it's Idle, launch the checkpointer and GC threads + TenantState::Idle => { + thread_mgr::spawn( + ThreadKind::Checkpointer, + Some(tenantid), + None, + "Checkpointer thread", + move || crate::tenant_threads::checkpoint_loop(tenantid, conf), + )?; + + // FIXME: if we fail to launch the GC thread, but already launched the + // checkpointer, we're in a strange state. + + thread_mgr::spawn( + ThreadKind::GarbageCollector, + Some(tenantid), + None, + "GC thread", + move || crate::tenant_threads::gc_loop(tenantid, conf), + )?; + + tenant.state = TenantState::Active; + } + + TenantState::Stopping => { + // don't re-activate it if it's being stopped } - None => bail!("Tenant not found for id {}", tenantid), } + Ok(()) } pub fn get_repository_for_tenant(tenantid: ZTenantId) -> Result> { @@ -196,10 +253,7 @@ pub fn get_repository_for_tenant(tenantid: ZTenantId) -> Result Ok(Arc::clone(repo)), - None => bail!("Repository for tenant {} is not yet valid", tenantid), - } + Ok(Arc::clone(&tenant.repo)) } pub fn get_timeline_for_tenant( @@ -212,16 +266,6 @@ pub fn get_timeline_for_tenant( .ok_or_else(|| anyhow!("cannot fetch timeline {}", timelineid)) } -fn list_tenantids() -> Result> { - access_tenants() - .iter() - .map(|v| { - let (tenantid, _) = v; - Ok(*tenantid) - }) - .collect() -} - #[derive(Serialize, Deserialize, Clone)] pub struct TenantInfo { #[serde(with = "hex")] diff --git a/pageserver/src/tenant_threads.rs b/pageserver/src/tenant_threads.rs index afcd313ea1..062af9f1ad 100644 --- a/pageserver/src/tenant_threads.rs +++ b/pageserver/src/tenant_threads.rs @@ -5,88 +5,14 @@ use crate::tenant_mgr; use crate::tenant_mgr::TenantState; use crate::CheckpointConfig; use anyhow::Result; -use lazy_static::lazy_static; -use std::collections::HashMap; -use std::sync::Mutex; -use std::thread::JoinHandle; use std::time::Duration; use tracing::*; -use zenith_metrics::{register_int_gauge_vec, IntGaugeVec}; use zenith_utils::zid::ZTenantId; -struct TenantHandleEntry { - checkpointer_handle: Option>, - gc_handle: Option>, -} - -// Preserve handles to wait for thread completion -// at shutdown -lazy_static! { - static ref TENANT_HANDLES: Mutex> = - Mutex::new(HashMap::new()); -} - -lazy_static! { - static ref TENANT_THREADS_COUNT: IntGaugeVec = register_int_gauge_vec!( - "tenant_threads_count", - "Number of live tenant threads", - &["tenant_thread_type"] - ) - .expect("failed to define a metric"); -} - -// Launch checkpointer and GC for the tenant. -// It's possible that the threads are running already, -// if so, just don't spawn new ones. -pub fn start_tenant_threads(conf: &'static PageServerConf, tenantid: ZTenantId) { - let mut handles = TENANT_HANDLES.lock().unwrap(); - let h = handles - .entry(tenantid) - .or_insert_with(|| TenantHandleEntry { - checkpointer_handle: None, - gc_handle: None, - }); - - if h.checkpointer_handle.is_none() { - h.checkpointer_handle = std::thread::Builder::new() - .name("Checkpointer thread".into()) - .spawn(move || { - checkpoint_loop(tenantid, conf).expect("Checkpointer thread died"); - }) - .ok(); - } - - if h.gc_handle.is_none() { - h.gc_handle = std::thread::Builder::new() - .name("GC thread".into()) - .spawn(move || { - gc_loop(tenantid, conf).expect("GC thread died"); - }) - .ok(); - } -} - -pub fn wait_for_tenant_threads_to_stop(tenantid: ZTenantId) { - let mut handles = TENANT_HANDLES.lock().unwrap(); - if let Some(h) = handles.get_mut(&tenantid) { - h.checkpointer_handle.take().map(JoinHandle::join); - trace!("checkpointer for tenant {} has stopped", tenantid); - h.gc_handle.take().map(JoinHandle::join); - trace!("gc for tenant {} has stopped", tenantid); - } - handles.remove(&tenantid); -} - /// /// Checkpointer thread's main loop /// -fn checkpoint_loop(tenantid: ZTenantId, conf: &'static PageServerConf) -> Result<()> { - let gauge = TENANT_THREADS_COUNT.with_label_values(&["checkpointer"]); - gauge.inc(); - scopeguard::defer! { - gauge.dec(); - } - +pub fn checkpoint_loop(tenantid: ZTenantId, conf: &'static PageServerConf) -> Result<()> { loop { if tenant_mgr::get_tenant_state(tenantid) != Some(TenantState::Active) { break; @@ -112,13 +38,7 @@ fn checkpoint_loop(tenantid: ZTenantId, conf: &'static PageServerConf) -> Result /// /// GC thread's main loop /// -fn gc_loop(tenantid: ZTenantId, conf: &'static PageServerConf) -> Result<()> { - let gauge = TENANT_THREADS_COUNT.with_label_values(&["gc"]); - gauge.inc(); - scopeguard::defer! { - gauge.dec(); - } - +pub fn gc_loop(tenantid: ZTenantId, conf: &'static PageServerConf) -> Result<()> { loop { if tenant_mgr::get_tenant_state(tenantid) != Some(TenantState::Active) { break; diff --git a/pageserver/src/thread_mgr.rs b/pageserver/src/thread_mgr.rs new file mode 100644 index 0000000000..a51f0909ca --- /dev/null +++ b/pageserver/src/thread_mgr.rs @@ -0,0 +1,284 @@ +//! +//! This module provides centralized handling of threads in the Page Server. +//! +//! We provide a few basic facilities: +//! - A global registry of threads that lists what kind of threads they are, and +//! which tenant or timeline they are working on +//! +//! - The ability to request a thread to shut down. +//! +//! +//! # How it works? +//! +//! There is a global hashmap of all the threads (`THREADS`). Whenever a new +//! thread is spawned, a PageServerThread entry is added there, and when a +//! thread dies, it removes itself from the hashmap. If you want to kill a +//! thread, you can scan the hashmap to find it. +//! +//! # Thread shutdown +//! +//! To kill a thread, we rely on co-operation from the victim. Each thread is +//! expected to periodically call the `is_shutdown_requested()` function, and +//! if it returns true, exit gracefully. In addition to that, when waiting for +//! the network or other long-running operation, you can use +//! `shutdown_watcher()` function to get a Future that will become ready if +//! the current thread has been requested to shut down. You can use that with +//! Tokio select!(), but note that it relies on thread-local storage, so it +//! will only work with the "current-thread" Tokio runtime! +//! +//! +//! TODO: This would be a good place to also handle panics in a somewhat sane way. +//! Depending on what thread panics, we might want to kill the whole server, or +//! only a single tenant or timeline. +//! + +use std::cell::RefCell; +use std::collections::HashMap; +use std::panic; +use std::panic::AssertUnwindSafe; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::sync::{Arc, Mutex}; +use std::thread; +use std::thread::JoinHandle; + +use tokio::sync::watch; + +use tracing::{info, warn}; + +use lazy_static::lazy_static; + +use zenith_utils::zid::{ZTenantId, ZTimelineId}; + +lazy_static! { + /// Each thread that we track is associated with a "thread ID". It's just + /// an increasing number that we assign, not related to any system thread + /// id. + static ref NEXT_THREAD_ID: AtomicU64 = AtomicU64::new(1); + + /// Global registry of threads + static ref THREADS: Mutex>> = Mutex::new(HashMap::new()); +} + +// There is a Tokio watch channel for each thread, which can be used to signal the +// thread that it needs to shut down. This thread local variable holds the receiving +// end of the channel. The sender is kept in the global registry, so that anyone +// can send the signal to request thread shutdown. +thread_local!(static SHUTDOWN_RX: RefCell>> = RefCell::new(None)); + +// Each thread holds reference to its own PageServerThread here. +thread_local!(static CURRENT_THREAD: RefCell>> = RefCell::new(None)); + +/// +/// There are many kinds of threads in the system. Some are associated with a particular +/// tenant or timeline, while others are global. +/// +/// Note that we don't try to limit how may threads of a certain kind can be running +/// at the same time. +/// +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum ThreadKind { + // libpq listener thread. It just accepts connection and spawns a + // PageRequestHandler thread for each connection. + LibpqEndpointListener, + + // HTTP endpoint listener. + HttpEndpointListener, + + // Thread that handles a single connection. A PageRequestHandler thread + // starts detached from any particular tenant or timeline, but it can be + // associated with one later, after receiving a command from the client. + PageRequestHandler, + + // Thread that connects to a safekeeper to fetch WAL for one timeline. + WalReceiver, + + // Thread that handles checkpointing of all timelines for a tenant. + Checkpointer, + + // Thread that handles GC of a tenant + GarbageCollector, + + // Thread for synchronizing pageserver relish data with the remote storage. + // Shared by all tenants. + StorageSync, +} + +struct PageServerThread { + _thread_id: u64, + + kind: ThreadKind, + + /// Tenant and timeline that this thread is associated with. + tenant_id: Option, + timeline_id: Option, + + name: String, + + // To request thread shutdown, set the flag, and send a dummy message to the + // channel to notify it. + shutdown_requested: AtomicBool, + shutdown_tx: watch::Sender<()>, + + /// Handle for waiting for the thread to exit. It can be None, if the + /// the thread has already exited. + join_handle: Mutex>>, +} + +/// Launch a new thread +pub fn spawn( + kind: ThreadKind, + tenant_id: Option, + timeline_id: Option, + name: &str, + f: F, +) -> std::io::Result<()> +where + F: FnOnce() -> Result<(), E> + Send + 'static, +{ + let (shutdown_tx, shutdown_rx) = watch::channel(()); + let thread_id = NEXT_THREAD_ID.fetch_add(1, Ordering::Relaxed); + let thread = PageServerThread { + _thread_id: thread_id, + kind, + tenant_id, + timeline_id, + name: name.to_string(), + + shutdown_requested: AtomicBool::new(false), + shutdown_tx, + + join_handle: Mutex::new(None), + }; + + let thread_rc = Arc::new(thread); + + let mut jh_guard = thread_rc.join_handle.lock().unwrap(); + + THREADS + .lock() + .unwrap() + .insert(thread_id, Arc::clone(&thread_rc)); + + let thread_rc2 = Arc::clone(&thread_rc); + let join_handle = match thread::Builder::new() + .name(name.to_string()) + .spawn(move || thread_wrapper(thread_id, thread_rc2, shutdown_rx, f)) + { + Ok(handle) => handle, + Err(err) => { + // Could not spawn the thread. Remove the entry + THREADS.lock().unwrap().remove(&thread_id); + return Err(err); + } + }; + *jh_guard = Some(join_handle); + drop(jh_guard); + + // The thread is now running. Nothing more to do here + Ok(()) +} + +/// This wrapper function runs in a newly-spawned thread. It initializes the +/// thread-local variables and calls the payload function +fn thread_wrapper( + thread_id: u64, + thread: Arc, + shutdown_rx: watch::Receiver<()>, + f: F, +) where + F: FnOnce() -> Result<(), E> + Send + 'static, +{ + SHUTDOWN_RX.with(|rx| { + *rx.borrow_mut() = Some(shutdown_rx); + }); + CURRENT_THREAD.with(|ct| { + *ct.borrow_mut() = Some(thread); + }); + + // We use AssertUnwindSafe here so that the payload function + // doesn't need to be UnwindSafe. We don't do anything after the + // unwinding that would expose us to unwind-unsafe behavior. + let result = panic::catch_unwind(AssertUnwindSafe(f)); + + // Remove our entry from the global hashmap. + THREADS.lock().unwrap().remove(&thread_id); + + // If the thread payload panic'd, exit with the panic. + if let Err(err) = result { + panic::resume_unwind(err); + } +} + +/// Is there a thread running that matches the criteria + +/// Signal and wait for threads to shut down. +/// +/// +/// The arguments are used to select the threads to kill. Any None arguments are +/// ignored. For example, to shut down all WalReceiver threads: +/// +/// shutdown_threads(Some(ThreadKind::WalReceiver), None, None) +/// +/// Or to shut down all threads for given timeline: +/// +/// shutdown_threads(None, Some(timelineid), None) +/// +pub fn shutdown_threads( + kind: Option, + tenant_id: Option, + timeline_id: Option, +) { + let mut victim_threads = Vec::new(); + + let threads = THREADS.lock().unwrap(); + for thread in threads.values() { + if (kind.is_none() || Some(thread.kind) == kind) + && (tenant_id.is_none() || thread.tenant_id == tenant_id) + && (timeline_id.is_none() || thread.timeline_id == timeline_id) + { + thread.shutdown_requested.store(true, Ordering::Relaxed); + // FIXME: handle error? + let _ = thread.shutdown_tx.send(()); + victim_threads.push(Arc::clone(thread)); + } + } + drop(threads); + + for thread in victim_threads { + info!("waiting for {} to shut down", thread.name); + if let Some(join_handle) = thread.join_handle.lock().unwrap().take() { + let _ = join_handle.join(); + } else { + // The thread had not even fully started yet. Or it was shut down + // concurrently and alrady exited + } + } +} + +/// A Future that can be used to check if the current thread has been requested to +/// shut down. +pub async fn shutdown_watcher() { + let _ = SHUTDOWN_RX + .with(|rx| { + rx.borrow() + .as_ref() + .expect("shutdown_requested() called in an unexpected thread") + .clone() + }) + .changed() + .await; +} + +/// Has the current thread been requested to shut down? +pub fn is_shutdown_requested() -> bool { + CURRENT_THREAD.with(|ct| { + if let Some(ct) = ct.borrow().as_ref() { + ct.shutdown_requested.load(Ordering::Relaxed) + } else { + if !cfg!(test) { + warn!("is_shutdown_requested() called in an unexpected thread"); + } + false + } + }) +} diff --git a/pageserver/src/walreceiver.rs b/pageserver/src/walreceiver.rs index 43c50746bd..35aa636b1f 100644 --- a/pageserver/src/walreceiver.rs +++ b/pageserver/src/walreceiver.rs @@ -7,8 +7,8 @@ use crate::config::PageServerConf; use crate::tenant_mgr; -use crate::tenant_mgr::TenantState; -use crate::tenant_threads; +use crate::thread_mgr; +use crate::thread_mgr::ThreadKind; use crate::walingest::WalIngest; use anyhow::{bail, Context, Error, Result}; use lazy_static::lazy_static; @@ -19,12 +19,9 @@ use postgres_types::PgLsn; use std::cell::Cell; use std::collections::HashMap; use std::str::FromStr; -use std::thread; -use std::thread::JoinHandle; use std::thread_local; use std::time::SystemTime; use tokio::pin; -use tokio::sync::oneshot; use tokio_postgres::replication::ReplicationStream; use tokio_postgres::{Client, NoTls, SimpleQueryMessage, SimpleQueryRow}; use tokio_stream::StreamExt; @@ -38,9 +35,6 @@ use zenith_utils::zid::ZTimelineId; // struct WalReceiverEntry { wal_producer_connstr: String, - wal_receiver_handle: Option>, - wal_receiver_interrupt_sender: Option>, - tenantid: ZTenantId, } lazy_static! { @@ -55,50 +49,9 @@ thread_local! { pub(crate) static IS_WAL_RECEIVER: Cell = Cell::new(false); } -// Wait for walreceiver to stop -// Now it stops when pageserver shutdown is requested. -// In future we can make this more granular and send shutdown signals -// per tenant/timeline to cancel inactive walreceivers. -// TODO deal with blocking pg connections -pub fn stop_wal_receiver(tenantid: ZTenantId, timelineid: ZTimelineId) { - let mut receivers = WAL_RECEIVERS.lock(); - - if let Some(r) = receivers.get_mut(&(tenantid, timelineid)) { - match r.wal_receiver_interrupt_sender.take() { - Some(s) => { - if s.send(()).is_err() { - warn!("wal receiver interrupt signal already sent"); - } - } - None => { - warn!("wal_receiver_interrupt_sender is missing, wal recever shouldn't be running") - } - } - - info!("waiting for wal receiver to stop"); - let handle = r.wal_receiver_handle.take(); - // do not hold the lock while joining the handle (deadlock is possible otherwise) - drop(receivers); - // there is no timeout or try_join option available so in case of a bug this can hang forever - handle.map(JoinHandle::join); - } -} - fn drop_wal_receiver(tenantid: ZTenantId, timelineid: ZTimelineId) { let mut receivers = WAL_RECEIVERS.lock(); receivers.remove(&(tenantid, timelineid)); - - // Check if it was the last walreceiver of the tenant. - // TODO now we store one WalReceiverEntry per timeline, - // so this iterator looks a bit strange. - for (_timelineid, entry) in receivers.iter() { - if entry.tenantid == tenantid { - return; - } - } - - // When last walreceiver of the tenant is gone, change state to Idle - tenant_mgr::set_tenant_state(tenantid, TenantState::Idle).unwrap(); } // Launch a new WAL receiver, or tell one that's running about change in connection string @@ -115,26 +68,24 @@ pub fn launch_wal_receiver( receiver.wal_producer_connstr = wal_producer_connstr.into(); } None => { - let (tx, rx) = tokio::sync::oneshot::channel::<()>(); - - let wal_receiver_handle = thread::Builder::new() - .name("WAL receiver thread".into()) - .spawn(move || { + thread_mgr::spawn( + ThreadKind::WalReceiver, + Some(tenantid), + Some(timelineid), + "WAL receiver thread", + move || { IS_WAL_RECEIVER.with(|c| c.set(true)); - thread_main(conf, tenantid, timelineid, rx); - })?; + thread_main(conf, tenantid, timelineid) + }, + )?; let receiver = WalReceiverEntry { wal_producer_connstr: wal_producer_connstr.into(), - wal_receiver_handle: Some(wal_receiver_handle), - wal_receiver_interrupt_sender: Some(tx), - tenantid, }; receivers.insert((tenantid, timelineid), receiver); // Update tenant state and start tenant threads, if they are not running yet. - tenant_mgr::set_tenant_state(tenantid, TenantState::Active)?; - tenant_threads::start_tenant_threads(conf, tenantid); + tenant_mgr::activate_tenant(conf, tenantid)?; } }; Ok(()) @@ -158,8 +109,7 @@ fn thread_main( conf: &'static PageServerConf, tenantid: ZTenantId, timelineid: ZTimelineId, - interrupt_receiver: oneshot::Receiver<()>, -) { +) -> Result<()> { let _enter = info_span!("WAL receiver", timeline = %timelineid, tenant = %tenantid).entered(); info!("WAL receiver thread started"); @@ -168,13 +118,7 @@ fn thread_main( // Make a connection to the WAL safekeeper, or directly to the primary PostgreSQL server, // and start streaming WAL from it. - let res = walreceiver_main( - conf, - tenantid, - timelineid, - &wal_producer_connstr, - interrupt_receiver, - ); + let res = walreceiver_main(conf, tenantid, timelineid, &wal_producer_connstr); // TODO cleanup info messages if let Err(e) = res { @@ -189,6 +133,7 @@ fn thread_main( // Drop it from list of active WAL_RECEIVERS // so that next callmemaybe request launched a new thread drop_wal_receiver(tenantid, timelineid); + Ok(()) } fn walreceiver_main( @@ -196,7 +141,6 @@ fn walreceiver_main( tenantid: ZTenantId, timelineid: ZTimelineId, wal_producer_connstr: &str, - mut interrupt_receiver: oneshot::Receiver<()>, ) -> Result<(), Error> { // Connect to the database in replication mode. info!("connecting to {:?}", wal_producer_connstr); @@ -273,12 +217,15 @@ fn walreceiver_main( let mut walingest = WalIngest::new(&*timeline, startpoint)?; while let Some(replication_message) = runtime.block_on(async { + let shutdown_watcher = thread_mgr::shutdown_watcher(); tokio::select! { - replication_message = physical_stream.next() => replication_message, - _ = &mut interrupt_receiver => { + // check for shutdown first + biased; + _ = shutdown_watcher => { info!("walreceiver interrupted"); None } + replication_message = physical_stream.next() => replication_message, } }) { let replication_message = replication_message?; diff --git a/proxy/src/main.rs b/proxy/src/main.rs index 0f7a6981f1..cad47156f5 100644 --- a/proxy/src/main.rs +++ b/proxy/src/main.rs @@ -117,7 +117,11 @@ fn main() -> anyhow::Result<()> { .name("Http thread".into()) .spawn(move || { let router = http::make_router(); - endpoint::serve_thread_main(router, http_listener) + endpoint::serve_thread_main( + router, + http_listener, + std::future::pending(), // never shut down + ) })?, // Spawn a thread to listen for connections. It will spawn further threads // for each connection. diff --git a/walkeeper/src/bin/safekeeper.rs b/walkeeper/src/bin/safekeeper.rs index 5ba2ed360f..23919d34d7 100644 --- a/walkeeper/src/bin/safekeeper.rs +++ b/walkeeper/src/bin/safekeeper.rs @@ -197,7 +197,12 @@ fn start_safekeeper(conf: SafeKeeperConf) -> Result<()> { .spawn(|| { // TODO authentication let router = http::make_router(conf_); - endpoint::serve_thread_main(router, http_listener).unwrap(); + endpoint::serve_thread_main( + router, + http_listener, + std::future::pending(), // never shut down + ) + .unwrap(); })?, ); diff --git a/zenith_utils/src/http/endpoint.rs b/zenith_utils/src/http/endpoint.rs index ffb798fe83..0be08f45e1 100644 --- a/zenith_utils/src/http/endpoint.rs +++ b/zenith_utils/src/http/endpoint.rs @@ -8,22 +8,15 @@ use lazy_static::lazy_static; use routerify::ext::RequestExt; use routerify::RequestInfo; use routerify::{Middleware, Router, RouterBuilder, RouterService}; -use std::net::TcpListener; use tracing::info; use zenith_metrics::{new_common_metric_name, register_int_counter, IntCounter}; use zenith_metrics::{Encoder, TextEncoder}; -use std::sync::Mutex; -use tokio::sync::oneshot::Sender; +use std::future::Future; +use std::net::TcpListener; use super::error::ApiError; -lazy_static! { - /// Channel used to send shutdown signal - wrapped in an Option to allow - /// it to be taken by value (since oneshot channels consume themselves on send) - static ref SHUTDOWN_SENDER: Mutex>> = Mutex::new(None); -} - lazy_static! { static ref SERVE_METRICS_COUNT: IntCounter = register_int_counter!( new_common_metric_name("serve_metrics_count"), @@ -153,17 +146,20 @@ pub fn check_permission(req: &Request, tenantid: Option) -> Res } } -/// Initiate graceful shutdown of the http endpoint -pub fn shutdown() { - if let Some(tx) = SHUTDOWN_SENDER.lock().unwrap().take() { - let _ = tx.send(()); - } -} - -pub fn serve_thread_main( +/// +/// Start listening for HTTP requests on given socket. +/// +/// 'shutdown_future' can be used to stop. If the Future becomes +/// ready, we stop listening for new requests, and the function returns. +/// +pub fn serve_thread_main( router_builder: RouterBuilder, listener: TcpListener, -) -> anyhow::Result<()> { + shutdown_future: S, +) -> anyhow::Result<()> +where + S: Future + Send + Sync, +{ info!("Starting a http endpoint at {}", listener.local_addr()?); // Create a Service from the router above to handle incoming requests. @@ -176,14 +172,9 @@ pub fn serve_thread_main( let _guard = runtime.enter(); - let (send, recv) = tokio::sync::oneshot::channel::<()>(); - *SHUTDOWN_SENDER.lock().unwrap() = Some(send); - let server = Server::from_tcp(listener)? .serve(service) - .with_graceful_shutdown(async { - recv.await.ok(); - }); + .with_graceful_shutdown(shutdown_future); runtime.block_on(server)?;