Make postgres_backend use generic IO type (#3789)

- Support measuring inbound and outbound traffic in MeasuredStream
- Start using MeasuredStream in safekeepers code
This commit is contained in:
Arthur Petukhovsky
2023-03-13 12:18:10 +03:00
committed by GitHub
parent 8699342249
commit d9a1329834
16 changed files with 234 additions and 154 deletions

View File

@@ -59,14 +59,14 @@ pub fn is_expected_io_error(e: &io::Error) -> bool {
}
#[async_trait::async_trait]
pub trait Handler {
pub trait Handler<IO> {
/// Handle single query.
/// postgres_backend will issue ReadyForQuery after calling this (this
/// might be not what we want after CopyData streaming, but currently we don't
/// care). It will also flush out the output buffer.
async fn process_query(
&mut self,
pgb: &mut PostgresBackend,
pgb: &mut PostgresBackend<IO>,
query_string: &str,
) -> Result<(), QueryError>;
@@ -77,7 +77,7 @@ pub trait Handler {
/// to override whole init logic in implementations.
fn startup(
&mut self,
_pgb: &mut PostgresBackend,
_pgb: &mut PostgresBackend<IO>,
_sm: &FeStartupPacket,
) -> Result<(), QueryError> {
Ok(())
@@ -86,7 +86,7 @@ pub trait Handler {
/// Check auth jwt
fn check_auth_jwt(
&mut self,
_pgb: &mut PostgresBackend,
_pgb: &mut PostgresBackend<IO>,
_jwt_response: &[u8],
) -> Result<(), QueryError> {
Err(QueryError::Other(anyhow::anyhow!("JWT auth failed")))
@@ -115,12 +115,12 @@ pub enum ProcessMsgResult {
}
/// Either plain TCP stream or encrypted one, implementing AsyncRead + AsyncWrite.
pub enum MaybeTlsStream {
Unencrypted(tokio::net::TcpStream),
Tls(Box<tokio_rustls::server::TlsStream<tokio::net::TcpStream>>),
pub enum MaybeTlsStream<IO> {
Unencrypted(IO),
Tls(Box<tokio_rustls::server::TlsStream<IO>>),
}
impl AsyncWrite for MaybeTlsStream {
impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MaybeTlsStream<IO> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
@@ -147,7 +147,7 @@ impl AsyncWrite for MaybeTlsStream {
}
}
}
impl AsyncRead for MaybeTlsStream {
impl<IO: AsyncRead + AsyncWrite + Unpin> AsyncRead for MaybeTlsStream<IO> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
@@ -192,13 +192,13 @@ impl fmt::Display for AuthType {
/// PostgresBackend after call to `split`. In principle we could always store a
/// pair of splitted handles, but that would force to to pay splitting price
/// (Arc and kinda mutex inside polling) for all uses (e.g. pageserver).
enum MaybeWriteOnly {
Full(Framed<MaybeTlsStream>),
WriteOnly(FramedWriter<MaybeTlsStream>),
enum MaybeWriteOnly<IO> {
Full(Framed<MaybeTlsStream<IO>>),
WriteOnly(FramedWriter<MaybeTlsStream<IO>>),
Broken, // temporary value palmed off during the split
}
impl MaybeWriteOnly {
impl<IO: AsyncRead + AsyncWrite + Unpin> MaybeWriteOnly<IO> {
async fn read_startup_message(&mut self) -> Result<Option<FeStartupPacket>, ConnectionError> {
match self {
MaybeWriteOnly::Full(framed) => framed.read_startup_message().await,
@@ -244,8 +244,8 @@ impl MaybeWriteOnly {
}
}
pub struct PostgresBackend {
framed: MaybeWriteOnly,
pub struct PostgresBackend<IO> {
framed: MaybeWriteOnly<IO>,
pub state: ProtoState,
@@ -255,6 +255,8 @@ pub struct PostgresBackend {
pub tls_config: Option<Arc<rustls::ServerConfig>>,
}
pub type PostgresBackendTCP = PostgresBackend<tokio::net::TcpStream>;
pub fn query_from_cstring(query_string: Bytes) -> Vec<u8> {
let mut query_string = query_string.to_vec();
if let Some(ch) = query_string.last() {
@@ -271,7 +273,7 @@ fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> {
std::str::from_utf8(without_null).map_err(|e| e.into())
}
impl PostgresBackend {
impl PostgresBackend<tokio::net::TcpStream> {
pub fn new(
socket: tokio::net::TcpStream,
auth_type: AuthType,
@@ -288,6 +290,25 @@ impl PostgresBackend {
peer_addr,
})
}
}
impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
pub fn new_from_io(
socket: IO,
peer_addr: SocketAddr,
auth_type: AuthType,
tls_config: Option<Arc<rustls::ServerConfig>>,
) -> io::Result<Self> {
let stream = MaybeTlsStream::Unencrypted(socket);
Ok(Self {
framed: MaybeWriteOnly::Full(Framed::new(stream)),
state: ProtoState::Initialization,
auth_type,
tls_config,
peer_addr,
})
}
pub fn get_peer_addr(&self) -> &SocketAddr {
&self.peer_addr
@@ -346,14 +367,14 @@ impl PostgresBackend {
/// to it in CopyData messages, and writes them to the connection
///
/// The caller is responsible for sending CopyOutResponse and CopyDone messages.
pub fn copyout_writer(&mut self) -> CopyDataWriter {
pub fn copyout_writer(&mut self) -> CopyDataWriter<IO> {
CopyDataWriter { pgb: self }
}
/// Wrapper for run_message_loop() that shuts down socket when we are done
pub async fn run<F, S>(
mut self,
handler: &mut impl Handler,
handler: &mut impl Handler<IO>,
shutdown_watcher: F,
) -> Result<(), QueryError>
where
@@ -369,7 +390,7 @@ impl PostgresBackend {
async fn run_message_loop<F, S>(
&mut self,
handler: &mut impl Handler,
handler: &mut impl Handler<IO>,
shutdown_watcher: F,
) -> Result<(), QueryError>
where
@@ -426,9 +447,9 @@ impl PostgresBackend {
/// Try to upgrade MaybeTlsStream into actual TLS one, performing handshake.
async fn tls_upgrade(
src: MaybeTlsStream,
src: MaybeTlsStream<IO>,
tls_config: Arc<rustls::ServerConfig>,
) -> anyhow::Result<MaybeTlsStream> {
) -> anyhow::Result<MaybeTlsStream<IO>> {
match src {
MaybeTlsStream::Unencrypted(s) => {
let acceptor = TlsAcceptor::from(tls_config);
@@ -466,7 +487,7 @@ impl PostgresBackend {
/// Split off owned read part from which messages can be read in different
/// task/thread.
pub fn split(&mut self) -> anyhow::Result<PostgresBackendReader> {
pub fn split(&mut self) -> anyhow::Result<PostgresBackendReader<IO>> {
// temporary replace stream with fake to cook split one, Indiana Jones style
match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
MaybeWriteOnly::Full(framed) => {
@@ -482,7 +503,7 @@ impl PostgresBackend {
}
/// Join read part back.
pub fn unsplit(&mut self, reader: PostgresBackendReader) -> anyhow::Result<()> {
pub fn unsplit(&mut self, reader: PostgresBackendReader<IO>) -> anyhow::Result<()> {
// temporary replace stream with fake to cook joined one, Indiana Jones style
match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
MaybeWriteOnly::Full(_) => {
@@ -499,7 +520,7 @@ impl PostgresBackend {
/// Perform handshake with the client, transitioning to Established.
/// In case of EOF during handshake logs this, sets state to Closed and returns Ok(()).
async fn handshake(&mut self, handler: &mut impl Handler) -> Result<(), QueryError> {
async fn handshake(&mut self, handler: &mut impl Handler<IO>) -> Result<(), QueryError> {
while self.state < ProtoState::Authentication {
match self.framed.read_startup_message().await? {
Some(msg) => {
@@ -565,7 +586,7 @@ impl PostgresBackend {
/// actual startup packet.
async fn process_startup_message(
&mut self,
handler: &mut impl Handler,
handler: &mut impl Handler<IO>,
msg: FeStartupPacket,
) -> Result<(), QueryError> {
assert!(self.state < ProtoState::Authentication);
@@ -629,7 +650,7 @@ impl PostgresBackend {
async fn process_message(
&mut self,
handler: &mut impl Handler,
handler: &mut impl Handler<IO>,
msg: FeMessage,
unnamed_query_string: &mut Bytes,
) -> Result<ProcessMsgResult, QueryError> {
@@ -776,9 +797,9 @@ impl PostgresBackend {
}
}
pub struct PostgresBackendReader(FramedReader<MaybeTlsStream>);
pub struct PostgresBackendReader<IO>(FramedReader<MaybeTlsStream<IO>>);
impl PostgresBackendReader {
impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackendReader<IO> {
/// Read full message or return None if connection is cleanly closed with no
/// unprocessed data.
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
@@ -812,11 +833,11 @@ impl PostgresBackendReader {
/// messages.
///
pub struct CopyDataWriter<'a> {
pgb: &'a mut PostgresBackend,
pub struct CopyDataWriter<'a, IO> {
pgb: &'a mut PostgresBackend<IO>,
}
impl<'a> AsyncWrite for CopyDataWriter<'a> {
impl<'a, IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for CopyDataWriter<'a, IO> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,