From 0a13a06053bc5422b6eb0f8dd807653c69c65521 Mon Sep 17 00:00:00 2001 From: Vlad Lazar Date: Wed, 20 Nov 2024 12:32:22 +0100 Subject: [PATCH] wal_decoder: add compression support --- Cargo.lock | 2 + libs/utils/src/postgres_client.rs | 11 ++- libs/wal_decoder/Cargo.toml | 2 + libs/wal_decoder/src/wire_format.rs | 70 ++++++++++++++++--- .../walreceiver/walreceiver_connection.rs | 18 +++-- safekeeper/src/send_interpreted_wal.rs | 5 +- safekeeper/src/send_wal.rs | 6 +- 7 files changed, 95 insertions(+), 19 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d828761331..53cad80ca8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7120,12 +7120,14 @@ name = "wal_decoder" version = "0.1.0" dependencies = [ "anyhow", + "async-compression", "bytes", "pageserver_api", "postgres_ffi", "prost", "serde", "thiserror", + "tokio", "tonic", "tonic-build", "tracing", diff --git a/libs/utils/src/postgres_client.rs b/libs/utils/src/postgres_client.rs index 3a1a51d876..a62568202b 100644 --- a/libs/utils/src/postgres_client.rs +++ b/libs/utils/src/postgres_client.rs @@ -14,6 +14,12 @@ pub enum InterpretedFormat { Protobuf, } +#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[serde(rename_all = "kebab-case")] +pub enum Compression { + Zstd { level: i8 }, +} + #[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)] #[serde(tag = "type", content = "args")] #[serde(rename_all = "kebab-case")] @@ -22,7 +28,10 @@ pub enum PostgresClientProtocol { Vanilla, /// Custom shard-aware protocol that replicates interpreted records. /// Used to send wal from safekeeper to pageserver. - Interpreted { format: InterpretedFormat }, + Interpreted { + format: InterpretedFormat, + compression: Option, + }, } pub struct ConnectionConfigArgs<'a> { diff --git a/libs/wal_decoder/Cargo.toml b/libs/wal_decoder/Cargo.toml index 96e1295434..8fac4e38ca 100644 --- a/libs/wal_decoder/Cargo.toml +++ b/libs/wal_decoder/Cargo.toml @@ -8,6 +8,7 @@ license.workspace = true testing = ["pageserver_api/testing"] [dependencies] +async-compression.workspace = true anyhow.workspace = true bytes.workspace = true pageserver_api.workspace = true @@ -15,6 +16,7 @@ prost.workspace = true postgres_ffi.workspace = true serde.workspace = true thiserror.workspace = true +tokio = { workspace = true, features = ["io-util"] } tonic.workspace = true tracing.workspace = true utils.workspace = true diff --git a/libs/wal_decoder/src/wire_format.rs b/libs/wal_decoder/src/wire_format.rs index 531f492bb5..bfa20a006a 100644 --- a/libs/wal_decoder/src/wire_format.rs +++ b/libs/wal_decoder/src/wire_format.rs @@ -1,9 +1,10 @@ use bytes::{BufMut, Bytes, BytesMut}; use pageserver_api::key::CompactKey; use prost::{DecodeError, EncodeError, Message}; +use tokio::io::AsyncWriteExt; use utils::bin_ser::{BeSer, DeserializeError, SerializeError}; use utils::lsn::Lsn; -use utils::postgres_client::InterpretedFormat; +use utils::postgres_client::{Compression, InterpretedFormat}; use crate::models::{ FlushUncommittedRecords, InterpretedWalRecord, InterpretedWalRecords, MetadataRecord, @@ -26,6 +27,8 @@ pub enum ToWireFormatError { Bincode(#[from] SerializeError), #[error("{0}")] Protobuf(#[from] ProtobufSerializeError), + #[error("{0}")] + Compression(#[from] std::io::Error), } #[derive(Debug, thiserror::Error)] @@ -42,6 +45,8 @@ pub enum FromWireFormatError { Bincode(#[from] DeserializeError), #[error("{0}")] Protobuf(#[from] ProtobufDeserializeError), + #[error("{0}")] + Decompress(#[from] std::io::Error), } #[derive(Debug, thiserror::Error)] @@ -61,17 +66,32 @@ pub enum TranscodeError { } pub trait ToWireFormat { - fn to_wire(self, format: InterpretedFormat) -> Result; + fn to_wire( + self, + format: InterpretedFormat, + compression: Option, + ) -> impl std::future::Future> + Send; } pub trait FromWireFormat { type T; - fn from_wire(buf: &Bytes, format: InterpretedFormat) -> Result; + fn from_wire( + buf: &Bytes, + format: InterpretedFormat, + compression: Option, + ) -> impl std::future::Future> + Send; } impl ToWireFormat for InterpretedWalRecords { - fn to_wire(self, format: InterpretedFormat) -> Result { - match format { + async fn to_wire( + self, + format: InterpretedFormat, + compression: Option, + ) -> Result { + use async_compression::tokio::write::ZstdEncoder; + use async_compression::Level; + + let encode_res: Result = match format { InterpretedFormat::Bincode => { let buf = BytesMut::new(); let mut buf = buf.writer(); @@ -87,20 +107,52 @@ impl ToWireFormat for InterpretedWalRecords { Ok(buf.freeze()) } - } + }; + + let buf = encode_res?; + let compressed_buf = match compression { + Some(Compression::Zstd { level }) => { + let mut encoder = ZstdEncoder::with_quality( + Vec::with_capacity(buf.len() / 4), + Level::Precise(level as i32), + ); + encoder.write_all(&buf).await?; + encoder.shutdown().await?; + Bytes::from(encoder.into_inner()) + } + None => buf, + }; + + Ok(compressed_buf) } } impl FromWireFormat for InterpretedWalRecords { type T = Self; - fn from_wire(buf: &Bytes, format: InterpretedFormat) -> Result { + async fn from_wire( + buf: &Bytes, + format: InterpretedFormat, + compression: Option, + ) -> Result { + let decompressed_buf = match compression { + Some(Compression::Zstd { .. }) => { + use async_compression::tokio::write::ZstdDecoder; + let mut decoded_buf = Vec::with_capacity(buf.len()); + let mut decoder = ZstdDecoder::new(&mut decoded_buf); + decoder.write_all(buf).await?; + decoder.flush().await?; + Bytes::from(decoded_buf) + } + None => buf.clone(), + }; + match format { InterpretedFormat::Bincode => { - InterpretedWalRecords::des(buf).map_err(FromWireFormatError::Bincode) + InterpretedWalRecords::des(&decompressed_buf).map_err(FromWireFormatError::Bincode) } InterpretedFormat::Protobuf => { - let proto = ProtoInterpretedWalRecords::decode(buf.clone()) + let proto = ProtoInterpretedWalRecords::decode(decompressed_buf) .map_err(|e| FromWireFormatError::Protobuf(e.into()))?; InterpretedWalRecords::try_from(proto) .map_err(|e| FromWireFormatError::Protobuf(e.into())) diff --git a/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs b/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs index cd3b7b3418..31cf1b6307 100644 --- a/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs +++ b/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs @@ -264,9 +264,12 @@ pub(super) async fn handle_walreceiver_connection( let mut walingest = WalIngest::new(timeline.as_ref(), startpoint, &ctx).await?; - let interpreted_format = match protocol { + let interpreted_proto_config = match protocol { PostgresClientProtocol::Vanilla => None, - PostgresClientProtocol::Interpreted { format } => Some(format), + PostgresClientProtocol::Interpreted { + format, + compression, + } => Some((format, compression)), }; while let Some(replication_message) = { @@ -342,13 +345,14 @@ pub(super) async fn handle_walreceiver_connection( // were interpreted. let streaming_lsn = Lsn::from(raw.streaming_lsn()); - let batch = - InterpretedWalRecords::from_wire(raw.data(), interpreted_format.unwrap()) - .with_context(|| { - anyhow::anyhow!( + let (format, compression) = interpreted_proto_config.unwrap(); + let batch = InterpretedWalRecords::from_wire(raw.data(), format, compression) + .await + .with_context(|| { + anyhow::anyhow!( "Failed to deserialize interpreted records ending at LSN {streaming_lsn}" ) - })?; + })?; let InterpretedWalRecords { records, diff --git a/safekeeper/src/send_interpreted_wal.rs b/safekeeper/src/send_interpreted_wal.rs index 0fc536e7b2..03874b63f5 100644 --- a/safekeeper/src/send_interpreted_wal.rs +++ b/safekeeper/src/send_interpreted_wal.rs @@ -9,6 +9,7 @@ use postgres_ffi::{get_current_timestamp, waldecoder::WalStreamDecoder}; use pq_proto::{BeMessage, InterpretedWalRecordsBody, WalSndKeepAlive}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::time::MissedTickBehavior; +use utils::postgres_client::Compression; use utils::postgres_client::InterpretedFormat; use wal_decoder::models::{InterpretedWalRecord, InterpretedWalRecords}; use wal_decoder::wire_format::ToWireFormat; @@ -21,6 +22,7 @@ use crate::wal_reader_stream::{WalBytes, WalReaderStreamBuilder}; /// is pre-interpreted and filtered for the shard. pub(crate) struct InterpretedWalSender<'a, IO> { pub(crate) format: InterpretedFormat, + pub(crate) compression: Option, pub(crate) pgb: &'a mut PostgresBackend, pub(crate) wal_stream_builder: WalReaderStreamBuilder, pub(crate) end_watch_view: EndWatchView, @@ -86,7 +88,8 @@ impl InterpretedWalSender<'_, IO> { records, next_record_lsn: max_next_record_lsn }; - let buf = batch.to_wire(self.format).with_context(|| "Failed to serialize interpreted WAL")?; + let buf = batch.to_wire(self.format, self.compression).await + .with_context(|| "Failed to serialize interpreted WAL")?; // Reset the keep alive ticker since we are sending something // over the wire now. diff --git a/safekeeper/src/send_wal.rs b/safekeeper/src/send_wal.rs index 88bb0ee461..225b7f4c05 100644 --- a/safekeeper/src/send_wal.rs +++ b/safekeeper/src/send_wal.rs @@ -489,7 +489,10 @@ impl SafekeeperPostgresHandler { Either::Left(sender.run()) } - PostgresClientProtocol::Interpreted { format } => { + PostgresClientProtocol::Interpreted { + format, + compression, + } => { let pg_version = tli.tli.get_state().await.1.server.pg_version / 10000; let end_watch_view = end_watch.view(); let wal_stream_builder = WalReaderStreamBuilder { @@ -503,6 +506,7 @@ impl SafekeeperPostgresHandler { let sender = InterpretedWalSender { format, + compression, pgb, wal_stream_builder, end_watch_view,