[proxy] Add the password hack authentication flow (#2095)

[proxy] Add the `password hack` authentication flow

This lets us authenticate users which can use neither
SNI (due to old libpq) nor connection string `options`
(due to restrictions in other client libraries).

Note: `PasswordHack` will accept passwords which are not
encoded in base64 via the "password" field. The assumption
is that most user passwords will be valid utf-8 strings,
and the rest may still be passed via "password_".
This commit is contained in:
Dmitry Ivanov
2022-07-25 17:23:10 +03:00
committed by GitHub
parent 39c59b8df5
commit 5f4ccae5c5
18 changed files with 750 additions and 550 deletions

View File

@@ -47,10 +47,12 @@ pub enum FeStartupPacket {
StartupMessage {
major_version: u32,
minor_version: u32,
params: HashMap<String, String>,
params: StartupMessageParams,
},
}
pub type StartupMessageParams = HashMap<String, String>;
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
pub struct CancelKeyData {
pub backend_pid: i32,

View File

@@ -1,11 +1,14 @@
//! Client authentication mechanisms.
pub mod backend;
pub use backend::DatabaseInfo;
pub use backend::{BackendType, DatabaseInfo};
mod credentials;
pub use credentials::ClientCredentials;
mod password_hack;
use password_hack::PasswordHackPayload;
mod flow;
pub use flow::*;
@@ -29,9 +32,8 @@ pub enum AuthErrorImpl {
#[error(transparent)]
Sasl(#[from] crate::sasl::Error),
/// For passwords that couldn't be processed by [`backend::legacy_console::parse_password`].
#[error("Malformed password message")]
MalformedPassword,
#[error("Malformed password message: {0}")]
MalformedPassword(&'static str),
/// Errors produced by [`crate::stream::PqStream`].
#[error(transparent)]
@@ -76,7 +78,7 @@ impl UserFacingError for AuthError {
Console(e) => e.to_string_client(),
GetAuthInfo(e) => e.to_string_client(),
Sasl(e) => e.to_string_client(),
MalformedPassword => self.to_string(),
MalformedPassword(_) => self.to_string(),
_ => "Internal error".to_string(),
}
}

View File

@@ -1,16 +1,14 @@
mod legacy_console;
mod link;
mod postgres;
pub mod console;
mod legacy_console;
pub use legacy_console::{AuthError, AuthErrorImpl};
use super::ClientCredentials;
use crate::{
compute,
config::{AuthBackendType, ProxyConfig},
mgmt,
auth::{self, AuthFlow, ClientCredentials},
compute, config, mgmt,
stream::PqStream,
waiters::{self, Waiter, Waiters},
};
@@ -78,32 +76,158 @@ impl From<DatabaseInfo> for tokio_postgres::Config {
}
}
pub(super) async fn handle_user(
config: &ProxyConfig,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
creds: ClientCredentials,
) -> super::Result<compute::NodeInfo> {
use AuthBackendType::*;
match config.auth_backend {
LegacyConsole => {
legacy_console::handle_user(
&config.auth_endpoint,
&config.auth_link_uri,
client,
&creds,
)
.await
/// This type serves two purposes:
///
/// * When `T` is `()`, it's just a regular auth backend selector
/// which we use in [`crate::config::ProxyConfig`].
///
/// * However, when we substitute `T` with [`ClientCredentials`],
/// this helps us provide the credentials only to those auth
/// backends which require them for the authentication process.
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum BackendType<T> {
/// Legacy Cloud API (V1) + link auth.
LegacyConsole(T),
/// Current Cloud API (V2).
Console(T),
/// Local mock of Cloud API (V2).
Postgres(T),
/// Authentication via a web browser.
Link,
}
impl<T> BackendType<T> {
/// Very similar to [`std::option::Option::map`].
/// Maps [`BackendType<T>`] to [`BackendType<R>`] by applying
/// a function to a contained value.
pub fn map<R>(self, f: impl FnOnce(T) -> R) -> BackendType<R> {
use BackendType::*;
match self {
LegacyConsole(x) => LegacyConsole(f(x)),
Console(x) => Console(f(x)),
Postgres(x) => Postgres(f(x)),
Link => Link,
}
}
}
impl<T, E> BackendType<Result<T, E>> {
/// Very similar to [`std::option::Option::transpose`].
/// This is most useful for error handling.
pub fn transpose(self) -> Result<BackendType<T>, E> {
use BackendType::*;
match self {
LegacyConsole(x) => x.map(LegacyConsole),
Console(x) => x.map(Console),
Postgres(x) => x.map(Postgres),
Link => Ok(Link),
}
}
}
impl BackendType<ClientCredentials> {
/// Authenticate the client via the requested backend, possibly using credentials.
pub async fn authenticate(
mut self,
urls: &config::AuthUrls,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
) -> super::Result<compute::NodeInfo> {
use BackendType::*;
if let Console(creds) | Postgres(creds) = &mut self {
// If there's no project so far, that entails that client doesn't
// support SNI or other means of passing the project name.
// We now expect to see a very specific payload in the place of password.
if creds.project().is_none() {
let payload = AuthFlow::new(client)
.begin(auth::PasswordHack)
.await?
.authenticate()
.await?;
// Finally we may finish the initialization of `creds`.
// TODO: add missing type safety to ClientCredentials.
creds.project = Some(payload.project);
let mut config = match &self {
Console(creds) => {
console::Api::new(&urls.auth_endpoint, creds)
.wake_compute()
.await?
}
Postgres(creds) => {
postgres::Api::new(&urls.auth_endpoint, creds)
.wake_compute()
.await?
}
_ => unreachable!("see the patterns above"),
};
// We should use a password from payload as well.
config.password(payload.password);
return Ok(compute::NodeInfo {
reported_auth_ok: false,
config,
});
}
}
match self {
LegacyConsole(creds) => {
legacy_console::handle_user(
&urls.auth_endpoint,
&urls.auth_link_uri,
&creds,
client,
)
.await
}
Console(creds) => {
console::Api::new(&urls.auth_endpoint, &creds)
.handle_user(client)
.await
}
Postgres(creds) => {
postgres::Api::new(&urls.auth_endpoint, &creds)
.handle_user(client)
.await
}
// NOTE: this auth backend doesn't use client credentials.
Link => link::handle_user(&urls.auth_link_uri, client).await,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_backend_type_map() {
let values = [
BackendType::LegacyConsole(0),
BackendType::Console(0),
BackendType::Postgres(0),
BackendType::Link,
];
for value in values {
assert_eq!(value.map(|x| x), value);
}
}
#[test]
fn test_backend_type_transpose() {
let values = [
BackendType::LegacyConsole(Ok::<_, ()>(0)),
BackendType::Console(Ok(0)),
BackendType::Postgres(Ok(0)),
BackendType::Link,
];
for value in values {
assert_eq!(value.map(Result::unwrap), value.transpose().unwrap());
}
Console => {
console::Api::new(&config.auth_endpoint, &creds)?
.handle_user(client)
.await
}
Postgres => {
postgres::Api::new(&config.auth_endpoint, &creds)?
.handle_user(client)
.await
}
Link => link::handle_user(&config.auth_link_uri, client).await,
}
}

View File

@@ -1,18 +1,17 @@
//! Cloud API V2.
use crate::{
auth::{self, AuthFlow, ClientCredentials, DatabaseInfo},
compute,
error::UserFacingError,
auth::{self, AuthFlow, ClientCredentials},
compute::{self, ComputeConnCfg},
error::{io_error, UserFacingError},
scram,
stream::PqStream,
url::ApiUrl,
};
use serde::{Deserialize, Serialize};
use std::{future::Future, io};
use std::future::Future;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage};
pub type Result<T> = std::result::Result<T, ConsoleAuthError>;
@@ -84,8 +83,8 @@ pub(super) struct Api<'a> {
impl<'a> Api<'a> {
/// Construct an API object containing the auth parameters.
pub(super) fn new(endpoint: &'a ApiUrl, creds: &'a ClientCredentials) -> Result<Self> {
Ok(Self { endpoint, creds })
pub(super) fn new(endpoint: &'a ApiUrl, creds: &'a ClientCredentials) -> Self {
Self { endpoint, creds }
}
/// Authenticate the existing user or throw an error.
@@ -100,7 +99,7 @@ impl<'a> Api<'a> {
let mut url = self.endpoint.clone();
url.path_segments_mut().push("proxy_get_role_secret");
url.query_pairs_mut()
.append_pair("project", self.creds.project_name.as_ref()?)
.append_pair("project", self.creds.project().expect("impossible"))
.append_pair("role", &self.creds.user);
// TODO: use a proper logger
@@ -120,11 +119,11 @@ impl<'a> Api<'a> {
}
/// Wake up the compute node and return the corresponding connection info.
async fn wake_compute(&self) -> Result<DatabaseInfo> {
pub(super) async fn wake_compute(&self) -> Result<ComputeConnCfg> {
let mut url = self.endpoint.clone();
url.path_segments_mut().push("proxy_wake_compute");
let project_name = self.creds.project_name.as_ref()?;
url.query_pairs_mut().append_pair("project", project_name);
url.query_pairs_mut()
.append_pair("project", self.creds.project().expect("impossible"));
// TODO: use a proper logger
println!("cplane request: {url}");
@@ -137,16 +136,20 @@ impl<'a> Api<'a> {
let response: GetWakeComputeResponse =
serde_json::from_str(&resp.text().await.map_err(io_error)?)?;
let (host, port) = parse_host_port(&response.address)
.ok_or(ConsoleAuthError::BadComputeAddress(response.address))?;
// Unfortunately, ownership won't let us use `Option::ok_or` here.
let (host, port) = match parse_host_port(&response.address) {
None => return Err(ConsoleAuthError::BadComputeAddress(response.address)),
Some(x) => x,
};
Ok(DatabaseInfo {
host,
port,
dbname: self.creds.dbname.to_owned(),
user: self.creds.user.to_owned(),
password: None,
})
let mut config = ComputeConnCfg::new();
config
.host(host)
.port(port)
.dbname(&self.creds.dbname)
.user(&self.creds.user);
Ok(config)
}
}
@@ -160,7 +163,7 @@ pub(super) async fn handle_user<'a, Endpoint, GetAuthInfo, WakeCompute>(
) -> auth::Result<compute::NodeInfo>
where
GetAuthInfo: Future<Output = Result<AuthInfo>>,
WakeCompute: Future<Output = Result<DatabaseInfo>>,
WakeCompute: Future<Output = Result<ComputeConnCfg>>,
{
let auth_info = get_auth_info(endpoint).await?;
@@ -179,48 +182,18 @@ where
}
};
client
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
let mut config = wake_compute(endpoint).await?;
if let Some(keys) = scram_keys {
config.auth_keys(tokio_postgres::config::AuthKeys::ScramSha256(keys));
}
Ok(compute::NodeInfo {
db_info: wake_compute(endpoint).await?,
scram_keys,
reported_auth_ok: false,
config,
})
}
/// Upcast (almost) any error into an opaque [`io::Error`].
pub(super) fn io_error(e: impl Into<Box<dyn std::error::Error + Send + Sync>>) -> io::Error {
io::Error::new(io::ErrorKind::Other, e)
}
fn parse_host_port(input: &str) -> Option<(String, u16)> {
fn parse_host_port(input: &str) -> Option<(&str, u16)> {
let (host, port) = input.split_once(':')?;
Some((host.to_owned(), port.parse().ok()?))
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn parse_db_info() -> anyhow::Result<()> {
let _: DatabaseInfo = serde_json::from_value(json!({
"host": "localhost",
"port": 5432,
"dbname": "postgres",
"user": "john_doe",
"password": "password",
}))?;
let _: DatabaseInfo = serde_json::from_value(json!({
"host": "localhost",
"port": 5432,
"dbname": "postgres",
"user": "john_doe",
}))?;
Ok(())
}
Some((host, port.parse().ok()?))
}

View File

@@ -11,7 +11,7 @@ use crate::{
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage};
use utils::pq_proto::BeMessage as Be;
#[derive(Debug, Error)]
pub enum AuthErrorImpl {
@@ -76,6 +76,12 @@ enum ProxyAuthResponse {
NotReady { ready: bool }, // TODO: get rid of `ready`
}
impl ClientCredentials {
fn is_existing_user(&self) -> bool {
self.user.ends_with("@zenith")
}
}
async fn authenticate_proxy_client(
auth_endpoint: &reqwest::Url,
creds: &ClientCredentials,
@@ -100,7 +106,7 @@ async fn authenticate_proxy_client(
}
let auth_info: ProxyAuthResponse = serde_json::from_str(resp.text().await?.as_str())?;
println!("got auth info: #{:?}", auth_info);
println!("got auth info: {:?}", auth_info);
use ProxyAuthResponse::*;
let db_info = match auth_info {
@@ -128,7 +134,9 @@ async fn handle_existing_user(
// Read client's password hash
let msg = client.read_password_message().await?;
let md5_response = parse_password(&msg).ok_or(auth::AuthErrorImpl::MalformedPassword)?;
let md5_response = parse_password(&msg).ok_or(auth::AuthErrorImpl::MalformedPassword(
"the password should be a valid null-terminated utf-8 string",
))?;
let db_info = authenticate_proxy_client(
auth_endpoint,
@@ -139,21 +147,17 @@ async fn handle_existing_user(
)
.await?;
client
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
Ok(compute::NodeInfo {
db_info,
scram_keys: None,
reported_auth_ok: false,
config: db_info.into(),
})
}
pub async fn handle_user(
auth_endpoint: &reqwest::Url,
auth_link_uri: &reqwest::Url,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
creds: &ClientCredentials,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
) -> auth::Result<compute::NodeInfo> {
if creds.is_existing_user() {
handle_existing_user(auth_endpoint, client, creds).await
@@ -201,4 +205,24 @@ mod tests {
.unwrap();
assert!(matches!(auth, ProxyAuthResponse::NotReady { .. }));
}
#[test]
fn parse_db_info() -> anyhow::Result<()> {
let _: DatabaseInfo = serde_json::from_value(json!({
"host": "localhost",
"port": 5432,
"dbname": "postgres",
"user": "john_doe",
"password": "password",
}))?;
let _: DatabaseInfo = serde_json::from_value(json!({
"host": "localhost",
"port": 5432,
"dbname": "postgres",
"user": "john_doe",
}))?;
Ok(())
}
}

View File

@@ -41,7 +41,7 @@ pub async fn handle_user(
client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?;
Ok(compute::NodeInfo {
db_info,
scram_keys: None,
reported_auth_ok: true,
config: db_info.into(),
})
}

View File

@@ -3,10 +3,12 @@
use crate::{
auth::{
self,
backend::console::{self, io_error, AuthInfo, Result},
ClientCredentials, DatabaseInfo,
backend::console::{self, AuthInfo, Result},
ClientCredentials,
},
compute, scram,
compute::{self, ComputeConnCfg},
error::io_error,
scram,
stream::PqStream,
url::ApiUrl,
};
@@ -20,8 +22,8 @@ pub(super) struct Api<'a> {
impl<'a> Api<'a> {
/// Construct an API object containing the auth parameters.
pub(super) fn new(endpoint: &'a ApiUrl, creds: &'a ClientCredentials) -> Result<Self> {
Ok(Self { endpoint, creds })
pub(super) fn new(endpoint: &'a ApiUrl, creds: &'a ClientCredentials) -> Self {
Self { endpoint, creds }
}
/// Authenticate the existing user or throw an error.
@@ -56,7 +58,10 @@ impl<'a> Api<'a> {
// We shouldn't get more than one row anyway.
[row, ..] => {
let entry = row.try_get(0).map_err(io_error)?;
let entry = row
.try_get("rolpassword")
.map_err(|e| io_error(format!("failed to read user's password: {e}")))?;
scram::ServerSecret::parse(entry)
.map(AuthInfo::Scram)
.or_else(|| {
@@ -75,14 +80,14 @@ impl<'a> Api<'a> {
}
/// We don't need to wake anything locally, so we just return the connection info.
async fn wake_compute(&self) -> Result<DatabaseInfo> {
Ok(DatabaseInfo {
// TODO: handle that near CLI params parsing
host: self.endpoint.host_str().unwrap_or("localhost").to_owned(),
port: self.endpoint.port().unwrap_or(5432),
dbname: self.creds.dbname.to_owned(),
user: self.creds.user.to_owned(),
password: None,
})
pub(super) async fn wake_compute(&self) -> Result<ComputeConnCfg> {
let mut config = ComputeConnCfg::new();
config
.host(self.endpoint.host_str().unwrap_or("localhost"))
.port(self.endpoint.port().unwrap_or(5432))
.dbname(&self.creds.dbname)
.user(&self.creds.user);
Ok(config)
}
}

View File

@@ -1,39 +1,25 @@
//! User credentials used in authentication.
use crate::compute;
use crate::config::ProxyConfig;
use crate::error::UserFacingError;
use crate::stream::PqStream;
use std::collections::HashMap;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use utils::pq_proto::StartupMessageParams;
#[derive(Debug, Error, PartialEq, Eq, Clone)]
pub enum ClientCredsParseError {
#[error("Parameter `{0}` is missing in startup packet.")]
#[error("Parameter '{0}' is missing in startup packet.")]
MissingKey(&'static str),
#[error(
"Project name is not specified. \
EITHER please upgrade the postgres client library (libpq) for SNI support \
OR pass the project name as a parameter: '&options=project%3D<project-name>'."
)]
MissingSNIAndProjectName,
#[error("Inconsistent project name inferred from SNI ('{0}') and project option ('{1}').")]
InconsistentProjectNameAndSNI(String, String),
#[error("Common name is not set.")]
CommonNameNotSet,
InconsistentProjectNames(String, String),
#[error(
"SNI ('{1}') inconsistently formatted with respect to common name ('{0}'). \
SNI should be formatted as '<project-name>.<common-name>'."
SNI should be formatted as '<project-name>.{0}'."
)]
InconsistentCommonNameAndSNI(String, String),
InconsistentSni(String, String),
#[error("Project name ('{0}') must contain only alphanumeric characters and hyphens ('-').")]
ProjectNameContainsIllegalChars(String),
#[error("Project name ('{0}') must contain only alphanumeric characters and hyphen.")]
MalformedProjectName(String),
}
impl UserFacingError for ClientCredsParseError {}
@@ -44,286 +30,171 @@ impl UserFacingError for ClientCredsParseError {}
pub struct ClientCredentials {
pub user: String,
pub dbname: String,
pub project_name: Result<String, ClientCredsParseError>,
pub project: Option<String>,
}
impl ClientCredentials {
pub fn is_existing_user(&self) -> bool {
// This logic will likely change in the future.
self.user.ends_with("@zenith")
pub fn project(&self) -> Option<&str> {
self.project.as_deref()
}
}
impl ClientCredentials {
pub fn parse(
mut options: HashMap<String, String>,
sni_data: Option<&str>,
mut options: StartupMessageParams,
sni: Option<&str>,
common_name: Option<&str>,
) -> Result<Self, ClientCredsParseError> {
let mut get_param = |key| {
options
.remove(key)
.ok_or(ClientCredsParseError::MissingKey(key))
};
use ClientCredsParseError::*;
// Some parameters are absolutely necessary, others not so much.
let mut get_param = |key| options.remove(key).ok_or(MissingKey(key));
// Some parameters are stored in the startup message.
let user = get_param("user")?;
let dbname = get_param("database")?;
let project_name = get_param("project").ok();
let project_name = get_project_name(sni_data, common_name, project_name.as_deref());
let project_a = get_param("project").ok();
// Alternative project name is in fact a subdomain from SNI.
// NOTE: we do not consider SNI if `common_name` is missing.
let project_b = sni
.zip(common_name)
.map(|(sni, cn)| {
// TODO: what if SNI is present but just a common name?
subdomain_from_sni(sni, cn)
.ok_or_else(|| InconsistentSni(sni.to_owned(), cn.to_owned()))
})
.transpose()?;
let project = match (project_a, project_b) {
// Invariant: if we have both project name variants, they should match.
(Some(a), Some(b)) if a != b => Some(Err(InconsistentProjectNames(a, b))),
(a, b) => a.or(b).map(|name| {
// Invariant: project name may not contain certain characters.
check_project_name(name).map_err(MalformedProjectName)
}),
}
.transpose()?;
Ok(Self {
user,
dbname,
project_name,
project,
})
}
}
/// Use credentials to authenticate the user.
pub async fn authenticate(
self,
config: &ProxyConfig,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
) -> super::Result<compute::NodeInfo> {
// This method is just a convenient facade for `handle_user`
super::backend::handle_user(config, client, self).await
fn check_project_name(name: String) -> Result<String, String> {
if name.chars().all(|c| c.is_alphanumeric() || c == '-') {
Ok(name)
} else {
Err(name)
}
}
/// Inferring project name from sni_data.
fn project_name_from_sni_data(
sni_data: &str,
common_name: &str,
) -> Result<String, ClientCredsParseError> {
let common_name_with_dot = format!(".{common_name}");
// check that ".{common_name_with_dot}" is the actual suffix in sni_data
if !sni_data.ends_with(&common_name_with_dot) {
return Err(ClientCredsParseError::InconsistentCommonNameAndSNI(
common_name.to_string(),
sni_data.to_string(),
fn subdomain_from_sni(sni: &str, common_name: &str) -> Option<String> {
sni.strip_suffix(common_name)?
.strip_suffix('.')
.map(str::to_owned)
}
#[cfg(test)]
mod tests {
use super::*;
fn make_options<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> StartupMessageParams {
StartupMessageParams::from(pairs.map(|(k, v)| (k.to_owned(), v.to_owned())))
}
#[test]
#[ignore = "TODO: fix how database is handled"]
fn parse_bare_minimum() -> anyhow::Result<()> {
// According to postgresql, only `user` should be required.
let options = make_options([("user", "john_doe")]);
// TODO: check that `creds.dbname` is None.
let creds = ClientCredentials::parse(options, None, None)?;
assert_eq!(creds.user, "john_doe");
Ok(())
}
#[test]
fn parse_missing_project() -> anyhow::Result<()> {
let options = make_options([("user", "john_doe"), ("database", "world")]);
let creds = ClientCredentials::parse(options, None, None)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.dbname, "world");
assert_eq!(creds.project, None);
Ok(())
}
#[test]
fn parse_project_from_sni() -> anyhow::Result<()> {
let options = make_options([("user", "john_doe"), ("database", "world")]);
let sni = Some("foo.localhost");
let common_name = Some("localhost");
let creds = ClientCredentials::parse(options, sni, common_name)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.dbname, "world");
assert_eq!(creds.project.as_deref(), Some("foo"));
Ok(())
}
#[test]
fn parse_project_from_options() -> anyhow::Result<()> {
let options = make_options([
("user", "john_doe"),
("database", "world"),
("project", "bar"),
]);
let creds = ClientCredentials::parse(options, None, None)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.dbname, "world");
assert_eq!(creds.project.as_deref(), Some("bar"));
Ok(())
}
#[test]
fn parse_projects_identical() -> anyhow::Result<()> {
let options = make_options([
("user", "john_doe"),
("database", "world"),
("project", "baz"),
]);
let sni = Some("baz.localhost");
let common_name = Some("localhost");
let creds = ClientCredentials::parse(options, sni, common_name)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.dbname, "world");
assert_eq!(creds.project.as_deref(), Some("baz"));
Ok(())
}
#[test]
fn parse_projects_different() {
let options = make_options([
("user", "john_doe"),
("database", "world"),
("project", "first"),
]);
let sni = Some("second.localhost");
let common_name = Some("localhost");
assert!(matches!(
ClientCredentials::parse(options, sni, common_name).expect_err("should fail"),
ClientCredsParseError::InconsistentProjectNames(_, _)
));
}
// return sni_data without the common name suffix.
Ok(sni_data
.strip_suffix(&common_name_with_dot)
.unwrap()
.to_string())
}
#[cfg(test)]
mod tests_for_project_name_from_sni_data {
use super::*;
#[test]
fn passing() {
let target_project_name = "my-project-123";
let common_name = "localtest.me";
let sni_data = format!("{target_project_name}.{common_name}");
assert_eq!(
project_name_from_sni_data(&sni_data, common_name),
Ok(target_project_name.to_string())
);
}
#[test]
fn throws_inconsistent_common_name_and_sni_data() {
let target_project_name = "my-project-123";
let common_name = "localtest.me";
let wrong_suffix = "wrongtest.me";
assert_eq!(common_name.len(), wrong_suffix.len());
let wrong_common_name = format!("wrong{wrong_suffix}");
let sni_data = format!("{target_project_name}.{wrong_common_name}");
assert_eq!(
project_name_from_sni_data(&sni_data, common_name),
Err(ClientCredsParseError::InconsistentCommonNameAndSNI(
common_name.to_string(),
sni_data
))
);
}
}
/// Determine project name from SNI or from project_name parameter from options argument.
fn get_project_name(
sni_data: Option<&str>,
common_name: Option<&str>,
project_name: Option<&str>,
) -> Result<String, ClientCredsParseError> {
// determine the project name from sni_data if it exists, otherwise from project_name.
let ret = match sni_data {
Some(sni_data) => {
let common_name = common_name.ok_or(ClientCredsParseError::CommonNameNotSet)?;
let project_name_from_sni = project_name_from_sni_data(sni_data, common_name)?;
// check invariant: project name from options and from sni should match
if let Some(project_name) = &project_name {
if !project_name_from_sni.eq(project_name) {
return Err(ClientCredsParseError::InconsistentProjectNameAndSNI(
project_name_from_sni,
project_name.to_string(),
));
}
}
project_name_from_sni
}
None => project_name
.ok_or(ClientCredsParseError::MissingSNIAndProjectName)?
.to_string(),
};
// check formatting invariant: project name must contain only alphanumeric characters and hyphens.
if !ret.chars().all(|x: char| x.is_alphanumeric() || x == '-') {
return Err(ClientCredsParseError::ProjectNameContainsIllegalChars(ret));
}
Ok(ret)
}
#[cfg(test)]
mod tests_for_project_name_only {
use super::*;
#[test]
fn passing_from_sni_data_only() {
let target_project_name = "my-project-123";
let common_name = "localtest.me";
let sni_data = format!("{target_project_name}.{common_name}");
assert_eq!(
get_project_name(Some(&sni_data), Some(common_name), None),
Ok(target_project_name.to_string())
);
}
#[test]
fn throws_project_name_contains_illegal_chars_from_sni_data_only() {
let project_name_prefix = "my-project";
let project_name_suffix = "123";
let common_name = "localtest.me";
for illegal_char_id in 0..256 {
let illegal_char = char::from_u32(illegal_char_id).unwrap();
if !(illegal_char.is_alphanumeric() || illegal_char == '-')
&& illegal_char.to_string().len() == 1
{
let target_project_name =
format!("{project_name_prefix}{illegal_char}{project_name_suffix}");
let sni_data = format!("{target_project_name}.{common_name}");
assert_eq!(
get_project_name(Some(&sni_data), Some(common_name), None),
Err(ClientCredsParseError::ProjectNameContainsIllegalChars(
target_project_name
))
);
}
}
}
#[test]
fn passing_from_project_name_only() {
let target_project_name = "my-project-123";
let common_names = [Some("localtest.me"), None];
for common_name in common_names {
assert_eq!(
get_project_name(None, common_name, Some(target_project_name)),
Ok(target_project_name.to_string())
);
}
}
#[test]
fn throws_project_name_contains_illegal_chars_from_project_name_only() {
let project_name_prefix = "my-project";
let project_name_suffix = "123";
let common_names = [Some("localtest.me"), None];
for common_name in common_names {
for illegal_char_id in 0..256 {
let illegal_char: char = char::from_u32(illegal_char_id).unwrap();
if !(illegal_char.is_alphanumeric() || illegal_char == '-')
&& illegal_char.to_string().len() == 1
{
let target_project_name =
format!("{project_name_prefix}{illegal_char}{project_name_suffix}");
assert_eq!(
get_project_name(None, common_name, Some(&target_project_name)),
Err(ClientCredsParseError::ProjectNameContainsIllegalChars(
target_project_name
))
);
}
}
}
}
#[test]
fn passing_from_sni_data_and_project_name() {
let target_project_name = "my-project-123";
let common_name = "localtest.me";
let sni_data = format!("{target_project_name}.{common_name}");
assert_eq!(
get_project_name(
Some(&sni_data),
Some(common_name),
Some(target_project_name)
),
Ok(target_project_name.to_string())
);
}
#[test]
fn throws_inconsistent_project_name_and_sni() {
let project_name_param = "my-project-123";
let wrong_project_name = "not-my-project-123";
let common_name = "localtest.me";
let sni_data = format!("{wrong_project_name}.{common_name}");
assert_eq!(
get_project_name(Some(&sni_data), Some(common_name), Some(project_name_param)),
Err(ClientCredsParseError::InconsistentProjectNameAndSNI(
wrong_project_name.to_string(),
project_name_param.to_string()
))
);
}
#[test]
fn throws_common_name_not_set() {
let target_project_name = "my-project-123";
let wrong_project_name = "not-my-project-123";
let common_name = "localtest.me";
let sni_datas = [
Some(format!("{wrong_project_name}.{common_name}")),
Some(format!("{target_project_name}.{common_name}")),
];
let project_names = [None, Some(target_project_name)];
for sni_data in sni_datas {
for project_name_param in project_names {
assert_eq!(
get_project_name(sni_data.as_deref(), None, project_name_param),
Err(ClientCredsParseError::CommonNameNotSet)
);
}
}
}
#[test]
fn throws_inconsistent_common_name_and_sni_data() {
let target_project_name = "my-project-123";
let wrong_project_name = "not-my-project-123";
let common_name = "localtest.me";
let wrong_suffix = "wrongtest.me";
assert_eq!(common_name.len(), wrong_suffix.len());
let wrong_common_name = format!("wrong{wrong_suffix}");
let sni_datas = [
Some(format!("{wrong_project_name}.{wrong_common_name}")),
Some(format!("{target_project_name}.{wrong_common_name}")),
];
let project_names = [None, Some(target_project_name)];
for project_name_param in project_names {
for sni_data in &sni_datas {
assert_eq!(
get_project_name(sni_data.as_deref(), Some(common_name), project_name_param),
Err(ClientCredsParseError::InconsistentCommonNameAndSNI(
common_name.to_string(),
sni_data.clone().unwrap().to_string()
))
);
}
}
}
}

View File

@@ -1,8 +1,7 @@
//! Main authentication flow.
use super::AuthErrorImpl;
use crate::stream::PqStream;
use crate::{sasl, scram};
use super::{AuthErrorImpl, PasswordHackPayload};
use crate::{sasl, scram, stream::PqStream};
use std::io;
use tokio::io::{AsyncRead, AsyncWrite};
use utils::pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
@@ -27,6 +26,17 @@ impl AuthMethod for Scram<'_> {
}
}
/// Use an ad hoc auth flow (for clients which don't support SNI) proposed in
/// <https://github.com/neondatabase/cloud/issues/1620#issuecomment-1165332290>.
pub struct PasswordHack;
impl AuthMethod for PasswordHack {
#[inline(always)]
fn first_message(&self) -> BeMessage<'_> {
Be::AuthenticationCleartextPassword
}
}
/// This wrapper for [`PqStream`] performs client authentication.
#[must_use]
pub struct AuthFlow<'a, Stream, State> {
@@ -57,13 +67,34 @@ impl<'a, S: AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, PasswordHack> {
/// Perform user authentication. Raise an error in case authentication failed.
pub async fn authenticate(self) -> super::Result<PasswordHackPayload> {
let msg = self.stream.read_password_message().await?;
let password = msg
.strip_suffix(&[0])
.ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?;
// The so-called "password" should contain a base64-encoded json.
// We will use it later to route the client to their project.
let bytes = base64::decode(password)
.map_err(|_| AuthErrorImpl::MalformedPassword("bad encoding"))?;
let payload = serde_json::from_slice(&bytes)
.map_err(|_| AuthErrorImpl::MalformedPassword("invalid payload"))?;
Ok(payload)
}
}
/// Stream wrapper for handling [SCRAM](crate::scram) auth.
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
/// Perform user authentication. Raise an error in case authentication failed.
pub async fn authenticate(self) -> super::Result<scram::ScramKey> {
// Initial client message contains the chosen auth method's name.
let msg = self.stream.read_password_message().await?;
let sasl = sasl::FirstMessage::parse(&msg).ok_or(AuthErrorImpl::MalformedPassword)?;
let sasl = sasl::FirstMessage::parse(&msg)
.ok_or(AuthErrorImpl::MalformedPassword("bad sasl message"))?;
// Currently, the only supported SASL method is SCRAM.
if !scram::METHODS.contains(&sasl.method) {

View File

@@ -0,0 +1,102 @@
//! Payload for ad hoc authentication method for clients that don't support SNI.
//! See the `impl` for [`super::backend::BackendType<ClientCredentials>`].
//! Read more: <https://github.com/neondatabase/cloud/issues/1620#issuecomment-1165332290>.
use serde::{de, Deserialize, Deserializer};
use std::fmt;
#[derive(Deserialize)]
#[serde(untagged)]
pub enum Password {
/// A regular string for utf-8 encoded passwords.
Simple { password: String },
/// Password is base64-encoded because it may contain arbitrary byte sequences.
Encoded {
#[serde(rename = "password_", deserialize_with = "deserialize_base64")]
password: Vec<u8>,
},
}
impl AsRef<[u8]> for Password {
fn as_ref(&self) -> &[u8] {
match self {
Password::Simple { password } => password.as_ref(),
Password::Encoded { password } => password.as_ref(),
}
}
}
#[derive(Deserialize)]
pub struct PasswordHackPayload {
pub project: String,
#[serde(flatten)]
pub password: Password,
}
fn deserialize_base64<'a, D: Deserializer<'a>>(des: D) -> Result<Vec<u8>, D::Error> {
// It's very tempting to replace this with
//
// ```
// let base64: &str = Deserialize::deserialize(des)?;
// base64::decode(base64).map_err(serde::de::Error::custom)
// ```
//
// Unfortunately, we can't always deserialize into `&str`, so we'd
// have to use an allocating `String` instead. Thus, visitor is better.
struct Visitor;
impl<'de> de::Visitor<'de> for Visitor {
type Value = Vec<u8>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string")
}
fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
base64::decode(v).map_err(de::Error::custom)
}
}
des.deserialize_str(Visitor)
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
use serde_json::json;
#[test]
fn parse_password() -> anyhow::Result<()> {
let password: Password = serde_json::from_value(json!({
"password": "foo",
}))?;
assert_eq!(password.as_ref(), "foo".as_bytes());
let password: Password = serde_json::from_value(json!({
"password_": base64::encode("foo"),
}))?;
assert_eq!(password.as_ref(), "foo".as_bytes());
Ok(())
}
#[rstest]
#[case("password", str::to_owned)]
#[case("password_", base64::encode)]
fn parse(#[case] key: &str, #[case] encode: fn(&'static str) -> String) -> anyhow::Result<()> {
let (password, project) = ("password", "pie-in-the-sky");
let payload = json!({
"project": project,
key: encode(password),
});
let payload: PasswordHackPayload = serde_json::from_value(payload)?;
assert_eq!(payload.password.as_ref(), password.as_bytes());
assert_eq!(payload.project, project);
Ok(())
}
}

View File

@@ -1,8 +1,6 @@
use crate::auth::DatabaseInfo;
use crate::cancellation::CancelClosure;
use crate::error::UserFacingError;
use std::io;
use std::net::SocketAddr;
use crate::{cancellation::CancelClosure, error::UserFacingError};
use futures::TryFutureExt;
use std::{io, net::SocketAddr};
use thiserror::Error;
use tokio::net::TcpStream;
use tokio_postgres::NoTls;
@@ -21,44 +19,96 @@ pub enum ConnectionError {
FailedToFetchPgVersion,
}
impl UserFacingError for ConnectionError {}
/// PostgreSQL version as [`String`].
pub type Version = String;
impl UserFacingError for ConnectionError {
fn to_string_client(&self) -> String {
use ConnectionError::*;
match self {
// This helps us drop irrelevant library-specific prefixes.
// TODO: propagate severity level and other parameters.
Postgres(err) => match err.as_db_error() {
Some(err) => err.message().to_string(),
None => err.to_string(),
},
other => other.to_string(),
}
}
}
/// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`.
pub type ScramKeys = tokio_postgres::config::ScramKeys<32>;
/// Compute node connection params.
pub type ComputeConnCfg = tokio_postgres::Config;
/// Various compute node info for establishing connection etc.
pub struct NodeInfo {
pub db_info: DatabaseInfo,
pub scram_keys: Option<ScramKeys>,
/// Did we send [`utils::pq_proto::BeMessage::AuthenticationOk`]?
pub reported_auth_ok: bool,
/// Compute node connection params.
pub config: tokio_postgres::Config,
}
impl NodeInfo {
async fn connect_raw(&self) -> io::Result<(SocketAddr, TcpStream)> {
let host_port = (self.db_info.host.as_str(), self.db_info.port);
let socket = TcpStream::connect(host_port).await?;
let socket_addr = socket.peer_addr()?;
socket2::SockRef::from(&socket).set_keepalive(true)?;
use tokio_postgres::config::Host;
Ok((socket_addr, socket))
let connect_once = |host, port| {
TcpStream::connect((host, port)).and_then(|socket| async {
let socket_addr = socket.peer_addr()?;
// This prevents load balancer from severing the connection.
socket2::SockRef::from(&socket).set_keepalive(true)?;
Ok((socket_addr, socket))
})
};
// We can't reuse connection establishing logic from `tokio_postgres` here,
// because it has no means for extracting the underlying socket which we
// require for our business.
let mut connection_error = None;
let ports = self.config.get_ports();
for (i, host) in self.config.get_hosts().iter().enumerate() {
let port = ports.get(i).or_else(|| ports.get(0)).unwrap_or(&5432);
let host = match host {
Host::Tcp(host) => host.as_str(),
Host::Unix(_) => continue, // unix sockets are not welcome here
};
// TODO: maybe we should add a timeout.
match connect_once(host, *port).await {
Ok(socket) => return Ok(socket),
Err(err) => {
// We can't throw an error here, as there might be more hosts to try.
println!("failed to connect to compute `{host}:{port}`: {err}");
connection_error = Some(err);
}
}
}
Err(connection_error.unwrap_or_else(|| {
io::Error::new(
io::ErrorKind::Other,
format!("couldn't connect: bad compute config: {:?}", self.config),
)
}))
}
}
pub struct PostgresConnection {
/// Socket connected to a compute node.
pub stream: TcpStream,
/// PostgreSQL version of this instance.
pub version: String,
}
impl NodeInfo {
/// Connect to a corresponding compute node.
pub async fn connect(self) -> Result<(TcpStream, Version, CancelClosure), ConnectionError> {
let (socket_addr, mut socket) = self
pub async fn connect(&self) -> Result<(PostgresConnection, CancelClosure), ConnectionError> {
let (socket_addr, mut stream) = self
.connect_raw()
.await
.map_err(|_| ConnectionError::FailedToConnectToCompute)?;
let mut config = tokio_postgres::Config::from(self.db_info);
if let Some(scram_keys) = self.scram_keys {
config.auth_keys(tokio_postgres::config::AuthKeys::ScramSha256(scram_keys));
}
// TODO: establish a secure connection to the DB
let (client, conn) = config.connect_raw(&mut socket, NoTls).await?;
let (client, conn) = self.config.connect_raw(&mut stream, NoTls).await?;
let version = conn
.parameter("server_version")
.ok_or(ConnectionError::FailedToFetchPgVersion)?
@@ -66,6 +116,8 @@ impl NodeInfo {
let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token());
Ok((socket, version, cancel_closure))
let db = PostgresConnection { stream, version };
Ok((db, cancel_closure))
}
}

View File

@@ -1,28 +1,16 @@
use crate::url::ApiUrl;
use crate::{auth, url::ApiUrl};
use anyhow::{bail, ensure, Context};
use std::{str::FromStr, sync::Arc};
#[derive(Debug)]
pub enum AuthBackendType {
/// Legacy Cloud API (V1).
LegacyConsole,
/// Authentication via a web browser.
Link,
/// Current Cloud API (V2).
Console,
/// Local mock of Cloud API (V2).
Postgres,
}
impl FromStr for AuthBackendType {
impl FromStr for auth::BackendType<()> {
type Err = anyhow::Error;
fn from_str(s: &str) -> anyhow::Result<Self> {
use AuthBackendType::*;
use auth::BackendType::*;
Ok(match s {
"legacy" => LegacyConsole,
"console" => Console,
"postgres" => Postgres,
"legacy" => LegacyConsole(()),
"console" => Console(()),
"postgres" => Postgres(()),
"link" => Link,
_ => bail!("Invalid option `{s}` for auth method"),
})
@@ -31,7 +19,11 @@ impl FromStr for AuthBackendType {
pub struct ProxyConfig {
pub tls_config: Option<TlsConfig>,
pub auth_backend: AuthBackendType,
pub auth_backend: auth::BackendType<()>,
pub auth_urls: AuthUrls,
}
pub struct AuthUrls {
pub auth_endpoint: ApiUrl,
pub auth_link_uri: ApiUrl,
}
@@ -87,10 +79,8 @@ pub fn configure_tls(key_path: &str, cert_path: &str) -> anyhow::Result<TlsConfi
"Failed to parse PEM object from bytes from file at '{cert_path}'."
))?
.1;
let almost_common_name = pem.parse_x509()?.tbs_certificate.subject.to_string();
let expected_prefix = "CN=*.";
let common_name = almost_common_name.strip_prefix(expected_prefix);
common_name.map(str::to_string)
let common_name = pem.parse_x509()?.subject().to_string();
common_name.strip_prefix("CN=*.").map(|s| s.to_string())
};
Ok(TlsConfig {

View File

@@ -1,3 +1,5 @@
use std::io;
/// Marks errors that may be safely shown to a client.
/// This trait can be seen as a specialized version of [`ToString`].
///
@@ -15,3 +17,8 @@ pub trait UserFacingError: ToString {
self.to_string()
}
}
/// Upcast (almost) any error into an opaque [`io::Error`].
pub fn io_error(e: impl Into<Box<dyn std::error::Error + Send + Sync>>) -> io::Error {
io::Error::new(io::ErrorKind::Other, e)
}

View File

@@ -118,11 +118,15 @@ async fn main() -> anyhow::Result<()> {
let mgmt_address: SocketAddr = arg_matches.value_of("mgmt").unwrap().parse()?;
let http_address: SocketAddr = arg_matches.value_of("http").unwrap().parse()?;
let auth_urls = config::AuthUrls {
auth_endpoint: arg_matches.value_of("auth-endpoint").unwrap().parse()?,
auth_link_uri: arg_matches.value_of("uri").unwrap().parse()?,
};
let config: &ProxyConfig = Box::leak(Box::new(ProxyConfig {
tls_config,
auth_backend: arg_matches.value_of("auth-backend").unwrap().parse()?,
auth_endpoint: arg_matches.value_of("auth-endpoint").unwrap().parse()?,
auth_link_uri: arg_matches.value_of("uri").unwrap().parse()?,
auth_urls,
}));
println!("Version: {GIT_VERSION}");

View File

@@ -82,11 +82,22 @@ async fn handle_client(
}
let tls = config.tls_config.as_ref();
let (stream, creds) = match handshake(stream, tls, cancel_map).await? {
let (mut stream, params) = match handshake(stream, tls, cancel_map).await? {
Some(x) => x,
None => return Ok(()), // it's a cancellation request
};
let creds = {
let sni = stream.get_ref().sni_hostname();
let common_name = tls.and_then(|tls| tls.common_name.as_deref());
let result = config
.auth_backend
.map(|_| auth::ClientCredentials::parse(params, sni, common_name))
.transpose();
async { result }.or_else(|e| stream.throw_error(e)).await?
};
let client = Client::new(stream, creds);
cancel_map
.with_session(|session| client.connect_to_db(config, session))
@@ -101,12 +112,10 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
stream: S,
mut tls: Option<&TlsConfig>,
cancel_map: &CancelMap,
) -> anyhow::Result<Option<(PqStream<Stream<S>>, auth::ClientCredentials)>> {
) -> anyhow::Result<Option<(PqStream<Stream<S>>, StartupMessageParams)>> {
// Client may try upgrading to each protocol only once
let (mut tried_ssl, mut tried_gss) = (false, false);
let common_name = tls.and_then(|cfg| cfg.common_name.as_deref());
let mut stream = PqStream::new(Stream::from_raw(stream));
loop {
let msg = stream.read_startup_packet().await?;
@@ -147,18 +156,7 @@ async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
stream.throw_error_str(ERR_INSECURE_CONNECTION).await?;
}
// Get SNI info when available
let sni_data = match stream.get_ref() {
Stream::Tls { tls } => tls.get_ref().1.sni_hostname().map(|s| s.to_owned()),
_ => None,
};
// Construct credentials
let creds =
auth::ClientCredentials::parse(params, sni_data.as_deref(), common_name);
let creds = async { creds }.or_else(|e| stream.throw_error(e)).await?;
break Ok(Some((stream, creds)));
break Ok(Some((stream, params)));
}
CancelRequest(cancel_key_data) => {
cancel_map.cancel_session(cancel_key_data).await?;
@@ -174,12 +172,12 @@ struct Client<S> {
/// The underlying libpq protocol stream.
stream: PqStream<S>,
/// Client credentials that we care about.
creds: auth::ClientCredentials,
creds: auth::BackendType<auth::ClientCredentials>,
}
impl<S> Client<S> {
/// Construct a new connection context.
fn new(stream: PqStream<S>, creds: auth::ClientCredentials) -> Self {
fn new(stream: PqStream<S>, creds: auth::BackendType<auth::ClientCredentials>) -> Self {
Self { stream, creds }
}
}
@@ -194,16 +192,22 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<S> {
let Self { mut stream, creds } = self;
// Authenticate and connect to a compute node.
let auth = creds.authenticate(config, &mut stream).await;
let auth = creds.authenticate(&config.auth_urls, &mut stream).await;
let node = async { auth }.or_else(|e| stream.throw_error(e)).await?;
let (db, version, cancel_closure) =
node.connect().or_else(|e| stream.throw_error(e)).await?;
let (db, cancel_closure) = node.connect().or_else(|e| stream.throw_error(e)).await?;
let cancel_key_data = session.enable_cancellation(cancel_closure);
// Report authentication success if we haven't done this already.
if !node.reported_auth_ok {
stream
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
}
stream
.write_message_noflush(&BeMessage::ParameterStatus(
BeParameterStatusMessage::ServerVersion(&version),
BeParameterStatusMessage::ServerVersion(&db.version),
))?
.write_message_noflush(&Be::BackendKeyData(cancel_key_data))?
.write_message(&BeMessage::ReadyForQuery)
@@ -217,7 +221,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<S> {
}
// Starting from here we only proxy the client's traffic.
let mut db = MetricsStream::new(db, inc_proxied);
let mut db = MetricsStream::new(db.stream, inc_proxied);
let mut client = MetricsStream::new(stream.into_inner(), inc_proxied);
let _ = tokio::io::copy_bidirectional(&mut client, &mut db).await?;
@@ -279,9 +283,13 @@ mod tests {
let config = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![cert], key)?;
.with_single_cert(vec![cert], key)?
.into();
config.into()
TlsConfig {
config,
common_name: Some(common_name.to_string()),
}
};
let client_config = {
@@ -297,11 +305,6 @@ mod tests {
ClientConfig { config, hostname }
};
let tls_config = TlsConfig {
config: tls_config,
common_name: Some(common_name.to_string()),
};
Ok((client_config, tls_config))
}
@@ -357,7 +360,7 @@ mod tests {
auth: impl TestAuth + Send,
) -> anyhow::Result<()> {
let cancel_map = CancelMap::default();
let (mut stream, _creds) = handshake(client, tls.as_ref(), &cancel_map)
let (mut stream, _params) = handshake(client, tls.as_ref(), &cancel_map)
.await?
.context("handshake failed")?;
@@ -436,32 +439,6 @@ mod tests {
proxy.await?
}
#[tokio::test]
async fn give_user_an_error_for_bad_creds() -> anyhow::Result<()> {
let (client, server) = tokio::io::duplex(1024);
let proxy = tokio::spawn(dummy_proxy(client, None, NoAuth));
let client_err = tokio_postgres::Config::new()
.ssl_mode(SslMode::Disable)
.connect_raw(server, NoTls)
.await
.err() // -> Option<E>
.context("client shouldn't be able to connect")?;
// TODO: this is ugly, but `format!` won't allow us to extract fmt string
assert!(client_err.to_string().contains("missing in startup packet"));
let server_err = proxy
.await?
.err() // -> Option<E>
.context("server shouldn't accept client")?;
assert!(client_err.to_string().contains(&server_err.to_string()));
Ok(())
}
#[tokio::test]
async fn keepalive_is_inherited() -> anyhow::Result<()> {
use tokio::net::{TcpListener, TcpStream};

View File

@@ -145,6 +145,14 @@ impl<S> Stream<S> {
pub fn from_raw(raw: S) -> Self {
Self::Raw { raw }
}
/// Return SNI hostname when it's available.
pub fn sni_hostname(&self) -> Option<&str> {
match self {
Stream::Raw { .. } => None,
Stream::Tls { tls } => tls.get_ref().1.sni_hostname(),
}
}
}
#[derive(Debug, Error)]

View File

@@ -1,8 +1,34 @@
import pytest
import json
import base64
def test_proxy_select_1(static_proxy):
static_proxy.safe_psql("select 1;", options="project=generic-project-name")
static_proxy.safe_psql('select 1', options='project=generic-project-name')
def test_password_hack(static_proxy):
user = 'borat'
password = 'password'
static_proxy.safe_psql(f"create role {user} with login password '{password}'",
options='project=irrelevant')
def encode(s: str) -> str:
return base64.b64encode(s.encode('utf-8')).decode('utf-8')
magic = encode(json.dumps({
'project': 'irrelevant',
'password': password,
}))
static_proxy.safe_psql('select 1', sslsni=0, user=user, password=magic)
magic = encode(json.dumps({
'project': 'irrelevant',
'password_': encode(password),
}))
static_proxy.safe_psql('select 1', sslsni=0, user=user, password=magic)
# Pass extra options to the server.
@@ -11,8 +37,8 @@ def test_proxy_select_1(static_proxy):
# See https://github.com/neondatabase/neon/issues/1287
@pytest.mark.xfail
def test_proxy_options(static_proxy):
with static_proxy.connect(options="-cproxytest.option=value") as conn:
with static_proxy.connect(options='-cproxytest.option=value') as conn:
with conn.cursor() as cur:
cur.execute("SHOW proxytest.option;")
cur.execute('SHOW proxytest.option')
value = cur.fetchall()[0][0]
assert value == 'value'

View File

@@ -30,7 +30,7 @@ from dataclasses import dataclass
# Type-related stuff
from psycopg2.extensions import connection as PgConnection
from psycopg2.extensions import make_dsn, parse_dsn
from typing import Any, Callable, Dict, Iterator, List, Optional, Type, TypeVar, cast, Union, Tuple
from typing import Any, Callable, Dict, Iterator, List, Optional, TypeVar, cast, Union, Tuple
from typing_extensions import Literal
import requests
@@ -280,20 +280,18 @@ class PgProtocol:
return str(make_dsn(**self.conn_options(**kwargs)))
def conn_options(self, **kwargs):
conn_options = self.default_options.copy()
result = self.default_options.copy()
if 'dsn' in kwargs:
conn_options.update(parse_dsn(kwargs['dsn']))
conn_options.update(kwargs)
result.update(parse_dsn(kwargs['dsn']))
result.update(kwargs)
# Individual statement timeout in seconds. 2 minutes should be
# enough for our tests, but if you need a longer, you can
# change it by calling "SET statement_timeout" after
# connecting.
if 'options' in conn_options:
conn_options['options'] = f"-cstatement_timeout=120s " + conn_options['options']
else:
conn_options['options'] = "-cstatement_timeout=120s"
return conn_options
options = result.get('options', '')
result['options'] = f'-cstatement_timeout=120s {options}'
return result
# autocommit=True here by default because that's what we need most of the time
def connect(self, autocommit=True, **kwargs) -> PgConnection:
@@ -1514,29 +1512,25 @@ def remote_pg(test_output_dir: Path) -> Iterator[RemotePostgres]:
class NeonProxy(PgProtocol):
def __init__(self, port: int, pg_port: int):
super().__init__(host="127.0.0.1",
user="proxy_user",
password="pytest2",
port=port,
dbname='postgres')
self.http_port = 7001
self.host = "127.0.0.1"
self.port = port
self.pg_port = pg_port
def __init__(self, proxy_port: int, http_port: int, auth_endpoint: str):
super().__init__(dsn=auth_endpoint, port=proxy_port)
self.host = '127.0.0.1'
self.http_port = http_port
self.proxy_port = proxy_port
self.auth_endpoint = auth_endpoint
self._popen: Optional[subprocess.Popen[bytes]] = None
def start(self) -> None:
assert self._popen is None
# Start proxy
bin_proxy = os.path.join(str(neon_binpath), 'proxy')
args = [bin_proxy]
args.extend(["--http", f"{self.host}:{self.http_port}"])
args.extend(["--proxy", f"{self.host}:{self.port}"])
args.extend(["--auth-backend", "postgres"])
args.extend(
["--auth-endpoint", f"postgres://proxy_auth:pytest1@localhost:{self.pg_port}/postgres"])
args = [
os.path.join(str(neon_binpath), 'proxy'),
*["--http", f"{self.host}:{self.http_port}"],
*["--proxy", f"{self.host}:{self.proxy_port}"],
*["--auth-backend", "postgres"],
*["--auth-endpoint", self.auth_endpoint],
]
self._popen = subprocess.Popen(args)
self._wait_until_ready()
@@ -1557,13 +1551,21 @@ class NeonProxy(PgProtocol):
@pytest.fixture(scope='function')
def static_proxy(vanilla_pg, port_distributor) -> Iterator[NeonProxy]:
"""Neon proxy that routes directly to vanilla postgres."""
vanilla_pg.start()
vanilla_pg.safe_psql("create user proxy_auth with password 'pytest1' superuser")
vanilla_pg.safe_psql("create user proxy_user with password 'pytest2'")
port = port_distributor.get_port()
pg_port = vanilla_pg.default_options['port']
with NeonProxy(port, pg_port) as proxy:
# For simplicity, we use the same user for both `--auth-endpoint` and `safe_psql`
vanilla_pg.start()
vanilla_pg.safe_psql("create user proxy with login superuser password 'password'")
port = vanilla_pg.default_options['port']
host = vanilla_pg.default_options['host']
dbname = vanilla_pg.default_options['dbname']
auth_endpoint = f'postgres://proxy:password@{host}:{port}/{dbname}'
proxy_port = port_distributor.get_port()
http_port = port_distributor.get_port()
with NeonProxy(proxy_port=proxy_port, http_port=http_port,
auth_endpoint=auth_endpoint) as proxy:
proxy.start()
yield proxy