mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-07 05:22:56 +00:00
[proxy] Refactoring
This patch attempts to fix some of the technical debt we had to introduce in previous patches.
This commit is contained in:
@@ -1,56 +1,58 @@
|
||||
mod credentials;
|
||||
mod flow;
|
||||
//! Client authentication mechanisms.
|
||||
|
||||
use crate::auth_backend::{console, legacy_console, link, postgres};
|
||||
use crate::config::{AuthBackendType, ProxyConfig};
|
||||
use crate::error::UserFacingError;
|
||||
use crate::stream::PqStream;
|
||||
use crate::{auth_backend, compute, waiters};
|
||||
use console::ConsoleAuthError::SniMissing;
|
||||
pub mod backend;
|
||||
pub use backend::DatabaseInfo;
|
||||
|
||||
mod credentials;
|
||||
pub use credentials::ClientCredentials;
|
||||
|
||||
mod flow;
|
||||
pub use flow::*;
|
||||
|
||||
use crate::{error::UserFacingError, waiters};
|
||||
use std::io;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
pub use credentials::ClientCredentials;
|
||||
pub use flow::*;
|
||||
/// Convenience wrapper for the authentication error.
|
||||
pub type Result<T> = std::result::Result<T, AuthError>;
|
||||
|
||||
/// Common authentication error.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum AuthErrorImpl {
|
||||
/// Authentication error reported by the console.
|
||||
#[error(transparent)]
|
||||
Console(#[from] auth_backend::AuthError),
|
||||
Console(#[from] backend::AuthError),
|
||||
|
||||
#[error(transparent)]
|
||||
GetAuthInfo(#[from] auth_backend::console::ConsoleAuthError),
|
||||
GetAuthInfo(#[from] backend::console::ConsoleAuthError),
|
||||
|
||||
#[error(transparent)]
|
||||
Sasl(#[from] crate::sasl::Error),
|
||||
|
||||
/// For passwords that couldn't be processed by [`parse_password`].
|
||||
/// For passwords that couldn't be processed by [`backend::legacy_console::parse_password`].
|
||||
#[error("Malformed password message")]
|
||||
MalformedPassword,
|
||||
|
||||
/// Errors produced by [`PqStream`].
|
||||
/// Errors produced by [`crate::stream::PqStream`].
|
||||
#[error(transparent)]
|
||||
Io(#[from] io::Error),
|
||||
}
|
||||
|
||||
impl AuthErrorImpl {
|
||||
pub fn auth_failed(msg: impl Into<String>) -> Self {
|
||||
AuthErrorImpl::Console(auth_backend::AuthError::auth_failed(msg))
|
||||
Self::Console(backend::AuthError::auth_failed(msg))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<waiters::RegisterError> for AuthErrorImpl {
|
||||
fn from(e: waiters::RegisterError) -> Self {
|
||||
AuthErrorImpl::Console(auth_backend::AuthError::from(e))
|
||||
Self::Console(backend::AuthError::from(e))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<waiters::WaitError> for AuthErrorImpl {
|
||||
fn from(e: waiters::WaitError) -> Self {
|
||||
AuthErrorImpl::Console(auth_backend::AuthError::from(e))
|
||||
Self::Console(backend::AuthError::from(e))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -63,7 +65,7 @@ where
|
||||
AuthErrorImpl: From<T>,
|
||||
{
|
||||
fn from(e: T) -> Self {
|
||||
AuthError(Box::new(e.into()))
|
||||
Self(Box::new(e.into()))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,34 +74,9 @@ impl UserFacingError for AuthError {
|
||||
use AuthErrorImpl::*;
|
||||
match self.0.as_ref() {
|
||||
Console(e) => e.to_string_client(),
|
||||
GetAuthInfo(e) => e.to_string_client(),
|
||||
MalformedPassword => self.to_string(),
|
||||
GetAuthInfo(e) if matches!(e, SniMissing) => e.to_string(),
|
||||
_ => "Internal error".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_user(
|
||||
config: &ProxyConfig,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
|
||||
creds: ClientCredentials,
|
||||
) -> Result<compute::NodeInfo, AuthError> {
|
||||
match config.auth_backend {
|
||||
AuthBackendType::LegacyConsole => {
|
||||
legacy_console::handle_user(
|
||||
&config.auth_endpoint,
|
||||
&config.auth_link_uri,
|
||||
client,
|
||||
&creds,
|
||||
)
|
||||
.await
|
||||
}
|
||||
AuthBackendType::Console => {
|
||||
console::handle_user(config.auth_endpoint.as_ref(), client, &creds).await
|
||||
}
|
||||
AuthBackendType::Postgres => {
|
||||
postgres::handle_user(&config.auth_endpoint, client, &creds).await
|
||||
}
|
||||
AuthBackendType::Link => link::handle_user(config.auth_link_uri.as_ref(), client).await,
|
||||
}
|
||||
}
|
||||
|
||||
109
proxy/src/auth/backend.rs
Normal file
109
proxy/src/auth/backend.rs
Normal file
@@ -0,0 +1,109 @@
|
||||
mod legacy_console;
|
||||
mod link;
|
||||
mod postgres;
|
||||
|
||||
pub mod console;
|
||||
|
||||
pub use legacy_console::{AuthError, AuthErrorImpl};
|
||||
|
||||
use super::ClientCredentials;
|
||||
use crate::{
|
||||
compute,
|
||||
config::{AuthBackendType, ProxyConfig},
|
||||
mgmt,
|
||||
stream::PqStream,
|
||||
waiters::{self, Waiter, Waiters},
|
||||
};
|
||||
use lazy_static::lazy_static;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
lazy_static! {
|
||||
static ref CPLANE_WAITERS: Waiters<mgmt::ComputeReady> = Default::default();
|
||||
}
|
||||
|
||||
/// Give caller an opportunity to wait for the cloud's reply.
|
||||
pub async fn with_waiter<R, T, E>(
|
||||
psql_session_id: impl Into<String>,
|
||||
action: impl FnOnce(Waiter<'static, mgmt::ComputeReady>) -> R,
|
||||
) -> Result<T, E>
|
||||
where
|
||||
R: std::future::Future<Output = Result<T, E>>,
|
||||
E: From<waiters::RegisterError>,
|
||||
{
|
||||
let waiter = CPLANE_WAITERS.register(psql_session_id.into())?;
|
||||
action(waiter).await
|
||||
}
|
||||
|
||||
pub fn notify(psql_session_id: &str, msg: mgmt::ComputeReady) -> Result<(), waiters::NotifyError> {
|
||||
CPLANE_WAITERS.notify(psql_session_id, msg)
|
||||
}
|
||||
|
||||
/// Compute node connection params provided by the cloud.
|
||||
/// Note how it implements serde traits, since we receive it over the wire.
|
||||
#[derive(Serialize, Deserialize, Default)]
|
||||
pub struct DatabaseInfo {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
pub dbname: String,
|
||||
pub user: String,
|
||||
pub password: Option<String>,
|
||||
}
|
||||
|
||||
// Manually implement debug to omit personal and sensitive info.
|
||||
impl std::fmt::Debug for DatabaseInfo {
|
||||
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
fmt.debug_struct("DatabaseInfo")
|
||||
.field("host", &self.host)
|
||||
.field("port", &self.port)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<DatabaseInfo> for tokio_postgres::Config {
|
||||
fn from(db_info: DatabaseInfo) -> Self {
|
||||
let mut config = tokio_postgres::Config::new();
|
||||
|
||||
config
|
||||
.host(&db_info.host)
|
||||
.port(db_info.port)
|
||||
.dbname(&db_info.dbname)
|
||||
.user(&db_info.user);
|
||||
|
||||
if let Some(password) = db_info.password {
|
||||
config.password(password);
|
||||
}
|
||||
|
||||
config
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_user(
|
||||
config: &ProxyConfig,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
|
||||
creds: ClientCredentials,
|
||||
) -> super::Result<compute::NodeInfo> {
|
||||
use AuthBackendType::*;
|
||||
match config.auth_backend {
|
||||
LegacyConsole => {
|
||||
legacy_console::handle_user(
|
||||
&config.auth_endpoint,
|
||||
&config.auth_link_uri,
|
||||
client,
|
||||
&creds,
|
||||
)
|
||||
.await
|
||||
}
|
||||
Console => {
|
||||
console::Api::new(&config.auth_endpoint, &creds)?
|
||||
.handle_user(client)
|
||||
.await
|
||||
}
|
||||
Postgres => {
|
||||
postgres::Api::new(&config.auth_endpoint, &creds)?
|
||||
.handle_user(client)
|
||||
.await
|
||||
}
|
||||
Link => link::handle_user(&config.auth_link_uri, client).await,
|
||||
}
|
||||
}
|
||||
225
proxy/src/auth/backend/console.rs
Normal file
225
proxy/src/auth/backend/console.rs
Normal file
@@ -0,0 +1,225 @@
|
||||
//! Cloud API V2.
|
||||
|
||||
use crate::{
|
||||
auth::{self, AuthFlow, ClientCredentials, DatabaseInfo},
|
||||
compute,
|
||||
error::UserFacingError,
|
||||
scram,
|
||||
stream::PqStream,
|
||||
url::ApiUrl,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{future::Future, io};
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage};
|
||||
|
||||
pub type Result<T> = std::result::Result<T, ConsoleAuthError>;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ConsoleAuthError {
|
||||
#[error(transparent)]
|
||||
BadProjectName(#[from] auth::credentials::ProjectNameError),
|
||||
|
||||
// We shouldn't include the actual secret here.
|
||||
#[error("Bad authentication secret")]
|
||||
BadSecret,
|
||||
|
||||
#[error("Console responded with a malformed compute address: '{0}'")]
|
||||
BadComputeAddress(String),
|
||||
|
||||
#[error("Console responded with a malformed JSON: '{0}'")]
|
||||
BadResponse(#[from] serde_json::Error),
|
||||
|
||||
/// HTTP status (other than 200) returned by the console.
|
||||
#[error("Console responded with an HTTP status: {0}")]
|
||||
HttpStatus(reqwest::StatusCode),
|
||||
|
||||
#[error(transparent)]
|
||||
Io(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
impl UserFacingError for ConsoleAuthError {
|
||||
fn to_string_client(&self) -> String {
|
||||
use ConsoleAuthError::*;
|
||||
match self {
|
||||
BadProjectName(e) => e.to_string_client(),
|
||||
_ => "Internal error".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: convert into an enum with "error"
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
struct GetRoleSecretResponse {
|
||||
role_secret: String,
|
||||
}
|
||||
|
||||
// TODO: convert into an enum with "error"
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
struct GetWakeComputeResponse {
|
||||
address: String,
|
||||
}
|
||||
|
||||
/// Auth secret which is managed by the cloud.
|
||||
pub enum AuthInfo {
|
||||
/// Md5 hash of user's password.
|
||||
Md5([u8; 16]),
|
||||
|
||||
/// [SCRAM](crate::scram) authentication info.
|
||||
Scram(scram::ServerSecret),
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub(super) struct Api<'a> {
|
||||
endpoint: &'a ApiUrl,
|
||||
creds: &'a ClientCredentials,
|
||||
/// Cache project name, since we'll need it several times.
|
||||
project: &'a str,
|
||||
}
|
||||
|
||||
impl<'a> Api<'a> {
|
||||
/// Construct an API object containing the auth parameters.
|
||||
pub(super) fn new(endpoint: &'a ApiUrl, creds: &'a ClientCredentials) -> Result<Self> {
|
||||
Ok(Self {
|
||||
endpoint,
|
||||
creds,
|
||||
project: creds.project_name()?,
|
||||
})
|
||||
}
|
||||
|
||||
/// Authenticate the existing user or throw an error.
|
||||
pub(super) async fn handle_user(
|
||||
self,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
|
||||
) -> auth::Result<compute::NodeInfo> {
|
||||
handle_user(client, &self, Self::get_auth_info, Self::wake_compute).await
|
||||
}
|
||||
|
||||
async fn get_auth_info(&self) -> Result<AuthInfo> {
|
||||
let mut url = self.endpoint.clone();
|
||||
url.path_segments_mut().push("proxy_get_role_secret");
|
||||
url.query_pairs_mut()
|
||||
.append_pair("project", self.project)
|
||||
.append_pair("role", &self.creds.user);
|
||||
|
||||
// TODO: use a proper logger
|
||||
println!("cplane request: {url}");
|
||||
|
||||
let resp = reqwest::get(url.into_inner()).await.map_err(io_error)?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(ConsoleAuthError::HttpStatus(resp.status()));
|
||||
}
|
||||
|
||||
let response: GetRoleSecretResponse =
|
||||
serde_json::from_str(&resp.text().await.map_err(io_error)?)?;
|
||||
|
||||
scram::ServerSecret::parse(response.role_secret.as_str())
|
||||
.map(AuthInfo::Scram)
|
||||
.ok_or(ConsoleAuthError::BadSecret)
|
||||
}
|
||||
|
||||
/// Wake up the compute node and return the corresponding connection info.
|
||||
async fn wake_compute(&self) -> Result<DatabaseInfo> {
|
||||
let mut url = self.endpoint.clone();
|
||||
url.path_segments_mut().push("proxy_wake_compute");
|
||||
url.query_pairs_mut().append_pair("project", self.project);
|
||||
|
||||
// TODO: use a proper logger
|
||||
println!("cplane request: {url}");
|
||||
|
||||
let resp = reqwest::get(url.into_inner()).await.map_err(io_error)?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(ConsoleAuthError::HttpStatus(resp.status()));
|
||||
}
|
||||
|
||||
let response: GetWakeComputeResponse =
|
||||
serde_json::from_str(&resp.text().await.map_err(io_error)?)?;
|
||||
|
||||
let (host, port) = parse_host_port(&response.address)
|
||||
.ok_or(ConsoleAuthError::BadComputeAddress(response.address))?;
|
||||
|
||||
Ok(DatabaseInfo {
|
||||
host,
|
||||
port,
|
||||
dbname: self.creds.dbname.to_owned(),
|
||||
user: self.creds.user.to_owned(),
|
||||
password: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Common logic for user handling in API V2.
|
||||
/// We reuse this for a mock API implementation in [`super::postgres`].
|
||||
pub(super) async fn handle_user<'a, Endpoint, GetAuthInfo, WakeCompute>(
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
endpoint: &'a Endpoint,
|
||||
get_auth_info: impl FnOnce(&'a Endpoint) -> GetAuthInfo,
|
||||
wake_compute: impl FnOnce(&'a Endpoint) -> WakeCompute,
|
||||
) -> auth::Result<compute::NodeInfo>
|
||||
where
|
||||
GetAuthInfo: Future<Output = Result<AuthInfo>>,
|
||||
WakeCompute: Future<Output = Result<DatabaseInfo>>,
|
||||
{
|
||||
let auth_info = get_auth_info(endpoint).await?;
|
||||
|
||||
let flow = AuthFlow::new(client);
|
||||
let scram_keys = match auth_info {
|
||||
AuthInfo::Md5(_) => {
|
||||
// TODO: decide if we should support MD5 in api v2
|
||||
return Err(auth::AuthErrorImpl::auth_failed("MD5 is not supported").into());
|
||||
}
|
||||
AuthInfo::Scram(secret) => {
|
||||
let scram = auth::Scram(&secret);
|
||||
Some(compute::ScramKeys {
|
||||
client_key: flow.begin(scram).await?.authenticate().await?.as_bytes(),
|
||||
server_key: secret.server_key.as_bytes(),
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
client
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
|
||||
|
||||
Ok(compute::NodeInfo {
|
||||
db_info: wake_compute(endpoint).await?,
|
||||
scram_keys,
|
||||
})
|
||||
}
|
||||
|
||||
/// Upcast (almost) any error into an opaque [`io::Error`].
|
||||
pub(super) fn io_error(e: impl Into<Box<dyn std::error::Error + Send + Sync>>) -> io::Error {
|
||||
io::Error::new(io::ErrorKind::Other, e)
|
||||
}
|
||||
|
||||
fn parse_host_port(input: &str) -> Option<(String, u16)> {
|
||||
let (host, port) = input.split_once(':')?;
|
||||
Some((host.to_owned(), port.parse().ok()?))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn parse_db_info() -> anyhow::Result<()> {
|
||||
let _: DatabaseInfo = serde_json::from_value(json!({
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"dbname": "postgres",
|
||||
"user": "john_doe",
|
||||
"password": "password",
|
||||
}))?;
|
||||
|
||||
let _: DatabaseInfo = serde_json::from_value(json!({
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"dbname": "postgres",
|
||||
"user": "john_doe",
|
||||
}))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,20 +1,18 @@
|
||||
//! Cloud API V1.
|
||||
|
||||
use super::console::DatabaseInfo;
|
||||
|
||||
use crate::auth::ClientCredentials;
|
||||
use crate::stream::PqStream;
|
||||
|
||||
use crate::{compute, waiters};
|
||||
use super::DatabaseInfo;
|
||||
use crate::{
|
||||
auth::{self, ClientCredentials},
|
||||
compute,
|
||||
error::UserFacingError,
|
||||
stream::PqStream,
|
||||
waiters,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage};
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::error::UserFacingError;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum AuthErrorImpl {
|
||||
/// Authentication error reported by the console.
|
||||
@@ -45,7 +43,7 @@ pub struct AuthError(Box<AuthErrorImpl>);
|
||||
impl AuthError {
|
||||
/// Smart constructor for authentication error reported by `mgmt`.
|
||||
pub fn auth_failed(msg: impl Into<String>) -> Self {
|
||||
AuthError(Box::new(AuthErrorImpl::AuthFailed(msg.into())))
|
||||
Self(Box::new(AuthErrorImpl::AuthFailed(msg.into())))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,7 +52,7 @@ where
|
||||
AuthErrorImpl: From<T>,
|
||||
{
|
||||
fn from(e: T) -> Self {
|
||||
AuthError(Box::new(e.into()))
|
||||
Self(Box::new(e.into()))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -120,7 +118,7 @@ async fn handle_existing_user(
|
||||
auth_endpoint: &reqwest::Url,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
|
||||
creds: &ClientCredentials,
|
||||
) -> Result<crate::compute::NodeInfo, crate::auth::AuthError> {
|
||||
) -> Result<compute::NodeInfo, auth::AuthError> {
|
||||
let psql_session_id = super::link::new_psql_session_id();
|
||||
let md5_salt = rand::random();
|
||||
|
||||
@@ -130,7 +128,7 @@ async fn handle_existing_user(
|
||||
|
||||
// Read client's password hash
|
||||
let msg = client.read_password_message().await?;
|
||||
let md5_response = parse_password(&msg).ok_or(crate::auth::AuthErrorImpl::MalformedPassword)?;
|
||||
let md5_response = parse_password(&msg).ok_or(auth::AuthErrorImpl::MalformedPassword)?;
|
||||
|
||||
let db_info = authenticate_proxy_client(
|
||||
auth_endpoint,
|
||||
@@ -156,11 +154,11 @@ pub async fn handle_user(
|
||||
auth_link_uri: &reqwest::Url,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
|
||||
creds: &ClientCredentials,
|
||||
) -> Result<crate::compute::NodeInfo, crate::auth::AuthError> {
|
||||
) -> auth::Result<compute::NodeInfo> {
|
||||
if creds.is_existing_user() {
|
||||
handle_existing_user(auth_endpoint, client, creds).await
|
||||
} else {
|
||||
super::link::handle_user(auth_link_uri.as_ref(), client).await
|
||||
super::link::handle_user(auth_link_uri, client).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::{compute, stream::PqStream};
|
||||
use crate::{auth, compute, stream::PqStream};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage};
|
||||
|
||||
@@ -19,13 +19,13 @@ pub fn new_psql_session_id() -> String {
|
||||
}
|
||||
|
||||
pub async fn handle_user(
|
||||
redirect_uri: &str,
|
||||
redirect_uri: &reqwest::Url,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> Result<compute::NodeInfo, crate::auth::AuthError> {
|
||||
) -> auth::Result<compute::NodeInfo> {
|
||||
let psql_session_id = new_psql_session_id();
|
||||
let greeting = hello_message(redirect_uri, &psql_session_id);
|
||||
let greeting = hello_message(redirect_uri.as_str(), &psql_session_id);
|
||||
|
||||
let db_info = crate::auth_backend::with_waiter(psql_session_id, |waiter| async {
|
||||
let db_info = super::with_waiter(psql_session_id, |waiter| async {
|
||||
// Give user a URL to spawn a new database
|
||||
client
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
@@ -34,9 +34,7 @@ pub async fn handle_user(
|
||||
.await?;
|
||||
|
||||
// Wait for web console response (see `mgmt`)
|
||||
waiter
|
||||
.await?
|
||||
.map_err(crate::auth::AuthErrorImpl::auth_failed)
|
||||
waiter.await?.map_err(auth::AuthErrorImpl::auth_failed)
|
||||
})
|
||||
.await?;
|
||||
|
||||
88
proxy/src/auth/backend/postgres.rs
Normal file
88
proxy/src/auth/backend/postgres.rs
Normal file
@@ -0,0 +1,88 @@
|
||||
//! Local mock of Cloud API V2.
|
||||
|
||||
use crate::{
|
||||
auth::{
|
||||
self,
|
||||
backend::console::{self, io_error, AuthInfo, Result},
|
||||
ClientCredentials, DatabaseInfo,
|
||||
},
|
||||
compute, scram,
|
||||
stream::PqStream,
|
||||
url::ApiUrl,
|
||||
};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
#[must_use]
|
||||
pub(super) struct Api<'a> {
|
||||
endpoint: &'a ApiUrl,
|
||||
creds: &'a ClientCredentials,
|
||||
}
|
||||
|
||||
impl<'a> Api<'a> {
|
||||
/// Construct an API object containing the auth parameters.
|
||||
pub(super) fn new(endpoint: &'a ApiUrl, creds: &'a ClientCredentials) -> Result<Self> {
|
||||
Ok(Self { endpoint, creds })
|
||||
}
|
||||
|
||||
/// Authenticate the existing user or throw an error.
|
||||
pub(super) async fn handle_user(
|
||||
self,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
|
||||
) -> auth::Result<compute::NodeInfo> {
|
||||
// We reuse user handling logic from a production module.
|
||||
console::handle_user(client, &self, Self::get_auth_info, Self::wake_compute).await
|
||||
}
|
||||
|
||||
/// This implementation fetches the auth info from a local postgres instance.
|
||||
async fn get_auth_info(&self) -> Result<AuthInfo> {
|
||||
// Perhaps we could persist this connection, but then we'd have to
|
||||
// write more code for reopening it if it got closed, which doesn't
|
||||
// seem worth it.
|
||||
let (client, connection) =
|
||||
tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls)
|
||||
.await
|
||||
.map_err(io_error)?;
|
||||
|
||||
tokio::spawn(connection);
|
||||
let query = "select rolpassword from pg_catalog.pg_authid where rolname = $1";
|
||||
let rows = client
|
||||
.query(query, &[&self.creds.user])
|
||||
.await
|
||||
.map_err(io_error)?;
|
||||
|
||||
match &rows[..] {
|
||||
// We can't get a secret if there's no such user.
|
||||
[] => Err(io_error(format!("unknown user '{}'", self.creds.user)).into()),
|
||||
|
||||
// We shouldn't get more than one row anyway.
|
||||
[row, ..] => {
|
||||
let entry = row.try_get(0).map_err(io_error)?;
|
||||
scram::ServerSecret::parse(entry)
|
||||
.map(AuthInfo::Scram)
|
||||
.or_else(|| {
|
||||
// It could be an md5 hash if it's not a SCRAM secret.
|
||||
let text = entry.strip_prefix("md5")?;
|
||||
Some(AuthInfo::Md5({
|
||||
let mut bytes = [0u8; 16];
|
||||
hex::decode_to_slice(text, &mut bytes).ok()?;
|
||||
bytes
|
||||
}))
|
||||
})
|
||||
// Putting the secret into this message is a security hazard!
|
||||
.ok_or(console::ConsoleAuthError::BadSecret)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// We don't need to wake anything locally, so we just return the connection info.
|
||||
async fn wake_compute(&self) -> Result<DatabaseInfo> {
|
||||
Ok(DatabaseInfo {
|
||||
// TODO: handle that near CLI params parsing
|
||||
host: self.endpoint.host_str().unwrap_or("localhost").to_owned(),
|
||||
port: self.endpoint.port().unwrap_or(5432),
|
||||
dbname: self.creds.dbname.to_owned(),
|
||||
user: self.creds.user.to_owned(),
|
||||
password: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,5 @@
|
||||
//! User credentials used in authentication.
|
||||
|
||||
use super::AuthError;
|
||||
use crate::compute;
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::error::UserFacingError;
|
||||
@@ -36,6 +35,27 @@ impl ClientCredentials {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ProjectNameError {
|
||||
#[error("SNI is missing, please upgrade the postgres client library")]
|
||||
Missing,
|
||||
|
||||
#[error("SNI is malformed")]
|
||||
Bad,
|
||||
}
|
||||
|
||||
impl UserFacingError for ProjectNameError {}
|
||||
|
||||
impl ClientCredentials {
|
||||
/// Determine project name from SNI.
|
||||
pub fn project_name(&self) -> Result<&str, ProjectNameError> {
|
||||
// Currently project name is passed as a top level domain
|
||||
let sni = self.sni_data.as_ref().ok_or(ProjectNameError::Missing)?;
|
||||
let (first, _) = sni.split_once('.').ok_or(ProjectNameError::Bad)?;
|
||||
Ok(first)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<HashMap<String, String>> for ClientCredentials {
|
||||
type Error = ClientCredsParseError;
|
||||
|
||||
@@ -47,11 +67,11 @@ impl TryFrom<HashMap<String, String>> for ClientCredentials {
|
||||
};
|
||||
|
||||
let user = get_param("user")?;
|
||||
let db = get_param("database")?;
|
||||
let dbname = get_param("database")?;
|
||||
|
||||
Ok(Self {
|
||||
user,
|
||||
dbname: db,
|
||||
dbname,
|
||||
sni_data: None,
|
||||
})
|
||||
}
|
||||
@@ -63,8 +83,8 @@ impl ClientCredentials {
|
||||
self,
|
||||
config: &ProxyConfig,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
|
||||
) -> Result<compute::NodeInfo, AuthError> {
|
||||
) -> super::Result<compute::NodeInfo> {
|
||||
// This method is just a convenient facade for `handle_user`
|
||||
super::handle_user(config, client, self).await
|
||||
super::backend::handle_user(config, client, self).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
//! Main authentication flow.
|
||||
|
||||
use super::{AuthError, AuthErrorImpl};
|
||||
use super::AuthErrorImpl;
|
||||
use crate::stream::PqStream;
|
||||
use crate::{sasl, scram};
|
||||
use std::io;
|
||||
@@ -32,7 +32,7 @@ impl AuthMethod for Scram<'_> {
|
||||
pub struct AuthFlow<'a, Stream, State> {
|
||||
/// The underlying stream which implements libpq's protocol.
|
||||
stream: &'a mut PqStream<Stream>,
|
||||
/// State might contain ancillary data (see [`AuthFlow::begin`]).
|
||||
/// State might contain ancillary data (see [`Self::begin`]).
|
||||
state: State,
|
||||
}
|
||||
|
||||
@@ -60,7 +60,7 @@ impl<'a, S: AsyncWrite + Unpin> AuthFlow<'a, S, Begin> {
|
||||
/// Stream wrapper for handling [SCRAM](crate::scram) auth.
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
|
||||
/// Perform user authentication. Raise an error in case authentication failed.
|
||||
pub async fn authenticate(self) -> Result<scram::ScramKey, AuthError> {
|
||||
pub async fn authenticate(self) -> super::Result<scram::ScramKey> {
|
||||
// Initial client message contains the chosen auth method's name.
|
||||
let msg = self.stream.read_password_message().await?;
|
||||
let sasl = sasl::FirstMessage::parse(&msg).ok_or(AuthErrorImpl::MalformedPassword)?;
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
pub mod console;
|
||||
pub mod legacy_console;
|
||||
pub mod link;
|
||||
pub mod postgres;
|
||||
|
||||
pub use legacy_console::{AuthError, AuthErrorImpl};
|
||||
|
||||
use crate::mgmt;
|
||||
use crate::waiters::{self, Waiter, Waiters};
|
||||
use lazy_static::lazy_static;
|
||||
|
||||
lazy_static! {
|
||||
static ref CPLANE_WAITERS: Waiters<mgmt::ComputeReady> = Default::default();
|
||||
}
|
||||
|
||||
/// Give caller an opportunity to wait for the cloud's reply.
|
||||
pub async fn with_waiter<R, T, E>(
|
||||
psql_session_id: impl Into<String>,
|
||||
action: impl FnOnce(Waiter<'static, mgmt::ComputeReady>) -> R,
|
||||
) -> Result<T, E>
|
||||
where
|
||||
R: std::future::Future<Output = Result<T, E>>,
|
||||
E: From<waiters::RegisterError>,
|
||||
{
|
||||
let waiter = CPLANE_WAITERS.register(psql_session_id.into())?;
|
||||
action(waiter).await
|
||||
}
|
||||
|
||||
pub fn notify(psql_session_id: &str, msg: mgmt::ComputeReady) -> Result<(), waiters::NotifyError> {
|
||||
CPLANE_WAITERS.notify(psql_session_id, msg)
|
||||
}
|
||||
@@ -1,243 +0,0 @@
|
||||
//! Declaration of Cloud API V2.
|
||||
|
||||
use crate::{
|
||||
auth::{self, AuthFlow},
|
||||
compute, scram,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
|
||||
use crate::auth::ClientCredentials;
|
||||
use crate::stream::PqStream;
|
||||
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ConsoleAuthError {
|
||||
// We shouldn't include the actual secret here.
|
||||
#[error("Bad authentication secret")]
|
||||
BadSecret,
|
||||
|
||||
#[error("Bad client credentials: {0:?}")]
|
||||
BadCredentials(crate::auth::ClientCredentials),
|
||||
|
||||
#[error("SNI info is missing, please upgrade the postgres client library")]
|
||||
SniMissing,
|
||||
|
||||
#[error("Unexpected SNI content")]
|
||||
SniWrong,
|
||||
|
||||
#[error(transparent)]
|
||||
BadUrl(#[from] url::ParseError),
|
||||
|
||||
#[error(transparent)]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
/// HTTP status (other than 200) returned by the console.
|
||||
#[error("Console responded with an HTTP status: {0}")]
|
||||
HttpStatus(reqwest::StatusCode),
|
||||
|
||||
#[error(transparent)]
|
||||
Transport(#[from] reqwest::Error),
|
||||
|
||||
#[error("Console responded with a malformed JSON: '{0}'")]
|
||||
MalformedResponse(#[from] serde_json::Error),
|
||||
|
||||
#[error("Console responded with a malformed compute address: '{0}'")]
|
||||
MalformedComputeAddress(String),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
struct GetRoleSecretResponse {
|
||||
role_secret: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
struct GetWakeComputeResponse {
|
||||
address: String,
|
||||
}
|
||||
|
||||
/// Auth secret which is managed by the cloud.
|
||||
pub enum AuthInfo {
|
||||
/// Md5 hash of user's password.
|
||||
Md5([u8; 16]),
|
||||
/// [SCRAM](crate::scram) authentication info.
|
||||
Scram(scram::ServerSecret),
|
||||
}
|
||||
|
||||
/// Compute node connection params provided by the cloud.
|
||||
/// Note how it implements serde traits, since we receive it over the wire.
|
||||
#[derive(Serialize, Deserialize, Default)]
|
||||
pub struct DatabaseInfo {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
pub dbname: String,
|
||||
pub user: String,
|
||||
|
||||
/// [Cloud API V1](super::legacy) returns cleartext password,
|
||||
/// but [Cloud API V2](super::api) implements [SCRAM](crate::scram)
|
||||
/// authentication, so we can leverage this method and cope without password.
|
||||
pub password: Option<String>,
|
||||
}
|
||||
|
||||
// Manually implement debug to omit personal and sensitive info.
|
||||
impl std::fmt::Debug for DatabaseInfo {
|
||||
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
fmt.debug_struct("DatabaseInfo")
|
||||
.field("host", &self.host)
|
||||
.field("port", &self.port)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<DatabaseInfo> for tokio_postgres::Config {
|
||||
fn from(db_info: DatabaseInfo) -> Self {
|
||||
let mut config = tokio_postgres::Config::new();
|
||||
|
||||
config
|
||||
.host(&db_info.host)
|
||||
.port(db_info.port)
|
||||
.dbname(&db_info.dbname)
|
||||
.user(&db_info.user);
|
||||
|
||||
if let Some(password) = db_info.password {
|
||||
config.password(password);
|
||||
}
|
||||
|
||||
config
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_auth_info(
|
||||
auth_endpoint: &str,
|
||||
user: &str,
|
||||
cluster: &str,
|
||||
) -> Result<AuthInfo, ConsoleAuthError> {
|
||||
let mut url = reqwest::Url::parse(&format!("{auth_endpoint}/proxy_get_role_secret"))?;
|
||||
|
||||
url.query_pairs_mut()
|
||||
.append_pair("project", cluster)
|
||||
.append_pair("role", user);
|
||||
|
||||
// TODO: use a proper logger
|
||||
println!("cplane request: {}", url);
|
||||
|
||||
let resp = reqwest::get(url).await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(ConsoleAuthError::HttpStatus(resp.status()));
|
||||
}
|
||||
|
||||
let response: GetRoleSecretResponse = serde_json::from_str(resp.text().await?.as_str())?;
|
||||
|
||||
scram::ServerSecret::parse(response.role_secret.as_str())
|
||||
.map(AuthInfo::Scram)
|
||||
.ok_or(ConsoleAuthError::BadSecret)
|
||||
}
|
||||
|
||||
/// Wake up the compute node and return the corresponding connection info.
|
||||
async fn wake_compute(
|
||||
auth_endpoint: &str,
|
||||
cluster: &str,
|
||||
) -> Result<(String, u16), ConsoleAuthError> {
|
||||
let mut url = reqwest::Url::parse(&format!("{auth_endpoint}/proxy_wake_compute"))?;
|
||||
url.query_pairs_mut().append_pair("project", cluster);
|
||||
|
||||
// TODO: use a proper logger
|
||||
println!("cplane request: {}", url);
|
||||
|
||||
let resp = reqwest::get(url).await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(ConsoleAuthError::HttpStatus(resp.status()));
|
||||
}
|
||||
|
||||
let response: GetWakeComputeResponse = serde_json::from_str(resp.text().await?.as_str())?;
|
||||
let (host, port) = response
|
||||
.address
|
||||
.split_once(':')
|
||||
.ok_or_else(|| ConsoleAuthError::MalformedComputeAddress(response.address.clone()))?;
|
||||
let port: u16 = port
|
||||
.parse()
|
||||
.map_err(|_| ConsoleAuthError::MalformedComputeAddress(response.address.clone()))?;
|
||||
|
||||
Ok((host.to_string(), port))
|
||||
}
|
||||
|
||||
pub async fn handle_user(
|
||||
auth_endpoint: &str,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
creds: &ClientCredentials,
|
||||
) -> Result<compute::NodeInfo, crate::auth::AuthError> {
|
||||
// Determine cluster name from SNI.
|
||||
let cluster = creds
|
||||
.sni_data
|
||||
.as_ref()
|
||||
.ok_or(ConsoleAuthError::SniMissing)?
|
||||
.split_once('.')
|
||||
.ok_or(ConsoleAuthError::SniWrong)?
|
||||
.0;
|
||||
|
||||
let user = creds.user.as_str();
|
||||
|
||||
// Step 1: get the auth secret
|
||||
let auth_info = get_auth_info(auth_endpoint, user, cluster).await?;
|
||||
|
||||
let flow = AuthFlow::new(client);
|
||||
let scram_keys = match auth_info {
|
||||
AuthInfo::Md5(_) => {
|
||||
// TODO: decide if we should support MD5 in api v2
|
||||
return Err(crate::auth::AuthErrorImpl::auth_failed("MD5 is not supported").into());
|
||||
}
|
||||
AuthInfo::Scram(secret) => {
|
||||
let scram = auth::Scram(&secret);
|
||||
Some(compute::ScramKeys {
|
||||
client_key: flow.begin(scram).await?.authenticate().await?.as_bytes(),
|
||||
server_key: secret.server_key.as_bytes(),
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
client
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
|
||||
|
||||
// Step 2: wake compute
|
||||
let (host, port) = wake_compute(auth_endpoint, cluster).await?;
|
||||
|
||||
Ok(compute::NodeInfo {
|
||||
db_info: DatabaseInfo {
|
||||
host,
|
||||
port,
|
||||
dbname: creds.dbname.clone(),
|
||||
user: creds.user.clone(),
|
||||
password: None,
|
||||
},
|
||||
scram_keys,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn parse_db_info() -> anyhow::Result<()> {
|
||||
let _: DatabaseInfo = serde_json::from_value(json!({
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"dbname": "postgres",
|
||||
"user": "john_doe",
|
||||
"password": "password",
|
||||
}))?;
|
||||
|
||||
let _: DatabaseInfo = serde_json::from_value(json!({
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"dbname": "postgres",
|
||||
"user": "john_doe",
|
||||
}))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,93 +0,0 @@
|
||||
//! Local mock of Cloud API V2.
|
||||
|
||||
use super::console::{self, AuthInfo, DatabaseInfo};
|
||||
use crate::scram;
|
||||
use crate::{auth::ClientCredentials, compute};
|
||||
|
||||
use crate::stream::PqStream;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage};
|
||||
|
||||
async fn get_auth_info(
|
||||
auth_endpoint: &str,
|
||||
creds: &ClientCredentials,
|
||||
) -> Result<AuthInfo, console::ConsoleAuthError> {
|
||||
// We wrap `tokio_postgres::Error` because we don't want to infect the
|
||||
// method's error type with a detail that's specific to debug mode only.
|
||||
let io_error = |e| std::io::Error::new(std::io::ErrorKind::Other, e);
|
||||
|
||||
// Perhaps we could persist this connection, but then we'd have to
|
||||
// write more code for reopening it if it got closed, which doesn't
|
||||
// seem worth it.
|
||||
let (client, connection) = tokio_postgres::connect(auth_endpoint, tokio_postgres::NoTls)
|
||||
.await
|
||||
.map_err(io_error)?;
|
||||
|
||||
tokio::spawn(connection);
|
||||
let query = "select rolpassword from pg_catalog.pg_authid where rolname = $1";
|
||||
let rows = client
|
||||
.query(query, &[&creds.user])
|
||||
.await
|
||||
.map_err(io_error)?;
|
||||
|
||||
match &rows[..] {
|
||||
// We can't get a secret if there's no such user.
|
||||
[] => Err(console::ConsoleAuthError::BadCredentials(creds.to_owned())),
|
||||
// We shouldn't get more than one row anyway.
|
||||
[row, ..] => {
|
||||
let entry = row.try_get(0).map_err(io_error)?;
|
||||
scram::ServerSecret::parse(entry)
|
||||
.map(AuthInfo::Scram)
|
||||
.or_else(|| {
|
||||
// It could be an md5 hash if it's not a SCRAM secret.
|
||||
let text = entry.strip_prefix("md5")?;
|
||||
Some(AuthInfo::Md5({
|
||||
let mut bytes = [0u8; 16];
|
||||
hex::decode_to_slice(text, &mut bytes).ok()?;
|
||||
bytes
|
||||
}))
|
||||
})
|
||||
// Putting the secret into this message is a security hazard!
|
||||
.ok_or(console::ConsoleAuthError::BadSecret)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn handle_user(
|
||||
auth_endpoint: &reqwest::Url,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
creds: &ClientCredentials,
|
||||
) -> Result<compute::NodeInfo, crate::auth::AuthError> {
|
||||
let auth_info = get_auth_info(auth_endpoint.as_ref(), creds).await?;
|
||||
|
||||
let flow = crate::auth::AuthFlow::new(client);
|
||||
let scram_keys = match auth_info {
|
||||
AuthInfo::Md5(_) => {
|
||||
// TODO: decide if we should support MD5 in api v2
|
||||
return Err(crate::auth::AuthErrorImpl::auth_failed("MD5 is not supported").into());
|
||||
}
|
||||
AuthInfo::Scram(secret) => {
|
||||
let scram = crate::auth::Scram(&secret);
|
||||
Some(compute::ScramKeys {
|
||||
client_key: flow.begin(scram).await?.authenticate().await?.as_bytes(),
|
||||
server_key: secret.server_key.as_bytes(),
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
client
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
|
||||
|
||||
Ok(compute::NodeInfo {
|
||||
db_info: DatabaseInfo {
|
||||
// TODO: handle that near CLI params parsing
|
||||
host: auth_endpoint.host_str().unwrap_or("localhost").to_owned(),
|
||||
port: auth_endpoint.port().unwrap_or(5432),
|
||||
dbname: creds.dbname.to_owned(),
|
||||
user: creds.user.to_owned(),
|
||||
password: None,
|
||||
},
|
||||
scram_keys,
|
||||
})
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::auth_backend::console::DatabaseInfo;
|
||||
use crate::auth::DatabaseInfo;
|
||||
use crate::cancellation::CancelClosure;
|
||||
use crate::error::UserFacingError;
|
||||
use std::io;
|
||||
@@ -37,7 +37,7 @@ pub struct NodeInfo {
|
||||
|
||||
impl NodeInfo {
|
||||
async fn connect_raw(&self) -> io::Result<(SocketAddr, TcpStream)> {
|
||||
let host_port = format!("{}:{}", self.db_info.host, self.db_info.port);
|
||||
let host_port = (self.db_info.host.as_str(), self.db_info.port);
|
||||
let socket = TcpStream::connect(host_port).await?;
|
||||
let socket_addr = socket.peer_addr()?;
|
||||
socket2::SockRef::from(&socket).set_keepalive(true)?;
|
||||
|
||||
@@ -1,39 +1,38 @@
|
||||
use anyhow::{ensure, Context};
|
||||
use crate::url::ApiUrl;
|
||||
use anyhow::{bail, ensure, Context};
|
||||
use std::{str::FromStr, sync::Arc};
|
||||
|
||||
#[non_exhaustive]
|
||||
pub enum AuthBackendType {
|
||||
/// Legacy Cloud API (V1).
|
||||
LegacyConsole,
|
||||
Console,
|
||||
Postgres,
|
||||
/// Authentication via a web browser.
|
||||
Link,
|
||||
/// Current Cloud API (V2).
|
||||
Console,
|
||||
/// Local mock of Cloud API (V2).
|
||||
Postgres,
|
||||
}
|
||||
|
||||
impl FromStr for AuthBackendType {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(s: &str) -> anyhow::Result<Self> {
|
||||
println!("ClientAuthMethod::from_str: '{}'", s);
|
||||
use AuthBackendType::*;
|
||||
match s {
|
||||
"legacy" => Ok(LegacyConsole),
|
||||
"console" => Ok(Console),
|
||||
"postgres" => Ok(Postgres),
|
||||
"link" => Ok(Link),
|
||||
_ => Err(anyhow::anyhow!("Invlid option for auth method")),
|
||||
}
|
||||
Ok(match s {
|
||||
"legacy" => LegacyConsole,
|
||||
"console" => Console,
|
||||
"postgres" => Postgres,
|
||||
"link" => Link,
|
||||
_ => bail!("Invalid option `{s}` for auth method"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ProxyConfig {
|
||||
/// TLS configuration for the proxy.
|
||||
pub tls_config: Option<TlsConfig>,
|
||||
|
||||
pub auth_backend: AuthBackendType,
|
||||
|
||||
pub auth_endpoint: reqwest::Url,
|
||||
|
||||
pub auth_link_uri: reqwest::Url,
|
||||
pub auth_endpoint: ApiUrl,
|
||||
pub auth_link_uri: ApiUrl,
|
||||
}
|
||||
|
||||
pub type TlsConfig = Arc<rustls::ServerConfig>;
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
//! in somewhat transparent manner (again via communication with control plane API).
|
||||
|
||||
mod auth;
|
||||
mod auth_backend;
|
||||
mod cancellation;
|
||||
mod compute;
|
||||
mod config;
|
||||
@@ -17,6 +16,7 @@ mod proxy;
|
||||
mod sasl;
|
||||
mod scram;
|
||||
mod stream;
|
||||
mod url;
|
||||
mod waiters;
|
||||
|
||||
use anyhow::{bail, Context};
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::auth_backend;
|
||||
use crate::auth;
|
||||
use anyhow::Context;
|
||||
use serde::Deserialize;
|
||||
use std::{
|
||||
@@ -77,12 +77,12 @@ struct PsqlSessionResponse {
|
||||
|
||||
#[derive(Deserialize)]
|
||||
enum PsqlSessionResult {
|
||||
Success(auth_backend::console::DatabaseInfo),
|
||||
Success(auth::DatabaseInfo),
|
||||
Failure(String),
|
||||
}
|
||||
|
||||
/// A message received by `mgmt` when a compute node is ready.
|
||||
pub type ComputeReady = Result<auth_backend::console::DatabaseInfo, String>;
|
||||
pub type ComputeReady = Result<auth::DatabaseInfo, String>;
|
||||
|
||||
impl PsqlSessionResult {
|
||||
fn into_compute_ready(self) -> ComputeReady {
|
||||
@@ -113,7 +113,7 @@ fn try_process_query(pgb: &mut PostgresBackend, query_string: &str) -> anyhow::R
|
||||
|
||||
let resp: PsqlSessionResponse = serde_json::from_str(query_string)?;
|
||||
|
||||
match auth_backend::notify(&resp.session_id, resp.result.into_compute_ready()) {
|
||||
match auth::backend::notify(&resp.session_id, resp.result.into_compute_ready()) {
|
||||
Ok(()) => {
|
||||
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
|
||||
.write_message_noflush(&BeMessage::DataRow(&[Some(b"ok")]))?
|
||||
|
||||
82
proxy/src/url.rs
Normal file
82
proxy/src/url.rs
Normal file
@@ -0,0 +1,82 @@
|
||||
use anyhow::bail;
|
||||
use url::form_urlencoded::Serializer;
|
||||
|
||||
/// A [url](url::Url) type with additional guarantees.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ApiUrl(url::Url);
|
||||
|
||||
impl ApiUrl {
|
||||
/// Consume the wrapper and return inner [url](url::Url).
|
||||
pub fn into_inner(self) -> url::Url {
|
||||
self.0
|
||||
}
|
||||
|
||||
/// See [`url::Url::query_pairs_mut`].
|
||||
pub fn query_pairs_mut(&mut self) -> Serializer<'_, url::UrlQuery<'_>> {
|
||||
self.0.query_pairs_mut()
|
||||
}
|
||||
|
||||
/// See [`url::Url::path_segments_mut`].
|
||||
pub fn path_segments_mut(&mut self) -> url::PathSegmentsMut {
|
||||
// We've already verified that it works during construction.
|
||||
self.0.path_segments_mut().expect("bad API url")
|
||||
}
|
||||
}
|
||||
|
||||
/// This instance imposes additional requirements on the url.
|
||||
impl std::str::FromStr for ApiUrl {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(s: &str) -> anyhow::Result<Self> {
|
||||
let mut url: url::Url = s.parse()?;
|
||||
|
||||
// Make sure that we can build upon this URL.
|
||||
if url.path_segments_mut().is_err() {
|
||||
bail!("bad API url provided");
|
||||
}
|
||||
|
||||
Ok(Self(url))
|
||||
}
|
||||
}
|
||||
|
||||
/// This instance is safe because it doesn't allow us to modify the object.
|
||||
impl std::ops::Deref for ApiUrl {
|
||||
type Target = url::Url;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ApiUrl {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
self.0.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn bad_url() {
|
||||
let url = "test:foobar";
|
||||
url.parse::<url::Url>().expect("unexpected parsing failure");
|
||||
let _ = url.parse::<ApiUrl>().expect_err("should not parse");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn good_url() {
|
||||
let url = "test://foobar";
|
||||
let mut a = url.parse::<url::Url>().expect("unexpected parsing failure");
|
||||
let mut b = url.parse::<ApiUrl>().expect("unexpected parsing failure");
|
||||
|
||||
a.path_segments_mut().unwrap().push("method");
|
||||
a.query_pairs_mut().append_pair("key", "value");
|
||||
|
||||
b.path_segments_mut().push("method");
|
||||
b.query_pairs_mut().append_pair("key", "value");
|
||||
|
||||
assert_eq!(a, b.into_inner());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user