mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-14 00:42:54 +00:00
Remove sync postgres_backend, tidy up its split usage.
- Add support for splitting async postgres_backend into read and write halfes. Safekeeper needs this for bidirectional streams. To this end, encapsulate reading-writing postgres messages to framed.rs with split support without any additional changes (relying on BufRead for reading and BytesMut out buffer for writing). - Use async postgres_backend throughout safekeeper (and in proxy auth link part). - In both safekeeper COPY streams, do read-write from the same thread/task with select! for easier error handling. - Tidy up finishing CopyBoth streams in safekeeper sending and receiving WAL -- join split parts back catching errors from them before returning. Initially I hoped to do that read-write without split at all, through polling IO: https://github.com/neondatabase/neon/pull/3522 However that turned out to be more complicated than I initially expected due to 1) borrow checking and 2) anon Future types. 1) required Rc<Refcell<...>> which is Send construct just to satisfy the checker; 2) can be workaround with transmute. But this is so messy that I decided to leave split.
This commit is contained in:
23
Cargo.lock
generated
23
Cargo.lock
generated
@@ -913,6 +913,7 @@ dependencies = [
|
||||
"once_cell",
|
||||
"pageserver_api",
|
||||
"postgres",
|
||||
"postgres_backend",
|
||||
"postgres_connection",
|
||||
"regex",
|
||||
"reqwest",
|
||||
@@ -2696,7 +2697,6 @@ dependencies = [
|
||||
"tokio-postgres-rustls",
|
||||
"tokio-rustls",
|
||||
"tracing",
|
||||
"utils",
|
||||
"workspace_hack",
|
||||
]
|
||||
|
||||
@@ -2922,6 +2922,7 @@ dependencies = [
|
||||
"opentelemetry",
|
||||
"parking_lot",
|
||||
"pin-project-lite",
|
||||
"postgres_backend",
|
||||
"pq_proto",
|
||||
"prometheus",
|
||||
"rand",
|
||||
@@ -3301,15 +3302,6 @@ dependencies = [
|
||||
"base64 0.21.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-split"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "78802c9612b4689d207acff746f38132ca1b12dadb55d471aa5f10fd580f47d3"
|
||||
dependencies = [
|
||||
"rustls",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustversion"
|
||||
version = "1.0.11"
|
||||
@@ -3346,6 +3338,7 @@ dependencies = [
|
||||
"parking_lot",
|
||||
"postgres",
|
||||
"postgres-protocol",
|
||||
"postgres_backend",
|
||||
"postgres_ffi",
|
||||
"pq_proto",
|
||||
"regex",
|
||||
@@ -4539,12 +4532,8 @@ dependencies = [
|
||||
"metrics",
|
||||
"nix",
|
||||
"once_cell",
|
||||
"pq_proto",
|
||||
"rand",
|
||||
"routerify",
|
||||
"rustls",
|
||||
"rustls-pemfile",
|
||||
"rustls-split",
|
||||
"sentry",
|
||||
"serde",
|
||||
"serde_json",
|
||||
@@ -4858,14 +4847,19 @@ name = "workspace_hack"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"chrono",
|
||||
"clap 4.1.4",
|
||||
"crossbeam-utils",
|
||||
"digest",
|
||||
"either",
|
||||
"fail",
|
||||
"futures",
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"futures-executor",
|
||||
"futures-sink",
|
||||
"futures-util",
|
||||
"hashbrown 0.12.3",
|
||||
"indexmap",
|
||||
@@ -4890,6 +4884,7 @@ dependencies = [
|
||||
"socket2",
|
||||
"syn",
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"tokio-util",
|
||||
"tonic",
|
||||
"tower",
|
||||
|
||||
@@ -24,6 +24,7 @@ url.workspace = true
|
||||
# Note: Do not directly depend on pageserver or safekeeper; use pageserver_api or safekeeper_api
|
||||
# instead, so that recompile times are better.
|
||||
pageserver_api.workspace = true
|
||||
postgres_backend.workspace = true
|
||||
safekeeper_api.workspace = true
|
||||
postgres_connection.workspace = true
|
||||
storage_broker.workspace = true
|
||||
|
||||
@@ -17,6 +17,7 @@ use pageserver_api::{
|
||||
DEFAULT_HTTP_LISTEN_ADDR as DEFAULT_PAGESERVER_HTTP_ADDR,
|
||||
DEFAULT_PG_LISTEN_ADDR as DEFAULT_PAGESERVER_PG_ADDR,
|
||||
};
|
||||
use postgres_backend::AuthType;
|
||||
use safekeeper_api::{
|
||||
DEFAULT_HTTP_LISTEN_PORT as DEFAULT_SAFEKEEPER_HTTP_PORT,
|
||||
DEFAULT_PG_LISTEN_PORT as DEFAULT_SAFEKEEPER_PG_PORT,
|
||||
@@ -30,7 +31,6 @@ use utils::{
|
||||
auth::{Claims, Scope},
|
||||
id::{NodeId, TenantId, TenantTimelineId, TimelineId},
|
||||
lsn::Lsn,
|
||||
postgres_backend::AuthType,
|
||||
project_git_version,
|
||||
};
|
||||
|
||||
|
||||
@@ -11,10 +11,10 @@ use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use postgres_backend::AuthType;
|
||||
use utils::{
|
||||
id::{TenantId, TimelineId},
|
||||
lsn::Lsn,
|
||||
postgres_backend::AuthType,
|
||||
};
|
||||
|
||||
use crate::local_env::{LocalEnv, DEFAULT_PG_VERSION};
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
use anyhow::{bail, ensure, Context};
|
||||
|
||||
use postgres_backend::AuthType;
|
||||
use reqwest::Url;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::{serde_as, DisplayFromStr};
|
||||
@@ -19,7 +20,6 @@ use std::process::{Command, Stdio};
|
||||
use utils::{
|
||||
auth::{encode_from_key_file, Claims, Scope},
|
||||
id::{NodeId, TenantId, TenantTimelineId, TimelineId},
|
||||
postgres_backend::AuthType,
|
||||
};
|
||||
|
||||
use crate::safekeeper::SafekeeperNode;
|
||||
|
||||
@@ -11,6 +11,7 @@ use anyhow::{bail, Context};
|
||||
use pageserver_api::models::{
|
||||
TenantConfigRequest, TenantCreateRequest, TenantInfo, TimelineCreateRequest, TimelineInfo,
|
||||
};
|
||||
use postgres_backend::AuthType;
|
||||
use postgres_connection::{parse_host_port, PgConnectionConfig};
|
||||
use reqwest::blocking::{Client, RequestBuilder, Response};
|
||||
use reqwest::{IntoUrl, Method};
|
||||
@@ -20,7 +21,6 @@ use utils::{
|
||||
http::error::HttpErrorBody,
|
||||
id::{TenantId, TimelineId},
|
||||
lsn::Lsn,
|
||||
postgres_backend::AuthType,
|
||||
};
|
||||
|
||||
use crate::{background_process, local_env::LocalEnv};
|
||||
|
||||
@@ -17,7 +17,6 @@ tokio-rustls.workspace = true
|
||||
tracing.workspace = true
|
||||
|
||||
pq_proto.workspace = true
|
||||
utils.workspace = true
|
||||
workspace_hack.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
|
||||
@@ -2,29 +2,26 @@
|
||||
//! To use, create PostgresBackend and run() it, passing the Handler
|
||||
//! implementation determining how to process the queries. Currently its API
|
||||
//! is rather narrow, but we can extend it once required.
|
||||
|
||||
use anyhow::Context;
|
||||
use bytes::{Buf, Bytes, BytesMut};
|
||||
use pq_proto::{BeMessage, ConnectionError, FeMessage, FeStartupPacket, SQLSTATE_INTERNAL_ERROR};
|
||||
use std::io;
|
||||
use bytes::Bytes;
|
||||
use futures::pin_mut;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::io::ErrorKind;
|
||||
use std::net::SocketAddr;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::Poll;
|
||||
use std::{future::Future, task::ready};
|
||||
use tracing::{debug, error, info, trace};
|
||||
use utils::postgres_backend::AuthType;
|
||||
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
|
||||
use std::task::{ready, Poll};
|
||||
use std::{fmt, io};
|
||||
use std::{future::Future, str::FromStr};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
use tracing::{debug, error, info, trace};
|
||||
|
||||
pub fn is_expected_io_error(e: &io::Error) -> bool {
|
||||
use io::ErrorKind::*;
|
||||
matches!(
|
||||
e.kind(),
|
||||
ConnectionRefused | ConnectionAborted | ConnectionReset
|
||||
)
|
||||
}
|
||||
use pq_proto::framed::{Framed, FramedReader, FramedWriter};
|
||||
use pq_proto::{
|
||||
BeMessage, ConnectionError, FeMessage, FeStartupPacket, SQLSTATE_INTERNAL_ERROR,
|
||||
SQLSTATE_SUCCESSFUL_COMPLETION,
|
||||
};
|
||||
|
||||
/// An error, occurred during query processing:
|
||||
/// either during the connection ([`ConnectionError`]) or before/after it.
|
||||
@@ -53,12 +50,20 @@ impl QueryError {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_expected_io_error(e: &io::Error) -> bool {
|
||||
use io::ErrorKind::*;
|
||||
matches!(
|
||||
e.kind(),
|
||||
ConnectionRefused | ConnectionAborted | ConnectionReset
|
||||
)
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait Handler {
|
||||
/// Handle single query.
|
||||
/// postgres_backend will issue ReadyForQuery after calling this (this
|
||||
/// might be not what we want after CopyData streaming, but currently we don't
|
||||
/// care).
|
||||
/// care). It will also flush out the output buffer.
|
||||
async fn process_query(
|
||||
&mut self,
|
||||
pgb: &mut PostgresBackend,
|
||||
@@ -92,9 +97,13 @@ pub trait Handler {
|
||||
/// XXX: The order of the constructors matters.
|
||||
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd)]
|
||||
pub enum ProtoState {
|
||||
/// Nothing happened yet.
|
||||
Initialization,
|
||||
/// Encryption handshake is done; waiting for encrypted Startup message.
|
||||
Encrypted,
|
||||
/// Waiting for password (auth token).
|
||||
Authentication,
|
||||
/// Performed handshake and auth, ReadyForQuery is issued.
|
||||
Established,
|
||||
Closed,
|
||||
}
|
||||
@@ -105,15 +114,13 @@ pub enum ProcessMsgResult {
|
||||
Break,
|
||||
}
|
||||
|
||||
/// Always-writeable sock_split stream.
|
||||
/// May not be readable. See [`PostgresBackend::take_stream_in`]
|
||||
pub enum Stream {
|
||||
Unencrypted(BufReader<tokio::net::TcpStream>),
|
||||
Tls(Box<tokio_rustls::server::TlsStream<BufReader<tokio::net::TcpStream>>>),
|
||||
Broken,
|
||||
/// Either plain TCP stream or encrypted one, implementing AsyncRead + AsyncWrite.
|
||||
pub enum MaybeTlsStream {
|
||||
Unencrypted(tokio::net::TcpStream),
|
||||
Tls(Box<tokio_rustls::server::TlsStream<tokio::net::TcpStream>>),
|
||||
}
|
||||
|
||||
impl AsyncWrite for Stream {
|
||||
impl AsyncWrite for MaybeTlsStream {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
@@ -122,14 +129,12 @@ impl AsyncWrite for Stream {
|
||||
match self.get_mut() {
|
||||
Self::Unencrypted(stream) => Pin::new(stream).poll_write(cx, buf),
|
||||
Self::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
|
||||
Self::Broken => unreachable!(),
|
||||
}
|
||||
}
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<io::Result<()>> {
|
||||
match self.get_mut() {
|
||||
Self::Unencrypted(stream) => Pin::new(stream).poll_flush(cx),
|
||||
Self::Tls(stream) => Pin::new(stream).poll_flush(cx),
|
||||
Self::Broken => unreachable!(),
|
||||
}
|
||||
}
|
||||
fn poll_shutdown(
|
||||
@@ -139,11 +144,10 @@ impl AsyncWrite for Stream {
|
||||
match self.get_mut() {
|
||||
Self::Unencrypted(stream) => Pin::new(stream).poll_shutdown(cx),
|
||||
Self::Tls(stream) => Pin::new(stream).poll_shutdown(cx),
|
||||
Self::Broken => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl AsyncRead for Stream {
|
||||
impl AsyncRead for MaybeTlsStream {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
@@ -152,18 +156,96 @@ impl AsyncRead for Stream {
|
||||
match self.get_mut() {
|
||||
Self::Unencrypted(stream) => Pin::new(stream).poll_read(cx, buf),
|
||||
Self::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
|
||||
Self::Broken => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
|
||||
pub enum AuthType {
|
||||
Trust,
|
||||
// This mimics postgres's AuthenticationCleartextPassword but instead of password expects JWT
|
||||
NeonJWT,
|
||||
}
|
||||
|
||||
impl FromStr for AuthType {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"Trust" => Ok(Self::Trust),
|
||||
"NeonJWT" => Ok(Self::NeonJWT),
|
||||
_ => anyhow::bail!("invalid value \"{s}\" for auth type"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for AuthType {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str(match self {
|
||||
AuthType::Trust => "Trust",
|
||||
AuthType::NeonJWT => "NeonJWT",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Either full duplex Framed or write only half; the latter is left in
|
||||
/// PostgresBackend after call to `split`. In principle we could always store a
|
||||
/// pair of splitted handles, but that would force to to pay splitting price
|
||||
/// (Arc and kinda mutex inside polling) for all uses (e.g. pageserver).
|
||||
enum MaybeWriteOnly {
|
||||
Full(Framed<MaybeTlsStream>),
|
||||
WriteOnly(FramedWriter<MaybeTlsStream>),
|
||||
Broken, // temporary value palmed off during the split
|
||||
}
|
||||
|
||||
impl MaybeWriteOnly {
|
||||
async fn read_startup_message(&mut self) -> Result<Option<FeStartupPacket>, ConnectionError> {
|
||||
match self {
|
||||
MaybeWriteOnly::Full(framed) => framed.read_startup_message().await,
|
||||
MaybeWriteOnly::WriteOnly(_) => {
|
||||
Err(io::Error::new(ErrorKind::Other, "reading from write only half").into())
|
||||
}
|
||||
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
|
||||
}
|
||||
}
|
||||
|
||||
async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
|
||||
match self {
|
||||
MaybeWriteOnly::Full(framed) => framed.read_message().await,
|
||||
MaybeWriteOnly::WriteOnly(_) => {
|
||||
Err(io::Error::new(ErrorKind::Other, "reading from write only half").into())
|
||||
}
|
||||
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
|
||||
}
|
||||
}
|
||||
|
||||
fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ConnectionError> {
|
||||
match self {
|
||||
MaybeWriteOnly::Full(framed) => framed.write_message(msg),
|
||||
MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.write_message_noflush(msg),
|
||||
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
|
||||
}
|
||||
}
|
||||
|
||||
async fn flush(&mut self) -> io::Result<()> {
|
||||
match self {
|
||||
MaybeWriteOnly::Full(framed) => framed.flush().await,
|
||||
MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.flush().await,
|
||||
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
|
||||
}
|
||||
}
|
||||
|
||||
async fn shutdown(&mut self) -> io::Result<()> {
|
||||
match self {
|
||||
MaybeWriteOnly::Full(framed) => framed.shutdown().await,
|
||||
MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.shutdown().await,
|
||||
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PostgresBackend {
|
||||
stream: Stream,
|
||||
|
||||
// Output buffer. c.f. BeMessage::write why we are using BytesMut here.
|
||||
// The data between 0 and "current position" as tracked by the bytes::Buf
|
||||
// implementation of BytesMut, have already been written.
|
||||
buf_out: BytesMut,
|
||||
framed: MaybeWriteOnly,
|
||||
|
||||
pub state: ProtoState,
|
||||
|
||||
@@ -183,7 +265,7 @@ pub fn query_from_cstring(query_string: Bytes) -> Vec<u8> {
|
||||
query_string
|
||||
}
|
||||
|
||||
// Cast a byte slice to a string slice, dropping null terminator if there's one.
|
||||
/// Cast a byte slice to a string slice, dropping null terminator if there's one.
|
||||
fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> {
|
||||
let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes);
|
||||
std::str::from_utf8(without_null).map_err(|e| e.into())
|
||||
@@ -196,10 +278,10 @@ impl PostgresBackend {
|
||||
tls_config: Option<Arc<rustls::ServerConfig>>,
|
||||
) -> io::Result<Self> {
|
||||
let peer_addr = socket.peer_addr()?;
|
||||
let stream = MaybeTlsStream::Unencrypted(socket);
|
||||
|
||||
Ok(Self {
|
||||
stream: Stream::Unencrypted(BufReader::new(socket)),
|
||||
buf_out: BytesMut::with_capacity(10 * 1024),
|
||||
framed: MaybeWriteOnly::Full(Framed::new(stream)),
|
||||
state: ProtoState::Initialization,
|
||||
auth_type,
|
||||
tls_config,
|
||||
@@ -211,30 +293,52 @@ impl PostgresBackend {
|
||||
&self.peer_addr
|
||||
}
|
||||
|
||||
/// Read full message or return None if connection is closed.
|
||||
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, QueryError> {
|
||||
use ProtoState::*;
|
||||
match self.state {
|
||||
Initialization | Encrypted => FeStartupPacket::read_fut(&mut self.stream).await,
|
||||
Authentication | Established => FeMessage::read_fut(&mut self.stream).await,
|
||||
Closed => Ok(None),
|
||||
/// Read full message or return None if connection is cleanly closed with no
|
||||
/// unprocessed data.
|
||||
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
|
||||
if let ProtoState::Closed = self.state {
|
||||
Ok(None)
|
||||
} else {
|
||||
let m = self.framed.read_message().await?;
|
||||
trace!("read msg {:?}", m);
|
||||
Ok(m)
|
||||
}
|
||||
.map_err(QueryError::from)
|
||||
}
|
||||
|
||||
/// Write message into internal output buffer, doesn't flush it. Technically
|
||||
/// error type can be only ProtocolError here (if, unlikely, serialization
|
||||
/// fails), but callers typically wrap it anyway.
|
||||
pub fn write_message_noflush(
|
||||
&mut self,
|
||||
message: &BeMessage<'_>,
|
||||
) -> Result<&mut Self, ConnectionError> {
|
||||
self.framed.write_message_noflush(message)?;
|
||||
trace!("wrote msg {:?}", message);
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Flush output buffer into the socket.
|
||||
pub async fn flush(&mut self) -> io::Result<()> {
|
||||
while self.buf_out.has_remaining() {
|
||||
let bytes_written = self.stream.write(self.buf_out.chunk()).await?;
|
||||
self.buf_out.advance(bytes_written);
|
||||
}
|
||||
self.buf_out.clear();
|
||||
Ok(())
|
||||
self.framed.flush().await
|
||||
}
|
||||
|
||||
/// Write message into internal output buffer.
|
||||
pub fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
|
||||
BeMessage::write(&mut self.buf_out, message)?;
|
||||
/// Polling version of `flush()`, saves the caller need to pin.
|
||||
pub fn poll_flush(
|
||||
&mut self,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
let flush_fut = self.flush();
|
||||
pin_mut!(flush_fut);
|
||||
flush_fut.poll(cx)
|
||||
}
|
||||
|
||||
/// Write message into internal output buffer and flush it to the stream.
|
||||
pub async fn write_message(
|
||||
&mut self,
|
||||
message: &BeMessage<'_>,
|
||||
) -> Result<&mut Self, ConnectionError> {
|
||||
self.write_message_noflush(message)?;
|
||||
self.flush().await?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
@@ -246,26 +350,7 @@ impl PostgresBackend {
|
||||
CopyDataWriter { pgb: self }
|
||||
}
|
||||
|
||||
/// A polling function that tries to write all the data from 'buf_out' to the
|
||||
/// underlying stream.
|
||||
fn poll_write_buf(
|
||||
&mut self,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
while self.buf_out.has_remaining() {
|
||||
match ready!(Pin::new(&mut self.stream).poll_write(cx, self.buf_out.chunk())) {
|
||||
Ok(bytes_written) => self.buf_out.advance(bytes_written),
|
||||
Err(err) => return Poll::Ready(Err(err)),
|
||||
}
|
||||
}
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_flush(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), std::io::Error>> {
|
||||
Pin::new(&mut self.stream).poll_flush(cx)
|
||||
}
|
||||
|
||||
// Wrapper for run_message_loop() that shuts down socket when we are done
|
||||
/// Wrapper for run_message_loop() that shuts down socket when we are done
|
||||
pub async fn run<F, S>(
|
||||
mut self,
|
||||
handler: &mut impl Handler,
|
||||
@@ -276,7 +361,9 @@ impl PostgresBackend {
|
||||
S: Future,
|
||||
{
|
||||
let ret = self.run_message_loop(handler, shutdown_watcher).await;
|
||||
let _ = self.stream.shutdown();
|
||||
// socket might be already closed, e.g. if previously received error,
|
||||
// so ignore result.
|
||||
self.framed.shutdown().await.ok();
|
||||
ret
|
||||
}
|
||||
|
||||
@@ -300,30 +387,12 @@ impl PostgresBackend {
|
||||
return Ok(())
|
||||
},
|
||||
|
||||
result = async {
|
||||
while self.state < ProtoState::Established {
|
||||
if let Some(msg) = self.read_message().await? {
|
||||
trace!("got message {msg:?} during handshake");
|
||||
|
||||
match self.process_handshake_message(handler, msg).await? {
|
||||
ProcessMsgResult::Continue => {
|
||||
self.flush().await?;
|
||||
continue;
|
||||
}
|
||||
ProcessMsgResult::Break => {
|
||||
trace!("postgres backend to {:?} exited during handshake", self.peer_addr);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
trace!("postgres backend to {:?} exited during handshake", self.peer_addr);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
Ok::<(), QueryError>(())
|
||||
} => {
|
||||
result = self.handshake(handler) => {
|
||||
// Handshake complete.
|
||||
result?;
|
||||
if self.state == ProtoState::Closed {
|
||||
return Ok(()); // EOF during handshake
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
@@ -355,114 +424,207 @@ impl PostgresBackend {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn start_tls(&mut self) -> anyhow::Result<()> {
|
||||
if let Stream::Unencrypted(plain_stream) =
|
||||
std::mem::replace(&mut self.stream, Stream::Broken)
|
||||
{
|
||||
let acceptor = TlsAcceptor::from(self.tls_config.clone().unwrap());
|
||||
let tls_stream = acceptor.accept(plain_stream).await?;
|
||||
|
||||
self.stream = Stream::Tls(Box::new(tls_stream));
|
||||
return Ok(());
|
||||
};
|
||||
anyhow::bail!("TLS already started");
|
||||
}
|
||||
|
||||
async fn process_handshake_message(
|
||||
&mut self,
|
||||
handler: &mut impl Handler,
|
||||
msg: FeMessage,
|
||||
) -> Result<ProcessMsgResult, QueryError> {
|
||||
assert!(self.state < ProtoState::Established);
|
||||
let have_tls = self.tls_config.is_some();
|
||||
match msg {
|
||||
FeMessage::StartupPacket(m) => {
|
||||
trace!("got startup message {m:?}");
|
||||
|
||||
match m {
|
||||
FeStartupPacket::SslRequest => {
|
||||
debug!("SSL requested");
|
||||
|
||||
self.write_message_noflush(&BeMessage::EncryptionResponse(have_tls))?;
|
||||
if have_tls {
|
||||
self.start_tls().await?;
|
||||
self.state = ProtoState::Encrypted;
|
||||
}
|
||||
}
|
||||
FeStartupPacket::GssEncRequest => {
|
||||
debug!("GSS requested");
|
||||
self.write_message_noflush(&BeMessage::EncryptionResponse(false))?;
|
||||
}
|
||||
FeStartupPacket::StartupMessage { .. } => {
|
||||
if have_tls && !matches!(self.state, ProtoState::Encrypted) {
|
||||
self.write_message_noflush(&BeMessage::ErrorResponse(
|
||||
"must connect with TLS",
|
||||
None,
|
||||
))?;
|
||||
return Err(QueryError::Other(anyhow::anyhow!(
|
||||
"client did not connect with TLS"
|
||||
)));
|
||||
}
|
||||
|
||||
// NB: startup() may change self.auth_type -- we are using that in proxy code
|
||||
// to bypass auth for new users.
|
||||
handler.startup(self, &m)?;
|
||||
|
||||
match self.auth_type {
|
||||
AuthType::Trust => {
|
||||
self.write_message_noflush(&BeMessage::AuthenticationOk)?
|
||||
.write_message_noflush(&BeMessage::CLIENT_ENCODING)?
|
||||
// The async python driver requires a valid server_version
|
||||
.write_message_noflush(&BeMessage::server_version("14.1"))?
|
||||
.write_message_noflush(&BeMessage::ReadyForQuery)?;
|
||||
self.state = ProtoState::Established;
|
||||
}
|
||||
AuthType::NeonJWT => {
|
||||
self.write_message_noflush(
|
||||
&BeMessage::AuthenticationCleartextPassword,
|
||||
)?;
|
||||
self.state = ProtoState::Authentication;
|
||||
}
|
||||
}
|
||||
}
|
||||
FeStartupPacket::CancelRequest { .. } => {
|
||||
self.state = ProtoState::Closed;
|
||||
return Ok(ProcessMsgResult::Break);
|
||||
}
|
||||
}
|
||||
/// Try to upgrade MaybeTlsStream into actual TLS one, performing handshake.
|
||||
async fn tls_upgrade(
|
||||
src: MaybeTlsStream,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
) -> anyhow::Result<MaybeTlsStream> {
|
||||
match src {
|
||||
MaybeTlsStream::Unencrypted(s) => {
|
||||
let acceptor = TlsAcceptor::from(tls_config);
|
||||
let tls_stream = acceptor.accept(s).await?;
|
||||
Ok(MaybeTlsStream::Tls(Box::new(tls_stream)))
|
||||
}
|
||||
|
||||
FeMessage::PasswordMessage(m) => {
|
||||
trace!("got password message '{:?}'", m);
|
||||
|
||||
assert!(self.state == ProtoState::Authentication);
|
||||
|
||||
match self.auth_type {
|
||||
AuthType::Trust => unreachable!(),
|
||||
AuthType::NeonJWT => {
|
||||
let (_, jwt_response) = m.split_last().context("protocol violation")?;
|
||||
|
||||
if let Err(e) = handler.check_auth_jwt(self, jwt_response) {
|
||||
self.write_message_noflush(&BeMessage::ErrorResponse(
|
||||
&e.to_string(),
|
||||
Some(e.pg_error_code()),
|
||||
))?;
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
self.write_message_noflush(&BeMessage::AuthenticationOk)?
|
||||
.write_message_noflush(&BeMessage::CLIENT_ENCODING)?
|
||||
.write_message_noflush(&BeMessage::ReadyForQuery)?;
|
||||
self.state = ProtoState::Established;
|
||||
}
|
||||
|
||||
_ => {
|
||||
self.state = ProtoState::Closed;
|
||||
return Ok(ProcessMsgResult::Break);
|
||||
MaybeTlsStream::Tls(_) => {
|
||||
anyhow::bail!("TLS already started");
|
||||
}
|
||||
}
|
||||
Ok(ProcessMsgResult::Continue)
|
||||
}
|
||||
|
||||
async fn start_tls(&mut self) -> anyhow::Result<()> {
|
||||
// temporary replace stream with fake to cook TLS one, Indiana Jones style
|
||||
match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
|
||||
MaybeWriteOnly::Full(framed) => {
|
||||
let tls_config = self
|
||||
.tls_config
|
||||
.as_ref()
|
||||
.context("start_tls called without conf")?
|
||||
.clone();
|
||||
let tls_framed = framed
|
||||
.map_stream(|s| PostgresBackend::tls_upgrade(s, tls_config))
|
||||
.await?;
|
||||
// push back ready TLS stream
|
||||
self.framed = MaybeWriteOnly::Full(tls_framed);
|
||||
Ok(())
|
||||
}
|
||||
MaybeWriteOnly::WriteOnly(_) => {
|
||||
anyhow::bail!("TLS upgrade attempt in split state")
|
||||
}
|
||||
MaybeWriteOnly::Broken => panic!("TLS upgrade on framed in invalid state"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Split off owned read part from which messages can be read in different
|
||||
/// task/thread.
|
||||
pub fn split(&mut self) -> anyhow::Result<PostgresBackendReader> {
|
||||
// temporary replace stream with fake to cook split one, Indiana Jones style
|
||||
match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
|
||||
MaybeWriteOnly::Full(framed) => {
|
||||
let (reader, writer) = framed.split();
|
||||
self.framed = MaybeWriteOnly::WriteOnly(writer);
|
||||
Ok(PostgresBackendReader(reader))
|
||||
}
|
||||
MaybeWriteOnly::WriteOnly(_) => {
|
||||
anyhow::bail!("PostgresBackend is already split")
|
||||
}
|
||||
MaybeWriteOnly::Broken => panic!("split on framed in invalid state"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Join read part back.
|
||||
pub fn unsplit(&mut self, reader: PostgresBackendReader) -> anyhow::Result<()> {
|
||||
// temporary replace stream with fake to cook joined one, Indiana Jones style
|
||||
match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
|
||||
MaybeWriteOnly::Full(_) => {
|
||||
anyhow::bail!("PostgresBackend is not split")
|
||||
}
|
||||
MaybeWriteOnly::WriteOnly(writer) => {
|
||||
let joined = Framed::unsplit(reader.0, writer);
|
||||
self.framed = MaybeWriteOnly::Full(joined);
|
||||
Ok(())
|
||||
}
|
||||
MaybeWriteOnly::Broken => panic!("unsplit on framed in invalid state"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform handshake with the client, transitioning to Established.
|
||||
/// In case of EOF during handshake logs this, sets state to Closed and returns Ok(()).
|
||||
async fn handshake(&mut self, handler: &mut impl Handler) -> Result<(), QueryError> {
|
||||
while self.state < ProtoState::Authentication {
|
||||
match self.framed.read_startup_message().await? {
|
||||
Some(msg) => {
|
||||
self.process_startup_message(handler, msg).await?;
|
||||
}
|
||||
None => {
|
||||
trace!(
|
||||
"postgres backend to {:?} received EOF during handshake",
|
||||
self.peer_addr
|
||||
);
|
||||
self.state = ProtoState::Closed;
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Perform auth, if needed.
|
||||
if self.state == ProtoState::Authentication {
|
||||
match self.framed.read_message().await? {
|
||||
Some(FeMessage::PasswordMessage(m)) => {
|
||||
assert!(self.auth_type == AuthType::NeonJWT);
|
||||
|
||||
let (_, jwt_response) = m.split_last().context("protocol violation")?;
|
||||
|
||||
if let Err(e) = handler.check_auth_jwt(self, jwt_response) {
|
||||
self.write_message_noflush(&BeMessage::ErrorResponse(
|
||||
&e.to_string(),
|
||||
Some(e.pg_error_code()),
|
||||
))?;
|
||||
return Err(e);
|
||||
}
|
||||
|
||||
self.write_message_noflush(&BeMessage::AuthenticationOk)?
|
||||
.write_message_noflush(&BeMessage::CLIENT_ENCODING)?
|
||||
.write_message(&BeMessage::ReadyForQuery)
|
||||
.await?;
|
||||
self.state = ProtoState::Established;
|
||||
}
|
||||
Some(m) => {
|
||||
return Err(QueryError::Other(anyhow::anyhow!(
|
||||
"Unexpected message {:?} while waiting for handshake",
|
||||
m
|
||||
)));
|
||||
}
|
||||
None => {
|
||||
trace!(
|
||||
"postgres backend to {:?} received EOF during auth",
|
||||
self.peer_addr
|
||||
);
|
||||
self.state = ProtoState::Closed;
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Process startup packet:
|
||||
/// - transition to Established if auth type is trust
|
||||
/// - transition to Authentication if auth type is NeonJWT.
|
||||
/// - or perform TLS handshake -- then need to call this again to receive
|
||||
/// actual startup packet.
|
||||
async fn process_startup_message(
|
||||
&mut self,
|
||||
handler: &mut impl Handler,
|
||||
msg: FeStartupPacket,
|
||||
) -> Result<(), QueryError> {
|
||||
assert!(self.state < ProtoState::Authentication);
|
||||
let have_tls = self.tls_config.is_some();
|
||||
match msg {
|
||||
FeStartupPacket::SslRequest => {
|
||||
debug!("SSL requested");
|
||||
|
||||
self.write_message(&BeMessage::EncryptionResponse(have_tls))
|
||||
.await?;
|
||||
|
||||
if have_tls {
|
||||
self.start_tls().await?;
|
||||
self.state = ProtoState::Encrypted;
|
||||
}
|
||||
}
|
||||
FeStartupPacket::GssEncRequest => {
|
||||
debug!("GSS requested");
|
||||
self.write_message(&BeMessage::EncryptionResponse(false))
|
||||
.await?;
|
||||
}
|
||||
FeStartupPacket::StartupMessage { .. } => {
|
||||
if have_tls && !matches!(self.state, ProtoState::Encrypted) {
|
||||
self.write_message(&BeMessage::ErrorResponse("must connect with TLS", None))
|
||||
.await?;
|
||||
return Err(QueryError::Other(anyhow::anyhow!(
|
||||
"client did not connect with TLS"
|
||||
)));
|
||||
}
|
||||
|
||||
// NB: startup() may change self.auth_type -- we are using that in proxy code
|
||||
// to bypass auth for new users.
|
||||
handler.startup(self, &msg)?;
|
||||
|
||||
match self.auth_type {
|
||||
AuthType::Trust => {
|
||||
self.write_message_noflush(&BeMessage::AuthenticationOk)?
|
||||
.write_message_noflush(&BeMessage::CLIENT_ENCODING)?
|
||||
.write_message_noflush(&BeMessage::INTEGER_DATETIMES)?
|
||||
// The async python driver requires a valid server_version
|
||||
.write_message_noflush(&BeMessage::server_version("14.1"))?
|
||||
.write_message(&BeMessage::ReadyForQuery)
|
||||
.await?;
|
||||
self.state = ProtoState::Established;
|
||||
}
|
||||
AuthType::NeonJWT => {
|
||||
self.write_message(&BeMessage::AuthenticationCleartextPassword)
|
||||
.await?;
|
||||
self.state = ProtoState::Authentication;
|
||||
}
|
||||
}
|
||||
}
|
||||
FeStartupPacket::CancelRequest { .. } => {
|
||||
return Err(QueryError::Other(anyhow::anyhow!(
|
||||
"Unexpected CancelRequest message during handshake"
|
||||
)));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn process_message(
|
||||
@@ -476,10 +638,6 @@ impl PostgresBackend {
|
||||
assert!(self.state == ProtoState::Established);
|
||||
|
||||
match msg {
|
||||
FeMessage::StartupPacket(_) | FeMessage::PasswordMessage(_) => {
|
||||
return Err(QueryError::Other(anyhow::anyhow!("protocol violation")));
|
||||
}
|
||||
|
||||
FeMessage::Query(body) => {
|
||||
// remove null terminator
|
||||
let query_string = cstr_to_str(&body)?;
|
||||
@@ -540,16 +698,114 @@ impl PostgresBackend {
|
||||
|
||||
// We prefer explicit pattern matching to wildcards, because
|
||||
// this helps us spot the places where new variants are missing
|
||||
FeMessage::CopyData(_) | FeMessage::CopyDone | FeMessage::CopyFail => {
|
||||
FeMessage::CopyData(_)
|
||||
| FeMessage::CopyDone
|
||||
| FeMessage::CopyFail
|
||||
| FeMessage::PasswordMessage(_)
|
||||
| FeMessage::StartupPacket(_) => {
|
||||
return Err(QueryError::Other(anyhow::anyhow!(
|
||||
"unexpected message type: {:?}",
|
||||
msg
|
||||
"unexpected message type: {msg:?}",
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ProcessMsgResult::Continue)
|
||||
}
|
||||
|
||||
/// Log as info/error result of handling COPY stream and send back
|
||||
/// ErrorResponse if that makes sense. Shutdown the stream if we got
|
||||
/// Terminate. TODO: transition into waiting for Sync msg if we initiate the
|
||||
/// close.
|
||||
pub async fn handle_copy_stream_end(&mut self, end: CopyStreamHandlerEnd) {
|
||||
use CopyStreamHandlerEnd::*;
|
||||
|
||||
let expected_end = match &end {
|
||||
ServerInitiated(_) | CopyDone | CopyFail | Terminate | EOF => true,
|
||||
CopyStreamHandlerEnd::Disconnected(ConnectionError::Socket(io_error))
|
||||
if is_expected_io_error(io_error) =>
|
||||
{
|
||||
true
|
||||
}
|
||||
_ => false,
|
||||
};
|
||||
if expected_end {
|
||||
info!("terminated: {:#}", end);
|
||||
} else {
|
||||
error!("terminated: {:?}", end);
|
||||
}
|
||||
|
||||
// Note: no current usages ever send this
|
||||
if let CopyDone = &end {
|
||||
if let Err(e) = self.write_message(&BeMessage::CopyDone).await {
|
||||
error!("failed to send CopyDone: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
if let Terminate = &end {
|
||||
self.state = ProtoState::Closed;
|
||||
}
|
||||
|
||||
let err_to_send_and_errcode = match &end {
|
||||
ServerInitiated(_) => Some((end.to_string(), SQLSTATE_SUCCESSFUL_COMPLETION)),
|
||||
Other(_) => Some((end.to_string(), SQLSTATE_INTERNAL_ERROR)),
|
||||
// Note: CopyFail in duplex copy is somewhat unexpected (at least to
|
||||
// PG walsender; evidently and per my docs reading client should
|
||||
// finish it with CopyDone). It is not a problem to recover from it
|
||||
// finishing the stream in both directions like we do, but note that
|
||||
// sync rust-postgres client (which we don't use anymore) hangs if
|
||||
// socket is not closed here.
|
||||
// https://github.com/sfackler/rust-postgres/issues/755
|
||||
// https://github.com/neondatabase/neon/issues/935
|
||||
//
|
||||
// Currently, the version of tokio_postgres replication patch we use
|
||||
// sends this when it closes the stream (e.g. pageserver decided to
|
||||
// switch conn to another safekeeper and client gets dropped).
|
||||
// Moreover, seems like 'connection' task errors with 'unexpected
|
||||
// message from server' when it receives ErrorResponse (anything but
|
||||
// CopyData/CopyDone) back.
|
||||
CopyFail => Some((end.to_string(), SQLSTATE_SUCCESSFUL_COMPLETION)),
|
||||
_ => None,
|
||||
};
|
||||
if let Some((err, errcode)) = err_to_send_and_errcode {
|
||||
if let Err(ee) = self
|
||||
.write_message(&BeMessage::ErrorResponse(&err, Some(errcode)))
|
||||
.await
|
||||
{
|
||||
error!("failed to send ErrorResponse: {}", ee);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PostgresBackendReader(FramedReader<MaybeTlsStream>);
|
||||
|
||||
impl PostgresBackendReader {
|
||||
/// Read full message or return None if connection is cleanly closed with no
|
||||
/// unprocessed data.
|
||||
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
|
||||
let m = self.0.read_message().await?;
|
||||
trace!("read msg {:?}", m);
|
||||
Ok(m)
|
||||
}
|
||||
|
||||
/// Get CopyData contents of the next message in COPY stream or error
|
||||
/// closing it. The error type is wider than actual errors which can happen
|
||||
/// here -- it includes 'Other' and 'ServerInitiated', but that's ok for
|
||||
/// current callers.
|
||||
pub async fn read_copy_message(&mut self) -> Result<Bytes, CopyStreamHandlerEnd> {
|
||||
match self.read_message().await? {
|
||||
Some(msg) => match msg {
|
||||
FeMessage::CopyData(m) => Ok(m),
|
||||
FeMessage::CopyDone => Err(CopyStreamHandlerEnd::CopyDone),
|
||||
FeMessage::CopyFail => Err(CopyStreamHandlerEnd::CopyFail),
|
||||
FeMessage::Terminate => Err(CopyStreamHandlerEnd::Terminate),
|
||||
_ => Err(CopyStreamHandlerEnd::from(ConnectionError::Protocol(
|
||||
format!("unexpected message in COPY stream {:?}", msg),
|
||||
))),
|
||||
},
|
||||
None => Err(CopyStreamHandlerEnd::EOF),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
@@ -572,16 +828,19 @@ impl<'a> AsyncWrite for CopyDataWriter<'a> {
|
||||
// It's not strictly required to flush between each message, but makes it easier
|
||||
// to view in wireshark, and usually the messages that the callers write are
|
||||
// decently-sized anyway.
|
||||
match ready!(this.pgb.poll_write_buf(cx)) {
|
||||
Ok(()) => {}
|
||||
Err(err) => return Poll::Ready(Err(err)),
|
||||
if let Err(err) = ready!(this.pgb.poll_flush(cx)) {
|
||||
return Poll::Ready(Err(err));
|
||||
}
|
||||
|
||||
// CopyData
|
||||
// XXX: if the input is large, we should split it into multiple messages.
|
||||
// Not sure what the threshold should be, but the ultimate hard limit is that
|
||||
// the length cannot exceed u32.
|
||||
this.pgb.write_message_noflush(&BeMessage::CopyData(buf))?;
|
||||
this.pgb
|
||||
.write_message_noflush(&BeMessage::CopyData(buf))
|
||||
// write_message only writes to the buffer, so it can fail iff the
|
||||
// message is invaid, but CopyData can't be invalid.
|
||||
.map_err(|_| io::Error::new(ErrorKind::Other, "failed to serialize CopyData"))?;
|
||||
|
||||
Poll::Ready(Ok(buf.len()))
|
||||
}
|
||||
@@ -591,21 +850,14 @@ impl<'a> AsyncWrite for CopyDataWriter<'a> {
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
let this = self.get_mut();
|
||||
match ready!(this.pgb.poll_write_buf(cx)) {
|
||||
Ok(()) => {}
|
||||
Err(err) => return Poll::Ready(Err(err)),
|
||||
}
|
||||
this.pgb.poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
let this = self.get_mut();
|
||||
match ready!(this.pgb.poll_write_buf(cx)) {
|
||||
Ok(()) => {}
|
||||
Err(err) => return Poll::Ready(Err(err)),
|
||||
}
|
||||
this.pgb.poll_flush(cx)
|
||||
}
|
||||
}
|
||||
@@ -617,7 +869,7 @@ pub fn short_error(e: &QueryError) -> String {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn log_query_error(query: &str, e: &QueryError) {
|
||||
fn log_query_error(query: &str, e: &QueryError) {
|
||||
match e {
|
||||
QueryError::Disconnected(ConnectionError::Socket(io_error)) => {
|
||||
if is_expected_io_error(io_error) {
|
||||
@@ -634,3 +886,26 @@ pub fn log_query_error(query: &str, e: &QueryError) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Something finishing handling of COPY stream, see handle_copy_stream_end.
|
||||
/// This is not always a real error, but it allows to use ? and thiserror impls.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum CopyStreamHandlerEnd {
|
||||
/// Handler initiates the end of streaming.
|
||||
#[error("{0}")]
|
||||
ServerInitiated(String),
|
||||
#[error("received CopyDone")]
|
||||
CopyDone,
|
||||
#[error("received CopyFail")]
|
||||
CopyFail,
|
||||
#[error("received Terminate")]
|
||||
Terminate,
|
||||
#[error("EOF on COPY stream")]
|
||||
EOF,
|
||||
/// The connection was lost
|
||||
#[error(transparent)]
|
||||
Disconnected(#[from] ConnectionError),
|
||||
/// Some other error
|
||||
#[error(transparent)]
|
||||
Other(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
139
libs/postgres_backend/tests/simple_select.rs
Normal file
139
libs/postgres_backend/tests/simple_select.rs
Normal file
@@ -0,0 +1,139 @@
|
||||
/// Test postgres_backend_async with tokio_postgres
|
||||
use once_cell::sync::Lazy;
|
||||
use postgres_backend::{AuthType, Handler, PostgresBackend, QueryError};
|
||||
use pq_proto::{BeMessage, RowDescriptor};
|
||||
use std::io::Cursor;
|
||||
use std::{future, sync::Arc};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tokio_postgres::tls::MakeTlsConnect;
|
||||
use tokio_postgres::{Config, NoTls, SimpleQueryMessage};
|
||||
use tokio_postgres_rustls::MakeRustlsConnect;
|
||||
|
||||
// generate client, server test streams
|
||||
async fn make_tcp_pair() -> (TcpStream, TcpStream) {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
let client_stream = TcpStream::connect(addr).await.unwrap();
|
||||
let (server_stream, _) = listener.accept().await.unwrap();
|
||||
(client_stream, server_stream)
|
||||
}
|
||||
|
||||
struct TestHandler {}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Handler for TestHandler {
|
||||
// return single col 'hey' for any query
|
||||
async fn process_query(
|
||||
&mut self,
|
||||
pgb: &mut PostgresBackend,
|
||||
_query_string: &str,
|
||||
) -> Result<(), QueryError> {
|
||||
pgb.write_message_noflush(&BeMessage::RowDescription(&[RowDescriptor::text_col(
|
||||
b"hey",
|
||||
)]))?
|
||||
.write_message_noflush(&BeMessage::DataRow(&[Some("hey".as_bytes())]))?
|
||||
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// test that basic select works
|
||||
#[tokio::test]
|
||||
async fn simple_select() {
|
||||
let (client_sock, server_sock) = make_tcp_pair().await;
|
||||
|
||||
// create and run pgbackend
|
||||
let pgbackend =
|
||||
PostgresBackend::new(server_sock, AuthType::Trust, None).expect("pgbackend creation");
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut handler = TestHandler {};
|
||||
pgbackend.run(&mut handler, future::pending::<()>).await
|
||||
});
|
||||
|
||||
let conf = Config::new();
|
||||
let (client, connection) = conf.connect_raw(client_sock, NoTls).await.expect("connect");
|
||||
// The connection object performs the actual communication with the database,
|
||||
// so spawn it off to run on its own.
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = connection.await {
|
||||
eprintln!("connection error: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
let first_val = &(client.simple_query("SELECT 42;").await.expect("select"))[0];
|
||||
if let SimpleQueryMessage::Row(row) = first_val {
|
||||
let first_col = row.get(0).expect("first column");
|
||||
assert_eq!(first_col, "hey");
|
||||
} else {
|
||||
panic!("expected SimpleQueryMessage::Row");
|
||||
}
|
||||
}
|
||||
|
||||
static KEY: Lazy<rustls::PrivateKey> = Lazy::new(|| {
|
||||
let mut cursor = Cursor::new(include_bytes!("key.pem"));
|
||||
rustls::PrivateKey(rustls_pemfile::rsa_private_keys(&mut cursor).unwrap()[0].clone())
|
||||
});
|
||||
|
||||
static CERT: Lazy<rustls::Certificate> = Lazy::new(|| {
|
||||
let mut cursor = Cursor::new(include_bytes!("cert.pem"));
|
||||
rustls::Certificate(rustls_pemfile::certs(&mut cursor).unwrap()[0].clone())
|
||||
});
|
||||
|
||||
// test that basic select with ssl works
|
||||
#[tokio::test]
|
||||
async fn simple_select_ssl() {
|
||||
let (client_sock, server_sock) = make_tcp_pair().await;
|
||||
|
||||
let server_cfg = rustls::ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(vec![CERT.clone()], KEY.clone())
|
||||
.unwrap();
|
||||
let tls_config = Some(Arc::new(server_cfg));
|
||||
let pgbackend =
|
||||
PostgresBackend::new(server_sock, AuthType::Trust, tls_config).expect("pgbackend creation");
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut handler = TestHandler {};
|
||||
pgbackend.run(&mut handler, future::pending::<()>).await
|
||||
});
|
||||
|
||||
let client_cfg = rustls::ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_root_certificates({
|
||||
let mut store = rustls::RootCertStore::empty();
|
||||
store.add(&CERT).unwrap();
|
||||
store
|
||||
})
|
||||
.with_no_client_auth();
|
||||
let mut make_tls_connect = tokio_postgres_rustls::MakeRustlsConnect::new(client_cfg);
|
||||
let tls_connect = <MakeRustlsConnect as MakeTlsConnect<TcpStream>>::make_tls_connect(
|
||||
&mut make_tls_connect,
|
||||
"localhost",
|
||||
)
|
||||
.expect("make_tls_connect");
|
||||
|
||||
let mut conf = Config::new();
|
||||
conf.ssl_mode(SslMode::Require);
|
||||
let (client, connection) = conf
|
||||
.connect_raw(client_sock, tls_connect)
|
||||
.await
|
||||
.expect("connect");
|
||||
// The connection object performs the actual communication with the database,
|
||||
// so spawn it off to run on its own.
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = connection.await {
|
||||
eprintln!("connection error: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
let first_val = &(client.simple_query("SELECT 42;").await.expect("select"))[0];
|
||||
if let SimpleQueryMessage::Row(row) = first_val {
|
||||
let first_col = row.get(0).expect("first column");
|
||||
assert_eq!(first_col, "hey");
|
||||
} else {
|
||||
panic!("expected SimpleQueryMessage::Row");
|
||||
}
|
||||
}
|
||||
175
libs/pq_proto/src/framed.rs
Normal file
175
libs/pq_proto/src/framed.rs
Normal file
@@ -0,0 +1,175 @@
|
||||
//! Provides `Framed` -- writing/flushing and reading Postgres messages to/from
|
||||
//! the async stream.
|
||||
use bytes::{Buf, BytesMut};
|
||||
use std::{
|
||||
future::Future,
|
||||
io::{self, ErrorKind},
|
||||
};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, ReadHalf, WriteHalf};
|
||||
|
||||
use crate::{BeMessage, ConnectionError, FeMessage, FeStartupPacket};
|
||||
|
||||
const INITIAL_CAPACITY: usize = 8 * 1024;
|
||||
|
||||
/// Wraps async io `stream`, providing messages to write/flush + read Postgres
|
||||
/// messages.
|
||||
pub struct Framed<S> {
|
||||
stream: BufReader<S>,
|
||||
write_buf: BytesMut,
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin> Framed<S> {
|
||||
pub fn new(stream: S) -> Self {
|
||||
Self {
|
||||
stream: BufReader::new(stream),
|
||||
write_buf: BytesMut::with_capacity(INITIAL_CAPACITY),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a shared reference to the underlying stream.
|
||||
pub fn get_ref(&self) -> &S {
|
||||
self.stream.get_ref()
|
||||
}
|
||||
|
||||
/// Extract the underlying stream.
|
||||
pub fn into_inner(self) -> S {
|
||||
self.stream.into_inner()
|
||||
}
|
||||
|
||||
/// Return new Framed with stream type transformed by async f, for TLS
|
||||
/// upgrade.
|
||||
pub async fn map_stream<S2: AsyncRead, E, F, Fut>(self, f: F) -> Result<Framed<S2>, E>
|
||||
where
|
||||
F: FnOnce(S) -> Fut,
|
||||
Fut: Future<Output = Result<S2, E>>,
|
||||
{
|
||||
let stream = f(self.stream.into_inner()).await?;
|
||||
Ok(Framed {
|
||||
stream: BufReader::new(stream),
|
||||
write_buf: self.write_buf,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin> Framed<S> {
|
||||
pub async fn read_startup_message(
|
||||
&mut self,
|
||||
) -> Result<Option<FeStartupPacket>, ConnectionError> {
|
||||
let msg = FeStartupPacket::read(&mut self.stream).await?;
|
||||
|
||||
match msg {
|
||||
Some(FeMessage::StartupPacket(packet)) => Ok(Some(packet)),
|
||||
None => Ok(None),
|
||||
_ => panic!("unreachable state"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
|
||||
FeMessage::read(&mut self.stream).await
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncWrite + AsyncRead + Unpin> Framed<S> {
|
||||
/// Write next message to the output buffer; doesn't flush.
|
||||
pub fn write_message(&mut self, msg: &BeMessage<'_>) -> Result<(), ConnectionError> {
|
||||
BeMessage::write(&mut self.write_buf, msg).map_err(|e| e.into())
|
||||
}
|
||||
|
||||
/// Flush out the buffer. This function is cancellation safe: it can be
|
||||
/// interrupted and flushing will be continued in the next call.
|
||||
pub async fn flush(&mut self) -> Result<(), io::Error> {
|
||||
flush(&mut self.stream, &mut self.write_buf).await
|
||||
}
|
||||
|
||||
/// Flush out the buffer and shutdown the stream.
|
||||
pub async fn shutdown(&mut self) -> Result<(), io::Error> {
|
||||
shutdown(&mut self.stream, &mut self.write_buf).await
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> Framed<S> {
|
||||
/// Split into owned read and write parts. Beware of potential issues with
|
||||
/// using halves in different tasks on TLS stream:
|
||||
/// https://github.com/tokio-rs/tls/issues/40
|
||||
pub fn split(self) -> (FramedReader<S>, FramedWriter<S>) {
|
||||
let (read_half, write_half) = tokio::io::split(self.stream);
|
||||
let reader = FramedReader { stream: read_half };
|
||||
let writer = FramedWriter {
|
||||
stream: write_half,
|
||||
write_buf: self.write_buf,
|
||||
};
|
||||
(reader, writer)
|
||||
}
|
||||
|
||||
/// Join read and write parts back.
|
||||
pub fn unsplit(reader: FramedReader<S>, writer: FramedWriter<S>) -> Self {
|
||||
Self {
|
||||
stream: reader.stream.unsplit(writer.stream),
|
||||
write_buf: writer.write_buf,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Read-only version of `Framed`.
|
||||
pub struct FramedReader<S> {
|
||||
stream: ReadHalf<BufReader<S>>,
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin> FramedReader<S> {
|
||||
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
|
||||
FeMessage::read(&mut self.stream).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Write-only version of `Framed`.
|
||||
pub struct FramedWriter<S> {
|
||||
stream: WriteHalf<BufReader<S>>,
|
||||
write_buf: BytesMut,
|
||||
}
|
||||
|
||||
impl<S: AsyncWrite + AsyncRead + Unpin> FramedWriter<S> {
|
||||
/// Write next message to the output buffer; doesn't flush.
|
||||
pub fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ConnectionError> {
|
||||
BeMessage::write(&mut self.write_buf, msg).map_err(|e| e.into())
|
||||
}
|
||||
|
||||
/// Flush out the buffer. This function is cancellation safe: it can be
|
||||
/// interrupted and flushing will be continued in the next call.
|
||||
pub async fn flush(&mut self) -> Result<(), io::Error> {
|
||||
flush(&mut self.stream, &mut self.write_buf).await
|
||||
}
|
||||
|
||||
/// Flush out the buffer and shutdown the stream.
|
||||
pub async fn shutdown(&mut self) -> Result<(), io::Error> {
|
||||
shutdown(&mut self.stream, &mut self.write_buf).await
|
||||
}
|
||||
}
|
||||
|
||||
async fn flush<S: AsyncWrite + Unpin>(
|
||||
stream: &mut S,
|
||||
write_buf: &mut BytesMut,
|
||||
) -> Result<(), io::Error> {
|
||||
while write_buf.has_remaining() {
|
||||
let bytes_written = stream.write(write_buf.chunk()).await?;
|
||||
if bytes_written == 0 {
|
||||
return Err(io::Error::new(
|
||||
ErrorKind::WriteZero,
|
||||
"failed to write message",
|
||||
));
|
||||
}
|
||||
// The advanced part will be garbage collected, likely during shifting
|
||||
// data left on next attempt to write to buffer when free space is not
|
||||
// enough.
|
||||
write_buf.advance(bytes_written);
|
||||
}
|
||||
write_buf.clear();
|
||||
stream.flush().await
|
||||
}
|
||||
|
||||
async fn shutdown<S: AsyncWrite + Unpin>(
|
||||
stream: &mut S,
|
||||
write_buf: &mut BytesMut,
|
||||
) -> Result<(), io::Error> {
|
||||
flush(stream, write_buf).await?;
|
||||
stream.shutdown().await
|
||||
}
|
||||
@@ -2,8 +2,7 @@
|
||||
//! <https://www.postgresql.org/docs/devel/protocol-message-formats.html>
|
||||
//! on message formats.
|
||||
|
||||
// Tools for calling certain async methods in sync contexts.
|
||||
pub mod sync;
|
||||
pub mod framed;
|
||||
|
||||
use anyhow::{ensure, Context, Result};
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
@@ -13,12 +12,10 @@ use std::{
|
||||
borrow::Cow,
|
||||
collections::HashMap,
|
||||
fmt,
|
||||
future::Future,
|
||||
io::{self, Cursor},
|
||||
str,
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
use sync::{AsyncishRead, SyncFuture};
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tracing::{trace, warn};
|
||||
|
||||
@@ -211,7 +208,7 @@ macro_rules! retry_read {
|
||||
pub enum ConnectionError {
|
||||
/// IO error during writing to or reading from the connection socket.
|
||||
#[error("Socket IO error: {0}")]
|
||||
Socket(std::io::Error),
|
||||
Socket(#[from] std::io::Error),
|
||||
/// Invalid packet was received from client
|
||||
#[error("Protocol error: {0}")]
|
||||
Protocol(String),
|
||||
@@ -238,87 +235,56 @@ impl ConnectionError {
|
||||
impl FeMessage {
|
||||
/// Read one message from the stream.
|
||||
/// This function returns `Ok(None)` in case of EOF.
|
||||
/// One way to handle this properly:
|
||||
///
|
||||
/// ```
|
||||
/// # use std::io;
|
||||
/// # use pq_proto::FeMessage;
|
||||
/// #
|
||||
/// # fn process_message(msg: FeMessage) -> anyhow::Result<()> {
|
||||
/// # Ok(())
|
||||
/// # };
|
||||
/// #
|
||||
/// fn do_the_job(stream: &mut (impl io::Read + Unpin)) -> anyhow::Result<()> {
|
||||
/// while let Some(msg) = FeMessage::read(stream)? {
|
||||
/// process_message(msg)?;
|
||||
/// }
|
||||
///
|
||||
/// Ok(())
|
||||
/// }
|
||||
/// ```
|
||||
#[inline(never)]
|
||||
pub fn read(
|
||||
stream: &mut (impl io::Read + Unpin),
|
||||
) -> Result<Option<FeMessage>, ConnectionError> {
|
||||
Self::read_fut(&mut AsyncishRead(stream)).wait()
|
||||
}
|
||||
|
||||
/// Read one message from the stream.
|
||||
/// See documentation for `Self::read`.
|
||||
pub fn read_fut<Reader>(
|
||||
stream: &mut Reader,
|
||||
) -> SyncFuture<Reader, impl Future<Output = Result<Option<FeMessage>, ConnectionError>> + '_>
|
||||
pub async fn read<Reader>(stream: &mut Reader) -> Result<Option<FeMessage>, ConnectionError>
|
||||
where
|
||||
Reader: tokio::io::AsyncRead + Unpin,
|
||||
{
|
||||
// We return a Future that's sync (has a `wait` method) if and only if the provided stream is SyncProof.
|
||||
// SyncFuture contract: we are only allowed to await on sync-proof futures, the AsyncRead and
|
||||
// AsyncReadExt methods of the stream.
|
||||
SyncFuture::new(async move {
|
||||
// Each libpq message begins with a message type byte, followed by message length
|
||||
// If the client closes the connection, return None. But if the client closes the
|
||||
// connection in the middle of a message, we will return an error.
|
||||
let tag = match retry_read!(stream.read_u8().await) {
|
||||
Ok(b) => b,
|
||||
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
|
||||
Err(e) => return Err(ConnectionError::Socket(e)),
|
||||
};
|
||||
// Each libpq message begins with a message type byte, followed by message length
|
||||
// If the client closes the connection, return None. But if the client closes the
|
||||
// connection in the middle of a message, we will return an error.
|
||||
let tag = match retry_read!(stream.read_u8().await) {
|
||||
Ok(b) => b,
|
||||
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
|
||||
Err(e) => return Err(ConnectionError::Socket(e)),
|
||||
};
|
||||
|
||||
// The message length includes itself, so it better be at least 4.
|
||||
let len = retry_read!(stream.read_u32().await)
|
||||
.map_err(ConnectionError::Socket)?
|
||||
.checked_sub(4)
|
||||
.ok_or_else(|| ConnectionError::Protocol("invalid message length".to_string()))?;
|
||||
// The message length includes itself, so it better be at least 4.
|
||||
let len = retry_read!(stream.read_u32().await)
|
||||
.map_err(ConnectionError::Socket)?
|
||||
.checked_sub(4)
|
||||
.ok_or_else(|| ConnectionError::Protocol("invalid message length".to_string()))?;
|
||||
|
||||
let body = {
|
||||
let mut buffer = vec![0u8; len as usize];
|
||||
stream
|
||||
.read_exact(&mut buffer)
|
||||
.await
|
||||
.map_err(ConnectionError::Socket)?;
|
||||
Bytes::from(buffer)
|
||||
};
|
||||
let body = {
|
||||
let mut buffer = vec![0u8; len as usize];
|
||||
stream
|
||||
.read_exact(&mut buffer)
|
||||
.await
|
||||
.map_err(ConnectionError::Socket)?;
|
||||
Bytes::from(buffer)
|
||||
};
|
||||
|
||||
match tag {
|
||||
b'Q' => Ok(Some(FeMessage::Query(body))),
|
||||
b'P' => Ok(Some(FeParseMessage::parse(body)?)),
|
||||
b'D' => Ok(Some(FeDescribeMessage::parse(body)?)),
|
||||
b'E' => Ok(Some(FeExecuteMessage::parse(body)?)),
|
||||
b'B' => Ok(Some(FeBindMessage::parse(body)?)),
|
||||
b'C' => Ok(Some(FeCloseMessage::parse(body)?)),
|
||||
b'S' => Ok(Some(FeMessage::Sync)),
|
||||
b'X' => Ok(Some(FeMessage::Terminate)),
|
||||
b'd' => Ok(Some(FeMessage::CopyData(body))),
|
||||
b'c' => Ok(Some(FeMessage::CopyDone)),
|
||||
b'f' => Ok(Some(FeMessage::CopyFail)),
|
||||
b'p' => Ok(Some(FeMessage::PasswordMessage(body))),
|
||||
tag => {
|
||||
return Err(ConnectionError::Protocol(format!(
|
||||
"unknown message tag: {tag},'{body:?}'"
|
||||
)))
|
||||
}
|
||||
match tag {
|
||||
b'Q' => Ok(Some(FeMessage::Query(body))),
|
||||
b'P' => Ok(Some(FeParseMessage::parse(body)?)),
|
||||
b'D' => Ok(Some(FeDescribeMessage::parse(body)?)),
|
||||
b'E' => Ok(Some(FeExecuteMessage::parse(body)?)),
|
||||
b'B' => Ok(Some(FeBindMessage::parse(body)?)),
|
||||
b'C' => Ok(Some(FeCloseMessage::parse(body)?)),
|
||||
b'S' => Ok(Some(FeMessage::Sync)),
|
||||
b'X' => Ok(Some(FeMessage::Terminate)),
|
||||
b'd' => Ok(Some(FeMessage::CopyData(body))),
|
||||
b'c' => Ok(Some(FeMessage::CopyDone)),
|
||||
b'f' => Ok(Some(FeMessage::CopyFail)),
|
||||
b'p' => Ok(Some(FeMessage::PasswordMessage(body))),
|
||||
tag => {
|
||||
return Err(ConnectionError::Protocol(format!(
|
||||
"unknown message tag: {tag},'{body:?}'"
|
||||
)))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -326,18 +292,7 @@ impl FeStartupPacket {
|
||||
/// Read startup message from the stream.
|
||||
// XXX: It's tempting yet undesirable to accept `stream` by value,
|
||||
// since such a change will cause user-supplied &mut references to be consumed
|
||||
pub fn read(
|
||||
stream: &mut (impl io::Read + Unpin),
|
||||
) -> Result<Option<FeMessage>, ConnectionError> {
|
||||
Self::read_fut(&mut AsyncishRead(stream)).wait()
|
||||
}
|
||||
|
||||
/// Read startup message from the stream.
|
||||
// XXX: It's tempting yet undesirable to accept `stream` by value,
|
||||
// since such a change will cause user-supplied &mut references to be consumed
|
||||
pub fn read_fut<Reader>(
|
||||
stream: &mut Reader,
|
||||
) -> SyncFuture<Reader, impl Future<Output = Result<Option<FeMessage>, ConnectionError>> + '_>
|
||||
pub async fn read<Reader>(stream: &mut Reader) -> Result<Option<FeMessage>, ConnectionError>
|
||||
where
|
||||
Reader: tokio::io::AsyncRead + Unpin,
|
||||
{
|
||||
@@ -347,99 +302,96 @@ impl FeStartupPacket {
|
||||
const NEGOTIATE_SSL_CODE: u32 = 5679;
|
||||
const NEGOTIATE_GSS_CODE: u32 = 5680;
|
||||
|
||||
SyncFuture::new(async move {
|
||||
// Read length. If the connection is closed before reading anything (or before
|
||||
// reading 4 bytes, to be precise), return None to indicate that the connection
|
||||
// was closed. This matches the PostgreSQL server's behavior, which avoids noise
|
||||
// in the log if the client opens connection but closes it immediately.
|
||||
let len = match retry_read!(stream.read_u32().await) {
|
||||
Ok(len) => len as usize,
|
||||
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
|
||||
Err(e) => return Err(ConnectionError::Socket(e)),
|
||||
};
|
||||
// Read length. If the connection is closed before reading anything (or before
|
||||
// reading 4 bytes, to be precise), return None to indicate that the connection
|
||||
// was closed. This matches the PostgreSQL server's behavior, which avoids noise
|
||||
// in the log if the client opens connection but closes it immediately.
|
||||
let len = match retry_read!(stream.read_u32().await) {
|
||||
Ok(len) => len as usize,
|
||||
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
|
||||
Err(e) => return Err(ConnectionError::Socket(e)),
|
||||
};
|
||||
|
||||
#[allow(clippy::manual_range_contains)]
|
||||
if len < 4 || len > MAX_STARTUP_PACKET_LENGTH {
|
||||
#[allow(clippy::manual_range_contains)]
|
||||
if len < 4 || len > MAX_STARTUP_PACKET_LENGTH {
|
||||
return Err(ConnectionError::Protocol(format!(
|
||||
"invalid message length {len}"
|
||||
)));
|
||||
}
|
||||
|
||||
let request_code = retry_read!(stream.read_u32().await).map_err(ConnectionError::Socket)?;
|
||||
|
||||
// the rest of startup packet are params
|
||||
let params_len = len - 8;
|
||||
let mut params_bytes = vec![0u8; params_len];
|
||||
stream
|
||||
.read_exact(params_bytes.as_mut())
|
||||
.await
|
||||
.map_err(ConnectionError::Socket)?;
|
||||
|
||||
// Parse params depending on request code
|
||||
let req_hi = request_code >> 16;
|
||||
let req_lo = request_code & ((1 << 16) - 1);
|
||||
let message = match (req_hi, req_lo) {
|
||||
(RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
|
||||
if params_len != 8 {
|
||||
return Err(ConnectionError::Protocol(
|
||||
"expected 8 bytes for CancelRequest params".to_string(),
|
||||
));
|
||||
}
|
||||
let mut cursor = Cursor::new(params_bytes);
|
||||
FeStartupPacket::CancelRequest(CancelKeyData {
|
||||
backend_pid: cursor.read_i32().await.map_err(ConnectionError::Socket)?,
|
||||
cancel_key: cursor.read_i32().await.map_err(ConnectionError::Socket)?,
|
||||
})
|
||||
}
|
||||
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => {
|
||||
// Requested upgrade to SSL (aka TLS)
|
||||
FeStartupPacket::SslRequest
|
||||
}
|
||||
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => {
|
||||
// Requested upgrade to GSSAPI
|
||||
FeStartupPacket::GssEncRequest
|
||||
}
|
||||
(RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
|
||||
return Err(ConnectionError::Protocol(format!(
|
||||
"invalid message length {len}"
|
||||
"Unrecognized request code {unrecognized_code}"
|
||||
)));
|
||||
}
|
||||
// TODO bail if protocol major_version is not 3?
|
||||
(major_version, minor_version) => {
|
||||
// Parse pairs of null-terminated strings (key, value).
|
||||
// See `postgres: ProcessStartupPacket, build_startup_packet`.
|
||||
let mut tokens = str::from_utf8(¶ms_bytes)
|
||||
.context("StartupMessage params: invalid utf-8")?
|
||||
.strip_suffix('\0') // drop packet's own null
|
||||
.ok_or_else(|| {
|
||||
ConnectionError::Protocol(
|
||||
"StartupMessage params: missing null terminator".to_string(),
|
||||
)
|
||||
})?
|
||||
.split_terminator('\0');
|
||||
|
||||
let request_code =
|
||||
retry_read!(stream.read_u32().await).map_err(ConnectionError::Socket)?;
|
||||
let mut params = HashMap::new();
|
||||
while let Some(name) = tokens.next() {
|
||||
let value = tokens.next().ok_or_else(|| {
|
||||
ConnectionError::Protocol(
|
||||
"StartupMessage params: key without value".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
// the rest of startup packet are params
|
||||
let params_len = len - 8;
|
||||
let mut params_bytes = vec![0u8; params_len];
|
||||
stream
|
||||
.read_exact(params_bytes.as_mut())
|
||||
.await
|
||||
.map_err(ConnectionError::Socket)?;
|
||||
|
||||
// Parse params depending on request code
|
||||
let req_hi = request_code >> 16;
|
||||
let req_lo = request_code & ((1 << 16) - 1);
|
||||
let message = match (req_hi, req_lo) {
|
||||
(RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
|
||||
if params_len != 8 {
|
||||
return Err(ConnectionError::Protocol(
|
||||
"expected 8 bytes for CancelRequest params".to_string(),
|
||||
));
|
||||
}
|
||||
let mut cursor = Cursor::new(params_bytes);
|
||||
FeStartupPacket::CancelRequest(CancelKeyData {
|
||||
backend_pid: cursor.read_i32().await.map_err(ConnectionError::Socket)?,
|
||||
cancel_key: cursor.read_i32().await.map_err(ConnectionError::Socket)?,
|
||||
})
|
||||
params.insert(name.to_owned(), value.to_owned());
|
||||
}
|
||||
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => {
|
||||
// Requested upgrade to SSL (aka TLS)
|
||||
FeStartupPacket::SslRequest
|
||||
}
|
||||
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => {
|
||||
// Requested upgrade to GSSAPI
|
||||
FeStartupPacket::GssEncRequest
|
||||
}
|
||||
(RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
|
||||
return Err(ConnectionError::Protocol(format!(
|
||||
"Unrecognized request code {unrecognized_code}"
|
||||
)));
|
||||
}
|
||||
// TODO bail if protocol major_version is not 3?
|
||||
(major_version, minor_version) => {
|
||||
// Parse pairs of null-terminated strings (key, value).
|
||||
// See `postgres: ProcessStartupPacket, build_startup_packet`.
|
||||
let mut tokens = str::from_utf8(¶ms_bytes)
|
||||
.context("StartupMessage params: invalid utf-8")?
|
||||
.strip_suffix('\0') // drop packet's own null
|
||||
.ok_or_else(|| {
|
||||
ConnectionError::Protocol(
|
||||
"StartupMessage params: missing null terminator".to_string(),
|
||||
)
|
||||
})?
|
||||
.split_terminator('\0');
|
||||
|
||||
let mut params = HashMap::new();
|
||||
while let Some(name) = tokens.next() {
|
||||
let value = tokens.next().ok_or_else(|| {
|
||||
ConnectionError::Protocol(
|
||||
"StartupMessage params: key without value".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
params.insert(name.to_owned(), value.to_owned());
|
||||
}
|
||||
|
||||
FeStartupPacket::StartupMessage {
|
||||
major_version,
|
||||
minor_version,
|
||||
params: StartupMessageParams { params },
|
||||
}
|
||||
FeStartupPacket::StartupMessage {
|
||||
major_version,
|
||||
minor_version,
|
||||
params: StartupMessageParams { params },
|
||||
}
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Some(FeMessage::StartupPacket(message)))
|
||||
})
|
||||
Ok(Some(FeMessage::StartupPacket(message)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -559,6 +511,11 @@ impl<'a> BeMessage<'a> {
|
||||
value: b"UTF8",
|
||||
};
|
||||
|
||||
pub const INTEGER_DATETIMES: Self = Self::ParameterStatus {
|
||||
name: b"integer_datetimes",
|
||||
value: b"on",
|
||||
};
|
||||
|
||||
/// Build a [`BeMessage::ParameterStatus`] holding the server version.
|
||||
pub fn server_version(version: &'a str) -> Self {
|
||||
Self::ParameterStatus {
|
||||
@@ -698,6 +655,7 @@ fn read_cstr(buf: &mut Bytes) -> anyhow::Result<Bytes> {
|
||||
}
|
||||
|
||||
pub const SQLSTATE_INTERNAL_ERROR: &[u8; 5] = b"XX000";
|
||||
pub const SQLSTATE_SUCCESSFUL_COMPLETION: &[u8; 5] = b"00000";
|
||||
|
||||
impl<'a> BeMessage<'a> {
|
||||
/// Write message to the given buf.
|
||||
@@ -1149,15 +1107,6 @@ mod tests {
|
||||
let params = make_params("foo\\ bar \\ \\\\ baz\\ lol");
|
||||
assert_eq!(split_options(¶ms), ["foo bar", " \\", "baz ", "lol"]);
|
||||
}
|
||||
|
||||
// Make sure that `read` is sync/async callable
|
||||
async fn _assert(stream: &mut (impl tokio::io::AsyncRead + Unpin)) {
|
||||
let _ = FeMessage::read(&mut [].as_ref());
|
||||
let _ = FeMessage::read_fut(stream).await;
|
||||
|
||||
let _ = FeStartupPacket::read(&mut [].as_ref());
|
||||
let _ = FeStartupPacket::read_fut(stream).await;
|
||||
}
|
||||
}
|
||||
|
||||
fn terminate_code(code: &[u8; 5]) -> [u8; 6] {
|
||||
|
||||
@@ -1,179 +0,0 @@
|
||||
use pin_project_lite::pin_project;
|
||||
use std::future::Future;
|
||||
use std::marker::PhantomData;
|
||||
use std::pin::Pin;
|
||||
use std::{io, task};
|
||||
|
||||
pin_project! {
|
||||
/// We use this future to mark certain methods
|
||||
/// as callable in both sync and async modes.
|
||||
#[repr(transparent)]
|
||||
pub struct SyncFuture<S, T: Future> {
|
||||
#[pin]
|
||||
inner: T,
|
||||
_marker: PhantomData<S>,
|
||||
}
|
||||
}
|
||||
|
||||
/// This wrapper lets us synchronously wait for inner future's completion
|
||||
/// (see [`SyncFuture::wait`]) **provided that `S` implements [`SyncProof`]**.
|
||||
/// For instance, `S` may be substituted with types implementing
|
||||
/// [`tokio::io::AsyncRead`], but it's not the only viable option.
|
||||
impl<S, T: Future> SyncFuture<S, T> {
|
||||
/// NOTE: caller should carefully pick a type for `S`,
|
||||
/// because we don't want to enable [`SyncFuture::wait`] when
|
||||
/// it's in fact impossible to run the future synchronously.
|
||||
/// Violation of this contract will not cause UB, but
|
||||
/// panics and async event loop freezes won't please you.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// ```
|
||||
/// # use pq_proto::sync::SyncFuture;
|
||||
/// # use std::future::Future;
|
||||
/// # use tokio::io::AsyncReadExt;
|
||||
/// #
|
||||
/// // Parse a pair of numbers from a stream
|
||||
/// pub fn parse_pair<Reader>(
|
||||
/// stream: &mut Reader,
|
||||
/// ) -> SyncFuture<Reader, impl Future<Output = anyhow::Result<(u32, u64)>> + '_>
|
||||
/// where
|
||||
/// Reader: tokio::io::AsyncRead + Unpin,
|
||||
/// {
|
||||
/// // If `Reader` is a `SyncProof`, this will give caller
|
||||
/// // an opportunity to use `SyncFuture::wait`, because
|
||||
/// // `.await` will always result in `Poll::Ready`.
|
||||
/// SyncFuture::new(async move {
|
||||
/// let x = stream.read_u32().await?;
|
||||
/// let y = stream.read_u64().await?;
|
||||
/// Ok((x, y))
|
||||
/// })
|
||||
/// }
|
||||
/// ```
|
||||
pub fn new(inner: T) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, T: Future> Future for SyncFuture<S, T> {
|
||||
type Output = T::Output;
|
||||
|
||||
/// In async code, [`SyncFuture`] behaves like a regular wrapper.
|
||||
#[inline(always)]
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
|
||||
self.project().inner.poll(cx)
|
||||
}
|
||||
}
|
||||
|
||||
/// Postulates that we can call [`SyncFuture::wait`].
|
||||
/// If implementer is also a [`Future`], it should always
|
||||
/// return [`task::Poll::Ready`] from [`Future::poll`].
|
||||
///
|
||||
/// Each implementation should document which futures
|
||||
/// specifically are being declared sync-proof.
|
||||
pub trait SyncPostulate {}
|
||||
|
||||
impl<T: SyncPostulate> SyncPostulate for &T {}
|
||||
impl<T: SyncPostulate> SyncPostulate for &mut T {}
|
||||
|
||||
impl<P: SyncPostulate, T: Future> SyncFuture<P, T> {
|
||||
/// Synchronously wait for future completion.
|
||||
pub fn wait(mut self) -> T::Output {
|
||||
const RAW_WAKER: task::RawWaker = task::RawWaker::new(
|
||||
std::ptr::null(),
|
||||
&task::RawWakerVTable::new(
|
||||
|_| RAW_WAKER,
|
||||
|_| panic!("SyncFuture: failed to wake"),
|
||||
|_| panic!("SyncFuture: failed to wake by ref"),
|
||||
|_| { /* drop is no-op */ },
|
||||
),
|
||||
);
|
||||
|
||||
// SAFETY: We never move `self` during this call;
|
||||
// furthermore, it will be dropped in the end regardless of panics
|
||||
let this = unsafe { Pin::new_unchecked(&mut self) };
|
||||
|
||||
// SAFETY: This waker doesn't do anything apart from panicking
|
||||
let waker = unsafe { task::Waker::from_raw(RAW_WAKER) };
|
||||
let context = &mut task::Context::from_waker(&waker);
|
||||
|
||||
match this.poll(context) {
|
||||
task::Poll::Ready(res) => res,
|
||||
_ => panic!("SyncFuture: unexpected pending!"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// This wrapper turns any [`std::io::Read`] into a blocking [`tokio::io::AsyncRead`],
|
||||
/// which lets us abstract over sync & async readers in methods returning [`SyncFuture`].
|
||||
/// NOTE: you **should not** use this in async code.
|
||||
#[repr(transparent)]
|
||||
pub struct AsyncishRead<T: io::Read + Unpin>(pub T);
|
||||
|
||||
/// This lets us call [`SyncFuture<AsyncishRead<_>, _>::wait`],
|
||||
/// and allows the future to await on any of the [`AsyncRead`]
|
||||
/// and [`AsyncReadExt`] methods on `AsyncishRead`.
|
||||
impl<T: io::Read + Unpin> SyncPostulate for AsyncishRead<T> {}
|
||||
|
||||
impl<T: io::Read + Unpin> tokio::io::AsyncRead for AsyncishRead<T> {
|
||||
#[inline(always)]
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
_cx: &mut task::Context<'_>,
|
||||
buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
task::Poll::Ready(
|
||||
// `Read::read` will block, meaning we don't need a real event loop!
|
||||
self.0
|
||||
.read(buf.initialize_unfilled())
|
||||
.map(|sz| buf.advance(sz)),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
// async helper(stream: &mut impl AsyncRead) -> io::Result<u32>
|
||||
fn bytes_add<Reader>(
|
||||
stream: &mut Reader,
|
||||
) -> SyncFuture<Reader, impl Future<Output = io::Result<u32>> + '_>
|
||||
where
|
||||
Reader: tokio::io::AsyncRead + Unpin,
|
||||
{
|
||||
SyncFuture::new(async move {
|
||||
let a = stream.read_u32().await?;
|
||||
let b = stream.read_u32().await?;
|
||||
Ok(a + b)
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sync() {
|
||||
let bytes = [100u32.to_be_bytes(), 200u32.to_be_bytes()].concat();
|
||||
let res = bytes_add(&mut AsyncishRead(&mut &bytes[..]))
|
||||
.wait()
|
||||
.unwrap();
|
||||
assert_eq!(res, 300);
|
||||
}
|
||||
|
||||
// We need a single-threaded executor for this test
|
||||
#[tokio::test(flavor = "current_thread")]
|
||||
async fn test_async() {
|
||||
let (mut tx, mut rx) = tokio::net::UnixStream::pair().unwrap();
|
||||
|
||||
let write = async move {
|
||||
tx.write_u32(100).await?;
|
||||
tx.write_u32(200).await?;
|
||||
Ok(())
|
||||
};
|
||||
|
||||
let (res, ()) = tokio::try_join!(bytes_add(&mut rx), write).unwrap();
|
||||
assert_eq!(res, 300);
|
||||
}
|
||||
}
|
||||
@@ -111,7 +111,7 @@ pub trait RemoteStorage: Send + Sync + 'static {
|
||||
}
|
||||
|
||||
pub struct Download {
|
||||
pub download_stream: Pin<Box<dyn io::AsyncRead + Unpin + Send>>,
|
||||
pub download_stream: Pin<Box<dyn io::AsyncRead + Unpin + Send + Sync>>,
|
||||
/// Extra key-value data, associated with the current remote file.
|
||||
pub metadata: Option<StorageMetadata>,
|
||||
}
|
||||
|
||||
@@ -19,8 +19,6 @@ jsonwebtoken.workspace = true
|
||||
nix.workspace = true
|
||||
once_cell.workspace = true
|
||||
routerify.workspace = true
|
||||
rustls.workspace = true
|
||||
rustls-split.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
signal-hook.workspace = true
|
||||
@@ -36,7 +34,6 @@ url.workspace = true
|
||||
uuid = { version = "1.2", features = ["v4", "serde"] }
|
||||
|
||||
metrics.workspace = true
|
||||
pq_proto.workspace = true
|
||||
workspace_hack.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
@@ -44,7 +41,6 @@ byteorder.workspace = true
|
||||
bytes.workspace = true
|
||||
criterion.workspace = true
|
||||
hex-literal.workspace = true
|
||||
rustls-pemfile.workspace = true
|
||||
tempfile.workspace = true
|
||||
|
||||
[[bench]]
|
||||
|
||||
@@ -13,7 +13,6 @@ pub mod simple_rcu;
|
||||
pub mod vec_map;
|
||||
|
||||
pub mod bin_ser;
|
||||
pub mod postgres_backend;
|
||||
|
||||
// helper functions for creating and fsyncing
|
||||
pub mod crashsafe;
|
||||
@@ -26,9 +25,6 @@ pub mod id;
|
||||
// http endpoint utils
|
||||
pub mod http;
|
||||
|
||||
// socket splitting utils
|
||||
pub mod sock_split;
|
||||
|
||||
// common log initialisation routine
|
||||
pub mod logging;
|
||||
|
||||
|
||||
@@ -1,544 +0,0 @@
|
||||
//! Server-side synchronous Postgres connection, as limited as we need.
|
||||
//! To use, create PostgresBackend and run() it, passing the Handler
|
||||
//! implementation determining how to process the queries. Currently its API
|
||||
//! is rather narrow, but we can extend it once required.
|
||||
|
||||
use crate::sock_split::{BidiStream, ReadStream, WriteStream};
|
||||
use anyhow::Context;
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use pq_proto::{BeMessage, ConnectionError, FeMessage, FeStartupPacket, SQLSTATE_INTERNAL_ERROR};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
use std::io::{self, Write};
|
||||
use std::net::{Shutdown, SocketAddr, TcpStream};
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tracing::*;
|
||||
|
||||
pub fn is_expected_io_error(e: &io::Error) -> bool {
|
||||
use io::ErrorKind::*;
|
||||
matches!(
|
||||
e.kind(),
|
||||
ConnectionRefused | ConnectionAborted | ConnectionReset
|
||||
)
|
||||
}
|
||||
|
||||
/// An error, occurred during query processing:
|
||||
/// either during the connection ([`ConnectionError`]) or before/after it.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum QueryError {
|
||||
/// The connection was lost while processing the query.
|
||||
#[error(transparent)]
|
||||
Disconnected(#[from] ConnectionError),
|
||||
/// Some other error
|
||||
#[error(transparent)]
|
||||
Other(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
impl From<io::Error> for QueryError {
|
||||
fn from(e: io::Error) -> Self {
|
||||
Self::Disconnected(ConnectionError::Socket(e))
|
||||
}
|
||||
}
|
||||
|
||||
impl QueryError {
|
||||
pub fn pg_error_code(&self) -> &'static [u8; 5] {
|
||||
match self {
|
||||
Self::Disconnected(_) => b"08006", // connection failure
|
||||
Self::Other(_) => SQLSTATE_INTERNAL_ERROR, // internal error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Handler {
|
||||
/// Handle single query.
|
||||
/// postgres_backend will issue ReadyForQuery after calling this (this
|
||||
/// might be not what we want after CopyData streaming, but currently we don't
|
||||
/// care).
|
||||
fn process_query(
|
||||
&mut self,
|
||||
pgb: &mut PostgresBackend,
|
||||
query_string: &str,
|
||||
) -> Result<(), QueryError>;
|
||||
|
||||
/// Called on startup packet receival, allows to process params.
|
||||
///
|
||||
/// If Ok(false) is returned postgres_backend will skip auth -- that is needed for new users
|
||||
/// creation is the proxy code. That is quite hacky and ad-hoc solution, may be we could allow
|
||||
/// to override whole init logic in implementations.
|
||||
fn startup(
|
||||
&mut self,
|
||||
_pgb: &mut PostgresBackend,
|
||||
_sm: &FeStartupPacket,
|
||||
) -> Result<(), QueryError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check auth jwt
|
||||
fn check_auth_jwt(
|
||||
&mut self,
|
||||
_pgb: &mut PostgresBackend,
|
||||
_jwt_response: &[u8],
|
||||
) -> Result<(), QueryError> {
|
||||
Err(QueryError::Other(anyhow::anyhow!("JWT auth failed")))
|
||||
}
|
||||
|
||||
fn is_shutdown_requested(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// PostgresBackend protocol state.
|
||||
/// XXX: The order of the constructors matters.
|
||||
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd)]
|
||||
pub enum ProtoState {
|
||||
Initialization,
|
||||
Encrypted,
|
||||
Authentication,
|
||||
Established,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
|
||||
pub enum AuthType {
|
||||
Trust,
|
||||
// This mimics postgres's AuthenticationCleartextPassword but instead of password expects JWT
|
||||
NeonJWT,
|
||||
}
|
||||
|
||||
impl FromStr for AuthType {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"Trust" => Ok(Self::Trust),
|
||||
"NeonJWT" => Ok(Self::NeonJWT),
|
||||
_ => anyhow::bail!("invalid value \"{s}\" for auth type"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for AuthType {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str(match self {
|
||||
AuthType::Trust => "Trust",
|
||||
AuthType::NeonJWT => "NeonJWT",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum ProcessMsgResult {
|
||||
Continue,
|
||||
Break,
|
||||
}
|
||||
|
||||
/// Always-writeable sock_split stream.
|
||||
/// May not be readable. See [`PostgresBackend::take_stream_in`]
|
||||
pub enum Stream {
|
||||
Bidirectional(BidiStream),
|
||||
WriteOnly(WriteStream),
|
||||
}
|
||||
|
||||
impl Stream {
|
||||
fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
|
||||
match self {
|
||||
Self::Bidirectional(bidi_stream) => bidi_stream.shutdown(how),
|
||||
Self::WriteOnly(write_stream) => write_stream.shutdown(how),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl io::Write for Stream {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
match self {
|
||||
Self::Bidirectional(bidi_stream) => bidi_stream.write(buf),
|
||||
Self::WriteOnly(write_stream) => write_stream.write(buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
match self {
|
||||
Self::Bidirectional(bidi_stream) => bidi_stream.flush(),
|
||||
Self::WriteOnly(write_stream) => write_stream.flush(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PostgresBackend {
|
||||
stream: Option<Stream>,
|
||||
// Output buffer. c.f. BeMessage::write why we are using BytesMut here.
|
||||
buf_out: BytesMut,
|
||||
|
||||
pub state: ProtoState,
|
||||
|
||||
auth_type: AuthType,
|
||||
|
||||
peer_addr: SocketAddr,
|
||||
pub tls_config: Option<Arc<rustls::ServerConfig>>,
|
||||
}
|
||||
|
||||
pub fn query_from_cstring(query_string: Bytes) -> Vec<u8> {
|
||||
let mut query_string = query_string.to_vec();
|
||||
if let Some(ch) = query_string.last() {
|
||||
if *ch == 0 {
|
||||
query_string.pop();
|
||||
}
|
||||
}
|
||||
query_string
|
||||
}
|
||||
|
||||
// Helper function for socket read loops
|
||||
pub fn is_socket_read_timed_out(error: &anyhow::Error) -> bool {
|
||||
for cause in error.chain() {
|
||||
if let Some(io_error) = cause.downcast_ref::<io::Error>() {
|
||||
if io_error.kind() == std::io::ErrorKind::WouldBlock {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
// Cast a byte slice to a string slice, dropping null terminator if there's one.
|
||||
fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> {
|
||||
let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes);
|
||||
std::str::from_utf8(without_null).map_err(|e| e.into())
|
||||
}
|
||||
|
||||
impl PostgresBackend {
|
||||
pub fn new(
|
||||
socket: TcpStream,
|
||||
auth_type: AuthType,
|
||||
tls_config: Option<Arc<rustls::ServerConfig>>,
|
||||
set_read_timeout: bool,
|
||||
) -> io::Result<Self> {
|
||||
let peer_addr = socket.peer_addr()?;
|
||||
if set_read_timeout {
|
||||
socket
|
||||
.set_read_timeout(Some(Duration::from_secs(5)))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
stream: Some(Stream::Bidirectional(BidiStream::from_tcp(socket))),
|
||||
buf_out: BytesMut::with_capacity(10 * 1024),
|
||||
state: ProtoState::Initialization,
|
||||
auth_type,
|
||||
tls_config,
|
||||
peer_addr,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn into_stream(self) -> Stream {
|
||||
self.stream.unwrap()
|
||||
}
|
||||
|
||||
/// Get direct reference (into the Option) to the read stream.
|
||||
fn get_stream_in(&mut self) -> anyhow::Result<&mut BidiStream> {
|
||||
match &mut self.stream {
|
||||
Some(Stream::Bidirectional(stream)) => Ok(stream),
|
||||
_ => anyhow::bail!("reader taken"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_peer_addr(&self) -> &SocketAddr {
|
||||
&self.peer_addr
|
||||
}
|
||||
|
||||
pub fn take_stream_in(&mut self) -> Option<ReadStream> {
|
||||
let stream = self.stream.take();
|
||||
match stream {
|
||||
Some(Stream::Bidirectional(bidi_stream)) => {
|
||||
let (read, write) = bidi_stream.split();
|
||||
self.stream = Some(Stream::WriteOnly(write));
|
||||
Some(read)
|
||||
}
|
||||
stream => {
|
||||
self.stream = stream;
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Read full message or return None if connection is closed.
|
||||
pub fn read_message(&mut self) -> Result<Option<FeMessage>, QueryError> {
|
||||
let (state, stream) = (self.state, self.get_stream_in()?);
|
||||
|
||||
use ProtoState::*;
|
||||
match state {
|
||||
Initialization | Encrypted => FeStartupPacket::read(stream),
|
||||
Authentication | Established => FeMessage::read(stream),
|
||||
}
|
||||
.map_err(QueryError::from)
|
||||
}
|
||||
|
||||
/// Write message into internal output buffer.
|
||||
pub fn write_message_noflush(&mut self, message: &BeMessage) -> io::Result<&mut Self> {
|
||||
BeMessage::write(&mut self.buf_out, message)?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Flush output buffer into the socket.
|
||||
pub fn flush(&mut self) -> io::Result<&mut Self> {
|
||||
let stream = self.stream.as_mut().unwrap();
|
||||
stream.write_all(&self.buf_out)?;
|
||||
self.buf_out.clear();
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Write message into internal buffer and flush it.
|
||||
pub fn write_message(&mut self, message: &BeMessage) -> io::Result<&mut Self> {
|
||||
self.write_message_noflush(message)?;
|
||||
self.flush()
|
||||
}
|
||||
|
||||
// Wrapper for run_message_loop() that shuts down socket when we are done
|
||||
pub fn run(mut self, handler: &mut impl Handler) -> Result<(), QueryError> {
|
||||
let ret = self.run_message_loop(handler);
|
||||
if let Some(stream) = self.stream.as_mut() {
|
||||
let _ = stream.shutdown(Shutdown::Both);
|
||||
}
|
||||
ret
|
||||
}
|
||||
|
||||
fn run_message_loop(&mut self, handler: &mut impl Handler) -> Result<(), QueryError> {
|
||||
trace!("postgres backend to {:?} started", self.peer_addr);
|
||||
|
||||
let mut unnamed_query_string = Bytes::new();
|
||||
|
||||
while !handler.is_shutdown_requested() {
|
||||
match self.read_message() {
|
||||
Ok(message) => {
|
||||
if let Some(msg) = message {
|
||||
trace!("got message {msg:?}");
|
||||
|
||||
match self.process_message(handler, msg, &mut unnamed_query_string)? {
|
||||
ProcessMsgResult::Continue => continue,
|
||||
ProcessMsgResult::Break => break,
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
if let QueryError::Other(e) = &e {
|
||||
if is_socket_read_timed_out(e) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
trace!("postgres backend to {:?} exited", self.peer_addr);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn start_tls(&mut self) -> anyhow::Result<()> {
|
||||
match self.stream.take() {
|
||||
Some(Stream::Bidirectional(bidi_stream)) => {
|
||||
let conn = rustls::ServerConnection::new(self.tls_config.clone().unwrap())?;
|
||||
self.stream = Some(Stream::Bidirectional(bidi_stream.start_tls(conn)?));
|
||||
Ok(())
|
||||
}
|
||||
stream => {
|
||||
self.stream = stream;
|
||||
anyhow::bail!("can't start TLs without bidi stream");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn process_message(
|
||||
&mut self,
|
||||
handler: &mut impl Handler,
|
||||
msg: FeMessage,
|
||||
unnamed_query_string: &mut Bytes,
|
||||
) -> Result<ProcessMsgResult, QueryError> {
|
||||
// Allow only startup and password messages during auth. Otherwise client would be able to bypass auth
|
||||
// TODO: change that to proper top-level match of protocol state with separate message handling for each state
|
||||
if self.state < ProtoState::Established
|
||||
&& !matches!(
|
||||
msg,
|
||||
FeMessage::PasswordMessage(_) | FeMessage::StartupPacket(_)
|
||||
)
|
||||
{
|
||||
return Err(QueryError::Other(anyhow::anyhow!("protocol violation")));
|
||||
}
|
||||
|
||||
let have_tls = self.tls_config.is_some();
|
||||
match msg {
|
||||
FeMessage::StartupPacket(m) => {
|
||||
trace!("got startup message {m:?}");
|
||||
|
||||
match m {
|
||||
FeStartupPacket::SslRequest => {
|
||||
debug!("SSL requested");
|
||||
|
||||
self.write_message(&BeMessage::EncryptionResponse(have_tls))?;
|
||||
if have_tls {
|
||||
self.start_tls()?;
|
||||
self.state = ProtoState::Encrypted;
|
||||
}
|
||||
}
|
||||
FeStartupPacket::GssEncRequest => {
|
||||
debug!("GSS requested");
|
||||
self.write_message(&BeMessage::EncryptionResponse(false))?;
|
||||
}
|
||||
FeStartupPacket::StartupMessage { .. } => {
|
||||
if have_tls && !matches!(self.state, ProtoState::Encrypted) {
|
||||
self.write_message(&BeMessage::ErrorResponse(
|
||||
"must connect with TLS",
|
||||
None,
|
||||
))?;
|
||||
return Err(QueryError::Other(anyhow::anyhow!(
|
||||
"client did not connect with TLS"
|
||||
)));
|
||||
}
|
||||
|
||||
// NB: startup() may change self.auth_type -- we are using that in proxy code
|
||||
// to bypass auth for new users.
|
||||
handler.startup(self, &m)?;
|
||||
|
||||
match self.auth_type {
|
||||
AuthType::Trust => {
|
||||
self.write_message_noflush(&BeMessage::AuthenticationOk)?
|
||||
.write_message_noflush(&BeMessage::CLIENT_ENCODING)?
|
||||
// The async python driver requires a valid server_version
|
||||
.write_message_noflush(&BeMessage::server_version("14.1"))?
|
||||
.write_message(&BeMessage::ReadyForQuery)?;
|
||||
self.state = ProtoState::Established;
|
||||
}
|
||||
AuthType::NeonJWT => {
|
||||
self.write_message(&BeMessage::AuthenticationCleartextPassword)?;
|
||||
self.state = ProtoState::Authentication;
|
||||
}
|
||||
}
|
||||
}
|
||||
FeStartupPacket::CancelRequest { .. } => {
|
||||
return Ok(ProcessMsgResult::Break);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
FeMessage::PasswordMessage(m) => {
|
||||
trace!("got password message '{:?}'", m);
|
||||
|
||||
assert!(self.state == ProtoState::Authentication);
|
||||
|
||||
match self.auth_type {
|
||||
AuthType::Trust => unreachable!(),
|
||||
AuthType::NeonJWT => {
|
||||
let (_, jwt_response) = m.split_last().context("protocol violation")?;
|
||||
|
||||
if let Err(e) = handler.check_auth_jwt(self, jwt_response) {
|
||||
self.write_message(&BeMessage::ErrorResponse(
|
||||
&e.to_string(),
|
||||
Some(e.pg_error_code()),
|
||||
))?;
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
self.write_message_noflush(&BeMessage::AuthenticationOk)?
|
||||
.write_message_noflush(&BeMessage::CLIENT_ENCODING)?
|
||||
.write_message(&BeMessage::ReadyForQuery)?;
|
||||
self.state = ProtoState::Established;
|
||||
}
|
||||
|
||||
FeMessage::Query(body) => {
|
||||
// remove null terminator
|
||||
let query_string = cstr_to_str(&body)?;
|
||||
|
||||
trace!("got query {query_string:?}");
|
||||
if let Err(e) = handler.process_query(self, query_string) {
|
||||
log_query_error(query_string, &e);
|
||||
let short_error = short_error(&e);
|
||||
self.write_message_noflush(&BeMessage::ErrorResponse(
|
||||
&short_error,
|
||||
Some(e.pg_error_code()),
|
||||
))?;
|
||||
}
|
||||
self.write_message(&BeMessage::ReadyForQuery)?;
|
||||
}
|
||||
|
||||
FeMessage::Parse(m) => {
|
||||
*unnamed_query_string = m.query_string;
|
||||
self.write_message(&BeMessage::ParseComplete)?;
|
||||
}
|
||||
|
||||
FeMessage::Describe(_) => {
|
||||
self.write_message_noflush(&BeMessage::ParameterDescription)?
|
||||
.write_message(&BeMessage::NoData)?;
|
||||
}
|
||||
|
||||
FeMessage::Bind(_) => {
|
||||
self.write_message(&BeMessage::BindComplete)?;
|
||||
}
|
||||
|
||||
FeMessage::Close(_) => {
|
||||
self.write_message(&BeMessage::CloseComplete)?;
|
||||
}
|
||||
|
||||
FeMessage::Execute(_) => {
|
||||
let query_string = cstr_to_str(unnamed_query_string)?;
|
||||
trace!("got execute {query_string:?}");
|
||||
if let Err(e) = handler.process_query(self, query_string) {
|
||||
log_query_error(query_string, &e);
|
||||
self.write_message(&BeMessage::ErrorResponse(
|
||||
&e.to_string(),
|
||||
Some(e.pg_error_code()),
|
||||
))?;
|
||||
}
|
||||
// NOTE there is no ReadyForQuery message. This handler is used
|
||||
// for basebackup and it uses CopyOut which doesn't require
|
||||
// ReadyForQuery message and backend just switches back to
|
||||
// processing mode after sending CopyDone or ErrorResponse.
|
||||
}
|
||||
|
||||
FeMessage::Sync => {
|
||||
self.write_message(&BeMessage::ReadyForQuery)?;
|
||||
}
|
||||
|
||||
FeMessage::Terminate => {
|
||||
return Ok(ProcessMsgResult::Break);
|
||||
}
|
||||
|
||||
// We prefer explicit pattern matching to wildcards, because
|
||||
// this helps us spot the places where new variants are missing
|
||||
FeMessage::CopyData(_) | FeMessage::CopyDone | FeMessage::CopyFail => {
|
||||
return Err(QueryError::Other(anyhow::anyhow!(
|
||||
"unexpected message type: {msg:?}"
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ProcessMsgResult::Continue)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn short_error(e: &QueryError) -> String {
|
||||
match e {
|
||||
QueryError::Disconnected(connection_error) => connection_error.to_string(),
|
||||
QueryError::Other(e) => format!("{e:#}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn log_query_error(query: &str, e: &QueryError) {
|
||||
match e {
|
||||
QueryError::Disconnected(ConnectionError::Socket(io_error)) => {
|
||||
if is_expected_io_error(io_error) {
|
||||
info!("query handler for '{query}' failed with expected io error: {io_error}");
|
||||
} else {
|
||||
error!("query handler for '{query}' failed with io error: {io_error}");
|
||||
}
|
||||
}
|
||||
QueryError::Disconnected(other_connection_error) => {
|
||||
error!("query handler for '{query}' failed with connection error: {other_connection_error:?}")
|
||||
}
|
||||
QueryError::Other(e) => {
|
||||
error!("query handler for '{query}' failed: {e:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,206 +0,0 @@
|
||||
use std::{
|
||||
io::{self, BufReader, Write},
|
||||
net::{Shutdown, TcpStream},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use rustls::Connection;
|
||||
|
||||
/// Wrapper supporting reads of a shared TcpStream.
|
||||
pub struct ArcTcpRead(Arc<TcpStream>);
|
||||
|
||||
impl io::Read for ArcTcpRead {
|
||||
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
|
||||
(&*self.0).read(buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for ArcTcpRead {
|
||||
type Target = TcpStream;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.0.deref()
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper around a TCP Stream supporting buffered reads.
|
||||
pub struct BufStream(BufReader<ArcTcpRead>);
|
||||
|
||||
impl io::Read for BufStream {
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
self.0.read(buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl io::Write for BufStream {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
self.get_ref().write(buf)
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
self.get_ref().flush()
|
||||
}
|
||||
}
|
||||
|
||||
impl BufStream {
|
||||
/// Unwrap into the internal BufReader.
|
||||
fn into_reader(self) -> BufReader<ArcTcpRead> {
|
||||
self.0
|
||||
}
|
||||
|
||||
/// Returns a reference to the underlying TcpStream.
|
||||
fn get_ref(&self) -> &TcpStream {
|
||||
&self.0.get_ref().0
|
||||
}
|
||||
}
|
||||
|
||||
pub enum ReadStream {
|
||||
Tcp(BufReader<ArcTcpRead>),
|
||||
Tls(rustls_split::ReadHalf),
|
||||
}
|
||||
|
||||
impl io::Read for ReadStream {
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
match self {
|
||||
Self::Tcp(reader) => reader.read(buf),
|
||||
Self::Tls(read_half) => read_half.read(buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReadStream {
|
||||
pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
|
||||
match self {
|
||||
Self::Tcp(stream) => stream.get_ref().shutdown(how),
|
||||
Self::Tls(write_half) => write_half.shutdown(how),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum WriteStream {
|
||||
Tcp(Arc<TcpStream>),
|
||||
Tls(rustls_split::WriteHalf),
|
||||
}
|
||||
|
||||
impl WriteStream {
|
||||
pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
|
||||
match self {
|
||||
Self::Tcp(stream) => stream.shutdown(how),
|
||||
Self::Tls(write_half) => write_half.shutdown(how),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl io::Write for WriteStream {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
match self {
|
||||
Self::Tcp(stream) => stream.as_ref().write(buf),
|
||||
Self::Tls(write_half) => write_half.write(buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
match self {
|
||||
Self::Tcp(stream) => stream.as_ref().flush(),
|
||||
Self::Tls(write_half) => write_half.flush(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type TlsStream<T> = rustls::StreamOwned<rustls::ServerConnection, T>;
|
||||
|
||||
pub enum BidiStream {
|
||||
Tcp(BufStream),
|
||||
/// This variant is boxed, because [`rustls::ServerConnection`] is quite larger than [`BufStream`].
|
||||
Tls(Box<TlsStream<BufStream>>),
|
||||
}
|
||||
|
||||
impl BidiStream {
|
||||
pub fn from_tcp(stream: TcpStream) -> Self {
|
||||
Self::Tcp(BufStream(BufReader::new(ArcTcpRead(Arc::new(stream)))))
|
||||
}
|
||||
|
||||
pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
|
||||
match self {
|
||||
Self::Tcp(stream) => stream.get_ref().shutdown(how),
|
||||
Self::Tls(tls_boxed) => {
|
||||
if how == Shutdown::Read {
|
||||
tls_boxed.sock.get_ref().shutdown(how)
|
||||
} else {
|
||||
tls_boxed.conn.send_close_notify();
|
||||
let res = tls_boxed.flush();
|
||||
tls_boxed.sock.get_ref().shutdown(how)?;
|
||||
res
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Split the bi-directional stream into two owned read and write halves.
|
||||
pub fn split(self) -> (ReadStream, WriteStream) {
|
||||
match self {
|
||||
Self::Tcp(stream) => {
|
||||
let reader = stream.into_reader();
|
||||
let stream: Arc<TcpStream> = reader.get_ref().0.clone();
|
||||
|
||||
(ReadStream::Tcp(reader), WriteStream::Tcp(stream))
|
||||
}
|
||||
Self::Tls(tls_boxed) => {
|
||||
let reader = tls_boxed.sock.into_reader();
|
||||
let buffer_data = reader.buffer().to_owned();
|
||||
let read_buf_cfg = rustls_split::BufCfg::with_data(buffer_data, 8192);
|
||||
let write_buf_cfg = rustls_split::BufCfg::with_capacity(8192);
|
||||
|
||||
// TODO would be nice to avoid the Arc here
|
||||
let socket = Arc::try_unwrap(reader.into_inner().0).unwrap();
|
||||
|
||||
let (read_half, write_half) = rustls_split::split(
|
||||
socket,
|
||||
Connection::Server(tls_boxed.conn),
|
||||
read_buf_cfg,
|
||||
write_buf_cfg,
|
||||
);
|
||||
(ReadStream::Tls(read_half), WriteStream::Tls(write_half))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn start_tls(self, mut conn: rustls::ServerConnection) -> io::Result<Self> {
|
||||
match self {
|
||||
Self::Tcp(mut stream) => {
|
||||
conn.complete_io(&mut stream)?;
|
||||
assert!(!conn.is_handshaking());
|
||||
Ok(Self::Tls(Box::new(TlsStream::new(conn, stream))))
|
||||
}
|
||||
Self::Tls { .. } => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"TLS is already started on this stream",
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl io::Read for BidiStream {
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
match self {
|
||||
Self::Tcp(stream) => stream.read(buf),
|
||||
Self::Tls(tls_boxed) => tls_boxed.read(buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl io::Write for BidiStream {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
match self {
|
||||
Self::Tcp(stream) => stream.write(buf),
|
||||
Self::Tls(tls_boxed) => tls_boxed.write(buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
match self {
|
||||
Self::Tcp(stream) => stream.flush(),
|
||||
Self::Tls(tls_boxed) => tls_boxed.flush(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,238 +0,0 @@
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
io::{Cursor, Read, Write},
|
||||
net::{TcpListener, TcpStream},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
use once_cell::sync::Lazy;
|
||||
|
||||
use utils::{
|
||||
postgres_backend::QueryError,
|
||||
postgres_backend::{AuthType, Handler, PostgresBackend},
|
||||
};
|
||||
|
||||
fn make_tcp_pair() -> (TcpStream, TcpStream) {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
let client_stream = TcpStream::connect(addr).unwrap();
|
||||
let (server_stream, _) = listener.accept().unwrap();
|
||||
(server_stream, client_stream)
|
||||
}
|
||||
|
||||
static KEY: Lazy<rustls::PrivateKey> = Lazy::new(|| {
|
||||
let mut cursor = Cursor::new(include_bytes!("key.pem"));
|
||||
rustls::PrivateKey(rustls_pemfile::rsa_private_keys(&mut cursor).unwrap()[0].clone())
|
||||
});
|
||||
|
||||
static CERT: Lazy<rustls::Certificate> = Lazy::new(|| {
|
||||
let mut cursor = Cursor::new(include_bytes!("cert.pem"));
|
||||
rustls::Certificate(rustls_pemfile::certs(&mut cursor).unwrap()[0].clone())
|
||||
});
|
||||
|
||||
#[test]
|
||||
// [false-positive](https://github.com/rust-lang/rust-clippy/issues/9274),
|
||||
// we resize the vector so doing some modifications after all
|
||||
#[allow(clippy::read_zero_byte_vec)]
|
||||
fn ssl() {
|
||||
let (mut client_sock, server_sock) = make_tcp_pair();
|
||||
|
||||
const QUERY: &str = "hello world";
|
||||
|
||||
let client_jh = std::thread::spawn(move || {
|
||||
// SSLRequest
|
||||
client_sock.write_u32::<BigEndian>(8).unwrap();
|
||||
client_sock.write_u32::<BigEndian>(80877103).unwrap();
|
||||
|
||||
let ssl_response = client_sock.read_u8().unwrap();
|
||||
assert_eq!(b'S', ssl_response);
|
||||
|
||||
let cfg = rustls::ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_root_certificates({
|
||||
let mut store = rustls::RootCertStore::empty();
|
||||
store.add(&CERT).unwrap();
|
||||
store
|
||||
})
|
||||
.with_no_client_auth();
|
||||
let client_config = Arc::new(cfg);
|
||||
|
||||
let dns_name = "localhost".try_into().unwrap();
|
||||
let mut conn = rustls::ClientConnection::new(client_config, dns_name).unwrap();
|
||||
|
||||
conn.complete_io(&mut client_sock).unwrap();
|
||||
assert!(!conn.is_handshaking());
|
||||
|
||||
let mut stream = rustls::Stream::new(&mut conn, &mut client_sock);
|
||||
|
||||
// StartupMessage
|
||||
stream.write_u32::<BigEndian>(9).unwrap();
|
||||
stream.write_u32::<BigEndian>(196608).unwrap();
|
||||
stream.write_u8(0).unwrap();
|
||||
stream.flush().unwrap();
|
||||
|
||||
// wait for ReadyForQuery
|
||||
let mut msg_buf = Vec::new();
|
||||
loop {
|
||||
let msg = stream.read_u8().unwrap();
|
||||
let size = stream.read_u32::<BigEndian>().unwrap() - 4;
|
||||
msg_buf.resize(size as usize, 0);
|
||||
stream.read_exact(&mut msg_buf).unwrap();
|
||||
|
||||
if msg == b'Z' {
|
||||
// ReadyForQuery
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Query
|
||||
stream.write_u8(b'Q').unwrap();
|
||||
stream
|
||||
.write_u32::<BigEndian>(4u32 + QUERY.len() as u32)
|
||||
.unwrap();
|
||||
stream.write_all(QUERY.as_ref()).unwrap();
|
||||
stream.flush().unwrap();
|
||||
|
||||
// ReadyForQuery
|
||||
let msg = stream.read_u8().unwrap();
|
||||
assert_eq!(msg, b'Z');
|
||||
});
|
||||
|
||||
struct TestHandler {
|
||||
got_query: bool,
|
||||
}
|
||||
impl Handler for TestHandler {
|
||||
fn process_query(
|
||||
&mut self,
|
||||
_pgb: &mut PostgresBackend,
|
||||
query_string: &str,
|
||||
) -> Result<(), QueryError> {
|
||||
self.got_query = query_string == QUERY;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
let mut handler = TestHandler { got_query: false };
|
||||
|
||||
let cfg = rustls::ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(vec![CERT.clone()], KEY.clone())
|
||||
.unwrap();
|
||||
let tls_config = Some(Arc::new(cfg));
|
||||
|
||||
let pgb = PostgresBackend::new(server_sock, AuthType::Trust, tls_config, true).unwrap();
|
||||
pgb.run(&mut handler).unwrap();
|
||||
assert!(handler.got_query);
|
||||
|
||||
client_jh.join().unwrap();
|
||||
|
||||
// TODO consider shutdown behavior
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_ssl() {
|
||||
let (mut client_sock, server_sock) = make_tcp_pair();
|
||||
|
||||
let client_jh = std::thread::spawn(move || {
|
||||
let mut buf = BytesMut::new();
|
||||
|
||||
// SSLRequest
|
||||
buf.put_u32(8);
|
||||
buf.put_u32(80877103);
|
||||
client_sock.write_all(&buf).unwrap();
|
||||
buf.clear();
|
||||
|
||||
let ssl_response = client_sock.read_u8().unwrap();
|
||||
assert_eq!(b'N', ssl_response);
|
||||
});
|
||||
|
||||
struct TestHandler;
|
||||
|
||||
impl Handler for TestHandler {
|
||||
fn process_query(
|
||||
&mut self,
|
||||
_pgb: &mut PostgresBackend,
|
||||
_query_string: &str,
|
||||
) -> Result<(), QueryError> {
|
||||
panic!()
|
||||
}
|
||||
}
|
||||
|
||||
let mut handler = TestHandler;
|
||||
|
||||
let pgb = PostgresBackend::new(server_sock, AuthType::Trust, None, true).unwrap();
|
||||
pgb.run(&mut handler).unwrap();
|
||||
|
||||
client_jh.join().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn server_forces_ssl() {
|
||||
let (mut client_sock, server_sock) = make_tcp_pair();
|
||||
|
||||
let client_jh = std::thread::spawn(move || {
|
||||
// StartupMessage
|
||||
client_sock.write_u32::<BigEndian>(9).unwrap();
|
||||
client_sock.write_u32::<BigEndian>(196608).unwrap();
|
||||
client_sock.write_u8(0).unwrap();
|
||||
client_sock.flush().unwrap();
|
||||
|
||||
// ErrorResponse
|
||||
assert_eq!(client_sock.read_u8().unwrap(), b'E');
|
||||
let len = client_sock.read_u32::<BigEndian>().unwrap() - 4;
|
||||
|
||||
let mut body = vec![0; len as usize];
|
||||
client_sock.read_exact(&mut body).unwrap();
|
||||
let mut body = Bytes::from(body);
|
||||
|
||||
let mut errors = HashMap::new();
|
||||
loop {
|
||||
let field_type = body.get_u8();
|
||||
if field_type == 0u8 {
|
||||
break;
|
||||
}
|
||||
|
||||
let end_idx = body.iter().position(|&b| b == 0u8).unwrap();
|
||||
let mut value = body.split_to(end_idx + 1);
|
||||
assert_eq!(value[end_idx], 0u8);
|
||||
value.truncate(end_idx);
|
||||
let old = errors.insert(field_type, value);
|
||||
assert!(old.is_none());
|
||||
}
|
||||
|
||||
assert!(!body.has_remaining());
|
||||
|
||||
assert_eq!("must connect with TLS", errors.get(&b'M').unwrap());
|
||||
|
||||
// TODO read failure
|
||||
});
|
||||
|
||||
struct TestHandler;
|
||||
impl Handler for TestHandler {
|
||||
fn process_query(
|
||||
&mut self,
|
||||
_pgb: &mut PostgresBackend,
|
||||
_query_string: &str,
|
||||
) -> Result<(), QueryError> {
|
||||
panic!()
|
||||
}
|
||||
}
|
||||
let mut handler = TestHandler;
|
||||
|
||||
let cfg = rustls::ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(vec![CERT.clone()], KEY.clone())
|
||||
.unwrap();
|
||||
let tls_config = Some(Arc::new(cfg));
|
||||
|
||||
let pgb = PostgresBackend::new(server_sock, AuthType::Trust, tls_config, true).unwrap();
|
||||
let res = pgb.run(&mut handler).unwrap_err();
|
||||
assert_eq!("client did not connect with TLS", format!("{}", res));
|
||||
|
||||
client_jh.join().unwrap();
|
||||
|
||||
// TODO consider shutdown behavior
|
||||
}
|
||||
@@ -23,11 +23,10 @@ use pageserver::{
|
||||
tenant::mgr,
|
||||
virtual_file,
|
||||
};
|
||||
use postgres_backend::AuthType;
|
||||
use utils::{
|
||||
auth::JwtAuth,
|
||||
logging,
|
||||
postgres_backend::AuthType,
|
||||
project_git_version,
|
||||
logging, project_git_version,
|
||||
sentry_init::init_sentry,
|
||||
signals::{self, Signal},
|
||||
tcp_listener,
|
||||
|
||||
@@ -21,10 +21,10 @@ use std::time::Duration;
|
||||
use toml_edit;
|
||||
use toml_edit::{Document, Item};
|
||||
|
||||
use postgres_backend::AuthType;
|
||||
use utils::{
|
||||
id::{NodeId, TenantId, TimelineId},
|
||||
logging::LogFormat,
|
||||
postgres_backend::AuthType,
|
||||
};
|
||||
|
||||
use crate::tenant::config::TenantConf;
|
||||
|
||||
@@ -20,7 +20,7 @@ use pageserver_api::models::{
|
||||
PagestreamFeMessage, PagestreamGetPageRequest, PagestreamGetPageResponse,
|
||||
PagestreamNblocksRequest, PagestreamNblocksResponse,
|
||||
};
|
||||
use postgres_backend::{self, is_expected_io_error, PostgresBackend, QueryError};
|
||||
use postgres_backend::{self, is_expected_io_error, AuthType, PostgresBackend, QueryError};
|
||||
use pq_proto::ConnectionError;
|
||||
use pq_proto::FeStartupPacket;
|
||||
use pq_proto::{BeMessage, FeMessage, RowDescriptor};
|
||||
@@ -36,7 +36,6 @@ use utils::{
|
||||
auth::{Claims, JwtAuth, Scope},
|
||||
id::{TenantId, TimelineId},
|
||||
lsn::Lsn,
|
||||
postgres_backend::AuthType,
|
||||
simple_rcu::RcuReadGuard,
|
||||
};
|
||||
|
||||
@@ -68,7 +67,7 @@ fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream<Item = io::Result<Byt
|
||||
Err(QueryError::Other(anyhow::anyhow!(msg)))
|
||||
}
|
||||
|
||||
msg = pgb.read_message() => { msg }
|
||||
msg = pgb.read_message() => { msg.map_err(QueryError::from)}
|
||||
};
|
||||
|
||||
match msg {
|
||||
@@ -79,14 +78,16 @@ fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream<Item = io::Result<Byt
|
||||
FeMessage::Sync => continue,
|
||||
FeMessage::Terminate => {
|
||||
let msg = "client terminated connection with Terminate message during COPY";
|
||||
let query_error_error = QueryError::Disconnected(ConnectionError::Socket(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
|
||||
pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error_error.pg_error_code())))?;
|
||||
let query_error = QueryError::Disconnected(ConnectionError::Socket(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
|
||||
// error can't happen here, ErrorResponse serialization should be always ok
|
||||
pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code()))).map_err(|e| e.into_io_error())?;
|
||||
Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?;
|
||||
break;
|
||||
}
|
||||
m => {
|
||||
let msg = format!("unexpected message {m:?}");
|
||||
pgb.write_message_noflush(&BeMessage::ErrorResponse(&msg, None))?;
|
||||
// error can't happen here, ErrorResponse serialization should be always ok
|
||||
pgb.write_message_noflush(&BeMessage::ErrorResponse(&msg, None)).map_err(|e| e.into_io_error())?;
|
||||
Err(io::Error::new(io::ErrorKind::Other, msg))?;
|
||||
break;
|
||||
}
|
||||
@@ -96,8 +97,9 @@ fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream<Item = io::Result<Byt
|
||||
}
|
||||
Ok(None) => {
|
||||
let msg = "client closed connection during COPY";
|
||||
let query_error_error = QueryError::Disconnected(ConnectionError::Socket(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
|
||||
pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error_error.pg_error_code())))?;
|
||||
let query_error = QueryError::Disconnected(ConnectionError::Socket(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
|
||||
// error can't happen here, ErrorResponse serialization should be always ok
|
||||
pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code()))).map_err(|e| e.into_io_error())?;
|
||||
pgb.flush().await?;
|
||||
Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?;
|
||||
}
|
||||
@@ -105,7 +107,7 @@ fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream<Item = io::Result<Byt
|
||||
Err(io_error)?;
|
||||
}
|
||||
Err(other) => {
|
||||
Err(io::Error::new(io::ErrorKind::Other, other))?;
|
||||
Err(io::Error::new(io::ErrorKind::Other, other.to_string()))?;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@@ -435,8 +435,8 @@ fn ignore_expected_errors(pg_error: postgres::Error) -> anyhow::Result<postgres:
|
||||
{
|
||||
return Ok(pg_error);
|
||||
} else if let Some(db_error) = pg_error.as_db_error() {
|
||||
if db_error.code() == &SqlState::CONNECTION_FAILURE
|
||||
&& db_error.message().contains("end streaming")
|
||||
if db_error.code() == &SqlState::SUCCESSFUL_COMPLETION
|
||||
&& db_error.message().contains("ending streaming")
|
||||
{
|
||||
return Ok(pg_error);
|
||||
}
|
||||
|
||||
@@ -31,6 +31,7 @@ once_cell.workspace = true
|
||||
opentelemetry.workspace = true
|
||||
parking_lot.workspace = true
|
||||
pin-project-lite.workspace = true
|
||||
postgres_backend.workspace = true
|
||||
pq_proto.workspace = true
|
||||
prometheus.workspace = true
|
||||
rand.workspace = true
|
||||
|
||||
@@ -4,13 +4,11 @@ use crate::{
|
||||
};
|
||||
use anyhow::Context;
|
||||
use once_cell::sync::Lazy;
|
||||
use postgres_backend::{self, AuthType, PostgresBackend, QueryError};
|
||||
use pq_proto::{BeMessage, SINGLE_COL_ROWDESC};
|
||||
use std::{net::TcpStream, thread};
|
||||
use std::future;
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tracing::{error, info, info_span};
|
||||
use utils::{
|
||||
postgres_backend::QueryError,
|
||||
postgres_backend::{self, AuthType, PostgresBackend},
|
||||
};
|
||||
|
||||
static CPLANE_WAITERS: Lazy<Waiters<ComputeReady>> = Lazy::new(Default::default);
|
||||
|
||||
@@ -33,7 +31,7 @@ pub fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), waiters::N
|
||||
|
||||
/// Console management API listener task.
|
||||
/// It spawns console response handlers needed for the link auth.
|
||||
pub async fn task_main(listener: tokio::net::TcpListener) -> anyhow::Result<()> {
|
||||
pub async fn task_main(listener: TcpListener) -> anyhow::Result<()> {
|
||||
scopeguard::defer! {
|
||||
info!("mgmt has shut down");
|
||||
}
|
||||
@@ -42,18 +40,12 @@ pub async fn task_main(listener: tokio::net::TcpListener) -> anyhow::Result<()>
|
||||
let (socket, peer_addr) = listener.accept().await?;
|
||||
info!("accepted connection from {peer_addr}");
|
||||
|
||||
let socket = socket.into_std()?;
|
||||
socket
|
||||
.set_nodelay(true)
|
||||
.context("failed to set client socket option")?;
|
||||
socket
|
||||
.set_nonblocking(false)
|
||||
.context("failed to set client socket option")?;
|
||||
|
||||
// TODO: replace with async tasks.
|
||||
thread::spawn(move || {
|
||||
let tid = std::thread::current().id();
|
||||
let span = info_span!("mgmt", thread = format_args!("{tid:?}"));
|
||||
tokio::task::spawn(async move {
|
||||
let span = info_span!("mgmt", peer = %peer_addr);
|
||||
let _enter = span.enter();
|
||||
|
||||
info!("started a new console management API thread");
|
||||
@@ -61,16 +53,16 @@ pub async fn task_main(listener: tokio::net::TcpListener) -> anyhow::Result<()>
|
||||
info!("console management API thread is about to finish");
|
||||
}
|
||||
|
||||
if let Err(e) = handle_connection(socket) {
|
||||
if let Err(e) = handle_connection(socket).await {
|
||||
error!("thread failed with an error: {e}");
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_connection(socket: TcpStream) -> Result<(), QueryError> {
|
||||
let pgbackend = PostgresBackend::new(socket, AuthType::Trust, None, true)?;
|
||||
pgbackend.run(&mut MgmtHandler)
|
||||
async fn handle_connection(socket: TcpStream) -> Result<(), QueryError> {
|
||||
let pgbackend = PostgresBackend::new(socket, AuthType::Trust, None)?;
|
||||
pgbackend.run(&mut MgmtHandler, future::pending::<()>).await
|
||||
}
|
||||
|
||||
/// A message received by `mgmt` when a compute node is ready.
|
||||
@@ -78,16 +70,21 @@ pub type ComputeReady = Result<DatabaseInfo, String>;
|
||||
|
||||
// TODO: replace with an http-based protocol.
|
||||
struct MgmtHandler;
|
||||
#[async_trait::async_trait]
|
||||
impl postgres_backend::Handler for MgmtHandler {
|
||||
fn process_query(&mut self, pgb: &mut PostgresBackend, query: &str) -> Result<(), QueryError> {
|
||||
try_process_query(pgb, query).map_err(|e| {
|
||||
async fn process_query(
|
||||
&mut self,
|
||||
pgb: &mut PostgresBackend,
|
||||
query: &str,
|
||||
) -> Result<(), QueryError> {
|
||||
try_process_query(pgb, query).await.map_err(|e| {
|
||||
error!("failed to process response: {e:?}");
|
||||
e
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn try_process_query(pgb: &mut PostgresBackend, query: &str) -> Result<(), QueryError> {
|
||||
async fn try_process_query(pgb: &mut PostgresBackend, query: &str) -> Result<(), QueryError> {
|
||||
let resp: KickSession = serde_json::from_str(query).context("Failed to parse query as json")?;
|
||||
|
||||
let span = info_span!("event", session_id = resp.session_id);
|
||||
@@ -98,11 +95,11 @@ fn try_process_query(pgb: &mut PostgresBackend, query: &str) -> Result<(), Query
|
||||
Ok(()) => {
|
||||
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
|
||||
.write_message_noflush(&BeMessage::DataRow(&[Some(b"ok")]))?
|
||||
.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
}
|
||||
Err(e) => {
|
||||
error!("failed to deliver response to per-client task");
|
||||
pgb.write_message(&BeMessage::ErrorResponse(&e.to_string(), None))?;
|
||||
pgb.write_message_noflush(&BeMessage::ErrorResponse(&e.to_string(), None))?;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ impl<S: AsyncRead + Unpin> PqStream<S> {
|
||||
/// Receive [`FeStartupPacket`], which is a first packet sent by a client.
|
||||
pub async fn read_startup_packet(&mut self) -> io::Result<FeStartupPacket> {
|
||||
// TODO: `FeStartupPacket::read_fut` should return `FeStartupPacket`
|
||||
let msg = FeStartupPacket::read_fut(&mut self.stream)
|
||||
let msg = FeStartupPacket::read(&mut self.stream)
|
||||
.await
|
||||
.map_err(ConnectionError::into_io_error)?
|
||||
.ok_or_else(err_connection)?;
|
||||
@@ -73,7 +73,7 @@ impl<S: AsyncRead + Unpin> PqStream<S> {
|
||||
}
|
||||
|
||||
async fn read_message(&mut self) -> io::Result<FeMessage> {
|
||||
FeMessage::read_fut(&mut self.stream)
|
||||
FeMessage::read(&mut self.stream)
|
||||
.await
|
||||
.map_err(ConnectionError::into_io_error)?
|
||||
.ok_or_else(err_connection)
|
||||
|
||||
@@ -11,12 +11,18 @@
|
||||
|
||||
# Not every feature is supported in macOS builds. Avoid running regular linting
|
||||
# script that checks every feature.
|
||||
#
|
||||
# manual-range-contains wants
|
||||
# !(8..=MAX_STARTUP_PACKET_LENGTH).contains(&len)
|
||||
# instead of
|
||||
# len < 4 || len > MAX_STARTUP_PACKET_LENGTH
|
||||
# , let's disagree.
|
||||
if [[ "$OSTYPE" == "darwin"* ]]; then
|
||||
# no extra features to test currently, add more here when needed
|
||||
cargo clippy --locked --all --all-targets --features testing -- -A unknown_lints -D warnings
|
||||
cargo clippy --locked --all --all-targets --features testing -- -A unknown_lints -A clippy::manual-range-contains -D warnings
|
||||
else
|
||||
# * `-A unknown_lints` – do not warn about unknown lint suppressions
|
||||
# that people with newer toolchains might use
|
||||
# * `-D warnings` - fail on any warnings (`cargo` returns non-zero exit status)
|
||||
cargo clippy --locked --all --all-targets --all-features -- -A unknown_lints -D warnings
|
||||
cargo clippy --locked --all --all-targets --all-features -- -A unknown_lints -A clippy::manual-range-contains -D warnings
|
||||
fi
|
||||
|
||||
@@ -36,6 +36,7 @@ toml_edit.workspace = true
|
||||
tracing.workspace = true
|
||||
url.workspace = true
|
||||
metrics.workspace = true
|
||||
postgres_backend.workspace = true
|
||||
postgres_ffi.workspace = true
|
||||
pq_proto.workspace = true
|
||||
remote_storage.workspace = true
|
||||
|
||||
@@ -236,7 +236,7 @@ fn start_safekeeper(conf: SafeKeeperConf) -> Result<()> {
|
||||
|
||||
let conf_cloned = conf.clone();
|
||||
let safekeeper_thread = thread::Builder::new()
|
||||
.name("safekeeper thread".into())
|
||||
.name("WAL service thread".into())
|
||||
.spawn(|| wal_service::thread_main(conf_cloned, pg_listener))
|
||||
.unwrap();
|
||||
|
||||
|
||||
@@ -1,27 +1,23 @@
|
||||
//! Part of Safekeeper pretending to be Postgres, i.e. handling Postgres
|
||||
//! protocol commands.
|
||||
|
||||
use anyhow::Context;
|
||||
use std::str;
|
||||
use tracing::{info, info_span, Instrument};
|
||||
|
||||
use crate::auth::check_permission;
|
||||
use crate::json_ctrl::{handle_json_ctrl, AppendLogicalMessage};
|
||||
use crate::receive_wal::ReceiveWalConn;
|
||||
|
||||
use crate::send_wal::ReplicationConn;
|
||||
|
||||
use crate::{GlobalTimelines, SafeKeeperConf};
|
||||
use anyhow::Context;
|
||||
|
||||
use postgres_backend::QueryError;
|
||||
use postgres_backend::{self, PostgresBackend};
|
||||
use postgres_ffi::PG_TLI;
|
||||
use regex::Regex;
|
||||
|
||||
use pq_proto::{BeMessage, FeStartupPacket, RowDescriptor, INT4_OID, TEXT_OID};
|
||||
use std::str;
|
||||
use tracing::info;
|
||||
use regex::Regex;
|
||||
use utils::auth::{Claims, Scope};
|
||||
use utils::postgres_backend::QueryError;
|
||||
use utils::{
|
||||
id::{TenantId, TenantTimelineId, TimelineId},
|
||||
lsn::Lsn,
|
||||
postgres_backend::{self, PostgresBackend},
|
||||
};
|
||||
|
||||
/// Safekeeper handler of postgres commands
|
||||
@@ -53,7 +49,7 @@ fn parse_cmd(cmd: &str) -> anyhow::Result<SafekeeperPostgresCommand> {
|
||||
let start_lsn = caps
|
||||
.next()
|
||||
.map(|cap| cap[1].parse::<Lsn>())
|
||||
.context("failed to parse start LSN from START_REPLICATION command")??;
|
||||
.context("parse start LSN from START_REPLICATION command")??;
|
||||
Ok(SafekeeperPostgresCommand::StartReplication { start_lsn })
|
||||
} else if cmd.starts_with("IDENTIFY_SYSTEM") {
|
||||
Ok(SafekeeperPostgresCommand::IdentifySystem)
|
||||
@@ -67,6 +63,7 @@ fn parse_cmd(cmd: &str) -> anyhow::Result<SafekeeperPostgresCommand> {
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl postgres_backend::Handler for SafekeeperPostgresHandler {
|
||||
// tenant_id and timeline_id are passed in connection string params
|
||||
fn startup(
|
||||
@@ -137,7 +134,7 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn process_query(
|
||||
async fn process_query(
|
||||
&mut self,
|
||||
pgb: &mut PostgresBackend,
|
||||
query_string: &str,
|
||||
@@ -147,9 +144,10 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler {
|
||||
.starts_with("set datestyle to ")
|
||||
{
|
||||
// important for debug because psycopg2 executes "SET datestyle TO 'ISO'" on connect
|
||||
pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let cmd = parse_cmd(query_string)?;
|
||||
|
||||
info!(
|
||||
@@ -161,26 +159,23 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler {
|
||||
let timeline_id = self.timeline_id.context("timelineid is required")?;
|
||||
self.check_permission(Some(tenant_id))?;
|
||||
self.ttid = TenantTimelineId::new(tenant_id, timeline_id);
|
||||
let span_ttid = self.ttid; // satisfy borrow checker
|
||||
|
||||
let res = match cmd {
|
||||
SafekeeperPostgresCommand::StartWalPush => ReceiveWalConn::new(pgb).run(self),
|
||||
match cmd {
|
||||
SafekeeperPostgresCommand::StartWalPush => {
|
||||
self.handle_start_wal_push(pgb)
|
||||
.instrument(info_span!("WAL receiver", ttid = %span_ttid))
|
||||
.await
|
||||
}
|
||||
SafekeeperPostgresCommand::StartReplication { start_lsn } => {
|
||||
ReplicationConn::new(pgb).run(self, pgb, start_lsn)
|
||||
self.handle_start_replication(pgb, start_lsn)
|
||||
.instrument(info_span!("WAL sender", ttid = %span_ttid))
|
||||
.await
|
||||
}
|
||||
SafekeeperPostgresCommand::IdentifySystem => self.handle_identify_system(pgb),
|
||||
SafekeeperPostgresCommand::JSONCtrl { ref cmd } => handle_json_ctrl(self, pgb, cmd),
|
||||
};
|
||||
|
||||
match res {
|
||||
Ok(()) => Ok(()),
|
||||
Err(QueryError::Disconnected(connection_error)) => {
|
||||
info!("Timeline {tenant_id}/{timeline_id} query failed with connection error: {connection_error}");
|
||||
Err(QueryError::Disconnected(connection_error))
|
||||
SafekeeperPostgresCommand::IdentifySystem => self.handle_identify_system(pgb).await,
|
||||
SafekeeperPostgresCommand::JSONCtrl { ref cmd } => {
|
||||
handle_json_ctrl(self, pgb, cmd).await
|
||||
}
|
||||
Err(QueryError::Other(e)) => Err(QueryError::Other(e.context(format!(
|
||||
"Failed to process query for timeline {}",
|
||||
self.ttid
|
||||
)))),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -217,7 +212,10 @@ impl SafekeeperPostgresHandler {
|
||||
///
|
||||
/// Handle IDENTIFY_SYSTEM replication command
|
||||
///
|
||||
fn handle_identify_system(&mut self, pgb: &mut PostgresBackend) -> Result<(), QueryError> {
|
||||
async fn handle_identify_system(
|
||||
&mut self,
|
||||
pgb: &mut PostgresBackend,
|
||||
) -> Result<(), QueryError> {
|
||||
let tli = GlobalTimelines::get(self.ttid).map_err(|e| QueryError::Other(e.into()))?;
|
||||
|
||||
let lsn = if self.is_walproposer_recovery() {
|
||||
@@ -267,7 +265,7 @@ impl SafekeeperPostgresHandler {
|
||||
Some(lsn_bytes),
|
||||
None,
|
||||
]))?
|
||||
.write_message(&BeMessage::CommandComplete(b"IDENTIFY_SYSTEM"))?;
|
||||
.write_message_noflush(&BeMessage::CommandComplete(b"IDENTIFY_SYSTEM"))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -168,12 +168,9 @@ async fn timeline_create_handler(mut request: Request<Body>) -> Result<Response<
|
||||
.commit_lsn
|
||||
.segment_lsn(server_info.wal_seg_size as usize)
|
||||
});
|
||||
tokio::task::spawn_blocking(move || {
|
||||
GlobalTimelines::create(ttid, server_info, request_data.commit_lsn, local_start_lsn)
|
||||
})
|
||||
.await
|
||||
.map_err(|e| ApiError::InternalServerError(e.into()))?
|
||||
.map_err(ApiError::InternalServerError)?;
|
||||
GlobalTimelines::create(ttid, server_info, request_data.commit_lsn, local_start_lsn)
|
||||
.await
|
||||
.map_err(ApiError::InternalServerError)?;
|
||||
|
||||
json_response(StatusCode::OK, ())
|
||||
}
|
||||
|
||||
@@ -10,10 +10,10 @@ use std::sync::Arc;
|
||||
|
||||
use anyhow::Context;
|
||||
use bytes::Bytes;
|
||||
use postgres_backend::QueryError;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::*;
|
||||
use utils::id::TenantTimelineId;
|
||||
use utils::postgres_backend::QueryError;
|
||||
|
||||
use crate::handler::SafekeeperPostgresHandler;
|
||||
use crate::safekeeper::{AcceptorProposerMessage, AppendResponse, ServerInfo};
|
||||
@@ -23,29 +23,30 @@ use crate::safekeeper::{
|
||||
use crate::safekeeper::{SafeKeeperState, Term, TermHistory, TermSwitchEntry};
|
||||
use crate::timeline::Timeline;
|
||||
use crate::GlobalTimelines;
|
||||
use postgres_backend::PostgresBackend;
|
||||
use postgres_ffi::encode_logical_message;
|
||||
use postgres_ffi::WAL_SEGMENT_SIZE;
|
||||
use pq_proto::{BeMessage, RowDescriptor, TEXT_OID};
|
||||
use utils::{lsn::Lsn, postgres_backend::PostgresBackend};
|
||||
use utils::lsn::Lsn;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct AppendLogicalMessage {
|
||||
// prefix and message to build LogicalMessage
|
||||
lm_prefix: String,
|
||||
lm_message: String,
|
||||
pub lm_prefix: String,
|
||||
pub lm_message: String,
|
||||
|
||||
// if true, commit_lsn will match flush_lsn after append
|
||||
set_commit_lsn: bool,
|
||||
pub set_commit_lsn: bool,
|
||||
|
||||
// if true, ProposerElected will be sent before append
|
||||
send_proposer_elected: bool,
|
||||
pub send_proposer_elected: bool,
|
||||
|
||||
// fields from AppendRequestHeader
|
||||
term: Term,
|
||||
epoch_start_lsn: Lsn,
|
||||
begin_lsn: Lsn,
|
||||
truncate_lsn: Lsn,
|
||||
pg_version: u32,
|
||||
pub term: Term,
|
||||
pub epoch_start_lsn: Lsn,
|
||||
pub begin_lsn: Lsn,
|
||||
pub truncate_lsn: Lsn,
|
||||
pub pg_version: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
@@ -59,7 +60,7 @@ struct AppendResult {
|
||||
/// Handles command to craft logical message WAL record with given
|
||||
/// content, and then append it with specified term and lsn. This
|
||||
/// function is used to test safekeepers in different scenarios.
|
||||
pub fn handle_json_ctrl(
|
||||
pub async fn handle_json_ctrl(
|
||||
spg: &SafekeeperPostgresHandler,
|
||||
pgb: &mut PostgresBackend,
|
||||
append_request: &AppendLogicalMessage,
|
||||
@@ -67,7 +68,7 @@ pub fn handle_json_ctrl(
|
||||
info!("JSON_CTRL request: {append_request:?}");
|
||||
|
||||
// need to init safekeeper state before AppendRequest
|
||||
let tli = prepare_safekeeper(spg.ttid, append_request.pg_version)?;
|
||||
let tli = prepare_safekeeper(spg.ttid, append_request.pg_version).await?;
|
||||
|
||||
// if send_proposer_elected is true, we need to update local history
|
||||
if append_request.send_proposer_elected {
|
||||
@@ -89,13 +90,16 @@ pub fn handle_json_ctrl(
|
||||
..Default::default()
|
||||
}]))?
|
||||
.write_message_noflush(&BeMessage::DataRow(&[Some(&response_data)]))?
|
||||
.write_message(&BeMessage::CommandComplete(b"JSON_CTRL"))?;
|
||||
.write_message_noflush(&BeMessage::CommandComplete(b"JSON_CTRL"))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Prepare safekeeper to process append requests without crashes,
|
||||
/// by sending ProposerGreeting with default server.wal_seg_size.
|
||||
fn prepare_safekeeper(ttid: TenantTimelineId, pg_version: u32) -> anyhow::Result<Arc<Timeline>> {
|
||||
async fn prepare_safekeeper(
|
||||
ttid: TenantTimelineId,
|
||||
pg_version: u32,
|
||||
) -> anyhow::Result<Arc<Timeline>> {
|
||||
GlobalTimelines::create(
|
||||
ttid,
|
||||
ServerInfo {
|
||||
@@ -106,6 +110,7 @@ fn prepare_safekeeper(ttid: TenantTimelineId, pg_version: u32) -> anyhow::Result
|
||||
Lsn::INVALID,
|
||||
Lsn::INVALID,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
fn send_proposer_elected(tli: &Arc<Timeline>, term: Term, lsn: Lsn) -> anyhow::Result<()> {
|
||||
@@ -128,15 +133,15 @@ fn send_proposer_elected(tli: &Arc<Timeline>, term: Term, lsn: Lsn) -> anyhow::R
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct InsertedWAL {
|
||||
pub struct InsertedWAL {
|
||||
begin_lsn: Lsn,
|
||||
end_lsn: Lsn,
|
||||
pub end_lsn: Lsn,
|
||||
append_response: AppendResponse,
|
||||
}
|
||||
|
||||
/// Extend local WAL with new LogicalMessage record. To do that,
|
||||
/// create AppendRequest with new WAL and pass it to safekeeper.
|
||||
fn append_logical_message(
|
||||
pub fn append_logical_message(
|
||||
tli: &Arc<Timeline>,
|
||||
msg: &AppendLogicalMessage,
|
||||
) -> anyhow::Result<InsertedWAL> {
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
use storage_broker::Uri;
|
||||
//
|
||||
use remote_storage::RemoteStorageConfig;
|
||||
use std::path::PathBuf;
|
||||
use std::time::Duration;
|
||||
use storage_broker::Uri;
|
||||
|
||||
use utils::id::{NodeId, TenantId, TenantTimelineId};
|
||||
|
||||
|
||||
@@ -2,204 +2,284 @@
|
||||
//! Gets messages from the network, passes them down to consensus module and
|
||||
//! sends replies back.
|
||||
|
||||
use anyhow::anyhow;
|
||||
use anyhow::Context;
|
||||
|
||||
use bytes::BytesMut;
|
||||
use tracing::*;
|
||||
use utils::lsn::Lsn;
|
||||
use utils::postgres_backend::QueryError;
|
||||
|
||||
use crate::handler::SafekeeperPostgresHandler;
|
||||
use crate::safekeeper::AcceptorProposerMessage;
|
||||
use crate::safekeeper::ProposerAcceptorMessage;
|
||||
use crate::safekeeper::ServerInfo;
|
||||
use crate::timeline::Timeline;
|
||||
use crate::GlobalTimelines;
|
||||
|
||||
use anyhow::{anyhow, Context};
|
||||
use bytes::BytesMut;
|
||||
use nix::unistd::gettid;
|
||||
use postgres_backend::CopyStreamHandlerEnd;
|
||||
use postgres_backend::PostgresBackend;
|
||||
use postgres_backend::PostgresBackendReader;
|
||||
use postgres_backend::QueryError;
|
||||
use pq_proto::BeMessage;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::mpsc::channel;
|
||||
use std::sync::mpsc::Receiver;
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
use std::thread::JoinHandle;
|
||||
use tokio::sync::mpsc::channel;
|
||||
use tokio::sync::mpsc::error::TryRecvError;
|
||||
use tokio::sync::mpsc::Receiver;
|
||||
use tokio::sync::mpsc::Sender;
|
||||
use tokio::task::spawn_blocking;
|
||||
use tracing::*;
|
||||
use utils::id::TenantTimelineId;
|
||||
use utils::lsn::Lsn;
|
||||
|
||||
use crate::safekeeper::AcceptorProposerMessage;
|
||||
use crate::safekeeper::ProposerAcceptorMessage;
|
||||
const MSG_QUEUE_SIZE: usize = 256;
|
||||
const REPLY_QUEUE_SIZE: usize = 16;
|
||||
|
||||
use crate::handler::SafekeeperPostgresHandler;
|
||||
use pq_proto::{BeMessage, FeMessage};
|
||||
use utils::{postgres_backend::PostgresBackend, sock_split::ReadStream};
|
||||
|
||||
pub struct ReceiveWalConn<'pg> {
|
||||
/// Postgres connection
|
||||
pg_backend: &'pg mut PostgresBackend,
|
||||
/// The cached result of `pg_backend.socket().peer_addr()` (roughly)
|
||||
peer_addr: SocketAddr,
|
||||
}
|
||||
|
||||
impl<'pg> ReceiveWalConn<'pg> {
|
||||
pub fn new(pg: &'pg mut PostgresBackend) -> ReceiveWalConn<'pg> {
|
||||
let peer_addr = *pg.get_peer_addr();
|
||||
ReceiveWalConn {
|
||||
pg_backend: pg,
|
||||
peer_addr,
|
||||
impl SafekeeperPostgresHandler {
|
||||
/// Wrapper around handle_start_wal_push_guts handling result. Error is
|
||||
/// handled here while we're still in walreceiver ttid span; with API
|
||||
/// extension, this can probably be moved into postgres_backend.
|
||||
pub async fn handle_start_wal_push(
|
||||
&mut self,
|
||||
pgb: &mut PostgresBackend,
|
||||
) -> Result<(), QueryError> {
|
||||
if let Err(end) = self.handle_start_wal_push_guts(pgb).await {
|
||||
// Log the result and probably send it to the client, closing the stream.
|
||||
pgb.handle_copy_stream_end(end).await;
|
||||
}
|
||||
}
|
||||
|
||||
// Send message to the postgres
|
||||
fn write_msg(&mut self, msg: &AcceptorProposerMessage) -> anyhow::Result<()> {
|
||||
let mut buf = BytesMut::with_capacity(128);
|
||||
msg.serialize(&mut buf)?;
|
||||
self.pg_backend.write_message(&BeMessage::CopyData(&buf))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Receive WAL from wal_proposer
|
||||
pub fn run(&mut self, spg: &mut SafekeeperPostgresHandler) -> Result<(), QueryError> {
|
||||
let _enter = info_span!("WAL acceptor", ttid = %spg.ttid).entered();
|
||||
|
||||
pub async fn handle_start_wal_push_guts(
|
||||
&mut self,
|
||||
pgb: &mut PostgresBackend,
|
||||
) -> Result<(), CopyStreamHandlerEnd> {
|
||||
// Notify the libpq client that it's allowed to send `CopyData` messages
|
||||
self.pg_backend
|
||||
.write_message(&BeMessage::CopyBothResponse)?;
|
||||
pgb.write_message(&BeMessage::CopyBothResponse).await?;
|
||||
|
||||
let r = self
|
||||
.pg_backend
|
||||
.take_stream_in()
|
||||
.ok_or_else(|| anyhow!("failed to take read stream from pgbackend"))?;
|
||||
let mut poll_reader = ProposerPollStream::new(r)?;
|
||||
// Experiments [1] confirm that doing network IO in one (this) thread and
|
||||
// processing with disc IO in another significantly improves
|
||||
// performance; we spawn off WalAcceptor thread for message processing
|
||||
// to this end.
|
||||
//
|
||||
// [1] https://github.com/neondatabase/neon/pull/1318
|
||||
let (msg_tx, msg_rx) = channel(MSG_QUEUE_SIZE);
|
||||
let (reply_tx, reply_rx) = channel(REPLY_QUEUE_SIZE);
|
||||
let mut acceptor_handle: Option<JoinHandle<anyhow::Result<()>>> = None;
|
||||
|
||||
// Receive information about server
|
||||
let next_msg = poll_reader.recv_msg()?;
|
||||
let tli = match next_msg {
|
||||
ProposerAcceptorMessage::Greeting(ref greeting) => {
|
||||
info!(
|
||||
"start handshake with walproposer {} sysid {} timeline {}",
|
||||
self.peer_addr, greeting.system_id, greeting.tli,
|
||||
);
|
||||
let server_info = ServerInfo {
|
||||
pg_version: greeting.pg_version,
|
||||
system_id: greeting.system_id,
|
||||
wal_seg_size: greeting.wal_seg_size,
|
||||
};
|
||||
GlobalTimelines::create(spg.ttid, server_info, Lsn::INVALID, Lsn::INVALID)?
|
||||
}
|
||||
_ => {
|
||||
return Err(QueryError::Other(anyhow::anyhow!(
|
||||
"unexpected message {next_msg:?} instead of greeting"
|
||||
)))
|
||||
}
|
||||
// Concurrently receive and send data; replies are not synchronized with
|
||||
// sends, so this avoids deadlocks.
|
||||
let mut pgb_reader = pgb.split().context("START_WAL_PUSH split")?;
|
||||
let peer_addr = *pgb.get_peer_addr();
|
||||
let res = tokio::select! {
|
||||
// todo: add read|write .context to these errors
|
||||
r = read_network(self.ttid, &mut pgb_reader, peer_addr, msg_tx, &mut acceptor_handle, msg_rx, reply_tx) => r,
|
||||
r = write_network(pgb, reply_rx) => r,
|
||||
};
|
||||
|
||||
let mut next_msg = Some(next_msg);
|
||||
// Join pg backend back.
|
||||
pgb.unsplit(pgb_reader)?;
|
||||
|
||||
let mut first_time_through = true;
|
||||
let mut _guard: Option<ComputeConnectionGuard> = None;
|
||||
loop {
|
||||
if matches!(next_msg, Some(ProposerAcceptorMessage::AppendRequest(_))) {
|
||||
// poll AppendRequest's without blocking and write WAL to disk without flushing,
|
||||
// while it's readily available
|
||||
while let Some(ProposerAcceptorMessage::AppendRequest(append_request)) = next_msg {
|
||||
let msg = ProposerAcceptorMessage::NoFlushAppendRequest(append_request);
|
||||
|
||||
let reply = tli.process_msg(&msg)?;
|
||||
if let Some(reply) = reply {
|
||||
self.write_msg(&reply)?;
|
||||
}
|
||||
|
||||
next_msg = poll_reader.poll_msg();
|
||||
}
|
||||
|
||||
// flush all written WAL to the disk
|
||||
let reply = tli.process_msg(&ProposerAcceptorMessage::FlushWAL)?;
|
||||
if let Some(reply) = reply {
|
||||
self.write_msg(&reply)?;
|
||||
}
|
||||
} else if let Some(msg) = next_msg.take() {
|
||||
// process other message
|
||||
let reply = tli.process_msg(&msg)?;
|
||||
if let Some(reply) = reply {
|
||||
self.write_msg(&reply)?;
|
||||
}
|
||||
}
|
||||
if first_time_through {
|
||||
// Register the connection and defer unregister. Do that only
|
||||
// after processing first message, as it sets wal_seg_size,
|
||||
// wanted by many.
|
||||
tli.on_compute_connect()?;
|
||||
_guard = Some(ComputeConnectionGuard {
|
||||
timeline: Arc::clone(&tli),
|
||||
});
|
||||
first_time_through = false;
|
||||
// Join the spawned WalAcceptor. At this point chans to/from it passed
|
||||
// to network routines are dropped, so it will exit as soon as it
|
||||
// touches them.
|
||||
match acceptor_handle {
|
||||
None => {
|
||||
// failed even before spawning; read_network should have error
|
||||
Err(res.expect_err("no error with WalAcceptor not spawn"))
|
||||
}
|
||||
Some(handle) => {
|
||||
let wal_acceptor_res = handle.join();
|
||||
|
||||
// blocking wait for the next message
|
||||
if next_msg.is_none() {
|
||||
next_msg = Some(poll_reader.recv_msg()?);
|
||||
// If there was any network error, return it.
|
||||
res?;
|
||||
|
||||
// Otherwise, WalAcceptor thread must have errored.
|
||||
match wal_acceptor_res {
|
||||
Ok(Ok(_)) => Ok(()), // can't happen currently; would be if we add graceful termination
|
||||
Ok(Err(e)) => Err(CopyStreamHandlerEnd::Other(e.context("WAL acceptor"))),
|
||||
Err(_) => Err(CopyStreamHandlerEnd::Other(anyhow!(
|
||||
"WalAcceptor thread panicked",
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct ProposerPollStream {
|
||||
msg_rx: Receiver<ProposerAcceptorMessage>,
|
||||
read_thread: Option<thread::JoinHandle<Result<(), QueryError>>>,
|
||||
/// Read next message from walproposer.
|
||||
/// TODO: Return Ok(None) on graceful termination.
|
||||
async fn read_message(
|
||||
pgb_reader: &mut PostgresBackendReader,
|
||||
) -> Result<ProposerAcceptorMessage, CopyStreamHandlerEnd> {
|
||||
let copy_data = pgb_reader.read_copy_message().await?;
|
||||
let msg = ProposerAcceptorMessage::parse(copy_data)?;
|
||||
Ok(msg)
|
||||
}
|
||||
|
||||
impl ProposerPollStream {
|
||||
fn new(mut r: ReadStream) -> anyhow::Result<Self> {
|
||||
let (msg_tx, msg_rx) = channel();
|
||||
|
||||
let read_thread = thread::Builder::new()
|
||||
.name("Read WAL thread".into())
|
||||
.spawn(move || -> Result<(), QueryError> {
|
||||
loop {
|
||||
let copy_data = match FeMessage::read(&mut r)? {
|
||||
Some(FeMessage::CopyData(bytes)) => Ok(bytes),
|
||||
Some(msg) => Err(QueryError::Other(anyhow::anyhow!(
|
||||
"expected `CopyData` message, found {msg:?}"
|
||||
))),
|
||||
None => Err(QueryError::from(std::io::Error::new(
|
||||
std::io::ErrorKind::ConnectionAborted,
|
||||
"walproposer closed the connection",
|
||||
))),
|
||||
}?;
|
||||
|
||||
let msg = ProposerAcceptorMessage::parse(copy_data)?;
|
||||
msg_tx
|
||||
.send(msg)
|
||||
.context("Failed to send the proposer message")?;
|
||||
}
|
||||
// msg_tx will be dropped here, this will also close msg_rx
|
||||
})?;
|
||||
|
||||
Ok(Self {
|
||||
msg_rx,
|
||||
read_thread: Some(read_thread),
|
||||
})
|
||||
}
|
||||
|
||||
fn recv_msg(&mut self) -> Result<ProposerAcceptorMessage, QueryError> {
|
||||
self.msg_rx.recv().map_err(|_| {
|
||||
// return error from the read thread
|
||||
let res = match self.read_thread.take() {
|
||||
Some(thread) => thread.join(),
|
||||
None => return QueryError::Other(anyhow::anyhow!("read thread is gone")),
|
||||
/// Read messages from socket and pass it to WalAcceptor thread. Returns Ok(())
|
||||
/// if msg_tx closed; it must mean WalAcceptor terminated, joining it should
|
||||
/// tell the error.
|
||||
async fn read_network(
|
||||
ttid: TenantTimelineId,
|
||||
pgb_reader: &mut PostgresBackendReader,
|
||||
peer_addr: SocketAddr,
|
||||
msg_tx: Sender<ProposerAcceptorMessage>,
|
||||
// WalAcceptor is spawned when we learn server info from walproposer and
|
||||
// create timeline; handle is put here.
|
||||
acceptor_handle: &mut Option<JoinHandle<anyhow::Result<()>>>,
|
||||
msg_rx: Receiver<ProposerAcceptorMessage>,
|
||||
reply_tx: Sender<AcceptorProposerMessage>,
|
||||
) -> Result<(), CopyStreamHandlerEnd> {
|
||||
// Receive information about server to create timeline, if not yet.
|
||||
let next_msg = read_message(pgb_reader).await?;
|
||||
let tli = match next_msg {
|
||||
ProposerAcceptorMessage::Greeting(ref greeting) => {
|
||||
info!(
|
||||
"start handshake with walproposer {} sysid {} timeline {}",
|
||||
peer_addr, greeting.system_id, greeting.tli,
|
||||
);
|
||||
let server_info = ServerInfo {
|
||||
pg_version: greeting.pg_version,
|
||||
system_id: greeting.system_id,
|
||||
wal_seg_size: greeting.wal_seg_size,
|
||||
};
|
||||
GlobalTimelines::create(ttid, server_info, Lsn::INVALID, Lsn::INVALID).await?
|
||||
}
|
||||
_ => {
|
||||
return Err(CopyStreamHandlerEnd::Other(anyhow::anyhow!(
|
||||
"unexpected message {next_msg:?} instead of greeting"
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
match res {
|
||||
Ok(Ok(())) => {
|
||||
QueryError::Other(anyhow::anyhow!("unexpected result from read thread"))
|
||||
}
|
||||
Err(err) => QueryError::Other(anyhow::anyhow!("read thread panicked: {err:?}")),
|
||||
Ok(Err(err)) => err,
|
||||
*acceptor_handle = Some(
|
||||
WalAcceptor::spawn(tli.clone(), msg_rx, reply_tx).context("spawn WalAcceptor thread")?,
|
||||
);
|
||||
|
||||
// Forward all messages to WalAcceptor
|
||||
read_network_loop(pgb_reader, msg_tx, next_msg).await
|
||||
}
|
||||
|
||||
async fn read_network_loop(
|
||||
pgb_reader: &mut PostgresBackendReader,
|
||||
msg_tx: Sender<ProposerAcceptorMessage>,
|
||||
mut next_msg: ProposerAcceptorMessage,
|
||||
) -> Result<(), CopyStreamHandlerEnd> {
|
||||
loop {
|
||||
if msg_tx.send(next_msg).await.is_err() {
|
||||
return Ok(()); // chan closed, WalAcceptor terminated
|
||||
}
|
||||
next_msg = read_message(pgb_reader).await?;
|
||||
}
|
||||
}
|
||||
|
||||
/// Read replies from WalAcceptor and pass them back to socket. Returns Ok(())
|
||||
/// if reply_rx closed; it must mean WalAcceptor terminated, joining it should
|
||||
/// tell the error.
|
||||
async fn write_network(
|
||||
pgb_writer: &mut PostgresBackend,
|
||||
mut reply_rx: Receiver<AcceptorProposerMessage>,
|
||||
) -> Result<(), CopyStreamHandlerEnd> {
|
||||
let mut buf = BytesMut::with_capacity(128);
|
||||
|
||||
loop {
|
||||
match reply_rx.recv().await {
|
||||
Some(msg) => {
|
||||
buf.clear();
|
||||
msg.serialize(&mut buf)?;
|
||||
pgb_writer.write_message(&BeMessage::CopyData(&buf)).await?;
|
||||
}
|
||||
})
|
||||
None => return Ok(()), // chan closed, WalAcceptor terminated
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Takes messages from msg_rx, processes and pushes replies to reply_tx.
|
||||
struct WalAcceptor {
|
||||
tli: Arc<Timeline>,
|
||||
msg_rx: Receiver<ProposerAcceptorMessage>,
|
||||
reply_tx: Sender<AcceptorProposerMessage>,
|
||||
}
|
||||
|
||||
impl WalAcceptor {
|
||||
/// Spawn thread with WalAcceptor running, return handle to it.
|
||||
fn spawn(
|
||||
tli: Arc<Timeline>,
|
||||
msg_rx: Receiver<ProposerAcceptorMessage>,
|
||||
reply_tx: Sender<AcceptorProposerMessage>,
|
||||
) -> anyhow::Result<JoinHandle<anyhow::Result<()>>> {
|
||||
let thread_name = format!("WAL acceptor {}", tli.ttid);
|
||||
thread::Builder::new()
|
||||
.name(thread_name)
|
||||
.spawn(move || -> anyhow::Result<()> {
|
||||
let mut wa = WalAcceptor {
|
||||
tli,
|
||||
msg_rx,
|
||||
reply_tx,
|
||||
};
|
||||
|
||||
let runtime = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()?;
|
||||
|
||||
let span_ttid = wa.tli.ttid; // satisfy borrow checker
|
||||
runtime.block_on(
|
||||
wa.run()
|
||||
.instrument(info_span!("WAL acceptor", tid = %gettid(), ttid = %span_ttid)),
|
||||
)
|
||||
})
|
||||
.map_err(anyhow::Error::from)
|
||||
}
|
||||
|
||||
fn poll_msg(&mut self) -> Option<ProposerAcceptorMessage> {
|
||||
let res = self.msg_rx.try_recv();
|
||||
/// The main loop. Returns Ok(()) if either msg_rx or reply_tx got closed;
|
||||
/// it must mean that network thread terminated.
|
||||
async fn run(&mut self) -> anyhow::Result<()> {
|
||||
// Register the connection and defer unregister.
|
||||
self.tli.on_compute_connect().await?;
|
||||
let _guard = ComputeConnectionGuard {
|
||||
timeline: Arc::clone(&self.tli),
|
||||
};
|
||||
|
||||
match res {
|
||||
Err(_) => None,
|
||||
Ok(msg) => Some(msg),
|
||||
let mut next_msg: ProposerAcceptorMessage;
|
||||
|
||||
loop {
|
||||
let opt_msg = self.msg_rx.recv().await;
|
||||
if opt_msg.is_none() {
|
||||
return Ok(()); // chan closed, streaming terminated
|
||||
}
|
||||
next_msg = opt_msg.unwrap();
|
||||
|
||||
if matches!(next_msg, ProposerAcceptorMessage::AppendRequest(_)) {
|
||||
// loop through AppendRequest's while it's readily available to
|
||||
// write as many WAL as possible without fsyncing
|
||||
while let ProposerAcceptorMessage::AppendRequest(append_request) = next_msg {
|
||||
let noflush_msg = ProposerAcceptorMessage::NoFlushAppendRequest(append_request);
|
||||
|
||||
if let Some(reply) = self.tli.process_msg(&noflush_msg)? {
|
||||
if self.reply_tx.send(reply).await.is_err() {
|
||||
return Ok(()); // chan closed, streaming terminated
|
||||
}
|
||||
}
|
||||
|
||||
match self.msg_rx.try_recv() {
|
||||
Ok(msg) => next_msg = msg,
|
||||
Err(TryRecvError::Empty) => break,
|
||||
Err(TryRecvError::Disconnected) => return Ok(()), // chan closed, streaming terminated
|
||||
}
|
||||
}
|
||||
|
||||
// flush all written WAL to the disk
|
||||
if let Some(reply) = self.tli.process_msg(&ProposerAcceptorMessage::FlushWAL)? {
|
||||
if self.reply_tx.send(reply).await.is_err() {
|
||||
return Ok(()); // chan closed, streaming terminated
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// process message other than AppendRequest
|
||||
if let Some(reply) = self.tli.process_msg(&next_msg)? {
|
||||
if self.reply_tx.send(reply).await.is_err() {
|
||||
return Ok(()); // chan closed, streaming terminated
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -210,8 +290,13 @@ struct ComputeConnectionGuard {
|
||||
|
||||
impl Drop for ComputeConnectionGuard {
|
||||
fn drop(&mut self) {
|
||||
if let Err(e) = self.timeline.on_compute_disconnect() {
|
||||
error!("failed to unregister compute connection: {}", e);
|
||||
}
|
||||
let tli = self.timeline.clone();
|
||||
// tokio forbids to call blocking_send inside the runtime, and see
|
||||
// comments in on_compute_disconnect why we call blocking_send.
|
||||
spawn_blocking(move || {
|
||||
if let Err(e) = tli.on_compute_disconnect() {
|
||||
error!("failed to unregister compute connection: {}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -488,7 +488,7 @@ impl AcceptorProposerMessage {
|
||||
buf.put_u64_le(msg.hs_feedback.xmin);
|
||||
buf.put_u64_le(msg.hs_feedback.catalog_xmin);
|
||||
|
||||
msg.pageserver_feedback.serialize(buf)?
|
||||
msg.pageserver_feedback.serialize(buf)?;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -5,24 +5,22 @@ use crate::handler::SafekeeperPostgresHandler;
|
||||
use crate::timeline::{ReplicaState, Timeline};
|
||||
use crate::wal_storage::WalReader;
|
||||
use crate::GlobalTimelines;
|
||||
use anyhow::Context;
|
||||
|
||||
use anyhow::Context as AnyhowContext;
|
||||
use bytes::Bytes;
|
||||
use postgres_backend::PostgresBackend;
|
||||
use postgres_backend::{CopyStreamHandlerEnd, PostgresBackendReader, QueryError};
|
||||
use postgres_ffi::get_current_timestamp;
|
||||
use postgres_ffi::{TimestampTz, MAX_SEND_SIZE};
|
||||
use pq_proto::{BeMessage, ReplicationFeedback, WalSndKeepAlive, XLogDataBody};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::cmp::min;
|
||||
use std::net::Shutdown;
|
||||
use std::str;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::{io, str, thread};
|
||||
use utils::postgres_backend::QueryError;
|
||||
|
||||
use pq_proto::{BeMessage, FeMessage, ReplicationFeedback, WalSndKeepAlive, XLogDataBody};
|
||||
use tokio::sync::watch::Receiver;
|
||||
use tokio::time::timeout;
|
||||
use tracing::*;
|
||||
use utils::{bin_ser::BeSer, lsn::Lsn, postgres_backend::PostgresBackend, sock_split::ReadStream};
|
||||
use utils::{bin_ser::BeSer, lsn::Lsn};
|
||||
|
||||
// See: https://www.postgresql.org/docs/13/protocol-replication.html
|
||||
const HOT_STANDBY_FEEDBACK_TAG_BYTE: u8 = b'h';
|
||||
@@ -60,13 +58,6 @@ pub struct StandbyReply {
|
||||
pub reply_requested: bool,
|
||||
}
|
||||
|
||||
/// A network connection that's speaking the replication protocol.
|
||||
pub struct ReplicationConn {
|
||||
/// This is an `Option` because we will spawn a background thread that will
|
||||
/// `take` it from us.
|
||||
stream_in: Option<ReadStream>,
|
||||
}
|
||||
|
||||
/// Scope guard to unregister replication connection from timeline
|
||||
struct ReplicationConnGuard {
|
||||
replica: usize, // replica internal ID assigned by timeline
|
||||
@@ -79,230 +70,275 @@ impl Drop for ReplicationConnGuard {
|
||||
}
|
||||
}
|
||||
|
||||
impl ReplicationConn {
|
||||
/// Create a new `ReplicationConn`
|
||||
pub fn new(pgb: &mut PostgresBackend) -> Self {
|
||||
Self {
|
||||
stream_in: pgb.take_stream_in(),
|
||||
impl SafekeeperPostgresHandler {
|
||||
/// Wrapper around handle_start_replication_guts handling result. Error is
|
||||
/// handled here while we're still in walsender ttid span; with API
|
||||
/// extension, this can probably be moved into postgres_backend.
|
||||
pub async fn handle_start_replication(
|
||||
&mut self,
|
||||
pgb: &mut PostgresBackend,
|
||||
start_pos: Lsn,
|
||||
) -> Result<(), QueryError> {
|
||||
if let Err(end) = self.handle_start_replication_guts(pgb, start_pos).await {
|
||||
// Log the result and probably send it to the client, closing the stream.
|
||||
pgb.handle_copy_stream_end(end).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle incoming messages from the network.
|
||||
/// This is spawned into the background by `handle_start_replication`.
|
||||
fn background_thread(
|
||||
mut stream_in: ReadStream,
|
||||
replica_guard: Arc<ReplicationConnGuard>,
|
||||
) -> anyhow::Result<()> {
|
||||
let replica_id = replica_guard.replica;
|
||||
let timeline = &replica_guard.timeline;
|
||||
|
||||
let mut state = ReplicaState::new();
|
||||
// Wait for replica's feedback.
|
||||
while let Some(msg) = FeMessage::read(&mut stream_in)? {
|
||||
match &msg {
|
||||
FeMessage::CopyData(m) => {
|
||||
// There's three possible data messages that the client is supposed to send here:
|
||||
// `HotStandbyFeedback` and `StandbyStatusUpdate` and `NeonStandbyFeedback`.
|
||||
|
||||
match m.first().cloned() {
|
||||
Some(HOT_STANDBY_FEEDBACK_TAG_BYTE) => {
|
||||
// Note: deserializing is on m[1..] because we skip the tag byte.
|
||||
state.hs_feedback = HotStandbyFeedback::des(&m[1..])
|
||||
.context("failed to deserialize HotStandbyFeedback")?;
|
||||
timeline.update_replica_state(replica_id, state);
|
||||
}
|
||||
Some(STANDBY_STATUS_UPDATE_TAG_BYTE) => {
|
||||
let _reply = StandbyReply::des(&m[1..])
|
||||
.context("failed to deserialize StandbyReply")?;
|
||||
// This must be a regular postgres replica,
|
||||
// because pageserver doesn't send this type of messages to safekeeper.
|
||||
// Currently this is not implemented, so this message is ignored.
|
||||
|
||||
warn!("unexpected StandbyReply. Read-only postgres replicas are not supported in safekeepers yet.");
|
||||
// timeline.update_replica_state(replica_id, Some(state));
|
||||
}
|
||||
Some(NEON_STATUS_UPDATE_TAG_BYTE) => {
|
||||
// Note: deserializing is on m[9..] because we skip the tag byte and len bytes.
|
||||
let buf = Bytes::copy_from_slice(&m[9..]);
|
||||
let reply = ReplicationFeedback::parse(buf);
|
||||
|
||||
trace!("ReplicationFeedback is {:?}", reply);
|
||||
// Only pageserver sends ReplicationFeedback, so set the flag.
|
||||
// This replica is the source of information to resend to compute.
|
||||
state.pageserver_feedback = Some(reply);
|
||||
|
||||
timeline.update_replica_state(replica_id, state);
|
||||
}
|
||||
_ => warn!("unexpected message {:?}", msg),
|
||||
}
|
||||
}
|
||||
FeMessage::Sync => {}
|
||||
FeMessage::CopyFail => {
|
||||
// Shutdown the connection, because rust-postgres client cannot be dropped
|
||||
// when connection is alive.
|
||||
let _ = stream_in.shutdown(Shutdown::Both);
|
||||
anyhow::bail!("Copy failed");
|
||||
}
|
||||
_ => {
|
||||
// We only handle `CopyData`, 'Sync', 'CopyFail' messages. Anything else is ignored.
|
||||
info!("unexpected message {:?}", msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
///
|
||||
/// Handle START_REPLICATION replication command
|
||||
///
|
||||
pub fn run(
|
||||
pub async fn handle_start_replication_guts(
|
||||
&mut self,
|
||||
spg: &mut SafekeeperPostgresHandler,
|
||||
pgb: &mut PostgresBackend,
|
||||
mut start_pos: Lsn,
|
||||
) -> Result<(), QueryError> {
|
||||
let _enter = info_span!("WAL sender", ttid = %spg.ttid).entered();
|
||||
|
||||
let tli = GlobalTimelines::get(spg.ttid).map_err(|e| QueryError::Other(e.into()))?;
|
||||
|
||||
// spawn the background thread which receives HotStandbyFeedback messages.
|
||||
let bg_timeline = Arc::clone(&tli);
|
||||
let bg_stream_in = self.stream_in.take().unwrap();
|
||||
let bg_timeline_id = spg.timeline_id.unwrap();
|
||||
start_pos: Lsn,
|
||||
) -> Result<(), CopyStreamHandlerEnd> {
|
||||
let appname = self.appname.clone();
|
||||
let tli =
|
||||
GlobalTimelines::get(self.ttid).map_err(|e| CopyStreamHandlerEnd::Other(e.into()))?;
|
||||
|
||||
let state = ReplicaState::new();
|
||||
// This replica_id is used below to check if it's time to stop replication.
|
||||
let replica_id = bg_timeline.add_replica(state);
|
||||
let replica_id = tli.add_replica(state);
|
||||
|
||||
// Use a guard object to remove our entry from the timeline, when the background
|
||||
// thread and us have both finished using it.
|
||||
let replica_guard = Arc::new(ReplicationConnGuard {
|
||||
let _guard = Arc::new(ReplicationConnGuard {
|
||||
replica: replica_id,
|
||||
timeline: bg_timeline,
|
||||
timeline: tli.clone(),
|
||||
});
|
||||
let bg_replica_guard = Arc::clone(&replica_guard);
|
||||
|
||||
// TODO: here we got two threads, one for writing WAL and one for receiving
|
||||
// feedback. If one of them fails, we should shutdown the other one too.
|
||||
let _ = thread::Builder::new()
|
||||
.name("HotStandbyFeedback thread".into())
|
||||
.spawn(move || {
|
||||
let _enter =
|
||||
info_span!("HotStandbyFeedback thread", timeline = %bg_timeline_id).entered();
|
||||
if let Err(err) = Self::background_thread(bg_stream_in, bg_replica_guard) {
|
||||
error!("Replication background thread failed: {}", err);
|
||||
// Walproposer gets special handling: safekeeper must give proposer all
|
||||
// local WAL till the end, whether committed or not (walproposer will
|
||||
// hang otherwise). That's because walproposer runs the consensus and
|
||||
// synchronizes safekeepers on the most advanced one.
|
||||
//
|
||||
// There is a small risk of this WAL getting concurrently garbaged if
|
||||
// another compute rises which collects majority and starts fixing log
|
||||
// on this safekeeper itself. That's ok as (old) proposer will never be
|
||||
// able to commit such WAL.
|
||||
let stop_pos: Option<Lsn> = if self.is_walproposer_recovery() {
|
||||
let wal_end = tli.get_flush_lsn();
|
||||
Some(wal_end)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let end_pos = stop_pos.unwrap_or(Lsn::INVALID);
|
||||
|
||||
info!(
|
||||
"starting streaming from {:?} till {:?}",
|
||||
start_pos, stop_pos
|
||||
);
|
||||
|
||||
// switch to copy
|
||||
pgb.write_message(&BeMessage::CopyBothResponse).await?;
|
||||
|
||||
let (_, persisted_state) = tli.get_state();
|
||||
let wal_reader = WalReader::new(
|
||||
self.conf.workdir.clone(),
|
||||
self.conf.timeline_dir(&tli.ttid),
|
||||
&persisted_state,
|
||||
start_pos,
|
||||
self.conf.wal_backup_enabled,
|
||||
)?;
|
||||
|
||||
// Split to concurrently receive and send data; replies are generally
|
||||
// not synchronized with sends, so this avoids deadlocks.
|
||||
let reader = pgb.split().context("START_REPLICATION split")?;
|
||||
|
||||
let mut sender = WalSender {
|
||||
pgb,
|
||||
tli: tli.clone(),
|
||||
appname,
|
||||
start_pos,
|
||||
end_pos,
|
||||
stop_pos,
|
||||
commit_lsn_watch_rx: tli.get_commit_lsn_watch_rx(),
|
||||
replica_id,
|
||||
wal_reader,
|
||||
send_buf: [0; MAX_SEND_SIZE],
|
||||
};
|
||||
let mut reply_reader = ReplyReader {
|
||||
reader,
|
||||
tli,
|
||||
replica_id,
|
||||
feedback: ReplicaState::new(),
|
||||
};
|
||||
|
||||
let res = tokio::select! {
|
||||
// todo: add read|write .context to these errors
|
||||
r = sender.run() => r,
|
||||
r = reply_reader.run() => r,
|
||||
};
|
||||
// Join pg backend back.
|
||||
pgb.unsplit(reply_reader.reader)?;
|
||||
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
/// A half driving sending WAL.
|
||||
struct WalSender<'a> {
|
||||
pgb: &'a mut PostgresBackend,
|
||||
tli: Arc<Timeline>,
|
||||
appname: Option<String>,
|
||||
// Position since which we are sending next chunk.
|
||||
start_pos: Lsn,
|
||||
// WAL up to this position is known to be locally available.
|
||||
end_pos: Lsn,
|
||||
// If present, terminate after reaching this position; used by walproposer
|
||||
// in recovery.
|
||||
stop_pos: Option<Lsn>,
|
||||
commit_lsn_watch_rx: Receiver<Lsn>,
|
||||
replica_id: usize,
|
||||
wal_reader: WalReader,
|
||||
// buffer for readling WAL into to send it
|
||||
send_buf: [u8; MAX_SEND_SIZE],
|
||||
}
|
||||
|
||||
impl WalSender<'_> {
|
||||
/// Send WAL until
|
||||
/// - an error occurs
|
||||
/// - if we are streaming to walproposer, we've streamed until stop_pos
|
||||
/// (recovery finished)
|
||||
/// - receiver is caughtup and there is no computes
|
||||
///
|
||||
/// Err(CopyStreamHandlerEnd) is always returned; Result is used only for ?
|
||||
/// convenience.
|
||||
async fn run(&mut self) -> Result<(), CopyStreamHandlerEnd> {
|
||||
loop {
|
||||
// If we are streaming to walproposer, check it is time to stop.
|
||||
if let Some(stop_pos) = self.stop_pos {
|
||||
if self.start_pos >= stop_pos {
|
||||
// recovery finished
|
||||
return Err(CopyStreamHandlerEnd::ServerInitiated(format!(
|
||||
"ending streaming to walproposer at {}, recovery finished",
|
||||
self.start_pos
|
||||
)));
|
||||
}
|
||||
})?;
|
||||
|
||||
let runtime = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()?;
|
||||
|
||||
runtime.block_on(async move {
|
||||
let (inmem_state, persisted_state) = tli.get_state();
|
||||
// add persisted_state.timeline_start_lsn == Lsn(0) check
|
||||
|
||||
// Walproposer gets special handling: safekeeper must give proposer all
|
||||
// local WAL till the end, whether committed or not (walproposer will
|
||||
// hang otherwise). That's because walproposer runs the consensus and
|
||||
// synchronizes safekeepers on the most advanced one.
|
||||
//
|
||||
// There is a small risk of this WAL getting concurrently garbaged if
|
||||
// another compute rises which collects majority and starts fixing log
|
||||
// on this safekeeper itself. That's ok as (old) proposer will never be
|
||||
// able to commit such WAL.
|
||||
let stop_pos: Option<Lsn> = if spg.is_walproposer_recovery() {
|
||||
let wal_end = tli.get_flush_lsn();
|
||||
Some(wal_end)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
// Wait for the next portion if it is not there yet, or just
|
||||
// update our end of WAL available for sending value, we
|
||||
// communicate it to the receiver.
|
||||
self.wait_wal().await?;
|
||||
}
|
||||
|
||||
info!("Start replication from {:?} till {:?}", start_pos, stop_pos);
|
||||
// try to send as much as available, capped by MAX_SEND_SIZE
|
||||
let mut send_size = self
|
||||
.end_pos
|
||||
.checked_sub(self.start_pos)
|
||||
.context("reading wal without waiting for it first")?
|
||||
.0 as usize;
|
||||
send_size = min(send_size, self.send_buf.len());
|
||||
let send_buf = &mut self.send_buf[..send_size];
|
||||
// read wal into buffer
|
||||
send_size = self.wal_reader.read(send_buf).await?;
|
||||
let send_buf = &send_buf[..send_size];
|
||||
|
||||
// switch to copy
|
||||
pgb.write_message(&BeMessage::CopyBothResponse)?;
|
||||
|
||||
let mut end_pos = stop_pos.unwrap_or(inmem_state.commit_lsn);
|
||||
|
||||
let mut wal_reader = WalReader::new(
|
||||
spg.conf.workdir.clone(),
|
||||
spg.conf.timeline_dir(&tli.ttid),
|
||||
&persisted_state,
|
||||
start_pos,
|
||||
spg.conf.wal_backup_enabled,
|
||||
)?;
|
||||
|
||||
// buffer for wal sending, limited by MAX_SEND_SIZE
|
||||
let mut send_buf = vec![0u8; MAX_SEND_SIZE];
|
||||
|
||||
// watcher for commit_lsn updates
|
||||
let mut commit_lsn_watch_rx = tli.get_commit_lsn_watch_rx();
|
||||
|
||||
loop {
|
||||
if let Some(stop_pos) = stop_pos {
|
||||
if start_pos >= stop_pos {
|
||||
break; /* recovery finished */
|
||||
}
|
||||
end_pos = stop_pos;
|
||||
} else {
|
||||
/* Wait until we have some data to stream */
|
||||
let lsn = wait_for_lsn(&mut commit_lsn_watch_rx, start_pos).await?;
|
||||
|
||||
if let Some(lsn) = lsn {
|
||||
end_pos = lsn;
|
||||
} else {
|
||||
// TODO: also check once in a while whether we are walsender
|
||||
// to right pageserver.
|
||||
if tli.should_walsender_stop(replica_id) {
|
||||
// Shut down, timeline is suspended.
|
||||
return Err(QueryError::from(io::Error::new(
|
||||
io::ErrorKind::ConnectionAborted,
|
||||
format!("end streaming to {:?}", spg.appname),
|
||||
)));
|
||||
}
|
||||
|
||||
// timeout expired: request pageserver status
|
||||
pgb.write_message(&BeMessage::KeepAlive(WalSndKeepAlive {
|
||||
sent_ptr: end_pos.0,
|
||||
timestamp: get_current_timestamp(),
|
||||
request_reply: true,
|
||||
}))?;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
let send_size = end_pos.checked_sub(start_pos).unwrap().0 as usize;
|
||||
let send_size = min(send_size, send_buf.len());
|
||||
|
||||
let send_buf = &mut send_buf[..send_size];
|
||||
|
||||
// read wal into buffer
|
||||
let send_size = wal_reader.read(send_buf).await?;
|
||||
let send_buf = &send_buf[..send_size];
|
||||
|
||||
// Write some data to the network socket.
|
||||
pgb.write_message(&BeMessage::XLogData(XLogDataBody {
|
||||
wal_start: start_pos.0,
|
||||
wal_end: end_pos.0,
|
||||
// and send it
|
||||
self.pgb
|
||||
.write_message(&BeMessage::XLogData(XLogDataBody {
|
||||
wal_start: self.start_pos.0,
|
||||
wal_end: self.end_pos.0,
|
||||
timestamp: get_current_timestamp(),
|
||||
data: send_buf,
|
||||
}))
|
||||
.context("Failed to send XLogData")?;
|
||||
.await?;
|
||||
|
||||
start_pos += send_size as u64;
|
||||
trace!("sent WAL up to {}", start_pos);
|
||||
trace!(
|
||||
"sent {} bytes of WAL {}-{}",
|
||||
send_size,
|
||||
self.start_pos,
|
||||
self.start_pos + send_size as u64
|
||||
);
|
||||
self.start_pos += send_size as u64;
|
||||
}
|
||||
}
|
||||
|
||||
/// wait until we have WAL to stream, sending keepalives and checking for
|
||||
/// exit in the meanwhile
|
||||
async fn wait_wal(&mut self) -> Result<(), CopyStreamHandlerEnd> {
|
||||
loop {
|
||||
if let Some(lsn) = wait_for_lsn(&mut self.commit_lsn_watch_rx, self.start_pos).await? {
|
||||
self.end_pos = lsn;
|
||||
return Ok(());
|
||||
}
|
||||
// Timed out waiting for WAL, check for termination and send KA
|
||||
if self.tli.should_walsender_stop(self.replica_id) {
|
||||
// Terminate if there is nothing more to send.
|
||||
// TODO close the stream properly
|
||||
return Err(CopyStreamHandlerEnd::ServerInitiated(format!(
|
||||
"ending streaming to {:?} at {}, receiver is caughtup and there is no computes",
|
||||
self.appname, self.start_pos,
|
||||
)));
|
||||
}
|
||||
self.pgb
|
||||
.write_message(&BeMessage::KeepAlive(WalSndKeepAlive {
|
||||
sent_ptr: self.end_pos.0,
|
||||
timestamp: get_current_timestamp(),
|
||||
request_reply: true,
|
||||
}))
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})
|
||||
/// A half driving receiving replies.
|
||||
struct ReplyReader {
|
||||
reader: PostgresBackendReader,
|
||||
tli: Arc<Timeline>,
|
||||
replica_id: usize,
|
||||
feedback: ReplicaState,
|
||||
}
|
||||
|
||||
impl ReplyReader {
|
||||
async fn run(&mut self) -> Result<(), CopyStreamHandlerEnd> {
|
||||
loop {
|
||||
let msg = self.reader.read_copy_message().await?;
|
||||
self.handle_feedback(&msg)?
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_feedback(&mut self, msg: &Bytes) -> anyhow::Result<()> {
|
||||
match msg.first().cloned() {
|
||||
Some(HOT_STANDBY_FEEDBACK_TAG_BYTE) => {
|
||||
// Note: deserializing is on m[1..] because we skip the tag byte.
|
||||
self.feedback.hs_feedback = HotStandbyFeedback::des(&msg[1..])
|
||||
.context("failed to deserialize HotStandbyFeedback")?;
|
||||
self.tli
|
||||
.update_replica_state(self.replica_id, self.feedback);
|
||||
}
|
||||
Some(STANDBY_STATUS_UPDATE_TAG_BYTE) => {
|
||||
let _reply =
|
||||
StandbyReply::des(&msg[1..]).context("failed to deserialize StandbyReply")?;
|
||||
// This must be a regular postgres replica,
|
||||
// because pageserver doesn't send this type of messages to safekeeper.
|
||||
// Currently we just ignore this, tracking progress for them is not supported.
|
||||
}
|
||||
Some(NEON_STATUS_UPDATE_TAG_BYTE) => {
|
||||
// pageserver sends this.
|
||||
// Note: deserializing is on m[9..] because we skip the tag byte and len bytes.
|
||||
let buf = Bytes::copy_from_slice(&msg[9..]);
|
||||
let reply = ReplicationFeedback::parse(buf);
|
||||
|
||||
trace!("ReplicationFeedback is {:?}", reply);
|
||||
// Only pageserver sends ReplicationFeedback, so set the flag.
|
||||
// This replica is the source of information to resend to compute.
|
||||
self.feedback.pageserver_feedback = Some(reply);
|
||||
|
||||
self.tli
|
||||
.update_replica_state(self.replica_id, self.feedback);
|
||||
}
|
||||
_ => warn!("unexpected message {:?}", msg),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
const POLL_STATE_TIMEOUT: Duration = Duration::from_secs(1);
|
||||
|
||||
// Wait until we have commit_lsn > lsn or timeout expires. Returns latest commit_lsn.
|
||||
/// Wait until we have commit_lsn > lsn or timeout expires. Returns
|
||||
/// - Ok(Some(commit_lsn)) if needed lsn is successfully observed;
|
||||
/// - Ok(None) if timeout expired;
|
||||
/// - Err in case of error (if watch channel is in trouble, shouldn't happen).
|
||||
async fn wait_for_lsn(rx: &mut Receiver<Lsn>, lsn: Lsn) -> anyhow::Result<Option<Lsn>> {
|
||||
let commit_lsn: Lsn = *rx.borrow();
|
||||
if commit_lsn > lsn {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//! This module implements Timeline lifecycle management and has all neccessary code
|
||||
//! This module implements Timeline lifecycle management and has all necessary code
|
||||
//! to glue together SafeKeeper and all other background services.
|
||||
|
||||
use anyhow::{anyhow, bail, Result};
|
||||
@@ -532,7 +532,7 @@ impl Timeline {
|
||||
|
||||
/// Register compute connection, starting timeline-related activity if it is
|
||||
/// not running yet.
|
||||
pub fn on_compute_connect(&self) -> Result<()> {
|
||||
pub async fn on_compute_connect(&self) -> Result<()> {
|
||||
if self.is_cancelled() {
|
||||
bail!(TimelineError::Cancelled(self.ttid));
|
||||
}
|
||||
@@ -546,7 +546,7 @@ impl Timeline {
|
||||
// Wake up wal backup launcher, if offloading not started yet.
|
||||
if is_wal_backup_action_pending {
|
||||
// Can fail only if channel to a static thread got closed, which is not normal at all.
|
||||
self.wal_backup_launcher_tx.blocking_send(self.ttid)?;
|
||||
self.wal_backup_launcher_tx.send(self.ttid).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -563,6 +563,11 @@ impl Timeline {
|
||||
// Wake up wal backup launcher, if it is time to stop the offloading.
|
||||
if is_wal_backup_action_pending {
|
||||
// Can fail only if channel to a static thread got closed, which is not normal at all.
|
||||
//
|
||||
// Note: this is blocking_send because on_compute_disconnect is called in Drop, there is
|
||||
// no async Drop and we use current thread runtimes. With current thread rt spawning
|
||||
// task in drop impl is racy, as thread along with runtime might finish before the task.
|
||||
// This should be switched send.await when/if we go to full async.
|
||||
self.wal_backup_launcher_tx.blocking_send(self.ttid)?;
|
||||
}
|
||||
Ok(())
|
||||
|
||||
@@ -171,7 +171,7 @@ impl GlobalTimelines {
|
||||
|
||||
/// Create a new timeline with the given id. If the timeline already exists, returns
|
||||
/// an existing timeline.
|
||||
pub fn create(
|
||||
pub async fn create(
|
||||
ttid: TenantTimelineId,
|
||||
server_info: ServerInfo,
|
||||
commit_lsn: Lsn,
|
||||
@@ -199,28 +199,20 @@ impl GlobalTimelines {
|
||||
|
||||
// Take a lock and finish the initialization holding this mutex. No other threads
|
||||
// can interfere with creation after we will insert timeline into the map.
|
||||
let mut shared_state = timeline.write_shared_state();
|
||||
{
|
||||
let mut shared_state = timeline.write_shared_state();
|
||||
|
||||
// We can get a race condition here in case of concurrent create calls, but only
|
||||
// in theory. create() will return valid timeline on the next try.
|
||||
TIMELINES_STATE
|
||||
.lock()
|
||||
.unwrap()
|
||||
.try_insert(timeline.clone())?;
|
||||
// We can get a race condition here in case of concurrent create calls, but only
|
||||
// in theory. create() will return valid timeline on the next try.
|
||||
TIMELINES_STATE
|
||||
.lock()
|
||||
.unwrap()
|
||||
.try_insert(timeline.clone())?;
|
||||
|
||||
// Write the new timeline to the disk and start background workers.
|
||||
// Bootstrap is transactional, so if it fails, the timeline will be deleted,
|
||||
// and the state on disk should remain unchanged.
|
||||
match timeline.bootstrap(&mut shared_state) {
|
||||
Ok(_) => {
|
||||
// We are done with bootstrap, release the lock, return the timeline.
|
||||
drop(shared_state);
|
||||
timeline
|
||||
.wal_backup_launcher_tx
|
||||
.blocking_send(timeline.ttid)?;
|
||||
Ok(timeline)
|
||||
}
|
||||
Err(e) => {
|
||||
// Write the new timeline to the disk and start background workers.
|
||||
// Bootstrap is transactional, so if it fails, the timeline will be deleted,
|
||||
// and the state on disk should remain unchanged.
|
||||
if let Err(e) = timeline.bootstrap(&mut shared_state) {
|
||||
// Note: the most likely reason for bootstrap failure is that the timeline
|
||||
// directory already exists on disk. This happens when timeline is corrupted
|
||||
// and wasn't loaded from disk on startup because of that. We want to preserve
|
||||
@@ -232,9 +224,13 @@ impl GlobalTimelines {
|
||||
|
||||
// Timeline failed to bootstrap, it cannot be used. Remove it from the map.
|
||||
TIMELINES_STATE.lock().unwrap().timelines.remove(&ttid);
|
||||
Err(e)
|
||||
return Err(e);
|
||||
}
|
||||
// We are done with bootstrap, release the lock, return the timeline.
|
||||
// {} block forces release before .await
|
||||
}
|
||||
timeline.wal_backup_launcher_tx.send(timeline.ttid).await?;
|
||||
Ok(timeline)
|
||||
}
|
||||
|
||||
/// Get a timeline from the global map. If it's not present, it doesn't exist on disk,
|
||||
@@ -254,7 +250,7 @@ impl GlobalTimelines {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns all timelines. This is used for background timeline proccesses.
|
||||
/// Returns all timelines. This is used for background timeline processes.
|
||||
pub fn get_all() -> Vec<Arc<Timeline>> {
|
||||
let global_lock = TIMELINES_STATE.lock().unwrap();
|
||||
global_lock
|
||||
|
||||
@@ -191,7 +191,7 @@ async fn wal_backup_launcher_main_loop(
|
||||
.map(|c| GenericRemoteStorage::from_config(c).expect("failed to create remote storage"))
|
||||
});
|
||||
|
||||
// Presense in this map means launcher is aware s3 offloading is needed for
|
||||
// Presence in this map means launcher is aware s3 offloading is needed for
|
||||
// the timeline, but task is started only if it makes sense for to offload
|
||||
// from this safekeeper.
|
||||
let mut tasks: HashMap<TenantTimelineId, WalBackupTimelineEntry> = HashMap::new();
|
||||
@@ -467,7 +467,7 @@ async fn backup_object(source_file: &Path, target_file: &RemotePath, size: usize
|
||||
pub async fn read_object(
|
||||
file_path: &RemotePath,
|
||||
offset: u64,
|
||||
) -> anyhow::Result<Pin<Box<dyn tokio::io::AsyncRead>>> {
|
||||
) -> anyhow::Result<Pin<Box<dyn tokio::io::AsyncRead + Send + Sync>>> {
|
||||
let storage = REMOTE_STORAGE
|
||||
.get()
|
||||
.context("Failed to get remote storage")?
|
||||
|
||||
@@ -2,50 +2,65 @@
|
||||
//! WAL service listens for client connections and
|
||||
//! receive WAL from wal_proposer and send it to WAL receivers
|
||||
//!
|
||||
use regex::Regex;
|
||||
use std::net::{TcpListener, TcpStream};
|
||||
use std::thread;
|
||||
use anyhow::{Context, Result};
|
||||
use nix::unistd::gettid;
|
||||
use postgres_backend::QueryError;
|
||||
use std::{future, thread};
|
||||
use tokio::net::TcpStream;
|
||||
use tracing::*;
|
||||
use utils::postgres_backend::QueryError;
|
||||
|
||||
use crate::handler::SafekeeperPostgresHandler;
|
||||
use crate::SafeKeeperConf;
|
||||
use utils::postgres_backend::{AuthType, PostgresBackend};
|
||||
use postgres_backend::{AuthType, PostgresBackend};
|
||||
|
||||
/// Accept incoming TCP connections and spawn them into a background thread.
|
||||
pub fn thread_main(conf: SafeKeeperConf, listener: TcpListener) -> ! {
|
||||
loop {
|
||||
match listener.accept() {
|
||||
Ok((socket, peer_addr)) => {
|
||||
debug!("accepted connection from {}", peer_addr);
|
||||
let conf = conf.clone();
|
||||
pub fn thread_main(conf: SafeKeeperConf, pg_listener: std::net::TcpListener) {
|
||||
let runtime = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.context("create runtime")
|
||||
// todo catch error in main thread
|
||||
.expect("failed to create runtime");
|
||||
|
||||
let _ = thread::Builder::new()
|
||||
.name("WAL service thread".into())
|
||||
.spawn(move || {
|
||||
if let Err(err) = handle_socket(socket, conf) {
|
||||
error!("connection handler exited: {}", err);
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
runtime
|
||||
.block_on(async move {
|
||||
// Tokio's from_std won't do this for us, per its comment.
|
||||
pg_listener.set_nonblocking(true)?;
|
||||
let listener = tokio::net::TcpListener::from_std(pg_listener)?;
|
||||
|
||||
loop {
|
||||
match listener.accept().await {
|
||||
Ok((socket, peer_addr)) => {
|
||||
debug!("accepted connection from {}", peer_addr);
|
||||
let conf = conf.clone();
|
||||
|
||||
let _ = thread::Builder::new()
|
||||
.name("WAL service thread".into())
|
||||
.spawn(move || {
|
||||
if let Err(err) = handle_socket(socket, conf) {
|
||||
error!("connection handler exited: {}", err);
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
Err(e) => error!("Failed to accept connection: {}", e),
|
||||
}
|
||||
}
|
||||
Err(e) => error!("Failed to accept connection: {}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get unique thread id (Rust internal), with ThreadId removed for shorter printing
|
||||
fn get_tid() -> u64 {
|
||||
let tids = format!("{:?}", thread::current().id());
|
||||
let r = Regex::new(r"ThreadId\((\d+)\)").unwrap();
|
||||
let caps = r.captures(&tids).unwrap();
|
||||
caps.get(1).unwrap().as_str().parse().unwrap()
|
||||
#[allow(unreachable_code)] // hint compiler the closure return type
|
||||
Ok::<(), anyhow::Error>(())
|
||||
})
|
||||
.expect("listener failed")
|
||||
}
|
||||
|
||||
/// This is run by `thread_main` above, inside a background thread.
|
||||
///
|
||||
fn handle_socket(socket: TcpStream, conf: SafeKeeperConf) -> Result<(), QueryError> {
|
||||
let _enter = info_span!("", tid = ?get_tid()).entered();
|
||||
let _enter = info_span!("", tid = %gettid()).entered();
|
||||
|
||||
let runtime = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()?;
|
||||
let local = tokio::task::LocalSet::new();
|
||||
|
||||
socket.set_nodelay(true)?;
|
||||
|
||||
@@ -54,9 +69,13 @@ fn handle_socket(socket: TcpStream, conf: SafeKeeperConf) -> Result<(), QueryErr
|
||||
Some(_) => AuthType::NeonJWT,
|
||||
};
|
||||
let mut conn_handler = SafekeeperPostgresHandler::new(conf);
|
||||
let pgbackend = PostgresBackend::new(socket, auth_type, None, false)?;
|
||||
// libpq replication protocol between safekeeper and replicas/pagers
|
||||
pgbackend.run(&mut conn_handler)?;
|
||||
let pgbackend = PostgresBackend::new(socket, auth_type, None)?;
|
||||
// libpq protocol between safekeeper and walproposer / pageserver
|
||||
// We don't use shutdown.
|
||||
local.block_on(
|
||||
&runtime,
|
||||
pgbackend.run(&mut conn_handler, future::pending::<()>),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -471,7 +471,7 @@ pub struct WalReader {
|
||||
timeline_dir: PathBuf,
|
||||
wal_seg_size: usize,
|
||||
pos: Lsn,
|
||||
wal_segment: Option<Pin<Box<dyn AsyncRead>>>,
|
||||
wal_segment: Option<Pin<Box<dyn AsyncRead + Send + Sync>>>,
|
||||
|
||||
// S3 will be used to read WAL if LSN is not available locally
|
||||
enable_remote_read: bool,
|
||||
@@ -538,7 +538,7 @@ impl WalReader {
|
||||
}
|
||||
|
||||
/// Open WAL segment at the current position of the reader.
|
||||
async fn open_segment(&self) -> Result<Pin<Box<dyn AsyncRead>>> {
|
||||
async fn open_segment(&self) -> Result<Pin<Box<dyn AsyncRead + Send + Sync>>> {
|
||||
let xlogoff = self.pos.segment_offset(self.wal_seg_size);
|
||||
let segno = self.pos.segment_number(self.wal_seg_size);
|
||||
let wal_file_name = XLogFileName(PG_TLI, segno, self.wal_seg_size);
|
||||
|
||||
@@ -2068,8 +2068,10 @@ class NeonPageserver(PgProtocol):
|
||||
".*Connection aborted: connection error: error communicating with the server: Broken pipe.*",
|
||||
".*Connection aborted: connection error: error communicating with the server: Transport endpoint is not connected.*",
|
||||
".*Connection aborted: connection error: error communicating with the server: Connection reset by peer.*",
|
||||
# FIXME: replication patch for tokio_postgres regards any but CopyDone/CopyData message in CopyBoth stream as unexpected
|
||||
".*Connection aborted: connection error: unexpected message from server*",
|
||||
".*kill_and_wait_impl.*: wait successful.*",
|
||||
".*Replication stream finished: db error: ERROR: Socket IO error: end streaming to Some.*",
|
||||
".*Replication stream finished: db error:.*ending streaming to Some*",
|
||||
".*query handler for 'pagestream.*failed: Broken pipe.*", # pageserver notices compute shut down
|
||||
".*query handler for 'pagestream.*failed: Connection reset by peer.*", # pageserver notices compute shut down
|
||||
# safekeeper connection can fail with this, in the window between timeline creation
|
||||
|
||||
@@ -1138,8 +1138,8 @@ def test_delete_force(neon_env_builder: NeonEnvBuilder, auth_enabled: bool):
|
||||
# FIXME: are these expected?
|
||||
env.pageserver.allowed_errors.extend(
|
||||
[
|
||||
".*Failed to process query for timeline .*: Timeline .* was not found in global map.*",
|
||||
".*Failed to process query for timeline .*: Timeline .* was cancelled and cannot be used anymore.*",
|
||||
".*Timeline .* was not found in global map.*",
|
||||
".*Timeline .* was cancelled and cannot be used anymore.*",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -14,14 +14,19 @@ publish = false
|
||||
### BEGIN HAKARI SECTION
|
||||
[dependencies]
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
byteorder = { version = "1" }
|
||||
bytes = { version = "1", features = ["serde"] }
|
||||
chrono = { version = "0.4", default-features = false, features = ["clock", "serde"] }
|
||||
clap = { version = "4", features = ["derive", "string"] }
|
||||
crossbeam-utils = { version = "0.8" }
|
||||
digest = { version = "0.10", features = ["mac", "std"] }
|
||||
either = { version = "1" }
|
||||
fail = { version = "0.5", default-features = false, features = ["failpoints"] }
|
||||
futures = { version = "0.3" }
|
||||
futures-channel = { version = "0.3", features = ["sink"] }
|
||||
futures-core = { version = "0.3" }
|
||||
futures-executor = { version = "0.3" }
|
||||
futures-sink = { version = "0.3" }
|
||||
futures-util = { version = "0.3", features = ["channel", "io", "sink"] }
|
||||
hashbrown = { version = "0.12", features = ["raw"] }
|
||||
indexmap = { version = "1", default-features = false, features = ["std"] }
|
||||
@@ -45,6 +50,7 @@ serde = { version = "1", features = ["alloc", "derive"] }
|
||||
serde_json = { version = "1", features = ["raw_value"] }
|
||||
socket2 = { version = "0.4", default-features = false, features = ["all"] }
|
||||
tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "net", "process", "rt-multi-thread", "signal", "sync", "time"] }
|
||||
tokio-rustls = { version = "0.23" }
|
||||
tokio-util = { version = "0.7", features = ["codec", "io"] }
|
||||
tonic = { version = "0.8", features = ["tls-roots"] }
|
||||
tower = { version = "0.4", features = ["balance", "buffer", "limit", "retry", "timeout", "util"] }
|
||||
|
||||
Reference in New Issue
Block a user