Compare commits

..

3 Commits

Author SHA1 Message Date
John Spray
ff6143975a lint 2025-07-02 13:44:13 +01:00
Haoyu Huang
fbbf578c6b SK: re-elect leader when backup lag is high (#781)
We observe that the offloader fails to upload a segment due to race
conditions on XLOG SWITCH and PG start streaming WALs. wal_backup task
continously failing to upload a full segment while the segment remains
partial on the disk.

The consequence is that commit_lsn for all SKs move forward but
backup_lsn stays the same. Then, all SKs run out of disk space.

See go/sk-ood-xlog-switch for more details.

To mitigate this issue, we will re-elect a new offloader if the current
offloader is lagging behind too much.
Each SK makes the decision locally but they are aware of each other's
commit and backup lsns.

The new algorithm is
- determine_offloader will pick a SK. say SK-1.
- Each SK checks
-- if commit_lsn - back_lsn > threshold,
-- -- remove SK-1 from the candidate and call determine_offloader again.

SK-1 will step down and all SKs will elect the same leader again.
After the backup is caught up, the leader will become SK-1 again.

This also helps when SK-1 is slow to backup.

I'll set the reelect backup lag to 4 GB later. Setting to 128 MB in dev
to trigger the code more frequently.

DEV.

(cherry picked from commit 7286f79f9536380d321e2442318bd8a631269499)
2025-07-02 08:32:45 +01:00
Alex Chi Z.
6c77638ea1 feat(storcon): retrieve feature flag and pass to pageservers (#12324)
## Problem

part of https://github.com/neondatabase/neon/issues/11813

## Summary of changes

It costs $$$ to directly retrieve the feature flags from the pageserver.
Therefore, this patch adds new APIs to retrieve the spec from the
storcon and updates it via pageserver.

* Storcon retrieves the feature flag and send it to the pageservers.
* If the feature flag gets updated outside of the normal refresh loop of
the pageserver, pageserver won't fetch the flags on its own as long as
the last updated time <= refresh_period.

Signed-off-by: Alex Chi Z <chi@neon.tech>
2025-06-25 14:58:18 +00:00
36 changed files with 922 additions and 287 deletions

2
Cargo.lock generated
View File

@@ -6815,6 +6815,7 @@ dependencies = [
"hex",
"http-utils",
"humantime",
"humantime-serde",
"hyper 0.14.30",
"itertools 0.10.5",
"json-structural-diff",
@@ -6825,6 +6826,7 @@ dependencies = [
"pageserver_api",
"pageserver_client",
"postgres_connection",
"posthog_client_lite",
"rand 0.8.5",
"regex",
"reqwest",

View File

@@ -12,6 +12,7 @@ use std::{env, fs};
use anyhow::{Context, bail};
use clap::ValueEnum;
use pageserver_api::config::PostHogConfig;
use pem::Pem;
use postgres_backend::AuthType;
use reqwest::{Certificate, Url};
@@ -213,6 +214,8 @@ pub struct NeonStorageControllerConf {
pub timeline_safekeeper_count: Option<i64>,
pub posthog_config: Option<PostHogConfig>,
pub kick_secondary_downloads: Option<bool>,
}
@@ -245,6 +248,7 @@ impl Default for NeonStorageControllerConf {
use_https_safekeeper_api: false,
use_local_compute_notifications: true,
timeline_safekeeper_count: None,
posthog_config: None,
kick_secondary_downloads: None,
}
}

View File

@@ -642,6 +642,18 @@ impl StorageController {
args.push(format!("--timeline-safekeeper-count={sk_cnt}"));
}
let mut envs = vec![
("LD_LIBRARY_PATH".to_owned(), pg_lib_dir.to_string()),
("DYLD_LIBRARY_PATH".to_owned(), pg_lib_dir.to_string()),
];
if let Some(posthog_config) = &self.config.posthog_config {
envs.push((
"POSTHOG_CONFIG".to_string(),
serde_json::to_string(posthog_config)?,
));
}
println!("Starting storage controller");
background_process::start_process(
@@ -649,10 +661,7 @@ impl StorageController {
&instance_dir,
&self.env.storage_controller_bin(),
args,
vec![
("LD_LIBRARY_PATH".to_owned(), pg_lib_dir.to_string()),
("DYLD_LIBRARY_PATH".to_owned(), pg_lib_dir.to_string()),
],
envs,
background_process::InitialPidFile::Create(self.pid_file(start_args.instance_id)),
&start_args.start_timeout,
|| async {

View File

@@ -63,7 +63,8 @@ impl Display for NodeMetadata {
}
}
/// PostHog integration config.
/// PostHog integration config. This is used in pageserver, storcon, and neon_local.
/// Ensure backward compatibility when adding new fields.
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct PostHogConfig {
/// PostHog project ID
@@ -76,7 +77,9 @@ pub struct PostHogConfig {
pub private_api_url: String,
/// Public API URL
pub public_api_url: String,
/// Refresh interval for the feature flag spec
/// Refresh interval for the feature flag spec.
/// The storcon will push the feature flag spec to the pageserver. If the pageserver does not receive
/// the spec for `refresh_interval`, it will fetch the spec from the PostHog API.
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(with = "humantime_serde")]
pub refresh_interval: Option<Duration>,

View File

@@ -1,17 +1,22 @@
//! A background loop that fetches feature flags from PostHog and updates the feature store.
use std::{sync::Arc, time::Duration};
use std::{
sync::Arc,
time::{Duration, SystemTime},
};
use arc_swap::ArcSwap;
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, info_span};
use crate::{CaptureEvent, FeatureStore, PostHogClient, PostHogClientConfig};
use crate::{
CaptureEvent, FeatureStore, LocalEvaluationResponse, PostHogClient, PostHogClientConfig,
};
/// A background loop that fetches feature flags from PostHog and updates the feature store.
pub struct FeatureResolverBackgroundLoop {
posthog_client: PostHogClient,
feature_store: ArcSwap<FeatureStore>,
feature_store: ArcSwap<(SystemTime, Arc<FeatureStore>)>,
cancel: CancellationToken,
}
@@ -19,11 +24,35 @@ impl FeatureResolverBackgroundLoop {
pub fn new(config: PostHogClientConfig, shutdown_pageserver: CancellationToken) -> Self {
Self {
posthog_client: PostHogClient::new(config),
feature_store: ArcSwap::new(Arc::new(FeatureStore::new())),
feature_store: ArcSwap::new(Arc::new((
SystemTime::UNIX_EPOCH,
Arc::new(FeatureStore::new()),
))),
cancel: shutdown_pageserver,
}
}
/// Update the feature store with a new feature flag spec bypassing the normal refresh loop.
pub fn update(&self, spec: String) -> anyhow::Result<()> {
let resp: LocalEvaluationResponse = serde_json::from_str(&spec)?;
self.update_feature_store_nofail(resp, "http_propagate");
Ok(())
}
fn update_feature_store_nofail(&self, resp: LocalEvaluationResponse, source: &'static str) {
let project_id = self.posthog_client.config.project_id.parse::<u64>().ok();
match FeatureStore::new_with_flags(resp.flags, project_id) {
Ok(feature_store) => {
self.feature_store
.store(Arc::new((SystemTime::now(), Arc::new(feature_store))));
tracing::info!("Feature flag updated from {}", source);
}
Err(e) => {
tracing::warn!("Cannot process feature flag spec from {}: {}", source, e);
}
}
}
pub fn spawn(
self: Arc<Self>,
handle: &tokio::runtime::Handle,
@@ -47,6 +76,17 @@ impl FeatureResolverBackgroundLoop {
_ = ticker.tick() => {}
_ = cancel.cancelled() => break
}
{
let last_update = this.feature_store.load().0;
if let Ok(elapsed) = last_update.elapsed() {
if elapsed < refresh_period {
tracing::debug!(
"Skipping feature flag refresh because it's too soon"
);
continue;
}
}
}
let resp = match this
.posthog_client
.get_feature_flags_local_evaluation()
@@ -58,16 +98,7 @@ impl FeatureResolverBackgroundLoop {
continue;
}
};
let project_id = this.posthog_client.config.project_id.parse::<u64>().ok();
match FeatureStore::new_with_flags(resp.flags, project_id) {
Ok(feature_store) => {
this.feature_store.store(Arc::new(feature_store));
tracing::info!("Feature flag updated");
}
Err(e) => {
tracing::warn!("Cannot process feature flag spec: {}", e);
}
}
this.update_feature_store_nofail(resp, "refresh_loop");
}
tracing::info!("PostHog feature resolver stopped");
}
@@ -92,6 +123,6 @@ impl FeatureResolverBackgroundLoop {
}
pub fn feature_store(&self) -> Arc<FeatureStore> {
self.feature_store.load_full()
self.feature_store.load().1.clone()
}
}

View File

@@ -544,17 +544,8 @@ impl PostHogClient {
self.config.server_api_key.starts_with("phs_")
}
/// Fetch the feature flag specs from the server.
///
/// This is unfortunately an undocumented API at:
/// - <https://posthog.com/docs/api/feature-flags#get-api-projects-project_id-feature_flags-local_evaluation>
/// - <https://posthog.com/docs/feature-flags/local-evaluation>
///
/// The handling logic in [`FeatureStore`] mostly follows the Python API implementation.
/// See `_compute_flag_locally` in <https://github.com/PostHog/posthog-python/blob/master/posthog/client.py>
pub async fn get_feature_flags_local_evaluation(
&self,
) -> anyhow::Result<LocalEvaluationResponse> {
/// Get the raw JSON spec, same as `get_feature_flags_local_evaluation` but without parsing.
pub async fn get_feature_flags_local_evaluation_raw(&self) -> anyhow::Result<String> {
// BASE_URL/api/projects/:project_id/feature_flags/local_evaluation
// with bearer token of self.server_api_key
// OR
@@ -588,7 +579,22 @@ impl PostHogClient {
body
));
}
Ok(serde_json::from_str(&body)?)
Ok(body)
}
/// Fetch the feature flag specs from the server.
///
/// This is unfortunately an undocumented API at:
/// - <https://posthog.com/docs/api/feature-flags#get-api-projects-project_id-feature_flags-local_evaluation>
/// - <https://posthog.com/docs/feature-flags/local-evaluation>
///
/// The handling logic in [`FeatureStore`] mostly follows the Python API implementation.
/// See `_compute_flag_locally` in <https://github.com/PostHog/posthog-python/blob/master/posthog/client.py>
pub async fn get_feature_flags_local_evaluation(
&self,
) -> Result<LocalEvaluationResponse, anyhow::Error> {
let raw = self.get_feature_flags_local_evaluation_raw().await?;
Ok(serde_json::from_str(&raw)?)
}
/// Capture an event. This will only be used to report the feature flag usage back to PostHog, though

View File

@@ -1,13 +1,12 @@
use std::net::IpAddr;
use postgres_protocol2::message::backend::Message;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio::sync::mpsc;
use crate::client::SocketConfig;
use crate::codec::BackendMessage;
use crate::config::{Host, SslMode};
use crate::config::Host;
use crate::connect_raw::connect_raw;
use crate::connect_socket::connect_socket;
use crate::connect_tls::connect_tls;
@@ -47,7 +46,13 @@ where
{
let socket = connect_socket(host_addr, host, port, config.connect_timeout).await?;
let stream = connect_tls(socket, config.ssl_mode, tls).await?;
let raw = connect_raw(stream, config).await?;
let RawConnection {
stream,
parameters,
delayed_notice,
process_id,
secret_key,
} = connect_raw(stream, config).await?;
let socket_config = SocketConfig {
host_addr,
@@ -56,46 +61,24 @@ where
connect_timeout: config.connect_timeout,
};
Ok(raw.into_managed_conn(socket_config, config.ssl_mode))
}
impl<S, T> RawConnection<S, T>
where
S: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin,
{
pub fn into_managed_conn(
self,
socket_config: SocketConfig,
ssl_mode: SslMode,
) -> (Client, Connection<S, T>) {
let RawConnection {
stream,
parameters,
delayed_notice,
process_id,
secret_key,
} = self;
let (client_tx, conn_rx) = mpsc::unbounded_channel();
let (conn_tx, client_rx) = mpsc::channel(4);
let client = Client::new(
client_tx,
client_rx,
socket_config,
ssl_mode,
process_id,
secret_key,
);
// delayed notices are always sent as "Async" messages.
let delayed = delayed_notice
.into_iter()
.map(|m| BackendMessage::Async(Message::NoticeResponse(m)))
.collect();
let connection = Connection::new(stream, delayed, parameters, conn_tx, conn_rx);
(client, connection)
}
let (client_tx, conn_rx) = mpsc::unbounded_channel();
let (conn_tx, client_rx) = mpsc::channel(4);
let client = Client::new(
client_tx,
client_rx,
socket_config,
config.ssl_mode,
process_id,
secret_key,
);
// delayed notices are always sent as "Async" messages.
let delayed = delayed_notice
.into_iter()
.map(|m| BackendMessage::Async(Message::NoticeResponse(m)))
.collect();
let connection = Connection::new(stream, delayed, parameters, conn_tx, conn_rx);
Ok((client, connection))
}

View File

@@ -844,4 +844,13 @@ impl Client {
.await
.map_err(Error::ReceiveBody)
}
pub async fn update_feature_flag_spec(&self, spec: String) -> Result<()> {
let uri = format!("{}/v1/feature_flag_spec", self.mgmt_api_endpoint);
self.request(Method::POST, uri, spec)
.await?
.json()
.await
.map_err(Error::ReceiveBody)
}
}

View File

@@ -31,6 +31,13 @@ impl FeatureResolver {
}
}
pub fn update(&self, spec: String) -> anyhow::Result<()> {
if let Some(inner) = &self.inner {
inner.update(spec)?;
}
Ok(())
}
pub fn spawn(
conf: &PageServerConf,
shutdown_pageserver: CancellationToken,

View File

@@ -3743,6 +3743,20 @@ async fn force_override_feature_flag_for_testing_delete(
json_response(StatusCode::OK, ())
}
async fn update_feature_flag_spec(
mut request: Request<Body>,
_cancel: CancellationToken,
) -> Result<Response<Body>, ApiError> {
check_permission(&request, None)?;
let body = json_request(&mut request).await?;
let state = get_state(&request);
state
.feature_resolver
.update(body)
.map_err(ApiError::InternalServerError)?;
json_response(StatusCode::OK, ())
}
/// Common functionality of all the HTTP API handlers.
///
/// - Adds a tracing span to each request (by `request_span`)
@@ -4128,5 +4142,8 @@ pub fn make_router(
.delete("/v1/feature_flag/:flag_key", |r| {
testing_api_handler("force override feature flag - delete", r, force_override_feature_flag_for_testing_delete)
})
.post("/v1/feature_flag_spec", |r| {
api_handler(r, update_feature_flag_spec)
})
.any(handler_404))
}

View File

@@ -350,7 +350,7 @@ impl CancellationHandler {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CancelClosure {
socket_addr: SocketAddr,
pub cancel_token: RawCancelToken,
cancel_token: RawCancelToken,
hostname: String, // for pg_sni router
user_info: ComputeUserInfo,
}

View File

@@ -86,14 +86,6 @@ pub(crate) enum ConnectionError {
#[error("error acquiring resource permit: {0}")]
TooManyConnectionAttempts(#[from] ApiLockError),
#[cfg(test)]
#[error("retryable: {retryable}, wakeable: {wakeable}, kind: {kind:?}")]
TestError {
retryable: bool,
wakeable: bool,
kind: crate::error::ErrorKind,
},
}
impl UserFacingError for ConnectionError {
@@ -104,8 +96,6 @@ impl UserFacingError for ConnectionError {
"Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
}
ConnectionError::TlsError(_) => COULD_NOT_CONNECT.to_owned(),
#[cfg(test)]
ConnectionError::TestError { .. } => self.to_string(),
}
}
}
@@ -116,8 +106,6 @@ impl ReportableError for ConnectionError {
ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
ConnectionError::WakeComputeError(e) => e.get_error_kind(),
ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(),
#[cfg(test)]
ConnectionError::TestError { kind, .. } => *kind,
}
}
}
@@ -264,19 +252,6 @@ impl AuthInfo {
.await?;
drop(pause);
// TODO: lots of useful info but maybe we can move it elsewhere (eg traces?)
info!(
compute_id = %compute.aux.compute_id,
pid = connection.process_id,
cold_start_info = ctx.cold_start_info().as_str(),
query_id = ctx.get_testodrome_id().as_deref(),
sslmode = ?compute.ssl_mode,
"connected to compute node at {} ({}) latency={}",
compute.hostname,
compute.socket_addr,
ctx.get_proxy_latency(),
);
let RawConnection {
stream: _,
parameters,
@@ -285,6 +260,8 @@ impl AuthInfo {
secret_key,
} = connection;
tracing::Span::current().record("pid", tracing::field::display(process_id));
// NB: CancelToken is supposed to hold socket_addr, but we use connect_raw.
// Yet another reason to rework the connection establishing code.
let cancel_closure = CancelClosure::new(
@@ -311,7 +288,6 @@ impl ConnectInfo {
async fn connect_raw(
&self,
config: &ComputeConfig,
direct: bool,
) -> Result<(SocketAddr, MaybeTlsStream<TcpStream, RustlsStream>), TlsError> {
let timeout = config.timeout;
@@ -354,7 +330,7 @@ impl ConnectInfo {
match connect_once(&*addrs).await {
Ok((sockaddr, stream)) => Ok((
sockaddr,
tls::connect_tls(stream, self.ssl_mode, config, host, direct).await?,
tls::connect_tls(stream, self.ssl_mode, config, host).await?,
)),
Err(err) => {
warn!("couldn't connect to compute node at {host}:{port}: {err}");
@@ -381,7 +357,7 @@ pub struct PostgresSettings {
pub struct ComputeConnection {
/// Socket connected to a compute node.
pub stream: MaybeRustlsStream,
pub stream: MaybeTlsStream<tokio::net::TcpStream, RustlsStream>,
/// Labels for proxy's metrics.
pub aux: MetricsAuxInfo,
pub hostname: Host,
@@ -397,12 +373,23 @@ impl ConnectInfo {
ctx: &RequestContext,
aux: &MetricsAuxInfo,
config: &ComputeConfig,
direct: bool,
) -> Result<ComputeConnection, ConnectionError> {
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let (socket_addr, stream) = self.connect_raw(config, direct).await?;
let (socket_addr, stream) = self.connect_raw(config).await?;
drop(pause);
tracing::Span::current().record("compute_id", tracing::field::display(&aux.compute_id));
// TODO: lots of useful info but maybe we can move it elsewhere (eg traces?)
info!(
cold_start_info = ctx.cold_start_info().as_str(),
"connected to compute node at {} ({socket_addr}) sslmode={:?}, latency={}, query_id={}",
self.host,
self.ssl_mode,
ctx.get_proxy_latency(),
ctx.get_testodrome_id().unwrap_or_default(),
);
let connection = ComputeConnection {
stream,
socket_addr,

View File

@@ -11,6 +11,8 @@ use crate::proxy::retry::CouldRetry;
#[derive(Debug, Error)]
pub enum TlsError {
#[error(transparent)]
Dns(#[from] InvalidDnsNameError),
#[error(transparent)]
Connection(#[from] std::io::Error),
#[error("TLS required but not provided")]
@@ -20,6 +22,7 @@ pub enum TlsError {
impl CouldRetry for TlsError {
fn could_retry(&self) -> bool {
match self {
TlsError::Dns(_) => false,
TlsError::Connection(err) => err.could_retry(),
// perhaps compute didn't realise it supports TLS?
TlsError::Required => true,
@@ -32,7 +35,6 @@ pub async fn connect_tls<S, T>(
mode: SslMode,
tls: &T,
host: &str,
direct: bool,
) -> Result<MaybeTlsStream<S, T::Stream>, TlsError>
where
S: AsyncRead + AsyncWrite + Unpin + Send,
@@ -47,7 +49,7 @@ where
SslMode::Prefer | SslMode::Require => {}
}
if !direct && !request_tls(&mut stream).await? {
if !request_tls(&mut stream).await? {
if SslMode::Require == mode {
return Err(TlsError::Required);
}
@@ -55,6 +57,7 @@ where
return Ok(MaybeTlsStream::Raw(stream));
}
let c = tls.make_tls_connect(host).map_err(std::io::Error::other)?;
Ok(MaybeTlsStream::Tls(c.connect(stream).boxed().await?))
Ok(MaybeTlsStream::Tls(
tls.make_tls_connect(host)?.connect(stream).boxed().await?,
))
}

View File

@@ -222,7 +222,6 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
ctx,
&TcpMechanism {
locks: &config.connect_compute_locks,
direct: false,
},
&node_info,
config.wake_compute_retry_config,

View File

@@ -263,12 +263,7 @@ impl NeonControlPlaneClient {
None => SslMode::Disable,
};
let host = match body.server_name {
Some(host) => {
if rustls::pki_types::DnsName::try_from_str(&host).is_err() {
return Err(WakeComputeError::BadComputeAddress(host.into_boxed_str()));
}
host.into()
}
Some(host) => host.into(),
None => host.into(),
};

View File

@@ -77,9 +77,8 @@ impl NodeInfo {
&self,
ctx: &RequestContext,
config: &ComputeConfig,
direct: bool,
) -> Result<compute::ComputeConnection, compute::ConnectionError> {
self.conn_info.connect(ctx, &self.aux, config, direct).await
self.conn_info.connect(ctx, &self.aux, config).await
}
}

View File

@@ -1,15 +1,18 @@
use async_trait::async_trait;
use tokio::time;
use tracing::{debug, info, warn};
use crate::compute::{self, COULD_NOT_CONNECT, ComputeConnection};
use crate::config::{ComputeConfig, RetryConfig};
use crate::context::RequestContext;
use crate::control_plane::errors::WakeComputeError;
use crate::control_plane::locks::ApiLocks;
use crate::control_plane::{self, NodeInfo};
use crate::error::ReportableError;
use crate::metrics::{
ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType,
};
use crate::proxy::retry::{ShouldRetryWakeCompute, retry_after, should_retry};
use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute, retry_after, should_retry};
use crate::proxy::wake_compute::{WakeComputeBackend, wake_compute};
use crate::types::Host;
@@ -32,34 +35,42 @@ pub(crate) fn invalidate_cache(node_info: control_plane::CachedNodeInfo) -> Node
node_info.invalidate()
}
#[async_trait]
pub(crate) trait ConnectMechanism {
type Connection;
type ConnectError: ReportableError;
type Error: From<Self::ConnectError>;
async fn connect_once(
&self,
ctx: &RequestContext,
node_info: &control_plane::CachedNodeInfo,
config: &ComputeConfig,
) -> Result<Self::Connection, compute::ConnectionError>;
) -> Result<Self::Connection, Self::ConnectError>;
}
pub(crate) struct TcpMechanism {
/// connect_to_compute concurrency lock
pub(crate) locks: &'static ApiLocks<Host>,
// whether to negotiate TLS for postgres protocol.
pub(crate) direct: bool,
}
#[async_trait]
impl ConnectMechanism for TcpMechanism {
type Connection = ComputeConnection;
type ConnectError = compute::ConnectionError;
type Error = compute::ConnectionError;
#[tracing::instrument(skip_all, fields(
pid = tracing::field::Empty,
compute_id = tracing::field::Empty
))]
async fn connect_once(
&self,
ctx: &RequestContext,
node_info: &control_plane::CachedNodeInfo,
config: &ComputeConfig,
) -> Result<ComputeConnection, compute::ConnectionError> {
) -> Result<ComputeConnection, Self::Error> {
let permit = self.locks.get_permit(&node_info.conn_info.host).await?;
permit.release_result(node_info.connect(ctx, config, self.direct).await)
permit.release_result(node_info.connect(ctx, config).await)
}
}
@@ -71,7 +82,11 @@ pub(crate) async fn connect_to_compute<M: ConnectMechanism, B: WakeComputeBacken
user_info: &B,
wake_compute_retry_config: RetryConfig,
compute: &ComputeConfig,
) -> Result<M::Connection, compute::ConnectionError> {
) -> Result<M::Connection, M::Error>
where
M::ConnectError: CouldRetry + ShouldRetryWakeCompute + std::fmt::Debug,
M::Error: From<WakeComputeError>,
{
let mut num_retries = 0;
let node_info =
wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?;
@@ -105,7 +120,7 @@ pub(crate) async fn connect_to_compute<M: ConnectMechanism, B: WakeComputeBacken
},
num_retries.into(),
);
return Err(err);
return Err(err.into());
}
node_info
} else {
@@ -146,7 +161,7 @@ pub(crate) async fn connect_to_compute<M: ConnectMechanism, B: WakeComputeBacken
},
num_retries.into(),
);
return Err(e);
return Err(e.into());
}
warn!(error = ?e, num_retries, retriable = true, COULD_NOT_CONNECT);

View File

@@ -358,7 +358,6 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
ctx,
&TcpMechanism {
locks: &config.connect_compute_locks,
direct: false,
},
&auth::Backend::ControlPlane(cplane, creds.info.clone()),
config.wake_compute_retry_config,

View File

@@ -1,3 +1,4 @@
use std::error::Error;
use std::io;
use tokio::time;
@@ -30,25 +31,85 @@ impl CouldRetry for io::Error {
}
}
impl CouldRetry for postgres_client::error::DbError {
fn could_retry(&self) -> bool {
use postgres_client::error::SqlState;
matches!(
self.code(),
&SqlState::CONNECTION_FAILURE
| &SqlState::CONNECTION_EXCEPTION
| &SqlState::CONNECTION_DOES_NOT_EXIST
| &SqlState::SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION,
)
}
}
impl ShouldRetryWakeCompute for postgres_client::error::DbError {
fn should_retry_wake_compute(&self) -> bool {
use postgres_client::error::SqlState;
// Here are errors that happens after the user successfully authenticated to the database.
// TODO: there are pgbouncer errors that should be retried, but they are not listed here.
let non_retriable_pg_errors = matches!(
self.code(),
&SqlState::TOO_MANY_CONNECTIONS
| &SqlState::OUT_OF_MEMORY
| &SqlState::SYNTAX_ERROR
| &SqlState::T_R_SERIALIZATION_FAILURE
| &SqlState::INVALID_CATALOG_NAME
| &SqlState::INVALID_SCHEMA_NAME
| &SqlState::INVALID_PARAMETER_VALUE,
);
if non_retriable_pg_errors {
return false;
}
// PGBouncer errors that should not trigger a wake_compute retry.
if self.code() == &SqlState::PROTOCOL_VIOLATION {
// Source for the error message:
// https://github.com/pgbouncer/pgbouncer/blob/f15997fe3effe3a94ba8bcc1ea562e6117d1a131/src/client.c#L1070
return !self
.message()
.contains("no more connections allowed (max_client_conn)");
}
true
}
}
impl CouldRetry for postgres_client::Error {
fn could_retry(&self) -> bool {
if let Some(io_err) = self.source().and_then(|x| x.downcast_ref()) {
io::Error::could_retry(io_err)
} else if let Some(db_err) = self.source().and_then(|x| x.downcast_ref()) {
postgres_client::error::DbError::could_retry(db_err)
} else {
false
}
}
}
impl ShouldRetryWakeCompute for postgres_client::Error {
fn should_retry_wake_compute(&self) -> bool {
if let Some(db_err) = self.source().and_then(|x| x.downcast_ref()) {
postgres_client::error::DbError::should_retry_wake_compute(db_err)
} else {
// likely an IO error. Possible the compute has shutdown and the
// cache is stale.
true
}
}
}
impl CouldRetry for compute::ConnectionError {
fn could_retry(&self) -> bool {
match self {
compute::ConnectionError::TlsError(err) => err.could_retry(),
compute::ConnectionError::WakeComputeError(err) => err.could_retry(),
compute::ConnectionError::TooManyConnectionAttempts(_) => false,
#[cfg(test)]
compute::ConnectionError::TestError { retryable, .. } => *retryable,
}
}
}
impl ShouldRetryWakeCompute for compute::ConnectionError {
fn should_retry_wake_compute(&self) -> bool {
match self {
// the cache entry was not checked for validity
compute::ConnectionError::TooManyConnectionAttempts(_) => false,
#[cfg(test)]
compute::ConnectionError::TestError { wakeable, .. } => *wakeable,
_ => true,
}
}
@@ -59,3 +120,56 @@ pub(crate) fn retry_after(num_retries: u32, config: RetryConfig) -> time::Durati
.base_delay
.mul_f64(config.backoff_factor.powi((num_retries as i32) - 1))
}
#[cfg(test)]
mod tests {
use postgres_client::error::{DbError, SqlState};
use super::ShouldRetryWakeCompute;
#[test]
fn should_retry_wake_compute_for_db_error() {
// These SQLStates should NOT trigger a wake_compute retry.
let non_retry_states = [
SqlState::TOO_MANY_CONNECTIONS,
SqlState::OUT_OF_MEMORY,
SqlState::SYNTAX_ERROR,
SqlState::T_R_SERIALIZATION_FAILURE,
SqlState::INVALID_CATALOG_NAME,
SqlState::INVALID_SCHEMA_NAME,
SqlState::INVALID_PARAMETER_VALUE,
];
for state in non_retry_states {
let err = DbError::new_test_error(state.clone(), "oops".to_string());
assert!(
!err.should_retry_wake_compute(),
"State {state:?} unexpectedly retried"
);
}
// Errors coming from pgbouncer should not trigger a wake_compute retry
let non_retry_pgbouncer_errors = ["no more connections allowed (max_client_conn)"];
for error in non_retry_pgbouncer_errors {
let err = DbError::new_test_error(SqlState::PROTOCOL_VIOLATION, error.to_string());
assert!(
!err.should_retry_wake_compute(),
"PGBouncer error {error:?} unexpectedly retried"
);
}
// These SQLStates should trigger a wake_compute retry.
let retry_states = [
SqlState::CONNECTION_FAILURE,
SqlState::CONNECTION_EXCEPTION,
SqlState::CONNECTION_DOES_NOT_EXIST,
SqlState::SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION,
];
for state in retry_states {
let err = DbError::new_test_error(state.clone(), "oops".to_string());
assert!(
err.should_retry_wake_compute(),
"State {state:?} unexpectedly skipped retry"
);
}
}
}

View File

@@ -10,7 +10,7 @@ use async_trait::async_trait;
use http::StatusCode;
use postgres_client::config::SslMode;
use postgres_client::tls::{MakeTlsConnect, NoTls};
use retry::retry_after;
use retry::{ShouldRetryWakeCompute, retry_after};
use rstest::rstest;
use rustls::crypto::ring;
use rustls::pki_types;
@@ -20,7 +20,6 @@ use tracing_test::traced_test;
use super::retry::CouldRetry;
use super::*;
use crate::auth::backend::{ComputeUserInfo, MaybeOwned};
use crate::compute::ConnectionError;
use crate::config::{ComputeConfig, RetryConfig};
use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient};
use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status};
@@ -424,36 +423,71 @@ impl TestConnectMechanism {
#[derive(Debug)]
struct TestConnection;
#[derive(Debug)]
struct TestConnectError {
retryable: bool,
wakeable: bool,
kind: crate::error::ErrorKind,
}
impl ReportableError for TestConnectError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
self.kind
}
}
impl std::fmt::Display for TestConnectError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
}
}
impl std::error::Error for TestConnectError {}
impl CouldRetry for TestConnectError {
fn could_retry(&self) -> bool {
self.retryable
}
}
impl ShouldRetryWakeCompute for TestConnectError {
fn should_retry_wake_compute(&self) -> bool {
self.wakeable
}
}
#[async_trait]
impl ConnectMechanism for TestConnectMechanism {
type Connection = TestConnection;
type ConnectError = TestConnectError;
type Error = anyhow::Error;
async fn connect_once(
&self,
_ctx: &RequestContext,
_node_info: &control_plane::CachedNodeInfo,
_config: &ComputeConfig,
) -> Result<Self::Connection, ConnectionError> {
) -> Result<Self::Connection, Self::ConnectError> {
let mut counter = self.counter.lock().unwrap();
let action = self.sequence[*counter];
*counter += 1;
match action {
ConnectAction::Connect => Ok(TestConnection),
ConnectAction::Retry => Err(ConnectionError::TestError {
ConnectAction::Retry => Err(TestConnectError {
retryable: true,
wakeable: true,
kind: ErrorKind::Compute,
}),
ConnectAction::RetryNoWake => Err(ConnectionError::TestError {
ConnectAction::RetryNoWake => Err(TestConnectError {
retryable: true,
wakeable: false,
kind: ErrorKind::Compute,
}),
ConnectAction::Fail => Err(ConnectionError::TestError {
ConnectAction::Fail => Err(TestConnectError {
retryable: false,
wakeable: true,
kind: ErrorKind::Compute,
}),
ConnectAction::FailNoWake => Err(ConnectionError::TestError {
ConnectAction::FailNoWake => Err(TestConnectError {
retryable: false,
wakeable: false,
kind: ErrorKind::Compute,

View File

@@ -1,12 +1,17 @@
use std::io;
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use ed25519_dalek::SigningKey;
use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer};
use jose_jwk::jose_b64;
use postgres_client::SocketConfig;
use postgres_client::maybe_tls_stream::MaybeTlsStream;
use postgres_client::config::SslMode;
use rand::rngs::OsRng;
use rustls::pki_types::{DnsName, ServerName};
use tokio::net::{TcpStream, lookup_host};
use tokio_rustls::TlsConnector;
use tracing::field::display;
use tracing::{debug, info};
@@ -18,19 +23,21 @@ use super::local_conn_pool::{self, EXT_NAME, EXT_SCHEMA, EXT_VERSION, LocalConnP
use crate::auth::backend::local::StaticAuthRules;
use crate::auth::backend::{ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo};
use crate::auth::{self, AuthError};
use crate::compute::{self, ComputeConnection};
use crate::compute_ctl::{
ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest,
};
use crate::config::ProxyConfig;
use crate::config::{ComputeConfig, ProxyConfig};
use crate::context::RequestContext;
use crate::control_plane::CachedNodeInfo;
use crate::control_plane::client::ApiLockError;
use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
use crate::control_plane::locks::ApiLocks;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::intern::EndpointIdInt;
use crate::proxy::connect_compute::TcpMechanism;
use crate::proxy::connect_compute::ConnectMechanism;
use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute};
use crate::rate_limiter::EndpointRateLimiter;
use crate::types::{EndpointId, LOCAL_PROXY_SUFFIX};
use crate::types::{EndpointId, Host, LOCAL_PROXY_SUFFIX};
pub(crate) struct PoolingBackend {
pub(crate) http_conn_pool: Arc<GlobalConnPool<Send, HttpConnPool<Send>>>,
@@ -150,6 +157,11 @@ impl PoolingBackend {
// Wake up the destination if needed. Code here is a bit involved because
// we reuse the code from the usual proxy and we need to prepare few structures
// that this code expects.
#[tracing::instrument(skip_all, fields(
pid = tracing::field::Empty,
compute_id = tracing::field::Empty,
conn_id = tracing::field::Empty,
))]
pub(crate) async fn connect_to_compute(
&self,
ctx: &RequestContext,
@@ -169,24 +181,30 @@ impl PoolingBackend {
return Ok(client);
}
let conn_id = uuid::Uuid::new_v4();
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
let backend = self.auth_backend.as_ref().map(|()| keys.info);
let connection = crate::proxy::connect_compute::connect_to_compute(
crate::proxy::connect_compute::connect_to_compute(
ctx,
&TcpMechanism {
&TokioMechanism {
conn_id,
conn_info,
pool: self.pool.clone(),
locks: &self.config.connect_compute_locks,
direct: false,
keys: keys.keys,
},
&backend,
self.config.wake_compute_retry_config,
&self.config.connect_to_compute,
)
.await?;
authenticate(ctx, &self.pool, &conn_info, keys.keys, connection, conn_id).await
.await
}
// Wake up the destination if needed
#[tracing::instrument(skip_all, fields(
compute_id = tracing::field::Empty,
conn_id = tracing::field::Empty,
))]
pub(crate) async fn connect_to_local_proxy(
&self,
ctx: &RequestContext,
@@ -198,6 +216,7 @@ impl PoolingBackend {
}
let conn_id = uuid::Uuid::new_v4();
tracing::Span::current().record("conn_id", display(conn_id));
debug!(%conn_id, "pool: opening a new connection '{conn_info}'");
let backend = self.auth_backend.as_ref().map(|()| ComputeUserInfo {
user: conn_info.user_info.user.clone(),
@@ -207,19 +226,19 @@ impl PoolingBackend {
)),
options: conn_info.user_info.options.clone(),
});
let connection = crate::proxy::connect_compute::connect_to_compute(
crate::proxy::connect_compute::connect_to_compute(
ctx,
&TcpMechanism {
&HyperMechanism {
conn_id,
conn_info,
pool: self.http_conn_pool.clone(),
locks: &self.config.connect_compute_locks,
direct: true,
},
&backend,
self.config.wake_compute_retry_config,
&self.config.connect_to_compute,
)
.await?;
h2handshake(ctx, &self.http_conn_pool, &conn_info, connection, conn_id).await
.await
}
/// Connect to postgres over localhost.
@@ -229,6 +248,10 @@ impl PoolingBackend {
/// # Panics
///
/// Panics if called with a non-local_proxy backend.
#[tracing::instrument(skip_all, fields(
pid = tracing::field::Empty,
conn_id = tracing::field::Empty,
))]
pub(crate) async fn connect_to_local_postgres(
&self,
ctx: &RequestContext,
@@ -350,8 +373,6 @@ fn create_random_jwk() -> (SigningKey, jose_jwk::Key) {
pub(crate) enum HttpConnError {
#[error("pooled connection closed at inconsistent state")]
ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError<uuid::Uuid>),
#[error("could not connect to compute")]
ConnectError(#[from] compute::ConnectionError),
#[error("could not connect to postgres in compute")]
PostgresConnectionError(#[from] postgres_client::Error),
#[error("could not connect to local-proxy in compute")]
@@ -373,6 +394,8 @@ pub(crate) enum HttpConnError {
#[derive(Debug, thiserror::Error)]
pub(crate) enum LocalProxyConnError {
#[error("error with connection to local-proxy")]
Io(#[source] std::io::Error),
#[error("could not establish h2 connection")]
H2(#[from] hyper::Error),
}
@@ -380,7 +403,6 @@ pub(crate) enum LocalProxyConnError {
impl ReportableError for HttpConnError {
fn get_error_kind(&self) -> ErrorKind {
match self {
HttpConnError::ConnectError(_) => ErrorKind::Compute,
HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute,
HttpConnError::PostgresConnectionError(p) => p.get_error_kind(),
HttpConnError::LocalProxyConnectionError(_) => ErrorKind::Compute,
@@ -397,7 +419,6 @@ impl ReportableError for HttpConnError {
impl UserFacingError for HttpConnError {
fn to_string_client(&self) -> String {
match self {
HttpConnError::ConnectError(p) => p.to_string_client(),
HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(),
HttpConnError::PostgresConnectionError(p) => p.to_string(),
HttpConnError::LocalProxyConnectionError(p) => p.to_string(),
@@ -413,9 +434,36 @@ impl UserFacingError for HttpConnError {
}
}
impl CouldRetry for HttpConnError {
fn could_retry(&self) -> bool {
match self {
HttpConnError::PostgresConnectionError(e) => e.could_retry(),
HttpConnError::LocalProxyConnectionError(e) => e.could_retry(),
HttpConnError::ComputeCtl(_) => false,
HttpConnError::ConnectionClosedAbruptly(_) => false,
HttpConnError::JwtPayloadError(_) => false,
HttpConnError::GetAuthInfo(_) => false,
HttpConnError::AuthError(_) => false,
HttpConnError::WakeCompute(_) => false,
HttpConnError::TooManyConnectionAttempts(_) => false,
}
}
}
impl ShouldRetryWakeCompute for HttpConnError {
fn should_retry_wake_compute(&self) -> bool {
match self {
HttpConnError::PostgresConnectionError(e) => e.should_retry_wake_compute(),
// we never checked cache validity
HttpConnError::TooManyConnectionAttempts(_) => false,
_ => true,
}
}
}
impl ReportableError for LocalProxyConnError {
fn get_error_kind(&self) -> ErrorKind {
match self {
LocalProxyConnError::Io(_) => ErrorKind::Compute,
LocalProxyConnError::H2(_) => ErrorKind::Compute,
}
}
@@ -427,106 +475,208 @@ impl UserFacingError for LocalProxyConnError {
}
}
async fn authenticate(
ctx: &RequestContext,
pool: &Arc<GlobalConnPool<postgres_client::Client, EndpointConnPool<postgres_client::Client>>>,
conn_info: &ConnInfo,
keys: ComputeCredentialKeys,
compute: ComputeConnection,
conn_id: uuid::Uuid,
) -> Result<Client<postgres_client::Client>, HttpConnError> {
// client config with stubbed connect info.
let mut config = postgres_client::Config::new(String::new(), 0);
config
.user(&conn_info.user_info.user)
.dbname(&conn_info.dbname);
if let ComputeCredentialKeys::AuthKeys(auth_keys) = keys {
config.auth_keys(auth_keys);
impl CouldRetry for LocalProxyConnError {
fn could_retry(&self) -> bool {
match self {
LocalProxyConnError::Io(_) => false,
LocalProxyConnError::H2(_) => false,
}
}
}
impl ShouldRetryWakeCompute for LocalProxyConnError {
fn should_retry_wake_compute(&self) -> bool {
match self {
LocalProxyConnError::Io(_) => false,
LocalProxyConnError::H2(_) => false,
}
}
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let connection = config.authenticate(compute.stream).await?;
drop(pause);
// TODO: lots of useful info but maybe we can move it elsewhere (eg traces?)
info!(
compute_id = %compute.aux.compute_id,
pid = connection.process_id,
cold_start_info = ctx.cold_start_info().as_str(),
query_id = ctx.get_testodrome_id().as_deref(),
sslmode = ?compute.ssl_mode,
%conn_id,
"connected to compute node at {} ({}) latency={}",
compute.hostname,
compute.socket_addr,
ctx.get_proxy_latency(),
);
let (client, connection) = connection.into_managed_conn(
SocketConfig {
host_addr: Some(compute.socket_addr.ip()),
host: postgres_client::config::Host::Tcp(compute.hostname.to_string()),
port: compute.socket_addr.port(),
connect_timeout: None,
},
compute.ssl_mode,
);
Ok(poll_client(
pool.clone(),
ctx,
conn_info.clone(),
client,
connection,
conn_id,
compute.aux,
))
}
async fn h2handshake(
ctx: &RequestContext,
pool: &Arc<GlobalConnPool<Send, HttpConnPool<Send>>>,
conn_info: &ConnInfo,
compute: ComputeConnection,
struct TokioMechanism {
pool: Arc<GlobalConnPool<postgres_client::Client, EndpointConnPool<postgres_client::Client>>>,
conn_info: ConnInfo,
conn_id: uuid::Uuid,
) -> Result<http_conn_pool::Client<Send>, HttpConnError> {
let stream = match compute.stream {
MaybeTlsStream::Raw(tcp) => Box::pin(tcp) as AsyncRW,
MaybeTlsStream::Tls(tls) => Box::into_pin(tls.0) as AsyncRW,
keys: ComputeCredentialKeys,
/// connect_to_compute concurrency lock
locks: &'static ApiLocks<Host>,
}
#[async_trait]
impl ConnectMechanism for TokioMechanism {
type Connection = Client<postgres_client::Client>;
type ConnectError = HttpConnError;
type Error = HttpConnError;
async fn connect_once(
&self,
ctx: &RequestContext,
node_info: &CachedNodeInfo,
compute_config: &ComputeConfig,
) -> Result<Self::Connection, Self::ConnectError> {
let permit = self.locks.get_permit(&node_info.conn_info.host).await?;
let mut config = node_info.conn_info.to_postgres_client_config();
let config = config
.user(&self.conn_info.user_info.user)
.dbname(&self.conn_info.dbname)
.connect_timeout(compute_config.timeout);
if let ComputeCredentialKeys::AuthKeys(auth_keys) = self.keys {
config.auth_keys(auth_keys);
}
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let res = config.connect(compute_config).await;
drop(pause);
let (client, connection) = permit.release_result(res)?;
tracing::Span::current().record("pid", tracing::field::display(client.get_process_id()));
tracing::Span::current().record(
"compute_id",
tracing::field::display(&node_info.aux.compute_id),
);
if let Some(query_id) = ctx.get_testodrome_id() {
info!("latency={}, query_id={}", ctx.get_proxy_latency(), query_id);
}
Ok(poll_client(
self.pool.clone(),
ctx,
self.conn_info.clone(),
client,
connection,
self.conn_id,
node_info.aux.clone(),
))
}
}
struct HyperMechanism {
pool: Arc<GlobalConnPool<Send, HttpConnPool<Send>>>,
conn_info: ConnInfo,
conn_id: uuid::Uuid,
/// connect_to_compute concurrency lock
locks: &'static ApiLocks<Host>,
}
#[async_trait]
impl ConnectMechanism for HyperMechanism {
type Connection = http_conn_pool::Client<Send>;
type ConnectError = HttpConnError;
type Error = HttpConnError;
async fn connect_once(
&self,
ctx: &RequestContext,
node_info: &CachedNodeInfo,
config: &ComputeConfig,
) -> Result<Self::Connection, Self::ConnectError> {
let host_addr = node_info.conn_info.host_addr;
let host = &node_info.conn_info.host;
let permit = self.locks.get_permit(host).await?;
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let tls = if node_info.conn_info.ssl_mode == SslMode::Disable {
None
} else {
Some(&config.tls)
};
let port = node_info.conn_info.port;
let res = connect_http2(host_addr, host, port, config.timeout, tls).await;
drop(pause);
let (client, connection) = permit.release_result(res)?;
tracing::Span::current().record(
"compute_id",
tracing::field::display(&node_info.aux.compute_id),
);
if let Some(query_id) = ctx.get_testodrome_id() {
info!("latency={}, query_id={}", ctx.get_proxy_latency(), query_id);
}
Ok(poll_http2_client(
self.pool.clone(),
ctx,
&self.conn_info,
client,
connection,
self.conn_id,
node_info.aux.clone(),
))
}
}
async fn connect_http2(
host_addr: Option<IpAddr>,
host: &str,
port: u16,
timeout: Duration,
tls: Option<&Arc<rustls::ClientConfig>>,
) -> Result<(http_conn_pool::Send, http_conn_pool::Connect), LocalProxyConnError> {
let addrs = match host_addr {
Some(addr) => vec![SocketAddr::new(addr, port)],
None => lookup_host((host, port))
.await
.map_err(LocalProxyConnError::Io)?
.collect(),
};
let mut last_err = None;
let mut addrs = addrs.into_iter();
let stream = loop {
let Some(addr) = addrs.next() else {
return Err(last_err.unwrap_or_else(|| {
LocalProxyConnError::Io(io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve any addresses",
))
}));
};
match tokio::time::timeout(timeout, TcpStream::connect(addr)).await {
Ok(Ok(stream)) => {
stream.set_nodelay(true).map_err(LocalProxyConnError::Io)?;
break stream;
}
Ok(Err(e)) => {
last_err = Some(LocalProxyConnError::Io(e));
}
Err(e) => {
last_err = Some(LocalProxyConnError::Io(io::Error::new(
io::ErrorKind::TimedOut,
e,
)));
}
}
};
let stream = if let Some(tls) = tls {
let host = DnsName::try_from(host)
.map_err(io::Error::other)
.map_err(LocalProxyConnError::Io)?
.to_owned();
let stream = TlsConnector::from(tls.clone())
.connect(ServerName::DnsName(host), stream)
.await
.map_err(LocalProxyConnError::Io)?;
Box::pin(stream) as AsyncRW
} else {
Box::pin(stream) as AsyncRW
};
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let (client, connection) = hyper::client::conn::http2::Builder::new(TokioExecutor::new())
.timer(TokioTimer::new())
.keep_alive_interval(Duration::from_secs(20))
.keep_alive_while_idle(true)
.keep_alive_timeout(Duration::from_secs(5))
.handshake(TokioIo::new(stream))
.await
.map_err(LocalProxyConnError::H2)?;
drop(pause);
.await?;
// TODO: lots of useful info but maybe we can move it elsewhere (eg traces?)
info!(
compute_id = %compute.aux.compute_id,
cold_start_info = ctx.cold_start_info().as_str(),
query_id = ctx.get_testodrome_id().as_deref(),
sslmode = ?compute.ssl_mode,
%conn_id,
"connected to compute node at {} ({}) latency={}",
compute.hostname,
compute.socket_addr,
ctx.get_proxy_latency(),
);
Ok(poll_http2_client(
pool.clone(),
ctx,
conn_info,
client,
connection,
conn_id,
compute.aux,
))
Ok((client, connection))
}

View File

@@ -69,7 +69,7 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
let mut session_id = ctx.session_id();
let (tx, mut rx) = tokio::sync::watch::channel(session_id);
let span = info_span!(parent: None, "connection", %conn_id, pid=client.get_process_id(), compute_id=%aux.compute_id);
let span = info_span!(parent: None, "connection", %conn_id);
let cold_start_info = ctx.cold_start_info();
span.in_scope(|| {
info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");

View File

@@ -518,14 +518,15 @@ impl<C: ClientInnerExt> GlobalConnPool<C, EndpointConnPool<C>> {
info!("pool: cached connection '{conn_info}' is closed, opening a new one");
return Ok(None);
}
info!(
conn_id = %client.get_conn_id(),
pid = client.inner.get_process_id(),
compute_id = &*client.aux.compute_id,
tracing::Span::current()
.record("conn_id", tracing::field::display(client.get_conn_id()));
tracing::Span::current().record(
"pid",
tracing::field::display(client.inner.get_process_id()),
);
debug!(
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
query_id = ctx.get_testodrome_id().as_deref(),
"reusing connection: latency={}",
ctx.get_proxy_latency(),
"pool: reusing connection '{conn_info}'"
);
match client.get_data() {

View File

@@ -6,7 +6,7 @@ use hyper::client::conn::http2;
use hyper_util::rt::{TokioExecutor, TokioIo};
use parking_lot::RwLock;
use smol_str::ToSmolStr;
use tracing::{Instrument, error, info, info_span};
use tracing::{Instrument, debug, error, info, info_span};
use super::AsyncRW;
use super::backend::HttpConnError;
@@ -115,6 +115,7 @@ impl<C: ClientInnerExt + Clone> Drop for HttpConnPool<C> {
}
impl<C: ClientInnerExt + Clone> GlobalConnPool<C, HttpConnPool<C>> {
#[expect(unused_results)]
pub(crate) fn get(
self: &Arc<Self>,
ctx: &RequestContext,
@@ -131,13 +132,10 @@ impl<C: ClientInnerExt + Clone> GlobalConnPool<C, HttpConnPool<C>> {
return result;
};
info!(
conn_id = %client.conn.conn_id,
compute_id = &*client.conn.aux.compute_id,
tracing::Span::current().record("conn_id", tracing::field::display(client.conn.conn_id));
debug!(
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
query_id = ctx.get_testodrome_id().as_deref(),
"reusing connection: latency={}",
ctx.get_proxy_latency(),
"pool: reusing connection '{conn_info}'"
);
ctx.set_cold_start_info(ColdStartInfo::HttpPoolHit);
ctx.success();
@@ -199,7 +197,7 @@ pub(crate) fn poll_http2_client(
let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
let session_id = ctx.session_id();
let span = info_span!(parent: None, "connection", %conn_id, compute_id=%aux.compute_id);
let span = info_span!(parent: None, "connection", %conn_id);
let cold_start_info = ctx.cold_start_info();
span.in_scope(|| {
info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
@@ -231,8 +229,6 @@ pub(crate) fn poll_http2_client(
tokio::spawn(
async move {
info!("new local proxy connection");
let _conn_gauge = conn_gauge;
let res = connection.await;
match res {

View File

@@ -30,7 +30,7 @@ use serde_json::value::RawValue;
use tokio::net::TcpStream;
use tokio::time::Instant;
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, error, info, info_span, warn};
use tracing::{Instrument, debug, error, info, info_span, warn};
use super::backend::HttpConnError;
use super::conn_pool_lib::{
@@ -107,13 +107,15 @@ impl<C: ClientInnerExt> LocalConnPool<C> {
return Ok(None);
}
info!(
pid = client.inner.get_process_id(),
conn_id = %client.get_conn_id(),
tracing::Span::current()
.record("conn_id", tracing::field::display(client.get_conn_id()));
tracing::Span::current().record(
"pid",
tracing::field::display(client.inner.get_process_id()),
);
debug!(
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
query_id = ctx.get_testodrome_id().as_deref(),
"reusing connection: latency={}",
ctx.get_proxy_latency(),
"local_pool: reusing connection '{conn_info}'"
);
match client.get_data() {

View File

@@ -60,7 +60,7 @@ mod private {
}
}
pub struct RustlsStream<S>(pub Box<TlsStream<S>>);
pub struct RustlsStream<S>(Box<TlsStream<S>>);
impl<S> postgres_client::tls::TlsStream for RustlsStream<S>
where

View File

@@ -18,7 +18,8 @@ use metrics::set_build_info_metric;
use remote_storage::RemoteStorageConfig;
use safekeeper::defaults::{
DEFAULT_CONTROL_FILE_SAVE_INTERVAL, DEFAULT_EVICTION_MIN_RESIDENT, DEFAULT_HEARTBEAT_TIMEOUT,
DEFAULT_HTTP_LISTEN_ADDR, DEFAULT_MAX_OFFLOADER_LAG_BYTES, DEFAULT_PARTIAL_BACKUP_CONCURRENCY,
DEFAULT_HTTP_LISTEN_ADDR, DEFAULT_MAX_OFFLOADER_LAG_BYTES,
DEFAULT_MAX_REELECT_OFFLOADER_LAG_BYTES, DEFAULT_PARTIAL_BACKUP_CONCURRENCY,
DEFAULT_PARTIAL_BACKUP_TIMEOUT, DEFAULT_PG_LISTEN_ADDR, DEFAULT_SSL_CERT_FILE,
DEFAULT_SSL_CERT_RELOAD_PERIOD, DEFAULT_SSL_KEY_FILE,
};
@@ -138,6 +139,11 @@ struct Args {
/// Safekeeper won't be elected for WAL offloading if it is lagging for more than this value in bytes
#[arg(long, default_value_t = DEFAULT_MAX_OFFLOADER_LAG_BYTES)]
max_offloader_lag: u64,
/* BEGIN_HADRON */
/// Safekeeper will re-elect a new offloader if the current backup lagging for more than this value in bytes
#[arg(long, default_value_t = DEFAULT_MAX_REELECT_OFFLOADER_LAG_BYTES)]
max_reelect_offloader_lag_bytes: u64,
/* END_HADRON */
/// Number of max parallel WAL segments to be offloaded to remote storage.
#[arg(long, default_value = "5")]
wal_backup_parallel_jobs: usize,
@@ -391,6 +397,9 @@ async fn main() -> anyhow::Result<()> {
peer_recovery_enabled: args.peer_recovery,
remote_storage: args.remote_storage,
max_offloader_lag_bytes: args.max_offloader_lag,
/* BEGIN_HADRON */
max_reelect_offloader_lag_bytes: args.max_reelect_offloader_lag_bytes,
/* END_HADRON */
wal_backup_enabled: !args.disable_wal_backup,
backup_parallel_jobs: args.wal_backup_parallel_jobs,
pg_auth,

View File

@@ -61,6 +61,9 @@ pub mod defaults {
pub const DEFAULT_HEARTBEAT_TIMEOUT: &str = "5000ms";
pub const DEFAULT_MAX_OFFLOADER_LAG_BYTES: u64 = 128 * (1 << 20);
/* BEGIN_HADRON */
pub const DEFAULT_MAX_REELECT_OFFLOADER_LAG_BYTES: u64 = 128 * (1 << 20);
/* END_HADRON */
pub const DEFAULT_PARTIAL_BACKUP_TIMEOUT: &str = "15m";
pub const DEFAULT_CONTROL_FILE_SAVE_INTERVAL: &str = "300s";
pub const DEFAULT_PARTIAL_BACKUP_CONCURRENCY: &str = "5";
@@ -99,6 +102,9 @@ pub struct SafeKeeperConf {
pub peer_recovery_enabled: bool,
pub remote_storage: Option<RemoteStorageConfig>,
pub max_offloader_lag_bytes: u64,
/* BEGIN_HADRON */
pub max_reelect_offloader_lag_bytes: u64,
/* END_HADRON */
pub backup_parallel_jobs: usize,
pub wal_backup_enabled: bool,
pub pg_auth: Option<Arc<JwtAuth>>,
@@ -151,6 +157,9 @@ impl SafeKeeperConf {
sk_auth_token: None,
heartbeat_timeout: Duration::new(5, 0),
max_offloader_lag_bytes: defaults::DEFAULT_MAX_OFFLOADER_LAG_BYTES,
/* BEGIN_HADRON */
max_reelect_offloader_lag_bytes: defaults::DEFAULT_MAX_REELECT_OFFLOADER_LAG_BYTES,
/* END_HADRON */
current_thread_runtime: false,
walsenders_keep_horizon: false,
partial_backup_timeout: Duration::from_secs(0),

View File

@@ -138,6 +138,15 @@ pub static BACKUP_ERRORS: Lazy<IntCounter> = Lazy::new(|| {
)
.expect("Failed to register safekeeper_backup_errors_total counter")
});
/* BEGIN_HADRON */
pub static BACKUP_REELECT_LEADER_COUNT: Lazy<IntCounter> = Lazy::new(|| {
register_int_counter!(
"safekeeper_backup_reelect_leader_total",
"Number of times the backup leader was reelected"
)
.expect("Failed to register safekeeper_backup_reelect_leader_total counter")
});
/* END_HADRON */
pub static BROKER_PUSH_ALL_UPDATES_SECONDS: Lazy<Histogram> = Lazy::new(|| {
register_histogram!(
"safekeeper_broker_push_update_seconds",

View File

@@ -26,7 +26,9 @@ use utils::id::{NodeId, TenantTimelineId};
use utils::lsn::Lsn;
use utils::{backoff, pausable_failpoint};
use crate::metrics::{BACKED_UP_SEGMENTS, BACKUP_ERRORS, WAL_BACKUP_TASKS};
use crate::metrics::{
BACKED_UP_SEGMENTS, BACKUP_ERRORS, BACKUP_REELECT_LEADER_COUNT, WAL_BACKUP_TASKS,
};
use crate::timeline::WalResidentTimeline;
use crate::timeline_manager::{Manager, StateSnapshot};
use crate::{SafeKeeperConf, WAL_BACKUP_RUNTIME};
@@ -70,8 +72,7 @@ pub(crate) async fn update_task(
need_backup: bool,
state: &StateSnapshot,
) {
let (offloader, election_dbg_str) =
determine_offloader(&state.peers, state.backup_lsn, mgr.tli.ttid, &mgr.conf);
let (offloader, election_dbg_str) = hadron_determine_offloader(mgr, state);
let elected_me = Some(mgr.conf.my_id) == offloader;
let should_task_run = need_backup && elected_me;
@@ -127,6 +128,71 @@ async fn shut_down_task(entry: &mut Option<WalBackupTaskHandle>) {
}
}
/* BEGIN_HADRON */
// On top of the neon determine_offloader, we also check if the current offloader is lagging behind too much.
// If it is, we re-elect a new offloader. This mitigates the below issue. It also helps distribute the load across SKs.
//
// We observe that the offloader fails to upload a segment due to race conditions on XLOG SWITCH and PG start streaming WALs.
// wal_backup task continously failing to upload a full segment while the segment remains partial on the disk.
// The consequence is that commit_lsn for all SKs move forward but backup_lsn stays the same. Then, all SKs run out of disk space.
// See go/sk-ood-xlog-switch for more details.
//
// To mitigate this issue, we will re-elect a new offloader if the current offloader is lagging behind too much.
// Each SK makes the decision locally but they are aware of each other's commit and backup lsns.
//
// determine_offloader will pick a SK. say SK-1.
// Each SK checks
// -- if commit_lsn - back_lsn > threshold,
// -- -- remove SK-1 from the candidate and call determine_offloader again.
// SK-1 will step down and all SKs will elect the same leader again.
// After the backup is caught up, the leader will become SK-1 again.
fn hadron_determine_offloader(mgr: &Manager, state: &StateSnapshot) -> (Option<NodeId>, String) {
let mut offloader: Option<NodeId>;
let mut election_dbg_str: String;
let caughtup_peers_count: usize;
(offloader, election_dbg_str, caughtup_peers_count) =
determine_offloader(&state.peers, state.backup_lsn, mgr.tli.ttid, &mgr.conf);
if offloader.is_none() || caughtup_peers_count <= 1 {
return (offloader, election_dbg_str);
}
let offloader_sk_id = offloader.unwrap();
let backup_lag = state.commit_lsn.checked_sub(state.backup_lsn);
if backup_lag.is_none() {
info!("Backup lag is None. Skipping re-election.");
return (offloader, election_dbg_str);
}
let backup_lag = backup_lag.unwrap().0;
if backup_lag < mgr.conf.max_reelect_offloader_lag_bytes {
info!(
"Backup lag {} is lower than the threshold {}. Skipping re-election.",
backup_lag, mgr.conf.max_reelect_offloader_lag_bytes
);
return (offloader, election_dbg_str);
}
info!(
"Electing a new leader: Backup lag is too high backup lsn lag {} threshold {}: {}",
backup_lag, mgr.conf.max_reelect_offloader_lag_bytes, election_dbg_str
);
BACKUP_REELECT_LEADER_COUNT.inc();
// Remove the current offloader if lag is too high.
let new_peers: Vec<_> = state
.peers
.iter()
.filter(|p| p.sk_id != offloader_sk_id)
.cloned()
.collect();
(offloader, election_dbg_str, _) =
determine_offloader(&new_peers, state.backup_lsn, mgr.tli.ttid, &mgr.conf);
(offloader, election_dbg_str)
}
/* END_HADRON */
/// The goal is to ensure that normally only one safekeepers offloads. However,
/// it is fine (and inevitable, as s3 doesn't provide CAS) that for some short
/// time we have several ones as they PUT the same files. Also,
@@ -141,13 +207,13 @@ fn determine_offloader(
wal_backup_lsn: Lsn,
ttid: TenantTimelineId,
conf: &SafeKeeperConf,
) -> (Option<NodeId>, String) {
) -> (Option<NodeId>, String, usize) {
// TODO: remove this once we fill newly joined safekeepers since backup_lsn.
let capable_peers = alive_peers
.iter()
.filter(|p| p.local_start_lsn <= wal_backup_lsn);
match capable_peers.clone().map(|p| p.commit_lsn).max() {
None => (None, "no connected peers to elect from".to_string()),
None => (None, "no connected peers to elect from".to_string(), 0),
Some(max_commit_lsn) => {
let threshold = max_commit_lsn
.checked_sub(conf.max_offloader_lag_bytes)
@@ -175,6 +241,7 @@ fn determine_offloader(
capable_peers_dbg,
caughtup_peers.len()
),
caughtup_peers.len(),
)
}
}

View File

@@ -159,6 +159,9 @@ pub fn run_server(os: NodeOs, disk: Arc<SafekeeperDisk>) -> Result<()> {
heartbeat_timeout: Duration::from_secs(0),
remote_storage: None,
max_offloader_lag_bytes: 0,
/* BEGIN_HADRON */
max_reelect_offloader_lag_bytes: 0,
/* END_HADRON */
wal_backup_enabled: false,
listen_pg_addr_tenant_only: None,
advertise_pg_addr: None,

View File

@@ -27,6 +27,7 @@ governor.workspace = true
hex.workspace = true
hyper0.workspace = true
humantime.workspace = true
humantime-serde.workspace = true
itertools.workspace = true
json-structural-diff.workspace = true
lasso.workspace = true
@@ -34,6 +35,7 @@ once_cell.workspace = true
pageserver_api.workspace = true
pageserver_client.workspace = true
postgres_connection.workspace = true
posthog_client_lite.workspace = true
rand.workspace = true
reqwest = { workspace = true, features = ["stream"] }
routerify.workspace = true

View File

@@ -14,11 +14,13 @@ use http_utils::tls_certs::ReloadingCertificateResolver;
use hyper0::Uri;
use metrics::BuildInfo;
use metrics::launch_timestamp::LaunchTimestamp;
use pageserver_api::config::PostHogConfig;
use reqwest::Certificate;
use storage_controller::http::make_router;
use storage_controller::metrics::preinitialize_metrics;
use storage_controller::persistence::Persistence;
use storage_controller::service::chaos_injector::ChaosInjector;
use storage_controller::service::feature_flag::FeatureFlagService;
use storage_controller::service::{
Config, HEARTBEAT_INTERVAL_DEFAULT, LONG_RECONCILE_THRESHOLD_DEFAULT,
MAX_OFFLINE_INTERVAL_DEFAULT, MAX_WARMING_UP_INTERVAL_DEFAULT,
@@ -252,6 +254,8 @@ struct Secrets {
peer_jwt_token: Option<String>,
}
const POSTHOG_CONFIG_ENV: &str = "POSTHOG_CONFIG";
impl Secrets {
const DATABASE_URL_ENV: &'static str = "DATABASE_URL";
const PAGESERVER_JWT_TOKEN_ENV: &'static str = "PAGESERVER_JWT_TOKEN";
@@ -409,6 +413,18 @@ async fn async_main() -> anyhow::Result<()> {
None => Vec::new(),
};
let posthog_config = if let Ok(json) = std::env::var(POSTHOG_CONFIG_ENV) {
let res: Result<PostHogConfig, _> = serde_json::from_str(&json);
if let Ok(config) = res {
Some(config)
} else {
tracing::warn!("Invalid posthog config: {json}");
None
}
} else {
None
};
let config = Config {
pageserver_jwt_token: secrets.pageserver_jwt_token,
safekeeper_jwt_token: secrets.safekeeper_jwt_token,
@@ -455,6 +471,7 @@ async fn async_main() -> anyhow::Result<()> {
timelines_onto_safekeepers: args.timelines_onto_safekeepers,
use_local_compute_notifications: args.use_local_compute_notifications,
timeline_safekeeper_count: args.timeline_safekeeper_count,
posthog_config: posthog_config.clone(),
#[cfg(feature = "testing")]
kick_secondary_downloads: args.kick_secondary_downloads,
};
@@ -537,6 +554,23 @@ async fn async_main() -> anyhow::Result<()> {
)
});
let feature_flag_task = if let Some(posthog_config) = posthog_config {
let service = service.clone();
let cancel = CancellationToken::new();
let cancel_bg = cancel.clone();
let task = tokio::task::spawn(
async move {
let feature_flag_service = FeatureFlagService::new(service, posthog_config);
let feature_flag_service = Arc::new(feature_flag_service);
feature_flag_service.run(cancel_bg).await
}
.instrument(tracing::info_span!("feature_flag_service")),
);
Some((task, cancel))
} else {
None
};
// Wait until we receive a signal
let mut sigint = tokio::signal::unix::signal(SignalKind::interrupt())?;
let mut sigquit = tokio::signal::unix::signal(SignalKind::quit())?;
@@ -584,6 +618,12 @@ async fn async_main() -> anyhow::Result<()> {
chaos_jh.await.ok();
}
// If we were running the feature flag service, stop that so that we're not calling into Service while it shuts down
if let Some((feature_flag_task, feature_flag_cancel)) = feature_flag_task {
feature_flag_cancel.cancel();
feature_flag_task.await.ok();
}
service.shutdown().await;
tracing::info!("Service shutdown complete");

View File

@@ -376,4 +376,13 @@ impl PageserverClient {
.await
)
}
pub(crate) async fn update_feature_flag_spec(&self, spec: String) -> Result<()> {
measured_request!(
"update_feature_flag_spec",
crate::metrics::Method::Post,
&self.node_id_label,
self.inner.update_feature_flag_spec(spec).await
)
}
}

View File

@@ -1,5 +1,6 @@
pub mod chaos_injector;
mod context_iterator;
pub mod feature_flag;
pub(crate) mod safekeeper_reconciler;
mod safekeeper_service;
@@ -25,6 +26,7 @@ use futures::stream::FuturesUnordered;
use http_utils::error::ApiError;
use hyper::Uri;
use itertools::Itertools;
use pageserver_api::config::PostHogConfig;
use pageserver_api::controller_api::{
AvailabilityZone, MetadataHealthRecord, MetadataHealthUpdateRequest, NodeAvailability,
NodeRegisterRequest, NodeSchedulingPolicy, NodeShard, NodeShardResponse, PlacementPolicy,
@@ -471,6 +473,9 @@ pub struct Config {
/// Safekeepers will be choosen from different availability zones.
pub timeline_safekeeper_count: i64,
/// PostHog integration config
pub posthog_config: Option<PostHogConfig>,
#[cfg(feature = "testing")]
pub kick_secondary_downloads: bool,
}

View File

@@ -0,0 +1,117 @@
use std::{sync::Arc, time::Duration};
use futures::StreamExt;
use pageserver_api::config::PostHogConfig;
use pageserver_client::mgmt_api;
use posthog_client_lite::{PostHogClient, PostHogClientConfig};
use reqwest::StatusCode;
use tokio::time::MissedTickBehavior;
use tokio_util::sync::CancellationToken;
use crate::{pageserver_client::PageserverClient, service::Service};
pub struct FeatureFlagService {
service: Arc<Service>,
config: PostHogConfig,
client: PostHogClient,
http_client: reqwest::Client,
}
const DEFAULT_POSTHOG_REFRESH_INTERVAL: Duration = Duration::from_secs(30);
impl FeatureFlagService {
pub fn new(service: Arc<Service>, config: PostHogConfig) -> Self {
let client = PostHogClient::new(PostHogClientConfig {
project_id: config.project_id.clone(),
server_api_key: config.server_api_key.clone(),
client_api_key: config.client_api_key.clone(),
private_api_url: config.private_api_url.clone(),
public_api_url: config.public_api_url.clone(),
});
Self {
service,
config,
client,
http_client: reqwest::Client::new(),
}
}
async fn refresh(self: Arc<Self>, cancel: CancellationToken) -> Result<(), anyhow::Error> {
let nodes = {
let inner = self.service.inner.read().unwrap();
inner.nodes.clone()
};
let feature_flag_spec = self.client.get_feature_flags_local_evaluation_raw().await?;
let stream = futures::stream::iter(nodes.values().cloned()).map(|node| {
let this = self.clone();
let feature_flag_spec = feature_flag_spec.clone();
async move {
let res = async {
let client = PageserverClient::new(
node.get_id(),
this.http_client.clone(),
node.base_url(),
// TODO: what if we rotate the token during storcon lifetime?
this.service.config.pageserver_jwt_token.as_deref(),
);
client.update_feature_flag_spec(feature_flag_spec).await?;
tracing::info!(
"Updated {}({}) with feature flag spec",
node.get_id(),
node.base_url()
);
Ok::<_, mgmt_api::Error>(())
};
if let Err(e) = res.await {
if let mgmt_api::Error::ApiError(status, _) = e {
if status == StatusCode::NOT_FOUND {
// This is expected during deployments where the API is not available, so we can ignore it
return;
}
}
tracing::warn!(
"Failed to update feature flag spec for {}: {e}",
node.get_id()
);
}
}
});
let mut stream = stream.buffer_unordered(8);
while stream.next().await.is_some() {
if cancel.is_cancelled() {
return Ok(());
}
}
Ok(())
}
pub async fn run(self: Arc<Self>, cancel: CancellationToken) {
let refresh_interval = self
.config
.refresh_interval
.unwrap_or(DEFAULT_POSTHOG_REFRESH_INTERVAL);
let mut interval = tokio::time::interval(refresh_interval);
interval.set_missed_tick_behavior(MissedTickBehavior::Skip);
tracing::info!(
"Starting feature flag service with refresh interval: {:?}",
refresh_interval
);
loop {
tokio::select! {
_ = interval.tick() => {}
_ = cancel.cancelled() => {
break;
}
}
let res = self.clone().refresh(cancel.clone()).await;
if let Err(e) = res {
tracing::error!("Failed to refresh feature flags: {e:#?}");
}
}
}
}