diff --git a/proxy/src/auth.rs b/proxy/src/auth.rs index d09470d15e..a50d23e351 100644 --- a/proxy/src/auth.rs +++ b/proxy/src/auth.rs @@ -22,10 +22,6 @@ pub type Result = std::result::Result; /// Common authentication error. #[derive(Debug, Error)] pub enum AuthErrorImpl { - // This will be dropped in the future. - #[error(transparent)] - Legacy(#[from] backend::LegacyAuthError), - #[error(transparent)] Link(#[from] backend::LinkAuthError), @@ -78,7 +74,6 @@ impl UserFacingError for AuthError { fn to_string_client(&self) -> String { use AuthErrorImpl::*; match self.0.as_ref() { - Legacy(e) => e.to_string_client(), Link(e) => e.to_string_client(), GetAuthInfo(e) => e.to_string_client(), WakeCompute(e) => e.to_string_client(), diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index 9c43620ffb..de0719a196 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -6,9 +6,6 @@ pub use link::LinkAuthError; mod console; pub use console::{GetAuthInfoError, WakeComputeError}; -mod legacy_console; -pub use legacy_console::LegacyAuthError; - use crate::{ auth::{self, AuthFlow, ClientCredentials}, compute, config, mgmt, @@ -56,7 +53,7 @@ impl std::fmt::Debug for DatabaseInfo { fmt.debug_struct("DatabaseInfo") .field("host", &self.host) .field("port", &self.port) - .finish() + .finish_non_exhaustive() } } @@ -88,8 +85,6 @@ impl From for tokio_postgres::Config { /// backends which require them for the authentication process. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum BackendType { - /// Legacy Cloud API (V1) + link auth. - LegacyConsole(T), /// Current Cloud API (V2). Console(T), /// Local mock of Cloud API (V2). @@ -105,7 +100,6 @@ impl BackendType { pub fn map(self, f: impl FnOnce(T) -> R) -> BackendType { use BackendType::*; match self { - LegacyConsole(x) => LegacyConsole(f(x)), Console(x) => Console(f(x)), Postgres(x) => Postgres(f(x)), Link => Link, @@ -119,7 +113,6 @@ impl BackendType> { pub fn transpose(self) -> Result, E> { use BackendType::*; match self { - LegacyConsole(x) => x.map(LegacyConsole), Console(x) => x.map(Console), Postgres(x) => x.map(Postgres), Link => Ok(Link), @@ -176,15 +169,6 @@ impl BackendType> { } match self { - LegacyConsole(creds) => { - legacy_console::handle_user( - &urls.auth_endpoint, - &urls.auth_link_uri, - &creds, - client, - ) - .await - } Console(creds) => { console::Api::new(&urls.auth_endpoint, &creds) .handle_user(client) @@ -208,7 +192,6 @@ mod tests { #[test] fn test_backend_type_map() { let values = [ - BackendType::LegacyConsole(0), BackendType::Console(0), BackendType::Postgres(0), BackendType::Link, @@ -222,8 +205,7 @@ mod tests { #[test] fn test_backend_type_transpose() { let values = [ - BackendType::LegacyConsole(Ok::<_, ()>(0)), - BackendType::Console(Ok(0)), + BackendType::Console(Ok::<_, ()>(0)), BackendType::Postgres(Ok(0)), BackendType::Link, ]; diff --git a/proxy/src/auth/backend/legacy_console.rs b/proxy/src/auth/backend/legacy_console.rs deleted file mode 100644 index b99a004dcd..0000000000 --- a/proxy/src/auth/backend/legacy_console.rs +++ /dev/null @@ -1,208 +0,0 @@ -//! Cloud API V1. - -use super::DatabaseInfo; -use crate::{ - auth::{self, ClientCredentials}, - compute, - error::UserFacingError, - stream::PqStream, - waiters, -}; -use serde::{Deserialize, Serialize}; -use thiserror::Error; -use tokio::io::{AsyncRead, AsyncWrite}; -use utils::pq_proto::BeMessage as Be; - -#[derive(Debug, Error)] -pub enum LegacyAuthError { - /// Authentication error reported by the console. - #[error("Authentication failed: {0}")] - AuthFailed(String), - - /// HTTP status (other than 200) returned by the console. - #[error("Console responded with an HTTP status: {0}")] - HttpStatus(reqwest::StatusCode), - - #[error("Console responded with a malformed JSON: {0}")] - BadResponse(#[from] serde_json::Error), - - #[error(transparent)] - Transport(#[from] reqwest::Error), - - #[error(transparent)] - WaiterRegister(#[from] waiters::RegisterError), - - #[error(transparent)] - WaiterWait(#[from] waiters::WaitError), -} - -impl UserFacingError for LegacyAuthError { - fn to_string_client(&self) -> String { - use LegacyAuthError::*; - match self { - AuthFailed(_) | HttpStatus(_) => self.to_string(), - _ => "Internal error".to_string(), - } - } -} - -// NOTE: the order of constructors is important. -// https://serde.rs/enum-representations.html#untagged -#[derive(Serialize, Deserialize, Debug)] -#[serde(untagged)] -enum ProxyAuthResponse { - Ready { conn_info: DatabaseInfo }, - Error { error: String }, - NotReady { ready: bool }, // TODO: get rid of `ready` -} - -impl ClientCredentials<'_> { - fn is_existing_user(&self) -> bool { - self.user.ends_with("@zenith") - } -} - -async fn authenticate_proxy_client( - auth_endpoint: &reqwest::Url, - creds: &ClientCredentials<'_>, - md5_response: &str, - salt: &[u8; 4], - psql_session_id: &str, -) -> Result { - let mut url = auth_endpoint.clone(); - url.query_pairs_mut() - .append_pair("login", creds.user) - .append_pair("database", creds.dbname) - .append_pair("md5response", md5_response) - .append_pair("salt", &hex::encode(salt)) - .append_pair("psql_session_id", psql_session_id); - - super::with_waiter(psql_session_id, |waiter| async { - println!("cloud request: {}", url); - // TODO: leverage `reqwest::Client` to reuse connections - let resp = reqwest::get(url).await?; - if !resp.status().is_success() { - return Err(LegacyAuthError::HttpStatus(resp.status())); - } - - let auth_info = serde_json::from_str(resp.text().await?.as_str())?; - println!("got auth info: {:?}", auth_info); - - use ProxyAuthResponse::*; - let db_info = match auth_info { - Ready { conn_info } => conn_info, - Error { error } => return Err(LegacyAuthError::AuthFailed(error)), - NotReady { .. } => waiter.await?.map_err(LegacyAuthError::AuthFailed)?, - }; - - Ok(db_info) - }) - .await -} - -async fn handle_existing_user( - auth_endpoint: &reqwest::Url, - client: &mut PqStream, - creds: &ClientCredentials<'_>, -) -> auth::Result { - let psql_session_id = super::link::new_psql_session_id(); - let md5_salt = rand::random(); - - client - .write_message(&Be::AuthenticationMD5Password(md5_salt)) - .await?; - - // Read client's password hash - let msg = client.read_password_message().await?; - let md5_response = parse_password(&msg).ok_or(auth::AuthErrorImpl::MalformedPassword( - "the password should be a valid null-terminated utf-8 string", - ))?; - - let db_info = authenticate_proxy_client( - auth_endpoint, - creds, - md5_response, - &md5_salt, - &psql_session_id, - ) - .await?; - - Ok(compute::NodeInfo { - reported_auth_ok: false, - config: db_info.into(), - }) -} - -pub async fn handle_user( - auth_endpoint: &reqwest::Url, - auth_link_uri: &reqwest::Url, - creds: &ClientCredentials<'_>, - client: &mut PqStream, -) -> auth::Result { - if creds.is_existing_user() { - handle_existing_user(auth_endpoint, client, creds).await - } else { - super::link::handle_user(auth_link_uri, client).await - } -} - -fn parse_password(bytes: &[u8]) -> Option<&str> { - std::str::from_utf8(bytes).ok()?.strip_suffix('\0') -} - -#[cfg(test)] -mod tests { - use super::*; - use serde_json::json; - - #[test] - fn test_proxy_auth_response() { - // Ready - let auth: ProxyAuthResponse = serde_json::from_value(json!({ - "ready": true, - "conn_info": DatabaseInfo::default(), - })) - .unwrap(); - assert!(matches!( - auth, - ProxyAuthResponse::Ready { - conn_info: DatabaseInfo { .. } - } - )); - - // Error - let auth: ProxyAuthResponse = serde_json::from_value(json!({ - "ready": false, - "error": "too bad, so sad", - })) - .unwrap(); - assert!(matches!(auth, ProxyAuthResponse::Error { .. })); - - // NotReady - let auth: ProxyAuthResponse = serde_json::from_value(json!({ - "ready": false, - })) - .unwrap(); - assert!(matches!(auth, ProxyAuthResponse::NotReady { .. })); - } - - #[test] - fn parse_db_info() -> anyhow::Result<()> { - let _: DatabaseInfo = serde_json::from_value(json!({ - "host": "localhost", - "port": 5432, - "dbname": "postgres", - "user": "john_doe", - "password": "password", - }))?; - - let _: DatabaseInfo = serde_json::from_value(json!({ - "host": "localhost", - "port": 5432, - "dbname": "postgres", - "user": "john_doe", - }))?; - - Ok(()) - } -} diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 1f01c25734..8835d660d5 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -1,21 +1,6 @@ use crate::{auth, url::ApiUrl}; -use anyhow::{bail, ensure, Context}; -use std::{str::FromStr, sync::Arc}; - -impl FromStr for auth::BackendType<()> { - type Err = anyhow::Error; - - fn from_str(s: &str) -> anyhow::Result { - use auth::BackendType::*; - Ok(match s { - "legacy" => LegacyConsole(()), - "console" => Console(()), - "postgres" => Postgres(()), - "link" => Link, - _ => bail!("Invalid option `{s}` for auth method"), - }) - } -} +use anyhow::{ensure, Context}; +use std::sync::Arc; pub struct ProxyConfig { pub tls_config: Option, diff --git a/proxy/src/main.rs b/proxy/src/main.rs index 2521f2af21..efe45f6386 100644 --- a/proxy/src/main.rs +++ b/proxy/src/main.rs @@ -20,7 +20,7 @@ mod url; mod waiters; use anyhow::{bail, Context}; -use clap::{App, Arg}; +use clap::{self, Arg}; use config::ProxyConfig; use futures::FutureExt; use std::{future::Future, net::SocketAddr}; @@ -36,9 +36,26 @@ async fn flatten_err( f.map(|r| r.context("join error").and_then(|x| x)).await } +/// A proper parser for auth backend parameter. +impl clap::ValueEnum for auth::BackendType<()> { + fn value_variants<'a>() -> &'a [Self] { + use auth::BackendType::*; + &[Console(()), Postgres(()), Link] + } + + fn to_possible_value<'a>(&self) -> Option> { + use auth::BackendType::*; + Some(clap::PossibleValue::new(match self { + Console(_) => "console", + Postgres(_) => "postgres", + Link => "link", + })) + } +} + #[tokio::main] async fn main() -> anyhow::Result<()> { - let arg_matches = App::new("Neon proxy/router") + let arg_matches = clap::App::new("Neon proxy/router") .version(GIT_VERSION) .arg( Arg::new("proxy") @@ -52,8 +69,8 @@ async fn main() -> anyhow::Result<()> { Arg::new("auth-backend") .long("auth-backend") .takes_value(true) - .help("Possible values: legacy | console | postgres | link") - .default_value("legacy"), + .value_parser(clap::builder::EnumValueParser::>::new()) + .default_value("link"), ) .arg( Arg::new("mgmt") @@ -118,6 +135,10 @@ async fn main() -> anyhow::Result<()> { let mgmt_address: SocketAddr = arg_matches.value_of("mgmt").unwrap().parse()?; let http_address: SocketAddr = arg_matches.value_of("http").unwrap().parse()?; + let auth_backend = *arg_matches + .try_get_one::>("auth-backend")? + .unwrap(); + let auth_urls = config::AuthUrls { auth_endpoint: arg_matches.value_of("auth-endpoint").unwrap().parse()?, auth_link_uri: arg_matches.value_of("uri").unwrap().parse()?, @@ -125,7 +146,7 @@ async fn main() -> anyhow::Result<()> { let config: &ProxyConfig = Box::leak(Box::new(ProxyConfig { tls_config, - auth_backend: arg_matches.value_of("auth-backend").unwrap().parse()?, + auth_backend, auth_urls, }));