wal_decoder: reuse codec throughout sender/receiver lifetime

Problem

Previously, we used `from_wire` and `to_wire` inline to encode and
decode record batches. This means we always have to match on the format,
and, more importantly, doesn't allow for reuse of the zstd
encoder/decoder.

Summary of Changes

Refactor such that the encoder and decoder can have the same lifetime
as the sender/receiver session.
This commit is contained in:
Vlad Lazar
2024-11-26 16:50:17 +01:00
parent 9e0148de11
commit 6666f6807b
8 changed files with 528 additions and 91 deletions

1
Cargo.lock generated
View File

@@ -7121,6 +7121,7 @@ version = "0.1.0"
dependencies = [
"anyhow",
"async-compression",
"async-trait",
"bytes",
"pageserver_api",
"postgres_ffi",

View File

@@ -9,6 +9,7 @@ testing = ["pageserver_api/testing"]
[dependencies]
async-compression.workspace = true
async-trait.workspace = true
anyhow.workspace = true
bytes.workspace = true
pageserver_api.workspace = true

View File

@@ -0,0 +1,192 @@
use bytes::{BufMut, Bytes, BytesMut};
use prost::Message;
use tokio::io::AsyncWriteExt;
use utils::postgres_client::{Compression, InterpretedFormat};
use crate::models::proto;
use crate::models::InterpretedWalRecords;
use crate::protobuf_conversions::TranscodeError;
use utils::bin_ser::{BeSer, DeserializeError, SerializeError};
#[derive(Debug, thiserror::Error)]
pub enum EncodeError {
#[error("{0}")]
Bincode(#[from] SerializeError),
#[error("{0}")]
Protobuf(#[from] ProtobufSerializeError),
#[error("{0}")]
Compression(#[from] std::io::Error),
}
#[derive(Debug, thiserror::Error)]
pub enum ProtobufSerializeError {
#[error("{0}")]
MetadataRecord(#[from] SerializeError),
#[error("{0}")]
Encode(#[from] prost::EncodeError),
}
#[derive(Debug, thiserror::Error)]
pub enum DecodeError {
#[error("{0}")]
Bincode(#[from] DeserializeError),
#[error("{0}")]
Protobuf(#[from] ProtobufDeserializeError),
#[error("{0}")]
Decompress(#[from] std::io::Error),
}
#[derive(Debug, thiserror::Error)]
pub enum ProtobufDeserializeError {
#[error("{0}")]
Transcode(#[from] TranscodeError),
#[error("{0}")]
Decode(#[from] prost::DecodeError),
}
pub fn encoder_from_proto(
format: InterpretedFormat,
compression: Option<Compression>,
) -> Box<dyn Encoder> {
match format {
InterpretedFormat::Bincode => Box::new(BincodeEncoder { compression }),
InterpretedFormat::Protobuf => Box::new(ProtobufEncoder { compression }),
}
}
pub fn make_decoder(
format: InterpretedFormat,
compression: Option<Compression>,
) -> Box<dyn Decoder> {
match format {
InterpretedFormat::Bincode => Box::new(BincodeDecoder { compression }),
InterpretedFormat::Protobuf => Box::new(ProtobufDecoder { compression }),
}
}
#[async_trait::async_trait]
pub trait Encoder: Send + Sync {
async fn encode(&self, records: InterpretedWalRecords) -> Result<Bytes, EncodeError>;
}
#[async_trait::async_trait]
pub trait Decoder: Send + Sync {
async fn decode(&self, buf: &Bytes) -> Result<InterpretedWalRecords, DecodeError>;
}
struct BincodeDecoder {
compression: Option<Compression>,
}
#[async_trait::async_trait]
impl Decoder for BincodeDecoder {
async fn decode(&self, buf: &Bytes) -> Result<InterpretedWalRecords, DecodeError> {
let decompressed_buf = match self.compression {
Some(Compression::Zstd { .. }) => {
use async_compression::tokio::write::ZstdDecoder;
let mut decoded_buf = Vec::with_capacity(buf.len());
let mut decoder = ZstdDecoder::new(&mut decoded_buf);
decoder.write_all(buf).await?;
decoder.flush().await?;
Bytes::from(decoded_buf)
}
None => buf.clone(),
};
InterpretedWalRecords::des(&decompressed_buf).map_err(DecodeError::Bincode)
}
}
struct BincodeEncoder {
compression: Option<Compression>,
}
#[async_trait::async_trait]
impl Encoder for BincodeEncoder {
async fn encode(&self, records: InterpretedWalRecords) -> Result<Bytes, EncodeError> {
use async_compression::tokio::write::ZstdEncoder;
use async_compression::Level;
let buf = BytesMut::new();
let mut buf = buf.writer();
records.ser_into(&mut buf)?;
let buf = buf.into_inner().freeze();
let compressed_buf = match self.compression {
Some(Compression::Zstd { level }) => {
let mut encoder = ZstdEncoder::with_quality(
Vec::with_capacity(buf.len() / 4),
Level::Precise(level as i32),
);
encoder.write_all(&buf).await?;
encoder.shutdown().await?;
Bytes::from(encoder.into_inner())
}
None => buf,
};
Ok(compressed_buf)
}
}
struct ProtobufDecoder {
compression: Option<Compression>,
}
#[async_trait::async_trait]
impl Decoder for ProtobufDecoder {
async fn decode(&self, buf: &Bytes) -> Result<InterpretedWalRecords, DecodeError> {
let decompressed_buf = match self.compression {
Some(Compression::Zstd { .. }) => {
use async_compression::tokio::write::ZstdDecoder;
let mut decoded_buf = Vec::with_capacity(buf.len());
let mut decoder = ZstdDecoder::new(&mut decoded_buf);
decoder.write_all(buf).await?;
decoder.flush().await?;
Bytes::from(decoded_buf)
}
None => buf.clone(),
};
let proto = proto::InterpretedWalRecords::decode(decompressed_buf)
.map_err(|e| DecodeError::Protobuf(e.into()))?;
InterpretedWalRecords::try_from(proto).map_err(|e| DecodeError::Protobuf(e.into()))
}
}
struct ProtobufEncoder {
compression: Option<Compression>,
}
#[async_trait::async_trait]
impl Encoder for ProtobufEncoder {
async fn encode(&self, records: InterpretedWalRecords) -> Result<Bytes, EncodeError> {
use async_compression::tokio::write::ZstdEncoder;
use async_compression::Level;
let proto: proto::InterpretedWalRecords = records.try_into()?;
let mut buf = BytesMut::new();
proto
.encode(&mut buf)
.map_err(|e| EncodeError::Protobuf(e.into()))?;
let buf = buf.freeze();
let compressed_buf = match self.compression {
Some(Compression::Zstd { level }) => {
let mut encoder = ZstdEncoder::with_quality(
Vec::with_capacity(buf.len() / 4),
Level::Precise(level as i32),
);
encoder.write_all(&buf).await?;
encoder.shutdown().await?;
Bytes::from(encoder.into_inner())
}
None => buf,
};
Ok(compressed_buf)
}
}

View File

@@ -1,4 +1,5 @@
pub mod codec;
pub mod decoder;
pub mod models;
pub mod protobuf_conversions;
pub mod serialized_batch;
pub mod wire_format;

View File

@@ -0,0 +1,220 @@
use pageserver_api::key::CompactKey;
use utils::bin_ser::{BeSer, DeserializeError, SerializeError};
use utils::lsn::Lsn;
use crate::models::{
FlushUncommittedRecords, InterpretedWalRecord, InterpretedWalRecords, MetadataRecord,
};
use crate::serialized_batch::{
ObservedValueMeta, SerializedValueBatch, SerializedValueMeta, ValueMeta,
};
use crate::models::proto;
#[derive(Debug, thiserror::Error)]
pub enum TranscodeError {
#[error("{0}")]
BadInput(String),
#[error("{0}")]
MetadataRecord(#[from] DeserializeError),
}
impl TryFrom<InterpretedWalRecords> for proto::InterpretedWalRecords {
type Error = SerializeError;
fn try_from(value: InterpretedWalRecords) -> Result<Self, Self::Error> {
let records = value
.records
.into_iter()
.map(proto::InterpretedWalRecord::try_from)
.collect::<Result<Vec<_>, _>>()?;
Ok(proto::InterpretedWalRecords {
records,
next_record_lsn: value.next_record_lsn.map(|l| l.0),
})
}
}
impl TryFrom<InterpretedWalRecord> for proto::InterpretedWalRecord {
type Error = SerializeError;
fn try_from(value: InterpretedWalRecord) -> Result<Self, Self::Error> {
let metadata_record = value
.metadata_record
.map(|meta_rec| -> Result<Vec<u8>, Self::Error> {
let mut buf = Vec::new();
meta_rec.ser_into(&mut buf)?;
Ok(buf)
})
.transpose()?;
Ok(proto::InterpretedWalRecord {
metadata_record,
batch: Some(proto::SerializedValueBatch::from(value.batch)),
next_record_lsn: value.next_record_lsn.0,
flush_uncommitted: matches!(value.flush_uncommitted, FlushUncommittedRecords::Yes),
xid: value.xid,
})
}
}
impl From<SerializedValueBatch> for proto::SerializedValueBatch {
fn from(value: SerializedValueBatch) -> Self {
proto::SerializedValueBatch {
raw: value.raw,
metadata: value
.metadata
.into_iter()
.map(proto::ValueMeta::from)
.collect(),
max_lsn: value.max_lsn.0,
len: value.len as u64,
}
}
}
impl From<ValueMeta> for proto::ValueMeta {
fn from(value: ValueMeta) -> Self {
match value {
ValueMeta::Observed(obs) => proto::ValueMeta {
r#type: proto::ValueMetaType::Observed.into(),
key: Some(proto::CompactKey::from(obs.key)),
lsn: obs.lsn.0,
batch_offset: None,
len: None,
will_init: None,
},
ValueMeta::Serialized(ser) => proto::ValueMeta {
r#type: proto::ValueMetaType::Serialized.into(),
key: Some(proto::CompactKey::from(ser.key)),
lsn: ser.lsn.0,
batch_offset: Some(ser.batch_offset),
len: Some(ser.len as u64),
will_init: Some(ser.will_init),
},
}
}
}
impl From<CompactKey> for proto::CompactKey {
fn from(value: CompactKey) -> Self {
proto::CompactKey {
high: (value.raw() >> 64) as i64,
low: value.raw() as i64,
}
}
}
impl TryFrom<proto::InterpretedWalRecords> for InterpretedWalRecords {
type Error = TranscodeError;
fn try_from(value: proto::InterpretedWalRecords) -> Result<Self, Self::Error> {
let records = value
.records
.into_iter()
.map(InterpretedWalRecord::try_from)
.collect::<Result<_, _>>()?;
Ok(InterpretedWalRecords {
records,
next_record_lsn: value.next_record_lsn.map(Lsn::from),
})
}
}
impl TryFrom<proto::InterpretedWalRecord> for InterpretedWalRecord {
type Error = TranscodeError;
fn try_from(value: proto::InterpretedWalRecord) -> Result<Self, Self::Error> {
let metadata_record = value
.metadata_record
.map(|mrec| -> Result<_, DeserializeError> { MetadataRecord::des(&mrec) })
.transpose()?;
let batch = {
let batch = value.batch.ok_or_else(|| {
TranscodeError::BadInput("InterpretedWalRecord::batch missing".to_string())
})?;
SerializedValueBatch::try_from(batch)?
};
Ok(InterpretedWalRecord {
metadata_record,
batch,
next_record_lsn: Lsn(value.next_record_lsn),
flush_uncommitted: if value.flush_uncommitted {
FlushUncommittedRecords::Yes
} else {
FlushUncommittedRecords::No
},
xid: value.xid,
})
}
}
impl TryFrom<proto::SerializedValueBatch> for SerializedValueBatch {
type Error = TranscodeError;
fn try_from(value: proto::SerializedValueBatch) -> Result<Self, Self::Error> {
let metadata = value
.metadata
.into_iter()
.map(ValueMeta::try_from)
.collect::<Result<Vec<_>, _>>()?;
Ok(SerializedValueBatch {
raw: value.raw,
metadata,
max_lsn: Lsn(value.max_lsn),
len: value.len as usize,
})
}
}
impl TryFrom<proto::ValueMeta> for ValueMeta {
type Error = TranscodeError;
fn try_from(value: proto::ValueMeta) -> Result<Self, Self::Error> {
match proto::ValueMetaType::try_from(value.r#type) {
Ok(proto::ValueMetaType::Serialized) => {
Ok(ValueMeta::Serialized(SerializedValueMeta {
key: value
.key
.ok_or_else(|| {
TranscodeError::BadInput("ValueMeta::key missing".to_string())
})?
.into(),
lsn: Lsn(value.lsn),
batch_offset: value.batch_offset.ok_or_else(|| {
TranscodeError::BadInput("ValueMeta::batch_offset missing".to_string())
})?,
len: value.len.ok_or_else(|| {
TranscodeError::BadInput("ValueMeta::len missing".to_string())
})? as usize,
will_init: value.will_init.ok_or_else(|| {
TranscodeError::BadInput("ValueMeta::will_init missing".to_string())
})?,
}))
}
Ok(proto::ValueMetaType::Observed) => Ok(ValueMeta::Observed(ObservedValueMeta {
key: value
.key
.ok_or_else(|| TranscodeError::BadInput("ValueMeta::key missing".to_string()))?
.into(),
lsn: Lsn(value.lsn),
})),
Err(_) => Err(TranscodeError::BadInput(format!(
"Unexpected ValueMeta::type {}",
value.r#type
))),
}
}
}
impl From<proto::CompactKey> for CompactKey {
fn from(value: proto::CompactKey) -> Self {
(((value.high as i128) << 64) | (value.low as i128)).into()
}
}

View File

@@ -23,8 +23,8 @@ use tokio_postgres::{replication::ReplicationStream, Client};
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, trace, warn, Instrument};
use wal_decoder::{
codec::make_decoder,
models::{FlushUncommittedRecords, InterpretedWalRecord, InterpretedWalRecords},
wire_format::FromWireFormat,
};
use super::TaskStateUpdate;
@@ -264,12 +264,12 @@ pub(super) async fn handle_walreceiver_connection(
let mut walingest = WalIngest::new(timeline.as_ref(), startpoint, &ctx).await?;
let interpreted_proto_config = match protocol {
let interpreted_proto_decoder = match protocol {
PostgresClientProtocol::Vanilla => None,
PostgresClientProtocol::Interpreted {
format,
compression,
} => Some((format, compression)),
} => Some(make_decoder(format, compression)),
};
while let Some(replication_message) = {
@@ -345,14 +345,12 @@ pub(super) async fn handle_walreceiver_connection(
// were interpreted.
let streaming_lsn = Lsn::from(raw.streaming_lsn());
let (format, compression) = interpreted_proto_config.unwrap();
let batch = InterpretedWalRecords::from_wire(raw.data(), format, compression)
.await
.with_context(|| {
anyhow::anyhow!(
let decoder = interpreted_proto_decoder.as_ref().unwrap();
let batch = decoder.decode(raw.data()).await.with_context(|| {
anyhow::anyhow!(
"Failed to deserialize interpreted records ending at LSN {streaming_lsn}"
)
})?;
})?;
let InterpretedWalRecords {
records,

View File

@@ -8,12 +8,9 @@ use postgres_ffi::MAX_SEND_SIZE;
use postgres_ffi::{get_current_timestamp, waldecoder::WalStreamDecoder};
use pq_proto::{BeMessage, InterpretedWalRecordsBody, WalSndKeepAlive};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::time::MissedTickBehavior;
use utils::lsn::Lsn;
use utils::postgres_client::Compression;
use utils::postgres_client::InterpretedFormat;
use wal_decoder::codec::Encoder;
use wal_decoder::models::{InterpretedWalRecord, InterpretedWalRecords};
use wal_decoder::wire_format::ToWireFormat;
use crate::send_wal::EndWatchView;
use crate::wal_reader_stream::{WalBytes, WalReaderStreamBuilder};
@@ -22,8 +19,7 @@ use crate::wal_reader_stream::{WalBytes, WalReaderStreamBuilder};
/// This is used for sending WAL to the pageserver. Said WAL
/// is pre-interpreted and filtered for the shard.
pub(crate) struct InterpretedWalSender<'a, IO> {
pub(crate) format: InterpretedFormat,
pub(crate) compression: Option<Compression>,
pub(crate) encoder: Box<dyn Encoder>,
pub(crate) pgb: &'a mut PostgresBackend<IO>,
pub(crate) wal_stream_builder: WalReaderStreamBuilder,
pub(crate) end_watch_view: EndWatchView,
@@ -45,6 +41,8 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> InterpretedWalSender<'_, IO> {
/// Err(CopyStreamHandlerEnd) is always returned; Result is used only for ?
/// convenience.
pub(crate) async fn run(self) -> Result<(), CopyStreamHandlerEnd> {
const KEEPALIVE_AFTER: Duration = Duration::from_secs(1);
let mut wal_position = self.wal_stream_builder.start_pos();
let mut wal_decoder =
WalStreamDecoder::new(self.wal_stream_builder.start_pos(), self.pg_version);
@@ -52,97 +50,122 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> InterpretedWalSender<'_, IO> {
let stream = self.wal_stream_builder.build(MAX_SEND_SIZE).await?;
let mut stream = std::pin::pin!(stream);
let mut keepalive_ticker = tokio::time::interval(Duration::from_secs(1));
keepalive_ticker.set_missed_tick_behavior(MissedTickBehavior::Skip);
keepalive_ticker.reset();
let (tx, mut rx) = tokio::sync::mpsc::channel::<Batch>(2);
loop {
tokio::select! {
// Get some WAL from the stream and then: decode, interpret and push it down the
// pipeline.
wal = stream.next(), if tx.capacity() > 0 => {
let WalBytes { wal, wal_start_lsn: _, wal_end_lsn, available_wal_end_lsn } = match wal {
Some(some) => some?,
None => { break; }
};
let wal_read_future = async move {
loop {
let _guard = tx
.reserve()
.await
.with_context(|| "Failed to reserve channel slot")?;
let wal = stream.next().await;
wal_position = wal_end_lsn;
wal_decoder.feed_bytes(&wal);
let mut records = Vec::new();
let mut max_next_record_lsn = None;
while let Some((next_record_lsn, recdata)) = wal_decoder
.poll_decode()
.with_context(|| "Failed to decode WAL")?
{
assert!(next_record_lsn.is_aligned());
max_next_record_lsn = Some(next_record_lsn);
// Deserialize and interpret WAL record
let interpreted = InterpretedWalRecord::from_bytes_filtered(
recdata,
&self.shard,
next_record_lsn,
self.pg_version,
)
.with_context(|| "Failed to interpret WAL")?;
if !interpreted.is_empty() {
records.push(interpreted);
}
let WalBytes {
wal,
wal_start_lsn: _,
wal_end_lsn,
available_wal_end_lsn,
} = match wal {
Some(some) => some?,
None => {
break;
}
};
let batch = InterpretedWalRecords {
records,
next_record_lsn: max_next_record_lsn
};
wal_position = wal_end_lsn;
wal_decoder.feed_bytes(&wal);
tx.send(Batch {wal_end_lsn, available_wal_end_lsn, records: batch}).await.unwrap();
},
// For a previously interpreted batch, serialize it and push it down the wire.
batch = rx.recv() => {
let batch = match batch {
let mut records = Vec::new();
let mut max_next_record_lsn = None;
while let Some((next_record_lsn, recdata)) = wal_decoder
.poll_decode()
.with_context(|| "Failed to decode WAL")?
{
assert!(next_record_lsn.is_aligned());
max_next_record_lsn = Some(next_record_lsn);
// Deserialize and interpret WAL record
let interpreted = InterpretedWalRecord::from_bytes_filtered(
recdata,
&self.shard,
next_record_lsn,
self.pg_version,
)
.with_context(|| "Failed to interpret WAL")?;
if !interpreted.is_empty() {
records.push(interpreted);
}
}
let batch = InterpretedWalRecords {
records,
next_record_lsn: max_next_record_lsn,
};
tx.send(Batch {
wal_end_lsn,
available_wal_end_lsn,
records: batch,
})
.await
.unwrap();
}
Ok::<_, CopyStreamHandlerEnd>(wal_position)
};
let encode_and_send_future = async move {
loop {
let timeout_or_batch = tokio::time::timeout(KEEPALIVE_AFTER, rx.recv()).await;
let batch = match timeout_or_batch {
Ok(batch) => match batch {
Some(b) => b,
None => { break; }
};
None => {
break;
}
},
Err(_) => {
self.pgb
.write_message(&BeMessage::KeepAlive(WalSndKeepAlive {
wal_end: self.end_watch_view.get().0,
timestamp: get_current_timestamp(),
request_reply: true,
}))
.await?;
let buf = batch
.records
.to_wire(self.format, self.compression)
.await
.with_context(|| "Failed to serialize interpreted WAL")
.map_err(CopyStreamHandlerEnd::from)?;
continue;
}
};
// Reset the keep alive ticker since we are sending something
// over the wire now.
keepalive_ticker.reset();
let buf = self
.encoder
.encode(batch.records)
.await
.with_context(|| "Failed to serialize interpreted WAL")
.map_err(CopyStreamHandlerEnd::from)?;
self.pgb
.write_message(&BeMessage::InterpretedWalRecords(InterpretedWalRecordsBody {
self.pgb
.write_message(&BeMessage::InterpretedWalRecords(
InterpretedWalRecordsBody {
streaming_lsn: batch.wal_end_lsn.0,
commit_lsn: batch.available_wal_end_lsn.0,
data: &buf,
})).await?;
}
// Send a periodic keep alive when the connection has been idle for a while.
_ = keepalive_ticker.tick() => {
self.pgb
.write_message(&BeMessage::KeepAlive(WalSndKeepAlive {
wal_end: self.end_watch_view.get().0,
timestamp: get_current_timestamp(),
request_reply: true,
}))
.await?;
}
},
))
.await?;
}
}
Ok::<_, CopyStreamHandlerEnd>(())
};
let pipeline_ok = tokio::try_join!(wal_read_future, encode_and_send_future)?;
let (final_wal_position, _) = pipeline_ok;
// The loop above ends when the receiver is caught up and there's no more WAL to send.
Err(CopyStreamHandlerEnd::ServerInitiated(format!(
"ending streaming to {:?} at {}, receiver is caughtup and there is no computes",
self.appname, wal_position,
self.appname, final_wal_position,
)))
}
}

View File

@@ -26,6 +26,7 @@ use utils::failpoint_support;
use utils::id::TenantTimelineId;
use utils::pageserver_feedback::PageserverFeedback;
use utils::postgres_client::PostgresClientProtocol;
use wal_decoder::codec::encoder_from_proto;
use std::cmp::{max, min};
use std::net::SocketAddr;
@@ -504,9 +505,9 @@ impl SafekeeperPostgresHandler {
wal_sender_guard: ws_guard.clone(),
};
let encoder = encoder_from_proto(format, compression);
let sender = InterpretedWalSender {
format,
compression,
encoder,
pgb,
wal_stream_builder,
end_watch_view,