diff --git a/Cargo.lock b/Cargo.lock index de24239357..6084c21a02 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1334,6 +1334,7 @@ dependencies = [ [[package]] name = "postgres-protocol" version = "0.6.1" +source = "git+https://github.com/zenithdb/rust-postgres.git?rev=9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858#9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858" dependencies = [ "base64 0.13.0", "byteorder", @@ -1351,7 +1352,7 @@ dependencies = [ [[package]] name = "postgres-protocol" version = "0.6.1" -source = "git+https://github.com/zenithdb/rust-postgres.git?rev=9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858#9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858" +source = "git+https://github.com/zenithdb/rust-postgres.git?rev=f1f16657aaebe2b9b4b16ef7abf6dc42301bad5d#f1f16657aaebe2b9b4b16ef7abf6dc42301bad5d" dependencies = [ "base64 0.13.0", "byteorder", @@ -1366,15 +1367,6 @@ dependencies = [ "stringprep", ] -[[package]] -name = "postgres-types" -version = "0.2.1" -dependencies = [ - "bytes", - "fallible-iterator", - "postgres-protocol 0.6.1", -] - [[package]] name = "postgres-types" version = "0.2.1" @@ -1385,6 +1377,16 @@ dependencies = [ "postgres-protocol 0.6.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858)", ] +[[package]] +name = "postgres-types" +version = "0.2.1" +source = "git+https://github.com/zenithdb/rust-postgres.git?rev=f1f16657aaebe2b9b4b16ef7abf6dc42301bad5d#f1f16657aaebe2b9b4b16ef7abf6dc42301bad5d" +dependencies = [ + "bytes", + "fallible-iterator", + "postgres-protocol 0.6.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=f1f16657aaebe2b9b4b16ef7abf6dc42301bad5d)", +] + [[package]] name = "postgres_ffi" version = "0.1.0" @@ -1461,10 +1463,9 @@ dependencies = [ "serde", "serde_json", "sha2", - "stringprep", "thiserror", "tokio", - "tokio-postgres 0.7.1", + "tokio-postgres 0.7.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=f1f16657aaebe2b9b4b16ef7abf6dc42301bad5d)", "zenith_utils", ] @@ -2110,27 +2111,6 @@ dependencies = [ "syn", ] -[[package]] -name = "tokio-postgres" -version = "0.7.1" -dependencies = [ - "async-trait", - "byteorder", - "bytes", - "fallible-iterator", - "futures", - "log", - "parking_lot", - "percent-encoding", - "phf", - "pin-project-lite", - "postgres-protocol 0.6.1", - "postgres-types 0.2.1", - "socket2", - "tokio", - "tokio-util", -] - [[package]] name = "tokio-postgres" version = "0.7.1" @@ -2153,6 +2133,28 @@ dependencies = [ "tokio-util", ] +[[package]] +name = "tokio-postgres" +version = "0.7.1" +source = "git+https://github.com/zenithdb/rust-postgres.git?rev=f1f16657aaebe2b9b4b16ef7abf6dc42301bad5d#f1f16657aaebe2b9b4b16ef7abf6dc42301bad5d" +dependencies = [ + "async-trait", + "byteorder", + "bytes", + "fallible-iterator", + "futures", + "log", + "parking_lot", + "percent-encoding", + "phf", + "pin-project-lite", + "postgres-protocol 0.6.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=f1f16657aaebe2b9b4b16ef7abf6dc42301bad5d)", + "postgres-types 0.2.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=f1f16657aaebe2b9b4b16ef7abf6dc42301bad5d)", + "socket2", + "tokio", + "tokio-util", +] + [[package]] name = "tokio-rustls" version = "0.22.0" diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 8c4c0827ee..32b6942425 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -22,10 +22,9 @@ rustls = "0.19.1" serde = "1" serde_json = "1" sha2 = "0.9.8" -stringprep = "0.1.2" thiserror = "1.0.30" tokio = { version = "1.11", features = ['macros'] } -tokio-postgres = { path = "../../rust-postgres/tokio-postgres" } -# tokio-postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev = "f1f16657aaebe2b9b4b16ef7abf6dc42301bad5d"} +# tokio-postgres = { path = "../../rust-postgres/tokio-postgres" } TODO remove this +tokio-postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev = "f1f16657aaebe2b9b4b16ef7abf6dc42301bad5d"} zenith_utils = { path = "../zenith_utils" } diff --git a/proxy/src/auth.rs b/proxy/src/auth.rs index f5e7d8a6d1..5dafaf8c39 100644 --- a/proxy/src/auth.rs +++ b/proxy/src/auth.rs @@ -1,6 +1,7 @@ //! Authentication machinery. use crate::sasl::{SaslFirstMessage, SaslMechanism, SaslMessage, SaslStream}; +use crate::scram::key::ScramKey; use crate::scram::{ScramExchangeServer, ScramSecret}; use anyhow::{bail, Context}; use zenith_utils::postgres_backend::{PostgresBackend, ProtoState}; @@ -87,7 +88,7 @@ impl AuthStream<'_, Md5> { /// Stream wrapper for handling [SCRAM](crate::scram) auth. impl AuthStream<'_, Scram<'_>> { /// Perform user authentication; Raise an error in case authentication failed. - pub fn authenticate(mut self) -> anyhow::Result<()> { + pub fn authenticate(mut self) -> anyhow::Result { // Initial client message contains the chosen auth method's name let msg = self.read_password_message()?; let sasl = SaslFirstMessage::parse(&msg).context("bad SASL message")?; @@ -99,9 +100,9 @@ impl AuthStream<'_, Scram<'_>> { let secret = self.state.0; let stream = (Some(msg.slice_ref(sasl.message)), &mut self); - ScramExchangeServer::new(secret).authenticate(stream)?; + let client_key = ScramExchangeServer::new(secret).authenticate(stream)?; - Ok(()) + Ok(client_key) } } diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 05ab64370e..188b7d29c2 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -115,39 +115,6 @@ pub fn proxy_conn_main(state: &'static ProxyState, socket: TcpStream) -> anyhow: proxy(client.split(), server.split()) } -// HACK copied from tokio-postgres -// since postgres passwords are not required to exclude saslprep-prohibited -// characters or even be valid UTF8, we run saslprep if possible and otherwise -// return the raw password. -fn normalize(pass: &str) -> Vec { - match stringprep::saslprep(pass) { - Ok(pass) => pass.into_owned().into_bytes(), - Err(_) => pass.as_bytes().to_vec(), - } -} - -// HACK copied from tokio-postgres -fn hi(str: &[u8], salt: &[u8], i: u32) -> [u8; 32] { - let mut hmac = Hmac::::new_varkey(str).expect("HMAC is able to accept all key sizes"); - hmac.update(salt); - hmac.update(&[0, 0, 0, 1]); - let mut prev = hmac.finalize().into_bytes(); - - let mut hi = prev; - - for _ in 1..i { - let mut hmac = Hmac::::new_varkey(str).expect("already checked above"); - hmac.update(&prev); - prev = hmac.finalize().into_bytes(); - - for (hi, prev) in hi.iter_mut().zip(prev) { - *hi ^= prev; - } - } - - hi.into() -} - impl ProxyConnection { /// Returns Ok(None) when connection was successfully closed. fn handle_client(mut self) -> anyhow::Result> { @@ -254,53 +221,10 @@ impl ProxyConnection { fn handle_existing_user(&mut self, user: &str, db: &str) -> anyhow::Result { let _cplane = CPlaneApi::new(&self.state.conf.auth_endpoint, &self.state.waiters); - // TODO read from postres instance - // I got these values from the proxy error log on key cache miss - let salt = [180, 76, 114, 155, 212, 214, 236, 192, 101, 236, 235, 4, 212, 87, 25, 85]; - let iterations = 4096; + // I got this by running `select rolname, rolpassword from pg_authid;` + let secret = crate::scram::ScramSecret::parse("SCRAM-SHA-256$4096:tExym9TW7MBl7OsE1FcZVQ==$Ao3nb0bStHOVIqOEUSfXdlvF9XIynqIGzSmDCs2O4p8=:WV6Eenyz5FGwuuVfKh0AXQVnzz4NLnKVV7FpVz1/1zY=").unwrap(); - // TODO read from CLI - let password = "postgres"; - - let normalized = normalize(password); - let salted = hi(&normalized, &salt, iterations); - - let mut mac = Hmac::::new_varkey(&salted).unwrap(); - mac.update(b"Client Key"); - let client_key = mac.finalize().into_bytes(); - - let mut mac = Hmac::::new_varkey(&salted).unwrap(); - mac.update(b"Server Key"); - let server_key = mac.finalize().into_bytes(); - - let mut hash = Sha256::default(); - hash.update(client_key); - let stored_key = hash.finalize_fixed(); - - let secret = crate::scram::ScramSecret { - iterations, - salt_base64: base64::encode(salt), - stored_key: ScramKey { bytes: stored_key.try_into()? }, - server_key: ScramKey { bytes: server_key.try_into()? }, - }; - - // TODO: fetch secret from console - // user='user' password='password' - // let secret = crate::scram::ScramSecret::parse( - // &[ - // "SCRAM-SHA-256", - // "4096:XiWzgkfGNyY3ipsz08PY+A==", - // &[ - // "YMmirZHYtTB6erVDCxL4Zjn66Kn7RCfS+aV3qROV4o8=", - // "aCSKHnugk1l9Ut6VhO5VeeWsB8xhVdPk/NyEgjOJ3nk=", - // ] - // .join(":"), - // ] - // .join("$"), - // ) - // .unwrap(); - - AuthStream::new(&mut self.pgb) + let client_key = AuthStream::new(&mut self.pgb) .begin(auth::Scram(&secret))? .authenticate()?; @@ -312,9 +236,6 @@ impl ProxyConnection { let host = "127.0.0.1"; let port = 5432; - // TODO fish out the real client_key from the authenticate() call above - // let client_key: ScramKey = [0; SCRAM_KEY_LEN].into(); - let client_key = ScramKey { bytes: client_key.as_slice().try_into().unwrap() }; let scram_auth_secret = ScramAuthSecret { iterations: secret.iterations, salt_base64: secret.salt_base64, diff --git a/proxy/src/sasl.rs b/proxy/src/sasl.rs index 7726942c6e..7a994e5183 100644 --- a/proxy/src/sasl.rs +++ b/proxy/src/sasl.rs @@ -109,33 +109,36 @@ pub enum SaslError { /// A convenient result type for SASL exchange. pub type Result = std::result::Result; +pub enum SaslStep { + Transition(T), + Authenticated(R), +} + /// Every SASL mechanism (e.g. [SCRAM](crate::scram)) is expected to implement this trait. -pub trait SaslMechanism: Sized { +pub trait SaslMechanism: Sized { /// Produce a server challenge to be sent to the client. /// This is how this method is called in PostgreSQL (libpq/sasl.h). - fn exchange(self, input: &str) -> Result<(Option, String)>; + fn exchange(self, input: &str) -> Result<(SaslStep, String)>; /// Perform SASL message exchange according to the underlying algorithm /// until user is either authenticated or denied access. - fn authenticate(mut self, mut stream: impl SaslStream) -> Result<()> { + fn authenticate(mut self, mut stream: impl SaslStream) -> Result { loop { let msg = stream.recv()?; let input = std::str::from_utf8(msg.as_ref()).context("bad encoding")?; let (this, reply) = self.exchange(input)?; match this { - Some(this) => { + SaslStep::Transition(this) => { stream.send(&SaslMessage::Continue(reply))?; self = this; } - None => { + SaslStep::Authenticated(outcome) => { stream.send(&SaslMessage::Final(reply))?; - break; + return Ok(outcome); } } } - - Ok(()) } } diff --git a/proxy/src/scram.rs b/proxy/src/scram.rs index 9ff3a84285..13e04294c9 100644 --- a/proxy/src/scram.rs +++ b/proxy/src/scram.rs @@ -15,7 +15,7 @@ mod signature; pub use channel_binding::*; pub use secret::*; -use crate::sasl::{self, SaslError, SaslMechanism}; +use crate::sasl::{self, SaslError, SaslMechanism, SaslStep}; use messages::{ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage}; use signature::SignatureBuilder; @@ -61,9 +61,10 @@ impl<'a> ScramExchangeServer<'a> { } } -impl SaslMechanism for ScramExchangeServer<'_> { - fn exchange(mut self, input: &str) -> sasl::Result<(Option, String)> { +impl SaslMechanism for ScramExchangeServer<'_> { + fn exchange(mut self, input: &str) -> sasl::Result<(sasl::SaslStep, String)> { use ScramExchangeServerState::*; + use sasl::SaslStep::*; match &self.state { Initial => { let client_first_message = @@ -83,7 +84,7 @@ impl SaslMechanism for ScramExchangeServer<'_> { server_first_message, }; - Ok((Some(self), msg)) + Ok((Transition(self), msg)) } SaltSent { cbind_flag, @@ -124,7 +125,7 @@ impl SaslMechanism for ScramExchangeServer<'_> { let msg = client_final_message .build_server_final_message(signature_builder, &self.secret.server_key); - Ok((None, msg)) + Ok((Authenticated(client_key), msg)) } } }