Remove sync postgres_backend, tidy up its split usage.

- Add support for splitting async postgres_backend into read and write halfes.
  Safekeeper needs this for bidirectional streams. To this end, encapsulate
  reading-writing postgres messages to framed.rs with split support without any
  additional changes (relying on BufRead for reading and BytesMut out buffer for
  writing).
- Use async postgres_backend throughout safekeeper (and in proxy auth link
  part).
- In both safekeeper COPY streams, do read-write from the same thread/task with
  select! for easier error handling.
- Tidy up finishing CopyBoth streams in safekeeper sending and receiving WAL
  -- join split parts back catching errors from them before returning.

Initially I hoped to do that read-write without split at all, through polling
IO:
https://github.com/neondatabase/neon/pull/3522
However that turned out to be more complicated than I initially expected
due to 1) borrow checking and 2) anon Future types. 1) required Rc<Refcell<...>>
which is Send construct just to satisfy the checker; 2) can be workaround with
transmute. But this is so messy that I decided to leave split.
This commit is contained in:
Arseny Sher
2023-02-02 12:03:45 +04:00
committed by Arseny Sher
parent 7627d85345
commit 0d8ced8534
45 changed files with 1657 additions and 2145 deletions

View File

@@ -17,7 +17,6 @@ tokio-rustls.workspace = true
tracing.workspace = true
pq_proto.workspace = true
utils.workspace = true
workspace_hack.workspace = true
[dev-dependencies]

View File

@@ -2,29 +2,26 @@
//! To use, create PostgresBackend and run() it, passing the Handler
//! implementation determining how to process the queries. Currently its API
//! is rather narrow, but we can extend it once required.
use anyhow::Context;
use bytes::{Buf, Bytes, BytesMut};
use pq_proto::{BeMessage, ConnectionError, FeMessage, FeStartupPacket, SQLSTATE_INTERNAL_ERROR};
use std::io;
use bytes::Bytes;
use futures::pin_mut;
use serde::{Deserialize, Serialize};
use std::io::ErrorKind;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::Poll;
use std::{future::Future, task::ready};
use tracing::{debug, error, info, trace};
use utils::postgres_backend::AuthType;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
use std::task::{ready, Poll};
use std::{fmt, io};
use std::{future::Future, str::FromStr};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::TlsAcceptor;
use tracing::{debug, error, info, trace};
pub fn is_expected_io_error(e: &io::Error) -> bool {
use io::ErrorKind::*;
matches!(
e.kind(),
ConnectionRefused | ConnectionAborted | ConnectionReset
)
}
use pq_proto::framed::{Framed, FramedReader, FramedWriter};
use pq_proto::{
BeMessage, ConnectionError, FeMessage, FeStartupPacket, SQLSTATE_INTERNAL_ERROR,
SQLSTATE_SUCCESSFUL_COMPLETION,
};
/// An error, occurred during query processing:
/// either during the connection ([`ConnectionError`]) or before/after it.
@@ -53,12 +50,20 @@ impl QueryError {
}
}
pub fn is_expected_io_error(e: &io::Error) -> bool {
use io::ErrorKind::*;
matches!(
e.kind(),
ConnectionRefused | ConnectionAborted | ConnectionReset
)
}
#[async_trait::async_trait]
pub trait Handler {
/// Handle single query.
/// postgres_backend will issue ReadyForQuery after calling this (this
/// might be not what we want after CopyData streaming, but currently we don't
/// care).
/// care). It will also flush out the output buffer.
async fn process_query(
&mut self,
pgb: &mut PostgresBackend,
@@ -92,9 +97,13 @@ pub trait Handler {
/// XXX: The order of the constructors matters.
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd)]
pub enum ProtoState {
/// Nothing happened yet.
Initialization,
/// Encryption handshake is done; waiting for encrypted Startup message.
Encrypted,
/// Waiting for password (auth token).
Authentication,
/// Performed handshake and auth, ReadyForQuery is issued.
Established,
Closed,
}
@@ -105,15 +114,13 @@ pub enum ProcessMsgResult {
Break,
}
/// Always-writeable sock_split stream.
/// May not be readable. See [`PostgresBackend::take_stream_in`]
pub enum Stream {
Unencrypted(BufReader<tokio::net::TcpStream>),
Tls(Box<tokio_rustls::server::TlsStream<BufReader<tokio::net::TcpStream>>>),
Broken,
/// Either plain TCP stream or encrypted one, implementing AsyncRead + AsyncWrite.
pub enum MaybeTlsStream {
Unencrypted(tokio::net::TcpStream),
Tls(Box<tokio_rustls::server::TlsStream<tokio::net::TcpStream>>),
}
impl AsyncWrite for Stream {
impl AsyncWrite for MaybeTlsStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
@@ -122,14 +129,12 @@ impl AsyncWrite for Stream {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_write(cx, buf),
Self::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
Self::Broken => unreachable!(),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_flush(cx),
Self::Tls(stream) => Pin::new(stream).poll_flush(cx),
Self::Broken => unreachable!(),
}
}
fn poll_shutdown(
@@ -139,11 +144,10 @@ impl AsyncWrite for Stream {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_shutdown(cx),
Self::Tls(stream) => Pin::new(stream).poll_shutdown(cx),
Self::Broken => unreachable!(),
}
}
}
impl AsyncRead for Stream {
impl AsyncRead for MaybeTlsStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
@@ -152,18 +156,96 @@ impl AsyncRead for Stream {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_read(cx, buf),
Self::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
Self::Broken => unreachable!(),
}
}
}
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
pub enum AuthType {
Trust,
// This mimics postgres's AuthenticationCleartextPassword but instead of password expects JWT
NeonJWT,
}
impl FromStr for AuthType {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"Trust" => Ok(Self::Trust),
"NeonJWT" => Ok(Self::NeonJWT),
_ => anyhow::bail!("invalid value \"{s}\" for auth type"),
}
}
}
impl fmt::Display for AuthType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
AuthType::Trust => "Trust",
AuthType::NeonJWT => "NeonJWT",
})
}
}
/// Either full duplex Framed or write only half; the latter is left in
/// PostgresBackend after call to `split`. In principle we could always store a
/// pair of splitted handles, but that would force to to pay splitting price
/// (Arc and kinda mutex inside polling) for all uses (e.g. pageserver).
enum MaybeWriteOnly {
Full(Framed<MaybeTlsStream>),
WriteOnly(FramedWriter<MaybeTlsStream>),
Broken, // temporary value palmed off during the split
}
impl MaybeWriteOnly {
async fn read_startup_message(&mut self) -> Result<Option<FeStartupPacket>, ConnectionError> {
match self {
MaybeWriteOnly::Full(framed) => framed.read_startup_message().await,
MaybeWriteOnly::WriteOnly(_) => {
Err(io::Error::new(ErrorKind::Other, "reading from write only half").into())
}
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
}
}
async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
match self {
MaybeWriteOnly::Full(framed) => framed.read_message().await,
MaybeWriteOnly::WriteOnly(_) => {
Err(io::Error::new(ErrorKind::Other, "reading from write only half").into())
}
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
}
}
fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ConnectionError> {
match self {
MaybeWriteOnly::Full(framed) => framed.write_message(msg),
MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.write_message_noflush(msg),
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
}
}
async fn flush(&mut self) -> io::Result<()> {
match self {
MaybeWriteOnly::Full(framed) => framed.flush().await,
MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.flush().await,
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
}
}
async fn shutdown(&mut self) -> io::Result<()> {
match self {
MaybeWriteOnly::Full(framed) => framed.shutdown().await,
MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.shutdown().await,
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
}
}
}
pub struct PostgresBackend {
stream: Stream,
// Output buffer. c.f. BeMessage::write why we are using BytesMut here.
// The data between 0 and "current position" as tracked by the bytes::Buf
// implementation of BytesMut, have already been written.
buf_out: BytesMut,
framed: MaybeWriteOnly,
pub state: ProtoState,
@@ -183,7 +265,7 @@ pub fn query_from_cstring(query_string: Bytes) -> Vec<u8> {
query_string
}
// Cast a byte slice to a string slice, dropping null terminator if there's one.
/// Cast a byte slice to a string slice, dropping null terminator if there's one.
fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> {
let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes);
std::str::from_utf8(without_null).map_err(|e| e.into())
@@ -196,10 +278,10 @@ impl PostgresBackend {
tls_config: Option<Arc<rustls::ServerConfig>>,
) -> io::Result<Self> {
let peer_addr = socket.peer_addr()?;
let stream = MaybeTlsStream::Unencrypted(socket);
Ok(Self {
stream: Stream::Unencrypted(BufReader::new(socket)),
buf_out: BytesMut::with_capacity(10 * 1024),
framed: MaybeWriteOnly::Full(Framed::new(stream)),
state: ProtoState::Initialization,
auth_type,
tls_config,
@@ -211,30 +293,52 @@ impl PostgresBackend {
&self.peer_addr
}
/// Read full message or return None if connection is closed.
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, QueryError> {
use ProtoState::*;
match self.state {
Initialization | Encrypted => FeStartupPacket::read_fut(&mut self.stream).await,
Authentication | Established => FeMessage::read_fut(&mut self.stream).await,
Closed => Ok(None),
/// Read full message or return None if connection is cleanly closed with no
/// unprocessed data.
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
if let ProtoState::Closed = self.state {
Ok(None)
} else {
let m = self.framed.read_message().await?;
trace!("read msg {:?}", m);
Ok(m)
}
.map_err(QueryError::from)
}
/// Write message into internal output buffer, doesn't flush it. Technically
/// error type can be only ProtocolError here (if, unlikely, serialization
/// fails), but callers typically wrap it anyway.
pub fn write_message_noflush(
&mut self,
message: &BeMessage<'_>,
) -> Result<&mut Self, ConnectionError> {
self.framed.write_message_noflush(message)?;
trace!("wrote msg {:?}", message);
Ok(self)
}
/// Flush output buffer into the socket.
pub async fn flush(&mut self) -> io::Result<()> {
while self.buf_out.has_remaining() {
let bytes_written = self.stream.write(self.buf_out.chunk()).await?;
self.buf_out.advance(bytes_written);
}
self.buf_out.clear();
Ok(())
self.framed.flush().await
}
/// Write message into internal output buffer.
pub fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
BeMessage::write(&mut self.buf_out, message)?;
/// Polling version of `flush()`, saves the caller need to pin.
pub fn poll_flush(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let flush_fut = self.flush();
pin_mut!(flush_fut);
flush_fut.poll(cx)
}
/// Write message into internal output buffer and flush it to the stream.
pub async fn write_message(
&mut self,
message: &BeMessage<'_>,
) -> Result<&mut Self, ConnectionError> {
self.write_message_noflush(message)?;
self.flush().await?;
Ok(self)
}
@@ -246,26 +350,7 @@ impl PostgresBackend {
CopyDataWriter { pgb: self }
}
/// A polling function that tries to write all the data from 'buf_out' to the
/// underlying stream.
fn poll_write_buf(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
while self.buf_out.has_remaining() {
match ready!(Pin::new(&mut self.stream).poll_write(cx, self.buf_out.chunk())) {
Ok(bytes_written) => self.buf_out.advance(bytes_written),
Err(err) => return Poll::Ready(Err(err)),
}
}
Poll::Ready(Ok(()))
}
fn poll_flush(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), std::io::Error>> {
Pin::new(&mut self.stream).poll_flush(cx)
}
// Wrapper for run_message_loop() that shuts down socket when we are done
/// Wrapper for run_message_loop() that shuts down socket when we are done
pub async fn run<F, S>(
mut self,
handler: &mut impl Handler,
@@ -276,7 +361,9 @@ impl PostgresBackend {
S: Future,
{
let ret = self.run_message_loop(handler, shutdown_watcher).await;
let _ = self.stream.shutdown();
// socket might be already closed, e.g. if previously received error,
// so ignore result.
self.framed.shutdown().await.ok();
ret
}
@@ -300,30 +387,12 @@ impl PostgresBackend {
return Ok(())
},
result = async {
while self.state < ProtoState::Established {
if let Some(msg) = self.read_message().await? {
trace!("got message {msg:?} during handshake");
match self.process_handshake_message(handler, msg).await? {
ProcessMsgResult::Continue => {
self.flush().await?;
continue;
}
ProcessMsgResult::Break => {
trace!("postgres backend to {:?} exited during handshake", self.peer_addr);
return Ok(());
}
}
} else {
trace!("postgres backend to {:?} exited during handshake", self.peer_addr);
return Ok(());
}
}
Ok::<(), QueryError>(())
} => {
result = self.handshake(handler) => {
// Handshake complete.
result?;
if self.state == ProtoState::Closed {
return Ok(()); // EOF during handshake
}
}
);
@@ -355,114 +424,207 @@ impl PostgresBackend {
Ok(())
}
async fn start_tls(&mut self) -> anyhow::Result<()> {
if let Stream::Unencrypted(plain_stream) =
std::mem::replace(&mut self.stream, Stream::Broken)
{
let acceptor = TlsAcceptor::from(self.tls_config.clone().unwrap());
let tls_stream = acceptor.accept(plain_stream).await?;
self.stream = Stream::Tls(Box::new(tls_stream));
return Ok(());
};
anyhow::bail!("TLS already started");
}
async fn process_handshake_message(
&mut self,
handler: &mut impl Handler,
msg: FeMessage,
) -> Result<ProcessMsgResult, QueryError> {
assert!(self.state < ProtoState::Established);
let have_tls = self.tls_config.is_some();
match msg {
FeMessage::StartupPacket(m) => {
trace!("got startup message {m:?}");
match m {
FeStartupPacket::SslRequest => {
debug!("SSL requested");
self.write_message_noflush(&BeMessage::EncryptionResponse(have_tls))?;
if have_tls {
self.start_tls().await?;
self.state = ProtoState::Encrypted;
}
}
FeStartupPacket::GssEncRequest => {
debug!("GSS requested");
self.write_message_noflush(&BeMessage::EncryptionResponse(false))?;
}
FeStartupPacket::StartupMessage { .. } => {
if have_tls && !matches!(self.state, ProtoState::Encrypted) {
self.write_message_noflush(&BeMessage::ErrorResponse(
"must connect with TLS",
None,
))?;
return Err(QueryError::Other(anyhow::anyhow!(
"client did not connect with TLS"
)));
}
// NB: startup() may change self.auth_type -- we are using that in proxy code
// to bypass auth for new users.
handler.startup(self, &m)?;
match self.auth_type {
AuthType::Trust => {
self.write_message_noflush(&BeMessage::AuthenticationOk)?
.write_message_noflush(&BeMessage::CLIENT_ENCODING)?
// The async python driver requires a valid server_version
.write_message_noflush(&BeMessage::server_version("14.1"))?
.write_message_noflush(&BeMessage::ReadyForQuery)?;
self.state = ProtoState::Established;
}
AuthType::NeonJWT => {
self.write_message_noflush(
&BeMessage::AuthenticationCleartextPassword,
)?;
self.state = ProtoState::Authentication;
}
}
}
FeStartupPacket::CancelRequest { .. } => {
self.state = ProtoState::Closed;
return Ok(ProcessMsgResult::Break);
}
}
/// Try to upgrade MaybeTlsStream into actual TLS one, performing handshake.
async fn tls_upgrade(
src: MaybeTlsStream,
tls_config: Arc<rustls::ServerConfig>,
) -> anyhow::Result<MaybeTlsStream> {
match src {
MaybeTlsStream::Unencrypted(s) => {
let acceptor = TlsAcceptor::from(tls_config);
let tls_stream = acceptor.accept(s).await?;
Ok(MaybeTlsStream::Tls(Box::new(tls_stream)))
}
FeMessage::PasswordMessage(m) => {
trace!("got password message '{:?}'", m);
assert!(self.state == ProtoState::Authentication);
match self.auth_type {
AuthType::Trust => unreachable!(),
AuthType::NeonJWT => {
let (_, jwt_response) = m.split_last().context("protocol violation")?;
if let Err(e) = handler.check_auth_jwt(self, jwt_response) {
self.write_message_noflush(&BeMessage::ErrorResponse(
&e.to_string(),
Some(e.pg_error_code()),
))?;
return Err(e);
}
}
}
self.write_message_noflush(&BeMessage::AuthenticationOk)?
.write_message_noflush(&BeMessage::CLIENT_ENCODING)?
.write_message_noflush(&BeMessage::ReadyForQuery)?;
self.state = ProtoState::Established;
}
_ => {
self.state = ProtoState::Closed;
return Ok(ProcessMsgResult::Break);
MaybeTlsStream::Tls(_) => {
anyhow::bail!("TLS already started");
}
}
Ok(ProcessMsgResult::Continue)
}
async fn start_tls(&mut self) -> anyhow::Result<()> {
// temporary replace stream with fake to cook TLS one, Indiana Jones style
match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
MaybeWriteOnly::Full(framed) => {
let tls_config = self
.tls_config
.as_ref()
.context("start_tls called without conf")?
.clone();
let tls_framed = framed
.map_stream(|s| PostgresBackend::tls_upgrade(s, tls_config))
.await?;
// push back ready TLS stream
self.framed = MaybeWriteOnly::Full(tls_framed);
Ok(())
}
MaybeWriteOnly::WriteOnly(_) => {
anyhow::bail!("TLS upgrade attempt in split state")
}
MaybeWriteOnly::Broken => panic!("TLS upgrade on framed in invalid state"),
}
}
/// Split off owned read part from which messages can be read in different
/// task/thread.
pub fn split(&mut self) -> anyhow::Result<PostgresBackendReader> {
// temporary replace stream with fake to cook split one, Indiana Jones style
match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
MaybeWriteOnly::Full(framed) => {
let (reader, writer) = framed.split();
self.framed = MaybeWriteOnly::WriteOnly(writer);
Ok(PostgresBackendReader(reader))
}
MaybeWriteOnly::WriteOnly(_) => {
anyhow::bail!("PostgresBackend is already split")
}
MaybeWriteOnly::Broken => panic!("split on framed in invalid state"),
}
}
/// Join read part back.
pub fn unsplit(&mut self, reader: PostgresBackendReader) -> anyhow::Result<()> {
// temporary replace stream with fake to cook joined one, Indiana Jones style
match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
MaybeWriteOnly::Full(_) => {
anyhow::bail!("PostgresBackend is not split")
}
MaybeWriteOnly::WriteOnly(writer) => {
let joined = Framed::unsplit(reader.0, writer);
self.framed = MaybeWriteOnly::Full(joined);
Ok(())
}
MaybeWriteOnly::Broken => panic!("unsplit on framed in invalid state"),
}
}
/// Perform handshake with the client, transitioning to Established.
/// In case of EOF during handshake logs this, sets state to Closed and returns Ok(()).
async fn handshake(&mut self, handler: &mut impl Handler) -> Result<(), QueryError> {
while self.state < ProtoState::Authentication {
match self.framed.read_startup_message().await? {
Some(msg) => {
self.process_startup_message(handler, msg).await?;
}
None => {
trace!(
"postgres backend to {:?} received EOF during handshake",
self.peer_addr
);
self.state = ProtoState::Closed;
return Ok(());
}
}
}
// Perform auth, if needed.
if self.state == ProtoState::Authentication {
match self.framed.read_message().await? {
Some(FeMessage::PasswordMessage(m)) => {
assert!(self.auth_type == AuthType::NeonJWT);
let (_, jwt_response) = m.split_last().context("protocol violation")?;
if let Err(e) = handler.check_auth_jwt(self, jwt_response) {
self.write_message_noflush(&BeMessage::ErrorResponse(
&e.to_string(),
Some(e.pg_error_code()),
))?;
return Err(e);
}
self.write_message_noflush(&BeMessage::AuthenticationOk)?
.write_message_noflush(&BeMessage::CLIENT_ENCODING)?
.write_message(&BeMessage::ReadyForQuery)
.await?;
self.state = ProtoState::Established;
}
Some(m) => {
return Err(QueryError::Other(anyhow::anyhow!(
"Unexpected message {:?} while waiting for handshake",
m
)));
}
None => {
trace!(
"postgres backend to {:?} received EOF during auth",
self.peer_addr
);
self.state = ProtoState::Closed;
return Ok(());
}
}
}
Ok(())
}
/// Process startup packet:
/// - transition to Established if auth type is trust
/// - transition to Authentication if auth type is NeonJWT.
/// - or perform TLS handshake -- then need to call this again to receive
/// actual startup packet.
async fn process_startup_message(
&mut self,
handler: &mut impl Handler,
msg: FeStartupPacket,
) -> Result<(), QueryError> {
assert!(self.state < ProtoState::Authentication);
let have_tls = self.tls_config.is_some();
match msg {
FeStartupPacket::SslRequest => {
debug!("SSL requested");
self.write_message(&BeMessage::EncryptionResponse(have_tls))
.await?;
if have_tls {
self.start_tls().await?;
self.state = ProtoState::Encrypted;
}
}
FeStartupPacket::GssEncRequest => {
debug!("GSS requested");
self.write_message(&BeMessage::EncryptionResponse(false))
.await?;
}
FeStartupPacket::StartupMessage { .. } => {
if have_tls && !matches!(self.state, ProtoState::Encrypted) {
self.write_message(&BeMessage::ErrorResponse("must connect with TLS", None))
.await?;
return Err(QueryError::Other(anyhow::anyhow!(
"client did not connect with TLS"
)));
}
// NB: startup() may change self.auth_type -- we are using that in proxy code
// to bypass auth for new users.
handler.startup(self, &msg)?;
match self.auth_type {
AuthType::Trust => {
self.write_message_noflush(&BeMessage::AuthenticationOk)?
.write_message_noflush(&BeMessage::CLIENT_ENCODING)?
.write_message_noflush(&BeMessage::INTEGER_DATETIMES)?
// The async python driver requires a valid server_version
.write_message_noflush(&BeMessage::server_version("14.1"))?
.write_message(&BeMessage::ReadyForQuery)
.await?;
self.state = ProtoState::Established;
}
AuthType::NeonJWT => {
self.write_message(&BeMessage::AuthenticationCleartextPassword)
.await?;
self.state = ProtoState::Authentication;
}
}
}
FeStartupPacket::CancelRequest { .. } => {
return Err(QueryError::Other(anyhow::anyhow!(
"Unexpected CancelRequest message during handshake"
)));
}
}
Ok(())
}
async fn process_message(
@@ -476,10 +638,6 @@ impl PostgresBackend {
assert!(self.state == ProtoState::Established);
match msg {
FeMessage::StartupPacket(_) | FeMessage::PasswordMessage(_) => {
return Err(QueryError::Other(anyhow::anyhow!("protocol violation")));
}
FeMessage::Query(body) => {
// remove null terminator
let query_string = cstr_to_str(&body)?;
@@ -540,16 +698,114 @@ impl PostgresBackend {
// We prefer explicit pattern matching to wildcards, because
// this helps us spot the places where new variants are missing
FeMessage::CopyData(_) | FeMessage::CopyDone | FeMessage::CopyFail => {
FeMessage::CopyData(_)
| FeMessage::CopyDone
| FeMessage::CopyFail
| FeMessage::PasswordMessage(_)
| FeMessage::StartupPacket(_) => {
return Err(QueryError::Other(anyhow::anyhow!(
"unexpected message type: {:?}",
msg
"unexpected message type: {msg:?}",
)));
}
}
Ok(ProcessMsgResult::Continue)
}
/// Log as info/error result of handling COPY stream and send back
/// ErrorResponse if that makes sense. Shutdown the stream if we got
/// Terminate. TODO: transition into waiting for Sync msg if we initiate the
/// close.
pub async fn handle_copy_stream_end(&mut self, end: CopyStreamHandlerEnd) {
use CopyStreamHandlerEnd::*;
let expected_end = match &end {
ServerInitiated(_) | CopyDone | CopyFail | Terminate | EOF => true,
CopyStreamHandlerEnd::Disconnected(ConnectionError::Socket(io_error))
if is_expected_io_error(io_error) =>
{
true
}
_ => false,
};
if expected_end {
info!("terminated: {:#}", end);
} else {
error!("terminated: {:?}", end);
}
// Note: no current usages ever send this
if let CopyDone = &end {
if let Err(e) = self.write_message(&BeMessage::CopyDone).await {
error!("failed to send CopyDone: {}", e);
}
}
if let Terminate = &end {
self.state = ProtoState::Closed;
}
let err_to_send_and_errcode = match &end {
ServerInitiated(_) => Some((end.to_string(), SQLSTATE_SUCCESSFUL_COMPLETION)),
Other(_) => Some((end.to_string(), SQLSTATE_INTERNAL_ERROR)),
// Note: CopyFail in duplex copy is somewhat unexpected (at least to
// PG walsender; evidently and per my docs reading client should
// finish it with CopyDone). It is not a problem to recover from it
// finishing the stream in both directions like we do, but note that
// sync rust-postgres client (which we don't use anymore) hangs if
// socket is not closed here.
// https://github.com/sfackler/rust-postgres/issues/755
// https://github.com/neondatabase/neon/issues/935
//
// Currently, the version of tokio_postgres replication patch we use
// sends this when it closes the stream (e.g. pageserver decided to
// switch conn to another safekeeper and client gets dropped).
// Moreover, seems like 'connection' task errors with 'unexpected
// message from server' when it receives ErrorResponse (anything but
// CopyData/CopyDone) back.
CopyFail => Some((end.to_string(), SQLSTATE_SUCCESSFUL_COMPLETION)),
_ => None,
};
if let Some((err, errcode)) = err_to_send_and_errcode {
if let Err(ee) = self
.write_message(&BeMessage::ErrorResponse(&err, Some(errcode)))
.await
{
error!("failed to send ErrorResponse: {}", ee);
}
}
}
}
pub struct PostgresBackendReader(FramedReader<MaybeTlsStream>);
impl PostgresBackendReader {
/// Read full message or return None if connection is cleanly closed with no
/// unprocessed data.
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
let m = self.0.read_message().await?;
trace!("read msg {:?}", m);
Ok(m)
}
/// Get CopyData contents of the next message in COPY stream or error
/// closing it. The error type is wider than actual errors which can happen
/// here -- it includes 'Other' and 'ServerInitiated', but that's ok for
/// current callers.
pub async fn read_copy_message(&mut self) -> Result<Bytes, CopyStreamHandlerEnd> {
match self.read_message().await? {
Some(msg) => match msg {
FeMessage::CopyData(m) => Ok(m),
FeMessage::CopyDone => Err(CopyStreamHandlerEnd::CopyDone),
FeMessage::CopyFail => Err(CopyStreamHandlerEnd::CopyFail),
FeMessage::Terminate => Err(CopyStreamHandlerEnd::Terminate),
_ => Err(CopyStreamHandlerEnd::from(ConnectionError::Protocol(
format!("unexpected message in COPY stream {:?}", msg),
))),
},
None => Err(CopyStreamHandlerEnd::EOF),
}
}
}
///
@@ -572,16 +828,19 @@ impl<'a> AsyncWrite for CopyDataWriter<'a> {
// It's not strictly required to flush between each message, but makes it easier
// to view in wireshark, and usually the messages that the callers write are
// decently-sized anyway.
match ready!(this.pgb.poll_write_buf(cx)) {
Ok(()) => {}
Err(err) => return Poll::Ready(Err(err)),
if let Err(err) = ready!(this.pgb.poll_flush(cx)) {
return Poll::Ready(Err(err));
}
// CopyData
// XXX: if the input is large, we should split it into multiple messages.
// Not sure what the threshold should be, but the ultimate hard limit is that
// the length cannot exceed u32.
this.pgb.write_message_noflush(&BeMessage::CopyData(buf))?;
this.pgb
.write_message_noflush(&BeMessage::CopyData(buf))
// write_message only writes to the buffer, so it can fail iff the
// message is invaid, but CopyData can't be invalid.
.map_err(|_| io::Error::new(ErrorKind::Other, "failed to serialize CopyData"))?;
Poll::Ready(Ok(buf.len()))
}
@@ -591,21 +850,14 @@ impl<'a> AsyncWrite for CopyDataWriter<'a> {
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let this = self.get_mut();
match ready!(this.pgb.poll_write_buf(cx)) {
Ok(()) => {}
Err(err) => return Poll::Ready(Err(err)),
}
this.pgb.poll_flush(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let this = self.get_mut();
match ready!(this.pgb.poll_write_buf(cx)) {
Ok(()) => {}
Err(err) => return Poll::Ready(Err(err)),
}
this.pgb.poll_flush(cx)
}
}
@@ -617,7 +869,7 @@ pub fn short_error(e: &QueryError) -> String {
}
}
pub fn log_query_error(query: &str, e: &QueryError) {
fn log_query_error(query: &str, e: &QueryError) {
match e {
QueryError::Disconnected(ConnectionError::Socket(io_error)) => {
if is_expected_io_error(io_error) {
@@ -634,3 +886,26 @@ pub fn log_query_error(query: &str, e: &QueryError) {
}
}
}
/// Something finishing handling of COPY stream, see handle_copy_stream_end.
/// This is not always a real error, but it allows to use ? and thiserror impls.
#[derive(thiserror::Error, Debug)]
pub enum CopyStreamHandlerEnd {
/// Handler initiates the end of streaming.
#[error("{0}")]
ServerInitiated(String),
#[error("received CopyDone")]
CopyDone,
#[error("received CopyFail")]
CopyFail,
#[error("received Terminate")]
Terminate,
#[error("EOF on COPY stream")]
EOF,
/// The connection was lost
#[error(transparent)]
Disconnected(#[from] ConnectionError),
/// Some other error
#[error(transparent)]
Other(#[from] anyhow::Error),
}

View File

@@ -0,0 +1,21 @@
-----BEGIN CERTIFICATE-----
MIIDbjCCAlagAwIBAgIUGHJukXa1bQathgBHC40+A18BsnYwDQYJKoZIhvcNAQEL
BQAwYzELMAkGA1UEBhMCVVMxEzARBgNVBAgMCkNhbGlmb3JuaWExFjAUBgNVBAcM
DVNhbiBGcmFuY2lzY28xEzARBgNVBAoMCk15IENvbXBhbnkxEjAQBgNVBAMMCWxv
Y2FsaG9zdDAgFw0yMTA4MTMxODQyMjBaGA8yMTIxMDcyMDE4NDIyMFowYzELMAkG
A1UEBhMCVVMxEzARBgNVBAgMCkNhbGlmb3JuaWExFjAUBgNVBAcMDVNhbiBGcmFu
Y2lzY28xEzARBgNVBAoMCk15IENvbXBhbnkxEjAQBgNVBAMMCWxvY2FsaG9zdDCC
ASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAOI9S+nh8ABMp5jpb7WWfAYr
tGJ4C7gi9IPTVIRxSSrt5KglEysrOiKlhan1Ut2e8CCudztdXtCvT8/goJWlmxpF
IQkErlCsOdGHeEJ0EZxoU1fMkBAQVf6Rb1JE9ladG2+D1e7yvxmMqfPVuU8lj+kN
nESP+I3ESNCtuqgtfcErxu3TuhSzV2slSi5lrYQCwERgCevl6LUNd2mEaYdS4mmJ
4RZqc2C4y7JO5wSDjga8GIBHJVo70HRVsvX7eE8r6tMP2HyGyonBitBKAc2QEQIv
cLCuMOTtTBlYcMvTmJEOHFKwIJXm0XmQfAWeKFfyK7493fB4Gu+8Dc1xC+IHaTEC
AwEAAaMYMBYwFAYDVR0RBA0wC4IJbG9jYWxob3N0MA0GCSqGSIb3DQEBCwUAA4IB
AQBjY+g3eF8m8lEWz+QgKp88MhTdtJTsEsSz0GAi58SnEkuyxVOHjKEyjGKJWTtT
ICgmEzC85uaS7VBdftoYNmsbvNewGiisDGQRWCjOGM7lTaA4FQPADguexMvXh/nO
9PQoTxtp7qwvGWO2mED6LWU6bjT3cL+XgrOwT9sticRTl6/BXV8wAmyxT0DkQ3nJ
zbRuTP/G2kE0bRK++67kK0ovopRkX6Dl6di1EFlkAnPBC2d8tdcNTXYhkxZk4O0q
GUolwiuWz/dtD3tZ2bx3vqzT7uIFHS4XP6Q3SRNWFTGhuvAc7DPvCZBqxy6odeyQ
VxBgJtq+pNjYYkeaSQVQ+UMU
-----END CERTIFICATE-----

View File

@@ -0,0 +1,27 @@
-----BEGIN RSA PRIVATE KEY-----
MIIEpAIBAAKCAQEA4j1L6eHwAEynmOlvtZZ8Biu0YngLuCL0g9NUhHFJKu3kqCUT
Kys6IqWFqfVS3Z7wIK53O11e0K9Pz+CglaWbGkUhCQSuUKw50Yd4QnQRnGhTV8yQ
EBBV/pFvUkT2Vp0bb4PV7vK/GYyp89W5TyWP6Q2cRI/4jcRI0K26qC19wSvG7dO6
FLNXayVKLmWthALARGAJ6+XotQ13aYRph1LiaYnhFmpzYLjLsk7nBIOOBrwYgEcl
WjvQdFWy9ft4Tyvq0w/YfIbKicGK0EoBzZARAi9wsK4w5O1MGVhwy9OYkQ4cUrAg
lebReZB8BZ4oV/Irvj3d8Hga77wNzXEL4gdpMQIDAQABAoIBAQClKycO+zpinZQG
GPbLVa/6OVIaSZYUusBUtaaQgrxuMPusnlSeQZLR1JH/APGchvq8gWLe3k3ogPT9
yPq0BhF0Xl+928L/dp1HkWWE7oQk8i1Wfiv27lY54iepoltN5KkxAsjfCC3oEz/I
mpINbFjiRmN90rYdmd2nLA6H1Z5ntZQm5AcTo3OJZlTVN9eH9TV8f0AQRQgUJsL9
75agSmj7euqZOqvvwfpsYzaZEhzMSG2QIcS3WglInbHy8c6ikZSm36J36wgsatMz
CBZ6pMNtonRSKvAECQhBGEA73evtnGbLH0EY9KouN4KSHEHob89dGVeeXozksf9x
QUE1/yOhAoGBAP818f7vIH6Z3QwWgTMwQsPBW+wNOIbTZrbZaihnz2K9XMu39TV6
DWQHMsOlvg2QURZGwqB3jFn4wqZHmt7XYwk553E60kIw4hDvgpkkqmXVwK3kZASQ
RRUax3hZ1gCWxpXlRZ1SvHNXjN9KEFwqQbR33XcxzC3TpSp0KYghT9jFAoGBAOLw
agejqSF+f/5W1QhEKlM+tSlluo2sn5kKVkM4nNezFukb3pu5oScFjoQQGsoaz5aU
kLlxW5h/aSxquhgcuo6I4Ux5dcgNm4QeonCCp+Qycn7tzyoJFL4odT9vYPQa5O9E
hD9aSqhBBD1IIOS2T3vcW6VxibKZx1CRMDdRz119AoGALflr1L8DHYteNLVBJRWG
kXkdtBJVooQmtr3Hz+uTgngWZWSIOc/45ZIeZPxQlmTvFpI8sWeX0wVrG0U+8vHe
F2Vk+hLcmavwrZhX8HqYb6vn/+tq0R+kMj8Wu+mDEawXrh0VQ1gKNsUIzZisBc5e
88G8FaLU41SDJniymqFVnvkCgYEA1ou/UfWRwg6b5tIkmKoI8aZJExgPpDzcrYyu
POLatLmlIUCt1b9K8V85evTWvtdWBd/yar8WfzeFMO69fGo8nOAfT3NMvJLQwblM
jN2Y6A4hXIpq3iyzpYsOPaiImn6KjQHTnSk5h5Pf9CeqoU8SGeEb629JZMYpPqvk
T4hSaOkCgYBPaf51oSAstqdj0vxrsFS3EN3D8Fk0xQWt9Ss3ZGFAlTaEq5xoIk4k
YfKVDv1S6/vlzbheIIzQ2lzVvG4AW+drQLsmEx5iMKvbNtFAur9kwUFU202Q2dki
ZQJ/JvjnPYFKxy+SVlLJ1h9RD9E3dgL/Ai7OUfbmX771vN0IQF7Z6Q==
-----END RSA PRIVATE KEY-----

View File

@@ -0,0 +1,139 @@
/// Test postgres_backend_async with tokio_postgres
use once_cell::sync::Lazy;
use postgres_backend::{AuthType, Handler, PostgresBackend, QueryError};
use pq_proto::{BeMessage, RowDescriptor};
use std::io::Cursor;
use std::{future, sync::Arc};
use tokio::net::{TcpListener, TcpStream};
use tokio_postgres::config::SslMode;
use tokio_postgres::tls::MakeTlsConnect;
use tokio_postgres::{Config, NoTls, SimpleQueryMessage};
use tokio_postgres_rustls::MakeRustlsConnect;
// generate client, server test streams
async fn make_tcp_pair() -> (TcpStream, TcpStream) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let client_stream = TcpStream::connect(addr).await.unwrap();
let (server_stream, _) = listener.accept().await.unwrap();
(client_stream, server_stream)
}
struct TestHandler {}
#[async_trait::async_trait]
impl Handler for TestHandler {
// return single col 'hey' for any query
async fn process_query(
&mut self,
pgb: &mut PostgresBackend,
_query_string: &str,
) -> Result<(), QueryError> {
pgb.write_message_noflush(&BeMessage::RowDescription(&[RowDescriptor::text_col(
b"hey",
)]))?
.write_message_noflush(&BeMessage::DataRow(&[Some("hey".as_bytes())]))?
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
Ok(())
}
}
// test that basic select works
#[tokio::test]
async fn simple_select() {
let (client_sock, server_sock) = make_tcp_pair().await;
// create and run pgbackend
let pgbackend =
PostgresBackend::new(server_sock, AuthType::Trust, None).expect("pgbackend creation");
tokio::spawn(async move {
let mut handler = TestHandler {};
pgbackend.run(&mut handler, future::pending::<()>).await
});
let conf = Config::new();
let (client, connection) = conf.connect_raw(client_sock, NoTls).await.expect("connect");
// The connection object performs the actual communication with the database,
// so spawn it off to run on its own.
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("connection error: {}", e);
}
});
let first_val = &(client.simple_query("SELECT 42;").await.expect("select"))[0];
if let SimpleQueryMessage::Row(row) = first_val {
let first_col = row.get(0).expect("first column");
assert_eq!(first_col, "hey");
} else {
panic!("expected SimpleQueryMessage::Row");
}
}
static KEY: Lazy<rustls::PrivateKey> = Lazy::new(|| {
let mut cursor = Cursor::new(include_bytes!("key.pem"));
rustls::PrivateKey(rustls_pemfile::rsa_private_keys(&mut cursor).unwrap()[0].clone())
});
static CERT: Lazy<rustls::Certificate> = Lazy::new(|| {
let mut cursor = Cursor::new(include_bytes!("cert.pem"));
rustls::Certificate(rustls_pemfile::certs(&mut cursor).unwrap()[0].clone())
});
// test that basic select with ssl works
#[tokio::test]
async fn simple_select_ssl() {
let (client_sock, server_sock) = make_tcp_pair().await;
let server_cfg = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![CERT.clone()], KEY.clone())
.unwrap();
let tls_config = Some(Arc::new(server_cfg));
let pgbackend =
PostgresBackend::new(server_sock, AuthType::Trust, tls_config).expect("pgbackend creation");
tokio::spawn(async move {
let mut handler = TestHandler {};
pgbackend.run(&mut handler, future::pending::<()>).await
});
let client_cfg = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates({
let mut store = rustls::RootCertStore::empty();
store.add(&CERT).unwrap();
store
})
.with_no_client_auth();
let mut make_tls_connect = tokio_postgres_rustls::MakeRustlsConnect::new(client_cfg);
let tls_connect = <MakeRustlsConnect as MakeTlsConnect<TcpStream>>::make_tls_connect(
&mut make_tls_connect,
"localhost",
)
.expect("make_tls_connect");
let mut conf = Config::new();
conf.ssl_mode(SslMode::Require);
let (client, connection) = conf
.connect_raw(client_sock, tls_connect)
.await
.expect("connect");
// The connection object performs the actual communication with the database,
// so spawn it off to run on its own.
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("connection error: {}", e);
}
});
let first_val = &(client.simple_query("SELECT 42;").await.expect("select"))[0];
if let SimpleQueryMessage::Row(row) = first_val {
let first_col = row.get(0).expect("first column");
assert_eq!(first_col, "hey");
} else {
panic!("expected SimpleQueryMessage::Row");
}
}