Make postgres_backend use generic IO type (#3789)

- Support measuring inbound and outbound traffic in MeasuredStream
- Start using MeasuredStream in safekeepers code
This commit is contained in:
Arthur Petukhovsky
2023-03-13 12:18:10 +03:00
committed by GitHub
parent 8699342249
commit d9a1329834
16 changed files with 234 additions and 154 deletions

View File

@@ -4,7 +4,7 @@ use crate::{
};
use anyhow::Context;
use once_cell::sync::Lazy;
use postgres_backend::{self, AuthType, PostgresBackend, QueryError};
use postgres_backend::{self, AuthType, PostgresBackend, PostgresBackendTCP, QueryError};
use pq_proto::{BeMessage, SINGLE_COL_ROWDESC};
use std::future;
use tokio::net::{TcpListener, TcpStream};
@@ -71,10 +71,10 @@ pub type ComputeReady = Result<DatabaseInfo, String>;
// TODO: replace with an http-based protocol.
struct MgmtHandler;
#[async_trait::async_trait]
impl postgres_backend::Handler for MgmtHandler {
impl postgres_backend::Handler<tokio::net::TcpStream> for MgmtHandler {
async fn process_query(
&mut self,
pgb: &mut PostgresBackend,
pgb: &mut PostgresBackendTCP,
query: &str,
) -> Result<(), QueryError> {
try_process_query(pgb, query).await.map_err(|e| {
@@ -84,7 +84,7 @@ impl postgres_backend::Handler for MgmtHandler {
}
}
async fn try_process_query(pgb: &mut PostgresBackend, query: &str) -> Result<(), QueryError> {
async fn try_process_query(pgb: &mut PostgresBackendTCP, query: &str) -> Result<(), QueryError> {
let resp: KickSession = serde_json::from_str(query).context("Failed to parse query as json")?;
let span = info_span!("event", session_id = resp.session_id);

View File

@@ -8,7 +8,7 @@ use crate::{
config::{ProxyConfig, TlsConfig},
console::{self, messages::MetricsAuxInfo},
error::io_error,
stream::{MeasuredStream, PqStream, Stream},
stream::{PqStream, Stream},
};
use anyhow::{bail, Context};
use futures::TryFutureExt;
@@ -18,6 +18,7 @@ use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams};
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{error, info, warn};
use utils::measured_stream::MeasuredStream;
/// Number of times we should retry the `/proxy_wake_compute` http request.
const NUM_RETRIES_WAKE_COMPUTE: usize = 1;
@@ -353,16 +354,24 @@ async fn proxy_pass(
aux: &MetricsAuxInfo,
) -> anyhow::Result<()> {
let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&aux.traffic_labels("tx"));
let mut client = MeasuredStream::new(client, |cnt| {
// Number of bytes we sent to the client (outbound).
m_sent.inc_by(cnt as u64);
});
let mut client = MeasuredStream::new(
client,
|_| {},
|cnt| {
// Number of bytes we sent to the client (outbound).
m_sent.inc_by(cnt as u64);
},
);
let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&aux.traffic_labels("rx"));
let mut compute = MeasuredStream::new(compute, |cnt| {
// Number of bytes the client sent to the compute node (inbound).
m_recv.inc_by(cnt as u64);
});
let mut compute = MeasuredStream::new(
compute,
|_| {},
|cnt| {
// Number of bytes the client sent to the compute node (inbound).
m_recv.inc_by(cnt as u64);
},
);
// Starting from here we only proxy the client's traffic.
info!("performing the proxy pass...");

View File

@@ -217,68 +217,3 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
}
}
}
pin_project! {
/// This stream tracks all writes and calls user provided
/// callback when the underlying stream is flushed.
pub struct MeasuredStream<S, W> {
#[pin]
stream: S,
write_count: usize,
inc_write_count: W,
}
}
impl<S, W> MeasuredStream<S, W> {
pub fn new(stream: S, inc_write_count: W) -> Self {
Self {
stream,
write_count: 0,
inc_write_count,
}
}
}
impl<S: AsyncRead + Unpin, W> AsyncRead for MeasuredStream<S, W> {
fn poll_read(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> task::Poll<io::Result<()>> {
self.project().stream.poll_read(context, buf)
}
}
impl<S: AsyncWrite + Unpin, W: FnMut(usize)> AsyncWrite for MeasuredStream<S, W> {
fn poll_write(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
buf: &[u8],
) -> task::Poll<io::Result<usize>> {
let this = self.project();
this.stream.poll_write(context, buf).map_ok(|cnt| {
// Increment the write count.
*this.write_count += cnt;
cnt
})
}
fn poll_flush(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
) -> task::Poll<io::Result<()>> {
let this = self.project();
this.stream.poll_flush(context).map_ok(|()| {
// Call the user provided callback and reset the write count.
(this.inc_write_count)(*this.write_count);
*this.write_count = 0;
})
}
fn poll_shutdown(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
) -> task::Poll<io::Result<()>> {
self.project().stream.poll_shutdown(context)
}
}