diff --git a/Cargo.lock b/Cargo.lock index f6630cc203..fcdc424636 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3864,6 +3864,7 @@ dependencies = [ "bytes", "camino", "camino-tempfile", + "futures", "futures-util", "http-types", "hyper", @@ -4291,6 +4292,7 @@ dependencies = [ "tokio-io-timeout", "tokio-postgres", "tokio-stream", + "tokio-util", "toml_edit", "tracing", "url", @@ -5220,6 +5222,7 @@ checksum = "5419f34732d9eb6ee4c3578b7989078579b7f039cbbb9ca2c4da015749371e15" dependencies = [ "bytes", "futures-core", + "futures-io", "futures-sink", "pin-project-lite", "tokio", diff --git a/libs/remote_storage/Cargo.toml b/libs/remote_storage/Cargo.toml index e8bfc005d3..2cc59a947b 100644 --- a/libs/remote_storage/Cargo.toml +++ b/libs/remote_storage/Cargo.toml @@ -16,10 +16,11 @@ aws-credential-types.workspace = true bytes.workspace = true camino.workspace = true hyper = { workspace = true, features = ["stream"] } +futures.workspace = true serde.workspace = true serde_json.workspace = true tokio = { workspace = true, features = ["sync", "fs", "io-util"] } -tokio-util.workspace = true +tokio-util = { workspace = true, features = ["compat"] } toml_edit.workspace = true tracing.workspace = true scopeguard.workspace = true diff --git a/libs/remote_storage/src/azure_blob.rs b/libs/remote_storage/src/azure_blob.rs index ae08e9b171..e559d00ded 100644 --- a/libs/remote_storage/src/azure_blob.rs +++ b/libs/remote_storage/src/azure_blob.rs @@ -1,21 +1,24 @@ //! Azure Blob Storage wrapper +use std::borrow::Cow; use std::collections::HashMap; use std::env; use std::num::NonZeroU32; +use std::pin::Pin; use std::sync::Arc; -use std::{borrow::Cow, io::Cursor}; use super::REMOTE_STORAGE_PREFIX_SEPARATOR; use anyhow::Result; use azure_core::request_options::{MaxResults, Metadata, Range}; +use azure_core::RetryOptions; use azure_identity::DefaultAzureCredential; use azure_storage::StorageCredentials; use azure_storage_blobs::prelude::ClientBuilder; use azure_storage_blobs::{blob::operations::GetBlobBuilder, prelude::ContainerClient}; +use bytes::Bytes; +use futures::stream::Stream; use futures_util::StreamExt; use http_types::StatusCode; -use tokio::io::AsyncRead; use tracing::debug; use crate::s3_bucket::RequestKind; @@ -49,7 +52,8 @@ impl AzureBlobStorage { StorageCredentials::token_credential(Arc::new(token_credential)) }; - let builder = ClientBuilder::new(account, credentials); + // we have an outer retry + let builder = ClientBuilder::new(account, credentials).retry(RetryOptions::none()); let client = builder.container_client(azure_config.container_name.to_owned()); @@ -116,7 +120,8 @@ impl AzureBlobStorage { let mut metadata = HashMap::new(); // TODO give proper streaming response instead of buffering into RAM // https://github.com/neondatabase/neon/issues/5563 - let mut buf = Vec::new(); + + let mut bufs = Vec::new(); while let Some(part) = response.next().await { let part = part.map_err(to_download_error)?; if let Some(blob_meta) = part.blob.metadata { @@ -127,10 +132,10 @@ impl AzureBlobStorage { .collect() .await .map_err(|e| DownloadError::Other(e.into()))?; - buf.extend_from_slice(&data.slice(..)); + bufs.push(data); } Ok(Download { - download_stream: Box::pin(Cursor::new(buf)), + download_stream: Box::pin(futures::stream::iter(bufs.into_iter().map(Ok))), metadata: Some(StorageMetadata(metadata)), }) } @@ -217,9 +222,10 @@ impl RemoteStorage for AzureBlobStorage { } Ok(res) } + async fn upload( &self, - mut from: impl AsyncRead + Unpin + Send + Sync + 'static, + from: impl Stream> + Send + Sync + 'static, data_size_bytes: usize, to: &RemotePath, metadata: Option, @@ -227,13 +233,12 @@ impl RemoteStorage for AzureBlobStorage { let _permit = self.permit(RequestKind::Put).await; let blob_client = self.client.blob_client(self.relative_path_to_name(to)); - // TODO FIX THIS UGLY HACK and don't buffer the entire object - // into RAM here, but use the streaming interface. For that, - // we'd have to change the interface though... - // https://github.com/neondatabase/neon/issues/5563 - let mut buf = Vec::with_capacity(data_size_bytes); - tokio::io::copy(&mut from, &mut buf).await?; - let body = azure_core::Body::Bytes(buf.into()); + let from: Pin> + Send + Sync + 'static>> = + Box::pin(from); + + let from = NonSeekableStream::new(from, data_size_bytes); + + let body = azure_core::Body::SeekableStream(Box::new(from)); let mut builder = blob_client.put_block_blob(body); @@ -312,3 +317,153 @@ impl RemoteStorage for AzureBlobStorage { Ok(()) } } + +pin_project_lite::pin_project! { + /// Hack to work around not being able to stream once with azure sdk. + /// + /// Azure sdk clones streams around with the assumption that they are like + /// `Arc` (except not supporting tokio), however our streams are not like + /// that. For example for an `index_part.json` we just have a single chunk of [`Bytes`] + /// representing the whole serialized vec. It could be trivially cloneable and "semi-trivially" + /// seekable, but we can also just re-try the request easier. + #[project = NonSeekableStreamProj] + enum NonSeekableStream { + /// A stream wrappers initial form. + /// + /// Mutex exists to allow moving when cloning. If the sdk changes to do less than 1 + /// clone before first request, then this must be changed. + Initial { + inner: std::sync::Mutex>>>, + len: usize, + }, + /// The actually readable variant, produced by cloning the Initial variant. + /// + /// The sdk currently always clones once, even without retry policy. + Actual { + #[pin] + inner: tokio_util::compat::Compat>, + len: usize, + read_any: bool, + }, + /// Most likely unneeded, but left to make life easier, in case more clones are added. + Cloned { + len_was: usize, + } + } +} + +impl NonSeekableStream +where + S: Stream> + Send + Sync + 'static, +{ + fn new(inner: S, len: usize) -> NonSeekableStream { + use tokio_util::compat::TokioAsyncReadCompatExt; + + let inner = tokio_util::io::StreamReader::new(inner).compat(); + let inner = Some(inner); + let inner = std::sync::Mutex::new(inner); + NonSeekableStream::Initial { inner, len } + } +} + +impl std::fmt::Debug for NonSeekableStream { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Initial { len, .. } => f.debug_struct("Initial").field("len", len).finish(), + Self::Actual { len, .. } => f.debug_struct("Actual").field("len", len).finish(), + Self::Cloned { len_was, .. } => f.debug_struct("Cloned").field("len", len_was).finish(), + } + } +} + +impl futures::io::AsyncRead for NonSeekableStream +where + S: Stream>, +{ + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut [u8], + ) -> std::task::Poll> { + match self.project() { + NonSeekableStreamProj::Actual { + inner, read_any, .. + } => { + *read_any = true; + inner.poll_read(cx, buf) + } + // NonSeekableStream::Initial does not support reading because it is just much easier + // to have the mutex in place where one does not poll the contents, or that's how it + // seemed originally. If there is a version upgrade which changes the cloning, then + // that support needs to be hacked in. + // + // including {self:?} into the message would be useful, but unsure how to unproject. + _ => std::task::Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::Other, + "cloned or initial values cannot be read", + ))), + } + } +} + +impl Clone for NonSeekableStream { + /// Weird clone implementation exists to support the sdk doing cloning before issuing the first + /// request, see type documentation. + fn clone(&self) -> Self { + use NonSeekableStream::*; + + match self { + Initial { inner, len } => { + if let Some(inner) = inner.lock().unwrap().take() { + Actual { + inner, + len: *len, + read_any: false, + } + } else { + Self::Cloned { len_was: *len } + } + } + Actual { len, .. } => Cloned { len_was: *len }, + Cloned { len_was } => Cloned { len_was: *len_was }, + } + } +} + +#[async_trait::async_trait] +impl azure_core::SeekableStream for NonSeekableStream +where + S: Stream> + Unpin + Send + Sync + 'static, +{ + async fn reset(&mut self) -> azure_core::error::Result<()> { + use NonSeekableStream::*; + + let msg = match self { + Initial { inner, .. } => { + if inner.get_mut().unwrap().is_some() { + return Ok(()); + } else { + "reset after first clone is not supported" + } + } + Actual { read_any, .. } if !*read_any => return Ok(()), + Actual { .. } => "reset after reading is not supported", + Cloned { .. } => "reset after second clone is not supported", + }; + Err(azure_core::error::Error::new( + azure_core::error::ErrorKind::Io, + std::io::Error::new(std::io::ErrorKind::Other, msg), + )) + } + + // Note: it is not documented if this should be the total or remaining length, total passes the + // tests. + fn len(&self) -> usize { + use NonSeekableStream::*; + match self { + Initial { len, .. } => *len, + Actual { len, .. } => *len, + Cloned { len_was, .. } => *len_was, + } + } +} diff --git a/libs/remote_storage/src/lib.rs b/libs/remote_storage/src/lib.rs index e6d306ff66..e77c54e1e7 100644 --- a/libs/remote_storage/src/lib.rs +++ b/libs/remote_storage/src/lib.rs @@ -19,8 +19,10 @@ use std::{collections::HashMap, fmt::Debug, num::NonZeroUsize, pin::Pin, sync::A use anyhow::{bail, Context}; use camino::{Utf8Path, Utf8PathBuf}; +use bytes::Bytes; +use futures::stream::Stream; use serde::{Deserialize, Serialize}; -use tokio::{io, sync::Semaphore}; +use tokio::sync::Semaphore; use toml_edit::Item; use tracing::info; @@ -179,7 +181,7 @@ pub trait RemoteStorage: Send + Sync + 'static { /// Streams the local file contents into remote into the remote storage entry. async fn upload( &self, - from: impl io::AsyncRead + Unpin + Send + Sync + 'static, + from: impl Stream> + Send + Sync + 'static, // S3 PUT request requires the content length to be specified, // otherwise it starts to fail with the concurrent connection count increasing. data_size_bytes: usize, @@ -206,7 +208,7 @@ pub trait RemoteStorage: Send + Sync + 'static { } pub struct Download { - pub download_stream: Pin>, + pub download_stream: Pin> + Unpin + Send + Sync>>, /// Extra key-value data, associated with the current remote file. pub metadata: Option, } @@ -300,7 +302,7 @@ impl GenericRemoteStorage { pub async fn upload( &self, - from: impl io::AsyncRead + Unpin + Send + Sync + 'static, + from: impl Stream> + Send + Sync + 'static, data_size_bytes: usize, to: &RemotePath, metadata: Option, @@ -398,7 +400,7 @@ impl GenericRemoteStorage { /// this path is used for the remote object id conversion only. pub async fn upload_storage_object( &self, - from: impl tokio::io::AsyncRead + Unpin + Send + Sync + 'static, + from: impl Stream> + Send + Sync + 'static, from_size_bytes: usize, to: &RemotePath, ) -> anyhow::Result<()> { diff --git a/libs/remote_storage/src/local_fs.rs b/libs/remote_storage/src/local_fs.rs index fccc78de20..0016c21955 100644 --- a/libs/remote_storage/src/local_fs.rs +++ b/libs/remote_storage/src/local_fs.rs @@ -7,11 +7,14 @@ use std::{borrow::Cow, future::Future, io::ErrorKind, pin::Pin}; use anyhow::{bail, ensure, Context}; +use bytes::Bytes; use camino::{Utf8Path, Utf8PathBuf}; +use futures::stream::Stream; use tokio::{ fs, io::{self, AsyncReadExt, AsyncSeekExt, AsyncWriteExt}, }; +use tokio_util::io::ReaderStream; use tracing::*; use utils::{crashsafe::path_with_suffix_extension, fs_ext::is_directory_empty}; @@ -219,7 +222,7 @@ impl RemoteStorage for LocalFs { async fn upload( &self, - data: impl io::AsyncRead + Unpin + Send + Sync + 'static, + data: impl Stream> + Send + Sync, data_size_bytes: usize, to: &RemotePath, metadata: Option, @@ -252,8 +255,11 @@ impl RemoteStorage for LocalFs { ); let from_size_bytes = data_size_bytes as u64; + let data = tokio_util::io::StreamReader::new(data); + let data = std::pin::pin!(data); let mut buffer_to_read = data.take(from_size_bytes); + // alternatively we could just write the bytes to a file, but local_fs is a testing utility let bytes_read = io::copy(&mut buffer_to_read, &mut destination) .await .with_context(|| { @@ -308,7 +314,7 @@ impl RemoteStorage for LocalFs { async fn download(&self, from: &RemotePath) -> Result { let target_path = from.with_base(&self.storage_root); if file_exists(&target_path).map_err(DownloadError::BadInput)? { - let source = io::BufReader::new( + let source = ReaderStream::new( fs::OpenOptions::new() .read(true) .open(&target_path) @@ -348,16 +354,14 @@ impl RemoteStorage for LocalFs { } let target_path = from.with_base(&self.storage_root); if file_exists(&target_path).map_err(DownloadError::BadInput)? { - let mut source = io::BufReader::new( - fs::OpenOptions::new() - .read(true) - .open(&target_path) - .await - .with_context(|| { - format!("Failed to open source file {target_path:?} to use in the download") - }) - .map_err(DownloadError::Other)?, - ); + let mut source = tokio::fs::OpenOptions::new() + .read(true) + .open(&target_path) + .await + .with_context(|| { + format!("Failed to open source file {target_path:?} to use in the download") + }) + .map_err(DownloadError::Other)?; source .seek(io::SeekFrom::Start(start_inclusive)) .await @@ -371,11 +375,13 @@ impl RemoteStorage for LocalFs { Ok(match end_exclusive { Some(end_exclusive) => Download { metadata, - download_stream: Box::pin(source.take(end_exclusive - start_inclusive)), + download_stream: Box::pin(ReaderStream::new( + source.take(end_exclusive - start_inclusive), + )), }, None => Download { metadata, - download_stream: Box::pin(source), + download_stream: Box::pin(ReaderStream::new(source)), }, }) } else { @@ -475,7 +481,9 @@ fn file_exists(file_path: &Utf8Path) -> anyhow::Result { mod fs_tests { use super::*; + use bytes::Bytes; use camino_tempfile::tempdir; + use futures_util::Stream; use std::{collections::HashMap, io::Write}; async fn read_and_assert_remote_file_contents( @@ -485,7 +493,7 @@ mod fs_tests { remote_storage_path: &RemotePath, expected_metadata: Option<&StorageMetadata>, ) -> anyhow::Result { - let mut download = storage + let download = storage .download(remote_storage_path) .await .map_err(|e| anyhow::anyhow!("Download failed: {e}"))?; @@ -494,13 +502,9 @@ mod fs_tests { "Unexpected metadata returned for the downloaded file" ); - let mut contents = String::new(); - download - .download_stream - .read_to_string(&mut contents) - .await - .context("Failed to read remote file contents into string")?; - Ok(contents) + let contents = aggregate(download.download_stream).await?; + + String::from_utf8(contents).map_err(anyhow::Error::new) } #[tokio::test] @@ -529,25 +533,26 @@ mod fs_tests { let storage = create_storage()?; let id = RemotePath::new(Utf8Path::new("dummy"))?; - let content = std::io::Cursor::new(b"12345"); + let content = Bytes::from_static(b"12345"); + let content = move || futures::stream::once(futures::future::ready(Ok(content.clone()))); // Check that you get an error if the size parameter doesn't match the actual // size of the stream. storage - .upload(Box::new(content.clone()), 0, &id, None) + .upload(content(), 0, &id, None) .await .expect_err("upload with zero size succeeded"); storage - .upload(Box::new(content.clone()), 4, &id, None) + .upload(content(), 4, &id, None) .await .expect_err("upload with too short size succeeded"); storage - .upload(Box::new(content.clone()), 6, &id, None) + .upload(content(), 6, &id, None) .await .expect_err("upload with too large size succeeded"); // Correct size is 5, this should succeed. - storage.upload(Box::new(content), 5, &id, None).await?; + storage.upload(content(), 5, &id, None).await?; Ok(()) } @@ -595,7 +600,7 @@ mod fs_tests { let uploaded_bytes = dummy_contents(upload_name).into_bytes(); let (first_part_local, second_part_local) = uploaded_bytes.split_at(3); - let mut first_part_download = storage + let first_part_download = storage .download_byte_range(&upload_target, 0, Some(first_part_local.len() as u64)) .await?; assert!( @@ -603,21 +608,13 @@ mod fs_tests { "No metadata should be returned for no metadata upload" ); - let mut first_part_remote = io::BufWriter::new(std::io::Cursor::new(Vec::new())); - io::copy( - &mut first_part_download.download_stream, - &mut first_part_remote, - ) - .await?; - first_part_remote.flush().await?; - let first_part_remote = first_part_remote.into_inner().into_inner(); + let first_part_remote = aggregate(first_part_download.download_stream).await?; assert_eq!( - first_part_local, - first_part_remote.as_slice(), + first_part_local, first_part_remote, "First part bytes should be returned when requested" ); - let mut second_part_download = storage + let second_part_download = storage .download_byte_range( &upload_target, first_part_local.len() as u64, @@ -629,17 +626,9 @@ mod fs_tests { "No metadata should be returned for no metadata upload" ); - let mut second_part_remote = io::BufWriter::new(std::io::Cursor::new(Vec::new())); - io::copy( - &mut second_part_download.download_stream, - &mut second_part_remote, - ) - .await?; - second_part_remote.flush().await?; - let second_part_remote = second_part_remote.into_inner().into_inner(); + let second_part_remote = aggregate(second_part_download.download_stream).await?; assert_eq!( - second_part_local, - second_part_remote.as_slice(), + second_part_local, second_part_remote, "Second part bytes should be returned when requested" ); @@ -729,17 +718,10 @@ mod fs_tests { let uploaded_bytes = dummy_contents(upload_name).into_bytes(); let (first_part_local, _) = uploaded_bytes.split_at(3); - let mut partial_download_with_metadata = storage + let partial_download_with_metadata = storage .download_byte_range(&upload_target, 0, Some(first_part_local.len() as u64)) .await?; - let mut first_part_remote = io::BufWriter::new(std::io::Cursor::new(Vec::new())); - io::copy( - &mut partial_download_with_metadata.download_stream, - &mut first_part_remote, - ) - .await?; - first_part_remote.flush().await?; - let first_part_remote = first_part_remote.into_inner().into_inner(); + let first_part_remote = aggregate(partial_download_with_metadata.download_stream).await?; assert_eq!( first_part_local, first_part_remote.as_slice(), @@ -815,16 +797,16 @@ mod fs_tests { ) })?; - storage - .upload(Box::new(file), size, &relative_path, metadata) - .await?; + let file = tokio_util::io::ReaderStream::new(file); + + storage.upload(file, size, &relative_path, metadata).await?; Ok(relative_path) } async fn create_file_for_upload( path: &Utf8Path, contents: &str, - ) -> anyhow::Result<(io::BufReader, usize)> { + ) -> anyhow::Result<(fs::File, usize)> { std::fs::create_dir_all(path.parent().unwrap())?; let mut file_for_writing = std::fs::OpenOptions::new() .write(true) @@ -834,7 +816,7 @@ mod fs_tests { drop(file_for_writing); let file_size = path.metadata()?.len() as usize; Ok(( - io::BufReader::new(fs::OpenOptions::new().read(true).open(&path).await?), + fs::OpenOptions::new().read(true).open(&path).await?, file_size, )) } @@ -848,4 +830,16 @@ mod fs_tests { files.sort_by(|a, b| a.0.cmp(&b.0)); Ok(files) } + + async fn aggregate( + stream: impl Stream>, + ) -> anyhow::Result> { + use futures::stream::StreamExt; + let mut out = Vec::new(); + let mut stream = std::pin::pin!(stream); + while let Some(res) = stream.next().await { + out.extend_from_slice(&res?[..]); + } + Ok(out) + } } diff --git a/libs/remote_storage/src/s3_bucket.rs b/libs/remote_storage/src/s3_bucket.rs index 3016a14ec9..97fa1bbf5b 100644 --- a/libs/remote_storage/src/s3_bucket.rs +++ b/libs/remote_storage/src/s3_bucket.rs @@ -4,9 +4,14 @@ //! allowing multiple api users to independently work with the same S3 bucket, if //! their bucket prefixes are both specified and different. -use std::{borrow::Cow, sync::Arc}; +use std::{ + borrow::Cow, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; -use anyhow::Context; +use anyhow::Context as _; use aws_config::{ environment::credentials::EnvironmentVariableCredentialsProvider, imds::credentials::ImdsCredentialsProvider, @@ -28,11 +33,10 @@ use aws_smithy_async::rt::sleep::TokioSleep; use aws_smithy_types::body::SdkBody; use aws_smithy_types::byte_stream::ByteStream; +use bytes::Bytes; +use futures::stream::Stream; use hyper::Body; use scopeguard::ScopeGuard; -use tokio::io::{self, AsyncRead}; -use tokio_util::io::ReaderStream; -use tracing::debug; use super::StorageMetadata; use crate::{ @@ -63,7 +67,7 @@ struct GetObjectRequest { impl S3Bucket { /// Creates the S3 storage, errors if incorrect AWS S3 configuration provided. pub fn new(aws_config: &S3Config) -> anyhow::Result { - debug!( + tracing::debug!( "Creating s3 remote storage for S3 bucket {}", aws_config.bucket_name ); @@ -225,12 +229,15 @@ impl S3Bucket { match get_object { Ok(object_output) => { let metadata = object_output.metadata().cloned().map(StorageMetadata); + + let body = object_output.body; + let body = ByteStreamAsStream::from(body); + let body = PermitCarrying::new(permit, body); + let body = TimedDownload::new(started_at, body); + Ok(Download { metadata, - download_stream: Box::pin(io::BufReader::new(TimedDownload::new( - started_at, - RatelimitedAsyncRead::new(permit, object_output.body.into_async_read()), - ))), + download_stream: Box::pin(body), }) } Err(SdkError::ServiceError(e)) if matches!(e.err(), GetObjectError::NoSuchKey(_)) => { @@ -243,29 +250,55 @@ impl S3Bucket { } } +pin_project_lite::pin_project! { + struct ByteStreamAsStream { + #[pin] + inner: aws_smithy_types::byte_stream::ByteStream + } +} + +impl From for ByteStreamAsStream { + fn from(inner: aws_smithy_types::byte_stream::ByteStream) -> Self { + ByteStreamAsStream { inner } + } +} + +impl Stream for ByteStreamAsStream { + type Item = std::io::Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // this does the std::io::ErrorKind::Other conversion + self.project().inner.poll_next(cx).map_err(|x| x.into()) + } + + // cannot implement size_hint because inner.size_hint is remaining size in bytes, which makes + // sense and Stream::size_hint does not really +} + pin_project_lite::pin_project! { /// An `AsyncRead` adapter which carries a permit for the lifetime of the value. - struct RatelimitedAsyncRead { + struct PermitCarrying { permit: tokio::sync::OwnedSemaphorePermit, #[pin] inner: S, } } -impl RatelimitedAsyncRead { +impl PermitCarrying { fn new(permit: tokio::sync::OwnedSemaphorePermit, inner: S) -> Self { - RatelimitedAsyncRead { permit, inner } + Self { permit, inner } } } -impl AsyncRead for RatelimitedAsyncRead { - fn poll_read( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &mut io::ReadBuf<'_>, - ) -> std::task::Poll> { - let this = self.project(); - this.inner.poll_read(cx, buf) +impl>> Stream for PermitCarrying { + type Item = ::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_next(cx) + } + + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() } } @@ -285,7 +318,7 @@ pin_project_lite::pin_project! { } } -impl TimedDownload { +impl TimedDownload { fn new(started_at: std::time::Instant, inner: S) -> Self { TimedDownload { started_at, @@ -295,25 +328,26 @@ impl TimedDownload { } } -impl AsyncRead for TimedDownload { - fn poll_read( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &mut io::ReadBuf<'_>, - ) -> std::task::Poll> { +impl>> Stream for TimedDownload { + type Item = ::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + use std::task::ready; + let this = self.project(); - let before = buf.filled().len(); - let read = std::task::ready!(this.inner.poll_read(cx, buf)); - let read_eof = buf.filled().len() == before; - - match read { - Ok(()) if read_eof => *this.outcome = AttemptOutcome::Ok, - Ok(()) => { /* still in progress */ } - Err(_) => *this.outcome = AttemptOutcome::Err, + let res = ready!(this.inner.poll_next(cx)); + match &res { + Some(Ok(_)) => {} + Some(Err(_)) => *this.outcome = metrics::AttemptOutcome::Err, + None => *this.outcome = metrics::AttemptOutcome::Ok, } - std::task::Poll::Ready(read) + Poll::Ready(res) + } + + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() } } @@ -403,7 +437,7 @@ impl RemoteStorage for S3Bucket { async fn upload( &self, - from: impl io::AsyncRead + Unpin + Send + Sync + 'static, + from: impl Stream> + Send + Sync + 'static, from_size_bytes: usize, to: &RemotePath, metadata: Option, @@ -413,7 +447,7 @@ impl RemoteStorage for S3Bucket { let started_at = start_measuring_requests(kind); - let body = Body::wrap_stream(ReaderStream::new(from)); + let body = Body::wrap_stream(from); let bytes_stream = ByteStream::new(SdkBody::from_body_0_4(body)); let res = self diff --git a/libs/remote_storage/src/simulate_failures.rs b/libs/remote_storage/src/simulate_failures.rs index cd13db1923..802b0db7f5 100644 --- a/libs/remote_storage/src/simulate_failures.rs +++ b/libs/remote_storage/src/simulate_failures.rs @@ -1,6 +1,8 @@ //! This module provides a wrapper around a real RemoteStorage implementation that //! causes the first N attempts at each upload or download operatio to fail. For //! testing purposes. +use bytes::Bytes; +use futures::stream::Stream; use std::collections::hash_map::Entry; use std::collections::HashMap; use std::sync::Mutex; @@ -108,7 +110,7 @@ impl RemoteStorage for UnreliableWrapper { async fn upload( &self, - data: impl tokio::io::AsyncRead + Unpin + Send + Sync + 'static, + data: impl Stream> + Send + Sync + 'static, // S3 PUT request requires the content length to be specified, // otherwise it starts to fail with the concurrent connection count increasing. data_size_bytes: usize, diff --git a/libs/remote_storage/tests/test_real_azure.rs b/libs/remote_storage/tests/test_real_azure.rs index b631079bc5..7327803198 100644 --- a/libs/remote_storage/tests/test_real_azure.rs +++ b/libs/remote_storage/tests/test_real_azure.rs @@ -7,7 +7,9 @@ use std::sync::Arc; use std::time::UNIX_EPOCH; use anyhow::Context; +use bytes::Bytes; use camino::Utf8Path; +use futures::stream::Stream; use once_cell::sync::OnceCell; use remote_storage::{ AzureConfig, Download, GenericRemoteStorage, RemotePath, RemoteStorageConfig, RemoteStorageKind, @@ -180,23 +182,14 @@ async fn azure_delete_objects_works(ctx: &mut MaybeEnabledAzure) -> anyhow::Resu let path3 = RemotePath::new(Utf8Path::new(format!("{}/path3", ctx.base_prefix).as_str())) .with_context(|| "RemotePath conversion")?; - let data1 = "remote blob data1".as_bytes(); - let data1_len = data1.len(); - let data2 = "remote blob data2".as_bytes(); - let data2_len = data2.len(); - let data3 = "remote blob data3".as_bytes(); - let data3_len = data3.len(); - ctx.client - .upload(std::io::Cursor::new(data1), data1_len, &path1, None) - .await?; + let (data, len) = upload_stream("remote blob data1".as_bytes().into()); + ctx.client.upload(data, len, &path1, None).await?; - ctx.client - .upload(std::io::Cursor::new(data2), data2_len, &path2, None) - .await?; + let (data, len) = upload_stream("remote blob data2".as_bytes().into()); + ctx.client.upload(data, len, &path2, None).await?; - ctx.client - .upload(std::io::Cursor::new(data3), data3_len, &path3, None) - .await?; + let (data, len) = upload_stream("remote blob data3".as_bytes().into()); + ctx.client.upload(data, len, &path3, None).await?; ctx.client.delete_objects(&[path1, path2]).await?; @@ -219,53 +212,56 @@ async fn azure_upload_download_works(ctx: &mut MaybeEnabledAzure) -> anyhow::Res let path = RemotePath::new(Utf8Path::new(format!("{}/file", ctx.base_prefix).as_str())) .with_context(|| "RemotePath conversion")?; - let data = "remote blob data here".as_bytes(); - let data_len = data.len() as u64; + let orig = bytes::Bytes::from_static("remote blob data here".as_bytes()); - ctx.client - .upload(std::io::Cursor::new(data), data.len(), &path, None) - .await?; + let (data, len) = wrap_stream(orig.clone()); - async fn download_and_compare(mut dl: Download) -> anyhow::Result> { + ctx.client.upload(data, len, &path, None).await?; + + async fn download_and_compare(dl: Download) -> anyhow::Result> { let mut buf = Vec::new(); - tokio::io::copy(&mut dl.download_stream, &mut buf).await?; + tokio::io::copy_buf( + &mut tokio_util::io::StreamReader::new(dl.download_stream), + &mut buf, + ) + .await?; Ok(buf) } // Normal download request let dl = ctx.client.download(&path).await?; let buf = download_and_compare(dl).await?; - assert_eq!(buf, data); + assert_eq!(&buf, &orig); // Full range (end specified) let dl = ctx .client - .download_byte_range(&path, 0, Some(data_len)) + .download_byte_range(&path, 0, Some(len as u64)) .await?; let buf = download_and_compare(dl).await?; - assert_eq!(buf, data); + assert_eq!(&buf, &orig); // partial range (end specified) let dl = ctx.client.download_byte_range(&path, 4, Some(10)).await?; let buf = download_and_compare(dl).await?; - assert_eq!(buf, data[4..10]); + assert_eq!(&buf, &orig[4..10]); // partial range (end beyond real end) let dl = ctx .client - .download_byte_range(&path, 8, Some(data_len * 100)) + .download_byte_range(&path, 8, Some(len as u64 * 100)) .await?; let buf = download_and_compare(dl).await?; - assert_eq!(buf, data[8..]); + assert_eq!(&buf, &orig[8..]); // Partial range (end unspecified) let dl = ctx.client.download_byte_range(&path, 4, None).await?; let buf = download_and_compare(dl).await?; - assert_eq!(buf, data[4..]); + assert_eq!(&buf, &orig[4..]); // Full range (end unspecified) let dl = ctx.client.download_byte_range(&path, 0, None).await?; let buf = download_and_compare(dl).await?; - assert_eq!(buf, data); + assert_eq!(&buf, &orig); debug!("Cleanup: deleting file at path {path:?}"); ctx.client @@ -504,11 +500,8 @@ async fn upload_azure_data( let blob_path = blob_prefix.join(Utf8Path::new(&format!("blob_{i}"))); debug!("Creating remote item {i} at path {blob_path:?}"); - let data = format!("remote blob data {i}").into_bytes(); - let data_len = data.len(); - task_client - .upload(std::io::Cursor::new(data), data_len, &blob_path, None) - .await?; + let (data, len) = upload_stream(format!("remote blob data {i}").into_bytes().into()); + task_client.upload(data, len, &blob_path, None).await?; Ok::<_, anyhow::Error>((blob_prefix, blob_path)) }); @@ -589,11 +582,8 @@ async fn upload_simple_azure_data( .with_context(|| format!("{blob_path:?} to RemotePath conversion"))?; debug!("Creating remote item {i} at path {blob_path:?}"); - let data = format!("remote blob data {i}").into_bytes(); - let data_len = data.len(); - task_client - .upload(std::io::Cursor::new(data), data_len, &blob_path, None) - .await?; + let (data, len) = upload_stream(format!("remote blob data {i}").into_bytes().into()); + task_client.upload(data, len, &blob_path, None).await?; Ok::<_, anyhow::Error>(blob_path) }); @@ -622,3 +612,32 @@ async fn upload_simple_azure_data( ControlFlow::Continue(uploaded_blobs) } } + +// FIXME: copypasted from test_real_s3, can't remember how to share a module which is not compiled +// to binary +fn upload_stream( + content: std::borrow::Cow<'static, [u8]>, +) -> ( + impl Stream> + Send + Sync + 'static, + usize, +) { + use std::borrow::Cow; + + let content = match content { + Cow::Borrowed(x) => Bytes::from_static(x), + Cow::Owned(vec) => Bytes::from(vec), + }; + wrap_stream(content) +} + +fn wrap_stream( + content: bytes::Bytes, +) -> ( + impl Stream> + Send + Sync + 'static, + usize, +) { + let len = content.len(); + let content = futures::future::ready(Ok(content)); + + (futures::stream::once(content), len) +} diff --git a/libs/remote_storage/tests/test_real_s3.rs b/libs/remote_storage/tests/test_real_s3.rs index 48f00e0106..ecd834e61c 100644 --- a/libs/remote_storage/tests/test_real_s3.rs +++ b/libs/remote_storage/tests/test_real_s3.rs @@ -7,7 +7,9 @@ use std::sync::Arc; use std::time::UNIX_EPOCH; use anyhow::Context; +use bytes::Bytes; use camino::Utf8Path; +use futures::stream::Stream; use once_cell::sync::OnceCell; use remote_storage::{ GenericRemoteStorage, RemotePath, RemoteStorageConfig, RemoteStorageKind, S3Config, @@ -176,23 +178,14 @@ async fn s3_delete_objects_works(ctx: &mut MaybeEnabledS3) -> anyhow::Result<()> let path3 = RemotePath::new(Utf8Path::new(format!("{}/path3", ctx.base_prefix).as_str())) .with_context(|| "RemotePath conversion")?; - let data1 = "remote blob data1".as_bytes(); - let data1_len = data1.len(); - let data2 = "remote blob data2".as_bytes(); - let data2_len = data2.len(); - let data3 = "remote blob data3".as_bytes(); - let data3_len = data3.len(); - ctx.client - .upload(std::io::Cursor::new(data1), data1_len, &path1, None) - .await?; + let (data, len) = upload_stream("remote blob data1".as_bytes().into()); + ctx.client.upload(data, len, &path1, None).await?; - ctx.client - .upload(std::io::Cursor::new(data2), data2_len, &path2, None) - .await?; + let (data, len) = upload_stream("remote blob data2".as_bytes().into()); + ctx.client.upload(data, len, &path2, None).await?; - ctx.client - .upload(std::io::Cursor::new(data3), data3_len, &path3, None) - .await?; + let (data, len) = upload_stream("remote blob data3".as_bytes().into()); + ctx.client.upload(data, len, &path3, None).await?; ctx.client.delete_objects(&[path1, path2]).await?; @@ -432,11 +425,9 @@ async fn upload_s3_data( let blob_path = blob_prefix.join(Utf8Path::new(&format!("blob_{i}"))); debug!("Creating remote item {i} at path {blob_path:?}"); - let data = format!("remote blob data {i}").into_bytes(); - let data_len = data.len(); - task_client - .upload(std::io::Cursor::new(data), data_len, &blob_path, None) - .await?; + let (data, data_len) = + upload_stream(format!("remote blob data {i}").into_bytes().into()); + task_client.upload(data, data_len, &blob_path, None).await?; Ok::<_, anyhow::Error>((blob_prefix, blob_path)) }); @@ -517,11 +508,9 @@ async fn upload_simple_s3_data( .with_context(|| format!("{blob_path:?} to RemotePath conversion"))?; debug!("Creating remote item {i} at path {blob_path:?}"); - let data = format!("remote blob data {i}").into_bytes(); - let data_len = data.len(); - task_client - .upload(std::io::Cursor::new(data), data_len, &blob_path, None) - .await?; + let (data, data_len) = + upload_stream(format!("remote blob data {i}").into_bytes().into()); + task_client.upload(data, data_len, &blob_path, None).await?; Ok::<_, anyhow::Error>(blob_path) }); @@ -550,3 +539,30 @@ async fn upload_simple_s3_data( ControlFlow::Continue(uploaded_blobs) } } + +fn upload_stream( + content: std::borrow::Cow<'static, [u8]>, +) -> ( + impl Stream> + Send + Sync + 'static, + usize, +) { + use std::borrow::Cow; + + let content = match content { + Cow::Borrowed(x) => Bytes::from_static(x), + Cow::Owned(vec) => Bytes::from(vec), + }; + wrap_stream(content) +} + +fn wrap_stream( + content: bytes::Bytes, +) -> ( + impl Stream> + Send + Sync + 'static, + usize, +) { + let len = content.len(); + let content = futures::future::ready(Ok(content)); + + (futures::stream::once(content), len) +} diff --git a/pageserver/src/tenant/delete.rs b/pageserver/src/tenant/delete.rs index 548b173c0d..b8d6d0a321 100644 --- a/pageserver/src/tenant/delete.rs +++ b/pageserver/src/tenant/delete.rs @@ -77,8 +77,10 @@ async fn create_remote_delete_mark( let data: &[u8] = &[]; backoff::retry( || async { + let data = bytes::Bytes::from_static(data); + let stream = futures::stream::once(futures::future::ready(Ok(data))); remote_storage - .upload(data, 0, &remote_mark_path, None) + .upload(stream, 0, &remote_mark_path, None) .await }, |_e| false, diff --git a/pageserver/src/tenant/remote_timeline_client/download.rs b/pageserver/src/tenant/remote_timeline_client/download.rs index deb5ea84a8..3356f55f34 100644 --- a/pageserver/src/tenant/remote_timeline_client/download.rs +++ b/pageserver/src/tenant/remote_timeline_client/download.rs @@ -75,12 +75,11 @@ pub async fn download_layer_file<'a>( let (mut destination_file, bytes_amount) = download_retry( || async { - // TODO: this doesn't use the cached fd for some reason? - let mut destination_file = fs::File::create(&temp_file_path) + let destination_file = tokio::fs::File::create(&temp_file_path) .await .with_context(|| format!("create a destination file for layer '{temp_file_path}'")) .map_err(DownloadError::Other)?; - let mut download = storage + let download = storage .download(&remote_path) .await .with_context(|| { @@ -90,9 +89,14 @@ pub async fn download_layer_file<'a>( }) .map_err(DownloadError::Other)?; + let mut destination_file = + tokio::io::BufWriter::with_capacity(8 * 1024, destination_file); + + let mut reader = tokio_util::io::StreamReader::new(download.download_stream); + let bytes_amount = tokio::time::timeout( MAX_DOWNLOAD_DURATION, - tokio::io::copy(&mut download.download_stream, &mut destination_file), + tokio::io::copy_buf(&mut reader, &mut destination_file), ) .await .map_err(|e| DownloadError::Other(anyhow::anyhow!("Timed out {:?}", e)))? @@ -103,6 +107,8 @@ pub async fn download_layer_file<'a>( }) .map_err(DownloadError::Other)?; + let destination_file = destination_file.into_inner(); + Ok((destination_file, bytes_amount)) }, &format!("download {remote_path:?}"), @@ -220,20 +226,22 @@ async fn do_download_index_part( index_generation: Generation, cancel: CancellationToken, ) -> Result { + use futures::stream::StreamExt; + let remote_path = remote_index_path(tenant_shard_id, timeline_id, index_generation); let index_part_bytes = download_retry_forever( || async { - let mut index_part_download = storage.download(&remote_path).await?; + let index_part_download = storage.download(&remote_path).await?; let mut index_part_bytes = Vec::new(); - tokio::io::copy( - &mut index_part_download.download_stream, - &mut index_part_bytes, - ) - .await - .with_context(|| format!("download index part at {remote_path:?}")) - .map_err(DownloadError::Other)?; + let mut stream = std::pin::pin!(index_part_download.download_stream); + while let Some(chunk) = stream.next().await { + let chunk = chunk + .with_context(|| format!("download index part at {remote_path:?}")) + .map_err(DownloadError::Other)?; + index_part_bytes.extend_from_slice(&chunk[..]); + } Ok(index_part_bytes) }, &format!("download {remote_path:?}"), @@ -398,7 +406,7 @@ pub(crate) async fn download_initdb_tar_zst( let file = download_retry( || async { - let mut file = OpenOptions::new() + let file = OpenOptions::new() .create(true) .truncate(true) .read(true) @@ -408,13 +416,17 @@ pub(crate) async fn download_initdb_tar_zst( .with_context(|| format!("tempfile creation {temp_path}")) .map_err(DownloadError::Other)?; - let mut download = storage.download(&remote_path).await?; + let download = storage.download(&remote_path).await?; + let mut download = tokio_util::io::StreamReader::new(download.download_stream); + let mut writer = tokio::io::BufWriter::with_capacity(8 * 1024, file); - tokio::io::copy(&mut download.download_stream, &mut file) + tokio::io::copy_buf(&mut download, &mut writer) .await .with_context(|| format!("download initdb.tar.zst at {remote_path:?}")) .map_err(DownloadError::Other)?; + let mut file = writer.into_inner(); + file.seek(std::io::SeekFrom::Start(0)) .await .with_context(|| format!("rewinding initdb.tar.zst at: {remote_path:?}")) diff --git a/pageserver/src/tenant/remote_timeline_client/upload.rs b/pageserver/src/tenant/remote_timeline_client/upload.rs index 4ca4438003..0ec539a64e 100644 --- a/pageserver/src/tenant/remote_timeline_client/upload.rs +++ b/pageserver/src/tenant/remote_timeline_client/upload.rs @@ -41,11 +41,15 @@ pub(super) async fn upload_index_part<'a>( .to_s3_bytes() .context("serialize index part file into bytes")?; let index_part_size = index_part_bytes.len(); - let index_part_bytes = tokio::io::BufReader::new(std::io::Cursor::new(index_part_bytes)); + let index_part_bytes = bytes::Bytes::from(index_part_bytes); let remote_path = remote_index_path(tenant_shard_id, timeline_id, generation); storage - .upload_storage_object(Box::new(index_part_bytes), index_part_size, &remote_path) + .upload_storage_object( + futures::stream::once(futures::future::ready(Ok(index_part_bytes))), + index_part_size, + &remote_path, + ) .await .with_context(|| format!("upload index part for '{tenant_shard_id} / {timeline_id}'")) } @@ -101,8 +105,10 @@ pub(super) async fn upload_timeline_layer<'a>( let fs_size = usize::try_from(fs_size) .with_context(|| format!("convert {source_path:?} size {fs_size} usize"))?; + let reader = tokio_util::io::ReaderStream::with_capacity(source_file, 8 * 1024); + storage - .upload(source_file, fs_size, &storage_path, None) + .upload(reader, fs_size, &storage_path, None) .await .with_context(|| format!("upload layer from local path '{source_path}'"))?; @@ -119,7 +125,8 @@ pub(crate) async fn upload_initdb_dir( tracing::trace!("uploading initdb dir"); let size = initdb_dir.len(); - let bytes = tokio::io::BufReader::new(std::io::Cursor::new(initdb_dir)); + + let bytes = futures::stream::once(futures::future::ready(Ok(initdb_dir))); let remote_path = remote_initdb_archive_path(tenant_id, timeline_id); storage diff --git a/safekeeper/Cargo.toml b/safekeeper/Cargo.toml index 53fcd5ff07..cccb4ebd79 100644 --- a/safekeeper/Cargo.toml +++ b/safekeeper/Cargo.toml @@ -35,6 +35,7 @@ serde_with.workspace = true signal-hook.workspace = true thiserror.workspace = true tokio = { workspace = true, features = ["fs"] } +tokio-util = { workspace = true } tokio-io-timeout.workspace = true tokio-postgres.workspace = true toml_edit.workspace = true diff --git a/safekeeper/src/wal_backup.rs b/safekeeper/src/wal_backup.rs index 22c68ce3c9..2e2cb11e3f 100644 --- a/safekeeper/src/wal_backup.rs +++ b/safekeeper/src/wal_backup.rs @@ -494,15 +494,13 @@ async fn backup_object( .as_ref() .unwrap(); - let file = tokio::io::BufReader::new( - File::open(&source_file) - .await - .with_context(|| format!("Failed to open file {} for wal backup", source_file))?, - ); - - storage - .upload_storage_object(Box::new(file), size, target_file) + let file = File::open(&source_file) .await + .with_context(|| format!("Failed to open file {source_file:?} for wal backup"))?; + + let file = tokio_util::io::ReaderStream::with_capacity(file, 8 * 1024); + + storage.upload_storage_object(file, size, target_file).await } pub async fn read_object( @@ -524,5 +522,9 @@ pub async fn read_object( format!("Failed to open WAL segment download stream for remote path {file_path:?}") })?; - Ok(download.download_stream) + let reader = tokio_util::io::StreamReader::new(download.download_stream); + + let reader = tokio::io::BufReader::with_capacity(8 * 1024, reader); + + Ok(Box::pin(reader)) } diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index 3e46731adf..82945dfacb 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -65,7 +65,7 @@ subtle = { version = "2" } time = { version = "0.3", features = ["local-offset", "macros", "serde-well-known"] } tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "net", "process", "rt-multi-thread", "signal", "test-util"] } tokio-rustls = { version = "0.24" } -tokio-util = { version = "0.7", features = ["codec", "io"] } +tokio-util = { version = "0.7", features = ["codec", "compat", "io"] } toml_datetime = { version = "0.6", default-features = false, features = ["serde"] } toml_edit = { version = "0.19", features = ["serde"] } tower = { version = "0.4", default-features = false, features = ["balance", "buffer", "limit", "log", "timeout", "util"] }