mirror of
https://github.com/neondatabase/neon.git
synced 2025-12-27 08:09:58 +00:00
Make postgres_backend use generic IO type (#3789)
- Support measuring inbound and outbound traffic in MeasuredStream - Start using MeasuredStream in safekeepers code
This commit is contained in:
committed by
GitHub
parent
8699342249
commit
d9a1329834
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -4532,6 +4532,7 @@ dependencies = [
|
||||
"metrics",
|
||||
"nix",
|
||||
"once_cell",
|
||||
"pin-project-lite",
|
||||
"rand",
|
||||
"routerify",
|
||||
"sentry",
|
||||
|
||||
@@ -59,14 +59,14 @@ pub fn is_expected_io_error(e: &io::Error) -> bool {
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait Handler {
|
||||
pub trait Handler<IO> {
|
||||
/// Handle single query.
|
||||
/// postgres_backend will issue ReadyForQuery after calling this (this
|
||||
/// might be not what we want after CopyData streaming, but currently we don't
|
||||
/// care). It will also flush out the output buffer.
|
||||
async fn process_query(
|
||||
&mut self,
|
||||
pgb: &mut PostgresBackend,
|
||||
pgb: &mut PostgresBackend<IO>,
|
||||
query_string: &str,
|
||||
) -> Result<(), QueryError>;
|
||||
|
||||
@@ -77,7 +77,7 @@ pub trait Handler {
|
||||
/// to override whole init logic in implementations.
|
||||
fn startup(
|
||||
&mut self,
|
||||
_pgb: &mut PostgresBackend,
|
||||
_pgb: &mut PostgresBackend<IO>,
|
||||
_sm: &FeStartupPacket,
|
||||
) -> Result<(), QueryError> {
|
||||
Ok(())
|
||||
@@ -86,7 +86,7 @@ pub trait Handler {
|
||||
/// Check auth jwt
|
||||
fn check_auth_jwt(
|
||||
&mut self,
|
||||
_pgb: &mut PostgresBackend,
|
||||
_pgb: &mut PostgresBackend<IO>,
|
||||
_jwt_response: &[u8],
|
||||
) -> Result<(), QueryError> {
|
||||
Err(QueryError::Other(anyhow::anyhow!("JWT auth failed")))
|
||||
@@ -115,12 +115,12 @@ pub enum ProcessMsgResult {
|
||||
}
|
||||
|
||||
/// Either plain TCP stream or encrypted one, implementing AsyncRead + AsyncWrite.
|
||||
pub enum MaybeTlsStream {
|
||||
Unencrypted(tokio::net::TcpStream),
|
||||
Tls(Box<tokio_rustls::server::TlsStream<tokio::net::TcpStream>>),
|
||||
pub enum MaybeTlsStream<IO> {
|
||||
Unencrypted(IO),
|
||||
Tls(Box<tokio_rustls::server::TlsStream<IO>>),
|
||||
}
|
||||
|
||||
impl AsyncWrite for MaybeTlsStream {
|
||||
impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MaybeTlsStream<IO> {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
@@ -147,7 +147,7 @@ impl AsyncWrite for MaybeTlsStream {
|
||||
}
|
||||
}
|
||||
}
|
||||
impl AsyncRead for MaybeTlsStream {
|
||||
impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncRead for MaybeTlsStream<IO> {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
@@ -192,13 +192,13 @@ impl fmt::Display for AuthType {
|
||||
/// PostgresBackend after call to `split`. In principle we could always store a
|
||||
/// pair of splitted handles, but that would force to to pay splitting price
|
||||
/// (Arc and kinda mutex inside polling) for all uses (e.g. pageserver).
|
||||
enum MaybeWriteOnly {
|
||||
Full(Framed<MaybeTlsStream>),
|
||||
WriteOnly(FramedWriter<MaybeTlsStream>),
|
||||
enum MaybeWriteOnly<IO> {
|
||||
Full(Framed<MaybeTlsStream<IO>>),
|
||||
WriteOnly(FramedWriter<MaybeTlsStream<IO>>),
|
||||
Broken, // temporary value palmed off during the split
|
||||
}
|
||||
|
||||
impl MaybeWriteOnly {
|
||||
impl<IO: AsyncRead + AsyncWrite + Unpin> MaybeWriteOnly<IO> {
|
||||
async fn read_startup_message(&mut self) -> Result<Option<FeStartupPacket>, ConnectionError> {
|
||||
match self {
|
||||
MaybeWriteOnly::Full(framed) => framed.read_startup_message().await,
|
||||
@@ -244,8 +244,8 @@ impl MaybeWriteOnly {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PostgresBackend {
|
||||
framed: MaybeWriteOnly,
|
||||
pub struct PostgresBackend<IO> {
|
||||
framed: MaybeWriteOnly<IO>,
|
||||
|
||||
pub state: ProtoState,
|
||||
|
||||
@@ -255,6 +255,8 @@ pub struct PostgresBackend {
|
||||
pub tls_config: Option<Arc<rustls::ServerConfig>>,
|
||||
}
|
||||
|
||||
pub type PostgresBackendTCP = PostgresBackend<tokio::net::TcpStream>;
|
||||
|
||||
pub fn query_from_cstring(query_string: Bytes) -> Vec<u8> {
|
||||
let mut query_string = query_string.to_vec();
|
||||
if let Some(ch) = query_string.last() {
|
||||
@@ -271,7 +273,7 @@ fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> {
|
||||
std::str::from_utf8(without_null).map_err(|e| e.into())
|
||||
}
|
||||
|
||||
impl PostgresBackend {
|
||||
impl PostgresBackend<tokio::net::TcpStream> {
|
||||
pub fn new(
|
||||
socket: tokio::net::TcpStream,
|
||||
auth_type: AuthType,
|
||||
@@ -288,6 +290,25 @@ impl PostgresBackend {
|
||||
peer_addr,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
|
||||
pub fn new_from_io(
|
||||
socket: IO,
|
||||
peer_addr: SocketAddr,
|
||||
auth_type: AuthType,
|
||||
tls_config: Option<Arc<rustls::ServerConfig>>,
|
||||
) -> io::Result<Self> {
|
||||
let stream = MaybeTlsStream::Unencrypted(socket);
|
||||
|
||||
Ok(Self {
|
||||
framed: MaybeWriteOnly::Full(Framed::new(stream)),
|
||||
state: ProtoState::Initialization,
|
||||
auth_type,
|
||||
tls_config,
|
||||
peer_addr,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_peer_addr(&self) -> &SocketAddr {
|
||||
&self.peer_addr
|
||||
@@ -346,14 +367,14 @@ impl PostgresBackend {
|
||||
/// to it in CopyData messages, and writes them to the connection
|
||||
///
|
||||
/// The caller is responsible for sending CopyOutResponse and CopyDone messages.
|
||||
pub fn copyout_writer(&mut self) -> CopyDataWriter {
|
||||
pub fn copyout_writer(&mut self) -> CopyDataWriter<IO> {
|
||||
CopyDataWriter { pgb: self }
|
||||
}
|
||||
|
||||
/// Wrapper for run_message_loop() that shuts down socket when we are done
|
||||
pub async fn run<F, S>(
|
||||
mut self,
|
||||
handler: &mut impl Handler,
|
||||
handler: &mut impl Handler<IO>,
|
||||
shutdown_watcher: F,
|
||||
) -> Result<(), QueryError>
|
||||
where
|
||||
@@ -369,7 +390,7 @@ impl PostgresBackend {
|
||||
|
||||
async fn run_message_loop<F, S>(
|
||||
&mut self,
|
||||
handler: &mut impl Handler,
|
||||
handler: &mut impl Handler<IO>,
|
||||
shutdown_watcher: F,
|
||||
) -> Result<(), QueryError>
|
||||
where
|
||||
@@ -426,9 +447,9 @@ impl PostgresBackend {
|
||||
|
||||
/// Try to upgrade MaybeTlsStream into actual TLS one, performing handshake.
|
||||
async fn tls_upgrade(
|
||||
src: MaybeTlsStream,
|
||||
src: MaybeTlsStream<IO>,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
) -> anyhow::Result<MaybeTlsStream> {
|
||||
) -> anyhow::Result<MaybeTlsStream<IO>> {
|
||||
match src {
|
||||
MaybeTlsStream::Unencrypted(s) => {
|
||||
let acceptor = TlsAcceptor::from(tls_config);
|
||||
@@ -466,7 +487,7 @@ impl PostgresBackend {
|
||||
|
||||
/// Split off owned read part from which messages can be read in different
|
||||
/// task/thread.
|
||||
pub fn split(&mut self) -> anyhow::Result<PostgresBackendReader> {
|
||||
pub fn split(&mut self) -> anyhow::Result<PostgresBackendReader<IO>> {
|
||||
// temporary replace stream with fake to cook split one, Indiana Jones style
|
||||
match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
|
||||
MaybeWriteOnly::Full(framed) => {
|
||||
@@ -482,7 +503,7 @@ impl PostgresBackend {
|
||||
}
|
||||
|
||||
/// Join read part back.
|
||||
pub fn unsplit(&mut self, reader: PostgresBackendReader) -> anyhow::Result<()> {
|
||||
pub fn unsplit(&mut self, reader: PostgresBackendReader<IO>) -> anyhow::Result<()> {
|
||||
// temporary replace stream with fake to cook joined one, Indiana Jones style
|
||||
match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
|
||||
MaybeWriteOnly::Full(_) => {
|
||||
@@ -499,7 +520,7 @@ impl PostgresBackend {
|
||||
|
||||
/// Perform handshake with the client, transitioning to Established.
|
||||
/// In case of EOF during handshake logs this, sets state to Closed and returns Ok(()).
|
||||
async fn handshake(&mut self, handler: &mut impl Handler) -> Result<(), QueryError> {
|
||||
async fn handshake(&mut self, handler: &mut impl Handler<IO>) -> Result<(), QueryError> {
|
||||
while self.state < ProtoState::Authentication {
|
||||
match self.framed.read_startup_message().await? {
|
||||
Some(msg) => {
|
||||
@@ -565,7 +586,7 @@ impl PostgresBackend {
|
||||
/// actual startup packet.
|
||||
async fn process_startup_message(
|
||||
&mut self,
|
||||
handler: &mut impl Handler,
|
||||
handler: &mut impl Handler<IO>,
|
||||
msg: FeStartupPacket,
|
||||
) -> Result<(), QueryError> {
|
||||
assert!(self.state < ProtoState::Authentication);
|
||||
@@ -629,7 +650,7 @@ impl PostgresBackend {
|
||||
|
||||
async fn process_message(
|
||||
&mut self,
|
||||
handler: &mut impl Handler,
|
||||
handler: &mut impl Handler<IO>,
|
||||
msg: FeMessage,
|
||||
unnamed_query_string: &mut Bytes,
|
||||
) -> Result<ProcessMsgResult, QueryError> {
|
||||
@@ -776,9 +797,9 @@ impl PostgresBackend {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PostgresBackendReader(FramedReader<MaybeTlsStream>);
|
||||
pub struct PostgresBackendReader<IO>(FramedReader<MaybeTlsStream<IO>>);
|
||||
|
||||
impl PostgresBackendReader {
|
||||
impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackendReader<IO> {
|
||||
/// Read full message or return None if connection is cleanly closed with no
|
||||
/// unprocessed data.
|
||||
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
|
||||
@@ -812,11 +833,11 @@ impl PostgresBackendReader {
|
||||
/// messages.
|
||||
///
|
||||
|
||||
pub struct CopyDataWriter<'a> {
|
||||
pgb: &'a mut PostgresBackend,
|
||||
pub struct CopyDataWriter<'a, IO> {
|
||||
pgb: &'a mut PostgresBackend<IO>,
|
||||
}
|
||||
|
||||
impl<'a> AsyncWrite for CopyDataWriter<'a> {
|
||||
impl<'a, IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for CopyDataWriter<'a, IO> {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
|
||||
@@ -4,6 +4,7 @@ use postgres_backend::{AuthType, Handler, PostgresBackend, QueryError};
|
||||
use pq_proto::{BeMessage, RowDescriptor};
|
||||
use std::io::Cursor;
|
||||
use std::{future, sync::Arc};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tokio_postgres::tls::MakeTlsConnect;
|
||||
@@ -22,11 +23,11 @@ async fn make_tcp_pair() -> (TcpStream, TcpStream) {
|
||||
struct TestHandler {}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Handler for TestHandler {
|
||||
impl<IO: AsyncRead + AsyncWrite + Unpin + Send> Handler<IO> for TestHandler {
|
||||
// return single col 'hey' for any query
|
||||
async fn process_query(
|
||||
&mut self,
|
||||
pgb: &mut PostgresBackend,
|
||||
pgb: &mut PostgresBackend<IO>,
|
||||
_query_string: &str,
|
||||
) -> Result<(), QueryError> {
|
||||
pgb.write_message_noflush(&BeMessage::RowDescription(&[RowDescriptor::text_col(
|
||||
|
||||
@@ -18,6 +18,7 @@ futures = { workspace = true}
|
||||
jsonwebtoken.workspace = true
|
||||
nix.workspace = true
|
||||
once_cell.workspace = true
|
||||
pin-project-lite.workspace = true
|
||||
routerify.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
|
||||
@@ -49,6 +49,8 @@ pub mod fs_ext;
|
||||
|
||||
pub mod history_buffer;
|
||||
|
||||
pub mod measured_stream;
|
||||
|
||||
/// use with fail::cfg("$name", "return(2000)")
|
||||
#[macro_export]
|
||||
macro_rules! failpoint_sleep_millis_async {
|
||||
|
||||
77
libs/utils/src/measured_stream.rs
Normal file
77
libs/utils/src/measured_stream.rs
Normal file
@@ -0,0 +1,77 @@
|
||||
use pin_project_lite::pin_project;
|
||||
use std::pin::Pin;
|
||||
use std::{io, task};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
|
||||
pin_project! {
|
||||
/// This stream tracks all writes and calls user provided
|
||||
/// callback when the underlying stream is flushed.
|
||||
pub struct MeasuredStream<S, R, W> {
|
||||
#[pin]
|
||||
stream: S,
|
||||
write_count: usize,
|
||||
inc_read_count: R,
|
||||
inc_write_count: W,
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, R, W> MeasuredStream<S, R, W> {
|
||||
pub fn new(stream: S, inc_read_count: R, inc_write_count: W) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
write_count: 0,
|
||||
inc_read_count,
|
||||
inc_write_count,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin, R: FnMut(usize), W> AsyncRead for MeasuredStream<S, R, W> {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
let this = self.project();
|
||||
let filled = buf.filled().len();
|
||||
this.stream.poll_read(context, buf).map_ok(|()| {
|
||||
let cnt = buf.filled().len() - filled;
|
||||
// Increment the read count.
|
||||
(this.inc_read_count)(cnt);
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncWrite + Unpin, R, W: FnMut(usize)> AsyncWrite for MeasuredStream<S, R, W> {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> task::Poll<io::Result<usize>> {
|
||||
let this = self.project();
|
||||
this.stream.poll_write(context, buf).map_ok(|cnt| {
|
||||
// Increment the write count.
|
||||
*this.write_count += cnt;
|
||||
cnt
|
||||
})
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
let this = self.project();
|
||||
this.stream.poll_flush(context).map_ok(|()| {
|
||||
// Call the user provided callback and reset the write count.
|
||||
(this.inc_write_count)(*this.write_count);
|
||||
*this.write_count = 0;
|
||||
})
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
self.project().stream.poll_shutdown(context)
|
||||
}
|
||||
}
|
||||
@@ -20,6 +20,7 @@ use pageserver_api::models::{
|
||||
PagestreamFeMessage, PagestreamGetPageRequest, PagestreamGetPageResponse,
|
||||
PagestreamNblocksRequest, PagestreamNblocksResponse,
|
||||
};
|
||||
use postgres_backend::PostgresBackendTCP;
|
||||
use postgres_backend::{self, is_expected_io_error, AuthType, PostgresBackend, QueryError};
|
||||
use pq_proto::framed::ConnectionError;
|
||||
use pq_proto::FeStartupPacket;
|
||||
@@ -54,7 +55,7 @@ use crate::trace::Tracer;
|
||||
use postgres_ffi::pg_constants::DEFAULTTABLESPACE_OID;
|
||||
use postgres_ffi::BLCKSZ;
|
||||
|
||||
fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream<Item = io::Result<Bytes>> + '_ {
|
||||
fn copyin_stream(pgb: &mut PostgresBackendTCP) -> impl Stream<Item = io::Result<Bytes>> + '_ {
|
||||
async_stream::try_stream! {
|
||||
loop {
|
||||
let msg = tokio::select! {
|
||||
@@ -288,7 +289,7 @@ impl PageServerHandler {
|
||||
#[instrument(skip(self, pgb, ctx))]
|
||||
async fn handle_pagerequests(
|
||||
&self,
|
||||
pgb: &mut PostgresBackend,
|
||||
pgb: &mut PostgresBackendTCP,
|
||||
tenant_id: TenantId,
|
||||
timeline_id: TimelineId,
|
||||
ctx: RequestContext,
|
||||
@@ -392,7 +393,7 @@ impl PageServerHandler {
|
||||
#[instrument(skip(self, pgb, ctx))]
|
||||
async fn handle_import_basebackup(
|
||||
&self,
|
||||
pgb: &mut PostgresBackend,
|
||||
pgb: &mut PostgresBackendTCP,
|
||||
tenant_id: TenantId,
|
||||
timeline_id: TimelineId,
|
||||
base_lsn: Lsn,
|
||||
@@ -448,7 +449,7 @@ impl PageServerHandler {
|
||||
#[instrument(skip(self, pgb, ctx))]
|
||||
async fn handle_import_wal(
|
||||
&self,
|
||||
pgb: &mut PostgresBackend,
|
||||
pgb: &mut PostgresBackendTCP,
|
||||
tenant_id: TenantId,
|
||||
timeline_id: TimelineId,
|
||||
start_lsn: Lsn,
|
||||
@@ -659,7 +660,7 @@ impl PageServerHandler {
|
||||
#[instrument(skip(self, pgb, ctx))]
|
||||
async fn handle_basebackup_request(
|
||||
&mut self,
|
||||
pgb: &mut PostgresBackend,
|
||||
pgb: &mut PostgresBackendTCP,
|
||||
tenant_id: TenantId,
|
||||
timeline_id: TimelineId,
|
||||
lsn: Option<Lsn>,
|
||||
@@ -723,10 +724,10 @@ impl PageServerHandler {
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl postgres_backend::Handler for PageServerHandler {
|
||||
impl postgres_backend::Handler<tokio::net::TcpStream> for PageServerHandler {
|
||||
fn check_auth_jwt(
|
||||
&mut self,
|
||||
_pgb: &mut PostgresBackend,
|
||||
_pgb: &mut PostgresBackendTCP,
|
||||
jwt_response: &[u8],
|
||||
) -> Result<(), QueryError> {
|
||||
// this unwrap is never triggered, because check_auth_jwt only called when auth_type is NeonJWT
|
||||
@@ -754,7 +755,7 @@ impl postgres_backend::Handler for PageServerHandler {
|
||||
|
||||
fn startup(
|
||||
&mut self,
|
||||
_pgb: &mut PostgresBackend,
|
||||
_pgb: &mut PostgresBackendTCP,
|
||||
_sm: &FeStartupPacket,
|
||||
) -> Result<(), QueryError> {
|
||||
Ok(())
|
||||
@@ -762,7 +763,7 @@ impl postgres_backend::Handler for PageServerHandler {
|
||||
|
||||
async fn process_query(
|
||||
&mut self,
|
||||
pgb: &mut PostgresBackend,
|
||||
pgb: &mut PostgresBackendTCP,
|
||||
query_string: &str,
|
||||
) -> Result<(), QueryError> {
|
||||
let ctx = self.connection_ctx.attached_child();
|
||||
|
||||
@@ -4,7 +4,7 @@ use crate::{
|
||||
};
|
||||
use anyhow::Context;
|
||||
use once_cell::sync::Lazy;
|
||||
use postgres_backend::{self, AuthType, PostgresBackend, QueryError};
|
||||
use postgres_backend::{self, AuthType, PostgresBackend, PostgresBackendTCP, QueryError};
|
||||
use pq_proto::{BeMessage, SINGLE_COL_ROWDESC};
|
||||
use std::future;
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
@@ -71,10 +71,10 @@ pub type ComputeReady = Result<DatabaseInfo, String>;
|
||||
// TODO: replace with an http-based protocol.
|
||||
struct MgmtHandler;
|
||||
#[async_trait::async_trait]
|
||||
impl postgres_backend::Handler for MgmtHandler {
|
||||
impl postgres_backend::Handler<tokio::net::TcpStream> for MgmtHandler {
|
||||
async fn process_query(
|
||||
&mut self,
|
||||
pgb: &mut PostgresBackend,
|
||||
pgb: &mut PostgresBackendTCP,
|
||||
query: &str,
|
||||
) -> Result<(), QueryError> {
|
||||
try_process_query(pgb, query).await.map_err(|e| {
|
||||
@@ -84,7 +84,7 @@ impl postgres_backend::Handler for MgmtHandler {
|
||||
}
|
||||
}
|
||||
|
||||
async fn try_process_query(pgb: &mut PostgresBackend, query: &str) -> Result<(), QueryError> {
|
||||
async fn try_process_query(pgb: &mut PostgresBackendTCP, query: &str) -> Result<(), QueryError> {
|
||||
let resp: KickSession = serde_json::from_str(query).context("Failed to parse query as json")?;
|
||||
|
||||
let span = info_span!("event", session_id = resp.session_id);
|
||||
|
||||
@@ -8,7 +8,7 @@ use crate::{
|
||||
config::{ProxyConfig, TlsConfig},
|
||||
console::{self, messages::MetricsAuxInfo},
|
||||
error::io_error,
|
||||
stream::{MeasuredStream, PqStream, Stream},
|
||||
stream::{PqStream, Stream},
|
||||
};
|
||||
use anyhow::{bail, Context};
|
||||
use futures::TryFutureExt;
|
||||
@@ -18,6 +18,7 @@ use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams};
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{error, info, warn};
|
||||
use utils::measured_stream::MeasuredStream;
|
||||
|
||||
/// Number of times we should retry the `/proxy_wake_compute` http request.
|
||||
const NUM_RETRIES_WAKE_COMPUTE: usize = 1;
|
||||
@@ -353,16 +354,24 @@ async fn proxy_pass(
|
||||
aux: &MetricsAuxInfo,
|
||||
) -> anyhow::Result<()> {
|
||||
let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&aux.traffic_labels("tx"));
|
||||
let mut client = MeasuredStream::new(client, |cnt| {
|
||||
let mut client = MeasuredStream::new(
|
||||
client,
|
||||
|_| {},
|
||||
|cnt| {
|
||||
// Number of bytes we sent to the client (outbound).
|
||||
m_sent.inc_by(cnt as u64);
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&aux.traffic_labels("rx"));
|
||||
let mut compute = MeasuredStream::new(compute, |cnt| {
|
||||
let mut compute = MeasuredStream::new(
|
||||
compute,
|
||||
|_| {},
|
||||
|cnt| {
|
||||
// Number of bytes the client sent to the compute node (inbound).
|
||||
m_recv.inc_by(cnt as u64);
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
// Starting from here we only proxy the client's traffic.
|
||||
info!("performing the proxy pass...");
|
||||
|
||||
@@ -217,68 +217,3 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
/// This stream tracks all writes and calls user provided
|
||||
/// callback when the underlying stream is flushed.
|
||||
pub struct MeasuredStream<S, W> {
|
||||
#[pin]
|
||||
stream: S,
|
||||
write_count: usize,
|
||||
inc_write_count: W,
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, W> MeasuredStream<S, W> {
|
||||
pub fn new(stream: S, inc_write_count: W) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
write_count: 0,
|
||||
inc_write_count,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin, W> AsyncRead for MeasuredStream<S, W> {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
self.project().stream.poll_read(context, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncWrite + Unpin, W: FnMut(usize)> AsyncWrite for MeasuredStream<S, W> {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> task::Poll<io::Result<usize>> {
|
||||
let this = self.project();
|
||||
this.stream.poll_write(context, buf).map_ok(|cnt| {
|
||||
// Increment the write count.
|
||||
*this.write_count += cnt;
|
||||
cnt
|
||||
})
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
let this = self.project();
|
||||
this.stream.poll_flush(context).map_ok(|()| {
|
||||
// Call the user provided callback and reset the write count.
|
||||
(this.inc_write_count)(*this.write_count);
|
||||
*this.write_count = 0;
|
||||
})
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
self.project().stream.poll_shutdown(context)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
use anyhow::Context;
|
||||
use std::str;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{info, info_span, Instrument};
|
||||
|
||||
use crate::auth::check_permission;
|
||||
@@ -67,11 +68,13 @@ fn parse_cmd(cmd: &str) -> anyhow::Result<SafekeeperPostgresCommand> {
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl postgres_backend::Handler for SafekeeperPostgresHandler {
|
||||
impl<IO: AsyncRead + AsyncWrite + Unpin + Send> postgres_backend::Handler<IO>
|
||||
for SafekeeperPostgresHandler
|
||||
{
|
||||
// tenant_id and timeline_id are passed in connection string params
|
||||
fn startup(
|
||||
&mut self,
|
||||
_pgb: &mut PostgresBackend,
|
||||
_pgb: &mut PostgresBackend<IO>,
|
||||
sm: &FeStartupPacket,
|
||||
) -> Result<(), QueryError> {
|
||||
if let FeStartupPacket::StartupMessage { params, .. } = sm {
|
||||
@@ -110,7 +113,7 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler {
|
||||
|
||||
fn check_auth_jwt(
|
||||
&mut self,
|
||||
_pgb: &mut PostgresBackend,
|
||||
_pgb: &mut PostgresBackend<IO>,
|
||||
jwt_response: &[u8],
|
||||
) -> Result<(), QueryError> {
|
||||
// this unwrap is never triggered, because check_auth_jwt only called when auth_type is NeonJWT
|
||||
@@ -139,7 +142,7 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler {
|
||||
|
||||
async fn process_query(
|
||||
&mut self,
|
||||
pgb: &mut PostgresBackend,
|
||||
pgb: &mut PostgresBackend<IO>,
|
||||
query_string: &str,
|
||||
) -> Result<(), QueryError> {
|
||||
if query_string
|
||||
@@ -216,9 +219,9 @@ impl SafekeeperPostgresHandler {
|
||||
///
|
||||
/// Handle IDENTIFY_SYSTEM replication command
|
||||
///
|
||||
async fn handle_identify_system(
|
||||
async fn handle_identify_system<IO: AsyncRead + AsyncWrite + Unpin>(
|
||||
&mut self,
|
||||
pgb: &mut PostgresBackend,
|
||||
pgb: &mut PostgresBackend<IO>,
|
||||
) -> Result<(), QueryError> {
|
||||
let tli = GlobalTimelines::get(self.ttid).map_err(|e| QueryError::Other(e.into()))?;
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ use anyhow::Context;
|
||||
use bytes::Bytes;
|
||||
use postgres_backend::QueryError;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::*;
|
||||
use utils::id::TenantTimelineId;
|
||||
|
||||
@@ -60,9 +61,9 @@ struct AppendResult {
|
||||
/// Handles command to craft logical message WAL record with given
|
||||
/// content, and then append it with specified term and lsn. This
|
||||
/// function is used to test safekeepers in different scenarios.
|
||||
pub async fn handle_json_ctrl(
|
||||
pub async fn handle_json_ctrl<IO: AsyncRead + AsyncWrite + Unpin>(
|
||||
spg: &SafekeeperPostgresHandler,
|
||||
pgb: &mut PostgresBackend,
|
||||
pgb: &mut PostgresBackend<IO>,
|
||||
append_request: &AppendLogicalMessage,
|
||||
) -> Result<(), QueryError> {
|
||||
info!("JSON_CTRL request: {append_request:?}");
|
||||
|
||||
@@ -7,7 +7,7 @@ use anyhow::Result;
|
||||
use metrics::{
|
||||
core::{AtomicU64, Collector, Desc, GenericGaugeVec, Opts},
|
||||
proto::MetricFamily,
|
||||
Gauge, IntGaugeVec,
|
||||
register_int_counter_vec, Gauge, IntCounterVec, IntGaugeVec,
|
||||
};
|
||||
use once_cell::sync::Lazy;
|
||||
use postgres_ffi::XLogSegNo;
|
||||
@@ -61,6 +61,14 @@ pub static PERSIST_CONTROL_FILE_SECONDS: Lazy<Histogram> = Lazy::new(|| {
|
||||
)
|
||||
.expect("Failed to register safekeeper_persist_control_file_seconds histogram vec")
|
||||
});
|
||||
pub static PG_IO_BYTES: Lazy<IntCounterVec> = Lazy::new(|| {
|
||||
register_int_counter_vec!(
|
||||
"safekeeper_pg_io_bytes",
|
||||
"Bytes read from or written to any PostgreSQL connection",
|
||||
&["direction"]
|
||||
)
|
||||
.expect("Failed to register safekeeper_pg_io_bytes gauge")
|
||||
});
|
||||
|
||||
/// Metrics for WalStorage in a single timeline.
|
||||
#[derive(Clone, Default)]
|
||||
|
||||
@@ -20,6 +20,8 @@ use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
use std::thread::JoinHandle;
|
||||
use tokio::io::AsyncRead;
|
||||
use tokio::io::AsyncWrite;
|
||||
use tokio::sync::mpsc::channel;
|
||||
use tokio::sync::mpsc::error::TryRecvError;
|
||||
use tokio::sync::mpsc::Receiver;
|
||||
@@ -36,9 +38,9 @@ impl SafekeeperPostgresHandler {
|
||||
/// Wrapper around handle_start_wal_push_guts handling result. Error is
|
||||
/// handled here while we're still in walreceiver ttid span; with API
|
||||
/// extension, this can probably be moved into postgres_backend.
|
||||
pub async fn handle_start_wal_push(
|
||||
pub async fn handle_start_wal_push<IO: AsyncRead + AsyncWrite + Unpin>(
|
||||
&mut self,
|
||||
pgb: &mut PostgresBackend,
|
||||
pgb: &mut PostgresBackend<IO>,
|
||||
) -> Result<(), QueryError> {
|
||||
if let Err(end) = self.handle_start_wal_push_guts(pgb).await {
|
||||
// Log the result and probably send it to the client, closing the stream.
|
||||
@@ -47,9 +49,9 @@ impl SafekeeperPostgresHandler {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn handle_start_wal_push_guts(
|
||||
pub async fn handle_start_wal_push_guts<IO: AsyncRead + AsyncWrite + Unpin>(
|
||||
&mut self,
|
||||
pgb: &mut PostgresBackend,
|
||||
pgb: &mut PostgresBackend<IO>,
|
||||
) -> Result<(), CopyStreamHandlerEnd> {
|
||||
// Notify the libpq client that it's allowed to send `CopyData` messages
|
||||
pgb.write_message(&BeMessage::CopyBothResponse).await?;
|
||||
@@ -111,17 +113,17 @@ impl SafekeeperPostgresHandler {
|
||||
}
|
||||
}
|
||||
|
||||
struct NetworkReader<'a> {
|
||||
struct NetworkReader<'a, IO> {
|
||||
ttid: TenantTimelineId,
|
||||
conn_id: ConnectionId,
|
||||
pgb_reader: &'a mut PostgresBackendReader,
|
||||
pgb_reader: &'a mut PostgresBackendReader<IO>,
|
||||
peer_addr: SocketAddr,
|
||||
// WalAcceptor is spawned when we learn server info from walproposer and
|
||||
// create timeline; handle is put here.
|
||||
acceptor_handle: &'a mut Option<JoinHandle<anyhow::Result<()>>>,
|
||||
}
|
||||
|
||||
impl<'a> NetworkReader<'a> {
|
||||
impl<'a, IO: AsyncRead + AsyncWrite + Unpin> NetworkReader<'a, IO> {
|
||||
async fn run(
|
||||
self,
|
||||
msg_tx: Sender<ProposerAcceptorMessage>,
|
||||
@@ -162,16 +164,16 @@ impl<'a> NetworkReader<'a> {
|
||||
|
||||
/// Read next message from walproposer.
|
||||
/// TODO: Return Ok(None) on graceful termination.
|
||||
async fn read_message(
|
||||
pgb_reader: &mut PostgresBackendReader,
|
||||
async fn read_message<IO: AsyncRead + AsyncWrite + Unpin>(
|
||||
pgb_reader: &mut PostgresBackendReader<IO>,
|
||||
) -> Result<ProposerAcceptorMessage, CopyStreamHandlerEnd> {
|
||||
let copy_data = pgb_reader.read_copy_message().await?;
|
||||
let msg = ProposerAcceptorMessage::parse(copy_data)?;
|
||||
Ok(msg)
|
||||
}
|
||||
|
||||
async fn read_network_loop(
|
||||
pgb_reader: &mut PostgresBackendReader,
|
||||
async fn read_network_loop<IO: AsyncRead + AsyncWrite + Unpin>(
|
||||
pgb_reader: &mut PostgresBackendReader<IO>,
|
||||
msg_tx: Sender<ProposerAcceptorMessage>,
|
||||
mut next_msg: ProposerAcceptorMessage,
|
||||
) -> Result<(), CopyStreamHandlerEnd> {
|
||||
@@ -186,8 +188,8 @@ async fn read_network_loop(
|
||||
/// Read replies from WalAcceptor and pass them back to socket. Returns Ok(())
|
||||
/// if reply_rx closed; it must mean WalAcceptor terminated, joining it should
|
||||
/// tell the error.
|
||||
async fn network_write(
|
||||
pgb_writer: &mut PostgresBackend,
|
||||
async fn network_write<IO: AsyncRead + AsyncWrite + Unpin>(
|
||||
pgb_writer: &mut PostgresBackend<IO>,
|
||||
mut reply_rx: Receiver<AcceptorProposerMessage>,
|
||||
) -> Result<(), CopyStreamHandlerEnd> {
|
||||
let mut buf = BytesMut::with_capacity(128);
|
||||
|
||||
@@ -13,6 +13,8 @@ use postgres_ffi::get_current_timestamp;
|
||||
use postgres_ffi::{TimestampTz, MAX_SEND_SIZE};
|
||||
use pq_proto::{BeMessage, ReplicationFeedback, WalSndKeepAlive, XLogDataBody};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
use std::cmp::min;
|
||||
use std::str;
|
||||
use std::sync::Arc;
|
||||
@@ -74,9 +76,9 @@ impl SafekeeperPostgresHandler {
|
||||
/// Wrapper around handle_start_replication_guts handling result. Error is
|
||||
/// handled here while we're still in walsender ttid span; with API
|
||||
/// extension, this can probably be moved into postgres_backend.
|
||||
pub async fn handle_start_replication(
|
||||
pub async fn handle_start_replication<IO: AsyncRead + AsyncWrite + Unpin>(
|
||||
&mut self,
|
||||
pgb: &mut PostgresBackend,
|
||||
pgb: &mut PostgresBackend<IO>,
|
||||
start_pos: Lsn,
|
||||
) -> Result<(), QueryError> {
|
||||
if let Err(end) = self.handle_start_replication_guts(pgb, start_pos).await {
|
||||
@@ -86,9 +88,9 @@ impl SafekeeperPostgresHandler {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn handle_start_replication_guts(
|
||||
pub async fn handle_start_replication_guts<IO: AsyncRead + AsyncWrite + Unpin>(
|
||||
&mut self,
|
||||
pgb: &mut PostgresBackend,
|
||||
pgb: &mut PostgresBackend<IO>,
|
||||
start_pos: Lsn,
|
||||
) -> Result<(), CopyStreamHandlerEnd> {
|
||||
let appname = self.appname.clone();
|
||||
@@ -176,8 +178,8 @@ impl SafekeeperPostgresHandler {
|
||||
}
|
||||
|
||||
/// A half driving sending WAL.
|
||||
struct WalSender<'a> {
|
||||
pgb: &'a mut PostgresBackend,
|
||||
struct WalSender<'a, IO> {
|
||||
pgb: &'a mut PostgresBackend<IO>,
|
||||
tli: Arc<Timeline>,
|
||||
appname: Option<String>,
|
||||
// Position since which we are sending next chunk.
|
||||
@@ -194,7 +196,7 @@ struct WalSender<'a> {
|
||||
send_buf: [u8; MAX_SEND_SIZE],
|
||||
}
|
||||
|
||||
impl WalSender<'_> {
|
||||
impl<IO: AsyncRead + AsyncWrite + Unpin> WalSender<'_, IO> {
|
||||
/// Send WAL until
|
||||
/// - an error occurs
|
||||
/// - if we are streaming to walproposer, we've streamed until stop_pos
|
||||
@@ -282,14 +284,14 @@ impl WalSender<'_> {
|
||||
}
|
||||
|
||||
/// A half driving receiving replies.
|
||||
struct ReplyReader {
|
||||
reader: PostgresBackendReader,
|
||||
struct ReplyReader<IO> {
|
||||
reader: PostgresBackendReader<IO>,
|
||||
tli: Arc<Timeline>,
|
||||
replica_id: usize,
|
||||
feedback: ReplicaState,
|
||||
}
|
||||
|
||||
impl ReplyReader {
|
||||
impl<IO: AsyncRead + AsyncWrite + Unpin> ReplyReader<IO> {
|
||||
async fn run(&mut self) -> Result<(), CopyStreamHandlerEnd> {
|
||||
loop {
|
||||
let msg = self.reader.read_copy_message().await?;
|
||||
|
||||
@@ -7,9 +7,10 @@ use postgres_backend::QueryError;
|
||||
use std::{future, thread};
|
||||
use tokio::net::TcpStream;
|
||||
use tracing::*;
|
||||
use utils::measured_stream::MeasuredStream;
|
||||
|
||||
use crate::handler::SafekeeperPostgresHandler;
|
||||
use crate::SafeKeeperConf;
|
||||
use crate::{handler::SafekeeperPostgresHandler, metrics::PG_IO_BYTES};
|
||||
use postgres_backend::{AuthType, PostgresBackend};
|
||||
|
||||
/// Accept incoming TCP connections and spawn them into a background thread.
|
||||
@@ -67,14 +68,29 @@ fn handle_socket(
|
||||
.build()?;
|
||||
let local = tokio::task::LocalSet::new();
|
||||
|
||||
let read_metrics = PG_IO_BYTES.with_label_values(&["read"]);
|
||||
let write_metrics = PG_IO_BYTES.with_label_values(&["write"]);
|
||||
|
||||
socket.set_nodelay(true)?;
|
||||
let peer_addr = socket.peer_addr()?;
|
||||
|
||||
// TODO: measure cross-az traffic
|
||||
let socket = MeasuredStream::new(
|
||||
socket,
|
||||
|cnt| {
|
||||
read_metrics.inc_by(cnt as u64);
|
||||
},
|
||||
|cnt| {
|
||||
write_metrics.inc_by(cnt as u64);
|
||||
},
|
||||
);
|
||||
|
||||
let auth_type = match conf.auth {
|
||||
None => AuthType::Trust,
|
||||
Some(_) => AuthType::NeonJWT,
|
||||
};
|
||||
let mut conn_handler = SafekeeperPostgresHandler::new(conf, conn_id);
|
||||
let pgbackend = PostgresBackend::new(socket, auth_type, None)?;
|
||||
let pgbackend = PostgresBackend::new_from_io(socket, peer_addr, auth_type, None)?;
|
||||
// libpq protocol between safekeeper and walproposer / pageserver
|
||||
// We don't use shutdown.
|
||||
local.block_on(
|
||||
|
||||
Reference in New Issue
Block a user