Make postgres_backend use generic IO type (#3789)

- Support measuring inbound and outbound traffic in MeasuredStream
- Start using MeasuredStream in safekeepers code
This commit is contained in:
Arthur Petukhovsky
2023-03-13 12:18:10 +03:00
committed by GitHub
parent 8699342249
commit d9a1329834
16 changed files with 234 additions and 154 deletions

1
Cargo.lock generated
View File

@@ -4532,6 +4532,7 @@ dependencies = [
"metrics",
"nix",
"once_cell",
"pin-project-lite",
"rand",
"routerify",
"sentry",

View File

@@ -59,14 +59,14 @@ pub fn is_expected_io_error(e: &io::Error) -> bool {
}
#[async_trait::async_trait]
pub trait Handler {
pub trait Handler<IO> {
/// Handle single query.
/// postgres_backend will issue ReadyForQuery after calling this (this
/// might be not what we want after CopyData streaming, but currently we don't
/// care). It will also flush out the output buffer.
async fn process_query(
&mut self,
pgb: &mut PostgresBackend,
pgb: &mut PostgresBackend<IO>,
query_string: &str,
) -> Result<(), QueryError>;
@@ -77,7 +77,7 @@ pub trait Handler {
/// to override whole init logic in implementations.
fn startup(
&mut self,
_pgb: &mut PostgresBackend,
_pgb: &mut PostgresBackend<IO>,
_sm: &FeStartupPacket,
) -> Result<(), QueryError> {
Ok(())
@@ -86,7 +86,7 @@ pub trait Handler {
/// Check auth jwt
fn check_auth_jwt(
&mut self,
_pgb: &mut PostgresBackend,
_pgb: &mut PostgresBackend<IO>,
_jwt_response: &[u8],
) -> Result<(), QueryError> {
Err(QueryError::Other(anyhow::anyhow!("JWT auth failed")))
@@ -115,12 +115,12 @@ pub enum ProcessMsgResult {
}
/// Either plain TCP stream or encrypted one, implementing AsyncRead + AsyncWrite.
pub enum MaybeTlsStream {
Unencrypted(tokio::net::TcpStream),
Tls(Box<tokio_rustls::server::TlsStream<tokio::net::TcpStream>>),
pub enum MaybeTlsStream<IO> {
Unencrypted(IO),
Tls(Box<tokio_rustls::server::TlsStream<IO>>),
}
impl AsyncWrite for MaybeTlsStream {
impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MaybeTlsStream<IO> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
@@ -147,7 +147,7 @@ impl AsyncWrite for MaybeTlsStream {
}
}
}
impl AsyncRead for MaybeTlsStream {
impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncRead for MaybeTlsStream<IO> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
@@ -192,13 +192,13 @@ impl fmt::Display for AuthType {
/// PostgresBackend after call to `split`. In principle we could always store a
/// pair of splitted handles, but that would force to to pay splitting price
/// (Arc and kinda mutex inside polling) for all uses (e.g. pageserver).
enum MaybeWriteOnly {
Full(Framed<MaybeTlsStream>),
WriteOnly(FramedWriter<MaybeTlsStream>),
enum MaybeWriteOnly<IO> {
Full(Framed<MaybeTlsStream<IO>>),
WriteOnly(FramedWriter<MaybeTlsStream<IO>>),
Broken, // temporary value palmed off during the split
}
impl MaybeWriteOnly {
impl<IO: AsyncRead + AsyncWrite + Unpin> MaybeWriteOnly<IO> {
async fn read_startup_message(&mut self) -> Result<Option<FeStartupPacket>, ConnectionError> {
match self {
MaybeWriteOnly::Full(framed) => framed.read_startup_message().await,
@@ -244,8 +244,8 @@ impl MaybeWriteOnly {
}
}
pub struct PostgresBackend {
framed: MaybeWriteOnly,
pub struct PostgresBackend<IO> {
framed: MaybeWriteOnly<IO>,
pub state: ProtoState,
@@ -255,6 +255,8 @@ pub struct PostgresBackend {
pub tls_config: Option<Arc<rustls::ServerConfig>>,
}
pub type PostgresBackendTCP = PostgresBackend<tokio::net::TcpStream>;
pub fn query_from_cstring(query_string: Bytes) -> Vec<u8> {
let mut query_string = query_string.to_vec();
if let Some(ch) = query_string.last() {
@@ -271,7 +273,7 @@ fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> {
std::str::from_utf8(without_null).map_err(|e| e.into())
}
impl PostgresBackend {
impl PostgresBackend<tokio::net::TcpStream> {
pub fn new(
socket: tokio::net::TcpStream,
auth_type: AuthType,
@@ -288,6 +290,25 @@ impl PostgresBackend {
peer_addr,
})
}
}
impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
pub fn new_from_io(
socket: IO,
peer_addr: SocketAddr,
auth_type: AuthType,
tls_config: Option<Arc<rustls::ServerConfig>>,
) -> io::Result<Self> {
let stream = MaybeTlsStream::Unencrypted(socket);
Ok(Self {
framed: MaybeWriteOnly::Full(Framed::new(stream)),
state: ProtoState::Initialization,
auth_type,
tls_config,
peer_addr,
})
}
pub fn get_peer_addr(&self) -> &SocketAddr {
&self.peer_addr
@@ -346,14 +367,14 @@ impl PostgresBackend {
/// to it in CopyData messages, and writes them to the connection
///
/// The caller is responsible for sending CopyOutResponse and CopyDone messages.
pub fn copyout_writer(&mut self) -> CopyDataWriter {
pub fn copyout_writer(&mut self) -> CopyDataWriter<IO> {
CopyDataWriter { pgb: self }
}
/// Wrapper for run_message_loop() that shuts down socket when we are done
pub async fn run<F, S>(
mut self,
handler: &mut impl Handler,
handler: &mut impl Handler<IO>,
shutdown_watcher: F,
) -> Result<(), QueryError>
where
@@ -369,7 +390,7 @@ impl PostgresBackend {
async fn run_message_loop<F, S>(
&mut self,
handler: &mut impl Handler,
handler: &mut impl Handler<IO>,
shutdown_watcher: F,
) -> Result<(), QueryError>
where
@@ -426,9 +447,9 @@ impl PostgresBackend {
/// Try to upgrade MaybeTlsStream into actual TLS one, performing handshake.
async fn tls_upgrade(
src: MaybeTlsStream,
src: MaybeTlsStream<IO>,
tls_config: Arc<rustls::ServerConfig>,
) -> anyhow::Result<MaybeTlsStream> {
) -> anyhow::Result<MaybeTlsStream<IO>> {
match src {
MaybeTlsStream::Unencrypted(s) => {
let acceptor = TlsAcceptor::from(tls_config);
@@ -466,7 +487,7 @@ impl PostgresBackend {
/// Split off owned read part from which messages can be read in different
/// task/thread.
pub fn split(&mut self) -> anyhow::Result<PostgresBackendReader> {
pub fn split(&mut self) -> anyhow::Result<PostgresBackendReader<IO>> {
// temporary replace stream with fake to cook split one, Indiana Jones style
match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
MaybeWriteOnly::Full(framed) => {
@@ -482,7 +503,7 @@ impl PostgresBackend {
}
/// Join read part back.
pub fn unsplit(&mut self, reader: PostgresBackendReader) -> anyhow::Result<()> {
pub fn unsplit(&mut self, reader: PostgresBackendReader<IO>) -> anyhow::Result<()> {
// temporary replace stream with fake to cook joined one, Indiana Jones style
match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
MaybeWriteOnly::Full(_) => {
@@ -499,7 +520,7 @@ impl PostgresBackend {
/// Perform handshake with the client, transitioning to Established.
/// In case of EOF during handshake logs this, sets state to Closed and returns Ok(()).
async fn handshake(&mut self, handler: &mut impl Handler) -> Result<(), QueryError> {
async fn handshake(&mut self, handler: &mut impl Handler<IO>) -> Result<(), QueryError> {
while self.state < ProtoState::Authentication {
match self.framed.read_startup_message().await? {
Some(msg) => {
@@ -565,7 +586,7 @@ impl PostgresBackend {
/// actual startup packet.
async fn process_startup_message(
&mut self,
handler: &mut impl Handler,
handler: &mut impl Handler<IO>,
msg: FeStartupPacket,
) -> Result<(), QueryError> {
assert!(self.state < ProtoState::Authentication);
@@ -629,7 +650,7 @@ impl PostgresBackend {
async fn process_message(
&mut self,
handler: &mut impl Handler,
handler: &mut impl Handler<IO>,
msg: FeMessage,
unnamed_query_string: &mut Bytes,
) -> Result<ProcessMsgResult, QueryError> {
@@ -776,9 +797,9 @@ impl PostgresBackend {
}
}
pub struct PostgresBackendReader(FramedReader<MaybeTlsStream>);
pub struct PostgresBackendReader<IO>(FramedReader<MaybeTlsStream<IO>>);
impl PostgresBackendReader {
impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackendReader<IO> {
/// Read full message or return None if connection is cleanly closed with no
/// unprocessed data.
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
@@ -812,11 +833,11 @@ impl PostgresBackendReader {
/// messages.
///
pub struct CopyDataWriter<'a> {
pgb: &'a mut PostgresBackend,
pub struct CopyDataWriter<'a, IO> {
pgb: &'a mut PostgresBackend<IO>,
}
impl<'a> AsyncWrite for CopyDataWriter<'a> {
impl<'a, IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for CopyDataWriter<'a, IO> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,

View File

@@ -4,6 +4,7 @@ use postgres_backend::{AuthType, Handler, PostgresBackend, QueryError};
use pq_proto::{BeMessage, RowDescriptor};
use std::io::Cursor;
use std::{future, sync::Arc};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, TcpStream};
use tokio_postgres::config::SslMode;
use tokio_postgres::tls::MakeTlsConnect;
@@ -22,11 +23,11 @@ async fn make_tcp_pair() -> (TcpStream, TcpStream) {
struct TestHandler {}
#[async_trait::async_trait]
impl Handler for TestHandler {
impl<IO: AsyncRead + AsyncWrite + Unpin + Send> Handler<IO> for TestHandler {
// return single col 'hey' for any query
async fn process_query(
&mut self,
pgb: &mut PostgresBackend,
pgb: &mut PostgresBackend<IO>,
_query_string: &str,
) -> Result<(), QueryError> {
pgb.write_message_noflush(&BeMessage::RowDescription(&[RowDescriptor::text_col(

View File

@@ -18,6 +18,7 @@ futures = { workspace = true}
jsonwebtoken.workspace = true
nix.workspace = true
once_cell.workspace = true
pin-project-lite.workspace = true
routerify.workspace = true
serde.workspace = true
serde_json.workspace = true

View File

@@ -49,6 +49,8 @@ pub mod fs_ext;
pub mod history_buffer;
pub mod measured_stream;
/// use with fail::cfg("$name", "return(2000)")
#[macro_export]
macro_rules! failpoint_sleep_millis_async {

View File

@@ -0,0 +1,77 @@
use pin_project_lite::pin_project;
use std::pin::Pin;
use std::{io, task};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
pin_project! {
/// This stream tracks all writes and calls user provided
/// callback when the underlying stream is flushed.
pub struct MeasuredStream<S, R, W> {
#[pin]
stream: S,
write_count: usize,
inc_read_count: R,
inc_write_count: W,
}
}
impl<S, R, W> MeasuredStream<S, R, W> {
pub fn new(stream: S, inc_read_count: R, inc_write_count: W) -> Self {
Self {
stream,
write_count: 0,
inc_read_count,
inc_write_count,
}
}
}
impl<S: AsyncRead + Unpin, R: FnMut(usize), W> AsyncRead for MeasuredStream<S, R, W> {
fn poll_read(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> task::Poll<io::Result<()>> {
let this = self.project();
let filled = buf.filled().len();
this.stream.poll_read(context, buf).map_ok(|()| {
let cnt = buf.filled().len() - filled;
// Increment the read count.
(this.inc_read_count)(cnt);
})
}
}
impl<S: AsyncWrite + Unpin, R, W: FnMut(usize)> AsyncWrite for MeasuredStream<S, R, W> {
fn poll_write(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
buf: &[u8],
) -> task::Poll<io::Result<usize>> {
let this = self.project();
this.stream.poll_write(context, buf).map_ok(|cnt| {
// Increment the write count.
*this.write_count += cnt;
cnt
})
}
fn poll_flush(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
) -> task::Poll<io::Result<()>> {
let this = self.project();
this.stream.poll_flush(context).map_ok(|()| {
// Call the user provided callback and reset the write count.
(this.inc_write_count)(*this.write_count);
*this.write_count = 0;
})
}
fn poll_shutdown(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
) -> task::Poll<io::Result<()>> {
self.project().stream.poll_shutdown(context)
}
}

View File

@@ -20,6 +20,7 @@ use pageserver_api::models::{
PagestreamFeMessage, PagestreamGetPageRequest, PagestreamGetPageResponse,
PagestreamNblocksRequest, PagestreamNblocksResponse,
};
use postgres_backend::PostgresBackendTCP;
use postgres_backend::{self, is_expected_io_error, AuthType, PostgresBackend, QueryError};
use pq_proto::framed::ConnectionError;
use pq_proto::FeStartupPacket;
@@ -54,7 +55,7 @@ use crate::trace::Tracer;
use postgres_ffi::pg_constants::DEFAULTTABLESPACE_OID;
use postgres_ffi::BLCKSZ;
fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream<Item = io::Result<Bytes>> + '_ {
fn copyin_stream(pgb: &mut PostgresBackendTCP) -> impl Stream<Item = io::Result<Bytes>> + '_ {
async_stream::try_stream! {
loop {
let msg = tokio::select! {
@@ -288,7 +289,7 @@ impl PageServerHandler {
#[instrument(skip(self, pgb, ctx))]
async fn handle_pagerequests(
&self,
pgb: &mut PostgresBackend,
pgb: &mut PostgresBackendTCP,
tenant_id: TenantId,
timeline_id: TimelineId,
ctx: RequestContext,
@@ -392,7 +393,7 @@ impl PageServerHandler {
#[instrument(skip(self, pgb, ctx))]
async fn handle_import_basebackup(
&self,
pgb: &mut PostgresBackend,
pgb: &mut PostgresBackendTCP,
tenant_id: TenantId,
timeline_id: TimelineId,
base_lsn: Lsn,
@@ -448,7 +449,7 @@ impl PageServerHandler {
#[instrument(skip(self, pgb, ctx))]
async fn handle_import_wal(
&self,
pgb: &mut PostgresBackend,
pgb: &mut PostgresBackendTCP,
tenant_id: TenantId,
timeline_id: TimelineId,
start_lsn: Lsn,
@@ -659,7 +660,7 @@ impl PageServerHandler {
#[instrument(skip(self, pgb, ctx))]
async fn handle_basebackup_request(
&mut self,
pgb: &mut PostgresBackend,
pgb: &mut PostgresBackendTCP,
tenant_id: TenantId,
timeline_id: TimelineId,
lsn: Option<Lsn>,
@@ -723,10 +724,10 @@ impl PageServerHandler {
}
#[async_trait::async_trait]
impl postgres_backend::Handler for PageServerHandler {
impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
fn check_auth_jwt(
&mut self,
_pgb: &mut PostgresBackend,
_pgb: &mut PostgresBackendTCP,
jwt_response: &[u8],
) -> Result<(), QueryError> {
// this unwrap is never triggered, because check_auth_jwt only called when auth_type is NeonJWT
@@ -754,7 +755,7 @@ impl postgres_backend::Handler for PageServerHandler {
fn startup(
&mut self,
_pgb: &mut PostgresBackend,
_pgb: &mut PostgresBackendTCP,
_sm: &FeStartupPacket,
) -> Result<(), QueryError> {
Ok(())
@@ -762,7 +763,7 @@ impl postgres_backend::Handler for PageServerHandler {
async fn process_query(
&mut self,
pgb: &mut PostgresBackend,
pgb: &mut PostgresBackendTCP,
query_string: &str,
) -> Result<(), QueryError> {
let ctx = self.connection_ctx.attached_child();

View File

@@ -4,7 +4,7 @@ use crate::{
};
use anyhow::Context;
use once_cell::sync::Lazy;
use postgres_backend::{self, AuthType, PostgresBackend, QueryError};
use postgres_backend::{self, AuthType, PostgresBackend, PostgresBackendTCP, QueryError};
use pq_proto::{BeMessage, SINGLE_COL_ROWDESC};
use std::future;
use tokio::net::{TcpListener, TcpStream};
@@ -71,10 +71,10 @@ pub type ComputeReady = Result<DatabaseInfo, String>;
// TODO: replace with an http-based protocol.
struct MgmtHandler;
#[async_trait::async_trait]
impl postgres_backend::Handler for MgmtHandler {
impl postgres_backend::Handler<tokio::net::TcpStream> for MgmtHandler {
async fn process_query(
&mut self,
pgb: &mut PostgresBackend,
pgb: &mut PostgresBackendTCP,
query: &str,
) -> Result<(), QueryError> {
try_process_query(pgb, query).await.map_err(|e| {
@@ -84,7 +84,7 @@ impl postgres_backend::Handler for MgmtHandler {
}
}
async fn try_process_query(pgb: &mut PostgresBackend, query: &str) -> Result<(), QueryError> {
async fn try_process_query(pgb: &mut PostgresBackendTCP, query: &str) -> Result<(), QueryError> {
let resp: KickSession = serde_json::from_str(query).context("Failed to parse query as json")?;
let span = info_span!("event", session_id = resp.session_id);

View File

@@ -8,7 +8,7 @@ use crate::{
config::{ProxyConfig, TlsConfig},
console::{self, messages::MetricsAuxInfo},
error::io_error,
stream::{MeasuredStream, PqStream, Stream},
stream::{PqStream, Stream},
};
use anyhow::{bail, Context};
use futures::TryFutureExt;
@@ -18,6 +18,7 @@ use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams};
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{error, info, warn};
use utils::measured_stream::MeasuredStream;
/// Number of times we should retry the `/proxy_wake_compute` http request.
const NUM_RETRIES_WAKE_COMPUTE: usize = 1;
@@ -353,16 +354,24 @@ async fn proxy_pass(
aux: &MetricsAuxInfo,
) -> anyhow::Result<()> {
let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&aux.traffic_labels("tx"));
let mut client = MeasuredStream::new(client, |cnt| {
let mut client = MeasuredStream::new(
client,
|_| {},
|cnt| {
// Number of bytes we sent to the client (outbound).
m_sent.inc_by(cnt as u64);
});
},
);
let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&aux.traffic_labels("rx"));
let mut compute = MeasuredStream::new(compute, |cnt| {
let mut compute = MeasuredStream::new(
compute,
|_| {},
|cnt| {
// Number of bytes the client sent to the compute node (inbound).
m_recv.inc_by(cnt as u64);
});
},
);
// Starting from here we only proxy the client's traffic.
info!("performing the proxy pass...");

View File

@@ -217,68 +217,3 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
}
}
}
pin_project! {
/// This stream tracks all writes and calls user provided
/// callback when the underlying stream is flushed.
pub struct MeasuredStream<S, W> {
#[pin]
stream: S,
write_count: usize,
inc_write_count: W,
}
}
impl<S, W> MeasuredStream<S, W> {
pub fn new(stream: S, inc_write_count: W) -> Self {
Self {
stream,
write_count: 0,
inc_write_count,
}
}
}
impl<S: AsyncRead + Unpin, W> AsyncRead for MeasuredStream<S, W> {
fn poll_read(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> task::Poll<io::Result<()>> {
self.project().stream.poll_read(context, buf)
}
}
impl<S: AsyncWrite + Unpin, W: FnMut(usize)> AsyncWrite for MeasuredStream<S, W> {
fn poll_write(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
buf: &[u8],
) -> task::Poll<io::Result<usize>> {
let this = self.project();
this.stream.poll_write(context, buf).map_ok(|cnt| {
// Increment the write count.
*this.write_count += cnt;
cnt
})
}
fn poll_flush(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
) -> task::Poll<io::Result<()>> {
let this = self.project();
this.stream.poll_flush(context).map_ok(|()| {
// Call the user provided callback and reset the write count.
(this.inc_write_count)(*this.write_count);
*this.write_count = 0;
})
}
fn poll_shutdown(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
) -> task::Poll<io::Result<()>> {
self.project().stream.poll_shutdown(context)
}
}

View File

@@ -3,6 +3,7 @@
use anyhow::Context;
use std::str;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, info_span, Instrument};
use crate::auth::check_permission;
@@ -67,11 +68,13 @@ fn parse_cmd(cmd: &str) -> anyhow::Result<SafekeeperPostgresCommand> {
}
#[async_trait::async_trait]
impl postgres_backend::Handler for SafekeeperPostgresHandler {
impl<IO: AsyncRead + AsyncWrite + Unpin + Send> postgres_backend::Handler<IO>
for SafekeeperPostgresHandler
{
// tenant_id and timeline_id are passed in connection string params
fn startup(
&mut self,
_pgb: &mut PostgresBackend,
_pgb: &mut PostgresBackend<IO>,
sm: &FeStartupPacket,
) -> Result<(), QueryError> {
if let FeStartupPacket::StartupMessage { params, .. } = sm {
@@ -110,7 +113,7 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler {
fn check_auth_jwt(
&mut self,
_pgb: &mut PostgresBackend,
_pgb: &mut PostgresBackend<IO>,
jwt_response: &[u8],
) -> Result<(), QueryError> {
// this unwrap is never triggered, because check_auth_jwt only called when auth_type is NeonJWT
@@ -139,7 +142,7 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler {
async fn process_query(
&mut self,
pgb: &mut PostgresBackend,
pgb: &mut PostgresBackend<IO>,
query_string: &str,
) -> Result<(), QueryError> {
if query_string
@@ -216,9 +219,9 @@ impl SafekeeperPostgresHandler {
///
/// Handle IDENTIFY_SYSTEM replication command
///
async fn handle_identify_system(
async fn handle_identify_system<IO: AsyncRead + AsyncWrite + Unpin>(
&mut self,
pgb: &mut PostgresBackend,
pgb: &mut PostgresBackend<IO>,
) -> Result<(), QueryError> {
let tli = GlobalTimelines::get(self.ttid).map_err(|e| QueryError::Other(e.into()))?;

View File

@@ -12,6 +12,7 @@ use anyhow::Context;
use bytes::Bytes;
use postgres_backend::QueryError;
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::*;
use utils::id::TenantTimelineId;
@@ -60,9 +61,9 @@ struct AppendResult {
/// Handles command to craft logical message WAL record with given
/// content, and then append it with specified term and lsn. This
/// function is used to test safekeepers in different scenarios.
pub async fn handle_json_ctrl(
pub async fn handle_json_ctrl<IO: AsyncRead + AsyncWrite + Unpin>(
spg: &SafekeeperPostgresHandler,
pgb: &mut PostgresBackend,
pgb: &mut PostgresBackend<IO>,
append_request: &AppendLogicalMessage,
) -> Result<(), QueryError> {
info!("JSON_CTRL request: {append_request:?}");

View File

@@ -7,7 +7,7 @@ use anyhow::Result;
use metrics::{
core::{AtomicU64, Collector, Desc, GenericGaugeVec, Opts},
proto::MetricFamily,
Gauge, IntGaugeVec,
register_int_counter_vec, Gauge, IntCounterVec, IntGaugeVec,
};
use once_cell::sync::Lazy;
use postgres_ffi::XLogSegNo;
@@ -61,6 +61,14 @@ pub static PERSIST_CONTROL_FILE_SECONDS: Lazy<Histogram> = Lazy::new(|| {
)
.expect("Failed to register safekeeper_persist_control_file_seconds histogram vec")
});
pub static PG_IO_BYTES: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"safekeeper_pg_io_bytes",
"Bytes read from or written to any PostgreSQL connection",
&["direction"]
)
.expect("Failed to register safekeeper_pg_io_bytes gauge")
});
/// Metrics for WalStorage in a single timeline.
#[derive(Clone, Default)]

View File

@@ -20,6 +20,8 @@ use std::net::SocketAddr;
use std::sync::Arc;
use std::thread;
use std::thread::JoinHandle;
use tokio::io::AsyncRead;
use tokio::io::AsyncWrite;
use tokio::sync::mpsc::channel;
use tokio::sync::mpsc::error::TryRecvError;
use tokio::sync::mpsc::Receiver;
@@ -36,9 +38,9 @@ impl SafekeeperPostgresHandler {
/// Wrapper around handle_start_wal_push_guts handling result. Error is
/// handled here while we're still in walreceiver ttid span; with API
/// extension, this can probably be moved into postgres_backend.
pub async fn handle_start_wal_push(
pub async fn handle_start_wal_push<IO: AsyncRead + AsyncWrite + Unpin>(
&mut self,
pgb: &mut PostgresBackend,
pgb: &mut PostgresBackend<IO>,
) -> Result<(), QueryError> {
if let Err(end) = self.handle_start_wal_push_guts(pgb).await {
// Log the result and probably send it to the client, closing the stream.
@@ -47,9 +49,9 @@ impl SafekeeperPostgresHandler {
Ok(())
}
pub async fn handle_start_wal_push_guts(
pub async fn handle_start_wal_push_guts<IO: AsyncRead + AsyncWrite + Unpin>(
&mut self,
pgb: &mut PostgresBackend,
pgb: &mut PostgresBackend<IO>,
) -> Result<(), CopyStreamHandlerEnd> {
// Notify the libpq client that it's allowed to send `CopyData` messages
pgb.write_message(&BeMessage::CopyBothResponse).await?;
@@ -111,17 +113,17 @@ impl SafekeeperPostgresHandler {
}
}
struct NetworkReader<'a> {
struct NetworkReader<'a, IO> {
ttid: TenantTimelineId,
conn_id: ConnectionId,
pgb_reader: &'a mut PostgresBackendReader,
pgb_reader: &'a mut PostgresBackendReader<IO>,
peer_addr: SocketAddr,
// WalAcceptor is spawned when we learn server info from walproposer and
// create timeline; handle is put here.
acceptor_handle: &'a mut Option<JoinHandle<anyhow::Result<()>>>,
}
impl<'a> NetworkReader<'a> {
impl<'a, IO: AsyncRead + AsyncWrite + Unpin> NetworkReader<'a, IO> {
async fn run(
self,
msg_tx: Sender<ProposerAcceptorMessage>,
@@ -162,16 +164,16 @@ impl<'a> NetworkReader<'a> {
/// Read next message from walproposer.
/// TODO: Return Ok(None) on graceful termination.
async fn read_message(
pgb_reader: &mut PostgresBackendReader,
async fn read_message<IO: AsyncRead + AsyncWrite + Unpin>(
pgb_reader: &mut PostgresBackendReader<IO>,
) -> Result<ProposerAcceptorMessage, CopyStreamHandlerEnd> {
let copy_data = pgb_reader.read_copy_message().await?;
let msg = ProposerAcceptorMessage::parse(copy_data)?;
Ok(msg)
}
async fn read_network_loop(
pgb_reader: &mut PostgresBackendReader,
async fn read_network_loop<IO: AsyncRead + AsyncWrite + Unpin>(
pgb_reader: &mut PostgresBackendReader<IO>,
msg_tx: Sender<ProposerAcceptorMessage>,
mut next_msg: ProposerAcceptorMessage,
) -> Result<(), CopyStreamHandlerEnd> {
@@ -186,8 +188,8 @@ async fn read_network_loop(
/// Read replies from WalAcceptor and pass them back to socket. Returns Ok(())
/// if reply_rx closed; it must mean WalAcceptor terminated, joining it should
/// tell the error.
async fn network_write(
pgb_writer: &mut PostgresBackend,
async fn network_write<IO: AsyncRead + AsyncWrite + Unpin>(
pgb_writer: &mut PostgresBackend<IO>,
mut reply_rx: Receiver<AcceptorProposerMessage>,
) -> Result<(), CopyStreamHandlerEnd> {
let mut buf = BytesMut::with_capacity(128);

View File

@@ -13,6 +13,8 @@ use postgres_ffi::get_current_timestamp;
use postgres_ffi::{TimestampTz, MAX_SEND_SIZE};
use pq_proto::{BeMessage, ReplicationFeedback, WalSndKeepAlive, XLogDataBody};
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncRead, AsyncWrite};
use std::cmp::min;
use std::str;
use std::sync::Arc;
@@ -74,9 +76,9 @@ impl SafekeeperPostgresHandler {
/// Wrapper around handle_start_replication_guts handling result. Error is
/// handled here while we're still in walsender ttid span; with API
/// extension, this can probably be moved into postgres_backend.
pub async fn handle_start_replication(
pub async fn handle_start_replication<IO: AsyncRead + AsyncWrite + Unpin>(
&mut self,
pgb: &mut PostgresBackend,
pgb: &mut PostgresBackend<IO>,
start_pos: Lsn,
) -> Result<(), QueryError> {
if let Err(end) = self.handle_start_replication_guts(pgb, start_pos).await {
@@ -86,9 +88,9 @@ impl SafekeeperPostgresHandler {
Ok(())
}
pub async fn handle_start_replication_guts(
pub async fn handle_start_replication_guts<IO: AsyncRead + AsyncWrite + Unpin>(
&mut self,
pgb: &mut PostgresBackend,
pgb: &mut PostgresBackend<IO>,
start_pos: Lsn,
) -> Result<(), CopyStreamHandlerEnd> {
let appname = self.appname.clone();
@@ -176,8 +178,8 @@ impl SafekeeperPostgresHandler {
}
/// A half driving sending WAL.
struct WalSender<'a> {
pgb: &'a mut PostgresBackend,
struct WalSender<'a, IO> {
pgb: &'a mut PostgresBackend<IO>,
tli: Arc<Timeline>,
appname: Option<String>,
// Position since which we are sending next chunk.
@@ -194,7 +196,7 @@ struct WalSender<'a> {
send_buf: [u8; MAX_SEND_SIZE],
}
impl WalSender<'_> {
impl<IO: AsyncRead + AsyncWrite + Unpin> WalSender<'_, IO> {
/// Send WAL until
/// - an error occurs
/// - if we are streaming to walproposer, we've streamed until stop_pos
@@ -282,14 +284,14 @@ impl WalSender<'_> {
}
/// A half driving receiving replies.
struct ReplyReader {
reader: PostgresBackendReader,
struct ReplyReader<IO> {
reader: PostgresBackendReader<IO>,
tli: Arc<Timeline>,
replica_id: usize,
feedback: ReplicaState,
}
impl ReplyReader {
impl<IO: AsyncRead + AsyncWrite + Unpin> ReplyReader<IO> {
async fn run(&mut self) -> Result<(), CopyStreamHandlerEnd> {
loop {
let msg = self.reader.read_copy_message().await?;

View File

@@ -7,9 +7,10 @@ use postgres_backend::QueryError;
use std::{future, thread};
use tokio::net::TcpStream;
use tracing::*;
use utils::measured_stream::MeasuredStream;
use crate::handler::SafekeeperPostgresHandler;
use crate::SafeKeeperConf;
use crate::{handler::SafekeeperPostgresHandler, metrics::PG_IO_BYTES};
use postgres_backend::{AuthType, PostgresBackend};
/// Accept incoming TCP connections and spawn them into a background thread.
@@ -67,14 +68,29 @@ fn handle_socket(
.build()?;
let local = tokio::task::LocalSet::new();
let read_metrics = PG_IO_BYTES.with_label_values(&["read"]);
let write_metrics = PG_IO_BYTES.with_label_values(&["write"]);
socket.set_nodelay(true)?;
let peer_addr = socket.peer_addr()?;
// TODO: measure cross-az traffic
let socket = MeasuredStream::new(
socket,
|cnt| {
read_metrics.inc_by(cnt as u64);
},
|cnt| {
write_metrics.inc_by(cnt as u64);
},
);
let auth_type = match conf.auth {
None => AuthType::Trust,
Some(_) => AuthType::NeonJWT,
};
let mut conn_handler = SafekeeperPostgresHandler::new(conf, conn_id);
let pgbackend = PostgresBackend::new(socket, auth_type, None)?;
let pgbackend = PostgresBackend::new_from_io(socket, peer_addr, auth_type, None)?;
// libpq protocol between safekeeper and walproposer / pageserver
// We don't use shutdown.
local.block_on(