mirror of
https://github.com/neondatabase/neon.git
synced 2025-12-23 06:09:59 +00:00
Remove sync postgres_backend, tidy up its split usage.
- Add support for splitting async postgres_backend into read and write halfes. Safekeeper needs this for bidirectional streams. To this end, encapsulate reading-writing postgres messages to framed.rs with split support without any additional changes (relying on BufRead for reading and BytesMut out buffer for writing). - Use async postgres_backend throughout safekeeper (and in proxy auth link part). - In both safekeeper COPY streams, do read-write from the same thread/task with select! for easier error handling. - Tidy up finishing CopyBoth streams in safekeeper sending and receiving WAL -- join split parts back catching errors from them before returning. Initially I hoped to do that read-write without split at all, through polling IO: https://github.com/neondatabase/neon/pull/3522 However that turned out to be more complicated than I initially expected due to 1) borrow checking and 2) anon Future types. 1) required Rc<Refcell<...>> which is Send construct just to satisfy the checker; 2) can be workaround with transmute. But this is so messy that I decided to leave split.
This commit is contained in:
@@ -17,7 +17,6 @@ tokio-rustls.workspace = true
|
||||
tracing.workspace = true
|
||||
|
||||
pq_proto.workspace = true
|
||||
utils.workspace = true
|
||||
workspace_hack.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
|
||||
@@ -2,29 +2,26 @@
|
||||
//! To use, create PostgresBackend and run() it, passing the Handler
|
||||
//! implementation determining how to process the queries. Currently its API
|
||||
//! is rather narrow, but we can extend it once required.
|
||||
|
||||
use anyhow::Context;
|
||||
use bytes::{Buf, Bytes, BytesMut};
|
||||
use pq_proto::{BeMessage, ConnectionError, FeMessage, FeStartupPacket, SQLSTATE_INTERNAL_ERROR};
|
||||
use std::io;
|
||||
use bytes::Bytes;
|
||||
use futures::pin_mut;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::io::ErrorKind;
|
||||
use std::net::SocketAddr;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::Poll;
|
||||
use std::{future::Future, task::ready};
|
||||
use tracing::{debug, error, info, trace};
|
||||
use utils::postgres_backend::AuthType;
|
||||
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
|
||||
use std::task::{ready, Poll};
|
||||
use std::{fmt, io};
|
||||
use std::{future::Future, str::FromStr};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
use tracing::{debug, error, info, trace};
|
||||
|
||||
pub fn is_expected_io_error(e: &io::Error) -> bool {
|
||||
use io::ErrorKind::*;
|
||||
matches!(
|
||||
e.kind(),
|
||||
ConnectionRefused | ConnectionAborted | ConnectionReset
|
||||
)
|
||||
}
|
||||
use pq_proto::framed::{Framed, FramedReader, FramedWriter};
|
||||
use pq_proto::{
|
||||
BeMessage, ConnectionError, FeMessage, FeStartupPacket, SQLSTATE_INTERNAL_ERROR,
|
||||
SQLSTATE_SUCCESSFUL_COMPLETION,
|
||||
};
|
||||
|
||||
/// An error, occurred during query processing:
|
||||
/// either during the connection ([`ConnectionError`]) or before/after it.
|
||||
@@ -53,12 +50,20 @@ impl QueryError {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_expected_io_error(e: &io::Error) -> bool {
|
||||
use io::ErrorKind::*;
|
||||
matches!(
|
||||
e.kind(),
|
||||
ConnectionRefused | ConnectionAborted | ConnectionReset
|
||||
)
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait Handler {
|
||||
/// Handle single query.
|
||||
/// postgres_backend will issue ReadyForQuery after calling this (this
|
||||
/// might be not what we want after CopyData streaming, but currently we don't
|
||||
/// care).
|
||||
/// care). It will also flush out the output buffer.
|
||||
async fn process_query(
|
||||
&mut self,
|
||||
pgb: &mut PostgresBackend,
|
||||
@@ -92,9 +97,13 @@ pub trait Handler {
|
||||
/// XXX: The order of the constructors matters.
|
||||
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd)]
|
||||
pub enum ProtoState {
|
||||
/// Nothing happened yet.
|
||||
Initialization,
|
||||
/// Encryption handshake is done; waiting for encrypted Startup message.
|
||||
Encrypted,
|
||||
/// Waiting for password (auth token).
|
||||
Authentication,
|
||||
/// Performed handshake and auth, ReadyForQuery is issued.
|
||||
Established,
|
||||
Closed,
|
||||
}
|
||||
@@ -105,15 +114,13 @@ pub enum ProcessMsgResult {
|
||||
Break,
|
||||
}
|
||||
|
||||
/// Always-writeable sock_split stream.
|
||||
/// May not be readable. See [`PostgresBackend::take_stream_in`]
|
||||
pub enum Stream {
|
||||
Unencrypted(BufReader<tokio::net::TcpStream>),
|
||||
Tls(Box<tokio_rustls::server::TlsStream<BufReader<tokio::net::TcpStream>>>),
|
||||
Broken,
|
||||
/// Either plain TCP stream or encrypted one, implementing AsyncRead + AsyncWrite.
|
||||
pub enum MaybeTlsStream {
|
||||
Unencrypted(tokio::net::TcpStream),
|
||||
Tls(Box<tokio_rustls::server::TlsStream<tokio::net::TcpStream>>),
|
||||
}
|
||||
|
||||
impl AsyncWrite for Stream {
|
||||
impl AsyncWrite for MaybeTlsStream {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
@@ -122,14 +129,12 @@ impl AsyncWrite for Stream {
|
||||
match self.get_mut() {
|
||||
Self::Unencrypted(stream) => Pin::new(stream).poll_write(cx, buf),
|
||||
Self::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
|
||||
Self::Broken => unreachable!(),
|
||||
}
|
||||
}
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<io::Result<()>> {
|
||||
match self.get_mut() {
|
||||
Self::Unencrypted(stream) => Pin::new(stream).poll_flush(cx),
|
||||
Self::Tls(stream) => Pin::new(stream).poll_flush(cx),
|
||||
Self::Broken => unreachable!(),
|
||||
}
|
||||
}
|
||||
fn poll_shutdown(
|
||||
@@ -139,11 +144,10 @@ impl AsyncWrite for Stream {
|
||||
match self.get_mut() {
|
||||
Self::Unencrypted(stream) => Pin::new(stream).poll_shutdown(cx),
|
||||
Self::Tls(stream) => Pin::new(stream).poll_shutdown(cx),
|
||||
Self::Broken => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl AsyncRead for Stream {
|
||||
impl AsyncRead for MaybeTlsStream {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
@@ -152,18 +156,96 @@ impl AsyncRead for Stream {
|
||||
match self.get_mut() {
|
||||
Self::Unencrypted(stream) => Pin::new(stream).poll_read(cx, buf),
|
||||
Self::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
|
||||
Self::Broken => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
|
||||
pub enum AuthType {
|
||||
Trust,
|
||||
// This mimics postgres's AuthenticationCleartextPassword but instead of password expects JWT
|
||||
NeonJWT,
|
||||
}
|
||||
|
||||
impl FromStr for AuthType {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"Trust" => Ok(Self::Trust),
|
||||
"NeonJWT" => Ok(Self::NeonJWT),
|
||||
_ => anyhow::bail!("invalid value \"{s}\" for auth type"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for AuthType {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str(match self {
|
||||
AuthType::Trust => "Trust",
|
||||
AuthType::NeonJWT => "NeonJWT",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Either full duplex Framed or write only half; the latter is left in
|
||||
/// PostgresBackend after call to `split`. In principle we could always store a
|
||||
/// pair of splitted handles, but that would force to to pay splitting price
|
||||
/// (Arc and kinda mutex inside polling) for all uses (e.g. pageserver).
|
||||
enum MaybeWriteOnly {
|
||||
Full(Framed<MaybeTlsStream>),
|
||||
WriteOnly(FramedWriter<MaybeTlsStream>),
|
||||
Broken, // temporary value palmed off during the split
|
||||
}
|
||||
|
||||
impl MaybeWriteOnly {
|
||||
async fn read_startup_message(&mut self) -> Result<Option<FeStartupPacket>, ConnectionError> {
|
||||
match self {
|
||||
MaybeWriteOnly::Full(framed) => framed.read_startup_message().await,
|
||||
MaybeWriteOnly::WriteOnly(_) => {
|
||||
Err(io::Error::new(ErrorKind::Other, "reading from write only half").into())
|
||||
}
|
||||
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
|
||||
}
|
||||
}
|
||||
|
||||
async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
|
||||
match self {
|
||||
MaybeWriteOnly::Full(framed) => framed.read_message().await,
|
||||
MaybeWriteOnly::WriteOnly(_) => {
|
||||
Err(io::Error::new(ErrorKind::Other, "reading from write only half").into())
|
||||
}
|
||||
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
|
||||
}
|
||||
}
|
||||
|
||||
fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ConnectionError> {
|
||||
match self {
|
||||
MaybeWriteOnly::Full(framed) => framed.write_message(msg),
|
||||
MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.write_message_noflush(msg),
|
||||
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
|
||||
}
|
||||
}
|
||||
|
||||
async fn flush(&mut self) -> io::Result<()> {
|
||||
match self {
|
||||
MaybeWriteOnly::Full(framed) => framed.flush().await,
|
||||
MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.flush().await,
|
||||
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
|
||||
}
|
||||
}
|
||||
|
||||
async fn shutdown(&mut self) -> io::Result<()> {
|
||||
match self {
|
||||
MaybeWriteOnly::Full(framed) => framed.shutdown().await,
|
||||
MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.shutdown().await,
|
||||
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PostgresBackend {
|
||||
stream: Stream,
|
||||
|
||||
// Output buffer. c.f. BeMessage::write why we are using BytesMut here.
|
||||
// The data between 0 and "current position" as tracked by the bytes::Buf
|
||||
// implementation of BytesMut, have already been written.
|
||||
buf_out: BytesMut,
|
||||
framed: MaybeWriteOnly,
|
||||
|
||||
pub state: ProtoState,
|
||||
|
||||
@@ -183,7 +265,7 @@ pub fn query_from_cstring(query_string: Bytes) -> Vec<u8> {
|
||||
query_string
|
||||
}
|
||||
|
||||
// Cast a byte slice to a string slice, dropping null terminator if there's one.
|
||||
/// Cast a byte slice to a string slice, dropping null terminator if there's one.
|
||||
fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> {
|
||||
let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes);
|
||||
std::str::from_utf8(without_null).map_err(|e| e.into())
|
||||
@@ -196,10 +278,10 @@ impl PostgresBackend {
|
||||
tls_config: Option<Arc<rustls::ServerConfig>>,
|
||||
) -> io::Result<Self> {
|
||||
let peer_addr = socket.peer_addr()?;
|
||||
let stream = MaybeTlsStream::Unencrypted(socket);
|
||||
|
||||
Ok(Self {
|
||||
stream: Stream::Unencrypted(BufReader::new(socket)),
|
||||
buf_out: BytesMut::with_capacity(10 * 1024),
|
||||
framed: MaybeWriteOnly::Full(Framed::new(stream)),
|
||||
state: ProtoState::Initialization,
|
||||
auth_type,
|
||||
tls_config,
|
||||
@@ -211,30 +293,52 @@ impl PostgresBackend {
|
||||
&self.peer_addr
|
||||
}
|
||||
|
||||
/// Read full message or return None if connection is closed.
|
||||
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, QueryError> {
|
||||
use ProtoState::*;
|
||||
match self.state {
|
||||
Initialization | Encrypted => FeStartupPacket::read_fut(&mut self.stream).await,
|
||||
Authentication | Established => FeMessage::read_fut(&mut self.stream).await,
|
||||
Closed => Ok(None),
|
||||
/// Read full message or return None if connection is cleanly closed with no
|
||||
/// unprocessed data.
|
||||
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
|
||||
if let ProtoState::Closed = self.state {
|
||||
Ok(None)
|
||||
} else {
|
||||
let m = self.framed.read_message().await?;
|
||||
trace!("read msg {:?}", m);
|
||||
Ok(m)
|
||||
}
|
||||
.map_err(QueryError::from)
|
||||
}
|
||||
|
||||
/// Write message into internal output buffer, doesn't flush it. Technically
|
||||
/// error type can be only ProtocolError here (if, unlikely, serialization
|
||||
/// fails), but callers typically wrap it anyway.
|
||||
pub fn write_message_noflush(
|
||||
&mut self,
|
||||
message: &BeMessage<'_>,
|
||||
) -> Result<&mut Self, ConnectionError> {
|
||||
self.framed.write_message_noflush(message)?;
|
||||
trace!("wrote msg {:?}", message);
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Flush output buffer into the socket.
|
||||
pub async fn flush(&mut self) -> io::Result<()> {
|
||||
while self.buf_out.has_remaining() {
|
||||
let bytes_written = self.stream.write(self.buf_out.chunk()).await?;
|
||||
self.buf_out.advance(bytes_written);
|
||||
}
|
||||
self.buf_out.clear();
|
||||
Ok(())
|
||||
self.framed.flush().await
|
||||
}
|
||||
|
||||
/// Write message into internal output buffer.
|
||||
pub fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
|
||||
BeMessage::write(&mut self.buf_out, message)?;
|
||||
/// Polling version of `flush()`, saves the caller need to pin.
|
||||
pub fn poll_flush(
|
||||
&mut self,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
let flush_fut = self.flush();
|
||||
pin_mut!(flush_fut);
|
||||
flush_fut.poll(cx)
|
||||
}
|
||||
|
||||
/// Write message into internal output buffer and flush it to the stream.
|
||||
pub async fn write_message(
|
||||
&mut self,
|
||||
message: &BeMessage<'_>,
|
||||
) -> Result<&mut Self, ConnectionError> {
|
||||
self.write_message_noflush(message)?;
|
||||
self.flush().await?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
@@ -246,26 +350,7 @@ impl PostgresBackend {
|
||||
CopyDataWriter { pgb: self }
|
||||
}
|
||||
|
||||
/// A polling function that tries to write all the data from 'buf_out' to the
|
||||
/// underlying stream.
|
||||
fn poll_write_buf(
|
||||
&mut self,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
while self.buf_out.has_remaining() {
|
||||
match ready!(Pin::new(&mut self.stream).poll_write(cx, self.buf_out.chunk())) {
|
||||
Ok(bytes_written) => self.buf_out.advance(bytes_written),
|
||||
Err(err) => return Poll::Ready(Err(err)),
|
||||
}
|
||||
}
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_flush(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), std::io::Error>> {
|
||||
Pin::new(&mut self.stream).poll_flush(cx)
|
||||
}
|
||||
|
||||
// Wrapper for run_message_loop() that shuts down socket when we are done
|
||||
/// Wrapper for run_message_loop() that shuts down socket when we are done
|
||||
pub async fn run<F, S>(
|
||||
mut self,
|
||||
handler: &mut impl Handler,
|
||||
@@ -276,7 +361,9 @@ impl PostgresBackend {
|
||||
S: Future,
|
||||
{
|
||||
let ret = self.run_message_loop(handler, shutdown_watcher).await;
|
||||
let _ = self.stream.shutdown();
|
||||
// socket might be already closed, e.g. if previously received error,
|
||||
// so ignore result.
|
||||
self.framed.shutdown().await.ok();
|
||||
ret
|
||||
}
|
||||
|
||||
@@ -300,30 +387,12 @@ impl PostgresBackend {
|
||||
return Ok(())
|
||||
},
|
||||
|
||||
result = async {
|
||||
while self.state < ProtoState::Established {
|
||||
if let Some(msg) = self.read_message().await? {
|
||||
trace!("got message {msg:?} during handshake");
|
||||
|
||||
match self.process_handshake_message(handler, msg).await? {
|
||||
ProcessMsgResult::Continue => {
|
||||
self.flush().await?;
|
||||
continue;
|
||||
}
|
||||
ProcessMsgResult::Break => {
|
||||
trace!("postgres backend to {:?} exited during handshake", self.peer_addr);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
trace!("postgres backend to {:?} exited during handshake", self.peer_addr);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
Ok::<(), QueryError>(())
|
||||
} => {
|
||||
result = self.handshake(handler) => {
|
||||
// Handshake complete.
|
||||
result?;
|
||||
if self.state == ProtoState::Closed {
|
||||
return Ok(()); // EOF during handshake
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
@@ -355,114 +424,207 @@ impl PostgresBackend {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn start_tls(&mut self) -> anyhow::Result<()> {
|
||||
if let Stream::Unencrypted(plain_stream) =
|
||||
std::mem::replace(&mut self.stream, Stream::Broken)
|
||||
{
|
||||
let acceptor = TlsAcceptor::from(self.tls_config.clone().unwrap());
|
||||
let tls_stream = acceptor.accept(plain_stream).await?;
|
||||
|
||||
self.stream = Stream::Tls(Box::new(tls_stream));
|
||||
return Ok(());
|
||||
};
|
||||
anyhow::bail!("TLS already started");
|
||||
}
|
||||
|
||||
async fn process_handshake_message(
|
||||
&mut self,
|
||||
handler: &mut impl Handler,
|
||||
msg: FeMessage,
|
||||
) -> Result<ProcessMsgResult, QueryError> {
|
||||
assert!(self.state < ProtoState::Established);
|
||||
let have_tls = self.tls_config.is_some();
|
||||
match msg {
|
||||
FeMessage::StartupPacket(m) => {
|
||||
trace!("got startup message {m:?}");
|
||||
|
||||
match m {
|
||||
FeStartupPacket::SslRequest => {
|
||||
debug!("SSL requested");
|
||||
|
||||
self.write_message_noflush(&BeMessage::EncryptionResponse(have_tls))?;
|
||||
if have_tls {
|
||||
self.start_tls().await?;
|
||||
self.state = ProtoState::Encrypted;
|
||||
}
|
||||
}
|
||||
FeStartupPacket::GssEncRequest => {
|
||||
debug!("GSS requested");
|
||||
self.write_message_noflush(&BeMessage::EncryptionResponse(false))?;
|
||||
}
|
||||
FeStartupPacket::StartupMessage { .. } => {
|
||||
if have_tls && !matches!(self.state, ProtoState::Encrypted) {
|
||||
self.write_message_noflush(&BeMessage::ErrorResponse(
|
||||
"must connect with TLS",
|
||||
None,
|
||||
))?;
|
||||
return Err(QueryError::Other(anyhow::anyhow!(
|
||||
"client did not connect with TLS"
|
||||
)));
|
||||
}
|
||||
|
||||
// NB: startup() may change self.auth_type -- we are using that in proxy code
|
||||
// to bypass auth for new users.
|
||||
handler.startup(self, &m)?;
|
||||
|
||||
match self.auth_type {
|
||||
AuthType::Trust => {
|
||||
self.write_message_noflush(&BeMessage::AuthenticationOk)?
|
||||
.write_message_noflush(&BeMessage::CLIENT_ENCODING)?
|
||||
// The async python driver requires a valid server_version
|
||||
.write_message_noflush(&BeMessage::server_version("14.1"))?
|
||||
.write_message_noflush(&BeMessage::ReadyForQuery)?;
|
||||
self.state = ProtoState::Established;
|
||||
}
|
||||
AuthType::NeonJWT => {
|
||||
self.write_message_noflush(
|
||||
&BeMessage::AuthenticationCleartextPassword,
|
||||
)?;
|
||||
self.state = ProtoState::Authentication;
|
||||
}
|
||||
}
|
||||
}
|
||||
FeStartupPacket::CancelRequest { .. } => {
|
||||
self.state = ProtoState::Closed;
|
||||
return Ok(ProcessMsgResult::Break);
|
||||
}
|
||||
}
|
||||
/// Try to upgrade MaybeTlsStream into actual TLS one, performing handshake.
|
||||
async fn tls_upgrade(
|
||||
src: MaybeTlsStream,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
) -> anyhow::Result<MaybeTlsStream> {
|
||||
match src {
|
||||
MaybeTlsStream::Unencrypted(s) => {
|
||||
let acceptor = TlsAcceptor::from(tls_config);
|
||||
let tls_stream = acceptor.accept(s).await?;
|
||||
Ok(MaybeTlsStream::Tls(Box::new(tls_stream)))
|
||||
}
|
||||
|
||||
FeMessage::PasswordMessage(m) => {
|
||||
trace!("got password message '{:?}'", m);
|
||||
|
||||
assert!(self.state == ProtoState::Authentication);
|
||||
|
||||
match self.auth_type {
|
||||
AuthType::Trust => unreachable!(),
|
||||
AuthType::NeonJWT => {
|
||||
let (_, jwt_response) = m.split_last().context("protocol violation")?;
|
||||
|
||||
if let Err(e) = handler.check_auth_jwt(self, jwt_response) {
|
||||
self.write_message_noflush(&BeMessage::ErrorResponse(
|
||||
&e.to_string(),
|
||||
Some(e.pg_error_code()),
|
||||
))?;
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
self.write_message_noflush(&BeMessage::AuthenticationOk)?
|
||||
.write_message_noflush(&BeMessage::CLIENT_ENCODING)?
|
||||
.write_message_noflush(&BeMessage::ReadyForQuery)?;
|
||||
self.state = ProtoState::Established;
|
||||
}
|
||||
|
||||
_ => {
|
||||
self.state = ProtoState::Closed;
|
||||
return Ok(ProcessMsgResult::Break);
|
||||
MaybeTlsStream::Tls(_) => {
|
||||
anyhow::bail!("TLS already started");
|
||||
}
|
||||
}
|
||||
Ok(ProcessMsgResult::Continue)
|
||||
}
|
||||
|
||||
async fn start_tls(&mut self) -> anyhow::Result<()> {
|
||||
// temporary replace stream with fake to cook TLS one, Indiana Jones style
|
||||
match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
|
||||
MaybeWriteOnly::Full(framed) => {
|
||||
let tls_config = self
|
||||
.tls_config
|
||||
.as_ref()
|
||||
.context("start_tls called without conf")?
|
||||
.clone();
|
||||
let tls_framed = framed
|
||||
.map_stream(|s| PostgresBackend::tls_upgrade(s, tls_config))
|
||||
.await?;
|
||||
// push back ready TLS stream
|
||||
self.framed = MaybeWriteOnly::Full(tls_framed);
|
||||
Ok(())
|
||||
}
|
||||
MaybeWriteOnly::WriteOnly(_) => {
|
||||
anyhow::bail!("TLS upgrade attempt in split state")
|
||||
}
|
||||
MaybeWriteOnly::Broken => panic!("TLS upgrade on framed in invalid state"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Split off owned read part from which messages can be read in different
|
||||
/// task/thread.
|
||||
pub fn split(&mut self) -> anyhow::Result<PostgresBackendReader> {
|
||||
// temporary replace stream with fake to cook split one, Indiana Jones style
|
||||
match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
|
||||
MaybeWriteOnly::Full(framed) => {
|
||||
let (reader, writer) = framed.split();
|
||||
self.framed = MaybeWriteOnly::WriteOnly(writer);
|
||||
Ok(PostgresBackendReader(reader))
|
||||
}
|
||||
MaybeWriteOnly::WriteOnly(_) => {
|
||||
anyhow::bail!("PostgresBackend is already split")
|
||||
}
|
||||
MaybeWriteOnly::Broken => panic!("split on framed in invalid state"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Join read part back.
|
||||
pub fn unsplit(&mut self, reader: PostgresBackendReader) -> anyhow::Result<()> {
|
||||
// temporary replace stream with fake to cook joined one, Indiana Jones style
|
||||
match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
|
||||
MaybeWriteOnly::Full(_) => {
|
||||
anyhow::bail!("PostgresBackend is not split")
|
||||
}
|
||||
MaybeWriteOnly::WriteOnly(writer) => {
|
||||
let joined = Framed::unsplit(reader.0, writer);
|
||||
self.framed = MaybeWriteOnly::Full(joined);
|
||||
Ok(())
|
||||
}
|
||||
MaybeWriteOnly::Broken => panic!("unsplit on framed in invalid state"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform handshake with the client, transitioning to Established.
|
||||
/// In case of EOF during handshake logs this, sets state to Closed and returns Ok(()).
|
||||
async fn handshake(&mut self, handler: &mut impl Handler) -> Result<(), QueryError> {
|
||||
while self.state < ProtoState::Authentication {
|
||||
match self.framed.read_startup_message().await? {
|
||||
Some(msg) => {
|
||||
self.process_startup_message(handler, msg).await?;
|
||||
}
|
||||
None => {
|
||||
trace!(
|
||||
"postgres backend to {:?} received EOF during handshake",
|
||||
self.peer_addr
|
||||
);
|
||||
self.state = ProtoState::Closed;
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Perform auth, if needed.
|
||||
if self.state == ProtoState::Authentication {
|
||||
match self.framed.read_message().await? {
|
||||
Some(FeMessage::PasswordMessage(m)) => {
|
||||
assert!(self.auth_type == AuthType::NeonJWT);
|
||||
|
||||
let (_, jwt_response) = m.split_last().context("protocol violation")?;
|
||||
|
||||
if let Err(e) = handler.check_auth_jwt(self, jwt_response) {
|
||||
self.write_message_noflush(&BeMessage::ErrorResponse(
|
||||
&e.to_string(),
|
||||
Some(e.pg_error_code()),
|
||||
))?;
|
||||
return Err(e);
|
||||
}
|
||||
|
||||
self.write_message_noflush(&BeMessage::AuthenticationOk)?
|
||||
.write_message_noflush(&BeMessage::CLIENT_ENCODING)?
|
||||
.write_message(&BeMessage::ReadyForQuery)
|
||||
.await?;
|
||||
self.state = ProtoState::Established;
|
||||
}
|
||||
Some(m) => {
|
||||
return Err(QueryError::Other(anyhow::anyhow!(
|
||||
"Unexpected message {:?} while waiting for handshake",
|
||||
m
|
||||
)));
|
||||
}
|
||||
None => {
|
||||
trace!(
|
||||
"postgres backend to {:?} received EOF during auth",
|
||||
self.peer_addr
|
||||
);
|
||||
self.state = ProtoState::Closed;
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Process startup packet:
|
||||
/// - transition to Established if auth type is trust
|
||||
/// - transition to Authentication if auth type is NeonJWT.
|
||||
/// - or perform TLS handshake -- then need to call this again to receive
|
||||
/// actual startup packet.
|
||||
async fn process_startup_message(
|
||||
&mut self,
|
||||
handler: &mut impl Handler,
|
||||
msg: FeStartupPacket,
|
||||
) -> Result<(), QueryError> {
|
||||
assert!(self.state < ProtoState::Authentication);
|
||||
let have_tls = self.tls_config.is_some();
|
||||
match msg {
|
||||
FeStartupPacket::SslRequest => {
|
||||
debug!("SSL requested");
|
||||
|
||||
self.write_message(&BeMessage::EncryptionResponse(have_tls))
|
||||
.await?;
|
||||
|
||||
if have_tls {
|
||||
self.start_tls().await?;
|
||||
self.state = ProtoState::Encrypted;
|
||||
}
|
||||
}
|
||||
FeStartupPacket::GssEncRequest => {
|
||||
debug!("GSS requested");
|
||||
self.write_message(&BeMessage::EncryptionResponse(false))
|
||||
.await?;
|
||||
}
|
||||
FeStartupPacket::StartupMessage { .. } => {
|
||||
if have_tls && !matches!(self.state, ProtoState::Encrypted) {
|
||||
self.write_message(&BeMessage::ErrorResponse("must connect with TLS", None))
|
||||
.await?;
|
||||
return Err(QueryError::Other(anyhow::anyhow!(
|
||||
"client did not connect with TLS"
|
||||
)));
|
||||
}
|
||||
|
||||
// NB: startup() may change self.auth_type -- we are using that in proxy code
|
||||
// to bypass auth for new users.
|
||||
handler.startup(self, &msg)?;
|
||||
|
||||
match self.auth_type {
|
||||
AuthType::Trust => {
|
||||
self.write_message_noflush(&BeMessage::AuthenticationOk)?
|
||||
.write_message_noflush(&BeMessage::CLIENT_ENCODING)?
|
||||
.write_message_noflush(&BeMessage::INTEGER_DATETIMES)?
|
||||
// The async python driver requires a valid server_version
|
||||
.write_message_noflush(&BeMessage::server_version("14.1"))?
|
||||
.write_message(&BeMessage::ReadyForQuery)
|
||||
.await?;
|
||||
self.state = ProtoState::Established;
|
||||
}
|
||||
AuthType::NeonJWT => {
|
||||
self.write_message(&BeMessage::AuthenticationCleartextPassword)
|
||||
.await?;
|
||||
self.state = ProtoState::Authentication;
|
||||
}
|
||||
}
|
||||
}
|
||||
FeStartupPacket::CancelRequest { .. } => {
|
||||
return Err(QueryError::Other(anyhow::anyhow!(
|
||||
"Unexpected CancelRequest message during handshake"
|
||||
)));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn process_message(
|
||||
@@ -476,10 +638,6 @@ impl PostgresBackend {
|
||||
assert!(self.state == ProtoState::Established);
|
||||
|
||||
match msg {
|
||||
FeMessage::StartupPacket(_) | FeMessage::PasswordMessage(_) => {
|
||||
return Err(QueryError::Other(anyhow::anyhow!("protocol violation")));
|
||||
}
|
||||
|
||||
FeMessage::Query(body) => {
|
||||
// remove null terminator
|
||||
let query_string = cstr_to_str(&body)?;
|
||||
@@ -540,16 +698,114 @@ impl PostgresBackend {
|
||||
|
||||
// We prefer explicit pattern matching to wildcards, because
|
||||
// this helps us spot the places where new variants are missing
|
||||
FeMessage::CopyData(_) | FeMessage::CopyDone | FeMessage::CopyFail => {
|
||||
FeMessage::CopyData(_)
|
||||
| FeMessage::CopyDone
|
||||
| FeMessage::CopyFail
|
||||
| FeMessage::PasswordMessage(_)
|
||||
| FeMessage::StartupPacket(_) => {
|
||||
return Err(QueryError::Other(anyhow::anyhow!(
|
||||
"unexpected message type: {:?}",
|
||||
msg
|
||||
"unexpected message type: {msg:?}",
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ProcessMsgResult::Continue)
|
||||
}
|
||||
|
||||
/// Log as info/error result of handling COPY stream and send back
|
||||
/// ErrorResponse if that makes sense. Shutdown the stream if we got
|
||||
/// Terminate. TODO: transition into waiting for Sync msg if we initiate the
|
||||
/// close.
|
||||
pub async fn handle_copy_stream_end(&mut self, end: CopyStreamHandlerEnd) {
|
||||
use CopyStreamHandlerEnd::*;
|
||||
|
||||
let expected_end = match &end {
|
||||
ServerInitiated(_) | CopyDone | CopyFail | Terminate | EOF => true,
|
||||
CopyStreamHandlerEnd::Disconnected(ConnectionError::Socket(io_error))
|
||||
if is_expected_io_error(io_error) =>
|
||||
{
|
||||
true
|
||||
}
|
||||
_ => false,
|
||||
};
|
||||
if expected_end {
|
||||
info!("terminated: {:#}", end);
|
||||
} else {
|
||||
error!("terminated: {:?}", end);
|
||||
}
|
||||
|
||||
// Note: no current usages ever send this
|
||||
if let CopyDone = &end {
|
||||
if let Err(e) = self.write_message(&BeMessage::CopyDone).await {
|
||||
error!("failed to send CopyDone: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
if let Terminate = &end {
|
||||
self.state = ProtoState::Closed;
|
||||
}
|
||||
|
||||
let err_to_send_and_errcode = match &end {
|
||||
ServerInitiated(_) => Some((end.to_string(), SQLSTATE_SUCCESSFUL_COMPLETION)),
|
||||
Other(_) => Some((end.to_string(), SQLSTATE_INTERNAL_ERROR)),
|
||||
// Note: CopyFail in duplex copy is somewhat unexpected (at least to
|
||||
// PG walsender; evidently and per my docs reading client should
|
||||
// finish it with CopyDone). It is not a problem to recover from it
|
||||
// finishing the stream in both directions like we do, but note that
|
||||
// sync rust-postgres client (which we don't use anymore) hangs if
|
||||
// socket is not closed here.
|
||||
// https://github.com/sfackler/rust-postgres/issues/755
|
||||
// https://github.com/neondatabase/neon/issues/935
|
||||
//
|
||||
// Currently, the version of tokio_postgres replication patch we use
|
||||
// sends this when it closes the stream (e.g. pageserver decided to
|
||||
// switch conn to another safekeeper and client gets dropped).
|
||||
// Moreover, seems like 'connection' task errors with 'unexpected
|
||||
// message from server' when it receives ErrorResponse (anything but
|
||||
// CopyData/CopyDone) back.
|
||||
CopyFail => Some((end.to_string(), SQLSTATE_SUCCESSFUL_COMPLETION)),
|
||||
_ => None,
|
||||
};
|
||||
if let Some((err, errcode)) = err_to_send_and_errcode {
|
||||
if let Err(ee) = self
|
||||
.write_message(&BeMessage::ErrorResponse(&err, Some(errcode)))
|
||||
.await
|
||||
{
|
||||
error!("failed to send ErrorResponse: {}", ee);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PostgresBackendReader(FramedReader<MaybeTlsStream>);
|
||||
|
||||
impl PostgresBackendReader {
|
||||
/// Read full message or return None if connection is cleanly closed with no
|
||||
/// unprocessed data.
|
||||
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
|
||||
let m = self.0.read_message().await?;
|
||||
trace!("read msg {:?}", m);
|
||||
Ok(m)
|
||||
}
|
||||
|
||||
/// Get CopyData contents of the next message in COPY stream or error
|
||||
/// closing it. The error type is wider than actual errors which can happen
|
||||
/// here -- it includes 'Other' and 'ServerInitiated', but that's ok for
|
||||
/// current callers.
|
||||
pub async fn read_copy_message(&mut self) -> Result<Bytes, CopyStreamHandlerEnd> {
|
||||
match self.read_message().await? {
|
||||
Some(msg) => match msg {
|
||||
FeMessage::CopyData(m) => Ok(m),
|
||||
FeMessage::CopyDone => Err(CopyStreamHandlerEnd::CopyDone),
|
||||
FeMessage::CopyFail => Err(CopyStreamHandlerEnd::CopyFail),
|
||||
FeMessage::Terminate => Err(CopyStreamHandlerEnd::Terminate),
|
||||
_ => Err(CopyStreamHandlerEnd::from(ConnectionError::Protocol(
|
||||
format!("unexpected message in COPY stream {:?}", msg),
|
||||
))),
|
||||
},
|
||||
None => Err(CopyStreamHandlerEnd::EOF),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
@@ -572,16 +828,19 @@ impl<'a> AsyncWrite for CopyDataWriter<'a> {
|
||||
// It's not strictly required to flush between each message, but makes it easier
|
||||
// to view in wireshark, and usually the messages that the callers write are
|
||||
// decently-sized anyway.
|
||||
match ready!(this.pgb.poll_write_buf(cx)) {
|
||||
Ok(()) => {}
|
||||
Err(err) => return Poll::Ready(Err(err)),
|
||||
if let Err(err) = ready!(this.pgb.poll_flush(cx)) {
|
||||
return Poll::Ready(Err(err));
|
||||
}
|
||||
|
||||
// CopyData
|
||||
// XXX: if the input is large, we should split it into multiple messages.
|
||||
// Not sure what the threshold should be, but the ultimate hard limit is that
|
||||
// the length cannot exceed u32.
|
||||
this.pgb.write_message_noflush(&BeMessage::CopyData(buf))?;
|
||||
this.pgb
|
||||
.write_message_noflush(&BeMessage::CopyData(buf))
|
||||
// write_message only writes to the buffer, so it can fail iff the
|
||||
// message is invaid, but CopyData can't be invalid.
|
||||
.map_err(|_| io::Error::new(ErrorKind::Other, "failed to serialize CopyData"))?;
|
||||
|
||||
Poll::Ready(Ok(buf.len()))
|
||||
}
|
||||
@@ -591,21 +850,14 @@ impl<'a> AsyncWrite for CopyDataWriter<'a> {
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
let this = self.get_mut();
|
||||
match ready!(this.pgb.poll_write_buf(cx)) {
|
||||
Ok(()) => {}
|
||||
Err(err) => return Poll::Ready(Err(err)),
|
||||
}
|
||||
this.pgb.poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
let this = self.get_mut();
|
||||
match ready!(this.pgb.poll_write_buf(cx)) {
|
||||
Ok(()) => {}
|
||||
Err(err) => return Poll::Ready(Err(err)),
|
||||
}
|
||||
this.pgb.poll_flush(cx)
|
||||
}
|
||||
}
|
||||
@@ -617,7 +869,7 @@ pub fn short_error(e: &QueryError) -> String {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn log_query_error(query: &str, e: &QueryError) {
|
||||
fn log_query_error(query: &str, e: &QueryError) {
|
||||
match e {
|
||||
QueryError::Disconnected(ConnectionError::Socket(io_error)) => {
|
||||
if is_expected_io_error(io_error) {
|
||||
@@ -634,3 +886,26 @@ pub fn log_query_error(query: &str, e: &QueryError) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Something finishing handling of COPY stream, see handle_copy_stream_end.
|
||||
/// This is not always a real error, but it allows to use ? and thiserror impls.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum CopyStreamHandlerEnd {
|
||||
/// Handler initiates the end of streaming.
|
||||
#[error("{0}")]
|
||||
ServerInitiated(String),
|
||||
#[error("received CopyDone")]
|
||||
CopyDone,
|
||||
#[error("received CopyFail")]
|
||||
CopyFail,
|
||||
#[error("received Terminate")]
|
||||
Terminate,
|
||||
#[error("EOF on COPY stream")]
|
||||
EOF,
|
||||
/// The connection was lost
|
||||
#[error(transparent)]
|
||||
Disconnected(#[from] ConnectionError),
|
||||
/// Some other error
|
||||
#[error(transparent)]
|
||||
Other(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
21
libs/postgres_backend/tests/cert.pem
Normal file
21
libs/postgres_backend/tests/cert.pem
Normal file
@@ -0,0 +1,21 @@
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIDbjCCAlagAwIBAgIUGHJukXa1bQathgBHC40+A18BsnYwDQYJKoZIhvcNAQEL
|
||||
BQAwYzELMAkGA1UEBhMCVVMxEzARBgNVBAgMCkNhbGlmb3JuaWExFjAUBgNVBAcM
|
||||
DVNhbiBGcmFuY2lzY28xEzARBgNVBAoMCk15IENvbXBhbnkxEjAQBgNVBAMMCWxv
|
||||
Y2FsaG9zdDAgFw0yMTA4MTMxODQyMjBaGA8yMTIxMDcyMDE4NDIyMFowYzELMAkG
|
||||
A1UEBhMCVVMxEzARBgNVBAgMCkNhbGlmb3JuaWExFjAUBgNVBAcMDVNhbiBGcmFu
|
||||
Y2lzY28xEzARBgNVBAoMCk15IENvbXBhbnkxEjAQBgNVBAMMCWxvY2FsaG9zdDCC
|
||||
ASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAOI9S+nh8ABMp5jpb7WWfAYr
|
||||
tGJ4C7gi9IPTVIRxSSrt5KglEysrOiKlhan1Ut2e8CCudztdXtCvT8/goJWlmxpF
|
||||
IQkErlCsOdGHeEJ0EZxoU1fMkBAQVf6Rb1JE9ladG2+D1e7yvxmMqfPVuU8lj+kN
|
||||
nESP+I3ESNCtuqgtfcErxu3TuhSzV2slSi5lrYQCwERgCevl6LUNd2mEaYdS4mmJ
|
||||
4RZqc2C4y7JO5wSDjga8GIBHJVo70HRVsvX7eE8r6tMP2HyGyonBitBKAc2QEQIv
|
||||
cLCuMOTtTBlYcMvTmJEOHFKwIJXm0XmQfAWeKFfyK7493fB4Gu+8Dc1xC+IHaTEC
|
||||
AwEAAaMYMBYwFAYDVR0RBA0wC4IJbG9jYWxob3N0MA0GCSqGSIb3DQEBCwUAA4IB
|
||||
AQBjY+g3eF8m8lEWz+QgKp88MhTdtJTsEsSz0GAi58SnEkuyxVOHjKEyjGKJWTtT
|
||||
ICgmEzC85uaS7VBdftoYNmsbvNewGiisDGQRWCjOGM7lTaA4FQPADguexMvXh/nO
|
||||
9PQoTxtp7qwvGWO2mED6LWU6bjT3cL+XgrOwT9sticRTl6/BXV8wAmyxT0DkQ3nJ
|
||||
zbRuTP/G2kE0bRK++67kK0ovopRkX6Dl6di1EFlkAnPBC2d8tdcNTXYhkxZk4O0q
|
||||
GUolwiuWz/dtD3tZ2bx3vqzT7uIFHS4XP6Q3SRNWFTGhuvAc7DPvCZBqxy6odeyQ
|
||||
VxBgJtq+pNjYYkeaSQVQ+UMU
|
||||
-----END CERTIFICATE-----
|
||||
27
libs/postgres_backend/tests/key.pem
Normal file
27
libs/postgres_backend/tests/key.pem
Normal file
@@ -0,0 +1,27 @@
|
||||
-----BEGIN RSA PRIVATE KEY-----
|
||||
MIIEpAIBAAKCAQEA4j1L6eHwAEynmOlvtZZ8Biu0YngLuCL0g9NUhHFJKu3kqCUT
|
||||
Kys6IqWFqfVS3Z7wIK53O11e0K9Pz+CglaWbGkUhCQSuUKw50Yd4QnQRnGhTV8yQ
|
||||
EBBV/pFvUkT2Vp0bb4PV7vK/GYyp89W5TyWP6Q2cRI/4jcRI0K26qC19wSvG7dO6
|
||||
FLNXayVKLmWthALARGAJ6+XotQ13aYRph1LiaYnhFmpzYLjLsk7nBIOOBrwYgEcl
|
||||
WjvQdFWy9ft4Tyvq0w/YfIbKicGK0EoBzZARAi9wsK4w5O1MGVhwy9OYkQ4cUrAg
|
||||
lebReZB8BZ4oV/Irvj3d8Hga77wNzXEL4gdpMQIDAQABAoIBAQClKycO+zpinZQG
|
||||
GPbLVa/6OVIaSZYUusBUtaaQgrxuMPusnlSeQZLR1JH/APGchvq8gWLe3k3ogPT9
|
||||
yPq0BhF0Xl+928L/dp1HkWWE7oQk8i1Wfiv27lY54iepoltN5KkxAsjfCC3oEz/I
|
||||
mpINbFjiRmN90rYdmd2nLA6H1Z5ntZQm5AcTo3OJZlTVN9eH9TV8f0AQRQgUJsL9
|
||||
75agSmj7euqZOqvvwfpsYzaZEhzMSG2QIcS3WglInbHy8c6ikZSm36J36wgsatMz
|
||||
CBZ6pMNtonRSKvAECQhBGEA73evtnGbLH0EY9KouN4KSHEHob89dGVeeXozksf9x
|
||||
QUE1/yOhAoGBAP818f7vIH6Z3QwWgTMwQsPBW+wNOIbTZrbZaihnz2K9XMu39TV6
|
||||
DWQHMsOlvg2QURZGwqB3jFn4wqZHmt7XYwk553E60kIw4hDvgpkkqmXVwK3kZASQ
|
||||
RRUax3hZ1gCWxpXlRZ1SvHNXjN9KEFwqQbR33XcxzC3TpSp0KYghT9jFAoGBAOLw
|
||||
agejqSF+f/5W1QhEKlM+tSlluo2sn5kKVkM4nNezFukb3pu5oScFjoQQGsoaz5aU
|
||||
kLlxW5h/aSxquhgcuo6I4Ux5dcgNm4QeonCCp+Qycn7tzyoJFL4odT9vYPQa5O9E
|
||||
hD9aSqhBBD1IIOS2T3vcW6VxibKZx1CRMDdRz119AoGALflr1L8DHYteNLVBJRWG
|
||||
kXkdtBJVooQmtr3Hz+uTgngWZWSIOc/45ZIeZPxQlmTvFpI8sWeX0wVrG0U+8vHe
|
||||
F2Vk+hLcmavwrZhX8HqYb6vn/+tq0R+kMj8Wu+mDEawXrh0VQ1gKNsUIzZisBc5e
|
||||
88G8FaLU41SDJniymqFVnvkCgYEA1ou/UfWRwg6b5tIkmKoI8aZJExgPpDzcrYyu
|
||||
POLatLmlIUCt1b9K8V85evTWvtdWBd/yar8WfzeFMO69fGo8nOAfT3NMvJLQwblM
|
||||
jN2Y6A4hXIpq3iyzpYsOPaiImn6KjQHTnSk5h5Pf9CeqoU8SGeEb629JZMYpPqvk
|
||||
T4hSaOkCgYBPaf51oSAstqdj0vxrsFS3EN3D8Fk0xQWt9Ss3ZGFAlTaEq5xoIk4k
|
||||
YfKVDv1S6/vlzbheIIzQ2lzVvG4AW+drQLsmEx5iMKvbNtFAur9kwUFU202Q2dki
|
||||
ZQJ/JvjnPYFKxy+SVlLJ1h9RD9E3dgL/Ai7OUfbmX771vN0IQF7Z6Q==
|
||||
-----END RSA PRIVATE KEY-----
|
||||
139
libs/postgres_backend/tests/simple_select.rs
Normal file
139
libs/postgres_backend/tests/simple_select.rs
Normal file
@@ -0,0 +1,139 @@
|
||||
/// Test postgres_backend_async with tokio_postgres
|
||||
use once_cell::sync::Lazy;
|
||||
use postgres_backend::{AuthType, Handler, PostgresBackend, QueryError};
|
||||
use pq_proto::{BeMessage, RowDescriptor};
|
||||
use std::io::Cursor;
|
||||
use std::{future, sync::Arc};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tokio_postgres::tls::MakeTlsConnect;
|
||||
use tokio_postgres::{Config, NoTls, SimpleQueryMessage};
|
||||
use tokio_postgres_rustls::MakeRustlsConnect;
|
||||
|
||||
// generate client, server test streams
|
||||
async fn make_tcp_pair() -> (TcpStream, TcpStream) {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
let client_stream = TcpStream::connect(addr).await.unwrap();
|
||||
let (server_stream, _) = listener.accept().await.unwrap();
|
||||
(client_stream, server_stream)
|
||||
}
|
||||
|
||||
struct TestHandler {}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Handler for TestHandler {
|
||||
// return single col 'hey' for any query
|
||||
async fn process_query(
|
||||
&mut self,
|
||||
pgb: &mut PostgresBackend,
|
||||
_query_string: &str,
|
||||
) -> Result<(), QueryError> {
|
||||
pgb.write_message_noflush(&BeMessage::RowDescription(&[RowDescriptor::text_col(
|
||||
b"hey",
|
||||
)]))?
|
||||
.write_message_noflush(&BeMessage::DataRow(&[Some("hey".as_bytes())]))?
|
||||
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// test that basic select works
|
||||
#[tokio::test]
|
||||
async fn simple_select() {
|
||||
let (client_sock, server_sock) = make_tcp_pair().await;
|
||||
|
||||
// create and run pgbackend
|
||||
let pgbackend =
|
||||
PostgresBackend::new(server_sock, AuthType::Trust, None).expect("pgbackend creation");
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut handler = TestHandler {};
|
||||
pgbackend.run(&mut handler, future::pending::<()>).await
|
||||
});
|
||||
|
||||
let conf = Config::new();
|
||||
let (client, connection) = conf.connect_raw(client_sock, NoTls).await.expect("connect");
|
||||
// The connection object performs the actual communication with the database,
|
||||
// so spawn it off to run on its own.
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = connection.await {
|
||||
eprintln!("connection error: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
let first_val = &(client.simple_query("SELECT 42;").await.expect("select"))[0];
|
||||
if let SimpleQueryMessage::Row(row) = first_val {
|
||||
let first_col = row.get(0).expect("first column");
|
||||
assert_eq!(first_col, "hey");
|
||||
} else {
|
||||
panic!("expected SimpleQueryMessage::Row");
|
||||
}
|
||||
}
|
||||
|
||||
static KEY: Lazy<rustls::PrivateKey> = Lazy::new(|| {
|
||||
let mut cursor = Cursor::new(include_bytes!("key.pem"));
|
||||
rustls::PrivateKey(rustls_pemfile::rsa_private_keys(&mut cursor).unwrap()[0].clone())
|
||||
});
|
||||
|
||||
static CERT: Lazy<rustls::Certificate> = Lazy::new(|| {
|
||||
let mut cursor = Cursor::new(include_bytes!("cert.pem"));
|
||||
rustls::Certificate(rustls_pemfile::certs(&mut cursor).unwrap()[0].clone())
|
||||
});
|
||||
|
||||
// test that basic select with ssl works
|
||||
#[tokio::test]
|
||||
async fn simple_select_ssl() {
|
||||
let (client_sock, server_sock) = make_tcp_pair().await;
|
||||
|
||||
let server_cfg = rustls::ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(vec![CERT.clone()], KEY.clone())
|
||||
.unwrap();
|
||||
let tls_config = Some(Arc::new(server_cfg));
|
||||
let pgbackend =
|
||||
PostgresBackend::new(server_sock, AuthType::Trust, tls_config).expect("pgbackend creation");
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut handler = TestHandler {};
|
||||
pgbackend.run(&mut handler, future::pending::<()>).await
|
||||
});
|
||||
|
||||
let client_cfg = rustls::ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_root_certificates({
|
||||
let mut store = rustls::RootCertStore::empty();
|
||||
store.add(&CERT).unwrap();
|
||||
store
|
||||
})
|
||||
.with_no_client_auth();
|
||||
let mut make_tls_connect = tokio_postgres_rustls::MakeRustlsConnect::new(client_cfg);
|
||||
let tls_connect = <MakeRustlsConnect as MakeTlsConnect<TcpStream>>::make_tls_connect(
|
||||
&mut make_tls_connect,
|
||||
"localhost",
|
||||
)
|
||||
.expect("make_tls_connect");
|
||||
|
||||
let mut conf = Config::new();
|
||||
conf.ssl_mode(SslMode::Require);
|
||||
let (client, connection) = conf
|
||||
.connect_raw(client_sock, tls_connect)
|
||||
.await
|
||||
.expect("connect");
|
||||
// The connection object performs the actual communication with the database,
|
||||
// so spawn it off to run on its own.
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = connection.await {
|
||||
eprintln!("connection error: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
let first_val = &(client.simple_query("SELECT 42;").await.expect("select"))[0];
|
||||
if let SimpleQueryMessage::Row(row) = first_val {
|
||||
let first_col = row.get(0).expect("first column");
|
||||
assert_eq!(first_col, "hey");
|
||||
} else {
|
||||
panic!("expected SimpleQueryMessage::Row");
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user