diff --git a/compute_tools/src/compute.rs b/compute_tools/src/compute.rs index f553bf3c1e..5b1cdb9805 100644 --- a/compute_tools/src/compute.rs +++ b/compute_tools/src/compute.rs @@ -32,6 +32,7 @@ use std::sync::{Arc, Condvar, Mutex, RwLock}; use std::time::{Duration, Instant}; use std::{env, fs}; use tokio::{spawn, sync::watch, task::JoinHandle, time}; +use tokio_util::sync::CancellationToken; use tracing::{Instrument, debug, error, info, instrument, warn}; use url::Url; use utils::id::{TenantId, TimelineId}; @@ -192,6 +193,7 @@ pub struct ComputeState { pub startup_span: Option, pub lfc_prewarm_state: LfcPrewarmState, + pub lfc_prewarm_token: CancellationToken, pub lfc_offload_state: LfcOffloadState, /// WAL flush LSN that is set after terminating Postgres and syncing safekeepers if @@ -217,6 +219,7 @@ impl ComputeState { lfc_offload_state: LfcOffloadState::default(), terminate_flush_lsn: None, promote_state: None, + lfc_prewarm_token: CancellationToken::new(), } } diff --git a/compute_tools/src/compute_prewarm.rs b/compute_tools/src/compute_prewarm.rs index 97e62c1c80..82cb28f1ac 100644 --- a/compute_tools/src/compute_prewarm.rs +++ b/compute_tools/src/compute_prewarm.rs @@ -7,7 +7,8 @@ use http::StatusCode; use reqwest::Client; use std::mem::replace; use std::sync::Arc; -use tokio::{io::AsyncReadExt, spawn}; +use tokio::{io::AsyncReadExt, select, spawn}; +use tokio_util::sync::CancellationToken; use tracing::{error, info}; #[derive(serde::Serialize, Default)] @@ -92,34 +93,35 @@ impl ComputeNode { /// If there is a prewarm request ongoing, return `false`, `true` otherwise. /// Has a failpoint "compute-prewarm" pub fn prewarm_lfc(self: &Arc, from_endpoint: Option) -> bool { + let token: CancellationToken; { - let state = &mut self.state.lock().unwrap().lfc_prewarm_state; - if let LfcPrewarmState::Prewarming = replace(state, LfcPrewarmState::Prewarming) { + let state = &mut self.state.lock().unwrap(); + token = state.lfc_prewarm_token.clone(); + if let LfcPrewarmState::Prewarming = + replace(&mut state.lfc_prewarm_state, LfcPrewarmState::Prewarming) + { return false; } } crate::metrics::LFC_PREWARMS.inc(); - let cloned = self.clone(); + let this = self.clone(); spawn(async move { - let state = match cloned.prewarm_impl(from_endpoint).await { - Ok(true) => LfcPrewarmState::Completed, - Ok(false) => { - info!( - "skipping LFC prewarm because LFC state is not found in endpoint storage" - ); - LfcPrewarmState::Skipped - } + let prewarm_state = match this.prewarm_impl(from_endpoint, token).await { + Ok(state) => state, Err(err) => { crate::metrics::LFC_PREWARM_ERRORS.inc(); error!(%err, "could not prewarm LFC"); - LfcPrewarmState::Failed { - error: format!("{err:#}"), - } + let error = format!("{err:#}"); + LfcPrewarmState::Failed { error } } }; - cloned.state.lock().unwrap().lfc_prewarm_state = state; + let state = &mut this.state.lock().unwrap(); + if let LfcPrewarmState::Cancelled = prewarm_state { + state.lfc_prewarm_token = CancellationToken::new(); + } + state.lfc_prewarm_state = prewarm_state; }); true } @@ -132,47 +134,70 @@ impl ComputeNode { /// Request LFC state from endpoint storage and load corresponding pages into Postgres. /// Returns a result with `false` if the LFC state is not found in endpoint storage. - async fn prewarm_impl(&self, from_endpoint: Option) -> Result { - let EndpointStoragePair { url, token } = self.endpoint_storage_pair(from_endpoint)?; + async fn prewarm_impl( + &self, + from_endpoint: Option, + token: CancellationToken, + ) -> Result { + let EndpointStoragePair { + url, + token: storage_token, + } = self.endpoint_storage_pair(from_endpoint)?; #[cfg(feature = "testing")] - fail::fail_point!("compute-prewarm", |_| { - bail!("prewarm configured to fail because of a failpoint") - }); + fail::fail_point!("compute-prewarm", |_| bail!("compute-prewarm failpoint")); info!(%url, "requesting LFC state from endpoint storage"); - let request = Client::new().get(&url).bearer_auth(token); - let res = request.send().await.context("querying endpoint storage")?; - match res.status() { + let request = Client::new().get(&url).bearer_auth(storage_token); + let response = select! { + _ = token.cancelled() => return Ok(LfcPrewarmState::Cancelled), + response = request.send() => response + } + .context("querying endpoint storage")?; + + match response.status() { StatusCode::OK => (), - StatusCode::NOT_FOUND => { - return Ok(false); - } + StatusCode::NOT_FOUND => return Ok(LfcPrewarmState::Skipped), status => bail!("{status} querying endpoint storage"), } let mut uncompressed = Vec::new(); - let lfc_state = res - .bytes() - .await - .context("getting request body from endpoint storage")?; - ZstdDecoder::new(lfc_state.iter().as_slice()) - .read_to_end(&mut uncompressed) - .await - .context("decoding LFC state")?; + let lfc_state = select! { + _ = token.cancelled() => return Ok(LfcPrewarmState::Cancelled), + lfc_state = response.bytes() => lfc_state + } + .context("getting request body from endpoint storage")?; + + let mut decoder = ZstdDecoder::new(lfc_state.iter().as_slice()); + select! { + _ = token.cancelled() => return Ok(LfcPrewarmState::Cancelled), + read = decoder.read_to_end(&mut uncompressed) => read + } + .context("decoding LFC state")?; + let uncompressed_len = uncompressed.len(); + info!(%url, "downloaded LFC state, uncompressed size {uncompressed_len}"); - info!(%url, "downloaded LFC state, uncompressed size {uncompressed_len}, loading into Postgres"); - - ComputeNode::get_maintenance_client(&self.tokio_conn_conf) + // Client connection and prewarm info querying are fast and therefore don't need + // cancellation + let client = ComputeNode::get_maintenance_client(&self.tokio_conn_conf) .await - .context("connecting to postgres")? - .query_one("select neon.prewarm_local_cache($1)", &[&uncompressed]) - .await - .context("loading LFC state into postgres") - .map(|_| ())?; + .context("connecting to postgres")?; + let pg_token = client.cancel_token(); - Ok(true) + let params: Vec<&(dyn postgres_types::ToSql + Sync)> = vec![&uncompressed]; + select! { + res = client.query_one("select neon.prewarm_local_cache($1)", ¶ms) => res, + _ = token.cancelled() => { + pg_token.cancel_query(postgres::NoTls).await + .context("cancelling neon.prewarm_local_cache()")?; + return Ok(LfcPrewarmState::Cancelled) + } + } + .context("loading LFC state into postgres") + .map(|_| ())?; + + Ok(LfcPrewarmState::Completed) } /// If offload request is ongoing, return false, true otherwise @@ -200,20 +225,20 @@ impl ComputeNode { async fn offload_lfc_with_state_update(&self) { crate::metrics::LFC_OFFLOADS.inc(); - - let Err(err) = self.offload_lfc_impl().await else { - self.state.lock().unwrap().lfc_offload_state = LfcOffloadState::Completed; - return; + let state = match self.offload_lfc_impl().await { + Ok(state) => state, + Err(err) => { + crate::metrics::LFC_OFFLOAD_ERRORS.inc(); + error!(%err, "could not offload LFC"); + let error = format!("{err:#}"); + LfcOffloadState::Failed { error } + } }; - crate::metrics::LFC_OFFLOAD_ERRORS.inc(); - error!(%err, "could not offload LFC state to endpoint storage"); - self.state.lock().unwrap().lfc_offload_state = LfcOffloadState::Failed { - error: format!("{err:#}"), - }; + self.state.lock().unwrap().lfc_offload_state = state; } - async fn offload_lfc_impl(&self) -> Result<()> { + async fn offload_lfc_impl(&self) -> Result { let EndpointStoragePair { url, token } = self.endpoint_storage_pair(None)?; info!(%url, "requesting LFC state from Postgres"); @@ -228,7 +253,7 @@ impl ComputeNode { .context("deserializing LFC state")?; let Some(state) = state else { info!(%url, "empty LFC state, not exporting"); - return Ok(()); + return Ok(LfcOffloadState::Skipped); }; let mut compressed = Vec::new(); @@ -242,7 +267,7 @@ impl ComputeNode { let request = Client::new().put(url).bearer_auth(token).body(compressed); match request.send().await { - Ok(res) if res.status() == StatusCode::OK => Ok(()), + Ok(res) if res.status() == StatusCode::OK => Ok(LfcOffloadState::Completed), Ok(res) => bail!( "Request to endpoint storage failed with status: {}", res.status() @@ -250,4 +275,8 @@ impl ComputeNode { Err(err) => Err(err).context("writing to endpoint storage"), } } + + pub fn cancel_prewarm(self: &Arc) { + self.state.lock().unwrap().lfc_prewarm_token.cancel(); + } } diff --git a/compute_tools/src/http/openapi_spec.yaml b/compute_tools/src/http/openapi_spec.yaml index ab729d62b5..27e610a87d 100644 --- a/compute_tools/src/http/openapi_spec.yaml +++ b/compute_tools/src/http/openapi_spec.yaml @@ -139,6 +139,15 @@ paths: application/json: schema: $ref: "#/components/schemas/LfcPrewarmState" + delete: + tags: + - Prewarm + summary: Cancel ongoing LFC prewarm + description: "" + operationId: cancelLfcPrewarm + responses: + 202: + description: Prewarm cancelled /lfc/offload: post: @@ -636,7 +645,7 @@ components: properties: status: description: LFC offload status - enum: [not_offloaded, offloading, completed, failed] + enum: [not_offloaded, offloading, completed, skipped, failed] type: string error: description: LFC offload error, if any diff --git a/compute_tools/src/http/routes/lfc.rs b/compute_tools/src/http/routes/lfc.rs index e98bd781a2..7483198723 100644 --- a/compute_tools/src/http/routes/lfc.rs +++ b/compute_tools/src/http/routes/lfc.rs @@ -46,3 +46,8 @@ pub(in crate::http) async fn offload(compute: Compute) -> Response { ) } } + +pub(in crate::http) async fn cancel_prewarm(compute: Compute) -> StatusCode { + compute.cancel_prewarm(); + StatusCode::ACCEPTED +} diff --git a/compute_tools/src/http/server.rs b/compute_tools/src/http/server.rs index 2fd3121f4f..869fdef11d 100644 --- a/compute_tools/src/http/server.rs +++ b/compute_tools/src/http/server.rs @@ -99,7 +99,12 @@ impl From<&Server> for Router> { ); let authenticated_router = Router::>::new() - .route("/lfc/prewarm", get(lfc::prewarm_state).post(lfc::prewarm)) + .route( + "/lfc/prewarm", + get(lfc::prewarm_state) + .post(lfc::prewarm) + .delete(lfc::cancel_prewarm), + ) .route("/lfc/offload", get(lfc::offload_state).post(lfc::offload)) .route("/promote", post(promote::promote)) .route("/check_writability", post(check_writability::is_writable)) diff --git a/libs/compute_api/src/responses.rs b/libs/compute_api/src/responses.rs index a27301e45e..a918644e4c 100644 --- a/libs/compute_api/src/responses.rs +++ b/libs/compute_api/src/responses.rs @@ -68,11 +68,15 @@ pub enum LfcPrewarmState { /// We tried to fetch the corresponding LFC state from the endpoint storage, /// but received `Not Found 404`. This should normally happen only during the /// first endpoint start after creation with `autoprewarm: true`. + /// This may also happen if LFC is turned off or not initialized /// /// During the orchestrated prewarm via API, when a caller explicitly /// provides the LFC state key to prewarm from, it's the caller responsibility /// to handle this status as an error state in this case. Skipped, + /// LFC prewarm was cancelled. Some pages in LFC cache may be prewarmed if query + /// has started working before cancellation + Cancelled, } impl Display for LfcPrewarmState { @@ -83,6 +87,7 @@ impl Display for LfcPrewarmState { LfcPrewarmState::Completed => f.write_str("Completed"), LfcPrewarmState::Skipped => f.write_str("Skipped"), LfcPrewarmState::Failed { error } => write!(f, "Error({error})"), + LfcPrewarmState::Cancelled => f.write_str("Cancelled"), } } } @@ -97,6 +102,7 @@ pub enum LfcOffloadState { Failed { error: String, }, + Skipped, } #[derive(Serialize, Debug, Clone, PartialEq)] diff --git a/test_runner/fixtures/endpoint/http.py b/test_runner/fixtures/endpoint/http.py index d235ac2143..c77a372017 100644 --- a/test_runner/fixtures/endpoint/http.py +++ b/test_runner/fixtures/endpoint/http.py @@ -78,20 +78,26 @@ class EndpointHttpClient(requests.Session): json: dict[str, str] = res.json() return json - def prewarm_lfc(self, from_endpoint_id: str | None = None): + def prewarm_lfc(self, from_endpoint_id: str | None = None) -> dict[str, str]: """ Prewarm LFC cache from given endpoint and wait till it finishes or errors """ params = {"from_endpoint": from_endpoint_id} if from_endpoint_id else dict() self.post(self.prewarm_url, params=params).raise_for_status() - self.prewarm_lfc_wait() + return self.prewarm_lfc_wait() - def prewarm_lfc_wait(self): + def cancel_prewarm_lfc(self): + """ + Cancel LFC prewarm if any is ongoing + """ + self.delete(self.prewarm_url).raise_for_status() + + def prewarm_lfc_wait(self) -> dict[str, str]: """ Wait till LFC prewarm returns with error or success. If prewarm was not requested before calling this function, it will error """ - statuses = "failed", "completed", "skipped" + statuses = "failed", "completed", "skipped", "cancelled" def prewarmed(): json = self.prewarm_lfc_status() @@ -101,6 +107,7 @@ class EndpointHttpClient(requests.Session): wait_until(prewarmed, timeout=60) res = self.prewarm_lfc_status() assert res["status"] != "failed", res + return res def offload_lfc_status(self) -> dict[str, str]: res = self.get(self.offload_url) @@ -108,29 +115,31 @@ class EndpointHttpClient(requests.Session): json: dict[str, str] = res.json() return json - def offload_lfc(self): + def offload_lfc(self) -> dict[str, str]: """ Offload LFC cache to endpoint storage and wait till offload finishes or errors """ self.post(self.offload_url).raise_for_status() - self.offload_lfc_wait() + return self.offload_lfc_wait() - def offload_lfc_wait(self): + def offload_lfc_wait(self) -> dict[str, str]: """ Wait till LFC offload returns with error or success. If offload was not requested before calling this function, it will error """ + statuses = "failed", "completed", "skipped" def offloaded(): json = self.offload_lfc_status() status, err = json["status"], json.get("error") - assert status in ["failed", "completed"], f"{status}, {err=}" + assert status in statuses, f"{status}, {err=}" wait_until(offloaded, timeout=60) res = self.offload_lfc_status() assert res["status"] != "failed", res + return res - def promote(self, promote_spec: dict[str, Any], disconnect: bool = False): + def promote(self, promote_spec: dict[str, Any], disconnect: bool = False) -> dict[str, str]: url = f"http://localhost:{self.external_port}/promote" if disconnect: try: # send first request to start promote and disconnect diff --git a/test_runner/regress/test_lfc_prewarm.py b/test_runner/regress/test_lfc_prewarm.py index 2bbe8c3e97..a96f18177c 100644 --- a/test_runner/regress/test_lfc_prewarm.py +++ b/test_runner/regress/test_lfc_prewarm.py @@ -1,6 +1,6 @@ import random -import threading from enum import StrEnum +from threading import Thread from time import sleep from typing import Any @@ -47,19 +47,23 @@ def offload_lfc(method: PrewarmMethod, client: EndpointHttpClient, cur: Cursor) # With autoprewarm, we need to be sure LFC was offloaded after all writes # finish, so we sleep. Otherwise we'll have less prewarmed pages than we want sleep(AUTOOFFLOAD_INTERVAL_SECS) - client.offload_lfc_wait() - return + offload_res = client.offload_lfc_wait() + log.info(offload_res) + return offload_res if method == PrewarmMethod.COMPUTE_CTL: status = client.prewarm_lfc_status() assert status["status"] == "not_prewarmed" assert "error" not in status - client.offload_lfc() + offload_res = client.offload_lfc() + log.info(offload_res) assert client.prewarm_lfc_status()["status"] == "not_prewarmed" + parsed = prom_parse(client) desired = {OFFLOAD_LABEL: 1, PREWARM_LABEL: 0, OFFLOAD_ERR_LABEL: 0, PREWARM_ERR_LABEL: 0} assert parsed == desired, f"{parsed=} != {desired=}" - return + + return offload_res raise AssertionError(f"{method} not in PrewarmMethod") @@ -68,21 +72,30 @@ def prewarm_endpoint( method: PrewarmMethod, client: EndpointHttpClient, cur: Cursor, lfc_state: str | None ): if method == PrewarmMethod.AUTOPREWARM: - client.prewarm_lfc_wait() + prewarm_res = client.prewarm_lfc_wait() + log.info(prewarm_res) elif method == PrewarmMethod.COMPUTE_CTL: - client.prewarm_lfc() + prewarm_res = client.prewarm_lfc() + log.info(prewarm_res) + return prewarm_res elif method == PrewarmMethod.POSTGRES: cur.execute("select neon.prewarm_local_cache(%s)", (lfc_state,)) -def check_prewarmed( +def check_prewarmed_contains( method: PrewarmMethod, client: EndpointHttpClient, desired_status: dict[str, str | int] ): if method == PrewarmMethod.AUTOPREWARM: - assert client.prewarm_lfc_status() == desired_status + prewarm_status = client.prewarm_lfc_status() + for k in desired_status: + assert desired_status[k] == prewarm_status[k] + assert prom_parse(client)[PREWARM_LABEL] == 1 elif method == PrewarmMethod.COMPUTE_CTL: - assert client.prewarm_lfc_status() == desired_status + prewarm_status = client.prewarm_lfc_status() + for k in desired_status: + assert desired_status[k] == prewarm_status[k] + desired = {OFFLOAD_LABEL: 0, PREWARM_LABEL: 1, PREWARM_ERR_LABEL: 0, OFFLOAD_ERR_LABEL: 0} assert prom_parse(client) == desired @@ -149,9 +162,6 @@ def test_lfc_prewarm(neon_simple_env: NeonEnv, method: PrewarmMethod): log.info(f"Used LFC size: {lfc_used_pages}") pg_cur.execute("select * from neon.get_prewarm_info()") total, prewarmed, skipped, _ = pg_cur.fetchall()[0] - log.info(f"Prewarm info: {total=} {prewarmed=} {skipped=}") - progress = (prewarmed + skipped) * 100 // total - log.info(f"Prewarm progress: {progress}%") assert lfc_used_pages > 10000 assert total > 0 assert prewarmed > 0 @@ -161,7 +171,54 @@ def test_lfc_prewarm(neon_simple_env: NeonEnv, method: PrewarmMethod): assert lfc_cur.fetchall()[0][0] == n_records * (n_records + 1) / 2 desired = {"status": "completed", "total": total, "prewarmed": prewarmed, "skipped": skipped} - check_prewarmed(method, client, desired) + check_prewarmed_contains(method, client, desired) + + +@pytest.mark.skipif(not USE_LFC, reason="LFC is disabled, skipping") +def test_lfc_prewarm_cancel(neon_simple_env: NeonEnv): + """ + Test we can cancel LFC prewarm and prewarm successfully after + """ + env = neon_simple_env + n_records = 1000000 + cfg = [ + "autovacuum = off", + "shared_buffers=1MB", + "neon.max_file_cache_size=1GB", + "neon.file_cache_size_limit=1GB", + "neon.file_cache_prewarm_limit=1000", + ] + endpoint = env.endpoints.create_start(branch_name="main", config_lines=cfg) + + pg_conn = endpoint.connect() + pg_cur = pg_conn.cursor() + pg_cur.execute("create schema neon; create extension neon with schema neon") + pg_cur.execute("create database lfc") + + lfc_conn = endpoint.connect(dbname="lfc") + lfc_cur = lfc_conn.cursor() + log.info(f"Inserting {n_records} rows") + lfc_cur.execute("create table t(pk integer primary key, payload text default repeat('?', 128))") + lfc_cur.execute(f"insert into t (pk) values (generate_series(1,{n_records}))") + log.info(f"Inserted {n_records} rows") + + client = endpoint.http_client() + method = PrewarmMethod.COMPUTE_CTL + offload_lfc(method, client, pg_cur) + + endpoint.stop() + endpoint.start() + + thread = Thread(target=lambda: prewarm_endpoint(method, client, pg_cur, None)) + thread.start() + # wait 2 seconds to ensure we cancel prewarm SQL query + sleep(2) + client.cancel_prewarm_lfc() + thread.join() + assert client.prewarm_lfc_status()["status"] == "cancelled" + + prewarm_endpoint(method, client, pg_cur, None) + assert client.prewarm_lfc_status()["status"] == "completed" @pytest.mark.skipif(not USE_LFC, reason="LFC is disabled, skipping") @@ -178,9 +235,8 @@ def test_lfc_prewarm_empty(neon_simple_env: NeonEnv): cur = conn.cursor() cur.execute("create schema neon; create extension neon with schema neon") method = PrewarmMethod.COMPUTE_CTL - offload_lfc(method, client, cur) - prewarm_endpoint(method, client, cur, None) - assert client.prewarm_lfc_status()["status"] == "skipped" + assert offload_lfc(method, client, cur)["status"] == "skipped" + assert prewarm_endpoint(method, client, cur, None)["status"] == "skipped" # autoprewarm isn't needed as we prewarm manually @@ -251,11 +307,11 @@ def test_lfc_prewarm_under_workload(neon_simple_env: NeonEnv, method: PrewarmMet workload_threads = [] for _ in range(n_threads): - t = threading.Thread(target=workload) + t = Thread(target=workload) workload_threads.append(t) t.start() - prewarm_thread = threading.Thread(target=prewarm) + prewarm_thread = Thread(target=prewarm) prewarm_thread.start() def prewarmed():