diff --git a/Cargo.lock b/Cargo.lock index c7af140f7d..9196015057 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -244,6 +244,17 @@ dependencies = [ "syn 2.0.52", ] +[[package]] +name = "async-timer" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba5fa6ed76cb2aa820707b4eb9ec46f42da9ce70b0eafab5e5e34942b38a44d5" +dependencies = [ + "libc", + "wasm-bindgen", + "winapi", +] + [[package]] name = "async-trait" version = "0.1.68" @@ -3590,6 +3601,7 @@ dependencies = [ "arc-swap", "async-compression", "async-stream", + "async-timer", "bit_field", "byteorder", "bytes", diff --git a/Cargo.toml b/Cargo.toml index dbda930535..d74efd51f8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,6 +47,7 @@ anyhow = { version = "1.0", features = ["backtrace"] } arc-swap = "1.6" async-compression = { version = "0.4.0", features = ["tokio", "gzip", "zstd"] } atomic-take = "1.1.0" +async-timer = "0.7.4" azure_core = { version = "0.19", default-features = false, features = ["enable_reqwest_rustls", "hmac_rust"] } azure_identity = { version = "0.19", default-features = false, features = ["enable_reqwest_rustls"] } azure_storage = { version = "0.19", default-features = false, features = ["enable_reqwest_rustls"] } diff --git a/pageserver/Cargo.toml b/pageserver/Cargo.toml index 143d8236df..ed2ee2f5c2 100644 --- a/pageserver/Cargo.toml +++ b/pageserver/Cargo.toml @@ -15,6 +15,7 @@ anyhow.workspace = true arc-swap.workspace = true async-compression.workspace = true async-stream.workspace = true +async-timer.workspace = true bit_field.workspace = true byteorder.workspace = true bytes.workspace = true diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index e9eb4bfe65..d25df2ef3e 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -3,6 +3,7 @@ use anyhow::{bail, Context}; use async_compression::tokio::write::GzipEncoder; +use async_timer::Oneshot; use bytes::Buf; use futures::FutureExt; use itertools::Itertools; @@ -22,6 +23,7 @@ use pq_proto::FeStartupPacket; use pq_proto::{BeMessage, FeMessage, RowDescriptor}; use std::borrow::Cow; use std::io; +use std::pin::Pin; use std::str; use std::str::FromStr; use std::sync::Arc; @@ -314,11 +316,15 @@ struct PageServerHandler { timeline_handles: TimelineHandles, - /// Messages queued up for the next processing batch - next_batch: Option, - /// See [`PageServerConf::server_side_batch_timeout`] server_side_batch_timeout: Option, + + server_side_batch_timer: Pin>, +} + +struct Carry { + msg: BatchedFeMessage, + started_at: Instant, } struct TimelineHandles { @@ -582,8 +588,10 @@ impl PageServerHandler { connection_ctx, timeline_handles: TimelineHandles::new(tenant_manager), cancel, - next_batch: None, server_side_batch_timeout, + server_side_batch_timer: Box::pin(async_timer::oneshot::Timer::new( + Duration::from_secs(999), + )), // reset each iteration } } @@ -617,42 +625,87 @@ impl PageServerHandler { pgb: &mut PostgresBackend, tenant_id: &TenantId, timeline_id: &TimelineId, + maybe_carry: &mut Option, ctx: &RequestContext, - ) -> Result, QueryError> + ) -> Result where IO: AsyncRead + AsyncWrite + Send + Sync + Unpin, { debug_assert_current_span_has_tenant_and_timeline_id_no_shard_id(); - let mut batch = self.next_batch.take(); - let mut batch_started_at: Option = None; + let mut batching_deadline_storage = None; // TODO: can this just be an unsync once_cell? - let next_batch: Option = loop { - let sleep_fut = match (self.server_side_batch_timeout, batch_started_at) { - (Some(batch_timeout), Some(started_at)) => futures::future::Either::Left( - tokio::time::sleep_until((started_at + batch_timeout).into()), - ), - _ => futures::future::Either::Right(futures::future::pending()), + loop { + // Create a future that will become ready when we need to stop batching. + // If there's carry, take the time it already spent batching into consideration. + use futures::future::Either; + let batching_deadline = match ( + &*maybe_carry as &Option, + &mut batching_deadline_storage, + ) { + (None, None) => Either::Left(futures::future::pending()), // there's no deadline before we have something batched + (None, Some(_)) => unreachable!(), + (Some(_), Some(fut)) => Either::Right(fut), // below arm already ran + (Some(carry), None) => { + match self.server_side_batch_timeout { + None => { + return Ok(BatchOrEof::Batch(smallvec::smallvec![ + maybe_carry + .take() + .expect("we already checked it's Some") + .msg + ])) + } + Some(batch_timeout) => { + // Take into consideration the time the carry spent waiting. + let now = Instant::now(); + let batch_timeout = + batch_timeout.saturating_sub(now - carry.started_at); + if batch_timeout.is_zero() { + // the timer doesn't support restarting with zero duration + return Ok(BatchOrEof::Batch(smallvec::smallvec![ + maybe_carry + .take() + .expect("we already checked it's Some") + .msg + ])); + } else { + std::future::poll_fn(|ctx| { + self.server_side_batch_timer + .restart(batch_timeout, ctx.waker()); + std::task::Poll::Ready(()) + }) + .await; + batching_deadline_storage = Some(&mut self.server_side_batch_timer); + Either::Right( + batching_deadline_storage.as_mut().expect("we just set it"), + ) + } + } + } + } }; - let msg = tokio::select! { biased; _ = self.cancel.cancelled() => { return Err(QueryError::Shutdown) } - msg = pgb.read_message() => { - msg - } - _ = sleep_fut => { - assert!(batch.is_some(), "batch_started_at => sleep_fut = futures::future::pending()"); - trace!("batch timeout"); - break None; + _ = batching_deadline => { + return Ok(BatchOrEof::Batch(smallvec::smallvec![maybe_carry.take().expect("per construction of batching_deadline").msg])); } + msg = pgb.read_message() => { msg } }; + + let msg_start = Instant::now(); + + // Rest of this loop body is trying to batch `msg` into `batch`. + // If we can add msg to batch we continue into the next loop iteration. + // If we can't add msg to batch batch, we carry `msg` over to the next call. + let copy_data_bytes = match msg? { Some(FeMessage::CopyData(bytes)) => bytes, Some(FeMessage::Terminate) => { - return Ok(Some(BatchOrEof::Eof)); + return Ok(BatchOrEof::Eof); } Some(m) => { return Err(QueryError::Other(anyhow::anyhow!( @@ -660,10 +713,11 @@ impl PageServerHandler { ))); } None => { - return Ok(Some(BatchOrEof::Eof)); + return Ok(BatchOrEof::Eof); } // client disconnected }; trace!("query: {copy_data_bytes:?}"); + fail::fail_point!("ps::handle-pagerequest-message"); // parse request @@ -705,11 +759,11 @@ impl PageServerHandler { span, error: $error, }; - let batch_and_error = match batch { - Some(b) => smallvec::smallvec![b, error], + let batch_and_error = match maybe_carry.take() { + Some(carry) => smallvec::smallvec![carry.msg, error], None => smallvec::smallvec![error], }; - Ok(Some(BatchOrEof::Batch(batch_and_error))) + Ok(BatchOrEof::Batch(batch_and_error)) }}; } @@ -762,26 +816,18 @@ impl PageServerHandler { } }; - let batch_timeout = match self.server_side_batch_timeout { - Some(value) => value, - None => { - // Batching is not enabled - stop on the first message. - return Ok(Some(BatchOrEof::Batch(smallvec::smallvec![this_msg]))); - } - }; - // check if we can batch - match (&mut batch, this_msg) { + match (maybe_carry.take(), this_msg) { (None, this_msg) => { - batch = Some(this_msg); + *maybe_carry = Some(Carry { msg: this_msg, started_at: msg_start }); } ( - Some(BatchedFeMessage::GetPage { + Some(Carry { msg: BatchedFeMessage::GetPage { span: _, shard: accum_shard, - pages: accum_pages, + pages: mut accum_pages, effective_request_lsn: accum_lsn, - }), + }, started_at: _}), BatchedFeMessage::GetPage { span: _, shard: this_shard, @@ -805,7 +851,7 @@ impl PageServerHandler { } // the vectored get currently only supports a single LSN, so, bounce as soon // as the effective request_lsn changes - if *accum_lsn != this_lsn { + if accum_lsn != this_lsn { trace!(%accum_lsn, %this_lsn, "stopping batching because LSN changed"); return false; } @@ -816,21 +862,16 @@ impl PageServerHandler { // ok to batch accum_pages.extend(this_pages); } - (Some(_), this_msg) => { + (Some(carry), this_msg) => { // by default, don't continue batching - break Some(this_msg); + *maybe_carry = Some(Carry { + msg: this_msg, + started_at: msg_start, + }); + return Ok(BatchOrEof::Batch(smallvec::smallvec![carry.msg])); } } - - // batching impl piece - let started_at = batch_started_at.get_or_insert_with(Instant::now); - if started_at.elapsed() > batch_timeout { - break None; - } - }; - - self.next_batch = next_batch; - Ok(batch.map(|b| BatchOrEof::Batch(smallvec::smallvec![b]))) + } } /// Pagestream sub-protocol handler. @@ -868,22 +909,17 @@ impl PageServerHandler { } } - // If [`PageServerHandler`] is reused for multiple pagestreams, - // then make sure to not process requests from the previous ones. - self.next_batch = None; + let mut carry: Option = None; loop { let maybe_batched = self - .read_batch_from_connection(pgb, &tenant_id, &timeline_id, &ctx) + .read_batch_from_connection(pgb, &tenant_id, &timeline_id, &mut carry, &ctx) .await?; let batched = match maybe_batched { - Some(BatchOrEof::Batch(b)) => b, - Some(BatchOrEof::Eof) => { + BatchOrEof::Batch(b) => b, + BatchOrEof::Eof => { break; } - None => { - continue; - } }; for batch in batched { diff --git a/test_runner/regress/test_pageserver_getpage_merge.py b/test_runner/regress/test_pageserver_getpage_merge.py index fbca14ebac..1d284d2547 100644 --- a/test_runner/regress/test_pageserver_getpage_merge.py +++ b/test_runner/regress/test_pageserver_getpage_merge.py @@ -7,9 +7,10 @@ from fixtures.neon_fixtures import NeonEnvBuilder from fixtures.log_helper import log @pytest.mark.parametrize("tablesize_mib", [50, 500]) -@pytest.mark.parametrize("batch_timeout", ["10us", "100us", "1ms"]) -@pytest.mark.parametrize("target_runtime", [30]) -def test_getpage_merge_smoke(neon_env_builder: NeonEnvBuilder, tablesize_mib: int, batch_timeout: str, target_runtime: int): +@pytest.mark.parametrize("batch_timeout", [None, "1ns", "5us", "10us", "100us", "1ms"]) +@pytest.mark.parametrize("target_runtime", [5]) +@pytest.mark.parametrize("effective_io_concurrency", [1, 32, 64, 100]) # 32 is the current vectored get max batch size +def test_getpage_merge_smoke(neon_env_builder: NeonEnvBuilder, tablesize_mib: int, batch_timeout: str, target_runtime: int, effective_io_concurrency: int): """ Do a bunch of sequential scans and ensure that the pageserver does some merging. """ @@ -23,6 +24,12 @@ def test_getpage_merge_smoke(neon_env_builder: NeonEnvBuilder, tablesize_mib: in conn = endpoint.connect() cur = conn.cursor() + log.info("tablesize_mib=%d, batch_timeout=%s, target_runtime=%d, effective_io_concurrency=%d", tablesize_mib, batch_timeout, target_runtime, effective_io_concurrency) + + cur.execute("SET max_parallel_workers_per_gather=0") # disable parallel backends + cur.execute(f"SET effective_io_concurrency={effective_io_concurrency}") + cur.execute("SET neon.readahead_buffer_size=128") # this is the current default value, but let's hard-code that + # # Setup # @@ -93,11 +100,6 @@ def test_getpage_merge_smoke(neon_env_builder: NeonEnvBuilder, tablesize_mib: in return self.metrics.normalize(self.iters) def workload() -> Result: - cur.execute("SET max_parallel_workers_per_gather=0") # disable parallel backends - cur.execute("SET effective_io_concurrency=100") - cur.execute("SET neon.readahead_buffer_size=128") - # cur.execute("SET neon.flush_output_after=1") - start = time.time() iters = 0 while time.time() - start < target_runtime or iters < 2: @@ -112,21 +114,9 @@ def test_getpage_merge_smoke(neon_env_builder: NeonEnvBuilder, tablesize_mib: in after = get_metrics() return Result(metrics=after-before, iters=iters) - log.info("workload without merge") - env.pageserver.restart() # reset the metrics - without_merge = workload() - - log.info("workload with merge") env.pageserver.patch_config_toml_nonrecursive({"server_side_batch_timeout": batch_timeout}) env.pageserver.restart() - with_merge = workload() - - results = { - "baseline": without_merge.normalized, - "candiate": with_merge.normalized, - "delta": with_merge.normalized - without_merge.normalized, - "relative": with_merge.normalized / without_merge.normalized - } + results = workload() # # Assertions on collected data