diff --git a/Cargo.lock b/Cargo.lock index 3ee261e885..d8bf04e87f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -753,6 +753,7 @@ dependencies = [ "axum", "axum-core", "bytes", + "form_urlencoded", "futures-util", "headers", "http 1.1.0", @@ -761,6 +762,8 @@ dependencies = [ "mime", "pin-project-lite", "serde", + "serde_html_form", + "serde_path_to_error", "tower 0.5.2", "tower-layer", "tower-service", @@ -6422,6 +6425,19 @@ dependencies = [ "syn 2.0.100", ] +[[package]] +name = "serde_html_form" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d2de91cf02bbc07cde38891769ccd5d4f073d22a40683aa4bc7a95781aaa2c4" +dependencies = [ + "form_urlencoded", + "indexmap 2.9.0", + "itoa", + "ryu", + "serde", +] + [[package]] name = "serde_json" version = "1.0.125" diff --git a/Cargo.toml b/Cargo.toml index a040010fb7..666ead7352 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -71,7 +71,7 @@ aws-credential-types = "1.2.0" aws-sigv4 = { version = "1.2", features = ["sign-http"] } aws-types = "1.3" axum = { version = "0.8.1", features = ["ws"] } -axum-extra = { version = "0.10.0", features = ["typed-header"] } +axum-extra = { version = "0.10.0", features = ["typed-header", "query"] } base64 = "0.13.0" bincode = "1.3" bindgen = "0.71" diff --git a/compute_tools/src/compute.rs b/compute_tools/src/compute.rs index bd6ed910be..f15538b157 100644 --- a/compute_tools/src/compute.rs +++ b/compute_tools/src/compute.rs @@ -785,7 +785,7 @@ impl ComputeNode { self.spawn_extension_stats_task(); if pspec.spec.autoprewarm { - self.prewarm_lfc(); + self.prewarm_lfc(None); } Ok(()) } diff --git a/compute_tools/src/compute_prewarm.rs b/compute_tools/src/compute_prewarm.rs index a6a84b3f1f..1c7a7bef60 100644 --- a/compute_tools/src/compute_prewarm.rs +++ b/compute_tools/src/compute_prewarm.rs @@ -25,11 +25,16 @@ struct EndpointStoragePair { } const KEY: &str = "lfc_state"; -impl TryFrom<&crate::compute::ParsedSpec> for EndpointStoragePair { - type Error = anyhow::Error; - fn try_from(pspec: &crate::compute::ParsedSpec) -> Result { - let Some(ref endpoint_id) = pspec.spec.endpoint_id else { - bail!("pspec.endpoint_id missing") +impl EndpointStoragePair { + /// endpoint_id is set to None while prewarming from other endpoint, see replica promotion + /// If not None, takes precedence over pspec.spec.endpoint_id + fn from_spec_and_endpoint( + pspec: &crate::compute::ParsedSpec, + endpoint_id: Option, + ) -> Result { + let endpoint_id = endpoint_id.as_ref().or(pspec.spec.endpoint_id.as_ref()); + let Some(ref endpoint_id) = endpoint_id else { + bail!("pspec.endpoint_id missing, other endpoint_id not provided") }; let Some(ref base_uri) = pspec.endpoint_storage_addr else { bail!("pspec.endpoint_storage_addr missing") @@ -84,7 +89,7 @@ impl ComputeNode { } /// Returns false if there is a prewarm request ongoing, true otherwise - pub fn prewarm_lfc(self: &Arc) -> bool { + pub fn prewarm_lfc(self: &Arc, from_endpoint: Option) -> bool { crate::metrics::LFC_PREWARM_REQUESTS.inc(); { let state = &mut self.state.lock().unwrap().lfc_prewarm_state; @@ -97,7 +102,7 @@ impl ComputeNode { let cloned = self.clone(); spawn(async move { - let Err(err) = cloned.prewarm_impl().await else { + let Err(err) = cloned.prewarm_impl(from_endpoint).await else { cloned.state.lock().unwrap().lfc_prewarm_state = LfcPrewarmState::Completed; return; }; @@ -109,13 +114,14 @@ impl ComputeNode { true } - fn endpoint_storage_pair(&self) -> Result { + /// from_endpoint: None for endpoint managed by this compute_ctl + fn endpoint_storage_pair(&self, from_endpoint: Option) -> Result { let state = self.state.lock().unwrap(); - state.pspec.as_ref().unwrap().try_into() + EndpointStoragePair::from_spec_and_endpoint(state.pspec.as_ref().unwrap(), from_endpoint) } - async fn prewarm_impl(&self) -> Result<()> { - let EndpointStoragePair { url, token } = self.endpoint_storage_pair()?; + async fn prewarm_impl(&self, from_endpoint: Option) -> Result<()> { + let EndpointStoragePair { url, token } = self.endpoint_storage_pair(from_endpoint)?; info!(%url, "requesting LFC state from endpoint storage"); let request = Client::new().get(&url).bearer_auth(token); @@ -173,7 +179,7 @@ impl ComputeNode { } async fn offload_lfc_impl(&self) -> Result<()> { - let EndpointStoragePair { url, token } = self.endpoint_storage_pair()?; + let EndpointStoragePair { url, token } = self.endpoint_storage_pair(None)?; info!(%url, "requesting LFC state from postgres"); let mut compressed = Vec::new(); diff --git a/compute_tools/src/http/routes/lfc.rs b/compute_tools/src/http/routes/lfc.rs index 07bcc6bfb7..e98bd781a2 100644 --- a/compute_tools/src/http/routes/lfc.rs +++ b/compute_tools/src/http/routes/lfc.rs @@ -2,6 +2,7 @@ use crate::compute_prewarm::LfcPrewarmStateWithProgress; use crate::http::JsonResponse; use axum::response::{IntoResponse, Response}; use axum::{Json, http::StatusCode}; +use axum_extra::extract::OptionalQuery; use compute_api::responses::LfcOffloadState; type Compute = axum::extract::State>; @@ -16,8 +17,16 @@ pub(in crate::http) async fn offload_state(compute: Compute) -> Json Response { - if compute.prewarm_lfc() { +#[derive(serde::Deserialize)] +pub struct PrewarmQuery { + pub from_endpoint: String, +} + +pub(in crate::http) async fn prewarm( + compute: Compute, + OptionalQuery(query): OptionalQuery, +) -> Response { + if compute.prewarm_lfc(query.map(|q| q.from_endpoint)) { StatusCode::ACCEPTED.into_response() } else { JsonResponse::error( diff --git a/test_runner/fixtures/endpoint/http.py b/test_runner/fixtures/endpoint/http.py index 6d37dd1cb1..e2d405227b 100644 --- a/test_runner/fixtures/endpoint/http.py +++ b/test_runner/fixtures/endpoint/http.py @@ -69,8 +69,10 @@ class EndpointHttpClient(requests.Session): json: dict[str, str] = res.json() return json - def prewarm_lfc(self): - self.post(f"http://localhost:{self.external_port}/lfc/prewarm").raise_for_status() + def prewarm_lfc(self, from_endpoint_id: str | None = None): + url: str = f"http://localhost:{self.external_port}/lfc/prewarm" + params = {"from_endpoint": from_endpoint_id} if from_endpoint_id else dict() + self.post(url, params=params).raise_for_status() def prewarmed(): json = self.prewarm_lfc_status() diff --git a/test_runner/regress/test_lfc_prewarm.py b/test_runner/regress/test_lfc_prewarm.py index 82e1e9fcba..40a9b29296 100644 --- a/test_runner/regress/test_lfc_prewarm.py +++ b/test_runner/regress/test_lfc_prewarm.py @@ -188,7 +188,8 @@ def test_lfc_prewarm_under_workload(neon_simple_env: NeonEnv, query: LfcQueryMet pg_cur.execute("select pg_reload_conf()") if query is LfcQueryMethod.COMPUTE_CTL: - http_client.prewarm_lfc() + # Same thing as prewarm_lfc(), testing other method + http_client.prewarm_lfc(endpoint.endpoint_id) else: pg_cur.execute("select prewarm_local_cache(%s)", (lfc_state,))