mirror of
https://github.com/neondatabase/neon.git
synced 2025-12-23 06:09:59 +00:00
chore(proxy): disallow unwrap and unimplemented (#10142)
As the title says, I updated the lint rules to no longer allow unwrap or unimplemented. Three special cases: * Tests are allowed to use them * std::sync::Mutex lock().unwrap() is common because it's usually correct to continue panicking on poison * `tokio::spawn_blocking(...).await.unwrap()` is common because it will only error if the blocking fn panics, so continuing the panic is also correct I've introduced two extension traits to help with these last two, that are a bit more explicit so they don't need an expect message every time.
This commit is contained in:
@@ -776,6 +776,7 @@ impl From<&jose_jwk::Key> for KeyType {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[expect(clippy::unwrap_used)]
|
||||
mod tests {
|
||||
use std::future::IntoFuture;
|
||||
use std::net::SocketAddr;
|
||||
|
||||
@@ -463,6 +463,8 @@ impl ComputeConnectBackend for Backend<'_, ComputeCredentials> {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#![allow(clippy::unimplemented, clippy::unwrap_used)]
|
||||
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
@@ -250,6 +250,7 @@ fn project_name_valid(name: &str) -> bool {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[expect(clippy::unwrap_used)]
|
||||
mod tests {
|
||||
use serde_json::json;
|
||||
use ComputeUserInfoParseError::*;
|
||||
|
||||
4
proxy/src/cache/endpoints.rs
vendored
4
proxy/src/cache/endpoints.rs
vendored
@@ -12,6 +12,7 @@ use tracing::info;
|
||||
|
||||
use crate::config::EndpointCacheConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::ext::LockExt;
|
||||
use crate::intern::{BranchIdInt, EndpointIdInt, ProjectIdInt};
|
||||
use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
|
||||
use crate::rate_limiter::GlobalRateLimiter;
|
||||
@@ -96,7 +97,7 @@ impl EndpointsCache {
|
||||
|
||||
// If the limiter allows, we can pretend like it's valid
|
||||
// (incase it is, due to redis channel lag).
|
||||
if self.limiter.lock().unwrap().check() {
|
||||
if self.limiter.lock_propagate_poison().check() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -258,6 +259,7 @@ impl EndpointsCache {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[expect(clippy::unwrap_used)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
|
||||
1
proxy/src/cache/project_info.rs
vendored
1
proxy/src/cache/project_info.rs
vendored
@@ -365,6 +365,7 @@ impl Cache for ProjectInfoCacheImpl {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[expect(clippy::unwrap_used)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::scram::ServerSecret;
|
||||
|
||||
@@ -13,6 +13,7 @@ use uuid::Uuid;
|
||||
|
||||
use crate::auth::{check_peer_addr_is_in_list, IpPattern};
|
||||
use crate::error::ReportableError;
|
||||
use crate::ext::LockExt;
|
||||
use crate::metrics::{CancellationRequest, CancellationSource, Metrics};
|
||||
use crate::rate_limiter::LeakyBucketRateLimiter;
|
||||
use crate::redis::cancellation_publisher::{
|
||||
@@ -114,7 +115,7 @@ impl<P: CancellationPublisher> CancellationHandler<P> {
|
||||
IpAddr::V4(ip) => IpNet::V4(Ipv4Net::new_assert(ip, 24).trunc()), // use defaut mask here
|
||||
IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()),
|
||||
};
|
||||
if !self.limiter.lock().unwrap().check(subnet_key, 1) {
|
||||
if !self.limiter.lock_propagate_poison().check(subnet_key, 1) {
|
||||
// log only the subnet part of the IP address to know which subnet is rate limited
|
||||
tracing::warn!("Rate limit exceeded. Skipping cancellation message, {subnet_key}");
|
||||
Metrics::get()
|
||||
@@ -283,6 +284,7 @@ impl<P> Drop for Session<P> {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[expect(clippy::unwrap_used)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
|
||||
@@ -221,15 +221,10 @@ impl CertResolver {
|
||||
) -> anyhow::Result<()> {
|
||||
let priv_key = {
|
||||
let key_bytes = std::fs::read(key_path)
|
||||
.context(format!("Failed to read TLS keys at '{key_path}'"))?;
|
||||
let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..]).collect_vec();
|
||||
|
||||
ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
|
||||
PrivateKeyDer::Pkcs8(
|
||||
keys.pop()
|
||||
.unwrap()
|
||||
.context(format!("Failed to parse TLS keys at '{key_path}'"))?,
|
||||
)
|
||||
.with_context(|| format!("Failed to read TLS keys at '{key_path}'"))?;
|
||||
rustls_pemfile::private_key(&mut &key_bytes[..])
|
||||
.with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
|
||||
.with_context(|| format!("Failed to parse TLS keys at '{key_path}'"))?
|
||||
};
|
||||
|
||||
let cert_chain_bytes = std::fs::read(cert_path)
|
||||
|
||||
@@ -23,6 +23,7 @@ use utils::backoff;
|
||||
use super::{RequestContextInner, LOG_CHAN};
|
||||
use crate::config::remote_storage_from_toml;
|
||||
use crate::context::LOG_CHAN_DISCONNECT;
|
||||
use crate::ext::TaskExt;
|
||||
|
||||
#[derive(clap::Args, Clone, Debug)]
|
||||
pub struct ParquetUploadArgs {
|
||||
@@ -171,7 +172,9 @@ pub async fn worker(
|
||||
};
|
||||
|
||||
let (tx, mut rx) = mpsc::unbounded_channel();
|
||||
LOG_CHAN.set(tx.downgrade()).unwrap();
|
||||
LOG_CHAN
|
||||
.set(tx.downgrade())
|
||||
.expect("only one worker should set the channel");
|
||||
|
||||
// setup row stream that will close on cancellation
|
||||
let cancellation_token2 = cancellation_token.clone();
|
||||
@@ -207,7 +210,9 @@ pub async fn worker(
|
||||
config.parquet_upload_disconnect_events_remote_storage
|
||||
{
|
||||
let (tx_disconnect, mut rx_disconnect) = mpsc::unbounded_channel();
|
||||
LOG_CHAN_DISCONNECT.set(tx_disconnect.downgrade()).unwrap();
|
||||
LOG_CHAN_DISCONNECT
|
||||
.set(tx_disconnect.downgrade())
|
||||
.expect("only one worker should set the channel");
|
||||
|
||||
// setup row stream that will close on cancellation
|
||||
tokio::spawn(async move {
|
||||
@@ -326,7 +331,7 @@ where
|
||||
Ok::<_, parquet::errors::ParquetError>((rows, w, rg_meta))
|
||||
})
|
||||
.await
|
||||
.unwrap()?;
|
||||
.propagate_task_panic()?;
|
||||
|
||||
rows.clear();
|
||||
Ok((rows, w, rg_meta))
|
||||
@@ -352,7 +357,7 @@ async fn upload_parquet(
|
||||
Ok((buffer, metadata))
|
||||
})
|
||||
.await
|
||||
.unwrap()?;
|
||||
.propagate_task_panic()?;
|
||||
|
||||
let data = buffer.split().freeze();
|
||||
|
||||
@@ -409,6 +414,7 @@ async fn upload_parquet(
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[expect(clippy::unwrap_used)]
|
||||
mod tests {
|
||||
use std::net::Ipv4Addr;
|
||||
use std::num::NonZeroUsize;
|
||||
|
||||
@@ -102,7 +102,9 @@ impl MockControlPlane {
|
||||
Some(s) => {
|
||||
info!("got allowed_ips: {s}");
|
||||
s.split(',')
|
||||
.map(|s| IpPattern::from_str(s).unwrap())
|
||||
.map(|s| {
|
||||
IpPattern::from_str(s).expect("mocked ip pattern should be correct")
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
None => vec![],
|
||||
|
||||
41
proxy/src/ext.rs
Normal file
41
proxy/src/ext.rs
Normal file
@@ -0,0 +1,41 @@
|
||||
use std::panic::resume_unwind;
|
||||
use std::sync::{Mutex, MutexGuard};
|
||||
|
||||
use tokio::task::JoinError;
|
||||
|
||||
pub(crate) trait LockExt<T> {
|
||||
fn lock_propagate_poison(&self) -> MutexGuard<'_, T>;
|
||||
}
|
||||
|
||||
impl<T> LockExt<T> for Mutex<T> {
|
||||
/// Lock the mutex and panic if the mutex was poisoned.
|
||||
#[track_caller]
|
||||
fn lock_propagate_poison(&self) -> MutexGuard<'_, T> {
|
||||
match self.lock() {
|
||||
Ok(guard) => guard,
|
||||
// poison occurs when another thread panicked while holding the lock guard.
|
||||
// since panicking is often unrecoverable, propagating the poison panic is reasonable.
|
||||
Err(poison) => panic!("{poison}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) trait TaskExt<T> {
|
||||
fn propagate_task_panic(self) -> T;
|
||||
}
|
||||
|
||||
impl<T> TaskExt<T> for Result<T, JoinError> {
|
||||
/// Unwrap the result and panic if the inner task panicked.
|
||||
/// Also panics if the task was cancelled
|
||||
#[track_caller]
|
||||
fn propagate_task_panic(self) -> T {
|
||||
match self {
|
||||
Ok(t) => t,
|
||||
// Using resume_unwind prevents the panic hook being called twice.
|
||||
// Since we use this for structured concurrency, there is only
|
||||
// 1 logical panic, so this is more correct.
|
||||
Err(e) if e.is_panic() => resume_unwind(e.into_panic()),
|
||||
Err(e) => panic!("unexpected task error: {e}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -14,6 +14,7 @@ use utils::http::error::ApiError;
|
||||
use utils::http::json::json_response;
|
||||
use utils::http::{RouterBuilder, RouterService};
|
||||
|
||||
use crate::ext::{LockExt, TaskExt};
|
||||
use crate::jemalloc;
|
||||
|
||||
async fn status_handler(_: Request<Body>) -> Result<Response<Body>, ApiError> {
|
||||
@@ -76,7 +77,7 @@ async fn prometheus_metrics_handler(
|
||||
let body = tokio::task::spawn_blocking(move || {
|
||||
let _span = span.entered();
|
||||
|
||||
let mut state = state.lock().unwrap();
|
||||
let mut state = state.lock_propagate_poison();
|
||||
let PrometheusHandler { encoder, metrics } = &mut *state;
|
||||
|
||||
metrics
|
||||
@@ -94,13 +95,13 @@ async fn prometheus_metrics_handler(
|
||||
body
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
.propagate_task_panic();
|
||||
|
||||
let response = Response::builder()
|
||||
.status(200)
|
||||
.header(CONTENT_TYPE, "text/plain; version=0.0.4")
|
||||
.body(Body::from(body))
|
||||
.unwrap();
|
||||
.expect("response headers should be valid");
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
@@ -83,7 +83,7 @@ impl<Id: InternId> StringInterner<Id> {
|
||||
pub(crate) fn new() -> Self {
|
||||
StringInterner {
|
||||
inner: ThreadedRodeo::with_capacity_memory_limits_and_hasher(
|
||||
Capacity::new(2500, NonZeroUsize::new(1 << 16).unwrap()),
|
||||
Capacity::new(2500, NonZeroUsize::new(1 << 16).expect("value is nonzero")),
|
||||
// unbounded
|
||||
MemoryLimits::for_memory_usage(usize::MAX),
|
||||
BuildHasherDefault::<FxHasher>::default(),
|
||||
@@ -207,6 +207,7 @@ impl From<ProjectId> for ProjectIdInt {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[expect(clippy::unwrap_used)]
|
||||
mod tests {
|
||||
use std::sync::OnceLock;
|
||||
|
||||
|
||||
@@ -22,8 +22,8 @@
|
||||
clippy::string_add,
|
||||
clippy::string_to_string,
|
||||
clippy::todo,
|
||||
// TODO: consider clippy::unimplemented
|
||||
// TODO: consider clippy::unwrap_used
|
||||
clippy::unimplemented,
|
||||
clippy::unwrap_used,
|
||||
)]
|
||||
// List of permanently allowed lints.
|
||||
#![allow(
|
||||
@@ -82,6 +82,7 @@ pub mod console_redirect_proxy;
|
||||
pub mod context;
|
||||
pub mod control_plane;
|
||||
pub mod error;
|
||||
mod ext;
|
||||
pub mod http;
|
||||
pub mod intern;
|
||||
pub mod jemalloc;
|
||||
|
||||
@@ -18,8 +18,16 @@ pub async fn init() -> anyhow::Result<LoggingGuard> {
|
||||
let env_filter = EnvFilter::builder()
|
||||
.with_default_directive(LevelFilter::INFO.into())
|
||||
.from_env_lossy()
|
||||
.add_directive("aws_config=info".parse().unwrap())
|
||||
.add_directive("azure_core::policies::transport=off".parse().unwrap());
|
||||
.add_directive(
|
||||
"aws_config=info"
|
||||
.parse()
|
||||
.expect("this should be a valid filter directive"),
|
||||
)
|
||||
.add_directive(
|
||||
"azure_core::policies::transport=off"
|
||||
.parse()
|
||||
.expect("this should be a valid filter directive"),
|
||||
);
|
||||
|
||||
let fmt_layer = tracing_subscriber::fmt::layer()
|
||||
.with_ansi(false)
|
||||
|
||||
@@ -8,14 +8,6 @@ pub(crate) fn split_cstr(bytes: &[u8]) -> Option<(&CStr, &[u8])> {
|
||||
Some((cstr, other))
|
||||
}
|
||||
|
||||
/// See <https://doc.rust-lang.org/std/primitive.slice.html#method.split_array_ref>.
|
||||
pub(crate) fn split_at_const<const N: usize>(bytes: &[u8]) -> Option<(&[u8; N], &[u8])> {
|
||||
(bytes.len() >= N).then(|| {
|
||||
let (head, tail) = bytes.split_at(N);
|
||||
(head.try_into().unwrap(), tail)
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -33,11 +25,4 @@ mod tests {
|
||||
assert_eq!(cstr.to_bytes(), b"foo");
|
||||
assert_eq!(rest, b"bar");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_split_at_const() {
|
||||
assert!(split_at_const::<0>(b"").is_some());
|
||||
assert!(split_at_const::<1>(b"").is_none());
|
||||
assert!(matches!(split_at_const::<1>(b"ok"), Some((b"o", b"k"))));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -396,6 +396,7 @@ impl NetworkEndianIpv6 {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[expect(clippy::unwrap_used)]
|
||||
mod tests {
|
||||
use tokio::io::AsyncReadExt;
|
||||
|
||||
|
||||
@@ -257,6 +257,7 @@ impl CopyBuffer {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[expect(clippy::unwrap_used)]
|
||||
mod tests {
|
||||
use tokio::io::AsyncWriteExt;
|
||||
|
||||
|
||||
@@ -488,7 +488,7 @@ impl NeonOptions {
|
||||
|
||||
pub(crate) fn neon_option(bytes: &str) -> Option<(&str, &str)> {
|
||||
static RE: OnceCell<Regex> = OnceCell::new();
|
||||
let re = RE.get_or_init(|| Regex::new(r"^neon_(\w+):(.+)").unwrap());
|
||||
let re = RE.get_or_init(|| Regex::new(r"^neon_(\w+):(.+)").expect("regex should be correct"));
|
||||
|
||||
let cap = re.captures(bytes)?;
|
||||
let (_, [k, v]) = cap.extract();
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
//! A group of high-level tests for connection establishing logic and auth.
|
||||
#![allow(clippy::unimplemented, clippy::unwrap_used)]
|
||||
|
||||
mod mitm;
|
||||
|
||||
|
||||
@@ -83,7 +83,7 @@ impl From<LeakyBucketConfig> for utils::leaky_bucket::LeakyBucketConfig {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(clippy::float_cmp)]
|
||||
#[allow(clippy::float_cmp, clippy::unwrap_used)]
|
||||
mod tests {
|
||||
use std::time::Duration;
|
||||
|
||||
|
||||
@@ -63,6 +63,7 @@ impl LimitAlgorithm for Aimd {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[expect(clippy::unwrap_used)]
|
||||
mod tests {
|
||||
use std::time::Duration;
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ use rand::{Rng, SeedableRng};
|
||||
use tokio::time::{Duration, Instant};
|
||||
use tracing::info;
|
||||
|
||||
use crate::ext::LockExt;
|
||||
use crate::intern::EndpointIdInt;
|
||||
|
||||
pub struct GlobalRateLimiter {
|
||||
@@ -246,12 +247,13 @@ impl<K: Hash + Eq, R: Rng, S: BuildHasher + Clone> BucketRateLimiter<K, R, S> {
|
||||
let n = self.map.shards().len();
|
||||
// this lock is ok as the periodic cycle of do_gc makes this very unlikely to collide
|
||||
// (impossible, infact, unless we have 2048 threads)
|
||||
let shard = self.rand.lock().unwrap().gen_range(0..n);
|
||||
let shard = self.rand.lock_propagate_poison().gen_range(0..n);
|
||||
self.map.shards()[shard].write().clear();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[expect(clippy::unwrap_used)]
|
||||
mod tests {
|
||||
use std::hash::BuildHasherDefault;
|
||||
use std::time::Duration;
|
||||
|
||||
@@ -69,7 +69,11 @@ impl ConnectionWithCredentialsProvider {
|
||||
|
||||
pub fn new_with_static_credentials<T: IntoConnectionInfo>(params: T) -> Self {
|
||||
Self {
|
||||
credentials: Credentials::Static(params.into_connection_info().unwrap()),
|
||||
credentials: Credentials::Static(
|
||||
params
|
||||
.into_connection_info()
|
||||
.expect("static configured redis credentials should be a valid format"),
|
||||
),
|
||||
con: None,
|
||||
refresh_token_task: None,
|
||||
mutex: tokio::sync::Mutex::new(()),
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
use pq_proto::{BeAuthenticationSaslMessage, BeMessage};
|
||||
|
||||
use crate::parse::{split_at_const, split_cstr};
|
||||
use crate::parse::split_cstr;
|
||||
|
||||
/// SASL-specific payload of [`PasswordMessage`](pq_proto::FeMessage::PasswordMessage).
|
||||
#[derive(Debug)]
|
||||
@@ -19,7 +19,7 @@ impl<'a> FirstMessage<'a> {
|
||||
let (method_cstr, tail) = split_cstr(bytes)?;
|
||||
let method = method_cstr.to_str().ok()?;
|
||||
|
||||
let (len_bytes, bytes) = split_at_const(tail)?;
|
||||
let (len_bytes, bytes) = tail.split_first_chunk()?;
|
||||
let len = u32::from_be_bytes(*len_bytes) as usize;
|
||||
if len != bytes.len() {
|
||||
return None;
|
||||
@@ -51,6 +51,7 @@ impl<'a> ServerMessage<&'a str> {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[expect(clippy::unwrap_used)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
|
||||
@@ -185,6 +185,7 @@ impl fmt::Debug for OwnedServerFirstMessage {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[expect(clippy::unwrap_used)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
|
||||
@@ -57,6 +57,7 @@ fn sha256<'a>(parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[expect(clippy::unwrap_used)]
|
||||
mod tests {
|
||||
use super::threadpool::ThreadPool;
|
||||
use super::{Exchange, ServerSecret};
|
||||
|
||||
@@ -72,6 +72,7 @@ impl ServerSecret {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[expect(clippy::unwrap_used)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
|
||||
@@ -33,14 +33,11 @@ thread_local! {
|
||||
}
|
||||
|
||||
impl ThreadPool {
|
||||
pub fn new(n_workers: u8) -> Arc<Self> {
|
||||
pub fn new(mut n_workers: u8) -> Arc<Self> {
|
||||
// rayon would be nice here, but yielding in rayon does not work well afaict.
|
||||
|
||||
if n_workers == 0 {
|
||||
return Arc::new(Self {
|
||||
runtime: None,
|
||||
metrics: Arc::new(ThreadPoolMetrics::new(n_workers as usize)),
|
||||
});
|
||||
n_workers = 1;
|
||||
}
|
||||
|
||||
Arc::new_cyclic(|pool| {
|
||||
@@ -66,7 +63,7 @@ impl ThreadPool {
|
||||
});
|
||||
})
|
||||
.build()
|
||||
.unwrap();
|
||||
.expect("password threadpool runtime should be configured correctly");
|
||||
|
||||
Self {
|
||||
runtime: Some(runtime),
|
||||
@@ -79,7 +76,7 @@ impl ThreadPool {
|
||||
JobHandle(
|
||||
self.runtime
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.expect("runtime is always set")
|
||||
.spawn(JobSpec { pbkdf2, endpoint }),
|
||||
)
|
||||
}
|
||||
@@ -87,7 +84,10 @@ impl ThreadPool {
|
||||
|
||||
impl Drop for ThreadPool {
|
||||
fn drop(&mut self) {
|
||||
self.runtime.take().unwrap().shutdown_background();
|
||||
self.runtime
|
||||
.take()
|
||||
.expect("runtime is always set")
|
||||
.shutdown_background();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -268,7 +268,11 @@ impl PoolingBackend {
|
||||
|
||||
if !self.local_pool.initialized(&conn_info) {
|
||||
// only install and grant usage one at a time.
|
||||
let _permit = local_backend.initialize.acquire().await.unwrap();
|
||||
let _permit = local_backend
|
||||
.initialize
|
||||
.acquire()
|
||||
.await
|
||||
.expect("semaphore should never be closed");
|
||||
|
||||
// check again for race
|
||||
if !self.local_pool.initialized(&conn_info) {
|
||||
|
||||
@@ -186,8 +186,8 @@ impl ClientDataRemote {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[expect(clippy::unwrap_used)]
|
||||
mod tests {
|
||||
use std::mem;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
|
||||
use super::*;
|
||||
@@ -269,39 +269,33 @@ mod tests {
|
||||
assert_eq!(0, pool.get_global_connections_count());
|
||||
}
|
||||
{
|
||||
let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
|
||||
client.do_drop().unwrap()();
|
||||
mem::forget(client); // drop the client
|
||||
let client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
|
||||
drop(client);
|
||||
assert_eq!(1, pool.get_global_connections_count());
|
||||
}
|
||||
{
|
||||
let mut closed_client = Client::new(
|
||||
let closed_client = Client::new(
|
||||
create_inner_with(MockClient::new(true)),
|
||||
conn_info.clone(),
|
||||
ep_pool.clone(),
|
||||
);
|
||||
closed_client.do_drop().unwrap()();
|
||||
mem::forget(closed_client); // drop the client
|
||||
// The closed client shouldn't be added to the pool.
|
||||
drop(closed_client);
|
||||
assert_eq!(1, pool.get_global_connections_count());
|
||||
}
|
||||
let is_closed: Arc<AtomicBool> = Arc::new(false.into());
|
||||
{
|
||||
let mut client = Client::new(
|
||||
let client = Client::new(
|
||||
create_inner_with(MockClient(is_closed.clone())),
|
||||
conn_info.clone(),
|
||||
ep_pool.clone(),
|
||||
);
|
||||
client.do_drop().unwrap()();
|
||||
mem::forget(client); // drop the client
|
||||
|
||||
drop(client);
|
||||
// The client should be added to the pool.
|
||||
assert_eq!(2, pool.get_global_connections_count());
|
||||
}
|
||||
{
|
||||
let mut client = Client::new(create_inner(), conn_info, ep_pool);
|
||||
client.do_drop().unwrap()();
|
||||
mem::forget(client); // drop the client
|
||||
let client = Client::new(create_inner(), conn_info, ep_pool);
|
||||
drop(client);
|
||||
|
||||
// The client shouldn't be added to the pool. Because the ep-pool is full.
|
||||
assert_eq!(2, pool.get_global_connections_count());
|
||||
@@ -319,15 +313,13 @@ mod tests {
|
||||
&pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key().unwrap()),
|
||||
);
|
||||
{
|
||||
let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
|
||||
client.do_drop().unwrap()();
|
||||
mem::forget(client); // drop the client
|
||||
let client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
|
||||
drop(client);
|
||||
assert_eq!(3, pool.get_global_connections_count());
|
||||
}
|
||||
{
|
||||
let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
|
||||
client.do_drop().unwrap()();
|
||||
mem::forget(client); // drop the client
|
||||
let client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
|
||||
drop(client);
|
||||
|
||||
// The client shouldn't be added to the pool. Because the global pool is full.
|
||||
assert_eq!(3, pool.get_global_connections_count());
|
||||
|
||||
@@ -187,19 +187,22 @@ impl<C: ClientInnerExt> EndpointConnPool<C> {
|
||||
|
||||
pub(crate) fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, client: ClientInnerCommon<C>) {
|
||||
let conn_id = client.get_conn_id();
|
||||
let pool_name = pool.read().get_name().to_string();
|
||||
let (max_conn, conn_count, pool_name) = {
|
||||
let pool = pool.read();
|
||||
(
|
||||
pool.global_pool_size_max_conns,
|
||||
pool.global_connections_count
|
||||
.load(atomic::Ordering::Relaxed),
|
||||
pool.get_name().to_string(),
|
||||
)
|
||||
};
|
||||
|
||||
if client.inner.is_closed() {
|
||||
info!(%conn_id, "{}: throwing away connection '{conn_info}' because connection is closed", pool_name);
|
||||
return;
|
||||
}
|
||||
|
||||
let global_max_conn = pool.read().global_pool_size_max_conns;
|
||||
if pool
|
||||
.read()
|
||||
.global_connections_count
|
||||
.load(atomic::Ordering::Relaxed)
|
||||
>= global_max_conn
|
||||
{
|
||||
if conn_count >= max_conn {
|
||||
info!(%conn_id, "{}: throwing away connection '{conn_info}' because pool is full", pool_name);
|
||||
return;
|
||||
}
|
||||
@@ -633,35 +636,29 @@ impl<C: ClientInnerExt> Client<C> {
|
||||
}
|
||||
|
||||
pub(crate) fn metrics(&self) -> Arc<MetricCounter> {
|
||||
let aux = &self.inner.as_ref().unwrap().aux;
|
||||
let aux = &self
|
||||
.inner
|
||||
.as_ref()
|
||||
.expect("client inner should not be removed")
|
||||
.aux;
|
||||
USAGE_METRICS.register(Ids {
|
||||
endpoint_id: aux.endpoint_id,
|
||||
branch_id: aux.branch_id,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn do_drop(&mut self) -> Option<impl FnOnce() + use<C>> {
|
||||
impl<C: ClientInnerExt> Drop for Client<C> {
|
||||
fn drop(&mut self) {
|
||||
let conn_info = self.conn_info.clone();
|
||||
let client = self
|
||||
.inner
|
||||
.take()
|
||||
.expect("client inner should not be removed");
|
||||
if let Some(conn_pool) = std::mem::take(&mut self.pool).upgrade() {
|
||||
let current_span = self.span.clone();
|
||||
let _current_span = self.span.enter();
|
||||
// return connection to the pool
|
||||
return Some(move || {
|
||||
let _span = current_span.enter();
|
||||
EndpointConnPool::put(&conn_pool, &conn_info, client);
|
||||
});
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Drop for Client<C> {
|
||||
fn drop(&mut self) {
|
||||
if let Some(drop) = self.do_drop() {
|
||||
tokio::task::spawn_blocking(drop);
|
||||
EndpointConnPool::put(&conn_pool, &conn_info, client);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,11 +81,14 @@ impl HttpErrorBody {
|
||||
.header(http::header::CONTENT_TYPE, "application/json")
|
||||
// we do not have nested maps with non string keys so serialization shouldn't fail
|
||||
.body(
|
||||
Full::new(Bytes::from(serde_json::to_string(self).unwrap()))
|
||||
.map_err(|x| match x {})
|
||||
.boxed(),
|
||||
Full::new(Bytes::from(
|
||||
serde_json::to_string(self)
|
||||
.expect("serialising HttpErrorBody should never fail"),
|
||||
))
|
||||
.map_err(|x| match x {})
|
||||
.boxed(),
|
||||
)
|
||||
.unwrap()
|
||||
.expect("content-type header should be valid")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -204,7 +204,10 @@ fn pg_array_parse_inner(
|
||||
|
||||
if c == '\\' {
|
||||
escaped = true;
|
||||
(i, c) = pg_array_chr.next().unwrap();
|
||||
let Some(x) = pg_array_chr.next() else {
|
||||
return Err(JsonConversionError::UnbalancedArray);
|
||||
};
|
||||
(i, c) = x;
|
||||
}
|
||||
|
||||
match c {
|
||||
@@ -253,6 +256,7 @@ fn pg_array_parse_inner(
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[expect(clippy::unwrap_used)]
|
||||
mod tests {
|
||||
use serde_json::json;
|
||||
|
||||
|
||||
@@ -179,7 +179,6 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
|
||||
info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
|
||||
});
|
||||
let pool = Arc::downgrade(&global_pool);
|
||||
let pool_clone = pool.clone();
|
||||
|
||||
let db_user = conn_info.db_and_user();
|
||||
let idle = global_pool.get_idle_timeout();
|
||||
@@ -273,11 +272,7 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
|
||||
}),
|
||||
};
|
||||
|
||||
Client::new(
|
||||
inner,
|
||||
conn_info,
|
||||
Arc::downgrade(&pool_clone.upgrade().unwrap().global_pool),
|
||||
)
|
||||
Client::new(inner, conn_info, Arc::downgrade(&global_pool.global_pool))
|
||||
}
|
||||
|
||||
impl ClientInnerCommon<postgres_client::Client> {
|
||||
@@ -321,7 +316,8 @@ fn resign_jwt(sk: &SigningKey, payload: &[u8], jti: u64) -> Result<String, HttpC
|
||||
let mut buffer = itoa::Buffer::new();
|
||||
|
||||
// encode the jti integer to a json rawvalue
|
||||
let jti = serde_json::from_str::<&RawValue>(buffer.format(jti)).unwrap();
|
||||
let jti = serde_json::from_str::<&RawValue>(buffer.format(jti))
|
||||
.expect("itoa formatted integer should be guaranteed valid json");
|
||||
|
||||
// update the jti in-place
|
||||
let payload =
|
||||
@@ -368,6 +364,7 @@ fn sign_jwt(sk: &SigningKey, payload: &[u8]) -> String {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[expect(clippy::unwrap_used)]
|
||||
mod tests {
|
||||
use p256::ecdsa::SigningKey;
|
||||
use typed_json::json;
|
||||
|
||||
@@ -46,6 +46,7 @@ use utils::http::error::ApiError;
|
||||
use crate::cancellation::CancellationHandlerMain;
|
||||
use crate::config::{ProxyConfig, ProxyProtocolV2};
|
||||
use crate::context::RequestContext;
|
||||
use crate::ext::TaskExt;
|
||||
use crate::metrics::Metrics;
|
||||
use crate::protocol2::{read_proxy_protocol, ChainRW, ConnectHeader, ConnectionInfo};
|
||||
use crate::proxy::run_until_cancelled;
|
||||
@@ -84,7 +85,7 @@ pub async fn task_main(
|
||||
cancellation_token.cancelled().await;
|
||||
tokio::task::spawn_blocking(move || conn_pool.shutdown())
|
||||
.await
|
||||
.unwrap();
|
||||
.propagate_task_panic();
|
||||
}
|
||||
});
|
||||
|
||||
@@ -104,7 +105,7 @@ pub async fn task_main(
|
||||
cancellation_token.cancelled().await;
|
||||
tokio::task::spawn_blocking(move || http_conn_pool.shutdown())
|
||||
.await
|
||||
.unwrap();
|
||||
.propagate_task_panic();
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
@@ -1110,6 +1110,7 @@ impl Discard<'_> {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[expect(clippy::unwrap_used)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
|
||||
@@ -178,6 +178,7 @@ pub(crate) async fn serve_websocket(
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[expect(clippy::unwrap_used)]
|
||||
mod tests {
|
||||
use std::pin::pin;
|
||||
|
||||
|
||||
@@ -50,6 +50,7 @@ impl std::fmt::Display for ApiUrl {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[expect(clippy::unwrap_used)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
|
||||
@@ -407,6 +407,7 @@ async fn upload_backup_events(
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[expect(clippy::unwrap_used)]
|
||||
mod tests {
|
||||
use std::fs;
|
||||
use std::io::BufReader;
|
||||
|
||||
Reference in New Issue
Block a user