Compare commits

...

4 Commits

Author SHA1 Message Date
Dmitry Ivanov
632c07cab5 Fix lints 2023-02-28 19:40:28 +03:00
Dmitry Ivanov
e9f73707c7 [proxy] Prevent unauthorized wake-ups in the "password hack" flow 2023-02-28 19:27:08 +03:00
Dmitry Ivanov
f9f40fa41d [proxy] Introduce SniParams for creds parsing 2023-02-28 19:27:08 +03:00
Dmitry Ivanov
021ab8365f [proxy] Refactoring in the classic auth backend 2023-02-28 19:27:08 +03:00
14 changed files with 279 additions and 127 deletions

View File

@@ -3,7 +3,7 @@
pub mod backend;
pub use backend::BackendType;
mod credentials;
pub mod credentials;
pub use credentials::ClientCredentials;
mod password_hack;

View File

@@ -11,7 +11,7 @@ use crate::{
provider::{CachedNodeInfo, ConsoleReqExtra},
Api,
},
stream, url,
scram, stream, url,
};
use futures::TryFutureExt;
use std::borrow::Cow;
@@ -59,8 +59,8 @@ impl std::fmt::Display for BackendType<'_, ()> {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
use BackendType::*;
match self {
Console(endpoint, _) => fmt.debug_tuple("Console").field(&endpoint.url()).finish(),
Postgres(endpoint, _) => fmt.debug_tuple("Postgres").field(&endpoint.url()).finish(),
Console(api, _) => fmt.debug_tuple("Console").field(&api.url()).finish(),
Postgres(api, _) => fmt.debug_tuple("Postgres").field(&api.url()).finish(),
Link(url) => fmt.debug_tuple("Link").field(&url.as_str()).finish(),
}
}
@@ -106,6 +106,23 @@ impl<'a, T, E> BackendType<'a, Result<T, E>> {
}
}
impl console::AuthInfo {
/// Either it's our way ([SCRAM](crate::scram)) or the highway :)
/// But seriously, we don't aim to support anything but SCRAM for now.
fn scram_or_goodbye(self) -> auth::Result<scram::ServerSecret> {
match self {
Self::Md5(_) => {
info!("auth endpoint chooses MD5");
Err(auth::AuthError::bad_auth_method("MD5"))
}
Self::Scram(secret) => {
info!("auth endpoint chooses SCRAM");
Ok(secret)
}
}
}
}
/// True to its name, this function encapsulates our current auth trade-offs.
/// Here, we choose the appropriate auth flow based on circumstances.
async fn auth_quirks(
@@ -183,7 +200,9 @@ impl BackendType<'_, ClientCredentials<'_>> {
info!("user successfully authenticated");
Ok(res)
}
}
impl BackendType<'_, ClientCredentials<'_>> {
/// When applicable, wake the compute node, gaining its connection info in the process.
/// The link auth flow doesn't support this, so we return [`None`] in that case.
pub async fn wake_compute(

View File

@@ -2,57 +2,54 @@ use super::AuthSuccess;
use crate::{
auth::{self, AuthFlow, ClientCredentials},
compute,
console::{self, AuthInfo, CachedNodeInfo, ConsoleReqExtra},
console::{self, CachedNodeInfo, ConsoleReqExtra},
sasl, scram,
stream::PqStream,
};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_postgres::config::AuthKeys;
use tracing::info;
pub(super) async fn authenticate(
async fn do_scram(
secret: scram::ServerSecret,
creds: &ClientCredentials<'_>,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<compute::ScramKeys> {
let outcome = AuthFlow::new(client)
.begin(auth::Scram(&secret))
.await?
.authenticate()
.await?;
let client_key = match outcome {
sasl::Outcome::Success(key) => key,
sasl::Outcome::Failure(reason) => {
info!("auth backend failed with an error: {reason}");
return Err(auth::AuthError::auth_failed(creds.user));
}
};
let keys = compute::ScramKeys {
client_key: client_key.as_bytes(),
server_key: secret.server_key.as_bytes(),
};
Ok(keys)
}
pub async fn authenticate(
api: &impl console::Api,
extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials<'_>,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<AuthSuccess<CachedNodeInfo>> {
info!("fetching user's authentication info");
let info = api.get_auth_info(extra, creds).await?.unwrap_or_else(|| {
// If we don't have an authentication secret, we mock one to
// prevent malicious probing (possible due to missing protocol steps).
// This mocked secret will never lead to successful authentication.
info!("authentication info not found, mocking it");
AuthInfo::Scram(scram::ServerSecret::mock(creds.user, rand::random()))
});
let info = console::get_auth_info(api, extra, creds).await?;
let flow = AuthFlow::new(client);
let scram_keys = match info {
AuthInfo::Md5(_) => {
info!("auth endpoint chooses MD5");
return Err(auth::AuthError::bad_auth_method("MD5"));
}
AuthInfo::Scram(secret) => {
info!("auth endpoint chooses SCRAM");
let scram = auth::Scram(&secret);
let client_key = match flow.begin(scram).await?.authenticate().await? {
sasl::Outcome::Success(key) => key,
sasl::Outcome::Failure(reason) => {
info!("auth backend failed with an error: {reason}");
return Err(auth::AuthError::auth_failed(creds.user));
}
};
Some(compute::ScramKeys {
client_key: client_key.as_bytes(),
server_key: secret.server_key.as_bytes(),
})
}
};
let secret = info.scram_or_goodbye()?;
let scram_keys = do_scram(secret, creds, client).await?;
let mut node = api.wake_compute(extra, creds).await?;
if let Some(keys) = scram_keys {
use tokio_postgres::config::AuthKeys;
node.config.auth_keys(AuthKeys::ScramSha256(keys));
}
node.config.auth_keys(AuthKeys::ScramSha256(scram_keys));
Ok(AuthSuccess {
reported_auth_ok: false,

View File

@@ -5,11 +5,33 @@ use crate::{
self,
provider::{CachedNodeInfo, ConsoleReqExtra},
},
stream,
stream::PqStream,
};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, warn};
/// Wake the compute node, but only if the password is valid.
async fn get_compute(
api: &impl console::Api,
extra: &ConsoleReqExtra<'_>,
creds: &mut ClientCredentials<'_>,
password: Vec<u8>,
) -> auth::Result<CachedNodeInfo> {
// TODO: this will slow down both "hacks" below; we probably need a cache.
let info = console::get_auth_info(api, extra, creds).await?;
let secret = info.scram_or_goodbye()?;
if !secret.matches_password(&password) {
info!("our obscure magic indicates that the password doesn't match");
return Err(auth::AuthError::auth_failed(creds.user));
}
let mut node = api.wake_compute(extra, creds).await?;
node.config.password(password);
Ok(node)
}
/// Compared to [SCRAM](crate::scram), cleartext password auth saves
/// one round trip and *expensive* computations (>= 4096 HMAC iterations).
/// These properties are benefical for serverless JS workers, so we
@@ -18,7 +40,7 @@ pub async fn cleartext_hack(
api: &impl console::Api,
extra: &ConsoleReqExtra<'_>,
creds: &mut ClientCredentials<'_>,
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<AuthSuccess<CachedNodeInfo>> {
warn!("cleartext auth flow override is enabled, proceeding");
let password = AuthFlow::new(client)
@@ -27,8 +49,7 @@ pub async fn cleartext_hack(
.authenticate()
.await?;
let mut node = api.wake_compute(extra, creds).await?;
node.config.password(password);
let node = get_compute(api, extra, creds, password).await?;
// Report tentative success; compute node will check the password anyway.
Ok(AuthSuccess {
@@ -43,7 +64,7 @@ pub async fn password_hack(
api: &impl console::Api,
extra: &ConsoleReqExtra<'_>,
creds: &mut ClientCredentials<'_>,
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<AuthSuccess<CachedNodeInfo>> {
warn!("project not specified, resorting to the password hack auth flow");
let payload = AuthFlow::new(client)
@@ -55,8 +76,7 @@ pub async fn password_hack(
info!(project = &payload.project, "received missing parameter");
creds.project = Some(payload.project.into());
let mut node = api.wake_compute(extra, creds).await?;
node.config.password(payload.password);
let node = get_compute(api, extra, creds, payload.password).await?;
// Report tentative success; compute node will check the password anyway.
Ok(AuthSuccess {

View File

@@ -31,12 +31,22 @@ pub enum ClientCredsParseError {
impl UserFacingError for ClientCredsParseError {}
/// eSNI parameters which might contain endpoint/project name.
#[derive(Default)]
pub struct SniParams<'a> {
/// Server Name Indication (TLS jargon).
pub sni: Option<&'a str>,
/// Common Name from a TLS certificate.
pub common_name: Option<&'a str>,
}
/// Various client credentials which we use for authentication.
/// Note that we don't store any kind of client key or password here.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ClientCredentials<'a> {
/// Name of postgres role.
pub user: &'a str,
// TODO: this is a severe misnomer! We should think of a new name ASAP.
/// Also known as endpoint in the console.
pub project: Option<Cow<'a, str>>,
}
@@ -49,18 +59,17 @@ impl ClientCredentials<'_> {
impl<'a> ClientCredentials<'a> {
pub fn parse(
params: &'a StartupMessageParams,
sni: Option<&str>,
common_name: Option<&str>,
startup_params: &'a StartupMessageParams,
&SniParams { sni, common_name }: &SniParams<'_>,
) -> Result<Self, ClientCredsParseError> {
use ClientCredsParseError::*;
// Some parameters are stored in the startup message.
let get_param = |key| params.get(key).ok_or(MissingKey(key));
let get_param = |key| startup_params.get(key).ok_or(MissingKey(key));
let user = get_param("user")?;
// Project name might be passed via PG's command-line options.
let project_option = params.options_raw().and_then(|mut options| {
let project_option = startup_params.options_raw().and_then(|mut options| {
options
.find_map(|opt| opt.strip_prefix("project="))
.map(Cow::Borrowed)
@@ -122,7 +131,9 @@ mod tests {
// According to postgresql, only `user` should be required.
let options = StartupMessageParams::new([("user", "john_doe")]);
let creds = ClientCredentials::parse(&options, None, None)?;
let sni = SniParams::default();
let creds = ClientCredentials::parse(&options, &sni)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.project, None);
@@ -131,13 +142,15 @@ mod tests {
#[test]
fn parse_excessive() -> anyhow::Result<()> {
let options = StartupMessageParams::new([
let startup = StartupMessageParams::new([
("user", "john_doe"),
("database", "world"), // should be ignored
("foo", "bar"), // should be ignored
]);
let creds = ClientCredentials::parse(&options, None, None)?;
let sni = SniParams::default();
let creds = ClientCredentials::parse(&startup, &sni)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.project, None);
@@ -146,12 +159,14 @@ mod tests {
#[test]
fn parse_project_from_sni() -> anyhow::Result<()> {
let options = StartupMessageParams::new([("user", "john_doe")]);
let startup = StartupMessageParams::new([("user", "john_doe")]);
let sni = Some("foo.localhost");
let common_name = Some("localhost");
let sni = SniParams {
sni: Some("foo.localhost"),
common_name: Some("localhost"),
};
let creds = ClientCredentials::parse(&options, sni, common_name)?;
let creds = ClientCredentials::parse(&startup, &sni)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.project.as_deref(), Some("foo"));
@@ -160,12 +175,14 @@ mod tests {
#[test]
fn parse_project_from_options() -> anyhow::Result<()> {
let options = StartupMessageParams::new([
let startup = StartupMessageParams::new([
("user", "john_doe"),
("options", "-ckey=1 project=bar -c geqo=off"),
]);
let creds = ClientCredentials::parse(&options, None, None)?;
let sni = SniParams::default();
let creds = ClientCredentials::parse(&startup, &sni)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.project.as_deref(), Some("bar"));
@@ -174,12 +191,17 @@ mod tests {
#[test]
fn parse_projects_identical() -> anyhow::Result<()> {
let options = StartupMessageParams::new([("user", "john_doe"), ("options", "project=baz")]);
let startup = StartupMessageParams::new([
("user", "john_doe"),
("options", "project=baz"), // fmt
]);
let sni = Some("baz.localhost");
let common_name = Some("localhost");
let sni = SniParams {
sni: Some("baz.localhost"),
common_name: Some("localhost"),
};
let creds = ClientCredentials::parse(&options, sni, common_name)?;
let creds = ClientCredentials::parse(&startup, &sni)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.project.as_deref(), Some("baz"));
@@ -188,13 +210,17 @@ mod tests {
#[test]
fn parse_projects_different() {
let options =
StartupMessageParams::new([("user", "john_doe"), ("options", "project=first")]);
let startup = StartupMessageParams::new([
("user", "john_doe"),
("options", "project=first"), // fmt
]);
let sni = Some("second.localhost");
let common_name = Some("localhost");
let sni = SniParams {
sni: Some("second.localhost"),
common_name: Some("localhost"),
};
let err = ClientCredentials::parse(&options, sni, common_name).expect_err("should fail");
let err = ClientCredentials::parse(&startup, &sni).expect_err("should fail");
match err {
InconsistentProjectNames { domain, option } => {
assert_eq!(option, "first");
@@ -206,12 +232,14 @@ mod tests {
#[test]
fn parse_inconsistent_sni() {
let options = StartupMessageParams::new([("user", "john_doe")]);
let startup = StartupMessageParams::new([("user", "john_doe")]);
let sni = Some("project.localhost");
let common_name = Some("example.com");
let sni = SniParams {
sni: Some("project.localhost"),
common_name: Some("example.com"),
};
let err = ClientCredentials::parse(&options, sni, common_name).expect_err("should fail");
let err = ClientCredentials::parse(&startup, &sni).expect_err("should fail");
match err {
InconsistentSni { sni, cn } => {
assert_eq!(sni, "project.localhost");

View File

@@ -6,7 +6,9 @@ pub mod messages;
/// Wrappers for console APIs and their mocks.
pub mod provider;
pub use provider::{errors, Api, AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo};
pub use provider::{
errors, get_auth_info, Api, AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
};
/// Various cache-related types.
pub mod caches {

View File

@@ -9,6 +9,7 @@ use crate::{
};
use async_trait::async_trait;
use std::sync::Arc;
use tracing::info;
pub mod errors {
use crate::{
@@ -175,6 +176,12 @@ pub struct NodeInfo {
pub type NodeInfoCache = TimedLru<Arc<str>, NodeInfo>;
pub type CachedNodeInfo = timed_lru::Cached<&'static NodeInfoCache>;
/// Various caches for [`console`].
pub struct ApiCaches {
/// Cache for the `wake_compute` API method.
pub node_info: NodeInfoCache,
}
/// This will allocate per each call, but the http requests alone
/// already require a few allocations, so it should be fine.
#[async_trait]
@@ -194,8 +201,21 @@ pub trait Api {
) -> Result<CachedNodeInfo, errors::WakeComputeError>;
}
/// Various caches for [`console`].
pub struct ApiCaches {
/// Cache for the `wake_compute` API method.
pub node_info: NodeInfoCache,
/// A more insightful version of [`Api::get_auth_info`] which
/// knows what to do when we get [`None`] instead of [`AuthInfo`].
pub async fn get_auth_info(
api: &impl Api,
extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials<'_>,
) -> Result<AuthInfo, errors::GetAuthInfoError> {
info!("fetching user's authentication info");
let info = api.get_auth_info(extra, creds).await?.unwrap_or_else(|| {
// If we don't have an authentication secret, we mock one to
// prevent malicious probing (possible due to missing protocol steps).
// This mocked secret will never lead to successful authentication.
info!("authentication info not found, mocking it");
AuthInfo::Scram(scram::ServerSecret::mock(creds.user, rand::random()))
});
Ok(info)
}

View File

@@ -2,7 +2,7 @@
mod tests;
use crate::{
auth::{self, backend::AuthSuccess},
auth::{self, backend::AuthSuccess, credentials},
cancellation::{self, CancelMap},
compute::{self, PostgresConnection},
config::{ProxyConfig, TlsConfig},
@@ -112,7 +112,6 @@ pub async fn handle_ws_client(
}
let tls = config.tls_config.as_ref();
let hostname = hostname.as_deref();
// TLS is None here, because the connection is already encrypted.
let do_handshake = handshake(stream, None, cancel_map);
@@ -121,13 +120,17 @@ pub async fn handle_ws_client(
None => return Ok(()), // it's a cancellation request
};
let sni = credentials::SniParams {
sni: hostname.as_deref(),
common_name: tls.and_then(|tls| tls.common_name.as_deref()),
};
// Extract credentials which we're going to use for auth.
let creds = {
let common_name = tls.and_then(|tls| tls.common_name.as_deref());
let result = config
.auth_backend
.as_ref()
.map(|_| auth::ClientCredentials::parse(&params, hostname, common_name))
.map(|_| auth::ClientCredentials::parse(&params, &sni))
.transpose();
async { result }.or_else(|e| stream.throw_error(e)).await?
@@ -159,14 +162,17 @@ async fn handle_client(
None => return Ok(()), // it's a cancellation request
};
let sni = credentials::SniParams {
sni: stream.get_ref().sni_hostname(),
common_name: tls.and_then(|tls| tls.common_name.as_deref()),
};
// Extract credentials which we're going to use for auth.
let creds = {
let sni = stream.get_ref().sni_hostname();
let common_name = tls.and_then(|tls| tls.common_name.as_deref());
let result = config
.auth_backend
.as_ref()
.map(|_| auth::ClientCredentials::parse(&params, sni, common_name))
.map(|_| auth::ClientCredentials::parse(&params, &sni))
.transpose();
async { result }.or_else(|e| stream.throw_error(e)).await?

View File

@@ -92,10 +92,10 @@ impl TestAuth for NoAuth {}
struct Scram(scram::ServerSecret);
impl Scram {
fn new(password: &str) -> anyhow::Result<Self> {
fn new(password: &[u8]) -> anyhow::Result<Self> {
let salt = rand::random::<[u8; 16]>();
let secret = scram::ServerSecret::build(password, &salt, 256)
.context("failed to generate scram secret")?;
let secret = scram::ServerSecret::build(password, &salt, 256);
Ok(Scram(secret))
}
@@ -230,11 +230,11 @@ async fn keepalive_is_inherited() -> anyhow::Result<()> {
}
#[rstest]
#[case("password_foo")]
#[case("pwd-bar")]
#[case("")]
#[case(b"password_foo")]
#[case(b"pwd-bar")]
#[case(b"")]
#[tokio::test]
async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> {
async fn scram_auth_good(#[case] password: &[u8]) -> anyhow::Result<()> {
let (client, server) = tokio::io::duplex(1024);
let (client_config, server_config) =

View File

@@ -12,7 +12,6 @@ mod messages;
mod secret;
mod signature;
#[cfg(test)]
mod password;
pub use exchange::Exchange;

View File

@@ -73,7 +73,7 @@ impl sasl::Mechanism for Exchange<'_> {
let server_first_message = client_first_message.build_server_first_message(
&(self.nonce)(),
&self.secret.salt_base64,
&self.secret.salt,
self.secret.iterations,
);
let msg = server_first_message.as_str().to_owned();

View File

@@ -75,19 +75,27 @@ impl<'a> ClientFirstMessage<'a> {
pub fn build_server_first_message(
&self,
nonce: &[u8; SCRAM_RAW_NONCE_LEN],
salt_base64: &str,
salt: &[u8],
iterations: u32,
) -> OwnedServerFirstMessage {
use std::fmt::Write;
let mut message = String::new();
// Write base64-encoded combined nonce.
write!(&mut message, "r={}", self.nonce).unwrap();
base64::encode_config_buf(nonce, base64::STANDARD, &mut message);
let combined_nonce = 2..message.len();
write!(&mut message, ",s={},i={}", salt_base64, iterations).unwrap();
// Write base64-encoded salt.
write!(&mut message, ",s=").unwrap();
base64::encode_config_buf(salt, base64::STANDARD, &mut message);
// Write number of iterations.
write!(&mut message, ",i={iterations}").unwrap();
// This design guarantees that it's impossible to create a
// server-first-message without receiving a client-first-message
// server-first-message without receiving a client-first-message.
OwnedServerFirstMessage {
message,
nonce: combined_nonce,
@@ -229,4 +237,49 @@ mod tests {
"SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI="
);
}
#[test]
fn build_server_messages() {
let input = "n,,n=pepe,r=t8JwklwKecDLwSsA72rHmVju";
let client_first_message = ClientFirstMessage::parse(input).unwrap();
let nonce = [0; 18];
let salt = [1, 2, 3];
let iterations = 4096;
let server_first_message =
client_first_message.build_server_first_message(&nonce, &salt, iterations);
assert_eq!(
server_first_message.message,
"r=t8JwklwKecDLwSsA72rHmVjuAAAAAAAAAAAAAAAAAAAAAAAA,s=AQID,i=4096"
);
assert_eq!(
server_first_message.nonce(),
"t8JwklwKecDLwSsA72rHmVjuAAAAAAAAAAAAAAAAAAAAAAAA"
);
let input = [
"c=eSws",
"r=iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU",
"p=SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI=",
]
.join(",");
let client_final_message = ClientFinalMessage::parse(&input).unwrap();
let signature_builder = SignatureBuilder {
client_first_message_bare: client_first_message.bare,
server_first_message: server_first_message.as_str(),
client_final_message_without_proof: client_final_message.without_proof,
};
let server_key = ScramKey::default();
let server_final_message =
client_final_message.build_server_final_message(signature_builder, &server_key);
assert_eq!(
server_final_message,
"v=XEL4X1vy5LnqIgOo4hOjm7zd1Ceyo9+nBUE+/zVnqLE="
);
}
}

View File

@@ -1,6 +1,7 @@
//! Password hashing routines.
use super::key::ScramKey;
use tracing::warn;
pub const SALTED_PASSWORD_LEN: usize = 32;
@@ -13,7 +14,12 @@ pub struct SaltedPassword {
impl SaltedPassword {
/// See `scram-common.c : scram_SaltedPassword` for details.
/// Further reading: <https://datatracker.ietf.org/doc/html/rfc2898> (see `PBKDF2`).
/// TODO: implement proper password normalization required by the RFC!
pub fn new(password: &[u8], salt: &[u8], iterations: u32) -> SaltedPassword {
if !password.is_ascii() {
warn!("found non-ascii symbols in password! salted password might be broken");
}
let one = 1_u32.to_be_bytes(); // magic
let mut current = super::hmac_sha256(password, [salt, &one]);
@@ -30,6 +36,7 @@ impl SaltedPassword {
}
/// Derive `ClientKey` from a salted hashed password.
#[cfg(test)]
pub fn client_key(&self) -> ScramKey {
super::hmac_sha256(&self.bytes, [b"Client Key".as_ref()]).into()
}

View File

@@ -1,15 +1,14 @@
//! Tools for SCRAM server secret management.
use super::base64_decode_array;
use super::key::ScramKey;
use super::{base64_decode_array, key::ScramKey, password::SaltedPassword};
/// Server secret is produced from [password](super::password::SaltedPassword)
/// Server secret is produced from [password](SaltedPassword)
/// and is used throughout the authentication process.
pub struct ServerSecret {
/// Number of iterations for `PBKDF2` function.
pub iterations: u32,
/// Salt used to hash user's password.
pub salt_base64: String,
pub salt: Vec<u8>,
/// Hashed `ClientKey`.
pub stored_key: ScramKey,
/// Used by client to verify server's signature.
@@ -30,7 +29,7 @@ impl ServerSecret {
let secret = ServerSecret {
iterations: iterations.parse().ok()?,
salt_base64: salt.to_owned(),
salt: base64::decode(salt).ok()?,
stored_key: base64_decode_array(stored_key)?.into(),
server_key: base64_decode_array(server_key)?.into(),
doomed: false,
@@ -48,31 +47,31 @@ impl ServerSecret {
Self {
iterations: 4096,
salt_base64: base64::encode(mocked_salt),
salt: mocked_salt.into(),
stored_key: ScramKey::default(),
server_key: ScramKey::default(),
doomed: true,
}
}
/// Check if this secret was derived from the given password.
pub fn matches_password(&self, password: &[u8]) -> bool {
let password = SaltedPassword::new(password, &self.salt, self.iterations);
self.server_key == password.server_key()
}
/// Build a new server secret from the prerequisites.
/// XXX: We only use this function in tests.
#[cfg(test)]
pub fn build(password: &str, salt: &[u8], iterations: u32) -> Option<Self> {
// TODO: implement proper password normalization required by the RFC
if !password.is_ascii() {
return None;
}
pub fn build(password: &[u8], salt: &[u8], iterations: u32) -> Self {
let password = SaltedPassword::new(password, salt, iterations);
let password = super::password::SaltedPassword::new(password.as_bytes(), salt, iterations);
Some(Self {
Self {
iterations,
salt_base64: base64::encode(salt),
salt: salt.into(),
stored_key: password.client_key().sha256(),
server_key: password.server_key(),
doomed: false,
})
}
}
}
@@ -87,17 +86,11 @@ mod tests {
let stored_key = "D5h6KTMBlUvDJk2Y8ELfC1Sjtc6k9YHjRyuRZyBNJns=";
let server_key = "Pi3QHbcluX//NDfVkKlFl88GGzlJ5LkyPwcdlN/QBvI=";
let secret = format!(
"SCRAM-SHA-256${iterations}:{salt}${stored_key}:{server_key}",
iterations = iterations,
salt = salt,
stored_key = stored_key,
server_key = server_key,
);
let secret = format!("SCRAM-SHA-256${iterations}:{salt}${stored_key}:{server_key}");
let parsed = ServerSecret::parse(&secret).unwrap();
assert_eq!(parsed.iterations, iterations);
assert_eq!(parsed.salt_base64, salt);
assert_eq!(base64::encode(parsed.salt), salt);
assert_eq!(base64::encode(parsed.stored_key), stored_key);
assert_eq!(base64::encode(parsed.server_key), server_key);
@@ -106,9 +99,9 @@ mod tests {
#[test]
fn build_scram_secret() {
let salt = b"salt";
let secret = ServerSecret::build("password", salt, 4096).unwrap();
let secret = ServerSecret::build(b"password", salt, 4096);
assert_eq!(secret.iterations, 4096);
assert_eq!(secret.salt_base64, base64::encode(salt));
assert_eq!(secret.salt, salt);
assert_eq!(
base64::encode(secret.stored_key.as_ref()),
"lF4cRm/Jky763CN4HtxdHnjV4Q8AWTNlKvGmEFFU8IQ="
@@ -118,4 +111,12 @@ mod tests {
"ub8OgRsftnk2ccDMOt7ffHXNcikRkQkq1lh4xaAqrSw="
);
}
#[test]
fn secret_match_password() {
let password = b"password";
let secret = ServerSecret::build(password, b"salt", 2);
assert!(secret.matches_password(password));
assert!(!secret.matches_password(b"different"));
}
}