diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 63573d49c0..9936176695 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -311,16 +311,10 @@ impl Client<'_, S> { .await?; let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&node.aux.traffic_labels("tx")); - let mut client = MeasuredStream::new(stream.into_inner(), |cnt| { - // Number of bytes we sent to the client (outbound). - m_sent.inc_by(cnt as u64); - }); + let mut client = MeasuredStream::new(stream.into_inner(), m_sent); let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&node.aux.traffic_labels("rx")); - let mut db = MeasuredStream::new(db.stream, |cnt| { - // Number of bytes the client sent to the compute node (inbound). - m_recv.inc_by(cnt as u64); - }); + let mut db = MeasuredStream::new(db.stream, m_recv); // Starting from here we only proxy the client's traffic. info!("performing the proxy pass..."); diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index 02a0fabe9a..27b13d3319 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -228,27 +228,27 @@ impl AsyncWrite for Stream { } pin_project! { - /// This stream tracks all writes and calls user provided - /// callback when the underlying stream is flushed. - pub struct MeasuredStream { + /// This stream tracks all writes, and whenever the stream is flushed, + /// increments the user-provided counter by the number of bytes flushed. + pub struct MeasuredStream { #[pin] stream: S, write_count: usize, - inc_write_count: W, + write_counter: prometheus::IntCounter, } } -impl MeasuredStream { - pub fn new(stream: S, inc_write_count: W) -> Self { +impl MeasuredStream { + pub fn new(stream: S, write_counter: prometheus::IntCounter) -> Self { Self { stream, write_count: 0, - inc_write_count, + write_counter, } } } -impl AsyncRead for MeasuredStream { +impl AsyncRead for MeasuredStream { fn poll_read( self: Pin<&mut Self>, context: &mut task::Context<'_>, @@ -258,7 +258,7 @@ impl AsyncRead for MeasuredStream { } } -impl AsyncWrite for MeasuredStream { +impl AsyncWrite for MeasuredStream { fn poll_write( self: Pin<&mut Self>, context: &mut task::Context<'_>, @@ -279,7 +279,7 @@ impl AsyncWrite for MeasuredStream let this = self.project(); this.stream.poll_flush(context).map_ok(|()| { // Call the user provided callback and reset the write count. - (this.inc_write_count)(*this.write_count); + this.write_counter.inc_by(*this.write_count as u64); *this.write_count = 0; }) }