From af0195b60478bc82cbb7c95c1421b5ab4c3e752e Mon Sep 17 00:00:00 2001 From: Dmitry Ivanov Date: Wed, 27 Apr 2022 13:34:59 +0300 Subject: [PATCH] [proxy] Introduce `cloud::Api` for communication with Neon Cloud * `cloud::legacy` talks to Cloud API V1. * `cloud::api` defines Cloud API v2. * `cloud::local` mocks the Cloud API V2 using a local postgres instance. * It's possible to choose between API versions using the `--api-version` flag. --- proxy/Cargo.toml | 2 +- proxy/src/auth.rs | 129 +++++++++++-------- proxy/src/auth/credentials.rs | 30 ++--- proxy/src/auth/flow.rs | 28 +--- proxy/src/cloud.rs | 46 +++++++ proxy/src/cloud/api.rs | 120 +++++++++++++++++ proxy/src/{cplane_api.rs => cloud/legacy.rs} | 65 +++------- proxy/src/cloud/local.rs | 76 +++++++++++ proxy/src/compute.rs | 63 +++------ proxy/src/config.rs | 84 +++++------- proxy/src/main.rs | 108 ++++++++-------- proxy/src/mgmt.rs | 8 +- proxy/src/proxy.rs | 4 +- proxy/src/scram.rs | 4 +- proxy/src/scram/key.rs | 4 + 15 files changed, 471 insertions(+), 300 deletions(-) create mode 100644 proxy/src/cloud.rs create mode 100644 proxy/src/cloud/api.rs rename proxy/src/{cplane_api.rs => cloud/legacy.rs} (81%) create mode 100644 proxy/src/cloud/local.rs diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index f7e872ceb9..73412609f3 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -5,6 +5,7 @@ edition = "2021" [dependencies] anyhow = "1.0" +async-trait = "0.1" base64 = "0.13.0" bytes = { version = "1.0.1", features = ['serde'] } clap = "3.0" @@ -37,7 +38,6 @@ metrics = { path = "../libs/metrics" } workspace_hack = { version = "0.1", path = "../workspace_hack" } [dev-dependencies] -async-trait = "0.1" rcgen = "0.8.14" rstest = "0.12" tokio-postgres-rustls = "0.9.0" diff --git a/proxy/src/auth.rs b/proxy/src/auth.rs index c6d32040dc..5234dfc2c6 100644 --- a/proxy/src/auth.rs +++ b/proxy/src/auth.rs @@ -1,22 +1,16 @@ mod credentials; - -#[cfg(test)] mod flow; -use crate::compute::DatabaseInfo; -use crate::config::ProxyConfig; -use crate::cplane_api::{self, CPlaneApi}; +use crate::config::{CloudApi, ProxyConfig}; use crate::error::UserFacingError; use crate::stream::PqStream; -use crate::waiters; +use crate::{cloud, compute, waiters}; use std::io; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; use utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage}; pub use credentials::ClientCredentials; - -#[cfg(test)] pub use flow::*; /// Common authentication error. @@ -24,9 +18,14 @@ pub use flow::*; pub enum AuthErrorImpl { /// Authentication error reported by the console. #[error(transparent)] - Console(#[from] cplane_api::AuthError), + Console(#[from] cloud::AuthError), + + #[error(transparent)] + GetAuthInfo(#[from] cloud::api::GetAuthInfoError), + + #[error(transparent)] + WakeCompute(#[from] cloud::api::WakeComputeError), - #[cfg(test)] #[error(transparent)] Sasl(#[from] crate::sasl::Error), @@ -41,19 +40,19 @@ pub enum AuthErrorImpl { impl AuthErrorImpl { pub fn auth_failed(msg: impl Into) -> Self { - AuthErrorImpl::Console(cplane_api::AuthError::auth_failed(msg)) + AuthErrorImpl::Console(cloud::AuthError::auth_failed(msg)) } } impl From for AuthErrorImpl { fn from(e: waiters::RegisterError) -> Self { - AuthErrorImpl::Console(cplane_api::AuthError::from(e)) + AuthErrorImpl::Console(cloud::AuthError::from(e)) } } impl From for AuthErrorImpl { fn from(e: waiters::WaitError) -> Self { - AuthErrorImpl::Console(cplane_api::AuthError::from(e)) + AuthErrorImpl::Console(cloud::AuthError::from(e)) } } @@ -81,40 +80,28 @@ impl UserFacingError for AuthError { } } -async fn handle_static( - host: String, - port: u16, - client: &mut PqStream, - creds: ClientCredentials, -) -> Result { - client - .write_message(&Be::AuthenticationCleartextPassword) - .await?; - - // Read client's password bytes - let msg = client.read_password_message().await?; - let cleartext_password = parse_password(&msg).ok_or(AuthErrorImpl::MalformedPassword)?; - - let db_info = DatabaseInfo { - host, - port, - dbname: creds.dbname.clone(), - user: creds.user.clone(), - password: Some(cleartext_password.into()), - }; - - client - .write_message_noflush(&Be::AuthenticationOk)? - .write_message_noflush(&BeParameterStatusMessage::encoding())?; - - Ok(db_info) -} - -async fn handle_existing_user( +async fn handle_user( config: &ProxyConfig, client: &mut PqStream, creds: ClientCredentials, -) -> Result { +) -> Result { + if creds.is_existing_user() { + match &config.cloud_endpoint { + CloudApi::V1(api) => handle_existing_user_v1(api, client, creds).await, + CloudApi::V2(api) => handle_existing_user_v2(api.as_ref(), client, creds).await, + } + } else { + let redirect_uri = config.redirect_uri.as_ref(); + handle_new_user(redirect_uri, client).await + } +} + +/// Authenticate user via a legacy cloud API endpoint. +async fn handle_existing_user_v1( + cloud: &cloud::Legacy, + client: &mut PqStream, + creds: ClientCredentials, +) -> Result { let psql_session_id = new_psql_session_id(); let md5_salt = rand::random(); @@ -126,8 +113,7 @@ async fn handle_existing_user( let msg = client.read_password_message().await?; let md5_response = parse_password(&msg).ok_or(AuthErrorImpl::MalformedPassword)?; - let cplane = CPlaneApi::new(config.auth_endpoint.clone()); - let db_info = cplane + let db_info = cloud .authenticate_proxy_client(creds, md5_response, &md5_salt, &psql_session_id) .await?; @@ -135,17 +121,53 @@ async fn handle_existing_user( .write_message_noflush(&Be::AuthenticationOk)? .write_message_noflush(&BeParameterStatusMessage::encoding())?; - Ok(db_info) + Ok(compute::NodeInfo { + db_info, + scram_keys: None, + }) +} + +/// Authenticate user via a new cloud API endpoint which supports SCRAM. +async fn handle_existing_user_v2( + cloud: &(impl cloud::Api + ?Sized), + client: &mut PqStream, + creds: ClientCredentials, +) -> Result { + let auth_info = cloud.get_auth_info(&creds).await?; + + let flow = AuthFlow::new(client); + let scram_keys = match auth_info { + cloud::api::AuthInfo::Md5(_) => { + // TODO: decide if we should support MD5 in api v2 + return Err(AuthErrorImpl::auth_failed("MD5 is not supported").into()); + } + cloud::api::AuthInfo::Scram(secret) => { + let scram = Scram(&secret); + Some(compute::ScramKeys { + client_key: flow.begin(scram).await?.authenticate().await?.as_bytes(), + server_key: secret.server_key.as_bytes(), + }) + } + }; + + client + .write_message_noflush(&Be::AuthenticationOk)? + .write_message_noflush(&BeParameterStatusMessage::encoding())?; + + Ok(compute::NodeInfo { + db_info: cloud.wake_compute(&creds).await?, + scram_keys, + }) } async fn handle_new_user( - config: &ProxyConfig, + redirect_uri: &str, client: &mut PqStream, -) -> Result { +) -> Result { let psql_session_id = new_psql_session_id(); - let greeting = hello_message(&config.redirect_uri, &psql_session_id); + let greeting = hello_message(redirect_uri, &psql_session_id); - let db_info = cplane_api::with_waiter(psql_session_id, |waiter| async { + let db_info = cloud::with_waiter(psql_session_id, |waiter| async { // Give user a URL to spawn a new database client .write_message_noflush(&Be::AuthenticationOk)? @@ -160,7 +182,10 @@ async fn handle_new_user( client.write_message_noflush(&Be::NoticeResponse("Connecting to database."))?; - Ok(db_info) + Ok(compute::NodeInfo { + db_info, + scram_keys: None, + }) } fn new_psql_session_id() -> String { diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index c3bb6da4f8..a3d06b49a2 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -1,7 +1,7 @@ //! User credentials used in authentication. use super::AuthError; -use crate::compute::DatabaseInfo; +use crate::compute; use crate::config::ProxyConfig; use crate::error::UserFacingError; use crate::stream::PqStream; @@ -18,12 +18,20 @@ pub enum ClientCredsParseError { impl UserFacingError for ClientCredsParseError {} /// Various client credentials which we use for authentication. -#[derive(Debug, PartialEq, Eq)] +/// Note that we don't store any kind of client key or password here. +#[derive(Debug, Clone, PartialEq, Eq)] pub struct ClientCredentials { pub user: String, pub dbname: String, } +impl ClientCredentials { + pub fn is_existing_user(&self) -> bool { + // This logic will likely change in the future. + self.user.ends_with("@zenith") + } +} + impl TryFrom> for ClientCredentials { type Error = ClientCredsParseError; @@ -47,20 +55,8 @@ impl ClientCredentials { self, config: &ProxyConfig, client: &mut PqStream, - ) -> Result { - use crate::config::ClientAuthMethod::*; - use crate::config::RouterConfig::*; - match &config.router_config { - Static { host, port } => super::handle_static(host.clone(), *port, client, self).await, - Dynamic(Mixed) => { - if self.user.ends_with("@zenith") { - super::handle_existing_user(config, client, self).await - } else { - super::handle_new_user(config, client).await - } - } - Dynamic(Password) => super::handle_existing_user(config, client, self).await, - Dynamic(Link) => super::handle_new_user(config, client).await, - } + ) -> Result { + // This method is just a convenient facade for `handle_user` + super::handle_user(config, client, self).await } } diff --git a/proxy/src/auth/flow.rs b/proxy/src/auth/flow.rs index bcfd94a9ed..3eed0f0a23 100644 --- a/proxy/src/auth/flow.rs +++ b/proxy/src/auth/flow.rs @@ -27,19 +27,6 @@ impl AuthMethod for Scram<'_> { } } -/// Use password-based auth in [`AuthFlow`]. -pub struct Md5( - /// Salt for client. - pub [u8; 4], -); - -impl AuthMethod for Md5 { - #[inline(always)] - fn first_message(&self) -> BeMessage<'_> { - Be::AuthenticationMD5Password(self.0) - } -} - /// This wrapper for [`PqStream`] performs client authentication. #[must_use] pub struct AuthFlow<'a, Stream, State> { @@ -70,19 +57,10 @@ impl<'a, S: AsyncWrite + Unpin> AuthFlow<'a, S, Begin> { } } -/// Stream wrapper for handling simple MD5 password auth. -impl AuthFlow<'_, S, Md5> { - /// Perform user authentication. Raise an error in case authentication failed. - #[allow(unused)] - pub async fn authenticate(self) -> Result<(), AuthError> { - unimplemented!("MD5 auth flow is yet to be implemented"); - } -} - /// Stream wrapper for handling [SCRAM](crate::scram) auth. impl AuthFlow<'_, S, Scram<'_>> { /// Perform user authentication. Raise an error in case authentication failed. - pub async fn authenticate(self) -> Result<(), AuthError> { + pub async fn authenticate(self) -> Result { // Initial client message contains the chosen auth method's name. let msg = self.stream.read_password_message().await?; let sasl = sasl::FirstMessage::parse(&msg).ok_or(AuthErrorImpl::MalformedPassword)?; @@ -93,10 +71,10 @@ impl AuthFlow<'_, S, Scram<'_>> { } let secret = self.state.0; - sasl::SaslStream::new(self.stream, sasl.message) + let key = sasl::SaslStream::new(self.stream, sasl.message) .authenticate(scram::Exchange::new(secret, rand::random, None)) .await?; - Ok(()) + Ok(key) } } diff --git a/proxy/src/cloud.rs b/proxy/src/cloud.rs new file mode 100644 index 0000000000..679cfb97e1 --- /dev/null +++ b/proxy/src/cloud.rs @@ -0,0 +1,46 @@ +mod local; + +mod legacy; +pub use legacy::{AuthError, AuthErrorImpl, Legacy}; + +pub mod api; +pub use api::{Api, BoxedApi}; + +use crate::mgmt; +use crate::waiters::{self, Waiter, Waiters}; +use lazy_static::lazy_static; + +lazy_static! { + static ref CPLANE_WAITERS: Waiters = Default::default(); +} + +/// Give caller an opportunity to wait for the cloud's reply. +pub async fn with_waiter( + psql_session_id: impl Into, + action: impl FnOnce(Waiter<'static, mgmt::ComputeReady>) -> R, +) -> Result +where + R: std::future::Future>, + E: From, +{ + let waiter = CPLANE_WAITERS.register(psql_session_id.into())?; + action(waiter).await +} + +pub fn notify(psql_session_id: &str, msg: mgmt::ComputeReady) -> Result<(), waiters::NotifyError> { + CPLANE_WAITERS.notify(psql_session_id, msg) +} + +/// Construct a new opaque cloud API provider. +pub fn new(url: reqwest::Url) -> anyhow::Result { + Ok(match url.scheme() { + "https" | "http" => { + todo!("build a real cloud wrapper") + } + "postgresql" | "postgres" | "pg" => { + // Just point to a local running postgres instance. + Box::new(local::Local { url }) + } + other => anyhow::bail!("unsupported url scheme: {other}"), + }) +} diff --git a/proxy/src/cloud/api.rs b/proxy/src/cloud/api.rs new file mode 100644 index 0000000000..713140c1e6 --- /dev/null +++ b/proxy/src/cloud/api.rs @@ -0,0 +1,120 @@ +//! Declaration of Cloud API V2. + +use crate::{auth, scram}; +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum GetAuthInfoError { + // We shouldn't include the actual secret here. + #[error("Bad authentication secret")] + BadSecret, + + #[error("Bad client credentials: {0:?}")] + BadCredentials(crate::auth::ClientCredentials), + + #[error(transparent)] + Io(#[from] std::io::Error), +} + +// TODO: convert to an enum and describe possible sub-errors (see above) +#[derive(Debug, Error)] +#[error("Failed to wake up the compute node")] +pub struct WakeComputeError; + +/// Opaque implementation of Cloud API. +pub type BoxedApi = Box; + +/// Cloud API methods required by the proxy. +#[async_trait] +pub trait Api { + /// Get authentication information for the given user. + async fn get_auth_info( + &self, + creds: &auth::ClientCredentials, + ) -> Result; + + /// Wake up the compute node and return the corresponding connection info. + async fn wake_compute( + &self, + creds: &auth::ClientCredentials, + ) -> Result; +} + +/// Auth secret which is managed by the cloud. +pub enum AuthInfo { + /// Md5 hash of user's password. + Md5([u8; 16]), + /// [SCRAM](crate::scram) authentication info. + Scram(scram::ServerSecret), +} + +/// Compute node connection params provided by the cloud. +/// Note how it implements serde traits, since we receive it over the wire. +#[derive(Serialize, Deserialize, Default)] +pub struct DatabaseInfo { + pub host: String, + pub port: u16, + pub dbname: String, + pub user: String, + + /// [Cloud API V1](super::legacy) returns cleartext password, + /// but [Cloud API V2](super::api) implements [SCRAM](crate::scram) + /// authentication, so we can leverage this method and cope without password. + pub password: Option, +} + +// Manually implement debug to omit personal and sensitive info. +impl std::fmt::Debug for DatabaseInfo { + fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { + fmt.debug_struct("DatabaseInfo") + .field("host", &self.host) + .field("port", &self.port) + .finish() + } +} + +impl From for tokio_postgres::Config { + fn from(db_info: DatabaseInfo) -> Self { + let mut config = tokio_postgres::Config::new(); + + config + .host(&db_info.host) + .port(db_info.port) + .dbname(&db_info.dbname) + .user(&db_info.user); + + if let Some(password) = db_info.password { + config.password(password); + } + + config + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn parse_db_info() -> anyhow::Result<()> { + let _: DatabaseInfo = serde_json::from_value(json!({ + "host": "localhost", + "port": 5432, + "dbname": "postgres", + "user": "john_doe", + "password": "password", + }))?; + + let _: DatabaseInfo = serde_json::from_value(json!({ + "host": "localhost", + "port": 5432, + "dbname": "postgres", + "user": "john_doe", + }))?; + + Ok(()) + } +} diff --git a/proxy/src/cplane_api.rs b/proxy/src/cloud/legacy.rs similarity index 81% rename from proxy/src/cplane_api.rs rename to proxy/src/cloud/legacy.rs index 21fce79df3..7d99995f1a 100644 --- a/proxy/src/cplane_api.rs +++ b/proxy/src/cloud/legacy.rs @@ -1,42 +1,19 @@ +//! Cloud API V1. + +use super::api::DatabaseInfo; use crate::auth::ClientCredentials; -use crate::compute::DatabaseInfo; use crate::error::UserFacingError; -use crate::mgmt; -use crate::waiters::{self, Waiter, Waiters}; -use lazy_static::lazy_static; +use crate::waiters; use serde::{Deserialize, Serialize}; use thiserror::Error; -lazy_static! { - static ref CPLANE_WAITERS: Waiters = Default::default(); -} - -/// Give caller an opportunity to wait for cplane's reply. -pub async fn with_waiter( - psql_session_id: impl Into, - action: impl FnOnce(Waiter<'static, mgmt::ComputeReady>) -> R, -) -> Result -where - R: std::future::Future>, - E: From, -{ - let waiter = CPLANE_WAITERS.register(psql_session_id.into())?; - action(waiter).await -} - -pub fn notify( - psql_session_id: &str, - msg: Result, -) -> Result<(), waiters::NotifyError> { - CPLANE_WAITERS.notify(psql_session_id, msg) -} - -/// Zenith console API wrapper. -pub struct CPlaneApi { +/// Neon cloud API provider. +pub struct Legacy { auth_endpoint: reqwest::Url, } -impl CPlaneApi { +impl Legacy { + /// Construct a new legacy cloud API provider. pub fn new(auth_endpoint: reqwest::Url) -> Self { Self { auth_endpoint } } @@ -95,7 +72,17 @@ impl UserFacingError for AuthError { } } -impl CPlaneApi { +// NOTE: the order of constructors is important. +// https://serde.rs/enum-representations.html#untagged +#[derive(Serialize, Deserialize, Debug)] +#[serde(untagged)] +enum ProxyAuthResponse { + Ready { conn_info: DatabaseInfo }, + Error { error: String }, + NotReady { ready: bool }, // TODO: get rid of `ready` +} + +impl Legacy { pub async fn authenticate_proxy_client( &self, creds: ClientCredentials, @@ -111,8 +98,8 @@ impl CPlaneApi { .append_pair("salt", &hex::encode(salt)) .append_pair("psql_session_id", psql_session_id); - with_waiter(psql_session_id, |waiter| async { - println!("cplane request: {}", url); + super::with_waiter(psql_session_id, |waiter| async { + println!("cloud request: {}", url); // TODO: leverage `reqwest::Client` to reuse connections let resp = reqwest::get(url).await?; if !resp.status().is_success() { @@ -135,16 +122,6 @@ impl CPlaneApi { } } -// NOTE: the order of constructors is important. -// https://serde.rs/enum-representations.html#untagged -#[derive(Serialize, Deserialize, Debug)] -#[serde(untagged)] -enum ProxyAuthResponse { - Ready { conn_info: DatabaseInfo }, - Error { error: String }, - NotReady { ready: bool }, // TODO: get rid of `ready` -} - #[cfg(test)] mod tests { use super::*; diff --git a/proxy/src/cloud/local.rs b/proxy/src/cloud/local.rs new file mode 100644 index 0000000000..88eda6630c --- /dev/null +++ b/proxy/src/cloud/local.rs @@ -0,0 +1,76 @@ +//! Local mock of Cloud API V2. + +use super::api::{self, Api, AuthInfo, DatabaseInfo}; +use crate::auth::ClientCredentials; +use crate::scram; +use async_trait::async_trait; + +/// Mocked cloud for testing purposes. +pub struct Local { + /// Database url, e.g. `postgres://user:password@localhost:5432/database`. + pub url: reqwest::Url, +} + +#[async_trait] +impl Api for Local { + async fn get_auth_info( + &self, + creds: &ClientCredentials, + ) -> Result { + // We wrap `tokio_postgres::Error` because we don't want to infect the + // method's error type with a detail that's specific to debug mode only. + let io_error = |e| std::io::Error::new(std::io::ErrorKind::Other, e); + + // Perhaps we could persist this connection, but then we'd have to + // write more code for reopening it if it got closed, which doesn't + // seem worth it. + let (client, connection) = + tokio_postgres::connect(self.url.as_str(), tokio_postgres::NoTls) + .await + .map_err(io_error)?; + + tokio::spawn(connection); + let query = "select rolpassword from pg_catalog.pg_authid where rolname = $1"; + let rows = client + .query(query, &[&creds.user]) + .await + .map_err(io_error)?; + + match &rows[..] { + // We can't get a secret if there's no such user. + [] => Err(api::GetAuthInfoError::BadCredentials(creds.to_owned())), + // We shouldn't get more than one row anyway. + [row, ..] => { + let entry = row.try_get(0).map_err(io_error)?; + scram::ServerSecret::parse(entry) + .map(AuthInfo::Scram) + .or_else(|| { + // It could be an md5 hash if it's not a SCRAM secret. + let text = entry.strip_prefix("md5")?; + Some(AuthInfo::Md5({ + let mut bytes = [0u8; 16]; + hex::decode_to_slice(text, &mut bytes).ok()?; + bytes + })) + }) + // Putting the secret into this message is a security hazard! + .ok_or(api::GetAuthInfoError::BadSecret) + } + } + } + + async fn wake_compute( + &self, + creds: &ClientCredentials, + ) -> Result { + // Local setup doesn't have a dedicated compute node, + // so we just return the local database we're pointed at. + Ok(DatabaseInfo { + host: self.url.host_str().unwrap_or("localhost").to_owned(), + port: self.url.port().unwrap_or(5432), + dbname: creds.dbname.to_owned(), + user: creds.user.to_owned(), + password: None, + }) + } +} diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index 3c0eee29bc..9949e91ea2 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -1,6 +1,6 @@ use crate::cancellation::CancelClosure; +use crate::cloud::api::DatabaseInfo; use crate::error::UserFacingError; -use serde::{Deserialize, Serialize}; use std::io; use std::net::SocketAddr; use thiserror::Error; @@ -23,32 +23,21 @@ pub enum ConnectionError { impl UserFacingError for ConnectionError {} -/// Compute node connection params. -#[derive(Serialize, Deserialize, Default)] -pub struct DatabaseInfo { - pub host: String, - pub port: u16, - pub dbname: String, - pub user: String, - pub password: Option, -} - -// Manually implement debug to omit personal and sensitive info -impl std::fmt::Debug for DatabaseInfo { - fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { - fmt.debug_struct("DatabaseInfo") - .field("host", &self.host) - .field("port", &self.port) - .finish() - } -} - /// PostgreSQL version as [`String`]. pub type Version = String; -impl DatabaseInfo { +/// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`. +pub type ScramKeys = tokio_postgres::config::ScramKeys<32>; + +/// Compute node connection params. +pub struct NodeInfo { + pub db_info: DatabaseInfo, + pub scram_keys: Option, +} + +impl NodeInfo { async fn connect_raw(&self) -> io::Result<(SocketAddr, TcpStream)> { - let host_port = format!("{}:{}", self.host, self.port); + let host_port = format!("{}:{}", self.db_info.host, self.db_info.port); let socket = TcpStream::connect(host_port).await?; let socket_addr = socket.peer_addr()?; socket2::SockRef::from(&socket).set_keepalive(true)?; @@ -63,11 +52,13 @@ impl DatabaseInfo { .await .map_err(|_| ConnectionError::FailedToConnectToCompute)?; - // TODO: establish a secure connection to the DB - let (client, conn) = tokio_postgres::Config::from(self) - .connect_raw(&mut socket, NoTls) - .await?; + let mut config = tokio_postgres::Config::from(self.db_info); + if let Some(scram_keys) = self.scram_keys { + config.auth_keys(tokio_postgres::config::AuthKeys::ScramSha256(scram_keys)); + } + // TODO: establish a secure connection to the DB + let (client, conn) = config.connect_raw(&mut socket, NoTls).await?; let version = conn .parameter("server_version") .ok_or(ConnectionError::FailedToFetchPgVersion)? @@ -78,21 +69,3 @@ impl DatabaseInfo { Ok((socket, version, cancel_closure)) } } - -impl From for tokio_postgres::Config { - fn from(db_info: DatabaseInfo) -> Self { - let mut config = tokio_postgres::Config::new(); - - config - .host(&db_info.host) - .port(db_info.port) - .dbname(&db_info.dbname) - .user(&db_info.user); - - if let Some(password) = db_info.password { - config.password(password); - } - - config - } -} diff --git a/proxy/src/config.rs b/proxy/src/config.rs index aef079d089..6b30df604d 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -1,65 +1,43 @@ +use crate::cloud; use anyhow::{bail, ensure, Context}; -use std::net::SocketAddr; -use std::str::FromStr; use std::sync::Arc; -pub type TlsConfig = Arc; - -#[non_exhaustive] -pub enum ClientAuthMethod { - Password, - Link, - - /// Use password auth only if username ends with "@zenith" - Mixed, -} - -pub enum RouterConfig { - Static { host: String, port: u16 }, - Dynamic(ClientAuthMethod), -} - -impl FromStr for ClientAuthMethod { - type Err = anyhow::Error; - - fn from_str(s: &str) -> anyhow::Result { - use ClientAuthMethod::*; - match s { - "password" => Ok(Password), - "link" => Ok(Link), - "mixed" => Ok(Mixed), - _ => bail!("Invalid option for router: `{}`", s), - } - } -} - pub struct ProxyConfig { - /// main entrypoint for users to connect to - pub proxy_address: SocketAddr, + /// Unauthenticated users will be redirected to this URL. + pub redirect_uri: reqwest::Url, - /// method of assigning compute nodes - pub router_config: RouterConfig, - - /// internally used for status and prometheus metrics - pub http_address: SocketAddr, - - /// management endpoint. Upon user account creation control plane - /// will notify us here, so that we can 'unfreeze' user session. - /// TODO It uses postgres protocol over TCP but should be migrated to http. - pub mgmt_address: SocketAddr, - - /// send unauthenticated users to this URI - pub redirect_uri: String, - - /// control plane address where we would check auth. - pub auth_endpoint: reqwest::Url, + /// Cloud API endpoint for user authentication. + pub cloud_endpoint: CloudApi, + /// TLS configuration for the proxy. pub tls_config: Option, } -pub fn configure_ssl(key_path: &str, cert_path: &str) -> anyhow::Result { +/// Cloud API configuration. +pub enum CloudApi { + /// We'll drop this one when [`CloudApi::V2`] is stable. + V1(crate::cloud::Legacy), + /// The new version of the cloud API. + V2(crate::cloud::BoxedApi), +} + +impl CloudApi { + /// Configure Cloud API provider. + pub fn new(version: &str, url: reqwest::Url) -> anyhow::Result { + Ok(match version { + "v1" => Self::V1(cloud::Legacy::new(url)), + "v2" => Self::V2(cloud::new(url)?), + _ => bail!("unknown cloud API version: {}", version), + }) + } +} + +pub type TlsConfig = Arc; + +/// Configure TLS for the main endpoint. +pub fn configure_tls(key_path: &str, cert_path: &str) -> anyhow::Result { let key = { - let key_bytes = std::fs::read(key_path).context("SSL key file")?; + let key_bytes = std::fs::read(key_path).context("TLS key file")?; let mut keys = rustls_pemfile::pkcs8_private_keys(&mut &key_bytes[..]) .context("couldn't read TLS keys")?; @@ -68,7 +46,7 @@ pub fn configure_ssl(key_path: &str, cert_path: &str) -> anyhow::Result>` into `Result`. async fn flatten_err( f: impl Future, JoinError>>, @@ -44,7 +37,7 @@ async fn flatten_err( #[tokio::main] async fn main() -> anyhow::Result<()> { metrics::set_common_metrics_prefix("zenith_proxy"); - let arg_matches = App::new("Zenith proxy/router") + let arg_matches = App::new("Neon proxy/router") .version(GIT_VERSION) .arg( Arg::new("proxy") @@ -97,77 +90,80 @@ async fn main() -> anyhow::Result<()> { .short('a') .long("auth-endpoint") .takes_value(true) - .help("API endpoint for authenticating users") + .help("cloud API endpoint for authenticating users") .default_value("http://localhost:3000/authenticate_proxy_request/"), ) .arg( - Arg::new("ssl-key") - .short('k') - .long("ssl-key") + Arg::new("api-version") + .long("api-version") .takes_value(true) - .help("path to SSL key for client postgres connections"), + .default_value("v1") + .possible_values(["v1", "v2"]) + .help("cloud API version to be used for authentication"), ) .arg( - Arg::new("ssl-cert") - .short('c') - .long("ssl-cert") + Arg::new("tls-key") + .short('k') + .long("tls-key") + .alias("ssl-key") // backwards compatibility .takes_value(true) - .help("path to SSL cert for client postgres connections"), + .help("path to TLS key for client postgres connections"), + ) + .arg( + Arg::new("tls-cert") + .short('c') + .long("tls-cert") + .alias("ssl-cert") // backwards compatibility + .takes_value(true) + .help("path to TLS cert for client postgres connections"), ) .get_matches(); let tls_config = match ( - arg_matches.value_of("ssl-key"), - arg_matches.value_of("ssl-cert"), + arg_matches.value_of("tls-key"), + arg_matches.value_of("tls-cert"), ) { - (Some(key_path), Some(cert_path)) => Some(config::configure_ssl(key_path, cert_path)?), + (Some(key_path), Some(cert_path)) => Some(config::configure_tls(key_path, cert_path)?), (None, None) => None, - _ => bail!("either both or neither ssl-key and ssl-cert must be specified"), + _ => bail!("either both or neither tls-key and tls-cert must be specified"), }; - let auth_method = arg_matches.value_of("auth-method").unwrap().parse()?; - let router_config = match arg_matches.value_of("static-router") { - None => RouterConfig::Dynamic(auth_method), - Some(addr) => { - if let ClientAuthMethod::Password = auth_method { - let (host, port) = addr.split_once(':').unwrap(); - RouterConfig::Static { - host: host.to_string(), - port: port.parse().unwrap(), - } - } else { - bail!("static-router requires --auth-method password") - } - } - }; + let proxy_address: SocketAddr = arg_matches.value_of("proxy").unwrap().parse()?; + let mgmt_address: SocketAddr = arg_matches.value_of("mgmt").unwrap().parse()?; + let http_address: SocketAddr = arg_matches.value_of("http").unwrap().parse()?; + + let cloud_endpoint = config::CloudApi::new( + arg_matches.value_of("api-version").unwrap(), + arg_matches.value_of("auth-endpoint").unwrap().parse()?, + )?; let config: &ProxyConfig = Box::leak(Box::new(ProxyConfig { - router_config, - proxy_address: arg_matches.value_of("proxy").unwrap().parse()?, - mgmt_address: arg_matches.value_of("mgmt").unwrap().parse()?, - http_address: arg_matches.value_of("http").unwrap().parse()?, redirect_uri: arg_matches.value_of("uri").unwrap().parse()?, - auth_endpoint: arg_matches.value_of("auth-endpoint").unwrap().parse()?, + cloud_endpoint, tls_config, })); println!("Version: {}", GIT_VERSION); // Check that we can bind to address before further initialization - println!("Starting http on {}", config.http_address); - let http_listener = TcpListener::bind(config.http_address).await?.into_std()?; + println!("Starting http on {}", http_address); + let http_listener = TcpListener::bind(http_address).await?.into_std()?; - println!("Starting mgmt on {}", config.mgmt_address); - let mgmt_listener = TcpListener::bind(config.mgmt_address).await?.into_std()?; + println!("Starting mgmt on {}", mgmt_address); + let mgmt_listener = TcpListener::bind(mgmt_address).await?.into_std()?; - println!("Starting proxy on {}", config.proxy_address); - let proxy_listener = TcpListener::bind(config.proxy_address).await?; + println!("Starting proxy on {}", proxy_address); + let proxy_listener = TcpListener::bind(proxy_address).await?; - let http = tokio::spawn(http::thread_main(http_listener)); - let proxy = tokio::spawn(proxy::thread_main(config, proxy_listener)); - let mgmt = tokio::task::spawn_blocking(move || mgmt::thread_main(mgmt_listener)); + let tasks = [ + tokio::spawn(http::thread_main(http_listener)), + tokio::spawn(proxy::thread_main(config, proxy_listener)), + tokio::task::spawn_blocking(move || mgmt::thread_main(mgmt_listener)), + ] + .map(flatten_err); - let tasks = [flatten_err(http), flatten_err(proxy), flatten_err(mgmt)]; + // This will block until all tasks have completed. + // Furthermore, the first one to fail will cancel the rest. let _: Vec<()> = futures::future::try_join_all(tasks).await?; Ok(()) diff --git a/proxy/src/mgmt.rs b/proxy/src/mgmt.rs index 23ad8a2013..c48df653d3 100644 --- a/proxy/src/mgmt.rs +++ b/proxy/src/mgmt.rs @@ -1,4 +1,4 @@ -use crate::{compute::DatabaseInfo, cplane_api}; +use crate::cloud; use anyhow::Context; use serde::Deserialize; use std::{ @@ -75,12 +75,12 @@ struct PsqlSessionResponse { #[derive(Deserialize)] enum PsqlSessionResult { - Success(DatabaseInfo), + Success(cloud::api::DatabaseInfo), Failure(String), } /// A message received by `mgmt` when a compute node is ready. -pub type ComputeReady = Result; +pub type ComputeReady = Result; impl PsqlSessionResult { fn into_compute_ready(self) -> ComputeReady { @@ -111,7 +111,7 @@ fn try_process_query(pgb: &mut PostgresBackend, query_string: &str) -> anyhow::R let resp: PsqlSessionResponse = serde_json::from_str(query_string)?; - match cplane_api::notify(&resp.session_id, resp.result.into_compute_ready()) { + match cloud::notify(&resp.session_id, resp.result.into_compute_ready()) { Ok(()) => { pgb.write_message_noflush(&SINGLE_COL_ROWDESC)? .write_message_noflush(&BeMessage::DataRow(&[Some(b"ok")]))? diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index f7de1618df..4bce2bf40d 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -185,10 +185,10 @@ impl Client { // Authenticate and connect to a compute node. let auth = creds.authenticate(config, &mut stream).await; - let db_info = async { auth }.or_else(|e| stream.throw_error(e)).await?; + let node = async { auth }.or_else(|e| stream.throw_error(e)).await?; let (db, version, cancel_closure) = - db_info.connect().or_else(|e| stream.throw_error(e)).await?; + node.connect().or_else(|e| stream.throw_error(e)).await?; let cancel_key_data = session.enable_cancellation(cancel_closure); stream diff --git a/proxy/src/scram.rs b/proxy/src/scram.rs index 22fce7ac7e..7cc4191435 100644 --- a/proxy/src/scram.rs +++ b/proxy/src/scram.rs @@ -9,10 +9,12 @@ mod exchange; mod key; mod messages; -mod password; mod secret; mod signature; +#[cfg(test)] +mod password; + pub use exchange::Exchange; pub use key::ScramKey; pub use secret::ServerSecret; diff --git a/proxy/src/scram/key.rs b/proxy/src/scram/key.rs index 73dd5e1d5c..e9c65fcef3 100644 --- a/proxy/src/scram/key.rs +++ b/proxy/src/scram/key.rs @@ -16,6 +16,10 @@ impl ScramKey { pub fn sha256(&self) -> Self { super::sha256([self.as_ref()]).into() } + + pub fn as_bytes(&self) -> [u8; SCRAM_KEY_LEN] { + self.bytes + } } impl From<[u8; SCRAM_KEY_LEN]> for ScramKey {