diff --git a/Cargo.lock b/Cargo.lock index da38779279..072aebc032 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2761,5 +2761,4 @@ name = "zenith_utils" version = "0.1.0" dependencies = [ "thiserror", - "tokio", ] diff --git a/pageserver/src/page_cache.rs b/pageserver/src/page_cache.rs index 667b16fb90..aae2fcef77 100644 --- a/pageserver/src/page_cache.rs +++ b/pageserver/src/page_cache.rs @@ -62,7 +62,7 @@ pub struct PageCache { // WAL redo manager walredo_mgr: WalRedoManager, - // Allows .await on the arrival of a particular LSN. + // Allows waiting for the arrival of a particular LSN. seqwait_lsn: SeqWait, // Counters, for metrics collection. @@ -170,12 +170,7 @@ fn gc_thread_main(conf: &PageServerConf, timelineid: ZTimelineId) { info!("Garbage collection thread started {}", timelineid); let pcache = get_pagecache(conf, timelineid).unwrap(); - let runtime = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .unwrap(); - - runtime.block_on(pcache.do_gc(conf)).unwrap(); + pcache.do_gc(conf).unwrap(); } fn open_rocksdb(_conf: &PageServerConf, timelineid: ZTimelineId) -> rocksdb::DB { @@ -380,10 +375,10 @@ impl PageCache { /// /// Returns an 8k page image /// - pub async fn get_page_at_lsn(&self, tag: BufferTag, req_lsn: Lsn) -> anyhow::Result { + pub fn get_page_at_lsn(&self, tag: BufferTag, req_lsn: Lsn) -> anyhow::Result { self.num_getpage_requests.fetch_add(1, Ordering::Relaxed); - let lsn = self.wait_lsn(req_lsn).await?; + let lsn = self.wait_lsn(req_lsn)?; // Look up cache entry. If it's a page image, return that. If it's a WAL record, // ask the WAL redo service to reconstruct the page image from the WAL records. @@ -409,7 +404,7 @@ impl PageCache { page_img = img.clone(); } else if content.wal_record.is_some() { // Request the WAL redo manager to apply the WAL records for us. - page_img = self.walredo_mgr.request_redo(tag, lsn).await?; + page_img = self.walredo_mgr.request_redo(tag, lsn)?; } else { // No base image, and no WAL record. Huh? bail!("no page image or WAL record for requested page"); @@ -441,16 +436,16 @@ impl PageCache { /// /// Get size of relation at given LSN. /// - pub async fn relsize_get(&self, rel: &RelTag, lsn: Lsn) -> anyhow::Result { - self.wait_lsn(lsn).await?; + pub fn relsize_get(&self, rel: &RelTag, lsn: Lsn) -> anyhow::Result { + self.wait_lsn(lsn)?; return self.relsize_get_nowait(rel, lsn); } /// /// Does relation exist at given LSN? /// - pub async fn relsize_exist(&self, rel: &RelTag, req_lsn: Lsn) -> anyhow::Result { - let lsn = self.wait_lsn(req_lsn).await?; + pub fn relsize_exist(&self, rel: &RelTag, req_lsn: Lsn) -> anyhow::Result { + let lsn = self.wait_lsn(req_lsn)?; let key = CacheKey { tag: BufferTag { @@ -815,7 +810,7 @@ impl PageCache { Ok(0) } - async fn do_gc(&self, conf: &PageServerConf) -> anyhow::Result { + fn do_gc(&self, conf: &PageServerConf) -> anyhow::Result { let mut buf = BytesMut::new(); loop { thread::sleep(conf.gc_period); @@ -867,7 +862,7 @@ impl PageCache { if (v[0] & PAGE_IMAGE_FLAG) == 0 { trace!("Reconstruct most recent page {:?}", key); // force reconstruction of most recent page version - self.walredo_mgr.request_redo(key.tag, key.lsn).await?; + self.walredo_mgr.request_redo(key.tag, key.lsn)?; reconstructed += 1; } @@ -887,7 +882,7 @@ impl PageCache { let v = iter.value().unwrap(); if (v[0] & PAGE_IMAGE_FLAG) == 0 { trace!("Reconstruct horizon page {:?}", key); - self.walredo_mgr.request_redo(key.tag, key.lsn).await?; + self.walredo_mgr.request_redo(key.tag, key.lsn)?; truncated += 1; } } @@ -930,7 +925,7 @@ impl PageCache { // // Wait until WAL has been received up to the given LSN. // - async fn wait_lsn(&self, mut lsn: Lsn) -> anyhow::Result { + fn wait_lsn(&self, mut lsn: Lsn) -> anyhow::Result { // When invalid LSN is requested, it means "don't wait, return latest version of the page" // This is necessary for bootstrap. if lsn == Lsn(0) { @@ -945,7 +940,6 @@ impl PageCache { self.seqwait_lsn .wait_for_timeout(lsn, TIMEOUT) - .await .with_context(|| { format!( "Timed out while waiting for WAL record at LSN {} to arrive", diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index 72f97aaaa7..2721aa4873 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -10,21 +10,16 @@ // *callmemaybe $url* -- ask pageserver to start walreceiver on $url // -use byteorder::{BigEndian, ByteOrder}; +use byteorder::{ReadBytesExt, WriteBytesExt, BE}; use bytes::{Buf, BufMut, Bytes, BytesMut}; use log::*; use regex::Regex; use std::io; +use std::io::{BufReader, BufWriter, Read, Write}; +use std::net::{TcpListener, TcpStream}; use std::str::FromStr; -use std::sync::Arc; use std::thread; use std::time::Duration; -use tokio::io::{AsyncReadExt, AsyncWriteExt, BufWriter}; -use tokio::net::{TcpListener, TcpStream}; -use tokio::runtime; -use tokio::runtime::Runtime; -use tokio::sync::mpsc; -use tokio::task; use zenith_utils::lsn::Lsn; use crate::basebackup; @@ -116,26 +111,41 @@ enum StartupRequestCode { } impl FeStartupMessage { - pub fn parse(buf: &mut BytesMut) -> Result> { + pub fn read(stream: &mut dyn std::io::Read) -> Result> { const MAX_STARTUP_PACKET_LENGTH: u32 = 10000; const CANCEL_REQUEST_CODE: u32 = (1234 << 16) | 5678; const NEGOTIATE_SSL_CODE: u32 = (1234 << 16) | 5679; const NEGOTIATE_GSS_CODE: u32 = (1234 << 16) | 5680; - if buf.len() < 4 { - return Ok(None); - } - let len = BigEndian::read_u32(&buf[0..4]); - + // Read length. If the connection is closed before reading anything (or before + // reading 4 bytes, to be precise), return None to indicate that the connection + // was closed. This matches the PostgreSQL server's behavior, which avoids noise + // in the log if the client opens connection but closes it immediately. + let len = match stream.read_u32::() { + Ok(len) => len, + Err(err) => { + if err.kind() == std::io::ErrorKind::UnexpectedEof { + return Ok(None); + } else { + return Err(err); + } + } + }; if len < 4 || len as u32 > MAX_STARTUP_PACKET_LENGTH { return Err(io::Error::new( io::ErrorKind::InvalidData, "invalid message length", )); } + let bodylen = len - 4; - let version = BigEndian::read_u32(&buf[4..8]); + // Read the rest of the startup packet + let mut body_buf: Vec = vec![0; bodylen as usize]; + stream.read_exact(&mut body_buf)?; + let mut body = Bytes::from(body_buf); + // Parse the first field, which indicates what kind of a packet it is + let version = body.get_u32(); let kind = match version { CANCEL_REQUEST_CODE => StartupRequestCode::Cancel, NEGOTIATE_SSL_CODE => StartupRequestCode::NegotiateSsl, @@ -143,7 +153,8 @@ impl FeStartupMessage { _ => StartupRequestCode::Normal, }; - buf.advance(len as usize); + // Ignore the rest of the packet + Ok(Some(FeMessage::StartupMessage(FeStartupMessage { version, kind, @@ -328,35 +339,38 @@ impl FeCloseMessage { } impl FeMessage { - pub fn parse(buf: &mut BytesMut) -> Result> { - if buf.len() < 5 { - let to_read = 5 - buf.len(); - buf.reserve(to_read); - return Ok(None); - } - - let tag = buf[0]; - let len = BigEndian::read_u32(&buf[1..5]); + pub fn read(stream: &mut dyn Read) -> Result> { + // Each libpq message begins with a message type byte, followed by message length + // If the client closes the connection, return None. But if the client closes the + // connection in the middle of a message, we will return an error. + let tag = match stream.read_u8() { + Ok(b) => b, + Err(err) => { + if err.kind() == std::io::ErrorKind::UnexpectedEof { + return Ok(None); + } else { + return Err(err); + } + } + }; + let len = stream.read_u32::()?; + // The message length includes itself, so it better be at least 4 if len < 4 { return Err(io::Error::new( io::ErrorKind::InvalidInput, "invalid message length: parsing u32", )); } + let bodylen = len - 4; - let total_len = len as usize + 1; - if buf.len() < total_len { - let to_read = total_len - buf.len(); - buf.reserve(to_read); - return Ok(None); - } + // Read message body + let mut body_buf: Vec = vec![0; bodylen as usize]; + stream.read_exact(&mut body_buf)?; - let mut body = buf.split_to(total_len); - body.advance(5); - - let mut body = body.freeze(); + let mut body = Bytes::from(body_buf); + // Parse it match tag { b'Q' => Ok(Some(FeMessage::Query(FeQueryMessage { body }))), b'P' => Ok(Some(FeParseMessage::parse(body)?)), @@ -385,13 +399,13 @@ impl FeMessage { 2 => Ok(Some(FeMessage::ZenithReadRequest(zreq))), _ => Err(io::Error::new( io::ErrorKind::InvalidInput, - format!("unknown smgr message tag: {},'{:?}'", smgr_tag, buf), + format!("unknown smgr message tag: {},'{:?}'", smgr_tag, body), )), } } tag => Err(io::Error::new( io::ErrorKind::InvalidInput, - format!("unknown message tag: {},'{:?}'", tag, buf), + format!("unknown message tag: {},'{:?}'", tag, body), )), } } @@ -399,152 +413,118 @@ impl FeMessage { /////////////////////////////////////////////////////////////////////////////// +/// +/// Main loop of the page service. +/// +/// Listens for connections, and launches a new handler thread for each. +/// pub fn thread_main(conf: &PageServerConf) { - // Create a new thread pool - // - // FIXME: It would be nice to keep this single-threaded for debugging purposes, - // but that currently leads to a deadlock: if a GetPage@LSN request arrives - // for an LSN that hasn't been received yet, the thread gets stuck waiting for - // the WAL to arrive. If the WAL receiver hasn't been launched yet, i.e - // we haven't received a "callmemaybe" request yet to tell us where to get the - // WAL, we will not have a thread available to process the "callmemaybe" - // request when it does arrive. Using a thread pool alleviates the problem so - // that it doesn't happen in the tests anymore, but in principle it could still - // happen if we receive enough GetPage@LSN requests to consume all of the - // available threads. - //let runtime = runtime::Builder::new_current_thread().enable_all().build().unwrap(); - let runtime = runtime::Runtime::new().unwrap(); - info!("Starting page server on {}", conf.listen_addr); - let runtime_ref = Arc::new(runtime); + let listener = TcpListener::bind(conf.listen_addr).unwrap(); - runtime_ref.block_on(async { - let listener = TcpListener::bind(conf.listen_addr).await.unwrap(); + loop { + let (socket, peer_addr) = listener.accept().unwrap(); + debug!("accepted connection from {}", peer_addr); + socket.set_nodelay(true).unwrap(); + let mut conn_handler = Connection::new(conf.clone(), socket); - loop { - let (socket, peer_addr) = listener.accept().await.unwrap(); - debug!("accepted connection from {}", peer_addr); - socket.set_nodelay(true).unwrap(); - let mut conn_handler = Connection::new(conf.clone(), socket, &runtime_ref); - - task::spawn(async move { - if let Err(err) = conn_handler.run().await { - error!("error: {}", err); - } - }); - } - }); + thread::spawn(move || { + if let Err(err) = conn_handler.run() { + error!("error: {}", err); + } + }); + } } #[derive(Debug)] struct Connection { + stream_in: BufReader, stream: BufWriter, buffer: BytesMut, init_done: bool, conf: PageServerConf, - runtime: Arc, } impl Connection { - pub fn new(conf: PageServerConf, socket: TcpStream, runtime: &Arc) -> Connection { + pub fn new(conf: PageServerConf, socket: TcpStream) -> Connection { Connection { + stream_in: BufReader::new(socket.try_clone().unwrap()), stream: BufWriter::new(socket), buffer: BytesMut::with_capacity(10 * 1024), init_done: false, conf, - runtime: Arc::clone(runtime), } } // // Read full message or return None if connection is closed // - async fn read_message(&mut self) -> Result> { - loop { - if let Some(message) = self.parse_message()? { - return Ok(Some(message)); - } - - if self.stream.read_buf(&mut self.buffer).await? == 0 { - if self.buffer.is_empty() { - return Ok(None); - } else { - return Err(io::Error::new( - io::ErrorKind::Other, - "connection reset by peer", - )); - } - } - } - } - - fn parse_message(&mut self) -> Result> { + fn read_message(&mut self) -> Result> { if !self.init_done { - FeStartupMessage::parse(&mut self.buffer) + FeStartupMessage::read(&mut self.stream_in) } else { - FeMessage::parse(&mut self.buffer) + FeMessage::read(&mut self.stream_in) } } - async fn write_message_noflush(&mut self, message: &BeMessage) -> io::Result<()> { + fn write_message_noflush(&mut self, message: &BeMessage) -> io::Result<()> { match message { BeMessage::AuthenticationOk => { - self.stream.write_u8(b'R').await?; - self.stream.write_i32(4 + 4).await?; - self.stream.write_i32(0).await?; + self.stream.write_u8(b'R')?; + self.stream.write_i32::(4 + 4)?; + self.stream.write_i32::(0)?; } BeMessage::ReadyForQuery => { - self.stream.write_u8(b'Z').await?; - self.stream.write_i32(4 + 1).await?; - self.stream.write_u8(b'I').await?; + self.stream.write_u8(b'Z')?; + self.stream.write_i32::(4 + 1)?; + self.stream.write_u8(b'I')?; } BeMessage::ParseComplete => { - self.stream.write_u8(b'1').await?; - self.stream.write_i32(4).await?; + self.stream.write_u8(b'1')?; + self.stream.write_i32::(4)?; } BeMessage::BindComplete => { - self.stream.write_u8(b'2').await?; - self.stream.write_i32(4).await?; + self.stream.write_u8(b'2')?; + self.stream.write_i32::(4)?; } BeMessage::CloseComplete => { - self.stream.write_u8(b'3').await?; - self.stream.write_i32(4).await?; + self.stream.write_u8(b'3')?; + self.stream.write_i32::(4)?; } BeMessage::NoData => { - self.stream.write_u8(b'n').await?; - self.stream.write_i32(4).await?; + self.stream.write_u8(b'n')?; + self.stream.write_i32::(4)?; } BeMessage::ParameterDescription => { - self.stream.write_u8(b't').await?; - self.stream.write_i32(6).await?; + self.stream.write_u8(b't')?; + self.stream.write_i32::(6)?; // we don't support params, so always 0 - self.stream.write_i16(0).await?; + self.stream.write_i16::(0)?; } BeMessage::RowDescription => { // XXX let b = Bytes::from("data\0"); - self.stream.write_u8(b'T').await?; + self.stream.write_u8(b'T')?; self.stream - .write_i32(4 + 2 + b.len() as i32 + 3 * (4 + 2)) - .await?; + .write_i32::(4 + 2 + b.len() as i32 + 3 * (4 + 2))?; - self.stream.write_i16(1).await?; - self.stream.write_all(&b).await?; - self.stream.write_i32(0).await?; /* table oid */ - self.stream.write_i16(0).await?; /* attnum */ - self.stream.write_i32(25).await?; /* TEXTOID */ - self.stream.write_i16(-1).await?; /* typlen */ - self.stream.write_i32(0).await?; /* typmod */ - self.stream.write_i16(0).await?; /* format code */ + self.stream.write_i16::(1)?; + self.stream.write_all(&b)?; + self.stream.write_i32::(0)?; /* table oid */ + self.stream.write_i16::(0)?; /* attnum */ + self.stream.write_i32::(25)?; /* TEXTOID */ + self.stream.write_i16::(-1)?; /* typlen */ + self.stream.write_i32::(0)?; /* typmod */ + self.stream.write_i16::(0)?; /* format code */ } // XXX: accept some text data @@ -552,74 +532,73 @@ impl Connection { // XXX let b = Bytes::from("hello world"); - self.stream.write_u8(b'D').await?; - self.stream.write_i32(4 + 2 + 4 + b.len() as i32).await?; + self.stream.write_u8(b'D')?; + self.stream.write_i32::(4 + 2 + 4 + b.len() as i32)?; - self.stream.write_i16(1).await?; - self.stream.write_i32(b.len() as i32).await?; - self.stream.write_all(&b).await?; + self.stream.write_i16::(1)?; + self.stream.write_i32::(b.len() as i32)?; + self.stream.write_all(&b)?; } BeMessage::ControlFile => { // TODO pass checkpoint and xid info in this message let b = Bytes::from("hello pg_control"); - self.stream.write_u8(b'D').await?; - self.stream.write_i32(4 + 2 + 4 + b.len() as i32).await?; + self.stream.write_u8(b'D')?; + self.stream.write_i32::(4 + 2 + 4 + b.len() as i32)?; - self.stream.write_i16(1).await?; - self.stream.write_i32(b.len() as i32).await?; - self.stream.write_all(&b).await?; + self.stream.write_i16::(1)?; + self.stream.write_i32::(b.len() as i32)?; + self.stream.write_all(&b)?; } BeMessage::CommandComplete => { let b = Bytes::from("SELECT 1\0"); - self.stream.write_u8(b'C').await?; - self.stream.write_i32(4 + b.len() as i32).await?; - self.stream.write_all(&b).await?; + self.stream.write_u8(b'C')?; + self.stream.write_i32::(4 + b.len() as i32)?; + self.stream.write_all(&b)?; } BeMessage::ZenithStatusResponse(resp) => { - self.stream.write_u8(b'd').await?; - self.stream.write_u32(4 + 1 + 1 + 4).await?; - self.stream.write_u8(100).await?; /* tag from pagestore_client.h */ - self.stream.write_u8(resp.ok as u8).await?; - self.stream.write_u32(resp.n_blocks).await?; + self.stream.write_u8(b'd')?; + self.stream.write_u32::(4 + 1 + 1 + 4)?; + self.stream.write_u8(100)?; /* tag from pagestore_client.h */ + self.stream.write_u8(resp.ok as u8)?; + self.stream.write_u32::(resp.n_blocks)?; } BeMessage::ZenithNblocksResponse(resp) => { - self.stream.write_u8(b'd').await?; - self.stream.write_u32(4 + 1 + 1 + 4).await?; - self.stream.write_u8(101).await?; /* tag from pagestore_client.h */ - self.stream.write_u8(resp.ok as u8).await?; - self.stream.write_u32(resp.n_blocks).await?; + self.stream.write_u8(b'd')?; + self.stream.write_u32::(4 + 1 + 1 + 4)?; + self.stream.write_u8(101)?; /* tag from pagestore_client.h */ + self.stream.write_u8(resp.ok as u8)?; + self.stream.write_u32::(resp.n_blocks)?; } BeMessage::ZenithReadResponse(resp) => { - self.stream.write_u8(b'd').await?; + self.stream.write_u8(b'd')?; self.stream - .write_u32(4 + 1 + 1 + 4 + resp.page.len() as u32) - .await?; - self.stream.write_u8(102).await?; /* tag from pagestore_client.h */ - self.stream.write_u8(resp.ok as u8).await?; - self.stream.write_u32(resp.n_blocks).await?; - self.stream.write_all(&resp.page.clone()).await?; + .write_u32::(4 + 1 + 1 + 4 + resp.page.len() as u32)?; + self.stream.write_u8(102)?; /* tag from pagestore_client.h */ + self.stream.write_u8(resp.ok as u8)?; + self.stream.write_u32::(resp.n_blocks)?; + self.stream.write_all(&resp.page.clone())?; } } Ok(()) } - async fn write_message(&mut self, message: &BeMessage) -> io::Result<()> { - self.write_message_noflush(message).await?; - self.stream.flush().await + fn write_message(&mut self, message: &BeMessage) -> io::Result<()> { + self.write_message_noflush(message)?; + self.stream.flush() } - async fn run(&mut self) -> Result<()> { + fn run(&mut self) -> Result<()> { let mut unnamed_query_string = Bytes::new(); loop { - let msg = self.read_message().await?; + let msg = self.read_message()?; trace!("got message {:?}", msg); match msg { Some(FeMessage::StartupMessage(m)) => { @@ -628,41 +607,39 @@ impl Connection { match m.kind { StartupRequestCode::NegotiateGss | StartupRequestCode::NegotiateSsl => { let b = Bytes::from("N"); - self.stream.write_all(&b).await?; - self.stream.flush().await?; + self.stream.write_all(&b)?; + self.stream.flush()?; } StartupRequestCode::Normal => { - self.write_message_noflush(&BeMessage::AuthenticationOk) - .await?; - self.write_message(&BeMessage::ReadyForQuery).await?; + self.write_message_noflush(&BeMessage::AuthenticationOk)?; + self.write_message(&BeMessage::ReadyForQuery)?; self.init_done = true; } StartupRequestCode::Cancel => return Ok(()), } } Some(FeMessage::Query(m)) => { - self.process_query(m.body).await?; + self.process_query(m.body)?; } Some(FeMessage::Parse(m)) => { unnamed_query_string = m.query_string; - self.write_message(&BeMessage::ParseComplete).await?; + self.write_message(&BeMessage::ParseComplete)?; } Some(FeMessage::Describe(_)) => { - self.write_message_noflush(&BeMessage::ParameterDescription) - .await?; - self.write_message(&BeMessage::NoData).await?; + self.write_message_noflush(&BeMessage::ParameterDescription)?; + self.write_message(&BeMessage::NoData)?; } Some(FeMessage::Bind(_)) => { - self.write_message(&BeMessage::BindComplete).await?; + self.write_message(&BeMessage::BindComplete)?; } Some(FeMessage::Close(_)) => { - self.write_message(&BeMessage::CloseComplete).await?; + self.write_message(&BeMessage::CloseComplete)?; } Some(FeMessage::Execute(_)) => { - self.process_query(unnamed_query_string.clone()).await?; + self.process_query(unnamed_query_string.clone())?; } Some(FeMessage::Sync) => { - self.write_message(&BeMessage::ReadyForQuery).await?; + self.write_message(&BeMessage::ReadyForQuery)?; } Some(FeMessage::Terminate) => { break; @@ -681,7 +658,7 @@ impl Connection { Ok(()) } - async fn process_query(&mut self, query_string: Bytes) -> Result<()> { + fn process_query(&mut self, query_string: Bytes) -> Result<()> { debug!("process query {:?}", query_string); // remove null terminator, if any @@ -691,13 +668,13 @@ impl Connection { } if query_string.starts_with(b"controlfile") { - self.handle_controlfile().await + self.handle_controlfile() } else if query_string.starts_with(b"pagestream ") { let (_l, r) = query_string.split_at("pagestream ".len()); let timelineid_str = String::from_utf8(r.to_vec()).unwrap(); let timelineid = ZTimelineId::from_str(&timelineid_str).unwrap(); - self.handle_pagerequests(timelineid).await + self.handle_pagerequests(timelineid) } else if query_string.starts_with(b"basebackup ") { let (_l, r) = query_string.split_at("basebackup ".len()); let r = r.to_vec(); @@ -706,10 +683,9 @@ impl Connection { let timelineid = ZTimelineId::from_str(&timelineid_str).unwrap(); // Check that the timeline exists - self.handle_basebackup_request(timelineid).await?; - self.write_message_noflush(&BeMessage::CommandComplete) - .await?; - self.write_message(&BeMessage::ReadyForQuery).await + self.handle_basebackup_request(timelineid)?; + self.write_message_noflush(&BeMessage::CommandComplete)?; + self.write_message(&BeMessage::ReadyForQuery) } else if query_string.starts_with(b"callmemaybe ") { let query_str = String::from_utf8(query_string.to_vec()) .unwrap() @@ -733,36 +709,29 @@ impl Connection { walreceiver::launch_wal_receiver(&self.conf, timelineid, &connstr); - self.write_message_noflush(&BeMessage::CommandComplete) - .await?; - self.write_message(&BeMessage::ReadyForQuery).await + self.write_message_noflush(&BeMessage::CommandComplete)?; + self.write_message(&BeMessage::ReadyForQuery) } else if query_string.starts_with(b"status") { - self.write_message_noflush(&BeMessage::RowDescription) - .await?; - self.write_message_noflush(&BeMessage::DataRow).await?; - self.write_message_noflush(&BeMessage::CommandComplete) - .await?; - self.write_message(&BeMessage::ReadyForQuery).await + self.write_message_noflush(&BeMessage::RowDescription)?; + self.write_message_noflush(&BeMessage::DataRow)?; + self.write_message_noflush(&BeMessage::CommandComplete)?; + self.write_message(&BeMessage::ReadyForQuery) } else { - self.write_message_noflush(&BeMessage::RowDescription) - .await?; - self.write_message_noflush(&BeMessage::DataRow).await?; - self.write_message_noflush(&BeMessage::CommandComplete) - .await?; - self.write_message(&BeMessage::ReadyForQuery).await + self.write_message_noflush(&BeMessage::RowDescription)?; + self.write_message_noflush(&BeMessage::DataRow)?; + self.write_message_noflush(&BeMessage::CommandComplete)?; + self.write_message(&BeMessage::ReadyForQuery) } } - async fn handle_controlfile(&mut self) -> Result<()> { - self.write_message_noflush(&BeMessage::RowDescription) - .await?; - self.write_message_noflush(&BeMessage::ControlFile).await?; - self.write_message_noflush(&BeMessage::CommandComplete) - .await?; - self.write_message(&BeMessage::ReadyForQuery).await + fn handle_controlfile(&mut self) -> Result<()> { + self.write_message_noflush(&BeMessage::RowDescription)?; + self.write_message_noflush(&BeMessage::ControlFile)?; + self.write_message_noflush(&BeMessage::CommandComplete)?; + self.write_message(&BeMessage::ReadyForQuery) } - async fn handle_pagerequests(&mut self, timelineid: ZTimelineId) -> Result<()> { + fn handle_pagerequests(&mut self, timelineid: ZTimelineId) -> Result<()> { // Check that the timeline exists let pcache = page_cache::get_or_restore_pagecache(&self.conf, timelineid); if pcache.is_err() { @@ -773,14 +742,14 @@ impl Connection { let pcache = pcache.unwrap(); /* switch client to COPYBOTH */ - self.stream.write_u8(b'W').await?; - self.stream.write_i32(4 + 1 + 2).await?; - self.stream.write_u8(0).await?; /* copy_is_binary */ - self.stream.write_i16(0).await?; /* numAttributes */ - self.stream.flush().await?; + self.stream.write_u8(b'W')?; + self.stream.write_i32::(4 + 1 + 2)?; + self.stream.write_u8(0)?; /* copy_is_binary */ + self.stream.write_i16::(0)?; /* numAttributes */ + self.stream.flush()?; loop { - let message = self.read_message().await?; + let message = self.read_message()?; if let Some(m) = &message { trace!("query({:?}): {:?}", timelineid, m); @@ -800,13 +769,12 @@ impl Connection { forknum: req.forknum, }; - let exist = pcache.relsize_exist(&tag, req.lsn).await.unwrap_or(false); + let exist = pcache.relsize_exist(&tag, req.lsn).unwrap_or(false); self.write_message(&BeMessage::ZenithStatusResponse(ZenithStatusResponse { ok: exist, n_blocks: 0, - })) - .await? + }))? } Some(FeMessage::ZenithNblocksRequest(req)) => { let tag = page_cache::RelTag { @@ -816,13 +784,12 @@ impl Connection { forknum: req.forknum, }; - let n_blocks = pcache.relsize_get(&tag, req.lsn).await.unwrap_or(0); + let n_blocks = pcache.relsize_get(&tag, req.lsn).unwrap_or(0); self.write_message(&BeMessage::ZenithNblocksResponse(ZenithStatusResponse { ok: true, n_blocks, - })) - .await? + }))? } Some(FeMessage::ZenithReadRequest(req)) => { let buf_tag = page_cache::BufferTag { @@ -835,7 +802,7 @@ impl Connection { blknum: req.blkno, }; - let msg = match pcache.get_page_at_lsn(buf_tag, req.lsn).await { + let msg = match pcache.get_page_at_lsn(buf_tag, req.lsn) { Ok(p) => BeMessage::ZenithReadResponse(ZenithReadResponse { ok: true, n_blocks: 0, @@ -852,14 +819,14 @@ impl Connection { } }; - self.write_message(&msg).await? + self.write_message(&msg)? } _ => {} } } } - async fn handle_basebackup_request(&mut self, timelineid: ZTimelineId) -> Result<()> { + fn handle_basebackup_request(&mut self, timelineid: ZTimelineId) -> Result<()> { // check that the timeline exists let pcache = page_cache::get_or_restore_pagecache(&self.conf, timelineid); if pcache.is_err() { @@ -870,11 +837,11 @@ impl Connection { /* switch client to COPYOUT */ let stream = &mut self.stream; - stream.write_u8(b'H').await?; - stream.write_i32(4 + 1 + 2).await?; - stream.write_u8(0).await?; /* copy_is_binary */ - stream.write_i16(0).await?; /* numAttributes */ - stream.flush().await?; + stream.write_u8(b'H')?; + stream.write_i32::(4 + 1 + 2)?; + stream.write_u8(0)?; /* copy_is_binary */ + stream.write_i16::(0)?; /* numAttributes */ + stream.flush()?; info!("sent CopyOut"); /* Send a tarball of the latest snapshot on the timeline */ @@ -882,49 +849,16 @@ impl Connection { // find latest snapshot let snapshotlsn = restore_local_repo::find_latest_snapshot(&self.conf, timelineid).unwrap(); - // Stream it - let (s, mut r) = mpsc::channel(5); - - let f_tar = task::spawn_blocking(move || { - basebackup::send_snapshot_tarball(&mut CopyDataSink(s), timelineid, snapshotlsn)?; - Ok(()) - }); - let f_tar2 = async { - let joinres = f_tar.await; - - if let Err(joinreserr) = joinres { - return Err(io::Error::new(io::ErrorKind::InvalidData, joinreserr)); - } - joinres.unwrap() - }; - - let f_pump = async move { - loop { - let buf = r.recv().await; - if buf.is_none() { - break; - } - let buf = buf.unwrap(); - - // CopyData - stream.write_u8(b'd').await?; - stream.write_u32((4 + buf.len()) as u32).await?; - stream.write_all(&buf).await?; - trace!("CopyData sent for {} bytes!", buf.len()); - - // FIXME: flush isn't really required, but makes it easier - // to view in wireshark - stream.flush().await?; - } - Ok(()) - }; - - tokio::try_join!(f_tar2, f_pump)?; + basebackup::send_snapshot_tarball( + &mut CopyDataSink { stream: stream }, + timelineid, + snapshotlsn, + )?; // CopyDone - self.stream.write_u8(b'c').await?; - self.stream.write_u32(4).await?; - self.stream.flush().await?; + self.stream.write_u8(b'c')?; + self.stream.write_u32::(4)?; + self.stream.flush()?; debug!("CopyDone sent!"); // FIXME: I'm getting an error from the tokio copyout driver without this. @@ -936,15 +870,28 @@ impl Connection { } } -struct CopyDataSink(mpsc::Sender); +/// +/// A std::io::Write implementation that wraps all data written to it in CopyData +/// messages. +/// +struct CopyDataSink<'a> { + stream: &'a mut BufWriter, +} -impl std::io::Write for CopyDataSink { +impl<'a> std::io::Write for CopyDataSink<'a> { fn write(&mut self, data: &[u8]) -> std::result::Result { - let buf = Bytes::copy_from_slice(data); + // CopyData + // FIXME: if the input is large, we should split it into multiple messages. + // Not sure what the threshold should be, but the ultimate hard limit is that + // the length cannot exceed u32. + self.stream.write_u8(b'd')?; + self.stream.write_u32::((4 + data.len()) as u32)?; + self.stream.write_all(&data)?; + trace!("CopyData sent for {} bytes!", data.len()); - if let Err(e) = self.0.blocking_send(buf) { - return Err(io::Error::new(io::ErrorKind::Other, e)); - } + // FIXME: flush isn't really required, but makes it easier + // to view in wireshark + self.stream.flush()?; Ok(data.len()) } diff --git a/pageserver/src/walreceiver.rs b/pageserver/src/walreceiver.rs index 5ef5f1cf02..3ab75ee02c 100644 --- a/pageserver/src/walreceiver.rs +++ b/pageserver/src/walreceiver.rs @@ -26,8 +26,9 @@ use std::path::PathBuf; use std::str::FromStr; use std::sync::Mutex; use std::thread; -use tokio::runtime; -use tokio::time::{sleep, Duration}; +use std::thread::sleep; +use std::time::Duration; +use tokio::runtime::Runtime; use tokio_postgres::replication::{PgTimestamp, ReplicationStream}; use tokio_postgres::{NoTls, SimpleQueryMessage, SimpleQueryRow}; use tokio_stream::StreamExt; @@ -95,30 +96,38 @@ fn thread_main(conf: &PageServerConf, timelineid: ZTimelineId) { timelineid ); - let runtime = runtime::Builder::new_current_thread() + // We need a tokio runtime to call the rust-postgres copy_both function. + // Most functions in the rust-postgres driver have a blocking wrapper, + // but copy_both does not (TODO: the copy_both support is still work-in-progress + // as of this writing. Check later if that has changed, or implement the + // wrapper ourselves in rust-postgres) + let runtime = tokio::runtime::Builder::new_current_thread() .enable_all() .build() .unwrap(); - runtime.block_on(async { - loop { - // Look up the current WAL producer address - let wal_producer_connstr = get_wal_producer_connstr(timelineid); + // + // Make a connection to the WAL safekeeper, or directly to the primary PostgreSQL server, + // and start streaming WAL from it. If the connection is lost, keep retrying. + // + loop { + // Look up the current WAL producer address + let wal_producer_connstr = get_wal_producer_connstr(timelineid); - let res = walreceiver_main(conf, timelineid, &wal_producer_connstr).await; + let res = walreceiver_main(&runtime, conf, timelineid, &wal_producer_connstr); - if let Err(e) = res { - info!( - "WAL streaming connection failed ({}), retrying in 1 second", - e - ); - sleep(Duration::from_secs(1)).await; - } + if let Err(e) = res { + info!( + "WAL streaming connection failed ({}), retrying in 1 second", + e + ); + sleep(Duration::from_secs(1)); } - }); + } } -async fn walreceiver_main( +fn walreceiver_main( + runtime: &Runtime, conf: &PageServerConf, timelineid: ZTimelineId, wal_producer_connstr: &str, @@ -126,18 +135,19 @@ async fn walreceiver_main( // Connect to the database in replication mode. info!("connecting to {:?}", wal_producer_connstr); let connect_cfg = format!("{} replication=true", wal_producer_connstr); - let (rclient, connection) = tokio_postgres::connect(&connect_cfg, NoTls).await?; + + let (rclient, connection) = runtime.block_on(tokio_postgres::connect(&connect_cfg, NoTls))?; info!("connected!"); // The connection object performs the actual communication with the database, // so spawn it off to run on its own. - tokio::spawn(async move { + runtime.spawn(async move { if let Err(e) = connection.await { error!("connection error: {}", e); } }); - let identify = identify_system(&rclient).await?; + let identify = identify_system(runtime, &rclient)?; info!("{:?}", identify); let end_of_wal = Lsn::from(u64::from(identify.xlogpos)); let mut caught_up = false; @@ -174,14 +184,15 @@ async fn walreceiver_main( ); let query = format!("START_REPLICATION PHYSICAL {}", startpoint); - let copy_stream = rclient.copy_both_simple::(&query).await?; + + let copy_stream = runtime.block_on(rclient.copy_both_simple::(&query))?; let physical_stream = ReplicationStream::new(copy_stream); tokio::pin!(physical_stream); let mut waldecoder = WalStreamDecoder::new(startpoint); - while let Some(replication_message) = physical_stream.next().await { + while let Some(replication_message) = runtime.block_on(physical_stream.next()) { match replication_message? { ReplicationMessage::XLogData(xlog_data) => { // Pass the WAL data to the decoder, and see if we can decode @@ -309,10 +320,11 @@ async fn walreceiver_main( let ts = PgTimestamp::now()?; const NO_REPLY: u8 = 0u8; - physical_stream - .as_mut() - .standby_status_update(write_lsn, flush_lsn, apply_lsn, ts, NO_REPLY) - .await?; + runtime.block_on( + physical_stream + .as_mut() + .standby_status_update(write_lsn, flush_lsn, apply_lsn, ts, NO_REPLY), + )?; } } _ => (), @@ -341,9 +353,12 @@ pub struct IdentifySystem { pub struct IdentifyError; /// Run the postgres `IDENTIFY_SYSTEM` command -pub async fn identify_system(client: &tokio_postgres::Client) -> Result { +pub fn identify_system( + runtime: &Runtime, + client: &tokio_postgres::Client, +) -> Result { let query_str = "IDENTIFY_SYSTEM"; - let response = client.simple_query(query_str).await?; + let response = runtime.block_on(client.simple_query(query_str))?; // get(N) from row, then parse it as some destination type. fn get_parse(row: &SimpleQueryRow, idx: usize) -> Result diff --git a/pageserver/src/walredo.rs b/pageserver/src/walredo.rs index f170e18008..abb965f1f6 100644 --- a/pageserver/src/walredo.rs +++ b/pageserver/src/walredo.rs @@ -24,13 +24,13 @@ use std::io::prelude::*; use std::io::Error; use std::path::PathBuf; use std::process::Stdio; +use std::sync::mpsc; use std::sync::{Arc, Mutex}; use std::time::Duration; use std::time::Instant; use tokio::io::AsyncBufReadExt; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::process::{Child, ChildStdin, ChildStdout, Command}; -use tokio::sync::{mpsc, oneshot}; use tokio::time::timeout; use zenith_utils::lsn::Lsn; @@ -52,8 +52,8 @@ pub struct WalRedoManager { conf: PageServerConf, timelineid: ZTimelineId, - request_tx: mpsc::UnboundedSender, - request_rx: Mutex>>, + request_tx: Mutex>, + request_rx: Mutex>>, } struct WalRedoManagerInternal { @@ -61,7 +61,7 @@ struct WalRedoManagerInternal { timelineid: ZTimelineId, pcache: Arc, - request_rx: mpsc::UnboundedReceiver, + request_rx: mpsc::Receiver, } #[derive(Debug)] @@ -69,7 +69,7 @@ struct WalRedoRequest { tag: BufferTag, lsn: Lsn, - response_channel: oneshot::Sender>, + response_channel: mpsc::Sender>, } /// An error happened in WAL redo @@ -89,12 +89,12 @@ impl WalRedoManager { /// This only initializes the struct. You need to call WalRedoManager::launch to /// start the thread that processes the requests. pub fn new(conf: &PageServerConf, timelineid: ZTimelineId) -> WalRedoManager { - let (tx, rx) = mpsc::unbounded_channel(); + let (tx, rx) = mpsc::channel(); WalRedoManager { conf: conf.clone(), timelineid, - request_tx: tx, + request_tx: Mutex::new(tx), request_rx: Mutex::new(Some(rx)), } } @@ -114,22 +114,13 @@ impl WalRedoManager { let _walredo_thread = std::thread::Builder::new() .name("WAL redo thread".into()) .spawn(move || { - // We block on waiting for requests on the walredo request channel, but - // use async I/O to communicate with the child process. Initialize the - // runtime for the async part. - let runtime = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .unwrap(); - let mut internal = WalRedoManagerInternal { _conf: conf_copy, timelineid: timelineid, pcache: pcache, request_rx: request_rx, }; - - runtime.block_on(internal.wal_redo_main()); + internal.wal_redo_main(); }) .unwrap(); } @@ -138,9 +129,9 @@ impl WalRedoManager { /// Request the WAL redo manager to apply WAL records, to reconstruct the page image /// of the given page version. /// - pub async fn request_redo(&self, tag: BufferTag, lsn: Lsn) -> Result { + pub fn request_redo(&self, tag: BufferTag, lsn: Lsn) -> Result { // Create a channel where to receive the response - let (tx, rx) = oneshot::channel::>(); + let (tx, rx) = mpsc::channel::>(); let request = WalRedoRequest { tag, @@ -149,10 +140,12 @@ impl WalRedoManager { }; self.request_tx + .lock() + .unwrap() .send(request) .expect("could not send WAL redo request"); - rx.await + rx.recv() .expect("could not receive response to WAL redo request") } } @@ -164,9 +157,17 @@ impl WalRedoManagerInternal { // // Main entry point for the WAL applicator thread. // - async fn wal_redo_main(&mut self) { + fn wal_redo_main(&mut self) { info!("WAL redo thread started {}", self.timelineid); + // We block on waiting for requests on the walredo request channel, but + // use async I/O to communicate with the child process. Initialize the + // runtime for the async part. + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + // Loop forever, handling requests as they come. loop { let mut process: WalRedoProcess; @@ -174,7 +175,7 @@ impl WalRedoManagerInternal { info!("launching WAL redo postgres process {}", self.timelineid); - process = WalRedoProcess::launch(&datadir).await.unwrap(); + process = runtime.block_on(WalRedoProcess::launch(&datadir)).unwrap(); info!("WAL redo postgres started"); // Pretty arbitrarily, reuse the same Postgres process for 100000 requests. @@ -182,9 +183,9 @@ impl WalRedoManagerInternal { // using up all shared buffers in Postgres's shared buffer cache; we don't // want to write any pages to disk in the WAL redo process. for _i in 1..100000 { - let request = self.request_rx.recv().await.unwrap(); + let request = self.request_rx.recv().unwrap(); - let result = self.handle_apply_request(&process, &request).await; + let result = runtime.block_on(self.handle_apply_request(&process, &request)); let result_ok = result.is_ok(); // Send the result to the requester @@ -202,11 +203,13 @@ impl WalRedoManagerInternal { // Time to kill the 'postgres' process. A new one will be launched on next // iteration of the loop. + // + // TODO: SIGKILL if needed info!("killing WAL redo postgres process"); - let _ = process.stdin.get_mut().shutdown().await; + let _ = process.stdin.get_mut().shutdown(); let mut child = process.child; drop(process.stdin); - let _ = child.wait().await; + let _ = child.wait(); } } @@ -441,6 +444,13 @@ impl WalRedoProcess { let mut stdin = self.stdin.borrow_mut(); let mut stdout = self.stdout.borrow_mut(); + // We do three things simultaneously: send the old base image and WAL records to + // the child process's stdin, read the result from child's stdout, and forward any logging + // information that the child writes to its stderr to the page server's log. + // + // 'f_stdin' handles writing the base image and WAL records to the child process. + // 'f_stdout' below reads the result back. And 'f_stderr', which was spawned into the + // tokio runtime in the 'launch' function already, forwards the logging. let f_stdin = async { // Send base image, if any. (If the record initializes the page, previous page // version is not needed.) @@ -487,10 +497,6 @@ impl WalRedoProcess { Ok::<[u8; 8192], Error>(buf) }; - // Kill the process. This closes its stdin, which should signal the process - // to terminate. TODO: SIGKILL if needed - //child.wait(); - let res = futures::try_join!(f_stdout, f_stdin)?; let buf = res.0; diff --git a/zenith_utils/Cargo.toml b/zenith_utils/Cargo.toml index a26a772c97..ee549ab2f9 100644 --- a/zenith_utils/Cargo.toml +++ b/zenith_utils/Cargo.toml @@ -5,8 +5,4 @@ authors = ["Eric Seppanen "] edition = "2018" [dependencies] -tokio = { version = "1.5", features = ["sync", "time" ] } thiserror = "1" - -[dev-dependencies] -tokio = { version = "1.5", features = ["macros", "rt"] } diff --git a/zenith_utils/src/seqwait.rs b/zenith_utils/src/seqwait.rs index b4f3cdd454..409090256e 100644 --- a/zenith_utils/src/seqwait.rs +++ b/zenith_utils/src/seqwait.rs @@ -1,12 +1,12 @@ #![warn(missing_docs)] -use std::collections::BTreeMap; +use std::cmp::{Eq, Ordering, PartialOrd}; +use std::collections::BinaryHeap; use std::fmt::Debug; use std::mem; +use std::sync::mpsc::{channel, Receiver, Sender}; use std::sync::Mutex; use std::time::Duration; -use tokio::sync::watch::{channel, Receiver, Sender}; -use tokio::time::timeout; /// An error happened while waiting for a number #[derive(Debug, PartialEq, thiserror::Error)] @@ -23,14 +23,44 @@ struct SeqWaitInt where T: Ord, { - waiters: BTreeMap, Receiver<()>)>, + waiters: BinaryHeap>, current: T, shutdown: bool, } +struct Waiter +where + T: Ord, +{ + wake_num: T, // wake me when this number arrives ... + wake_channel: Sender<()>, // ... by sending a message to this channel +} + +// BinaryHeap is a max-heap, and we want a min-heap. Reverse the ordering here +// to get that. +impl PartialOrd for Waiter { + fn partial_cmp(&self, other: &Self) -> Option { + other.wake_num.partial_cmp(&self.wake_num) + } +} + +impl Ord for Waiter { + fn cmp(&self, other: &Self) -> Ordering { + other.wake_num.cmp(&self.wake_num) + } +} + +impl PartialEq for Waiter { + fn eq(&self, other: &Self) -> bool { + other.wake_num == self.wake_num + } +} + +impl Eq for Waiter {} + /// A tool for waiting on a sequence number /// -/// This provides a way to await the arrival of a number. +/// This provides a way to wait the arrival of a number. /// As soon as the number arrives by another caller calling /// [`advance`], then the waiter will be woken up. /// @@ -56,7 +86,7 @@ where /// Create a new `SeqWait`, initialized to a particular number pub fn new(starting_num: T) -> Self { let internal = SeqWaitInt { - waiters: BTreeMap::new(), + waiters: BinaryHeap::new(), current: starting_num, shutdown: false, }; @@ -92,29 +122,12 @@ where /// /// This call won't complete until someone has called `advance` /// with a number greater than or equal to the one we're waiting for. - pub async fn wait_for(&self, num: T) -> Result<(), SeqWaitError> { - let mut rx = { - let mut internal = self.internal.lock().unwrap(); - if internal.current >= num { - return Ok(()); - } - if internal.shutdown { - return Err(SeqWaitError::Shutdown); - } - - // If we already have a channel for waiting on this number, reuse it. - if let Some((_, rx)) = internal.waiters.get_mut(&num) { - // an Err from changed() means the sender was dropped. - rx.clone() - } else { - // Create a new channel. - let (tx, rx) = channel(()); - internal.waiters.insert(num, (tx, rx.clone())); - rx - } - // Drop the lock as we exit this scope. - }; - rx.changed().await.map_err(|_| SeqWaitError::Shutdown) + pub fn wait_for(&self, num: T) -> Result<(), SeqWaitError> { + match self.queue_for_wait(num) { + Ok(None) => Ok(()), + Ok(Some(rx)) => rx.recv().map_err(|_| SeqWaitError::Shutdown), + Err(e) => Err(e), + } } /// Wait for a number to arrive @@ -124,14 +137,36 @@ where /// /// If that hasn't happened after the specified timeout duration, /// [`SeqWaitError::Timeout`] will be returned. - pub async fn wait_for_timeout( - &self, - num: T, - timeout_duration: Duration, - ) -> Result<(), SeqWaitError> { - timeout(timeout_duration, self.wait_for(num)) - .await - .unwrap_or(Err(SeqWaitError::Timeout)) + pub fn wait_for_timeout(&self, num: T, timeout_duration: Duration) -> Result<(), SeqWaitError> { + match self.queue_for_wait(num) { + Ok(None) => Ok(()), + Ok(Some(rx)) => rx.recv_timeout(timeout_duration).map_err(|e| match e { + std::sync::mpsc::RecvTimeoutError::Timeout => SeqWaitError::Timeout, + std::sync::mpsc::RecvTimeoutError::Disconnected => SeqWaitError::Shutdown, + }), + Err(e) => Err(e), + } + } + + /// Register and return a channel that will be notified when a number arrives, + /// or None, if it has already arrived. + fn queue_for_wait(&self, num: T) -> Result>, SeqWaitError> { + let mut internal = self.internal.lock().unwrap(); + if internal.current >= num { + return Ok(None); + } + if internal.shutdown { + return Err(SeqWaitError::Shutdown); + } + + // Create a new channel. + let (tx, rx) = channel(); + internal.waiters.push(Waiter { + wake_num: num, + wake_channel: tx, + }); + // Drop the lock as we exit this scope. + Ok(Some(rx)) } /// Announce a new number has arrived @@ -152,22 +187,19 @@ where } internal.current = num; - // split_off will give me all the high-numbered waiters, - // so split and then swap. Everything at or above `num` - // stays. - let mut split = internal.waiters.split_off(&num); - std::mem::swap(&mut split, &mut internal.waiters); - - // `split_at` didn't get the value at `num`; if it's - // there take that too. - if let Some(sleeper) = internal.waiters.remove(&num) { - split.insert(num, sleeper); + // Pop all waiters <= num from the heap. Collect them in a vector, and + // wake them up after releasing the lock. + let mut wake_these = Vec::new(); + while let Some(n) = internal.waiters.peek() { + if n.wake_num > num { + break; + } + wake_these.push(internal.waiters.pop().unwrap().wake_channel); } - - split + wake_these }; - for (_wake_num, (tx, _rx)) in wake_these { + for tx in wake_these { // This can fail if there are no receivers. // We don't care; discard the error. let _ = tx.send(()); @@ -179,38 +211,40 @@ where mod tests { use super::*; use std::sync::Arc; - use tokio::time::{sleep, Duration}; + use std::thread::sleep; + use std::thread::spawn; + use std::time::Duration; - #[tokio::test] - async fn seqwait() { + #[test] + fn seqwait() { let seq = Arc::new(SeqWait::new(0)); let seq2 = Arc::clone(&seq); let seq3 = Arc::clone(&seq); - tokio::spawn(async move { - seq2.wait_for(42).await.expect("wait_for 42"); + spawn(move || { + seq2.wait_for(42).expect("wait_for 42"); seq2.advance(100); - seq2.wait_for(999).await.expect_err("no 999"); + seq2.wait_for(999).expect_err("no 999"); }); - tokio::spawn(async move { - seq3.wait_for(42).await.expect("wait_for 42"); - seq3.wait_for(0).await.expect("wait_for 0"); + spawn(move || { + seq3.wait_for(42).expect("wait_for 42"); + seq3.wait_for(0).expect("wait_for 0"); }); - sleep(Duration::from_secs(1)).await; + sleep(Duration::from_secs(1)); seq.advance(99); - seq.wait_for(100).await.expect("wait_for 100"); + seq.wait_for(100).expect("wait_for 100"); seq.shutdown(); } - #[tokio::test] - async fn seqwait_timeout() { + #[test] + fn seqwait_timeout() { let seq = Arc::new(SeqWait::new(0)); let seq2 = Arc::clone(&seq); - tokio::spawn(async move { + spawn(move || { let timeout = Duration::from_millis(1); - let res = seq2.wait_for_timeout(42, timeout).await; + let res = seq2.wait_for_timeout(42, timeout); assert_eq!(res, Err(SeqWaitError::Timeout)); }); - sleep(Duration::from_secs(1)).await; + sleep(Duration::from_secs(1)); // This will attempt to wake, but nothing will happen // because the waiter already dropped its Receiver. seq.advance(99);