diff --git a/Cargo.lock b/Cargo.lock index c1a14210de..56986cbb5e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7124,6 +7124,7 @@ dependencies = [ "pageserver_api", "postgres_ffi", "serde", + "thiserror", "tracing", "utils", "workspace_hack", diff --git a/libs/utils/src/postgres_client.rs b/libs/utils/src/postgres_client.rs index 3073bbde4c..3ed08a40ee 100644 --- a/libs/utils/src/postgres_client.rs +++ b/libs/utils/src/postgres_client.rs @@ -7,40 +7,21 @@ use postgres_connection::{parse_host_port, PgConnectionConfig}; use crate::id::TenantTimelineId; -/// Postgres client protocol types -#[derive( - Copy, - Clone, - PartialEq, - Eq, - strum_macros::EnumString, - strum_macros::Display, - serde_with::DeserializeFromStr, - serde_with::SerializeDisplay, - Debug, -)] -#[strum(serialize_all = "kebab-case")] -#[repr(u8)] +#[derive(Copy, Clone, PartialEq, Eq, Debug, serde::Serialize, serde::Deserialize)] +#[serde(rename_all = "kebab-case")] +pub enum InterpretedFormat { + Bincode, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[serde(tag = "type", content = "args")] +#[serde(rename_all = "kebab-case")] pub enum PostgresClientProtocol { /// Usual Postgres replication protocol Vanilla, /// Custom shard-aware protocol that replicates interpreted records. /// Used to send wal from safekeeper to pageserver. - Interpreted, -} - -impl TryFrom for PostgresClientProtocol { - type Error = u8; - - fn try_from(value: u8) -> Result { - Ok(match value { - v if v == (PostgresClientProtocol::Vanilla as u8) => PostgresClientProtocol::Vanilla, - v if v == (PostgresClientProtocol::Interpreted as u8) => { - PostgresClientProtocol::Interpreted - } - x => return Err(x), - }) - } + Interpreted { format: InterpretedFormat }, } pub struct ConnectionConfigArgs<'a> { @@ -63,7 +44,10 @@ impl<'a> ConnectionConfigArgs<'a> { "-c".to_owned(), format!("timeline_id={}", self.ttid.timeline_id), format!("tenant_id={}", self.ttid.tenant_id), - format!("protocol={}", self.protocol as u8), + format!( + "protocol={}", + serde_json::to_string(&self.protocol).unwrap() + ), ]; if self.shard_number.is_some() { diff --git a/libs/wal_decoder/Cargo.toml b/libs/wal_decoder/Cargo.toml index c8c0f4c990..02b9f72ed5 100644 --- a/libs/wal_decoder/Cargo.toml +++ b/libs/wal_decoder/Cargo.toml @@ -13,6 +13,7 @@ bytes.workspace = true pageserver_api.workspace = true postgres_ffi.workspace = true serde.workspace = true +thiserror.workspace = true tracing.workspace = true utils.workspace = true workspace_hack = { version = "0.1", path = "../../workspace_hack" } diff --git a/libs/wal_decoder/src/lib.rs b/libs/wal_decoder/src/lib.rs index a8a26956e6..96b717021f 100644 --- a/libs/wal_decoder/src/lib.rs +++ b/libs/wal_decoder/src/lib.rs @@ -1,3 +1,4 @@ pub mod decoder; pub mod models; pub mod serialized_batch; +pub mod wire_format; diff --git a/libs/wal_decoder/src/wire_format.rs b/libs/wal_decoder/src/wire_format.rs new file mode 100644 index 0000000000..ee857f043f --- /dev/null +++ b/libs/wal_decoder/src/wire_format.rs @@ -0,0 +1,52 @@ +use bytes::{BufMut, Bytes, BytesMut}; +use utils::bin_ser::{BeSer, DeserializeError, SerializeError}; +use utils::postgres_client::InterpretedFormat; + +use crate::models::InterpretedWalRecord; + +#[derive(Debug, thiserror::Error)] +pub enum ToWireFormatError { + #[error("{0}")] + Bincode(SerializeError), +} + +#[derive(Debug, thiserror::Error)] +pub enum FromWireFormatError { + #[error("{0}")] + Bincode(DeserializeError), +} + +pub trait ToWireFormat { + fn to_wire(self, format: InterpretedFormat) -> Result; +} + +pub trait FromWireFormat { + type T; + fn from_wire(buf: &Bytes, format: InterpretedFormat) -> Result; +} + +impl ToWireFormat for Vec { + fn to_wire(self, format: InterpretedFormat) -> Result { + match format { + InterpretedFormat::Bincode => { + let buf = BytesMut::new(); + let mut buf = buf.writer(); + self.ser_into(&mut buf) + .map_err(ToWireFormatError::Bincode)?; + Ok(buf.into_inner().freeze()) + } + } + } +} + +impl FromWireFormat for Vec { + type T = Self; + + fn from_wire(buf: &Bytes, format: InterpretedFormat) -> Result { + match format { + InterpretedFormat::Bincode => { + Vec::::des(buf).map_err(FromWireFormatError::Bincode) + } + } + } +} diff --git a/pageserver/src/tenant/timeline/walreceiver/connection_manager.rs b/pageserver/src/tenant/timeline/walreceiver/connection_manager.rs index 7a64703a30..583d6309ab 100644 --- a/pageserver/src/tenant/timeline/walreceiver/connection_manager.rs +++ b/pageserver/src/tenant/timeline/walreceiver/connection_manager.rs @@ -535,6 +535,7 @@ impl ConnectionManagerState { let node_id = new_sk.safekeeper_id; let connect_timeout = self.conf.wal_connect_timeout; let ingest_batch_size = self.conf.ingest_batch_size; + let protocol = self.conf.protocol; let timeline = Arc::clone(&self.timeline); let ctx = ctx.detached_child( TaskKind::WalReceiverConnectionHandler, @@ -548,6 +549,7 @@ impl ConnectionManagerState { let res = super::walreceiver_connection::handle_walreceiver_connection( timeline, + protocol, new_sk.wal_source_connconf, events_sender, cancellation.clone(), @@ -991,7 +993,7 @@ impl ConnectionManagerState { PostgresClientProtocol::Vanilla => { (None, None, None) }, - PostgresClientProtocol::Interpreted => { + PostgresClientProtocol::Interpreted { .. } => { let shard_identity = self.timeline.get_shard_identity(); ( Some(shard_identity.number.0), diff --git a/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs b/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs index 1a0e66ceb3..4f17f07a86 100644 --- a/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs +++ b/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs @@ -22,7 +22,10 @@ use tokio::{select, sync::watch, time}; use tokio_postgres::{replication::ReplicationStream, Client}; use tokio_util::sync::CancellationToken; use tracing::{debug, error, info, trace, warn, Instrument}; -use wal_decoder::models::{FlushUncommittedRecords, InterpretedWalRecord}; +use wal_decoder::{ + models::{FlushUncommittedRecords, InterpretedWalRecord}, + wire_format::FromWireFormat, +}; use super::TaskStateUpdate; use crate::{ @@ -36,7 +39,7 @@ use crate::{ use postgres_backend::is_expected_io_error; use postgres_connection::PgConnectionConfig; use postgres_ffi::waldecoder::WalStreamDecoder; -use utils::{bin_ser::BeSer, id::NodeId, lsn::Lsn}; +use utils::{id::NodeId, lsn::Lsn, postgres_client::PostgresClientProtocol}; use utils::{pageserver_feedback::PageserverFeedback, sync::gate::GateError}; /// Status of the connection. @@ -109,6 +112,7 @@ impl From for WalReceiverError { #[allow(clippy::too_many_arguments)] pub(super) async fn handle_walreceiver_connection( timeline: Arc, + protocol: PostgresClientProtocol, wal_source_connconf: PgConnectionConfig, events_sender: watch::Sender>, cancellation: CancellationToken, @@ -260,6 +264,11 @@ pub(super) async fn handle_walreceiver_connection( let mut walingest = WalIngest::new(timeline.as_ref(), startpoint, &ctx).await?; + let interpreted_format = match protocol { + PostgresClientProtocol::Vanilla => None, + PostgresClientProtocol::Interpreted { format } => Some(format), + }; + while let Some(replication_message) = { select! { _ = cancellation.cancelled() => { @@ -337,11 +346,13 @@ pub(super) async fn handle_walreceiver_connection( Lsn(raw.next_record_lsn().unwrap_or(0)) ); - let records = Vec::::des(raw.data()).with_context(|| { - anyhow::anyhow!( + let records = + Vec::::from_wire(raw.data(), interpreted_format.unwrap()) + .with_context(|| { + anyhow::anyhow!( "Failed to deserialize interpreted records ending at LSN {streaming_lsn}" ) - })?; + })?; // We start the modification at 0 because each interpreted record // advances it to its end LSN. 0 is just an initialization placeholder. diff --git a/safekeeper/src/handler.rs b/safekeeper/src/handler.rs index cec7c3c7ee..22f33b17e0 100644 --- a/safekeeper/src/handler.rs +++ b/safekeeper/src/handler.rs @@ -123,17 +123,10 @@ impl postgres_backend::Handler // https://github.com/neondatabase/neon/pull/2433#discussion_r970005064 match opt.split_once('=') { Some(("protocol", value)) => { - let raw_value = value - .parse::() - .with_context(|| format!("Failed to parse {value} as protocol"))?; - - self.protocol = Some( - PostgresClientProtocol::try_from(raw_value).map_err(|_| { - QueryError::Other(anyhow::anyhow!( - "Unexpected client protocol type: {raw_value}" - )) - })?, - ); + self.protocol = + Some(serde_json::from_str(value).with_context(|| { + format!("Failed to parse {value} as protocol") + })?); } Some(("ztenantid", value)) | Some(("tenant_id", value)) => { self.tenant_id = Some(value.parse().with_context(|| { @@ -180,7 +173,7 @@ impl postgres_backend::Handler ))); } } - PostgresClientProtocol::Interpreted => { + PostgresClientProtocol::Interpreted { .. } => { match (shard_count, shard_number, shard_stripe_size) { (Some(count), Some(number), Some(stripe_size)) => { let params = ShardParameters { diff --git a/safekeeper/src/send_interpreted_wal.rs b/safekeeper/src/send_interpreted_wal.rs index cf0ee276e9..fc318fdd4b 100644 --- a/safekeeper/src/send_interpreted_wal.rs +++ b/safekeeper/src/send_interpreted_wal.rs @@ -9,9 +9,10 @@ use postgres_ffi::{get_current_timestamp, waldecoder::WalStreamDecoder}; use pq_proto::{BeMessage, InterpretedWalRecordsBody, WalSndKeepAlive}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::time::MissedTickBehavior; -use utils::bin_ser::BeSer; use utils::lsn::Lsn; +use utils::postgres_client::InterpretedFormat; use wal_decoder::models::InterpretedWalRecord; +use wal_decoder::wire_format::ToWireFormat; use crate::send_wal::EndWatchView; use crate::wal_reader_stream::{WalBytes, WalReaderStreamBuilder}; @@ -20,6 +21,7 @@ use crate::wal_reader_stream::{WalBytes, WalReaderStreamBuilder}; /// This is used for sending WAL to the pageserver. Said WAL /// is pre-interpreted and filtered for the shard. pub(crate) struct InterpretedWalSender<'a, IO> { + pub(crate) format: InterpretedFormat, pub(crate) pgb: &'a mut PostgresBackend, pub(crate) wal_stream_builder: WalReaderStreamBuilder, pub(crate) end_watch_view: EndWatchView, @@ -81,10 +83,7 @@ impl InterpretedWalSender<'_, IO> { } } - let mut buf = Vec::new(); - records - .ser_into(&mut buf) - .with_context(|| "Failed to serialize interpreted WAL")?; + let buf = records.to_wire(self.format).with_context(|| "Failed to serialize interpreted WAL")?; // Reset the keep alive ticker since we are sending something // over the wire now. @@ -95,7 +94,7 @@ impl InterpretedWalSender<'_, IO> { streaming_lsn: wal_end_lsn.0, commit_lsn: available_wal_end_lsn.0, next_record_lsn: max_next_record_lsn.unwrap_or(Lsn::INVALID).0, - data: buf.as_slice(), + data: &buf, })).await?; } diff --git a/safekeeper/src/send_wal.rs b/safekeeper/src/send_wal.rs index 1acfcad418..88bb0ee461 100644 --- a/safekeeper/src/send_wal.rs +++ b/safekeeper/src/send_wal.rs @@ -454,7 +454,7 @@ impl SafekeeperPostgresHandler { } info!( - "starting streaming from {:?}, available WAL ends at {}, recovery={}, appname={:?}, protocol={}", + "starting streaming from {:?}, available WAL ends at {}, recovery={}, appname={:?}, protocol={:?}", start_pos, end_pos, matches!(end_watch, EndWatch::Flush(_)), @@ -489,7 +489,7 @@ impl SafekeeperPostgresHandler { Either::Left(sender.run()) } - PostgresClientProtocol::Interpreted => { + PostgresClientProtocol::Interpreted { format } => { let pg_version = tli.tli.get_state().await.1.server.pg_version / 10000; let end_watch_view = end_watch.view(); let wal_stream_builder = WalReaderStreamBuilder { @@ -502,6 +502,7 @@ impl SafekeeperPostgresHandler { }; let sender = InterpretedWalSender { + format, pgb, wal_stream_builder, end_watch_view, diff --git a/test_runner/performance/test_sharded_ingest.py b/test_runner/performance/test_sharded_ingest.py index e965aae5a0..afb1ef45ef 100644 --- a/test_runner/performance/test_sharded_ingest.py +++ b/test_runner/performance/test_sharded_ingest.py @@ -27,14 +27,29 @@ def test_sharded_ingest( and fanning out to a large number of shards on dedicated Pageservers. Comparing the base case (shard_count=1) to the sharded case indicates the overhead of sharding. """ - neon_env_builder.pageserver_config_override = ( - f"wal_receiver_protocol = '{wal_receiver_protocol}'" - ) - ROW_COUNT = 100_000_000 # about 7 GB of WAL neon_env_builder.num_pageservers = shard_count - env = neon_env_builder.init_start() + env = neon_env_builder.init_configs() + + for ps in env.pageservers: + if wal_receiver_protocol == "vanilla": + ps.patch_config_toml_nonrecursive({ + "wal_receiver_protocol": { + "type": "vanilla", + } + }) + elif wal_receiver_protocol == "interpreted": + ps.patch_config_toml_nonrecursive({ + "wal_receiver_protocol": { + "type": "interpreted", + "args": { + "format": "bincode" + } + } + }) + + env.start() # Create a sharded tenant and timeline, and migrate it to the respective pageservers. Ensure # the storage controller doesn't mess with shard placements.