mirror of
https://github.com/neondatabase/neon.git
synced 2026-02-10 22:20:38 +00:00
Compare commits
2 Commits
bojan/prox
...
bojan/slow
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
800811957b | ||
|
|
d650e42ae3 |
98
Cargo.lock
generated
98
Cargo.lock
generated
@@ -23,17 +23,6 @@ version = "0.4.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "739f4a8db6605981345c5654f3a85b056ce52f37a39d34da03f25bf2151ea16e"
|
||||
|
||||
[[package]]
|
||||
name = "ahash"
|
||||
version = "0.7.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47"
|
||||
dependencies = [
|
||||
"getrandom",
|
||||
"once_cell",
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aho-corasick"
|
||||
version = "0.7.18"
|
||||
@@ -552,17 +541,6 @@ dependencies = [
|
||||
"termcolor",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fail"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ec3245a0ca564e7f3c797d20d833a6870f57a728ac967d5225b3ffdef4465011"
|
||||
dependencies = [
|
||||
"lazy_static",
|
||||
"log",
|
||||
"rand",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fallible-iterator"
|
||||
version = "0.2.0"
|
||||
@@ -791,7 +769,7 @@ version = "0.9.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d7afe4a420e3fe79967a00898cc1f4db7c8a49a9333a29f8a4bd76a253d5cd04"
|
||||
dependencies = [
|
||||
"ahash 0.4.7",
|
||||
"ahash",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -799,9 +777,6 @@ name = "hashbrown"
|
||||
version = "0.11.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e"
|
||||
dependencies = [
|
||||
"ahash 0.7.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hermit-abi"
|
||||
@@ -921,7 +896,7 @@ dependencies = [
|
||||
"hyper",
|
||||
"rustls 0.20.2",
|
||||
"tokio",
|
||||
"tokio-rustls 0.23.2",
|
||||
"tokio-rustls",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1006,7 +981,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "afabcc15e437a6484fc4f12d0fd63068fe457bf93f1c148d3d9649c60b103f32"
|
||||
dependencies = [
|
||||
"base64 0.12.3",
|
||||
"pem 0.8.3",
|
||||
"pem",
|
||||
"ring",
|
||||
"serde",
|
||||
"serde_json",
|
||||
@@ -1300,7 +1275,6 @@ dependencies = [
|
||||
"crc32c",
|
||||
"crossbeam-utils",
|
||||
"daemonize",
|
||||
"fail",
|
||||
"futures",
|
||||
"hex",
|
||||
"hex-literal",
|
||||
@@ -1378,15 +1352,6 @@ dependencies = [
|
||||
"regex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pem"
|
||||
version = "1.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e9a3b09a20e374558580a4914d3b7d89bd61b954a5a5e1dcbea98753addb1947"
|
||||
dependencies = [
|
||||
"base64 0.13.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "percent-encoding"
|
||||
version = "2.1.0"
|
||||
@@ -1591,25 +1556,17 @@ dependencies = [
|
||||
"anyhow",
|
||||
"bytes",
|
||||
"clap 3.0.14",
|
||||
"futures",
|
||||
"hashbrown 0.11.2",
|
||||
"hex",
|
||||
"hyper",
|
||||
"lazy_static",
|
||||
"md5",
|
||||
"parking_lot",
|
||||
"pin-project-lite",
|
||||
"rand",
|
||||
"rcgen",
|
||||
"reqwest",
|
||||
"rustls 0.19.1",
|
||||
"scopeguard",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tokio",
|
||||
"tokio-postgres 0.7.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=2949d98df52587d562986aad155dd4e889e408b7)",
|
||||
"tokio-postgres-rustls",
|
||||
"tokio-rustls 0.22.0",
|
||||
"zenith_metrics",
|
||||
"zenith_utils",
|
||||
]
|
||||
@@ -1663,18 +1620,6 @@ dependencies = [
|
||||
"rand_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rcgen"
|
||||
version = "0.8.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5911d1403f4143c9d56a702069d593e8d0f3fab880a85e103604d0893ea31ba7"
|
||||
dependencies = [
|
||||
"chrono",
|
||||
"pem 1.0.2",
|
||||
"ring",
|
||||
"yasna",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "redox_syscall"
|
||||
version = "0.2.10"
|
||||
@@ -1758,7 +1703,7 @@ dependencies = [
|
||||
"serde_json",
|
||||
"serde_urlencoded",
|
||||
"tokio",
|
||||
"tokio-rustls 0.23.2",
|
||||
"tokio-rustls",
|
||||
"tokio-util",
|
||||
"url",
|
||||
"wasm-bindgen",
|
||||
@@ -2320,32 +2265,6 @@ dependencies = [
|
||||
"tokio-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-postgres-rustls"
|
||||
version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7bd8c37d8c23cb6ecdc32fc171bade4e9c7f1be65f693a17afbaad02091a0a19"
|
||||
dependencies = [
|
||||
"futures",
|
||||
"ring",
|
||||
"rustls 0.19.1",
|
||||
"tokio",
|
||||
"tokio-postgres 0.7.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=2949d98df52587d562986aad155dd4e889e408b7)",
|
||||
"tokio-rustls 0.22.0",
|
||||
"webpki 0.21.4",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-rustls"
|
||||
version = "0.22.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bc6844de72e57df1980054b38be3a9f4702aba4858be64dd700181a8a6d0e1b6"
|
||||
dependencies = [
|
||||
"rustls 0.19.1",
|
||||
"tokio",
|
||||
"webpki 0.21.4",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-rustls"
|
||||
version = "0.23.2"
|
||||
@@ -2811,15 +2730,6 @@ version = "0.8.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d2d7d3948613f75c98fd9328cfdcc45acc4d360655289d0a7d4ec931392200a3"
|
||||
|
||||
[[package]]
|
||||
name = "yasna"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e262a29d0e61ccf2b6190d7050d4b237535fc76ce4c1210d9caa316f71dffa75"
|
||||
dependencies = [
|
||||
"chrono",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zenith"
|
||||
version = "0.1.0"
|
||||
|
||||
@@ -16,8 +16,3 @@ members = [
|
||||
# This is useful for profiling and, to some extent, debug.
|
||||
# Besides, debug info should not affect the performance.
|
||||
debug = true
|
||||
|
||||
# This is only needed for proxy's tests
|
||||
# TODO: we should probably fork tokio-postgres-rustls instead
|
||||
[patch.crates-io]
|
||||
tokio-postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="2949d98df52587d562986aad155dd4e889e408b7" }
|
||||
|
||||
@@ -334,26 +334,14 @@ impl PostgresNode {
|
||||
if let Some(lsn) = self.lsn {
|
||||
conf.append("recovery_target_lsn", &lsn.to_string());
|
||||
}
|
||||
|
||||
conf.append_line("");
|
||||
// Configure backpressure
|
||||
// - Replication write lag depends on how fast the walreceiver can process incoming WAL.
|
||||
// This lag determines latency of get_page_at_lsn. Speed of applying WAL is about 10MB/sec,
|
||||
// so to avoid expiration of 1 minute timeout, this lag should not be larger than 600MB.
|
||||
// Actually latency should be much smaller (better if < 1sec). But we assume that recently
|
||||
// updates pages are not requested from pageserver.
|
||||
// - Replication flush lag depends on speed of persisting data by checkpointer (creation of
|
||||
// delta/image layers) and advancing disk_consistent_lsn. Safekeepers are able to
|
||||
// remove/archive WAL only beyond disk_consistent_lsn. Too large a lag can cause long
|
||||
// recovery time (in case of pageserver crash) and disk space overflow at safekeepers.
|
||||
// - Replication apply lag depends on speed of uploading changes to S3 by uploader thread.
|
||||
// To be able to restore database in case of pageserver node crash, safekeeper should not
|
||||
// remove WAL beyond this point. Too large lag can cause space exhaustion in safekeepers
|
||||
// (if they are not able to upload WAL to S3).
|
||||
conf.append("max_replication_write_lag", "500MB");
|
||||
conf.append("max_replication_flush_lag", "10GB");
|
||||
|
||||
if !self.env.safekeepers.is_empty() {
|
||||
// Configure backpressure
|
||||
// In setup with safekeepers apply_lag depends on
|
||||
// speed of data checkpointing on pageserver (see disk_consistent_lsn).
|
||||
conf.append("max_replication_apply_lag", "1500MB");
|
||||
|
||||
// Configure the node to connect to the safekeepers
|
||||
conf.append("synchronous_standby_names", "walproposer");
|
||||
|
||||
@@ -366,6 +354,11 @@ impl PostgresNode {
|
||||
.join(",");
|
||||
conf.append("wal_acceptors", &wal_acceptors);
|
||||
} else {
|
||||
// Configure backpressure
|
||||
// In setup without safekeepers, flush_lag depends on
|
||||
// speed of of data checkpointing on pageserver (see disk_consistent_lsn)
|
||||
conf.append("max_replication_flush_lag", "1500MB");
|
||||
|
||||
// We only use setup without safekeepers for tests,
|
||||
// and don't care about data durability on pageserver,
|
||||
// so set more relaxed synchronous_commit.
|
||||
|
||||
@@ -41,7 +41,6 @@ url = "2"
|
||||
nix = "0.23"
|
||||
once_cell = "1.8.0"
|
||||
crossbeam-utils = "0.8.5"
|
||||
fail = "0.5.0"
|
||||
|
||||
rust-s3 = { version = "0.28", default-features = false, features = ["no-verify-ssl", "tokio-rustls-tls"] }
|
||||
async-compression = {version = "0.3", features = ["zstd", "tokio"]}
|
||||
|
||||
@@ -234,7 +234,9 @@ paths:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/BranchInfo"
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/components/schemas/BranchInfo"
|
||||
"400":
|
||||
description: Malformed branch create request
|
||||
content:
|
||||
@@ -368,15 +370,12 @@ components:
|
||||
format: hex
|
||||
ancestor_id:
|
||||
type: string
|
||||
format: hex
|
||||
ancestor_lsn:
|
||||
type: string
|
||||
current_logical_size:
|
||||
type: integer
|
||||
current_logical_size_non_incremental:
|
||||
type: integer
|
||||
latest_valid_lsn:
|
||||
type: integer
|
||||
TimelineInfo:
|
||||
type: object
|
||||
required:
|
||||
|
||||
@@ -27,10 +27,13 @@ use zenith_utils::lsn::Lsn;
|
||||
use zenith_utils::postgres_backend::is_socket_read_timed_out;
|
||||
use zenith_utils::postgres_backend::PostgresBackend;
|
||||
use zenith_utils::postgres_backend::{self, AuthType};
|
||||
use zenith_utils::pq_proto::{BeMessage, FeMessage, RowDescriptor, SINGLE_COL_ROWDESC};
|
||||
use zenith_utils::pq_proto::{
|
||||
BeMessage, FeMessage, RowDescriptor, HELLO_WORLD_ROW, SINGLE_COL_ROWDESC,
|
||||
};
|
||||
use zenith_utils::zid::{ZTenantId, ZTimelineId};
|
||||
|
||||
use crate::basebackup;
|
||||
use crate::branches;
|
||||
use crate::config::PageServerConf;
|
||||
use crate::relish::*;
|
||||
use crate::repository::Timeline;
|
||||
@@ -659,21 +662,79 @@ impl postgres_backend::Handler for PageServerHandler {
|
||||
walreceiver::launch_wal_receiver(self.conf, tenantid, timelineid, &connstr)?;
|
||||
|
||||
pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
} else if query_string.starts_with("branch_create ") {
|
||||
let err = || format!("invalid branch_create: '{}'", query_string);
|
||||
|
||||
// branch_create <tenantid> <branchname> <startpoint>
|
||||
// TODO lazy static
|
||||
// TODO: escaping, to allow branch names with spaces
|
||||
let re = Regex::new(r"^branch_create ([[:xdigit:]]+) (\S+) ([^\r\n\s;]+)[\r\n\s;]*;?$")
|
||||
.unwrap();
|
||||
let caps = re.captures(query_string).with_context(err)?;
|
||||
|
||||
let tenantid = ZTenantId::from_str(caps.get(1).unwrap().as_str())?;
|
||||
let branchname = caps.get(2).with_context(err)?.as_str().to_owned();
|
||||
let startpoint_str = caps.get(3).with_context(err)?.as_str().to_owned();
|
||||
|
||||
self.check_permission(Some(tenantid))?;
|
||||
|
||||
let _enter =
|
||||
info_span!("branch_create", name = %branchname, tenant = %tenantid).entered();
|
||||
|
||||
let branch =
|
||||
branches::create_branch(self.conf, &branchname, &startpoint_str, &tenantid)?;
|
||||
let branch = serde_json::to_vec(&branch)?;
|
||||
|
||||
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
|
||||
.write_message_noflush(&BeMessage::DataRow(&[Some(&branch)]))?
|
||||
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
} else if query_string.starts_with("branch_list ") {
|
||||
// branch_list <zenith tenantid as hex string>
|
||||
let re = Regex::new(r"^branch_list ([[:xdigit:]]+)$").unwrap();
|
||||
let caps = re
|
||||
.captures(query_string)
|
||||
.with_context(|| format!("invalid branch_list: '{}'", query_string))?;
|
||||
|
||||
let tenantid = ZTenantId::from_str(caps.get(1).unwrap().as_str())?;
|
||||
|
||||
// since these handlers for tenant/branch commands are deprecated (in favor of http based ones)
|
||||
// just use false in place of include non incremental logical size
|
||||
let branches = crate::branches::get_branches(self.conf, &tenantid, false)?;
|
||||
let branches_buf = serde_json::to_vec(&branches)?;
|
||||
|
||||
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
|
||||
.write_message_noflush(&BeMessage::DataRow(&[Some(&branches_buf)]))?
|
||||
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
} else if query_string.starts_with("tenant_list") {
|
||||
let tenants = crate::tenant_mgr::list_tenants()?;
|
||||
let tenants_buf = serde_json::to_vec(&tenants)?;
|
||||
|
||||
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
|
||||
.write_message_noflush(&BeMessage::DataRow(&[Some(&tenants_buf)]))?
|
||||
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
} else if query_string.starts_with("tenant_create") {
|
||||
let err = || format!("invalid tenant_create: '{}'", query_string);
|
||||
|
||||
// tenant_create <tenantid>
|
||||
let re = Regex::new(r"^tenant_create ([[:xdigit:]]+)$").unwrap();
|
||||
let caps = re.captures(query_string).with_context(err)?;
|
||||
|
||||
self.check_permission(None)?;
|
||||
|
||||
let tenantid = ZTenantId::from_str(caps.get(1).unwrap().as_str())?;
|
||||
|
||||
tenant_mgr::create_repository_for_tenant(self.conf, tenantid)?;
|
||||
|
||||
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
|
||||
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
} else if query_string.starts_with("status") {
|
||||
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
|
||||
.write_message_noflush(&HELLO_WORLD_ROW)?
|
||||
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
} else if query_string.to_ascii_lowercase().starts_with("set ") {
|
||||
// important because psycopg2 executes "SET datestyle TO 'ISO'"
|
||||
// on connect
|
||||
pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
} else if query_string.starts_with("failpoints ") {
|
||||
let (_, failpoints) = query_string.split_at("failpoints ".len());
|
||||
for failpoint in failpoints.split(';') {
|
||||
if let Some((name, actions)) = failpoint.split_once('=') {
|
||||
info!("cfg failpoint: {} {}", name, actions);
|
||||
fail::cfg(name, actions).unwrap();
|
||||
} else {
|
||||
bail!("Invalid failpoints format");
|
||||
}
|
||||
}
|
||||
pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
} else if query_string.starts_with("do_gc ") {
|
||||
// Run GC immediately on given timeline.
|
||||
// FIXME: This is just for tests. See test_runner/batch_others/test_gc.py.
|
||||
|
||||
@@ -12,7 +12,6 @@ use crate::thread_mgr::ThreadKind;
|
||||
use crate::walingest::WalIngest;
|
||||
use anyhow::{bail, Context, Error, Result};
|
||||
use bytes::BytesMut;
|
||||
use fail::fail_point;
|
||||
use lazy_static::lazy_static;
|
||||
use postgres_ffi::waldecoder::*;
|
||||
use postgres_protocol::message::backend::ReplicationMessage;
|
||||
@@ -32,7 +31,6 @@ use zenith_utils::lsn::Lsn;
|
||||
use zenith_utils::pq_proto::ZenithFeedback;
|
||||
use zenith_utils::zid::ZTenantId;
|
||||
use zenith_utils::zid::ZTimelineId;
|
||||
|
||||
//
|
||||
// We keep one WAL Receiver active per timeline.
|
||||
//
|
||||
@@ -256,8 +254,6 @@ fn walreceiver_main(
|
||||
let writer = timeline.writer();
|
||||
walingest.ingest_record(writer.as_ref(), recdata, lsn)?;
|
||||
|
||||
fail_point!("walreceiver-after-ingest");
|
||||
|
||||
last_rec_lsn = lsn;
|
||||
}
|
||||
|
||||
|
||||
27
poetry.lock
generated
27
poetry.lock
generated
@@ -814,7 +814,7 @@ python-versions = "*"
|
||||
|
||||
[[package]]
|
||||
name = "moto"
|
||||
version = "3.0.0"
|
||||
version = "3.0.3"
|
||||
description = "A library that allows your python tests to easily mock out the boto library"
|
||||
category = "main"
|
||||
optional = false
|
||||
@@ -849,6 +849,7 @@ xmltodict = "*"
|
||||
[package.extras]
|
||||
all = ["PyYAML (>=5.1)", "python-jose[cryptography] (>=3.1.0,<4.0.0)", "ecdsa (!=0.15)", "docker (>=2.5.1)", "graphql-core", "jsondiff (>=1.1.2)", "aws-xray-sdk (>=0.93,!=0.96)", "idna (>=2.5,<4)", "cfn-lint (>=0.4.0)", "sshpubkeys (>=3.1.0)", "setuptools"]
|
||||
apigateway = ["python-jose[cryptography] (>=3.1.0,<4.0.0)", "ecdsa (!=0.15)"]
|
||||
apigatewayv2 = ["PyYAML (>=5.1)"]
|
||||
appsync = ["graphql-core"]
|
||||
awslambda = ["docker (>=2.5.1)"]
|
||||
batch = ["docker (>=2.5.1)"]
|
||||
@@ -1059,6 +1060,20 @@ python-versions = ">=3.6"
|
||||
py = "*"
|
||||
pytest = ">=3.10"
|
||||
|
||||
[[package]]
|
||||
name = "pytest-skip-slow"
|
||||
version = "0.0.2"
|
||||
description = "A pytest plugin to skip `@pytest.mark.slow` tests by default. "
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
|
||||
[package.dependencies]
|
||||
pytest = ">=6.2.0"
|
||||
|
||||
[package.extras]
|
||||
test = ["tox"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-xdist"
|
||||
version = "2.5.0"
|
||||
@@ -1352,7 +1367,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-
|
||||
[metadata]
|
||||
lock-version = "1.1"
|
||||
python-versions = "^3.7"
|
||||
content-hash = "0fa6c9377fbc827240d18d8b7e3742def37e90fc3277fddf8525d82dabd13090"
|
||||
content-hash = "d59ea97fb78d13dcede2719734c55b7d6b9dcc7f86001a9228c8808ed6e58eb7"
|
||||
|
||||
[metadata.files]
|
||||
aiopg = [
|
||||
@@ -1666,8 +1681,8 @@ mccabe = [
|
||||
{file = "mccabe-0.6.1.tar.gz", hash = "sha256:dd8d182285a0fe56bace7f45b5e7d1a6ebcbf524e8f3bd87eb0f125271b8831f"},
|
||||
]
|
||||
moto = [
|
||||
{file = "moto-3.0.0-py2.py3-none-any.whl", hash = "sha256:762d33bbad3642c687f6495e69331318bef43f9aa662174397706ec3ad2a3578"},
|
||||
{file = "moto-3.0.0.tar.gz", hash = "sha256:d6b00a2663290e7ebb06823d5ffcb124c8dc9bf526b878539ef7c4a377fd8255"},
|
||||
{file = "moto-3.0.3-py2.py3-none-any.whl", hash = "sha256:445a574395b8a43a249ae0f932bf10c5cc677054198bfa1ff92e6fbd60e72c38"},
|
||||
{file = "moto-3.0.3.tar.gz", hash = "sha256:fa3fbdc22c55d7e70b407e2f2639c48ac82b074f472b167609405c0c1e3a2ccb"},
|
||||
]
|
||||
mypy = [
|
||||
{file = "mypy-0.910-cp35-cp35m-macosx_10_9_x86_64.whl", hash = "sha256:a155d80ea6cee511a3694b108c4494a39f42de11ee4e61e72bc424c490e46457"},
|
||||
@@ -1842,6 +1857,10 @@ pytest-forked = [
|
||||
{file = "pytest-forked-1.4.0.tar.gz", hash = "sha256:8b67587c8f98cbbadfdd804539ed5455b6ed03802203485dd2f53c1422d7440e"},
|
||||
{file = "pytest_forked-1.4.0-py3-none-any.whl", hash = "sha256:bbbb6717efc886b9d64537b41fb1497cfaf3c9601276be8da2cccfea5a3c8ad8"},
|
||||
]
|
||||
pytest-skip-slow = [
|
||||
{file = "pytest-skip-slow-0.0.2.tar.gz", hash = "sha256:06b8353d8b1e2168c22b8c34172329b4f592eea648dac141099cc881fd466718"},
|
||||
{file = "pytest_skip_slow-0.0.2-py3-none-any.whl", hash = "sha256:9029eb8c258e6c3c4b499848d864c9e9c48499d8b8bf5bb413044f59874cfd06"},
|
||||
]
|
||||
pytest-xdist = [
|
||||
{file = "pytest-xdist-2.5.0.tar.gz", hash = "sha256:4580deca3ff04ddb2ac53eba39d76cb5dd5edeac050cb6fbc768b0dd712b4edf"},
|
||||
{file = "pytest_xdist-2.5.0-py3-none-any.whl", hash = "sha256:6fe5c74fec98906deb8f2d2b616b5c782022744978e7bd4695d39c8f42d0ce65"},
|
||||
|
||||
@@ -6,28 +6,18 @@ edition = "2021"
|
||||
[dependencies]
|
||||
anyhow = "1.0"
|
||||
bytes = { version = "1.0.1", features = ['serde'] }
|
||||
clap = "3.0"
|
||||
futures = "0.3.13"
|
||||
hashbrown = "0.11.2"
|
||||
hex = "0.4.3"
|
||||
hyper = "0.14"
|
||||
lazy_static = "1.4.0"
|
||||
md5 = "0.7.0"
|
||||
parking_lot = "0.11.2"
|
||||
pin-project-lite = "0.2.7"
|
||||
rand = "0.8.3"
|
||||
reqwest = { version = "0.11", default-features = false, features = ["blocking", "json", "rustls-tls"] }
|
||||
rustls = "0.19.1"
|
||||
scopeguard = "1.1.0"
|
||||
hex = "0.4.3"
|
||||
hyper = "0.14"
|
||||
serde = "1"
|
||||
serde_json = "1"
|
||||
tokio = { version = "1.11", features = ["macros"] }
|
||||
tokio-postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="2949d98df52587d562986aad155dd4e889e408b7" }
|
||||
tokio-rustls = "0.22.0"
|
||||
clap = "3.0"
|
||||
rustls = "0.19.1"
|
||||
reqwest = { version = "0.11", default-features = false, features = ["blocking", "json", "rustls-tls"] }
|
||||
|
||||
zenith_utils = { path = "../zenith_utils" }
|
||||
zenith_metrics = { path = "../zenith_metrics" }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio-postgres-rustls = "0.8.0"
|
||||
rcgen = "0.8.14"
|
||||
|
||||
@@ -1,169 +0,0 @@
|
||||
use crate::compute::DatabaseInfo;
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::cplane_api::{self, CPlaneApi};
|
||||
use crate::stream::PqStream;
|
||||
use anyhow::{anyhow, bail, Context};
|
||||
use std::collections::HashMap;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use zenith_utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage, FeMessage as Fe};
|
||||
|
||||
/// Various client credentials which we use for authentication.
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub struct ClientCredentials {
|
||||
pub user: String,
|
||||
pub dbname: String,
|
||||
}
|
||||
|
||||
impl TryFrom<HashMap<String, String>> for ClientCredentials {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn try_from(mut value: HashMap<String, String>) -> Result<Self, Self::Error> {
|
||||
let mut get_param = |key| {
|
||||
value
|
||||
.remove(key)
|
||||
.with_context(|| format!("{} is missing in startup packet", key))
|
||||
};
|
||||
|
||||
let user = get_param("user")?;
|
||||
let db = get_param("database")?;
|
||||
|
||||
Ok(Self { user, dbname: db })
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientCredentials {
|
||||
/// Use credentials to authenticate the user.
|
||||
pub async fn authenticate(
|
||||
self,
|
||||
config: &ProxyConfig,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> anyhow::Result<DatabaseInfo> {
|
||||
use crate::config::ClientAuthMethod::*;
|
||||
use crate::config::RouterConfig::*;
|
||||
let db_info = match &config.router_config {
|
||||
Static { host, port } => handle_static(host.clone(), *port, client, self).await,
|
||||
Dynamic(Mixed) => {
|
||||
if self.user.ends_with("@zenith") {
|
||||
handle_existing_user(config, client, self).await
|
||||
} else {
|
||||
handle_new_user(config, client).await
|
||||
}
|
||||
}
|
||||
Dynamic(Password) => handle_existing_user(config, client, self).await,
|
||||
Dynamic(Link) => handle_new_user(config, client).await,
|
||||
};
|
||||
|
||||
db_info.context("failed to authenticate client")
|
||||
}
|
||||
}
|
||||
|
||||
fn new_psql_session_id() -> String {
|
||||
hex::encode(rand::random::<[u8; 8]>())
|
||||
}
|
||||
|
||||
async fn handle_static(
|
||||
host: String,
|
||||
port: u16,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
creds: ClientCredentials,
|
||||
) -> anyhow::Result<DatabaseInfo> {
|
||||
client
|
||||
.write_message(&Be::AuthenticationCleartextPassword)
|
||||
.await?;
|
||||
|
||||
// Read client's password bytes
|
||||
let msg = match client.read_message().await? {
|
||||
Fe::PasswordMessage(msg) => msg,
|
||||
bad => bail!("unexpected message type: {:?}", bad),
|
||||
};
|
||||
|
||||
let cleartext_password = std::str::from_utf8(&msg)?.split('\0').next().unwrap();
|
||||
|
||||
let db_info = DatabaseInfo {
|
||||
host,
|
||||
port,
|
||||
dbname: creds.dbname.clone(),
|
||||
user: creds.user.clone(),
|
||||
password: Some(cleartext_password.into()),
|
||||
};
|
||||
|
||||
client
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
|
||||
|
||||
Ok(db_info)
|
||||
}
|
||||
|
||||
async fn handle_existing_user(
|
||||
config: &ProxyConfig,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
creds: ClientCredentials,
|
||||
) -> anyhow::Result<DatabaseInfo> {
|
||||
let psql_session_id = new_psql_session_id();
|
||||
let md5_salt = rand::random();
|
||||
|
||||
client
|
||||
.write_message(&Be::AuthenticationMD5Password(&md5_salt))
|
||||
.await?;
|
||||
|
||||
// Read client's password hash
|
||||
let msg = match client.read_message().await? {
|
||||
Fe::PasswordMessage(msg) => msg,
|
||||
bad => bail!("unexpected message type: {:?}", bad),
|
||||
};
|
||||
|
||||
let (_trailing_null, md5_response) = msg
|
||||
.split_last()
|
||||
.ok_or_else(|| anyhow!("unexpected password message"))?;
|
||||
|
||||
let cplane = CPlaneApi::new(&config.auth_endpoint);
|
||||
let db_info = cplane
|
||||
.authenticate_proxy_request(creds, md5_response, &md5_salt, &psql_session_id)
|
||||
.await?;
|
||||
|
||||
client
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
|
||||
|
||||
Ok(db_info)
|
||||
}
|
||||
|
||||
async fn handle_new_user(
|
||||
config: &ProxyConfig,
|
||||
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
) -> anyhow::Result<DatabaseInfo> {
|
||||
let psql_session_id = new_psql_session_id();
|
||||
let greeting = hello_message(&config.redirect_uri, &psql_session_id);
|
||||
|
||||
let db_info = cplane_api::with_waiter(psql_session_id, |waiter| async {
|
||||
// Give user a URL to spawn a new database
|
||||
client
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
.write_message_noflush(&BeParameterStatusMessage::encoding())?
|
||||
.write_message(&Be::NoticeResponse(greeting))
|
||||
.await?;
|
||||
|
||||
// Wait for web console response
|
||||
waiter.await?.map_err(|e| anyhow!(e))
|
||||
})
|
||||
.await?;
|
||||
|
||||
client.write_message_noflush(&Be::NoticeResponse("Connecting to database.".into()))?;
|
||||
|
||||
Ok(db_info)
|
||||
}
|
||||
|
||||
fn hello_message(redirect_uri: &str, session_id: &str) -> String {
|
||||
format!(
|
||||
concat![
|
||||
"☀️ Welcome to Zenith!\n",
|
||||
"To proceed with database creation, open the following link:\n\n",
|
||||
" {redirect_uri}{session_id}\n\n",
|
||||
"It needs to be done once and we will send you '.pgpass' file,\n",
|
||||
"which will allow you to access or create ",
|
||||
"databases without opening your web browser."
|
||||
],
|
||||
redirect_uri = redirect_uri,
|
||||
session_id = session_id,
|
||||
)
|
||||
}
|
||||
@@ -1,106 +0,0 @@
|
||||
use anyhow::{anyhow, Context};
|
||||
use hashbrown::HashMap;
|
||||
use parking_lot::Mutex;
|
||||
use std::net::SocketAddr;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_postgres::{CancelToken, NoTls};
|
||||
use zenith_utils::pq_proto::CancelKeyData;
|
||||
|
||||
/// Enables serving CancelRequests.
|
||||
#[derive(Default)]
|
||||
pub struct CancelMap(Mutex<HashMap<CancelKeyData, Option<CancelClosure>>>);
|
||||
|
||||
impl CancelMap {
|
||||
/// Cancel a running query for the corresponding connection.
|
||||
pub async fn cancel_session(&self, key: CancelKeyData) -> anyhow::Result<()> {
|
||||
let cancel_closure = self
|
||||
.0
|
||||
.lock()
|
||||
.get(&key)
|
||||
.and_then(|x| x.clone())
|
||||
.with_context(|| format!("unknown session: {:?}", key))?;
|
||||
|
||||
cancel_closure.try_cancel_query().await
|
||||
}
|
||||
|
||||
/// Run async action within an ephemeral session identified by [`CancelKeyData`].
|
||||
pub async fn with_session<'a, F, R, V>(&'a self, f: F) -> anyhow::Result<V>
|
||||
where
|
||||
F: FnOnce(Session<'a>) -> R,
|
||||
R: std::future::Future<Output = anyhow::Result<V>>,
|
||||
{
|
||||
// HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
|
||||
// expose it and we don't want to do another roundtrip to query
|
||||
// for it. The client will be able to notice that this is not the
|
||||
// actual backend_pid, but backend_pid is not used for anything
|
||||
// so it doesn't matter.
|
||||
let key = rand::random();
|
||||
|
||||
// Random key collisions are unlikely to happen here, but they're still possible,
|
||||
// which is why we have to take care not to rewrite an existing key.
|
||||
self.0
|
||||
.lock()
|
||||
.try_insert(key, None)
|
||||
.map_err(|_| anyhow!("session already exists: {:?}", key))?;
|
||||
|
||||
// This will guarantee that the session gets dropped
|
||||
// as soon as the future is finished.
|
||||
scopeguard::defer! {
|
||||
self.0.lock().remove(&key);
|
||||
}
|
||||
|
||||
let session = Session::new(key, self);
|
||||
f(session).await
|
||||
}
|
||||
}
|
||||
|
||||
/// This should've been a [`std::future::Future`], but
|
||||
/// it's impossible to name a type of an unboxed future
|
||||
/// (we'd need something like `#![feature(type_alias_impl_trait)]`).
|
||||
#[derive(Clone)]
|
||||
pub struct CancelClosure {
|
||||
socket_addr: SocketAddr,
|
||||
cancel_token: CancelToken,
|
||||
}
|
||||
|
||||
impl CancelClosure {
|
||||
pub fn new(socket_addr: SocketAddr, cancel_token: CancelToken) -> Self {
|
||||
Self {
|
||||
socket_addr,
|
||||
cancel_token,
|
||||
}
|
||||
}
|
||||
|
||||
/// Cancels the query running on user's compute node.
|
||||
pub async fn try_cancel_query(self) -> anyhow::Result<()> {
|
||||
let socket = TcpStream::connect(self.socket_addr).await?;
|
||||
self.cancel_token.cancel_query_raw(socket, NoTls).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper for registering query cancellation tokens.
|
||||
pub struct Session<'a> {
|
||||
/// The user-facing key identifying this session.
|
||||
key: CancelKeyData,
|
||||
/// The [`CancelMap`] this session belongs to.
|
||||
cancel_map: &'a CancelMap,
|
||||
}
|
||||
|
||||
impl<'a> Session<'a> {
|
||||
fn new(key: CancelKeyData, cancel_map: &'a CancelMap) -> Self {
|
||||
Self { key, cancel_map }
|
||||
}
|
||||
|
||||
/// Store the cancel token for the given session.
|
||||
/// This enables query cancellation in [`crate::proxy::handshake`].
|
||||
pub fn enable_cancellation(self, cancel_closure: CancelClosure) -> CancelKeyData {
|
||||
self.cancel_map
|
||||
.0
|
||||
.lock()
|
||||
.insert(self.key, Some(cancel_closure));
|
||||
|
||||
self.key
|
||||
}
|
||||
}
|
||||
@@ -1,42 +0,0 @@
|
||||
use anyhow::Context;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::net::{SocketAddr, ToSocketAddrs};
|
||||
|
||||
/// Compute node connection params.
|
||||
#[derive(Serialize, Deserialize, Debug, Default)]
|
||||
pub struct DatabaseInfo {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
pub dbname: String,
|
||||
pub user: String,
|
||||
pub password: Option<String>,
|
||||
}
|
||||
|
||||
impl DatabaseInfo {
|
||||
pub fn socket_addr(&self) -> anyhow::Result<SocketAddr> {
|
||||
let host_port = format!("{}:{}", self.host, self.port);
|
||||
host_port
|
||||
.to_socket_addrs()
|
||||
.with_context(|| format!("cannot resolve {} to SocketAddr", host_port))?
|
||||
.next()
|
||||
.context("cannot resolve at least one SocketAddr")
|
||||
}
|
||||
}
|
||||
|
||||
impl From<DatabaseInfo> for tokio_postgres::Config {
|
||||
fn from(db_info: DatabaseInfo) -> Self {
|
||||
let mut config = tokio_postgres::Config::new();
|
||||
|
||||
config
|
||||
.host(&db_info.host)
|
||||
.port(db_info.port)
|
||||
.dbname(&db_info.dbname)
|
||||
.user(&db_info.user);
|
||||
|
||||
if let Some(password) = db_info.password {
|
||||
config.password(password);
|
||||
}
|
||||
|
||||
config
|
||||
}
|
||||
}
|
||||
@@ -1,79 +1,18 @@
|
||||
use crate::auth::ClientCredentials;
|
||||
use crate::compute::DatabaseInfo;
|
||||
use crate::waiters::{Waiter, Waiters};
|
||||
use anyhow::{anyhow, bail};
|
||||
use lazy_static::lazy_static;
|
||||
use anyhow::{anyhow, bail, Context};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::net::{SocketAddr, ToSocketAddrs};
|
||||
|
||||
lazy_static! {
|
||||
static ref CPLANE_WAITERS: Waiters<Result<DatabaseInfo, String>> = Default::default();
|
||||
use crate::state::ProxyWaiters;
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Default)]
|
||||
pub struct DatabaseInfo {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
pub dbname: String,
|
||||
pub user: String,
|
||||
pub password: Option<String>,
|
||||
}
|
||||
|
||||
/// Give caller an opportunity to wait for cplane's reply.
|
||||
pub async fn with_waiter<F, R, T>(psql_session_id: impl Into<String>, f: F) -> anyhow::Result<T>
|
||||
where
|
||||
F: FnOnce(Waiter<'static, Result<DatabaseInfo, String>>) -> R,
|
||||
R: std::future::Future<Output = anyhow::Result<T>>,
|
||||
{
|
||||
let waiter = CPLANE_WAITERS.register(psql_session_id.into())?;
|
||||
f(waiter).await
|
||||
}
|
||||
|
||||
pub fn notify(psql_session_id: &str, msg: Result<DatabaseInfo, String>) -> anyhow::Result<()> {
|
||||
CPLANE_WAITERS.notify(psql_session_id, msg)
|
||||
}
|
||||
|
||||
/// Zenith console API wrapper.
|
||||
pub struct CPlaneApi<'a> {
|
||||
auth_endpoint: &'a str,
|
||||
}
|
||||
|
||||
impl<'a> CPlaneApi<'a> {
|
||||
pub fn new(auth_endpoint: &'a str) -> Self {
|
||||
Self { auth_endpoint }
|
||||
}
|
||||
}
|
||||
|
||||
impl CPlaneApi<'_> {
|
||||
pub async fn authenticate_proxy_request(
|
||||
&self,
|
||||
creds: ClientCredentials,
|
||||
md5_response: &[u8],
|
||||
salt: &[u8; 4],
|
||||
psql_session_id: &str,
|
||||
) -> anyhow::Result<DatabaseInfo> {
|
||||
let mut url = reqwest::Url::parse(self.auth_endpoint)?;
|
||||
url.query_pairs_mut()
|
||||
.append_pair("login", &creds.user)
|
||||
.append_pair("database", &creds.dbname)
|
||||
.append_pair("md5response", std::str::from_utf8(md5_response)?)
|
||||
.append_pair("salt", &hex::encode(salt))
|
||||
.append_pair("psql_session_id", psql_session_id);
|
||||
|
||||
with_waiter(psql_session_id, |waiter| async {
|
||||
println!("cplane request: {}", url);
|
||||
// TODO: leverage `reqwest::Client` to reuse connections
|
||||
let resp = reqwest::get(url).await?;
|
||||
if !resp.status().is_success() {
|
||||
bail!("Auth failed: {}", resp.status())
|
||||
}
|
||||
|
||||
let auth_info: ProxyAuthResponse = serde_json::from_str(resp.text().await?.as_str())?;
|
||||
println!("got auth info: #{:?}", auth_info);
|
||||
|
||||
use ProxyAuthResponse::*;
|
||||
match auth_info {
|
||||
Ready { conn_info } => Ok(conn_info),
|
||||
Error { error } => bail!(error),
|
||||
NotReady { .. } => waiter.await?.map_err(|e| anyhow!(e)),
|
||||
}
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: the order of constructors is important.
|
||||
// https://serde.rs/enum-representations.html#untagged
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[serde(untagged)]
|
||||
enum ProxyAuthResponse {
|
||||
@@ -82,6 +21,86 @@ enum ProxyAuthResponse {
|
||||
NotReady { ready: bool }, // TODO: get rid of `ready`
|
||||
}
|
||||
|
||||
impl DatabaseInfo {
|
||||
pub fn socket_addr(&self) -> anyhow::Result<SocketAddr> {
|
||||
let host_port = format!("{}:{}", self.host, self.port);
|
||||
host_port
|
||||
.to_socket_addrs()
|
||||
.with_context(|| format!("cannot resolve {} to SocketAddr", host_port))?
|
||||
.next()
|
||||
.context("cannot resolve at least one SocketAddr")
|
||||
}
|
||||
}
|
||||
|
||||
impl From<DatabaseInfo> for tokio_postgres::Config {
|
||||
fn from(db_info: DatabaseInfo) -> Self {
|
||||
let mut config = tokio_postgres::Config::new();
|
||||
|
||||
config
|
||||
.host(&db_info.host)
|
||||
.port(db_info.port)
|
||||
.dbname(&db_info.dbname)
|
||||
.user(&db_info.user);
|
||||
|
||||
if let Some(password) = db_info.password {
|
||||
config.password(password);
|
||||
}
|
||||
|
||||
config
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CPlaneApi<'a> {
|
||||
auth_endpoint: &'a str,
|
||||
waiters: &'a ProxyWaiters,
|
||||
}
|
||||
|
||||
impl<'a> CPlaneApi<'a> {
|
||||
pub fn new(auth_endpoint: &'a str, waiters: &'a ProxyWaiters) -> Self {
|
||||
Self {
|
||||
auth_endpoint,
|
||||
waiters,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CPlaneApi<'_> {
|
||||
pub fn authenticate_proxy_request(
|
||||
&self,
|
||||
user: &str,
|
||||
database: &str,
|
||||
md5_response: &[u8],
|
||||
salt: &[u8; 4],
|
||||
psql_session_id: &str,
|
||||
) -> anyhow::Result<DatabaseInfo> {
|
||||
let mut url = reqwest::Url::parse(self.auth_endpoint)?;
|
||||
url.query_pairs_mut()
|
||||
.append_pair("login", user)
|
||||
.append_pair("database", database)
|
||||
.append_pair("md5response", std::str::from_utf8(md5_response)?)
|
||||
.append_pair("salt", &hex::encode(salt))
|
||||
.append_pair("psql_session_id", psql_session_id);
|
||||
|
||||
let waiter = self.waiters.register(psql_session_id.to_owned());
|
||||
|
||||
println!("cplane request: {}", url);
|
||||
let resp = reqwest::blocking::get(url)?;
|
||||
if !resp.status().is_success() {
|
||||
bail!("Auth failed: {}", resp.status())
|
||||
}
|
||||
|
||||
let auth_info: ProxyAuthResponse = serde_json::from_str(resp.text()?.as_str())?;
|
||||
println!("got auth info: #{:?}", auth_info);
|
||||
|
||||
use ProxyAuthResponse::*;
|
||||
match auth_info {
|
||||
Ready { conn_info } => Ok(conn_info),
|
||||
Error { error } => bail!(error),
|
||||
NotReady { .. } => waiter.wait()?.map_err(|e| anyhow!(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
@@ -1,30 +1,15 @@
|
||||
use anyhow::anyhow;
|
||||
use hyper::{Body, Request, Response, StatusCode};
|
||||
use std::net::TcpListener;
|
||||
use zenith_utils::http::RouterBuilder;
|
||||
|
||||
use zenith_utils::http::endpoint;
|
||||
use zenith_utils::http::error::ApiError;
|
||||
use zenith_utils::http::json::json_response;
|
||||
use zenith_utils::http::{RouterBuilder, RouterService};
|
||||
|
||||
async fn status_handler(_: Request<Body>) -> Result<Response<Body>, ApiError> {
|
||||
Ok(json_response(StatusCode::OK, "")?)
|
||||
}
|
||||
|
||||
fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
|
||||
pub fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
|
||||
let router = endpoint::make_router();
|
||||
router.get("/v1/status", status_handler)
|
||||
}
|
||||
|
||||
pub async fn thread_main(http_listener: TcpListener) -> anyhow::Result<()> {
|
||||
scopeguard::defer! {
|
||||
println!("http has shut down");
|
||||
}
|
||||
|
||||
let service = || RouterService::new(make_router().build()?);
|
||||
|
||||
hyper::Server::from_tcp(http_listener)?
|
||||
.serve(service().map_err(|e| anyhow!(e))?)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -5,36 +5,21 @@
|
||||
/// (control plane API in our case) and can create new databases and accounts
|
||||
/// in somewhat transparent manner (again via communication with control plane API).
|
||||
///
|
||||
use anyhow::{bail, Context};
|
||||
use anyhow::bail;
|
||||
use clap::{App, Arg};
|
||||
use config::ProxyConfig;
|
||||
use futures::FutureExt;
|
||||
use std::future::Future;
|
||||
use tokio::{net::TcpListener, task::JoinError};
|
||||
use zenith_utils::GIT_VERSION;
|
||||
use state::{ProxyConfig, ProxyState};
|
||||
use std::thread;
|
||||
use zenith_utils::http::endpoint;
|
||||
use zenith_utils::{tcp_listener, GIT_VERSION};
|
||||
|
||||
use crate::config::{ClientAuthMethod, RouterConfig};
|
||||
|
||||
mod auth;
|
||||
mod cancellation;
|
||||
mod compute;
|
||||
mod config;
|
||||
mod cplane_api;
|
||||
mod http;
|
||||
mod mgmt;
|
||||
mod proxy;
|
||||
mod stream;
|
||||
mod state;
|
||||
mod waiters;
|
||||
|
||||
/// Flattens Result<Result<T>> into Result<T>.
|
||||
async fn flatten_err(
|
||||
f: impl Future<Output = Result<anyhow::Result<()>, JoinError>>,
|
||||
) -> anyhow::Result<()> {
|
||||
f.map(|r| r.context("join error").and_then(|x| x)).await
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
fn main() -> anyhow::Result<()> {
|
||||
zenith_metrics::set_common_metrics_prefix("zenith_proxy");
|
||||
let arg_matches = App::new("Zenith proxy/router")
|
||||
.version(GIT_VERSION)
|
||||
@@ -46,20 +31,6 @@ async fn main() -> anyhow::Result<()> {
|
||||
.help("listen for incoming client connections on ip:port")
|
||||
.default_value("127.0.0.1:4432"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("auth-method")
|
||||
.long("auth-method")
|
||||
.takes_value(true)
|
||||
.help("Possible values: password | link | mixed")
|
||||
.default_value("mixed"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("static-router")
|
||||
.short('s')
|
||||
.long("static-router")
|
||||
.takes_value(true)
|
||||
.help("Route all clients to host:port"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("mgmt")
|
||||
.short('m')
|
||||
@@ -108,59 +79,63 @@ async fn main() -> anyhow::Result<()> {
|
||||
)
|
||||
.get_matches();
|
||||
|
||||
let tls_config = match (
|
||||
let ssl_config = match (
|
||||
arg_matches.value_of("ssl-key"),
|
||||
arg_matches.value_of("ssl-cert"),
|
||||
) {
|
||||
(Some(key_path), Some(cert_path)) => Some(config::configure_ssl(key_path, cert_path)?),
|
||||
(Some(key_path), Some(cert_path)) => {
|
||||
Some(crate::state::configure_ssl(key_path, cert_path)?)
|
||||
}
|
||||
(None, None) => None,
|
||||
_ => bail!("either both or neither ssl-key and ssl-cert must be specified"),
|
||||
};
|
||||
|
||||
let auth_method = arg_matches.value_of("auth-method").unwrap().parse()?;
|
||||
let router_config = match arg_matches.value_of("static-router") {
|
||||
None => RouterConfig::Dynamic(auth_method),
|
||||
Some(addr) => {
|
||||
if let ClientAuthMethod::Password = auth_method {
|
||||
let (host, port) = addr.split_once(":").unwrap();
|
||||
RouterConfig::Static {
|
||||
host: host.to_string(),
|
||||
port: port.parse().unwrap(),
|
||||
}
|
||||
} else {
|
||||
bail!("static-router requires --auth-method password")
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let config: &ProxyConfig = Box::leak(Box::new(ProxyConfig {
|
||||
router_config,
|
||||
let config = ProxyConfig {
|
||||
proxy_address: arg_matches.value_of("proxy").unwrap().parse()?,
|
||||
mgmt_address: arg_matches.value_of("mgmt").unwrap().parse()?,
|
||||
http_address: arg_matches.value_of("http").unwrap().parse()?,
|
||||
redirect_uri: arg_matches.value_of("uri").unwrap().parse()?,
|
||||
auth_endpoint: arg_matches.value_of("auth-endpoint").unwrap().parse()?,
|
||||
tls_config,
|
||||
}));
|
||||
ssl_config,
|
||||
};
|
||||
let state: &ProxyState = Box::leak(Box::new(ProxyState::new(config)));
|
||||
|
||||
println!("Version: {}", GIT_VERSION);
|
||||
|
||||
// Check that we can bind to address before further initialization
|
||||
println!("Starting http on {}", config.http_address);
|
||||
let http_listener = TcpListener::bind(config.http_address).await?.into_std()?;
|
||||
println!("Starting http on {}", state.conf.http_address);
|
||||
let http_listener = tcp_listener::bind(state.conf.http_address)?;
|
||||
|
||||
println!("Starting mgmt on {}", config.mgmt_address);
|
||||
let mgmt_listener = TcpListener::bind(config.mgmt_address).await?.into_std()?;
|
||||
println!("Starting proxy on {}", state.conf.proxy_address);
|
||||
let pageserver_listener = tcp_listener::bind(state.conf.proxy_address)?;
|
||||
|
||||
println!("Starting proxy on {}", config.proxy_address);
|
||||
let proxy_listener = TcpListener::bind(config.proxy_address).await?;
|
||||
println!("Starting mgmt on {}", state.conf.mgmt_address);
|
||||
let mgmt_listener = tcp_listener::bind(state.conf.mgmt_address)?;
|
||||
|
||||
let http = tokio::spawn(http::thread_main(http_listener));
|
||||
let proxy = tokio::spawn(proxy::thread_main(config, proxy_listener));
|
||||
let mgmt = tokio::task::spawn_blocking(move || mgmt::thread_main(mgmt_listener));
|
||||
let threads = [
|
||||
thread::Builder::new()
|
||||
.name("Http thread".into())
|
||||
.spawn(move || {
|
||||
let router = http::make_router();
|
||||
endpoint::serve_thread_main(
|
||||
router,
|
||||
http_listener,
|
||||
std::future::pending(), // never shut down
|
||||
)
|
||||
})?,
|
||||
// Spawn a thread to listen for connections. It will spawn further threads
|
||||
// for each connection.
|
||||
thread::Builder::new()
|
||||
.name("Listener thread".into())
|
||||
.spawn(move || proxy::thread_main(state, pageserver_listener))?,
|
||||
thread::Builder::new()
|
||||
.name("Mgmt thread".into())
|
||||
.spawn(move || mgmt::thread_main(state, mgmt_listener))?,
|
||||
];
|
||||
|
||||
let tasks = [flatten_err(http), flatten_err(proxy), flatten_err(mgmt)];
|
||||
let _: Vec<()> = futures::future::try_join_all(tasks).await?;
|
||||
for t in threads {
|
||||
t.join().unwrap()?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,49 +1,44 @@
|
||||
use crate::{compute::DatabaseInfo, cplane_api};
|
||||
use anyhow::Context;
|
||||
use serde::Deserialize;
|
||||
use std::{
|
||||
net::{TcpListener, TcpStream},
|
||||
thread,
|
||||
};
|
||||
|
||||
use serde::Deserialize;
|
||||
use zenith_utils::{
|
||||
postgres_backend::{self, AuthType, PostgresBackend},
|
||||
pq_proto::{BeMessage, SINGLE_COL_ROWDESC},
|
||||
};
|
||||
|
||||
use crate::{cplane_api::DatabaseInfo, ProxyState};
|
||||
|
||||
///
|
||||
/// Main proxy listener loop.
|
||||
///
|
||||
/// Listens for connections, and launches a new handler thread for each.
|
||||
///
|
||||
pub fn thread_main(listener: TcpListener) -> anyhow::Result<()> {
|
||||
scopeguard::defer! {
|
||||
println!("mgmt has shut down");
|
||||
}
|
||||
|
||||
listener
|
||||
.set_nonblocking(false)
|
||||
.context("failed to set listener to blocking")?;
|
||||
pub fn thread_main(state: &'static ProxyState, listener: TcpListener) -> anyhow::Result<()> {
|
||||
loop {
|
||||
let (socket, peer_addr) = listener.accept().context("failed to accept a new client")?;
|
||||
let (socket, peer_addr) = listener.accept()?;
|
||||
println!("accepted connection from {}", peer_addr);
|
||||
socket
|
||||
.set_nodelay(true)
|
||||
.context("failed to set client socket option")?;
|
||||
socket.set_nodelay(true).unwrap();
|
||||
|
||||
thread::spawn(move || {
|
||||
if let Err(err) = handle_connection(socket) {
|
||||
if let Err(err) = handle_connection(state, socket) {
|
||||
println!("error: {}", err);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_connection(socket: TcpStream) -> anyhow::Result<()> {
|
||||
fn handle_connection(state: &ProxyState, socket: TcpStream) -> anyhow::Result<()> {
|
||||
let mut conn_handler = MgmtHandler { state };
|
||||
let pgbackend = PostgresBackend::new(socket, AuthType::Trust, None, true)?;
|
||||
pgbackend.run(&mut MgmtHandler)
|
||||
pgbackend.run(&mut conn_handler)
|
||||
}
|
||||
|
||||
struct MgmtHandler;
|
||||
struct MgmtHandler<'a> {
|
||||
state: &'a ProxyState,
|
||||
}
|
||||
|
||||
/// Serialized examples:
|
||||
// {
|
||||
@@ -79,13 +74,13 @@ enum PsqlSessionResult {
|
||||
Failure(String),
|
||||
}
|
||||
|
||||
impl postgres_backend::Handler for MgmtHandler {
|
||||
impl postgres_backend::Handler for MgmtHandler<'_> {
|
||||
fn process_query(
|
||||
&mut self,
|
||||
pgb: &mut PostgresBackend,
|
||||
query_string: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let res = try_process_query(pgb, query_string);
|
||||
let res = try_process_query(self, pgb, query_string);
|
||||
// intercept and log error message
|
||||
if res.is_err() {
|
||||
println!("Mgmt query failed: #{:?}", res);
|
||||
@@ -94,7 +89,11 @@ impl postgres_backend::Handler for MgmtHandler {
|
||||
}
|
||||
}
|
||||
|
||||
fn try_process_query(pgb: &mut PostgresBackend, query_string: &str) -> anyhow::Result<()> {
|
||||
fn try_process_query(
|
||||
mgmt: &mut MgmtHandler,
|
||||
pgb: &mut PostgresBackend,
|
||||
query_string: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
println!("Got mgmt query: '{}'", query_string);
|
||||
|
||||
let resp: PsqlSessionResponse = serde_json::from_str(query_string)?;
|
||||
@@ -105,7 +104,7 @@ fn try_process_query(pgb: &mut PostgresBackend, query_string: &str) -> anyhow::R
|
||||
Failure(message) => Err(message),
|
||||
};
|
||||
|
||||
match cplane_api::notify(&resp.session_id, msg) {
|
||||
match mgmt.state.waiters.notify(&resp.session_id, msg) {
|
||||
Ok(()) => {
|
||||
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
|
||||
.write_message_noflush(&BeMessage::DataRow(&[Some(b"ok")]))?
|
||||
|
||||
@@ -1,332 +1,389 @@
|
||||
use crate::auth;
|
||||
use crate::cancellation::{self, CancelClosure, CancelMap};
|
||||
use crate::compute::DatabaseInfo;
|
||||
use crate::config::{ProxyConfig, TlsConfig};
|
||||
use crate::stream::{MetricsStream, PqStream, Stream};
|
||||
use anyhow::{bail, Context};
|
||||
use crate::cplane_api::{CPlaneApi, DatabaseInfo};
|
||||
use crate::ProxyState;
|
||||
use anyhow::{anyhow, bail, Context};
|
||||
use lazy_static::lazy_static;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::net::TcpStream;
|
||||
use rand::prelude::StdRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use std::cell::Cell;
|
||||
use std::collections::HashMap;
|
||||
use std::net::{SocketAddr, TcpStream};
|
||||
use std::sync::Mutex;
|
||||
use std::{io, thread};
|
||||
use tokio_postgres::NoTls;
|
||||
use zenith_metrics::{new_common_metric_name, register_int_counter, IntCounter};
|
||||
use zenith_utils::pq_proto::{BeMessage as Be, *};
|
||||
use zenith_utils::postgres_backend::{self, PostgresBackend, ProtoState, Stream};
|
||||
use zenith_utils::pq_proto::{BeMessage as Be, FeMessage as Fe, *};
|
||||
use zenith_utils::sock_split::{ReadStream, WriteStream};
|
||||
|
||||
lazy_static! {
|
||||
static ref NUM_CONNECTIONS_ACCEPTED_COUNTER: IntCounter = register_int_counter!(
|
||||
new_common_metric_name("num_connections_accepted"),
|
||||
"Number of TCP client connections accepted."
|
||||
)
|
||||
.unwrap();
|
||||
static ref NUM_CONNECTIONS_CLOSED_COUNTER: IntCounter = register_int_counter!(
|
||||
new_common_metric_name("num_connections_closed"),
|
||||
"Number of TCP client connections closed."
|
||||
)
|
||||
.unwrap();
|
||||
static ref NUM_BYTES_PROXIED_COUNTER: IntCounter = register_int_counter!(
|
||||
new_common_metric_name("num_bytes_proxied"),
|
||||
"Number of bytes sent/received between any client and backend."
|
||||
)
|
||||
.unwrap();
|
||||
struct CancelClosure {
|
||||
socket_addr: SocketAddr,
|
||||
cancel_token: tokio_postgres::CancelToken,
|
||||
}
|
||||
|
||||
async fn log_error<R, F>(future: F) -> F::Output
|
||||
where
|
||||
F: std::future::Future<Output = anyhow::Result<R>>,
|
||||
{
|
||||
future.await.map_err(|err| {
|
||||
println!("error: {}", err);
|
||||
err
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn thread_main(
|
||||
config: &'static ProxyConfig,
|
||||
listener: tokio::net::TcpListener,
|
||||
) -> anyhow::Result<()> {
|
||||
scopeguard::defer! {
|
||||
println!("proxy has shut down");
|
||||
}
|
||||
|
||||
let cancel_map = Arc::new(CancelMap::default());
|
||||
loop {
|
||||
let (socket, peer_addr) = listener.accept().await?;
|
||||
println!("accepted connection from {}", peer_addr);
|
||||
|
||||
let cancel_map = Arc::clone(&cancel_map);
|
||||
tokio::spawn(log_error(async move {
|
||||
socket
|
||||
.set_nodelay(true)
|
||||
.context("failed to set socket option")?;
|
||||
|
||||
handle_client(config, &cancel_map, socket).await
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_client(
|
||||
config: &ProxyConfig,
|
||||
cancel_map: &CancelMap,
|
||||
stream: impl AsyncRead + AsyncWrite + Unpin,
|
||||
) -> anyhow::Result<()> {
|
||||
// The `closed` counter will increase when this future is destroyed.
|
||||
NUM_CONNECTIONS_ACCEPTED_COUNTER.inc();
|
||||
scopeguard::defer! {
|
||||
NUM_CONNECTIONS_CLOSED_COUNTER.inc();
|
||||
}
|
||||
|
||||
let tls = config.tls_config.clone();
|
||||
if let Some((client, creds)) = handshake(stream, tls, cancel_map).await? {
|
||||
cancel_map
|
||||
.with_session(|session| async {
|
||||
connect_client_to_db(config, session, client, creds).await
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle a connection from one client.
|
||||
/// For better testing experience, `stream` can be
|
||||
/// any object satisfying the traits.
|
||||
async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
stream: S,
|
||||
mut tls: Option<TlsConfig>,
|
||||
cancel_map: &CancelMap,
|
||||
) -> anyhow::Result<Option<(PqStream<Stream<S>>, auth::ClientCredentials)>> {
|
||||
// Client may try upgrading to each protocol only once
|
||||
let (mut tried_ssl, mut tried_gss) = (false, false);
|
||||
|
||||
let mut stream = PqStream::new(Stream::from_raw(stream));
|
||||
loop {
|
||||
let msg = stream.read_startup_packet().await?;
|
||||
println!("got message: {:?}", msg);
|
||||
|
||||
use FeStartupPacket::*;
|
||||
match msg {
|
||||
SslRequest => match stream.get_ref() {
|
||||
Stream::Raw { .. } if !tried_ssl => {
|
||||
tried_ssl = true;
|
||||
|
||||
// We can't perform TLS handshake without a config
|
||||
let enc = tls.is_some();
|
||||
stream.write_message(&Be::EncryptionResponse(enc)).await?;
|
||||
|
||||
if let Some(tls) = tls.take() {
|
||||
// Upgrade raw stream into a secure TLS-backed stream.
|
||||
// NOTE: We've consumed `tls`; this fact will be used later.
|
||||
stream = PqStream::new(stream.into_inner().upgrade(tls).await?);
|
||||
}
|
||||
}
|
||||
_ => bail!("protocol violation"),
|
||||
},
|
||||
GssEncRequest => match stream.get_ref() {
|
||||
Stream::Raw { .. } if !tried_gss => {
|
||||
tried_gss = true;
|
||||
|
||||
// Currently, we don't support GSSAPI
|
||||
stream.write_message(&Be::EncryptionResponse(false)).await?;
|
||||
}
|
||||
_ => bail!("protocol violation"),
|
||||
},
|
||||
StartupMessage { params, .. } => {
|
||||
// Check that the config has been consumed during upgrade
|
||||
// OR we didn't provide it at all (for dev purposes).
|
||||
if tls.is_some() {
|
||||
let msg = "connection is insecure (try using `sslmode=require`)";
|
||||
stream.write_message(&Be::ErrorResponse(msg)).await?;
|
||||
bail!(msg);
|
||||
}
|
||||
|
||||
break Ok(Some((stream, params.try_into()?)));
|
||||
}
|
||||
CancelRequest(cancel_key_data) => {
|
||||
cancel_map.cancel_session(cancel_key_data).await?;
|
||||
|
||||
break Ok(None);
|
||||
}
|
||||
impl CancelClosure {
|
||||
async fn try_cancel_query(&self) {
|
||||
if let Ok(socket) = tokio::net::TcpStream::connect(self.socket_addr).await {
|
||||
// NOTE ignoring the result because:
|
||||
// 1. This is a best effort attempt, the database doesn't have to listen
|
||||
// 2. Being opaque about errors here helps avoid leaking info to unauthenticated user
|
||||
let _ = self.cancel_token.cancel_query_raw(socket, NoTls).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn connect_client_to_db(
|
||||
config: &ProxyConfig,
|
||||
session: cancellation::Session<'_>,
|
||||
mut client: PqStream<impl AsyncRead + AsyncWrite + Unpin>,
|
||||
creds: auth::ClientCredentials,
|
||||
lazy_static! {
|
||||
// Enables serving CancelRequests
|
||||
static ref CANCEL_MAP: Mutex<HashMap<CancelKeyData, CancelClosure>> = Mutex::new(HashMap::new());
|
||||
|
||||
// Metrics
|
||||
static ref NUM_CONNECTIONS_ACCEPTED_COUNTER: IntCounter = register_int_counter!(
|
||||
new_common_metric_name("num_connections_accepted"),
|
||||
"Number of TCP client connections accepted."
|
||||
).unwrap();
|
||||
static ref NUM_CONNECTIONS_CLOSED_COUNTER: IntCounter = register_int_counter!(
|
||||
new_common_metric_name("num_connections_closed"),
|
||||
"Number of TCP client connections closed."
|
||||
).unwrap();
|
||||
static ref NUM_CONNECTIONS_FAILED_COUNTER: IntCounter = register_int_counter!(
|
||||
new_common_metric_name("num_connections_failed"),
|
||||
"Number of TCP client connections that closed due to error."
|
||||
).unwrap();
|
||||
static ref NUM_BYTES_PROXIED_COUNTER: IntCounter = register_int_counter!(
|
||||
new_common_metric_name("num_bytes_proxied"),
|
||||
"Number of bytes sent/received between any client and backend."
|
||||
).unwrap();
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
// Used to clean up the CANCEL_MAP. Might not be necessary if we use tokio thread pool in main loop.
|
||||
static THREAD_CANCEL_KEY_DATA: Cell<Option<CancelKeyData>> = Cell::new(None);
|
||||
}
|
||||
|
||||
///
|
||||
/// Main proxy listener loop.
|
||||
///
|
||||
/// Listens for connections, and launches a new handler thread for each.
|
||||
///
|
||||
pub fn thread_main(
|
||||
state: &'static ProxyState,
|
||||
listener: std::net::TcpListener,
|
||||
) -> anyhow::Result<()> {
|
||||
let db_info = creds.authenticate(config, &mut client).await?;
|
||||
let (db, version, cancel_closure) = connect_to_db(db_info).await?;
|
||||
let cancel_key_data = session.enable_cancellation(cancel_closure);
|
||||
loop {
|
||||
let (socket, peer_addr) = listener.accept()?;
|
||||
println!("accepted connection from {}", peer_addr);
|
||||
NUM_CONNECTIONS_ACCEPTED_COUNTER.inc();
|
||||
socket.set_nodelay(true).unwrap();
|
||||
|
||||
client
|
||||
.write_message_noflush(&BeMessage::ParameterStatus(
|
||||
BeParameterStatusMessage::ServerVersion(&version),
|
||||
))?
|
||||
.write_message_noflush(&Be::BackendKeyData(cancel_key_data))?
|
||||
.write_message(&BeMessage::ReadyForQuery)
|
||||
.await?;
|
||||
// TODO Use a threadpool instead. Maybe use tokio's threadpool by
|
||||
// spawning a future into its runtime. Tokio's JoinError should
|
||||
// allow us to handle cleanup properly even if the future panics.
|
||||
thread::Builder::new()
|
||||
.name("Proxy thread".into())
|
||||
.spawn(move || {
|
||||
if let Err(err) = proxy_conn_main(state, socket) {
|
||||
NUM_CONNECTIONS_FAILED_COUNTER.inc();
|
||||
println!("error: {}", err);
|
||||
}
|
||||
|
||||
// This function will be called for writes to either direction.
|
||||
fn inc_proxied(cnt: usize) {
|
||||
// Consider inventing something more sophisticated
|
||||
// if this ever becomes a bottleneck (cacheline bouncing).
|
||||
NUM_BYTES_PROXIED_COUNTER.inc_by(cnt as u64);
|
||||
// Clean up CANCEL_MAP.
|
||||
NUM_CONNECTIONS_CLOSED_COUNTER.inc();
|
||||
THREAD_CANCEL_KEY_DATA.with(|cell| {
|
||||
if let Some(cancel_key_data) = cell.get() {
|
||||
CANCEL_MAP.lock().unwrap().remove(&cancel_key_data);
|
||||
};
|
||||
});
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: clean up fields
|
||||
struct ProxyConnection {
|
||||
state: &'static ProxyState,
|
||||
psql_session_id: String,
|
||||
pgb: PostgresBackend,
|
||||
}
|
||||
|
||||
pub fn proxy_conn_main(state: &'static ProxyState, socket: TcpStream) -> anyhow::Result<()> {
|
||||
let conn = ProxyConnection {
|
||||
state,
|
||||
psql_session_id: hex::encode(rand::random::<[u8; 8]>()),
|
||||
pgb: PostgresBackend::new(
|
||||
socket,
|
||||
postgres_backend::AuthType::MD5,
|
||||
state.conf.ssl_config.clone(),
|
||||
false,
|
||||
)?,
|
||||
};
|
||||
|
||||
let (client, server) = match conn.handle_client()? {
|
||||
Some(x) => x,
|
||||
None => return Ok(()),
|
||||
};
|
||||
|
||||
let server = zenith_utils::sock_split::BidiStream::from_tcp(server);
|
||||
|
||||
let client = match client {
|
||||
Stream::Bidirectional(bidi_stream) => bidi_stream,
|
||||
_ => panic!("invalid stream type"),
|
||||
};
|
||||
|
||||
proxy(client.split(), server.split())
|
||||
}
|
||||
|
||||
impl ProxyConnection {
|
||||
/// Returns Ok(None) when connection was successfully closed.
|
||||
fn handle_client(mut self) -> anyhow::Result<Option<(Stream, TcpStream)>> {
|
||||
let mut authenticate = || {
|
||||
let (username, dbname) = match self.handle_startup()? {
|
||||
Some(x) => x,
|
||||
None => return Ok(None),
|
||||
};
|
||||
|
||||
// Both scenarios here should end up producing database credentials
|
||||
if username.ends_with("@zenith") {
|
||||
self.handle_existing_user(&username, &dbname).map(Some)
|
||||
} else {
|
||||
self.handle_new_user().map(Some)
|
||||
}
|
||||
};
|
||||
|
||||
let conn = match authenticate() {
|
||||
Ok(Some(db_info)) => connect_to_db(db_info),
|
||||
Ok(None) => return Ok(None),
|
||||
Err(e) => {
|
||||
// Report the error to the client
|
||||
self.pgb.write_message(&Be::ErrorResponse(&e.to_string()))?;
|
||||
bail!("failed to handle client: {:?}", e);
|
||||
}
|
||||
};
|
||||
|
||||
// We'll get rid of this once migration to async is complete
|
||||
let (pg_version, db_stream) = {
|
||||
let runtime = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()?;
|
||||
|
||||
let (pg_version, stream, cancel_key_data) = runtime.block_on(conn)?;
|
||||
self.pgb
|
||||
.write_message(&BeMessage::BackendKeyData(cancel_key_data))?;
|
||||
let stream = stream.into_std()?;
|
||||
stream.set_nonblocking(false)?;
|
||||
|
||||
(pg_version, stream)
|
||||
};
|
||||
|
||||
// Let the client send new requests
|
||||
self.pgb
|
||||
.write_message_noflush(&BeMessage::ParameterStatus(
|
||||
BeParameterStatusMessage::ServerVersion(&pg_version),
|
||||
))?
|
||||
.write_message(&Be::ReadyForQuery)?;
|
||||
|
||||
Ok(Some((self.pgb.into_stream(), db_stream)))
|
||||
}
|
||||
|
||||
let mut db = MetricsStream::new(db, inc_proxied);
|
||||
let mut client = MetricsStream::new(client.into_inner(), inc_proxied);
|
||||
let _ = tokio::io::copy_bidirectional(&mut client, &mut db).await?;
|
||||
/// Returns Ok(None) when connection was successfully closed.
|
||||
fn handle_startup(&mut self) -> anyhow::Result<Option<(String, String)>> {
|
||||
let have_tls = self.pgb.tls_config.is_some();
|
||||
let mut encrypted = false;
|
||||
|
||||
loop {
|
||||
let msg = match self.pgb.read_message()? {
|
||||
Some(Fe::StartupPacket(msg)) => msg,
|
||||
None => bail!("connection is lost"),
|
||||
bad => bail!("unexpected message type: {:?}", bad),
|
||||
};
|
||||
println!("got message: {:?}", msg);
|
||||
|
||||
match msg {
|
||||
FeStartupPacket::GssEncRequest => {
|
||||
self.pgb.write_message(&Be::EncryptionResponse(false))?;
|
||||
}
|
||||
FeStartupPacket::SslRequest => {
|
||||
self.pgb.write_message(&Be::EncryptionResponse(have_tls))?;
|
||||
if have_tls {
|
||||
self.pgb.start_tls()?;
|
||||
encrypted = true;
|
||||
}
|
||||
}
|
||||
FeStartupPacket::StartupMessage { mut params, .. } => {
|
||||
if have_tls && !encrypted {
|
||||
bail!("must connect with TLS");
|
||||
}
|
||||
|
||||
let mut get_param = |key| {
|
||||
params
|
||||
.remove(key)
|
||||
.with_context(|| format!("{} is missing in startup packet", key))
|
||||
};
|
||||
|
||||
return Ok(Some((get_param("user")?, get_param("database")?)));
|
||||
}
|
||||
FeStartupPacket::CancelRequest(cancel_key_data) => {
|
||||
if let Some(cancel_closure) = CANCEL_MAP.lock().unwrap().get(&cancel_key_data) {
|
||||
let runtime = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.unwrap();
|
||||
runtime.block_on(cancel_closure.try_cancel_query());
|
||||
}
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_existing_user(&mut self, user: &str, db: &str) -> anyhow::Result<DatabaseInfo> {
|
||||
let md5_salt = rand::random::<[u8; 4]>();
|
||||
|
||||
// Ask password
|
||||
self.pgb
|
||||
.write_message(&Be::AuthenticationMD5Password(&md5_salt))?;
|
||||
self.pgb.state = ProtoState::Authentication; // XXX
|
||||
|
||||
// Check password
|
||||
let msg = match self.pgb.read_message()? {
|
||||
Some(Fe::PasswordMessage(msg)) => msg,
|
||||
None => bail!("connection is lost"),
|
||||
bad => bail!("unexpected message type: {:?}", bad),
|
||||
};
|
||||
println!("got message: {:?}", msg);
|
||||
|
||||
let (_trailing_null, md5_response) = msg
|
||||
.split_last()
|
||||
.ok_or_else(|| anyhow!("unexpected password message"))?;
|
||||
|
||||
let cplane = CPlaneApi::new(&self.state.conf.auth_endpoint, &self.state.waiters);
|
||||
let db_info = cplane.authenticate_proxy_request(
|
||||
user,
|
||||
db,
|
||||
md5_response,
|
||||
&md5_salt,
|
||||
&self.psql_session_id,
|
||||
)?;
|
||||
|
||||
self.pgb
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
|
||||
|
||||
Ok(db_info)
|
||||
}
|
||||
|
||||
fn handle_new_user(&mut self) -> anyhow::Result<DatabaseInfo> {
|
||||
let greeting = hello_message(&self.state.conf.redirect_uri, &self.psql_session_id);
|
||||
|
||||
// First, register this session
|
||||
let waiter = self.state.waiters.register(self.psql_session_id.clone());
|
||||
|
||||
// Give user a URL to spawn a new database
|
||||
self.pgb
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
.write_message_noflush(&BeParameterStatusMessage::encoding())?
|
||||
.write_message(&Be::NoticeResponse(greeting))?;
|
||||
|
||||
// Wait for web console response
|
||||
let db_info = waiter.wait()?.map_err(|e| anyhow!(e))?;
|
||||
|
||||
self.pgb
|
||||
.write_message_noflush(&Be::NoticeResponse("Connecting to database.".into()))?;
|
||||
|
||||
Ok(db_info)
|
||||
}
|
||||
}
|
||||
|
||||
fn hello_message(redirect_uri: &str, session_id: &str) -> String {
|
||||
format!(
|
||||
concat![
|
||||
"☀️ Welcome to Zenith!\n",
|
||||
"To proceed with database creation, open the following link:\n\n",
|
||||
" {redirect_uri}{session_id}\n\n",
|
||||
"It needs to be done once and we will send you '.pgpass' file,\n",
|
||||
"which will allow you to access or create ",
|
||||
"databases without opening your web browser."
|
||||
],
|
||||
redirect_uri = redirect_uri,
|
||||
session_id = session_id,
|
||||
)
|
||||
}
|
||||
|
||||
/// Create a TCP connection to a postgres database, authenticate with it, and receive the ReadyForQuery message
|
||||
async fn connect_to_db(
|
||||
db_info: DatabaseInfo,
|
||||
) -> anyhow::Result<(String, tokio::net::TcpStream, CancelKeyData)> {
|
||||
// Make raw connection. When connect_raw finishes we've received ReadyForQuery.
|
||||
let socket_addr = db_info.socket_addr()?;
|
||||
let mut socket = tokio::net::TcpStream::connect(socket_addr).await?;
|
||||
let config = tokio_postgres::Config::from(db_info);
|
||||
// NOTE We effectively ignore some ParameterStatus and NoticeResponse
|
||||
// messages here. Not sure if that could break something.
|
||||
let (client, conn) = config.connect_raw(&mut socket, NoTls).await?;
|
||||
|
||||
// Save info for potentially cancelling the query later
|
||||
let mut rng = StdRng::from_entropy();
|
||||
let cancel_key_data = CancelKeyData {
|
||||
// HACK We'd rather get the real backend_pid but tokio_postgres doesn't
|
||||
// expose it and we don't want to do another roundtrip to query
|
||||
// for it. The client will be able to notice that this is not the
|
||||
// actual backend_pid, but backend_pid is not used for anything
|
||||
// so it doesn't matter.
|
||||
backend_pid: rng.gen(),
|
||||
cancel_key: rng.gen(),
|
||||
};
|
||||
let cancel_closure = CancelClosure {
|
||||
socket_addr,
|
||||
cancel_token: client.cancel_token(),
|
||||
};
|
||||
CANCEL_MAP
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(cancel_key_data, cancel_closure);
|
||||
THREAD_CANCEL_KEY_DATA.with(|cell| {
|
||||
let prev_value = cell.replace(Some(cancel_key_data));
|
||||
assert!(
|
||||
prev_value.is_none(),
|
||||
"THREAD_CANCEL_KEY_DATA was already set"
|
||||
);
|
||||
});
|
||||
|
||||
let version = conn.parameter("server_version").unwrap();
|
||||
Ok((version.into(), socket, cancel_key_data))
|
||||
}
|
||||
|
||||
/// Concurrently proxy both directions of the client and server connections
|
||||
fn proxy(
|
||||
(client_read, client_write): (ReadStream, WriteStream),
|
||||
(server_read, server_write): (ReadStream, WriteStream),
|
||||
) -> anyhow::Result<()> {
|
||||
fn do_proxy(mut reader: impl io::Read, mut writer: WriteStream) -> io::Result<u64> {
|
||||
/// FlushWriter will make sure that every message is sent as soon as possible
|
||||
struct FlushWriter<W>(W);
|
||||
|
||||
impl<W: io::Write> io::Write for FlushWriter<W> {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
// `std::io::copy` is guaranteed to exit if we return an error,
|
||||
// so we can afford to lose `res` in case `flush` fails
|
||||
let res = self.0.write(buf);
|
||||
if let Ok(count) = res {
|
||||
NUM_BYTES_PROXIED_COUNTER.inc_by(count as u64);
|
||||
self.flush()?;
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
self.0.flush()
|
||||
}
|
||||
}
|
||||
|
||||
let res = std::io::copy(&mut reader, &mut FlushWriter(&mut writer));
|
||||
writer.shutdown(std::net::Shutdown::Both)?;
|
||||
res
|
||||
}
|
||||
|
||||
let client_to_server_jh = thread::spawn(move || do_proxy(client_read, server_write));
|
||||
|
||||
do_proxy(server_read, client_write)?;
|
||||
client_to_server_jh.join().unwrap()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Connect to a corresponding compute node.
|
||||
async fn connect_to_db(
|
||||
db_info: DatabaseInfo,
|
||||
) -> anyhow::Result<(TcpStream, String, CancelClosure)> {
|
||||
// TODO: establish a secure connection to the DB
|
||||
let socket_addr = db_info.socket_addr()?;
|
||||
let mut socket = TcpStream::connect(socket_addr).await?;
|
||||
|
||||
let (client, conn) = tokio_postgres::Config::from(db_info)
|
||||
.connect_raw(&mut socket, NoTls)
|
||||
.await?;
|
||||
|
||||
let version = conn
|
||||
.parameter("server_version")
|
||||
.context("failed to fetch postgres server version")?
|
||||
.into();
|
||||
|
||||
let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token());
|
||||
|
||||
Ok((socket, version, cancel_closure))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use tokio::io::DuplexStream;
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tokio_postgres::tls::MakeTlsConnect;
|
||||
use tokio_postgres_rustls::MakeRustlsConnect;
|
||||
|
||||
async fn dummy_proxy(
|
||||
client: impl AsyncRead + AsyncWrite + Unpin,
|
||||
tls: Option<TlsConfig>,
|
||||
) -> anyhow::Result<()> {
|
||||
let cancel_map = CancelMap::default();
|
||||
|
||||
// TODO: add some infra + tests for credentials
|
||||
let (mut stream, _creds) = handshake(client, tls, &cancel_map)
|
||||
.await?
|
||||
.context("no stream")?;
|
||||
|
||||
stream
|
||||
.write_message_noflush(&Be::AuthenticationOk)?
|
||||
.write_message_noflush(&BeParameterStatusMessage::encoding())?
|
||||
.write_message(&BeMessage::ReadyForQuery)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn generate_certs(
|
||||
hostname: &str,
|
||||
) -> anyhow::Result<(rustls::Certificate, rustls::Certificate, rustls::PrivateKey)> {
|
||||
let ca = rcgen::Certificate::from_params({
|
||||
let mut params = rcgen::CertificateParams::default();
|
||||
params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
|
||||
params
|
||||
})?;
|
||||
|
||||
let cert = rcgen::generate_simple_self_signed(vec![hostname.into()])?;
|
||||
Ok((
|
||||
rustls::Certificate(ca.serialize_der()?),
|
||||
rustls::Certificate(cert.serialize_der_with_signer(&ca)?),
|
||||
rustls::PrivateKey(cert.serialize_private_key_der()),
|
||||
))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handshake_tls_is_enforced_by_proxy() -> anyhow::Result<()> {
|
||||
let (client, server) = tokio::io::duplex(1024);
|
||||
|
||||
let server_config = {
|
||||
let (_ca, cert, key) = generate_certs("localhost")?;
|
||||
|
||||
let mut config = rustls::ServerConfig::new(rustls::NoClientAuth::new());
|
||||
config.set_single_cert(vec![cert], key)?;
|
||||
config
|
||||
};
|
||||
|
||||
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config.into())));
|
||||
|
||||
tokio_postgres::Config::new()
|
||||
.user("john_doe")
|
||||
.dbname("earth")
|
||||
.ssl_mode(SslMode::Disable)
|
||||
.connect_raw(server, NoTls)
|
||||
.await
|
||||
.err() // -> Option<E>
|
||||
.context("client shouldn't be able to connect")?;
|
||||
|
||||
proxy
|
||||
.await?
|
||||
.err() // -> Option<E>
|
||||
.context("server shouldn't accept client")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handshake_tls() -> anyhow::Result<()> {
|
||||
let (client, server) = tokio::io::duplex(1024);
|
||||
|
||||
let (ca, cert, key) = generate_certs("localhost")?;
|
||||
|
||||
let server_config = {
|
||||
let mut config = rustls::ServerConfig::new(rustls::NoClientAuth::new());
|
||||
config.set_single_cert(vec![cert], key)?;
|
||||
config
|
||||
};
|
||||
|
||||
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config.into())));
|
||||
|
||||
let client_config = {
|
||||
let mut config = rustls::ClientConfig::new();
|
||||
config.root_store.add(&ca)?;
|
||||
config
|
||||
};
|
||||
|
||||
let mut mk = MakeRustlsConnect::new(client_config);
|
||||
let tls = MakeTlsConnect::<DuplexStream>::make_tls_connect(&mut mk, "localhost")?;
|
||||
|
||||
let (_client, _conn) = tokio_postgres::Config::new()
|
||||
.user("john_doe")
|
||||
.dbname("earth")
|
||||
.ssl_mode(SslMode::Require)
|
||||
.connect_raw(server, tls)
|
||||
.await?;
|
||||
|
||||
proxy.await?
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handshake_raw() -> anyhow::Result<()> {
|
||||
let (client, server) = tokio::io::duplex(1024);
|
||||
|
||||
let proxy = tokio::spawn(dummy_proxy(client, None));
|
||||
|
||||
let (_client, _conn) = tokio_postgres::Config::new()
|
||||
.user("john_doe")
|
||||
.dbname("earth")
|
||||
.ssl_mode(SslMode::Prefer)
|
||||
.connect_raw(server, NoTls)
|
||||
.await?;
|
||||
|
||||
proxy.await?
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,46 +1,15 @@
|
||||
use crate::cplane_api::DatabaseInfo;
|
||||
use anyhow::{anyhow, ensure, Context};
|
||||
use rustls::{internal::pemfile, NoClientAuth, ProtocolVersion, ServerConfig};
|
||||
use std::net::SocketAddr;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub type TlsConfig = Arc<ServerConfig>;
|
||||
|
||||
#[non_exhaustive]
|
||||
pub enum ClientAuthMethod {
|
||||
Password,
|
||||
Link,
|
||||
|
||||
/// Use password auth only if username ends with "@zenith"
|
||||
Mixed,
|
||||
}
|
||||
|
||||
pub enum RouterConfig {
|
||||
Static { host: String, port: u16 },
|
||||
Dynamic(ClientAuthMethod),
|
||||
}
|
||||
|
||||
impl FromStr for ClientAuthMethod {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(s: &str) -> anyhow::Result<Self> {
|
||||
use ClientAuthMethod::*;
|
||||
match s {
|
||||
"password" => Ok(Password),
|
||||
"link" => Ok(Link),
|
||||
"mixed" => Ok(Mixed),
|
||||
_ => Err(anyhow::anyhow!("Invlid option for router")),
|
||||
}
|
||||
}
|
||||
}
|
||||
pub type SslConfig = Arc<ServerConfig>;
|
||||
|
||||
pub struct ProxyConfig {
|
||||
/// main entrypoint for users to connect to
|
||||
pub proxy_address: SocketAddr,
|
||||
|
||||
/// method of assigning compute nodes
|
||||
pub router_config: RouterConfig,
|
||||
|
||||
/// internally used for status and prometheus metrics
|
||||
pub http_address: SocketAddr,
|
||||
|
||||
@@ -55,10 +24,26 @@ pub struct ProxyConfig {
|
||||
/// control plane address where we would check auth.
|
||||
pub auth_endpoint: String,
|
||||
|
||||
pub tls_config: Option<TlsConfig>,
|
||||
pub ssl_config: Option<SslConfig>,
|
||||
}
|
||||
|
||||
pub fn configure_ssl(key_path: &str, cert_path: &str) -> anyhow::Result<TlsConfig> {
|
||||
pub type ProxyWaiters = crate::waiters::Waiters<Result<DatabaseInfo, String>>;
|
||||
|
||||
pub struct ProxyState {
|
||||
pub conf: ProxyConfig,
|
||||
pub waiters: ProxyWaiters,
|
||||
}
|
||||
|
||||
impl ProxyState {
|
||||
pub fn new(conf: ProxyConfig) -> Self {
|
||||
Self {
|
||||
conf,
|
||||
waiters: ProxyWaiters::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn configure_ssl(key_path: &str, cert_path: &str) -> anyhow::Result<SslConfig> {
|
||||
let key = {
|
||||
let key_bytes = std::fs::read(key_path).context("SSL key file")?;
|
||||
let mut keys = pemfile::pkcs8_private_keys(&mut &key_bytes[..])
|
||||
@@ -1,230 +0,0 @@
|
||||
use anyhow::Context;
|
||||
use bytes::BytesMut;
|
||||
use pin_project_lite::pin_project;
|
||||
use rustls::ServerConfig;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::{io, task};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
|
||||
use tokio_rustls::server::TlsStream;
|
||||
use zenith_utils::pq_proto::{BeMessage, FeMessage, FeStartupPacket};
|
||||
|
||||
pin_project! {
|
||||
/// Stream wrapper which implements libpq's protocol.
|
||||
/// NOTE: This object deliberately doesn't implement [`AsyncRead`]
|
||||
/// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
|
||||
/// to pass random malformed bytes through the connection).
|
||||
pub struct PqStream<S> {
|
||||
#[pin]
|
||||
stream: S,
|
||||
buffer: BytesMut,
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> PqStream<S> {
|
||||
/// Construct a new libpq protocol wrapper.
|
||||
pub fn new(stream: S) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
buffer: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract the underlying stream.
|
||||
pub fn into_inner(self) -> S {
|
||||
self.stream
|
||||
}
|
||||
|
||||
/// Get a reference to the underlying stream.
|
||||
pub fn get_ref(&self) -> &S {
|
||||
&self.stream
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin> PqStream<S> {
|
||||
/// Receive [`FeStartupPacket`], which is a first packet sent by a client.
|
||||
pub async fn read_startup_packet(&mut self) -> anyhow::Result<FeStartupPacket> {
|
||||
match FeStartupPacket::read_fut(&mut self.stream).await? {
|
||||
Some(FeMessage::StartupPacket(packet)) => Ok(packet),
|
||||
None => anyhow::bail!("connection is lost"),
|
||||
other => anyhow::bail!("bad message type: {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn read_message(&mut self) -> anyhow::Result<FeMessage> {
|
||||
FeMessage::read_fut(&mut self.stream)
|
||||
.await?
|
||||
.context("connection is lost")
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncWrite + Unpin> PqStream<S> {
|
||||
/// Write the message into an internal buffer, but don't flush the underlying stream.
|
||||
pub fn write_message_noflush<'a>(&mut self, message: &BeMessage<'a>) -> io::Result<&mut Self> {
|
||||
BeMessage::write(&mut self.buffer, message)?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Write the message into an internal buffer and flush it.
|
||||
pub async fn write_message<'a>(&mut self, message: &BeMessage<'a>) -> io::Result<&mut Self> {
|
||||
self.write_message_noflush(message)?;
|
||||
self.flush().await?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
/// Flush the output buffer into the underlying stream.
|
||||
pub async fn flush(&mut self) -> io::Result<&mut Self> {
|
||||
self.stream.write_all(&self.buffer).await?;
|
||||
self.buffer.clear();
|
||||
self.stream.flush().await?;
|
||||
Ok(self)
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
/// Wrapper for upgrading raw streams into secure streams.
|
||||
/// NOTE: it should be possible to decompose this object as necessary.
|
||||
#[project = StreamProj]
|
||||
pub enum Stream<S> {
|
||||
/// We always begin with a raw stream,
|
||||
/// which may then be upgraded into a secure stream.
|
||||
Raw { #[pin] raw: S },
|
||||
/// We box [`TlsStream`] since it can be quite large.
|
||||
Tls { #[pin] tls: Box<TlsStream<S>> },
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Stream<S> {
|
||||
/// Construct a new instance from a raw stream.
|
||||
pub fn from_raw(raw: S) -> Self {
|
||||
Self::Raw { raw }
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
|
||||
/// If possible, upgrade raw stream into a secure TLS-based stream.
|
||||
pub async fn upgrade(self, cfg: Arc<ServerConfig>) -> anyhow::Result<Self> {
|
||||
match self {
|
||||
Stream::Raw { raw } => {
|
||||
let tls = Box::new(tokio_rustls::TlsAcceptor::from(cfg).accept(raw).await?);
|
||||
Ok(Stream::Tls { tls })
|
||||
}
|
||||
Stream::Tls { .. } => anyhow::bail!("can't upgrade TLS stream"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
use StreamProj::*;
|
||||
match self.project() {
|
||||
Raw { raw } => raw.poll_read(context, buf),
|
||||
Tls { tls } => tls.poll_read(context, buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> task::Poll<io::Result<usize>> {
|
||||
use StreamProj::*;
|
||||
match self.project() {
|
||||
Raw { raw } => raw.poll_write(context, buf),
|
||||
Tls { tls } => tls.poll_write(context, buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
use StreamProj::*;
|
||||
match self.project() {
|
||||
Raw { raw } => raw.poll_flush(context),
|
||||
Tls { tls } => tls.poll_flush(context),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
use StreamProj::*;
|
||||
match self.project() {
|
||||
Raw { raw } => raw.poll_shutdown(context),
|
||||
Tls { tls } => tls.poll_shutdown(context),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
/// This stream tracks all writes and calls user provided
|
||||
/// callback when the underlying stream is flushed.
|
||||
pub struct MetricsStream<S, W> {
|
||||
#[pin]
|
||||
stream: S,
|
||||
write_count: usize,
|
||||
inc_write_count: W,
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, W> MetricsStream<S, W> {
|
||||
pub fn new(stream: S, inc_write_count: W) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
write_count: 0,
|
||||
inc_write_count,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin, W> AsyncRead for MetricsStream<S, W> {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
self.project().stream.poll_read(context, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncWrite + Unpin, W: FnMut(usize)> AsyncWrite for MetricsStream<S, W> {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> task::Poll<io::Result<usize>> {
|
||||
let this = self.project();
|
||||
this.stream.poll_write(context, buf).map_ok(|cnt| {
|
||||
// Increment the write count.
|
||||
*this.write_count += cnt;
|
||||
cnt
|
||||
})
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
let this = self.project();
|
||||
this.stream.poll_flush(context).map_ok(|()| {
|
||||
// Call the user provided callback and reset the write count.
|
||||
(this.inc_write_count)(*this.write_count);
|
||||
*this.write_count = 0;
|
||||
})
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
context: &mut task::Context<'_>,
|
||||
) -> task::Poll<io::Result<()>> {
|
||||
self.project().stream.poll_shutdown(context)
|
||||
}
|
||||
}
|
||||
@@ -1,12 +1,8 @@
|
||||
use anyhow::{anyhow, Context};
|
||||
use hashbrown::HashMap;
|
||||
use parking_lot::Mutex;
|
||||
use pin_project_lite::pin_project;
|
||||
use std::pin::Pin;
|
||||
use std::task;
|
||||
use tokio::sync::oneshot;
|
||||
use anyhow::Context;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{mpsc, Mutex};
|
||||
|
||||
pub struct Waiters<T>(pub(self) Mutex<HashMap<String, oneshot::Sender<T>>>);
|
||||
pub struct Waiters<T>(pub(self) Mutex<HashMap<String, mpsc::Sender<T>>>);
|
||||
|
||||
impl<T> Default for Waiters<T> {
|
||||
fn default() -> Self {
|
||||
@@ -15,86 +11,48 @@ impl<T> Default for Waiters<T> {
|
||||
}
|
||||
|
||||
impl<T> Waiters<T> {
|
||||
pub fn register(&self, key: String) -> anyhow::Result<Waiter<T>> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
pub fn register(&self, key: String) -> Waiter<T> {
|
||||
let (tx, rx) = mpsc::channel();
|
||||
|
||||
self.0
|
||||
.lock()
|
||||
.try_insert(key.clone(), tx)
|
||||
.map_err(|_| anyhow!("waiter already registered"))?;
|
||||
// TODO: use `try_insert` (unstable)
|
||||
let prev = self.0.lock().unwrap().insert(key.clone(), tx);
|
||||
assert!(matches!(prev, None)); // assert_matches! is nightly-only
|
||||
|
||||
Ok(Waiter {
|
||||
Waiter {
|
||||
receiver: rx,
|
||||
guard: DropKey {
|
||||
registry: self,
|
||||
key,
|
||||
},
|
||||
})
|
||||
registry: self,
|
||||
key,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn notify(&self, key: &str, value: T) -> anyhow::Result<()>
|
||||
where
|
||||
T: Send + Sync,
|
||||
T: Send + Sync + 'static,
|
||||
{
|
||||
let tx = self
|
||||
.0
|
||||
.lock()
|
||||
.unwrap()
|
||||
.remove(key)
|
||||
.with_context(|| format!("key {} not found", key))?;
|
||||
|
||||
tx.send(value).map_err(|_| anyhow!("waiter channel hangup"))
|
||||
tx.send(value).context("channel hangup")
|
||||
}
|
||||
}
|
||||
|
||||
struct DropKey<'a, T> {
|
||||
key: String,
|
||||
pub struct Waiter<'a, T> {
|
||||
receiver: mpsc::Receiver<T>,
|
||||
registry: &'a Waiters<T>,
|
||||
key: String,
|
||||
}
|
||||
|
||||
impl<'a, T> Drop for DropKey<'a, T> {
|
||||
impl<T> Waiter<'_, T> {
|
||||
pub fn wait(self) -> anyhow::Result<T> {
|
||||
self.receiver.recv().context("channel hangup")
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Drop for Waiter<'_, T> {
|
||||
fn drop(&mut self) {
|
||||
self.registry.0.lock().remove(&self.key);
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
pub struct Waiter<'a, T> {
|
||||
#[pin]
|
||||
receiver: oneshot::Receiver<T>,
|
||||
guard: DropKey<'a, T>,
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> std::future::Future for Waiter<'_, T> {
|
||||
type Output = anyhow::Result<T>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
|
||||
self.project()
|
||||
.receiver
|
||||
.poll(cx)
|
||||
.map_err(|_| anyhow!("channel hangup"))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_waiter() -> anyhow::Result<()> {
|
||||
let waiters = Arc::new(Waiters::default());
|
||||
|
||||
let key = "Key";
|
||||
let waiter = waiters.register(key.to_owned())?;
|
||||
|
||||
let waiters = Arc::clone(&waiters);
|
||||
let notifier = tokio::spawn(async move {
|
||||
waiters.notify(key, Default::default())?;
|
||||
Ok(())
|
||||
});
|
||||
|
||||
let () = waiter.await?;
|
||||
notifier.await?
|
||||
self.registry.0.lock().unwrap().remove(&self.key);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,6 +21,7 @@ types-psycopg2 = "^2.9.6"
|
||||
boto3 = "^1.20.40"
|
||||
boto3-stubs = "^1.20.40"
|
||||
moto = {version = "^3.0.0", extras = ["server"]}
|
||||
pytest-skip-slow = "^0.0.2"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
yapf = "==0.31.0"
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from contextlib import closing
|
||||
from typing import Iterator
|
||||
from uuid import UUID, uuid4
|
||||
from uuid import uuid4
|
||||
import psycopg2
|
||||
from fixtures.zenith_fixtures import ZenithEnvBuilder, ZenithPageserverApiException
|
||||
from fixtures.zenith_fixtures import ZenithEnvBuilder
|
||||
import pytest
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
def test_pageserver_auth(zenith_env_builder: ZenithEnvBuilder):
|
||||
zenith_env_builder.pageserver_auth_enabled = True
|
||||
@@ -12,38 +14,32 @@ def test_pageserver_auth(zenith_env_builder: ZenithEnvBuilder):
|
||||
|
||||
ps = env.pageserver
|
||||
|
||||
tenant_token = env.auth_keys.generate_tenant_token(env.initial_tenant.hex)
|
||||
tenant_http_client = env.pageserver.http_client(tenant_token)
|
||||
tenant_token = env.auth_keys.generate_tenant_token(env.initial_tenant)
|
||||
invalid_tenant_token = env.auth_keys.generate_tenant_token(uuid4().hex)
|
||||
invalid_tenant_http_client = env.pageserver.http_client(invalid_tenant_token)
|
||||
|
||||
management_token = env.auth_keys.generate_management_token()
|
||||
management_http_client = env.pageserver.http_client(management_token)
|
||||
|
||||
# this does not invoke auth check and only decodes jwt and checks it for validity
|
||||
# check both tokens
|
||||
ps.safe_psql("set FOO", password=tenant_token)
|
||||
ps.safe_psql("set FOO", password=management_token)
|
||||
ps.safe_psql("status", password=tenant_token)
|
||||
ps.safe_psql("status", password=management_token)
|
||||
|
||||
# tenant can create branches
|
||||
tenant_http_client.branch_create(env.initial_tenant, 'new1', 'main')
|
||||
ps.safe_psql(f"branch_create {env.initial_tenant} new1 main", password=tenant_token)
|
||||
# console can create branches for tenant
|
||||
management_http_client.branch_create(env.initial_tenant, 'new2', 'main')
|
||||
ps.safe_psql(f"branch_create {env.initial_tenant} new2 main", password=management_token)
|
||||
|
||||
# fail to create branch using token with different tenant_id
|
||||
with pytest.raises(ZenithPageserverApiException,
|
||||
match='Forbidden: Tenant id mismatch. Permission denied'):
|
||||
invalid_tenant_http_client.branch_create(env.initial_tenant, "new3", "main")
|
||||
# fail to create branch using token with different tenantid
|
||||
with pytest.raises(psycopg2.DatabaseError, match='Tenant id mismatch. Permission denied'):
|
||||
ps.safe_psql(f"branch_create {env.initial_tenant} new2 main", password=invalid_tenant_token)
|
||||
|
||||
# create tenant using management token
|
||||
management_http_client.tenant_create(uuid4())
|
||||
ps.safe_psql(f"tenant_create {uuid4().hex}", password=management_token)
|
||||
|
||||
# fail to create tenant using tenant token
|
||||
with pytest.raises(
|
||||
ZenithPageserverApiException,
|
||||
match='Forbidden: Attempt to access management api with tenant scope. Permission denied'
|
||||
):
|
||||
tenant_http_client.tenant_create(uuid4())
|
||||
psycopg2.DatabaseError,
|
||||
match='Attempt to access management api with tenant scope. Permission denied'):
|
||||
ps.safe_psql(f"tenant_create {uuid4().hex}", password=tenant_token)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('with_wal_acceptors', [False, True])
|
||||
@@ -54,7 +50,7 @@ def test_compute_auth_to_pageserver(zenith_env_builder: ZenithEnvBuilder, with_w
|
||||
env = zenith_env_builder.init()
|
||||
|
||||
branch = f"test_compute_auth_to_pageserver{with_wal_acceptors}"
|
||||
env.zenith_cli.create_branch(branch, "main")
|
||||
env.zenith_cli(["branch", branch, "main"])
|
||||
|
||||
pg = env.postgres.create_start(branch)
|
||||
|
||||
|
||||
@@ -1,154 +0,0 @@
|
||||
from contextlib import closing, contextmanager
|
||||
import psycopg2.extras
|
||||
from fixtures.zenith_fixtures import ZenithEnvBuilder
|
||||
from fixtures.log_helper import log
|
||||
import os
|
||||
import time
|
||||
import asyncpg
|
||||
from fixtures.zenith_fixtures import Postgres
|
||||
import threading
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def pg_cur(pg):
|
||||
with closing(pg.connect()) as conn:
|
||||
with conn.cursor() as cur:
|
||||
yield cur
|
||||
|
||||
|
||||
# Periodically check that all backpressure lags are below the configured threshold,
|
||||
# assert if they are not.
|
||||
# If the check query fails, stop the thread. Main thread should notice that and stop the test.
|
||||
def check_backpressure(pg: Postgres, stop_event: threading.Event, polling_interval=5):
|
||||
log.info("checks started")
|
||||
|
||||
with pg_cur(pg) as cur:
|
||||
cur.execute("CREATE EXTENSION zenith") # TODO move it to zenith_fixtures?
|
||||
|
||||
cur.execute("select pg_size_bytes(current_setting('max_replication_write_lag'))")
|
||||
res = cur.fetchone()
|
||||
max_replication_write_lag_bytes = res[0]
|
||||
log.info(f"max_replication_write_lag: {max_replication_write_lag_bytes} bytes")
|
||||
|
||||
cur.execute("select pg_size_bytes(current_setting('max_replication_flush_lag'))")
|
||||
res = cur.fetchone()
|
||||
max_replication_flush_lag_bytes = res[0]
|
||||
log.info(f"max_replication_flush_lag: {max_replication_flush_lag_bytes} bytes")
|
||||
|
||||
cur.execute("select pg_size_bytes(current_setting('max_replication_apply_lag'))")
|
||||
res = cur.fetchone()
|
||||
max_replication_apply_lag_bytes = res[0]
|
||||
log.info(f"max_replication_apply_lag: {max_replication_apply_lag_bytes} bytes")
|
||||
|
||||
with pg_cur(pg) as cur:
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
cur.execute('''
|
||||
select pg_wal_lsn_diff(pg_current_wal_flush_lsn(),received_lsn) as received_lsn_lag,
|
||||
pg_wal_lsn_diff(pg_current_wal_flush_lsn(),disk_consistent_lsn) as disk_consistent_lsn_lag,
|
||||
pg_wal_lsn_diff(pg_current_wal_flush_lsn(),remote_consistent_lsn) as remote_consistent_lsn_lag,
|
||||
pg_size_pretty(pg_wal_lsn_diff(pg_current_wal_flush_lsn(),received_lsn)),
|
||||
pg_size_pretty(pg_wal_lsn_diff(pg_current_wal_flush_lsn(),disk_consistent_lsn)),
|
||||
pg_size_pretty(pg_wal_lsn_diff(pg_current_wal_flush_lsn(),remote_consistent_lsn))
|
||||
from backpressure_lsns();
|
||||
''')
|
||||
|
||||
res = cur.fetchone()
|
||||
received_lsn_lag = res[0]
|
||||
disk_consistent_lsn_lag = res[1]
|
||||
remote_consistent_lsn_lag = res[2]
|
||||
|
||||
log.info(f"received_lsn_lag = {received_lsn_lag} ({res[3]}), "
|
||||
f"disk_consistent_lsn_lag = {disk_consistent_lsn_lag} ({res[4]}), "
|
||||
f"remote_consistent_lsn_lag = {remote_consistent_lsn_lag} ({res[5]})")
|
||||
|
||||
# Since feedback from pageserver is not immediate, we should allow some lag overflow
|
||||
lag_overflow = 5 * 1024 * 1024 # 5MB
|
||||
|
||||
if max_replication_write_lag_bytes > 0:
|
||||
assert received_lsn_lag < max_replication_write_lag_bytes + lag_overflow
|
||||
if max_replication_flush_lag_bytes > 0:
|
||||
assert disk_consistent_lsn_lag < max_replication_flush_lag_bytes + lag_overflow
|
||||
if max_replication_apply_lag_bytes > 0:
|
||||
assert remote_consistent_lsn_lag < max_replication_apply_lag_bytes + lag_overflow
|
||||
|
||||
time.sleep(polling_interval)
|
||||
|
||||
except Exception as e:
|
||||
log.info(f"backpressure check query failed: {e}")
|
||||
stop_event.set()
|
||||
|
||||
log.info('check thread stopped')
|
||||
|
||||
|
||||
# This test illustrates how to tune backpressure to control the lag
|
||||
# between the WAL flushed on compute node and WAL digested by pageserver.
|
||||
#
|
||||
# To test it, throttle walreceiver ingest using failpoint and run heavy write load.
|
||||
# If backpressure is disabled or not tuned properly, the query will timeout, because the walreceiver cannot keep up.
|
||||
# If backpressure is enabled and tuned properly, insertion will be throttled, but the query will not timeout.
|
||||
|
||||
|
||||
def test_backpressure_received_lsn_lag(zenith_env_builder: ZenithEnvBuilder):
|
||||
zenith_env_builder.num_safekeepers = 1
|
||||
env = zenith_env_builder.init()
|
||||
# Create a branch for us
|
||||
env.zenith_cli.create_branch("test_backpressure", "main")
|
||||
|
||||
pg = env.postgres.create_start('test_backpressure',
|
||||
config_lines=['max_replication_write_lag=30MB'])
|
||||
log.info("postgres is running on 'test_backpressure' branch")
|
||||
|
||||
# setup check thread
|
||||
check_stop_event = threading.Event()
|
||||
check_thread = threading.Thread(target=check_backpressure, args=(pg, check_stop_event))
|
||||
check_thread.start()
|
||||
|
||||
# Configure failpoint to slow down walreceiver ingest
|
||||
with closing(env.pageserver.connect()) as psconn:
|
||||
with psconn.cursor(cursor_factory=psycopg2.extras.DictCursor) as pscur:
|
||||
pscur.execute("failpoints walreceiver-after-ingest=sleep(20)")
|
||||
|
||||
# FIXME
|
||||
# Wait for the check thread to start
|
||||
#
|
||||
# Now if load starts too soon,
|
||||
# check thread cannot auth, because it is not able to connect to the database
|
||||
# because of the lag and waiting for lsn to replay to arrive.
|
||||
time.sleep(2)
|
||||
|
||||
with pg_cur(pg) as cur:
|
||||
# Create and initialize test table
|
||||
cur.execute("CREATE TABLE foo(x bigint)")
|
||||
|
||||
inserts_to_do = 2000000
|
||||
rows_inserted = 0
|
||||
|
||||
while check_thread.is_alive() and rows_inserted < inserts_to_do:
|
||||
try:
|
||||
cur.execute("INSERT INTO foo select from generate_series(1, 100000)")
|
||||
rows_inserted += 100000
|
||||
except Exception as e:
|
||||
if check_thread.is_alive():
|
||||
log.info('stopping check thread')
|
||||
check_stop_event.set()
|
||||
check_thread.join()
|
||||
assert False, f"Exception {e} while inserting rows, but WAL lag is within configured threshold. That means backpressure is not tuned properly"
|
||||
else:
|
||||
assert False, f"Exception {e} while inserting rows and WAL lag overflowed configured threshold. That means backpressure doesn't work."
|
||||
|
||||
log.info(f"inserted {rows_inserted} rows")
|
||||
|
||||
if check_thread.is_alive():
|
||||
log.info('stopping check thread')
|
||||
check_stop_event.set()
|
||||
check_thread.join()
|
||||
log.info('check thread stopped')
|
||||
else:
|
||||
assert False, "WAL lag overflowed configured threshold. That means backpressure doesn't work."
|
||||
|
||||
|
||||
#TODO test_backpressure_disk_consistent_lsn_lag. Play with pageserver's checkpoint settings
|
||||
#TODO test_backpressure_remote_consistent_lsn_lag
|
||||
@@ -7,6 +7,8 @@ from fixtures.log_helper import log
|
||||
from fixtures.utils import print_gc_result
|
||||
from fixtures.zenith_fixtures import ZenithEnvBuilder
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
#
|
||||
# Create a couple of branches off the main branch, at a historical point in time.
|
||||
@@ -22,7 +24,7 @@ def test_branch_behind(zenith_env_builder: ZenithEnvBuilder):
|
||||
env = zenith_env_builder.init()
|
||||
|
||||
# Branch at the point where only 100 rows were inserted
|
||||
env.zenith_cli.create_branch("test_branch_behind", "main")
|
||||
env.zenith_cli(["branch", "test_branch_behind", "main"])
|
||||
|
||||
pgmain = env.postgres.create_start('test_branch_behind')
|
||||
log.info("postgres is running on 'test_branch_behind' branch")
|
||||
@@ -60,7 +62,7 @@ def test_branch_behind(zenith_env_builder: ZenithEnvBuilder):
|
||||
log.info(f'LSN after 200100 rows: {lsn_b}')
|
||||
|
||||
# Branch at the point where only 100 rows were inserted
|
||||
env.zenith_cli.create_branch("test_branch_behind_hundred", "test_branch_behind@" + lsn_a)
|
||||
env.zenith_cli(["branch", "test_branch_behind_hundred", "test_branch_behind@" + lsn_a])
|
||||
|
||||
# Insert many more rows. This generates enough WAL to fill a few segments.
|
||||
main_cur.execute('''
|
||||
@@ -75,7 +77,7 @@ def test_branch_behind(zenith_env_builder: ZenithEnvBuilder):
|
||||
log.info(f'LSN after 400100 rows: {lsn_c}')
|
||||
|
||||
# Branch at the point where only 200100 rows were inserted
|
||||
env.zenith_cli.create_branch("test_branch_behind_more", "test_branch_behind@" + lsn_b)
|
||||
env.zenith_cli(["branch", "test_branch_behind_more", "test_branch_behind@" + lsn_b])
|
||||
|
||||
pg_hundred = env.postgres.create_start("test_branch_behind_hundred")
|
||||
pg_more = env.postgres.create_start("test_branch_behind_more")
|
||||
@@ -99,7 +101,7 @@ def test_branch_behind(zenith_env_builder: ZenithEnvBuilder):
|
||||
# Check bad lsn's for branching
|
||||
|
||||
# branch at segment boundary
|
||||
env.zenith_cli.create_branch("test_branch_segment_boundary", "test_branch_behind@0/3000000")
|
||||
env.zenith_cli(["branch", "test_branch_segment_boundary", "test_branch_behind@0/3000000"])
|
||||
pg = env.postgres.create_start("test_branch_segment_boundary")
|
||||
cur = pg.connect().cursor()
|
||||
cur.execute('SELECT 1')
|
||||
@@ -107,23 +109,23 @@ def test_branch_behind(zenith_env_builder: ZenithEnvBuilder):
|
||||
|
||||
# branch at pre-initdb lsn
|
||||
with pytest.raises(Exception, match="invalid branch start lsn"):
|
||||
env.zenith_cli.create_branch("test_branch_preinitdb", "main@0/42")
|
||||
env.zenith_cli(["branch", "test_branch_preinitdb", "main@0/42"])
|
||||
|
||||
# branch at pre-ancestor lsn
|
||||
with pytest.raises(Exception, match="less than timeline ancestor lsn"):
|
||||
env.zenith_cli.create_branch("test_branch_preinitdb", "test_branch_behind@0/42")
|
||||
env.zenith_cli(["branch", "test_branch_preinitdb", "test_branch_behind@0/42"])
|
||||
|
||||
# check that we cannot create branch based on garbage collected data
|
||||
with closing(env.pageserver.connect()) as psconn:
|
||||
with psconn.cursor(cursor_factory=psycopg2.extras.DictCursor) as pscur:
|
||||
# call gc to advace latest_gc_cutoff_lsn
|
||||
pscur.execute(f"do_gc {env.initial_tenant.hex} {timeline} 0")
|
||||
pscur.execute(f"do_gc {env.initial_tenant} {timeline} 0")
|
||||
row = pscur.fetchone()
|
||||
print_gc_result(row)
|
||||
|
||||
with pytest.raises(Exception, match="invalid branch start lsn"):
|
||||
# this gced_lsn is pretty random, so if gc is disabled this woudln't fail
|
||||
env.zenith_cli.create_branch("test_branch_create_fail", f"test_branch_behind@{gced_lsn}")
|
||||
env.zenith_cli(["branch", "test_branch_create_fail", f"test_branch_behind@{gced_lsn}"])
|
||||
|
||||
# check that after gc everything is still there
|
||||
hundred_cur.execute('SELECT count(*) FROM foo')
|
||||
|
||||
@@ -6,13 +6,16 @@ from contextlib import closing
|
||||
from fixtures.zenith_fixtures import ZenithEnv
|
||||
from fixtures.log_helper import log
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
#
|
||||
# Test compute node start after clog truncation
|
||||
#
|
||||
def test_clog_truncate(zenith_simple_env: ZenithEnv):
|
||||
env = zenith_simple_env
|
||||
env.zenith_cli.create_branch("test_clog_truncate", "empty")
|
||||
# Create a branch for us
|
||||
env.zenith_cli(["branch", "test_clog_truncate", "empty"])
|
||||
|
||||
# set agressive autovacuum to make sure that truncation will happen
|
||||
config = [
|
||||
@@ -62,8 +65,8 @@ def test_clog_truncate(zenith_simple_env: ZenithEnv):
|
||||
|
||||
# create new branch after clog truncation and start a compute node on it
|
||||
log.info(f'create branch at lsn_after_truncation {lsn_after_truncation}')
|
||||
env.zenith_cli.create_branch("test_clog_truncate_new",
|
||||
"test_clog_truncate@" + lsn_after_truncation)
|
||||
env.zenith_cli(
|
||||
["branch", "test_clog_truncate_new", "test_clog_truncate@" + lsn_after_truncation])
|
||||
|
||||
pg2 = env.postgres.create_start('test_clog_truncate_new')
|
||||
log.info('postgres is running on test_clog_truncate_new branch')
|
||||
|
||||
@@ -3,13 +3,16 @@ from contextlib import closing
|
||||
from fixtures.zenith_fixtures import ZenithEnv
|
||||
from fixtures.log_helper import log
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
#
|
||||
# Test starting Postgres with custom options
|
||||
#
|
||||
def test_config(zenith_simple_env: ZenithEnv):
|
||||
env = zenith_simple_env
|
||||
env.zenith_cli.create_branch("test_config", "empty")
|
||||
# Create a branch for us
|
||||
env.zenith_cli(["branch", "test_config", "empty"])
|
||||
|
||||
# change config
|
||||
pg = env.postgres.create_start('test_config', config_lines=['log_min_messages=debug1'])
|
||||
|
||||
@@ -5,13 +5,15 @@ from contextlib import closing
|
||||
from fixtures.zenith_fixtures import ZenithEnv, check_restored_datadir_content
|
||||
from fixtures.log_helper import log
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
#
|
||||
# Test CREATE DATABASE when there have been relmapper changes
|
||||
#
|
||||
def test_createdb(zenith_simple_env: ZenithEnv):
|
||||
env = zenith_simple_env
|
||||
env.zenith_cli.create_branch("test_createdb", "empty")
|
||||
env.zenith_cli(["branch", "test_createdb", "empty"])
|
||||
|
||||
pg = env.postgres.create_start('test_createdb')
|
||||
log.info("postgres is running on 'test_createdb' branch")
|
||||
@@ -27,7 +29,7 @@ def test_createdb(zenith_simple_env: ZenithEnv):
|
||||
lsn = cur.fetchone()[0]
|
||||
|
||||
# Create a branch
|
||||
env.zenith_cli.create_branch("test_createdb2", "test_createdb@" + lsn)
|
||||
env.zenith_cli(["branch", "test_createdb2", "test_createdb@" + lsn])
|
||||
|
||||
pg2 = env.postgres.create_start('test_createdb2')
|
||||
|
||||
@@ -41,7 +43,7 @@ def test_createdb(zenith_simple_env: ZenithEnv):
|
||||
#
|
||||
def test_dropdb(zenith_simple_env: ZenithEnv, test_output_dir):
|
||||
env = zenith_simple_env
|
||||
env.zenith_cli.create_branch("test_dropdb", "empty")
|
||||
env.zenith_cli(["branch", "test_dropdb", "empty"])
|
||||
|
||||
pg = env.postgres.create_start('test_dropdb')
|
||||
log.info("postgres is running on 'test_dropdb' branch")
|
||||
@@ -66,10 +68,10 @@ def test_dropdb(zenith_simple_env: ZenithEnv, test_output_dir):
|
||||
lsn_after_drop = cur.fetchone()[0]
|
||||
|
||||
# Create two branches before and after database drop.
|
||||
env.zenith_cli.create_branch("test_before_dropdb", "test_dropdb@" + lsn_before_drop)
|
||||
env.zenith_cli(["branch", "test_before_dropdb", "test_dropdb@" + lsn_before_drop])
|
||||
pg_before = env.postgres.create_start('test_before_dropdb')
|
||||
|
||||
env.zenith_cli.create_branch("test_after_dropdb", "test_dropdb@" + lsn_after_drop)
|
||||
env.zenith_cli(["branch", "test_after_dropdb", "test_dropdb@" + lsn_after_drop])
|
||||
pg_after = env.postgres.create_start('test_after_dropdb')
|
||||
|
||||
# Test that database exists on the branch before drop
|
||||
|
||||
@@ -3,13 +3,15 @@ from contextlib import closing
|
||||
from fixtures.zenith_fixtures import ZenithEnv
|
||||
from fixtures.log_helper import log
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
#
|
||||
# Test CREATE USER to check shared catalog restore
|
||||
#
|
||||
def test_createuser(zenith_simple_env: ZenithEnv):
|
||||
env = zenith_simple_env
|
||||
env.zenith_cli.create_branch("test_createuser", "empty")
|
||||
env.zenith_cli(["branch", "test_createuser", "empty"])
|
||||
|
||||
pg = env.postgres.create_start('test_createuser')
|
||||
log.info("postgres is running on 'test_createuser' branch")
|
||||
@@ -25,7 +27,7 @@ def test_createuser(zenith_simple_env: ZenithEnv):
|
||||
lsn = cur.fetchone()[0]
|
||||
|
||||
# Create a branch
|
||||
env.zenith_cli.create_branch("test_createuser2", "test_createuser@" + lsn)
|
||||
env.zenith_cli(["branch", "test_createuser2", "test_createuser@" + lsn])
|
||||
|
||||
pg2 = env.postgres.create_start('test_createuser2')
|
||||
|
||||
|
||||
@@ -7,6 +7,8 @@ import random
|
||||
from fixtures.zenith_fixtures import ZenithEnv, Postgres, Safekeeper
|
||||
from fixtures.log_helper import log
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
# Test configuration
|
||||
#
|
||||
# Create a table with {num_rows} rows, and perform {updates_to_perform} random
|
||||
@@ -34,7 +36,7 @@ async def gc(env: ZenithEnv, timeline: str):
|
||||
psconn = await env.pageserver.connect_async()
|
||||
|
||||
while updates_performed < updates_to_perform:
|
||||
await psconn.execute(f"do_gc {env.initial_tenant.hex} {timeline} 0")
|
||||
await psconn.execute(f"do_gc {env.initial_tenant} {timeline} 0")
|
||||
|
||||
|
||||
# At the same time, run UPDATEs and GC
|
||||
@@ -55,7 +57,9 @@ async def update_and_gc(env: ZenithEnv, pg: Postgres, timeline: str):
|
||||
#
|
||||
def test_gc_aggressive(zenith_simple_env: ZenithEnv):
|
||||
env = zenith_simple_env
|
||||
env.zenith_cli.create_branch("test_gc_aggressive", "empty")
|
||||
# Create a branch for us
|
||||
env.zenith_cli(["branch", "test_gc_aggressive", "empty"])
|
||||
|
||||
pg = env.postgres.create_start('test_gc_aggressive')
|
||||
log.info('postgres is running on test_gc_aggressive branch')
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from fixtures.zenith_fixtures import ZenithEnv, check_restored_datadir_content
|
||||
from fixtures.log_helper import log
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
#
|
||||
# Test multixact state after branching
|
||||
@@ -10,7 +12,8 @@ from fixtures.log_helper import log
|
||||
#
|
||||
def test_multixact(zenith_simple_env: ZenithEnv, test_output_dir):
|
||||
env = zenith_simple_env
|
||||
env.zenith_cli.create_branch("test_multixact", "empty")
|
||||
# Create a branch for us
|
||||
env.zenith_cli(["branch", "test_multixact", "empty"])
|
||||
pg = env.postgres.create_start('test_multixact')
|
||||
|
||||
log.info("postgres is running on 'test_multixact' branch")
|
||||
@@ -60,7 +63,7 @@ def test_multixact(zenith_simple_env: ZenithEnv, test_output_dir):
|
||||
assert int(next_multixact_id) > int(next_multixact_id_old)
|
||||
|
||||
# Branch at this point
|
||||
env.zenith_cli.create_branch("test_multixact_new", "test_multixact@" + lsn)
|
||||
env.zenith_cli(["branch", "test_multixact_new", "test_multixact@" + lsn])
|
||||
pg_new = env.postgres.create_start('test_multixact_new')
|
||||
|
||||
log.info("postgres is running on 'test_multixact_new' branch")
|
||||
|
||||
@@ -5,6 +5,8 @@ import time
|
||||
from fixtures.zenith_fixtures import ZenithEnvBuilder
|
||||
from fixtures.log_helper import log
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
# Test restarting page server, while safekeeper and compute node keep
|
||||
# running.
|
||||
|
||||
@@ -3,6 +3,8 @@ from contextlib import closing
|
||||
from fixtures.zenith_fixtures import ZenithEnv
|
||||
from fixtures.log_helper import log
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
#
|
||||
# Test where Postgres generates a lot of WAL, and it's garbage collected away, but
|
||||
@@ -16,7 +18,8 @@ from fixtures.log_helper import log
|
||||
#
|
||||
def test_old_request_lsn(zenith_simple_env: ZenithEnv):
|
||||
env = zenith_simple_env
|
||||
env.zenith_cli.create_branch("test_old_request_lsn", "empty")
|
||||
# Create a branch for us
|
||||
env.zenith_cli(["branch", "test_old_request_lsn", "empty"])
|
||||
pg = env.postgres.create_start('test_old_request_lsn')
|
||||
log.info('postgres is running on test_old_request_lsn branch')
|
||||
|
||||
@@ -54,7 +57,7 @@ def test_old_request_lsn(zenith_simple_env: ZenithEnv):
|
||||
# Make a lot of updates on a single row, generating a lot of WAL. Trigger
|
||||
# garbage collections so that the page server will remove old page versions.
|
||||
for i in range(10):
|
||||
pscur.execute(f"do_gc {env.initial_tenant.hex} {timeline} 0")
|
||||
pscur.execute(f"do_gc {env.initial_tenant} {timeline} 0")
|
||||
for j in range(100):
|
||||
cur.execute('UPDATE foo SET val = val + 1 WHERE id = 1;')
|
||||
|
||||
|
||||
@@ -1,15 +1,95 @@
|
||||
import json
|
||||
from uuid import uuid4, UUID
|
||||
import pytest
|
||||
import psycopg2
|
||||
import requests
|
||||
from fixtures.zenith_fixtures import ZenithEnv, ZenithEnvBuilder, ZenithPageserverHttpClient
|
||||
from typing import cast
|
||||
import pytest, psycopg2
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
def check_client(client: ZenithPageserverHttpClient, initial_tenant: UUID):
|
||||
def test_status_psql(zenith_simple_env: ZenithEnv):
|
||||
env = zenith_simple_env
|
||||
assert env.pageserver.safe_psql('status') == [
|
||||
('hello world', ),
|
||||
]
|
||||
|
||||
|
||||
def test_branch_list_psql(zenith_simple_env: ZenithEnv):
|
||||
env = zenith_simple_env
|
||||
# Create a branch for us
|
||||
env.zenith_cli(["branch", "test_branch_list_main", "empty"])
|
||||
|
||||
conn = env.pageserver.connect()
|
||||
cur = conn.cursor()
|
||||
|
||||
cur.execute(f'branch_list {env.initial_tenant}')
|
||||
branches = json.loads(cur.fetchone()[0])
|
||||
# Filter out branches created by other tests
|
||||
branches = [x for x in branches if x['name'].startswith('test_branch_list')]
|
||||
|
||||
assert len(branches) == 1
|
||||
assert branches[0]['name'] == 'test_branch_list_main'
|
||||
assert 'timeline_id' in branches[0]
|
||||
assert 'latest_valid_lsn' in branches[0]
|
||||
assert 'ancestor_id' in branches[0]
|
||||
assert 'ancestor_lsn' in branches[0]
|
||||
|
||||
# Create another branch, and start Postgres on it
|
||||
env.zenith_cli(['branch', 'test_branch_list_experimental', 'test_branch_list_main'])
|
||||
env.zenith_cli(['pg', 'create', 'test_branch_list_experimental'])
|
||||
|
||||
cur.execute(f'branch_list {env.initial_tenant}')
|
||||
new_branches = json.loads(cur.fetchone()[0])
|
||||
# Filter out branches created by other tests
|
||||
new_branches = [x for x in new_branches if x['name'].startswith('test_branch_list')]
|
||||
assert len(new_branches) == 2
|
||||
new_branches.sort(key=lambda k: k['name'])
|
||||
|
||||
assert new_branches[0]['name'] == 'test_branch_list_experimental'
|
||||
assert new_branches[0]['timeline_id'] != branches[0]['timeline_id']
|
||||
|
||||
# TODO: do the LSNs have to match here?
|
||||
assert new_branches[1] == branches[0]
|
||||
|
||||
conn.close()
|
||||
|
||||
|
||||
def test_tenant_list_psql(zenith_env_builder: ZenithEnvBuilder):
|
||||
# don't use zenith_simple_env, because there might be other tenants there,
|
||||
# left over from other tests.
|
||||
env = zenith_env_builder.init()
|
||||
|
||||
res = env.zenith_cli(["tenant", "list"])
|
||||
res.check_returncode()
|
||||
tenants = sorted(map(lambda t: t.split()[0], res.stdout.splitlines()))
|
||||
assert tenants == [env.initial_tenant]
|
||||
|
||||
conn = env.pageserver.connect()
|
||||
cur = conn.cursor()
|
||||
|
||||
# check same tenant cannot be created twice
|
||||
with pytest.raises(psycopg2.DatabaseError,
|
||||
match=f'repo for {env.initial_tenant} already exists'):
|
||||
cur.execute(f'tenant_create {env.initial_tenant}')
|
||||
|
||||
# create one more tenant
|
||||
tenant1 = uuid4().hex
|
||||
cur.execute(f'tenant_create {tenant1}')
|
||||
|
||||
cur.execute('tenant_list')
|
||||
|
||||
# compare tenants list
|
||||
new_tenants = sorted(map(lambda t: cast(str, t['id']), json.loads(cur.fetchone()[0])))
|
||||
assert sorted([env.initial_tenant, tenant1]) == new_tenants
|
||||
|
||||
|
||||
def check_client(client: ZenithPageserverHttpClient, initial_tenant: str):
|
||||
client.check_status()
|
||||
|
||||
# check initial tenant is there
|
||||
assert initial_tenant.hex in {t['id'] for t in client.tenant_list()}
|
||||
assert initial_tenant in {t['id'] for t in client.tenant_list()}
|
||||
|
||||
# create new tenant and check it is also there
|
||||
tenant_id = uuid4()
|
||||
|
||||
@@ -7,6 +7,8 @@ from multiprocessing import Process, Value
|
||||
from fixtures.zenith_fixtures import ZenithEnvBuilder
|
||||
from fixtures.log_helper import log
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
# Test safekeeper sync and pageserver catch up
|
||||
# while initial compute node is down and pageserver is lagging behind safekeepers.
|
||||
@@ -16,7 +18,7 @@ def test_pageserver_catchup_while_compute_down(zenith_env_builder: ZenithEnvBuil
|
||||
zenith_env_builder.num_safekeepers = 3
|
||||
env = zenith_env_builder.init()
|
||||
|
||||
env.zenith_cli.create_branch("test_pageserver_catchup_while_compute_down", "main")
|
||||
env.zenith_cli(["branch", "test_pageserver_catchup_while_compute_down", "main"])
|
||||
pg = env.postgres.create_start('test_pageserver_catchup_while_compute_down')
|
||||
|
||||
pg_conn = pg.connect()
|
||||
|
||||
@@ -7,6 +7,8 @@ from multiprocessing import Process, Value
|
||||
from fixtures.zenith_fixtures import ZenithEnvBuilder
|
||||
from fixtures.log_helper import log
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
# Test restarting page server, while safekeeper and compute node keep
|
||||
# running.
|
||||
@@ -15,7 +17,7 @@ def test_pageserver_restart(zenith_env_builder: ZenithEnvBuilder):
|
||||
zenith_env_builder.num_safekeepers = 1
|
||||
env = zenith_env_builder.init()
|
||||
|
||||
env.zenith_cli.create_branch("test_pageserver_restart", "main")
|
||||
env.zenith_cli(["branch", "test_pageserver_restart", "main"])
|
||||
pg = env.postgres.create_start('test_pageserver_restart')
|
||||
|
||||
pg_conn = pg.connect()
|
||||
|
||||
@@ -5,6 +5,8 @@ import subprocess
|
||||
from fixtures.zenith_fixtures import ZenithEnv, Postgres
|
||||
from fixtures.log_helper import log
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
async def repeat_bytes(buf, repetitions: int):
|
||||
for i in range(repetitions):
|
||||
@@ -37,7 +39,9 @@ async def parallel_load_same_table(pg: Postgres, n_parallel: int):
|
||||
# Load data into one table with COPY TO from 5 parallel connections
|
||||
def test_parallel_copy(zenith_simple_env: ZenithEnv, n_parallel=5):
|
||||
env = zenith_simple_env
|
||||
env.zenith_cli.create_branch("test_parallel_copy", "empty")
|
||||
# Create a branch for us
|
||||
env.zenith_cli(["branch", "test_parallel_copy", "empty"])
|
||||
|
||||
pg = env.postgres.create_start('test_parallel_copy')
|
||||
log.info("postgres is running on 'test_parallel_copy' branch")
|
||||
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
from fixtures.zenith_fixtures import ZenithEnv
|
||||
from fixtures.log_helper import log
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
def test_pgbench(zenith_simple_env: ZenithEnv, pg_bin):
|
||||
env = zenith_simple_env
|
||||
env.zenith_cli.create_branch("test_pgbench", "empty")
|
||||
# Create a branch for us
|
||||
env.zenith_cli(["branch", "test_pgbench", "empty"])
|
||||
|
||||
pg = env.postgres.create_start('test_pgbench')
|
||||
log.info("postgres is running on 'test_pgbench' branch")
|
||||
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
import pytest
|
||||
import subprocess
|
||||
import signal
|
||||
import time
|
||||
|
||||
|
||||
def test_proxy_select_1(static_proxy):
|
||||
static_proxy.safe_psql("select 1;")
|
||||
|
||||
|
||||
def test_proxy_cancel(static_proxy):
|
||||
"""Test that we can cancel a big generate_series query."""
|
||||
conn = static_proxy.connect()
|
||||
conn.cancel()
|
||||
|
||||
with conn.cursor() as cur:
|
||||
from psycopg2.errors import QueryCanceled
|
||||
with pytest.raises(QueryCanceled):
|
||||
cur.execute("select * from generate_series(1, 100000000);")
|
||||
|
||||
|
||||
def test_proxy_pgbench_cancel(static_proxy, pg_bin):
|
||||
"""Test that we can cancel the init phase of pgbench."""
|
||||
start_time = static_proxy.safe_psql("select now();")[0]
|
||||
|
||||
def get_running_queries():
|
||||
magic_string = "fsdsdfhdfhfgbcbfgbfgbf"
|
||||
with static_proxy.connect() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(f"""
|
||||
-- {magic_string}
|
||||
select query
|
||||
from pg_stat_activity
|
||||
where pg_stat_activity.query_start > %s
|
||||
""", start_time)
|
||||
return [
|
||||
row[0]
|
||||
for row in cur.fetchall()
|
||||
if not magic_string in row[0]
|
||||
]
|
||||
|
||||
# Let pgbench init run for 1 second
|
||||
p = subprocess.Popen(['pgbench', '-s500', '-i', static_proxy.connstr()])
|
||||
time.sleep(1)
|
||||
|
||||
# Make sure something is still running, and that get_running_queries works
|
||||
assert len(get_running_queries()) > 0
|
||||
|
||||
# Send sigint, which would cancel any pgbench queries
|
||||
p.send_signal(signal.SIGINT)
|
||||
|
||||
# Assert that nothing is running
|
||||
time.sleep(1)
|
||||
assert len(get_running_queries()) == 0
|
||||
@@ -2,6 +2,8 @@ import pytest
|
||||
from fixtures.log_helper import log
|
||||
from fixtures.zenith_fixtures import ZenithEnv
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
#
|
||||
# Create read-only compute nodes, anchored at historical points in time.
|
||||
@@ -11,7 +13,7 @@ from fixtures.zenith_fixtures import ZenithEnv
|
||||
#
|
||||
def test_readonly_node(zenith_simple_env: ZenithEnv):
|
||||
env = zenith_simple_env
|
||||
env.zenith_cli.create_branch("test_readonly_node", "empty")
|
||||
env.zenith_cli(["branch", "test_readonly_node", "empty"])
|
||||
|
||||
pgmain = env.postgres.create_start('test_readonly_node')
|
||||
log.info("postgres is running on 'test_readonly_node' branch")
|
||||
@@ -86,5 +88,4 @@ def test_readonly_node(zenith_simple_env: ZenithEnv):
|
||||
# Create node at pre-initdb lsn
|
||||
with pytest.raises(Exception, match="invalid basebackup lsn"):
|
||||
# compute node startup with invalid LSN should fail
|
||||
env.zenith_cli.pg_start("test_readonly_node_preinitdb",
|
||||
timeline_spec="test_readonly_node@0/42")
|
||||
env.zenith_cli(["pg", "start", "test_readonly_node_preinitdb", "test_readonly_node@0/42"])
|
||||
|
||||
@@ -9,6 +9,8 @@ from fixtures.zenith_fixtures import ZenithEnvBuilder
|
||||
from fixtures.log_helper import log
|
||||
import pytest
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
#
|
||||
# Tests that a piece of data is backed up and restored correctly:
|
||||
|
||||
@@ -4,6 +4,8 @@ from contextlib import closing
|
||||
from fixtures.zenith_fixtures import ZenithEnvBuilder
|
||||
from fixtures.log_helper import log
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
#
|
||||
# Test restarting and recreating a postgres instance
|
||||
@@ -15,7 +17,7 @@ def test_restart_compute(zenith_env_builder: ZenithEnvBuilder, with_wal_acceptor
|
||||
zenith_env_builder.num_safekeepers = 3
|
||||
env = zenith_env_builder.init()
|
||||
|
||||
env.zenith_cli.create_branch("test_restart_compute", "main")
|
||||
env.zenith_cli(["branch", "test_restart_compute", "main"])
|
||||
|
||||
pg = env.postgres.create_start('test_restart_compute')
|
||||
log.info("postgres is running on 'test_restart_compute' branch")
|
||||
|
||||
@@ -5,6 +5,8 @@ from fixtures.utils import print_gc_result
|
||||
from fixtures.zenith_fixtures import ZenithEnv
|
||||
from fixtures.log_helper import log
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
#
|
||||
# Test Garbage Collection of old layer files
|
||||
@@ -14,7 +16,7 @@ from fixtures.log_helper import log
|
||||
#
|
||||
def test_layerfiles_gc(zenith_simple_env: ZenithEnv):
|
||||
env = zenith_simple_env
|
||||
env.zenith_cli.create_branch("test_layerfiles_gc", "empty")
|
||||
env.zenith_cli(["branch", "test_layerfiles_gc", "empty"])
|
||||
pg = env.postgres.create_start('test_layerfiles_gc')
|
||||
|
||||
with closing(pg.connect()) as conn:
|
||||
@@ -48,7 +50,7 @@ def test_layerfiles_gc(zenith_simple_env: ZenithEnv):
|
||||
cur.execute("DELETE FROM foo")
|
||||
|
||||
log.info("Running GC before test")
|
||||
pscur.execute(f"do_gc {env.initial_tenant.hex} {timeline} 0")
|
||||
pscur.execute(f"do_gc {env.initial_tenant} {timeline} 0")
|
||||
row = pscur.fetchone()
|
||||
print_gc_result(row)
|
||||
# remember the number of files
|
||||
@@ -61,7 +63,7 @@ def test_layerfiles_gc(zenith_simple_env: ZenithEnv):
|
||||
# removing the old image and delta layer.
|
||||
log.info("Inserting one row and running GC")
|
||||
cur.execute("INSERT INTO foo VALUES (1)")
|
||||
pscur.execute(f"do_gc {env.initial_tenant.hex} {timeline} 0")
|
||||
pscur.execute(f"do_gc {env.initial_tenant} {timeline} 0")
|
||||
row = pscur.fetchone()
|
||||
print_gc_result(row)
|
||||
assert row['layer_relfiles_total'] == layer_relfiles_remain + 2
|
||||
@@ -75,7 +77,7 @@ def test_layerfiles_gc(zenith_simple_env: ZenithEnv):
|
||||
cur.execute("INSERT INTO foo VALUES (2)")
|
||||
cur.execute("INSERT INTO foo VALUES (3)")
|
||||
|
||||
pscur.execute(f"do_gc {env.initial_tenant.hex} {timeline} 0")
|
||||
pscur.execute(f"do_gc {env.initial_tenant} {timeline} 0")
|
||||
row = pscur.fetchone()
|
||||
print_gc_result(row)
|
||||
assert row['layer_relfiles_total'] == layer_relfiles_remain + 2
|
||||
@@ -87,7 +89,7 @@ def test_layerfiles_gc(zenith_simple_env: ZenithEnv):
|
||||
cur.execute("INSERT INTO foo VALUES (2)")
|
||||
cur.execute("INSERT INTO foo VALUES (3)")
|
||||
|
||||
pscur.execute(f"do_gc {env.initial_tenant.hex} {timeline} 0")
|
||||
pscur.execute(f"do_gc {env.initial_tenant} {timeline} 0")
|
||||
row = pscur.fetchone()
|
||||
print_gc_result(row)
|
||||
assert row['layer_relfiles_total'] == layer_relfiles_remain + 2
|
||||
@@ -96,7 +98,7 @@ def test_layerfiles_gc(zenith_simple_env: ZenithEnv):
|
||||
|
||||
# Run GC again, with no changes in the database. Should not remove anything.
|
||||
log.info("Run GC again, with nothing to do")
|
||||
pscur.execute(f"do_gc {env.initial_tenant.hex} {timeline} 0")
|
||||
pscur.execute(f"do_gc {env.initial_tenant} {timeline} 0")
|
||||
row = pscur.fetchone()
|
||||
print_gc_result(row)
|
||||
assert row['layer_relfiles_total'] == layer_relfiles_remain
|
||||
@@ -109,7 +111,7 @@ def test_layerfiles_gc(zenith_simple_env: ZenithEnv):
|
||||
log.info("Drop table and run GC again")
|
||||
cur.execute("DROP TABLE foo")
|
||||
|
||||
pscur.execute(f"do_gc {env.initial_tenant.hex} {timeline} 0")
|
||||
pscur.execute(f"do_gc {env.initial_tenant} {timeline} 0")
|
||||
row = pscur.fetchone()
|
||||
print_gc_result(row)
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from fixtures.zenith_fixtures import ZenithEnv, check_restored_datadir_content
|
||||
from fixtures.log_helper import log
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
# Test subtransactions
|
||||
#
|
||||
@@ -10,7 +12,8 @@ from fixtures.log_helper import log
|
||||
# CLOG.
|
||||
def test_subxacts(zenith_simple_env: ZenithEnv, test_output_dir):
|
||||
env = zenith_simple_env
|
||||
env.zenith_cli.create_branch("test_subxacts", "empty")
|
||||
# Create a branch for us
|
||||
env.zenith_cli(["branch", "test_subxacts", "empty"])
|
||||
pg = env.postgres.create_start('test_subxacts')
|
||||
|
||||
log.info("postgres is running on 'test_subxacts' branch")
|
||||
|
||||
@@ -108,8 +108,8 @@ def load(pg: Postgres, stop_event: threading.Event, load_ok_event: threading.Eve
|
||||
log.info('load thread stopped')
|
||||
|
||||
|
||||
def assert_local(pageserver_http_client: ZenithPageserverHttpClient, tenant: UUID, timeline: str):
|
||||
timeline_detail = pageserver_http_client.timeline_detail(tenant, UUID(timeline))
|
||||
def assert_local(pageserver_http_client: ZenithPageserverHttpClient, tenant: str, timeline: str):
|
||||
timeline_detail = pageserver_http_client.timeline_detail(UUID(tenant), UUID(timeline))
|
||||
assert timeline_detail.get('type') == "Local", timeline_detail
|
||||
return timeline_detail
|
||||
|
||||
@@ -127,10 +127,10 @@ def test_tenant_relocation(zenith_env_builder: ZenithEnvBuilder,
|
||||
# create folder for remote storage mock
|
||||
remote_storage_mock_path = env.repo_dir / 'local_fs_remote_storage'
|
||||
|
||||
tenant = env.create_tenant(UUID("74ee8b079a0e437eb0afea7d26a07209"))
|
||||
tenant = env.create_tenant("74ee8b079a0e437eb0afea7d26a07209")
|
||||
log.info("tenant to relocate %s", tenant)
|
||||
|
||||
env.zenith_cli.create_branch("test_tenant_relocation", "main", tenant_id=tenant)
|
||||
env.zenith_cli(["branch", "test_tenant_relocation", "main", f"--tenantid={tenant}"])
|
||||
|
||||
tenant_pg = env.postgres.create_start(
|
||||
"test_tenant_relocation",
|
||||
@@ -167,11 +167,11 @@ def test_tenant_relocation(zenith_env_builder: ZenithEnvBuilder,
|
||||
# run checkpoint manually to be sure that data landed in remote storage
|
||||
with closing(env.pageserver.connect()) as psconn:
|
||||
with psconn.cursor() as pscur:
|
||||
pscur.execute(f"do_gc {tenant.hex} {timeline}")
|
||||
pscur.execute(f"do_gc {tenant} {timeline}")
|
||||
|
||||
# ensure upload is completed
|
||||
pageserver_http_client = env.pageserver.http_client()
|
||||
timeline_detail = pageserver_http_client.timeline_detail(tenant, UUID(timeline))
|
||||
timeline_detail = pageserver_http_client.timeline_detail(UUID(tenant), UUID(timeline))
|
||||
assert timeline_detail['disk_consistent_lsn'] == timeline_detail['timeline_state']['Ready']
|
||||
|
||||
log.info("inititalizing new pageserver")
|
||||
@@ -194,7 +194,7 @@ def test_tenant_relocation(zenith_env_builder: ZenithEnvBuilder,
|
||||
new_pageserver_http_port):
|
||||
|
||||
# call to attach timeline to new pageserver
|
||||
new_pageserver_http_client.timeline_attach(tenant, UUID(timeline))
|
||||
new_pageserver_http_client.timeline_attach(UUID(tenant), UUID(timeline))
|
||||
# FIXME cannot handle duplicate download requests, subject to fix in https://github.com/zenithdb/zenith/issues/997
|
||||
time.sleep(5)
|
||||
# new pageserver should in sync (modulo wal tail or vacuum activity) with the old one because there was no new writes since checkpoint
|
||||
@@ -241,7 +241,7 @@ def test_tenant_relocation(zenith_env_builder: ZenithEnvBuilder,
|
||||
# detach tenant from old pageserver before we check
|
||||
# that all the data is there to be sure that old pageserver
|
||||
# is no longer involved, and if it is, we will see the errors
|
||||
pageserver_http_client.timeline_detach(tenant, UUID(timeline))
|
||||
pageserver_http_client.timeline_detach(UUID(tenant), UUID(timeline))
|
||||
|
||||
with pg_cur(tenant_pg) as cur:
|
||||
# check that data is still there
|
||||
|
||||
@@ -15,12 +15,18 @@ def test_tenants_normal_work(zenith_env_builder: ZenithEnvBuilder, with_wal_acce
|
||||
tenant_1 = env.create_tenant()
|
||||
tenant_2 = env.create_tenant()
|
||||
|
||||
env.zenith_cli.create_branch(f"test_tenants_normal_work_with_wal_acceptors{with_wal_acceptors}",
|
||||
"main",
|
||||
tenant_id=tenant_1)
|
||||
env.zenith_cli.create_branch(f"test_tenants_normal_work_with_wal_acceptors{with_wal_acceptors}",
|
||||
"main",
|
||||
tenant_id=tenant_2)
|
||||
env.zenith_cli([
|
||||
"branch",
|
||||
f"test_tenants_normal_work_with_wal_acceptors{with_wal_acceptors}",
|
||||
"main",
|
||||
f"--tenantid={tenant_1}"
|
||||
])
|
||||
env.zenith_cli([
|
||||
"branch",
|
||||
f"test_tenants_normal_work_with_wal_acceptors{with_wal_acceptors}",
|
||||
"main",
|
||||
f"--tenantid={tenant_2}"
|
||||
])
|
||||
|
||||
pg_tenant1 = env.postgres.create_start(
|
||||
f"test_tenants_normal_work_with_wal_acceptors{with_wal_acceptors}",
|
||||
|
||||
@@ -10,10 +10,10 @@ import time
|
||||
def test_timeline_size(zenith_simple_env: ZenithEnv):
|
||||
env = zenith_simple_env
|
||||
# Branch at the point where only 100 rows were inserted
|
||||
env.zenith_cli.create_branch("test_timeline_size", "empty")
|
||||
env.zenith_cli(["branch", "test_timeline_size", "empty"])
|
||||
|
||||
client = env.pageserver.http_client()
|
||||
res = client.branch_detail(env.initial_tenant, "test_timeline_size")
|
||||
res = client.branch_detail(UUID(env.initial_tenant), "test_timeline_size")
|
||||
assert res["current_logical_size"] == res["current_logical_size_non_incremental"]
|
||||
|
||||
pgmain = env.postgres.create_start("test_timeline_size")
|
||||
@@ -31,36 +31,36 @@ def test_timeline_size(zenith_simple_env: ZenithEnv):
|
||||
FROM generate_series(1, 10) g
|
||||
""")
|
||||
|
||||
res = client.branch_detail(env.initial_tenant, "test_timeline_size")
|
||||
res = client.branch_detail(UUID(env.initial_tenant), "test_timeline_size")
|
||||
assert res["current_logical_size"] == res["current_logical_size_non_incremental"]
|
||||
cur.execute("TRUNCATE foo")
|
||||
|
||||
res = client.branch_detail(env.initial_tenant, "test_timeline_size")
|
||||
res = client.branch_detail(UUID(env.initial_tenant), "test_timeline_size")
|
||||
assert res["current_logical_size"] == res["current_logical_size_non_incremental"]
|
||||
|
||||
|
||||
# wait until received_lsn_lag is 0
|
||||
# wait until write_lag is 0
|
||||
def wait_for_pageserver_catchup(pgmain: Postgres, polling_interval=1, timeout=60):
|
||||
started_at = time.time()
|
||||
|
||||
received_lsn_lag = 1
|
||||
while received_lsn_lag > 0:
|
||||
write_lag = 1
|
||||
while write_lag > 0:
|
||||
elapsed = time.time() - started_at
|
||||
if elapsed > timeout:
|
||||
raise RuntimeError(
|
||||
f"timed out waiting for pageserver to reach pg_current_wal_flush_lsn()")
|
||||
raise RuntimeError(f"timed out waiting for pageserver to reach pg_current_wal_lsn()")
|
||||
|
||||
with closing(pgmain.connect()) as conn:
|
||||
with conn.cursor() as cur:
|
||||
|
||||
cur.execute('''
|
||||
select pg_size_pretty(pg_cluster_size()),
|
||||
pg_wal_lsn_diff(pg_current_wal_flush_lsn(),received_lsn) as received_lsn_lag
|
||||
FROM backpressure_lsns();
|
||||
pg_wal_lsn_diff(pg_current_wal_lsn(),write_lsn) as write_lag,
|
||||
pg_wal_lsn_diff(pg_current_wal_lsn(),sent_lsn) as pending_lag
|
||||
FROM pg_stat_get_wal_senders();
|
||||
''')
|
||||
res = cur.fetchone()
|
||||
log.info(f"pg_cluster_size = {res[0]}, received_lsn_lag = {res[1]}")
|
||||
received_lsn_lag = res[1]
|
||||
log.info(
|
||||
f"pg_cluster_size = {res[0]}, write_lag = {res[1]}, pending_lag = {res[2]}")
|
||||
write_lag = res[1]
|
||||
|
||||
time.sleep(polling_interval)
|
||||
|
||||
@@ -68,10 +68,10 @@ def wait_for_pageserver_catchup(pgmain: Postgres, polling_interval=1, timeout=60
|
||||
def test_timeline_size_quota(zenith_env_builder: ZenithEnvBuilder):
|
||||
zenith_env_builder.num_safekeepers = 1
|
||||
env = zenith_env_builder.init()
|
||||
env.zenith_cli.create_branch("test_timeline_size_quota", "main")
|
||||
env.zenith_cli(["branch", "test_timeline_size_quota", "main"])
|
||||
|
||||
client = env.pageserver.http_client()
|
||||
res = client.branch_detail(env.initial_tenant, "test_timeline_size_quota")
|
||||
res = client.branch_detail(UUID(env.initial_tenant), "test_timeline_size_quota")
|
||||
assert res["current_logical_size"] == res["current_logical_size_non_incremental"]
|
||||
|
||||
pgmain = env.postgres.create_start(
|
||||
|
||||
@@ -3,13 +3,15 @@ import os
|
||||
from fixtures.zenith_fixtures import ZenithEnv
|
||||
from fixtures.log_helper import log
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
#
|
||||
# Test branching, when a transaction is in prepared state
|
||||
#
|
||||
def test_twophase(zenith_simple_env: ZenithEnv):
|
||||
env = zenith_simple_env
|
||||
env.zenith_cli.create_branch("test_twophase", "empty")
|
||||
env.zenith_cli(["branch", "test_twophase", "empty"])
|
||||
|
||||
pg = env.postgres.create_start('test_twophase', config_lines=['max_prepared_transactions=5'])
|
||||
log.info("postgres is running on 'test_twophase' branch")
|
||||
@@ -56,7 +58,7 @@ def test_twophase(zenith_simple_env: ZenithEnv):
|
||||
assert len(twophase_files) == 2
|
||||
|
||||
# Create a branch with the transaction in prepared state
|
||||
env.zenith_cli.create_branch("test_twophase_prepared", "test_twophase")
|
||||
env.zenith_cli(["branch", "test_twophase_prepared", "test_twophase"])
|
||||
|
||||
# Start compute on the new branch
|
||||
pg2 = env.postgres.create_start(
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from fixtures.zenith_fixtures import ZenithEnv
|
||||
from fixtures.log_helper import log
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
#
|
||||
# Test that the VM bit is cleared correctly at a HEAP_DELETE and
|
||||
@@ -9,7 +11,8 @@ from fixtures.log_helper import log
|
||||
def test_vm_bit_clear(zenith_simple_env: ZenithEnv):
|
||||
env = zenith_simple_env
|
||||
|
||||
env.zenith_cli.create_branch("test_vm_bit_clear", "empty")
|
||||
# Create a branch for us
|
||||
env.zenith_cli(["branch", "test_vm_bit_clear", "empty"])
|
||||
pg = env.postgres.create_start('test_vm_bit_clear')
|
||||
|
||||
log.info("postgres is running on 'test_vm_bit_clear' branch")
|
||||
@@ -33,7 +36,7 @@ def test_vm_bit_clear(zenith_simple_env: ZenithEnv):
|
||||
cur.execute('UPDATE vmtest_update SET id = 5000 WHERE id = 1')
|
||||
|
||||
# Branch at this point, to test that later
|
||||
env.zenith_cli.create_branch("test_vm_bit_clear_new", "test_vm_bit_clear")
|
||||
env.zenith_cli(["branch", "test_vm_bit_clear_new", "test_vm_bit_clear"])
|
||||
|
||||
# Clear the buffer cache, to force the VM page to be re-fetched from
|
||||
# the page server
|
||||
|
||||
@@ -17,6 +17,8 @@ from fixtures.utils import lsn_to_hex, mkdir_if_needed
|
||||
from fixtures.log_helper import log
|
||||
from typing import List, Optional, Any
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
# basic test, write something in setup with wal acceptors, ensure that commits
|
||||
# succeed and data is written
|
||||
@@ -24,7 +26,7 @@ def test_normal_work(zenith_env_builder: ZenithEnvBuilder):
|
||||
zenith_env_builder.num_safekeepers = 3
|
||||
env = zenith_env_builder.init()
|
||||
|
||||
env.zenith_cli.create_branch("test_wal_acceptors_normal_work", "main")
|
||||
env.zenith_cli(["branch", "test_wal_acceptors_normal_work", "main"])
|
||||
|
||||
pg = env.postgres.create_start('test_wal_acceptors_normal_work')
|
||||
|
||||
@@ -60,10 +62,10 @@ def test_many_timelines(zenith_env_builder: ZenithEnvBuilder):
|
||||
# start postgres on each timeline
|
||||
pgs = []
|
||||
for branch in branches:
|
||||
env.zenith_cli.create_branch(branch, "main")
|
||||
env.zenith_cli(["branch", branch, "main"])
|
||||
pgs.append(env.postgres.create_start(branch))
|
||||
|
||||
tenant_id = env.initial_tenant
|
||||
tenant_id = uuid.UUID(env.initial_tenant)
|
||||
|
||||
def collect_metrics(message: str) -> List[BranchMetrics]:
|
||||
with env.pageserver.http_client() as pageserver_http:
|
||||
@@ -90,8 +92,8 @@ def test_many_timelines(zenith_env_builder: ZenithEnvBuilder):
|
||||
latest_valid_lsn=branch_detail["latest_valid_lsn"],
|
||||
)
|
||||
for sk_m in sk_metrics:
|
||||
m.flush_lsns.append(sk_m.flush_lsn_inexact[(tenant_id.hex, timeline_id)])
|
||||
m.commit_lsns.append(sk_m.commit_lsn_inexact[(tenant_id.hex, timeline_id)])
|
||||
m.flush_lsns.append(sk_m.flush_lsn_inexact[timeline_id])
|
||||
m.commit_lsns.append(sk_m.commit_lsn_inexact[timeline_id])
|
||||
|
||||
for flush_lsn, commit_lsn in zip(m.flush_lsns, m.commit_lsns):
|
||||
# Invariant. May be < when transaction is in progress.
|
||||
@@ -183,7 +185,7 @@ def test_restarts(zenith_env_builder: ZenithEnvBuilder):
|
||||
zenith_env_builder.num_safekeepers = n_acceptors
|
||||
env = zenith_env_builder.init()
|
||||
|
||||
env.zenith_cli.create_branch("test_wal_acceptors_restarts", "main")
|
||||
env.zenith_cli(["branch", "test_wal_acceptors_restarts", "main"])
|
||||
pg = env.postgres.create_start('test_wal_acceptors_restarts')
|
||||
|
||||
# we rely upon autocommit after each statement
|
||||
@@ -220,7 +222,7 @@ def test_unavailability(zenith_env_builder: ZenithEnvBuilder):
|
||||
zenith_env_builder.num_safekeepers = 2
|
||||
env = zenith_env_builder.init()
|
||||
|
||||
env.zenith_cli.create_branch("test_wal_acceptors_unavailability", "main")
|
||||
env.zenith_cli(["branch", "test_wal_acceptors_unavailability", "main"])
|
||||
pg = env.postgres.create_start('test_wal_acceptors_unavailability')
|
||||
|
||||
# we rely upon autocommit after each statement
|
||||
@@ -291,7 +293,7 @@ def test_race_conditions(zenith_env_builder: ZenithEnvBuilder, stop_value):
|
||||
zenith_env_builder.num_safekeepers = 3
|
||||
env = zenith_env_builder.init()
|
||||
|
||||
env.zenith_cli.create_branch("test_wal_acceptors_race_conditions", "main")
|
||||
env.zenith_cli(["branch", "test_wal_acceptors_race_conditions", "main"])
|
||||
pg = env.postgres.create_start('test_wal_acceptors_race_conditions')
|
||||
|
||||
# we rely upon autocommit after each statement
|
||||
@@ -319,16 +321,16 @@ class ProposerPostgres(PgProtocol):
|
||||
def __init__(self,
|
||||
pgdata_dir: str,
|
||||
pg_bin,
|
||||
timeline_id: uuid.UUID,
|
||||
tenant_id: uuid.UUID,
|
||||
timeline_id: str,
|
||||
tenant_id: str,
|
||||
listen_addr: str,
|
||||
port: int):
|
||||
super().__init__(host=listen_addr, port=port, username='zenith_admin')
|
||||
|
||||
self.pgdata_dir: str = pgdata_dir
|
||||
self.pg_bin: PgBin = pg_bin
|
||||
self.timeline_id: uuid.UUID = timeline_id
|
||||
self.tenant_id: uuid.UUID = tenant_id
|
||||
self.timeline_id: str = timeline_id
|
||||
self.tenant_id: str = tenant_id
|
||||
self.listen_addr: str = listen_addr
|
||||
self.port: int = port
|
||||
|
||||
@@ -348,8 +350,8 @@ class ProposerPostgres(PgProtocol):
|
||||
cfg = [
|
||||
"synchronous_standby_names = 'walproposer'\n",
|
||||
"shared_preload_libraries = 'zenith'\n",
|
||||
f"zenith.zenith_timeline = '{self.timeline_id.hex}'\n",
|
||||
f"zenith.zenith_tenant = '{self.tenant_id.hex}'\n",
|
||||
f"zenith.zenith_timeline = '{self.timeline_id}'\n",
|
||||
f"zenith.zenith_tenant = '{self.tenant_id}'\n",
|
||||
f"zenith.page_server_connstring = ''\n",
|
||||
f"wal_acceptors = '{wal_acceptors}'\n",
|
||||
f"listen_addresses = '{self.listen_addr}'\n",
|
||||
@@ -406,8 +408,8 @@ def test_sync_safekeepers(zenith_env_builder: ZenithEnvBuilder,
|
||||
zenith_env_builder.num_safekeepers = 3
|
||||
env = zenith_env_builder.init()
|
||||
|
||||
timeline_id = uuid.uuid4()
|
||||
tenant_id = uuid.uuid4()
|
||||
timeline_id = uuid.uuid4().hex
|
||||
tenant_id = uuid.uuid4().hex
|
||||
|
||||
# write config for proposer
|
||||
pgdata_dir = os.path.join(env.repo_dir, "proposer_pgdata")
|
||||
@@ -456,7 +458,7 @@ def test_timeline_status(zenith_env_builder: ZenithEnvBuilder):
|
||||
zenith_env_builder.num_safekeepers = 1
|
||||
env = zenith_env_builder.init()
|
||||
|
||||
env.zenith_cli.create_branch("test_timeline_status", "main")
|
||||
env.zenith_cli(["branch", "test_timeline_status", "main"])
|
||||
pg = env.postgres.create_start('test_timeline_status')
|
||||
|
||||
wa = env.safekeepers[0]
|
||||
@@ -493,15 +495,15 @@ class SafekeeperEnv:
|
||||
self.bin_safekeeper = os.path.join(str(zenith_binpath), 'safekeeper')
|
||||
self.safekeepers: Optional[List[subprocess.CompletedProcess[Any]]] = None
|
||||
self.postgres: Optional[ProposerPostgres] = None
|
||||
self.tenant_id: Optional[uuid.UUID] = None
|
||||
self.timeline_id: Optional[uuid.UUID] = None
|
||||
self.tenant_id: Optional[str] = None
|
||||
self.timeline_id: Optional[str] = None
|
||||
|
||||
def init(self) -> "SafekeeperEnv":
|
||||
assert self.postgres is None, "postgres is already initialized"
|
||||
assert self.safekeepers is None, "safekeepers are already initialized"
|
||||
|
||||
self.timeline_id = uuid.uuid4()
|
||||
self.tenant_id = uuid.uuid4()
|
||||
self.timeline_id = uuid.uuid4().hex
|
||||
self.tenant_id = uuid.uuid4().hex
|
||||
mkdir_if_needed(str(self.repo_dir))
|
||||
|
||||
# Create config and a Safekeeper object for each safekeeper
|
||||
@@ -634,7 +636,7 @@ def test_replace_safekeeper(zenith_env_builder: ZenithEnvBuilder):
|
||||
|
||||
zenith_env_builder.num_safekeepers = 4
|
||||
env = zenith_env_builder.init()
|
||||
env.zenith_cli.create_branch("test_replace_safekeeper", "main")
|
||||
env.zenith_cli(["branch", "test_replace_safekeeper", "main"])
|
||||
|
||||
log.info("Use only first 3 safekeepers")
|
||||
env.safekeepers[3].stop()
|
||||
|
||||
@@ -9,6 +9,7 @@ from fixtures.utils import lsn_from_hex, lsn_to_hex
|
||||
from typing import List
|
||||
|
||||
log = getLogger('root.wal_acceptor_async')
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
class BankClient(object):
|
||||
@@ -202,7 +203,7 @@ def test_restarts_under_load(zenith_env_builder: ZenithEnvBuilder):
|
||||
zenith_env_builder.num_safekeepers = 3
|
||||
env = zenith_env_builder.init()
|
||||
|
||||
env.zenith_cli.create_branch("test_wal_acceptors_restarts_under_load", "main")
|
||||
env.zenith_cli(["branch", "test_wal_acceptors_restarts_under_load", "main"])
|
||||
pg = env.postgres.create_start('test_wal_acceptors_restarts_under_load')
|
||||
|
||||
asyncio.run(run_restarts_under_load(pg, env.safekeepers))
|
||||
|
||||
@@ -3,26 +3,30 @@ import uuid
|
||||
import requests
|
||||
|
||||
from psycopg2.extensions import cursor as PgCursor
|
||||
from fixtures.zenith_fixtures import ZenithEnv, ZenithEnvBuilder, ZenithPageserverHttpClient
|
||||
from fixtures.zenith_fixtures import ZenithEnv, ZenithEnvBuilder
|
||||
from typing import cast
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
def helper_compare_branch_list(pageserver_http_client: ZenithPageserverHttpClient,
|
||||
env: ZenithEnv,
|
||||
initial_tenant: uuid.UUID):
|
||||
|
||||
def helper_compare_branch_list(page_server_cur: PgCursor, env: ZenithEnv, initial_tenant: str):
|
||||
"""
|
||||
Compare branches list returned by CLI and directly via API.
|
||||
Filters out branches created by other tests.
|
||||
"""
|
||||
branches = pageserver_http_client.branch_list(initial_tenant)
|
||||
branches_api = sorted(map(lambda b: cast(str, b['name']), branches))
|
||||
|
||||
page_server_cur.execute(f'branch_list {initial_tenant}')
|
||||
branches_api = sorted(
|
||||
map(lambda b: cast(str, b['name']), json.loads(page_server_cur.fetchone()[0])))
|
||||
branches_api = [b for b in branches_api if b.startswith('test_cli_') or b in ('empty', 'main')]
|
||||
|
||||
res = env.zenith_cli.list_branches()
|
||||
res = env.zenith_cli(["branch"])
|
||||
res.check_returncode()
|
||||
branches_cli = sorted(map(lambda b: b.split(':')[-1].strip(), res.stdout.strip().split("\n")))
|
||||
branches_cli = [b for b in branches_cli if b.startswith('test_cli_') or b in ('empty', 'main')]
|
||||
|
||||
res = env.zenith_cli.list_branches(tenant_id=initial_tenant)
|
||||
res = env.zenith_cli(["branch", f"--tenantid={initial_tenant}"])
|
||||
res.check_returncode()
|
||||
branches_cli_with_tenant_arg = sorted(
|
||||
map(lambda b: b.split(':')[-1].strip(), res.stdout.strip().split("\n")))
|
||||
branches_cli_with_tenant_arg = [
|
||||
@@ -34,20 +38,24 @@ def helper_compare_branch_list(pageserver_http_client: ZenithPageserverHttpClien
|
||||
|
||||
def test_cli_branch_list(zenith_simple_env: ZenithEnv):
|
||||
env = zenith_simple_env
|
||||
pageserver_http_client = env.pageserver.http_client()
|
||||
page_server_conn = env.pageserver.connect()
|
||||
page_server_cur = page_server_conn.cursor()
|
||||
|
||||
# Initial sanity check
|
||||
helper_compare_branch_list(pageserver_http_client, env, env.initial_tenant)
|
||||
env.zenith_cli.create_branch("test_cli_branch_list_main", "empty")
|
||||
helper_compare_branch_list(pageserver_http_client, env, env.initial_tenant)
|
||||
helper_compare_branch_list(page_server_cur, env, env.initial_tenant)
|
||||
|
||||
# Create a branch for us
|
||||
res = env.zenith_cli(["branch", "test_cli_branch_list_main", "empty"])
|
||||
assert res.stderr == ''
|
||||
helper_compare_branch_list(page_server_cur, env, env.initial_tenant)
|
||||
|
||||
# Create a nested branch
|
||||
res = env.zenith_cli.create_branch("test_cli_branch_list_nested", "test_cli_branch_list_main")
|
||||
res = env.zenith_cli(["branch", "test_cli_branch_list_nested", "test_cli_branch_list_main"])
|
||||
assert res.stderr == ''
|
||||
helper_compare_branch_list(pageserver_http_client, env, env.initial_tenant)
|
||||
helper_compare_branch_list(page_server_cur, env, env.initial_tenant)
|
||||
|
||||
# Check that all new branches are visible via CLI
|
||||
res = env.zenith_cli.list_branches()
|
||||
res = env.zenith_cli(["branch"])
|
||||
assert res.stderr == ''
|
||||
branches_cli = sorted(map(lambda b: b.split(':')[-1].strip(), res.stdout.strip().split("\n")))
|
||||
|
||||
@@ -55,11 +63,12 @@ def test_cli_branch_list(zenith_simple_env: ZenithEnv):
|
||||
assert 'test_cli_branch_list_nested' in branches_cli
|
||||
|
||||
|
||||
def helper_compare_tenant_list(pageserver_http_client: ZenithPageserverHttpClient, env: ZenithEnv):
|
||||
tenants = pageserver_http_client.tenant_list()
|
||||
tenants_api = sorted(map(lambda t: cast(str, t['id']), tenants))
|
||||
def helper_compare_tenant_list(page_server_cur: PgCursor, env: ZenithEnv):
|
||||
page_server_cur.execute(f'tenant_list')
|
||||
tenants_api = sorted(
|
||||
map(lambda t: cast(str, t['id']), json.loads(page_server_cur.fetchone()[0])))
|
||||
|
||||
res = env.zenith_cli.list_tenants()
|
||||
res = env.zenith_cli(["tenant", "list"])
|
||||
assert res.stderr == ''
|
||||
tenants_cli = sorted(map(lambda t: t.split()[0], res.stdout.splitlines()))
|
||||
|
||||
@@ -68,30 +77,35 @@ def helper_compare_tenant_list(pageserver_http_client: ZenithPageserverHttpClien
|
||||
|
||||
def test_cli_tenant_list(zenith_simple_env: ZenithEnv):
|
||||
env = zenith_simple_env
|
||||
pageserver_http_client = env.pageserver.http_client()
|
||||
page_server_conn = env.pageserver.connect()
|
||||
page_server_cur = page_server_conn.cursor()
|
||||
|
||||
# Initial sanity check
|
||||
helper_compare_tenant_list(pageserver_http_client, env)
|
||||
helper_compare_tenant_list(page_server_cur, env)
|
||||
|
||||
# Create new tenant
|
||||
tenant1 = uuid.uuid4()
|
||||
env.zenith_cli.create_tenant(tenant1)
|
||||
tenant1 = uuid.uuid4().hex
|
||||
res = env.zenith_cli(["tenant", "create", tenant1])
|
||||
res.check_returncode()
|
||||
|
||||
# check tenant1 appeared
|
||||
helper_compare_tenant_list(pageserver_http_client, env)
|
||||
helper_compare_tenant_list(page_server_cur, env)
|
||||
|
||||
# Create new tenant
|
||||
tenant2 = uuid.uuid4()
|
||||
env.zenith_cli.create_tenant(tenant2)
|
||||
tenant2 = uuid.uuid4().hex
|
||||
res = env.zenith_cli(["tenant", "create", tenant2])
|
||||
res.check_returncode()
|
||||
|
||||
# check tenant2 appeared
|
||||
helper_compare_tenant_list(pageserver_http_client, env)
|
||||
helper_compare_tenant_list(page_server_cur, env)
|
||||
|
||||
res = env.zenith_cli.list_tenants()
|
||||
res = env.zenith_cli(["tenant", "list"])
|
||||
res.check_returncode()
|
||||
tenants = sorted(map(lambda t: t.split()[0], res.stdout.splitlines()))
|
||||
|
||||
assert env.initial_tenant.hex in tenants
|
||||
assert tenant1.hex in tenants
|
||||
assert tenant2.hex in tenants
|
||||
assert env.initial_tenant in tenants
|
||||
assert tenant1 in tenants
|
||||
assert tenant2 in tenants
|
||||
|
||||
|
||||
def test_cli_ipv4_listeners(zenith_env_builder: ZenithEnvBuilder):
|
||||
@@ -109,21 +123,3 @@ def test_cli_ipv4_listeners(zenith_env_builder: ZenithEnvBuilder):
|
||||
# Connect to ps port on v4 loopback
|
||||
# res = requests.get(f'http://127.0.0.1:{env.pageserver.service_port.http}/v1/status')
|
||||
# assert res.ok
|
||||
|
||||
|
||||
def test_cli_start_stop(zenith_env_builder: ZenithEnvBuilder):
|
||||
# Start with single sk
|
||||
zenith_env_builder.num_safekeepers = 1
|
||||
env = zenith_env_builder.init()
|
||||
|
||||
# Stop default ps/sk
|
||||
env.zenith_cli.pageserver_stop()
|
||||
env.zenith_cli.safekeeper_stop()
|
||||
|
||||
# Default start
|
||||
res = env.zenith_cli.raw_cli(["start"])
|
||||
res.check_returncode()
|
||||
|
||||
# Default stop
|
||||
res = env.zenith_cli.raw_cli(["stop"])
|
||||
res.check_returncode()
|
||||
|
||||
@@ -3,11 +3,15 @@ import os
|
||||
from fixtures.utils import mkdir_if_needed
|
||||
from fixtures.zenith_fixtures import ZenithEnv, base_dir, pg_distrib_dir
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
def test_isolation(zenith_simple_env: ZenithEnv, test_output_dir, pg_bin, capsys):
|
||||
env = zenith_simple_env
|
||||
|
||||
env.zenith_cli.create_branch("test_isolation", "empty")
|
||||
# Create a branch for us
|
||||
env.zenith_cli(["branch", "test_isolation", "empty"])
|
||||
|
||||
# Connect to postgres and create a database called "regression".
|
||||
# isolation tests use prepared transactions, so enable them
|
||||
pg = env.postgres.create_start('test_isolation', config_lines=['max_prepared_transactions=100'])
|
||||
|
||||
@@ -3,11 +3,15 @@ import os
|
||||
from fixtures.utils import mkdir_if_needed
|
||||
from fixtures.zenith_fixtures import ZenithEnv, check_restored_datadir_content, base_dir, pg_distrib_dir
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
def test_pg_regress(zenith_simple_env: ZenithEnv, test_output_dir: str, pg_bin, capsys):
|
||||
env = zenith_simple_env
|
||||
|
||||
env.zenith_cli.create_branch("test_pg_regress", "empty")
|
||||
# Create a branch for us
|
||||
env.zenith_cli(["branch", "test_pg_regress", "empty"])
|
||||
|
||||
# Connect to postgres and create a database called "regression".
|
||||
pg = env.postgres.create_start('test_pg_regress')
|
||||
pg.safe_psql('CREATE DATABASE regression')
|
||||
|
||||
@@ -7,11 +7,15 @@ from fixtures.zenith_fixtures import (ZenithEnv,
|
||||
pg_distrib_dir)
|
||||
from fixtures.log_helper import log
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
|
||||
def test_zenith_regress(zenith_simple_env: ZenithEnv, test_output_dir, pg_bin, capsys):
|
||||
env = zenith_simple_env
|
||||
|
||||
env.zenith_cli.create_branch("test_zenith_regress", "empty")
|
||||
# Create a branch for us
|
||||
env.zenith_cli(["branch", "test_zenith_regress", "empty"])
|
||||
|
||||
# Connect to postgres and create a database called "regression".
|
||||
pg = env.postgres.create_start('test_zenith_regress')
|
||||
pg.safe_psql('CREATE DATABASE regression')
|
||||
|
||||
@@ -1,6 +1 @@
|
||||
pytest_plugins = (
|
||||
"fixtures.zenith_fixtures",
|
||||
"fixtures.benchmark_fixture",
|
||||
"fixtures.compare_fixtures",
|
||||
"fixtures.slow",
|
||||
)
|
||||
pytest_plugins = ("fixtures.zenith_fixtures", "fixtures.benchmark_fixture")
|
||||
|
||||
@@ -8,7 +8,6 @@ import timeit
|
||||
import calendar
|
||||
import enum
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
import pytest
|
||||
from _pytest.config import Config
|
||||
from _pytest.terminal import TerminalReporter
|
||||
@@ -27,6 +26,8 @@ bencmark, and then record the result by calling zenbenchmark.record. For example
|
||||
import timeit
|
||||
from fixtures.zenith_fixtures import ZenithEnv
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures", "fixtures.benchmark_fixture")
|
||||
|
||||
def test_mybench(zenith_simple_env: env, zenbenchmark):
|
||||
|
||||
# Initialize the test
|
||||
@@ -39,8 +40,6 @@ def test_mybench(zenith_simple_env: env, zenbenchmark):
|
||||
# Record another measurement
|
||||
zenbenchmark.record('speed_of_light', 300000, 'km/s')
|
||||
|
||||
There's no need to import this file to use it. It should be declared as a plugin
|
||||
inside conftest.py, and that makes it available to all tests.
|
||||
|
||||
You can measure multiple things in one test, and record each one with a separate
|
||||
call to zenbenchmark. For example, you could time the bulk loading that happens
|
||||
@@ -277,11 +276,11 @@ class ZenithBenchmarker:
|
||||
assert matches
|
||||
return int(round(float(matches.group(1))))
|
||||
|
||||
def get_timeline_size(self, repo_dir: Path, tenantid: uuid.UUID, timelineid: str):
|
||||
def get_timeline_size(self, repo_dir: Path, tenantid: str, timelineid: str):
|
||||
"""
|
||||
Calculate the on-disk size of a timeline
|
||||
"""
|
||||
path = "{}/tenants/{}/timelines/{}".format(repo_dir, tenantid.hex, timelineid)
|
||||
path = "{}/tenants/{}/timelines/{}".format(repo_dir, tenantid, timelineid)
|
||||
|
||||
totalbytes = 0
|
||||
for root, dirs, files in os.walk(path):
|
||||
|
||||
@@ -65,7 +65,7 @@ class ZenithCompare(PgCompare):
|
||||
|
||||
# We only use one branch and one timeline
|
||||
self.branch = branch_name
|
||||
self.env.zenith_cli.create_branch(self.branch, "empty")
|
||||
self.env.zenith_cli(["branch", self.branch, "empty"])
|
||||
self._pg = self.env.postgres.create_start(self.branch)
|
||||
self.timeline = self.pg.safe_psql("SHOW zenith.zenith_timeline")[0][0]
|
||||
|
||||
@@ -86,7 +86,7 @@ class ZenithCompare(PgCompare):
|
||||
return self._pg_bin
|
||||
|
||||
def flush(self):
|
||||
self.pscur.execute(f"do_gc {self.env.initial_tenant.hex} {self.timeline} 0")
|
||||
self.pscur.execute(f"do_gc {self.env.initial_tenant} {self.timeline} 0")
|
||||
|
||||
def report_peak_memory_use(self) -> None:
|
||||
self.zenbenchmark.record("peak_mem",
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
import pytest
|
||||
"""
|
||||
This plugin allows tests to be marked as slow using pytest.mark.slow. By default slow
|
||||
tests are excluded. They need to be specifically requested with the --runslow flag in
|
||||
order to run.
|
||||
|
||||
Copied from here: https://docs.pytest.org/en/latest/example/simple.html
|
||||
"""
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption("--runslow", action="store_true", default=False, help="run slow tests")
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
config.addinivalue_line("markers", "slow: mark test as slow to run")
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
if config.getoption("--runslow"):
|
||||
# --runslow given in cli: do not skip slow tests
|
||||
return
|
||||
skip_slow = pytest.mark.skip(reason="need --runslow option to run")
|
||||
for item in items:
|
||||
if "slow" in item.keywords:
|
||||
item.add_marker(skip_slow)
|
||||
@@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
import textwrap
|
||||
from cached_property import cached_property
|
||||
import asyncpg
|
||||
import os
|
||||
@@ -27,7 +26,7 @@ from dataclasses import dataclass
|
||||
|
||||
# Type-related stuff
|
||||
from psycopg2.extensions import connection as PgConnection
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, TypeVar, cast, Union, Tuple
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, TypeVar, cast, Union
|
||||
from typing_extensions import Literal
|
||||
import pytest
|
||||
|
||||
@@ -45,8 +44,9 @@ the standard pytest.fixture with some extra behavior.
|
||||
There are several environment variables that can control the running of tests:
|
||||
ZENITH_BIN, POSTGRES_DISTRIB_DIR, etc. See README.md for more information.
|
||||
|
||||
There's no need to import this file to use it. It should be declared as a plugin
|
||||
inside conftest.py, and that makes it available to all tests.
|
||||
To use fixtures in a test file, add this line of code:
|
||||
|
||||
>>> pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
|
||||
Don't import functions from this file, or pytest will emit warnings. Instead
|
||||
put directly-importable functions into utils.py or another separate file.
|
||||
@@ -237,13 +237,10 @@ def port_distributor(worker_base_port):
|
||||
|
||||
class PgProtocol:
|
||||
""" Reusable connection logic """
|
||||
def __init__(self, host: str, port: int,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None):
|
||||
def __init__(self, host: str, port: int, username: Optional[str] = None):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.username = username
|
||||
self.password = password
|
||||
|
||||
def connstr(self,
|
||||
*,
|
||||
@@ -255,7 +252,6 @@ class PgProtocol:
|
||||
"""
|
||||
|
||||
username = username or self.username
|
||||
password = password or self.password
|
||||
res = f'host={self.host} port={self.port} dbname={dbname}'
|
||||
|
||||
if username:
|
||||
@@ -520,7 +516,6 @@ class ZenithEnv:
|
||||
self.rust_log_override = config.rust_log_override
|
||||
self.port_distributor = config.port_distributor
|
||||
self.s3_mock_server = config.s3_mock_server
|
||||
self.zenith_cli = ZenithCli(env=self)
|
||||
|
||||
self.postgres = PostgresFactory(self)
|
||||
|
||||
@@ -528,12 +523,12 @@ class ZenithEnv:
|
||||
|
||||
# generate initial tenant ID here instead of letting 'zenith init' generate it,
|
||||
# so that we don't need to dig it out of the config file afterwards.
|
||||
self.initial_tenant = uuid.uuid4()
|
||||
self.initial_tenant = uuid.uuid4().hex
|
||||
|
||||
# Create a config file corresponding to the options
|
||||
toml = textwrap.dedent(f"""
|
||||
default_tenantid = '{self.initial_tenant.hex}'
|
||||
""")
|
||||
toml = f"""
|
||||
default_tenantid = '{self.initial_tenant}'
|
||||
"""
|
||||
|
||||
# Create config for pageserver
|
||||
pageserver_port = PageserverPort(
|
||||
@@ -542,12 +537,12 @@ class ZenithEnv:
|
||||
)
|
||||
pageserver_auth_type = "ZenithJWT" if config.pageserver_auth_enabled else "Trust"
|
||||
|
||||
toml += textwrap.dedent(f"""
|
||||
[pageserver]
|
||||
listen_pg_addr = 'localhost:{pageserver_port.pg}'
|
||||
listen_http_addr = 'localhost:{pageserver_port.http}'
|
||||
auth_type = '{pageserver_auth_type}'
|
||||
""")
|
||||
toml += f"""
|
||||
[pageserver]
|
||||
listen_pg_addr = 'localhost:{pageserver_port.pg}'
|
||||
listen_http_addr = 'localhost:{pageserver_port.http}'
|
||||
auth_type = '{pageserver_auth_type}'
|
||||
"""
|
||||
|
||||
# Create a corresponding ZenithPageserver object
|
||||
self.pageserver = ZenithPageserver(self,
|
||||
@@ -577,7 +572,15 @@ sync = false # Disable fsyncs to make the tests go faster
|
||||
|
||||
log.info(f"Config: {toml}")
|
||||
|
||||
self.zenith_cli.init(toml)
|
||||
# Run 'zenith init' using the config file we constructed
|
||||
with tempfile.NamedTemporaryFile(mode='w+') as tmp:
|
||||
tmp.write(toml)
|
||||
tmp.flush()
|
||||
|
||||
cmd = ['init', f'--config={tmp.name}']
|
||||
append_pageserver_param_overrides(cmd, config.pageserver_remote_storage)
|
||||
|
||||
self.zenith_cli(cmd)
|
||||
|
||||
# Start up the page server and all the safekeepers
|
||||
self.pageserver.start()
|
||||
@@ -589,12 +592,69 @@ sync = false # Disable fsyncs to make the tests go faster
|
||||
""" Get list of safekeeper endpoints suitable for wal_acceptors GUC """
|
||||
return ','.join([f'localhost:{wa.port.pg}' for wa in self.safekeepers])
|
||||
|
||||
def create_tenant(self, tenant_id: Optional[uuid.UUID] = None) -> uuid.UUID:
|
||||
def create_tenant(self, tenant_id: Optional[str] = None):
|
||||
if tenant_id is None:
|
||||
tenant_id = uuid.uuid4()
|
||||
self.zenith_cli.create_tenant(tenant_id)
|
||||
tenant_id = uuid.uuid4().hex
|
||||
res = self.zenith_cli(['tenant', 'create', tenant_id])
|
||||
res.check_returncode()
|
||||
return tenant_id
|
||||
|
||||
def zenith_cli(self, arguments: List[str]) -> 'subprocess.CompletedProcess[str]':
|
||||
"""
|
||||
Run "zenith" with the specified arguments.
|
||||
|
||||
Arguments must be in list form, e.g. ['pg', 'create']
|
||||
|
||||
Return both stdout and stderr, which can be accessed as
|
||||
|
||||
>>> result = env.zenith_cli(...)
|
||||
>>> assert result.stderr == ""
|
||||
>>> log.info(result.stdout)
|
||||
"""
|
||||
|
||||
assert type(arguments) == list
|
||||
|
||||
bin_zenith = os.path.join(str(zenith_binpath), 'zenith')
|
||||
|
||||
args = [bin_zenith] + arguments
|
||||
log.info('Running command "{}"'.format(' '.join(args)))
|
||||
log.info(f'Running in "{self.repo_dir}"')
|
||||
|
||||
env_vars = os.environ.copy()
|
||||
env_vars['ZENITH_REPO_DIR'] = str(self.repo_dir)
|
||||
env_vars['POSTGRES_DISTRIB_DIR'] = str(pg_distrib_dir)
|
||||
|
||||
if self.rust_log_override is not None:
|
||||
env_vars['RUST_LOG'] = self.rust_log_override
|
||||
|
||||
# Pass coverage settings
|
||||
var = 'LLVM_PROFILE_FILE'
|
||||
val = os.environ.get(var)
|
||||
if val:
|
||||
env_vars[var] = val
|
||||
|
||||
# Intercept CalledProcessError and print more info
|
||||
try:
|
||||
res = subprocess.run(args,
|
||||
env=env_vars,
|
||||
check=True,
|
||||
universal_newlines=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE)
|
||||
log.info(f"Run success: {res.stdout}")
|
||||
except subprocess.CalledProcessError as exc:
|
||||
# this way command output will be in recorded and shown in CI in failure message
|
||||
msg = f"""\
|
||||
Run failed: {exc}
|
||||
stdout: {exc.stdout}
|
||||
stderr: {exc.stderr}
|
||||
"""
|
||||
log.info(msg)
|
||||
|
||||
raise Exception(msg) from exc
|
||||
|
||||
return res
|
||||
|
||||
@cached_property
|
||||
def auth_keys(self) -> AuthKeys:
|
||||
pub = (Path(self.repo_dir) / 'auth_public_key.pem').read_bytes()
|
||||
@@ -622,7 +682,7 @@ def _shared_simple_env(request: Any, port_distributor) -> Iterator[ZenithEnv]:
|
||||
env = builder.init()
|
||||
|
||||
# For convenience in tests, create a branch from the freshly-initialized cluster.
|
||||
env.zenith_cli.create_branch("empty", "main")
|
||||
env.zenith_cli(["branch", "empty", "main"])
|
||||
|
||||
# Return the builder to the caller
|
||||
yield env
|
||||
@@ -668,10 +728,6 @@ def zenith_env_builder(test_output_dir, port_distributor) -> Iterator[ZenithEnvB
|
||||
yield builder
|
||||
|
||||
|
||||
class ZenithPageserverApiException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ZenithPageserverHttpClient(requests.Session):
|
||||
def __init__(self, port: int, auth_token: Optional[str] = None) -> None:
|
||||
super().__init__()
|
||||
@@ -681,32 +737,22 @@ class ZenithPageserverHttpClient(requests.Session):
|
||||
if auth_token is not None:
|
||||
self.headers['Authorization'] = f'Bearer {auth_token}'
|
||||
|
||||
def verbose_error(self, res: requests.Response):
|
||||
try:
|
||||
res.raise_for_status()
|
||||
except requests.RequestException as e:
|
||||
try:
|
||||
msg = res.json()['msg']
|
||||
except:
|
||||
msg = ''
|
||||
raise ZenithPageserverApiException(msg) from e
|
||||
|
||||
def check_status(self):
|
||||
self.get(f"http://localhost:{self.port}/v1/status").raise_for_status()
|
||||
|
||||
def timeline_attach(self, tenant_id: uuid.UUID, timeline_id: uuid.UUID):
|
||||
res = self.post(
|
||||
f"http://localhost:{self.port}/v1/timeline/{tenant_id.hex}/{timeline_id.hex}/attach", )
|
||||
self.verbose_error(res)
|
||||
res.raise_for_status()
|
||||
|
||||
def timeline_detach(self, tenant_id: uuid.UUID, timeline_id: uuid.UUID):
|
||||
res = self.post(
|
||||
f"http://localhost:{self.port}/v1/timeline/{tenant_id.hex}/{timeline_id.hex}/detach", )
|
||||
self.verbose_error(res)
|
||||
res.raise_for_status()
|
||||
|
||||
def branch_list(self, tenant_id: uuid.UUID) -> List[Dict[Any, Any]]:
|
||||
res = self.get(f"http://localhost:{self.port}/v1/branch/{tenant_id.hex}")
|
||||
self.verbose_error(res)
|
||||
res.raise_for_status()
|
||||
res_json = res.json()
|
||||
assert isinstance(res_json, list)
|
||||
return res_json
|
||||
@@ -718,7 +764,7 @@ class ZenithPageserverHttpClient(requests.Session):
|
||||
'name': name,
|
||||
'start_point': start_point,
|
||||
})
|
||||
self.verbose_error(res)
|
||||
res.raise_for_status()
|
||||
res_json = res.json()
|
||||
assert isinstance(res_json, dict)
|
||||
return res_json
|
||||
@@ -727,14 +773,14 @@ class ZenithPageserverHttpClient(requests.Session):
|
||||
res = self.get(
|
||||
f"http://localhost:{self.port}/v1/branch/{tenant_id.hex}/{name}?include-non-incremental-logical-size=1",
|
||||
)
|
||||
self.verbose_error(res)
|
||||
res.raise_for_status()
|
||||
res_json = res.json()
|
||||
assert isinstance(res_json, dict)
|
||||
return res_json
|
||||
|
||||
def tenant_list(self) -> List[Dict[Any, Any]]:
|
||||
res = self.get(f"http://localhost:{self.port}/v1/tenant")
|
||||
self.verbose_error(res)
|
||||
res.raise_for_status()
|
||||
res_json = res.json()
|
||||
assert isinstance(res_json, list)
|
||||
return res_json
|
||||
@@ -746,27 +792,27 @@ class ZenithPageserverHttpClient(requests.Session):
|
||||
'tenant_id': tenant_id.hex,
|
||||
},
|
||||
)
|
||||
self.verbose_error(res)
|
||||
res.raise_for_status()
|
||||
return res.json()
|
||||
|
||||
def timeline_list(self, tenant_id: uuid.UUID) -> List[str]:
|
||||
res = self.get(f"http://localhost:{self.port}/v1/timeline/{tenant_id.hex}")
|
||||
self.verbose_error(res)
|
||||
res.raise_for_status()
|
||||
res_json = res.json()
|
||||
assert isinstance(res_json, list)
|
||||
return res_json
|
||||
|
||||
def timeline_detail(self, tenant_id: uuid.UUID, timeline_id: uuid.UUID) -> Dict[Any, Any]:
|
||||
def timeline_detail(self, tenant_id: uuid.UUID, timeline_id: uuid.UUID):
|
||||
res = self.get(
|
||||
f"http://localhost:{self.port}/v1/timeline/{tenant_id.hex}/{timeline_id.hex}")
|
||||
self.verbose_error(res)
|
||||
res.raise_for_status()
|
||||
res_json = res.json()
|
||||
assert isinstance(res_json, dict)
|
||||
return res_json
|
||||
|
||||
def get_metrics(self) -> str:
|
||||
res = self.get(f"http://localhost:{self.port}/metrics")
|
||||
self.verbose_error(res)
|
||||
res.raise_for_status()
|
||||
return res.text
|
||||
|
||||
|
||||
@@ -793,189 +839,6 @@ class S3Storage:
|
||||
RemoteStorage = Union[LocalFsStorage, S3Storage]
|
||||
|
||||
|
||||
class ZenithCli:
|
||||
"""
|
||||
A typed wrapper around the `zenith` CLI tool.
|
||||
Supports main commands via typed methods and a way to run arbitrary command directly via CLI.
|
||||
"""
|
||||
def __init__(self, env: ZenithEnv) -> None:
|
||||
self.env = env
|
||||
pass
|
||||
|
||||
def create_tenant(self, tenant_id: Optional[uuid.UUID] = None) -> uuid.UUID:
|
||||
if tenant_id is None:
|
||||
tenant_id = uuid.uuid4()
|
||||
self.raw_cli(['tenant', 'create', tenant_id.hex])
|
||||
return tenant_id
|
||||
|
||||
def list_tenants(self) -> 'subprocess.CompletedProcess[str]':
|
||||
return self.raw_cli(['tenant', 'list'])
|
||||
|
||||
def create_branch(self,
|
||||
branch_name: str,
|
||||
starting_point: str,
|
||||
tenant_id: Optional[uuid.UUID] = None) -> 'subprocess.CompletedProcess[str]':
|
||||
args = ['branch']
|
||||
if tenant_id is not None:
|
||||
args.extend(['--tenantid', tenant_id.hex])
|
||||
args.extend([branch_name, starting_point])
|
||||
|
||||
return self.raw_cli(args)
|
||||
|
||||
def list_branches(self,
|
||||
tenant_id: Optional[uuid.UUID] = None) -> 'subprocess.CompletedProcess[str]':
|
||||
args = ['branch']
|
||||
if tenant_id is not None:
|
||||
args.extend(['--tenantid', tenant_id.hex])
|
||||
return self.raw_cli(args)
|
||||
|
||||
def init(self, config_toml: str) -> 'subprocess.CompletedProcess[str]':
|
||||
with tempfile.NamedTemporaryFile(mode='w+') as tmp:
|
||||
tmp.write(config_toml)
|
||||
tmp.flush()
|
||||
|
||||
cmd = ['init', f'--config={tmp.name}']
|
||||
append_pageserver_param_overrides(cmd, self.env.pageserver.remote_storage)
|
||||
|
||||
return self.raw_cli(cmd)
|
||||
|
||||
def pageserver_start(self) -> 'subprocess.CompletedProcess[str]':
|
||||
start_args = ['pageserver', 'start']
|
||||
append_pageserver_param_overrides(start_args, self.env.pageserver.remote_storage)
|
||||
return self.raw_cli(start_args)
|
||||
|
||||
def pageserver_stop(self, immediate=False) -> 'subprocess.CompletedProcess[str]':
|
||||
cmd = ['pageserver', 'stop']
|
||||
if immediate:
|
||||
cmd.extend(['-m', 'immediate'])
|
||||
|
||||
log.info(f"Stopping pageserver with {cmd}")
|
||||
return self.raw_cli(cmd)
|
||||
|
||||
def safekeeper_start(self, name: str) -> 'subprocess.CompletedProcess[str]':
|
||||
return self.raw_cli(['safekeeper', 'start', name])
|
||||
|
||||
def safekeeper_stop(self,
|
||||
name: Optional[str] = None,
|
||||
immediate=False) -> 'subprocess.CompletedProcess[str]':
|
||||
args = ['safekeeper', 'stop']
|
||||
if immediate:
|
||||
args.extend(['-m', 'immediate'])
|
||||
if name is not None:
|
||||
args.append(name)
|
||||
return self.raw_cli(args)
|
||||
|
||||
def pg_create(
|
||||
self,
|
||||
node_name: str,
|
||||
tenant_id: Optional[uuid.UUID] = None,
|
||||
timeline_spec: Optional[str] = None,
|
||||
port: Optional[int] = None,
|
||||
) -> 'subprocess.CompletedProcess[str]':
|
||||
args = ['pg', 'create']
|
||||
if tenant_id is not None:
|
||||
args.extend(['--tenantid', tenant_id.hex])
|
||||
if port is not None:
|
||||
args.append(f'--port={port}')
|
||||
args.append(node_name)
|
||||
if timeline_spec is not None:
|
||||
args.append(timeline_spec)
|
||||
return self.raw_cli(args)
|
||||
|
||||
def pg_start(
|
||||
self,
|
||||
node_name: str,
|
||||
tenant_id: Optional[uuid.UUID] = None,
|
||||
timeline_spec: Optional[str] = None,
|
||||
port: Optional[int] = None,
|
||||
) -> 'subprocess.CompletedProcess[str]':
|
||||
args = ['pg', 'start']
|
||||
if tenant_id is not None:
|
||||
args.extend(['--tenantid', tenant_id.hex])
|
||||
if port is not None:
|
||||
args.append(f'--port={port}')
|
||||
args.append(node_name)
|
||||
if timeline_spec is not None:
|
||||
args.append(timeline_spec)
|
||||
|
||||
return self.raw_cli(args)
|
||||
|
||||
def pg_stop(
|
||||
self,
|
||||
node_name: str,
|
||||
tenant_id: Optional[uuid.UUID] = None,
|
||||
destroy=False,
|
||||
) -> 'subprocess.CompletedProcess[str]':
|
||||
args = ['pg', 'stop']
|
||||
if tenant_id is not None:
|
||||
args.extend(['--tenantid', tenant_id.hex])
|
||||
if destroy:
|
||||
args.append('--destroy')
|
||||
args.append(node_name)
|
||||
|
||||
return self.raw_cli(args)
|
||||
|
||||
def raw_cli(self,
|
||||
arguments: List[str],
|
||||
check_return_code=True) -> 'subprocess.CompletedProcess[str]':
|
||||
"""
|
||||
Run "zenith" with the specified arguments.
|
||||
|
||||
Arguments must be in list form, e.g. ['pg', 'create']
|
||||
|
||||
Return both stdout and stderr, which can be accessed as
|
||||
|
||||
>>> result = env.zenith_cli.raw_cli(...)
|
||||
>>> assert result.stderr == ""
|
||||
>>> log.info(result.stdout)
|
||||
"""
|
||||
|
||||
assert type(arguments) == list
|
||||
|
||||
bin_zenith = os.path.join(str(zenith_binpath), 'zenith')
|
||||
|
||||
args = [bin_zenith] + arguments
|
||||
log.info('Running command "{}"'.format(' '.join(args)))
|
||||
log.info(f'Running in "{self.env.repo_dir}"')
|
||||
|
||||
env_vars = os.environ.copy()
|
||||
env_vars['ZENITH_REPO_DIR'] = str(self.env.repo_dir)
|
||||
env_vars['POSTGRES_DISTRIB_DIR'] = str(pg_distrib_dir)
|
||||
|
||||
if self.env.rust_log_override is not None:
|
||||
env_vars['RUST_LOG'] = self.env.rust_log_override
|
||||
|
||||
# Pass coverage settings
|
||||
var = 'LLVM_PROFILE_FILE'
|
||||
val = os.environ.get(var)
|
||||
if val:
|
||||
env_vars[var] = val
|
||||
|
||||
# Intercept CalledProcessError and print more info
|
||||
try:
|
||||
res = subprocess.run(args,
|
||||
env=env_vars,
|
||||
check=True,
|
||||
universal_newlines=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE)
|
||||
log.info(f"Run success: {res.stdout}")
|
||||
except subprocess.CalledProcessError as exc:
|
||||
# this way command output will be in recorded and shown in CI in failure message
|
||||
msg = f"""\
|
||||
Run failed: {exc}
|
||||
stdout: {exc.stdout}
|
||||
stderr: {exc.stderr}
|
||||
"""
|
||||
log.info(msg)
|
||||
|
||||
raise Exception(msg) from exc
|
||||
|
||||
if check_return_code:
|
||||
res.check_returncode()
|
||||
return res
|
||||
|
||||
|
||||
class ZenithPageserver(PgProtocol):
|
||||
"""
|
||||
An object representing a running pageserver.
|
||||
@@ -1000,7 +863,10 @@ class ZenithPageserver(PgProtocol):
|
||||
"""
|
||||
assert self.running == False
|
||||
|
||||
self.env.zenith_cli.pageserver_start()
|
||||
start_args = ['pageserver', 'start']
|
||||
append_pageserver_param_overrides(start_args, self.remote_storage)
|
||||
|
||||
self.env.zenith_cli(start_args)
|
||||
self.running = True
|
||||
return self
|
||||
|
||||
@@ -1009,8 +875,13 @@ class ZenithPageserver(PgProtocol):
|
||||
Stop the page server.
|
||||
Returns self.
|
||||
"""
|
||||
cmd = ['pageserver', 'stop']
|
||||
if immediate:
|
||||
cmd.extend(['-m', 'immediate'])
|
||||
|
||||
log.info(f"Stopping pageserver with {cmd}")
|
||||
if self.running:
|
||||
self.env.zenith_cli.pageserver_stop(immediate)
|
||||
self.env.zenith_cli(cmd)
|
||||
self.running = False
|
||||
|
||||
return self
|
||||
@@ -1161,62 +1032,9 @@ def vanilla_pg(test_output_dir: str) -> Iterator[VanillaPostgres]:
|
||||
yield vanilla_pg
|
||||
|
||||
|
||||
class ZenithProxy(PgProtocol):
|
||||
def __init__(self, port: int):
|
||||
super().__init__(host="127.0.0.1", username="pytest", password="pytest", port=port)
|
||||
self.running = False
|
||||
|
||||
def start_static(self, addr="127.0.0.1:5432") -> None:
|
||||
assert not self.running
|
||||
self.running = True
|
||||
|
||||
http_port = "7001"
|
||||
args = [
|
||||
# TODO is cargo run the right thing to do?
|
||||
"cargo",
|
||||
"run",
|
||||
"--bin", "proxy",
|
||||
"--",
|
||||
"--http", http_port,
|
||||
"--proxy", f"{self.host}:{self.port}",
|
||||
"--auth-method", "password",
|
||||
"--static-router", addr,
|
||||
]
|
||||
self.popen = subprocess.Popen(args)
|
||||
|
||||
# Readiness probe
|
||||
requests.get(f"http://{self.host}:{http_port}/v1/status")
|
||||
|
||||
|
||||
def stop(self) -> None:
|
||||
assert self.running
|
||||
self.running = False
|
||||
|
||||
# NOTE the process will die when we're done with tests anyway, because
|
||||
# it's a child process. This is mostly to clean up in between different tests.
|
||||
self.popen.kill()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
if self.running:
|
||||
self.stop()
|
||||
|
||||
|
||||
@pytest.fixture(scope='function')
|
||||
def static_proxy(vanilla_pg) -> Iterator[ZenithProxy]:
|
||||
vanilla_pg.start()
|
||||
vanilla_pg.safe_psql("create user pytest with password 'pytest';")
|
||||
|
||||
with ZenithProxy(4432) as proxy:
|
||||
proxy.start_static()
|
||||
yield proxy
|
||||
|
||||
|
||||
class Postgres(PgProtocol):
|
||||
""" An object representing a running postgres daemon. """
|
||||
def __init__(self, env: ZenithEnv, tenant_id: uuid.UUID, port: int):
|
||||
def __init__(self, env: ZenithEnv, tenant_id: str, port: int):
|
||||
super().__init__(host='localhost', port=port, username='zenith_admin')
|
||||
|
||||
self.env = env
|
||||
@@ -1243,12 +1061,16 @@ class Postgres(PgProtocol):
|
||||
if branch is None:
|
||||
branch = node_name
|
||||
|
||||
self.env.zenith_cli.pg_create(node_name,
|
||||
tenant_id=self.tenant_id,
|
||||
port=self.port,
|
||||
timeline_spec=branch)
|
||||
self.env.zenith_cli([
|
||||
'pg',
|
||||
'create',
|
||||
f'--tenantid={self.tenant_id}',
|
||||
f'--port={self.port}',
|
||||
node_name,
|
||||
branch
|
||||
])
|
||||
self.node_name = node_name
|
||||
path = pathlib.Path('pgdatadirs') / 'tenants' / self.tenant_id.hex / self.node_name
|
||||
path = pathlib.Path('pgdatadirs') / 'tenants' / self.tenant_id / self.node_name
|
||||
self.pgdata_dir = os.path.join(self.env.repo_dir, path)
|
||||
|
||||
if config_lines is None:
|
||||
@@ -1267,9 +1089,8 @@ class Postgres(PgProtocol):
|
||||
|
||||
log.info(f"Starting postgres node {self.node_name}")
|
||||
|
||||
run_result = self.env.zenith_cli.pg_start(self.node_name,
|
||||
tenant_id=self.tenant_id,
|
||||
port=self.port)
|
||||
run_result = self.env.zenith_cli(
|
||||
['pg', 'start', f'--tenantid={self.tenant_id}', f'--port={self.port}', self.node_name])
|
||||
self.running = True
|
||||
|
||||
log.info(f"stdout: {run_result.stdout}")
|
||||
@@ -1279,7 +1100,7 @@ class Postgres(PgProtocol):
|
||||
def pg_data_dir_path(self) -> str:
|
||||
""" Path to data directory """
|
||||
assert self.node_name
|
||||
path = pathlib.Path('pgdatadirs') / 'tenants' / self.tenant_id.hex / self.node_name
|
||||
path = pathlib.Path('pgdatadirs') / 'tenants' / self.tenant_id / self.node_name
|
||||
return os.path.join(self.env.repo_dir, path)
|
||||
|
||||
def pg_xact_dir_path(self) -> str:
|
||||
@@ -1339,7 +1160,7 @@ class Postgres(PgProtocol):
|
||||
|
||||
if self.running:
|
||||
assert self.node_name is not None
|
||||
self.env.zenith_cli.pg_stop(self.node_name, tenant_id=self.tenant_id)
|
||||
self.env.zenith_cli(['pg', 'stop', self.node_name, f'--tenantid={self.tenant_id}'])
|
||||
self.running = False
|
||||
|
||||
return self
|
||||
@@ -1351,7 +1172,8 @@ class Postgres(PgProtocol):
|
||||
"""
|
||||
|
||||
assert self.node_name is not None
|
||||
self.env.zenith_cli.pg_stop(self.node_name, self.tenant_id, destroy=True)
|
||||
self.env.zenith_cli(
|
||||
['pg', 'stop', '--destroy', self.node_name, f'--tenantid={self.tenant_id}'])
|
||||
self.node_name = None
|
||||
|
||||
return self
|
||||
@@ -1393,7 +1215,7 @@ class PostgresFactory:
|
||||
def create_start(self,
|
||||
node_name: str = "main",
|
||||
branch: Optional[str] = None,
|
||||
tenant_id: Optional[uuid.UUID] = None,
|
||||
tenant_id: Optional[str] = None,
|
||||
config_lines: Optional[List[str]] = None) -> Postgres:
|
||||
|
||||
pg = Postgres(
|
||||
@@ -1413,7 +1235,7 @@ class PostgresFactory:
|
||||
def create(self,
|
||||
node_name: str = "main",
|
||||
branch: Optional[str] = None,
|
||||
tenant_id: Optional[uuid.UUID] = None,
|
||||
tenant_id: Optional[str] = None,
|
||||
config_lines: Optional[List[str]] = None) -> Postgres:
|
||||
|
||||
pg = Postgres(
|
||||
@@ -1458,7 +1280,7 @@ class Safekeeper:
|
||||
auth_token: Optional[str] = None
|
||||
|
||||
def start(self) -> 'Safekeeper':
|
||||
self.env.zenith_cli.safekeeper_start(self.name)
|
||||
self.env.zenith_cli(['safekeeper', 'start', self.name])
|
||||
|
||||
# wait for wal acceptor start by checking its status
|
||||
started_at = time.time()
|
||||
@@ -1477,13 +1299,16 @@ class Safekeeper:
|
||||
return self
|
||||
|
||||
def stop(self, immediate=False) -> 'Safekeeper':
|
||||
cmd = ['safekeeper', 'stop']
|
||||
if immediate:
|
||||
cmd.extend(['-m', 'immediate'])
|
||||
cmd.append(self.name)
|
||||
|
||||
log.info('Stopping safekeeper {}'.format(self.name))
|
||||
self.env.zenith_cli.safekeeper_stop(self.name, immediate)
|
||||
self.env.zenith_cli(cmd)
|
||||
return self
|
||||
|
||||
def append_logical_message(self,
|
||||
tenant_id: uuid.UUID,
|
||||
timeline_id: uuid.UUID,
|
||||
def append_logical_message(self, tenant_id: str, timeline_id: str,
|
||||
request: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Send JSON_CTRL query to append LogicalMessage to WAL and modify
|
||||
@@ -1493,7 +1318,7 @@ class Safekeeper:
|
||||
|
||||
# "replication=0" hacks psycopg not to send additional queries
|
||||
# on startup, see https://github.com/psycopg/psycopg2/pull/482
|
||||
connstr = f"host=localhost port={self.port.pg} replication=0 options='-c ztimelineid={timeline_id.hex} ztenantid={tenant_id.hex}'"
|
||||
connstr = f"host=localhost port={self.port.pg} replication=0 options='-c ztimelineid={timeline_id} ztenantid={tenant_id}'"
|
||||
|
||||
with closing(psycopg2.connect(connstr)) as conn:
|
||||
# server doesn't support transactions
|
||||
@@ -1522,8 +1347,8 @@ class SafekeeperTimelineStatus:
|
||||
class SafekeeperMetrics:
|
||||
# These are metrics from Prometheus which uses float64 internally.
|
||||
# As a consequence, values may differ from real original int64s.
|
||||
flush_lsn_inexact: Dict[Tuple[str, str], int] = field(default_factory=dict)
|
||||
commit_lsn_inexact: Dict[Tuple[str, str], int] = field(default_factory=dict)
|
||||
flush_lsn_inexact: Dict[str, int] = field(default_factory=dict)
|
||||
commit_lsn_inexact: Dict[str, int] = field(default_factory=dict)
|
||||
|
||||
|
||||
class SafekeeperHttpClient(requests.Session):
|
||||
@@ -1547,16 +1372,14 @@ class SafekeeperHttpClient(requests.Session):
|
||||
all_metrics_text = request_result.text
|
||||
|
||||
metrics = SafekeeperMetrics()
|
||||
for match in re.finditer(
|
||||
r'^safekeeper_flush_lsn{tenant_id="([0-9a-f]+)",timeline_id="([0-9a-f]+)"} (\S+)$',
|
||||
all_metrics_text,
|
||||
re.MULTILINE):
|
||||
metrics.flush_lsn_inexact[(match.group(1), match.group(2))] = int(match.group(3))
|
||||
for match in re.finditer(
|
||||
r'^safekeeper_commit_lsn{tenant_id="([0-9a-f]+)",timeline_id="([0-9a-f]+)"} (\S+)$',
|
||||
all_metrics_text,
|
||||
re.MULTILINE):
|
||||
metrics.commit_lsn_inexact[(match.group(1), match.group(2))] = int(match.group(3))
|
||||
for match in re.finditer(r'^safekeeper_flush_lsn{ztli="([0-9a-f]+)"} (\S+)$',
|
||||
all_metrics_text,
|
||||
re.MULTILINE):
|
||||
metrics.flush_lsn_inexact[match.group(1)] = int(match.group(2))
|
||||
for match in re.finditer(r'^safekeeper_commit_lsn{ztli="([0-9a-f]+)"} (\S+)$',
|
||||
all_metrics_text,
|
||||
re.MULTILINE):
|
||||
metrics.commit_lsn_inexact[match.group(1)] = int(match.group(2))
|
||||
return metrics
|
||||
|
||||
|
||||
@@ -1665,7 +1488,7 @@ def check_restored_datadir_content(test_output_dir: str, env: ZenithEnv, pg: Pos
|
||||
{psql_path} \
|
||||
--no-psqlrc \
|
||||
postgres://localhost:{env.pageserver.service_port.pg} \
|
||||
-c 'basebackup {pg.tenant_id.hex} {timeline}' \
|
||||
-c 'basebackup {pg.tenant_id} {timeline}' \
|
||||
| tar -x -C {restored_dir_path}
|
||||
"""
|
||||
|
||||
|
||||
@@ -4,6 +4,12 @@ from fixtures.log_helper import log
|
||||
from fixtures.benchmark_fixture import MetricReport, ZenithBenchmarker
|
||||
from fixtures.compare_fixtures import PgCompare, VanillaCompare, ZenithCompare
|
||||
|
||||
pytest_plugins = (
|
||||
"fixtures.zenith_fixtures",
|
||||
"fixtures.benchmark_fixture",
|
||||
"fixtures.compare_fixtures",
|
||||
)
|
||||
|
||||
|
||||
#
|
||||
# Run bulk INSERT test.
|
||||
|
||||
@@ -4,6 +4,8 @@ import pytest
|
||||
|
||||
from fixtures.zenith_fixtures import ZenithEnvBuilder
|
||||
|
||||
pytest_plugins = ("fixtures.benchmark_fixture")
|
||||
|
||||
# Run bulk tenant creation test.
|
||||
#
|
||||
# Collects metrics:
|
||||
@@ -31,10 +33,12 @@ def test_bulk_tenant_create(
|
||||
start = timeit.default_timer()
|
||||
|
||||
tenant = env.create_tenant()
|
||||
env.zenith_cli.create_branch(
|
||||
env.zenith_cli([
|
||||
"branch",
|
||||
f"test_bulk_tenant_create_{tenants_count}_{i}_{use_wal_acceptors}",
|
||||
"main",
|
||||
tenant_id=tenant)
|
||||
f"--tenantid={tenant}"
|
||||
])
|
||||
|
||||
# FIXME: We used to start new safekeepers here. Did that make sense? Should we do it now?
|
||||
#if use_wal_acceptors == 'with_wa':
|
||||
|
||||
@@ -6,6 +6,12 @@ from fixtures.compare_fixtures import PgCompare, VanillaCompare, ZenithCompare
|
||||
from io import BufferedReader, RawIOBase
|
||||
from itertools import repeat
|
||||
|
||||
pytest_plugins = (
|
||||
"fixtures.zenith_fixtures",
|
||||
"fixtures.benchmark_fixture",
|
||||
"fixtures.compare_fixtures",
|
||||
)
|
||||
|
||||
|
||||
class CopyTestData(RawIOBase):
|
||||
def __init__(self, rows: int):
|
||||
|
||||
@@ -5,6 +5,12 @@ from fixtures.zenith_fixtures import ZenithEnv
|
||||
from fixtures.compare_fixtures import PgCompare, VanillaCompare, ZenithCompare
|
||||
from fixtures.log_helper import log
|
||||
|
||||
pytest_plugins = (
|
||||
"fixtures.zenith_fixtures",
|
||||
"fixtures.benchmark_fixture",
|
||||
"fixtures.compare_fixtures",
|
||||
)
|
||||
|
||||
|
||||
#
|
||||
# Test buffering GisT build. It WAL-logs the whole relation, in 32-page chunks.
|
||||
|
||||
@@ -6,6 +6,12 @@ from fixtures.log_helper import log
|
||||
from fixtures.benchmark_fixture import MetricReport, ZenithBenchmarker
|
||||
from fixtures.compare_fixtures import PgCompare, VanillaCompare, ZenithCompare
|
||||
|
||||
pytest_plugins = (
|
||||
"fixtures.zenith_fixtures",
|
||||
"fixtures.benchmark_fixture",
|
||||
"fixtures.compare_fixtures",
|
||||
)
|
||||
|
||||
|
||||
async def repeat_bytes(buf, repetitions: int):
|
||||
for i in range(repetitions):
|
||||
|
||||
@@ -5,6 +5,12 @@ from fixtures.compare_fixtures import PgCompare, VanillaCompare, ZenithCompare
|
||||
from fixtures.benchmark_fixture import MetricReport, ZenithBenchmarker
|
||||
from fixtures.log_helper import log
|
||||
|
||||
pytest_plugins = (
|
||||
"fixtures.zenith_fixtures",
|
||||
"fixtures.benchmark_fixture",
|
||||
"fixtures.compare_fixtures",
|
||||
)
|
||||
|
||||
|
||||
#
|
||||
# Run a very short pgbench test.
|
||||
|
||||
@@ -9,6 +9,8 @@ import calendar
|
||||
import timeit
|
||||
import os
|
||||
|
||||
pytest_plugins = ("fixtures.benchmark_fixture", )
|
||||
|
||||
|
||||
def utc_now_timestamp() -> int:
|
||||
return calendar.timegm(datetime.utcnow().utctimetuple())
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
import os
|
||||
from contextlib import closing
|
||||
from fixtures.benchmark_fixture import MetricReport
|
||||
from fixtures.zenith_fixtures import ZenithEnv
|
||||
from fixtures.compare_fixtures import PgCompare, VanillaCompare, ZenithCompare
|
||||
from fixtures.log_helper import log
|
||||
|
||||
import psycopg2.extras
|
||||
import random
|
||||
import time
|
||||
from fixtures.utils import print_gc_result
|
||||
|
||||
|
||||
# This is a clear-box test that demonstrates the worst case scenario for the
|
||||
# "1 segment per layer" implementation of the pageserver. It writes to random
|
||||
# rows, while almost never writing to the same segment twice before flushing.
|
||||
# A naive pageserver implementation would create a full image layer for each
|
||||
# dirty segment, leading to write_amplification = segment_size / page_size,
|
||||
# when compared to vanilla postgres. With segment_size = 10MB, that's 1250.
|
||||
def test_random_writes(zenith_with_baseline: PgCompare):
|
||||
env = zenith_with_baseline
|
||||
|
||||
# Number of rows in the test database. 1M rows runs quickly, but implies
|
||||
# a small effective_checkpoint_distance, which makes the test less realistic.
|
||||
# Using a 300 TB database would imply a 250 MB effective_checkpoint_distance,
|
||||
# but it will take a very long time to run. From what I've seen so far,
|
||||
# increasing n_rows doesn't have impact on the (zenith_runtime / vanilla_runtime)
|
||||
# performance ratio.
|
||||
n_rows = 1 * 1000 * 1000 # around 36 MB table
|
||||
|
||||
# Number of writes per 3 segments. A value of 1 should produce a random
|
||||
# workload where we almost never write to the same segment twice. Larger
|
||||
# values of load_factor produce a larger effective_checkpoint_distance,
|
||||
# making the test more realistic, but less effective. If you want a realistic
|
||||
# worst case scenario and you have time to wait you should increase n_rows instead.
|
||||
load_factor = 1
|
||||
|
||||
# Not sure why but this matters in a weird way (up to 2x difference in perf).
|
||||
# TODO look into it
|
||||
n_iterations = 1
|
||||
|
||||
with closing(env.pg.connect()) as conn:
|
||||
with conn.cursor() as cur:
|
||||
# Create the test table
|
||||
with env.record_duration('init'):
|
||||
cur.execute("""
|
||||
CREATE TABLE Big(
|
||||
pk integer primary key,
|
||||
count integer default 0
|
||||
);
|
||||
""")
|
||||
cur.execute(f"INSERT INTO Big (pk) values (generate_series(1,{n_rows}))")
|
||||
|
||||
# Get table size (can't be predicted because padding and alignment)
|
||||
cur.execute("SELECT pg_relation_size('Big');")
|
||||
row = cur.fetchone()
|
||||
table_size = row[0]
|
||||
env.zenbenchmark.record("table_size", table_size, 'bytes', MetricReport.TEST_PARAM)
|
||||
|
||||
# Decide how much to write, based on knowledge of pageserver implementation.
|
||||
# Avoiding segment collisions maximizes (zenith_runtime / vanilla_runtime).
|
||||
segment_size = 10 * 1024 * 1024
|
||||
n_segments = table_size // segment_size
|
||||
n_writes = load_factor * n_segments // 3
|
||||
|
||||
# The closer this is to 250 MB, the more realistic the test is.
|
||||
effective_checkpoint_distance = table_size * n_writes // n_rows
|
||||
env.zenbenchmark.record("effective_checkpoint_distance",
|
||||
effective_checkpoint_distance,
|
||||
'bytes',
|
||||
MetricReport.TEST_PARAM)
|
||||
|
||||
# Update random keys
|
||||
with env.record_duration('run'):
|
||||
for it in range(n_iterations):
|
||||
for i in range(n_writes):
|
||||
key = random.randint(1, n_rows)
|
||||
cur.execute(f"update Big set count=count+1 where pk={key}")
|
||||
env.flush()
|
||||
@@ -11,6 +11,12 @@ from fixtures.benchmark_fixture import MetricReport, ZenithBenchmarker
|
||||
from fixtures.compare_fixtures import PgCompare
|
||||
import pytest
|
||||
|
||||
pytest_plugins = (
|
||||
"fixtures.zenith_fixtures",
|
||||
"fixtures.benchmark_fixture",
|
||||
"fixtures.compare_fixtures",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('rows', [
|
||||
pytest.param(100000),
|
||||
|
||||
@@ -17,6 +17,12 @@ from fixtures.zenith_fixtures import ZenithEnv
|
||||
from fixtures.compare_fixtures import PgCompare, VanillaCompare, ZenithCompare
|
||||
from fixtures.log_helper import log
|
||||
|
||||
pytest_plugins = (
|
||||
"fixtures.zenith_fixtures",
|
||||
"fixtures.benchmark_fixture",
|
||||
"fixtures.compare_fixtures",
|
||||
)
|
||||
|
||||
|
||||
def test_write_amplification(zenith_with_baseline: PgCompare):
|
||||
env = zenith_with_baseline
|
||||
|
||||
@@ -3,6 +3,8 @@ import os
|
||||
|
||||
from fixtures.zenith_fixtures import ZenithEnv
|
||||
from fixtures.log_helper import log
|
||||
|
||||
pytest_plugins = ("fixtures.zenith_fixtures")
|
||||
"""
|
||||
Use this test to see what happens when tests fail.
|
||||
|
||||
@@ -21,7 +23,9 @@ run_broken = pytest.mark.skipif(os.environ.get('RUN_BROKEN') is None,
|
||||
def test_broken(zenith_simple_env: ZenithEnv, pg_bin):
|
||||
env = zenith_simple_env
|
||||
|
||||
env.zenith_cli.create_branch("test_broken", "empty")
|
||||
# Create a branch for us
|
||||
env.zenith_cli(["branch", "test_broken", "empty"])
|
||||
|
||||
env.postgres.create_start("test_broken")
|
||||
log.info('postgres is running')
|
||||
|
||||
|
||||
2
vendor/postgres
vendored
2
vendor/postgres
vendored
Submodule vendor/postgres updated: a3709cc364...d914790e6c
@@ -10,7 +10,7 @@ use std::fs::File;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::thread;
|
||||
use tracing::*;
|
||||
use walkeeper::control_file::{self, CreateControlFile};
|
||||
use walkeeper::timeline::{CreateControlFile, FileStorage};
|
||||
use zenith_utils::http::endpoint;
|
||||
use zenith_utils::{logging, tcp_listener, GIT_VERSION};
|
||||
|
||||
@@ -96,10 +96,7 @@ fn main() -> Result<()> {
|
||||
.get_matches();
|
||||
|
||||
if let Some(addr) = arg_matches.value_of("dump-control-file") {
|
||||
let state = control_file::FileStorage::load_control_file(
|
||||
Path::new(addr),
|
||||
CreateControlFile::False,
|
||||
)?;
|
||||
let state = FileStorage::load_control_file(Path::new(addr), CreateControlFile::False)?;
|
||||
let json = serde_json::to_string(&state)?;
|
||||
print!("{}", json);
|
||||
return Ok(());
|
||||
|
||||
@@ -1,297 +0,0 @@
|
||||
//! Control file serialization, deserialization and persistence.
|
||||
|
||||
use anyhow::{bail, ensure, Context, Result};
|
||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||
use lazy_static::lazy_static;
|
||||
|
||||
use std::fs::{self, File, OpenOptions};
|
||||
use std::io::{Read, Write};
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use tracing::*;
|
||||
use zenith_metrics::{register_histogram_vec, Histogram, HistogramVec, DISK_WRITE_SECONDS_BUCKETS};
|
||||
use zenith_utils::bin_ser::LeSer;
|
||||
|
||||
use zenith_utils::zid::ZTenantTimelineId;
|
||||
|
||||
use crate::control_file_upgrade::upgrade_control_file;
|
||||
use crate::safekeeper::{SafeKeeperState, SK_FORMAT_VERSION, SK_MAGIC};
|
||||
|
||||
use crate::SafeKeeperConf;
|
||||
|
||||
use std::convert::TryInto;
|
||||
|
||||
// contains persistent metadata for safekeeper
|
||||
const CONTROL_FILE_NAME: &str = "safekeeper.control";
|
||||
// needed to atomically update the state using `rename`
|
||||
const CONTROL_FILE_NAME_PARTIAL: &str = "safekeeper.control.partial";
|
||||
pub const CHECKSUM_SIZE: usize = std::mem::size_of::<u32>();
|
||||
|
||||
// A named boolean.
|
||||
#[derive(Debug)]
|
||||
pub enum CreateControlFile {
|
||||
True,
|
||||
False,
|
||||
}
|
||||
|
||||
lazy_static! {
|
||||
static ref PERSIST_CONTROL_FILE_SECONDS: HistogramVec = register_histogram_vec!(
|
||||
"safekeeper_persist_control_file_seconds",
|
||||
"Seconds to persist and sync control file, grouped by timeline",
|
||||
&["tenant_id", "timeline_id"],
|
||||
DISK_WRITE_SECONDS_BUCKETS.to_vec()
|
||||
)
|
||||
.expect("Failed to register safekeeper_persist_control_file_seconds histogram vec");
|
||||
}
|
||||
|
||||
pub trait Storage {
|
||||
/// Persist safekeeper state on disk.
|
||||
fn persist(&mut self, s: &SafeKeeperState) -> Result<()>;
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FileStorage {
|
||||
// save timeline dir to avoid reconstructing it every time
|
||||
timeline_dir: PathBuf,
|
||||
conf: SafeKeeperConf,
|
||||
persist_control_file_seconds: Histogram,
|
||||
}
|
||||
|
||||
impl FileStorage {
|
||||
pub fn new(zttid: &ZTenantTimelineId, conf: &SafeKeeperConf) -> FileStorage {
|
||||
let timeline_dir = conf.timeline_dir(zttid);
|
||||
let tenant_id = zttid.tenant_id.to_string();
|
||||
let timeline_id = zttid.timeline_id.to_string();
|
||||
FileStorage {
|
||||
timeline_dir,
|
||||
conf: conf.clone(),
|
||||
persist_control_file_seconds: PERSIST_CONTROL_FILE_SECONDS
|
||||
.with_label_values(&[&tenant_id, &timeline_id]),
|
||||
}
|
||||
}
|
||||
|
||||
// Check the magic/version in the on-disk data and deserialize it, if possible.
|
||||
fn deser_sk_state(buf: &mut &[u8]) -> Result<SafeKeeperState> {
|
||||
// Read the version independent part
|
||||
let magic = buf.read_u32::<LittleEndian>()?;
|
||||
if magic != SK_MAGIC {
|
||||
bail!(
|
||||
"bad control file magic: {:X}, expected {:X}",
|
||||
magic,
|
||||
SK_MAGIC
|
||||
);
|
||||
}
|
||||
let version = buf.read_u32::<LittleEndian>()?;
|
||||
if version == SK_FORMAT_VERSION {
|
||||
let res = SafeKeeperState::des(buf)?;
|
||||
return Ok(res);
|
||||
}
|
||||
// try to upgrade
|
||||
upgrade_control_file(buf, version)
|
||||
}
|
||||
|
||||
// Load control file for given zttid at path specified by conf.
|
||||
pub fn load_control_file_conf(
|
||||
conf: &SafeKeeperConf,
|
||||
zttid: &ZTenantTimelineId,
|
||||
create: CreateControlFile,
|
||||
) -> Result<SafeKeeperState> {
|
||||
let path = conf.timeline_dir(zttid).join(CONTROL_FILE_NAME);
|
||||
Self::load_control_file(path, create)
|
||||
}
|
||||
|
||||
/// Read in the control file.
|
||||
/// If create=false and file doesn't exist, bails out.
|
||||
pub fn load_control_file<P: AsRef<Path>>(
|
||||
control_file_path: P,
|
||||
create: CreateControlFile,
|
||||
) -> Result<SafeKeeperState> {
|
||||
info!(
|
||||
"loading control file {}, create={:?}",
|
||||
control_file_path.as_ref().display(),
|
||||
create,
|
||||
);
|
||||
|
||||
let mut control_file = OpenOptions::new()
|
||||
.read(true)
|
||||
.write(true)
|
||||
.create(matches!(create, CreateControlFile::True))
|
||||
.open(&control_file_path)
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"failed to open control file at {}",
|
||||
control_file_path.as_ref().display(),
|
||||
)
|
||||
})?;
|
||||
|
||||
// Empty file is legit on 'create', don't try to deser from it.
|
||||
let state = if control_file.metadata().unwrap().len() == 0 {
|
||||
if let CreateControlFile::False = create {
|
||||
bail!("control file is empty");
|
||||
}
|
||||
SafeKeeperState::new()
|
||||
} else {
|
||||
let mut buf = Vec::new();
|
||||
control_file
|
||||
.read_to_end(&mut buf)
|
||||
.context("failed to read control file")?;
|
||||
|
||||
let calculated_checksum = crc32c::crc32c(&buf[..buf.len() - CHECKSUM_SIZE]);
|
||||
|
||||
let expected_checksum_bytes: &[u8; CHECKSUM_SIZE] =
|
||||
buf[buf.len() - CHECKSUM_SIZE..].try_into()?;
|
||||
let expected_checksum = u32::from_le_bytes(*expected_checksum_bytes);
|
||||
|
||||
ensure!(
|
||||
calculated_checksum == expected_checksum,
|
||||
format!(
|
||||
"safekeeper control file checksum mismatch: expected {} got {}",
|
||||
expected_checksum, calculated_checksum
|
||||
)
|
||||
);
|
||||
|
||||
FileStorage::deser_sk_state(&mut &buf[..buf.len() - CHECKSUM_SIZE]).with_context(
|
||||
|| {
|
||||
format!(
|
||||
"while reading control file {}",
|
||||
control_file_path.as_ref().display(),
|
||||
)
|
||||
},
|
||||
)?
|
||||
};
|
||||
Ok(state)
|
||||
}
|
||||
}
|
||||
|
||||
impl Storage for FileStorage {
|
||||
// persists state durably to underlying storage
|
||||
// for description see https://lwn.net/Articles/457667/
|
||||
fn persist(&mut self, s: &SafeKeeperState) -> Result<()> {
|
||||
let _timer = &self.persist_control_file_seconds.start_timer();
|
||||
|
||||
// write data to safekeeper.control.partial
|
||||
let control_partial_path = self.timeline_dir.join(CONTROL_FILE_NAME_PARTIAL);
|
||||
let mut control_partial = File::create(&control_partial_path).with_context(|| {
|
||||
format!(
|
||||
"failed to create partial control file at: {}",
|
||||
&control_partial_path.display()
|
||||
)
|
||||
})?;
|
||||
let mut buf: Vec<u8> = Vec::new();
|
||||
buf.write_u32::<LittleEndian>(SK_MAGIC)?;
|
||||
buf.write_u32::<LittleEndian>(SK_FORMAT_VERSION)?;
|
||||
s.ser_into(&mut buf)?;
|
||||
|
||||
// calculate checksum before resize
|
||||
let checksum = crc32c::crc32c(&buf);
|
||||
buf.extend_from_slice(&checksum.to_le_bytes());
|
||||
|
||||
control_partial.write_all(&buf).with_context(|| {
|
||||
format!(
|
||||
"failed to write safekeeper state into control file at: {}",
|
||||
control_partial_path.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
// fsync the file
|
||||
control_partial.sync_all().with_context(|| {
|
||||
format!(
|
||||
"failed to sync partial control file at {}",
|
||||
control_partial_path.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
let control_path = self.timeline_dir.join(CONTROL_FILE_NAME);
|
||||
|
||||
// rename should be atomic
|
||||
fs::rename(&control_partial_path, &control_path)?;
|
||||
// this sync is not required by any standard but postgres does this (see durable_rename)
|
||||
File::open(&control_path)
|
||||
.and_then(|f| f.sync_all())
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"failed to sync control file at: {}",
|
||||
&control_path.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
// fsync the directory (linux specific)
|
||||
File::open(&self.timeline_dir)
|
||||
.and_then(|f| f.sync_all())
|
||||
.context("failed to sync control file directory")?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::FileStorage;
|
||||
use super::*;
|
||||
use crate::{safekeeper::SafeKeeperState, SafeKeeperConf, ZTenantTimelineId};
|
||||
use anyhow::Result;
|
||||
use std::fs;
|
||||
use zenith_utils::lsn::Lsn;
|
||||
|
||||
fn stub_conf() -> SafeKeeperConf {
|
||||
let workdir = tempfile::tempdir().unwrap().into_path();
|
||||
SafeKeeperConf {
|
||||
workdir,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
fn load_from_control_file(
|
||||
conf: &SafeKeeperConf,
|
||||
zttid: &ZTenantTimelineId,
|
||||
create: CreateControlFile,
|
||||
) -> Result<(FileStorage, SafeKeeperState)> {
|
||||
fs::create_dir_all(&conf.timeline_dir(zttid)).expect("failed to create timeline dir");
|
||||
Ok((
|
||||
FileStorage::new(zttid, conf),
|
||||
FileStorage::load_control_file_conf(conf, zttid, create)?,
|
||||
))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_read_write_safekeeper_state() {
|
||||
let conf = stub_conf();
|
||||
let zttid = ZTenantTimelineId::generate();
|
||||
{
|
||||
let (mut storage, mut state) =
|
||||
load_from_control_file(&conf, &zttid, CreateControlFile::True)
|
||||
.expect("failed to read state");
|
||||
// change something
|
||||
state.wal_start_lsn = Lsn(42);
|
||||
storage.persist(&state).expect("failed to persist state");
|
||||
}
|
||||
|
||||
let (_, state) = load_from_control_file(&conf, &zttid, CreateControlFile::False)
|
||||
.expect("failed to read state");
|
||||
assert_eq!(state.wal_start_lsn, Lsn(42));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_safekeeper_state_checksum_mismatch() {
|
||||
let conf = stub_conf();
|
||||
let zttid = ZTenantTimelineId::generate();
|
||||
{
|
||||
let (mut storage, mut state) =
|
||||
load_from_control_file(&conf, &zttid, CreateControlFile::True)
|
||||
.expect("failed to read state");
|
||||
// change something
|
||||
state.wal_start_lsn = Lsn(42);
|
||||
storage.persist(&state).expect("failed to persist state");
|
||||
}
|
||||
let control_path = conf.timeline_dir(&zttid).join(CONTROL_FILE_NAME);
|
||||
let mut data = fs::read(&control_path).unwrap();
|
||||
data[0] += 1; // change the first byte of the file to fail checksum validation
|
||||
fs::write(&control_path, &data).expect("failed to write control file");
|
||||
|
||||
match load_from_control_file(&conf, &zttid, CreateControlFile::False) {
|
||||
Err(err) => assert!(err
|
||||
.to_string()
|
||||
.contains("safekeeper control file checksum mismatch")),
|
||||
Ok(_) => panic!("expected error"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -19,7 +19,7 @@ use zenith_utils::pq_proto::{BeMessage, FeStartupPacket, RowDescriptor, INT4_OID
|
||||
use zenith_utils::zid::{ZTenantId, ZTenantTimelineId, ZTimelineId};
|
||||
|
||||
use crate::callmemaybe::CallmeEvent;
|
||||
use crate::control_file::CreateControlFile;
|
||||
use crate::timeline::CreateControlFile;
|
||||
use tokio::sync::mpsc::UnboundedSender;
|
||||
|
||||
/// Safekeeper handler of postgres commands
|
||||
|
||||
@@ -7,9 +7,9 @@ use zenith_utils::http::{RequestExt, RouterBuilder};
|
||||
use zenith_utils::lsn::Lsn;
|
||||
use zenith_utils::zid::ZTenantTimelineId;
|
||||
|
||||
use crate::control_file::CreateControlFile;
|
||||
use crate::safekeeper::Term;
|
||||
use crate::safekeeper::TermHistory;
|
||||
use crate::timeline::CreateControlFile;
|
||||
use crate::timeline::GlobalTimelines;
|
||||
use crate::SafeKeeperConf;
|
||||
use zenith_utils::http::endpoint;
|
||||
|
||||
@@ -5,8 +5,6 @@ use std::time::Duration;
|
||||
use zenith_utils::zid::ZTenantTimelineId;
|
||||
|
||||
pub mod callmemaybe;
|
||||
pub mod control_file;
|
||||
pub mod control_file_upgrade;
|
||||
pub mod handler;
|
||||
pub mod http;
|
||||
pub mod json_ctrl;
|
||||
@@ -15,8 +13,8 @@ pub mod s3_offload;
|
||||
pub mod safekeeper;
|
||||
pub mod send_wal;
|
||||
pub mod timeline;
|
||||
pub mod upgrade;
|
||||
pub mod wal_service;
|
||||
pub mod wal_storage;
|
||||
|
||||
pub mod defaults {
|
||||
use const_format::formatcp;
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
use anyhow::{bail, Context, Result};
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
|
||||
use postgres_ffi::waldecoder::WalStreamDecoder;
|
||||
use postgres_ffi::xlog_utils::TimeLineID;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::cmp::min;
|
||||
@@ -13,11 +13,12 @@ use tracing::*;
|
||||
|
||||
use lazy_static::lazy_static;
|
||||
|
||||
use crate::control_file;
|
||||
use crate::send_wal::HotStandbyFeedback;
|
||||
use crate::wal_storage;
|
||||
use postgres_ffi::xlog_utils::MAX_SEND_SIZE;
|
||||
use zenith_metrics::{register_gauge_vec, Gauge, GaugeVec};
|
||||
use zenith_metrics::{
|
||||
register_gauge_vec, register_histogram_vec, Gauge, GaugeVec, Histogram, HistogramVec,
|
||||
DISK_WRITE_SECONDS_BUCKETS,
|
||||
};
|
||||
use zenith_utils::bin_ser::LeSer;
|
||||
use zenith_utils::lsn::Lsn;
|
||||
use zenith_utils::pq_proto::SystemId;
|
||||
@@ -406,87 +407,130 @@ impl AcceptorProposerMessage {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Storage {
|
||||
/// Persist safekeeper state on disk.
|
||||
fn persist(&mut self, s: &SafeKeeperState) -> Result<()>;
|
||||
/// Write piece of wal in buf to disk and sync it.
|
||||
fn write_wal(&mut self, server: &ServerInfo, startpos: Lsn, buf: &[u8]) -> Result<()>;
|
||||
// Truncate WAL at specified LSN
|
||||
fn truncate_wal(&mut self, s: &ServerInfo, endpos: Lsn) -> Result<()>;
|
||||
}
|
||||
|
||||
lazy_static! {
|
||||
// The prometheus crate does not support u64 yet, i64 only (see `IntGauge`).
|
||||
// i64 is faster than f64, so update to u64 when available.
|
||||
static ref FLUSH_LSN_GAUGE: GaugeVec = register_gauge_vec!(
|
||||
"safekeeper_flush_lsn",
|
||||
"Current flush_lsn, grouped by timeline",
|
||||
&["ztli"]
|
||||
)
|
||||
.expect("Failed to register safekeeper_flush_lsn gauge vec");
|
||||
static ref COMMIT_LSN_GAUGE: GaugeVec = register_gauge_vec!(
|
||||
"safekeeper_commit_lsn",
|
||||
"Current commit_lsn (not necessarily persisted to disk), grouped by timeline",
|
||||
&["tenant_id", "timeline_id"]
|
||||
&["ztli"]
|
||||
)
|
||||
.expect("Failed to register safekeeper_commit_lsn gauge vec");
|
||||
static ref WRITE_WAL_BYTES: HistogramVec = register_histogram_vec!(
|
||||
"safekeeper_write_wal_bytes",
|
||||
"Bytes written to WAL in a single request, grouped by timeline",
|
||||
&["timeline_id"],
|
||||
vec![1.0, 10.0, 100.0, 1024.0, 8192.0, 128.0 * 1024.0, 1024.0 * 1024.0, 10.0 * 1024.0 * 1024.0]
|
||||
)
|
||||
.expect("Failed to register safekeeper_write_wal_bytes histogram vec");
|
||||
static ref WRITE_WAL_SECONDS: HistogramVec = register_histogram_vec!(
|
||||
"safekeeper_write_wal_seconds",
|
||||
"Seconds spent writing and syncing WAL to a disk in a single request, grouped by timeline",
|
||||
&["timeline_id"],
|
||||
DISK_WRITE_SECONDS_BUCKETS.to_vec()
|
||||
)
|
||||
.expect("Failed to register safekeeper_write_wal_seconds histogram vec");
|
||||
}
|
||||
|
||||
struct SafeKeeperMetrics {
|
||||
flush_lsn: Gauge,
|
||||
commit_lsn: Gauge,
|
||||
write_wal_bytes: Histogram,
|
||||
write_wal_seconds: Histogram,
|
||||
}
|
||||
|
||||
impl SafeKeeperMetrics {
|
||||
fn new(tenant_id: ZTenantId, timeline_id: ZTimelineId, commit_lsn: Lsn) -> Self {
|
||||
let tenant_id = tenant_id.to_string();
|
||||
let timeline_id = timeline_id.to_string();
|
||||
let m = Self {
|
||||
commit_lsn: COMMIT_LSN_GAUGE.with_label_values(&[&tenant_id, &timeline_id]),
|
||||
struct SafeKeeperMetricsBuilder {
|
||||
ztli: ZTimelineId,
|
||||
flush_lsn: Lsn,
|
||||
commit_lsn: Lsn,
|
||||
}
|
||||
|
||||
impl SafeKeeperMetricsBuilder {
|
||||
fn build(self) -> SafeKeeperMetrics {
|
||||
let ztli_str = format!("{}", self.ztli);
|
||||
let m = SafeKeeperMetrics {
|
||||
flush_lsn: FLUSH_LSN_GAUGE.with_label_values(&[&ztli_str]),
|
||||
commit_lsn: COMMIT_LSN_GAUGE.with_label_values(&[&ztli_str]),
|
||||
write_wal_bytes: WRITE_WAL_BYTES.with_label_values(&[&ztli_str]),
|
||||
write_wal_seconds: WRITE_WAL_SECONDS.with_label_values(&[&ztli_str]),
|
||||
};
|
||||
m.commit_lsn.set(u64::from(commit_lsn) as f64);
|
||||
m.flush_lsn.set(u64::from(self.flush_lsn) as f64);
|
||||
m.commit_lsn.set(u64::from(self.commit_lsn) as f64);
|
||||
m
|
||||
}
|
||||
}
|
||||
|
||||
/// SafeKeeper which consumes events (messages from compute) and provides
|
||||
/// replies.
|
||||
pub struct SafeKeeper<CTRL: control_file::Storage, WAL: wal_storage::Storage> {
|
||||
pub struct SafeKeeper<ST: Storage> {
|
||||
/// Locally flushed part of WAL with full records (end_lsn of last record).
|
||||
/// Established by reading wal.
|
||||
pub flush_lsn: Lsn,
|
||||
// Cached metrics so we don't have to recompute labels on each update.
|
||||
metrics: SafeKeeperMetrics,
|
||||
|
||||
/// not-yet-flushed pairs of same named fields in s.*
|
||||
pub commit_lsn: Lsn,
|
||||
pub truncate_lsn: Lsn,
|
||||
pub storage: ST,
|
||||
pub s: SafeKeeperState, // persistent part
|
||||
|
||||
pub control_store: CTRL,
|
||||
pub wal_store: WAL,
|
||||
decoder: WalStreamDecoder,
|
||||
}
|
||||
|
||||
impl<CTRL, WAL> SafeKeeper<CTRL, WAL>
|
||||
impl<ST> SafeKeeper<ST>
|
||||
where
|
||||
CTRL: control_file::Storage,
|
||||
WAL: wal_storage::Storage,
|
||||
ST: Storage,
|
||||
{
|
||||
// constructor
|
||||
pub fn new(
|
||||
ztli: ZTimelineId,
|
||||
control_store: CTRL,
|
||||
wal_store: WAL,
|
||||
flush_lsn: Lsn,
|
||||
storage: ST,
|
||||
state: SafeKeeperState,
|
||||
) -> SafeKeeper<CTRL, WAL> {
|
||||
) -> SafeKeeper<ST> {
|
||||
if state.server.timeline_id != ZTimelineId::from([0u8; 16])
|
||||
&& ztli != state.server.timeline_id
|
||||
{
|
||||
panic!("Calling SafeKeeper::new with inconsistent ztli ({}) and SafeKeeperState.server.timeline_id ({})", ztli, state.server.timeline_id);
|
||||
}
|
||||
|
||||
SafeKeeper {
|
||||
metrics: SafeKeeperMetrics::new(state.server.tenant_id, ztli, state.commit_lsn),
|
||||
flush_lsn,
|
||||
metrics: SafeKeeperMetricsBuilder {
|
||||
ztli,
|
||||
flush_lsn,
|
||||
commit_lsn: state.commit_lsn,
|
||||
}
|
||||
.build(),
|
||||
commit_lsn: state.commit_lsn,
|
||||
truncate_lsn: state.truncate_lsn,
|
||||
storage,
|
||||
s: state,
|
||||
control_store,
|
||||
wal_store,
|
||||
decoder: WalStreamDecoder::new(Lsn(0)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get history of term switches for the available WAL
|
||||
fn get_term_history(&self) -> TermHistory {
|
||||
self.s
|
||||
.acceptor_state
|
||||
.term_history
|
||||
.up_to(self.wal_store.flush_lsn())
|
||||
self.s.acceptor_state.term_history.up_to(self.flush_lsn)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn get_epoch(&self) -> Term {
|
||||
self.s.acceptor_state.get_epoch(self.wal_store.flush_lsn())
|
||||
self.s.acceptor_state.get_epoch(self.flush_lsn)
|
||||
}
|
||||
|
||||
/// Process message from proposer and possibly form reply. Concurrent
|
||||
@@ -528,20 +572,20 @@ where
|
||||
}
|
||||
|
||||
// set basic info about server, if not yet
|
||||
// TODO: verify that is doesn't change after
|
||||
self.s.server.system_id = msg.system_id;
|
||||
self.s.server.tenant_id = msg.tenant_id;
|
||||
self.s.server.timeline_id = msg.ztli;
|
||||
self.s.server.wal_seg_size = msg.wal_seg_size;
|
||||
self.control_store
|
||||
self.storage
|
||||
.persist(&self.s)
|
||||
.context("failed to persist shared state")?;
|
||||
|
||||
// pass wal_seg_size to read WAL and find flush_lsn
|
||||
self.wal_store.init_storage(&self.s)?;
|
||||
|
||||
// update tenant_id/timeline_id in metrics
|
||||
self.metrics = SafeKeeperMetrics::new(msg.tenant_id, msg.ztli, self.commit_lsn);
|
||||
self.metrics = SafeKeeperMetricsBuilder {
|
||||
ztli: self.s.server.timeline_id,
|
||||
flush_lsn: self.flush_lsn,
|
||||
commit_lsn: self.commit_lsn,
|
||||
}
|
||||
.build();
|
||||
|
||||
info!(
|
||||
"processed greeting from proposer {:?}, sending term {:?}",
|
||||
@@ -561,14 +605,14 @@ where
|
||||
let mut resp = VoteResponse {
|
||||
term: self.s.acceptor_state.term,
|
||||
vote_given: false as u64,
|
||||
flush_lsn: self.wal_store.flush_lsn(),
|
||||
flush_lsn: self.flush_lsn,
|
||||
truncate_lsn: self.s.truncate_lsn,
|
||||
term_history: self.get_term_history(),
|
||||
};
|
||||
if self.s.acceptor_state.term < msg.term {
|
||||
self.s.acceptor_state.term = msg.term;
|
||||
// persist vote before sending it out
|
||||
self.control_store.persist(&self.s)?;
|
||||
self.storage.persist(&self.s)?;
|
||||
resp.term = self.s.acceptor_state.term;
|
||||
resp.vote_given = true as u64;
|
||||
}
|
||||
@@ -580,7 +624,7 @@ where
|
||||
fn bump_if_higher(&mut self, term: Term) -> Result<()> {
|
||||
if self.s.acceptor_state.term < term {
|
||||
self.s.acceptor_state.term = term;
|
||||
self.control_store.persist(&self.s)?;
|
||||
self.storage.persist(&self.s)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -589,7 +633,7 @@ where
|
||||
fn append_response(&self) -> AppendResponse {
|
||||
AppendResponse {
|
||||
term: self.s.acceptor_state.term,
|
||||
flush_lsn: self.wal_store.flush_lsn(),
|
||||
flush_lsn: self.flush_lsn,
|
||||
commit_lsn: self.s.commit_lsn,
|
||||
// will be filled by the upper code to avoid bothering safekeeper
|
||||
hs_feedback: HotStandbyFeedback::empty(),
|
||||
@@ -605,12 +649,22 @@ where
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// truncate wal, update the lsns
|
||||
self.wal_store.truncate_wal(msg.start_streaming_at)?;
|
||||
// TODO: cross check divergence point
|
||||
|
||||
// streaming must not create a hole
|
||||
assert!(self.flush_lsn == Lsn(0) || self.flush_lsn >= msg.start_streaming_at);
|
||||
|
||||
// truncate obsolete part of WAL
|
||||
if self.flush_lsn != Lsn(0) {
|
||||
self.storage
|
||||
.truncate_wal(&self.s.server, msg.start_streaming_at)?;
|
||||
}
|
||||
// update our end of WAL pointer
|
||||
self.flush_lsn = msg.start_streaming_at;
|
||||
self.metrics.flush_lsn.set(u64::from(self.flush_lsn) as f64);
|
||||
// and now adopt term history from proposer
|
||||
self.s.acceptor_state.term_history = msg.term_history.clone();
|
||||
self.control_store.persist(&self.s)?;
|
||||
self.storage.persist(&self.s)?;
|
||||
|
||||
info!("start receiving WAL since {:?}", msg.start_streaming_at);
|
||||
|
||||
@@ -636,14 +690,42 @@ where
|
||||
// After ProposerElected, which performs truncation, we should get only
|
||||
// indeed append requests (but flush_lsn is advanced only on record
|
||||
// boundary, so might be less).
|
||||
assert!(self.wal_store.flush_lsn() <= msg.h.begin_lsn);
|
||||
assert!(self.flush_lsn <= msg.h.begin_lsn);
|
||||
|
||||
self.s.proposer_uuid = msg.h.proposer_uuid;
|
||||
let mut sync_control_file = false;
|
||||
|
||||
// do the job
|
||||
let mut last_rec_lsn = Lsn(0);
|
||||
if !msg.wal_data.is_empty() {
|
||||
self.wal_store.write_wal(msg.h.begin_lsn, &msg.wal_data)?;
|
||||
self.metrics
|
||||
.write_wal_bytes
|
||||
.observe(msg.wal_data.len() as f64);
|
||||
{
|
||||
let _timer = self.metrics.write_wal_seconds.start_timer();
|
||||
self.storage
|
||||
.write_wal(&self.s.server, msg.h.begin_lsn, &msg.wal_data)?;
|
||||
}
|
||||
|
||||
// figure out last record's end lsn for reporting (if we got the
|
||||
// whole record)
|
||||
if self.decoder.available() != msg.h.begin_lsn {
|
||||
info!(
|
||||
"restart decoder from {} to {}",
|
||||
self.decoder.available(),
|
||||
msg.h.begin_lsn,
|
||||
);
|
||||
self.decoder = WalStreamDecoder::new(msg.h.begin_lsn);
|
||||
}
|
||||
self.decoder.feed_bytes(&msg.wal_data);
|
||||
loop {
|
||||
match self.decoder.poll_decode()? {
|
||||
None => break, // no full record yet
|
||||
Some((lsn, _rec)) => {
|
||||
last_rec_lsn = lsn;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If this was the first record we ever receieved, remember LSN to help
|
||||
// find_end_of_wal skip the hole in the beginning.
|
||||
@@ -653,11 +735,16 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
if last_rec_lsn > self.flush_lsn {
|
||||
self.flush_lsn = last_rec_lsn;
|
||||
self.metrics.flush_lsn.set(u64::from(self.flush_lsn) as f64);
|
||||
}
|
||||
|
||||
// Advance commit_lsn taking into account what we have locally.
|
||||
// commit_lsn can be 0, being unknown to new walproposer while he hasn't
|
||||
// collected majority of its epoch acks yet, ignore it in this case.
|
||||
if msg.h.commit_lsn != Lsn(0) {
|
||||
let commit_lsn = min(msg.h.commit_lsn, self.wal_store.flush_lsn());
|
||||
let commit_lsn = min(msg.h.commit_lsn, self.flush_lsn);
|
||||
// If new commit_lsn reached epoch switch, force sync of control
|
||||
// file: walproposer in sync mode is very interested when this
|
||||
// happens. Note: this is for sync-safekeepers mode only, as
|
||||
@@ -683,7 +770,7 @@ where
|
||||
}
|
||||
|
||||
if sync_control_file {
|
||||
self.control_store.persist(&self.s)?;
|
||||
self.storage.persist(&self.s)?;
|
||||
}
|
||||
|
||||
let resp = self.append_response();
|
||||
@@ -702,52 +789,34 @@ where
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::wal_storage::Storage;
|
||||
|
||||
// fake storage for tests
|
||||
struct InMemoryState {
|
||||
struct InMemoryStorage {
|
||||
persisted_state: SafeKeeperState,
|
||||
}
|
||||
|
||||
impl control_file::Storage for InMemoryState {
|
||||
impl Storage for InMemoryStorage {
|
||||
fn persist(&mut self, s: &SafeKeeperState) -> Result<()> {
|
||||
self.persisted_state = s.clone();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
struct DummyWalStore {
|
||||
lsn: Lsn,
|
||||
}
|
||||
|
||||
impl wal_storage::Storage for DummyWalStore {
|
||||
fn flush_lsn(&self) -> Lsn {
|
||||
self.lsn
|
||||
}
|
||||
|
||||
fn init_storage(&mut self, _state: &SafeKeeperState) -> Result<()> {
|
||||
fn write_wal(&mut self, _server: &ServerInfo, _startpos: Lsn, _buf: &[u8]) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn write_wal(&mut self, startpos: Lsn, buf: &[u8]) -> Result<()> {
|
||||
self.lsn = startpos + buf.len() as u64;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn truncate_wal(&mut self, end_pos: Lsn) -> Result<()> {
|
||||
self.lsn = end_pos;
|
||||
fn truncate_wal(&mut self, _server: &ServerInfo, _end_pos: Lsn) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_voting() {
|
||||
let storage = InMemoryState {
|
||||
let storage = InMemoryStorage {
|
||||
persisted_state: SafeKeeperState::new(),
|
||||
};
|
||||
let wal_store = DummyWalStore { lsn: Lsn(0) };
|
||||
let ztli = ZTimelineId::from([0u8; 16]);
|
||||
let mut sk = SafeKeeper::new(ztli, storage, wal_store, SafeKeeperState::new());
|
||||
let mut sk = SafeKeeper::new(ztli, Lsn(0), storage, SafeKeeperState::new());
|
||||
|
||||
// check voting for 1 is ok
|
||||
let vote_request = ProposerAcceptorMessage::VoteRequest(VoteRequest { term: 1 });
|
||||
@@ -758,11 +827,11 @@ mod tests {
|
||||
}
|
||||
|
||||
// reboot...
|
||||
let state = sk.control_store.persisted_state.clone();
|
||||
let storage = InMemoryState {
|
||||
let state = sk.storage.persisted_state.clone();
|
||||
let storage = InMemoryStorage {
|
||||
persisted_state: state.clone(),
|
||||
};
|
||||
sk = SafeKeeper::new(ztli, storage, sk.wal_store, state);
|
||||
sk = SafeKeeper::new(ztli, Lsn(0), storage, state);
|
||||
|
||||
// and ensure voting second time for 1 is not ok
|
||||
vote_resp = sk.process_msg(&vote_request);
|
||||
@@ -774,12 +843,11 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_epoch_switch() {
|
||||
let storage = InMemoryState {
|
||||
let storage = InMemoryStorage {
|
||||
persisted_state: SafeKeeperState::new(),
|
||||
};
|
||||
let wal_store = DummyWalStore { lsn: Lsn(0) };
|
||||
let ztli = ZTimelineId::from([0u8; 16]);
|
||||
let mut sk = SafeKeeper::new(ztli, storage, wal_store, SafeKeeperState::new());
|
||||
let mut sk = SafeKeeper::new(ztli, Lsn(0), storage, SafeKeeperState::new());
|
||||
|
||||
let mut ar_hdr = AppendRequestHeader {
|
||||
term: 1,
|
||||
@@ -820,7 +888,7 @@ mod tests {
|
||||
};
|
||||
let resp = sk.process_msg(&ProposerAcceptorMessage::AppendRequest(append_request));
|
||||
assert!(resp.is_ok());
|
||||
sk.wal_store.truncate_wal(Lsn(3)).unwrap(); // imitate the complete record at 3 %)
|
||||
sk.flush_lsn = Lsn(3); // imitate the complete record at 3 %)
|
||||
assert_eq!(sk.get_epoch(), 1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,16 +3,20 @@
|
||||
|
||||
use crate::handler::SafekeeperPostgresHandler;
|
||||
use crate::timeline::{ReplicaState, Timeline, TimelineTools};
|
||||
use crate::wal_storage::WalReader;
|
||||
use anyhow::{bail, Context, Result};
|
||||
|
||||
use postgres_ffi::xlog_utils::{get_current_timestamp, TimestampTz, MAX_SEND_SIZE};
|
||||
use postgres_ffi::xlog_utils::{
|
||||
get_current_timestamp, TimestampTz, XLogFileName, MAX_SEND_SIZE, PG_TLI,
|
||||
};
|
||||
|
||||
use crate::callmemaybe::{CallmeEvent, SubscriptionStateKey};
|
||||
use bytes::Bytes;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::cmp::min;
|
||||
use std::fs::File;
|
||||
use std::io::{Read, Seek, SeekFrom};
|
||||
use std::net::Shutdown;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use std::thread::sleep;
|
||||
use std::time::Duration;
|
||||
@@ -190,6 +194,24 @@ impl ReplicationConn {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Helper function for opening a wal file.
|
||||
fn open_wal_file(wal_file_path: &Path) -> Result<File> {
|
||||
// First try to open the .partial file.
|
||||
let mut partial_path = wal_file_path.to_owned();
|
||||
partial_path.set_extension("partial");
|
||||
if let Ok(opened_file) = File::open(&partial_path) {
|
||||
return Ok(opened_file);
|
||||
}
|
||||
|
||||
// If that failed, try it without the .partial extension.
|
||||
File::open(&wal_file_path)
|
||||
.with_context(|| format!("Failed to open WAL file {:?}", wal_file_path))
|
||||
.map_err(|e| {
|
||||
error!("{}", e);
|
||||
e
|
||||
})
|
||||
}
|
||||
|
||||
///
|
||||
/// Handle START_REPLICATION replication command
|
||||
///
|
||||
@@ -289,15 +311,7 @@ impl ReplicationConn {
|
||||
pgb.write_message(&BeMessage::CopyBothResponse)?;
|
||||
|
||||
let mut end_pos = Lsn(0);
|
||||
|
||||
let mut wal_reader = WalReader::new(
|
||||
spg.conf.timeline_dir(&spg.timeline.get().zttid),
|
||||
wal_seg_size,
|
||||
start_pos,
|
||||
);
|
||||
|
||||
// buffer for wal sending, limited by MAX_SEND_SIZE
|
||||
let mut send_buf = vec![0u8; MAX_SEND_SIZE];
|
||||
let mut wal_file: Option<File> = None;
|
||||
|
||||
loop {
|
||||
if let Some(stop_pos) = stop_pos {
|
||||
@@ -331,26 +345,53 @@ impl ReplicationConn {
|
||||
}
|
||||
}
|
||||
|
||||
// Take the `File` from `wal_file`, or open a new file.
|
||||
let mut file = match wal_file.take() {
|
||||
Some(file) => file,
|
||||
None => {
|
||||
// Open a new file.
|
||||
let segno = start_pos.segment_number(wal_seg_size);
|
||||
let wal_file_name = XLogFileName(PG_TLI, segno, wal_seg_size);
|
||||
let wal_file_path = spg
|
||||
.conf
|
||||
.timeline_dir(&spg.timeline.get().zttid)
|
||||
.join(wal_file_name);
|
||||
Self::open_wal_file(&wal_file_path)?
|
||||
}
|
||||
};
|
||||
|
||||
let xlogoff = start_pos.segment_offset(wal_seg_size) as usize;
|
||||
|
||||
// How much to read and send in message? We cannot cross the WAL file
|
||||
// boundary, and we don't want send more than MAX_SEND_SIZE.
|
||||
let send_size = end_pos.checked_sub(start_pos).unwrap().0 as usize;
|
||||
let send_size = min(send_size, send_buf.len());
|
||||
let send_size = min(send_size, wal_seg_size - xlogoff);
|
||||
let send_size = min(send_size, MAX_SEND_SIZE);
|
||||
|
||||
let send_buf = &mut send_buf[..send_size];
|
||||
|
||||
// read wal into buffer
|
||||
let send_size = wal_reader.read(send_buf)?;
|
||||
let send_buf = &send_buf[..send_size];
|
||||
// Read some data from the file.
|
||||
let mut file_buf = vec![0u8; send_size];
|
||||
file.seek(SeekFrom::Start(xlogoff as u64))
|
||||
.and_then(|_| file.read_exact(&mut file_buf))
|
||||
.context("Failed to read data from WAL file")?;
|
||||
|
||||
// Write some data to the network socket.
|
||||
pgb.write_message(&BeMessage::XLogData(XLogDataBody {
|
||||
wal_start: start_pos.0,
|
||||
wal_end: end_pos.0,
|
||||
timestamp: get_current_timestamp(),
|
||||
data: send_buf,
|
||||
data: &file_buf,
|
||||
}))
|
||||
.context("Failed to send XLogData")?;
|
||||
|
||||
start_pos += send_size as u64;
|
||||
|
||||
trace!("sent WAL up to {}", start_pos);
|
||||
|
||||
// Decide whether to reuse this file. If we don't set wal_file here
|
||||
// a new file will be opened next time.
|
||||
if start_pos.segment_offset(wal_seg_size) != 0 {
|
||||
wal_file = Some(file);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,35 +1,42 @@
|
||||
//! This module contains timeline id -> safekeeper state map with file-backed
|
||||
//! persistence and support for interaction between sending and receiving wal.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
|
||||
use anyhow::{bail, ensure, Context, Result};
|
||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||
use lazy_static::lazy_static;
|
||||
|
||||
use postgres_ffi::xlog_utils::{find_end_of_wal, XLogSegNo, PG_TLI};
|
||||
use std::cmp::{max, min};
|
||||
use std::collections::HashMap;
|
||||
use std::fs::{self};
|
||||
|
||||
use std::fs::{self, File, OpenOptions};
|
||||
use std::io::{Read, Seek, SeekFrom, Write};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::{Arc, Condvar, Mutex};
|
||||
use std::time::Duration;
|
||||
use tokio::sync::mpsc::UnboundedSender;
|
||||
use tracing::*;
|
||||
|
||||
use zenith_metrics::{register_histogram_vec, Histogram, HistogramVec, DISK_WRITE_SECONDS_BUCKETS};
|
||||
use zenith_utils::bin_ser::LeSer;
|
||||
use zenith_utils::lsn::Lsn;
|
||||
use zenith_utils::zid::ZTenantTimelineId;
|
||||
|
||||
use crate::callmemaybe::{CallmeEvent, SubscriptionStateKey};
|
||||
use crate::control_file::{self, CreateControlFile};
|
||||
|
||||
use crate::safekeeper::{
|
||||
AcceptorProposerMessage, ProposerAcceptorMessage, SafeKeeper, SafeKeeperState,
|
||||
AcceptorProposerMessage, ProposerAcceptorMessage, SafeKeeper, SafeKeeperState, ServerInfo,
|
||||
Storage, SK_FORMAT_VERSION, SK_MAGIC,
|
||||
};
|
||||
use crate::send_wal::HotStandbyFeedback;
|
||||
use crate::wal_storage::{self, Storage};
|
||||
use crate::upgrade::upgrade_control_file;
|
||||
use crate::SafeKeeperConf;
|
||||
|
||||
use postgres_ffi::xlog_utils::{XLogFileName, XLOG_BLCKSZ};
|
||||
use std::convert::TryInto;
|
||||
use zenith_utils::pq_proto::ZenithFeedback;
|
||||
|
||||
// contains persistent metadata for safekeeper
|
||||
const CONTROL_FILE_NAME: &str = "safekeeper.control";
|
||||
// needed to atomically update the state using `rename`
|
||||
const CONTROL_FILE_NAME_PARTIAL: &str = "safekeeper.control.partial";
|
||||
const POLL_STATE_TIMEOUT: Duration = Duration::from_secs(1);
|
||||
pub const CHECKSUM_SIZE: usize = std::mem::size_of::<u32>();
|
||||
|
||||
/// Replica status update + hot standby feedback
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
@@ -68,7 +75,7 @@ impl ReplicaState {
|
||||
/// Shared state associated with database instance
|
||||
struct SharedState {
|
||||
/// Safekeeper object
|
||||
sk: SafeKeeper<control_file::FileStorage, wal_storage::PhysicalStorage>,
|
||||
sk: SafeKeeper<FileStorage>,
|
||||
/// For receiving-sending wal cooperation
|
||||
/// quorum commit LSN we've notified walsenders about
|
||||
notified_commit_lsn: Lsn,
|
||||
@@ -86,6 +93,23 @@ struct SharedState {
|
||||
pageserver_connstr: Option<String>,
|
||||
}
|
||||
|
||||
// A named boolean.
|
||||
#[derive(Debug)]
|
||||
pub enum CreateControlFile {
|
||||
True,
|
||||
False,
|
||||
}
|
||||
|
||||
lazy_static! {
|
||||
static ref PERSIST_CONTROL_FILE_SECONDS: HistogramVec = register_histogram_vec!(
|
||||
"safekeeper_persist_control_file_seconds",
|
||||
"Seconds to persist and sync control file, grouped by timeline",
|
||||
&["timeline_id"],
|
||||
DISK_WRITE_SECONDS_BUCKETS.to_vec()
|
||||
)
|
||||
.expect("Failed to register safekeeper_persist_control_file_seconds histogram vec");
|
||||
}
|
||||
|
||||
impl SharedState {
|
||||
/// Restore SharedState from control file.
|
||||
/// If create=false and file doesn't exist, bails out.
|
||||
@@ -94,18 +118,32 @@ impl SharedState {
|
||||
zttid: &ZTenantTimelineId,
|
||||
create: CreateControlFile,
|
||||
) -> Result<Self> {
|
||||
let state = control_file::FileStorage::load_control_file_conf(conf, zttid, create)
|
||||
let state = FileStorage::load_control_file_conf(conf, zttid, create)
|
||||
.context("failed to load from control file")?;
|
||||
|
||||
let control_store = control_file::FileStorage::new(zttid, conf);
|
||||
|
||||
let wal_store = wal_storage::PhysicalStorage::new(zttid, conf);
|
||||
|
||||
info!("timeline {} created or restored", zttid.timeline_id);
|
||||
let file_storage = FileStorage::new(zttid, conf);
|
||||
let flush_lsn = if state.server.wal_seg_size != 0 {
|
||||
let wal_dir = conf.timeline_dir(zttid);
|
||||
Lsn(find_end_of_wal(
|
||||
&wal_dir,
|
||||
state.server.wal_seg_size as usize,
|
||||
true,
|
||||
state.wal_start_lsn,
|
||||
)?
|
||||
.0)
|
||||
} else {
|
||||
Lsn(0)
|
||||
};
|
||||
info!(
|
||||
"timeline {} created or restored: flush_lsn={}, commit_lsn={}, truncate_lsn={}",
|
||||
zttid.timeline_id, flush_lsn, state.commit_lsn, state.truncate_lsn,
|
||||
);
|
||||
if flush_lsn < state.commit_lsn || flush_lsn < state.truncate_lsn {
|
||||
warn!("timeline {} potential data loss: flush_lsn by find_end_of_wal is less than either commit_lsn or truncate_lsn from control file", zttid.timeline_id);
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
notified_commit_lsn: Lsn(0),
|
||||
sk: SafeKeeper::new(zttid.timeline_id, control_store, wal_store, state),
|
||||
sk: SafeKeeper::new(zttid.timeline_id, flush_lsn, file_storage, state),
|
||||
replicas: Vec::new(),
|
||||
active: false,
|
||||
num_computes: 0,
|
||||
@@ -412,7 +450,7 @@ impl Timeline {
|
||||
|
||||
pub fn get_end_of_wal(&self) -> Lsn {
|
||||
let shared_state = self.mutex.lock().unwrap();
|
||||
shared_state.sk.wal_store.flush_lsn()
|
||||
shared_state.sk.flush_lsn
|
||||
}
|
||||
}
|
||||
|
||||
@@ -488,3 +526,397 @@ impl GlobalTimelines {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FileStorage {
|
||||
// save timeline dir to avoid reconstructing it every time
|
||||
timeline_dir: PathBuf,
|
||||
conf: SafeKeeperConf,
|
||||
persist_control_file_seconds: Histogram,
|
||||
}
|
||||
|
||||
impl FileStorage {
|
||||
fn new(zttid: &ZTenantTimelineId, conf: &SafeKeeperConf) -> FileStorage {
|
||||
let timeline_dir = conf.timeline_dir(zttid);
|
||||
let timelineid_str = format!("{}", zttid);
|
||||
FileStorage {
|
||||
timeline_dir,
|
||||
conf: conf.clone(),
|
||||
persist_control_file_seconds: PERSIST_CONTROL_FILE_SECONDS
|
||||
.with_label_values(&[&timelineid_str]),
|
||||
}
|
||||
}
|
||||
|
||||
// Check the magic/version in the on-disk data and deserialize it, if possible.
|
||||
fn deser_sk_state(buf: &mut &[u8]) -> Result<SafeKeeperState> {
|
||||
// Read the version independent part
|
||||
let magic = buf.read_u32::<LittleEndian>()?;
|
||||
if magic != SK_MAGIC {
|
||||
bail!(
|
||||
"bad control file magic: {:X}, expected {:X}",
|
||||
magic,
|
||||
SK_MAGIC
|
||||
);
|
||||
}
|
||||
let version = buf.read_u32::<LittleEndian>()?;
|
||||
if version == SK_FORMAT_VERSION {
|
||||
let res = SafeKeeperState::des(buf)?;
|
||||
return Ok(res);
|
||||
}
|
||||
// try to upgrade
|
||||
upgrade_control_file(buf, version)
|
||||
}
|
||||
|
||||
// Load control file for given zttid at path specified by conf.
|
||||
fn load_control_file_conf(
|
||||
conf: &SafeKeeperConf,
|
||||
zttid: &ZTenantTimelineId,
|
||||
create: CreateControlFile,
|
||||
) -> Result<SafeKeeperState> {
|
||||
let path = conf.timeline_dir(zttid).join(CONTROL_FILE_NAME);
|
||||
Self::load_control_file(path, create)
|
||||
}
|
||||
|
||||
/// Read in the control file.
|
||||
/// If create=false and file doesn't exist, bails out.
|
||||
pub fn load_control_file<P: AsRef<Path>>(
|
||||
control_file_path: P,
|
||||
create: CreateControlFile,
|
||||
) -> Result<SafeKeeperState> {
|
||||
info!(
|
||||
"loading control file {}, create={:?}",
|
||||
control_file_path.as_ref().display(),
|
||||
create,
|
||||
);
|
||||
|
||||
let mut control_file = OpenOptions::new()
|
||||
.read(true)
|
||||
.write(true)
|
||||
.create(matches!(create, CreateControlFile::True))
|
||||
.open(&control_file_path)
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"failed to open control file at {}",
|
||||
control_file_path.as_ref().display(),
|
||||
)
|
||||
})?;
|
||||
|
||||
// Empty file is legit on 'create', don't try to deser from it.
|
||||
let state = if control_file.metadata().unwrap().len() == 0 {
|
||||
if let CreateControlFile::False = create {
|
||||
bail!("control file is empty");
|
||||
}
|
||||
SafeKeeperState::new()
|
||||
} else {
|
||||
let mut buf = Vec::new();
|
||||
control_file
|
||||
.read_to_end(&mut buf)
|
||||
.context("failed to read control file")?;
|
||||
|
||||
let calculated_checksum = crc32c::crc32c(&buf[..buf.len() - CHECKSUM_SIZE]);
|
||||
|
||||
let expected_checksum_bytes: &[u8; CHECKSUM_SIZE] =
|
||||
buf[buf.len() - CHECKSUM_SIZE..].try_into()?;
|
||||
let expected_checksum = u32::from_le_bytes(*expected_checksum_bytes);
|
||||
|
||||
ensure!(
|
||||
calculated_checksum == expected_checksum,
|
||||
format!(
|
||||
"safekeeper control file checksum mismatch: expected {} got {}",
|
||||
expected_checksum, calculated_checksum
|
||||
)
|
||||
);
|
||||
|
||||
FileStorage::deser_sk_state(&mut &buf[..buf.len() - CHECKSUM_SIZE]).with_context(
|
||||
|| {
|
||||
format!(
|
||||
"while reading control file {}",
|
||||
control_file_path.as_ref().display(),
|
||||
)
|
||||
},
|
||||
)?
|
||||
};
|
||||
Ok(state)
|
||||
}
|
||||
|
||||
/// Helper returning full path to WAL segment file and its .partial brother.
|
||||
fn wal_file_paths(&self, segno: XLogSegNo, wal_seg_size: usize) -> (PathBuf, PathBuf) {
|
||||
let wal_file_name = XLogFileName(PG_TLI, segno, wal_seg_size);
|
||||
let wal_file_path = self.timeline_dir.join(wal_file_name.clone());
|
||||
let wal_file_partial_path = self.timeline_dir.join(wal_file_name + ".partial");
|
||||
(wal_file_path, wal_file_partial_path)
|
||||
}
|
||||
}
|
||||
|
||||
impl Storage for FileStorage {
|
||||
// persists state durably to underlying storage
|
||||
// for description see https://lwn.net/Articles/457667/
|
||||
fn persist(&mut self, s: &SafeKeeperState) -> Result<()> {
|
||||
let _timer = &self.persist_control_file_seconds.start_timer();
|
||||
|
||||
// write data to safekeeper.control.partial
|
||||
let control_partial_path = self.timeline_dir.join(CONTROL_FILE_NAME_PARTIAL);
|
||||
let mut control_partial = File::create(&control_partial_path).with_context(|| {
|
||||
format!(
|
||||
"failed to create partial control file at: {}",
|
||||
&control_partial_path.display()
|
||||
)
|
||||
})?;
|
||||
let mut buf: Vec<u8> = Vec::new();
|
||||
buf.write_u32::<LittleEndian>(SK_MAGIC)?;
|
||||
buf.write_u32::<LittleEndian>(SK_FORMAT_VERSION)?;
|
||||
s.ser_into(&mut buf)?;
|
||||
|
||||
// calculate checksum before resize
|
||||
let checksum = crc32c::crc32c(&buf);
|
||||
buf.extend_from_slice(&checksum.to_le_bytes());
|
||||
|
||||
control_partial.write_all(&buf).with_context(|| {
|
||||
format!(
|
||||
"failed to write safekeeper state into control file at: {}",
|
||||
control_partial_path.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
// fsync the file
|
||||
control_partial.sync_all().with_context(|| {
|
||||
format!(
|
||||
"failed to sync partial control file at {}",
|
||||
control_partial_path.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
let control_path = self.timeline_dir.join(CONTROL_FILE_NAME);
|
||||
|
||||
// rename should be atomic
|
||||
fs::rename(&control_partial_path, &control_path)?;
|
||||
// this sync is not required by any standard but postgres does this (see durable_rename)
|
||||
File::open(&control_path)
|
||||
.and_then(|f| f.sync_all())
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"failed to sync control file at: {}",
|
||||
&control_path.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
// fsync the directory (linux specific)
|
||||
File::open(&self.timeline_dir)
|
||||
.and_then(|f| f.sync_all())
|
||||
.context("failed to sync control file directory")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn write_wal(&mut self, server: &ServerInfo, startpos: Lsn, buf: &[u8]) -> Result<()> {
|
||||
let mut bytes_left: usize = buf.len();
|
||||
let mut bytes_written: usize = 0;
|
||||
let mut partial;
|
||||
let mut start_pos = startpos;
|
||||
const ZERO_BLOCK: &[u8] = &[0u8; XLOG_BLCKSZ];
|
||||
let wal_seg_size = server.wal_seg_size as usize;
|
||||
|
||||
/* Extract WAL location for this block */
|
||||
let mut xlogoff = start_pos.segment_offset(wal_seg_size) as usize;
|
||||
|
||||
while bytes_left != 0 {
|
||||
let bytes_to_write;
|
||||
|
||||
/*
|
||||
* If crossing a WAL boundary, only write up until we reach wal
|
||||
* segment size.
|
||||
*/
|
||||
if xlogoff + bytes_left > wal_seg_size {
|
||||
bytes_to_write = wal_seg_size - xlogoff;
|
||||
} else {
|
||||
bytes_to_write = bytes_left;
|
||||
}
|
||||
|
||||
/* Open file */
|
||||
let segno = start_pos.segment_number(wal_seg_size);
|
||||
let (wal_file_path, wal_file_partial_path) = self.wal_file_paths(segno, wal_seg_size);
|
||||
{
|
||||
let mut wal_file: File;
|
||||
/* Try to open already completed segment */
|
||||
if let Ok(file) = OpenOptions::new().write(true).open(&wal_file_path) {
|
||||
wal_file = file;
|
||||
partial = false;
|
||||
} else if let Ok(file) = OpenOptions::new().write(true).open(&wal_file_partial_path)
|
||||
{
|
||||
/* Try to open existed partial file */
|
||||
wal_file = file;
|
||||
partial = true;
|
||||
} else {
|
||||
/* Create and fill new partial file */
|
||||
partial = true;
|
||||
match OpenOptions::new()
|
||||
.create(true)
|
||||
.write(true)
|
||||
.open(&wal_file_partial_path)
|
||||
{
|
||||
Ok(mut file) => {
|
||||
for _ in 0..(wal_seg_size / XLOG_BLCKSZ) {
|
||||
file.write_all(ZERO_BLOCK)?;
|
||||
}
|
||||
wal_file = file;
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to open log file {:?}: {}", &wal_file_path, e);
|
||||
return Err(e.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
wal_file.seek(SeekFrom::Start(xlogoff as u64))?;
|
||||
wal_file.write_all(&buf[bytes_written..(bytes_written + bytes_to_write)])?;
|
||||
|
||||
// Flush file, if not said otherwise
|
||||
if !self.conf.no_sync {
|
||||
wal_file.sync_all()?;
|
||||
}
|
||||
}
|
||||
/* Write was successful, advance our position */
|
||||
bytes_written += bytes_to_write;
|
||||
bytes_left -= bytes_to_write;
|
||||
start_pos += bytes_to_write as u64;
|
||||
xlogoff += bytes_to_write;
|
||||
|
||||
/* Did we reach the end of a WAL segment? */
|
||||
if start_pos.segment_offset(wal_seg_size) == 0 {
|
||||
xlogoff = 0;
|
||||
if partial {
|
||||
fs::rename(&wal_file_partial_path, &wal_file_path)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn truncate_wal(&mut self, server: &ServerInfo, end_pos: Lsn) -> Result<()> {
|
||||
let partial;
|
||||
const ZERO_BLOCK: &[u8] = &[0u8; XLOG_BLCKSZ];
|
||||
let wal_seg_size = server.wal_seg_size as usize;
|
||||
|
||||
/* Extract WAL location for this block */
|
||||
let mut xlogoff = end_pos.segment_offset(wal_seg_size) as usize;
|
||||
|
||||
/* Open file */
|
||||
let mut segno = end_pos.segment_number(wal_seg_size);
|
||||
let (wal_file_path, wal_file_partial_path) = self.wal_file_paths(segno, wal_seg_size);
|
||||
{
|
||||
let mut wal_file: File;
|
||||
/* Try to open already completed segment */
|
||||
if let Ok(file) = OpenOptions::new().write(true).open(&wal_file_path) {
|
||||
wal_file = file;
|
||||
partial = false;
|
||||
} else {
|
||||
wal_file = OpenOptions::new()
|
||||
.write(true)
|
||||
.open(&wal_file_partial_path)?;
|
||||
partial = true;
|
||||
}
|
||||
wal_file.seek(SeekFrom::Start(xlogoff as u64))?;
|
||||
while xlogoff < wal_seg_size {
|
||||
let bytes_to_write = min(XLOG_BLCKSZ, wal_seg_size - xlogoff);
|
||||
wal_file.write_all(&ZERO_BLOCK[0..bytes_to_write])?;
|
||||
xlogoff += bytes_to_write;
|
||||
}
|
||||
// Flush file, if not said otherwise
|
||||
if !self.conf.no_sync {
|
||||
wal_file.sync_all()?;
|
||||
}
|
||||
}
|
||||
if !partial {
|
||||
// Make segment partial once again
|
||||
fs::rename(&wal_file_path, &wal_file_partial_path)?;
|
||||
}
|
||||
// Remove all subsequent segments
|
||||
loop {
|
||||
segno += 1;
|
||||
let (wal_file_path, wal_file_partial_path) = self.wal_file_paths(segno, wal_seg_size);
|
||||
// TODO: better use fs::try_exists which is currenty avaialble only in nightly build
|
||||
if wal_file_path.exists() {
|
||||
fs::remove_file(&wal_file_path)?;
|
||||
} else if wal_file_partial_path.exists() {
|
||||
fs::remove_file(&wal_file_partial_path)?;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::FileStorage;
|
||||
use crate::{
|
||||
safekeeper::{SafeKeeperState, Storage},
|
||||
timeline::{CreateControlFile, CONTROL_FILE_NAME},
|
||||
SafeKeeperConf, ZTenantTimelineId,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use std::fs;
|
||||
use zenith_utils::lsn::Lsn;
|
||||
|
||||
fn stub_conf() -> SafeKeeperConf {
|
||||
let workdir = tempfile::tempdir().unwrap().into_path();
|
||||
SafeKeeperConf {
|
||||
workdir,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
fn load_from_control_file(
|
||||
conf: &SafeKeeperConf,
|
||||
zttid: &ZTenantTimelineId,
|
||||
create: CreateControlFile,
|
||||
) -> Result<(FileStorage, SafeKeeperState)> {
|
||||
fs::create_dir_all(&conf.timeline_dir(zttid)).expect("failed to create timeline dir");
|
||||
Ok((
|
||||
FileStorage::new(zttid, conf),
|
||||
FileStorage::load_control_file_conf(conf, zttid, create)?,
|
||||
))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_read_write_safekeeper_state() {
|
||||
let conf = stub_conf();
|
||||
let zttid = ZTenantTimelineId::generate();
|
||||
{
|
||||
let (mut storage, mut state) =
|
||||
load_from_control_file(&conf, &zttid, CreateControlFile::True)
|
||||
.expect("failed to read state");
|
||||
// change something
|
||||
state.wal_start_lsn = Lsn(42);
|
||||
storage.persist(&state).expect("failed to persist state");
|
||||
}
|
||||
|
||||
let (_, state) = load_from_control_file(&conf, &zttid, CreateControlFile::False)
|
||||
.expect("failed to read state");
|
||||
assert_eq!(state.wal_start_lsn, Lsn(42));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_safekeeper_state_checksum_mismatch() {
|
||||
let conf = stub_conf();
|
||||
let zttid = ZTenantTimelineId::generate();
|
||||
{
|
||||
let (mut storage, mut state) =
|
||||
load_from_control_file(&conf, &zttid, CreateControlFile::True)
|
||||
.expect("failed to read state");
|
||||
// change something
|
||||
state.wal_start_lsn = Lsn(42);
|
||||
storage.persist(&state).expect("failed to persist state");
|
||||
}
|
||||
let control_path = conf.timeline_dir(&zttid).join(CONTROL_FILE_NAME);
|
||||
let mut data = fs::read(&control_path).unwrap();
|
||||
data[0] += 1; // change the first byte of the file to fail checksum validation
|
||||
fs::write(&control_path, &data).expect("failed to write control file");
|
||||
|
||||
match load_from_control_file(&conf, &zttid, CreateControlFile::False) {
|
||||
Err(err) => assert!(err
|
||||
.to_string()
|
||||
.contains("safekeeper control file checksum mismatch")),
|
||||
Ok(_) => panic!("expected error"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,493 +0,0 @@
|
||||
//! This module has everything to deal with WAL -- reading and writing to disk.
|
||||
//!
|
||||
//! Safekeeper WAL is stored in the timeline directory, in format similar to pg_wal.
|
||||
//! PG timeline is always 1, so WAL segments are usually have names like this:
|
||||
//! - 000000010000000000000001
|
||||
//! - 000000010000000000000002.partial
|
||||
//!
|
||||
//! Note that last file has `.partial` suffix, that's different from postgres.
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use std::io::{Read, Seek, SeekFrom};
|
||||
|
||||
use lazy_static::lazy_static;
|
||||
use postgres_ffi::xlog_utils::{find_end_of_wal, XLogSegNo, PG_TLI};
|
||||
use std::cmp::min;
|
||||
|
||||
use std::fs::{self, File, OpenOptions};
|
||||
use std::io::Write;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use tracing::*;
|
||||
|
||||
use zenith_utils::lsn::Lsn;
|
||||
use zenith_utils::zid::ZTenantTimelineId;
|
||||
|
||||
use crate::safekeeper::SafeKeeperState;
|
||||
|
||||
use crate::SafeKeeperConf;
|
||||
use postgres_ffi::xlog_utils::{XLogFileName, XLOG_BLCKSZ};
|
||||
|
||||
use postgres_ffi::waldecoder::WalStreamDecoder;
|
||||
|
||||
use zenith_metrics::{
|
||||
register_gauge_vec, register_histogram_vec, Gauge, GaugeVec, Histogram, HistogramVec,
|
||||
DISK_WRITE_SECONDS_BUCKETS,
|
||||
};
|
||||
|
||||
lazy_static! {
|
||||
// The prometheus crate does not support u64 yet, i64 only (see `IntGauge`).
|
||||
// i64 is faster than f64, so update to u64 when available.
|
||||
static ref FLUSH_LSN_GAUGE: GaugeVec = register_gauge_vec!(
|
||||
"safekeeper_flush_lsn",
|
||||
"Current flush_lsn, grouped by timeline",
|
||||
&["tenant_id", "timeline_id"]
|
||||
)
|
||||
.expect("Failed to register safekeeper_flush_lsn gauge vec");
|
||||
static ref WRITE_WAL_BYTES: HistogramVec = register_histogram_vec!(
|
||||
"safekeeper_write_wal_bytes",
|
||||
"Bytes written to WAL in a single request, grouped by timeline",
|
||||
&["tenant_id", "timeline_id"],
|
||||
vec![1.0, 10.0, 100.0, 1024.0, 8192.0, 128.0 * 1024.0, 1024.0 * 1024.0, 10.0 * 1024.0 * 1024.0]
|
||||
)
|
||||
.expect("Failed to register safekeeper_write_wal_bytes histogram vec");
|
||||
static ref WRITE_WAL_SECONDS: HistogramVec = register_histogram_vec!(
|
||||
"safekeeper_write_wal_seconds",
|
||||
"Seconds spent writing and syncing WAL to a disk in a single request, grouped by timeline",
|
||||
&["tenant_id", "timeline_id"],
|
||||
DISK_WRITE_SECONDS_BUCKETS.to_vec()
|
||||
)
|
||||
.expect("Failed to register safekeeper_write_wal_seconds histogram vec");
|
||||
}
|
||||
|
||||
struct WalStorageMetrics {
|
||||
flush_lsn: Gauge,
|
||||
write_wal_bytes: Histogram,
|
||||
write_wal_seconds: Histogram,
|
||||
}
|
||||
|
||||
impl WalStorageMetrics {
|
||||
fn new(zttid: &ZTenantTimelineId) -> Self {
|
||||
let tenant_id = zttid.tenant_id.to_string();
|
||||
let timeline_id = zttid.timeline_id.to_string();
|
||||
Self {
|
||||
flush_lsn: FLUSH_LSN_GAUGE.with_label_values(&[&tenant_id, &timeline_id]),
|
||||
write_wal_bytes: WRITE_WAL_BYTES.with_label_values(&[&tenant_id, &timeline_id]),
|
||||
write_wal_seconds: WRITE_WAL_SECONDS.with_label_values(&[&tenant_id, &timeline_id]),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Storage {
|
||||
/// lsn of last durably stored WAL record.
|
||||
fn flush_lsn(&self) -> Lsn;
|
||||
|
||||
/// Init storage with wal_seg_size and read WAL from disk to get latest lsn.
|
||||
fn init_storage(&mut self, state: &SafeKeeperState) -> Result<()>;
|
||||
|
||||
/// Write piece of wal in buf to disk and sync it.
|
||||
fn write_wal(&mut self, startpos: Lsn, buf: &[u8]) -> Result<()>;
|
||||
|
||||
// Truncate WAL at specified LSN.
|
||||
fn truncate_wal(&mut self, end_pos: Lsn) -> Result<()>;
|
||||
}
|
||||
|
||||
pub struct PhysicalStorage {
|
||||
metrics: WalStorageMetrics,
|
||||
zttid: ZTenantTimelineId,
|
||||
timeline_dir: PathBuf,
|
||||
conf: SafeKeeperConf,
|
||||
|
||||
// fields below are filled upon initialization
|
||||
|
||||
// None if unitialized, Some(lsn) if storage is initialized
|
||||
wal_seg_size: Option<usize>,
|
||||
|
||||
// Relationship of lsns:
|
||||
// `write_lsn` >= `write_record_lsn` >= `flush_record_lsn`
|
||||
//
|
||||
// All lsns are zeroes, if storage is just created, and there are no segments on disk.
|
||||
|
||||
// Written to disk, but possibly still in the cache and not fully persisted.
|
||||
// Also can be ahead of record_lsn, if happen to be in the middle of a WAL record.
|
||||
write_lsn: Lsn,
|
||||
|
||||
// The LSN of the last WAL record written to disk. Still can be not fully flushed.
|
||||
write_record_lsn: Lsn,
|
||||
|
||||
// The LSN of the last WAL record flushed to disk.
|
||||
flush_record_lsn: Lsn,
|
||||
|
||||
// Decoder is required for detecting boundaries of WAL records.
|
||||
decoder: WalStreamDecoder,
|
||||
}
|
||||
|
||||
impl PhysicalStorage {
|
||||
pub fn new(zttid: &ZTenantTimelineId, conf: &SafeKeeperConf) -> PhysicalStorage {
|
||||
let timeline_dir = conf.timeline_dir(zttid);
|
||||
PhysicalStorage {
|
||||
metrics: WalStorageMetrics::new(zttid),
|
||||
zttid: *zttid,
|
||||
timeline_dir,
|
||||
conf: conf.clone(),
|
||||
wal_seg_size: None,
|
||||
write_lsn: Lsn(0),
|
||||
write_record_lsn: Lsn(0),
|
||||
flush_record_lsn: Lsn(0),
|
||||
decoder: WalStreamDecoder::new(Lsn(0)),
|
||||
}
|
||||
}
|
||||
|
||||
// wrapper for flush_lsn updates that also updates metrics
|
||||
fn update_flush_lsn(&mut self) {
|
||||
self.flush_record_lsn = self.write_record_lsn;
|
||||
self.metrics.flush_lsn.set(self.flush_record_lsn.0 as f64);
|
||||
}
|
||||
|
||||
/// Helper returning full path to WAL segment file and its .partial brother.
|
||||
fn wal_file_paths(&self, segno: XLogSegNo) -> Result<(PathBuf, PathBuf)> {
|
||||
let wal_seg_size = self
|
||||
.wal_seg_size
|
||||
.ok_or_else(|| anyhow!("wal_seg_size is not initialized"))?;
|
||||
|
||||
let wal_file_name = XLogFileName(PG_TLI, segno, wal_seg_size);
|
||||
let wal_file_path = self.timeline_dir.join(wal_file_name.clone());
|
||||
let wal_file_partial_path = self.timeline_dir.join(wal_file_name + ".partial");
|
||||
Ok((wal_file_path, wal_file_partial_path))
|
||||
}
|
||||
|
||||
// TODO: this function is going to be refactored soon, what will change:
|
||||
// - flush will be called separately from write_wal, this function
|
||||
// will only write bytes to disk
|
||||
// - File will be cached in PhysicalStorage, to remove extra syscalls,
|
||||
// such as open(), seek(), close()
|
||||
fn write_and_flush(&mut self, startpos: Lsn, buf: &[u8]) -> Result<()> {
|
||||
let wal_seg_size = self
|
||||
.wal_seg_size
|
||||
.ok_or_else(|| anyhow!("wal_seg_size is not initialized"))?;
|
||||
|
||||
let mut bytes_left: usize = buf.len();
|
||||
let mut bytes_written: usize = 0;
|
||||
let mut partial;
|
||||
let mut start_pos = startpos;
|
||||
const ZERO_BLOCK: &[u8] = &[0u8; XLOG_BLCKSZ];
|
||||
|
||||
/* Extract WAL location for this block */
|
||||
let mut xlogoff = start_pos.segment_offset(wal_seg_size) as usize;
|
||||
|
||||
while bytes_left != 0 {
|
||||
let bytes_to_write;
|
||||
|
||||
/*
|
||||
* If crossing a WAL boundary, only write up until we reach wal
|
||||
* segment size.
|
||||
*/
|
||||
if xlogoff + bytes_left > wal_seg_size {
|
||||
bytes_to_write = wal_seg_size - xlogoff;
|
||||
} else {
|
||||
bytes_to_write = bytes_left;
|
||||
}
|
||||
|
||||
/* Open file */
|
||||
let segno = start_pos.segment_number(wal_seg_size);
|
||||
let (wal_file_path, wal_file_partial_path) = self.wal_file_paths(segno)?;
|
||||
{
|
||||
let mut wal_file: File;
|
||||
/* Try to open already completed segment */
|
||||
if let Ok(file) = OpenOptions::new().write(true).open(&wal_file_path) {
|
||||
wal_file = file;
|
||||
partial = false;
|
||||
} else if let Ok(file) = OpenOptions::new().write(true).open(&wal_file_partial_path)
|
||||
{
|
||||
/* Try to open existed partial file */
|
||||
wal_file = file;
|
||||
partial = true;
|
||||
} else {
|
||||
/* Create and fill new partial file */
|
||||
partial = true;
|
||||
match OpenOptions::new()
|
||||
.create(true)
|
||||
.write(true)
|
||||
.open(&wal_file_partial_path)
|
||||
{
|
||||
Ok(mut file) => {
|
||||
for _ in 0..(wal_seg_size / XLOG_BLCKSZ) {
|
||||
file.write_all(ZERO_BLOCK)?;
|
||||
}
|
||||
wal_file = file;
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to open log file {:?}: {}", &wal_file_path, e);
|
||||
return Err(e.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
wal_file.seek(SeekFrom::Start(xlogoff as u64))?;
|
||||
wal_file.write_all(&buf[bytes_written..(bytes_written + bytes_to_write)])?;
|
||||
|
||||
// Flush file, if not said otherwise
|
||||
if !self.conf.no_sync {
|
||||
wal_file.sync_all()?;
|
||||
}
|
||||
}
|
||||
/* Write was successful, advance our position */
|
||||
bytes_written += bytes_to_write;
|
||||
bytes_left -= bytes_to_write;
|
||||
start_pos += bytes_to_write as u64;
|
||||
xlogoff += bytes_to_write;
|
||||
|
||||
/* Did we reach the end of a WAL segment? */
|
||||
if start_pos.segment_offset(wal_seg_size) == 0 {
|
||||
xlogoff = 0;
|
||||
if partial {
|
||||
fs::rename(&wal_file_partial_path, &wal_file_path)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Storage for PhysicalStorage {
|
||||
// flush_lsn returns lsn of last durably stored WAL record.
|
||||
fn flush_lsn(&self) -> Lsn {
|
||||
self.flush_record_lsn
|
||||
}
|
||||
|
||||
// Storage needs to know wal_seg_size to know which segment to read/write, but
|
||||
// wal_seg_size is not always known at the moment of storage creation. This method
|
||||
// allows to postpone its initialization.
|
||||
fn init_storage(&mut self, state: &SafeKeeperState) -> Result<()> {
|
||||
if state.server.wal_seg_size == 0 {
|
||||
// wal_seg_size is still unknown
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if let Some(wal_seg_size) = self.wal_seg_size {
|
||||
// physical storage is already initialized
|
||||
assert_eq!(wal_seg_size, state.server.wal_seg_size as usize);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// initialize physical storage
|
||||
let wal_seg_size = state.server.wal_seg_size as usize;
|
||||
self.wal_seg_size = Some(wal_seg_size);
|
||||
|
||||
// we need to read WAL from disk to know which LSNs are stored on disk
|
||||
self.write_lsn =
|
||||
Lsn(find_end_of_wal(&self.timeline_dir, wal_seg_size, true, state.wal_start_lsn)?.0);
|
||||
|
||||
self.write_record_lsn = self.write_lsn;
|
||||
|
||||
// TODO: do we really know that write_lsn is fully flushed to disk?
|
||||
// If not, maybe it's better to call fsync() here to be sure?
|
||||
self.update_flush_lsn();
|
||||
|
||||
info!(
|
||||
"initialized storage for timeline {}, flush_lsn={}, commit_lsn={}, truncate_lsn={}",
|
||||
self.zttid.timeline_id, self.flush_record_lsn, state.commit_lsn, state.truncate_lsn,
|
||||
);
|
||||
if self.flush_record_lsn < state.commit_lsn || self.flush_record_lsn < state.truncate_lsn {
|
||||
warn!("timeline {} potential data loss: flush_lsn by find_end_of_wal is less than either commit_lsn or truncate_lsn from control file", self.zttid.timeline_id);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Write and flush WAL to disk.
|
||||
fn write_wal(&mut self, startpos: Lsn, buf: &[u8]) -> Result<()> {
|
||||
if self.write_lsn > startpos {
|
||||
warn!(
|
||||
"write_wal rewrites WAL written before, write_lsn={}, startpos={}",
|
||||
self.write_lsn, startpos
|
||||
);
|
||||
}
|
||||
if self.write_lsn < startpos {
|
||||
warn!(
|
||||
"write_wal creates gap in written WAL, write_lsn={}, startpos={}",
|
||||
self.write_lsn, startpos
|
||||
);
|
||||
// TODO: return error if write_lsn is not zero
|
||||
}
|
||||
|
||||
{
|
||||
let _timer = self.metrics.write_wal_seconds.start_timer();
|
||||
self.write_and_flush(startpos, buf)?;
|
||||
}
|
||||
|
||||
// WAL is written and flushed, updating lsns
|
||||
self.write_lsn = startpos + buf.len() as u64;
|
||||
self.metrics.write_wal_bytes.observe(buf.len() as f64);
|
||||
|
||||
// figure out last record's end lsn for reporting (if we got the
|
||||
// whole record)
|
||||
if self.decoder.available() != startpos {
|
||||
info!(
|
||||
"restart decoder from {} to {}",
|
||||
self.decoder.available(),
|
||||
startpos,
|
||||
);
|
||||
self.decoder = WalStreamDecoder::new(startpos);
|
||||
}
|
||||
self.decoder.feed_bytes(buf);
|
||||
loop {
|
||||
match self.decoder.poll_decode()? {
|
||||
None => break, // no full record yet
|
||||
Some((lsn, _rec)) => {
|
||||
self.write_record_lsn = lsn;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.update_flush_lsn();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Truncate written WAL by removing all WAL segments after the given LSN.
|
||||
// end_pos must point to the end of the WAL record.
|
||||
fn truncate_wal(&mut self, end_pos: Lsn) -> Result<()> {
|
||||
let wal_seg_size = self
|
||||
.wal_seg_size
|
||||
.ok_or_else(|| anyhow!("wal_seg_size is not initialized"))?;
|
||||
|
||||
// TODO: cross check divergence point
|
||||
|
||||
// nothing to truncate
|
||||
if self.write_lsn == Lsn(0) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Streaming must not create a hole, so truncate cannot be called on non-written lsn
|
||||
assert!(self.write_lsn >= end_pos);
|
||||
|
||||
// open segment files and delete or fill end with zeroes
|
||||
|
||||
let partial;
|
||||
const ZERO_BLOCK: &[u8] = &[0u8; XLOG_BLCKSZ];
|
||||
|
||||
/* Extract WAL location for this block */
|
||||
let mut xlogoff = end_pos.segment_offset(wal_seg_size) as usize;
|
||||
|
||||
/* Open file */
|
||||
let mut segno = end_pos.segment_number(wal_seg_size);
|
||||
let (wal_file_path, wal_file_partial_path) = self.wal_file_paths(segno)?;
|
||||
{
|
||||
let mut wal_file: File;
|
||||
/* Try to open already completed segment */
|
||||
if let Ok(file) = OpenOptions::new().write(true).open(&wal_file_path) {
|
||||
wal_file = file;
|
||||
partial = false;
|
||||
} else {
|
||||
wal_file = OpenOptions::new()
|
||||
.write(true)
|
||||
.open(&wal_file_partial_path)?;
|
||||
partial = true;
|
||||
}
|
||||
wal_file.seek(SeekFrom::Start(xlogoff as u64))?;
|
||||
while xlogoff < wal_seg_size {
|
||||
let bytes_to_write = min(XLOG_BLCKSZ, wal_seg_size - xlogoff);
|
||||
wal_file.write_all(&ZERO_BLOCK[0..bytes_to_write])?;
|
||||
xlogoff += bytes_to_write;
|
||||
}
|
||||
// Flush file, if not said otherwise
|
||||
if !self.conf.no_sync {
|
||||
wal_file.sync_all()?;
|
||||
}
|
||||
}
|
||||
if !partial {
|
||||
// Make segment partial once again
|
||||
fs::rename(&wal_file_path, &wal_file_partial_path)?;
|
||||
}
|
||||
// Remove all subsequent segments
|
||||
loop {
|
||||
segno += 1;
|
||||
let (wal_file_path, wal_file_partial_path) = self.wal_file_paths(segno)?;
|
||||
// TODO: better use fs::try_exists which is currenty avaialble only in nightly build
|
||||
if wal_file_path.exists() {
|
||||
fs::remove_file(&wal_file_path)?;
|
||||
} else if wal_file_partial_path.exists() {
|
||||
fs::remove_file(&wal_file_partial_path)?;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Update lsns
|
||||
self.write_lsn = end_pos;
|
||||
self.write_record_lsn = end_pos;
|
||||
self.update_flush_lsn();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct WalReader {
|
||||
timeline_dir: PathBuf,
|
||||
wal_seg_size: usize,
|
||||
pos: Lsn,
|
||||
file: Option<File>,
|
||||
}
|
||||
|
||||
impl WalReader {
|
||||
pub fn new(timeline_dir: PathBuf, wal_seg_size: usize, pos: Lsn) -> Self {
|
||||
Self {
|
||||
timeline_dir,
|
||||
wal_seg_size,
|
||||
pos,
|
||||
file: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
|
||||
// Take the `File` from `wal_file`, or open a new file.
|
||||
let mut file = match self.file.take() {
|
||||
Some(file) => file,
|
||||
None => {
|
||||
// Open a new file.
|
||||
let segno = self.pos.segment_number(self.wal_seg_size);
|
||||
let wal_file_name = XLogFileName(PG_TLI, segno, self.wal_seg_size);
|
||||
let wal_file_path = self.timeline_dir.join(wal_file_name);
|
||||
Self::open_wal_file(&wal_file_path)?
|
||||
}
|
||||
};
|
||||
|
||||
let xlogoff = self.pos.segment_offset(self.wal_seg_size) as usize;
|
||||
|
||||
// How much to read and send in message? We cannot cross the WAL file
|
||||
// boundary, and we don't want send more than provided buffer.
|
||||
let send_size = min(buf.len(), self.wal_seg_size - xlogoff);
|
||||
|
||||
// Read some data from the file.
|
||||
let buf = &mut buf[0..send_size];
|
||||
file.seek(SeekFrom::Start(xlogoff as u64))
|
||||
.and_then(|_| file.read_exact(buf))
|
||||
.context("Failed to read data from WAL file")?;
|
||||
|
||||
self.pos += send_size as u64;
|
||||
|
||||
// Decide whether to reuse this file. If we don't set wal_file here
|
||||
// a new file will be opened next time.
|
||||
if self.pos.segment_offset(self.wal_seg_size) != 0 {
|
||||
self.file = Some(file);
|
||||
}
|
||||
|
||||
Ok(send_size)
|
||||
}
|
||||
|
||||
/// Helper function for opening a wal file.
|
||||
fn open_wal_file(wal_file_path: &Path) -> Result<File> {
|
||||
// First try to open the .partial file.
|
||||
let mut partial_path = wal_file_path.to_owned();
|
||||
partial_path.set_extension("partial");
|
||||
if let Ok(opened_file) = File::open(&partial_path) {
|
||||
return Ok(opened_file);
|
||||
}
|
||||
|
||||
// If that failed, try it without the .partial extension.
|
||||
File::open(&wal_file_path)
|
||||
.with_context(|| format!("Failed to open WAL file {:?}", wal_file_path))
|
||||
.map_err(|e| {
|
||||
error!("{}", e);
|
||||
e
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -5,4 +5,4 @@ pub mod request;
|
||||
|
||||
/// Current fast way to apply simple http routing in various Zenith binaries.
|
||||
/// Re-exported for sake of uniform approach, that could be later replaced with better alternatives, if needed.
|
||||
pub use routerify::{ext::RequestExt, RouterBuilder, RouterService};
|
||||
pub use routerify::{ext::RequestExt, RouterBuilder};
|
||||
|
||||
@@ -430,11 +430,11 @@ impl PostgresBackend {
|
||||
trace!("got query {:?}", query_string);
|
||||
// xxx distinguish fatal and recoverable errors?
|
||||
if let Err(e) = handler.process_query(self, query_string) {
|
||||
// ":?" uses the alternate formatting style, which makes anyhow display the
|
||||
// full cause of the error, not just the top-level context + its trace.
|
||||
// We don't want to send that in the ErrorResponse though,
|
||||
// because it's not relevant to the compute node logs.
|
||||
error!("query handler for '{}' failed: {:?}", query_string, e);
|
||||
// ":#" uses the alternate formatting style, which makes anyhow display the
|
||||
// full cause of the error, not just the top-level context. We don't want to
|
||||
// send that in the ErrorResponse though, because it's not relevant to the
|
||||
// compute node logs.
|
||||
warn!("query handler for {} failed: {:#}", query_string, e);
|
||||
self.write_message_noflush(&BeMessage::ErrorResponse(&e.to_string()))?;
|
||||
// TODO: untangle convoluted control flow
|
||||
if e.to_string().contains("failed to run") {
|
||||
@@ -467,7 +467,7 @@ impl PostgresBackend {
|
||||
trace!("got execute {:?}", query_string);
|
||||
// xxx distinguish fatal and recoverable errors?
|
||||
if let Err(e) = handler.process_query(self, query_string) {
|
||||
error!("query handler for '{}' failed: {:?}", query_string, e);
|
||||
warn!("query handler for {:?} failed: {:#}", query_string, e);
|
||||
self.write_message(&BeMessage::ErrorResponse(&e.to_string()))?;
|
||||
}
|
||||
// NOTE there is no ReadyForQuery message. This handler is used
|
||||
|
||||
@@ -57,16 +57,6 @@ pub struct CancelKeyData {
|
||||
pub cancel_key: i32,
|
||||
}
|
||||
|
||||
use rand::distributions::{Distribution, Standard};
|
||||
impl Distribution<CancelKeyData> for Standard {
|
||||
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> CancelKeyData {
|
||||
CancelKeyData {
|
||||
backend_pid: rng.gen(),
|
||||
cancel_key: rng.gen(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FeQueryMessage {
|
||||
pub body: Bytes,
|
||||
|
||||
Reference in New Issue
Block a user