From df4e37b7cc1ba9fb9b21988a3ecd5b0590f33d53 Mon Sep 17 00:00:00 2001 From: Mikhail Date: Thu, 31 Jul 2025 12:51:19 +0100 Subject: [PATCH 1/3] Report timespans for promotion and prewarm (#12730) - Return sub-actions time spans for prewarm, prewarm offload, and promotion in http handlers. - Set `synchronous_standby_names=walproposer` for promoted endpoints. Otherwise, walproposer on promoted standby ignores reply from safekeeper and is stuck on lsn COMMIT eternally. --- compute_tools/src/compute.rs | 2 +- compute_tools/src/compute_prewarm.rs | 107 ++++++------ compute_tools/src/compute_promote.rs | 161 ++++++++++--------- compute_tools/src/http/openapi_spec.yaml | 31 +++- compute_tools/src/http/routes/lfc.rs | 5 +- compute_tools/src/http/routes/promote.rs | 13 +- libs/compute_api/src/responses.rs | 38 +++-- test_runner/regress/test_replica_promotes.py | 1 + 8 files changed, 209 insertions(+), 149 deletions(-) diff --git a/compute_tools/src/compute.rs b/compute_tools/src/compute.rs index 1df837e1e6..c0f0289e06 100644 --- a/compute_tools/src/compute.rs +++ b/compute_tools/src/compute.rs @@ -2780,7 +2780,7 @@ LIMIT 100", // 4. We start again and try to prewarm with the state from 2. instead of the previous complete state if matches!( prewarm_state, - LfcPrewarmState::Completed + LfcPrewarmState::Completed { .. } | LfcPrewarmState::NotPrewarmed | LfcPrewarmState::Skipped ) { diff --git a/compute_tools/src/compute_prewarm.rs b/compute_tools/src/compute_prewarm.rs index 82cb28f1ac..c7051d35d0 100644 --- a/compute_tools/src/compute_prewarm.rs +++ b/compute_tools/src/compute_prewarm.rs @@ -7,19 +7,11 @@ use http::StatusCode; use reqwest::Client; use std::mem::replace; use std::sync::Arc; +use std::time::Instant; use tokio::{io::AsyncReadExt, select, spawn}; use tokio_util::sync::CancellationToken; use tracing::{error, info}; -#[derive(serde::Serialize, Default)] -pub struct LfcPrewarmStateWithProgress { - #[serde(flatten)] - base: LfcPrewarmState, - total: i32, - prewarmed: i32, - skipped: i32, -} - /// A pair of url and a token to query endpoint storage for LFC prewarm-related tasks struct EndpointStoragePair { url: String, @@ -28,7 +20,7 @@ struct EndpointStoragePair { const KEY: &str = "lfc_state"; impl EndpointStoragePair { - /// endpoint_id is set to None while prewarming from other endpoint, see replica promotion + /// endpoint_id is set to None while prewarming from other endpoint, see compute_promote.rs /// If not None, takes precedence over pspec.spec.endpoint_id fn from_spec_and_endpoint( pspec: &crate::compute::ParsedSpec, @@ -54,36 +46,8 @@ impl EndpointStoragePair { } impl ComputeNode { - // If prewarm failed, we want to get overall number of segments as well as done ones. - // However, this function should be reliable even if querying postgres failed. - pub async fn lfc_prewarm_state(&self) -> LfcPrewarmStateWithProgress { - info!("requesting LFC prewarm state from postgres"); - let mut state = LfcPrewarmStateWithProgress::default(); - { - state.base = self.state.lock().unwrap().lfc_prewarm_state.clone(); - } - - let client = match ComputeNode::get_maintenance_client(&self.tokio_conn_conf).await { - Ok(client) => client, - Err(err) => { - error!(%err, "connecting to postgres"); - return state; - } - }; - let row = match client - .query_one("select * from neon.get_prewarm_info()", &[]) - .await - { - Ok(row) => row, - Err(err) => { - error!(%err, "querying LFC prewarm status"); - return state; - } - }; - state.total = row.try_get(0).unwrap_or_default(); - state.prewarmed = row.try_get(1).unwrap_or_default(); - state.skipped = row.try_get(2).unwrap_or_default(); - state + pub async fn lfc_prewarm_state(&self) -> LfcPrewarmState { + self.state.lock().unwrap().lfc_prewarm_state.clone() } pub fn lfc_offload_state(&self) -> LfcOffloadState { @@ -133,7 +97,6 @@ impl ComputeNode { } /// Request LFC state from endpoint storage and load corresponding pages into Postgres. - /// Returns a result with `false` if the LFC state is not found in endpoint storage. async fn prewarm_impl( &self, from_endpoint: Option, @@ -148,6 +111,7 @@ impl ComputeNode { fail::fail_point!("compute-prewarm", |_| bail!("compute-prewarm failpoint")); info!(%url, "requesting LFC state from endpoint storage"); + let mut now = Instant::now(); let request = Client::new().get(&url).bearer_auth(storage_token); let response = select! { _ = token.cancelled() => return Ok(LfcPrewarmState::Cancelled), @@ -160,6 +124,8 @@ impl ComputeNode { StatusCode::NOT_FOUND => return Ok(LfcPrewarmState::Skipped), status => bail!("{status} querying endpoint storage"), } + let state_download_time_ms = now.elapsed().as_millis() as u32; + now = Instant::now(); let mut uncompressed = Vec::new(); let lfc_state = select! { @@ -174,6 +140,8 @@ impl ComputeNode { read = decoder.read_to_end(&mut uncompressed) => read } .context("decoding LFC state")?; + let uncompress_time_ms = now.elapsed().as_millis() as u32; + now = Instant::now(); let uncompressed_len = uncompressed.len(); info!(%url, "downloaded LFC state, uncompressed size {uncompressed_len}"); @@ -196,15 +164,34 @@ impl ComputeNode { } .context("loading LFC state into postgres") .map(|_| ())?; + let prewarm_time_ms = now.elapsed().as_millis() as u32; - Ok(LfcPrewarmState::Completed) + let row = client + .query_one("select * from neon.get_prewarm_info()", &[]) + .await + .context("querying prewarm info")?; + let total = row.try_get(0).unwrap_or_default(); + let prewarmed = row.try_get(1).unwrap_or_default(); + let skipped = row.try_get(2).unwrap_or_default(); + + Ok(LfcPrewarmState::Completed { + total, + prewarmed, + skipped, + state_download_time_ms, + uncompress_time_ms, + prewarm_time_ms, + }) } /// If offload request is ongoing, return false, true otherwise pub fn offload_lfc(self: &Arc) -> bool { { let state = &mut self.state.lock().unwrap().lfc_offload_state; - if replace(state, LfcOffloadState::Offloading) == LfcOffloadState::Offloading { + if matches!( + replace(state, LfcOffloadState::Offloading), + LfcOffloadState::Offloading + ) { return false; } } @@ -216,7 +203,10 @@ impl ComputeNode { pub async fn offload_lfc_async(self: &Arc) { { let state = &mut self.state.lock().unwrap().lfc_offload_state; - if replace(state, LfcOffloadState::Offloading) == LfcOffloadState::Offloading { + if matches!( + replace(state, LfcOffloadState::Offloading), + LfcOffloadState::Offloading + ) { return; } } @@ -234,7 +224,6 @@ impl ComputeNode { LfcOffloadState::Failed { error } } }; - self.state.lock().unwrap().lfc_offload_state = state; } @@ -242,6 +231,7 @@ impl ComputeNode { let EndpointStoragePair { url, token } = self.endpoint_storage_pair(None)?; info!(%url, "requesting LFC state from Postgres"); + let mut now = Instant::now(); let row = ComputeNode::get_maintenance_client(&self.tokio_conn_conf) .await .context("connecting to postgres")? @@ -255,25 +245,36 @@ impl ComputeNode { info!(%url, "empty LFC state, not exporting"); return Ok(LfcOffloadState::Skipped); }; + let state_query_time_ms = now.elapsed().as_millis() as u32; + now = Instant::now(); let mut compressed = Vec::new(); ZstdEncoder::new(state) .read_to_end(&mut compressed) .await .context("compressing LFC state")?; + let compress_time_ms = now.elapsed().as_millis() as u32; + now = Instant::now(); let compressed_len = compressed.len(); - info!(%url, "downloaded LFC state, compressed size {compressed_len}, writing to endpoint storage"); + info!(%url, "downloaded LFC state, compressed size {compressed_len}"); let request = Client::new().put(url).bearer_auth(token).body(compressed); - match request.send().await { - Ok(res) if res.status() == StatusCode::OK => Ok(LfcOffloadState::Completed), - Ok(res) => bail!( - "Request to endpoint storage failed with status: {}", - res.status() - ), - Err(err) => Err(err).context("writing to endpoint storage"), + let response = request + .send() + .await + .context("writing to endpoint storage")?; + let state_upload_time_ms = now.elapsed().as_millis() as u32; + let status = response.status(); + if status != StatusCode::OK { + bail!("request to endpoint storage failed: {status}"); } + + Ok(LfcOffloadState::Completed { + compress_time_ms, + state_query_time_ms, + state_upload_time_ms, + }) } pub fn cancel_prewarm(self: &Arc) { diff --git a/compute_tools/src/compute_promote.rs b/compute_tools/src/compute_promote.rs index 29195b60e9..15b5fcfb46 100644 --- a/compute_tools/src/compute_promote.rs +++ b/compute_tools/src/compute_promote.rs @@ -1,32 +1,24 @@ use crate::compute::ComputeNode; -use anyhow::{Context, Result, bail}; +use anyhow::{Context, bail}; use compute_api::responses::{LfcPrewarmState, PromoteConfig, PromoteState}; -use compute_api::spec::ComputeMode; -use itertools::Itertools; -use std::collections::HashMap; -use std::{sync::Arc, time::Duration}; -use tokio::time::sleep; +use std::time::Instant; use tracing::info; -use utils::lsn::Lsn; impl ComputeNode { - /// Returns only when promote fails or succeeds. If a network error occurs - /// and http client disconnects, this does not stop promotion, and subsequent - /// calls block until promote finishes. + /// Returns only when promote fails or succeeds. If http client calling this function + /// disconnects, this does not stop promotion, and subsequent calls block until promote finishes. /// Called by control plane on secondary after primary endpoint is terminated /// Has a failpoint "compute-promotion" - pub async fn promote(self: &Arc, cfg: PromoteConfig) -> PromoteState { - let cloned = self.clone(); - let promote_fn = async move || { - let Err(err) = cloned.promote_impl(cfg).await else { - return PromoteState::Completed; - }; - tracing::error!(%err, "promoting"); - PromoteState::Failed { - error: format!("{err:#}"), + pub async fn promote(self: &std::sync::Arc, cfg: PromoteConfig) -> PromoteState { + let this = self.clone(); + let promote_fn = async move || match this.promote_impl(cfg).await { + Ok(state) => state, + Err(err) => { + tracing::error!(%err, "promoting replica"); + let error = format!("{err:#}"); + PromoteState::Failed { error } } }; - let start_promotion = || { let (tx, rx) = tokio::sync::watch::channel(PromoteState::NotPromoted); tokio::spawn(async move { tx.send(promote_fn().await) }); @@ -34,36 +26,31 @@ impl ComputeNode { }; let mut task; - // self.state is unlocked after block ends so we lock it in promote_impl - // and task.changed() is reached + // promote_impl locks self.state so we need to unlock it before calling task.changed() { - task = self - .state - .lock() - .unwrap() - .promote_state - .get_or_insert_with(start_promotion) - .clone() + let promote_state = &mut self.state.lock().unwrap().promote_state; + task = promote_state.get_or_insert_with(start_promotion).clone() + } + if task.changed().await.is_err() { + let error = "promote sender dropped".to_string(); + return PromoteState::Failed { error }; } - task.changed().await.expect("promote sender dropped"); task.borrow().clone() } - async fn promote_impl(&self, mut cfg: PromoteConfig) -> Result<()> { + async fn promote_impl(&self, cfg: PromoteConfig) -> anyhow::Result { { let state = self.state.lock().unwrap(); let mode = &state.pspec.as_ref().unwrap().spec.mode; - if *mode != ComputeMode::Replica { - bail!("{} is not replica", mode.to_type_str()); + if *mode != compute_api::spec::ComputeMode::Replica { + bail!("compute mode \"{}\" is not replica", mode.to_type_str()); } - - // we don't need to query Postgres so not self.lfc_prewarm_state() match &state.lfc_prewarm_state { - LfcPrewarmState::NotPrewarmed | LfcPrewarmState::Prewarming => { - bail!("prewarm not requested or pending") + status @ (LfcPrewarmState::NotPrewarmed | LfcPrewarmState::Prewarming) => { + bail!("compute {status}") } LfcPrewarmState::Failed { error } => { - tracing::warn!(%error, "replica prewarm failed") + tracing::warn!(%error, "compute prewarm failed") } _ => {} } @@ -72,9 +59,10 @@ impl ComputeNode { let client = ComputeNode::get_maintenance_client(&self.tokio_conn_conf) .await .context("connecting to postgres")?; + let mut now = Instant::now(); let primary_lsn = cfg.wal_flush_lsn; - let mut last_wal_replay_lsn: Lsn = Lsn::INVALID; + let mut standby_lsn = utils::lsn::Lsn::INVALID; const RETRIES: i32 = 20; for i in 0..=RETRIES { let row = client @@ -82,16 +70,18 @@ impl ComputeNode { .await .context("getting last replay lsn")?; let lsn: u64 = row.get::(0).into(); - last_wal_replay_lsn = lsn.into(); - if last_wal_replay_lsn >= primary_lsn { + standby_lsn = lsn.into(); + if standby_lsn >= primary_lsn { break; } - info!("Try {i}, replica lsn {last_wal_replay_lsn}, primary lsn {primary_lsn}"); - sleep(Duration::from_secs(1)).await; + info!(%standby_lsn, %primary_lsn, "catching up, try {i}"); + tokio::time::sleep(std::time::Duration::from_secs(1)).await; } - if last_wal_replay_lsn < primary_lsn { + if standby_lsn < primary_lsn { bail!("didn't catch up with primary in {RETRIES} retries"); } + let lsn_wait_time_ms = now.elapsed().as_millis() as u32; + now = Instant::now(); // using $1 doesn't work with ALTER SYSTEM SET let safekeepers_sql = format!( @@ -102,27 +92,33 @@ impl ComputeNode { .query(&safekeepers_sql, &[]) .await .context("setting safekeepers")?; + client + .query( + "ALTER SYSTEM SET synchronous_standby_names=walproposer", + &[], + ) + .await + .context("setting synchronous_standby_names")?; client .query("SELECT pg_catalog.pg_reload_conf()", &[]) .await .context("reloading postgres config")?; #[cfg(feature = "testing")] - fail::fail_point!("compute-promotion", |_| { - bail!("promotion configured to fail because of a failpoint") - }); + fail::fail_point!("compute-promotion", |_| bail!( + "compute-promotion failpoint" + )); let row = client .query_one("SELECT * FROM pg_catalog.pg_promote()", &[]) .await .context("pg_promote")?; if !row.get::(0) { - bail!("pg_promote() returned false"); + bail!("pg_promote() failed"); } + let pg_promote_time_ms = now.elapsed().as_millis() as u32; + let now = Instant::now(); - let client = ComputeNode::get_maintenance_client(&self.tokio_conn_conf) - .await - .context("connecting to postgres")?; let row = client .query_one("SHOW transaction_read_only", &[]) .await @@ -131,36 +127,47 @@ impl ComputeNode { bail!("replica in read only mode after promotion"); } + // Already checked validity in http handler + #[allow(unused_mut)] + let mut new_pspec = crate::compute::ParsedSpec::try_from(cfg.spec).expect("invalid spec"); { let mut state = self.state.lock().unwrap(); - let spec = &mut state.pspec.as_mut().unwrap().spec; - spec.mode = ComputeMode::Primary; - let new_conf = cfg.spec.cluster.postgresql_conf.as_mut().unwrap(); - let existing_conf = spec.cluster.postgresql_conf.as_ref().unwrap(); - Self::merge_spec(new_conf, existing_conf); + + // Local setup has different ports for pg process (port=) for primary and secondary. + // Primary is stopped so we need secondary's "port" value + #[cfg(feature = "testing")] + { + let old_spec = &state.pspec.as_ref().unwrap().spec; + let Some(old_conf) = old_spec.cluster.postgresql_conf.as_ref() else { + bail!("pspec.spec.cluster.postgresql_conf missing for endpoint"); + }; + let set: std::collections::HashMap<&str, &str> = old_conf + .split_terminator('\n') + .map(|e| e.split_once("=").expect("invalid item")) + .collect(); + + let Some(new_conf) = new_pspec.spec.cluster.postgresql_conf.as_mut() else { + bail!("pspec.spec.cluster.postgresql_conf missing for supplied config"); + }; + new_conf.push_str(&format!("port={}\n", set["port"])); + } + + tracing::debug!("applied spec: {:#?}", new_pspec.spec); + if self.params.lakebase_mode { + ComputeNode::set_spec(&self.params, &mut state, new_pspec); + } else { + state.pspec = Some(new_pspec); + } } + info!("applied new spec, reconfiguring as primary"); - self.reconfigure() - } + self.reconfigure()?; + let reconfigure_time_ms = now.elapsed().as_millis() as u32; - /// Merge old and new Postgres conf specs to apply on secondary. - /// Change new spec's port and safekeepers since they are supplied - /// differenly - fn merge_spec(new_conf: &mut String, existing_conf: &str) { - let mut new_conf_set: HashMap<&str, &str> = new_conf - .split_terminator('\n') - .map(|e| e.split_once("=").expect("invalid item")) - .collect(); - new_conf_set.remove("neon.safekeepers"); - - let existing_conf_set: HashMap<&str, &str> = existing_conf - .split_terminator('\n') - .map(|e| e.split_once("=").expect("invalid item")) - .collect(); - new_conf_set.insert("port", existing_conf_set["port"]); - *new_conf = new_conf_set - .iter() - .map(|(k, v)| format!("{k}={v}")) - .join("\n"); + Ok(PromoteState::Completed { + lsn_wait_time_ms, + pg_promote_time_ms, + reconfigure_time_ms, + }) } } diff --git a/compute_tools/src/http/openapi_spec.yaml b/compute_tools/src/http/openapi_spec.yaml index 27e610a87d..82e61acfdc 100644 --- a/compute_tools/src/http/openapi_spec.yaml +++ b/compute_tools/src/http/openapi_spec.yaml @@ -617,9 +617,6 @@ components: type: object required: - status - - total - - prewarmed - - skipped properties: status: description: LFC prewarm status @@ -637,6 +634,15 @@ components: skipped: description: Pages processed but not prewarmed type: integer + state_download_time_ms: + description: Time it takes to download LFC state to compute + type: integer + uncompress_time_ms: + description: Time it takes to uncompress LFC state + type: integer + prewarm_time_ms: + description: Time it takes to prewarm LFC state in Postgres + type: integer LfcOffloadState: type: object @@ -650,6 +656,16 @@ components: error: description: LFC offload error, if any type: string + state_query_time_ms: + description: Time it takes to get LFC state from Postgres + type: integer + compress_time_ms: + description: Time it takes to compress LFC state + type: integer + state_upload_time_ms: + description: Time it takes to upload LFC state to endpoint storage + type: integer + PromoteState: type: object @@ -663,6 +679,15 @@ components: error: description: Promote error, if any type: string + lsn_wait_time_ms: + description: Time it takes for secondary to catch up with primary WAL flush LSN + type: integer + pg_promote_time_ms: + description: Time it takes to call pg_promote on secondary + type: integer + reconfigure_time_ms: + description: Time it takes to reconfigure promoted secondary + type: integer SetRoleGrantsRequest: type: object diff --git a/compute_tools/src/http/routes/lfc.rs b/compute_tools/src/http/routes/lfc.rs index 7483198723..6ad216778e 100644 --- a/compute_tools/src/http/routes/lfc.rs +++ b/compute_tools/src/http/routes/lfc.rs @@ -1,12 +1,11 @@ -use crate::compute_prewarm::LfcPrewarmStateWithProgress; use crate::http::JsonResponse; use axum::response::{IntoResponse, Response}; use axum::{Json, http::StatusCode}; use axum_extra::extract::OptionalQuery; -use compute_api::responses::LfcOffloadState; +use compute_api::responses::{LfcOffloadState, LfcPrewarmState}; type Compute = axum::extract::State>; -pub(in crate::http) async fn prewarm_state(compute: Compute) -> Json { +pub(in crate::http) async fn prewarm_state(compute: Compute) -> Json { Json(compute.lfc_prewarm_state().await) } diff --git a/compute_tools/src/http/routes/promote.rs b/compute_tools/src/http/routes/promote.rs index 7ca3464b63..865de0da91 100644 --- a/compute_tools/src/http/routes/promote.rs +++ b/compute_tools/src/http/routes/promote.rs @@ -1,11 +1,22 @@ use crate::http::JsonResponse; use axum::extract::Json; +use compute_api::responses::PromoteConfig; use http::StatusCode; pub(in crate::http) async fn promote( compute: axum::extract::State>, - Json(cfg): Json, + Json(cfg): Json, ) -> axum::response::Response { + // Return early at the cost of extra parsing spec + let pspec = match crate::compute::ParsedSpec::try_from(cfg.spec) { + Ok(p) => p, + Err(e) => return JsonResponse::error(StatusCode::BAD_REQUEST, e), + }; + + let cfg = PromoteConfig { + spec: pspec.spec, + wal_flush_lsn: cfg.wal_flush_lsn, + }; let state = compute.promote(cfg).await; if let compute_api::responses::PromoteState::Failed { error: _ } = state { return JsonResponse::create_response(StatusCode::INTERNAL_SERVER_ERROR, state); diff --git a/libs/compute_api/src/responses.rs b/libs/compute_api/src/responses.rs index a918644e4c..a61d418dd1 100644 --- a/libs/compute_api/src/responses.rs +++ b/libs/compute_api/src/responses.rs @@ -1,10 +1,9 @@ //! Structs representing the JSON formats used in the compute_ctl's HTTP API. -use std::fmt::Display; - use chrono::{DateTime, Utc}; use jsonwebtoken::jwk::JwkSet; use serde::{Deserialize, Serialize, Serializer}; +use std::fmt::Display; use crate::privilege::Privilege; use crate::spec::{ComputeSpec, Database, ExtVersion, PgIdent, Role}; @@ -49,7 +48,7 @@ pub struct ExtensionInstallResponse { /// Status of the LFC prewarm process. The same state machine is reused for /// both autoprewarm (prewarm after compute/Postgres start using the previously /// stored LFC state) and explicit prewarming via API. -#[derive(Serialize, Default, Debug, Clone, PartialEq)] +#[derive(Serialize, Default, Debug, Clone)] #[serde(tag = "status", rename_all = "snake_case")] pub enum LfcPrewarmState { /// Default value when compute boots up. @@ -59,7 +58,14 @@ pub enum LfcPrewarmState { Prewarming, /// We found requested LFC state in the endpoint storage and /// completed prewarming successfully. - Completed, + Completed { + total: i32, + prewarmed: i32, + skipped: i32, + state_download_time_ms: u32, + uncompress_time_ms: u32, + prewarm_time_ms: u32, + }, /// Unexpected error happened during prewarming. Note, `Not Found 404` /// response from the endpoint storage is explicitly excluded here /// because it can normally happen on the first compute start, @@ -84,7 +90,7 @@ impl Display for LfcPrewarmState { match self { LfcPrewarmState::NotPrewarmed => f.write_str("NotPrewarmed"), LfcPrewarmState::Prewarming => f.write_str("Prewarming"), - LfcPrewarmState::Completed => f.write_str("Completed"), + LfcPrewarmState::Completed { .. } => f.write_str("Completed"), LfcPrewarmState::Skipped => f.write_str("Skipped"), LfcPrewarmState::Failed { error } => write!(f, "Error({error})"), LfcPrewarmState::Cancelled => f.write_str("Cancelled"), @@ -92,26 +98,36 @@ impl Display for LfcPrewarmState { } } -#[derive(Serialize, Default, Debug, Clone, PartialEq)] +#[derive(Serialize, Default, Debug, Clone)] #[serde(tag = "status", rename_all = "snake_case")] pub enum LfcOffloadState { #[default] NotOffloaded, Offloading, - Completed, + Completed { + state_query_time_ms: u32, + compress_time_ms: u32, + state_upload_time_ms: u32, + }, Failed { error: String, }, + /// LFC state was empty so it wasn't offloaded Skipped, } -#[derive(Serialize, Debug, Clone, PartialEq)] +#[derive(Serialize, Debug, Clone)] #[serde(tag = "status", rename_all = "snake_case")] -/// Response of /promote pub enum PromoteState { NotPromoted, - Completed, - Failed { error: String }, + Completed { + lsn_wait_time_ms: u32, + pg_promote_time_ms: u32, + reconfigure_time_ms: u32, + }, + Failed { + error: String, + }, } #[derive(Deserialize, Default, Debug)] diff --git a/test_runner/regress/test_replica_promotes.py b/test_runner/regress/test_replica_promotes.py index 9415d6886c..26165b40cd 100644 --- a/test_runner/regress/test_replica_promotes.py +++ b/test_runner/regress/test_replica_promotes.py @@ -145,6 +145,7 @@ def test_replica_promote(neon_simple_env: NeonEnv, method: PromoteMethod): stop_and_check_lsn(secondary, None) if method == PromoteMethod.COMPUTE_CTL: + log.info("Restarting primary to check new config") secondary.stop() # In production, compute ultimately receives new compute spec from cplane. secondary.respec(mode="Primary") From 312a74f11fdf5a8689837834be4c5ce0cc95cebf Mon Sep 17 00:00:00 2001 From: Dmitrii Kovalkov <34828390+DimasKovas@users.noreply.github.com> Date: Thu, 31 Jul 2025 16:40:32 +0400 Subject: [PATCH 2/3] storcon: implement safekeeper_migrate_abort handler (#12705) ## Problem Right now if we commit a joint configuration to DB, there is no way back. The only way to get the clean mconf is to continue the migration. The RFC also described an abort mechanism, which allows to abort current migration and revert mconf change. It might be needed if the migration is stuck and cannot have any progress, e.g. if the sk we are migrating to went down during the migration. This PR implements this abort algorithm. - Closes: https://databricks.atlassian.net/browse/LKB-899 - Closes: https://github.com/neondatabase/neon/issues/12549 ## Summary of changes - Implement `safekeeper_migrate_abort` handler with the algorithm described in RFC - Add `timeline-safekeeper-migrate-abort` subcommand to `storcon_cli` - Add test for the migration abort algorithm. --- control_plane/storcon_cli/src/main.rs | 18 +++ storage_controller/src/http.rs | 28 ++++ .../src/service/safekeeper_service.rs | 123 +++++++++++++++++- test_runner/fixtures/neon_fixtures.py | 13 ++ .../regress/test_safekeeper_migration.py | 88 +++++++++++++ 5 files changed, 266 insertions(+), 4 deletions(-) diff --git a/control_plane/storcon_cli/src/main.rs b/control_plane/storcon_cli/src/main.rs index a4d1030488..635b1858ec 100644 --- a/control_plane/storcon_cli/src/main.rs +++ b/control_plane/storcon_cli/src/main.rs @@ -303,6 +303,13 @@ enum Command { #[arg(long, required = true, value_delimiter = ',')] new_sk_set: Vec, }, + /// Abort ongoing safekeeper migration. + TimelineSafekeeperMigrateAbort { + #[arg(long)] + tenant_id: TenantId, + #[arg(long)] + timeline_id: TimelineId, + }, } #[derive(Parser)] @@ -1396,6 +1403,17 @@ async fn main() -> anyhow::Result<()> { ) .await?; } + Command::TimelineSafekeeperMigrateAbort { + tenant_id, + timeline_id, + } => { + let path = + format!("v1/tenant/{tenant_id}/timeline/{timeline_id}/safekeeper_migrate_abort"); + + storcon_client + .dispatch::<(), ()>(Method::POST, path, None) + .await?; + } } Ok(()) diff --git a/storage_controller/src/http.rs b/storage_controller/src/http.rs index ff73719adb..b40da4fd65 100644 --- a/storage_controller/src/http.rs +++ b/storage_controller/src/http.rs @@ -644,6 +644,7 @@ async fn handle_tenant_timeline_safekeeper_migrate( req: Request, ) -> Result, ApiError> { let tenant_id: TenantId = parse_request_param(&req, "tenant_id")?; + // TODO(diko): it's not PS operation, there should be a different permission scope. check_permissions(&req, Scope::PageServerApi)?; maybe_rate_limit(&req, tenant_id).await; @@ -665,6 +666,23 @@ async fn handle_tenant_timeline_safekeeper_migrate( json_response(StatusCode::OK, ()) } +async fn handle_tenant_timeline_safekeeper_migrate_abort( + service: Arc, + req: Request, +) -> Result, ApiError> { + let tenant_id: TenantId = parse_request_param(&req, "tenant_id")?; + let timeline_id: TimelineId = parse_request_param(&req, "timeline_id")?; + // TODO(diko): it's not PS operation, there should be a different permission scope. + check_permissions(&req, Scope::PageServerApi)?; + maybe_rate_limit(&req, tenant_id).await; + + service + .tenant_timeline_safekeeper_migrate_abort(tenant_id, timeline_id) + .await?; + + json_response(StatusCode::OK, ()) +} + async fn handle_tenant_timeline_lsn_lease( service: Arc, req: Request, @@ -2611,6 +2629,16 @@ pub fn make_router( ) }, ) + .post( + "/v1/tenant/:tenant_id/timeline/:timeline_id/safekeeper_migrate_abort", + |r| { + tenant_service_handler( + r, + handle_tenant_timeline_safekeeper_migrate_abort, + RequestName("v1_tenant_timeline_safekeeper_migrate_abort"), + ) + }, + ) // LSN lease passthrough to all shards .post( "/v1/tenant/:tenant_id/timeline/:timeline_id/lsn_lease", diff --git a/storage_controller/src/service/safekeeper_service.rs b/storage_controller/src/service/safekeeper_service.rs index fab1342d5d..689d341b6a 100644 --- a/storage_controller/src/service/safekeeper_service.rs +++ b/storage_controller/src/service/safekeeper_service.rs @@ -1230,10 +1230,7 @@ impl Service { } // It it is the same new_sk_set, we can continue the migration (retry). } else { - let prev_finished = timeline.cplane_notified_generation == timeline.generation - && timeline.sk_set_notified_generation == timeline.generation; - - if !prev_finished { + if !is_migration_finished(&timeline) { // The previous migration is committed, but the finish step failed. // Safekeepers/cplane might not know about the last membership configuration. // Retry the finish step to ensure smooth migration. @@ -1545,6 +1542,8 @@ impl Service { timeline_id: TimelineId, timeline: &TimelinePersistence, ) -> Result<(), ApiError> { + tracing::info!(generation=?timeline.generation, sk_set=?timeline.sk_set, new_sk_set=?timeline.new_sk_set, "retrying finish safekeeper migration"); + if timeline.new_sk_set.is_some() { // Logical error, should never happen. return Err(ApiError::InternalServerError(anyhow::anyhow!( @@ -1624,4 +1623,120 @@ impl Service { Ok(wal_positions[quorum_size - 1]) } + + /// Abort ongoing safekeeper migration. + pub(crate) async fn tenant_timeline_safekeeper_migrate_abort( + self: &Arc, + tenant_id: TenantId, + timeline_id: TimelineId, + ) -> Result<(), ApiError> { + // TODO(diko): per-tenant lock is too wide. Consider introducing per-timeline locks. + let _tenant_lock = trace_shared_lock( + &self.tenant_op_locks, + tenant_id, + TenantOperations::TimelineSafekeeperMigrate, + ) + .await; + + // Fetch current timeline configuration from the configuration storage. + let timeline = self + .persistence + .get_timeline(tenant_id, timeline_id) + .await?; + + let Some(timeline) = timeline else { + return Err(ApiError::NotFound( + anyhow::anyhow!( + "timeline {tenant_id}/{timeline_id} doesn't exist in timelines table" + ) + .into(), + )); + }; + + let mut generation = SafekeeperGeneration::new(timeline.generation as u32); + + let Some(new_sk_set) = &timeline.new_sk_set else { + // No new_sk_set -> no active migration that we can abort. + tracing::info!("timeline has no active migration"); + + if !is_migration_finished(&timeline) { + // The last migration is committed, but the finish step failed. + // Safekeepers/cplane might not know about the last membership configuration. + // Retry the finish step to make the timeline state clean. + self.finish_safekeeper_migration_retry(tenant_id, timeline_id, &timeline) + .await?; + } + return Ok(()); + }; + + tracing::info!(sk_set=?timeline.sk_set, ?new_sk_set, ?generation, "aborting timeline migration"); + + let cur_safekeepers = self.get_safekeepers(&timeline.sk_set)?; + let new_safekeepers = self.get_safekeepers(new_sk_set)?; + + let cur_sk_member_set = + Self::make_member_set(&cur_safekeepers).map_err(ApiError::InternalServerError)?; + + // Increment current generation and remove new_sk_set from the timeline to abort the migration. + generation = generation.next(); + + let mconf = membership::Configuration { + generation, + members: cur_sk_member_set, + new_members: None, + }; + + // Exclude safekeepers which were added during the current migration. + let cur_ids: HashSet = cur_safekeepers.iter().map(|sk| sk.get_id()).collect(); + let exclude_safekeepers = new_safekeepers + .into_iter() + .filter(|sk| !cur_ids.contains(&sk.get_id())) + .collect::>(); + + let exclude_requests = exclude_safekeepers + .iter() + .map(|sk| TimelinePendingOpPersistence { + sk_id: sk.skp.id, + tenant_id: tenant_id.to_string(), + timeline_id: timeline_id.to_string(), + generation: generation.into_inner() as i32, + op_kind: SafekeeperTimelineOpKind::Exclude, + }) + .collect::>(); + + let cur_sk_set = cur_safekeepers + .iter() + .map(|sk| sk.get_id()) + .collect::>(); + + // Persist new mconf and exclude requests. + self.persistence + .update_timeline_membership( + tenant_id, + timeline_id, + generation, + &cur_sk_set, + None, + &exclude_requests, + ) + .await?; + + // At this point we have already commited the abort, but still need to notify + // cplane/safekeepers with the new mconf. That's what finish_safekeeper_migration does. + self.finish_safekeeper_migration( + tenant_id, + timeline_id, + &cur_safekeepers, + &mconf, + &exclude_safekeepers, + ) + .await?; + + Ok(()) + } +} + +fn is_migration_finished(timeline: &TimelinePersistence) -> bool { + timeline.cplane_notified_generation == timeline.generation + && timeline.sk_set_notified_generation == timeline.generation } diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index c3dfc78218..41213d374a 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -2323,6 +2323,19 @@ class NeonStorageController(MetricsGetter, LogUtils): response.raise_for_status() log.info(f"migrate_safekeepers success: {response.json()}") + def abort_safekeeper_migration( + self, + tenant_id: TenantId, + timeline_id: TimelineId, + ): + response = self.request( + "POST", + f"{self.api}/v1/tenant/{tenant_id}/timeline/{timeline_id}/safekeeper_migrate_abort", + headers=self.headers(TokenScope.PAGE_SERVER_API), + ) + response.raise_for_status() + log.info(f"abort_safekeeper_migration success: {response.json()}") + def locate(self, tenant_id: TenantId) -> list[dict[str, Any]]: """ :return: list of {"shard_id": "", "node_id": int, "listen_pg_addr": str, "listen_pg_port": int, "listen_http_addr": str, "listen_http_port": int} diff --git a/test_runner/regress/test_safekeeper_migration.py b/test_runner/regress/test_safekeeper_migration.py index 97a6ece446..ba067b97de 100644 --- a/test_runner/regress/test_safekeeper_migration.py +++ b/test_runner/regress/test_safekeeper_migration.py @@ -460,3 +460,91 @@ def test_pull_from_most_advanced_sk(neon_env_builder: NeonEnvBuilder): ep.start(safekeeper_generation=5, safekeepers=new_sk_set2) assert ep.safe_psql("SELECT * FROM t") == [(0,), (1,)] + + +def test_abort_safekeeper_migration(neon_env_builder: NeonEnvBuilder): + """ + Test that safekeeper migration can be aborted. + 1. Insert failpoints and ensure the abort successfully reverts the timeline state. + 2. Check that endpoint is operational after the abort. + """ + neon_env_builder.num_safekeepers = 2 + neon_env_builder.storage_controller_config = { + "timelines_onto_safekeepers": True, + "timeline_safekeeper_count": 1, + } + env = neon_env_builder.init_start() + env.pageserver.allowed_errors.extend(PAGESERVER_ALLOWED_ERRORS) + + mconf = env.storage_controller.timeline_locate(env.initial_tenant, env.initial_timeline) + assert len(mconf["sk_set"]) == 1 + cur_sk = mconf["sk_set"][0] + cur_gen = 1 + + ep = env.endpoints.create("main", tenant_id=env.initial_tenant) + ep.start(safekeeper_generation=1, safekeepers=mconf["sk_set"]) + ep.safe_psql("CREATE EXTENSION neon_test_utils;") + ep.safe_psql("CREATE TABLE t(a int)") + ep.safe_psql("INSERT INTO t VALUES (1)") + + another_sk = [sk.id for sk in env.safekeepers if sk.id != cur_sk][0] + + failpoints = [ + "sk-migration-after-step-3", + "sk-migration-after-step-4", + "sk-migration-after-step-5", + "sk-migration-after-step-7", + ] + + for fp in failpoints: + env.storage_controller.configure_failpoints((fp, "return(1)")) + + with pytest.raises(StorageControllerApiException, match=f"failpoint {fp}"): + env.storage_controller.migrate_safekeepers( + env.initial_tenant, env.initial_timeline, [another_sk] + ) + cur_gen += 1 + + env.storage_controller.configure_failpoints((fp, "off")) + + # We should have a joint mconf after the failure. + mconf = env.storage_controller.timeline_locate(env.initial_tenant, env.initial_timeline) + assert mconf["generation"] == cur_gen + assert mconf["sk_set"] == [cur_sk] + assert mconf["new_sk_set"] == [another_sk] + + env.storage_controller.abort_safekeeper_migration(env.initial_tenant, env.initial_timeline) + cur_gen += 1 + + # Abort should revert the timeline to the previous sk_set and increment the generation. + mconf = env.storage_controller.timeline_locate(env.initial_tenant, env.initial_timeline) + assert mconf["generation"] == cur_gen + assert mconf["sk_set"] == [cur_sk] + assert mconf["new_sk_set"] is None + + assert ep.safe_psql("SHOW neon.safekeepers")[0][0].startswith(f"g#{cur_gen}:") + ep.safe_psql(f"INSERT INTO t VALUES ({cur_gen})") + + # After step-8 the final mconf is committed and the migration is not abortable anymore. + # So the abort should not abort anything. + env.storage_controller.configure_failpoints(("sk-migration-after-step-8", "return(1)")) + + with pytest.raises(StorageControllerApiException, match="failpoint sk-migration-after-step-8"): + env.storage_controller.migrate_safekeepers( + env.initial_tenant, env.initial_timeline, [another_sk] + ) + cur_gen += 2 + + env.storage_controller.configure_failpoints((fp, "off")) + + env.storage_controller.abort_safekeeper_migration(env.initial_tenant, env.initial_timeline) + + # The migration is fully committed, no abort should have been performed. + mconf = env.storage_controller.timeline_locate(env.initial_tenant, env.initial_timeline) + assert mconf["generation"] == cur_gen + assert mconf["sk_set"] == [another_sk] + assert mconf["new_sk_set"] is None + + ep.safe_psql(f"INSERT INTO t VALUES ({cur_gen})") + ep.clear_buffers() + assert ep.safe_psql("SELECT * FROM t") == [(i + 1,) for i in range(cur_gen) if i % 2 == 0] From d96cea191708fbfbbb3ee4599bc9e4eb391f67c1 Mon Sep 17 00:00:00 2001 From: Ruslan Talpa Date: Thu, 31 Jul 2025 16:05:09 +0300 Subject: [PATCH 3/3] [proxy] handle options request in rest broker (cors headers) (#12744) ## Problem rest broker needs to respond with the correct cors headers for the api to be usable from other domains ## Summary of changes added a code path in rest broker to handle the OPTIONS requests --------- Co-authored-by: Ruslan Talpa --- proxy/src/serverless/rest.rs | 224 +++++++++++++++++++++++++---------- 1 file changed, 159 insertions(+), 65 deletions(-) diff --git a/proxy/src/serverless/rest.rs b/proxy/src/serverless/rest.rs index 0c3d2c958d..9f98e87272 100644 --- a/proxy/src/serverless/rest.rs +++ b/proxy/src/serverless/rest.rs @@ -5,12 +5,17 @@ use std::sync::Arc; use bytes::Bytes; use http::Method; -use http::header::{AUTHORIZATION, CONTENT_TYPE, HOST}; +use http::header::{ + ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_METHODS, ACCESS_CONTROL_ALLOW_ORIGIN, + ACCESS_CONTROL_EXPOSE_HEADERS, ACCESS_CONTROL_MAX_AGE, ACCESS_CONTROL_REQUEST_HEADERS, ALLOW, + AUTHORIZATION, CONTENT_TYPE, HOST, ORIGIN, +}; use http_body_util::combinators::BoxBody; -use http_body_util::{BodyExt, Full}; +use http_body_util::{BodyExt, Empty, Full}; use http_utils::error::ApiError; use hyper::body::Incoming; -use hyper::http::{HeaderName, HeaderValue}; +use hyper::http::response::Builder; +use hyper::http::{HeaderMap, HeaderName, HeaderValue}; use hyper::{Request, Response, StatusCode}; use indexmap::IndexMap; use moka::sync::Cache; @@ -67,6 +72,15 @@ use crate::util::deserialize_json_string; static EMPTY_JSON_SCHEMA: &str = r#"{"schemas":[]}"#; const INTROSPECTION_SQL: &str = POSTGRESQL_INTROSPECTION_SQL; +const HEADER_VALUE_ALLOW_ALL_ORIGINS: HeaderValue = HeaderValue::from_static("*"); +// CORS headers values +const ACCESS_CONTROL_ALLOW_METHODS_VALUE: HeaderValue = + HeaderValue::from_static("GET, POST, PATCH, PUT, DELETE, OPTIONS"); +const ACCESS_CONTROL_MAX_AGE_VALUE: HeaderValue = HeaderValue::from_static("86400"); +const ACCESS_CONTROL_EXPOSE_HEADERS_VALUE: HeaderValue = HeaderValue::from_static( + "Content-Encoding, Content-Location, Content-Range, Content-Type, Date, Location, Server, Transfer-Encoding, Range-Unit", +); +const ACCESS_CONTROL_ALLOW_HEADERS_VALUE: HeaderValue = HeaderValue::from_static("Authorization"); // A wrapper around the DbSchema that allows for self-referencing #[self_referencing] @@ -137,6 +151,8 @@ pub struct ApiConfig { pub role_claim_key: String, #[serde(default, deserialize_with = "deserialize_comma_separated_option")] pub db_extra_search_path: Option>, + #[serde(default, deserialize_with = "deserialize_comma_separated_option")] + pub server_cors_allowed_origins: Option>, } // The DbSchemaCache is a cache of the ApiConfig and DbSchemaOwned for each endpoint @@ -165,7 +181,13 @@ impl DbSchemaCache { } } - pub async fn get_cached_or_remote( + pub fn get_cached( + &self, + endpoint_id: &EndpointCacheKey, + ) -> Option> { + count_cache_outcome(CacheKind::Schema, self.0.get(endpoint_id)) + } + pub async fn get_remote( &self, endpoint_id: &EndpointCacheKey, auth_header: &HeaderValue, @@ -174,47 +196,42 @@ impl DbSchemaCache { ctx: &RequestContext, config: &'static ProxyConfig, ) -> Result, RestError> { - let cache_result = count_cache_outcome(CacheKind::Schema, self.0.get(endpoint_id)); - match cache_result { - Some(v) => Ok(v), - None => { - info!("db_schema cache miss for endpoint: {:?}", endpoint_id); - let remote_value = self - .get_remote(auth_header, connection_string, client, ctx, config) - .await; - let (api_config, schema_owned) = match remote_value { - Ok((api_config, schema_owned)) => (api_config, schema_owned), - Err(e @ RestError::SchemaTooLarge) => { - // for the case where the schema is too large, we cache an empty dummy value - // all the other requests will fail without triggering the introspection query - let schema_owned = serde_json::from_str::(EMPTY_JSON_SCHEMA) - .map_err(|e| JsonDeserialize { source: e })?; + info!("db_schema cache miss for endpoint: {:?}", endpoint_id); + let remote_value = self + .internal_get_remote(auth_header, connection_string, client, ctx, config) + .await; + let (api_config, schema_owned) = match remote_value { + Ok((api_config, schema_owned)) => (api_config, schema_owned), + Err(e @ RestError::SchemaTooLarge) => { + // for the case where the schema is too large, we cache an empty dummy value + // all the other requests will fail without triggering the introspection query + let schema_owned = serde_json::from_str::(EMPTY_JSON_SCHEMA) + .map_err(|e| JsonDeserialize { source: e })?; - let api_config = ApiConfig { - db_schemas: vec![], - db_anon_role: None, - db_max_rows: None, - db_allowed_select_functions: vec![], - role_claim_key: String::new(), - db_extra_search_path: None, - }; - let value = Arc::new((api_config, schema_owned)); - count_cache_insert(CacheKind::Schema); - self.0.insert(endpoint_id.clone(), value); - return Err(e); - } - Err(e) => { - return Err(e); - } + let api_config = ApiConfig { + db_schemas: vec![], + db_anon_role: None, + db_max_rows: None, + db_allowed_select_functions: vec![], + role_claim_key: String::new(), + db_extra_search_path: None, + server_cors_allowed_origins: None, }; let value = Arc::new((api_config, schema_owned)); count_cache_insert(CacheKind::Schema); - self.0.insert(endpoint_id.clone(), value.clone()); - Ok(value) + self.0.insert(endpoint_id.clone(), value); + return Err(e); } - } + Err(e) => { + return Err(e); + } + }; + let value = Arc::new((api_config, schema_owned)); + count_cache_insert(CacheKind::Schema); + self.0.insert(endpoint_id.clone(), value.clone()); + Ok(value) } - pub async fn get_remote( + async fn internal_get_remote( &self, auth_header: &HeaderValue, connection_string: &str, @@ -531,7 +548,7 @@ pub(crate) async fn handle( ) -> Result>, ApiError> { let result = handle_inner(cancel, config, &ctx, request, backend).await; - let mut response = match result { + let response = match result { Ok(r) => { ctx.set_success(); @@ -640,9 +657,6 @@ pub(crate) async fn handle( } }; - response - .headers_mut() - .insert("Access-Control-Allow-Origin", HeaderValue::from_static("*")); Ok(response) } @@ -722,6 +736,37 @@ async fn handle_inner( } } +fn apply_common_cors_headers( + response: &mut Builder, + request_headers: &HeaderMap, + allowed_origins: Option<&Vec>, +) { + let request_origin = request_headers + .get(ORIGIN) + .map(|v| v.to_str().unwrap_or("")); + + let response_allow_origin = match (request_origin, allowed_origins) { + (Some(or), Some(allowed_origins)) => { + if allowed_origins.iter().any(|o| o == or) { + Some(HeaderValue::from_str(or).unwrap_or(HEADER_VALUE_ALLOW_ALL_ORIGINS)) + } else { + None + } + } + (Some(_), None) => Some(HEADER_VALUE_ALLOW_ALL_ORIGINS), + _ => None, + }; + if let Some(h) = response.headers_mut() { + h.insert( + ACCESS_CONTROL_EXPOSE_HEADERS, + ACCESS_CONTROL_EXPOSE_HEADERS_VALUE, + ); + if let Some(origin) = response_allow_origin { + h.insert(ACCESS_CONTROL_ALLOW_ORIGIN, origin); + } + } +} + #[allow(clippy::too_many_arguments)] async fn handle_rest_inner( config: &'static ProxyConfig, @@ -733,12 +778,6 @@ async fn handle_rest_inner( jwt: String, backend: Arc, ) -> Result>, RestError> { - // validate the jwt token - let jwt_parsed = backend - .authenticate_with_jwt(ctx, &conn_info.user_info, jwt) - .await - .map_err(HttpConnError::from)?; - let db_schema_cache = config .rest_config @@ -754,28 +793,83 @@ async fn handle_rest_inner( message: "Failed to get endpoint cache key".to_string(), }))?; - let mut client = backend.connect_to_local_proxy(ctx, conn_info).await?; - let (parts, originial_body) = request.into_parts(); + // try and get the cached entry for this endpoint + // it contains the api config and the introspected db schema + let cached_entry = db_schema_cache.get_cached(&endpoint_cache_key); + + let allowed_origins = cached_entry + .as_ref() + .and_then(|arc| arc.0.server_cors_allowed_origins.as_ref()); + + let mut response = Response::builder(); + apply_common_cors_headers(&mut response, &parts.headers, allowed_origins); + + // handle the OPTIONS request + if parts.method == Method::OPTIONS { + let allowed_headers = parts + .headers + .get(ACCESS_CONTROL_REQUEST_HEADERS) + .and_then(|a| a.to_str().ok()) + .filter(|v| !v.is_empty()) + .map_or_else( + || "Authorization".to_string(), + |v| format!("{v}, Authorization"), + ); + return response + .status(StatusCode::OK) + .header( + ACCESS_CONTROL_ALLOW_METHODS, + ACCESS_CONTROL_ALLOW_METHODS_VALUE, + ) + .header(ACCESS_CONTROL_MAX_AGE, ACCESS_CONTROL_MAX_AGE_VALUE) + .header( + ACCESS_CONTROL_ALLOW_HEADERS, + HeaderValue::from_str(&allowed_headers) + .unwrap_or(ACCESS_CONTROL_ALLOW_HEADERS_VALUE), + ) + .header(ALLOW, ACCESS_CONTROL_ALLOW_METHODS_VALUE) + .body(Empty::new().map_err(|x| match x {}).boxed()) + .map_err(|e| { + RestError::SubzeroCore(InternalError { + message: e.to_string(), + }) + }); + } + + // validate the jwt token + let jwt_parsed = backend + .authenticate_with_jwt(ctx, &conn_info.user_info, jwt) + .await + .map_err(HttpConnError::from)?; + let auth_header = parts .headers .get(AUTHORIZATION) .ok_or(RestError::SubzeroCore(InternalError { message: "Authorization header is required".to_string(), }))?; + let mut client = backend.connect_to_local_proxy(ctx, conn_info).await?; - let entry = db_schema_cache - .get_cached_or_remote( - &endpoint_cache_key, - auth_header, - connection_string, - &mut client, - ctx, - config, - ) - .await?; + let entry = match cached_entry { + Some(e) => e, + None => { + // if not cached, get the remote entry (will run the introspection query) + db_schema_cache + .get_remote( + &endpoint_cache_key, + auth_header, + connection_string, + &mut client, + ctx, + config, + ) + .await? + } + }; let (api_config, db_schema_owned) = entry.as_ref(); + let db_schema = db_schema_owned.borrow_schema(); let db_schemas = &api_config.db_schemas; // list of schemas available for the api @@ -999,8 +1093,8 @@ async fn handle_rest_inner( let _metrics = client.metrics(ctx); // FIXME: is everything in the context set correctly? // send the request to the local proxy - let response = make_raw_local_proxy_request(&mut client, headers, req_body).await?; - let (parts, body) = response.into_parts(); + let proxy_response = make_raw_local_proxy_request(&mut client, headers, req_body).await?; + let (response_parts, body) = proxy_response.into_parts(); let max_response = config.http_config.max_response_size_bytes; let bytes = read_body_with_limit(body, max_response) @@ -1009,7 +1103,7 @@ async fn handle_rest_inner( // if the response status is greater than 399, then it is an error // FIXME: check if there are other error codes or shapes of the response - if parts.status.as_u16() > 399 { + if response_parts.status.as_u16() > 399 { // turn this postgres error from the json into PostgresError let postgres_error = serde_json::from_slice(&bytes) .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?; @@ -1175,7 +1269,7 @@ async fn handle_rest_inner( .boxed(); // build the response - let mut response = Response::builder() + response = response .status(StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)) .header(CONTENT_TYPE, http_content_type);