mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-27 10:00:38 +00:00
Merge pull request #8505 from neondatabase/rc/proxy/2024-07-25
Proxy release 2024-07-25
This commit is contained in:
@@ -717,8 +717,10 @@ mod tests {
|
||||
_ => panic!("wrong message"),
|
||||
}
|
||||
});
|
||||
let endpoint_rate_limiter =
|
||||
Arc::new(EndpointRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET));
|
||||
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
|
||||
EndpointRateLimiter::DEFAULT,
|
||||
64,
|
||||
));
|
||||
|
||||
let _creds = auth_quirks(
|
||||
&mut ctx,
|
||||
@@ -767,8 +769,10 @@ mod tests {
|
||||
frontend::password_message(b"my-secret-password", &mut write).unwrap();
|
||||
client.write_all(&write).await.unwrap();
|
||||
});
|
||||
let endpoint_rate_limiter =
|
||||
Arc::new(EndpointRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET));
|
||||
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
|
||||
EndpointRateLimiter::DEFAULT,
|
||||
64,
|
||||
));
|
||||
|
||||
let _creds = auth_quirks(
|
||||
&mut ctx,
|
||||
@@ -818,8 +822,10 @@ mod tests {
|
||||
client.write_all(&write).await.unwrap();
|
||||
});
|
||||
|
||||
let endpoint_rate_limiter =
|
||||
Arc::new(EndpointRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET));
|
||||
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
|
||||
EndpointRateLimiter::DEFAULT,
|
||||
64,
|
||||
));
|
||||
|
||||
let creds = auth_quirks(
|
||||
&mut ctx,
|
||||
|
||||
@@ -22,7 +22,9 @@ use proxy::http;
|
||||
use proxy::http::health_server::AppMetrics;
|
||||
use proxy::metrics::Metrics;
|
||||
use proxy::rate_limiter::EndpointRateLimiter;
|
||||
use proxy::rate_limiter::LeakyBucketConfig;
|
||||
use proxy::rate_limiter::RateBucketInfo;
|
||||
use proxy::rate_limiter::WakeComputeRateLimiter;
|
||||
use proxy::redis::cancellation_publisher::RedisPublisherClient;
|
||||
use proxy::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
|
||||
use proxy::redis::elasticache;
|
||||
@@ -176,6 +178,9 @@ struct ProxyCliArgs {
|
||||
/// redis url for notifications (if empty, redis_host:port will be used for both notifications and streaming connections)
|
||||
#[clap(long)]
|
||||
redis_notifications: Option<String>,
|
||||
/// what from the available authentications type to use for the regional redis we have. Supported are "irsa" and "plain".
|
||||
#[clap(long, default_value = "irsa")]
|
||||
redis_auth_type: String,
|
||||
/// redis host for streaming connections (might be different from the notifications host)
|
||||
#[clap(long)]
|
||||
redis_host: Option<String>,
|
||||
@@ -319,24 +324,38 @@ async fn main() -> anyhow::Result<()> {
|
||||
),
|
||||
aws_credentials_provider,
|
||||
));
|
||||
let regional_redis_client = match (args.redis_host, args.redis_port) {
|
||||
(Some(host), Some(port)) => Some(
|
||||
ConnectionWithCredentialsProvider::new_with_credentials_provider(
|
||||
host,
|
||||
port,
|
||||
elasticache_credentials_provider.clone(),
|
||||
let regional_redis_client = match (args.redis_auth_type.as_str(), &args.redis_notifications) {
|
||||
("plain", redis_url) => match redis_url {
|
||||
None => {
|
||||
bail!("plain auth requires redis_notifications to be set");
|
||||
}
|
||||
Some(url) => Some(
|
||||
ConnectionWithCredentialsProvider::new_with_static_credentials(url.to_string()),
|
||||
),
|
||||
),
|
||||
(None, None) => {
|
||||
warn!("Redis events from console are disabled");
|
||||
None
|
||||
}
|
||||
},
|
||||
("irsa", _) => match (&args.redis_host, args.redis_port) {
|
||||
(Some(host), Some(port)) => Some(
|
||||
ConnectionWithCredentialsProvider::new_with_credentials_provider(
|
||||
host.to_string(),
|
||||
port,
|
||||
elasticache_credentials_provider.clone(),
|
||||
),
|
||||
),
|
||||
(None, None) => {
|
||||
warn!("irsa auth requires redis-host and redis-port to be set, continuing without regional_redis_client");
|
||||
None
|
||||
}
|
||||
_ => {
|
||||
bail!("redis-host and redis-port must be specified together");
|
||||
}
|
||||
},
|
||||
_ => {
|
||||
bail!("redis-host and redis-port must be specified together");
|
||||
bail!("unknown auth type given");
|
||||
}
|
||||
};
|
||||
|
||||
let redis_notifications_client = if let Some(url) = args.redis_notifications {
|
||||
Some(ConnectionWithCredentialsProvider::new_with_static_credentials(url))
|
||||
Some(ConnectionWithCredentialsProvider::new_with_static_credentials(url.to_string()))
|
||||
} else {
|
||||
regional_redis_client.clone()
|
||||
};
|
||||
@@ -373,9 +392,24 @@ async fn main() -> anyhow::Result<()> {
|
||||
proxy::metrics::CancellationSource::FromClient,
|
||||
));
|
||||
|
||||
let mut endpoint_rps_limit = args.endpoint_rps_limit.clone();
|
||||
RateBucketInfo::validate(&mut endpoint_rps_limit)?;
|
||||
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(endpoint_rps_limit));
|
||||
// bit of a hack - find the min rps and max rps supported and turn it into
|
||||
// leaky bucket config instead
|
||||
let max = args
|
||||
.endpoint_rps_limit
|
||||
.iter()
|
||||
.map(|x| x.rps())
|
||||
.max_by(f64::total_cmp)
|
||||
.unwrap_or(EndpointRateLimiter::DEFAULT.max);
|
||||
let rps = args
|
||||
.endpoint_rps_limit
|
||||
.iter()
|
||||
.map(|x| x.rps())
|
||||
.min_by(f64::total_cmp)
|
||||
.unwrap_or(EndpointRateLimiter::DEFAULT.rps);
|
||||
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
|
||||
LeakyBucketConfig { rps, max },
|
||||
64,
|
||||
));
|
||||
|
||||
// client facing tasks. these will exit on error or on cancellation
|
||||
// cancellation returns Ok(())
|
||||
@@ -577,7 +611,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
let mut wake_compute_rps_limit = args.wake_compute_limit.clone();
|
||||
RateBucketInfo::validate(&mut wake_compute_rps_limit)?;
|
||||
let wake_compute_endpoint_rate_limiter =
|
||||
Arc::new(EndpointRateLimiter::new(wake_compute_rps_limit));
|
||||
Arc::new(WakeComputeRateLimiter::new(wake_compute_rps_limit));
|
||||
let api = console::provider::neon::Api::new(
|
||||
endpoint,
|
||||
caches,
|
||||
|
||||
@@ -307,7 +307,7 @@ impl ConnCfg {
|
||||
let pause = ctx.latency_timer.pause(crate::metrics::Waiting::Compute);
|
||||
let (client, connection) = self.0.connect_raw(stream, tls).await?;
|
||||
drop(pause);
|
||||
tracing::Span::current().record("pid", &tracing::field::display(client.get_process_id()));
|
||||
tracing::Span::current().record("pid", tracing::field::display(client.get_process_id()));
|
||||
let stream = connection.stream.into_inner();
|
||||
|
||||
info!(
|
||||
|
||||
@@ -12,7 +12,7 @@ use crate::{
|
||||
console::messages::{ColdStartInfo, Reason},
|
||||
http,
|
||||
metrics::{CacheOutcome, Metrics},
|
||||
rate_limiter::EndpointRateLimiter,
|
||||
rate_limiter::WakeComputeRateLimiter,
|
||||
scram, EndpointCacheKey,
|
||||
};
|
||||
use crate::{cache::Cached, context::RequestMonitoring};
|
||||
@@ -26,7 +26,7 @@ pub struct Api {
|
||||
endpoint: http::Endpoint,
|
||||
pub caches: &'static ApiCaches,
|
||||
pub locks: &'static ApiLocks<EndpointCacheKey>,
|
||||
pub wake_compute_endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
pub wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
|
||||
jwt: String,
|
||||
}
|
||||
|
||||
@@ -36,7 +36,7 @@ impl Api {
|
||||
endpoint: http::Endpoint,
|
||||
caches: &'static ApiCaches,
|
||||
locks: &'static ApiLocks<EndpointCacheKey>,
|
||||
wake_compute_endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
|
||||
) -> Self {
|
||||
let jwt: String = match std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN") {
|
||||
Ok(v) => v,
|
||||
|
||||
@@ -181,8 +181,9 @@ pub async fn worker(
|
||||
let rx = futures::stream::poll_fn(move |cx| rx.poll_recv(cx));
|
||||
let rx = rx.map(RequestData::from);
|
||||
|
||||
let storage =
|
||||
GenericRemoteStorage::from_config(&remote_storage_config).context("remote storage init")?;
|
||||
let storage = GenericRemoteStorage::from_config(&remote_storage_config)
|
||||
.await
|
||||
.context("remote storage init")?;
|
||||
|
||||
let properties = WriterProperties::builder()
|
||||
.set_data_page_size_limit(config.parquet_upload_page_size)
|
||||
@@ -217,6 +218,7 @@ pub async fn worker(
|
||||
|
||||
let storage_disconnect =
|
||||
GenericRemoteStorage::from_config(&disconnect_events_storage_config)
|
||||
.await
|
||||
.context("remote storage for disconnect events init")?;
|
||||
let parquet_config_disconnect = parquet_config.clone();
|
||||
tokio::try_join!(
|
||||
@@ -545,7 +547,9 @@ mod tests {
|
||||
},
|
||||
timeout: std::time::Duration::from_secs(120),
|
||||
};
|
||||
let storage = GenericRemoteStorage::from_config(&remote_storage_config).unwrap();
|
||||
let storage = GenericRemoteStorage::from_config(&remote_storage_config)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
worker_inner(storage, rx, config).await.unwrap();
|
||||
|
||||
|
||||
@@ -3,4 +3,8 @@ mod limiter;
|
||||
pub use limit_algorithm::{
|
||||
aimd::Aimd, DynamicLimiter, Outcome, RateLimitAlgorithm, RateLimiterConfig, Token,
|
||||
};
|
||||
pub use limiter::{BucketRateLimiter, EndpointRateLimiter, GlobalRateLimiter, RateBucketInfo};
|
||||
pub use limiter::{BucketRateLimiter, GlobalRateLimiter, RateBucketInfo, WakeComputeRateLimiter};
|
||||
mod leaky_bucket;
|
||||
pub use leaky_bucket::{
|
||||
EndpointRateLimiter, LeakyBucketConfig, LeakyBucketRateLimiter, LeakyBucketState,
|
||||
};
|
||||
|
||||
171
proxy/src/rate_limiter/leaky_bucket.rs
Normal file
171
proxy/src/rate_limiter/leaky_bucket.rs
Normal file
@@ -0,0 +1,171 @@
|
||||
use std::{
|
||||
hash::Hash,
|
||||
sync::atomic::{AtomicUsize, Ordering},
|
||||
};
|
||||
|
||||
use ahash::RandomState;
|
||||
use dashmap::DashMap;
|
||||
use rand::{thread_rng, Rng};
|
||||
use tokio::time::Instant;
|
||||
use tracing::info;
|
||||
|
||||
use crate::intern::EndpointIdInt;
|
||||
|
||||
// Simple per-endpoint rate limiter.
|
||||
pub type EndpointRateLimiter = LeakyBucketRateLimiter<EndpointIdInt>;
|
||||
|
||||
pub struct LeakyBucketRateLimiter<Key> {
|
||||
map: DashMap<Key, LeakyBucketState, RandomState>,
|
||||
config: LeakyBucketConfig,
|
||||
access_count: AtomicUsize,
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq> LeakyBucketRateLimiter<K> {
|
||||
pub const DEFAULT: LeakyBucketConfig = LeakyBucketConfig {
|
||||
rps: 600.0,
|
||||
max: 1500.0,
|
||||
};
|
||||
|
||||
pub fn new_with_shards(config: LeakyBucketConfig, shards: usize) -> Self {
|
||||
Self {
|
||||
map: DashMap::with_hasher_and_shard_amount(RandomState::new(), shards),
|
||||
config,
|
||||
access_count: AtomicUsize::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check that number of connections to the endpoint is below `max_rps` rps.
|
||||
pub fn check(&self, key: K, n: u32) -> bool {
|
||||
let now = Instant::now();
|
||||
|
||||
if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 {
|
||||
self.do_gc(now);
|
||||
}
|
||||
|
||||
let mut entry = self.map.entry(key).or_insert_with(|| LeakyBucketState {
|
||||
time: now,
|
||||
filled: 0.0,
|
||||
});
|
||||
|
||||
entry.check(&self.config, now, n as f64)
|
||||
}
|
||||
|
||||
fn do_gc(&self, now: Instant) {
|
||||
info!(
|
||||
"cleaning up bucket rate limiter, current size = {}",
|
||||
self.map.len()
|
||||
);
|
||||
let n = self.map.shards().len();
|
||||
let shard = thread_rng().gen_range(0..n);
|
||||
self.map.shards()[shard]
|
||||
.write()
|
||||
.retain(|_, value| !value.get_mut().update(&self.config, now));
|
||||
}
|
||||
}
|
||||
|
||||
pub struct LeakyBucketConfig {
|
||||
pub rps: f64,
|
||||
pub max: f64,
|
||||
}
|
||||
|
||||
pub struct LeakyBucketState {
|
||||
filled: f64,
|
||||
time: Instant,
|
||||
}
|
||||
|
||||
impl LeakyBucketConfig {
|
||||
pub fn new(rps: f64, max: f64) -> Self {
|
||||
assert!(rps > 0.0, "rps must be positive");
|
||||
assert!(max > 0.0, "max must be positive");
|
||||
Self { rps, max }
|
||||
}
|
||||
}
|
||||
|
||||
impl LeakyBucketState {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
filled: 0.0,
|
||||
time: Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
/// updates the timer and returns true if the bucket is empty
|
||||
fn update(&mut self, info: &LeakyBucketConfig, now: Instant) -> bool {
|
||||
let drain = now.duration_since(self.time);
|
||||
let drain = drain.as_secs_f64() * info.rps;
|
||||
|
||||
self.filled = (self.filled - drain).clamp(0.0, info.max);
|
||||
self.time = now;
|
||||
|
||||
self.filled == 0.0
|
||||
}
|
||||
|
||||
pub fn check(&mut self, info: &LeakyBucketConfig, now: Instant, n: f64) -> bool {
|
||||
self.update(info, now);
|
||||
|
||||
if self.filled + n > info.max {
|
||||
return false;
|
||||
}
|
||||
self.filled += n;
|
||||
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for LeakyBucketState {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::time::Duration;
|
||||
|
||||
use tokio::time::Instant;
|
||||
|
||||
use super::{LeakyBucketConfig, LeakyBucketState};
|
||||
|
||||
#[tokio::test(start_paused = true)]
|
||||
async fn check() {
|
||||
let info = LeakyBucketConfig::new(500.0, 2000.0);
|
||||
let mut bucket = LeakyBucketState::new();
|
||||
|
||||
// should work for 2000 requests this second
|
||||
for _ in 0..2000 {
|
||||
assert!(bucket.check(&info, Instant::now(), 1.0));
|
||||
}
|
||||
assert!(!bucket.check(&info, Instant::now(), 1.0));
|
||||
assert_eq!(bucket.filled, 2000.0);
|
||||
|
||||
// in 1ms we should drain 0.5 tokens.
|
||||
// make sure we don't lose any tokens
|
||||
tokio::time::advance(Duration::from_millis(1)).await;
|
||||
assert!(!bucket.check(&info, Instant::now(), 1.0));
|
||||
tokio::time::advance(Duration::from_millis(1)).await;
|
||||
assert!(bucket.check(&info, Instant::now(), 1.0));
|
||||
|
||||
// in 10ms we should drain 5 tokens
|
||||
tokio::time::advance(Duration::from_millis(10)).await;
|
||||
for _ in 0..5 {
|
||||
assert!(bucket.check(&info, Instant::now(), 1.0));
|
||||
}
|
||||
assert!(!bucket.check(&info, Instant::now(), 1.0));
|
||||
|
||||
// in 10s we should drain 5000 tokens
|
||||
// but cap is only 2000
|
||||
tokio::time::advance(Duration::from_secs(10)).await;
|
||||
for _ in 0..2000 {
|
||||
assert!(bucket.check(&info, Instant::now(), 1.0));
|
||||
}
|
||||
assert!(!bucket.check(&info, Instant::now(), 1.0));
|
||||
|
||||
// should sustain 500rps
|
||||
for _ in 0..2000 {
|
||||
tokio::time::advance(Duration::from_millis(10)).await;
|
||||
for _ in 0..5 {
|
||||
assert!(bucket.check(&info, Instant::now(), 1.0));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -61,7 +61,7 @@ impl GlobalRateLimiter {
|
||||
// Purposefully ignore user name and database name as clients can reconnect
|
||||
// with different names, so we'll end up sending some http requests to
|
||||
// the control plane.
|
||||
pub type EndpointRateLimiter = BucketRateLimiter<EndpointIdInt, StdRng, RandomState>;
|
||||
pub type WakeComputeRateLimiter = BucketRateLimiter<EndpointIdInt, StdRng, RandomState>;
|
||||
|
||||
pub struct BucketRateLimiter<Key, Rand = StdRng, Hasher = RandomState> {
|
||||
map: DashMap<Key, Vec<RateBucket>, Hasher>,
|
||||
@@ -103,7 +103,7 @@ pub struct RateBucketInfo {
|
||||
|
||||
impl std::fmt::Display for RateBucketInfo {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let rps = (self.max_rpi as u64) * 1000 / self.interval.as_millis() as u64;
|
||||
let rps = self.rps().floor() as u64;
|
||||
write!(f, "{rps}@{}", humantime::format_duration(self.interval))
|
||||
}
|
||||
}
|
||||
@@ -140,6 +140,10 @@ impl RateBucketInfo {
|
||||
Self::new(200, Duration::from_secs(600)),
|
||||
];
|
||||
|
||||
pub fn rps(&self) -> f64 {
|
||||
(self.max_rpi as f64) / self.interval.as_secs_f64()
|
||||
}
|
||||
|
||||
pub fn validate(info: &mut [Self]) -> anyhow::Result<()> {
|
||||
info.sort_unstable_by_key(|info| info.interval);
|
||||
let invalid = info
|
||||
@@ -245,7 +249,7 @@ mod tests {
|
||||
use rustc_hash::FxHasher;
|
||||
use tokio::time;
|
||||
|
||||
use super::{BucketRateLimiter, EndpointRateLimiter};
|
||||
use super::{BucketRateLimiter, WakeComputeRateLimiter};
|
||||
use crate::{intern::EndpointIdInt, rate_limiter::RateBucketInfo, EndpointId};
|
||||
|
||||
#[test]
|
||||
@@ -293,7 +297,7 @@ mod tests {
|
||||
.map(|s| s.parse().unwrap())
|
||||
.collect();
|
||||
RateBucketInfo::validate(&mut rates).unwrap();
|
||||
let limiter = EndpointRateLimiter::new(rates);
|
||||
let limiter = WakeComputeRateLimiter::new(rates);
|
||||
|
||||
let endpoint = EndpointId::from("ep-my-endpoint-1234");
|
||||
let endpoint = EndpointIdInt::from(endpoint);
|
||||
|
||||
@@ -106,7 +106,7 @@ impl RedisPublisherClient {
|
||||
cancel_key_data,
|
||||
session_id,
|
||||
}))?;
|
||||
self.client.publish(PROXY_CHANNEL_NAME, payload).await?;
|
||||
let _: () = self.client.publish(PROXY_CHANNEL_NAME, payload).await?;
|
||||
Ok(())
|
||||
}
|
||||
pub async fn try_connect(&mut self) -> anyhow::Result<()> {
|
||||
|
||||
@@ -178,7 +178,7 @@ impl ConnectionWithCredentialsProvider {
|
||||
credentials_provider: Arc<CredentialsProvider>,
|
||||
) -> anyhow::Result<()> {
|
||||
let (user, password) = credentials_provider.provide_credentials().await?;
|
||||
redis::cmd("AUTH")
|
||||
let _: () = redis::cmd("AUTH")
|
||||
.arg(user)
|
||||
.arg(password)
|
||||
.query_async(con)
|
||||
|
||||
@@ -127,7 +127,7 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
|
||||
Cancel(cancel_session) => {
|
||||
tracing::Span::current().record(
|
||||
"session_id",
|
||||
&tracing::field::display(cancel_session.session_id),
|
||||
tracing::field::display(cancel_session.session_id),
|
||||
);
|
||||
Metrics::get()
|
||||
.proxy
|
||||
|
||||
@@ -241,7 +241,7 @@ impl ConnectMechanism for TokioMechanism {
|
||||
drop(pause);
|
||||
let (client, connection) = permit.release_result(res)?;
|
||||
|
||||
tracing::Span::current().record("pid", &tracing::field::display(client.get_process_id()));
|
||||
tracing::Span::current().record("pid", tracing::field::display(client.get_process_id()));
|
||||
Ok(poll_client(
|
||||
self.pool.clone(),
|
||||
ctx,
|
||||
|
||||
@@ -403,7 +403,7 @@ impl<C: ClientInnerExt> GlobalConnPool<C> {
|
||||
tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id));
|
||||
tracing::Span::current().record(
|
||||
"pid",
|
||||
&tracing::field::display(client.inner.get_process_id()),
|
||||
tracing::field::display(client.inner.get_process_id()),
|
||||
);
|
||||
info!(
|
||||
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
|
||||
|
||||
@@ -18,7 +18,7 @@ use hyper1::Response;
|
||||
use hyper1::StatusCode;
|
||||
use hyper1::{HeaderMap, Request};
|
||||
use pq_proto::StartupMessageParamsBuilder;
|
||||
use serde_json::json;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use tokio::time;
|
||||
use tokio_postgres::error::DbError;
|
||||
@@ -32,6 +32,7 @@ use tokio_postgres::Transaction;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
use typed_json::json;
|
||||
use url::Url;
|
||||
use utils::http::error::ApiError;
|
||||
|
||||
@@ -262,13 +263,8 @@ pub async fn handle(
|
||||
| SqlOverHttpError::Postgres(e) => e.as_db_error(),
|
||||
_ => None,
|
||||
};
|
||||
fn get<'a, T: serde::Serialize>(
|
||||
db: Option<&'a DbError>,
|
||||
x: impl FnOnce(&'a DbError) -> T,
|
||||
) -> Value {
|
||||
db.map(x)
|
||||
.and_then(|t| serde_json::to_value(t).ok())
|
||||
.unwrap_or_default()
|
||||
fn get<'a, T: Default>(db: Option<&'a DbError>, x: impl FnOnce(&'a DbError) -> T) -> T {
|
||||
db.map(x).unwrap_or_default()
|
||||
}
|
||||
|
||||
if let Some(db_error) = db_error {
|
||||
@@ -277,17 +273,11 @@ pub async fn handle(
|
||||
|
||||
let position = db_error.and_then(|db| db.position());
|
||||
let (position, internal_position, internal_query) = match position {
|
||||
Some(ErrorPosition::Original(position)) => (
|
||||
Value::String(position.to_string()),
|
||||
Value::Null,
|
||||
Value::Null,
|
||||
),
|
||||
Some(ErrorPosition::Internal { position, query }) => (
|
||||
Value::Null,
|
||||
Value::String(position.to_string()),
|
||||
Value::String(query.clone()),
|
||||
),
|
||||
None => (Value::Null, Value::Null, Value::Null),
|
||||
Some(ErrorPosition::Original(position)) => (Some(position.to_string()), None, None),
|
||||
Some(ErrorPosition::Internal { position, query }) => {
|
||||
(None, Some(position.to_string()), Some(query.clone()))
|
||||
}
|
||||
None => (None, None, None),
|
||||
};
|
||||
|
||||
let code = get(db_error, |db| db.code().code());
|
||||
@@ -577,10 +567,8 @@ async fn handle_inner(
|
||||
.status(StatusCode::OK)
|
||||
.header(header::CONTENT_TYPE, "application/json");
|
||||
|
||||
//
|
||||
// Now execute the query and return the result
|
||||
//
|
||||
let result = match payload {
|
||||
// Now execute the query and return the result.
|
||||
let json_output = match payload {
|
||||
Payload::Single(stmt) => stmt.process(cancel, &mut client, parsed_headers).await?,
|
||||
Payload::Batch(statements) => {
|
||||
if parsed_headers.txn_read_only {
|
||||
@@ -604,11 +592,9 @@ async fn handle_inner(
|
||||
|
||||
let metrics = client.metrics();
|
||||
|
||||
// how could this possibly fail
|
||||
let body = serde_json::to_string(&result).expect("json serialization should not fail");
|
||||
let len = body.len();
|
||||
let len = json_output.len();
|
||||
let response = response
|
||||
.body(Full::new(Bytes::from(body)))
|
||||
.body(Full::new(Bytes::from(json_output)))
|
||||
// only fails if invalid status code or invalid header/values are given.
|
||||
// these are not user configurable so it cannot fail dynamically
|
||||
.expect("building response payload should not fail");
|
||||
@@ -630,7 +616,7 @@ impl QueryData {
|
||||
cancel: CancellationToken,
|
||||
client: &mut Client<tokio_postgres::Client>,
|
||||
parsed_headers: HttpHeaders,
|
||||
) -> Result<Value, SqlOverHttpError> {
|
||||
) -> Result<String, SqlOverHttpError> {
|
||||
let (inner, mut discard) = client.inner();
|
||||
let cancel_token = inner.cancel_token();
|
||||
|
||||
@@ -643,7 +629,10 @@ impl QueryData {
|
||||
// The query successfully completed.
|
||||
Either::Left((Ok((status, results)), __not_yet_cancelled)) => {
|
||||
discard.check_idle(status);
|
||||
Ok(results)
|
||||
|
||||
let json_output =
|
||||
serde_json::to_string(&results).expect("json serialization should not fail");
|
||||
Ok(json_output)
|
||||
}
|
||||
// The query failed with an error
|
||||
Either::Left((Err(e), __not_yet_cancelled)) => {
|
||||
@@ -661,7 +650,10 @@ impl QueryData {
|
||||
// query successed before it was cancelled.
|
||||
Ok(Ok((status, results))) => {
|
||||
discard.check_idle(status);
|
||||
Ok(results)
|
||||
|
||||
let json_output = serde_json::to_string(&results)
|
||||
.expect("json serialization should not fail");
|
||||
Ok(json_output)
|
||||
}
|
||||
// query failed or was cancelled.
|
||||
Ok(Err(error)) => {
|
||||
@@ -695,7 +687,7 @@ impl BatchQueryData {
|
||||
cancel: CancellationToken,
|
||||
client: &mut Client<tokio_postgres::Client>,
|
||||
parsed_headers: HttpHeaders,
|
||||
) -> Result<Value, SqlOverHttpError> {
|
||||
) -> Result<String, SqlOverHttpError> {
|
||||
info!("starting transaction");
|
||||
let (inner, mut discard) = client.inner();
|
||||
let cancel_token = inner.cancel_token();
|
||||
@@ -717,9 +709,9 @@ impl BatchQueryData {
|
||||
e
|
||||
})?;
|
||||
|
||||
let results =
|
||||
let json_output =
|
||||
match query_batch(cancel.child_token(), &transaction, self, parsed_headers).await {
|
||||
Ok(results) => {
|
||||
Ok(json_output) => {
|
||||
info!("commit");
|
||||
let status = transaction.commit().await.map_err(|e| {
|
||||
// if we cannot commit - for now don't return connection to pool
|
||||
@@ -728,7 +720,7 @@ impl BatchQueryData {
|
||||
e
|
||||
})?;
|
||||
discard.check_idle(status);
|
||||
results
|
||||
json_output
|
||||
}
|
||||
Err(SqlOverHttpError::Cancelled(_)) => {
|
||||
if let Err(err) = cancel_token.cancel_query(NoTls).await {
|
||||
@@ -752,7 +744,7 @@ impl BatchQueryData {
|
||||
}
|
||||
};
|
||||
|
||||
Ok(json!({ "results": results }))
|
||||
Ok(json_output)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -761,7 +753,7 @@ async fn query_batch(
|
||||
transaction: &Transaction<'_>,
|
||||
queries: BatchQueryData,
|
||||
parsed_headers: HttpHeaders,
|
||||
) -> Result<Vec<Value>, SqlOverHttpError> {
|
||||
) -> Result<String, SqlOverHttpError> {
|
||||
let mut results = Vec::with_capacity(queries.queries.len());
|
||||
let mut current_size = 0;
|
||||
for stmt in queries.queries {
|
||||
@@ -786,7 +778,11 @@ async fn query_batch(
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(results)
|
||||
|
||||
let results = json!({ "results": results });
|
||||
let json_output = serde_json::to_string(&results).expect("json serialization should not fail");
|
||||
|
||||
Ok(json_output)
|
||||
}
|
||||
|
||||
async fn query_to_json<T: GenericClient>(
|
||||
@@ -794,7 +790,7 @@ async fn query_to_json<T: GenericClient>(
|
||||
data: QueryData,
|
||||
current_size: &mut usize,
|
||||
parsed_headers: HttpHeaders,
|
||||
) -> Result<(ReadyForQueryStatus, Value), SqlOverHttpError> {
|
||||
) -> Result<(ReadyForQueryStatus, impl Serialize), SqlOverHttpError> {
|
||||
info!("executing query");
|
||||
let query_params = data.params;
|
||||
let mut row_stream = std::pin::pin!(client.query_raw_txt(&data.query, query_params).await?);
|
||||
@@ -843,8 +839,8 @@ async fn query_to_json<T: GenericClient>(
|
||||
|
||||
for c in row_stream.columns() {
|
||||
fields.push(json!({
|
||||
"name": Value::String(c.name().to_owned()),
|
||||
"dataTypeID": Value::Number(c.type_().oid().into()),
|
||||
"name": c.name().to_owned(),
|
||||
"dataTypeID": c.type_().oid(),
|
||||
"tableID": c.table_oid(),
|
||||
"columnID": c.column_id(),
|
||||
"dataTypeSize": c.type_size(),
|
||||
@@ -862,15 +858,14 @@ async fn query_to_json<T: GenericClient>(
|
||||
.map(|row| pg_text_row_to_json(row, &columns, parsed_headers.raw_output, array_mode))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
// resulting JSON format is based on the format of node-postgres result
|
||||
Ok((
|
||||
ready,
|
||||
json!({
|
||||
"command": command_tag_name,
|
||||
"rowCount": command_tag_count,
|
||||
"rows": rows,
|
||||
"fields": fields,
|
||||
"rowAsArray": array_mode,
|
||||
}),
|
||||
))
|
||||
// Resulting JSON format is based on the format of node-postgres result.
|
||||
let results = json!({
|
||||
"command": command_tag_name.to_string(),
|
||||
"rowCount": command_tag_count,
|
||||
"rows": rows,
|
||||
"fields": fields,
|
||||
"rowAsArray": array_mode,
|
||||
});
|
||||
|
||||
Ok((ready, results))
|
||||
}
|
||||
|
||||
@@ -357,11 +357,15 @@ pub async fn task_backup(
|
||||
info!("metrics backup has shut down");
|
||||
}
|
||||
// Even if the remote storage is not configured, we still want to clear the metrics.
|
||||
let storage = backup_config
|
||||
.remote_storage_config
|
||||
.as_ref()
|
||||
.map(|config| GenericRemoteStorage::from_config(config).context("remote storage init"))
|
||||
.transpose()?;
|
||||
let storage = if let Some(config) = backup_config.remote_storage_config.as_ref() {
|
||||
Some(
|
||||
GenericRemoteStorage::from_config(config)
|
||||
.await
|
||||
.context("remote storage init")?,
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let mut ticker = tokio::time::interval(backup_config.interval);
|
||||
let mut prev = Utc::now();
|
||||
let hostname = hostname::get()?.as_os_str().to_string_lossy().into_owned();
|
||||
|
||||
@@ -111,7 +111,7 @@ mod tests {
|
||||
|
||||
let waiters = Arc::clone(&waiters);
|
||||
let notifier = tokio::spawn(async move {
|
||||
waiters.notify(key, Default::default())?;
|
||||
waiters.notify(key, ())?;
|
||||
Ok(())
|
||||
});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user