diff --git a/Cargo.lock b/Cargo.lock index d154b4eaea..dab3d12263 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3054,6 +3054,7 @@ dependencies = [ "hyper", "metrics", "once_cell", + "pin-project-lite", "serde", "serde_json", "tempfile", diff --git a/libs/remote_storage/Cargo.toml b/libs/remote_storage/Cargo.toml index 4382fbac32..15812e8439 100644 --- a/libs/remote_storage/Cargo.toml +++ b/libs/remote_storage/Cargo.toml @@ -21,7 +21,7 @@ toml_edit.workspace = true tracing.workspace = true metrics.workspace = true utils.workspace = true - +pin-project-lite.workspace = true workspace_hack.workspace = true [dev-dependencies] diff --git a/libs/remote_storage/src/s3_bucket.rs b/libs/remote_storage/src/s3_bucket.rs index 18a2c5dedd..93f5e0596e 100644 --- a/libs/remote_storage/src/s3_bucket.rs +++ b/libs/remote_storage/src/s3_bucket.rs @@ -20,7 +20,10 @@ use aws_sdk_s3::{ }; use aws_smithy_http::body::SdkBody; use hyper::Body; -use tokio::{io, sync::Semaphore}; +use tokio::{ + io::{self, AsyncRead}, + sync::Semaphore, +}; use tokio_util::io::ReaderStream; use tracing::debug; @@ -102,7 +105,7 @@ pub struct S3Bucket { // Every request to S3 can be throttled or cancelled, if a certain number of requests per second is exceeded. // Same goes to IAM, which is queried before every S3 request, if enabled. IAM has even lower RPS threshold. // The helps to ensure we don't exceed the thresholds. - concurrency_limiter: Semaphore, + concurrency_limiter: Arc, } #[derive(Default)] @@ -162,7 +165,7 @@ impl S3Bucket { client, bucket_name: aws_config.bucket_name.clone(), prefix_in_bucket, - concurrency_limiter: Semaphore::new(aws_config.concurrency_limit.get()), + concurrency_limiter: Arc::new(Semaphore::new(aws_config.concurrency_limit.get())), }) } @@ -194,9 +197,10 @@ impl S3Bucket { } async fn download_object(&self, request: GetObjectRequest) -> Result { - let _guard = self + let permit = self .concurrency_limiter - .acquire() + .clone() + .acquire_owned() .await .context("Concurrency limiter semaphore got closed during S3 download") .map_err(DownloadError::Other)?; @@ -217,9 +221,10 @@ impl S3Bucket { let metadata = object_output.metadata().cloned().map(StorageMetadata); Ok(Download { metadata, - download_stream: Box::pin(io::BufReader::new( + download_stream: Box::pin(io::BufReader::new(RatelimitedAsyncRead::new( + permit, object_output.body.into_async_read(), - )), + ))), }) } Err(SdkError::ServiceError { @@ -240,6 +245,32 @@ impl S3Bucket { } } +pin_project_lite::pin_project! { + /// An `AsyncRead` adapter which carries a permit for the lifetime of the value. + struct RatelimitedAsyncRead { + permit: tokio::sync::OwnedSemaphorePermit, + #[pin] + inner: S, + } +} + +impl RatelimitedAsyncRead { + fn new(permit: tokio::sync::OwnedSemaphorePermit, inner: S) -> Self { + RatelimitedAsyncRead { permit, inner } + } +} + +impl AsyncRead for RatelimitedAsyncRead { + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut io::ReadBuf<'_>, + ) -> std::task::Poll> { + let this = self.project(); + this.inner.poll_read(cx, buf) + } +} + #[async_trait::async_trait] impl RemoteStorage for S3Bucket { async fn list(&self) -> anyhow::Result> {