diff --git a/Cargo.lock b/Cargo.lock index 117e102a35..6084c21a02 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1193,8 +1193,8 @@ dependencies = [ "once_cell", "parking_lot", "postgres", - "postgres-protocol", - "postgres-types", + "postgres-protocol 0.6.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858)", + "postgres-types 0.2.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858)", "postgres_ffi", "rand", "regex", @@ -1208,7 +1208,7 @@ dependencies = [ "tempfile", "thiserror", "tokio", - "tokio-postgres", + "tokio-postgres 0.7.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858)", "tokio-stream", "toml_edit", "tracing", @@ -1326,9 +1326,9 @@ dependencies = [ "fallible-iterator", "futures", "log", - "postgres-protocol", + "postgres-protocol 0.6.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858)", "tokio", - "tokio-postgres", + "tokio-postgres 0.7.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858)", ] [[package]] @@ -1349,6 +1349,24 @@ dependencies = [ "stringprep", ] +[[package]] +name = "postgres-protocol" +version = "0.6.1" +source = "git+https://github.com/zenithdb/rust-postgres.git?rev=f1f16657aaebe2b9b4b16ef7abf6dc42301bad5d#f1f16657aaebe2b9b4b16ef7abf6dc42301bad5d" +dependencies = [ + "base64 0.13.0", + "byteorder", + "bytes", + "fallible-iterator", + "hmac 0.10.1", + "lazy_static", + "md-5", + "memchr", + "rand", + "sha2", + "stringprep", +] + [[package]] name = "postgres-types" version = "0.2.1" @@ -1356,7 +1374,17 @@ source = "git+https://github.com/zenithdb/rust-postgres.git?rev=9eb0dbfbeb6a6c1b dependencies = [ "bytes", "fallible-iterator", - "postgres-protocol", + "postgres-protocol 0.6.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858)", +] + +[[package]] +name = "postgres-types" +version = "0.2.1" +source = "git+https://github.com/zenithdb/rust-postgres.git?rev=f1f16657aaebe2b9b4b16ef7abf6dc42301bad5d#f1f16657aaebe2b9b4b16ef7abf6dc42301bad5d" +dependencies = [ + "bytes", + "fallible-iterator", + "postgres-protocol 0.6.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=f1f16657aaebe2b9b4b16ef7abf6dc42301bad5d)", ] [[package]] @@ -1437,7 +1465,7 @@ dependencies = [ "sha2", "thiserror", "tokio", - "tokio-postgres", + "tokio-postgres 0.7.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=f1f16657aaebe2b9b4b16ef7abf6dc42301bad5d)", "zenith_utils", ] @@ -2098,8 +2126,30 @@ dependencies = [ "percent-encoding", "phf", "pin-project-lite", - "postgres-protocol", - "postgres-types", + "postgres-protocol 0.6.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858)", + "postgres-types 0.2.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858)", + "socket2", + "tokio", + "tokio-util", +] + +[[package]] +name = "tokio-postgres" +version = "0.7.1" +source = "git+https://github.com/zenithdb/rust-postgres.git?rev=f1f16657aaebe2b9b4b16ef7abf6dc42301bad5d#f1f16657aaebe2b9b4b16ef7abf6dc42301bad5d" +dependencies = [ + "async-trait", + "byteorder", + "bytes", + "fallible-iterator", + "futures", + "log", + "parking_lot", + "percent-encoding", + "phf", + "pin-project-lite", + "postgres-protocol 0.6.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=f1f16657aaebe2b9b4b16ef7abf6dc42301bad5d)", + "postgres-types 0.2.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=f1f16657aaebe2b9b4b16ef7abf6dc42301bad5d)", "socket2", "tokio", "tokio-util", @@ -2337,7 +2387,7 @@ dependencies = [ "hyper", "lazy_static", "postgres", - "postgres-protocol", + "postgres-protocol 0.6.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858)", "postgres_ffi", "regex", "routerify", @@ -2347,7 +2397,7 @@ dependencies = [ "signal-hook", "tempfile", "tokio", - "tokio-postgres", + "tokio-postgres 0.7.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858)", "tracing", "walkdir", "workspace_hack", diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 7845d714bf..32b6942425 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -24,6 +24,7 @@ serde_json = "1" sha2 = "0.9.8" thiserror = "1.0.30" tokio = { version = "1.11", features = ['macros'] } -tokio-postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858" } +# tokio-postgres = { path = "../../rust-postgres/tokio-postgres" } TODO remove this +tokio-postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev = "f1f16657aaebe2b9b4b16ef7abf6dc42301bad5d"} zenith_utils = { path = "../zenith_utils" } diff --git a/proxy/src/cplane_api.rs b/proxy/src/cplane_api.rs index 5ca3838b7e..5cc3fe3b2f 100644 --- a/proxy/src/cplane_api.rs +++ b/proxy/src/cplane_api.rs @@ -21,35 +21,6 @@ enum ProxyAuthResponse { NotReady { ready: bool }, // TODO: get rid of `ready` } -impl DatabaseInfo { - pub fn socket_addr(&self) -> anyhow::Result { - let host_port = format!("{}:{}", self.host, self.port); - host_port - .to_socket_addrs() - .with_context(|| format!("cannot resolve {} to SocketAddr", host_port))? - .next() - .ok_or_else(|| anyhow!("cannot resolve at least one SocketAddr")) - } -} - -impl From for tokio_postgres::Config { - fn from(db_info: DatabaseInfo) -> Self { - let mut config = tokio_postgres::Config::new(); - - config - .host(&db_info.host) - .port(db_info.port) - .dbname(&db_info.dbname) - .user(&db_info.user); - - if let Some(password) = db_info.password { - config.password(password); - } - - config - } -} - pub struct CPlaneApi<'a> { auth_endpoint: &'a str, waiters: &'a ProxyWaiters, diff --git a/proxy/src/db.rs b/proxy/src/db.rs new file mode 100644 index 0000000000..0491c3668b --- /dev/null +++ b/proxy/src/db.rs @@ -0,0 +1,69 @@ +/// +/// Utils for connecting with the postgres dataabase. +/// + +use crate::scram::key::ScramKey; +use std::net::{SocketAddr, ToSocketAddrs}; +use anyhow::{Context, anyhow}; + +/// Sufficient information to authenticate as client. +pub struct ScramAuthSecret { + pub iterations: u32, + pub salt_base64: String, + pub client_key: ScramKey, + pub server_key: ScramKey, +} + +#[non_exhaustive] +pub enum AuthSecret { + Scram(ScramAuthSecret), + Password(String), +} + + +pub struct DatabaseAuthInfo { + pub host: String, + pub port: u16, + pub dbname: String, + pub user: String, + pub auth_secret: AuthSecret, +} + +impl From for tokio_postgres::Config { + fn from(auth_info: DatabaseAuthInfo) -> Self { + let mut config = tokio_postgres::Config::new(); + + config + .host(&auth_info.host) + .port(auth_info.port) + .dbname(&auth_info.dbname) + .user(&auth_info.user); + + match auth_info.auth_secret { + AuthSecret::Scram(scram_secret) => { + config.add_scram_key( + scram_secret.salt_base64.into_bytes(), // TODO test this + scram_secret.iterations, + scram_secret.client_key.bytes.to_vec(), + scram_secret.server_key.bytes.to_vec(), + ); + }, + AuthSecret::Password(password) => { + config.password(password); + } + } + + config + } +} + +impl DatabaseAuthInfo { + pub fn socket_addr(&self) -> anyhow::Result { + let host_port = format!("{}:{}", self.host, self.port); + host_port + .to_socket_addrs() + .with_context(|| format!("cannot resolve {} to SocketAddr", host_port))? + .next() + .ok_or_else(|| anyhow!("cannot resolve at least one SocketAddr")) + } +} diff --git a/proxy/src/main.rs b/proxy/src/main.rs index 6f26a82fbc..7142b31861 100644 --- a/proxy/src/main.rs +++ b/proxy/src/main.rs @@ -11,6 +11,7 @@ use state::{ProxyConfig, ProxyState}; use std::thread; use zenith_utils::{tcp_listener, GIT_VERSION}; +mod db; mod auth; mod cplane_api; mod mgmt; diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index f02f16eeed..182ef88f11 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -1,6 +1,8 @@ use crate::auth::{self, AuthStream}; use crate::cplane_api::{CPlaneApi, DatabaseInfo}; use crate::ProxyState; +use crate::db::{AuthSecret, DatabaseAuthInfo, ScramAuthSecret}; +use crate::scram::key::{SCRAM_KEY_LEN, ScramKey}; use anyhow::{anyhow, bail}; use lazy_static::lazy_static; use parking_lot::Mutex; @@ -127,7 +129,7 @@ impl ProxyConnection { }; let conn = match authenticate() { - Ok(Some(db_info)) => connect_to_db(db_info), + Ok(Some(db_auth_info)) => connect_to_db(db_auth_info), Ok(None) => return Ok(None), Err(e) => { // Report the error to the client @@ -212,7 +214,7 @@ impl ProxyConnection { } } - fn handle_existing_user(&mut self, _user: &str, _db: &str) -> anyhow::Result { + fn handle_existing_user(&mut self, user: &str, db: &str) -> anyhow::Result { let _cplane = CPlaneApi::new(&self.state.conf.auth_endpoint, &self.state.waiters); // TODO: fetch secret from console @@ -239,10 +241,30 @@ impl ProxyConnection { .write_message_noflush(&Be::AuthenticationOk)? .write_message_noflush(&BeParameterStatusMessage::encoding())?; - todo!() + // TODO get this info from console and tell it to start the db + let host = ""; + let port = 0; + + // TODO fish out the real client_key from the authenticate() call above + let client_key: ScramKey = [0; SCRAM_KEY_LEN].into(); + let scram_auth_secret = ScramAuthSecret { + iterations: secret.iterations, + salt_base64: secret.salt_base64, + client_key, + server_key: secret.server_key, + }; + let auth_info = DatabaseAuthInfo { + host: host.into(), + port, + dbname: db.into(), + user: user.into(), + auth_secret: AuthSecret::Scram(scram_auth_secret) + }; + + Ok(auth_info) } - fn handle_new_user(&mut self) -> anyhow::Result { + fn handle_new_user(&mut self) -> anyhow::Result { let greeting = hello_message(&self.state.conf.redirect_uri, &self.psql_session_id); // First, register this session @@ -260,7 +282,15 @@ impl ProxyConnection { self.pgb .write_message_noflush(&Be::NoticeResponse("Connecting to database.".into()))?; - Ok(db_info) + let db_auth_info = DatabaseAuthInfo { + host: db_info.host, + port: db_info.port, + dbname: db_info.dbname, + user: db_info.user, + auth_secret: AuthSecret::Password(db_info.password.unwrap()) + }; + + Ok(db_auth_info) } } @@ -281,7 +311,7 @@ fn hello_message(redirect_uri: &str, session_id: &str) -> String { /// Create a TCP connection to a postgres database, authenticate with it, and receive the ReadyForQuery message async fn connect_to_db( - db_info: DatabaseInfo, + db_info: DatabaseAuthInfo, ) -> anyhow::Result<(String, tokio::net::TcpStream, CancelKeyData)> { // Make raw connection. When connect_raw finishes we've received ReadyForQuery. let socket_addr = db_info.socket_addr()?; diff --git a/proxy/src/scram.rs b/proxy/src/scram.rs index 44519c453f..9ff3a84285 100644 --- a/proxy/src/scram.rs +++ b/proxy/src/scram.rs @@ -7,7 +7,7 @@ //! * mod channel_binding; -mod key; +pub mod key; // TODO do I have to make it pub? mod messages; mod secret; mod signature; diff --git a/proxy/src/scram/key.rs b/proxy/src/scram/key.rs index cae31e3919..37ac586d45 100644 --- a/proxy/src/scram/key.rs +++ b/proxy/src/scram/key.rs @@ -6,10 +6,10 @@ use sha2::{Digest, Sha256}; pub const SCRAM_KEY_LEN: usize = 32; /// Thin wrapper for byte array. -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, PartialEq, Eq)] // TODO maybe no debug? Avoid accidental logging. #[repr(transparent)] pub struct ScramKey { - bytes: [u8; SCRAM_KEY_LEN], + pub bytes: [u8; SCRAM_KEY_LEN], // TODO does it have to be public? } impl ScramKey {