Nightly lints and small tweaks (#12456)

Let chains available in 1.88 :D new clippy lints coming up in future
releases.
This commit is contained in:
Conrad Ludgate
2025-07-03 15:47:12 +01:00
committed by GitHub
parent 4db934407a
commit 03e604e432
32 changed files with 239 additions and 316 deletions

View File

@@ -52,7 +52,7 @@ pub(crate) async fn hi(str: &[u8], salt: &[u8], iterations: u32) -> [u8; 32] {
}
// yield every ~250us
// hopefully reduces tail latencies
if i % 1024 == 0 {
if i.is_multiple_of(1024) {
yield_now().await
}
}

View File

@@ -90,7 +90,7 @@ pub struct InnerClient {
}
impl InnerClient {
pub fn start(&mut self) -> Result<PartialQuery, Error> {
pub fn start(&mut self) -> Result<PartialQuery<'_>, Error> {
self.responses.waiting += 1;
Ok(PartialQuery(Some(self)))
}
@@ -227,7 +227,7 @@ impl Client {
&mut self,
statement: &str,
params: I,
) -> Result<RowStream, Error>
) -> Result<RowStream<'_>, Error>
where
S: AsRef<str>,
I: IntoIterator<Item = Option<S>>,
@@ -262,7 +262,7 @@ impl Client {
pub(crate) async fn simple_query_raw(
&mut self,
query: &str,
) -> Result<SimpleQueryStream, Error> {
) -> Result<SimpleQueryStream<'_>, Error> {
simple_query::simple_query(self.inner_mut(), query).await
}

View File

@@ -12,7 +12,11 @@ mod private {
/// This trait is "sealed", and cannot be implemented outside of this crate.
pub trait GenericClient: private::Sealed {
/// Like `Client::query_raw_txt`.
async fn query_raw_txt<S, I>(&mut self, statement: &str, params: I) -> Result<RowStream, Error>
async fn query_raw_txt<S, I>(
&mut self,
statement: &str,
params: I,
) -> Result<RowStream<'_>, Error>
where
S: AsRef<str> + Sync + Send,
I: IntoIterator<Item = Option<S>> + Sync + Send,
@@ -22,7 +26,11 @@ pub trait GenericClient: private::Sealed {
impl private::Sealed for Client {}
impl GenericClient for Client {
async fn query_raw_txt<S, I>(&mut self, statement: &str, params: I) -> Result<RowStream, Error>
async fn query_raw_txt<S, I>(
&mut self,
statement: &str,
params: I,
) -> Result<RowStream<'_>, Error>
where
S: AsRef<str> + Sync + Send,
I: IntoIterator<Item = Option<S>> + Sync + Send,
@@ -35,7 +43,11 @@ impl GenericClient for Client {
impl private::Sealed for Transaction<'_> {}
impl GenericClient for Transaction<'_> {
async fn query_raw_txt<S, I>(&mut self, statement: &str, params: I) -> Result<RowStream, Error>
async fn query_raw_txt<S, I>(
&mut self,
statement: &str,
params: I,
) -> Result<RowStream<'_>, Error>
where
S: AsRef<str> + Sync + Send,
I: IntoIterator<Item = Option<S>> + Sync + Send,

View File

@@ -47,7 +47,7 @@ impl<'a> Transaction<'a> {
&mut self,
statement: &str,
params: I,
) -> Result<RowStream, Error>
) -> Result<RowStream<'_>, Error>
where
S: AsRef<str>,
I: IntoIterator<Item = Option<S>>,

View File

@@ -164,21 +164,20 @@ async fn authenticate(
})?
.map_err(ConsoleRedirectError::from)?;
if auth_config.ip_allowlist_check_enabled {
if let Some(allowed_ips) = &db_info.allowed_ips {
if !auth::check_peer_addr_is_in_list(&ctx.peer_addr(), allowed_ips) {
return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
}
}
if auth_config.ip_allowlist_check_enabled
&& let Some(allowed_ips) = &db_info.allowed_ips
&& !auth::check_peer_addr_is_in_list(&ctx.peer_addr(), allowed_ips)
{
return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
}
// Check if the access over the public internet is allowed, otherwise block. Note that
// the console redirect is not behind the VPC service endpoint, so we don't need to check
// the VPC endpoint ID.
if let Some(public_access_allowed) = db_info.public_access_allowed {
if !public_access_allowed {
return Err(auth::AuthError::NetworkNotAllowed);
}
if let Some(public_access_allowed) = db_info.public_access_allowed
&& !public_access_allowed
{
return Err(auth::AuthError::NetworkNotAllowed);
}
client.write_message(BeMessage::NoticeResponse("Connecting to database."));

View File

@@ -399,36 +399,36 @@ impl JwkCacheEntryLock {
tracing::debug!(?payload, "JWT signature valid with claims");
if let Some(aud) = expected_audience {
if payload.audience.0.iter().all(|s| s != aud) {
return Err(JwtError::InvalidClaims(
JwtClaimsError::InvalidJwtTokenAudience,
));
}
if let Some(aud) = expected_audience
&& payload.audience.0.iter().all(|s| s != aud)
{
return Err(JwtError::InvalidClaims(
JwtClaimsError::InvalidJwtTokenAudience,
));
}
let now = SystemTime::now();
if let Some(exp) = payload.expiration {
if now >= exp + CLOCK_SKEW_LEEWAY {
return Err(JwtError::InvalidClaims(JwtClaimsError::JwtTokenHasExpired(
exp.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
)));
}
if let Some(exp) = payload.expiration
&& now >= exp + CLOCK_SKEW_LEEWAY
{
return Err(JwtError::InvalidClaims(JwtClaimsError::JwtTokenHasExpired(
exp.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
)));
}
if let Some(nbf) = payload.not_before {
if nbf >= now + CLOCK_SKEW_LEEWAY {
return Err(JwtError::InvalidClaims(
JwtClaimsError::JwtTokenNotYetReadyToUse(
nbf.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
),
));
}
if let Some(nbf) = payload.not_before
&& nbf >= now + CLOCK_SKEW_LEEWAY
{
return Err(JwtError::InvalidClaims(
JwtClaimsError::JwtTokenNotYetReadyToUse(
nbf.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
),
));
}
Ok(ComputeCredentialKeys::JwtPayload(payloadb))

View File

@@ -345,15 +345,13 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
Err(e) => {
// The password could have been changed, so we invalidate the cache.
// We should only invalidate the cache if the TTL might have expired.
if e.is_password_failed() {
#[allow(irrefutable_let_patterns)]
if let ControlPlaneClient::ProxyV1(api) = &*api {
if let Some(ep) = &user_info.endpoint_id {
api.caches
.project_info
.maybe_invalidate_role_secret(ep, &user_info.user);
}
}
if e.is_password_failed()
&& let ControlPlaneClient::ProxyV1(api) = &*api
&& let Some(ep) = &user_info.endpoint_id
{
api.caches
.project_info
.maybe_invalidate_role_secret(ep, &user_info.user);
}
Err(e)

View File

@@ -7,9 +7,7 @@ use anyhow::bail;
use arc_swap::ArcSwapOption;
use camino::Utf8PathBuf;
use clap::Parser;
use futures::future::Either;
use tokio::net::TcpListener;
use tokio::sync::Notify;
use tokio::task::JoinSet;
@@ -22,9 +20,9 @@ use crate::auth::backend::jwt::JwkCache;
use crate::auth::backend::local::LocalBackend;
use crate::auth::{self};
use crate::cancellation::CancellationHandler;
use crate::config::refresh_config_loop;
use crate::config::{
self, AuthenticationConfig, ComputeConfig, HttpConfig, ProxyConfig, RetryConfig,
refresh_config_loop,
};
use crate::control_plane::locks::ApiLocks;
use crate::http::health_server::AppMetrics;

View File

@@ -10,11 +10,15 @@ use std::time::Duration;
use anyhow::Context;
use anyhow::{bail, ensure};
use arc_swap::ArcSwapOption;
#[cfg(any(test, feature = "testing"))]
use camino::Utf8PathBuf;
use futures::future::Either;
use itertools::{Itertools, Position};
use rand::{Rng, thread_rng};
use remote_storage::RemoteStorageConfig;
use tokio::net::TcpListener;
#[cfg(any(test, feature = "testing"))]
use tokio::sync::Notify;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, error, info, warn};
@@ -47,10 +51,6 @@ use crate::tls::client_config::compute_client_config_with_root_certs;
#[cfg(any(test, feature = "testing"))]
use crate::url::ApiUrl;
use crate::{auth, control_plane, http, serverless, usage_metrics};
#[cfg(any(test, feature = "testing"))]
use camino::Utf8PathBuf;
#[cfg(any(test, feature = "testing"))]
use tokio::sync::Notify;
project_git_version!(GIT_VERSION);
project_build_tag!(BUILD_TAG);
@@ -520,54 +520,51 @@ pub async fn run() -> anyhow::Result<()> {
maintenance_tasks.spawn(usage_metrics::task_main(metrics_config));
}
#[cfg_attr(not(any(test, feature = "testing")), expect(irrefutable_let_patterns))]
if let Either::Left(auth::Backend::ControlPlane(api, ())) = &auth_backend {
if let crate::control_plane::client::ControlPlaneClient::ProxyV1(api) = &**api {
if let Some(client) = redis_client {
// project info cache and invalidation of that cache.
let cache = api.caches.project_info.clone();
maintenance_tasks.spawn(notifications::task_main(client.clone(), cache.clone()));
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
if let Either::Left(auth::Backend::ControlPlane(api, ())) = &auth_backend
&& let crate::control_plane::client::ControlPlaneClient::ProxyV1(api) = &**api
&& let Some(client) = redis_client
{
// project info cache and invalidation of that cache.
let cache = api.caches.project_info.clone();
maintenance_tasks.spawn(notifications::task_main(client.clone(), cache.clone()));
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
// Try to connect to Redis 3 times with 1 + (0..0.1) second interval.
// This prevents immediate exit and pod restart,
// which can cause hammering of the redis in case of connection issues.
// cancellation key management
let mut redis_kv_client = RedisKVClient::new(client.clone());
for attempt in (0..3).with_position() {
match redis_kv_client.try_connect().await {
Ok(()) => {
info!("Connected to Redis KV client");
cancellation_handler.init_tx(BatchQueue::new(CancellationProcessor {
client: redis_kv_client,
batch_size: args.cancellation_batch_size,
}));
// Try to connect to Redis 3 times with 1 + (0..0.1) second interval.
// This prevents immediate exit and pod restart,
// which can cause hammering of the redis in case of connection issues.
// cancellation key management
let mut redis_kv_client = RedisKVClient::new(client.clone());
for attempt in (0..3).with_position() {
match redis_kv_client.try_connect().await {
Ok(()) => {
info!("Connected to Redis KV client");
cancellation_handler.init_tx(BatchQueue::new(CancellationProcessor {
client: redis_kv_client,
batch_size: args.cancellation_batch_size,
}));
break;
}
Err(e) => {
error!("Failed to connect to Redis KV client: {e}");
if matches!(attempt, Position::Last(_)) {
bail!(
"Failed to connect to Redis KV client after {} attempts",
attempt.into_inner()
);
}
let jitter = thread_rng().gen_range(0..100);
tokio::time::sleep(Duration::from_millis(1000 + jitter)).await;
}
}
break;
}
Err(e) => {
error!("Failed to connect to Redis KV client: {e}");
if matches!(attempt, Position::Last(_)) {
bail!(
"Failed to connect to Redis KV client after {} attempts",
attempt.into_inner()
);
}
let jitter = thread_rng().gen_range(0..100);
tokio::time::sleep(Duration::from_millis(1000 + jitter)).await;
}
// listen for notifications of new projects/endpoints/branches
let cache = api.caches.endpoints_cache.clone();
let span = tracing::info_span!("endpoints_cache");
maintenance_tasks.spawn(
async move { cache.do_read(client, cancellation_token.clone()).await }
.instrument(span),
);
}
}
// listen for notifications of new projects/endpoints/branches
let cache = api.caches.endpoints_cache.clone();
let span = tracing::info_span!("endpoints_cache");
maintenance_tasks.spawn(
async move { cache.do_read(client, cancellation_token.clone()).await }.instrument(span),
);
}
let maintenance = loop {

View File

@@ -4,28 +4,26 @@ use std::time::Duration;
use anyhow::{Context, Ok, bail, ensure};
use arc_swap::ArcSwapOption;
use camino::{Utf8Path, Utf8PathBuf};
use clap::ValueEnum;
use compute_api::spec::LocalProxySpec;
use remote_storage::RemoteStorageConfig;
use thiserror::Error;
use tokio::sync::Notify;
use tracing::{debug, error, info, warn};
use crate::auth::backend::jwt::JwkCache;
use crate::auth::backend::local::JWKS_ROLE_MAP;
use crate::control_plane::locks::ApiLocks;
use crate::control_plane::messages::{EndpointJwksResponse, JwksSettings};
use crate::ext::TaskExt;
use crate::intern::RoleNameInt;
use crate::rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig};
use crate::scram::threadpool::ThreadPool;
use crate::serverless::GlobalConnPoolOptions;
use crate::serverless::cancel_set::CancelSet;
pub use crate::tls::server_config::{TlsConfig, configure_tls};
use crate::types::Host;
use crate::auth::backend::local::JWKS_ROLE_MAP;
use crate::control_plane::messages::{EndpointJwksResponse, JwksSettings};
use crate::ext::TaskExt;
use crate::intern::RoleNameInt;
use crate::types::RoleName;
use camino::{Utf8Path, Utf8PathBuf};
use compute_api::spec::LocalProxySpec;
use thiserror::Error;
use tokio::sync::Notify;
use tracing::{debug, error, info, warn};
use crate::types::{Host, RoleName};
pub struct ProxyConfig {
pub tls_config: ArcSwapOption<TlsConfig>,

View File

@@ -209,11 +209,9 @@ impl RequestContext {
if let Some(options_str) = options.get("options") {
// If not found directly, try to extract it from the options string
for option in options_str.split_whitespace() {
if option.starts_with("neon_query_id:") {
if let Some(value) = option.strip_prefix("neon_query_id:") {
this.set_testodrome_id(value.into());
break;
}
if let Some(value) = option.strip_prefix("neon_query_id:") {
this.set_testodrome_id(value.into());
break;
}
}
}

View File

@@ -250,10 +250,8 @@ impl NeonControlPlaneClient {
info!(duration = ?start.elapsed(), "received http response");
let body = parse_body::<WakeCompute>(response.status(), response.bytes().await?)?;
// Unfortunately, ownership won't let us use `Option::ok_or` here.
let (host, port) = match parse_host_port(&body.address) {
None => return Err(WakeComputeError::BadComputeAddress(body.address)),
Some(x) => x,
let Some((host, port)) = parse_host_port(&body.address) else {
return Err(WakeComputeError::BadComputeAddress(body.address));
};
let host_addr = IpAddr::from_str(host).ok();

View File

@@ -271,18 +271,18 @@ where
});
// In case logging fails we generate a simpler JSON object.
if let Err(err) = res {
if let Ok(mut line) = serde_json::to_vec(&serde_json::json!( {
if let Err(err) = res
&& let Ok(mut line) = serde_json::to_vec(&serde_json::json!( {
"timestamp": now.to_rfc3339_opts(chrono::SecondsFormat::Micros, true),
"level": "ERROR",
"message": format_args!("cannot log event: {err:?}"),
"fields": {
"event": format_args!("{event:?}"),
},
})) {
line.push(b'\n');
self.writer.make_writer().write_all(&line).ok();
}
}))
{
line.push(b'\n');
self.writer.make_writer().write_all(&line).ok();
}
}
@@ -583,10 +583,11 @@ impl EventFormatter {
THREAD_ID.with(|tid| serializer.serialize_entry("thread_id", tid))?;
// TODO: tls cache? name could change
if let Some(thread_name) = std::thread::current().name() {
if !thread_name.is_empty() && thread_name != "tokio-runtime-worker" {
serializer.serialize_entry("thread_name", thread_name)?;
}
if let Some(thread_name) = std::thread::current().name()
&& !thread_name.is_empty()
&& thread_name != "tokio-runtime-worker"
{
serializer.serialize_entry("thread_name", thread_name)?;
}
if let Some(task_id) = tokio::task::try_id() {
@@ -596,10 +597,10 @@ impl EventFormatter {
serializer.serialize_entry("target", meta.target())?;
// Skip adding module if it's the same as target.
if let Some(module) = meta.module_path() {
if module != meta.target() {
serializer.serialize_entry("module", module)?;
}
if let Some(module) = meta.module_path()
&& module != meta.target()
{
serializer.serialize_entry("module", module)?;
}
if let Some(file) = meta.file() {

View File

@@ -236,13 +236,6 @@ pub enum Bool {
False,
}
#[derive(FixedCardinalityLabel, Copy, Clone)]
#[label(singleton = "outcome")]
pub enum Outcome {
Success,
Failed,
}
#[derive(FixedCardinalityLabel, Copy, Clone)]
#[label(singleton = "outcome")]
pub enum CacheOutcome {

View File

@@ -90,27 +90,27 @@ where
// TODO: 1 info log, with a enum label for close direction.
// Early termination checks from compute to client.
if let TransferState::Done(_) = compute_to_client {
if let TransferState::Running(buf) = &client_to_compute {
info!("Compute is done, terminate client");
// Initiate shutdown
client_to_compute = TransferState::ShuttingDown(buf.amt);
client_to_compute_result =
transfer_one_direction(cx, &mut client_to_compute, client, compute)
.map_err(ErrorSource::from_client)?;
}
if let TransferState::Done(_) = compute_to_client
&& let TransferState::Running(buf) = &client_to_compute
{
info!("Compute is done, terminate client");
// Initiate shutdown
client_to_compute = TransferState::ShuttingDown(buf.amt);
client_to_compute_result =
transfer_one_direction(cx, &mut client_to_compute, client, compute)
.map_err(ErrorSource::from_client)?;
}
// Early termination checks from client to compute.
if let TransferState::Done(_) = client_to_compute {
if let TransferState::Running(buf) = &compute_to_client {
info!("Client is done, terminate compute");
// Initiate shutdown
compute_to_client = TransferState::ShuttingDown(buf.amt);
compute_to_client_result =
transfer_one_direction(cx, &mut compute_to_client, compute, client)
.map_err(ErrorSource::from_compute)?;
}
if let TransferState::Done(_) = client_to_compute
&& let TransferState::Running(buf) = &compute_to_client
{
info!("Client is done, terminate compute");
// Initiate shutdown
compute_to_client = TransferState::ShuttingDown(buf.amt);
compute_to_client_result =
transfer_one_direction(cx, &mut compute_to_client, compute, client)
.map_err(ErrorSource::from_compute)?;
}
// It is not a problem if ready! returns early ... (comment remains the same)

View File

@@ -39,7 +39,11 @@ impl<K: Hash + Eq> LeakyBucketRateLimiter<K> {
let config = config.map_or(self.default_config, Into::into);
if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 {
if self
.access_count
.fetch_add(1, Ordering::AcqRel)
.is_multiple_of(2048)
{
self.do_gc(now);
}

View File

@@ -211,7 +211,11 @@ impl<K: Hash + Eq, R: Rng, S: BuildHasher + Clone> BucketRateLimiter<K, R, S> {
// worst case memory usage is about:
// = 2 * 2048 * 64 * (48B + 72B)
// = 30MB
if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 {
if self
.access_count
.fetch_add(1, Ordering::AcqRel)
.is_multiple_of(2048)
{
self.do_gc();
}

View File

@@ -1,79 +0,0 @@
use core::net::IpAddr;
use std::sync::Arc;
use tokio::sync::Mutex;
use uuid::Uuid;
use crate::pqproto::CancelKeyData;
pub trait CancellationPublisherMut: Send + Sync + 'static {
#[allow(async_fn_in_trait)]
async fn try_publish(
&mut self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
peer_addr: IpAddr,
) -> anyhow::Result<()>;
}
pub trait CancellationPublisher: Send + Sync + 'static {
#[allow(async_fn_in_trait)]
async fn try_publish(
&self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
peer_addr: IpAddr,
) -> anyhow::Result<()>;
}
impl CancellationPublisher for () {
async fn try_publish(
&self,
_cancel_key_data: CancelKeyData,
_session_id: Uuid,
_peer_addr: IpAddr,
) -> anyhow::Result<()> {
Ok(())
}
}
impl<P: CancellationPublisher> CancellationPublisherMut for P {
async fn try_publish(
&mut self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
peer_addr: IpAddr,
) -> anyhow::Result<()> {
<P as CancellationPublisher>::try_publish(self, cancel_key_data, session_id, peer_addr)
.await
}
}
impl<P: CancellationPublisher> CancellationPublisher for Option<P> {
async fn try_publish(
&self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
peer_addr: IpAddr,
) -> anyhow::Result<()> {
if let Some(p) = self {
p.try_publish(cancel_key_data, session_id, peer_addr).await
} else {
Ok(())
}
}
}
impl<P: CancellationPublisherMut> CancellationPublisher for Arc<Mutex<P>> {
async fn try_publish(
&self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
peer_addr: IpAddr,
) -> anyhow::Result<()> {
self.lock()
.await
.try_publish(cancel_key_data, session_id, peer_addr)
.await
}
}

View File

@@ -1,11 +1,11 @@
use std::convert::Infallible;
use std::sync::{Arc, atomic::AtomicBool, atomic::Ordering};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use futures::FutureExt;
use redis::aio::{ConnectionLike, MultiplexedConnection};
use redis::{ConnectionInfo, IntoConnectionInfo, RedisConnectionInfo, RedisResult};
use tokio::task::JoinHandle;
use tokio::task::AbortHandle;
use tracing::{error, info, warn};
use super::elasticache::CredentialsProvider;
@@ -32,7 +32,7 @@ pub struct ConnectionWithCredentialsProvider {
credentials: Credentials,
// TODO: with more load on the connection, we should consider using a connection pool
con: Option<MultiplexedConnection>,
refresh_token_task: Option<JoinHandle<Infallible>>,
refresh_token_task: Option<AbortHandle>,
mutex: tokio::sync::Mutex<()>,
credentials_refreshed: Arc<AtomicBool>,
}
@@ -127,7 +127,7 @@ impl ConnectionWithCredentialsProvider {
credentials_provider,
credentials_refreshed,
));
self.refresh_token_task = Some(f);
self.refresh_token_task = Some(f.abort_handle());
}
match Self::ping(&mut con).await {
Ok(()) => {
@@ -179,7 +179,7 @@ impl ConnectionWithCredentialsProvider {
mut con: MultiplexedConnection,
credentials_provider: Arc<CredentialsProvider>,
credentials_refreshed: Arc<AtomicBool>,
) -> Infallible {
) -> ! {
loop {
// The connection lives for 12h, for the sanity check we refresh it every hour.
tokio::time::sleep(Duration::from_secs(60 * 60)).await;
@@ -244,7 +244,7 @@ impl ConnectionLike for ConnectionWithCredentialsProvider {
&'a mut self,
cmd: &'a redis::Cmd,
) -> redis::RedisFuture<'a, redis::Value> {
(async move { self.send_packed_command(cmd).await }).boxed()
self.send_packed_command(cmd).boxed()
}
fn req_packed_commands<'a>(
@@ -253,10 +253,10 @@ impl ConnectionLike for ConnectionWithCredentialsProvider {
offset: usize,
count: usize,
) -> redis::RedisFuture<'a, Vec<redis::Value>> {
(async move { self.send_packed_commands(cmd, offset, count).await }).boxed()
self.send_packed_commands(cmd, offset, count).boxed()
}
fn get_db(&self) -> i64 {
0
self.con.as_ref().map_or(0, |c| c.get_db())
}
}

View File

@@ -1,4 +1,3 @@
pub mod cancellation_publisher;
pub mod connection_with_credentials_provider;
pub mod elasticache;
pub mod keys;

View File

@@ -54,9 +54,7 @@ impl<T: std::fmt::Display> ChannelBinding<T> {
"eSws".into()
}
Self::Required(mode) => {
use std::io::Write;
let mut cbind_input = vec![];
write!(&mut cbind_input, "p={mode},,",).unwrap();
let mut cbind_input = format!("p={mode},,",).into_bytes();
cbind_input.extend_from_slice(get_cbind_data(mode)?);
BASE64_STANDARD.encode(&cbind_input).into()
}

View File

@@ -107,7 +107,7 @@ pub(crate) async fn exchange(
secret: &ServerSecret,
password: &[u8],
) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
let salt = BASE64_STANDARD.decode(&secret.salt_base64)?;
let salt = BASE64_STANDARD.decode(&*secret.salt_base64)?;
let client_key = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
if secret.is_password_invalid(&client_key).into() {

View File

@@ -87,13 +87,20 @@ impl<'a> ClientFirstMessage<'a> {
salt_base64: &str,
iterations: u32,
) -> OwnedServerFirstMessage {
use std::fmt::Write;
let mut message = String::with_capacity(128);
message.push_str("r=");
let mut message = String::new();
write!(&mut message, "r={}", self.nonce).unwrap();
// write combined nonce
let combined_nonce_start = message.len();
message.push_str(self.nonce);
BASE64_STANDARD.encode_string(nonce, &mut message);
let combined_nonce = 2..message.len();
write!(&mut message, ",s={salt_base64},i={iterations}").unwrap();
let combined_nonce = combined_nonce_start..message.len();
// write salt and iterations
message.push_str(",s=");
message.push_str(salt_base64);
message.push_str(",i=");
message.push_str(itoa::Buffer::new().format(iterations));
// This design guarantees that it's impossible to create a
// server-first-message without receiving a client-first-message

View File

@@ -14,7 +14,7 @@ pub(crate) struct ServerSecret {
/// Number of iterations for `PBKDF2` function.
pub(crate) iterations: u32,
/// Salt used to hash user's password.
pub(crate) salt_base64: String,
pub(crate) salt_base64: Box<str>,
/// Hashed `ClientKey`.
pub(crate) stored_key: ScramKey,
/// Used by client to verify server's signature.
@@ -35,7 +35,7 @@ impl ServerSecret {
let secret = ServerSecret {
iterations: iterations.parse().ok()?,
salt_base64: salt.to_owned(),
salt_base64: salt.into(),
stored_key: base64_decode_array(stored_key)?.into(),
server_key: base64_decode_array(server_key)?.into(),
doomed: false,
@@ -58,7 +58,7 @@ impl ServerSecret {
// iteration count 1 for our generated passwords going forward.
// PG16 users can set iteration count=1 already today.
iterations: 1,
salt_base64: BASE64_STANDARD.encode(nonce),
salt_base64: BASE64_STANDARD.encode(nonce).into_boxed_str(),
stored_key: ScramKey::default(),
server_key: ScramKey::default(),
doomed: true,
@@ -88,7 +88,7 @@ mod tests {
let parsed = ServerSecret::parse(&secret).unwrap();
assert_eq!(parsed.iterations, iterations);
assert_eq!(parsed.salt_base64, salt);
assert_eq!(&*parsed.salt_base64, salt);
assert_eq!(BASE64_STANDARD.encode(parsed.stored_key), stored_key);
assert_eq!(BASE64_STANDARD.encode(parsed.server_key), server_key);

View File

@@ -137,7 +137,7 @@ impl Future for JobSpec {
let state = state.as_mut().expect("should be set on thread startup");
state.tick = state.tick.wrapping_add(1);
if state.tick % SKETCH_RESET_INTERVAL == 0 {
if state.tick.is_multiple_of(SKETCH_RESET_INTERVAL) {
state.countmin.reset();
}

View File

@@ -349,11 +349,11 @@ impl PoolingBackend {
debug!("setting up backend session state");
// initiates the auth session
if !disable_pg_session_jwt {
if let Err(e) = client.batch_execute("select auth.init();").await {
discard.discard();
return Err(e.into());
}
if !disable_pg_session_jwt
&& let Err(e) = client.batch_execute("select auth.init();").await
{
discard.discard();
return Err(e.into());
}
info!("backend session state initialized");

View File

@@ -148,11 +148,10 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
}
// remove from connection pool
if let Some(pool) = pool.clone().upgrade() {
if pool.write().remove_client(db_user.clone(), conn_id) {
if let Some(pool) = pool.clone().upgrade()
&& pool.write().remove_client(db_user.clone(), conn_id) {
info!("closed connection removed");
}
}
Poll::Ready(())
}).await;

View File

@@ -2,6 +2,8 @@ use std::collections::VecDeque;
use std::sync::atomic::{self, AtomicUsize};
use std::sync::{Arc, Weak};
use bytes::Bytes;
use http_body_util::combinators::BoxBody;
use hyper::client::conn::http2;
use hyper_util::rt::{TokioExecutor, TokioIo};
use parking_lot::RwLock;
@@ -20,8 +22,6 @@ use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
use crate::protocol2::ConnectionInfoExtra;
use crate::types::EndpointCacheKey;
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
use bytes::Bytes;
use http_body_util::combinators::BoxBody;
pub(crate) type Send = http2::SendRequest<BoxBody<Bytes, hyper::Error>>;
pub(crate) type Connect =
@@ -240,10 +240,10 @@ pub(crate) fn poll_http2_client(
}
// remove from connection pool
if let Some(pool) = pool.clone().upgrade() {
if pool.write().remove_conn(conn_id) {
info!("closed connection removed");
}
if let Some(pool) = pool.clone().upgrade()
&& pool.write().remove_conn(conn_id)
{
info!("closed connection removed");
}
}
.instrument(span),

View File

@@ -12,8 +12,7 @@ use serde::Serialize;
use url::Url;
use uuid::Uuid;
use super::conn_pool::AuthData;
use super::conn_pool::ConnInfoWithAuth;
use super::conn_pool::{AuthData, ConnInfoWithAuth};
use super::conn_pool_lib::ConnInfo;
use super::error::{ConnInfoError, Credentials};
use crate::auth::backend::ComputeUserInfo;

View File

@@ -249,11 +249,10 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
}
// remove from connection pool
if let Some(pool) = pool.clone().upgrade() {
if pool.global_pool.write().remove_client(db_user.clone(), conn_id) {
if let Some(pool) = pool.clone().upgrade()
&& pool.global_pool.write().remove_client(db_user.clone(), conn_id) {
info!("closed connection removed");
}
}
Poll::Ready(())
}).await;

View File

@@ -1,23 +1,25 @@
use std::pin::pin;
use std::sync::Arc;
use bytes::Bytes;
use futures::future::{Either, select, try_join};
use futures::{StreamExt, TryFutureExt};
use http::{Method, header::AUTHORIZATION};
use http_body_util::{BodyExt, Full, combinators::BoxBody};
use http::Method;
use http::header::AUTHORIZATION;
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Full};
use http_utils::error::ApiError;
use hyper::body::Incoming;
use hyper::{
Request, Response, StatusCode, header,
http::{HeaderName, HeaderValue},
};
use hyper::http::{HeaderName, HeaderValue};
use hyper::{Request, Response, StatusCode, header};
use indexmap::IndexMap;
use postgres_client::error::{DbError, ErrorPosition, SqlState};
use postgres_client::{
GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, RowStream, Transaction,
};
use serde::Serialize;
use serde_json::{Value, value::RawValue};
use std::pin::pin;
use std::sync::Arc;
use serde_json::Value;
use serde_json::value::RawValue;
use tokio::time::{self, Instant};
use tokio_util::sync::CancellationToken;
use tracing::{Level, debug, error, info};
@@ -33,7 +35,6 @@ use super::http_util::{
};
use super::json::{JsonConversionError, json_to_pg_text, pg_text_row_to_json};
use crate::auth::backend::ComputeCredentialKeys;
use crate::config::{HttpConfig, ProxyConfig};
use crate::context::RequestContext;
use crate::error::{ErrorKind, ReportableError, UserFacingError};

View File

@@ -199,27 +199,27 @@ impl<S: AsyncWrite + Unpin> PqStream<S> {
let probe_msg;
let mut msg = &*msg;
if let Some(ctx) = ctx {
if ctx.get_testodrome_id().is_some() {
let tag = match error_kind {
ErrorKind::User => "client",
ErrorKind::ClientDisconnect => "client",
ErrorKind::RateLimit => "proxy",
ErrorKind::ServiceRateLimit => "proxy",
ErrorKind::Quota => "proxy",
ErrorKind::Service => "proxy",
ErrorKind::ControlPlane => "controlplane",
ErrorKind::Postgres => "other",
ErrorKind::Compute => "compute",
};
probe_msg = typed_json::json!({
"tag": tag,
"msg": msg,
"cold_start_info": ctx.cold_start_info(),
})
.to_string();
msg = &probe_msg;
}
if let Some(ctx) = ctx
&& ctx.get_testodrome_id().is_some()
{
let tag = match error_kind {
ErrorKind::User => "client",
ErrorKind::ClientDisconnect => "client",
ErrorKind::RateLimit => "proxy",
ErrorKind::ServiceRateLimit => "proxy",
ErrorKind::Quota => "proxy",
ErrorKind::Service => "proxy",
ErrorKind::ControlPlane => "controlplane",
ErrorKind::Postgres => "other",
ErrorKind::Compute => "compute",
};
probe_msg = typed_json::json!({
"tag": tag,
"msg": msg,
"cold_start_info": ctx.cold_start_info(),
})
.to_string();
msg = &probe_msg;
}
// TODO: either preserve the error code from postgres, or assign error codes to proxy errors.