diff --git a/libs/postgres_ffi/src/lib.rs b/libs/postgres_ffi/src/lib.rs index 91f400b74d..9529f45787 100644 --- a/libs/postgres_ffi/src/lib.rs +++ b/libs/postgres_ffi/src/lib.rs @@ -61,7 +61,47 @@ pub fn page_set_lsn(pg: &mut [u8], lsn: Lsn) { /// Calculate page checksum and stamp it onto the page. /// NB: this will zero out and ignore any existing checksum. pub fn page_set_checksum(page: &mut [u8], blkno: u32) { - page[8..10].copy_from_slice(&[0u8; 2]); let checksum = pg_checksum_page(page, blkno); page[8..10].copy_from_slice(&checksum.to_le_bytes()); } + +/// Check if page checksum is valid. +pub fn page_verify_checksum(page: &[u8], blkno: u32) -> bool { + let checksum = pg_checksum_page(page, blkno); + checksum == u16::from_le_bytes(page[8..10].try_into().unwrap()) +} + +#[cfg(test)] +mod tests { + use crate::pg_constants::BLCKSZ; + use crate::{page_set_checksum, page_verify_checksum}; + use utils::pg_checksum_page::pg_checksum_page; + + #[test] + fn set_and_verify_checksum() { + // Create a page with some content and without a correct checksum. + let mut page: [u8; BLCKSZ as usize] = [0; BLCKSZ as usize]; + for (i, byte) in page.iter_mut().enumerate().take(BLCKSZ as usize) { + *byte = i as u8; + } + + // Calculate the checksum. + let checksum = pg_checksum_page(&page[..], 0); + + // Sanity check: random bytes in the checksum attribute should not be + // a valid checksum. + assert_ne!( + checksum, + u16::from_le_bytes(page[8..10].try_into().unwrap()) + ); + + // Set the actual checksum. + page_set_checksum(&mut page, 0); + + // Verify the checksum. + assert!(page_verify_checksum(&page[..], 0)); + + // Checksum is not valid with another block number. + assert!(!page_verify_checksum(&page[..], 1)); + } +} diff --git a/libs/postgres_ffi/src/waldecoder.rs b/libs/postgres_ffi/src/waldecoder.rs index 91542d268f..a3365da089 100644 --- a/libs/postgres_ffi/src/waldecoder.rs +++ b/libs/postgres_ffi/src/waldecoder.rs @@ -14,7 +14,6 @@ use super::XLogLongPageHeaderData; use super::XLogPageHeaderData; use super::XLogRecord; use bytes::{Buf, BufMut, Bytes, BytesMut}; -use crc32c::*; use log::*; use std::cmp::min; use thiserror::Error; @@ -198,18 +197,12 @@ impl WalStreamDecoder { } // We now have a record in the 'recordbuf' local variable. - let xlogrec = - XLogRecord::from_slice(&recordbuf[0..XLOG_SIZE_OF_XLOG_RECORD]).map_err(|e| { - WalDecodeError { - msg: format!("xlog record deserialization failed {}", e), - lsn: self.lsn, - } - })?; + let xlogrec = XLogRecord::from_buf(&recordbuf).map_err(|e| WalDecodeError { + msg: format!("xlog record deserialization failed {}", e), + lsn: self.lsn, + })?; - let mut crc = 0; - crc = crc32c_append(crc, &recordbuf[XLOG_RECORD_CRC_OFFS + 4..]); - crc = crc32c_append(crc, &recordbuf[0..XLOG_RECORD_CRC_OFFS]); - if crc != xlogrec.xl_crc { + if !wal_record_verify_checksum(&xlogrec, &recordbuf) { return Err(WalDecodeError { msg: "WAL record crc mismatch".into(), lsn: self.lsn, diff --git a/libs/postgres_ffi/src/xlog_utils.rs b/libs/postgres_ffi/src/xlog_utils.rs index 67541d844e..d835c9b9e6 100644 --- a/libs/postgres_ffi/src/xlog_utils.rs +++ b/libs/postgres_ffi/src/xlog_utils.rs @@ -477,6 +477,10 @@ impl XLogRecord { XLogRecord::des(buf) } + pub fn from_buf(buf: &[u8]) -> Result { + XLogRecord::from_slice(&buf[0..XLOG_SIZE_OF_XLOG_RECORD]) + } + pub fn from_bytes(buf: &mut B) -> Result { use utils::bin_ser::LeSer; XLogRecord::des_from(&mut buf.reader()) @@ -742,3 +746,11 @@ mod tests { assert_eq!(checkpoint.nextXid.value, 2048); } } + +pub fn wal_record_verify_checksum(rec: &XLogRecord, recordbuf: &Bytes) -> bool { + let mut crc = 0; + crc = crc32c_append(crc, &recordbuf[XLOG_RECORD_CRC_OFFS + 4..]); + crc = crc32c_append(crc, &recordbuf[0..XLOG_RECORD_CRC_OFFS]); + + crc == rec.xl_crc +} diff --git a/libs/postgres_ffi/wal_generate/src/lib.rs b/libs/postgres_ffi/wal_generate/src/lib.rs index 37fd805e27..88dd112f92 100644 --- a/libs/postgres_ffi/wal_generate/src/lib.rs +++ b/libs/postgres_ffi/wal_generate/src/lib.rs @@ -55,8 +55,8 @@ impl Conf { let output = self .new_pg_command("initdb")? .arg("-D") - .arg("--data-checksums") .arg(self.datadir.as_os_str()) + .arg("--data-checksums") .args(&["-U", "postgres", "--no-instructions", "--no-sync"]) .output()?; debug!("initdb output: {:?}", output); diff --git a/libs/utils/src/pg_checksum_page.rs b/libs/utils/src/pg_checksum_page.rs index 5c7cdaf7ac..d1886316b0 100644 --- a/libs/utils/src/pg_checksum_page.rs +++ b/libs/utils/src/pg_checksum_page.rs @@ -35,6 +35,9 @@ fn checksum_comp(checksum: u32, value: u32) -> u32 { * The checksum includes the block number (to detect the case where a page is * somehow moved to a different location), the page header (excluding the * checksum itself), and the page data. + * + * As in C implementation in Postgres, the checksum attribute on the page is + * excluded from the calculation and preserved. */ pub fn pg_checksum_page(data: &[u8], blkno: u32) -> u16 { let page = unsafe { std::mem::transmute::<&[u8], &[u32]>(data) }; @@ -43,8 +46,19 @@ pub fn pg_checksum_page(data: &[u8], blkno: u32) -> u16 { /* main checksum calculation */ for i in 0..(BLCKSZ / (4 * N_SUMS)) { - for j in 0..N_SUMS { - sums[j] = checksum_comp(sums[j], page[i * N_SUMS + j]); + for (j, sum) in sums.iter_mut().enumerate().take(N_SUMS) { + let chunk_i = i * N_SUMS + j; + let chunk: u32; + if chunk_i == 2 { + let mut chunk_copy = page[chunk_i].to_le_bytes(); + chunk_copy[0] = 0; + chunk_copy[1] = 0; + chunk = u32::from_le_bytes(chunk_copy); + } else { + chunk = page[chunk_i]; + } + + *sum = checksum_comp(*sum, chunk); } } /* finally add in two rounds of zeroes for additional mixing */ @@ -68,3 +82,39 @@ pub fn pg_checksum_page(data: &[u8], blkno: u32) -> u16 { */ ((checksum % 65535) + 1) as u16 } + +#[cfg(test)] +mod tests { + use super::{pg_checksum_page, BLCKSZ}; + + #[test] + fn page_with_and_without_checksum() { + // Create a page with some content and without a correct checksum. + let mut page: [u8; BLCKSZ] = [0; BLCKSZ]; + for (i, byte) in page.iter_mut().enumerate().take(BLCKSZ) { + *byte = i as u8; + } + + // Calculate the checksum. + let checksum = pg_checksum_page(&page[..], 0); + + // Zero the checksum attribute on the page. + page[8..10].copy_from_slice(&[0u8; 2]); + + // Calculate the checksum again, should be the same. + let new_checksum = pg_checksum_page(&page[..], 0); + assert_eq!(checksum, new_checksum); + + // Set the correct checksum into the page. + page[8..10].copy_from_slice(&checksum.to_le_bytes()); + + // Calculate the checksum again, should be the same. + let new_checksum = pg_checksum_page(&page[..], 0); + assert_eq!(checksum, new_checksum); + + // Check that we protect from the page transposition, i.e. page is the + // same, but in the wrong place. + let wrong_blockno_checksum = pg_checksum_page(&page[..], 1); + assert_ne!(checksum, wrong_blockno_checksum); + } +} diff --git a/pageserver/src/walredo.rs b/pageserver/src/walredo.rs index 22f280d5bd..05669b851f 100644 --- a/pageserver/src/walredo.rs +++ b/pageserver/src/walredo.rs @@ -48,7 +48,8 @@ use postgres_ffi::nonrelfile_utils::mx_offset_to_flags_bitshift; use postgres_ffi::nonrelfile_utils::mx_offset_to_flags_offset; use postgres_ffi::nonrelfile_utils::mx_offset_to_member_offset; use postgres_ffi::nonrelfile_utils::transaction_id_set_status; -use postgres_ffi::pg_constants; +use postgres_ffi::xlog_utils::wal_record_verify_checksum; +use postgres_ffi::{page_verify_checksum, pg_constants, XLogRecord}; /// /// `RelTag` + block number (`blknum`) gives us a unique id of the page in the cluster. @@ -726,6 +727,13 @@ impl PostgresRedoProcess { let mut writebuf: Vec = Vec::new(); build_begin_redo_for_block_msg(tag, &mut writebuf); if let Some(img) = base_img { + if !page_verify_checksum(&img, tag.blknum) { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("block {} of relation {} is invalid", tag.blknum, tag.rel), + )); + } + build_push_page_msg(tag, &img, &mut writebuf); } for (lsn, rec) in records.iter() { @@ -734,6 +742,25 @@ impl PostgresRedoProcess { rec: postgres_rec, } = rec { + let xlogrec = XLogRecord::from_buf(postgres_rec).map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!( + "could not deserialize WAL record for relation {} at LSN {}: {}", + tag.rel, lsn, e + ), + ) + })?; + if !wal_record_verify_checksum(&xlogrec, postgres_rec) { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!( + "WAL record for relation {} at LSN {} is invalid", + tag.rel, lsn + ), + )); + } + build_apply_record_msg(*lsn, postgres_rec, &mut writebuf); } else { return Err(Error::new(