diff --git a/libs/postgres_ffi/src/waldecoder.rs b/libs/postgres_ffi/src/waldecoder.rs index d4b7efbac4..cbb761236c 100644 --- a/libs/postgres_ffi/src/waldecoder.rs +++ b/libs/postgres_ffi/src/waldecoder.rs @@ -18,19 +18,25 @@ use bytes::{Buf, BufMut, Bytes, BytesMut}; use crc32c::*; use log::*; use std::cmp::min; +use std::num::NonZeroU32; use thiserror::Error; use utils::lsn::Lsn; +enum State { + WaitingForRecord, + ReassemblingRecord { + recordbuf: BytesMut, + contlen: NonZeroU32, + }, + SkippingEverything { + skip_until_lsn: Lsn, + }, +} + pub struct WalStreamDecoder { lsn: Lsn, - - contlen: u32, - padlen: u32, - inputbuf: BytesMut, - - /// buffer used to reassemble records that cross page boundaries. - recordbuf: BytesMut, + state: State, } #[derive(Error, Debug, Clone)] @@ -48,12 +54,8 @@ impl WalStreamDecoder { pub fn new(lsn: Lsn) -> WalStreamDecoder { WalStreamDecoder { lsn, - - contlen: 0, - padlen: 0, - inputbuf: BytesMut::new(), - recordbuf: BytesMut::new(), + state: State::WaitingForRecord, } } @@ -80,26 +82,39 @@ impl WalStreamDecoder { hdr.xlp_pageaddr, self.lsn )); } - if self.contlen == 0 { - if hdr.xlp_info & XLP_FIRST_IS_CONTRECORD != 0 { - return Err( - "invalid xlog page header: unexpected XLP_FIRST_IS_CONTRECORD".into(), - ); + match self.state { + State::WaitingForRecord => { + if hdr.xlp_info & XLP_FIRST_IS_CONTRECORD != 0 { + return Err( + "invalid xlog page header: unexpected XLP_FIRST_IS_CONTRECORD".into(), + ); + } + if hdr.xlp_rem_len != 0 { + return Err(format!( + "invalid xlog page header: xlp_rem_len={}, but it's not a contrecord", + hdr.xlp_rem_len + )); + } } - } else { - if hdr.xlp_info & XLP_FIRST_IS_CONTRECORD == 0 { - return Err( - "invalid xlog page header: XLP_FIRST_IS_CONTRECORD expected, not found" - .into(), - ); + State::ReassemblingRecord { contlen, .. } => { + if hdr.xlp_info & XLP_FIRST_IS_CONTRECORD == 0 { + return Err( + "invalid xlog page header: XLP_FIRST_IS_CONTRECORD expected, not found" + .into(), + ); + } + if hdr.xlp_rem_len != contlen.get() { + return Err(format!( + "invalid xlog page header: xlp_rem_len={}, expected {}", + hdr.xlp_rem_len, + contlen.get() + )); + } } - } - if hdr.xlp_rem_len != self.contlen { - return Err(format!( - "invalid xlog page header: xlp_rem_len={}, expected {}", - hdr.xlp_rem_len, self.contlen - )); - } + State::SkippingEverything { .. } => { + panic!("Should not be validating page header in the SkippingEverything state"); + } + }; Ok(()) }; validate_impl().map_err(|msg| WalDecodeError { msg, lsn: self.lsn }) @@ -114,115 +129,121 @@ impl WalStreamDecoder { /// Err(WalDecodeError): an error occurred while decoding, meaning the input was invalid. /// pub fn poll_decode(&mut self) -> Result, WalDecodeError> { - let recordbuf; - // Run state machine that validates page headers, and reassembles records // that cross page boundaries. loop { // parse and verify page boundaries as we go - if self.padlen > 0 { - // We should first skip padding, as we may have to skip some page headers if we're processing the XLOG_SWITCH record. - if self.inputbuf.remaining() < self.padlen as usize { - return Ok(None); - } + // However, we may have to skip some page headers if we're processing the XLOG_SWITCH record or skipping padding for whatever reason. + match self.state { + State::WaitingForRecord | State::ReassemblingRecord { .. } => { + if self.lsn.segment_offset(pg_constants::WAL_SEGMENT_SIZE) == 0 { + // parse long header - // skip padding - self.inputbuf.advance(self.padlen as usize); - self.lsn += self.padlen as u64; - self.padlen = 0; - } else if self.lsn.segment_offset(pg_constants::WAL_SEGMENT_SIZE) == 0 { - // parse long header + if self.inputbuf.remaining() < XLOG_SIZE_OF_XLOG_LONG_PHD { + return Ok(None); + } - if self.inputbuf.remaining() < XLOG_SIZE_OF_XLOG_LONG_PHD { - return Ok(None); - } + let hdr = XLogLongPageHeaderData::from_bytes(&mut self.inputbuf).map_err( + |e| WalDecodeError { + msg: format!("long header deserialization failed {}", e), + lsn: self.lsn, + }, + )?; - let hdr = XLogLongPageHeaderData::from_bytes(&mut self.inputbuf).map_err(|e| { - WalDecodeError { - msg: format!("long header deserialization failed {}", e), - lsn: self.lsn, + self.validate_page_header(&hdr.std)?; + + self.lsn += XLOG_SIZE_OF_XLOG_LONG_PHD as u64; + } else if self.lsn.block_offset() == 0 { + if self.inputbuf.remaining() < XLOG_SIZE_OF_XLOG_SHORT_PHD { + return Ok(None); + } + + let hdr = + XLogPageHeaderData::from_bytes(&mut self.inputbuf).map_err(|e| { + WalDecodeError { + msg: format!("header deserialization failed {}", e), + lsn: self.lsn, + } + })?; + + self.validate_page_header(&hdr)?; + + self.lsn += XLOG_SIZE_OF_XLOG_SHORT_PHD as u64; } - })?; - - self.validate_page_header(&hdr.std)?; - - self.lsn += XLOG_SIZE_OF_XLOG_LONG_PHD as u64; - continue; - } else if self.lsn.block_offset() == 0 { - if self.inputbuf.remaining() < XLOG_SIZE_OF_XLOG_SHORT_PHD { - return Ok(None); } - - let hdr = XLogPageHeaderData::from_bytes(&mut self.inputbuf).map_err(|e| { - WalDecodeError { - msg: format!("header deserialization failed {}", e), - lsn: self.lsn, + State::SkippingEverything { .. } => {} + } + match &mut self.state { + State::WaitingForRecord => { + // need to have at least the xl_tot_len field + if self.inputbuf.remaining() < 4 { + return Ok(None); } - })?; - self.validate_page_header(&hdr)?; - - self.lsn += XLOG_SIZE_OF_XLOG_SHORT_PHD as u64; - continue; - } else if self.contlen == 0 { - assert!(self.recordbuf.is_empty()); - - // need to have at least the xl_tot_len field - if self.inputbuf.remaining() < 4 { - return Ok(None); + // peek xl_tot_len at the beginning of the record. + // FIXME: assumes little-endian + let xl_tot_len = (&self.inputbuf[0..4]).get_u32_le(); + if (xl_tot_len as usize) < XLOG_SIZE_OF_XLOG_RECORD { + return Err(WalDecodeError { + msg: format!("invalid xl_tot_len {}", xl_tot_len), + lsn: self.lsn, + }); + } + // Fast path for the common case that the whole record fits on the page. + let pageleft = self.lsn.remaining_in_block() as u32; + if self.inputbuf.remaining() >= xl_tot_len as usize && xl_tot_len <= pageleft { + self.lsn += xl_tot_len as u64; + let recordbuf = self.inputbuf.copy_to_bytes(xl_tot_len as usize); + return Ok(Some(self.complete_record(recordbuf)?)); + } else { + // Need to assemble the record from pieces. Remember the size of the + // record, and loop back. On next iteration, we will reach the 'else' + // branch below, and copy the part of the record that was on this page + // to 'recordbuf'. Subsequent iterations will skip page headers, and + // append the continuations from the next pages to 'recordbuf'. + self.state = State::ReassemblingRecord { + recordbuf: BytesMut::with_capacity(xl_tot_len as usize), + contlen: NonZeroU32::new(xl_tot_len).unwrap(), + } + } } + State::ReassemblingRecord { recordbuf, contlen } => { + // we're continuing a record, possibly from previous page. + let pageleft = self.lsn.remaining_in_block() as u32; - // peek xl_tot_len at the beginning of the record. - // FIXME: assumes little-endian - let xl_tot_len = (&self.inputbuf[0..4]).get_u32_le(); - if (xl_tot_len as usize) < XLOG_SIZE_OF_XLOG_RECORD { - return Err(WalDecodeError { - msg: format!("invalid xl_tot_len {}", xl_tot_len), - lsn: self.lsn, - }); + // read the rest of the record, or as much as fits on this page. + let n = min(contlen.get(), pageleft) as usize; + + if self.inputbuf.remaining() < n { + return Ok(None); + } + + recordbuf.put(self.inputbuf.split_to(n)); + self.lsn += n as u64; + *contlen = match NonZeroU32::new(contlen.get() - n as u32) { + Some(x) => x, + None => { + // The record is now complete. + let recordbuf = std::mem::replace(recordbuf, BytesMut::new()).freeze(); + return Ok(Some(self.complete_record(recordbuf)?)); + } + } } - - // Fast path for the common case that the whole record fits on the page. - let pageleft = self.lsn.remaining_in_block() as u32; - if self.inputbuf.remaining() >= xl_tot_len as usize && xl_tot_len <= pageleft { - // Take the record from the 'inputbuf', and validate it. - recordbuf = self.inputbuf.copy_to_bytes(xl_tot_len as usize); - self.lsn += xl_tot_len as u64; - break; - } else { - // Need to assemble the record from pieces. Remember the size of the - // record, and loop back. On next iteration, we will reach the 'else' - // branch below, and copy the part of the record that was on this page - // to 'recordbuf'. Subsequent iterations will skip page headers, and - // append the continuations from the next pages to 'recordbuf'. - self.recordbuf.reserve(xl_tot_len as usize); - self.contlen = xl_tot_len; - continue; + State::SkippingEverything { skip_until_lsn } => { + assert!(*skip_until_lsn >= self.lsn); + let n = skip_until_lsn.0 - self.lsn.0; + if self.inputbuf.remaining() < n as usize { + return Ok(None); + } + self.inputbuf.advance(n as usize); + self.lsn += n; + self.state = State::WaitingForRecord; } - } else { - // we're continuing a record, possibly from previous page. - let pageleft = self.lsn.remaining_in_block() as u32; - - // read the rest of the record, or as much as fits on this page. - let n = min(self.contlen, pageleft) as usize; - - if self.inputbuf.remaining() < n { - return Ok(None); - } - - self.recordbuf.put(self.inputbuf.split_to(n)); - self.lsn += n as u64; - self.contlen -= n as u32; - - if self.contlen == 0 { - // The record is now complete. - recordbuf = std::mem::replace(&mut self.recordbuf, BytesMut::new()).freeze(); - break; - } - continue; } } + } + fn complete_record(&mut self, recordbuf: Bytes) -> Result<(Lsn, Bytes), WalDecodeError> { // We now have a record in the 'recordbuf' local variable. let xlogrec = XLogRecord::from_slice(&recordbuf[0..XLOG_SIZE_OF_XLOG_RECORD]).map_err(|e| { @@ -244,18 +265,20 @@ impl WalStreamDecoder { // XLOG_SWITCH records are special. If we see one, we need to skip // to the next WAL segment. - if xlogrec.is_xlog_switch_record() { + let next_lsn = if xlogrec.is_xlog_switch_record() { trace!("saw xlog switch record at {}", self.lsn); - self.padlen = self.lsn.calc_padding(pg_constants::WAL_SEGMENT_SIZE as u64) as u32; + self.lsn + self.lsn.calc_padding(pg_constants::WAL_SEGMENT_SIZE as u64) } else { // Pad to an 8-byte boundary - self.padlen = self.lsn.calc_padding(8u32) as u32; - } + self.lsn.align() + }; + self.state = State::SkippingEverything { + skip_until_lsn: next_lsn, + }; // We should return LSN of the next record, not the last byte of this record or // the byte immediately after. Note that this handles both XLOG_SWITCH and usual // records, the former "spans" until the next WAL segment (see test_xlog_switch). - let result = (self.lsn + self.padlen as u64, recordbuf); - Ok(Some(result)) + Ok((next_lsn, recordbuf)) } }