Compare commits

..

3 Commits

Author SHA1 Message Date
Aleksandr Sarantsev
c8475ed008 Introduce flag for deletion API 2025-07-08 17:20:15 +04:00
Aleksandr Sarantsev
2f3fc7cb57 Fix keep-failing reconciles test & add logs (#12497)
## Problem

Test is flaky due to the following warning in the logs:

```
Keeping extra secondaries: can't determine which of [NodeId(1), NodeId(2)] to remove (some nodes offline?)
```

Some nodes being offline is expected behavior in this test.

## Summary of changes

- Added `Keeping extra secondaries` to the list of allowed errors
- Improved logging for better debugging experience

Co-authored-by: Aleksandr Sarantsev <aleksandr.sarantsev@databricks.com>
2025-07-08 08:51:50 +00:00
Folke Behrens
e65d5f7369 proxy: Remove the endpoint filter cache (#12488)
## Problem

The endpoint filter cache is still unused because it's not yet reliable
enough to be used. It only consumes a lot of memory.

## Summary of changes

Remove the code. Needs a new design.

neondatabase/cloud#30634
2025-07-07 17:46:33 +00:00
18 changed files with 175 additions and 685 deletions

View File

@@ -75,6 +75,12 @@ enum Command {
NodeStartDelete {
#[arg(long)]
node_id: NodeId,
/// When `force` is true, skip waiting for shards to prewarm during migration.
/// This can significantly speed up node deletion since prewarming all shards
/// can take considerable time, but may result in slower initial access to
/// migrated shards until they warm up naturally.
#[arg(long)]
force: bool,
},
/// Cancel deletion of the specified pageserver and wait for `timeout`
/// for the operation to be canceled. May be retried.
@@ -933,13 +939,14 @@ async fn main() -> anyhow::Result<()> {
.dispatch::<(), ()>(Method::DELETE, format!("control/v1/node/{node_id}"), None)
.await?;
}
Command::NodeStartDelete { node_id } => {
Command::NodeStartDelete { node_id, force } => {
let query = if force {
format!("control/v1/node/{node_id}/delete?force=true")
} else {
format!("control/v1/node/{node_id}/delete")
};
storcon_client
.dispatch::<(), ()>(
Method::PUT,
format!("control/v1/node/{node_id}/delete"),
None,
)
.dispatch::<(), ()>(Method::PUT, query, None)
.await?;
println!("Delete started for {node_id}");
}

View File

@@ -21,7 +21,7 @@ use tokio::net::TcpListener;
use tokio::sync::Notify;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, error, info, warn};
use tracing::{error, info, warn};
use utils::sentry_init::init_sentry;
use utils::{project_build_tag, project_git_version};
@@ -195,7 +195,9 @@ struct ProxyCliArgs {
#[clap(long, default_value = config::ProjectInfoCacheOptions::CACHE_DEFAULT_OPTIONS)]
project_info_cache: String,
/// cache for all valid endpoints
#[clap(long, default_value = config::EndpointCacheConfig::CACHE_DEFAULT_OPTIONS)]
// TODO: remove after a couple of releases.
#[clap(long, default_value_t = String::new())]
#[deprecated]
endpoint_cache_config: String,
#[clap(flatten)]
parquet_upload: ParquetUploadArgs,
@@ -558,13 +560,6 @@ pub async fn run() -> anyhow::Result<()> {
}
}
}
// listen for notifications of new projects/endpoints/branches
let cache = api.caches.endpoints_cache.clone();
let span = tracing::info_span!("endpoints_cache");
maintenance_tasks.spawn(
async move { cache.do_read(client, cancellation_token.clone()).await }.instrument(span),
);
}
let maintenance = loop {
@@ -712,18 +707,15 @@ fn build_auth_backend(
let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?;
let project_info_cache_config: ProjectInfoCacheOptions =
args.project_info_cache.parse()?;
let endpoint_cache_config: config::EndpointCacheConfig =
args.endpoint_cache_config.parse()?;
info!("Using NodeInfoCache (wake_compute) with options={wake_compute_cache_config:?}");
info!(
"Using AllowedIpsCache (wake_compute) with options={project_info_cache_config:?}"
);
info!("Using EndpointCacheConfig with options={endpoint_cache_config:?}");
let caches = Box::leak(Box::new(control_plane::caches::ApiCaches::new(
wake_compute_cache_config,
project_info_cache_config,
endpoint_cache_config,
)));
let config::ConcurrencyLockOptions {
@@ -793,18 +785,15 @@ fn build_auth_backend(
let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?;
let project_info_cache_config: ProjectInfoCacheOptions =
args.project_info_cache.parse()?;
let endpoint_cache_config: config::EndpointCacheConfig =
args.endpoint_cache_config.parse()?;
info!("Using NodeInfoCache (wake_compute) with options={wake_compute_cache_config:?}");
info!(
"Using AllowedIpsCache (wake_compute) with options={project_info_cache_config:?}"
);
info!("Using EndpointCacheConfig with options={endpoint_cache_config:?}");
let caches = Box::leak(Box::new(control_plane::caches::ApiCaches::new(
wake_compute_cache_config,
project_info_cache_config,
endpoint_cache_config,
)));
let config::ConcurrencyLockOptions {

View File

@@ -1,283 +0,0 @@
use std::convert::Infallible;
use std::future::pending;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use clashmap::ClashSet;
use redis::streams::{StreamReadOptions, StreamReadReply};
use redis::{AsyncCommands, FromRedisValue, Value};
use serde::Deserialize;
use tokio_util::sync::CancellationToken;
use tracing::info;
use crate::config::EndpointCacheConfig;
use crate::context::RequestContext;
use crate::ext::LockExt;
use crate::intern::{BranchIdInt, EndpointIdInt, ProjectIdInt};
use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
use crate::rate_limiter::GlobalRateLimiter;
use crate::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use crate::types::EndpointId;
// TODO: this could be an enum, but events in Redis need to be fixed first.
// ProjectCreated was sent with type:branch_created. So we ignore type.
#[derive(Deserialize, Debug, Clone, PartialEq)]
struct ControlPlaneEvent {
endpoint_created: Option<EndpointCreated>,
branch_created: Option<BranchCreated>,
project_created: Option<ProjectCreated>,
#[serde(rename = "type")]
_type: Option<String>,
}
#[derive(Deserialize, Debug, Clone, PartialEq)]
struct EndpointCreated {
endpoint_id: EndpointIdInt,
}
#[derive(Deserialize, Debug, Clone, PartialEq)]
struct BranchCreated {
branch_id: BranchIdInt,
}
#[derive(Deserialize, Debug, Clone, PartialEq)]
struct ProjectCreated {
project_id: ProjectIdInt,
}
impl TryFrom<&Value> for ControlPlaneEvent {
type Error = anyhow::Error;
fn try_from(value: &Value) -> Result<Self, Self::Error> {
let json = String::from_redis_value(value)?;
Ok(serde_json::from_str(&json)?)
}
}
pub struct EndpointsCache {
config: EndpointCacheConfig,
endpoints: ClashSet<EndpointIdInt>,
branches: ClashSet<BranchIdInt>,
projects: ClashSet<ProjectIdInt>,
ready: AtomicBool,
limiter: Arc<Mutex<GlobalRateLimiter>>,
}
impl EndpointsCache {
pub(crate) fn new(config: EndpointCacheConfig) -> Self {
Self {
limiter: Arc::new(Mutex::new(GlobalRateLimiter::new(
config.limiter_info.clone(),
))),
config,
endpoints: ClashSet::new(),
branches: ClashSet::new(),
projects: ClashSet::new(),
ready: AtomicBool::new(false),
}
}
pub(crate) fn is_valid(&self, ctx: &RequestContext, endpoint: &EndpointId) -> bool {
if !self.ready.load(Ordering::Acquire) {
// the endpoint cache is not yet fully initialised.
return true;
}
if !self.should_reject(endpoint) {
ctx.set_rejected(false);
return true;
}
// report that we might want to reject this endpoint
ctx.set_rejected(true);
// If cache is disabled, just collect the metrics and return.
if self.config.disable_cache {
return true;
}
// If the limiter allows, we can pretend like it's valid
// (incase it is, due to redis channel lag).
if self.limiter.lock_propagate_poison().check() {
return true;
}
// endpoint not found, and there's too much load.
false
}
fn should_reject(&self, endpoint: &EndpointId) -> bool {
if endpoint.is_endpoint() {
let Some(endpoint) = EndpointIdInt::get(endpoint) else {
// if we haven't interned this endpoint, it's not in the cache.
return true;
};
!self.endpoints.contains(&endpoint)
} else if endpoint.is_branch() {
let Some(branch) = BranchIdInt::get(endpoint) else {
// if we haven't interned this branch, it's not in the cache.
return true;
};
!self.branches.contains(&branch)
} else {
let Some(project) = ProjectIdInt::get(endpoint) else {
// if we haven't interned this project, it's not in the cache.
return true;
};
!self.projects.contains(&project)
}
}
fn insert_event(&self, event: ControlPlaneEvent) {
if let Some(endpoint_created) = event.endpoint_created {
self.endpoints.insert(endpoint_created.endpoint_id);
Metrics::get()
.proxy
.redis_events_count
.inc(RedisEventsCount::EndpointCreated);
} else if let Some(branch_created) = event.branch_created {
self.branches.insert(branch_created.branch_id);
Metrics::get()
.proxy
.redis_events_count
.inc(RedisEventsCount::BranchCreated);
} else if let Some(project_created) = event.project_created {
self.projects.insert(project_created.project_id);
Metrics::get()
.proxy
.redis_events_count
.inc(RedisEventsCount::ProjectCreated);
}
}
pub async fn do_read(
&self,
mut con: ConnectionWithCredentialsProvider,
cancellation_token: CancellationToken,
) -> anyhow::Result<Infallible> {
let mut last_id = "0-0".to_string();
loop {
if let Err(e) = con.connect().await {
tracing::error!("error connecting to redis: {:?}", e);
self.ready.store(false, Ordering::Release);
}
if let Err(e) = self.read_from_stream(&mut con, &mut last_id).await {
tracing::error!("error reading from redis: {:?}", e);
self.ready.store(false, Ordering::Release);
}
if cancellation_token.is_cancelled() {
info!("cancellation token is cancelled, exiting");
// Maintenance tasks run forever. Sleep forever when canceled.
pending::<()>().await;
}
tokio::time::sleep(self.config.retry_interval).await;
}
}
async fn read_from_stream(
&self,
con: &mut ConnectionWithCredentialsProvider,
last_id: &mut String,
) -> anyhow::Result<()> {
tracing::info!("reading endpoints/branches/projects from redis");
self.batch_read(
con,
StreamReadOptions::default().count(self.config.initial_batch_size),
last_id,
true,
)
.await?;
tracing::info!("ready to filter user requests");
self.ready.store(true, Ordering::Release);
self.batch_read(
con,
StreamReadOptions::default()
.count(self.config.default_batch_size)
.block(self.config.xread_timeout.as_millis() as usize),
last_id,
false,
)
.await
}
async fn batch_read(
&self,
conn: &mut ConnectionWithCredentialsProvider,
opts: StreamReadOptions,
last_id: &mut String,
return_when_finish: bool,
) -> anyhow::Result<()> {
let mut total: usize = 0;
loop {
let mut res: StreamReadReply = conn
.xread_options(&[&self.config.stream_name], &[last_id.as_str()], &opts)
.await?;
if res.keys.is_empty() {
if return_when_finish {
if total != 0 {
break;
}
anyhow::bail!(
"Redis stream {} is empty, cannot be used to filter endpoints",
self.config.stream_name
);
}
// If we are not returning when finish, we should wait for more data.
continue;
}
if res.keys.len() != 1 {
anyhow::bail!("Cannot read from redis stream {}", self.config.stream_name);
}
let key = res.keys.pop().expect("Checked length above");
let len = key.ids.len();
for stream_id in key.ids {
total += 1;
for value in stream_id.map.values() {
match value.try_into() {
Ok(event) => self.insert_event(event),
Err(err) => {
Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
channel: &self.config.stream_name,
});
tracing::error!("error parsing value {value:?}: {err:?}");
}
}
}
if total.is_power_of_two() {
tracing::debug!("endpoints read {}", total);
}
*last_id = stream_id.id;
}
if return_when_finish && len <= self.config.default_batch_size {
break;
}
}
tracing::info!("read {} endpoints/branches/projects from redis", total);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_control_plane_event() {
let s = r#"{"branch_created":null,"endpoint_created":{"endpoint_id":"ep-rapid-thunder-w0qqw2q9"},"project_created":null,"type":"endpoint_created"}"#;
let endpoint_id: EndpointId = "ep-rapid-thunder-w0qqw2q9".into();
assert_eq!(
serde_json::from_str::<ControlPlaneEvent>(s).unwrap(),
ControlPlaneEvent {
endpoint_created: Some(EndpointCreated {
endpoint_id: endpoint_id.into(),
}),
branch_created: None,
project_created: None,
_type: Some("endpoint_created".into()),
}
);
}
}

View File

@@ -1,5 +1,4 @@
pub(crate) mod common;
pub(crate) mod endpoints;
pub(crate) mod project_info;
mod timed_lru;

View File

@@ -18,7 +18,7 @@ use crate::control_plane::locks::ApiLocks;
use crate::control_plane::messages::{EndpointJwksResponse, JwksSettings};
use crate::ext::TaskExt;
use crate::intern::RoleNameInt;
use crate::rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig};
use crate::rate_limiter::{RateLimitAlgorithm, RateLimiterConfig};
use crate::scram::threadpool::ThreadPool;
use crate::serverless::GlobalConnPoolOptions;
use crate::serverless::cancel_set::CancelSet;
@@ -80,79 +80,6 @@ pub struct AuthenticationConfig {
pub console_redirect_confirmation_timeout: tokio::time::Duration,
}
#[derive(Debug)]
pub struct EndpointCacheConfig {
/// Batch size to receive all endpoints on the startup.
pub initial_batch_size: usize,
/// Batch size to receive endpoints.
pub default_batch_size: usize,
/// Timeouts for the stream read operation.
pub xread_timeout: Duration,
/// Stream name to read from.
pub stream_name: String,
/// Limiter info (to distinguish when to enable cache).
pub limiter_info: Vec<RateBucketInfo>,
/// Disable cache.
/// If true, cache is ignored, but reports all statistics.
pub disable_cache: bool,
/// Retry interval for the stream read operation.
pub retry_interval: Duration,
}
impl EndpointCacheConfig {
/// Default options for [`crate::control_plane::NodeInfoCache`].
/// Notice that by default the limiter is empty, which means that cache is disabled.
pub const CACHE_DEFAULT_OPTIONS: &'static str = "initial_batch_size=1000,default_batch_size=10,xread_timeout=5m,stream_name=controlPlane,disable_cache=true,limiter_info=1000@1s,retry_interval=1s";
/// Parse cache options passed via cmdline.
/// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
fn parse(options: &str) -> anyhow::Result<Self> {
let mut initial_batch_size = None;
let mut default_batch_size = None;
let mut xread_timeout = None;
let mut stream_name = None;
let mut limiter_info = vec![];
let mut disable_cache = false;
let mut retry_interval = None;
for option in options.split(',') {
let (key, value) = option
.split_once('=')
.with_context(|| format!("bad key-value pair: {option}"))?;
match key {
"initial_batch_size" => initial_batch_size = Some(value.parse()?),
"default_batch_size" => default_batch_size = Some(value.parse()?),
"xread_timeout" => xread_timeout = Some(humantime::parse_duration(value)?),
"stream_name" => stream_name = Some(value.to_string()),
"limiter_info" => limiter_info.push(RateBucketInfo::from_str(value)?),
"disable_cache" => disable_cache = value.parse()?,
"retry_interval" => retry_interval = Some(humantime::parse_duration(value)?),
unknown => bail!("unknown key: {unknown}"),
}
}
RateBucketInfo::validate(&mut limiter_info)?;
Ok(Self {
initial_batch_size: initial_batch_size.context("missing `initial_batch_size`")?,
default_batch_size: default_batch_size.context("missing `default_batch_size`")?,
xread_timeout: xread_timeout.context("missing `xread_timeout`")?,
stream_name: stream_name.context("missing `stream_name`")?,
disable_cache,
limiter_info,
retry_interval: retry_interval.context("missing `retry_interval`")?,
})
}
}
impl FromStr for EndpointCacheConfig {
type Err = anyhow::Error;
fn from_str(options: &str) -> Result<Self, Self::Err> {
let error = || format!("failed to parse endpoint cache options '{options}'");
Self::parse(options).with_context(error)
}
}
#[derive(Debug)]
pub struct MetricBackupCollectionConfig {
pub remote_storage_config: Option<RemoteStorageConfig>,

View File

@@ -7,7 +7,7 @@ use once_cell::sync::OnceCell;
use smol_str::SmolStr;
use tokio::sync::mpsc;
use tracing::field::display;
use tracing::{Span, debug, error, info_span};
use tracing::{Span, error, info_span};
use try_lock::TryLock;
use uuid::Uuid;
@@ -15,10 +15,7 @@ use self::parquet::RequestData;
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
use crate::error::ErrorKind;
use crate::intern::{BranchIdInt, ProjectIdInt};
use crate::metrics::{
ConnectOutcome, InvalidEndpointsGroup, LatencyAccumulated, LatencyTimer, Metrics, Protocol,
Waiting,
};
use crate::metrics::{LatencyAccumulated, LatencyTimer, Metrics, Protocol, Waiting};
use crate::pqproto::StartupMessageParams;
use crate::protocol2::{ConnectionInfo, ConnectionInfoExtra};
use crate::types::{DbName, EndpointId, RoleName};
@@ -70,8 +67,6 @@ struct RequestContextInner {
// This sender is only used to log the length of session in case of success.
disconnect_sender: Option<mpsc::UnboundedSender<RequestData>>,
pub(crate) latency_timer: LatencyTimer,
// Whether proxy decided that it's not a valid endpoint end rejected it before going to cplane.
rejected: Option<bool>,
disconnect_timestamp: Option<chrono::DateTime<Utc>>,
}
@@ -106,7 +101,6 @@ impl Clone for RequestContext {
auth_method: inner.auth_method.clone(),
jwt_issuer: inner.jwt_issuer.clone(),
success: inner.success,
rejected: inner.rejected,
cold_start_info: inner.cold_start_info,
pg_options: inner.pg_options.clone(),
testodrome_query_id: inner.testodrome_query_id.clone(),
@@ -151,7 +145,6 @@ impl RequestContext {
auth_method: None,
jwt_issuer: None,
success: false,
rejected: None,
cold_start_info: ColdStartInfo::Unknown,
pg_options: None,
testodrome_query_id: None,
@@ -183,11 +176,6 @@ impl RequestContext {
)
}
pub(crate) fn set_rejected(&self, rejected: bool) {
let mut this = self.0.try_lock().expect("should not deadlock");
this.rejected = Some(rejected);
}
pub(crate) fn set_cold_start_info(&self, info: ColdStartInfo) {
self.0
.try_lock()
@@ -461,38 +449,6 @@ impl RequestContextInner {
}
fn log_connect(&mut self) {
let outcome = if self.success {
ConnectOutcome::Success
} else {
ConnectOutcome::Failed
};
// TODO: get rid of entirely/refactor
// check for false positives
// AND false negatives
if let Some(rejected) = self.rejected {
let ep = self
.endpoint_id
.as_ref()
.map(|x| x.as_str())
.unwrap_or_default();
// This makes sense only if cache is disabled
debug!(
?outcome,
?rejected,
?ep,
"check endpoint is valid with outcome"
);
Metrics::get()
.proxy
.invalid_endpoints_total
.inc(InvalidEndpointsGroup {
protocol: self.protocol,
rejected: rejected.into(),
outcome,
});
}
if let Some(tx) = self.sender.take() {
// If type changes, this error handling needs to be updated.
let tx: mpsc::UnboundedSender<RequestData> = tx;

View File

@@ -159,13 +159,6 @@ impl NeonControlPlaneClient {
ctx: &RequestContext,
endpoint: &EndpointId,
) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
if !self
.caches
.endpoints_cache
.is_valid(ctx, &endpoint.normalize())
{
return Err(GetEndpointJwksError::EndpointNotFound);
}
let request_id = ctx.session_id().to_string();
async {
let request = self
@@ -300,11 +293,6 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
return Ok(secret);
}
if !self.caches.endpoints_cache.is_valid(ctx, normalized_ep) {
info!("endpoint is not valid, skipping the request");
return Err(GetAuthInfoError::UnknownEndpoint);
}
let auth_info = self.do_get_auth_req(ctx, endpoint, role).await?;
let control = EndpointAccessControl {
@@ -346,11 +334,6 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
return Ok(control);
}
if !self.caches.endpoints_cache.is_valid(ctx, normalized_ep) {
info!("endpoint is not valid, skipping the request");
return Err(GetAuthInfoError::UnknownEndpoint);
}
let auth_info = self.do_get_auth_req(ctx, endpoint, role).await?;
let control = EndpointAccessControl {

View File

@@ -13,9 +13,8 @@ use tracing::{debug, info};
use super::{EndpointAccessControl, RoleAccessControl};
use crate::auth::backend::ComputeUserInfo;
use crate::auth::backend::jwt::{AuthRule, FetchAuthRules, FetchAuthRulesError};
use crate::cache::endpoints::EndpointsCache;
use crate::cache::project_info::ProjectInfoCacheImpl;
use crate::config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions};
use crate::config::{CacheOptions, ProjectInfoCacheOptions};
use crate::context::RequestContext;
use crate::control_plane::{CachedNodeInfo, ControlPlaneApi, NodeInfoCache, errors};
use crate::error::ReportableError;
@@ -121,15 +120,12 @@ pub struct ApiCaches {
pub(crate) node_info: NodeInfoCache,
/// Cache which stores project_id -> endpoint_ids mapping.
pub project_info: Arc<ProjectInfoCacheImpl>,
/// List of all valid endpoints.
pub endpoints_cache: Arc<EndpointsCache>,
}
impl ApiCaches {
pub fn new(
wake_compute_cache_config: CacheOptions,
project_info_cache_config: ProjectInfoCacheOptions,
endpoint_cache_config: EndpointCacheConfig,
) -> Self {
Self {
node_info: NodeInfoCache::new(
@@ -139,7 +135,6 @@ impl ApiCaches {
true,
),
project_info: Arc::new(ProjectInfoCacheImpl::new(project_info_cache_config)),
endpoints_cache: Arc::new(EndpointsCache::new(endpoint_cache_config)),
}
}
}

View File

@@ -99,10 +99,6 @@ pub(crate) enum GetAuthInfoError {
#[error(transparent)]
ApiError(ControlPlaneError),
/// Proxy does not know about the endpoint in advanced
#[error("endpoint not found in endpoint cache")]
UnknownEndpoint,
}
// This allows more useful interactions than `#[from]`.
@@ -119,8 +115,6 @@ impl UserFacingError for GetAuthInfoError {
Self::BadSecret => REQUEST_FAILED.to_owned(),
// However, API might return a meaningful error.
Self::ApiError(e) => e.to_string_client(),
// pretend like control plane returned an error.
Self::UnknownEndpoint => REQUEST_FAILED.to_owned(),
}
}
}
@@ -130,8 +124,6 @@ impl ReportableError for GetAuthInfoError {
match self {
Self::BadSecret => crate::error::ErrorKind::ControlPlane,
Self::ApiError(_) => crate::error::ErrorKind::ControlPlane,
// we only apply endpoint filtering if control plane is under high load.
Self::UnknownEndpoint => crate::error::ErrorKind::ServiceRateLimit,
}
}
}
@@ -200,9 +192,6 @@ impl CouldRetry for WakeComputeError {
#[derive(Debug, Error)]
pub enum GetEndpointJwksError {
#[error("endpoint not found")]
EndpointNotFound,
#[error("failed to build control plane request: {0}")]
RequestBuild(#[source] reqwest::Error),

View File

@@ -16,44 +16,6 @@ use super::LeakyBucketConfig;
use crate::ext::LockExt;
use crate::intern::EndpointIdInt;
pub struct GlobalRateLimiter {
data: Vec<RateBucket>,
info: Vec<RateBucketInfo>,
}
impl GlobalRateLimiter {
pub fn new(info: Vec<RateBucketInfo>) -> Self {
Self {
data: vec![
RateBucket {
start: Instant::now(),
count: 0,
};
info.len()
],
info,
}
}
/// Check that number of connections is below `max_rps` rps.
pub fn check(&mut self) -> bool {
let now = Instant::now();
let should_allow_request = self
.data
.iter_mut()
.zip(&self.info)
.all(|(bucket, info)| bucket.should_allow_request(info, now, 1));
if should_allow_request {
// only increment the bucket counts if the request will actually be accepted
self.data.iter_mut().for_each(|b| b.inc(1));
}
should_allow_request
}
}
// Simple per-endpoint rate limiter.
//
// Check that number of connections to the endpoint is below `max_rps` rps.

View File

@@ -8,4 +8,4 @@ pub(crate) use limit_algorithm::aimd::Aimd;
pub(crate) use limit_algorithm::{
DynamicLimiter, Outcome, RateLimitAlgorithm, RateLimiterConfig, Token,
};
pub use limiter::{GlobalRateLimiter, RateBucketInfo, WakeComputeRateLimiter};
pub use limiter::{RateBucketInfo, WakeComputeRateLimiter};

View File

@@ -1,112 +1,60 @@
use postgres_client::Row;
use postgres_client::types::{Kind, Type};
use serde::Deserialize;
use serde::de::{Deserializer, IgnoredAny, Visitor};
use serde_json::value::RawValue;
use serde_json::{Map, Value};
//
// Convert json non-string types to strings, so that they can be passed to Postgres
// as parameters.
//
pub(crate) fn json_to_pg_text(json: Vec<Box<RawValue>>) -> Vec<Option<String>> {
json.into_iter()
.map(|raw| {
match raw.get().as_bytes() {
// special handling for null.
b"null" => None,
// remove the escape characters from the string.
[b'"', ..] => {
Some(String::deserialize(&*raw).expect("json should be a valid string"))
}
[b'[', ..] => {
let mut output = String::with_capacity(raw.get().len());
raw.deserialize_seq(PgArrayVisitor(&raw, &mut output))
.expect("json should be a valid");
Some(output)
}
// write all other values out directly
_ => Some(<Box<str>>::from(raw).into()),
}
})
.collect()
pub(crate) fn json_to_pg_text(json: Vec<Value>) -> Vec<Option<String>> {
json.iter().map(json_value_to_pg_text).collect()
}
struct PgArrayVisitor<'de, 'a>(&'de RawValue, &'a mut String);
fn json_value_to_pg_text(value: &Value) -> Option<String> {
match value {
// special care for nulls
Value::Null => None,
impl PgArrayVisitor<'_, '_> {
#[inline]
#[allow(clippy::unnecessary_wraps)]
fn raw<E>(self) -> Result<(), E> {
self.1.push_str(self.0.get());
Ok(())
// convert to text with escaping
v @ (Value::Bool(_) | Value::Number(_) | Value::Object(_)) => Some(v.to_string()),
// avoid escaping here, as we pass this as a parameter
Value::String(s) => Some(s.clone()),
// special care for arrays
Value::Array(_) => json_array_to_pg_array(value),
}
}
impl<'de> Visitor<'de> for PgArrayVisitor<'de, '_> {
type Value = ();
//
// Serialize a JSON array to a Postgres array. Contrary to the strings in the params
// in the array we need to escape the strings. Postgres is okay with arrays of form
// '{1,"2",3}'::int[], so we don't check that array holds values of the same type, leaving
// it for Postgres to check.
//
// Example of the same escaping in node-postgres: packages/pg/lib/utils.js
//
fn json_array_to_pg_array(value: &Value) -> Option<String> {
match value {
// special care for nulls
Value::Null => None,
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("any valid JSON value")
}
// convert to text with escaping
// here string needs to be escaped, as it is part of the array
v @ (Value::Bool(_) | Value::Number(_) | Value::String(_)) => Some(v.to_string()),
v @ Value::Object(_) => json_array_to_pg_array(&Value::String(v.to_string())),
// special care for nulls
fn visit_none<E>(self) -> Result<Self::Value, E> {
self.1.push_str("NULL");
Ok(())
}
fn visit_unit<E>(self) -> Result<Self::Value, E> {
self.1.push_str("NULL");
Ok(())
}
// recurse into array
Value::Array(arr) => {
let vals = arr
.iter()
.map(json_array_to_pg_array)
.map(|v| v.unwrap_or_else(|| "NULL".to_string()))
.collect::<Vec<_>>()
.join(",");
// convert to text with escaping
fn visit_bool<E>(self, _: bool) -> Result<Self::Value, E> {
self.raw()
}
fn visit_i64<E>(self, _: i64) -> Result<Self::Value, E> {
self.raw()
}
fn visit_u64<E>(self, _: u64) -> Result<Self::Value, E> {
self.raw()
}
fn visit_i128<E>(self, _: i128) -> Result<Self::Value, E> {
self.raw()
}
fn visit_u128<E>(self, _: u128) -> Result<Self::Value, E> {
self.raw()
}
fn visit_f64<E>(self, _: f64) -> Result<Self::Value, E> {
self.raw()
}
fn visit_str<E>(self, _: &str) -> Result<Self::Value, E> {
self.raw()
}
// an object needs re-escaping
fn visit_map<A: serde::de::MapAccess<'de>>(self, mut map: A) -> Result<Self::Value, A::Error> {
while map.next_entry::<IgnoredAny, IgnoredAny>()?.is_some() {}
let s = serde_json::to_string(self.0.get()).expect("a string should be valid json");
self.1.push_str(&s);
Ok(())
}
// write an array
fn visit_seq<A: serde::de::SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
self.1.push('{');
let mut comma = false;
while let Some(val) = seq.next_element::<&'de RawValue>()? {
if comma {
self.1.push(',');
}
comma = true;
val.deserialize_any(PgArrayVisitor(val, self.1))
.expect("all json values are valid");
Some(format!("{{{vals}}}"))
}
self.1.push('}');
Ok(())
}
}
@@ -436,14 +384,6 @@ mod tests {
use super::*;
fn json_to_pg_text(json: Vec<serde_json::Value>) -> Vec<Option<String>> {
let json = json
.into_iter()
.map(|value| serde_json::from_str(&value.to_string()).unwrap())
.collect();
super::json_to_pg_text(json)
}
#[test]
fn test_atomic_types_to_pg_params() {
let json = vec![Value::Bool(true), Value::Bool(false)];

View File

@@ -18,6 +18,7 @@ use postgres_client::{
GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, RowStream, Transaction,
};
use serde::Serialize;
use serde_json::Value;
use serde_json::value::RawValue;
use tokio::time::{self, Instant};
use tokio_util::sync::CancellationToken;
@@ -47,8 +48,9 @@ use crate::util::run_until_cancelled;
#[serde(rename_all = "camelCase")]
struct QueryData {
query: String,
#[serde(deserialize_with = "bytes_to_pg_text")]
#[serde(default)]
params: Vec<Box<RawValue>>,
params: Vec<Option<String>>,
#[serde(default)]
array_mode: Option<bool>,
}
@@ -58,6 +60,8 @@ struct BatchQueryData {
queries: Vec<QueryData>,
}
#[derive(serde::Deserialize)]
#[serde(untagged)]
enum Payload {
Single(QueryData),
Batch(BatchQueryData),
@@ -65,6 +69,15 @@ enum Payload {
static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
fn bytes_to_pg_text<'de, D>(deserializer: D) -> Result<Vec<Option<String>>, D::Error>
where
D: serde::de::Deserializer<'de>,
{
// TODO: consider avoiding the allocation here.
let json: Vec<Value> = serde::de::Deserialize::deserialize(deserializer)?;
Ok(json_to_pg_text(json))
}
pub(crate) async fn handle(
config: &'static ProxyConfig,
ctx: RequestContext,
@@ -486,14 +499,7 @@ async fn handle_db_inner(
.observe(HttpDirection::Request, body.len() as f64);
debug!(length = body.len(), "request payload read");
// try unbatched, then try batched.
let payload = if let Ok(batch) = serde_json::from_slice(&body) {
Payload::Batch(batch)
} else {
Payload::Single(serde_json::from_slice(&body)?)
};
let payload: Payload = serde_json::from_slice(&body)?;
Ok::<Payload, ReadPayloadError>(payload) // Adjust error type accordingly
}
.map_err(SqlOverHttpError::from),
@@ -881,8 +887,7 @@ async fn query_to_json<T: GenericClient>(
) -> Result<(ReadyForQueryStatus, impl Serialize + use<T>), SqlOverHttpError> {
let query_start = Instant::now();
let query_params = json_to_pg_text(data.params);
let query_params = data.params;
let mut row_stream = client
.query_raw_txt(&data.query, query_params)
.await
@@ -1028,38 +1033,55 @@ mod tests {
#[test]
fn test_payload() {
let payload = "{\"query\":\"SELECT * FROM users WHERE name = ?\",\"params\":[\"test\"],\"arrayMode\":true}";
let QueryData {
query,
params,
array_mode,
} = serde_json::from_str(payload).unwrap();
let deserialized_payload: Payload = serde_json::from_str(payload).unwrap();
assert_eq!(query, "SELECT * FROM users WHERE name = ?");
assert_eq!(params[0].get(), "\"test\"");
assert!(array_mode.unwrap());
match deserialized_payload {
Payload::Single(QueryData {
query,
params,
array_mode,
}) => {
assert_eq!(query, "SELECT * FROM users WHERE name = ?");
assert_eq!(params, vec![Some(String::from("test"))]);
assert!(array_mode.unwrap());
}
Payload::Batch(_) => {
panic!("deserialization failed: case with single query, one param, and array mode")
}
}
let payload = "{\"queries\":[{\"query\":\"SELECT * FROM users0 WHERE name = ?\",\"params\":[\"test0\"], \"arrayMode\":false},{\"query\":\"SELECT * FROM users1 WHERE name = ?\",\"params\":[\"test1\"],\"arrayMode\":true}]}";
let BatchQueryData { queries } = serde_json::from_str(payload).unwrap();
let deserialized_payload: Payload = serde_json::from_str(payload).unwrap();
assert_eq!(queries.len(), 2);
for (i, query) in queries.into_iter().enumerate() {
assert_eq!(
query.query,
format!("SELECT * FROM users{i} WHERE name = ?")
);
assert_eq!(query.params[0].get(), &format!("\"test{i}\""));
assert_eq!(query.array_mode.unwrap(), i > 0);
match deserialized_payload {
Payload::Batch(BatchQueryData { queries }) => {
assert_eq!(queries.len(), 2);
for (i, query) in queries.into_iter().enumerate() {
assert_eq!(
query.query,
format!("SELECT * FROM users{i} WHERE name = ?")
);
assert_eq!(query.params, vec![Some(format!("test{i}"))]);
assert_eq!(query.array_mode.unwrap(), i > 0);
}
}
Payload::Single(_) => panic!("deserialization failed: case with multiple queries"),
}
let payload = "{\"query\":\"SELECT 1\"}";
let QueryData {
query,
params,
array_mode,
} = serde_json::from_str(payload).unwrap();
let deserialized_payload: Payload = serde_json::from_str(payload).unwrap();
assert_eq!(query, "SELECT 1");
assert!(params.is_empty());
assert!(array_mode.is_none());
match deserialized_payload {
Payload::Single(QueryData {
query,
params,
array_mode,
}) => {
assert_eq!(query, "SELECT 1");
assert_eq!(params, vec![]);
assert!(array_mode.is_none());
}
Payload::Batch(_) => panic!("deserialization failed: case with only one query"),
}
}
}

View File

@@ -107,13 +107,3 @@ smol_str_wrapper!(DbName);
// postgres hostname, will likely be a port:ip addr
smol_str_wrapper!(Host);
// Endpoints are a bit tricky. Rare they might be branches or projects.
impl EndpointId {
pub(crate) fn is_endpoint(&self) -> bool {
self.0.starts_with("ep-")
}
pub(crate) fn is_branch(&self) -> bool {
self.0.starts_with("br-")
}
}

View File

@@ -1066,9 +1066,10 @@ async fn handle_node_delete(req: Request<Body>) -> Result<Response<Body>, ApiErr
let state = get_state(&req);
let node_id: NodeId = parse_request_param(&req, "node_id")?;
let force: bool = parse_query_param(&req, "force")?.unwrap_or(false);
json_response(
StatusCode::OK,
state.service.start_node_delete(node_id).await?,
state.service.start_node_delete(node_id, force).await?,
)
}

View File

@@ -7165,6 +7165,7 @@ impl Service {
self: &Arc<Self>,
node_id: NodeId,
policy_on_start: NodeSchedulingPolicy,
force: bool,
cancel: CancellationToken,
) -> Result<(), OperationError> {
let reconciler_config = ReconcilerConfigBuilder::new(ReconcilerPriority::Normal).build();
@@ -7172,23 +7173,28 @@ impl Service {
let mut waiters: Vec<ReconcilerWaiter> = Vec::new();
let mut tid_iter = create_shared_shard_iterator(self.clone());
let process_cancel = || async {
// Attempt to restore the node to its original scheduling policy
match self
.node_configure(node_id, None, Some(policy_on_start))
.await
{
Ok(()) => Err(OperationError::Cancelled),
Err(err) => {
Err(OperationError::FinalizeError(
format!(
"Failed to finalise delete cancel of {} by setting scheduling policy to {}: {}",
node_id, String::from(policy_on_start), err
)
.into(),
))
}
}
};
while !tid_iter.finished() {
if cancel.is_cancelled() {
match self
.node_configure(node_id, None, Some(policy_on_start))
.await
{
Ok(()) => return Err(OperationError::Cancelled),
Err(err) => {
return Err(OperationError::FinalizeError(
format!(
"Failed to finalise delete cancel of {} by setting scheduling policy to {}: {}",
node_id, String::from(policy_on_start), err
)
.into(),
));
}
}
return process_cancel().await;
}
operation_utils::validate_node_state(
@@ -7249,13 +7255,24 @@ impl Service {
)
}
// Do not wait for any reconciliations to finish if the deletion has been forced.
let waiter = self.maybe_configured_reconcile_shard(
tenant_shard,
nodes,
reconciler_config,
);
if let Some(some) = waiter {
waiters.push(some);
if force {
// Here we remove an existing observed location for the node we're removing, and it will
// not be re-added by a reconciler's completion because we filter out removed nodes in
// process_result.
//
// Note that we update the shard's observed state _after_ calling maybe_configured_reconcile_shard:
// that means any reconciles we spawned will know about the node we're deleting,
// enabling them to do live migrations if it's still online.
tenant_shard.observed.locations.remove(&node_id);
} else if let Some(waiter) = waiter {
waiters.push(waiter);
}
}
}
@@ -7269,21 +7286,7 @@ impl Service {
while !waiters.is_empty() {
if cancel.is_cancelled() {
match self
.node_configure(node_id, None, Some(policy_on_start))
.await
{
Ok(()) => return Err(OperationError::Cancelled),
Err(err) => {
return Err(OperationError::FinalizeError(
format!(
"Failed to finalise drain cancel of {} by setting scheduling policy to {}: {}",
node_id, String::from(policy_on_start), err
)
.into(),
));
}
}
return process_cancel().await;
}
tracing::info!("Awaiting {} pending delete reconciliations", waiters.len());
@@ -7888,6 +7891,7 @@ impl Service {
pub(crate) async fn start_node_delete(
self: &Arc<Self>,
node_id: NodeId,
force: bool,
) -> Result<(), ApiError> {
let (ongoing_op, node_policy, schedulable_nodes_count) = {
let locked = self.inner.read().unwrap();
@@ -7957,7 +7961,7 @@ impl Service {
tracing::info!("Delete background operation starting");
let res = service
.delete_node(node_id, policy_on_start, cancel)
.delete_node(node_id, policy_on_start, force, cancel)
.await;
match res {
Ok(()) => {

View File

@@ -2084,11 +2084,14 @@ class NeonStorageController(MetricsGetter, LogUtils):
headers=self.headers(TokenScope.ADMIN),
)
def node_delete(self, node_id):
def node_delete(self, node_id, force: bool = False):
log.info(f"node_delete({node_id})")
query = f"{self.api}/control/v1/node/{node_id}/delete"
if force:
query += "?force=true"
self.request(
"PUT",
f"{self.api}/control/v1/node/{node_id}/delete",
query,
headers=self.headers(TokenScope.ADMIN),
)

View File

@@ -1034,16 +1034,19 @@ def test_storage_controller_compute_hook_keep_failing(
alive_pageservers = [p for p in env.pageservers if p.id != banned_tenant_ps.id]
# Stop pageserver and ban tenant to trigger failed reconciliation
log.info(f"Banning tenant {banned_tenant} and stopping pageserver {banned_tenant_ps.id}")
status_by_tenant[banned_tenant] = 423
banned_tenant_ps.stop()
env.storage_controller.allowed_errors.append(NOTIFY_BLOCKED_LOG)
env.storage_controller.allowed_errors.extend(NOTIFY_FAILURE_LOGS)
env.storage_controller.allowed_errors.append(".*Keeping extra secondaries.*")
env.storage_controller.allowed_errors.append(".*Shard reconciliation is keep-failing.*")
env.storage_controller.node_configure(banned_tenant_ps.id, {"availability": "Offline"})
# Migrate all allowed tenant shards to the first alive pageserver
# to trigger storage controller optimizations due to affinity rules
for shard_number in range(shard_count):
log.info(f"Migrating shard {shard_number} of {allowed_tenant} to {alive_pageservers[0].id}")
env.storage_controller.tenant_shard_migrate(
TenantShardId(allowed_tenant, shard_number, shard_count),
alive_pageservers[0].id,
@@ -2618,7 +2621,7 @@ def test_storage_controller_node_deletion(
wait_until(assert_shards_migrated)
log.info(f"Deleting pageserver {victim.id}")
env.storage_controller.node_delete_old(victim.id)
env.storage_controller.node_delete(victim.id, force=True)
if not while_offline:
@@ -2631,7 +2634,10 @@ def test_storage_controller_node_deletion(
wait_until(assert_victim_evacuated)
# The node should be gone from the list API
assert victim.id not in [n["id"] for n in env.storage_controller.node_list()]
def assert_victim_gone():
assert victim.id not in [n["id"] for n in env.storage_controller.node_list()]
wait_until(assert_victim_gone)
# No tenants should refer to the node in their intent
for tenant_id in tenant_ids:
@@ -3262,10 +3268,10 @@ def test_ps_unavailable_after_delete(neon_env_builder: NeonEnvBuilder):
assert_nodes_count(3)
ps = env.pageservers[0]
env.storage_controller.node_delete_old(ps.id)
env.storage_controller.node_delete(ps.id, force=True)
# After deletion, the node count must be reduced
assert_nodes_count(2)
wait_until(lambda: assert_nodes_count(2))
# Running pageserver CLI init in a separate thread
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: