diff --git a/Cargo.lock b/Cargo.lock index d9c5ae5123..5a71566b80 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -86,9 +86,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.51" +version = "0.1.52" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44318e776df68115a881de9a8fd1b9e53368d7a4a5ce4cc48517da3393233a5e" +checksum = "061a7acccaa286c011ddc30970520b98fa40e00c9d644633fb26b5fc63a265e3" dependencies = [ "proc-macro2", "quote", @@ -1564,6 +1564,8 @@ name = "proxy" version = "0.1.0" dependencies = [ "anyhow", + "async-trait", + "base64 0.13.0", "bytes", "clap", "futures", diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 09bf1b72f8..6539d6d444 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -31,6 +31,8 @@ scopeguard = "1.1.0" zenith_utils = { path = "../zenith_utils" } zenith_metrics = { path = "../zenith_metrics" } +base64 = "0.13.0" +async-trait = "0.1.52" [dev-dependencies] tokio-postgres-rustls = "0.8.0" diff --git a/proxy/src/auth.rs b/proxy/src/auth.rs new file mode 100644 index 0000000000..9f4a8fdd81 --- /dev/null +++ b/proxy/src/auth.rs @@ -0,0 +1,41 @@ +use crate::db::AuthSecret; +use crate::stream::PqStream; +use bytes::Bytes; +use tokio::io::{AsyncRead, AsyncWrite}; +use zenith_utils::pq_proto::BeMessage as Be; + + +/// Stored secret for authenticating the user via md5 but authenticating +/// to the compute database with a (possibly different) plaintext password. +pub struct PlaintextStoredSecret { + pub salt: [u8; 4], + pub hashed_salted_password: Bytes, + pub compute_db_password: String, +} + +/// Sufficient information to auth user and create AuthSecret +#[non_exhaustive] +pub enum StoredSecret { + PlaintextPassword(PlaintextStoredSecret), + // TODO add md5 option? + // TODO add SCRAM option +} + +pub async fn authenticate( + client: &mut PqStream, + stored_secret: StoredSecret +) -> anyhow::Result { + match stored_secret { + StoredSecret::PlaintextPassword(stored) => { + client.write_message(&Be::AuthenticationMD5Password(&stored.salt)).await?; + let provided = client.read_password_message().await?; + anyhow::ensure!(provided == stored.hashed_salted_password); + Ok(AuthSecret::Password(stored.compute_db_password)) + }, + } +} + +#[async_trait::async_trait] +pub trait SecretStore { + async fn get_stored_secret(&self, creds: &crate::cplane_api::ClientCredentials) -> anyhow::Result; +} diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs new file mode 100644 index 0000000000..0721f17031 --- /dev/null +++ b/proxy/src/compute.rs @@ -0,0 +1,7 @@ +use crate::{cplane_api::ClientCredentials, db::DatabaseConnInfo}; + + +#[async_trait::async_trait] +pub trait ComputeProvider { + async fn get_compute_node(&self, creds: &ClientCredentials) -> anyhow::Result; +} diff --git a/proxy/src/cplane_api.rs b/proxy/src/cplane_api.rs index 5d96dacaf1..8209719102 100644 --- a/proxy/src/cplane_api.rs +++ b/proxy/src/cplane_api.rs @@ -45,35 +45,6 @@ enum ProxyAuthResponse { NotReady { ready: bool }, // TODO: get rid of `ready` } -impl DatabaseInfo { - pub fn socket_addr(&self) -> anyhow::Result { - let host_port = format!("{}:{}", self.host, self.port); - host_port - .to_socket_addrs() - .with_context(|| format!("cannot resolve {} to SocketAddr", host_port))? - .next() - .context("cannot resolve at least one SocketAddr") - } -} - -impl From for tokio_postgres::Config { - fn from(db_info: DatabaseInfo) -> Self { - let mut config = tokio_postgres::Config::new(); - - config - .host(&db_info.host) - .port(db_info.port) - .dbname(&db_info.dbname) - .user(&db_info.user); - - if let Some(password) = db_info.password { - config.password(password); - } - - config - } -} - pub struct CPlaneApi<'a> { auth_endpoint: &'a str, waiters: &'a ProxyWaiters, diff --git a/proxy/src/db.rs b/proxy/src/db.rs new file mode 100644 index 0000000000..1c5a401a12 --- /dev/null +++ b/proxy/src/db.rs @@ -0,0 +1,58 @@ +/// +/// Utils for connecting with the postgres dataabase. +/// + +use std::net::{SocketAddr, ToSocketAddrs}; +use anyhow::{Context, anyhow}; + +use crate::cplane_api::ClientCredentials; + +pub struct DatabaseConnInfo { + pub host: String, + pub port: u16, +} + +pub struct DatabaseAuthInfo { + pub conn_info: DatabaseConnInfo, + pub creds: ClientCredentials, + pub auth_secret: AuthSecret, +} + +/// Sufficient information to auth with database +#[non_exhaustive] +#[derive(Debug)] +pub enum AuthSecret { + Password(String), + // TODO add SCRAM option +} + +impl From for tokio_postgres::Config { + fn from(auth_info: DatabaseAuthInfo) -> Self { + let mut config = tokio_postgres::Config::new(); + + config + .host(&auth_info.conn_info.host) + .port(auth_info.conn_info.port) + .dbname(&auth_info.creds.dbname) + .user(&auth_info.creds.user); + + match auth_info.auth_secret { + AuthSecret::Password(password) => { + config.password(password); + } + } + + config + } +} + +impl DatabaseConnInfo { + pub fn socket_addr(&self) -> anyhow::Result { + let host_port = format!("{}:{}", self.host, self.port); + host_port + .to_socket_addrs() + .with_context(|| format!("cannot resolve {} to SocketAddr", host_port))? + .next() + .ok_or_else(|| anyhow!("cannot resolve at least one SocketAddr")) + } +} diff --git a/proxy/src/main.rs b/proxy/src/main.rs index 70c20f90b2..ce2288c16a 100644 --- a/proxy/src/main.rs +++ b/proxy/src/main.rs @@ -10,6 +10,10 @@ use clap::{App, Arg}; use state::{ProxyConfig, ProxyState}; use zenith_utils::{tcp_listener, GIT_VERSION}; +mod compute; +mod mock; +mod auth; +mod db; mod cancellation; mod cplane_api; mod http; diff --git a/proxy/src/mock.rs b/proxy/src/mock.rs new file mode 100644 index 0000000000..3291a44aec --- /dev/null +++ b/proxy/src/mock.rs @@ -0,0 +1,32 @@ +use bytes::Bytes; + +use crate::{auth::{PlaintextStoredSecret, SecretStore, StoredSecret}, compute::ComputeProvider, cplane_api::ClientCredentials, db::DatabaseConnInfo}; + + +pub struct MockConsole { +} + +#[async_trait::async_trait] +impl SecretStore for MockConsole { + async fn get_stored_secret(&self, creds: &ClientCredentials) -> anyhow::Result { + let salt = [0; 4]; + match (&creds.user[..], &creds.dbname[..]) { + ("postgres", "postgres") => Ok(StoredSecret::PlaintextPassword(PlaintextStoredSecret { + salt, + hashed_salted_password: "md52fff09cd9def51601fc5445943b3a11f\0".into(), + compute_db_password: "postgres".into(), + })), + _ => unimplemented!() + } + } +} + +#[async_trait::async_trait] +impl ComputeProvider for MockConsole{ + async fn get_compute_node(&self, creds: &ClientCredentials) -> anyhow::Result { + return Ok(DatabaseConnInfo { + host: "127.0.0.1".into(), + port: 5432, + }) + } +} diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 4666c941d9..096211e6c9 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -1,5 +1,9 @@ +use crate::auth::{self, StoredSecret, SecretStore}; use crate::cancellation::{self, CancelClosure}; +use crate::compute::ComputeProvider; use crate::cplane_api as cplane; +use crate::db::{AuthSecret, DatabaseAuthInfo}; +use crate::mock::MockConsole; use crate::state::SslConfig; use crate::stream::{PqStream, Stream}; use crate::ProxyState; @@ -140,24 +144,28 @@ async fn handshake( } } -// TODO: implement proper authentication async fn connect_client_to_db( mut client: PqStream, creds: cplane::ClientCredentials, session: cancellation::Session, ) -> anyhow::Result<()> { - // TODO: get this from an api call - let db_info = cplane::DatabaseInfo { - host: "127.0.0.1".into(), - port: 5432, - dbname: creds.dbname, - user: "dmitry".into(), - password: None, + // Authenticate + // TODO use real console + let console = MockConsole {}; + let stored_secret = console.get_stored_secret(&creds).await?; + let auth_secret = auth::authenticate(&mut client, stored_secret).await?; + let conn_info = console.get_compute_node(&creds).await?; + let db_auth_info = DatabaseAuthInfo { + conn_info, + creds, + auth_secret, }; - let (mut db, version, cancel_closure) = connect_to_db(db_info).await?; + // Connect to db + let (mut db, version, cancel_closure) = connect_to_db(db_auth_info).await?; let cancel_key_data = session.enable_cancellation(cancel_closure); + // Report success to client client .write_message_noflush(&Be::AuthenticationOk)? .write_message_noflush(&BeParameterStatusMessage::encoding())? @@ -191,10 +199,10 @@ fn hello_message(redirect_uri: &str, session_id: &str) -> String { /// Connect to a corresponding compute node. async fn connect_to_db( - db_info: cplane::DatabaseInfo, + db_info: DatabaseAuthInfo, ) -> anyhow::Result<(TcpStream, String, CancelClosure)> { // TODO: establish a secure connection to the DB - let socket_addr = db_info.socket_addr()?; + let socket_addr = db_info.conn_info.socket_addr()?; let mut socket = TcpStream::connect(socket_addr).await?; let (client, conn) = tokio_postgres::Config::from(db_info) diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index 47ac0810d4..0d39818ac7 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -49,6 +49,14 @@ impl PqStream { other => anyhow::bail!("bad message type: {:?}", other), } } + + pub async fn read_password_message(&mut self) -> anyhow::Result { + match FeMessage::read_fut(&mut self.stream).await? { + Some(FeMessage::PasswordMessage(msg)) => Ok(msg), + None => anyhow::bail!("connection is lost"), + other => anyhow::bail!("bad message type: {:?}", other), + } + } } impl PqStream {