mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-15 17:32:56 +00:00
Merge remote-tracking branch 'upstream/main' into jcsp/storcon-split-refine
This commit is contained in:
1
.github/workflows/actionlint.yml
vendored
1
.github/workflows/actionlint.yml
vendored
@@ -17,6 +17,7 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
actionlint:
|
||||
if: ${{ !contains(github.event.pull_request.labels.*.name, 'run-no-ci') }}
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
2
.github/workflows/build_and_test.yml
vendored
2
.github/workflows/build_and_test.yml
vendored
@@ -26,8 +26,8 @@ env:
|
||||
|
||||
jobs:
|
||||
check-permissions:
|
||||
if: ${{ !contains(github.event.pull_request.labels.*.name, 'run-no-ci') }}
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Disallow PRs from forks
|
||||
if: |
|
||||
|
||||
2
.github/workflows/neon_extra_builds.yml
vendored
2
.github/workflows/neon_extra_builds.yml
vendored
@@ -117,6 +117,7 @@ jobs:
|
||||
|
||||
check-linux-arm-build:
|
||||
timeout-minutes: 90
|
||||
if: ${{ !contains(github.event.pull_request.labels.*.name, 'run-no-ci') }}
|
||||
runs-on: [ self-hosted, dev, arm64 ]
|
||||
|
||||
env:
|
||||
@@ -237,6 +238,7 @@ jobs:
|
||||
|
||||
check-codestyle-rust-arm:
|
||||
timeout-minutes: 90
|
||||
if: ${{ !contains(github.event.pull_request.labels.*.name, 'run-no-ci') }}
|
||||
runs-on: [ self-hosted, dev, arm64 ]
|
||||
|
||||
container:
|
||||
|
||||
32
Cargo.lock
generated
32
Cargo.lock
generated
@@ -1639,6 +1639,22 @@ dependencies = [
|
||||
"rusticata-macros",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "desim"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bytes",
|
||||
"hex",
|
||||
"parking_lot 0.12.1",
|
||||
"rand 0.8.5",
|
||||
"scopeguard",
|
||||
"smallvec",
|
||||
"tracing",
|
||||
"utils",
|
||||
"workspace_hack",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "diesel"
|
||||
version = "2.1.4"
|
||||
@@ -2247,11 +2263,11 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "hashlink"
|
||||
version = "0.8.2"
|
||||
version = "0.8.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0761a1b9491c4f2e3d66aa0f62d0fba0af9a0e2852e4d48ea506632a4b56e6aa"
|
||||
checksum = "e8094feaf31ff591f651a2664fb9cfd92bba7a60ce3197265e9482ebe753c8f7"
|
||||
dependencies = [
|
||||
"hashbrown 0.13.2",
|
||||
"hashbrown 0.14.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3936,6 +3952,7 @@ dependencies = [
|
||||
"pin-project-lite",
|
||||
"postgres-protocol",
|
||||
"rand 0.8.5",
|
||||
"serde",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tracing",
|
||||
@@ -4827,6 +4844,7 @@ dependencies = [
|
||||
"clap",
|
||||
"const_format",
|
||||
"crc32c",
|
||||
"desim",
|
||||
"fail",
|
||||
"fs2",
|
||||
"futures",
|
||||
@@ -4842,6 +4860,7 @@ dependencies = [
|
||||
"postgres_backend",
|
||||
"postgres_ffi",
|
||||
"pq_proto",
|
||||
"rand 0.8.5",
|
||||
"regex",
|
||||
"remote_storage",
|
||||
"reqwest",
|
||||
@@ -4862,8 +4881,10 @@ dependencies = [
|
||||
"tokio-util",
|
||||
"toml_edit",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"url",
|
||||
"utils",
|
||||
"walproposer",
|
||||
"workspace_hack",
|
||||
]
|
||||
|
||||
@@ -5740,7 +5761,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "tokio-epoll-uring"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/neondatabase/tokio-epoll-uring.git?branch=main#d6a1c93442fb6b3a5bec490204961134e54925dc"
|
||||
source = "git+https://github.com/neondatabase/tokio-epoll-uring.git?branch=main#868d2c42b5d54ca82fead6e8f2f233b69a540d3e"
|
||||
dependencies = [
|
||||
"futures",
|
||||
"nix 0.26.4",
|
||||
@@ -6265,8 +6286,9 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "uring-common"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/neondatabase/tokio-epoll-uring.git?branch=main#d6a1c93442fb6b3a5bec490204961134e54925dc"
|
||||
source = "git+https://github.com/neondatabase/tokio-epoll-uring.git?branch=main#868d2c42b5d54ca82fead6e8f2f233b69a540d3e"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"io-uring",
|
||||
"libc",
|
||||
]
|
||||
|
||||
@@ -18,6 +18,7 @@ members = [
|
||||
"libs/pageserver_api",
|
||||
"libs/postgres_ffi",
|
||||
"libs/safekeeper_api",
|
||||
"libs/desim",
|
||||
"libs/utils",
|
||||
"libs/consumption_metrics",
|
||||
"libs/postgres_backend",
|
||||
@@ -80,7 +81,7 @@ futures-core = "0.3"
|
||||
futures-util = "0.3"
|
||||
git-version = "0.3"
|
||||
hashbrown = "0.13"
|
||||
hashlink = "0.8.1"
|
||||
hashlink = "0.8.4"
|
||||
hdrhistogram = "7.5.2"
|
||||
hex = "0.4"
|
||||
hex-literal = "0.4"
|
||||
@@ -203,6 +204,7 @@ postgres_ffi = { version = "0.1", path = "./libs/postgres_ffi/" }
|
||||
pq_proto = { version = "0.1", path = "./libs/pq_proto/" }
|
||||
remote_storage = { version = "0.1", path = "./libs/remote_storage/" }
|
||||
safekeeper_api = { version = "0.1", path = "./libs/safekeeper_api" }
|
||||
desim = { version = "0.1", path = "./libs/desim" }
|
||||
storage_broker = { version = "0.1", path = "./storage_broker/" } # Note: main broker code is inside the binary crate, so linking with the library shouldn't be heavy.
|
||||
tenant_size_model = { version = "0.1", path = "./libs/tenant_size_model/" }
|
||||
tracing-utils = { version = "0.1", path = "./libs/tracing-utils/" }
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::collections::HashMap;
|
||||
use std::env;
|
||||
use std::fs;
|
||||
use std::io::BufRead;
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
use std::os::unix::fs::{symlink, PermissionsExt};
|
||||
use std::path::Path;
|
||||
use std::process::{Command, Stdio};
|
||||
use std::str::FromStr;
|
||||
@@ -634,6 +634,48 @@ impl ComputeNode {
|
||||
// Update pg_hba.conf received with basebackup.
|
||||
update_pg_hba(pgdata_path)?;
|
||||
|
||||
// Place pg_dynshmem under /dev/shm. This allows us to use
|
||||
// 'dynamic_shared_memory_type = mmap' so that the files are placed in
|
||||
// /dev/shm, similar to how 'dynamic_shared_memory_type = posix' works.
|
||||
//
|
||||
// Why on earth don't we just stick to the 'posix' default, you might
|
||||
// ask. It turns out that making large allocations with 'posix' doesn't
|
||||
// work very well with autoscaling. The behavior we want is that:
|
||||
//
|
||||
// 1. You can make large DSM allocations, larger than the current RAM
|
||||
// size of the VM, without errors
|
||||
//
|
||||
// 2. If the allocated memory is really used, the VM is scaled up
|
||||
// automatically to accommodate that
|
||||
//
|
||||
// We try to make that possible by having swap in the VM. But with the
|
||||
// default 'posix' DSM implementation, we fail step 1, even when there's
|
||||
// plenty of swap available. PostgreSQL uses posix_fallocate() to create
|
||||
// the shmem segment, which is really just a file in /dev/shm in Linux,
|
||||
// but posix_fallocate() on tmpfs returns ENOMEM if the size is larger
|
||||
// than available RAM.
|
||||
//
|
||||
// Using 'dynamic_shared_memory_type = mmap' works around that, because
|
||||
// the Postgres 'mmap' DSM implementation doesn't use
|
||||
// posix_fallocate(). Instead, it uses repeated calls to write(2) to
|
||||
// fill the file with zeros. It's weird that that differs between
|
||||
// 'posix' and 'mmap', but we take advantage of it. When the file is
|
||||
// filled slowly with write(2), the kernel allows it to grow larger, as
|
||||
// long as there's swap available.
|
||||
//
|
||||
// In short, using 'dynamic_shared_memory_type = mmap' allows us one DSM
|
||||
// segment to be larger than currently available RAM. But because we
|
||||
// don't want to store it on a real file, which the kernel would try to
|
||||
// flush to disk, so symlink pg_dynshm to /dev/shm.
|
||||
//
|
||||
// We don't set 'dynamic_shared_memory_type = mmap' here, we let the
|
||||
// control plane control that option. If 'mmap' is not used, this
|
||||
// symlink doesn't affect anything.
|
||||
//
|
||||
// See https://github.com/neondatabase/autoscaling/issues/800
|
||||
std::fs::remove_dir(pgdata_path.join("pg_dynshmem"))?;
|
||||
symlink("/dev/shm/", pgdata_path.join("pg_dynshmem"))?;
|
||||
|
||||
match spec.mode {
|
||||
ComputeMode::Primary => {}
|
||||
ComputeMode::Replica | ComputeMode::Static(..) => {
|
||||
|
||||
18
libs/desim/Cargo.toml
Normal file
18
libs/desim/Cargo.toml
Normal file
@@ -0,0 +1,18 @@
|
||||
[package]
|
||||
name = "desim"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
rand.workspace = true
|
||||
tracing.workspace = true
|
||||
bytes.workspace = true
|
||||
utils.workspace = true
|
||||
parking_lot.workspace = true
|
||||
hex.workspace = true
|
||||
scopeguard.workspace = true
|
||||
smallvec = { workspace = true, features = ["write"] }
|
||||
|
||||
workspace_hack.workspace = true
|
||||
7
libs/desim/README.md
Normal file
7
libs/desim/README.md
Normal file
@@ -0,0 +1,7 @@
|
||||
# Discrete Event SIMulator
|
||||
|
||||
This is a library for running simulations of distributed systems. The main idea is borrowed from [FoundationDB](https://www.youtube.com/watch?v=4fFDFbi3toc).
|
||||
|
||||
Each node runs as a separate thread. This library was not optimized for speed yet, but it's already much faster than running usual intergration tests in real time, because it uses virtual simulation time and can fast-forward time to skip intervals where all nodes are doing nothing but sleeping or waiting for something.
|
||||
|
||||
The original purpose for this library is to test walproposer and safekeeper implementation working together, in a scenarios close to the real world environment. This simulator is determenistic and can inject failures in networking without waiting minutes of wall-time to trigger timeout, which makes it easier to find bugs in our consensus implementation compared to using integration tests.
|
||||
108
libs/desim/src/chan.rs
Normal file
108
libs/desim/src/chan.rs
Normal file
@@ -0,0 +1,108 @@
|
||||
use std::{collections::VecDeque, sync::Arc};
|
||||
|
||||
use parking_lot::{Mutex, MutexGuard};
|
||||
|
||||
use crate::executor::{self, PollSome, Waker};
|
||||
|
||||
/// FIFO channel with blocking send and receive. Can be cloned and shared between threads.
|
||||
/// Blocking functions should be used only from threads that are managed by the executor.
|
||||
pub struct Chan<T> {
|
||||
shared: Arc<State<T>>,
|
||||
}
|
||||
|
||||
impl<T> Clone for Chan<T> {
|
||||
fn clone(&self) -> Self {
|
||||
Chan {
|
||||
shared: self.shared.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Default for Chan<T> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Chan<T> {
|
||||
pub fn new() -> Chan<T> {
|
||||
Chan {
|
||||
shared: Arc::new(State {
|
||||
queue: Mutex::new(VecDeque::new()),
|
||||
waker: Waker::new(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a message from the front of the queue, block if the queue is empty.
|
||||
/// If not called from the executor thread, it can block forever.
|
||||
pub fn recv(&self) -> T {
|
||||
self.shared.recv()
|
||||
}
|
||||
|
||||
/// Panic if the queue is empty.
|
||||
pub fn must_recv(&self) -> T {
|
||||
self.shared
|
||||
.try_recv()
|
||||
.expect("message should've been ready")
|
||||
}
|
||||
|
||||
/// Get a message from the front of the queue, return None if the queue is empty.
|
||||
/// Never blocks.
|
||||
pub fn try_recv(&self) -> Option<T> {
|
||||
self.shared.try_recv()
|
||||
}
|
||||
|
||||
/// Send a message to the back of the queue.
|
||||
pub fn send(&self, t: T) {
|
||||
self.shared.send(t);
|
||||
}
|
||||
}
|
||||
|
||||
struct State<T> {
|
||||
queue: Mutex<VecDeque<T>>,
|
||||
waker: Waker,
|
||||
}
|
||||
|
||||
impl<T> State<T> {
|
||||
fn send(&self, t: T) {
|
||||
self.queue.lock().push_back(t);
|
||||
self.waker.wake_all();
|
||||
}
|
||||
|
||||
fn try_recv(&self) -> Option<T> {
|
||||
let mut q = self.queue.lock();
|
||||
q.pop_front()
|
||||
}
|
||||
|
||||
fn recv(&self) -> T {
|
||||
// interrupt the receiver to prevent consuming everything at once
|
||||
executor::yield_me(0);
|
||||
|
||||
let mut queue = self.queue.lock();
|
||||
if let Some(t) = queue.pop_front() {
|
||||
return t;
|
||||
}
|
||||
loop {
|
||||
self.waker.wake_me_later();
|
||||
if let Some(t) = queue.pop_front() {
|
||||
return t;
|
||||
}
|
||||
MutexGuard::unlocked(&mut queue, || {
|
||||
executor::yield_me(-1);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> PollSome for Chan<T> {
|
||||
/// Schedules a wakeup for the current thread.
|
||||
fn wake_me(&self) {
|
||||
self.shared.waker.wake_me_later();
|
||||
}
|
||||
|
||||
/// Checks if chan has any pending messages.
|
||||
fn has_some(&self) -> bool {
|
||||
!self.shared.queue.lock().is_empty()
|
||||
}
|
||||
}
|
||||
483
libs/desim/src/executor.rs
Normal file
483
libs/desim/src/executor.rs
Normal file
@@ -0,0 +1,483 @@
|
||||
use std::{
|
||||
panic::AssertUnwindSafe,
|
||||
sync::{
|
||||
atomic::{AtomicBool, AtomicU32, AtomicU8, Ordering},
|
||||
mpsc, Arc, OnceLock,
|
||||
},
|
||||
thread::JoinHandle,
|
||||
};
|
||||
|
||||
use tracing::{debug, error, trace};
|
||||
|
||||
use crate::time::Timing;
|
||||
|
||||
/// Stores status of the running threads. Threads are registered in the runtime upon creation
|
||||
/// and deregistered upon termination.
|
||||
pub struct Runtime {
|
||||
// stores handles to all threads that are currently running
|
||||
threads: Vec<ThreadHandle>,
|
||||
// stores current time and pending wakeups
|
||||
clock: Arc<Timing>,
|
||||
// thread counter
|
||||
thread_counter: AtomicU32,
|
||||
// Thread step counter -- how many times all threads has been actually
|
||||
// stepped (note that all world/time/executor/thread have slightly different
|
||||
// meaning of steps). For observability.
|
||||
pub step_counter: u64,
|
||||
}
|
||||
|
||||
impl Runtime {
|
||||
/// Init new runtime, no running threads.
|
||||
pub fn new(clock: Arc<Timing>) -> Self {
|
||||
Self {
|
||||
threads: Vec::new(),
|
||||
clock,
|
||||
thread_counter: AtomicU32::new(0),
|
||||
step_counter: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Spawn a new thread and register it in the runtime.
|
||||
pub fn spawn<F>(&mut self, f: F) -> ExternalHandle
|
||||
where
|
||||
F: FnOnce() + Send + 'static,
|
||||
{
|
||||
let (tx, rx) = mpsc::channel();
|
||||
|
||||
let clock = self.clock.clone();
|
||||
let tid = self.thread_counter.fetch_add(1, Ordering::SeqCst);
|
||||
debug!("spawning thread-{}", tid);
|
||||
|
||||
let join = std::thread::spawn(move || {
|
||||
let _guard = tracing::info_span!("", tid).entered();
|
||||
|
||||
let res = std::panic::catch_unwind(AssertUnwindSafe(|| {
|
||||
with_thread_context(|ctx| {
|
||||
assert!(ctx.clock.set(clock).is_ok());
|
||||
ctx.id.store(tid, Ordering::SeqCst);
|
||||
tx.send(ctx.clone()).expect("failed to send thread context");
|
||||
// suspend thread to put it to `threads` in sleeping state
|
||||
ctx.yield_me(0);
|
||||
});
|
||||
|
||||
// start user-provided function
|
||||
f();
|
||||
}));
|
||||
debug!("thread finished");
|
||||
|
||||
if let Err(e) = res {
|
||||
with_thread_context(|ctx| {
|
||||
if !ctx.allow_panic.load(std::sync::atomic::Ordering::SeqCst) {
|
||||
error!("thread panicked, terminating the process: {:?}", e);
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
debug!("thread panicked: {:?}", e);
|
||||
let mut result = ctx.result.lock();
|
||||
if result.0 == -1 {
|
||||
*result = (256, format!("thread panicked: {:?}", e));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
with_thread_context(|ctx| {
|
||||
ctx.finish_me();
|
||||
});
|
||||
});
|
||||
|
||||
let ctx = rx.recv().expect("failed to receive thread context");
|
||||
let handle = ThreadHandle::new(ctx.clone(), join);
|
||||
|
||||
self.threads.push(handle);
|
||||
|
||||
ExternalHandle { ctx }
|
||||
}
|
||||
|
||||
/// Returns true if there are any unfinished activity, such as running thread or pending events.
|
||||
/// Otherwise returns false, which means all threads are blocked forever.
|
||||
pub fn step(&mut self) -> bool {
|
||||
trace!("runtime step");
|
||||
|
||||
// have we run any thread?
|
||||
let mut ran = false;
|
||||
|
||||
self.threads.retain(|thread: &ThreadHandle| {
|
||||
let res = thread.ctx.wakeup.compare_exchange(
|
||||
PENDING_WAKEUP,
|
||||
NO_WAKEUP,
|
||||
Ordering::SeqCst,
|
||||
Ordering::SeqCst,
|
||||
);
|
||||
if res.is_err() {
|
||||
// thread has no pending wakeups, leaving as is
|
||||
return true;
|
||||
}
|
||||
ran = true;
|
||||
|
||||
trace!("entering thread-{}", thread.ctx.tid());
|
||||
let status = thread.step();
|
||||
self.step_counter += 1;
|
||||
trace!(
|
||||
"out of thread-{} with status {:?}",
|
||||
thread.ctx.tid(),
|
||||
status
|
||||
);
|
||||
|
||||
if status == Status::Sleep {
|
||||
true
|
||||
} else {
|
||||
trace!("thread has finished");
|
||||
// removing the thread from the list
|
||||
false
|
||||
}
|
||||
});
|
||||
|
||||
if !ran {
|
||||
trace!("no threads were run, stepping clock");
|
||||
if let Some(ctx_to_wake) = self.clock.step() {
|
||||
trace!("waking up thread-{}", ctx_to_wake.tid());
|
||||
ctx_to_wake.inc_wake();
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
/// Kill all threads. This is done by setting a flag in each thread context and waking it up.
|
||||
pub fn crash_all_threads(&mut self) {
|
||||
for thread in self.threads.iter() {
|
||||
thread.ctx.crash_stop();
|
||||
}
|
||||
|
||||
// all threads should be finished after a few steps
|
||||
while !self.threads.is_empty() {
|
||||
self.step();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Runtime {
|
||||
fn drop(&mut self) {
|
||||
debug!("dropping the runtime");
|
||||
self.crash_all_threads();
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ExternalHandle {
|
||||
ctx: Arc<ThreadContext>,
|
||||
}
|
||||
|
||||
impl ExternalHandle {
|
||||
/// Returns true if thread has finished execution.
|
||||
pub fn is_finished(&self) -> bool {
|
||||
let status = self.ctx.mutex.lock();
|
||||
*status == Status::Finished
|
||||
}
|
||||
|
||||
/// Returns exitcode and message, which is available after thread has finished execution.
|
||||
pub fn result(&self) -> (i32, String) {
|
||||
let result = self.ctx.result.lock();
|
||||
result.clone()
|
||||
}
|
||||
|
||||
/// Returns thread id.
|
||||
pub fn id(&self) -> u32 {
|
||||
self.ctx.id.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
/// Sets a flag to crash thread on the next wakeup.
|
||||
pub fn crash_stop(&self) {
|
||||
self.ctx.crash_stop();
|
||||
}
|
||||
}
|
||||
|
||||
struct ThreadHandle {
|
||||
ctx: Arc<ThreadContext>,
|
||||
_join: JoinHandle<()>,
|
||||
}
|
||||
|
||||
impl ThreadHandle {
|
||||
/// Create a new [`ThreadHandle`] and wait until thread will enter [`Status::Sleep`] state.
|
||||
fn new(ctx: Arc<ThreadContext>, join: JoinHandle<()>) -> Self {
|
||||
let mut status = ctx.mutex.lock();
|
||||
// wait until thread will go into the first yield
|
||||
while *status != Status::Sleep {
|
||||
ctx.condvar.wait(&mut status);
|
||||
}
|
||||
drop(status);
|
||||
|
||||
Self { ctx, _join: join }
|
||||
}
|
||||
|
||||
/// Allows thread to execute one step of its execution.
|
||||
/// Returns [`Status`] of the thread after the step.
|
||||
fn step(&self) -> Status {
|
||||
let mut status = self.ctx.mutex.lock();
|
||||
assert!(matches!(*status, Status::Sleep));
|
||||
|
||||
*status = Status::Running;
|
||||
self.ctx.condvar.notify_all();
|
||||
|
||||
while *status == Status::Running {
|
||||
self.ctx.condvar.wait(&mut status);
|
||||
}
|
||||
|
||||
*status
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
enum Status {
|
||||
/// Thread is running.
|
||||
Running,
|
||||
/// Waiting for event to complete, will be resumed by the executor step, once wakeup flag is set.
|
||||
Sleep,
|
||||
/// Thread finished execution.
|
||||
Finished,
|
||||
}
|
||||
|
||||
const NO_WAKEUP: u8 = 0;
|
||||
const PENDING_WAKEUP: u8 = 1;
|
||||
|
||||
pub struct ThreadContext {
|
||||
id: AtomicU32,
|
||||
// used to block thread until it is woken up
|
||||
mutex: parking_lot::Mutex<Status>,
|
||||
condvar: parking_lot::Condvar,
|
||||
// used as a flag to indicate runtime that thread is ready to be woken up
|
||||
wakeup: AtomicU8,
|
||||
clock: OnceLock<Arc<Timing>>,
|
||||
// execution result, set by exit() call
|
||||
result: parking_lot::Mutex<(i32, String)>,
|
||||
// determines if process should be killed on receiving panic
|
||||
allow_panic: AtomicBool,
|
||||
// acts as a signal that thread should crash itself on the next wakeup
|
||||
crash_request: AtomicBool,
|
||||
}
|
||||
|
||||
impl ThreadContext {
|
||||
pub(crate) fn new() -> Self {
|
||||
Self {
|
||||
id: AtomicU32::new(0),
|
||||
mutex: parking_lot::Mutex::new(Status::Running),
|
||||
condvar: parking_lot::Condvar::new(),
|
||||
wakeup: AtomicU8::new(NO_WAKEUP),
|
||||
clock: OnceLock::new(),
|
||||
result: parking_lot::Mutex::new((-1, String::new())),
|
||||
allow_panic: AtomicBool::new(false),
|
||||
crash_request: AtomicBool::new(false),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Functions for executor to control thread execution.
|
||||
impl ThreadContext {
|
||||
/// Set atomic flag to indicate that thread is ready to be woken up.
|
||||
fn inc_wake(&self) {
|
||||
self.wakeup.store(PENDING_WAKEUP, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
/// Internal function used for event queues.
|
||||
pub(crate) fn schedule_wakeup(self: &Arc<Self>, after_ms: u64) {
|
||||
self.clock
|
||||
.get()
|
||||
.unwrap()
|
||||
.schedule_wakeup(after_ms, self.clone());
|
||||
}
|
||||
|
||||
fn tid(&self) -> u32 {
|
||||
self.id.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
fn crash_stop(&self) {
|
||||
let status = self.mutex.lock();
|
||||
if *status == Status::Finished {
|
||||
debug!(
|
||||
"trying to crash thread-{}, which is already finished",
|
||||
self.tid()
|
||||
);
|
||||
return;
|
||||
}
|
||||
assert!(matches!(*status, Status::Sleep));
|
||||
drop(status);
|
||||
|
||||
self.allow_panic.store(true, Ordering::SeqCst);
|
||||
self.crash_request.store(true, Ordering::SeqCst);
|
||||
// set a wakeup
|
||||
self.inc_wake();
|
||||
// it will panic on the next wakeup
|
||||
}
|
||||
}
|
||||
|
||||
// Internal functions.
|
||||
impl ThreadContext {
|
||||
/// Blocks thread until it's woken up by the executor. If `after_ms` is 0, is will be
|
||||
/// woken on the next step. If `after_ms` > 0, wakeup is scheduled after that time.
|
||||
/// Otherwise wakeup is not scheduled inside `yield_me`, and should be arranged before
|
||||
/// calling this function.
|
||||
fn yield_me(self: &Arc<Self>, after_ms: i64) {
|
||||
let mut status = self.mutex.lock();
|
||||
assert!(matches!(*status, Status::Running));
|
||||
|
||||
match after_ms.cmp(&0) {
|
||||
std::cmp::Ordering::Less => {
|
||||
// block until something wakes us up
|
||||
}
|
||||
std::cmp::Ordering::Equal => {
|
||||
// tell executor that we are ready to be woken up
|
||||
self.inc_wake();
|
||||
}
|
||||
std::cmp::Ordering::Greater => {
|
||||
// schedule wakeup
|
||||
self.clock
|
||||
.get()
|
||||
.unwrap()
|
||||
.schedule_wakeup(after_ms as u64, self.clone());
|
||||
}
|
||||
}
|
||||
|
||||
*status = Status::Sleep;
|
||||
self.condvar.notify_all();
|
||||
|
||||
// wait until executor wakes us up
|
||||
while *status != Status::Running {
|
||||
self.condvar.wait(&mut status);
|
||||
}
|
||||
|
||||
if self.crash_request.load(Ordering::SeqCst) {
|
||||
panic!("crashed by request");
|
||||
}
|
||||
}
|
||||
|
||||
/// Called only once, exactly before thread finishes execution.
|
||||
fn finish_me(&self) {
|
||||
let mut status = self.mutex.lock();
|
||||
assert!(matches!(*status, Status::Running));
|
||||
|
||||
*status = Status::Finished;
|
||||
{
|
||||
let mut result = self.result.lock();
|
||||
if result.0 == -1 {
|
||||
*result = (0, "finished normally".to_owned());
|
||||
}
|
||||
}
|
||||
self.condvar.notify_all();
|
||||
}
|
||||
}
|
||||
|
||||
/// Invokes the given closure with a reference to the current thread [`ThreadContext`].
|
||||
#[inline(always)]
|
||||
fn with_thread_context<T>(f: impl FnOnce(&Arc<ThreadContext>) -> T) -> T {
|
||||
thread_local!(static THREAD_DATA: Arc<ThreadContext> = Arc::new(ThreadContext::new()));
|
||||
THREAD_DATA.with(f)
|
||||
}
|
||||
|
||||
/// Waker is used to wake up threads that are blocked on condition.
|
||||
/// It keeps track of contexts [`Arc<ThreadContext>`] and can increment the counter
|
||||
/// of several contexts to send a notification.
|
||||
pub struct Waker {
|
||||
// contexts that are waiting for a notification
|
||||
contexts: parking_lot::Mutex<smallvec::SmallVec<[Arc<ThreadContext>; 8]>>,
|
||||
}
|
||||
|
||||
impl Default for Waker {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Waker {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
contexts: parking_lot::Mutex::new(smallvec::SmallVec::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Subscribe current thread to receive a wake notification later.
|
||||
pub fn wake_me_later(&self) {
|
||||
with_thread_context(|ctx| {
|
||||
self.contexts.lock().push(ctx.clone());
|
||||
});
|
||||
}
|
||||
|
||||
/// Wake up all threads that are waiting for a notification and clear the list.
|
||||
pub fn wake_all(&self) {
|
||||
let mut v = self.contexts.lock();
|
||||
for ctx in v.iter() {
|
||||
ctx.inc_wake();
|
||||
}
|
||||
v.clear();
|
||||
}
|
||||
}
|
||||
|
||||
/// See [`ThreadContext::yield_me`].
|
||||
pub fn yield_me(after_ms: i64) {
|
||||
with_thread_context(|ctx| ctx.yield_me(after_ms))
|
||||
}
|
||||
|
||||
/// Get current time.
|
||||
pub fn now() -> u64 {
|
||||
with_thread_context(|ctx| ctx.clock.get().unwrap().now())
|
||||
}
|
||||
|
||||
pub fn exit(code: i32, msg: String) {
|
||||
with_thread_context(|ctx| {
|
||||
ctx.allow_panic.store(true, Ordering::SeqCst);
|
||||
let mut result = ctx.result.lock();
|
||||
*result = (code, msg);
|
||||
panic!("exit");
|
||||
});
|
||||
}
|
||||
|
||||
pub(crate) fn get_thread_ctx() -> Arc<ThreadContext> {
|
||||
with_thread_context(|ctx| ctx.clone())
|
||||
}
|
||||
|
||||
/// Trait for polling channels until they have something.
|
||||
pub trait PollSome {
|
||||
/// Schedule wakeup for message arrival.
|
||||
fn wake_me(&self);
|
||||
|
||||
/// Check if channel has a ready message.
|
||||
fn has_some(&self) -> bool;
|
||||
}
|
||||
|
||||
/// Blocks current thread until one of the channels has a ready message. Returns
|
||||
/// index of the channel that has a message. If timeout is reached, returns None.
|
||||
///
|
||||
/// Negative timeout means block forever. Zero timeout means check channels and return
|
||||
/// immediately. Positive timeout means block until timeout is reached.
|
||||
pub fn epoll_chans(chans: &[Box<dyn PollSome>], timeout: i64) -> Option<usize> {
|
||||
let deadline = if timeout < 0 {
|
||||
0
|
||||
} else {
|
||||
now() + timeout as u64
|
||||
};
|
||||
|
||||
loop {
|
||||
for chan in chans {
|
||||
chan.wake_me()
|
||||
}
|
||||
|
||||
for (i, chan) in chans.iter().enumerate() {
|
||||
if chan.has_some() {
|
||||
return Some(i);
|
||||
}
|
||||
}
|
||||
|
||||
if timeout < 0 {
|
||||
// block until wakeup
|
||||
yield_me(-1);
|
||||
} else {
|
||||
let current_time = now();
|
||||
if current_time >= deadline {
|
||||
return None;
|
||||
}
|
||||
|
||||
yield_me((deadline - current_time) as i64);
|
||||
}
|
||||
}
|
||||
}
|
||||
8
libs/desim/src/lib.rs
Normal file
8
libs/desim/src/lib.rs
Normal file
@@ -0,0 +1,8 @@
|
||||
pub mod chan;
|
||||
pub mod executor;
|
||||
pub mod network;
|
||||
pub mod node_os;
|
||||
pub mod options;
|
||||
pub mod proto;
|
||||
pub mod time;
|
||||
pub mod world;
|
||||
451
libs/desim/src/network.rs
Normal file
451
libs/desim/src/network.rs
Normal file
@@ -0,0 +1,451 @@
|
||||
use std::{
|
||||
cmp::Ordering,
|
||||
collections::{BinaryHeap, VecDeque},
|
||||
fmt::{self, Debug},
|
||||
ops::DerefMut,
|
||||
sync::{mpsc, Arc},
|
||||
};
|
||||
|
||||
use parking_lot::{
|
||||
lock_api::{MappedMutexGuard, MutexGuard},
|
||||
Mutex, RawMutex,
|
||||
};
|
||||
use rand::rngs::StdRng;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::{
|
||||
executor::{self, ThreadContext},
|
||||
options::NetworkOptions,
|
||||
proto::NetEvent,
|
||||
proto::NodeEvent,
|
||||
};
|
||||
|
||||
use super::{chan::Chan, proto::AnyMessage};
|
||||
|
||||
pub struct NetworkTask {
|
||||
options: Arc<NetworkOptions>,
|
||||
connections: Mutex<Vec<VirtualConnection>>,
|
||||
/// min-heap of connections having something to deliver.
|
||||
events: Mutex<BinaryHeap<Event>>,
|
||||
task_context: Arc<ThreadContext>,
|
||||
}
|
||||
|
||||
impl NetworkTask {
|
||||
pub fn start_new(options: Arc<NetworkOptions>, tx: mpsc::Sender<Arc<NetworkTask>>) {
|
||||
let ctx = executor::get_thread_ctx();
|
||||
let task = Arc::new(Self {
|
||||
options,
|
||||
connections: Mutex::new(Vec::new()),
|
||||
events: Mutex::new(BinaryHeap::new()),
|
||||
task_context: ctx,
|
||||
});
|
||||
|
||||
// send the task upstream
|
||||
tx.send(task.clone()).unwrap();
|
||||
|
||||
// start the task
|
||||
task.start();
|
||||
}
|
||||
|
||||
pub fn start_new_connection(self: &Arc<Self>, rng: StdRng, dst_accept: Chan<NodeEvent>) -> TCP {
|
||||
let now = executor::now();
|
||||
let connection_id = self.connections.lock().len();
|
||||
|
||||
let vc = VirtualConnection {
|
||||
connection_id,
|
||||
dst_accept,
|
||||
dst_sockets: [Chan::new(), Chan::new()],
|
||||
state: Mutex::new(ConnectionState {
|
||||
buffers: [NetworkBuffer::new(None), NetworkBuffer::new(Some(now))],
|
||||
rng,
|
||||
}),
|
||||
};
|
||||
vc.schedule_timeout(self);
|
||||
vc.send_connect(self);
|
||||
|
||||
let recv_chan = vc.dst_sockets[0].clone();
|
||||
self.connections.lock().push(vc);
|
||||
|
||||
TCP {
|
||||
net: self.clone(),
|
||||
conn_id: connection_id,
|
||||
dir: 0,
|
||||
recv_chan,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// private functions
|
||||
impl NetworkTask {
|
||||
/// Schedule to wakeup network task (self) `after_ms` later to deliver
|
||||
/// messages of connection `id`.
|
||||
fn schedule(&self, id: usize, after_ms: u64) {
|
||||
self.events.lock().push(Event {
|
||||
time: executor::now() + after_ms,
|
||||
conn_id: id,
|
||||
});
|
||||
self.task_context.schedule_wakeup(after_ms);
|
||||
}
|
||||
|
||||
/// Get locked connection `id`.
|
||||
fn get(&self, id: usize) -> MappedMutexGuard<'_, RawMutex, VirtualConnection> {
|
||||
MutexGuard::map(self.connections.lock(), |connections| {
|
||||
connections.get_mut(id).unwrap()
|
||||
})
|
||||
}
|
||||
|
||||
fn collect_pending_events(&self, now: u64, vec: &mut Vec<Event>) {
|
||||
vec.clear();
|
||||
let mut events = self.events.lock();
|
||||
while let Some(event) = events.peek() {
|
||||
if event.time > now {
|
||||
break;
|
||||
}
|
||||
let event = events.pop().unwrap();
|
||||
vec.push(event);
|
||||
}
|
||||
}
|
||||
|
||||
fn start(self: &Arc<Self>) {
|
||||
debug!("started network task");
|
||||
|
||||
let mut events = Vec::new();
|
||||
loop {
|
||||
let now = executor::now();
|
||||
self.collect_pending_events(now, &mut events);
|
||||
|
||||
for event in events.drain(..) {
|
||||
let conn = self.get(event.conn_id);
|
||||
conn.process(self);
|
||||
}
|
||||
|
||||
// block until wakeup
|
||||
executor::yield_me(-1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 0 - from node(0) to node(1)
|
||||
// 1 - from node(1) to node(0)
|
||||
type MessageDirection = u8;
|
||||
|
||||
fn sender_str(dir: MessageDirection) -> &'static str {
|
||||
match dir {
|
||||
0 => "client",
|
||||
1 => "server",
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn receiver_str(dir: MessageDirection) -> &'static str {
|
||||
match dir {
|
||||
0 => "server",
|
||||
1 => "client",
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Virtual connection between two nodes.
|
||||
/// Node 0 is the creator of the connection (client),
|
||||
/// and node 1 is the acceptor (server).
|
||||
struct VirtualConnection {
|
||||
connection_id: usize,
|
||||
/// one-off chan, used to deliver Accept message to dst
|
||||
dst_accept: Chan<NodeEvent>,
|
||||
/// message sinks
|
||||
dst_sockets: [Chan<NetEvent>; 2],
|
||||
state: Mutex<ConnectionState>,
|
||||
}
|
||||
|
||||
struct ConnectionState {
|
||||
buffers: [NetworkBuffer; 2],
|
||||
rng: StdRng,
|
||||
}
|
||||
|
||||
impl VirtualConnection {
|
||||
/// Notify the future about the possible timeout.
|
||||
fn schedule_timeout(&self, net: &NetworkTask) {
|
||||
if let Some(timeout) = net.options.keepalive_timeout {
|
||||
net.schedule(self.connection_id, timeout);
|
||||
}
|
||||
}
|
||||
|
||||
/// Send the handshake (Accept) to the server.
|
||||
fn send_connect(&self, net: &NetworkTask) {
|
||||
let now = executor::now();
|
||||
let mut state = self.state.lock();
|
||||
let delay = net.options.connect_delay.delay(&mut state.rng);
|
||||
let buffer = &mut state.buffers[0];
|
||||
assert!(buffer.buf.is_empty());
|
||||
assert!(!buffer.recv_closed);
|
||||
assert!(!buffer.send_closed);
|
||||
assert!(buffer.last_recv.is_none());
|
||||
|
||||
let delay = if let Some(ms) = delay {
|
||||
ms
|
||||
} else {
|
||||
debug!("NET: TCP #{} dropped connect", self.connection_id);
|
||||
buffer.send_closed = true;
|
||||
return;
|
||||
};
|
||||
|
||||
// Send a message into the future.
|
||||
buffer
|
||||
.buf
|
||||
.push_back((now + delay, AnyMessage::InternalConnect));
|
||||
net.schedule(self.connection_id, delay);
|
||||
}
|
||||
|
||||
/// Transmit some of the messages from the buffer to the nodes.
|
||||
fn process(&self, net: &Arc<NetworkTask>) {
|
||||
let now = executor::now();
|
||||
|
||||
let mut state = self.state.lock();
|
||||
|
||||
for direction in 0..2 {
|
||||
self.process_direction(
|
||||
net,
|
||||
state.deref_mut(),
|
||||
now,
|
||||
direction as MessageDirection,
|
||||
&self.dst_sockets[direction ^ 1],
|
||||
);
|
||||
}
|
||||
|
||||
// Close the one side of the connection by timeout if the node
|
||||
// has not received any messages for a long time.
|
||||
if let Some(timeout) = net.options.keepalive_timeout {
|
||||
let mut to_close = [false, false];
|
||||
for direction in 0..2 {
|
||||
let buffer = &mut state.buffers[direction];
|
||||
if buffer.recv_closed {
|
||||
continue;
|
||||
}
|
||||
if let Some(last_recv) = buffer.last_recv {
|
||||
if now - last_recv >= timeout {
|
||||
debug!(
|
||||
"NET: connection {} timed out at {}",
|
||||
self.connection_id,
|
||||
receiver_str(direction as MessageDirection)
|
||||
);
|
||||
let node_idx = direction ^ 1;
|
||||
to_close[node_idx] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
drop(state);
|
||||
|
||||
for (node_idx, should_close) in to_close.iter().enumerate() {
|
||||
if *should_close {
|
||||
self.close(node_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Process messages in the buffer in the given direction.
|
||||
fn process_direction(
|
||||
&self,
|
||||
net: &Arc<NetworkTask>,
|
||||
state: &mut ConnectionState,
|
||||
now: u64,
|
||||
direction: MessageDirection,
|
||||
to_socket: &Chan<NetEvent>,
|
||||
) {
|
||||
let buffer = &mut state.buffers[direction as usize];
|
||||
if buffer.recv_closed {
|
||||
assert!(buffer.buf.is_empty());
|
||||
}
|
||||
|
||||
while !buffer.buf.is_empty() && buffer.buf.front().unwrap().0 <= now {
|
||||
let msg = buffer.buf.pop_front().unwrap().1;
|
||||
|
||||
buffer.last_recv = Some(now);
|
||||
self.schedule_timeout(net);
|
||||
|
||||
if let AnyMessage::InternalConnect = msg {
|
||||
// TODO: assert to_socket is the server
|
||||
let server_to_client = TCP {
|
||||
net: net.clone(),
|
||||
conn_id: self.connection_id,
|
||||
dir: direction ^ 1,
|
||||
recv_chan: to_socket.clone(),
|
||||
};
|
||||
// special case, we need to deliver new connection to a separate channel
|
||||
self.dst_accept.send(NodeEvent::Accept(server_to_client));
|
||||
} else {
|
||||
to_socket.send(NetEvent::Message(msg));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to send a message to the buffer, optionally dropping it and
|
||||
/// determining delivery timestamp.
|
||||
fn send(&self, net: &NetworkTask, direction: MessageDirection, msg: AnyMessage) {
|
||||
let now = executor::now();
|
||||
let mut state = self.state.lock();
|
||||
|
||||
let (delay, close) = if let Some(ms) = net.options.send_delay.delay(&mut state.rng) {
|
||||
(ms, false)
|
||||
} else {
|
||||
(0, true)
|
||||
};
|
||||
|
||||
let buffer = &mut state.buffers[direction as usize];
|
||||
if buffer.send_closed {
|
||||
debug!(
|
||||
"NET: TCP #{} dropped message {:?} (broken pipe)",
|
||||
self.connection_id, msg
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if close {
|
||||
debug!(
|
||||
"NET: TCP #{} dropped message {:?} (pipe just broke)",
|
||||
self.connection_id, msg
|
||||
);
|
||||
buffer.send_closed = true;
|
||||
return;
|
||||
}
|
||||
|
||||
if buffer.recv_closed {
|
||||
debug!(
|
||||
"NET: TCP #{} dropped message {:?} (recv closed)",
|
||||
self.connection_id, msg
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Send a message into the future.
|
||||
buffer.buf.push_back((now + delay, msg));
|
||||
net.schedule(self.connection_id, delay);
|
||||
}
|
||||
|
||||
/// Close the connection. Only one side of the connection will be closed,
|
||||
/// and no further messages will be delivered. The other side will not be notified.
|
||||
fn close(&self, node_idx: usize) {
|
||||
let mut state = self.state.lock();
|
||||
let recv_buffer = &mut state.buffers[1 ^ node_idx];
|
||||
if recv_buffer.recv_closed {
|
||||
debug!(
|
||||
"NET: TCP #{} closed twice at {}",
|
||||
self.connection_id,
|
||||
sender_str(node_idx as MessageDirection),
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
debug!(
|
||||
"NET: TCP #{} closed at {}",
|
||||
self.connection_id,
|
||||
sender_str(node_idx as MessageDirection),
|
||||
);
|
||||
recv_buffer.recv_closed = true;
|
||||
for msg in recv_buffer.buf.drain(..) {
|
||||
debug!(
|
||||
"NET: TCP #{} dropped message {:?} (closed)",
|
||||
self.connection_id, msg
|
||||
);
|
||||
}
|
||||
|
||||
let send_buffer = &mut state.buffers[node_idx];
|
||||
send_buffer.send_closed = true;
|
||||
drop(state);
|
||||
|
||||
// TODO: notify the other side?
|
||||
|
||||
self.dst_sockets[node_idx].send(NetEvent::Closed);
|
||||
}
|
||||
}
|
||||
|
||||
struct NetworkBuffer {
|
||||
/// Messages paired with time of delivery
|
||||
buf: VecDeque<(u64, AnyMessage)>,
|
||||
/// True if the connection is closed on the receiving side,
|
||||
/// i.e. no more messages from the buffer will be delivered.
|
||||
recv_closed: bool,
|
||||
/// True if the connection is closed on the sending side,
|
||||
/// i.e. no more messages will be added to the buffer.
|
||||
send_closed: bool,
|
||||
/// Last time a message was delivered from the buffer.
|
||||
/// If None, it means that the server is the receiver and
|
||||
/// it has not yet aware of this connection (i.e. has not
|
||||
/// received the Accept).
|
||||
last_recv: Option<u64>,
|
||||
}
|
||||
|
||||
impl NetworkBuffer {
|
||||
fn new(last_recv: Option<u64>) -> Self {
|
||||
Self {
|
||||
buf: VecDeque::new(),
|
||||
recv_closed: false,
|
||||
send_closed: false,
|
||||
last_recv,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Single end of a bidirectional network stream without reordering (TCP-like).
|
||||
/// Reads are implemented using channels, writes go to the buffer inside VirtualConnection.
|
||||
pub struct TCP {
|
||||
net: Arc<NetworkTask>,
|
||||
conn_id: usize,
|
||||
dir: MessageDirection,
|
||||
recv_chan: Chan<NetEvent>,
|
||||
}
|
||||
|
||||
impl Debug for TCP {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "TCP #{} ({})", self.conn_id, sender_str(self.dir),)
|
||||
}
|
||||
}
|
||||
|
||||
impl TCP {
|
||||
/// Send a message to the other side. It's guaranteed that it will not arrive
|
||||
/// before the arrival of all messages sent earlier.
|
||||
pub fn send(&self, msg: AnyMessage) {
|
||||
let conn = self.net.get(self.conn_id);
|
||||
conn.send(&self.net, self.dir, msg);
|
||||
}
|
||||
|
||||
/// Get a channel to receive incoming messages.
|
||||
pub fn recv_chan(&self) -> Chan<NetEvent> {
|
||||
self.recv_chan.clone()
|
||||
}
|
||||
|
||||
pub fn connection_id(&self) -> usize {
|
||||
self.conn_id
|
||||
}
|
||||
|
||||
pub fn close(&self) {
|
||||
let conn = self.net.get(self.conn_id);
|
||||
conn.close(self.dir as usize);
|
||||
}
|
||||
}
|
||||
struct Event {
|
||||
time: u64,
|
||||
conn_id: usize,
|
||||
}
|
||||
|
||||
// BinaryHeap is a max-heap, and we want a min-heap. Reverse the ordering here
|
||||
// to get that.
|
||||
impl PartialOrd for Event {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl Ord for Event {
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
(other.time, other.conn_id).cmp(&(self.time, self.conn_id))
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for Event {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
(other.time, other.conn_id) == (self.time, self.conn_id)
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for Event {}
|
||||
54
libs/desim/src/node_os.rs
Normal file
54
libs/desim/src/node_os.rs
Normal file
@@ -0,0 +1,54 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use rand::Rng;
|
||||
|
||||
use crate::proto::NodeEvent;
|
||||
|
||||
use super::{
|
||||
chan::Chan,
|
||||
network::TCP,
|
||||
world::{Node, NodeId, World},
|
||||
};
|
||||
|
||||
/// Abstraction with all functions (aka syscalls) available to the node.
|
||||
#[derive(Clone)]
|
||||
pub struct NodeOs {
|
||||
world: Arc<World>,
|
||||
internal: Arc<Node>,
|
||||
}
|
||||
|
||||
impl NodeOs {
|
||||
pub fn new(world: Arc<World>, internal: Arc<Node>) -> NodeOs {
|
||||
NodeOs { world, internal }
|
||||
}
|
||||
|
||||
/// Get the node id.
|
||||
pub fn id(&self) -> NodeId {
|
||||
self.internal.id
|
||||
}
|
||||
|
||||
/// Opens a bidirectional connection with the other node. Always successful.
|
||||
pub fn open_tcp(&self, dst: NodeId) -> TCP {
|
||||
self.world.open_tcp(dst)
|
||||
}
|
||||
|
||||
/// Returns a channel to receive node events (socket Accept and internal messages).
|
||||
pub fn node_events(&self) -> Chan<NodeEvent> {
|
||||
self.internal.node_events()
|
||||
}
|
||||
|
||||
/// Get current time.
|
||||
pub fn now(&self) -> u64 {
|
||||
self.world.now()
|
||||
}
|
||||
|
||||
/// Generate a random number in range [0, max).
|
||||
pub fn random(&self, max: u64) -> u64 {
|
||||
self.internal.rng.lock().gen_range(0..max)
|
||||
}
|
||||
|
||||
/// Append a new event to the world event log.
|
||||
pub fn log_event(&self, data: String) {
|
||||
self.internal.log_event(data)
|
||||
}
|
||||
}
|
||||
50
libs/desim/src/options.rs
Normal file
50
libs/desim/src/options.rs
Normal file
@@ -0,0 +1,50 @@
|
||||
use rand::{rngs::StdRng, Rng};
|
||||
|
||||
/// Describes random delays and failures. Delay will be uniformly distributed in [min, max].
|
||||
/// Connection failure will occur with the probablity fail_prob.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Delay {
|
||||
pub min: u64,
|
||||
pub max: u64,
|
||||
pub fail_prob: f64, // [0; 1]
|
||||
}
|
||||
|
||||
impl Delay {
|
||||
/// Create a struct with no delay, no failures.
|
||||
pub fn empty() -> Delay {
|
||||
Delay {
|
||||
min: 0,
|
||||
max: 0,
|
||||
fail_prob: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a struct with a fixed delay.
|
||||
pub fn fixed(ms: u64) -> Delay {
|
||||
Delay {
|
||||
min: ms,
|
||||
max: ms,
|
||||
fail_prob: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate a random delay in range [min, max]. Return None if the
|
||||
/// message should be dropped.
|
||||
pub fn delay(&self, rng: &mut StdRng) -> Option<u64> {
|
||||
if rng.gen_bool(self.fail_prob) {
|
||||
return None;
|
||||
}
|
||||
Some(rng.gen_range(self.min..=self.max))
|
||||
}
|
||||
}
|
||||
|
||||
/// Describes network settings. All network packets will be subjected to the same delays and failures.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct NetworkOptions {
|
||||
/// Connection will be automatically closed after this timeout if no data is received.
|
||||
pub keepalive_timeout: Option<u64>,
|
||||
/// New connections will be delayed by this amount of time.
|
||||
pub connect_delay: Delay,
|
||||
/// Each message will be delayed by this amount of time.
|
||||
pub send_delay: Delay,
|
||||
}
|
||||
63
libs/desim/src/proto.rs
Normal file
63
libs/desim/src/proto.rs
Normal file
@@ -0,0 +1,63 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use bytes::Bytes;
|
||||
use utils::lsn::Lsn;
|
||||
|
||||
use crate::{network::TCP, world::NodeId};
|
||||
|
||||
/// Internal node events.
|
||||
#[derive(Debug)]
|
||||
pub enum NodeEvent {
|
||||
Accept(TCP),
|
||||
Internal(AnyMessage),
|
||||
}
|
||||
|
||||
/// Events that are coming from a network socket.
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum NetEvent {
|
||||
Message(AnyMessage),
|
||||
Closed,
|
||||
}
|
||||
|
||||
/// Custom events generated throughout the simulation. Can be used by the test to verify the correctness.
|
||||
#[derive(Debug)]
|
||||
pub struct SimEvent {
|
||||
pub time: u64,
|
||||
pub node: NodeId,
|
||||
pub data: String,
|
||||
}
|
||||
|
||||
/// Umbrella type for all possible flavours of messages. These events can be sent over network
|
||||
/// or to an internal node events channel.
|
||||
#[derive(Clone)]
|
||||
pub enum AnyMessage {
|
||||
/// Not used, empty placeholder.
|
||||
None,
|
||||
/// Used internally for notifying node about new incoming connection.
|
||||
InternalConnect,
|
||||
Just32(u32),
|
||||
ReplCell(ReplCell),
|
||||
Bytes(Bytes),
|
||||
LSN(u64),
|
||||
}
|
||||
|
||||
impl Debug for AnyMessage {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
AnyMessage::None => write!(f, "None"),
|
||||
AnyMessage::InternalConnect => write!(f, "InternalConnect"),
|
||||
AnyMessage::Just32(v) => write!(f, "Just32({})", v),
|
||||
AnyMessage::ReplCell(v) => write!(f, "ReplCell({:?})", v),
|
||||
AnyMessage::Bytes(v) => write!(f, "Bytes({})", hex::encode(v)),
|
||||
AnyMessage::LSN(v) => write!(f, "LSN({})", Lsn(*v)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Used in reliable_copy_test.rs
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ReplCell {
|
||||
pub value: u32,
|
||||
pub client_id: u32,
|
||||
pub seqno: u32,
|
||||
}
|
||||
129
libs/desim/src/time.rs
Normal file
129
libs/desim/src/time.rs
Normal file
@@ -0,0 +1,129 @@
|
||||
use std::{
|
||||
cmp::Ordering,
|
||||
collections::BinaryHeap,
|
||||
ops::DerefMut,
|
||||
sync::{
|
||||
atomic::{AtomicU32, AtomicU64},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
|
||||
use parking_lot::Mutex;
|
||||
use tracing::trace;
|
||||
|
||||
use crate::executor::ThreadContext;
|
||||
|
||||
/// Holds current time and all pending wakeup events.
|
||||
pub struct Timing {
|
||||
/// Current world's time.
|
||||
current_time: AtomicU64,
|
||||
/// Pending timers.
|
||||
queue: Mutex<BinaryHeap<Pending>>,
|
||||
/// Global nonce. Makes picking events from binary heap queue deterministic
|
||||
/// by appending a number to events with the same timestamp.
|
||||
nonce: AtomicU32,
|
||||
/// Used to schedule fake events.
|
||||
fake_context: Arc<ThreadContext>,
|
||||
}
|
||||
|
||||
impl Default for Timing {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Timing {
|
||||
/// Create a new empty clock with time set to 0.
|
||||
pub fn new() -> Timing {
|
||||
Timing {
|
||||
current_time: AtomicU64::new(0),
|
||||
queue: Mutex::new(BinaryHeap::new()),
|
||||
nonce: AtomicU32::new(0),
|
||||
fake_context: Arc::new(ThreadContext::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the current world's time.
|
||||
pub fn now(&self) -> u64 {
|
||||
self.current_time.load(std::sync::atomic::Ordering::SeqCst)
|
||||
}
|
||||
|
||||
/// Tick-tock the global clock. Return the event ready to be processed
|
||||
/// or move the clock forward and then return the event.
|
||||
pub(crate) fn step(&self) -> Option<Arc<ThreadContext>> {
|
||||
let mut queue = self.queue.lock();
|
||||
|
||||
if queue.is_empty() {
|
||||
// no future events
|
||||
return None;
|
||||
}
|
||||
|
||||
if !self.is_event_ready(queue.deref_mut()) {
|
||||
let next_time = queue.peek().unwrap().time;
|
||||
self.current_time
|
||||
.store(next_time, std::sync::atomic::Ordering::SeqCst);
|
||||
trace!("rewind time to {}", next_time);
|
||||
assert!(self.is_event_ready(queue.deref_mut()));
|
||||
}
|
||||
|
||||
Some(queue.pop().unwrap().wake_context)
|
||||
}
|
||||
|
||||
/// Append an event to the queue, to wakeup the thread in `ms` milliseconds.
|
||||
pub(crate) fn schedule_wakeup(&self, ms: u64, wake_context: Arc<ThreadContext>) {
|
||||
self.nonce.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
|
||||
let nonce = self.nonce.load(std::sync::atomic::Ordering::SeqCst);
|
||||
self.queue.lock().push(Pending {
|
||||
time: self.now() + ms,
|
||||
nonce,
|
||||
wake_context,
|
||||
})
|
||||
}
|
||||
|
||||
/// Append a fake event to the queue, to prevent clocks from skipping this time.
|
||||
pub fn schedule_fake(&self, ms: u64) {
|
||||
self.queue.lock().push(Pending {
|
||||
time: self.now() + ms,
|
||||
nonce: 0,
|
||||
wake_context: self.fake_context.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
/// Return true if there is a ready event.
|
||||
fn is_event_ready(&self, queue: &mut BinaryHeap<Pending>) -> bool {
|
||||
queue.peek().map_or(false, |x| x.time <= self.now())
|
||||
}
|
||||
|
||||
/// Clear all pending events.
|
||||
pub(crate) fn clear(&self) {
|
||||
self.queue.lock().clear();
|
||||
}
|
||||
}
|
||||
|
||||
struct Pending {
|
||||
time: u64,
|
||||
nonce: u32,
|
||||
wake_context: Arc<ThreadContext>,
|
||||
}
|
||||
|
||||
// BinaryHeap is a max-heap, and we want a min-heap. Reverse the ordering here
|
||||
// to get that.
|
||||
impl PartialOrd for Pending {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl Ord for Pending {
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
(other.time, other.nonce).cmp(&(self.time, self.nonce))
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for Pending {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
(other.time, other.nonce) == (self.time, self.nonce)
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for Pending {}
|
||||
180
libs/desim/src/world.rs
Normal file
180
libs/desim/src/world.rs
Normal file
@@ -0,0 +1,180 @@
|
||||
use parking_lot::Mutex;
|
||||
use rand::{rngs::StdRng, SeedableRng};
|
||||
use std::{
|
||||
ops::DerefMut,
|
||||
sync::{mpsc, Arc},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
executor::{ExternalHandle, Runtime},
|
||||
network::NetworkTask,
|
||||
options::NetworkOptions,
|
||||
proto::{NodeEvent, SimEvent},
|
||||
time::Timing,
|
||||
};
|
||||
|
||||
use super::{chan::Chan, network::TCP, node_os::NodeOs};
|
||||
|
||||
pub type NodeId = u32;
|
||||
|
||||
/// World contains simulation state.
|
||||
pub struct World {
|
||||
nodes: Mutex<Vec<Arc<Node>>>,
|
||||
/// Random number generator.
|
||||
rng: Mutex<StdRng>,
|
||||
/// Internal event log.
|
||||
events: Mutex<Vec<SimEvent>>,
|
||||
/// Separate task that processes all network messages.
|
||||
network_task: Arc<NetworkTask>,
|
||||
/// Runtime for running threads and moving time.
|
||||
runtime: Mutex<Runtime>,
|
||||
/// To get current time.
|
||||
timing: Arc<Timing>,
|
||||
}
|
||||
|
||||
impl World {
|
||||
pub fn new(seed: u64, options: Arc<NetworkOptions>) -> World {
|
||||
let timing = Arc::new(Timing::new());
|
||||
let mut runtime = Runtime::new(timing.clone());
|
||||
|
||||
let (tx, rx) = mpsc::channel();
|
||||
|
||||
runtime.spawn(move || {
|
||||
// create and start network background thread, and send it back via the channel
|
||||
NetworkTask::start_new(options, tx)
|
||||
});
|
||||
|
||||
// wait for the network task to start
|
||||
while runtime.step() {}
|
||||
|
||||
let network_task = rx.recv().unwrap();
|
||||
|
||||
World {
|
||||
nodes: Mutex::new(Vec::new()),
|
||||
rng: Mutex::new(StdRng::seed_from_u64(seed)),
|
||||
events: Mutex::new(Vec::new()),
|
||||
network_task,
|
||||
runtime: Mutex::new(runtime),
|
||||
timing,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn step(&self) -> bool {
|
||||
self.runtime.lock().step()
|
||||
}
|
||||
|
||||
pub fn get_thread_step_count(&self) -> u64 {
|
||||
self.runtime.lock().step_counter
|
||||
}
|
||||
|
||||
/// Create a new random number generator.
|
||||
pub fn new_rng(&self) -> StdRng {
|
||||
let mut rng = self.rng.lock();
|
||||
StdRng::from_rng(rng.deref_mut()).unwrap()
|
||||
}
|
||||
|
||||
/// Create a new node.
|
||||
pub fn new_node(self: &Arc<Self>) -> Arc<Node> {
|
||||
let mut nodes = self.nodes.lock();
|
||||
let id = nodes.len() as NodeId;
|
||||
let node = Arc::new(Node::new(id, self.clone(), self.new_rng()));
|
||||
nodes.push(node.clone());
|
||||
node
|
||||
}
|
||||
|
||||
/// Get an internal node state by id.
|
||||
fn get_node(&self, id: NodeId) -> Option<Arc<Node>> {
|
||||
let nodes = self.nodes.lock();
|
||||
let num = id as usize;
|
||||
if num < nodes.len() {
|
||||
Some(nodes[num].clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn stop_all(&self) {
|
||||
self.runtime.lock().crash_all_threads();
|
||||
}
|
||||
|
||||
/// Returns a writable end of a TCP connection, to send src->dst messages.
|
||||
pub fn open_tcp(self: &Arc<World>, dst: NodeId) -> TCP {
|
||||
// TODO: replace unwrap() with /dev/null socket.
|
||||
let dst = self.get_node(dst).unwrap();
|
||||
let dst_accept = dst.node_events.lock().clone();
|
||||
|
||||
let rng = self.new_rng();
|
||||
self.network_task.start_new_connection(rng, dst_accept)
|
||||
}
|
||||
|
||||
/// Get current time.
|
||||
pub fn now(&self) -> u64 {
|
||||
self.timing.now()
|
||||
}
|
||||
|
||||
/// Get a copy of the internal clock.
|
||||
pub fn clock(&self) -> Arc<Timing> {
|
||||
self.timing.clone()
|
||||
}
|
||||
|
||||
pub fn add_event(&self, node: NodeId, data: String) {
|
||||
let time = self.now();
|
||||
self.events.lock().push(SimEvent { time, node, data });
|
||||
}
|
||||
|
||||
pub fn take_events(&self) -> Vec<SimEvent> {
|
||||
let mut events = self.events.lock();
|
||||
let mut res = Vec::new();
|
||||
std::mem::swap(&mut res, &mut events);
|
||||
res
|
||||
}
|
||||
|
||||
pub fn deallocate(&self) {
|
||||
self.stop_all();
|
||||
self.timing.clear();
|
||||
self.nodes.lock().clear();
|
||||
}
|
||||
}
|
||||
|
||||
/// Internal node state.
|
||||
pub struct Node {
|
||||
pub id: NodeId,
|
||||
node_events: Mutex<Chan<NodeEvent>>,
|
||||
world: Arc<World>,
|
||||
pub(crate) rng: Mutex<StdRng>,
|
||||
}
|
||||
|
||||
impl Node {
|
||||
pub fn new(id: NodeId, world: Arc<World>, rng: StdRng) -> Node {
|
||||
Node {
|
||||
id,
|
||||
node_events: Mutex::new(Chan::new()),
|
||||
world,
|
||||
rng: Mutex::new(rng),
|
||||
}
|
||||
}
|
||||
|
||||
/// Spawn a new thread with this node context.
|
||||
pub fn launch(self: &Arc<Self>, f: impl FnOnce(NodeOs) + Send + 'static) -> ExternalHandle {
|
||||
let node = self.clone();
|
||||
let world = self.world.clone();
|
||||
self.world.runtime.lock().spawn(move || {
|
||||
f(NodeOs::new(world, node.clone()));
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns a channel to receive Accepts and internal messages.
|
||||
pub fn node_events(&self) -> Chan<NodeEvent> {
|
||||
self.node_events.lock().clone()
|
||||
}
|
||||
|
||||
/// This will drop all in-flight Accept messages.
|
||||
pub fn replug_node_events(&self, chan: Chan<NodeEvent>) {
|
||||
*self.node_events.lock() = chan;
|
||||
}
|
||||
|
||||
/// Append event to the world's log.
|
||||
pub fn log_event(&self, data: String) {
|
||||
self.world.add_event(self.id, data)
|
||||
}
|
||||
}
|
||||
244
libs/desim/tests/reliable_copy_test.rs
Normal file
244
libs/desim/tests/reliable_copy_test.rs
Normal file
@@ -0,0 +1,244 @@
|
||||
//! Simple test to verify that simulator is working.
|
||||
#[cfg(test)]
|
||||
mod reliable_copy_test {
|
||||
use anyhow::Result;
|
||||
use desim::executor::{self, PollSome};
|
||||
use desim::options::{Delay, NetworkOptions};
|
||||
use desim::proto::{NetEvent, NodeEvent, ReplCell};
|
||||
use desim::world::{NodeId, World};
|
||||
use desim::{node_os::NodeOs, proto::AnyMessage};
|
||||
use parking_lot::Mutex;
|
||||
use std::sync::Arc;
|
||||
use tracing::info;
|
||||
|
||||
/// Disk storage trait and implementation.
|
||||
pub trait Storage<T> {
|
||||
fn flush_pos(&self) -> u32;
|
||||
fn flush(&mut self) -> Result<()>;
|
||||
fn write(&mut self, t: T);
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SharedStorage<T> {
|
||||
pub state: Arc<Mutex<InMemoryStorage<T>>>,
|
||||
}
|
||||
|
||||
impl<T> SharedStorage<T> {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
state: Arc::new(Mutex::new(InMemoryStorage::new())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Storage<T> for SharedStorage<T> {
|
||||
fn flush_pos(&self) -> u32 {
|
||||
self.state.lock().flush_pos
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> Result<()> {
|
||||
executor::yield_me(0);
|
||||
self.state.lock().flush()
|
||||
}
|
||||
|
||||
fn write(&mut self, t: T) {
|
||||
executor::yield_me(0);
|
||||
self.state.lock().write(t);
|
||||
}
|
||||
}
|
||||
|
||||
pub struct InMemoryStorage<T> {
|
||||
pub data: Vec<T>,
|
||||
pub flush_pos: u32,
|
||||
}
|
||||
|
||||
impl<T> InMemoryStorage<T> {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
data: Vec::new(),
|
||||
flush_pos: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn flush(&mut self) -> Result<()> {
|
||||
self.flush_pos = self.data.len() as u32;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn write(&mut self, t: T) {
|
||||
self.data.push(t);
|
||||
}
|
||||
}
|
||||
|
||||
/// Server implementation.
|
||||
pub fn run_server(os: NodeOs, mut storage: Box<dyn Storage<u32>>) {
|
||||
info!("started server");
|
||||
|
||||
let node_events = os.node_events();
|
||||
let mut epoll_vec: Vec<Box<dyn PollSome>> = vec![Box::new(node_events.clone())];
|
||||
let mut sockets = vec![];
|
||||
|
||||
loop {
|
||||
let index = executor::epoll_chans(&epoll_vec, -1).unwrap();
|
||||
|
||||
if index == 0 {
|
||||
let node_event = node_events.must_recv();
|
||||
info!("got node event: {:?}", node_event);
|
||||
if let NodeEvent::Accept(tcp) = node_event {
|
||||
tcp.send(AnyMessage::Just32(storage.flush_pos()));
|
||||
epoll_vec.push(Box::new(tcp.recv_chan()));
|
||||
sockets.push(tcp);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
let recv_chan = sockets[index - 1].recv_chan();
|
||||
let socket = &sockets[index - 1];
|
||||
|
||||
let event = recv_chan.must_recv();
|
||||
info!("got event: {:?}", event);
|
||||
if let NetEvent::Message(AnyMessage::ReplCell(cell)) = event {
|
||||
if cell.seqno != storage.flush_pos() {
|
||||
info!("got out of order data: {:?}", cell);
|
||||
continue;
|
||||
}
|
||||
storage.write(cell.value);
|
||||
storage.flush().unwrap();
|
||||
socket.send(AnyMessage::Just32(storage.flush_pos()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Client copies all data from array to the remote node.
|
||||
pub fn run_client(os: NodeOs, data: &[ReplCell], dst: NodeId) {
|
||||
info!("started client");
|
||||
|
||||
let mut delivered = 0;
|
||||
|
||||
let mut sock = os.open_tcp(dst);
|
||||
let mut recv_chan = sock.recv_chan();
|
||||
|
||||
while delivered < data.len() {
|
||||
let num = &data[delivered];
|
||||
info!("sending data: {:?}", num.clone());
|
||||
sock.send(AnyMessage::ReplCell(num.clone()));
|
||||
|
||||
// loop {
|
||||
let event = recv_chan.recv();
|
||||
match event {
|
||||
NetEvent::Message(AnyMessage::Just32(flush_pos)) => {
|
||||
if flush_pos == 1 + delivered as u32 {
|
||||
delivered += 1;
|
||||
}
|
||||
}
|
||||
NetEvent::Closed => {
|
||||
info!("connection closed, reestablishing");
|
||||
sock = os.open_tcp(dst);
|
||||
recv_chan = sock.recv_chan();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// }
|
||||
}
|
||||
|
||||
let sock = os.open_tcp(dst);
|
||||
for num in data {
|
||||
info!("sending data: {:?}", num.clone());
|
||||
sock.send(AnyMessage::ReplCell(num.clone()));
|
||||
}
|
||||
|
||||
info!("sent all data and finished client");
|
||||
}
|
||||
|
||||
/// Run test simulations.
|
||||
#[test]
|
||||
fn sim_example_reliable_copy() {
|
||||
utils::logging::init(
|
||||
utils::logging::LogFormat::Test,
|
||||
utils::logging::TracingErrorLayerEnablement::Disabled,
|
||||
utils::logging::Output::Stdout,
|
||||
)
|
||||
.expect("logging init failed");
|
||||
|
||||
let delay = Delay {
|
||||
min: 1,
|
||||
max: 60,
|
||||
fail_prob: 0.4,
|
||||
};
|
||||
|
||||
let network = NetworkOptions {
|
||||
keepalive_timeout: Some(50),
|
||||
connect_delay: delay.clone(),
|
||||
send_delay: delay.clone(),
|
||||
};
|
||||
|
||||
for seed in 0..20 {
|
||||
let u32_data: [u32; 5] = [1, 2, 3, 4, 5];
|
||||
let data = u32_to_cells(&u32_data, 1);
|
||||
let world = Arc::new(World::new(seed, Arc::new(network.clone())));
|
||||
|
||||
start_simulation(Options {
|
||||
world,
|
||||
time_limit: 1_000_000,
|
||||
client_fn: Box::new(move |os, server_id| run_client(os, &data, server_id)),
|
||||
u32_data,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Options {
|
||||
pub world: Arc<World>,
|
||||
pub time_limit: u64,
|
||||
pub u32_data: [u32; 5],
|
||||
pub client_fn: Box<dyn FnOnce(NodeOs, u32) + Send + 'static>,
|
||||
}
|
||||
|
||||
pub fn start_simulation(options: Options) {
|
||||
let world = options.world;
|
||||
|
||||
let client_node = world.new_node();
|
||||
let server_node = world.new_node();
|
||||
let server_id = server_node.id;
|
||||
|
||||
// start the client thread
|
||||
client_node.launch(move |os| {
|
||||
let client_fn = options.client_fn;
|
||||
client_fn(os, server_id);
|
||||
});
|
||||
|
||||
// start the server thread
|
||||
let shared_storage = SharedStorage::new();
|
||||
let server_storage = shared_storage.clone();
|
||||
server_node.launch(move |os| run_server(os, Box::new(server_storage)));
|
||||
|
||||
while world.step() && world.now() < options.time_limit {}
|
||||
|
||||
let disk_data = shared_storage.state.lock().data.clone();
|
||||
assert!(verify_data(&disk_data, &options.u32_data[..]));
|
||||
}
|
||||
|
||||
pub fn u32_to_cells(data: &[u32], client_id: u32) -> Vec<ReplCell> {
|
||||
let mut res = Vec::new();
|
||||
for (i, _) in data.iter().enumerate() {
|
||||
res.push(ReplCell {
|
||||
client_id,
|
||||
seqno: i as u32,
|
||||
value: data[i],
|
||||
});
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
fn verify_data(disk_data: &[u32], data: &[u32]) -> bool {
|
||||
if disk_data.len() != data.len() {
|
||||
return false;
|
||||
}
|
||||
for i in 0..data.len() {
|
||||
if disk_data[i] != data[i] {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
}
|
||||
@@ -494,6 +494,8 @@ pub struct TimelineInfo {
|
||||
pub current_logical_size: u64,
|
||||
pub current_logical_size_is_accurate: bool,
|
||||
|
||||
pub directory_entries_counts: Vec<u64>,
|
||||
|
||||
/// Sum of the size of all layer files.
|
||||
/// If a layer is present in both local FS and S3, it counts only once.
|
||||
pub current_physical_size: Option<u64>, // is None when timeline is Unloaded
|
||||
|
||||
@@ -124,6 +124,7 @@ impl RelTag {
|
||||
Ord,
|
||||
strum_macros::EnumIter,
|
||||
strum_macros::FromRepr,
|
||||
enum_map::Enum,
|
||||
)]
|
||||
#[repr(u8)]
|
||||
pub enum SlruKind {
|
||||
|
||||
@@ -431,11 +431,11 @@ pub fn generate_wal_segment(segno: u64, system_id: u64, lsn: Lsn) -> Result<Byte
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Serialize)]
|
||||
struct XlLogicalMessage {
|
||||
db_id: Oid,
|
||||
transactional: uint32, // bool, takes 4 bytes due to alignment in C structures
|
||||
prefix_size: uint64,
|
||||
message_size: uint64,
|
||||
pub struct XlLogicalMessage {
|
||||
pub db_id: Oid,
|
||||
pub transactional: uint32, // bool, takes 4 bytes due to alignment in C structures
|
||||
pub prefix_size: uint64,
|
||||
pub message_size: uint64,
|
||||
}
|
||||
|
||||
impl XlLogicalMessage {
|
||||
|
||||
@@ -13,5 +13,6 @@ rand.workspace = true
|
||||
tokio.workspace = true
|
||||
tracing.workspace = true
|
||||
thiserror.workspace = true
|
||||
serde.workspace = true
|
||||
|
||||
workspace_hack.workspace = true
|
||||
|
||||
@@ -7,6 +7,7 @@ pub mod framed;
|
||||
|
||||
use byteorder::{BigEndian, ReadBytesExt};
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{borrow::Cow, collections::HashMap, fmt, io, str};
|
||||
|
||||
// re-export for use in utils pageserver_feedback.rs
|
||||
@@ -123,7 +124,7 @@ impl StartupMessageParams {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)]
|
||||
#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
|
||||
pub struct CancelKeyData {
|
||||
pub backend_pid: i32,
|
||||
pub cancel_key: i32,
|
||||
|
||||
@@ -54,12 +54,10 @@ impl Generation {
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
pub fn get_suffix(&self) -> String {
|
||||
pub fn get_suffix(&self) -> impl std::fmt::Display {
|
||||
match self {
|
||||
Self::Valid(v) => {
|
||||
format!("-{:08x}", v)
|
||||
}
|
||||
Self::None => "".into(),
|
||||
Self::Valid(v) => GenerationFileSuffix(Some(*v)),
|
||||
Self::None => GenerationFileSuffix(None),
|
||||
Self::Broken => {
|
||||
panic!("Tried to use a broken generation");
|
||||
}
|
||||
@@ -90,6 +88,7 @@ impl Generation {
|
||||
}
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
pub fn next(&self) -> Generation {
|
||||
match self {
|
||||
Self::Valid(n) => Self::Valid(*n + 1),
|
||||
@@ -107,6 +106,18 @@ impl Generation {
|
||||
}
|
||||
}
|
||||
|
||||
struct GenerationFileSuffix(Option<u32>);
|
||||
|
||||
impl std::fmt::Display for GenerationFileSuffix {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
if let Some(g) = self.0 {
|
||||
write!(f, "-{g:08x}")
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for Generation {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
@@ -164,4 +175,24 @@ mod test {
|
||||
assert!(Generation::none() < Generation::new(0));
|
||||
assert!(Generation::none() < Generation::new(1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn suffix_is_stable() {
|
||||
use std::fmt::Write as _;
|
||||
|
||||
// the suffix must remain stable through-out the pageserver remote storage evolution and
|
||||
// not be changed accidentially without thinking about migration
|
||||
let examples = [
|
||||
(line!(), Generation::None, ""),
|
||||
(line!(), Generation::Valid(0), "-00000000"),
|
||||
(line!(), Generation::Valid(u32::MAX), "-ffffffff"),
|
||||
];
|
||||
|
||||
let mut s = String::new();
|
||||
for (line, gen, expected) in examples {
|
||||
s.clear();
|
||||
write!(s, "{}", &gen.get_suffix()).expect("string grows");
|
||||
assert_eq!(s, expected, "example on {line}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,37 +69,44 @@ impl<T> OnceCell<T> {
|
||||
F: FnOnce(InitPermit) -> Fut,
|
||||
Fut: std::future::Future<Output = Result<(T, InitPermit), E>>,
|
||||
{
|
||||
let sem = {
|
||||
loop {
|
||||
let sem = {
|
||||
let guard = self.inner.lock().unwrap();
|
||||
if guard.value.is_some() {
|
||||
return Ok(Guard(guard));
|
||||
}
|
||||
guard.init_semaphore.clone()
|
||||
};
|
||||
|
||||
{
|
||||
let permit = {
|
||||
// increment the count for the duration of queued
|
||||
let _guard = CountWaitingInitializers::start(self);
|
||||
sem.acquire().await
|
||||
};
|
||||
|
||||
let Ok(permit) = permit else {
|
||||
let guard = self.inner.lock().unwrap();
|
||||
if !Arc::ptr_eq(&sem, &guard.init_semaphore) {
|
||||
// there was a take_and_deinit in between
|
||||
continue;
|
||||
}
|
||||
assert!(
|
||||
guard.value.is_some(),
|
||||
"semaphore got closed, must be initialized"
|
||||
);
|
||||
return Ok(Guard(guard));
|
||||
};
|
||||
|
||||
permit.forget();
|
||||
}
|
||||
|
||||
let permit = InitPermit(sem);
|
||||
let (value, _permit) = factory(permit).await?;
|
||||
|
||||
let guard = self.inner.lock().unwrap();
|
||||
if guard.value.is_some() {
|
||||
return Ok(Guard(guard));
|
||||
}
|
||||
guard.init_semaphore.clone()
|
||||
};
|
||||
|
||||
let permit = {
|
||||
// increment the count for the duration of queued
|
||||
let _guard = CountWaitingInitializers::start(self);
|
||||
sem.acquire_owned().await
|
||||
};
|
||||
|
||||
match permit {
|
||||
Ok(permit) => {
|
||||
let permit = InitPermit(permit);
|
||||
let (value, _permit) = factory(permit).await?;
|
||||
|
||||
let guard = self.inner.lock().unwrap();
|
||||
|
||||
Ok(Self::set0(value, guard))
|
||||
}
|
||||
Err(_closed) => {
|
||||
let guard = self.inner.lock().unwrap();
|
||||
assert!(
|
||||
guard.value.is_some(),
|
||||
"semaphore got closed, must be initialized"
|
||||
);
|
||||
return Ok(Guard(guard));
|
||||
}
|
||||
return Ok(Self::set0(value, guard));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -197,27 +204,41 @@ impl<'a, T> Guard<'a, T> {
|
||||
/// [`OnceCell::get_or_init`] will wait on it to complete.
|
||||
pub fn take_and_deinit(&mut self) -> (T, InitPermit) {
|
||||
let mut swapped = Inner::default();
|
||||
let permit = swapped
|
||||
.init_semaphore
|
||||
.clone()
|
||||
.try_acquire_owned()
|
||||
.expect("we just created this");
|
||||
let sem = swapped.init_semaphore.clone();
|
||||
// acquire and forget right away, moving the control over to InitPermit
|
||||
sem.try_acquire().expect("we just created this").forget();
|
||||
std::mem::swap(&mut *self.0, &mut swapped);
|
||||
swapped
|
||||
.value
|
||||
.map(|v| (v, InitPermit(permit)))
|
||||
.map(|v| (v, InitPermit(sem)))
|
||||
.expect("guard is not created unless value has been initialized")
|
||||
}
|
||||
}
|
||||
|
||||
/// Type held by OnceCell (de)initializing task.
|
||||
pub struct InitPermit(tokio::sync::OwnedSemaphorePermit);
|
||||
///
|
||||
/// On drop, this type will return the permit.
|
||||
pub struct InitPermit(Arc<tokio::sync::Semaphore>);
|
||||
|
||||
impl Drop for InitPermit {
|
||||
fn drop(&mut self) {
|
||||
assert_eq!(
|
||||
self.0.available_permits(),
|
||||
0,
|
||||
"InitPermit should only exist as the unique permit"
|
||||
);
|
||||
self.0.add_permits(1);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use futures::Future;
|
||||
|
||||
use super::*;
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
pin::{pin, Pin},
|
||||
sync::atomic::{AtomicUsize, Ordering},
|
||||
time::Duration,
|
||||
};
|
||||
@@ -380,4 +401,85 @@ mod tests {
|
||||
.unwrap();
|
||||
assert_eq!(*g, "now initialized");
|
||||
}
|
||||
|
||||
#[tokio::test(start_paused = true)]
|
||||
async fn reproduce_init_take_deinit_race() {
|
||||
init_take_deinit_scenario(|cell, factory| {
|
||||
Box::pin(async {
|
||||
cell.get_or_init(factory).await.unwrap();
|
||||
})
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
type BoxedInitFuture<T, E> = Pin<Box<dyn Future<Output = Result<(T, InitPermit), E>>>>;
|
||||
type BoxedInitFunction<T, E> = Box<dyn Fn(InitPermit) -> BoxedInitFuture<T, E>>;
|
||||
|
||||
/// Reproduce an assertion failure.
|
||||
///
|
||||
/// This has interesting generics to be generic between `get_or_init` and `get_mut_or_init`.
|
||||
/// We currently only have one, but the structure is kept.
|
||||
async fn init_take_deinit_scenario<F>(init_way: F)
|
||||
where
|
||||
F: for<'a> Fn(
|
||||
&'a OnceCell<&'static str>,
|
||||
BoxedInitFunction<&'static str, Infallible>,
|
||||
) -> Pin<Box<dyn Future<Output = ()> + 'a>>,
|
||||
{
|
||||
let cell = OnceCell::default();
|
||||
|
||||
// acquire the init_semaphore only permit to drive initializing tasks in order to waiting
|
||||
// on the same semaphore.
|
||||
let permit = cell
|
||||
.inner
|
||||
.lock()
|
||||
.unwrap()
|
||||
.init_semaphore
|
||||
.clone()
|
||||
.try_acquire_owned()
|
||||
.unwrap();
|
||||
|
||||
let mut t1 = pin!(init_way(
|
||||
&cell,
|
||||
Box::new(|permit| Box::pin(async move { Ok(("t1", permit)) })),
|
||||
));
|
||||
|
||||
let mut t2 = pin!(init_way(
|
||||
&cell,
|
||||
Box::new(|permit| Box::pin(async move { Ok(("t2", permit)) })),
|
||||
));
|
||||
|
||||
// drive t2 first to the init_semaphore -- the timeout will be hit once t2 future can
|
||||
// no longer make progress
|
||||
tokio::select! {
|
||||
_ = &mut t2 => unreachable!("it cannot get permit"),
|
||||
_ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {}
|
||||
}
|
||||
|
||||
// followed by t1 in the init_semaphore
|
||||
tokio::select! {
|
||||
_ = &mut t1 => unreachable!("it cannot get permit"),
|
||||
_ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {}
|
||||
}
|
||||
|
||||
// now let t2 proceed and initialize
|
||||
drop(permit);
|
||||
t2.await;
|
||||
|
||||
let (s, permit) = { cell.get().unwrap().take_and_deinit() };
|
||||
assert_eq!("t2", s);
|
||||
|
||||
// now originally t1 would see the semaphore it has as closed. it cannot yet get a permit from
|
||||
// the new one.
|
||||
tokio::select! {
|
||||
_ = &mut t1 => unreachable!("it cannot get permit"),
|
||||
_ = tokio::time::sleep(Duration::from_secs(3600 * 24 * 7 * 365)) => {}
|
||||
}
|
||||
|
||||
// only now we get to initialize it
|
||||
drop(permit);
|
||||
t1.await;
|
||||
|
||||
assert_eq!("t1", *cell.get().unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,6 +34,9 @@ fn main() -> anyhow::Result<()> {
|
||||
println!("cargo:rustc-link-lib=static=walproposer");
|
||||
println!("cargo:rustc-link-search={walproposer_lib_search_str}");
|
||||
|
||||
// Rebuild crate when libwalproposer.a changes
|
||||
println!("cargo:rerun-if-changed={walproposer_lib_search_str}/libwalproposer.a");
|
||||
|
||||
let pg_config_bin = pg_install_abs.join("v16").join("bin").join("pg_config");
|
||||
let inc_server_path: String = if pg_config_bin.exists() {
|
||||
let output = Command::new(pg_config_bin)
|
||||
@@ -79,6 +82,7 @@ fn main() -> anyhow::Result<()> {
|
||||
.allowlist_function("WalProposerBroadcast")
|
||||
.allowlist_function("WalProposerPoll")
|
||||
.allowlist_function("WalProposerFree")
|
||||
.allowlist_function("SafekeeperStateDesiredEvents")
|
||||
.allowlist_var("DEBUG5")
|
||||
.allowlist_var("DEBUG4")
|
||||
.allowlist_var("DEBUG3")
|
||||
|
||||
@@ -22,6 +22,7 @@ use crate::bindings::WalProposerExecStatusType;
|
||||
use crate::bindings::WalproposerShmemState;
|
||||
use crate::bindings::XLogRecPtr;
|
||||
use crate::walproposer::ApiImpl;
|
||||
use crate::walproposer::StreamingCallback;
|
||||
use crate::walproposer::WaitResult;
|
||||
|
||||
extern "C" fn get_shmem_state(wp: *mut WalProposer) -> *mut WalproposerShmemState {
|
||||
@@ -36,7 +37,8 @@ extern "C" fn start_streaming(wp: *mut WalProposer, startpos: XLogRecPtr) {
|
||||
unsafe {
|
||||
let callback_data = (*(*wp).config).callback_data;
|
||||
let api = callback_data as *mut Box<dyn ApiImpl>;
|
||||
(*api).start_streaming(startpos)
|
||||
let callback = StreamingCallback::new(wp);
|
||||
(*api).start_streaming(startpos, &callback);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -134,19 +136,18 @@ extern "C" fn conn_async_read(
|
||||
unsafe {
|
||||
let callback_data = (*(*(*sk).wp).config).callback_data;
|
||||
let api = callback_data as *mut Box<dyn ApiImpl>;
|
||||
let (res, result) = (*api).conn_async_read(&mut (*sk));
|
||||
|
||||
// This function has guarantee that returned buf will be valid until
|
||||
// the next call. So we can store a Vec in each Safekeeper and reuse
|
||||
// it on the next call.
|
||||
let mut inbuf = take_vec_u8(&mut (*sk).inbuf).unwrap_or_default();
|
||||
|
||||
inbuf.clear();
|
||||
inbuf.extend_from_slice(res);
|
||||
|
||||
let result = (*api).conn_async_read(&mut (*sk), &mut inbuf);
|
||||
|
||||
// Put a Vec back to sk->inbuf and return data ptr.
|
||||
*amount = inbuf.len() as i32;
|
||||
*buf = store_vec_u8(&mut (*sk).inbuf, inbuf);
|
||||
*amount = res.len() as i32;
|
||||
|
||||
result
|
||||
}
|
||||
@@ -182,6 +183,10 @@ extern "C" fn recovery_download(wp: *mut WalProposer, sk: *mut Safekeeper) -> bo
|
||||
unsafe {
|
||||
let callback_data = (*(*(*sk).wp).config).callback_data;
|
||||
let api = callback_data as *mut Box<dyn ApiImpl>;
|
||||
|
||||
// currently `recovery_download` is always called right after election
|
||||
(*api).after_election(&mut (*wp));
|
||||
|
||||
(*api).recovery_download(&mut (*wp), &mut (*sk))
|
||||
}
|
||||
}
|
||||
@@ -277,7 +282,8 @@ extern "C" fn wait_event_set(
|
||||
}
|
||||
WaitResult::Timeout => {
|
||||
*event_sk = std::ptr::null_mut();
|
||||
*events = crate::bindings::WL_TIMEOUT;
|
||||
// WaitEventSetWait returns 0 for timeout.
|
||||
*events = 0;
|
||||
0
|
||||
}
|
||||
WaitResult::Network(sk, event_mask) => {
|
||||
@@ -340,7 +346,7 @@ extern "C" fn log_internal(
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub enum Level {
|
||||
Debug5,
|
||||
Debug4,
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
use std::ffi::CString;
|
||||
|
||||
use postgres_ffi::WAL_SEGMENT_SIZE;
|
||||
use utils::id::TenantTimelineId;
|
||||
use utils::{id::TenantTimelineId, lsn::Lsn};
|
||||
|
||||
use crate::{
|
||||
api_bindings::{create_api, take_vec_u8, Level},
|
||||
bindings::{
|
||||
NeonWALReadResult, Safekeeper, WalProposer, WalProposerConfig, WalProposerCreate,
|
||||
WalProposerFree, WalProposerStart,
|
||||
NeonWALReadResult, Safekeeper, WalProposer, WalProposerBroadcast, WalProposerConfig,
|
||||
WalProposerCreate, WalProposerFree, WalProposerPoll, WalProposerStart,
|
||||
},
|
||||
};
|
||||
|
||||
@@ -16,11 +16,11 @@ use crate::{
|
||||
///
|
||||
/// Refer to `pgxn/neon/walproposer.h` for documentation.
|
||||
pub trait ApiImpl {
|
||||
fn get_shmem_state(&self) -> &mut crate::bindings::WalproposerShmemState {
|
||||
fn get_shmem_state(&self) -> *mut crate::bindings::WalproposerShmemState {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn start_streaming(&self, _startpos: u64) {
|
||||
fn start_streaming(&self, _startpos: u64, _callback: &StreamingCallback) {
|
||||
todo!()
|
||||
}
|
||||
|
||||
@@ -70,7 +70,11 @@ pub trait ApiImpl {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn conn_async_read(&self, _sk: &mut Safekeeper) -> (&[u8], crate::bindings::PGAsyncReadResult) {
|
||||
fn conn_async_read(
|
||||
&self,
|
||||
_sk: &mut Safekeeper,
|
||||
_vec: &mut Vec<u8>,
|
||||
) -> crate::bindings::PGAsyncReadResult {
|
||||
todo!()
|
||||
}
|
||||
|
||||
@@ -151,12 +155,14 @@ pub trait ApiImpl {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum WaitResult {
|
||||
Latch,
|
||||
Timeout,
|
||||
Network(*mut Safekeeper, u32),
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Config {
|
||||
/// Tenant and timeline id
|
||||
pub ttid: TenantTimelineId,
|
||||
@@ -242,6 +248,24 @@ impl Drop for Wrapper {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct StreamingCallback {
|
||||
wp: *mut WalProposer,
|
||||
}
|
||||
|
||||
impl StreamingCallback {
|
||||
pub fn new(wp: *mut WalProposer) -> StreamingCallback {
|
||||
StreamingCallback { wp }
|
||||
}
|
||||
|
||||
pub fn broadcast(&self, startpos: Lsn, endpos: Lsn) {
|
||||
unsafe { WalProposerBroadcast(self.wp, startpos.0, endpos.0) }
|
||||
}
|
||||
|
||||
pub fn poll(&self) {
|
||||
unsafe { WalProposerPoll(self.wp) }
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use core::panic;
|
||||
@@ -344,14 +368,13 @@ mod tests {
|
||||
fn conn_async_read(
|
||||
&self,
|
||||
_: &mut crate::bindings::Safekeeper,
|
||||
) -> (&[u8], crate::bindings::PGAsyncReadResult) {
|
||||
vec: &mut Vec<u8>,
|
||||
) -> crate::bindings::PGAsyncReadResult {
|
||||
println!("conn_async_read");
|
||||
let reply = self.next_safekeeper_reply();
|
||||
println!("conn_async_read result: {:?}", reply);
|
||||
(
|
||||
reply,
|
||||
crate::bindings::PGAsyncReadResult_PG_ASYNC_READ_SUCCESS,
|
||||
)
|
||||
vec.extend_from_slice(reply);
|
||||
crate::bindings::PGAsyncReadResult_PG_ASYNC_READ_SUCCESS
|
||||
}
|
||||
|
||||
fn conn_blocking_write(&self, _: &mut crate::bindings::Safekeeper, buf: &[u8]) -> bool {
|
||||
|
||||
@@ -234,7 +234,7 @@ impl DeletionHeader {
|
||||
let header_bytes = serde_json::to_vec(self).context("serialize deletion header")?;
|
||||
let header_path = conf.deletion_header_path();
|
||||
let temp_path = path_with_suffix_extension(&header_path, TEMP_SUFFIX);
|
||||
VirtualFile::crashsafe_overwrite(&header_path, &temp_path, &header_bytes)
|
||||
VirtualFile::crashsafe_overwrite(&header_path, &temp_path, header_bytes)
|
||||
.await
|
||||
.maybe_fatal_err("save deletion header")?;
|
||||
|
||||
@@ -325,7 +325,7 @@ impl DeletionList {
|
||||
let temp_path = path_with_suffix_extension(&path, TEMP_SUFFIX);
|
||||
|
||||
let bytes = serde_json::to_vec(self).expect("Failed to serialize deletion list");
|
||||
VirtualFile::crashsafe_overwrite(&path, &temp_path, &bytes)
|
||||
VirtualFile::crashsafe_overwrite(&path, &temp_path, bytes)
|
||||
.await
|
||||
.maybe_fatal_err("save deletion list")
|
||||
.map_err(Into::into)
|
||||
|
||||
@@ -422,6 +422,7 @@ async fn build_timeline_info_common(
|
||||
tenant::timeline::logical_size::Accuracy::Approximate => false,
|
||||
tenant::timeline::logical_size::Accuracy::Exact => true,
|
||||
},
|
||||
directory_entries_counts: timeline.get_directory_metrics().to_vec(),
|
||||
current_physical_size,
|
||||
current_logical_size_non_incremental: None,
|
||||
timeline_dir_layer_file_size_sum: None,
|
||||
@@ -488,7 +489,9 @@ async fn timeline_create_handler(
|
||||
let state = get_state(&request);
|
||||
|
||||
async {
|
||||
let tenant = state.tenant_manager.get_attached_tenant_shard(tenant_shard_id, false)?;
|
||||
let tenant = state
|
||||
.tenant_manager
|
||||
.get_attached_tenant_shard(tenant_shard_id, false)?;
|
||||
|
||||
tenant.wait_to_become_active(ACTIVE_TENANT_TIMEOUT).await?;
|
||||
|
||||
@@ -498,48 +501,62 @@ async fn timeline_create_handler(
|
||||
tracing::info!("bootstrapping");
|
||||
}
|
||||
|
||||
match tenant.create_timeline(
|
||||
new_timeline_id,
|
||||
request_data.ancestor_timeline_id.map(TimelineId::from),
|
||||
request_data.ancestor_start_lsn,
|
||||
request_data.pg_version.unwrap_or(crate::DEFAULT_PG_VERSION),
|
||||
request_data.existing_initdb_timeline_id,
|
||||
state.broker_client.clone(),
|
||||
&ctx,
|
||||
)
|
||||
.await {
|
||||
match tenant
|
||||
.create_timeline(
|
||||
new_timeline_id,
|
||||
request_data.ancestor_timeline_id,
|
||||
request_data.ancestor_start_lsn,
|
||||
request_data.pg_version.unwrap_or(crate::DEFAULT_PG_VERSION),
|
||||
request_data.existing_initdb_timeline_id,
|
||||
state.broker_client.clone(),
|
||||
&ctx,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(new_timeline) => {
|
||||
// Created. Construct a TimelineInfo for it.
|
||||
let timeline_info = build_timeline_info_common(&new_timeline, &ctx, tenant::timeline::GetLogicalSizePriority::User)
|
||||
.await
|
||||
.map_err(ApiError::InternalServerError)?;
|
||||
let timeline_info = build_timeline_info_common(
|
||||
&new_timeline,
|
||||
&ctx,
|
||||
tenant::timeline::GetLogicalSizePriority::User,
|
||||
)
|
||||
.await
|
||||
.map_err(ApiError::InternalServerError)?;
|
||||
json_response(StatusCode::CREATED, timeline_info)
|
||||
}
|
||||
Err(_) if tenant.cancel.is_cancelled() => {
|
||||
// In case we get some ugly error type during shutdown, cast it into a clean 503.
|
||||
json_response(StatusCode::SERVICE_UNAVAILABLE, HttpErrorBody::from_msg("Tenant shutting down".to_string()))
|
||||
}
|
||||
Err(tenant::CreateTimelineError::Conflict | tenant::CreateTimelineError::AlreadyCreating) => {
|
||||
json_response(StatusCode::CONFLICT, ())
|
||||
}
|
||||
Err(tenant::CreateTimelineError::AncestorLsn(err)) => {
|
||||
json_response(StatusCode::NOT_ACCEPTABLE, HttpErrorBody::from_msg(
|
||||
format!("{err:#}")
|
||||
))
|
||||
}
|
||||
Err(e @ tenant::CreateTimelineError::AncestorNotActive) => {
|
||||
json_response(StatusCode::SERVICE_UNAVAILABLE, HttpErrorBody::from_msg(e.to_string()))
|
||||
}
|
||||
Err(tenant::CreateTimelineError::ShuttingDown) => {
|
||||
json_response(StatusCode::SERVICE_UNAVAILABLE,HttpErrorBody::from_msg("tenant shutting down".to_string()))
|
||||
json_response(
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
HttpErrorBody::from_msg("Tenant shutting down".to_string()),
|
||||
)
|
||||
}
|
||||
Err(
|
||||
tenant::CreateTimelineError::Conflict
|
||||
| tenant::CreateTimelineError::AlreadyCreating,
|
||||
) => json_response(StatusCode::CONFLICT, ()),
|
||||
Err(tenant::CreateTimelineError::AncestorLsn(err)) => json_response(
|
||||
StatusCode::NOT_ACCEPTABLE,
|
||||
HttpErrorBody::from_msg(format!("{err:#}")),
|
||||
),
|
||||
Err(e @ tenant::CreateTimelineError::AncestorNotActive) => json_response(
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
HttpErrorBody::from_msg(e.to_string()),
|
||||
),
|
||||
Err(tenant::CreateTimelineError::ShuttingDown) => json_response(
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
HttpErrorBody::from_msg("tenant shutting down".to_string()),
|
||||
),
|
||||
Err(tenant::CreateTimelineError::Other(err)) => Err(ApiError::InternalServerError(err)),
|
||||
}
|
||||
}
|
||||
.instrument(info_span!("timeline_create",
|
||||
tenant_id = %tenant_shard_id.tenant_id,
|
||||
shard_id = %tenant_shard_id.shard_slug(),
|
||||
timeline_id = %new_timeline_id, lsn=?request_data.ancestor_start_lsn, pg_version=?request_data.pg_version))
|
||||
timeline_id = %new_timeline_id,
|
||||
lsn=?request_data.ancestor_start_lsn,
|
||||
pg_version=?request_data.pg_version
|
||||
))
|
||||
.await
|
||||
}
|
||||
|
||||
|
||||
@@ -602,6 +602,15 @@ pub(crate) mod initial_logical_size {
|
||||
});
|
||||
}
|
||||
|
||||
static DIRECTORY_ENTRIES_COUNT: Lazy<UIntGaugeVec> = Lazy::new(|| {
|
||||
register_uint_gauge_vec!(
|
||||
"pageserver_directory_entries_count",
|
||||
"Sum of the entries in pageserver-stored directory listings",
|
||||
&["tenant_id", "shard_id", "timeline_id"]
|
||||
)
|
||||
.expect("failed to define a metric")
|
||||
});
|
||||
|
||||
pub(crate) static TENANT_STATE_METRIC: Lazy<UIntGaugeVec> = Lazy::new(|| {
|
||||
register_uint_gauge_vec!(
|
||||
"pageserver_tenant_states_count",
|
||||
@@ -1809,6 +1818,7 @@ pub(crate) struct TimelineMetrics {
|
||||
resident_physical_size_gauge: UIntGauge,
|
||||
/// copy of LayeredTimeline.current_logical_size
|
||||
pub current_logical_size_gauge: UIntGauge,
|
||||
pub directory_entries_count_gauge: Lazy<UIntGauge, Box<dyn Send + Fn() -> UIntGauge>>,
|
||||
pub num_persistent_files_created: IntCounter,
|
||||
pub persistent_bytes_written: IntCounter,
|
||||
pub evictions: IntCounter,
|
||||
@@ -1818,12 +1828,12 @@ pub(crate) struct TimelineMetrics {
|
||||
impl TimelineMetrics {
|
||||
pub fn new(
|
||||
tenant_shard_id: &TenantShardId,
|
||||
timeline_id: &TimelineId,
|
||||
timeline_id_raw: &TimelineId,
|
||||
evictions_with_low_residence_duration_builder: EvictionsWithLowResidenceDurationBuilder,
|
||||
) -> Self {
|
||||
let tenant_id = tenant_shard_id.tenant_id.to_string();
|
||||
let shard_id = format!("{}", tenant_shard_id.shard_slug());
|
||||
let timeline_id = timeline_id.to_string();
|
||||
let timeline_id = timeline_id_raw.to_string();
|
||||
let flush_time_histo = StorageTimeMetrics::new(
|
||||
StorageTimeOperation::LayerFlush,
|
||||
&tenant_id,
|
||||
@@ -1876,6 +1886,22 @@ impl TimelineMetrics {
|
||||
let current_logical_size_gauge = CURRENT_LOGICAL_SIZE
|
||||
.get_metric_with_label_values(&[&tenant_id, &shard_id, &timeline_id])
|
||||
.unwrap();
|
||||
// TODO use impl Trait syntax here once we have ability to use it: https://github.com/rust-lang/rust/issues/63065
|
||||
let directory_entries_count_gauge_closure = {
|
||||
let tenant_shard_id = *tenant_shard_id;
|
||||
let timeline_id_raw = *timeline_id_raw;
|
||||
move || {
|
||||
let tenant_id = tenant_shard_id.tenant_id.to_string();
|
||||
let shard_id = format!("{}", tenant_shard_id.shard_slug());
|
||||
let timeline_id = timeline_id_raw.to_string();
|
||||
let gauge: UIntGauge = DIRECTORY_ENTRIES_COUNT
|
||||
.get_metric_with_label_values(&[&tenant_id, &shard_id, &timeline_id])
|
||||
.unwrap();
|
||||
gauge
|
||||
}
|
||||
};
|
||||
let directory_entries_count_gauge: Lazy<UIntGauge, Box<dyn Send + Fn() -> UIntGauge>> =
|
||||
Lazy::new(Box::new(directory_entries_count_gauge_closure));
|
||||
let num_persistent_files_created = NUM_PERSISTENT_FILES_CREATED
|
||||
.get_metric_with_label_values(&[&tenant_id, &shard_id, &timeline_id])
|
||||
.unwrap();
|
||||
@@ -1902,6 +1928,7 @@ impl TimelineMetrics {
|
||||
last_record_gauge,
|
||||
resident_physical_size_gauge,
|
||||
current_logical_size_gauge,
|
||||
directory_entries_count_gauge,
|
||||
num_persistent_files_created,
|
||||
persistent_bytes_written,
|
||||
evictions,
|
||||
@@ -1944,6 +1971,9 @@ impl Drop for TimelineMetrics {
|
||||
RESIDENT_PHYSICAL_SIZE.remove_label_values(&[tenant_id, &shard_id, timeline_id]);
|
||||
}
|
||||
let _ = CURRENT_LOGICAL_SIZE.remove_label_values(&[tenant_id, &shard_id, timeline_id]);
|
||||
if let Some(metric) = Lazy::get(&DIRECTORY_ENTRIES_COUNT) {
|
||||
let _ = metric.remove_label_values(&[tenant_id, &shard_id, timeline_id]);
|
||||
}
|
||||
let _ =
|
||||
NUM_PERSISTENT_FILES_CREATED.remove_label_values(&[tenant_id, &shard_id, timeline_id]);
|
||||
let _ = PERSISTENT_BYTES_WRITTEN.remove_label_values(&[tenant_id, &shard_id, timeline_id]);
|
||||
|
||||
@@ -14,6 +14,7 @@ use crate::span::debug_assert_current_span_has_tenant_and_timeline_id_no_shard_i
|
||||
use crate::walrecord::NeonWalRecord;
|
||||
use anyhow::{ensure, Context};
|
||||
use bytes::{Buf, Bytes, BytesMut};
|
||||
use enum_map::Enum;
|
||||
use pageserver_api::key::{
|
||||
dbdir_key_range, is_rel_block_key, is_slru_block_key, rel_block_to_key, rel_dir_to_key,
|
||||
rel_key_range, rel_size_to_key, relmap_file_key, slru_block_to_key, slru_dir_to_key,
|
||||
@@ -155,6 +156,7 @@ impl Timeline {
|
||||
pending_updates: HashMap::new(),
|
||||
pending_deletions: Vec::new(),
|
||||
pending_nblocks: 0,
|
||||
pending_directory_entries: Vec::new(),
|
||||
lsn,
|
||||
}
|
||||
}
|
||||
@@ -868,6 +870,7 @@ pub struct DatadirModification<'a> {
|
||||
pending_updates: HashMap<Key, Vec<(Lsn, Value)>>,
|
||||
pending_deletions: Vec<(Range<Key>, Lsn)>,
|
||||
pending_nblocks: i64,
|
||||
pending_directory_entries: Vec<(DirectoryKind, usize)>,
|
||||
}
|
||||
|
||||
impl<'a> DatadirModification<'a> {
|
||||
@@ -899,6 +902,7 @@ impl<'a> DatadirModification<'a> {
|
||||
let buf = DbDirectory::ser(&DbDirectory {
|
||||
dbdirs: HashMap::new(),
|
||||
})?;
|
||||
self.pending_directory_entries.push((DirectoryKind::Db, 0));
|
||||
self.put(DBDIR_KEY, Value::Image(buf.into()));
|
||||
|
||||
// Create AuxFilesDirectory
|
||||
@@ -907,16 +911,24 @@ impl<'a> DatadirModification<'a> {
|
||||
let buf = TwoPhaseDirectory::ser(&TwoPhaseDirectory {
|
||||
xids: HashSet::new(),
|
||||
})?;
|
||||
self.pending_directory_entries
|
||||
.push((DirectoryKind::TwoPhase, 0));
|
||||
self.put(TWOPHASEDIR_KEY, Value::Image(buf.into()));
|
||||
|
||||
let buf: Bytes = SlruSegmentDirectory::ser(&SlruSegmentDirectory::default())?.into();
|
||||
let empty_dir = Value::Image(buf);
|
||||
self.put(slru_dir_to_key(SlruKind::Clog), empty_dir.clone());
|
||||
self.pending_directory_entries
|
||||
.push((DirectoryKind::SlruSegment(SlruKind::Clog), 0));
|
||||
self.put(
|
||||
slru_dir_to_key(SlruKind::MultiXactMembers),
|
||||
empty_dir.clone(),
|
||||
);
|
||||
self.pending_directory_entries
|
||||
.push((DirectoryKind::SlruSegment(SlruKind::Clog), 0));
|
||||
self.put(slru_dir_to_key(SlruKind::MultiXactOffsets), empty_dir);
|
||||
self.pending_directory_entries
|
||||
.push((DirectoryKind::SlruSegment(SlruKind::MultiXactOffsets), 0));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1017,6 +1029,7 @@ impl<'a> DatadirModification<'a> {
|
||||
let buf = RelDirectory::ser(&RelDirectory {
|
||||
rels: HashSet::new(),
|
||||
})?;
|
||||
self.pending_directory_entries.push((DirectoryKind::Rel, 0));
|
||||
self.put(
|
||||
rel_dir_to_key(spcnode, dbnode),
|
||||
Value::Image(Bytes::from(buf)),
|
||||
@@ -1039,6 +1052,8 @@ impl<'a> DatadirModification<'a> {
|
||||
if !dir.xids.insert(xid) {
|
||||
anyhow::bail!("twophase file for xid {} already exists", xid);
|
||||
}
|
||||
self.pending_directory_entries
|
||||
.push((DirectoryKind::TwoPhase, dir.xids.len()));
|
||||
self.put(
|
||||
TWOPHASEDIR_KEY,
|
||||
Value::Image(Bytes::from(TwoPhaseDirectory::ser(&dir)?)),
|
||||
@@ -1074,6 +1089,8 @@ impl<'a> DatadirModification<'a> {
|
||||
let mut dir = DbDirectory::des(&buf)?;
|
||||
if dir.dbdirs.remove(&(spcnode, dbnode)).is_some() {
|
||||
let buf = DbDirectory::ser(&dir)?;
|
||||
self.pending_directory_entries
|
||||
.push((DirectoryKind::Db, dir.dbdirs.len()));
|
||||
self.put(DBDIR_KEY, Value::Image(buf.into()));
|
||||
} else {
|
||||
warn!(
|
||||
@@ -1111,6 +1128,8 @@ impl<'a> DatadirModification<'a> {
|
||||
// Didn't exist. Update dbdir
|
||||
dbdir.dbdirs.insert((rel.spcnode, rel.dbnode), false);
|
||||
let buf = DbDirectory::ser(&dbdir).context("serialize db")?;
|
||||
self.pending_directory_entries
|
||||
.push((DirectoryKind::Db, dbdir.dbdirs.len()));
|
||||
self.put(DBDIR_KEY, Value::Image(buf.into()));
|
||||
|
||||
// and create the RelDirectory
|
||||
@@ -1125,6 +1144,10 @@ impl<'a> DatadirModification<'a> {
|
||||
if !rel_dir.rels.insert((rel.relnode, rel.forknum)) {
|
||||
return Err(RelationError::AlreadyExists);
|
||||
}
|
||||
|
||||
self.pending_directory_entries
|
||||
.push((DirectoryKind::Rel, rel_dir.rels.len()));
|
||||
|
||||
self.put(
|
||||
rel_dir_key,
|
||||
Value::Image(Bytes::from(
|
||||
@@ -1216,6 +1239,9 @@ impl<'a> DatadirModification<'a> {
|
||||
let buf = self.get(dir_key, ctx).await?;
|
||||
let mut dir = RelDirectory::des(&buf)?;
|
||||
|
||||
self.pending_directory_entries
|
||||
.push((DirectoryKind::Rel, dir.rels.len()));
|
||||
|
||||
if dir.rels.remove(&(rel.relnode, rel.forknum)) {
|
||||
self.put(dir_key, Value::Image(Bytes::from(RelDirectory::ser(&dir)?)));
|
||||
} else {
|
||||
@@ -1251,6 +1277,8 @@ impl<'a> DatadirModification<'a> {
|
||||
if !dir.segments.insert(segno) {
|
||||
anyhow::bail!("slru segment {kind:?}/{segno} already exists");
|
||||
}
|
||||
self.pending_directory_entries
|
||||
.push((DirectoryKind::SlruSegment(kind), dir.segments.len()));
|
||||
self.put(
|
||||
dir_key,
|
||||
Value::Image(Bytes::from(SlruSegmentDirectory::ser(&dir)?)),
|
||||
@@ -1295,6 +1323,8 @@ impl<'a> DatadirModification<'a> {
|
||||
if !dir.segments.remove(&segno) {
|
||||
warn!("slru segment {:?}/{} does not exist", kind, segno);
|
||||
}
|
||||
self.pending_directory_entries
|
||||
.push((DirectoryKind::SlruSegment(kind), dir.segments.len()));
|
||||
self.put(
|
||||
dir_key,
|
||||
Value::Image(Bytes::from(SlruSegmentDirectory::ser(&dir)?)),
|
||||
@@ -1325,6 +1355,8 @@ impl<'a> DatadirModification<'a> {
|
||||
if !dir.xids.remove(&xid) {
|
||||
warn!("twophase file for xid {} does not exist", xid);
|
||||
}
|
||||
self.pending_directory_entries
|
||||
.push((DirectoryKind::TwoPhase, dir.xids.len()));
|
||||
self.put(
|
||||
TWOPHASEDIR_KEY,
|
||||
Value::Image(Bytes::from(TwoPhaseDirectory::ser(&dir)?)),
|
||||
@@ -1340,6 +1372,8 @@ impl<'a> DatadirModification<'a> {
|
||||
let buf = AuxFilesDirectory::ser(&AuxFilesDirectory {
|
||||
files: HashMap::new(),
|
||||
})?;
|
||||
self.pending_directory_entries
|
||||
.push((DirectoryKind::AuxFiles, 0));
|
||||
self.put(AUX_FILES_KEY, Value::Image(Bytes::from(buf)));
|
||||
Ok(())
|
||||
}
|
||||
@@ -1366,6 +1400,9 @@ impl<'a> DatadirModification<'a> {
|
||||
} else {
|
||||
dir.files.insert(path, Bytes::copy_from_slice(content));
|
||||
}
|
||||
self.pending_directory_entries
|
||||
.push((DirectoryKind::AuxFiles, dir.files.len()));
|
||||
|
||||
self.put(
|
||||
AUX_FILES_KEY,
|
||||
Value::Image(Bytes::from(
|
||||
@@ -1427,6 +1464,10 @@ impl<'a> DatadirModification<'a> {
|
||||
self.pending_nblocks = 0;
|
||||
}
|
||||
|
||||
for (kind, count) in std::mem::take(&mut self.pending_directory_entries) {
|
||||
writer.update_directory_entries_count(kind, count as u64);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1464,6 +1505,10 @@ impl<'a> DatadirModification<'a> {
|
||||
writer.update_current_logical_size(pending_nblocks * i64::from(BLCKSZ));
|
||||
}
|
||||
|
||||
for (kind, count) in std::mem::take(&mut self.pending_directory_entries) {
|
||||
writer.update_directory_entries_count(kind, count as u64);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1588,6 +1633,23 @@ struct SlruSegmentDirectory {
|
||||
segments: HashSet<u32>,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Debug, enum_map::Enum)]
|
||||
#[repr(u8)]
|
||||
pub(crate) enum DirectoryKind {
|
||||
Db,
|
||||
TwoPhase,
|
||||
Rel,
|
||||
AuxFiles,
|
||||
SlruSegment(SlruKind),
|
||||
}
|
||||
|
||||
impl DirectoryKind {
|
||||
pub(crate) const KINDS_NUM: usize = <DirectoryKind as Enum>::LENGTH;
|
||||
pub(crate) fn offset(&self) -> usize {
|
||||
self.into_usize()
|
||||
}
|
||||
}
|
||||
|
||||
static ZERO_PAGE: Bytes = Bytes::from_static(&[0u8; BLCKSZ as usize]);
|
||||
|
||||
#[allow(clippy::bool_assert_comparison)]
|
||||
|
||||
@@ -644,10 +644,10 @@ impl Tenant {
|
||||
|
||||
// The attach task will carry a GateGuard, so that shutdown() reliably waits for it to drop out if
|
||||
// we shut down while attaching.
|
||||
let Ok(attach_gate_guard) = tenant.gate.enter() else {
|
||||
// We just created the Tenant: nothing else can have shut it down yet
|
||||
unreachable!();
|
||||
};
|
||||
let attach_gate_guard = tenant
|
||||
.gate
|
||||
.enter()
|
||||
.expect("We just created the Tenant: nothing else can have shut it down yet");
|
||||
|
||||
// Do all the hard work in the background
|
||||
let tenant_clone = Arc::clone(&tenant);
|
||||
@@ -755,36 +755,27 @@ impl Tenant {
|
||||
AttachType::Normal
|
||||
};
|
||||
|
||||
let preload_timer = TENANT.preload.start_timer();
|
||||
let preload = match mode {
|
||||
SpawnMode::Create => {
|
||||
// Don't count the skipped preload into the histogram of preload durations
|
||||
preload_timer.stop_and_discard();
|
||||
let preload = match (&mode, &remote_storage) {
|
||||
(SpawnMode::Create, _) => {
|
||||
None
|
||||
},
|
||||
SpawnMode::Normal => {
|
||||
match &remote_storage {
|
||||
Some(remote_storage) => Some(
|
||||
match tenant_clone
|
||||
.preload(remote_storage, task_mgr::shutdown_token())
|
||||
.instrument(
|
||||
tracing::info_span!(parent: None, "attach_preload", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug()),
|
||||
)
|
||||
.await {
|
||||
Ok(p) => {
|
||||
preload_timer.observe_duration();
|
||||
p
|
||||
}
|
||||
,
|
||||
Err(e) => {
|
||||
make_broken(&tenant_clone, anyhow::anyhow!(e));
|
||||
return Ok(());
|
||||
}
|
||||
},
|
||||
),
|
||||
None => None,
|
||||
(SpawnMode::Normal, Some(remote_storage)) => {
|
||||
let _preload_timer = TENANT.preload.start_timer();
|
||||
let res = tenant_clone
|
||||
.preload(remote_storage, task_mgr::shutdown_token())
|
||||
.await;
|
||||
match res {
|
||||
Ok(p) => Some(p),
|
||||
Err(e) => {
|
||||
make_broken(&tenant_clone, anyhow::anyhow!(e));
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
(SpawnMode::Normal, None) => {
|
||||
let _preload_timer = TENANT.preload.start_timer();
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
// Remote preload is complete.
|
||||
@@ -820,36 +811,37 @@ impl Tenant {
|
||||
info!("ready for backgound jobs barrier");
|
||||
}
|
||||
|
||||
match DeleteTenantFlow::resume_from_attach(
|
||||
let deleted = DeleteTenantFlow::resume_from_attach(
|
||||
deletion,
|
||||
&tenant_clone,
|
||||
preload,
|
||||
tenants,
|
||||
&ctx,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Err(err) => {
|
||||
make_broken(&tenant_clone, anyhow::anyhow!(err));
|
||||
return Ok(());
|
||||
}
|
||||
Ok(()) => return Ok(()),
|
||||
.await;
|
||||
|
||||
if let Err(e) = deleted {
|
||||
make_broken(&tenant_clone, anyhow::anyhow!(e));
|
||||
}
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// We will time the duration of the attach phase unless this is a creation (attach will do no work)
|
||||
let attach_timer = match mode {
|
||||
SpawnMode::Create => None,
|
||||
SpawnMode::Normal => {Some(TENANT.attach.start_timer())}
|
||||
let attached = {
|
||||
let _attach_timer = match mode {
|
||||
SpawnMode::Create => None,
|
||||
SpawnMode::Normal => {Some(TENANT.attach.start_timer())}
|
||||
};
|
||||
tenant_clone.attach(preload, mode, &ctx).await
|
||||
};
|
||||
match tenant_clone.attach(preload, mode, &ctx).await {
|
||||
|
||||
match attached {
|
||||
Ok(()) => {
|
||||
info!("attach finished, activating");
|
||||
if let Some(t)= attach_timer {t.observe_duration();}
|
||||
tenant_clone.activate(broker_client, None, &ctx);
|
||||
}
|
||||
Err(e) => {
|
||||
if let Some(t)= attach_timer {t.observe_duration();}
|
||||
make_broken(&tenant_clone, anyhow::anyhow!(e));
|
||||
}
|
||||
}
|
||||
@@ -862,34 +854,26 @@ impl Tenant {
|
||||
// logical size calculations: if logical size calculation semaphore is saturated,
|
||||
// then warmup will wait for that before proceeding to the next tenant.
|
||||
if let AttachType::Warmup(_permit) = attach_type {
|
||||
let mut futs = FuturesUnordered::new();
|
||||
let timelines: Vec<_> = tenant_clone.timelines.lock().unwrap().values().cloned().collect();
|
||||
for t in timelines {
|
||||
futs.push(t.await_initial_logical_size())
|
||||
}
|
||||
let mut futs: FuturesUnordered<_> = tenant_clone.timelines.lock().unwrap().values().cloned().map(|t| t.await_initial_logical_size()).collect();
|
||||
tracing::info!("Waiting for initial logical sizes while warming up...");
|
||||
while futs.next().await.is_some() {
|
||||
|
||||
}
|
||||
while futs.next().await.is_some() {}
|
||||
tracing::info!("Warm-up complete");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
.instrument({
|
||||
let span = tracing::info_span!(parent: None, "attach", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug(), gen=?generation);
|
||||
span.follows_from(Span::current());
|
||||
span
|
||||
}),
|
||||
.instrument(tracing::info_span!(parent: None, "attach", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug(), gen=?generation)),
|
||||
);
|
||||
Ok(tenant)
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub(crate) async fn preload(
|
||||
self: &Arc<Tenant>,
|
||||
remote_storage: &GenericRemoteStorage,
|
||||
cancel: CancellationToken,
|
||||
) -> anyhow::Result<TenantPreload> {
|
||||
span::debug_assert_current_span_has_tenant_id();
|
||||
// Get list of remote timelines
|
||||
// download index files for every tenant timeline
|
||||
info!("listing remote timelines");
|
||||
@@ -2896,7 +2880,7 @@ impl Tenant {
|
||||
let config_path = config_path.to_owned();
|
||||
tokio::task::spawn_blocking(move || {
|
||||
Handle::current().block_on(async move {
|
||||
let conf_content = conf_content.as_bytes();
|
||||
let conf_content = conf_content.into_bytes();
|
||||
VirtualFile::crashsafe_overwrite(&config_path, &temp_path, conf_content)
|
||||
.await
|
||||
.with_context(|| {
|
||||
@@ -2933,7 +2917,7 @@ impl Tenant {
|
||||
let target_config_path = target_config_path.to_owned();
|
||||
tokio::task::spawn_blocking(move || {
|
||||
Handle::current().block_on(async move {
|
||||
let conf_content = conf_content.as_bytes();
|
||||
let conf_content = conf_content.into_bytes();
|
||||
VirtualFile::crashsafe_overwrite(&target_config_path, &temp_path, conf_content)
|
||||
.await
|
||||
.with_context(|| {
|
||||
@@ -3982,6 +3966,8 @@ pub(crate) mod harness {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[derive(Debug)]
|
||||
enum LoadMode {
|
||||
Local,
|
||||
Remote,
|
||||
@@ -4064,7 +4050,7 @@ pub(crate) mod harness {
|
||||
info_span!("TenantHarness", tenant_id=%self.tenant_shard_id.tenant_id, shard_id=%self.tenant_shard_id.shard_slug())
|
||||
}
|
||||
|
||||
pub async fn load(&self) -> (Arc<Tenant>, RequestContext) {
|
||||
pub(crate) async fn load(&self) -> (Arc<Tenant>, RequestContext) {
|
||||
let ctx = RequestContext::new(TaskKind::UnitTest, DownloadBehavior::Error);
|
||||
(
|
||||
self.try_load(&ctx)
|
||||
@@ -4074,31 +4060,31 @@ pub(crate) mod harness {
|
||||
)
|
||||
}
|
||||
|
||||
fn remote_empty(&self) -> bool {
|
||||
let tenant_path = self.conf.tenant_path(&self.tenant_shard_id);
|
||||
let remote_tenant_dir = self
|
||||
.remote_fs_dir
|
||||
.join(tenant_path.strip_prefix(&self.conf.workdir).unwrap());
|
||||
if std::fs::metadata(&remote_tenant_dir).is_err() {
|
||||
return true;
|
||||
}
|
||||
|
||||
match std::fs::read_dir(remote_tenant_dir)
|
||||
.unwrap()
|
||||
.flatten()
|
||||
.next()
|
||||
{
|
||||
Some(entry) => {
|
||||
tracing::debug!(
|
||||
"remote_empty: not empty, found file {}",
|
||||
entry.file_name().to_string_lossy(),
|
||||
);
|
||||
false
|
||||
}
|
||||
None => true,
|
||||
}
|
||||
/// For tests that specifically want to exercise the local load path, which does
|
||||
/// not use remote storage.
|
||||
pub(crate) async fn try_load_local(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
) -> anyhow::Result<Arc<Tenant>> {
|
||||
self.do_try_load(ctx, LoadMode::Local).await
|
||||
}
|
||||
|
||||
/// The 'load' in this function is either a local load or a normal attachment,
|
||||
pub(crate) async fn try_load(&self, ctx: &RequestContext) -> anyhow::Result<Arc<Tenant>> {
|
||||
// If we have nothing in remote storage, must use load_local instead of attach: attach
|
||||
// will error out if there are no timelines.
|
||||
//
|
||||
// See https://github.com/neondatabase/neon/issues/5456 for how we will eliminate
|
||||
// this weird state of a Tenant which exists but doesn't have any timelines.
|
||||
let mode = match self.remote_empty() {
|
||||
true => LoadMode::Local,
|
||||
false => LoadMode::Remote,
|
||||
};
|
||||
|
||||
self.do_try_load(ctx, mode).await
|
||||
}
|
||||
|
||||
#[instrument(skip_all, fields(tenant_id=%self.tenant_shard_id.tenant_id, shard_id=%self.tenant_shard_id.shard_slug(), ?mode))]
|
||||
async fn do_try_load(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
@@ -4125,20 +4111,13 @@ pub(crate) mod harness {
|
||||
|
||||
match mode {
|
||||
LoadMode::Local => {
|
||||
tenant
|
||||
.load_local(ctx)
|
||||
.instrument(info_span!("try_load", tenant_id=%self.tenant_shard_id.tenant_id, shard_id=%self.tenant_shard_id.shard_slug()))
|
||||
.await?;
|
||||
tenant.load_local(ctx).await?;
|
||||
}
|
||||
LoadMode::Remote => {
|
||||
let preload = tenant
|
||||
.preload(&self.remote_storage, CancellationToken::new())
|
||||
.instrument(info_span!("try_load_preload", tenant_id=%self.tenant_shard_id.tenant_id, shard_id=%self.tenant_shard_id.shard_slug()))
|
||||
.await?;
|
||||
tenant
|
||||
.attach(Some(preload), SpawnMode::Normal, ctx)
|
||||
.instrument(info_span!("try_load", tenant_id=%self.tenant_shard_id.tenant_id, shard_id=%self.tenant_shard_id.shard_slug()))
|
||||
.await?;
|
||||
tenant.attach(Some(preload), SpawnMode::Normal, ctx).await?;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4149,25 +4128,29 @@ pub(crate) mod harness {
|
||||
Ok(tenant)
|
||||
}
|
||||
|
||||
/// For tests that specifically want to exercise the local load path, which does
|
||||
/// not use remote storage.
|
||||
pub async fn try_load_local(&self, ctx: &RequestContext) -> anyhow::Result<Arc<Tenant>> {
|
||||
self.do_try_load(ctx, LoadMode::Local).await
|
||||
}
|
||||
fn remote_empty(&self) -> bool {
|
||||
let tenant_path = self.conf.tenant_path(&self.tenant_shard_id);
|
||||
let remote_tenant_dir = self
|
||||
.remote_fs_dir
|
||||
.join(tenant_path.strip_prefix(&self.conf.workdir).unwrap());
|
||||
if std::fs::metadata(&remote_tenant_dir).is_err() {
|
||||
return true;
|
||||
}
|
||||
|
||||
/// The 'load' in this function is either a local load or a normal attachment,
|
||||
pub async fn try_load(&self, ctx: &RequestContext) -> anyhow::Result<Arc<Tenant>> {
|
||||
// If we have nothing in remote storage, must use load_local instead of attach: attach
|
||||
// will error out if there are no timelines.
|
||||
//
|
||||
// See https://github.com/neondatabase/neon/issues/5456 for how we will eliminate
|
||||
// this weird state of a Tenant which exists but doesn't have any timelines.
|
||||
let mode = match self.remote_empty() {
|
||||
true => LoadMode::Local,
|
||||
false => LoadMode::Remote,
|
||||
};
|
||||
|
||||
self.do_try_load(ctx, mode).await
|
||||
match std::fs::read_dir(remote_tenant_dir)
|
||||
.unwrap()
|
||||
.flatten()
|
||||
.next()
|
||||
{
|
||||
Some(entry) => {
|
||||
tracing::debug!(
|
||||
"remote_empty: not empty, found file {}",
|
||||
entry.file_name().to_string_lossy(),
|
||||
);
|
||||
false
|
||||
}
|
||||
None => true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn timeline_path(&self, timeline_id: &TimelineId) -> Utf8PathBuf {
|
||||
|
||||
@@ -11,6 +11,9 @@
|
||||
//! len < 128: 0XXXXXXX
|
||||
//! len >= 128: 1XXXXXXX XXXXXXXX XXXXXXXX XXXXXXXX
|
||||
//!
|
||||
use bytes::{BufMut, BytesMut};
|
||||
use tokio_epoll_uring::{BoundedBuf, Slice};
|
||||
|
||||
use crate::context::RequestContext;
|
||||
use crate::page_cache::PAGE_SZ;
|
||||
use crate::tenant::block_io::BlockCursor;
|
||||
@@ -100,6 +103,8 @@ pub struct BlobWriter<const BUFFERED: bool> {
|
||||
offset: u64,
|
||||
/// A buffer to save on write calls, only used if BUFFERED=true
|
||||
buf: Vec<u8>,
|
||||
/// We do tiny writes for the length headers; they need to be in an owned buffer;
|
||||
io_buf: Option<BytesMut>,
|
||||
}
|
||||
|
||||
impl<const BUFFERED: bool> BlobWriter<BUFFERED> {
|
||||
@@ -108,6 +113,7 @@ impl<const BUFFERED: bool> BlobWriter<BUFFERED> {
|
||||
inner,
|
||||
offset: start_offset,
|
||||
buf: Vec::with_capacity(Self::CAPACITY),
|
||||
io_buf: Some(BytesMut::new()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -117,21 +123,31 @@ impl<const BUFFERED: bool> BlobWriter<BUFFERED> {
|
||||
|
||||
const CAPACITY: usize = if BUFFERED { PAGE_SZ } else { 0 };
|
||||
|
||||
#[inline(always)]
|
||||
/// Writes the given buffer directly to the underlying `VirtualFile`.
|
||||
/// You need to make sure that the internal buffer is empty, otherwise
|
||||
/// data will be written in wrong order.
|
||||
async fn write_all_unbuffered(&mut self, src_buf: &[u8]) -> Result<(), Error> {
|
||||
self.inner.write_all(src_buf).await?;
|
||||
self.offset += src_buf.len() as u64;
|
||||
Ok(())
|
||||
#[inline(always)]
|
||||
async fn write_all_unbuffered<B: BoundedBuf>(
|
||||
&mut self,
|
||||
src_buf: B,
|
||||
) -> (B::Buf, Result<(), Error>) {
|
||||
let (src_buf, res) = self.inner.write_all(src_buf).await;
|
||||
let nbytes = match res {
|
||||
Ok(nbytes) => nbytes,
|
||||
Err(e) => return (src_buf, Err(e)),
|
||||
};
|
||||
self.offset += nbytes as u64;
|
||||
(src_buf, Ok(()))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
/// Flushes the internal buffer to the underlying `VirtualFile`.
|
||||
pub async fn flush_buffer(&mut self) -> Result<(), Error> {
|
||||
self.inner.write_all(&self.buf).await?;
|
||||
self.buf.clear();
|
||||
let buf = std::mem::take(&mut self.buf);
|
||||
let (mut buf, res) = self.inner.write_all(buf).await;
|
||||
res?;
|
||||
buf.clear();
|
||||
self.buf = buf;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -146,62 +162,91 @@ impl<const BUFFERED: bool> BlobWriter<BUFFERED> {
|
||||
}
|
||||
|
||||
/// Internal, possibly buffered, write function
|
||||
async fn write_all(&mut self, mut src_buf: &[u8]) -> Result<(), Error> {
|
||||
async fn write_all<B: BoundedBuf>(&mut self, src_buf: B) -> (B::Buf, Result<(), Error>) {
|
||||
if !BUFFERED {
|
||||
assert!(self.buf.is_empty());
|
||||
self.write_all_unbuffered(src_buf).await?;
|
||||
return Ok(());
|
||||
return self.write_all_unbuffered(src_buf).await;
|
||||
}
|
||||
let remaining = Self::CAPACITY - self.buf.len();
|
||||
let src_buf_len = src_buf.bytes_init();
|
||||
if src_buf_len == 0 {
|
||||
return (Slice::into_inner(src_buf.slice_full()), Ok(()));
|
||||
}
|
||||
let mut src_buf = src_buf.slice(0..src_buf_len);
|
||||
// First try to copy as much as we can into the buffer
|
||||
if remaining > 0 {
|
||||
let copied = self.write_into_buffer(src_buf);
|
||||
src_buf = &src_buf[copied..];
|
||||
let copied = self.write_into_buffer(&src_buf);
|
||||
src_buf = src_buf.slice(copied..);
|
||||
}
|
||||
// Then, if the buffer is full, flush it out
|
||||
if self.buf.len() == Self::CAPACITY {
|
||||
self.flush_buffer().await?;
|
||||
if let Err(e) = self.flush_buffer().await {
|
||||
return (Slice::into_inner(src_buf), Err(e));
|
||||
}
|
||||
}
|
||||
// Finally, write the tail of src_buf:
|
||||
// If it wholly fits into the buffer without
|
||||
// completely filling it, then put it there.
|
||||
// If not, write it out directly.
|
||||
if !src_buf.is_empty() {
|
||||
let src_buf = if !src_buf.is_empty() {
|
||||
assert_eq!(self.buf.len(), 0);
|
||||
if src_buf.len() < Self::CAPACITY {
|
||||
let copied = self.write_into_buffer(src_buf);
|
||||
let copied = self.write_into_buffer(&src_buf);
|
||||
// We just verified above that src_buf fits into our internal buffer.
|
||||
assert_eq!(copied, src_buf.len());
|
||||
Slice::into_inner(src_buf)
|
||||
} else {
|
||||
self.write_all_unbuffered(src_buf).await?;
|
||||
let (src_buf, res) = self.write_all_unbuffered(src_buf).await;
|
||||
if let Err(e) = res {
|
||||
return (src_buf, Err(e));
|
||||
}
|
||||
src_buf
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
} else {
|
||||
Slice::into_inner(src_buf)
|
||||
};
|
||||
(src_buf, Ok(()))
|
||||
}
|
||||
|
||||
/// Write a blob of data. Returns the offset that it was written to,
|
||||
/// which can be used to retrieve the data later.
|
||||
pub async fn write_blob(&mut self, srcbuf: &[u8]) -> Result<u64, Error> {
|
||||
pub async fn write_blob<B: BoundedBuf>(&mut self, srcbuf: B) -> (B::Buf, Result<u64, Error>) {
|
||||
let offset = self.offset;
|
||||
|
||||
if srcbuf.len() < 128 {
|
||||
// Short blob. Write a 1-byte length header
|
||||
let len_buf = srcbuf.len() as u8;
|
||||
self.write_all(&[len_buf]).await?;
|
||||
} else {
|
||||
// Write a 4-byte length header
|
||||
if srcbuf.len() > 0x7fff_ffff {
|
||||
return Err(Error::new(
|
||||
ErrorKind::Other,
|
||||
format!("blob too large ({} bytes)", srcbuf.len()),
|
||||
));
|
||||
let len = srcbuf.bytes_init();
|
||||
|
||||
let mut io_buf = self.io_buf.take().expect("we always put it back below");
|
||||
io_buf.clear();
|
||||
let (io_buf, hdr_res) = async {
|
||||
if len < 128 {
|
||||
// Short blob. Write a 1-byte length header
|
||||
io_buf.put_u8(len as u8);
|
||||
self.write_all(io_buf).await
|
||||
} else {
|
||||
// Write a 4-byte length header
|
||||
if len > 0x7fff_ffff {
|
||||
return (
|
||||
io_buf,
|
||||
Err(Error::new(
|
||||
ErrorKind::Other,
|
||||
format!("blob too large ({} bytes)", len),
|
||||
)),
|
||||
);
|
||||
}
|
||||
let mut len_buf = (len as u32).to_be_bytes();
|
||||
len_buf[0] |= 0x80;
|
||||
io_buf.extend_from_slice(&len_buf[..]);
|
||||
self.write_all(io_buf).await
|
||||
}
|
||||
let mut len_buf = ((srcbuf.len()) as u32).to_be_bytes();
|
||||
len_buf[0] |= 0x80;
|
||||
self.write_all(&len_buf).await?;
|
||||
}
|
||||
self.write_all(srcbuf).await?;
|
||||
Ok(offset)
|
||||
.await;
|
||||
self.io_buf = Some(io_buf);
|
||||
match hdr_res {
|
||||
Ok(_) => (),
|
||||
Err(e) => return (Slice::into_inner(srcbuf.slice(..)), Err(e)),
|
||||
}
|
||||
let (srcbuf, res) = self.write_all(srcbuf).await;
|
||||
(srcbuf, res.map(|_| offset))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -248,12 +293,14 @@ mod tests {
|
||||
let file = VirtualFile::create(pathbuf.as_path()).await?;
|
||||
let mut wtr = BlobWriter::<BUFFERED>::new(file, 0);
|
||||
for blob in blobs.iter() {
|
||||
let offs = wtr.write_blob(blob).await?;
|
||||
let (_, res) = wtr.write_blob(blob.clone()).await;
|
||||
let offs = res?;
|
||||
offsets.push(offs);
|
||||
}
|
||||
// Write out one page worth of zeros so that we can
|
||||
// read again with read_blk
|
||||
let offs = wtr.write_blob(&vec![0; PAGE_SZ]).await?;
|
||||
let (_, res) = wtr.write_blob(vec![0; PAGE_SZ]).await;
|
||||
let offs = res?;
|
||||
println!("Writing final blob at offs={offs}");
|
||||
wtr.flush_buffer().await?;
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ use pageserver_api::{models::TenantState, shard::TenantShardId};
|
||||
use remote_storage::{GenericRemoteStorage, RemotePath};
|
||||
use tokio::sync::OwnedMutexGuard;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{error, instrument, Instrument, Span};
|
||||
use tracing::{error, instrument, Instrument};
|
||||
|
||||
use utils::{backoff, completion, crashsafe, fs_ext, id::TimelineId};
|
||||
|
||||
@@ -496,11 +496,7 @@ impl DeleteTenantFlow {
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
.instrument({
|
||||
let span = tracing::info_span!(parent: None, "delete_tenant", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug());
|
||||
span.follows_from(Span::current());
|
||||
span
|
||||
}),
|
||||
.instrument(tracing::info_span!(parent: None, "delete_tenant", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug())),
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -279,7 +279,7 @@ pub async fn save_metadata(
|
||||
let path = conf.metadata_path(tenant_shard_id, timeline_id);
|
||||
let temp_path = path_with_suffix_extension(&path, TEMP_FILE_SUFFIX);
|
||||
let metadata_bytes = data.to_bytes().context("serialize metadata")?;
|
||||
VirtualFile::crashsafe_overwrite(&path, &temp_path, &metadata_bytes)
|
||||
VirtualFile::crashsafe_overwrite(&path, &temp_path, metadata_bytes)
|
||||
.await
|
||||
.context("write metadata")?;
|
||||
Ok(())
|
||||
|
||||
@@ -1700,23 +1700,6 @@ impl RemoteTimelineClient {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_layers_metadata(
|
||||
&self,
|
||||
layers: Vec<LayerFileName>,
|
||||
) -> anyhow::Result<Vec<Option<LayerFileMetadata>>> {
|
||||
let q = self.upload_queue.lock().unwrap();
|
||||
let q = match &*q {
|
||||
UploadQueue::Stopped(_) | UploadQueue::Uninitialized => {
|
||||
anyhow::bail!("queue is in state {}", q.as_str())
|
||||
}
|
||||
UploadQueue::Initialized(inner) => inner,
|
||||
};
|
||||
|
||||
let decorated = layers.into_iter().map(|l| q.latest_files.get(&l).cloned());
|
||||
|
||||
Ok(decorated.collect())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn remote_timelines_path(tenant_shard_id: &TenantShardId) -> RemotePath {
|
||||
|
||||
@@ -486,7 +486,7 @@ impl<'a> TenantDownloader<'a> {
|
||||
let heatmap_path_bg = heatmap_path.clone();
|
||||
tokio::task::spawn_blocking(move || {
|
||||
tokio::runtime::Handle::current().block_on(async move {
|
||||
VirtualFile::crashsafe_overwrite(&heatmap_path_bg, &temp_path, &heatmap_bytes).await
|
||||
VirtualFile::crashsafe_overwrite(&heatmap_path_bg, &temp_path, heatmap_bytes).await
|
||||
})
|
||||
})
|
||||
.await
|
||||
|
||||
@@ -257,6 +257,12 @@ impl LayerAccessStats {
|
||||
ret
|
||||
}
|
||||
|
||||
/// Get the latest access timestamp, falling back to latest residence event, further falling
|
||||
/// back to `SystemTime::now` for a usable timestamp for eviction.
|
||||
pub(crate) fn latest_activity_or_now(&self) -> SystemTime {
|
||||
self.latest_activity().unwrap_or_else(SystemTime::now)
|
||||
}
|
||||
|
||||
/// Get the latest access timestamp, falling back to latest residence event.
|
||||
///
|
||||
/// This function can only return `None` if there has not yet been a call to the
|
||||
@@ -271,7 +277,7 @@ impl LayerAccessStats {
|
||||
/// that that type can only be produced by inserting into the layer map.
|
||||
///
|
||||
/// [`record_residence_event`]: Self::record_residence_event
|
||||
pub(crate) fn latest_activity(&self) -> Option<SystemTime> {
|
||||
fn latest_activity(&self) -> Option<SystemTime> {
|
||||
let locked = self.0.lock().unwrap();
|
||||
let inner = &locked.for_eviction_policy;
|
||||
match inner.last_accesses.recent() {
|
||||
|
||||
@@ -416,27 +416,31 @@ impl DeltaLayerWriterInner {
|
||||
/// The values must be appended in key, lsn order.
|
||||
///
|
||||
async fn put_value(&mut self, key: Key, lsn: Lsn, val: Value) -> anyhow::Result<()> {
|
||||
self.put_value_bytes(key, lsn, &Value::ser(&val)?, val.will_init())
|
||||
.await
|
||||
let (_, res) = self
|
||||
.put_value_bytes(key, lsn, Value::ser(&val)?, val.will_init())
|
||||
.await;
|
||||
res
|
||||
}
|
||||
|
||||
async fn put_value_bytes(
|
||||
&mut self,
|
||||
key: Key,
|
||||
lsn: Lsn,
|
||||
val: &[u8],
|
||||
val: Vec<u8>,
|
||||
will_init: bool,
|
||||
) -> anyhow::Result<()> {
|
||||
) -> (Vec<u8>, anyhow::Result<()>) {
|
||||
assert!(self.lsn_range.start <= lsn);
|
||||
|
||||
let off = self.blob_writer.write_blob(val).await?;
|
||||
let (val, res) = self.blob_writer.write_blob(val).await;
|
||||
let off = match res {
|
||||
Ok(off) => off,
|
||||
Err(e) => return (val, Err(anyhow::anyhow!(e))),
|
||||
};
|
||||
|
||||
let blob_ref = BlobRef::new(off, will_init);
|
||||
|
||||
let delta_key = DeltaKey::from_key_lsn(&key, lsn);
|
||||
self.tree.append(&delta_key.0, blob_ref.0)?;
|
||||
|
||||
Ok(())
|
||||
let res = self.tree.append(&delta_key.0, blob_ref.0);
|
||||
(val, res.map_err(|e| anyhow::anyhow!(e)))
|
||||
}
|
||||
|
||||
fn size(&self) -> u64 {
|
||||
@@ -457,7 +461,8 @@ impl DeltaLayerWriterInner {
|
||||
file.seek(SeekFrom::Start(index_start_blk as u64 * PAGE_SZ as u64))
|
||||
.await?;
|
||||
for buf in block_buf.blocks {
|
||||
file.write_all(buf.as_ref()).await?;
|
||||
let (_buf, res) = file.write_all(buf).await;
|
||||
res?;
|
||||
}
|
||||
assert!(self.lsn_range.start < self.lsn_range.end);
|
||||
// Fill in the summary on blk 0
|
||||
@@ -472,17 +477,12 @@ impl DeltaLayerWriterInner {
|
||||
index_root_blk,
|
||||
};
|
||||
|
||||
let mut buf = smallvec::SmallVec::<[u8; PAGE_SZ]>::new();
|
||||
let mut buf = Vec::with_capacity(PAGE_SZ);
|
||||
// TODO: could use smallvec here but it's a pain with Slice<T>
|
||||
Summary::ser_into(&summary, &mut buf)?;
|
||||
if buf.spilled() {
|
||||
// This is bad as we only have one free block for the summary
|
||||
warn!(
|
||||
"Used more than one page size for summary buffer: {}",
|
||||
buf.len()
|
||||
);
|
||||
}
|
||||
file.seek(SeekFrom::Start(0)).await?;
|
||||
file.write_all(&buf).await?;
|
||||
let (_buf, res) = file.write_all(buf).await;
|
||||
res?;
|
||||
|
||||
let metadata = file
|
||||
.metadata()
|
||||
@@ -587,9 +587,9 @@ impl DeltaLayerWriter {
|
||||
&mut self,
|
||||
key: Key,
|
||||
lsn: Lsn,
|
||||
val: &[u8],
|
||||
val: Vec<u8>,
|
||||
will_init: bool,
|
||||
) -> anyhow::Result<()> {
|
||||
) -> (Vec<u8>, anyhow::Result<()>) {
|
||||
self.inner
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
@@ -675,18 +675,12 @@ impl DeltaLayer {
|
||||
|
||||
let new_summary = rewrite(actual_summary);
|
||||
|
||||
let mut buf = smallvec::SmallVec::<[u8; PAGE_SZ]>::new();
|
||||
let mut buf = Vec::with_capacity(PAGE_SZ);
|
||||
// TODO: could use smallvec here, but it's a pain with Slice<T>
|
||||
Summary::ser_into(&new_summary, &mut buf).context("serialize")?;
|
||||
if buf.spilled() {
|
||||
// The code in DeltaLayerWriterInner just warn!()s for this.
|
||||
// It should probably error out as well.
|
||||
return Err(RewriteSummaryError::Other(anyhow::anyhow!(
|
||||
"Used more than one page size for summary buffer: {}",
|
||||
buf.len()
|
||||
)));
|
||||
}
|
||||
file.seek(SeekFrom::Start(0)).await?;
|
||||
file.write_all(&buf).await?;
|
||||
let (_buf, res) = file.write_all(buf).await;
|
||||
res?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -341,18 +341,12 @@ impl ImageLayer {
|
||||
|
||||
let new_summary = rewrite(actual_summary);
|
||||
|
||||
let mut buf = smallvec::SmallVec::<[u8; PAGE_SZ]>::new();
|
||||
let mut buf = Vec::with_capacity(PAGE_SZ);
|
||||
// TODO: could use smallvec here but it's a pain with Slice<T>
|
||||
Summary::ser_into(&new_summary, &mut buf).context("serialize")?;
|
||||
if buf.spilled() {
|
||||
// The code in ImageLayerWriterInner just warn!()s for this.
|
||||
// It should probably error out as well.
|
||||
return Err(RewriteSummaryError::Other(anyhow::anyhow!(
|
||||
"Used more than one page size for summary buffer: {}",
|
||||
buf.len()
|
||||
)));
|
||||
}
|
||||
file.seek(SeekFrom::Start(0)).await?;
|
||||
file.write_all(&buf).await?;
|
||||
let (_buf, res) = file.write_all(buf).await;
|
||||
res?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -528,9 +522,11 @@ impl ImageLayerWriterInner {
|
||||
///
|
||||
/// The page versions must be appended in blknum order.
|
||||
///
|
||||
async fn put_image(&mut self, key: Key, img: &[u8]) -> anyhow::Result<()> {
|
||||
async fn put_image(&mut self, key: Key, img: Bytes) -> anyhow::Result<()> {
|
||||
ensure!(self.key_range.contains(&key));
|
||||
let off = self.blob_writer.write_blob(img).await?;
|
||||
let (_img, res) = self.blob_writer.write_blob(img).await;
|
||||
// TODO: re-use the buffer for `img` further upstack
|
||||
let off = res?;
|
||||
|
||||
let mut keybuf: [u8; KEY_SIZE] = [0u8; KEY_SIZE];
|
||||
key.write_to_byte_slice(&mut keybuf);
|
||||
@@ -553,7 +549,8 @@ impl ImageLayerWriterInner {
|
||||
.await?;
|
||||
let (index_root_blk, block_buf) = self.tree.finish()?;
|
||||
for buf in block_buf.blocks {
|
||||
file.write_all(buf.as_ref()).await?;
|
||||
let (_buf, res) = file.write_all(buf).await;
|
||||
res?;
|
||||
}
|
||||
|
||||
// Fill in the summary on blk 0
|
||||
@@ -568,17 +565,12 @@ impl ImageLayerWriterInner {
|
||||
index_root_blk,
|
||||
};
|
||||
|
||||
let mut buf = smallvec::SmallVec::<[u8; PAGE_SZ]>::new();
|
||||
let mut buf = Vec::with_capacity(PAGE_SZ);
|
||||
// TODO: could use smallvec here but it's a pain with Slice<T>
|
||||
Summary::ser_into(&summary, &mut buf)?;
|
||||
if buf.spilled() {
|
||||
// This is bad as we only have one free block for the summary
|
||||
warn!(
|
||||
"Used more than one page size for summary buffer: {}",
|
||||
buf.len()
|
||||
);
|
||||
}
|
||||
file.seek(SeekFrom::Start(0)).await?;
|
||||
file.write_all(&buf).await?;
|
||||
let (_buf, res) = file.write_all(buf).await;
|
||||
res?;
|
||||
|
||||
let metadata = file
|
||||
.metadata()
|
||||
@@ -659,7 +651,7 @@ impl ImageLayerWriter {
|
||||
///
|
||||
/// The page versions must be appended in blknum order.
|
||||
///
|
||||
pub async fn put_image(&mut self, key: Key, img: &[u8]) -> anyhow::Result<()> {
|
||||
pub async fn put_image(&mut self, key: Key, img: Bytes) -> anyhow::Result<()> {
|
||||
self.inner.as_mut().unwrap().put_image(key, img).await
|
||||
}
|
||||
|
||||
|
||||
@@ -383,9 +383,11 @@ impl InMemoryLayer {
|
||||
for (lsn, pos) in vec_map.as_slice() {
|
||||
cursor.read_blob_into_buf(*pos, &mut buf, &ctx).await?;
|
||||
let will_init = Value::des(&buf)?.will_init();
|
||||
delta_layer_writer
|
||||
.put_value_bytes(key, *lsn, &buf, will_init)
|
||||
.await?;
|
||||
let res;
|
||||
(buf, res) = delta_layer_writer
|
||||
.put_value_bytes(key, *lsn, buf, will_init)
|
||||
.await;
|
||||
res?;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1413,10 +1413,6 @@ impl ResidentLayer {
|
||||
&self.owner.0.path
|
||||
}
|
||||
|
||||
pub(crate) fn access_stats(&self) -> &LayerAccessStats {
|
||||
self.owner.access_stats()
|
||||
}
|
||||
|
||||
pub(crate) fn metadata(&self) -> LayerFileMetadata {
|
||||
self.owner.metadata()
|
||||
}
|
||||
|
||||
@@ -12,7 +12,9 @@ use bytes::Bytes;
|
||||
use camino::{Utf8Path, Utf8PathBuf};
|
||||
use enumset::EnumSet;
|
||||
use fail::fail_point;
|
||||
use futures::stream::StreamExt;
|
||||
use itertools::Itertools;
|
||||
use once_cell::sync::Lazy;
|
||||
use pageserver_api::{
|
||||
keyspace::{key_range_size, KeySpaceAccum},
|
||||
models::{
|
||||
@@ -33,17 +35,22 @@ use tokio_util::sync::CancellationToken;
|
||||
use tracing::*;
|
||||
use utils::sync::gate::Gate;
|
||||
|
||||
use std::collections::{BTreeMap, BinaryHeap, HashMap, HashSet};
|
||||
use std::ops::{Deref, Range};
|
||||
use std::pin::pin;
|
||||
use std::sync::atomic::Ordering as AtomicOrdering;
|
||||
use std::sync::{Arc, Mutex, RwLock, Weak};
|
||||
use std::time::{Duration, Instant, SystemTime};
|
||||
use std::{
|
||||
array,
|
||||
collections::{BTreeMap, BinaryHeap, HashMap, HashSet},
|
||||
sync::atomic::AtomicU64,
|
||||
};
|
||||
use std::{
|
||||
cmp::{max, min, Ordering},
|
||||
ops::ControlFlow,
|
||||
};
|
||||
|
||||
use crate::pgdatadir_mapping::DirectoryKind;
|
||||
use crate::tenant::timeline::logical_size::CurrentLogicalSize;
|
||||
use crate::tenant::{
|
||||
layer_map::{LayerMap, SearchResult},
|
||||
@@ -105,7 +112,7 @@ use self::logical_size::LogicalSize;
|
||||
use self::walreceiver::{WalReceiver, WalReceiverConf};
|
||||
|
||||
use super::config::TenantConf;
|
||||
use super::remote_timeline_client::index::{IndexLayerMetadata, IndexPart};
|
||||
use super::remote_timeline_client::index::IndexPart;
|
||||
use super::remote_timeline_client::RemoteTimelineClient;
|
||||
use super::secondary::heatmap::{HeatMapLayer, HeatMapTimeline};
|
||||
use super::{debug_assert_current_span_has_tenant_and_timeline_id, AttachedTenantConf};
|
||||
@@ -257,6 +264,8 @@ pub struct Timeline {
|
||||
// in `crate::page_service` writes these metrics.
|
||||
pub(crate) query_metrics: crate::metrics::SmgrQueryTimePerTimeline,
|
||||
|
||||
directory_metrics: [AtomicU64; DirectoryKind::KINDS_NUM],
|
||||
|
||||
/// Ensures layers aren't frozen by checkpointer between
|
||||
/// [`Timeline::get_layer_for_write`] and layer reads.
|
||||
/// Locked automatically by [`TimelineWriter`] and checkpointer.
|
||||
@@ -789,6 +798,10 @@ impl Timeline {
|
||||
self.metrics.resident_physical_size_get()
|
||||
}
|
||||
|
||||
pub(crate) fn get_directory_metrics(&self) -> [u64; DirectoryKind::KINDS_NUM] {
|
||||
array::from_fn(|idx| self.directory_metrics[idx].load(AtomicOrdering::Relaxed))
|
||||
}
|
||||
|
||||
///
|
||||
/// Wait until WAL has been received and processed up to this LSN.
|
||||
///
|
||||
@@ -1458,7 +1471,7 @@ impl Timeline {
|
||||
generation,
|
||||
shard_identity,
|
||||
pg_version,
|
||||
layers: Arc::new(tokio::sync::RwLock::new(LayerManager::create())),
|
||||
layers: Default::default(),
|
||||
wanted_image_layers: Mutex::new(None),
|
||||
|
||||
walredo_mgr,
|
||||
@@ -1495,6 +1508,8 @@ impl Timeline {
|
||||
&timeline_id,
|
||||
),
|
||||
|
||||
directory_metrics: array::from_fn(|_| AtomicU64::new(0)),
|
||||
|
||||
flush_loop_state: Mutex::new(FlushLoopState::NotStarted),
|
||||
|
||||
layer_flush_start_tx,
|
||||
@@ -2263,6 +2278,29 @@ impl Timeline {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn update_directory_entries_count(&self, kind: DirectoryKind, count: u64) {
|
||||
self.directory_metrics[kind.offset()].store(count, AtomicOrdering::Relaxed);
|
||||
let aux_metric =
|
||||
self.directory_metrics[DirectoryKind::AuxFiles.offset()].load(AtomicOrdering::Relaxed);
|
||||
|
||||
let sum_of_entries = self
|
||||
.directory_metrics
|
||||
.iter()
|
||||
.map(|v| v.load(AtomicOrdering::Relaxed))
|
||||
.sum();
|
||||
// Set a high general threshold and a lower threshold for the auxiliary files,
|
||||
// as we can have large numbers of relations in the db directory.
|
||||
const SUM_THRESHOLD: u64 = 5000;
|
||||
const AUX_THRESHOLD: u64 = 1000;
|
||||
if sum_of_entries >= SUM_THRESHOLD || aux_metric >= AUX_THRESHOLD {
|
||||
self.metrics
|
||||
.directory_entries_count_gauge
|
||||
.set(sum_of_entries);
|
||||
} else if let Some(metric) = Lazy::get(&self.metrics.directory_entries_count_gauge) {
|
||||
metric.set(sum_of_entries);
|
||||
}
|
||||
}
|
||||
|
||||
async fn find_layer(&self, layer_file_name: &str) -> Option<Layer> {
|
||||
let guard = self.layers.read().await;
|
||||
for historic_layer in guard.layer_map().iter_historic_layers() {
|
||||
@@ -2283,45 +2321,28 @@ impl Timeline {
|
||||
/// should treat this as a cue to simply skip doing any heatmap uploading
|
||||
/// for this timeline.
|
||||
pub(crate) async fn generate_heatmap(&self) -> Option<HeatMapTimeline> {
|
||||
let eviction_info = self.get_local_layers_for_disk_usage_eviction().await;
|
||||
// no point in heatmaps without remote client
|
||||
let _remote_client = self.remote_client.as_ref()?;
|
||||
|
||||
let remote_client = match &self.remote_client {
|
||||
Some(c) => c,
|
||||
None => return None,
|
||||
};
|
||||
if !self.is_active() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let layer_file_names = eviction_info
|
||||
.resident_layers
|
||||
.iter()
|
||||
.map(|l| l.layer.get_name())
|
||||
.collect::<Vec<_>>();
|
||||
let guard = self.layers.read().await;
|
||||
|
||||
let decorated = match remote_client.get_layers_metadata(layer_file_names) {
|
||||
Ok(d) => d,
|
||||
Err(_) => {
|
||||
// Getting metadata only fails on Timeline in bad state.
|
||||
return None;
|
||||
}
|
||||
};
|
||||
let resident = guard.resident_layers().map(|layer| {
|
||||
let last_activity_ts = layer.access_stats().latest_activity_or_now();
|
||||
|
||||
let heatmap_layers = std::iter::zip(
|
||||
eviction_info.resident_layers.into_iter(),
|
||||
decorated.into_iter(),
|
||||
)
|
||||
.filter_map(|(layer, remote_info)| {
|
||||
remote_info.map(|remote_info| {
|
||||
HeatMapLayer::new(
|
||||
layer.layer.get_name(),
|
||||
IndexLayerMetadata::from(remote_info),
|
||||
layer.last_activity_ts,
|
||||
)
|
||||
})
|
||||
HeatMapLayer::new(
|
||||
layer.layer_desc().filename(),
|
||||
layer.metadata().into(),
|
||||
last_activity_ts,
|
||||
)
|
||||
});
|
||||
|
||||
Some(HeatMapTimeline::new(
|
||||
self.timeline_id,
|
||||
heatmap_layers.collect(),
|
||||
))
|
||||
let layers = resident.collect().await;
|
||||
|
||||
Some(HeatMapTimeline::new(self.timeline_id, layers))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3328,7 +3349,7 @@ impl Timeline {
|
||||
}
|
||||
};
|
||||
|
||||
image_layer_writer.put_image(img_key, &img).await?;
|
||||
image_layer_writer.put_image(img_key, img).await?;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4662,41 +4683,24 @@ impl Timeline {
|
||||
/// Returns non-remote layers for eviction.
|
||||
pub(crate) async fn get_local_layers_for_disk_usage_eviction(&self) -> DiskUsageEvictionInfo {
|
||||
let guard = self.layers.read().await;
|
||||
let layers = guard.layer_map();
|
||||
|
||||
let mut max_layer_size: Option<u64> = None;
|
||||
let mut resident_layers = Vec::new();
|
||||
|
||||
for l in layers.iter_historic_layers() {
|
||||
let file_size = l.file_size();
|
||||
max_layer_size = max_layer_size.map_or(Some(file_size), |m| Some(m.max(file_size)));
|
||||
let resident_layers = guard
|
||||
.resident_layers()
|
||||
.map(|layer| {
|
||||
let file_size = layer.layer_desc().file_size;
|
||||
max_layer_size = max_layer_size.map_or(Some(file_size), |m| Some(m.max(file_size)));
|
||||
|
||||
let l = guard.get_from_desc(&l);
|
||||
let last_activity_ts = layer.access_stats().latest_activity_or_now();
|
||||
|
||||
let l = match l.keep_resident().await {
|
||||
Ok(Some(l)) => l,
|
||||
Ok(None) => continue,
|
||||
Err(e) => {
|
||||
// these should not happen, but we cannot make them statically impossible right
|
||||
// now.
|
||||
tracing::warn!(layer=%l, "failed to keep the layer resident: {e:#}");
|
||||
continue;
|
||||
EvictionCandidate {
|
||||
layer: layer.into(),
|
||||
last_activity_ts,
|
||||
relative_last_activity: finite_f32::FiniteF32::ZERO,
|
||||
}
|
||||
};
|
||||
|
||||
let last_activity_ts = l.access_stats().latest_activity().unwrap_or_else(|| {
|
||||
// We only use this fallback if there's an implementation error.
|
||||
// `latest_activity` already does rate-limited warn!() log.
|
||||
debug!(layer=%l, "last_activity returns None, using SystemTime::now");
|
||||
SystemTime::now()
|
||||
});
|
||||
|
||||
resident_layers.push(EvictionCandidate {
|
||||
layer: l.drop_eviction_guard().into(),
|
||||
last_activity_ts,
|
||||
relative_last_activity: finite_f32::FiniteF32::ZERO,
|
||||
});
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
.await;
|
||||
|
||||
DiskUsageEvictionInfo {
|
||||
max_layer_size,
|
||||
|
||||
@@ -6,7 +6,7 @@ use std::{
|
||||
use anyhow::Context;
|
||||
use pageserver_api::{models::TimelineState, shard::TenantShardId};
|
||||
use tokio::sync::OwnedMutexGuard;
|
||||
use tracing::{debug, error, info, instrument, warn, Instrument, Span};
|
||||
use tracing::{debug, error, info, instrument, warn, Instrument};
|
||||
use utils::{crashsafe, fs_ext, id::TimelineId};
|
||||
|
||||
use crate::{
|
||||
@@ -541,12 +541,7 @@ impl DeleteTimelineFlow {
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
.instrument({
|
||||
let span =
|
||||
tracing::info_span!(parent: None, "delete_timeline", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug(),timeline_id=%timeline_id);
|
||||
span.follows_from(Span::current());
|
||||
span
|
||||
}),
|
||||
.instrument(tracing::info_span!(parent: None, "delete_timeline", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug(),timeline_id=%timeline_id)),
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -239,12 +239,7 @@ impl Timeline {
|
||||
}
|
||||
};
|
||||
|
||||
let last_activity_ts = hist_layer.access_stats().latest_activity().unwrap_or_else(|| {
|
||||
// We only use this fallback if there's an implementation error.
|
||||
// `latest_activity` already does rate-limited warn!() log.
|
||||
debug!(layer=%hist_layer, "last_activity returns None, using SystemTime::now");
|
||||
SystemTime::now()
|
||||
});
|
||||
let last_activity_ts = hist_layer.access_stats().latest_activity_or_now();
|
||||
|
||||
let no_activity_for = match now.duration_since(last_activity_ts) {
|
||||
Ok(d) => d,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use anyhow::{bail, ensure, Context, Result};
|
||||
use futures::StreamExt;
|
||||
use pageserver_api::shard::TenantShardId;
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use tracing::trace;
|
||||
@@ -20,19 +21,13 @@ use crate::{
|
||||
};
|
||||
|
||||
/// Provides semantic APIs to manipulate the layer map.
|
||||
#[derive(Default)]
|
||||
pub(crate) struct LayerManager {
|
||||
layer_map: LayerMap,
|
||||
layer_fmgr: LayerFileManager<Layer>,
|
||||
}
|
||||
|
||||
impl LayerManager {
|
||||
pub(crate) fn create() -> Self {
|
||||
Self {
|
||||
layer_map: LayerMap::default(),
|
||||
layer_fmgr: LayerFileManager::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_from_desc(&self, desc: &PersistentLayerDesc) -> Layer {
|
||||
self.layer_fmgr.get_from_desc(desc)
|
||||
}
|
||||
@@ -246,6 +241,32 @@ impl LayerManager {
|
||||
layer.delete_on_drop();
|
||||
}
|
||||
|
||||
pub(crate) fn resident_layers(&self) -> impl futures::stream::Stream<Item = Layer> + '_ {
|
||||
// for small layer maps, we most likely have all resident, but for larger more are likely
|
||||
// to be evicted assuming lots of layers correlated with longer lifespan.
|
||||
|
||||
let layers = self
|
||||
.layer_map()
|
||||
.iter_historic_layers()
|
||||
.map(|desc| self.get_from_desc(&desc));
|
||||
|
||||
let layers = futures::stream::iter(layers);
|
||||
|
||||
layers.filter_map(|layer| async move {
|
||||
// TODO(#6028): this query does not really need to see the ResidentLayer
|
||||
match layer.keep_resident().await {
|
||||
Ok(Some(layer)) => Some(layer.drop_eviction_guard()),
|
||||
Ok(None) => None,
|
||||
Err(e) => {
|
||||
// these should not happen, but we cannot make them statically impossible right
|
||||
// now.
|
||||
tracing::warn!(%layer, "failed to keep the layer resident: {e:#}");
|
||||
None
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn contains(&self, layer: &Layer) -> bool {
|
||||
self.layer_fmgr.contains(layer)
|
||||
}
|
||||
@@ -253,6 +274,12 @@ impl LayerManager {
|
||||
|
||||
pub(crate) struct LayerFileManager<T>(HashMap<PersistentLayerKey, T>);
|
||||
|
||||
impl<T> Default for LayerFileManager<T> {
|
||||
fn default() -> Self {
|
||||
Self(HashMap::default())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: AsLayerDesc + Clone> LayerFileManager<T> {
|
||||
fn get_from_desc(&self, desc: &PersistentLayerDesc) -> T {
|
||||
// The assumption for the `expect()` is that all code maintains the following invariant:
|
||||
@@ -275,10 +302,6 @@ impl<T: AsLayerDesc + Clone> LayerFileManager<T> {
|
||||
self.0.contains_key(&layer.layer_desc().key())
|
||||
}
|
||||
|
||||
pub(crate) fn new() -> Self {
|
||||
Self(HashMap::new())
|
||||
}
|
||||
|
||||
pub(crate) fn remove(&mut self, layer: &T) {
|
||||
let present = self.0.remove(&layer.layer_desc().key());
|
||||
if present.is_none() && cfg!(debug_assertions) {
|
||||
|
||||
@@ -19,7 +19,7 @@ use once_cell::sync::OnceCell;
|
||||
use pageserver_api::shard::TenantShardId;
|
||||
use std::fs::{self, File};
|
||||
use std::io::{Error, ErrorKind, Seek, SeekFrom};
|
||||
use tokio_epoll_uring::IoBufMut;
|
||||
use tokio_epoll_uring::{BoundedBuf, IoBufMut, Slice};
|
||||
|
||||
use std::os::fd::{AsRawFd, FromRawFd, IntoRawFd, OwnedFd, RawFd};
|
||||
use std::os::unix::fs::FileExt;
|
||||
@@ -410,10 +410,10 @@ impl VirtualFile {
|
||||
/// step, the tmp path is renamed to the final path. As renames are
|
||||
/// atomic, a crash during the write operation will never leave behind a
|
||||
/// partially written file.
|
||||
pub async fn crashsafe_overwrite(
|
||||
pub async fn crashsafe_overwrite<B: BoundedBuf>(
|
||||
final_path: &Utf8Path,
|
||||
tmp_path: &Utf8Path,
|
||||
content: &[u8],
|
||||
content: B,
|
||||
) -> std::io::Result<()> {
|
||||
let Some(final_path_parent) = final_path.parent() else {
|
||||
return Err(std::io::Error::from_raw_os_error(
|
||||
@@ -430,7 +430,8 @@ impl VirtualFile {
|
||||
.create_new(true),
|
||||
)
|
||||
.await?;
|
||||
file.write_all(content).await?;
|
||||
let (_content, res) = file.write_all(content).await;
|
||||
res?;
|
||||
file.sync_all().await?;
|
||||
drop(file); // before the rename, that's important!
|
||||
// renames are atomic
|
||||
@@ -601,23 +602,36 @@ impl VirtualFile {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn write_all(&mut self, mut buf: &[u8]) -> Result<(), Error> {
|
||||
/// Writes `buf.slice(0..buf.bytes_init())`.
|
||||
/// Returns the IoBuf that is underlying the BoundedBuf `buf`.
|
||||
/// I.e., the returned value's `bytes_init()` method returns something different than the `bytes_init()` that was passed in.
|
||||
/// It's quite brittle and easy to mis-use, so, we return the size in the Ok() variant.
|
||||
pub async fn write_all<B: BoundedBuf>(&mut self, buf: B) -> (B::Buf, Result<usize, Error>) {
|
||||
let nbytes = buf.bytes_init();
|
||||
if nbytes == 0 {
|
||||
return (Slice::into_inner(buf.slice_full()), Ok(0));
|
||||
}
|
||||
let mut buf = buf.slice(0..nbytes);
|
||||
while !buf.is_empty() {
|
||||
match self.write(buf).await {
|
||||
// TODO: push `Slice` further down
|
||||
match self.write(&buf).await {
|
||||
Ok(0) => {
|
||||
return Err(Error::new(
|
||||
std::io::ErrorKind::WriteZero,
|
||||
"failed to write whole buffer",
|
||||
));
|
||||
return (
|
||||
Slice::into_inner(buf),
|
||||
Err(Error::new(
|
||||
std::io::ErrorKind::WriteZero,
|
||||
"failed to write whole buffer",
|
||||
)),
|
||||
);
|
||||
}
|
||||
Ok(n) => {
|
||||
buf = &buf[n..];
|
||||
buf = buf.slice(n..);
|
||||
}
|
||||
Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => {}
|
||||
Err(e) => return Err(e),
|
||||
Err(e) => return (Slice::into_inner(buf), Err(e)),
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
(Slice::into_inner(buf), Ok(nbytes))
|
||||
}
|
||||
|
||||
async fn write(&mut self, buf: &[u8]) -> Result<usize, std::io::Error> {
|
||||
@@ -676,7 +690,6 @@ where
|
||||
F: FnMut(tokio_epoll_uring::Slice<B>, u64) -> Fut,
|
||||
Fut: std::future::Future<Output = (tokio_epoll_uring::Slice<B>, std::io::Result<usize>)>,
|
||||
{
|
||||
use tokio_epoll_uring::BoundedBuf;
|
||||
let mut buf: tokio_epoll_uring::Slice<B> = buf.slice_full(); // includes all the uninitialized memory
|
||||
while buf.bytes_total() != 0 {
|
||||
let res;
|
||||
@@ -1063,10 +1076,19 @@ mod tests {
|
||||
MaybeVirtualFile::File(file) => file.seek(pos),
|
||||
}
|
||||
}
|
||||
async fn write_all(&mut self, buf: &[u8]) -> Result<(), Error> {
|
||||
async fn write_all<B: BoundedBuf>(&mut self, buf: B) -> Result<(), Error> {
|
||||
match self {
|
||||
MaybeVirtualFile::VirtualFile(file) => file.write_all(buf).await,
|
||||
MaybeVirtualFile::File(file) => file.write_all(buf),
|
||||
MaybeVirtualFile::VirtualFile(file) => {
|
||||
let (_buf, res) = file.write_all(buf).await;
|
||||
res.map(|_| ())
|
||||
}
|
||||
MaybeVirtualFile::File(file) => {
|
||||
let buf_len = buf.bytes_init();
|
||||
if buf_len == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
file.write_all(&buf.slice(0..buf_len))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1141,7 +1163,7 @@ mod tests {
|
||||
.to_owned(),
|
||||
)
|
||||
.await?;
|
||||
file_a.write_all(b"foobar").await?;
|
||||
file_a.write_all(b"foobar".to_vec()).await?;
|
||||
|
||||
// cannot read from a file opened in write-only mode
|
||||
let _ = file_a.read_string().await.unwrap_err();
|
||||
@@ -1150,7 +1172,7 @@ mod tests {
|
||||
let mut file_a = openfunc(path_a, OpenOptions::new().read(true).to_owned()).await?;
|
||||
|
||||
// cannot write to a file opened in read-only mode
|
||||
let _ = file_a.write_all(b"bar").await.unwrap_err();
|
||||
let _ = file_a.write_all(b"bar".to_vec()).await.unwrap_err();
|
||||
|
||||
// Try simple read
|
||||
assert_eq!("foobar", file_a.read_string().await?);
|
||||
@@ -1293,7 +1315,7 @@ mod tests {
|
||||
let path = testdir.join("myfile");
|
||||
let tmp_path = testdir.join("myfile.tmp");
|
||||
|
||||
VirtualFile::crashsafe_overwrite(&path, &tmp_path, b"foo")
|
||||
VirtualFile::crashsafe_overwrite(&path, &tmp_path, b"foo".to_vec())
|
||||
.await
|
||||
.unwrap();
|
||||
let mut file = MaybeVirtualFile::from(VirtualFile::open(&path).await.unwrap());
|
||||
@@ -1302,7 +1324,7 @@ mod tests {
|
||||
assert!(!tmp_path.exists());
|
||||
drop(file);
|
||||
|
||||
VirtualFile::crashsafe_overwrite(&path, &tmp_path, b"bar")
|
||||
VirtualFile::crashsafe_overwrite(&path, &tmp_path, b"bar".to_vec())
|
||||
.await
|
||||
.unwrap();
|
||||
let mut file = MaybeVirtualFile::from(VirtualFile::open(&path).await.unwrap());
|
||||
@@ -1324,7 +1346,7 @@ mod tests {
|
||||
std::fs::write(&tmp_path, "some preexisting junk that should be removed").unwrap();
|
||||
assert!(tmp_path.exists());
|
||||
|
||||
VirtualFile::crashsafe_overwrite(&path, &tmp_path, b"foo")
|
||||
VirtualFile::crashsafe_overwrite(&path, &tmp_path, b"foo".to_vec())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
||||
@@ -346,7 +346,7 @@ impl WalIngest {
|
||||
let info = decoded.xl_info & pg_constants::XLR_RMGR_INFO_MASK;
|
||||
|
||||
if info == pg_constants::XLOG_LOGICAL_MESSAGE {
|
||||
let xlrec = XlLogicalMessage::decode(&mut buf);
|
||||
let xlrec = crate::walrecord::XlLogicalMessage::decode(&mut buf);
|
||||
let prefix = std::str::from_utf8(&buf[0..xlrec.prefix_size - 1])?;
|
||||
let message = &buf[xlrec.prefix_size..xlrec.prefix_size + xlrec.message_size];
|
||||
if prefix == "neon-test" {
|
||||
|
||||
@@ -3079,14 +3079,6 @@ neon_redo_read_buffer_filter(XLogReaderState *record, uint8 block_id)
|
||||
XLogRecGetBlockTag(record, block_id, &rinfo, &forknum, &blkno);
|
||||
#endif
|
||||
|
||||
/*
|
||||
* Out of an abundance of caution, we always run redo on shared catalogs,
|
||||
* regardless of whether the block is stored in shared buffers. See also
|
||||
* this function's top comment.
|
||||
*/
|
||||
if (!OidIsValid(NInfoGetDbOid(rinfo)))
|
||||
return false;
|
||||
|
||||
CopyNRelFileInfoToBufTag(tag, rinfo);
|
||||
tag.forkNum = forknum;
|
||||
tag.blockNum = blkno;
|
||||
@@ -3100,17 +3092,28 @@ neon_redo_read_buffer_filter(XLogReaderState *record, uint8 block_id)
|
||||
*/
|
||||
LWLockAcquire(partitionLock, LW_SHARED);
|
||||
|
||||
/* Try to find the relevant buffer */
|
||||
buffer = BufTableLookup(&tag, hash);
|
||||
|
||||
no_redo_needed = buffer < 0;
|
||||
/*
|
||||
* Out of an abundance of caution, we always run redo on shared catalogs,
|
||||
* regardless of whether the block is stored in shared buffers. See also
|
||||
* this function's top comment.
|
||||
*/
|
||||
if (!OidIsValid(NInfoGetDbOid(rinfo)))
|
||||
{
|
||||
no_redo_needed = false;
|
||||
}
|
||||
else
|
||||
{
|
||||
/* Try to find the relevant buffer */
|
||||
buffer = BufTableLookup(&tag, hash);
|
||||
|
||||
no_redo_needed = buffer < 0;
|
||||
}
|
||||
/* In both cases st lwlsn past this WAL record */
|
||||
SetLastWrittenLSNForBlock(end_recptr, rinfo, forknum, blkno);
|
||||
|
||||
/*
|
||||
* we don't have the buffer in memory, update lwLsn past this record, also
|
||||
* evict page fro file cache
|
||||
* evict page from file cache
|
||||
*/
|
||||
if (no_redo_needed)
|
||||
lfc_evict(rinfo, forknum, blkno);
|
||||
|
||||
@@ -688,7 +688,7 @@ RecvAcceptorGreeting(Safekeeper *sk)
|
||||
if (!AsyncReadMessage(sk, (AcceptorProposerMessage *) &sk->greetResponse))
|
||||
return;
|
||||
|
||||
wp_log(LOG, "received AcceptorGreeting from safekeeper %s:%s", sk->host, sk->port);
|
||||
wp_log(LOG, "received AcceptorGreeting from safekeeper %s:%s, term=" INT64_FORMAT, sk->host, sk->port, sk->greetResponse.term);
|
||||
|
||||
/* Protocol is all good, move to voting. */
|
||||
sk->state = SS_VOTING;
|
||||
@@ -922,6 +922,7 @@ static void
|
||||
DetermineEpochStartLsn(WalProposer *wp)
|
||||
{
|
||||
TermHistory *dth;
|
||||
int n_ready = 0;
|
||||
|
||||
wp->propEpochStartLsn = InvalidXLogRecPtr;
|
||||
wp->donorEpoch = 0;
|
||||
@@ -932,6 +933,8 @@ DetermineEpochStartLsn(WalProposer *wp)
|
||||
{
|
||||
if (wp->safekeeper[i].state == SS_IDLE)
|
||||
{
|
||||
n_ready++;
|
||||
|
||||
if (GetEpoch(&wp->safekeeper[i]) > wp->donorEpoch ||
|
||||
(GetEpoch(&wp->safekeeper[i]) == wp->donorEpoch &&
|
||||
wp->safekeeper[i].voteResponse.flushLsn > wp->propEpochStartLsn))
|
||||
@@ -958,6 +961,16 @@ DetermineEpochStartLsn(WalProposer *wp)
|
||||
}
|
||||
}
|
||||
|
||||
if (n_ready < wp->quorum)
|
||||
{
|
||||
/*
|
||||
* This is a rare case that can be triggered if safekeeper has voted and disconnected.
|
||||
* In this case, its state will not be SS_IDLE and its vote cannot be used, because
|
||||
* we clean up `voteResponse` in `ShutdownConnection`.
|
||||
*/
|
||||
wp_log(FATAL, "missing majority of votes, collected %d, expected %d, got %d", wp->n_votes, wp->quorum, n_ready);
|
||||
}
|
||||
|
||||
/*
|
||||
* If propEpochStartLsn is 0, it means flushLsn is 0 everywhere, we are bootstrapping
|
||||
* and nothing was committed yet. Start streaming then from the basebackup LSN.
|
||||
|
||||
@@ -486,6 +486,8 @@ typedef struct walproposer_api
|
||||
*
|
||||
* On success, the data is placed in *buf. It is valid until the next call
|
||||
* to this function.
|
||||
*
|
||||
* Returns PG_ASYNC_READ_FAIL on closed connection.
|
||||
*/
|
||||
PGAsyncReadResult (*conn_async_read) (Safekeeper *sk, char **buf, int *amount);
|
||||
|
||||
@@ -532,6 +534,13 @@ typedef struct walproposer_api
|
||||
* Returns 0 if timeout is reached, 1 if some event happened. Updates
|
||||
* events mask to indicate events and sets sk to the safekeeper which has
|
||||
* an event.
|
||||
*
|
||||
* On timeout, events is set to WL_NO_EVENTS. On socket event, events is
|
||||
* set to WL_SOCKET_READABLE and/or WL_SOCKET_WRITEABLE. When socket is
|
||||
* closed, events is set to WL_SOCKET_READABLE.
|
||||
*
|
||||
* WL_SOCKET_WRITEABLE is usually set only when we need to flush the buffer.
|
||||
* It can be returned only if caller asked for this event in the last *_event_set call.
|
||||
*/
|
||||
int (*wait_event_set) (WalProposer *wp, long timeout, Safekeeper **sk, uint32 *events);
|
||||
|
||||
|
||||
@@ -36,9 +36,6 @@ pub enum AuthErrorImpl {
|
||||
#[error(transparent)]
|
||||
GetAuthInfo(#[from] console::errors::GetAuthInfoError),
|
||||
|
||||
#[error(transparent)]
|
||||
WakeCompute(#[from] console::errors::WakeComputeError),
|
||||
|
||||
/// SASL protocol errors (includes [SCRAM](crate::scram)).
|
||||
#[error(transparent)]
|
||||
Sasl(#[from] crate::sasl::Error),
|
||||
@@ -119,7 +116,6 @@ impl UserFacingError for AuthError {
|
||||
match self.0.as_ref() {
|
||||
Link(e) => e.to_string_client(),
|
||||
GetAuthInfo(e) => e.to_string_client(),
|
||||
WakeCompute(e) => e.to_string_client(),
|
||||
Sasl(e) => e.to_string_client(),
|
||||
AuthFailed(_) => self.to_string(),
|
||||
BadAuthMethod(_) => self.to_string(),
|
||||
@@ -139,7 +135,6 @@ impl ReportableError for AuthError {
|
||||
match self.0.as_ref() {
|
||||
Link(e) => e.get_error_kind(),
|
||||
GetAuthInfo(e) => e.get_error_kind(),
|
||||
WakeCompute(e) => e.get_error_kind(),
|
||||
Sasl(e) => e.get_error_kind(),
|
||||
AuthFailed(_) => crate::error::ErrorKind::User,
|
||||
BadAuthMethod(_) => crate::error::ErrorKind::User,
|
||||
|
||||
@@ -10,9 +10,9 @@ use crate::auth::validate_password_and_exchange;
|
||||
use crate::cache::Cached;
|
||||
use crate::console::errors::GetAuthInfoError;
|
||||
use crate::console::provider::{CachedRoleSecret, ConsoleBackend};
|
||||
use crate::console::AuthSecret;
|
||||
use crate::console::{AuthSecret, NodeInfo};
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::proxy::wake_compute::wake_compute;
|
||||
use crate::proxy::connect_compute::ComputeConnectBackend;
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::stream::Stream;
|
||||
use crate::{
|
||||
@@ -26,7 +26,6 @@ use crate::{
|
||||
stream, url,
|
||||
};
|
||||
use crate::{scram, EndpointCacheKey, EndpointId, RoleName};
|
||||
use futures::TryFutureExt;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::info;
|
||||
@@ -56,11 +55,11 @@ impl<T> std::ops::Deref for MaybeOwned<'_, T> {
|
||||
/// * However, when we substitute `T` with [`ComputeUserInfoMaybeEndpoint`],
|
||||
/// this helps us provide the credentials only to those auth
|
||||
/// backends which require them for the authentication process.
|
||||
pub enum BackendType<'a, T> {
|
||||
pub enum BackendType<'a, T, D> {
|
||||
/// Cloud API (V2).
|
||||
Console(MaybeOwned<'a, ConsoleBackend>, T),
|
||||
/// Authentication via a web browser.
|
||||
Link(MaybeOwned<'a, url::ApiUrl>),
|
||||
Link(MaybeOwned<'a, url::ApiUrl>, D),
|
||||
}
|
||||
|
||||
pub trait TestBackend: Send + Sync + 'static {
|
||||
@@ -71,7 +70,7 @@ pub trait TestBackend: Send + Sync + 'static {
|
||||
fn get_role_secret(&self) -> Result<CachedRoleSecret, console::errors::GetAuthInfoError>;
|
||||
}
|
||||
|
||||
impl std::fmt::Display for BackendType<'_, ()> {
|
||||
impl std::fmt::Display for BackendType<'_, (), ()> {
|
||||
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
@@ -86,51 +85,50 @@ impl std::fmt::Display for BackendType<'_, ()> {
|
||||
#[cfg(test)]
|
||||
ConsoleBackend::Test(_) => fmt.debug_tuple("Test").finish(),
|
||||
},
|
||||
Link(url) => fmt.debug_tuple("Link").field(&url.as_str()).finish(),
|
||||
Link(url, _) => fmt.debug_tuple("Link").field(&url.as_str()).finish(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> BackendType<'_, T> {
|
||||
impl<T, D> BackendType<'_, T, D> {
|
||||
/// Very similar to [`std::option::Option::as_ref`].
|
||||
/// This helps us pass structured config to async tasks.
|
||||
pub fn as_ref(&self) -> BackendType<'_, &T> {
|
||||
pub fn as_ref(&self) -> BackendType<'_, &T, &D> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(c, x) => Console(MaybeOwned::Borrowed(c), x),
|
||||
Link(c) => Link(MaybeOwned::Borrowed(c)),
|
||||
Link(c, x) => Link(MaybeOwned::Borrowed(c), x),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> BackendType<'a, T> {
|
||||
impl<'a, T, D> BackendType<'a, T, D> {
|
||||
/// Very similar to [`std::option::Option::map`].
|
||||
/// Maps [`BackendType<T>`] to [`BackendType<R>`] by applying
|
||||
/// a function to a contained value.
|
||||
pub fn map<R>(self, f: impl FnOnce(T) -> R) -> BackendType<'a, R> {
|
||||
pub fn map<R>(self, f: impl FnOnce(T) -> R) -> BackendType<'a, R, D> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(c, x) => Console(c, f(x)),
|
||||
Link(c) => Link(c),
|
||||
Link(c, x) => Link(c, x),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T, E> BackendType<'a, Result<T, E>> {
|
||||
impl<'a, T, D, E> BackendType<'a, Result<T, E>, D> {
|
||||
/// Very similar to [`std::option::Option::transpose`].
|
||||
/// This is most useful for error handling.
|
||||
pub fn transpose(self) -> Result<BackendType<'a, T>, E> {
|
||||
pub fn transpose(self) -> Result<BackendType<'a, T, D>, E> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(c, x) => x.map(|x| Console(c, x)),
|
||||
Link(c) => Ok(Link(c)),
|
||||
Link(c, x) => Ok(Link(c, x)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ComputeCredentials<T> {
|
||||
pub struct ComputeCredentials {
|
||||
pub info: ComputeUserInfo,
|
||||
pub keys: T,
|
||||
pub keys: ComputeCredentialKeys,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -153,7 +151,6 @@ impl ComputeUserInfo {
|
||||
}
|
||||
|
||||
pub enum ComputeCredentialKeys {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Password(Vec<u8>),
|
||||
AuthKeys(AuthKeys),
|
||||
}
|
||||
@@ -188,19 +185,21 @@ async fn auth_quirks(
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
// If there's no project so far, that entails that client doesn't
|
||||
// support SNI or other means of passing the endpoint (project) name.
|
||||
// We now expect to see a very specific payload in the place of password.
|
||||
let (info, unauthenticated_password) = match user_info.try_into() {
|
||||
Err(info) => {
|
||||
let res = hacks::password_hack_no_authentication(info, client, &mut ctx.latency_timer)
|
||||
.await?;
|
||||
let res = hacks::password_hack_no_authentication(ctx, info, client).await?;
|
||||
|
||||
ctx.set_endpoint_id(res.info.endpoint.clone());
|
||||
tracing::Span::current().record("ep", &tracing::field::display(&res.info.endpoint));
|
||||
|
||||
(res.info, Some(res.keys))
|
||||
let password = match res.keys {
|
||||
ComputeCredentialKeys::Password(p) => p,
|
||||
_ => unreachable!("password hack should return a password"),
|
||||
};
|
||||
(res.info, Some(password))
|
||||
}
|
||||
Ok(info) => (info, None),
|
||||
};
|
||||
@@ -254,7 +253,7 @@ async fn authenticate_with_secret(
|
||||
unauthenticated_password: Option<Vec<u8>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
if let Some(password) = unauthenticated_password {
|
||||
let auth_outcome = validate_password_and_exchange(&password, secret)?;
|
||||
let keys = match auth_outcome {
|
||||
@@ -276,21 +275,22 @@ async fn authenticate_with_secret(
|
||||
// Perform cleartext auth if we're allowed to do that.
|
||||
// Currently, we use it for websocket connections (latency).
|
||||
if allow_cleartext {
|
||||
return hacks::authenticate_cleartext(info, client, &mut ctx.latency_timer, secret).await;
|
||||
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
|
||||
return hacks::authenticate_cleartext(ctx, info, client, secret).await;
|
||||
}
|
||||
|
||||
// Finally, proceed with the main auth flow (SCRAM-based).
|
||||
classic::authenticate(info, client, config, &mut ctx.latency_timer, secret).await
|
||||
classic::authenticate(ctx, info, client, config, secret).await
|
||||
}
|
||||
|
||||
impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
|
||||
impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint, &()> {
|
||||
/// Get compute endpoint name from the credentials.
|
||||
pub fn get_endpoint(&self) -> Option<EndpointId> {
|
||||
use BackendType::*;
|
||||
|
||||
match self {
|
||||
Console(_, user_info) => user_info.endpoint_id.clone(),
|
||||
Link(_) => Some("link".into()),
|
||||
Link(_, _) => Some("link".into()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -300,7 +300,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
|
||||
|
||||
match self {
|
||||
Console(_, user_info) => &user_info.user,
|
||||
Link(_) => "link",
|
||||
Link(_, _) => "link",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -312,7 +312,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
allow_cleartext: bool,
|
||||
config: &'static AuthenticationConfig,
|
||||
) -> auth::Result<(CachedNodeInfo, BackendType<'a, ComputeUserInfo>)> {
|
||||
) -> auth::Result<BackendType<'a, ComputeCredentials, NodeInfo>> {
|
||||
use BackendType::*;
|
||||
|
||||
let res = match self {
|
||||
@@ -323,33 +323,17 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
|
||||
"performing authentication using the console"
|
||||
);
|
||||
|
||||
let compute_credentials =
|
||||
let credentials =
|
||||
auth_quirks(ctx, &*api, user_info, client, allow_cleartext, config).await?;
|
||||
|
||||
let mut num_retries = 0;
|
||||
let mut node =
|
||||
wake_compute(&mut num_retries, ctx, &api, &compute_credentials.info).await?;
|
||||
|
||||
ctx.set_project(node.aux.clone());
|
||||
|
||||
match compute_credentials.keys {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
ComputeCredentialKeys::Password(password) => node.config.password(password),
|
||||
ComputeCredentialKeys::AuthKeys(auth_keys) => node.config.auth_keys(auth_keys),
|
||||
};
|
||||
|
||||
(node, BackendType::Console(api, compute_credentials.info))
|
||||
BackendType::Console(api, credentials)
|
||||
}
|
||||
// NOTE: this auth backend doesn't use client credentials.
|
||||
Link(url) => {
|
||||
Link(url, _) => {
|
||||
info!("performing link authentication");
|
||||
|
||||
let node_info = link::authenticate(ctx, &url, client).await?;
|
||||
let info = link::authenticate(ctx, &url, client).await?;
|
||||
|
||||
(
|
||||
CachedNodeInfo::new_uncached(node_info),
|
||||
BackendType::Link(url),
|
||||
)
|
||||
BackendType::Link(url, info)
|
||||
}
|
||||
};
|
||||
|
||||
@@ -358,7 +342,7 @@ impl<'a> BackendType<'a, ComputeUserInfoMaybeEndpoint> {
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendType<'_, ComputeUserInfo> {
|
||||
impl BackendType<'_, ComputeUserInfo, &()> {
|
||||
pub async fn get_role_secret(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
@@ -366,7 +350,7 @@ impl BackendType<'_, ComputeUserInfo> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(api, user_info) => api.get_role_secret(ctx, user_info).await,
|
||||
Link(_) => Ok(Cached::new_uncached(None)),
|
||||
Link(_, _) => Ok(Cached::new_uncached(None)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -377,21 +361,51 @@ impl BackendType<'_, ComputeUserInfo> {
|
||||
use BackendType::*;
|
||||
match self {
|
||||
Console(api, user_info) => api.get_allowed_ips_and_secret(ctx, user_info).await,
|
||||
Link(_) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
|
||||
}
|
||||
}
|
||||
|
||||
/// When applicable, wake the compute node, gaining its connection info in the process.
|
||||
/// The link auth flow doesn't support this, so we return [`None`] in that case.
|
||||
pub async fn wake_compute(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
) -> Result<Option<CachedNodeInfo>, console::errors::WakeComputeError> {
|
||||
use BackendType::*;
|
||||
|
||||
match self {
|
||||
Console(api, user_info) => api.wake_compute(ctx, user_info).map_ok(Some).await,
|
||||
Link(_) => Ok(None),
|
||||
Link(_, _) => Ok((Cached::new_uncached(Arc::new(vec![])), None)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, NodeInfo> {
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
|
||||
use BackendType::*;
|
||||
|
||||
match self {
|
||||
Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
|
||||
Link(_, info) => Ok(Cached::new_uncached(info.clone())),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_keys(&self) -> Option<&ComputeCredentialKeys> {
|
||||
match self {
|
||||
BackendType::Console(_, creds) => Some(&creds.keys),
|
||||
BackendType::Link(_, _) => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl ComputeConnectBackend for BackendType<'_, ComputeCredentials, &()> {
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
) -> Result<CachedNodeInfo, console::errors::WakeComputeError> {
|
||||
use BackendType::*;
|
||||
|
||||
match self {
|
||||
Console(api, creds) => api.wake_compute(ctx, &creds.info).await,
|
||||
Link(_, _) => unreachable!("link auth flow doesn't support waking the compute"),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_keys(&self) -> Option<&ComputeCredentialKeys> {
|
||||
match self {
|
||||
BackendType::Console(_, creds) => Some(&creds.keys),
|
||||
BackendType::Link(_, _) => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ use crate::{
|
||||
compute,
|
||||
config::AuthenticationConfig,
|
||||
console::AuthSecret,
|
||||
metrics::LatencyTimer,
|
||||
context::RequestMonitoring,
|
||||
sasl,
|
||||
stream::{PqStream, Stream},
|
||||
};
|
||||
@@ -12,12 +12,12 @@ use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{info, warn};
|
||||
|
||||
pub(super) async fn authenticate(
|
||||
ctx: &mut RequestMonitoring,
|
||||
creds: ComputeUserInfo,
|
||||
client: &mut PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
config: &'static AuthenticationConfig,
|
||||
latency_timer: &mut LatencyTimer,
|
||||
secret: AuthSecret,
|
||||
) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
let flow = AuthFlow::new(client);
|
||||
let scram_keys = match secret {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
@@ -27,13 +27,11 @@ pub(super) async fn authenticate(
|
||||
}
|
||||
AuthSecret::Scram(secret) => {
|
||||
info!("auth endpoint chooses SCRAM");
|
||||
let scram = auth::Scram(&secret);
|
||||
let scram = auth::Scram(&secret, &mut *ctx);
|
||||
|
||||
let auth_outcome = tokio::time::timeout(
|
||||
config.scram_protocol_timeout,
|
||||
async {
|
||||
// pause the timer while we communicate with the client
|
||||
let _paused = latency_timer.pause();
|
||||
|
||||
flow.begin(scram).await.map_err(|error| {
|
||||
warn!(?error, "error sending scram acknowledgement");
|
||||
|
||||
@@ -4,7 +4,7 @@ use super::{
|
||||
use crate::{
|
||||
auth::{self, AuthFlow},
|
||||
console::AuthSecret,
|
||||
metrics::LatencyTimer,
|
||||
context::RequestMonitoring,
|
||||
sasl,
|
||||
stream::{self, Stream},
|
||||
};
|
||||
@@ -16,15 +16,16 @@ use tracing::{info, warn};
|
||||
/// These properties are benefical for serverless JS workers, so we
|
||||
/// use this mechanism for websocket connections.
|
||||
pub async fn authenticate_cleartext(
|
||||
ctx: &mut RequestMonitoring,
|
||||
info: ComputeUserInfo,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
latency_timer: &mut LatencyTimer,
|
||||
secret: AuthSecret,
|
||||
) -> auth::Result<ComputeCredentials<ComputeCredentialKeys>> {
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
warn!("cleartext auth flow override is enabled, proceeding");
|
||||
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
|
||||
|
||||
// pause the timer while we communicate with the client
|
||||
let _paused = latency_timer.pause();
|
||||
let _paused = ctx.latency_timer.pause();
|
||||
|
||||
let auth_outcome = AuthFlow::new(client)
|
||||
.begin(auth::CleartextPassword(secret))
|
||||
@@ -47,14 +48,15 @@ pub async fn authenticate_cleartext(
|
||||
/// Similar to [`authenticate_cleartext`], but there's a specific password format,
|
||||
/// and passwords are not yet validated (we don't know how to validate them!)
|
||||
pub async fn password_hack_no_authentication(
|
||||
ctx: &mut RequestMonitoring,
|
||||
info: ComputeUserInfoNoEndpoint,
|
||||
client: &mut stream::PqStream<Stream<impl AsyncRead + AsyncWrite + Unpin>>,
|
||||
latency_timer: &mut LatencyTimer,
|
||||
) -> auth::Result<ComputeCredentials<Vec<u8>>> {
|
||||
) -> auth::Result<ComputeCredentials> {
|
||||
warn!("project not specified, resorting to the password hack auth flow");
|
||||
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
|
||||
|
||||
// pause the timer while we communicate with the client
|
||||
let _paused = latency_timer.pause();
|
||||
let _paused = ctx.latency_timer.pause();
|
||||
|
||||
let payload = AuthFlow::new(client)
|
||||
.begin(auth::PasswordHack)
|
||||
@@ -71,6 +73,6 @@ pub async fn password_hack_no_authentication(
|
||||
options: info.options,
|
||||
endpoint: payload.endpoint,
|
||||
},
|
||||
keys: payload.password,
|
||||
keys: ComputeCredentialKeys::Password(payload.password),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -61,6 +61,8 @@ pub(super) async fn authenticate(
|
||||
link_uri: &reqwest::Url,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> auth::Result<NodeInfo> {
|
||||
ctx.set_auth_method(crate::context::AuthMethod::Web);
|
||||
|
||||
// registering waiter can fail if we get unlucky with rng.
|
||||
// just try again.
|
||||
let (psql_session_id, waiter) = loop {
|
||||
|
||||
@@ -99,6 +99,9 @@ impl ComputeUserInfoMaybeEndpoint {
|
||||
// record the values if we have them
|
||||
ctx.set_application(params.get("application_name").map(SmolStr::from));
|
||||
ctx.set_user(user.clone());
|
||||
if let Some(dbname) = params.get("database") {
|
||||
ctx.set_dbname(dbname.into());
|
||||
}
|
||||
|
||||
// Project name might be passed via PG's command-line options.
|
||||
let endpoint_option = params
|
||||
|
||||
@@ -4,9 +4,11 @@ use super::{backend::ComputeCredentialKeys, AuthErrorImpl, PasswordHackPayload};
|
||||
use crate::{
|
||||
config::TlsServerEndPoint,
|
||||
console::AuthSecret,
|
||||
context::RequestMonitoring,
|
||||
sasl, scram,
|
||||
stream::{PqStream, Stream},
|
||||
};
|
||||
use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
|
||||
use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
|
||||
use std::io;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
@@ -23,7 +25,7 @@ pub trait AuthMethod {
|
||||
pub struct Begin;
|
||||
|
||||
/// Use [SCRAM](crate::scram)-based auth in [`AuthFlow`].
|
||||
pub struct Scram<'a>(pub &'a scram::ServerSecret);
|
||||
pub struct Scram<'a>(pub &'a scram::ServerSecret, pub &'a mut RequestMonitoring);
|
||||
|
||||
impl AuthMethod for Scram<'_> {
|
||||
#[inline(always)]
|
||||
@@ -138,6 +140,11 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
|
||||
/// Perform user authentication. Raise an error in case authentication failed.
|
||||
pub async fn authenticate(self) -> super::Result<sasl::Outcome<scram::ScramKey>> {
|
||||
let Scram(secret, ctx) = self.state;
|
||||
|
||||
// pause the timer while we communicate with the client
|
||||
let _paused = ctx.latency_timer.pause();
|
||||
|
||||
// Initial client message contains the chosen auth method's name.
|
||||
let msg = self.stream.read_password_message().await?;
|
||||
let sasl = sasl::FirstMessage::parse(&msg)
|
||||
@@ -148,9 +155,15 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
|
||||
return Err(super::AuthError::bad_auth_method(sasl.method));
|
||||
}
|
||||
|
||||
match sasl.method {
|
||||
SCRAM_SHA_256 => ctx.auth_method = Some(crate::context::AuthMethod::ScramSha256),
|
||||
SCRAM_SHA_256_PLUS => {
|
||||
ctx.auth_method = Some(crate::context::AuthMethod::ScramSha256Plus)
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
info!("client chooses {}", sasl.method);
|
||||
|
||||
let secret = self.state.0;
|
||||
let outcome = sasl::SaslStream::new(self.stream, sasl.message)
|
||||
.authenticate(scram::Exchange::new(
|
||||
secret,
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
use futures::future::Either;
|
||||
use proxy::auth;
|
||||
use proxy::auth::backend::MaybeOwned;
|
||||
use proxy::cancellation::CancelMap;
|
||||
use proxy::cancellation::CancellationHandler;
|
||||
use proxy::config::AuthenticationConfig;
|
||||
use proxy::config::CacheOptions;
|
||||
use proxy::config::HttpConfig;
|
||||
@@ -12,6 +14,7 @@ use proxy::rate_limiter::EndpointRateLimiter;
|
||||
use proxy::rate_limiter::RateBucketInfo;
|
||||
use proxy::rate_limiter::RateLimiterConfig;
|
||||
use proxy::redis::notifications;
|
||||
use proxy::redis::publisher::RedisPublisherClient;
|
||||
use proxy::serverless::GlobalConnPoolOptions;
|
||||
use proxy::usage_metrics;
|
||||
|
||||
@@ -22,6 +25,7 @@ use std::net::SocketAddr;
|
||||
use std::pin::pin;
|
||||
use std::sync::Arc;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::task::JoinSet;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::info;
|
||||
@@ -129,6 +133,9 @@ struct ProxyCliArgs {
|
||||
/// Can be given multiple times for different bucket sizes.
|
||||
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
|
||||
endpoint_rps_limit: Vec<RateBucketInfo>,
|
||||
/// Redis rate limiter max number of requests per second.
|
||||
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
|
||||
redis_rps_limit: Vec<RateBucketInfo>,
|
||||
/// Initial limit for dynamic rate limiter. Makes sense only if `rate_limit_algorithm` is *not* `None`.
|
||||
#[clap(long, default_value_t = 100)]
|
||||
initial_limit: usize,
|
||||
@@ -225,6 +232,19 @@ async fn main() -> anyhow::Result<()> {
|
||||
let cancellation_token = CancellationToken::new();
|
||||
|
||||
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new(&config.endpoint_rps_limit));
|
||||
let cancel_map = CancelMap::default();
|
||||
let redis_publisher = match &args.redis_notifications {
|
||||
Some(url) => Some(Arc::new(Mutex::new(RedisPublisherClient::new(
|
||||
url,
|
||||
args.region.clone(),
|
||||
&config.redis_rps_limit,
|
||||
)?))),
|
||||
None => None,
|
||||
};
|
||||
let cancellation_handler = Arc::new(CancellationHandler::new(
|
||||
cancel_map.clone(),
|
||||
redis_publisher,
|
||||
));
|
||||
|
||||
// client facing tasks. these will exit on error or on cancellation
|
||||
// cancellation returns Ok(())
|
||||
@@ -234,6 +254,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
proxy_listener,
|
||||
cancellation_token.clone(),
|
||||
endpoint_rate_limiter.clone(),
|
||||
cancellation_handler.clone(),
|
||||
));
|
||||
|
||||
// TODO: rename the argument to something like serverless.
|
||||
@@ -248,6 +269,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
serverless_listener,
|
||||
cancellation_token.clone(),
|
||||
endpoint_rate_limiter.clone(),
|
||||
cancellation_handler.clone(),
|
||||
));
|
||||
}
|
||||
|
||||
@@ -271,7 +293,12 @@ async fn main() -> anyhow::Result<()> {
|
||||
let cache = api.caches.project_info.clone();
|
||||
if let Some(url) = args.redis_notifications {
|
||||
info!("Starting redis notifications listener ({url})");
|
||||
maintenance_tasks.spawn(notifications::task_main(url.to_owned(), cache.clone()));
|
||||
maintenance_tasks.spawn(notifications::task_main(
|
||||
url.to_owned(),
|
||||
cache.clone(),
|
||||
cancel_map.clone(),
|
||||
args.region.clone(),
|
||||
));
|
||||
}
|
||||
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
|
||||
}
|
||||
@@ -383,7 +410,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
}
|
||||
AuthBackend::Link => {
|
||||
let url = args.uri.parse()?;
|
||||
auth::BackendType::Link(MaybeOwned::Owned(url))
|
||||
auth::BackendType::Link(MaybeOwned::Owned(url), ())
|
||||
}
|
||||
};
|
||||
let http_config = HttpConfig {
|
||||
@@ -403,6 +430,8 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
|
||||
let mut endpoint_rps_limit = args.endpoint_rps_limit.clone();
|
||||
RateBucketInfo::validate(&mut endpoint_rps_limit)?;
|
||||
let mut redis_rps_limit = args.redis_rps_limit.clone();
|
||||
RateBucketInfo::validate(&mut redis_rps_limit)?;
|
||||
|
||||
let config = Box::leak(Box::new(ProxyConfig {
|
||||
tls_config,
|
||||
@@ -414,6 +443,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
|
||||
require_client_ip: args.require_client_ip,
|
||||
disable_ip_check_for_http: args.disable_ip_check_for_http,
|
||||
endpoint_rps_limit,
|
||||
redis_rps_limit,
|
||||
handshake_timeout: args.handshake_timeout,
|
||||
// TODO: add this argument
|
||||
region: args.region.clone(),
|
||||
|
||||
@@ -1,16 +1,28 @@
|
||||
use async_trait::async_trait;
|
||||
use dashmap::DashMap;
|
||||
use pq_proto::CancelKeyData;
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
use thiserror::Error;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio_postgres::{CancelToken, NoTls};
|
||||
use tracing::info;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::error::ReportableError;
|
||||
use crate::{
|
||||
error::ReportableError, metrics::NUM_CANCELLATION_REQUESTS,
|
||||
redis::publisher::RedisPublisherClient,
|
||||
};
|
||||
|
||||
pub type CancelMap = Arc<DashMap<CancelKeyData, Option<CancelClosure>>>;
|
||||
|
||||
/// Enables serving `CancelRequest`s.
|
||||
#[derive(Default)]
|
||||
pub struct CancelMap(DashMap<CancelKeyData, Option<CancelClosure>>);
|
||||
///
|
||||
/// If there is a `RedisPublisherClient` available, it will be used to publish the cancellation key to other proxy instances.
|
||||
pub struct CancellationHandler {
|
||||
map: CancelMap,
|
||||
redis_client: Option<Arc<Mutex<RedisPublisherClient>>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum CancelError {
|
||||
@@ -32,15 +44,43 @@ impl ReportableError for CancelError {
|
||||
}
|
||||
}
|
||||
|
||||
impl CancelMap {
|
||||
impl CancellationHandler {
|
||||
pub fn new(map: CancelMap, redis_client: Option<Arc<Mutex<RedisPublisherClient>>>) -> Self {
|
||||
Self { map, redis_client }
|
||||
}
|
||||
/// Cancel a running query for the corresponding connection.
|
||||
pub async fn cancel_session(&self, key: CancelKeyData) -> Result<(), CancelError> {
|
||||
pub async fn cancel_session(
|
||||
&self,
|
||||
key: CancelKeyData,
|
||||
session_id: Uuid,
|
||||
) -> Result<(), CancelError> {
|
||||
let from = "from_client";
|
||||
// NB: we should immediately release the lock after cloning the token.
|
||||
let Some(cancel_closure) = self.0.get(&key).and_then(|x| x.clone()) else {
|
||||
let Some(cancel_closure) = self.map.get(&key).and_then(|x| x.clone()) else {
|
||||
tracing::warn!("query cancellation key not found: {key}");
|
||||
if let Some(redis_client) = &self.redis_client {
|
||||
NUM_CANCELLATION_REQUESTS
|
||||
.with_label_values(&[from, "not_found"])
|
||||
.inc();
|
||||
info!("publishing cancellation key to Redis");
|
||||
match redis_client.lock().await.try_publish(key, session_id).await {
|
||||
Ok(()) => {
|
||||
info!("cancellation key successfuly published to Redis");
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("failed to publish a message: {e}");
|
||||
return Err(CancelError::IO(std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
e.to_string(),
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
NUM_CANCELLATION_REQUESTS
|
||||
.with_label_values(&[from, "found"])
|
||||
.inc();
|
||||
info!("cancelling query per user's request using key {key}");
|
||||
cancel_closure.try_cancel_query().await
|
||||
}
|
||||
@@ -57,7 +97,7 @@ impl CancelMap {
|
||||
|
||||
// Random key collisions are unlikely to happen here, but they're still possible,
|
||||
// which is why we have to take care not to rewrite an existing key.
|
||||
match self.0.entry(key) {
|
||||
match self.map.entry(key) {
|
||||
dashmap::mapref::entry::Entry::Occupied(_) => continue,
|
||||
dashmap::mapref::entry::Entry::Vacant(e) => {
|
||||
e.insert(None);
|
||||
@@ -69,18 +109,46 @@ impl CancelMap {
|
||||
info!("registered new query cancellation key {key}");
|
||||
Session {
|
||||
key,
|
||||
cancel_map: self,
|
||||
cancellation_handler: self,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn contains(&self, session: &Session) -> bool {
|
||||
self.0.contains_key(&session.key)
|
||||
self.map.contains_key(&session.key)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn is_empty(&self) -> bool {
|
||||
self.0.is_empty()
|
||||
self.map.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait NotificationsCancellationHandler {
|
||||
async fn cancel_session_no_publish(&self, key: CancelKeyData) -> Result<(), CancelError>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl NotificationsCancellationHandler for CancellationHandler {
|
||||
async fn cancel_session_no_publish(&self, key: CancelKeyData) -> Result<(), CancelError> {
|
||||
let from = "from_redis";
|
||||
let cancel_closure = self.map.get(&key).and_then(|x| x.clone());
|
||||
match cancel_closure {
|
||||
Some(cancel_closure) => {
|
||||
NUM_CANCELLATION_REQUESTS
|
||||
.with_label_values(&[from, "found"])
|
||||
.inc();
|
||||
cancel_closure.try_cancel_query().await
|
||||
}
|
||||
None => {
|
||||
NUM_CANCELLATION_REQUESTS
|
||||
.with_label_values(&[from, "not_found"])
|
||||
.inc();
|
||||
tracing::warn!("query cancellation key not found: {key}");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -115,7 +183,7 @@ pub struct Session {
|
||||
/// The user-facing key identifying this session.
|
||||
key: CancelKeyData,
|
||||
/// The [`CancelMap`] this session belongs to.
|
||||
cancel_map: Arc<CancelMap>,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
@@ -123,7 +191,9 @@ impl Session {
|
||||
/// This enables query cancellation in `crate::proxy::prepare_client_connection`.
|
||||
pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
|
||||
info!("enabling query cancellation for this session");
|
||||
self.cancel_map.0.insert(self.key, Some(cancel_closure));
|
||||
self.cancellation_handler
|
||||
.map
|
||||
.insert(self.key, Some(cancel_closure));
|
||||
|
||||
self.key
|
||||
}
|
||||
@@ -131,7 +201,7 @@ impl Session {
|
||||
|
||||
impl Drop for Session {
|
||||
fn drop(&mut self) {
|
||||
self.cancel_map.0.remove(&self.key);
|
||||
self.cancellation_handler.map.remove(&self.key);
|
||||
info!("dropped query cancellation key {}", &self.key);
|
||||
}
|
||||
}
|
||||
@@ -142,13 +212,16 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn check_session_drop() -> anyhow::Result<()> {
|
||||
let cancel_map: Arc<CancelMap> = Default::default();
|
||||
let cancellation_handler = Arc::new(CancellationHandler {
|
||||
map: CancelMap::default(),
|
||||
redis_client: None,
|
||||
});
|
||||
|
||||
let session = cancel_map.clone().get_session();
|
||||
assert!(cancel_map.contains(&session));
|
||||
let session = cancellation_handler.clone().get_session();
|
||||
assert!(cancellation_handler.contains(&session));
|
||||
drop(session);
|
||||
// Check that the session has been dropped.
|
||||
assert!(cancel_map.is_empty());
|
||||
assert!(cancellation_handler.is_empty());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::{
|
||||
auth::parse_endpoint_param,
|
||||
cancellation::CancelClosure,
|
||||
console::errors::WakeComputeError,
|
||||
console::{errors::WakeComputeError, messages::MetricsAuxInfo},
|
||||
context::RequestMonitoring,
|
||||
error::{ReportableError, UserFacingError},
|
||||
metrics::NUM_DB_CONNECTIONS_GAUGE,
|
||||
@@ -93,7 +93,7 @@ impl ConnCfg {
|
||||
}
|
||||
|
||||
/// Reuse password or auth keys from the other config.
|
||||
pub fn reuse_password(&mut self, other: &Self) {
|
||||
pub fn reuse_password(&mut self, other: Self) {
|
||||
if let Some(password) = other.get_password() {
|
||||
self.password(password);
|
||||
}
|
||||
@@ -253,6 +253,8 @@ pub struct PostgresConnection {
|
||||
pub params: std::collections::HashMap<String, String>,
|
||||
/// Query cancellation token.
|
||||
pub cancel_closure: CancelClosure,
|
||||
/// Labels for proxy's metrics.
|
||||
pub aux: MetricsAuxInfo,
|
||||
|
||||
_guage: IntCounterPairGuard,
|
||||
}
|
||||
@@ -263,6 +265,7 @@ impl ConnCfg {
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
allow_self_signed_compute: bool,
|
||||
aux: MetricsAuxInfo,
|
||||
timeout: Duration,
|
||||
) -> Result<PostgresConnection, ConnectionError> {
|
||||
let (socket_addr, stream, host) = self.connect_raw(timeout).await?;
|
||||
@@ -297,6 +300,7 @@ impl ConnCfg {
|
||||
stream,
|
||||
params,
|
||||
cancel_closure,
|
||||
aux,
|
||||
_guage: NUM_DB_CONNECTIONS_GAUGE
|
||||
.with_label_values(&[ctx.protocol])
|
||||
.guard(),
|
||||
|
||||
@@ -13,7 +13,7 @@ use x509_parser::oid_registry;
|
||||
|
||||
pub struct ProxyConfig {
|
||||
pub tls_config: Option<TlsConfig>,
|
||||
pub auth_backend: auth::BackendType<'static, ()>,
|
||||
pub auth_backend: auth::BackendType<'static, (), ()>,
|
||||
pub metric_collection: Option<MetricCollectionConfig>,
|
||||
pub allow_self_signed_compute: bool,
|
||||
pub http_config: HttpConfig,
|
||||
@@ -21,6 +21,7 @@ pub struct ProxyConfig {
|
||||
pub require_client_ip: bool,
|
||||
pub disable_ip_check_for_http: bool,
|
||||
pub endpoint_rps_limit: Vec<RateBucketInfo>,
|
||||
pub redis_rps_limit: Vec<RateBucketInfo>,
|
||||
pub region: String,
|
||||
pub handshake_timeout: Duration,
|
||||
}
|
||||
|
||||
@@ -4,7 +4,10 @@ pub mod neon;
|
||||
|
||||
use super::messages::MetricsAuxInfo;
|
||||
use crate::{
|
||||
auth::{backend::ComputeUserInfo, IpPattern},
|
||||
auth::{
|
||||
backend::{ComputeCredentialKeys, ComputeUserInfo},
|
||||
IpPattern,
|
||||
},
|
||||
cache::{project_info::ProjectInfoCacheImpl, Cached, TimedLru},
|
||||
compute,
|
||||
config::{CacheOptions, ProjectInfoCacheOptions},
|
||||
@@ -261,6 +264,34 @@ pub struct NodeInfo {
|
||||
pub allow_self_signed_compute: bool,
|
||||
}
|
||||
|
||||
impl NodeInfo {
|
||||
pub async fn connect(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
timeout: Duration,
|
||||
) -> Result<compute::PostgresConnection, compute::ConnectionError> {
|
||||
self.config
|
||||
.connect(
|
||||
ctx,
|
||||
self.allow_self_signed_compute,
|
||||
self.aux.clone(),
|
||||
timeout,
|
||||
)
|
||||
.await
|
||||
}
|
||||
pub fn reuse_settings(&mut self, other: Self) {
|
||||
self.allow_self_signed_compute = other.allow_self_signed_compute;
|
||||
self.config.reuse_password(other.config);
|
||||
}
|
||||
|
||||
pub fn set_keys(&mut self, keys: &ComputeCredentialKeys) {
|
||||
match keys {
|
||||
ComputeCredentialKeys::Password(password) => self.config.password(password),
|
||||
ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
pub type NodeInfoCache = TimedLru<EndpointCacheKey, NodeInfo>;
|
||||
pub type CachedNodeInfo = Cached<&'static NodeInfoCache>;
|
||||
pub type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option<AuthSecret>>;
|
||||
|
||||
@@ -176,9 +176,7 @@ impl super::Api for Api {
|
||||
_ctx: &mut RequestMonitoring,
|
||||
_user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedNodeInfo, WakeComputeError> {
|
||||
self.do_wake_compute()
|
||||
.map_ok(CachedNodeInfo::new_uncached)
|
||||
.await
|
||||
self.do_wake_compute().map_ok(Cached::new_uncached).await
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ use crate::{
|
||||
console::messages::MetricsAuxInfo,
|
||||
error::ErrorKind,
|
||||
metrics::{LatencyTimer, ENDPOINT_ERRORS_BY_KIND, ERROR_BY_KIND},
|
||||
BranchId, EndpointId, ProjectId, RoleName,
|
||||
BranchId, DbName, EndpointId, ProjectId, RoleName,
|
||||
};
|
||||
|
||||
pub mod parquet;
|
||||
@@ -34,9 +34,11 @@ pub struct RequestMonitoring {
|
||||
project: Option<ProjectId>,
|
||||
branch: Option<BranchId>,
|
||||
endpoint_id: Option<EndpointId>,
|
||||
dbname: Option<DbName>,
|
||||
user: Option<RoleName>,
|
||||
application: Option<SmolStr>,
|
||||
error_kind: Option<ErrorKind>,
|
||||
pub(crate) auth_method: Option<AuthMethod>,
|
||||
success: bool,
|
||||
|
||||
// extra
|
||||
@@ -45,6 +47,15 @@ pub struct RequestMonitoring {
|
||||
pub latency_timer: LatencyTimer,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum AuthMethod {
|
||||
// aka link aka passwordless
|
||||
Web,
|
||||
ScramSha256,
|
||||
ScramSha256Plus,
|
||||
Cleartext,
|
||||
}
|
||||
|
||||
impl RequestMonitoring {
|
||||
pub fn new(
|
||||
session_id: Uuid,
|
||||
@@ -62,9 +73,11 @@ impl RequestMonitoring {
|
||||
project: None,
|
||||
branch: None,
|
||||
endpoint_id: None,
|
||||
dbname: None,
|
||||
user: None,
|
||||
application: None,
|
||||
error_kind: None,
|
||||
auth_method: None,
|
||||
success: false,
|
||||
|
||||
sender: LOG_CHAN.get().and_then(|tx| tx.upgrade()),
|
||||
@@ -106,10 +119,18 @@ impl RequestMonitoring {
|
||||
self.application = app.or_else(|| self.application.clone());
|
||||
}
|
||||
|
||||
pub fn set_dbname(&mut self, dbname: DbName) {
|
||||
self.dbname = Some(dbname);
|
||||
}
|
||||
|
||||
pub fn set_user(&mut self, user: RoleName) {
|
||||
self.user = Some(user);
|
||||
}
|
||||
|
||||
pub fn set_auth_method(&mut self, auth_method: AuthMethod) {
|
||||
self.auth_method = Some(auth_method);
|
||||
}
|
||||
|
||||
pub fn set_error_kind(&mut self, kind: ErrorKind) {
|
||||
ERROR_BY_KIND
|
||||
.with_label_values(&[kind.to_metric_label()])
|
||||
|
||||
@@ -84,8 +84,10 @@ struct RequestData {
|
||||
username: Option<String>,
|
||||
application_name: Option<String>,
|
||||
endpoint_id: Option<String>,
|
||||
database: Option<String>,
|
||||
project: Option<String>,
|
||||
branch: Option<String>,
|
||||
auth_method: Option<&'static str>,
|
||||
error: Option<&'static str>,
|
||||
/// Success is counted if we form a HTTP response with sql rows inside
|
||||
/// Or if we make it to proxy_pass
|
||||
@@ -104,8 +106,15 @@ impl From<RequestMonitoring> for RequestData {
|
||||
username: value.user.as_deref().map(String::from),
|
||||
application_name: value.application.as_deref().map(String::from),
|
||||
endpoint_id: value.endpoint_id.as_deref().map(String::from),
|
||||
database: value.dbname.as_deref().map(String::from),
|
||||
project: value.project.as_deref().map(String::from),
|
||||
branch: value.branch.as_deref().map(String::from),
|
||||
auth_method: value.auth_method.as_ref().map(|x| match x {
|
||||
super::AuthMethod::Web => "web",
|
||||
super::AuthMethod::ScramSha256 => "scram_sha_256",
|
||||
super::AuthMethod::ScramSha256Plus => "scram_sha_256_plus",
|
||||
super::AuthMethod::Cleartext => "cleartext",
|
||||
}),
|
||||
protocol: value.protocol,
|
||||
region: value.region,
|
||||
error: value.error_kind.as_ref().map(|e| e.to_metric_label()),
|
||||
@@ -431,8 +440,10 @@ mod tests {
|
||||
application_name: Some("test".to_owned()),
|
||||
username: Some(hex::encode(rng.gen::<[u8; 4]>())),
|
||||
endpoint_id: Some(hex::encode(rng.gen::<[u8; 16]>())),
|
||||
database: Some(hex::encode(rng.gen::<[u8; 16]>())),
|
||||
project: Some(hex::encode(rng.gen::<[u8; 16]>())),
|
||||
branch: Some(hex::encode(rng.gen::<[u8; 16]>())),
|
||||
auth_method: None,
|
||||
protocol: ["tcp", "ws", "http"][rng.gen_range(0..3)],
|
||||
region: "us-east-1",
|
||||
error: None,
|
||||
@@ -505,15 +516,15 @@ mod tests {
|
||||
assert_eq!(
|
||||
file_stats,
|
||||
[
|
||||
(1087635, 3, 6000),
|
||||
(1087288, 3, 6000),
|
||||
(1087444, 3, 6000),
|
||||
(1087572, 3, 6000),
|
||||
(1087468, 3, 6000),
|
||||
(1087500, 3, 6000),
|
||||
(1087533, 3, 6000),
|
||||
(1087566, 3, 6000),
|
||||
(362671, 1, 2000)
|
||||
(1313727, 3, 6000),
|
||||
(1313720, 3, 6000),
|
||||
(1313780, 3, 6000),
|
||||
(1313737, 3, 6000),
|
||||
(1313867, 3, 6000),
|
||||
(1313709, 3, 6000),
|
||||
(1313501, 3, 6000),
|
||||
(1313737, 3, 6000),
|
||||
(438118, 1, 2000)
|
||||
],
|
||||
);
|
||||
|
||||
@@ -543,11 +554,11 @@ mod tests {
|
||||
assert_eq!(
|
||||
file_stats,
|
||||
[
|
||||
(1028637, 5, 10000),
|
||||
(1031969, 5, 10000),
|
||||
(1019900, 5, 10000),
|
||||
(1020365, 5, 10000),
|
||||
(1025010, 5, 10000)
|
||||
(1219459, 5, 10000),
|
||||
(1225609, 5, 10000),
|
||||
(1227403, 5, 10000),
|
||||
(1226765, 5, 10000),
|
||||
(1218043, 5, 10000)
|
||||
],
|
||||
);
|
||||
|
||||
@@ -579,11 +590,11 @@ mod tests {
|
||||
assert_eq!(
|
||||
file_stats,
|
||||
[
|
||||
(1210770, 6, 12000),
|
||||
(1211036, 6, 12000),
|
||||
(1210990, 6, 12000),
|
||||
(1210861, 6, 12000),
|
||||
(202073, 1, 2000)
|
||||
(1205106, 5, 10000),
|
||||
(1204837, 5, 10000),
|
||||
(1205130, 5, 10000),
|
||||
(1205118, 5, 10000),
|
||||
(1205373, 5, 10000)
|
||||
],
|
||||
);
|
||||
|
||||
@@ -608,15 +619,15 @@ mod tests {
|
||||
assert_eq!(
|
||||
file_stats,
|
||||
[
|
||||
(1087635, 3, 6000),
|
||||
(1087288, 3, 6000),
|
||||
(1087444, 3, 6000),
|
||||
(1087572, 3, 6000),
|
||||
(1087468, 3, 6000),
|
||||
(1087500, 3, 6000),
|
||||
(1087533, 3, 6000),
|
||||
(1087566, 3, 6000),
|
||||
(362671, 1, 2000)
|
||||
(1313727, 3, 6000),
|
||||
(1313720, 3, 6000),
|
||||
(1313780, 3, 6000),
|
||||
(1313737, 3, 6000),
|
||||
(1313867, 3, 6000),
|
||||
(1313709, 3, 6000),
|
||||
(1313501, 3, 6000),
|
||||
(1313737, 3, 6000),
|
||||
(438118, 1, 2000)
|
||||
],
|
||||
);
|
||||
|
||||
@@ -653,7 +664,7 @@ mod tests {
|
||||
// files are smaller than the size threshold, but they took too long to fill so were flushed early
|
||||
assert_eq!(
|
||||
file_stats,
|
||||
[(545264, 2, 3001), (545025, 2, 3000), (544857, 2, 2999)],
|
||||
[(658383, 2, 3001), (658097, 2, 3000), (657893, 2, 2999)],
|
||||
);
|
||||
|
||||
tmpdir.close().unwrap();
|
||||
|
||||
@@ -29,7 +29,7 @@ pub trait UserFacingError: ReportableError {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
|
||||
pub enum ErrorKind {
|
||||
/// Wrong password, unknown endpoint, protocol violation, etc...
|
||||
User,
|
||||
@@ -90,3 +90,13 @@ impl ReportableError for tokio::time::error::Elapsed {
|
||||
ErrorKind::RateLimit
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for tokio_postgres::error::Error {
|
||||
fn get_error_kind(&self) -> ErrorKind {
|
||||
if self.as_db_error().is_some() {
|
||||
ErrorKind::Postgres
|
||||
} else {
|
||||
ErrorKind::Compute
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -152,6 +152,15 @@ pub static NUM_OPEN_CLIENTS_IN_HTTP_POOL: Lazy<IntGauge> = Lazy::new(|| {
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
pub static NUM_CANCELLATION_REQUESTS: Lazy<IntCounterVec> = Lazy::new(|| {
|
||||
register_int_counter_vec!(
|
||||
"proxy_cancellation_requests_total",
|
||||
"Number of cancellation requests (per found/not_found).",
|
||||
&["source", "kind"],
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct LatencyTimer {
|
||||
// time since the stopwatch was started
|
||||
@@ -200,8 +209,9 @@ impl LatencyTimer {
|
||||
|
||||
pub fn success(&mut self) {
|
||||
// stop the stopwatch and record the time that we have accumulated
|
||||
let start = self.start.take().expect("latency timer should be started");
|
||||
self.accumulated += start.elapsed();
|
||||
if let Some(start) = self.start.take() {
|
||||
self.accumulated += start.elapsed();
|
||||
}
|
||||
|
||||
// success
|
||||
self.outcome = "success";
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
mod tests;
|
||||
|
||||
pub mod connect_compute;
|
||||
mod copy_bidirectional;
|
||||
pub mod handshake;
|
||||
pub mod passthrough;
|
||||
pub mod retry;
|
||||
@@ -9,7 +10,7 @@ pub mod wake_compute;
|
||||
|
||||
use crate::{
|
||||
auth,
|
||||
cancellation::{self, CancelMap},
|
||||
cancellation::{self, CancellationHandler},
|
||||
compute,
|
||||
config::{ProxyConfig, TlsConfig},
|
||||
context::RequestMonitoring,
|
||||
@@ -61,6 +62,7 @@ pub async fn task_main(
|
||||
listener: tokio::net::TcpListener,
|
||||
cancellation_token: CancellationToken,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
) -> anyhow::Result<()> {
|
||||
scopeguard::defer! {
|
||||
info!("proxy has shut down");
|
||||
@@ -71,7 +73,6 @@ pub async fn task_main(
|
||||
socket2::SockRef::from(&listener).set_keepalive(true)?;
|
||||
|
||||
let connections = tokio_util::task::task_tracker::TaskTracker::new();
|
||||
let cancel_map = Arc::new(CancelMap::default());
|
||||
|
||||
while let Some(accept_result) =
|
||||
run_until_cancelled(listener.accept(), &cancellation_token).await
|
||||
@@ -79,7 +80,7 @@ pub async fn task_main(
|
||||
let (socket, peer_addr) = accept_result?;
|
||||
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
let cancel_map = Arc::clone(&cancel_map);
|
||||
let cancellation_handler = Arc::clone(&cancellation_handler);
|
||||
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
||||
|
||||
let session_span = info_span!(
|
||||
@@ -112,7 +113,7 @@ pub async fn task_main(
|
||||
let res = handle_client(
|
||||
config,
|
||||
&mut ctx,
|
||||
cancel_map,
|
||||
cancellation_handler,
|
||||
socket,
|
||||
ClientMode::Tcp,
|
||||
endpoint_rate_limiter,
|
||||
@@ -162,14 +163,14 @@ pub enum ClientMode {
|
||||
|
||||
/// Abstracts the logic of handling TCP vs WS clients
|
||||
impl ClientMode {
|
||||
fn allow_cleartext(&self) -> bool {
|
||||
pub fn allow_cleartext(&self) -> bool {
|
||||
match self {
|
||||
ClientMode::Tcp => false,
|
||||
ClientMode::Websockets { .. } => true,
|
||||
}
|
||||
}
|
||||
|
||||
fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool {
|
||||
pub fn allow_self_signed_compute(&self, config: &ProxyConfig) -> bool {
|
||||
match self {
|
||||
ClientMode::Tcp => config.allow_self_signed_compute,
|
||||
ClientMode::Websockets { .. } => false,
|
||||
@@ -226,7 +227,7 @@ impl ReportableError for ClientRequestError {
|
||||
pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
config: &'static ProxyConfig,
|
||||
ctx: &mut RequestMonitoring,
|
||||
cancel_map: Arc<CancelMap>,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
stream: S,
|
||||
mode: ClientMode,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
@@ -252,8 +253,8 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? {
|
||||
HandshakeData::Startup(stream, params) => (stream, params),
|
||||
HandshakeData::Cancel(cancel_key_data) => {
|
||||
return Ok(cancel_map
|
||||
.cancel_session(cancel_key_data)
|
||||
return Ok(cancellation_handler
|
||||
.cancel_session(cancel_key_data, ctx.session_id)
|
||||
.await
|
||||
.map(|()| None)?)
|
||||
}
|
||||
@@ -286,7 +287,7 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
}
|
||||
|
||||
let user = user_info.get_user().to_owned();
|
||||
let (mut node_info, user_info) = match user_info
|
||||
let user_info = match user_info
|
||||
.authenticate(
|
||||
ctx,
|
||||
&mut stream,
|
||||
@@ -305,19 +306,16 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
}
|
||||
};
|
||||
|
||||
node_info.allow_self_signed_compute = mode.allow_self_signed_compute(config);
|
||||
|
||||
let aux = node_info.aux.clone();
|
||||
let mut node = connect_to_compute(
|
||||
ctx,
|
||||
&TcpMechanism { params: ¶ms },
|
||||
node_info,
|
||||
&user_info,
|
||||
mode.allow_self_signed_compute(config),
|
||||
)
|
||||
.or_else(|e| stream.throw_error(e))
|
||||
.await?;
|
||||
|
||||
let session = cancel_map.get_session();
|
||||
let session = cancellation_handler.get_session();
|
||||
prepare_client_connection(&node, &session, &mut stream).await?;
|
||||
|
||||
// Before proxy passing, forward to compute whatever data is left in the
|
||||
@@ -329,10 +327,11 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
|
||||
Ok(Some(ProxyPassthrough {
|
||||
client: stream,
|
||||
aux: node.aux.clone(),
|
||||
compute: node,
|
||||
aux,
|
||||
req: _request_gauge,
|
||||
conn: _client_gauge,
|
||||
cancel: session,
|
||||
}))
|
||||
}
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
use crate::{
|
||||
auth,
|
||||
auth::backend::ComputeCredentialKeys,
|
||||
compute::{self, PostgresConnection},
|
||||
console::{self, errors::WakeComputeError},
|
||||
console::{self, errors::WakeComputeError, CachedNodeInfo, NodeInfo},
|
||||
context::RequestMonitoring,
|
||||
error::ReportableError,
|
||||
metrics::NUM_CONNECTION_FAILURES,
|
||||
proxy::{
|
||||
retry::{retry_after, ShouldRetry},
|
||||
@@ -20,7 +21,7 @@ const CONNECT_TIMEOUT: time::Duration = time::Duration::from_secs(2);
|
||||
/// (e.g. the compute node's address might've changed at the wrong time).
|
||||
/// Invalidate the cache entry (if any) to prevent subsequent errors.
|
||||
#[tracing::instrument(name = "invalidate_cache", skip_all)]
|
||||
pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> compute::ConnCfg {
|
||||
pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> NodeInfo {
|
||||
let is_cached = node_info.cached();
|
||||
if is_cached {
|
||||
warn!("invalidating stalled compute node info cache entry");
|
||||
@@ -31,13 +32,13 @@ pub fn invalidate_cache(node_info: console::CachedNodeInfo) -> compute::ConnCfg
|
||||
};
|
||||
NUM_CONNECTION_FAILURES.with_label_values(&[label]).inc();
|
||||
|
||||
node_info.invalidate().config
|
||||
node_info.invalidate()
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait ConnectMechanism {
|
||||
type Connection;
|
||||
type ConnectError;
|
||||
type ConnectError: ReportableError;
|
||||
type Error: From<Self::ConnectError>;
|
||||
async fn connect_once(
|
||||
&self,
|
||||
@@ -49,6 +50,16 @@ pub trait ConnectMechanism {
|
||||
fn update_connect_config(&self, conf: &mut compute::ConnCfg);
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait ComputeConnectBackend {
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
) -> Result<CachedNodeInfo, console::errors::WakeComputeError>;
|
||||
|
||||
fn get_keys(&self) -> Option<&ComputeCredentialKeys>;
|
||||
}
|
||||
|
||||
pub struct TcpMechanism<'a> {
|
||||
/// KV-dictionary with PostgreSQL connection params.
|
||||
pub params: &'a StartupMessageParams,
|
||||
@@ -67,11 +78,7 @@ impl ConnectMechanism for TcpMechanism<'_> {
|
||||
node_info: &console::CachedNodeInfo,
|
||||
timeout: time::Duration,
|
||||
) -> Result<PostgresConnection, Self::Error> {
|
||||
let allow_self_signed_compute = node_info.allow_self_signed_compute;
|
||||
node_info
|
||||
.config
|
||||
.connect(ctx, allow_self_signed_compute, timeout)
|
||||
.await
|
||||
node_info.connect(ctx, timeout).await
|
||||
}
|
||||
|
||||
fn update_connect_config(&self, config: &mut compute::ConnCfg) {
|
||||
@@ -82,16 +89,23 @@ impl ConnectMechanism for TcpMechanism<'_> {
|
||||
/// Try to connect to the compute node, retrying if necessary.
|
||||
/// This function might update `node_info`, so we take it by `&mut`.
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn connect_to_compute<M: ConnectMechanism>(
|
||||
pub async fn connect_to_compute<M: ConnectMechanism, B: ComputeConnectBackend>(
|
||||
ctx: &mut RequestMonitoring,
|
||||
mechanism: &M,
|
||||
mut node_info: console::CachedNodeInfo,
|
||||
user_info: &auth::BackendType<'_, auth::backend::ComputeUserInfo>,
|
||||
user_info: &B,
|
||||
allow_self_signed_compute: bool,
|
||||
) -> Result<M::Connection, M::Error>
|
||||
where
|
||||
M::ConnectError: ShouldRetry + std::fmt::Debug,
|
||||
M::Error: From<WakeComputeError>,
|
||||
{
|
||||
let mut num_retries = 0;
|
||||
let mut node_info = wake_compute(&mut num_retries, ctx, user_info).await?;
|
||||
if let Some(keys) = user_info.get_keys() {
|
||||
node_info.set_keys(keys);
|
||||
}
|
||||
node_info.allow_self_signed_compute = allow_self_signed_compute;
|
||||
// let mut node_info = credentials.get_node_info(ctx, user_info).await?;
|
||||
mechanism.update_connect_config(&mut node_info.config);
|
||||
|
||||
// try once
|
||||
@@ -108,28 +122,31 @@ where
|
||||
|
||||
error!(error = ?err, "could not connect to compute node");
|
||||
|
||||
let mut num_retries = 1;
|
||||
|
||||
match user_info {
|
||||
auth::BackendType::Console(api, info) => {
|
||||
let node_info =
|
||||
if err.get_error_kind() == crate::error::ErrorKind::Postgres || !node_info.cached() {
|
||||
// If the error is Postgres, that means that we managed to connect to the compute node, but there was an error.
|
||||
// Do not need to retrieve a new node_info, just return the old one.
|
||||
if !err.should_retry(num_retries) {
|
||||
return Err(err.into());
|
||||
}
|
||||
node_info
|
||||
} else {
|
||||
// if we failed to connect, it's likely that the compute node was suspended, wake a new compute node
|
||||
info!("compute node's state has likely changed; requesting a wake-up");
|
||||
|
||||
ctx.latency_timer.cache_miss();
|
||||
let config = invalidate_cache(node_info);
|
||||
node_info = wake_compute(&mut num_retries, ctx, api, info).await?;
|
||||
let old_node_info = invalidate_cache(node_info);
|
||||
let mut node_info = wake_compute(&mut num_retries, ctx, user_info).await?;
|
||||
node_info.reuse_settings(old_node_info);
|
||||
|
||||
node_info.config.reuse_password(&config);
|
||||
mechanism.update_connect_config(&mut node_info.config);
|
||||
}
|
||||
// nothing to do?
|
||||
auth::BackendType::Link(_) => {}
|
||||
};
|
||||
node_info
|
||||
};
|
||||
|
||||
// now that we have a new node, try connect to it repeatedly.
|
||||
// this can error for a few reasons, for instance:
|
||||
// * DNS connection settings haven't quite propagated yet
|
||||
info!("wake_compute success. attempting to connect");
|
||||
num_retries = 1;
|
||||
loop {
|
||||
match mechanism
|
||||
.connect_once(ctx, &node_info, CONNECT_TIMEOUT)
|
||||
|
||||
256
proxy/src/proxy/copy_bidirectional.rs
Normal file
256
proxy/src/proxy/copy_bidirectional.rs
Normal file
@@ -0,0 +1,256 @@
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
|
||||
use std::future::poll_fn;
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use std::task::{ready, Context, Poll};
|
||||
|
||||
#[derive(Debug)]
|
||||
enum TransferState {
|
||||
Running(CopyBuffer),
|
||||
ShuttingDown(u64),
|
||||
Done(u64),
|
||||
}
|
||||
|
||||
fn transfer_one_direction<A, B>(
|
||||
cx: &mut Context<'_>,
|
||||
state: &mut TransferState,
|
||||
r: &mut A,
|
||||
w: &mut B,
|
||||
) -> Poll<io::Result<u64>>
|
||||
where
|
||||
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
|
||||
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
|
||||
{
|
||||
let mut r = Pin::new(r);
|
||||
let mut w = Pin::new(w);
|
||||
loop {
|
||||
match state {
|
||||
TransferState::Running(buf) => {
|
||||
let count = ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?;
|
||||
*state = TransferState::ShuttingDown(count);
|
||||
}
|
||||
TransferState::ShuttingDown(count) => {
|
||||
ready!(w.as_mut().poll_shutdown(cx))?;
|
||||
*state = TransferState::Done(*count);
|
||||
}
|
||||
TransferState::Done(count) => return Poll::Ready(Ok(*count)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn copy_bidirectional<A, B>(
|
||||
a: &mut A,
|
||||
b: &mut B,
|
||||
) -> Result<(u64, u64), std::io::Error>
|
||||
where
|
||||
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
|
||||
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
|
||||
{
|
||||
let mut a_to_b = TransferState::Running(CopyBuffer::new());
|
||||
let mut b_to_a = TransferState::Running(CopyBuffer::new());
|
||||
|
||||
poll_fn(|cx| {
|
||||
let mut a_to_b_result = transfer_one_direction(cx, &mut a_to_b, a, b)?;
|
||||
let mut b_to_a_result = transfer_one_direction(cx, &mut b_to_a, b, a)?;
|
||||
|
||||
// Early termination checks
|
||||
if let TransferState::Done(_) = a_to_b {
|
||||
if let TransferState::Running(buf) = &b_to_a {
|
||||
// Initiate shutdown
|
||||
b_to_a = TransferState::ShuttingDown(buf.amt);
|
||||
b_to_a_result = transfer_one_direction(cx, &mut b_to_a, b, a)?;
|
||||
}
|
||||
}
|
||||
if let TransferState::Done(_) = b_to_a {
|
||||
if let TransferState::Running(buf) = &a_to_b {
|
||||
// Initiate shutdown
|
||||
a_to_b = TransferState::ShuttingDown(buf.amt);
|
||||
a_to_b_result = transfer_one_direction(cx, &mut a_to_b, a, b)?;
|
||||
}
|
||||
}
|
||||
|
||||
// It is not a problem if ready! returns early ... (comment remains the same)
|
||||
let a_to_b = ready!(a_to_b_result);
|
||||
let b_to_a = ready!(b_to_a_result);
|
||||
|
||||
Poll::Ready(Ok((a_to_b, b_to_a)))
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(super) struct CopyBuffer {
|
||||
read_done: bool,
|
||||
need_flush: bool,
|
||||
pos: usize,
|
||||
cap: usize,
|
||||
amt: u64,
|
||||
buf: Box<[u8]>,
|
||||
}
|
||||
const DEFAULT_BUF_SIZE: usize = 8 * 1024;
|
||||
|
||||
impl CopyBuffer {
|
||||
pub(super) fn new() -> Self {
|
||||
Self {
|
||||
read_done: false,
|
||||
need_flush: false,
|
||||
pos: 0,
|
||||
cap: 0,
|
||||
amt: 0,
|
||||
buf: vec![0; DEFAULT_BUF_SIZE].into_boxed_slice(),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_fill_buf<R>(
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
reader: Pin<&mut R>,
|
||||
) -> Poll<io::Result<()>>
|
||||
where
|
||||
R: AsyncRead + ?Sized,
|
||||
{
|
||||
let me = &mut *self;
|
||||
let mut buf = ReadBuf::new(&mut me.buf);
|
||||
buf.set_filled(me.cap);
|
||||
|
||||
let res = reader.poll_read(cx, &mut buf);
|
||||
if let Poll::Ready(Ok(())) = res {
|
||||
let filled_len = buf.filled().len();
|
||||
me.read_done = me.cap == filled_len;
|
||||
me.cap = filled_len;
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
fn poll_write_buf<R, W>(
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
mut reader: Pin<&mut R>,
|
||||
mut writer: Pin<&mut W>,
|
||||
) -> Poll<io::Result<usize>>
|
||||
where
|
||||
R: AsyncRead + ?Sized,
|
||||
W: AsyncWrite + ?Sized,
|
||||
{
|
||||
let me = &mut *self;
|
||||
match writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]) {
|
||||
Poll::Pending => {
|
||||
// Top up the buffer towards full if we can read a bit more
|
||||
// data - this should improve the chances of a large write
|
||||
if !me.read_done && me.cap < me.buf.len() {
|
||||
ready!(me.poll_fill_buf(cx, reader.as_mut()))?;
|
||||
}
|
||||
Poll::Pending
|
||||
}
|
||||
res => res,
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn poll_copy<R, W>(
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
mut reader: Pin<&mut R>,
|
||||
mut writer: Pin<&mut W>,
|
||||
) -> Poll<io::Result<u64>>
|
||||
where
|
||||
R: AsyncRead + ?Sized,
|
||||
W: AsyncWrite + ?Sized,
|
||||
{
|
||||
loop {
|
||||
// If our buffer is empty, then we need to read some data to
|
||||
// continue.
|
||||
if self.pos == self.cap && !self.read_done {
|
||||
self.pos = 0;
|
||||
self.cap = 0;
|
||||
|
||||
match self.poll_fill_buf(cx, reader.as_mut()) {
|
||||
Poll::Ready(Ok(())) => (),
|
||||
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
|
||||
Poll::Pending => {
|
||||
// Try flushing when the reader has no progress to avoid deadlock
|
||||
// when the reader depends on buffered writer.
|
||||
if self.need_flush {
|
||||
ready!(writer.as_mut().poll_flush(cx))?;
|
||||
self.need_flush = false;
|
||||
}
|
||||
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If our buffer has some data, let's write it out!
|
||||
while self.pos < self.cap {
|
||||
let i = ready!(self.poll_write_buf(cx, reader.as_mut(), writer.as_mut()))?;
|
||||
if i == 0 {
|
||||
return Poll::Ready(Err(io::Error::new(
|
||||
io::ErrorKind::WriteZero,
|
||||
"write zero byte into writer",
|
||||
)));
|
||||
} else {
|
||||
self.pos += i;
|
||||
self.amt += i as u64;
|
||||
self.need_flush = true;
|
||||
}
|
||||
}
|
||||
|
||||
// If pos larger than cap, this loop will never stop.
|
||||
// In particular, user's wrong poll_write implementation returning
|
||||
// incorrect written length may lead to thread blocking.
|
||||
debug_assert!(
|
||||
self.pos <= self.cap,
|
||||
"writer returned length larger than input slice"
|
||||
);
|
||||
|
||||
// If we've written all the data and we've seen EOF, flush out the
|
||||
// data and finish the transfer.
|
||||
if self.pos == self.cap && self.read_done {
|
||||
ready!(writer.as_mut().poll_flush(cx))?;
|
||||
return Poll::Ready(Ok(self.amt));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_early_termination_a_to_d() {
|
||||
let (mut a_mock, mut b_mock) = tokio::io::duplex(8); // Create a mock duplex stream
|
||||
let (mut c_mock, mut d_mock) = tokio::io::duplex(32); // Create a mock duplex stream
|
||||
|
||||
// Simulate 'a' finishing while there's still data for 'b'
|
||||
a_mock.write_all(b"hello").await.unwrap();
|
||||
a_mock.shutdown().await.unwrap();
|
||||
d_mock.write_all(b"Neon Serverless Postgres").await.unwrap();
|
||||
|
||||
let result = copy_bidirectional(&mut b_mock, &mut c_mock).await.unwrap();
|
||||
|
||||
// Assert correct transferred amounts
|
||||
let (a_to_d_count, d_to_a_count) = result;
|
||||
assert_eq!(a_to_d_count, 5); // 'hello' was transferred
|
||||
assert!(d_to_a_count <= 8); // response only partially transferred or not at all
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_early_termination_d_to_a() {
|
||||
let (mut a_mock, mut b_mock) = tokio::io::duplex(32); // Create a mock duplex stream
|
||||
let (mut c_mock, mut d_mock) = tokio::io::duplex(8); // Create a mock duplex stream
|
||||
|
||||
// Simulate 'a' finishing while there's still data for 'b'
|
||||
d_mock.write_all(b"hello").await.unwrap();
|
||||
d_mock.shutdown().await.unwrap();
|
||||
a_mock.write_all(b"Neon Serverless Postgres").await.unwrap();
|
||||
|
||||
let result = copy_bidirectional(&mut b_mock, &mut c_mock).await.unwrap();
|
||||
|
||||
// Assert correct transferred amounts
|
||||
let (a_to_d_count, d_to_a_count) = result;
|
||||
assert_eq!(d_to_a_count, 5); // 'hello' was transferred
|
||||
assert!(a_to_d_count <= 8); // response only partially transferred or not at all
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
use crate::{
|
||||
cancellation,
|
||||
compute::PostgresConnection,
|
||||
console::messages::MetricsAuxInfo,
|
||||
metrics::NUM_BYTES_PROXIED_COUNTER,
|
||||
@@ -45,7 +46,7 @@ pub async fn proxy_pass(
|
||||
|
||||
// Starting from here we only proxy the client's traffic.
|
||||
info!("performing the proxy pass...");
|
||||
let _ = tokio::io::copy_bidirectional(&mut client, &mut compute).await?;
|
||||
let _ = crate::proxy::copy_bidirectional::copy_bidirectional(&mut client, &mut compute).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -57,6 +58,7 @@ pub struct ProxyPassthrough<S> {
|
||||
|
||||
pub req: IntCounterPairGuard,
|
||||
pub conn: IntCounterPairGuard,
|
||||
pub cancel: cancellation::Session,
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {
|
||||
|
||||
@@ -2,13 +2,19 @@
|
||||
|
||||
mod mitm;
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use super::connect_compute::ConnectMechanism;
|
||||
use super::retry::ShouldRetry;
|
||||
use super::*;
|
||||
use crate::auth::backend::{ComputeUserInfo, MaybeOwned, TestBackend};
|
||||
use crate::auth::backend::{
|
||||
ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned, TestBackend,
|
||||
};
|
||||
use crate::config::CertResolver;
|
||||
use crate::console::caches::NodeInfoCache;
|
||||
use crate::console::provider::{CachedAllowedIps, CachedRoleSecret, ConsoleBackend};
|
||||
use crate::console::{self, CachedNodeInfo, NodeInfo};
|
||||
use crate::error::ErrorKind;
|
||||
use crate::proxy::retry::{retry_after, NUM_RETRIES_CONNECT};
|
||||
use crate::{auth, http, sasl, scram};
|
||||
use async_trait::async_trait;
|
||||
@@ -144,7 +150,7 @@ impl TestAuth for Scram {
|
||||
stream: &mut PqStream<Stream<S>>,
|
||||
) -> anyhow::Result<()> {
|
||||
let outcome = auth::AuthFlow::new(stream)
|
||||
.begin(auth::Scram(&self.0))
|
||||
.begin(auth::Scram(&self.0, &mut RequestMonitoring::test()))
|
||||
.await?
|
||||
.authenticate()
|
||||
.await?;
|
||||
@@ -369,12 +375,15 @@ enum ConnectAction {
|
||||
Connect,
|
||||
Retry,
|
||||
Fail,
|
||||
RetryPg,
|
||||
FailPg,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct TestConnectMechanism {
|
||||
counter: Arc<std::sync::Mutex<usize>>,
|
||||
sequence: Vec<ConnectAction>,
|
||||
cache: &'static NodeInfoCache,
|
||||
}
|
||||
|
||||
impl TestConnectMechanism {
|
||||
@@ -393,6 +402,12 @@ impl TestConnectMechanism {
|
||||
Self {
|
||||
counter: Arc::new(std::sync::Mutex::new(0)),
|
||||
sequence,
|
||||
cache: Box::leak(Box::new(NodeInfoCache::new(
|
||||
"test",
|
||||
1,
|
||||
Duration::from_secs(100),
|
||||
false,
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -403,6 +418,13 @@ struct TestConnection;
|
||||
#[derive(Debug)]
|
||||
struct TestConnectError {
|
||||
retryable: bool,
|
||||
kind: crate::error::ErrorKind,
|
||||
}
|
||||
|
||||
impl ReportableError for TestConnectError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
self.kind
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TestConnectError {
|
||||
@@ -436,8 +458,22 @@ impl ConnectMechanism for TestConnectMechanism {
|
||||
*counter += 1;
|
||||
match action {
|
||||
ConnectAction::Connect => Ok(TestConnection),
|
||||
ConnectAction::Retry => Err(TestConnectError { retryable: true }),
|
||||
ConnectAction::Fail => Err(TestConnectError { retryable: false }),
|
||||
ConnectAction::Retry => Err(TestConnectError {
|
||||
retryable: true,
|
||||
kind: ErrorKind::Compute,
|
||||
}),
|
||||
ConnectAction::Fail => Err(TestConnectError {
|
||||
retryable: false,
|
||||
kind: ErrorKind::Compute,
|
||||
}),
|
||||
ConnectAction::FailPg => Err(TestConnectError {
|
||||
retryable: false,
|
||||
kind: ErrorKind::Postgres,
|
||||
}),
|
||||
ConnectAction::RetryPg => Err(TestConnectError {
|
||||
retryable: true,
|
||||
kind: ErrorKind::Postgres,
|
||||
}),
|
||||
x => panic!("expecting action {:?}, connect is called instead", x),
|
||||
}
|
||||
}
|
||||
@@ -451,7 +487,7 @@ impl TestBackend for TestConnectMechanism {
|
||||
let action = self.sequence[*counter];
|
||||
*counter += 1;
|
||||
match action {
|
||||
ConnectAction::Wake => Ok(helper_create_cached_node_info()),
|
||||
ConnectAction::Wake => Ok(helper_create_cached_node_info(self.cache)),
|
||||
ConnectAction::WakeFail => {
|
||||
let err = console::errors::ApiError::Console {
|
||||
status: http::StatusCode::FORBIDDEN,
|
||||
@@ -483,37 +519,41 @@ impl TestBackend for TestConnectMechanism {
|
||||
}
|
||||
}
|
||||
|
||||
fn helper_create_cached_node_info() -> CachedNodeInfo {
|
||||
fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo {
|
||||
let node = NodeInfo {
|
||||
config: compute::ConnCfg::new(),
|
||||
aux: Default::default(),
|
||||
allow_self_signed_compute: false,
|
||||
};
|
||||
CachedNodeInfo::new_uncached(node)
|
||||
let (_, node) = cache.insert("key".into(), node);
|
||||
node
|
||||
}
|
||||
|
||||
fn helper_create_connect_info(
|
||||
mechanism: &TestConnectMechanism,
|
||||
) -> (CachedNodeInfo, auth::BackendType<'static, ComputeUserInfo>) {
|
||||
let cache = helper_create_cached_node_info();
|
||||
) -> auth::BackendType<'static, ComputeCredentials, &()> {
|
||||
let user_info = auth::BackendType::Console(
|
||||
MaybeOwned::Owned(ConsoleBackend::Test(Box::new(mechanism.clone()))),
|
||||
ComputeUserInfo {
|
||||
endpoint: "endpoint".into(),
|
||||
user: "user".into(),
|
||||
options: NeonOptions::parse_options_raw(""),
|
||||
ComputeCredentials {
|
||||
info: ComputeUserInfo {
|
||||
endpoint: "endpoint".into(),
|
||||
user: "user".into(),
|
||||
options: NeonOptions::parse_options_raw(""),
|
||||
},
|
||||
keys: ComputeCredentialKeys::Password("password".into()),
|
||||
},
|
||||
);
|
||||
(cache, user_info)
|
||||
user_info
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_to_compute_success() {
|
||||
let _ = env_logger::try_init();
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Connect]);
|
||||
let (cache, user_info) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -521,24 +561,52 @@ async fn connect_to_compute_success() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_to_compute_retry() {
|
||||
let _ = env_logger::try_init();
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Retry, Wake, Retry, Connect]);
|
||||
let (cache, user_info) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_to_compute_retry_pg() {
|
||||
let _ = env_logger::try_init();
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, RetryPg, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_to_compute_fail_pg() {
|
||||
let _ = env_logger::try_init();
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, FailPg]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
|
||||
.await
|
||||
.unwrap_err();
|
||||
mechanism.verify();
|
||||
}
|
||||
|
||||
/// Test that we don't retry if the error is not retryable.
|
||||
#[tokio::test]
|
||||
async fn connect_to_compute_non_retry_1() {
|
||||
let _ = env_logger::try_init();
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Retry, Wake, Retry, Fail]);
|
||||
let (cache, user_info) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Fail]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
|
||||
.await
|
||||
.unwrap_err();
|
||||
mechanism.verify();
|
||||
@@ -547,11 +615,12 @@ async fn connect_to_compute_non_retry_1() {
|
||||
/// Even for non-retryable errors, we should retry at least once.
|
||||
#[tokio::test]
|
||||
async fn connect_to_compute_non_retry_2() {
|
||||
let _ = env_logger::try_init();
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Fail, Wake, Retry, Connect]);
|
||||
let (cache, user_info) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
|
||||
let mechanism = TestConnectMechanism::new(vec![Wake, Fail, Wake, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -560,15 +629,16 @@ async fn connect_to_compute_non_retry_2() {
|
||||
/// Retry for at most `NUM_RETRIES_CONNECT` times.
|
||||
#[tokio::test]
|
||||
async fn connect_to_compute_non_retry_3() {
|
||||
let _ = env_logger::try_init();
|
||||
assert_eq!(NUM_RETRIES_CONNECT, 16);
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![
|
||||
Retry, Wake, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry,
|
||||
Retry, Retry, Retry, Retry, /* the 17th time */ Retry,
|
||||
Wake, Retry, Wake, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry, Retry,
|
||||
Retry, Retry, Retry, Retry, Retry, /* the 17th time */ Retry,
|
||||
]);
|
||||
let (cache, user_info) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
|
||||
.await
|
||||
.unwrap_err();
|
||||
mechanism.verify();
|
||||
@@ -577,11 +647,12 @@ async fn connect_to_compute_non_retry_3() {
|
||||
/// Should retry wake compute.
|
||||
#[tokio::test]
|
||||
async fn wake_retry() {
|
||||
let _ = env_logger::try_init();
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Retry, WakeRetry, Wake, Connect]);
|
||||
let (cache, user_info) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
|
||||
let mechanism = TestConnectMechanism::new(vec![WakeRetry, Wake, Connect]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
|
||||
.await
|
||||
.unwrap();
|
||||
mechanism.verify();
|
||||
@@ -590,11 +661,12 @@ async fn wake_retry() {
|
||||
/// Wake failed with a non-retryable error.
|
||||
#[tokio::test]
|
||||
async fn wake_non_retry() {
|
||||
let _ = env_logger::try_init();
|
||||
use ConnectAction::*;
|
||||
let mut ctx = RequestMonitoring::test();
|
||||
let mechanism = TestConnectMechanism::new(vec![Retry, WakeFail]);
|
||||
let (cache, user_info) = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, cache, &user_info)
|
||||
let mechanism = TestConnectMechanism::new(vec![WakeRetry, WakeFail]);
|
||||
let user_info = helper_create_connect_info(&mechanism);
|
||||
connect_to_compute(&mut ctx, &mechanism, &user_info, false)
|
||||
.await
|
||||
.unwrap_err();
|
||||
mechanism.verify();
|
||||
|
||||
@@ -1,9 +1,4 @@
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::console::{
|
||||
errors::WakeComputeError,
|
||||
provider::{CachedNodeInfo, ConsoleBackend},
|
||||
Api,
|
||||
};
|
||||
use crate::console::{errors::WakeComputeError, provider::CachedNodeInfo};
|
||||
use crate::context::RequestMonitoring;
|
||||
use crate::metrics::{bool_to_str, NUM_WAKEUP_FAILURES};
|
||||
use crate::proxy::retry::retry_after;
|
||||
@@ -11,17 +6,16 @@ use hyper::StatusCode;
|
||||
use std::ops::ControlFlow;
|
||||
use tracing::{error, warn};
|
||||
|
||||
use super::connect_compute::ComputeConnectBackend;
|
||||
use super::retry::ShouldRetry;
|
||||
|
||||
/// wake a compute (or retrieve an existing compute session from cache)
|
||||
pub async fn wake_compute(
|
||||
pub async fn wake_compute<B: ComputeConnectBackend>(
|
||||
num_retries: &mut u32,
|
||||
ctx: &mut RequestMonitoring,
|
||||
api: &ConsoleBackend,
|
||||
info: &ComputeUserInfo,
|
||||
api: &B,
|
||||
) -> Result<CachedNodeInfo, WakeComputeError> {
|
||||
loop {
|
||||
let wake_res = api.wake_compute(ctx, info).await;
|
||||
let wake_res = api.wake_compute(ctx).await;
|
||||
match handle_try_wake(wake_res, *num_retries) {
|
||||
Err(e) => {
|
||||
error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node");
|
||||
|
||||
@@ -4,4 +4,4 @@ mod limiter;
|
||||
pub use aimd::Aimd;
|
||||
pub use limit_algorithm::{AimdConfig, Fixed, RateLimitAlgorithm, RateLimiterConfig};
|
||||
pub use limiter::Limiter;
|
||||
pub use limiter::{EndpointRateLimiter, RateBucketInfo};
|
||||
pub use limiter::{EndpointRateLimiter, RateBucketInfo, RedisRateLimiter};
|
||||
|
||||
@@ -22,6 +22,44 @@ use super::{
|
||||
RateLimiterConfig,
|
||||
};
|
||||
|
||||
pub struct RedisRateLimiter {
|
||||
data: Vec<RateBucket>,
|
||||
info: &'static [RateBucketInfo],
|
||||
}
|
||||
|
||||
impl RedisRateLimiter {
|
||||
pub fn new(info: &'static [RateBucketInfo]) -> Self {
|
||||
Self {
|
||||
data: vec![
|
||||
RateBucket {
|
||||
start: Instant::now(),
|
||||
count: 0,
|
||||
};
|
||||
info.len()
|
||||
],
|
||||
info,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check that number of connections is below `max_rps` rps.
|
||||
pub fn check(&mut self) -> bool {
|
||||
let now = Instant::now();
|
||||
|
||||
let should_allow_request = self
|
||||
.data
|
||||
.iter_mut()
|
||||
.zip(self.info)
|
||||
.all(|(bucket, info)| bucket.should_allow_request(info, now));
|
||||
|
||||
if should_allow_request {
|
||||
// only increment the bucket counts if the request will actually be accepted
|
||||
self.data.iter_mut().for_each(RateBucket::inc);
|
||||
}
|
||||
|
||||
should_allow_request
|
||||
}
|
||||
}
|
||||
|
||||
// Simple per-endpoint rate limiter.
|
||||
//
|
||||
// Check that number of connections to the endpoint is below `max_rps` rps.
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
pub mod notifications;
|
||||
pub mod publisher;
|
||||
|
||||
@@ -1,38 +1,44 @@
|
||||
use std::{convert::Infallible, sync::Arc};
|
||||
|
||||
use futures::StreamExt;
|
||||
use pq_proto::CancelKeyData;
|
||||
use redis::aio::PubSub;
|
||||
use serde::Deserialize;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
cache::project_info::ProjectInfoCache,
|
||||
cancellation::{CancelMap, CancellationHandler, NotificationsCancellationHandler},
|
||||
intern::{ProjectIdInt, RoleNameInt},
|
||||
};
|
||||
|
||||
const CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
|
||||
const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
|
||||
pub(crate) const PROXY_CHANNEL_NAME: &str = "neondb-proxy-to-proxy-updates";
|
||||
const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
|
||||
const INVALIDATION_LAG: std::time::Duration = std::time::Duration::from_secs(20);
|
||||
|
||||
struct ConsoleRedisClient {
|
||||
struct RedisConsumerClient {
|
||||
client: redis::Client,
|
||||
}
|
||||
|
||||
impl ConsoleRedisClient {
|
||||
impl RedisConsumerClient {
|
||||
pub fn new(url: &str) -> anyhow::Result<Self> {
|
||||
let client = redis::Client::open(url)?;
|
||||
Ok(Self { client })
|
||||
}
|
||||
async fn try_connect(&self) -> anyhow::Result<PubSub> {
|
||||
let mut conn = self.client.get_async_connection().await?.into_pubsub();
|
||||
tracing::info!("subscribing to a channel `{CHANNEL_NAME}`");
|
||||
conn.subscribe(CHANNEL_NAME).await?;
|
||||
tracing::info!("subscribing to a channel `{CPLANE_CHANNEL_NAME}`");
|
||||
conn.subscribe(CPLANE_CHANNEL_NAME).await?;
|
||||
tracing::info!("subscribing to a channel `{PROXY_CHANNEL_NAME}`");
|
||||
conn.subscribe(PROXY_CHANNEL_NAME).await?;
|
||||
Ok(conn)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
#[serde(tag = "topic", content = "data")]
|
||||
enum Notification {
|
||||
pub(crate) enum Notification {
|
||||
#[serde(
|
||||
rename = "/allowed_ips_updated",
|
||||
deserialize_with = "deserialize_json_string"
|
||||
@@ -45,16 +51,25 @@ enum Notification {
|
||||
deserialize_with = "deserialize_json_string"
|
||||
)]
|
||||
PasswordUpdate { password_update: PasswordUpdate },
|
||||
#[serde(rename = "/cancel_session")]
|
||||
Cancel(CancelSession),
|
||||
}
|
||||
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
|
||||
struct AllowedIpsUpdate {
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
pub(crate) struct AllowedIpsUpdate {
|
||||
project_id: ProjectIdInt,
|
||||
}
|
||||
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
|
||||
struct PasswordUpdate {
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
pub(crate) struct PasswordUpdate {
|
||||
project_id: ProjectIdInt,
|
||||
role_name: RoleNameInt,
|
||||
}
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
|
||||
pub(crate) struct CancelSession {
|
||||
pub region_id: Option<String>,
|
||||
pub cancel_key_data: CancelKeyData,
|
||||
pub session_id: Uuid,
|
||||
}
|
||||
|
||||
fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result<T, D::Error>
|
||||
where
|
||||
T: for<'de2> serde::Deserialize<'de2>,
|
||||
@@ -64,6 +79,88 @@ where
|
||||
serde_json::from_str(&s).map_err(<D::Error as serde::de::Error>::custom)
|
||||
}
|
||||
|
||||
struct MessageHandler<
|
||||
C: ProjectInfoCache + Send + Sync + 'static,
|
||||
H: NotificationsCancellationHandler + Send + Sync + 'static,
|
||||
> {
|
||||
cache: Arc<C>,
|
||||
cancellation_handler: Arc<H>,
|
||||
region_id: String,
|
||||
}
|
||||
|
||||
impl<
|
||||
C: ProjectInfoCache + Send + Sync + 'static,
|
||||
H: NotificationsCancellationHandler + Send + Sync + 'static,
|
||||
> MessageHandler<C, H>
|
||||
{
|
||||
pub fn new(cache: Arc<C>, cancellation_handler: Arc<H>, region_id: String) -> Self {
|
||||
Self {
|
||||
cache,
|
||||
cancellation_handler,
|
||||
region_id,
|
||||
}
|
||||
}
|
||||
pub fn disable_ttl(&self) {
|
||||
self.cache.disable_ttl();
|
||||
}
|
||||
pub fn enable_ttl(&self) {
|
||||
self.cache.enable_ttl();
|
||||
}
|
||||
#[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))]
|
||||
async fn handle_message(&self, msg: redis::Msg) -> anyhow::Result<()> {
|
||||
use Notification::*;
|
||||
let payload: String = msg.get_payload()?;
|
||||
tracing::debug!(?payload, "received a message payload");
|
||||
|
||||
let msg: Notification = match serde_json::from_str(&payload) {
|
||||
Ok(msg) => msg,
|
||||
Err(e) => {
|
||||
tracing::error!("broken message: {e}");
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
tracing::debug!(?msg, "received a message");
|
||||
match msg {
|
||||
Cancel(cancel_session) => {
|
||||
tracing::Span::current().record(
|
||||
"session_id",
|
||||
&tracing::field::display(cancel_session.session_id),
|
||||
);
|
||||
if let Some(cancel_region) = cancel_session.region_id {
|
||||
// If the message is not for this region, ignore it.
|
||||
if cancel_region != self.region_id {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
// This instance of cancellation_handler doesn't have a RedisPublisherClient so it can't publish the message.
|
||||
match self
|
||||
.cancellation_handler
|
||||
.cancel_session_no_publish(cancel_session.cancel_key_data)
|
||||
.await
|
||||
{
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
tracing::error!("failed to cancel session: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
invalidate_cache(self.cache.clone(), msg.clone());
|
||||
// It might happen that the invalid entry is on the way to be cached.
|
||||
// To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
|
||||
// TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
|
||||
let cache = self.cache.clone();
|
||||
tokio::spawn(async move {
|
||||
tokio::time::sleep(INVALIDATION_LAG).await;
|
||||
invalidate_cache(cache, msg);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
|
||||
use Notification::*;
|
||||
match msg {
|
||||
@@ -74,50 +171,33 @@ fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
|
||||
password_update.project_id,
|
||||
password_update.role_name,
|
||||
),
|
||||
Cancel(_) => unreachable!("cancel message should be handled separately"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(cache))]
|
||||
fn handle_message<C>(msg: redis::Msg, cache: Arc<C>) -> anyhow::Result<()>
|
||||
where
|
||||
C: ProjectInfoCache + Send + Sync + 'static,
|
||||
{
|
||||
let payload: String = msg.get_payload()?;
|
||||
tracing::debug!(?payload, "received a message payload");
|
||||
|
||||
let msg: Notification = match serde_json::from_str(&payload) {
|
||||
Ok(msg) => msg,
|
||||
Err(e) => {
|
||||
tracing::error!("broken message: {e}");
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
tracing::debug!(?msg, "received a message");
|
||||
invalidate_cache(cache.clone(), msg.clone());
|
||||
// It might happen that the invalid entry is on the way to be cached.
|
||||
// To make sure that the entry is invalidated, let's repeat the invalidation in INVALIDATION_LAG seconds.
|
||||
// TODO: include the version (or the timestamp) in the message and invalidate only if the entry is cached before the message.
|
||||
tokio::spawn(async move {
|
||||
tokio::time::sleep(INVALIDATION_LAG).await;
|
||||
invalidate_cache(cache, msg.clone());
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle console's invalidation messages.
|
||||
#[tracing::instrument(name = "console_notifications", skip_all)]
|
||||
pub async fn task_main<C>(url: String, cache: Arc<C>) -> anyhow::Result<Infallible>
|
||||
pub async fn task_main<C>(
|
||||
url: String,
|
||||
cache: Arc<C>,
|
||||
cancel_map: CancelMap,
|
||||
region_id: String,
|
||||
) -> anyhow::Result<Infallible>
|
||||
where
|
||||
C: ProjectInfoCache + Send + Sync + 'static,
|
||||
{
|
||||
cache.enable_ttl();
|
||||
let handler = MessageHandler::new(
|
||||
cache,
|
||||
Arc::new(CancellationHandler::new(cancel_map, None)),
|
||||
region_id,
|
||||
);
|
||||
|
||||
loop {
|
||||
let redis = ConsoleRedisClient::new(&url)?;
|
||||
let redis = RedisConsumerClient::new(&url)?;
|
||||
let conn = match redis.try_connect().await {
|
||||
Ok(conn) => {
|
||||
cache.disable_ttl();
|
||||
handler.disable_ttl();
|
||||
conn
|
||||
}
|
||||
Err(e) => {
|
||||
@@ -130,7 +210,7 @@ where
|
||||
};
|
||||
let mut stream = conn.into_on_message();
|
||||
while let Some(msg) = stream.next().await {
|
||||
match handle_message(msg, cache.clone()) {
|
||||
match handler.handle_message(msg).await {
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
tracing::error!("failed to handle message: {e}, will try to reconnect");
|
||||
@@ -138,7 +218,7 @@ where
|
||||
}
|
||||
}
|
||||
}
|
||||
cache.enable_ttl();
|
||||
handler.enable_ttl();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -198,6 +278,33 @@ mod tests {
|
||||
}
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
#[test]
|
||||
fn parse_cancel_session() -> anyhow::Result<()> {
|
||||
let cancel_key_data = CancelKeyData {
|
||||
backend_pid: 42,
|
||||
cancel_key: 41,
|
||||
};
|
||||
let uuid = uuid::Uuid::new_v4();
|
||||
let msg = Notification::Cancel(CancelSession {
|
||||
cancel_key_data,
|
||||
region_id: None,
|
||||
session_id: uuid,
|
||||
});
|
||||
let text = serde_json::to_string(&msg)?;
|
||||
let result: Notification = serde_json::from_str(&text)?;
|
||||
assert_eq!(msg, result);
|
||||
|
||||
let msg = Notification::Cancel(CancelSession {
|
||||
cancel_key_data,
|
||||
region_id: Some("region".to_string()),
|
||||
session_id: uuid,
|
||||
});
|
||||
let text = serde_json::to_string(&msg)?;
|
||||
let result: Notification = serde_json::from_str(&text)?;
|
||||
assert_eq!(msg, result,);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
80
proxy/src/redis/publisher.rs
Normal file
80
proxy/src/redis/publisher.rs
Normal file
@@ -0,0 +1,80 @@
|
||||
use pq_proto::CancelKeyData;
|
||||
use redis::AsyncCommands;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::rate_limiter::{RateBucketInfo, RedisRateLimiter};
|
||||
|
||||
use super::notifications::{CancelSession, Notification, PROXY_CHANNEL_NAME};
|
||||
|
||||
pub struct RedisPublisherClient {
|
||||
client: redis::Client,
|
||||
publisher: Option<redis::aio::Connection>,
|
||||
region_id: String,
|
||||
limiter: RedisRateLimiter,
|
||||
}
|
||||
|
||||
impl RedisPublisherClient {
|
||||
pub fn new(
|
||||
url: &str,
|
||||
region_id: String,
|
||||
info: &'static [RateBucketInfo],
|
||||
) -> anyhow::Result<Self> {
|
||||
let client = redis::Client::open(url)?;
|
||||
Ok(Self {
|
||||
client,
|
||||
publisher: None,
|
||||
region_id,
|
||||
limiter: RedisRateLimiter::new(info),
|
||||
})
|
||||
}
|
||||
pub async fn try_publish(
|
||||
&mut self,
|
||||
cancel_key_data: CancelKeyData,
|
||||
session_id: Uuid,
|
||||
) -> anyhow::Result<()> {
|
||||
if !self.limiter.check() {
|
||||
tracing::info!("Rate limit exceeded. Skipping cancellation message");
|
||||
return Err(anyhow::anyhow!("Rate limit exceeded"));
|
||||
}
|
||||
match self.publish(cancel_key_data, session_id).await {
|
||||
Ok(()) => return Ok(()),
|
||||
Err(e) => {
|
||||
tracing::error!("failed to publish a message: {e}");
|
||||
self.publisher = None;
|
||||
}
|
||||
}
|
||||
tracing::info!("Publisher is disconnected. Reconnectiong...");
|
||||
self.try_connect().await?;
|
||||
self.publish(cancel_key_data, session_id).await
|
||||
}
|
||||
|
||||
async fn publish(
|
||||
&mut self,
|
||||
cancel_key_data: CancelKeyData,
|
||||
session_id: Uuid,
|
||||
) -> anyhow::Result<()> {
|
||||
let conn = self
|
||||
.publisher
|
||||
.as_mut()
|
||||
.ok_or_else(|| anyhow::anyhow!("not connected"))?;
|
||||
let payload = serde_json::to_string(&Notification::Cancel(CancelSession {
|
||||
region_id: Some(self.region_id.clone()),
|
||||
cancel_key_data,
|
||||
session_id,
|
||||
}))?;
|
||||
conn.publish(PROXY_CHANNEL_NAME, payload).await?;
|
||||
Ok(())
|
||||
}
|
||||
pub async fn try_connect(&mut self) -> anyhow::Result<()> {
|
||||
match self.client.get_async_connection().await {
|
||||
Ok(conn) => {
|
||||
self.publisher = Some(conn);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("failed to connect to redis: {e}");
|
||||
return Err(e.into());
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -24,7 +24,7 @@ use crate::metrics::NUM_CLIENT_CONNECTION_GAUGE;
|
||||
use crate::protocol2::{ProxyProtocolAccept, WithClientIp};
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::serverless::backend::PoolingBackend;
|
||||
use crate::{cancellation::CancelMap, config::ProxyConfig};
|
||||
use crate::{cancellation::CancellationHandler, config::ProxyConfig};
|
||||
use futures::StreamExt;
|
||||
use hyper::{
|
||||
server::{
|
||||
@@ -50,6 +50,7 @@ pub async fn task_main(
|
||||
ws_listener: TcpListener,
|
||||
cancellation_token: CancellationToken,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
) -> anyhow::Result<()> {
|
||||
scopeguard::defer! {
|
||||
info!("websocket server has shut down");
|
||||
@@ -115,7 +116,7 @@ pub async fn task_main(
|
||||
let backend = backend.clone();
|
||||
let ws_connections = ws_connections.clone();
|
||||
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
||||
|
||||
let cancellation_handler = cancellation_handler.clone();
|
||||
async move {
|
||||
let peer_addr = match client_addr {
|
||||
Some(addr) => addr,
|
||||
@@ -127,9 +128,9 @@ pub async fn task_main(
|
||||
let backend = backend.clone();
|
||||
let ws_connections = ws_connections.clone();
|
||||
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
||||
let cancellation_handler = cancellation_handler.clone();
|
||||
|
||||
async move {
|
||||
let cancel_map = Arc::new(CancelMap::default());
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
|
||||
request_handler(
|
||||
@@ -137,7 +138,7 @@ pub async fn task_main(
|
||||
config,
|
||||
backend,
|
||||
ws_connections,
|
||||
cancel_map,
|
||||
cancellation_handler,
|
||||
session_id,
|
||||
peer_addr.ip(),
|
||||
endpoint_rate_limiter,
|
||||
@@ -205,7 +206,7 @@ async fn request_handler(
|
||||
config: &'static ProxyConfig,
|
||||
backend: Arc<PoolingBackend>,
|
||||
ws_connections: TaskTracker,
|
||||
cancel_map: Arc<CancelMap>,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
session_id: uuid::Uuid,
|
||||
peer_addr: IpAddr,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
@@ -232,7 +233,7 @@ async fn request_handler(
|
||||
config,
|
||||
ctx,
|
||||
websocket,
|
||||
cancel_map,
|
||||
cancellation_handler,
|
||||
host,
|
||||
endpoint_rate_limiter,
|
||||
)
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use tracing::info;
|
||||
use tracing::{field::display, info};
|
||||
|
||||
use crate::{
|
||||
auth::{backend::ComputeCredentialKeys, check_peer_addr_is_in_list, AuthError},
|
||||
auth::{backend::ComputeCredentials, check_peer_addr_is_in_list, AuthError},
|
||||
compute,
|
||||
config::ProxyConfig,
|
||||
console::{
|
||||
@@ -15,7 +15,7 @@ use crate::{
|
||||
proxy::connect_compute::ConnectMechanism,
|
||||
};
|
||||
|
||||
use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool, APP_NAME};
|
||||
use super::conn_pool::{poll_client, Client, ConnInfo, GlobalConnPool};
|
||||
|
||||
pub struct PoolingBackend {
|
||||
pub pool: Arc<GlobalConnPool<tokio_postgres::Client>>,
|
||||
@@ -27,7 +27,7 @@ impl PoolingBackend {
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
conn_info: &ConnInfo,
|
||||
) -> Result<ComputeCredentialKeys, AuthError> {
|
||||
) -> Result<ComputeCredentials, AuthError> {
|
||||
let user_info = conn_info.user_info.clone();
|
||||
let backend = self.config.auth_backend.as_ref().map(|_| user_info.clone());
|
||||
let (allowed_ips, maybe_secret) = backend.get_allowed_ips_and_secret(ctx).await?;
|
||||
@@ -49,13 +49,17 @@ impl PoolingBackend {
|
||||
};
|
||||
let auth_outcome =
|
||||
crate::auth::validate_password_and_exchange(&conn_info.password, secret)?;
|
||||
match auth_outcome {
|
||||
let res = match auth_outcome {
|
||||
crate::sasl::Outcome::Success(key) => Ok(key),
|
||||
crate::sasl::Outcome::Failure(reason) => {
|
||||
info!("auth backend failed with an error: {reason}");
|
||||
Err(AuthError::auth_failed(&*conn_info.user_info.user))
|
||||
}
|
||||
}
|
||||
};
|
||||
res.map(|key| ComputeCredentials {
|
||||
info: user_info,
|
||||
keys: key,
|
||||
})
|
||||
}
|
||||
|
||||
// Wake up the destination if needed. Code here is a bit involved because
|
||||
@@ -66,7 +70,7 @@ impl PoolingBackend {
|
||||
&self,
|
||||
ctx: &mut RequestMonitoring,
|
||||
conn_info: ConnInfo,
|
||||
keys: ComputeCredentialKeys,
|
||||
keys: ComputeCredentials,
|
||||
force_new: bool,
|
||||
) -> Result<Client<tokio_postgres::Client>, HttpConnError> {
|
||||
let maybe_client = if !force_new {
|
||||
@@ -81,27 +85,9 @@ impl PoolingBackend {
|
||||
return Ok(client);
|
||||
}
|
||||
let conn_id = uuid::Uuid::new_v4();
|
||||
tracing::Span::current().record("conn_id", display(conn_id));
|
||||
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
|
||||
ctx.set_application(Some(APP_NAME));
|
||||
let backend = self
|
||||
.config
|
||||
.auth_backend
|
||||
.as_ref()
|
||||
.map(|_| conn_info.user_info.clone());
|
||||
|
||||
let mut node_info = backend
|
||||
.wake_compute(ctx)
|
||||
.await?
|
||||
.ok_or(HttpConnError::NoComputeInfo)?;
|
||||
|
||||
match keys {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
ComputeCredentialKeys::Password(password) => node_info.config.password(password),
|
||||
ComputeCredentialKeys::AuthKeys(auth_keys) => node_info.config.auth_keys(auth_keys),
|
||||
};
|
||||
|
||||
ctx.set_project(node_info.aux.clone());
|
||||
|
||||
let backend = self.config.auth_backend.as_ref().map(|_| keys);
|
||||
crate::proxy::connect_compute::connect_to_compute(
|
||||
ctx,
|
||||
&TokioMechanism {
|
||||
@@ -109,8 +95,8 @@ impl PoolingBackend {
|
||||
conn_info,
|
||||
pool: self.pool.clone(),
|
||||
},
|
||||
node_info,
|
||||
&backend,
|
||||
false, // do not allow self signed compute for http flow
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -129,8 +115,6 @@ pub enum HttpConnError {
|
||||
AuthError(#[from] AuthError),
|
||||
#[error("wake_compute returned error")]
|
||||
WakeCompute(#[from] WakeComputeError),
|
||||
#[error("wake_compute returned nothing")]
|
||||
NoComputeInfo,
|
||||
}
|
||||
|
||||
struct TokioMechanism {
|
||||
|
||||
@@ -4,7 +4,6 @@ use metrics::IntCounterPairGuard;
|
||||
use parking_lot::RwLock;
|
||||
use rand::Rng;
|
||||
use smallvec::SmallVec;
|
||||
use smol_str::SmolStr;
|
||||
use std::{collections::HashMap, pin::pin, sync::Arc, sync::Weak, time::Duration};
|
||||
use std::{
|
||||
fmt,
|
||||
@@ -31,8 +30,6 @@ use tracing::{info, info_span, Instrument};
|
||||
|
||||
use super::backend::HttpConnError;
|
||||
|
||||
pub const APP_NAME: SmolStr = SmolStr::new_inline("/sql_over_http");
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ConnInfo {
|
||||
pub user_info: ComputeUserInfo,
|
||||
@@ -379,12 +376,13 @@ impl<C: ClientInnerExt> GlobalConnPool<C> {
|
||||
info!("pool: cached connection '{conn_info}' is closed, opening a new one");
|
||||
return Ok(None);
|
||||
} else {
|
||||
info!("pool: reusing connection '{conn_info}'");
|
||||
client.session.send(ctx.session_id)?;
|
||||
tracing::Span::current().record("conn_id", tracing::field::display(client.conn_id));
|
||||
tracing::Span::current().record(
|
||||
"pid",
|
||||
&tracing::field::display(client.inner.get_process_id()),
|
||||
);
|
||||
info!("pool: reusing connection '{conn_info}'");
|
||||
client.session.send(ctx.session_id)?;
|
||||
ctx.latency_timer.pool_hit();
|
||||
ctx.latency_timer.success();
|
||||
return Ok(Some(Client::new(client, conn_info.clone(), endpoint_pool)));
|
||||
@@ -577,7 +575,6 @@ pub struct Client<C: ClientInnerExt> {
|
||||
}
|
||||
|
||||
pub struct Discard<'a, C: ClientInnerExt> {
|
||||
conn_id: uuid::Uuid,
|
||||
conn_info: &'a ConnInfo,
|
||||
pool: &'a mut Weak<RwLock<EndpointConnPool<C>>>,
|
||||
}
|
||||
@@ -603,14 +600,7 @@ impl<C: ClientInnerExt> Client<C> {
|
||||
span: _,
|
||||
} = self;
|
||||
let inner = inner.as_mut().expect("client inner should not be removed");
|
||||
(
|
||||
&mut inner.inner,
|
||||
Discard {
|
||||
pool,
|
||||
conn_info,
|
||||
conn_id: inner.conn_id,
|
||||
},
|
||||
)
|
||||
(&mut inner.inner, Discard { pool, conn_info })
|
||||
}
|
||||
|
||||
pub fn check_idle(&mut self, status: ReadyForQueryStatus) {
|
||||
@@ -625,13 +615,13 @@ impl<C: ClientInnerExt> Discard<'_, C> {
|
||||
pub fn check_idle(&mut self, status: ReadyForQueryStatus) {
|
||||
let conn_info = &self.conn_info;
|
||||
if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 {
|
||||
info!(conn_id = %self.conn_id, "pool: throwing away connection '{conn_info}' because connection is not idle")
|
||||
info!("pool: throwing away connection '{conn_info}' because connection is not idle")
|
||||
}
|
||||
}
|
||||
pub fn discard(&mut self) {
|
||||
let conn_info = &self.conn_info;
|
||||
if std::mem::take(self.pool).strong_count() > 0 {
|
||||
info!(conn_id = %self.conn_id, "pool: throwing away connection '{conn_info}' because connection is potentially in a broken state")
|
||||
info!("pool: throwing away connection '{conn_info}' because connection is potentially in a broken state")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,6 +36,8 @@ use crate::error::ReportableError;
|
||||
use crate::metrics::HTTP_CONTENT_LENGTH;
|
||||
use crate::metrics::NUM_CONNECTION_REQUESTS_GAUGE;
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::serverless::backend::HttpConnError;
|
||||
use crate::DbName;
|
||||
use crate::RoleName;
|
||||
|
||||
use super::backend::PoolingBackend;
|
||||
@@ -117,6 +119,9 @@ fn get_conn_info(
|
||||
headers: &HeaderMap,
|
||||
tls: &TlsConfig,
|
||||
) -> Result<ConnInfo, ConnInfoError> {
|
||||
// HTTP only uses cleartext (for now and likely always)
|
||||
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
|
||||
|
||||
let connection_string = headers
|
||||
.get("Neon-Connection-String")
|
||||
.ok_or(ConnInfoError::InvalidHeader("Neon-Connection-String"))?
|
||||
@@ -134,7 +139,8 @@ fn get_conn_info(
|
||||
.path_segments()
|
||||
.ok_or(ConnInfoError::MissingDbName)?;
|
||||
|
||||
let dbname = url_path.next().ok_or(ConnInfoError::InvalidDbName)?;
|
||||
let dbname: DbName = url_path.next().ok_or(ConnInfoError::InvalidDbName)?.into();
|
||||
ctx.set_dbname(dbname.clone());
|
||||
|
||||
let username = RoleName::from(urlencoding::decode(connection_url.username())?);
|
||||
if username.is_empty() {
|
||||
@@ -174,7 +180,7 @@ fn get_conn_info(
|
||||
|
||||
Ok(ConnInfo {
|
||||
user_info,
|
||||
dbname: dbname.into(),
|
||||
dbname,
|
||||
password: match password {
|
||||
std::borrow::Cow::Borrowed(b) => b.into(),
|
||||
std::borrow::Cow::Owned(b) => b.into(),
|
||||
@@ -300,7 +306,14 @@ pub async fn handle(
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
#[instrument(name = "sql-over-http", fields(pid = tracing::field::Empty), skip_all)]
|
||||
#[instrument(
|
||||
name = "sql-over-http",
|
||||
skip_all,
|
||||
fields(
|
||||
pid = tracing::field::Empty,
|
||||
conn_id = tracing::field::Empty
|
||||
)
|
||||
)]
|
||||
async fn handle_inner(
|
||||
config: &'static ProxyConfig,
|
||||
ctx: &mut RequestMonitoring,
|
||||
@@ -354,12 +367,10 @@ async fn handle_inner(
|
||||
let txn_read_only = headers.get(&TXN_READ_ONLY) == Some(&HEADER_VALUE_TRUE);
|
||||
let txn_deferrable = headers.get(&TXN_DEFERRABLE) == Some(&HEADER_VALUE_TRUE);
|
||||
|
||||
let paused = ctx.latency_timer.pause();
|
||||
let request_content_length = match request.body().size_hint().upper() {
|
||||
Some(v) => v,
|
||||
None => MAX_REQUEST_SIZE + 1,
|
||||
};
|
||||
drop(paused);
|
||||
info!(request_content_length, "request size in bytes");
|
||||
HTTP_CONTENT_LENGTH.observe(request_content_length as f64);
|
||||
|
||||
@@ -375,15 +386,20 @@ async fn handle_inner(
|
||||
let body = hyper::body::to_bytes(request.into_body())
|
||||
.await
|
||||
.map_err(anyhow::Error::from)?;
|
||||
info!(length = body.len(), "request payload read");
|
||||
let payload: Payload = serde_json::from_slice(&body)?;
|
||||
Ok::<Payload, anyhow::Error>(payload) // Adjust error type accordingly
|
||||
};
|
||||
|
||||
let authenticate_and_connect = async {
|
||||
let keys = backend.authenticate(ctx, &conn_info).await?;
|
||||
backend
|
||||
let client = backend
|
||||
.connect_to_compute(ctx, conn_info, keys, !allow_pool)
|
||||
.await
|
||||
.await?;
|
||||
// not strictly necessary to mark success here,
|
||||
// but it's just insurance for if we forget it somewhere else
|
||||
ctx.latency_timer.success();
|
||||
Ok::<_, HttpConnError>(client)
|
||||
};
|
||||
|
||||
// Run both operations in parallel
|
||||
@@ -415,6 +431,7 @@ async fn handle_inner(
|
||||
results
|
||||
}
|
||||
Payload::Batch(statements) => {
|
||||
info!("starting transaction");
|
||||
let (inner, mut discard) = client.inner();
|
||||
let mut builder = inner.build_transaction();
|
||||
if let Some(isolation_level) = txn_isolation_level {
|
||||
@@ -444,6 +461,7 @@ async fn handle_inner(
|
||||
.await
|
||||
{
|
||||
Ok(results) => {
|
||||
info!("commit");
|
||||
let status = transaction.commit().await.map_err(|e| {
|
||||
// if we cannot commit - for now don't return connection to pool
|
||||
// TODO: get a query status from the error
|
||||
@@ -454,6 +472,7 @@ async fn handle_inner(
|
||||
results
|
||||
}
|
||||
Err(err) => {
|
||||
info!("rollback");
|
||||
let status = transaction.rollback().await.map_err(|e| {
|
||||
// if we cannot rollback - for now don't return connection to pool
|
||||
// TODO: get a query status from the error
|
||||
@@ -528,8 +547,10 @@ async fn query_to_json<T: GenericClient>(
|
||||
raw_output: bool,
|
||||
default_array_mode: bool,
|
||||
) -> anyhow::Result<(ReadyForQueryStatus, Value)> {
|
||||
info!("executing query");
|
||||
let query_params = data.params;
|
||||
let row_stream = client.query_raw_txt(&data.query, query_params).await?;
|
||||
info!("finished executing query");
|
||||
|
||||
// Manually drain the stream into a vector to leave row_stream hanging
|
||||
// around to get a command tag. Also check that the response is not too
|
||||
@@ -564,6 +585,13 @@ async fn query_to_json<T: GenericClient>(
|
||||
}
|
||||
.and_then(|s| s.parse::<i64>().ok());
|
||||
|
||||
info!(
|
||||
rows = rows.len(),
|
||||
?ready,
|
||||
command_tag,
|
||||
"finished reading rows"
|
||||
);
|
||||
|
||||
let mut fields = vec![];
|
||||
let mut columns = vec![];
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::{
|
||||
cancellation::CancelMap,
|
||||
cancellation::CancellationHandler,
|
||||
config::ProxyConfig,
|
||||
context::RequestMonitoring,
|
||||
error::{io_error, ReportableError},
|
||||
@@ -133,7 +133,7 @@ pub async fn serve_websocket(
|
||||
config: &'static ProxyConfig,
|
||||
mut ctx: RequestMonitoring,
|
||||
websocket: HyperWebsocket,
|
||||
cancel_map: Arc<CancelMap>,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
hostname: Option<String>,
|
||||
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
|
||||
) -> anyhow::Result<()> {
|
||||
@@ -141,7 +141,7 @@ pub async fn serve_websocket(
|
||||
let res = handle_client(
|
||||
config,
|
||||
&mut ctx,
|
||||
cancel_map,
|
||||
cancellation_handler,
|
||||
WebSocketRw::new(websocket),
|
||||
ClientMode::Websockets { hostname },
|
||||
endpoint_rate_limiter,
|
||||
|
||||
@@ -61,3 +61,10 @@ tokio-stream.workspace = true
|
||||
utils.workspace = true
|
||||
|
||||
workspace_hack.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
walproposer.workspace = true
|
||||
rand.workspace = true
|
||||
desim.workspace = true
|
||||
tracing.workspace = true
|
||||
tracing-subscriber = { workspace = true, features = ["json"] }
|
||||
|
||||
155
safekeeper/tests/misc_test.rs
Normal file
155
safekeeper/tests/misc_test.rs
Normal file
@@ -0,0 +1,155 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use tracing::{info, warn};
|
||||
use utils::lsn::Lsn;
|
||||
|
||||
use crate::walproposer_sim::{
|
||||
log::{init_logger, init_tracing_logger},
|
||||
simulation::{generate_network_opts, generate_schedule, Schedule, TestAction, TestConfig},
|
||||
};
|
||||
|
||||
pub mod walproposer_sim;
|
||||
|
||||
// Test that simulation supports restarting (crashing) safekeepers.
|
||||
#[test]
|
||||
fn crash_safekeeper() {
|
||||
let clock = init_logger();
|
||||
let config = TestConfig::new(Some(clock));
|
||||
let test = config.start(1337);
|
||||
|
||||
let lsn = test.sync_safekeepers().unwrap();
|
||||
assert_eq!(lsn, Lsn(0));
|
||||
info!("Sucessfully synced empty safekeepers at 0/0");
|
||||
|
||||
let mut wp = test.launch_walproposer(lsn);
|
||||
|
||||
// Write some WAL and crash safekeeper 0 without waiting for replication.
|
||||
test.poll_for_duration(30);
|
||||
wp.write_tx(3);
|
||||
test.servers[0].restart();
|
||||
|
||||
// Wait some time, so that walproposer can reconnect.
|
||||
test.poll_for_duration(2000);
|
||||
}
|
||||
|
||||
// Test that walproposer can be crashed (stopped).
|
||||
#[test]
|
||||
fn test_simple_restart() {
|
||||
let clock = init_logger();
|
||||
let config = TestConfig::new(Some(clock));
|
||||
let test = config.start(1337);
|
||||
|
||||
let lsn = test.sync_safekeepers().unwrap();
|
||||
assert_eq!(lsn, Lsn(0));
|
||||
info!("Sucessfully synced empty safekeepers at 0/0");
|
||||
|
||||
let mut wp = test.launch_walproposer(lsn);
|
||||
|
||||
test.poll_for_duration(30);
|
||||
wp.write_tx(3);
|
||||
test.poll_for_duration(100);
|
||||
|
||||
wp.stop();
|
||||
drop(wp);
|
||||
|
||||
let lsn = test.sync_safekeepers().unwrap();
|
||||
info!("Sucessfully synced safekeepers at {}", lsn);
|
||||
}
|
||||
|
||||
// Test runnning a simple schedule, restarting everything a several times.
|
||||
#[test]
|
||||
fn test_simple_schedule() -> anyhow::Result<()> {
|
||||
let clock = init_logger();
|
||||
let mut config = TestConfig::new(Some(clock));
|
||||
config.network.keepalive_timeout = Some(100);
|
||||
let test = config.start(1337);
|
||||
|
||||
let schedule: Schedule = vec![
|
||||
(0, TestAction::RestartWalProposer),
|
||||
(50, TestAction::WriteTx(5)),
|
||||
(100, TestAction::RestartSafekeeper(0)),
|
||||
(100, TestAction::WriteTx(5)),
|
||||
(110, TestAction::RestartSafekeeper(1)),
|
||||
(110, TestAction::WriteTx(5)),
|
||||
(120, TestAction::RestartSafekeeper(2)),
|
||||
(120, TestAction::WriteTx(5)),
|
||||
(201, TestAction::RestartWalProposer),
|
||||
(251, TestAction::RestartSafekeeper(0)),
|
||||
(251, TestAction::RestartSafekeeper(1)),
|
||||
(251, TestAction::RestartSafekeeper(2)),
|
||||
(251, TestAction::WriteTx(5)),
|
||||
(255, TestAction::WriteTx(5)),
|
||||
(1000, TestAction::WriteTx(5)),
|
||||
];
|
||||
|
||||
test.run_schedule(&schedule)?;
|
||||
info!("Test finished, stopping all threads");
|
||||
test.world.deallocate();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Test that simulation can process 10^4 transactions.
|
||||
#[test]
|
||||
fn test_many_tx() -> anyhow::Result<()> {
|
||||
let clock = init_logger();
|
||||
let config = TestConfig::new(Some(clock));
|
||||
let test = config.start(1337);
|
||||
|
||||
let mut schedule: Schedule = vec![];
|
||||
for i in 0..100 {
|
||||
schedule.push((i * 10, TestAction::WriteTx(100)));
|
||||
}
|
||||
|
||||
test.run_schedule(&schedule)?;
|
||||
info!("Test finished, stopping all threads");
|
||||
test.world.stop_all();
|
||||
|
||||
let events = test.world.take_events();
|
||||
info!("Events: {:?}", events);
|
||||
let last_commit_lsn = events
|
||||
.iter()
|
||||
.filter_map(|event| {
|
||||
if event.data.starts_with("commit_lsn;") {
|
||||
let lsn: u64 = event.data.split(';').nth(1).unwrap().parse().unwrap();
|
||||
return Some(lsn);
|
||||
}
|
||||
None
|
||||
})
|
||||
.last()
|
||||
.unwrap();
|
||||
|
||||
let initdb_lsn = 21623024;
|
||||
let diff = last_commit_lsn - initdb_lsn;
|
||||
info!("Last commit lsn: {}, diff: {}", last_commit_lsn, diff);
|
||||
// each tx is at least 8 bytes, it's written a 100 times for in a loop for 100 times
|
||||
assert!(diff > 100 * 100 * 8);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Checks that we don't have nasty circular dependencies, preventing Arc from deallocating.
|
||||
// This test doesn't really assert anything, you need to run it manually to check if there
|
||||
// is any issue.
|
||||
#[test]
|
||||
fn test_res_dealloc() -> anyhow::Result<()> {
|
||||
let clock = init_tracing_logger(true);
|
||||
let mut config = TestConfig::new(Some(clock));
|
||||
|
||||
let seed = 123456;
|
||||
config.network = generate_network_opts(seed);
|
||||
let test = config.start(seed);
|
||||
warn!("Running test with seed {}", seed);
|
||||
|
||||
let schedule = generate_schedule(seed);
|
||||
info!("schedule: {:?}", schedule);
|
||||
test.run_schedule(&schedule).unwrap();
|
||||
test.world.stop_all();
|
||||
|
||||
let world = test.world.clone();
|
||||
drop(test);
|
||||
info!("world strong count: {}", Arc::strong_count(&world));
|
||||
world.deallocate();
|
||||
info!("world strong count: {}", Arc::strong_count(&world));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
56
safekeeper/tests/random_test.rs
Normal file
56
safekeeper/tests/random_test.rs
Normal file
@@ -0,0 +1,56 @@
|
||||
use rand::Rng;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::walproposer_sim::{
|
||||
log::{init_logger, init_tracing_logger},
|
||||
simulation::{generate_network_opts, generate_schedule, TestConfig},
|
||||
simulation_logs::validate_events,
|
||||
};
|
||||
|
||||
pub mod walproposer_sim;
|
||||
|
||||
// Generates 2000 random seeds and runs a schedule for each of them.
|
||||
// If you seed this test fail, please report the last seed to the
|
||||
// @safekeeper team.
|
||||
#[test]
|
||||
fn test_random_schedules() -> anyhow::Result<()> {
|
||||
let clock = init_logger();
|
||||
let mut config = TestConfig::new(Some(clock));
|
||||
|
||||
for _ in 0..2000 {
|
||||
let seed: u64 = rand::thread_rng().gen();
|
||||
config.network = generate_network_opts(seed);
|
||||
|
||||
let test = config.start(seed);
|
||||
warn!("Running test with seed {}", seed);
|
||||
|
||||
let schedule = generate_schedule(seed);
|
||||
test.run_schedule(&schedule).unwrap();
|
||||
validate_events(test.world.take_events());
|
||||
test.world.deallocate();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// After you found a seed that fails, you can insert this seed here
|
||||
// and run the test to see the full debug output.
|
||||
#[test]
|
||||
fn test_one_schedule() -> anyhow::Result<()> {
|
||||
let clock = init_tracing_logger(true);
|
||||
let mut config = TestConfig::new(Some(clock));
|
||||
|
||||
let seed = 11047466935058776390;
|
||||
config.network = generate_network_opts(seed);
|
||||
info!("network: {:?}", config.network);
|
||||
let test = config.start(seed);
|
||||
warn!("Running test with seed {}", seed);
|
||||
|
||||
let schedule = generate_schedule(seed);
|
||||
info!("schedule: {:?}", schedule);
|
||||
test.run_schedule(&schedule).unwrap();
|
||||
validate_events(test.world.take_events());
|
||||
test.world.deallocate();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
45
safekeeper/tests/simple_test.rs
Normal file
45
safekeeper/tests/simple_test.rs
Normal file
@@ -0,0 +1,45 @@
|
||||
use tracing::info;
|
||||
use utils::lsn::Lsn;
|
||||
|
||||
use crate::walproposer_sim::{log::init_logger, simulation::TestConfig};
|
||||
|
||||
pub mod walproposer_sim;
|
||||
|
||||
// Check that first start of sync_safekeepers() returns 0/0 on empty safekeepers.
|
||||
#[test]
|
||||
fn sync_empty_safekeepers() {
|
||||
let clock = init_logger();
|
||||
let config = TestConfig::new(Some(clock));
|
||||
let test = config.start(1337);
|
||||
|
||||
let lsn = test.sync_safekeepers().unwrap();
|
||||
assert_eq!(lsn, Lsn(0));
|
||||
info!("Sucessfully synced empty safekeepers at 0/0");
|
||||
|
||||
let lsn = test.sync_safekeepers().unwrap();
|
||||
assert_eq!(lsn, Lsn(0));
|
||||
info!("Sucessfully synced (again) empty safekeepers at 0/0");
|
||||
}
|
||||
|
||||
// Check that there are no panics when we are writing and streaming WAL to safekeepers.
|
||||
#[test]
|
||||
fn run_walproposer_generate_wal() {
|
||||
let clock = init_logger();
|
||||
let config = TestConfig::new(Some(clock));
|
||||
let test = config.start(1337);
|
||||
|
||||
let lsn = test.sync_safekeepers().unwrap();
|
||||
assert_eq!(lsn, Lsn(0));
|
||||
info!("Sucessfully synced empty safekeepers at 0/0");
|
||||
|
||||
let mut wp = test.launch_walproposer(lsn);
|
||||
|
||||
// wait for walproposer to start
|
||||
test.poll_for_duration(30);
|
||||
|
||||
// just write some WAL
|
||||
for _ in 0..100 {
|
||||
wp.write_tx(1);
|
||||
test.poll_for_duration(5);
|
||||
}
|
||||
}
|
||||
57
safekeeper/tests/walproposer_sim/block_storage.rs
Normal file
57
safekeeper/tests/walproposer_sim/block_storage.rs
Normal file
@@ -0,0 +1,57 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
const BLOCK_SIZE: usize = 8192;
|
||||
|
||||
/// A simple in-memory implementation of a block storage. Can be used to implement external
|
||||
/// storage in tests.
|
||||
pub struct BlockStorage {
|
||||
blocks: HashMap<u64, [u8; BLOCK_SIZE]>,
|
||||
}
|
||||
|
||||
impl Default for BlockStorage {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl BlockStorage {
|
||||
pub fn new() -> Self {
|
||||
BlockStorage {
|
||||
blocks: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn read(&self, pos: u64, buf: &mut [u8]) {
|
||||
let mut buf_offset = 0;
|
||||
let mut storage_pos = pos;
|
||||
while buf_offset < buf.len() {
|
||||
let block_id = storage_pos / BLOCK_SIZE as u64;
|
||||
let block = self.blocks.get(&block_id).unwrap_or(&[0; BLOCK_SIZE]);
|
||||
let block_offset = storage_pos % BLOCK_SIZE as u64;
|
||||
let block_len = BLOCK_SIZE as u64 - block_offset;
|
||||
let buf_len = buf.len() - buf_offset;
|
||||
let copy_len = std::cmp::min(block_len as usize, buf_len);
|
||||
buf[buf_offset..buf_offset + copy_len]
|
||||
.copy_from_slice(&block[block_offset as usize..block_offset as usize + copy_len]);
|
||||
buf_offset += copy_len;
|
||||
storage_pos += copy_len as u64;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn write(&mut self, pos: u64, buf: &[u8]) {
|
||||
let mut buf_offset = 0;
|
||||
let mut storage_pos = pos;
|
||||
while buf_offset < buf.len() {
|
||||
let block_id = storage_pos / BLOCK_SIZE as u64;
|
||||
let block = self.blocks.entry(block_id).or_insert([0; BLOCK_SIZE]);
|
||||
let block_offset = storage_pos % BLOCK_SIZE as u64;
|
||||
let block_len = BLOCK_SIZE as u64 - block_offset;
|
||||
let buf_len = buf.len() - buf_offset;
|
||||
let copy_len = std::cmp::min(block_len as usize, buf_len);
|
||||
block[block_offset as usize..block_offset as usize + copy_len]
|
||||
.copy_from_slice(&buf[buf_offset..buf_offset + copy_len]);
|
||||
buf_offset += copy_len;
|
||||
storage_pos += copy_len as u64
|
||||
}
|
||||
}
|
||||
}
|
||||
77
safekeeper/tests/walproposer_sim/log.rs
Normal file
77
safekeeper/tests/walproposer_sim/log.rs
Normal file
@@ -0,0 +1,77 @@
|
||||
use std::{fmt, sync::Arc};
|
||||
|
||||
use desim::time::Timing;
|
||||
use once_cell::sync::OnceCell;
|
||||
use parking_lot::Mutex;
|
||||
use tracing_subscriber::fmt::{format::Writer, time::FormatTime};
|
||||
|
||||
/// SimClock can be plugged into tracing logger to print simulation time.
|
||||
#[derive(Clone)]
|
||||
pub struct SimClock {
|
||||
clock_ptr: Arc<Mutex<Option<Arc<Timing>>>>,
|
||||
}
|
||||
|
||||
impl Default for SimClock {
|
||||
fn default() -> Self {
|
||||
SimClock {
|
||||
clock_ptr: Arc::new(Mutex::new(None)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SimClock {
|
||||
pub fn set_clock(&self, clock: Arc<Timing>) {
|
||||
*self.clock_ptr.lock() = Some(clock);
|
||||
}
|
||||
}
|
||||
|
||||
impl FormatTime for SimClock {
|
||||
fn format_time(&self, w: &mut Writer<'_>) -> fmt::Result {
|
||||
let clock = self.clock_ptr.lock();
|
||||
|
||||
if let Some(clock) = clock.as_ref() {
|
||||
let now = clock.now();
|
||||
write!(w, "[{}]", now)
|
||||
} else {
|
||||
write!(w, "[?]")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static LOGGING_DONE: OnceCell<SimClock> = OnceCell::new();
|
||||
|
||||
/// Returns ptr to clocks attached to tracing logger to update them when the
|
||||
/// world is (re)created.
|
||||
pub fn init_tracing_logger(debug_enabled: bool) -> SimClock {
|
||||
LOGGING_DONE
|
||||
.get_or_init(|| {
|
||||
let clock = SimClock::default();
|
||||
let base_logger = tracing_subscriber::fmt()
|
||||
.with_target(false)
|
||||
// prefix log lines with simulated time timestamp
|
||||
.with_timer(clock.clone())
|
||||
// .with_ansi(true) TODO
|
||||
.with_max_level(match debug_enabled {
|
||||
true => tracing::Level::DEBUG,
|
||||
false => tracing::Level::WARN,
|
||||
})
|
||||
.with_writer(std::io::stdout);
|
||||
base_logger.init();
|
||||
|
||||
// logging::replace_panic_hook_with_tracing_panic_hook().forget();
|
||||
|
||||
if !debug_enabled {
|
||||
std::panic::set_hook(Box::new(|_| {}));
|
||||
}
|
||||
|
||||
clock
|
||||
})
|
||||
.clone()
|
||||
}
|
||||
|
||||
pub fn init_logger() -> SimClock {
|
||||
// RUST_TRACEBACK envvar controls whether we print all logs or only warnings.
|
||||
let debug_enabled = std::env::var("RUST_TRACEBACK").is_ok();
|
||||
|
||||
init_tracing_logger(debug_enabled)
|
||||
}
|
||||
8
safekeeper/tests/walproposer_sim/mod.rs
Normal file
8
safekeeper/tests/walproposer_sim/mod.rs
Normal file
@@ -0,0 +1,8 @@
|
||||
pub mod block_storage;
|
||||
pub mod log;
|
||||
pub mod safekeeper;
|
||||
pub mod safekeeper_disk;
|
||||
pub mod simulation;
|
||||
pub mod simulation_logs;
|
||||
pub mod walproposer_api;
|
||||
pub mod walproposer_disk;
|
||||
410
safekeeper/tests/walproposer_sim/safekeeper.rs
Normal file
410
safekeeper/tests/walproposer_sim/safekeeper.rs
Normal file
@@ -0,0 +1,410 @@
|
||||
//! Safekeeper communication endpoint to WAL proposer (compute node).
|
||||
//! Gets messages from the network, passes them down to consensus module and
|
||||
//! sends replies back.
|
||||
|
||||
use std::{collections::HashMap, sync::Arc, time::Duration};
|
||||
|
||||
use anyhow::{bail, Result};
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use camino::Utf8PathBuf;
|
||||
use desim::{
|
||||
executor::{self, PollSome},
|
||||
network::TCP,
|
||||
node_os::NodeOs,
|
||||
proto::{AnyMessage, NetEvent, NodeEvent},
|
||||
};
|
||||
use hyper::Uri;
|
||||
use safekeeper::{
|
||||
safekeeper::{ProposerAcceptorMessage, SafeKeeper, ServerInfo, UNKNOWN_SERVER_VERSION},
|
||||
state::TimelinePersistentState,
|
||||
timeline::TimelineError,
|
||||
wal_storage::Storage,
|
||||
SafeKeeperConf,
|
||||
};
|
||||
use tracing::{debug, info_span};
|
||||
use utils::{
|
||||
id::{NodeId, TenantId, TenantTimelineId, TimelineId},
|
||||
lsn::Lsn,
|
||||
};
|
||||
|
||||
use super::safekeeper_disk::{DiskStateStorage, DiskWALStorage, SafekeeperDisk, TimelineDisk};
|
||||
|
||||
struct SharedState {
|
||||
sk: SafeKeeper<DiskStateStorage, DiskWALStorage>,
|
||||
disk: Arc<TimelineDisk>,
|
||||
}
|
||||
|
||||
struct GlobalMap {
|
||||
timelines: HashMap<TenantTimelineId, SharedState>,
|
||||
conf: SafeKeeperConf,
|
||||
disk: Arc<SafekeeperDisk>,
|
||||
}
|
||||
|
||||
impl GlobalMap {
|
||||
/// Restores global state from disk.
|
||||
fn new(disk: Arc<SafekeeperDisk>, conf: SafeKeeperConf) -> Result<Self> {
|
||||
let mut timelines = HashMap::new();
|
||||
|
||||
for (&ttid, disk) in disk.timelines.lock().iter() {
|
||||
debug!("loading timeline {}", ttid);
|
||||
let state = disk.state.lock().clone();
|
||||
|
||||
if state.server.wal_seg_size == 0 {
|
||||
bail!(TimelineError::UninitializedWalSegSize(ttid));
|
||||
}
|
||||
|
||||
if state.server.pg_version == UNKNOWN_SERVER_VERSION {
|
||||
bail!(TimelineError::UninitialinzedPgVersion(ttid));
|
||||
}
|
||||
|
||||
if state.commit_lsn < state.local_start_lsn {
|
||||
bail!(
|
||||
"commit_lsn {} is higher than local_start_lsn {}",
|
||||
state.commit_lsn,
|
||||
state.local_start_lsn
|
||||
);
|
||||
}
|
||||
|
||||
let control_store = DiskStateStorage::new(disk.clone());
|
||||
let wal_store = DiskWALStorage::new(disk.clone(), &control_store)?;
|
||||
|
||||
let sk = SafeKeeper::new(control_store, wal_store, conf.my_id)?;
|
||||
timelines.insert(
|
||||
ttid,
|
||||
SharedState {
|
||||
sk,
|
||||
disk: disk.clone(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
timelines,
|
||||
conf,
|
||||
disk,
|
||||
})
|
||||
}
|
||||
|
||||
fn create(&mut self, ttid: TenantTimelineId, server_info: ServerInfo) -> Result<()> {
|
||||
if self.timelines.contains_key(&ttid) {
|
||||
bail!("timeline {} already exists", ttid);
|
||||
}
|
||||
|
||||
debug!("creating new timeline {}", ttid);
|
||||
|
||||
let commit_lsn = Lsn::INVALID;
|
||||
let local_start_lsn = Lsn::INVALID;
|
||||
|
||||
let state =
|
||||
TimelinePersistentState::new(&ttid, server_info, vec![], commit_lsn, local_start_lsn);
|
||||
|
||||
if state.server.wal_seg_size == 0 {
|
||||
bail!(TimelineError::UninitializedWalSegSize(ttid));
|
||||
}
|
||||
|
||||
if state.server.pg_version == UNKNOWN_SERVER_VERSION {
|
||||
bail!(TimelineError::UninitialinzedPgVersion(ttid));
|
||||
}
|
||||
|
||||
if state.commit_lsn < state.local_start_lsn {
|
||||
bail!(
|
||||
"commit_lsn {} is higher than local_start_lsn {}",
|
||||
state.commit_lsn,
|
||||
state.local_start_lsn
|
||||
);
|
||||
}
|
||||
|
||||
let disk_timeline = self.disk.put_state(&ttid, state);
|
||||
let control_store = DiskStateStorage::new(disk_timeline.clone());
|
||||
let wal_store = DiskWALStorage::new(disk_timeline.clone(), &control_store)?;
|
||||
|
||||
let sk = SafeKeeper::new(control_store, wal_store, self.conf.my_id)?;
|
||||
|
||||
self.timelines.insert(
|
||||
ttid,
|
||||
SharedState {
|
||||
sk,
|
||||
disk: disk_timeline,
|
||||
},
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get(&mut self, ttid: &TenantTimelineId) -> &mut SharedState {
|
||||
self.timelines.get_mut(ttid).expect("timeline must exist")
|
||||
}
|
||||
|
||||
fn has_tli(&self, ttid: &TenantTimelineId) -> bool {
|
||||
self.timelines.contains_key(ttid)
|
||||
}
|
||||
}
|
||||
|
||||
/// State of a single connection to walproposer.
|
||||
struct ConnState {
|
||||
tcp: TCP,
|
||||
|
||||
greeting: bool,
|
||||
ttid: TenantTimelineId,
|
||||
flush_pending: bool,
|
||||
|
||||
runtime: tokio::runtime::Runtime,
|
||||
}
|
||||
|
||||
pub fn run_server(os: NodeOs, disk: Arc<SafekeeperDisk>) -> Result<()> {
|
||||
let _enter = info_span!("safekeeper", id = os.id()).entered();
|
||||
debug!("started server");
|
||||
os.log_event("started;safekeeper".to_owned());
|
||||
let conf = SafeKeeperConf {
|
||||
workdir: Utf8PathBuf::from("."),
|
||||
my_id: NodeId(os.id() as u64),
|
||||
listen_pg_addr: String::new(),
|
||||
listen_http_addr: String::new(),
|
||||
no_sync: false,
|
||||
broker_endpoint: "/".parse::<Uri>().unwrap(),
|
||||
broker_keepalive_interval: Duration::from_secs(0),
|
||||
heartbeat_timeout: Duration::from_secs(0),
|
||||
remote_storage: None,
|
||||
max_offloader_lag_bytes: 0,
|
||||
wal_backup_enabled: false,
|
||||
listen_pg_addr_tenant_only: None,
|
||||
advertise_pg_addr: None,
|
||||
availability_zone: None,
|
||||
peer_recovery_enabled: false,
|
||||
backup_parallel_jobs: 0,
|
||||
pg_auth: None,
|
||||
pg_tenant_only_auth: None,
|
||||
http_auth: None,
|
||||
current_thread_runtime: false,
|
||||
};
|
||||
|
||||
let mut global = GlobalMap::new(disk, conf.clone())?;
|
||||
let mut conns: HashMap<usize, ConnState> = HashMap::new();
|
||||
|
||||
for (&_ttid, shared_state) in global.timelines.iter_mut() {
|
||||
let flush_lsn = shared_state.sk.wal_store.flush_lsn();
|
||||
let commit_lsn = shared_state.sk.state.commit_lsn;
|
||||
os.log_event(format!("tli_loaded;{};{}", flush_lsn.0, commit_lsn.0));
|
||||
}
|
||||
|
||||
let node_events = os.node_events();
|
||||
let mut epoll_vec: Vec<Box<dyn PollSome>> = vec![];
|
||||
let mut epoll_idx: Vec<usize> = vec![];
|
||||
|
||||
// TODO: batch events processing (multiple events per tick)
|
||||
loop {
|
||||
epoll_vec.clear();
|
||||
epoll_idx.clear();
|
||||
|
||||
// node events channel
|
||||
epoll_vec.push(Box::new(node_events.clone()));
|
||||
epoll_idx.push(0);
|
||||
|
||||
// tcp connections
|
||||
for conn in conns.values() {
|
||||
epoll_vec.push(Box::new(conn.tcp.recv_chan()));
|
||||
epoll_idx.push(conn.tcp.connection_id());
|
||||
}
|
||||
|
||||
// waiting for the next message
|
||||
let index = executor::epoll_chans(&epoll_vec, -1).unwrap();
|
||||
|
||||
if index == 0 {
|
||||
// got a new connection
|
||||
match node_events.must_recv() {
|
||||
NodeEvent::Accept(tcp) => {
|
||||
conns.insert(
|
||||
tcp.connection_id(),
|
||||
ConnState {
|
||||
tcp,
|
||||
greeting: false,
|
||||
ttid: TenantTimelineId::empty(),
|
||||
flush_pending: false,
|
||||
runtime: tokio::runtime::Builder::new_current_thread().build()?,
|
||||
},
|
||||
);
|
||||
}
|
||||
NodeEvent::Internal(_) => unreachable!(),
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
let connection_id = epoll_idx[index];
|
||||
let conn = conns.get_mut(&connection_id).unwrap();
|
||||
let mut next_event = Some(conn.tcp.recv_chan().must_recv());
|
||||
|
||||
loop {
|
||||
let event = match next_event {
|
||||
Some(event) => event,
|
||||
None => break,
|
||||
};
|
||||
|
||||
match event {
|
||||
NetEvent::Message(msg) => {
|
||||
let res = conn.process_any(msg, &mut global);
|
||||
if res.is_err() {
|
||||
debug!("conn {:?} error: {:#}", connection_id, res.unwrap_err());
|
||||
conns.remove(&connection_id);
|
||||
break;
|
||||
}
|
||||
}
|
||||
NetEvent::Closed => {
|
||||
// TODO: remove from conns?
|
||||
}
|
||||
}
|
||||
|
||||
next_event = conn.tcp.recv_chan().try_recv();
|
||||
}
|
||||
|
||||
conns.retain(|_, conn| {
|
||||
let res = conn.flush(&mut global);
|
||||
if res.is_err() {
|
||||
debug!("conn {:?} error: {:?}", conn.tcp, res);
|
||||
}
|
||||
res.is_ok()
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl ConnState {
|
||||
/// Process a message from the network. It can be START_REPLICATION request or a valid ProposerAcceptorMessage message.
|
||||
fn process_any(&mut self, any: AnyMessage, global: &mut GlobalMap) -> Result<()> {
|
||||
if let AnyMessage::Bytes(copy_data) = any {
|
||||
let repl_prefix = b"START_REPLICATION ";
|
||||
if !self.greeting && copy_data.starts_with(repl_prefix) {
|
||||
self.process_start_replication(copy_data.slice(repl_prefix.len()..), global)?;
|
||||
bail!("finished processing START_REPLICATION")
|
||||
}
|
||||
|
||||
let msg = ProposerAcceptorMessage::parse(copy_data)?;
|
||||
debug!("got msg: {:?}", msg);
|
||||
self.process(msg, global)
|
||||
} else {
|
||||
bail!("unexpected message, expected AnyMessage::Bytes");
|
||||
}
|
||||
}
|
||||
|
||||
/// Process START_REPLICATION request.
|
||||
fn process_start_replication(
|
||||
&mut self,
|
||||
copy_data: Bytes,
|
||||
global: &mut GlobalMap,
|
||||
) -> Result<()> {
|
||||
// format is "<tenant_id> <timeline_id> <start_lsn> <end_lsn>"
|
||||
let str = String::from_utf8(copy_data.to_vec())?;
|
||||
|
||||
let mut parts = str.split(' ');
|
||||
let tenant_id = parts.next().unwrap().parse::<TenantId>()?;
|
||||
let timeline_id = parts.next().unwrap().parse::<TimelineId>()?;
|
||||
let start_lsn = parts.next().unwrap().parse::<u64>()?;
|
||||
let end_lsn = parts.next().unwrap().parse::<u64>()?;
|
||||
|
||||
let ttid = TenantTimelineId::new(tenant_id, timeline_id);
|
||||
let shared_state = global.get(&ttid);
|
||||
|
||||
// read bytes from start_lsn to end_lsn
|
||||
let mut buf = vec![0; (end_lsn - start_lsn) as usize];
|
||||
shared_state.disk.wal.lock().read(start_lsn, &mut buf);
|
||||
|
||||
// send bytes to the client
|
||||
self.tcp.send(AnyMessage::Bytes(Bytes::from(buf)));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get or create a timeline.
|
||||
fn init_timeline(
|
||||
&mut self,
|
||||
ttid: TenantTimelineId,
|
||||
server_info: ServerInfo,
|
||||
global: &mut GlobalMap,
|
||||
) -> Result<()> {
|
||||
self.ttid = ttid;
|
||||
if global.has_tli(&ttid) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
global.create(ttid, server_info)
|
||||
}
|
||||
|
||||
/// Process a ProposerAcceptorMessage.
|
||||
fn process(&mut self, msg: ProposerAcceptorMessage, global: &mut GlobalMap) -> Result<()> {
|
||||
if !self.greeting {
|
||||
self.greeting = true;
|
||||
|
||||
match msg {
|
||||
ProposerAcceptorMessage::Greeting(ref greeting) => {
|
||||
tracing::info!(
|
||||
"start handshake with walproposer {:?} {:?}",
|
||||
self.tcp,
|
||||
greeting
|
||||
);
|
||||
let server_info = ServerInfo {
|
||||
pg_version: greeting.pg_version,
|
||||
system_id: greeting.system_id,
|
||||
wal_seg_size: greeting.wal_seg_size,
|
||||
};
|
||||
let ttid = TenantTimelineId::new(greeting.tenant_id, greeting.timeline_id);
|
||||
self.init_timeline(ttid, server_info, global)?
|
||||
}
|
||||
_ => {
|
||||
bail!("unexpected message {msg:?} instead of greeting");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let tli = global.get(&self.ttid);
|
||||
|
||||
match msg {
|
||||
ProposerAcceptorMessage::AppendRequest(append_request) => {
|
||||
self.flush_pending = true;
|
||||
self.process_sk_msg(
|
||||
tli,
|
||||
&ProposerAcceptorMessage::NoFlushAppendRequest(append_request),
|
||||
)?;
|
||||
}
|
||||
other => {
|
||||
self.process_sk_msg(tli, &other)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Process FlushWAL if needed.
|
||||
fn flush(&mut self, global: &mut GlobalMap) -> Result<()> {
|
||||
// TODO: try to add extra flushes in simulation, to verify that extra flushes don't break anything
|
||||
if !self.flush_pending {
|
||||
return Ok(());
|
||||
}
|
||||
self.flush_pending = false;
|
||||
let shared_state = global.get(&self.ttid);
|
||||
self.process_sk_msg(shared_state, &ProposerAcceptorMessage::FlushWAL)
|
||||
}
|
||||
|
||||
/// Make safekeeper process a message and send a reply to the TCP
|
||||
fn process_sk_msg(
|
||||
&mut self,
|
||||
shared_state: &mut SharedState,
|
||||
msg: &ProposerAcceptorMessage,
|
||||
) -> Result<()> {
|
||||
let mut reply = self.runtime.block_on(shared_state.sk.process_msg(msg))?;
|
||||
if let Some(reply) = &mut reply {
|
||||
// TODO: if this is AppendResponse, fill in proper hot standby feedback and disk consistent lsn
|
||||
|
||||
let mut buf = BytesMut::with_capacity(128);
|
||||
reply.serialize(&mut buf)?;
|
||||
|
||||
self.tcp.send(AnyMessage::Bytes(buf.into()));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for ConnState {
|
||||
fn drop(&mut self) {
|
||||
debug!("dropping conn: {:?}", self.tcp);
|
||||
if !std::thread::panicking() {
|
||||
self.tcp.close();
|
||||
}
|
||||
// TODO: clean up non-fsynced WAL
|
||||
}
|
||||
}
|
||||
278
safekeeper/tests/walproposer_sim/safekeeper_disk.rs
Normal file
278
safekeeper/tests/walproposer_sim/safekeeper_disk.rs
Normal file
@@ -0,0 +1,278 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use parking_lot::Mutex;
|
||||
use safekeeper::state::TimelinePersistentState;
|
||||
use utils::id::TenantTimelineId;
|
||||
|
||||
use super::block_storage::BlockStorage;
|
||||
|
||||
use std::{ops::Deref, time::Instant};
|
||||
|
||||
use anyhow::Result;
|
||||
use bytes::{Buf, BytesMut};
|
||||
use futures::future::BoxFuture;
|
||||
use postgres_ffi::{waldecoder::WalStreamDecoder, XLogSegNo};
|
||||
use safekeeper::{control_file, metrics::WalStorageMetrics, wal_storage};
|
||||
use tracing::{debug, info};
|
||||
use utils::lsn::Lsn;
|
||||
|
||||
/// All safekeeper state that is usually saved to disk.
|
||||
pub struct SafekeeperDisk {
|
||||
pub timelines: Mutex<HashMap<TenantTimelineId, Arc<TimelineDisk>>>,
|
||||
}
|
||||
|
||||
impl Default for SafekeeperDisk {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl SafekeeperDisk {
|
||||
pub fn new() -> Self {
|
||||
SafekeeperDisk {
|
||||
timelines: Mutex::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn put_state(
|
||||
&self,
|
||||
ttid: &TenantTimelineId,
|
||||
state: TimelinePersistentState,
|
||||
) -> Arc<TimelineDisk> {
|
||||
self.timelines
|
||||
.lock()
|
||||
.entry(*ttid)
|
||||
.and_modify(|e| {
|
||||
let mut mu = e.state.lock();
|
||||
*mu = state.clone();
|
||||
})
|
||||
.or_insert_with(|| {
|
||||
Arc::new(TimelineDisk {
|
||||
state: Mutex::new(state),
|
||||
wal: Mutex::new(BlockStorage::new()),
|
||||
})
|
||||
})
|
||||
.clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// Control file state and WAL storage.
|
||||
pub struct TimelineDisk {
|
||||
pub state: Mutex<TimelinePersistentState>,
|
||||
pub wal: Mutex<BlockStorage>,
|
||||
}
|
||||
|
||||
/// Implementation of `control_file::Storage` trait.
|
||||
pub struct DiskStateStorage {
|
||||
persisted_state: TimelinePersistentState,
|
||||
disk: Arc<TimelineDisk>,
|
||||
last_persist_at: Instant,
|
||||
}
|
||||
|
||||
impl DiskStateStorage {
|
||||
pub fn new(disk: Arc<TimelineDisk>) -> Self {
|
||||
let guard = disk.state.lock();
|
||||
let state = guard.clone();
|
||||
drop(guard);
|
||||
DiskStateStorage {
|
||||
persisted_state: state,
|
||||
disk,
|
||||
last_persist_at: Instant::now(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl control_file::Storage for DiskStateStorage {
|
||||
/// Persist safekeeper state on disk and update internal state.
|
||||
async fn persist(&mut self, s: &TimelinePersistentState) -> Result<()> {
|
||||
self.persisted_state = s.clone();
|
||||
*self.disk.state.lock() = s.clone();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Timestamp of last persist.
|
||||
fn last_persist_at(&self) -> Instant {
|
||||
// TODO: don't rely on it in tests
|
||||
self.last_persist_at
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for DiskStateStorage {
|
||||
type Target = TimelinePersistentState;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.persisted_state
|
||||
}
|
||||
}
|
||||
|
||||
/// Implementation of `wal_storage::Storage` trait.
|
||||
pub struct DiskWALStorage {
|
||||
/// Written to disk, but possibly still in the cache and not fully persisted.
|
||||
/// Also can be ahead of record_lsn, if happen to be in the middle of a WAL record.
|
||||
write_lsn: Lsn,
|
||||
|
||||
/// The LSN of the last WAL record written to disk. Still can be not fully flushed.
|
||||
write_record_lsn: Lsn,
|
||||
|
||||
/// The LSN of the last WAL record flushed to disk.
|
||||
flush_record_lsn: Lsn,
|
||||
|
||||
/// Decoder is required for detecting boundaries of WAL records.
|
||||
decoder: WalStreamDecoder,
|
||||
|
||||
/// Bytes of WAL records that are not yet written to disk.
|
||||
unflushed_bytes: BytesMut,
|
||||
|
||||
/// Contains BlockStorage for WAL.
|
||||
disk: Arc<TimelineDisk>,
|
||||
}
|
||||
|
||||
impl DiskWALStorage {
|
||||
pub fn new(disk: Arc<TimelineDisk>, state: &TimelinePersistentState) -> Result<Self> {
|
||||
let write_lsn = if state.commit_lsn == Lsn(0) {
|
||||
Lsn(0)
|
||||
} else {
|
||||
Self::find_end_of_wal(disk.clone(), state.commit_lsn)?
|
||||
};
|
||||
|
||||
let flush_lsn = write_lsn;
|
||||
Ok(DiskWALStorage {
|
||||
write_lsn,
|
||||
write_record_lsn: flush_lsn,
|
||||
flush_record_lsn: flush_lsn,
|
||||
decoder: WalStreamDecoder::new(flush_lsn, 16),
|
||||
unflushed_bytes: BytesMut::new(),
|
||||
disk,
|
||||
})
|
||||
}
|
||||
|
||||
fn find_end_of_wal(disk: Arc<TimelineDisk>, start_lsn: Lsn) -> Result<Lsn> {
|
||||
let mut buf = [0; 8192];
|
||||
let mut pos = start_lsn.0;
|
||||
let mut decoder = WalStreamDecoder::new(start_lsn, 16);
|
||||
let mut result = start_lsn;
|
||||
loop {
|
||||
disk.wal.lock().read(pos, &mut buf);
|
||||
pos += buf.len() as u64;
|
||||
decoder.feed_bytes(&buf);
|
||||
|
||||
loop {
|
||||
match decoder.poll_decode() {
|
||||
Ok(Some(record)) => result = record.0,
|
||||
Err(e) => {
|
||||
debug!(
|
||||
"find_end_of_wal reached end at {:?}, decode error: {:?}",
|
||||
result, e
|
||||
);
|
||||
return Ok(result);
|
||||
}
|
||||
Ok(None) => break, // need more data
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl wal_storage::Storage for DiskWALStorage {
|
||||
/// LSN of last durably stored WAL record.
|
||||
fn flush_lsn(&self) -> Lsn {
|
||||
self.flush_record_lsn
|
||||
}
|
||||
|
||||
/// Write piece of WAL from buf to disk, but not necessarily sync it.
|
||||
async fn write_wal(&mut self, startpos: Lsn, buf: &[u8]) -> Result<()> {
|
||||
if self.write_lsn != startpos {
|
||||
panic!("write_wal called with wrong startpos");
|
||||
}
|
||||
|
||||
self.unflushed_bytes.extend_from_slice(buf);
|
||||
self.write_lsn += buf.len() as u64;
|
||||
|
||||
if self.decoder.available() != startpos {
|
||||
info!(
|
||||
"restart decoder from {} to {}",
|
||||
self.decoder.available(),
|
||||
startpos,
|
||||
);
|
||||
self.decoder = WalStreamDecoder::new(startpos, 16);
|
||||
}
|
||||
self.decoder.feed_bytes(buf);
|
||||
loop {
|
||||
match self.decoder.poll_decode()? {
|
||||
None => break, // no full record yet
|
||||
Some((lsn, _rec)) => {
|
||||
self.write_record_lsn = lsn;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Truncate WAL at specified LSN, which must be the end of WAL record.
|
||||
async fn truncate_wal(&mut self, end_pos: Lsn) -> Result<()> {
|
||||
if self.write_lsn != Lsn(0) && end_pos > self.write_lsn {
|
||||
panic!(
|
||||
"truncate_wal called on non-written WAL, write_lsn={}, end_pos={}",
|
||||
self.write_lsn, end_pos
|
||||
);
|
||||
}
|
||||
|
||||
self.flush_wal().await?;
|
||||
|
||||
// write zeroes to disk from end_pos until self.write_lsn
|
||||
let buf = [0; 8192];
|
||||
let mut pos = end_pos.0;
|
||||
while pos < self.write_lsn.0 {
|
||||
self.disk.wal.lock().write(pos, &buf);
|
||||
pos += buf.len() as u64;
|
||||
}
|
||||
|
||||
self.write_lsn = end_pos;
|
||||
self.write_record_lsn = end_pos;
|
||||
self.flush_record_lsn = end_pos;
|
||||
self.unflushed_bytes.clear();
|
||||
self.decoder = WalStreamDecoder::new(end_pos, 16);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Durably store WAL on disk, up to the last written WAL record.
|
||||
async fn flush_wal(&mut self) -> Result<()> {
|
||||
if self.flush_record_lsn == self.write_record_lsn {
|
||||
// no need to do extra flush
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let num_bytes = self.write_record_lsn.0 - self.flush_record_lsn.0;
|
||||
|
||||
self.disk.wal.lock().write(
|
||||
self.flush_record_lsn.0,
|
||||
&self.unflushed_bytes[..num_bytes as usize],
|
||||
);
|
||||
self.unflushed_bytes.advance(num_bytes as usize);
|
||||
self.flush_record_lsn = self.write_record_lsn;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Remove all segments <= given segno. Returns function doing that as we
|
||||
/// want to perform it without timeline lock.
|
||||
fn remove_up_to(&self, _segno_up_to: XLogSegNo) -> BoxFuture<'static, anyhow::Result<()>> {
|
||||
Box::pin(async move { Ok(()) })
|
||||
}
|
||||
|
||||
/// Release resources associated with the storage -- technically, close FDs.
|
||||
/// Currently we don't remove timelines until restart (#3146), so need to
|
||||
/// spare descriptors. This would be useful for temporary tli detach as
|
||||
/// well.
|
||||
fn close(&mut self) {}
|
||||
|
||||
/// Get metrics for this timeline.
|
||||
fn get_metrics(&self) -> WalStorageMetrics {
|
||||
WalStorageMetrics::default()
|
||||
}
|
||||
}
|
||||
436
safekeeper/tests/walproposer_sim/simulation.rs
Normal file
436
safekeeper/tests/walproposer_sim/simulation.rs
Normal file
@@ -0,0 +1,436 @@
|
||||
use std::{cell::Cell, str::FromStr, sync::Arc};
|
||||
|
||||
use crate::walproposer_sim::{safekeeper::run_server, walproposer_api::SimulationApi};
|
||||
use desim::{
|
||||
executor::{self, ExternalHandle},
|
||||
node_os::NodeOs,
|
||||
options::{Delay, NetworkOptions},
|
||||
proto::{AnyMessage, NodeEvent},
|
||||
world::Node,
|
||||
world::World,
|
||||
};
|
||||
use rand::{Rng, SeedableRng};
|
||||
use tracing::{debug, info_span, warn};
|
||||
use utils::{id::TenantTimelineId, lsn::Lsn};
|
||||
use walproposer::walproposer::{Config, Wrapper};
|
||||
|
||||
use super::{
|
||||
log::SimClock, safekeeper_disk::SafekeeperDisk, walproposer_api,
|
||||
walproposer_disk::DiskWalProposer,
|
||||
};
|
||||
|
||||
/// Simulated safekeeper node.
|
||||
pub struct SafekeeperNode {
|
||||
pub node: Arc<Node>,
|
||||
pub id: u32,
|
||||
pub disk: Arc<SafekeeperDisk>,
|
||||
pub thread: Cell<ExternalHandle>,
|
||||
}
|
||||
|
||||
impl SafekeeperNode {
|
||||
/// Create and start a safekeeper at the specified Node.
|
||||
pub fn new(node: Arc<Node>) -> Self {
|
||||
let disk = Arc::new(SafekeeperDisk::new());
|
||||
let thread = Cell::new(SafekeeperNode::launch(disk.clone(), node.clone()));
|
||||
|
||||
Self {
|
||||
id: node.id,
|
||||
node,
|
||||
disk,
|
||||
thread,
|
||||
}
|
||||
}
|
||||
|
||||
fn launch(disk: Arc<SafekeeperDisk>, node: Arc<Node>) -> ExternalHandle {
|
||||
// start the server thread
|
||||
node.launch(move |os| {
|
||||
run_server(os, disk).expect("server should finish without errors");
|
||||
})
|
||||
}
|
||||
|
||||
/// Restart the safekeeper.
|
||||
pub fn restart(&self) {
|
||||
let new_thread = SafekeeperNode::launch(self.disk.clone(), self.node.clone());
|
||||
let old_thread = self.thread.replace(new_thread);
|
||||
old_thread.crash_stop();
|
||||
}
|
||||
}
|
||||
|
||||
/// Simulated walproposer node.
|
||||
pub struct WalProposer {
|
||||
thread: ExternalHandle,
|
||||
node: Arc<Node>,
|
||||
disk: Arc<DiskWalProposer>,
|
||||
sync_safekeepers: bool,
|
||||
}
|
||||
|
||||
impl WalProposer {
|
||||
/// Generic start function for both modes.
|
||||
fn start(
|
||||
os: NodeOs,
|
||||
disk: Arc<DiskWalProposer>,
|
||||
ttid: TenantTimelineId,
|
||||
addrs: Vec<String>,
|
||||
lsn: Option<Lsn>,
|
||||
) {
|
||||
let sync_safekeepers = lsn.is_none();
|
||||
|
||||
let _enter = if sync_safekeepers {
|
||||
info_span!("sync", started = executor::now()).entered()
|
||||
} else {
|
||||
info_span!("walproposer", started = executor::now()).entered()
|
||||
};
|
||||
|
||||
os.log_event(format!("started;walproposer;{}", sync_safekeepers as i32));
|
||||
|
||||
let config = Config {
|
||||
ttid,
|
||||
safekeepers_list: addrs,
|
||||
safekeeper_reconnect_timeout: 1000,
|
||||
safekeeper_connection_timeout: 5000,
|
||||
sync_safekeepers,
|
||||
};
|
||||
let args = walproposer_api::Args {
|
||||
os,
|
||||
config: config.clone(),
|
||||
disk,
|
||||
redo_start_lsn: lsn,
|
||||
};
|
||||
let api = SimulationApi::new(args);
|
||||
let wp = Wrapper::new(Box::new(api), config);
|
||||
wp.start();
|
||||
}
|
||||
|
||||
/// Start walproposer in a sync_safekeepers mode.
|
||||
pub fn launch_sync(ttid: TenantTimelineId, addrs: Vec<String>, node: Arc<Node>) -> Self {
|
||||
debug!("sync_safekeepers started at node {}", node.id);
|
||||
let disk = DiskWalProposer::new();
|
||||
let disk_wp = disk.clone();
|
||||
|
||||
// start the client thread
|
||||
let handle = node.launch(move |os| {
|
||||
WalProposer::start(os, disk_wp, ttid, addrs, None);
|
||||
});
|
||||
|
||||
Self {
|
||||
thread: handle,
|
||||
node,
|
||||
disk,
|
||||
sync_safekeepers: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Start walproposer in a normal mode.
|
||||
pub fn launch_walproposer(
|
||||
ttid: TenantTimelineId,
|
||||
addrs: Vec<String>,
|
||||
node: Arc<Node>,
|
||||
lsn: Lsn,
|
||||
) -> Self {
|
||||
debug!("walproposer started at node {}", node.id);
|
||||
let disk = DiskWalProposer::new();
|
||||
disk.lock().reset_to(lsn);
|
||||
let disk_wp = disk.clone();
|
||||
|
||||
// start the client thread
|
||||
let handle = node.launch(move |os| {
|
||||
WalProposer::start(os, disk_wp, ttid, addrs, Some(lsn));
|
||||
});
|
||||
|
||||
Self {
|
||||
thread: handle,
|
||||
node,
|
||||
disk,
|
||||
sync_safekeepers: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn write_tx(&mut self, cnt: usize) {
|
||||
let start_lsn = self.disk.lock().flush_rec_ptr();
|
||||
|
||||
for _ in 0..cnt {
|
||||
self.disk
|
||||
.lock()
|
||||
.insert_logical_message("prefix", b"message")
|
||||
.expect("failed to generate logical message");
|
||||
}
|
||||
|
||||
let end_lsn = self.disk.lock().flush_rec_ptr();
|
||||
|
||||
// log event
|
||||
self.node
|
||||
.log_event(format!("write_wal;{};{};{}", start_lsn.0, end_lsn.0, cnt));
|
||||
|
||||
// now we need to set "Latch" in walproposer
|
||||
self.node
|
||||
.node_events()
|
||||
.send(NodeEvent::Internal(AnyMessage::Just32(0)));
|
||||
}
|
||||
|
||||
pub fn stop(&self) {
|
||||
self.thread.crash_stop();
|
||||
}
|
||||
}
|
||||
|
||||
/// Holds basic simulation settings, such as network options.
|
||||
pub struct TestConfig {
|
||||
pub network: NetworkOptions,
|
||||
pub timeout: u64,
|
||||
pub clock: Option<SimClock>,
|
||||
}
|
||||
|
||||
impl TestConfig {
|
||||
/// Create a new TestConfig with default settings.
|
||||
pub fn new(clock: Option<SimClock>) -> Self {
|
||||
Self {
|
||||
network: NetworkOptions {
|
||||
keepalive_timeout: Some(2000),
|
||||
connect_delay: Delay {
|
||||
min: 1,
|
||||
max: 5,
|
||||
fail_prob: 0.0,
|
||||
},
|
||||
send_delay: Delay {
|
||||
min: 1,
|
||||
max: 5,
|
||||
fail_prob: 0.0,
|
||||
},
|
||||
},
|
||||
timeout: 1_000 * 10,
|
||||
clock,
|
||||
}
|
||||
}
|
||||
|
||||
/// Start a new simulation with the specified seed.
|
||||
pub fn start(&self, seed: u64) -> Test {
|
||||
let world = Arc::new(World::new(seed, Arc::new(self.network.clone())));
|
||||
|
||||
if let Some(clock) = &self.clock {
|
||||
clock.set_clock(world.clock());
|
||||
}
|
||||
|
||||
let servers = [
|
||||
SafekeeperNode::new(world.new_node()),
|
||||
SafekeeperNode::new(world.new_node()),
|
||||
SafekeeperNode::new(world.new_node()),
|
||||
];
|
||||
|
||||
let server_ids = [servers[0].id, servers[1].id, servers[2].id];
|
||||
let safekeepers_addrs = server_ids.map(|id| format!("node:{}", id)).to_vec();
|
||||
|
||||
let ttid = TenantTimelineId::generate();
|
||||
|
||||
Test {
|
||||
world,
|
||||
servers,
|
||||
sk_list: safekeepers_addrs,
|
||||
ttid,
|
||||
timeout: self.timeout,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Holds simulation state.
|
||||
pub struct Test {
|
||||
pub world: Arc<World>,
|
||||
pub servers: [SafekeeperNode; 3],
|
||||
pub sk_list: Vec<String>,
|
||||
pub ttid: TenantTimelineId,
|
||||
pub timeout: u64,
|
||||
}
|
||||
|
||||
impl Test {
|
||||
/// Start a sync_safekeepers thread and wait for it to finish.
|
||||
pub fn sync_safekeepers(&self) -> anyhow::Result<Lsn> {
|
||||
let wp = self.launch_sync_safekeepers();
|
||||
|
||||
// poll until exit or timeout
|
||||
let time_limit = self.timeout;
|
||||
while self.world.step() && self.world.now() < time_limit && !wp.thread.is_finished() {}
|
||||
|
||||
if !wp.thread.is_finished() {
|
||||
anyhow::bail!("timeout or idle stuck");
|
||||
}
|
||||
|
||||
let res = wp.thread.result();
|
||||
if res.0 != 0 {
|
||||
anyhow::bail!("non-zero exitcode: {:?}", res);
|
||||
}
|
||||
let lsn = Lsn::from_str(&res.1)?;
|
||||
Ok(lsn)
|
||||
}
|
||||
|
||||
/// Spawn a new sync_safekeepers thread.
|
||||
pub fn launch_sync_safekeepers(&self) -> WalProposer {
|
||||
WalProposer::launch_sync(self.ttid, self.sk_list.clone(), self.world.new_node())
|
||||
}
|
||||
|
||||
/// Spawn a new walproposer thread.
|
||||
pub fn launch_walproposer(&self, lsn: Lsn) -> WalProposer {
|
||||
let lsn = if lsn.0 == 0 {
|
||||
// usual LSN after basebackup
|
||||
Lsn(21623024)
|
||||
} else {
|
||||
lsn
|
||||
};
|
||||
|
||||
WalProposer::launch_walproposer(self.ttid, self.sk_list.clone(), self.world.new_node(), lsn)
|
||||
}
|
||||
|
||||
/// Execute the simulation for the specified duration.
|
||||
pub fn poll_for_duration(&self, duration: u64) {
|
||||
let time_limit = std::cmp::min(self.world.now() + duration, self.timeout);
|
||||
while self.world.step() && self.world.now() < time_limit {}
|
||||
}
|
||||
|
||||
/// Execute the simulation together with events defined in some schedule.
|
||||
pub fn run_schedule(&self, schedule: &Schedule) -> anyhow::Result<()> {
|
||||
// scheduling empty events so that world will stop in those points
|
||||
{
|
||||
let clock = self.world.clock();
|
||||
|
||||
let now = self.world.now();
|
||||
for (time, _) in schedule {
|
||||
if *time < now {
|
||||
continue;
|
||||
}
|
||||
clock.schedule_fake(*time - now);
|
||||
}
|
||||
}
|
||||
|
||||
let mut wp = self.launch_sync_safekeepers();
|
||||
|
||||
let mut skipped_tx = 0;
|
||||
let mut started_tx = 0;
|
||||
|
||||
let mut schedule_ptr = 0;
|
||||
|
||||
loop {
|
||||
if wp.sync_safekeepers && wp.thread.is_finished() {
|
||||
let res = wp.thread.result();
|
||||
if res.0 != 0 {
|
||||
warn!("sync non-zero exitcode: {:?}", res);
|
||||
debug!("restarting sync_safekeepers");
|
||||
// restart the sync_safekeepers
|
||||
wp = self.launch_sync_safekeepers();
|
||||
continue;
|
||||
}
|
||||
let lsn = Lsn::from_str(&res.1)?;
|
||||
debug!("sync_safekeepers finished at LSN {}", lsn);
|
||||
wp = self.launch_walproposer(lsn);
|
||||
debug!("walproposer started at thread {}", wp.thread.id());
|
||||
}
|
||||
|
||||
let now = self.world.now();
|
||||
while schedule_ptr < schedule.len() && schedule[schedule_ptr].0 <= now {
|
||||
if now != schedule[schedule_ptr].0 {
|
||||
warn!("skipped event {:?} at {}", schedule[schedule_ptr], now);
|
||||
}
|
||||
|
||||
let action = &schedule[schedule_ptr].1;
|
||||
match action {
|
||||
TestAction::WriteTx(size) => {
|
||||
if !wp.sync_safekeepers && !wp.thread.is_finished() {
|
||||
started_tx += *size;
|
||||
wp.write_tx(*size);
|
||||
debug!("written {} transactions", size);
|
||||
} else {
|
||||
skipped_tx += size;
|
||||
debug!("skipped {} transactions", size);
|
||||
}
|
||||
}
|
||||
TestAction::RestartSafekeeper(id) => {
|
||||
debug!("restarting safekeeper {}", id);
|
||||
self.servers[*id].restart();
|
||||
}
|
||||
TestAction::RestartWalProposer => {
|
||||
debug!("restarting sync_safekeepers");
|
||||
wp.stop();
|
||||
wp = self.launch_sync_safekeepers();
|
||||
}
|
||||
}
|
||||
schedule_ptr += 1;
|
||||
}
|
||||
|
||||
if schedule_ptr == schedule.len() {
|
||||
break;
|
||||
}
|
||||
let next_event_time = schedule[schedule_ptr].0;
|
||||
|
||||
// poll until the next event
|
||||
if wp.thread.is_finished() {
|
||||
while self.world.step() && self.world.now() < next_event_time {}
|
||||
} else {
|
||||
while self.world.step()
|
||||
&& self.world.now() < next_event_time
|
||||
&& !wp.thread.is_finished()
|
||||
{}
|
||||
}
|
||||
}
|
||||
|
||||
debug!(
|
||||
"finished schedule, total steps: {}",
|
||||
self.world.get_thread_step_count()
|
||||
);
|
||||
debug!("skipped_tx: {}", skipped_tx);
|
||||
debug!("started_tx: {}", started_tx);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum TestAction {
|
||||
WriteTx(usize),
|
||||
RestartSafekeeper(usize),
|
||||
RestartWalProposer,
|
||||
}
|
||||
|
||||
pub type Schedule = Vec<(u64, TestAction)>;
|
||||
|
||||
pub fn generate_schedule(seed: u64) -> Schedule {
|
||||
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
|
||||
let mut schedule = Vec::new();
|
||||
let mut time = 0;
|
||||
|
||||
let cnt = rng.gen_range(1..100);
|
||||
|
||||
for _ in 0..cnt {
|
||||
time += rng.gen_range(0..500);
|
||||
let action = match rng.gen_range(0..3) {
|
||||
0 => TestAction::WriteTx(rng.gen_range(1..10)),
|
||||
1 => TestAction::RestartSafekeeper(rng.gen_range(0..3)),
|
||||
2 => TestAction::RestartWalProposer,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
schedule.push((time, action));
|
||||
}
|
||||
|
||||
schedule
|
||||
}
|
||||
|
||||
pub fn generate_network_opts(seed: u64) -> NetworkOptions {
|
||||
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
|
||||
|
||||
let timeout = rng.gen_range(100..2000);
|
||||
let max_delay = rng.gen_range(1..2 * timeout);
|
||||
let min_delay = rng.gen_range(1..=max_delay);
|
||||
|
||||
let max_fail_prob = rng.gen_range(0.0..0.9);
|
||||
let connect_fail_prob = rng.gen_range(0.0..max_fail_prob);
|
||||
let send_fail_prob = rng.gen_range(0.0..connect_fail_prob);
|
||||
|
||||
NetworkOptions {
|
||||
keepalive_timeout: Some(timeout),
|
||||
connect_delay: Delay {
|
||||
min: min_delay,
|
||||
max: max_delay,
|
||||
fail_prob: connect_fail_prob,
|
||||
},
|
||||
send_delay: Delay {
|
||||
min: min_delay,
|
||||
max: max_delay,
|
||||
fail_prob: send_fail_prob,
|
||||
},
|
||||
}
|
||||
}
|
||||
187
safekeeper/tests/walproposer_sim/simulation_logs.rs
Normal file
187
safekeeper/tests/walproposer_sim/simulation_logs.rs
Normal file
@@ -0,0 +1,187 @@
|
||||
use desim::proto::SimEvent;
|
||||
use tracing::debug;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
enum NodeKind {
|
||||
Unknown,
|
||||
Safekeeper,
|
||||
WalProposer,
|
||||
}
|
||||
|
||||
impl Default for NodeKind {
|
||||
fn default() -> Self {
|
||||
Self::Unknown
|
||||
}
|
||||
}
|
||||
|
||||
/// Simulation state of walproposer/safekeeper, derived from the simulation logs.
|
||||
#[derive(Clone, Debug, Default)]
|
||||
struct NodeInfo {
|
||||
kind: NodeKind,
|
||||
|
||||
// walproposer
|
||||
is_sync: bool,
|
||||
term: u64,
|
||||
epoch_lsn: u64,
|
||||
|
||||
// safekeeper
|
||||
commit_lsn: u64,
|
||||
flush_lsn: u64,
|
||||
}
|
||||
|
||||
impl NodeInfo {
|
||||
fn init_kind(&mut self, kind: NodeKind) {
|
||||
if self.kind == NodeKind::Unknown {
|
||||
self.kind = kind;
|
||||
} else {
|
||||
assert!(self.kind == kind);
|
||||
}
|
||||
}
|
||||
|
||||
fn started(&mut self, data: &str) {
|
||||
let mut parts = data.split(';');
|
||||
assert!(parts.next().unwrap() == "started");
|
||||
match parts.next().unwrap() {
|
||||
"safekeeper" => {
|
||||
self.init_kind(NodeKind::Safekeeper);
|
||||
}
|
||||
"walproposer" => {
|
||||
self.init_kind(NodeKind::WalProposer);
|
||||
let is_sync: u8 = parts.next().unwrap().parse().unwrap();
|
||||
self.is_sync = is_sync != 0;
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Global state of the simulation, derived from the simulation logs.
|
||||
#[derive(Debug, Default)]
|
||||
struct GlobalState {
|
||||
nodes: Vec<NodeInfo>,
|
||||
commit_lsn: u64,
|
||||
write_lsn: u64,
|
||||
max_write_lsn: u64,
|
||||
|
||||
written_wal: u64,
|
||||
written_records: u64,
|
||||
}
|
||||
|
||||
impl GlobalState {
|
||||
fn new() -> Self {
|
||||
Default::default()
|
||||
}
|
||||
|
||||
fn get(&mut self, id: u32) -> &mut NodeInfo {
|
||||
let id = id as usize;
|
||||
if id >= self.nodes.len() {
|
||||
self.nodes.resize(id + 1, NodeInfo::default());
|
||||
}
|
||||
&mut self.nodes[id]
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to find inconsistencies in the simulation log.
|
||||
pub fn validate_events(events: Vec<SimEvent>) {
|
||||
const INITDB_LSN: u64 = 21623024;
|
||||
|
||||
let hook = std::panic::take_hook();
|
||||
scopeguard::defer_on_success! {
|
||||
std::panic::set_hook(hook);
|
||||
};
|
||||
|
||||
let mut state = GlobalState::new();
|
||||
state.max_write_lsn = INITDB_LSN;
|
||||
|
||||
for event in events {
|
||||
debug!("{:?}", event);
|
||||
|
||||
let node = state.get(event.node);
|
||||
if event.data.starts_with("started;") {
|
||||
node.started(&event.data);
|
||||
continue;
|
||||
}
|
||||
assert!(node.kind != NodeKind::Unknown);
|
||||
|
||||
// drop reference to unlock state
|
||||
let mut node = node.clone();
|
||||
|
||||
let mut parts = event.data.split(';');
|
||||
match node.kind {
|
||||
NodeKind::Safekeeper => match parts.next().unwrap() {
|
||||
"tli_loaded" => {
|
||||
let flush_lsn: u64 = parts.next().unwrap().parse().unwrap();
|
||||
let commit_lsn: u64 = parts.next().unwrap().parse().unwrap();
|
||||
node.flush_lsn = flush_lsn;
|
||||
node.commit_lsn = commit_lsn;
|
||||
}
|
||||
_ => unreachable!(),
|
||||
},
|
||||
NodeKind::WalProposer => {
|
||||
match parts.next().unwrap() {
|
||||
"prop_elected" => {
|
||||
let prop_lsn: u64 = parts.next().unwrap().parse().unwrap();
|
||||
let prop_term: u64 = parts.next().unwrap().parse().unwrap();
|
||||
let prev_lsn: u64 = parts.next().unwrap().parse().unwrap();
|
||||
let prev_term: u64 = parts.next().unwrap().parse().unwrap();
|
||||
|
||||
assert!(prop_lsn >= prev_lsn);
|
||||
assert!(prop_term >= prev_term);
|
||||
|
||||
assert!(prop_lsn >= state.commit_lsn);
|
||||
|
||||
if prop_lsn > state.write_lsn {
|
||||
assert!(prop_lsn <= state.max_write_lsn);
|
||||
debug!(
|
||||
"moving write_lsn up from {} to {}",
|
||||
state.write_lsn, prop_lsn
|
||||
);
|
||||
state.write_lsn = prop_lsn;
|
||||
}
|
||||
if prop_lsn < state.write_lsn {
|
||||
debug!(
|
||||
"moving write_lsn down from {} to {}",
|
||||
state.write_lsn, prop_lsn
|
||||
);
|
||||
state.write_lsn = prop_lsn;
|
||||
}
|
||||
|
||||
node.epoch_lsn = prop_lsn;
|
||||
node.term = prop_term;
|
||||
}
|
||||
"write_wal" => {
|
||||
assert!(!node.is_sync);
|
||||
let start_lsn: u64 = parts.next().unwrap().parse().unwrap();
|
||||
let end_lsn: u64 = parts.next().unwrap().parse().unwrap();
|
||||
let cnt: u64 = parts.next().unwrap().parse().unwrap();
|
||||
|
||||
let size = end_lsn - start_lsn;
|
||||
state.written_wal += size;
|
||||
state.written_records += cnt;
|
||||
|
||||
// TODO: If we allow writing WAL before winning the election
|
||||
|
||||
assert!(start_lsn >= state.commit_lsn);
|
||||
assert!(end_lsn >= start_lsn);
|
||||
// assert!(start_lsn == state.write_lsn);
|
||||
state.write_lsn = end_lsn;
|
||||
|
||||
if end_lsn > state.max_write_lsn {
|
||||
state.max_write_lsn = end_lsn;
|
||||
}
|
||||
}
|
||||
"commit_lsn" => {
|
||||
let lsn: u64 = parts.next().unwrap().parse().unwrap();
|
||||
assert!(lsn >= state.commit_lsn);
|
||||
state.commit_lsn = lsn;
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
|
||||
// update the node in the state struct
|
||||
*state.get(event.node) = node;
|
||||
}
|
||||
}
|
||||
676
safekeeper/tests/walproposer_sim/walproposer_api.rs
Normal file
676
safekeeper/tests/walproposer_sim/walproposer_api.rs
Normal file
@@ -0,0 +1,676 @@
|
||||
use std::{
|
||||
cell::{RefCell, RefMut, UnsafeCell},
|
||||
ffi::CStr,
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use bytes::Bytes;
|
||||
use desim::{
|
||||
executor::{self, PollSome},
|
||||
network::TCP,
|
||||
node_os::NodeOs,
|
||||
proto::{AnyMessage, NetEvent, NodeEvent},
|
||||
world::NodeId,
|
||||
};
|
||||
use tracing::debug;
|
||||
use utils::lsn::Lsn;
|
||||
use walproposer::{
|
||||
api_bindings::Level,
|
||||
bindings::{
|
||||
pg_atomic_uint64, NeonWALReadResult, PageserverFeedback, SafekeeperStateDesiredEvents,
|
||||
WL_SOCKET_READABLE, WL_SOCKET_WRITEABLE,
|
||||
},
|
||||
walproposer::{ApiImpl, Config},
|
||||
};
|
||||
|
||||
use super::walproposer_disk::DiskWalProposer;
|
||||
|
||||
/// Special state for each wp->sk connection.
|
||||
struct SafekeeperConn {
|
||||
host: String,
|
||||
port: String,
|
||||
node_id: NodeId,
|
||||
// socket is Some(..) equals to connection is established
|
||||
socket: Option<TCP>,
|
||||
// connection is in progress
|
||||
is_connecting: bool,
|
||||
// START_WAL_PUSH is in progress
|
||||
is_start_wal_push: bool,
|
||||
// pointer to Safekeeper in walproposer for callbacks
|
||||
raw_ptr: *mut walproposer::bindings::Safekeeper,
|
||||
}
|
||||
|
||||
impl SafekeeperConn {
|
||||
pub fn new(host: String, port: String) -> Self {
|
||||
// port number is the same as NodeId
|
||||
let port_num = port.parse::<u32>().unwrap();
|
||||
Self {
|
||||
host,
|
||||
port,
|
||||
node_id: port_num,
|
||||
socket: None,
|
||||
is_connecting: false,
|
||||
is_start_wal_push: false,
|
||||
raw_ptr: std::ptr::null_mut(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Simulation version of a postgres WaitEventSet. At pos 0 there is always
|
||||
/// a special NodeEvents channel, which is used as a latch.
|
||||
struct EventSet {
|
||||
os: NodeOs,
|
||||
// all pollable channels, 0 is always NodeEvent channel
|
||||
chans: Vec<Box<dyn PollSome>>,
|
||||
// 0 is always nullptr
|
||||
sk_ptrs: Vec<*mut walproposer::bindings::Safekeeper>,
|
||||
// event mask for each channel
|
||||
masks: Vec<u32>,
|
||||
}
|
||||
|
||||
impl EventSet {
|
||||
pub fn new(os: NodeOs) -> Self {
|
||||
let node_events = os.node_events();
|
||||
Self {
|
||||
os,
|
||||
chans: vec![Box::new(node_events)],
|
||||
sk_ptrs: vec![std::ptr::null_mut()],
|
||||
masks: vec![WL_SOCKET_READABLE],
|
||||
}
|
||||
}
|
||||
|
||||
/// Leaves all readable channels at the beginning of the array.
|
||||
fn sort_readable(&mut self) -> usize {
|
||||
let mut cnt = 1;
|
||||
for i in 1..self.chans.len() {
|
||||
if self.masks[i] & WL_SOCKET_READABLE != 0 {
|
||||
self.chans.swap(i, cnt);
|
||||
self.sk_ptrs.swap(i, cnt);
|
||||
self.masks.swap(i, cnt);
|
||||
cnt += 1;
|
||||
}
|
||||
}
|
||||
cnt
|
||||
}
|
||||
|
||||
fn update_event_set(&mut self, conn: &SafekeeperConn, event_mask: u32) {
|
||||
let index = self
|
||||
.sk_ptrs
|
||||
.iter()
|
||||
.position(|&ptr| ptr == conn.raw_ptr)
|
||||
.expect("safekeeper should exist in event set");
|
||||
self.masks[index] = event_mask;
|
||||
}
|
||||
|
||||
fn add_safekeeper(&mut self, sk: &SafekeeperConn, event_mask: u32) {
|
||||
for ptr in self.sk_ptrs.iter() {
|
||||
assert!(*ptr != sk.raw_ptr);
|
||||
}
|
||||
|
||||
self.chans.push(Box::new(
|
||||
sk.socket
|
||||
.as_ref()
|
||||
.expect("socket should not be closed")
|
||||
.recv_chan(),
|
||||
));
|
||||
self.sk_ptrs.push(sk.raw_ptr);
|
||||
self.masks.push(event_mask);
|
||||
}
|
||||
|
||||
fn remove_safekeeper(&mut self, sk: &SafekeeperConn) {
|
||||
let index = self.sk_ptrs.iter().position(|&ptr| ptr == sk.raw_ptr);
|
||||
if index.is_none() {
|
||||
debug!("remove_safekeeper: sk={:?} not found", sk.raw_ptr);
|
||||
return;
|
||||
}
|
||||
let index = index.unwrap();
|
||||
|
||||
self.chans.remove(index);
|
||||
self.sk_ptrs.remove(index);
|
||||
self.masks.remove(index);
|
||||
|
||||
// to simulate the actual behaviour
|
||||
self.refresh_event_set();
|
||||
}
|
||||
|
||||
/// Updates all masks to match the result of a SafekeeperStateDesiredEvents.
|
||||
fn refresh_event_set(&mut self) {
|
||||
for (i, mask) in self.masks.iter_mut().enumerate() {
|
||||
if i == 0 {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut mask_sk: u32 = 0;
|
||||
let mut mask_nwr: u32 = 0;
|
||||
unsafe { SafekeeperStateDesiredEvents(self.sk_ptrs[i], &mut mask_sk, &mut mask_nwr) };
|
||||
|
||||
if mask_sk != *mask {
|
||||
debug!(
|
||||
"refresh_event_set: sk={:?}, old_mask={:#b}, new_mask={:#b}",
|
||||
self.sk_ptrs[i], *mask, mask_sk
|
||||
);
|
||||
*mask = mask_sk;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Wait for events on all channels.
|
||||
fn wait(&mut self, timeout_millis: i64) -> walproposer::walproposer::WaitResult {
|
||||
// all channels are always writeable
|
||||
for (i, mask) in self.masks.iter().enumerate() {
|
||||
if *mask & WL_SOCKET_WRITEABLE != 0 {
|
||||
return walproposer::walproposer::WaitResult::Network(
|
||||
self.sk_ptrs[i],
|
||||
WL_SOCKET_WRITEABLE,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let cnt = self.sort_readable();
|
||||
|
||||
let slice = &self.chans[0..cnt];
|
||||
match executor::epoll_chans(slice, timeout_millis) {
|
||||
None => walproposer::walproposer::WaitResult::Timeout,
|
||||
Some(0) => {
|
||||
let msg = self.os.node_events().must_recv();
|
||||
match msg {
|
||||
NodeEvent::Internal(AnyMessage::Just32(0)) => {
|
||||
// got a notification about new WAL available
|
||||
}
|
||||
NodeEvent::Internal(_) => unreachable!(),
|
||||
NodeEvent::Accept(_) => unreachable!(),
|
||||
}
|
||||
walproposer::walproposer::WaitResult::Latch
|
||||
}
|
||||
Some(index) => walproposer::walproposer::WaitResult::Network(
|
||||
self.sk_ptrs[index],
|
||||
WL_SOCKET_READABLE,
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// This struct handles all calls from walproposer into walproposer_api.
|
||||
pub struct SimulationApi {
|
||||
os: NodeOs,
|
||||
safekeepers: RefCell<Vec<SafekeeperConn>>,
|
||||
disk: Arc<DiskWalProposer>,
|
||||
redo_start_lsn: Option<Lsn>,
|
||||
shmem: UnsafeCell<walproposer::bindings::WalproposerShmemState>,
|
||||
config: Config,
|
||||
event_set: RefCell<Option<EventSet>>,
|
||||
}
|
||||
|
||||
pub struct Args {
|
||||
pub os: NodeOs,
|
||||
pub config: Config,
|
||||
pub disk: Arc<DiskWalProposer>,
|
||||
pub redo_start_lsn: Option<Lsn>,
|
||||
}
|
||||
|
||||
impl SimulationApi {
|
||||
pub fn new(args: Args) -> Self {
|
||||
// initialize connection state for each safekeeper
|
||||
let sk_conns = args
|
||||
.config
|
||||
.safekeepers_list
|
||||
.iter()
|
||||
.map(|s| {
|
||||
SafekeeperConn::new(
|
||||
s.split(':').next().unwrap().to_string(),
|
||||
s.split(':').nth(1).unwrap().to_string(),
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
Self {
|
||||
os: args.os,
|
||||
safekeepers: RefCell::new(sk_conns),
|
||||
disk: args.disk,
|
||||
redo_start_lsn: args.redo_start_lsn,
|
||||
shmem: UnsafeCell::new(walproposer::bindings::WalproposerShmemState {
|
||||
mutex: 0,
|
||||
feedback: PageserverFeedback {
|
||||
currentClusterSize: 0,
|
||||
last_received_lsn: 0,
|
||||
disk_consistent_lsn: 0,
|
||||
remote_consistent_lsn: 0,
|
||||
replytime: 0,
|
||||
},
|
||||
mineLastElectedTerm: 0,
|
||||
backpressureThrottlingTime: pg_atomic_uint64 { value: 0 },
|
||||
}),
|
||||
config: args.config,
|
||||
event_set: RefCell::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get SafekeeperConn for the given Safekeeper.
|
||||
fn get_conn(&self, sk: &mut walproposer::bindings::Safekeeper) -> RefMut<'_, SafekeeperConn> {
|
||||
let sk_port = unsafe { CStr::from_ptr(sk.port).to_str().unwrap() };
|
||||
let state = self.safekeepers.borrow_mut();
|
||||
RefMut::map(state, |v| {
|
||||
v.iter_mut()
|
||||
.find(|conn| conn.port == sk_port)
|
||||
.expect("safekeeper conn not found by port")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl ApiImpl for SimulationApi {
|
||||
fn get_current_timestamp(&self) -> i64 {
|
||||
debug!("get_current_timestamp");
|
||||
// PG TimestampTZ is microseconds, but simulation unit is assumed to be
|
||||
// milliseconds, so add 10^3
|
||||
self.os.now() as i64 * 1000
|
||||
}
|
||||
|
||||
fn conn_status(
|
||||
&self,
|
||||
_: &mut walproposer::bindings::Safekeeper,
|
||||
) -> walproposer::bindings::WalProposerConnStatusType {
|
||||
debug!("conn_status");
|
||||
// break the connection with a 10% chance
|
||||
if self.os.random(100) < 10 {
|
||||
walproposer::bindings::WalProposerConnStatusType_WP_CONNECTION_BAD
|
||||
} else {
|
||||
walproposer::bindings::WalProposerConnStatusType_WP_CONNECTION_OK
|
||||
}
|
||||
}
|
||||
|
||||
fn conn_connect_start(&self, sk: &mut walproposer::bindings::Safekeeper) {
|
||||
debug!("conn_connect_start");
|
||||
let mut conn = self.get_conn(sk);
|
||||
|
||||
assert!(conn.socket.is_none());
|
||||
let socket = self.os.open_tcp(conn.node_id);
|
||||
conn.socket = Some(socket);
|
||||
conn.raw_ptr = sk;
|
||||
conn.is_connecting = true;
|
||||
}
|
||||
|
||||
fn conn_connect_poll(
|
||||
&self,
|
||||
_: &mut walproposer::bindings::Safekeeper,
|
||||
) -> walproposer::bindings::WalProposerConnectPollStatusType {
|
||||
debug!("conn_connect_poll");
|
||||
// TODO: break the connection here
|
||||
walproposer::bindings::WalProposerConnectPollStatusType_WP_CONN_POLLING_OK
|
||||
}
|
||||
|
||||
fn conn_send_query(&self, sk: &mut walproposer::bindings::Safekeeper, query: &str) -> bool {
|
||||
debug!("conn_send_query: {}", query);
|
||||
self.get_conn(sk).is_start_wal_push = true;
|
||||
true
|
||||
}
|
||||
|
||||
fn conn_get_query_result(
|
||||
&self,
|
||||
_: &mut walproposer::bindings::Safekeeper,
|
||||
) -> walproposer::bindings::WalProposerExecStatusType {
|
||||
debug!("conn_get_query_result");
|
||||
// TODO: break the connection here
|
||||
walproposer::bindings::WalProposerExecStatusType_WP_EXEC_SUCCESS_COPYBOTH
|
||||
}
|
||||
|
||||
fn conn_async_read(
|
||||
&self,
|
||||
sk: &mut walproposer::bindings::Safekeeper,
|
||||
vec: &mut Vec<u8>,
|
||||
) -> walproposer::bindings::PGAsyncReadResult {
|
||||
debug!("conn_async_read");
|
||||
let mut conn = self.get_conn(sk);
|
||||
|
||||
let socket = if let Some(socket) = conn.socket.as_mut() {
|
||||
socket
|
||||
} else {
|
||||
// socket is already closed
|
||||
return walproposer::bindings::PGAsyncReadResult_PG_ASYNC_READ_FAIL;
|
||||
};
|
||||
|
||||
let msg = socket.recv_chan().try_recv();
|
||||
|
||||
match msg {
|
||||
None => {
|
||||
// no message is ready
|
||||
walproposer::bindings::PGAsyncReadResult_PG_ASYNC_READ_TRY_AGAIN
|
||||
}
|
||||
Some(NetEvent::Closed) => {
|
||||
// connection is closed
|
||||
debug!("conn_async_read: connection is closed");
|
||||
conn.socket = None;
|
||||
walproposer::bindings::PGAsyncReadResult_PG_ASYNC_READ_FAIL
|
||||
}
|
||||
Some(NetEvent::Message(msg)) => {
|
||||
// got a message
|
||||
let b = match msg {
|
||||
desim::proto::AnyMessage::Bytes(b) => b,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
vec.extend_from_slice(&b);
|
||||
walproposer::bindings::PGAsyncReadResult_PG_ASYNC_READ_SUCCESS
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn conn_blocking_write(&self, sk: &mut walproposer::bindings::Safekeeper, buf: &[u8]) -> bool {
|
||||
let mut conn = self.get_conn(sk);
|
||||
debug!("conn_blocking_write to {}: {:?}", conn.node_id, buf);
|
||||
let socket = conn.socket.as_mut().unwrap();
|
||||
socket.send(desim::proto::AnyMessage::Bytes(Bytes::copy_from_slice(buf)));
|
||||
true
|
||||
}
|
||||
|
||||
fn conn_async_write(
|
||||
&self,
|
||||
sk: &mut walproposer::bindings::Safekeeper,
|
||||
buf: &[u8],
|
||||
) -> walproposer::bindings::PGAsyncWriteResult {
|
||||
let mut conn = self.get_conn(sk);
|
||||
debug!("conn_async_write to {}: {:?}", conn.node_id, buf);
|
||||
if let Some(socket) = conn.socket.as_mut() {
|
||||
socket.send(desim::proto::AnyMessage::Bytes(Bytes::copy_from_slice(buf)));
|
||||
} else {
|
||||
// connection is already closed
|
||||
debug!("conn_async_write: writing to a closed socket!");
|
||||
// TODO: maybe we should return error here?
|
||||
}
|
||||
walproposer::bindings::PGAsyncWriteResult_PG_ASYNC_WRITE_SUCCESS
|
||||
}
|
||||
|
||||
fn wal_reader_allocate(&self, _: &mut walproposer::bindings::Safekeeper) -> NeonWALReadResult {
|
||||
debug!("wal_reader_allocate");
|
||||
walproposer::bindings::NeonWALReadResult_NEON_WALREAD_SUCCESS
|
||||
}
|
||||
|
||||
fn wal_read(
|
||||
&self,
|
||||
_sk: &mut walproposer::bindings::Safekeeper,
|
||||
buf: &mut [u8],
|
||||
startpos: u64,
|
||||
) -> NeonWALReadResult {
|
||||
self.disk.lock().read(startpos, buf);
|
||||
walproposer::bindings::NeonWALReadResult_NEON_WALREAD_SUCCESS
|
||||
}
|
||||
|
||||
fn init_event_set(&self, _: &mut walproposer::bindings::WalProposer) {
|
||||
debug!("init_event_set");
|
||||
let new_event_set = EventSet::new(self.os.clone());
|
||||
let old_event_set = self.event_set.replace(Some(new_event_set));
|
||||
assert!(old_event_set.is_none());
|
||||
}
|
||||
|
||||
fn update_event_set(&self, sk: &mut walproposer::bindings::Safekeeper, event_mask: u32) {
|
||||
debug!(
|
||||
"update_event_set, sk={:?}, events_mask={:#b}",
|
||||
sk as *mut walproposer::bindings::Safekeeper, event_mask
|
||||
);
|
||||
let conn = self.get_conn(sk);
|
||||
|
||||
self.event_set
|
||||
.borrow_mut()
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.update_event_set(&conn, event_mask);
|
||||
}
|
||||
|
||||
fn add_safekeeper_event_set(
|
||||
&self,
|
||||
sk: &mut walproposer::bindings::Safekeeper,
|
||||
event_mask: u32,
|
||||
) {
|
||||
debug!(
|
||||
"add_safekeeper_event_set, sk={:?}, events_mask={:#b}",
|
||||
sk as *mut walproposer::bindings::Safekeeper, event_mask
|
||||
);
|
||||
|
||||
self.event_set
|
||||
.borrow_mut()
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.add_safekeeper(&self.get_conn(sk), event_mask);
|
||||
}
|
||||
|
||||
fn rm_safekeeper_event_set(&self, sk: &mut walproposer::bindings::Safekeeper) {
|
||||
debug!(
|
||||
"rm_safekeeper_event_set, sk={:?}",
|
||||
sk as *mut walproposer::bindings::Safekeeper,
|
||||
);
|
||||
|
||||
self.event_set
|
||||
.borrow_mut()
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.remove_safekeeper(&self.get_conn(sk));
|
||||
}
|
||||
|
||||
fn active_state_update_event_set(&self, sk: &mut walproposer::bindings::Safekeeper) {
|
||||
debug!("active_state_update_event_set");
|
||||
|
||||
assert!(sk.state == walproposer::bindings::SafekeeperState_SS_ACTIVE);
|
||||
self.event_set
|
||||
.borrow_mut()
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.refresh_event_set();
|
||||
}
|
||||
|
||||
fn wal_reader_events(&self, _sk: &mut walproposer::bindings::Safekeeper) -> u32 {
|
||||
0
|
||||
}
|
||||
|
||||
fn wait_event_set(
|
||||
&self,
|
||||
_: &mut walproposer::bindings::WalProposer,
|
||||
timeout_millis: i64,
|
||||
) -> walproposer::walproposer::WaitResult {
|
||||
// TODO: handle multiple stages as part of the simulation (e.g. connect, start_wal_push, etc)
|
||||
let mut conns = self.safekeepers.borrow_mut();
|
||||
for conn in conns.iter_mut() {
|
||||
if conn.socket.is_some() && conn.is_connecting {
|
||||
conn.is_connecting = false;
|
||||
debug!("wait_event_set, connecting to {}:{}", conn.host, conn.port);
|
||||
return walproposer::walproposer::WaitResult::Network(
|
||||
conn.raw_ptr,
|
||||
WL_SOCKET_READABLE | WL_SOCKET_WRITEABLE,
|
||||
);
|
||||
}
|
||||
if conn.socket.is_some() && conn.is_start_wal_push {
|
||||
conn.is_start_wal_push = false;
|
||||
debug!(
|
||||
"wait_event_set, start wal push to {}:{}",
|
||||
conn.host, conn.port
|
||||
);
|
||||
return walproposer::walproposer::WaitResult::Network(
|
||||
conn.raw_ptr,
|
||||
WL_SOCKET_READABLE,
|
||||
);
|
||||
}
|
||||
}
|
||||
drop(conns);
|
||||
|
||||
let res = self
|
||||
.event_set
|
||||
.borrow_mut()
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.wait(timeout_millis);
|
||||
|
||||
debug!(
|
||||
"wait_event_set, timeout_millis={}, res={:?}",
|
||||
timeout_millis, res,
|
||||
);
|
||||
res
|
||||
}
|
||||
|
||||
fn strong_random(&self, buf: &mut [u8]) -> bool {
|
||||
debug!("strong_random");
|
||||
buf.fill(0);
|
||||
true
|
||||
}
|
||||
|
||||
fn finish_sync_safekeepers(&self, lsn: u64) {
|
||||
debug!("finish_sync_safekeepers, lsn={}", lsn);
|
||||
executor::exit(0, Lsn(lsn).to_string());
|
||||
}
|
||||
|
||||
fn log_internal(&self, _wp: &mut walproposer::bindings::WalProposer, level: Level, msg: &str) {
|
||||
debug!("wp_log[{}] {}", level, msg);
|
||||
if level == Level::Fatal || level == Level::Panic {
|
||||
if msg.contains("rejects our connection request with term") {
|
||||
// collected quorum with lower term, then got rejected by next connected safekeeper
|
||||
executor::exit(1, msg.to_owned());
|
||||
}
|
||||
if msg.contains("collected propEpochStartLsn") && msg.contains(", but basebackup LSN ")
|
||||
{
|
||||
// sync-safekeepers collected wrong quorum, walproposer collected another quorum
|
||||
executor::exit(1, msg.to_owned());
|
||||
}
|
||||
if msg.contains("failed to download WAL for logical replicaiton") {
|
||||
// Recovery connection broken and recovery was failed
|
||||
executor::exit(1, msg.to_owned());
|
||||
}
|
||||
if msg.contains("missing majority of votes, collected") {
|
||||
// Voting bug when safekeeper disconnects after voting
|
||||
executor::exit(1, msg.to_owned());
|
||||
}
|
||||
panic!("unknown FATAL error from walproposer: {}", msg);
|
||||
}
|
||||
}
|
||||
|
||||
fn after_election(&self, wp: &mut walproposer::bindings::WalProposer) {
|
||||
let prop_lsn = wp.propEpochStartLsn;
|
||||
let prop_term = wp.propTerm;
|
||||
|
||||
let mut prev_lsn: u64 = 0;
|
||||
let mut prev_term: u64 = 0;
|
||||
|
||||
unsafe {
|
||||
let history = wp.propTermHistory.entries;
|
||||
let len = wp.propTermHistory.n_entries as usize;
|
||||
if len > 1 {
|
||||
let entry = *history.wrapping_add(len - 2);
|
||||
prev_lsn = entry.lsn;
|
||||
prev_term = entry.term;
|
||||
}
|
||||
}
|
||||
|
||||
let msg = format!(
|
||||
"prop_elected;{};{};{};{}",
|
||||
prop_lsn, prop_term, prev_lsn, prev_term
|
||||
);
|
||||
|
||||
debug!(msg);
|
||||
self.os.log_event(msg);
|
||||
}
|
||||
|
||||
fn get_redo_start_lsn(&self) -> u64 {
|
||||
debug!("get_redo_start_lsn -> {:?}", self.redo_start_lsn);
|
||||
self.redo_start_lsn.expect("redo_start_lsn is not set").0
|
||||
}
|
||||
|
||||
fn get_shmem_state(&self) -> *mut walproposer::bindings::WalproposerShmemState {
|
||||
self.shmem.get()
|
||||
}
|
||||
|
||||
fn start_streaming(
|
||||
&self,
|
||||
startpos: u64,
|
||||
callback: &walproposer::walproposer::StreamingCallback,
|
||||
) {
|
||||
let disk = &self.disk;
|
||||
let disk_lsn = disk.lock().flush_rec_ptr().0;
|
||||
debug!("start_streaming at {} (disk_lsn={})", startpos, disk_lsn);
|
||||
if startpos < disk_lsn {
|
||||
debug!("startpos < disk_lsn, it means we wrote some transaction even before streaming started");
|
||||
}
|
||||
assert!(startpos <= disk_lsn);
|
||||
let mut broadcasted = Lsn(startpos);
|
||||
|
||||
loop {
|
||||
let available = disk.lock().flush_rec_ptr();
|
||||
assert!(available >= broadcasted);
|
||||
callback.broadcast(broadcasted, available);
|
||||
broadcasted = available;
|
||||
callback.poll();
|
||||
}
|
||||
}
|
||||
|
||||
fn process_safekeeper_feedback(
|
||||
&self,
|
||||
wp: &mut walproposer::bindings::WalProposer,
|
||||
commit_lsn: u64,
|
||||
) {
|
||||
debug!("process_safekeeper_feedback, commit_lsn={}", commit_lsn);
|
||||
if commit_lsn > wp.lastSentCommitLsn {
|
||||
self.os.log_event(format!("commit_lsn;{}", commit_lsn));
|
||||
}
|
||||
}
|
||||
|
||||
fn get_flush_rec_ptr(&self) -> u64 {
|
||||
let lsn = self.disk.lock().flush_rec_ptr();
|
||||
debug!("get_flush_rec_ptr: {}", lsn);
|
||||
lsn.0
|
||||
}
|
||||
|
||||
fn recovery_download(
|
||||
&self,
|
||||
wp: &mut walproposer::bindings::WalProposer,
|
||||
sk: &mut walproposer::bindings::Safekeeper,
|
||||
) -> bool {
|
||||
let mut startpos = wp.truncateLsn;
|
||||
let endpos = wp.propEpochStartLsn;
|
||||
|
||||
if startpos == endpos {
|
||||
debug!("recovery_download: nothing to download");
|
||||
return true;
|
||||
}
|
||||
|
||||
debug!("recovery_download from {} to {}", startpos, endpos,);
|
||||
|
||||
let replication_prompt = format!(
|
||||
"START_REPLICATION {} {} {} {}",
|
||||
self.config.ttid.tenant_id, self.config.ttid.timeline_id, startpos, endpos,
|
||||
);
|
||||
let async_conn = self.get_conn(sk);
|
||||
|
||||
let conn = self.os.open_tcp(async_conn.node_id);
|
||||
conn.send(desim::proto::AnyMessage::Bytes(replication_prompt.into()));
|
||||
|
||||
let chan = conn.recv_chan();
|
||||
while startpos < endpos {
|
||||
let event = chan.recv();
|
||||
match event {
|
||||
NetEvent::Closed => {
|
||||
debug!("connection closed in recovery");
|
||||
break;
|
||||
}
|
||||
NetEvent::Message(AnyMessage::Bytes(b)) => {
|
||||
debug!("got recovery bytes from safekeeper");
|
||||
self.disk.lock().write(startpos, &b);
|
||||
startpos += b.len() as u64;
|
||||
}
|
||||
NetEvent::Message(_) => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
debug!("recovery finished at {}", startpos);
|
||||
|
||||
startpos == endpos
|
||||
}
|
||||
|
||||
fn conn_finish(&self, sk: &mut walproposer::bindings::Safekeeper) {
|
||||
let mut conn = self.get_conn(sk);
|
||||
debug!("conn_finish to {}", conn.node_id);
|
||||
if let Some(socket) = conn.socket.as_mut() {
|
||||
socket.close();
|
||||
} else {
|
||||
// connection is already closed
|
||||
}
|
||||
conn.socket = None;
|
||||
}
|
||||
|
||||
fn conn_error_message(&self, _sk: &mut walproposer::bindings::Safekeeper) -> String {
|
||||
"connection is closed, probably".into()
|
||||
}
|
||||
}
|
||||
314
safekeeper/tests/walproposer_sim/walproposer_disk.rs
Normal file
314
safekeeper/tests/walproposer_sim/walproposer_disk.rs
Normal file
@@ -0,0 +1,314 @@
|
||||
use std::{ffi::CString, sync::Arc};
|
||||
|
||||
use byteorder::{LittleEndian, WriteBytesExt};
|
||||
use crc32c::crc32c_append;
|
||||
use parking_lot::{Mutex, MutexGuard};
|
||||
use postgres_ffi::{
|
||||
pg_constants::{
|
||||
RM_LOGICALMSG_ID, XLOG_LOGICAL_MESSAGE, XLP_LONG_HEADER, XLR_BLOCK_ID_DATA_LONG,
|
||||
XLR_BLOCK_ID_DATA_SHORT,
|
||||
},
|
||||
v16::{
|
||||
wal_craft_test_export::{XLogLongPageHeaderData, XLogPageHeaderData, XLOG_PAGE_MAGIC},
|
||||
xlog_utils::{
|
||||
XLogSegNoOffsetToRecPtr, XlLogicalMessage, XLOG_RECORD_CRC_OFFS,
|
||||
XLOG_SIZE_OF_XLOG_LONG_PHD, XLOG_SIZE_OF_XLOG_RECORD, XLOG_SIZE_OF_XLOG_SHORT_PHD,
|
||||
XLP_FIRST_IS_CONTRECORD,
|
||||
},
|
||||
XLogRecord,
|
||||
},
|
||||
WAL_SEGMENT_SIZE, XLOG_BLCKSZ,
|
||||
};
|
||||
use utils::lsn::Lsn;
|
||||
|
||||
use super::block_storage::BlockStorage;
|
||||
|
||||
/// Simulation implementation of walproposer WAL storage.
|
||||
pub struct DiskWalProposer {
|
||||
state: Mutex<State>,
|
||||
}
|
||||
|
||||
impl DiskWalProposer {
|
||||
pub fn new() -> Arc<DiskWalProposer> {
|
||||
Arc::new(DiskWalProposer {
|
||||
state: Mutex::new(State {
|
||||
internal_available_lsn: Lsn(0),
|
||||
prev_lsn: Lsn(0),
|
||||
disk: BlockStorage::new(),
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn lock(&self) -> MutexGuard<State> {
|
||||
self.state.lock()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct State {
|
||||
// flush_lsn
|
||||
internal_available_lsn: Lsn,
|
||||
// needed for WAL generation
|
||||
prev_lsn: Lsn,
|
||||
// actual WAL storage
|
||||
disk: BlockStorage,
|
||||
}
|
||||
|
||||
impl State {
|
||||
pub fn read(&self, pos: u64, buf: &mut [u8]) {
|
||||
self.disk.read(pos, buf);
|
||||
// TODO: fail on reading uninitialized data
|
||||
}
|
||||
|
||||
pub fn write(&mut self, pos: u64, buf: &[u8]) {
|
||||
self.disk.write(pos, buf);
|
||||
}
|
||||
|
||||
/// Update the internal available LSN to the given value.
|
||||
pub fn reset_to(&mut self, lsn: Lsn) {
|
||||
self.internal_available_lsn = lsn;
|
||||
}
|
||||
|
||||
/// Get current LSN.
|
||||
pub fn flush_rec_ptr(&self) -> Lsn {
|
||||
self.internal_available_lsn
|
||||
}
|
||||
|
||||
/// Generate a new WAL record at the current LSN.
|
||||
pub fn insert_logical_message(&mut self, prefix: &str, msg: &[u8]) -> anyhow::Result<()> {
|
||||
let prefix_cstr = CString::new(prefix)?;
|
||||
let prefix_bytes = prefix_cstr.as_bytes_with_nul();
|
||||
|
||||
let lm = XlLogicalMessage {
|
||||
db_id: 0,
|
||||
transactional: 0,
|
||||
prefix_size: prefix_bytes.len() as ::std::os::raw::c_ulong,
|
||||
message_size: msg.len() as ::std::os::raw::c_ulong,
|
||||
};
|
||||
|
||||
let record_bytes = lm.encode();
|
||||
let rdatas: Vec<&[u8]> = vec![&record_bytes, prefix_bytes, msg];
|
||||
insert_wal_record(self, rdatas, RM_LOGICALMSG_ID, XLOG_LOGICAL_MESSAGE)
|
||||
}
|
||||
}
|
||||
|
||||
fn insert_wal_record(
|
||||
state: &mut State,
|
||||
rdatas: Vec<&[u8]>,
|
||||
rmid: u8,
|
||||
info: u8,
|
||||
) -> anyhow::Result<()> {
|
||||
// bytes right after the header, in the same rdata block
|
||||
let mut scratch = Vec::new();
|
||||
let mainrdata_len: usize = rdatas.iter().map(|rdata| rdata.len()).sum();
|
||||
|
||||
if mainrdata_len > 0 {
|
||||
if mainrdata_len > 255 {
|
||||
scratch.push(XLR_BLOCK_ID_DATA_LONG);
|
||||
// TODO: verify endiness
|
||||
let _ = scratch.write_u32::<LittleEndian>(mainrdata_len as u32);
|
||||
} else {
|
||||
scratch.push(XLR_BLOCK_ID_DATA_SHORT);
|
||||
scratch.push(mainrdata_len as u8);
|
||||
}
|
||||
}
|
||||
|
||||
let total_len: u32 = (XLOG_SIZE_OF_XLOG_RECORD + scratch.len() + mainrdata_len) as u32;
|
||||
let size = maxalign(total_len);
|
||||
assert!(size as usize > XLOG_SIZE_OF_XLOG_RECORD);
|
||||
|
||||
let start_bytepos = recptr_to_bytepos(state.internal_available_lsn);
|
||||
let end_bytepos = start_bytepos + size as u64;
|
||||
|
||||
let start_recptr = bytepos_to_recptr(start_bytepos);
|
||||
let end_recptr = bytepos_to_recptr(end_bytepos);
|
||||
|
||||
assert!(recptr_to_bytepos(start_recptr) == start_bytepos);
|
||||
assert!(recptr_to_bytepos(end_recptr) == end_bytepos);
|
||||
|
||||
let mut crc = crc32c_append(0, &scratch);
|
||||
for rdata in &rdatas {
|
||||
crc = crc32c_append(crc, rdata);
|
||||
}
|
||||
|
||||
let mut header = XLogRecord {
|
||||
xl_tot_len: total_len,
|
||||
xl_xid: 0,
|
||||
xl_prev: state.prev_lsn.0,
|
||||
xl_info: info,
|
||||
xl_rmid: rmid,
|
||||
__bindgen_padding_0: [0u8; 2usize],
|
||||
xl_crc: crc,
|
||||
};
|
||||
|
||||
// now we have the header and can finish the crc
|
||||
let header_bytes = header.encode()?;
|
||||
let crc = crc32c_append(crc, &header_bytes[0..XLOG_RECORD_CRC_OFFS]);
|
||||
header.xl_crc = crc;
|
||||
|
||||
let mut header_bytes = header.encode()?.to_vec();
|
||||
assert!(header_bytes.len() == XLOG_SIZE_OF_XLOG_RECORD);
|
||||
|
||||
header_bytes.extend_from_slice(&scratch);
|
||||
|
||||
// finish rdatas
|
||||
let mut rdatas = rdatas;
|
||||
rdatas.insert(0, &header_bytes);
|
||||
|
||||
write_walrecord_to_disk(state, total_len as u64, rdatas, start_recptr, end_recptr)?;
|
||||
|
||||
state.internal_available_lsn = end_recptr;
|
||||
state.prev_lsn = start_recptr;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn write_walrecord_to_disk(
|
||||
state: &mut State,
|
||||
total_len: u64,
|
||||
rdatas: Vec<&[u8]>,
|
||||
start: Lsn,
|
||||
end: Lsn,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut curr_ptr = start;
|
||||
let mut freespace = insert_freespace(curr_ptr);
|
||||
let mut written: usize = 0;
|
||||
|
||||
assert!(freespace >= std::mem::size_of::<u32>());
|
||||
|
||||
for mut rdata in rdatas {
|
||||
while rdata.len() >= freespace {
|
||||
assert!(
|
||||
curr_ptr.segment_offset(WAL_SEGMENT_SIZE) >= XLOG_SIZE_OF_XLOG_SHORT_PHD
|
||||
|| freespace == 0
|
||||
);
|
||||
|
||||
state.write(curr_ptr.0, &rdata[..freespace]);
|
||||
rdata = &rdata[freespace..];
|
||||
written += freespace;
|
||||
curr_ptr = Lsn(curr_ptr.0 + freespace as u64);
|
||||
|
||||
let mut new_page = XLogPageHeaderData {
|
||||
xlp_magic: XLOG_PAGE_MAGIC as u16,
|
||||
xlp_info: XLP_BKP_REMOVABLE,
|
||||
xlp_tli: 1,
|
||||
xlp_pageaddr: curr_ptr.0,
|
||||
xlp_rem_len: (total_len - written as u64) as u32,
|
||||
..Default::default() // Put 0 in padding fields.
|
||||
};
|
||||
if new_page.xlp_rem_len > 0 {
|
||||
new_page.xlp_info |= XLP_FIRST_IS_CONTRECORD;
|
||||
}
|
||||
|
||||
if curr_ptr.segment_offset(WAL_SEGMENT_SIZE) == 0 {
|
||||
new_page.xlp_info |= XLP_LONG_HEADER;
|
||||
let long_page = XLogLongPageHeaderData {
|
||||
std: new_page,
|
||||
xlp_sysid: 0,
|
||||
xlp_seg_size: WAL_SEGMENT_SIZE as u32,
|
||||
xlp_xlog_blcksz: XLOG_BLCKSZ as u32,
|
||||
};
|
||||
let header_bytes = long_page.encode()?;
|
||||
assert!(header_bytes.len() == XLOG_SIZE_OF_XLOG_LONG_PHD);
|
||||
state.write(curr_ptr.0, &header_bytes);
|
||||
curr_ptr = Lsn(curr_ptr.0 + header_bytes.len() as u64);
|
||||
} else {
|
||||
let header_bytes = new_page.encode()?;
|
||||
assert!(header_bytes.len() == XLOG_SIZE_OF_XLOG_SHORT_PHD);
|
||||
state.write(curr_ptr.0, &header_bytes);
|
||||
curr_ptr = Lsn(curr_ptr.0 + header_bytes.len() as u64);
|
||||
}
|
||||
freespace = insert_freespace(curr_ptr);
|
||||
}
|
||||
|
||||
assert!(
|
||||
curr_ptr.segment_offset(WAL_SEGMENT_SIZE) >= XLOG_SIZE_OF_XLOG_SHORT_PHD
|
||||
|| rdata.is_empty()
|
||||
);
|
||||
state.write(curr_ptr.0, rdata);
|
||||
curr_ptr = Lsn(curr_ptr.0 + rdata.len() as u64);
|
||||
written += rdata.len();
|
||||
freespace -= rdata.len();
|
||||
}
|
||||
|
||||
assert!(written == total_len as usize);
|
||||
curr_ptr.0 = maxalign(curr_ptr.0);
|
||||
assert!(curr_ptr == end);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn maxalign<T>(size: T) -> T
|
||||
where
|
||||
T: std::ops::BitAnd<Output = T>
|
||||
+ std::ops::Add<Output = T>
|
||||
+ std::ops::Not<Output = T>
|
||||
+ From<u8>,
|
||||
{
|
||||
(size + T::from(7)) & !T::from(7)
|
||||
}
|
||||
|
||||
fn insert_freespace(ptr: Lsn) -> usize {
|
||||
if ptr.block_offset() == 0 {
|
||||
0
|
||||
} else {
|
||||
(XLOG_BLCKSZ as u64 - ptr.block_offset()) as usize
|
||||
}
|
||||
}
|
||||
|
||||
const XLP_BKP_REMOVABLE: u16 = 0x0004;
|
||||
const USABLE_BYTES_IN_PAGE: u64 = (XLOG_BLCKSZ - XLOG_SIZE_OF_XLOG_SHORT_PHD) as u64;
|
||||
const USABLE_BYTES_IN_SEGMENT: u64 = ((WAL_SEGMENT_SIZE / XLOG_BLCKSZ) as u64
|
||||
* USABLE_BYTES_IN_PAGE)
|
||||
- (XLOG_SIZE_OF_XLOG_RECORD - XLOG_SIZE_OF_XLOG_SHORT_PHD) as u64;
|
||||
|
||||
fn bytepos_to_recptr(bytepos: u64) -> Lsn {
|
||||
let fullsegs = bytepos / USABLE_BYTES_IN_SEGMENT;
|
||||
let mut bytesleft = bytepos % USABLE_BYTES_IN_SEGMENT;
|
||||
|
||||
let seg_offset = if bytesleft < (XLOG_BLCKSZ - XLOG_SIZE_OF_XLOG_SHORT_PHD) as u64 {
|
||||
// fits on first page of segment
|
||||
bytesleft + XLOG_SIZE_OF_XLOG_SHORT_PHD as u64
|
||||
} else {
|
||||
// account for the first page on segment with long header
|
||||
bytesleft -= (XLOG_BLCKSZ - XLOG_SIZE_OF_XLOG_SHORT_PHD) as u64;
|
||||
let fullpages = bytesleft / USABLE_BYTES_IN_PAGE;
|
||||
bytesleft %= USABLE_BYTES_IN_PAGE;
|
||||
|
||||
XLOG_BLCKSZ as u64
|
||||
+ fullpages * XLOG_BLCKSZ as u64
|
||||
+ bytesleft
|
||||
+ XLOG_SIZE_OF_XLOG_SHORT_PHD as u64
|
||||
};
|
||||
|
||||
Lsn(XLogSegNoOffsetToRecPtr(
|
||||
fullsegs,
|
||||
seg_offset as u32,
|
||||
WAL_SEGMENT_SIZE,
|
||||
))
|
||||
}
|
||||
|
||||
fn recptr_to_bytepos(ptr: Lsn) -> u64 {
|
||||
let fullsegs = ptr.segment_number(WAL_SEGMENT_SIZE);
|
||||
let offset = ptr.segment_offset(WAL_SEGMENT_SIZE) as u64;
|
||||
|
||||
let fullpages = offset / XLOG_BLCKSZ as u64;
|
||||
let offset = offset % XLOG_BLCKSZ as u64;
|
||||
|
||||
if fullpages == 0 {
|
||||
fullsegs * USABLE_BYTES_IN_SEGMENT
|
||||
+ if offset > 0 {
|
||||
assert!(offset >= XLOG_SIZE_OF_XLOG_SHORT_PHD as u64);
|
||||
offset - XLOG_SIZE_OF_XLOG_SHORT_PHD as u64
|
||||
} else {
|
||||
0
|
||||
}
|
||||
} else {
|
||||
fullsegs * USABLE_BYTES_IN_SEGMENT
|
||||
+ (XLOG_BLCKSZ - XLOG_SIZE_OF_XLOG_SHORT_PHD) as u64
|
||||
+ (fullpages - 1) * USABLE_BYTES_IN_PAGE
|
||||
+ if offset > 0 {
|
||||
assert!(offset >= XLOG_SIZE_OF_XLOG_SHORT_PHD as u64);
|
||||
offset - XLOG_SIZE_OF_XLOG_SHORT_PHD as u64
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -96,5 +96,6 @@ PAGESERVER_PER_TENANT_METRICS: Tuple[str, ...] = (
|
||||
"pageserver_evictions_total",
|
||||
"pageserver_evictions_with_low_residence_duration_total",
|
||||
*PAGESERVER_PER_TENANT_REMOTE_TIMELINE_CLIENT_METRICS,
|
||||
# "pageserver_directory_entries_count", -- only used if above a certain threshold
|
||||
# "pageserver_broken_tenants_count" -- used only for broken
|
||||
)
|
||||
|
||||
@@ -899,7 +899,7 @@ class NeonEnvBuilder:
|
||||
|
||||
if self.scrub_on_exit:
|
||||
try:
|
||||
S3Scrubber(self.test_output_dir, self).scan_metadata()
|
||||
S3Scrubber(self).scan_metadata()
|
||||
except Exception as e:
|
||||
log.error(f"Error during remote storage scrub: {e}")
|
||||
cleanup_error = e
|
||||
@@ -3659,9 +3659,9 @@ class SafekeeperHttpClient(requests.Session):
|
||||
|
||||
|
||||
class S3Scrubber:
|
||||
def __init__(self, log_dir: Path, env: NeonEnvBuilder):
|
||||
def __init__(self, env: NeonEnvBuilder, log_dir: Optional[Path] = None):
|
||||
self.env = env
|
||||
self.log_dir = log_dir
|
||||
self.log_dir = log_dir or env.test_output_dir
|
||||
|
||||
def scrubber_cli(self, args: list[str], timeout) -> str:
|
||||
assert isinstance(self.env.pageserver_remote_storage, S3Storage)
|
||||
@@ -3682,7 +3682,7 @@ class S3Scrubber:
|
||||
args = base_args + args
|
||||
|
||||
(output_path, stdout, status_code) = subprocess_capture(
|
||||
self.log_dir,
|
||||
self.env.test_output_dir,
|
||||
args,
|
||||
echo_stderr=True,
|
||||
echo_stdout=True,
|
||||
@@ -3967,27 +3967,24 @@ def list_files_to_compare(pgdata_dir: Path) -> List[str]:
|
||||
|
||||
# pg is the existing and running compute node, that we want to compare with a basebackup
|
||||
def check_restored_datadir_content(test_output_dir: Path, env: NeonEnv, endpoint: Endpoint):
|
||||
pg_bin = PgBin(test_output_dir, env.pg_distrib_dir, env.pg_version)
|
||||
|
||||
# Get the timeline ID. We need it for the 'basebackup' command
|
||||
timeline_id = TimelineId(endpoint.safe_psql("SHOW neon.timeline_id")[0][0])
|
||||
|
||||
# many tests already checkpoint, but do it just in case
|
||||
with closing(endpoint.connect()) as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("CHECKPOINT")
|
||||
|
||||
# wait for pageserver to catch up
|
||||
wait_for_last_flush_lsn(env, endpoint, endpoint.tenant_id, timeline_id)
|
||||
# stop postgres to ensure that files won't change
|
||||
endpoint.stop()
|
||||
|
||||
# Read the shutdown checkpoint's LSN
|
||||
pg_controldata_path = os.path.join(pg_bin.pg_bin_path, "pg_controldata")
|
||||
cmd = f"{pg_controldata_path} -D {endpoint.pgdata_dir}"
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, shell=True)
|
||||
checkpoint_lsn = re.findall(
|
||||
"Latest checkpoint location:\\s+([0-9A-F]+/[0-9A-F]+)", result.stdout
|
||||
)[0]
|
||||
log.debug(f"last checkpoint at {checkpoint_lsn}")
|
||||
|
||||
# Take a basebackup from pageserver
|
||||
restored_dir_path = env.repo_dir / f"{endpoint.endpoint_id}_restored_datadir"
|
||||
restored_dir_path.mkdir(exist_ok=True)
|
||||
|
||||
pg_bin = PgBin(test_output_dir, env.pg_distrib_dir, env.pg_version)
|
||||
psql_path = os.path.join(pg_bin.pg_bin_path, "psql")
|
||||
|
||||
pageserver_id = env.attachment_service.locate(endpoint.tenant_id)[0]["node_id"]
|
||||
@@ -3995,7 +3992,7 @@ def check_restored_datadir_content(test_output_dir: Path, env: NeonEnv, endpoint
|
||||
{psql_path} \
|
||||
--no-psqlrc \
|
||||
postgres://localhost:{env.get_pageserver(pageserver_id).service_port.pg} \
|
||||
-c 'basebackup {endpoint.tenant_id} {timeline_id} {checkpoint_lsn}' \
|
||||
-c 'basebackup {endpoint.tenant_id} {timeline_id}' \
|
||||
| tar -x -C {restored_dir_path}
|
||||
"""
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user