wal_decoder: abstract wire format

We want to add new wire formats, but the current code has bincode
hard-codeded.

To this end:
1. Rework the wal_receiver_protocol PS config to include a format
   modifier for interpreted protocol type.
2. Abstract wire format encoding and decoding to a separate module
   in wal_decoder
3. Glue things back together
This commit is contained in:
Vlad Lazar
2024-11-19 14:12:19 +01:00
parent 7a2f0ed8d4
commit 7b0f1605b2
11 changed files with 121 additions and 61 deletions

1
Cargo.lock generated
View File

@@ -7124,6 +7124,7 @@ dependencies = [
"pageserver_api",
"postgres_ffi",
"serde",
"thiserror",
"tracing",
"utils",
"workspace_hack",

View File

@@ -7,40 +7,21 @@ use postgres_connection::{parse_host_port, PgConnectionConfig};
use crate::id::TenantTimelineId;
/// Postgres client protocol types
#[derive(
Copy,
Clone,
PartialEq,
Eq,
strum_macros::EnumString,
strum_macros::Display,
serde_with::DeserializeFromStr,
serde_with::SerializeDisplay,
Debug,
)]
#[strum(serialize_all = "kebab-case")]
#[repr(u8)]
#[derive(Copy, Clone, PartialEq, Eq, Debug, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum InterpretedFormat {
Bincode,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(tag = "type", content = "args")]
#[serde(rename_all = "kebab-case")]
pub enum PostgresClientProtocol {
/// Usual Postgres replication protocol
Vanilla,
/// Custom shard-aware protocol that replicates interpreted records.
/// Used to send wal from safekeeper to pageserver.
Interpreted,
}
impl TryFrom<u8> for PostgresClientProtocol {
type Error = u8;
fn try_from(value: u8) -> Result<Self, Self::Error> {
Ok(match value {
v if v == (PostgresClientProtocol::Vanilla as u8) => PostgresClientProtocol::Vanilla,
v if v == (PostgresClientProtocol::Interpreted as u8) => {
PostgresClientProtocol::Interpreted
}
x => return Err(x),
})
}
Interpreted { format: InterpretedFormat },
}
pub struct ConnectionConfigArgs<'a> {
@@ -63,7 +44,10 @@ impl<'a> ConnectionConfigArgs<'a> {
"-c".to_owned(),
format!("timeline_id={}", self.ttid.timeline_id),
format!("tenant_id={}", self.ttid.tenant_id),
format!("protocol={}", self.protocol as u8),
format!(
"protocol={}",
serde_json::to_string(&self.protocol).unwrap()
),
];
if self.shard_number.is_some() {

View File

@@ -13,6 +13,7 @@ bytes.workspace = true
pageserver_api.workspace = true
postgres_ffi.workspace = true
serde.workspace = true
thiserror.workspace = true
tracing.workspace = true
utils.workspace = true
workspace_hack = { version = "0.1", path = "../../workspace_hack" }

View File

@@ -1,3 +1,4 @@
pub mod decoder;
pub mod models;
pub mod serialized_batch;
pub mod wire_format;

View File

@@ -0,0 +1,52 @@
use bytes::{BufMut, Bytes, BytesMut};
use utils::bin_ser::{BeSer, DeserializeError, SerializeError};
use utils::postgres_client::InterpretedFormat;
use crate::models::InterpretedWalRecord;
#[derive(Debug, thiserror::Error)]
pub enum ToWireFormatError {
#[error("{0}")]
Bincode(SerializeError),
}
#[derive(Debug, thiserror::Error)]
pub enum FromWireFormatError {
#[error("{0}")]
Bincode(DeserializeError),
}
pub trait ToWireFormat {
fn to_wire(self, format: InterpretedFormat) -> Result<Bytes, ToWireFormatError>;
}
pub trait FromWireFormat {
type T;
fn from_wire(buf: &Bytes, format: InterpretedFormat) -> Result<Self::T, FromWireFormatError>;
}
impl ToWireFormat for Vec<InterpretedWalRecord> {
fn to_wire(self, format: InterpretedFormat) -> Result<Bytes, ToWireFormatError> {
match format {
InterpretedFormat::Bincode => {
let buf = BytesMut::new();
let mut buf = buf.writer();
self.ser_into(&mut buf)
.map_err(ToWireFormatError::Bincode)?;
Ok(buf.into_inner().freeze())
}
}
}
}
impl FromWireFormat for Vec<InterpretedWalRecord> {
type T = Self;
fn from_wire(buf: &Bytes, format: InterpretedFormat) -> Result<Self, FromWireFormatError> {
match format {
InterpretedFormat::Bincode => {
Vec::<InterpretedWalRecord>::des(buf).map_err(FromWireFormatError::Bincode)
}
}
}
}

View File

@@ -535,6 +535,7 @@ impl ConnectionManagerState {
let node_id = new_sk.safekeeper_id;
let connect_timeout = self.conf.wal_connect_timeout;
let ingest_batch_size = self.conf.ingest_batch_size;
let protocol = self.conf.protocol;
let timeline = Arc::clone(&self.timeline);
let ctx = ctx.detached_child(
TaskKind::WalReceiverConnectionHandler,
@@ -548,6 +549,7 @@ impl ConnectionManagerState {
let res = super::walreceiver_connection::handle_walreceiver_connection(
timeline,
protocol,
new_sk.wal_source_connconf,
events_sender,
cancellation.clone(),
@@ -991,7 +993,7 @@ impl ConnectionManagerState {
PostgresClientProtocol::Vanilla => {
(None, None, None)
},
PostgresClientProtocol::Interpreted => {
PostgresClientProtocol::Interpreted { .. } => {
let shard_identity = self.timeline.get_shard_identity();
(
Some(shard_identity.number.0),

View File

@@ -22,7 +22,10 @@ use tokio::{select, sync::watch, time};
use tokio_postgres::{replication::ReplicationStream, Client};
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, trace, warn, Instrument};
use wal_decoder::models::{FlushUncommittedRecords, InterpretedWalRecord};
use wal_decoder::{
models::{FlushUncommittedRecords, InterpretedWalRecord},
wire_format::FromWireFormat,
};
use super::TaskStateUpdate;
use crate::{
@@ -36,7 +39,7 @@ use crate::{
use postgres_backend::is_expected_io_error;
use postgres_connection::PgConnectionConfig;
use postgres_ffi::waldecoder::WalStreamDecoder;
use utils::{bin_ser::BeSer, id::NodeId, lsn::Lsn};
use utils::{id::NodeId, lsn::Lsn, postgres_client::PostgresClientProtocol};
use utils::{pageserver_feedback::PageserverFeedback, sync::gate::GateError};
/// Status of the connection.
@@ -109,6 +112,7 @@ impl From<WalDecodeError> for WalReceiverError {
#[allow(clippy::too_many_arguments)]
pub(super) async fn handle_walreceiver_connection(
timeline: Arc<Timeline>,
protocol: PostgresClientProtocol,
wal_source_connconf: PgConnectionConfig,
events_sender: watch::Sender<TaskStateUpdate<WalConnectionStatus>>,
cancellation: CancellationToken,
@@ -260,6 +264,11 @@ pub(super) async fn handle_walreceiver_connection(
let mut walingest = WalIngest::new(timeline.as_ref(), startpoint, &ctx).await?;
let interpreted_format = match protocol {
PostgresClientProtocol::Vanilla => None,
PostgresClientProtocol::Interpreted { format } => Some(format),
};
while let Some(replication_message) = {
select! {
_ = cancellation.cancelled() => {
@@ -337,11 +346,13 @@ pub(super) async fn handle_walreceiver_connection(
Lsn(raw.next_record_lsn().unwrap_or(0))
);
let records = Vec::<InterpretedWalRecord>::des(raw.data()).with_context(|| {
anyhow::anyhow!(
let records =
Vec::<InterpretedWalRecord>::from_wire(raw.data(), interpreted_format.unwrap())
.with_context(|| {
anyhow::anyhow!(
"Failed to deserialize interpreted records ending at LSN {streaming_lsn}"
)
})?;
})?;
// We start the modification at 0 because each interpreted record
// advances it to its end LSN. 0 is just an initialization placeholder.

View File

@@ -123,17 +123,10 @@ impl<IO: AsyncRead + AsyncWrite + Unpin + Send> postgres_backend::Handler<IO>
// https://github.com/neondatabase/neon/pull/2433#discussion_r970005064
match opt.split_once('=') {
Some(("protocol", value)) => {
let raw_value = value
.parse::<u8>()
.with_context(|| format!("Failed to parse {value} as protocol"))?;
self.protocol = Some(
PostgresClientProtocol::try_from(raw_value).map_err(|_| {
QueryError::Other(anyhow::anyhow!(
"Unexpected client protocol type: {raw_value}"
))
})?,
);
self.protocol =
Some(serde_json::from_str(value).with_context(|| {
format!("Failed to parse {value} as protocol")
})?);
}
Some(("ztenantid", value)) | Some(("tenant_id", value)) => {
self.tenant_id = Some(value.parse().with_context(|| {
@@ -180,7 +173,7 @@ impl<IO: AsyncRead + AsyncWrite + Unpin + Send> postgres_backend::Handler<IO>
)));
}
}
PostgresClientProtocol::Interpreted => {
PostgresClientProtocol::Interpreted { .. } => {
match (shard_count, shard_number, shard_stripe_size) {
(Some(count), Some(number), Some(stripe_size)) => {
let params = ShardParameters {

View File

@@ -9,9 +9,10 @@ use postgres_ffi::{get_current_timestamp, waldecoder::WalStreamDecoder};
use pq_proto::{BeMessage, InterpretedWalRecordsBody, WalSndKeepAlive};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::time::MissedTickBehavior;
use utils::bin_ser::BeSer;
use utils::lsn::Lsn;
use utils::postgres_client::InterpretedFormat;
use wal_decoder::models::InterpretedWalRecord;
use wal_decoder::wire_format::ToWireFormat;
use crate::send_wal::EndWatchView;
use crate::wal_reader_stream::{WalBytes, WalReaderStreamBuilder};
@@ -20,6 +21,7 @@ use crate::wal_reader_stream::{WalBytes, WalReaderStreamBuilder};
/// This is used for sending WAL to the pageserver. Said WAL
/// is pre-interpreted and filtered for the shard.
pub(crate) struct InterpretedWalSender<'a, IO> {
pub(crate) format: InterpretedFormat,
pub(crate) pgb: &'a mut PostgresBackend<IO>,
pub(crate) wal_stream_builder: WalReaderStreamBuilder,
pub(crate) end_watch_view: EndWatchView,
@@ -81,10 +83,7 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> InterpretedWalSender<'_, IO> {
}
}
let mut buf = Vec::new();
records
.ser_into(&mut buf)
.with_context(|| "Failed to serialize interpreted WAL")?;
let buf = records.to_wire(self.format).with_context(|| "Failed to serialize interpreted WAL")?;
// Reset the keep alive ticker since we are sending something
// over the wire now.
@@ -95,7 +94,7 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> InterpretedWalSender<'_, IO> {
streaming_lsn: wal_end_lsn.0,
commit_lsn: available_wal_end_lsn.0,
next_record_lsn: max_next_record_lsn.unwrap_or(Lsn::INVALID).0,
data: buf.as_slice(),
data: &buf,
})).await?;
}

View File

@@ -454,7 +454,7 @@ impl SafekeeperPostgresHandler {
}
info!(
"starting streaming from {:?}, available WAL ends at {}, recovery={}, appname={:?}, protocol={}",
"starting streaming from {:?}, available WAL ends at {}, recovery={}, appname={:?}, protocol={:?}",
start_pos,
end_pos,
matches!(end_watch, EndWatch::Flush(_)),
@@ -489,7 +489,7 @@ impl SafekeeperPostgresHandler {
Either::Left(sender.run())
}
PostgresClientProtocol::Interpreted => {
PostgresClientProtocol::Interpreted { format } => {
let pg_version = tli.tli.get_state().await.1.server.pg_version / 10000;
let end_watch_view = end_watch.view();
let wal_stream_builder = WalReaderStreamBuilder {
@@ -502,6 +502,7 @@ impl SafekeeperPostgresHandler {
};
let sender = InterpretedWalSender {
format,
pgb,
wal_stream_builder,
end_watch_view,

View File

@@ -27,14 +27,29 @@ def test_sharded_ingest(
and fanning out to a large number of shards on dedicated Pageservers. Comparing the base case
(shard_count=1) to the sharded case indicates the overhead of sharding.
"""
neon_env_builder.pageserver_config_override = (
f"wal_receiver_protocol = '{wal_receiver_protocol}'"
)
ROW_COUNT = 100_000_000 # about 7 GB of WAL
neon_env_builder.num_pageservers = shard_count
env = neon_env_builder.init_start()
env = neon_env_builder.init_configs()
for ps in env.pageservers:
if wal_receiver_protocol == "vanilla":
ps.patch_config_toml_nonrecursive({
"wal_receiver_protocol": {
"type": "vanilla",
}
})
elif wal_receiver_protocol == "interpreted":
ps.patch_config_toml_nonrecursive({
"wal_receiver_protocol": {
"type": "interpreted",
"args": {
"format": "bincode"
}
}
})
env.start()
# Create a sharded tenant and timeline, and migrate it to the respective pageservers. Ensure
# the storage controller doesn't mess with shard placements.