mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-12 07:52:55 +00:00
proxy: Fix some warnings by extended clippy checks (#8748)
* Missing blank lifetimes which is now deprecated. * Matching off unqualified enum variants that could act like variable. * Missing semicolons.
This commit is contained in:
@@ -113,38 +113,36 @@ impl<E: Into<AuthErrorImpl>> From<E> for AuthError {
|
||||
|
||||
impl UserFacingError for AuthError {
|
||||
fn to_string_client(&self) -> String {
|
||||
use AuthErrorImpl::*;
|
||||
match self.0.as_ref() {
|
||||
Link(e) => e.to_string_client(),
|
||||
GetAuthInfo(e) => e.to_string_client(),
|
||||
Sasl(e) => e.to_string_client(),
|
||||
AuthFailed(_) => self.to_string(),
|
||||
BadAuthMethod(_) => self.to_string(),
|
||||
MalformedPassword(_) => self.to_string(),
|
||||
MissingEndpointName => self.to_string(),
|
||||
Io(_) => "Internal error".to_string(),
|
||||
IpAddressNotAllowed(_) => self.to_string(),
|
||||
TooManyConnections => self.to_string(),
|
||||
UserTimeout(_) => self.to_string(),
|
||||
AuthErrorImpl::Link(e) => e.to_string_client(),
|
||||
AuthErrorImpl::GetAuthInfo(e) => e.to_string_client(),
|
||||
AuthErrorImpl::Sasl(e) => e.to_string_client(),
|
||||
AuthErrorImpl::AuthFailed(_) => self.to_string(),
|
||||
AuthErrorImpl::BadAuthMethod(_) => self.to_string(),
|
||||
AuthErrorImpl::MalformedPassword(_) => self.to_string(),
|
||||
AuthErrorImpl::MissingEndpointName => self.to_string(),
|
||||
AuthErrorImpl::Io(_) => "Internal error".to_string(),
|
||||
AuthErrorImpl::IpAddressNotAllowed(_) => self.to_string(),
|
||||
AuthErrorImpl::TooManyConnections => self.to_string(),
|
||||
AuthErrorImpl::UserTimeout(_) => self.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for AuthError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
use AuthErrorImpl::*;
|
||||
match self.0.as_ref() {
|
||||
Link(e) => e.get_error_kind(),
|
||||
GetAuthInfo(e) => e.get_error_kind(),
|
||||
Sasl(e) => e.get_error_kind(),
|
||||
AuthFailed(_) => crate::error::ErrorKind::User,
|
||||
BadAuthMethod(_) => crate::error::ErrorKind::User,
|
||||
MalformedPassword(_) => crate::error::ErrorKind::User,
|
||||
MissingEndpointName => crate::error::ErrorKind::User,
|
||||
Io(_) => crate::error::ErrorKind::ClientDisconnect,
|
||||
IpAddressNotAllowed(_) => crate::error::ErrorKind::User,
|
||||
TooManyConnections => crate::error::ErrorKind::RateLimit,
|
||||
UserTimeout(_) => crate::error::ErrorKind::User,
|
||||
AuthErrorImpl::Link(e) => e.get_error_kind(),
|
||||
AuthErrorImpl::GetAuthInfo(e) => e.get_error_kind(),
|
||||
AuthErrorImpl::Sasl(e) => e.get_error_kind(),
|
||||
AuthErrorImpl::AuthFailed(_) => crate::error::ErrorKind::User,
|
||||
AuthErrorImpl::BadAuthMethod(_) => crate::error::ErrorKind::User,
|
||||
AuthErrorImpl::MalformedPassword(_) => crate::error::ErrorKind::User,
|
||||
AuthErrorImpl::MissingEndpointName => crate::error::ErrorKind::User,
|
||||
AuthErrorImpl::Io(_) => crate::error::ErrorKind::ClientDisconnect,
|
||||
AuthErrorImpl::IpAddressNotAllowed(_) => crate::error::ErrorKind::User,
|
||||
AuthErrorImpl::TooManyConnections => crate::error::ErrorKind::RateLimit,
|
||||
AuthErrorImpl::UserTimeout(_) => crate::error::ErrorKind::User,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -80,9 +80,8 @@ pub trait TestBackend: Send + Sync + 'static {
|
||||
|
||||
impl std::fmt::Display for BackendType<'_, (), ()> {
|
||||
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(api, _) => match &**api {
|
||||
Self::Console(api, _) => match &**api {
|
||||
ConsoleBackend::Console(endpoint) => {
|
||||
fmt.debug_tuple("Console").field(&endpoint.url()).finish()
|
||||
}
|
||||
@@ -93,7 +92,7 @@ impl std::fmt::Display for BackendType<'_, (), ()> {
|
||||
#[cfg(test)]
|
||||
ConsoleBackend::Test(_) => fmt.debug_tuple("Test").finish(),
|
||||
},
|
||||
Link(url, _) => fmt.debug_tuple("Link").field(&url.as_str()).finish(),
|
||||
Self::Link(url, _) => fmt.debug_tuple("Link").field(&url.as_str()).finish(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -102,10 +101,9 @@ impl<T, D> BackendType<'_, T, D> {
|
||||
/// Very similar to [`std::option::Option::as_ref`].
|
||||
/// This helps us pass structured config to async tasks.
|
||||
pub fn as_ref(&self) -> BackendType<'_, &T, &D> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(c, x) => Console(MaybeOwned::Borrowed(c), x),
|
||||
Link(c, x) => Link(MaybeOwned::Borrowed(c), x),
|
||||
Self::Console(c, x) => BackendType::Console(MaybeOwned::Borrowed(c), x),
|
||||
Self::Link(c, x) => BackendType::Link(MaybeOwned::Borrowed(c), x),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -115,10 +113,9 @@ impl<'a, T, D> BackendType<'a, T, D> {
|
||||
/// Maps [`BackendType<T>`] to [`BackendType<R>`] by applying
|
||||
/// a function to a contained value.
|
||||
pub fn map<R>(self, f: impl FnOnce(T) -> R) -> BackendType<'a, R, D> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(c, x) => Console(c, f(x)),
|
||||
Link(c, x) => Link(c, x),
|
||||
Self::Console(c, x) => BackendType::Console(c, f(x)),
|
||||
Self::Link(c, x) => BackendType::Link(c, x),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -126,10 +123,9 @@ impl<'a, T, D, E> BackendType<'a, Result<T, E>, D> {
|
||||
/// Very similar to [`std::option::Option::transpose`].
|
||||
/// This is most useful for error handling.
|
||||
pub fn transpose(self) -> Result<BackendType<'a, T, D>, E> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(c, x) => x.map(|x| Console(c, x)),
|
||||
Link(c, x) => Ok(Link(c, x)),
|
||||
Self::Console(c, x) => x.map(|x| BackendType::Console(c, x)),
|
||||
Self::Link(c, x) => Ok(BackendType::Link(c, x)),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -293,7 +289,9 @@ async fn auth_quirks(
|
||||
ctx.set_endpoint_id(res.info.endpoint.clone());
|
||||
let password = match res.keys {
|
||||
ComputeCredentialKeys::Password(p) => p,
|
||||
_ => unreachable!("password hack should return a password"),
|
||||
ComputeCredentialKeys::AuthKeys(_) => {
|
||||
unreachable!("password hack should return a password")
|
||||
}
|
||||
};
|
||||
(res.info, Some(password))
|
||||
}
|
||||
@@ -400,21 +398,17 @@ async fn authenticate_with_secret(
|
||||
impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
|
||||
/// Get compute endpoint name from the credentials.
|
||||
pub fn get_endpoint(&self) -> Option<EndpointId> {
|
||||
use BackendType::*;
|
||||
|
||||
match self {
|
||||
Console(_, user_info) => user_info.endpoint_id.clone(),
|
||||
Link(_, _) => Some("link".into()),
|
||||
Self::Console(_, user_info) => user_info.endpoint_id.clone(),
|
||||
Self::Link(_, _) => Some("link".into()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get username from the credentials.
|
||||
pub fn get_user(&self) -> &str {
|
||||
use BackendType::*;
|
||||
|
||||
match self {
|
||||
Console(_, user_info) => &user_info.user,
|
||||
Link(_, _) => "link",
|
||||
Self::Console(_, user_info) => &user_info.user,
|
||||
Self::Link(_, _) => "link",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -428,10 +422,8 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
|
||||
config: &'static AuthenticationConfig,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> auth::Result<BackendType<'a, ComputeCredentials, NodeInfo>> {
|
||||
use BackendType::*;
|
||||
|
||||
let res = match self {
|
||||
Console(api, user_info) => {
|
||||
Self::Console(api, user_info) => {
|
||||
info!(
|
||||
user = &*user_info.user,
|
||||
project = user_info.endpoint(),
|
||||
@@ -451,7 +443,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
|
||||
BackendType::Console(api, credentials)
|
||||
}
|
||||
// NOTE: this auth backend doesn't use client credentials.
|
||||
Link(url, _) => {
|
||||
Self::Link(url, _) => {
|
||||
info!("performing link authentication");
|
||||
|
||||
let info = link::authenticate(ctx, &url, client).await?;
|
||||
@@ -470,10 +462,9 @@ impl BackendType<'_, ComputeUserInfo, &()> {
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
) -> Result<CachedRoleSecret, GetAuthInfoError> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(api, user_info) => api.get_role_secret(ctx, user_info).await,
|
||||
Link(_, _) => Ok(Cached::new_uncached(None)),
|
||||
Self::Console(api, user_info) => api.get_role_secret(ctx, user_info).await,
|
||||
Self::Link(_, _) => Ok(Cached::new_uncached(None)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -481,10 +472,9 @@ impl BackendType<'_, ComputeUserInfo, &()> {
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(api, user_info) => api.get_allowed_ips_and_secret(ctx, user_info).await,
|
||||
Link(_, _) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
|
||||
Self::Console(api, user_info) => api.get_allowed_ips_and_secret(ctx, user_info).await,
|
||||
Self::Link(_, _) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -495,18 +485,16 @@ impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, NodeInfo> {
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
|
||||
use BackendType::*;
|
||||
|
||||
match self {
|
||||
Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
|
||||
Link(_, info) => Ok(Cached::new_uncached(info.clone())),
|
||||
Self::Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
|
||||
Self::Link(_, info) => Ok(Cached::new_uncached(info.clone())),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_keys(&self) -> Option<&ComputeCredentialKeys> {
|
||||
match self {
|
||||
BackendType::Console(_, creds) => Some(&creds.keys),
|
||||
BackendType::Link(_, _) => None,
|
||||
Self::Console(_, creds) => Some(&creds.keys),
|
||||
Self::Link(_, _) => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -517,18 +505,16 @@ impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, &()> {
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
|
||||
use BackendType::*;
|
||||
|
||||
match self {
|
||||
Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
|
||||
Link(_, _) => unreachable!("link auth flow doesn't support waking the compute"),
|
||||
Self::Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
|
||||
Self::Link(_, _) => unreachable!("link auth flow doesn't support waking the compute"),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_keys(&self) -> Option<&ComputeCredentialKeys> {
|
||||
match self {
|
||||
BackendType::Console(_, creds) => Some(&creds.keys),
|
||||
BackendType::Link(_, _) => None,
|
||||
Self::Console(_, creds) => Some(&creds.keys),
|
||||
Self::Link(_, _) => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -195,7 +195,7 @@ impl JwkCacheEntryLock {
|
||||
|
||||
let header = base64::decode_config(header, base64::URL_SAFE_NO_PAD)
|
||||
.context("Provided authentication token is not a valid JWT encoding")?;
|
||||
let header = serde_json::from_slice::<JWTHeader>(&header)
|
||||
let header = serde_json::from_slice::<JWTHeader<'_>>(&header)
|
||||
.context("Provided authentication token is not a valid JWT encoding")?;
|
||||
|
||||
let sig = base64::decode_config(signature, base64::URL_SAFE_NO_PAD)
|
||||
@@ -340,7 +340,7 @@ impl JwkRenewalPermit<'_> {
|
||||
}
|
||||
}
|
||||
|
||||
async fn acquire_permit(from: &Arc<JwkCacheEntryLock>) -> JwkRenewalPermit {
|
||||
async fn acquire_permit(from: &Arc<JwkCacheEntryLock>) -> JwkRenewalPermit<'_> {
|
||||
match from.lookup.acquire().await {
|
||||
Ok(permit) => {
|
||||
permit.forget();
|
||||
@@ -352,7 +352,7 @@ impl JwkRenewalPermit<'_> {
|
||||
}
|
||||
}
|
||||
|
||||
fn try_acquire_permit(from: &Arc<JwkCacheEntryLock>) -> Option<JwkRenewalPermit> {
|
||||
fn try_acquire_permit(from: &Arc<JwkCacheEntryLock>) -> Option<JwkRenewalPermit<'_>> {
|
||||
match from.lookup.try_acquire() {
|
||||
Ok(permit) => {
|
||||
permit.forget();
|
||||
|
||||
@@ -89,10 +89,12 @@ impl ComputeUserInfoMaybeEndpoint {
|
||||
sni: Option<&str>,
|
||||
common_names: Option<&HashSet<String>>,
|
||||
) -> Result<Self, ComputeUserInfoParseError> {
|
||||
use ComputeUserInfoParseError::*;
|
||||
|
||||
// Some parameters are stored in the startup message.
|
||||
let get_param = |key| params.get(key).ok_or(MissingKey(key));
|
||||
let get_param = |key| {
|
||||
params
|
||||
.get(key)
|
||||
.ok_or(ComputeUserInfoParseError::MissingKey(key))
|
||||
};
|
||||
let user: RoleName = get_param("user")?.into();
|
||||
|
||||
// Project name might be passed via PG's command-line options.
|
||||
@@ -122,11 +124,14 @@ impl ComputeUserInfoMaybeEndpoint {
|
||||
let endpoint = match (endpoint_option, endpoint_from_domain) {
|
||||
// Invariant: if we have both project name variants, they should match.
|
||||
(Some(option), Some(domain)) if option != domain => {
|
||||
Some(Err(InconsistentProjectNames { domain, option }))
|
||||
Some(Err(ComputeUserInfoParseError::InconsistentProjectNames {
|
||||
domain,
|
||||
option,
|
||||
}))
|
||||
}
|
||||
// Invariant: project name may not contain certain characters.
|
||||
(a, b) => a.or(b).map(|name| match project_name_valid(name.as_ref()) {
|
||||
false => Err(MalformedProjectName(name)),
|
||||
false => Err(ComputeUserInfoParseError::MalformedProjectName(name)),
|
||||
true => Ok(name),
|
||||
}),
|
||||
}
|
||||
@@ -186,7 +191,7 @@ impl<'de> serde::de::Deserialize<'de> for IpPattern {
|
||||
impl<'de> serde::de::Visitor<'de> for StrVisitor {
|
||||
type Value = IpPattern;
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(formatter, "comma separated list with ip address, ip address range, or ip address subnet mask")
|
||||
}
|
||||
|
||||
|
||||
2
proxy/src/cache/common.rs
vendored
2
proxy/src/cache/common.rs
vendored
@@ -24,7 +24,7 @@ impl<C: Cache> Cache for &C {
|
||||
type LookupInfo<Key> = C::LookupInfo<Key>;
|
||||
|
||||
fn invalidate(&self, info: &Self::LookupInfo<Self::Key>) {
|
||||
C::invalidate(self, info)
|
||||
C::invalidate(self, info);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
2
proxy/src/cache/timed_lru.rs
vendored
2
proxy/src/cache/timed_lru.rs
vendored
@@ -58,7 +58,7 @@ impl<K: Hash + Eq, V> Cache for TimedLru<K, V> {
|
||||
type LookupInfo<Key> = LookupInfo<Key>;
|
||||
|
||||
fn invalidate(&self, info: &Self::LookupInfo<K>) {
|
||||
self.invalidate_raw(info)
|
||||
self.invalidate_raw(info);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -44,11 +44,10 @@ pub enum ConnectionError {
|
||||
|
||||
impl UserFacingError for ConnectionError {
|
||||
fn to_string_client(&self) -> String {
|
||||
use ConnectionError::*;
|
||||
match self {
|
||||
// This helps us drop irrelevant library-specific prefixes.
|
||||
// TODO: propagate severity level and other parameters.
|
||||
Postgres(err) => match err.as_db_error() {
|
||||
ConnectionError::Postgres(err) => match err.as_db_error() {
|
||||
Some(err) => {
|
||||
let msg = err.message();
|
||||
|
||||
@@ -62,8 +61,8 @@ impl UserFacingError for ConnectionError {
|
||||
}
|
||||
None => err.to_string(),
|
||||
},
|
||||
WakeComputeError(err) => err.to_string_client(),
|
||||
TooManyConnectionAttempts(_) => {
|
||||
ConnectionError::WakeComputeError(err) => err.to_string_client(),
|
||||
ConnectionError::TooManyConnectionAttempts(_) => {
|
||||
"Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
|
||||
}
|
||||
_ => COULD_NOT_CONNECT.to_owned(),
|
||||
@@ -366,16 +365,16 @@ static TLS_ROOTS: OnceCell<Arc<rustls::RootCertStore>> = OnceCell::new();
|
||||
struct AcceptEverythingVerifier;
|
||||
impl ServerCertVerifier for AcceptEverythingVerifier {
|
||||
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
|
||||
use rustls::SignatureScheme::*;
|
||||
use rustls::SignatureScheme;
|
||||
// The schemes for which `SignatureScheme::supported_in_tls13` returns true.
|
||||
vec![
|
||||
ECDSA_NISTP521_SHA512,
|
||||
ECDSA_NISTP384_SHA384,
|
||||
ECDSA_NISTP256_SHA256,
|
||||
RSA_PSS_SHA512,
|
||||
RSA_PSS_SHA384,
|
||||
RSA_PSS_SHA256,
|
||||
ED25519,
|
||||
SignatureScheme::ECDSA_NISTP521_SHA512,
|
||||
SignatureScheme::ECDSA_NISTP384_SHA384,
|
||||
SignatureScheme::ECDSA_NISTP256_SHA256,
|
||||
SignatureScheme::RSA_PSS_SHA512,
|
||||
SignatureScheme::RSA_PSS_SHA384,
|
||||
SignatureScheme::RSA_PSS_SHA256,
|
||||
SignatureScheme::ED25519,
|
||||
]
|
||||
}
|
||||
fn verify_server_cert(
|
||||
|
||||
@@ -155,7 +155,7 @@ pub enum TlsServerEndPoint {
|
||||
}
|
||||
|
||||
impl TlsServerEndPoint {
|
||||
pub fn new(cert: &CertificateDer) -> anyhow::Result<Self> {
|
||||
pub fn new(cert: &CertificateDer<'_>) -> anyhow::Result<Self> {
|
||||
let sha256_oids = [
|
||||
// I'm explicitly not adding MD5 or SHA1 here... They're bad.
|
||||
oid_registry::OID_SIG_ECDSA_WITH_SHA256,
|
||||
@@ -278,7 +278,7 @@ impl CertResolver {
|
||||
impl rustls::server::ResolvesServerCert for CertResolver {
|
||||
fn resolve(
|
||||
&self,
|
||||
client_hello: rustls::server::ClientHello,
|
||||
client_hello: rustls::server::ClientHello<'_>,
|
||||
) -> Option<Arc<rustls::sign::CertifiedKey>> {
|
||||
self.resolve(client_hello.server_name()).map(|x| x.0)
|
||||
}
|
||||
@@ -559,7 +559,7 @@ impl RetryConfig {
|
||||
match key {
|
||||
"num_retries" => num_retries = Some(value.parse()?),
|
||||
"base_retry_wait_duration" => {
|
||||
base_retry_wait_duration = Some(humantime::parse_duration(value)?)
|
||||
base_retry_wait_duration = Some(humantime::parse_duration(value)?);
|
||||
}
|
||||
"retry_wait_exponent_base" => retry_wait_exponent_base = Some(value.parse()?),
|
||||
unknown => bail!("unknown key: {unknown}"),
|
||||
|
||||
@@ -22,16 +22,15 @@ impl ConsoleError {
|
||||
self.status
|
||||
.as_ref()
|
||||
.and_then(|s| s.details.error_info.as_ref())
|
||||
.map(|e| e.reason)
|
||||
.unwrap_or(Reason::Unknown)
|
||||
.map_or(Reason::Unknown, |e| e.reason)
|
||||
}
|
||||
|
||||
pub fn get_user_facing_message(&self) -> String {
|
||||
use super::provider::errors::REQUEST_FAILED;
|
||||
self.status
|
||||
.as_ref()
|
||||
.and_then(|s| s.details.user_facing_message.as_ref())
|
||||
.map(|m| m.message.clone().into())
|
||||
.unwrap_or_else(|| {
|
||||
.map_or_else(|| {
|
||||
// Ask @neondatabase/control-plane for review before adding more.
|
||||
match self.http_status_code {
|
||||
http::StatusCode::NOT_FOUND => {
|
||||
@@ -48,19 +47,18 @@ impl ConsoleError {
|
||||
}
|
||||
_ => REQUEST_FAILED.to_owned(),
|
||||
}
|
||||
})
|
||||
}, |m| m.message.clone().into())
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for ConsoleError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
let msg = self
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let msg: &str = self
|
||||
.status
|
||||
.as_ref()
|
||||
.and_then(|s| s.details.user_facing_message.as_ref())
|
||||
.map(|m| m.message.as_ref())
|
||||
.unwrap_or_else(|| &self.error);
|
||||
write!(f, "{}", msg)
|
||||
.map_or_else(|| self.error.as_ref(), |m| m.message.as_ref());
|
||||
write!(f, "{msg}")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -286,7 +284,7 @@ pub struct DatabaseInfo {
|
||||
|
||||
// Manually implement debug to omit sensitive info.
|
||||
impl fmt::Debug for DatabaseInfo {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("DatabaseInfo")
|
||||
.field("host", &self.host)
|
||||
.field("port", &self.port)
|
||||
@@ -373,7 +371,7 @@ mod tests {
|
||||
}
|
||||
}
|
||||
});
|
||||
let _: KickSession = serde_json::from_str(&json.to_string())?;
|
||||
let _: KickSession<'_> = serde_json::from_str(&json.to_string())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -93,7 +93,8 @@ impl postgres_backend::Handler<tokio::net::TcpStream> for MgmtHandler {
|
||||
}
|
||||
|
||||
fn try_process_query(pgb: &mut PostgresBackendTCP, query: &str) -> Result<(), QueryError> {
|
||||
let resp: KickSession = serde_json::from_str(query).context("Failed to parse query as json")?;
|
||||
let resp: KickSession<'_> =
|
||||
serde_json::from_str(query).context("Failed to parse query as json")?;
|
||||
|
||||
let span = info_span!("event", session_id = resp.session_id);
|
||||
let _enter = span.enter();
|
||||
|
||||
@@ -26,7 +26,7 @@ use tracing::info;
|
||||
pub mod errors {
|
||||
use crate::{
|
||||
console::messages::{self, ConsoleError, Reason},
|
||||
error::{io_error, ReportableError, UserFacingError},
|
||||
error::{io_error, ErrorKind, ReportableError, UserFacingError},
|
||||
proxy::retry::CouldRetry,
|
||||
};
|
||||
use thiserror::Error;
|
||||
@@ -51,21 +51,19 @@ pub mod errors {
|
||||
impl ApiError {
|
||||
/// Returns HTTP status code if it's the reason for failure.
|
||||
pub fn get_reason(&self) -> messages::Reason {
|
||||
use ApiError::*;
|
||||
match self {
|
||||
Console(e) => e.get_reason(),
|
||||
_ => messages::Reason::Unknown,
|
||||
ApiError::Console(e) => e.get_reason(),
|
||||
ApiError::Transport(_) => messages::Reason::Unknown,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for ApiError {
|
||||
fn to_string_client(&self) -> String {
|
||||
use ApiError::*;
|
||||
match self {
|
||||
// To minimize risks, only select errors are forwarded to users.
|
||||
Console(c) => c.get_user_facing_message(),
|
||||
_ => REQUEST_FAILED.to_owned(),
|
||||
ApiError::Console(c) => c.get_user_facing_message(),
|
||||
ApiError::Transport(_) => REQUEST_FAILED.to_owned(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -73,57 +71,53 @@ pub mod errors {
|
||||
impl ReportableError for ApiError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
ApiError::Console(e) => {
|
||||
use crate::error::ErrorKind::*;
|
||||
match e.get_reason() {
|
||||
Reason::RoleProtected => User,
|
||||
Reason::ResourceNotFound => User,
|
||||
Reason::ProjectNotFound => User,
|
||||
Reason::EndpointNotFound => User,
|
||||
Reason::BranchNotFound => User,
|
||||
Reason::RateLimitExceeded => ServiceRateLimit,
|
||||
Reason::NonDefaultBranchComputeTimeExceeded => User,
|
||||
Reason::ActiveTimeQuotaExceeded => User,
|
||||
Reason::ComputeTimeQuotaExceeded => User,
|
||||
Reason::WrittenDataQuotaExceeded => User,
|
||||
Reason::DataTransferQuotaExceeded => User,
|
||||
Reason::LogicalSizeQuotaExceeded => User,
|
||||
Reason::ConcurrencyLimitReached => ControlPlane,
|
||||
Reason::LockAlreadyTaken => ControlPlane,
|
||||
Reason::RunningOperations => ControlPlane,
|
||||
Reason::Unknown => match &e {
|
||||
ConsoleError {
|
||||
http_status_code:
|
||||
http::StatusCode::NOT_FOUND | http::StatusCode::NOT_ACCEPTABLE,
|
||||
..
|
||||
} => crate::error::ErrorKind::User,
|
||||
ConsoleError {
|
||||
http_status_code: http::StatusCode::UNPROCESSABLE_ENTITY,
|
||||
error,
|
||||
..
|
||||
} if error.contains(
|
||||
"compute time quota of non-primary branches is exceeded",
|
||||
) =>
|
||||
{
|
||||
crate::error::ErrorKind::User
|
||||
}
|
||||
ConsoleError {
|
||||
http_status_code: http::StatusCode::LOCKED,
|
||||
error,
|
||||
..
|
||||
} if error.contains("quota exceeded")
|
||||
|| error.contains("the limit for current plan reached") =>
|
||||
{
|
||||
crate::error::ErrorKind::User
|
||||
}
|
||||
ConsoleError {
|
||||
http_status_code: http::StatusCode::TOO_MANY_REQUESTS,
|
||||
..
|
||||
} => crate::error::ErrorKind::ServiceRateLimit,
|
||||
ConsoleError { .. } => crate::error::ErrorKind::ControlPlane,
|
||||
},
|
||||
}
|
||||
}
|
||||
ApiError::Console(e) => match e.get_reason() {
|
||||
Reason::RoleProtected => ErrorKind::User,
|
||||
Reason::ResourceNotFound => ErrorKind::User,
|
||||
Reason::ProjectNotFound => ErrorKind::User,
|
||||
Reason::EndpointNotFound => ErrorKind::User,
|
||||
Reason::BranchNotFound => ErrorKind::User,
|
||||
Reason::RateLimitExceeded => ErrorKind::ServiceRateLimit,
|
||||
Reason::NonDefaultBranchComputeTimeExceeded => ErrorKind::User,
|
||||
Reason::ActiveTimeQuotaExceeded => ErrorKind::User,
|
||||
Reason::ComputeTimeQuotaExceeded => ErrorKind::User,
|
||||
Reason::WrittenDataQuotaExceeded => ErrorKind::User,
|
||||
Reason::DataTransferQuotaExceeded => ErrorKind::User,
|
||||
Reason::LogicalSizeQuotaExceeded => ErrorKind::User,
|
||||
Reason::ConcurrencyLimitReached => ErrorKind::ControlPlane,
|
||||
Reason::LockAlreadyTaken => ErrorKind::ControlPlane,
|
||||
Reason::RunningOperations => ErrorKind::ControlPlane,
|
||||
Reason::Unknown => match &e {
|
||||
ConsoleError {
|
||||
http_status_code:
|
||||
http::StatusCode::NOT_FOUND | http::StatusCode::NOT_ACCEPTABLE,
|
||||
..
|
||||
} => crate::error::ErrorKind::User,
|
||||
ConsoleError {
|
||||
http_status_code: http::StatusCode::UNPROCESSABLE_ENTITY,
|
||||
error,
|
||||
..
|
||||
} if error
|
||||
.contains("compute time quota of non-primary branches is exceeded") =>
|
||||
{
|
||||
crate::error::ErrorKind::User
|
||||
}
|
||||
ConsoleError {
|
||||
http_status_code: http::StatusCode::LOCKED,
|
||||
error,
|
||||
..
|
||||
} if error.contains("quota exceeded")
|
||||
|| error.contains("the limit for current plan reached") =>
|
||||
{
|
||||
crate::error::ErrorKind::User
|
||||
}
|
||||
ConsoleError {
|
||||
http_status_code: http::StatusCode::TOO_MANY_REQUESTS,
|
||||
..
|
||||
} => crate::error::ErrorKind::ServiceRateLimit,
|
||||
ConsoleError { .. } => crate::error::ErrorKind::ControlPlane,
|
||||
},
|
||||
},
|
||||
ApiError::Transport(_) => crate::error::ErrorKind::ControlPlane,
|
||||
}
|
||||
}
|
||||
@@ -170,12 +164,11 @@ pub mod errors {
|
||||
|
||||
impl UserFacingError for GetAuthInfoError {
|
||||
fn to_string_client(&self) -> String {
|
||||
use GetAuthInfoError::*;
|
||||
match self {
|
||||
// We absolutely should not leak any secrets!
|
||||
BadSecret => REQUEST_FAILED.to_owned(),
|
||||
Self::BadSecret => REQUEST_FAILED.to_owned(),
|
||||
// However, API might return a meaningful error.
|
||||
ApiError(e) => e.to_string_client(),
|
||||
Self::ApiError(e) => e.to_string_client(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -183,8 +176,8 @@ pub mod errors {
|
||||
impl ReportableError for GetAuthInfoError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
GetAuthInfoError::BadSecret => crate::error::ErrorKind::ControlPlane,
|
||||
GetAuthInfoError::ApiError(_) => crate::error::ErrorKind::ControlPlane,
|
||||
Self::BadSecret => crate::error::ErrorKind::ControlPlane,
|
||||
Self::ApiError(_) => crate::error::ErrorKind::ControlPlane,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -213,17 +206,16 @@ pub mod errors {
|
||||
|
||||
impl UserFacingError for WakeComputeError {
|
||||
fn to_string_client(&self) -> String {
|
||||
use WakeComputeError::*;
|
||||
match self {
|
||||
// We shouldn't show user the address even if it's broken.
|
||||
// Besides, user is unlikely to care about this detail.
|
||||
BadComputeAddress(_) => REQUEST_FAILED.to_owned(),
|
||||
Self::BadComputeAddress(_) => REQUEST_FAILED.to_owned(),
|
||||
// However, API might return a meaningful error.
|
||||
ApiError(e) => e.to_string_client(),
|
||||
Self::ApiError(e) => e.to_string_client(),
|
||||
|
||||
TooManyConnections => self.to_string(),
|
||||
Self::TooManyConnections => self.to_string(),
|
||||
|
||||
TooManyConnectionAttempts(_) => {
|
||||
Self::TooManyConnectionAttempts(_) => {
|
||||
"Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
|
||||
}
|
||||
}
|
||||
@@ -233,10 +225,10 @@ pub mod errors {
|
||||
impl ReportableError for WakeComputeError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
WakeComputeError::BadComputeAddress(_) => crate::error::ErrorKind::ControlPlane,
|
||||
WakeComputeError::ApiError(e) => e.get_error_kind(),
|
||||
WakeComputeError::TooManyConnections => crate::error::ErrorKind::RateLimit,
|
||||
WakeComputeError::TooManyConnectionAttempts(e) => e.get_error_kind(),
|
||||
Self::BadComputeAddress(_) => crate::error::ErrorKind::ControlPlane,
|
||||
Self::ApiError(e) => e.get_error_kind(),
|
||||
Self::TooManyConnections => crate::error::ErrorKind::RateLimit,
|
||||
Self::TooManyConnectionAttempts(e) => e.get_error_kind(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -244,10 +236,10 @@ pub mod errors {
|
||||
impl CouldRetry for WakeComputeError {
|
||||
fn could_retry(&self) -> bool {
|
||||
match self {
|
||||
WakeComputeError::BadComputeAddress(_) => false,
|
||||
WakeComputeError::ApiError(e) => e.could_retry(),
|
||||
WakeComputeError::TooManyConnections => false,
|
||||
WakeComputeError::TooManyConnectionAttempts(_) => false,
|
||||
Self::BadComputeAddress(_) => false,
|
||||
Self::ApiError(e) => e.could_retry(),
|
||||
Self::TooManyConnections => false,
|
||||
Self::TooManyConnectionAttempts(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -366,13 +358,14 @@ impl Api for ConsoleBackend {
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedRoleSecret, errors::GetAuthInfoError> {
|
||||
use ConsoleBackend::*;
|
||||
match self {
|
||||
Console(api) => api.get_role_secret(ctx, user_info).await,
|
||||
Self::Console(api) => api.get_role_secret(ctx, user_info).await,
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Postgres(api) => api.get_role_secret(ctx, user_info).await,
|
||||
Self::Postgres(api) => api.get_role_secret(ctx, user_info).await,
|
||||
#[cfg(test)]
|
||||
Test(_) => unreachable!("this function should never be called in the test backend"),
|
||||
Self::Test(_) => {
|
||||
unreachable!("this function should never be called in the test backend")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -381,13 +374,12 @@ impl Api for ConsoleBackend {
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError> {
|
||||
use ConsoleBackend::*;
|
||||
match self {
|
||||
Console(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
|
||||
Self::Console(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Postgres(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
|
||||
Self::Postgres(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
|
||||
#[cfg(test)]
|
||||
Test(api) => api.get_allowed_ips_and_secret(),
|
||||
Self::Test(api) => api.get_allowed_ips_and_secret(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -396,14 +388,12 @@ impl Api for ConsoleBackend {
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedNodeInfo, errors::WakeComputeError> {
|
||||
use ConsoleBackend::*;
|
||||
|
||||
match self {
|
||||
Console(api) => api.wake_compute(ctx, user_info).await,
|
||||
Self::Console(api) => api.wake_compute(ctx, user_info).await,
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Postgres(api) => api.wake_compute(ctx, user_info).await,
|
||||
Self::Postgres(api) => api.wake_compute(ctx, user_info).await,
|
||||
#[cfg(test)]
|
||||
Test(api) => api.wake_compute(),
|
||||
Self::Test(api) => api.wake_compute(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -549,7 +539,7 @@ impl WakeComputePermit {
|
||||
!self.permit.is_disabled()
|
||||
}
|
||||
pub fn release(self, outcome: Outcome) {
|
||||
self.permit.release(outcome)
|
||||
self.permit.release(outcome);
|
||||
}
|
||||
pub fn release_result<T, E>(self, res: Result<T, E>) -> Result<T, E> {
|
||||
match res {
|
||||
|
||||
@@ -166,7 +166,7 @@ impl RequestMonitoring {
|
||||
pub fn set_project(&self, x: MetricsAuxInfo) {
|
||||
let mut this = self.0.try_lock().expect("should not deadlock");
|
||||
if this.endpoint_id.is_none() {
|
||||
this.set_endpoint_id(x.endpoint_id.as_str().into())
|
||||
this.set_endpoint_id(x.endpoint_id.as_str().into());
|
||||
}
|
||||
this.branch = Some(x.branch_id);
|
||||
this.project = Some(x.project_id);
|
||||
@@ -260,7 +260,7 @@ impl RequestMonitoring {
|
||||
.cold_start_info
|
||||
}
|
||||
|
||||
pub fn latency_timer_pause(&self, waiting_for: Waiting) -> LatencyTimerPause {
|
||||
pub fn latency_timer_pause(&self, waiting_for: Waiting) -> LatencyTimerPause<'_> {
|
||||
LatencyTimerPause {
|
||||
ctx: self,
|
||||
start: tokio::time::Instant::now(),
|
||||
@@ -273,7 +273,7 @@ impl RequestMonitoring {
|
||||
.try_lock()
|
||||
.expect("should not deadlock")
|
||||
.latency_timer
|
||||
.success()
|
||||
.success();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -328,7 +328,7 @@ impl RequestMonitoringInner {
|
||||
fn has_private_peer_addr(&self) -> bool {
|
||||
match self.peer_addr {
|
||||
IpAddr::V4(ip) => ip.is_private(),
|
||||
_ => false,
|
||||
IpAddr::V6(_) => false,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -736,7 +736,7 @@ mod tests {
|
||||
while let Some(r) = s.next().await {
|
||||
tx.send(r).unwrap();
|
||||
}
|
||||
time::sleep(time::Duration::from_secs(70)).await
|
||||
time::sleep(time::Duration::from_secs(70)).await;
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ impl<'de, Id: InternId> serde::de::Deserialize<'de> for InternedString<Id> {
|
||||
impl<'de, Id: InternId> serde::de::Visitor<'de> for Visitor<Id> {
|
||||
type Value = InternedString<Id>;
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
formatter.write_str("a string")
|
||||
}
|
||||
|
||||
|
||||
@@ -252,7 +252,7 @@ impl Drop for HttpEndpointPoolsGuard<'_> {
|
||||
}
|
||||
|
||||
impl HttpEndpointPools {
|
||||
pub fn guard(&self) -> HttpEndpointPoolsGuard {
|
||||
pub fn guard(&self) -> HttpEndpointPoolsGuard<'_> {
|
||||
self.http_pool_endpoints_registered_total.inc();
|
||||
HttpEndpointPoolsGuard {
|
||||
dec: &self.http_pool_endpoints_unregistered_total,
|
||||
|
||||
@@ -184,7 +184,7 @@ impl CopyBuffer {
|
||||
}
|
||||
Poll::Pending
|
||||
}
|
||||
res => res.map_err(ErrorDirection::Write),
|
||||
res @ Poll::Ready(_) => res.map_err(ErrorDirection::Write),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -82,9 +82,8 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
let mut stream = PqStream::new(Stream::from_raw(stream));
|
||||
loop {
|
||||
let msg = stream.read_startup_packet().await?;
|
||||
use FeStartupPacket::*;
|
||||
match msg {
|
||||
SslRequest { direct } => match stream.get_ref() {
|
||||
FeStartupPacket::SslRequest { direct } => match stream.get_ref() {
|
||||
Stream::Raw { .. } if !tried_ssl => {
|
||||
tried_ssl = true;
|
||||
|
||||
@@ -139,7 +138,7 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
|
||||
let tls_stream = accept.await.inspect_err(|_| {
|
||||
if record_handshake_error {
|
||||
Metrics::get().proxy.tls_handshake_failures.inc()
|
||||
Metrics::get().proxy.tls_handshake_failures.inc();
|
||||
}
|
||||
})?;
|
||||
|
||||
@@ -182,7 +181,7 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
}
|
||||
_ => return Err(HandshakeError::ProtocolViolation),
|
||||
},
|
||||
GssEncRequest => match stream.get_ref() {
|
||||
FeStartupPacket::GssEncRequest => match stream.get_ref() {
|
||||
Stream::Raw { .. } if !tried_gss => {
|
||||
tried_gss = true;
|
||||
|
||||
@@ -191,7 +190,7 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
}
|
||||
_ => return Err(HandshakeError::ProtocolViolation),
|
||||
},
|
||||
StartupMessage { params, version }
|
||||
FeStartupPacket::StartupMessage { params, version }
|
||||
if PG_PROTOCOL_EARLIEST <= version && version <= PG_PROTOCOL_LATEST =>
|
||||
{
|
||||
// Check that the config has been consumed during upgrade
|
||||
@@ -211,7 +210,7 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
break Ok(HandshakeData::Startup(stream, params));
|
||||
}
|
||||
// downgrade protocol version
|
||||
StartupMessage { params, version }
|
||||
FeStartupPacket::StartupMessage { params, version }
|
||||
if version.major() == 3 && version > PG_PROTOCOL_LATEST =>
|
||||
{
|
||||
warn!(?version, "unsupported minor version");
|
||||
@@ -241,7 +240,7 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
);
|
||||
break Ok(HandshakeData::Startup(stream, params));
|
||||
}
|
||||
StartupMessage { version, .. } => {
|
||||
FeStartupPacket::StartupMessage { version, .. } => {
|
||||
warn!(
|
||||
?version,
|
||||
session_type = "normal",
|
||||
@@ -249,7 +248,7 @@ pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
);
|
||||
return Err(HandshakeError::ProtocolViolation);
|
||||
}
|
||||
CancelRequest(cancel_key_data) => {
|
||||
FeStartupPacket::CancelRequest(cancel_key_data) => {
|
||||
info!(session_type = "cancellation", "successful handshake");
|
||||
break Ok(HandshakeData::Cancel(cancel_key_data));
|
||||
}
|
||||
|
||||
@@ -68,7 +68,7 @@ async fn proxy_mitm(
|
||||
end_client.send(Bytes::from_static(b"R\0\0\0\x17\0\0\0\x0aSCRAM-SHA-256\0\0")).await.unwrap();
|
||||
continue;
|
||||
}
|
||||
end_client.send(message).await.unwrap()
|
||||
end_client.send(message).await.unwrap();
|
||||
}
|
||||
_ => break,
|
||||
}
|
||||
@@ -88,7 +88,7 @@ async fn proxy_mitm(
|
||||
end_server.send(buf.freeze()).await.unwrap();
|
||||
continue;
|
||||
}
|
||||
end_server.send(message).await.unwrap()
|
||||
end_server.send(message).await.unwrap();
|
||||
}
|
||||
_ => break,
|
||||
}
|
||||
|
||||
@@ -237,7 +237,7 @@ impl Token {
|
||||
}
|
||||
|
||||
pub fn release(mut self, outcome: Outcome) {
|
||||
self.release_mut(Some(outcome))
|
||||
self.release_mut(Some(outcome));
|
||||
}
|
||||
|
||||
pub fn release_mut(&mut self, outcome: Option<Outcome>) {
|
||||
@@ -249,7 +249,7 @@ impl Token {
|
||||
|
||||
impl Drop for Token {
|
||||
fn drop(&mut self) {
|
||||
self.release_mut(None)
|
||||
self.release_mut(None);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -25,9 +25,8 @@ pub struct Aimd {
|
||||
|
||||
impl LimitAlgorithm for Aimd {
|
||||
fn update(&self, old_limit: usize, sample: Sample) -> usize {
|
||||
use Outcome::*;
|
||||
match sample.outcome {
|
||||
Success => {
|
||||
Outcome::Success => {
|
||||
let utilisation = sample.in_flight as f32 / old_limit as f32;
|
||||
|
||||
if utilisation > self.utilisation {
|
||||
@@ -42,7 +41,7 @@ impl LimitAlgorithm for Aimd {
|
||||
old_limit
|
||||
}
|
||||
}
|
||||
Overload => {
|
||||
Outcome::Overload => {
|
||||
let limit = old_limit as f32 * self.dec;
|
||||
|
||||
// Floor instead of round, so the limit reduces even with small numbers.
|
||||
|
||||
@@ -98,7 +98,7 @@ impl ConnectionWithCredentialsProvider {
|
||||
info!("Establishing a new connection...");
|
||||
self.con = None;
|
||||
if let Some(f) = self.refresh_token_task.take() {
|
||||
f.abort()
|
||||
f.abort();
|
||||
}
|
||||
let mut con = self
|
||||
.get_client()
|
||||
|
||||
@@ -108,7 +108,6 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
|
||||
}
|
||||
#[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))]
|
||||
async fn handle_message(&self, msg: redis::Msg) -> anyhow::Result<()> {
|
||||
use Notification::*;
|
||||
let payload: String = msg.get_payload()?;
|
||||
tracing::debug!(?payload, "received a message payload");
|
||||
|
||||
@@ -124,7 +123,7 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
|
||||
};
|
||||
tracing::debug!(?msg, "received a message");
|
||||
match msg {
|
||||
Cancel(cancel_session) => {
|
||||
Notification::Cancel(cancel_session) => {
|
||||
tracing::Span::current().record(
|
||||
"session_id",
|
||||
tracing::field::display(cancel_session.session_id),
|
||||
@@ -153,12 +152,12 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
|
||||
}
|
||||
_ => {
|
||||
invalidate_cache(self.cache.clone(), msg.clone());
|
||||
if matches!(msg, AllowedIpsUpdate { .. }) {
|
||||
if matches!(msg, Notification::AllowedIpsUpdate { .. }) {
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.redis_events_count
|
||||
.inc(RedisEventsCount::AllowedIpsUpdate);
|
||||
} else if matches!(msg, PasswordUpdate { .. }) {
|
||||
} else if matches!(msg, Notification::PasswordUpdate { .. }) {
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.redis_events_count
|
||||
@@ -180,16 +179,16 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
|
||||
}
|
||||
|
||||
fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
|
||||
use Notification::*;
|
||||
match msg {
|
||||
AllowedIpsUpdate { allowed_ips_update } => {
|
||||
cache.invalidate_allowed_ips_for_project(allowed_ips_update.project_id)
|
||||
Notification::AllowedIpsUpdate { allowed_ips_update } => {
|
||||
cache.invalidate_allowed_ips_for_project(allowed_ips_update.project_id);
|
||||
}
|
||||
PasswordUpdate { password_update } => cache.invalidate_role_secret_for_project(
|
||||
password_update.project_id,
|
||||
password_update.role_name,
|
||||
),
|
||||
Cancel(_) => unreachable!("cancel message should be handled separately"),
|
||||
Notification::PasswordUpdate { password_update } => cache
|
||||
.invalidate_role_secret_for_project(
|
||||
password_update.project_id,
|
||||
password_update.role_name,
|
||||
),
|
||||
Notification::Cancel(_) => unreachable!("cancel message should be handled separately"),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -42,10 +42,9 @@ pub enum Error {
|
||||
|
||||
impl UserFacingError for Error {
|
||||
fn to_string_client(&self) -> String {
|
||||
use Error::*;
|
||||
match self {
|
||||
ChannelBindingFailed(m) => m.to_string(),
|
||||
ChannelBindingBadMethod(m) => format!("unsupported channel binding method {m}"),
|
||||
Self::ChannelBindingFailed(m) => (*m).to_string(),
|
||||
Self::ChannelBindingBadMethod(m) => format!("unsupported channel binding method {m}"),
|
||||
_ => "authentication protocol violation".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,11 +13,10 @@ pub enum ChannelBinding<T> {
|
||||
|
||||
impl<T> ChannelBinding<T> {
|
||||
pub fn and_then<R, E>(self, f: impl FnOnce(T) -> Result<R, E>) -> Result<ChannelBinding<R>, E> {
|
||||
use ChannelBinding::*;
|
||||
Ok(match self {
|
||||
NotSupportedClient => NotSupportedClient,
|
||||
NotSupportedServer => NotSupportedServer,
|
||||
Required(x) => Required(f(x)?),
|
||||
Self::NotSupportedClient => ChannelBinding::NotSupportedClient,
|
||||
Self::NotSupportedServer => ChannelBinding::NotSupportedServer,
|
||||
Self::Required(x) => ChannelBinding::Required(f(x)?),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -25,11 +24,10 @@ impl<T> ChannelBinding<T> {
|
||||
impl<'a> ChannelBinding<&'a str> {
|
||||
// NB: FromStr doesn't work with lifetimes
|
||||
pub fn parse(input: &'a str) -> Option<Self> {
|
||||
use ChannelBinding::*;
|
||||
Some(match input {
|
||||
"n" => NotSupportedClient,
|
||||
"y" => NotSupportedServer,
|
||||
other => Required(other.strip_prefix("p=")?),
|
||||
"n" => Self::NotSupportedClient,
|
||||
"y" => Self::NotSupportedServer,
|
||||
other => Self::Required(other.strip_prefix("p=")?),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -40,17 +38,16 @@ impl<T: std::fmt::Display> ChannelBinding<T> {
|
||||
&self,
|
||||
get_cbind_data: impl FnOnce(&T) -> Result<&'a [u8], E>,
|
||||
) -> Result<std::borrow::Cow<'static, str>, E> {
|
||||
use ChannelBinding::*;
|
||||
Ok(match self {
|
||||
NotSupportedClient => {
|
||||
Self::NotSupportedClient => {
|
||||
// base64::encode("n,,")
|
||||
"biws".into()
|
||||
}
|
||||
NotSupportedServer => {
|
||||
Self::NotSupportedServer => {
|
||||
// base64::encode("y,,")
|
||||
"eSws".into()
|
||||
}
|
||||
Required(mode) => {
|
||||
Self::Required(mode) => {
|
||||
use std::io::Write;
|
||||
let mut cbind_input = vec![];
|
||||
write!(&mut cbind_input, "p={mode},,",).unwrap();
|
||||
|
||||
@@ -42,10 +42,9 @@ pub(super) enum ServerMessage<T> {
|
||||
|
||||
impl<'a> ServerMessage<&'a str> {
|
||||
pub(super) fn to_reply(&self) -> BeMessage<'a> {
|
||||
use BeAuthenticationSaslMessage::*;
|
||||
BeMessage::AuthenticationSasl(match self {
|
||||
ServerMessage::Continue(s) => Continue(s.as_bytes()),
|
||||
ServerMessage::Final(s) => Final(s.as_bytes()),
|
||||
ServerMessage::Continue(s) => BeAuthenticationSaslMessage::Continue(s.as_bytes()),
|
||||
ServerMessage::Final(s) => BeAuthenticationSaslMessage::Final(s.as_bytes()),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -137,12 +137,12 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn round_trip() {
|
||||
run_round_trip_test("pencil", "pencil").await
|
||||
run_round_trip_test("pencil", "pencil").await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[should_panic(expected = "password doesn't match")]
|
||||
async fn failure() {
|
||||
run_round_trip_test("pencil", "eraser").await
|
||||
run_round_trip_test("pencil", "eraser").await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -98,8 +98,6 @@ mod tests {
|
||||
// q% of counts will be within p of the actual value
|
||||
let mut sketch = CountMinSketch::with_params(p / N as f64, 1.0 - q);
|
||||
|
||||
dbg!(sketch.buckets.len());
|
||||
|
||||
// insert a bunch of entries in a random order
|
||||
let mut ids2 = ids.clone();
|
||||
while !ids2.is_empty() {
|
||||
|
||||
@@ -210,23 +210,23 @@ impl sasl::Mechanism for Exchange<'_> {
|
||||
type Output = super::ScramKey;
|
||||
|
||||
fn exchange(mut self, input: &str) -> sasl::Result<sasl::Step<Self, Self::Output>> {
|
||||
use {sasl::Step::*, ExchangeState::*};
|
||||
use {sasl::Step, ExchangeState};
|
||||
match &self.state {
|
||||
Initial(init) => {
|
||||
ExchangeState::Initial(init) => {
|
||||
match init.transition(self.secret, &self.tls_server_end_point, input)? {
|
||||
Continue(sent, msg) => {
|
||||
self.state = SaltSent(sent);
|
||||
Ok(Continue(self, msg))
|
||||
Step::Continue(sent, msg) => {
|
||||
self.state = ExchangeState::SaltSent(sent);
|
||||
Ok(Step::Continue(self, msg))
|
||||
}
|
||||
Success(x, _) => match x {},
|
||||
Failure(msg) => Ok(Failure(msg)),
|
||||
Step::Success(x, _) => match x {},
|
||||
Step::Failure(msg) => Ok(Step::Failure(msg)),
|
||||
}
|
||||
}
|
||||
SaltSent(sent) => {
|
||||
ExchangeState::SaltSent(sent) => {
|
||||
match sent.transition(self.secret, &self.tls_server_end_point, input)? {
|
||||
Success(keys, msg) => Ok(Success(keys, msg)),
|
||||
Continue(x, _) => match x {},
|
||||
Failure(msg) => Ok(Failure(msg)),
|
||||
Step::Success(keys, msg) => Ok(Step::Success(keys, msg)),
|
||||
Step::Continue(x, _) => match x {},
|
||||
Step::Failure(msg) => Ok(Step::Failure(msg)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,7 +59,7 @@ impl<'a> ClientFirstMessage<'a> {
|
||||
|
||||
// https://github.com/postgres/postgres/blob/f83908798f78c4cafda217ca875602c88ea2ae28/src/backend/libpq/auth-scram.c#L13-L14
|
||||
if !username.is_empty() {
|
||||
tracing::warn!(username, "scram username provided, but is not expected")
|
||||
tracing::warn!(username, "scram username provided, but is not expected");
|
||||
// TODO(conrad):
|
||||
// return None;
|
||||
}
|
||||
@@ -137,7 +137,7 @@ impl<'a> ClientFinalMessage<'a> {
|
||||
/// Build a response to [`ClientFinalMessage`].
|
||||
pub fn build_server_final_message(
|
||||
&self,
|
||||
signature_builder: SignatureBuilder,
|
||||
signature_builder: SignatureBuilder<'_>,
|
||||
server_key: &ScramKey,
|
||||
) -> String {
|
||||
let mut buf = String::from("v=");
|
||||
@@ -212,7 +212,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn parse_client_first_message_with_invalid_gs2_authz() {
|
||||
assert!(ClientFirstMessage::parse("n,authzid,n=,r=nonce").is_none())
|
||||
assert!(ClientFirstMessage::parse("n,authzid,n=,r=nonce").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -84,6 +84,6 @@ mod tests {
|
||||
};
|
||||
|
||||
let expected = pbkdf2_hmac_array::<Sha256, 32>(pass, salt, 600000);
|
||||
assert_eq!(hash, expected)
|
||||
assert_eq!(hash, expected);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -270,7 +270,7 @@ fn thread_rt(pool: Arc<ThreadPool>, worker: Worker<JobSpec>, index: usize) {
|
||||
.inc(ThreadPoolWorkerId(index));
|
||||
|
||||
// skip for now
|
||||
worker.push(job)
|
||||
worker.push(job);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -316,6 +316,6 @@ mod tests {
|
||||
10, 114, 73, 188, 140, 222, 196, 156, 214, 184, 79, 157, 119, 242, 16, 31, 53, 242,
|
||||
178, 43, 95, 8, 225, 182, 122, 40, 219, 21, 89, 147, 64, 140,
|
||||
];
|
||||
assert_eq!(actual, expected)
|
||||
assert_eq!(actual, expected);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -120,7 +120,7 @@ pub async fn task_main(
|
||||
tracing::trace!("attempting to cancel a random connection");
|
||||
if let Some(token) = config.http_config.cancel_set.take() {
|
||||
tracing::debug!("cancelling a random connection");
|
||||
token.cancel()
|
||||
token.cancel();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -198,7 +198,7 @@ async fn connection_startup(
|
||||
let peer_addr = peer.unwrap_or(peer_addr).ip();
|
||||
let has_private_peer_addr = match peer_addr {
|
||||
IpAddr::V4(ip) => ip.is_private(),
|
||||
_ => false,
|
||||
IpAddr::V6(_) => false,
|
||||
};
|
||||
info!(?session_id, %peer_addr, "accepted new TCP connection");
|
||||
|
||||
|
||||
@@ -390,7 +390,7 @@ impl<C: ClientInnerExt> GlobalConnPool<C> {
|
||||
.write()
|
||||
.get_conn_entry(conn_info.db_and_user())
|
||||
{
|
||||
client = Some(entry.conn)
|
||||
client = Some(entry.conn);
|
||||
}
|
||||
let endpoint_pool = Arc::downgrade(&endpoint_pool);
|
||||
|
||||
@@ -662,13 +662,13 @@ impl<C: ClientInnerExt> Discard<'_, C> {
|
||||
pub fn check_idle(&mut self, status: ReadyForQueryStatus) {
|
||||
let conn_info = &self.conn_info;
|
||||
if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
|
||||
info!("pool: throwing away connection '{conn_info}' because connection is not idle")
|
||||
info!("pool: throwing away connection '{conn_info}' because connection is not idle");
|
||||
}
|
||||
}
|
||||
pub fn discard(&mut self) {
|
||||
let conn_info = &self.conn_info;
|
||||
if std::mem::take(self.pool).strong_count() > 0 {
|
||||
info!("pool: throwing away connection '{conn_info}' because connection is potentially in a broken state")
|
||||
info!("pool: throwing away connection '{conn_info}' because connection is potentially in a broken state");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -234,7 +234,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
|
||||
.await
|
||||
.inspect_err(|_| {
|
||||
if record_handshake_error {
|
||||
Metrics::get().proxy.tls_handshake_failures.inc()
|
||||
Metrics::get().proxy.tls_handshake_failures.inc();
|
||||
}
|
||||
})?),
|
||||
Stream::Tls { .. } => Err(StreamUpgradeError::AlreadyTls),
|
||||
|
||||
@@ -12,7 +12,7 @@ impl ApiUrl {
|
||||
}
|
||||
|
||||
/// See [`url::Url::path_segments_mut`].
|
||||
pub fn path_segments_mut(&mut self) -> url::PathSegmentsMut {
|
||||
pub fn path_segments_mut(&mut self) -> url::PathSegmentsMut<'_> {
|
||||
// We've already verified that it works during construction.
|
||||
self.0.path_segments_mut().expect("bad API url")
|
||||
}
|
||||
|
||||
@@ -36,7 +36,7 @@ impl<T> Default for Waiters<T> {
|
||||
}
|
||||
|
||||
impl<T> Waiters<T> {
|
||||
pub fn register(&self, key: String) -> Result<Waiter<T>, RegisterError> {
|
||||
pub fn register(&self, key: String) -> Result<Waiter<'_, T>, RegisterError> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
self.0
|
||||
|
||||
Reference in New Issue
Block a user