From 0acb604fa3793a0ae99238c026bf3c40391ae461 Mon Sep 17 00:00:00 2001 From: Joonas Koivunen Date: Tue, 4 Jun 2024 14:19:36 +0300 Subject: [PATCH] test: no missed wakeups, cancellation and timeout flow to downloads (#7863) I suspected a wakeup could be lost with `remote_storage::support::DownloadStream` if the cancellation and inner stream wakeups happen simultaneously. The next poll would only return the cancellation error without setting the wakeup. There is no lost wakeup because the single future for getting the cancellation error is consumed when the value is ready, and a new future is created for the *next* value. The new future is always polled. Similarly, if only the `Stream::poll_next` is being used after a `Some(_)` value has been yielded, it makes no sense to have an expectation of a wakeup for the *(N+1)th* stream value already set because when a value is wanted, `Stream::poll_next` will be called. A test is added to show that the above is true. Additionally, there was a question of these cancellations and timeouts flowing to attached or secondary tenant downloads. A test is added to show that this, in fact, happens. Lastly, a warning message is logged when a download stream is polled after a timeout or cancellation error (currently unexpected) so we can rule it out while troubleshooting. --- libs/remote_storage/src/support.rs | 50 ++++- .../tenant/remote_timeline_client/download.rs | 5 + pageserver/src/tenant/secondary/downloader.rs | 7 +- pageserver/src/tenant/secondary/scheduler.rs | 7 +- test_runner/fixtures/remote_storage.py | 5 + test_runner/regress/test_ondemand_download.py | 201 +++++++++++++++++- 6 files changed, 263 insertions(+), 12 deletions(-) diff --git a/libs/remote_storage/src/support.rs b/libs/remote_storage/src/support.rs index d146b5445b..1ed9ed9305 100644 --- a/libs/remote_storage/src/support.rs +++ b/libs/remote_storage/src/support.rs @@ -78,6 +78,10 @@ where let e = Err(std::io::Error::from(e)); return Poll::Ready(Some(e)); } + } else { + // this would be perfectly valid behaviour for doing a graceful completion on the + // download for example, but not one we expect to do right now. + tracing::warn!("continuing polling after having cancelled or timeouted"); } this.inner.poll_next(cx) @@ -89,13 +93,22 @@ where } /// Fires only on the first cancel or timeout, not on both. -pub(crate) async fn cancel_or_timeout( +pub(crate) fn cancel_or_timeout( timeout: Duration, cancel: CancellationToken, -) -> TimeoutOrCancel { - tokio::select! { - _ = tokio::time::sleep(timeout) => TimeoutOrCancel::Timeout, - _ = cancel.cancelled() => TimeoutOrCancel::Cancel, +) -> impl std::future::Future + 'static { + // futures are lazy, they don't do anything before being polled. + // + // "precalculate" the wanted deadline before returning the future, so that we can use pause + // failpoint to trigger a timeout in test. + let deadline = tokio::time::Instant::now() + timeout; + async move { + tokio::select! { + _ = tokio::time::sleep_until(deadline) => TimeoutOrCancel::Timeout, + _ = cancel.cancelled() => { + TimeoutOrCancel::Cancel + }, + } } } @@ -172,4 +185,31 @@ mod tests { _ = tokio::time::sleep(Duration::from_secs(121)) => {}, } } + + #[tokio::test] + async fn notified_but_pollable_after() { + let inner = futures::stream::once(futures::future::ready(Ok(bytes::Bytes::from_static( + b"hello world", + )))); + let timeout = Duration::from_secs(120); + let cancel = CancellationToken::new(); + + cancel.cancel(); + let stream = DownloadStream::new(cancel_or_timeout(timeout, cancel.clone()), inner); + let mut stream = std::pin::pin!(stream); + + let next = stream.next().await; + let ioe = next.unwrap().unwrap_err(); + assert!( + matches!( + ioe.get_ref().unwrap().downcast_ref::(), + Some(&DownloadError::Cancelled) + ), + "{ioe:?}" + ); + + let next = stream.next().await; + let bytes = next.unwrap().unwrap(); + assert_eq!(&b"hello world"[..], bytes); + } } diff --git a/pageserver/src/tenant/remote_timeline_client/download.rs b/pageserver/src/tenant/remote_timeline_client/download.rs index bd75f980e8..d0385e4aee 100644 --- a/pageserver/src/tenant/remote_timeline_client/download.rs +++ b/pageserver/src/tenant/remote_timeline_client/download.rs @@ -28,6 +28,7 @@ use crate::TEMP_FILE_SUFFIX; use remote_storage::{DownloadError, GenericRemoteStorage, ListingMode, RemotePath}; use utils::crashsafe::path_with_suffix_extension; use utils::id::{TenantId, TimelineId}; +use utils::pausable_failpoint; use super::index::{IndexPart, LayerFileMetadata}; use super::{ @@ -152,6 +153,8 @@ async fn download_object<'a>( let download = storage.download(src_path, cancel).await?; + pausable_failpoint!("before-downloading-layer-stream-pausable"); + let mut buf_writer = tokio::io::BufWriter::with_capacity(super::BUFFER_SIZE, destination_file); @@ -199,6 +202,8 @@ async fn download_object<'a>( let mut download = storage.download(src_path, cancel).await?; + pausable_failpoint!("before-downloading-layer-stream-pausable"); + // TODO: use vectored write (writev) once supported by tokio-epoll-uring. // There's chunks_vectored() on the stream. let (bytes_amount, destination_file) = async { diff --git a/pageserver/src/tenant/secondary/downloader.rs b/pageserver/src/tenant/secondary/downloader.rs index 5c915d6b53..62803c7838 100644 --- a/pageserver/src/tenant/secondary/downloader.rs +++ b/pageserver/src/tenant/secondary/downloader.rs @@ -1000,7 +1000,7 @@ impl<'a> TenantDownloader<'a> { layer.name, layer.metadata.file_size ); - let downloaded_bytes = match download_layer_file( + let downloaded_bytes = download_layer_file( self.conf, self.remote_storage, *tenant_shard_id, @@ -1011,8 +1011,9 @@ impl<'a> TenantDownloader<'a> { &self.secondary_state.cancel, ctx, ) - .await - { + .await; + + let downloaded_bytes = match downloaded_bytes { Ok(bytes) => bytes, Err(DownloadError::NotFound) => { // A heatmap might be out of date and refer to a layer that doesn't exist any more. diff --git a/pageserver/src/tenant/secondary/scheduler.rs b/pageserver/src/tenant/secondary/scheduler.rs index 0ec1c7872a..28cf2125df 100644 --- a/pageserver/src/tenant/secondary/scheduler.rs +++ b/pageserver/src/tenant/secondary/scheduler.rs @@ -334,8 +334,11 @@ where let tenant_shard_id = job.get_tenant_shard_id(); let barrier = if let Some(barrier) = self.get_running(tenant_shard_id) { - tracing::info!(tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug(), - "Command already running, waiting for it"); + tracing::info!( + tenant_id=%tenant_shard_id.tenant_id, + shard_id=%tenant_shard_id.shard_slug(), + "Command already running, waiting for it" + ); barrier } else { let running = self.spawn_now(job); diff --git a/test_runner/fixtures/remote_storage.py b/test_runner/fixtures/remote_storage.py index ee18c53b52..6f6526d3fc 100644 --- a/test_runner/fixtures/remote_storage.py +++ b/test_runner/fixtures/remote_storage.py @@ -171,6 +171,8 @@ class S3Storage: """Is this MOCK_S3 (false) or REAL_S3 (true)""" real: bool endpoint: Optional[str] = None + """formatting deserialized with humantime crate, for example "1s".""" + custom_timeout: Optional[str] = None def access_env_vars(self) -> Dict[str, str]: if self.aws_profile is not None: @@ -208,6 +210,9 @@ class S3Storage: if self.endpoint is not None: rv["endpoint"] = self.endpoint + if self.custom_timeout is not None: + rv["timeout"] = self.custom_timeout + return rv def to_toml_inline_table(self) -> str: diff --git a/test_runner/regress/test_ondemand_download.py b/test_runner/regress/test_ondemand_download.py index 6fe23846c7..4a25dfd874 100644 --- a/test_runner/regress/test_ondemand_download.py +++ b/test_runner/regress/test_ondemand_download.py @@ -3,8 +3,10 @@ import time from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor from typing import Any, DefaultDict, Dict, Tuple +import pytest from fixtures.common_types import Lsn from fixtures.log_helper import log from fixtures.neon_fixtures import ( @@ -13,7 +15,7 @@ from fixtures.neon_fixtures import ( last_flush_lsn_upload, wait_for_last_flush_lsn, ) -from fixtures.pageserver.http import PageserverHttpClient +from fixtures.pageserver.http import PageserverApiException, PageserverHttpClient from fixtures.pageserver.utils import ( assert_tenant_state, wait_for_last_record_lsn, @@ -21,7 +23,7 @@ from fixtures.pageserver.utils import ( wait_for_upload_queue_empty, wait_until_tenant_active, ) -from fixtures.remote_storage import RemoteStorageKind +from fixtures.remote_storage import RemoteStorageKind, S3Storage, s3_storage from fixtures.utils import query_scalar, wait_until @@ -656,5 +658,200 @@ def test_compaction_downloads_on_demand_with_image_creation(neon_env_builder: Ne assert dict(kinds_after) == {"Delta": 4, "Image": 1} +def test_layer_download_cancelled_by_config_location(neon_env_builder: NeonEnvBuilder): + """ + Demonstrates that tenant shutdown will cancel on-demand download and secondary doing warmup. + """ + neon_env_builder.enable_pageserver_remote_storage(s3_storage()) + + # turn off background tasks so that they don't interfere with the downloads + env = neon_env_builder.init_start( + initial_tenant_conf={ + "gc_period": "0s", + "compaction_period": "0s", + } + ) + client = env.pageserver.http_client() + failpoint = "before-downloading-layer-stream-pausable" + client.configure_failpoints((failpoint, "pause")) + + env.pageserver.allowed_errors.extend( + [ + ".*downloading failed, possibly for shutdown.*", + ] + ) + + info = client.layer_map_info(env.initial_tenant, env.initial_timeline) + assert len(info.delta_layers()) == 1 + + layer = info.delta_layers()[0] + + client.tenant_heatmap_upload(env.initial_tenant) + + # evict the initdb layer so we can download it + client.evict_layer(env.initial_tenant, env.initial_timeline, layer.layer_file_name) + + with ThreadPoolExecutor(max_workers=2) as exec: + download = exec.submit( + client.download_layer, + env.initial_tenant, + env.initial_timeline, + layer.layer_file_name, + ) + + _, offset = wait_until( + 20, 0.5, lambda: env.pageserver.assert_log_contains(f"at failpoint {failpoint}") + ) + + location_conf = {"mode": "Detached", "tenant_conf": {}} + # assume detach removes the layers + detach = exec.submit(client.tenant_location_conf, env.initial_tenant, location_conf) + + _, offset = wait_until( + 20, + 0.5, + lambda: env.pageserver.assert_log_contains( + "closing is taking longer than expected", offset + ), + ) + + client.configure_failpoints((failpoint, "off")) + + with pytest.raises( + PageserverApiException, match="downloading failed, possibly for shutdown" + ): + download.result() + + env.pageserver.assert_log_contains(".*downloading failed, possibly for shutdown.*") + + detach.result() + + client.configure_failpoints((failpoint, "pause")) + + _, offset = wait_until( + 20, + 0.5, + lambda: env.pageserver.assert_log_contains(f"cfg failpoint: {failpoint} pause", offset), + ) + + location_conf = { + "mode": "Secondary", + "secondary_conf": {"warm": True}, + "tenant_conf": {}, + } + + client.tenant_location_conf(env.initial_tenant, location_conf) + + warmup = exec.submit(client.tenant_secondary_download, env.initial_tenant, wait_ms=30000) + + _, offset = wait_until( + 20, + 0.5, + lambda: env.pageserver.assert_log_contains(f"at failpoint {failpoint}", offset), + ) + + client.configure_failpoints((failpoint, "off")) + location_conf = {"mode": "Detached", "tenant_conf": {}} + client.tenant_location_conf(env.initial_tenant, location_conf) + + client.configure_failpoints((failpoint, "off")) + + # here we have nothing in the log, but we see that the warmup and conf location update worked + warmup.result() + + +def test_layer_download_timeouted(neon_env_builder: NeonEnvBuilder): + """ + Pause using a pausable_failpoint longer than the client timeout to simulate the timeout happening. + """ + neon_env_builder.enable_pageserver_remote_storage(s3_storage()) + assert isinstance(neon_env_builder.pageserver_remote_storage, S3Storage) + neon_env_builder.pageserver_remote_storage.custom_timeout = "1s" + + # turn off background tasks so that they don't interfere with the downloads + env = neon_env_builder.init_start( + initial_tenant_conf={ + "gc_period": "0s", + "compaction_period": "0s", + } + ) + client = env.pageserver.http_client() + failpoint = "before-downloading-layer-stream-pausable" + client.configure_failpoints((failpoint, "pause")) + + info = client.layer_map_info(env.initial_tenant, env.initial_timeline) + assert len(info.delta_layers()) == 1 + + layer = info.delta_layers()[0] + + client.tenant_heatmap_upload(env.initial_tenant) + + # evict so we can download it + client.evict_layer(env.initial_tenant, env.initial_timeline, layer.layer_file_name) + + with ThreadPoolExecutor(max_workers=2) as exec: + download = exec.submit( + client.download_layer, + env.initial_tenant, + env.initial_timeline, + layer.layer_file_name, + ) + + _, offset = wait_until( + 20, 0.5, lambda: env.pageserver.assert_log_contains(f"at failpoint {failpoint}") + ) + # ensure enough time while paused to trip the timeout + time.sleep(2) + + client.configure_failpoints((failpoint, "off")) + download.result() + + _, offset = env.pageserver.assert_log_contains( + ".*failed, will retry \\(attempt 0\\): timeout.*" + ) + _, offset = env.pageserver.assert_log_contains(".*succeeded after [0-9]+ retries.*", offset) + + client.evict_layer(env.initial_tenant, env.initial_timeline, layer.layer_file_name) + + client.configure_failpoints((failpoint, "pause")) + + # capture the next offset for a new synchronization with the failpoint + _, offset = wait_until( + 20, + 0.5, + lambda: env.pageserver.assert_log_contains(f"cfg failpoint: {failpoint} pause", offset), + ) + + location_conf = { + "mode": "Secondary", + "secondary_conf": {"warm": True}, + "tenant_conf": {}, + } + + client.tenant_location_conf( + env.initial_tenant, + location_conf, + ) + + started = time.time() + + warmup = exec.submit(client.tenant_secondary_download, env.initial_tenant, wait_ms=30000) + # ensure enough time while paused to trip the timeout + time.sleep(2) + + client.configure_failpoints((failpoint, "off")) + + warmup.result() + + elapsed = time.time() - started + + _, offset = env.pageserver.assert_log_contains( + ".*failed, will retry \\(attempt 0\\): timeout.*", offset + ) + _, offset = env.pageserver.assert_log_contains(".*succeeded after [0-9]+ retries.*", offset) + + assert elapsed < 30, "too long passed: {elapsed=}" + + def stringify(conf: Dict[str, Any]) -> Dict[str, str]: return dict(map(lambda x: (x[0], str(x[1])), conf.items()))