mirror of
https://github.com/neondatabase/neon.git
synced 2026-02-05 11:40:37 +00:00
Compare commits
7 Commits
projects_m
...
proxy-scra
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c743e7a5cc | ||
|
|
d44873246d | ||
|
|
4c65cc3be7 | ||
|
|
b791f47eb3 | ||
|
|
7dc933b741 | ||
|
|
7e714ce8be | ||
|
|
61a5b59224 |
76
Cargo.lock
generated
76
Cargo.lock
generated
@@ -1193,8 +1193,8 @@ dependencies = [
|
|||||||
"once_cell",
|
"once_cell",
|
||||||
"parking_lot",
|
"parking_lot",
|
||||||
"postgres",
|
"postgres",
|
||||||
"postgres-protocol",
|
"postgres-protocol 0.6.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858)",
|
||||||
"postgres-types",
|
"postgres-types 0.2.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858)",
|
||||||
"postgres_ffi",
|
"postgres_ffi",
|
||||||
"rand",
|
"rand",
|
||||||
"regex",
|
"regex",
|
||||||
@@ -1208,7 +1208,7 @@ dependencies = [
|
|||||||
"tempfile",
|
"tempfile",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-postgres",
|
"tokio-postgres 0.7.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858)",
|
||||||
"tokio-stream",
|
"tokio-stream",
|
||||||
"toml_edit",
|
"toml_edit",
|
||||||
"tracing",
|
"tracing",
|
||||||
@@ -1326,9 +1326,9 @@ dependencies = [
|
|||||||
"fallible-iterator",
|
"fallible-iterator",
|
||||||
"futures",
|
"futures",
|
||||||
"log",
|
"log",
|
||||||
"postgres-protocol",
|
"postgres-protocol 0.6.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858)",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-postgres",
|
"tokio-postgres 0.7.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858)",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -1349,6 +1349,24 @@ dependencies = [
|
|||||||
"stringprep",
|
"stringprep",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "postgres-protocol"
|
||||||
|
version = "0.6.1"
|
||||||
|
source = "git+https://github.com/zenithdb/rust-postgres.git?rev=f1f16657aaebe2b9b4b16ef7abf6dc42301bad5d#f1f16657aaebe2b9b4b16ef7abf6dc42301bad5d"
|
||||||
|
dependencies = [
|
||||||
|
"base64 0.13.0",
|
||||||
|
"byteorder",
|
||||||
|
"bytes",
|
||||||
|
"fallible-iterator",
|
||||||
|
"hmac 0.10.1",
|
||||||
|
"lazy_static",
|
||||||
|
"md-5",
|
||||||
|
"memchr",
|
||||||
|
"rand",
|
||||||
|
"sha2",
|
||||||
|
"stringprep",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "postgres-types"
|
name = "postgres-types"
|
||||||
version = "0.2.1"
|
version = "0.2.1"
|
||||||
@@ -1356,7 +1374,17 @@ source = "git+https://github.com/zenithdb/rust-postgres.git?rev=9eb0dbfbeb6a6c1b
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"bytes",
|
"bytes",
|
||||||
"fallible-iterator",
|
"fallible-iterator",
|
||||||
"postgres-protocol",
|
"postgres-protocol 0.6.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858)",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "postgres-types"
|
||||||
|
version = "0.2.1"
|
||||||
|
source = "git+https://github.com/zenithdb/rust-postgres.git?rev=f1f16657aaebe2b9b4b16ef7abf6dc42301bad5d#f1f16657aaebe2b9b4b16ef7abf6dc42301bad5d"
|
||||||
|
dependencies = [
|
||||||
|
"bytes",
|
||||||
|
"fallible-iterator",
|
||||||
|
"postgres-protocol 0.6.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=f1f16657aaebe2b9b4b16ef7abf6dc42301bad5d)",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -1421,9 +1449,11 @@ name = "proxy"
|
|||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
|
"base64 0.13.0",
|
||||||
"bytes",
|
"bytes",
|
||||||
"clap",
|
"clap",
|
||||||
"hex",
|
"hex",
|
||||||
|
"hmac 0.10.1",
|
||||||
"lazy_static",
|
"lazy_static",
|
||||||
"md5",
|
"md5",
|
||||||
"parking_lot",
|
"parking_lot",
|
||||||
@@ -1432,8 +1462,10 @@ dependencies = [
|
|||||||
"rustls 0.19.1",
|
"rustls 0.19.1",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
"sha2",
|
||||||
|
"thiserror",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-postgres",
|
"tokio-postgres 0.7.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=f1f16657aaebe2b9b4b16ef7abf6dc42301bad5d)",
|
||||||
"zenith_utils",
|
"zenith_utils",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -2094,8 +2126,30 @@ dependencies = [
|
|||||||
"percent-encoding",
|
"percent-encoding",
|
||||||
"phf",
|
"phf",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"postgres-protocol",
|
"postgres-protocol 0.6.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858)",
|
||||||
"postgres-types",
|
"postgres-types 0.2.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858)",
|
||||||
|
"socket2",
|
||||||
|
"tokio",
|
||||||
|
"tokio-util",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tokio-postgres"
|
||||||
|
version = "0.7.1"
|
||||||
|
source = "git+https://github.com/zenithdb/rust-postgres.git?rev=f1f16657aaebe2b9b4b16ef7abf6dc42301bad5d#f1f16657aaebe2b9b4b16ef7abf6dc42301bad5d"
|
||||||
|
dependencies = [
|
||||||
|
"async-trait",
|
||||||
|
"byteorder",
|
||||||
|
"bytes",
|
||||||
|
"fallible-iterator",
|
||||||
|
"futures",
|
||||||
|
"log",
|
||||||
|
"parking_lot",
|
||||||
|
"percent-encoding",
|
||||||
|
"phf",
|
||||||
|
"pin-project-lite",
|
||||||
|
"postgres-protocol 0.6.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=f1f16657aaebe2b9b4b16ef7abf6dc42301bad5d)",
|
||||||
|
"postgres-types 0.2.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=f1f16657aaebe2b9b4b16ef7abf6dc42301bad5d)",
|
||||||
"socket2",
|
"socket2",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-util",
|
"tokio-util",
|
||||||
@@ -2333,7 +2387,7 @@ dependencies = [
|
|||||||
"hyper",
|
"hyper",
|
||||||
"lazy_static",
|
"lazy_static",
|
||||||
"postgres",
|
"postgres",
|
||||||
"postgres-protocol",
|
"postgres-protocol 0.6.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858)",
|
||||||
"postgres_ffi",
|
"postgres_ffi",
|
||||||
"regex",
|
"regex",
|
||||||
"routerify",
|
"routerify",
|
||||||
@@ -2343,7 +2397,7 @@ dependencies = [
|
|||||||
"signal-hook",
|
"signal-hook",
|
||||||
"tempfile",
|
"tempfile",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-postgres",
|
"tokio-postgres 0.7.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858)",
|
||||||
"tracing",
|
"tracing",
|
||||||
"walkdir",
|
"walkdir",
|
||||||
"workspace_hack",
|
"workspace_hack",
|
||||||
|
|||||||
@@ -8,18 +8,23 @@ edition = "2018"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow = "1.0"
|
anyhow = "1.0"
|
||||||
|
base64 = "0.13.0"
|
||||||
bytes = { version = "1.0.1", features = ['serde'] }
|
bytes = { version = "1.0.1", features = ['serde'] }
|
||||||
|
clap = "2.33.0"
|
||||||
|
hex = "0.4.3"
|
||||||
|
hmac = "0.10.1"
|
||||||
lazy_static = "1.4.0"
|
lazy_static = "1.4.0"
|
||||||
md5 = "0.7.0"
|
md5 = "0.7.0"
|
||||||
rand = "0.8.3"
|
|
||||||
hex = "0.4.3"
|
|
||||||
parking_lot = "0.11.2"
|
parking_lot = "0.11.2"
|
||||||
|
rand = "0.8.3"
|
||||||
|
reqwest = { version = "0.11", default-features = false, features = ["blocking", "json", "rustls-tls"] }
|
||||||
|
rustls = "0.19.1"
|
||||||
serde = "1"
|
serde = "1"
|
||||||
serde_json = "1"
|
serde_json = "1"
|
||||||
tokio = { version = "1.11", features = ["macros"] }
|
sha2 = "0.9.8"
|
||||||
tokio-postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="9eb0dbfbeb6a6c1b79099b9f7ae4a8c021877858" }
|
thiserror = "1.0.30"
|
||||||
clap = "2.33.0"
|
tokio = { version = "1.11", features = ['macros'] }
|
||||||
rustls = "0.19.1"
|
# tokio-postgres = { path = "../../rust-postgres/tokio-postgres" } TODO remove this
|
||||||
reqwest = { version = "0.11", default-features = false, features = ["blocking", "json", "rustls-tls"] }
|
tokio-postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev = "f1f16657aaebe2b9b4b16ef7abf6dc42301bad5d"}
|
||||||
|
|
||||||
zenith_utils = { path = "../zenith_utils" }
|
zenith_utils = { path = "../zenith_utils" }
|
||||||
|
|||||||
140
proxy/src/auth.rs
Normal file
140
proxy/src/auth.rs
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
//! Authentication machinery.
|
||||||
|
|
||||||
|
use crate::sasl::{SaslFirstMessage, SaslMechanism, SaslMessage, SaslStream};
|
||||||
|
use crate::scram::key::ScramKey;
|
||||||
|
use crate::scram::{ScramExchangeServer, ScramSecret};
|
||||||
|
use anyhow::{bail, Context};
|
||||||
|
use zenith_utils::postgres_backend::{PostgresBackend, ProtoState};
|
||||||
|
use zenith_utils::pq_proto::{
|
||||||
|
BeAuthenticationSaslMessage as BeSaslMessage, BeMessage as Be, FeMessage as Fe, *,
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO: add SCRAM-SHA-256-PLUS
|
||||||
|
/// A list of supported SCRAM methods.
|
||||||
|
const SCRAM_METHODS: &[&str] = &["SCRAM-SHA-256"];
|
||||||
|
|
||||||
|
/// Initial state of [`AuthStream`].
|
||||||
|
pub struct Begin;
|
||||||
|
|
||||||
|
/// Use [SCRAM](crate::scram)-based auth in [`AuthStream`].
|
||||||
|
pub struct Scram<'a>(pub &'a ScramSecret);
|
||||||
|
|
||||||
|
/// Use password-based auth in [`AuthStream`].
|
||||||
|
pub struct Md5(
|
||||||
|
/// Salt for client.
|
||||||
|
pub [u8; 4],
|
||||||
|
);
|
||||||
|
|
||||||
|
/// Every authentication selector is supposed to implement this trait.
|
||||||
|
pub trait AuthMethod {
|
||||||
|
/// Any authentication selector should provide initial backend message
|
||||||
|
/// containing auth method name and parameters, e.g. md5 salt.
|
||||||
|
fn first_message(&self) -> BeMessage<'_>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AuthMethod for Scram<'_> {
|
||||||
|
#[inline(always)]
|
||||||
|
fn first_message(&self) -> BeMessage<'_> {
|
||||||
|
Be::AuthenticationSasl(BeSaslMessage::Methods(SCRAM_METHODS))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AuthMethod for Md5 {
|
||||||
|
#[inline(always)]
|
||||||
|
fn first_message(&self) -> BeMessage<'_> {
|
||||||
|
Be::AuthenticationMD5Password(self.0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This wrapper for [`PostgresBackend`] performs client authentication.
|
||||||
|
#[must_use]
|
||||||
|
pub struct AuthStream<'a, State> {
|
||||||
|
/// The underlying stream which implements libpq's protocol.
|
||||||
|
pgb: &'a mut PostgresBackend,
|
||||||
|
/// State might contain ancillary data (see [`AuthStream::begin`]).
|
||||||
|
state: State,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Initial state of the stream wrapper.
|
||||||
|
impl<'a> AuthStream<'a, Begin> {
|
||||||
|
/// Create a new wrapper for client authentication.
|
||||||
|
pub fn new(pgb: &'a mut PostgresBackend) -> Self {
|
||||||
|
Self { pgb, state: Begin }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Move to the next step by sending auth method's name & params to client.
|
||||||
|
pub fn begin<M: AuthMethod>(self, method: M) -> anyhow::Result<AuthStream<'a, M>> {
|
||||||
|
self.pgb.write_message(&method.first_message())?;
|
||||||
|
self.pgb.state = ProtoState::Authentication;
|
||||||
|
|
||||||
|
Ok(AuthStream {
|
||||||
|
pgb: self.pgb,
|
||||||
|
state: method,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Stream wrapper for handling simple MD5 password auth.
|
||||||
|
impl AuthStream<'_, Md5> {
|
||||||
|
/// Perform user authentication; Raise an error in case authentication failed.
|
||||||
|
pub fn authenticate(mut self) -> anyhow::Result<()> {
|
||||||
|
let msg = self.read_password_message()?;
|
||||||
|
let (_trailing_null, _md5_response) = msg.split_last().context("bad password message")?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Stream wrapper for handling [SCRAM](crate::scram) auth.
|
||||||
|
impl AuthStream<'_, Scram<'_>> {
|
||||||
|
/// Perform user authentication; Raise an error in case authentication failed.
|
||||||
|
pub fn authenticate(mut self) -> anyhow::Result<ScramKey> {
|
||||||
|
// Initial client message contains the chosen auth method's name
|
||||||
|
let msg = self.read_password_message()?;
|
||||||
|
let sasl = SaslFirstMessage::parse(&msg).context("bad SASL message")?;
|
||||||
|
|
||||||
|
// Currently, the only supported SASL method is SCRAM
|
||||||
|
if !SCRAM_METHODS.contains(&sasl.method) {
|
||||||
|
bail!("unsupported SASL method: {}", sasl.method);
|
||||||
|
}
|
||||||
|
|
||||||
|
let secret = self.state.0;
|
||||||
|
let stream = (Some(msg.slice_ref(sasl.message)), &mut self);
|
||||||
|
let client_key = ScramExchangeServer::new(secret).authenticate(stream)?;
|
||||||
|
|
||||||
|
Ok(client_key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Only [`AuthMethod`] states should receive password messages.
|
||||||
|
impl<M: AuthMethod> AuthStream<'_, M> {
|
||||||
|
/// Receive a new [`PasswordMessage`](FeMessage::PasswordMessage) and extract its payload.
|
||||||
|
fn read_password_message(&mut self) -> anyhow::Result<bytes::Bytes> {
|
||||||
|
match self.pgb.read_message()? {
|
||||||
|
Some(Fe::PasswordMessage(msg)) => Ok(msg),
|
||||||
|
None => bail!("connection is lost"),
|
||||||
|
bad => bail!("unexpected message type: {:?}", bad),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Abstract away all intricacies of [`PostgresBackend`],
|
||||||
|
/// since [SASL](crate::sasl) protocols are text-based.
|
||||||
|
impl SaslStream for AuthStream<'_, Scram<'_>> {
|
||||||
|
type In = bytes::Bytes;
|
||||||
|
|
||||||
|
fn recv(&mut self) -> anyhow::Result<Self::In> {
|
||||||
|
self.read_password_message()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn send(&mut self, data: &SaslMessage<impl AsRef<[u8]>>) -> anyhow::Result<()> {
|
||||||
|
let reply = match data {
|
||||||
|
SaslMessage::Continue(reply) => BeSaslMessage::Continue(reply.as_ref()),
|
||||||
|
SaslMessage::Final(reply) => BeSaslMessage::Final(reply.as_ref()),
|
||||||
|
};
|
||||||
|
|
||||||
|
self.pgb.write_message(&Be::AuthenticationSasl(reply))?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -21,35 +21,6 @@ enum ProxyAuthResponse {
|
|||||||
NotReady { ready: bool }, // TODO: get rid of `ready`
|
NotReady { ready: bool }, // TODO: get rid of `ready`
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DatabaseInfo {
|
|
||||||
pub fn socket_addr(&self) -> anyhow::Result<SocketAddr> {
|
|
||||||
let host_port = format!("{}:{}", self.host, self.port);
|
|
||||||
host_port
|
|
||||||
.to_socket_addrs()
|
|
||||||
.with_context(|| format!("cannot resolve {} to SocketAddr", host_port))?
|
|
||||||
.next()
|
|
||||||
.ok_or_else(|| anyhow!("cannot resolve at least one SocketAddr"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<DatabaseInfo> for tokio_postgres::Config {
|
|
||||||
fn from(db_info: DatabaseInfo) -> Self {
|
|
||||||
let mut config = tokio_postgres::Config::new();
|
|
||||||
|
|
||||||
config
|
|
||||||
.host(&db_info.host)
|
|
||||||
.port(db_info.port)
|
|
||||||
.dbname(&db_info.dbname)
|
|
||||||
.user(&db_info.user);
|
|
||||||
|
|
||||||
if let Some(password) = db_info.password {
|
|
||||||
config.password(password);
|
|
||||||
}
|
|
||||||
|
|
||||||
config
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct CPlaneApi<'a> {
|
pub struct CPlaneApi<'a> {
|
||||||
auth_endpoint: &'a str,
|
auth_endpoint: &'a str,
|
||||||
waiters: &'a ProxyWaiters,
|
waiters: &'a ProxyWaiters,
|
||||||
|
|||||||
69
proxy/src/db.rs
Normal file
69
proxy/src/db.rs
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
///
|
||||||
|
/// Utils for connecting with the postgres dataabase.
|
||||||
|
///
|
||||||
|
|
||||||
|
use crate::scram::key::ScramKey;
|
||||||
|
use std::net::{SocketAddr, ToSocketAddrs};
|
||||||
|
use anyhow::{Context, anyhow};
|
||||||
|
|
||||||
|
/// Sufficient information to authenticate as client.
|
||||||
|
pub struct ScramAuthSecret {
|
||||||
|
pub iterations: u32,
|
||||||
|
pub salt_base64: String,
|
||||||
|
pub client_key: ScramKey,
|
||||||
|
pub server_key: ScramKey,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[non_exhaustive]
|
||||||
|
pub enum AuthSecret {
|
||||||
|
Scram(ScramAuthSecret),
|
||||||
|
Password(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
pub struct DatabaseAuthInfo {
|
||||||
|
pub host: String,
|
||||||
|
pub port: u16,
|
||||||
|
pub dbname: String,
|
||||||
|
pub user: String,
|
||||||
|
pub auth_secret: AuthSecret,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<DatabaseAuthInfo> for tokio_postgres::Config {
|
||||||
|
fn from(auth_info: DatabaseAuthInfo) -> Self {
|
||||||
|
let mut config = tokio_postgres::Config::new();
|
||||||
|
|
||||||
|
config
|
||||||
|
.host(&auth_info.host)
|
||||||
|
.port(auth_info.port)
|
||||||
|
.dbname(&auth_info.dbname)
|
||||||
|
.user(&auth_info.user);
|
||||||
|
|
||||||
|
match auth_info.auth_secret {
|
||||||
|
AuthSecret::Scram(scram_secret) => {
|
||||||
|
config.add_scram_key(
|
||||||
|
base64::decode(scram_secret.salt_base64).unwrap(),
|
||||||
|
scram_secret.iterations,
|
||||||
|
scram_secret.client_key.bytes.to_vec(),
|
||||||
|
scram_secret.server_key.bytes.to_vec(),
|
||||||
|
);
|
||||||
|
},
|
||||||
|
AuthSecret::Password(password) => {
|
||||||
|
config.password(password);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
config
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DatabaseAuthInfo {
|
||||||
|
pub fn socket_addr(&self) -> anyhow::Result<SocketAddr> {
|
||||||
|
let host_port = format!("{}:{}", self.host, self.port);
|
||||||
|
host_port
|
||||||
|
.to_socket_addrs()
|
||||||
|
.with_context(|| format!("cannot resolve {} to SocketAddr", host_port))?
|
||||||
|
.next()
|
||||||
|
.ok_or_else(|| anyhow!("cannot resolve at least one SocketAddr"))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -11,9 +11,14 @@ use state::{ProxyConfig, ProxyState};
|
|||||||
use std::thread;
|
use std::thread;
|
||||||
use zenith_utils::{tcp_listener, GIT_VERSION};
|
use zenith_utils::{tcp_listener, GIT_VERSION};
|
||||||
|
|
||||||
|
mod db;
|
||||||
|
mod auth;
|
||||||
mod cplane_api;
|
mod cplane_api;
|
||||||
mod mgmt;
|
mod mgmt;
|
||||||
|
mod parse;
|
||||||
mod proxy;
|
mod proxy;
|
||||||
|
mod sasl;
|
||||||
|
mod scram;
|
||||||
mod state;
|
mod state;
|
||||||
mod waiters;
|
mod waiters;
|
||||||
|
|
||||||
|
|||||||
18
proxy/src/parse.rs
Normal file
18
proxy/src/parse.rs
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
//! Small parsing helpers.
|
||||||
|
|
||||||
|
use std::convert::TryInto;
|
||||||
|
use std::ffi::CStr;
|
||||||
|
|
||||||
|
pub fn split_cstr(bytes: &[u8]) -> Option<(&CStr, &[u8])> {
|
||||||
|
let pos = bytes.iter().position(|&x| x == 0)?;
|
||||||
|
let (cstr, other) = bytes.split_at(pos + 1);
|
||||||
|
// SAFETY: we've already checked that there's a terminator
|
||||||
|
Some((unsafe { CStr::from_bytes_with_nul_unchecked(cstr) }, other))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn split_at_const<const N: usize>(bytes: &[u8]) -> Option<(&[u8; N], &[u8])> {
|
||||||
|
(bytes.len() >= N).then(|| {
|
||||||
|
let (head, tail) = bytes.split_at(N);
|
||||||
|
(head.try_into().unwrap(), tail)
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -1,5 +1,7 @@
|
|||||||
use crate::cplane_api::{CPlaneApi, DatabaseInfo};
|
use crate::auth::{self, AuthStream};
|
||||||
|
use crate::cplane_api::CPlaneApi;
|
||||||
use crate::ProxyState;
|
use crate::ProxyState;
|
||||||
|
use crate::db::{AuthSecret, DatabaseAuthInfo, ScramAuthSecret};
|
||||||
use anyhow::{anyhow, bail};
|
use anyhow::{anyhow, bail};
|
||||||
use lazy_static::lazy_static;
|
use lazy_static::lazy_static;
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
@@ -10,7 +12,7 @@ use std::collections::HashMap;
|
|||||||
use std::net::{SocketAddr, TcpStream};
|
use std::net::{SocketAddr, TcpStream};
|
||||||
use std::{io, thread};
|
use std::{io, thread};
|
||||||
use tokio_postgres::NoTls;
|
use tokio_postgres::NoTls;
|
||||||
use zenith_utils::postgres_backend::{self, PostgresBackend, ProtoState, Stream};
|
use zenith_utils::postgres_backend::{self, PostgresBackend, Stream};
|
||||||
use zenith_utils::pq_proto::{BeMessage as Be, FeMessage as Fe, *};
|
use zenith_utils::pq_proto::{BeMessage as Be, FeMessage as Fe, *};
|
||||||
use zenith_utils::sock_split::{ReadStream, WriteStream};
|
use zenith_utils::sock_split::{ReadStream, WriteStream};
|
||||||
|
|
||||||
@@ -117,8 +119,9 @@ impl ProxyConnection {
|
|||||||
None => return Ok(None),
|
None => return Ok(None),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// HACK for local testing only
|
||||||
// Both scenarios here should end up producing database credentials
|
// Both scenarios here should end up producing database credentials
|
||||||
if username.ends_with("@zenith") {
|
if true || username.ends_with("@zenith") {
|
||||||
self.handle_existing_user(&username, &dbname).map(Some)
|
self.handle_existing_user(&username, &dbname).map(Some)
|
||||||
} else {
|
} else {
|
||||||
self.handle_new_user().map(Some)
|
self.handle_new_user().map(Some)
|
||||||
@@ -126,7 +129,7 @@ impl ProxyConnection {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let conn = match authenticate() {
|
let conn = match authenticate() {
|
||||||
Ok(Some(db_info)) => connect_to_db(db_info),
|
Ok(Some(db_auth_info)) => connect_to_db(db_auth_info),
|
||||||
Ok(None) => return Ok(None),
|
Ok(None) => return Ok(None),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
// Report the error to the client
|
// Report the error to the client
|
||||||
@@ -211,43 +214,43 @@ impl ProxyConnection {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_existing_user(&mut self, user: &str, db: &str) -> anyhow::Result<DatabaseInfo> {
|
fn handle_existing_user(&mut self, user: &str, db: &str) -> anyhow::Result<DatabaseAuthInfo> {
|
||||||
let md5_salt = rand::random::<[u8; 4]>();
|
let _cplane = CPlaneApi::new(&self.state.conf.auth_endpoint, &self.state.waiters);
|
||||||
|
|
||||||
// Ask password
|
// TODO read from console
|
||||||
self.pgb
|
// I got this by running `select rolname, rolpassword from pg_authid;`
|
||||||
.write_message(&Be::AuthenticationMD5Password(&md5_salt))?;
|
let secret = crate::scram::ScramSecret::parse("SCRAM-SHA-256$4096:tExym9TW7MBl7OsE1FcZVQ==$Ao3nb0bStHOVIqOEUSfXdlvF9XIynqIGzSmDCs2O4p8=:WV6Eenyz5FGwuuVfKh0AXQVnzz4NLnKVV7FpVz1/1zY=").unwrap();
|
||||||
self.pgb.state = ProtoState::Authentication; // XXX
|
|
||||||
|
|
||||||
// Check password
|
let client_key = AuthStream::new(&mut self.pgb)
|
||||||
let msg = match self.pgb.read_message()? {
|
.begin(auth::Scram(&secret))?
|
||||||
Some(Fe::PasswordMessage(msg)) => msg,
|
.authenticate()?;
|
||||||
None => bail!("connection is lost"),
|
|
||||||
bad => bail!("unexpected message type: {:?}", bad),
|
|
||||||
};
|
|
||||||
println!("got message: {:?}", msg);
|
|
||||||
|
|
||||||
let (_trailing_null, md5_response) = msg
|
|
||||||
.split_last()
|
|
||||||
.ok_or_else(|| anyhow!("unexpected password message"))?;
|
|
||||||
|
|
||||||
let cplane = CPlaneApi::new(&self.state.conf.auth_endpoint, &self.state.waiters);
|
|
||||||
let db_info = cplane.authenticate_proxy_request(
|
|
||||||
user,
|
|
||||||
db,
|
|
||||||
md5_response,
|
|
||||||
&md5_salt,
|
|
||||||
&self.psql_session_id,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
self.pgb
|
self.pgb
|
||||||
.write_message_noflush(&Be::AuthenticationOk)?
|
.write_message_noflush(&Be::AuthenticationOk)?
|
||||||
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
|
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
|
||||||
|
|
||||||
Ok(db_info)
|
// TODO get this info from console and tell it to start the db
|
||||||
|
let host = "127.0.0.1";
|
||||||
|
let port = 5432;
|
||||||
|
|
||||||
|
let scram_auth_secret = ScramAuthSecret {
|
||||||
|
iterations: secret.iterations,
|
||||||
|
salt_base64: secret.salt_base64,
|
||||||
|
client_key,
|
||||||
|
server_key: secret.server_key,
|
||||||
|
};
|
||||||
|
let auth_info = DatabaseAuthInfo {
|
||||||
|
host: host.into(),
|
||||||
|
port,
|
||||||
|
dbname: db.into(),
|
||||||
|
user: user.into(),
|
||||||
|
auth_secret: AuthSecret::Scram(scram_auth_secret)
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(auth_info)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_new_user(&mut self) -> anyhow::Result<DatabaseInfo> {
|
fn handle_new_user(&mut self) -> anyhow::Result<DatabaseAuthInfo> {
|
||||||
let greeting = hello_message(&self.state.conf.redirect_uri, &self.psql_session_id);
|
let greeting = hello_message(&self.state.conf.redirect_uri, &self.psql_session_id);
|
||||||
|
|
||||||
// First, register this session
|
// First, register this session
|
||||||
@@ -265,7 +268,15 @@ impl ProxyConnection {
|
|||||||
self.pgb
|
self.pgb
|
||||||
.write_message_noflush(&Be::NoticeResponse("Connecting to database.".into()))?;
|
.write_message_noflush(&Be::NoticeResponse("Connecting to database.".into()))?;
|
||||||
|
|
||||||
Ok(db_info)
|
let db_auth_info = DatabaseAuthInfo {
|
||||||
|
host: db_info.host,
|
||||||
|
port: db_info.port,
|
||||||
|
dbname: db_info.dbname,
|
||||||
|
user: db_info.user,
|
||||||
|
auth_secret: AuthSecret::Password(db_info.password.unwrap())
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(db_auth_info)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -286,7 +297,7 @@ fn hello_message(redirect_uri: &str, session_id: &str) -> String {
|
|||||||
|
|
||||||
/// Create a TCP connection to a postgres database, authenticate with it, and receive the ReadyForQuery message
|
/// Create a TCP connection to a postgres database, authenticate with it, and receive the ReadyForQuery message
|
||||||
async fn connect_to_db(
|
async fn connect_to_db(
|
||||||
db_info: DatabaseInfo,
|
db_info: DatabaseAuthInfo,
|
||||||
) -> anyhow::Result<(String, tokio::net::TcpStream, CancelKeyData)> {
|
) -> anyhow::Result<(String, tokio::net::TcpStream, CancelKeyData)> {
|
||||||
// Make raw connection. When connect_raw finishes we've received ReadyForQuery.
|
// Make raw connection. When connect_raw finishes we've received ReadyForQuery.
|
||||||
let socket_addr = db_info.socket_addr()?;
|
let socket_addr = db_info.socket_addr()?;
|
||||||
|
|||||||
161
proxy/src/sasl.rs
Normal file
161
proxy/src/sasl.rs
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
//! Simple Authentication and Security Layer.
|
||||||
|
//!
|
||||||
|
//! RFC: <https://datatracker.ietf.org/doc/html/rfc4422>.
|
||||||
|
//!
|
||||||
|
//! Reference implementation:
|
||||||
|
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/backend/libpq/auth-sasl.c>
|
||||||
|
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/interfaces/libpq/fe-auth.c>
|
||||||
|
|
||||||
|
use crate::parse::{split_at_const, split_cstr};
|
||||||
|
use anyhow::Context;
|
||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
|
/// SASL-specific payload of [`PasswordMessage`](zenith_utils::pq_proto::FeMessage::PasswordMessage).
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct SaslFirstMessage<'a> {
|
||||||
|
/// Authentication method, e.g. `"SCRAM-SHA-256"`.
|
||||||
|
pub method: &'a str,
|
||||||
|
/// Initial client message.
|
||||||
|
pub message: &'a [u8],
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> SaslFirstMessage<'a> {
|
||||||
|
// NB: FromStr doesn't work with lifetimes
|
||||||
|
pub fn parse(bytes: &'a [u8]) -> Option<Self> {
|
||||||
|
let (method_cstr, tail) = split_cstr(bytes)?;
|
||||||
|
let method = method_cstr.to_str().ok()?;
|
||||||
|
|
||||||
|
let (len_bytes, message) = split_at_const(tail)?;
|
||||||
|
let len = u32::from_be_bytes(*len_bytes) as usize;
|
||||||
|
if len != message.len() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(Self { method, message })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A single SASL message.
|
||||||
|
/// This struct is deliberately decoupled from lower-level
|
||||||
|
/// [`BeAuthenticationSaslMessage`](zenith_utils::pq_proto::BeAuthenticationSaslMessage).
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum SaslMessage<T> {
|
||||||
|
/// We expect to see more steps.
|
||||||
|
Continue(T),
|
||||||
|
/// This is the final step.
|
||||||
|
Final(T),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This specialized trait provides capabilities akin to
|
||||||
|
/// [`std::io::Read`]+[`std::io::Write`] in oder to
|
||||||
|
/// abstract away underlying stream implementations.
|
||||||
|
pub trait SaslStream {
|
||||||
|
/// We'd like to use `AsRef<[str]>` here, but afaik there's
|
||||||
|
/// no cheap way to make [`String`] out of [`bytes::Bytes`];
|
||||||
|
/// On the other hand, byte slices are a decent middle ground.
|
||||||
|
type In: AsRef<[u8]>;
|
||||||
|
|
||||||
|
/// Receive a [SASL](crate::sasl) message from a client.
|
||||||
|
fn recv(&mut self) -> anyhow::Result<Self::In>;
|
||||||
|
|
||||||
|
/// Send a [SASL](crate::sasl) message to a client.
|
||||||
|
fn send(&mut self, data: &SaslMessage<impl AsRef<[u8]>>) -> anyhow::Result<()>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: SaslStream> SaslStream for &mut S {
|
||||||
|
type In = S::In;
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn recv(&mut self) -> anyhow::Result<Self::In> {
|
||||||
|
S::recv(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn send(&mut self, data: &SaslMessage<impl AsRef<[u8]>>) -> anyhow::Result<()> {
|
||||||
|
S::send(self, data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sometimes it's necessary to mix in a message we got from somewhere else.
|
||||||
|
impl<'a, V: AsRef<[u8]>, S: SaslStream<In = V>> SaslStream for (Option<V>, S) {
|
||||||
|
type In = S::In;
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn recv(&mut self) -> anyhow::Result<Self::In> {
|
||||||
|
// Try returning a stashed message first
|
||||||
|
match self.0.take() {
|
||||||
|
Some(value) => Ok(value),
|
||||||
|
None => self.1.recv(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn send(&mut self, data: &SaslMessage<impl AsRef<[u8]>>) -> anyhow::Result<()> {
|
||||||
|
self.1.send(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Fine-grained auth errors help in writing tests.
|
||||||
|
#[derive(Error, Debug)]
|
||||||
|
pub enum SaslError {
|
||||||
|
#[error("failed to authenticate client: {0}")]
|
||||||
|
AuthenticationFailed(&'static str),
|
||||||
|
#[error("bad client message")]
|
||||||
|
BadClientMessage,
|
||||||
|
#[error(transparent)]
|
||||||
|
Other(#[from] anyhow::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A convenient result type for SASL exchange.
|
||||||
|
pub type Result<T> = std::result::Result<T, SaslError>;
|
||||||
|
|
||||||
|
pub enum SaslStep<T, R> {
|
||||||
|
Transition(T),
|
||||||
|
Authenticated(R),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Every SASL mechanism (e.g. [SCRAM](crate::scram)) is expected to implement this trait.
|
||||||
|
pub trait SaslMechanism<T>: Sized {
|
||||||
|
/// Produce a server challenge to be sent to the client.
|
||||||
|
/// This is how this method is called in PostgreSQL (libpq/sasl.h).
|
||||||
|
fn exchange(self, input: &str) -> Result<(SaslStep<Self, T>, String)>;
|
||||||
|
|
||||||
|
/// Perform SASL message exchange according to the underlying algorithm
|
||||||
|
/// until user is either authenticated or denied access.
|
||||||
|
fn authenticate(mut self, mut stream: impl SaslStream) -> Result<T> {
|
||||||
|
loop {
|
||||||
|
let msg = stream.recv()?;
|
||||||
|
let input = std::str::from_utf8(msg.as_ref()).context("bad encoding")?;
|
||||||
|
|
||||||
|
let (this, reply) = self.exchange(input)?;
|
||||||
|
match this {
|
||||||
|
SaslStep::Transition(this) => {
|
||||||
|
stream.send(&SaslMessage::Continue(reply))?;
|
||||||
|
self = this;
|
||||||
|
}
|
||||||
|
SaslStep::Authenticated(outcome) => {
|
||||||
|
stream.send(&SaslMessage::Final(reply))?;
|
||||||
|
return Ok(outcome);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::ffi::CStr;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_sasl_first_message() {
|
||||||
|
let proto = CStr::from_bytes_with_nul(b"SCRAM-SHA-256\0").unwrap();
|
||||||
|
let sasl = "n,,n=,r=KHQ2Gjc7NptyB8aov5/TnUy4".as_bytes();
|
||||||
|
let sasl_len = (sasl.len() as u32).to_be_bytes();
|
||||||
|
let bytes = [proto.to_bytes_with_nul(), sasl_len.as_ref(), sasl].concat();
|
||||||
|
|
||||||
|
let password = SaslFirstMessage::parse(&bytes).unwrap();
|
||||||
|
assert_eq!(password.method, proto.to_str().unwrap());
|
||||||
|
assert_eq!(password.message, sasl);
|
||||||
|
}
|
||||||
|
}
|
||||||
132
proxy/src/scram.rs
Normal file
132
proxy/src/scram.rs
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
//! Salted Challenge Response Authentication Mechanism.
|
||||||
|
//!
|
||||||
|
//! RFC: <https://datatracker.ietf.org/doc/html/rfc5802>.
|
||||||
|
//!
|
||||||
|
//! Reference implementation:
|
||||||
|
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/backend/libpq/auth-scram.c>
|
||||||
|
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/interfaces/libpq/fe-auth-scram.c>
|
||||||
|
|
||||||
|
mod channel_binding;
|
||||||
|
pub mod key; // TODO do I have to make it pub?
|
||||||
|
mod messages;
|
||||||
|
mod secret;
|
||||||
|
mod signature;
|
||||||
|
|
||||||
|
pub use channel_binding::*;
|
||||||
|
pub use secret::*;
|
||||||
|
|
||||||
|
use crate::sasl::{self, SaslError, SaslMechanism, SaslStep};
|
||||||
|
use messages::{ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage};
|
||||||
|
use signature::SignatureBuilder;
|
||||||
|
|
||||||
|
pub use self::secret::ScramSecret;
|
||||||
|
|
||||||
|
/// Decode base64 into array without any heap allocations
|
||||||
|
fn base64_decode_array<const N: usize>(input: impl AsRef<[u8]>) -> Option<[u8; N]> {
|
||||||
|
let mut bytes = [0u8; N];
|
||||||
|
|
||||||
|
let size = base64::decode_config_slice(input, base64::STANDARD, &mut bytes).ok()?;
|
||||||
|
if size != N {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
enum ScramExchangeServerState {
|
||||||
|
/// Waiting for [`ClientFirstMessage`].
|
||||||
|
Initial,
|
||||||
|
/// Waiting for [`ClientFinalMessage`].
|
||||||
|
SaltSent {
|
||||||
|
cbind_flag: ChannelBinding<String>,
|
||||||
|
client_first_message_bare: String,
|
||||||
|
server_first_message: OwnedServerFirstMessage,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Server's side of SCRAM auth algorithm.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ScramExchangeServer<'a> {
|
||||||
|
state: ScramExchangeServerState,
|
||||||
|
secret: &'a ScramSecret,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> ScramExchangeServer<'a> {
|
||||||
|
pub fn new(secret: &'a ScramSecret) -> Self {
|
||||||
|
Self {
|
||||||
|
state: ScramExchangeServerState::Initial,
|
||||||
|
secret,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SaslMechanism<key::ScramKey> for ScramExchangeServer<'_> {
|
||||||
|
fn exchange(mut self, input: &str) -> sasl::Result<(sasl::SaslStep<Self, key::ScramKey>, String)> {
|
||||||
|
use ScramExchangeServerState::*;
|
||||||
|
use sasl::SaslStep::*;
|
||||||
|
match &self.state {
|
||||||
|
Initial => {
|
||||||
|
let client_first_message =
|
||||||
|
ClientFirstMessage::parse(input).ok_or(SaslError::BadClientMessage)?;
|
||||||
|
|
||||||
|
let server_first_message = client_first_message.build_server_first_message(
|
||||||
|
// TODO: use secure random
|
||||||
|
&rand::random(),
|
||||||
|
&self.secret.salt_base64,
|
||||||
|
self.secret.iterations,
|
||||||
|
);
|
||||||
|
let msg = server_first_message.as_str().to_owned();
|
||||||
|
|
||||||
|
self.state = SaltSent {
|
||||||
|
cbind_flag: client_first_message.cbind_flag.map(str::to_owned),
|
||||||
|
client_first_message_bare: client_first_message.bare.to_owned(),
|
||||||
|
server_first_message,
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok((Transition(self), msg))
|
||||||
|
}
|
||||||
|
SaltSent {
|
||||||
|
cbind_flag,
|
||||||
|
client_first_message_bare,
|
||||||
|
server_first_message,
|
||||||
|
} => {
|
||||||
|
let client_final_message =
|
||||||
|
ClientFinalMessage::parse(input).ok_or(SaslError::BadClientMessage)?;
|
||||||
|
|
||||||
|
let channel_binding = cbind_flag.encode(|_| {
|
||||||
|
// TODO: make global design decision regarding the certificate
|
||||||
|
todo!("fetch TLS certificate data")
|
||||||
|
});
|
||||||
|
|
||||||
|
// This might've been caused by a MITM attack
|
||||||
|
if client_final_message.channel_binding != channel_binding {
|
||||||
|
return Err(SaslError::AuthenticationFailed("channel binding failed"));
|
||||||
|
}
|
||||||
|
|
||||||
|
if client_final_message.nonce != server_first_message.nonce() {
|
||||||
|
return Err(SaslError::AuthenticationFailed("bad nonce"));
|
||||||
|
}
|
||||||
|
|
||||||
|
let signature_builder = SignatureBuilder {
|
||||||
|
client_first_message_bare,
|
||||||
|
server_first_message: server_first_message.as_str(),
|
||||||
|
client_final_message_without_proof: client_final_message.without_proof,
|
||||||
|
};
|
||||||
|
|
||||||
|
let client_key = signature_builder
|
||||||
|
.build(&self.secret.stored_key)
|
||||||
|
.derive_client_key(&client_final_message.proof);
|
||||||
|
|
||||||
|
if client_key.sha256() != self.secret.stored_key {
|
||||||
|
return Err(SaslError::AuthenticationFailed("keys don't match"));
|
||||||
|
}
|
||||||
|
|
||||||
|
let msg = client_final_message
|
||||||
|
.build_server_final_message(signature_builder, &self.secret.server_key);
|
||||||
|
|
||||||
|
Ok((Authenticated(client_key), msg))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
77
proxy/src/scram/channel_binding.rs
Normal file
77
proxy/src/scram/channel_binding.rs
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
//! Definition and parser for channel binding flag (a part of GS2 header).
|
||||||
|
|
||||||
|
/// Channel binding flag (possibly with params).
|
||||||
|
#[derive(Debug, PartialEq, Eq)]
|
||||||
|
pub enum ChannelBinding<T> {
|
||||||
|
/// Client doesn't support channel binding.
|
||||||
|
NotSupportedClient,
|
||||||
|
/// Client thinks server doesn't support channel binding.
|
||||||
|
NotSupportedServer,
|
||||||
|
/// Client wants to use this type of channel binding.
|
||||||
|
Required(T),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> ChannelBinding<T> {
|
||||||
|
pub fn map<R>(self, f: impl FnOnce(T) -> R) -> ChannelBinding<R> {
|
||||||
|
use ChannelBinding::*;
|
||||||
|
match self {
|
||||||
|
NotSupportedClient => NotSupportedClient,
|
||||||
|
NotSupportedServer => NotSupportedServer,
|
||||||
|
Required(x) => Required(f(x)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> ChannelBinding<&'a str> {
|
||||||
|
// NB: FromStr doesn't work with lifetimes
|
||||||
|
pub fn parse(input: &'a str) -> Option<Self> {
|
||||||
|
use ChannelBinding::*;
|
||||||
|
Some(match input {
|
||||||
|
"n" => NotSupportedClient,
|
||||||
|
"y" => NotSupportedServer,
|
||||||
|
other => Required(other.strip_prefix("p=")?),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: AsRef<str>> ChannelBinding<T> {
|
||||||
|
/// Encode channel binding data as base64 for subsequent checks.
|
||||||
|
pub fn encode(&self, get_cbind_data: impl FnOnce(&str) -> String) -> String {
|
||||||
|
use ChannelBinding::*;
|
||||||
|
match self {
|
||||||
|
NotSupportedClient => {
|
||||||
|
// base64::encode("n,,")
|
||||||
|
"biws".into()
|
||||||
|
}
|
||||||
|
NotSupportedServer => {
|
||||||
|
// base64::encode("y,,")
|
||||||
|
"eSws".into()
|
||||||
|
}
|
||||||
|
Required(s) => {
|
||||||
|
let s = s.as_ref();
|
||||||
|
let msg = format!("p={mode},,{data}", mode = s, data = get_cbind_data(s));
|
||||||
|
base64::encode(msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn channel_binding_encode() {
|
||||||
|
use ChannelBinding::*;
|
||||||
|
|
||||||
|
let cases = [
|
||||||
|
(NotSupportedClient, base64::encode("n,,")),
|
||||||
|
(NotSupportedServer, base64::encode("y,,")),
|
||||||
|
(Required("foo"), base64::encode("p=foo,,bar")),
|
||||||
|
];
|
||||||
|
|
||||||
|
for (cb, input) in cases {
|
||||||
|
assert_eq!(cb.encode(|_| "bar".to_owned()), input);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
40
proxy/src/scram/key.rs
Normal file
40
proxy/src/scram/key.rs
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
//! Tools for client/server/stored keys management.
|
||||||
|
|
||||||
|
use sha2::{Digest, Sha256};
|
||||||
|
|
||||||
|
/// Faithfully taken from PostgreSQL.
|
||||||
|
pub const SCRAM_KEY_LEN: usize = 32;
|
||||||
|
|
||||||
|
/// Thin wrapper for byte array.
|
||||||
|
#[derive(Debug, PartialEq, Eq)] // TODO maybe no debug? Avoid accidental logging.
|
||||||
|
#[repr(transparent)]
|
||||||
|
pub struct ScramKey {
|
||||||
|
pub bytes: [u8; SCRAM_KEY_LEN], // TODO does it have to be public?
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ScramKey {
|
||||||
|
pub fn sha256(&self) -> ScramKey {
|
||||||
|
let mut bytes = [0u8; SCRAM_KEY_LEN];
|
||||||
|
bytes.copy_from_slice({
|
||||||
|
let mut hash = Sha256::new();
|
||||||
|
hash.update(&self.bytes);
|
||||||
|
hash.finalize().as_slice()
|
||||||
|
});
|
||||||
|
|
||||||
|
bytes.into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<[u8; SCRAM_KEY_LEN]> for ScramKey {
|
||||||
|
#[inline(always)]
|
||||||
|
fn from(bytes: [u8; SCRAM_KEY_LEN]) -> Self {
|
||||||
|
Self { bytes }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsRef<[u8]> for ScramKey {
|
||||||
|
#[inline(always)]
|
||||||
|
fn as_ref(&self) -> &[u8] {
|
||||||
|
&self.bytes
|
||||||
|
}
|
||||||
|
}
|
||||||
228
proxy/src/scram/messages.rs
Normal file
228
proxy/src/scram/messages.rs
Normal file
@@ -0,0 +1,228 @@
|
|||||||
|
//! Definitions for SCRAM messages.
|
||||||
|
|
||||||
|
use super::base64_decode_array;
|
||||||
|
use super::channel_binding::ChannelBinding;
|
||||||
|
use super::key::{ScramKey, SCRAM_KEY_LEN};
|
||||||
|
use super::signature::SignatureBuilder;
|
||||||
|
use std::fmt;
|
||||||
|
use std::ops::Range;
|
||||||
|
|
||||||
|
/// Faithfully taken from PostgreSQL.
|
||||||
|
const SCRAM_RAW_NONCE_LEN: usize = 18;
|
||||||
|
|
||||||
|
/// Although we ignore all extensions, we still have to validate the message.
|
||||||
|
fn validate_sasl_extensions<'a>(parts: impl Iterator<Item = &'a str>) -> Option<()> {
|
||||||
|
for mut chars in parts.map(|s| s.chars()) {
|
||||||
|
let attr = chars.next()?;
|
||||||
|
if !('a'..'z').contains(&attr) && !('A'..'Z').contains(&attr) {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
let eq = chars.next()?;
|
||||||
|
if eq != '=' {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ClientFirstMessage<'a> {
|
||||||
|
/// `client-first-message-bare`.
|
||||||
|
pub bare: &'a str,
|
||||||
|
/// Channel binding mode.
|
||||||
|
pub cbind_flag: ChannelBinding<&'a str>,
|
||||||
|
/// (Client username)[<https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf/src/backend/libpq/auth-scram.c#L13>].
|
||||||
|
pub username: &'a str,
|
||||||
|
/// Client nonce.
|
||||||
|
pub nonce: &'a str,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> ClientFirstMessage<'a> {
|
||||||
|
// NB: FromStr doesn't work with lifetimes
|
||||||
|
pub fn parse(input: &'a str) -> Option<Self> {
|
||||||
|
let mut parts = input.split(',');
|
||||||
|
|
||||||
|
let cbind_flag = ChannelBinding::parse(parts.next()?)?;
|
||||||
|
|
||||||
|
// PG doesn't support authorization identity,
|
||||||
|
// so we don't bother defining GS2 header type
|
||||||
|
let authzid = parts.next()?;
|
||||||
|
if !authzid.is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unfortunately, `parts.as_str()` is unstable
|
||||||
|
let pos = authzid.as_ptr() as usize - input.as_ptr() as usize + 1;
|
||||||
|
let (_, bare) = input.split_at(pos);
|
||||||
|
|
||||||
|
// In theory, these might be preceded by "reserved-mext" (i.e. "m=")
|
||||||
|
let username = parts.next()?.strip_prefix("n=")?;
|
||||||
|
let nonce = parts.next()?.strip_prefix("r=")?;
|
||||||
|
|
||||||
|
// Validate but ignore auth extensions
|
||||||
|
validate_sasl_extensions(parts)?;
|
||||||
|
|
||||||
|
Some(Self {
|
||||||
|
bare,
|
||||||
|
cbind_flag,
|
||||||
|
username,
|
||||||
|
nonce,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn build_server_first_message(
|
||||||
|
&self,
|
||||||
|
nonce: &[u8; SCRAM_RAW_NONCE_LEN],
|
||||||
|
salt_base64: &str,
|
||||||
|
iterations: u32,
|
||||||
|
) -> OwnedServerFirstMessage {
|
||||||
|
use std::fmt::Write;
|
||||||
|
|
||||||
|
let mut message = String::new();
|
||||||
|
write!(&mut message, "r={}", self.nonce).unwrap();
|
||||||
|
base64::encode_config_buf(nonce, base64::STANDARD, &mut message);
|
||||||
|
let combined_nonce = 2..message.len();
|
||||||
|
write!(&mut message, ",s={},i={}", salt_base64, iterations).unwrap();
|
||||||
|
|
||||||
|
// This design guarantees that it's impossible to create a
|
||||||
|
// server-first-message without receiving a client-first-message
|
||||||
|
OwnedServerFirstMessage {
|
||||||
|
message,
|
||||||
|
nonce: combined_nonce,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ClientFinalMessage<'a> {
|
||||||
|
/// `client-final-message-without-proof`.
|
||||||
|
pub without_proof: &'a str,
|
||||||
|
/// Channel binding data (base64).
|
||||||
|
pub channel_binding: &'a str,
|
||||||
|
/// Combined client & server nonce.
|
||||||
|
pub nonce: &'a str,
|
||||||
|
/// Client auth proof.
|
||||||
|
pub proof: [u8; SCRAM_KEY_LEN],
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> ClientFinalMessage<'a> {
|
||||||
|
// NB: FromStr doesn't work with lifetimes
|
||||||
|
pub fn parse(input: &'a str) -> Option<Self> {
|
||||||
|
let (without_proof, proof) = input.rsplit_once(',')?;
|
||||||
|
|
||||||
|
let mut parts = without_proof.split(',');
|
||||||
|
let channel_binding = parts.next()?.strip_prefix("c=")?;
|
||||||
|
let nonce = parts.next()?.strip_prefix("r=")?;
|
||||||
|
|
||||||
|
// Validate but ignore auth extensions
|
||||||
|
validate_sasl_extensions(parts)?;
|
||||||
|
|
||||||
|
let proof = base64_decode_array(proof.strip_prefix("p=")?)?;
|
||||||
|
|
||||||
|
Some(Self {
|
||||||
|
without_proof,
|
||||||
|
channel_binding,
|
||||||
|
nonce,
|
||||||
|
proof,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn build_server_final_message(
|
||||||
|
&self,
|
||||||
|
signature_builder: SignatureBuilder,
|
||||||
|
server_key: &ScramKey,
|
||||||
|
) -> String {
|
||||||
|
let mut buf = String::from("v=");
|
||||||
|
base64::encode_config_buf(
|
||||||
|
signature_builder.build(server_key),
|
||||||
|
base64::STANDARD,
|
||||||
|
&mut buf,
|
||||||
|
);
|
||||||
|
|
||||||
|
buf
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct OwnedServerFirstMessage {
|
||||||
|
/// Owned `server-first-message`.
|
||||||
|
message: String,
|
||||||
|
/// Slice into `message`.
|
||||||
|
nonce: Range<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OwnedServerFirstMessage {
|
||||||
|
/// Extract combined nonce from the message.
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn nonce(&self) -> &str {
|
||||||
|
&self.message[self.nonce.clone()]
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get reference to a text representation of the message.
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn as_str(&self) -> &str {
|
||||||
|
&self.message
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Debug for OwnedServerFirstMessage {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
f.debug_struct("ServerFirstMessage")
|
||||||
|
.field("message", &self.as_str())
|
||||||
|
.field("nonce", &self.nonce())
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_client_first_message() {
|
||||||
|
use ChannelBinding::*;
|
||||||
|
|
||||||
|
// (Almost) real strings captured during debug sessions
|
||||||
|
let cases = [
|
||||||
|
(NotSupportedClient, "n,,n=pepe,r=t8JwklwKecDLwSsA72rHmVju"),
|
||||||
|
(NotSupportedServer, "y,,n=pepe,r=t8JwklwKecDLwSsA72rHmVju"),
|
||||||
|
(
|
||||||
|
Required("tls-server-end-point"),
|
||||||
|
"p=tls-server-end-point,,n=pepe,r=t8JwklwKecDLwSsA72rHmVju",
|
||||||
|
),
|
||||||
|
];
|
||||||
|
|
||||||
|
for (cb, input) in cases {
|
||||||
|
let msg = ClientFirstMessage::parse(input).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(msg.bare, "n=pepe,r=t8JwklwKecDLwSsA72rHmVju");
|
||||||
|
assert_eq!(msg.username, "pepe");
|
||||||
|
assert_eq!(msg.nonce, "t8JwklwKecDLwSsA72rHmVju");
|
||||||
|
assert_eq!(msg.cbind_flag, cb);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_client_final_message() {
|
||||||
|
let input = [
|
||||||
|
"c=eSws",
|
||||||
|
"r=iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU",
|
||||||
|
"p=SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI=",
|
||||||
|
]
|
||||||
|
.join(",");
|
||||||
|
|
||||||
|
let msg = ClientFinalMessage::parse(&input).unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
msg.without_proof,
|
||||||
|
"c=eSws,r=iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
msg.nonce,
|
||||||
|
"iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
base64::encode(msg.proof),
|
||||||
|
"SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI="
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
65
proxy/src/scram/secret.rs
Normal file
65
proxy/src/scram/secret.rs
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
//! Tools for SCRAM server secret management.
|
||||||
|
|
||||||
|
use super::base64_decode_array;
|
||||||
|
use super::key::ScramKey;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ScramSecret {
|
||||||
|
pub iterations: u32,
|
||||||
|
pub salt_base64: String,
|
||||||
|
pub stored_key: ScramKey,
|
||||||
|
pub server_key: ScramKey,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ScramSecret {
|
||||||
|
pub fn parse(input: &str) -> Option<Self> {
|
||||||
|
// SCRAM-SHA-256$<iterations>:<salt>$<storedkey>:<serverkey>
|
||||||
|
let s = input.strip_prefix("SCRAM-SHA-256$")?;
|
||||||
|
let (params, keys) = s.split_once('$')?;
|
||||||
|
|
||||||
|
let ((iterations, salt), (stored_key, server_key)) =
|
||||||
|
params.split_once(':').zip(keys.split_once(':'))?;
|
||||||
|
|
||||||
|
let secret = ScramSecret {
|
||||||
|
iterations: iterations.parse().ok()?,
|
||||||
|
salt_base64: salt.to_owned(),
|
||||||
|
stored_key: base64_decode_array(stored_key)?.into(),
|
||||||
|
server_key: base64_decode_array(server_key)?.into(),
|
||||||
|
};
|
||||||
|
|
||||||
|
Some(secret)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn mock() -> Self {
|
||||||
|
todo!("see auth-scram.c : mock_scram_secret")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_scram_secret() {
|
||||||
|
let iterations = 4096;
|
||||||
|
let salt = "+/tQQax7twvwTj64mjBsxQ==";
|
||||||
|
let stored_key = "D5h6KTMBlUvDJk2Y8ELfC1Sjtc6k9YHjRyuRZyBNJns=";
|
||||||
|
let server_key = "Pi3QHbcluX//NDfVkKlFl88GGzlJ5LkyPwcdlN/QBvI=";
|
||||||
|
|
||||||
|
let secret = format!(
|
||||||
|
"SCRAM-SHA-256${iterations}:{salt}${stored_key}:{server_key}",
|
||||||
|
iterations = iterations,
|
||||||
|
salt = salt,
|
||||||
|
stored_key = stored_key,
|
||||||
|
server_key = server_key,
|
||||||
|
);
|
||||||
|
|
||||||
|
let parsed = ScramSecret::parse(&secret).unwrap();
|
||||||
|
assert_eq!(parsed.iterations, iterations);
|
||||||
|
assert_eq!(parsed.salt_base64, salt);
|
||||||
|
|
||||||
|
// TODO: derive from 'password'
|
||||||
|
assert_eq!(base64::encode(parsed.stored_key), stored_key);
|
||||||
|
assert_eq!(base64::encode(parsed.server_key), server_key);
|
||||||
|
}
|
||||||
|
}
|
||||||
70
proxy/src/scram/signature.rs
Normal file
70
proxy/src/scram/signature.rs
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
//! Tools for client/server signature management.
|
||||||
|
|
||||||
|
use super::key::{ScramKey, SCRAM_KEY_LEN};
|
||||||
|
use hmac::{Hmac, Mac, NewMac};
|
||||||
|
use sha2::Sha256;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct SignatureBuilder<'a> {
|
||||||
|
pub client_first_message_bare: &'a str,
|
||||||
|
pub server_first_message: &'a str,
|
||||||
|
pub client_final_message_without_proof: &'a str,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SignatureBuilder<'_> {
|
||||||
|
pub fn build(&self, key: &ScramKey) -> Signature {
|
||||||
|
let mut mac = Hmac::<Sha256>::new_varkey(key.as_ref()).expect("bad key size");
|
||||||
|
|
||||||
|
mac.update(self.client_first_message_bare.as_bytes());
|
||||||
|
mac.update(b",");
|
||||||
|
mac.update(self.server_first_message.as_bytes());
|
||||||
|
mac.update(b",");
|
||||||
|
mac.update(self.client_final_message_without_proof.as_bytes());
|
||||||
|
|
||||||
|
// TODO: maybe newer `hmac` et al already migrated to regular arrays?
|
||||||
|
let mut signature = [0u8; SCRAM_KEY_LEN];
|
||||||
|
signature.copy_from_slice(mac.finalize().into_bytes().as_slice());
|
||||||
|
signature.into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
#[repr(transparent)]
|
||||||
|
pub struct Signature {
|
||||||
|
bytes: [u8; SCRAM_KEY_LEN],
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Signature {
|
||||||
|
/// Derive ClientKey from client's signature and proof
|
||||||
|
pub fn derive_client_key(&self, proof: &[u8; SCRAM_KEY_LEN]) -> ScramKey {
|
||||||
|
let signature = self.as_ref().iter();
|
||||||
|
|
||||||
|
// This is how the proof is calculated:
|
||||||
|
//
|
||||||
|
// 1. sha256(ClientKey) -> StoredKey
|
||||||
|
// 2. hmac_sha256(StoredKey, [messages...]) -> ClientSignature
|
||||||
|
// 3. ClientKey ^ ClientSignature -> ClientProof
|
||||||
|
//
|
||||||
|
// Step 3 implies that we can restore ClientKey from the proof
|
||||||
|
// by xoring the latter with the ClientSignature again. Afterwards
|
||||||
|
// we can check that the presumed ClientKey meets our expectations.
|
||||||
|
let mut bytes = [0u8; SCRAM_KEY_LEN];
|
||||||
|
for (i, value) in signature.zip(proof).map(|(x, y)| x ^ y).enumerate() {
|
||||||
|
bytes[i] = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
bytes.into()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<[u8; SCRAM_KEY_LEN]> for Signature {
|
||||||
|
fn from(bytes: [u8; SCRAM_KEY_LEN]) -> Self {
|
||||||
|
Self { bytes }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsRef<[u8]> for Signature {
|
||||||
|
fn as_ref(&self) -> &[u8] {
|
||||||
|
&self.bytes
|
||||||
|
}
|
||||||
|
}
|
||||||
2
vendor/postgres
vendored
2
vendor/postgres
vendored
Submodule vendor/postgres updated: 14f9177a22...12250cf3af
@@ -373,9 +373,8 @@ impl PostgresBackend {
|
|||||||
}
|
}
|
||||||
AuthType::MD5 => {
|
AuthType::MD5 => {
|
||||||
rand::thread_rng().fill(&mut self.md5_salt);
|
rand::thread_rng().fill(&mut self.md5_salt);
|
||||||
let md5_salt = self.md5_salt;
|
|
||||||
self.write_message(&BeMessage::AuthenticationMD5Password(
|
self.write_message(&BeMessage::AuthenticationMD5Password(
|
||||||
&md5_salt,
|
self.md5_salt,
|
||||||
))?;
|
))?;
|
||||||
self.state = ProtoState::Authentication;
|
self.state = ProtoState::Authentication;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -353,7 +353,8 @@ fn read_null_terminated(buf: &mut Bytes) -> anyhow::Result<Bytes> {
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum BeMessage<'a> {
|
pub enum BeMessage<'a> {
|
||||||
AuthenticationOk,
|
AuthenticationOk,
|
||||||
AuthenticationMD5Password(&'a [u8; 4]),
|
AuthenticationMD5Password([u8; 4]),
|
||||||
|
AuthenticationSasl(BeAuthenticationSaslMessage<'a>),
|
||||||
AuthenticationCleartextPassword,
|
AuthenticationCleartextPassword,
|
||||||
BackendKeyData(CancelKeyData),
|
BackendKeyData(CancelKeyData),
|
||||||
BindComplete,
|
BindComplete,
|
||||||
@@ -381,6 +382,13 @@ pub enum BeMessage<'a> {
|
|||||||
KeepAlive(WalSndKeepAlive),
|
KeepAlive(WalSndKeepAlive),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum BeAuthenticationSaslMessage<'a> {
|
||||||
|
Methods(&'a [&'a str]),
|
||||||
|
Continue(&'a [u8]),
|
||||||
|
Final(&'a [u8]),
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum BeParameterStatusMessage<'a> {
|
pub enum BeParameterStatusMessage<'a> {
|
||||||
Encoding(&'a str),
|
Encoding(&'a str),
|
||||||
@@ -552,6 +560,32 @@ impl<'a> BeMessage<'a> {
|
|||||||
.unwrap(); // write into BytesMut can't fail
|
.unwrap(); // write into BytesMut can't fail
|
||||||
}
|
}
|
||||||
|
|
||||||
|
BeMessage::AuthenticationSasl(msg) => {
|
||||||
|
buf.put_u8(b'R');
|
||||||
|
write_body(buf, |buf| {
|
||||||
|
use BeAuthenticationSaslMessage::*;
|
||||||
|
match msg {
|
||||||
|
Methods(methods) => {
|
||||||
|
buf.put_i32(10); // Specifies that SASL auth method is used.
|
||||||
|
for method in methods.iter() {
|
||||||
|
write_cstr(method.as_bytes(), buf)?;
|
||||||
|
}
|
||||||
|
buf.put_u8(0); // zero terminator for the list
|
||||||
|
}
|
||||||
|
Continue(extra) => {
|
||||||
|
buf.put_i32(11); // Continue SASL auth.
|
||||||
|
buf.put_slice(extra);
|
||||||
|
}
|
||||||
|
Final(extra) => {
|
||||||
|
buf.put_i32(12); // Send final SASL message.
|
||||||
|
buf.put_slice(extra);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok::<_, io::Error>(())
|
||||||
|
})
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
BeMessage::BackendKeyData(key_data) => {
|
BeMessage::BackendKeyData(key_data) => {
|
||||||
buf.put_u8(b'K');
|
buf.put_u8(b'K');
|
||||||
write_body(buf, |buf| {
|
write_body(buf, |buf| {
|
||||||
|
|||||||
Reference in New Issue
Block a user