mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-07 13:32:57 +00:00
proxy: remove dead code (#8847)
By marking everything possible as pub(crate), we find a few dead code candidates.
This commit is contained in:
@@ -4,17 +4,17 @@ pub mod backend;
|
||||
pub use backend::BackendType;
|
||||
|
||||
mod credentials;
|
||||
pub use credentials::{
|
||||
pub(crate) use credentials::{
|
||||
check_peer_addr_is_in_list, endpoint_sni, ComputeUserInfoMaybeEndpoint,
|
||||
ComputeUserInfoParseError, IpPattern,
|
||||
};
|
||||
|
||||
mod password_hack;
|
||||
pub use password_hack::parse_endpoint_param;
|
||||
pub(crate) use password_hack::parse_endpoint_param;
|
||||
use password_hack::PasswordHackPayload;
|
||||
|
||||
mod flow;
|
||||
pub use flow::*;
|
||||
pub(crate) use flow::*;
|
||||
use tokio::time::error::Elapsed;
|
||||
|
||||
use crate::{
|
||||
@@ -25,11 +25,11 @@ use std::{io, net::IpAddr};
|
||||
use thiserror::Error;
|
||||
|
||||
/// Convenience wrapper for the authentication error.
|
||||
pub type Result<T> = std::result::Result<T, AuthError>;
|
||||
pub(crate) type Result<T> = std::result::Result<T, AuthError>;
|
||||
|
||||
/// Common authentication error.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum AuthErrorImpl {
|
||||
pub(crate) enum AuthErrorImpl {
|
||||
#[error(transparent)]
|
||||
Link(#[from] backend::LinkAuthError),
|
||||
|
||||
@@ -77,30 +77,30 @@ pub enum AuthErrorImpl {
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error(transparent)]
|
||||
pub struct AuthError(Box<AuthErrorImpl>);
|
||||
pub(crate) struct AuthError(Box<AuthErrorImpl>);
|
||||
|
||||
impl AuthError {
|
||||
pub fn bad_auth_method(name: impl Into<Box<str>>) -> Self {
|
||||
pub(crate) fn bad_auth_method(name: impl Into<Box<str>>) -> Self {
|
||||
AuthErrorImpl::BadAuthMethod(name.into()).into()
|
||||
}
|
||||
|
||||
pub fn auth_failed(user: impl Into<Box<str>>) -> Self {
|
||||
pub(crate) fn auth_failed(user: impl Into<Box<str>>) -> Self {
|
||||
AuthErrorImpl::AuthFailed(user.into()).into()
|
||||
}
|
||||
|
||||
pub fn ip_address_not_allowed(ip: IpAddr) -> Self {
|
||||
pub(crate) fn ip_address_not_allowed(ip: IpAddr) -> Self {
|
||||
AuthErrorImpl::IpAddressNotAllowed(ip).into()
|
||||
}
|
||||
|
||||
pub fn too_many_connections() -> Self {
|
||||
pub(crate) fn too_many_connections() -> Self {
|
||||
AuthErrorImpl::TooManyConnections.into()
|
||||
}
|
||||
|
||||
pub fn is_auth_failed(&self) -> bool {
|
||||
pub(crate) fn is_auth_failed(&self) -> bool {
|
||||
matches!(self.0.as_ref(), AuthErrorImpl::AuthFailed(_))
|
||||
}
|
||||
|
||||
pub fn user_timeout(elapsed: Elapsed) -> Self {
|
||||
pub(crate) fn user_timeout(elapsed: Elapsed) -> Self {
|
||||
AuthErrorImpl::UserTimeout(elapsed).into()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use ipnet::{Ipv4Net, Ipv6Net};
|
||||
pub use link::LinkAuthError;
|
||||
pub(crate) use link::LinkAuthError;
|
||||
use local::LocalBackend;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_postgres::config::AuthKeys;
|
||||
@@ -74,12 +74,12 @@ pub enum BackendType<'a, T, D> {
|
||||
Local(MaybeOwned<'a, LocalBackend>),
|
||||
}
|
||||
|
||||
pub trait TestBackend: Send + Sync + 'static {
|
||||
#[cfg(test)]
|
||||
pub(crate) trait TestBackend: Send + Sync + 'static {
|
||||
fn wake_compute(&self) -> Result<CachedNodeInfo, console::errors::WakeComputeError>;
|
||||
fn get_allowed_ips_and_secret(
|
||||
&self,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), console::errors::GetAuthInfoError>;
|
||||
fn get_role_secret(&self) -> Result<CachedRoleSecret, console::errors::GetAuthInfoError>;
|
||||
}
|
||||
|
||||
impl std::fmt::Display for BackendType<'_, (), ()> {
|
||||
@@ -105,7 +105,7 @@ impl std::fmt::Display for BackendType<'_, (), ()> {
|
||||
impl<T, D> BackendType<'_, T, D> {
|
||||
/// Very similar to [`std::option::Option::as_ref`].
|
||||
/// This helps us pass structured config to async tasks.
|
||||
pub fn as_ref(&self) -> BackendType<'_, &T, &D> {
|
||||
pub(crate) fn as_ref(&self) -> BackendType<'_, &T, &D> {
|
||||
match self {
|
||||
Self::Console(c, x) => BackendType::Console(MaybeOwned::Borrowed(c), x),
|
||||
Self::Link(c, x) => BackendType::Link(MaybeOwned::Borrowed(c), x),
|
||||
@@ -118,7 +118,7 @@ impl<'a, T, D> BackendType<'a, T, D> {
|
||||
/// Very similar to [`std::option::Option::map`].
|
||||
/// Maps [`BackendType<T>`] to [`BackendType<R>`] by applying
|
||||
/// a function to a contained value.
|
||||
pub fn map<R>(self, f: impl FnOnce(T) -> R) -> BackendType<'a, R, D> {
|
||||
pub(crate) fn map<R>(self, f: impl FnOnce(T) -> R) -> BackendType<'a, R, D> {
|
||||
match self {
|
||||
Self::Console(c, x) => BackendType::Console(c, f(x)),
|
||||
Self::Link(c, x) => BackendType::Link(c, x),
|
||||
@@ -129,7 +129,7 @@ impl<'a, T, D> BackendType<'a, T, D> {
|
||||
impl<'a, T, D, E> BackendType<'a, Result<T, E>, D> {
|
||||
/// Very similar to [`std::option::Option::transpose`].
|
||||
/// This is most useful for error handling.
|
||||
pub fn transpose(self) -> Result<BackendType<'a, T, D>, E> {
|
||||
pub(crate) fn transpose(self) -> Result<BackendType<'a, T, D>, E> {
|
||||
match self {
|
||||
Self::Console(c, x) => x.map(|x| BackendType::Console(c, x)),
|
||||
Self::Link(c, x) => Ok(BackendType::Link(c, x)),
|
||||
@@ -138,31 +138,31 @@ impl<'a, T, D, E> BackendType<'a, Result<T, E>, D> {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ComputeCredentials {
|
||||
pub info: ComputeUserInfo,
|
||||
pub keys: ComputeCredentialKeys,
|
||||
pub(crate) struct ComputeCredentials {
|
||||
pub(crate) info: ComputeUserInfo,
|
||||
pub(crate) keys: ComputeCredentialKeys,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ComputeUserInfoNoEndpoint {
|
||||
pub user: RoleName,
|
||||
pub options: NeonOptions,
|
||||
pub(crate) struct ComputeUserInfoNoEndpoint {
|
||||
pub(crate) user: RoleName,
|
||||
pub(crate) options: NeonOptions,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ComputeUserInfo {
|
||||
pub endpoint: EndpointId,
|
||||
pub user: RoleName,
|
||||
pub options: NeonOptions,
|
||||
pub(crate) struct ComputeUserInfo {
|
||||
pub(crate) endpoint: EndpointId,
|
||||
pub(crate) user: RoleName,
|
||||
pub(crate) options: NeonOptions,
|
||||
}
|
||||
|
||||
impl ComputeUserInfo {
|
||||
pub fn endpoint_cache_key(&self) -> EndpointCacheKey {
|
||||
pub(crate) fn endpoint_cache_key(&self) -> EndpointCacheKey {
|
||||
self.options.get_cache_key(&self.endpoint)
|
||||
}
|
||||
}
|
||||
|
||||
pub enum ComputeCredentialKeys {
|
||||
pub(crate) enum ComputeCredentialKeys {
|
||||
Password(Vec<u8>),
|
||||
AuthKeys(AuthKeys),
|
||||
None,
|
||||
@@ -222,7 +222,7 @@ impl RateBucketInfo {
|
||||
}
|
||||
|
||||
impl AuthenticationConfig {
|
||||
pub fn check_rate_limit(
|
||||
pub(crate) fn check_rate_limit(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
config: &AuthenticationConfig,
|
||||
@@ -404,17 +404,8 @@ async fn authenticate_with_secret(
|
||||
}
|
||||
|
||||
impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
|
||||
/// Get compute endpoint name from the credentials.
|
||||
pub fn get_endpoint(&self) -> Option<EndpointId> {
|
||||
match self {
|
||||
Self::Console(_, user_info) => user_info.endpoint_id.clone(),
|
||||
Self::Link(_, ()) => Some("link".into()),
|
||||
Self::Local(_) => Some("local".into()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get username from the credentials.
|
||||
pub fn get_user(&self) -> &str {
|
||||
pub(crate) fn get_user(&self) -> &str {
|
||||
match self {
|
||||
Self::Console(_, user_info) => &user_info.user,
|
||||
Self::Link(_, ()) => "link",
|
||||
@@ -424,7 +415,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
|
||||
|
||||
/// Authenticate the client via the requested backend, possibly using credentials.
|
||||
#[tracing::instrument(fields(allow_cleartext = allow_cleartext), skip_all)]
|
||||
pub async fn authenticate(
|
||||
pub(crate) async fn authenticate(
|
||||
self,
|
||||
ctx: &RequestMonitoring,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
@@ -471,7 +462,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
|
||||
}
|
||||
|
||||
impl BackendType<'_, ComputeUserInfo, &()> {
|
||||
pub async fn get_role_secret(
|
||||
pub(crate) async fn get_role_secret(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
) -> Result<CachedRoleSecret, GetAuthInfoError> {
|
||||
@@ -482,7 +473,7 @@ impl BackendType<'_, ComputeUserInfo, &()> {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_allowed_ips_and_secret(
|
||||
pub(crate) async fn get_allowed_ips_and_secret(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
|
||||
|
||||
@@ -17,7 +17,7 @@ use tracing::{info, warn};
|
||||
/// one round trip and *expensive* computations (>= 4096 HMAC iterations).
|
||||
/// These properties are benefical for serverless JS workers, so we
|
||||
/// use this mechanism for websocket connections.
|
||||
pub async fn authenticate_cleartext(
|
||||
pub(crate) async fn authenticate_cleartext(
|
||||
ctx: &RequestMonitoring,
|
||||
info: ComputeUserInfo,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
@@ -59,7 +59,7 @@ pub async fn authenticate_cleartext(
|
||||
/// Workaround for clients which don't provide an endpoint (project) name.
|
||||
/// Similar to [`authenticate_cleartext`], but there's a specific password format,
|
||||
/// and passwords are not yet validated (we don't know how to validate them!)
|
||||
pub async fn password_hack_no_authentication(
|
||||
pub(crate) async fn password_hack_no_authentication(
|
||||
ctx: &RequestMonitoring,
|
||||
info: ComputeUserInfoNoEndpoint,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
|
||||
@@ -22,27 +22,27 @@ const MAX_RENEW: Duration = Duration::from_secs(3600);
|
||||
const MAX_JWK_BODY_SIZE: usize = 64 * 1024;
|
||||
|
||||
/// How to get the JWT auth rules
|
||||
pub trait FetchAuthRules: Clone + Send + Sync + 'static {
|
||||
pub(crate) trait FetchAuthRules: Clone + Send + Sync + 'static {
|
||||
fn fetch_auth_rules(
|
||||
&self,
|
||||
role_name: RoleName,
|
||||
) -> impl Future<Output = anyhow::Result<Vec<AuthRule>>> + Send;
|
||||
}
|
||||
|
||||
pub struct AuthRule {
|
||||
pub id: String,
|
||||
pub jwks_url: url::Url,
|
||||
pub audience: Option<String>,
|
||||
pub(crate) struct AuthRule {
|
||||
pub(crate) id: String,
|
||||
pub(crate) jwks_url: url::Url,
|
||||
pub(crate) audience: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct JwkCache {
|
||||
pub(crate) struct JwkCache {
|
||||
client: reqwest::Client,
|
||||
|
||||
map: DashMap<(EndpointId, RoleName), Arc<JwkCacheEntryLock>>,
|
||||
}
|
||||
|
||||
pub struct JwkCacheEntry {
|
||||
pub(crate) struct JwkCacheEntry {
|
||||
/// Should refetch at least every hour to verify when old keys have been removed.
|
||||
/// Should refetch when new key IDs are seen only every 5 minutes or so
|
||||
last_retrieved: Instant,
|
||||
@@ -75,7 +75,7 @@ impl KeySet {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct JwkCacheEntryLock {
|
||||
pub(crate) struct JwkCacheEntryLock {
|
||||
cached: ArcSwapOption<JwkCacheEntry>,
|
||||
lookup: tokio::sync::Semaphore,
|
||||
}
|
||||
@@ -309,7 +309,7 @@ impl JwkCacheEntryLock {
|
||||
}
|
||||
|
||||
impl JwkCache {
|
||||
pub async fn check_jwt<F: FetchAuthRules>(
|
||||
pub(crate) async fn check_jwt<F: FetchAuthRules>(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
endpoint: EndpointId,
|
||||
|
||||
@@ -13,7 +13,7 @@ use tokio_postgres::config::SslMode;
|
||||
use tracing::{info, info_span};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum LinkAuthError {
|
||||
pub(crate) enum LinkAuthError {
|
||||
#[error(transparent)]
|
||||
WaiterRegister(#[from] waiters::RegisterError),
|
||||
|
||||
@@ -52,7 +52,7 @@ fn hello_message(redirect_uri: &reqwest::Url, session_id: &str) -> String {
|
||||
)
|
||||
}
|
||||
|
||||
pub fn new_psql_session_id() -> String {
|
||||
pub(crate) fn new_psql_session_id() -> String {
|
||||
hex::encode(rand::random::<[u8; 8]>())
|
||||
}
|
||||
|
||||
|
||||
@@ -16,16 +16,14 @@ use crate::{
|
||||
use super::jwt::{AuthRule, FetchAuthRules, JwkCache};
|
||||
|
||||
pub struct LocalBackend {
|
||||
pub jwks_cache: JwkCache,
|
||||
pub postgres_addr: SocketAddr,
|
||||
pub node_info: NodeInfo,
|
||||
pub(crate) jwks_cache: JwkCache,
|
||||
pub(crate) node_info: NodeInfo,
|
||||
}
|
||||
|
||||
impl LocalBackend {
|
||||
pub fn new(postgres_addr: SocketAddr) -> Self {
|
||||
LocalBackend {
|
||||
jwks_cache: JwkCache::default(),
|
||||
postgres_addr,
|
||||
node_info: NodeInfo {
|
||||
config: {
|
||||
let mut cfg = ConnCfg::new();
|
||||
@@ -47,7 +45,7 @@ impl LocalBackend {
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct StaticAuthRules;
|
||||
pub(crate) struct StaticAuthRules;
|
||||
|
||||
pub static JWKS_ROLE_MAP: ArcSwapOption<JwksRoleSettings> = ArcSwapOption::const_empty();
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ use thiserror::Error;
|
||||
use tracing::{info, warn};
|
||||
|
||||
#[derive(Debug, Error, PartialEq, Eq, Clone)]
|
||||
pub enum ComputeUserInfoParseError {
|
||||
pub(crate) enum ComputeUserInfoParseError {
|
||||
#[error("Parameter '{0}' is missing in startup packet.")]
|
||||
MissingKey(&'static str),
|
||||
|
||||
@@ -51,20 +51,20 @@ impl ReportableError for ComputeUserInfoParseError {
|
||||
/// Various client credentials which we use for authentication.
|
||||
/// Note that we don't store any kind of client key or password here.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ComputeUserInfoMaybeEndpoint {
|
||||
pub user: RoleName,
|
||||
pub endpoint_id: Option<EndpointId>,
|
||||
pub options: NeonOptions,
|
||||
pub(crate) struct ComputeUserInfoMaybeEndpoint {
|
||||
pub(crate) user: RoleName,
|
||||
pub(crate) endpoint_id: Option<EndpointId>,
|
||||
pub(crate) options: NeonOptions,
|
||||
}
|
||||
|
||||
impl ComputeUserInfoMaybeEndpoint {
|
||||
#[inline]
|
||||
pub fn endpoint(&self) -> Option<&str> {
|
||||
pub(crate) fn endpoint(&self) -> Option<&str> {
|
||||
self.endpoint_id.as_deref()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn endpoint_sni(
|
||||
pub(crate) fn endpoint_sni(
|
||||
sni: &str,
|
||||
common_names: &HashSet<String>,
|
||||
) -> Result<Option<EndpointId>, ComputeUserInfoParseError> {
|
||||
@@ -83,7 +83,7 @@ pub fn endpoint_sni(
|
||||
}
|
||||
|
||||
impl ComputeUserInfoMaybeEndpoint {
|
||||
pub fn parse(
|
||||
pub(crate) fn parse(
|
||||
ctx: &RequestMonitoring,
|
||||
params: &StartupMessageParams,
|
||||
sni: Option<&str>,
|
||||
@@ -173,12 +173,12 @@ impl ComputeUserInfoMaybeEndpoint {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check_peer_addr_is_in_list(peer_addr: &IpAddr, ip_list: &[IpPattern]) -> bool {
|
||||
pub(crate) fn check_peer_addr_is_in_list(peer_addr: &IpAddr, ip_list: &[IpPattern]) -> bool {
|
||||
ip_list.is_empty() || ip_list.iter().any(|pattern| check_ip(peer_addr, pattern))
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Eq, PartialEq)]
|
||||
pub enum IpPattern {
|
||||
pub(crate) enum IpPattern {
|
||||
Subnet(ipnet::IpNet),
|
||||
Range(IpAddr, IpAddr),
|
||||
Single(IpAddr),
|
||||
|
||||
@@ -17,17 +17,20 @@ use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::info;
|
||||
|
||||
/// Every authentication selector is supposed to implement this trait.
|
||||
pub trait AuthMethod {
|
||||
pub(crate) trait AuthMethod {
|
||||
/// Any authentication selector should provide initial backend message
|
||||
/// containing auth method name and parameters, e.g. md5 salt.
|
||||
fn first_message(&self, channel_binding: bool) -> BeMessage<'_>;
|
||||
}
|
||||
|
||||
/// Initial state of [`AuthFlow`].
|
||||
pub struct Begin;
|
||||
pub(crate) struct Begin;
|
||||
|
||||
/// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
|
||||
pub struct Scram<'a>(pub &'a scram::ServerSecret, pub &'a RequestMonitoring);
|
||||
pub(crate) struct Scram<'a>(
|
||||
pub(crate) &'a scram::ServerSecret,
|
||||
pub(crate) &'a RequestMonitoring,
|
||||
);
|
||||
|
||||
impl AuthMethod for Scram<'_> {
|
||||
#[inline(always)]
|
||||
@@ -44,7 +47,7 @@ impl AuthMethod for Scram<'_> {
|
||||
|
||||
/// Use an ad hoc auth flow (for clients which don't support SNI) proposed in
|
||||
/// <https://github.com/neondatabase/cloud/issues/1620#issuecomment-1165332290>.
|
||||
pub struct PasswordHack;
|
||||
pub(crate) struct PasswordHack;
|
||||
|
||||
impl AuthMethod for PasswordHack {
|
||||
#[inline(always)]
|
||||
@@ -55,10 +58,10 @@ impl AuthMethod for PasswordHack {
|
||||
|
||||
/// Use clear-text password auth called `password` in docs
|
||||
/// <https://www.postgresql.org/docs/current/auth-password.html>
|
||||
pub struct CleartextPassword {
|
||||
pub pool: Arc<ThreadPool>,
|
||||
pub endpoint: EndpointIdInt,
|
||||
pub secret: AuthSecret,
|
||||
pub(crate) struct CleartextPassword {
|
||||
pub(crate) pool: Arc<ThreadPool>,
|
||||
pub(crate) endpoint: EndpointIdInt,
|
||||
pub(crate) secret: AuthSecret,
|
||||
}
|
||||
|
||||
impl AuthMethod for CleartextPassword {
|
||||
@@ -70,7 +73,7 @@ impl AuthMethod for CleartextPassword {
|
||||
|
||||
/// This wrapper for [`PqStream`] performs client authentication.
|
||||
#[must_use]
|
||||
pub struct AuthFlow<'a, S, State> {
|
||||
pub(crate) struct AuthFlow<'a, S, State> {
|
||||
/// The underlying stream which implements libpq's protocol.
|
||||
stream: &'a mut PqStream<Stream<S>>,
|
||||
/// State might contain ancillary data (see [`Self::begin`]).
|
||||
@@ -81,7 +84,7 @@ pub struct AuthFlow<'a, S, State> {
|
||||
/// Initial state of the stream wrapper.
|
||||
impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
|
||||
/// Create a new wrapper for client authentication.
|
||||
pub fn new(stream: &'a mut PqStream<Stream<S>>) -> Self {
|
||||
pub(crate) fn new(stream: &'a mut PqStream<Stream<S>>) -> Self {
|
||||
let tls_server_end_point = stream.get_ref().tls_server_end_point();
|
||||
|
||||
Self {
|
||||
@@ -92,7 +95,7 @@ impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
|
||||
}
|
||||
|
||||
/// Move to the next step by sending auth method's name & params to client.
|
||||
pub async fn begin<M: AuthMethod>(self, method: M) -> io::Result<AuthFlow<'a, S, M>> {
|
||||
pub(crate) async fn begin<M: AuthMethod>(self, method: M) -> io::Result<AuthFlow<'a, S, M>> {
|
||||
self.stream
|
||||
.write_message(&method.first_message(self.tls_server_end_point.supported()))
|
||||
.await?;
|
||||
@@ -107,7 +110,7 @@ impl<'a, S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
|
||||
/// Perform user authentication. Raise an error in case authentication failed.
|
||||
pub async fn get_password(self) -> super::Result<PasswordHackPayload> {
|
||||
pub(crate) async fn get_password(self) -> super::Result<PasswordHackPayload> {
|
||||
let msg = self.stream.read_password_message().await?;
|
||||
let password = msg
|
||||
.strip_suffix(&[0])
|
||||
@@ -126,7 +129,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
|
||||
/// Perform user authentication. Raise an error in case authentication failed.
|
||||
pub async fn authenticate(self) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
|
||||
pub(crate) async fn authenticate(self) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
|
||||
let msg = self.stream.read_password_message().await?;
|
||||
let password = msg
|
||||
.strip_suffix(&[0])
|
||||
@@ -151,7 +154,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
|
||||
/// Stream wrapper for handling [SCRAM](crate::scram) auth.
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
|
||||
/// Perform user authentication. Raise an error in case authentication failed.
|
||||
pub async fn authenticate(self) -> super::Result<sasl::Outcome<scram::ScramKey>> {
|
||||
pub(crate) async fn authenticate(self) -> super::Result<sasl::Outcome<scram::ScramKey>> {
|
||||
let Scram(secret, ctx) = self.state;
|
||||
|
||||
// pause the timer while we communicate with the client
|
||||
|
||||
@@ -7,13 +7,13 @@ use bstr::ByteSlice;
|
||||
|
||||
use crate::EndpointId;
|
||||
|
||||
pub struct PasswordHackPayload {
|
||||
pub endpoint: EndpointId,
|
||||
pub password: Vec<u8>,
|
||||
pub(crate) struct PasswordHackPayload {
|
||||
pub(crate) endpoint: EndpointId,
|
||||
pub(crate) password: Vec<u8>,
|
||||
}
|
||||
|
||||
impl PasswordHackPayload {
|
||||
pub fn parse(bytes: &[u8]) -> Option<Self> {
|
||||
pub(crate) fn parse(bytes: &[u8]) -> Option<Self> {
|
||||
// The format is `project=<utf-8>;<password-bytes>` or `project=<utf-8>$<password-bytes>`.
|
||||
let separators = [";", "$"];
|
||||
for sep in separators {
|
||||
@@ -30,7 +30,7 @@ impl PasswordHackPayload {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn parse_endpoint_param(bytes: &str) -> Option<&str> {
|
||||
pub(crate) fn parse_endpoint_param(bytes: &str) -> Option<&str> {
|
||||
bytes
|
||||
.strip_prefix("project=")
|
||||
.or_else(|| bytes.strip_prefix("endpoint="))
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
pub mod common;
|
||||
pub mod endpoints;
|
||||
pub mod project_info;
|
||||
pub(crate) mod common;
|
||||
pub(crate) mod endpoints;
|
||||
pub(crate) mod project_info;
|
||||
mod timed_lru;
|
||||
|
||||
pub use common::{Cache, Cached};
|
||||
pub use timed_lru::TimedLru;
|
||||
pub(crate) use common::{Cache, Cached};
|
||||
pub(crate) use timed_lru::TimedLru;
|
||||
|
||||
18
proxy/src/cache/common.rs
vendored
18
proxy/src/cache/common.rs
vendored
@@ -3,7 +3,7 @@ use std::ops::{Deref, DerefMut};
|
||||
/// A generic trait which exposes types of cache's key and value,
|
||||
/// as well as the notion of cache entry invalidation.
|
||||
/// This is useful for [`Cached`].
|
||||
pub trait Cache {
|
||||
pub(crate) trait Cache {
|
||||
/// Entry's key.
|
||||
type Key;
|
||||
|
||||
@@ -29,21 +29,21 @@ impl<C: Cache> Cache for &C {
|
||||
}
|
||||
|
||||
/// Wrapper for convenient entry invalidation.
|
||||
pub struct Cached<C: Cache, V = <C as Cache>::Value> {
|
||||
pub(crate) struct Cached<C: Cache, V = <C as Cache>::Value> {
|
||||
/// Cache + lookup info.
|
||||
pub token: Option<(C, C::LookupInfo<C::Key>)>,
|
||||
pub(crate) token: Option<(C, C::LookupInfo<C::Key>)>,
|
||||
|
||||
/// The value itself.
|
||||
pub value: V,
|
||||
pub(crate) value: V,
|
||||
}
|
||||
|
||||
impl<C: Cache, V> Cached<C, V> {
|
||||
/// Place any entry into this wrapper; invalidation will be a no-op.
|
||||
pub fn new_uncached(value: V) -> Self {
|
||||
pub(crate) fn new_uncached(value: V) -> Self {
|
||||
Self { token: None, value }
|
||||
}
|
||||
|
||||
pub fn take_value(self) -> (Cached<C, ()>, V) {
|
||||
pub(crate) fn take_value(self) -> (Cached<C, ()>, V) {
|
||||
(
|
||||
Cached {
|
||||
token: self.token,
|
||||
@@ -53,7 +53,7 @@ impl<C: Cache, V> Cached<C, V> {
|
||||
)
|
||||
}
|
||||
|
||||
pub fn map<U>(self, f: impl FnOnce(V) -> U) -> Cached<C, U> {
|
||||
pub(crate) fn map<U>(self, f: impl FnOnce(V) -> U) -> Cached<C, U> {
|
||||
Cached {
|
||||
token: self.token,
|
||||
value: f(self.value),
|
||||
@@ -61,7 +61,7 @@ impl<C: Cache, V> Cached<C, V> {
|
||||
}
|
||||
|
||||
/// Drop this entry from a cache if it's still there.
|
||||
pub fn invalidate(self) -> V {
|
||||
pub(crate) fn invalidate(self) -> V {
|
||||
if let Some((cache, info)) = &self.token {
|
||||
cache.invalidate(info);
|
||||
}
|
||||
@@ -69,7 +69,7 @@ impl<C: Cache, V> Cached<C, V> {
|
||||
}
|
||||
|
||||
/// Tell if this entry is actually cached.
|
||||
pub fn cached(&self) -> bool {
|
||||
pub(crate) fn cached(&self) -> bool {
|
||||
self.token.is_some()
|
||||
}
|
||||
}
|
||||
|
||||
6
proxy/src/cache/endpoints.rs
vendored
6
proxy/src/cache/endpoints.rs
vendored
@@ -28,7 +28,7 @@ use crate::{
|
||||
};
|
||||
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
pub struct ControlPlaneEventKey {
|
||||
pub(crate) struct ControlPlaneEventKey {
|
||||
endpoint_created: Option<EndpointCreated>,
|
||||
branch_created: Option<BranchCreated>,
|
||||
project_created: Option<ProjectCreated>,
|
||||
@@ -56,7 +56,7 @@ pub struct EndpointsCache {
|
||||
}
|
||||
|
||||
impl EndpointsCache {
|
||||
pub fn new(config: EndpointCacheConfig) -> Self {
|
||||
pub(crate) fn new(config: EndpointCacheConfig) -> Self {
|
||||
Self {
|
||||
limiter: Arc::new(Mutex::new(GlobalRateLimiter::new(
|
||||
config.limiter_info.clone(),
|
||||
@@ -68,7 +68,7 @@ impl EndpointsCache {
|
||||
ready: AtomicBool::new(false),
|
||||
}
|
||||
}
|
||||
pub async fn is_valid(&self, ctx: &RequestMonitoring, endpoint: &EndpointId) -> bool {
|
||||
pub(crate) async fn is_valid(&self, ctx: &RequestMonitoring, endpoint: &EndpointId) -> bool {
|
||||
if !self.ready.load(Ordering::Acquire) {
|
||||
return true;
|
||||
}
|
||||
|
||||
24
proxy/src/cache/project_info.rs
vendored
24
proxy/src/cache/project_info.rs
vendored
@@ -24,7 +24,7 @@ use crate::{
|
||||
use super::{Cache, Cached};
|
||||
|
||||
#[async_trait]
|
||||
pub trait ProjectInfoCache {
|
||||
pub(crate) trait ProjectInfoCache {
|
||||
fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt);
|
||||
fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt);
|
||||
async fn decrement_active_listeners(&self);
|
||||
@@ -37,7 +37,7 @@ struct Entry<T> {
|
||||
}
|
||||
|
||||
impl<T> Entry<T> {
|
||||
pub fn new(value: T) -> Self {
|
||||
pub(crate) fn new(value: T) -> Self {
|
||||
Self {
|
||||
created_at: Instant::now(),
|
||||
value,
|
||||
@@ -64,7 +64,7 @@ impl EndpointInfo {
|
||||
Some(t) => t < created_at,
|
||||
}
|
||||
}
|
||||
pub fn get_role_secret(
|
||||
pub(crate) fn get_role_secret(
|
||||
&self,
|
||||
role_name: RoleNameInt,
|
||||
valid_since: Instant,
|
||||
@@ -81,7 +81,7 @@ impl EndpointInfo {
|
||||
None
|
||||
}
|
||||
|
||||
pub fn get_allowed_ips(
|
||||
pub(crate) fn get_allowed_ips(
|
||||
&self,
|
||||
valid_since: Instant,
|
||||
ignore_cache_since: Option<Instant>,
|
||||
@@ -96,10 +96,10 @@ impl EndpointInfo {
|
||||
}
|
||||
None
|
||||
}
|
||||
pub fn invalidate_allowed_ips(&mut self) {
|
||||
pub(crate) fn invalidate_allowed_ips(&mut self) {
|
||||
self.allowed_ips = None;
|
||||
}
|
||||
pub fn invalidate_role_secret(&mut self, role_name: RoleNameInt) {
|
||||
pub(crate) fn invalidate_role_secret(&mut self, role_name: RoleNameInt) {
|
||||
self.secret.remove(&role_name);
|
||||
}
|
||||
}
|
||||
@@ -178,7 +178,7 @@ impl ProjectInfoCache for ProjectInfoCacheImpl {
|
||||
}
|
||||
|
||||
impl ProjectInfoCacheImpl {
|
||||
pub fn new(config: ProjectInfoCacheOptions) -> Self {
|
||||
pub(crate) fn new(config: ProjectInfoCacheOptions) -> Self {
|
||||
Self {
|
||||
cache: DashMap::new(),
|
||||
project2ep: DashMap::new(),
|
||||
@@ -189,7 +189,7 @@ impl ProjectInfoCacheImpl {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_role_secret(
|
||||
pub(crate) fn get_role_secret(
|
||||
&self,
|
||||
endpoint_id: &EndpointId,
|
||||
role_name: &RoleName,
|
||||
@@ -212,7 +212,7 @@ impl ProjectInfoCacheImpl {
|
||||
}
|
||||
Some(Cached::new_uncached(value))
|
||||
}
|
||||
pub fn get_allowed_ips(
|
||||
pub(crate) fn get_allowed_ips(
|
||||
&self,
|
||||
endpoint_id: &EndpointId,
|
||||
) -> Option<Cached<&Self, Arc<Vec<IpPattern>>>> {
|
||||
@@ -230,7 +230,7 @@ impl ProjectInfoCacheImpl {
|
||||
}
|
||||
Some(Cached::new_uncached(value))
|
||||
}
|
||||
pub fn insert_role_secret(
|
||||
pub(crate) fn insert_role_secret(
|
||||
&self,
|
||||
project_id: ProjectIdInt,
|
||||
endpoint_id: EndpointIdInt,
|
||||
@@ -247,7 +247,7 @@ impl ProjectInfoCacheImpl {
|
||||
entry.secret.insert(role_name, secret.into());
|
||||
}
|
||||
}
|
||||
pub fn insert_allowed_ips(
|
||||
pub(crate) fn insert_allowed_ips(
|
||||
&self,
|
||||
project_id: ProjectIdInt,
|
||||
endpoint_id: EndpointIdInt,
|
||||
@@ -319,7 +319,7 @@ impl ProjectInfoCacheImpl {
|
||||
|
||||
/// Lookup info for project info cache.
|
||||
/// This is used to invalidate cache entries.
|
||||
pub struct CachedLookupInfo {
|
||||
pub(crate) struct CachedLookupInfo {
|
||||
/// Search by this key.
|
||||
endpoint_id: EndpointIdInt,
|
||||
lookup_type: LookupType,
|
||||
|
||||
45
proxy/src/cache/timed_lru.rs
vendored
45
proxy/src/cache/timed_lru.rs
vendored
@@ -39,7 +39,7 @@ use super::{common::Cached, *};
|
||||
///
|
||||
/// * It's possible for an entry that has not yet expired entry to be evicted
|
||||
/// before expired items. That's a bit wasteful, but probably fine in practice.
|
||||
pub struct TimedLru<K, V> {
|
||||
pub(crate) struct TimedLru<K, V> {
|
||||
/// Cache's name for tracing.
|
||||
name: &'static str,
|
||||
|
||||
@@ -72,7 +72,7 @@ struct Entry<T> {
|
||||
|
||||
impl<K: Hash + Eq, V> TimedLru<K, V> {
|
||||
/// Construct a new LRU cache with timed entries.
|
||||
pub fn new(
|
||||
pub(crate) fn new(
|
||||
name: &'static str,
|
||||
capacity: usize,
|
||||
ttl: Duration,
|
||||
@@ -207,11 +207,11 @@ impl<K: Hash + Eq, V> TimedLru<K, V> {
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq + Clone, V: Clone> TimedLru<K, V> {
|
||||
pub fn insert_ttl(&self, key: K, value: V, ttl: Duration) {
|
||||
pub(crate) fn insert_ttl(&self, key: K, value: V, ttl: Duration) {
|
||||
self.insert_raw_ttl(key, value, ttl, false);
|
||||
}
|
||||
|
||||
pub fn insert_unit(&self, key: K, value: V) -> (Option<V>, Cached<&Self, ()>) {
|
||||
pub(crate) fn insert_unit(&self, key: K, value: V) -> (Option<V>, Cached<&Self, ()>) {
|
||||
let (created_at, old) = self.insert_raw(key.clone(), value);
|
||||
|
||||
let cached = Cached {
|
||||
@@ -221,22 +221,11 @@ impl<K: Hash + Eq + Clone, V: Clone> TimedLru<K, V> {
|
||||
|
||||
(old, cached)
|
||||
}
|
||||
|
||||
pub fn insert(&self, key: K, value: V) -> (Option<V>, Cached<&Self>) {
|
||||
let (created_at, old) = self.insert_raw(key.clone(), value.clone());
|
||||
|
||||
let cached = Cached {
|
||||
token: Some((self, LookupInfo { created_at, key })),
|
||||
value,
|
||||
};
|
||||
|
||||
(old, cached)
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq, V: Clone> TimedLru<K, V> {
|
||||
/// Retrieve a cached entry in convenient wrapper.
|
||||
pub fn get<Q>(&self, key: &Q) -> Option<timed_lru::Cached<&Self>>
|
||||
pub(crate) fn get<Q>(&self, key: &Q) -> Option<timed_lru::Cached<&Self>>
|
||||
where
|
||||
K: Borrow<Q> + Clone,
|
||||
Q: Hash + Eq + ?Sized,
|
||||
@@ -253,32 +242,10 @@ impl<K: Hash + Eq, V: Clone> TimedLru<K, V> {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Retrieve a cached entry in convenient wrapper, ignoring its TTL.
|
||||
pub fn get_ignoring_ttl<Q>(&self, key: &Q) -> Option<timed_lru::Cached<&Self>>
|
||||
where
|
||||
K: Borrow<Q>,
|
||||
Q: Hash + Eq + ?Sized,
|
||||
{
|
||||
let mut cache = self.cache.lock();
|
||||
cache
|
||||
.get(key)
|
||||
.map(|entry| Cached::new_uncached(entry.value.clone()))
|
||||
}
|
||||
|
||||
/// Remove an entry from the cache.
|
||||
pub fn remove<Q>(&self, key: &Q) -> Option<V>
|
||||
where
|
||||
K: Borrow<Q> + Clone,
|
||||
Q: Hash + Eq + ?Sized,
|
||||
{
|
||||
let mut cache = self.cache.lock();
|
||||
cache.remove(key).map(|entry| entry.value)
|
||||
}
|
||||
}
|
||||
|
||||
/// Lookup information for key invalidation.
|
||||
pub struct LookupInfo<K> {
|
||||
pub(crate) struct LookupInfo<K> {
|
||||
/// Time of creation of a cache [`Entry`].
|
||||
/// We use this during invalidation lookups to prevent eviction of a newer
|
||||
/// entry sharing the same key (it might've been inserted by a different
|
||||
|
||||
@@ -18,7 +18,7 @@ use crate::{
|
||||
|
||||
pub type CancelMap = Arc<DashMap<CancelKeyData, Option<CancelClosure>>>;
|
||||
pub type CancellationHandlerMain = CancellationHandler<Option<Arc<Mutex<RedisPublisherClient>>>>;
|
||||
pub type CancellationHandlerMainInternal = Option<Arc<Mutex<RedisPublisherClient>>>;
|
||||
pub(crate) type CancellationHandlerMainInternal = Option<Arc<Mutex<RedisPublisherClient>>>;
|
||||
|
||||
/// Enables serving `CancelRequest`s.
|
||||
///
|
||||
@@ -32,7 +32,7 @@ pub struct CancellationHandler<P> {
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum CancelError {
|
||||
pub(crate) enum CancelError {
|
||||
#[error("{0}")]
|
||||
IO(#[from] std::io::Error),
|
||||
#[error("{0}")]
|
||||
@@ -53,7 +53,7 @@ impl ReportableError for CancelError {
|
||||
|
||||
impl<P: CancellationPublisher> CancellationHandler<P> {
|
||||
/// Run async action within an ephemeral session identified by [`CancelKeyData`].
|
||||
pub fn get_session(self: Arc<Self>) -> Session<P> {
|
||||
pub(crate) fn get_session(self: Arc<Self>) -> Session<P> {
|
||||
// HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
|
||||
// expose it and we don't want to do another roundtrip to query
|
||||
// for it. The client will be able to notice that this is not the
|
||||
@@ -81,7 +81,7 @@ impl<P: CancellationPublisher> CancellationHandler<P> {
|
||||
}
|
||||
/// Try to cancel a running query for the corresponding connection.
|
||||
/// If the cancellation key is not found, it will be published to Redis.
|
||||
pub async fn cancel_session(
|
||||
pub(crate) async fn cancel_session(
|
||||
&self,
|
||||
key: CancelKeyData,
|
||||
session_id: Uuid,
|
||||
@@ -155,14 +155,14 @@ pub struct CancelClosure {
|
||||
}
|
||||
|
||||
impl CancelClosure {
|
||||
pub fn new(socket_addr: SocketAddr, cancel_token: CancelToken) -> Self {
|
||||
pub(crate) fn new(socket_addr: SocketAddr, cancel_token: CancelToken) -> Self {
|
||||
Self {
|
||||
socket_addr,
|
||||
cancel_token,
|
||||
}
|
||||
}
|
||||
/// Cancels the query running on user's compute node.
|
||||
pub async fn try_cancel_query(self) -> Result<(), CancelError> {
|
||||
pub(crate) async fn try_cancel_query(self) -> Result<(), CancelError> {
|
||||
let socket = TcpStream::connect(self.socket_addr).await?;
|
||||
self.cancel_token.cancel_query_raw(socket, NoTls).await?;
|
||||
info!("query was cancelled");
|
||||
@@ -171,7 +171,7 @@ impl CancelClosure {
|
||||
}
|
||||
|
||||
/// Helper for registering query cancellation tokens.
|
||||
pub struct Session<P> {
|
||||
pub(crate) struct Session<P> {
|
||||
/// The user-facing key identifying this session.
|
||||
key: CancelKeyData,
|
||||
/// The [`CancelMap`] this session belongs to.
|
||||
@@ -181,7 +181,7 @@ pub struct Session<P> {
|
||||
impl<P> Session<P> {
|
||||
/// Store the cancel token for the given session.
|
||||
/// This enables query cancellation in `crate::proxy::prepare_client_connection`.
|
||||
pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
|
||||
pub(crate) fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
|
||||
info!("enabling query cancellation for this session");
|
||||
self.cancellation_handler
|
||||
.map
|
||||
|
||||
@@ -23,7 +23,7 @@ use tracing::{error, info, warn};
|
||||
const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node";
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ConnectionError {
|
||||
pub(crate) enum ConnectionError {
|
||||
/// This error doesn't seem to reveal any secrets; for instance,
|
||||
/// `tokio_postgres::error::Kind` doesn't contain ip addresses and such.
|
||||
#[error("{COULD_NOT_CONNECT}: {0}")]
|
||||
@@ -86,22 +86,22 @@ impl ReportableError for ConnectionError {
|
||||
}
|
||||
|
||||
/// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
|
||||
pub type ScramKeys = tokio_postgres::config::ScramKeys<32>;
|
||||
pub(crate) type ScramKeys = tokio_postgres::config::ScramKeys<32>;
|
||||
|
||||
/// A config for establishing a connection to compute node.
|
||||
/// Eventually, `tokio_postgres` will be replaced with something better.
|
||||
/// Newtype allows us to implement methods on top of it.
|
||||
#[derive(Clone, Default)]
|
||||
pub struct ConnCfg(Box<tokio_postgres::Config>);
|
||||
pub(crate) struct ConnCfg(Box<tokio_postgres::Config>);
|
||||
|
||||
/// Creation and initialization routines.
|
||||
impl ConnCfg {
|
||||
pub fn new() -> Self {
|
||||
pub(crate) fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Reuse password or auth keys from the other config.
|
||||
pub fn reuse_password(&mut self, other: Self) {
|
||||
pub(crate) fn reuse_password(&mut self, other: Self) {
|
||||
if let Some(password) = other.get_password() {
|
||||
self.password(password);
|
||||
}
|
||||
@@ -111,7 +111,7 @@ impl ConnCfg {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_host(&self) -> Result<Host, WakeComputeError> {
|
||||
pub(crate) fn get_host(&self) -> Result<Host, WakeComputeError> {
|
||||
match self.0.get_hosts() {
|
||||
[tokio_postgres::config::Host::Tcp(s)] => Ok(s.into()),
|
||||
// we should not have multiple address or unix addresses.
|
||||
@@ -122,7 +122,7 @@ impl ConnCfg {
|
||||
}
|
||||
|
||||
/// Apply startup message params to the connection config.
|
||||
pub fn set_startup_params(&mut self, params: &StartupMessageParams) {
|
||||
pub(crate) fn set_startup_params(&mut self, params: &StartupMessageParams) {
|
||||
// Only set `user` if it's not present in the config.
|
||||
// Link auth flow takes username from the console's response.
|
||||
if let (None, Some(user)) = (self.get_user(), params.get("user")) {
|
||||
@@ -255,25 +255,25 @@ impl ConnCfg {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PostgresConnection {
|
||||
pub(crate) struct PostgresConnection {
|
||||
/// Socket connected to a compute node.
|
||||
pub stream: tokio_postgres::maybe_tls_stream::MaybeTlsStream<
|
||||
pub(crate) stream: tokio_postgres::maybe_tls_stream::MaybeTlsStream<
|
||||
tokio::net::TcpStream,
|
||||
tokio_postgres_rustls::RustlsStream<tokio::net::TcpStream>,
|
||||
>,
|
||||
/// PostgreSQL connection parameters.
|
||||
pub params: std::collections::HashMap<String, String>,
|
||||
pub(crate) params: std::collections::HashMap<String, String>,
|
||||
/// Query cancellation token.
|
||||
pub cancel_closure: CancelClosure,
|
||||
pub(crate) cancel_closure: CancelClosure,
|
||||
/// Labels for proxy's metrics.
|
||||
pub aux: MetricsAuxInfo,
|
||||
pub(crate) aux: MetricsAuxInfo,
|
||||
|
||||
_guage: NumDbConnectionsGuard<'static>,
|
||||
}
|
||||
|
||||
impl ConnCfg {
|
||||
/// Connect to a corresponding compute node.
|
||||
pub async fn connect(
|
||||
pub(crate) async fn connect(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
allow_self_signed_compute: bool,
|
||||
|
||||
@@ -10,7 +10,7 @@ pub(crate) use provider::{errors, Api, AuthSecret, CachedNodeInfo, NodeInfo};
|
||||
|
||||
/// Various cache-related types.
|
||||
pub mod caches {
|
||||
pub use super::provider::{ApiCaches, NodeInfoCache};
|
||||
pub use super::provider::ApiCaches;
|
||||
}
|
||||
|
||||
/// Various cache-related types.
|
||||
|
||||
@@ -12,22 +12,22 @@ use crate::RoleName;
|
||||
/// Generic error response with human-readable description.
|
||||
/// Note that we can't always present it to user as is.
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct ConsoleError {
|
||||
pub error: Box<str>,
|
||||
pub(crate) struct ConsoleError {
|
||||
pub(crate) error: Box<str>,
|
||||
#[serde(skip)]
|
||||
pub http_status_code: http::StatusCode,
|
||||
pub status: Option<Status>,
|
||||
pub(crate) http_status_code: http::StatusCode,
|
||||
pub(crate) status: Option<Status>,
|
||||
}
|
||||
|
||||
impl ConsoleError {
|
||||
pub fn get_reason(&self) -> Reason {
|
||||
pub(crate) fn get_reason(&self) -> Reason {
|
||||
self.status
|
||||
.as_ref()
|
||||
.and_then(|s| s.details.error_info.as_ref())
|
||||
.map_or(Reason::Unknown, |e| e.reason)
|
||||
}
|
||||
|
||||
pub fn get_user_facing_message(&self) -> String {
|
||||
pub(crate) fn get_user_facing_message(&self) -> String {
|
||||
use super::provider::errors::REQUEST_FAILED;
|
||||
self.status
|
||||
.as_ref()
|
||||
@@ -88,27 +88,28 @@ impl CouldRetry for ConsoleError {
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct Status {
|
||||
pub code: Box<str>,
|
||||
pub message: Box<str>,
|
||||
pub details: Details,
|
||||
#[allow(dead_code)]
|
||||
pub(crate) struct Status {
|
||||
pub(crate) code: Box<str>,
|
||||
pub(crate) message: Box<str>,
|
||||
pub(crate) details: Details,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct Details {
|
||||
pub error_info: Option<ErrorInfo>,
|
||||
pub retry_info: Option<RetryInfo>,
|
||||
pub user_facing_message: Option<UserFacingMessage>,
|
||||
pub(crate) struct Details {
|
||||
pub(crate) error_info: Option<ErrorInfo>,
|
||||
pub(crate) retry_info: Option<RetryInfo>,
|
||||
pub(crate) user_facing_message: Option<UserFacingMessage>,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, Deserialize)]
|
||||
pub struct ErrorInfo {
|
||||
pub reason: Reason,
|
||||
pub(crate) struct ErrorInfo {
|
||||
pub(crate) reason: Reason,
|
||||
// Schema could also have `metadata` field, but it's not structured. Skip it for now.
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Deserialize, Default)]
|
||||
pub enum Reason {
|
||||
pub(crate) enum Reason {
|
||||
/// RoleProtected indicates that the role is protected and the attempted operation is not permitted on protected roles.
|
||||
#[serde(rename = "ROLE_PROTECTED")]
|
||||
RoleProtected,
|
||||
@@ -168,7 +169,7 @@ pub enum Reason {
|
||||
}
|
||||
|
||||
impl Reason {
|
||||
pub fn is_not_found(&self) -> bool {
|
||||
pub(crate) fn is_not_found(self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
Reason::ResourceNotFound
|
||||
@@ -178,7 +179,7 @@ impl Reason {
|
||||
)
|
||||
}
|
||||
|
||||
pub fn can_retry(&self) -> bool {
|
||||
pub(crate) fn can_retry(self) -> bool {
|
||||
match self {
|
||||
// do not retry role protected errors
|
||||
// not a transitive error
|
||||
@@ -208,22 +209,23 @@ impl Reason {
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, Deserialize)]
|
||||
pub struct RetryInfo {
|
||||
pub retry_delay_ms: u64,
|
||||
#[allow(dead_code)]
|
||||
pub(crate) struct RetryInfo {
|
||||
pub(crate) retry_delay_ms: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct UserFacingMessage {
|
||||
pub message: Box<str>,
|
||||
pub(crate) struct UserFacingMessage {
|
||||
pub(crate) message: Box<str>,
|
||||
}
|
||||
|
||||
/// Response which holds client's auth secret, e.g. [`crate::scram::ServerSecret`].
|
||||
/// Returned by the `/proxy_get_role_secret` API method.
|
||||
#[derive(Deserialize)]
|
||||
pub struct GetRoleSecret {
|
||||
pub role_secret: Box<str>,
|
||||
pub allowed_ips: Option<Vec<IpPattern>>,
|
||||
pub project_id: Option<ProjectIdInt>,
|
||||
pub(crate) struct GetRoleSecret {
|
||||
pub(crate) role_secret: Box<str>,
|
||||
pub(crate) allowed_ips: Option<Vec<IpPattern>>,
|
||||
pub(crate) project_id: Option<ProjectIdInt>,
|
||||
}
|
||||
|
||||
// Manually implement debug to omit sensitive info.
|
||||
@@ -236,21 +238,21 @@ impl fmt::Debug for GetRoleSecret {
|
||||
/// Response which holds compute node's `host:port` pair.
|
||||
/// Returned by the `/proxy_wake_compute` API method.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct WakeCompute {
|
||||
pub address: Box<str>,
|
||||
pub aux: MetricsAuxInfo,
|
||||
pub(crate) struct WakeCompute {
|
||||
pub(crate) address: Box<str>,
|
||||
pub(crate) aux: MetricsAuxInfo,
|
||||
}
|
||||
|
||||
/// Async response which concludes the link auth flow.
|
||||
/// Also known as `kickResponse` in the console.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct KickSession<'a> {
|
||||
pub(crate) struct KickSession<'a> {
|
||||
/// Session ID is assigned by the proxy.
|
||||
pub session_id: &'a str,
|
||||
pub(crate) session_id: &'a str,
|
||||
|
||||
/// Compute node connection params.
|
||||
#[serde(deserialize_with = "KickSession::parse_db_info")]
|
||||
pub result: DatabaseInfo,
|
||||
pub(crate) result: DatabaseInfo,
|
||||
}
|
||||
|
||||
impl KickSession<'_> {
|
||||
@@ -273,15 +275,15 @@ impl KickSession<'_> {
|
||||
|
||||
/// Compute node connection params.
|
||||
#[derive(Deserialize)]
|
||||
pub struct DatabaseInfo {
|
||||
pub host: Box<str>,
|
||||
pub port: u16,
|
||||
pub dbname: Box<str>,
|
||||
pub user: Box<str>,
|
||||
pub(crate) struct DatabaseInfo {
|
||||
pub(crate) host: Box<str>,
|
||||
pub(crate) port: u16,
|
||||
pub(crate) dbname: Box<str>,
|
||||
pub(crate) user: Box<str>,
|
||||
/// Console always provides a password, but it might
|
||||
/// be inconvenient for debug with local PG instance.
|
||||
pub password: Option<Box<str>>,
|
||||
pub aux: MetricsAuxInfo,
|
||||
pub(crate) password: Option<Box<str>>,
|
||||
pub(crate) aux: MetricsAuxInfo,
|
||||
}
|
||||
|
||||
// Manually implement debug to omit sensitive info.
|
||||
@@ -299,12 +301,12 @@ impl fmt::Debug for DatabaseInfo {
|
||||
/// Various labels for prometheus metrics.
|
||||
/// Also known as `ProxyMetricsAuxInfo` in the console.
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct MetricsAuxInfo {
|
||||
pub endpoint_id: EndpointIdInt,
|
||||
pub project_id: ProjectIdInt,
|
||||
pub branch_id: BranchIdInt,
|
||||
pub(crate) struct MetricsAuxInfo {
|
||||
pub(crate) endpoint_id: EndpointIdInt,
|
||||
pub(crate) project_id: ProjectIdInt,
|
||||
pub(crate) branch_id: BranchIdInt,
|
||||
#[serde(default)]
|
||||
pub cold_start_info: ColdStartInfo,
|
||||
pub(crate) cold_start_info: ColdStartInfo,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Serialize, Deserialize, Clone, Copy, FixedCardinalityLabel)]
|
||||
@@ -331,7 +333,7 @@ pub enum ColdStartInfo {
|
||||
}
|
||||
|
||||
impl ColdStartInfo {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
pub(crate) fn as_str(self) -> &'static str {
|
||||
match self {
|
||||
ColdStartInfo::Unknown => "unknown",
|
||||
ColdStartInfo::Warm => "warm",
|
||||
|
||||
@@ -14,13 +14,13 @@ use tracing::{error, info, info_span, Instrument};
|
||||
static CPLANE_WAITERS: Lazy<Waiters<ComputeReady>> = Lazy::new(Default::default);
|
||||
|
||||
/// Give caller an opportunity to wait for the cloud's reply.
|
||||
pub fn get_waiter(
|
||||
pub(crate) fn get_waiter(
|
||||
psql_session_id: impl Into<String>,
|
||||
) -> Result<Waiter<'static, ComputeReady>, waiters::RegisterError> {
|
||||
CPLANE_WAITERS.register(psql_session_id.into())
|
||||
}
|
||||
|
||||
pub fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), waiters::NotifyError> {
|
||||
pub(crate) fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), waiters::NotifyError> {
|
||||
CPLANE_WAITERS.notify(psql_session_id, msg)
|
||||
}
|
||||
|
||||
@@ -74,7 +74,7 @@ async fn handle_connection(socket: TcpStream) -> Result<(), QueryError> {
|
||||
}
|
||||
|
||||
/// A message received by `mgmt` when a compute node is ready.
|
||||
pub type ComputeReady = DatabaseInfo;
|
||||
pub(crate) type ComputeReady = DatabaseInfo;
|
||||
|
||||
// TODO: replace with an http-based protocol.
|
||||
struct MgmtHandler;
|
||||
|
||||
@@ -23,7 +23,7 @@ use std::{hash::Hash, sync::Arc, time::Duration};
|
||||
use tokio::time::Instant;
|
||||
use tracing::info;
|
||||
|
||||
pub mod errors {
|
||||
pub(crate) mod errors {
|
||||
use crate::{
|
||||
console::messages::{self, ConsoleError, Reason},
|
||||
error::{io_error, ErrorKind, ReportableError, UserFacingError},
|
||||
@@ -34,11 +34,11 @@ pub mod errors {
|
||||
use super::ApiLockError;
|
||||
|
||||
/// A go-to error message which doesn't leak any detail.
|
||||
pub const REQUEST_FAILED: &str = "Console request failed";
|
||||
pub(crate) const REQUEST_FAILED: &str = "Console request failed";
|
||||
|
||||
/// Common console API error.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ApiError {
|
||||
pub(crate) enum ApiError {
|
||||
/// Error returned by the console itself.
|
||||
#[error("{REQUEST_FAILED} with {0}")]
|
||||
Console(ConsoleError),
|
||||
@@ -50,7 +50,7 @@ pub mod errors {
|
||||
|
||||
impl ApiError {
|
||||
/// Returns HTTP status code if it's the reason for failure.
|
||||
pub fn get_reason(&self) -> messages::Reason {
|
||||
pub(crate) fn get_reason(&self) -> messages::Reason {
|
||||
match self {
|
||||
ApiError::Console(e) => e.get_reason(),
|
||||
ApiError::Transport(_) => messages::Reason::Unknown,
|
||||
@@ -146,7 +146,7 @@ pub mod errors {
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum GetAuthInfoError {
|
||||
pub(crate) enum GetAuthInfoError {
|
||||
// We shouldn't include the actual secret here.
|
||||
#[error("Console responded with a malformed auth secret")]
|
||||
BadSecret,
|
||||
@@ -183,7 +183,7 @@ pub mod errors {
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum WakeComputeError {
|
||||
pub(crate) enum WakeComputeError {
|
||||
#[error("Console responded with a malformed compute address: {0}")]
|
||||
BadComputeAddress(Box<str>),
|
||||
|
||||
@@ -247,7 +247,7 @@ pub mod errors {
|
||||
|
||||
/// Auth secret which is managed by the cloud.
|
||||
#[derive(Clone, Eq, PartialEq, Debug)]
|
||||
pub enum AuthSecret {
|
||||
pub(crate) enum AuthSecret {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
/// Md5 hash of user's password.
|
||||
Md5([u8; 16]),
|
||||
@@ -257,32 +257,32 @@ pub enum AuthSecret {
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct AuthInfo {
|
||||
pub secret: Option<AuthSecret>,
|
||||
pub(crate) struct AuthInfo {
|
||||
pub(crate) secret: Option<AuthSecret>,
|
||||
/// List of IP addresses allowed for the autorization.
|
||||
pub allowed_ips: Vec<IpPattern>,
|
||||
pub(crate) allowed_ips: Vec<IpPattern>,
|
||||
/// Project ID. This is used for cache invalidation.
|
||||
pub project_id: Option<ProjectIdInt>,
|
||||
pub(crate) project_id: Option<ProjectIdInt>,
|
||||
}
|
||||
|
||||
/// Info for establishing a connection to a compute node.
|
||||
/// This is what we get after auth succeeded, but not before!
|
||||
#[derive(Clone)]
|
||||
pub struct NodeInfo {
|
||||
pub(crate) struct NodeInfo {
|
||||
/// Compute node connection params.
|
||||
/// It's sad that we have to clone this, but this will improve
|
||||
/// once we migrate to a bespoke connection logic.
|
||||
pub config: compute::ConnCfg,
|
||||
pub(crate) config: compute::ConnCfg,
|
||||
|
||||
/// Labels for proxy's metrics.
|
||||
pub aux: MetricsAuxInfo,
|
||||
pub(crate) aux: MetricsAuxInfo,
|
||||
|
||||
/// Whether we should accept self-signed certificates (for testing)
|
||||
pub allow_self_signed_compute: bool,
|
||||
pub(crate) allow_self_signed_compute: bool,
|
||||
}
|
||||
|
||||
impl NodeInfo {
|
||||
pub async fn connect(
|
||||
pub(crate) async fn connect(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
timeout: Duration,
|
||||
@@ -296,12 +296,12 @@ impl NodeInfo {
|
||||
)
|
||||
.await
|
||||
}
|
||||
pub fn reuse_settings(&mut self, other: Self) {
|
||||
pub(crate) fn reuse_settings(&mut self, other: Self) {
|
||||
self.allow_self_signed_compute = other.allow_self_signed_compute;
|
||||
self.config.reuse_password(other.config);
|
||||
}
|
||||
|
||||
pub fn set_keys(&mut self, keys: &ComputeCredentialKeys) {
|
||||
pub(crate) fn set_keys(&mut self, keys: &ComputeCredentialKeys) {
|
||||
match keys {
|
||||
ComputeCredentialKeys::Password(password) => self.config.password(password),
|
||||
ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys),
|
||||
@@ -310,10 +310,10 @@ impl NodeInfo {
|
||||
}
|
||||
}
|
||||
|
||||
pub type NodeInfoCache = TimedLru<EndpointCacheKey, Result<NodeInfo, Box<ConsoleError>>>;
|
||||
pub type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>;
|
||||
pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option<AuthSecret>>;
|
||||
pub type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc<Vec<IpPattern>>>;
|
||||
pub(crate) type NodeInfoCache = TimedLru<EndpointCacheKey, Result<NodeInfo, Box<ConsoleError>>>;
|
||||
pub(crate) type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>;
|
||||
pub(crate) type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option<AuthSecret>>;
|
||||
pub(crate) type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc<Vec<IpPattern>>>;
|
||||
|
||||
/// This will allocate per each call, but the http requests alone
|
||||
/// already require a few allocations, so it should be fine.
|
||||
@@ -350,6 +350,7 @@ pub enum ConsoleBackend {
|
||||
Postgres(mock::Api),
|
||||
/// Internal testing
|
||||
#[cfg(test)]
|
||||
#[allow(private_interfaces)]
|
||||
Test(Box<dyn crate::auth::backend::TestBackend>),
|
||||
}
|
||||
|
||||
@@ -402,7 +403,7 @@ impl Api for ConsoleBackend {
|
||||
/// Various caches for [`console`](super).
|
||||
pub struct ApiCaches {
|
||||
/// Cache for the `wake_compute` API method.
|
||||
pub node_info: NodeInfoCache,
|
||||
pub(crate) node_info: NodeInfoCache,
|
||||
/// Cache which stores project_id -> endpoint_ids mapping.
|
||||
pub project_info: Arc<ProjectInfoCacheImpl>,
|
||||
/// List of all valid endpoints.
|
||||
@@ -439,7 +440,7 @@ pub struct ApiLocks<K> {
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ApiLockError {
|
||||
pub(crate) enum ApiLockError {
|
||||
#[error("timeout acquiring resource permit")]
|
||||
TimeoutError(#[from] tokio::time::error::Elapsed),
|
||||
}
|
||||
@@ -471,7 +472,7 @@ impl<K: Hash + Eq + Clone> ApiLocks<K> {
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn get_permit(&self, key: &K) -> Result<WakeComputePermit, ApiLockError> {
|
||||
pub(crate) async fn get_permit(&self, key: &K) -> Result<WakeComputePermit, ApiLockError> {
|
||||
if self.config.initial_limit == 0 {
|
||||
return Ok(WakeComputePermit {
|
||||
permit: Token::disabled(),
|
||||
@@ -531,18 +532,18 @@ impl<K: Hash + Eq + Clone> ApiLocks<K> {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct WakeComputePermit {
|
||||
pub(crate) struct WakeComputePermit {
|
||||
permit: Token,
|
||||
}
|
||||
|
||||
impl WakeComputePermit {
|
||||
pub fn should_check_cache(&self) -> bool {
|
||||
pub(crate) fn should_check_cache(&self) -> bool {
|
||||
!self.permit.is_disabled()
|
||||
}
|
||||
pub fn release(self, outcome: Outcome) {
|
||||
pub(crate) fn release(self, outcome: Outcome) {
|
||||
self.permit.release(outcome);
|
||||
}
|
||||
pub fn release_result<T, E>(self, res: Result<T, E>) -> Result<T, E> {
|
||||
pub(crate) fn release_result<T, E>(self, res: Result<T, E>) -> Result<T, E> {
|
||||
match res {
|
||||
Ok(_) => self.release(Outcome::Success),
|
||||
Err(_) => self.release(Outcome::Overload),
|
||||
|
||||
@@ -48,7 +48,7 @@ impl Api {
|
||||
Self { endpoint }
|
||||
}
|
||||
|
||||
pub fn url(&self) -> &str {
|
||||
pub(crate) fn url(&self) -> &str {
|
||||
self.endpoint.as_str()
|
||||
}
|
||||
|
||||
|
||||
@@ -25,8 +25,8 @@ use tracing::{debug, error, info, info_span, warn, Instrument};
|
||||
pub struct Api {
|
||||
endpoint: http::Endpoint,
|
||||
pub caches: &'static ApiCaches,
|
||||
pub locks: &'static ApiLocks<EndpointCacheKey>,
|
||||
pub wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
|
||||
pub(crate) locks: &'static ApiLocks<EndpointCacheKey>,
|
||||
pub(crate) wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
|
||||
jwt: String,
|
||||
}
|
||||
|
||||
@@ -51,7 +51,7 @@ impl Api {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn url(&self) -> &str {
|
||||
pub(crate) fn url(&self) -> &str {
|
||||
self.endpoint.url().as_str()
|
||||
}
|
||||
|
||||
|
||||
@@ -22,8 +22,9 @@ use self::parquet::RequestData;
|
||||
|
||||
pub mod parquet;
|
||||
|
||||
pub static LOG_CHAN: OnceCell<mpsc::WeakUnboundedSender<RequestData>> = OnceCell::new();
|
||||
pub static LOG_CHAN_DISCONNECT: OnceCell<mpsc::WeakUnboundedSender<RequestData>> = OnceCell::new();
|
||||
pub(crate) static LOG_CHAN: OnceCell<mpsc::WeakUnboundedSender<RequestData>> = OnceCell::new();
|
||||
pub(crate) static LOG_CHAN_DISCONNECT: OnceCell<mpsc::WeakUnboundedSender<RequestData>> =
|
||||
OnceCell::new();
|
||||
|
||||
/// Context data for a single request to connect to a database.
|
||||
///
|
||||
@@ -38,12 +39,12 @@ pub struct RequestMonitoring(
|
||||
);
|
||||
|
||||
struct RequestMonitoringInner {
|
||||
pub peer_addr: IpAddr,
|
||||
pub session_id: Uuid,
|
||||
pub protocol: Protocol,
|
||||
pub(crate) peer_addr: IpAddr,
|
||||
pub(crate) session_id: Uuid,
|
||||
pub(crate) protocol: Protocol,
|
||||
first_packet: chrono::DateTime<Utc>,
|
||||
region: &'static str,
|
||||
pub span: Span,
|
||||
pub(crate) span: Span,
|
||||
|
||||
// filled in as they are discovered
|
||||
project: Option<ProjectIdInt>,
|
||||
@@ -63,14 +64,14 @@ struct RequestMonitoringInner {
|
||||
sender: Option<mpsc::UnboundedSender<RequestData>>,
|
||||
// This sender is only used to log the length of session in case of success.
|
||||
disconnect_sender: Option<mpsc::UnboundedSender<RequestData>>,
|
||||
pub latency_timer: LatencyTimer,
|
||||
pub(crate) latency_timer: LatencyTimer,
|
||||
// Whether proxy decided that it's not a valid endpoint end rejected it before going to cplane.
|
||||
rejected: Option<bool>,
|
||||
disconnect_timestamp: Option<chrono::DateTime<Utc>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum AuthMethod {
|
||||
pub(crate) enum AuthMethod {
|
||||
// aka link aka passwordless
|
||||
Web,
|
||||
ScramSha256,
|
||||
@@ -125,11 +126,11 @@ impl RequestMonitoring {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn test() -> Self {
|
||||
pub(crate) fn test() -> Self {
|
||||
RequestMonitoring::new(Uuid::now_v7(), [127, 0, 0, 1].into(), Protocol::Tcp, "test")
|
||||
}
|
||||
|
||||
pub fn console_application_name(&self) -> String {
|
||||
pub(crate) fn console_application_name(&self) -> String {
|
||||
let this = self.0.try_lock().expect("should not deadlock");
|
||||
format!(
|
||||
"{}/{}",
|
||||
@@ -138,19 +139,19 @@ impl RequestMonitoring {
|
||||
)
|
||||
}
|
||||
|
||||
pub fn set_rejected(&self, rejected: bool) {
|
||||
pub(crate) fn set_rejected(&self, rejected: bool) {
|
||||
let mut this = self.0.try_lock().expect("should not deadlock");
|
||||
this.rejected = Some(rejected);
|
||||
}
|
||||
|
||||
pub fn set_cold_start_info(&self, info: ColdStartInfo) {
|
||||
pub(crate) fn set_cold_start_info(&self, info: ColdStartInfo) {
|
||||
self.0
|
||||
.try_lock()
|
||||
.expect("should not deadlock")
|
||||
.set_cold_start_info(info);
|
||||
}
|
||||
|
||||
pub fn set_db_options(&self, options: StartupMessageParams) {
|
||||
pub(crate) fn set_db_options(&self, options: StartupMessageParams) {
|
||||
let mut this = self.0.try_lock().expect("should not deadlock");
|
||||
this.set_application(options.get("application_name").map(SmolStr::from));
|
||||
if let Some(user) = options.get("user") {
|
||||
@@ -163,7 +164,7 @@ impl RequestMonitoring {
|
||||
this.pg_options = Some(options);
|
||||
}
|
||||
|
||||
pub fn set_project(&self, x: MetricsAuxInfo) {
|
||||
pub(crate) fn set_project(&self, x: MetricsAuxInfo) {
|
||||
let mut this = self.0.try_lock().expect("should not deadlock");
|
||||
if this.endpoint_id.is_none() {
|
||||
this.set_endpoint_id(x.endpoint_id.as_str().into());
|
||||
@@ -173,33 +174,33 @@ impl RequestMonitoring {
|
||||
this.set_cold_start_info(x.cold_start_info);
|
||||
}
|
||||
|
||||
pub fn set_project_id(&self, project_id: ProjectIdInt) {
|
||||
pub(crate) fn set_project_id(&self, project_id: ProjectIdInt) {
|
||||
let mut this = self.0.try_lock().expect("should not deadlock");
|
||||
this.project = Some(project_id);
|
||||
}
|
||||
|
||||
pub fn set_endpoint_id(&self, endpoint_id: EndpointId) {
|
||||
pub(crate) fn set_endpoint_id(&self, endpoint_id: EndpointId) {
|
||||
self.0
|
||||
.try_lock()
|
||||
.expect("should not deadlock")
|
||||
.set_endpoint_id(endpoint_id);
|
||||
}
|
||||
|
||||
pub fn set_dbname(&self, dbname: DbName) {
|
||||
pub(crate) fn set_dbname(&self, dbname: DbName) {
|
||||
self.0
|
||||
.try_lock()
|
||||
.expect("should not deadlock")
|
||||
.set_dbname(dbname);
|
||||
}
|
||||
|
||||
pub fn set_user(&self, user: RoleName) {
|
||||
pub(crate) fn set_user(&self, user: RoleName) {
|
||||
self.0
|
||||
.try_lock()
|
||||
.expect("should not deadlock")
|
||||
.set_user(user);
|
||||
}
|
||||
|
||||
pub fn set_auth_method(&self, auth_method: AuthMethod) {
|
||||
pub(crate) fn set_auth_method(&self, auth_method: AuthMethod) {
|
||||
let mut this = self.0.try_lock().expect("should not deadlock");
|
||||
this.auth_method = Some(auth_method);
|
||||
}
|
||||
@@ -211,7 +212,7 @@ impl RequestMonitoring {
|
||||
.has_private_peer_addr()
|
||||
}
|
||||
|
||||
pub fn set_error_kind(&self, kind: ErrorKind) {
|
||||
pub(crate) fn set_error_kind(&self, kind: ErrorKind) {
|
||||
let mut this = self.0.try_lock().expect("should not deadlock");
|
||||
// Do not record errors from the private address to metrics.
|
||||
if !this.has_private_peer_addr() {
|
||||
@@ -237,30 +238,30 @@ impl RequestMonitoring {
|
||||
.log_connect();
|
||||
}
|
||||
|
||||
pub fn protocol(&self) -> Protocol {
|
||||
pub(crate) fn protocol(&self) -> Protocol {
|
||||
self.0.try_lock().expect("should not deadlock").protocol
|
||||
}
|
||||
|
||||
pub fn span(&self) -> Span {
|
||||
pub(crate) fn span(&self) -> Span {
|
||||
self.0.try_lock().expect("should not deadlock").span.clone()
|
||||
}
|
||||
|
||||
pub fn session_id(&self) -> Uuid {
|
||||
pub(crate) fn session_id(&self) -> Uuid {
|
||||
self.0.try_lock().expect("should not deadlock").session_id
|
||||
}
|
||||
|
||||
pub fn peer_addr(&self) -> IpAddr {
|
||||
pub(crate) fn peer_addr(&self) -> IpAddr {
|
||||
self.0.try_lock().expect("should not deadlock").peer_addr
|
||||
}
|
||||
|
||||
pub fn cold_start_info(&self) -> ColdStartInfo {
|
||||
pub(crate) fn cold_start_info(&self) -> ColdStartInfo {
|
||||
self.0
|
||||
.try_lock()
|
||||
.expect("should not deadlock")
|
||||
.cold_start_info
|
||||
}
|
||||
|
||||
pub fn latency_timer_pause(&self, waiting_for: Waiting) -> LatencyTimerPause<'_> {
|
||||
pub(crate) fn latency_timer_pause(&self, waiting_for: Waiting) -> LatencyTimerPause<'_> {
|
||||
LatencyTimerPause {
|
||||
ctx: self,
|
||||
start: tokio::time::Instant::now(),
|
||||
@@ -268,7 +269,7 @@ impl RequestMonitoring {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn success(&self) {
|
||||
pub(crate) fn success(&self) {
|
||||
self.0
|
||||
.try_lock()
|
||||
.expect("should not deadlock")
|
||||
@@ -277,7 +278,7 @@ impl RequestMonitoring {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct LatencyTimerPause<'a> {
|
||||
pub(crate) struct LatencyTimerPause<'a> {
|
||||
ctx: &'a RequestMonitoring,
|
||||
start: tokio::time::Instant,
|
||||
waiting_for: Waiting,
|
||||
|
||||
@@ -62,8 +62,8 @@ pub struct ParquetUploadArgs {
|
||||
// But after FAILED_UPLOAD_WARN_THRESHOLD retries, we start to log it at WARN
|
||||
// level instead, as repeated failures can mean a more serious problem. If it
|
||||
// fails more than FAILED_UPLOAD_RETRIES times, we give up
|
||||
pub const FAILED_UPLOAD_WARN_THRESHOLD: u32 = 3;
|
||||
pub const FAILED_UPLOAD_MAX_RETRIES: u32 = 10;
|
||||
pub(crate) const FAILED_UPLOAD_WARN_THRESHOLD: u32 = 3;
|
||||
pub(crate) const FAILED_UPLOAD_MAX_RETRIES: u32 = 10;
|
||||
|
||||
// the parquet crate leaves a lot to be desired...
|
||||
// what follows is an attempt to write parquet files with minimal allocs.
|
||||
@@ -73,7 +73,7 @@ pub const FAILED_UPLOAD_MAX_RETRIES: u32 = 10;
|
||||
// * after each rowgroup write, we check the length of the file and upload to s3 if large enough
|
||||
|
||||
#[derive(parquet_derive::ParquetRecordWriter)]
|
||||
pub struct RequestData {
|
||||
pub(crate) struct RequestData {
|
||||
region: &'static str,
|
||||
protocol: &'static str,
|
||||
/// Must be UTC. The derive macro doesn't like the timezones
|
||||
|
||||
@@ -3,12 +3,12 @@ use std::{error::Error as StdError, fmt, io};
|
||||
use measured::FixedCardinalityLabel;
|
||||
|
||||
/// Upcast (almost) any error into an opaque [`io::Error`].
|
||||
pub fn io_error(e: impl Into<Box<dyn StdError + Send + Sync>>) -> io::Error {
|
||||
pub(crate) fn io_error(e: impl Into<Box<dyn StdError + Send + Sync>>) -> io::Error {
|
||||
io::Error::new(io::ErrorKind::Other, e)
|
||||
}
|
||||
|
||||
/// A small combinator for pluggable error logging.
|
||||
pub fn log_error<E: fmt::Display>(e: E) -> E {
|
||||
pub(crate) fn log_error<E: fmt::Display>(e: E) -> E {
|
||||
tracing::error!("{e}");
|
||||
e
|
||||
}
|
||||
@@ -19,7 +19,7 @@ pub fn log_error<E: fmt::Display>(e: E) -> E {
|
||||
/// NOTE: This trait should not be implemented for [`anyhow::Error`], since it
|
||||
/// is way too convenient and tends to proliferate all across the codebase,
|
||||
/// ultimately leading to accidental leaks of sensitive data.
|
||||
pub trait UserFacingError: ReportableError {
|
||||
pub(crate) trait UserFacingError: ReportableError {
|
||||
/// Format the error for client, stripping all sensitive info.
|
||||
///
|
||||
/// Although this might be a no-op for many types, it's highly
|
||||
@@ -64,7 +64,7 @@ pub enum ErrorKind {
|
||||
}
|
||||
|
||||
impl ErrorKind {
|
||||
pub fn to_metric_label(&self) -> &'static str {
|
||||
pub(crate) fn to_metric_label(self) -> &'static str {
|
||||
match self {
|
||||
ErrorKind::User => "user",
|
||||
ErrorKind::ClientDisconnect => "clientdisconnect",
|
||||
@@ -78,7 +78,7 @@ impl ErrorKind {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ReportableError: fmt::Display + Send + 'static {
|
||||
pub(crate) trait ReportableError: fmt::Display + Send + 'static {
|
||||
fn get_error_kind(&self) -> ErrorKind;
|
||||
}
|
||||
|
||||
|
||||
@@ -12,9 +12,9 @@ use http_body_util::BodyExt;
|
||||
use hyper1::body::Body;
|
||||
use serde::de::DeserializeOwned;
|
||||
|
||||
pub use reqwest::{Request, Response, StatusCode};
|
||||
pub use reqwest_middleware::{ClientWithMiddleware, Error};
|
||||
pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
|
||||
pub(crate) use reqwest::{Request, Response};
|
||||
pub(crate) use reqwest_middleware::{ClientWithMiddleware, Error};
|
||||
pub(crate) use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
|
||||
|
||||
use crate::{
|
||||
metrics::{ConsoleRequest, Metrics},
|
||||
@@ -35,7 +35,7 @@ pub fn new_client() -> ClientWithMiddleware {
|
||||
.build()
|
||||
}
|
||||
|
||||
pub fn new_client_with_timeout(default_timout: Duration) -> ClientWithMiddleware {
|
||||
pub(crate) fn new_client_with_timeout(default_timout: Duration) -> ClientWithMiddleware {
|
||||
let timeout_client = reqwest::ClientBuilder::new()
|
||||
.timeout(default_timout)
|
||||
.build()
|
||||
@@ -77,20 +77,20 @@ impl Endpoint {
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn url(&self) -> &ApiUrl {
|
||||
pub(crate) fn url(&self) -> &ApiUrl {
|
||||
&self.endpoint
|
||||
}
|
||||
|
||||
/// Return a [builder](RequestBuilder) for a `GET` request,
|
||||
/// appending a single `path` segment to the base endpoint URL.
|
||||
pub fn get(&self, path: &str) -> RequestBuilder {
|
||||
pub(crate) fn get(&self, path: &str) -> RequestBuilder {
|
||||
let mut url = self.endpoint.clone();
|
||||
url.path_segments_mut().push(path);
|
||||
self.client.get(url.into_inner())
|
||||
}
|
||||
|
||||
/// Execute a [request](reqwest::Request).
|
||||
pub async fn execute(&self, request: Request) -> Result<Response, Error> {
|
||||
pub(crate) async fn execute(&self, request: Request) -> Result<Response, Error> {
|
||||
let _timer = Metrics::get()
|
||||
.proxy
|
||||
.console_request_latency
|
||||
@@ -102,7 +102,7 @@ impl Endpoint {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn parse_json_body_with_limit<D: DeserializeOwned>(
|
||||
pub(crate) async fn parse_json_body_with_limit<D: DeserializeOwned>(
|
||||
mut b: impl Body<Data = Bytes, Error = reqwest::Error> + Unpin,
|
||||
limit: usize,
|
||||
) -> anyhow::Result<D> {
|
||||
|
||||
@@ -29,10 +29,10 @@ impl<Id: InternId> std::fmt::Display for InternedString<Id> {
|
||||
}
|
||||
|
||||
impl<Id: InternId> InternedString<Id> {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
pub(crate) fn as_str(&self) -> &'static str {
|
||||
Id::get_interner().inner.resolve(&self.inner)
|
||||
}
|
||||
pub fn get(s: &str) -> Option<Self> {
|
||||
pub(crate) fn get(s: &str) -> Option<Self> {
|
||||
Id::get_interner().get(s)
|
||||
}
|
||||
}
|
||||
@@ -78,7 +78,7 @@ impl<Id: InternId> serde::Serialize for InternedString<Id> {
|
||||
}
|
||||
|
||||
impl<Id: InternId> StringInterner<Id> {
|
||||
pub fn new() -> Self {
|
||||
pub(crate) fn new() -> Self {
|
||||
StringInterner {
|
||||
inner: ThreadedRodeo::with_capacity_memory_limits_and_hasher(
|
||||
Capacity::new(2500, NonZeroUsize::new(1 << 16).unwrap()),
|
||||
@@ -90,26 +90,24 @@ impl<Id: InternId> StringInterner<Id> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.inner.is_empty()
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
#[cfg(test)]
|
||||
fn len(&self) -> usize {
|
||||
self.inner.len()
|
||||
}
|
||||
|
||||
pub fn current_memory_usage(&self) -> usize {
|
||||
#[cfg(test)]
|
||||
fn current_memory_usage(&self) -> usize {
|
||||
self.inner.current_memory_usage()
|
||||
}
|
||||
|
||||
pub fn get_or_intern(&self, s: &str) -> InternedString<Id> {
|
||||
pub(crate) fn get_or_intern(&self, s: &str) -> InternedString<Id> {
|
||||
InternedString {
|
||||
inner: self.inner.get_or_intern(s),
|
||||
_id: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get(&self, s: &str) -> Option<InternedString<Id>> {
|
||||
pub(crate) fn get(&self, s: &str) -> Option<InternedString<Id>> {
|
||||
Some(InternedString {
|
||||
inner: self.inner.get(s)?,
|
||||
_id: PhantomData,
|
||||
@@ -132,14 +130,14 @@ impl<Id: InternId> Default for StringInterner<Id> {
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
|
||||
pub struct RoleNameTag;
|
||||
pub(crate) struct RoleNameTag;
|
||||
impl InternId for RoleNameTag {
|
||||
fn get_interner() -> &'static StringInterner<Self> {
|
||||
pub static ROLE_NAMES: OnceLock<StringInterner<RoleNameTag>> = OnceLock::new();
|
||||
static ROLE_NAMES: OnceLock<StringInterner<RoleNameTag>> = OnceLock::new();
|
||||
ROLE_NAMES.get_or_init(Default::default)
|
||||
}
|
||||
}
|
||||
pub type RoleNameInt = InternedString<RoleNameTag>;
|
||||
pub(crate) type RoleNameInt = InternedString<RoleNameTag>;
|
||||
impl From<&RoleName> for RoleNameInt {
|
||||
fn from(value: &RoleName) -> Self {
|
||||
RoleNameTag::get_interner().get_or_intern(value)
|
||||
@@ -150,7 +148,7 @@ impl From<&RoleName> for RoleNameInt {
|
||||
pub struct EndpointIdTag;
|
||||
impl InternId for EndpointIdTag {
|
||||
fn get_interner() -> &'static StringInterner<Self> {
|
||||
pub static ROLE_NAMES: OnceLock<StringInterner<EndpointIdTag>> = OnceLock::new();
|
||||
static ROLE_NAMES: OnceLock<StringInterner<EndpointIdTag>> = OnceLock::new();
|
||||
ROLE_NAMES.get_or_init(Default::default)
|
||||
}
|
||||
}
|
||||
@@ -170,7 +168,7 @@ impl From<EndpointId> for EndpointIdInt {
|
||||
pub struct BranchIdTag;
|
||||
impl InternId for BranchIdTag {
|
||||
fn get_interner() -> &'static StringInterner<Self> {
|
||||
pub static ROLE_NAMES: OnceLock<StringInterner<BranchIdTag>> = OnceLock::new();
|
||||
static ROLE_NAMES: OnceLock<StringInterner<BranchIdTag>> = OnceLock::new();
|
||||
ROLE_NAMES.get_or_init(Default::default)
|
||||
}
|
||||
}
|
||||
@@ -190,7 +188,7 @@ impl From<BranchId> for BranchIdInt {
|
||||
pub struct ProjectIdTag;
|
||||
impl InternId for ProjectIdTag {
|
||||
fn get_interner() -> &'static StringInterner<Self> {
|
||||
pub static ROLE_NAMES: OnceLock<StringInterner<ProjectIdTag>> = OnceLock::new();
|
||||
static ROLE_NAMES: OnceLock<StringInterner<ProjectIdTag>> = OnceLock::new();
|
||||
ROLE_NAMES.get_or_init(Default::default)
|
||||
}
|
||||
}
|
||||
@@ -217,7 +215,7 @@ mod tests {
|
||||
struct MyId;
|
||||
impl InternId for MyId {
|
||||
fn get_interner() -> &'static StringInterner<Self> {
|
||||
pub static ROLE_NAMES: OnceLock<StringInterner<MyId>> = OnceLock::new();
|
||||
pub(crate) static ROLE_NAMES: OnceLock<StringInterner<MyId>> = OnceLock::new();
|
||||
ROLE_NAMES.get_or_init(Default::default)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -157,7 +157,8 @@ macro_rules! smol_str_wrapper {
|
||||
pub struct $name(smol_str::SmolStr);
|
||||
|
||||
impl $name {
|
||||
pub fn as_str(&self) -> &str {
|
||||
#[allow(unused)]
|
||||
pub(crate) fn as_str(&self) -> &str {
|
||||
self.0.as_str()
|
||||
}
|
||||
}
|
||||
@@ -252,19 +253,19 @@ smol_str_wrapper!(Host);
|
||||
|
||||
// Endpoints are a bit tricky. Rare they might be branches or projects.
|
||||
impl EndpointId {
|
||||
pub fn is_endpoint(&self) -> bool {
|
||||
pub(crate) fn is_endpoint(&self) -> bool {
|
||||
self.0.starts_with("ep-")
|
||||
}
|
||||
pub fn is_branch(&self) -> bool {
|
||||
pub(crate) fn is_branch(&self) -> bool {
|
||||
self.0.starts_with("br-")
|
||||
}
|
||||
pub fn is_project(&self) -> bool {
|
||||
!self.is_endpoint() && !self.is_branch()
|
||||
}
|
||||
pub fn as_branch(&self) -> BranchId {
|
||||
// pub(crate) fn is_project(&self) -> bool {
|
||||
// !self.is_endpoint() && !self.is_branch()
|
||||
// }
|
||||
pub(crate) fn as_branch(&self) -> BranchId {
|
||||
BranchId(self.0.clone())
|
||||
}
|
||||
pub fn as_project(&self) -> ProjectId {
|
||||
pub(crate) fn as_project(&self) -> ProjectId {
|
||||
ProjectId(self.0.clone())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,14 +2,14 @@
|
||||
|
||||
use std::ffi::CStr;
|
||||
|
||||
pub fn split_cstr(bytes: &[u8]) -> Option<(&CStr, &[u8])> {
|
||||
pub(crate) fn split_cstr(bytes: &[u8]) -> Option<(&CStr, &[u8])> {
|
||||
let cstr = CStr::from_bytes_until_nul(bytes).ok()?;
|
||||
let (_, other) = bytes.split_at(cstr.to_bytes_with_nul().len());
|
||||
Some((cstr, other))
|
||||
}
|
||||
|
||||
/// See <https://doc.rust-lang.org/std/primitive.slice.html#method.split_array_ref>.
|
||||
pub fn split_at_const<const N: usize>(bytes: &[u8]) -> Option<(&[u8; N], &[u8])> {
|
||||
pub(crate) fn split_at_const<const N: usize>(bytes: &[u8]) -> Option<(&[u8; N], &[u8])> {
|
||||
(bytes.len() >= N).then(|| {
|
||||
let (head, tail) = bytes.split_at(N);
|
||||
(head.try_into().unwrap(), tail)
|
||||
|
||||
@@ -13,9 +13,9 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
|
||||
|
||||
pin_project! {
|
||||
/// A chained [`AsyncRead`] with [`AsyncWrite`] passthrough
|
||||
pub struct ChainRW<T> {
|
||||
pub(crate) struct ChainRW<T> {
|
||||
#[pin]
|
||||
pub inner: T,
|
||||
pub(crate) inner: T,
|
||||
buf: BytesMut,
|
||||
}
|
||||
}
|
||||
@@ -60,7 +60,7 @@ const HEADER: [u8; 12] = [
|
||||
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
|
||||
];
|
||||
|
||||
pub async fn read_proxy_protocol<T: AsyncRead + Unpin>(
|
||||
pub(crate) async fn read_proxy_protocol<T: AsyncRead + Unpin>(
|
||||
mut read: T,
|
||||
) -> std::io::Result<(ChainRW<T>, Option<SocketAddr>)> {
|
||||
let mut buf = BytesMut::with_capacity(128);
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
pub mod connect_compute;
|
||||
pub(crate) mod connect_compute;
|
||||
mod copy_bidirectional;
|
||||
pub mod handshake;
|
||||
pub mod passthrough;
|
||||
pub mod retry;
|
||||
pub mod wake_compute;
|
||||
pub(crate) mod handshake;
|
||||
pub(crate) mod passthrough;
|
||||
pub(crate) mod retry;
|
||||
pub(crate) mod wake_compute;
|
||||
pub use copy_bidirectional::copy_bidirectional_client_compute;
|
||||
pub use copy_bidirectional::ErrorSource;
|
||||
|
||||
@@ -170,21 +170,21 @@ pub async fn task_main(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub enum ClientMode {
|
||||
pub(crate) enum ClientMode {
|
||||
Tcp,
|
||||
Websockets { hostname: Option<String> },
|
||||
}
|
||||
|
||||
/// Abstracts the logic of handling TCP vs WS clients
|
||||
impl ClientMode {
|
||||
pub fn allow_cleartext(&self) -> bool {
|
||||
pub(crate) fn allow_cleartext(&self) -> bool {
|
||||
match self {
|
||||
ClientMode::Tcp => false,
|
||||
ClientMode::Websockets { .. } => true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool {
|
||||
pub(crate) fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool {
|
||||
match self {
|
||||
ClientMode::Tcp => config.allow_self_signed_compute,
|
||||
ClientMode::Websockets { .. } => false,
|
||||
@@ -213,7 +213,7 @@ impl ClientMode {
|
||||
// 2. Handshake: handshake reports errors if it can, otherwise if the handshake fails due to protocol violation,
|
||||
// we cannot be sure the client even understands our error message
|
||||
// 3. PrepareClient: The client disconnected, so we can't tell them anyway...
|
||||
pub enum ClientRequestError {
|
||||
pub(crate) enum ClientRequestError {
|
||||
#[error("{0}")]
|
||||
Cancellation(#[from] cancellation::CancelError),
|
||||
#[error("{0}")]
|
||||
@@ -238,7 +238,7 @@ impl ReportableError for ClientRequestError {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
config: &'static ProxyConfig,
|
||||
ctx: &RequestMonitoring,
|
||||
cancellation_handler: Arc<CancellationHandlerMain>,
|
||||
@@ -340,9 +340,9 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
client: stream,
|
||||
aux: node.aux.clone(),
|
||||
compute: node,
|
||||
req: request_gauge,
|
||||
conn: conn_gauge,
|
||||
cancel: session,
|
||||
_req: request_gauge,
|
||||
_conn: conn_gauge,
|
||||
_cancel: session,
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -377,20 +377,20 @@ async fn prepare_client_connection<P>(
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Default)]
|
||||
pub struct NeonOptions(Vec<(SmolStr, SmolStr)>);
|
||||
pub(crate) struct NeonOptions(Vec<(SmolStr, SmolStr)>);
|
||||
|
||||
impl NeonOptions {
|
||||
pub fn parse_params(params: &StartupMessageParams) -> Self {
|
||||
pub(crate) fn parse_params(params: &StartupMessageParams) -> Self {
|
||||
params
|
||||
.options_raw()
|
||||
.map(Self::parse_from_iter)
|
||||
.unwrap_or_default()
|
||||
}
|
||||
pub fn parse_options_raw(options: &str) -> Self {
|
||||
pub(crate) fn parse_options_raw(options: &str) -> Self {
|
||||
Self::parse_from_iter(StartupMessageParams::parse_options_raw(options))
|
||||
}
|
||||
|
||||
pub fn is_ephemeral(&self) -> bool {
|
||||
pub(crate) fn is_ephemeral(&self) -> bool {
|
||||
// Currently, neon endpoint options are all reserved for ephemeral endpoints.
|
||||
!self.0.is_empty()
|
||||
}
|
||||
@@ -404,7 +404,7 @@ impl NeonOptions {
|
||||
Self(options)
|
||||
}
|
||||
|
||||
pub fn get_cache_key(&self, prefix: &str) -> EndpointCacheKey {
|
||||
pub(crate) fn get_cache_key(&self, prefix: &str) -> EndpointCacheKey {
|
||||
// prefix + format!(" {k}:{v}")
|
||||
// kinda jank because SmolStr is immutable
|
||||
std::iter::once(prefix)
|
||||
@@ -415,7 +415,7 @@ impl NeonOptions {
|
||||
|
||||
/// <https://swagger.io/docs/specification/serialization/> DeepObject format
|
||||
/// `paramName[prop1]=value1¶mName[prop2]=value2&...`
|
||||
pub fn to_deep_object(&self) -> Vec<(SmolStr, SmolStr)> {
|
||||
pub(crate) fn to_deep_object(&self) -> Vec<(SmolStr, SmolStr)> {
|
||||
self.0
|
||||
.iter()
|
||||
.map(|(k, v)| (format_smolstr!("options[{}]", k), v.clone()))
|
||||
@@ -423,7 +423,7 @@ impl NeonOptions {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn neon_option(bytes: &str) -> Option<(&str, &str)> {
|
||||
pub(crate) fn neon_option(bytes: &str) -> Option<(&str, &str)> {
|
||||
static RE: OnceCell<Regex> = OnceCell::new();
|
||||
let re = RE.get_or_init(|| Regex::new(r"^neon_(\w+):(.+)").unwrap());
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ const CONNECT_TIMEOUT: time::Duration = time::Duration::from_secs(2);
|
||||
/// (e.g. the compute node's address might've changed at the wrong time).
|
||||
/// Invalidate the cache entry (if any) to prevent subsequent errors.
|
||||
#[tracing::instrument(name = "invalidate_cache", skip_all)]
|
||||
pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> NodeInfo {
|
||||
pub(crate) fn invalidate_cache(node_info: console::CachedNodeInfo) -> NodeInfo {
|
||||
let is_cached = node_info.cached();
|
||||
if is_cached {
|
||||
warn!("invalidating stalled compute node info cache entry");
|
||||
@@ -41,7 +41,7 @@ pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> NodeInfo {
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait ConnectMechanism {
|
||||
pub(crate) trait ConnectMechanism {
|
||||
type Connection;
|
||||
type ConnectError: ReportableError;
|
||||
type Error: From<Self::ConnectError>;
|
||||
@@ -56,7 +56,7 @@ pub trait ConnectMechanism {
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait ComputeConnectBackend {
|
||||
pub(crate) trait ComputeConnectBackend {
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
@@ -65,12 +65,12 @@ pub trait ComputeConnectBackend {
|
||||
fn get_keys(&self) -> &ComputeCredentialKeys;
|
||||
}
|
||||
|
||||
pub struct TcpMechanism<'a> {
|
||||
pub(crate) struct TcpMechanism<'a> {
|
||||
/// KV-dictionary with PostgreSQL connection params.
|
||||
pub params: &'a StartupMessageParams,
|
||||
pub(crate) params: &'a StartupMessageParams,
|
||||
|
||||
/// connect_to_compute concurrency lock
|
||||
pub locks: &'static ApiLocks<Host>,
|
||||
pub(crate) locks: &'static ApiLocks<Host>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -98,7 +98,7 @@ impl ConnectMechanism for TcpMechanism<'_> {
|
||||
|
||||
/// Try to connect to the compute node, retrying if necessary.
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn connect_to_compute<M: ConnectMechanism, B: ComputeConnectBackend>(
|
||||
pub(crate) async fn connect_to_compute<M: ConnectMechanism, B: ComputeConnectBackend>(
|
||||
ctx: &RequestMonitoring,
|
||||
mechanism: &M,
|
||||
user_info: &B,
|
||||
|
||||
@@ -14,7 +14,7 @@ enum TransferState {
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ErrorDirection {
|
||||
pub(crate) enum ErrorDirection {
|
||||
Read(io::Error),
|
||||
Write(io::Error),
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@ use crate::{
|
||||
};
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum HandshakeError {
|
||||
pub(crate) enum HandshakeError {
|
||||
#[error("data is sent before server replied with EncryptionResponse")]
|
||||
EarlyData,
|
||||
|
||||
@@ -57,7 +57,7 @@ impl ReportableError for HandshakeError {
|
||||
}
|
||||
}
|
||||
|
||||
pub enum HandshakeData<S> {
|
||||
pub(crate) enum HandshakeData<S> {
|
||||
Startup(PqStream<Stream<S>>, StartupMessageParams),
|
||||
Cancel(CancelKeyData),
|
||||
}
|
||||
@@ -67,7 +67,7 @@ pub enum HandshakeData<S> {
|
||||
/// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
|
||||
/// we also take an extra care of propagating only the select handshake errors to client.
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
ctx: &RequestMonitoring,
|
||||
stream: S,
|
||||
mut tls: Option<&TlsConfig>,
|
||||
|
||||
@@ -14,7 +14,7 @@ use super::copy_bidirectional::ErrorSource;
|
||||
|
||||
/// Forward bytes in both directions (client <-> compute).
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn proxy_pass(
|
||||
pub(crate) async fn proxy_pass(
|
||||
client: impl AsyncRead + AsyncWrite + Unpin,
|
||||
compute: impl AsyncRead + AsyncWrite + Unpin,
|
||||
aux: MetricsAuxInfo,
|
||||
@@ -57,18 +57,18 @@ pub async fn proxy_pass(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub struct ProxyPassthrough<P, S> {
|
||||
pub client: Stream<S>,
|
||||
pub compute: PostgresConnection,
|
||||
pub aux: MetricsAuxInfo,
|
||||
pub(crate) struct ProxyPassthrough<P, S> {
|
||||
pub(crate) client: Stream<S>,
|
||||
pub(crate) compute: PostgresConnection,
|
||||
pub(crate) aux: MetricsAuxInfo,
|
||||
|
||||
pub req: NumConnectionRequestsGuard<'static>,
|
||||
pub conn: NumClientConnectionsGuard<'static>,
|
||||
pub cancel: cancellation::Session<P>,
|
||||
pub(crate) _req: NumConnectionRequestsGuard<'static>,
|
||||
pub(crate) _conn: NumClientConnectionsGuard<'static>,
|
||||
pub(crate) _cancel: cancellation::Session<P>,
|
||||
}
|
||||
|
||||
impl<P, S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<P, S> {
|
||||
pub async fn proxy_pass(self) -> Result<(), ErrorSource> {
|
||||
pub(crate) async fn proxy_pass(self) -> Result<(), ErrorSource> {
|
||||
let res = proxy_pass(self.client, self.compute.stream, self.aux).await;
|
||||
if let Err(err) = self.compute.cancel_closure.try_cancel_query().await {
|
||||
tracing::error!(?err, "could not cancel the query in the database");
|
||||
|
||||
@@ -2,18 +2,18 @@ use crate::{compute, config::RetryConfig};
|
||||
use std::{error::Error, io};
|
||||
use tokio::time;
|
||||
|
||||
pub trait CouldRetry {
|
||||
pub(crate) trait CouldRetry {
|
||||
/// Returns true if the error could be retried
|
||||
fn could_retry(&self) -> bool;
|
||||
}
|
||||
|
||||
pub trait ShouldRetryWakeCompute {
|
||||
pub(crate) trait ShouldRetryWakeCompute {
|
||||
/// Returns true if we need to invalidate the cache for this node.
|
||||
/// If false, we can continue retrying with the current node cache.
|
||||
fn should_retry_wake_compute(&self) -> bool;
|
||||
}
|
||||
|
||||
pub fn should_retry(err: &impl CouldRetry, num_retries: u32, config: RetryConfig) -> bool {
|
||||
pub(crate) fn should_retry(err: &impl CouldRetry, num_retries: u32, config: RetryConfig) -> bool {
|
||||
num_retries < config.max_retries && err.could_retry()
|
||||
}
|
||||
|
||||
@@ -101,7 +101,7 @@ impl ShouldRetryWakeCompute for compute::ConnectionError {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn retry_after(num_retries: u32, config: RetryConfig) -> time::Duration {
|
||||
pub(crate) fn retry_after(num_retries: u32, config: RetryConfig) -> time::Duration {
|
||||
config
|
||||
.base_delay
|
||||
.mul_f64(config.backoff_factor.powi((num_retries as i32) - 1))
|
||||
|
||||
@@ -11,14 +11,14 @@ use crate::auth::backend::{
|
||||
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned, TestBackend,
|
||||
};
|
||||
use crate::config::{CertResolver, RetryConfig};
|
||||
use crate::console::caches::NodeInfoCache;
|
||||
use crate::console::messages::{ConsoleError, Details, MetricsAuxInfo, Status};
|
||||
use crate::console::provider::{CachedAllowedIps, CachedRoleSecret, ConsoleBackend};
|
||||
use crate::console::provider::{CachedAllowedIps, CachedRoleSecret, ConsoleBackend, NodeInfoCache};
|
||||
use crate::console::{self, CachedNodeInfo, NodeInfo};
|
||||
use crate::error::ErrorKind;
|
||||
use crate::{http, sasl, scram, BranchId, EndpointId, ProjectId};
|
||||
use crate::{sasl, scram, BranchId, EndpointId, ProjectId};
|
||||
use anyhow::{bail, Context};
|
||||
use async_trait::async_trait;
|
||||
use http::StatusCode;
|
||||
use retry::{retry_after, ShouldRetryWakeCompute};
|
||||
use rstest::rstest;
|
||||
use rustls::pki_types;
|
||||
@@ -491,7 +491,7 @@ impl TestBackend for TestConnectMechanism {
|
||||
ConnectAction::Wake => Ok(helper_create_cached_node_info(self.cache)),
|
||||
ConnectAction::WakeFail => {
|
||||
let err = console::errors::ApiError::Console(ConsoleError {
|
||||
http_status_code: http::StatusCode::BAD_REQUEST,
|
||||
http_status_code: StatusCode::BAD_REQUEST,
|
||||
error: "TEST".into(),
|
||||
status: None,
|
||||
});
|
||||
@@ -500,7 +500,7 @@ impl TestBackend for TestConnectMechanism {
|
||||
}
|
||||
ConnectAction::WakeRetry => {
|
||||
let err = console::errors::ApiError::Console(ConsoleError {
|
||||
http_status_code: http::StatusCode::BAD_REQUEST,
|
||||
http_status_code: StatusCode::BAD_REQUEST,
|
||||
error: "TEST".into(),
|
||||
status: Some(Status {
|
||||
code: "error".into(),
|
||||
@@ -525,9 +525,6 @@ impl TestBackend for TestConnectMechanism {
|
||||
{
|
||||
unimplemented!("not used in tests")
|
||||
}
|
||||
fn get_role_secret(&self) -> Result<CachedRoleSecret, console::errors::GetAuthInfoError> {
|
||||
unimplemented!("not used in tests")
|
||||
}
|
||||
}
|
||||
|
||||
fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo {
|
||||
|
||||
@@ -102,7 +102,7 @@ async fn proxy_mitm(
|
||||
}
|
||||
|
||||
/// taken from tokio-postgres
|
||||
pub async fn connect_tls<S, T>(mut stream: S, tls: T) -> T::Stream
|
||||
pub(crate) async fn connect_tls<S, T>(mut stream: S, tls: T) -> T::Stream
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: TlsConnect<S>,
|
||||
|
||||
@@ -12,7 +12,7 @@ use tracing::{error, info, warn};
|
||||
|
||||
use super::connect_compute::ComputeConnectBackend;
|
||||
|
||||
pub async fn wake_compute<B: ComputeConnectBackend>(
|
||||
pub(crate) async fn wake_compute<B: ComputeConnectBackend>(
|
||||
num_retries: &mut u32,
|
||||
ctx: &RequestMonitoring,
|
||||
api: &B,
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
mod leaky_bucket;
|
||||
mod limit_algorithm;
|
||||
mod limiter;
|
||||
pub use limit_algorithm::{
|
||||
aimd::Aimd, DynamicLimiter, Outcome, RateLimitAlgorithm, RateLimiterConfig, Token,
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) use limit_algorithm::aimd::Aimd;
|
||||
|
||||
pub(crate) use limit_algorithm::{
|
||||
DynamicLimiter, Outcome, RateLimitAlgorithm, RateLimiterConfig, Token,
|
||||
};
|
||||
pub use limiter::{BucketRateLimiter, GlobalRateLimiter, RateBucketInfo, WakeComputeRateLimiter};
|
||||
mod leaky_bucket;
|
||||
pub(crate) use limiter::GlobalRateLimiter;
|
||||
|
||||
pub use leaky_bucket::{
|
||||
EndpointRateLimiter, LeakyBucketConfig, LeakyBucketRateLimiter, LeakyBucketState,
|
||||
};
|
||||
pub use limiter::{BucketRateLimiter, RateBucketInfo, WakeComputeRateLimiter};
|
||||
|
||||
@@ -35,7 +35,7 @@ impl<K: Hash + Eq> LeakyBucketRateLimiter<K> {
|
||||
}
|
||||
|
||||
/// Check that number of connections to the endpoint is below `max_rps` rps.
|
||||
pub fn check(&self, key: K, n: u32) -> bool {
|
||||
pub(crate) fn check(&self, key: K, n: u32) -> bool {
|
||||
let now = Instant::now();
|
||||
|
||||
if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 {
|
||||
@@ -73,8 +73,9 @@ pub struct LeakyBucketState {
|
||||
time: Instant,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl LeakyBucketConfig {
|
||||
pub fn new(rps: f64, max: f64) -> Self {
|
||||
pub(crate) fn new(rps: f64, max: f64) -> Self {
|
||||
assert!(rps > 0.0, "rps must be positive");
|
||||
assert!(max > 0.0, "max must be positive");
|
||||
Self { rps, max }
|
||||
@@ -82,7 +83,7 @@ impl LeakyBucketConfig {
|
||||
}
|
||||
|
||||
impl LeakyBucketState {
|
||||
pub fn new() -> Self {
|
||||
pub(crate) fn new() -> Self {
|
||||
Self {
|
||||
filled: 0.0,
|
||||
time: Instant::now(),
|
||||
@@ -100,7 +101,7 @@ impl LeakyBucketState {
|
||||
self.filled == 0.0
|
||||
}
|
||||
|
||||
pub fn check(&mut self, info: &LeakyBucketConfig, now: Instant, n: f64) -> bool {
|
||||
pub(crate) fn check(&mut self, info: &LeakyBucketConfig, now: Instant, n: f64) -> bool {
|
||||
self.update(info, now);
|
||||
|
||||
if self.filled + n > info.max {
|
||||
|
||||
@@ -8,13 +8,13 @@ use tokio::{
|
||||
|
||||
use self::aimd::Aimd;
|
||||
|
||||
pub mod aimd;
|
||||
pub(crate) mod aimd;
|
||||
|
||||
/// Whether a job succeeded or failed as a result of congestion/overload.
|
||||
///
|
||||
/// Errors not considered to be caused by overload should be ignored.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum Outcome {
|
||||
pub(crate) enum Outcome {
|
||||
/// The job succeeded, or failed in a way unrelated to overload.
|
||||
Success,
|
||||
/// The job failed because of overload, e.g. it timed out or an explicit backpressure signal
|
||||
@@ -23,14 +23,14 @@ pub enum Outcome {
|
||||
}
|
||||
|
||||
/// An algorithm for controlling a concurrency limit.
|
||||
pub trait LimitAlgorithm: Send + Sync + 'static {
|
||||
pub(crate) trait LimitAlgorithm: Send + Sync + 'static {
|
||||
/// Update the concurrency limit in response to a new job completion.
|
||||
fn update(&self, old_limit: usize, sample: Sample) -> usize;
|
||||
}
|
||||
|
||||
/// The result of a job (or jobs), including the [`Outcome`] (loss) and latency (delay).
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
|
||||
pub struct Sample {
|
||||
pub(crate) struct Sample {
|
||||
pub(crate) latency: Duration,
|
||||
/// Jobs in flight when the sample was taken.
|
||||
pub(crate) in_flight: usize,
|
||||
@@ -39,7 +39,7 @@ pub struct Sample {
|
||||
|
||||
#[derive(Clone, Copy, Debug, Default, serde::Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum RateLimitAlgorithm {
|
||||
pub(crate) enum RateLimitAlgorithm {
|
||||
#[default]
|
||||
Fixed,
|
||||
Aimd {
|
||||
@@ -48,7 +48,7 @@ pub enum RateLimitAlgorithm {
|
||||
},
|
||||
}
|
||||
|
||||
pub struct Fixed;
|
||||
pub(crate) struct Fixed;
|
||||
|
||||
impl LimitAlgorithm for Fixed {
|
||||
fn update(&self, old_limit: usize, _sample: Sample) -> usize {
|
||||
@@ -59,12 +59,12 @@ impl LimitAlgorithm for Fixed {
|
||||
#[derive(Clone, Copy, Debug, serde::Deserialize, PartialEq)]
|
||||
pub struct RateLimiterConfig {
|
||||
#[serde(flatten)]
|
||||
pub algorithm: RateLimitAlgorithm,
|
||||
pub initial_limit: usize,
|
||||
pub(crate) algorithm: RateLimitAlgorithm,
|
||||
pub(crate) initial_limit: usize,
|
||||
}
|
||||
|
||||
impl RateLimiterConfig {
|
||||
pub fn create_rate_limit_algorithm(self) -> Box<dyn LimitAlgorithm> {
|
||||
pub(crate) fn create_rate_limit_algorithm(self) -> Box<dyn LimitAlgorithm> {
|
||||
match self.algorithm {
|
||||
RateLimitAlgorithm::Fixed => Box::new(Fixed),
|
||||
RateLimitAlgorithm::Aimd { conf } => Box::new(conf),
|
||||
@@ -72,7 +72,7 @@ impl RateLimiterConfig {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct LimiterInner {
|
||||
pub(crate) struct LimiterInner {
|
||||
alg: Box<dyn LimitAlgorithm>,
|
||||
available: usize,
|
||||
limit: usize,
|
||||
@@ -114,7 +114,7 @@ impl LimiterInner {
|
||||
///
|
||||
/// The limit will be automatically adjusted based on observed latency (delay) and/or failures
|
||||
/// caused by overload (loss).
|
||||
pub struct DynamicLimiter {
|
||||
pub(crate) struct DynamicLimiter {
|
||||
config: RateLimiterConfig,
|
||||
inner: Mutex<LimiterInner>,
|
||||
// to notify when a token is available
|
||||
@@ -124,7 +124,7 @@ pub struct DynamicLimiter {
|
||||
/// A concurrency token, required to run a job.
|
||||
///
|
||||
/// Release the token back to the [`DynamicLimiter`] after the job is complete.
|
||||
pub struct Token {
|
||||
pub(crate) struct Token {
|
||||
start: Instant,
|
||||
limiter: Option<Arc<DynamicLimiter>>,
|
||||
}
|
||||
@@ -133,14 +133,14 @@ pub struct Token {
|
||||
///
|
||||
/// Not guaranteed to be consistent under high concurrency.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct LimiterState {
|
||||
#[cfg(test)]
|
||||
struct LimiterState {
|
||||
limit: usize,
|
||||
in_flight: usize,
|
||||
}
|
||||
|
||||
impl DynamicLimiter {
|
||||
/// Create a limiter with a given limit control algorithm.
|
||||
pub fn new(config: RateLimiterConfig) -> Arc<Self> {
|
||||
pub(crate) fn new(config: RateLimiterConfig) -> Arc<Self> {
|
||||
let ready = Notify::new();
|
||||
ready.notify_one();
|
||||
|
||||
@@ -157,7 +157,10 @@ impl DynamicLimiter {
|
||||
}
|
||||
|
||||
/// Try to acquire a concurrency [Token], waiting for `duration` if there are none available.
|
||||
pub async fn acquire_timeout(self: &Arc<Self>, duration: Duration) -> Result<Token, Elapsed> {
|
||||
pub(crate) async fn acquire_timeout(
|
||||
self: &Arc<Self>,
|
||||
duration: Duration,
|
||||
) -> Result<Token, Elapsed> {
|
||||
tokio::time::timeout(duration, self.acquire()).await?
|
||||
}
|
||||
|
||||
@@ -208,12 +211,10 @@ impl DynamicLimiter {
|
||||
}
|
||||
|
||||
/// The current state of the limiter.
|
||||
pub fn state(&self) -> LimiterState {
|
||||
#[cfg(test)]
|
||||
fn state(&self) -> LimiterState {
|
||||
let inner = self.inner.lock();
|
||||
LimiterState {
|
||||
limit: inner.limit,
|
||||
in_flight: inner.in_flight,
|
||||
}
|
||||
LimiterState { limit: inner.limit }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -224,22 +225,22 @@ impl Token {
|
||||
limiter: Some(limiter),
|
||||
}
|
||||
}
|
||||
pub fn disabled() -> Self {
|
||||
pub(crate) fn disabled() -> Self {
|
||||
Self {
|
||||
start: Instant::now(),
|
||||
limiter: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_disabled(&self) -> bool {
|
||||
pub(crate) fn is_disabled(&self) -> bool {
|
||||
self.limiter.is_none()
|
||||
}
|
||||
|
||||
pub fn release(mut self, outcome: Outcome) {
|
||||
pub(crate) fn release(mut self, outcome: Outcome) {
|
||||
self.release_mut(Some(outcome));
|
||||
}
|
||||
|
||||
pub fn release_mut(&mut self, outcome: Option<Outcome>) {
|
||||
pub(crate) fn release_mut(&mut self, outcome: Option<Outcome>) {
|
||||
if let Some(limiter) = self.limiter.take() {
|
||||
limiter.release_inner(self.start, outcome);
|
||||
}
|
||||
@@ -252,13 +253,10 @@ impl Drop for Token {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl LimiterState {
|
||||
/// The current concurrency limit.
|
||||
pub fn limit(&self) -> usize {
|
||||
fn limit(self) -> usize {
|
||||
self.limit
|
||||
}
|
||||
/// The number of jobs in flight.
|
||||
pub fn in_flight(&self) -> usize {
|
||||
self.in_flight
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,17 +10,17 @@ use super::{LimitAlgorithm, Outcome, Sample};
|
||||
///
|
||||
/// Reduces available concurrency by a factor when load-based errors are detected.
|
||||
#[derive(Clone, Copy, Debug, serde::Deserialize, PartialEq)]
|
||||
pub struct Aimd {
|
||||
pub(crate) struct Aimd {
|
||||
/// Minimum limit for AIMD algorithm.
|
||||
pub min: usize,
|
||||
pub(crate) min: usize,
|
||||
/// Maximum limit for AIMD algorithm.
|
||||
pub max: usize,
|
||||
pub(crate) max: usize,
|
||||
/// Decrease AIMD decrease by value in case of error.
|
||||
pub dec: f32,
|
||||
pub(crate) dec: f32,
|
||||
/// Increase AIMD increase by value in case of success.
|
||||
pub inc: usize,
|
||||
pub(crate) inc: usize,
|
||||
/// A threshold below which the limit won't be increased.
|
||||
pub utilisation: f32,
|
||||
pub(crate) utilisation: f32,
|
||||
}
|
||||
|
||||
impl LimitAlgorithm for Aimd {
|
||||
|
||||
@@ -17,13 +17,13 @@ use tracing::info;
|
||||
|
||||
use crate::intern::EndpointIdInt;
|
||||
|
||||
pub struct GlobalRateLimiter {
|
||||
pub(crate) struct GlobalRateLimiter {
|
||||
data: Vec<RateBucket>,
|
||||
info: Vec<RateBucketInfo>,
|
||||
}
|
||||
|
||||
impl GlobalRateLimiter {
|
||||
pub fn new(info: Vec<RateBucketInfo>) -> Self {
|
||||
pub(crate) fn new(info: Vec<RateBucketInfo>) -> Self {
|
||||
Self {
|
||||
data: vec![
|
||||
RateBucket {
|
||||
@@ -37,7 +37,7 @@ impl GlobalRateLimiter {
|
||||
}
|
||||
|
||||
/// Check that number of connections is below `max_rps` rps.
|
||||
pub fn check(&mut self) -> bool {
|
||||
pub(crate) fn check(&mut self) -> bool {
|
||||
let now = Instant::now();
|
||||
|
||||
let should_allow_request = self
|
||||
@@ -96,9 +96,9 @@ impl RateBucket {
|
||||
|
||||
#[derive(Clone, Copy, PartialEq)]
|
||||
pub struct RateBucketInfo {
|
||||
pub interval: Duration,
|
||||
pub(crate) interval: Duration,
|
||||
// requests per interval
|
||||
pub max_rpi: u32,
|
||||
pub(crate) max_rpi: u32,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for RateBucketInfo {
|
||||
@@ -192,7 +192,7 @@ impl<K: Hash + Eq, R: Rng, S: BuildHasher + Clone> BucketRateLimiter<K, R, S> {
|
||||
}
|
||||
|
||||
/// Check that number of connections to the endpoint is below `max_rps` rps.
|
||||
pub fn check(&self, key: K, n: u32) -> bool {
|
||||
pub(crate) fn check(&self, key: K, n: u32) -> bool {
|
||||
// do a partial GC every 2k requests. This cleans up ~ 1/64th of the map.
|
||||
// worst case memory usage is about:
|
||||
// = 2 * 2048 * 64 * (48B + 72B)
|
||||
@@ -228,7 +228,7 @@ impl<K: Hash + Eq, R: Rng, S: BuildHasher + Clone> BucketRateLimiter<K, R, S> {
|
||||
/// Clean the map. Simple strategy: remove all entries in a random shard.
|
||||
/// At worst, we'll double the effective max_rps during the cleanup.
|
||||
/// But that way deletion does not aquire mutex on each entry access.
|
||||
pub fn do_gc(&self) {
|
||||
pub(crate) fn do_gc(&self) {
|
||||
info!(
|
||||
"cleaning up bucket rate limiter, current size = {}",
|
||||
self.map.len()
|
||||
|
||||
@@ -109,7 +109,7 @@ impl RedisPublisherClient {
|
||||
let _: () = self.client.publish(PROXY_CHANNEL_NAME, payload).await?;
|
||||
Ok(())
|
||||
}
|
||||
pub async fn try_connect(&mut self) -> anyhow::Result<()> {
|
||||
pub(crate) async fn try_connect(&mut self) -> anyhow::Result<()> {
|
||||
match self.client.connect().await {
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
|
||||
@@ -81,7 +81,7 @@ impl ConnectionWithCredentialsProvider {
|
||||
redis::cmd("PING").query_async(con).await
|
||||
}
|
||||
|
||||
pub async fn connect(&mut self) -> anyhow::Result<()> {
|
||||
pub(crate) async fn connect(&mut self) -> anyhow::Result<()> {
|
||||
let _guard = self.mutex.lock().await;
|
||||
if let Some(con) = self.con.as_mut() {
|
||||
match Self::ping(con).await {
|
||||
@@ -149,7 +149,7 @@ impl ConnectionWithCredentialsProvider {
|
||||
|
||||
// PubSub does not support credentials refresh.
|
||||
// Requires manual reconnection every 12h.
|
||||
pub async fn get_async_pubsub(&self) -> anyhow::Result<redis::aio::PubSub> {
|
||||
pub(crate) async fn get_async_pubsub(&self) -> anyhow::Result<redis::aio::PubSub> {
|
||||
Ok(self.get_client().await?.get_async_pubsub().await?)
|
||||
}
|
||||
|
||||
@@ -187,7 +187,10 @@ impl ConnectionWithCredentialsProvider {
|
||||
}
|
||||
/// Sends an already encoded (packed) command into the TCP socket and
|
||||
/// reads the single response from it.
|
||||
pub async fn send_packed_command(&mut self, cmd: &redis::Cmd) -> RedisResult<redis::Value> {
|
||||
pub(crate) async fn send_packed_command(
|
||||
&mut self,
|
||||
cmd: &redis::Cmd,
|
||||
) -> RedisResult<redis::Value> {
|
||||
// Clone connection to avoid having to lock the ArcSwap in write mode
|
||||
let con = self.con.as_mut().ok_or(redis::RedisError::from((
|
||||
redis::ErrorKind::IoError,
|
||||
@@ -199,7 +202,7 @@ impl ConnectionWithCredentialsProvider {
|
||||
/// Sends multiple already encoded (packed) command into the TCP socket
|
||||
/// and reads `count` responses from it. This is used to implement
|
||||
/// pipelining.
|
||||
pub async fn send_packed_commands(
|
||||
pub(crate) async fn send_packed_commands(
|
||||
&mut self,
|
||||
cmd: &redis::Pipeline,
|
||||
offset: usize,
|
||||
|
||||
@@ -51,7 +51,7 @@ impl CredentialsProvider {
|
||||
credentials_provider,
|
||||
}
|
||||
}
|
||||
pub async fn provide_credentials(&self) -> anyhow::Result<(String, String)> {
|
||||
pub(crate) async fn provide_credentials(&self) -> anyhow::Result<(String, String)> {
|
||||
let aws_credentials = self
|
||||
.credentials_provider
|
||||
.provide_credentials()
|
||||
|
||||
@@ -58,9 +58,9 @@ pub(crate) struct PasswordUpdate {
|
||||
}
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
pub(crate) struct CancelSession {
|
||||
pub region_id: Option<String>,
|
||||
pub cancel_key_data: CancelKeyData,
|
||||
pub session_id: Uuid,
|
||||
pub(crate) region_id: Option<String>,
|
||||
pub(crate) cancel_key_data: CancelKeyData,
|
||||
pub(crate) session_id: Uuid,
|
||||
}
|
||||
|
||||
fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result<T, D::Error>
|
||||
@@ -89,7 +89,7 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> Clone for MessageHandler<C> {
|
||||
}
|
||||
|
||||
impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
|
||||
pub fn new(
|
||||
pub(crate) fn new(
|
||||
cache: Arc<C>,
|
||||
cancellation_handler: Arc<CancellationHandler<()>>,
|
||||
region_id: String,
|
||||
@@ -100,10 +100,10 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
|
||||
region_id,
|
||||
}
|
||||
}
|
||||
pub async fn increment_active_listeners(&self) {
|
||||
pub(crate) async fn increment_active_listeners(&self) {
|
||||
self.cache.increment_active_listeners().await;
|
||||
}
|
||||
pub async fn decrement_active_listeners(&self) {
|
||||
pub(crate) async fn decrement_active_listeners(&self) {
|
||||
self.cache.decrement_active_listeners().await;
|
||||
}
|
||||
#[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))]
|
||||
|
||||
@@ -14,13 +14,13 @@ use crate::error::{ReportableError, UserFacingError};
|
||||
use std::io;
|
||||
use thiserror::Error;
|
||||
|
||||
pub use channel_binding::ChannelBinding;
|
||||
pub use messages::FirstMessage;
|
||||
pub use stream::{Outcome, SaslStream};
|
||||
pub(crate) use channel_binding::ChannelBinding;
|
||||
pub(crate) use messages::FirstMessage;
|
||||
pub(crate) use stream::{Outcome, SaslStream};
|
||||
|
||||
/// Fine-grained auth errors help in writing tests.
|
||||
#[derive(Error, Debug)]
|
||||
pub enum Error {
|
||||
pub(crate) enum Error {
|
||||
#[error("Channel binding failed: {0}")]
|
||||
ChannelBindingFailed(&'static str),
|
||||
|
||||
@@ -64,11 +64,11 @@ impl ReportableError for Error {
|
||||
}
|
||||
|
||||
/// A convenient result type for SASL exchange.
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
pub(crate) type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
/// A result of one SASL exchange.
|
||||
#[must_use]
|
||||
pub enum Step<T, R> {
|
||||
pub(crate) enum Step<T, R> {
|
||||
/// We should continue exchanging messages.
|
||||
Continue(T, String),
|
||||
/// The client has been authenticated successfully.
|
||||
@@ -78,7 +78,7 @@ pub enum Step<T, R> {
|
||||
}
|
||||
|
||||
/// Every SASL mechanism (e.g. [SCRAM](crate::scram)) is expected to implement this trait.
|
||||
pub trait Mechanism: Sized {
|
||||
pub(crate) trait Mechanism: Sized {
|
||||
/// What's produced as a result of successful authentication.
|
||||
type Output;
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
/// Channel binding flag (possibly with params).
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum ChannelBinding<T> {
|
||||
pub(crate) enum ChannelBinding<T> {
|
||||
/// Client doesn't support channel binding.
|
||||
NotSupportedClient,
|
||||
/// Client thinks server doesn't support channel binding.
|
||||
@@ -12,7 +12,10 @@ pub enum ChannelBinding<T> {
|
||||
}
|
||||
|
||||
impl<T> ChannelBinding<T> {
|
||||
pub fn and_then<R, E>(self, f: impl FnOnce(T) -> Result<R, E>) -> Result<ChannelBinding<R>, E> {
|
||||
pub(crate) fn and_then<R, E>(
|
||||
self,
|
||||
f: impl FnOnce(T) -> Result<R, E>,
|
||||
) -> Result<ChannelBinding<R>, E> {
|
||||
Ok(match self {
|
||||
Self::NotSupportedClient => ChannelBinding::NotSupportedClient,
|
||||
Self::NotSupportedServer => ChannelBinding::NotSupportedServer,
|
||||
@@ -23,7 +26,7 @@ impl<T> ChannelBinding<T> {
|
||||
|
||||
impl<'a> ChannelBinding<&'a str> {
|
||||
// NB: FromStr doesn't work with lifetimes
|
||||
pub fn parse(input: &'a str) -> Option<Self> {
|
||||
pub(crate) fn parse(input: &'a str) -> Option<Self> {
|
||||
Some(match input {
|
||||
"n" => Self::NotSupportedClient,
|
||||
"y" => Self::NotSupportedServer,
|
||||
@@ -34,7 +37,7 @@ impl<'a> ChannelBinding<&'a str> {
|
||||
|
||||
impl<T: std::fmt::Display> ChannelBinding<T> {
|
||||
/// Encode channel binding data as base64 for subsequent checks.
|
||||
pub fn encode<'a, E>(
|
||||
pub(crate) fn encode<'a, E>(
|
||||
&self,
|
||||
get_cbind_data: impl FnOnce(&T) -> Result<&'a [u8], E>,
|
||||
) -> Result<std::borrow::Cow<'static, str>, E> {
|
||||
|
||||
@@ -5,16 +5,16 @@ use pq_proto::{BeAuthenticationSaslMessage, BeMessage};
|
||||
|
||||
/// SASL-specific payload of [`PasswordMessage`](pq_proto::FeMessage::PasswordMessage).
|
||||
#[derive(Debug)]
|
||||
pub struct FirstMessage<'a> {
|
||||
pub(crate) struct FirstMessage<'a> {
|
||||
/// Authentication method, e.g. `"SCRAM-SHA-256"`.
|
||||
pub method: &'a str,
|
||||
pub(crate) method: &'a str,
|
||||
/// Initial client message.
|
||||
pub message: &'a str,
|
||||
pub(crate) message: &'a str,
|
||||
}
|
||||
|
||||
impl<'a> FirstMessage<'a> {
|
||||
// NB: FromStr doesn't work with lifetimes
|
||||
pub fn parse(bytes: &'a [u8]) -> Option<Self> {
|
||||
pub(crate) fn parse(bytes: &'a [u8]) -> Option<Self> {
|
||||
let (method_cstr, tail) = split_cstr(bytes)?;
|
||||
let method = method_cstr.to_str().ok()?;
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::info;
|
||||
|
||||
/// Abstracts away all peculiarities of the libpq's protocol.
|
||||
pub struct SaslStream<'a, S> {
|
||||
pub(crate) struct SaslStream<'a, S> {
|
||||
/// The underlying stream.
|
||||
stream: &'a mut PqStream<S>,
|
||||
/// Current password message we received from client.
|
||||
@@ -17,7 +17,7 @@ pub struct SaslStream<'a, S> {
|
||||
}
|
||||
|
||||
impl<'a, S> SaslStream<'a, S> {
|
||||
pub fn new(stream: &'a mut PqStream<S>, first: &'a str) -> Self {
|
||||
pub(crate) fn new(stream: &'a mut PqStream<S>, first: &'a str) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
current: bytes::Bytes::new(),
|
||||
@@ -53,7 +53,7 @@ impl<S: AsyncWrite + Unpin> SaslStream<'_, S> {
|
||||
/// It's much easier to match on those two variants
|
||||
/// than to peek into a noisy protocol error type.
|
||||
#[must_use = "caller must explicitly check for success"]
|
||||
pub enum Outcome<R> {
|
||||
pub(crate) enum Outcome<R> {
|
||||
/// Authentication succeeded and produced some value.
|
||||
Success(R),
|
||||
/// Authentication failed (reason attached).
|
||||
@@ -63,7 +63,7 @@ pub enum Outcome<R> {
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> SaslStream<'_, S> {
|
||||
/// Perform SASL message exchange according to the underlying algorithm
|
||||
/// until user is either authenticated or denied access.
|
||||
pub async fn authenticate<M: Mechanism>(
|
||||
pub(crate) async fn authenticate<M: Mechanism>(
|
||||
mut self,
|
||||
mut mechanism: M,
|
||||
) -> super::Result<Outcome<M::Output>> {
|
||||
|
||||
@@ -15,9 +15,9 @@ mod secret;
|
||||
mod signature;
|
||||
pub mod threadpool;
|
||||
|
||||
pub use exchange::{exchange, Exchange};
|
||||
pub use key::ScramKey;
|
||||
pub use secret::ServerSecret;
|
||||
pub(crate) use exchange::{exchange, Exchange};
|
||||
pub(crate) use key::ScramKey;
|
||||
pub(crate) use secret::ServerSecret;
|
||||
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::{Digest, Sha256};
|
||||
@@ -26,8 +26,8 @@ const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
|
||||
const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
|
||||
|
||||
/// A list of supported SCRAM methods.
|
||||
pub const METHODS: &[&str] = &[SCRAM_SHA_256_PLUS, SCRAM_SHA_256];
|
||||
pub const METHODS_WITHOUT_PLUS: &[&str] = &[SCRAM_SHA_256];
|
||||
pub(crate) const METHODS: &[&str] = &[SCRAM_SHA_256_PLUS, SCRAM_SHA_256];
|
||||
pub(crate) const METHODS_WITHOUT_PLUS: &[&str] = &[SCRAM_SHA_256];
|
||||
|
||||
/// Decode base64 into array without any heap allocations
|
||||
fn base64_decode_array<const N: usize>(input: impl AsRef<[u8]>) -> Option<[u8; N]> {
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::hash::Hash;
|
||||
|
||||
/// estimator of hash jobs per second.
|
||||
/// <https://en.wikipedia.org/wiki/Count%E2%80%93min_sketch>
|
||||
pub struct CountMinSketch {
|
||||
pub(crate) struct CountMinSketch {
|
||||
// one for each depth
|
||||
hashers: Vec<ahash::RandomState>,
|
||||
width: usize,
|
||||
@@ -20,7 +20,7 @@ impl CountMinSketch {
|
||||
/// actual <= estimate
|
||||
/// estimate <= actual + ε * N with probability 1 - δ
|
||||
/// where N is the cardinality of the stream
|
||||
pub fn with_params(epsilon: f64, delta: f64) -> Self {
|
||||
pub(crate) fn with_params(epsilon: f64, delta: f64) -> Self {
|
||||
CountMinSketch::new(
|
||||
(std::f64::consts::E / epsilon).ceil() as usize,
|
||||
(1.0_f64 / delta).ln().ceil() as usize,
|
||||
@@ -49,7 +49,7 @@ impl CountMinSketch {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn inc_and_return<T: Hash>(&mut self, t: &T, x: u32) -> u32 {
|
||||
pub(crate) fn inc_and_return<T: Hash>(&mut self, t: &T, x: u32) -> u32 {
|
||||
let mut min = u32::MAX;
|
||||
for row in 0..self.depth {
|
||||
let col = (self.hashers[row].hash_one(t) as usize) % self.width;
|
||||
@@ -61,7 +61,7 @@ impl CountMinSketch {
|
||||
min
|
||||
}
|
||||
|
||||
pub fn reset(&mut self) {
|
||||
pub(crate) fn reset(&mut self) {
|
||||
self.buckets.clear();
|
||||
self.buckets.resize(self.width * self.depth, 0);
|
||||
}
|
||||
|
||||
@@ -56,14 +56,14 @@ enum ExchangeState {
|
||||
}
|
||||
|
||||
/// Server's side of SCRAM auth algorithm.
|
||||
pub struct Exchange<'a> {
|
||||
pub(crate) struct Exchange<'a> {
|
||||
state: ExchangeState,
|
||||
secret: &'a ServerSecret,
|
||||
tls_server_end_point: config::TlsServerEndPoint,
|
||||
}
|
||||
|
||||
impl<'a> Exchange<'a> {
|
||||
pub fn new(
|
||||
pub(crate) fn new(
|
||||
secret: &'a ServerSecret,
|
||||
nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
|
||||
tls_server_end_point: config::TlsServerEndPoint,
|
||||
@@ -101,7 +101,7 @@ async fn derive_client_key(
|
||||
make_key(b"Client Key").into()
|
||||
}
|
||||
|
||||
pub async fn exchange(
|
||||
pub(crate) async fn exchange(
|
||||
pool: &ThreadPool,
|
||||
endpoint: EndpointIdInt,
|
||||
secret: &ServerSecret,
|
||||
|
||||
@@ -3,14 +3,14 @@
|
||||
use subtle::ConstantTimeEq;
|
||||
|
||||
/// Faithfully taken from PostgreSQL.
|
||||
pub const SCRAM_KEY_LEN: usize = 32;
|
||||
pub(crate) const SCRAM_KEY_LEN: usize = 32;
|
||||
|
||||
/// One of the keys derived from the user's password.
|
||||
/// We use the same structure for all keys, i.e.
|
||||
/// `ClientKey`, `StoredKey`, and `ServerKey`.
|
||||
#[derive(Clone, Default, Eq, Debug)]
|
||||
#[repr(transparent)]
|
||||
pub struct ScramKey {
|
||||
pub(crate) struct ScramKey {
|
||||
bytes: [u8; SCRAM_KEY_LEN],
|
||||
}
|
||||
|
||||
@@ -27,11 +27,11 @@ impl ConstantTimeEq for ScramKey {
|
||||
}
|
||||
|
||||
impl ScramKey {
|
||||
pub fn sha256(&self) -> Self {
|
||||
pub(crate) fn sha256(&self) -> Self {
|
||||
super::sha256([self.as_ref()]).into()
|
||||
}
|
||||
|
||||
pub fn as_bytes(&self) -> [u8; SCRAM_KEY_LEN] {
|
||||
pub(crate) fn as_bytes(&self) -> [u8; SCRAM_KEY_LEN] {
|
||||
self.bytes
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ use std::fmt;
|
||||
use std::ops::Range;
|
||||
|
||||
/// Faithfully taken from PostgreSQL.
|
||||
pub const SCRAM_RAW_NONCE_LEN: usize = 18;
|
||||
pub(crate) const SCRAM_RAW_NONCE_LEN: usize = 18;
|
||||
|
||||
/// Although we ignore all extensions, we still have to validate the message.
|
||||
fn validate_sasl_extensions<'a>(parts: impl Iterator<Item = &'a str>) -> Option<()> {
|
||||
@@ -27,18 +27,18 @@ fn validate_sasl_extensions<'a>(parts: impl Iterator<Item = &'a str>) -> Option<
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ClientFirstMessage<'a> {
|
||||
pub(crate) struct ClientFirstMessage<'a> {
|
||||
/// `client-first-message-bare`.
|
||||
pub bare: &'a str,
|
||||
pub(crate) bare: &'a str,
|
||||
/// Channel binding mode.
|
||||
pub cbind_flag: ChannelBinding<&'a str>,
|
||||
pub(crate) cbind_flag: ChannelBinding<&'a str>,
|
||||
/// Client nonce.
|
||||
pub nonce: &'a str,
|
||||
pub(crate) nonce: &'a str,
|
||||
}
|
||||
|
||||
impl<'a> ClientFirstMessage<'a> {
|
||||
// NB: FromStr doesn't work with lifetimes
|
||||
pub fn parse(input: &'a str) -> Option<Self> {
|
||||
pub(crate) fn parse(input: &'a str) -> Option<Self> {
|
||||
let mut parts = input.split(',');
|
||||
|
||||
let cbind_flag = ChannelBinding::parse(parts.next()?)?;
|
||||
@@ -77,7 +77,7 @@ impl<'a> ClientFirstMessage<'a> {
|
||||
}
|
||||
|
||||
/// Build a response to [`ClientFirstMessage`].
|
||||
pub fn build_server_first_message(
|
||||
pub(crate) fn build_server_first_message(
|
||||
&self,
|
||||
nonce: &[u8; SCRAM_RAW_NONCE_LEN],
|
||||
salt_base64: &str,
|
||||
@@ -101,20 +101,20 @@ impl<'a> ClientFirstMessage<'a> {
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ClientFinalMessage<'a> {
|
||||
pub(crate) struct ClientFinalMessage<'a> {
|
||||
/// `client-final-message-without-proof`.
|
||||
pub without_proof: &'a str,
|
||||
pub(crate) without_proof: &'a str,
|
||||
/// Channel binding data (base64).
|
||||
pub channel_binding: &'a str,
|
||||
pub(crate) channel_binding: &'a str,
|
||||
/// Combined client & server nonce.
|
||||
pub nonce: &'a str,
|
||||
pub(crate) nonce: &'a str,
|
||||
/// Client auth proof.
|
||||
pub proof: [u8; SCRAM_KEY_LEN],
|
||||
pub(crate) proof: [u8; SCRAM_KEY_LEN],
|
||||
}
|
||||
|
||||
impl<'a> ClientFinalMessage<'a> {
|
||||
// NB: FromStr doesn't work with lifetimes
|
||||
pub fn parse(input: &'a str) -> Option<Self> {
|
||||
pub(crate) fn parse(input: &'a str) -> Option<Self> {
|
||||
let (without_proof, proof) = input.rsplit_once(',')?;
|
||||
|
||||
let mut parts = without_proof.split(',');
|
||||
@@ -135,7 +135,7 @@ impl<'a> ClientFinalMessage<'a> {
|
||||
}
|
||||
|
||||
/// Build a response to [`ClientFinalMessage`].
|
||||
pub fn build_server_final_message(
|
||||
pub(crate) fn build_server_final_message(
|
||||
&self,
|
||||
signature_builder: SignatureBuilder<'_>,
|
||||
server_key: &ScramKey,
|
||||
@@ -153,7 +153,7 @@ impl<'a> ClientFinalMessage<'a> {
|
||||
|
||||
/// We need to keep a convenient representation of this
|
||||
/// message for the next authentication step.
|
||||
pub struct OwnedServerFirstMessage {
|
||||
pub(crate) struct OwnedServerFirstMessage {
|
||||
/// Owned `server-first-message`.
|
||||
message: String,
|
||||
/// Slice into `message`.
|
||||
@@ -163,13 +163,13 @@ pub struct OwnedServerFirstMessage {
|
||||
impl OwnedServerFirstMessage {
|
||||
/// Extract combined nonce from the message.
|
||||
#[inline(always)]
|
||||
pub fn nonce(&self) -> &str {
|
||||
pub(crate) fn nonce(&self) -> &str {
|
||||
&self.message[self.nonce.clone()]
|
||||
}
|
||||
|
||||
/// Get reference to a text representation of the message.
|
||||
#[inline(always)]
|
||||
pub fn as_str(&self) -> &str {
|
||||
pub(crate) fn as_str(&self) -> &str {
|
||||
&self.message
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ use hmac::{
|
||||
};
|
||||
use sha2::Sha256;
|
||||
|
||||
pub struct Pbkdf2 {
|
||||
pub(crate) struct Pbkdf2 {
|
||||
hmac: Hmac<Sha256>,
|
||||
prev: GenericArray<u8, U32>,
|
||||
hi: GenericArray<u8, U32>,
|
||||
@@ -13,7 +13,7 @@ pub struct Pbkdf2 {
|
||||
|
||||
// inspired from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L36-L61>
|
||||
impl Pbkdf2 {
|
||||
pub fn start(str: &[u8], salt: &[u8], iterations: u32) -> Self {
|
||||
pub(crate) fn start(str: &[u8], salt: &[u8], iterations: u32) -> Self {
|
||||
let hmac =
|
||||
Hmac::<Sha256>::new_from_slice(str).expect("HMAC is able to accept all key sizes");
|
||||
|
||||
@@ -33,11 +33,11 @@ impl Pbkdf2 {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cost(&self) -> u32 {
|
||||
pub(crate) fn cost(&self) -> u32 {
|
||||
(self.iterations).clamp(0, 4096)
|
||||
}
|
||||
|
||||
pub fn turn(&mut self) -> std::task::Poll<[u8; 32]> {
|
||||
pub(crate) fn turn(&mut self) -> std::task::Poll<[u8; 32]> {
|
||||
let Self {
|
||||
hmac,
|
||||
prev,
|
||||
|
||||
@@ -8,22 +8,22 @@ use super::key::ScramKey;
|
||||
/// Server secret is produced from user's password,
|
||||
/// and is used throughout the authentication process.
|
||||
#[derive(Clone, Eq, PartialEq, Debug)]
|
||||
pub struct ServerSecret {
|
||||
pub(crate) struct ServerSecret {
|
||||
/// Number of iterations for `PBKDF2` function.
|
||||
pub iterations: u32,
|
||||
pub(crate) iterations: u32,
|
||||
/// Salt used to hash user's password.
|
||||
pub salt_base64: String,
|
||||
pub(crate) salt_base64: String,
|
||||
/// Hashed `ClientKey`.
|
||||
pub stored_key: ScramKey,
|
||||
pub(crate) stored_key: ScramKey,
|
||||
/// Used by client to verify server's signature.
|
||||
pub server_key: ScramKey,
|
||||
pub(crate) server_key: ScramKey,
|
||||
/// Should auth fail no matter what?
|
||||
/// This is exactly the case for mocked secrets.
|
||||
pub doomed: bool,
|
||||
pub(crate) doomed: bool,
|
||||
}
|
||||
|
||||
impl ServerSecret {
|
||||
pub fn parse(input: &str) -> Option<Self> {
|
||||
pub(crate) fn parse(input: &str) -> Option<Self> {
|
||||
// SCRAM-SHA-256$<iterations>:<salt>$<storedkey>:<serverkey>
|
||||
let s = input.strip_prefix("SCRAM-SHA-256$")?;
|
||||
let (params, keys) = s.split_once('$')?;
|
||||
@@ -42,7 +42,7 @@ impl ServerSecret {
|
||||
Some(secret)
|
||||
}
|
||||
|
||||
pub fn is_password_invalid(&self, client_key: &ScramKey) -> Choice {
|
||||
pub(crate) fn is_password_invalid(&self, client_key: &ScramKey) -> Choice {
|
||||
// constant time to not leak partial key match
|
||||
client_key.sha256().ct_ne(&self.stored_key) | Choice::from(self.doomed as u8)
|
||||
}
|
||||
@@ -50,7 +50,7 @@ impl ServerSecret {
|
||||
/// To avoid revealing information to an attacker, we use a
|
||||
/// mocked server secret even if the user doesn't exist.
|
||||
/// See `auth-scram.c : mock_scram_secret` for details.
|
||||
pub fn mock(nonce: [u8; 32]) -> Self {
|
||||
pub(crate) fn mock(nonce: [u8; 32]) -> Self {
|
||||
Self {
|
||||
// this doesn't reveal much information as we're going to use
|
||||
// iteration count 1 for our generated passwords going forward.
|
||||
@@ -66,7 +66,7 @@ impl ServerSecret {
|
||||
/// Build a new server secret from the prerequisites.
|
||||
/// XXX: We only use this function in tests.
|
||||
#[cfg(test)]
|
||||
pub async fn build(password: &str) -> Option<Self> {
|
||||
pub(crate) async fn build(password: &str) -> Option<Self> {
|
||||
Self::parse(&postgres_protocol::password::scram_sha_256(password.as_bytes()).await)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,14 +4,14 @@ use super::key::{ScramKey, SCRAM_KEY_LEN};
|
||||
|
||||
/// A collection of message parts needed to derive the client's signature.
|
||||
#[derive(Debug)]
|
||||
pub struct SignatureBuilder<'a> {
|
||||
pub client_first_message_bare: &'a str,
|
||||
pub server_first_message: &'a str,
|
||||
pub client_final_message_without_proof: &'a str,
|
||||
pub(crate) struct SignatureBuilder<'a> {
|
||||
pub(crate) client_first_message_bare: &'a str,
|
||||
pub(crate) server_first_message: &'a str,
|
||||
pub(crate) client_final_message_without_proof: &'a str,
|
||||
}
|
||||
|
||||
impl SignatureBuilder<'_> {
|
||||
pub fn build(&self, key: &ScramKey) -> Signature {
|
||||
pub(crate) fn build(&self, key: &ScramKey) -> Signature {
|
||||
let parts = [
|
||||
self.client_first_message_bare.as_bytes(),
|
||||
b",",
|
||||
@@ -28,13 +28,13 @@ impl SignatureBuilder<'_> {
|
||||
/// produces `ClientKey` that we need for authentication.
|
||||
#[derive(Debug)]
|
||||
#[repr(transparent)]
|
||||
pub struct Signature {
|
||||
pub(crate) struct Signature {
|
||||
bytes: [u8; SCRAM_KEY_LEN],
|
||||
}
|
||||
|
||||
impl Signature {
|
||||
/// Derive `ClientKey` from client's signature and proof.
|
||||
pub fn derive_client_key(&self, proof: &[u8; SCRAM_KEY_LEN]) -> ScramKey {
|
||||
pub(crate) fn derive_client_key(&self, proof: &[u8; SCRAM_KEY_LEN]) -> ScramKey {
|
||||
// This is how the proof is calculated:
|
||||
//
|
||||
// 1. sha256(ClientKey) -> StoredKey
|
||||
|
||||
@@ -68,7 +68,7 @@ impl ThreadPool {
|
||||
pool
|
||||
}
|
||||
|
||||
pub fn spawn_job(
|
||||
pub(crate) fn spawn_job(
|
||||
&self,
|
||||
endpoint: EndpointIdInt,
|
||||
pbkdf2: Pbkdf2,
|
||||
|
||||
@@ -25,8 +25,6 @@ use hyper_util::rt::TokioExecutor;
|
||||
use hyper_util::server::conn::auto::Builder;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::SeedableRng;
|
||||
pub use reqwest_middleware::{ClientWithMiddleware, Error};
|
||||
pub use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::time::timeout;
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
@@ -50,7 +48,7 @@ use tokio_util::sync::CancellationToken;
|
||||
use tracing::{error, info, warn, Instrument};
|
||||
use utils::http::error::ApiError;
|
||||
|
||||
pub const SERVERLESS_DRIVER_SNI: &str = "api";
|
||||
pub(crate) const SERVERLESS_DRIVER_SNI: &str = "api";
|
||||
|
||||
pub async fn task_main(
|
||||
config: &'static ProxyConfig,
|
||||
@@ -178,9 +176,9 @@ pub async fn task_main(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + 'static {}
|
||||
pub(crate) trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + 'static {}
|
||||
impl<T: AsyncRead + AsyncWrite + Send + 'static> AsyncReadWrite for T {}
|
||||
pub type AsyncRW = Pin<Box<dyn AsyncReadWrite>>;
|
||||
pub(crate) type AsyncRW = Pin<Box<dyn AsyncReadWrite>>;
|
||||
|
||||
#[async_trait]
|
||||
trait MaybeTlsAcceptor: Send + Sync + 'static {
|
||||
|
||||
@@ -29,14 +29,14 @@ use crate::{
|
||||
|
||||
use super::conn_pool::{poll_client, AuthData, Client, ConnInfo, GlobalConnPool};
|
||||
|
||||
pub struct PoolingBackend {
|
||||
pub pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
|
||||
pub config: &'static ProxyConfig,
|
||||
pub endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
pub(crate) struct PoolingBackend {
|
||||
pub(crate) pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
|
||||
pub(crate) config: &'static ProxyConfig,
|
||||
pub(crate) endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
}
|
||||
|
||||
impl PoolingBackend {
|
||||
pub async fn authenticate_with_password(
|
||||
pub(crate) async fn authenticate_with_password(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
config: &AuthenticationConfig,
|
||||
@@ -98,7 +98,7 @@ impl PoolingBackend {
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn authenticate_with_jwt(
|
||||
pub(crate) async fn authenticate_with_jwt(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
@@ -135,7 +135,7 @@ impl PoolingBackend {
|
||||
// we reuse the code from the usual proxy and we need to prepare few structures
|
||||
// that this code expects.
|
||||
#[tracing::instrument(fields(pid = tracing::field::Empty), skip_all)]
|
||||
pub async fn connect_to_compute(
|
||||
pub(crate) async fn connect_to_compute(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
conn_info: ConnInfo,
|
||||
@@ -175,7 +175,7 @@ impl PoolingBackend {
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum HttpConnError {
|
||||
pub(crate) enum HttpConnError {
|
||||
#[error("pooled connection closed at inconsistent state")]
|
||||
ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError<uuid::Uuid>),
|
||||
#[error("could not connection to compute")]
|
||||
|
||||
@@ -22,7 +22,7 @@ pub struct CancelSet {
|
||||
hasher: Hasher,
|
||||
}
|
||||
|
||||
pub struct CancelShard {
|
||||
pub(crate) struct CancelShard {
|
||||
tokens: IndexMap<uuid::Uuid, (Instant, CancellationToken), Hasher>,
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ impl CancelSet {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn take(&self) -> Option<CancellationToken> {
|
||||
pub(crate) fn take(&self) -> Option<CancellationToken> {
|
||||
for _ in 0..4 {
|
||||
if let Some(token) = self.take_raw(thread_rng().gen()) {
|
||||
return Some(token);
|
||||
@@ -50,12 +50,12 @@ impl CancelSet {
|
||||
None
|
||||
}
|
||||
|
||||
pub fn take_raw(&self, rng: usize) -> Option<CancellationToken> {
|
||||
pub(crate) fn take_raw(&self, rng: usize) -> Option<CancellationToken> {
|
||||
NonZeroUsize::new(self.shards.len())
|
||||
.and_then(|len| self.shards[rng % len].lock().take(rng / len))
|
||||
}
|
||||
|
||||
pub fn insert(&self, id: uuid::Uuid, token: CancellationToken) -> CancelGuard<'_> {
|
||||
pub(crate) fn insert(&self, id: uuid::Uuid, token: CancellationToken) -> CancelGuard<'_> {
|
||||
let shard = NonZeroUsize::new(self.shards.len()).map(|len| {
|
||||
let hash = self.hasher.hash_one(id) as usize;
|
||||
let shard = &self.shards[hash % len];
|
||||
@@ -88,7 +88,7 @@ impl CancelShard {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CancelGuard<'a> {
|
||||
pub(crate) struct CancelGuard<'a> {
|
||||
shard: Option<&'a Mutex<CancelShard>>,
|
||||
id: Uuid,
|
||||
}
|
||||
|
||||
@@ -30,25 +30,25 @@ use tracing::{info, info_span, Instrument};
|
||||
use super::backend::HttpConnError;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ConnInfo {
|
||||
pub user_info: ComputeUserInfo,
|
||||
pub dbname: DbName,
|
||||
pub auth: AuthData,
|
||||
pub(crate) struct ConnInfo {
|
||||
pub(crate) user_info: ComputeUserInfo,
|
||||
pub(crate) dbname: DbName,
|
||||
pub(crate) auth: AuthData,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum AuthData {
|
||||
pub(crate) enum AuthData {
|
||||
Password(SmallVec<[u8; 16]>),
|
||||
Jwt(String),
|
||||
}
|
||||
|
||||
impl ConnInfo {
|
||||
// hm, change to hasher to avoid cloning?
|
||||
pub fn db_and_user(&self) -> (DbName, RoleName) {
|
||||
pub(crate) fn db_and_user(&self) -> (DbName, RoleName) {
|
||||
(self.dbname.clone(), self.user_info.user.clone())
|
||||
}
|
||||
|
||||
pub fn endpoint_cache_key(&self) -> Option<EndpointCacheKey> {
|
||||
pub(crate) fn endpoint_cache_key(&self) -> Option<EndpointCacheKey> {
|
||||
// We don't want to cache http connections for ephemeral endpoints.
|
||||
if self.user_info.options.is_ephemeral() {
|
||||
None
|
||||
@@ -79,7 +79,7 @@ struct ConnPoolEntry<C: ClientInnerExt> {
|
||||
|
||||
// Per-endpoint connection pool, (dbname, username) -> DbUserConnPool
|
||||
// Number of open connections is limited by the `max_conns_per_endpoint`.
|
||||
pub struct EndpointConnPool<C: ClientInnerExt> {
|
||||
pub(crate) struct EndpointConnPool<C: ClientInnerExt> {
|
||||
pools: HashMap<(DbName, RoleName), DbUserConnPool<C>>,
|
||||
total_conns: usize,
|
||||
max_conns: usize,
|
||||
@@ -198,7 +198,7 @@ impl<C: ClientInnerExt> Drop for EndpointConnPool<C> {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DbUserConnPool<C: ClientInnerExt> {
|
||||
pub(crate) struct DbUserConnPool<C: ClientInnerExt> {
|
||||
conns: Vec<ConnPoolEntry<C>>,
|
||||
}
|
||||
|
||||
@@ -241,7 +241,7 @@ impl<C: ClientInnerExt> DbUserConnPool<C> {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct GlobalConnPool<C: ClientInnerExt> {
|
||||
pub(crate) struct GlobalConnPool<C: ClientInnerExt> {
|
||||
// endpoint -> per-endpoint connection pool
|
||||
//
|
||||
// That should be a fairly conteded map, so return reference to the per-endpoint
|
||||
@@ -282,7 +282,7 @@ pub struct GlobalConnPoolOptions {
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> GlobalConnPool<C> {
|
||||
pub fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
|
||||
pub(crate) fn new(config: &'static crate::config::HttpConfig) -> Arc<Self> {
|
||||
let shards = config.pool_options.pool_shards;
|
||||
Arc::new(Self {
|
||||
global_pool: DashMap::with_shard_amount(shards),
|
||||
@@ -293,21 +293,21 @@ impl<C: ClientInnerExt> GlobalConnPool<C> {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn get_global_connections_count(&self) -> usize {
|
||||
pub(crate) fn get_global_connections_count(&self) -> usize {
|
||||
self.global_connections_count
|
||||
.load(atomic::Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn get_idle_timeout(&self) -> Duration {
|
||||
pub(crate) fn get_idle_timeout(&self) -> Duration {
|
||||
self.config.pool_options.idle_timeout
|
||||
}
|
||||
|
||||
pub fn shutdown(&self) {
|
||||
pub(crate) fn shutdown(&self) {
|
||||
// drops all strong references to endpoint-pools
|
||||
self.global_pool.clear();
|
||||
}
|
||||
|
||||
pub async fn gc_worker(&self, mut rng: impl Rng) {
|
||||
pub(crate) async fn gc_worker(&self, mut rng: impl Rng) {
|
||||
let epoch = self.config.pool_options.gc_epoch;
|
||||
let mut interval = tokio::time::interval(epoch / (self.global_pool.shards().len()) as u32);
|
||||
loop {
|
||||
@@ -381,7 +381,7 @@ impl<C: ClientInnerExt> GlobalConnPool<C> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get(
|
||||
pub(crate) fn get(
|
||||
self: &Arc<Self>,
|
||||
ctx: &RequestMonitoring,
|
||||
conn_info: &ConnInfo,
|
||||
@@ -468,7 +468,7 @@ impl<C: ClientInnerExt> GlobalConnPool<C> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn poll_client<C: ClientInnerExt>(
|
||||
pub(crate) fn poll_client<C: ClientInnerExt>(
|
||||
global_pool: Arc<GlobalConnPool<C>>,
|
||||
ctx: &RequestMonitoring,
|
||||
conn_info: ConnInfo,
|
||||
@@ -596,7 +596,7 @@ impl<C: ClientInnerExt> Drop for ClientInner<C> {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ClientInnerExt: Sync + Send + 'static {
|
||||
pub(crate) trait ClientInnerExt: Sync + Send + 'static {
|
||||
fn is_closed(&self) -> bool;
|
||||
fn get_process_id(&self) -> i32;
|
||||
}
|
||||
@@ -611,13 +611,13 @@ impl ClientInnerExt for tokio_postgres::Client {
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> ClientInner<C> {
|
||||
pub fn is_closed(&self) -> bool {
|
||||
pub(crate) fn is_closed(&self) -> bool {
|
||||
self.inner.is_closed()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Client<C> {
|
||||
pub fn metrics(&self) -> Arc<MetricCounter> {
|
||||
pub(crate) fn metrics(&self) -> Arc<MetricCounter> {
|
||||
let aux = &self.inner.as_ref().unwrap().aux;
|
||||
USAGE_METRICS.register(Ids {
|
||||
endpoint_id: aux.endpoint_id,
|
||||
@@ -626,14 +626,14 @@ impl<C: ClientInnerExt> Client<C> {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Client<C: ClientInnerExt> {
|
||||
pub(crate) struct Client<C: ClientInnerExt> {
|
||||
span: Span,
|
||||
inner: Option<ClientInner<C>>,
|
||||
conn_info: ConnInfo,
|
||||
pool: Weak<RwLock<EndpointConnPool<C>>>,
|
||||
}
|
||||
|
||||
pub struct Discard<'a, C: ClientInnerExt> {
|
||||
pub(crate) struct Discard<'a, C: ClientInnerExt> {
|
||||
conn_info: &'a ConnInfo,
|
||||
pool: &'a mut Weak<RwLock<EndpointConnPool<C>>>,
|
||||
}
|
||||
@@ -651,7 +651,7 @@ impl<C: ClientInnerExt> Client<C> {
|
||||
pool,
|
||||
}
|
||||
}
|
||||
pub fn inner(&mut self) -> (&mut C, Discard<'_, C>) {
|
||||
pub(crate) fn inner(&mut self) -> (&mut C, Discard<'_, C>) {
|
||||
let Self {
|
||||
inner,
|
||||
pool,
|
||||
@@ -664,13 +664,13 @@ impl<C: ClientInnerExt> Client<C> {
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> Discard<'_, C> {
|
||||
pub fn check_idle(&mut self, status: ReadyForQueryStatus) {
|
||||
pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) {
|
||||
let conn_info = &self.conn_info;
|
||||
if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
|
||||
info!("pool: throwing away connection '{conn_info}' because connection is not idle");
|
||||
}
|
||||
}
|
||||
pub fn discard(&mut self) {
|
||||
pub(crate) fn discard(&mut self) {
|
||||
let conn_info = &self.conn_info;
|
||||
if std::mem::take(self.pool).strong_count() > 0 {
|
||||
info!("pool: throwing away connection '{conn_info}' because connection is potentially in a broken state");
|
||||
|
||||
@@ -11,7 +11,7 @@ use serde::Serialize;
|
||||
use utils::http::error::ApiError;
|
||||
|
||||
/// Like [`ApiError::into_response`]
|
||||
pub fn api_error_into_response(this: ApiError) -> Response<Full<Bytes>> {
|
||||
pub(crate) fn api_error_into_response(this: ApiError) -> Response<Full<Bytes>> {
|
||||
match this {
|
||||
ApiError::BadRequest(err) => HttpErrorBody::response_from_msg_and_status(
|
||||
format!("{err:#?}"), // use debug printing so that we give the cause
|
||||
@@ -59,7 +59,7 @@ pub fn api_error_into_response(this: ApiError) -> Response<Full<Bytes>> {
|
||||
/// Same as [`utils::http::error::HttpErrorBody`]
|
||||
#[derive(Serialize)]
|
||||
struct HttpErrorBody {
|
||||
pub msg: String,
|
||||
pub(crate) msg: String,
|
||||
}
|
||||
|
||||
impl HttpErrorBody {
|
||||
@@ -80,7 +80,7 @@ impl HttpErrorBody {
|
||||
}
|
||||
|
||||
/// Same as [`utils::http::json::json_response`]
|
||||
pub fn json_response<T: Serialize>(
|
||||
pub(crate) fn json_response<T: Serialize>(
|
||||
status: StatusCode,
|
||||
data: T,
|
||||
) -> Result<Response<Full<Bytes>>, ApiError> {
|
||||
|
||||
@@ -8,7 +8,7 @@ use tokio_postgres::Row;
|
||||
// Convert json non-string types to strings, so that they can be passed to Postgres
|
||||
// as parameters.
|
||||
//
|
||||
pub fn json_to_pg_text(json: Vec<Value>) -> Vec<Option<String>> {
|
||||
pub(crate) fn json_to_pg_text(json: Vec<Value>) -> Vec<Option<String>> {
|
||||
json.iter().map(json_value_to_pg_text).collect()
|
||||
}
|
||||
|
||||
@@ -61,7 +61,7 @@ fn json_array_to_pg_array(value: &Value) -> Option<String> {
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum JsonConversionError {
|
||||
pub(crate) enum JsonConversionError {
|
||||
#[error("internal error compute returned invalid data: {0}")]
|
||||
AsTextError(tokio_postgres::Error),
|
||||
#[error("parse int error: {0}")]
|
||||
@@ -77,7 +77,7 @@ pub enum JsonConversionError {
|
||||
//
|
||||
// Convert postgres row with text-encoded values to JSON object
|
||||
//
|
||||
pub fn pg_text_row_to_json(
|
||||
pub(crate) fn pg_text_row_to_json(
|
||||
row: &Row,
|
||||
columns: &[Type],
|
||||
raw_output: bool,
|
||||
|
||||
@@ -110,7 +110,7 @@ where
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ConnInfoError {
|
||||
pub(crate) enum ConnInfoError {
|
||||
#[error("invalid header: {0}")]
|
||||
InvalidHeader(&'static HeaderName),
|
||||
#[error("invalid connection string: {0}")]
|
||||
@@ -246,7 +246,7 @@ fn get_conn_info(
|
||||
}
|
||||
|
||||
// TODO: return different http error codes
|
||||
pub async fn handle(
|
||||
pub(crate) async fn handle(
|
||||
config: &'static ProxyConfig,
|
||||
ctx: RequestMonitoring,
|
||||
request: Request<Incoming>,
|
||||
@@ -359,7 +359,7 @@ pub async fn handle(
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum SqlOverHttpError {
|
||||
pub(crate) enum SqlOverHttpError {
|
||||
#[error("{0}")]
|
||||
ReadPayload(#[from] ReadPayloadError),
|
||||
#[error("{0}")]
|
||||
@@ -413,7 +413,7 @@ impl UserFacingError for SqlOverHttpError {
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ReadPayloadError {
|
||||
pub(crate) enum ReadPayloadError {
|
||||
#[error("could not read the HTTP request body: {0}")]
|
||||
Read(#[from] hyper1::Error),
|
||||
#[error("could not parse the HTTP request body: {0}")]
|
||||
@@ -430,7 +430,7 @@ impl ReportableError for ReadPayloadError {
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum SqlOverHttpCancel {
|
||||
pub(crate) enum SqlOverHttpCancel {
|
||||
#[error("query was cancelled")]
|
||||
Postgres,
|
||||
#[error("query was cancelled while stuck trying to connect to the database")]
|
||||
|
||||
@@ -27,7 +27,7 @@ use tracing::warn;
|
||||
pin_project! {
|
||||
/// This is a wrapper around a [`WebSocketStream`] that
|
||||
/// implements [`AsyncRead`] and [`AsyncWrite`].
|
||||
pub struct WebSocketRw<S> {
|
||||
pub(crate) struct WebSocketRw<S> {
|
||||
#[pin]
|
||||
stream: WebSocketServer<S>,
|
||||
recv: Bytes,
|
||||
@@ -36,7 +36,7 @@ pin_project! {
|
||||
}
|
||||
|
||||
impl<S> WebSocketRw<S> {
|
||||
pub fn new(stream: WebSocketServer<S>) -> Self {
|
||||
pub(crate) fn new(stream: WebSocketServer<S>) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
recv: Bytes::new(),
|
||||
@@ -127,7 +127,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncBufRead for WebSocketRw<S> {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn serve_websocket(
|
||||
pub(crate) async fn serve_websocket(
|
||||
config: &'static ProxyConfig,
|
||||
ctx: RequestMonitoring,
|
||||
websocket: OnUpgrade,
|
||||
|
||||
@@ -35,7 +35,7 @@ impl<S> PqStream<S> {
|
||||
}
|
||||
|
||||
/// Get a shared reference to the underlying stream.
|
||||
pub fn get_ref(&self) -> &S {
|
||||
pub(crate) fn get_ref(&self) -> &S {
|
||||
self.framed.get_ref()
|
||||
}
|
||||
}
|
||||
@@ -62,7 +62,7 @@ impl<S: AsyncRead + Unpin> PqStream<S> {
|
||||
.ok_or_else(err_connection)
|
||||
}
|
||||
|
||||
pub async fn read_password_message(&mut self) -> io::Result<bytes::Bytes> {
|
||||
pub(crate) async fn read_password_message(&mut self) -> io::Result<bytes::Bytes> {
|
||||
match self.read_message().await? {
|
||||
FeMessage::PasswordMessage(msg) => Ok(msg),
|
||||
bad => Err(io::Error::new(
|
||||
@@ -99,7 +99,10 @@ impl ReportableError for ReportedError {
|
||||
|
||||
impl<S: AsyncWrite + Unpin> PqStream<S> {
|
||||
/// Write the message into an internal buffer, but don't flush the underlying stream.
|
||||
pub fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
|
||||
pub(crate) fn write_message_noflush(
|
||||
&mut self,
|
||||
message: &BeMessage<'_>,
|
||||
) -> io::Result<&mut Self> {
|
||||
self.framed
|
||||
.write_message(message)
|
||||
.map_err(ProtocolError::into_io_error)?;
|
||||
@@ -114,7 +117,7 @@ impl<S: AsyncWrite + Unpin> PqStream<S> {
|
||||
}
|
||||
|
||||
/// Flush the output buffer into the underlying stream.
|
||||
pub async fn flush(&mut self) -> io::Result<&mut Self> {
|
||||
pub(crate) async fn flush(&mut self) -> io::Result<&mut Self> {
|
||||
self.framed.flush().await?;
|
||||
Ok(self)
|
||||
}
|
||||
@@ -146,7 +149,7 @@ impl<S: AsyncWrite + Unpin> PqStream<S> {
|
||||
|
||||
/// Write the error message using [`Self::write_message`], then re-throw it.
|
||||
/// Trait [`UserFacingError`] acts as an allowlist for error types.
|
||||
pub async fn throw_error<T, E>(&mut self, error: E) -> Result<T, ReportedError>
|
||||
pub(crate) async fn throw_error<T, E>(&mut self, error: E) -> Result<T, ReportedError>
|
||||
where
|
||||
E: UserFacingError + Into<anyhow::Error>,
|
||||
{
|
||||
@@ -200,7 +203,7 @@ impl<S> Stream<S> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tls_server_end_point(&self) -> TlsServerEndPoint {
|
||||
pub(crate) fn tls_server_end_point(&self) -> TlsServerEndPoint {
|
||||
match self {
|
||||
Stream::Raw { .. } => TlsServerEndPoint::Undefined,
|
||||
Stream::Tls {
|
||||
|
||||
@@ -7,12 +7,12 @@ pub struct ApiUrl(url::Url);
|
||||
|
||||
impl ApiUrl {
|
||||
/// Consume the wrapper and return inner [url](url::Url).
|
||||
pub fn into_inner(self) -> url::Url {
|
||||
pub(crate) fn into_inner(self) -> url::Url {
|
||||
self.0
|
||||
}
|
||||
|
||||
/// See [`url::Url::path_segments_mut`].
|
||||
pub fn path_segments_mut(&mut self) -> url::PathSegmentsMut<'_> {
|
||||
pub(crate) fn path_segments_mut(&mut self) -> url::PathSegmentsMut<'_> {
|
||||
// We've already verified that it works during construction.
|
||||
self.0.path_segments_mut().expect("bad API url")
|
||||
}
|
||||
|
||||
@@ -43,12 +43,12 @@ const DEFAULT_HTTP_REPORTING_TIMEOUT: Duration = Duration::from_secs(60);
|
||||
/// so while the project-id is unique across regions the whole pipeline will work correctly
|
||||
/// because we enrich the event with project_id in the control-plane endpoint.
|
||||
#[derive(Eq, Hash, PartialEq, Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct Ids {
|
||||
pub endpoint_id: EndpointIdInt,
|
||||
pub branch_id: BranchIdInt,
|
||||
pub(crate) struct Ids {
|
||||
pub(crate) endpoint_id: EndpointIdInt,
|
||||
pub(crate) branch_id: BranchIdInt,
|
||||
}
|
||||
|
||||
pub trait MetricCounterRecorder {
|
||||
pub(crate) trait MetricCounterRecorder {
|
||||
/// Record that some bytes were sent from the proxy to the client
|
||||
fn record_egress(&self, bytes: u64);
|
||||
/// Record that some connections were opened
|
||||
@@ -92,7 +92,7 @@ impl MetricCounterReporter for MetricBackupCounter {
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MetricCounter {
|
||||
pub(crate) struct MetricCounter {
|
||||
transmitted: AtomicU64,
|
||||
opened_connections: AtomicUsize,
|
||||
backup: Arc<MetricBackupCounter>,
|
||||
@@ -173,14 +173,14 @@ impl<C: MetricCounterReporter> Clearable for C {
|
||||
type FastHasher = std::hash::BuildHasherDefault<rustc_hash::FxHasher>;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct Metrics {
|
||||
pub(crate) struct Metrics {
|
||||
endpoints: DashMap<Ids, Arc<MetricCounter>, FastHasher>,
|
||||
backup_endpoints: DashMap<Ids, Arc<MetricBackupCounter>, FastHasher>,
|
||||
}
|
||||
|
||||
impl Metrics {
|
||||
/// Register a new byte metrics counter for this endpoint
|
||||
pub fn register(&self, ids: Ids) -> Arc<MetricCounter> {
|
||||
pub(crate) fn register(&self, ids: Ids) -> Arc<MetricCounter> {
|
||||
let backup = if let Some(entry) = self.backup_endpoints.get(&ids) {
|
||||
entry.clone()
|
||||
} else {
|
||||
@@ -215,7 +215,7 @@ impl Metrics {
|
||||
}
|
||||
}
|
||||
|
||||
pub static USAGE_METRICS: Lazy<Metrics> = Lazy::new(Metrics::default);
|
||||
pub(crate) static USAGE_METRICS: Lazy<Metrics> = Lazy::new(Metrics::default);
|
||||
|
||||
pub async fn task_main(config: &MetricCollectionConfig) -> anyhow::Result<Infallible> {
|
||||
info!("metrics collector config: {config:?}");
|
||||
|
||||
@@ -7,13 +7,13 @@ use thiserror::Error;
|
||||
use tokio::sync::oneshot;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum RegisterError {
|
||||
pub(crate) enum RegisterError {
|
||||
#[error("Waiter `{0}` already registered")]
|
||||
Occupied(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum NotifyError {
|
||||
pub(crate) enum NotifyError {
|
||||
#[error("Notify failed: waiter `{0}` not registered")]
|
||||
NotFound(String),
|
||||
|
||||
@@ -22,12 +22,12 @@ pub enum NotifyError {
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum WaitError {
|
||||
pub(crate) enum WaitError {
|
||||
#[error("Wait failed: channel hangup")]
|
||||
Hangup,
|
||||
}
|
||||
|
||||
pub struct Waiters<T>(pub(self) Mutex<HashMap<String, oneshot::Sender<T>>>);
|
||||
pub(crate) struct Waiters<T>(pub(self) Mutex<HashMap<String, oneshot::Sender<T>>>);
|
||||
|
||||
impl<T> Default for Waiters<T> {
|
||||
fn default() -> Self {
|
||||
@@ -36,7 +36,7 @@ impl<T> Default for Waiters<T> {
|
||||
}
|
||||
|
||||
impl<T> Waiters<T> {
|
||||
pub fn register(&self, key: String) -> Result<Waiter<'_, T>, RegisterError> {
|
||||
pub(crate) fn register(&self, key: String) -> Result<Waiter<'_, T>, RegisterError> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
self.0
|
||||
@@ -53,7 +53,7 @@ impl<T> Waiters<T> {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn notify(&self, key: &str, value: T) -> Result<(), NotifyError>
|
||||
pub(crate) fn notify(&self, key: &str, value: T) -> Result<(), NotifyError>
|
||||
where
|
||||
T: Send + Sync,
|
||||
{
|
||||
@@ -79,7 +79,7 @@ impl<'a, T> Drop for DropKey<'a, T> {
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
pub struct Waiter<'a, T> {
|
||||
pub(crate) struct Waiter<'a, T> {
|
||||
#[pin]
|
||||
receiver: oneshot::Receiver<T>,
|
||||
guard: DropKey<'a, T>,
|
||||
|
||||
Reference in New Issue
Block a user