[proxy] Refactoring

This patch attempts to fix some of the technical debt
we had to introduce in previous patches.
This commit is contained in:
Dmitry Ivanov
2022-05-18 16:01:56 +03:00
parent 757746b571
commit 5d813f9738
16 changed files with 599 additions and 470 deletions

View File

@@ -1,56 +1,58 @@
mod credentials;
mod flow;
//! Client authentication mechanisms.
use crate::auth_backend::{console, legacy_console, link, postgres};
use crate::config::{AuthBackendType, ProxyConfig};
use crate::error::UserFacingError;
use crate::stream::PqStream;
use crate::{auth_backend, compute, waiters};
use console::ConsoleAuthError::SniMissing;
pub mod backend;
pub use backend::DatabaseInfo;
mod credentials;
pub use credentials::ClientCredentials;
mod flow;
pub use flow::*;
use crate::{error::UserFacingError, waiters};
use std::io;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
pub use credentials::ClientCredentials;
pub use flow::*;
/// Convenience wrapper for the authentication error.
pub type Result<T> = std::result::Result<T, AuthError>;
/// Common authentication error.
#[derive(Debug, Error)]
pub enum AuthErrorImpl {
/// Authentication error reported by the console.
#[error(transparent)]
Console(#[from] auth_backend::AuthError),
Console(#[from] backend::AuthError),
#[error(transparent)]
GetAuthInfo(#[from] auth_backend::console::ConsoleAuthError),
GetAuthInfo(#[from] backend::console::ConsoleAuthError),
#[error(transparent)]
Sasl(#[from] crate::sasl::Error),
/// For passwords that couldn't be processed by [`parse_password`].
/// For passwords that couldn't be processed by [`backend::legacy_console::parse_password`].
#[error("Malformed password message")]
MalformedPassword,
/// Errors produced by [`PqStream`].
/// Errors produced by [`crate::stream::PqStream`].
#[error(transparent)]
Io(#[from] io::Error),
}
impl AuthErrorImpl {
pub fn auth_failed(msg: impl Into<String>) -> Self {
AuthErrorImpl::Console(auth_backend::AuthError::auth_failed(msg))
Self::Console(backend::AuthError::auth_failed(msg))
}
}
impl From<waiters::RegisterError> for AuthErrorImpl {
fn from(e: waiters::RegisterError) -> Self {
AuthErrorImpl::Console(auth_backend::AuthError::from(e))
Self::Console(backend::AuthError::from(e))
}
}
impl From<waiters::WaitError> for AuthErrorImpl {
fn from(e: waiters::WaitError) -> Self {
AuthErrorImpl::Console(auth_backend::AuthError::from(e))
Self::Console(backend::AuthError::from(e))
}
}
@@ -63,7 +65,7 @@ where
AuthErrorImpl: From<T>,
{
fn from(e: T) -> Self {
AuthError(Box::new(e.into()))
Self(Box::new(e.into()))
}
}
@@ -72,34 +74,9 @@ impl UserFacingError for AuthError {
use AuthErrorImpl::*;
match self.0.as_ref() {
Console(e) => e.to_string_client(),
GetAuthInfo(e) => e.to_string_client(),
MalformedPassword => self.to_string(),
GetAuthInfo(e) if matches!(e, SniMissing) => e.to_string(),
_ => "Internal error".to_string(),
}
}
}
async fn handle_user(
config: &ProxyConfig,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
creds: ClientCredentials,
) -> Result<compute::NodeInfo, AuthError> {
match config.auth_backend {
AuthBackendType::LegacyConsole => {
legacy_console::handle_user(
&config.auth_endpoint,
&config.auth_link_uri,
client,
&creds,
)
.await
}
AuthBackendType::Console => {
console::handle_user(config.auth_endpoint.as_ref(), client, &creds).await
}
AuthBackendType::Postgres => {
postgres::handle_user(&config.auth_endpoint, client, &creds).await
}
AuthBackendType::Link => link::handle_user(config.auth_link_uri.as_ref(), client).await,
}
}

109
proxy/src/auth/backend.rs Normal file
View File

@@ -0,0 +1,109 @@
mod legacy_console;
mod link;
mod postgres;
pub mod console;
pub use legacy_console::{AuthError, AuthErrorImpl};
use super::ClientCredentials;
use crate::{
compute,
config::{AuthBackendType, ProxyConfig},
mgmt,
stream::PqStream,
waiters::{self, Waiter, Waiters},
};
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncRead, AsyncWrite};
lazy_static! {
static ref CPLANE_WAITERS: Waiters<mgmt::ComputeReady> = Default::default();
}
/// Give caller an opportunity to wait for the cloud's reply.
pub async fn with_waiter<R, T, E>(
psql_session_id: impl Into<String>,
action: impl FnOnce(Waiter<'static, mgmt::ComputeReady>) -> R,
) -> Result<T, E>
where
R: std::future::Future<Output = Result<T, E>>,
E: From<waiters::RegisterError>,
{
let waiter = CPLANE_WAITERS.register(psql_session_id.into())?;
action(waiter).await
}
pub fn notify(psql_session_id: &str, msg: mgmt::ComputeReady) -> Result<(), waiters::NotifyError> {
CPLANE_WAITERS.notify(psql_session_id, msg)
}
/// Compute node connection params provided by the cloud.
/// Note how it implements serde traits, since we receive it over the wire.
#[derive(Serialize, Deserialize, Default)]
pub struct DatabaseInfo {
pub host: String,
pub port: u16,
pub dbname: String,
pub user: String,
pub password: Option<String>,
}
// Manually implement debug to omit personal and sensitive info.
impl std::fmt::Debug for DatabaseInfo {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
fmt.debug_struct("DatabaseInfo")
.field("host", &self.host)
.field("port", &self.port)
.finish()
}
}
impl From<DatabaseInfo> for tokio_postgres::Config {
fn from(db_info: DatabaseInfo) -> Self {
let mut config = tokio_postgres::Config::new();
config
.host(&db_info.host)
.port(db_info.port)
.dbname(&db_info.dbname)
.user(&db_info.user);
if let Some(password) = db_info.password {
config.password(password);
}
config
}
}
pub(super) async fn handle_user(
config: &ProxyConfig,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
creds: ClientCredentials,
) -> super::Result<compute::NodeInfo> {
use AuthBackendType::*;
match config.auth_backend {
LegacyConsole => {
legacy_console::handle_user(
&config.auth_endpoint,
&config.auth_link_uri,
client,
&creds,
)
.await
}
Console => {
console::Api::new(&config.auth_endpoint, &creds)?
.handle_user(client)
.await
}
Postgres => {
postgres::Api::new(&config.auth_endpoint, &creds)?
.handle_user(client)
.await
}
Link => link::handle_user(&config.auth_link_uri, client).await,
}
}

View File

@@ -0,0 +1,225 @@
//! Cloud API V2.
use crate::{
auth::{self, AuthFlow, ClientCredentials, DatabaseInfo},
compute,
error::UserFacingError,
scram,
stream::PqStream,
url::ApiUrl,
};
use serde::{Deserialize, Serialize};
use std::{future::Future, io};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage};
pub type Result<T> = std::result::Result<T, ConsoleAuthError>;
#[derive(Debug, Error)]
pub enum ConsoleAuthError {
#[error(transparent)]
BadProjectName(#[from] auth::credentials::ProjectNameError),
// We shouldn't include the actual secret here.
#[error("Bad authentication secret")]
BadSecret,
#[error("Console responded with a malformed compute address: '{0}'")]
BadComputeAddress(String),
#[error("Console responded with a malformed JSON: '{0}'")]
BadResponse(#[from] serde_json::Error),
/// HTTP status (other than 200) returned by the console.
#[error("Console responded with an HTTP status: {0}")]
HttpStatus(reqwest::StatusCode),
#[error(transparent)]
Io(#[from] std::io::Error),
}
impl UserFacingError for ConsoleAuthError {
fn to_string_client(&self) -> String {
use ConsoleAuthError::*;
match self {
BadProjectName(e) => e.to_string_client(),
_ => "Internal error".to_string(),
}
}
}
// TODO: convert into an enum with "error"
#[derive(Serialize, Deserialize, Debug)]
struct GetRoleSecretResponse {
role_secret: String,
}
// TODO: convert into an enum with "error"
#[derive(Serialize, Deserialize, Debug)]
struct GetWakeComputeResponse {
address: String,
}
/// Auth secret which is managed by the cloud.
pub enum AuthInfo {
/// Md5 hash of user's password.
Md5([u8; 16]),
/// [SCRAM](crate::scram) authentication info.
Scram(scram::ServerSecret),
}
#[must_use]
pub(super) struct Api<'a> {
endpoint: &'a ApiUrl,
creds: &'a ClientCredentials,
/// Cache project name, since we'll need it several times.
project: &'a str,
}
impl<'a> Api<'a> {
/// Construct an API object containing the auth parameters.
pub(super) fn new(endpoint: &'a ApiUrl, creds: &'a ClientCredentials) -> Result<Self> {
Ok(Self {
endpoint,
creds,
project: creds.project_name()?,
})
}
/// Authenticate the existing user or throw an error.
pub(super) async fn handle_user(
self,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
) -> auth::Result<compute::NodeInfo> {
handle_user(client, &self, Self::get_auth_info, Self::wake_compute).await
}
async fn get_auth_info(&self) -> Result<AuthInfo> {
let mut url = self.endpoint.clone();
url.path_segments_mut().push("proxy_get_role_secret");
url.query_pairs_mut()
.append_pair("project", self.project)
.append_pair("role", &self.creds.user);
// TODO: use a proper logger
println!("cplane request: {url}");
let resp = reqwest::get(url.into_inner()).await.map_err(io_error)?;
if !resp.status().is_success() {
return Err(ConsoleAuthError::HttpStatus(resp.status()));
}
let response: GetRoleSecretResponse =
serde_json::from_str(&resp.text().await.map_err(io_error)?)?;
scram::ServerSecret::parse(response.role_secret.as_str())
.map(AuthInfo::Scram)
.ok_or(ConsoleAuthError::BadSecret)
}
/// Wake up the compute node and return the corresponding connection info.
async fn wake_compute(&self) -> Result<DatabaseInfo> {
let mut url = self.endpoint.clone();
url.path_segments_mut().push("proxy_wake_compute");
url.query_pairs_mut().append_pair("project", self.project);
// TODO: use a proper logger
println!("cplane request: {url}");
let resp = reqwest::get(url.into_inner()).await.map_err(io_error)?;
if !resp.status().is_success() {
return Err(ConsoleAuthError::HttpStatus(resp.status()));
}
let response: GetWakeComputeResponse =
serde_json::from_str(&resp.text().await.map_err(io_error)?)?;
let (host, port) = parse_host_port(&response.address)
.ok_or(ConsoleAuthError::BadComputeAddress(response.address))?;
Ok(DatabaseInfo {
host,
port,
dbname: self.creds.dbname.to_owned(),
user: self.creds.user.to_owned(),
password: None,
})
}
}
/// Common logic for user handling in API V2.
/// We reuse this for a mock API implementation in [`super::postgres`].
pub(super) async fn handle_user<'a, Endpoint, GetAuthInfo, WakeCompute>(
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
endpoint: &'a Endpoint,
get_auth_info: impl FnOnce(&'a Endpoint) -> GetAuthInfo,
wake_compute: impl FnOnce(&'a Endpoint) -> WakeCompute,
) -> auth::Result<compute::NodeInfo>
where
GetAuthInfo: Future<Output = Result<AuthInfo>>,
WakeCompute: Future<Output = Result<DatabaseInfo>>,
{
let auth_info = get_auth_info(endpoint).await?;
let flow = AuthFlow::new(client);
let scram_keys = match auth_info {
AuthInfo::Md5(_) => {
// TODO: decide if we should support MD5 in api v2
return Err(auth::AuthErrorImpl::auth_failed("MD5 is not supported").into());
}
AuthInfo::Scram(secret) => {
let scram = auth::Scram(&secret);
Some(compute::ScramKeys {
client_key: flow.begin(scram).await?.authenticate().await?.as_bytes(),
server_key: secret.server_key.as_bytes(),
})
}
};
client
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
Ok(compute::NodeInfo {
db_info: wake_compute(endpoint).await?,
scram_keys,
})
}
/// Upcast (almost) any error into an opaque [`io::Error`].
pub(super) fn io_error(e: impl Into<Box<dyn std::error::Error + Send + Sync>>) -> io::Error {
io::Error::new(io::ErrorKind::Other, e)
}
fn parse_host_port(input: &str) -> Option<(String, u16)> {
let (host, port) = input.split_once(':')?;
Some((host.to_owned(), port.parse().ok()?))
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn parse_db_info() -> anyhow::Result<()> {
let _: DatabaseInfo = serde_json::from_value(json!({
"host": "localhost",
"port": 5432,
"dbname": "postgres",
"user": "john_doe",
"password": "password",
}))?;
let _: DatabaseInfo = serde_json::from_value(json!({
"host": "localhost",
"port": 5432,
"dbname": "postgres",
"user": "john_doe",
}))?;
Ok(())
}
}

View File

@@ -1,20 +1,18 @@
//! Cloud API V1.
use super::console::DatabaseInfo;
use crate::auth::ClientCredentials;
use crate::stream::PqStream;
use crate::{compute, waiters};
use super::DatabaseInfo;
use crate::{
auth::{self, ClientCredentials},
compute,
error::UserFacingError,
stream::PqStream,
waiters,
};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage};
use thiserror::Error;
use crate::error::UserFacingError;
#[derive(Debug, Error)]
pub enum AuthErrorImpl {
/// Authentication error reported by the console.
@@ -45,7 +43,7 @@ pub struct AuthError(Box<AuthErrorImpl>);
impl AuthError {
/// Smart constructor for authentication error reported by `mgmt`.
pub fn auth_failed(msg: impl Into<String>) -> Self {
AuthError(Box::new(AuthErrorImpl::AuthFailed(msg.into())))
Self(Box::new(AuthErrorImpl::AuthFailed(msg.into())))
}
}
@@ -54,7 +52,7 @@ where
AuthErrorImpl: From<T>,
{
fn from(e: T) -> Self {
AuthError(Box::new(e.into()))
Self(Box::new(e.into()))
}
}
@@ -120,7 +118,7 @@ async fn handle_existing_user(
auth_endpoint: &reqwest::Url,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
creds: &ClientCredentials,
) -> Result<crate::compute::NodeInfo, crate::auth::AuthError> {
) -> Result<compute::NodeInfo, auth::AuthError> {
let psql_session_id = super::link::new_psql_session_id();
let md5_salt = rand::random();
@@ -130,7 +128,7 @@ async fn handle_existing_user(
// Read client's password hash
let msg = client.read_password_message().await?;
let md5_response = parse_password(&msg).ok_or(crate::auth::AuthErrorImpl::MalformedPassword)?;
let md5_response = parse_password(&msg).ok_or(auth::AuthErrorImpl::MalformedPassword)?;
let db_info = authenticate_proxy_client(
auth_endpoint,
@@ -156,11 +154,11 @@ pub async fn handle_user(
auth_link_uri: &reqwest::Url,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
creds: &ClientCredentials,
) -> Result<crate::compute::NodeInfo, crate::auth::AuthError> {
) -> auth::Result<compute::NodeInfo> {
if creds.is_existing_user() {
handle_existing_user(auth_endpoint, client, creds).await
} else {
super::link::handle_user(auth_link_uri.as_ref(), client).await
super::link::handle_user(auth_link_uri, client).await
}
}

View File

@@ -1,4 +1,4 @@
use crate::{compute, stream::PqStream};
use crate::{auth, compute, stream::PqStream};
use tokio::io::{AsyncRead, AsyncWrite};
use utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage};
@@ -19,13 +19,13 @@ pub fn new_psql_session_id() -> String {
}
pub async fn handle_user(
redirect_uri: &str,
redirect_uri: &reqwest::Url,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> Result<compute::NodeInfo, crate::auth::AuthError> {
) -> auth::Result<compute::NodeInfo> {
let psql_session_id = new_psql_session_id();
let greeting = hello_message(redirect_uri, &psql_session_id);
let greeting = hello_message(redirect_uri.as_str(), &psql_session_id);
let db_info = crate::auth_backend::with_waiter(psql_session_id, |waiter| async {
let db_info = super::with_waiter(psql_session_id, |waiter| async {
// Give user a URL to spawn a new database
client
.write_message_noflush(&Be::AuthenticationOk)?
@@ -34,9 +34,7 @@ pub async fn handle_user(
.await?;
// Wait for web console response (see `mgmt`)
waiter
.await?
.map_err(crate::auth::AuthErrorImpl::auth_failed)
waiter.await?.map_err(auth::AuthErrorImpl::auth_failed)
})
.await?;

View File

@@ -0,0 +1,88 @@
//! Local mock of Cloud API V2.
use crate::{
auth::{
self,
backend::console::{self, io_error, AuthInfo, Result},
ClientCredentials, DatabaseInfo,
},
compute, scram,
stream::PqStream,
url::ApiUrl,
};
use tokio::io::{AsyncRead, AsyncWrite};
#[must_use]
pub(super) struct Api<'a> {
endpoint: &'a ApiUrl,
creds: &'a ClientCredentials,
}
impl<'a> Api<'a> {
/// Construct an API object containing the auth parameters.
pub(super) fn new(endpoint: &'a ApiUrl, creds: &'a ClientCredentials) -> Result<Self> {
Ok(Self { endpoint, creds })
}
/// Authenticate the existing user or throw an error.
pub(super) async fn handle_user(
self,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
) -> auth::Result<compute::NodeInfo> {
// We reuse user handling logic from a production module.
console::handle_user(client, &self, Self::get_auth_info, Self::wake_compute).await
}
/// This implementation fetches the auth info from a local postgres instance.
async fn get_auth_info(&self) -> Result<AuthInfo> {
// Perhaps we could persist this connection, but then we'd have to
// write more code for reopening it if it got closed, which doesn't
// seem worth it.
let (client, connection) =
tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls)
.await
.map_err(io_error)?;
tokio::spawn(connection);
let query = "select rolpassword from pg_catalog.pg_authid where rolname = $1";
let rows = client
.query(query, &[&self.creds.user])
.await
.map_err(io_error)?;
match &rows[..] {
// We can't get a secret if there's no such user.
[] => Err(io_error(format!("unknown user '{}'", self.creds.user)).into()),
// We shouldn't get more than one row anyway.
[row, ..] => {
let entry = row.try_get(0).map_err(io_error)?;
scram::ServerSecret::parse(entry)
.map(AuthInfo::Scram)
.or_else(|| {
// It could be an md5 hash if it's not a SCRAM secret.
let text = entry.strip_prefix("md5")?;
Some(AuthInfo::Md5({
let mut bytes = [0u8; 16];
hex::decode_to_slice(text, &mut bytes).ok()?;
bytes
}))
})
// Putting the secret into this message is a security hazard!
.ok_or(console::ConsoleAuthError::BadSecret)
}
}
}
/// We don't need to wake anything locally, so we just return the connection info.
async fn wake_compute(&self) -> Result<DatabaseInfo> {
Ok(DatabaseInfo {
// TODO: handle that near CLI params parsing
host: self.endpoint.host_str().unwrap_or("localhost").to_owned(),
port: self.endpoint.port().unwrap_or(5432),
dbname: self.creds.dbname.to_owned(),
user: self.creds.user.to_owned(),
password: None,
})
}
}

View File

@@ -1,6 +1,5 @@
//! User credentials used in authentication.
use super::AuthError;
use crate::compute;
use crate::config::ProxyConfig;
use crate::error::UserFacingError;
@@ -36,6 +35,27 @@ impl ClientCredentials {
}
}
#[derive(Debug, Error)]
pub enum ProjectNameError {
#[error("SNI is missing, please upgrade the postgres client library")]
Missing,
#[error("SNI is malformed")]
Bad,
}
impl UserFacingError for ProjectNameError {}
impl ClientCredentials {
/// Determine project name from SNI.
pub fn project_name(&self) -> Result<&str, ProjectNameError> {
// Currently project name is passed as a top level domain
let sni = self.sni_data.as_ref().ok_or(ProjectNameError::Missing)?;
let (first, _) = sni.split_once('.').ok_or(ProjectNameError::Bad)?;
Ok(first)
}
}
impl TryFrom<HashMap<String, String>> for ClientCredentials {
type Error = ClientCredsParseError;
@@ -47,11 +67,11 @@ impl TryFrom<HashMap<String, String>> for ClientCredentials {
};
let user = get_param("user")?;
let db = get_param("database")?;
let dbname = get_param("database")?;
Ok(Self {
user,
dbname: db,
dbname,
sni_data: None,
})
}
@@ -63,8 +83,8 @@ impl ClientCredentials {
self,
config: &ProxyConfig,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
) -> Result<compute::NodeInfo, AuthError> {
) -> super::Result<compute::NodeInfo> {
// This method is just a convenient facade for `handle_user`
super::handle_user(config, client, self).await
super::backend::handle_user(config, client, self).await
}
}

View File

@@ -1,6 +1,6 @@
//! Main authentication flow.
use super::{AuthError, AuthErrorImpl};
use super::AuthErrorImpl;
use crate::stream::PqStream;
use crate::{sasl, scram};
use std::io;
@@ -32,7 +32,7 @@ impl AuthMethod for Scram<'_> {
pub struct AuthFlow<'a, Stream, State> {
/// The underlying stream which implements libpq's protocol.
stream: &'a mut PqStream<Stream>,
/// State might contain ancillary data (see [`AuthFlow::begin`]).
/// State might contain ancillary data (see [`Self::begin`]).
state: State,
}
@@ -60,7 +60,7 @@ impl<'a, S: AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
/// Stream wrapper for handling [SCRAM](crate::scram) auth.
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
/// Perform user authentication. Raise an error in case authentication failed.
pub async fn authenticate(self) -> Result<scram::ScramKey, AuthError> {
pub async fn authenticate(self) -> super::Result<scram::ScramKey> {
// Initial client message contains the chosen auth method's name.
let msg = self.stream.read_password_message().await?;
let sasl = sasl::FirstMessage::parse(&msg).ok_or(AuthErrorImpl::MalformedPassword)?;

View File

@@ -1,31 +0,0 @@
pub mod console;
pub mod legacy_console;
pub mod link;
pub mod postgres;
pub use legacy_console::{AuthError, AuthErrorImpl};
use crate::mgmt;
use crate::waiters::{self, Waiter, Waiters};
use lazy_static::lazy_static;
lazy_static! {
static ref CPLANE_WAITERS: Waiters<mgmt::ComputeReady> = Default::default();
}
/// Give caller an opportunity to wait for the cloud's reply.
pub async fn with_waiter<R, T, E>(
psql_session_id: impl Into<String>,
action: impl FnOnce(Waiter<'static, mgmt::ComputeReady>) -> R,
) -> Result<T, E>
where
R: std::future::Future<Output = Result<T, E>>,
E: From<waiters::RegisterError>,
{
let waiter = CPLANE_WAITERS.register(psql_session_id.into())?;
action(waiter).await
}
pub fn notify(psql_session_id: &str, msg: mgmt::ComputeReady) -> Result<(), waiters::NotifyError> {
CPLANE_WAITERS.notify(psql_session_id, msg)
}

View File

@@ -1,243 +0,0 @@
//! Declaration of Cloud API V2.
use crate::{
auth::{self, AuthFlow},
compute, scram,
};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use crate::auth::ClientCredentials;
use crate::stream::PqStream;
use tokio::io::{AsyncRead, AsyncWrite};
use utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage};
#[derive(Debug, Error)]
pub enum ConsoleAuthError {
// We shouldn't include the actual secret here.
#[error("Bad authentication secret")]
BadSecret,
#[error("Bad client credentials: {0:?}")]
BadCredentials(crate::auth::ClientCredentials),
#[error("SNI info is missing, please upgrade the postgres client library")]
SniMissing,
#[error("Unexpected SNI content")]
SniWrong,
#[error(transparent)]
BadUrl(#[from] url::ParseError),
#[error(transparent)]
Io(#[from] std::io::Error),
/// HTTP status (other than 200) returned by the console.
#[error("Console responded with an HTTP status: {0}")]
HttpStatus(reqwest::StatusCode),
#[error(transparent)]
Transport(#[from] reqwest::Error),
#[error("Console responded with a malformed JSON: '{0}'")]
MalformedResponse(#[from] serde_json::Error),
#[error("Console responded with a malformed compute address: '{0}'")]
MalformedComputeAddress(String),
}
#[derive(Serialize, Deserialize, Debug)]
struct GetRoleSecretResponse {
role_secret: String,
}
#[derive(Serialize, Deserialize, Debug)]
struct GetWakeComputeResponse {
address: String,
}
/// Auth secret which is managed by the cloud.
pub enum AuthInfo {
/// Md5 hash of user's password.
Md5([u8; 16]),
/// [SCRAM](crate::scram) authentication info.
Scram(scram::ServerSecret),
}
/// Compute node connection params provided by the cloud.
/// Note how it implements serde traits, since we receive it over the wire.
#[derive(Serialize, Deserialize, Default)]
pub struct DatabaseInfo {
pub host: String,
pub port: u16,
pub dbname: String,
pub user: String,
/// [Cloud API V1](super::legacy) returns cleartext password,
/// but [Cloud API V2](super::api) implements [SCRAM](crate::scram)
/// authentication, so we can leverage this method and cope without password.
pub password: Option<String>,
}
// Manually implement debug to omit personal and sensitive info.
impl std::fmt::Debug for DatabaseInfo {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
fmt.debug_struct("DatabaseInfo")
.field("host", &self.host)
.field("port", &self.port)
.finish()
}
}
impl From<DatabaseInfo> for tokio_postgres::Config {
fn from(db_info: DatabaseInfo) -> Self {
let mut config = tokio_postgres::Config::new();
config
.host(&db_info.host)
.port(db_info.port)
.dbname(&db_info.dbname)
.user(&db_info.user);
if let Some(password) = db_info.password {
config.password(password);
}
config
}
}
async fn get_auth_info(
auth_endpoint: &str,
user: &str,
cluster: &str,
) -> Result<AuthInfo, ConsoleAuthError> {
let mut url = reqwest::Url::parse(&format!("{auth_endpoint}/proxy_get_role_secret"))?;
url.query_pairs_mut()
.append_pair("project", cluster)
.append_pair("role", user);
// TODO: use a proper logger
println!("cplane request: {}", url);
let resp = reqwest::get(url).await?;
if !resp.status().is_success() {
return Err(ConsoleAuthError::HttpStatus(resp.status()));
}
let response: GetRoleSecretResponse = serde_json::from_str(resp.text().await?.as_str())?;
scram::ServerSecret::parse(response.role_secret.as_str())
.map(AuthInfo::Scram)
.ok_or(ConsoleAuthError::BadSecret)
}
/// Wake up the compute node and return the corresponding connection info.
async fn wake_compute(
auth_endpoint: &str,
cluster: &str,
) -> Result<(String, u16), ConsoleAuthError> {
let mut url = reqwest::Url::parse(&format!("{auth_endpoint}/proxy_wake_compute"))?;
url.query_pairs_mut().append_pair("project", cluster);
// TODO: use a proper logger
println!("cplane request: {}", url);
let resp = reqwest::get(url).await?;
if !resp.status().is_success() {
return Err(ConsoleAuthError::HttpStatus(resp.status()));
}
let response: GetWakeComputeResponse = serde_json::from_str(resp.text().await?.as_str())?;
let (host, port) = response
.address
.split_once(':')
.ok_or_else(|| ConsoleAuthError::MalformedComputeAddress(response.address.clone()))?;
let port: u16 = port
.parse()
.map_err(|_| ConsoleAuthError::MalformedComputeAddress(response.address.clone()))?;
Ok((host.to_string(), port))
}
pub async fn handle_user(
auth_endpoint: &str,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
creds: &ClientCredentials,
) -> Result<compute::NodeInfo, crate::auth::AuthError> {
// Determine cluster name from SNI.
let cluster = creds
.sni_data
.as_ref()
.ok_or(ConsoleAuthError::SniMissing)?
.split_once('.')
.ok_or(ConsoleAuthError::SniWrong)?
.0;
let user = creds.user.as_str();
// Step 1: get the auth secret
let auth_info = get_auth_info(auth_endpoint, user, cluster).await?;
let flow = AuthFlow::new(client);
let scram_keys = match auth_info {
AuthInfo::Md5(_) => {
// TODO: decide if we should support MD5 in api v2
return Err(crate::auth::AuthErrorImpl::auth_failed("MD5 is not supported").into());
}
AuthInfo::Scram(secret) => {
let scram = auth::Scram(&secret);
Some(compute::ScramKeys {
client_key: flow.begin(scram).await?.authenticate().await?.as_bytes(),
server_key: secret.server_key.as_bytes(),
})
}
};
client
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
// Step 2: wake compute
let (host, port) = wake_compute(auth_endpoint, cluster).await?;
Ok(compute::NodeInfo {
db_info: DatabaseInfo {
host,
port,
dbname: creds.dbname.clone(),
user: creds.user.clone(),
password: None,
},
scram_keys,
})
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn parse_db_info() -> anyhow::Result<()> {
let _: DatabaseInfo = serde_json::from_value(json!({
"host": "localhost",
"port": 5432,
"dbname": "postgres",
"user": "john_doe",
"password": "password",
}))?;
let _: DatabaseInfo = serde_json::from_value(json!({
"host": "localhost",
"port": 5432,
"dbname": "postgres",
"user": "john_doe",
}))?;
Ok(())
}
}

View File

@@ -1,93 +0,0 @@
//! Local mock of Cloud API V2.
use super::console::{self, AuthInfo, DatabaseInfo};
use crate::scram;
use crate::{auth::ClientCredentials, compute};
use crate::stream::PqStream;
use tokio::io::{AsyncRead, AsyncWrite};
use utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage};
async fn get_auth_info(
auth_endpoint: &str,
creds: &ClientCredentials,
) -> Result<AuthInfo, console::ConsoleAuthError> {
// We wrap `tokio_postgres::Error` because we don't want to infect the
// method's error type with a detail that's specific to debug mode only.
let io_error = |e| std::io::Error::new(std::io::ErrorKind::Other, e);
// Perhaps we could persist this connection, but then we'd have to
// write more code for reopening it if it got closed, which doesn't
// seem worth it.
let (client, connection) = tokio_postgres::connect(auth_endpoint, tokio_postgres::NoTls)
.await
.map_err(io_error)?;
tokio::spawn(connection);
let query = "select rolpassword from pg_catalog.pg_authid where rolname = $1";
let rows = client
.query(query, &[&creds.user])
.await
.map_err(io_error)?;
match &rows[..] {
// We can't get a secret if there's no such user.
[] => Err(console::ConsoleAuthError::BadCredentials(creds.to_owned())),
// We shouldn't get more than one row anyway.
[row, ..] => {
let entry = row.try_get(0).map_err(io_error)?;
scram::ServerSecret::parse(entry)
.map(AuthInfo::Scram)
.or_else(|| {
// It could be an md5 hash if it's not a SCRAM secret.
let text = entry.strip_prefix("md5")?;
Some(AuthInfo::Md5({
let mut bytes = [0u8; 16];
hex::decode_to_slice(text, &mut bytes).ok()?;
bytes
}))
})
// Putting the secret into this message is a security hazard!
.ok_or(console::ConsoleAuthError::BadSecret)
}
}
}
pub async fn handle_user(
auth_endpoint: &reqwest::Url,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
creds: &ClientCredentials,
) -> Result<compute::NodeInfo, crate::auth::AuthError> {
let auth_info = get_auth_info(auth_endpoint.as_ref(), creds).await?;
let flow = crate::auth::AuthFlow::new(client);
let scram_keys = match auth_info {
AuthInfo::Md5(_) => {
// TODO: decide if we should support MD5 in api v2
return Err(crate::auth::AuthErrorImpl::auth_failed("MD5 is not supported").into());
}
AuthInfo::Scram(secret) => {
let scram = crate::auth::Scram(&secret);
Some(compute::ScramKeys {
client_key: flow.begin(scram).await?.authenticate().await?.as_bytes(),
server_key: secret.server_key.as_bytes(),
})
}
};
client
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
Ok(compute::NodeInfo {
db_info: DatabaseInfo {
// TODO: handle that near CLI params parsing
host: auth_endpoint.host_str().unwrap_or("localhost").to_owned(),
port: auth_endpoint.port().unwrap_or(5432),
dbname: creds.dbname.to_owned(),
user: creds.user.to_owned(),
password: None,
},
scram_keys,
})
}

View File

@@ -1,4 +1,4 @@
use crate::auth_backend::console::DatabaseInfo;
use crate::auth::DatabaseInfo;
use crate::cancellation::CancelClosure;
use crate::error::UserFacingError;
use std::io;
@@ -37,7 +37,7 @@ pub struct NodeInfo {
impl NodeInfo {
async fn connect_raw(&self) -> io::Result<(SocketAddr, TcpStream)> {
let host_port = format!("{}:{}", self.db_info.host, self.db_info.port);
let host_port = (self.db_info.host.as_str(), self.db_info.port);
let socket = TcpStream::connect(host_port).await?;
let socket_addr = socket.peer_addr()?;
socket2::SockRef::from(&socket).set_keepalive(true)?;

View File

@@ -1,39 +1,38 @@
use anyhow::{ensure, Context};
use crate::url::ApiUrl;
use anyhow::{bail, ensure, Context};
use std::{str::FromStr, sync::Arc};
#[non_exhaustive]
pub enum AuthBackendType {
/// Legacy Cloud API (V1).
LegacyConsole,
Console,
Postgres,
/// Authentication via a web browser.
Link,
/// Current Cloud API (V2).
Console,
/// Local mock of Cloud API (V2).
Postgres,
}
impl FromStr for AuthBackendType {
type Err = anyhow::Error;
fn from_str(s: &str) -> anyhow::Result<Self> {
println!("ClientAuthMethod::from_str: '{}'", s);
use AuthBackendType::*;
match s {
"legacy" => Ok(LegacyConsole),
"console" => Ok(Console),
"postgres" => Ok(Postgres),
"link" => Ok(Link),
_ => Err(anyhow::anyhow!("Invlid option for auth method")),
}
Ok(match s {
"legacy" => LegacyConsole,
"console" => Console,
"postgres" => Postgres,
"link" => Link,
_ => bail!("Invalid option `{s}` for auth method"),
})
}
}
pub struct ProxyConfig {
/// TLS configuration for the proxy.
pub tls_config: Option<TlsConfig>,
pub auth_backend: AuthBackendType,
pub auth_endpoint: reqwest::Url,
pub auth_link_uri: reqwest::Url,
pub auth_endpoint: ApiUrl,
pub auth_link_uri: ApiUrl,
}
pub type TlsConfig = Arc<rustls::ServerConfig>;

View File

@@ -5,7 +5,6 @@
//! in somewhat transparent manner (again via communication with control plane API).
mod auth;
mod auth_backend;
mod cancellation;
mod compute;
mod config;
@@ -17,6 +16,7 @@ mod proxy;
mod sasl;
mod scram;
mod stream;
mod url;
mod waiters;
use anyhow::{bail, Context};

View File

@@ -1,4 +1,4 @@
use crate::auth_backend;
use crate::auth;
use anyhow::Context;
use serde::Deserialize;
use std::{
@@ -77,12 +77,12 @@ struct PsqlSessionResponse {
#[derive(Deserialize)]
enum PsqlSessionResult {
Success(auth_backend::console::DatabaseInfo),
Success(auth::DatabaseInfo),
Failure(String),
}
/// A message received by `mgmt` when a compute node is ready.
pub type ComputeReady = Result<auth_backend::console::DatabaseInfo, String>;
pub type ComputeReady = Result<auth::DatabaseInfo, String>;
impl PsqlSessionResult {
fn into_compute_ready(self) -> ComputeReady {
@@ -113,7 +113,7 @@ fn try_process_query(pgb: &mut PostgresBackend, query_string: &str) -> anyhow::R
let resp: PsqlSessionResponse = serde_json::from_str(query_string)?;
match auth_backend::notify(&resp.session_id, resp.result.into_compute_ready()) {
match auth::backend::notify(&resp.session_id, resp.result.into_compute_ready()) {
Ok(()) => {
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
.write_message_noflush(&BeMessage::DataRow(&[Some(b"ok")]))?

82
proxy/src/url.rs Normal file
View File

@@ -0,0 +1,82 @@
use anyhow::bail;
use url::form_urlencoded::Serializer;
/// A [url](url::Url) type with additional guarantees.
#[derive(Debug, Clone)]
pub struct ApiUrl(url::Url);
impl ApiUrl {
/// Consume the wrapper and return inner [url](url::Url).
pub fn into_inner(self) -> url::Url {
self.0
}
/// See [`url::Url::query_pairs_mut`].
pub fn query_pairs_mut(&mut self) -> Serializer<'_, url::UrlQuery<'_>> {
self.0.query_pairs_mut()
}
/// See [`url::Url::path_segments_mut`].
pub fn path_segments_mut(&mut self) -> url::PathSegmentsMut {
// We've already verified that it works during construction.
self.0.path_segments_mut().expect("bad API url")
}
}
/// This instance imposes additional requirements on the url.
impl std::str::FromStr for ApiUrl {
type Err = anyhow::Error;
fn from_str(s: &str) -> anyhow::Result<Self> {
let mut url: url::Url = s.parse()?;
// Make sure that we can build upon this URL.
if url.path_segments_mut().is_err() {
bail!("bad API url provided");
}
Ok(Self(url))
}
}
/// This instance is safe because it doesn't allow us to modify the object.
impl std::ops::Deref for ApiUrl {
type Target = url::Url;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl std::fmt::Display for ApiUrl {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn bad_url() {
let url = "test:foobar";
url.parse::<url::Url>().expect("unexpected parsing failure");
let _ = url.parse::<ApiUrl>().expect_err("should not parse");
}
#[test]
fn good_url() {
let url = "test://foobar";
let mut a = url.parse::<url::Url>().expect("unexpected parsing failure");
let mut b = url.parse::<ApiUrl>().expect("unexpected parsing failure");
a.path_segments_mut().unwrap().push("method");
a.query_pairs_mut().append_pair("key", "value");
b.path_segments_mut().push("method");
b.query_pairs_mut().append_pair("key", "value");
assert_eq!(a, b.into_inner());
}
}