mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-21 07:00:38 +00:00
Compare commits
3 Commits
conrad/laz
...
ephemerals
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c8475ed008 | ||
|
|
2f3fc7cb57 | ||
|
|
e65d5f7369 |
@@ -75,6 +75,12 @@ enum Command {
|
||||
NodeStartDelete {
|
||||
#[arg(long)]
|
||||
node_id: NodeId,
|
||||
/// When `force` is true, skip waiting for shards to prewarm during migration.
|
||||
/// This can significantly speed up node deletion since prewarming all shards
|
||||
/// can take considerable time, but may result in slower initial access to
|
||||
/// migrated shards until they warm up naturally.
|
||||
#[arg(long)]
|
||||
force: bool,
|
||||
},
|
||||
/// Cancel deletion of the specified pageserver and wait for `timeout`
|
||||
/// for the operation to be canceled. May be retried.
|
||||
@@ -933,13 +939,14 @@ async fn main() -> anyhow::Result<()> {
|
||||
.dispatch::<(), ()>(Method::DELETE, format!("control/v1/node/{node_id}"), None)
|
||||
.await?;
|
||||
}
|
||||
Command::NodeStartDelete { node_id } => {
|
||||
Command::NodeStartDelete { node_id, force } => {
|
||||
let query = if force {
|
||||
format!("control/v1/node/{node_id}/delete?force=true")
|
||||
} else {
|
||||
format!("control/v1/node/{node_id}/delete")
|
||||
};
|
||||
storcon_client
|
||||
.dispatch::<(), ()>(
|
||||
Method::PUT,
|
||||
format!("control/v1/node/{node_id}/delete"),
|
||||
None,
|
||||
)
|
||||
.dispatch::<(), ()>(Method::PUT, query, None)
|
||||
.await?;
|
||||
println!("Delete started for {node_id}");
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ use tokio::net::TcpListener;
|
||||
use tokio::sync::Notify;
|
||||
use tokio::task::JoinSet;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{Instrument, error, info, warn};
|
||||
use tracing::{error, info, warn};
|
||||
use utils::sentry_init::init_sentry;
|
||||
use utils::{project_build_tag, project_git_version};
|
||||
|
||||
@@ -195,7 +195,9 @@ struct ProxyCliArgs {
|
||||
#[clap(long, default_value = config::ProjectInfoCacheOptions::CACHE_DEFAULT_OPTIONS)]
|
||||
project_info_cache: String,
|
||||
/// cache for all valid endpoints
|
||||
#[clap(long, default_value = config::EndpointCacheConfig::CACHE_DEFAULT_OPTIONS)]
|
||||
// TODO: remove after a couple of releases.
|
||||
#[clap(long, default_value_t = String::new())]
|
||||
#[deprecated]
|
||||
endpoint_cache_config: String,
|
||||
#[clap(flatten)]
|
||||
parquet_upload: ParquetUploadArgs,
|
||||
@@ -558,13 +560,6 @@ pub async fn run() -> anyhow::Result<()> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// listen for notifications of new projects/endpoints/branches
|
||||
let cache = api.caches.endpoints_cache.clone();
|
||||
let span = tracing::info_span!("endpoints_cache");
|
||||
maintenance_tasks.spawn(
|
||||
async move { cache.do_read(client, cancellation_token.clone()).await }.instrument(span),
|
||||
);
|
||||
}
|
||||
|
||||
let maintenance = loop {
|
||||
@@ -712,18 +707,15 @@ fn build_auth_backend(
|
||||
let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?;
|
||||
let project_info_cache_config: ProjectInfoCacheOptions =
|
||||
args.project_info_cache.parse()?;
|
||||
let endpoint_cache_config: config::EndpointCacheConfig =
|
||||
args.endpoint_cache_config.parse()?;
|
||||
|
||||
info!("Using NodeInfoCache (wake_compute) with options={wake_compute_cache_config:?}");
|
||||
info!(
|
||||
"Using AllowedIpsCache (wake_compute) with options={project_info_cache_config:?}"
|
||||
);
|
||||
info!("Using EndpointCacheConfig with options={endpoint_cache_config:?}");
|
||||
|
||||
let caches = Box::leak(Box::new(control_plane::caches::ApiCaches::new(
|
||||
wake_compute_cache_config,
|
||||
project_info_cache_config,
|
||||
endpoint_cache_config,
|
||||
)));
|
||||
|
||||
let config::ConcurrencyLockOptions {
|
||||
@@ -793,18 +785,15 @@ fn build_auth_backend(
|
||||
let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?;
|
||||
let project_info_cache_config: ProjectInfoCacheOptions =
|
||||
args.project_info_cache.parse()?;
|
||||
let endpoint_cache_config: config::EndpointCacheConfig =
|
||||
args.endpoint_cache_config.parse()?;
|
||||
|
||||
info!("Using NodeInfoCache (wake_compute) with options={wake_compute_cache_config:?}");
|
||||
info!(
|
||||
"Using AllowedIpsCache (wake_compute) with options={project_info_cache_config:?}"
|
||||
);
|
||||
info!("Using EndpointCacheConfig with options={endpoint_cache_config:?}");
|
||||
|
||||
let caches = Box::leak(Box::new(control_plane::caches::ApiCaches::new(
|
||||
wake_compute_cache_config,
|
||||
project_info_cache_config,
|
||||
endpoint_cache_config,
|
||||
)));
|
||||
|
||||
let config::ConcurrencyLockOptions {
|
||||
|
||||
283
proxy/src/cache/endpoints.rs
vendored
283
proxy/src/cache/endpoints.rs
vendored
@@ -1,283 +0,0 @@
|
||||
use std::convert::Infallible;
|
||||
use std::future::pending;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use clashmap::ClashSet;
|
||||
use redis::streams::{StreamReadOptions, StreamReadReply};
|
||||
use redis::{AsyncCommands, FromRedisValue, Value};
|
||||
use serde::Deserialize;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::info;
|
||||
|
||||
use crate::config::EndpointCacheConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::ext::LockExt;
|
||||
use crate::intern::{BranchIdInt, EndpointIdInt, ProjectIdInt};
|
||||
use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
|
||||
use crate::rate_limiter::GlobalRateLimiter;
|
||||
use crate::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
|
||||
use crate::types::EndpointId;
|
||||
|
||||
// TODO: this could be an enum, but events in Redis need to be fixed first.
|
||||
// ProjectCreated was sent with type:branch_created. So we ignore type.
|
||||
#[derive(Deserialize, Debug, Clone, PartialEq)]
|
||||
struct ControlPlaneEvent {
|
||||
endpoint_created: Option<EndpointCreated>,
|
||||
branch_created: Option<BranchCreated>,
|
||||
project_created: Option<ProjectCreated>,
|
||||
#[serde(rename = "type")]
|
||||
_type: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone, PartialEq)]
|
||||
struct EndpointCreated {
|
||||
endpoint_id: EndpointIdInt,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone, PartialEq)]
|
||||
struct BranchCreated {
|
||||
branch_id: BranchIdInt,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone, PartialEq)]
|
||||
struct ProjectCreated {
|
||||
project_id: ProjectIdInt,
|
||||
}
|
||||
|
||||
impl TryFrom<&Value> for ControlPlaneEvent {
|
||||
type Error = anyhow::Error;
|
||||
fn try_from(value: &Value) -> Result<Self, Self::Error> {
|
||||
let json = String::from_redis_value(value)?;
|
||||
Ok(serde_json::from_str(&json)?)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct EndpointsCache {
|
||||
config: EndpointCacheConfig,
|
||||
endpoints: ClashSet<EndpointIdInt>,
|
||||
branches: ClashSet<BranchIdInt>,
|
||||
projects: ClashSet<ProjectIdInt>,
|
||||
ready: AtomicBool,
|
||||
limiter: Arc<Mutex<GlobalRateLimiter>>,
|
||||
}
|
||||
|
||||
impl EndpointsCache {
|
||||
pub(crate) fn new(config: EndpointCacheConfig) -> Self {
|
||||
Self {
|
||||
limiter: Arc::new(Mutex::new(GlobalRateLimiter::new(
|
||||
config.limiter_info.clone(),
|
||||
))),
|
||||
config,
|
||||
endpoints: ClashSet::new(),
|
||||
branches: ClashSet::new(),
|
||||
projects: ClashSet::new(),
|
||||
ready: AtomicBool::new(false),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn is_valid(&self, ctx: &RequestContext, endpoint: &EndpointId) -> bool {
|
||||
if !self.ready.load(Ordering::Acquire) {
|
||||
// the endpoint cache is not yet fully initialised.
|
||||
return true;
|
||||
}
|
||||
|
||||
if !self.should_reject(endpoint) {
|
||||
ctx.set_rejected(false);
|
||||
return true;
|
||||
}
|
||||
|
||||
// report that we might want to reject this endpoint
|
||||
ctx.set_rejected(true);
|
||||
|
||||
// If cache is disabled, just collect the metrics and return.
|
||||
if self.config.disable_cache {
|
||||
return true;
|
||||
}
|
||||
|
||||
// If the limiter allows, we can pretend like it's valid
|
||||
// (incase it is, due to redis channel lag).
|
||||
if self.limiter.lock_propagate_poison().check() {
|
||||
return true;
|
||||
}
|
||||
|
||||
// endpoint not found, and there's too much load.
|
||||
false
|
||||
}
|
||||
|
||||
fn should_reject(&self, endpoint: &EndpointId) -> bool {
|
||||
if endpoint.is_endpoint() {
|
||||
let Some(endpoint) = EndpointIdInt::get(endpoint) else {
|
||||
// if we haven't interned this endpoint, it's not in the cache.
|
||||
return true;
|
||||
};
|
||||
!self.endpoints.contains(&endpoint)
|
||||
} else if endpoint.is_branch() {
|
||||
let Some(branch) = BranchIdInt::get(endpoint) else {
|
||||
// if we haven't interned this branch, it's not in the cache.
|
||||
return true;
|
||||
};
|
||||
!self.branches.contains(&branch)
|
||||
} else {
|
||||
let Some(project) = ProjectIdInt::get(endpoint) else {
|
||||
// if we haven't interned this project, it's not in the cache.
|
||||
return true;
|
||||
};
|
||||
!self.projects.contains(&project)
|
||||
}
|
||||
}
|
||||
|
||||
fn insert_event(&self, event: ControlPlaneEvent) {
|
||||
if let Some(endpoint_created) = event.endpoint_created {
|
||||
self.endpoints.insert(endpoint_created.endpoint_id);
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.redis_events_count
|
||||
.inc(RedisEventsCount::EndpointCreated);
|
||||
} else if let Some(branch_created) = event.branch_created {
|
||||
self.branches.insert(branch_created.branch_id);
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.redis_events_count
|
||||
.inc(RedisEventsCount::BranchCreated);
|
||||
} else if let Some(project_created) = event.project_created {
|
||||
self.projects.insert(project_created.project_id);
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.redis_events_count
|
||||
.inc(RedisEventsCount::ProjectCreated);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn do_read(
|
||||
&self,
|
||||
mut con: ConnectionWithCredentialsProvider,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> anyhow::Result<Infallible> {
|
||||
let mut last_id = "0-0".to_string();
|
||||
loop {
|
||||
if let Err(e) = con.connect().await {
|
||||
tracing::error!("error connecting to redis: {:?}", e);
|
||||
self.ready.store(false, Ordering::Release);
|
||||
}
|
||||
if let Err(e) = self.read_from_stream(&mut con, &mut last_id).await {
|
||||
tracing::error!("error reading from redis: {:?}", e);
|
||||
self.ready.store(false, Ordering::Release);
|
||||
}
|
||||
if cancellation_token.is_cancelled() {
|
||||
info!("cancellation token is cancelled, exiting");
|
||||
// Maintenance tasks run forever. Sleep forever when canceled.
|
||||
pending::<()>().await;
|
||||
}
|
||||
tokio::time::sleep(self.config.retry_interval).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn read_from_stream(
|
||||
&self,
|
||||
con: &mut ConnectionWithCredentialsProvider,
|
||||
last_id: &mut String,
|
||||
) -> anyhow::Result<()> {
|
||||
tracing::info!("reading endpoints/branches/projects from redis");
|
||||
self.batch_read(
|
||||
con,
|
||||
StreamReadOptions::default().count(self.config.initial_batch_size),
|
||||
last_id,
|
||||
true,
|
||||
)
|
||||
.await?;
|
||||
tracing::info!("ready to filter user requests");
|
||||
self.ready.store(true, Ordering::Release);
|
||||
self.batch_read(
|
||||
con,
|
||||
StreamReadOptions::default()
|
||||
.count(self.config.default_batch_size)
|
||||
.block(self.config.xread_timeout.as_millis() as usize),
|
||||
last_id,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn batch_read(
|
||||
&self,
|
||||
conn: &mut ConnectionWithCredentialsProvider,
|
||||
opts: StreamReadOptions,
|
||||
last_id: &mut String,
|
||||
return_when_finish: bool,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut total: usize = 0;
|
||||
loop {
|
||||
let mut res: StreamReadReply = conn
|
||||
.xread_options(&[&self.config.stream_name], &[last_id.as_str()], &opts)
|
||||
.await?;
|
||||
|
||||
if res.keys.is_empty() {
|
||||
if return_when_finish {
|
||||
if total != 0 {
|
||||
break;
|
||||
}
|
||||
anyhow::bail!(
|
||||
"Redis stream {} is empty, cannot be used to filter endpoints",
|
||||
self.config.stream_name
|
||||
);
|
||||
}
|
||||
// If we are not returning when finish, we should wait for more data.
|
||||
continue;
|
||||
}
|
||||
if res.keys.len() != 1 {
|
||||
anyhow::bail!("Cannot read from redis stream {}", self.config.stream_name);
|
||||
}
|
||||
|
||||
let key = res.keys.pop().expect("Checked length above");
|
||||
let len = key.ids.len();
|
||||
for stream_id in key.ids {
|
||||
total += 1;
|
||||
for value in stream_id.map.values() {
|
||||
match value.try_into() {
|
||||
Ok(event) => self.insert_event(event),
|
||||
Err(err) => {
|
||||
Metrics::get().proxy.redis_errors_total.inc(RedisErrors {
|
||||
channel: &self.config.stream_name,
|
||||
});
|
||||
tracing::error!("error parsing value {value:?}: {err:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
if total.is_power_of_two() {
|
||||
tracing::debug!("endpoints read {}", total);
|
||||
}
|
||||
*last_id = stream_id.id;
|
||||
}
|
||||
if return_when_finish && len <= self.config.default_batch_size {
|
||||
break;
|
||||
}
|
||||
}
|
||||
tracing::info!("read {} endpoints/branches/projects from redis", total);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_control_plane_event() {
|
||||
let s = r#"{"branch_created":null,"endpoint_created":{"endpoint_id":"ep-rapid-thunder-w0qqw2q9"},"project_created":null,"type":"endpoint_created"}"#;
|
||||
|
||||
let endpoint_id: EndpointId = "ep-rapid-thunder-w0qqw2q9".into();
|
||||
|
||||
assert_eq!(
|
||||
serde_json::from_str::<ControlPlaneEvent>(s).unwrap(),
|
||||
ControlPlaneEvent {
|
||||
endpoint_created: Some(EndpointCreated {
|
||||
endpoint_id: endpoint_id.into(),
|
||||
}),
|
||||
branch_created: None,
|
||||
project_created: None,
|
||||
_type: Some("endpoint_created".into()),
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
||||
1
proxy/src/cache/mod.rs
vendored
1
proxy/src/cache/mod.rs
vendored
@@ -1,5 +1,4 @@
|
||||
pub(crate) mod common;
|
||||
pub(crate) mod endpoints;
|
||||
pub(crate) mod project_info;
|
||||
mod timed_lru;
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ use crate::control_plane::locks::ApiLocks;
|
||||
use crate::control_plane::messages::{EndpointJwksResponse, JwksSettings};
|
||||
use crate::ext::TaskExt;
|
||||
use crate::intern::RoleNameInt;
|
||||
use crate::rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig};
|
||||
use crate::rate_limiter::{RateLimitAlgorithm, RateLimiterConfig};
|
||||
use crate::scram::threadpool::ThreadPool;
|
||||
use crate::serverless::GlobalConnPoolOptions;
|
||||
use crate::serverless::cancel_set::CancelSet;
|
||||
@@ -80,79 +80,6 @@ pub struct AuthenticationConfig {
|
||||
pub console_redirect_confirmation_timeout: tokio::time::Duration,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct EndpointCacheConfig {
|
||||
/// Batch size to receive all endpoints on the startup.
|
||||
pub initial_batch_size: usize,
|
||||
/// Batch size to receive endpoints.
|
||||
pub default_batch_size: usize,
|
||||
/// Timeouts for the stream read operation.
|
||||
pub xread_timeout: Duration,
|
||||
/// Stream name to read from.
|
||||
pub stream_name: String,
|
||||
/// Limiter info (to distinguish when to enable cache).
|
||||
pub limiter_info: Vec<RateBucketInfo>,
|
||||
/// Disable cache.
|
||||
/// If true, cache is ignored, but reports all statistics.
|
||||
pub disable_cache: bool,
|
||||
/// Retry interval for the stream read operation.
|
||||
pub retry_interval: Duration,
|
||||
}
|
||||
|
||||
impl EndpointCacheConfig {
|
||||
/// Default options for [`crate::control_plane::NodeInfoCache`].
|
||||
/// Notice that by default the limiter is empty, which means that cache is disabled.
|
||||
pub const CACHE_DEFAULT_OPTIONS: &'static str = "initial_batch_size=1000,default_batch_size=10,xread_timeout=5m,stream_name=controlPlane,disable_cache=true,limiter_info=1000@1s,retry_interval=1s";
|
||||
|
||||
/// Parse cache options passed via cmdline.
|
||||
/// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
|
||||
fn parse(options: &str) -> anyhow::Result<Self> {
|
||||
let mut initial_batch_size = None;
|
||||
let mut default_batch_size = None;
|
||||
let mut xread_timeout = None;
|
||||
let mut stream_name = None;
|
||||
let mut limiter_info = vec![];
|
||||
let mut disable_cache = false;
|
||||
let mut retry_interval = None;
|
||||
|
||||
for option in options.split(',') {
|
||||
let (key, value) = option
|
||||
.split_once('=')
|
||||
.with_context(|| format!("bad key-value pair: {option}"))?;
|
||||
|
||||
match key {
|
||||
"initial_batch_size" => initial_batch_size = Some(value.parse()?),
|
||||
"default_batch_size" => default_batch_size = Some(value.parse()?),
|
||||
"xread_timeout" => xread_timeout = Some(humantime::parse_duration(value)?),
|
||||
"stream_name" => stream_name = Some(value.to_string()),
|
||||
"limiter_info" => limiter_info.push(RateBucketInfo::from_str(value)?),
|
||||
"disable_cache" => disable_cache = value.parse()?,
|
||||
"retry_interval" => retry_interval = Some(humantime::parse_duration(value)?),
|
||||
unknown => bail!("unknown key: {unknown}"),
|
||||
}
|
||||
}
|
||||
RateBucketInfo::validate(&mut limiter_info)?;
|
||||
|
||||
Ok(Self {
|
||||
initial_batch_size: initial_batch_size.context("missing `initial_batch_size`")?,
|
||||
default_batch_size: default_batch_size.context("missing `default_batch_size`")?,
|
||||
xread_timeout: xread_timeout.context("missing `xread_timeout`")?,
|
||||
stream_name: stream_name.context("missing `stream_name`")?,
|
||||
disable_cache,
|
||||
limiter_info,
|
||||
retry_interval: retry_interval.context("missing `retry_interval`")?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for EndpointCacheConfig {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(options: &str) -> Result<Self, Self::Err> {
|
||||
let error = || format!("failed to parse endpoint cache options '{options}'");
|
||||
Self::parse(options).with_context(error)
|
||||
}
|
||||
}
|
||||
#[derive(Debug)]
|
||||
pub struct MetricBackupCollectionConfig {
|
||||
pub remote_storage_config: Option<RemoteStorageConfig>,
|
||||
|
||||
@@ -7,7 +7,7 @@ use once_cell::sync::OnceCell;
|
||||
use smol_str::SmolStr;
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::field::display;
|
||||
use tracing::{Span, debug, error, info_span};
|
||||
use tracing::{Span, error, info_span};
|
||||
use try_lock::TryLock;
|
||||
use uuid::Uuid;
|
||||
|
||||
@@ -15,10 +15,7 @@ use self::parquet::RequestData;
|
||||
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
|
||||
use crate::error::ErrorKind;
|
||||
use crate::intern::{BranchIdInt, ProjectIdInt};
|
||||
use crate::metrics::{
|
||||
ConnectOutcome, InvalidEndpointsGroup, LatencyAccumulated, LatencyTimer, Metrics, Protocol,
|
||||
Waiting,
|
||||
};
|
||||
use crate::metrics::{LatencyAccumulated, LatencyTimer, Metrics, Protocol, Waiting};
|
||||
use crate::pqproto::StartupMessageParams;
|
||||
use crate::protocol2::{ConnectionInfo, ConnectionInfoExtra};
|
||||
use crate::types::{DbName, EndpointId, RoleName};
|
||||
@@ -70,8 +67,6 @@ struct RequestContextInner {
|
||||
// This sender is only used to log the length of session in case of success.
|
||||
disconnect_sender: Option<mpsc::UnboundedSender<RequestData>>,
|
||||
pub(crate) latency_timer: LatencyTimer,
|
||||
// Whether proxy decided that it's not a valid endpoint end rejected it before going to cplane.
|
||||
rejected: Option<bool>,
|
||||
disconnect_timestamp: Option<chrono::DateTime<Utc>>,
|
||||
}
|
||||
|
||||
@@ -106,7 +101,6 @@ impl Clone for RequestContext {
|
||||
auth_method: inner.auth_method.clone(),
|
||||
jwt_issuer: inner.jwt_issuer.clone(),
|
||||
success: inner.success,
|
||||
rejected: inner.rejected,
|
||||
cold_start_info: inner.cold_start_info,
|
||||
pg_options: inner.pg_options.clone(),
|
||||
testodrome_query_id: inner.testodrome_query_id.clone(),
|
||||
@@ -151,7 +145,6 @@ impl RequestContext {
|
||||
auth_method: None,
|
||||
jwt_issuer: None,
|
||||
success: false,
|
||||
rejected: None,
|
||||
cold_start_info: ColdStartInfo::Unknown,
|
||||
pg_options: None,
|
||||
testodrome_query_id: None,
|
||||
@@ -183,11 +176,6 @@ impl RequestContext {
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn set_rejected(&self, rejected: bool) {
|
||||
let mut this = self.0.try_lock().expect("should not deadlock");
|
||||
this.rejected = Some(rejected);
|
||||
}
|
||||
|
||||
pub(crate) fn set_cold_start_info(&self, info: ColdStartInfo) {
|
||||
self.0
|
||||
.try_lock()
|
||||
@@ -461,38 +449,6 @@ impl RequestContextInner {
|
||||
}
|
||||
|
||||
fn log_connect(&mut self) {
|
||||
let outcome = if self.success {
|
||||
ConnectOutcome::Success
|
||||
} else {
|
||||
ConnectOutcome::Failed
|
||||
};
|
||||
|
||||
// TODO: get rid of entirely/refactor
|
||||
// check for false positives
|
||||
// AND false negatives
|
||||
if let Some(rejected) = self.rejected {
|
||||
let ep = self
|
||||
.endpoint_id
|
||||
.as_ref()
|
||||
.map(|x| x.as_str())
|
||||
.unwrap_or_default();
|
||||
// This makes sense only if cache is disabled
|
||||
debug!(
|
||||
?outcome,
|
||||
?rejected,
|
||||
?ep,
|
||||
"check endpoint is valid with outcome"
|
||||
);
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.invalid_endpoints_total
|
||||
.inc(InvalidEndpointsGroup {
|
||||
protocol: self.protocol,
|
||||
rejected: rejected.into(),
|
||||
outcome,
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(tx) = self.sender.take() {
|
||||
// If type changes, this error handling needs to be updated.
|
||||
let tx: mpsc::UnboundedSender<RequestData> = tx;
|
||||
|
||||
@@ -159,13 +159,6 @@ impl NeonControlPlaneClient {
|
||||
ctx: &RequestContext,
|
||||
endpoint: &EndpointId,
|
||||
) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
|
||||
if !self
|
||||
.caches
|
||||
.endpoints_cache
|
||||
.is_valid(ctx, &endpoint.normalize())
|
||||
{
|
||||
return Err(GetEndpointJwksError::EndpointNotFound);
|
||||
}
|
||||
let request_id = ctx.session_id().to_string();
|
||||
async {
|
||||
let request = self
|
||||
@@ -300,11 +293,6 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
|
||||
return Ok(secret);
|
||||
}
|
||||
|
||||
if !self.caches.endpoints_cache.is_valid(ctx, normalized_ep) {
|
||||
info!("endpoint is not valid, skipping the request");
|
||||
return Err(GetAuthInfoError::UnknownEndpoint);
|
||||
}
|
||||
|
||||
let auth_info = self.do_get_auth_req(ctx, endpoint, role).await?;
|
||||
|
||||
let control = EndpointAccessControl {
|
||||
@@ -346,11 +334,6 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
|
||||
return Ok(control);
|
||||
}
|
||||
|
||||
if !self.caches.endpoints_cache.is_valid(ctx, normalized_ep) {
|
||||
info!("endpoint is not valid, skipping the request");
|
||||
return Err(GetAuthInfoError::UnknownEndpoint);
|
||||
}
|
||||
|
||||
let auth_info = self.do_get_auth_req(ctx, endpoint, role).await?;
|
||||
|
||||
let control = EndpointAccessControl {
|
||||
|
||||
@@ -13,9 +13,8 @@ use tracing::{debug, info};
|
||||
use super::{EndpointAccessControl, RoleAccessControl};
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::auth::backend::jwt::{AuthRule, FetchAuthRules, FetchAuthRulesError};
|
||||
use crate::cache::endpoints::EndpointsCache;
|
||||
use crate::cache::project_info::ProjectInfoCacheImpl;
|
||||
use crate::config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions};
|
||||
use crate::config::{CacheOptions, ProjectInfoCacheOptions};
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::{CachedNodeInfo, ControlPlaneApi, NodeInfoCache, errors};
|
||||
use crate::error::ReportableError;
|
||||
@@ -121,15 +120,12 @@ pub struct ApiCaches {
|
||||
pub(crate) node_info: NodeInfoCache,
|
||||
/// Cache which stores project_id -> endpoint_ids mapping.
|
||||
pub project_info: Arc<ProjectInfoCacheImpl>,
|
||||
/// List of all valid endpoints.
|
||||
pub endpoints_cache: Arc<EndpointsCache>,
|
||||
}
|
||||
|
||||
impl ApiCaches {
|
||||
pub fn new(
|
||||
wake_compute_cache_config: CacheOptions,
|
||||
project_info_cache_config: ProjectInfoCacheOptions,
|
||||
endpoint_cache_config: EndpointCacheConfig,
|
||||
) -> Self {
|
||||
Self {
|
||||
node_info: NodeInfoCache::new(
|
||||
@@ -139,7 +135,6 @@ impl ApiCaches {
|
||||
true,
|
||||
),
|
||||
project_info: Arc::new(ProjectInfoCacheImpl::new(project_info_cache_config)),
|
||||
endpoints_cache: Arc::new(EndpointsCache::new(endpoint_cache_config)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -99,10 +99,6 @@ pub(crate) enum GetAuthInfoError {
|
||||
|
||||
#[error(transparent)]
|
||||
ApiError(ControlPlaneError),
|
||||
|
||||
/// Proxy does not know about the endpoint in advanced
|
||||
#[error("endpoint not found in endpoint cache")]
|
||||
UnknownEndpoint,
|
||||
}
|
||||
|
||||
// This allows more useful interactions than `#[from]`.
|
||||
@@ -119,8 +115,6 @@ impl UserFacingError for GetAuthInfoError {
|
||||
Self::BadSecret => REQUEST_FAILED.to_owned(),
|
||||
// However, API might return a meaningful error.
|
||||
Self::ApiError(e) => e.to_string_client(),
|
||||
// pretend like control plane returned an error.
|
||||
Self::UnknownEndpoint => REQUEST_FAILED.to_owned(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -130,8 +124,6 @@ impl ReportableError for GetAuthInfoError {
|
||||
match self {
|
||||
Self::BadSecret => crate::error::ErrorKind::ControlPlane,
|
||||
Self::ApiError(_) => crate::error::ErrorKind::ControlPlane,
|
||||
// we only apply endpoint filtering if control plane is under high load.
|
||||
Self::UnknownEndpoint => crate::error::ErrorKind::ServiceRateLimit,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -200,9 +192,6 @@ impl CouldRetry for WakeComputeError {
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum GetEndpointJwksError {
|
||||
#[error("endpoint not found")]
|
||||
EndpointNotFound,
|
||||
|
||||
#[error("failed to build control plane request: {0}")]
|
||||
RequestBuild(#[source] reqwest::Error),
|
||||
|
||||
|
||||
@@ -16,44 +16,6 @@ use super::LeakyBucketConfig;
|
||||
use crate::ext::LockExt;
|
||||
use crate::intern::EndpointIdInt;
|
||||
|
||||
pub struct GlobalRateLimiter {
|
||||
data: Vec<RateBucket>,
|
||||
info: Vec<RateBucketInfo>,
|
||||
}
|
||||
|
||||
impl GlobalRateLimiter {
|
||||
pub fn new(info: Vec<RateBucketInfo>) -> Self {
|
||||
Self {
|
||||
data: vec![
|
||||
RateBucket {
|
||||
start: Instant::now(),
|
||||
count: 0,
|
||||
};
|
||||
info.len()
|
||||
],
|
||||
info,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check that number of connections is below `max_rps` rps.
|
||||
pub fn check(&mut self) -> bool {
|
||||
let now = Instant::now();
|
||||
|
||||
let should_allow_request = self
|
||||
.data
|
||||
.iter_mut()
|
||||
.zip(&self.info)
|
||||
.all(|(bucket, info)| bucket.should_allow_request(info, now, 1));
|
||||
|
||||
if should_allow_request {
|
||||
// only increment the bucket counts if the request will actually be accepted
|
||||
self.data.iter_mut().for_each(|b| b.inc(1));
|
||||
}
|
||||
|
||||
should_allow_request
|
||||
}
|
||||
}
|
||||
|
||||
// Simple per-endpoint rate limiter.
|
||||
//
|
||||
// Check that number of connections to the endpoint is below `max_rps` rps.
|
||||
|
||||
@@ -8,4 +8,4 @@ pub(crate) use limit_algorithm::aimd::Aimd;
|
||||
pub(crate) use limit_algorithm::{
|
||||
DynamicLimiter, Outcome, RateLimitAlgorithm, RateLimiterConfig, Token,
|
||||
};
|
||||
pub use limiter::{GlobalRateLimiter, RateBucketInfo, WakeComputeRateLimiter};
|
||||
pub use limiter::{RateBucketInfo, WakeComputeRateLimiter};
|
||||
|
||||
@@ -1,112 +1,60 @@
|
||||
use postgres_client::Row;
|
||||
use postgres_client::types::{Kind, Type};
|
||||
use serde::Deserialize;
|
||||
use serde::de::{Deserializer, IgnoredAny, Visitor};
|
||||
use serde_json::value::RawValue;
|
||||
use serde_json::{Map, Value};
|
||||
|
||||
//
|
||||
// Convert json non-string types to strings, so that they can be passed to Postgres
|
||||
// as parameters.
|
||||
//
|
||||
pub(crate) fn json_to_pg_text(json: Vec<Box<RawValue>>) -> Vec<Option<String>> {
|
||||
json.into_iter()
|
||||
.map(|raw| {
|
||||
match raw.get().as_bytes() {
|
||||
// special handling for null.
|
||||
b"null" => None,
|
||||
// remove the escape characters from the string.
|
||||
[b'"', ..] => {
|
||||
Some(String::deserialize(&*raw).expect("json should be a valid string"))
|
||||
}
|
||||
[b'[', ..] => {
|
||||
let mut output = String::with_capacity(raw.get().len());
|
||||
raw.deserialize_seq(PgArrayVisitor(&raw, &mut output))
|
||||
.expect("json should be a valid");
|
||||
Some(output)
|
||||
}
|
||||
// write all other values out directly
|
||||
_ => Some(<Box<str>>::from(raw).into()),
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
pub(crate) fn json_to_pg_text(json: Vec<Value>) -> Vec<Option<String>> {
|
||||
json.iter().map(json_value_to_pg_text).collect()
|
||||
}
|
||||
|
||||
struct PgArrayVisitor<'de, 'a>(&'de RawValue, &'a mut String);
|
||||
fn json_value_to_pg_text(value: &Value) -> Option<String> {
|
||||
match value {
|
||||
// special care for nulls
|
||||
Value::Null => None,
|
||||
|
||||
impl PgArrayVisitor<'_, '_> {
|
||||
#[inline]
|
||||
#[allow(clippy::unnecessary_wraps)]
|
||||
fn raw<E>(self) -> Result<(), E> {
|
||||
self.1.push_str(self.0.get());
|
||||
Ok(())
|
||||
// convert to text with escaping
|
||||
v @ (Value::Bool(_) | Value::Number(_) | Value::Object(_)) => Some(v.to_string()),
|
||||
|
||||
// avoid escaping here, as we pass this as a parameter
|
||||
Value::String(s) => Some(s.clone()),
|
||||
|
||||
// special care for arrays
|
||||
Value::Array(_) => json_array_to_pg_array(value),
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Visitor<'de> for PgArrayVisitor<'de, '_> {
|
||||
type Value = ();
|
||||
//
|
||||
// Serialize a JSON array to a Postgres array. Contrary to the strings in the params
|
||||
// in the array we need to escape the strings. Postgres is okay with arrays of form
|
||||
// '{1,"2",3}'::int[], so we don't check that array holds values of the same type, leaving
|
||||
// it for Postgres to check.
|
||||
//
|
||||
// Example of the same escaping in node-postgres: packages/pg/lib/utils.js
|
||||
//
|
||||
fn json_array_to_pg_array(value: &Value) -> Option<String> {
|
||||
match value {
|
||||
// special care for nulls
|
||||
Value::Null => None,
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
formatter.write_str("any valid JSON value")
|
||||
}
|
||||
// convert to text with escaping
|
||||
// here string needs to be escaped, as it is part of the array
|
||||
v @ (Value::Bool(_) | Value::Number(_) | Value::String(_)) => Some(v.to_string()),
|
||||
v @ Value::Object(_) => json_array_to_pg_array(&Value::String(v.to_string())),
|
||||
|
||||
// special care for nulls
|
||||
fn visit_none<E>(self) -> Result<Self::Value, E> {
|
||||
self.1.push_str("NULL");
|
||||
Ok(())
|
||||
}
|
||||
fn visit_unit<E>(self) -> Result<Self::Value, E> {
|
||||
self.1.push_str("NULL");
|
||||
Ok(())
|
||||
}
|
||||
// recurse into array
|
||||
Value::Array(arr) => {
|
||||
let vals = arr
|
||||
.iter()
|
||||
.map(json_array_to_pg_array)
|
||||
.map(|v| v.unwrap_or_else(|| "NULL".to_string()))
|
||||
.collect::<Vec<_>>()
|
||||
.join(",");
|
||||
|
||||
// convert to text with escaping
|
||||
fn visit_bool<E>(self, _: bool) -> Result<Self::Value, E> {
|
||||
self.raw()
|
||||
}
|
||||
fn visit_i64<E>(self, _: i64) -> Result<Self::Value, E> {
|
||||
self.raw()
|
||||
}
|
||||
fn visit_u64<E>(self, _: u64) -> Result<Self::Value, E> {
|
||||
self.raw()
|
||||
}
|
||||
fn visit_i128<E>(self, _: i128) -> Result<Self::Value, E> {
|
||||
self.raw()
|
||||
}
|
||||
fn visit_u128<E>(self, _: u128) -> Result<Self::Value, E> {
|
||||
self.raw()
|
||||
}
|
||||
fn visit_f64<E>(self, _: f64) -> Result<Self::Value, E> {
|
||||
self.raw()
|
||||
}
|
||||
fn visit_str<E>(self, _: &str) -> Result<Self::Value, E> {
|
||||
self.raw()
|
||||
}
|
||||
|
||||
// an object needs re-escaping
|
||||
fn visit_map<A: serde::de::MapAccess<'de>>(self, mut map: A) -> Result<Self::Value, A::Error> {
|
||||
while map.next_entry::<IgnoredAny, IgnoredAny>()?.is_some() {}
|
||||
|
||||
let s = serde_json::to_string(self.0.get()).expect("a string should be valid json");
|
||||
self.1.push_str(&s);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// write an array
|
||||
fn visit_seq<A: serde::de::SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
|
||||
self.1.push('{');
|
||||
let mut comma = false;
|
||||
while let Some(val) = seq.next_element::<&'de RawValue>()? {
|
||||
if comma {
|
||||
self.1.push(',');
|
||||
}
|
||||
comma = true;
|
||||
|
||||
val.deserialize_any(PgArrayVisitor(val, self.1))
|
||||
.expect("all json values are valid");
|
||||
Some(format!("{{{vals}}}"))
|
||||
}
|
||||
self.1.push('}');
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -436,14 +384,6 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
|
||||
fn json_to_pg_text(json: Vec<serde_json::Value>) -> Vec<Option<String>> {
|
||||
let json = json
|
||||
.into_iter()
|
||||
.map(|value| serde_json::from_str(&value.to_string()).unwrap())
|
||||
.collect();
|
||||
super::json_to_pg_text(json)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_atomic_types_to_pg_params() {
|
||||
let json = vec![Value::Bool(true), Value::Bool(false)];
|
||||
|
||||
@@ -18,6 +18,7 @@ use postgres_client::{
|
||||
GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, RowStream, Transaction,
|
||||
};
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use serde_json::value::RawValue;
|
||||
use tokio::time::{self, Instant};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
@@ -47,8 +48,9 @@ use crate::util::run_until_cancelled;
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct QueryData {
|
||||
query: String,
|
||||
#[serde(deserialize_with = "bytes_to_pg_text")]
|
||||
#[serde(default)]
|
||||
params: Vec<Box<RawValue>>,
|
||||
params: Vec<Option<String>>,
|
||||
#[serde(default)]
|
||||
array_mode: Option<bool>,
|
||||
}
|
||||
@@ -58,6 +60,8 @@ struct BatchQueryData {
|
||||
queries: Vec<QueryData>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum Payload {
|
||||
Single(QueryData),
|
||||
Batch(BatchQueryData),
|
||||
@@ -65,6 +69,15 @@ enum Payload {
|
||||
|
||||
static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
|
||||
|
||||
fn bytes_to_pg_text<'de, D>(deserializer: D) -> Result<Vec<Option<String>>, D::Error>
|
||||
where
|
||||
D: serde::de::Deserializer<'de>,
|
||||
{
|
||||
// TODO: consider avoiding the allocation here.
|
||||
let json: Vec<Value> = serde::de::Deserialize::deserialize(deserializer)?;
|
||||
Ok(json_to_pg_text(json))
|
||||
}
|
||||
|
||||
pub(crate) async fn handle(
|
||||
config: &'static ProxyConfig,
|
||||
ctx: RequestContext,
|
||||
@@ -486,14 +499,7 @@ async fn handle_db_inner(
|
||||
.observe(HttpDirection::Request, body.len() as f64);
|
||||
|
||||
debug!(length = body.len(), "request payload read");
|
||||
|
||||
// try unbatched, then try batched.
|
||||
let payload = if let Ok(batch) = serde_json::from_slice(&body) {
|
||||
Payload::Batch(batch)
|
||||
} else {
|
||||
Payload::Single(serde_json::from_slice(&body)?)
|
||||
};
|
||||
|
||||
let payload: Payload = serde_json::from_slice(&body)?;
|
||||
Ok::<Payload, ReadPayloadError>(payload) // Adjust error type accordingly
|
||||
}
|
||||
.map_err(SqlOverHttpError::from),
|
||||
@@ -881,8 +887,7 @@ async fn query_to_json<T: GenericClient>(
|
||||
) -> Result<(ReadyForQueryStatus, impl Serialize + use<T>), SqlOverHttpError> {
|
||||
let query_start = Instant::now();
|
||||
|
||||
let query_params = json_to_pg_text(data.params);
|
||||
|
||||
let query_params = data.params;
|
||||
let mut row_stream = client
|
||||
.query_raw_txt(&data.query, query_params)
|
||||
.await
|
||||
@@ -1028,38 +1033,55 @@ mod tests {
|
||||
#[test]
|
||||
fn test_payload() {
|
||||
let payload = "{\"query\":\"SELECT * FROM users WHERE name = ?\",\"params\":[\"test\"],\"arrayMode\":true}";
|
||||
let QueryData {
|
||||
query,
|
||||
params,
|
||||
array_mode,
|
||||
} = serde_json::from_str(payload).unwrap();
|
||||
let deserialized_payload: Payload = serde_json::from_str(payload).unwrap();
|
||||
|
||||
assert_eq!(query, "SELECT * FROM users WHERE name = ?");
|
||||
assert_eq!(params[0].get(), "\"test\"");
|
||||
assert!(array_mode.unwrap());
|
||||
match deserialized_payload {
|
||||
Payload::Single(QueryData {
|
||||
query,
|
||||
params,
|
||||
array_mode,
|
||||
}) => {
|
||||
assert_eq!(query, "SELECT * FROM users WHERE name = ?");
|
||||
assert_eq!(params, vec![Some(String::from("test"))]);
|
||||
assert!(array_mode.unwrap());
|
||||
}
|
||||
Payload::Batch(_) => {
|
||||
panic!("deserialization failed: case with single query, one param, and array mode")
|
||||
}
|
||||
}
|
||||
|
||||
let payload = "{\"queries\":[{\"query\":\"SELECT * FROM users0 WHERE name = ?\",\"params\":[\"test0\"], \"arrayMode\":false},{\"query\":\"SELECT * FROM users1 WHERE name = ?\",\"params\":[\"test1\"],\"arrayMode\":true}]}";
|
||||
let BatchQueryData { queries } = serde_json::from_str(payload).unwrap();
|
||||
let deserialized_payload: Payload = serde_json::from_str(payload).unwrap();
|
||||
|
||||
assert_eq!(queries.len(), 2);
|
||||
for (i, query) in queries.into_iter().enumerate() {
|
||||
assert_eq!(
|
||||
query.query,
|
||||
format!("SELECT * FROM users{i} WHERE name = ?")
|
||||
);
|
||||
assert_eq!(query.params[0].get(), &format!("\"test{i}\""));
|
||||
assert_eq!(query.array_mode.unwrap(), i > 0);
|
||||
match deserialized_payload {
|
||||
Payload::Batch(BatchQueryData { queries }) => {
|
||||
assert_eq!(queries.len(), 2);
|
||||
for (i, query) in queries.into_iter().enumerate() {
|
||||
assert_eq!(
|
||||
query.query,
|
||||
format!("SELECT * FROM users{i} WHERE name = ?")
|
||||
);
|
||||
assert_eq!(query.params, vec![Some(format!("test{i}"))]);
|
||||
assert_eq!(query.array_mode.unwrap(), i > 0);
|
||||
}
|
||||
}
|
||||
Payload::Single(_) => panic!("deserialization failed: case with multiple queries"),
|
||||
}
|
||||
|
||||
let payload = "{\"query\":\"SELECT 1\"}";
|
||||
let QueryData {
|
||||
query,
|
||||
params,
|
||||
array_mode,
|
||||
} = serde_json::from_str(payload).unwrap();
|
||||
let deserialized_payload: Payload = serde_json::from_str(payload).unwrap();
|
||||
|
||||
assert_eq!(query, "SELECT 1");
|
||||
assert!(params.is_empty());
|
||||
assert!(array_mode.is_none());
|
||||
match deserialized_payload {
|
||||
Payload::Single(QueryData {
|
||||
query,
|
||||
params,
|
||||
array_mode,
|
||||
}) => {
|
||||
assert_eq!(query, "SELECT 1");
|
||||
assert_eq!(params, vec![]);
|
||||
assert!(array_mode.is_none());
|
||||
}
|
||||
Payload::Batch(_) => panic!("deserialization failed: case with only one query"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -107,13 +107,3 @@ smol_str_wrapper!(DbName);
|
||||
|
||||
// postgres hostname, will likely be a port:ip addr
|
||||
smol_str_wrapper!(Host);
|
||||
|
||||
// Endpoints are a bit tricky. Rare they might be branches or projects.
|
||||
impl EndpointId {
|
||||
pub(crate) fn is_endpoint(&self) -> bool {
|
||||
self.0.starts_with("ep-")
|
||||
}
|
||||
pub(crate) fn is_branch(&self) -> bool {
|
||||
self.0.starts_with("br-")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1066,9 +1066,10 @@ async fn handle_node_delete(req: Request<Body>) -> Result<Response<Body>, ApiErr
|
||||
|
||||
let state = get_state(&req);
|
||||
let node_id: NodeId = parse_request_param(&req, "node_id")?;
|
||||
let force: bool = parse_query_param(&req, "force")?.unwrap_or(false);
|
||||
json_response(
|
||||
StatusCode::OK,
|
||||
state.service.start_node_delete(node_id).await?,
|
||||
state.service.start_node_delete(node_id, force).await?,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -7165,6 +7165,7 @@ impl Service {
|
||||
self: &Arc<Self>,
|
||||
node_id: NodeId,
|
||||
policy_on_start: NodeSchedulingPolicy,
|
||||
force: bool,
|
||||
cancel: CancellationToken,
|
||||
) -> Result<(), OperationError> {
|
||||
let reconciler_config = ReconcilerConfigBuilder::new(ReconcilerPriority::Normal).build();
|
||||
@@ -7172,23 +7173,28 @@ impl Service {
|
||||
let mut waiters: Vec<ReconcilerWaiter> = Vec::new();
|
||||
let mut tid_iter = create_shared_shard_iterator(self.clone());
|
||||
|
||||
let process_cancel = || async {
|
||||
// Attempt to restore the node to its original scheduling policy
|
||||
match self
|
||||
.node_configure(node_id, None, Some(policy_on_start))
|
||||
.await
|
||||
{
|
||||
Ok(()) => Err(OperationError::Cancelled),
|
||||
Err(err) => {
|
||||
Err(OperationError::FinalizeError(
|
||||
format!(
|
||||
"Failed to finalise delete cancel of {} by setting scheduling policy to {}: {}",
|
||||
node_id, String::from(policy_on_start), err
|
||||
)
|
||||
.into(),
|
||||
))
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
while !tid_iter.finished() {
|
||||
if cancel.is_cancelled() {
|
||||
match self
|
||||
.node_configure(node_id, None, Some(policy_on_start))
|
||||
.await
|
||||
{
|
||||
Ok(()) => return Err(OperationError::Cancelled),
|
||||
Err(err) => {
|
||||
return Err(OperationError::FinalizeError(
|
||||
format!(
|
||||
"Failed to finalise delete cancel of {} by setting scheduling policy to {}: {}",
|
||||
node_id, String::from(policy_on_start), err
|
||||
)
|
||||
.into(),
|
||||
));
|
||||
}
|
||||
}
|
||||
return process_cancel().await;
|
||||
}
|
||||
|
||||
operation_utils::validate_node_state(
|
||||
@@ -7249,13 +7255,24 @@ impl Service {
|
||||
)
|
||||
}
|
||||
|
||||
// Do not wait for any reconciliations to finish if the deletion has been forced.
|
||||
let waiter = self.maybe_configured_reconcile_shard(
|
||||
tenant_shard,
|
||||
nodes,
|
||||
reconciler_config,
|
||||
);
|
||||
if let Some(some) = waiter {
|
||||
waiters.push(some);
|
||||
|
||||
if force {
|
||||
// Here we remove an existing observed location for the node we're removing, and it will
|
||||
// not be re-added by a reconciler's completion because we filter out removed nodes in
|
||||
// process_result.
|
||||
//
|
||||
// Note that we update the shard's observed state _after_ calling maybe_configured_reconcile_shard:
|
||||
// that means any reconciles we spawned will know about the node we're deleting,
|
||||
// enabling them to do live migrations if it's still online.
|
||||
tenant_shard.observed.locations.remove(&node_id);
|
||||
} else if let Some(waiter) = waiter {
|
||||
waiters.push(waiter);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -7269,21 +7286,7 @@ impl Service {
|
||||
|
||||
while !waiters.is_empty() {
|
||||
if cancel.is_cancelled() {
|
||||
match self
|
||||
.node_configure(node_id, None, Some(policy_on_start))
|
||||
.await
|
||||
{
|
||||
Ok(()) => return Err(OperationError::Cancelled),
|
||||
Err(err) => {
|
||||
return Err(OperationError::FinalizeError(
|
||||
format!(
|
||||
"Failed to finalise drain cancel of {} by setting scheduling policy to {}: {}",
|
||||
node_id, String::from(policy_on_start), err
|
||||
)
|
||||
.into(),
|
||||
));
|
||||
}
|
||||
}
|
||||
return process_cancel().await;
|
||||
}
|
||||
|
||||
tracing::info!("Awaiting {} pending delete reconciliations", waiters.len());
|
||||
@@ -7888,6 +7891,7 @@ impl Service {
|
||||
pub(crate) async fn start_node_delete(
|
||||
self: &Arc<Self>,
|
||||
node_id: NodeId,
|
||||
force: bool,
|
||||
) -> Result<(), ApiError> {
|
||||
let (ongoing_op, node_policy, schedulable_nodes_count) = {
|
||||
let locked = self.inner.read().unwrap();
|
||||
@@ -7957,7 +7961,7 @@ impl Service {
|
||||
|
||||
tracing::info!("Delete background operation starting");
|
||||
let res = service
|
||||
.delete_node(node_id, policy_on_start, cancel)
|
||||
.delete_node(node_id, policy_on_start, force, cancel)
|
||||
.await;
|
||||
match res {
|
||||
Ok(()) => {
|
||||
|
||||
@@ -2084,11 +2084,14 @@ class NeonStorageController(MetricsGetter, LogUtils):
|
||||
headers=self.headers(TokenScope.ADMIN),
|
||||
)
|
||||
|
||||
def node_delete(self, node_id):
|
||||
def node_delete(self, node_id, force: bool = False):
|
||||
log.info(f"node_delete({node_id})")
|
||||
query = f"{self.api}/control/v1/node/{node_id}/delete"
|
||||
if force:
|
||||
query += "?force=true"
|
||||
self.request(
|
||||
"PUT",
|
||||
f"{self.api}/control/v1/node/{node_id}/delete",
|
||||
query,
|
||||
headers=self.headers(TokenScope.ADMIN),
|
||||
)
|
||||
|
||||
|
||||
@@ -1034,16 +1034,19 @@ def test_storage_controller_compute_hook_keep_failing(
|
||||
alive_pageservers = [p for p in env.pageservers if p.id != banned_tenant_ps.id]
|
||||
|
||||
# Stop pageserver and ban tenant to trigger failed reconciliation
|
||||
log.info(f"Banning tenant {banned_tenant} and stopping pageserver {banned_tenant_ps.id}")
|
||||
status_by_tenant[banned_tenant] = 423
|
||||
banned_tenant_ps.stop()
|
||||
env.storage_controller.allowed_errors.append(NOTIFY_BLOCKED_LOG)
|
||||
env.storage_controller.allowed_errors.extend(NOTIFY_FAILURE_LOGS)
|
||||
env.storage_controller.allowed_errors.append(".*Keeping extra secondaries.*")
|
||||
env.storage_controller.allowed_errors.append(".*Shard reconciliation is keep-failing.*")
|
||||
env.storage_controller.node_configure(banned_tenant_ps.id, {"availability": "Offline"})
|
||||
|
||||
# Migrate all allowed tenant shards to the first alive pageserver
|
||||
# to trigger storage controller optimizations due to affinity rules
|
||||
for shard_number in range(shard_count):
|
||||
log.info(f"Migrating shard {shard_number} of {allowed_tenant} to {alive_pageservers[0].id}")
|
||||
env.storage_controller.tenant_shard_migrate(
|
||||
TenantShardId(allowed_tenant, shard_number, shard_count),
|
||||
alive_pageservers[0].id,
|
||||
@@ -2618,7 +2621,7 @@ def test_storage_controller_node_deletion(
|
||||
wait_until(assert_shards_migrated)
|
||||
|
||||
log.info(f"Deleting pageserver {victim.id}")
|
||||
env.storage_controller.node_delete_old(victim.id)
|
||||
env.storage_controller.node_delete(victim.id, force=True)
|
||||
|
||||
if not while_offline:
|
||||
|
||||
@@ -2631,7 +2634,10 @@ def test_storage_controller_node_deletion(
|
||||
wait_until(assert_victim_evacuated)
|
||||
|
||||
# The node should be gone from the list API
|
||||
assert victim.id not in [n["id"] for n in env.storage_controller.node_list()]
|
||||
def assert_victim_gone():
|
||||
assert victim.id not in [n["id"] for n in env.storage_controller.node_list()]
|
||||
|
||||
wait_until(assert_victim_gone)
|
||||
|
||||
# No tenants should refer to the node in their intent
|
||||
for tenant_id in tenant_ids:
|
||||
@@ -3262,10 +3268,10 @@ def test_ps_unavailable_after_delete(neon_env_builder: NeonEnvBuilder):
|
||||
assert_nodes_count(3)
|
||||
|
||||
ps = env.pageservers[0]
|
||||
env.storage_controller.node_delete_old(ps.id)
|
||||
env.storage_controller.node_delete(ps.id, force=True)
|
||||
|
||||
# After deletion, the node count must be reduced
|
||||
assert_nodes_count(2)
|
||||
wait_until(lambda: assert_nodes_count(2))
|
||||
|
||||
# Running pageserver CLI init in a separate thread
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
|
||||
|
||||
Reference in New Issue
Block a user