Refactor use_cleartext_password_flow.

It's not a property of the credentials that we receive from the
client, so remove it from ClientCredentials. Instead, pass it as an
argument directly to 'authenticate' function, where it's actually
used. All the rest of the changes is just plumbing to pass it through
the call stack to 'authenticate'
This commit is contained in:
Heikki Linnakangas
2023-02-07 19:36:30 +02:00
committed by Dmitry Ivanov
parent 0d3aefb274
commit d9c518b2cc
3 changed files with 31 additions and 23 deletions

View File

@@ -114,6 +114,7 @@ impl<'l> BackendType<'l, ClientCredentials<'_>> {
&'a mut self,
extra: &'a ConsoleReqExtra<'a>,
client: &'a mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
use_cleartext_password_flow: bool,
) -> auth::Result<Option<AuthSuccess<CachedNodeInfo>>> {
use BackendType::*;
@@ -158,7 +159,7 @@ impl<'l> BackendType<'l, ClientCredentials<'_>> {
(node, payload.password)
}
// This is a hack to allow cleartext password in secure connections (wss).
Console(api, creds) if creds.use_cleartext_password_flow => {
Console(api, creds) if use_cleartext_password_flow => {
let payload = fetch_plaintext_password(client).await?;
let node = api.wake_compute(extra, creds).await?;
@@ -182,16 +183,25 @@ impl<'l> BackendType<'l, ClientCredentials<'_>> {
}
/// Authenticate the client via the requested backend, possibly using credentials.
///
/// If `use_cleartext_password_flow` is true, we use the old cleartext password
/// flow. It is used for websocket connections, which want to minimize the number
/// of round trips. (Plaintext password authentication requires only one round-trip,
/// where SCRAM requires two.)
pub async fn authenticate<'a>(
&mut self,
extra: &'a ConsoleReqExtra<'a>,
client: &'a mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin>,
use_cleartext_password_flow: bool,
) -> auth::Result<AuthSuccess<CachedNodeInfo>> {
use BackendType::*;
// Handle cases when `project` is missing in `creds`.
// TODO: type safety: return `creds` with irrefutable `project`.
if let Some(res) = self.try_password_hack(extra, client).await? {
if let Some(res) = self
.try_password_hack(extra, client, use_cleartext_password_flow)
.await?
{
info!("user successfully authenticated (using the password hack)");
return Ok(res);
}

View File

@@ -34,9 +34,6 @@ pub struct ClientCredentials<'a> {
pub user: &'a str,
// TODO: this is a severe misnomer! We should think of a new name ASAP.
pub project: Option<Cow<'a, str>>,
/// If `True`, we'll use the old cleartext password flow. This is used for
/// websocket connections, which want to minimize the number of round trips.
pub use_cleartext_password_flow: bool,
}
impl ClientCredentials<'_> {
@@ -51,7 +48,6 @@ impl<'a> ClientCredentials<'a> {
params: &'a StartupMessageParams,
sni: Option<&str>,
common_name: Option<&str>,
use_cleartext_password_flow: bool,
) -> Result<Self, ClientCredsParseError> {
use ClientCredsParseError::*;
@@ -99,14 +95,12 @@ impl<'a> ClientCredentials<'a> {
info!(
user = user,
project = project.as_deref(),
use_cleartext_password_flow = use_cleartext_password_flow,
"credentials"
);
Ok(Self {
user,
project,
use_cleartext_password_flow,
})
}
}
@@ -131,7 +125,7 @@ mod tests {
// According to postgresql, only `user` should be required.
let options = StartupMessageParams::new([("user", "john_doe")]);
let creds = ClientCredentials::parse(&options, None, None, false)?;
let creds = ClientCredentials::parse(&options, None, None)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.project, None);
@@ -146,7 +140,7 @@ mod tests {
("foo", "bar"), // should be ignored
]);
let creds = ClientCredentials::parse(&options, None, None, false)?;
let creds = ClientCredentials::parse(&options, None, None)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.project, None);
@@ -160,7 +154,7 @@ mod tests {
let sni = Some("foo.localhost");
let common_name = Some("localhost");
let creds = ClientCredentials::parse(&options, sni, common_name, false)?;
let creds = ClientCredentials::parse(&options, sni, common_name)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.project.as_deref(), Some("foo"));
@@ -174,7 +168,7 @@ mod tests {
("options", "-ckey=1 project=bar -c geqo=off"),
]);
let creds = ClientCredentials::parse(&options, None, None, false)?;
let creds = ClientCredentials::parse(&options, None, None)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.project.as_deref(), Some("bar"));
@@ -188,7 +182,7 @@ mod tests {
let sni = Some("baz.localhost");
let common_name = Some("localhost");
let creds = ClientCredentials::parse(&options, sni, common_name, false)?;
let creds = ClientCredentials::parse(&options, sni, common_name)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.project.as_deref(), Some("baz"));
@@ -203,8 +197,7 @@ mod tests {
let sni = Some("second.localhost");
let common_name = Some("localhost");
let err =
ClientCredentials::parse(&options, sni, common_name, false).expect_err("should fail");
let err = ClientCredentials::parse(&options, sni, common_name).expect_err("should fail");
match err {
InconsistentProjectNames { domain, option } => {
assert_eq!(option, "first");
@@ -221,8 +214,7 @@ mod tests {
let sni = Some("project.localhost");
let common_name = Some("example.com");
let err =
ClientCredentials::parse(&options, sni, common_name, false).expect_err("should fail");
let err = ClientCredentials::parse(&options, sni, common_name).expect_err("should fail");
match err {
InconsistentSni { sni, cn } => {
assert_eq!(sni, "project.localhost");

View File

@@ -127,7 +127,7 @@ pub async fn handle_ws_client(
let result = config
.auth_backend
.as_ref()
.map(|_| auth::ClientCredentials::parse(&params, hostname, common_name, true))
.map(|_| auth::ClientCredentials::parse(&params, hostname, common_name))
.transpose();
async { result }.or_else(|e| stream.throw_error(e)).await?
@@ -135,7 +135,7 @@ pub async fn handle_ws_client(
let client = Client::new(stream, creds, &params, session_id);
cancel_map
.with_session(|session| client.connect_to_db(session))
.with_session(|session| client.connect_to_db(session, true))
.await
}
@@ -165,7 +165,7 @@ async fn handle_client(
let result = config
.auth_backend
.as_ref()
.map(|_| auth::ClientCredentials::parse(&params, sni, common_name, false))
.map(|_| auth::ClientCredentials::parse(&params, sni, common_name))
.transpose();
async { result }.or_else(|e| stream.throw_error(e)).await?
@@ -173,7 +173,7 @@ async fn handle_client(
let client = Client::new(stream, creds, &params, session_id);
cancel_map
.with_session(|session| client.connect_to_db(session))
.with_session(|session| client.connect_to_db(session, false))
.await
}
@@ -401,7 +401,11 @@ impl<'a, S> Client<'a, S> {
impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
/// Let the client authenticate and connect to the designated compute node.
async fn connect_to_db(self, session: cancellation::Session<'_>) -> anyhow::Result<()> {
async fn connect_to_db(
self,
session: cancellation::Session<'_>,
use_cleartext_password_flow: bool,
) -> anyhow::Result<()> {
let Self {
mut stream,
mut creds,
@@ -416,7 +420,9 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
let auth_result = async {
// `&mut stream` doesn't let us merge those 2 lines.
let res = creds.authenticate(&extra, &mut stream).await;
let res = creds
.authenticate(&extra, &mut stream, use_cleartext_password_flow)
.await;
async { res }.or_else(|e| stream.throw_error(e)).await
}
.instrument(info_span!("auth"))