From 5f4ccae5c5d426d8587ac9f91b251f8f842f4333 Mon Sep 17 00:00:00 2001 From: Dmitry Ivanov Date: Mon, 25 Jul 2022 17:23:10 +0300 Subject: [PATCH] [proxy] Add the `password hack` authentication flow (#2095) [proxy] Add the `password hack` authentication flow This lets us authenticate users which can use neither SNI (due to old libpq) nor connection string `options` (due to restrictions in other client libraries). Note: `PasswordHack` will accept passwords which are not encoded in base64 via the "password" field. The assumption is that most user passwords will be valid utf-8 strings, and the rest may still be passed via "password_". --- libs/utils/src/pq_proto.rs | 4 +- proxy/src/auth.rs | 12 +- proxy/src/auth/backend.rs | 186 ++++++++-- proxy/src/auth/backend/console.rs | 91 ++--- proxy/src/auth/backend/legacy_console.rs | 44 ++- proxy/src/auth/backend/link.rs | 4 +- proxy/src/auth/backend/postgres.rs | 35 +- proxy/src/auth/credentials.rs | 431 ++++++++--------------- proxy/src/auth/flow.rs | 39 +- proxy/src/auth/password_hack.rs | 102 ++++++ proxy/src/compute.rs | 104 ++++-- proxy/src/config.rs | 36 +- proxy/src/error.rs | 7 + proxy/src/main.rs | 8 +- proxy/src/proxy.rs | 91 ++--- proxy/src/stream.rs | 8 + test_runner/batch_others/test_proxy.py | 32 +- test_runner/fixtures/neon_fixtures.py | 66 ++-- 18 files changed, 750 insertions(+), 550 deletions(-) create mode 100644 proxy/src/auth/password_hack.rs diff --git a/libs/utils/src/pq_proto.rs b/libs/utils/src/pq_proto.rs index 0a320f123c..3dcae4d0af 100644 --- a/libs/utils/src/pq_proto.rs +++ b/libs/utils/src/pq_proto.rs @@ -47,10 +47,12 @@ pub enum FeStartupPacket { StartupMessage { major_version: u32, minor_version: u32, - params: HashMap, + params: StartupMessageParams, }, } +pub type StartupMessageParams = HashMap; + #[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)] pub struct CancelKeyData { pub backend_pid: i32, diff --git a/proxy/src/auth.rs b/proxy/src/auth.rs index 9bddd58fce..61c7458e16 100644 --- a/proxy/src/auth.rs +++ b/proxy/src/auth.rs @@ -1,11 +1,14 @@ //! Client authentication mechanisms. pub mod backend; -pub use backend::DatabaseInfo; +pub use backend::{BackendType, DatabaseInfo}; mod credentials; pub use credentials::ClientCredentials; +mod password_hack; +use password_hack::PasswordHackPayload; + mod flow; pub use flow::*; @@ -29,9 +32,8 @@ pub enum AuthErrorImpl { #[error(transparent)] Sasl(#[from] crate::sasl::Error), - /// For passwords that couldn't be processed by [`backend::legacy_console::parse_password`]. - #[error("Malformed password message")] - MalformedPassword, + #[error("Malformed password message: {0}")] + MalformedPassword(&'static str), /// Errors produced by [`crate::stream::PqStream`]. #[error(transparent)] @@ -76,7 +78,7 @@ impl UserFacingError for AuthError { Console(e) => e.to_string_client(), GetAuthInfo(e) => e.to_string_client(), Sasl(e) => e.to_string_client(), - MalformedPassword => self.to_string(), + MalformedPassword(_) => self.to_string(), _ => "Internal error".to_string(), } } diff --git a/proxy/src/auth/backend.rs b/proxy/src/auth/backend.rs index 1d41f7f932..5e87059c86 100644 --- a/proxy/src/auth/backend.rs +++ b/proxy/src/auth/backend.rs @@ -1,16 +1,14 @@ -mod legacy_console; mod link; mod postgres; pub mod console; +mod legacy_console; pub use legacy_console::{AuthError, AuthErrorImpl}; -use super::ClientCredentials; use crate::{ - compute, - config::{AuthBackendType, ProxyConfig}, - mgmt, + auth::{self, AuthFlow, ClientCredentials}, + compute, config, mgmt, stream::PqStream, waiters::{self, Waiter, Waiters}, }; @@ -78,32 +76,158 @@ impl From for tokio_postgres::Config { } } -pub(super) async fn handle_user( - config: &ProxyConfig, - client: &mut PqStream, - creds: ClientCredentials, -) -> super::Result { - use AuthBackendType::*; - match config.auth_backend { - LegacyConsole => { - legacy_console::handle_user( - &config.auth_endpoint, - &config.auth_link_uri, - client, - &creds, - ) - .await +/// This type serves two purposes: +/// +/// * When `T` is `()`, it's just a regular auth backend selector +/// which we use in [`crate::config::ProxyConfig`]. +/// +/// * However, when we substitute `T` with [`ClientCredentials`], +/// this helps us provide the credentials only to those auth +/// backends which require them for the authentication process. +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum BackendType { + /// Legacy Cloud API (V1) + link auth. + LegacyConsole(T), + /// Current Cloud API (V2). + Console(T), + /// Local mock of Cloud API (V2). + Postgres(T), + /// Authentication via a web browser. + Link, +} + +impl BackendType { + /// Very similar to [`std::option::Option::map`]. + /// Maps [`BackendType`] to [`BackendType`] by applying + /// a function to a contained value. + pub fn map(self, f: impl FnOnce(T) -> R) -> BackendType { + use BackendType::*; + match self { + LegacyConsole(x) => LegacyConsole(f(x)), + Console(x) => Console(f(x)), + Postgres(x) => Postgres(f(x)), + Link => Link, + } + } +} + +impl BackendType> { + /// Very similar to [`std::option::Option::transpose`]. + /// This is most useful for error handling. + pub fn transpose(self) -> Result, E> { + use BackendType::*; + match self { + LegacyConsole(x) => x.map(LegacyConsole), + Console(x) => x.map(Console), + Postgres(x) => x.map(Postgres), + Link => Ok(Link), + } + } +} + +impl BackendType { + /// Authenticate the client via the requested backend, possibly using credentials. + pub async fn authenticate( + mut self, + urls: &config::AuthUrls, + client: &mut PqStream, + ) -> super::Result { + use BackendType::*; + + if let Console(creds) | Postgres(creds) = &mut self { + // If there's no project so far, that entails that client doesn't + // support SNI or other means of passing the project name. + // We now expect to see a very specific payload in the place of password. + if creds.project().is_none() { + let payload = AuthFlow::new(client) + .begin(auth::PasswordHack) + .await? + .authenticate() + .await?; + + // Finally we may finish the initialization of `creds`. + // TODO: add missing type safety to ClientCredentials. + creds.project = Some(payload.project); + + let mut config = match &self { + Console(creds) => { + console::Api::new(&urls.auth_endpoint, creds) + .wake_compute() + .await? + } + Postgres(creds) => { + postgres::Api::new(&urls.auth_endpoint, creds) + .wake_compute() + .await? + } + _ => unreachable!("see the patterns above"), + }; + + // We should use a password from payload as well. + config.password(payload.password); + + return Ok(compute::NodeInfo { + reported_auth_ok: false, + config, + }); + } + } + + match self { + LegacyConsole(creds) => { + legacy_console::handle_user( + &urls.auth_endpoint, + &urls.auth_link_uri, + &creds, + client, + ) + .await + } + Console(creds) => { + console::Api::new(&urls.auth_endpoint, &creds) + .handle_user(client) + .await + } + Postgres(creds) => { + postgres::Api::new(&urls.auth_endpoint, &creds) + .handle_user(client) + .await + } + // NOTE: this auth backend doesn't use client credentials. + Link => link::handle_user(&urls.auth_link_uri, client).await, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_backend_type_map() { + let values = [ + BackendType::LegacyConsole(0), + BackendType::Console(0), + BackendType::Postgres(0), + BackendType::Link, + ]; + + for value in values { + assert_eq!(value.map(|x| x), value); + } + } + + #[test] + fn test_backend_type_transpose() { + let values = [ + BackendType::LegacyConsole(Ok::<_, ()>(0)), + BackendType::Console(Ok(0)), + BackendType::Postgres(Ok(0)), + BackendType::Link, + ]; + + for value in values { + assert_eq!(value.map(Result::unwrap), value.transpose().unwrap()); } - Console => { - console::Api::new(&config.auth_endpoint, &creds)? - .handle_user(client) - .await - } - Postgres => { - postgres::Api::new(&config.auth_endpoint, &creds)? - .handle_user(client) - .await - } - Link => link::handle_user(&config.auth_link_uri, client).await, } } diff --git a/proxy/src/auth/backend/console.rs b/proxy/src/auth/backend/console.rs index 3085f0b0e4..a8ff1a3522 100644 --- a/proxy/src/auth/backend/console.rs +++ b/proxy/src/auth/backend/console.rs @@ -1,18 +1,17 @@ //! Cloud API V2. use crate::{ - auth::{self, AuthFlow, ClientCredentials, DatabaseInfo}, - compute, - error::UserFacingError, + auth::{self, AuthFlow, ClientCredentials}, + compute::{self, ComputeConnCfg}, + error::{io_error, UserFacingError}, scram, stream::PqStream, url::ApiUrl, }; use serde::{Deserialize, Serialize}; -use std::{future::Future, io}; +use std::future::Future; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; -use utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage}; pub type Result = std::result::Result; @@ -84,8 +83,8 @@ pub(super) struct Api<'a> { impl<'a> Api<'a> { /// Construct an API object containing the auth parameters. - pub(super) fn new(endpoint: &'a ApiUrl, creds: &'a ClientCredentials) -> Result { - Ok(Self { endpoint, creds }) + pub(super) fn new(endpoint: &'a ApiUrl, creds: &'a ClientCredentials) -> Self { + Self { endpoint, creds } } /// Authenticate the existing user or throw an error. @@ -100,7 +99,7 @@ impl<'a> Api<'a> { let mut url = self.endpoint.clone(); url.path_segments_mut().push("proxy_get_role_secret"); url.query_pairs_mut() - .append_pair("project", self.creds.project_name.as_ref()?) + .append_pair("project", self.creds.project().expect("impossible")) .append_pair("role", &self.creds.user); // TODO: use a proper logger @@ -120,11 +119,11 @@ impl<'a> Api<'a> { } /// Wake up the compute node and return the corresponding connection info. - async fn wake_compute(&self) -> Result { + pub(super) async fn wake_compute(&self) -> Result { let mut url = self.endpoint.clone(); url.path_segments_mut().push("proxy_wake_compute"); - let project_name = self.creds.project_name.as_ref()?; - url.query_pairs_mut().append_pair("project", project_name); + url.query_pairs_mut() + .append_pair("project", self.creds.project().expect("impossible")); // TODO: use a proper logger println!("cplane request: {url}"); @@ -137,16 +136,20 @@ impl<'a> Api<'a> { let response: GetWakeComputeResponse = serde_json::from_str(&resp.text().await.map_err(io_error)?)?; - let (host, port) = parse_host_port(&response.address) - .ok_or(ConsoleAuthError::BadComputeAddress(response.address))?; + // Unfortunately, ownership won't let us use `Option::ok_or` here. + let (host, port) = match parse_host_port(&response.address) { + None => return Err(ConsoleAuthError::BadComputeAddress(response.address)), + Some(x) => x, + }; - Ok(DatabaseInfo { - host, - port, - dbname: self.creds.dbname.to_owned(), - user: self.creds.user.to_owned(), - password: None, - }) + let mut config = ComputeConnCfg::new(); + config + .host(host) + .port(port) + .dbname(&self.creds.dbname) + .user(&self.creds.user); + + Ok(config) } } @@ -160,7 +163,7 @@ pub(super) async fn handle_user<'a, Endpoint, GetAuthInfo, WakeCompute>( ) -> auth::Result where GetAuthInfo: Future>, - WakeCompute: Future>, + WakeCompute: Future>, { let auth_info = get_auth_info(endpoint).await?; @@ -179,48 +182,18 @@ where } }; - client - .write_message_noflush(&Be::AuthenticationOk)? - .write_message_noflush(&BeParameterStatusMessage::encoding())?; + let mut config = wake_compute(endpoint).await?; + if let Some(keys) = scram_keys { + config.auth_keys(tokio_postgres::config::AuthKeys::ScramSha256(keys)); + } Ok(compute::NodeInfo { - db_info: wake_compute(endpoint).await?, - scram_keys, + reported_auth_ok: false, + config, }) } -/// Upcast (almost) any error into an opaque [`io::Error`]. -pub(super) fn io_error(e: impl Into>) -> io::Error { - io::Error::new(io::ErrorKind::Other, e) -} - -fn parse_host_port(input: &str) -> Option<(String, u16)> { +fn parse_host_port(input: &str) -> Option<(&str, u16)> { let (host, port) = input.split_once(':')?; - Some((host.to_owned(), port.parse().ok()?)) -} - -#[cfg(test)] -mod tests { - use super::*; - use serde_json::json; - - #[test] - fn parse_db_info() -> anyhow::Result<()> { - let _: DatabaseInfo = serde_json::from_value(json!({ - "host": "localhost", - "port": 5432, - "dbname": "postgres", - "user": "john_doe", - "password": "password", - }))?; - - let _: DatabaseInfo = serde_json::from_value(json!({ - "host": "localhost", - "port": 5432, - "dbname": "postgres", - "user": "john_doe", - }))?; - - Ok(()) - } + Some((host, port.parse().ok()?)) } diff --git a/proxy/src/auth/backend/legacy_console.rs b/proxy/src/auth/backend/legacy_console.rs index 467da63a98..7a5e9b6f62 100644 --- a/proxy/src/auth/backend/legacy_console.rs +++ b/proxy/src/auth/backend/legacy_console.rs @@ -11,7 +11,7 @@ use crate::{ use serde::{Deserialize, Serialize}; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; -use utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage}; +use utils::pq_proto::BeMessage as Be; #[derive(Debug, Error)] pub enum AuthErrorImpl { @@ -76,6 +76,12 @@ enum ProxyAuthResponse { NotReady { ready: bool }, // TODO: get rid of `ready` } +impl ClientCredentials { + fn is_existing_user(&self) -> bool { + self.user.ends_with("@zenith") + } +} + async fn authenticate_proxy_client( auth_endpoint: &reqwest::Url, creds: &ClientCredentials, @@ -100,7 +106,7 @@ async fn authenticate_proxy_client( } let auth_info: ProxyAuthResponse = serde_json::from_str(resp.text().await?.as_str())?; - println!("got auth info: #{:?}", auth_info); + println!("got auth info: {:?}", auth_info); use ProxyAuthResponse::*; let db_info = match auth_info { @@ -128,7 +134,9 @@ async fn handle_existing_user( // Read client's password hash let msg = client.read_password_message().await?; - let md5_response = parse_password(&msg).ok_or(auth::AuthErrorImpl::MalformedPassword)?; + let md5_response = parse_password(&msg).ok_or(auth::AuthErrorImpl::MalformedPassword( + "the password should be a valid null-terminated utf-8 string", + ))?; let db_info = authenticate_proxy_client( auth_endpoint, @@ -139,21 +147,17 @@ async fn handle_existing_user( ) .await?; - client - .write_message_noflush(&Be::AuthenticationOk)? - .write_message_noflush(&BeParameterStatusMessage::encoding())?; - Ok(compute::NodeInfo { - db_info, - scram_keys: None, + reported_auth_ok: false, + config: db_info.into(), }) } pub async fn handle_user( auth_endpoint: &reqwest::Url, auth_link_uri: &reqwest::Url, - client: &mut PqStream, creds: &ClientCredentials, + client: &mut PqStream, ) -> auth::Result { if creds.is_existing_user() { handle_existing_user(auth_endpoint, client, creds).await @@ -201,4 +205,24 @@ mod tests { .unwrap(); assert!(matches!(auth, ProxyAuthResponse::NotReady { .. })); } + + #[test] + fn parse_db_info() -> anyhow::Result<()> { + let _: DatabaseInfo = serde_json::from_value(json!({ + "host": "localhost", + "port": 5432, + "dbname": "postgres", + "user": "john_doe", + "password": "password", + }))?; + + let _: DatabaseInfo = serde_json::from_value(json!({ + "host": "localhost", + "port": 5432, + "dbname": "postgres", + "user": "john_doe", + }))?; + + Ok(()) + } } diff --git a/proxy/src/auth/backend/link.rs b/proxy/src/auth/backend/link.rs index 669c9e00e9..d658a34825 100644 --- a/proxy/src/auth/backend/link.rs +++ b/proxy/src/auth/backend/link.rs @@ -41,7 +41,7 @@ pub async fn handle_user( client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?; Ok(compute::NodeInfo { - db_info, - scram_keys: None, + reported_auth_ok: true, + config: db_info.into(), }) } diff --git a/proxy/src/auth/backend/postgres.rs b/proxy/src/auth/backend/postgres.rs index 721b9db095..1d7ab8f249 100644 --- a/proxy/src/auth/backend/postgres.rs +++ b/proxy/src/auth/backend/postgres.rs @@ -3,10 +3,12 @@ use crate::{ auth::{ self, - backend::console::{self, io_error, AuthInfo, Result}, - ClientCredentials, DatabaseInfo, + backend::console::{self, AuthInfo, Result}, + ClientCredentials, }, - compute, scram, + compute::{self, ComputeConnCfg}, + error::io_error, + scram, stream::PqStream, url::ApiUrl, }; @@ -20,8 +22,8 @@ pub(super) struct Api<'a> { impl<'a> Api<'a> { /// Construct an API object containing the auth parameters. - pub(super) fn new(endpoint: &'a ApiUrl, creds: &'a ClientCredentials) -> Result { - Ok(Self { endpoint, creds }) + pub(super) fn new(endpoint: &'a ApiUrl, creds: &'a ClientCredentials) -> Self { + Self { endpoint, creds } } /// Authenticate the existing user or throw an error. @@ -56,7 +58,10 @@ impl<'a> Api<'a> { // We shouldn't get more than one row anyway. [row, ..] => { - let entry = row.try_get(0).map_err(io_error)?; + let entry = row + .try_get("rolpassword") + .map_err(|e| io_error(format!("failed to read user's password: {e}")))?; + scram::ServerSecret::parse(entry) .map(AuthInfo::Scram) .or_else(|| { @@ -75,14 +80,14 @@ impl<'a> Api<'a> { } /// We don't need to wake anything locally, so we just return the connection info. - async fn wake_compute(&self) -> Result { - Ok(DatabaseInfo { - // TODO: handle that near CLI params parsing - host: self.endpoint.host_str().unwrap_or("localhost").to_owned(), - port: self.endpoint.port().unwrap_or(5432), - dbname: self.creds.dbname.to_owned(), - user: self.creds.user.to_owned(), - password: None, - }) + pub(super) async fn wake_compute(&self) -> Result { + let mut config = ComputeConnCfg::new(); + config + .host(self.endpoint.host_str().unwrap_or("localhost")) + .port(self.endpoint.port().unwrap_or(5432)) + .dbname(&self.creds.dbname) + .user(&self.creds.user); + + Ok(config) } } diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index b5312fbe1f..4c72da1c48 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -1,39 +1,25 @@ //! User credentials used in authentication. -use crate::compute; -use crate::config::ProxyConfig; use crate::error::UserFacingError; -use crate::stream::PqStream; -use std::collections::HashMap; use thiserror::Error; -use tokio::io::{AsyncRead, AsyncWrite}; +use utils::pq_proto::StartupMessageParams; #[derive(Debug, Error, PartialEq, Eq, Clone)] pub enum ClientCredsParseError { - #[error("Parameter `{0}` is missing in startup packet.")] + #[error("Parameter '{0}' is missing in startup packet.")] MissingKey(&'static str), - #[error( - "Project name is not specified. \ - EITHER please upgrade the postgres client library (libpq) for SNI support \ - OR pass the project name as a parameter: '&options=project%3D'." - )] - MissingSNIAndProjectName, - #[error("Inconsistent project name inferred from SNI ('{0}') and project option ('{1}').")] - InconsistentProjectNameAndSNI(String, String), - - #[error("Common name is not set.")] - CommonNameNotSet, + InconsistentProjectNames(String, String), #[error( "SNI ('{1}') inconsistently formatted with respect to common name ('{0}'). \ - SNI should be formatted as '.'." + SNI should be formatted as '.{0}'." )] - InconsistentCommonNameAndSNI(String, String), + InconsistentSni(String, String), - #[error("Project name ('{0}') must contain only alphanumeric characters and hyphens ('-').")] - ProjectNameContainsIllegalChars(String), + #[error("Project name ('{0}') must contain only alphanumeric characters and hyphen.")] + MalformedProjectName(String), } impl UserFacingError for ClientCredsParseError {} @@ -44,286 +30,171 @@ impl UserFacingError for ClientCredsParseError {} pub struct ClientCredentials { pub user: String, pub dbname: String, - pub project_name: Result, + pub project: Option, } impl ClientCredentials { - pub fn is_existing_user(&self) -> bool { - // This logic will likely change in the future. - self.user.ends_with("@zenith") + pub fn project(&self) -> Option<&str> { + self.project.as_deref() } +} +impl ClientCredentials { pub fn parse( - mut options: HashMap, - sni_data: Option<&str>, + mut options: StartupMessageParams, + sni: Option<&str>, common_name: Option<&str>, ) -> Result { - let mut get_param = |key| { - options - .remove(key) - .ok_or(ClientCredsParseError::MissingKey(key)) - }; + use ClientCredsParseError::*; + // Some parameters are absolutely necessary, others not so much. + let mut get_param = |key| options.remove(key).ok_or(MissingKey(key)); + + // Some parameters are stored in the startup message. let user = get_param("user")?; let dbname = get_param("database")?; - let project_name = get_param("project").ok(); - let project_name = get_project_name(sni_data, common_name, project_name.as_deref()); + let project_a = get_param("project").ok(); + + // Alternative project name is in fact a subdomain from SNI. + // NOTE: we do not consider SNI if `common_name` is missing. + let project_b = sni + .zip(common_name) + .map(|(sni, cn)| { + // TODO: what if SNI is present but just a common name? + subdomain_from_sni(sni, cn) + .ok_or_else(|| InconsistentSni(sni.to_owned(), cn.to_owned())) + }) + .transpose()?; + + let project = match (project_a, project_b) { + // Invariant: if we have both project name variants, they should match. + (Some(a), Some(b)) if a != b => Some(Err(InconsistentProjectNames(a, b))), + (a, b) => a.or(b).map(|name| { + // Invariant: project name may not contain certain characters. + check_project_name(name).map_err(MalformedProjectName) + }), + } + .transpose()?; Ok(Self { user, dbname, - project_name, + project, }) } +} - /// Use credentials to authenticate the user. - pub async fn authenticate( - self, - config: &ProxyConfig, - client: &mut PqStream, - ) -> super::Result { - // This method is just a convenient facade for `handle_user` - super::backend::handle_user(config, client, self).await +fn check_project_name(name: String) -> Result { + if name.chars().all(|c| c.is_alphanumeric() || c == '-') { + Ok(name) + } else { + Err(name) } } -/// Inferring project name from sni_data. -fn project_name_from_sni_data( - sni_data: &str, - common_name: &str, -) -> Result { - let common_name_with_dot = format!(".{common_name}"); - // check that ".{common_name_with_dot}" is the actual suffix in sni_data - if !sni_data.ends_with(&common_name_with_dot) { - return Err(ClientCredsParseError::InconsistentCommonNameAndSNI( - common_name.to_string(), - sni_data.to_string(), +fn subdomain_from_sni(sni: &str, common_name: &str) -> Option { + sni.strip_suffix(common_name)? + .strip_suffix('.') + .map(str::to_owned) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_options<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> StartupMessageParams { + StartupMessageParams::from(pairs.map(|(k, v)| (k.to_owned(), v.to_owned()))) + } + + #[test] + #[ignore = "TODO: fix how database is handled"] + fn parse_bare_minimum() -> anyhow::Result<()> { + // According to postgresql, only `user` should be required. + let options = make_options([("user", "john_doe")]); + + // TODO: check that `creds.dbname` is None. + let creds = ClientCredentials::parse(options, None, None)?; + assert_eq!(creds.user, "john_doe"); + + Ok(()) + } + + #[test] + fn parse_missing_project() -> anyhow::Result<()> { + let options = make_options([("user", "john_doe"), ("database", "world")]); + + let creds = ClientCredentials::parse(options, None, None)?; + assert_eq!(creds.user, "john_doe"); + assert_eq!(creds.dbname, "world"); + assert_eq!(creds.project, None); + + Ok(()) + } + + #[test] + fn parse_project_from_sni() -> anyhow::Result<()> { + let options = make_options([("user", "john_doe"), ("database", "world")]); + + let sni = Some("foo.localhost"); + let common_name = Some("localhost"); + + let creds = ClientCredentials::parse(options, sni, common_name)?; + assert_eq!(creds.user, "john_doe"); + assert_eq!(creds.dbname, "world"); + assert_eq!(creds.project.as_deref(), Some("foo")); + + Ok(()) + } + + #[test] + fn parse_project_from_options() -> anyhow::Result<()> { + let options = make_options([ + ("user", "john_doe"), + ("database", "world"), + ("project", "bar"), + ]); + + let creds = ClientCredentials::parse(options, None, None)?; + assert_eq!(creds.user, "john_doe"); + assert_eq!(creds.dbname, "world"); + assert_eq!(creds.project.as_deref(), Some("bar")); + + Ok(()) + } + + #[test] + fn parse_projects_identical() -> anyhow::Result<()> { + let options = make_options([ + ("user", "john_doe"), + ("database", "world"), + ("project", "baz"), + ]); + + let sni = Some("baz.localhost"); + let common_name = Some("localhost"); + + let creds = ClientCredentials::parse(options, sni, common_name)?; + assert_eq!(creds.user, "john_doe"); + assert_eq!(creds.dbname, "world"); + assert_eq!(creds.project.as_deref(), Some("baz")); + + Ok(()) + } + + #[test] + fn parse_projects_different() { + let options = make_options([ + ("user", "john_doe"), + ("database", "world"), + ("project", "first"), + ]); + + let sni = Some("second.localhost"); + let common_name = Some("localhost"); + + assert!(matches!( + ClientCredentials::parse(options, sni, common_name).expect_err("should fail"), + ClientCredsParseError::InconsistentProjectNames(_, _) )); } - // return sni_data without the common name suffix. - Ok(sni_data - .strip_suffix(&common_name_with_dot) - .unwrap() - .to_string()) -} - -#[cfg(test)] -mod tests_for_project_name_from_sni_data { - use super::*; - - #[test] - fn passing() { - let target_project_name = "my-project-123"; - let common_name = "localtest.me"; - let sni_data = format!("{target_project_name}.{common_name}"); - assert_eq!( - project_name_from_sni_data(&sni_data, common_name), - Ok(target_project_name.to_string()) - ); - } - - #[test] - fn throws_inconsistent_common_name_and_sni_data() { - let target_project_name = "my-project-123"; - let common_name = "localtest.me"; - let wrong_suffix = "wrongtest.me"; - assert_eq!(common_name.len(), wrong_suffix.len()); - let wrong_common_name = format!("wrong{wrong_suffix}"); - let sni_data = format!("{target_project_name}.{wrong_common_name}"); - assert_eq!( - project_name_from_sni_data(&sni_data, common_name), - Err(ClientCredsParseError::InconsistentCommonNameAndSNI( - common_name.to_string(), - sni_data - )) - ); - } -} - -/// Determine project name from SNI or from project_name parameter from options argument. -fn get_project_name( - sni_data: Option<&str>, - common_name: Option<&str>, - project_name: Option<&str>, -) -> Result { - // determine the project name from sni_data if it exists, otherwise from project_name. - let ret = match sni_data { - Some(sni_data) => { - let common_name = common_name.ok_or(ClientCredsParseError::CommonNameNotSet)?; - let project_name_from_sni = project_name_from_sni_data(sni_data, common_name)?; - // check invariant: project name from options and from sni should match - if let Some(project_name) = &project_name { - if !project_name_from_sni.eq(project_name) { - return Err(ClientCredsParseError::InconsistentProjectNameAndSNI( - project_name_from_sni, - project_name.to_string(), - )); - } - } - project_name_from_sni - } - None => project_name - .ok_or(ClientCredsParseError::MissingSNIAndProjectName)? - .to_string(), - }; - - // check formatting invariant: project name must contain only alphanumeric characters and hyphens. - if !ret.chars().all(|x: char| x.is_alphanumeric() || x == '-') { - return Err(ClientCredsParseError::ProjectNameContainsIllegalChars(ret)); - } - - Ok(ret) -} - -#[cfg(test)] -mod tests_for_project_name_only { - use super::*; - - #[test] - fn passing_from_sni_data_only() { - let target_project_name = "my-project-123"; - let common_name = "localtest.me"; - let sni_data = format!("{target_project_name}.{common_name}"); - assert_eq!( - get_project_name(Some(&sni_data), Some(common_name), None), - Ok(target_project_name.to_string()) - ); - } - - #[test] - fn throws_project_name_contains_illegal_chars_from_sni_data_only() { - let project_name_prefix = "my-project"; - let project_name_suffix = "123"; - let common_name = "localtest.me"; - - for illegal_char_id in 0..256 { - let illegal_char = char::from_u32(illegal_char_id).unwrap(); - if !(illegal_char.is_alphanumeric() || illegal_char == '-') - && illegal_char.to_string().len() == 1 - { - let target_project_name = - format!("{project_name_prefix}{illegal_char}{project_name_suffix}"); - let sni_data = format!("{target_project_name}.{common_name}"); - assert_eq!( - get_project_name(Some(&sni_data), Some(common_name), None), - Err(ClientCredsParseError::ProjectNameContainsIllegalChars( - target_project_name - )) - ); - } - } - } - - #[test] - fn passing_from_project_name_only() { - let target_project_name = "my-project-123"; - let common_names = [Some("localtest.me"), None]; - for common_name in common_names { - assert_eq!( - get_project_name(None, common_name, Some(target_project_name)), - Ok(target_project_name.to_string()) - ); - } - } - - #[test] - fn throws_project_name_contains_illegal_chars_from_project_name_only() { - let project_name_prefix = "my-project"; - let project_name_suffix = "123"; - let common_names = [Some("localtest.me"), None]; - - for common_name in common_names { - for illegal_char_id in 0..256 { - let illegal_char: char = char::from_u32(illegal_char_id).unwrap(); - if !(illegal_char.is_alphanumeric() || illegal_char == '-') - && illegal_char.to_string().len() == 1 - { - let target_project_name = - format!("{project_name_prefix}{illegal_char}{project_name_suffix}"); - assert_eq!( - get_project_name(None, common_name, Some(&target_project_name)), - Err(ClientCredsParseError::ProjectNameContainsIllegalChars( - target_project_name - )) - ); - } - } - } - } - - #[test] - fn passing_from_sni_data_and_project_name() { - let target_project_name = "my-project-123"; - let common_name = "localtest.me"; - let sni_data = format!("{target_project_name}.{common_name}"); - assert_eq!( - get_project_name( - Some(&sni_data), - Some(common_name), - Some(target_project_name) - ), - Ok(target_project_name.to_string()) - ); - } - - #[test] - fn throws_inconsistent_project_name_and_sni() { - let project_name_param = "my-project-123"; - let wrong_project_name = "not-my-project-123"; - let common_name = "localtest.me"; - let sni_data = format!("{wrong_project_name}.{common_name}"); - assert_eq!( - get_project_name(Some(&sni_data), Some(common_name), Some(project_name_param)), - Err(ClientCredsParseError::InconsistentProjectNameAndSNI( - wrong_project_name.to_string(), - project_name_param.to_string() - )) - ); - } - - #[test] - fn throws_common_name_not_set() { - let target_project_name = "my-project-123"; - let wrong_project_name = "not-my-project-123"; - let common_name = "localtest.me"; - let sni_datas = [ - Some(format!("{wrong_project_name}.{common_name}")), - Some(format!("{target_project_name}.{common_name}")), - ]; - let project_names = [None, Some(target_project_name)]; - for sni_data in sni_datas { - for project_name_param in project_names { - assert_eq!( - get_project_name(sni_data.as_deref(), None, project_name_param), - Err(ClientCredsParseError::CommonNameNotSet) - ); - } - } - } - - #[test] - fn throws_inconsistent_common_name_and_sni_data() { - let target_project_name = "my-project-123"; - let wrong_project_name = "not-my-project-123"; - let common_name = "localtest.me"; - let wrong_suffix = "wrongtest.me"; - assert_eq!(common_name.len(), wrong_suffix.len()); - let wrong_common_name = format!("wrong{wrong_suffix}"); - let sni_datas = [ - Some(format!("{wrong_project_name}.{wrong_common_name}")), - Some(format!("{target_project_name}.{wrong_common_name}")), - ]; - let project_names = [None, Some(target_project_name)]; - for project_name_param in project_names { - for sni_data in &sni_datas { - assert_eq!( - get_project_name(sni_data.as_deref(), Some(common_name), project_name_param), - Err(ClientCredsParseError::InconsistentCommonNameAndSNI( - common_name.to_string(), - sni_data.clone().unwrap().to_string() - )) - ); - } - } - } } diff --git a/proxy/src/auth/flow.rs b/proxy/src/auth/flow.rs index 7efff13bfc..705f1e3807 100644 --- a/proxy/src/auth/flow.rs +++ b/proxy/src/auth/flow.rs @@ -1,8 +1,7 @@ //! Main authentication flow. -use super::AuthErrorImpl; -use crate::stream::PqStream; -use crate::{sasl, scram}; +use super::{AuthErrorImpl, PasswordHackPayload}; +use crate::{sasl, scram, stream::PqStream}; use std::io; use tokio::io::{AsyncRead, AsyncWrite}; use utils::pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be}; @@ -27,6 +26,17 @@ impl AuthMethod for Scram<'_> { } } +/// Use an ad hoc auth flow (for clients which don't support SNI) proposed in +/// . +pub struct PasswordHack; + +impl AuthMethod for PasswordHack { + #[inline(always)] + fn first_message(&self) -> BeMessage<'_> { + Be::AuthenticationCleartextPassword + } +} + /// This wrapper for [`PqStream`] performs client authentication. #[must_use] pub struct AuthFlow<'a, Stream, State> { @@ -57,13 +67,34 @@ impl<'a, S: AsyncWrite + Unpin> AuthFlow<'a, S, Begin> { } } +impl AuthFlow<'_, S, PasswordHack> { + /// Perform user authentication. Raise an error in case authentication failed. + pub async fn authenticate(self) -> super::Result { + let msg = self.stream.read_password_message().await?; + let password = msg + .strip_suffix(&[0]) + .ok_or(AuthErrorImpl::MalformedPassword("missing terminator"))?; + + // The so-called "password" should contain a base64-encoded json. + // We will use it later to route the client to their project. + let bytes = base64::decode(password) + .map_err(|_| AuthErrorImpl::MalformedPassword("bad encoding"))?; + + let payload = serde_json::from_slice(&bytes) + .map_err(|_| AuthErrorImpl::MalformedPassword("invalid payload"))?; + + Ok(payload) + } +} + /// Stream wrapper for handling [SCRAM](crate::scram) auth. impl AuthFlow<'_, S, Scram<'_>> { /// Perform user authentication. Raise an error in case authentication failed. pub async fn authenticate(self) -> super::Result { // Initial client message contains the chosen auth method's name. let msg = self.stream.read_password_message().await?; - let sasl = sasl::FirstMessage::parse(&msg).ok_or(AuthErrorImpl::MalformedPassword)?; + let sasl = sasl::FirstMessage::parse(&msg) + .ok_or(AuthErrorImpl::MalformedPassword("bad sasl message"))?; // Currently, the only supported SASL method is SCRAM. if !scram::METHODS.contains(&sasl.method) { diff --git a/proxy/src/auth/password_hack.rs b/proxy/src/auth/password_hack.rs new file mode 100644 index 0000000000..6a1258ab31 --- /dev/null +++ b/proxy/src/auth/password_hack.rs @@ -0,0 +1,102 @@ +//! Payload for ad hoc authentication method for clients that don't support SNI. +//! See the `impl` for [`super::backend::BackendType`]. +//! Read more: . + +use serde::{de, Deserialize, Deserializer}; +use std::fmt; + +#[derive(Deserialize)] +#[serde(untagged)] +pub enum Password { + /// A regular string for utf-8 encoded passwords. + Simple { password: String }, + + /// Password is base64-encoded because it may contain arbitrary byte sequences. + Encoded { + #[serde(rename = "password_", deserialize_with = "deserialize_base64")] + password: Vec, + }, +} + +impl AsRef<[u8]> for Password { + fn as_ref(&self) -> &[u8] { + match self { + Password::Simple { password } => password.as_ref(), + Password::Encoded { password } => password.as_ref(), + } + } +} + +#[derive(Deserialize)] +pub struct PasswordHackPayload { + pub project: String, + + #[serde(flatten)] + pub password: Password, +} + +fn deserialize_base64<'a, D: Deserializer<'a>>(des: D) -> Result, D::Error> { + // It's very tempting to replace this with + // + // ``` + // let base64: &str = Deserialize::deserialize(des)?; + // base64::decode(base64).map_err(serde::de::Error::custom) + // ``` + // + // Unfortunately, we can't always deserialize into `&str`, so we'd + // have to use an allocating `String` instead. Thus, visitor is better. + struct Visitor; + + impl<'de> de::Visitor<'de> for Visitor { + type Value = Vec; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a string") + } + + fn visit_str(self, v: &str) -> Result { + base64::decode(v).map_err(de::Error::custom) + } + } + + des.deserialize_str(Visitor) +} + +#[cfg(test)] +mod tests { + use super::*; + use rstest::rstest; + use serde_json::json; + + #[test] + fn parse_password() -> anyhow::Result<()> { + let password: Password = serde_json::from_value(json!({ + "password": "foo", + }))?; + assert_eq!(password.as_ref(), "foo".as_bytes()); + + let password: Password = serde_json::from_value(json!({ + "password_": base64::encode("foo"), + }))?; + assert_eq!(password.as_ref(), "foo".as_bytes()); + + Ok(()) + } + + #[rstest] + #[case("password", str::to_owned)] + #[case("password_", base64::encode)] + fn parse(#[case] key: &str, #[case] encode: fn(&'static str) -> String) -> anyhow::Result<()> { + let (password, project) = ("password", "pie-in-the-sky"); + let payload = json!({ + "project": project, + key: encode(password), + }); + + let payload: PasswordHackPayload = serde_json::from_value(payload)?; + assert_eq!(payload.password.as_ref(), password.as_bytes()); + assert_eq!(payload.project, project); + + Ok(()) + } +} diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index cccd6e60d4..896ef3588d 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -1,8 +1,6 @@ -use crate::auth::DatabaseInfo; -use crate::cancellation::CancelClosure; -use crate::error::UserFacingError; -use std::io; -use std::net::SocketAddr; +use crate::{cancellation::CancelClosure, error::UserFacingError}; +use futures::TryFutureExt; +use std::{io, net::SocketAddr}; use thiserror::Error; use tokio::net::TcpStream; use tokio_postgres::NoTls; @@ -21,44 +19,96 @@ pub enum ConnectionError { FailedToFetchPgVersion, } -impl UserFacingError for ConnectionError {} - -/// PostgreSQL version as [`String`]. -pub type Version = String; +impl UserFacingError for ConnectionError { + fn to_string_client(&self) -> String { + use ConnectionError::*; + match self { + // This helps us drop irrelevant library-specific prefixes. + // TODO: propagate severity level and other parameters. + Postgres(err) => match err.as_db_error() { + Some(err) => err.message().to_string(), + None => err.to_string(), + }, + other => other.to_string(), + } + } +} /// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`. pub type ScramKeys = tokio_postgres::config::ScramKeys<32>; -/// Compute node connection params. +pub type ComputeConnCfg = tokio_postgres::Config; + +/// Various compute node info for establishing connection etc. pub struct NodeInfo { - pub db_info: DatabaseInfo, - pub scram_keys: Option, + /// Did we send [`utils::pq_proto::BeMessage::AuthenticationOk`]? + pub reported_auth_ok: bool, + /// Compute node connection params. + pub config: tokio_postgres::Config, } impl NodeInfo { async fn connect_raw(&self) -> io::Result<(SocketAddr, TcpStream)> { - let host_port = (self.db_info.host.as_str(), self.db_info.port); - let socket = TcpStream::connect(host_port).await?; - let socket_addr = socket.peer_addr()?; - socket2::SockRef::from(&socket).set_keepalive(true)?; + use tokio_postgres::config::Host; - Ok((socket_addr, socket)) + let connect_once = |host, port| { + TcpStream::connect((host, port)).and_then(|socket| async { + let socket_addr = socket.peer_addr()?; + // This prevents load balancer from severing the connection. + socket2::SockRef::from(&socket).set_keepalive(true)?; + Ok((socket_addr, socket)) + }) + }; + + // We can't reuse connection establishing logic from `tokio_postgres` here, + // because it has no means for extracting the underlying socket which we + // require for our business. + let mut connection_error = None; + let ports = self.config.get_ports(); + for (i, host) in self.config.get_hosts().iter().enumerate() { + let port = ports.get(i).or_else(|| ports.get(0)).unwrap_or(&5432); + let host = match host { + Host::Tcp(host) => host.as_str(), + Host::Unix(_) => continue, // unix sockets are not welcome here + }; + + // TODO: maybe we should add a timeout. + match connect_once(host, *port).await { + Ok(socket) => return Ok(socket), + Err(err) => { + // We can't throw an error here, as there might be more hosts to try. + println!("failed to connect to compute `{host}:{port}`: {err}"); + connection_error = Some(err); + } + } + } + + Err(connection_error.unwrap_or_else(|| { + io::Error::new( + io::ErrorKind::Other, + format!("couldn't connect: bad compute config: {:?}", self.config), + ) + })) } +} +pub struct PostgresConnection { + /// Socket connected to a compute node. + pub stream: TcpStream, + /// PostgreSQL version of this instance. + pub version: String, +} + +impl NodeInfo { /// Connect to a corresponding compute node. - pub async fn connect(self) -> Result<(TcpStream, Version, CancelClosure), ConnectionError> { - let (socket_addr, mut socket) = self + pub async fn connect(&self) -> Result<(PostgresConnection, CancelClosure), ConnectionError> { + let (socket_addr, mut stream) = self .connect_raw() .await .map_err(|_| ConnectionError::FailedToConnectToCompute)?; - let mut config = tokio_postgres::Config::from(self.db_info); - if let Some(scram_keys) = self.scram_keys { - config.auth_keys(tokio_postgres::config::AuthKeys::ScramSha256(scram_keys)); - } - // TODO: establish a secure connection to the DB - let (client, conn) = config.connect_raw(&mut socket, NoTls).await?; + let (client, conn) = self.config.connect_raw(&mut stream, NoTls).await?; let version = conn .parameter("server_version") .ok_or(ConnectionError::FailedToFetchPgVersion)? @@ -66,6 +116,8 @@ impl NodeInfo { let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token()); - Ok((socket, version, cancel_closure)) + let db = PostgresConnection { stream, version }; + + Ok((db, cancel_closure)) } } diff --git a/proxy/src/config.rs b/proxy/src/config.rs index df3923de1a..1f01c25734 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -1,28 +1,16 @@ -use crate::url::ApiUrl; +use crate::{auth, url::ApiUrl}; use anyhow::{bail, ensure, Context}; use std::{str::FromStr, sync::Arc}; -#[derive(Debug)] -pub enum AuthBackendType { - /// Legacy Cloud API (V1). - LegacyConsole, - /// Authentication via a web browser. - Link, - /// Current Cloud API (V2). - Console, - /// Local mock of Cloud API (V2). - Postgres, -} - -impl FromStr for AuthBackendType { +impl FromStr for auth::BackendType<()> { type Err = anyhow::Error; fn from_str(s: &str) -> anyhow::Result { - use AuthBackendType::*; + use auth::BackendType::*; Ok(match s { - "legacy" => LegacyConsole, - "console" => Console, - "postgres" => Postgres, + "legacy" => LegacyConsole(()), + "console" => Console(()), + "postgres" => Postgres(()), "link" => Link, _ => bail!("Invalid option `{s}` for auth method"), }) @@ -31,7 +19,11 @@ impl FromStr for AuthBackendType { pub struct ProxyConfig { pub tls_config: Option, - pub auth_backend: AuthBackendType, + pub auth_backend: auth::BackendType<()>, + pub auth_urls: AuthUrls, +} + +pub struct AuthUrls { pub auth_endpoint: ApiUrl, pub auth_link_uri: ApiUrl, } @@ -87,10 +79,8 @@ pub fn configure_tls(key_path: &str, cert_path: &str) -> anyhow::Result>) -> io::Error { + io::Error::new(io::ErrorKind::Other, e) +} diff --git a/proxy/src/main.rs b/proxy/src/main.rs index b68b2440dd..2521f2af21 100644 --- a/proxy/src/main.rs +++ b/proxy/src/main.rs @@ -118,11 +118,15 @@ async fn main() -> anyhow::Result<()> { let mgmt_address: SocketAddr = arg_matches.value_of("mgmt").unwrap().parse()?; let http_address: SocketAddr = arg_matches.value_of("http").unwrap().parse()?; + let auth_urls = config::AuthUrls { + auth_endpoint: arg_matches.value_of("auth-endpoint").unwrap().parse()?, + auth_link_uri: arg_matches.value_of("uri").unwrap().parse()?, + }; + let config: &ProxyConfig = Box::leak(Box::new(ProxyConfig { tls_config, auth_backend: arg_matches.value_of("auth-backend").unwrap().parse()?, - auth_endpoint: arg_matches.value_of("auth-endpoint").unwrap().parse()?, - auth_link_uri: arg_matches.value_of("uri").unwrap().parse()?, + auth_urls, })); println!("Version: {GIT_VERSION}"); diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 7e364b5e9c..f202782109 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -82,11 +82,22 @@ async fn handle_client( } let tls = config.tls_config.as_ref(); - let (stream, creds) = match handshake(stream, tls, cancel_map).await? { + let (mut stream, params) = match handshake(stream, tls, cancel_map).await? { Some(x) => x, None => return Ok(()), // it's a cancellation request }; + let creds = { + let sni = stream.get_ref().sni_hostname(); + let common_name = tls.and_then(|tls| tls.common_name.as_deref()); + let result = config + .auth_backend + .map(|_| auth::ClientCredentials::parse(params, sni, common_name)) + .transpose(); + + async { result }.or_else(|e| stream.throw_error(e)).await? + }; + let client = Client::new(stream, creds); cancel_map .with_session(|session| client.connect_to_db(config, session)) @@ -101,12 +112,10 @@ async fn handshake( stream: S, mut tls: Option<&TlsConfig>, cancel_map: &CancelMap, -) -> anyhow::Result>, auth::ClientCredentials)>> { +) -> anyhow::Result>, StartupMessageParams)>> { // Client may try upgrading to each protocol only once let (mut tried_ssl, mut tried_gss) = (false, false); - let common_name = tls.and_then(|cfg| cfg.common_name.as_deref()); - let mut stream = PqStream::new(Stream::from_raw(stream)); loop { let msg = stream.read_startup_packet().await?; @@ -147,18 +156,7 @@ async fn handshake( stream.throw_error_str(ERR_INSECURE_CONNECTION).await?; } - // Get SNI info when available - let sni_data = match stream.get_ref() { - Stream::Tls { tls } => tls.get_ref().1.sni_hostname().map(|s| s.to_owned()), - _ => None, - }; - - // Construct credentials - let creds = - auth::ClientCredentials::parse(params, sni_data.as_deref(), common_name); - let creds = async { creds }.or_else(|e| stream.throw_error(e)).await?; - - break Ok(Some((stream, creds))); + break Ok(Some((stream, params))); } CancelRequest(cancel_key_data) => { cancel_map.cancel_session(cancel_key_data).await?; @@ -174,12 +172,12 @@ struct Client { /// The underlying libpq protocol stream. stream: PqStream, /// Client credentials that we care about. - creds: auth::ClientCredentials, + creds: auth::BackendType, } impl Client { /// Construct a new connection context. - fn new(stream: PqStream, creds: auth::ClientCredentials) -> Self { + fn new(stream: PqStream, creds: auth::BackendType) -> Self { Self { stream, creds } } } @@ -194,16 +192,22 @@ impl Client { let Self { mut stream, creds } = self; // Authenticate and connect to a compute node. - let auth = creds.authenticate(config, &mut stream).await; + let auth = creds.authenticate(&config.auth_urls, &mut stream).await; let node = async { auth }.or_else(|e| stream.throw_error(e)).await?; - let (db, version, cancel_closure) = - node.connect().or_else(|e| stream.throw_error(e)).await?; + let (db, cancel_closure) = node.connect().or_else(|e| stream.throw_error(e)).await?; let cancel_key_data = session.enable_cancellation(cancel_closure); + // Report authentication success if we haven't done this already. + if !node.reported_auth_ok { + stream + .write_message_noflush(&Be::AuthenticationOk)? + .write_message_noflush(&BeParameterStatusMessage::encoding())?; + } + stream .write_message_noflush(&BeMessage::ParameterStatus( - BeParameterStatusMessage::ServerVersion(&version), + BeParameterStatusMessage::ServerVersion(&db.version), ))? .write_message_noflush(&Be::BackendKeyData(cancel_key_data))? .write_message(&BeMessage::ReadyForQuery) @@ -217,7 +221,7 @@ impl Client { } // Starting from here we only proxy the client's traffic. - let mut db = MetricsStream::new(db, inc_proxied); + let mut db = MetricsStream::new(db.stream, inc_proxied); let mut client = MetricsStream::new(stream.into_inner(), inc_proxied); let _ = tokio::io::copy_bidirectional(&mut client, &mut db).await?; @@ -279,9 +283,13 @@ mod tests { let config = rustls::ServerConfig::builder() .with_safe_defaults() .with_no_client_auth() - .with_single_cert(vec![cert], key)?; + .with_single_cert(vec![cert], key)? + .into(); - config.into() + TlsConfig { + config, + common_name: Some(common_name.to_string()), + } }; let client_config = { @@ -297,11 +305,6 @@ mod tests { ClientConfig { config, hostname } }; - let tls_config = TlsConfig { - config: tls_config, - common_name: Some(common_name.to_string()), - }; - Ok((client_config, tls_config)) } @@ -357,7 +360,7 @@ mod tests { auth: impl TestAuth + Send, ) -> anyhow::Result<()> { let cancel_map = CancelMap::default(); - let (mut stream, _creds) = handshake(client, tls.as_ref(), &cancel_map) + let (mut stream, _params) = handshake(client, tls.as_ref(), &cancel_map) .await? .context("handshake failed")?; @@ -436,32 +439,6 @@ mod tests { proxy.await? } - #[tokio::test] - async fn give_user_an_error_for_bad_creds() -> anyhow::Result<()> { - let (client, server) = tokio::io::duplex(1024); - - let proxy = tokio::spawn(dummy_proxy(client, None, NoAuth)); - - let client_err = tokio_postgres::Config::new() - .ssl_mode(SslMode::Disable) - .connect_raw(server, NoTls) - .await - .err() // -> Option - .context("client shouldn't be able to connect")?; - - // TODO: this is ugly, but `format!` won't allow us to extract fmt string - assert!(client_err.to_string().contains("missing in startup packet")); - - let server_err = proxy - .await? - .err() // -> Option - .context("server shouldn't accept client")?; - - assert!(client_err.to_string().contains(&server_err.to_string())); - - Ok(()) - } - #[tokio::test] async fn keepalive_is_inherited() -> anyhow::Result<()> { use tokio::net::{TcpListener, TcpStream}; diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index 42b0185fde..54ff8bcc07 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -145,6 +145,14 @@ impl Stream { pub fn from_raw(raw: S) -> Self { Self::Raw { raw } } + + /// Return SNI hostname when it's available. + pub fn sni_hostname(&self) -> Option<&str> { + match self { + Stream::Raw { .. } => None, + Stream::Tls { tls } => tls.get_ref().1.sni_hostname(), + } + } } #[derive(Debug, Error)] diff --git a/test_runner/batch_others/test_proxy.py b/test_runner/batch_others/test_proxy.py index ebeede8df7..92c8475e69 100644 --- a/test_runner/batch_others/test_proxy.py +++ b/test_runner/batch_others/test_proxy.py @@ -1,8 +1,34 @@ import pytest +import json +import base64 def test_proxy_select_1(static_proxy): - static_proxy.safe_psql("select 1;", options="project=generic-project-name") + static_proxy.safe_psql('select 1', options='project=generic-project-name') + + +def test_password_hack(static_proxy): + user = 'borat' + password = 'password' + static_proxy.safe_psql(f"create role {user} with login password '{password}'", + options='project=irrelevant') + + def encode(s: str) -> str: + return base64.b64encode(s.encode('utf-8')).decode('utf-8') + + magic = encode(json.dumps({ + 'project': 'irrelevant', + 'password': password, + })) + + static_proxy.safe_psql('select 1', sslsni=0, user=user, password=magic) + + magic = encode(json.dumps({ + 'project': 'irrelevant', + 'password_': encode(password), + })) + + static_proxy.safe_psql('select 1', sslsni=0, user=user, password=magic) # Pass extra options to the server. @@ -11,8 +37,8 @@ def test_proxy_select_1(static_proxy): # See https://github.com/neondatabase/neon/issues/1287 @pytest.mark.xfail def test_proxy_options(static_proxy): - with static_proxy.connect(options="-cproxytest.option=value") as conn: + with static_proxy.connect(options='-cproxytest.option=value') as conn: with conn.cursor() as cur: - cur.execute("SHOW proxytest.option;") + cur.execute('SHOW proxytest.option') value = cur.fetchall()[0][0] assert value == 'value' diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 3a6a233208..b1fba29e3b 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -30,7 +30,7 @@ from dataclasses import dataclass # Type-related stuff from psycopg2.extensions import connection as PgConnection from psycopg2.extensions import make_dsn, parse_dsn -from typing import Any, Callable, Dict, Iterator, List, Optional, Type, TypeVar, cast, Union, Tuple +from typing import Any, Callable, Dict, Iterator, List, Optional, TypeVar, cast, Union, Tuple from typing_extensions import Literal import requests @@ -280,20 +280,18 @@ class PgProtocol: return str(make_dsn(**self.conn_options(**kwargs))) def conn_options(self, **kwargs): - conn_options = self.default_options.copy() + result = self.default_options.copy() if 'dsn' in kwargs: - conn_options.update(parse_dsn(kwargs['dsn'])) - conn_options.update(kwargs) + result.update(parse_dsn(kwargs['dsn'])) + result.update(kwargs) # Individual statement timeout in seconds. 2 minutes should be # enough for our tests, but if you need a longer, you can # change it by calling "SET statement_timeout" after # connecting. - if 'options' in conn_options: - conn_options['options'] = f"-cstatement_timeout=120s " + conn_options['options'] - else: - conn_options['options'] = "-cstatement_timeout=120s" - return conn_options + options = result.get('options', '') + result['options'] = f'-cstatement_timeout=120s {options}' + return result # autocommit=True here by default because that's what we need most of the time def connect(self, autocommit=True, **kwargs) -> PgConnection: @@ -1514,29 +1512,25 @@ def remote_pg(test_output_dir: Path) -> Iterator[RemotePostgres]: class NeonProxy(PgProtocol): - def __init__(self, port: int, pg_port: int): - super().__init__(host="127.0.0.1", - user="proxy_user", - password="pytest2", - port=port, - dbname='postgres') - self.http_port = 7001 - self.host = "127.0.0.1" - self.port = port - self.pg_port = pg_port + def __init__(self, proxy_port: int, http_port: int, auth_endpoint: str): + super().__init__(dsn=auth_endpoint, port=proxy_port) + self.host = '127.0.0.1' + self.http_port = http_port + self.proxy_port = proxy_port + self.auth_endpoint = auth_endpoint self._popen: Optional[subprocess.Popen[bytes]] = None def start(self) -> None: assert self._popen is None # Start proxy - bin_proxy = os.path.join(str(neon_binpath), 'proxy') - args = [bin_proxy] - args.extend(["--http", f"{self.host}:{self.http_port}"]) - args.extend(["--proxy", f"{self.host}:{self.port}"]) - args.extend(["--auth-backend", "postgres"]) - args.extend( - ["--auth-endpoint", f"postgres://proxy_auth:pytest1@localhost:{self.pg_port}/postgres"]) + args = [ + os.path.join(str(neon_binpath), 'proxy'), + *["--http", f"{self.host}:{self.http_port}"], + *["--proxy", f"{self.host}:{self.proxy_port}"], + *["--auth-backend", "postgres"], + *["--auth-endpoint", self.auth_endpoint], + ] self._popen = subprocess.Popen(args) self._wait_until_ready() @@ -1557,13 +1551,21 @@ class NeonProxy(PgProtocol): @pytest.fixture(scope='function') def static_proxy(vanilla_pg, port_distributor) -> Iterator[NeonProxy]: """Neon proxy that routes directly to vanilla postgres.""" - vanilla_pg.start() - vanilla_pg.safe_psql("create user proxy_auth with password 'pytest1' superuser") - vanilla_pg.safe_psql("create user proxy_user with password 'pytest2'") - port = port_distributor.get_port() - pg_port = vanilla_pg.default_options['port'] - with NeonProxy(port, pg_port) as proxy: + # For simplicity, we use the same user for both `--auth-endpoint` and `safe_psql` + vanilla_pg.start() + vanilla_pg.safe_psql("create user proxy with login superuser password 'password'") + + port = vanilla_pg.default_options['port'] + host = vanilla_pg.default_options['host'] + dbname = vanilla_pg.default_options['dbname'] + auth_endpoint = f'postgres://proxy:password@{host}:{port}/{dbname}' + + proxy_port = port_distributor.get_port() + http_port = port_distributor.get_port() + + with NeonProxy(proxy_port=proxy_port, http_port=http_port, + auth_endpoint=auth_endpoint) as proxy: proxy.start() yield proxy