mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-15 17:32:56 +00:00
proxy: experiment with idea to split crates
This commit is contained in:
31
Cargo.lock
generated
31
Cargo.lock
generated
@@ -4472,7 +4472,7 @@ dependencies = [
|
||||
"postgres-protocol",
|
||||
"postgres_backend",
|
||||
"pq_proto",
|
||||
"prometheus",
|
||||
"proxy-sasl",
|
||||
"rand 0.8.5",
|
||||
"rand_distr",
|
||||
"rcgen",
|
||||
@@ -4525,6 +4525,35 @@ dependencies = [
|
||||
"x509-parser",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proxy-sasl"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
"anyhow",
|
||||
"base64 0.13.1",
|
||||
"bytes",
|
||||
"crossbeam-deque",
|
||||
"hmac",
|
||||
"itertools 0.10.5",
|
||||
"lasso",
|
||||
"measured",
|
||||
"parking_lot 0.12.1",
|
||||
"pbkdf2",
|
||||
"postgres-protocol",
|
||||
"pq_proto",
|
||||
"rand 0.8.5",
|
||||
"rustls 0.22.4",
|
||||
"sha2",
|
||||
"subtle",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"uuid",
|
||||
"workspace_hack",
|
||||
"x509-parser",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quick-xml"
|
||||
version = "0.31.0"
|
||||
|
||||
@@ -10,6 +10,7 @@ members = [
|
||||
"pageserver/client",
|
||||
"pageserver/pagebench",
|
||||
"proxy/core",
|
||||
"proxy/sasl",
|
||||
"safekeeper",
|
||||
"storage_broker",
|
||||
"storage_controller",
|
||||
|
||||
@@ -9,6 +9,8 @@ default = []
|
||||
testing = []
|
||||
|
||||
[dependencies]
|
||||
proxy-sasl = { version = "0.1", path = "../sasl" }
|
||||
|
||||
ahash.workspace = true
|
||||
anyhow.workspace = true
|
||||
arc-swap.workspace = true
|
||||
@@ -59,7 +61,7 @@ parquet_derive.workspace = true
|
||||
pin-project-lite.workspace = true
|
||||
postgres_backend.workspace = true
|
||||
pq_proto.workspace = true
|
||||
prometheus.workspace = true
|
||||
# prometheus.workspace = true
|
||||
rand.workspace = true
|
||||
regex.workspace = true
|
||||
remote_storage = { version = "0.1", path = "../../libs/remote_storage/" }
|
||||
|
||||
@@ -38,7 +38,7 @@ pub enum AuthErrorImpl {
|
||||
|
||||
/// SASL protocol errors (includes [SCRAM](crate::scram)).
|
||||
#[error(transparent)]
|
||||
Sasl(#[from] crate::sasl::Error),
|
||||
Sasl(#[from] proxy_sasl::sasl::Error),
|
||||
|
||||
#[error("Unsupported authentication method: {0}")]
|
||||
BadAuthMethod(Box<str>),
|
||||
@@ -148,3 +148,28 @@ impl ReportableError for AuthError {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for proxy_sasl::sasl::Error {
|
||||
fn to_string_client(&self) -> String {
|
||||
match self {
|
||||
proxy_sasl::sasl::Error::ChannelBindingFailed(m) => m.to_string(),
|
||||
proxy_sasl::sasl::Error::ChannelBindingBadMethod(m) => {
|
||||
format!("unsupported channel binding method {m}")
|
||||
}
|
||||
_ => "authentication protocol violation".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for proxy_sasl::sasl::Error {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
proxy_sasl::sasl::Error::ChannelBindingFailed(_) => crate::error::ErrorKind::User,
|
||||
proxy_sasl::sasl::Error::ChannelBindingBadMethod(_) => crate::error::ErrorKind::User,
|
||||
proxy_sasl::sasl::Error::BadClientMessage(_) => crate::error::ErrorKind::User,
|
||||
proxy_sasl::sasl::Error::MissingBinding => crate::error::ErrorKind::Service,
|
||||
proxy_sasl::sasl::Error::Base64(_) => crate::error::ErrorKind::ControlPlane,
|
||||
proxy_sasl::sasl::Error::Io(_) => crate::error::ErrorKind::ClientDisconnect,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ use std::time::Duration;
|
||||
|
||||
use ipnet::{Ipv4Net, Ipv6Net};
|
||||
pub use link::LinkAuthError;
|
||||
use proxy_sasl::scram;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_postgres::config::AuthKeys;
|
||||
use tracing::{info, warn};
|
||||
@@ -36,7 +37,7 @@ use crate::{
|
||||
},
|
||||
stream, url,
|
||||
};
|
||||
use crate::{scram, EndpointCacheKey, EndpointId, RoleName};
|
||||
use crate::{EndpointCacheKey, EndpointId, RoleName};
|
||||
|
||||
/// Alternative to [`std::borrow::Cow`] but doesn't need `T: ToOwned` as we don't need that functionality
|
||||
pub enum MaybeOwned<'a, T> {
|
||||
@@ -371,8 +372,8 @@ async fn authenticate_with_secret(
|
||||
let auth_outcome =
|
||||
validate_password_and_exchange(&config.thread_pool, ep, &password, secret).await?;
|
||||
let keys = match auth_outcome {
|
||||
crate::sasl::Outcome::Success(key) => key,
|
||||
crate::sasl::Outcome::Failure(reason) => {
|
||||
proxy_sasl::sasl::Outcome::Success(key) => key,
|
||||
proxy_sasl::sasl::Outcome::Failure(reason) => {
|
||||
info!("auth backend failed with an error: {reason}");
|
||||
return Err(auth::AuthError::auth_failed(&*info.user));
|
||||
}
|
||||
@@ -558,9 +559,9 @@ mod tests {
|
||||
context::RequestMonitoring,
|
||||
proxy::NeonOptions,
|
||||
rate_limiter::{EndpointRateLimiter, RateBucketInfo},
|
||||
scram::{threadpool::ThreadPool, ServerSecret},
|
||||
stream::{PqStream, Stream},
|
||||
};
|
||||
use proxy_sasl::scram::{threadpool::ThreadPool, ServerSecret};
|
||||
|
||||
use super::{auth_quirks, AuthRateLimiter};
|
||||
|
||||
@@ -669,7 +670,11 @@ mod tests {
|
||||
let ctx = RequestMonitoring::test();
|
||||
let api = Auth {
|
||||
ips: vec![],
|
||||
secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
|
||||
secret: AuthSecret::Scram(
|
||||
ServerSecret::build_test_secret("my-secret-password")
|
||||
.await
|
||||
.unwrap(),
|
||||
),
|
||||
};
|
||||
|
||||
let user_info = ComputeUserInfoMaybeEndpoint {
|
||||
@@ -746,7 +751,11 @@ mod tests {
|
||||
let ctx = RequestMonitoring::test();
|
||||
let api = Auth {
|
||||
ips: vec![],
|
||||
secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
|
||||
secret: AuthSecret::Scram(
|
||||
ServerSecret::build_test_secret("my-secret-password")
|
||||
.await
|
||||
.unwrap(),
|
||||
),
|
||||
};
|
||||
|
||||
let user_info = ComputeUserInfoMaybeEndpoint {
|
||||
@@ -798,7 +807,11 @@ mod tests {
|
||||
let ctx = RequestMonitoring::test();
|
||||
let api = Auth {
|
||||
ips: vec![],
|
||||
secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
|
||||
secret: AuthSecret::Scram(
|
||||
ServerSecret::build_test_secret("my-secret-password")
|
||||
.await
|
||||
.unwrap(),
|
||||
),
|
||||
};
|
||||
|
||||
let user_info = ComputeUserInfoMaybeEndpoint {
|
||||
|
||||
@@ -5,9 +5,9 @@ use crate::{
|
||||
config::AuthenticationConfig,
|
||||
console::AuthSecret,
|
||||
context::RequestMonitoring,
|
||||
sasl,
|
||||
stream::{PqStream, Stream},
|
||||
};
|
||||
use proxy_sasl::sasl;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{info, warn};
|
||||
|
||||
|
||||
@@ -7,9 +7,9 @@ use crate::{
|
||||
console::AuthSecret,
|
||||
context::RequestMonitoring,
|
||||
intern::EndpointIdInt,
|
||||
sasl,
|
||||
stream::{self, Stream},
|
||||
};
|
||||
use proxy_sasl::sasl;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{info, warn};
|
||||
|
||||
|
||||
@@ -2,16 +2,17 @@
|
||||
|
||||
use super::{backend::ComputeCredentialKeys, AuthErrorImpl, PasswordHackPayload};
|
||||
use crate::{
|
||||
config::TlsServerEndPoint,
|
||||
console::AuthSecret,
|
||||
context::RequestMonitoring,
|
||||
intern::EndpointIdInt,
|
||||
sasl,
|
||||
scram::{self, threadpool::ThreadPool},
|
||||
stream::{PqStream, Stream},
|
||||
};
|
||||
use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
|
||||
use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
|
||||
use proxy_sasl::{
|
||||
sasl,
|
||||
scram::{self, threadpool::ThreadPool, TlsServerEndPoint},
|
||||
};
|
||||
use std::{io, sync::Arc};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::info;
|
||||
@@ -56,7 +57,7 @@ impl AuthMethod for PasswordHack {
|
||||
/// Use clear-text password auth called `password` in docs
|
||||
/// <https://www.postgresql.org/docs/current/auth-password.html>
|
||||
pub struct CleartextPassword {
|
||||
pub pool: Arc<ThreadPool>,
|
||||
pub pool: Arc<ThreadPool<EndpointIdInt>>,
|
||||
pub endpoint: EndpointIdInt,
|
||||
pub secret: AuthSecret,
|
||||
}
|
||||
@@ -174,7 +175,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
|
||||
}
|
||||
info!("client chooses {}", sasl.method);
|
||||
|
||||
let outcome = sasl::SaslStream::new(self.stream, sasl.message)
|
||||
let outcome = sasl::SaslStream::new(&mut self.stream.framed, sasl.message)
|
||||
.authenticate(scram::Exchange::new(
|
||||
secret,
|
||||
rand::random,
|
||||
@@ -191,7 +192,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
|
||||
}
|
||||
|
||||
pub(crate) async fn validate_password_and_exchange(
|
||||
pool: &ThreadPool,
|
||||
pool: &ThreadPool<EndpointIdInt>,
|
||||
endpoint: EndpointIdInt,
|
||||
password: &[u8],
|
||||
secret: AuthSecret,
|
||||
@@ -206,7 +207,8 @@ pub(crate) async fn validate_password_and_exchange(
|
||||
}
|
||||
// perform scram authentication as both client and server to validate the keys
|
||||
AuthSecret::Scram(scram_secret) => {
|
||||
let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?;
|
||||
let outcome =
|
||||
proxy_sasl::scram::exchange(pool, endpoint, &scram_secret, password).await?;
|
||||
|
||||
let client_key = match outcome {
|
||||
sasl::Outcome::Success(client_key) => client_key,
|
||||
|
||||
@@ -7,10 +7,11 @@ use std::{net::SocketAddr, sync::Arc};
|
||||
|
||||
use futures::future::Either;
|
||||
use itertools::Itertools;
|
||||
use proxy::config::TlsServerEndPoint;
|
||||
use proxy::context::RequestMonitoring;
|
||||
use proxy::metrics::{Metrics, ThreadPoolMetrics};
|
||||
use proxy::metrics::Metrics;
|
||||
use proxy::proxy::{copy_bidirectional_client_compute, run_until_cancelled, ErrorSource};
|
||||
use proxy_sasl::scram::threadpool::ThreadPoolMetrics;
|
||||
use proxy_sasl::scram::TlsServerEndPoint;
|
||||
use rustls::pki_types::PrivateKeyDer;
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
|
||||
@@ -30,7 +30,6 @@ use proxy::redis::cancellation_publisher::RedisPublisherClient;
|
||||
use proxy::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
|
||||
use proxy::redis::elasticache;
|
||||
use proxy::redis::notifications;
|
||||
use proxy::scram::threadpool::ThreadPool;
|
||||
use proxy::serverless::cancel_set::CancelSet;
|
||||
use proxy::serverless::GlobalConnPoolOptions;
|
||||
use proxy::usage_metrics;
|
||||
@@ -38,6 +37,7 @@ use proxy::usage_metrics;
|
||||
use anyhow::bail;
|
||||
use proxy::config::{self, ProxyConfig};
|
||||
use proxy::serverless;
|
||||
use proxy_sasl::scram::threadpool::ThreadPool;
|
||||
use remote_storage::RemoteStorageConfig;
|
||||
use std::net::SocketAddr;
|
||||
use std::pin::pin;
|
||||
@@ -607,7 +607,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
timeout,
|
||||
epoch,
|
||||
&Metrics::get().wake_compute_lock,
|
||||
)?));
|
||||
)));
|
||||
tokio::spawn(locks.garbage_collect_worker());
|
||||
|
||||
let url = args.auth_endpoint.parse()?;
|
||||
@@ -658,7 +658,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
timeout,
|
||||
epoch,
|
||||
&Metrics::get().proxy.connect_compute_lock,
|
||||
)?;
|
||||
);
|
||||
|
||||
let http_config = HttpConfig {
|
||||
pool_options: GlobalConnPoolOptions {
|
||||
|
||||
3
proxy/core/src/cache/project_info.rs
vendored
3
proxy/core/src/cache/project_info.rs
vendored
@@ -371,7 +371,8 @@ impl Cache for ProjectInfoCacheImpl {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{scram::ServerSecret, ProjectId};
|
||||
use crate::ProjectId;
|
||||
use proxy_sasl::scram::ServerSecret;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_project_info_cache_settings() {
|
||||
|
||||
@@ -1,27 +1,26 @@
|
||||
use crate::{
|
||||
auth::{self, backend::AuthRateLimiter},
|
||||
console::locks::ApiLocks,
|
||||
intern::EndpointIdInt,
|
||||
rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig},
|
||||
scram::threadpool::ThreadPool,
|
||||
serverless::{cancel_set::CancelSet, GlobalConnPoolOptions},
|
||||
Host,
|
||||
};
|
||||
|
||||
use anyhow::{bail, ensure, Context, Ok};
|
||||
use itertools::Itertools;
|
||||
use proxy_sasl::scram::{threadpool::ThreadPool, TlsServerEndPoint};
|
||||
use remote_storage::RemoteStorageConfig;
|
||||
use rustls::{
|
||||
crypto::ring::sign,
|
||||
pki_types::{CertificateDer, PrivateKeyDer},
|
||||
};
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
str::FromStr,
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
use tracing::{error, info};
|
||||
use x509_parser::oid_registry;
|
||||
|
||||
pub struct ProxyConfig {
|
||||
pub tls_config: Option<TlsConfig>,
|
||||
@@ -58,7 +57,7 @@ pub struct HttpConfig {
|
||||
}
|
||||
|
||||
pub struct AuthenticationConfig {
|
||||
pub thread_pool: Arc<ThreadPool>,
|
||||
pub thread_pool: Arc<ThreadPool<EndpointIdInt>>,
|
||||
pub scram_protocol_timeout: tokio::time::Duration,
|
||||
pub rate_limiter_enabled: bool,
|
||||
pub rate_limiter: AuthRateLimiter,
|
||||
@@ -126,66 +125,6 @@ pub fn configure_tls(
|
||||
})
|
||||
}
|
||||
|
||||
/// Channel binding parameter
|
||||
///
|
||||
/// <https://www.rfc-editor.org/rfc/rfc5929#section-4>
|
||||
/// Description: The hash of the TLS server's certificate as it
|
||||
/// appears, octet for octet, in the server's Certificate message. Note
|
||||
/// that the Certificate message contains a certificate_list, in which
|
||||
/// the first element is the server's certificate.
|
||||
///
|
||||
/// The hash function is to be selected as follows:
|
||||
///
|
||||
/// * if the certificate's signatureAlgorithm uses a single hash
|
||||
/// function, and that hash function is either MD5 or SHA-1, then use SHA-256;
|
||||
///
|
||||
/// * if the certificate's signatureAlgorithm uses a single hash
|
||||
/// function and that hash function neither MD5 nor SHA-1, then use
|
||||
/// the hash function associated with the certificate's
|
||||
/// signatureAlgorithm;
|
||||
///
|
||||
/// * if the certificate's signatureAlgorithm uses no hash functions or
|
||||
/// uses multiple hash functions, then this channel binding type's
|
||||
/// channel bindings are undefined at this time (updates to is channel
|
||||
/// binding type may occur to address this issue if it ever arises).
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum TlsServerEndPoint {
|
||||
Sha256([u8; 32]),
|
||||
Undefined,
|
||||
}
|
||||
|
||||
impl TlsServerEndPoint {
|
||||
pub fn new(cert: &CertificateDer) -> anyhow::Result<Self> {
|
||||
let sha256_oids = [
|
||||
// I'm explicitly not adding MD5 or SHA1 here... They're bad.
|
||||
oid_registry::OID_SIG_ECDSA_WITH_SHA256,
|
||||
oid_registry::OID_PKCS1_SHA256WITHRSA,
|
||||
];
|
||||
|
||||
let pem = x509_parser::parse_x509_certificate(cert)
|
||||
.context("Failed to parse PEM object from cerficiate")?
|
||||
.1;
|
||||
|
||||
info!(subject = %pem.subject, "parsing TLS certificate");
|
||||
|
||||
let reg = oid_registry::OidRegistry::default().with_all_crypto();
|
||||
let oid = pem.signature_algorithm.oid();
|
||||
let alg = reg.get(oid);
|
||||
if sha256_oids.contains(oid) {
|
||||
let tls_server_end_point: [u8; 32] = Sha256::new().chain_update(cert).finalize().into();
|
||||
info!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding");
|
||||
Ok(Self::Sha256(tls_server_end_point))
|
||||
} else {
|
||||
error!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), "unknown channel binding");
|
||||
Ok(Self::Undefined)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn supported(&self) -> bool {
|
||||
!matches!(self, TlsServerEndPoint::Undefined)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
pub struct CertResolver {
|
||||
certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,
|
||||
|
||||
@@ -16,9 +16,10 @@ use crate::{
|
||||
intern::ProjectIdInt,
|
||||
metrics::ApiLockMetrics,
|
||||
rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token},
|
||||
scram, EndpointCacheKey,
|
||||
EndpointCacheKey,
|
||||
};
|
||||
use dashmap::DashMap;
|
||||
use proxy_sasl::scram;
|
||||
use std::{hash::Hash, sync::Arc, time::Duration};
|
||||
use tokio::time::Instant;
|
||||
use tracing::info;
|
||||
@@ -469,15 +470,15 @@ impl<K: Hash + Eq + Clone> ApiLocks<K> {
|
||||
timeout: Duration,
|
||||
epoch: std::time::Duration,
|
||||
metrics: &'static ApiLockMetrics,
|
||||
) -> prometheus::Result<Self> {
|
||||
Ok(Self {
|
||||
) -> Self {
|
||||
Self {
|
||||
name,
|
||||
node_locks: DashMap::with_shard_amount(shards),
|
||||
config,
|
||||
timeout,
|
||||
epoch,
|
||||
metrics,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_permit(&self, key: &K) -> Result<WakeComputePermit, ApiLockError> {
|
||||
|
||||
@@ -5,7 +5,7 @@ use super::{
|
||||
AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo,
|
||||
};
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl};
|
||||
use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, url::ApiUrl};
|
||||
use crate::{auth::IpPattern, cache::Cached};
|
||||
use crate::{
|
||||
console::{
|
||||
@@ -15,6 +15,7 @@ use crate::{
|
||||
BranchId, EndpointId, ProjectId,
|
||||
};
|
||||
use futures::TryFutureExt;
|
||||
use proxy_sasl::scram;
|
||||
use std::{str::FromStr, sync::Arc};
|
||||
use thiserror::Error;
|
||||
use tokio_postgres::{config::SslMode, Client};
|
||||
|
||||
@@ -13,10 +13,11 @@ use crate::{
|
||||
http,
|
||||
metrics::{CacheOutcome, Metrics},
|
||||
rate_limiter::WakeComputeRateLimiter,
|
||||
scram, EndpointCacheKey,
|
||||
EndpointCacheKey,
|
||||
};
|
||||
use crate::{cache::Cached, context::RequestMonitoring};
|
||||
use futures::TryFutureExt;
|
||||
use proxy_sasl::scram;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use tokio::time::Instant;
|
||||
use tokio_postgres::config::SslMode;
|
||||
|
||||
@@ -21,13 +21,13 @@ pub mod intern;
|
||||
pub mod jemalloc;
|
||||
pub mod logging;
|
||||
pub mod metrics;
|
||||
pub mod parse;
|
||||
// pub mod parse;
|
||||
pub mod protocol2;
|
||||
pub mod proxy;
|
||||
pub mod rate_limiter;
|
||||
pub mod redis;
|
||||
pub mod sasl;
|
||||
pub mod scram;
|
||||
// pub mod sasl;
|
||||
// pub mod scram;
|
||||
pub mod serverless;
|
||||
pub mod stream;
|
||||
pub mod url;
|
||||
|
||||
@@ -2,13 +2,14 @@ use std::sync::{Arc, OnceLock};
|
||||
|
||||
use lasso::ThreadedRodeo;
|
||||
use measured::{
|
||||
label::{FixedCardinalitySet, LabelGroupSet, LabelName, LabelSet, LabelValue, StaticLabelSet},
|
||||
label::StaticLabelSet,
|
||||
metric::{histogram::Thresholds, name::MetricName},
|
||||
Counter, CounterVec, FixedCardinalityLabel, Gauge, GaugeVec, Histogram, HistogramVec,
|
||||
LabelGroup, MetricGroup,
|
||||
Counter, CounterVec, FixedCardinalityLabel, Gauge, Histogram, HistogramVec, LabelGroup,
|
||||
MetricGroup,
|
||||
};
|
||||
use metrics::{CounterPairAssoc, CounterPairVec, HyperLogLog, HyperLogLogVec};
|
||||
|
||||
use proxy_sasl::scram::threadpool::ThreadPoolMetrics;
|
||||
use tokio::time::{self, Instant};
|
||||
|
||||
use crate::console::messages::ColdStartInfo;
|
||||
@@ -546,78 +547,3 @@ pub enum RedisEventsCount {
|
||||
PasswordUpdate,
|
||||
AllowedIpsUpdate,
|
||||
}
|
||||
|
||||
pub struct ThreadPoolWorkers(usize);
|
||||
pub struct ThreadPoolWorkerId(pub usize);
|
||||
|
||||
impl LabelValue for ThreadPoolWorkerId {
|
||||
fn visit<V: measured::label::LabelVisitor>(&self, v: V) -> V::Output {
|
||||
v.write_int(self.0 as i64)
|
||||
}
|
||||
}
|
||||
|
||||
impl LabelGroup for ThreadPoolWorkerId {
|
||||
fn visit_values(&self, v: &mut impl measured::label::LabelGroupVisitor) {
|
||||
v.write_value(LabelName::from_str("worker"), self);
|
||||
}
|
||||
}
|
||||
|
||||
impl LabelGroupSet for ThreadPoolWorkers {
|
||||
type Group<'a> = ThreadPoolWorkerId;
|
||||
|
||||
fn cardinality(&self) -> Option<usize> {
|
||||
Some(self.0)
|
||||
}
|
||||
|
||||
fn encode_dense(&self, value: Self::Unique) -> Option<usize> {
|
||||
Some(value)
|
||||
}
|
||||
|
||||
fn decode_dense(&self, value: usize) -> Self::Group<'_> {
|
||||
ThreadPoolWorkerId(value)
|
||||
}
|
||||
|
||||
type Unique = usize;
|
||||
|
||||
fn encode(&self, value: Self::Group<'_>) -> Option<Self::Unique> {
|
||||
Some(value.0)
|
||||
}
|
||||
|
||||
fn decode(&self, value: &Self::Unique) -> Self::Group<'_> {
|
||||
ThreadPoolWorkerId(*value)
|
||||
}
|
||||
}
|
||||
|
||||
impl LabelSet for ThreadPoolWorkers {
|
||||
type Value<'a> = ThreadPoolWorkerId;
|
||||
|
||||
fn dynamic_cardinality(&self) -> Option<usize> {
|
||||
Some(self.0)
|
||||
}
|
||||
|
||||
fn encode(&self, value: Self::Value<'_>) -> Option<usize> {
|
||||
(value.0 < self.0).then_some(value.0)
|
||||
}
|
||||
|
||||
fn decode(&self, value: usize) -> Self::Value<'_> {
|
||||
ThreadPoolWorkerId(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl FixedCardinalitySet for ThreadPoolWorkers {
|
||||
fn cardinality(&self) -> usize {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(MetricGroup)]
|
||||
#[metric(new(workers: usize))]
|
||||
pub struct ThreadPoolMetrics {
|
||||
pub injector_queue_depth: Gauge,
|
||||
#[metric(init = GaugeVec::with_label_set(ThreadPoolWorkers(workers)))]
|
||||
pub worker_queue_depth: GaugeVec<ThreadPoolWorkers>,
|
||||
#[metric(init = CounterVec::with_label_set(ThreadPoolWorkers(workers)))]
|
||||
pub worker_task_turns_total: CounterVec<ThreadPoolWorkers>,
|
||||
#[metric(init = CounterVec::with_label_set(ThreadPoolWorkers(workers)))]
|
||||
pub worker_task_skips_total: CounterVec<ThreadPoolWorkers>,
|
||||
}
|
||||
|
||||
@@ -16,9 +16,10 @@ use crate::console::messages::{ConsoleError, Details, MetricsAuxInfo, Status};
|
||||
use crate::console::provider::{CachedAllowedIps, CachedRoleSecret, ConsoleBackend};
|
||||
use crate::console::{self, CachedNodeInfo, NodeInfo};
|
||||
use crate::error::ErrorKind;
|
||||
use crate::{http, sasl, scram, BranchId, EndpointId, ProjectId};
|
||||
use crate::{http, BranchId, EndpointId, ProjectId};
|
||||
use anyhow::{bail, Context};
|
||||
use async_trait::async_trait;
|
||||
use proxy_sasl::{sasl, scram};
|
||||
use retry::{retry_after, ShouldRetryWakeCompute};
|
||||
use rstest::rstest;
|
||||
use rustls::pki_types;
|
||||
@@ -137,7 +138,7 @@ struct Scram(scram::ServerSecret);
|
||||
|
||||
impl Scram {
|
||||
async fn new(password: &str) -> anyhow::Result<Self> {
|
||||
let secret = scram::ServerSecret::build(password)
|
||||
let secret = scram::ServerSecret::build_test_secret(password)
|
||||
.await
|
||||
.context("failed to generate scram secret")?;
|
||||
Ok(Scram(secret))
|
||||
|
||||
@@ -79,11 +79,11 @@ impl PoolingBackend {
|
||||
)
|
||||
.await?;
|
||||
let res = match auth_outcome {
|
||||
crate::sasl::Outcome::Success(key) => {
|
||||
proxy_sasl::sasl::Outcome::Success(key) => {
|
||||
info!("user successfully authenticated");
|
||||
Ok(key)
|
||||
}
|
||||
crate::sasl::Outcome::Failure(reason) => {
|
||||
proxy_sasl::sasl::Outcome::Failure(reason) => {
|
||||
info!("auth backend failed with an error: {reason}");
|
||||
Err(AuthError::auth_failed(&*conn_info.user_info.user))
|
||||
}
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
use crate::config::TlsServerEndPoint;
|
||||
use crate::error::{ErrorKind, ReportableError, UserFacingError};
|
||||
use crate::metrics::Metrics;
|
||||
use bytes::BytesMut;
|
||||
|
||||
use pq_proto::framed::{ConnectionError, Framed};
|
||||
use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
|
||||
use proxy_sasl::scram::TlsServerEndPoint;
|
||||
use rustls::ServerConfig;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
128
proxy/sasl/Cargo.toml
Normal file
128
proxy/sasl/Cargo.toml
Normal file
@@ -0,0 +1,128 @@
|
||||
[package]
|
||||
name = "proxy-sasl"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[features]
|
||||
default = []
|
||||
testing = []
|
||||
|
||||
[dependencies]
|
||||
ahash.workspace = true
|
||||
anyhow.workspace = true
|
||||
# arc-swap.workspace = true
|
||||
# async-compression.workspace = true
|
||||
# async-trait.workspace = true
|
||||
# atomic-take.workspace = true
|
||||
# aws-config.workspace = true
|
||||
# aws-sdk-iam.workspace = true
|
||||
# aws-sigv4.workspace = true
|
||||
# aws-types.workspace = true
|
||||
base64.workspace = true
|
||||
# bstr.workspace = true
|
||||
bytes = { workspace = true, features = ["serde"] }
|
||||
# camino.workspace = true
|
||||
# chrono.workspace = true
|
||||
# clap.workspace = true
|
||||
# consumption_metrics.workspace = true
|
||||
crossbeam-deque.workspace = true
|
||||
# dashmap.workspace = true
|
||||
# env_logger.workspace = true
|
||||
# framed-websockets.workspace = true
|
||||
# futures.workspace = true
|
||||
# git-version.workspace = true
|
||||
# hashbrown.workspace = true
|
||||
# hashlink.workspace = true
|
||||
# hex.workspace = true
|
||||
hmac.workspace = true
|
||||
# hostname.workspace = true
|
||||
# http.workspace = true
|
||||
# humantime.workspace = true
|
||||
# humantime-serde.workspace = true
|
||||
# hyper.workspace = true
|
||||
# hyper1 = { package = "hyper", version = "1.2", features = ["server"] }
|
||||
# hyper-util = { version = "0.1", features = ["server", "http1", "http2", "tokio"] }
|
||||
# http-body-util = { version = "0.1" }
|
||||
# indexmap.workspace = true
|
||||
# ipnet.workspace = true
|
||||
itertools.workspace = true
|
||||
lasso = { workspace = true, features = ["multi-threaded"] }
|
||||
# md5.workspace = true
|
||||
measured = { workspace = true, features = ["lasso"] }
|
||||
# metrics.workspace = true
|
||||
# once_cell.workspace = true
|
||||
# opentelemetry.workspace = true
|
||||
parking_lot.workspace = true
|
||||
# parquet.workspace = true
|
||||
# parquet_derive.workspace = true
|
||||
# pin-project-lite.workspace = true
|
||||
# postgres_backend.workspace = true
|
||||
pq_proto.workspace = true
|
||||
# prometheus.workspace = true
|
||||
rand.workspace = true
|
||||
# regex.workspace = true
|
||||
# remote_storage = { version = "0.1", path = "../../libs/remote_storage/" }
|
||||
# reqwest.workspace = true
|
||||
# reqwest-middleware = { workspace = true, features = ["json"] }
|
||||
# reqwest-retry.workspace = true
|
||||
# reqwest-tracing.workspace = true
|
||||
# routerify.workspace = true
|
||||
# rustc-hash.workspace = true
|
||||
# rustls-pemfile.workspace = true
|
||||
rustls.workspace = true
|
||||
# scopeguard.workspace = true
|
||||
# serde.workspace = true
|
||||
# serde_json.workspace = true
|
||||
sha2 = { workspace = true, features = ["asm", "oid"] }
|
||||
# smol_str.workspace = true
|
||||
# smallvec.workspace = true
|
||||
# socket2.workspace = true
|
||||
subtle.workspace = true
|
||||
# task-local-extensions.workspace = true
|
||||
thiserror.workspace = true
|
||||
# tikv-jemallocator.workspace = true
|
||||
# tikv-jemalloc-ctl = { workspace = true, features = ["use_std"] }
|
||||
# tokio-postgres.workspace = true
|
||||
# tokio-postgres-rustls.workspace = true
|
||||
# tokio-rustls.workspace = true
|
||||
# tokio-util.workspace = true
|
||||
tokio = { workspace = true, features = ["signal"] }
|
||||
# tower-service.workspace = true
|
||||
# tracing-opentelemetry.workspace = true
|
||||
# tracing-subscriber.workspace = true
|
||||
# tracing-utils.workspace = true
|
||||
tracing.workspace = true
|
||||
# try-lock.workspace = true
|
||||
# typed-json.workspace = true
|
||||
# url.workspace = true
|
||||
# urlencoding.workspace = true
|
||||
# utils.workspace = true
|
||||
# uuid.workspace = true
|
||||
# rustls-native-certs.workspace = true
|
||||
x509-parser.workspace = true
|
||||
postgres-protocol.workspace = true
|
||||
# redis.workspace = true
|
||||
|
||||
# # jwt stuff
|
||||
# jose-jwa = "0.1.2"
|
||||
# jose-jwk = { version = "0.1.2", features = ["p256", "p384", "rsa"] }
|
||||
# signature = "2"
|
||||
# ecdsa = "0.16"
|
||||
# p256 = "0.13"
|
||||
# rsa = "0.9"
|
||||
|
||||
workspace_hack.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
# camino-tempfile.workspace = true
|
||||
# fallible-iterator.workspace = true
|
||||
# tokio-tungstenite.workspace = true
|
||||
pbkdf2 = { workspace = true, features = ["simple", "std"] }
|
||||
# rcgen.workspace = true
|
||||
# rstest.workspace = true
|
||||
# tokio-postgres-rustls.workspace = true
|
||||
# walkdir.workspace = true
|
||||
# rand_distr = "0.4"
|
||||
|
||||
uuid.workspace = true
|
||||
3
proxy/sasl/src/lib.rs
Normal file
3
proxy/sasl/src/lib.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
mod parse;
|
||||
pub mod sasl;
|
||||
pub mod scram;
|
||||
43
proxy/sasl/src/parse.rs
Normal file
43
proxy/sasl/src/parse.rs
Normal file
@@ -0,0 +1,43 @@
|
||||
//! Small parsing helpers.
|
||||
|
||||
use std::ffi::CStr;
|
||||
|
||||
pub fn split_cstr(bytes: &[u8]) -> Option<(&CStr, &[u8])> {
|
||||
let cstr = CStr::from_bytes_until_nul(bytes).ok()?;
|
||||
let (_, other) = bytes.split_at(cstr.to_bytes_with_nul().len());
|
||||
Some((cstr, other))
|
||||
}
|
||||
|
||||
/// See <https://doc.rust-lang.org/std/primitive.slice.html#method.split_array_ref>.
|
||||
pub fn split_at_const<const N: usize>(bytes: &[u8]) -> Option<(&[u8; N], &[u8])> {
|
||||
(bytes.len() >= N).then(|| {
|
||||
let (head, tail) = bytes.split_at(N);
|
||||
(head.try_into().unwrap(), tail)
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_split_cstr() {
|
||||
assert!(split_cstr(b"").is_none());
|
||||
assert!(split_cstr(b"foo").is_none());
|
||||
|
||||
let (cstr, rest) = split_cstr(b"\0").expect("uh-oh");
|
||||
assert_eq!(cstr.to_bytes(), b"");
|
||||
assert_eq!(rest, b"");
|
||||
|
||||
let (cstr, rest) = split_cstr(b"foo\0bar").expect("uh-oh");
|
||||
assert_eq!(cstr.to_bytes(), b"foo");
|
||||
assert_eq!(rest, b"bar");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_split_at_const() {
|
||||
assert!(split_at_const::<0>(b"").is_some());
|
||||
assert!(split_at_const::<1>(b"").is_none());
|
||||
assert!(matches!(split_at_const::<1>(b"ok"), Some((b"o", b"k"))));
|
||||
}
|
||||
}
|
||||
@@ -10,7 +10,7 @@ mod channel_binding;
|
||||
mod messages;
|
||||
mod stream;
|
||||
|
||||
use crate::error::{ReportableError, UserFacingError};
|
||||
// use crate::error::{ReportableError, UserFacingError};
|
||||
use std::io;
|
||||
use thiserror::Error;
|
||||
|
||||
@@ -40,29 +40,29 @@ pub enum Error {
|
||||
Io(#[from] io::Error),
|
||||
}
|
||||
|
||||
impl UserFacingError for Error {
|
||||
fn to_string_client(&self) -> String {
|
||||
use Error::*;
|
||||
match self {
|
||||
ChannelBindingFailed(m) => m.to_string(),
|
||||
ChannelBindingBadMethod(m) => format!("unsupported channel binding method {m}"),
|
||||
_ => "authentication protocol violation".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
// impl UserFacingError for Error {
|
||||
// fn to_string_client(&self) -> String {
|
||||
// use Error::*;
|
||||
// match self {
|
||||
// ChannelBindingFailed(m) => m.to_string(),
|
||||
// ChannelBindingBadMethod(m) => format!("unsupported channel binding method {m}"),
|
||||
// _ => "authentication protocol violation".to_string(),
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
impl ReportableError for Error {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
Error::ChannelBindingFailed(_) => crate::error::ErrorKind::User,
|
||||
Error::ChannelBindingBadMethod(_) => crate::error::ErrorKind::User,
|
||||
Error::BadClientMessage(_) => crate::error::ErrorKind::User,
|
||||
Error::MissingBinding => crate::error::ErrorKind::Service,
|
||||
Error::Base64(_) => crate::error::ErrorKind::ControlPlane,
|
||||
Error::Io(_) => crate::error::ErrorKind::ClientDisconnect,
|
||||
}
|
||||
}
|
||||
}
|
||||
// impl ReportableError for Error {
|
||||
// fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
// match self {
|
||||
// Error::ChannelBindingFailed(_) => crate::error::ErrorKind::User,
|
||||
// Error::ChannelBindingBadMethod(_) => crate::error::ErrorKind::User,
|
||||
// Error::BadClientMessage(_) => crate::error::ErrorKind::User,
|
||||
// Error::MissingBinding => crate::error::ErrorKind::Service,
|
||||
// Error::Base64(_) => crate::error::ErrorKind::ControlPlane,
|
||||
// Error::Io(_) => crate::error::ErrorKind::ClientDisconnect,
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
/// A convenient result type for SASL exchange.
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
@@ -1,7 +1,10 @@
|
||||
//! Abstraction for the string-oriented SASL protocols.
|
||||
|
||||
use super::{messages::ServerMessage, Mechanism};
|
||||
use crate::stream::PqStream;
|
||||
use pq_proto::{
|
||||
framed::{ConnectionError, Framed},
|
||||
FeMessage, ProtocolError,
|
||||
};
|
||||
use std::io;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::info;
|
||||
@@ -9,7 +12,7 @@ use tracing::info;
|
||||
/// Abstracts away all peculiarities of the libpq's protocol.
|
||||
pub struct SaslStream<'a, S> {
|
||||
/// The underlying stream.
|
||||
stream: &'a mut PqStream<S>,
|
||||
stream: &'a mut Framed<S>,
|
||||
/// Current password message we received from client.
|
||||
current: bytes::Bytes,
|
||||
/// First SASL message produced by client.
|
||||
@@ -17,7 +20,7 @@ pub struct SaslStream<'a, S> {
|
||||
}
|
||||
|
||||
impl<'a, S> SaslStream<'a, S> {
|
||||
pub fn new(stream: &'a mut PqStream<S>, first: &'a str) -> Self {
|
||||
pub fn new(stream: &'a mut Framed<S>, first: &'a str) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
current: bytes::Bytes::new(),
|
||||
@@ -26,6 +29,27 @@ impl<'a, S> SaslStream<'a, S> {
|
||||
}
|
||||
}
|
||||
|
||||
fn err_connection() -> io::Error {
|
||||
io::Error::new(io::ErrorKind::ConnectionAborted, "connection is lost")
|
||||
}
|
||||
|
||||
pub async fn read_password_message<S: AsyncRead + Unpin>(
|
||||
framed: &mut Framed<S>,
|
||||
) -> io::Result<bytes::Bytes> {
|
||||
let msg = framed
|
||||
.read_message()
|
||||
.await
|
||||
.map_err(ConnectionError::into_io_error)?
|
||||
.ok_or_else(err_connection)?;
|
||||
match msg {
|
||||
FeMessage::PasswordMessage(msg) => Ok(msg),
|
||||
bad => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("unexpected message type: {:?}", bad),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin> SaslStream<'_, S> {
|
||||
// Receive a new SASL message from the client.
|
||||
async fn recv(&mut self) -> io::Result<&str> {
|
||||
@@ -33,7 +57,7 @@ impl<S: AsyncRead + Unpin> SaslStream<'_, S> {
|
||||
return Ok(first);
|
||||
}
|
||||
|
||||
self.current = self.stream.read_password_message().await?;
|
||||
self.current = read_password_message(self.stream).await?;
|
||||
let s = std::str::from_utf8(&self.current)
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?;
|
||||
|
||||
@@ -44,7 +68,10 @@ impl<S: AsyncRead + Unpin> SaslStream<'_, S> {
|
||||
impl<S: AsyncWrite + Unpin> SaslStream<'_, S> {
|
||||
// Send a SASL message to the client.
|
||||
async fn send(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> {
|
||||
self.stream.write_message(&msg.to_reply()).await?;
|
||||
self.stream
|
||||
.write_message(&msg.to_reply())
|
||||
.map_err(ProtocolError::into_io_error)?;
|
||||
self.stream.flush().await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -15,12 +15,16 @@ mod secret;
|
||||
mod signature;
|
||||
pub mod threadpool;
|
||||
|
||||
use anyhow::Context;
|
||||
pub use exchange::{exchange, Exchange};
|
||||
pub use key::ScramKey;
|
||||
use rustls::pki_types::CertificateDer;
|
||||
pub use secret::ServerSecret;
|
||||
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::{Digest, Sha256};
|
||||
use tracing::{error, info};
|
||||
use x509_parser::oid_registry;
|
||||
|
||||
const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
|
||||
const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
|
||||
@@ -57,12 +61,71 @@ fn sha256<'a>(parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
|
||||
hasher.finalize().into()
|
||||
}
|
||||
|
||||
/// Channel binding parameter
|
||||
///
|
||||
/// <https://www.rfc-editor.org/rfc/rfc5929#section-4>
|
||||
/// Description: The hash of the TLS server's certificate as it
|
||||
/// appears, octet for octet, in the server's Certificate message. Note
|
||||
/// that the Certificate message contains a certificate_list, in which
|
||||
/// the first element is the server's certificate.
|
||||
///
|
||||
/// The hash function is to be selected as follows:
|
||||
///
|
||||
/// * if the certificate's signatureAlgorithm uses a single hash
|
||||
/// function, and that hash function is either MD5 or SHA-1, then use SHA-256;
|
||||
///
|
||||
/// * if the certificate's signatureAlgorithm uses a single hash
|
||||
/// function and that hash function neither MD5 nor SHA-1, then use
|
||||
/// the hash function associated with the certificate's
|
||||
/// signatureAlgorithm;
|
||||
///
|
||||
/// * if the certificate's signatureAlgorithm uses no hash functions or
|
||||
/// uses multiple hash functions, then this channel binding type's
|
||||
/// channel bindings are undefined at this time (updates to is channel
|
||||
/// binding type may occur to address this issue if it ever arises).
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum TlsServerEndPoint {
|
||||
Sha256([u8; 32]),
|
||||
Undefined,
|
||||
}
|
||||
|
||||
impl TlsServerEndPoint {
|
||||
pub fn new(cert: &CertificateDer) -> anyhow::Result<Self> {
|
||||
let sha256_oids = [
|
||||
// I'm explicitly not adding MD5 or SHA1 here... They're bad.
|
||||
oid_registry::OID_SIG_ECDSA_WITH_SHA256,
|
||||
oid_registry::OID_PKCS1_SHA256WITHRSA,
|
||||
];
|
||||
|
||||
let pem = x509_parser::parse_x509_certificate(cert)
|
||||
.context("Failed to parse PEM object from cerficiate")?
|
||||
.1;
|
||||
|
||||
info!(subject = %pem.subject, "parsing TLS certificate");
|
||||
|
||||
let reg = oid_registry::OidRegistry::default().with_all_crypto();
|
||||
let oid = pem.signature_algorithm.oid();
|
||||
let alg = reg.get(oid);
|
||||
if sha256_oids.contains(oid) {
|
||||
let tls_server_end_point: [u8; 32] = Sha256::new().chain_update(cert).finalize().into();
|
||||
info!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding");
|
||||
Ok(Self::Sha256(tls_server_end_point))
|
||||
} else {
|
||||
error!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), "unknown channel binding");
|
||||
Ok(Self::Undefined)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn supported(&self) -> bool {
|
||||
!matches!(self, TlsServerEndPoint::Undefined)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{
|
||||
intern::EndpointIdInt,
|
||||
sasl::{Mechanism, Step},
|
||||
EndpointId,
|
||||
scram::TlsServerEndPoint,
|
||||
};
|
||||
|
||||
use super::{threadpool::ThreadPool, Exchange, ServerSecret};
|
||||
@@ -79,11 +142,7 @@ mod tests {
|
||||
const NONCE: [u8; 18] = [
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
|
||||
];
|
||||
let mut exchange = Exchange::new(
|
||||
&secret,
|
||||
|| NONCE,
|
||||
crate::config::TlsServerEndPoint::Undefined,
|
||||
);
|
||||
let mut exchange = Exchange::new(&secret, || NONCE, TlsServerEndPoint::Undefined);
|
||||
|
||||
let client_first = "n,,n=user,r=rOprNGfwEbeRWgbNEkqO";
|
||||
let client_final = "c=biws,r=rOprNGfwEbeRWgbNEkqOAQIDBAUGBwgJCgsMDQ4PEBES,p=rw1r5Kph5ThxmaUBC2GAQ6MfXbPnNkFiTIvdb/Rear0=";
|
||||
@@ -120,11 +179,11 @@ mod tests {
|
||||
|
||||
async fn run_round_trip_test(server_password: &str, client_password: &str) {
|
||||
let pool = ThreadPool::new(1);
|
||||
let ep = "foo";
|
||||
|
||||
let ep = EndpointId::from("foo");
|
||||
let ep = EndpointIdInt::from(ep);
|
||||
|
||||
let scram_secret = ServerSecret::build(server_password).await.unwrap();
|
||||
let scram_secret = ServerSecret::build_test_secret(server_password)
|
||||
.await
|
||||
.unwrap();
|
||||
let outcome = super::exchange(&pool, ep, &scram_secret, client_password.as_bytes())
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -1,6 +1,7 @@
|
||||
//! Implementation of the SCRAM authentication algorithm.
|
||||
|
||||
use std::convert::Infallible;
|
||||
use std::hash::Hash;
|
||||
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::Sha256;
|
||||
@@ -13,8 +14,6 @@ use super::secret::ServerSecret;
|
||||
use super::signature::SignatureBuilder;
|
||||
use super::threadpool::ThreadPool;
|
||||
use super::ScramKey;
|
||||
use crate::config;
|
||||
use crate::intern::EndpointIdInt;
|
||||
use crate::sasl::{self, ChannelBinding, Error as SaslError};
|
||||
|
||||
/// The only channel binding mode we currently support.
|
||||
@@ -59,14 +58,14 @@ enum ExchangeState {
|
||||
pub struct Exchange<'a> {
|
||||
state: ExchangeState,
|
||||
secret: &'a ServerSecret,
|
||||
tls_server_end_point: config::TlsServerEndPoint,
|
||||
tls_server_end_point: super::TlsServerEndPoint,
|
||||
}
|
||||
|
||||
impl<'a> Exchange<'a> {
|
||||
pub fn new(
|
||||
secret: &'a ServerSecret,
|
||||
nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
|
||||
tls_server_end_point: config::TlsServerEndPoint,
|
||||
tls_server_end_point: super::TlsServerEndPoint,
|
||||
) -> Self {
|
||||
Self {
|
||||
state: ExchangeState::Initial(SaslInitial { nonce }),
|
||||
@@ -77,15 +76,15 @@ impl<'a> Exchange<'a> {
|
||||
}
|
||||
|
||||
// copied from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L236-L248>
|
||||
async fn derive_client_key(
|
||||
pool: &ThreadPool,
|
||||
endpoint: EndpointIdInt,
|
||||
async fn derive_client_key<K: Hash + Send + 'static>(
|
||||
pool: &ThreadPool<K>,
|
||||
concurrency_key: K,
|
||||
password: &[u8],
|
||||
salt: &[u8],
|
||||
iterations: u32,
|
||||
) -> ScramKey {
|
||||
let salted_password = pool
|
||||
.spawn_job(endpoint, Pbkdf2::start(password, salt, iterations))
|
||||
.spawn_job(concurrency_key, Pbkdf2::start(password, salt, iterations))
|
||||
.await
|
||||
.expect("job should not be cancelled");
|
||||
|
||||
@@ -101,14 +100,15 @@ async fn derive_client_key(
|
||||
make_key(b"Client Key").into()
|
||||
}
|
||||
|
||||
pub async fn exchange(
|
||||
pool: &ThreadPool,
|
||||
endpoint: EndpointIdInt,
|
||||
pub async fn exchange<K: Hash + Send + 'static>(
|
||||
pool: &ThreadPool<K>,
|
||||
concurrency_key: K,
|
||||
secret: &ServerSecret,
|
||||
password: &[u8],
|
||||
) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
|
||||
let salt = base64::decode(&secret.salt_base64)?;
|
||||
let client_key = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
|
||||
let client_key =
|
||||
derive_client_key(pool, concurrency_key, password, &salt, secret.iterations).await;
|
||||
|
||||
if secret.is_password_invalid(&client_key).into() {
|
||||
Ok(sasl::Outcome::Failure("password doesn't match"))
|
||||
@@ -121,7 +121,7 @@ impl SaslInitial {
|
||||
fn transition(
|
||||
&self,
|
||||
secret: &ServerSecret,
|
||||
tls_server_end_point: &config::TlsServerEndPoint,
|
||||
tls_server_end_point: &super::TlsServerEndPoint,
|
||||
input: &str,
|
||||
) -> sasl::Result<sasl::Step<SaslSentInner, Infallible>> {
|
||||
let client_first_message = ClientFirstMessage::parse(input)
|
||||
@@ -156,7 +156,7 @@ impl SaslSentInner {
|
||||
fn transition(
|
||||
&self,
|
||||
secret: &ServerSecret,
|
||||
tls_server_end_point: &config::TlsServerEndPoint,
|
||||
tls_server_end_point: &super::TlsServerEndPoint,
|
||||
input: &str,
|
||||
) -> sasl::Result<sasl::Step<Infallible, super::ScramKey>> {
|
||||
let Self {
|
||||
@@ -169,8 +169,8 @@ impl SaslSentInner {
|
||||
.ok_or(SaslError::BadClientMessage("invalid client-final-message"))?;
|
||||
|
||||
let channel_binding = cbind_flag.encode(|_| match tls_server_end_point {
|
||||
config::TlsServerEndPoint::Sha256(x) => Ok(x),
|
||||
config::TlsServerEndPoint::Undefined => Err(SaslError::MissingBinding),
|
||||
super::TlsServerEndPoint::Sha256(x) => Ok(x),
|
||||
super::TlsServerEndPoint::Undefined => Err(SaslError::MissingBinding),
|
||||
})?;
|
||||
|
||||
// This might've been caused by a MITM attack
|
||||
@@ -64,9 +64,7 @@ impl ServerSecret {
|
||||
}
|
||||
|
||||
/// Build a new server secret from the prerequisites.
|
||||
/// XXX: We only use this function in tests.
|
||||
#[cfg(test)]
|
||||
pub async fn build(password: &str) -> Option<Self> {
|
||||
pub async fn build_test_secret(password: &str) -> Option<Self> {
|
||||
Self::parse(&postgres_protocol::password::scram_sha_256(password.as_bytes()).await)
|
||||
}
|
||||
}
|
||||
@@ -4,29 +4,36 @@
|
||||
//! 1. Fairness per endpoint.
|
||||
//! 2. Yield support for high iteration counts.
|
||||
|
||||
use std::sync::{
|
||||
atomic::{AtomicU64, Ordering},
|
||||
Arc,
|
||||
use std::{
|
||||
hash::Hash,
|
||||
sync::{
|
||||
atomic::{AtomicU64, Ordering},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
|
||||
use crossbeam_deque::{Injector, Stealer, Worker};
|
||||
use itertools::Itertools;
|
||||
use measured::{
|
||||
label::{FixedCardinalitySet, LabelGroupSet, LabelName, LabelSet, LabelValue},
|
||||
CounterVec, Gauge, GaugeVec, LabelGroup, MetricGroup,
|
||||
};
|
||||
use parking_lot::{Condvar, Mutex};
|
||||
use rand::Rng;
|
||||
use rand::{rngs::SmallRng, SeedableRng};
|
||||
use tokio::sync::oneshot;
|
||||
|
||||
use crate::{
|
||||
intern::EndpointIdInt,
|
||||
metrics::{ThreadPoolMetrics, ThreadPoolWorkerId},
|
||||
// intern::EndpointIdInt,
|
||||
// metrics::{ThreadPoolMetrics, ThreadPoolWorkerId},
|
||||
scram::countmin::CountMinSketch,
|
||||
};
|
||||
|
||||
use super::pbkdf2::Pbkdf2;
|
||||
|
||||
pub struct ThreadPool {
|
||||
queue: Injector<JobSpec>,
|
||||
stealers: Vec<Stealer<JobSpec>>,
|
||||
pub struct ThreadPool<K> {
|
||||
queue: Injector<JobSpec<K>>,
|
||||
stealers: Vec<Stealer<JobSpec<K>>>,
|
||||
parkers: Vec<(Condvar, Mutex<ThreadState>)>,
|
||||
/// bitpacked representation.
|
||||
/// lower 8 bits = number of sleeping threads
|
||||
@@ -42,7 +49,7 @@ enum ThreadState {
|
||||
Active,
|
||||
}
|
||||
|
||||
impl ThreadPool {
|
||||
impl<K: Hash + Send + 'static> ThreadPool<K> {
|
||||
pub fn new(n_workers: u8) -> Arc<Self> {
|
||||
let workers = (0..n_workers).map(|_| Worker::new_fifo()).collect_vec();
|
||||
let stealers = workers.iter().map(|w| w.stealer()).collect_vec();
|
||||
@@ -68,11 +75,7 @@ impl ThreadPool {
|
||||
pool
|
||||
}
|
||||
|
||||
pub fn spawn_job(
|
||||
&self,
|
||||
endpoint: EndpointIdInt,
|
||||
pbkdf2: Pbkdf2,
|
||||
) -> oneshot::Receiver<[u8; 32]> {
|
||||
pub fn spawn_job(&self, key: K, pbkdf2: Pbkdf2) -> oneshot::Receiver<[u8; 32]> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
let queue_was_empty = self.queue.is_empty();
|
||||
@@ -81,7 +84,7 @@ impl ThreadPool {
|
||||
self.queue.push(JobSpec {
|
||||
response: tx,
|
||||
pbkdf2,
|
||||
endpoint,
|
||||
key,
|
||||
});
|
||||
|
||||
// inspired from <https://github.com/rayon-rs/rayon/blob/3e3962cb8f7b50773bcc360b48a7a674a53a2c77/rayon-core/src/sleep/mod.rs#L242>
|
||||
@@ -139,7 +142,12 @@ impl ThreadPool {
|
||||
}
|
||||
}
|
||||
|
||||
fn steal(&self, rng: &mut impl Rng, skip: usize, worker: &Worker<JobSpec>) -> Option<JobSpec> {
|
||||
fn steal(
|
||||
&self,
|
||||
rng: &mut impl Rng,
|
||||
skip: usize,
|
||||
worker: &Worker<JobSpec<K>>,
|
||||
) -> Option<JobSpec<K>> {
|
||||
// announce thread as idle
|
||||
self.counters.fetch_add(256, Ordering::SeqCst);
|
||||
|
||||
@@ -188,7 +196,11 @@ impl ThreadPool {
|
||||
}
|
||||
}
|
||||
|
||||
fn thread_rt(pool: Arc<ThreadPool>, worker: Worker<JobSpec>, index: usize) {
|
||||
fn thread_rt<K: Hash + Send + 'static>(
|
||||
pool: Arc<ThreadPool<K>>,
|
||||
worker: Worker<JobSpec<K>>,
|
||||
index: usize,
|
||||
) {
|
||||
/// interval when we should steal from the global queue
|
||||
/// so that tail latencies are managed appropriately
|
||||
const STEAL_INTERVAL: usize = 61;
|
||||
@@ -236,7 +248,7 @@ fn thread_rt(pool: Arc<ThreadPool>, worker: Worker<JobSpec>, index: usize) {
|
||||
|
||||
// receiver is closed, cancel the task
|
||||
if !job.response.is_closed() {
|
||||
let rate = sketch.inc_and_return(&job.endpoint, job.pbkdf2.cost());
|
||||
let rate = sketch.inc_and_return(&job.key, job.pbkdf2.cost());
|
||||
|
||||
const P: f64 = 2000.0;
|
||||
// probability decreases as rate increases.
|
||||
@@ -287,24 +299,96 @@ fn thread_rt(pool: Arc<ThreadPool>, worker: Worker<JobSpec>, index: usize) {
|
||||
}
|
||||
}
|
||||
|
||||
struct JobSpec {
|
||||
struct JobSpec<K> {
|
||||
response: oneshot::Sender<[u8; 32]>,
|
||||
pbkdf2: Pbkdf2,
|
||||
endpoint: EndpointIdInt,
|
||||
key: K,
|
||||
}
|
||||
|
||||
pub struct ThreadPoolWorkers(usize);
|
||||
pub struct ThreadPoolWorkerId(pub usize);
|
||||
|
||||
impl LabelValue for ThreadPoolWorkerId {
|
||||
fn visit<V: measured::label::LabelVisitor>(&self, v: V) -> V::Output {
|
||||
v.write_int(self.0 as i64)
|
||||
}
|
||||
}
|
||||
|
||||
impl LabelGroup for ThreadPoolWorkerId {
|
||||
fn visit_values(&self, v: &mut impl measured::label::LabelGroupVisitor) {
|
||||
v.write_value(LabelName::from_str("worker"), self);
|
||||
}
|
||||
}
|
||||
|
||||
impl LabelGroupSet for ThreadPoolWorkers {
|
||||
type Group<'a> = ThreadPoolWorkerId;
|
||||
|
||||
fn cardinality(&self) -> Option<usize> {
|
||||
Some(self.0)
|
||||
}
|
||||
|
||||
fn encode_dense(&self, value: Self::Unique) -> Option<usize> {
|
||||
Some(value)
|
||||
}
|
||||
|
||||
fn decode_dense(&self, value: usize) -> Self::Group<'_> {
|
||||
ThreadPoolWorkerId(value)
|
||||
}
|
||||
|
||||
type Unique = usize;
|
||||
|
||||
fn encode(&self, value: Self::Group<'_>) -> Option<Self::Unique> {
|
||||
Some(value.0)
|
||||
}
|
||||
|
||||
fn decode(&self, value: &Self::Unique) -> Self::Group<'_> {
|
||||
ThreadPoolWorkerId(*value)
|
||||
}
|
||||
}
|
||||
|
||||
impl LabelSet for ThreadPoolWorkers {
|
||||
type Value<'a> = ThreadPoolWorkerId;
|
||||
|
||||
fn dynamic_cardinality(&self) -> Option<usize> {
|
||||
Some(self.0)
|
||||
}
|
||||
|
||||
fn encode(&self, value: Self::Value<'_>) -> Option<usize> {
|
||||
(value.0 < self.0).then_some(value.0)
|
||||
}
|
||||
|
||||
fn decode(&self, value: usize) -> Self::Value<'_> {
|
||||
ThreadPoolWorkerId(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl FixedCardinalitySet for ThreadPoolWorkers {
|
||||
fn cardinality(&self) -> usize {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(MetricGroup)]
|
||||
#[metric(new(workers: usize))]
|
||||
pub struct ThreadPoolMetrics {
|
||||
pub injector_queue_depth: Gauge,
|
||||
#[metric(init = GaugeVec::with_label_set(ThreadPoolWorkers(workers)))]
|
||||
pub worker_queue_depth: GaugeVec<ThreadPoolWorkers>,
|
||||
#[metric(init = CounterVec::with_label_set(ThreadPoolWorkers(workers)))]
|
||||
pub worker_task_turns_total: CounterVec<ThreadPoolWorkers>,
|
||||
#[metric(init = CounterVec::with_label_set(ThreadPoolWorkers(workers)))]
|
||||
pub worker_task_skips_total: CounterVec<ThreadPoolWorkers>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::EndpointId;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn hash_is_correct() {
|
||||
let pool = ThreadPool::new(1);
|
||||
|
||||
let ep = EndpointId::from("foo");
|
||||
let ep = EndpointIdInt::from(ep);
|
||||
let ep = "foo";
|
||||
|
||||
let salt = [0x55; 32];
|
||||
let actual = pool
|
||||
Reference in New Issue
Block a user