mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-23 16:10:37 +00:00
Added invariant check for project name. (#1921)
Summary: Added invariant checking for project name. Refactored ClientCredentials and TlsConfig.
* Added formatting invariant check for project name:
**\forall c \in project_name . c \in [alnum] U {'-'}.
** sni_data == <project_name>.<common_name>
* Added exhaustive tests for get_project_name.
* Refactored TlsConfig to contain common_name : Option<String>.
* Refactored ClientCredentials construction to construct project_name directly.
* Merged ProjectNameError into ClientCredsParseError.
* Tweaked proxy tests to accommodate refactored ClientCredentials construction semantics.
* [Pytests] Added project option argument to test_proxy_select_1.
* Removed project param from Api since now it's contained in creds.
* Refactored &Option<String> -> Option<&str>.
Co-authored-by: Dmitrii Ivanov <dima@neon.tech>.
This commit is contained in:
119
Cargo.lock
generated
119
Cargo.lock
generated
@@ -64,6 +64,45 @@ dependencies = [
|
||||
"nodrop",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "asn1-rs"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "30ff05a702273012438132f449575dbc804e27b2f3cbe3069aa237d26c98fa33"
|
||||
dependencies = [
|
||||
"asn1-rs-derive",
|
||||
"asn1-rs-impl",
|
||||
"displaydoc",
|
||||
"nom",
|
||||
"num-traits",
|
||||
"rusticata-macros",
|
||||
"thiserror",
|
||||
"time 0.3.9",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "asn1-rs-derive"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "db8b7511298d5b7784b40b092d9e9dcd3a627a5707e4b5e507931ab0d44eeebf"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"synstructure",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "asn1-rs-impl"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2777730b2039ac0f95f093556e61b6d26cebed5393ca6f152717777cec3a42ed"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-stream"
|
||||
version = "0.3.3"
|
||||
@@ -712,6 +751,12 @@ dependencies = [
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "data-encoding"
|
||||
version = "2.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3ee2393c4a91429dffb4bedf19f4d6abf27d8a732c8ce4980305d782e5426d57"
|
||||
|
||||
[[package]]
|
||||
name = "debugid"
|
||||
version = "0.7.3"
|
||||
@@ -721,6 +766,20 @@ dependencies = [
|
||||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "der-parser"
|
||||
version = "7.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fe398ac75057914d7d07307bf67dc7f3f574a26783b4fc7805a20ffa9f506e82"
|
||||
dependencies = [
|
||||
"asn1-rs",
|
||||
"displaydoc",
|
||||
"nom",
|
||||
"num-bigint",
|
||||
"num-traits",
|
||||
"rusticata-macros",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "digest"
|
||||
version = "0.9.0"
|
||||
@@ -762,6 +821,17 @@ dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "displaydoc"
|
||||
version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3bf95dc3f046b9da4f2d51833c0d3547d8564ef6910f5c1ed130306a75b92886"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "either"
|
||||
version = "1.6.1"
|
||||
@@ -1731,6 +1801,15 @@ dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "oid-registry"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "38e20717fa0541f39bd146692035c37bedfa532b3e5071b35761082407546b2a"
|
||||
dependencies = [
|
||||
"asn1-rs",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "once_cell"
|
||||
version = "1.10.0"
|
||||
@@ -2250,6 +2329,7 @@ dependencies = [
|
||||
"url",
|
||||
"utils",
|
||||
"workspace_hack",
|
||||
"x509-parser",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2621,6 +2701,15 @@ dependencies = [
|
||||
"semver",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rusticata-macros"
|
||||
version = "4.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "faf0c4a6ece9950b9abdb62b1cfcf2a68b3b67a10ba445b3bb85be2a293d0632"
|
||||
dependencies = [
|
||||
"nom",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls"
|
||||
version = "0.20.4"
|
||||
@@ -3060,6 +3149,18 @@ version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "20518fe4a4c9acf048008599e464deb21beeae3d3578418951a189c235a7a9a8"
|
||||
|
||||
[[package]]
|
||||
name = "synstructure"
|
||||
version = "0.12.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f36bdaa60a83aca3921b5259d5400cbf5e90fc51931376a9bd4a0eb79aa7210f"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"unicode-xid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tar"
|
||||
version = "0.4.38"
|
||||
@@ -3922,6 +4023,24 @@ dependencies = [
|
||||
"tracing-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "x509-parser"
|
||||
version = "0.13.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9fb9bace5b5589ffead1afb76e43e34cff39cd0f3ce7e170ae0c29e53b88eb1c"
|
||||
dependencies = [
|
||||
"asn1-rs",
|
||||
"base64",
|
||||
"data-encoding",
|
||||
"der-parser",
|
||||
"lazy_static",
|
||||
"nom",
|
||||
"oid-registry",
|
||||
"rusticata-macros",
|
||||
"thiserror",
|
||||
"time 0.3.9",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "xattr"
|
||||
version = "0.2.2"
|
||||
|
||||
@@ -39,6 +39,8 @@ utils = { path = "../libs/utils" }
|
||||
metrics = { path = "../libs/metrics" }
|
||||
workspace_hack = { version = "0.1", path = "../workspace_hack" }
|
||||
|
||||
x509-parser = "0.13.2"
|
||||
|
||||
[dev-dependencies]
|
||||
rcgen = "0.8.14"
|
||||
rstest = "0.12"
|
||||
|
||||
@@ -19,7 +19,7 @@ pub type Result<T> = std::result::Result<T, ConsoleAuthError>;
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ConsoleAuthError {
|
||||
#[error(transparent)]
|
||||
BadProjectName(#[from] auth::credentials::ProjectNameError),
|
||||
BadProjectName(#[from] auth::credentials::ClientCredsParseError),
|
||||
|
||||
// We shouldn't include the actual secret here.
|
||||
#[error("Bad authentication secret")]
|
||||
@@ -74,18 +74,12 @@ pub enum AuthInfo {
|
||||
pub(super) struct Api<'a> {
|
||||
endpoint: &'a ApiUrl,
|
||||
creds: &'a ClientCredentials,
|
||||
/// Cache project name, since we'll need it several times.
|
||||
project: &'a str,
|
||||
}
|
||||
|
||||
impl<'a> Api<'a> {
|
||||
/// Construct an API object containing the auth parameters.
|
||||
pub(super) fn new(endpoint: &'a ApiUrl, creds: &'a ClientCredentials) -> Result<Self> {
|
||||
Ok(Self {
|
||||
endpoint,
|
||||
creds,
|
||||
project: creds.project_name()?,
|
||||
})
|
||||
Ok(Self { endpoint, creds })
|
||||
}
|
||||
|
||||
/// Authenticate the existing user or throw an error.
|
||||
@@ -100,7 +94,7 @@ impl<'a> Api<'a> {
|
||||
let mut url = self.endpoint.clone();
|
||||
url.path_segments_mut().push("proxy_get_role_secret");
|
||||
url.query_pairs_mut()
|
||||
.append_pair("project", self.project)
|
||||
.append_pair("project", &self.creds.project_name)
|
||||
.append_pair("role", &self.creds.user);
|
||||
|
||||
// TODO: use a proper logger
|
||||
@@ -123,7 +117,8 @@ impl<'a> Api<'a> {
|
||||
async fn wake_compute(&self) -> Result<DatabaseInfo> {
|
||||
let mut url = self.endpoint.clone();
|
||||
url.path_segments_mut().push("proxy_wake_compute");
|
||||
url.query_pairs_mut().append_pair("project", self.project);
|
||||
url.query_pairs_mut()
|
||||
.append_pair("project", &self.creds.project_name);
|
||||
|
||||
// TODO: use a proper logger
|
||||
println!("cplane request: {url}");
|
||||
|
||||
@@ -8,10 +8,32 @@ use std::collections::HashMap;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[derive(Debug, Error, PartialEq)]
|
||||
pub enum ClientCredsParseError {
|
||||
#[error("Parameter `{0}` is missing in startup packet")]
|
||||
#[error("Parameter `{0}` is missing in startup packet.")]
|
||||
MissingKey(&'static str),
|
||||
|
||||
#[error(
|
||||
"Project name is not specified. \
|
||||
EITHER please upgrade the postgres client library (libpq) for SNI support \
|
||||
OR pass the project name as a parameter: '&options=project%3D<project-name>'."
|
||||
)]
|
||||
MissingSNIAndProjectName,
|
||||
|
||||
#[error("Inconsistent project name inferred from SNI ('{0}') and project option ('{1}').")]
|
||||
InconsistentProjectNameAndSNI(String, String),
|
||||
|
||||
#[error("Common name is not set.")]
|
||||
CommonNameNotSet,
|
||||
|
||||
#[error(
|
||||
"SNI ('{1}') inconsistently formatted with respect to common name ('{0}'). \
|
||||
SNI should be formatted as '<project-name>.<common-name>'."
|
||||
)]
|
||||
InconsistentCommonNameAndSNI(String, String),
|
||||
|
||||
#[error("Project name ('{0}') must contain only alphanumeric characters and hyphens ('-').")]
|
||||
ProjectNameContainsIllegalChars(String),
|
||||
}
|
||||
|
||||
impl UserFacingError for ClientCredsParseError {}
|
||||
@@ -22,15 +44,7 @@ impl UserFacingError for ClientCredsParseError {}
|
||||
pub struct ClientCredentials {
|
||||
pub user: String,
|
||||
pub dbname: String,
|
||||
|
||||
// New console API requires SNI info to determine the cluster name.
|
||||
// Other Auth backends don't need it.
|
||||
pub sni_data: Option<String>,
|
||||
|
||||
// project_name is passed as argument from options from url.
|
||||
// In case sni_data is missing: project_name is used to determine cluster name.
|
||||
// In case sni_data is available: project_name and sni_data should match (otherwise throws an error).
|
||||
pub project_name: Option<String>,
|
||||
pub project_name: String,
|
||||
}
|
||||
|
||||
impl ClientCredentials {
|
||||
@@ -38,60 +52,14 @@ impl ClientCredentials {
|
||||
// This logic will likely change in the future.
|
||||
self.user.ends_with("@zenith")
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ProjectNameError {
|
||||
#[error("SNI is missing. EITHER please upgrade the postgres client library OR pass the project name as a parameter: '...&options=project%3D<project-name>...'.")]
|
||||
Missing,
|
||||
|
||||
#[error("SNI is malformed.")]
|
||||
Bad,
|
||||
|
||||
#[error("Inconsistent project name inferred from SNI and project option. String from SNI: '{0}', String from project option: '{1}'")]
|
||||
Inconsistent(String, String),
|
||||
}
|
||||
|
||||
impl UserFacingError for ProjectNameError {}
|
||||
|
||||
impl ClientCredentials {
|
||||
/// Determine project name from SNI or from project_name parameter from options argument.
|
||||
pub fn project_name(&self) -> Result<&str, ProjectNameError> {
|
||||
// Checking that if both sni_data and project_name are set, then they should match
|
||||
// otherwise, throws a ProjectNameError::Inconsistent error.
|
||||
if let Some(sni_data) = &self.sni_data {
|
||||
let project_name_from_sni_data =
|
||||
sni_data.split_once('.').ok_or(ProjectNameError::Bad)?.0;
|
||||
if let Some(project_name_from_options) = &self.project_name {
|
||||
if !project_name_from_options.eq(project_name_from_sni_data) {
|
||||
return Err(ProjectNameError::Inconsistent(
|
||||
project_name_from_sni_data.to_string(),
|
||||
project_name_from_options.to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
// determine the project name from self.sni_data if it exists, otherwise from self.project_name.
|
||||
let ret = match &self.sni_data {
|
||||
// if sni_data exists, use it to determine project name
|
||||
Some(sni_data) => sni_data.split_once('.').ok_or(ProjectNameError::Bad)?.0,
|
||||
// otherwise use project_option if it was manually set thought options parameter.
|
||||
None => self
|
||||
.project_name
|
||||
.as_ref()
|
||||
.ok_or(ProjectNameError::Missing)?
|
||||
.as_str(),
|
||||
};
|
||||
Ok(ret)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<HashMap<String, String>> for ClientCredentials {
|
||||
type Error = ClientCredsParseError;
|
||||
|
||||
fn try_from(mut value: HashMap<String, String>) -> Result<Self, Self::Error> {
|
||||
pub fn parse(
|
||||
mut options: HashMap<String, String>,
|
||||
sni_data: Option<&str>,
|
||||
common_name: Option<&str>,
|
||||
) -> Result<Self, ClientCredsParseError> {
|
||||
let mut get_param = |key| {
|
||||
value
|
||||
options
|
||||
.remove(key)
|
||||
.ok_or(ClientCredsParseError::MissingKey(key))
|
||||
};
|
||||
@@ -99,17 +67,15 @@ impl TryFrom<HashMap<String, String>> for ClientCredentials {
|
||||
let user = get_param("user")?;
|
||||
let dbname = get_param("database")?;
|
||||
let project_name = get_param("project").ok();
|
||||
let project_name = get_project_name(sni_data, common_name, project_name.as_deref())?;
|
||||
|
||||
Ok(Self {
|
||||
user,
|
||||
dbname,
|
||||
sni_data: None,
|
||||
project_name,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientCredentials {
|
||||
/// Use credentials to authenticate the user.
|
||||
pub async fn authenticate(
|
||||
self,
|
||||
@@ -120,3 +86,244 @@ impl ClientCredentials {
|
||||
super::backend::handle_user(config, client, self).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Inferring project name from sni_data.
|
||||
fn project_name_from_sni_data(
|
||||
sni_data: &str,
|
||||
common_name: &str,
|
||||
) -> Result<String, ClientCredsParseError> {
|
||||
let common_name_with_dot = format!(".{common_name}");
|
||||
// check that ".{common_name_with_dot}" is the actual suffix in sni_data
|
||||
if !sni_data.ends_with(&common_name_with_dot) {
|
||||
return Err(ClientCredsParseError::InconsistentCommonNameAndSNI(
|
||||
common_name.to_string(),
|
||||
sni_data.to_string(),
|
||||
));
|
||||
}
|
||||
// return sni_data without the common name suffix.
|
||||
Ok(sni_data
|
||||
.strip_suffix(&common_name_with_dot)
|
||||
.unwrap()
|
||||
.to_string())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests_for_project_name_from_sni_data {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn passing() {
|
||||
let target_project_name = "my-project-123";
|
||||
let common_name = "localtest.me";
|
||||
let sni_data = format!("{target_project_name}.{common_name}");
|
||||
assert_eq!(
|
||||
project_name_from_sni_data(&sni_data, common_name),
|
||||
Ok(target_project_name.to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn throws_inconsistent_common_name_and_sni_data() {
|
||||
let target_project_name = "my-project-123";
|
||||
let common_name = "localtest.me";
|
||||
let wrong_suffix = "wrongtest.me";
|
||||
assert_eq!(common_name.len(), wrong_suffix.len());
|
||||
let wrong_common_name = format!("wrong{wrong_suffix}");
|
||||
let sni_data = format!("{target_project_name}.{wrong_common_name}");
|
||||
assert_eq!(
|
||||
project_name_from_sni_data(&sni_data, common_name),
|
||||
Err(ClientCredsParseError::InconsistentCommonNameAndSNI(
|
||||
common_name.to_string(),
|
||||
sni_data
|
||||
))
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Determine project name from SNI or from project_name parameter from options argument.
|
||||
fn get_project_name(
|
||||
sni_data: Option<&str>,
|
||||
common_name: Option<&str>,
|
||||
project_name: Option<&str>,
|
||||
) -> Result<String, ClientCredsParseError> {
|
||||
// determine the project name from sni_data if it exists, otherwise from project_name.
|
||||
let ret = match sni_data {
|
||||
Some(sni_data) => {
|
||||
let common_name = common_name.ok_or(ClientCredsParseError::CommonNameNotSet)?;
|
||||
let project_name_from_sni = project_name_from_sni_data(sni_data, common_name)?;
|
||||
// check invariant: project name from options and from sni should match
|
||||
if let Some(project_name) = &project_name {
|
||||
if !project_name_from_sni.eq(project_name) {
|
||||
return Err(ClientCredsParseError::InconsistentProjectNameAndSNI(
|
||||
project_name_from_sni,
|
||||
project_name.to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
project_name_from_sni
|
||||
}
|
||||
None => project_name
|
||||
.ok_or(ClientCredsParseError::MissingSNIAndProjectName)?
|
||||
.to_string(),
|
||||
};
|
||||
|
||||
// check formatting invariant: project name must contain only alphanumeric characters and hyphens.
|
||||
if !ret.chars().all(|x: char| x.is_alphanumeric() || x == '-') {
|
||||
return Err(ClientCredsParseError::ProjectNameContainsIllegalChars(ret));
|
||||
}
|
||||
|
||||
Ok(ret)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests_for_project_name_only {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn passing_from_sni_data_only() {
|
||||
let target_project_name = "my-project-123";
|
||||
let common_name = "localtest.me";
|
||||
let sni_data = format!("{target_project_name}.{common_name}");
|
||||
assert_eq!(
|
||||
get_project_name(Some(&sni_data), Some(common_name), None),
|
||||
Ok(target_project_name.to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn throws_project_name_contains_illegal_chars_from_sni_data_only() {
|
||||
let project_name_prefix = "my-project";
|
||||
let project_name_suffix = "123";
|
||||
let common_name = "localtest.me";
|
||||
|
||||
for illegal_char_id in 0..256 {
|
||||
let illegal_char = char::from_u32(illegal_char_id).unwrap();
|
||||
if !(illegal_char.is_alphanumeric() || illegal_char == '-')
|
||||
&& illegal_char.to_string().len() == 1
|
||||
{
|
||||
let target_project_name =
|
||||
format!("{project_name_prefix}{illegal_char}{project_name_suffix}");
|
||||
let sni_data = format!("{target_project_name}.{common_name}");
|
||||
assert_eq!(
|
||||
get_project_name(Some(&sni_data), Some(common_name), None),
|
||||
Err(ClientCredsParseError::ProjectNameContainsIllegalChars(
|
||||
target_project_name
|
||||
))
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn passing_from_project_name_only() {
|
||||
let target_project_name = "my-project-123";
|
||||
let common_names = [Some("localtest.me"), None];
|
||||
for common_name in common_names {
|
||||
assert_eq!(
|
||||
get_project_name(None, common_name, Some(target_project_name)),
|
||||
Ok(target_project_name.to_string())
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn throws_project_name_contains_illegal_chars_from_project_name_only() {
|
||||
let project_name_prefix = "my-project";
|
||||
let project_name_suffix = "123";
|
||||
let common_names = [Some("localtest.me"), None];
|
||||
|
||||
for common_name in common_names {
|
||||
for illegal_char_id in 0..256 {
|
||||
let illegal_char: char = char::from_u32(illegal_char_id).unwrap();
|
||||
if !(illegal_char.is_alphanumeric() || illegal_char == '-')
|
||||
&& illegal_char.to_string().len() == 1
|
||||
{
|
||||
let target_project_name =
|
||||
format!("{project_name_prefix}{illegal_char}{project_name_suffix}");
|
||||
assert_eq!(
|
||||
get_project_name(None, common_name, Some(&target_project_name)),
|
||||
Err(ClientCredsParseError::ProjectNameContainsIllegalChars(
|
||||
target_project_name
|
||||
))
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn passing_from_sni_data_and_project_name() {
|
||||
let target_project_name = "my-project-123";
|
||||
let common_name = "localtest.me";
|
||||
let sni_data = format!("{target_project_name}.{common_name}");
|
||||
assert_eq!(
|
||||
get_project_name(
|
||||
Some(&sni_data),
|
||||
Some(common_name),
|
||||
Some(target_project_name)
|
||||
),
|
||||
Ok(target_project_name.to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn throws_inconsistent_project_name_and_sni() {
|
||||
let project_name_param = "my-project-123";
|
||||
let wrong_project_name = "not-my-project-123";
|
||||
let common_name = "localtest.me";
|
||||
let sni_data = format!("{wrong_project_name}.{common_name}");
|
||||
assert_eq!(
|
||||
get_project_name(Some(&sni_data), Some(common_name), Some(project_name_param)),
|
||||
Err(ClientCredsParseError::InconsistentProjectNameAndSNI(
|
||||
wrong_project_name.to_string(),
|
||||
project_name_param.to_string()
|
||||
))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn throws_common_name_not_set() {
|
||||
let target_project_name = "my-project-123";
|
||||
let wrong_project_name = "not-my-project-123";
|
||||
let common_name = "localtest.me";
|
||||
let sni_datas = [
|
||||
Some(format!("{wrong_project_name}.{common_name}")),
|
||||
Some(format!("{target_project_name}.{common_name}")),
|
||||
];
|
||||
let project_names = [None, Some(target_project_name)];
|
||||
for sni_data in sni_datas {
|
||||
for project_name_param in project_names {
|
||||
assert_eq!(
|
||||
get_project_name(sni_data.as_deref(), None, project_name_param),
|
||||
Err(ClientCredsParseError::CommonNameNotSet)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn throws_inconsistent_common_name_and_sni_data() {
|
||||
let target_project_name = "my-project-123";
|
||||
let wrong_project_name = "not-my-project-123";
|
||||
let common_name = "localtest.me";
|
||||
let wrong_suffix = "wrongtest.me";
|
||||
assert_eq!(common_name.len(), wrong_suffix.len());
|
||||
let wrong_common_name = format!("wrong{wrong_suffix}");
|
||||
let sni_datas = [
|
||||
Some(format!("{wrong_project_name}.{wrong_common_name}")),
|
||||
Some(format!("{target_project_name}.{wrong_common_name}")),
|
||||
];
|
||||
let project_names = [None, Some(target_project_name)];
|
||||
for project_name_param in project_names {
|
||||
for sni_data in &sni_datas {
|
||||
assert_eq!(
|
||||
get_project_name(sni_data.as_deref(), Some(common_name), project_name_param),
|
||||
Err(ClientCredsParseError::InconsistentCommonNameAndSNI(
|
||||
common_name.to_string(),
|
||||
sni_data.clone().unwrap().to_string()
|
||||
))
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,23 +36,35 @@ pub struct ProxyConfig {
|
||||
pub auth_link_uri: ApiUrl,
|
||||
}
|
||||
|
||||
pub type TlsConfig = Arc<rustls::ServerConfig>;
|
||||
pub struct TlsConfig {
|
||||
pub config: Arc<rustls::ServerConfig>,
|
||||
pub common_name: Option<String>,
|
||||
}
|
||||
|
||||
impl TlsConfig {
|
||||
pub fn to_server_config(&self) -> Arc<rustls::ServerConfig> {
|
||||
self.config.clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// Configure TLS for the main endpoint.
|
||||
pub fn configure_tls(key_path: &str, cert_path: &str) -> anyhow::Result<TlsConfig> {
|
||||
let key = {
|
||||
let key_bytes = std::fs::read(key_path).context("TLS key file")?;
|
||||
let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..])
|
||||
.context("couldn't read TLS keys")?;
|
||||
.context(format!("Failed to read TLS keys at '{key_path}'"))?;
|
||||
|
||||
ensure!(keys.len() == 1, "keys.len() = {} (should be 1)", keys.len());
|
||||
keys.pop().map(rustls::PrivateKey).unwrap()
|
||||
};
|
||||
|
||||
let cert_chain_bytes = std::fs::read(cert_path)
|
||||
.context(format!("Failed to read TLS cert file at '{cert_path}.'"))?;
|
||||
let cert_chain = {
|
||||
let cert_chain_bytes = std::fs::read(cert_path).context("TLS cert file")?;
|
||||
rustls_pemfile::certs(&mut &cert_chain_bytes[..])
|
||||
.context("couldn't read TLS certificate chain")?
|
||||
.context(format!(
|
||||
"Failed to read TLS certificate chain from bytes from file at '{cert_path}'."
|
||||
))?
|
||||
.into_iter()
|
||||
.map(rustls::Certificate)
|
||||
.collect()
|
||||
@@ -64,7 +76,25 @@ pub fn configure_tls(key_path: &str, cert_path: &str) -> anyhow::Result<TlsConfi
|
||||
// allow TLS 1.2 to be compatible with older client libraries
|
||||
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])?
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(cert_chain, key)?;
|
||||
.with_single_cert(cert_chain, key)?
|
||||
.into();
|
||||
|
||||
Ok(config.into())
|
||||
// determine common name from tls-cert (-c server.crt param).
|
||||
// used in asserting project name formatting invariant.
|
||||
let common_name = {
|
||||
let pem = x509_parser::pem::parse_x509_pem(&cert_chain_bytes)
|
||||
.context(format!(
|
||||
"Failed to parse PEM object from bytes from file at '{cert_path}'."
|
||||
))?
|
||||
.1;
|
||||
let almost_common_name = pem.parse_x509()?.tbs_certificate.subject.to_string();
|
||||
let expected_prefix = "CN=*.";
|
||||
let common_name = almost_common_name.strip_prefix(expected_prefix);
|
||||
common_name.map(str::to_string)
|
||||
};
|
||||
|
||||
Ok(TlsConfig {
|
||||
config,
|
||||
common_name,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -81,7 +81,7 @@ async fn handle_client(
|
||||
NUM_CONNECTIONS_CLOSED_COUNTER.inc();
|
||||
}
|
||||
|
||||
let tls = config.tls_config.clone();
|
||||
let tls = config.tls_config.as_ref();
|
||||
let (stream, creds) = match handshake(stream, tls, cancel_map).await? {
|
||||
Some(x) => x,
|
||||
None => return Ok(()), // it's a cancellation request
|
||||
@@ -99,12 +99,14 @@ async fn handle_client(
|
||||
/// we also take an extra care of propagating only the select handshake errors to client.
|
||||
async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
stream: S,
|
||||
mut tls: Option<TlsConfig>,
|
||||
mut tls: Option<&TlsConfig>,
|
||||
cancel_map: &CancelMap,
|
||||
) -> anyhow::Result<Option<(PqStream<Stream<S>>, auth::ClientCredentials)>> {
|
||||
// Client may try upgrading to each protocol only once
|
||||
let (mut tried_ssl, mut tried_gss) = (false, false);
|
||||
|
||||
let common_name = tls.and_then(|cfg| cfg.common_name.as_deref());
|
||||
|
||||
let mut stream = PqStream::new(Stream::from_raw(stream));
|
||||
loop {
|
||||
let msg = stream.read_startup_packet().await?;
|
||||
@@ -122,7 +124,9 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
if let Some(tls) = tls.take() {
|
||||
// Upgrade raw stream into a secure TLS-backed stream.
|
||||
// NOTE: We've consumed `tls`; this fact will be used later.
|
||||
stream = PqStream::new(stream.into_inner().upgrade(tls).await?);
|
||||
stream = PqStream::new(
|
||||
stream.into_inner().upgrade(tls.to_server_config()).await?,
|
||||
);
|
||||
}
|
||||
}
|
||||
_ => bail!(ERR_PROTO_VIOLATION),
|
||||
@@ -143,15 +147,16 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
stream.throw_error_str(ERR_INSECURE_CONNECTION).await?;
|
||||
}
|
||||
|
||||
// Here and forth: `or_else` demands that we use a future here
|
||||
let mut creds: auth::ClientCredentials = async { params.try_into() }
|
||||
.or_else(|e| stream.throw_error(e))
|
||||
.await?;
|
||||
// Get SNI info when available
|
||||
let sni_data = match stream.get_ref() {
|
||||
Stream::Tls { tls } => tls.get_ref().1.sni_hostname().map(|s| s.to_owned()),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
// Set SNI info when available
|
||||
if let Stream::Tls { tls } = stream.get_ref() {
|
||||
creds.sni_data = tls.get_ref().1.sni_hostname().map(|s| s.to_owned());
|
||||
}
|
||||
// Construct credentials
|
||||
let creds =
|
||||
auth::ClientCredentials::parse(params, sni_data.as_deref(), common_name);
|
||||
let creds = async { creds }.or_else(|e| stream.throw_error(e)).await?;
|
||||
|
||||
break Ok(Some((stream, creds)));
|
||||
}
|
||||
@@ -264,12 +269,13 @@ mod tests {
|
||||
}
|
||||
|
||||
/// Generate TLS certificates and build rustls configs for client and server.
|
||||
fn generate_tls_config(
|
||||
hostname: &str,
|
||||
) -> anyhow::Result<(ClientConfig<'_>, Arc<rustls::ServerConfig>)> {
|
||||
fn generate_tls_config<'a>(
|
||||
hostname: &'a str,
|
||||
common_name: &'a str,
|
||||
) -> anyhow::Result<(ClientConfig<'a>, TlsConfig)> {
|
||||
let (ca, cert, key) = generate_certs(hostname)?;
|
||||
|
||||
let server_config = {
|
||||
let tls_config = {
|
||||
let config = rustls::ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_no_client_auth()
|
||||
@@ -291,7 +297,12 @@ mod tests {
|
||||
ClientConfig { config, hostname }
|
||||
};
|
||||
|
||||
Ok((client_config, server_config))
|
||||
let tls_config = TlsConfig {
|
||||
config: tls_config,
|
||||
common_name: Some(common_name.to_string()),
|
||||
};
|
||||
|
||||
Ok((client_config, tls_config))
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -346,7 +357,7 @@ mod tests {
|
||||
auth: impl TestAuth + Send,
|
||||
) -> anyhow::Result<()> {
|
||||
let cancel_map = CancelMap::default();
|
||||
let (mut stream, _creds) = handshake(client, tls, &cancel_map)
|
||||
let (mut stream, _creds) = handshake(client, tls.as_ref(), &cancel_map)
|
||||
.await?
|
||||
.context("handshake failed")?;
|
||||
|
||||
@@ -365,7 +376,8 @@ mod tests {
|
||||
async fn handshake_tls_is_enforced_by_proxy() -> anyhow::Result<()> {
|
||||
let (client, server) = tokio::io::duplex(1024);
|
||||
|
||||
let (_, server_config) = generate_tls_config("localhost")?;
|
||||
let (_, server_config) =
|
||||
generate_tls_config("generic-project-name.localhost", "localhost")?;
|
||||
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth));
|
||||
|
||||
let client_err = tokio_postgres::Config::new()
|
||||
@@ -393,7 +405,8 @@ mod tests {
|
||||
async fn handshake_tls() -> anyhow::Result<()> {
|
||||
let (client, server) = tokio::io::duplex(1024);
|
||||
|
||||
let (client_config, server_config) = generate_tls_config("localhost")?;
|
||||
let (client_config, server_config) =
|
||||
generate_tls_config("generic-project-name.localhost", "localhost")?;
|
||||
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), NoAuth));
|
||||
|
||||
let (_client, _conn) = tokio_postgres::Config::new()
|
||||
@@ -415,6 +428,7 @@ mod tests {
|
||||
let (_client, _conn) = tokio_postgres::Config::new()
|
||||
.user("john_doe")
|
||||
.dbname("earth")
|
||||
.options("project=generic-project-name")
|
||||
.ssl_mode(SslMode::Prefer)
|
||||
.connect_raw(server, NoTls)
|
||||
.await?;
|
||||
@@ -476,7 +490,8 @@ mod tests {
|
||||
async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> {
|
||||
let (client, server) = tokio::io::duplex(1024);
|
||||
|
||||
let (client_config, server_config) = generate_tls_config("localhost")?;
|
||||
let (client_config, server_config) =
|
||||
generate_tls_config("generic-project-name.localhost", "localhost")?;
|
||||
let proxy = tokio::spawn(dummy_proxy(
|
||||
client,
|
||||
Some(server_config),
|
||||
@@ -498,7 +513,8 @@ mod tests {
|
||||
async fn scram_auth_mock() -> anyhow::Result<()> {
|
||||
let (client, server) = tokio::io::duplex(1024);
|
||||
|
||||
let (client_config, server_config) = generate_tls_config("localhost")?;
|
||||
let (client_config, server_config) =
|
||||
generate_tls_config("generic-project-name.localhost", "localhost")?;
|
||||
let proxy = tokio::spawn(dummy_proxy(
|
||||
client,
|
||||
Some(server_config),
|
||||
|
||||
@@ -2,7 +2,7 @@ import pytest
|
||||
|
||||
|
||||
def test_proxy_select_1(static_proxy):
|
||||
static_proxy.safe_psql("select 1;")
|
||||
static_proxy.safe_psql("select 1;", options="project=generic-project-name")
|
||||
|
||||
|
||||
# Pass extra options to the server.
|
||||
|
||||
Reference in New Issue
Block a user