Compare commits

...

6 Commits

Author SHA1 Message Date
Dmitry Ivanov
632c07cab5 Fix lints 2023-02-28 19:40:28 +03:00
Dmitry Ivanov
e9f73707c7 [proxy] Prevent unauthorized wake-ups in the "password hack" flow 2023-02-28 19:27:08 +03:00
Dmitry Ivanov
f9f40fa41d [proxy] Introduce SniParams for creds parsing 2023-02-28 19:27:08 +03:00
Dmitry Ivanov
021ab8365f [proxy] Refactoring in the classic auth backend 2023-02-28 19:27:08 +03:00
Alexander Bayandin
000eb1b069 Bump tempfile from 3.3.0 to 3.4.0 (#3709)
Update `tempfile` crate to get rid of `remove_dir_all` dependency
Ref https://github.com/neondatabase/neon/security/dependabot/15
2023-02-27 12:44:08 +00:00
Heikki Linnakangas
f51b48fa49 Fix UNLOGGED tables.
Instead of trying to create missing files on the way, send init fork contents as
main fork from pageserver during basebackup. Add test for that. Call
put_rel_drop for init forks; previously they weren't removed. Bump
vendor/postgres to revert previous approach on Postgres side.

Co-authored-by: Arseny Sher <sher-ars@yandex.ru>

ref https://github.com/neondatabase/postgres/pull/264
ref https://github.com/neondatabase/postgres/pull/259
ref https://github.com/neondatabase/neon/issues/1222
2023-02-24 23:30:02 +04:00
22 changed files with 363 additions and 159 deletions

18
Cargo.lock generated
View File

@@ -3067,15 +3067,6 @@ dependencies = [
"workspace_hack",
]
[[package]]
name = "remove_dir_all"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3acd125665422973a33ac9d3dd2df85edad0f4ae9b00dafb1a05e43a9f5ef8e7"
dependencies = [
"winapi",
]
[[package]]
name = "reqwest"
version = "0.11.14"
@@ -3849,16 +3840,15 @@ dependencies = [
[[package]]
name = "tempfile"
version = "3.3.0"
version = "3.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5cdb1ef4eaeeaddc8fbd371e5017057064af0911902ef36b39801f67cc6d79e4"
checksum = "af18f7ae1acd354b992402e9ec5864359d693cd8a79dcbef59f76891701c1e95"
dependencies = [
"cfg-if",
"fastrand",
"libc",
"redox_syscall",
"remove_dir_all",
"winapi",
"rustix",
"windows-sys 0.42.0",
]
[[package]]

View File

@@ -150,7 +150,7 @@ workspace_hack = { version = "0.1", path = "./workspace_hack/" }
criterion = "0.4"
rcgen = "0.10"
rstest = "0.16"
tempfile = "3.2"
tempfile = "3.4"
tonic-build = "0.8"
# This is only needed for proxy's tests.

View File

@@ -98,6 +98,15 @@ impl RelTag {
name
}
pub fn with_forknum(&self, forknum: u8) -> Self {
RelTag {
forknum,
spcnode: self.spcnode,
dbnode: self.dbnode,
relnode: self.relnode,
}
}
}
///

View File

@@ -33,6 +33,7 @@ use pageserver_api::reltag::{RelTag, SlruKind};
use postgres_ffi::pg_constants::{DEFAULTTABLESPACE_OID, GLOBALTABLESPACE_OID};
use postgres_ffi::pg_constants::{PGDATA_SPECIAL_FILES, PGDATA_SUBDIRS, PG_HBA};
use postgres_ffi::relfile_utils::{INIT_FORKNUM, MAIN_FORKNUM};
use postgres_ffi::TransactionId;
use postgres_ffi::XLogFileName;
use postgres_ffi::PG_TLI;
@@ -190,14 +191,31 @@ where
{
self.add_dbdir(spcnode, dbnode, has_relmap_file).await?;
// Gather and send relational files in each database if full backup is requested.
if self.full_backup {
for rel in self
.timeline
.list_rels(spcnode, dbnode, self.lsn, self.ctx)
.await?
{
self.add_rel(rel).await?;
// If full backup is requested, include all relation files.
// Otherwise only include init forks of unlogged relations.
let rels = self
.timeline
.list_rels(spcnode, dbnode, self.lsn, self.ctx)
.await?;
for &rel in rels.iter() {
// Send init fork as main fork to provide well formed empty
// contents of UNLOGGED relations. Postgres copies it in
// `reinit.c` during recovery.
if rel.forknum == INIT_FORKNUM {
// I doubt we need _init fork itself, but having it at least
// serves as a marker relation is unlogged.
self.add_rel(rel, rel).await?;
self.add_rel(rel, rel.with_forknum(MAIN_FORKNUM)).await?;
continue;
}
if self.full_backup {
if rel.forknum == MAIN_FORKNUM && rels.contains(&rel.with_forknum(INIT_FORKNUM))
{
// skip this, will include it when we reach the init fork
continue;
}
self.add_rel(rel, rel).await?;
}
}
}
@@ -220,15 +238,16 @@ where
Ok(())
}
async fn add_rel(&mut self, tag: RelTag) -> anyhow::Result<()> {
/// Add contents of relfilenode `src`, naming it as `dst`.
async fn add_rel(&mut self, src: RelTag, dst: RelTag) -> anyhow::Result<()> {
let nblocks = self
.timeline
.get_rel_size(tag, self.lsn, false, self.ctx)
.get_rel_size(src, self.lsn, false, self.ctx)
.await?;
// If the relation is empty, create an empty file
if nblocks == 0 {
let file_name = tag.to_segfile_name(0);
let file_name = dst.to_segfile_name(0);
let header = new_tar_header(&file_name, 0)?;
self.ar.append(&header, &mut io::empty()).await?;
return Ok(());
@@ -244,12 +263,12 @@ where
for blknum in startblk..endblk {
let img = self
.timeline
.get_rel_page_at_lsn(tag, blknum, self.lsn, false, self.ctx)
.get_rel_page_at_lsn(src, blknum, self.lsn, false, self.ctx)
.await?;
segment_data.extend_from_slice(&img[..]);
}
let file_name = tag.to_segfile_name(seg as u32);
let file_name = dst.to_segfile_name(seg as u32);
let header = new_tar_header(&file_name, segment_data.len() as u64)?;
self.ar.append(&header, segment_data.as_slice()).await?;

View File

@@ -37,7 +37,7 @@ use crate::walrecord::*;
use crate::ZERO_PAGE;
use pageserver_api::reltag::{RelTag, SlruKind};
use postgres_ffi::pg_constants;
use postgres_ffi::relfile_utils::{FSM_FORKNUM, MAIN_FORKNUM, VISIBILITYMAP_FORKNUM};
use postgres_ffi::relfile_utils::{FSM_FORKNUM, INIT_FORKNUM, MAIN_FORKNUM, VISIBILITYMAP_FORKNUM};
use postgres_ffi::v14::nonrelfile_utils::mx_offset_to_member_segment;
use postgres_ffi::v14::xlog_utils::*;
use postgres_ffi::v14::CheckPoint;
@@ -762,7 +762,7 @@ impl<'a> WalIngest<'a> {
)?;
for xnode in &parsed.xnodes {
for forknum in MAIN_FORKNUM..=VISIBILITYMAP_FORKNUM {
for forknum in MAIN_FORKNUM..=INIT_FORKNUM {
let rel = RelTag {
forknum,
spcnode: xnode.spcnode,

View File

@@ -3,7 +3,7 @@
pub mod backend;
pub use backend::BackendType;
mod credentials;
pub mod credentials;
pub use credentials::ClientCredentials;
mod password_hack;

View File

@@ -11,7 +11,7 @@ use crate::{
provider::{CachedNodeInfo, ConsoleReqExtra},
Api,
},
stream, url,
scram, stream, url,
};
use futures::TryFutureExt;
use std::borrow::Cow;
@@ -59,8 +59,8 @@ impl std::fmt::Display for BackendType<'_, ()> {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
use BackendType::*;
match self {
Console(endpoint, _) => fmt.debug_tuple("Console").field(&endpoint.url()).finish(),
Postgres(endpoint, _) => fmt.debug_tuple("Postgres").field(&endpoint.url()).finish(),
Console(api, _) => fmt.debug_tuple("Console").field(&api.url()).finish(),
Postgres(api, _) => fmt.debug_tuple("Postgres").field(&api.url()).finish(),
Link(url) => fmt.debug_tuple("Link").field(&url.as_str()).finish(),
}
}
@@ -106,6 +106,23 @@ impl<'a, T, E> BackendType<'a, Result<T, E>> {
}
}
impl console::AuthInfo {
/// Either it's our way ([SCRAM](crate::scram)) or the highway :)
/// But seriously, we don't aim to support anything but SCRAM for now.
fn scram_or_goodbye(self) -> auth::Result<scram::ServerSecret> {
match self {
Self::Md5(_) => {
info!("auth endpoint chooses MD5");
Err(auth::AuthError::bad_auth_method("MD5"))
}
Self::Scram(secret) => {
info!("auth endpoint chooses SCRAM");
Ok(secret)
}
}
}
}
/// True to its name, this function encapsulates our current auth trade-offs.
/// Here, we choose the appropriate auth flow based on circumstances.
async fn auth_quirks(
@@ -183,7 +200,9 @@ impl BackendType<'_, ClientCredentials<'_>> {
info!("user successfully authenticated");
Ok(res)
}
}
impl BackendType<'_, ClientCredentials<'_>> {
/// When applicable, wake the compute node, gaining its connection info in the process.
/// The link auth flow doesn't support this, so we return [`None`] in that case.
pub async fn wake_compute(

View File

@@ -2,57 +2,54 @@ use super::AuthSuccess;
use crate::{
auth::{self, AuthFlow, ClientCredentials},
compute,
console::{self, AuthInfo, CachedNodeInfo, ConsoleReqExtra},
console::{self, CachedNodeInfo, ConsoleReqExtra},
sasl, scram,
stream::PqStream,
};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_postgres::config::AuthKeys;
use tracing::info;
pub(super) async fn authenticate(
async fn do_scram(
secret: scram::ServerSecret,
creds: &ClientCredentials<'_>,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<compute::ScramKeys> {
let outcome = AuthFlow::new(client)
.begin(auth::Scram(&secret))
.await?
.authenticate()
.await?;
let client_key = match outcome {
sasl::Outcome::Success(key) => key,
sasl::Outcome::Failure(reason) => {
info!("auth backend failed with an error: {reason}");
return Err(auth::AuthError::auth_failed(creds.user));
}
};
let keys = compute::ScramKeys {
client_key: client_key.as_bytes(),
server_key: secret.server_key.as_bytes(),
};
Ok(keys)
}
pub async fn authenticate(
api: &impl console::Api,
extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials<'_>,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<AuthSuccess<CachedNodeInfo>> {
info!("fetching user's authentication info");
let info = api.get_auth_info(extra, creds).await?.unwrap_or_else(|| {
// If we don't have an authentication secret, we mock one to
// prevent malicious probing (possible due to missing protocol steps).
// This mocked secret will never lead to successful authentication.
info!("authentication info not found, mocking it");
AuthInfo::Scram(scram::ServerSecret::mock(creds.user, rand::random()))
});
let info = console::get_auth_info(api, extra, creds).await?;
let flow = AuthFlow::new(client);
let scram_keys = match info {
AuthInfo::Md5(_) => {
info!("auth endpoint chooses MD5");
return Err(auth::AuthError::bad_auth_method("MD5"));
}
AuthInfo::Scram(secret) => {
info!("auth endpoint chooses SCRAM");
let scram = auth::Scram(&secret);
let client_key = match flow.begin(scram).await?.authenticate().await? {
sasl::Outcome::Success(key) => key,
sasl::Outcome::Failure(reason) => {
info!("auth backend failed with an error: {reason}");
return Err(auth::AuthError::auth_failed(creds.user));
}
};
Some(compute::ScramKeys {
client_key: client_key.as_bytes(),
server_key: secret.server_key.as_bytes(),
})
}
};
let secret = info.scram_or_goodbye()?;
let scram_keys = do_scram(secret, creds, client).await?;
let mut node = api.wake_compute(extra, creds).await?;
if let Some(keys) = scram_keys {
use tokio_postgres::config::AuthKeys;
node.config.auth_keys(AuthKeys::ScramSha256(keys));
}
node.config.auth_keys(AuthKeys::ScramSha256(scram_keys));
Ok(AuthSuccess {
reported_auth_ok: false,

View File

@@ -5,11 +5,33 @@ use crate::{
self,
provider::{CachedNodeInfo, ConsoleReqExtra},
},
stream,
stream::PqStream,
};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, warn};
/// Wake the compute node, but only if the password is valid.
async fn get_compute(
api: &impl console::Api,
extra: &ConsoleReqExtra<'_>,
creds: &mut ClientCredentials<'_>,
password: Vec<u8>,
) -> auth::Result<CachedNodeInfo> {
// TODO: this will slow down both "hacks" below; we probably need a cache.
let info = console::get_auth_info(api, extra, creds).await?;
let secret = info.scram_or_goodbye()?;
if !secret.matches_password(&password) {
info!("our obscure magic indicates that the password doesn't match");
return Err(auth::AuthError::auth_failed(creds.user));
}
let mut node = api.wake_compute(extra, creds).await?;
node.config.password(password);
Ok(node)
}
/// Compared to [SCRAM](crate::scram), cleartext password auth saves
/// one round trip and *expensive* computations (>= 4096 HMAC iterations).
/// These properties are benefical for serverless JS workers, so we
@@ -18,7 +40,7 @@ pub async fn cleartext_hack(
api: &impl console::Api,
extra: &ConsoleReqExtra<'_>,
creds: &mut ClientCredentials<'_>,
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<AuthSuccess<CachedNodeInfo>> {
warn!("cleartext auth flow override is enabled, proceeding");
let password = AuthFlow::new(client)
@@ -27,8 +49,7 @@ pub async fn cleartext_hack(
.authenticate()
.await?;
let mut node = api.wake_compute(extra, creds).await?;
node.config.password(password);
let node = get_compute(api, extra, creds, password).await?;
// Report tentative success; compute node will check the password anyway.
Ok(AuthSuccess {
@@ -43,7 +64,7 @@ pub async fn password_hack(
api: &impl console::Api,
extra: &ConsoleReqExtra<'_>,
creds: &mut ClientCredentials<'_>,
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> auth::Result<AuthSuccess<CachedNodeInfo>> {
warn!("project not specified, resorting to the password hack auth flow");
let payload = AuthFlow::new(client)
@@ -55,8 +76,7 @@ pub async fn password_hack(
info!(project = &payload.project, "received missing parameter");
creds.project = Some(payload.project.into());
let mut node = api.wake_compute(extra, creds).await?;
node.config.password(payload.password);
let node = get_compute(api, extra, creds, payload.password).await?;
// Report tentative success; compute node will check the password anyway.
Ok(AuthSuccess {

View File

@@ -31,12 +31,22 @@ pub enum ClientCredsParseError {
impl UserFacingError for ClientCredsParseError {}
/// eSNI parameters which might contain endpoint/project name.
#[derive(Default)]
pub struct SniParams<'a> {
/// Server Name Indication (TLS jargon).
pub sni: Option<&'a str>,
/// Common Name from a TLS certificate.
pub common_name: Option<&'a str>,
}
/// Various client credentials which we use for authentication.
/// Note that we don't store any kind of client key or password here.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ClientCredentials<'a> {
/// Name of postgres role.
pub user: &'a str,
// TODO: this is a severe misnomer! We should think of a new name ASAP.
/// Also known as endpoint in the console.
pub project: Option<Cow<'a, str>>,
}
@@ -49,18 +59,17 @@ impl ClientCredentials<'_> {
impl<'a> ClientCredentials<'a> {
pub fn parse(
params: &'a StartupMessageParams,
sni: Option<&str>,
common_name: Option<&str>,
startup_params: &'a StartupMessageParams,
&SniParams { sni, common_name }: &SniParams<'_>,
) -> Result<Self, ClientCredsParseError> {
use ClientCredsParseError::*;
// Some parameters are stored in the startup message.
let get_param = |key| params.get(key).ok_or(MissingKey(key));
let get_param = |key| startup_params.get(key).ok_or(MissingKey(key));
let user = get_param("user")?;
// Project name might be passed via PG's command-line options.
let project_option = params.options_raw().and_then(|mut options| {
let project_option = startup_params.options_raw().and_then(|mut options| {
options
.find_map(|opt| opt.strip_prefix("project="))
.map(Cow::Borrowed)
@@ -122,7 +131,9 @@ mod tests {
// According to postgresql, only `user` should be required.
let options = StartupMessageParams::new([("user", "john_doe")]);
let creds = ClientCredentials::parse(&options, None, None)?;
let sni = SniParams::default();
let creds = ClientCredentials::parse(&options, &sni)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.project, None);
@@ -131,13 +142,15 @@ mod tests {
#[test]
fn parse_excessive() -> anyhow::Result<()> {
let options = StartupMessageParams::new([
let startup = StartupMessageParams::new([
("user", "john_doe"),
("database", "world"), // should be ignored
("foo", "bar"), // should be ignored
]);
let creds = ClientCredentials::parse(&options, None, None)?;
let sni = SniParams::default();
let creds = ClientCredentials::parse(&startup, &sni)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.project, None);
@@ -146,12 +159,14 @@ mod tests {
#[test]
fn parse_project_from_sni() -> anyhow::Result<()> {
let options = StartupMessageParams::new([("user", "john_doe")]);
let startup = StartupMessageParams::new([("user", "john_doe")]);
let sni = Some("foo.localhost");
let common_name = Some("localhost");
let sni = SniParams {
sni: Some("foo.localhost"),
common_name: Some("localhost"),
};
let creds = ClientCredentials::parse(&options, sni, common_name)?;
let creds = ClientCredentials::parse(&startup, &sni)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.project.as_deref(), Some("foo"));
@@ -160,12 +175,14 @@ mod tests {
#[test]
fn parse_project_from_options() -> anyhow::Result<()> {
let options = StartupMessageParams::new([
let startup = StartupMessageParams::new([
("user", "john_doe"),
("options", "-ckey=1 project=bar -c geqo=off"),
]);
let creds = ClientCredentials::parse(&options, None, None)?;
let sni = SniParams::default();
let creds = ClientCredentials::parse(&startup, &sni)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.project.as_deref(), Some("bar"));
@@ -174,12 +191,17 @@ mod tests {
#[test]
fn parse_projects_identical() -> anyhow::Result<()> {
let options = StartupMessageParams::new([("user", "john_doe"), ("options", "project=baz")]);
let startup = StartupMessageParams::new([
("user", "john_doe"),
("options", "project=baz"), // fmt
]);
let sni = Some("baz.localhost");
let common_name = Some("localhost");
let sni = SniParams {
sni: Some("baz.localhost"),
common_name: Some("localhost"),
};
let creds = ClientCredentials::parse(&options, sni, common_name)?;
let creds = ClientCredentials::parse(&startup, &sni)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.project.as_deref(), Some("baz"));
@@ -188,13 +210,17 @@ mod tests {
#[test]
fn parse_projects_different() {
let options =
StartupMessageParams::new([("user", "john_doe"), ("options", "project=first")]);
let startup = StartupMessageParams::new([
("user", "john_doe"),
("options", "project=first"), // fmt
]);
let sni = Some("second.localhost");
let common_name = Some("localhost");
let sni = SniParams {
sni: Some("second.localhost"),
common_name: Some("localhost"),
};
let err = ClientCredentials::parse(&options, sni, common_name).expect_err("should fail");
let err = ClientCredentials::parse(&startup, &sni).expect_err("should fail");
match err {
InconsistentProjectNames { domain, option } => {
assert_eq!(option, "first");
@@ -206,12 +232,14 @@ mod tests {
#[test]
fn parse_inconsistent_sni() {
let options = StartupMessageParams::new([("user", "john_doe")]);
let startup = StartupMessageParams::new([("user", "john_doe")]);
let sni = Some("project.localhost");
let common_name = Some("example.com");
let sni = SniParams {
sni: Some("project.localhost"),
common_name: Some("example.com"),
};
let err = ClientCredentials::parse(&options, sni, common_name).expect_err("should fail");
let err = ClientCredentials::parse(&startup, &sni).expect_err("should fail");
match err {
InconsistentSni { sni, cn } => {
assert_eq!(sni, "project.localhost");

View File

@@ -6,7 +6,9 @@ pub mod messages;
/// Wrappers for console APIs and their mocks.
pub mod provider;
pub use provider::{errors, Api, AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo};
pub use provider::{
errors, get_auth_info, Api, AuthInfo, CachedNodeInfo, ConsoleReqExtra, NodeInfo,
};
/// Various cache-related types.
pub mod caches {

View File

@@ -9,6 +9,7 @@ use crate::{
};
use async_trait::async_trait;
use std::sync::Arc;
use tracing::info;
pub mod errors {
use crate::{
@@ -175,6 +176,12 @@ pub struct NodeInfo {
pub type NodeInfoCache = TimedLru<Arc<str>, NodeInfo>;
pub type CachedNodeInfo = timed_lru::Cached<&'static NodeInfoCache>;
/// Various caches for [`console`].
pub struct ApiCaches {
/// Cache for the `wake_compute` API method.
pub node_info: NodeInfoCache,
}
/// This will allocate per each call, but the http requests alone
/// already require a few allocations, so it should be fine.
#[async_trait]
@@ -194,8 +201,21 @@ pub trait Api {
) -> Result<CachedNodeInfo, errors::WakeComputeError>;
}
/// Various caches for [`console`].
pub struct ApiCaches {
/// Cache for the `wake_compute` API method.
pub node_info: NodeInfoCache,
/// A more insightful version of [`Api::get_auth_info`] which
/// knows what to do when we get [`None`] instead of [`AuthInfo`].
pub async fn get_auth_info(
api: &impl Api,
extra: &ConsoleReqExtra<'_>,
creds: &ClientCredentials<'_>,
) -> Result<AuthInfo, errors::GetAuthInfoError> {
info!("fetching user's authentication info");
let info = api.get_auth_info(extra, creds).await?.unwrap_or_else(|| {
// If we don't have an authentication secret, we mock one to
// prevent malicious probing (possible due to missing protocol steps).
// This mocked secret will never lead to successful authentication.
info!("authentication info not found, mocking it");
AuthInfo::Scram(scram::ServerSecret::mock(creds.user, rand::random()))
});
Ok(info)
}

View File

@@ -2,7 +2,7 @@
mod tests;
use crate::{
auth::{self, backend::AuthSuccess},
auth::{self, backend::AuthSuccess, credentials},
cancellation::{self, CancelMap},
compute::{self, PostgresConnection},
config::{ProxyConfig, TlsConfig},
@@ -112,7 +112,6 @@ pub async fn handle_ws_client(
}
let tls = config.tls_config.as_ref();
let hostname = hostname.as_deref();
// TLS is None here, because the connection is already encrypted.
let do_handshake = handshake(stream, None, cancel_map);
@@ -121,13 +120,17 @@ pub async fn handle_ws_client(
None => return Ok(()), // it's a cancellation request
};
let sni = credentials::SniParams {
sni: hostname.as_deref(),
common_name: tls.and_then(|tls| tls.common_name.as_deref()),
};
// Extract credentials which we're going to use for auth.
let creds = {
let common_name = tls.and_then(|tls| tls.common_name.as_deref());
let result = config
.auth_backend
.as_ref()
.map(|_| auth::ClientCredentials::parse(&params, hostname, common_name))
.map(|_| auth::ClientCredentials::parse(&params, &sni))
.transpose();
async { result }.or_else(|e| stream.throw_error(e)).await?
@@ -159,14 +162,17 @@ async fn handle_client(
None => return Ok(()), // it's a cancellation request
};
let sni = credentials::SniParams {
sni: stream.get_ref().sni_hostname(),
common_name: tls.and_then(|tls| tls.common_name.as_deref()),
};
// Extract credentials which we're going to use for auth.
let creds = {
let sni = stream.get_ref().sni_hostname();
let common_name = tls.and_then(|tls| tls.common_name.as_deref());
let result = config
.auth_backend
.as_ref()
.map(|_| auth::ClientCredentials::parse(&params, sni, common_name))
.map(|_| auth::ClientCredentials::parse(&params, &sni))
.transpose();
async { result }.or_else(|e| stream.throw_error(e)).await?

View File

@@ -92,10 +92,10 @@ impl TestAuth for NoAuth {}
struct Scram(scram::ServerSecret);
impl Scram {
fn new(password: &str) -> anyhow::Result<Self> {
fn new(password: &[u8]) -> anyhow::Result<Self> {
let salt = rand::random::<[u8; 16]>();
let secret = scram::ServerSecret::build(password, &salt, 256)
.context("failed to generate scram secret")?;
let secret = scram::ServerSecret::build(password, &salt, 256);
Ok(Scram(secret))
}
@@ -230,11 +230,11 @@ async fn keepalive_is_inherited() -> anyhow::Result<()> {
}
#[rstest]
#[case("password_foo")]
#[case("pwd-bar")]
#[case("")]
#[case(b"password_foo")]
#[case(b"pwd-bar")]
#[case(b"")]
#[tokio::test]
async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> {
async fn scram_auth_good(#[case] password: &[u8]) -> anyhow::Result<()> {
let (client, server) = tokio::io::duplex(1024);
let (client_config, server_config) =

View File

@@ -12,7 +12,6 @@ mod messages;
mod secret;
mod signature;
#[cfg(test)]
mod password;
pub use exchange::Exchange;

View File

@@ -73,7 +73,7 @@ impl sasl::Mechanism for Exchange<'_> {
let server_first_message = client_first_message.build_server_first_message(
&(self.nonce)(),
&self.secret.salt_base64,
&self.secret.salt,
self.secret.iterations,
);
let msg = server_first_message.as_str().to_owned();

View File

@@ -75,19 +75,27 @@ impl<'a> ClientFirstMessage<'a> {
pub fn build_server_first_message(
&self,
nonce: &[u8; SCRAM_RAW_NONCE_LEN],
salt_base64: &str,
salt: &[u8],
iterations: u32,
) -> OwnedServerFirstMessage {
use std::fmt::Write;
let mut message = String::new();
// Write base64-encoded combined nonce.
write!(&mut message, "r={}", self.nonce).unwrap();
base64::encode_config_buf(nonce, base64::STANDARD, &mut message);
let combined_nonce = 2..message.len();
write!(&mut message, ",s={},i={}", salt_base64, iterations).unwrap();
// Write base64-encoded salt.
write!(&mut message, ",s=").unwrap();
base64::encode_config_buf(salt, base64::STANDARD, &mut message);
// Write number of iterations.
write!(&mut message, ",i={iterations}").unwrap();
// This design guarantees that it's impossible to create a
// server-first-message without receiving a client-first-message
// server-first-message without receiving a client-first-message.
OwnedServerFirstMessage {
message,
nonce: combined_nonce,
@@ -229,4 +237,49 @@ mod tests {
"SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI="
);
}
#[test]
fn build_server_messages() {
let input = "n,,n=pepe,r=t8JwklwKecDLwSsA72rHmVju";
let client_first_message = ClientFirstMessage::parse(input).unwrap();
let nonce = [0; 18];
let salt = [1, 2, 3];
let iterations = 4096;
let server_first_message =
client_first_message.build_server_first_message(&nonce, &salt, iterations);
assert_eq!(
server_first_message.message,
"r=t8JwklwKecDLwSsA72rHmVjuAAAAAAAAAAAAAAAAAAAAAAAA,s=AQID,i=4096"
);
assert_eq!(
server_first_message.nonce(),
"t8JwklwKecDLwSsA72rHmVjuAAAAAAAAAAAAAAAAAAAAAAAA"
);
let input = [
"c=eSws",
"r=iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU",
"p=SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI=",
]
.join(",");
let client_final_message = ClientFinalMessage::parse(&input).unwrap();
let signature_builder = SignatureBuilder {
client_first_message_bare: client_first_message.bare,
server_first_message: server_first_message.as_str(),
client_final_message_without_proof: client_final_message.without_proof,
};
let server_key = ScramKey::default();
let server_final_message =
client_final_message.build_server_final_message(signature_builder, &server_key);
assert_eq!(
server_final_message,
"v=XEL4X1vy5LnqIgOo4hOjm7zd1Ceyo9+nBUE+/zVnqLE="
);
}
}

View File

@@ -1,6 +1,7 @@
//! Password hashing routines.
use super::key::ScramKey;
use tracing::warn;
pub const SALTED_PASSWORD_LEN: usize = 32;
@@ -13,7 +14,12 @@ pub struct SaltedPassword {
impl SaltedPassword {
/// See `scram-common.c : scram_SaltedPassword` for details.
/// Further reading: <https://datatracker.ietf.org/doc/html/rfc2898> (see `PBKDF2`).
/// TODO: implement proper password normalization required by the RFC!
pub fn new(password: &[u8], salt: &[u8], iterations: u32) -> SaltedPassword {
if !password.is_ascii() {
warn!("found non-ascii symbols in password! salted password might be broken");
}
let one = 1_u32.to_be_bytes(); // magic
let mut current = super::hmac_sha256(password, [salt, &one]);
@@ -30,6 +36,7 @@ impl SaltedPassword {
}
/// Derive `ClientKey` from a salted hashed password.
#[cfg(test)]
pub fn client_key(&self) -> ScramKey {
super::hmac_sha256(&self.bytes, [b"Client Key".as_ref()]).into()
}

View File

@@ -1,15 +1,14 @@
//! Tools for SCRAM server secret management.
use super::base64_decode_array;
use super::key::ScramKey;
use super::{base64_decode_array, key::ScramKey, password::SaltedPassword};
/// Server secret is produced from [password](super::password::SaltedPassword)
/// Server secret is produced from [password](SaltedPassword)
/// and is used throughout the authentication process.
pub struct ServerSecret {
/// Number of iterations for `PBKDF2` function.
pub iterations: u32,
/// Salt used to hash user's password.
pub salt_base64: String,
pub salt: Vec<u8>,
/// Hashed `ClientKey`.
pub stored_key: ScramKey,
/// Used by client to verify server's signature.
@@ -30,7 +29,7 @@ impl ServerSecret {
let secret = ServerSecret {
iterations: iterations.parse().ok()?,
salt_base64: salt.to_owned(),
salt: base64::decode(salt).ok()?,
stored_key: base64_decode_array(stored_key)?.into(),
server_key: base64_decode_array(server_key)?.into(),
doomed: false,
@@ -48,31 +47,31 @@ impl ServerSecret {
Self {
iterations: 4096,
salt_base64: base64::encode(mocked_salt),
salt: mocked_salt.into(),
stored_key: ScramKey::default(),
server_key: ScramKey::default(),
doomed: true,
}
}
/// Check if this secret was derived from the given password.
pub fn matches_password(&self, password: &[u8]) -> bool {
let password = SaltedPassword::new(password, &self.salt, self.iterations);
self.server_key == password.server_key()
}
/// Build a new server secret from the prerequisites.
/// XXX: We only use this function in tests.
#[cfg(test)]
pub fn build(password: &str, salt: &[u8], iterations: u32) -> Option<Self> {
// TODO: implement proper password normalization required by the RFC
if !password.is_ascii() {
return None;
}
pub fn build(password: &[u8], salt: &[u8], iterations: u32) -> Self {
let password = SaltedPassword::new(password, salt, iterations);
let password = super::password::SaltedPassword::new(password.as_bytes(), salt, iterations);
Some(Self {
Self {
iterations,
salt_base64: base64::encode(salt),
salt: salt.into(),
stored_key: password.client_key().sha256(),
server_key: password.server_key(),
doomed: false,
})
}
}
}
@@ -87,17 +86,11 @@ mod tests {
let stored_key = "D5h6KTMBlUvDJk2Y8ELfC1Sjtc6k9YHjRyuRZyBNJns=";
let server_key = "Pi3QHbcluX//NDfVkKlFl88GGzlJ5LkyPwcdlN/QBvI=";
let secret = format!(
"SCRAM-SHA-256${iterations}:{salt}${stored_key}:{server_key}",
iterations = iterations,
salt = salt,
stored_key = stored_key,
server_key = server_key,
);
let secret = format!("SCRAM-SHA-256${iterations}:{salt}${stored_key}:{server_key}");
let parsed = ServerSecret::parse(&secret).unwrap();
assert_eq!(parsed.iterations, iterations);
assert_eq!(parsed.salt_base64, salt);
assert_eq!(base64::encode(parsed.salt), salt);
assert_eq!(base64::encode(parsed.stored_key), stored_key);
assert_eq!(base64::encode(parsed.server_key), server_key);
@@ -106,9 +99,9 @@ mod tests {
#[test]
fn build_scram_secret() {
let salt = b"salt";
let secret = ServerSecret::build("password", salt, 4096).unwrap();
let secret = ServerSecret::build(b"password", salt, 4096);
assert_eq!(secret.iterations, 4096);
assert_eq!(secret.salt_base64, base64::encode(salt));
assert_eq!(secret.salt, salt);
assert_eq!(
base64::encode(secret.stored_key.as_ref()),
"lF4cRm/Jky763CN4HtxdHnjV4Q8AWTNlKvGmEFFU8IQ="
@@ -118,4 +111,12 @@ mod tests {
"ub8OgRsftnk2ccDMOt7ffHXNcikRkQkq1lh4xaAqrSw="
);
}
#[test]
fn secret_match_password() {
let password = b"password";
let secret = ServerSecret::build(password, b"salt", 2);
assert!(secret.matches_password(password));
assert!(!secret.matches_password(b"different"));
}
}

View File

@@ -0,0 +1,34 @@
from fixtures.neon_fixtures import NeonEnv, fork_at_current_lsn
#
# Test UNLOGGED tables/relations. Postgres copies init fork contents to main
# fork to reset them during recovery. In Neon, pageserver directly sends init
# fork contents as main fork during basebackup.
#
def test_unlogged(neon_simple_env: NeonEnv):
env = neon_simple_env
env.neon_cli.create_branch("test_unlogged", "empty")
pg = env.postgres.create_start("test_unlogged")
conn = pg.connect()
cur = conn.cursor()
cur.execute("CREATE UNLOGGED TABLE iut (id int);")
# create index to test unlogged index relation as well
cur.execute("CREATE UNIQUE INDEX iut_idx ON iut (id);")
cur.execute("INSERT INTO iut values (42);")
# create another compute to fetch inital empty contents from pageserver
fork_at_current_lsn(env, pg, "test_unlogged_basebackup", "test_unlogged")
pg2 = env.postgres.create_start(
"test_unlogged_basebackup",
)
conn2 = pg2.connect()
cur2 = conn2.cursor()
# after restart table should be empty but valid
cur2.execute("PREPARE iut_plan (int) AS INSERT INTO iut VALUES ($1)")
cur2.execute("EXECUTE iut_plan (43);")
cur2.execute("SELECT * FROM iut")
assert cur2.fetchall() == [(43,)]