proxy: experiment with idea to split crates

This commit is contained in:
Conrad Ludgate
2024-08-13 14:16:50 +01:00
parent a2968c6cf8
commit b62e7c0138
36 changed files with 549 additions and 264 deletions

View File

@@ -38,7 +38,7 @@ pub enum AuthErrorImpl {
/// SASL protocol errors (includes [SCRAM](crate::scram)).
#[error(transparent)]
Sasl(#[from] crate::sasl::Error),
Sasl(#[from] proxy_sasl::sasl::Error),
#[error("Unsupported authentication method: {0}")]
BadAuthMethod(Box<str>),
@@ -148,3 +148,28 @@ impl ReportableError for AuthError {
}
}
}
impl UserFacingError for proxy_sasl::sasl::Error {
fn to_string_client(&self) -> String {
match self {
proxy_sasl::sasl::Error::ChannelBindingFailed(m) => m.to_string(),
proxy_sasl::sasl::Error::ChannelBindingBadMethod(m) => {
format!("unsupported channel binding method {m}")
}
_ => "authentication protocol violation".to_string(),
}
}
}
impl ReportableError for proxy_sasl::sasl::Error {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
proxy_sasl::sasl::Error::ChannelBindingFailed(_) => crate::error::ErrorKind::User,
proxy_sasl::sasl::Error::ChannelBindingBadMethod(_) => crate::error::ErrorKind::User,
proxy_sasl::sasl::Error::BadClientMessage(_) => crate::error::ErrorKind::User,
proxy_sasl::sasl::Error::MissingBinding => crate::error::ErrorKind::Service,
proxy_sasl::sasl::Error::Base64(_) => crate::error::ErrorKind::ControlPlane,
proxy_sasl::sasl::Error::Io(_) => crate::error::ErrorKind::ClientDisconnect,
}
}
}

View File

@@ -9,6 +9,7 @@ use std::time::Duration;
use ipnet::{Ipv4Net, Ipv6Net};
pub use link::LinkAuthError;
use proxy_sasl::scram;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_postgres::config::AuthKeys;
use tracing::{info, warn};
@@ -36,7 +37,7 @@ use crate::{
},
stream, url,
};
use crate::{scram, EndpointCacheKey, EndpointId, RoleName};
use crate::{EndpointCacheKey, EndpointId, RoleName};
/// Alternative to [`std::borrow::Cow`] but doesn't need `T: ToOwned` as we don't need that functionality
pub enum MaybeOwned<'a, T> {
@@ -371,8 +372,8 @@ async fn authenticate_with_secret(
let auth_outcome =
validate_password_and_exchange(&config.thread_pool, ep, &password, secret).await?;
let keys = match auth_outcome {
crate::sasl::Outcome::Success(key) => key,
crate::sasl::Outcome::Failure(reason) => {
proxy_sasl::sasl::Outcome::Success(key) => key,
proxy_sasl::sasl::Outcome::Failure(reason) => {
info!("auth backend failed with an error: {reason}");
return Err(auth::AuthError::auth_failed(&*info.user));
}
@@ -558,9 +559,9 @@ mod tests {
context::RequestMonitoring,
proxy::NeonOptions,
rate_limiter::{EndpointRateLimiter, RateBucketInfo},
scram::{threadpool::ThreadPool, ServerSecret},
stream::{PqStream, Stream},
};
use proxy_sasl::scram::{threadpool::ThreadPool, ServerSecret};
use super::{auth_quirks, AuthRateLimiter};
@@ -669,7 +670,11 @@ mod tests {
let ctx = RequestMonitoring::test();
let api = Auth {
ips: vec![],
secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
secret: AuthSecret::Scram(
ServerSecret::build_test_secret("my-secret-password")
.await
.unwrap(),
),
};
let user_info = ComputeUserInfoMaybeEndpoint {
@@ -746,7 +751,11 @@ mod tests {
let ctx = RequestMonitoring::test();
let api = Auth {
ips: vec![],
secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
secret: AuthSecret::Scram(
ServerSecret::build_test_secret("my-secret-password")
.await
.unwrap(),
),
};
let user_info = ComputeUserInfoMaybeEndpoint {
@@ -798,7 +807,11 @@ mod tests {
let ctx = RequestMonitoring::test();
let api = Auth {
ips: vec![],
secret: AuthSecret::Scram(ServerSecret::build("my-secret-password").await.unwrap()),
secret: AuthSecret::Scram(
ServerSecret::build_test_secret("my-secret-password")
.await
.unwrap(),
),
};
let user_info = ComputeUserInfoMaybeEndpoint {

View File

@@ -5,9 +5,9 @@ use crate::{
config::AuthenticationConfig,
console::AuthSecret,
context::RequestMonitoring,
sasl,
stream::{PqStream, Stream},
};
use proxy_sasl::sasl;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, warn};

View File

@@ -7,9 +7,9 @@ use crate::{
console::AuthSecret,
context::RequestMonitoring,
intern::EndpointIdInt,
sasl,
stream::{self, Stream},
};
use proxy_sasl::sasl;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, warn};

View File

@@ -2,16 +2,17 @@
use super::{backend::ComputeCredentialKeys, AuthErrorImpl, PasswordHackPayload};
use crate::{
config::TlsServerEndPoint,
console::AuthSecret,
context::RequestMonitoring,
intern::EndpointIdInt,
sasl,
scram::{self, threadpool::ThreadPool},
stream::{PqStream, Stream},
};
use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
use proxy_sasl::{
sasl,
scram::{self, threadpool::ThreadPool, TlsServerEndPoint},
};
use std::{io, sync::Arc};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
@@ -56,7 +57,7 @@ impl AuthMethod for PasswordHack {
/// Use clear-text password auth called `password` in docs
/// <https://www.postgresql.org/docs/current/auth-password.html>
pub struct CleartextPassword {
pub pool: Arc<ThreadPool>,
pub pool: Arc<ThreadPool<EndpointIdInt>>,
pub endpoint: EndpointIdInt,
pub secret: AuthSecret,
}
@@ -174,7 +175,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
}
info!("client chooses {}", sasl.method);
let outcome = sasl::SaslStream::new(self.stream, sasl.message)
let outcome = sasl::SaslStream::new(&mut self.stream.framed, sasl.message)
.authenticate(scram::Exchange::new(
secret,
rand::random,
@@ -191,7 +192,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
}
pub(crate) async fn validate_password_and_exchange(
pool: &ThreadPool,
pool: &ThreadPool<EndpointIdInt>,
endpoint: EndpointIdInt,
password: &[u8],
secret: AuthSecret,
@@ -206,7 +207,8 @@ pub(crate) async fn validate_password_and_exchange(
}
// perform scram authentication as both client and server to validate the keys
AuthSecret::Scram(scram_secret) => {
let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?;
let outcome =
proxy_sasl::scram::exchange(pool, endpoint, &scram_secret, password).await?;
let client_key = match outcome {
sasl::Outcome::Success(client_key) => client_key,

View File

@@ -7,10 +7,11 @@ use std::{net::SocketAddr, sync::Arc};
use futures::future::Either;
use itertools::Itertools;
use proxy::config::TlsServerEndPoint;
use proxy::context::RequestMonitoring;
use proxy::metrics::{Metrics, ThreadPoolMetrics};
use proxy::metrics::Metrics;
use proxy::proxy::{copy_bidirectional_client_compute, run_until_cancelled, ErrorSource};
use proxy_sasl::scram::threadpool::ThreadPoolMetrics;
use proxy_sasl::scram::TlsServerEndPoint;
use rustls::pki_types::PrivateKeyDer;
use tokio::net::TcpListener;

View File

@@ -30,7 +30,6 @@ use proxy::redis::cancellation_publisher::RedisPublisherClient;
use proxy::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use proxy::redis::elasticache;
use proxy::redis::notifications;
use proxy::scram::threadpool::ThreadPool;
use proxy::serverless::cancel_set::CancelSet;
use proxy::serverless::GlobalConnPoolOptions;
use proxy::usage_metrics;
@@ -38,6 +37,7 @@ use proxy::usage_metrics;
use anyhow::bail;
use proxy::config::{self, ProxyConfig};
use proxy::serverless;
use proxy_sasl::scram::threadpool::ThreadPool;
use remote_storage::RemoteStorageConfig;
use std::net::SocketAddr;
use std::pin::pin;
@@ -607,7 +607,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
timeout,
epoch,
&Metrics::get().wake_compute_lock,
)?));
)));
tokio::spawn(locks.garbage_collect_worker());
let url = args.auth_endpoint.parse()?;
@@ -658,7 +658,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
timeout,
epoch,
&Metrics::get().proxy.connect_compute_lock,
)?;
);
let http_config = HttpConfig {
pool_options: GlobalConnPoolOptions {

View File

@@ -371,7 +371,8 @@ impl Cache for ProjectInfoCacheImpl {
#[cfg(test)]
mod tests {
use super::*;
use crate::{scram::ServerSecret, ProjectId};
use crate::ProjectId;
use proxy_sasl::scram::ServerSecret;
#[tokio::test]
async fn test_project_info_cache_settings() {

View File

@@ -1,27 +1,26 @@
use crate::{
auth::{self, backend::AuthRateLimiter},
console::locks::ApiLocks,
intern::EndpointIdInt,
rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig},
scram::threadpool::ThreadPool,
serverless::{cancel_set::CancelSet, GlobalConnPoolOptions},
Host,
};
use anyhow::{bail, ensure, Context, Ok};
use itertools::Itertools;
use proxy_sasl::scram::{threadpool::ThreadPool, TlsServerEndPoint};
use remote_storage::RemoteStorageConfig;
use rustls::{
crypto::ring::sign,
pki_types::{CertificateDer, PrivateKeyDer},
};
use sha2::{Digest, Sha256};
use std::{
collections::{HashMap, HashSet},
str::FromStr,
sync::Arc,
time::Duration,
};
use tracing::{error, info};
use x509_parser::oid_registry;
pub struct ProxyConfig {
pub tls_config: Option<TlsConfig>,
@@ -58,7 +57,7 @@ pub struct HttpConfig {
}
pub struct AuthenticationConfig {
pub thread_pool: Arc<ThreadPool>,
pub thread_pool: Arc<ThreadPool<EndpointIdInt>>,
pub scram_protocol_timeout: tokio::time::Duration,
pub rate_limiter_enabled: bool,
pub rate_limiter: AuthRateLimiter,
@@ -126,66 +125,6 @@ pub fn configure_tls(
})
}
/// Channel binding parameter
///
/// <https://www.rfc-editor.org/rfc/rfc5929#section-4>
/// Description: The hash of the TLS server's certificate as it
/// appears, octet for octet, in the server's Certificate message. Note
/// that the Certificate message contains a certificate_list, in which
/// the first element is the server's certificate.
///
/// The hash function is to be selected as follows:
///
/// * if the certificate's signatureAlgorithm uses a single hash
/// function, and that hash function is either MD5 or SHA-1, then use SHA-256;
///
/// * if the certificate's signatureAlgorithm uses a single hash
/// function and that hash function neither MD5 nor SHA-1, then use
/// the hash function associated with the certificate's
/// signatureAlgorithm;
///
/// * if the certificate's signatureAlgorithm uses no hash functions or
/// uses multiple hash functions, then this channel binding type's
/// channel bindings are undefined at this time (updates to is channel
/// binding type may occur to address this issue if it ever arises).
#[derive(Debug, Clone, Copy)]
pub enum TlsServerEndPoint {
Sha256([u8; 32]),
Undefined,
}
impl TlsServerEndPoint {
pub fn new(cert: &CertificateDer) -> anyhow::Result<Self> {
let sha256_oids = [
// I'm explicitly not adding MD5 or SHA1 here... They're bad.
oid_registry::OID_SIG_ECDSA_WITH_SHA256,
oid_registry::OID_PKCS1_SHA256WITHRSA,
];
let pem = x509_parser::parse_x509_certificate(cert)
.context("Failed to parse PEM object from cerficiate")?
.1;
info!(subject = %pem.subject, "parsing TLS certificate");
let reg = oid_registry::OidRegistry::default().with_all_crypto();
let oid = pem.signature_algorithm.oid();
let alg = reg.get(oid);
if sha256_oids.contains(oid) {
let tls_server_end_point: [u8; 32] = Sha256::new().chain_update(cert).finalize().into();
info!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding");
Ok(Self::Sha256(tls_server_end_point))
} else {
error!(subject = %pem.subject, signature_algorithm = alg.map(|a| a.description()), "unknown channel binding");
Ok(Self::Undefined)
}
}
pub fn supported(&self) -> bool {
!matches!(self, TlsServerEndPoint::Undefined)
}
}
#[derive(Default, Debug)]
pub struct CertResolver {
certs: HashMap<String, (Arc<rustls::sign::CertifiedKey>, TlsServerEndPoint)>,

View File

@@ -16,9 +16,10 @@ use crate::{
intern::ProjectIdInt,
metrics::ApiLockMetrics,
rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token},
scram, EndpointCacheKey,
EndpointCacheKey,
};
use dashmap::DashMap;
use proxy_sasl::scram;
use std::{hash::Hash, sync::Arc, time::Duration};
use tokio::time::Instant;
use tracing::info;
@@ -469,15 +470,15 @@ impl<K: Hash + Eq + Clone> ApiLocks<K> {
timeout: Duration,
epoch: std::time::Duration,
metrics: &'static ApiLockMetrics,
) -> prometheus::Result<Self> {
Ok(Self {
) -> Self {
Self {
name,
node_locks: DashMap::with_shard_amount(shards),
config,
timeout,
epoch,
metrics,
})
}
}
pub async fn get_permit(&self, key: &K) -> Result<WakeComputePermit, ApiLockError> {

View File

@@ -5,7 +5,7 @@ use super::{
AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo,
};
use crate::context::RequestMonitoring;
use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl};
use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, url::ApiUrl};
use crate::{auth::IpPattern, cache::Cached};
use crate::{
console::{
@@ -15,6 +15,7 @@ use crate::{
BranchId, EndpointId, ProjectId,
};
use futures::TryFutureExt;
use proxy_sasl::scram;
use std::{str::FromStr, sync::Arc};
use thiserror::Error;
use tokio_postgres::{config::SslMode, Client};

View File

@@ -13,10 +13,11 @@ use crate::{
http,
metrics::{CacheOutcome, Metrics},
rate_limiter::WakeComputeRateLimiter,
scram, EndpointCacheKey,
EndpointCacheKey,
};
use crate::{cache::Cached, context::RequestMonitoring};
use futures::TryFutureExt;
use proxy_sasl::scram;
use std::{sync::Arc, time::Duration};
use tokio::time::Instant;
use tokio_postgres::config::SslMode;

View File

@@ -21,13 +21,13 @@ pub mod intern;
pub mod jemalloc;
pub mod logging;
pub mod metrics;
pub mod parse;
// pub mod parse;
pub mod protocol2;
pub mod proxy;
pub mod rate_limiter;
pub mod redis;
pub mod sasl;
pub mod scram;
// pub mod sasl;
// pub mod scram;
pub mod serverless;
pub mod stream;
pub mod url;

View File

@@ -2,13 +2,14 @@ use std::sync::{Arc, OnceLock};
use lasso::ThreadedRodeo;
use measured::{
label::{FixedCardinalitySet, LabelGroupSet, LabelName, LabelSet, LabelValue, StaticLabelSet},
label::StaticLabelSet,
metric::{histogram::Thresholds, name::MetricName},
Counter, CounterVec, FixedCardinalityLabel, Gauge, GaugeVec, Histogram, HistogramVec,
LabelGroup, MetricGroup,
Counter, CounterVec, FixedCardinalityLabel, Gauge, Histogram, HistogramVec, LabelGroup,
MetricGroup,
};
use metrics::{CounterPairAssoc, CounterPairVec, HyperLogLog, HyperLogLogVec};
use proxy_sasl::scram::threadpool::ThreadPoolMetrics;
use tokio::time::{self, Instant};
use crate::console::messages::ColdStartInfo;
@@ -546,78 +547,3 @@ pub enum RedisEventsCount {
PasswordUpdate,
AllowedIpsUpdate,
}
pub struct ThreadPoolWorkers(usize);
pub struct ThreadPoolWorkerId(pub usize);
impl LabelValue for ThreadPoolWorkerId {
fn visit<V: measured::label::LabelVisitor>(&self, v: V) -> V::Output {
v.write_int(self.0 as i64)
}
}
impl LabelGroup for ThreadPoolWorkerId {
fn visit_values(&self, v: &mut impl measured::label::LabelGroupVisitor) {
v.write_value(LabelName::from_str("worker"), self);
}
}
impl LabelGroupSet for ThreadPoolWorkers {
type Group<'a> = ThreadPoolWorkerId;
fn cardinality(&self) -> Option<usize> {
Some(self.0)
}
fn encode_dense(&self, value: Self::Unique) -> Option<usize> {
Some(value)
}
fn decode_dense(&self, value: usize) -> Self::Group<'_> {
ThreadPoolWorkerId(value)
}
type Unique = usize;
fn encode(&self, value: Self::Group<'_>) -> Option<Self::Unique> {
Some(value.0)
}
fn decode(&self, value: &Self::Unique) -> Self::Group<'_> {
ThreadPoolWorkerId(*value)
}
}
impl LabelSet for ThreadPoolWorkers {
type Value<'a> = ThreadPoolWorkerId;
fn dynamic_cardinality(&self) -> Option<usize> {
Some(self.0)
}
fn encode(&self, value: Self::Value<'_>) -> Option<usize> {
(value.0 < self.0).then_some(value.0)
}
fn decode(&self, value: usize) -> Self::Value<'_> {
ThreadPoolWorkerId(value)
}
}
impl FixedCardinalitySet for ThreadPoolWorkers {
fn cardinality(&self) -> usize {
self.0
}
}
#[derive(MetricGroup)]
#[metric(new(workers: usize))]
pub struct ThreadPoolMetrics {
pub injector_queue_depth: Gauge,
#[metric(init = GaugeVec::with_label_set(ThreadPoolWorkers(workers)))]
pub worker_queue_depth: GaugeVec<ThreadPoolWorkers>,
#[metric(init = CounterVec::with_label_set(ThreadPoolWorkers(workers)))]
pub worker_task_turns_total: CounterVec<ThreadPoolWorkers>,
#[metric(init = CounterVec::with_label_set(ThreadPoolWorkers(workers)))]
pub worker_task_skips_total: CounterVec<ThreadPoolWorkers>,
}

View File

@@ -16,9 +16,10 @@ use crate::console::messages::{ConsoleError, Details, MetricsAuxInfo, Status};
use crate::console::provider::{CachedAllowedIps, CachedRoleSecret, ConsoleBackend};
use crate::console::{self, CachedNodeInfo, NodeInfo};
use crate::error::ErrorKind;
use crate::{http, sasl, scram, BranchId, EndpointId, ProjectId};
use crate::{http, BranchId, EndpointId, ProjectId};
use anyhow::{bail, Context};
use async_trait::async_trait;
use proxy_sasl::{sasl, scram};
use retry::{retry_after, ShouldRetryWakeCompute};
use rstest::rstest;
use rustls::pki_types;
@@ -137,7 +138,7 @@ struct Scram(scram::ServerSecret);
impl Scram {
async fn new(password: &str) -> anyhow::Result<Self> {
let secret = scram::ServerSecret::build(password)
let secret = scram::ServerSecret::build_test_secret(password)
.await
.context("failed to generate scram secret")?;
Ok(Scram(secret))

View File

@@ -1,89 +0,0 @@
//! Simple Authentication and Security Layer.
//!
//! RFC: <https://datatracker.ietf.org/doc/html/rfc4422>.
//!
//! Reference implementation:
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/backend/libpq/auth-sasl.c>
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/interfaces/libpq/fe-auth.c>
mod channel_binding;
mod messages;
mod stream;
use crate::error::{ReportableError, UserFacingError};
use std::io;
use thiserror::Error;
pub use channel_binding::ChannelBinding;
pub use messages::FirstMessage;
pub use stream::{Outcome, SaslStream};
/// Fine-grained auth errors help in writing tests.
#[derive(Error, Debug)]
pub enum Error {
#[error("Channel binding failed: {0}")]
ChannelBindingFailed(&'static str),
#[error("Unsupported channel binding method: {0}")]
ChannelBindingBadMethod(Box<str>),
#[error("Bad client message: {0}")]
BadClientMessage(&'static str),
#[error("Internal error: missing digest")]
MissingBinding,
#[error("could not decode salt: {0}")]
Base64(#[from] base64::DecodeError),
#[error(transparent)]
Io(#[from] io::Error),
}
impl UserFacingError for Error {
fn to_string_client(&self) -> String {
use Error::*;
match self {
ChannelBindingFailed(m) => m.to_string(),
ChannelBindingBadMethod(m) => format!("unsupported channel binding method {m}"),
_ => "authentication protocol violation".to_string(),
}
}
}
impl ReportableError for Error {
fn get_error_kind(&self) -> crate::error::ErrorKind {
match self {
Error::ChannelBindingFailed(_) => crate::error::ErrorKind::User,
Error::ChannelBindingBadMethod(_) => crate::error::ErrorKind::User,
Error::BadClientMessage(_) => crate::error::ErrorKind::User,
Error::MissingBinding => crate::error::ErrorKind::Service,
Error::Base64(_) => crate::error::ErrorKind::ControlPlane,
Error::Io(_) => crate::error::ErrorKind::ClientDisconnect,
}
}
}
/// A convenient result type for SASL exchange.
pub type Result<T> = std::result::Result<T, Error>;
/// A result of one SASL exchange.
#[must_use]
pub enum Step<T, R> {
/// We should continue exchanging messages.
Continue(T, String),
/// The client has been authenticated successfully.
Success(R, String),
/// Authentication failed (reason attached).
Failure(&'static str),
}
/// Every SASL mechanism (e.g. [SCRAM](crate::scram)) is expected to implement this trait.
pub trait Mechanism: Sized {
/// What's produced as a result of successful authentication.
type Output;
/// Produce a server challenge to be sent to the client.
/// This is how this method is called in PostgreSQL (`libpq/sasl.h`).
fn exchange(self, input: &str) -> Result<Step<Self, Self::Output>>;
}

View File

@@ -1,84 +0,0 @@
//! Definition and parser for channel binding flag (a part of the `GS2` header).
/// Channel binding flag (possibly with params).
#[derive(Debug, PartialEq, Eq)]
pub enum ChannelBinding<T> {
/// Client doesn't support channel binding.
NotSupportedClient,
/// Client thinks server doesn't support channel binding.
NotSupportedServer,
/// Client wants to use this type of channel binding.
Required(T),
}
impl<T> ChannelBinding<T> {
pub fn and_then<R, E>(self, f: impl FnOnce(T) -> Result<R, E>) -> Result<ChannelBinding<R>, E> {
use ChannelBinding::*;
Ok(match self {
NotSupportedClient => NotSupportedClient,
NotSupportedServer => NotSupportedServer,
Required(x) => Required(f(x)?),
})
}
}
impl<'a> ChannelBinding<&'a str> {
// NB: FromStr doesn't work with lifetimes
pub fn parse(input: &'a str) -> Option<Self> {
use ChannelBinding::*;
Some(match input {
"n" => NotSupportedClient,
"y" => NotSupportedServer,
other => Required(other.strip_prefix("p=")?),
})
}
}
impl<T: std::fmt::Display> ChannelBinding<T> {
/// Encode channel binding data as base64 for subsequent checks.
pub fn encode<'a, E>(
&self,
get_cbind_data: impl FnOnce(&T) -> Result<&'a [u8], E>,
) -> Result<std::borrow::Cow<'static, str>, E> {
use ChannelBinding::*;
Ok(match self {
NotSupportedClient => {
// base64::encode("n,,")
"biws".into()
}
NotSupportedServer => {
// base64::encode("y,,")
"eSws".into()
}
Required(mode) => {
use std::io::Write;
let mut cbind_input = vec![];
write!(&mut cbind_input, "p={mode},,",).unwrap();
cbind_input.extend_from_slice(get_cbind_data(mode)?);
base64::encode(&cbind_input).into()
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn channel_binding_encode() -> anyhow::Result<()> {
use ChannelBinding::*;
let cases = [
(NotSupportedClient, base64::encode("n,,")),
(NotSupportedServer, base64::encode("y,,")),
(Required("foo"), base64::encode("p=foo,,bar")),
];
for (cb, input) in cases {
assert_eq!(cb.encode(|_| anyhow::Ok(b"bar"))?, input);
}
Ok(())
}
}

View File

@@ -1,68 +0,0 @@
//! Definitions for SASL messages.
use crate::parse::{split_at_const, split_cstr};
use pq_proto::{BeAuthenticationSaslMessage, BeMessage};
/// SASL-specific payload of [`PasswordMessage`](pq_proto::FeMessage::PasswordMessage).
#[derive(Debug)]
pub struct FirstMessage<'a> {
/// Authentication method, e.g. `"SCRAM-SHA-256"`.
pub method: &'a str,
/// Initial client message.
pub message: &'a str,
}
impl<'a> FirstMessage<'a> {
// NB: FromStr doesn't work with lifetimes
pub fn parse(bytes: &'a [u8]) -> Option<Self> {
let (method_cstr, tail) = split_cstr(bytes)?;
let method = method_cstr.to_str().ok()?;
let (len_bytes, bytes) = split_at_const(tail)?;
let len = u32::from_be_bytes(*len_bytes) as usize;
if len != bytes.len() {
return None;
}
let message = std::str::from_utf8(bytes).ok()?;
Some(Self { method, message })
}
}
/// A single SASL message.
/// This struct is deliberately decoupled from lower-level
/// [`BeAuthenticationSaslMessage`].
#[derive(Debug)]
pub(super) enum ServerMessage<T> {
/// We expect to see more steps.
Continue(T),
/// This is the final step.
Final(T),
}
impl<'a> ServerMessage<&'a str> {
pub(super) fn to_reply(&self) -> BeMessage<'a> {
use BeAuthenticationSaslMessage::*;
BeMessage::AuthenticationSasl(match self {
ServerMessage::Continue(s) => Continue(s.as_bytes()),
ServerMessage::Final(s) => Final(s.as_bytes()),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_sasl_first_message() {
let proto = "SCRAM-SHA-256";
let sasl = "n,,n=,r=KHQ2Gjc7NptyB8aov5/TnUy4";
let sasl_len = (sasl.len() as u32).to_be_bytes();
let bytes = [proto.as_bytes(), &[0], sasl_len.as_ref(), sasl.as_bytes()].concat();
let password = FirstMessage::parse(&bytes).unwrap();
assert_eq!(password.method, proto);
assert_eq!(password.message, sasl);
}
}

View File

@@ -1,92 +0,0 @@
//! Abstraction for the string-oriented SASL protocols.
use super::{messages::ServerMessage, Mechanism};
use crate::stream::PqStream;
use std::io;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
/// Abstracts away all peculiarities of the libpq's protocol.
pub struct SaslStream<'a, S> {
/// The underlying stream.
stream: &'a mut PqStream<S>,
/// Current password message we received from client.
current: bytes::Bytes,
/// First SASL message produced by client.
first: Option<&'a str>,
}
impl<'a, S> SaslStream<'a, S> {
pub fn new(stream: &'a mut PqStream<S>, first: &'a str) -> Self {
Self {
stream,
current: bytes::Bytes::new(),
first: Some(first),
}
}
}
impl<S: AsyncRead + Unpin> SaslStream<'_, S> {
// Receive a new SASL message from the client.
async fn recv(&mut self) -> io::Result<&str> {
if let Some(first) = self.first.take() {
return Ok(first);
}
self.current = self.stream.read_password_message().await?;
let s = std::str::from_utf8(&self.current)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?;
Ok(s)
}
}
impl<S: AsyncWrite + Unpin> SaslStream<'_, S> {
// Send a SASL message to the client.
async fn send(&mut self, msg: &ServerMessage<&str>) -> io::Result<()> {
self.stream.write_message(&msg.to_reply()).await?;
Ok(())
}
}
/// SASL authentication outcome.
/// It's much easier to match on those two variants
/// than to peek into a noisy protocol error type.
#[must_use = "caller must explicitly check for success"]
pub enum Outcome<R> {
/// Authentication succeeded and produced some value.
Success(R),
/// Authentication failed (reason attached).
Failure(&'static str),
}
impl<S: AsyncRead + AsyncWrite + Unpin> SaslStream<'_, S> {
/// Perform SASL message exchange according to the underlying algorithm
/// until user is either authenticated or denied access.
pub async fn authenticate<M: Mechanism>(
mut self,
mut mechanism: M,
) -> super::Result<Outcome<M::Output>> {
loop {
let input = self.recv().await?;
let step = mechanism.exchange(input).map_err(|error| {
info!(?error, "error during SASL exchange");
error
})?;
use super::Step;
return Ok(match step {
Step::Continue(moved_mechanism, reply) => {
self.send(&ServerMessage::Continue(&reply)).await?;
mechanism = moved_mechanism;
continue;
}
Step::Success(result, reply) => {
self.send(&ServerMessage::Final(&reply)).await?;
Outcome::Success(result)
}
Step::Failure(reason) => Outcome::Failure(reason),
});
}
}
}

View File

@@ -1,148 +0,0 @@
//! Salted Challenge Response Authentication Mechanism.
//!
//! RFC: <https://datatracker.ietf.org/doc/html/rfc5802>.
//!
//! Reference implementation:
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/backend/libpq/auth-scram.c>
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/interfaces/libpq/fe-auth-scram.c>
mod countmin;
mod exchange;
mod key;
mod messages;
mod pbkdf2;
mod secret;
mod signature;
pub mod threadpool;
pub use exchange::{exchange, Exchange};
pub use key::ScramKey;
pub use secret::ServerSecret;
use hmac::{Hmac, Mac};
use sha2::{Digest, Sha256};
const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
/// A list of supported SCRAM methods.
pub const METHODS: &[&str] = &[SCRAM_SHA_256_PLUS, SCRAM_SHA_256];
pub const METHODS_WITHOUT_PLUS: &[&str] = &[SCRAM_SHA_256];
/// Decode base64 into array without any heap allocations
fn base64_decode_array<const N: usize>(input: impl AsRef<[u8]>) -> Option<[u8; N]> {
let mut bytes = [0u8; N];
let size = base64::decode_config_slice(input, base64::STANDARD, &mut bytes).ok()?;
if size != N {
return None;
}
Some(bytes)
}
/// This function essentially is `Hmac(sha256, key, input)`.
/// Further reading: <https://datatracker.ietf.org/doc/html/rfc2104>.
fn hmac_sha256<'a>(key: &[u8], parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
let mut mac = Hmac::<Sha256>::new_from_slice(key).expect("bad key size");
parts.into_iter().for_each(|s| mac.update(s));
mac.finalize().into_bytes().into()
}
fn sha256<'a>(parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
let mut hasher = Sha256::new();
parts.into_iter().for_each(|s| hasher.update(s));
hasher.finalize().into()
}
#[cfg(test)]
mod tests {
use crate::{
intern::EndpointIdInt,
sasl::{Mechanism, Step},
EndpointId,
};
use super::{threadpool::ThreadPool, Exchange, ServerSecret};
#[test]
fn snapshot() {
let iterations = 4096;
let salt = "QSXCR+Q6sek8bf92";
let stored_key = "FO+9jBb3MUukt6jJnzjPZOWc5ow/Pu6JtPyju0aqaE8=";
let server_key = "qxJ1SbmSAi5EcS0J5Ck/cKAm/+Ixa+Kwp63f4OHDgzo=";
let secret = format!("SCRAM-SHA-256${iterations}:{salt}${stored_key}:{server_key}",);
let secret = ServerSecret::parse(&secret).unwrap();
const NONCE: [u8; 18] = [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
];
let mut exchange = Exchange::new(
&secret,
|| NONCE,
crate::config::TlsServerEndPoint::Undefined,
);
let client_first = "n,,n=user,r=rOprNGfwEbeRWgbNEkqO";
let client_final = "c=biws,r=rOprNGfwEbeRWgbNEkqOAQIDBAUGBwgJCgsMDQ4PEBES,p=rw1r5Kph5ThxmaUBC2GAQ6MfXbPnNkFiTIvdb/Rear0=";
let server_first =
"r=rOprNGfwEbeRWgbNEkqOAQIDBAUGBwgJCgsMDQ4PEBES,s=QSXCR+Q6sek8bf92,i=4096";
let server_final = "v=qtUDIofVnIhM7tKn93EQUUt5vgMOldcDVu1HC+OH0o0=";
exchange = match exchange.exchange(client_first).unwrap() {
Step::Continue(exchange, message) => {
assert_eq!(message, server_first);
exchange
}
Step::Success(_, _) => panic!("expected continue, got success"),
Step::Failure(f) => panic!("{f}"),
};
let key = match exchange.exchange(client_final).unwrap() {
Step::Success(key, message) => {
assert_eq!(message, server_final);
key
}
Step::Continue(_, _) => panic!("expected success, got continue"),
Step::Failure(f) => panic!("{f}"),
};
assert_eq!(
key.as_bytes(),
[
74, 103, 1, 132, 12, 31, 200, 48, 28, 54, 82, 232, 207, 12, 138, 189, 40, 32, 134,
27, 125, 170, 232, 35, 171, 167, 166, 41, 70, 228, 182, 112,
]
);
}
async fn run_round_trip_test(server_password: &str, client_password: &str) {
let pool = ThreadPool::new(1);
let ep = EndpointId::from("foo");
let ep = EndpointIdInt::from(ep);
let scram_secret = ServerSecret::build(server_password).await.unwrap();
let outcome = super::exchange(&pool, ep, &scram_secret, client_password.as_bytes())
.await
.unwrap();
match outcome {
crate::sasl::Outcome::Success(_) => {}
crate::sasl::Outcome::Failure(r) => panic!("{r}"),
}
}
#[tokio::test]
async fn round_trip() {
run_round_trip_test("pencil", "pencil").await
}
#[tokio::test]
#[should_panic(expected = "password doesn't match")]
async fn failure() {
run_round_trip_test("pencil", "eraser").await
}
}

View File

@@ -1,173 +0,0 @@
use std::hash::Hash;
/// estimator of hash jobs per second.
/// <https://en.wikipedia.org/wiki/Count%E2%80%93min_sketch>
pub struct CountMinSketch {
// one for each depth
hashers: Vec<ahash::RandomState>,
width: usize,
depth: usize,
// buckets, width*depth
buckets: Vec<u32>,
}
impl CountMinSketch {
/// Given parameters (ε, δ),
/// set width = ceil(e/ε)
/// set depth = ceil(ln(1/δ))
///
/// guarantees:
/// actual <= estimate
/// estimate <= actual + ε * N with probability 1 - δ
/// where N is the cardinality of the stream
pub fn with_params(epsilon: f64, delta: f64) -> Self {
CountMinSketch::new(
(std::f64::consts::E / epsilon).ceil() as usize,
(1.0_f64 / delta).ln().ceil() as usize,
)
}
fn new(width: usize, depth: usize) -> Self {
Self {
#[cfg(test)]
hashers: (0..depth)
.map(|i| {
// digits of pi for good randomness
ahash::RandomState::with_seeds(
314159265358979323,
84626433832795028,
84197169399375105,
82097494459230781 + i as u64,
)
})
.collect(),
#[cfg(not(test))]
hashers: (0..depth).map(|_| ahash::RandomState::new()).collect(),
width,
depth,
buckets: vec![0; width * depth],
}
}
pub fn inc_and_return<T: Hash>(&mut self, t: &T, x: u32) -> u32 {
let mut min = u32::MAX;
for row in 0..self.depth {
let col = (self.hashers[row].hash_one(t) as usize) % self.width;
let row = &mut self.buckets[row * self.width..][..self.width];
row[col] = row[col].saturating_add(x);
min = std::cmp::min(min, row[col]);
}
min
}
pub fn reset(&mut self) {
self.buckets.clear();
self.buckets.resize(self.width * self.depth, 0);
}
}
#[cfg(test)]
mod tests {
use rand::{rngs::StdRng, seq::SliceRandom, Rng, SeedableRng};
use super::CountMinSketch;
fn eval_precision(n: usize, p: f64, q: f64) -> usize {
// fixed value of phi for consistent test
let mut rng = StdRng::seed_from_u64(16180339887498948482);
#[allow(non_snake_case)]
let mut N = 0;
let mut ids = vec![];
for _ in 0..n {
// number of insert operations
let n = rng.gen_range(1..100);
// number to insert at once
let m = rng.gen_range(1..4096);
let id = uuid::Builder::from_random_bytes(rng.gen()).into_uuid();
ids.push((id, n, m));
// N = sum(actual)
N += n * m;
}
// q% of counts will be within p of the actual value
let mut sketch = CountMinSketch::with_params(p / N as f64, 1.0 - q);
dbg!(sketch.buckets.len());
// insert a bunch of entries in a random order
let mut ids2 = ids.clone();
while !ids2.is_empty() {
ids2.shuffle(&mut rng);
let mut i = 0;
while i < ids2.len() {
sketch.inc_and_return(&ids2[i].0, ids2[i].1);
ids2[i].2 -= 1;
if ids2[i].2 == 0 {
ids2.remove(i);
} else {
i += 1;
}
}
}
let mut within_p = 0;
for (id, n, m) in ids {
let actual = n * m;
let estimate = sketch.inc_and_return(&id, 0);
// This estimate has the guarantee that actual <= estimate
assert!(actual <= estimate);
// This estimate has the guarantee that estimate <= actual + εN with probability 1 - δ.
// ε = p / N, δ = 1 - q;
// therefore, estimate <= actual + p with probability q.
if estimate as f64 <= actual as f64 + p {
within_p += 1;
}
}
within_p
}
#[test]
fn precision() {
assert_eq!(eval_precision(100, 100.0, 0.99), 100);
assert_eq!(eval_precision(1000, 100.0, 0.99), 1000);
assert_eq!(eval_precision(100, 4096.0, 0.99), 100);
assert_eq!(eval_precision(1000, 4096.0, 0.99), 1000);
// seems to be more precise than the literature indicates?
// probably numbers are too small to truly represent the probabilities.
assert_eq!(eval_precision(100, 4096.0, 0.90), 100);
assert_eq!(eval_precision(1000, 4096.0, 0.90), 1000);
assert_eq!(eval_precision(100, 4096.0, 0.1), 98);
assert_eq!(eval_precision(1000, 4096.0, 0.1), 991);
}
// returns memory usage in bytes, and the time complexity per insert.
fn eval_cost(p: f64, q: f64) -> (usize, usize) {
#[allow(non_snake_case)]
// N = sum(actual)
// Let's assume 1021 samples, all of 4096
let N = 1021 * 4096;
let sketch = CountMinSketch::with_params(p / N as f64, 1.0 - q);
let memory = size_of::<u32>() * sketch.buckets.len();
let time = sketch.depth;
(memory, time)
}
#[test]
fn memory_usage() {
assert_eq!(eval_cost(100.0, 0.99), (2273580, 5));
assert_eq!(eval_cost(4096.0, 0.99), (55520, 5));
assert_eq!(eval_cost(4096.0, 0.90), (33312, 3));
assert_eq!(eval_cost(4096.0, 0.1), (11104, 1));
}
}

View File

@@ -1,234 +0,0 @@
//! Implementation of the SCRAM authentication algorithm.
use std::convert::Infallible;
use hmac::{Hmac, Mac};
use sha2::Sha256;
use super::messages::{
ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN,
};
use super::pbkdf2::Pbkdf2;
use super::secret::ServerSecret;
use super::signature::SignatureBuilder;
use super::threadpool::ThreadPool;
use super::ScramKey;
use crate::config;
use crate::intern::EndpointIdInt;
use crate::sasl::{self, ChannelBinding, Error as SaslError};
/// The only channel binding mode we currently support.
#[derive(Debug)]
struct TlsServerEndPoint;
impl std::fmt::Display for TlsServerEndPoint {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "tls-server-end-point")
}
}
impl std::str::FromStr for TlsServerEndPoint {
type Err = sasl::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"tls-server-end-point" => Ok(TlsServerEndPoint),
_ => Err(sasl::Error::ChannelBindingBadMethod(s.into())),
}
}
}
struct SaslSentInner {
cbind_flag: ChannelBinding<TlsServerEndPoint>,
client_first_message_bare: String,
server_first_message: OwnedServerFirstMessage,
}
struct SaslInitial {
nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
}
enum ExchangeState {
/// Waiting for [`ClientFirstMessage`].
Initial(SaslInitial),
/// Waiting for [`ClientFinalMessage`].
SaltSent(SaslSentInner),
}
/// Server's side of SCRAM auth algorithm.
pub struct Exchange<'a> {
state: ExchangeState,
secret: &'a ServerSecret,
tls_server_end_point: config::TlsServerEndPoint,
}
impl<'a> Exchange<'a> {
pub fn new(
secret: &'a ServerSecret,
nonce: fn() -> [u8; SCRAM_RAW_NONCE_LEN],
tls_server_end_point: config::TlsServerEndPoint,
) -> Self {
Self {
state: ExchangeState::Initial(SaslInitial { nonce }),
secret,
tls_server_end_point,
}
}
}
// copied from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L236-L248>
async fn derive_client_key(
pool: &ThreadPool,
endpoint: EndpointIdInt,
password: &[u8],
salt: &[u8],
iterations: u32,
) -> ScramKey {
let salted_password = pool
.spawn_job(endpoint, Pbkdf2::start(password, salt, iterations))
.await
.expect("job should not be cancelled");
let make_key = |name| {
let key = Hmac::<Sha256>::new_from_slice(&salted_password)
.expect("HMAC is able to accept all key sizes")
.chain_update(name)
.finalize();
<[u8; 32]>::from(key.into_bytes())
};
make_key(b"Client Key").into()
}
pub async fn exchange(
pool: &ThreadPool,
endpoint: EndpointIdInt,
secret: &ServerSecret,
password: &[u8],
) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
let salt = base64::decode(&secret.salt_base64)?;
let client_key = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
if secret.is_password_invalid(&client_key).into() {
Ok(sasl::Outcome::Failure("password doesn't match"))
} else {
Ok(sasl::Outcome::Success(client_key))
}
}
impl SaslInitial {
fn transition(
&self,
secret: &ServerSecret,
tls_server_end_point: &config::TlsServerEndPoint,
input: &str,
) -> sasl::Result<sasl::Step<SaslSentInner, Infallible>> {
let client_first_message = ClientFirstMessage::parse(input)
.ok_or(SaslError::BadClientMessage("invalid client-first-message"))?;
// If the flag is set to "y" and the server supports channel
// binding, the server MUST fail authentication
if client_first_message.cbind_flag == ChannelBinding::NotSupportedServer
&& tls_server_end_point.supported()
{
return Err(SaslError::ChannelBindingFailed("SCRAM-PLUS not used"));
}
let server_first_message = client_first_message.build_server_first_message(
&(self.nonce)(),
&secret.salt_base64,
secret.iterations,
);
let msg = server_first_message.as_str().to_owned();
let next = SaslSentInner {
cbind_flag: client_first_message.cbind_flag.and_then(str::parse)?,
client_first_message_bare: client_first_message.bare.to_owned(),
server_first_message,
};
Ok(sasl::Step::Continue(next, msg))
}
}
impl SaslSentInner {
fn transition(
&self,
secret: &ServerSecret,
tls_server_end_point: &config::TlsServerEndPoint,
input: &str,
) -> sasl::Result<sasl::Step<Infallible, super::ScramKey>> {
let Self {
cbind_flag,
client_first_message_bare,
server_first_message,
} = self;
let client_final_message = ClientFinalMessage::parse(input)
.ok_or(SaslError::BadClientMessage("invalid client-final-message"))?;
let channel_binding = cbind_flag.encode(|_| match tls_server_end_point {
config::TlsServerEndPoint::Sha256(x) => Ok(x),
config::TlsServerEndPoint::Undefined => Err(SaslError::MissingBinding),
})?;
// This might've been caused by a MITM attack
if client_final_message.channel_binding != channel_binding {
return Err(SaslError::ChannelBindingFailed(
"insecure connection: secure channel data mismatch",
));
}
if client_final_message.nonce != server_first_message.nonce() {
return Err(SaslError::BadClientMessage("combined nonce doesn't match"));
}
let signature_builder = SignatureBuilder {
client_first_message_bare,
server_first_message: server_first_message.as_str(),
client_final_message_without_proof: client_final_message.without_proof,
};
let client_key = signature_builder
.build(&secret.stored_key)
.derive_client_key(&client_final_message.proof);
// Auth fails either if keys don't match or it's pre-determined to fail.
if secret.is_password_invalid(&client_key).into() {
return Ok(sasl::Step::Failure("password doesn't match"));
}
let msg =
client_final_message.build_server_final_message(signature_builder, &secret.server_key);
Ok(sasl::Step::Success(client_key, msg))
}
}
impl sasl::Mechanism for Exchange<'_> {
type Output = super::ScramKey;
fn exchange(mut self, input: &str) -> sasl::Result<sasl::Step<Self, Self::Output>> {
use {sasl::Step::*, ExchangeState::*};
match &self.state {
Initial(init) => {
match init.transition(self.secret, &self.tls_server_end_point, input)? {
Continue(sent, msg) => {
self.state = SaltSent(sent);
Ok(Continue(self, msg))
}
Success(x, _) => match x {},
Failure(msg) => Ok(Failure(msg)),
}
}
SaltSent(sent) => {
match sent.transition(self.secret, &self.tls_server_end_point, input)? {
Success(keys, msg) => Ok(Success(keys, msg)),
Continue(x, _) => match x {},
Failure(msg) => Ok(Failure(msg)),
}
}
}
}
}

View File

@@ -1,51 +0,0 @@
//! Tools for client/server/stored key management.
use subtle::ConstantTimeEq;
/// Faithfully taken from PostgreSQL.
pub const SCRAM_KEY_LEN: usize = 32;
/// One of the keys derived from the user's password.
/// We use the same structure for all keys, i.e.
/// `ClientKey`, `StoredKey`, and `ServerKey`.
#[derive(Clone, Default, Eq, Debug)]
#[repr(transparent)]
pub struct ScramKey {
bytes: [u8; SCRAM_KEY_LEN],
}
impl PartialEq for ScramKey {
fn eq(&self, other: &Self) -> bool {
self.ct_eq(other).into()
}
}
impl ConstantTimeEq for ScramKey {
fn ct_eq(&self, other: &Self) -> subtle::Choice {
self.bytes.ct_eq(&other.bytes)
}
}
impl ScramKey {
pub fn sha256(&self) -> Self {
super::sha256([self.as_ref()]).into()
}
pub fn as_bytes(&self) -> [u8; SCRAM_KEY_LEN] {
self.bytes
}
}
impl From<[u8; SCRAM_KEY_LEN]> for ScramKey {
#[inline(always)]
fn from(bytes: [u8; SCRAM_KEY_LEN]) -> Self {
Self { bytes }
}
}
impl AsRef<[u8]> for ScramKey {
#[inline(always)]
fn as_ref(&self) -> &[u8] {
&self.bytes
}
}

View File

@@ -1,257 +0,0 @@
//! Definitions for SCRAM messages.
use super::base64_decode_array;
use super::key::{ScramKey, SCRAM_KEY_LEN};
use super::signature::SignatureBuilder;
use crate::sasl::ChannelBinding;
use std::fmt;
use std::ops::Range;
/// Faithfully taken from PostgreSQL.
pub const SCRAM_RAW_NONCE_LEN: usize = 18;
/// Although we ignore all extensions, we still have to validate the message.
fn validate_sasl_extensions<'a>(parts: impl Iterator<Item = &'a str>) -> Option<()> {
for mut chars in parts.map(|s| s.chars()) {
let attr = chars.next()?;
if !attr.is_ascii_alphabetic() {
return None;
}
let eq = chars.next()?;
if eq != '=' {
return None;
}
}
Some(())
}
#[derive(Debug)]
pub struct ClientFirstMessage<'a> {
/// `client-first-message-bare`.
pub bare: &'a str,
/// Channel binding mode.
pub cbind_flag: ChannelBinding<&'a str>,
/// Client nonce.
pub nonce: &'a str,
}
impl<'a> ClientFirstMessage<'a> {
// NB: FromStr doesn't work with lifetimes
pub fn parse(input: &'a str) -> Option<Self> {
let mut parts = input.split(',');
let cbind_flag = ChannelBinding::parse(parts.next()?)?;
// PG doesn't support authorization identity,
// so we don't bother defining GS2 header type
let authzid = parts.next()?;
if !authzid.is_empty() {
return None;
}
// Unfortunately, `parts.as_str()` is unstable
let pos = authzid.as_ptr() as usize - input.as_ptr() as usize + 1;
let (_, bare) = input.split_at(pos);
// In theory, these might be preceded by "reserved-mext" (i.e. "m=")
let username = parts.next()?.strip_prefix("n=")?;
// https://github.com/postgres/postgres/blob/f83908798f78c4cafda217ca875602c88ea2ae28/src/backend/libpq/auth-scram.c#L13-L14
if !username.is_empty() {
tracing::warn!(username, "scram username provided, but is not expected")
// TODO(conrad):
// return None;
}
let nonce = parts.next()?.strip_prefix("r=")?;
// Validate but ignore auth extensions
validate_sasl_extensions(parts)?;
Some(Self {
bare,
cbind_flag,
nonce,
})
}
/// Build a response to [`ClientFirstMessage`].
pub fn build_server_first_message(
&self,
nonce: &[u8; SCRAM_RAW_NONCE_LEN],
salt_base64: &str,
iterations: u32,
) -> OwnedServerFirstMessage {
use std::fmt::Write;
let mut message = String::new();
write!(&mut message, "r={}", self.nonce).unwrap();
base64::encode_config_buf(nonce, base64::STANDARD, &mut message);
let combined_nonce = 2..message.len();
write!(&mut message, ",s={},i={}", salt_base64, iterations).unwrap();
// This design guarantees that it's impossible to create a
// server-first-message without receiving a client-first-message
OwnedServerFirstMessage {
message,
nonce: combined_nonce,
}
}
}
#[derive(Debug)]
pub struct ClientFinalMessage<'a> {
/// `client-final-message-without-proof`.
pub without_proof: &'a str,
/// Channel binding data (base64).
pub channel_binding: &'a str,
/// Combined client & server nonce.
pub nonce: &'a str,
/// Client auth proof.
pub proof: [u8; SCRAM_KEY_LEN],
}
impl<'a> ClientFinalMessage<'a> {
// NB: FromStr doesn't work with lifetimes
pub fn parse(input: &'a str) -> Option<Self> {
let (without_proof, proof) = input.rsplit_once(',')?;
let mut parts = without_proof.split(',');
let channel_binding = parts.next()?.strip_prefix("c=")?;
let nonce = parts.next()?.strip_prefix("r=")?;
// Validate but ignore auth extensions
validate_sasl_extensions(parts)?;
let proof = base64_decode_array(proof.strip_prefix("p=")?)?;
Some(Self {
without_proof,
channel_binding,
nonce,
proof,
})
}
/// Build a response to [`ClientFinalMessage`].
pub fn build_server_final_message(
&self,
signature_builder: SignatureBuilder,
server_key: &ScramKey,
) -> String {
let mut buf = String::from("v=");
base64::encode_config_buf(
signature_builder.build(server_key),
base64::STANDARD,
&mut buf,
);
buf
}
}
/// We need to keep a convenient representation of this
/// message for the next authentication step.
pub struct OwnedServerFirstMessage {
/// Owned `server-first-message`.
message: String,
/// Slice into `message`.
nonce: Range<usize>,
}
impl OwnedServerFirstMessage {
/// Extract combined nonce from the message.
#[inline(always)]
pub fn nonce(&self) -> &str {
&self.message[self.nonce.clone()]
}
/// Get reference to a text representation of the message.
#[inline(always)]
pub fn as_str(&self) -> &str {
&self.message
}
}
impl fmt::Debug for OwnedServerFirstMessage {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ServerFirstMessage")
.field("message", &self.as_str())
.field("nonce", &self.nonce())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_client_first_message() {
use ChannelBinding::*;
// (Almost) real strings captured during debug sessions
let cases = [
(NotSupportedClient, "n,,n=,r=t8JwklwKecDLwSsA72rHmVju"),
(NotSupportedServer, "y,,n=,r=t8JwklwKecDLwSsA72rHmVju"),
(
Required("tls-server-end-point"),
"p=tls-server-end-point,,n=,r=t8JwklwKecDLwSsA72rHmVju",
),
];
for (cb, input) in cases {
let msg = ClientFirstMessage::parse(input).unwrap();
assert_eq!(msg.bare, "n=,r=t8JwklwKecDLwSsA72rHmVju");
assert_eq!(msg.nonce, "t8JwklwKecDLwSsA72rHmVju");
assert_eq!(msg.cbind_flag, cb);
}
}
#[test]
fn parse_client_first_message_with_invalid_gs2_authz() {
assert!(ClientFirstMessage::parse("n,authzid,n=,r=nonce").is_none())
}
#[test]
fn parse_client_first_message_with_extra_params() {
let msg = ClientFirstMessage::parse("n,,n=,r=nonce,a=foo,b=bar,c=baz").unwrap();
assert_eq!(msg.bare, "n=,r=nonce,a=foo,b=bar,c=baz");
assert_eq!(msg.nonce, "nonce");
assert_eq!(msg.cbind_flag, ChannelBinding::NotSupportedClient);
}
#[test]
fn parse_client_first_message_with_extra_params_invalid() {
// must be of the form `<ascii letter>=<...>`
assert!(ClientFirstMessage::parse("n,,n=,r=nonce,abc=foo").is_none());
assert!(ClientFirstMessage::parse("n,,n=,r=nonce,1=foo").is_none());
assert!(ClientFirstMessage::parse("n,,n=,r=nonce,a").is_none());
}
#[test]
fn parse_client_final_message() {
let input = [
"c=eSws",
"r=iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU",
"p=SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI=",
]
.join(",");
let msg = ClientFinalMessage::parse(&input).unwrap();
assert_eq!(
msg.without_proof,
"c=eSws,r=iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU"
);
assert_eq!(
msg.nonce,
"iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU"
);
assert_eq!(
base64::encode(msg.proof),
"SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI="
);
}
}

View File

@@ -1,89 +0,0 @@
use hmac::{
digest::{consts::U32, generic_array::GenericArray},
Hmac, Mac,
};
use sha2::Sha256;
pub struct Pbkdf2 {
hmac: Hmac<Sha256>,
prev: GenericArray<u8, U32>,
hi: GenericArray<u8, U32>,
iterations: u32,
}
// inspired from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L36-L61>
impl Pbkdf2 {
pub fn start(str: &[u8], salt: &[u8], iterations: u32) -> Self {
let hmac =
Hmac::<Sha256>::new_from_slice(str).expect("HMAC is able to accept all key sizes");
let prev = hmac
.clone()
.chain_update(salt)
.chain_update(1u32.to_be_bytes())
.finalize()
.into_bytes();
Self {
hmac,
// one consumed for the hash above
iterations: iterations - 1,
hi: prev,
prev,
}
}
pub fn cost(&self) -> u32 {
(self.iterations).clamp(0, 4096)
}
pub fn turn(&mut self) -> std::task::Poll<[u8; 32]> {
let Self {
hmac,
prev,
hi,
iterations,
} = self;
// only do 4096 iterations per turn before sharing the thread for fairness
let n = (*iterations).clamp(0, 4096);
for _ in 0..n {
*prev = hmac.clone().chain_update(*prev).finalize().into_bytes();
for (hi, prev) in hi.iter_mut().zip(*prev) {
*hi ^= prev;
}
}
*iterations -= n;
if *iterations == 0 {
std::task::Poll::Ready((*hi).into())
} else {
std::task::Poll::Pending
}
}
}
#[cfg(test)]
mod tests {
use super::Pbkdf2;
use pbkdf2::pbkdf2_hmac_array;
use sha2::Sha256;
#[test]
fn works() {
let salt = b"sodium chloride";
let pass = b"Ne0n_!5_50_C007";
let mut job = Pbkdf2::start(pass, salt, 600000);
let hash = loop {
let std::task::Poll::Ready(hash) = job.turn() else {
continue;
};
break hash;
};
let expected = pbkdf2_hmac_array::<Sha256, 32>(pass, salt, 600000);
assert_eq!(hash, expected)
}
}

View File

@@ -1,100 +0,0 @@
//! Tools for SCRAM server secret management.
use subtle::{Choice, ConstantTimeEq};
use super::base64_decode_array;
use super::key::ScramKey;
/// Server secret is produced from user's password,
/// and is used throughout the authentication process.
#[derive(Clone, Eq, PartialEq, Debug)]
pub struct ServerSecret {
/// Number of iterations for `PBKDF2` function.
pub iterations: u32,
/// Salt used to hash user's password.
pub salt_base64: String,
/// Hashed `ClientKey`.
pub stored_key: ScramKey,
/// Used by client to verify server's signature.
pub server_key: ScramKey,
/// Should auth fail no matter what?
/// This is exactly the case for mocked secrets.
pub doomed: bool,
}
impl ServerSecret {
pub fn parse(input: &str) -> Option<Self> {
// SCRAM-SHA-256$<iterations>:<salt>$<storedkey>:<serverkey>
let s = input.strip_prefix("SCRAM-SHA-256$")?;
let (params, keys) = s.split_once('$')?;
let ((iterations, salt), (stored_key, server_key)) =
params.split_once(':').zip(keys.split_once(':'))?;
let secret = ServerSecret {
iterations: iterations.parse().ok()?,
salt_base64: salt.to_owned(),
stored_key: base64_decode_array(stored_key)?.into(),
server_key: base64_decode_array(server_key)?.into(),
doomed: false,
};
Some(secret)
}
pub fn is_password_invalid(&self, client_key: &ScramKey) -> Choice {
// constant time to not leak partial key match
client_key.sha256().ct_ne(&self.stored_key) | Choice::from(self.doomed as u8)
}
/// To avoid revealing information to an attacker, we use a
/// mocked server secret even if the user doesn't exist.
/// See `auth-scram.c : mock_scram_secret` for details.
pub fn mock(nonce: [u8; 32]) -> Self {
Self {
// this doesn't reveal much information as we're going to use
// iteration count 1 for our generated passwords going forward.
// PG16 users can set iteration count=1 already today.
iterations: 1,
salt_base64: base64::encode(nonce),
stored_key: ScramKey::default(),
server_key: ScramKey::default(),
doomed: true,
}
}
/// Build a new server secret from the prerequisites.
/// XXX: We only use this function in tests.
#[cfg(test)]
pub async fn build(password: &str) -> Option<Self> {
Self::parse(&postgres_protocol::password::scram_sha_256(password.as_bytes()).await)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_scram_secret() {
let iterations = 4096;
let salt = "+/tQQax7twvwTj64mjBsxQ==";
let stored_key = "D5h6KTMBlUvDJk2Y8ELfC1Sjtc6k9YHjRyuRZyBNJns=";
let server_key = "Pi3QHbcluX//NDfVkKlFl88GGzlJ5LkyPwcdlN/QBvI=";
let secret = format!(
"SCRAM-SHA-256${iterations}:{salt}${stored_key}:{server_key}",
iterations = iterations,
salt = salt,
stored_key = stored_key,
server_key = server_key,
);
let parsed = ServerSecret::parse(&secret).unwrap();
assert_eq!(parsed.iterations, iterations);
assert_eq!(parsed.salt_base64, salt);
assert_eq!(base64::encode(parsed.stored_key), stored_key);
assert_eq!(base64::encode(parsed.server_key), server_key);
}
}

View File

@@ -1,66 +0,0 @@
//! Tools for client/server signature management.
use super::key::{ScramKey, SCRAM_KEY_LEN};
/// A collection of message parts needed to derive the client's signature.
#[derive(Debug)]
pub struct SignatureBuilder<'a> {
pub client_first_message_bare: &'a str,
pub server_first_message: &'a str,
pub client_final_message_without_proof: &'a str,
}
impl SignatureBuilder<'_> {
pub fn build(&self, key: &ScramKey) -> Signature {
let parts = [
self.client_first_message_bare.as_bytes(),
b",",
self.server_first_message.as_bytes(),
b",",
self.client_final_message_without_proof.as_bytes(),
];
super::hmac_sha256(key.as_ref(), parts).into()
}
}
/// A computed value which, when xored with `ClientProof`,
/// produces `ClientKey` that we need for authentication.
#[derive(Debug)]
#[repr(transparent)]
pub struct Signature {
bytes: [u8; SCRAM_KEY_LEN],
}
impl Signature {
/// Derive `ClientKey` from client's signature and proof.
pub fn derive_client_key(&self, proof: &[u8; SCRAM_KEY_LEN]) -> ScramKey {
// This is how the proof is calculated:
//
// 1. sha256(ClientKey) -> StoredKey
// 2. hmac_sha256(StoredKey, [messages...]) -> ClientSignature
// 3. ClientKey ^ ClientSignature -> ClientProof
//
// Step 3 implies that we can restore ClientKey from the proof
// by xoring the latter with the ClientSignature. Afterwards we
// can check that the presumed ClientKey meets our expectations.
let mut signature = self.bytes;
for (i, x) in proof.iter().enumerate() {
signature[i] ^= x;
}
signature.into()
}
}
impl From<[u8; SCRAM_KEY_LEN]> for Signature {
fn from(bytes: [u8; SCRAM_KEY_LEN]) -> Self {
Self { bytes }
}
}
impl AsRef<[u8]> for Signature {
fn as_ref(&self) -> &[u8] {
&self.bytes
}
}

View File

@@ -1,321 +0,0 @@
//! Custom threadpool implementation for password hashing.
//!
//! Requirements:
//! 1. Fairness per endpoint.
//! 2. Yield support for high iteration counts.
use std::sync::{
atomic::{AtomicU64, Ordering},
Arc,
};
use crossbeam_deque::{Injector, Stealer, Worker};
use itertools::Itertools;
use parking_lot::{Condvar, Mutex};
use rand::Rng;
use rand::{rngs::SmallRng, SeedableRng};
use tokio::sync::oneshot;
use crate::{
intern::EndpointIdInt,
metrics::{ThreadPoolMetrics, ThreadPoolWorkerId},
scram::countmin::CountMinSketch,
};
use super::pbkdf2::Pbkdf2;
pub struct ThreadPool {
queue: Injector<JobSpec>,
stealers: Vec<Stealer<JobSpec>>,
parkers: Vec<(Condvar, Mutex<ThreadState>)>,
/// bitpacked representation.
/// lower 8 bits = number of sleeping threads
/// next 8 bits = number of idle threads (searching for work)
counters: AtomicU64,
pub metrics: Arc<ThreadPoolMetrics>,
}
#[derive(PartialEq)]
enum ThreadState {
Parked,
Active,
}
impl ThreadPool {
pub fn new(n_workers: u8) -> Arc<Self> {
let workers = (0..n_workers).map(|_| Worker::new_fifo()).collect_vec();
let stealers = workers.iter().map(|w| w.stealer()).collect_vec();
let parkers = (0..n_workers)
.map(|_| (Condvar::new(), Mutex::new(ThreadState::Active)))
.collect_vec();
let pool = Arc::new(Self {
queue: Injector::new(),
stealers,
parkers,
// threads start searching for work
counters: AtomicU64::new((n_workers as u64) << 8),
metrics: Arc::new(ThreadPoolMetrics::new(n_workers as usize)),
});
for (i, worker) in workers.into_iter().enumerate() {
let pool = Arc::clone(&pool);
std::thread::spawn(move || thread_rt(pool, worker, i));
}
pool
}
pub fn spawn_job(
&self,
endpoint: EndpointIdInt,
pbkdf2: Pbkdf2,
) -> oneshot::Receiver<[u8; 32]> {
let (tx, rx) = oneshot::channel();
let queue_was_empty = self.queue.is_empty();
self.metrics.injector_queue_depth.inc();
self.queue.push(JobSpec {
response: tx,
pbkdf2,
endpoint,
});
// inspired from <https://github.com/rayon-rs/rayon/blob/3e3962cb8f7b50773bcc360b48a7a674a53a2c77/rayon-core/src/sleep/mod.rs#L242>
let counts = self.counters.load(Ordering::SeqCst);
let num_awake_but_idle = (counts >> 8) & 0xff;
let num_sleepers = counts & 0xff;
// If the queue is non-empty, then we always wake up a worker
// -- clearly the existing idle jobs aren't enough. Otherwise,
// check to see if we have enough idle workers.
if !queue_was_empty || num_awake_but_idle == 0 {
let num_to_wake = Ord::min(1, num_sleepers);
self.wake_any_threads(num_to_wake);
}
rx
}
#[cold]
fn wake_any_threads(&self, mut num_to_wake: u64) {
if num_to_wake > 0 {
for i in 0..self.parkers.len() {
if self.wake_specific_thread(i) {
num_to_wake -= 1;
if num_to_wake == 0 {
return;
}
}
}
}
}
fn wake_specific_thread(&self, index: usize) -> bool {
let (condvar, lock) = &self.parkers[index];
let mut state = lock.lock();
if *state == ThreadState::Parked {
condvar.notify_one();
// When the thread went to sleep, it will have incremented
// this value. When we wake it, its our job to decrement
// it. We could have the thread do it, but that would
// introduce a delay between when the thread was
// *notified* and when this counter was decremented. That
// might mislead people with new work into thinking that
// there are sleeping threads that they should try to
// wake, when in fact there is nothing left for them to
// do.
self.counters.fetch_sub(1, Ordering::SeqCst);
*state = ThreadState::Active;
true
} else {
false
}
}
fn steal(&self, rng: &mut impl Rng, skip: usize, worker: &Worker<JobSpec>) -> Option<JobSpec> {
// announce thread as idle
self.counters.fetch_add(256, Ordering::SeqCst);
// try steal from the global queue
loop {
match self.queue.steal_batch_and_pop(worker) {
crossbeam_deque::Steal::Success(job) => {
self.metrics
.injector_queue_depth
.set(self.queue.len() as i64);
// no longer idle
self.counters.fetch_sub(256, Ordering::SeqCst);
return Some(job);
}
crossbeam_deque::Steal::Retry => continue,
crossbeam_deque::Steal::Empty => break,
}
}
// try steal from our neighbours
loop {
let mut retry = false;
let start = rng.gen_range(0..self.stealers.len());
let job = (start..self.stealers.len())
.chain(0..start)
.filter(|i| *i != skip)
.find_map(
|victim| match self.stealers[victim].steal_batch_and_pop(worker) {
crossbeam_deque::Steal::Success(job) => Some(job),
crossbeam_deque::Steal::Empty => None,
crossbeam_deque::Steal::Retry => {
retry = true;
None
}
},
);
if job.is_some() {
// no longer idle
self.counters.fetch_sub(256, Ordering::SeqCst);
return job;
}
if !retry {
return None;
}
}
}
}
fn thread_rt(pool: Arc<ThreadPool>, worker: Worker<JobSpec>, index: usize) {
/// interval when we should steal from the global queue
/// so that tail latencies are managed appropriately
const STEAL_INTERVAL: usize = 61;
/// How often to reset the sketch values
const SKETCH_RESET_INTERVAL: usize = 1021;
let mut rng = SmallRng::from_entropy();
// used to determine whether we should temporarily skip tasks for fairness.
// 99% of estimates will overcount by no more than 4096 samples
let mut sketch = CountMinSketch::with_params(1.0 / (SKETCH_RESET_INTERVAL as f64), 0.01);
let (condvar, lock) = &pool.parkers[index];
'wait: loop {
// wait for notification of work
{
let mut lock = lock.lock();
// queue is empty
pool.metrics
.worker_queue_depth
.set(ThreadPoolWorkerId(index), 0);
// subtract 1 from idle count, add 1 to sleeping count.
pool.counters.fetch_sub(255, Ordering::SeqCst);
*lock = ThreadState::Parked;
condvar.wait(&mut lock);
}
for i in 0.. {
let mut job = match worker
.pop()
.or_else(|| pool.steal(&mut rng, index, &worker))
{
Some(job) => job,
None => continue 'wait,
};
pool.metrics
.worker_queue_depth
.set(ThreadPoolWorkerId(index), worker.len() as i64);
// receiver is closed, cancel the task
if !job.response.is_closed() {
let rate = sketch.inc_and_return(&job.endpoint, job.pbkdf2.cost());
const P: f64 = 2000.0;
// probability decreases as rate increases.
// lower probability, higher chance of being skipped
//
// estimates (rate in terms of 4096 rounds):
// rate = 0 => probability = 100%
// rate = 10 => probability = 71.3%
// rate = 50 => probability = 62.1%
// rate = 500 => probability = 52.3%
// rate = 1021 => probability = 49.8%
//
// My expectation is that the pool queue will only begin backing up at ~1000rps
// in which case the SKETCH_RESET_INTERVAL represents 1 second. Thus, the rates above
// are in requests per second.
let probability = P.ln() / (P + rate as f64).ln();
if pool.queue.len() > 32 || rng.gen_bool(probability) {
pool.metrics
.worker_task_turns_total
.inc(ThreadPoolWorkerId(index));
match job.pbkdf2.turn() {
std::task::Poll::Ready(result) => {
let _ = job.response.send(result);
}
std::task::Poll::Pending => worker.push(job),
}
} else {
pool.metrics
.worker_task_skips_total
.inc(ThreadPoolWorkerId(index));
// skip for now
worker.push(job)
}
}
// if we get stuck with a few long lived jobs in the queue
// it's better to try and steal from the queue too for fairness
if i % STEAL_INTERVAL == 0 {
let _ = pool.queue.steal_batch(&worker);
}
if i % SKETCH_RESET_INTERVAL == 0 {
sketch.reset();
}
}
}
}
struct JobSpec {
response: oneshot::Sender<[u8; 32]>,
pbkdf2: Pbkdf2,
endpoint: EndpointIdInt,
}
#[cfg(test)]
mod tests {
use crate::EndpointId;
use super::*;
#[tokio::test]
async fn hash_is_correct() {
let pool = ThreadPool::new(1);
let ep = EndpointId::from("foo");
let ep = EndpointIdInt::from(ep);
let salt = [0x55; 32];
let actual = pool
.spawn_job(ep, Pbkdf2::start(b"password", &salt, 4096))
.await
.unwrap();
let expected = [
10, 114, 73, 188, 140, 222, 196, 156, 214, 184, 79, 157, 119, 242, 16, 31, 53, 242,
178, 43, 95, 8, 225, 182, 122, 40, 219, 21, 89, 147, 64, 140,
];
assert_eq!(actual, expected)
}
}

View File

@@ -79,11 +79,11 @@ impl PoolingBackend {
)
.await?;
let res = match auth_outcome {
crate::sasl::Outcome::Success(key) => {
proxy_sasl::sasl::Outcome::Success(key) => {
info!("user successfully authenticated");
Ok(key)
}
crate::sasl::Outcome::Failure(reason) => {
proxy_sasl::sasl::Outcome::Failure(reason) => {
info!("auth backend failed with an error: {reason}");
Err(AuthError::auth_failed(&*conn_info.user_info.user))
}

View File

@@ -1,10 +1,10 @@
use crate::config::TlsServerEndPoint;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::metrics::Metrics;
use bytes::BytesMut;
use pq_proto::framed::{ConnectionError, Framed};
use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
use proxy_sasl::scram::TlsServerEndPoint;
use rustls::ServerConfig;
use std::pin::Pin;
use std::sync::Arc;