diff --git a/compute_tools/src/compute.rs b/compute_tools/src/compute.rs index c98b267beb..2831e69676 100644 --- a/compute_tools/src/compute.rs +++ b/compute_tools/src/compute.rs @@ -1,5 +1,4 @@ use std::fs; -use std::io::Read; use std::os::unix::fs::PermissionsExt; use std::path::Path; use std::process::{Command, Stdio}; @@ -16,6 +15,7 @@ use utils::lsn::Lsn; use compute_api::responses::{ComputeMetrics, ComputeStatus}; use compute_api::spec::{ComputeMode, ComputeSpec}; +use utils::measured_stream::MeasuredReader; use crate::config; use crate::pg_helpers::*; @@ -134,50 +134,6 @@ impl TryFrom for ParsedSpec { } } -/// Wrapper for a reader that counts bytes and reports metrics. -/// -/// HACK The interface of this struct is a little funny, mostly because we want -/// to use it as input for tar::Archive::new(reader), which for some reason -/// takes ownership of the reader instead of just &mut. So we can't access -/// the reader to read the byte count because we lose ownership. Instead we -/// pass the ComputeNode inside the struct and update metrics on Drop. -struct ByteCounter<'a, R: Read> { - inner: R, - byte_count: usize, - compute_node: &'a ComputeNode, -} - -impl<'a, R: Read> ByteCounter<'a, R> { - fn new(reader: R, compute_node: &'a ComputeNode) -> Self { - Self { - inner: reader, - byte_count: 0, - compute_node, - } - } -} - -impl Read for ByteCounter<'_, R> { - fn read(&mut self, buf: &mut [u8]) -> std::io::Result { - let result = self.inner.read(buf); - if let Ok(n_bytes) = result { - self.byte_count += n_bytes - } - result - } -} - -impl Drop for ByteCounter<'_, R> { - fn drop(&mut self) { - self.compute_node - .state - .lock() - .unwrap() - .metrics - .basebackup_bytes = self.byte_count as u64; - } -} - impl ComputeNode { pub fn set_status(&self, status: ComputeStatus) { let mut state = self.state.lock().unwrap(); @@ -224,17 +180,20 @@ impl ComputeNode { _ => format!("basebackup {} {} {}", spec.tenant_id, spec.timeline_id, lsn), }; let copyreader = client.copy_out(basebackup_cmd.as_str())?; - let read_counter = ByteCounter::new(copyreader, self); + let mut measured_reader = MeasuredReader::new(copyreader); // Read the archive directly from the `CopyOutReader` // // Set `ignore_zeros` so that unpack() reads all the Copy data and // doesn't stop at the end-of-archive marker. Otherwise, if the server // sends an Error after finishing the tarball, we will not notice it. - let mut ar = tar::Archive::new(read_counter); + let mut ar = tar::Archive::new(&mut measured_reader); ar.set_ignore_zeros(true); ar.unpack(&self.pgdata)?; + // Report metrics + self.state.lock().unwrap().metrics.basebackup_bytes = + measured_reader.get_byte_count() as u64; self.state.lock().unwrap().metrics.basebackup_ms = Utc::now() .signed_duration_since(start_time) .to_std() diff --git a/libs/utils/src/measured_stream.rs b/libs/utils/src/measured_stream.rs index c37d686a1d..c82fc13109 100644 --- a/libs/utils/src/measured_stream.rs +++ b/libs/utils/src/measured_stream.rs @@ -1,4 +1,5 @@ use pin_project_lite::pin_project; +use std::io::Read; use std::pin::Pin; use std::{io, task}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; @@ -75,3 +76,34 @@ impl AsyncWrite for MeasuredStream { + inner: R, + byte_count: usize, +} + +impl MeasuredReader { + pub fn new(reader: R) -> Self { + Self { + inner: reader, + byte_count: 0, + } + } + + pub fn get_byte_count(&self) -> usize { + self.byte_count + } +} + +impl Read for MeasuredReader { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + let result = self.inner.read(buf); + if let Ok(n_bytes) = result { + self.byte_count += n_bytes + } + result + } +}