Refactor use_cleartext_password_flow.

It's not a property of the credentials that we receive from the
client, so remove it from ClientCredentials. Instead, pass it as an
argument directly to 'authenticate' function, where it's actually
used. All the rest of the changes is just plumbing to pass it through
the call stack to 'authenticate'
This commit is contained in:
Heikki Linnakangas
2023-01-24 21:35:56 +02:00
parent 3e150419ef
commit 4d68e3108f
4 changed files with 67 additions and 53 deletions

View File

@@ -143,6 +143,7 @@ impl BackendType<'_, ClientCredentials<'_>> {
&mut self,
extra: &ConsoleReqExtra<'_>,
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
use_cleartext_password_flow: bool,
) -> auth::Result<Option<AuthSuccess<NodeInfo>>> {
use BackendType::*;
@@ -190,7 +191,7 @@ impl BackendType<'_, ClientCredentials<'_>> {
(node, payload)
}
Console(endpoint, creds) if creds.use_cleartext_password_flow => {
Console(endpoint, creds) if use_cleartext_password_flow => {
// This is a hack to allow cleartext password in secure connections (wss).
let payload = fetch_plaintext_password(client).await?;
let creds = creds.as_ref();
@@ -220,17 +221,25 @@ impl BackendType<'_, ClientCredentials<'_>> {
}
/// Authenticate the client via the requested backend, possibly using credentials.
///
/// If `use_cleartext_password_flow` is true, we use the old cleartext password
/// flow. It is used for websocket connections, which want to minimize the number
/// of round trips.
#[instrument(skip_all)]
pub async fn authenticate(
mut self,
extra: &ConsoleReqExtra<'_>,
client: &mut stream::PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
use_cleartext_password_flow: bool,
) -> auth::Result<AuthSuccess<NodeInfo>> {
use BackendType::*;
// Handle cases when `project` is missing in `creds`.
// TODO: type safety: return `creds` with irrefutable `project`.
if let Some(res) = self.try_password_hack(extra, client).await? {
if let Some(res) = self
.try_password_hack(extra, client, use_cleartext_password_flow)
.await?
{
info!("user successfully authenticated (using the password hack)");
return Ok(res);
}

View File

@@ -34,9 +34,6 @@ pub struct ClientCredentials<'a> {
pub user: &'a str,
pub dbname: &'a str,
pub project: Option<Cow<'a, str>>,
/// If `True`, we'll use the old cleartext password flow. This is used for
/// websocket connections, which want to minimize the number of round trips.
pub use_cleartext_password_flow: bool,
}
impl ClientCredentials<'_> {
@@ -53,7 +50,6 @@ impl<'a> ClientCredentials<'a> {
user: self.user,
dbname: self.dbname,
project: self.project().map(Cow::Borrowed),
use_cleartext_password_flow: self.use_cleartext_password_flow,
}
}
}
@@ -63,7 +59,6 @@ impl<'a> ClientCredentials<'a> {
params: &'a StartupMessageParams,
sni: Option<&str>,
common_name: Option<&str>,
use_cleartext_password_flow: bool,
) -> Result<Self, ClientCredsParseError> {
use ClientCredsParseError::*;
@@ -113,7 +108,6 @@ impl<'a> ClientCredentials<'a> {
user = user,
dbname = dbname,
project = project.as_deref(),
use_cleartext_password_flow = use_cleartext_password_flow,
"credentials"
);
@@ -121,7 +115,6 @@ impl<'a> ClientCredentials<'a> {
user,
dbname,
project,
use_cleartext_password_flow,
})
}
}
@@ -148,7 +141,7 @@ mod tests {
let options = StartupMessageParams::new([("user", "john_doe")]);
// TODO: check that `creds.dbname` is None.
let creds = ClientCredentials::parse(&options, None, None, false)?;
let creds = ClientCredentials::parse(&options, None, None)?;
assert_eq!(creds.user, "john_doe");
Ok(())
@@ -158,7 +151,7 @@ mod tests {
fn parse_missing_project() -> anyhow::Result<()> {
let options = StartupMessageParams::new([("user", "john_doe"), ("database", "world")]);
let creds = ClientCredentials::parse(&options, None, None, false)?;
let creds = ClientCredentials::parse(&options, None, None)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.dbname, "world");
assert_eq!(creds.project, None);
@@ -173,7 +166,7 @@ mod tests {
let sni = Some("foo.localhost");
let common_name = Some("localhost");
let creds = ClientCredentials::parse(&options, sni, common_name, false)?;
let creds = ClientCredentials::parse(&options, sni, common_name)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.dbname, "world");
assert_eq!(creds.project.as_deref(), Some("foo"));
@@ -189,7 +182,7 @@ mod tests {
("options", "-ckey=1 project=bar -c geqo=off"),
]);
let creds = ClientCredentials::parse(&options, None, None, false)?;
let creds = ClientCredentials::parse(&options, None, None)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.dbname, "world");
assert_eq!(creds.project.as_deref(), Some("bar"));
@@ -208,7 +201,7 @@ mod tests {
let sni = Some("baz.localhost");
let common_name = Some("localhost");
let creds = ClientCredentials::parse(&options, sni, common_name, false)?;
let creds = ClientCredentials::parse(&options, sni, common_name)?;
assert_eq!(creds.user, "john_doe");
assert_eq!(creds.dbname, "world");
assert_eq!(creds.project.as_deref(), Some("baz"));
@@ -227,8 +220,7 @@ mod tests {
let sni = Some("second.localhost");
let common_name = Some("localhost");
let err =
ClientCredentials::parse(&options, sni, common_name, false).expect_err("should fail");
let err = ClientCredentials::parse(&options, sni, common_name).expect_err("should fail");
match err {
InconsistentProjectNames { domain, option } => {
assert_eq!(option, "first");
@@ -245,8 +237,7 @@ mod tests {
let sni = Some("project.localhost");
let common_name = Some("example.com");
let err =
ClientCredentials::parse(&options, sni, common_name, false).expect_err("should fail");
let err = ClientCredentials::parse(&options, sni, common_name).expect_err("should fail");
match err {
InconsistentSni { sni, cn } => {
assert_eq!(sni, "project.localhost");

View File

@@ -25,12 +25,11 @@ impl CancelMap {
cancel_closure.try_cancel_query().await
}
/// Run async action within an ephemeral session identified by [`CancelKeyData`].
pub async fn with_session<'a, F, R, V>(&'a self, f: F) -> anyhow::Result<V>
where
F: FnOnce(Session<'a>) -> R,
R: std::future::Future<Output = anyhow::Result<V>>,
{
/// Create a new session, with a new client-facing random cancellation key.
///
/// Use `enable_query_cancellation` to register a database cancellation
/// key with it, and to get the client-facing key.
pub fn new_session<'a>(&'a self) -> anyhow::Result<Session<'a>> {
// HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
// expose it and we don't want to do another roundtrip to query
// for it. The client will be able to notice that this is not the
@@ -44,17 +43,9 @@ impl CancelMap {
.lock()
.try_insert(key, None)
.map_err(|_| anyhow!("query cancellation key already exists: {key}"))?;
// This will guarantee that the session gets dropped
// as soon as the future is finished.
scopeguard::defer! {
self.0.lock().remove(&key);
info!("dropped query cancellation key {key}");
}
info!("registered new query cancellation key {key}");
let session = Session::new(key, self);
f(session).await
Ok(Session::new(key, self))
}
#[cfg(test)]
@@ -111,7 +102,7 @@ impl<'a> Session<'a> {
impl Session<'_> {
/// Store the cancel token for the given session.
/// This enables query cancellation in [`crate::proxy::handshake`].
pub fn enable_query_cancellation(self, cancel_closure: CancelClosure) -> CancelKeyData {
pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
info!("enabling query cancellation for this session");
self.cancel_map
.0
@@ -122,6 +113,14 @@ impl Session<'_> {
}
}
impl<'a> Drop for Session<'a> {
fn drop(&mut self) {
let key = &self.key;
self.cancel_map.0.lock().remove(key);
info!("dropped query cancellation key {key}");
}
}
#[cfg(test)]
mod tests {
use super::*;
@@ -132,14 +131,14 @@ mod tests {
static CANCEL_MAP: Lazy<CancelMap> = Lazy::new(Default::default);
let (tx, rx) = tokio::sync::oneshot::channel();
let task = tokio::spawn(CANCEL_MAP.with_session(|session| async move {
let session = CANCEL_MAP.new_session()?;
let task = tokio::spawn(async move {
assert!(CANCEL_MAP.contains(&session));
tx.send(()).expect("failed to send");
futures::future::pending::<()>().await; // sleep forever
Ok(())
}));
});
// Wait until the task has been spawned.
rx.await.context("failed to hear from the task")?;

View File

@@ -109,16 +109,20 @@ pub async fn handle_ws_client(
let result = config
.auth_backend
.as_ref()
.map(|_| auth::ClientCredentials::parse(&params, hostname, common_name, true))
.map(|_| auth::ClientCredentials::parse(&params, hostname, common_name))
.transpose();
async { result }.or_else(|e| stream.throw_error(e)).await?
};
let client = Client::new(stream, creds, &params, session_id);
cancel_map
.with_session(|session| client.handle_connection(session))
.await
let client = Client::new(
stream,
creds,
&params,
session_id,
cancel_map.new_session()?,
);
client.handle_connection(true).await
}
/// Handle an incoming client connection, handshake and authentication.
@@ -150,16 +154,20 @@ async fn handle_client(
let result = config
.auth_backend
.as_ref()
.map(|_| auth::ClientCredentials::parse(&params, sni, common_name, false))
.map(|_| auth::ClientCredentials::parse(&params, sni, common_name))
.transpose();
async { result }.or_else(|e| stream.throw_error(e)).await?
};
let client = Client::new(stream, creds, &params, session_id);
cancel_map
.with_session(|session| client.handle_connection(session))
.await
let client = Client::new(
stream,
creds,
&params,
session_id,
cancel_map.new_session()?,
);
client.handle_connection(false).await
}
/// Establish a (most probably, secure) connection with the client.
@@ -238,6 +246,8 @@ struct Client<'a, S> {
params: &'a StartupMessageParams,
/// Unique connection ID.
session_id: uuid::Uuid,
session: cancellation::Session<'a>,
}
impl<'a, S> Client<'a, S> {
@@ -247,19 +257,21 @@ impl<'a, S> Client<'a, S> {
creds: auth::BackendType<'a, auth::ClientCredentials<'a>>,
params: &'a StartupMessageParams,
session_id: uuid::Uuid,
session: cancellation::Session<'a>,
) -> Self {
Self {
stream,
creds,
params,
session_id,
session,
}
}
}
impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<'_, S> {
async fn handle_connection(self, session: cancellation::Session<'_>) -> anyhow::Result<()> {
let (mut client, mut db) = self.connect_to_db(session).await?;
async fn handle_connection(self, use_cleartext_password_flow: bool) -> anyhow::Result<()> {
let (mut client, mut db) = self.connect_to_db(use_cleartext_password_flow).await?;
// Starting from here we only proxy the client's traffic.
info!("performing the proxy pass...");
@@ -271,13 +283,14 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<'_, S> {
#[instrument(skip_all)]
async fn connect_to_db(
self,
session: cancellation::Session<'_>,
use_cleartext_password_flow: bool,
) -> anyhow::Result<(MeasuredStream<S>, MeasuredStream<tokio::net::TcpStream>)> {
let Self {
mut stream,
creds,
params,
session_id,
session,
} = self;
let extra = auth::ConsoleReqExtra {
@@ -287,7 +300,9 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<'_, S> {
let auth_result = async {
// `&mut stream` doesn't let us merge those 2 lines.
let res = creds.authenticate(&extra, &mut stream).await;
let res = creds
.authenticate(&extra, &mut stream, use_cleartext_password_flow)
.await;
async { res }.or_else(|e| stream.throw_error(e)).await
}
.await?;