mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-17 10:22:56 +00:00
Compare commits
4 Commits
release-pr
...
asher/sk-w
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f281dc5953 | ||
|
|
48fb085ebd | ||
|
|
2bbd24edbf | ||
|
|
5e972ccdc4 |
7
Cargo.lock
generated
7
Cargo.lock
generated
@@ -2510,6 +2510,7 @@ name = "pq_proto"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"pin-project-lite",
|
||||
"postgres-protocol",
|
||||
@@ -2517,6 +2518,7 @@ dependencies = [
|
||||
"serde",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
"workspace_hack",
|
||||
]
|
||||
@@ -3074,6 +3076,7 @@ dependencies = [
|
||||
"const_format",
|
||||
"crc32c",
|
||||
"fs2",
|
||||
"futures",
|
||||
"git-version",
|
||||
"hex",
|
||||
"humantime",
|
||||
@@ -3082,6 +3085,7 @@ dependencies = [
|
||||
"nix",
|
||||
"once_cell",
|
||||
"parking_lot",
|
||||
"pin-project-lite",
|
||||
"postgres",
|
||||
"postgres-protocol",
|
||||
"postgres_ffi",
|
||||
@@ -4203,6 +4207,7 @@ dependencies = [
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"criterion",
|
||||
"futures",
|
||||
"git-version",
|
||||
"hex",
|
||||
"hex-literal",
|
||||
@@ -4211,6 +4216,7 @@ dependencies = [
|
||||
"metrics",
|
||||
"nix",
|
||||
"once_cell",
|
||||
"pin-utils",
|
||||
"pq_proto",
|
||||
"rand",
|
||||
"routerify",
|
||||
@@ -4228,6 +4234,7 @@ dependencies = [
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"workspace_hack",
|
||||
|
||||
@@ -14,7 +14,7 @@ use anyhow::{Context, Result};
|
||||
use utils::{
|
||||
id::{TenantId, TimelineId},
|
||||
lsn::Lsn,
|
||||
postgres_backend::AuthType,
|
||||
postgres_backend_async::AuthType,
|
||||
};
|
||||
|
||||
use crate::local_env::{LocalEnv, DEFAULT_PG_VERSION};
|
||||
|
||||
@@ -19,7 +19,7 @@ use std::process::{Command, Stdio};
|
||||
use utils::{
|
||||
auth::{encode_from_key_file, Claims, Scope},
|
||||
id::{NodeId, TenantId, TenantTimelineId, TimelineId},
|
||||
postgres_backend::AuthType,
|
||||
postgres_backend_async::AuthType,
|
||||
};
|
||||
|
||||
use crate::safekeeper::SafekeeperNode;
|
||||
|
||||
@@ -7,11 +7,13 @@ license = "Apache-2.0"
|
||||
[dependencies]
|
||||
anyhow = "1.0"
|
||||
bytes = "1.0.1"
|
||||
byteorder = "1.4.3"
|
||||
pin-project-lite = "0.2.7"
|
||||
postgres-protocol = { git = "https://github.com/neondatabase/rust-postgres.git", rev="43e6db254a97fdecbce33d8bc0890accfd74495e" }
|
||||
rand = "0.8.3"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
tokio = { version = "1.17", features = ["macros"] }
|
||||
tokio-util = { version = "0.7.3" }
|
||||
tracing = "0.1"
|
||||
thiserror = "1.0"
|
||||
|
||||
|
||||
62
libs/pq_proto/src/codec.rs
Normal file
62
libs/pq_proto/src/codec.rs
Normal file
@@ -0,0 +1,62 @@
|
||||
//! Provides `PostgresCodec` defining how to serilize/deserialize Postgres
|
||||
//! messages to/from the wire, to be used with `tokio_util::codec::Framed`.
|
||||
use std::io;
|
||||
|
||||
use bytes::BytesMut;
|
||||
use tokio_util::codec::{Decoder, Encoder};
|
||||
|
||||
use crate::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
|
||||
|
||||
// Defines how to serilize/deserialize Postgres messages to/from the wire, to be
|
||||
// used with `tokio_util::codec::Framed`.
|
||||
pub struct PostgresCodec {
|
||||
// Have we already decoded startup message? All further should start with
|
||||
// message type byte then.
|
||||
startup_read: bool,
|
||||
}
|
||||
|
||||
impl PostgresCodec {
|
||||
pub fn new() -> Self {
|
||||
PostgresCodec {
|
||||
startup_read: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Error on postgres connection: either IO (physical transport error) or
|
||||
/// protocol violation.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum ConnectionError {
|
||||
#[error(transparent)]
|
||||
Io(#[from] io::Error),
|
||||
#[error(transparent)]
|
||||
Protocol(#[from] ProtocolError),
|
||||
}
|
||||
|
||||
impl Encoder<&BeMessage<'_>> for PostgresCodec {
|
||||
type Error = ConnectionError;
|
||||
|
||||
fn encode(&mut self, item: &BeMessage, dst: &mut BytesMut) -> Result<(), ConnectionError> {
|
||||
BeMessage::write(dst, &item)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for PostgresCodec {
|
||||
type Item = FeMessage;
|
||||
type Error = ConnectionError;
|
||||
|
||||
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<FeMessage>, ConnectionError> {
|
||||
let msg = if !self.startup_read {
|
||||
let msg = FeStartupPacket::parse(src);
|
||||
if let Ok(Some(FeMessage::StartupPacket(FeStartupPacket::StartupMessage { .. }))) = msg
|
||||
{
|
||||
self.startup_read = true;
|
||||
}
|
||||
msg?
|
||||
} else {
|
||||
FeMessage::parse(src)?
|
||||
};
|
||||
Ok(msg)
|
||||
}
|
||||
}
|
||||
@@ -3,9 +3,11 @@
|
||||
//! on message formats.
|
||||
|
||||
// Tools for calling certain async methods in sync contexts.
|
||||
pub mod codec;
|
||||
pub mod sync;
|
||||
|
||||
use anyhow::{ensure, Context, Result};
|
||||
use anyhow::{anyhow, bail, ensure, Context, Result};
|
||||
use byteorder::{BigEndian, ByteOrder, ReadBytesExt};
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
use postgres_protocol::PG_EPOCH;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -19,7 +21,7 @@ use std::{
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
use sync::{AsyncishRead, SyncFuture};
|
||||
use tokio::io::AsyncReadExt;
|
||||
// use tokio::io::AsyncReadExt;
|
||||
use tracing::{trace, warn};
|
||||
|
||||
pub type Oid = u32;
|
||||
@@ -194,36 +196,108 @@ macro_rules! retry_read {
|
||||
};
|
||||
}
|
||||
|
||||
/// An error occured during connection being open.
|
||||
/// An error occured while parsing or serializing raw stream into Postgres
|
||||
/// messages.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum ConnectionError {
|
||||
pub enum ProtocolError {
|
||||
/// IO error during writing to or reading from the connection socket.
|
||||
/// removeme
|
||||
#[error("Socket IO error: {0}")]
|
||||
Socket(std::io::Error),
|
||||
/// Invalid packet was received from client
|
||||
/// Invalid packet was received from the client (e.g. unexpected message
|
||||
/// type or broken len).
|
||||
#[error("Protocol error: {0}")]
|
||||
Protocol(String),
|
||||
/// Failed to parse a protocol mesage
|
||||
/// Failed to parse or, (unlikely), serialize a protocol message.
|
||||
#[error("Message parse error: {0}")]
|
||||
MessageParse(anyhow::Error),
|
||||
}
|
||||
|
||||
impl From<anyhow::Error> for ConnectionError {
|
||||
// Allows to return anyhow error from msg parsing routines, meaning less typing.
|
||||
impl From<anyhow::Error> for ProtocolError {
|
||||
fn from(e: anyhow::Error) -> Self {
|
||||
Self::MessageParse(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl ConnectionError {
|
||||
impl ProtocolError {
|
||||
pub fn into_io_error(self) -> io::Error {
|
||||
match self {
|
||||
ConnectionError::Socket(io) => io,
|
||||
ProtocolError::Socket(io) => io,
|
||||
other => io::Error::new(io::ErrorKind::Other, other.to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FeMessage {
|
||||
/// Read and parse one message from the `buf` input buffer. If there is at
|
||||
/// least one valid message, returns it, advancing `buf`; redundant copies
|
||||
/// are avoided, as thanks to `bytes` crate ptrs in parsed message point
|
||||
/// directly into the `buf` (processed data is garbage collected after
|
||||
/// parsed message is dropped).
|
||||
///
|
||||
/// Returns None if `buf` doesn't contain enough data for a single message.
|
||||
/// For efficiency, tries to reserve large enough space in `buf` for the
|
||||
/// next message in this case.
|
||||
///
|
||||
/// Returns Error if message is malformed, the only possible ErrorKind is
|
||||
/// InvalidInput.
|
||||
//
|
||||
// Inspired by rust-postgres Message::parse.
|
||||
pub fn parse(buf: &mut BytesMut) -> Result<Option<FeMessage>, ProtocolError> {
|
||||
// Every message contains message type byte and 4 bytes len; can't do
|
||||
// much without them.
|
||||
if buf.len() < 5 {
|
||||
let to_read = 5 - buf.len();
|
||||
buf.reserve(to_read);
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// We shouldn't advance `buf` as probably full message is not there yet,
|
||||
// so can't directly use Bytes::get_u32 etc.
|
||||
let tag = buf[0];
|
||||
let len = (&buf[1..5]).read_u32::<BigEndian>().unwrap();
|
||||
if len < 4 {
|
||||
return Err(ProtocolError::Protocol(format!(
|
||||
"invalid message length {}",
|
||||
len
|
||||
)));
|
||||
}
|
||||
|
||||
// lenth field includes itself, but not message type.
|
||||
let total_len = len as usize + 1;
|
||||
if buf.len() < total_len {
|
||||
// Don't have full message yet.
|
||||
let to_read = total_len - buf.len();
|
||||
buf.reserve(to_read);
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// got the message, advance buffer
|
||||
let mut msg = buf.split_to(total_len).freeze();
|
||||
msg.advance(5); // consume message type and len
|
||||
|
||||
match tag {
|
||||
b'Q' => Ok(Some(FeMessage::Query(msg))),
|
||||
b'P' => Ok(Some(FeParseMessage::parse(msg)?)),
|
||||
b'D' => Ok(Some(FeDescribeMessage::parse(msg)?)),
|
||||
b'E' => Ok(Some(FeExecuteMessage::parse(msg)?)),
|
||||
b'B' => Ok(Some(FeBindMessage::parse(msg)?)),
|
||||
b'C' => Ok(Some(FeCloseMessage::parse(msg)?)),
|
||||
b'S' => Ok(Some(FeMessage::Sync)),
|
||||
b'X' => Ok(Some(FeMessage::Terminate)),
|
||||
b'd' => Ok(Some(FeMessage::CopyData(msg))),
|
||||
b'c' => Ok(Some(FeMessage::CopyDone)),
|
||||
b'f' => Ok(Some(FeMessage::CopyFail)),
|
||||
b'p' => Ok(Some(FeMessage::PasswordMessage(msg))),
|
||||
tag => {
|
||||
return Err(ProtocolError::Protocol(format!(
|
||||
"unknown message tag: {tag},'{msg:?}'"
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Read one message from the stream.
|
||||
/// This function returns `Ok(None)` in case of EOF.
|
||||
/// One way to handle this properly:
|
||||
@@ -245,68 +319,8 @@ impl FeMessage {
|
||||
/// }
|
||||
/// ```
|
||||
#[inline(never)]
|
||||
pub fn read(
|
||||
stream: &mut (impl io::Read + Unpin),
|
||||
) -> Result<Option<FeMessage>, ConnectionError> {
|
||||
Self::read_fut(&mut AsyncishRead(stream)).wait()
|
||||
}
|
||||
|
||||
/// Read one message from the stream.
|
||||
/// See documentation for `Self::read`.
|
||||
pub fn read_fut<Reader>(
|
||||
stream: &mut Reader,
|
||||
) -> SyncFuture<Reader, impl Future<Output = Result<Option<FeMessage>, ConnectionError>> + '_>
|
||||
where
|
||||
Reader: tokio::io::AsyncRead + Unpin,
|
||||
{
|
||||
// We return a Future that's sync (has a `wait` method) if and only if the provided stream is SyncProof.
|
||||
// SyncFuture contract: we are only allowed to await on sync-proof futures, the AsyncRead and
|
||||
// AsyncReadExt methods of the stream.
|
||||
SyncFuture::new(async move {
|
||||
// Each libpq message begins with a message type byte, followed by message length
|
||||
// If the client closes the connection, return None. But if the client closes the
|
||||
// connection in the middle of a message, we will return an error.
|
||||
let tag = match retry_read!(stream.read_u8().await) {
|
||||
Ok(b) => b,
|
||||
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
|
||||
Err(e) => return Err(ConnectionError::Socket(e)),
|
||||
};
|
||||
|
||||
// The message length includes itself, so it better be at least 4.
|
||||
let len = retry_read!(stream.read_u32().await)
|
||||
.map_err(ConnectionError::Socket)?
|
||||
.checked_sub(4)
|
||||
.ok_or_else(|| ConnectionError::Protocol("invalid message length".to_string()))?;
|
||||
|
||||
let body = {
|
||||
let mut buffer = vec![0u8; len as usize];
|
||||
stream
|
||||
.read_exact(&mut buffer)
|
||||
.await
|
||||
.map_err(ConnectionError::Socket)?;
|
||||
Bytes::from(buffer)
|
||||
};
|
||||
|
||||
match tag {
|
||||
b'Q' => Ok(Some(FeMessage::Query(body))),
|
||||
b'P' => Ok(Some(FeParseMessage::parse(body)?)),
|
||||
b'D' => Ok(Some(FeDescribeMessage::parse(body)?)),
|
||||
b'E' => Ok(Some(FeExecuteMessage::parse(body)?)),
|
||||
b'B' => Ok(Some(FeBindMessage::parse(body)?)),
|
||||
b'C' => Ok(Some(FeCloseMessage::parse(body)?)),
|
||||
b'S' => Ok(Some(FeMessage::Sync)),
|
||||
b'X' => Ok(Some(FeMessage::Terminate)),
|
||||
b'd' => Ok(Some(FeMessage::CopyData(body))),
|
||||
b'c' => Ok(Some(FeMessage::CopyDone)),
|
||||
b'f' => Ok(Some(FeMessage::CopyFail)),
|
||||
b'p' => Ok(Some(FeMessage::PasswordMessage(body))),
|
||||
tag => {
|
||||
return Err(ConnectionError::Protocol(format!(
|
||||
"unknown message tag: {tag},'{body:?}'"
|
||||
)))
|
||||
}
|
||||
}
|
||||
})
|
||||
pub fn read(_stream: &mut (impl io::Read + Unpin)) -> Result<Option<FeMessage>, ProtocolError> {
|
||||
Ok(None) // removeme
|
||||
}
|
||||
}
|
||||
|
||||
@@ -314,21 +328,124 @@ impl FeStartupPacket {
|
||||
/// Read startup message from the stream.
|
||||
// XXX: It's tempting yet undesirable to accept `stream` by value,
|
||||
// since such a change will cause user-supplied &mut references to be consumed
|
||||
pub fn read(
|
||||
stream: &mut (impl io::Read + Unpin),
|
||||
) -> Result<Option<FeMessage>, ConnectionError> {
|
||||
pub fn read(stream: &mut (impl io::Read + Unpin)) -> Result<Option<FeMessage>, ProtocolError> {
|
||||
Self::read_fut(&mut AsyncishRead(stream)).wait()
|
||||
}
|
||||
|
||||
/// Read and parse startup message from the `buf` input buffer. It is
|
||||
/// different from [`FeMessage::parse`] because startup messages don't have
|
||||
/// message type byte; otherwise, its comments apply.
|
||||
pub fn parse(buf: &mut BytesMut) -> Result<Option<FeMessage>, ProtocolError> {
|
||||
const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
|
||||
const RESERVED_INVALID_MAJOR_VERSION: u32 = 1234;
|
||||
const CANCEL_REQUEST_CODE: u32 = 5678;
|
||||
const NEGOTIATE_SSL_CODE: u32 = 5679;
|
||||
const NEGOTIATE_GSS_CODE: u32 = 5680;
|
||||
|
||||
if buf.len() < 4 {
|
||||
let to_read = 5 - buf.len();
|
||||
buf.reserve(to_read);
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// We shouldn't advance `buf` as probably full message is not there yet,
|
||||
// so can't directly use Bytes::get_u32 etc.
|
||||
let len = (&buf[0..4]).read_u32::<BigEndian>().unwrap() as usize;
|
||||
if len < 8 || len > MAX_STARTUP_PACKET_LENGTH {
|
||||
return Err(ProtocolError::Protocol(format!(
|
||||
"invalid startup packet message length {}",
|
||||
len
|
||||
)));
|
||||
}
|
||||
|
||||
if buf.len() < len {
|
||||
// Don't have full message yet.
|
||||
let to_read = len - buf.len();
|
||||
buf.reserve(to_read);
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// got the message, advance buffer
|
||||
let mut msg = buf.split_to(len).freeze();
|
||||
msg.advance(4); // consume len
|
||||
|
||||
let request_code = msg.get_u32();
|
||||
let req_hi = request_code >> 16;
|
||||
let req_lo = request_code & ((1 << 16) - 1);
|
||||
// StartupMessage, CancelRequest, SSLRequest etc are differentiated by request code.
|
||||
let message = match (req_hi, req_lo) {
|
||||
(RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
|
||||
if msg.remaining() < 8 {
|
||||
return Err(ProtocolError::MessageParse(anyhow!(
|
||||
"CancelRequest message is malformed, backend PID / secret key missing"
|
||||
)));
|
||||
}
|
||||
FeStartupPacket::CancelRequest(CancelKeyData {
|
||||
backend_pid: msg.get_i32(),
|
||||
cancel_key: msg.get_i32(),
|
||||
})
|
||||
}
|
||||
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => {
|
||||
// Requested upgrade to SSL (aka TLS)
|
||||
FeStartupPacket::SslRequest
|
||||
}
|
||||
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => {
|
||||
// Requested upgrade to GSSAPI
|
||||
FeStartupPacket::GssEncRequest
|
||||
}
|
||||
(RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
|
||||
return Err(ProtocolError::Protocol(format!(
|
||||
"Unrecognized request code {unrecognized_code}"
|
||||
)));
|
||||
}
|
||||
// TODO bail if protocol major_version is not 3?
|
||||
(major_version, minor_version) => {
|
||||
// StartupMessage
|
||||
|
||||
// Parse pairs of null-terminated strings (key, value).
|
||||
// See `postgres: ProcessStartupPacket, build_startup_packet`.
|
||||
let mut tokens = str::from_utf8(&msg)
|
||||
.context("StartupMessage params: invalid utf-8")?
|
||||
.strip_suffix('\0') // drop packet's own null
|
||||
.ok_or_else(|| {
|
||||
ProtocolError::Protocol(
|
||||
"StartupMessage params: missing null terminator".to_string(),
|
||||
)
|
||||
})?
|
||||
.split_terminator('\0');
|
||||
|
||||
let mut params = HashMap::new();
|
||||
while let Some(name) = tokens.next() {
|
||||
let value = tokens.next().ok_or_else(|| {
|
||||
ProtocolError::Protocol(
|
||||
"StartupMessage params: key without value".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
params.insert(name.to_owned(), value.to_owned());
|
||||
}
|
||||
|
||||
FeStartupPacket::StartupMessage {
|
||||
major_version,
|
||||
minor_version,
|
||||
params: StartupMessageParams { params },
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(Some(FeMessage::StartupPacket(message)))
|
||||
}
|
||||
|
||||
/// Read startup message from the stream.
|
||||
// XXX: It's tempting yet undesirable to accept `stream` by value,
|
||||
// since such a change will cause user-supplied &mut references to be consumed
|
||||
pub fn read_fut<Reader>(
|
||||
stream: &mut Reader,
|
||||
) -> SyncFuture<Reader, impl Future<Output = Result<Option<FeMessage>, ConnectionError>> + '_>
|
||||
) -> SyncFuture<Reader, impl Future<Output = Result<Option<FeMessage>, ProtocolError>> + '_>
|
||||
where
|
||||
Reader: tokio::io::AsyncRead + Unpin,
|
||||
{
|
||||
use tokio::io::AsyncReadExt;
|
||||
|
||||
const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
|
||||
const RESERVED_INVALID_MAJOR_VERSION: u32 = 1234;
|
||||
const CANCEL_REQUEST_CODE: u32 = 5678;
|
||||
@@ -343,18 +460,18 @@ impl FeStartupPacket {
|
||||
let len = match retry_read!(stream.read_u32().await) {
|
||||
Ok(len) => len as usize,
|
||||
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
|
||||
Err(e) => return Err(ConnectionError::Socket(e)),
|
||||
Err(e) => return Err(ProtocolError::Socket(e)),
|
||||
};
|
||||
|
||||
#[allow(clippy::manual_range_contains)]
|
||||
if len < 4 || len > MAX_STARTUP_PACKET_LENGTH {
|
||||
return Err(ConnectionError::Protocol(format!(
|
||||
return Err(ProtocolError::Protocol(format!(
|
||||
"invalid message length {len}"
|
||||
)));
|
||||
}
|
||||
|
||||
let request_code =
|
||||
retry_read!(stream.read_u32().await).map_err(ConnectionError::Socket)?;
|
||||
retry_read!(stream.read_u32().await).map_err(ProtocolError::Socket)?;
|
||||
|
||||
// the rest of startup packet are params
|
||||
let params_len = len - 8;
|
||||
@@ -362,7 +479,7 @@ impl FeStartupPacket {
|
||||
stream
|
||||
.read_exact(params_bytes.as_mut())
|
||||
.await
|
||||
.map_err(ConnectionError::Socket)?;
|
||||
.map_err(ProtocolError::Socket)?;
|
||||
|
||||
// Parse params depending on request code
|
||||
let req_hi = request_code >> 16;
|
||||
@@ -370,14 +487,16 @@ impl FeStartupPacket {
|
||||
let message = match (req_hi, req_lo) {
|
||||
(RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
|
||||
if params_len != 8 {
|
||||
return Err(ConnectionError::Protocol(
|
||||
return Err(ProtocolError::Protocol(
|
||||
"expected 8 bytes for CancelRequest params".to_string(),
|
||||
));
|
||||
}
|
||||
let mut cursor = Cursor::new(params_bytes);
|
||||
FeStartupPacket::CancelRequest(CancelKeyData {
|
||||
backend_pid: cursor.read_i32().await.map_err(ConnectionError::Socket)?,
|
||||
cancel_key: cursor.read_i32().await.map_err(ConnectionError::Socket)?,
|
||||
backend_pid: 2,
|
||||
cancel_key: 2,
|
||||
// backend_pid: cursor.read_i32().await.map_err(ConnectionError::Socket)?,
|
||||
// cancel_key: cursor.read_i32().await.map_err(ConnectionError::Socket)?,
|
||||
})
|
||||
}
|
||||
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => {
|
||||
@@ -389,7 +508,7 @@ impl FeStartupPacket {
|
||||
FeStartupPacket::GssEncRequest
|
||||
}
|
||||
(RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
|
||||
return Err(ConnectionError::Protocol(format!(
|
||||
return Err(ProtocolError::Protocol(format!(
|
||||
"Unrecognized request code {unrecognized_code}"
|
||||
)));
|
||||
}
|
||||
@@ -401,7 +520,7 @@ impl FeStartupPacket {
|
||||
.context("StartupMessage params: invalid utf-8")?
|
||||
.strip_suffix('\0') // drop packet's own null
|
||||
.ok_or_else(|| {
|
||||
ConnectionError::Protocol(
|
||||
ProtocolError::Protocol(
|
||||
"StartupMessage params: missing null terminator".to_string(),
|
||||
)
|
||||
})?
|
||||
@@ -410,7 +529,7 @@ impl FeStartupPacket {
|
||||
let mut params = HashMap::new();
|
||||
while let Some(name) = tokens.next() {
|
||||
let value = tokens.next().ok_or_else(|| {
|
||||
ConnectionError::Protocol(
|
||||
ProtocolError::Protocol(
|
||||
"StartupMessage params: key without value".to_string(),
|
||||
)
|
||||
})?;
|
||||
@@ -440,6 +559,9 @@ impl FeParseMessage {
|
||||
|
||||
let _pstmt_name = read_cstr(&mut buf)?;
|
||||
let query_string = read_cstr(&mut buf)?;
|
||||
if buf.remaining() < 2 {
|
||||
bail!("Parse message is malformed, nparams missing");
|
||||
}
|
||||
let nparams = buf.get_i16();
|
||||
|
||||
ensure!(nparams == 0, "query params not implemented");
|
||||
@@ -466,6 +588,9 @@ impl FeDescribeMessage {
|
||||
impl FeExecuteMessage {
|
||||
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
|
||||
let portal_name = read_cstr(&mut buf)?;
|
||||
if buf.remaining() < 4 {
|
||||
bail!("FeExecuteMessage message is malformed, maxrows missing");
|
||||
}
|
||||
let maxrows = buf.get_i32();
|
||||
|
||||
ensure!(portal_name.is_empty(), "named portals not implemented");
|
||||
@@ -547,6 +672,11 @@ impl<'a> BeMessage<'a> {
|
||||
value: b"UTF8",
|
||||
};
|
||||
|
||||
pub const INTEGER_DATETIMES: Self = Self::ParameterStatus {
|
||||
name: b"integer_datetimes",
|
||||
value: b"on",
|
||||
};
|
||||
|
||||
/// Build a [`BeMessage::ParameterStatus`] holding the server version.
|
||||
pub fn server_version(version: &'a str) -> Self {
|
||||
Self::ParameterStatus {
|
||||
@@ -665,13 +795,12 @@ fn write_body<R>(buf: &mut BytesMut, f: impl FnOnce(&mut BytesMut) -> R) -> R {
|
||||
}
|
||||
|
||||
/// Safe write of s into buf as cstring (String in the protocol).
|
||||
fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> io::Result<()> {
|
||||
fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> Result<(), ProtocolError> {
|
||||
let bytes = s.as_ref();
|
||||
if bytes.contains(&0) {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"string contains embedded null",
|
||||
));
|
||||
return Err(ProtocolError::MessageParse(anyhow!(
|
||||
"string contains embedded null"
|
||||
)));
|
||||
}
|
||||
buf.put_slice(bytes);
|
||||
buf.put_u8(0);
|
||||
@@ -680,7 +809,7 @@ fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> io::Result<()> {
|
||||
|
||||
fn read_cstr(buf: &mut Bytes) -> anyhow::Result<Bytes> {
|
||||
let pos = buf.iter().position(|x| *x == 0);
|
||||
let result = buf.split_to(pos.context("missing terminator")?);
|
||||
let result = buf.split_to(pos.context("missing cstring terminator")?);
|
||||
buf.advance(1); // drop the null terminator
|
||||
Ok(result)
|
||||
}
|
||||
@@ -688,12 +817,12 @@ fn read_cstr(buf: &mut Bytes) -> anyhow::Result<Bytes> {
|
||||
pub const SQLSTATE_INTERNAL_ERROR: &[u8; 5] = b"XX000";
|
||||
|
||||
impl<'a> BeMessage<'a> {
|
||||
/// Write message to the given buf.
|
||||
// Unlike the reading side, we use BytesMut
|
||||
// here as msg len precedes its body and it is handy to write it down first
|
||||
// and then fill the length. With Write we would have to either calc it
|
||||
// manually or have one more buffer.
|
||||
pub fn write(buf: &mut BytesMut, message: &BeMessage) -> io::Result<()> {
|
||||
/// Serialize `message` to the given `buf`.
|
||||
/// Apart from smart memory managemet, BytesMut is good here as msg len
|
||||
/// precedes its body and it is handy to write it down first and then fill
|
||||
/// the length. With Write we would have to either calc it manually or have
|
||||
/// one more buffer.
|
||||
pub fn write(buf: &mut BytesMut, message: &BeMessage) -> Result<(), ProtocolError> {
|
||||
match message {
|
||||
BeMessage::AuthenticationOk => {
|
||||
buf.put_u8(b'R');
|
||||
@@ -719,7 +848,7 @@ impl<'a> BeMessage<'a> {
|
||||
|
||||
BeMessage::AuthenticationSasl(msg) => {
|
||||
buf.put_u8(b'R');
|
||||
write_body(buf, |buf| {
|
||||
write_body(buf, |buf| -> Result<(), ProtocolError> {
|
||||
use BeAuthenticationSaslMessage::*;
|
||||
match msg {
|
||||
Methods(methods) => {
|
||||
@@ -738,7 +867,7 @@ impl<'a> BeMessage<'a> {
|
||||
buf.put_slice(extra);
|
||||
}
|
||||
}
|
||||
Ok::<_, io::Error>(())
|
||||
Ok(())
|
||||
})?;
|
||||
}
|
||||
|
||||
@@ -829,7 +958,7 @@ impl<'a> BeMessage<'a> {
|
||||
BeMessage::ErrorResponse(error_msg, pg_error_code) => {
|
||||
// 'E' signalizes ErrorResponse messages
|
||||
buf.put_u8(b'E');
|
||||
write_body(buf, |buf| {
|
||||
write_body(buf, |buf| -> Result<(), ProtocolError> {
|
||||
buf.put_u8(b'S'); // severity
|
||||
buf.put_slice(b"ERROR\0");
|
||||
|
||||
@@ -842,7 +971,7 @@ impl<'a> BeMessage<'a> {
|
||||
write_cstr(error_msg, buf)?;
|
||||
|
||||
buf.put_u8(0); // terminator
|
||||
Ok::<_, io::Error>(())
|
||||
Ok(())
|
||||
})?;
|
||||
}
|
||||
|
||||
@@ -854,7 +983,7 @@ impl<'a> BeMessage<'a> {
|
||||
|
||||
// 'N' signalizes NoticeResponse messages
|
||||
buf.put_u8(b'N');
|
||||
write_body(buf, |buf| {
|
||||
write_body(buf, |buf| -> Result<(), ProtocolError> {
|
||||
buf.put_u8(b'S'); // severity
|
||||
buf.put_slice(b"NOTICE\0");
|
||||
|
||||
@@ -865,7 +994,7 @@ impl<'a> BeMessage<'a> {
|
||||
write_cstr(error_msg.as_bytes(), buf)?;
|
||||
|
||||
buf.put_u8(0); // terminator
|
||||
Ok::<_, io::Error>(())
|
||||
Ok(())
|
||||
})?;
|
||||
}
|
||||
|
||||
@@ -909,7 +1038,7 @@ impl<'a> BeMessage<'a> {
|
||||
|
||||
BeMessage::RowDescription(rows) => {
|
||||
buf.put_u8(b'T');
|
||||
write_body(buf, |buf| {
|
||||
write_body(buf, |buf| -> Result<(), ProtocolError> {
|
||||
buf.put_i16(rows.len() as i16); // # of fields
|
||||
for row in rows.iter() {
|
||||
write_cstr(row.name, buf)?;
|
||||
@@ -920,7 +1049,7 @@ impl<'a> BeMessage<'a> {
|
||||
buf.put_i32(-1); /* typmod */
|
||||
buf.put_i16(0); /* format code */
|
||||
}
|
||||
Ok::<_, io::Error>(())
|
||||
Ok(())
|
||||
})?;
|
||||
}
|
||||
|
||||
|
||||
@@ -111,7 +111,7 @@ pub trait RemoteStorage: Send + Sync + 'static {
|
||||
}
|
||||
|
||||
pub struct Download {
|
||||
pub download_stream: Pin<Box<dyn io::AsyncRead + Unpin + Send>>,
|
||||
pub download_stream: Pin<Box<dyn io::AsyncRead + Unpin + Send + Sync>>,
|
||||
/// Extra key-value data, associated with the current remote file.
|
||||
pub metadata: Option<StorageMetadata>,
|
||||
}
|
||||
|
||||
@@ -10,13 +10,16 @@ async-trait = "0.1"
|
||||
anyhow = "1.0"
|
||||
bincode = "1.3"
|
||||
bytes = "1.0.1"
|
||||
futures = "0.3"
|
||||
hyper = { version = "0.14.7", features = ["full"] }
|
||||
pin-utils = "0.1"
|
||||
routerify = "3"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
thiserror = "1.0"
|
||||
tokio = { version = "1.17", features = ["macros"]}
|
||||
tokio-rustls = "0.23"
|
||||
tokio-util = { version = "0.7.3" }
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }
|
||||
nix = "0.25"
|
||||
|
||||
@@ -13,7 +13,7 @@ pub mod simple_rcu;
|
||||
pub mod vec_map;
|
||||
|
||||
pub mod bin_ser;
|
||||
pub mod postgres_backend;
|
||||
// pub mod postgres_backend;
|
||||
pub mod postgres_backend_async;
|
||||
|
||||
// helper functions for creating and fsyncing
|
||||
@@ -52,6 +52,8 @@ pub mod signals;
|
||||
|
||||
pub mod fs_ext;
|
||||
|
||||
pub mod send_rc;
|
||||
|
||||
/// use with fail::cfg("$name", "return(2000)")
|
||||
#[macro_export]
|
||||
macro_rules! failpoint_sleep_millis_async {
|
||||
|
||||
@@ -2,29 +2,24 @@
|
||||
//! To use, create PostgresBackend and run() it, passing the Handler
|
||||
//! implementation determining how to process the queries. Currently its API
|
||||
//! is rather narrow, but we can extend it once required.
|
||||
|
||||
use crate::postgres_backend::AuthType;
|
||||
use anyhow::Context;
|
||||
use bytes::{Buf, Bytes, BytesMut};
|
||||
use pq_proto::{BeMessage, ConnectionError, FeMessage, FeStartupPacket, SQLSTATE_INTERNAL_ERROR};
|
||||
use std::future::Future;
|
||||
use std::io;
|
||||
use futures::stream::StreamExt;
|
||||
use futures::{pin_mut, Sink, SinkExt};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::net::SocketAddr;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::Poll;
|
||||
use tracing::{debug, error, info, trace};
|
||||
|
||||
use std::{fmt, io};
|
||||
use std::{future::Future, str::FromStr};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
use tokio_util::codec::Framed;
|
||||
use tracing::{debug, error, info, trace};
|
||||
|
||||
pub fn is_expected_io_error(e: &io::Error) -> bool {
|
||||
use io::ErrorKind::*;
|
||||
matches!(
|
||||
e.kind(),
|
||||
ConnectionRefused | ConnectionAborted | ConnectionReset
|
||||
)
|
||||
}
|
||||
use pq_proto::codec::{ConnectionError, PostgresCodec};
|
||||
use pq_proto::{BeMessage, FeMessage, FeStartupPacket, SQLSTATE_INTERNAL_ERROR};
|
||||
|
||||
/// An error, occurred during query processing:
|
||||
/// either during the connection ([`ConnectionError`]) or before/after it.
|
||||
@@ -40,7 +35,7 @@ pub enum QueryError {
|
||||
|
||||
impl From<io::Error> for QueryError {
|
||||
fn from(e: io::Error) -> Self {
|
||||
Self::Disconnected(ConnectionError::Socket(e))
|
||||
Self::Disconnected(ConnectionError::Io(e))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,6 +48,14 @@ impl QueryError {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_expected_io_error(e: &io::Error) -> bool {
|
||||
use io::ErrorKind::*;
|
||||
matches!(
|
||||
e.kind(),
|
||||
ConnectionRefused | ConnectionAborted | ConnectionReset
|
||||
)
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait Handler {
|
||||
/// Handle single query.
|
||||
@@ -93,6 +96,7 @@ pub trait Handler {
|
||||
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd)]
|
||||
pub enum ProtoState {
|
||||
Initialization,
|
||||
// Encryption handshake is done; waiting for encrypted Startup message.
|
||||
Encrypted,
|
||||
Authentication,
|
||||
Established,
|
||||
@@ -105,15 +109,14 @@ pub enum ProcessMsgResult {
|
||||
Break,
|
||||
}
|
||||
|
||||
/// Always-writeable sock_split stream.
|
||||
/// May not be readable. See [`PostgresBackend::take_stream_in`]
|
||||
pub enum Stream {
|
||||
Unencrypted(BufReader<tokio::net::TcpStream>),
|
||||
Tls(Box<tokio_rustls::server::TlsStream<BufReader<tokio::net::TcpStream>>>),
|
||||
Broken,
|
||||
/// Either plain TCP stream or encrypted one, implementing AsyncRead + AsyncWrite.
|
||||
pub enum MaybeTlsStream {
|
||||
Unencrypted(tokio::net::TcpStream),
|
||||
Tls(Box<tokio_rustls::server::TlsStream<tokio::net::TcpStream>>),
|
||||
Broken, // temporary value for switch to TLS
|
||||
}
|
||||
|
||||
impl AsyncWrite for Stream {
|
||||
impl AsyncWrite for MaybeTlsStream {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
@@ -122,14 +125,14 @@ impl AsyncWrite for Stream {
|
||||
match self.get_mut() {
|
||||
Self::Unencrypted(stream) => Pin::new(stream).poll_write(cx, buf),
|
||||
Self::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
|
||||
Self::Broken => unreachable!(),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<io::Result<()>> {
|
||||
match self.get_mut() {
|
||||
Self::Unencrypted(stream) => Pin::new(stream).poll_flush(cx),
|
||||
Self::Tls(stream) => Pin::new(stream).poll_flush(cx),
|
||||
Self::Broken => unreachable!(),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
fn poll_shutdown(
|
||||
@@ -139,11 +142,11 @@ impl AsyncWrite for Stream {
|
||||
match self.get_mut() {
|
||||
Self::Unencrypted(stream) => Pin::new(stream).poll_shutdown(cx),
|
||||
Self::Tls(stream) => Pin::new(stream).poll_shutdown(cx),
|
||||
Self::Broken => unreachable!(),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl AsyncRead for Stream {
|
||||
impl AsyncRead for MaybeTlsStream {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
@@ -152,18 +155,49 @@ impl AsyncRead for Stream {
|
||||
match self.get_mut() {
|
||||
Self::Unencrypted(stream) => Pin::new(stream).poll_read(cx, buf),
|
||||
Self::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
|
||||
Self::Broken => unreachable!(),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PostgresBackend {
|
||||
stream: Stream,
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
|
||||
pub enum AuthType {
|
||||
Trust,
|
||||
// This mimics postgres's AuthenticationCleartextPassword but instead of password expects JWT
|
||||
NeonJWT,
|
||||
}
|
||||
|
||||
// Output buffer. c.f. BeMessage::write why we are using BytesMut here.
|
||||
// The data between 0 and "current position" as tracked by the bytes::Buf
|
||||
// implementation of BytesMut, have already been written.
|
||||
buf_out: BytesMut,
|
||||
impl FromStr for AuthType {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"Trust" => Ok(Self::Trust),
|
||||
"NeonJWT" => Ok(Self::NeonJWT),
|
||||
_ => anyhow::bail!("invalid value \"{s}\" for auth type"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for AuthType {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str(match self {
|
||||
AuthType::Trust => "Trust",
|
||||
AuthType::NeonJWT => "NeonJWT",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PostgresBackend {
|
||||
// Provides serialization/deserialization to the underlying transport backed
|
||||
// with buffers; implements Sink consuming messages and Stream reading them.
|
||||
//
|
||||
// Sink::start_send only queues message to the interal buffer.
|
||||
// SinkExt::flush flushes buffer to the stream.
|
||||
//
|
||||
// StreamExt::read reads next message. In case of EOF without partial
|
||||
// message it returns None.
|
||||
stream: Framed<MaybeTlsStream, PostgresCodec>,
|
||||
|
||||
pub state: ProtoState,
|
||||
|
||||
@@ -196,10 +230,10 @@ impl PostgresBackend {
|
||||
tls_config: Option<Arc<rustls::ServerConfig>>,
|
||||
) -> io::Result<Self> {
|
||||
let peer_addr = socket.peer_addr()?;
|
||||
let stream = MaybeTlsStream::Unencrypted(socket);
|
||||
|
||||
Ok(Self {
|
||||
stream: Stream::Unencrypted(BufReader::new(socket)),
|
||||
buf_out: BytesMut::with_capacity(10 * 1024),
|
||||
stream: Framed::new(stream, PostgresCodec::new()),
|
||||
state: ProtoState::Initialization,
|
||||
auth_type,
|
||||
tls_config,
|
||||
@@ -212,29 +246,60 @@ impl PostgresBackend {
|
||||
}
|
||||
|
||||
/// Read full message or return None if connection is closed.
|
||||
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, QueryError> {
|
||||
use ProtoState::*;
|
||||
match self.state {
|
||||
Initialization | Encrypted => FeStartupPacket::read_fut(&mut self.stream).await,
|
||||
Authentication | Established => FeMessage::read_fut(&mut self.stream).await,
|
||||
Closed => Ok(None),
|
||||
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
|
||||
if let ProtoState::Closed = self.state {
|
||||
Ok(None)
|
||||
} else {
|
||||
let msg = self.stream.next().await;
|
||||
// Option<Result<...>>, so swap.
|
||||
msg.map_or(Ok(None), |res| res.map(Some))
|
||||
}
|
||||
.map_err(QueryError::from)
|
||||
}
|
||||
|
||||
/// Polling version of read_message, saves the caller need to pin.
|
||||
pub fn poll_read_message(
|
||||
&mut self,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<Option<FeMessage>, ConnectionError>> {
|
||||
let read_fut = self.read_message();
|
||||
pin_mut!(read_fut);
|
||||
read_fut.poll(cx)
|
||||
}
|
||||
|
||||
/// Flush output buffer into the socket.
|
||||
pub async fn flush(&mut self) -> io::Result<()> {
|
||||
while self.buf_out.has_remaining() {
|
||||
let bytes_written = self.stream.write(self.buf_out.chunk()).await?;
|
||||
self.buf_out.advance(bytes_written);
|
||||
}
|
||||
self.buf_out.clear();
|
||||
Ok(())
|
||||
self.stream.flush().await.map_err(|e| match e {
|
||||
ConnectionError::Io(e) => e,
|
||||
// the only error we can get from flushing is IO
|
||||
_ => unreachable!(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Write message into internal output buffer.
|
||||
pub fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
|
||||
BeMessage::write(&mut self.buf_out, message)?;
|
||||
/// Polling version of `flush()`, saves the caller need to pin.
|
||||
pub fn poll_flush(
|
||||
&mut self,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
let flush_fut = self.flush();
|
||||
pin_mut!(flush_fut);
|
||||
flush_fut.poll(cx)
|
||||
}
|
||||
|
||||
/// Write message into internal output buffer. Technically error type can be
|
||||
/// only ProtocolError here (if, unlikely, serialization fails), but callers
|
||||
/// typically wrap it anyway.
|
||||
pub fn write_message(&mut self, message: &BeMessage<'_>) -> Result<&mut Self, ConnectionError> {
|
||||
Pin::new(&mut self.stream).start_send(message)?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Write message into internal output buffer and flush it to the stream.
|
||||
pub async fn write_message_flush(
|
||||
&mut self,
|
||||
message: &BeMessage<'_>,
|
||||
) -> Result<&mut Self, ConnectionError> {
|
||||
self.write_message(message)?;
|
||||
self.flush().await?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
@@ -246,28 +311,6 @@ impl PostgresBackend {
|
||||
CopyDataWriter { pgb: self }
|
||||
}
|
||||
|
||||
/// A polling function that tries to write all the data from 'buf_out' to the
|
||||
/// underlying stream.
|
||||
fn poll_write_buf(
|
||||
&mut self,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
while self.buf_out.has_remaining() {
|
||||
match Pin::new(&mut self.stream).poll_write(cx, self.buf_out.chunk()) {
|
||||
Poll::Ready(Ok(bytes_written)) => {
|
||||
self.buf_out.advance(bytes_written);
|
||||
}
|
||||
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
|
||||
Poll::Pending => return Poll::Pending,
|
||||
}
|
||||
}
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_flush(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), std::io::Error>> {
|
||||
Pin::new(&mut self.stream).poll_flush(cx)
|
||||
}
|
||||
|
||||
// Wrapper for run_message_loop() that shuts down socket when we are done
|
||||
pub async fn run<F, S>(
|
||||
mut self,
|
||||
@@ -279,7 +322,7 @@ impl PostgresBackend {
|
||||
S: Future,
|
||||
{
|
||||
let ret = self.run_message_loop(handler, shutdown_watcher).await;
|
||||
let _ = self.stream.shutdown();
|
||||
let _ = self.stream.get_mut().shutdown();
|
||||
ret
|
||||
}
|
||||
|
||||
@@ -359,14 +402,22 @@ impl PostgresBackend {
|
||||
}
|
||||
|
||||
async fn start_tls(&mut self) -> anyhow::Result<()> {
|
||||
if let Stream::Unencrypted(plain_stream) =
|
||||
std::mem::replace(&mut self.stream, Stream::Broken)
|
||||
if let MaybeTlsStream::Unencrypted(plain_stream) =
|
||||
// temporary replace stream with fake broken to prepare TLS one
|
||||
std::mem::replace(self.stream.get_mut(), MaybeTlsStream::Broken)
|
||||
{
|
||||
let acceptor = TlsAcceptor::from(self.tls_config.clone().unwrap());
|
||||
let tls_stream = acceptor.accept(plain_stream).await?;
|
||||
|
||||
self.stream = Stream::Tls(Box::new(tls_stream));
|
||||
return Ok(());
|
||||
match acceptor.accept(plain_stream).await {
|
||||
Ok(tls_stream) => {
|
||||
// push back ready TLS stream
|
||||
*self.stream.get_mut() = MaybeTlsStream::Tls(Box::new(tls_stream));
|
||||
return Ok(());
|
||||
}
|
||||
Err(e) => {
|
||||
self.state = ProtoState::Closed;
|
||||
return Err(e.into());
|
||||
}
|
||||
}
|
||||
};
|
||||
anyhow::bail!("TLS already started");
|
||||
}
|
||||
@@ -380,13 +431,12 @@ impl PostgresBackend {
|
||||
let have_tls = self.tls_config.is_some();
|
||||
match msg {
|
||||
FeMessage::StartupPacket(m) => {
|
||||
trace!("got startup message {m:?}");
|
||||
|
||||
match m {
|
||||
FeStartupPacket::SslRequest => {
|
||||
debug!("SSL requested");
|
||||
|
||||
self.write_message(&BeMessage::EncryptionResponse(have_tls))?;
|
||||
|
||||
if have_tls {
|
||||
self.start_tls().await?;
|
||||
self.state = ProtoState::Encrypted;
|
||||
@@ -415,6 +465,7 @@ impl PostgresBackend {
|
||||
AuthType::Trust => {
|
||||
self.write_message(&BeMessage::AuthenticationOk)?
|
||||
.write_message(&BeMessage::CLIENT_ENCODING)?
|
||||
.write_message(&BeMessage::INTEGER_DATETIMES)?
|
||||
// The async python driver requires a valid server_version
|
||||
.write_message(&BeMessage::server_version("14.1"))?
|
||||
.write_message(&BeMessage::ReadyForQuery)?;
|
||||
@@ -454,6 +505,7 @@ impl PostgresBackend {
|
||||
}
|
||||
self.write_message(&BeMessage::AuthenticationOk)?
|
||||
.write_message(&BeMessage::CLIENT_ENCODING)?
|
||||
.write_message(&BeMessage::INTEGER_DATETIMES)?
|
||||
.write_message(&BeMessage::ReadyForQuery)?;
|
||||
self.state = ProtoState::Established;
|
||||
}
|
||||
@@ -573,7 +625,7 @@ impl<'a> AsyncWrite for CopyDataWriter<'a> {
|
||||
// It's not strictly required to flush between each message, but makes it easier
|
||||
// to view in wireshark, and usually the messages that the callers write are
|
||||
// decently-sized anyway.
|
||||
match this.pgb.poll_write_buf(cx) {
|
||||
match this.pgb.poll_flush(cx) {
|
||||
Poll::Ready(Ok(())) => {}
|
||||
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
|
||||
Poll::Pending => return Poll::Pending,
|
||||
@@ -583,7 +635,11 @@ impl<'a> AsyncWrite for CopyDataWriter<'a> {
|
||||
// XXX: if the input is large, we should split it into multiple messages.
|
||||
// Not sure what the threshold should be, but the ultimate hard limit is that
|
||||
// the length cannot exceed u32.
|
||||
this.pgb.write_message(&BeMessage::CopyData(buf))?;
|
||||
this.pgb
|
||||
.write_message(&BeMessage::CopyData(buf))
|
||||
// write_message only writes to buffer, so can fail iff message is
|
||||
// invaid, but CopyData can't be invalid.
|
||||
.expect("failed to serialize CopyData");
|
||||
|
||||
Poll::Ready(Ok(buf.len()))
|
||||
}
|
||||
@@ -593,23 +649,14 @@ impl<'a> AsyncWrite for CopyDataWriter<'a> {
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
let this = self.get_mut();
|
||||
match this.pgb.poll_write_buf(cx) {
|
||||
Poll::Ready(Ok(())) => {}
|
||||
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
|
||||
Poll::Pending => return Poll::Pending,
|
||||
}
|
||||
this.pgb.poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
let this = self.get_mut();
|
||||
match this.pgb.poll_write_buf(cx) {
|
||||
Poll::Ready(Ok(())) => {}
|
||||
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
|
||||
Poll::Pending => return Poll::Pending,
|
||||
}
|
||||
this.pgb.poll_flush(cx)
|
||||
}
|
||||
}
|
||||
@@ -623,7 +670,7 @@ pub fn short_error(e: &QueryError) -> String {
|
||||
|
||||
pub(super) fn log_query_error(query: &str, e: &QueryError) {
|
||||
match e {
|
||||
QueryError::Disconnected(ConnectionError::Socket(io_error)) => {
|
||||
QueryError::Disconnected(ConnectionError::Io(io_error)) => {
|
||||
if is_expected_io_error(io_error) {
|
||||
info!("query handler for '{query}' failed with expected io error: {io_error}");
|
||||
} else {
|
||||
|
||||
116
libs/utils/src/send_rc.rs
Normal file
116
libs/utils/src/send_rc.rs
Normal file
@@ -0,0 +1,116 @@
|
||||
/// Provides Send wrappers of Rc and RefMut.
|
||||
use std::{
|
||||
borrow::Borrow,
|
||||
cell::{Ref, RefCell, RefMut},
|
||||
ops::{Deref, DerefMut},
|
||||
rc::Rc,
|
||||
};
|
||||
|
||||
/// Rc wrapper which is Send.
|
||||
/// This is useful to allow transferring a group of Rcs pointing to the same
|
||||
/// object between threads, e.g. in self referential struct.
|
||||
#[derive(Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
|
||||
pub struct SendRc<T>
|
||||
where
|
||||
T: ?Sized,
|
||||
{
|
||||
rc: Rc<T>,
|
||||
}
|
||||
|
||||
// SAFETY: Passing Rc(s)<T: Send> between threads is fine as long as there is no
|
||||
// concurrent access to the object they point to, so you must move all such Rcs
|
||||
// together. This appears to be impossible to express in rust type system and
|
||||
// SendRc doesn't provide any additional protection -- but unlike sendable
|
||||
// crate, neither it requires any additional actions before/after move. Ensuring
|
||||
// that sending conforms to the above is the responsibility of the type user.
|
||||
unsafe impl<T: ?Sized + Send> Send for SendRc<T> {}
|
||||
|
||||
impl<T> SendRc<T> {
|
||||
/// Constructs a new SendRc<T>
|
||||
pub fn new(value: T) -> SendRc<T> {
|
||||
SendRc { rc: Rc::new(value) }
|
||||
}
|
||||
}
|
||||
|
||||
// https://stegosaurusdormant.com/understanding-derive-clone/ explains in detail
|
||||
// why derive Clone doesn't work here.
|
||||
impl<T> Clone for SendRc<T> {
|
||||
fn clone(&self) -> Self {
|
||||
SendRc {
|
||||
rc: self.rc.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Deref into inner rc.
|
||||
impl<T> Deref for SendRc<T> {
|
||||
type Target = Rc<T>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.rc
|
||||
}
|
||||
}
|
||||
|
||||
/// Extends RefCell with borrow[_mut] variants which return Sendable Ref[Mut]
|
||||
/// wrappers.
|
||||
pub trait RefCellSend<T: ?Sized> {
|
||||
fn borrow_mut_send(&self) -> RefMutSend<'_, T>;
|
||||
}
|
||||
|
||||
impl<T: Sized> RefCellSend<T> for RefCell<T> {
|
||||
fn borrow_mut_send(&self) -> RefMutSend<'_, T> {
|
||||
RefMutSend {
|
||||
ref_mut: self.borrow_mut(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// RefMut wrapper which is Send. See impl Send for safety. Allows to move a
|
||||
/// RefMut along with RefCell it originates from between threads, e.g. have Send
|
||||
/// Future containing RefMut.
|
||||
#[derive(Debug)]
|
||||
pub struct RefMutSend<'b, T>
|
||||
where
|
||||
T: 'b + ?Sized,
|
||||
{
|
||||
ref_mut: RefMut<'b, T>,
|
||||
}
|
||||
|
||||
// SAFETY: Similar to SendRc, this is safe as long as RefMut stays in the same
|
||||
// thread with original RefCell, so they should be passed together.
|
||||
// Actually, since this is a referential type violating this is not
|
||||
// straightforward; examples of unsafe usage could be
|
||||
// - Passing a RefMut to different thread without source RefCell. Seems only
|
||||
// possible with std::thread::scope.
|
||||
// - Somehow multiple threads get access to single RefCell concurrently,
|
||||
// violating its !Sync requirement. Improper usage of SendRc can do that.
|
||||
unsafe impl<'b, T: ?Sized + Send> Send for RefMutSend<'b, T> {}
|
||||
|
||||
impl<'b, T> RefMutSend<'b, T> {
|
||||
/// Constructs a new RefMutSend<T>
|
||||
pub fn new(ref_mut: RefMut<'b, T>) -> RefMutSend<'b, T> {
|
||||
RefMutSend { ref_mut }
|
||||
}
|
||||
}
|
||||
|
||||
// Deref into inner RefMut.
|
||||
impl<'b, T> Deref for RefMutSend<'b, T>
|
||||
where
|
||||
T: 'b + ?Sized,
|
||||
{
|
||||
type Target = RefMut<'b, T>;
|
||||
|
||||
fn deref<'a>(&'a self) -> &'a RefMut<'b, T> {
|
||||
&self.ref_mut
|
||||
}
|
||||
}
|
||||
|
||||
// DerefMut into inner RefMut.
|
||||
impl<'b, T> DerefMut for RefMutSend<'b, T>
|
||||
where
|
||||
T: 'b + ?Sized,
|
||||
{
|
||||
fn deref_mut<'a>(&'a mut self) -> &'a mut RefMut<'b, T> {
|
||||
&mut self.ref_mut
|
||||
}
|
||||
}
|
||||
@@ -24,7 +24,7 @@ use pageserver::{
|
||||
use utils::{
|
||||
auth::JwtAuth,
|
||||
logging,
|
||||
postgres_backend::AuthType,
|
||||
postgres_backend_async::AuthType,
|
||||
project_git_version,
|
||||
sentry_init::{init_sentry, release_name},
|
||||
signals::{self, Signal},
|
||||
|
||||
@@ -24,7 +24,7 @@ use toml_edit::{Document, Item};
|
||||
use utils::{
|
||||
id::{NodeId, TenantId, TimelineId},
|
||||
logging::LogFormat,
|
||||
postgres_backend::AuthType,
|
||||
postgres_backend_async::AuthType,
|
||||
};
|
||||
|
||||
use crate::tenant::config::TenantConf;
|
||||
|
||||
@@ -19,7 +19,7 @@ use pageserver_api::models::{
|
||||
PagestreamFeMessage, PagestreamGetPageRequest, PagestreamGetPageResponse,
|
||||
PagestreamNblocksRequest, PagestreamNblocksResponse,
|
||||
};
|
||||
use pq_proto::ConnectionError;
|
||||
use pq_proto::codec::ConnectionError;
|
||||
use pq_proto::FeStartupPacket;
|
||||
use pq_proto::{BeMessage, FeMessage, RowDescriptor};
|
||||
use std::io;
|
||||
@@ -35,7 +35,7 @@ use utils::{
|
||||
auth::{Claims, JwtAuth, Scope},
|
||||
id::{TenantId, TimelineId},
|
||||
lsn::Lsn,
|
||||
postgres_backend::AuthType,
|
||||
postgres_backend_async::AuthType,
|
||||
postgres_backend_async::{self, PostgresBackend},
|
||||
simple_rcu::RcuReadGuard,
|
||||
};
|
||||
@@ -67,7 +67,7 @@ fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream<Item = io::Result<Byt
|
||||
Err(QueryError::Other(anyhow::anyhow!(msg)))
|
||||
}
|
||||
|
||||
msg = pgb.read_message() => { msg }
|
||||
msg = pgb.read_message() => { msg.map_err(QueryError::from)}
|
||||
};
|
||||
|
||||
match msg {
|
||||
@@ -78,14 +78,16 @@ fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream<Item = io::Result<Byt
|
||||
FeMessage::Sync => continue,
|
||||
FeMessage::Terminate => {
|
||||
let msg = "client terminated connection with Terminate message during COPY";
|
||||
let query_error_error = QueryError::Disconnected(ConnectionError::Socket(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
|
||||
pgb.write_message(&BeMessage::ErrorResponse(msg, Some(query_error_error.pg_error_code())))?;
|
||||
let query_error = QueryError::Disconnected(ConnectionError::Io(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
|
||||
pgb.write_message(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code())))
|
||||
.expect("failed to serialize ErrorResponse");
|
||||
Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?;
|
||||
break;
|
||||
}
|
||||
m => {
|
||||
let msg = format!("unexpected message {m:?}");
|
||||
pgb.write_message(&BeMessage::ErrorResponse(&msg, None))?;
|
||||
pgb.write_message(&BeMessage::ErrorResponse(&msg, None))
|
||||
.expect("failed to serialize ErrorResponse");
|
||||
Err(io::Error::new(io::ErrorKind::Other, msg))?;
|
||||
break;
|
||||
}
|
||||
@@ -95,16 +97,17 @@ fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream<Item = io::Result<Byt
|
||||
}
|
||||
Ok(None) => {
|
||||
let msg = "client closed connection during COPY";
|
||||
let query_error_error = QueryError::Disconnected(ConnectionError::Socket(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
|
||||
pgb.write_message(&BeMessage::ErrorResponse(msg, Some(query_error_error.pg_error_code())))?;
|
||||
let query_error = QueryError::Disconnected(ConnectionError::Io(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
|
||||
pgb.write_message(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code())))
|
||||
.expect("failed to serialize ErrorResponse");
|
||||
pgb.flush().await?;
|
||||
Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?;
|
||||
}
|
||||
Err(QueryError::Disconnected(ConnectionError::Socket(io_error))) => {
|
||||
Err(QueryError::Disconnected(ConnectionError::Io(io_error))) => {
|
||||
Err(io_error)?;
|
||||
}
|
||||
Err(other) => {
|
||||
Err(io::Error::new(io::ErrorKind::Other, other))?;
|
||||
Err(io::Error::new(io::ErrorKind::Other, other.to_string()))?;
|
||||
}
|
||||
};
|
||||
}
|
||||
@@ -202,7 +205,7 @@ async fn page_service_conn_main(
|
||||
// we've been requested to shut down
|
||||
Ok(())
|
||||
}
|
||||
Err(QueryError::Disconnected(ConnectionError::Socket(io_error))) => {
|
||||
Err(QueryError::Disconnected(ConnectionError::Io(io_error))) => {
|
||||
// `ConnectionReset` error happens when the Postgres client closes the connection.
|
||||
// As this disconnection happens quite often and is expected,
|
||||
// we decided to downgrade the logging level to `INFO`.
|
||||
|
||||
@@ -2,7 +2,7 @@ use crate::error::UserFacingError;
|
||||
use anyhow::bail;
|
||||
use bytes::BytesMut;
|
||||
use pin_project_lite::pin_project;
|
||||
use pq_proto::{BeMessage, ConnectionError, FeMessage, FeStartupPacket};
|
||||
use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
|
||||
use rustls::ServerConfig;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
@@ -53,7 +53,7 @@ impl<S: AsyncRead + Unpin> PqStream<S> {
|
||||
// TODO: `FeStartupPacket::read_fut` should return `FeStartupPacket`
|
||||
let msg = FeStartupPacket::read_fut(&mut self.stream)
|
||||
.await
|
||||
.map_err(ConnectionError::into_io_error)?
|
||||
.map_err(ProtocolError::into_io_error)?
|
||||
.ok_or_else(err_connection)?;
|
||||
|
||||
match msg {
|
||||
@@ -75,7 +75,7 @@ impl<S: AsyncRead + Unpin> PqStream<S> {
|
||||
async fn read_message(&mut self) -> io::Result<FeMessage> {
|
||||
FeMessage::read_fut(&mut self.stream)
|
||||
.await
|
||||
.map_err(ConnectionError::into_io_error)?
|
||||
.map_err(ProtocolError::into_io_error)?
|
||||
.ok_or_else(err_connection)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
# version, we can consider updating.
|
||||
# See https://tracker.debian.org/pkg/rustc for more details on Debian rustc package,
|
||||
# we use "unstable" version number as the highest version used in the project by default.
|
||||
channel = "1.62.1"
|
||||
channel = "1.66.1"
|
||||
profile = "default"
|
||||
# The default profile includes rustc, rust-std, cargo, rust-docs, rustfmt and clippy.
|
||||
# https://rust-lang.github.io/rustup/concepts/profiles.html
|
||||
|
||||
@@ -14,12 +14,14 @@ clap = { version = "4.0", features = ["derive"] }
|
||||
const_format = "0.2.21"
|
||||
crc32c = "0.6.0"
|
||||
fs2 = "0.4.3"
|
||||
futures = "0.3"
|
||||
git-version = "0.3.5"
|
||||
hex = "0.4.3"
|
||||
humantime = "2.1.0"
|
||||
hyper = "0.14"
|
||||
nix = "0.25"
|
||||
once_cell = "1.13.0"
|
||||
pin-project-lite = "0.2"
|
||||
parking_lot = "0.12.1"
|
||||
postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="43e6db254a97fdecbce33d8bc0890accfd74495e" }
|
||||
postgres-protocol = { git = "https://github.com/neondatabase/rust-postgres.git", rev="43e6db254a97fdecbce33d8bc0890accfd74495e" }
|
||||
|
||||
@@ -228,20 +228,20 @@ fn start_safekeeper(conf: SafeKeeperConf) -> Result<()> {
|
||||
|
||||
let conf_cloned = conf.clone();
|
||||
let safekeeper_thread = thread::Builder::new()
|
||||
.name("safekeeper thread".into())
|
||||
.name("WAL service thread".into())
|
||||
.spawn(|| wal_service::thread_main(conf_cloned, pg_listener))
|
||||
.unwrap();
|
||||
|
||||
threads.push(safekeeper_thread);
|
||||
|
||||
let conf_ = conf.clone();
|
||||
threads.push(
|
||||
thread::Builder::new()
|
||||
.name("broker thread".into())
|
||||
.spawn(|| {
|
||||
broker::thread_main(conf_);
|
||||
})?,
|
||||
);
|
||||
// threads.push(
|
||||
// thread::Builder::new()
|
||||
// .name("broker thread".into())
|
||||
// .spawn(|| {
|
||||
// broker::thread_main(conf_);
|
||||
// })?,
|
||||
// );
|
||||
|
||||
let conf_ = conf.clone();
|
||||
threads.push(
|
||||
|
||||
@@ -1,27 +1,23 @@
|
||||
//! Part of Safekeeper pretending to be Postgres, i.e. handling Postgres
|
||||
//! protocol commands.
|
||||
|
||||
use anyhow::{bail, Context};
|
||||
use std::str;
|
||||
use tracing::{info, info_span, Instrument};
|
||||
|
||||
use crate::auth::check_permission;
|
||||
use crate::json_ctrl::{handle_json_ctrl, AppendLogicalMessage};
|
||||
use crate::receive_wal::ReceiveWalConn;
|
||||
|
||||
use crate::send_wal::ReplicationConn;
|
||||
|
||||
use crate::{GlobalTimelines, SafeKeeperConf};
|
||||
use anyhow::Context;
|
||||
|
||||
use postgres_ffi::PG_TLI;
|
||||
use regex::Regex;
|
||||
|
||||
use pq_proto::{BeMessage, FeStartupPacket, RowDescriptor, INT4_OID, TEXT_OID};
|
||||
use std::str;
|
||||
use tracing::info;
|
||||
use regex::Regex;
|
||||
use utils::auth::{Claims, Scope};
|
||||
use utils::postgres_backend_async::QueryError;
|
||||
use utils::{
|
||||
id::{TenantId, TenantTimelineId, TimelineId},
|
||||
lsn::Lsn,
|
||||
postgres_backend::{self, PostgresBackend},
|
||||
postgres_backend_async::{self, PostgresBackend},
|
||||
};
|
||||
|
||||
/// Safekeeper handler of postgres commands
|
||||
@@ -41,9 +37,11 @@ enum SafekeeperPostgresCommand {
|
||||
StartReplication { start_lsn: Lsn },
|
||||
IdentifySystem,
|
||||
JSONCtrl { cmd: AppendLogicalMessage },
|
||||
Show { guc: String },
|
||||
}
|
||||
|
||||
fn parse_cmd(cmd: &str) -> anyhow::Result<SafekeeperPostgresCommand> {
|
||||
let cmd_lowercase = cmd.to_ascii_lowercase();
|
||||
if cmd.starts_with("START_WAL_PUSH") {
|
||||
Ok(SafekeeperPostgresCommand::StartWalPush)
|
||||
} else if cmd.starts_with("START_REPLICATION") {
|
||||
@@ -53,7 +51,7 @@ fn parse_cmd(cmd: &str) -> anyhow::Result<SafekeeperPostgresCommand> {
|
||||
let start_lsn = caps
|
||||
.next()
|
||||
.map(|cap| cap[1].parse::<Lsn>())
|
||||
.context("failed to parse start LSN from START_REPLICATION command")??;
|
||||
.context("parse start LSN from START_REPLICATION command")??;
|
||||
Ok(SafekeeperPostgresCommand::StartReplication { start_lsn })
|
||||
} else if cmd.starts_with("IDENTIFY_SYSTEM") {
|
||||
Ok(SafekeeperPostgresCommand::IdentifySystem)
|
||||
@@ -62,12 +60,21 @@ fn parse_cmd(cmd: &str) -> anyhow::Result<SafekeeperPostgresCommand> {
|
||||
Ok(SafekeeperPostgresCommand::JSONCtrl {
|
||||
cmd: serde_json::from_str(cmd)?,
|
||||
})
|
||||
} else if cmd_lowercase.starts_with("show") {
|
||||
let re = Regex::new(r"show ((?:[[:alpha:]]|_)+)").unwrap();
|
||||
let mut caps = re.captures_iter(&cmd_lowercase);
|
||||
let guc = caps
|
||||
.next()
|
||||
.map(|cap| cap[1].parse::<String>())
|
||||
.context("parse guc in SHOW command")??;
|
||||
Ok(SafekeeperPostgresCommand::Show { guc })
|
||||
} else {
|
||||
anyhow::bail!("unsupported command {cmd}");
|
||||
}
|
||||
}
|
||||
|
||||
impl postgres_backend::Handler for SafekeeperPostgresHandler {
|
||||
#[async_trait::async_trait]
|
||||
impl postgres_backend_async::Handler for SafekeeperPostgresHandler {
|
||||
// tenant_id and timeline_id are passed in connection string params
|
||||
fn startup(
|
||||
&mut self,
|
||||
@@ -137,7 +144,7 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn process_query(
|
||||
async fn process_query(
|
||||
&mut self,
|
||||
pgb: &mut PostgresBackend,
|
||||
query_string: &str,
|
||||
@@ -147,9 +154,11 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler {
|
||||
.starts_with("set datestyle to ")
|
||||
{
|
||||
// important for debug because psycopg2 executes "SET datestyle TO 'ISO'" on connect
|
||||
pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
pgb.write_message_flush(&BeMessage::CommandComplete(b"SELECT 1"))
|
||||
.await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let cmd = parse_cmd(query_string)?;
|
||||
|
||||
info!(
|
||||
@@ -161,14 +170,22 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler {
|
||||
let timeline_id = self.timeline_id.context("timelineid is required")?;
|
||||
self.check_permission(Some(tenant_id))?;
|
||||
self.ttid = TenantTimelineId::new(tenant_id, timeline_id);
|
||||
let span_ttid = self.ttid; // satisfy borrow checker
|
||||
|
||||
let res = match cmd {
|
||||
SafekeeperPostgresCommand::StartWalPush => ReceiveWalConn::new(pgb).run(self),
|
||||
// SafekeeperPostgresCommand::StartWalPush => ReceiveWalConn::new(pgb).run(self),
|
||||
SafekeeperPostgresCommand::StartWalPush => Ok(()),
|
||||
SafekeeperPostgresCommand::StartReplication { start_lsn } => {
|
||||
ReplicationConn::new(pgb).run(self, pgb, start_lsn)
|
||||
self.handle_start_replication(pgb, start_lsn)
|
||||
.instrument(info_span!("WAL sender", ttid = %span_ttid))
|
||||
.await
|
||||
}
|
||||
SafekeeperPostgresCommand::IdentifySystem => self.handle_identify_system(pgb),
|
||||
SafekeeperPostgresCommand::JSONCtrl { ref cmd } => handle_json_ctrl(self, pgb, cmd),
|
||||
SafekeeperPostgresCommand::IdentifySystem => self.handle_identify_system(pgb).await,
|
||||
SafekeeperPostgresCommand::Show { guc } => self.handle_show(guc, pgb).await,
|
||||
SafekeeperPostgresCommand::JSONCtrl { ref cmd } => {
|
||||
handle_json_ctrl(self, pgb, cmd).await
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
match res {
|
||||
@@ -217,7 +234,10 @@ impl SafekeeperPostgresHandler {
|
||||
///
|
||||
/// Handle IDENTIFY_SYSTEM replication command
|
||||
///
|
||||
fn handle_identify_system(&mut self, pgb: &mut PostgresBackend) -> Result<(), QueryError> {
|
||||
async fn handle_identify_system(
|
||||
&mut self,
|
||||
pgb: &mut PostgresBackend,
|
||||
) -> Result<(), QueryError> {
|
||||
let tli = GlobalTimelines::get(self.ttid)?;
|
||||
|
||||
let lsn = if self.is_walproposer_recovery() {
|
||||
@@ -235,7 +255,7 @@ impl SafekeeperPostgresHandler {
|
||||
let tli_bytes = tli.as_bytes();
|
||||
let sysid_bytes = sysid.as_bytes();
|
||||
|
||||
pgb.write_message_noflush(&BeMessage::RowDescription(&[
|
||||
pgb.write_message(&BeMessage::RowDescription(&[
|
||||
RowDescriptor {
|
||||
name: b"systemid",
|
||||
typoid: TEXT_OID,
|
||||
@@ -261,13 +281,48 @@ impl SafekeeperPostgresHandler {
|
||||
..Default::default()
|
||||
},
|
||||
]))?
|
||||
.write_message_noflush(&BeMessage::DataRow(&[
|
||||
.write_message(&BeMessage::DataRow(&[
|
||||
Some(sysid_bytes),
|
||||
Some(tli_bytes),
|
||||
Some(lsn_bytes),
|
||||
None,
|
||||
]))?
|
||||
.write_message(&BeMessage::CommandComplete(b"IDENTIFY_SYSTEM"))?;
|
||||
.write_message_flush(&BeMessage::CommandComplete(b"IDENTIFY_SYSTEM"))
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_show(
|
||||
&mut self,
|
||||
guc: String,
|
||||
pgb: &mut PostgresBackend,
|
||||
) -> Result<(), QueryError> {
|
||||
match guc.as_str() {
|
||||
// pg_receivewal wants it
|
||||
"data_directory_mode" => {
|
||||
pgb.write_message(&BeMessage::RowDescription(&[RowDescriptor::int8_col(
|
||||
b"data_directory_mode",
|
||||
)]))?
|
||||
// xxx we could return real one, not just 0700
|
||||
.write_message(&BeMessage::DataRow(&[Some(0700.to_string().as_bytes())]))?
|
||||
.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
}
|
||||
// pg_receivewal wants it
|
||||
"wal_segment_size" => {
|
||||
let tli = GlobalTimelines::get(self.ttid)?;
|
||||
let wal_seg_size = tli.get_state().1.server.wal_seg_size;
|
||||
let wal_seg_size_mb = (wal_seg_size / 1024 / 1024).to_string() + "MB";
|
||||
|
||||
pgb.write_message(&BeMessage::RowDescription(&[RowDescriptor::text_col(
|
||||
b"wal_segment_size",
|
||||
)]))?
|
||||
.write_message(&BeMessage::DataRow(&[Some(wal_seg_size_mb.as_bytes())]))?
|
||||
.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
}
|
||||
_ => {
|
||||
return Err(anyhow::anyhow!("SHOW of unknown setting").into());
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -8,11 +8,14 @@ use serde::Serialize;
|
||||
use serde::Serializer;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::fmt::Display;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use storage_broker::proto::SafekeeperTimelineInfo;
|
||||
use storage_broker::proto::TenantTimelineId as ProtoTenantTimelineId;
|
||||
use tokio::task::JoinError;
|
||||
|
||||
use crate::json_ctrl::append_logical_message;
|
||||
use crate::json_ctrl::AppendLogicalMessage;
|
||||
use crate::safekeeper::ServerInfo;
|
||||
use crate::safekeeper::Term;
|
||||
|
||||
@@ -191,6 +194,50 @@ async fn timeline_create_handler(mut request: Request<Body>) -> Result<Response<
|
||||
json_response(StatusCode::OK, ())
|
||||
}
|
||||
|
||||
// Create fake timeline + insert some valid WAL. Useful to test WAL streaming
|
||||
// from safekeeper in isolation, e.g.
|
||||
// pg_receivewal -v -d "host=localhost port=5454 options='-c tenant_id=deadbeefdeadbeefdeadbeefdeadbeef timeline_id=deadbeefdeadbeefdeadbeefdeadbeef'" -D ~/tmp/tmp/tmp
|
||||
// (hacking pg_receivewal startpos is currently needed though to make pg_receivewal work)
|
||||
async fn create_fake_timeline_handler(_request: Request<Body>) -> Result<Response<Body>, ApiError> {
|
||||
let ttid = TenantTimelineId {
|
||||
tenant_id: TenantId::from_str("deadbeefdeadbeefdeadbeefdeadbeef")
|
||||
.expect("timeline_id parsing failed"),
|
||||
timeline_id: TimelineId::from_str("deadbeefdeadbeefdeadbeefdeadbeef")
|
||||
.expect("tenant_id parsing failed"),
|
||||
};
|
||||
let pg_version = 150000;
|
||||
let server_info = ServerInfo {
|
||||
pg_version,
|
||||
system_id: 0,
|
||||
wal_seg_size: WAL_SEGMENT_SIZE as u32,
|
||||
};
|
||||
let init_lsn = Lsn(0x1493AC8);
|
||||
tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
|
||||
let tli = GlobalTimelines::create(ttid, server_info, init_lsn, init_lsn)?;
|
||||
let mut begin_lsn = init_lsn;
|
||||
for _ in 0..16 {
|
||||
let append = AppendLogicalMessage {
|
||||
lm_prefix: "db".to_owned(),
|
||||
lm_message: "hahabubu".to_owned(),
|
||||
set_commit_lsn: true,
|
||||
send_proposer_elected: false, // actually ignored here
|
||||
term: 0,
|
||||
epoch_start_lsn: init_lsn,
|
||||
begin_lsn,
|
||||
truncate_lsn: init_lsn,
|
||||
pg_version,
|
||||
};
|
||||
let inserted = append_logical_message(&tli, &append)?;
|
||||
begin_lsn = inserted.end_lsn;
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
.await
|
||||
.map_err(|e| ApiError::InternalServerError(e.into()))?
|
||||
.map_err(ApiError::InternalServerError)?;
|
||||
json_response(StatusCode::OK, ())
|
||||
}
|
||||
|
||||
/// Deactivates the timeline and removes its data directory.
|
||||
async fn timeline_delete_force_handler(
|
||||
mut request: Request<Body>,
|
||||
@@ -302,6 +349,7 @@ pub fn make_router(conf: SafeKeeperConf) -> RouterBuilder<hyper::Body, ApiError>
|
||||
.get("/v1/status", status_handler)
|
||||
// Will be used in the future instead of implicit timeline creation
|
||||
.post("/v1/tenant/timeline", timeline_create_handler)
|
||||
.post("/v1/fake_timeline", create_fake_timeline_handler)
|
||||
.get(
|
||||
"/v1/tenant/:tenant_id/timeline/:timeline_id",
|
||||
timeline_status_handler,
|
||||
|
||||
@@ -26,26 +26,26 @@ use crate::GlobalTimelines;
|
||||
use postgres_ffi::encode_logical_message;
|
||||
use postgres_ffi::WAL_SEGMENT_SIZE;
|
||||
use pq_proto::{BeMessage, RowDescriptor, TEXT_OID};
|
||||
use utils::{lsn::Lsn, postgres_backend::PostgresBackend};
|
||||
use utils::{lsn::Lsn, postgres_backend_async::PostgresBackend};
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct AppendLogicalMessage {
|
||||
// prefix and message to build LogicalMessage
|
||||
lm_prefix: String,
|
||||
lm_message: String,
|
||||
pub lm_prefix: String,
|
||||
pub lm_message: String,
|
||||
|
||||
// if true, commit_lsn will match flush_lsn after append
|
||||
set_commit_lsn: bool,
|
||||
pub set_commit_lsn: bool,
|
||||
|
||||
// if true, ProposerElected will be sent before append
|
||||
send_proposer_elected: bool,
|
||||
pub send_proposer_elected: bool,
|
||||
|
||||
// fields from AppendRequestHeader
|
||||
term: Term,
|
||||
epoch_start_lsn: Lsn,
|
||||
begin_lsn: Lsn,
|
||||
truncate_lsn: Lsn,
|
||||
pg_version: u32,
|
||||
pub term: Term,
|
||||
pub epoch_start_lsn: Lsn,
|
||||
pub begin_lsn: Lsn,
|
||||
pub truncate_lsn: Lsn,
|
||||
pub pg_version: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
@@ -59,7 +59,7 @@ struct AppendResult {
|
||||
/// Handles command to craft logical message WAL record with given
|
||||
/// content, and then append it with specified term and lsn. This
|
||||
/// function is used to test safekeepers in different scenarios.
|
||||
pub fn handle_json_ctrl(
|
||||
pub async fn handle_json_ctrl(
|
||||
spg: &SafekeeperPostgresHandler,
|
||||
pgb: &mut PostgresBackend,
|
||||
append_request: &AppendLogicalMessage,
|
||||
@@ -82,14 +82,15 @@ pub fn handle_json_ctrl(
|
||||
let response_data = serde_json::to_vec(&response)
|
||||
.with_context(|| format!("Response {response:?} is not a json array"))?;
|
||||
|
||||
pgb.write_message_noflush(&BeMessage::RowDescription(&[RowDescriptor {
|
||||
pgb.write_message(&BeMessage::RowDescription(&[RowDescriptor {
|
||||
name: b"json",
|
||||
typoid: TEXT_OID,
|
||||
typlen: -1,
|
||||
..Default::default()
|
||||
}]))?
|
||||
.write_message_noflush(&BeMessage::DataRow(&[Some(&response_data)]))?
|
||||
.write_message(&BeMessage::CommandComplete(b"JSON_CTRL"))?;
|
||||
.write_message(&BeMessage::DataRow(&[Some(&response_data)]))?
|
||||
.write_message_flush(&BeMessage::CommandComplete(b"JSON_CTRL"))
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -128,15 +129,15 @@ fn send_proposer_elected(tli: &Arc<Timeline>, term: Term, lsn: Lsn) -> anyhow::R
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct InsertedWAL {
|
||||
pub struct InsertedWAL {
|
||||
begin_lsn: Lsn,
|
||||
end_lsn: Lsn,
|
||||
pub end_lsn: Lsn,
|
||||
append_response: AppendResponse,
|
||||
}
|
||||
|
||||
/// Extend local WAL with new LogicalMessage record. To do that,
|
||||
/// create AppendRequest with new WAL and pass it to safekeeper.
|
||||
fn append_logical_message(
|
||||
pub fn append_logical_message(
|
||||
tli: &Arc<Timeline>,
|
||||
msg: &AppendLogicalMessage,
|
||||
) -> anyhow::Result<InsertedWAL> {
|
||||
|
||||
@@ -26,7 +26,7 @@ use crate::safekeeper::ProposerAcceptorMessage;
|
||||
|
||||
use crate::handler::SafekeeperPostgresHandler;
|
||||
use pq_proto::{BeMessage, FeMessage};
|
||||
use utils::{postgres_backend::PostgresBackend, sock_split::ReadStream};
|
||||
use utils::{postgres_backend_async::PostgresBackend, sock_split::ReadStream};
|
||||
|
||||
pub struct ReceiveWalConn<'pg> {
|
||||
/// Postgres connection
|
||||
@@ -59,82 +59,83 @@ impl<'pg> ReceiveWalConn<'pg> {
|
||||
// Notify the libpq client that it's allowed to send `CopyData` messages
|
||||
self.pg_backend
|
||||
.write_message(&BeMessage::CopyBothResponse)?;
|
||||
|
||||
let r = self
|
||||
.pg_backend
|
||||
.take_stream_in()
|
||||
.ok_or_else(|| anyhow!("failed to take read stream from pgbackend"))?;
|
||||
let mut poll_reader = ProposerPollStream::new(r)?;
|
||||
Ok(())
|
||||
// let r = self
|
||||
// .pg_backend
|
||||
// .take_stream_in()
|
||||
// .ok_or_else(|| anyhow!("failed to take read stream from pgbackend"))?;
|
||||
// let mut poll_reader = ProposerPollStream::new(r)?;
|
||||
|
||||
// Receive information about server
|
||||
let next_msg = poll_reader.recv_msg()?;
|
||||
let tli = match next_msg {
|
||||
ProposerAcceptorMessage::Greeting(ref greeting) => {
|
||||
info!(
|
||||
"start handshake with walproposer {} sysid {} timeline {}",
|
||||
self.peer_addr, greeting.system_id, greeting.tli,
|
||||
);
|
||||
let server_info = ServerInfo {
|
||||
pg_version: greeting.pg_version,
|
||||
system_id: greeting.system_id,
|
||||
wal_seg_size: greeting.wal_seg_size,
|
||||
};
|
||||
GlobalTimelines::create(spg.ttid, server_info, Lsn::INVALID, Lsn::INVALID)?
|
||||
}
|
||||
_ => {
|
||||
return Err(QueryError::Other(anyhow::anyhow!(
|
||||
"unexpected message {next_msg:?} instead of greeting"
|
||||
)))
|
||||
}
|
||||
};
|
||||
// let next_msg = poll_reader.recv_msg()?;
|
||||
// let tli = match next_msg {
|
||||
// ProposerAcceptorMessage::Greeting(ref greeting) => {
|
||||
// info!(
|
||||
// "start handshake with walproposer {} sysid {} timeline {}",
|
||||
// self.peer_addr, greeting.system_id, greeting.tli,
|
||||
// );
|
||||
// let server_info = ServerInfo {
|
||||
// pg_version: greeting.pg_version,
|
||||
// system_id: greeting.system_id,
|
||||
// wal_seg_size: greeting.wal_seg_size,
|
||||
// };
|
||||
// GlobalTimelines::create(spg.ttid, server_info, Lsn::INVALID, Lsn::INVALID)?
|
||||
// }
|
||||
// _ => {
|
||||
// return Err(QueryError::Other(anyhow::anyhow!(
|
||||
// "unexpected message {next_msg:?} instead of greeting"
|
||||
// )))
|
||||
// }
|
||||
// };
|
||||
|
||||
let mut next_msg = Some(next_msg);
|
||||
// let mut next_msg = None;
|
||||
|
||||
let mut first_time_through = true;
|
||||
let mut _guard: Option<ComputeConnectionGuard> = None;
|
||||
loop {
|
||||
if matches!(next_msg, Some(ProposerAcceptorMessage::AppendRequest(_))) {
|
||||
// poll AppendRequest's without blocking and write WAL to disk without flushing,
|
||||
// while it's readily available
|
||||
while let Some(ProposerAcceptorMessage::AppendRequest(append_request)) = next_msg {
|
||||
let msg = ProposerAcceptorMessage::NoFlushAppendRequest(append_request);
|
||||
// let mut first_time_through = true;
|
||||
// let mut _guard: Option<ComputeConnectionGuard> = None;
|
||||
// loop {
|
||||
// if matches!(next_msg, Some(ProposerAcceptorMessage::AppendRequest(_))) {
|
||||
// // poll AppendRequest's without blocking and write WAL to disk without flushing,
|
||||
// // while it's readily available
|
||||
// while let Some(ProposerAcceptorMessage::AppendRequest(append_request)) = next_msg {
|
||||
// let msg = ProposerAcceptorMessage::NoFlushAppendRequest(append_request);
|
||||
|
||||
let reply = tli.process_msg(&msg)?;
|
||||
if let Some(reply) = reply {
|
||||
self.write_msg(&reply)?;
|
||||
}
|
||||
// let reply = tli.process_msg(&msg)?;
|
||||
// if let Some(reply) = reply {
|
||||
// self.write_msg(&reply)?;
|
||||
// }
|
||||
|
||||
next_msg = poll_reader.poll_msg();
|
||||
}
|
||||
// // next_msg = poll_reader.poll_msg();
|
||||
// next_msg = poll_reader.poll_msg();
|
||||
// }
|
||||
|
||||
// flush all written WAL to the disk
|
||||
let reply = tli.process_msg(&ProposerAcceptorMessage::FlushWAL)?;
|
||||
if let Some(reply) = reply {
|
||||
self.write_msg(&reply)?;
|
||||
}
|
||||
} else if let Some(msg) = next_msg.take() {
|
||||
// process other message
|
||||
let reply = tli.process_msg(&msg)?;
|
||||
if let Some(reply) = reply {
|
||||
self.write_msg(&reply)?;
|
||||
}
|
||||
}
|
||||
if first_time_through {
|
||||
// Register the connection and defer unregister. Do that only
|
||||
// after processing first message, as it sets wal_seg_size,
|
||||
// wanted by many.
|
||||
tli.on_compute_connect()?;
|
||||
_guard = Some(ComputeConnectionGuard {
|
||||
timeline: Arc::clone(&tli),
|
||||
});
|
||||
first_time_through = false;
|
||||
}
|
||||
// // flush all written WAL to the disk
|
||||
// let reply = tli.process_msg(&ProposerAcceptorMessage::FlushWAL)?;
|
||||
// if let Some(reply) = reply {
|
||||
// self.write_msg(&reply)?;
|
||||
// }
|
||||
// } else if let Some(msg) = next_msg.take() {
|
||||
// // process other message
|
||||
// let reply = tli.process_msg(&msg)?;
|
||||
// if let Some(reply) = reply {
|
||||
// self.write_msg(&reply)?;
|
||||
// }
|
||||
// }
|
||||
// if first_time_through {
|
||||
// // Register the connection and defer unregister. Do that only
|
||||
// // after processing first message, as it sets wal_seg_size,
|
||||
// // wanted by many.
|
||||
// tli.on_compute_connect()?;
|
||||
// _guard = Some(ComputeConnectionGuard {
|
||||
// timeline: Arc::clone(&tli),
|
||||
// });
|
||||
// first_time_through = false;
|
||||
// }
|
||||
|
||||
// blocking wait for the next message
|
||||
if next_msg.is_none() {
|
||||
next_msg = Some(poll_reader.recv_msg()?);
|
||||
}
|
||||
}
|
||||
// // blocking wait for the next message
|
||||
// if next_msg.is_none() {
|
||||
// next_msg = Some(poll_reader.recv_msg()?);
|
||||
// }
|
||||
// }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -144,37 +145,37 @@ struct ProposerPollStream {
|
||||
}
|
||||
|
||||
impl ProposerPollStream {
|
||||
fn new(mut r: ReadStream) -> anyhow::Result<Self> {
|
||||
let (msg_tx, msg_rx) = channel();
|
||||
// fn new(mut r: ReadStream) -> anyhow::Result<Self> {
|
||||
// let (msg_tx, msg_rx) = channel();
|
||||
|
||||
let read_thread = thread::Builder::new()
|
||||
.name("Read WAL thread".into())
|
||||
.spawn(move || -> Result<(), QueryError> {
|
||||
loop {
|
||||
let copy_data = match FeMessage::read(&mut r)? {
|
||||
Some(FeMessage::CopyData(bytes)) => Ok(bytes),
|
||||
Some(msg) => Err(QueryError::Other(anyhow::anyhow!(
|
||||
"expected `CopyData` message, found {msg:?}"
|
||||
))),
|
||||
None => Err(QueryError::from(std::io::Error::new(
|
||||
std::io::ErrorKind::ConnectionAborted,
|
||||
"walproposer closed the connection",
|
||||
))),
|
||||
}?;
|
||||
// let read_thread = thread::Builder::new()
|
||||
// .name("Read WAL thread".into())
|
||||
// .spawn(move || -> Result<(), QueryError> {
|
||||
// loop {
|
||||
// let copy_data = match FeMessage::read(&mut r)? {
|
||||
// Some(FeMessage::CopyData(bytes)) => Ok(bytes),
|
||||
// Some(msg) => Err(QueryError::Other(anyhow::anyhow!(
|
||||
// "expected `CopyData` message, found {msg:?}"
|
||||
// ))),
|
||||
// None => Err(QueryError::from(std::io::Error::new(
|
||||
// std::io::ErrorKind::ConnectionAborted,
|
||||
// "walproposer closed the connection",
|
||||
// ))),
|
||||
// }?;
|
||||
|
||||
let msg = ProposerAcceptorMessage::parse(copy_data)?;
|
||||
msg_tx
|
||||
.send(msg)
|
||||
.context("Failed to send the proposer message")?;
|
||||
}
|
||||
// msg_tx will be dropped here, this will also close msg_rx
|
||||
})?;
|
||||
// let msg = ProposerAcceptorMessage::parse(copy_data)?;
|
||||
// msg_tx
|
||||
// .send(msg)
|
||||
// .context("Failed to send the proposer message")?;
|
||||
// }
|
||||
// // msg_tx will be dropped here, this will also close msg_rx
|
||||
// })?;
|
||||
|
||||
Ok(Self {
|
||||
msg_rx,
|
||||
read_thread: Some(read_thread),
|
||||
})
|
||||
}
|
||||
// Ok(Self {
|
||||
// msg_rx,
|
||||
// read_thread: Some(read_thread),
|
||||
// })
|
||||
// }
|
||||
|
||||
fn recv_msg(&mut self) -> Result<ProposerAcceptorMessage, QueryError> {
|
||||
self.msg_rx.recv().map_err(|_| {
|
||||
|
||||
@@ -1,28 +1,36 @@
|
||||
//! This module implements the streaming side of replication protocol, starting
|
||||
//! with the "START_REPLICATION" message.
|
||||
|
||||
use anyhow::Context as AnyhowContext;
|
||||
use bytes::Bytes;
|
||||
use futures::future::BoxFuture;
|
||||
use pin_project_lite::pin_project;
|
||||
use postgres_ffi::get_current_timestamp;
|
||||
use postgres_ffi::{TimestampTz, MAX_SEND_SIZE};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::cell::RefCell;
|
||||
use std::cmp::min;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::rc::Rc;
|
||||
use std::sync::Arc;
|
||||
use std::task::{ready, Context, Poll};
|
||||
use std::time::Duration;
|
||||
use std::{io, str, thread};
|
||||
use tokio::sync::watch::Receiver;
|
||||
use tokio::time::timeout;
|
||||
use tracing::*;
|
||||
use utils::postgres_backend_async::QueryError;
|
||||
use utils::send_rc::RefCellSend;
|
||||
use utils::send_rc::SendRc;
|
||||
|
||||
use pq_proto::{BeMessage, FeMessage, ReplicationFeedback, WalSndKeepAlive, XLogDataBody};
|
||||
use utils::{bin_ser::BeSer, lsn::Lsn, postgres_backend_async::PostgresBackend};
|
||||
|
||||
use crate::handler::SafekeeperPostgresHandler;
|
||||
use crate::timeline::{ReplicaState, Timeline};
|
||||
use crate::wal_storage::WalReader;
|
||||
use crate::GlobalTimelines;
|
||||
use anyhow::Context;
|
||||
|
||||
use bytes::Bytes;
|
||||
use postgres_ffi::get_current_timestamp;
|
||||
use postgres_ffi::{TimestampTz, MAX_SEND_SIZE};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::cmp::min;
|
||||
use std::net::Shutdown;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::{io, str, thread};
|
||||
use utils::postgres_backend_async::QueryError;
|
||||
|
||||
use pq_proto::{BeMessage, FeMessage, ReplicationFeedback, WalSndKeepAlive, XLogDataBody};
|
||||
use tokio::sync::watch::Receiver;
|
||||
use tokio::time::timeout;
|
||||
use tracing::*;
|
||||
use utils::{bin_ser::BeSer, lsn::Lsn, postgres_backend::PostgresBackend, sock_split::ReadStream};
|
||||
|
||||
// See: https://www.postgresql.org/docs/13/protocol-replication.html
|
||||
const HOT_STANDBY_FEEDBACK_TAG_BYTE: u8 = b'h';
|
||||
@@ -60,13 +68,6 @@ pub struct StandbyReply {
|
||||
pub reply_requested: bool,
|
||||
}
|
||||
|
||||
/// A network connection that's speaking the replication protocol.
|
||||
pub struct ReplicationConn {
|
||||
/// This is an `Option` because we will spawn a background thread that will
|
||||
/// `take` it from us.
|
||||
stream_in: Option<ReadStream>,
|
||||
}
|
||||
|
||||
/// Scope guard to unregister replication connection from timeline
|
||||
struct ReplicationConnGuard {
|
||||
replica: usize, // replica internal ID assigned by timeline
|
||||
@@ -79,230 +80,418 @@ impl Drop for ReplicationConnGuard {
|
||||
}
|
||||
}
|
||||
|
||||
impl ReplicationConn {
|
||||
/// Create a new `ReplicationConn`
|
||||
pub fn new(pgb: &mut PostgresBackend) -> Self {
|
||||
Self {
|
||||
stream_in: pgb.take_stream_in(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle incoming messages from the network.
|
||||
/// This is spawned into the background by `handle_start_replication`.
|
||||
fn background_thread(
|
||||
mut stream_in: ReadStream,
|
||||
replica_guard: Arc<ReplicationConnGuard>,
|
||||
) -> anyhow::Result<()> {
|
||||
let replica_id = replica_guard.replica;
|
||||
let timeline = &replica_guard.timeline;
|
||||
|
||||
let mut state = ReplicaState::new();
|
||||
// Wait for replica's feedback.
|
||||
while let Some(msg) = FeMessage::read(&mut stream_in)? {
|
||||
match &msg {
|
||||
FeMessage::CopyData(m) => {
|
||||
// There's three possible data messages that the client is supposed to send here:
|
||||
// `HotStandbyFeedback` and `StandbyStatusUpdate` and `NeonStandbyFeedback`.
|
||||
|
||||
match m.first().cloned() {
|
||||
Some(HOT_STANDBY_FEEDBACK_TAG_BYTE) => {
|
||||
// Note: deserializing is on m[1..] because we skip the tag byte.
|
||||
state.hs_feedback = HotStandbyFeedback::des(&m[1..])
|
||||
.context("failed to deserialize HotStandbyFeedback")?;
|
||||
timeline.update_replica_state(replica_id, state);
|
||||
}
|
||||
Some(STANDBY_STATUS_UPDATE_TAG_BYTE) => {
|
||||
let _reply = StandbyReply::des(&m[1..])
|
||||
.context("failed to deserialize StandbyReply")?;
|
||||
// This must be a regular postgres replica,
|
||||
// because pageserver doesn't send this type of messages to safekeeper.
|
||||
// Currently this is not implemented, so this message is ignored.
|
||||
|
||||
warn!("unexpected StandbyReply. Read-only postgres replicas are not supported in safekeepers yet.");
|
||||
// timeline.update_replica_state(replica_id, Some(state));
|
||||
}
|
||||
Some(NEON_STATUS_UPDATE_TAG_BYTE) => {
|
||||
// Note: deserializing is on m[9..] because we skip the tag byte and len bytes.
|
||||
let buf = Bytes::copy_from_slice(&m[9..]);
|
||||
let reply = ReplicationFeedback::parse(buf);
|
||||
|
||||
trace!("ReplicationFeedback is {:?}", reply);
|
||||
// Only pageserver sends ReplicationFeedback, so set the flag.
|
||||
// This replica is the source of information to resend to compute.
|
||||
state.pageserver_feedback = Some(reply);
|
||||
|
||||
timeline.update_replica_state(replica_id, state);
|
||||
}
|
||||
_ => warn!("unexpected message {:?}", msg),
|
||||
}
|
||||
}
|
||||
FeMessage::Sync => {}
|
||||
FeMessage::CopyFail => {
|
||||
// Shutdown the connection, because rust-postgres client cannot be dropped
|
||||
// when connection is alive.
|
||||
let _ = stream_in.shutdown(Shutdown::Both);
|
||||
anyhow::bail!("Copy failed");
|
||||
}
|
||||
_ => {
|
||||
// We only handle `CopyData`, 'Sync', 'CopyFail' messages. Anything else is ignored.
|
||||
info!("unexpected message {:?}", msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
///
|
||||
/// Handle START_REPLICATION replication command
|
||||
///
|
||||
pub fn run(
|
||||
impl SafekeeperPostgresHandler {
|
||||
pub async fn handle_start_replication(
|
||||
&mut self,
|
||||
spg: &mut SafekeeperPostgresHandler,
|
||||
pgb: &mut PostgresBackend,
|
||||
mut start_pos: Lsn,
|
||||
start_pos: Lsn,
|
||||
) -> Result<(), QueryError> {
|
||||
let _enter = info_span!("WAL sender", ttid = %spg.ttid).entered();
|
||||
|
||||
let tli = GlobalTimelines::get(spg.ttid)?;
|
||||
|
||||
// spawn the background thread which receives HotStandbyFeedback messages.
|
||||
let bg_timeline = Arc::clone(&tli);
|
||||
let bg_stream_in = self.stream_in.take().unwrap();
|
||||
let bg_timeline_id = spg.timeline_id.unwrap();
|
||||
let appname = self.appname.clone();
|
||||
let tli = GlobalTimelines::get(self.ttid)?;
|
||||
|
||||
let state = ReplicaState::new();
|
||||
// This replica_id is used below to check if it's time to stop replication.
|
||||
let replica_id = bg_timeline.add_replica(state);
|
||||
let replica_id = tli.add_replica(state);
|
||||
|
||||
// Use a guard object to remove our entry from the timeline, when the background
|
||||
// thread and us have both finished using it.
|
||||
let replica_guard = Arc::new(ReplicationConnGuard {
|
||||
let _guard = Arc::new(ReplicationConnGuard {
|
||||
replica: replica_id,
|
||||
timeline: bg_timeline,
|
||||
timeline: tli.clone(),
|
||||
});
|
||||
let bg_replica_guard = Arc::clone(&replica_guard);
|
||||
|
||||
// TODO: here we got two threads, one for writing WAL and one for receiving
|
||||
// feedback. If one of them fails, we should shutdown the other one too.
|
||||
let _ = thread::Builder::new()
|
||||
.name("HotStandbyFeedback thread".into())
|
||||
.spawn(move || {
|
||||
let _enter =
|
||||
info_span!("HotStandbyFeedback thread", timeline = %bg_timeline_id).entered();
|
||||
if let Err(err) = Self::background_thread(bg_stream_in, bg_replica_guard) {
|
||||
error!("Replication background thread failed: {}", err);
|
||||
// Walproposer gets special handling: safekeeper must give proposer all
|
||||
// local WAL till the end, whether committed or not (walproposer will
|
||||
// hang otherwise). That's because walproposer runs the consensus and
|
||||
// synchronizes safekeepers on the most advanced one.
|
||||
//
|
||||
// There is a small risk of this WAL getting concurrently garbaged if
|
||||
// another compute rises which collects majority and starts fixing log
|
||||
// on this safekeeper itself. That's ok as (old) proposer will never be
|
||||
// able to commit such WAL.
|
||||
let stop_pos: Option<Lsn> = if self.is_walproposer_recovery() {
|
||||
let wal_end = tli.get_flush_lsn();
|
||||
Some(wal_end)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let end_pos = stop_pos.unwrap_or(Lsn::INVALID);
|
||||
|
||||
info!(
|
||||
"starting streaming from {:?} till {:?}",
|
||||
start_pos, stop_pos
|
||||
);
|
||||
|
||||
// switch to copy
|
||||
pgb.write_message(&BeMessage::CopyBothResponse)?;
|
||||
|
||||
let (_, persisted_state) = tli.get_state();
|
||||
let wal_reader = WalReader::new(
|
||||
self.conf.workdir.clone(),
|
||||
self.conf.timeline_dir(&tli.ttid),
|
||||
&persisted_state,
|
||||
start_pos,
|
||||
self.conf.wal_backup_enabled,
|
||||
)?;
|
||||
let write_ctx = SendRc::new(WriteContext {
|
||||
wal_reader: RefCell::new(wal_reader),
|
||||
send_buf: RefCell::new([0; MAX_SEND_SIZE]),
|
||||
});
|
||||
|
||||
let mut c = ReplicationContext {
|
||||
tli,
|
||||
replica_id,
|
||||
appname,
|
||||
pgb,
|
||||
start_pos,
|
||||
end_pos,
|
||||
stop_pos,
|
||||
write_ctx,
|
||||
feedback: ReplicaState::new(),
|
||||
};
|
||||
|
||||
let _phantom_wf = c.wait_wal_fut();
|
||||
let real_end_pos = c.end_pos;
|
||||
c.end_pos = c.start_pos + 1; // to well form read_wal future
|
||||
let _phantom_rf = c.read_wal_fut();
|
||||
c.end_pos = real_end_pos;
|
||||
|
||||
ReplicationHandler {
|
||||
c,
|
||||
write_state: WriteState::FlushWal,
|
||||
_phantom_wf,
|
||||
_phantom_rf,
|
||||
}
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
/// START_REPLICATION stream driver: sends WAL and receives feedback.
|
||||
struct ReplicationHandler<'a, WF, RF>
|
||||
where
|
||||
WF: Future<Output = anyhow::Result<Option<Lsn>>>,
|
||||
RF: Future<Output = anyhow::Result<usize>>,
|
||||
{
|
||||
c: ReplicationContext<'a>,
|
||||
#[pin]
|
||||
write_state: WriteState<WF, RF>,
|
||||
// To deduce anonymous types.
|
||||
_phantom_wf: WF,
|
||||
_phantom_rf: RF,
|
||||
}
|
||||
}
|
||||
|
||||
/// Data ReplicationHandler maintains. Separated so we could generate WriteState
|
||||
/// futures during init, deducing their type.
|
||||
struct ReplicationContext<'a> {
|
||||
tli: Arc<Timeline>,
|
||||
appname: Option<String>,
|
||||
replica_id: usize,
|
||||
pgb: &'a mut PostgresBackend,
|
||||
// Position since which we are sending next chunk.
|
||||
start_pos: Lsn,
|
||||
// WAL up to this position is known to be locally available.
|
||||
end_pos: Lsn,
|
||||
// If present, terminate after reaching this position; used by walproposer
|
||||
// in recovery.
|
||||
stop_pos: Option<Lsn>,
|
||||
// This data is needed to create Future sending WAL, so we need to both have
|
||||
// it here (to create new future) and borrow it to the future itself.
|
||||
// Essentially this is a self referential struct. To satisfy borrow checker,
|
||||
// use Rc<RefCell>. To make ReplicationHandler itself Send'able future, wrap
|
||||
// it into SendRc; this is safe as ReplicationHandler is passed between
|
||||
// threads only as a whole (during rescheduling).
|
||||
//
|
||||
// Right now we're in CurrentThread runtime, so Send is somewhat redundant;
|
||||
// however, otherwise we'd need to inconveniently have separate !Send
|
||||
// version of pg backend Handler trait (and work with LocalSet).
|
||||
write_ctx: SendRc<WriteContext>,
|
||||
feedback: ReplicaState,
|
||||
}
|
||||
|
||||
// State which ReplicationHandler needs to create futures sending data.
|
||||
struct WriteContext {
|
||||
wal_reader: RefCell<WalReader>,
|
||||
// buffer for readling WAL into to send it
|
||||
send_buf: RefCell<[u8; MAX_SEND_SIZE]>,
|
||||
}
|
||||
|
||||
// Yield points of WAL sending machinery.
|
||||
pin_project! {
|
||||
#[project = WriteStateProj]
|
||||
enum WriteState<WF, RF>
|
||||
where
|
||||
WF: Future<Output = anyhow::Result<Option<Lsn>>>,
|
||||
RF: Future<Output = anyhow::Result<usize>>,
|
||||
{
|
||||
WaitWal{ #[pin] fut: WF},
|
||||
ReadWal{ #[pin] fut: RF},
|
||||
FlushWal,
|
||||
}
|
||||
}
|
||||
|
||||
impl<WF, RF> Future for ReplicationHandler<'_, WF, RF>
|
||||
where
|
||||
WF: Future<Output = anyhow::Result<Option<Lsn>>>,
|
||||
RF: Future<Output = anyhow::Result<usize>>,
|
||||
{
|
||||
type Output = Result<(), QueryError>;
|
||||
|
||||
// We need to read feedback from the socket and write data there at the same
|
||||
// time. To avoid having to split socket, which creates messy split-join
|
||||
// APIs, is problematic with TLS [1] and needs to manage two tasks, just run
|
||||
// single task and use poll interfaces, basically manual state machine,
|
||||
// which is simple here.
|
||||
//
|
||||
// [1] https://github.com/tokio-rs/tls/issues/40
|
||||
//
|
||||
// Completes only when the stream is over, technically on error currently.
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
if let Poll::Ready(r) = self.as_mut().poll_read(cx) {
|
||||
return Poll::Ready(r);
|
||||
}
|
||||
self.as_mut().poll_write(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<WF, RF> ReplicationHandler<'_, WF, RF>
|
||||
where
|
||||
WF: Future<Output = anyhow::Result<Option<Lsn>>>,
|
||||
RF: Future<Output = anyhow::Result<usize>>,
|
||||
{
|
||||
// Poll reading, i.e. getting feedback and processing it. Completes only on error/end of stream.
|
||||
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), QueryError>> {
|
||||
loop {
|
||||
match ready!(self.as_mut().project().c.pgb.poll_read_message(cx)) {
|
||||
Ok(Some(msg)) => self.as_mut().handle_feedback(&msg)?,
|
||||
Ok(None) => {
|
||||
return Poll::Ready(Err(QueryError::Other(anyhow::anyhow!(
|
||||
"EOF on replication stream"
|
||||
))))
|
||||
}
|
||||
})?;
|
||||
|
||||
let runtime = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()?;
|
||||
|
||||
runtime.block_on(async move {
|
||||
let (inmem_state, persisted_state) = tli.get_state();
|
||||
// add persisted_state.timeline_start_lsn == Lsn(0) check
|
||||
|
||||
// Walproposer gets special handling: safekeeper must give proposer all
|
||||
// local WAL till the end, whether committed or not (walproposer will
|
||||
// hang otherwise). That's because walproposer runs the consensus and
|
||||
// synchronizes safekeepers on the most advanced one.
|
||||
//
|
||||
// There is a small risk of this WAL getting concurrently garbaged if
|
||||
// another compute rises which collects majority and starts fixing log
|
||||
// on this safekeeper itself. That's ok as (old) proposer will never be
|
||||
// able to commit such WAL.
|
||||
let stop_pos: Option<Lsn> = if spg.is_walproposer_recovery() {
|
||||
let wal_end = tli.get_flush_lsn();
|
||||
Some(wal_end)
|
||||
} else {
|
||||
None
|
||||
Err(err) => return Poll::Ready(Err(err.into())),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
info!("Start replication from {:?} till {:?}", start_pos, stop_pos);
|
||||
|
||||
// switch to copy
|
||||
pgb.write_message(&BeMessage::CopyBothResponse)?;
|
||||
|
||||
let mut end_pos = stop_pos.unwrap_or(inmem_state.commit_lsn);
|
||||
|
||||
let mut wal_reader = WalReader::new(
|
||||
spg.conf.workdir.clone(),
|
||||
spg.conf.timeline_dir(&tli.ttid),
|
||||
&persisted_state,
|
||||
start_pos,
|
||||
spg.conf.wal_backup_enabled,
|
||||
)?;
|
||||
|
||||
// buffer for wal sending, limited by MAX_SEND_SIZE
|
||||
let mut send_buf = vec![0u8; MAX_SEND_SIZE];
|
||||
|
||||
// watcher for commit_lsn updates
|
||||
let mut commit_lsn_watch_rx = tli.get_commit_lsn_watch_rx();
|
||||
|
||||
loop {
|
||||
if let Some(stop_pos) = stop_pos {
|
||||
if start_pos >= stop_pos {
|
||||
break; /* recovery finished */
|
||||
fn handle_feedback(self: Pin<&mut Self>, msg: &FeMessage) -> Result<(), QueryError> {
|
||||
let this = self.project();
|
||||
match &msg {
|
||||
FeMessage::CopyData(m) => {
|
||||
// There's three possible data messages that the client is supposed to send here:
|
||||
// `HotStandbyFeedback` and `StandbyStatusUpdate` and `NeonStandbyFeedback`.
|
||||
match m.first().cloned() {
|
||||
Some(HOT_STANDBY_FEEDBACK_TAG_BYTE) => {
|
||||
// Note: deserializing is on m[1..] because we skip the tag byte.
|
||||
this.c.feedback.hs_feedback = HotStandbyFeedback::des(&m[1..])
|
||||
.context("failed to deserialize HotStandbyFeedback")?;
|
||||
this.c
|
||||
.tli
|
||||
.update_replica_state(this.c.replica_id, this.c.feedback);
|
||||
}
|
||||
end_pos = stop_pos;
|
||||
} else {
|
||||
/* Wait until we have some data to stream */
|
||||
let lsn = wait_for_lsn(&mut commit_lsn_watch_rx, start_pos).await?;
|
||||
Some(STANDBY_STATUS_UPDATE_TAG_BYTE) => {
|
||||
let _reply = StandbyReply::des(&m[1..])
|
||||
.context("failed to deserialize StandbyReply")?;
|
||||
// This must be a regular postgres replica,
|
||||
// because pageserver doesn't send this type of messages to safekeeper.
|
||||
// Currently we just ignore this, tracking progress for them is not supported.
|
||||
}
|
||||
Some(NEON_STATUS_UPDATE_TAG_BYTE) => {
|
||||
// Note: deserializing is on m[9..] because we skip the tag byte and len bytes.
|
||||
let buf = Bytes::copy_from_slice(&m[9..]);
|
||||
let reply = ReplicationFeedback::parse(buf);
|
||||
|
||||
if let Some(lsn) = lsn {
|
||||
end_pos = lsn;
|
||||
} else {
|
||||
// TODO: also check once in a while whether we are walsender
|
||||
// to right pageserver.
|
||||
if tli.should_walsender_stop(replica_id) {
|
||||
// Shut down, timeline is suspended.
|
||||
return Err(QueryError::from(io::Error::new(
|
||||
io::ErrorKind::ConnectionAborted,
|
||||
format!("end streaming to {:?}", spg.appname),
|
||||
)));
|
||||
}
|
||||
trace!("ReplicationFeedback is {:?}", reply);
|
||||
// Only pageserver sends ReplicationFeedback, so set the flag.
|
||||
// This replica is the source of information to resend to compute.
|
||||
this.c.feedback.pageserver_feedback = Some(reply);
|
||||
|
||||
// timeout expired: request pageserver status
|
||||
pgb.write_message(&BeMessage::KeepAlive(WalSndKeepAlive {
|
||||
sent_ptr: end_pos.0,
|
||||
timestamp: get_current_timestamp(),
|
||||
request_reply: true,
|
||||
}))?;
|
||||
this.c
|
||||
.tli
|
||||
.update_replica_state(this.c.replica_id, this.c.feedback);
|
||||
}
|
||||
_ => warn!("unexpected message {:?}", msg),
|
||||
}
|
||||
}
|
||||
FeMessage::CopyFail => {
|
||||
// XXX we should probably (tell pgb to) close the socket, as
|
||||
// CopyFail in duplex copy is somewhat unexpected (at least to
|
||||
// PG walsender; evidently client should finish it with
|
||||
// CopyDone). Note that sync rust-postgres client (which we
|
||||
// don't use anymore) hangs otherwise.
|
||||
// https://github.com/sfackler/rust-postgres/issues/755
|
||||
// https://github.com/neondatabase/neon/issues/935
|
||||
//
|
||||
return Err(anyhow::anyhow!("unexpected CopyFail").into());
|
||||
}
|
||||
_ => {
|
||||
return Err(
|
||||
anyhow::anyhow!("unexpected message {:?} in replication stream", msg).into(),
|
||||
);
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Poll writing, i.e. sending more WAL. Completes only on error or when we
|
||||
// decide to shutdown connection -- receiver is caughtup and there is no
|
||||
// active computes; this is still handled as Err though.
|
||||
fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), QueryError>> {
|
||||
// send while we don't block or error out
|
||||
loop {
|
||||
match &mut self.as_mut().project().write_state.project() {
|
||||
WriteStateProj::WaitWal { fut } => match ready!(fut.as_mut().poll(cx))? {
|
||||
Some(lsn) => {
|
||||
self.as_mut().project().c.end_pos = lsn;
|
||||
self.as_mut().start_read_wal();
|
||||
continue;
|
||||
}
|
||||
// Timed out waiting for WAL, send keepalive and possibly terminate.
|
||||
None => {
|
||||
let mut this = self.as_mut().project();
|
||||
if this.c.tli.should_walsender_stop(this.c.replica_id) {
|
||||
// Terminate if there is nothing more to send.
|
||||
// TODO close the stream properly
|
||||
return Poll::Ready(Err(anyhow::anyhow!(format!(
|
||||
"ending streaming to {:?} at {}, receiver is caughtup and there is no computes",
|
||||
self.c.appname, self.c.start_pos,
|
||||
)).into()));
|
||||
}
|
||||
let end_pos = this.c.end_pos.0;
|
||||
this.c
|
||||
.pgb
|
||||
.write_message(&BeMessage::KeepAlive(WalSndKeepAlive {
|
||||
sent_ptr: end_pos,
|
||||
timestamp: get_current_timestamp(),
|
||||
request_reply: true,
|
||||
}))?;
|
||||
/* flush KA */
|
||||
this.write_state.set(WriteState::FlushWal);
|
||||
}
|
||||
},
|
||||
WriteStateProj::ReadWal { fut } => {
|
||||
let read_len = ready!(fut.as_mut().poll(cx))?;
|
||||
assert!(read_len > 0, "read_len={}", read_len);
|
||||
|
||||
let mut this = self.as_mut().project();
|
||||
let write_ctx_clone = this.c.write_ctx.clone();
|
||||
let send_buf = &write_ctx_clone.send_buf.borrow()[..read_len];
|
||||
let chunk_end = this.c.start_pos + read_len as u64;
|
||||
// write data to the output buffer
|
||||
this.c
|
||||
.pgb
|
||||
.write_message(&BeMessage::XLogData(XLogDataBody {
|
||||
wal_start: this.c.start_pos.0,
|
||||
wal_end: chunk_end.0,
|
||||
timestamp: get_current_timestamp(),
|
||||
data: send_buf,
|
||||
}))
|
||||
.context("Failed to write XLogData")?;
|
||||
trace!("wrote a chunk of wal {}-{}", this.c.start_pos, chunk_end);
|
||||
this.c.start_pos = chunk_end;
|
||||
// and flush it
|
||||
this.write_state.set(WriteState::FlushWal);
|
||||
}
|
||||
WriteStateProj::FlushWal => {
|
||||
let this = self.as_mut().project();
|
||||
|
||||
let send_size = end_pos.checked_sub(start_pos).unwrap().0 as usize;
|
||||
let send_size = min(send_size, send_buf.len());
|
||||
|
||||
let send_buf = &mut send_buf[..send_size];
|
||||
|
||||
// read wal into buffer
|
||||
let send_size = wal_reader.read(send_buf).await?;
|
||||
let send_buf = &send_buf[..send_size];
|
||||
|
||||
// Write some data to the network socket.
|
||||
pgb.write_message(&BeMessage::XLogData(XLogDataBody {
|
||||
wal_start: start_pos.0,
|
||||
wal_end: end_pos.0,
|
||||
timestamp: get_current_timestamp(),
|
||||
data: send_buf,
|
||||
}))
|
||||
.context("Failed to send XLogData")?;
|
||||
|
||||
start_pos += send_size as u64;
|
||||
trace!("sent WAL up to {}", start_pos);
|
||||
ready!(this.c.pgb.poll_flush(cx))?;
|
||||
// If we are streaming to walproposer, check it is time to stop.
|
||||
if let Some(stop_pos) = this.c.stop_pos {
|
||||
if this.c.start_pos >= stop_pos {
|
||||
// recovery finished
|
||||
// TODO close the stream properly
|
||||
return Poll::Ready(Err(anyhow::anyhow!(format!(
|
||||
"ending streaming to walproposer at {}, receiver is caughtup and there is no computes",
|
||||
this.c.start_pos)).into()));
|
||||
}
|
||||
self.as_mut().start_read_wal();
|
||||
continue;
|
||||
} else {
|
||||
// if we don't know next portion is already available, wait
|
||||
// for it; otherwise proceed to sending
|
||||
if self.c.end_pos <= self.c.start_pos {
|
||||
self.as_mut().start_wait_wal();
|
||||
} else {
|
||||
self.as_mut().start_read_wal();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})
|
||||
// Start waiting for WAL, creating future doing that.
|
||||
fn start_wait_wal(self: Pin<&mut Self>) {
|
||||
let fut = self.c.wait_wal_fut();
|
||||
self.project().write_state.set(WriteState::WaitWal {
|
||||
fut: {
|
||||
// SAFETY: this function is the only way to assign WaitWal to
|
||||
// write_state. We just workaround impossibility of specifying
|
||||
// async fn type, which is anonymous.
|
||||
// transmute_copy is used as transmute refuses generic param:
|
||||
// https://users.rust-lang.org/t/transmute-doesnt-work-on-generic-types/87272
|
||||
assert_eq!(std::mem::size_of::<WF>(), std::mem::size_of_val(&fut));
|
||||
let t = unsafe { std::mem::transmute_copy(&fut) };
|
||||
std::mem::forget(fut);
|
||||
t
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
// Switch into reading WAL state, creating Future doing that.
|
||||
fn start_read_wal(self: Pin<&mut Self>) {
|
||||
let fut = self.c.read_wal_fut();
|
||||
self.project().write_state.set(WriteState::ReadWal {
|
||||
fut: {
|
||||
// SAFETY: this function is the only way to assign ReadWal to
|
||||
// write_state. We just workaround impossibility of specifying
|
||||
// async fn type, which is anonymous.
|
||||
// transmute_copy is used as transmute refuses generic param:
|
||||
// https://users.rust-lang.org/t/transmute-doesnt-work-on-generic-types/87272
|
||||
assert_eq!(std::mem::size_of::<RF>(), std::mem::size_of_val(&fut));
|
||||
let t = unsafe { std::mem::transmute_copy(&fut) };
|
||||
std::mem::forget(fut);
|
||||
t
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl ReplicationContext<'_> {
|
||||
// Create future waiting for WAL.
|
||||
fn wait_wal_fut(&self) -> impl Future<Output = anyhow::Result<Option<Lsn>>> {
|
||||
let mut commit_lsn_watch_rx = self.tli.get_commit_lsn_watch_rx();
|
||||
let start_pos = self.start_pos;
|
||||
async move { wait_for_lsn(&mut commit_lsn_watch_rx, start_pos).await }
|
||||
}
|
||||
|
||||
// Create future reading WAL.
|
||||
fn read_wal_fut(&self) -> impl Future<Output = anyhow::Result<usize>> {
|
||||
let mut send_size = self
|
||||
.end_pos
|
||||
.checked_sub(self.start_pos)
|
||||
.expect("reading wal without waiting for it first")
|
||||
.0 as usize;
|
||||
send_size = min(send_size, self.write_ctx.send_buf.borrow().len());
|
||||
let write_ctx_fut = self.write_ctx.clone();
|
||||
async move {
|
||||
let mut wal_reader_ref = write_ctx_fut.wal_reader.borrow_mut_send();
|
||||
let mut send_buf_ref = write_ctx_fut.send_buf.borrow_mut_send();
|
||||
|
||||
let send_buf = &mut send_buf_ref[..send_size];
|
||||
wal_reader_ref.read(send_buf).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const POLL_STATE_TIMEOUT: Duration = Duration::from_secs(1);
|
||||
|
||||
// Wait until we have commit_lsn > lsn or timeout expires. Returns latest commit_lsn.
|
||||
// Wait until we have commit_lsn > lsn or timeout expires. Returns
|
||||
// - Ok(Some(commit_lsn)) if needed lsn is successfully observed;
|
||||
// - Ok(None) if timeout expired;
|
||||
// - Err in case of error (if watch channel is in trouble, shouldn't happen).
|
||||
async fn wait_for_lsn(rx: &mut Receiver<Lsn>, lsn: Lsn) -> anyhow::Result<Option<Lsn>> {
|
||||
let commit_lsn: Lsn = *rx.borrow();
|
||||
if commit_lsn > lsn {
|
||||
|
||||
@@ -346,7 +346,9 @@ impl WalBackupTask {
|
||||
backup_lsn, commit_lsn, e
|
||||
);
|
||||
|
||||
retry_attempt = retry_attempt.saturating_add(1);
|
||||
if retry_attempt < u32::MAX {
|
||||
retry_attempt += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -385,7 +387,7 @@ async fn backup_single_segment(
|
||||
) -> Result<()> {
|
||||
let segment_file_path = seg.file_path(timeline_dir)?;
|
||||
let remote_segment_path = segment_file_path
|
||||
.strip_prefix(workspace_dir)
|
||||
.strip_prefix(&workspace_dir)
|
||||
.context("Failed to strip workspace dir prefix")
|
||||
.and_then(RemotePath::new)
|
||||
.with_context(|| {
|
||||
@@ -467,7 +469,7 @@ async fn backup_object(source_file: &Path, target_file: &RemotePath, size: usize
|
||||
pub async fn read_object(
|
||||
file_path: &RemotePath,
|
||||
offset: u64,
|
||||
) -> anyhow::Result<Pin<Box<dyn tokio::io::AsyncRead>>> {
|
||||
) -> anyhow::Result<Pin<Box<dyn tokio::io::AsyncRead + Send + Sync>>> {
|
||||
let storage = REMOTE_STORAGE
|
||||
.get()
|
||||
.context("Failed to get remote storage")?
|
||||
|
||||
@@ -2,36 +2,54 @@
|
||||
//! WAL service listens for client connections and
|
||||
//! receive WAL from wal_proposer and send it to WAL receivers
|
||||
//!
|
||||
use anyhow::{Context, Result};
|
||||
use regex::Regex;
|
||||
use std::net::{TcpListener, TcpStream};
|
||||
use std::thread;
|
||||
use std::{future, thread};
|
||||
use tokio::net::TcpStream;
|
||||
use tracing::*;
|
||||
use utils::postgres_backend_async::QueryError;
|
||||
|
||||
use crate::handler::SafekeeperPostgresHandler;
|
||||
use crate::SafeKeeperConf;
|
||||
use utils::postgres_backend::{AuthType, PostgresBackend};
|
||||
use utils::postgres_backend_async::{AuthType, PostgresBackend};
|
||||
|
||||
/// Accept incoming TCP connections and spawn them into a background thread.
|
||||
pub fn thread_main(conf: SafeKeeperConf, listener: TcpListener) -> ! {
|
||||
loop {
|
||||
match listener.accept() {
|
||||
Ok((socket, peer_addr)) => {
|
||||
debug!("accepted connection from {}", peer_addr);
|
||||
let conf = conf.clone();
|
||||
pub fn thread_main(conf: SafeKeeperConf, pg_listener: std::net::TcpListener) {
|
||||
let runtime = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.context("create runtime")
|
||||
// todo catch error in main thread
|
||||
.expect("failed to create runtime");
|
||||
|
||||
let _ = thread::Builder::new()
|
||||
.name("WAL service thread".into())
|
||||
.spawn(move || {
|
||||
if let Err(err) = handle_socket(socket, conf) {
|
||||
error!("connection handler exited: {}", err);
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
runtime
|
||||
.block_on(async move {
|
||||
// Tokio's from_std won't do this for us, per its comment.
|
||||
pg_listener.set_nonblocking(true)?;
|
||||
let listener = tokio::net::TcpListener::from_std(pg_listener)?;
|
||||
|
||||
loop {
|
||||
match listener.accept().await {
|
||||
Ok((socket, peer_addr)) => {
|
||||
debug!("accepted connection from {}", peer_addr);
|
||||
let conf = conf.clone();
|
||||
|
||||
let _ = thread::Builder::new()
|
||||
.name("WAL service thread".into())
|
||||
.spawn(move || {
|
||||
if let Err(err) = handle_socket(socket, conf) {
|
||||
error!("connection handler exited: {}", err);
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
Err(e) => error!("Failed to accept connection: {}", e),
|
||||
}
|
||||
}
|
||||
Err(e) => error!("Failed to accept connection: {}", e),
|
||||
}
|
||||
}
|
||||
#[allow(unreachable_code)] // hint compiler the closure return type
|
||||
Ok::<(), anyhow::Error>(())
|
||||
})
|
||||
.expect("listener failed")
|
||||
}
|
||||
|
||||
// Get unique thread id (Rust internal), with ThreadId removed for shorter printing
|
||||
@@ -44,9 +62,14 @@ fn get_tid() -> u64 {
|
||||
|
||||
/// This is run by `thread_main` above, inside a background thread.
|
||||
///
|
||||
fn handle_socket(socket: TcpStream, conf: SafeKeeperConf) -> Result<(), QueryError> {
|
||||
fn handle_socket(mut socket: TcpStream, conf: SafeKeeperConf) -> Result<(), QueryError> {
|
||||
let _enter = info_span!("", tid = ?get_tid()).entered();
|
||||
|
||||
let runtime = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()?;
|
||||
let local = tokio::task::LocalSet::new();
|
||||
|
||||
socket.set_nodelay(true)?;
|
||||
|
||||
let auth_type = match conf.auth {
|
||||
@@ -54,9 +77,13 @@ fn handle_socket(socket: TcpStream, conf: SafeKeeperConf) -> Result<(), QueryErr
|
||||
Some(_) => AuthType::NeonJWT,
|
||||
};
|
||||
let mut conn_handler = SafekeeperPostgresHandler::new(conf);
|
||||
let pgbackend = PostgresBackend::new(socket, auth_type, None, false)?;
|
||||
// libpq replication protocol between safekeeper and replicas/pagers
|
||||
pgbackend.run(&mut conn_handler)?;
|
||||
let pgbackend = PostgresBackend::new(socket, auth_type, None)?;
|
||||
// libpq protocol between safekeeper and walproposer / pageserver
|
||||
// We don't use shutdown.
|
||||
local.block_on(
|
||||
&runtime,
|
||||
pgbackend.run(&mut conn_handler, || future::pending::<()>()),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -450,7 +450,7 @@ pub struct WalReader {
|
||||
timeline_dir: PathBuf,
|
||||
wal_seg_size: usize,
|
||||
pos: Lsn,
|
||||
wal_segment: Option<Pin<Box<dyn AsyncRead>>>,
|
||||
wal_segment: Option<Pin<Box<dyn AsyncRead + Send + Sync>>>,
|
||||
|
||||
// S3 will be used to read WAL if LSN is not available locally
|
||||
enable_remote_read: bool,
|
||||
@@ -491,6 +491,11 @@ impl WalReader {
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn fake_read(&mut self) -> Result<usize> {
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
Ok(self.pos.0 as usize)
|
||||
}
|
||||
|
||||
pub async fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
|
||||
let mut wal_segment = match self.wal_segment.take() {
|
||||
Some(reader) => reader,
|
||||
@@ -517,7 +522,7 @@ impl WalReader {
|
||||
}
|
||||
|
||||
/// Open WAL segment at the current position of the reader.
|
||||
async fn open_segment(&self) -> Result<Pin<Box<dyn AsyncRead>>> {
|
||||
async fn open_segment(&self) -> Result<Pin<Box<dyn AsyncRead + Send + Sync>>> {
|
||||
let xlogoff = self.pos.segment_offset(self.wal_seg_size);
|
||||
let segno = self.pos.segment_number(self.wal_seg_size);
|
||||
let wal_file_name = XLogFileName(PG_TLI, segno, self.wal_seg_size);
|
||||
|
||||
Reference in New Issue
Block a user