From bcf5cd908ec830d43a9c5250c8d78bda137e84ab Mon Sep 17 00:00:00 2001 From: Kliment Serafimov Date: Thu, 2 Jun 2022 00:04:26 +0200 Subject: [PATCH] Merged changes. --- libs/utils/src/pq_proto.rs | 10 +- proxy/src/auth/credentials.rs | 34 +++- proxy/src/auth_backend/console.rs | 251 ++++++++++++++++++++++++++++++ 3 files changed, 286 insertions(+), 9 deletions(-) create mode 100644 proxy/src/auth_backend/console.rs diff --git a/libs/utils/src/pq_proto.rs b/libs/utils/src/pq_proto.rs index ce86cf8c91..0ad8adb3a6 100644 --- a/libs/utils/src/pq_proto.rs +++ b/libs/utils/src/pq_proto.rs @@ -269,11 +269,15 @@ impl FeStartupPacket { .next() .context("expected even number of params in StartupMessage")?; if name == "options" { - // deprecated way of passing params as cmd line args - for cmdopt in value.split(' ') { - let nameval: Vec<&str> = cmdopt.split('=').collect(); + //parsing options arguments "..&options=:,.." + //extended example and set of options: + //https://github.com/neondatabase/neon/blob/main/docs/rfcs/016-connection-routing.md#connection-url + for cmdopt in value.split(',') { + let nameval: Vec<&str> = cmdopt.split(':').collect(); if nameval.len() == 2 { params.insert(nameval[0].to_string(), nameval[1].to_string()); + } else { + //todo: inform user / throw error message if options format is wrong. } } } else { diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index 467e7db282..2636c2237d 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -26,6 +26,10 @@ pub struct ClientCredentials { // New console API requires SNI info to determine the cluster name. // Other Auth backends don't need it. pub sni_data: Option, + + // cluster_option is passed as argument from options from url. + // To be used to determine cluster name in case sni_data is missing. + pub project_option: Option, } impl ClientCredentials { @@ -37,10 +41,10 @@ impl ClientCredentials { #[derive(Debug, Error)] pub enum ProjectNameError { - #[error("SNI is missing, please upgrade the postgres client library")] + #[error("SNI info is missing. EITHER please upgrade the postgres client library OR pass the project name as a parameter: '..&options=project:..'.")] Missing, - #[error("SNI is malformed")] + #[error("SNI is malformed.")] Bad, } @@ -49,10 +53,22 @@ impl UserFacingError for ProjectNameError {} impl ClientCredentials { /// Determine project name from SNI. pub fn project_name(&self) -> Result<&str, ProjectNameError> { - // Currently project name is passed as a top level domain - let sni = self.sni_data.as_ref().ok_or(ProjectNameError::Missing)?; - let (first, _) = sni.split_once('.').ok_or(ProjectNameError::Bad)?; - Ok(first) + let ret = match &self.sni_data { + //if sni_data exists, use it to determine project name + Some(sni_data) => { + sni_data + .split_once('.') + .ok_or(ProjectNameError::Bad)? + .0 + } + //otherwise use project_option if it was manually set thought ..&options=project: parameter + None => self + .project_option + .as_ref() + .ok_or(ProjectNameError::Missing)? + .as_str(), + }; + Ok(ret) } } @@ -68,11 +84,17 @@ impl TryFrom> for ClientCredentials { let user = get_param("user")?; let dbname = get_param("database")?; + let project = get_param("project"); + let project_option = match project { + Ok(project) => Some(project), + Err(_) => None, + }; Ok(Self { user, dbname, sni_data: None, + project_option, }) } } diff --git a/proxy/src/auth_backend/console.rs b/proxy/src/auth_backend/console.rs new file mode 100644 index 0000000000..551bcb7bb6 --- /dev/null +++ b/proxy/src/auth_backend/console.rs @@ -0,0 +1,251 @@ +//! Declaration of Cloud API V2. + +use crate::{ + auth::{self, AuthFlow}, + compute, scram, +}; +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +use crate::auth::ClientCredentials; +use crate::stream::PqStream; + +use tokio::io::{AsyncRead, AsyncWrite}; +use utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage}; + +#[derive(Debug, Error)] +pub enum ConsoleAuthError { + // We shouldn't include the actual secret here. + #[error("Bad authentication secret")] + BadSecret, + + #[error("Bad client credentials: {0:?}")] + BadCredentials(crate::auth::ClientCredentials), + + #[error("SNI info is missing. EITHER please upgrade the postgres client library OR pass ..&options=cluster:.. parameter")] + SniMissingAndProjectNameMissing, + + #[error("Unexpected SNI content")] + SniWrong, + + #[error(transparent)] + BadUrl(#[from] url::ParseError), + + #[error(transparent)] + Io(#[from] std::io::Error), + + /// HTTP status (other than 200) returned by the console. + #[error("Console responded with an HTTP status: {0}")] + HttpStatus(reqwest::StatusCode), + + #[error(transparent)] + Transport(#[from] reqwest::Error), + + #[error("Console responded with a malformed JSON: '{0}'")] + MalformedResponse(#[from] serde_json::Error), + + #[error("Console responded with a malformed compute address: '{0}'")] + MalformedComputeAddress(String), +} + +#[derive(Serialize, Deserialize, Debug)] +struct GetRoleSecretResponse { + role_secret: String, +} + +#[derive(Serialize, Deserialize, Debug)] +struct GetWakeComputeResponse { + address: String, +} + +/// Auth secret which is managed by the cloud. +pub enum AuthInfo { + /// Md5 hash of user's password. + Md5([u8; 16]), + /// [SCRAM](crate::scram) authentication info. + Scram(scram::ServerSecret), +} + +/// Compute node connection params provided by the cloud. +/// Note how it implements serde traits, since we receive it over the wire. +#[derive(Serialize, Deserialize, Default)] +pub struct DatabaseInfo { + pub host: String, + pub port: u16, + pub dbname: String, + pub user: String, + + /// [Cloud API V1](super::legacy) returns cleartext password, + /// but [Cloud API V2](super::api) implements [SCRAM](crate::scram) + /// authentication, so we can leverage this method and cope without password. + pub password: Option, +} + +// Manually implement debug to omit personal and sensitive info. +impl std::fmt::Debug for DatabaseInfo { + fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { + fmt.debug_struct("DatabaseInfo") + .field("host", &self.host) + .field("port", &self.port) + .finish() + } +} + +impl From for tokio_postgres::Config { + fn from(db_info: DatabaseInfo) -> Self { + let mut config = tokio_postgres::Config::new(); + + config + .host(&db_info.host) + .port(db_info.port) + .dbname(&db_info.dbname) + .user(&db_info.user); + + if let Some(password) = db_info.password { + config.password(password); + } + + config + } +} + +async fn get_auth_info( + auth_endpoint: &str, + user: &str, + cluster: &str, +) -> Result { + let mut url = reqwest::Url::parse(&format!("{auth_endpoint}/proxy_get_role_secret"))?; + + url.query_pairs_mut() + .append_pair("project", cluster) + .append_pair("role", user); + + // TODO: use a proper logger + println!("cplane request: {}", url); + + let resp = reqwest::get(url).await?; + if !resp.status().is_success() { + return Err(ConsoleAuthError::HttpStatus(resp.status())); + } + + let response: GetRoleSecretResponse = serde_json::from_str(resp.text().await?.as_str())?; + + scram::ServerSecret::parse(response.role_secret.as_str()) + .map(AuthInfo::Scram) + .ok_or(ConsoleAuthError::BadSecret) +} + +/// Wake up the compute node and return the corresponding connection info. +async fn wake_compute( + auth_endpoint: &str, + cluster: &str, +) -> Result<(String, u16), ConsoleAuthError> { + let mut url = reqwest::Url::parse(&format!("{auth_endpoint}/proxy_wake_compute"))?; + url.query_pairs_mut().append_pair("project", cluster); + + // TODO: use a proper logger + println!("cplane request: {}", url); + + let resp = reqwest::get(url).await?; + if !resp.status().is_success() { + return Err(ConsoleAuthError::HttpStatus(resp.status())); + } + + let response: GetWakeComputeResponse = serde_json::from_str(resp.text().await?.as_str())?; + let (host, port) = response + .address + .split_once(':') + .ok_or_else(|| ConsoleAuthError::MalformedComputeAddress(response.address.clone()))?; + let port: u16 = port + .parse() + .map_err(|_| ConsoleAuthError::MalformedComputeAddress(response.address.clone()))?; + + Ok((host.to_string(), port)) +} + +pub async fn handle_user( + auth_endpoint: &str, + client: &mut PqStream, + creds: &ClientCredentials, +) -> Result { + // Determine cluster name from SNI (creds.sni_data) or from creds.cluster_option. + let cluster = match &creds.sni_data { + //if sni_data exists, use it + Some(sni_data) => { + sni_data + .split_once('.') + .ok_or(ConsoleAuthError::SniWrong)? + .0 + } + //otherwise use cluster_option if it was manually set thought ..&options=cluster: parameter + None => creds + .cluster_option + .as_ref() + .ok_or(ConsoleAuthError::SniMissingAndProjectNameMissing)? + .as_str(), + }; + + let user = creds.user.as_str(); + + // Step 1: get the auth secret + let auth_info = get_auth_info(auth_endpoint, user, cluster).await?; + + let flow = AuthFlow::new(client); + let scram_keys = match auth_info { + AuthInfo::Md5(_) => { + // TODO: decide if we should support MD5 in api v2 + return Err(crate::auth::AuthErrorImpl::auth_failed("MD5 is not supported").into()); + } + AuthInfo::Scram(secret) => { + let scram = auth::Scram(&secret); + Some(compute::ScramKeys { + client_key: flow.begin(scram).await?.authenticate().await?.as_bytes(), + server_key: secret.server_key.as_bytes(), + }) + } + }; + + client + .write_message_noflush(&Be::AuthenticationOk)? + .write_message_noflush(&BeParameterStatusMessage::encoding())?; + + // Step 2: wake compute + let (host, port) = wake_compute(auth_endpoint, cluster).await?; + + Ok(compute::NodeInfo { + db_info: DatabaseInfo { + host, + port, + dbname: creds.dbname.clone(), + user: creds.user.clone(), + password: None, + }, + scram_keys, + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn parse_db_info() -> anyhow::Result<()> { + let _: DatabaseInfo = serde_json::from_value(json!({ + "host": "localhost", + "port": 5432, + "dbname": "postgres", + "user": "john_doe", + "password": "password", + }))?; + + let _: DatabaseInfo = serde_json::from_value(json!({ + "host": "localhost", + "port": 5432, + "dbname": "postgres", + "user": "john_doe", + }))?; + + Ok(()) + } +}