Compare commits

...

4 Commits

Author SHA1 Message Date
Arseny Sher
f281dc5953 Add forgotten files. 2023-02-06 13:44:42 +04:00
Arseny Sher
48fb085ebd Test & bug fix it.
Add fake_timeline endoint creating timeline + some WAL.

curl -X POST http://127.0.0.1:7676/v1/fake_timeline
Set in pg_receivewal.c:
  stream.startpos = 0x1493AC8;
pg_install/v15/bin/pg_receivewal -v -d "host=localhost port=5454 options='-c tenant_id=deadbeefdeadbeefdeadbeefdeadbeef timeline_id=deadbeefdeadbeefdeadbeefdeadbeef'" -D ~/tmp/tmp/tmp
2023-02-03 17:14:51 +04:00
Arseny Sher
2bbd24edbf Get rid of futurex boxing through transmute. 2023-02-02 14:34:10 +04:00
Arseny Sher
5e972ccdc4 WIP safekeeper walsender: read-write from single task.
- Use postgres_backend_async throughout safekeeper.
- Use Framed in postgres_backend_async, it allows polling interface and
  takes some logic.
- Do read-write from single task in walsender.

The latter turned out to be more complicated than I initially expected due to 1)
borrow checking and 2) anon Future types. 1) required SendRc<Refcell<...>>
construct just to satisfy the checker; 2) is currently done via boxing futures,
which is a pointless heap allocation in active path.

I'll probably try to workaround 2) with transmute, but it made me wonder whether
socket split, like it was done previously, would be better. It is also messy
though:
- we need to manage two tasks, properly join them and should on exit/error
  should join pgbackend back to leave it in valid state; pgbackend itself must
  swell a bit to provide splitted interface.
- issues with tls
- tokio::io::split has pointless mutex inside

fixing walreceiver and proxy is not done yet
2023-02-02 12:03:45 +04:00
26 changed files with 1315 additions and 614 deletions

7
Cargo.lock generated
View File

@@ -2510,6 +2510,7 @@ name = "pq_proto"
version = "0.1.0"
dependencies = [
"anyhow",
"byteorder",
"bytes",
"pin-project-lite",
"postgres-protocol",
@@ -2517,6 +2518,7 @@ dependencies = [
"serde",
"thiserror",
"tokio",
"tokio-util",
"tracing",
"workspace_hack",
]
@@ -3074,6 +3076,7 @@ dependencies = [
"const_format",
"crc32c",
"fs2",
"futures",
"git-version",
"hex",
"humantime",
@@ -3082,6 +3085,7 @@ dependencies = [
"nix",
"once_cell",
"parking_lot",
"pin-project-lite",
"postgres",
"postgres-protocol",
"postgres_ffi",
@@ -4203,6 +4207,7 @@ dependencies = [
"byteorder",
"bytes",
"criterion",
"futures",
"git-version",
"hex",
"hex-literal",
@@ -4211,6 +4216,7 @@ dependencies = [
"metrics",
"nix",
"once_cell",
"pin-utils",
"pq_proto",
"rand",
"routerify",
@@ -4228,6 +4234,7 @@ dependencies = [
"thiserror",
"tokio",
"tokio-rustls",
"tokio-util",
"tracing",
"tracing-subscriber",
"workspace_hack",

View File

@@ -14,7 +14,7 @@ use anyhow::{Context, Result};
use utils::{
id::{TenantId, TimelineId},
lsn::Lsn,
postgres_backend::AuthType,
postgres_backend_async::AuthType,
};
use crate::local_env::{LocalEnv, DEFAULT_PG_VERSION};

View File

@@ -19,7 +19,7 @@ use std::process::{Command, Stdio};
use utils::{
auth::{encode_from_key_file, Claims, Scope},
id::{NodeId, TenantId, TenantTimelineId, TimelineId},
postgres_backend::AuthType,
postgres_backend_async::AuthType,
};
use crate::safekeeper::SafekeeperNode;

View File

@@ -7,11 +7,13 @@ license = "Apache-2.0"
[dependencies]
anyhow = "1.0"
bytes = "1.0.1"
byteorder = "1.4.3"
pin-project-lite = "0.2.7"
postgres-protocol = { git = "https://github.com/neondatabase/rust-postgres.git", rev="43e6db254a97fdecbce33d8bc0890accfd74495e" }
rand = "0.8.3"
serde = { version = "1.0", features = ["derive"] }
tokio = { version = "1.17", features = ["macros"] }
tokio-util = { version = "0.7.3" }
tracing = "0.1"
thiserror = "1.0"

View File

@@ -0,0 +1,62 @@
//! Provides `PostgresCodec` defining how to serilize/deserialize Postgres
//! messages to/from the wire, to be used with `tokio_util::codec::Framed`.
use std::io;
use bytes::BytesMut;
use tokio_util::codec::{Decoder, Encoder};
use crate::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
// Defines how to serilize/deserialize Postgres messages to/from the wire, to be
// used with `tokio_util::codec::Framed`.
pub struct PostgresCodec {
// Have we already decoded startup message? All further should start with
// message type byte then.
startup_read: bool,
}
impl PostgresCodec {
pub fn new() -> Self {
PostgresCodec {
startup_read: false,
}
}
}
/// Error on postgres connection: either IO (physical transport error) or
/// protocol violation.
#[derive(thiserror::Error, Debug)]
pub enum ConnectionError {
#[error(transparent)]
Io(#[from] io::Error),
#[error(transparent)]
Protocol(#[from] ProtocolError),
}
impl Encoder<&BeMessage<'_>> for PostgresCodec {
type Error = ConnectionError;
fn encode(&mut self, item: &BeMessage, dst: &mut BytesMut) -> Result<(), ConnectionError> {
BeMessage::write(dst, &item)?;
Ok(())
}
}
impl Decoder for PostgresCodec {
type Item = FeMessage;
type Error = ConnectionError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<FeMessage>, ConnectionError> {
let msg = if !self.startup_read {
let msg = FeStartupPacket::parse(src);
if let Ok(Some(FeMessage::StartupPacket(FeStartupPacket::StartupMessage { .. }))) = msg
{
self.startup_read = true;
}
msg?
} else {
FeMessage::parse(src)?
};
Ok(msg)
}
}

View File

@@ -3,9 +3,11 @@
//! on message formats.
// Tools for calling certain async methods in sync contexts.
pub mod codec;
pub mod sync;
use anyhow::{ensure, Context, Result};
use anyhow::{anyhow, bail, ensure, Context, Result};
use byteorder::{BigEndian, ByteOrder, ReadBytesExt};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use postgres_protocol::PG_EPOCH;
use serde::{Deserialize, Serialize};
@@ -19,7 +21,7 @@ use std::{
time::{Duration, SystemTime},
};
use sync::{AsyncishRead, SyncFuture};
use tokio::io::AsyncReadExt;
// use tokio::io::AsyncReadExt;
use tracing::{trace, warn};
pub type Oid = u32;
@@ -194,36 +196,108 @@ macro_rules! retry_read {
};
}
/// An error occured during connection being open.
/// An error occured while parsing or serializing raw stream into Postgres
/// messages.
#[derive(thiserror::Error, Debug)]
pub enum ConnectionError {
pub enum ProtocolError {
/// IO error during writing to or reading from the connection socket.
/// removeme
#[error("Socket IO error: {0}")]
Socket(std::io::Error),
/// Invalid packet was received from client
/// Invalid packet was received from the client (e.g. unexpected message
/// type or broken len).
#[error("Protocol error: {0}")]
Protocol(String),
/// Failed to parse a protocol mesage
/// Failed to parse or, (unlikely), serialize a protocol message.
#[error("Message parse error: {0}")]
MessageParse(anyhow::Error),
}
impl From<anyhow::Error> for ConnectionError {
// Allows to return anyhow error from msg parsing routines, meaning less typing.
impl From<anyhow::Error> for ProtocolError {
fn from(e: anyhow::Error) -> Self {
Self::MessageParse(e)
}
}
impl ConnectionError {
impl ProtocolError {
pub fn into_io_error(self) -> io::Error {
match self {
ConnectionError::Socket(io) => io,
ProtocolError::Socket(io) => io,
other => io::Error::new(io::ErrorKind::Other, other.to_string()),
}
}
}
impl FeMessage {
/// Read and parse one message from the `buf` input buffer. If there is at
/// least one valid message, returns it, advancing `buf`; redundant copies
/// are avoided, as thanks to `bytes` crate ptrs in parsed message point
/// directly into the `buf` (processed data is garbage collected after
/// parsed message is dropped).
///
/// Returns None if `buf` doesn't contain enough data for a single message.
/// For efficiency, tries to reserve large enough space in `buf` for the
/// next message in this case.
///
/// Returns Error if message is malformed, the only possible ErrorKind is
/// InvalidInput.
//
// Inspired by rust-postgres Message::parse.
pub fn parse(buf: &mut BytesMut) -> Result<Option<FeMessage>, ProtocolError> {
// Every message contains message type byte and 4 bytes len; can't do
// much without them.
if buf.len() < 5 {
let to_read = 5 - buf.len();
buf.reserve(to_read);
return Ok(None);
}
// We shouldn't advance `buf` as probably full message is not there yet,
// so can't directly use Bytes::get_u32 etc.
let tag = buf[0];
let len = (&buf[1..5]).read_u32::<BigEndian>().unwrap();
if len < 4 {
return Err(ProtocolError::Protocol(format!(
"invalid message length {}",
len
)));
}
// lenth field includes itself, but not message type.
let total_len = len as usize + 1;
if buf.len() < total_len {
// Don't have full message yet.
let to_read = total_len - buf.len();
buf.reserve(to_read);
return Ok(None);
}
// got the message, advance buffer
let mut msg = buf.split_to(total_len).freeze();
msg.advance(5); // consume message type and len
match tag {
b'Q' => Ok(Some(FeMessage::Query(msg))),
b'P' => Ok(Some(FeParseMessage::parse(msg)?)),
b'D' => Ok(Some(FeDescribeMessage::parse(msg)?)),
b'E' => Ok(Some(FeExecuteMessage::parse(msg)?)),
b'B' => Ok(Some(FeBindMessage::parse(msg)?)),
b'C' => Ok(Some(FeCloseMessage::parse(msg)?)),
b'S' => Ok(Some(FeMessage::Sync)),
b'X' => Ok(Some(FeMessage::Terminate)),
b'd' => Ok(Some(FeMessage::CopyData(msg))),
b'c' => Ok(Some(FeMessage::CopyDone)),
b'f' => Ok(Some(FeMessage::CopyFail)),
b'p' => Ok(Some(FeMessage::PasswordMessage(msg))),
tag => {
return Err(ProtocolError::Protocol(format!(
"unknown message tag: {tag},'{msg:?}'"
)))
}
}
}
/// Read one message from the stream.
/// This function returns `Ok(None)` in case of EOF.
/// One way to handle this properly:
@@ -245,68 +319,8 @@ impl FeMessage {
/// }
/// ```
#[inline(never)]
pub fn read(
stream: &mut (impl io::Read + Unpin),
) -> Result<Option<FeMessage>, ConnectionError> {
Self::read_fut(&mut AsyncishRead(stream)).wait()
}
/// Read one message from the stream.
/// See documentation for `Self::read`.
pub fn read_fut<Reader>(
stream: &mut Reader,
) -> SyncFuture<Reader, impl Future<Output = Result<Option<FeMessage>, ConnectionError>> + '_>
where
Reader: tokio::io::AsyncRead + Unpin,
{
// We return a Future that's sync (has a `wait` method) if and only if the provided stream is SyncProof.
// SyncFuture contract: we are only allowed to await on sync-proof futures, the AsyncRead and
// AsyncReadExt methods of the stream.
SyncFuture::new(async move {
// Each libpq message begins with a message type byte, followed by message length
// If the client closes the connection, return None. But if the client closes the
// connection in the middle of a message, we will return an error.
let tag = match retry_read!(stream.read_u8().await) {
Ok(b) => b,
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(ConnectionError::Socket(e)),
};
// The message length includes itself, so it better be at least 4.
let len = retry_read!(stream.read_u32().await)
.map_err(ConnectionError::Socket)?
.checked_sub(4)
.ok_or_else(|| ConnectionError::Protocol("invalid message length".to_string()))?;
let body = {
let mut buffer = vec![0u8; len as usize];
stream
.read_exact(&mut buffer)
.await
.map_err(ConnectionError::Socket)?;
Bytes::from(buffer)
};
match tag {
b'Q' => Ok(Some(FeMessage::Query(body))),
b'P' => Ok(Some(FeParseMessage::parse(body)?)),
b'D' => Ok(Some(FeDescribeMessage::parse(body)?)),
b'E' => Ok(Some(FeExecuteMessage::parse(body)?)),
b'B' => Ok(Some(FeBindMessage::parse(body)?)),
b'C' => Ok(Some(FeCloseMessage::parse(body)?)),
b'S' => Ok(Some(FeMessage::Sync)),
b'X' => Ok(Some(FeMessage::Terminate)),
b'd' => Ok(Some(FeMessage::CopyData(body))),
b'c' => Ok(Some(FeMessage::CopyDone)),
b'f' => Ok(Some(FeMessage::CopyFail)),
b'p' => Ok(Some(FeMessage::PasswordMessage(body))),
tag => {
return Err(ConnectionError::Protocol(format!(
"unknown message tag: {tag},'{body:?}'"
)))
}
}
})
pub fn read(_stream: &mut (impl io::Read + Unpin)) -> Result<Option<FeMessage>, ProtocolError> {
Ok(None) // removeme
}
}
@@ -314,21 +328,124 @@ impl FeStartupPacket {
/// Read startup message from the stream.
// XXX: It's tempting yet undesirable to accept `stream` by value,
// since such a change will cause user-supplied &mut references to be consumed
pub fn read(
stream: &mut (impl io::Read + Unpin),
) -> Result<Option<FeMessage>, ConnectionError> {
pub fn read(stream: &mut (impl io::Read + Unpin)) -> Result<Option<FeMessage>, ProtocolError> {
Self::read_fut(&mut AsyncishRead(stream)).wait()
}
/// Read and parse startup message from the `buf` input buffer. It is
/// different from [`FeMessage::parse`] because startup messages don't have
/// message type byte; otherwise, its comments apply.
pub fn parse(buf: &mut BytesMut) -> Result<Option<FeMessage>, ProtocolError> {
const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
const RESERVED_INVALID_MAJOR_VERSION: u32 = 1234;
const CANCEL_REQUEST_CODE: u32 = 5678;
const NEGOTIATE_SSL_CODE: u32 = 5679;
const NEGOTIATE_GSS_CODE: u32 = 5680;
if buf.len() < 4 {
let to_read = 5 - buf.len();
buf.reserve(to_read);
return Ok(None);
}
// We shouldn't advance `buf` as probably full message is not there yet,
// so can't directly use Bytes::get_u32 etc.
let len = (&buf[0..4]).read_u32::<BigEndian>().unwrap() as usize;
if len < 8 || len > MAX_STARTUP_PACKET_LENGTH {
return Err(ProtocolError::Protocol(format!(
"invalid startup packet message length {}",
len
)));
}
if buf.len() < len {
// Don't have full message yet.
let to_read = len - buf.len();
buf.reserve(to_read);
return Ok(None);
}
// got the message, advance buffer
let mut msg = buf.split_to(len).freeze();
msg.advance(4); // consume len
let request_code = msg.get_u32();
let req_hi = request_code >> 16;
let req_lo = request_code & ((1 << 16) - 1);
// StartupMessage, CancelRequest, SSLRequest etc are differentiated by request code.
let message = match (req_hi, req_lo) {
(RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
if msg.remaining() < 8 {
return Err(ProtocolError::MessageParse(anyhow!(
"CancelRequest message is malformed, backend PID / secret key missing"
)));
}
FeStartupPacket::CancelRequest(CancelKeyData {
backend_pid: msg.get_i32(),
cancel_key: msg.get_i32(),
})
}
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => {
// Requested upgrade to SSL (aka TLS)
FeStartupPacket::SslRequest
}
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => {
// Requested upgrade to GSSAPI
FeStartupPacket::GssEncRequest
}
(RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
return Err(ProtocolError::Protocol(format!(
"Unrecognized request code {unrecognized_code}"
)));
}
// TODO bail if protocol major_version is not 3?
(major_version, minor_version) => {
// StartupMessage
// Parse pairs of null-terminated strings (key, value).
// See `postgres: ProcessStartupPacket, build_startup_packet`.
let mut tokens = str::from_utf8(&msg)
.context("StartupMessage params: invalid utf-8")?
.strip_suffix('\0') // drop packet's own null
.ok_or_else(|| {
ProtocolError::Protocol(
"StartupMessage params: missing null terminator".to_string(),
)
})?
.split_terminator('\0');
let mut params = HashMap::new();
while let Some(name) = tokens.next() {
let value = tokens.next().ok_or_else(|| {
ProtocolError::Protocol(
"StartupMessage params: key without value".to_string(),
)
})?;
params.insert(name.to_owned(), value.to_owned());
}
FeStartupPacket::StartupMessage {
major_version,
minor_version,
params: StartupMessageParams { params },
}
}
};
Ok(Some(FeMessage::StartupPacket(message)))
}
/// Read startup message from the stream.
// XXX: It's tempting yet undesirable to accept `stream` by value,
// since such a change will cause user-supplied &mut references to be consumed
pub fn read_fut<Reader>(
stream: &mut Reader,
) -> SyncFuture<Reader, impl Future<Output = Result<Option<FeMessage>, ConnectionError>> + '_>
) -> SyncFuture<Reader, impl Future<Output = Result<Option<FeMessage>, ProtocolError>> + '_>
where
Reader: tokio::io::AsyncRead + Unpin,
{
use tokio::io::AsyncReadExt;
const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
const RESERVED_INVALID_MAJOR_VERSION: u32 = 1234;
const CANCEL_REQUEST_CODE: u32 = 5678;
@@ -343,18 +460,18 @@ impl FeStartupPacket {
let len = match retry_read!(stream.read_u32().await) {
Ok(len) => len as usize,
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(ConnectionError::Socket(e)),
Err(e) => return Err(ProtocolError::Socket(e)),
};
#[allow(clippy::manual_range_contains)]
if len < 4 || len > MAX_STARTUP_PACKET_LENGTH {
return Err(ConnectionError::Protocol(format!(
return Err(ProtocolError::Protocol(format!(
"invalid message length {len}"
)));
}
let request_code =
retry_read!(stream.read_u32().await).map_err(ConnectionError::Socket)?;
retry_read!(stream.read_u32().await).map_err(ProtocolError::Socket)?;
// the rest of startup packet are params
let params_len = len - 8;
@@ -362,7 +479,7 @@ impl FeStartupPacket {
stream
.read_exact(params_bytes.as_mut())
.await
.map_err(ConnectionError::Socket)?;
.map_err(ProtocolError::Socket)?;
// Parse params depending on request code
let req_hi = request_code >> 16;
@@ -370,14 +487,16 @@ impl FeStartupPacket {
let message = match (req_hi, req_lo) {
(RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
if params_len != 8 {
return Err(ConnectionError::Protocol(
return Err(ProtocolError::Protocol(
"expected 8 bytes for CancelRequest params".to_string(),
));
}
let mut cursor = Cursor::new(params_bytes);
FeStartupPacket::CancelRequest(CancelKeyData {
backend_pid: cursor.read_i32().await.map_err(ConnectionError::Socket)?,
cancel_key: cursor.read_i32().await.map_err(ConnectionError::Socket)?,
backend_pid: 2,
cancel_key: 2,
// backend_pid: cursor.read_i32().await.map_err(ConnectionError::Socket)?,
// cancel_key: cursor.read_i32().await.map_err(ConnectionError::Socket)?,
})
}
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => {
@@ -389,7 +508,7 @@ impl FeStartupPacket {
FeStartupPacket::GssEncRequest
}
(RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
return Err(ConnectionError::Protocol(format!(
return Err(ProtocolError::Protocol(format!(
"Unrecognized request code {unrecognized_code}"
)));
}
@@ -401,7 +520,7 @@ impl FeStartupPacket {
.context("StartupMessage params: invalid utf-8")?
.strip_suffix('\0') // drop packet's own null
.ok_or_else(|| {
ConnectionError::Protocol(
ProtocolError::Protocol(
"StartupMessage params: missing null terminator".to_string(),
)
})?
@@ -410,7 +529,7 @@ impl FeStartupPacket {
let mut params = HashMap::new();
while let Some(name) = tokens.next() {
let value = tokens.next().ok_or_else(|| {
ConnectionError::Protocol(
ProtocolError::Protocol(
"StartupMessage params: key without value".to_string(),
)
})?;
@@ -440,6 +559,9 @@ impl FeParseMessage {
let _pstmt_name = read_cstr(&mut buf)?;
let query_string = read_cstr(&mut buf)?;
if buf.remaining() < 2 {
bail!("Parse message is malformed, nparams missing");
}
let nparams = buf.get_i16();
ensure!(nparams == 0, "query params not implemented");
@@ -466,6 +588,9 @@ impl FeDescribeMessage {
impl FeExecuteMessage {
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
let portal_name = read_cstr(&mut buf)?;
if buf.remaining() < 4 {
bail!("FeExecuteMessage message is malformed, maxrows missing");
}
let maxrows = buf.get_i32();
ensure!(portal_name.is_empty(), "named portals not implemented");
@@ -547,6 +672,11 @@ impl<'a> BeMessage<'a> {
value: b"UTF8",
};
pub const INTEGER_DATETIMES: Self = Self::ParameterStatus {
name: b"integer_datetimes",
value: b"on",
};
/// Build a [`BeMessage::ParameterStatus`] holding the server version.
pub fn server_version(version: &'a str) -> Self {
Self::ParameterStatus {
@@ -665,13 +795,12 @@ fn write_body<R>(buf: &mut BytesMut, f: impl FnOnce(&mut BytesMut) -> R) -> R {
}
/// Safe write of s into buf as cstring (String in the protocol).
fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> io::Result<()> {
fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> Result<(), ProtocolError> {
let bytes = s.as_ref();
if bytes.contains(&0) {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"string contains embedded null",
));
return Err(ProtocolError::MessageParse(anyhow!(
"string contains embedded null"
)));
}
buf.put_slice(bytes);
buf.put_u8(0);
@@ -680,7 +809,7 @@ fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> io::Result<()> {
fn read_cstr(buf: &mut Bytes) -> anyhow::Result<Bytes> {
let pos = buf.iter().position(|x| *x == 0);
let result = buf.split_to(pos.context("missing terminator")?);
let result = buf.split_to(pos.context("missing cstring terminator")?);
buf.advance(1); // drop the null terminator
Ok(result)
}
@@ -688,12 +817,12 @@ fn read_cstr(buf: &mut Bytes) -> anyhow::Result<Bytes> {
pub const SQLSTATE_INTERNAL_ERROR: &[u8; 5] = b"XX000";
impl<'a> BeMessage<'a> {
/// Write message to the given buf.
// Unlike the reading side, we use BytesMut
// here as msg len precedes its body and it is handy to write it down first
// and then fill the length. With Write we would have to either calc it
// manually or have one more buffer.
pub fn write(buf: &mut BytesMut, message: &BeMessage) -> io::Result<()> {
/// Serialize `message` to the given `buf`.
/// Apart from smart memory managemet, BytesMut is good here as msg len
/// precedes its body and it is handy to write it down first and then fill
/// the length. With Write we would have to either calc it manually or have
/// one more buffer.
pub fn write(buf: &mut BytesMut, message: &BeMessage) -> Result<(), ProtocolError> {
match message {
BeMessage::AuthenticationOk => {
buf.put_u8(b'R');
@@ -719,7 +848,7 @@ impl<'a> BeMessage<'a> {
BeMessage::AuthenticationSasl(msg) => {
buf.put_u8(b'R');
write_body(buf, |buf| {
write_body(buf, |buf| -> Result<(), ProtocolError> {
use BeAuthenticationSaslMessage::*;
match msg {
Methods(methods) => {
@@ -738,7 +867,7 @@ impl<'a> BeMessage<'a> {
buf.put_slice(extra);
}
}
Ok::<_, io::Error>(())
Ok(())
})?;
}
@@ -829,7 +958,7 @@ impl<'a> BeMessage<'a> {
BeMessage::ErrorResponse(error_msg, pg_error_code) => {
// 'E' signalizes ErrorResponse messages
buf.put_u8(b'E');
write_body(buf, |buf| {
write_body(buf, |buf| -> Result<(), ProtocolError> {
buf.put_u8(b'S'); // severity
buf.put_slice(b"ERROR\0");
@@ -842,7 +971,7 @@ impl<'a> BeMessage<'a> {
write_cstr(error_msg, buf)?;
buf.put_u8(0); // terminator
Ok::<_, io::Error>(())
Ok(())
})?;
}
@@ -854,7 +983,7 @@ impl<'a> BeMessage<'a> {
// 'N' signalizes NoticeResponse messages
buf.put_u8(b'N');
write_body(buf, |buf| {
write_body(buf, |buf| -> Result<(), ProtocolError> {
buf.put_u8(b'S'); // severity
buf.put_slice(b"NOTICE\0");
@@ -865,7 +994,7 @@ impl<'a> BeMessage<'a> {
write_cstr(error_msg.as_bytes(), buf)?;
buf.put_u8(0); // terminator
Ok::<_, io::Error>(())
Ok(())
})?;
}
@@ -909,7 +1038,7 @@ impl<'a> BeMessage<'a> {
BeMessage::RowDescription(rows) => {
buf.put_u8(b'T');
write_body(buf, |buf| {
write_body(buf, |buf| -> Result<(), ProtocolError> {
buf.put_i16(rows.len() as i16); // # of fields
for row in rows.iter() {
write_cstr(row.name, buf)?;
@@ -920,7 +1049,7 @@ impl<'a> BeMessage<'a> {
buf.put_i32(-1); /* typmod */
buf.put_i16(0); /* format code */
}
Ok::<_, io::Error>(())
Ok(())
})?;
}

View File

@@ -111,7 +111,7 @@ pub trait RemoteStorage: Send + Sync + 'static {
}
pub struct Download {
pub download_stream: Pin<Box<dyn io::AsyncRead + Unpin + Send>>,
pub download_stream: Pin<Box<dyn io::AsyncRead + Unpin + Send + Sync>>,
/// Extra key-value data, associated with the current remote file.
pub metadata: Option<StorageMetadata>,
}

View File

@@ -10,13 +10,16 @@ async-trait = "0.1"
anyhow = "1.0"
bincode = "1.3"
bytes = "1.0.1"
futures = "0.3"
hyper = { version = "0.14.7", features = ["full"] }
pin-utils = "0.1"
routerify = "3"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1"
thiserror = "1.0"
tokio = { version = "1.17", features = ["macros"]}
tokio-rustls = "0.23"
tokio-util = { version = "0.7.3" }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }
nix = "0.25"

View File

@@ -13,7 +13,7 @@ pub mod simple_rcu;
pub mod vec_map;
pub mod bin_ser;
pub mod postgres_backend;
// pub mod postgres_backend;
pub mod postgres_backend_async;
// helper functions for creating and fsyncing
@@ -52,6 +52,8 @@ pub mod signals;
pub mod fs_ext;
pub mod send_rc;
/// use with fail::cfg("$name", "return(2000)")
#[macro_export]
macro_rules! failpoint_sleep_millis_async {

View File

@@ -2,29 +2,24 @@
//! To use, create PostgresBackend and run() it, passing the Handler
//! implementation determining how to process the queries. Currently its API
//! is rather narrow, but we can extend it once required.
use crate::postgres_backend::AuthType;
use anyhow::Context;
use bytes::{Buf, Bytes, BytesMut};
use pq_proto::{BeMessage, ConnectionError, FeMessage, FeStartupPacket, SQLSTATE_INTERNAL_ERROR};
use std::future::Future;
use std::io;
use futures::stream::StreamExt;
use futures::{pin_mut, Sink, SinkExt};
use serde::{Deserialize, Serialize};
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::Poll;
use tracing::{debug, error, info, trace};
use std::{fmt, io};
use std::{future::Future, str::FromStr};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
use tokio_rustls::TlsAcceptor;
use tokio_util::codec::Framed;
use tracing::{debug, error, info, trace};
pub fn is_expected_io_error(e: &io::Error) -> bool {
use io::ErrorKind::*;
matches!(
e.kind(),
ConnectionRefused | ConnectionAborted | ConnectionReset
)
}
use pq_proto::codec::{ConnectionError, PostgresCodec};
use pq_proto::{BeMessage, FeMessage, FeStartupPacket, SQLSTATE_INTERNAL_ERROR};
/// An error, occurred during query processing:
/// either during the connection ([`ConnectionError`]) or before/after it.
@@ -40,7 +35,7 @@ pub enum QueryError {
impl From<io::Error> for QueryError {
fn from(e: io::Error) -> Self {
Self::Disconnected(ConnectionError::Socket(e))
Self::Disconnected(ConnectionError::Io(e))
}
}
@@ -53,6 +48,14 @@ impl QueryError {
}
}
pub fn is_expected_io_error(e: &io::Error) -> bool {
use io::ErrorKind::*;
matches!(
e.kind(),
ConnectionRefused | ConnectionAborted | ConnectionReset
)
}
#[async_trait::async_trait]
pub trait Handler {
/// Handle single query.
@@ -93,6 +96,7 @@ pub trait Handler {
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd)]
pub enum ProtoState {
Initialization,
// Encryption handshake is done; waiting for encrypted Startup message.
Encrypted,
Authentication,
Established,
@@ -105,15 +109,14 @@ pub enum ProcessMsgResult {
Break,
}
/// Always-writeable sock_split stream.
/// May not be readable. See [`PostgresBackend::take_stream_in`]
pub enum Stream {
Unencrypted(BufReader<tokio::net::TcpStream>),
Tls(Box<tokio_rustls::server::TlsStream<BufReader<tokio::net::TcpStream>>>),
Broken,
/// Either plain TCP stream or encrypted one, implementing AsyncRead + AsyncWrite.
pub enum MaybeTlsStream {
Unencrypted(tokio::net::TcpStream),
Tls(Box<tokio_rustls::server::TlsStream<tokio::net::TcpStream>>),
Broken, // temporary value for switch to TLS
}
impl AsyncWrite for Stream {
impl AsyncWrite for MaybeTlsStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
@@ -122,14 +125,14 @@ impl AsyncWrite for Stream {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_write(cx, buf),
Self::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
Self::Broken => unreachable!(),
_ => unreachable!(),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_flush(cx),
Self::Tls(stream) => Pin::new(stream).poll_flush(cx),
Self::Broken => unreachable!(),
_ => unreachable!(),
}
}
fn poll_shutdown(
@@ -139,11 +142,11 @@ impl AsyncWrite for Stream {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_shutdown(cx),
Self::Tls(stream) => Pin::new(stream).poll_shutdown(cx),
Self::Broken => unreachable!(),
_ => unreachable!(),
}
}
}
impl AsyncRead for Stream {
impl AsyncRead for MaybeTlsStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
@@ -152,18 +155,49 @@ impl AsyncRead for Stream {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_read(cx, buf),
Self::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
Self::Broken => unreachable!(),
_ => unreachable!(),
}
}
}
pub struct PostgresBackend {
stream: Stream,
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
pub enum AuthType {
Trust,
// This mimics postgres's AuthenticationCleartextPassword but instead of password expects JWT
NeonJWT,
}
// Output buffer. c.f. BeMessage::write why we are using BytesMut here.
// The data between 0 and "current position" as tracked by the bytes::Buf
// implementation of BytesMut, have already been written.
buf_out: BytesMut,
impl FromStr for AuthType {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"Trust" => Ok(Self::Trust),
"NeonJWT" => Ok(Self::NeonJWT),
_ => anyhow::bail!("invalid value \"{s}\" for auth type"),
}
}
}
impl fmt::Display for AuthType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
AuthType::Trust => "Trust",
AuthType::NeonJWT => "NeonJWT",
})
}
}
pub struct PostgresBackend {
// Provides serialization/deserialization to the underlying transport backed
// with buffers; implements Sink consuming messages and Stream reading them.
//
// Sink::start_send only queues message to the interal buffer.
// SinkExt::flush flushes buffer to the stream.
//
// StreamExt::read reads next message. In case of EOF without partial
// message it returns None.
stream: Framed<MaybeTlsStream, PostgresCodec>,
pub state: ProtoState,
@@ -196,10 +230,10 @@ impl PostgresBackend {
tls_config: Option<Arc<rustls::ServerConfig>>,
) -> io::Result<Self> {
let peer_addr = socket.peer_addr()?;
let stream = MaybeTlsStream::Unencrypted(socket);
Ok(Self {
stream: Stream::Unencrypted(BufReader::new(socket)),
buf_out: BytesMut::with_capacity(10 * 1024),
stream: Framed::new(stream, PostgresCodec::new()),
state: ProtoState::Initialization,
auth_type,
tls_config,
@@ -212,29 +246,60 @@ impl PostgresBackend {
}
/// Read full message or return None if connection is closed.
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, QueryError> {
use ProtoState::*;
match self.state {
Initialization | Encrypted => FeStartupPacket::read_fut(&mut self.stream).await,
Authentication | Established => FeMessage::read_fut(&mut self.stream).await,
Closed => Ok(None),
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
if let ProtoState::Closed = self.state {
Ok(None)
} else {
let msg = self.stream.next().await;
// Option<Result<...>>, so swap.
msg.map_or(Ok(None), |res| res.map(Some))
}
.map_err(QueryError::from)
}
/// Polling version of read_message, saves the caller need to pin.
pub fn poll_read_message(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<Option<FeMessage>, ConnectionError>> {
let read_fut = self.read_message();
pin_mut!(read_fut);
read_fut.poll(cx)
}
/// Flush output buffer into the socket.
pub async fn flush(&mut self) -> io::Result<()> {
while self.buf_out.has_remaining() {
let bytes_written = self.stream.write(self.buf_out.chunk()).await?;
self.buf_out.advance(bytes_written);
}
self.buf_out.clear();
Ok(())
self.stream.flush().await.map_err(|e| match e {
ConnectionError::Io(e) => e,
// the only error we can get from flushing is IO
_ => unreachable!(),
})
}
/// Write message into internal output buffer.
pub fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
BeMessage::write(&mut self.buf_out, message)?;
/// Polling version of `flush()`, saves the caller need to pin.
pub fn poll_flush(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let flush_fut = self.flush();
pin_mut!(flush_fut);
flush_fut.poll(cx)
}
/// Write message into internal output buffer. Technically error type can be
/// only ProtocolError here (if, unlikely, serialization fails), but callers
/// typically wrap it anyway.
pub fn write_message(&mut self, message: &BeMessage<'_>) -> Result<&mut Self, ConnectionError> {
Pin::new(&mut self.stream).start_send(message)?;
Ok(self)
}
/// Write message into internal output buffer and flush it to the stream.
pub async fn write_message_flush(
&mut self,
message: &BeMessage<'_>,
) -> Result<&mut Self, ConnectionError> {
self.write_message(message)?;
self.flush().await?;
Ok(self)
}
@@ -246,28 +311,6 @@ impl PostgresBackend {
CopyDataWriter { pgb: self }
}
/// A polling function that tries to write all the data from 'buf_out' to the
/// underlying stream.
fn poll_write_buf(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
while self.buf_out.has_remaining() {
match Pin::new(&mut self.stream).poll_write(cx, self.buf_out.chunk()) {
Poll::Ready(Ok(bytes_written)) => {
self.buf_out.advance(bytes_written);
}
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => return Poll::Pending,
}
}
Poll::Ready(Ok(()))
}
fn poll_flush(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), std::io::Error>> {
Pin::new(&mut self.stream).poll_flush(cx)
}
// Wrapper for run_message_loop() that shuts down socket when we are done
pub async fn run<F, S>(
mut self,
@@ -279,7 +322,7 @@ impl PostgresBackend {
S: Future,
{
let ret = self.run_message_loop(handler, shutdown_watcher).await;
let _ = self.stream.shutdown();
let _ = self.stream.get_mut().shutdown();
ret
}
@@ -359,14 +402,22 @@ impl PostgresBackend {
}
async fn start_tls(&mut self) -> anyhow::Result<()> {
if let Stream::Unencrypted(plain_stream) =
std::mem::replace(&mut self.stream, Stream::Broken)
if let MaybeTlsStream::Unencrypted(plain_stream) =
// temporary replace stream with fake broken to prepare TLS one
std::mem::replace(self.stream.get_mut(), MaybeTlsStream::Broken)
{
let acceptor = TlsAcceptor::from(self.tls_config.clone().unwrap());
let tls_stream = acceptor.accept(plain_stream).await?;
self.stream = Stream::Tls(Box::new(tls_stream));
return Ok(());
match acceptor.accept(plain_stream).await {
Ok(tls_stream) => {
// push back ready TLS stream
*self.stream.get_mut() = MaybeTlsStream::Tls(Box::new(tls_stream));
return Ok(());
}
Err(e) => {
self.state = ProtoState::Closed;
return Err(e.into());
}
}
};
anyhow::bail!("TLS already started");
}
@@ -380,13 +431,12 @@ impl PostgresBackend {
let have_tls = self.tls_config.is_some();
match msg {
FeMessage::StartupPacket(m) => {
trace!("got startup message {m:?}");
match m {
FeStartupPacket::SslRequest => {
debug!("SSL requested");
self.write_message(&BeMessage::EncryptionResponse(have_tls))?;
if have_tls {
self.start_tls().await?;
self.state = ProtoState::Encrypted;
@@ -415,6 +465,7 @@ impl PostgresBackend {
AuthType::Trust => {
self.write_message(&BeMessage::AuthenticationOk)?
.write_message(&BeMessage::CLIENT_ENCODING)?
.write_message(&BeMessage::INTEGER_DATETIMES)?
// The async python driver requires a valid server_version
.write_message(&BeMessage::server_version("14.1"))?
.write_message(&BeMessage::ReadyForQuery)?;
@@ -454,6 +505,7 @@ impl PostgresBackend {
}
self.write_message(&BeMessage::AuthenticationOk)?
.write_message(&BeMessage::CLIENT_ENCODING)?
.write_message(&BeMessage::INTEGER_DATETIMES)?
.write_message(&BeMessage::ReadyForQuery)?;
self.state = ProtoState::Established;
}
@@ -573,7 +625,7 @@ impl<'a> AsyncWrite for CopyDataWriter<'a> {
// It's not strictly required to flush between each message, but makes it easier
// to view in wireshark, and usually the messages that the callers write are
// decently-sized anyway.
match this.pgb.poll_write_buf(cx) {
match this.pgb.poll_flush(cx) {
Poll::Ready(Ok(())) => {}
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => return Poll::Pending,
@@ -583,7 +635,11 @@ impl<'a> AsyncWrite for CopyDataWriter<'a> {
// XXX: if the input is large, we should split it into multiple messages.
// Not sure what the threshold should be, but the ultimate hard limit is that
// the length cannot exceed u32.
this.pgb.write_message(&BeMessage::CopyData(buf))?;
this.pgb
.write_message(&BeMessage::CopyData(buf))
// write_message only writes to buffer, so can fail iff message is
// invaid, but CopyData can't be invalid.
.expect("failed to serialize CopyData");
Poll::Ready(Ok(buf.len()))
}
@@ -593,23 +649,14 @@ impl<'a> AsyncWrite for CopyDataWriter<'a> {
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let this = self.get_mut();
match this.pgb.poll_write_buf(cx) {
Poll::Ready(Ok(())) => {}
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => return Poll::Pending,
}
this.pgb.poll_flush(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let this = self.get_mut();
match this.pgb.poll_write_buf(cx) {
Poll::Ready(Ok(())) => {}
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => return Poll::Pending,
}
this.pgb.poll_flush(cx)
}
}
@@ -623,7 +670,7 @@ pub fn short_error(e: &QueryError) -> String {
pub(super) fn log_query_error(query: &str, e: &QueryError) {
match e {
QueryError::Disconnected(ConnectionError::Socket(io_error)) => {
QueryError::Disconnected(ConnectionError::Io(io_error)) => {
if is_expected_io_error(io_error) {
info!("query handler for '{query}' failed with expected io error: {io_error}");
} else {

116
libs/utils/src/send_rc.rs Normal file
View File

@@ -0,0 +1,116 @@
/// Provides Send wrappers of Rc and RefMut.
use std::{
borrow::Borrow,
cell::{Ref, RefCell, RefMut},
ops::{Deref, DerefMut},
rc::Rc,
};
/// Rc wrapper which is Send.
/// This is useful to allow transferring a group of Rcs pointing to the same
/// object between threads, e.g. in self referential struct.
#[derive(Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub struct SendRc<T>
where
T: ?Sized,
{
rc: Rc<T>,
}
// SAFETY: Passing Rc(s)<T: Send> between threads is fine as long as there is no
// concurrent access to the object they point to, so you must move all such Rcs
// together. This appears to be impossible to express in rust type system and
// SendRc doesn't provide any additional protection -- but unlike sendable
// crate, neither it requires any additional actions before/after move. Ensuring
// that sending conforms to the above is the responsibility of the type user.
unsafe impl<T: ?Sized + Send> Send for SendRc<T> {}
impl<T> SendRc<T> {
/// Constructs a new SendRc<T>
pub fn new(value: T) -> SendRc<T> {
SendRc { rc: Rc::new(value) }
}
}
// https://stegosaurusdormant.com/understanding-derive-clone/ explains in detail
// why derive Clone doesn't work here.
impl<T> Clone for SendRc<T> {
fn clone(&self) -> Self {
SendRc {
rc: self.rc.clone(),
}
}
}
// Deref into inner rc.
impl<T> Deref for SendRc<T> {
type Target = Rc<T>;
fn deref(&self) -> &Self::Target {
&self.rc
}
}
/// Extends RefCell with borrow[_mut] variants which return Sendable Ref[Mut]
/// wrappers.
pub trait RefCellSend<T: ?Sized> {
fn borrow_mut_send(&self) -> RefMutSend<'_, T>;
}
impl<T: Sized> RefCellSend<T> for RefCell<T> {
fn borrow_mut_send(&self) -> RefMutSend<'_, T> {
RefMutSend {
ref_mut: self.borrow_mut(),
}
}
}
/// RefMut wrapper which is Send. See impl Send for safety. Allows to move a
/// RefMut along with RefCell it originates from between threads, e.g. have Send
/// Future containing RefMut.
#[derive(Debug)]
pub struct RefMutSend<'b, T>
where
T: 'b + ?Sized,
{
ref_mut: RefMut<'b, T>,
}
// SAFETY: Similar to SendRc, this is safe as long as RefMut stays in the same
// thread with original RefCell, so they should be passed together.
// Actually, since this is a referential type violating this is not
// straightforward; examples of unsafe usage could be
// - Passing a RefMut to different thread without source RefCell. Seems only
// possible with std::thread::scope.
// - Somehow multiple threads get access to single RefCell concurrently,
// violating its !Sync requirement. Improper usage of SendRc can do that.
unsafe impl<'b, T: ?Sized + Send> Send for RefMutSend<'b, T> {}
impl<'b, T> RefMutSend<'b, T> {
/// Constructs a new RefMutSend<T>
pub fn new(ref_mut: RefMut<'b, T>) -> RefMutSend<'b, T> {
RefMutSend { ref_mut }
}
}
// Deref into inner RefMut.
impl<'b, T> Deref for RefMutSend<'b, T>
where
T: 'b + ?Sized,
{
type Target = RefMut<'b, T>;
fn deref<'a>(&'a self) -> &'a RefMut<'b, T> {
&self.ref_mut
}
}
// DerefMut into inner RefMut.
impl<'b, T> DerefMut for RefMutSend<'b, T>
where
T: 'b + ?Sized,
{
fn deref_mut<'a>(&'a mut self) -> &'a mut RefMut<'b, T> {
&mut self.ref_mut
}
}

View File

@@ -24,7 +24,7 @@ use pageserver::{
use utils::{
auth::JwtAuth,
logging,
postgres_backend::AuthType,
postgres_backend_async::AuthType,
project_git_version,
sentry_init::{init_sentry, release_name},
signals::{self, Signal},

View File

@@ -24,7 +24,7 @@ use toml_edit::{Document, Item};
use utils::{
id::{NodeId, TenantId, TimelineId},
logging::LogFormat,
postgres_backend::AuthType,
postgres_backend_async::AuthType,
};
use crate::tenant::config::TenantConf;

View File

@@ -19,7 +19,7 @@ use pageserver_api::models::{
PagestreamFeMessage, PagestreamGetPageRequest, PagestreamGetPageResponse,
PagestreamNblocksRequest, PagestreamNblocksResponse,
};
use pq_proto::ConnectionError;
use pq_proto::codec::ConnectionError;
use pq_proto::FeStartupPacket;
use pq_proto::{BeMessage, FeMessage, RowDescriptor};
use std::io;
@@ -35,7 +35,7 @@ use utils::{
auth::{Claims, JwtAuth, Scope},
id::{TenantId, TimelineId},
lsn::Lsn,
postgres_backend::AuthType,
postgres_backend_async::AuthType,
postgres_backend_async::{self, PostgresBackend},
simple_rcu::RcuReadGuard,
};
@@ -67,7 +67,7 @@ fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream<Item = io::Result<Byt
Err(QueryError::Other(anyhow::anyhow!(msg)))
}
msg = pgb.read_message() => { msg }
msg = pgb.read_message() => { msg.map_err(QueryError::from)}
};
match msg {
@@ -78,14 +78,16 @@ fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream<Item = io::Result<Byt
FeMessage::Sync => continue,
FeMessage::Terminate => {
let msg = "client terminated connection with Terminate message during COPY";
let query_error_error = QueryError::Disconnected(ConnectionError::Socket(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
pgb.write_message(&BeMessage::ErrorResponse(msg, Some(query_error_error.pg_error_code())))?;
let query_error = QueryError::Disconnected(ConnectionError::Io(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
pgb.write_message(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code())))
.expect("failed to serialize ErrorResponse");
Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?;
break;
}
m => {
let msg = format!("unexpected message {m:?}");
pgb.write_message(&BeMessage::ErrorResponse(&msg, None))?;
pgb.write_message(&BeMessage::ErrorResponse(&msg, None))
.expect("failed to serialize ErrorResponse");
Err(io::Error::new(io::ErrorKind::Other, msg))?;
break;
}
@@ -95,16 +97,17 @@ fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream<Item = io::Result<Byt
}
Ok(None) => {
let msg = "client closed connection during COPY";
let query_error_error = QueryError::Disconnected(ConnectionError::Socket(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
pgb.write_message(&BeMessage::ErrorResponse(msg, Some(query_error_error.pg_error_code())))?;
let query_error = QueryError::Disconnected(ConnectionError::Io(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
pgb.write_message(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code())))
.expect("failed to serialize ErrorResponse");
pgb.flush().await?;
Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?;
}
Err(QueryError::Disconnected(ConnectionError::Socket(io_error))) => {
Err(QueryError::Disconnected(ConnectionError::Io(io_error))) => {
Err(io_error)?;
}
Err(other) => {
Err(io::Error::new(io::ErrorKind::Other, other))?;
Err(io::Error::new(io::ErrorKind::Other, other.to_string()))?;
}
};
}
@@ -202,7 +205,7 @@ async fn page_service_conn_main(
// we've been requested to shut down
Ok(())
}
Err(QueryError::Disconnected(ConnectionError::Socket(io_error))) => {
Err(QueryError::Disconnected(ConnectionError::Io(io_error))) => {
// `ConnectionReset` error happens when the Postgres client closes the connection.
// As this disconnection happens quite often and is expected,
// we decided to downgrade the logging level to `INFO`.

View File

@@ -2,7 +2,7 @@ use crate::error::UserFacingError;
use anyhow::bail;
use bytes::BytesMut;
use pin_project_lite::pin_project;
use pq_proto::{BeMessage, ConnectionError, FeMessage, FeStartupPacket};
use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
use rustls::ServerConfig;
use std::pin::Pin;
use std::sync::Arc;
@@ -53,7 +53,7 @@ impl<S: AsyncRead + Unpin> PqStream<S> {
// TODO: `FeStartupPacket::read_fut` should return `FeStartupPacket`
let msg = FeStartupPacket::read_fut(&mut self.stream)
.await
.map_err(ConnectionError::into_io_error)?
.map_err(ProtocolError::into_io_error)?
.ok_or_else(err_connection)?;
match msg {
@@ -75,7 +75,7 @@ impl<S: AsyncRead + Unpin> PqStream<S> {
async fn read_message(&mut self) -> io::Result<FeMessage> {
FeMessage::read_fut(&mut self.stream)
.await
.map_err(ConnectionError::into_io_error)?
.map_err(ProtocolError::into_io_error)?
.ok_or_else(err_connection)
}
}

View File

@@ -4,7 +4,7 @@
# version, we can consider updating.
# See https://tracker.debian.org/pkg/rustc for more details on Debian rustc package,
# we use "unstable" version number as the highest version used in the project by default.
channel = "1.62.1"
channel = "1.66.1"
profile = "default"
# The default profile includes rustc, rust-std, cargo, rust-docs, rustfmt and clippy.
# https://rust-lang.github.io/rustup/concepts/profiles.html

View File

@@ -14,12 +14,14 @@ clap = { version = "4.0", features = ["derive"] }
const_format = "0.2.21"
crc32c = "0.6.0"
fs2 = "0.4.3"
futures = "0.3"
git-version = "0.3.5"
hex = "0.4.3"
humantime = "2.1.0"
hyper = "0.14"
nix = "0.25"
once_cell = "1.13.0"
pin-project-lite = "0.2"
parking_lot = "0.12.1"
postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="43e6db254a97fdecbce33d8bc0890accfd74495e" }
postgres-protocol = { git = "https://github.com/neondatabase/rust-postgres.git", rev="43e6db254a97fdecbce33d8bc0890accfd74495e" }

View File

@@ -228,20 +228,20 @@ fn start_safekeeper(conf: SafeKeeperConf) -> Result<()> {
let conf_cloned = conf.clone();
let safekeeper_thread = thread::Builder::new()
.name("safekeeper thread".into())
.name("WAL service thread".into())
.spawn(|| wal_service::thread_main(conf_cloned, pg_listener))
.unwrap();
threads.push(safekeeper_thread);
let conf_ = conf.clone();
threads.push(
thread::Builder::new()
.name("broker thread".into())
.spawn(|| {
broker::thread_main(conf_);
})?,
);
// threads.push(
// thread::Builder::new()
// .name("broker thread".into())
// .spawn(|| {
// broker::thread_main(conf_);
// })?,
// );
let conf_ = conf.clone();
threads.push(

View File

@@ -1,27 +1,23 @@
//! Part of Safekeeper pretending to be Postgres, i.e. handling Postgres
//! protocol commands.
use anyhow::{bail, Context};
use std::str;
use tracing::{info, info_span, Instrument};
use crate::auth::check_permission;
use crate::json_ctrl::{handle_json_ctrl, AppendLogicalMessage};
use crate::receive_wal::ReceiveWalConn;
use crate::send_wal::ReplicationConn;
use crate::{GlobalTimelines, SafeKeeperConf};
use anyhow::Context;
use postgres_ffi::PG_TLI;
use regex::Regex;
use pq_proto::{BeMessage, FeStartupPacket, RowDescriptor, INT4_OID, TEXT_OID};
use std::str;
use tracing::info;
use regex::Regex;
use utils::auth::{Claims, Scope};
use utils::postgres_backend_async::QueryError;
use utils::{
id::{TenantId, TenantTimelineId, TimelineId},
lsn::Lsn,
postgres_backend::{self, PostgresBackend},
postgres_backend_async::{self, PostgresBackend},
};
/// Safekeeper handler of postgres commands
@@ -41,9 +37,11 @@ enum SafekeeperPostgresCommand {
StartReplication { start_lsn: Lsn },
IdentifySystem,
JSONCtrl { cmd: AppendLogicalMessage },
Show { guc: String },
}
fn parse_cmd(cmd: &str) -> anyhow::Result<SafekeeperPostgresCommand> {
let cmd_lowercase = cmd.to_ascii_lowercase();
if cmd.starts_with("START_WAL_PUSH") {
Ok(SafekeeperPostgresCommand::StartWalPush)
} else if cmd.starts_with("START_REPLICATION") {
@@ -53,7 +51,7 @@ fn parse_cmd(cmd: &str) -> anyhow::Result<SafekeeperPostgresCommand> {
let start_lsn = caps
.next()
.map(|cap| cap[1].parse::<Lsn>())
.context("failed to parse start LSN from START_REPLICATION command")??;
.context("parse start LSN from START_REPLICATION command")??;
Ok(SafekeeperPostgresCommand::StartReplication { start_lsn })
} else if cmd.starts_with("IDENTIFY_SYSTEM") {
Ok(SafekeeperPostgresCommand::IdentifySystem)
@@ -62,12 +60,21 @@ fn parse_cmd(cmd: &str) -> anyhow::Result<SafekeeperPostgresCommand> {
Ok(SafekeeperPostgresCommand::JSONCtrl {
cmd: serde_json::from_str(cmd)?,
})
} else if cmd_lowercase.starts_with("show") {
let re = Regex::new(r"show ((?:[[:alpha:]]|_)+)").unwrap();
let mut caps = re.captures_iter(&cmd_lowercase);
let guc = caps
.next()
.map(|cap| cap[1].parse::<String>())
.context("parse guc in SHOW command")??;
Ok(SafekeeperPostgresCommand::Show { guc })
} else {
anyhow::bail!("unsupported command {cmd}");
}
}
impl postgres_backend::Handler for SafekeeperPostgresHandler {
#[async_trait::async_trait]
impl postgres_backend_async::Handler for SafekeeperPostgresHandler {
// tenant_id and timeline_id are passed in connection string params
fn startup(
&mut self,
@@ -137,7 +144,7 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler {
Ok(())
}
fn process_query(
async fn process_query(
&mut self,
pgb: &mut PostgresBackend,
query_string: &str,
@@ -147,9 +154,11 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler {
.starts_with("set datestyle to ")
{
// important for debug because psycopg2 executes "SET datestyle TO 'ISO'" on connect
pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
pgb.write_message_flush(&BeMessage::CommandComplete(b"SELECT 1"))
.await?;
return Ok(());
}
let cmd = parse_cmd(query_string)?;
info!(
@@ -161,14 +170,22 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler {
let timeline_id = self.timeline_id.context("timelineid is required")?;
self.check_permission(Some(tenant_id))?;
self.ttid = TenantTimelineId::new(tenant_id, timeline_id);
let span_ttid = self.ttid; // satisfy borrow checker
let res = match cmd {
SafekeeperPostgresCommand::StartWalPush => ReceiveWalConn::new(pgb).run(self),
// SafekeeperPostgresCommand::StartWalPush => ReceiveWalConn::new(pgb).run(self),
SafekeeperPostgresCommand::StartWalPush => Ok(()),
SafekeeperPostgresCommand::StartReplication { start_lsn } => {
ReplicationConn::new(pgb).run(self, pgb, start_lsn)
self.handle_start_replication(pgb, start_lsn)
.instrument(info_span!("WAL sender", ttid = %span_ttid))
.await
}
SafekeeperPostgresCommand::IdentifySystem => self.handle_identify_system(pgb),
SafekeeperPostgresCommand::JSONCtrl { ref cmd } => handle_json_ctrl(self, pgb, cmd),
SafekeeperPostgresCommand::IdentifySystem => self.handle_identify_system(pgb).await,
SafekeeperPostgresCommand::Show { guc } => self.handle_show(guc, pgb).await,
SafekeeperPostgresCommand::JSONCtrl { ref cmd } => {
handle_json_ctrl(self, pgb, cmd).await
}
_ => unreachable!(),
};
match res {
@@ -217,7 +234,10 @@ impl SafekeeperPostgresHandler {
///
/// Handle IDENTIFY_SYSTEM replication command
///
fn handle_identify_system(&mut self, pgb: &mut PostgresBackend) -> Result<(), QueryError> {
async fn handle_identify_system(
&mut self,
pgb: &mut PostgresBackend,
) -> Result<(), QueryError> {
let tli = GlobalTimelines::get(self.ttid)?;
let lsn = if self.is_walproposer_recovery() {
@@ -235,7 +255,7 @@ impl SafekeeperPostgresHandler {
let tli_bytes = tli.as_bytes();
let sysid_bytes = sysid.as_bytes();
pgb.write_message_noflush(&BeMessage::RowDescription(&[
pgb.write_message(&BeMessage::RowDescription(&[
RowDescriptor {
name: b"systemid",
typoid: TEXT_OID,
@@ -261,13 +281,48 @@ impl SafekeeperPostgresHandler {
..Default::default()
},
]))?
.write_message_noflush(&BeMessage::DataRow(&[
.write_message(&BeMessage::DataRow(&[
Some(sysid_bytes),
Some(tli_bytes),
Some(lsn_bytes),
None,
]))?
.write_message(&BeMessage::CommandComplete(b"IDENTIFY_SYSTEM"))?;
.write_message_flush(&BeMessage::CommandComplete(b"IDENTIFY_SYSTEM"))
.await?;
Ok(())
}
async fn handle_show(
&mut self,
guc: String,
pgb: &mut PostgresBackend,
) -> Result<(), QueryError> {
match guc.as_str() {
// pg_receivewal wants it
"data_directory_mode" => {
pgb.write_message(&BeMessage::RowDescription(&[RowDescriptor::int8_col(
b"data_directory_mode",
)]))?
// xxx we could return real one, not just 0700
.write_message(&BeMessage::DataRow(&[Some(0700.to_string().as_bytes())]))?
.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
}
// pg_receivewal wants it
"wal_segment_size" => {
let tli = GlobalTimelines::get(self.ttid)?;
let wal_seg_size = tli.get_state().1.server.wal_seg_size;
let wal_seg_size_mb = (wal_seg_size / 1024 / 1024).to_string() + "MB";
pgb.write_message(&BeMessage::RowDescription(&[RowDescriptor::text_col(
b"wal_segment_size",
)]))?
.write_message(&BeMessage::DataRow(&[Some(wal_seg_size_mb.as_bytes())]))?
.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
}
_ => {
return Err(anyhow::anyhow!("SHOW of unknown setting").into());
}
}
Ok(())
}

View File

@@ -8,11 +8,14 @@ use serde::Serialize;
use serde::Serializer;
use std::collections::{HashMap, HashSet};
use std::fmt::Display;
use std::str::FromStr;
use std::sync::Arc;
use storage_broker::proto::SafekeeperTimelineInfo;
use storage_broker::proto::TenantTimelineId as ProtoTenantTimelineId;
use tokio::task::JoinError;
use crate::json_ctrl::append_logical_message;
use crate::json_ctrl::AppendLogicalMessage;
use crate::safekeeper::ServerInfo;
use crate::safekeeper::Term;
@@ -191,6 +194,50 @@ async fn timeline_create_handler(mut request: Request<Body>) -> Result<Response<
json_response(StatusCode::OK, ())
}
// Create fake timeline + insert some valid WAL. Useful to test WAL streaming
// from safekeeper in isolation, e.g.
// pg_receivewal -v -d "host=localhost port=5454 options='-c tenant_id=deadbeefdeadbeefdeadbeefdeadbeef timeline_id=deadbeefdeadbeefdeadbeefdeadbeef'" -D ~/tmp/tmp/tmp
// (hacking pg_receivewal startpos is currently needed though to make pg_receivewal work)
async fn create_fake_timeline_handler(_request: Request<Body>) -> Result<Response<Body>, ApiError> {
let ttid = TenantTimelineId {
tenant_id: TenantId::from_str("deadbeefdeadbeefdeadbeefdeadbeef")
.expect("timeline_id parsing failed"),
timeline_id: TimelineId::from_str("deadbeefdeadbeefdeadbeefdeadbeef")
.expect("tenant_id parsing failed"),
};
let pg_version = 150000;
let server_info = ServerInfo {
pg_version,
system_id: 0,
wal_seg_size: WAL_SEGMENT_SIZE as u32,
};
let init_lsn = Lsn(0x1493AC8);
tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
let tli = GlobalTimelines::create(ttid, server_info, init_lsn, init_lsn)?;
let mut begin_lsn = init_lsn;
for _ in 0..16 {
let append = AppendLogicalMessage {
lm_prefix: "db".to_owned(),
lm_message: "hahabubu".to_owned(),
set_commit_lsn: true,
send_proposer_elected: false, // actually ignored here
term: 0,
epoch_start_lsn: init_lsn,
begin_lsn,
truncate_lsn: init_lsn,
pg_version,
};
let inserted = append_logical_message(&tli, &append)?;
begin_lsn = inserted.end_lsn;
}
Ok(())
})
.await
.map_err(|e| ApiError::InternalServerError(e.into()))?
.map_err(ApiError::InternalServerError)?;
json_response(StatusCode::OK, ())
}
/// Deactivates the timeline and removes its data directory.
async fn timeline_delete_force_handler(
mut request: Request<Body>,
@@ -302,6 +349,7 @@ pub fn make_router(conf: SafeKeeperConf) -> RouterBuilder<hyper::Body, ApiError>
.get("/v1/status", status_handler)
// Will be used in the future instead of implicit timeline creation
.post("/v1/tenant/timeline", timeline_create_handler)
.post("/v1/fake_timeline", create_fake_timeline_handler)
.get(
"/v1/tenant/:tenant_id/timeline/:timeline_id",
timeline_status_handler,

View File

@@ -26,26 +26,26 @@ use crate::GlobalTimelines;
use postgres_ffi::encode_logical_message;
use postgres_ffi::WAL_SEGMENT_SIZE;
use pq_proto::{BeMessage, RowDescriptor, TEXT_OID};
use utils::{lsn::Lsn, postgres_backend::PostgresBackend};
use utils::{lsn::Lsn, postgres_backend_async::PostgresBackend};
#[derive(Serialize, Deserialize, Debug)]
pub struct AppendLogicalMessage {
// prefix and message to build LogicalMessage
lm_prefix: String,
lm_message: String,
pub lm_prefix: String,
pub lm_message: String,
// if true, commit_lsn will match flush_lsn after append
set_commit_lsn: bool,
pub set_commit_lsn: bool,
// if true, ProposerElected will be sent before append
send_proposer_elected: bool,
pub send_proposer_elected: bool,
// fields from AppendRequestHeader
term: Term,
epoch_start_lsn: Lsn,
begin_lsn: Lsn,
truncate_lsn: Lsn,
pg_version: u32,
pub term: Term,
pub epoch_start_lsn: Lsn,
pub begin_lsn: Lsn,
pub truncate_lsn: Lsn,
pub pg_version: u32,
}
#[derive(Debug, Serialize, Deserialize)]
@@ -59,7 +59,7 @@ struct AppendResult {
/// Handles command to craft logical message WAL record with given
/// content, and then append it with specified term and lsn. This
/// function is used to test safekeepers in different scenarios.
pub fn handle_json_ctrl(
pub async fn handle_json_ctrl(
spg: &SafekeeperPostgresHandler,
pgb: &mut PostgresBackend,
append_request: &AppendLogicalMessage,
@@ -82,14 +82,15 @@ pub fn handle_json_ctrl(
let response_data = serde_json::to_vec(&response)
.with_context(|| format!("Response {response:?} is not a json array"))?;
pgb.write_message_noflush(&BeMessage::RowDescription(&[RowDescriptor {
pgb.write_message(&BeMessage::RowDescription(&[RowDescriptor {
name: b"json",
typoid: TEXT_OID,
typlen: -1,
..Default::default()
}]))?
.write_message_noflush(&BeMessage::DataRow(&[Some(&response_data)]))?
.write_message(&BeMessage::CommandComplete(b"JSON_CTRL"))?;
.write_message(&BeMessage::DataRow(&[Some(&response_data)]))?
.write_message_flush(&BeMessage::CommandComplete(b"JSON_CTRL"))
.await?;
Ok(())
}
@@ -128,15 +129,15 @@ fn send_proposer_elected(tli: &Arc<Timeline>, term: Term, lsn: Lsn) -> anyhow::R
}
#[derive(Debug, Serialize, Deserialize)]
struct InsertedWAL {
pub struct InsertedWAL {
begin_lsn: Lsn,
end_lsn: Lsn,
pub end_lsn: Lsn,
append_response: AppendResponse,
}
/// Extend local WAL with new LogicalMessage record. To do that,
/// create AppendRequest with new WAL and pass it to safekeeper.
fn append_logical_message(
pub fn append_logical_message(
tli: &Arc<Timeline>,
msg: &AppendLogicalMessage,
) -> anyhow::Result<InsertedWAL> {

View File

@@ -26,7 +26,7 @@ use crate::safekeeper::ProposerAcceptorMessage;
use crate::handler::SafekeeperPostgresHandler;
use pq_proto::{BeMessage, FeMessage};
use utils::{postgres_backend::PostgresBackend, sock_split::ReadStream};
use utils::{postgres_backend_async::PostgresBackend, sock_split::ReadStream};
pub struct ReceiveWalConn<'pg> {
/// Postgres connection
@@ -59,82 +59,83 @@ impl<'pg> ReceiveWalConn<'pg> {
// Notify the libpq client that it's allowed to send `CopyData` messages
self.pg_backend
.write_message(&BeMessage::CopyBothResponse)?;
let r = self
.pg_backend
.take_stream_in()
.ok_or_else(|| anyhow!("failed to take read stream from pgbackend"))?;
let mut poll_reader = ProposerPollStream::new(r)?;
Ok(())
// let r = self
// .pg_backend
// .take_stream_in()
// .ok_or_else(|| anyhow!("failed to take read stream from pgbackend"))?;
// let mut poll_reader = ProposerPollStream::new(r)?;
// Receive information about server
let next_msg = poll_reader.recv_msg()?;
let tli = match next_msg {
ProposerAcceptorMessage::Greeting(ref greeting) => {
info!(
"start handshake with walproposer {} sysid {} timeline {}",
self.peer_addr, greeting.system_id, greeting.tli,
);
let server_info = ServerInfo {
pg_version: greeting.pg_version,
system_id: greeting.system_id,
wal_seg_size: greeting.wal_seg_size,
};
GlobalTimelines::create(spg.ttid, server_info, Lsn::INVALID, Lsn::INVALID)?
}
_ => {
return Err(QueryError::Other(anyhow::anyhow!(
"unexpected message {next_msg:?} instead of greeting"
)))
}
};
// let next_msg = poll_reader.recv_msg()?;
// let tli = match next_msg {
// ProposerAcceptorMessage::Greeting(ref greeting) => {
// info!(
// "start handshake with walproposer {} sysid {} timeline {}",
// self.peer_addr, greeting.system_id, greeting.tli,
// );
// let server_info = ServerInfo {
// pg_version: greeting.pg_version,
// system_id: greeting.system_id,
// wal_seg_size: greeting.wal_seg_size,
// };
// GlobalTimelines::create(spg.ttid, server_info, Lsn::INVALID, Lsn::INVALID)?
// }
// _ => {
// return Err(QueryError::Other(anyhow::anyhow!(
// "unexpected message {next_msg:?} instead of greeting"
// )))
// }
// };
let mut next_msg = Some(next_msg);
// let mut next_msg = None;
let mut first_time_through = true;
let mut _guard: Option<ComputeConnectionGuard> = None;
loop {
if matches!(next_msg, Some(ProposerAcceptorMessage::AppendRequest(_))) {
// poll AppendRequest's without blocking and write WAL to disk without flushing,
// while it's readily available
while let Some(ProposerAcceptorMessage::AppendRequest(append_request)) = next_msg {
let msg = ProposerAcceptorMessage::NoFlushAppendRequest(append_request);
// let mut first_time_through = true;
// let mut _guard: Option<ComputeConnectionGuard> = None;
// loop {
// if matches!(next_msg, Some(ProposerAcceptorMessage::AppendRequest(_))) {
// // poll AppendRequest's without blocking and write WAL to disk without flushing,
// // while it's readily available
// while let Some(ProposerAcceptorMessage::AppendRequest(append_request)) = next_msg {
// let msg = ProposerAcceptorMessage::NoFlushAppendRequest(append_request);
let reply = tli.process_msg(&msg)?;
if let Some(reply) = reply {
self.write_msg(&reply)?;
}
// let reply = tli.process_msg(&msg)?;
// if let Some(reply) = reply {
// self.write_msg(&reply)?;
// }
next_msg = poll_reader.poll_msg();
}
// // next_msg = poll_reader.poll_msg();
// next_msg = poll_reader.poll_msg();
// }
// flush all written WAL to the disk
let reply = tli.process_msg(&ProposerAcceptorMessage::FlushWAL)?;
if let Some(reply) = reply {
self.write_msg(&reply)?;
}
} else if let Some(msg) = next_msg.take() {
// process other message
let reply = tli.process_msg(&msg)?;
if let Some(reply) = reply {
self.write_msg(&reply)?;
}
}
if first_time_through {
// Register the connection and defer unregister. Do that only
// after processing first message, as it sets wal_seg_size,
// wanted by many.
tli.on_compute_connect()?;
_guard = Some(ComputeConnectionGuard {
timeline: Arc::clone(&tli),
});
first_time_through = false;
}
// // flush all written WAL to the disk
// let reply = tli.process_msg(&ProposerAcceptorMessage::FlushWAL)?;
// if let Some(reply) = reply {
// self.write_msg(&reply)?;
// }
// } else if let Some(msg) = next_msg.take() {
// // process other message
// let reply = tli.process_msg(&msg)?;
// if let Some(reply) = reply {
// self.write_msg(&reply)?;
// }
// }
// if first_time_through {
// // Register the connection and defer unregister. Do that only
// // after processing first message, as it sets wal_seg_size,
// // wanted by many.
// tli.on_compute_connect()?;
// _guard = Some(ComputeConnectionGuard {
// timeline: Arc::clone(&tli),
// });
// first_time_through = false;
// }
// blocking wait for the next message
if next_msg.is_none() {
next_msg = Some(poll_reader.recv_msg()?);
}
}
// // blocking wait for the next message
// if next_msg.is_none() {
// next_msg = Some(poll_reader.recv_msg()?);
// }
// }
}
}
@@ -144,37 +145,37 @@ struct ProposerPollStream {
}
impl ProposerPollStream {
fn new(mut r: ReadStream) -> anyhow::Result<Self> {
let (msg_tx, msg_rx) = channel();
// fn new(mut r: ReadStream) -> anyhow::Result<Self> {
// let (msg_tx, msg_rx) = channel();
let read_thread = thread::Builder::new()
.name("Read WAL thread".into())
.spawn(move || -> Result<(), QueryError> {
loop {
let copy_data = match FeMessage::read(&mut r)? {
Some(FeMessage::CopyData(bytes)) => Ok(bytes),
Some(msg) => Err(QueryError::Other(anyhow::anyhow!(
"expected `CopyData` message, found {msg:?}"
))),
None => Err(QueryError::from(std::io::Error::new(
std::io::ErrorKind::ConnectionAborted,
"walproposer closed the connection",
))),
}?;
// let read_thread = thread::Builder::new()
// .name("Read WAL thread".into())
// .spawn(move || -> Result<(), QueryError> {
// loop {
// let copy_data = match FeMessage::read(&mut r)? {
// Some(FeMessage::CopyData(bytes)) => Ok(bytes),
// Some(msg) => Err(QueryError::Other(anyhow::anyhow!(
// "expected `CopyData` message, found {msg:?}"
// ))),
// None => Err(QueryError::from(std::io::Error::new(
// std::io::ErrorKind::ConnectionAborted,
// "walproposer closed the connection",
// ))),
// }?;
let msg = ProposerAcceptorMessage::parse(copy_data)?;
msg_tx
.send(msg)
.context("Failed to send the proposer message")?;
}
// msg_tx will be dropped here, this will also close msg_rx
})?;
// let msg = ProposerAcceptorMessage::parse(copy_data)?;
// msg_tx
// .send(msg)
// .context("Failed to send the proposer message")?;
// }
// // msg_tx will be dropped here, this will also close msg_rx
// })?;
Ok(Self {
msg_rx,
read_thread: Some(read_thread),
})
}
// Ok(Self {
// msg_rx,
// read_thread: Some(read_thread),
// })
// }
fn recv_msg(&mut self) -> Result<ProposerAcceptorMessage, QueryError> {
self.msg_rx.recv().map_err(|_| {

View File

@@ -1,28 +1,36 @@
//! This module implements the streaming side of replication protocol, starting
//! with the "START_REPLICATION" message.
use anyhow::Context as AnyhowContext;
use bytes::Bytes;
use futures::future::BoxFuture;
use pin_project_lite::pin_project;
use postgres_ffi::get_current_timestamp;
use postgres_ffi::{TimestampTz, MAX_SEND_SIZE};
use serde::{Deserialize, Serialize};
use std::cell::RefCell;
use std::cmp::min;
use std::future::Future;
use std::pin::Pin;
use std::rc::Rc;
use std::sync::Arc;
use std::task::{ready, Context, Poll};
use std::time::Duration;
use std::{io, str, thread};
use tokio::sync::watch::Receiver;
use tokio::time::timeout;
use tracing::*;
use utils::postgres_backend_async::QueryError;
use utils::send_rc::RefCellSend;
use utils::send_rc::SendRc;
use pq_proto::{BeMessage, FeMessage, ReplicationFeedback, WalSndKeepAlive, XLogDataBody};
use utils::{bin_ser::BeSer, lsn::Lsn, postgres_backend_async::PostgresBackend};
use crate::handler::SafekeeperPostgresHandler;
use crate::timeline::{ReplicaState, Timeline};
use crate::wal_storage::WalReader;
use crate::GlobalTimelines;
use anyhow::Context;
use bytes::Bytes;
use postgres_ffi::get_current_timestamp;
use postgres_ffi::{TimestampTz, MAX_SEND_SIZE};
use serde::{Deserialize, Serialize};
use std::cmp::min;
use std::net::Shutdown;
use std::sync::Arc;
use std::time::Duration;
use std::{io, str, thread};
use utils::postgres_backend_async::QueryError;
use pq_proto::{BeMessage, FeMessage, ReplicationFeedback, WalSndKeepAlive, XLogDataBody};
use tokio::sync::watch::Receiver;
use tokio::time::timeout;
use tracing::*;
use utils::{bin_ser::BeSer, lsn::Lsn, postgres_backend::PostgresBackend, sock_split::ReadStream};
// See: https://www.postgresql.org/docs/13/protocol-replication.html
const HOT_STANDBY_FEEDBACK_TAG_BYTE: u8 = b'h';
@@ -60,13 +68,6 @@ pub struct StandbyReply {
pub reply_requested: bool,
}
/// A network connection that's speaking the replication protocol.
pub struct ReplicationConn {
/// This is an `Option` because we will spawn a background thread that will
/// `take` it from us.
stream_in: Option<ReadStream>,
}
/// Scope guard to unregister replication connection from timeline
struct ReplicationConnGuard {
replica: usize, // replica internal ID assigned by timeline
@@ -79,230 +80,418 @@ impl Drop for ReplicationConnGuard {
}
}
impl ReplicationConn {
/// Create a new `ReplicationConn`
pub fn new(pgb: &mut PostgresBackend) -> Self {
Self {
stream_in: pgb.take_stream_in(),
}
}
/// Handle incoming messages from the network.
/// This is spawned into the background by `handle_start_replication`.
fn background_thread(
mut stream_in: ReadStream,
replica_guard: Arc<ReplicationConnGuard>,
) -> anyhow::Result<()> {
let replica_id = replica_guard.replica;
let timeline = &replica_guard.timeline;
let mut state = ReplicaState::new();
// Wait for replica's feedback.
while let Some(msg) = FeMessage::read(&mut stream_in)? {
match &msg {
FeMessage::CopyData(m) => {
// There's three possible data messages that the client is supposed to send here:
// `HotStandbyFeedback` and `StandbyStatusUpdate` and `NeonStandbyFeedback`.
match m.first().cloned() {
Some(HOT_STANDBY_FEEDBACK_TAG_BYTE) => {
// Note: deserializing is on m[1..] because we skip the tag byte.
state.hs_feedback = HotStandbyFeedback::des(&m[1..])
.context("failed to deserialize HotStandbyFeedback")?;
timeline.update_replica_state(replica_id, state);
}
Some(STANDBY_STATUS_UPDATE_TAG_BYTE) => {
let _reply = StandbyReply::des(&m[1..])
.context("failed to deserialize StandbyReply")?;
// This must be a regular postgres replica,
// because pageserver doesn't send this type of messages to safekeeper.
// Currently this is not implemented, so this message is ignored.
warn!("unexpected StandbyReply. Read-only postgres replicas are not supported in safekeepers yet.");
// timeline.update_replica_state(replica_id, Some(state));
}
Some(NEON_STATUS_UPDATE_TAG_BYTE) => {
// Note: deserializing is on m[9..] because we skip the tag byte and len bytes.
let buf = Bytes::copy_from_slice(&m[9..]);
let reply = ReplicationFeedback::parse(buf);
trace!("ReplicationFeedback is {:?}", reply);
// Only pageserver sends ReplicationFeedback, so set the flag.
// This replica is the source of information to resend to compute.
state.pageserver_feedback = Some(reply);
timeline.update_replica_state(replica_id, state);
}
_ => warn!("unexpected message {:?}", msg),
}
}
FeMessage::Sync => {}
FeMessage::CopyFail => {
// Shutdown the connection, because rust-postgres client cannot be dropped
// when connection is alive.
let _ = stream_in.shutdown(Shutdown::Both);
anyhow::bail!("Copy failed");
}
_ => {
// We only handle `CopyData`, 'Sync', 'CopyFail' messages. Anything else is ignored.
info!("unexpected message {:?}", msg);
}
}
}
Ok(())
}
///
/// Handle START_REPLICATION replication command
///
pub fn run(
impl SafekeeperPostgresHandler {
pub async fn handle_start_replication(
&mut self,
spg: &mut SafekeeperPostgresHandler,
pgb: &mut PostgresBackend,
mut start_pos: Lsn,
start_pos: Lsn,
) -> Result<(), QueryError> {
let _enter = info_span!("WAL sender", ttid = %spg.ttid).entered();
let tli = GlobalTimelines::get(spg.ttid)?;
// spawn the background thread which receives HotStandbyFeedback messages.
let bg_timeline = Arc::clone(&tli);
let bg_stream_in = self.stream_in.take().unwrap();
let bg_timeline_id = spg.timeline_id.unwrap();
let appname = self.appname.clone();
let tli = GlobalTimelines::get(self.ttid)?;
let state = ReplicaState::new();
// This replica_id is used below to check if it's time to stop replication.
let replica_id = bg_timeline.add_replica(state);
let replica_id = tli.add_replica(state);
// Use a guard object to remove our entry from the timeline, when the background
// thread and us have both finished using it.
let replica_guard = Arc::new(ReplicationConnGuard {
let _guard = Arc::new(ReplicationConnGuard {
replica: replica_id,
timeline: bg_timeline,
timeline: tli.clone(),
});
let bg_replica_guard = Arc::clone(&replica_guard);
// TODO: here we got two threads, one for writing WAL and one for receiving
// feedback. If one of them fails, we should shutdown the other one too.
let _ = thread::Builder::new()
.name("HotStandbyFeedback thread".into())
.spawn(move || {
let _enter =
info_span!("HotStandbyFeedback thread", timeline = %bg_timeline_id).entered();
if let Err(err) = Self::background_thread(bg_stream_in, bg_replica_guard) {
error!("Replication background thread failed: {}", err);
// Walproposer gets special handling: safekeeper must give proposer all
// local WAL till the end, whether committed or not (walproposer will
// hang otherwise). That's because walproposer runs the consensus and
// synchronizes safekeepers on the most advanced one.
//
// There is a small risk of this WAL getting concurrently garbaged if
// another compute rises which collects majority and starts fixing log
// on this safekeeper itself. That's ok as (old) proposer will never be
// able to commit such WAL.
let stop_pos: Option<Lsn> = if self.is_walproposer_recovery() {
let wal_end = tli.get_flush_lsn();
Some(wal_end)
} else {
None
};
let end_pos = stop_pos.unwrap_or(Lsn::INVALID);
info!(
"starting streaming from {:?} till {:?}",
start_pos, stop_pos
);
// switch to copy
pgb.write_message(&BeMessage::CopyBothResponse)?;
let (_, persisted_state) = tli.get_state();
let wal_reader = WalReader::new(
self.conf.workdir.clone(),
self.conf.timeline_dir(&tli.ttid),
&persisted_state,
start_pos,
self.conf.wal_backup_enabled,
)?;
let write_ctx = SendRc::new(WriteContext {
wal_reader: RefCell::new(wal_reader),
send_buf: RefCell::new([0; MAX_SEND_SIZE]),
});
let mut c = ReplicationContext {
tli,
replica_id,
appname,
pgb,
start_pos,
end_pos,
stop_pos,
write_ctx,
feedback: ReplicaState::new(),
};
let _phantom_wf = c.wait_wal_fut();
let real_end_pos = c.end_pos;
c.end_pos = c.start_pos + 1; // to well form read_wal future
let _phantom_rf = c.read_wal_fut();
c.end_pos = real_end_pos;
ReplicationHandler {
c,
write_state: WriteState::FlushWal,
_phantom_wf,
_phantom_rf,
}
.await
}
}
pin_project! {
/// START_REPLICATION stream driver: sends WAL and receives feedback.
struct ReplicationHandler<'a, WF, RF>
where
WF: Future<Output = anyhow::Result<Option<Lsn>>>,
RF: Future<Output = anyhow::Result<usize>>,
{
c: ReplicationContext<'a>,
#[pin]
write_state: WriteState<WF, RF>,
// To deduce anonymous types.
_phantom_wf: WF,
_phantom_rf: RF,
}
}
/// Data ReplicationHandler maintains. Separated so we could generate WriteState
/// futures during init, deducing their type.
struct ReplicationContext<'a> {
tli: Arc<Timeline>,
appname: Option<String>,
replica_id: usize,
pgb: &'a mut PostgresBackend,
// Position since which we are sending next chunk.
start_pos: Lsn,
// WAL up to this position is known to be locally available.
end_pos: Lsn,
// If present, terminate after reaching this position; used by walproposer
// in recovery.
stop_pos: Option<Lsn>,
// This data is needed to create Future sending WAL, so we need to both have
// it here (to create new future) and borrow it to the future itself.
// Essentially this is a self referential struct. To satisfy borrow checker,
// use Rc<RefCell>. To make ReplicationHandler itself Send'able future, wrap
// it into SendRc; this is safe as ReplicationHandler is passed between
// threads only as a whole (during rescheduling).
//
// Right now we're in CurrentThread runtime, so Send is somewhat redundant;
// however, otherwise we'd need to inconveniently have separate !Send
// version of pg backend Handler trait (and work with LocalSet).
write_ctx: SendRc<WriteContext>,
feedback: ReplicaState,
}
// State which ReplicationHandler needs to create futures sending data.
struct WriteContext {
wal_reader: RefCell<WalReader>,
// buffer for readling WAL into to send it
send_buf: RefCell<[u8; MAX_SEND_SIZE]>,
}
// Yield points of WAL sending machinery.
pin_project! {
#[project = WriteStateProj]
enum WriteState<WF, RF>
where
WF: Future<Output = anyhow::Result<Option<Lsn>>>,
RF: Future<Output = anyhow::Result<usize>>,
{
WaitWal{ #[pin] fut: WF},
ReadWal{ #[pin] fut: RF},
FlushWal,
}
}
impl<WF, RF> Future for ReplicationHandler<'_, WF, RF>
where
WF: Future<Output = anyhow::Result<Option<Lsn>>>,
RF: Future<Output = anyhow::Result<usize>>,
{
type Output = Result<(), QueryError>;
// We need to read feedback from the socket and write data there at the same
// time. To avoid having to split socket, which creates messy split-join
// APIs, is problematic with TLS [1] and needs to manage two tasks, just run
// single task and use poll interfaces, basically manual state machine,
// which is simple here.
//
// [1] https://github.com/tokio-rs/tls/issues/40
//
// Completes only when the stream is over, technically on error currently.
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if let Poll::Ready(r) = self.as_mut().poll_read(cx) {
return Poll::Ready(r);
}
self.as_mut().poll_write(cx)
}
}
impl<WF, RF> ReplicationHandler<'_, WF, RF>
where
WF: Future<Output = anyhow::Result<Option<Lsn>>>,
RF: Future<Output = anyhow::Result<usize>>,
{
// Poll reading, i.e. getting feedback and processing it. Completes only on error/end of stream.
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), QueryError>> {
loop {
match ready!(self.as_mut().project().c.pgb.poll_read_message(cx)) {
Ok(Some(msg)) => self.as_mut().handle_feedback(&msg)?,
Ok(None) => {
return Poll::Ready(Err(QueryError::Other(anyhow::anyhow!(
"EOF on replication stream"
))))
}
})?;
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
runtime.block_on(async move {
let (inmem_state, persisted_state) = tli.get_state();
// add persisted_state.timeline_start_lsn == Lsn(0) check
// Walproposer gets special handling: safekeeper must give proposer all
// local WAL till the end, whether committed or not (walproposer will
// hang otherwise). That's because walproposer runs the consensus and
// synchronizes safekeepers on the most advanced one.
//
// There is a small risk of this WAL getting concurrently garbaged if
// another compute rises which collects majority and starts fixing log
// on this safekeeper itself. That's ok as (old) proposer will never be
// able to commit such WAL.
let stop_pos: Option<Lsn> = if spg.is_walproposer_recovery() {
let wal_end = tli.get_flush_lsn();
Some(wal_end)
} else {
None
Err(err) => return Poll::Ready(Err(err.into())),
};
}
}
info!("Start replication from {:?} till {:?}", start_pos, stop_pos);
// switch to copy
pgb.write_message(&BeMessage::CopyBothResponse)?;
let mut end_pos = stop_pos.unwrap_or(inmem_state.commit_lsn);
let mut wal_reader = WalReader::new(
spg.conf.workdir.clone(),
spg.conf.timeline_dir(&tli.ttid),
&persisted_state,
start_pos,
spg.conf.wal_backup_enabled,
)?;
// buffer for wal sending, limited by MAX_SEND_SIZE
let mut send_buf = vec![0u8; MAX_SEND_SIZE];
// watcher for commit_lsn updates
let mut commit_lsn_watch_rx = tli.get_commit_lsn_watch_rx();
loop {
if let Some(stop_pos) = stop_pos {
if start_pos >= stop_pos {
break; /* recovery finished */
fn handle_feedback(self: Pin<&mut Self>, msg: &FeMessage) -> Result<(), QueryError> {
let this = self.project();
match &msg {
FeMessage::CopyData(m) => {
// There's three possible data messages that the client is supposed to send here:
// `HotStandbyFeedback` and `StandbyStatusUpdate` and `NeonStandbyFeedback`.
match m.first().cloned() {
Some(HOT_STANDBY_FEEDBACK_TAG_BYTE) => {
// Note: deserializing is on m[1..] because we skip the tag byte.
this.c.feedback.hs_feedback = HotStandbyFeedback::des(&m[1..])
.context("failed to deserialize HotStandbyFeedback")?;
this.c
.tli
.update_replica_state(this.c.replica_id, this.c.feedback);
}
end_pos = stop_pos;
} else {
/* Wait until we have some data to stream */
let lsn = wait_for_lsn(&mut commit_lsn_watch_rx, start_pos).await?;
Some(STANDBY_STATUS_UPDATE_TAG_BYTE) => {
let _reply = StandbyReply::des(&m[1..])
.context("failed to deserialize StandbyReply")?;
// This must be a regular postgres replica,
// because pageserver doesn't send this type of messages to safekeeper.
// Currently we just ignore this, tracking progress for them is not supported.
}
Some(NEON_STATUS_UPDATE_TAG_BYTE) => {
// Note: deserializing is on m[9..] because we skip the tag byte and len bytes.
let buf = Bytes::copy_from_slice(&m[9..]);
let reply = ReplicationFeedback::parse(buf);
if let Some(lsn) = lsn {
end_pos = lsn;
} else {
// TODO: also check once in a while whether we are walsender
// to right pageserver.
if tli.should_walsender_stop(replica_id) {
// Shut down, timeline is suspended.
return Err(QueryError::from(io::Error::new(
io::ErrorKind::ConnectionAborted,
format!("end streaming to {:?}", spg.appname),
)));
}
trace!("ReplicationFeedback is {:?}", reply);
// Only pageserver sends ReplicationFeedback, so set the flag.
// This replica is the source of information to resend to compute.
this.c.feedback.pageserver_feedback = Some(reply);
// timeout expired: request pageserver status
pgb.write_message(&BeMessage::KeepAlive(WalSndKeepAlive {
sent_ptr: end_pos.0,
timestamp: get_current_timestamp(),
request_reply: true,
}))?;
this.c
.tli
.update_replica_state(this.c.replica_id, this.c.feedback);
}
_ => warn!("unexpected message {:?}", msg),
}
}
FeMessage::CopyFail => {
// XXX we should probably (tell pgb to) close the socket, as
// CopyFail in duplex copy is somewhat unexpected (at least to
// PG walsender; evidently client should finish it with
// CopyDone). Note that sync rust-postgres client (which we
// don't use anymore) hangs otherwise.
// https://github.com/sfackler/rust-postgres/issues/755
// https://github.com/neondatabase/neon/issues/935
//
return Err(anyhow::anyhow!("unexpected CopyFail").into());
}
_ => {
return Err(
anyhow::anyhow!("unexpected message {:?} in replication stream", msg).into(),
);
}
};
Ok(())
}
// Poll writing, i.e. sending more WAL. Completes only on error or when we
// decide to shutdown connection -- receiver is caughtup and there is no
// active computes; this is still handled as Err though.
fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), QueryError>> {
// send while we don't block or error out
loop {
match &mut self.as_mut().project().write_state.project() {
WriteStateProj::WaitWal { fut } => match ready!(fut.as_mut().poll(cx))? {
Some(lsn) => {
self.as_mut().project().c.end_pos = lsn;
self.as_mut().start_read_wal();
continue;
}
// Timed out waiting for WAL, send keepalive and possibly terminate.
None => {
let mut this = self.as_mut().project();
if this.c.tli.should_walsender_stop(this.c.replica_id) {
// Terminate if there is nothing more to send.
// TODO close the stream properly
return Poll::Ready(Err(anyhow::anyhow!(format!(
"ending streaming to {:?} at {}, receiver is caughtup and there is no computes",
self.c.appname, self.c.start_pos,
)).into()));
}
let end_pos = this.c.end_pos.0;
this.c
.pgb
.write_message(&BeMessage::KeepAlive(WalSndKeepAlive {
sent_ptr: end_pos,
timestamp: get_current_timestamp(),
request_reply: true,
}))?;
/* flush KA */
this.write_state.set(WriteState::FlushWal);
}
},
WriteStateProj::ReadWal { fut } => {
let read_len = ready!(fut.as_mut().poll(cx))?;
assert!(read_len > 0, "read_len={}", read_len);
let mut this = self.as_mut().project();
let write_ctx_clone = this.c.write_ctx.clone();
let send_buf = &write_ctx_clone.send_buf.borrow()[..read_len];
let chunk_end = this.c.start_pos + read_len as u64;
// write data to the output buffer
this.c
.pgb
.write_message(&BeMessage::XLogData(XLogDataBody {
wal_start: this.c.start_pos.0,
wal_end: chunk_end.0,
timestamp: get_current_timestamp(),
data: send_buf,
}))
.context("Failed to write XLogData")?;
trace!("wrote a chunk of wal {}-{}", this.c.start_pos, chunk_end);
this.c.start_pos = chunk_end;
// and flush it
this.write_state.set(WriteState::FlushWal);
}
WriteStateProj::FlushWal => {
let this = self.as_mut().project();
let send_size = end_pos.checked_sub(start_pos).unwrap().0 as usize;
let send_size = min(send_size, send_buf.len());
let send_buf = &mut send_buf[..send_size];
// read wal into buffer
let send_size = wal_reader.read(send_buf).await?;
let send_buf = &send_buf[..send_size];
// Write some data to the network socket.
pgb.write_message(&BeMessage::XLogData(XLogDataBody {
wal_start: start_pos.0,
wal_end: end_pos.0,
timestamp: get_current_timestamp(),
data: send_buf,
}))
.context("Failed to send XLogData")?;
start_pos += send_size as u64;
trace!("sent WAL up to {}", start_pos);
ready!(this.c.pgb.poll_flush(cx))?;
// If we are streaming to walproposer, check it is time to stop.
if let Some(stop_pos) = this.c.stop_pos {
if this.c.start_pos >= stop_pos {
// recovery finished
// TODO close the stream properly
return Poll::Ready(Err(anyhow::anyhow!(format!(
"ending streaming to walproposer at {}, receiver is caughtup and there is no computes",
this.c.start_pos)).into()));
}
self.as_mut().start_read_wal();
continue;
} else {
// if we don't know next portion is already available, wait
// for it; otherwise proceed to sending
if self.c.end_pos <= self.c.start_pos {
self.as_mut().start_wait_wal();
} else {
self.as_mut().start_read_wal();
}
}
}
}
}
}
Ok(())
})
// Start waiting for WAL, creating future doing that.
fn start_wait_wal(self: Pin<&mut Self>) {
let fut = self.c.wait_wal_fut();
self.project().write_state.set(WriteState::WaitWal {
fut: {
// SAFETY: this function is the only way to assign WaitWal to
// write_state. We just workaround impossibility of specifying
// async fn type, which is anonymous.
// transmute_copy is used as transmute refuses generic param:
// https://users.rust-lang.org/t/transmute-doesnt-work-on-generic-types/87272
assert_eq!(std::mem::size_of::<WF>(), std::mem::size_of_val(&fut));
let t = unsafe { std::mem::transmute_copy(&fut) };
std::mem::forget(fut);
t
},
});
}
// Switch into reading WAL state, creating Future doing that.
fn start_read_wal(self: Pin<&mut Self>) {
let fut = self.c.read_wal_fut();
self.project().write_state.set(WriteState::ReadWal {
fut: {
// SAFETY: this function is the only way to assign ReadWal to
// write_state. We just workaround impossibility of specifying
// async fn type, which is anonymous.
// transmute_copy is used as transmute refuses generic param:
// https://users.rust-lang.org/t/transmute-doesnt-work-on-generic-types/87272
assert_eq!(std::mem::size_of::<RF>(), std::mem::size_of_val(&fut));
let t = unsafe { std::mem::transmute_copy(&fut) };
std::mem::forget(fut);
t
},
});
}
}
impl ReplicationContext<'_> {
// Create future waiting for WAL.
fn wait_wal_fut(&self) -> impl Future<Output = anyhow::Result<Option<Lsn>>> {
let mut commit_lsn_watch_rx = self.tli.get_commit_lsn_watch_rx();
let start_pos = self.start_pos;
async move { wait_for_lsn(&mut commit_lsn_watch_rx, start_pos).await }
}
// Create future reading WAL.
fn read_wal_fut(&self) -> impl Future<Output = anyhow::Result<usize>> {
let mut send_size = self
.end_pos
.checked_sub(self.start_pos)
.expect("reading wal without waiting for it first")
.0 as usize;
send_size = min(send_size, self.write_ctx.send_buf.borrow().len());
let write_ctx_fut = self.write_ctx.clone();
async move {
let mut wal_reader_ref = write_ctx_fut.wal_reader.borrow_mut_send();
let mut send_buf_ref = write_ctx_fut.send_buf.borrow_mut_send();
let send_buf = &mut send_buf_ref[..send_size];
wal_reader_ref.read(send_buf).await
}
}
}
const POLL_STATE_TIMEOUT: Duration = Duration::from_secs(1);
// Wait until we have commit_lsn > lsn or timeout expires. Returns latest commit_lsn.
// Wait until we have commit_lsn > lsn or timeout expires. Returns
// - Ok(Some(commit_lsn)) if needed lsn is successfully observed;
// - Ok(None) if timeout expired;
// - Err in case of error (if watch channel is in trouble, shouldn't happen).
async fn wait_for_lsn(rx: &mut Receiver<Lsn>, lsn: Lsn) -> anyhow::Result<Option<Lsn>> {
let commit_lsn: Lsn = *rx.borrow();
if commit_lsn > lsn {

View File

@@ -346,7 +346,9 @@ impl WalBackupTask {
backup_lsn, commit_lsn, e
);
retry_attempt = retry_attempt.saturating_add(1);
if retry_attempt < u32::MAX {
retry_attempt += 1;
}
}
}
}
@@ -385,7 +387,7 @@ async fn backup_single_segment(
) -> Result<()> {
let segment_file_path = seg.file_path(timeline_dir)?;
let remote_segment_path = segment_file_path
.strip_prefix(workspace_dir)
.strip_prefix(&workspace_dir)
.context("Failed to strip workspace dir prefix")
.and_then(RemotePath::new)
.with_context(|| {
@@ -467,7 +469,7 @@ async fn backup_object(source_file: &Path, target_file: &RemotePath, size: usize
pub async fn read_object(
file_path: &RemotePath,
offset: u64,
) -> anyhow::Result<Pin<Box<dyn tokio::io::AsyncRead>>> {
) -> anyhow::Result<Pin<Box<dyn tokio::io::AsyncRead + Send + Sync>>> {
let storage = REMOTE_STORAGE
.get()
.context("Failed to get remote storage")?

View File

@@ -2,36 +2,54 @@
//! WAL service listens for client connections and
//! receive WAL from wal_proposer and send it to WAL receivers
//!
use anyhow::{Context, Result};
use regex::Regex;
use std::net::{TcpListener, TcpStream};
use std::thread;
use std::{future, thread};
use tokio::net::TcpStream;
use tracing::*;
use utils::postgres_backend_async::QueryError;
use crate::handler::SafekeeperPostgresHandler;
use crate::SafeKeeperConf;
use utils::postgres_backend::{AuthType, PostgresBackend};
use utils::postgres_backend_async::{AuthType, PostgresBackend};
/// Accept incoming TCP connections and spawn them into a background thread.
pub fn thread_main(conf: SafeKeeperConf, listener: TcpListener) -> ! {
loop {
match listener.accept() {
Ok((socket, peer_addr)) => {
debug!("accepted connection from {}", peer_addr);
let conf = conf.clone();
pub fn thread_main(conf: SafeKeeperConf, pg_listener: std::net::TcpListener) {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.context("create runtime")
// todo catch error in main thread
.expect("failed to create runtime");
let _ = thread::Builder::new()
.name("WAL service thread".into())
.spawn(move || {
if let Err(err) = handle_socket(socket, conf) {
error!("connection handler exited: {}", err);
}
})
.unwrap();
runtime
.block_on(async move {
// Tokio's from_std won't do this for us, per its comment.
pg_listener.set_nonblocking(true)?;
let listener = tokio::net::TcpListener::from_std(pg_listener)?;
loop {
match listener.accept().await {
Ok((socket, peer_addr)) => {
debug!("accepted connection from {}", peer_addr);
let conf = conf.clone();
let _ = thread::Builder::new()
.name("WAL service thread".into())
.spawn(move || {
if let Err(err) = handle_socket(socket, conf) {
error!("connection handler exited: {}", err);
}
})
.unwrap();
}
Err(e) => error!("Failed to accept connection: {}", e),
}
}
Err(e) => error!("Failed to accept connection: {}", e),
}
}
#[allow(unreachable_code)] // hint compiler the closure return type
Ok::<(), anyhow::Error>(())
})
.expect("listener failed")
}
// Get unique thread id (Rust internal), with ThreadId removed for shorter printing
@@ -44,9 +62,14 @@ fn get_tid() -> u64 {
/// This is run by `thread_main` above, inside a background thread.
///
fn handle_socket(socket: TcpStream, conf: SafeKeeperConf) -> Result<(), QueryError> {
fn handle_socket(mut socket: TcpStream, conf: SafeKeeperConf) -> Result<(), QueryError> {
let _enter = info_span!("", tid = ?get_tid()).entered();
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
let local = tokio::task::LocalSet::new();
socket.set_nodelay(true)?;
let auth_type = match conf.auth {
@@ -54,9 +77,13 @@ fn handle_socket(socket: TcpStream, conf: SafeKeeperConf) -> Result<(), QueryErr
Some(_) => AuthType::NeonJWT,
};
let mut conn_handler = SafekeeperPostgresHandler::new(conf);
let pgbackend = PostgresBackend::new(socket, auth_type, None, false)?;
// libpq replication protocol between safekeeper and replicas/pagers
pgbackend.run(&mut conn_handler)?;
let pgbackend = PostgresBackend::new(socket, auth_type, None)?;
// libpq protocol between safekeeper and walproposer / pageserver
// We don't use shutdown.
local.block_on(
&runtime,
pgbackend.run(&mut conn_handler, || future::pending::<()>()),
)?;
Ok(())
}

View File

@@ -450,7 +450,7 @@ pub struct WalReader {
timeline_dir: PathBuf,
wal_seg_size: usize,
pos: Lsn,
wal_segment: Option<Pin<Box<dyn AsyncRead>>>,
wal_segment: Option<Pin<Box<dyn AsyncRead + Send + Sync>>>,
// S3 will be used to read WAL if LSN is not available locally
enable_remote_read: bool,
@@ -491,6 +491,11 @@ impl WalReader {
})
}
pub async fn fake_read(&mut self) -> Result<usize> {
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
Ok(self.pos.0 as usize)
}
pub async fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
let mut wal_segment = match self.wal_segment.take() {
Some(reader) => reader,
@@ -517,7 +522,7 @@ impl WalReader {
}
/// Open WAL segment at the current position of the reader.
async fn open_segment(&self) -> Result<Pin<Box<dyn AsyncRead>>> {
async fn open_segment(&self) -> Result<Pin<Box<dyn AsyncRead + Send + Sync>>> {
let xlogoff = self.pos.segment_offset(self.wal_seg_size);
let segno = self.pos.segment_number(self.wal_seg_size);
let wal_file_name = XLogFileName(PG_TLI, segno, self.wal_seg_size);