diff --git a/proxy/src/scram/exchange.rs b/proxy/src/scram/exchange.rs index 7fdadc7038..786cbcaa19 100644 --- a/proxy/src/scram/exchange.rs +++ b/proxy/src/scram/exchange.rs @@ -86,8 +86,7 @@ async fn derive_client_key( ) -> ScramKey { let salted_password = pool .spawn_job(endpoint, Pbkdf2::start(password, salt, iterations)) - .await - .expect("job should not be cancelled"); + .await; let make_key = |name| { let key = Hmac::::new_from_slice(&salted_password) diff --git a/proxy/src/scram/threadpool.rs b/proxy/src/scram/threadpool.rs index d73a927995..2702aeebfe 100644 --- a/proxy/src/scram/threadpool.rs +++ b/proxy/src/scram/threadpool.rs @@ -15,6 +15,7 @@ use std::{ task::{Context, Poll}, }; +use futures::FutureExt; use rand::Rng; use rand::{rngs::SmallRng, SeedableRng}; @@ -74,15 +75,13 @@ impl ThreadPool { }) } - pub(crate) fn spawn_job( - &self, - endpoint: EndpointIdInt, - pbkdf2: Pbkdf2, - ) -> tokio::task::JoinHandle<[u8; 32]> { - self.runtime - .as_ref() - .unwrap() - .spawn(JobSpec { pbkdf2, endpoint }) + pub(crate) fn spawn_job(&self, endpoint: EndpointIdInt, pbkdf2: Pbkdf2) -> JobHandle { + JobHandle( + self.runtime + .as_ref() + .unwrap() + .spawn(JobSpec { pbkdf2, endpoint }), + ) } } @@ -167,6 +166,26 @@ impl Future for JobSpec { } } +pub(crate) struct JobHandle(tokio::task::JoinHandle<[u8; 32]>); + +impl Future for JobHandle { + type Output = [u8; 32]; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.0.poll_unpin(cx) { + Poll::Ready(Ok(ok)) => Poll::Ready(ok), + Poll::Ready(Err(err)) => std::panic::resume_unwind(err.into_panic()), + Poll::Pending => Poll::Pending, + } + } +} + +impl Drop for JobHandle { + fn drop(&mut self) { + self.0.abort(); + } +} + #[cfg(test)] mod tests { use crate::EndpointId; @@ -183,8 +202,7 @@ mod tests { let salt = [0x55; 32]; let actual = pool .spawn_job(ep, Pbkdf2::start(b"password", &salt, 4096)) - .await - .unwrap(); + .await; let expected = [ 10, 114, 73, 188, 140, 222, 196, 156, 214, 184, 79, 157, 119, 242, 16, 31, 53, 242,