diff --git a/Cargo.lock b/Cargo.lock index 9d2e335bd2..b5bac69326 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1408,6 +1408,7 @@ dependencies = [ "tokio-stream", "tui", "walkdir", + "zenith_utils", ] [[package]] @@ -2755,3 +2756,7 @@ dependencies = [ [[package]] name = "zenith_utils" version = "0.1.0" +dependencies = [ + "thiserror", + "tokio", +] diff --git a/control_plane/src/compute.rs b/control_plane/src/compute.rs index d5fecba6df..9807756232 100644 --- a/control_plane/src/compute.rs +++ b/control_plane/src/compute.rs @@ -3,23 +3,22 @@ use std::io::{Read, Write}; use std::net::SocketAddr; use std::net::TcpStream; use std::os::unix::fs::PermissionsExt; +use std::path::Path; use std::process::Command; use std::sync::Arc; use std::time::Duration; use std::{collections::BTreeMap, path::PathBuf}; -use std::path::Path; use anyhow::{Context, Result}; use lazy_static::lazy_static; use regex::Regex; -use tar; use postgres::{Client, NoTls}; use crate::local_env::LocalEnv; use crate::storage::{PageServerNode, WalProposerNode}; -use pageserver::ZTimelineId; use pageserver::zenith_repo_dir; +use pageserver::ZTimelineId; // // ComputeControlPlane @@ -192,11 +191,11 @@ impl PostgresNode { ); let port: u16 = CONF_PORT_RE .captures(config.as_str()) - .ok_or(anyhow::Error::msg(err_msg.clone() + " 1"))? + .ok_or_else(|| anyhow::Error::msg(err_msg.clone() + " 1"))? .iter() .last() - .ok_or(anyhow::Error::msg(err_msg.clone() + " 2"))? - .ok_or(anyhow::Error::msg(err_msg.clone() + " 3"))? + .ok_or_else(|| anyhow::Error::msg(err_msg.clone() + " 2"))? + .ok_or_else(|| anyhow::Error::msg(err_msg.clone() + " 3"))? .as_str() .parse() .with_context(|| err_msg)?; @@ -294,7 +293,7 @@ impl PostgresNode { // slot or something proper, to prevent the compute node // from removing WAL that hasn't been streamed to the safekeepr or // page server yet. But this will do for now. - self.append_conf("postgresql.conf", &format!("wal_keep_size='10TB'\n")); + self.append_conf("postgresql.conf", "wal_keep_size='10TB'\n"); // Connect it to the page server. @@ -447,10 +446,9 @@ impl PostgresNode { } } - pub fn pg_regress(&self) { self.safe_psql("postgres", "CREATE DATABASE regression"); - let data_dir = zenith_repo_dir(); + let data_dir = zenith_repo_dir(); let regress_run_path = data_dir.join("regress"); fs::create_dir_all(regress_run_path.clone()).unwrap(); fs::create_dir_all(regress_run_path.join("testtablespace")).unwrap(); diff --git a/control_plane/src/local_env.rs b/control_plane/src/local_env.rs index db71721e21..ce3badf857 100644 --- a/control_plane/src/local_env.rs +++ b/control_plane/src/local_env.rs @@ -15,8 +15,8 @@ use std::process::{Command, Stdio}; use anyhow::Result; use serde_derive::{Deserialize, Serialize}; -use pageserver::ZTimelineId; use pageserver::zenith_repo_dir; +use pageserver::ZTimelineId; use walkeeper::xlog_utils; // @@ -101,7 +101,7 @@ pub fn init() -> Result<()> { // ok, we are good to go let mut conf = LocalEnv { - repo_path: repo_path.clone(), + repo_path, pg_distrib_dir, zenith_distrib_dir, systemid: 0, @@ -247,7 +247,7 @@ pub fn test_env(testname: &str) -> LocalEnv { systemid: 0, }; init_repo(&mut local_env).expect("could not initialize zenith repository"); - return local_env; + local_env } // Find the directory where the binaries were put (i.e. target/debug/) @@ -259,7 +259,7 @@ pub fn cargo_bin_dir() -> PathBuf { pathbuf.pop(); } - return pathbuf; + pathbuf } #[derive(Debug, Clone, Copy)] @@ -351,7 +351,7 @@ pub fn find_end_of_wal(local_env: &LocalEnv, timeline: ZTimelineId) -> Result Result<(u32, u32, u32), FilePathError> { u32::from_str_radix(segno_match.unwrap().as_str(), 10)? }; - return Ok((relnode, forknum, segno)); + Ok((relnode, forknum, segno)) } fn parse_rel_file_path(path: &str) -> Result<(), FilePathError> { @@ -172,9 +172,9 @@ fn parse_rel_file_path(path: &str) -> Result<(), FilePathError> { if let Some(fname) = path.strip_prefix("global/") { let (_relnode, _forknum, _segno) = parse_filename(fname)?; - return Ok(()); + Ok(()) } else if let Some(dbpath) = path.strip_prefix("base/") { - let mut s = dbpath.split("/"); + let mut s = dbpath.split('/'); let dbnode_str = s .next() .ok_or_else(|| FilePathError::new("invalid relation data file name"))?; @@ -188,15 +188,15 @@ fn parse_rel_file_path(path: &str) -> Result<(), FilePathError> { let (_relnode, _forknum, _segno) = parse_filename(fname)?; - return Ok(()); + Ok(()) } else if let Some(_) = path.strip_prefix("pg_tblspc/") { // TODO - return Err(FilePathError::new("tablespaces not supported")); + Err(FilePathError::new("tablespaces not supported")) } else { - return Err(FilePathError::new("invalid relation data file name")); + Err(FilePathError::new("invalid relation data file name")) } } fn is_rel_file_path(path: &str) -> bool { - return parse_rel_file_path(path).is_ok(); + parse_rel_file_path(path).is_ok() } diff --git a/pageserver/src/bin/pageserver.rs b/pageserver/src/bin/pageserver.rs index 0a966a81a6..8801e5de14 100644 --- a/pageserver/src/bin/pageserver.rs +++ b/pageserver/src/bin/pageserver.rs @@ -3,6 +3,7 @@ // use log::*; +use parse_duration::parse; use std::fs; use std::fs::OpenOptions; use std::io; @@ -10,7 +11,6 @@ use std::path::PathBuf; use std::process::exit; use std::thread; use std::time::Duration; -use parse_duration::parse; use anyhow::{Context, Result}; use clap::{App, Arg}; @@ -19,12 +19,12 @@ use daemonize::Daemonize; use slog::Drain; use pageserver::page_service; -use pageserver::zenith_repo_dir; use pageserver::tui; +use pageserver::zenith_repo_dir; //use pageserver::walreceiver; use pageserver::PageServerConf; -const DEFAULT_GC_HORIZON : u64 = 64*1024*1024; +const DEFAULT_GC_HORIZON: u64 = 64 * 1024 * 1024; fn main() -> Result<()> { let arg_matches = App::new("Zenith page server") @@ -63,7 +63,7 @@ fn main() -> Result<()> { daemonize: false, interactive: false, gc_horizon: DEFAULT_GC_HORIZON, - gc_period: Duration::from_secs(10), + gc_period: Duration::from_secs(10), listen_addr: "127.0.0.1:5430".parse().unwrap(), }; @@ -139,7 +139,7 @@ fn start_pageserver(conf: &PageServerConf) -> Result<()> { .with_context(|| format!("failed to open {:?}", &log_filename))?; let daemonize = Daemonize::new() - .pid_file(repodir.clone().join("pageserver.pid")) + .pid_file(repodir.join("pageserver.pid")) .working_directory(repodir) .stdout(stdout) .stderr(stderr); @@ -183,9 +183,9 @@ fn start_pageserver(conf: &PageServerConf) -> Result<()> { .unwrap(); threads.push(page_server_thread); - if tui_thread.is_some() { + if let Some(tui_thread) = tui_thread { // The TUI thread exits when the user asks to Quit. - tui_thread.unwrap().join().unwrap(); + tui_thread.join().unwrap(); } else { // In non-interactive mode, wait forever. for t in threads { @@ -203,18 +203,19 @@ fn init_logging(conf: &PageServerConf) -> Result Result>, pub walredo_receiver: Receiver>, - valid_lsn_condvar: Condvar, + // Allows .await on the arrival of a particular LSN. + seqwait_lsn: SeqWait, // Counters, for metrics collection. pub num_entries: AtomicU64, @@ -51,6 +53,7 @@ pub struct PageCache { pub first_valid_lsn: AtomicU64, pub last_valid_lsn: AtomicU64, pub last_record_lsn: AtomicU64, + walreceiver_works: AtomicBool, } #[derive(Clone)] @@ -99,7 +102,6 @@ struct PageCacheShared { first_valid_lsn: u64, last_valid_lsn: u64, last_record_lsn: u64, - walreceiver_works: bool, } lazy_static! { @@ -155,15 +157,15 @@ pub fn get_or_restore_pagecache( }) .unwrap(); - return Ok(result); + Ok(result) } } } fn gc_thread_main(conf: &PageServerConf, timelineid: ZTimelineId) { info!("Garbage collection thread started {}", timelineid); - let pcache = get_pagecache(conf, timelineid).unwrap(); - pcache.do_gc(conf).unwrap(); + let pcache = get_pagecache(conf, timelineid).unwrap(); + pcache.do_gc(conf).unwrap(); } fn open_rocksdb(_conf: &PageServerConf, timelineid: ZTimelineId) -> DB { @@ -185,9 +187,8 @@ fn init_page_cache(conf: &PageServerConf, timelineid: ZTimelineId) -> PageCache first_valid_lsn: 0, last_valid_lsn: 0, last_record_lsn: 0, - walreceiver_works: false, }), - valid_lsn_condvar: Condvar::new(), + seqwait_lsn: SeqWait::new(0), walredo_sender: s, walredo_receiver: r, @@ -200,6 +201,7 @@ fn init_page_cache(conf: &PageServerConf, timelineid: ZTimelineId) -> PageCache first_valid_lsn: AtomicU64::new(0), last_valid_lsn: AtomicU64::new(0), last_record_lsn: AtomicU64::new(0), + walreceiver_works: AtomicBool::new(false), } } @@ -355,7 +357,7 @@ impl WALRecord { buf.put_u64(self.lsn); buf.put_u8(self.will_init as u8); buf.put_u8(self.truncate as u8); - buf.put_u32(self.main_data_offset); + buf.put_u32(self.main_data_offset); buf.put_u32(self.rec.len() as u32); buf.put_slice(&self.rec[..]); } @@ -363,7 +365,7 @@ impl WALRecord { let lsn = buf.get_u64(); let will_init = buf.get_u8() != 0; let truncate = buf.get_u8() != 0; - let main_data_offset = buf.get_u32(); + let main_data_offset = buf.get_u32(); let mut dst = vec![0u8; buf.get_u32() as usize]; buf.copy_to_slice(&mut dst); WALRecord { @@ -371,7 +373,7 @@ impl WALRecord { will_init, truncate, rec: Bytes::from(dst), - main_data_offset + main_data_offset, } } } @@ -379,84 +381,88 @@ impl WALRecord { // Public interface functions impl PageCache { - fn do_gc(&self, conf: &PageServerConf) -> anyhow::Result { - let mut minbuf = BytesMut::new(); - let mut maxbuf = BytesMut::new(); - let cf = self.db.cf_handle(DEFAULT_COLUMN_FAMILY_NAME).unwrap(); - loop { - thread::sleep(conf.gc_period); - let last_lsn = self.get_last_valid_lsn(); - if last_lsn > conf.gc_horizon { - let horizon = last_lsn - conf.gc_horizon; - let mut maxkey = CacheKey { - tag: BufferTag { - rel: RelTag { - spcnode: u32::MAX, - dbnode: u32::MAX, - relnode: u32::MAX, - forknum: u8::MAX, - }, - blknum: u32::MAX, - }, - lsn: u64::MAX - }; - loop { - maxbuf.clear(); - maxkey.pack(&mut maxbuf); - let mut iter = self.db.iterator(IteratorMode::From(&maxbuf[..], Direction::Reverse)); - if let Some((k,v)) = iter.next() { - minbuf.clear(); - minbuf.extend_from_slice(&v); - let content = CacheEntryContent::unpack(&mut minbuf); - minbuf.clear(); - minbuf.extend_from_slice(&k); - let key = CacheKey::unpack(&mut minbuf); + fn do_gc(&self, conf: &PageServerConf) -> anyhow::Result { + let mut minbuf = BytesMut::new(); + let mut maxbuf = BytesMut::new(); + let cf = self.db.cf_handle(DEFAULT_COLUMN_FAMILY_NAME).unwrap(); + loop { + thread::sleep(conf.gc_period); + let last_lsn = self.get_last_valid_lsn(); + if last_lsn > conf.gc_horizon { + let horizon = last_lsn - conf.gc_horizon; + let mut maxkey = CacheKey { + tag: BufferTag { + rel: RelTag { + spcnode: u32::MAX, + dbnode: u32::MAX, + relnode: u32::MAX, + forknum: u8::MAX, + }, + blknum: u32::MAX, + }, + lsn: u64::MAX, + }; + loop { + maxbuf.clear(); + maxkey.pack(&mut maxbuf); + let mut iter = self + .db + .iterator(IteratorMode::From(&maxbuf[..], Direction::Reverse)); + if let Some((k, v)) = iter.next() { + minbuf.clear(); + minbuf.extend_from_slice(&v); + let content = CacheEntryContent::unpack(&mut minbuf); + minbuf.clear(); + minbuf.extend_from_slice(&k); + let key = CacheKey::unpack(&mut minbuf); - // Construct boundaries for old records cleanup - maxkey.tag = key.tag; - let last_lsn = key.lsn; - maxkey.lsn = min(horizon, last_lsn); // do not remove last version + // Construct boundaries for old records cleanup + maxkey.tag = key.tag; + let last_lsn = key.lsn; + maxkey.lsn = min(horizon, last_lsn); // do not remove last version - let mut minkey = maxkey.clone(); - minkey.lsn = 0; + let mut minkey = maxkey.clone(); + minkey.lsn = 0; - // reconstruct most recent page version - if content.wal_record.is_some() { - // force reconstruction of most recent page version - self.reconstruct_page(key, content)?; - } + // reconstruct most recent page version + if content.wal_record.is_some() { + // force reconstruction of most recent page version + self.reconstruct_page(key, content)?; + } - maxbuf.clear(); - maxkey.pack(&mut maxbuf); + maxbuf.clear(); + maxkey.pack(&mut maxbuf); - if last_lsn > horizon { - // locate most recent record before horizon - let mut iter = self.db.iterator(IteratorMode::From(&maxbuf[..], Direction::Reverse)); - if let Some((k,v)) = iter.next() { - minbuf.clear(); - minbuf.extend_from_slice(&v); - let content = CacheEntryContent::unpack(&mut minbuf); - if content.wal_record.is_some() { - minbuf.clear(); - minbuf.extend_from_slice(&k); - let key = CacheKey::unpack(&mut minbuf); - self.reconstruct_page(key, content)?; - } - } - } - // remove records prior to horizon - minbuf.clear(); - minkey.pack(&mut minbuf); - self.db.delete_range_cf(cf, &minbuf[..], &maxbuf[..])?; + if last_lsn > horizon { + // locate most recent record before horizon + let mut iter = self + .db + .iterator(IteratorMode::From(&maxbuf[..], Direction::Reverse)); + if let Some((k, v)) = iter.next() { + minbuf.clear(); + minbuf.extend_from_slice(&v); + let content = CacheEntryContent::unpack(&mut minbuf); + if content.wal_record.is_some() { + minbuf.clear(); + minbuf.extend_from_slice(&k); + let key = CacheKey::unpack(&mut minbuf); + self.reconstruct_page(key, content)?; + } + } + } + // remove records prior to horizon + minbuf.clear(); + minkey.pack(&mut minbuf); + self.db.delete_range_cf(cf, &minbuf[..], &maxbuf[..])?; - maxkey = minkey; - } - } - } - } - } + maxkey = minkey; + } + } + } + } + } - fn reconstruct_page(&self, key: CacheKey, content: CacheEntryContent) -> anyhow::Result { + fn reconstruct_page(&self, key: CacheKey, content: CacheEntryContent) -> anyhow::Result { let entry_rc = Arc::new(CacheEntry::new(key.clone(), content)); let mut entry_content = entry_rc.content.lock().unwrap(); @@ -473,80 +479,56 @@ impl PageCache { let page_img = match &entry_content.page_image { Some(p) => p.clone(), None => { - error!( - "could not apply WAL to reconstruct page image for GetPage@LSN request" - ); + error!("could not apply WAL to reconstruct page image for GetPage@LSN request"); bail!("could not apply WAL to reconstruct page image"); } }; self.put_page_image(key.tag, key.lsn, page_img.clone()); - Ok(page_img) - } + Ok(page_img) + } - fn wait_lsn(&self, lsn: u64) -> anyhow::Result<()> { - let mut shared = self.shared.lock().unwrap(); - let mut waited = false; - - // There is a a race at postgres instance start - // when we request a page before walsender established connection - // and was able to stream the page. Just don't wait and return what we have. - // TODO is there any corner case when this is incorrect? - if !shared.walreceiver_works { + async fn wait_lsn(&self, lsn: u64) -> anyhow::Result<()> { + let walreceiver_works = self.walreceiver_works.load(Ordering::Acquire); + if walreceiver_works { + self.seqwait_lsn + .wait_for_timeout(lsn, TIMEOUT) + .await + .with_context(|| { + format!( + "Timed out while waiting for WAL record at LSN {:X}/{:X} to arrive", + lsn >> 32, + lsn & 0xffff_ffff + ) + })?; + } else { + // There is a a race at postgres instance start + // when we request a page before walsender established connection + // and was able to stream the page. Just don't wait and return what we have. + // TODO is there any corner case when this is incorrect? trace!( - " walreceiver doesn't work yet last_valid_lsn {}, requested {}", - shared.last_valid_lsn, + "walreceiver doesn't work yet last_valid_lsn {}, requested {}", + self.last_valid_lsn.load(Ordering::Acquire), lsn ); } - if shared.walreceiver_works { + let shared = self.shared.lock().unwrap(); - while lsn > shared.last_valid_lsn { - // TODO: Wait for the WAL receiver to catch up - waited = true; - trace!( - "not caught up yet: {}, requested {}", - shared.last_valid_lsn, - lsn - ); - let wait_result = self - .valid_lsn_condvar - .wait_timeout(shared, TIMEOUT) - .unwrap(); - - shared = wait_result.0; - if wait_result.1.timed_out() { - bail!( - "Timed out while waiting for WAL record at LSN {:X}/{:X} to arrive", - lsn >> 32, - lsn & 0xffff_ffff - ); - } - } - } - if waited { - trace!("caught up now, continuing"); - } - - if lsn < shared.first_valid_lsn { - bail!( - "LSN {:X}/{:X} has already been removed", - lsn >> 32, - lsn & 0xffff_ffff - ); - } - Ok(()) - } + if walreceiver_works { + assert!(lsn <= shared.last_valid_lsn); + } + Ok(()) + } // // GetPage@LSN // // Returns an 8k page image // - pub fn get_page_at_lsn(&self, tag: BufferTag, lsn: u64) -> anyhow::Result { + pub async fn get_page_at_lsn(&self, tag: BufferTag, lsn: u64) -> anyhow::Result { self.num_getpage_requests.fetch_add(1, Ordering::Relaxed); - self.wait_lsn(lsn)?; + self.wait_lsn(lsn).await?; // Look up cache entry. If it's a page image, return that. If it's a WAL record, // ask the WAL redo service to reconstruct the page image from the WAL records. @@ -581,8 +563,8 @@ impl PageCache { } else if content.wal_record.is_some() { buf.clear(); buf.extend_from_slice(&k); - let key = CacheKey::unpack(&mut buf); - page_img = self.reconstruct_page(key, content)?; + let key = CacheKey::unpack(&mut buf); + page_img = self.reconstruct_page(key, content)?; } else { // No base image, and no WAL record. Huh? bail!("no page image or WAL record for requested page"); @@ -602,7 +584,7 @@ impl PageCache { tag.blknum ); - return Ok(page_img); + Ok(page_img) } // @@ -660,7 +642,7 @@ impl PageCache { } records.reverse(); - return (base_img, records); + (base_img, records) } // @@ -692,9 +674,9 @@ impl PageCache { // Adds a relation-wide WAL record (like truncate) to the page cache, // associating it with all pages started with specified block number // - pub fn put_rel_wal_record(&self, tag: BufferTag, rec: WALRecord) { + pub async fn put_rel_wal_record(&self, tag: BufferTag, rec: WALRecord) -> anyhow::Result<()> { let mut key = CacheKey { tag, lsn: rec.lsn }; - let old_rel_size = self.relsize_get(&tag.rel, u64::MAX).unwrap(); + let old_rel_size = self.relsize_get(&tag.rel, u64::MAX).await?; let content = CacheEntryContent { page_image: None, wal_record: Some(rec), @@ -716,6 +698,7 @@ impl PageCache { let n = (old_rel_size - tag.blknum) as u64; self.num_entries.fetch_add(n, Ordering::Relaxed); self.num_wal_records.fetch_add(n, Ordering::Relaxed); + Ok(()) } // @@ -751,11 +734,11 @@ impl PageCache { if lsn >= oldlsn { // Now we receive entries from walreceiver and should wait if from_walreceiver { - shared.walreceiver_works = true; + self.walreceiver_works.store(true, Ordering::Release); } shared.last_valid_lsn = lsn; - self.valid_lsn_condvar.notify_all(); + self.seqwait_lsn.advance(lsn); self.last_valid_lsn.store(lsn, Ordering::Relaxed); } else { @@ -781,7 +764,7 @@ impl PageCache { shared.last_valid_lsn = lsn; shared.last_record_lsn = lsn; - self.valid_lsn_condvar.notify_all(); + self.seqwait_lsn.advance(lsn); self.last_valid_lsn.store(lsn, Ordering::Relaxed); self.last_record_lsn.store(lsn, Ordering::Relaxed); @@ -821,13 +804,13 @@ impl PageCache { pub fn get_last_valid_lsn(&self) -> u64 { let shared = self.shared.lock().unwrap(); - return shared.last_record_lsn; + shared.last_record_lsn } - pub fn relsize_get(&self, rel: &RelTag, lsn: u64) -> anyhow::Result { - if lsn != u64::MAX { - self.wait_lsn(lsn)?; - } + pub async fn relsize_get(&self, rel: &RelTag, lsn: u64) -> anyhow::Result { + if lsn != u64::MAX { + self.wait_lsn(lsn).await?; + } let mut key = CacheKey { tag: BufferTag { @@ -867,11 +850,11 @@ impl PageCache { } break; } - return Ok(0); + Ok(0) } - pub fn relsize_exist(&self, rel: &RelTag, lsn: u64) -> anyhow::Result { - self.wait_lsn(lsn)?; + pub async fn relsize_exist(&self, rel: &RelTag, lsn: u64) -> anyhow::Result { + self.wait_lsn(lsn).await?; let key = CacheKey { tag: BufferTag { @@ -893,7 +876,7 @@ impl PageCache { return Ok(true); } } - return Ok(false); + Ok(false) } pub fn get_stats(&self) -> PageCacheStats { diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index 7ce285164c..2ba1c64de9 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -18,13 +18,13 @@ use std::io; use std::str::FromStr; use std::sync::Arc; use std::thread; +use std::time::Duration; use tokio::io::{AsyncReadExt, AsyncWriteExt, BufWriter}; use tokio::net::{TcpListener, TcpStream}; use tokio::runtime; use tokio::runtime::Runtime; use tokio::sync::mpsc; use tokio::task; -use std::time::Duration; use crate::basebackup; use crate::page_cache; @@ -186,12 +186,11 @@ fn read_null_terminated(buf: &mut Bytes) -> Result { } result.put_u8(byte); } - return Ok(result.freeze()); + Ok(result.freeze()) } impl FeParseMessage { - pub fn parse(body: Bytes) -> Result { - let mut buf = body.clone(); + pub fn parse(mut buf: Bytes) -> Result { let _pstmt_name = read_null_terminated(&mut buf)?; let query_string = read_null_terminated(&mut buf)?; let nparams = buf.get_i16(); @@ -201,7 +200,7 @@ impl FeParseMessage { // now, just ignore the statement name, assuming that the client never // uses more than one prepared statement at a time. /* - if pstmt_name.len() != 0 { + if !pstmt_name.is_empty() { return Err(io::Error::new( io::ErrorKind::InvalidInput, "named prepared statements not implemented in Parse", @@ -227,14 +226,13 @@ struct FeDescribeMessage { } impl FeDescribeMessage { - pub fn parse(body: Bytes) -> Result { - let mut buf = body.clone(); + pub fn parse(mut buf: Bytes) -> Result { let kind = buf.get_u8(); let _pstmt_name = read_null_terminated(&mut buf)?; // FIXME: see FeParseMessage::parse /* - if pstmt_name.len() != 0 { + if !pstmt_name.is_empty() { return Err(io::Error::new( io::ErrorKind::InvalidInput, "named prepared statements not implemented in Describe", @@ -261,12 +259,11 @@ struct FeExecuteMessage { } impl FeExecuteMessage { - pub fn parse(body: Bytes) -> Result { - let mut buf = body.clone(); + pub fn parse(mut buf: Bytes) -> Result { let portal_name = read_null_terminated(&mut buf)?; let maxrows = buf.get_i32(); - if portal_name.len() != 0 { + if !portal_name.is_empty() { return Err(io::Error::new( io::ErrorKind::InvalidInput, "named portals not implemented", @@ -289,12 +286,11 @@ impl FeExecuteMessage { struct FeBindMessage {} impl FeBindMessage { - pub fn parse(body: Bytes) -> Result { - let mut buf = body.clone(); + pub fn parse(mut buf: Bytes) -> Result { let portal_name = read_null_terminated(&mut buf)?; let _pstmt_name = read_null_terminated(&mut buf)?; - if portal_name.len() != 0 { + if !portal_name.is_empty() { return Err(io::Error::new( io::ErrorKind::InvalidInput, "named portals not implemented", @@ -303,7 +299,7 @@ impl FeBindMessage { // FIXME: see FeParseMessage::parse /* - if pstmt_name.len() != 0 { + if !pstmt_name.is_empty() { return Err(io::Error::new( io::ErrorKind::InvalidInput, "named prepared statements not implemented", @@ -320,8 +316,7 @@ impl FeBindMessage { struct FeCloseMessage {} impl FeCloseMessage { - pub fn parse(body: Bytes) -> Result { - let mut buf = body.clone(); + pub fn parse(mut buf: Bytes) -> Result { let _kind = buf.get_u8(); let _pstmt_or_portal_name = read_null_terminated(&mut buf)?; @@ -362,7 +357,7 @@ impl FeMessage { let mut body = body.freeze(); match tag { - b'Q' => Ok(Some(FeMessage::Query(FeQueryMessage { body: body }))), + b'Q' => Ok(Some(FeMessage::Query(FeQueryMessage { body }))), b'P' => Ok(Some(FeParseMessage::parse(body)?)), b'D' => Ok(Some(FeDescribeMessage::parse(body)?)), b'E' => Ok(Some(FeExecuteMessage::parse(body)?)), @@ -423,7 +418,7 @@ pub fn thread_main(conf: &PageServerConf) { let runtime_ref = Arc::new(runtime); - runtime_ref.clone().block_on(async { + runtime_ref.block_on(async { let listener = TcpListener::bind(conf.listen_addr).await.unwrap(); loop { @@ -534,7 +529,7 @@ impl Connection { BeMessage::RowDescription => { // XXX - let mut b = Bytes::from("data\0"); + let b = Bytes::from("data\0"); self.stream.write_u8(b'T').await?; self.stream @@ -542,7 +537,7 @@ impl Connection { .await?; self.stream.write_i16(1).await?; - self.stream.write_all(&mut b).await?; + self.stream.write_all(&b).await?; self.stream.write_i32(0).await?; /* table oid */ self.stream.write_i16(0).await?; /* attnum */ self.stream.write_i32(25).await?; /* TEXTOID */ @@ -554,34 +549,34 @@ impl Connection { // XXX: accept some text data BeMessage::DataRow => { // XXX - let mut b = Bytes::from("hello world"); + let b = Bytes::from("hello world"); self.stream.write_u8(b'D').await?; self.stream.write_i32(4 + 2 + 4 + b.len() as i32).await?; self.stream.write_i16(1).await?; self.stream.write_i32(b.len() as i32).await?; - self.stream.write_all(&mut b).await?; + self.stream.write_all(&b).await?; } BeMessage::ControlFile => { // TODO pass checkpoint and xid info in this message - let mut b = Bytes::from("hello pg_control"); + let b = Bytes::from("hello pg_control"); self.stream.write_u8(b'D').await?; self.stream.write_i32(4 + 2 + 4 + b.len() as i32).await?; self.stream.write_i16(1).await?; self.stream.write_i32(b.len() as i32).await?; - self.stream.write_all(&mut b).await?; + self.stream.write_all(&b).await?; } BeMessage::CommandComplete => { - let mut b = Bytes::from("SELECT 1\0"); + let b = Bytes::from("SELECT 1\0"); self.stream.write_u8(b'C').await?; self.stream.write_i32(4 + b.len() as i32).await?; - self.stream.write_all(&mut b).await?; + self.stream.write_all(&b).await?; } BeMessage::ZenithStatusResponse(resp) => { @@ -608,7 +603,7 @@ impl Connection { self.stream.write_u8(102).await?; /* tag from pagestore_client.h */ self.stream.write_u8(resp.ok as u8).await?; self.stream.write_u32(resp.n_blocks).await?; - self.stream.write_all(&mut resp.page.clone()).await?; + self.stream.write_all(&resp.page.clone()).await?; } } @@ -631,8 +626,8 @@ impl Connection { match m.kind { StartupRequestCode::NegotiateGss | StartupRequestCode::NegotiateSsl => { - let mut b = Bytes::from("N"); - self.stream.write_all(&mut b).await?; + let b = Bytes::from("N"); + self.stream.write_all(&b).await?; self.stream.flush().await?; } StartupRequestCode::Normal => { @@ -724,7 +719,7 @@ impl Connection { let caps = re.captures(&query_str); let caps = caps.unwrap(); - let timelineid = ZTimelineId::from_str(caps.get(1).unwrap().as_str().clone()).unwrap(); + let timelineid = ZTimelineId::from_str(caps.get(1).unwrap().as_str()).unwrap(); let connstr: String = String::from(caps.get(2).unwrap().as_str()); // Check that the timeline exists @@ -804,7 +799,7 @@ impl Connection { forknum: req.forknum, }; - let exist = pcache.relsize_exist(&tag, req.lsn).unwrap_or(false); + let exist = pcache.relsize_exist(&tag, req.lsn).await.unwrap_or(false); self.write_message(&BeMessage::ZenithStatusResponse(ZenithStatusResponse { ok: exist, @@ -820,7 +815,7 @@ impl Connection { forknum: req.forknum, }; - let n_blocks = pcache.relsize_get(&tag, req.lsn).unwrap_or(0); + let n_blocks = pcache.relsize_get(&tag, req.lsn).await.unwrap_or(0); self.write_message(&BeMessage::ZenithNblocksResponse(ZenithStatusResponse { ok: true, @@ -839,7 +834,7 @@ impl Connection { blknum: req.blkno, }; - let msg = match pcache.get_page_at_lsn(buf_tag, req.lsn) { + let msg = match pcache.get_page_at_lsn(buf_tag, req.lsn).await { Ok(p) => BeMessage::ZenithReadResponse(ZenithReadResponse { ok: true, n_blocks: 0, @@ -896,13 +891,10 @@ impl Connection { let f_tar2 = async { let joinres = f_tar.await; - if joinres.is_err() { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - joinres.unwrap_err(), - )); + if let Err(joinreserr) = joinres { + return Err(io::Error::new(io::ErrorKind::InvalidData, joinreserr)); } - return joinres.unwrap(); + joinres.unwrap() }; let f_pump = async move { @@ -911,12 +903,12 @@ impl Connection { if buf.is_none() { break; } - let mut buf = buf.unwrap(); + let buf = buf.unwrap(); // CopyData stream.write_u8(b'd').await?; stream.write_u32((4 + buf.len()) as u32).await?; - stream.write_all(&mut buf).await?; + stream.write_all(&buf).await?; trace!("CopyData sent for {} bytes!", buf.len()); // FIXME: flush isn't really required, but makes it easier diff --git a/pageserver/src/restore_local_repo.rs b/pageserver/src/restore_local_repo.rs index bf5e48a76e..90598b3c4f 100644 --- a/pageserver/src/restore_local_repo.rs +++ b/pageserver/src/restore_local_repo.rs @@ -27,9 +27,9 @@ use anyhow::Result; use bytes::Bytes; use crate::page_cache; -use crate::page_cache::RelTag; use crate::page_cache::BufferTag; use crate::page_cache::PageCache; +use crate::page_cache::RelTag; use crate::waldecoder::{decode_wal_record, WalStreamDecoder}; use crate::PageServerConf; use crate::ZTimelineId; @@ -187,10 +187,9 @@ fn restore_relfile( // Does it look like a relation file? let p = parse_relfilename(path.file_name().unwrap().to_str().unwrap()); - if p.is_err() { - let e = p.unwrap_err(); + if let Err(e) = p { warn!("unrecognized file in snapshot: {:?} ({})", path, e); - return Err(e)?; + return Err(e.into()); } let (relnode, forknum, segno) = p.unwrap(); @@ -205,12 +204,12 @@ fn restore_relfile( Ok(_) => { let tag = BufferTag { rel: RelTag { - spcnode: spcoid, - dbnode: dboid, - relnode: relnode, - forknum: forknum as u8, - }, - blknum: blknum, + spcnode: spcoid, + dbnode: dboid, + relnode: relnode, + forknum: forknum as u8, + }, + blknum, }; pcache.put_page_image(tag, lsn, Bytes::copy_from_slice(&buf)); /* @@ -249,7 +248,7 @@ fn restore_wal( ) -> Result<()> { let walpath = format!("timelines/{}/wal", timeline); - let mut waldecoder = WalStreamDecoder::new(u64::from(startpoint)); + let mut waldecoder = WalStreamDecoder::new(startpoint); let mut segno = XLByteToSeg(startpoint, 16 * 1024 * 1024); let mut offset = XLogSegmentOffset(startpoint, 16 * 1024 * 1024); @@ -261,7 +260,7 @@ fn restore_wal( // It could be as .partial if !PathBuf::from(&path).exists() { - path = path + ".partial"; + path += ".partial"; } // Slurp the WAL file @@ -303,18 +302,18 @@ fn restore_wal( for blk in decoded.blocks.iter() { let tag = BufferTag { rel: RelTag { - spcnode: blk.rnode_spcnode, - dbnode: blk.rnode_dbnode, - relnode: blk.rnode_relnode, - forknum: blk.forknum as u8, - }, + spcnode: blk.rnode_spcnode, + dbnode: blk.rnode_dbnode, + relnode: blk.rnode_relnode, + forknum: blk.forknum as u8, + }, blknum: blk.blkno, }; let rec = page_cache::WALRecord { - lsn: lsn, + lsn, will_init: blk.will_init || blk.apply_image, - truncate: false, + truncate: false, rec: recdata.clone(), main_data_offset: decoded.main_data_offset as u32, }; @@ -483,5 +482,5 @@ fn parse_relfilename(fname: &str) -> Result<(u32, u32, u32), FilePathError> { u32::from_str_radix(segno_match.unwrap().as_str(), 10)? }; - return Ok((relnode, forknum, segno)); + Ok((relnode, forknum, segno)) } diff --git a/pageserver/src/restore_s3.rs b/pageserver/src/restore_s3.rs index 253e6e589b..6b9d7d1ad8 100644 --- a/pageserver/src/restore_s3.rs +++ b/pageserver/src/restore_s3.rs @@ -38,12 +38,9 @@ pub fn restore_main(conf: &PageServerConf) { let result = restore_chunk(conf).await; match result { - Ok(_) => { - return; - } + Ok(_) => {} Err(err) => { error!("S3 error: {}", err); - return; } } }); @@ -199,7 +196,7 @@ fn parse_filename(fname: &str) -> Result<(u32, u32, u32, u64), FilePathError> { .ok_or_else(|| FilePathError::new("invalid relation data file name"))?; let relnode_str = caps.name("relnode").unwrap().as_str(); - let relnode = u32::from_str_radix(relnode_str, 10)?; + let relnode: u32 = relnode_str.parse()?; let forkname_match = caps.name("forkname"); let forkname = if forkname_match.is_none() { @@ -213,14 +210,14 @@ fn parse_filename(fname: &str) -> Result<(u32, u32, u32, u64), FilePathError> { let segno = if segno_match.is_none() { 0 } else { - u32::from_str_radix(segno_match.unwrap().as_str(), 10)? + segno_match.unwrap().as_str().parse::()? }; - let lsn_hi = u64::from_str_radix(caps.name("lsnhi").unwrap().as_str(), 16)?; - let lsn_lo = u64::from_str_radix(caps.name("lsnlo").unwrap().as_str(), 16)?; + let lsn_hi: u64 = caps.name("lsnhi").unwrap().as_str().parse()?; + let lsn_lo: u64 = caps.name("lsnlo").unwrap().as_str().parse()?; let lsn = lsn_hi << 32 | lsn_lo; - return Ok((relnode, forknum, segno, lsn)); + Ok((relnode, forknum, segno, lsn)) } fn parse_rel_file_path(path: &str) -> Result { @@ -244,20 +241,20 @@ fn parse_rel_file_path(path: &str) -> Result Result slog_scope::GlobalLoggerGuard { { return true; } - return false; + false }) .fuse(); @@ -41,7 +41,7 @@ pub fn init_logging() -> slog_scope::GlobalLoggerGuard { { return true; } - return false; + false }) .fuse(); @@ -52,7 +52,7 @@ pub fn init_logging() -> slog_scope::GlobalLoggerGuard { { return true; } - return false; + false }) .fuse(); @@ -65,7 +65,7 @@ pub fn init_logging() -> slog_scope::GlobalLoggerGuard { { return true; } - return false; + false }) .fuse(); @@ -84,11 +84,11 @@ pub fn init_logging() -> slog_scope::GlobalLoggerGuard { return true; } - return false; + false }) .fuse(); let logger = slog::Logger::root(drain, slog::o!()); - return slog_scope::set_global_logger(logger); + slog_scope::set_global_logger(logger) } pub fn ui_main() -> Result<(), Box> { diff --git a/pageserver/src/tui_event.rs b/pageserver/src/tui_event.rs index 5546b680ee..d88cac5d5b 100644 --- a/pageserver/src/tui_event.rs +++ b/pageserver/src/tui_event.rs @@ -76,8 +76,8 @@ impl Events { }; Events { rx, - ignore_exit_key, input_handle, + ignore_exit_key, tick_handle, } } diff --git a/pageserver/src/tui_logger.rs b/pageserver/src/tui_logger.rs index dcb4a23467..663add4065 100644 --- a/pageserver/src/tui_logger.rs +++ b/pageserver/src/tui_logger.rs @@ -51,7 +51,7 @@ impl Drain for TuiLogger { events.pop_back(); } - return Ok(()); + Ok(()) } } diff --git a/pageserver/src/waldecoder.rs b/pageserver/src/waldecoder.rs index 623b9b7189..15dca57786 100644 --- a/pageserver/src/waldecoder.rs +++ b/pageserver/src/waldecoder.rs @@ -227,7 +227,7 @@ impl WalStreamDecoder { // FIXME: check that hdr.xlp_rem_len matches self.contlen //println!("next xlog page (xlp_rem_len: {})", hdr.xlp_rem_len); - return hdr; + hdr } #[allow(non_snake_case)] @@ -239,7 +239,7 @@ impl WalStreamDecoder { xlp_xlog_blcksz: self.inputbuf.get_u32_le(), }; - return hdr; + hdr } } @@ -352,7 +352,7 @@ fn is_xlog_switch_record(rec: &Bytes) -> bool { buf.advance(2); // 2 bytes of padding let _xl_crc = buf.get_u32_le(); - return xl_info == pg_constants::XLOG_SWITCH && xl_rmid == pg_constants::RM_XLOG_ID; + xl_info == pg_constants::XLOG_SWITCH && xl_rmid == pg_constants::RM_XLOG_ID } pub type Oid = u32; @@ -680,7 +680,7 @@ pub fn decode_wal_record(record: Bytes) -> DecodedWALRecord { } DecodedWALRecord { - xl_info, + xl_info, xl_rmid, record, blocks, diff --git a/pageserver/src/walreceiver.rs b/pageserver/src/walreceiver.rs index 8e8b61989e..da1cfbd6c7 100644 --- a/pageserver/src/walreceiver.rs +++ b/pageserver/src/walreceiver.rs @@ -244,7 +244,8 @@ async fn walreceiver_main( } // include truncate wal record in all pages if decoded.xl_rmid == pg_constants::RM_SMGR_ID - && (decoded.xl_info & pg_constants::XLR_RMGR_INFO_MASK) == pg_constants::XLOG_SMGR_TRUNCATE + && (decoded.xl_info & pg_constants::XLR_RMGR_INFO_MASK) + == pg_constants::XLOG_SMGR_TRUNCATE { let truncate = decode_truncate_record(&decoded); if (truncate.flags & SMGR_TRUNCATE_HEAP) != 0 { @@ -262,9 +263,9 @@ async fn walreceiver_main( will_init: false, truncate: true, rec: recdata.clone(), - main_data_offset: decoded.main_data_offset as u32, + main_data_offset: decoded.main_data_offset as u32, }; - pcache.put_rel_wal_record(tag, rec); + pcache.put_rel_wal_record(tag, rec).await?; } } // Now that this record has been handled, let the page cache know that @@ -438,7 +439,7 @@ fn write_wal_file( let mut bytes_written: usize = 0; let mut partial; let mut start_pos = startpos; - const ZERO_BLOCK: &'static [u8] = &[0u8; XLOG_BLCKSZ]; + const ZERO_BLOCK: &[u8] = &[0u8; XLOG_BLCKSZ]; let wal_dir = PathBuf::from(format!("timelines/{}/wal", timeline)); diff --git a/pageserver/src/walredo.rs b/pageserver/src/walredo.rs index 106c17ba22..fafbb376d8 100644 --- a/pageserver/src/walredo.rs +++ b/pageserver/src/walredo.rs @@ -226,7 +226,7 @@ fn handle_apply_request( // Wake up the requester, whether the operation succeeded or not. entry_rc.walredo_condvar.notify_all(); - return result; + result } struct WalRedoProcess { @@ -325,7 +325,7 @@ impl WalRedoProcess { ) -> Result { let mut stdin = self.stdin.borrow_mut(); let mut stdout = self.stdout.borrow_mut(); - return runtime.block_on(async { + runtime.block_on(async { // // This async block sends all the commands to the process. // @@ -388,7 +388,7 @@ impl WalRedoProcess { let buf = res.0; Ok::(Bytes::from(std::vec::Vec::from(buf))) - }); + }) } } @@ -396,13 +396,13 @@ fn build_begin_redo_for_block_msg(tag: BufferTag) -> Bytes { let len = 4 + 5 * 4; let mut buf = BytesMut::with_capacity(1 + len); - buf.put_u8('B' as u8); + buf.put_u8(b'B'); buf.put_u32(len as u32); tag.pack(&mut buf); assert!(buf.len() == 1 + len); - return buf.freeze(); + buf.freeze() } fn build_push_page_msg(tag: BufferTag, base_img: Bytes) -> Bytes { @@ -411,39 +411,39 @@ fn build_push_page_msg(tag: BufferTag, base_img: Bytes) -> Bytes { let len = 4 + 5 * 4 + base_img.len(); let mut buf = BytesMut::with_capacity(1 + len); - buf.put_u8('P' as u8); + buf.put_u8(b'P'); buf.put_u32(len as u32); tag.pack(&mut buf); buf.put(base_img); assert!(buf.len() == 1 + len); - return buf.freeze(); + buf.freeze() } fn build_apply_record_msg(endlsn: u64, rec: Bytes) -> Bytes { let len = 4 + 8 + rec.len(); let mut buf = BytesMut::with_capacity(1 + len); - buf.put_u8('A' as u8); + buf.put_u8(b'A'); buf.put_u32(len as u32); buf.put_u64(endlsn); buf.put(rec); assert!(buf.len() == 1 + len); - return buf.freeze(); + buf.freeze() } fn build_get_page_msg(tag: BufferTag) -> Bytes { let len = 4 + 5 * 4; let mut buf = BytesMut::with_capacity(1 + len); - buf.put_u8('G' as u8); + buf.put_u8(b'G'); buf.put_u32(len as u32); tag.pack(&mut buf); assert!(buf.len() == 1 + len); - return buf.freeze(); + buf.freeze() } diff --git a/postgres_ffi/src/lib.rs b/postgres_ffi/src/lib.rs index b6cf6bdb2b..59cad0db39 100644 --- a/postgres_ffi/src/lib.rs +++ b/postgres_ffi/src/lib.rs @@ -18,13 +18,13 @@ impl ControlFileData { controlfile = unsafe { std::mem::transmute::<[u8; SIZEOF_CONTROLDATA], ControlFileData>(b) }; - return controlfile; + controlfile } } -pub fn decode_pg_control(buf: Bytes) -> Result { +pub fn decode_pg_control(mut buf: Bytes) -> Result { let mut b: [u8; SIZEOF_CONTROLDATA] = [0u8; SIZEOF_CONTROLDATA]; - buf.clone().copy_to_slice(&mut b); + buf.copy_to_slice(&mut b); let controlfile: ControlFileData; @@ -63,5 +63,5 @@ pub fn encode_pg_control(controlfile: ControlFileData) -> Bytes { // Fill the rest of the control file with zeros. buf.resize(PG_CONTROL_FILE_SIZE as usize, 0); - return buf.into(); + buf.into() } diff --git a/walkeeper/src/bin/wal_acceptor.rs b/walkeeper/src/bin/wal_acceptor.rs index 8dfa31e23b..57503b1912 100644 --- a/walkeeper/src/bin/wal_acceptor.rs +++ b/walkeeper/src/bin/wal_acceptor.rs @@ -69,7 +69,7 @@ fn main() -> Result<()> { let mut conf = WalAcceptorConf { data_dir: PathBuf::from("./"), - systemid: systemid, + systemid, daemonize: false, no_sync: false, pageserver_addr: None, diff --git a/walkeeper/src/pq_protocol.rs b/walkeeper/src/pq_protocol.rs index f6e18d9aa4..57517c322f 100644 --- a/walkeeper/src/pq_protocol.rs +++ b/walkeeper/src/pq_protocol.rs @@ -91,9 +91,9 @@ impl FeStartupMessage { options = true; } else if options { for opt in p.split(' ') { - if opt.starts_with("ztimelineid=") { + if let Some(ztimelineid_str) = opt.strip_prefix("ztimelineid=") { // FIXME: rethrow parsing error, don't unwrap - timelineid = Some(ZTimelineId::from_str(&opt[12..]).unwrap()); + timelineid = Some(ZTimelineId::from_str(ztimelineid_str).unwrap()); break; } } diff --git a/walkeeper/src/wal_service.rs b/walkeeper/src/wal_service.rs index 64627d33b5..9d7e6a8bfc 100644 --- a/walkeeper/src/wal_service.rs +++ b/walkeeper/src/wal_service.rs @@ -444,7 +444,7 @@ impl Timeline { fn get_hs_feedback(&self) -> HotStandbyFeedback { let shared_state = self.mutex.lock().unwrap(); - return shared_state.hs_feedback; + shared_state.hs_feedback } // Load and lock control file (prevent running more than one instance of safekeeper) @@ -527,7 +527,7 @@ impl Timeline { let file = shared_state.control_file.as_mut().unwrap(); file.seek(SeekFrom::Start(0))?; - file.write_all(&mut buf[..])?; + file.write_all(&buf[..])?; if sync { file.sync_all()?; } @@ -554,7 +554,7 @@ impl Connection { async fn run(&mut self) -> Result<()> { self.inbuf.resize(4, 0u8); self.stream.read_exact(&mut self.inbuf[0..4]).await?; - let startup_pkg_len = BigEndian::read_u32(&mut self.inbuf[0..4]); + let startup_pkg_len = BigEndian::read_u32(&self.inbuf[0..4]); if startup_pkg_len == 0 { self.receive_wal().await?; // internal protocol between wal_proposer and wal_acceptor } else { @@ -997,12 +997,12 @@ impl Connection { // Try to fetch replica's feedback match self.stream.try_read_buf(&mut self.inbuf) { Ok(0) => break, - Ok(_) => match self.parse_message()? { - Some(FeMessage::CopyData(m)) => self - .timeline() - .add_hs_feedback(HotStandbyFeedback::parse(&m.body)), - _ => {} - }, + Ok(_) => { + if let Some(FeMessage::CopyData(m)) = self.parse_message()? { + self.timeline() + .add_hs_feedback(HotStandbyFeedback::parse(&m.body)) + } + } Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {} Err(e) => { return Err(e); @@ -1102,7 +1102,7 @@ impl Connection { let mut bytes_written: usize = 0; let mut partial; let mut start_pos = startpos; - const ZERO_BLOCK: &'static [u8] = &[0u8; XLOG_BLCKSZ]; + const ZERO_BLOCK: &[u8] = &[0u8; XLOG_BLCKSZ]; /* Extract WAL location for this block */ let mut xlogoff = XLogSegmentOffset(start_pos, wal_seg_size) as usize; diff --git a/walkeeper/src/xlog_utils.rs b/walkeeper/src/xlog_utils.rs index 7c18131186..c31a160cce 100644 --- a/walkeeper/src/xlog_utils.rs +++ b/walkeeper/src/xlog_utils.rs @@ -23,17 +23,17 @@ pub type XLogSegNo = u64; #[allow(non_snake_case)] pub fn XLogSegmentOffset(xlogptr: XLogRecPtr, wal_segsz_bytes: usize) -> u32 { - return (xlogptr as u32) & (wal_segsz_bytes as u32 - 1); + (xlogptr as u32) & (wal_segsz_bytes as u32 - 1) } #[allow(non_snake_case)] pub fn XLogSegmentsPerXLogId(wal_segsz_bytes: usize) -> XLogSegNo { - return (0x100000000u64 / wal_segsz_bytes as u64) as XLogSegNo; + (0x100000000u64 / wal_segsz_bytes as u64) as XLogSegNo } #[allow(non_snake_case)] pub fn XLByteToSeg(xlogptr: XLogRecPtr, wal_segsz_bytes: usize) -> XLogSegNo { - return xlogptr / wal_segsz_bytes as u64; + xlogptr / wal_segsz_bytes as u64 } #[allow(non_snake_case)] @@ -42,7 +42,7 @@ pub fn XLogSegNoOffsetToRecPtr( offset: u32, wal_segsz_bytes: usize, ) -> XLogRecPtr { - return segno * (wal_segsz_bytes as u64) + (offset as u64); + segno * (wal_segsz_bytes as u64) + (offset as u64) } #[allow(non_snake_case)] @@ -60,7 +60,7 @@ pub fn XLogFromFileName(fname: &str, wal_seg_size: usize) -> (XLogSegNo, TimeLin let tli = u32::from_str_radix(&fname[0..8], 16).unwrap(); let log = u32::from_str_radix(&fname[8..16], 16).unwrap() as XLogSegNo; let seg = u32::from_str_radix(&fname[16..24], 16).unwrap() as XLogSegNo; - return (log * XLogSegmentsPerXLogId(wal_seg_size) + seg, tli); + (log * XLogSegmentsPerXLogId(wal_seg_size) + seg, tli) } #[allow(non_snake_case)] @@ -70,7 +70,7 @@ pub fn IsXLogFileName(fname: &str) -> bool { #[allow(non_snake_case)] pub fn IsPartialXLogFileName(fname: &str) -> bool { - return fname.ends_with(".partial") && IsXLogFileName(&fname[0..fname.len() - 8]); + fname.ends_with(".partial") && IsXLogFileName(&fname[0..fname.len() - 8]) } pub fn get_current_timestamp() -> TimestampTz { @@ -181,7 +181,7 @@ fn find_end_of_wal_segment( } } } - return last_valid_rec_pos as u32; + last_valid_rec_pos as u32 } pub fn find_end_of_wal( @@ -237,7 +237,7 @@ pub fn find_end_of_wal( let high_ptr = XLogSegNoOffsetToRecPtr(high_segno, high_offs, wal_seg_size); return (high_ptr, high_tli); } - return (0, 0); + (0, 0) } pub fn main() { diff --git a/zenith/src/main.rs b/zenith/src/main.rs index 53d1528a6b..8cbd97e1ea 100644 --- a/zenith/src/main.rs +++ b/zenith/src/main.rs @@ -76,7 +76,7 @@ fn main() -> Result<()> { // all other commands would need config - let repopath = PathBuf::from(zenith_repo_dir()); + let repopath = zenith_repo_dir(); if !repopath.exists() { bail!( "Zenith repository does not exists in {}.\n\ @@ -186,7 +186,7 @@ fn handle_pg(pg_match: &ArgMatches, env: &local_env::LocalEnv) -> Result<()> { let node = cplane .nodes .get(name) - .ok_or(anyhow!("postgres {} is not found", name))?; + .ok_or_else(|| anyhow!("postgres {} is not found", name))?; node.start()?; } ("stop", Some(sub_m)) => { @@ -194,7 +194,7 @@ fn handle_pg(pg_match: &ArgMatches, env: &local_env::LocalEnv) -> Result<()> { let node = cplane .nodes .get(name) - .ok_or(anyhow!("postgres {} is not found", name))?; + .ok_or_else(|| anyhow!("postgres {} is not found", name))?; node.stop()?; } @@ -277,19 +277,19 @@ fn list_branches() -> Result<()> { // // fn parse_point_in_time(s: &str) -> Result { - let mut strings = s.split("@"); + let mut strings = s.split('@'); let name = strings.next().unwrap(); let lsn: Option; if let Some(lsnstr) = strings.next() { - let mut s = lsnstr.split("/"); + let mut s = lsnstr.split('/'); let lsn_hi: u64 = s .next() - .ok_or(anyhow!("invalid LSN in point-in-time specification"))? + .ok_or_else(|| anyhow!("invalid LSN in point-in-time specification"))? .parse()?; let lsn_lo: u64 = s .next() - .ok_or(anyhow!("invalid LSN in point-in-time specification"))? + .ok_or_else(|| anyhow!("invalid LSN in point-in-time specification"))? .parse()?; lsn = Some(lsn_hi << 32 | lsn_lo); } else { @@ -312,11 +312,8 @@ fn parse_point_in_time(s: &str) -> Result { let pointstr = fs::read_to_string(branchpath)?; let mut result = parse_point_in_time(&pointstr)?; - if lsn.is_some() { - result.lsn = lsn.unwrap(); - } else { - result.lsn = 0; - } + + result.lsn = lsn.unwrap_or(0); return Ok(result); } diff --git a/zenith_utils/Cargo.toml b/zenith_utils/Cargo.toml index 77bc1e9ecb..a26a772c97 100644 --- a/zenith_utils/Cargo.toml +++ b/zenith_utils/Cargo.toml @@ -5,3 +5,8 @@ authors = ["Eric Seppanen "] edition = "2018" [dependencies] +tokio = { version = "1.5", features = ["sync", "time" ] } +thiserror = "1" + +[dev-dependencies] +tokio = { version = "1.5", features = ["macros", "rt"] } diff --git a/zenith_utils/src/lib.rs b/zenith_utils/src/lib.rs index 2d86ad041f..8acd9cb84b 100644 --- a/zenith_utils/src/lib.rs +++ b/zenith_utils/src/lib.rs @@ -1,2 +1,4 @@ //! zenith_utils is intended to be a place to put code that is shared //! between other crates in this repository. + +pub mod seqwait; diff --git a/zenith_utils/src/seqwait.rs b/zenith_utils/src/seqwait.rs new file mode 100644 index 0000000000..bd94b8b350 --- /dev/null +++ b/zenith_utils/src/seqwait.rs @@ -0,0 +1,199 @@ +use std::collections::BTreeMap; +use std::mem; +use std::sync::Mutex; +use std::time::Duration; +use tokio::sync::watch::{channel, Receiver, Sender}; +use tokio::time::timeout; + +/// An error happened while waiting for a number +#[derive(Debug, PartialEq, thiserror::Error)] +#[error("SeqWaitError")] +pub enum SeqWaitError { + /// The wait timeout was reached + Timeout, + /// [`SeqWait::shutdown`] was called + Shutdown, +} + +/// Internal components of a `SeqWait` +struct SeqWaitInt { + waiters: BTreeMap, Receiver<()>)>, + current: u64, + shutdown: bool, +} + +/// A tool for waiting on a sequence number +/// +/// This provides a way to await the arrival of a number. +/// As soon as the number arrives by another caller calling +/// [`advance`], then the waiter will be woken up. +/// +/// This implementation takes a blocking Mutex on both [`wait_for`] +/// and [`advance`], meaning there may be unexpected executor blocking +/// due to thread scheduling unfairness. There are probably better +/// implementations, but we can probably live with this for now. +/// +/// [`wait_for`]: SeqWait::wait_for +/// [`advance`]: SeqWait::advance +/// +pub struct SeqWait { + internal: Mutex, +} + +impl SeqWait { + /// Create a new `SeqWait`, initialized to a particular number + pub fn new(starting_num: u64) -> Self { + let internal = SeqWaitInt { + waiters: BTreeMap::new(), + current: starting_num, + shutdown: false, + }; + SeqWait { + internal: Mutex::new(internal), + } + } + + /// Shut down a `SeqWait`, causing all waiters (present and + /// future) to return an error. + pub fn shutdown(&self) { + let waiters = { + // Prevent new waiters; wake all those that exist. + // Wake everyone with an error. + let mut internal = self.internal.lock().unwrap(); + + // This will steal the entire waiters map. + // When we drop it all waiters will be woken. + mem::take(&mut internal.waiters); + + // Drop the lock as we exit this scope. + }; + + // When we drop the waiters list, each Receiver will + // be woken with an error. + // This drop doesn't need to be explicit; it's done + // here to make it easier to read the code and understand + // the order of events. + drop(waiters); + } + + /// Wait for a number to arrive + /// + /// This call won't complete until someone has called `advance` + /// with a number greater than or equal to the one we're waiting for. + pub async fn wait_for(&self, num: u64) -> Result<(), SeqWaitError> { + let mut rx = { + let mut internal = self.internal.lock().unwrap(); + if internal.current >= num { + return Ok(()); + } + if internal.shutdown { + return Err(SeqWaitError::Shutdown); + } + + // If we already have a channel for waiting on this number, reuse it. + if let Some((_, rx)) = internal.waiters.get_mut(&num) { + // an Err from changed() means the sender was dropped. + rx.clone() + } else { + // Create a new channel. + let (tx, rx) = channel(()); + internal.waiters.insert(num, (tx, rx.clone())); + rx + } + // Drop the lock as we exit this scope. + }; + rx.changed().await.map_err(|_| SeqWaitError::Shutdown) + } + + /// Wait for a number to arrive + /// + /// This call won't complete until someone has called `advance` + /// with a number greater than or equal to the one we're waiting for. + /// + /// If that hasn't happened after the specified timeout duration, + /// [`SeqWaitError::Timeout`] will be returned. + pub async fn wait_for_timeout( + &self, + num: u64, + timeout_duration: Duration, + ) -> Result<(), SeqWaitError> { + timeout(timeout_duration, self.wait_for(num)) + .await + .unwrap_or(Err(SeqWaitError::Timeout)) + } + + /// Announce a new number has arrived + /// + /// All waiters at this value or below will be woken. + /// + /// `advance` will panic if you send it a lower number than + /// a previous call. + pub fn advance(&self, num: u64) { + let wake_these = { + let mut internal = self.internal.lock().unwrap(); + + if internal.current > num { + panic!( + "tried to advance backwards, from {} to {}", + internal.current, num + ); + } + internal.current = num; + + // split_off will give me all the high-numbered waiters, + // so split and then swap. Everything at or above (num + 1) + // gets to stay. + let mut split = internal.waiters.split_off(&(num + 1)); + std::mem::swap(&mut split, &mut internal.waiters); + split + }; + + for (_wake_num, (tx, _rx)) in wake_these { + // This can fail if there are no receivers. + // We don't care; discard the error. + let _ = tx.send(()); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + use tokio::time::{sleep, Duration}; + + #[tokio::test] + async fn seqwait() { + let seq = Arc::new(SeqWait::new(0)); + let seq2 = Arc::clone(&seq); + let seq3 = Arc::clone(&seq); + tokio::spawn(async move { + seq2.wait_for(42).await.expect("wait_for 42"); + seq2.advance(100); + seq2.wait_for(999).await.expect_err("no 999"); + }); + tokio::spawn(async move { + seq3.wait_for(42).await.expect("wait_for 42"); + seq3.wait_for(0).await.expect("wait_for 0"); + }); + sleep(Duration::from_secs(1)).await; + seq.advance(99); + seq.wait_for(100).await.expect("wait_for 100"); + seq.shutdown(); + } + + #[tokio::test] + async fn seqwait_timeout() { + let seq = Arc::new(SeqWait::new(0)); + let seq2 = Arc::clone(&seq); + tokio::spawn(async move { + let timeout = Duration::from_millis(1); + let res = seq2.wait_for_timeout(42, timeout).await; + assert_eq!(res, Err(SeqWaitError::Timeout)); + }); + sleep(Duration::from_secs(1)).await; + // This will attempt to wake, but nothing will happen + // because the waiter already dropped its Receiver. + seq.advance(99); + } +}