Compare commits

..

2 Commits

Author SHA1 Message Date
Bojan Serafimov
800811957b Use pytest plugin 2022-02-15 21:25:55 -05:00
Bojan Serafimov
d650e42ae3 Add large seqscan test 2022-02-15 18:03:04 -05:00
84 changed files with 1895 additions and 3065 deletions

98
Cargo.lock generated
View File

@@ -23,17 +23,6 @@ version = "0.4.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "739f4a8db6605981345c5654f3a85b056ce52f37a39d34da03f25bf2151ea16e"
[[package]]
name = "ahash"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47"
dependencies = [
"getrandom",
"once_cell",
"version_check",
]
[[package]]
name = "aho-corasick"
version = "0.7.18"
@@ -552,17 +541,6 @@ dependencies = [
"termcolor",
]
[[package]]
name = "fail"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec3245a0ca564e7f3c797d20d833a6870f57a728ac967d5225b3ffdef4465011"
dependencies = [
"lazy_static",
"log",
"rand",
]
[[package]]
name = "fallible-iterator"
version = "0.2.0"
@@ -791,7 +769,7 @@ version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d7afe4a420e3fe79967a00898cc1f4db7c8a49a9333a29f8a4bd76a253d5cd04"
dependencies = [
"ahash 0.4.7",
"ahash",
]
[[package]]
@@ -799,9 +777,6 @@ name = "hashbrown"
version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e"
dependencies = [
"ahash 0.7.6",
]
[[package]]
name = "hermit-abi"
@@ -921,7 +896,7 @@ dependencies = [
"hyper",
"rustls 0.20.2",
"tokio",
"tokio-rustls 0.23.2",
"tokio-rustls",
]
[[package]]
@@ -1006,7 +981,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "afabcc15e437a6484fc4f12d0fd63068fe457bf93f1c148d3d9649c60b103f32"
dependencies = [
"base64 0.12.3",
"pem 0.8.3",
"pem",
"ring",
"serde",
"serde_json",
@@ -1300,7 +1275,6 @@ dependencies = [
"crc32c",
"crossbeam-utils",
"daemonize",
"fail",
"futures",
"hex",
"hex-literal",
@@ -1378,15 +1352,6 @@ dependencies = [
"regex",
]
[[package]]
name = "pem"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9a3b09a20e374558580a4914d3b7d89bd61b954a5a5e1dcbea98753addb1947"
dependencies = [
"base64 0.13.0",
]
[[package]]
name = "percent-encoding"
version = "2.1.0"
@@ -1591,25 +1556,17 @@ dependencies = [
"anyhow",
"bytes",
"clap 3.0.14",
"futures",
"hashbrown 0.11.2",
"hex",
"hyper",
"lazy_static",
"md5",
"parking_lot",
"pin-project-lite",
"rand",
"rcgen",
"reqwest",
"rustls 0.19.1",
"scopeguard",
"serde",
"serde_json",
"tokio",
"tokio-postgres 0.7.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=2949d98df52587d562986aad155dd4e889e408b7)",
"tokio-postgres-rustls",
"tokio-rustls 0.22.0",
"zenith_metrics",
"zenith_utils",
]
@@ -1663,18 +1620,6 @@ dependencies = [
"rand_core",
]
[[package]]
name = "rcgen"
version = "0.8.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5911d1403f4143c9d56a702069d593e8d0f3fab880a85e103604d0893ea31ba7"
dependencies = [
"chrono",
"pem 1.0.2",
"ring",
"yasna",
]
[[package]]
name = "redox_syscall"
version = "0.2.10"
@@ -1758,7 +1703,7 @@ dependencies = [
"serde_json",
"serde_urlencoded",
"tokio",
"tokio-rustls 0.23.2",
"tokio-rustls",
"tokio-util",
"url",
"wasm-bindgen",
@@ -2320,32 +2265,6 @@ dependencies = [
"tokio-util",
]
[[package]]
name = "tokio-postgres-rustls"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7bd8c37d8c23cb6ecdc32fc171bade4e9c7f1be65f693a17afbaad02091a0a19"
dependencies = [
"futures",
"ring",
"rustls 0.19.1",
"tokio",
"tokio-postgres 0.7.1 (git+https://github.com/zenithdb/rust-postgres.git?rev=2949d98df52587d562986aad155dd4e889e408b7)",
"tokio-rustls 0.22.0",
"webpki 0.21.4",
]
[[package]]
name = "tokio-rustls"
version = "0.22.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc6844de72e57df1980054b38be3a9f4702aba4858be64dd700181a8a6d0e1b6"
dependencies = [
"rustls 0.19.1",
"tokio",
"webpki 0.21.4",
]
[[package]]
name = "tokio-rustls"
version = "0.23.2"
@@ -2811,15 +2730,6 @@ version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d2d7d3948613f75c98fd9328cfdcc45acc4d360655289d0a7d4ec931392200a3"
[[package]]
name = "yasna"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e262a29d0e61ccf2b6190d7050d4b237535fc76ce4c1210d9caa316f71dffa75"
dependencies = [
"chrono",
]
[[package]]
name = "zenith"
version = "0.1.0"

View File

@@ -16,8 +16,3 @@ members = [
# This is useful for profiling and, to some extent, debug.
# Besides, debug info should not affect the performance.
debug = true
# This is only needed for proxy's tests
# TODO: we should probably fork tokio-postgres-rustls instead
[patch.crates-io]
tokio-postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="2949d98df52587d562986aad155dd4e889e408b7" }

View File

@@ -334,26 +334,14 @@ impl PostgresNode {
if let Some(lsn) = self.lsn {
conf.append("recovery_target_lsn", &lsn.to_string());
}
conf.append_line("");
// Configure backpressure
// - Replication write lag depends on how fast the walreceiver can process incoming WAL.
// This lag determines latency of get_page_at_lsn. Speed of applying WAL is about 10MB/sec,
// so to avoid expiration of 1 minute timeout, this lag should not be larger than 600MB.
// Actually latency should be much smaller (better if < 1sec). But we assume that recently
// updates pages are not requested from pageserver.
// - Replication flush lag depends on speed of persisting data by checkpointer (creation of
// delta/image layers) and advancing disk_consistent_lsn. Safekeepers are able to
// remove/archive WAL only beyond disk_consistent_lsn. Too large a lag can cause long
// recovery time (in case of pageserver crash) and disk space overflow at safekeepers.
// - Replication apply lag depends on speed of uploading changes to S3 by uploader thread.
// To be able to restore database in case of pageserver node crash, safekeeper should not
// remove WAL beyond this point. Too large lag can cause space exhaustion in safekeepers
// (if they are not able to upload WAL to S3).
conf.append("max_replication_write_lag", "500MB");
conf.append("max_replication_flush_lag", "10GB");
if !self.env.safekeepers.is_empty() {
// Configure backpressure
// In setup with safekeepers apply_lag depends on
// speed of data checkpointing on pageserver (see disk_consistent_lsn).
conf.append("max_replication_apply_lag", "1500MB");
// Configure the node to connect to the safekeepers
conf.append("synchronous_standby_names", "walproposer");
@@ -366,6 +354,11 @@ impl PostgresNode {
.join(",");
conf.append("wal_acceptors", &wal_acceptors);
} else {
// Configure backpressure
// In setup without safekeepers, flush_lag depends on
// speed of of data checkpointing on pageserver (see disk_consistent_lsn)
conf.append("max_replication_flush_lag", "1500MB");
// We only use setup without safekeepers for tests,
// and don't care about data durability on pageserver,
// so set more relaxed synchronous_commit.

View File

@@ -41,7 +41,6 @@ url = "2"
nix = "0.23"
once_cell = "1.8.0"
crossbeam-utils = "0.8.5"
fail = "0.5.0"
rust-s3 = { version = "0.28", default-features = false, features = ["no-verify-ssl", "tokio-rustls-tls"] }
async-compression = {version = "0.3", features = ["zstd", "tokio"]}

View File

@@ -234,7 +234,9 @@ paths:
content:
application/json:
schema:
$ref: "#/components/schemas/BranchInfo"
type: array
items:
$ref: "#/components/schemas/BranchInfo"
"400":
description: Malformed branch create request
content:
@@ -368,15 +370,12 @@ components:
format: hex
ancestor_id:
type: string
format: hex
ancestor_lsn:
type: string
current_logical_size:
type: integer
current_logical_size_non_incremental:
type: integer
latest_valid_lsn:
type: integer
TimelineInfo:
type: object
required:

View File

@@ -27,10 +27,13 @@ use zenith_utils::lsn::Lsn;
use zenith_utils::postgres_backend::is_socket_read_timed_out;
use zenith_utils::postgres_backend::PostgresBackend;
use zenith_utils::postgres_backend::{self, AuthType};
use zenith_utils::pq_proto::{BeMessage, FeMessage, RowDescriptor, SINGLE_COL_ROWDESC};
use zenith_utils::pq_proto::{
BeMessage, FeMessage, RowDescriptor, HELLO_WORLD_ROW, SINGLE_COL_ROWDESC,
};
use zenith_utils::zid::{ZTenantId, ZTimelineId};
use crate::basebackup;
use crate::branches;
use crate::config::PageServerConf;
use crate::relish::*;
use crate::repository::Timeline;
@@ -659,21 +662,79 @@ impl postgres_backend::Handler for PageServerHandler {
walreceiver::launch_wal_receiver(self.conf, tenantid, timelineid, &connstr)?;
pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
} else if query_string.starts_with("branch_create ") {
let err = || format!("invalid branch_create: '{}'", query_string);
// branch_create <tenantid> <branchname> <startpoint>
// TODO lazy static
// TODO: escaping, to allow branch names with spaces
let re = Regex::new(r"^branch_create ([[:xdigit:]]+) (\S+) ([^\r\n\s;]+)[\r\n\s;]*;?$")
.unwrap();
let caps = re.captures(query_string).with_context(err)?;
let tenantid = ZTenantId::from_str(caps.get(1).unwrap().as_str())?;
let branchname = caps.get(2).with_context(err)?.as_str().to_owned();
let startpoint_str = caps.get(3).with_context(err)?.as_str().to_owned();
self.check_permission(Some(tenantid))?;
let _enter =
info_span!("branch_create", name = %branchname, tenant = %tenantid).entered();
let branch =
branches::create_branch(self.conf, &branchname, &startpoint_str, &tenantid)?;
let branch = serde_json::to_vec(&branch)?;
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
.write_message_noflush(&BeMessage::DataRow(&[Some(&branch)]))?
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
} else if query_string.starts_with("branch_list ") {
// branch_list <zenith tenantid as hex string>
let re = Regex::new(r"^branch_list ([[:xdigit:]]+)$").unwrap();
let caps = re
.captures(query_string)
.with_context(|| format!("invalid branch_list: '{}'", query_string))?;
let tenantid = ZTenantId::from_str(caps.get(1).unwrap().as_str())?;
// since these handlers for tenant/branch commands are deprecated (in favor of http based ones)
// just use false in place of include non incremental logical size
let branches = crate::branches::get_branches(self.conf, &tenantid, false)?;
let branches_buf = serde_json::to_vec(&branches)?;
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
.write_message_noflush(&BeMessage::DataRow(&[Some(&branches_buf)]))?
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
} else if query_string.starts_with("tenant_list") {
let tenants = crate::tenant_mgr::list_tenants()?;
let tenants_buf = serde_json::to_vec(&tenants)?;
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
.write_message_noflush(&BeMessage::DataRow(&[Some(&tenants_buf)]))?
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
} else if query_string.starts_with("tenant_create") {
let err = || format!("invalid tenant_create: '{}'", query_string);
// tenant_create <tenantid>
let re = Regex::new(r"^tenant_create ([[:xdigit:]]+)$").unwrap();
let caps = re.captures(query_string).with_context(err)?;
self.check_permission(None)?;
let tenantid = ZTenantId::from_str(caps.get(1).unwrap().as_str())?;
tenant_mgr::create_repository_for_tenant(self.conf, tenantid)?;
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
} else if query_string.starts_with("status") {
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
.write_message_noflush(&HELLO_WORLD_ROW)?
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
} else if query_string.to_ascii_lowercase().starts_with("set ") {
// important because psycopg2 executes "SET datestyle TO 'ISO'"
// on connect
pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
} else if query_string.starts_with("failpoints ") {
let (_, failpoints) = query_string.split_at("failpoints ".len());
for failpoint in failpoints.split(';') {
if let Some((name, actions)) = failpoint.split_once('=') {
info!("cfg failpoint: {} {}", name, actions);
fail::cfg(name, actions).unwrap();
} else {
bail!("Invalid failpoints format");
}
}
pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
} else if query_string.starts_with("do_gc ") {
// Run GC immediately on given timeline.
// FIXME: This is just for tests. See test_runner/batch_others/test_gc.py.

View File

@@ -12,7 +12,6 @@ use crate::thread_mgr::ThreadKind;
use crate::walingest::WalIngest;
use anyhow::{bail, Context, Error, Result};
use bytes::BytesMut;
use fail::fail_point;
use lazy_static::lazy_static;
use postgres_ffi::waldecoder::*;
use postgres_protocol::message::backend::ReplicationMessage;
@@ -32,7 +31,6 @@ use zenith_utils::lsn::Lsn;
use zenith_utils::pq_proto::ZenithFeedback;
use zenith_utils::zid::ZTenantId;
use zenith_utils::zid::ZTimelineId;
//
// We keep one WAL Receiver active per timeline.
//
@@ -256,8 +254,6 @@ fn walreceiver_main(
let writer = timeline.writer();
walingest.ingest_record(writer.as_ref(), recdata, lsn)?;
fail_point!("walreceiver-after-ingest");
last_rec_lsn = lsn;
}

27
poetry.lock generated
View File

@@ -814,7 +814,7 @@ python-versions = "*"
[[package]]
name = "moto"
version = "3.0.0"
version = "3.0.3"
description = "A library that allows your python tests to easily mock out the boto library"
category = "main"
optional = false
@@ -849,6 +849,7 @@ xmltodict = "*"
[package.extras]
all = ["PyYAML (>=5.1)", "python-jose[cryptography] (>=3.1.0,<4.0.0)", "ecdsa (!=0.15)", "docker (>=2.5.1)", "graphql-core", "jsondiff (>=1.1.2)", "aws-xray-sdk (>=0.93,!=0.96)", "idna (>=2.5,<4)", "cfn-lint (>=0.4.0)", "sshpubkeys (>=3.1.0)", "setuptools"]
apigateway = ["python-jose[cryptography] (>=3.1.0,<4.0.0)", "ecdsa (!=0.15)"]
apigatewayv2 = ["PyYAML (>=5.1)"]
appsync = ["graphql-core"]
awslambda = ["docker (>=2.5.1)"]
batch = ["docker (>=2.5.1)"]
@@ -1059,6 +1060,20 @@ python-versions = ">=3.6"
py = "*"
pytest = ">=3.10"
[[package]]
name = "pytest-skip-slow"
version = "0.0.2"
description = "A pytest plugin to skip `@pytest.mark.slow` tests by default. "
category = "main"
optional = false
python-versions = ">=3.7"
[package.dependencies]
pytest = ">=6.2.0"
[package.extras]
test = ["tox"]
[[package]]
name = "pytest-xdist"
version = "2.5.0"
@@ -1352,7 +1367,7 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest-
[metadata]
lock-version = "1.1"
python-versions = "^3.7"
content-hash = "0fa6c9377fbc827240d18d8b7e3742def37e90fc3277fddf8525d82dabd13090"
content-hash = "d59ea97fb78d13dcede2719734c55b7d6b9dcc7f86001a9228c8808ed6e58eb7"
[metadata.files]
aiopg = [
@@ -1666,8 +1681,8 @@ mccabe = [
{file = "mccabe-0.6.1.tar.gz", hash = "sha256:dd8d182285a0fe56bace7f45b5e7d1a6ebcbf524e8f3bd87eb0f125271b8831f"},
]
moto = [
{file = "moto-3.0.0-py2.py3-none-any.whl", hash = "sha256:762d33bbad3642c687f6495e69331318bef43f9aa662174397706ec3ad2a3578"},
{file = "moto-3.0.0.tar.gz", hash = "sha256:d6b00a2663290e7ebb06823d5ffcb124c8dc9bf526b878539ef7c4a377fd8255"},
{file = "moto-3.0.3-py2.py3-none-any.whl", hash = "sha256:445a574395b8a43a249ae0f932bf10c5cc677054198bfa1ff92e6fbd60e72c38"},
{file = "moto-3.0.3.tar.gz", hash = "sha256:fa3fbdc22c55d7e70b407e2f2639c48ac82b074f472b167609405c0c1e3a2ccb"},
]
mypy = [
{file = "mypy-0.910-cp35-cp35m-macosx_10_9_x86_64.whl", hash = "sha256:a155d80ea6cee511a3694b108c4494a39f42de11ee4e61e72bc424c490e46457"},
@@ -1842,6 +1857,10 @@ pytest-forked = [
{file = "pytest-forked-1.4.0.tar.gz", hash = "sha256:8b67587c8f98cbbadfdd804539ed5455b6ed03802203485dd2f53c1422d7440e"},
{file = "pytest_forked-1.4.0-py3-none-any.whl", hash = "sha256:bbbb6717efc886b9d64537b41fb1497cfaf3c9601276be8da2cccfea5a3c8ad8"},
]
pytest-skip-slow = [
{file = "pytest-skip-slow-0.0.2.tar.gz", hash = "sha256:06b8353d8b1e2168c22b8c34172329b4f592eea648dac141099cc881fd466718"},
{file = "pytest_skip_slow-0.0.2-py3-none-any.whl", hash = "sha256:9029eb8c258e6c3c4b499848d864c9e9c48499d8b8bf5bb413044f59874cfd06"},
]
pytest-xdist = [
{file = "pytest-xdist-2.5.0.tar.gz", hash = "sha256:4580deca3ff04ddb2ac53eba39d76cb5dd5edeac050cb6fbc768b0dd712b4edf"},
{file = "pytest_xdist-2.5.0-py3-none-any.whl", hash = "sha256:6fe5c74fec98906deb8f2d2b616b5c782022744978e7bd4695d39c8f42d0ce65"},

View File

@@ -6,28 +6,18 @@ edition = "2021"
[dependencies]
anyhow = "1.0"
bytes = { version = "1.0.1", features = ['serde'] }
clap = "3.0"
futures = "0.3.13"
hashbrown = "0.11.2"
hex = "0.4.3"
hyper = "0.14"
lazy_static = "1.4.0"
md5 = "0.7.0"
parking_lot = "0.11.2"
pin-project-lite = "0.2.7"
rand = "0.8.3"
reqwest = { version = "0.11", default-features = false, features = ["blocking", "json", "rustls-tls"] }
rustls = "0.19.1"
scopeguard = "1.1.0"
hex = "0.4.3"
hyper = "0.14"
serde = "1"
serde_json = "1"
tokio = { version = "1.11", features = ["macros"] }
tokio-postgres = { git = "https://github.com/zenithdb/rust-postgres.git", rev="2949d98df52587d562986aad155dd4e889e408b7" }
tokio-rustls = "0.22.0"
clap = "3.0"
rustls = "0.19.1"
reqwest = { version = "0.11", default-features = false, features = ["blocking", "json", "rustls-tls"] }
zenith_utils = { path = "../zenith_utils" }
zenith_metrics = { path = "../zenith_metrics" }
[dev-dependencies]
tokio-postgres-rustls = "0.8.0"
rcgen = "0.8.14"

View File

@@ -1,169 +0,0 @@
use crate::compute::DatabaseInfo;
use crate::config::ProxyConfig;
use crate::cplane_api::{self, CPlaneApi};
use crate::stream::PqStream;
use anyhow::{anyhow, bail, Context};
use std::collections::HashMap;
use tokio::io::{AsyncRead, AsyncWrite};
use zenith_utils::pq_proto::{BeMessage as Be, BeParameterStatusMessage, FeMessage as Fe};
/// Various client credentials which we use for authentication.
#[derive(Debug, PartialEq, Eq)]
pub struct ClientCredentials {
pub user: String,
pub dbname: String,
}
impl TryFrom<HashMap<String, String>> for ClientCredentials {
type Error = anyhow::Error;
fn try_from(mut value: HashMap<String, String>) -> Result<Self, Self::Error> {
let mut get_param = |key| {
value
.remove(key)
.with_context(|| format!("{} is missing in startup packet", key))
};
let user = get_param("user")?;
let db = get_param("database")?;
Ok(Self { user, dbname: db })
}
}
impl ClientCredentials {
/// Use credentials to authenticate the user.
pub async fn authenticate(
self,
config: &ProxyConfig,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> anyhow::Result<DatabaseInfo> {
use crate::config::ClientAuthMethod::*;
use crate::config::RouterConfig::*;
let db_info = match &config.router_config {
Static { host, port } => handle_static(host.clone(), *port, client, self).await,
Dynamic(Mixed) => {
if self.user.ends_with("@zenith") {
handle_existing_user(config, client, self).await
} else {
handle_new_user(config, client).await
}
}
Dynamic(Password) => handle_existing_user(config, client, self).await,
Dynamic(Link) => handle_new_user(config, client).await,
};
db_info.context("failed to authenticate client")
}
}
fn new_psql_session_id() -> String {
hex::encode(rand::random::<[u8; 8]>())
}
async fn handle_static(
host: String,
port: u16,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
creds: ClientCredentials,
) -> anyhow::Result<DatabaseInfo> {
client
.write_message(&Be::AuthenticationCleartextPassword)
.await?;
// Read client's password bytes
let msg = match client.read_message().await? {
Fe::PasswordMessage(msg) => msg,
bad => bail!("unexpected message type: {:?}", bad),
};
let cleartext_password = std::str::from_utf8(&msg)?.split('\0').next().unwrap();
let db_info = DatabaseInfo {
host,
port,
dbname: creds.dbname.clone(),
user: creds.user.clone(),
password: Some(cleartext_password.into()),
};
client
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
Ok(db_info)
}
async fn handle_existing_user(
config: &ProxyConfig,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
creds: ClientCredentials,
) -> anyhow::Result<DatabaseInfo> {
let psql_session_id = new_psql_session_id();
let md5_salt = rand::random();
client
.write_message(&Be::AuthenticationMD5Password(&md5_salt))
.await?;
// Read client's password hash
let msg = match client.read_message().await? {
Fe::PasswordMessage(msg) => msg,
bad => bail!("unexpected message type: {:?}", bad),
};
let (_trailing_null, md5_response) = msg
.split_last()
.ok_or_else(|| anyhow!("unexpected password message"))?;
let cplane = CPlaneApi::new(&config.auth_endpoint);
let db_info = cplane
.authenticate_proxy_request(creds, md5_response, &md5_salt, &psql_session_id)
.await?;
client
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
Ok(db_info)
}
async fn handle_new_user(
config: &ProxyConfig,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin>,
) -> anyhow::Result<DatabaseInfo> {
let psql_session_id = new_psql_session_id();
let greeting = hello_message(&config.redirect_uri, &psql_session_id);
let db_info = cplane_api::with_waiter(psql_session_id, |waiter| async {
// Give user a URL to spawn a new database
client
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?
.write_message(&Be::NoticeResponse(greeting))
.await?;
// Wait for web console response
waiter.await?.map_err(|e| anyhow!(e))
})
.await?;
client.write_message_noflush(&Be::NoticeResponse("Connecting to database.".into()))?;
Ok(db_info)
}
fn hello_message(redirect_uri: &str, session_id: &str) -> String {
format!(
concat![
"☀️ Welcome to Zenith!\n",
"To proceed with database creation, open the following link:\n\n",
" {redirect_uri}{session_id}\n\n",
"It needs to be done once and we will send you '.pgpass' file,\n",
"which will allow you to access or create ",
"databases without opening your web browser."
],
redirect_uri = redirect_uri,
session_id = session_id,
)
}

View File

@@ -1,106 +0,0 @@
use anyhow::{anyhow, Context};
use hashbrown::HashMap;
use parking_lot::Mutex;
use std::net::SocketAddr;
use tokio::net::TcpStream;
use tokio_postgres::{CancelToken, NoTls};
use zenith_utils::pq_proto::CancelKeyData;
/// Enables serving CancelRequests.
#[derive(Default)]
pub struct CancelMap(Mutex<HashMap<CancelKeyData, Option<CancelClosure>>>);
impl CancelMap {
/// Cancel a running query for the corresponding connection.
pub async fn cancel_session(&self, key: CancelKeyData) -> anyhow::Result<()> {
let cancel_closure = self
.0
.lock()
.get(&key)
.and_then(|x| x.clone())
.with_context(|| format!("unknown session: {:?}", key))?;
cancel_closure.try_cancel_query().await
}
/// Run async action within an ephemeral session identified by [`CancelKeyData`].
pub async fn with_session<'a, F, R, V>(&'a self, f: F) -> anyhow::Result<V>
where
F: FnOnce(Session<'a>) -> R,
R: std::future::Future<Output = anyhow::Result<V>>,
{
// HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
// expose it and we don't want to do another roundtrip to query
// for it. The client will be able to notice that this is not the
// actual backend_pid, but backend_pid is not used for anything
// so it doesn't matter.
let key = rand::random();
// Random key collisions are unlikely to happen here, but they're still possible,
// which is why we have to take care not to rewrite an existing key.
self.0
.lock()
.try_insert(key, None)
.map_err(|_| anyhow!("session already exists: {:?}", key))?;
// This will guarantee that the session gets dropped
// as soon as the future is finished.
scopeguard::defer! {
self.0.lock().remove(&key);
}
let session = Session::new(key, self);
f(session).await
}
}
/// This should've been a [`std::future::Future`], but
/// it's impossible to name a type of an unboxed future
/// (we'd need something like `#![feature(type_alias_impl_trait)]`).
#[derive(Clone)]
pub struct CancelClosure {
socket_addr: SocketAddr,
cancel_token: CancelToken,
}
impl CancelClosure {
pub fn new(socket_addr: SocketAddr, cancel_token: CancelToken) -> Self {
Self {
socket_addr,
cancel_token,
}
}
/// Cancels the query running on user's compute node.
pub async fn try_cancel_query(self) -> anyhow::Result<()> {
let socket = TcpStream::connect(self.socket_addr).await?;
self.cancel_token.cancel_query_raw(socket, NoTls).await?;
Ok(())
}
}
/// Helper for registering query cancellation tokens.
pub struct Session<'a> {
/// The user-facing key identifying this session.
key: CancelKeyData,
/// The [`CancelMap`] this session belongs to.
cancel_map: &'a CancelMap,
}
impl<'a> Session<'a> {
fn new(key: CancelKeyData, cancel_map: &'a CancelMap) -> Self {
Self { key, cancel_map }
}
/// Store the cancel token for the given session.
/// This enables query cancellation in [`crate::proxy::handshake`].
pub fn enable_cancellation(self, cancel_closure: CancelClosure) -> CancelKeyData {
self.cancel_map
.0
.lock()
.insert(self.key, Some(cancel_closure));
self.key
}
}

View File

@@ -1,42 +0,0 @@
use anyhow::Context;
use serde::{Deserialize, Serialize};
use std::net::{SocketAddr, ToSocketAddrs};
/// Compute node connection params.
#[derive(Serialize, Deserialize, Debug, Default)]
pub struct DatabaseInfo {
pub host: String,
pub port: u16,
pub dbname: String,
pub user: String,
pub password: Option<String>,
}
impl DatabaseInfo {
pub fn socket_addr(&self) -> anyhow::Result<SocketAddr> {
let host_port = format!("{}:{}", self.host, self.port);
host_port
.to_socket_addrs()
.with_context(|| format!("cannot resolve {} to SocketAddr", host_port))?
.next()
.context("cannot resolve at least one SocketAddr")
}
}
impl From<DatabaseInfo> for tokio_postgres::Config {
fn from(db_info: DatabaseInfo) -> Self {
let mut config = tokio_postgres::Config::new();
config
.host(&db_info.host)
.port(db_info.port)
.dbname(&db_info.dbname)
.user(&db_info.user);
if let Some(password) = db_info.password {
config.password(password);
}
config
}
}

View File

@@ -1,79 +1,18 @@
use crate::auth::ClientCredentials;
use crate::compute::DatabaseInfo;
use crate::waiters::{Waiter, Waiters};
use anyhow::{anyhow, bail};
use lazy_static::lazy_static;
use anyhow::{anyhow, bail, Context};
use serde::{Deserialize, Serialize};
use std::net::{SocketAddr, ToSocketAddrs};
lazy_static! {
static ref CPLANE_WAITERS: Waiters<Result<DatabaseInfo, String>> = Default::default();
use crate::state::ProxyWaiters;
#[derive(Serialize, Deserialize, Debug, Default)]
pub struct DatabaseInfo {
pub host: String,
pub port: u16,
pub dbname: String,
pub user: String,
pub password: Option<String>,
}
/// Give caller an opportunity to wait for cplane's reply.
pub async fn with_waiter<F, R, T>(psql_session_id: impl Into<String>, f: F) -> anyhow::Result<T>
where
F: FnOnce(Waiter<'static, Result<DatabaseInfo, String>>) -> R,
R: std::future::Future<Output = anyhow::Result<T>>,
{
let waiter = CPLANE_WAITERS.register(psql_session_id.into())?;
f(waiter).await
}
pub fn notify(psql_session_id: &str, msg: Result<DatabaseInfo, String>) -> anyhow::Result<()> {
CPLANE_WAITERS.notify(psql_session_id, msg)
}
/// Zenith console API wrapper.
pub struct CPlaneApi<'a> {
auth_endpoint: &'a str,
}
impl<'a> CPlaneApi<'a> {
pub fn new(auth_endpoint: &'a str) -> Self {
Self { auth_endpoint }
}
}
impl CPlaneApi<'_> {
pub async fn authenticate_proxy_request(
&self,
creds: ClientCredentials,
md5_response: &[u8],
salt: &[u8; 4],
psql_session_id: &str,
) -> anyhow::Result<DatabaseInfo> {
let mut url = reqwest::Url::parse(self.auth_endpoint)?;
url.query_pairs_mut()
.append_pair("login", &creds.user)
.append_pair("database", &creds.dbname)
.append_pair("md5response", std::str::from_utf8(md5_response)?)
.append_pair("salt", &hex::encode(salt))
.append_pair("psql_session_id", psql_session_id);
with_waiter(psql_session_id, |waiter| async {
println!("cplane request: {}", url);
// TODO: leverage `reqwest::Client` to reuse connections
let resp = reqwest::get(url).await?;
if !resp.status().is_success() {
bail!("Auth failed: {}", resp.status())
}
let auth_info: ProxyAuthResponse = serde_json::from_str(resp.text().await?.as_str())?;
println!("got auth info: #{:?}", auth_info);
use ProxyAuthResponse::*;
match auth_info {
Ready { conn_info } => Ok(conn_info),
Error { error } => bail!(error),
NotReady { .. } => waiter.await?.map_err(|e| anyhow!(e)),
}
})
.await
}
}
// NOTE: the order of constructors is important.
// https://serde.rs/enum-representations.html#untagged
#[derive(Serialize, Deserialize, Debug)]
#[serde(untagged)]
enum ProxyAuthResponse {
@@ -82,6 +21,86 @@ enum ProxyAuthResponse {
NotReady { ready: bool }, // TODO: get rid of `ready`
}
impl DatabaseInfo {
pub fn socket_addr(&self) -> anyhow::Result<SocketAddr> {
let host_port = format!("{}:{}", self.host, self.port);
host_port
.to_socket_addrs()
.with_context(|| format!("cannot resolve {} to SocketAddr", host_port))?
.next()
.context("cannot resolve at least one SocketAddr")
}
}
impl From<DatabaseInfo> for tokio_postgres::Config {
fn from(db_info: DatabaseInfo) -> Self {
let mut config = tokio_postgres::Config::new();
config
.host(&db_info.host)
.port(db_info.port)
.dbname(&db_info.dbname)
.user(&db_info.user);
if let Some(password) = db_info.password {
config.password(password);
}
config
}
}
pub struct CPlaneApi<'a> {
auth_endpoint: &'a str,
waiters: &'a ProxyWaiters,
}
impl<'a> CPlaneApi<'a> {
pub fn new(auth_endpoint: &'a str, waiters: &'a ProxyWaiters) -> Self {
Self {
auth_endpoint,
waiters,
}
}
}
impl CPlaneApi<'_> {
pub fn authenticate_proxy_request(
&self,
user: &str,
database: &str,
md5_response: &[u8],
salt: &[u8; 4],
psql_session_id: &str,
) -> anyhow::Result<DatabaseInfo> {
let mut url = reqwest::Url::parse(self.auth_endpoint)?;
url.query_pairs_mut()
.append_pair("login", user)
.append_pair("database", database)
.append_pair("md5response", std::str::from_utf8(md5_response)?)
.append_pair("salt", &hex::encode(salt))
.append_pair("psql_session_id", psql_session_id);
let waiter = self.waiters.register(psql_session_id.to_owned());
println!("cplane request: {}", url);
let resp = reqwest::blocking::get(url)?;
if !resp.status().is_success() {
bail!("Auth failed: {}", resp.status())
}
let auth_info: ProxyAuthResponse = serde_json::from_str(resp.text()?.as_str())?;
println!("got auth info: #{:?}", auth_info);
use ProxyAuthResponse::*;
match auth_info {
Ready { conn_info } => Ok(conn_info),
Error { error } => bail!(error),
NotReady { .. } => waiter.wait()?.map_err(|e| anyhow!(e)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -1,30 +1,15 @@
use anyhow::anyhow;
use hyper::{Body, Request, Response, StatusCode};
use std::net::TcpListener;
use zenith_utils::http::RouterBuilder;
use zenith_utils::http::endpoint;
use zenith_utils::http::error::ApiError;
use zenith_utils::http::json::json_response;
use zenith_utils::http::{RouterBuilder, RouterService};
async fn status_handler(_: Request<Body>) -> Result<Response<Body>, ApiError> {
Ok(json_response(StatusCode::OK, "")?)
}
fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
pub fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
let router = endpoint::make_router();
router.get("/v1/status", status_handler)
}
pub async fn thread_main(http_listener: TcpListener) -> anyhow::Result<()> {
scopeguard::defer! {
println!("http has shut down");
}
let service = || RouterService::new(make_router().build()?);
hyper::Server::from_tcp(http_listener)?
.serve(service().map_err(|e| anyhow!(e))?)
.await?;
Ok(())
}

View File

@@ -5,36 +5,21 @@
/// (control plane API in our case) and can create new databases and accounts
/// in somewhat transparent manner (again via communication with control plane API).
///
use anyhow::{bail, Context};
use anyhow::bail;
use clap::{App, Arg};
use config::ProxyConfig;
use futures::FutureExt;
use std::future::Future;
use tokio::{net::TcpListener, task::JoinError};
use zenith_utils::GIT_VERSION;
use state::{ProxyConfig, ProxyState};
use std::thread;
use zenith_utils::http::endpoint;
use zenith_utils::{tcp_listener, GIT_VERSION};
use crate::config::{ClientAuthMethod, RouterConfig};
mod auth;
mod cancellation;
mod compute;
mod config;
mod cplane_api;
mod http;
mod mgmt;
mod proxy;
mod stream;
mod state;
mod waiters;
/// Flattens Result<Result<T>> into Result<T>.
async fn flatten_err(
f: impl Future<Output = Result<anyhow::Result<()>, JoinError>>,
) -> anyhow::Result<()> {
f.map(|r| r.context("join error").and_then(|x| x)).await
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
fn main() -> anyhow::Result<()> {
zenith_metrics::set_common_metrics_prefix("zenith_proxy");
let arg_matches = App::new("Zenith proxy/router")
.version(GIT_VERSION)
@@ -46,20 +31,6 @@ async fn main() -> anyhow::Result<()> {
.help("listen for incoming client connections on ip:port")
.default_value("127.0.0.1:4432"),
)
.arg(
Arg::new("auth-method")
.long("auth-method")
.takes_value(true)
.help("Possible values: password | link | mixed")
.default_value("mixed"),
)
.arg(
Arg::new("static-router")
.short('s')
.long("static-router")
.takes_value(true)
.help("Route all clients to host:port"),
)
.arg(
Arg::new("mgmt")
.short('m')
@@ -108,59 +79,63 @@ async fn main() -> anyhow::Result<()> {
)
.get_matches();
let tls_config = match (
let ssl_config = match (
arg_matches.value_of("ssl-key"),
arg_matches.value_of("ssl-cert"),
) {
(Some(key_path), Some(cert_path)) => Some(config::configure_ssl(key_path, cert_path)?),
(Some(key_path), Some(cert_path)) => {
Some(crate::state::configure_ssl(key_path, cert_path)?)
}
(None, None) => None,
_ => bail!("either both or neither ssl-key and ssl-cert must be specified"),
};
let auth_method = arg_matches.value_of("auth-method").unwrap().parse()?;
let router_config = match arg_matches.value_of("static-router") {
None => RouterConfig::Dynamic(auth_method),
Some(addr) => {
if let ClientAuthMethod::Password = auth_method {
let (host, port) = addr.split_once(":").unwrap();
RouterConfig::Static {
host: host.to_string(),
port: port.parse().unwrap(),
}
} else {
bail!("static-router requires --auth-method password")
}
}
};
let config: &ProxyConfig = Box::leak(Box::new(ProxyConfig {
router_config,
let config = ProxyConfig {
proxy_address: arg_matches.value_of("proxy").unwrap().parse()?,
mgmt_address: arg_matches.value_of("mgmt").unwrap().parse()?,
http_address: arg_matches.value_of("http").unwrap().parse()?,
redirect_uri: arg_matches.value_of("uri").unwrap().parse()?,
auth_endpoint: arg_matches.value_of("auth-endpoint").unwrap().parse()?,
tls_config,
}));
ssl_config,
};
let state: &ProxyState = Box::leak(Box::new(ProxyState::new(config)));
println!("Version: {}", GIT_VERSION);
// Check that we can bind to address before further initialization
println!("Starting http on {}", config.http_address);
let http_listener = TcpListener::bind(config.http_address).await?.into_std()?;
println!("Starting http on {}", state.conf.http_address);
let http_listener = tcp_listener::bind(state.conf.http_address)?;
println!("Starting mgmt on {}", config.mgmt_address);
let mgmt_listener = TcpListener::bind(config.mgmt_address).await?.into_std()?;
println!("Starting proxy on {}", state.conf.proxy_address);
let pageserver_listener = tcp_listener::bind(state.conf.proxy_address)?;
println!("Starting proxy on {}", config.proxy_address);
let proxy_listener = TcpListener::bind(config.proxy_address).await?;
println!("Starting mgmt on {}", state.conf.mgmt_address);
let mgmt_listener = tcp_listener::bind(state.conf.mgmt_address)?;
let http = tokio::spawn(http::thread_main(http_listener));
let proxy = tokio::spawn(proxy::thread_main(config, proxy_listener));
let mgmt = tokio::task::spawn_blocking(move || mgmt::thread_main(mgmt_listener));
let threads = [
thread::Builder::new()
.name("Http thread".into())
.spawn(move || {
let router = http::make_router();
endpoint::serve_thread_main(
router,
http_listener,
std::future::pending(), // never shut down
)
})?,
// Spawn a thread to listen for connections. It will spawn further threads
// for each connection.
thread::Builder::new()
.name("Listener thread".into())
.spawn(move || proxy::thread_main(state, pageserver_listener))?,
thread::Builder::new()
.name("Mgmt thread".into())
.spawn(move || mgmt::thread_main(state, mgmt_listener))?,
];
let tasks = [flatten_err(http), flatten_err(proxy), flatten_err(mgmt)];
let _: Vec<()> = futures::future::try_join_all(tasks).await?;
for t in threads {
t.join().unwrap()?;
}
Ok(())
}

View File

@@ -1,49 +1,44 @@
use crate::{compute::DatabaseInfo, cplane_api};
use anyhow::Context;
use serde::Deserialize;
use std::{
net::{TcpListener, TcpStream},
thread,
};
use serde::Deserialize;
use zenith_utils::{
postgres_backend::{self, AuthType, PostgresBackend},
pq_proto::{BeMessage, SINGLE_COL_ROWDESC},
};
use crate::{cplane_api::DatabaseInfo, ProxyState};
///
/// Main proxy listener loop.
///
/// Listens for connections, and launches a new handler thread for each.
///
pub fn thread_main(listener: TcpListener) -> anyhow::Result<()> {
scopeguard::defer! {
println!("mgmt has shut down");
}
listener
.set_nonblocking(false)
.context("failed to set listener to blocking")?;
pub fn thread_main(state: &'static ProxyState, listener: TcpListener) -> anyhow::Result<()> {
loop {
let (socket, peer_addr) = listener.accept().context("failed to accept a new client")?;
let (socket, peer_addr) = listener.accept()?;
println!("accepted connection from {}", peer_addr);
socket
.set_nodelay(true)
.context("failed to set client socket option")?;
socket.set_nodelay(true).unwrap();
thread::spawn(move || {
if let Err(err) = handle_connection(socket) {
if let Err(err) = handle_connection(state, socket) {
println!("error: {}", err);
}
});
}
}
fn handle_connection(socket: TcpStream) -> anyhow::Result<()> {
fn handle_connection(state: &ProxyState, socket: TcpStream) -> anyhow::Result<()> {
let mut conn_handler = MgmtHandler { state };
let pgbackend = PostgresBackend::new(socket, AuthType::Trust, None, true)?;
pgbackend.run(&mut MgmtHandler)
pgbackend.run(&mut conn_handler)
}
struct MgmtHandler;
struct MgmtHandler<'a> {
state: &'a ProxyState,
}
/// Serialized examples:
// {
@@ -79,13 +74,13 @@ enum PsqlSessionResult {
Failure(String),
}
impl postgres_backend::Handler for MgmtHandler {
impl postgres_backend::Handler for MgmtHandler<'_> {
fn process_query(
&mut self,
pgb: &mut PostgresBackend,
query_string: &str,
) -> anyhow::Result<()> {
let res = try_process_query(pgb, query_string);
let res = try_process_query(self, pgb, query_string);
// intercept and log error message
if res.is_err() {
println!("Mgmt query failed: #{:?}", res);
@@ -94,7 +89,11 @@ impl postgres_backend::Handler for MgmtHandler {
}
}
fn try_process_query(pgb: &mut PostgresBackend, query_string: &str) -> anyhow::Result<()> {
fn try_process_query(
mgmt: &mut MgmtHandler,
pgb: &mut PostgresBackend,
query_string: &str,
) -> anyhow::Result<()> {
println!("Got mgmt query: '{}'", query_string);
let resp: PsqlSessionResponse = serde_json::from_str(query_string)?;
@@ -105,7 +104,7 @@ fn try_process_query(pgb: &mut PostgresBackend, query_string: &str) -> anyhow::R
Failure(message) => Err(message),
};
match cplane_api::notify(&resp.session_id, msg) {
match mgmt.state.waiters.notify(&resp.session_id, msg) {
Ok(()) => {
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
.write_message_noflush(&BeMessage::DataRow(&[Some(b"ok")]))?

View File

@@ -1,332 +1,389 @@
use crate::auth;
use crate::cancellation::{self, CancelClosure, CancelMap};
use crate::compute::DatabaseInfo;
use crate::config::{ProxyConfig, TlsConfig};
use crate::stream::{MetricsStream, PqStream, Stream};
use anyhow::{bail, Context};
use crate::cplane_api::{CPlaneApi, DatabaseInfo};
use crate::ProxyState;
use anyhow::{anyhow, bail, Context};
use lazy_static::lazy_static;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use rand::prelude::StdRng;
use rand::{Rng, SeedableRng};
use std::cell::Cell;
use std::collections::HashMap;
use std::net::{SocketAddr, TcpStream};
use std::sync::Mutex;
use std::{io, thread};
use tokio_postgres::NoTls;
use zenith_metrics::{new_common_metric_name, register_int_counter, IntCounter};
use zenith_utils::pq_proto::{BeMessage as Be, *};
use zenith_utils::postgres_backend::{self, PostgresBackend, ProtoState, Stream};
use zenith_utils::pq_proto::{BeMessage as Be, FeMessage as Fe, *};
use zenith_utils::sock_split::{ReadStream, WriteStream};
lazy_static! {
static ref NUM_CONNECTIONS_ACCEPTED_COUNTER: IntCounter = register_int_counter!(
new_common_metric_name("num_connections_accepted"),
"Number of TCP client connections accepted."
)
.unwrap();
static ref NUM_CONNECTIONS_CLOSED_COUNTER: IntCounter = register_int_counter!(
new_common_metric_name("num_connections_closed"),
"Number of TCP client connections closed."
)
.unwrap();
static ref NUM_BYTES_PROXIED_COUNTER: IntCounter = register_int_counter!(
new_common_metric_name("num_bytes_proxied"),
"Number of bytes sent/received between any client and backend."
)
.unwrap();
struct CancelClosure {
socket_addr: SocketAddr,
cancel_token: tokio_postgres::CancelToken,
}
async fn log_error<R, F>(future: F) -> F::Output
where
F: std::future::Future<Output = anyhow::Result<R>>,
{
future.await.map_err(|err| {
println!("error: {}", err);
err
})
}
pub async fn thread_main(
config: &'static ProxyConfig,
listener: tokio::net::TcpListener,
) -> anyhow::Result<()> {
scopeguard::defer! {
println!("proxy has shut down");
}
let cancel_map = Arc::new(CancelMap::default());
loop {
let (socket, peer_addr) = listener.accept().await?;
println!("accepted connection from {}", peer_addr);
let cancel_map = Arc::clone(&cancel_map);
tokio::spawn(log_error(async move {
socket
.set_nodelay(true)
.context("failed to set socket option")?;
handle_client(config, &cancel_map, socket).await
}));
}
}
async fn handle_client(
config: &ProxyConfig,
cancel_map: &CancelMap,
stream: impl AsyncRead + AsyncWrite + Unpin,
) -> anyhow::Result<()> {
// The `closed` counter will increase when this future is destroyed.
NUM_CONNECTIONS_ACCEPTED_COUNTER.inc();
scopeguard::defer! {
NUM_CONNECTIONS_CLOSED_COUNTER.inc();
}
let tls = config.tls_config.clone();
if let Some((client, creds)) = handshake(stream, tls, cancel_map).await? {
cancel_map
.with_session(|session| async {
connect_client_to_db(config, session, client, creds).await
})
.await?;
}
Ok(())
}
/// Handle a connection from one client.
/// For better testing experience, `stream` can be
/// any object satisfying the traits.
async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
stream: S,
mut tls: Option<TlsConfig>,
cancel_map: &CancelMap,
) -> anyhow::Result<Option<(PqStream<Stream<S>>, auth::ClientCredentials)>> {
// Client may try upgrading to each protocol only once
let (mut tried_ssl, mut tried_gss) = (false, false);
let mut stream = PqStream::new(Stream::from_raw(stream));
loop {
let msg = stream.read_startup_packet().await?;
println!("got message: {:?}", msg);
use FeStartupPacket::*;
match msg {
SslRequest => match stream.get_ref() {
Stream::Raw { .. } if !tried_ssl => {
tried_ssl = true;
// We can't perform TLS handshake without a config
let enc = tls.is_some();
stream.write_message(&Be::EncryptionResponse(enc)).await?;
if let Some(tls) = tls.take() {
// Upgrade raw stream into a secure TLS-backed stream.
// NOTE: We've consumed `tls`; this fact will be used later.
stream = PqStream::new(stream.into_inner().upgrade(tls).await?);
}
}
_ => bail!("protocol violation"),
},
GssEncRequest => match stream.get_ref() {
Stream::Raw { .. } if !tried_gss => {
tried_gss = true;
// Currently, we don't support GSSAPI
stream.write_message(&Be::EncryptionResponse(false)).await?;
}
_ => bail!("protocol violation"),
},
StartupMessage { params, .. } => {
// Check that the config has been consumed during upgrade
// OR we didn't provide it at all (for dev purposes).
if tls.is_some() {
let msg = "connection is insecure (try using `sslmode=require`)";
stream.write_message(&Be::ErrorResponse(msg)).await?;
bail!(msg);
}
break Ok(Some((stream, params.try_into()?)));
}
CancelRequest(cancel_key_data) => {
cancel_map.cancel_session(cancel_key_data).await?;
break Ok(None);
}
impl CancelClosure {
async fn try_cancel_query(&self) {
if let Ok(socket) = tokio::net::TcpStream::connect(self.socket_addr).await {
// NOTE ignoring the result because:
// 1. This is a best effort attempt, the database doesn't have to listen
// 2. Being opaque about errors here helps avoid leaking info to unauthenticated user
let _ = self.cancel_token.cancel_query_raw(socket, NoTls).await;
}
}
}
async fn connect_client_to_db(
config: &ProxyConfig,
session: cancellation::Session<'_>,
mut client: PqStream<impl AsyncRead + AsyncWrite + Unpin>,
creds: auth::ClientCredentials,
lazy_static! {
// Enables serving CancelRequests
static ref CANCEL_MAP: Mutex<HashMap<CancelKeyData, CancelClosure>> = Mutex::new(HashMap::new());
// Metrics
static ref NUM_CONNECTIONS_ACCEPTED_COUNTER: IntCounter = register_int_counter!(
new_common_metric_name("num_connections_accepted"),
"Number of TCP client connections accepted."
).unwrap();
static ref NUM_CONNECTIONS_CLOSED_COUNTER: IntCounter = register_int_counter!(
new_common_metric_name("num_connections_closed"),
"Number of TCP client connections closed."
).unwrap();
static ref NUM_CONNECTIONS_FAILED_COUNTER: IntCounter = register_int_counter!(
new_common_metric_name("num_connections_failed"),
"Number of TCP client connections that closed due to error."
).unwrap();
static ref NUM_BYTES_PROXIED_COUNTER: IntCounter = register_int_counter!(
new_common_metric_name("num_bytes_proxied"),
"Number of bytes sent/received between any client and backend."
).unwrap();
}
thread_local! {
// Used to clean up the CANCEL_MAP. Might not be necessary if we use tokio thread pool in main loop.
static THREAD_CANCEL_KEY_DATA: Cell<Option<CancelKeyData>> = Cell::new(None);
}
///
/// Main proxy listener loop.
///
/// Listens for connections, and launches a new handler thread for each.
///
pub fn thread_main(
state: &'static ProxyState,
listener: std::net::TcpListener,
) -> anyhow::Result<()> {
let db_info = creds.authenticate(config, &mut client).await?;
let (db, version, cancel_closure) = connect_to_db(db_info).await?;
let cancel_key_data = session.enable_cancellation(cancel_closure);
loop {
let (socket, peer_addr) = listener.accept()?;
println!("accepted connection from {}", peer_addr);
NUM_CONNECTIONS_ACCEPTED_COUNTER.inc();
socket.set_nodelay(true).unwrap();
client
.write_message_noflush(&BeMessage::ParameterStatus(
BeParameterStatusMessage::ServerVersion(&version),
))?
.write_message_noflush(&Be::BackendKeyData(cancel_key_data))?
.write_message(&BeMessage::ReadyForQuery)
.await?;
// TODO Use a threadpool instead. Maybe use tokio's threadpool by
// spawning a future into its runtime. Tokio's JoinError should
// allow us to handle cleanup properly even if the future panics.
thread::Builder::new()
.name("Proxy thread".into())
.spawn(move || {
if let Err(err) = proxy_conn_main(state, socket) {
NUM_CONNECTIONS_FAILED_COUNTER.inc();
println!("error: {}", err);
}
// This function will be called for writes to either direction.
fn inc_proxied(cnt: usize) {
// Consider inventing something more sophisticated
// if this ever becomes a bottleneck (cacheline bouncing).
NUM_BYTES_PROXIED_COUNTER.inc_by(cnt as u64);
// Clean up CANCEL_MAP.
NUM_CONNECTIONS_CLOSED_COUNTER.inc();
THREAD_CANCEL_KEY_DATA.with(|cell| {
if let Some(cancel_key_data) = cell.get() {
CANCEL_MAP.lock().unwrap().remove(&cancel_key_data);
};
});
})?;
}
}
// TODO: clean up fields
struct ProxyConnection {
state: &'static ProxyState,
psql_session_id: String,
pgb: PostgresBackend,
}
pub fn proxy_conn_main(state: &'static ProxyState, socket: TcpStream) -> anyhow::Result<()> {
let conn = ProxyConnection {
state,
psql_session_id: hex::encode(rand::random::<[u8; 8]>()),
pgb: PostgresBackend::new(
socket,
postgres_backend::AuthType::MD5,
state.conf.ssl_config.clone(),
false,
)?,
};
let (client, server) = match conn.handle_client()? {
Some(x) => x,
None => return Ok(()),
};
let server = zenith_utils::sock_split::BidiStream::from_tcp(server);
let client = match client {
Stream::Bidirectional(bidi_stream) => bidi_stream,
_ => panic!("invalid stream type"),
};
proxy(client.split(), server.split())
}
impl ProxyConnection {
/// Returns Ok(None) when connection was successfully closed.
fn handle_client(mut self) -> anyhow::Result<Option<(Stream, TcpStream)>> {
let mut authenticate = || {
let (username, dbname) = match self.handle_startup()? {
Some(x) => x,
None => return Ok(None),
};
// Both scenarios here should end up producing database credentials
if username.ends_with("@zenith") {
self.handle_existing_user(&username, &dbname).map(Some)
} else {
self.handle_new_user().map(Some)
}
};
let conn = match authenticate() {
Ok(Some(db_info)) => connect_to_db(db_info),
Ok(None) => return Ok(None),
Err(e) => {
// Report the error to the client
self.pgb.write_message(&Be::ErrorResponse(&e.to_string()))?;
bail!("failed to handle client: {:?}", e);
}
};
// We'll get rid of this once migration to async is complete
let (pg_version, db_stream) = {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
let (pg_version, stream, cancel_key_data) = runtime.block_on(conn)?;
self.pgb
.write_message(&BeMessage::BackendKeyData(cancel_key_data))?;
let stream = stream.into_std()?;
stream.set_nonblocking(false)?;
(pg_version, stream)
};
// Let the client send new requests
self.pgb
.write_message_noflush(&BeMessage::ParameterStatus(
BeParameterStatusMessage::ServerVersion(&pg_version),
))?
.write_message(&Be::ReadyForQuery)?;
Ok(Some((self.pgb.into_stream(), db_stream)))
}
let mut db = MetricsStream::new(db, inc_proxied);
let mut client = MetricsStream::new(client.into_inner(), inc_proxied);
let _ = tokio::io::copy_bidirectional(&mut client, &mut db).await?;
/// Returns Ok(None) when connection was successfully closed.
fn handle_startup(&mut self) -> anyhow::Result<Option<(String, String)>> {
let have_tls = self.pgb.tls_config.is_some();
let mut encrypted = false;
loop {
let msg = match self.pgb.read_message()? {
Some(Fe::StartupPacket(msg)) => msg,
None => bail!("connection is lost"),
bad => bail!("unexpected message type: {:?}", bad),
};
println!("got message: {:?}", msg);
match msg {
FeStartupPacket::GssEncRequest => {
self.pgb.write_message(&Be::EncryptionResponse(false))?;
}
FeStartupPacket::SslRequest => {
self.pgb.write_message(&Be::EncryptionResponse(have_tls))?;
if have_tls {
self.pgb.start_tls()?;
encrypted = true;
}
}
FeStartupPacket::StartupMessage { mut params, .. } => {
if have_tls && !encrypted {
bail!("must connect with TLS");
}
let mut get_param = |key| {
params
.remove(key)
.with_context(|| format!("{} is missing in startup packet", key))
};
return Ok(Some((get_param("user")?, get_param("database")?)));
}
FeStartupPacket::CancelRequest(cancel_key_data) => {
if let Some(cancel_closure) = CANCEL_MAP.lock().unwrap().get(&cancel_key_data) {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
runtime.block_on(cancel_closure.try_cancel_query());
}
return Ok(None);
}
}
}
}
fn handle_existing_user(&mut self, user: &str, db: &str) -> anyhow::Result<DatabaseInfo> {
let md5_salt = rand::random::<[u8; 4]>();
// Ask password
self.pgb
.write_message(&Be::AuthenticationMD5Password(&md5_salt))?;
self.pgb.state = ProtoState::Authentication; // XXX
// Check password
let msg = match self.pgb.read_message()? {
Some(Fe::PasswordMessage(msg)) => msg,
None => bail!("connection is lost"),
bad => bail!("unexpected message type: {:?}", bad),
};
println!("got message: {:?}", msg);
let (_trailing_null, md5_response) = msg
.split_last()
.ok_or_else(|| anyhow!("unexpected password message"))?;
let cplane = CPlaneApi::new(&self.state.conf.auth_endpoint, &self.state.waiters);
let db_info = cplane.authenticate_proxy_request(
user,
db,
md5_response,
&md5_salt,
&self.psql_session_id,
)?;
self.pgb
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?;
Ok(db_info)
}
fn handle_new_user(&mut self) -> anyhow::Result<DatabaseInfo> {
let greeting = hello_message(&self.state.conf.redirect_uri, &self.psql_session_id);
// First, register this session
let waiter = self.state.waiters.register(self.psql_session_id.clone());
// Give user a URL to spawn a new database
self.pgb
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?
.write_message(&Be::NoticeResponse(greeting))?;
// Wait for web console response
let db_info = waiter.wait()?.map_err(|e| anyhow!(e))?;
self.pgb
.write_message_noflush(&Be::NoticeResponse("Connecting to database.".into()))?;
Ok(db_info)
}
}
fn hello_message(redirect_uri: &str, session_id: &str) -> String {
format!(
concat![
"☀️ Welcome to Zenith!\n",
"To proceed with database creation, open the following link:\n\n",
" {redirect_uri}{session_id}\n\n",
"It needs to be done once and we will send you '.pgpass' file,\n",
"which will allow you to access or create ",
"databases without opening your web browser."
],
redirect_uri = redirect_uri,
session_id = session_id,
)
}
/// Create a TCP connection to a postgres database, authenticate with it, and receive the ReadyForQuery message
async fn connect_to_db(
db_info: DatabaseInfo,
) -> anyhow::Result<(String, tokio::net::TcpStream, CancelKeyData)> {
// Make raw connection. When connect_raw finishes we've received ReadyForQuery.
let socket_addr = db_info.socket_addr()?;
let mut socket = tokio::net::TcpStream::connect(socket_addr).await?;
let config = tokio_postgres::Config::from(db_info);
// NOTE We effectively ignore some ParameterStatus and NoticeResponse
// messages here. Not sure if that could break something.
let (client, conn) = config.connect_raw(&mut socket, NoTls).await?;
// Save info for potentially cancelling the query later
let mut rng = StdRng::from_entropy();
let cancel_key_data = CancelKeyData {
// HACK We'd rather get the real backend_pid but tokio_postgres doesn't
// expose it and we don't want to do another roundtrip to query
// for it. The client will be able to notice that this is not the
// actual backend_pid, but backend_pid is not used for anything
// so it doesn't matter.
backend_pid: rng.gen(),
cancel_key: rng.gen(),
};
let cancel_closure = CancelClosure {
socket_addr,
cancel_token: client.cancel_token(),
};
CANCEL_MAP
.lock()
.unwrap()
.insert(cancel_key_data, cancel_closure);
THREAD_CANCEL_KEY_DATA.with(|cell| {
let prev_value = cell.replace(Some(cancel_key_data));
assert!(
prev_value.is_none(),
"THREAD_CANCEL_KEY_DATA was already set"
);
});
let version = conn.parameter("server_version").unwrap();
Ok((version.into(), socket, cancel_key_data))
}
/// Concurrently proxy both directions of the client and server connections
fn proxy(
(client_read, client_write): (ReadStream, WriteStream),
(server_read, server_write): (ReadStream, WriteStream),
) -> anyhow::Result<()> {
fn do_proxy(mut reader: impl io::Read, mut writer: WriteStream) -> io::Result<u64> {
/// FlushWriter will make sure that every message is sent as soon as possible
struct FlushWriter<W>(W);
impl<W: io::Write> io::Write for FlushWriter<W> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
// `std::io::copy` is guaranteed to exit if we return an error,
// so we can afford to lose `res` in case `flush` fails
let res = self.0.write(buf);
if let Ok(count) = res {
NUM_BYTES_PROXIED_COUNTER.inc_by(count as u64);
self.flush()?;
}
res
}
fn flush(&mut self) -> io::Result<()> {
self.0.flush()
}
}
let res = std::io::copy(&mut reader, &mut FlushWriter(&mut writer));
writer.shutdown(std::net::Shutdown::Both)?;
res
}
let client_to_server_jh = thread::spawn(move || do_proxy(client_read, server_write));
do_proxy(server_read, client_write)?;
client_to_server_jh.join().unwrap()?;
Ok(())
}
/// Connect to a corresponding compute node.
async fn connect_to_db(
db_info: DatabaseInfo,
) -> anyhow::Result<(TcpStream, String, CancelClosure)> {
// TODO: establish a secure connection to the DB
let socket_addr = db_info.socket_addr()?;
let mut socket = TcpStream::connect(socket_addr).await?;
let (client, conn) = tokio_postgres::Config::from(db_info)
.connect_raw(&mut socket, NoTls)
.await?;
let version = conn
.parameter("server_version")
.context("failed to fetch postgres server version")?
.into();
let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token());
Ok((socket, version, cancel_closure))
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::DuplexStream;
use tokio_postgres::config::SslMode;
use tokio_postgres::tls::MakeTlsConnect;
use tokio_postgres_rustls::MakeRustlsConnect;
async fn dummy_proxy(
client: impl AsyncRead + AsyncWrite + Unpin,
tls: Option<TlsConfig>,
) -> anyhow::Result<()> {
let cancel_map = CancelMap::default();
// TODO: add some infra + tests for credentials
let (mut stream, _creds) = handshake(client, tls, &cancel_map)
.await?
.context("no stream")?;
stream
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&BeParameterStatusMessage::encoding())?
.write_message(&BeMessage::ReadyForQuery)
.await?;
Ok(())
}
fn generate_certs(
hostname: &str,
) -> anyhow::Result<(rustls::Certificate, rustls::Certificate, rustls::PrivateKey)> {
let ca = rcgen::Certificate::from_params({
let mut params = rcgen::CertificateParams::default();
params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
params
})?;
let cert = rcgen::generate_simple_self_signed(vec![hostname.into()])?;
Ok((
rustls::Certificate(ca.serialize_der()?),
rustls::Certificate(cert.serialize_der_with_signer(&ca)?),
rustls::PrivateKey(cert.serialize_private_key_der()),
))
}
#[tokio::test]
async fn handshake_tls_is_enforced_by_proxy() -> anyhow::Result<()> {
let (client, server) = tokio::io::duplex(1024);
let server_config = {
let (_ca, cert, key) = generate_certs("localhost")?;
let mut config = rustls::ServerConfig::new(rustls::NoClientAuth::new());
config.set_single_cert(vec![cert], key)?;
config
};
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config.into())));
tokio_postgres::Config::new()
.user("john_doe")
.dbname("earth")
.ssl_mode(SslMode::Disable)
.connect_raw(server, NoTls)
.await
.err() // -> Option<E>
.context("client shouldn't be able to connect")?;
proxy
.await?
.err() // -> Option<E>
.context("server shouldn't accept client")?;
Ok(())
}
#[tokio::test]
async fn handshake_tls() -> anyhow::Result<()> {
let (client, server) = tokio::io::duplex(1024);
let (ca, cert, key) = generate_certs("localhost")?;
let server_config = {
let mut config = rustls::ServerConfig::new(rustls::NoClientAuth::new());
config.set_single_cert(vec![cert], key)?;
config
};
let proxy = tokio::spawn(dummy_proxy(client, Some(server_config.into())));
let client_config = {
let mut config = rustls::ClientConfig::new();
config.root_store.add(&ca)?;
config
};
let mut mk = MakeRustlsConnect::new(client_config);
let tls = MakeTlsConnect::<DuplexStream>::make_tls_connect(&mut mk, "localhost")?;
let (_client, _conn) = tokio_postgres::Config::new()
.user("john_doe")
.dbname("earth")
.ssl_mode(SslMode::Require)
.connect_raw(server, tls)
.await?;
proxy.await?
}
#[tokio::test]
async fn handshake_raw() -> anyhow::Result<()> {
let (client, server) = tokio::io::duplex(1024);
let proxy = tokio::spawn(dummy_proxy(client, None));
let (_client, _conn) = tokio_postgres::Config::new()
.user("john_doe")
.dbname("earth")
.ssl_mode(SslMode::Prefer)
.connect_raw(server, NoTls)
.await?;
proxy.await?
}
}

View File

@@ -1,46 +1,15 @@
use crate::cplane_api::DatabaseInfo;
use anyhow::{anyhow, ensure, Context};
use rustls::{internal::pemfile, NoClientAuth, ProtocolVersion, ServerConfig};
use std::net::SocketAddr;
use std::str::FromStr;
use std::sync::Arc;
pub type TlsConfig = Arc<ServerConfig>;
#[non_exhaustive]
pub enum ClientAuthMethod {
Password,
Link,
/// Use password auth only if username ends with "@zenith"
Mixed,
}
pub enum RouterConfig {
Static { host: String, port: u16 },
Dynamic(ClientAuthMethod),
}
impl FromStr for ClientAuthMethod {
type Err = anyhow::Error;
fn from_str(s: &str) -> anyhow::Result<Self> {
use ClientAuthMethod::*;
match s {
"password" => Ok(Password),
"link" => Ok(Link),
"mixed" => Ok(Mixed),
_ => Err(anyhow::anyhow!("Invlid option for router")),
}
}
}
pub type SslConfig = Arc<ServerConfig>;
pub struct ProxyConfig {
/// main entrypoint for users to connect to
pub proxy_address: SocketAddr,
/// method of assigning compute nodes
pub router_config: RouterConfig,
/// internally used for status and prometheus metrics
pub http_address: SocketAddr,
@@ -55,10 +24,26 @@ pub struct ProxyConfig {
/// control plane address where we would check auth.
pub auth_endpoint: String,
pub tls_config: Option<TlsConfig>,
pub ssl_config: Option<SslConfig>,
}
pub fn configure_ssl(key_path: &str, cert_path: &str) -> anyhow::Result<TlsConfig> {
pub type ProxyWaiters = crate::waiters::Waiters<Result<DatabaseInfo, String>>;
pub struct ProxyState {
pub conf: ProxyConfig,
pub waiters: ProxyWaiters,
}
impl ProxyState {
pub fn new(conf: ProxyConfig) -> Self {
Self {
conf,
waiters: ProxyWaiters::default(),
}
}
}
pub fn configure_ssl(key_path: &str, cert_path: &str) -> anyhow::Result<SslConfig> {
let key = {
let key_bytes = std::fs::read(key_path).context("SSL key file")?;
let mut keys = pemfile::pkcs8_private_keys(&mut &key_bytes[..])

View File

@@ -1,230 +0,0 @@
use anyhow::Context;
use bytes::BytesMut;
use pin_project_lite::pin_project;
use rustls::ServerConfig;
use std::pin::Pin;
use std::sync::Arc;
use std::{io, task};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
use tokio_rustls::server::TlsStream;
use zenith_utils::pq_proto::{BeMessage, FeMessage, FeStartupPacket};
pin_project! {
/// Stream wrapper which implements libpq's protocol.
/// NOTE: This object deliberately doesn't implement [`AsyncRead`]
/// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
/// to pass random malformed bytes through the connection).
pub struct PqStream<S> {
#[pin]
stream: S,
buffer: BytesMut,
}
}
impl<S> PqStream<S> {
/// Construct a new libpq protocol wrapper.
pub fn new(stream: S) -> Self {
Self {
stream,
buffer: Default::default(),
}
}
/// Extract the underlying stream.
pub fn into_inner(self) -> S {
self.stream
}
/// Get a reference to the underlying stream.
pub fn get_ref(&self) -> &S {
&self.stream
}
}
impl<S: AsyncRead + Unpin> PqStream<S> {
/// Receive [`FeStartupPacket`], which is a first packet sent by a client.
pub async fn read_startup_packet(&mut self) -> anyhow::Result<FeStartupPacket> {
match FeStartupPacket::read_fut(&mut self.stream).await? {
Some(FeMessage::StartupPacket(packet)) => Ok(packet),
None => anyhow::bail!("connection is lost"),
other => anyhow::bail!("bad message type: {:?}", other),
}
}
pub async fn read_message(&mut self) -> anyhow::Result<FeMessage> {
FeMessage::read_fut(&mut self.stream)
.await?
.context("connection is lost")
}
}
impl<S: AsyncWrite + Unpin> PqStream<S> {
/// Write the message into an internal buffer, but don't flush the underlying stream.
pub fn write_message_noflush<'a>(&mut self, message: &BeMessage<'a>) -> io::Result<&mut Self> {
BeMessage::write(&mut self.buffer, message)?;
Ok(self)
}
/// Write the message into an internal buffer and flush it.
pub async fn write_message<'a>(&mut self, message: &BeMessage<'a>) -> io::Result<&mut Self> {
self.write_message_noflush(message)?;
self.flush().await?;
Ok(self)
}
/// Flush the output buffer into the underlying stream.
pub async fn flush(&mut self) -> io::Result<&mut Self> {
self.stream.write_all(&self.buffer).await?;
self.buffer.clear();
self.stream.flush().await?;
Ok(self)
}
}
pin_project! {
/// Wrapper for upgrading raw streams into secure streams.
/// NOTE: it should be possible to decompose this object as necessary.
#[project = StreamProj]
pub enum Stream<S> {
/// We always begin with a raw stream,
/// which may then be upgraded into a secure stream.
Raw { #[pin] raw: S },
/// We box [`TlsStream`] since it can be quite large.
Tls { #[pin] tls: Box<TlsStream<S>> },
}
}
impl<S> Stream<S> {
/// Construct a new instance from a raw stream.
pub fn from_raw(raw: S) -> Self {
Self::Raw { raw }
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> Stream<S> {
/// If possible, upgrade raw stream into a secure TLS-based stream.
pub async fn upgrade(self, cfg: Arc<ServerConfig>) -> anyhow::Result<Self> {
match self {
Stream::Raw { raw } => {
let tls = Box::new(tokio_rustls::TlsAcceptor::from(cfg).accept(raw).await?);
Ok(Stream::Tls { tls })
}
Stream::Tls { .. } => anyhow::bail!("can't upgrade TLS stream"),
}
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Stream<S> {
fn poll_read(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> task::Poll<io::Result<()>> {
use StreamProj::*;
match self.project() {
Raw { raw } => raw.poll_read(context, buf),
Tls { tls } => tls.poll_read(context, buf),
}
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Stream<S> {
fn poll_write(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
buf: &[u8],
) -> task::Poll<io::Result<usize>> {
use StreamProj::*;
match self.project() {
Raw { raw } => raw.poll_write(context, buf),
Tls { tls } => tls.poll_write(context, buf),
}
}
fn poll_flush(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
) -> task::Poll<io::Result<()>> {
use StreamProj::*;
match self.project() {
Raw { raw } => raw.poll_flush(context),
Tls { tls } => tls.poll_flush(context),
}
}
fn poll_shutdown(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
) -> task::Poll<io::Result<()>> {
use StreamProj::*;
match self.project() {
Raw { raw } => raw.poll_shutdown(context),
Tls { tls } => tls.poll_shutdown(context),
}
}
}
pin_project! {
/// This stream tracks all writes and calls user provided
/// callback when the underlying stream is flushed.
pub struct MetricsStream<S, W> {
#[pin]
stream: S,
write_count: usize,
inc_write_count: W,
}
}
impl<S, W> MetricsStream<S, W> {
pub fn new(stream: S, inc_write_count: W) -> Self {
Self {
stream,
write_count: 0,
inc_write_count,
}
}
}
impl<S: AsyncRead + Unpin, W> AsyncRead for MetricsStream<S, W> {
fn poll_read(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> task::Poll<io::Result<()>> {
self.project().stream.poll_read(context, buf)
}
}
impl<S: AsyncWrite + Unpin, W: FnMut(usize)> AsyncWrite for MetricsStream<S, W> {
fn poll_write(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
buf: &[u8],
) -> task::Poll<io::Result<usize>> {
let this = self.project();
this.stream.poll_write(context, buf).map_ok(|cnt| {
// Increment the write count.
*this.write_count += cnt;
cnt
})
}
fn poll_flush(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
) -> task::Poll<io::Result<()>> {
let this = self.project();
this.stream.poll_flush(context).map_ok(|()| {
// Call the user provided callback and reset the write count.
(this.inc_write_count)(*this.write_count);
*this.write_count = 0;
})
}
fn poll_shutdown(
self: Pin<&mut Self>,
context: &mut task::Context<'_>,
) -> task::Poll<io::Result<()>> {
self.project().stream.poll_shutdown(context)
}
}

View File

@@ -1,12 +1,8 @@
use anyhow::{anyhow, Context};
use hashbrown::HashMap;
use parking_lot::Mutex;
use pin_project_lite::pin_project;
use std::pin::Pin;
use std::task;
use tokio::sync::oneshot;
use anyhow::Context;
use std::collections::HashMap;
use std::sync::{mpsc, Mutex};
pub struct Waiters<T>(pub(self) Mutex<HashMap<String, oneshot::Sender<T>>>);
pub struct Waiters<T>(pub(self) Mutex<HashMap<String, mpsc::Sender<T>>>);
impl<T> Default for Waiters<T> {
fn default() -> Self {
@@ -15,86 +11,48 @@ impl<T> Default for Waiters<T> {
}
impl<T> Waiters<T> {
pub fn register(&self, key: String) -> anyhow::Result<Waiter<T>> {
let (tx, rx) = oneshot::channel();
pub fn register(&self, key: String) -> Waiter<T> {
let (tx, rx) = mpsc::channel();
self.0
.lock()
.try_insert(key.clone(), tx)
.map_err(|_| anyhow!("waiter already registered"))?;
// TODO: use `try_insert` (unstable)
let prev = self.0.lock().unwrap().insert(key.clone(), tx);
assert!(matches!(prev, None)); // assert_matches! is nightly-only
Ok(Waiter {
Waiter {
receiver: rx,
guard: DropKey {
registry: self,
key,
},
})
registry: self,
key,
}
}
pub fn notify(&self, key: &str, value: T) -> anyhow::Result<()>
where
T: Send + Sync,
T: Send + Sync + 'static,
{
let tx = self
.0
.lock()
.unwrap()
.remove(key)
.with_context(|| format!("key {} not found", key))?;
tx.send(value).map_err(|_| anyhow!("waiter channel hangup"))
tx.send(value).context("channel hangup")
}
}
struct DropKey<'a, T> {
key: String,
pub struct Waiter<'a, T> {
receiver: mpsc::Receiver<T>,
registry: &'a Waiters<T>,
key: String,
}
impl<'a, T> Drop for DropKey<'a, T> {
impl<T> Waiter<'_, T> {
pub fn wait(self) -> anyhow::Result<T> {
self.receiver.recv().context("channel hangup")
}
}
impl<T> Drop for Waiter<'_, T> {
fn drop(&mut self) {
self.registry.0.lock().remove(&self.key);
}
}
pin_project! {
pub struct Waiter<'a, T> {
#[pin]
receiver: oneshot::Receiver<T>,
guard: DropKey<'a, T>,
}
}
impl<T> std::future::Future for Waiter<'_, T> {
type Output = anyhow::Result<T>;
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
self.project()
.receiver
.poll(cx)
.map_err(|_| anyhow!("channel hangup"))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
#[tokio::test]
async fn test_waiter() -> anyhow::Result<()> {
let waiters = Arc::new(Waiters::default());
let key = "Key";
let waiter = waiters.register(key.to_owned())?;
let waiters = Arc::clone(&waiters);
let notifier = tokio::spawn(async move {
waiters.notify(key, Default::default())?;
Ok(())
});
let () = waiter.await?;
notifier.await?
self.registry.0.lock().unwrap().remove(&self.key);
}
}

View File

@@ -21,6 +21,7 @@ types-psycopg2 = "^2.9.6"
boto3 = "^1.20.40"
boto3-stubs = "^1.20.40"
moto = {version = "^3.0.0", extras = ["server"]}
pytest-skip-slow = "^0.0.2"
[tool.poetry.dev-dependencies]
yapf = "==0.31.0"

View File

@@ -1,10 +1,12 @@
from contextlib import closing
from typing import Iterator
from uuid import UUID, uuid4
from uuid import uuid4
import psycopg2
from fixtures.zenith_fixtures import ZenithEnvBuilder, ZenithPageserverApiException
from fixtures.zenith_fixtures import ZenithEnvBuilder
import pytest
pytest_plugins = ("fixtures.zenith_fixtures")
def test_pageserver_auth(zenith_env_builder: ZenithEnvBuilder):
zenith_env_builder.pageserver_auth_enabled = True
@@ -12,38 +14,32 @@ def test_pageserver_auth(zenith_env_builder: ZenithEnvBuilder):
ps = env.pageserver
tenant_token = env.auth_keys.generate_tenant_token(env.initial_tenant.hex)
tenant_http_client = env.pageserver.http_client(tenant_token)
tenant_token = env.auth_keys.generate_tenant_token(env.initial_tenant)
invalid_tenant_token = env.auth_keys.generate_tenant_token(uuid4().hex)
invalid_tenant_http_client = env.pageserver.http_client(invalid_tenant_token)
management_token = env.auth_keys.generate_management_token()
management_http_client = env.pageserver.http_client(management_token)
# this does not invoke auth check and only decodes jwt and checks it for validity
# check both tokens
ps.safe_psql("set FOO", password=tenant_token)
ps.safe_psql("set FOO", password=management_token)
ps.safe_psql("status", password=tenant_token)
ps.safe_psql("status", password=management_token)
# tenant can create branches
tenant_http_client.branch_create(env.initial_tenant, 'new1', 'main')
ps.safe_psql(f"branch_create {env.initial_tenant} new1 main", password=tenant_token)
# console can create branches for tenant
management_http_client.branch_create(env.initial_tenant, 'new2', 'main')
ps.safe_psql(f"branch_create {env.initial_tenant} new2 main", password=management_token)
# fail to create branch using token with different tenant_id
with pytest.raises(ZenithPageserverApiException,
match='Forbidden: Tenant id mismatch. Permission denied'):
invalid_tenant_http_client.branch_create(env.initial_tenant, "new3", "main")
# fail to create branch using token with different tenantid
with pytest.raises(psycopg2.DatabaseError, match='Tenant id mismatch. Permission denied'):
ps.safe_psql(f"branch_create {env.initial_tenant} new2 main", password=invalid_tenant_token)
# create tenant using management token
management_http_client.tenant_create(uuid4())
ps.safe_psql(f"tenant_create {uuid4().hex}", password=management_token)
# fail to create tenant using tenant token
with pytest.raises(
ZenithPageserverApiException,
match='Forbidden: Attempt to access management api with tenant scope. Permission denied'
):
tenant_http_client.tenant_create(uuid4())
psycopg2.DatabaseError,
match='Attempt to access management api with tenant scope. Permission denied'):
ps.safe_psql(f"tenant_create {uuid4().hex}", password=tenant_token)
@pytest.mark.parametrize('with_wal_acceptors', [False, True])
@@ -54,7 +50,7 @@ def test_compute_auth_to_pageserver(zenith_env_builder: ZenithEnvBuilder, with_w
env = zenith_env_builder.init()
branch = f"test_compute_auth_to_pageserver{with_wal_acceptors}"
env.zenith_cli.create_branch(branch, "main")
env.zenith_cli(["branch", branch, "main"])
pg = env.postgres.create_start(branch)

View File

@@ -1,154 +0,0 @@
from contextlib import closing, contextmanager
import psycopg2.extras
from fixtures.zenith_fixtures import ZenithEnvBuilder
from fixtures.log_helper import log
import os
import time
import asyncpg
from fixtures.zenith_fixtures import Postgres
import threading
pytest_plugins = ("fixtures.zenith_fixtures")
@contextmanager
def pg_cur(pg):
with closing(pg.connect()) as conn:
with conn.cursor() as cur:
yield cur
# Periodically check that all backpressure lags are below the configured threshold,
# assert if they are not.
# If the check query fails, stop the thread. Main thread should notice that and stop the test.
def check_backpressure(pg: Postgres, stop_event: threading.Event, polling_interval=5):
log.info("checks started")
with pg_cur(pg) as cur:
cur.execute("CREATE EXTENSION zenith") # TODO move it to zenith_fixtures?
cur.execute("select pg_size_bytes(current_setting('max_replication_write_lag'))")
res = cur.fetchone()
max_replication_write_lag_bytes = res[0]
log.info(f"max_replication_write_lag: {max_replication_write_lag_bytes} bytes")
cur.execute("select pg_size_bytes(current_setting('max_replication_flush_lag'))")
res = cur.fetchone()
max_replication_flush_lag_bytes = res[0]
log.info(f"max_replication_flush_lag: {max_replication_flush_lag_bytes} bytes")
cur.execute("select pg_size_bytes(current_setting('max_replication_apply_lag'))")
res = cur.fetchone()
max_replication_apply_lag_bytes = res[0]
log.info(f"max_replication_apply_lag: {max_replication_apply_lag_bytes} bytes")
with pg_cur(pg) as cur:
while not stop_event.is_set():
try:
cur.execute('''
select pg_wal_lsn_diff(pg_current_wal_flush_lsn(),received_lsn) as received_lsn_lag,
pg_wal_lsn_diff(pg_current_wal_flush_lsn(),disk_consistent_lsn) as disk_consistent_lsn_lag,
pg_wal_lsn_diff(pg_current_wal_flush_lsn(),remote_consistent_lsn) as remote_consistent_lsn_lag,
pg_size_pretty(pg_wal_lsn_diff(pg_current_wal_flush_lsn(),received_lsn)),
pg_size_pretty(pg_wal_lsn_diff(pg_current_wal_flush_lsn(),disk_consistent_lsn)),
pg_size_pretty(pg_wal_lsn_diff(pg_current_wal_flush_lsn(),remote_consistent_lsn))
from backpressure_lsns();
''')
res = cur.fetchone()
received_lsn_lag = res[0]
disk_consistent_lsn_lag = res[1]
remote_consistent_lsn_lag = res[2]
log.info(f"received_lsn_lag = {received_lsn_lag} ({res[3]}), "
f"disk_consistent_lsn_lag = {disk_consistent_lsn_lag} ({res[4]}), "
f"remote_consistent_lsn_lag = {remote_consistent_lsn_lag} ({res[5]})")
# Since feedback from pageserver is not immediate, we should allow some lag overflow
lag_overflow = 5 * 1024 * 1024 # 5MB
if max_replication_write_lag_bytes > 0:
assert received_lsn_lag < max_replication_write_lag_bytes + lag_overflow
if max_replication_flush_lag_bytes > 0:
assert disk_consistent_lsn_lag < max_replication_flush_lag_bytes + lag_overflow
if max_replication_apply_lag_bytes > 0:
assert remote_consistent_lsn_lag < max_replication_apply_lag_bytes + lag_overflow
time.sleep(polling_interval)
except Exception as e:
log.info(f"backpressure check query failed: {e}")
stop_event.set()
log.info('check thread stopped')
# This test illustrates how to tune backpressure to control the lag
# between the WAL flushed on compute node and WAL digested by pageserver.
#
# To test it, throttle walreceiver ingest using failpoint and run heavy write load.
# If backpressure is disabled or not tuned properly, the query will timeout, because the walreceiver cannot keep up.
# If backpressure is enabled and tuned properly, insertion will be throttled, but the query will not timeout.
def test_backpressure_received_lsn_lag(zenith_env_builder: ZenithEnvBuilder):
zenith_env_builder.num_safekeepers = 1
env = zenith_env_builder.init()
# Create a branch for us
env.zenith_cli.create_branch("test_backpressure", "main")
pg = env.postgres.create_start('test_backpressure',
config_lines=['max_replication_write_lag=30MB'])
log.info("postgres is running on 'test_backpressure' branch")
# setup check thread
check_stop_event = threading.Event()
check_thread = threading.Thread(target=check_backpressure, args=(pg, check_stop_event))
check_thread.start()
# Configure failpoint to slow down walreceiver ingest
with closing(env.pageserver.connect()) as psconn:
with psconn.cursor(cursor_factory=psycopg2.extras.DictCursor) as pscur:
pscur.execute("failpoints walreceiver-after-ingest=sleep(20)")
# FIXME
# Wait for the check thread to start
#
# Now if load starts too soon,
# check thread cannot auth, because it is not able to connect to the database
# because of the lag and waiting for lsn to replay to arrive.
time.sleep(2)
with pg_cur(pg) as cur:
# Create and initialize test table
cur.execute("CREATE TABLE foo(x bigint)")
inserts_to_do = 2000000
rows_inserted = 0
while check_thread.is_alive() and rows_inserted < inserts_to_do:
try:
cur.execute("INSERT INTO foo select from generate_series(1, 100000)")
rows_inserted += 100000
except Exception as e:
if check_thread.is_alive():
log.info('stopping check thread')
check_stop_event.set()
check_thread.join()
assert False, f"Exception {e} while inserting rows, but WAL lag is within configured threshold. That means backpressure is not tuned properly"
else:
assert False, f"Exception {e} while inserting rows and WAL lag overflowed configured threshold. That means backpressure doesn't work."
log.info(f"inserted {rows_inserted} rows")
if check_thread.is_alive():
log.info('stopping check thread')
check_stop_event.set()
check_thread.join()
log.info('check thread stopped')
else:
assert False, "WAL lag overflowed configured threshold. That means backpressure doesn't work."
#TODO test_backpressure_disk_consistent_lsn_lag. Play with pageserver's checkpoint settings
#TODO test_backpressure_remote_consistent_lsn_lag

View File

@@ -7,6 +7,8 @@ from fixtures.log_helper import log
from fixtures.utils import print_gc_result
from fixtures.zenith_fixtures import ZenithEnvBuilder
pytest_plugins = ("fixtures.zenith_fixtures")
#
# Create a couple of branches off the main branch, at a historical point in time.
@@ -22,7 +24,7 @@ def test_branch_behind(zenith_env_builder: ZenithEnvBuilder):
env = zenith_env_builder.init()
# Branch at the point where only 100 rows were inserted
env.zenith_cli.create_branch("test_branch_behind", "main")
env.zenith_cli(["branch", "test_branch_behind", "main"])
pgmain = env.postgres.create_start('test_branch_behind')
log.info("postgres is running on 'test_branch_behind' branch")
@@ -60,7 +62,7 @@ def test_branch_behind(zenith_env_builder: ZenithEnvBuilder):
log.info(f'LSN after 200100 rows: {lsn_b}')
# Branch at the point where only 100 rows were inserted
env.zenith_cli.create_branch("test_branch_behind_hundred", "test_branch_behind@" + lsn_a)
env.zenith_cli(["branch", "test_branch_behind_hundred", "test_branch_behind@" + lsn_a])
# Insert many more rows. This generates enough WAL to fill a few segments.
main_cur.execute('''
@@ -75,7 +77,7 @@ def test_branch_behind(zenith_env_builder: ZenithEnvBuilder):
log.info(f'LSN after 400100 rows: {lsn_c}')
# Branch at the point where only 200100 rows were inserted
env.zenith_cli.create_branch("test_branch_behind_more", "test_branch_behind@" + lsn_b)
env.zenith_cli(["branch", "test_branch_behind_more", "test_branch_behind@" + lsn_b])
pg_hundred = env.postgres.create_start("test_branch_behind_hundred")
pg_more = env.postgres.create_start("test_branch_behind_more")
@@ -99,7 +101,7 @@ def test_branch_behind(zenith_env_builder: ZenithEnvBuilder):
# Check bad lsn's for branching
# branch at segment boundary
env.zenith_cli.create_branch("test_branch_segment_boundary", "test_branch_behind@0/3000000")
env.zenith_cli(["branch", "test_branch_segment_boundary", "test_branch_behind@0/3000000"])
pg = env.postgres.create_start("test_branch_segment_boundary")
cur = pg.connect().cursor()
cur.execute('SELECT 1')
@@ -107,23 +109,23 @@ def test_branch_behind(zenith_env_builder: ZenithEnvBuilder):
# branch at pre-initdb lsn
with pytest.raises(Exception, match="invalid branch start lsn"):
env.zenith_cli.create_branch("test_branch_preinitdb", "main@0/42")
env.zenith_cli(["branch", "test_branch_preinitdb", "main@0/42"])
# branch at pre-ancestor lsn
with pytest.raises(Exception, match="less than timeline ancestor lsn"):
env.zenith_cli.create_branch("test_branch_preinitdb", "test_branch_behind@0/42")
env.zenith_cli(["branch", "test_branch_preinitdb", "test_branch_behind@0/42"])
# check that we cannot create branch based on garbage collected data
with closing(env.pageserver.connect()) as psconn:
with psconn.cursor(cursor_factory=psycopg2.extras.DictCursor) as pscur:
# call gc to advace latest_gc_cutoff_lsn
pscur.execute(f"do_gc {env.initial_tenant.hex} {timeline} 0")
pscur.execute(f"do_gc {env.initial_tenant} {timeline} 0")
row = pscur.fetchone()
print_gc_result(row)
with pytest.raises(Exception, match="invalid branch start lsn"):
# this gced_lsn is pretty random, so if gc is disabled this woudln't fail
env.zenith_cli.create_branch("test_branch_create_fail", f"test_branch_behind@{gced_lsn}")
env.zenith_cli(["branch", "test_branch_create_fail", f"test_branch_behind@{gced_lsn}"])
# check that after gc everything is still there
hundred_cur.execute('SELECT count(*) FROM foo')

View File

@@ -6,13 +6,16 @@ from contextlib import closing
from fixtures.zenith_fixtures import ZenithEnv
from fixtures.log_helper import log
pytest_plugins = ("fixtures.zenith_fixtures")
#
# Test compute node start after clog truncation
#
def test_clog_truncate(zenith_simple_env: ZenithEnv):
env = zenith_simple_env
env.zenith_cli.create_branch("test_clog_truncate", "empty")
# Create a branch for us
env.zenith_cli(["branch", "test_clog_truncate", "empty"])
# set agressive autovacuum to make sure that truncation will happen
config = [
@@ -62,8 +65,8 @@ def test_clog_truncate(zenith_simple_env: ZenithEnv):
# create new branch after clog truncation and start a compute node on it
log.info(f'create branch at lsn_after_truncation {lsn_after_truncation}')
env.zenith_cli.create_branch("test_clog_truncate_new",
"test_clog_truncate@" + lsn_after_truncation)
env.zenith_cli(
["branch", "test_clog_truncate_new", "test_clog_truncate@" + lsn_after_truncation])
pg2 = env.postgres.create_start('test_clog_truncate_new')
log.info('postgres is running on test_clog_truncate_new branch')

View File

@@ -3,13 +3,16 @@ from contextlib import closing
from fixtures.zenith_fixtures import ZenithEnv
from fixtures.log_helper import log
pytest_plugins = ("fixtures.zenith_fixtures")
#
# Test starting Postgres with custom options
#
def test_config(zenith_simple_env: ZenithEnv):
env = zenith_simple_env
env.zenith_cli.create_branch("test_config", "empty")
# Create a branch for us
env.zenith_cli(["branch", "test_config", "empty"])
# change config
pg = env.postgres.create_start('test_config', config_lines=['log_min_messages=debug1'])

View File

@@ -5,13 +5,15 @@ from contextlib import closing
from fixtures.zenith_fixtures import ZenithEnv, check_restored_datadir_content
from fixtures.log_helper import log
pytest_plugins = ("fixtures.zenith_fixtures")
#
# Test CREATE DATABASE when there have been relmapper changes
#
def test_createdb(zenith_simple_env: ZenithEnv):
env = zenith_simple_env
env.zenith_cli.create_branch("test_createdb", "empty")
env.zenith_cli(["branch", "test_createdb", "empty"])
pg = env.postgres.create_start('test_createdb')
log.info("postgres is running on 'test_createdb' branch")
@@ -27,7 +29,7 @@ def test_createdb(zenith_simple_env: ZenithEnv):
lsn = cur.fetchone()[0]
# Create a branch
env.zenith_cli.create_branch("test_createdb2", "test_createdb@" + lsn)
env.zenith_cli(["branch", "test_createdb2", "test_createdb@" + lsn])
pg2 = env.postgres.create_start('test_createdb2')
@@ -41,7 +43,7 @@ def test_createdb(zenith_simple_env: ZenithEnv):
#
def test_dropdb(zenith_simple_env: ZenithEnv, test_output_dir):
env = zenith_simple_env
env.zenith_cli.create_branch("test_dropdb", "empty")
env.zenith_cli(["branch", "test_dropdb", "empty"])
pg = env.postgres.create_start('test_dropdb')
log.info("postgres is running on 'test_dropdb' branch")
@@ -66,10 +68,10 @@ def test_dropdb(zenith_simple_env: ZenithEnv, test_output_dir):
lsn_after_drop = cur.fetchone()[0]
# Create two branches before and after database drop.
env.zenith_cli.create_branch("test_before_dropdb", "test_dropdb@" + lsn_before_drop)
env.zenith_cli(["branch", "test_before_dropdb", "test_dropdb@" + lsn_before_drop])
pg_before = env.postgres.create_start('test_before_dropdb')
env.zenith_cli.create_branch("test_after_dropdb", "test_dropdb@" + lsn_after_drop)
env.zenith_cli(["branch", "test_after_dropdb", "test_dropdb@" + lsn_after_drop])
pg_after = env.postgres.create_start('test_after_dropdb')
# Test that database exists on the branch before drop

View File

@@ -3,13 +3,15 @@ from contextlib import closing
from fixtures.zenith_fixtures import ZenithEnv
from fixtures.log_helper import log
pytest_plugins = ("fixtures.zenith_fixtures")
#
# Test CREATE USER to check shared catalog restore
#
def test_createuser(zenith_simple_env: ZenithEnv):
env = zenith_simple_env
env.zenith_cli.create_branch("test_createuser", "empty")
env.zenith_cli(["branch", "test_createuser", "empty"])
pg = env.postgres.create_start('test_createuser')
log.info("postgres is running on 'test_createuser' branch")
@@ -25,7 +27,7 @@ def test_createuser(zenith_simple_env: ZenithEnv):
lsn = cur.fetchone()[0]
# Create a branch
env.zenith_cli.create_branch("test_createuser2", "test_createuser@" + lsn)
env.zenith_cli(["branch", "test_createuser2", "test_createuser@" + lsn])
pg2 = env.postgres.create_start('test_createuser2')

View File

@@ -7,6 +7,8 @@ import random
from fixtures.zenith_fixtures import ZenithEnv, Postgres, Safekeeper
from fixtures.log_helper import log
pytest_plugins = ("fixtures.zenith_fixtures")
# Test configuration
#
# Create a table with {num_rows} rows, and perform {updates_to_perform} random
@@ -34,7 +36,7 @@ async def gc(env: ZenithEnv, timeline: str):
psconn = await env.pageserver.connect_async()
while updates_performed < updates_to_perform:
await psconn.execute(f"do_gc {env.initial_tenant.hex} {timeline} 0")
await psconn.execute(f"do_gc {env.initial_tenant} {timeline} 0")
# At the same time, run UPDATEs and GC
@@ -55,7 +57,9 @@ async def update_and_gc(env: ZenithEnv, pg: Postgres, timeline: str):
#
def test_gc_aggressive(zenith_simple_env: ZenithEnv):
env = zenith_simple_env
env.zenith_cli.create_branch("test_gc_aggressive", "empty")
# Create a branch for us
env.zenith_cli(["branch", "test_gc_aggressive", "empty"])
pg = env.postgres.create_start('test_gc_aggressive')
log.info('postgres is running on test_gc_aggressive branch')

View File

@@ -1,6 +1,8 @@
from fixtures.zenith_fixtures import ZenithEnv, check_restored_datadir_content
from fixtures.log_helper import log
pytest_plugins = ("fixtures.zenith_fixtures")
#
# Test multixact state after branching
@@ -10,7 +12,8 @@ from fixtures.log_helper import log
#
def test_multixact(zenith_simple_env: ZenithEnv, test_output_dir):
env = zenith_simple_env
env.zenith_cli.create_branch("test_multixact", "empty")
# Create a branch for us
env.zenith_cli(["branch", "test_multixact", "empty"])
pg = env.postgres.create_start('test_multixact')
log.info("postgres is running on 'test_multixact' branch")
@@ -60,7 +63,7 @@ def test_multixact(zenith_simple_env: ZenithEnv, test_output_dir):
assert int(next_multixact_id) > int(next_multixact_id_old)
# Branch at this point
env.zenith_cli.create_branch("test_multixact_new", "test_multixact@" + lsn)
env.zenith_cli(["branch", "test_multixact_new", "test_multixact@" + lsn])
pg_new = env.postgres.create_start('test_multixact_new')
log.info("postgres is running on 'test_multixact_new' branch")

View File

@@ -5,6 +5,8 @@ import time
from fixtures.zenith_fixtures import ZenithEnvBuilder
from fixtures.log_helper import log
pytest_plugins = ("fixtures.zenith_fixtures")
# Test restarting page server, while safekeeper and compute node keep
# running.

View File

@@ -3,6 +3,8 @@ from contextlib import closing
from fixtures.zenith_fixtures import ZenithEnv
from fixtures.log_helper import log
pytest_plugins = ("fixtures.zenith_fixtures")
#
# Test where Postgres generates a lot of WAL, and it's garbage collected away, but
@@ -16,7 +18,8 @@ from fixtures.log_helper import log
#
def test_old_request_lsn(zenith_simple_env: ZenithEnv):
env = zenith_simple_env
env.zenith_cli.create_branch("test_old_request_lsn", "empty")
# Create a branch for us
env.zenith_cli(["branch", "test_old_request_lsn", "empty"])
pg = env.postgres.create_start('test_old_request_lsn')
log.info('postgres is running on test_old_request_lsn branch')
@@ -54,7 +57,7 @@ def test_old_request_lsn(zenith_simple_env: ZenithEnv):
# Make a lot of updates on a single row, generating a lot of WAL. Trigger
# garbage collections so that the page server will remove old page versions.
for i in range(10):
pscur.execute(f"do_gc {env.initial_tenant.hex} {timeline} 0")
pscur.execute(f"do_gc {env.initial_tenant} {timeline} 0")
for j in range(100):
cur.execute('UPDATE foo SET val = val + 1 WHERE id = 1;')

View File

@@ -1,15 +1,95 @@
import json
from uuid import uuid4, UUID
import pytest
import psycopg2
import requests
from fixtures.zenith_fixtures import ZenithEnv, ZenithEnvBuilder, ZenithPageserverHttpClient
from typing import cast
import pytest, psycopg2
pytest_plugins = ("fixtures.zenith_fixtures")
def check_client(client: ZenithPageserverHttpClient, initial_tenant: UUID):
def test_status_psql(zenith_simple_env: ZenithEnv):
env = zenith_simple_env
assert env.pageserver.safe_psql('status') == [
('hello world', ),
]
def test_branch_list_psql(zenith_simple_env: ZenithEnv):
env = zenith_simple_env
# Create a branch for us
env.zenith_cli(["branch", "test_branch_list_main", "empty"])
conn = env.pageserver.connect()
cur = conn.cursor()
cur.execute(f'branch_list {env.initial_tenant}')
branches = json.loads(cur.fetchone()[0])
# Filter out branches created by other tests
branches = [x for x in branches if x['name'].startswith('test_branch_list')]
assert len(branches) == 1
assert branches[0]['name'] == 'test_branch_list_main'
assert 'timeline_id' in branches[0]
assert 'latest_valid_lsn' in branches[0]
assert 'ancestor_id' in branches[0]
assert 'ancestor_lsn' in branches[0]
# Create another branch, and start Postgres on it
env.zenith_cli(['branch', 'test_branch_list_experimental', 'test_branch_list_main'])
env.zenith_cli(['pg', 'create', 'test_branch_list_experimental'])
cur.execute(f'branch_list {env.initial_tenant}')
new_branches = json.loads(cur.fetchone()[0])
# Filter out branches created by other tests
new_branches = [x for x in new_branches if x['name'].startswith('test_branch_list')]
assert len(new_branches) == 2
new_branches.sort(key=lambda k: k['name'])
assert new_branches[0]['name'] == 'test_branch_list_experimental'
assert new_branches[0]['timeline_id'] != branches[0]['timeline_id']
# TODO: do the LSNs have to match here?
assert new_branches[1] == branches[0]
conn.close()
def test_tenant_list_psql(zenith_env_builder: ZenithEnvBuilder):
# don't use zenith_simple_env, because there might be other tenants there,
# left over from other tests.
env = zenith_env_builder.init()
res = env.zenith_cli(["tenant", "list"])
res.check_returncode()
tenants = sorted(map(lambda t: t.split()[0], res.stdout.splitlines()))
assert tenants == [env.initial_tenant]
conn = env.pageserver.connect()
cur = conn.cursor()
# check same tenant cannot be created twice
with pytest.raises(psycopg2.DatabaseError,
match=f'repo for {env.initial_tenant} already exists'):
cur.execute(f'tenant_create {env.initial_tenant}')
# create one more tenant
tenant1 = uuid4().hex
cur.execute(f'tenant_create {tenant1}')
cur.execute('tenant_list')
# compare tenants list
new_tenants = sorted(map(lambda t: cast(str, t['id']), json.loads(cur.fetchone()[0])))
assert sorted([env.initial_tenant, tenant1]) == new_tenants
def check_client(client: ZenithPageserverHttpClient, initial_tenant: str):
client.check_status()
# check initial tenant is there
assert initial_tenant.hex in {t['id'] for t in client.tenant_list()}
assert initial_tenant in {t['id'] for t in client.tenant_list()}
# create new tenant and check it is also there
tenant_id = uuid4()

View File

@@ -7,6 +7,8 @@ from multiprocessing import Process, Value
from fixtures.zenith_fixtures import ZenithEnvBuilder
from fixtures.log_helper import log
pytest_plugins = ("fixtures.zenith_fixtures")
# Test safekeeper sync and pageserver catch up
# while initial compute node is down and pageserver is lagging behind safekeepers.
@@ -16,7 +18,7 @@ def test_pageserver_catchup_while_compute_down(zenith_env_builder: ZenithEnvBuil
zenith_env_builder.num_safekeepers = 3
env = zenith_env_builder.init()
env.zenith_cli.create_branch("test_pageserver_catchup_while_compute_down", "main")
env.zenith_cli(["branch", "test_pageserver_catchup_while_compute_down", "main"])
pg = env.postgres.create_start('test_pageserver_catchup_while_compute_down')
pg_conn = pg.connect()

View File

@@ -7,6 +7,8 @@ from multiprocessing import Process, Value
from fixtures.zenith_fixtures import ZenithEnvBuilder
from fixtures.log_helper import log
pytest_plugins = ("fixtures.zenith_fixtures")
# Test restarting page server, while safekeeper and compute node keep
# running.
@@ -15,7 +17,7 @@ def test_pageserver_restart(zenith_env_builder: ZenithEnvBuilder):
zenith_env_builder.num_safekeepers = 1
env = zenith_env_builder.init()
env.zenith_cli.create_branch("test_pageserver_restart", "main")
env.zenith_cli(["branch", "test_pageserver_restart", "main"])
pg = env.postgres.create_start('test_pageserver_restart')
pg_conn = pg.connect()

View File

@@ -5,6 +5,8 @@ import subprocess
from fixtures.zenith_fixtures import ZenithEnv, Postgres
from fixtures.log_helper import log
pytest_plugins = ("fixtures.zenith_fixtures")
async def repeat_bytes(buf, repetitions: int):
for i in range(repetitions):
@@ -37,7 +39,9 @@ async def parallel_load_same_table(pg: Postgres, n_parallel: int):
# Load data into one table with COPY TO from 5 parallel connections
def test_parallel_copy(zenith_simple_env: ZenithEnv, n_parallel=5):
env = zenith_simple_env
env.zenith_cli.create_branch("test_parallel_copy", "empty")
# Create a branch for us
env.zenith_cli(["branch", "test_parallel_copy", "empty"])
pg = env.postgres.create_start('test_parallel_copy')
log.info("postgres is running on 'test_parallel_copy' branch")

View File

@@ -1,10 +1,14 @@
from fixtures.zenith_fixtures import ZenithEnv
from fixtures.log_helper import log
pytest_plugins = ("fixtures.zenith_fixtures")
def test_pgbench(zenith_simple_env: ZenithEnv, pg_bin):
env = zenith_simple_env
env.zenith_cli.create_branch("test_pgbench", "empty")
# Create a branch for us
env.zenith_cli(["branch", "test_pgbench", "empty"])
pg = env.postgres.create_start('test_pgbench')
log.info("postgres is running on 'test_pgbench' branch")

View File

@@ -1,54 +0,0 @@
import pytest
import subprocess
import signal
import time
def test_proxy_select_1(static_proxy):
static_proxy.safe_psql("select 1;")
def test_proxy_cancel(static_proxy):
"""Test that we can cancel a big generate_series query."""
conn = static_proxy.connect()
conn.cancel()
with conn.cursor() as cur:
from psycopg2.errors import QueryCanceled
with pytest.raises(QueryCanceled):
cur.execute("select * from generate_series(1, 100000000);")
def test_proxy_pgbench_cancel(static_proxy, pg_bin):
"""Test that we can cancel the init phase of pgbench."""
start_time = static_proxy.safe_psql("select now();")[0]
def get_running_queries():
magic_string = "fsdsdfhdfhfgbcbfgbfgbf"
with static_proxy.connect() as conn:
with conn.cursor() as cur:
cur.execute(f"""
-- {magic_string}
select query
from pg_stat_activity
where pg_stat_activity.query_start > %s
""", start_time)
return [
row[0]
for row in cur.fetchall()
if not magic_string in row[0]
]
# Let pgbench init run for 1 second
p = subprocess.Popen(['pgbench', '-s500', '-i', static_proxy.connstr()])
time.sleep(1)
# Make sure something is still running, and that get_running_queries works
assert len(get_running_queries()) > 0
# Send sigint, which would cancel any pgbench queries
p.send_signal(signal.SIGINT)
# Assert that nothing is running
time.sleep(1)
assert len(get_running_queries()) == 0

View File

@@ -2,6 +2,8 @@ import pytest
from fixtures.log_helper import log
from fixtures.zenith_fixtures import ZenithEnv
pytest_plugins = ("fixtures.zenith_fixtures")
#
# Create read-only compute nodes, anchored at historical points in time.
@@ -11,7 +13,7 @@ from fixtures.zenith_fixtures import ZenithEnv
#
def test_readonly_node(zenith_simple_env: ZenithEnv):
env = zenith_simple_env
env.zenith_cli.create_branch("test_readonly_node", "empty")
env.zenith_cli(["branch", "test_readonly_node", "empty"])
pgmain = env.postgres.create_start('test_readonly_node')
log.info("postgres is running on 'test_readonly_node' branch")
@@ -86,5 +88,4 @@ def test_readonly_node(zenith_simple_env: ZenithEnv):
# Create node at pre-initdb lsn
with pytest.raises(Exception, match="invalid basebackup lsn"):
# compute node startup with invalid LSN should fail
env.zenith_cli.pg_start("test_readonly_node_preinitdb",
timeline_spec="test_readonly_node@0/42")
env.zenith_cli(["pg", "start", "test_readonly_node_preinitdb", "test_readonly_node@0/42"])

View File

@@ -9,6 +9,8 @@ from fixtures.zenith_fixtures import ZenithEnvBuilder
from fixtures.log_helper import log
import pytest
pytest_plugins = ("fixtures.zenith_fixtures")
#
# Tests that a piece of data is backed up and restored correctly:

View File

@@ -4,6 +4,8 @@ from contextlib import closing
from fixtures.zenith_fixtures import ZenithEnvBuilder
from fixtures.log_helper import log
pytest_plugins = ("fixtures.zenith_fixtures")
#
# Test restarting and recreating a postgres instance
@@ -15,7 +17,7 @@ def test_restart_compute(zenith_env_builder: ZenithEnvBuilder, with_wal_acceptor
zenith_env_builder.num_safekeepers = 3
env = zenith_env_builder.init()
env.zenith_cli.create_branch("test_restart_compute", "main")
env.zenith_cli(["branch", "test_restart_compute", "main"])
pg = env.postgres.create_start('test_restart_compute')
log.info("postgres is running on 'test_restart_compute' branch")

View File

@@ -5,6 +5,8 @@ from fixtures.utils import print_gc_result
from fixtures.zenith_fixtures import ZenithEnv
from fixtures.log_helper import log
pytest_plugins = ("fixtures.zenith_fixtures")
#
# Test Garbage Collection of old layer files
@@ -14,7 +16,7 @@ from fixtures.log_helper import log
#
def test_layerfiles_gc(zenith_simple_env: ZenithEnv):
env = zenith_simple_env
env.zenith_cli.create_branch("test_layerfiles_gc", "empty")
env.zenith_cli(["branch", "test_layerfiles_gc", "empty"])
pg = env.postgres.create_start('test_layerfiles_gc')
with closing(pg.connect()) as conn:
@@ -48,7 +50,7 @@ def test_layerfiles_gc(zenith_simple_env: ZenithEnv):
cur.execute("DELETE FROM foo")
log.info("Running GC before test")
pscur.execute(f"do_gc {env.initial_tenant.hex} {timeline} 0")
pscur.execute(f"do_gc {env.initial_tenant} {timeline} 0")
row = pscur.fetchone()
print_gc_result(row)
# remember the number of files
@@ -61,7 +63,7 @@ def test_layerfiles_gc(zenith_simple_env: ZenithEnv):
# removing the old image and delta layer.
log.info("Inserting one row and running GC")
cur.execute("INSERT INTO foo VALUES (1)")
pscur.execute(f"do_gc {env.initial_tenant.hex} {timeline} 0")
pscur.execute(f"do_gc {env.initial_tenant} {timeline} 0")
row = pscur.fetchone()
print_gc_result(row)
assert row['layer_relfiles_total'] == layer_relfiles_remain + 2
@@ -75,7 +77,7 @@ def test_layerfiles_gc(zenith_simple_env: ZenithEnv):
cur.execute("INSERT INTO foo VALUES (2)")
cur.execute("INSERT INTO foo VALUES (3)")
pscur.execute(f"do_gc {env.initial_tenant.hex} {timeline} 0")
pscur.execute(f"do_gc {env.initial_tenant} {timeline} 0")
row = pscur.fetchone()
print_gc_result(row)
assert row['layer_relfiles_total'] == layer_relfiles_remain + 2
@@ -87,7 +89,7 @@ def test_layerfiles_gc(zenith_simple_env: ZenithEnv):
cur.execute("INSERT INTO foo VALUES (2)")
cur.execute("INSERT INTO foo VALUES (3)")
pscur.execute(f"do_gc {env.initial_tenant.hex} {timeline} 0")
pscur.execute(f"do_gc {env.initial_tenant} {timeline} 0")
row = pscur.fetchone()
print_gc_result(row)
assert row['layer_relfiles_total'] == layer_relfiles_remain + 2
@@ -96,7 +98,7 @@ def test_layerfiles_gc(zenith_simple_env: ZenithEnv):
# Run GC again, with no changes in the database. Should not remove anything.
log.info("Run GC again, with nothing to do")
pscur.execute(f"do_gc {env.initial_tenant.hex} {timeline} 0")
pscur.execute(f"do_gc {env.initial_tenant} {timeline} 0")
row = pscur.fetchone()
print_gc_result(row)
assert row['layer_relfiles_total'] == layer_relfiles_remain
@@ -109,7 +111,7 @@ def test_layerfiles_gc(zenith_simple_env: ZenithEnv):
log.info("Drop table and run GC again")
cur.execute("DROP TABLE foo")
pscur.execute(f"do_gc {env.initial_tenant.hex} {timeline} 0")
pscur.execute(f"do_gc {env.initial_tenant} {timeline} 0")
row = pscur.fetchone()
print_gc_result(row)

View File

@@ -1,6 +1,8 @@
from fixtures.zenith_fixtures import ZenithEnv, check_restored_datadir_content
from fixtures.log_helper import log
pytest_plugins = ("fixtures.zenith_fixtures")
# Test subtransactions
#
@@ -10,7 +12,8 @@ from fixtures.log_helper import log
# CLOG.
def test_subxacts(zenith_simple_env: ZenithEnv, test_output_dir):
env = zenith_simple_env
env.zenith_cli.create_branch("test_subxacts", "empty")
# Create a branch for us
env.zenith_cli(["branch", "test_subxacts", "empty"])
pg = env.postgres.create_start('test_subxacts')
log.info("postgres is running on 'test_subxacts' branch")

View File

@@ -108,8 +108,8 @@ def load(pg: Postgres, stop_event: threading.Event, load_ok_event: threading.Eve
log.info('load thread stopped')
def assert_local(pageserver_http_client: ZenithPageserverHttpClient, tenant: UUID, timeline: str):
timeline_detail = pageserver_http_client.timeline_detail(tenant, UUID(timeline))
def assert_local(pageserver_http_client: ZenithPageserverHttpClient, tenant: str, timeline: str):
timeline_detail = pageserver_http_client.timeline_detail(UUID(tenant), UUID(timeline))
assert timeline_detail.get('type') == "Local", timeline_detail
return timeline_detail
@@ -127,10 +127,10 @@ def test_tenant_relocation(zenith_env_builder: ZenithEnvBuilder,
# create folder for remote storage mock
remote_storage_mock_path = env.repo_dir / 'local_fs_remote_storage'
tenant = env.create_tenant(UUID("74ee8b079a0e437eb0afea7d26a07209"))
tenant = env.create_tenant("74ee8b079a0e437eb0afea7d26a07209")
log.info("tenant to relocate %s", tenant)
env.zenith_cli.create_branch("test_tenant_relocation", "main", tenant_id=tenant)
env.zenith_cli(["branch", "test_tenant_relocation", "main", f"--tenantid={tenant}"])
tenant_pg = env.postgres.create_start(
"test_tenant_relocation",
@@ -167,11 +167,11 @@ def test_tenant_relocation(zenith_env_builder: ZenithEnvBuilder,
# run checkpoint manually to be sure that data landed in remote storage
with closing(env.pageserver.connect()) as psconn:
with psconn.cursor() as pscur:
pscur.execute(f"do_gc {tenant.hex} {timeline}")
pscur.execute(f"do_gc {tenant} {timeline}")
# ensure upload is completed
pageserver_http_client = env.pageserver.http_client()
timeline_detail = pageserver_http_client.timeline_detail(tenant, UUID(timeline))
timeline_detail = pageserver_http_client.timeline_detail(UUID(tenant), UUID(timeline))
assert timeline_detail['disk_consistent_lsn'] == timeline_detail['timeline_state']['Ready']
log.info("inititalizing new pageserver")
@@ -194,7 +194,7 @@ def test_tenant_relocation(zenith_env_builder: ZenithEnvBuilder,
new_pageserver_http_port):
# call to attach timeline to new pageserver
new_pageserver_http_client.timeline_attach(tenant, UUID(timeline))
new_pageserver_http_client.timeline_attach(UUID(tenant), UUID(timeline))
# FIXME cannot handle duplicate download requests, subject to fix in https://github.com/zenithdb/zenith/issues/997
time.sleep(5)
# new pageserver should in sync (modulo wal tail or vacuum activity) with the old one because there was no new writes since checkpoint
@@ -241,7 +241,7 @@ def test_tenant_relocation(zenith_env_builder: ZenithEnvBuilder,
# detach tenant from old pageserver before we check
# that all the data is there to be sure that old pageserver
# is no longer involved, and if it is, we will see the errors
pageserver_http_client.timeline_detach(tenant, UUID(timeline))
pageserver_http_client.timeline_detach(UUID(tenant), UUID(timeline))
with pg_cur(tenant_pg) as cur:
# check that data is still there

View File

@@ -15,12 +15,18 @@ def test_tenants_normal_work(zenith_env_builder: ZenithEnvBuilder, with_wal_acce
tenant_1 = env.create_tenant()
tenant_2 = env.create_tenant()
env.zenith_cli.create_branch(f"test_tenants_normal_work_with_wal_acceptors{with_wal_acceptors}",
"main",
tenant_id=tenant_1)
env.zenith_cli.create_branch(f"test_tenants_normal_work_with_wal_acceptors{with_wal_acceptors}",
"main",
tenant_id=tenant_2)
env.zenith_cli([
"branch",
f"test_tenants_normal_work_with_wal_acceptors{with_wal_acceptors}",
"main",
f"--tenantid={tenant_1}"
])
env.zenith_cli([
"branch",
f"test_tenants_normal_work_with_wal_acceptors{with_wal_acceptors}",
"main",
f"--tenantid={tenant_2}"
])
pg_tenant1 = env.postgres.create_start(
f"test_tenants_normal_work_with_wal_acceptors{with_wal_acceptors}",

View File

@@ -10,10 +10,10 @@ import time
def test_timeline_size(zenith_simple_env: ZenithEnv):
env = zenith_simple_env
# Branch at the point where only 100 rows were inserted
env.zenith_cli.create_branch("test_timeline_size", "empty")
env.zenith_cli(["branch", "test_timeline_size", "empty"])
client = env.pageserver.http_client()
res = client.branch_detail(env.initial_tenant, "test_timeline_size")
res = client.branch_detail(UUID(env.initial_tenant), "test_timeline_size")
assert res["current_logical_size"] == res["current_logical_size_non_incremental"]
pgmain = env.postgres.create_start("test_timeline_size")
@@ -31,36 +31,36 @@ def test_timeline_size(zenith_simple_env: ZenithEnv):
FROM generate_series(1, 10) g
""")
res = client.branch_detail(env.initial_tenant, "test_timeline_size")
res = client.branch_detail(UUID(env.initial_tenant), "test_timeline_size")
assert res["current_logical_size"] == res["current_logical_size_non_incremental"]
cur.execute("TRUNCATE foo")
res = client.branch_detail(env.initial_tenant, "test_timeline_size")
res = client.branch_detail(UUID(env.initial_tenant), "test_timeline_size")
assert res["current_logical_size"] == res["current_logical_size_non_incremental"]
# wait until received_lsn_lag is 0
# wait until write_lag is 0
def wait_for_pageserver_catchup(pgmain: Postgres, polling_interval=1, timeout=60):
started_at = time.time()
received_lsn_lag = 1
while received_lsn_lag > 0:
write_lag = 1
while write_lag > 0:
elapsed = time.time() - started_at
if elapsed > timeout:
raise RuntimeError(
f"timed out waiting for pageserver to reach pg_current_wal_flush_lsn()")
raise RuntimeError(f"timed out waiting for pageserver to reach pg_current_wal_lsn()")
with closing(pgmain.connect()) as conn:
with conn.cursor() as cur:
cur.execute('''
select pg_size_pretty(pg_cluster_size()),
pg_wal_lsn_diff(pg_current_wal_flush_lsn(),received_lsn) as received_lsn_lag
FROM backpressure_lsns();
pg_wal_lsn_diff(pg_current_wal_lsn(),write_lsn) as write_lag,
pg_wal_lsn_diff(pg_current_wal_lsn(),sent_lsn) as pending_lag
FROM pg_stat_get_wal_senders();
''')
res = cur.fetchone()
log.info(f"pg_cluster_size = {res[0]}, received_lsn_lag = {res[1]}")
received_lsn_lag = res[1]
log.info(
f"pg_cluster_size = {res[0]}, write_lag = {res[1]}, pending_lag = {res[2]}")
write_lag = res[1]
time.sleep(polling_interval)
@@ -68,10 +68,10 @@ def wait_for_pageserver_catchup(pgmain: Postgres, polling_interval=1, timeout=60
def test_timeline_size_quota(zenith_env_builder: ZenithEnvBuilder):
zenith_env_builder.num_safekeepers = 1
env = zenith_env_builder.init()
env.zenith_cli.create_branch("test_timeline_size_quota", "main")
env.zenith_cli(["branch", "test_timeline_size_quota", "main"])
client = env.pageserver.http_client()
res = client.branch_detail(env.initial_tenant, "test_timeline_size_quota")
res = client.branch_detail(UUID(env.initial_tenant), "test_timeline_size_quota")
assert res["current_logical_size"] == res["current_logical_size_non_incremental"]
pgmain = env.postgres.create_start(

View File

@@ -3,13 +3,15 @@ import os
from fixtures.zenith_fixtures import ZenithEnv
from fixtures.log_helper import log
pytest_plugins = ("fixtures.zenith_fixtures")
#
# Test branching, when a transaction is in prepared state
#
def test_twophase(zenith_simple_env: ZenithEnv):
env = zenith_simple_env
env.zenith_cli.create_branch("test_twophase", "empty")
env.zenith_cli(["branch", "test_twophase", "empty"])
pg = env.postgres.create_start('test_twophase', config_lines=['max_prepared_transactions=5'])
log.info("postgres is running on 'test_twophase' branch")
@@ -56,7 +58,7 @@ def test_twophase(zenith_simple_env: ZenithEnv):
assert len(twophase_files) == 2
# Create a branch with the transaction in prepared state
env.zenith_cli.create_branch("test_twophase_prepared", "test_twophase")
env.zenith_cli(["branch", "test_twophase_prepared", "test_twophase"])
# Start compute on the new branch
pg2 = env.postgres.create_start(

View File

@@ -1,6 +1,8 @@
from fixtures.zenith_fixtures import ZenithEnv
from fixtures.log_helper import log
pytest_plugins = ("fixtures.zenith_fixtures")
#
# Test that the VM bit is cleared correctly at a HEAP_DELETE and
@@ -9,7 +11,8 @@ from fixtures.log_helper import log
def test_vm_bit_clear(zenith_simple_env: ZenithEnv):
env = zenith_simple_env
env.zenith_cli.create_branch("test_vm_bit_clear", "empty")
# Create a branch for us
env.zenith_cli(["branch", "test_vm_bit_clear", "empty"])
pg = env.postgres.create_start('test_vm_bit_clear')
log.info("postgres is running on 'test_vm_bit_clear' branch")
@@ -33,7 +36,7 @@ def test_vm_bit_clear(zenith_simple_env: ZenithEnv):
cur.execute('UPDATE vmtest_update SET id = 5000 WHERE id = 1')
# Branch at this point, to test that later
env.zenith_cli.create_branch("test_vm_bit_clear_new", "test_vm_bit_clear")
env.zenith_cli(["branch", "test_vm_bit_clear_new", "test_vm_bit_clear"])
# Clear the buffer cache, to force the VM page to be re-fetched from
# the page server

View File

@@ -17,6 +17,8 @@ from fixtures.utils import lsn_to_hex, mkdir_if_needed
from fixtures.log_helper import log
from typing import List, Optional, Any
pytest_plugins = ("fixtures.zenith_fixtures")
# basic test, write something in setup with wal acceptors, ensure that commits
# succeed and data is written
@@ -24,7 +26,7 @@ def test_normal_work(zenith_env_builder: ZenithEnvBuilder):
zenith_env_builder.num_safekeepers = 3
env = zenith_env_builder.init()
env.zenith_cli.create_branch("test_wal_acceptors_normal_work", "main")
env.zenith_cli(["branch", "test_wal_acceptors_normal_work", "main"])
pg = env.postgres.create_start('test_wal_acceptors_normal_work')
@@ -60,10 +62,10 @@ def test_many_timelines(zenith_env_builder: ZenithEnvBuilder):
# start postgres on each timeline
pgs = []
for branch in branches:
env.zenith_cli.create_branch(branch, "main")
env.zenith_cli(["branch", branch, "main"])
pgs.append(env.postgres.create_start(branch))
tenant_id = env.initial_tenant
tenant_id = uuid.UUID(env.initial_tenant)
def collect_metrics(message: str) -> List[BranchMetrics]:
with env.pageserver.http_client() as pageserver_http:
@@ -90,8 +92,8 @@ def test_many_timelines(zenith_env_builder: ZenithEnvBuilder):
latest_valid_lsn=branch_detail["latest_valid_lsn"],
)
for sk_m in sk_metrics:
m.flush_lsns.append(sk_m.flush_lsn_inexact[(tenant_id.hex, timeline_id)])
m.commit_lsns.append(sk_m.commit_lsn_inexact[(tenant_id.hex, timeline_id)])
m.flush_lsns.append(sk_m.flush_lsn_inexact[timeline_id])
m.commit_lsns.append(sk_m.commit_lsn_inexact[timeline_id])
for flush_lsn, commit_lsn in zip(m.flush_lsns, m.commit_lsns):
# Invariant. May be < when transaction is in progress.
@@ -183,7 +185,7 @@ def test_restarts(zenith_env_builder: ZenithEnvBuilder):
zenith_env_builder.num_safekeepers = n_acceptors
env = zenith_env_builder.init()
env.zenith_cli.create_branch("test_wal_acceptors_restarts", "main")
env.zenith_cli(["branch", "test_wal_acceptors_restarts", "main"])
pg = env.postgres.create_start('test_wal_acceptors_restarts')
# we rely upon autocommit after each statement
@@ -220,7 +222,7 @@ def test_unavailability(zenith_env_builder: ZenithEnvBuilder):
zenith_env_builder.num_safekeepers = 2
env = zenith_env_builder.init()
env.zenith_cli.create_branch("test_wal_acceptors_unavailability", "main")
env.zenith_cli(["branch", "test_wal_acceptors_unavailability", "main"])
pg = env.postgres.create_start('test_wal_acceptors_unavailability')
# we rely upon autocommit after each statement
@@ -291,7 +293,7 @@ def test_race_conditions(zenith_env_builder: ZenithEnvBuilder, stop_value):
zenith_env_builder.num_safekeepers = 3
env = zenith_env_builder.init()
env.zenith_cli.create_branch("test_wal_acceptors_race_conditions", "main")
env.zenith_cli(["branch", "test_wal_acceptors_race_conditions", "main"])
pg = env.postgres.create_start('test_wal_acceptors_race_conditions')
# we rely upon autocommit after each statement
@@ -319,16 +321,16 @@ class ProposerPostgres(PgProtocol):
def __init__(self,
pgdata_dir: str,
pg_bin,
timeline_id: uuid.UUID,
tenant_id: uuid.UUID,
timeline_id: str,
tenant_id: str,
listen_addr: str,
port: int):
super().__init__(host=listen_addr, port=port, username='zenith_admin')
self.pgdata_dir: str = pgdata_dir
self.pg_bin: PgBin = pg_bin
self.timeline_id: uuid.UUID = timeline_id
self.tenant_id: uuid.UUID = tenant_id
self.timeline_id: str = timeline_id
self.tenant_id: str = tenant_id
self.listen_addr: str = listen_addr
self.port: int = port
@@ -348,8 +350,8 @@ class ProposerPostgres(PgProtocol):
cfg = [
"synchronous_standby_names = 'walproposer'\n",
"shared_preload_libraries = 'zenith'\n",
f"zenith.zenith_timeline = '{self.timeline_id.hex}'\n",
f"zenith.zenith_tenant = '{self.tenant_id.hex}'\n",
f"zenith.zenith_timeline = '{self.timeline_id}'\n",
f"zenith.zenith_tenant = '{self.tenant_id}'\n",
f"zenith.page_server_connstring = ''\n",
f"wal_acceptors = '{wal_acceptors}'\n",
f"listen_addresses = '{self.listen_addr}'\n",
@@ -406,8 +408,8 @@ def test_sync_safekeepers(zenith_env_builder: ZenithEnvBuilder,
zenith_env_builder.num_safekeepers = 3
env = zenith_env_builder.init()
timeline_id = uuid.uuid4()
tenant_id = uuid.uuid4()
timeline_id = uuid.uuid4().hex
tenant_id = uuid.uuid4().hex
# write config for proposer
pgdata_dir = os.path.join(env.repo_dir, "proposer_pgdata")
@@ -456,7 +458,7 @@ def test_timeline_status(zenith_env_builder: ZenithEnvBuilder):
zenith_env_builder.num_safekeepers = 1
env = zenith_env_builder.init()
env.zenith_cli.create_branch("test_timeline_status", "main")
env.zenith_cli(["branch", "test_timeline_status", "main"])
pg = env.postgres.create_start('test_timeline_status')
wa = env.safekeepers[0]
@@ -493,15 +495,15 @@ class SafekeeperEnv:
self.bin_safekeeper = os.path.join(str(zenith_binpath), 'safekeeper')
self.safekeepers: Optional[List[subprocess.CompletedProcess[Any]]] = None
self.postgres: Optional[ProposerPostgres] = None
self.tenant_id: Optional[uuid.UUID] = None
self.timeline_id: Optional[uuid.UUID] = None
self.tenant_id: Optional[str] = None
self.timeline_id: Optional[str] = None
def init(self) -> "SafekeeperEnv":
assert self.postgres is None, "postgres is already initialized"
assert self.safekeepers is None, "safekeepers are already initialized"
self.timeline_id = uuid.uuid4()
self.tenant_id = uuid.uuid4()
self.timeline_id = uuid.uuid4().hex
self.tenant_id = uuid.uuid4().hex
mkdir_if_needed(str(self.repo_dir))
# Create config and a Safekeeper object for each safekeeper
@@ -634,7 +636,7 @@ def test_replace_safekeeper(zenith_env_builder: ZenithEnvBuilder):
zenith_env_builder.num_safekeepers = 4
env = zenith_env_builder.init()
env.zenith_cli.create_branch("test_replace_safekeeper", "main")
env.zenith_cli(["branch", "test_replace_safekeeper", "main"])
log.info("Use only first 3 safekeepers")
env.safekeepers[3].stop()

View File

@@ -9,6 +9,7 @@ from fixtures.utils import lsn_from_hex, lsn_to_hex
from typing import List
log = getLogger('root.wal_acceptor_async')
pytest_plugins = ("fixtures.zenith_fixtures")
class BankClient(object):
@@ -202,7 +203,7 @@ def test_restarts_under_load(zenith_env_builder: ZenithEnvBuilder):
zenith_env_builder.num_safekeepers = 3
env = zenith_env_builder.init()
env.zenith_cli.create_branch("test_wal_acceptors_restarts_under_load", "main")
env.zenith_cli(["branch", "test_wal_acceptors_restarts_under_load", "main"])
pg = env.postgres.create_start('test_wal_acceptors_restarts_under_load')
asyncio.run(run_restarts_under_load(pg, env.safekeepers))

View File

@@ -3,26 +3,30 @@ import uuid
import requests
from psycopg2.extensions import cursor as PgCursor
from fixtures.zenith_fixtures import ZenithEnv, ZenithEnvBuilder, ZenithPageserverHttpClient
from fixtures.zenith_fixtures import ZenithEnv, ZenithEnvBuilder
from typing import cast
pytest_plugins = ("fixtures.zenith_fixtures")
def helper_compare_branch_list(pageserver_http_client: ZenithPageserverHttpClient,
env: ZenithEnv,
initial_tenant: uuid.UUID):
def helper_compare_branch_list(page_server_cur: PgCursor, env: ZenithEnv, initial_tenant: str):
"""
Compare branches list returned by CLI and directly via API.
Filters out branches created by other tests.
"""
branches = pageserver_http_client.branch_list(initial_tenant)
branches_api = sorted(map(lambda b: cast(str, b['name']), branches))
page_server_cur.execute(f'branch_list {initial_tenant}')
branches_api = sorted(
map(lambda b: cast(str, b['name']), json.loads(page_server_cur.fetchone()[0])))
branches_api = [b for b in branches_api if b.startswith('test_cli_') or b in ('empty', 'main')]
res = env.zenith_cli.list_branches()
res = env.zenith_cli(["branch"])
res.check_returncode()
branches_cli = sorted(map(lambda b: b.split(':')[-1].strip(), res.stdout.strip().split("\n")))
branches_cli = [b for b in branches_cli if b.startswith('test_cli_') or b in ('empty', 'main')]
res = env.zenith_cli.list_branches(tenant_id=initial_tenant)
res = env.zenith_cli(["branch", f"--tenantid={initial_tenant}"])
res.check_returncode()
branches_cli_with_tenant_arg = sorted(
map(lambda b: b.split(':')[-1].strip(), res.stdout.strip().split("\n")))
branches_cli_with_tenant_arg = [
@@ -34,20 +38,24 @@ def helper_compare_branch_list(pageserver_http_client: ZenithPageserverHttpClien
def test_cli_branch_list(zenith_simple_env: ZenithEnv):
env = zenith_simple_env
pageserver_http_client = env.pageserver.http_client()
page_server_conn = env.pageserver.connect()
page_server_cur = page_server_conn.cursor()
# Initial sanity check
helper_compare_branch_list(pageserver_http_client, env, env.initial_tenant)
env.zenith_cli.create_branch("test_cli_branch_list_main", "empty")
helper_compare_branch_list(pageserver_http_client, env, env.initial_tenant)
helper_compare_branch_list(page_server_cur, env, env.initial_tenant)
# Create a branch for us
res = env.zenith_cli(["branch", "test_cli_branch_list_main", "empty"])
assert res.stderr == ''
helper_compare_branch_list(page_server_cur, env, env.initial_tenant)
# Create a nested branch
res = env.zenith_cli.create_branch("test_cli_branch_list_nested", "test_cli_branch_list_main")
res = env.zenith_cli(["branch", "test_cli_branch_list_nested", "test_cli_branch_list_main"])
assert res.stderr == ''
helper_compare_branch_list(pageserver_http_client, env, env.initial_tenant)
helper_compare_branch_list(page_server_cur, env, env.initial_tenant)
# Check that all new branches are visible via CLI
res = env.zenith_cli.list_branches()
res = env.zenith_cli(["branch"])
assert res.stderr == ''
branches_cli = sorted(map(lambda b: b.split(':')[-1].strip(), res.stdout.strip().split("\n")))
@@ -55,11 +63,12 @@ def test_cli_branch_list(zenith_simple_env: ZenithEnv):
assert 'test_cli_branch_list_nested' in branches_cli
def helper_compare_tenant_list(pageserver_http_client: ZenithPageserverHttpClient, env: ZenithEnv):
tenants = pageserver_http_client.tenant_list()
tenants_api = sorted(map(lambda t: cast(str, t['id']), tenants))
def helper_compare_tenant_list(page_server_cur: PgCursor, env: ZenithEnv):
page_server_cur.execute(f'tenant_list')
tenants_api = sorted(
map(lambda t: cast(str, t['id']), json.loads(page_server_cur.fetchone()[0])))
res = env.zenith_cli.list_tenants()
res = env.zenith_cli(["tenant", "list"])
assert res.stderr == ''
tenants_cli = sorted(map(lambda t: t.split()[0], res.stdout.splitlines()))
@@ -68,30 +77,35 @@ def helper_compare_tenant_list(pageserver_http_client: ZenithPageserverHttpClien
def test_cli_tenant_list(zenith_simple_env: ZenithEnv):
env = zenith_simple_env
pageserver_http_client = env.pageserver.http_client()
page_server_conn = env.pageserver.connect()
page_server_cur = page_server_conn.cursor()
# Initial sanity check
helper_compare_tenant_list(pageserver_http_client, env)
helper_compare_tenant_list(page_server_cur, env)
# Create new tenant
tenant1 = uuid.uuid4()
env.zenith_cli.create_tenant(tenant1)
tenant1 = uuid.uuid4().hex
res = env.zenith_cli(["tenant", "create", tenant1])
res.check_returncode()
# check tenant1 appeared
helper_compare_tenant_list(pageserver_http_client, env)
helper_compare_tenant_list(page_server_cur, env)
# Create new tenant
tenant2 = uuid.uuid4()
env.zenith_cli.create_tenant(tenant2)
tenant2 = uuid.uuid4().hex
res = env.zenith_cli(["tenant", "create", tenant2])
res.check_returncode()
# check tenant2 appeared
helper_compare_tenant_list(pageserver_http_client, env)
helper_compare_tenant_list(page_server_cur, env)
res = env.zenith_cli.list_tenants()
res = env.zenith_cli(["tenant", "list"])
res.check_returncode()
tenants = sorted(map(lambda t: t.split()[0], res.stdout.splitlines()))
assert env.initial_tenant.hex in tenants
assert tenant1.hex in tenants
assert tenant2.hex in tenants
assert env.initial_tenant in tenants
assert tenant1 in tenants
assert tenant2 in tenants
def test_cli_ipv4_listeners(zenith_env_builder: ZenithEnvBuilder):
@@ -109,21 +123,3 @@ def test_cli_ipv4_listeners(zenith_env_builder: ZenithEnvBuilder):
# Connect to ps port on v4 loopback
# res = requests.get(f'http://127.0.0.1:{env.pageserver.service_port.http}/v1/status')
# assert res.ok
def test_cli_start_stop(zenith_env_builder: ZenithEnvBuilder):
# Start with single sk
zenith_env_builder.num_safekeepers = 1
env = zenith_env_builder.init()
# Stop default ps/sk
env.zenith_cli.pageserver_stop()
env.zenith_cli.safekeeper_stop()
# Default start
res = env.zenith_cli.raw_cli(["start"])
res.check_returncode()
# Default stop
res = env.zenith_cli.raw_cli(["stop"])
res.check_returncode()

View File

@@ -3,11 +3,15 @@ import os
from fixtures.utils import mkdir_if_needed
from fixtures.zenith_fixtures import ZenithEnv, base_dir, pg_distrib_dir
pytest_plugins = ("fixtures.zenith_fixtures")
def test_isolation(zenith_simple_env: ZenithEnv, test_output_dir, pg_bin, capsys):
env = zenith_simple_env
env.zenith_cli.create_branch("test_isolation", "empty")
# Create a branch for us
env.zenith_cli(["branch", "test_isolation", "empty"])
# Connect to postgres and create a database called "regression".
# isolation tests use prepared transactions, so enable them
pg = env.postgres.create_start('test_isolation', config_lines=['max_prepared_transactions=100'])

View File

@@ -3,11 +3,15 @@ import os
from fixtures.utils import mkdir_if_needed
from fixtures.zenith_fixtures import ZenithEnv, check_restored_datadir_content, base_dir, pg_distrib_dir
pytest_plugins = ("fixtures.zenith_fixtures")
def test_pg_regress(zenith_simple_env: ZenithEnv, test_output_dir: str, pg_bin, capsys):
env = zenith_simple_env
env.zenith_cli.create_branch("test_pg_regress", "empty")
# Create a branch for us
env.zenith_cli(["branch", "test_pg_regress", "empty"])
# Connect to postgres and create a database called "regression".
pg = env.postgres.create_start('test_pg_regress')
pg.safe_psql('CREATE DATABASE regression')

View File

@@ -7,11 +7,15 @@ from fixtures.zenith_fixtures import (ZenithEnv,
pg_distrib_dir)
from fixtures.log_helper import log
pytest_plugins = ("fixtures.zenith_fixtures")
def test_zenith_regress(zenith_simple_env: ZenithEnv, test_output_dir, pg_bin, capsys):
env = zenith_simple_env
env.zenith_cli.create_branch("test_zenith_regress", "empty")
# Create a branch for us
env.zenith_cli(["branch", "test_zenith_regress", "empty"])
# Connect to postgres and create a database called "regression".
pg = env.postgres.create_start('test_zenith_regress')
pg.safe_psql('CREATE DATABASE regression')

View File

@@ -1,6 +1 @@
pytest_plugins = (
"fixtures.zenith_fixtures",
"fixtures.benchmark_fixture",
"fixtures.compare_fixtures",
"fixtures.slow",
)
pytest_plugins = ("fixtures.zenith_fixtures", "fixtures.benchmark_fixture")

View File

@@ -8,7 +8,6 @@ import timeit
import calendar
import enum
from datetime import datetime
import uuid
import pytest
from _pytest.config import Config
from _pytest.terminal import TerminalReporter
@@ -27,6 +26,8 @@ bencmark, and then record the result by calling zenbenchmark.record. For example
import timeit
from fixtures.zenith_fixtures import ZenithEnv
pytest_plugins = ("fixtures.zenith_fixtures", "fixtures.benchmark_fixture")
def test_mybench(zenith_simple_env: env, zenbenchmark):
# Initialize the test
@@ -39,8 +40,6 @@ def test_mybench(zenith_simple_env: env, zenbenchmark):
# Record another measurement
zenbenchmark.record('speed_of_light', 300000, 'km/s')
There's no need to import this file to use it. It should be declared as a plugin
inside conftest.py, and that makes it available to all tests.
You can measure multiple things in one test, and record each one with a separate
call to zenbenchmark. For example, you could time the bulk loading that happens
@@ -277,11 +276,11 @@ class ZenithBenchmarker:
assert matches
return int(round(float(matches.group(1))))
def get_timeline_size(self, repo_dir: Path, tenantid: uuid.UUID, timelineid: str):
def get_timeline_size(self, repo_dir: Path, tenantid: str, timelineid: str):
"""
Calculate the on-disk size of a timeline
"""
path = "{}/tenants/{}/timelines/{}".format(repo_dir, tenantid.hex, timelineid)
path = "{}/tenants/{}/timelines/{}".format(repo_dir, tenantid, timelineid)
totalbytes = 0
for root, dirs, files in os.walk(path):

View File

@@ -65,7 +65,7 @@ class ZenithCompare(PgCompare):
# We only use one branch and one timeline
self.branch = branch_name
self.env.zenith_cli.create_branch(self.branch, "empty")
self.env.zenith_cli(["branch", self.branch, "empty"])
self._pg = self.env.postgres.create_start(self.branch)
self.timeline = self.pg.safe_psql("SHOW zenith.zenith_timeline")[0][0]
@@ -86,7 +86,7 @@ class ZenithCompare(PgCompare):
return self._pg_bin
def flush(self):
self.pscur.execute(f"do_gc {self.env.initial_tenant.hex} {self.timeline} 0")
self.pscur.execute(f"do_gc {self.env.initial_tenant} {self.timeline} 0")
def report_peak_memory_use(self) -> None:
self.zenbenchmark.record("peak_mem",

View File

@@ -1,26 +0,0 @@
import pytest
"""
This plugin allows tests to be marked as slow using pytest.mark.slow. By default slow
tests are excluded. They need to be specifically requested with the --runslow flag in
order to run.
Copied from here: https://docs.pytest.org/en/latest/example/simple.html
"""
def pytest_addoption(parser):
parser.addoption("--runslow", action="store_true", default=False, help="run slow tests")
def pytest_configure(config):
config.addinivalue_line("markers", "slow: mark test as slow to run")
def pytest_collection_modifyitems(config, items):
if config.getoption("--runslow"):
# --runslow given in cli: do not skip slow tests
return
skip_slow = pytest.mark.skip(reason="need --runslow option to run")
for item in items:
if "slow" in item.keywords:
item.add_marker(skip_slow)

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
from dataclasses import dataclass, field
import textwrap
from cached_property import cached_property
import asyncpg
import os
@@ -27,7 +26,7 @@ from dataclasses import dataclass
# Type-related stuff
from psycopg2.extensions import connection as PgConnection
from typing import Any, Callable, Dict, Iterator, List, Optional, TypeVar, cast, Union, Tuple
from typing import Any, Callable, Dict, Iterator, List, Optional, TypeVar, cast, Union
from typing_extensions import Literal
import pytest
@@ -45,8 +44,9 @@ the standard pytest.fixture with some extra behavior.
There are several environment variables that can control the running of tests:
ZENITH_BIN, POSTGRES_DISTRIB_DIR, etc. See README.md for more information.
There's no need to import this file to use it. It should be declared as a plugin
inside conftest.py, and that makes it available to all tests.
To use fixtures in a test file, add this line of code:
>>> pytest_plugins = ("fixtures.zenith_fixtures")
Don't import functions from this file, or pytest will emit warnings. Instead
put directly-importable functions into utils.py or another separate file.
@@ -237,13 +237,10 @@ def port_distributor(worker_base_port):
class PgProtocol:
""" Reusable connection logic """
def __init__(self, host: str, port: int,
username: Optional[str] = None,
password: Optional[str] = None):
def __init__(self, host: str, port: int, username: Optional[str] = None):
self.host = host
self.port = port
self.username = username
self.password = password
def connstr(self,
*,
@@ -255,7 +252,6 @@ class PgProtocol:
"""
username = username or self.username
password = password or self.password
res = f'host={self.host} port={self.port} dbname={dbname}'
if username:
@@ -520,7 +516,6 @@ class ZenithEnv:
self.rust_log_override = config.rust_log_override
self.port_distributor = config.port_distributor
self.s3_mock_server = config.s3_mock_server
self.zenith_cli = ZenithCli(env=self)
self.postgres = PostgresFactory(self)
@@ -528,12 +523,12 @@ class ZenithEnv:
# generate initial tenant ID here instead of letting 'zenith init' generate it,
# so that we don't need to dig it out of the config file afterwards.
self.initial_tenant = uuid.uuid4()
self.initial_tenant = uuid.uuid4().hex
# Create a config file corresponding to the options
toml = textwrap.dedent(f"""
default_tenantid = '{self.initial_tenant.hex}'
""")
toml = f"""
default_tenantid = '{self.initial_tenant}'
"""
# Create config for pageserver
pageserver_port = PageserverPort(
@@ -542,12 +537,12 @@ class ZenithEnv:
)
pageserver_auth_type = "ZenithJWT" if config.pageserver_auth_enabled else "Trust"
toml += textwrap.dedent(f"""
[pageserver]
listen_pg_addr = 'localhost:{pageserver_port.pg}'
listen_http_addr = 'localhost:{pageserver_port.http}'
auth_type = '{pageserver_auth_type}'
""")
toml += f"""
[pageserver]
listen_pg_addr = 'localhost:{pageserver_port.pg}'
listen_http_addr = 'localhost:{pageserver_port.http}'
auth_type = '{pageserver_auth_type}'
"""
# Create a corresponding ZenithPageserver object
self.pageserver = ZenithPageserver(self,
@@ -577,7 +572,15 @@ sync = false # Disable fsyncs to make the tests go faster
log.info(f"Config: {toml}")
self.zenith_cli.init(toml)
# Run 'zenith init' using the config file we constructed
with tempfile.NamedTemporaryFile(mode='w+') as tmp:
tmp.write(toml)
tmp.flush()
cmd = ['init', f'--config={tmp.name}']
append_pageserver_param_overrides(cmd, config.pageserver_remote_storage)
self.zenith_cli(cmd)
# Start up the page server and all the safekeepers
self.pageserver.start()
@@ -589,12 +592,69 @@ sync = false # Disable fsyncs to make the tests go faster
""" Get list of safekeeper endpoints suitable for wal_acceptors GUC """
return ','.join([f'localhost:{wa.port.pg}' for wa in self.safekeepers])
def create_tenant(self, tenant_id: Optional[uuid.UUID] = None) -> uuid.UUID:
def create_tenant(self, tenant_id: Optional[str] = None):
if tenant_id is None:
tenant_id = uuid.uuid4()
self.zenith_cli.create_tenant(tenant_id)
tenant_id = uuid.uuid4().hex
res = self.zenith_cli(['tenant', 'create', tenant_id])
res.check_returncode()
return tenant_id
def zenith_cli(self, arguments: List[str]) -> 'subprocess.CompletedProcess[str]':
"""
Run "zenith" with the specified arguments.
Arguments must be in list form, e.g. ['pg', 'create']
Return both stdout and stderr, which can be accessed as
>>> result = env.zenith_cli(...)
>>> assert result.stderr == ""
>>> log.info(result.stdout)
"""
assert type(arguments) == list
bin_zenith = os.path.join(str(zenith_binpath), 'zenith')
args = [bin_zenith] + arguments
log.info('Running command "{}"'.format(' '.join(args)))
log.info(f'Running in "{self.repo_dir}"')
env_vars = os.environ.copy()
env_vars['ZENITH_REPO_DIR'] = str(self.repo_dir)
env_vars['POSTGRES_DISTRIB_DIR'] = str(pg_distrib_dir)
if self.rust_log_override is not None:
env_vars['RUST_LOG'] = self.rust_log_override
# Pass coverage settings
var = 'LLVM_PROFILE_FILE'
val = os.environ.get(var)
if val:
env_vars[var] = val
# Intercept CalledProcessError and print more info
try:
res = subprocess.run(args,
env=env_vars,
check=True,
universal_newlines=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
log.info(f"Run success: {res.stdout}")
except subprocess.CalledProcessError as exc:
# this way command output will be in recorded and shown in CI in failure message
msg = f"""\
Run failed: {exc}
stdout: {exc.stdout}
stderr: {exc.stderr}
"""
log.info(msg)
raise Exception(msg) from exc
return res
@cached_property
def auth_keys(self) -> AuthKeys:
pub = (Path(self.repo_dir) / 'auth_public_key.pem').read_bytes()
@@ -622,7 +682,7 @@ def _shared_simple_env(request: Any, port_distributor) -> Iterator[ZenithEnv]:
env = builder.init()
# For convenience in tests, create a branch from the freshly-initialized cluster.
env.zenith_cli.create_branch("empty", "main")
env.zenith_cli(["branch", "empty", "main"])
# Return the builder to the caller
yield env
@@ -668,10 +728,6 @@ def zenith_env_builder(test_output_dir, port_distributor) -> Iterator[ZenithEnvB
yield builder
class ZenithPageserverApiException(Exception):
pass
class ZenithPageserverHttpClient(requests.Session):
def __init__(self, port: int, auth_token: Optional[str] = None) -> None:
super().__init__()
@@ -681,32 +737,22 @@ class ZenithPageserverHttpClient(requests.Session):
if auth_token is not None:
self.headers['Authorization'] = f'Bearer {auth_token}'
def verbose_error(self, res: requests.Response):
try:
res.raise_for_status()
except requests.RequestException as e:
try:
msg = res.json()['msg']
except:
msg = ''
raise ZenithPageserverApiException(msg) from e
def check_status(self):
self.get(f"http://localhost:{self.port}/v1/status").raise_for_status()
def timeline_attach(self, tenant_id: uuid.UUID, timeline_id: uuid.UUID):
res = self.post(
f"http://localhost:{self.port}/v1/timeline/{tenant_id.hex}/{timeline_id.hex}/attach", )
self.verbose_error(res)
res.raise_for_status()
def timeline_detach(self, tenant_id: uuid.UUID, timeline_id: uuid.UUID):
res = self.post(
f"http://localhost:{self.port}/v1/timeline/{tenant_id.hex}/{timeline_id.hex}/detach", )
self.verbose_error(res)
res.raise_for_status()
def branch_list(self, tenant_id: uuid.UUID) -> List[Dict[Any, Any]]:
res = self.get(f"http://localhost:{self.port}/v1/branch/{tenant_id.hex}")
self.verbose_error(res)
res.raise_for_status()
res_json = res.json()
assert isinstance(res_json, list)
return res_json
@@ -718,7 +764,7 @@ class ZenithPageserverHttpClient(requests.Session):
'name': name,
'start_point': start_point,
})
self.verbose_error(res)
res.raise_for_status()
res_json = res.json()
assert isinstance(res_json, dict)
return res_json
@@ -727,14 +773,14 @@ class ZenithPageserverHttpClient(requests.Session):
res = self.get(
f"http://localhost:{self.port}/v1/branch/{tenant_id.hex}/{name}?include-non-incremental-logical-size=1",
)
self.verbose_error(res)
res.raise_for_status()
res_json = res.json()
assert isinstance(res_json, dict)
return res_json
def tenant_list(self) -> List[Dict[Any, Any]]:
res = self.get(f"http://localhost:{self.port}/v1/tenant")
self.verbose_error(res)
res.raise_for_status()
res_json = res.json()
assert isinstance(res_json, list)
return res_json
@@ -746,27 +792,27 @@ class ZenithPageserverHttpClient(requests.Session):
'tenant_id': tenant_id.hex,
},
)
self.verbose_error(res)
res.raise_for_status()
return res.json()
def timeline_list(self, tenant_id: uuid.UUID) -> List[str]:
res = self.get(f"http://localhost:{self.port}/v1/timeline/{tenant_id.hex}")
self.verbose_error(res)
res.raise_for_status()
res_json = res.json()
assert isinstance(res_json, list)
return res_json
def timeline_detail(self, tenant_id: uuid.UUID, timeline_id: uuid.UUID) -> Dict[Any, Any]:
def timeline_detail(self, tenant_id: uuid.UUID, timeline_id: uuid.UUID):
res = self.get(
f"http://localhost:{self.port}/v1/timeline/{tenant_id.hex}/{timeline_id.hex}")
self.verbose_error(res)
res.raise_for_status()
res_json = res.json()
assert isinstance(res_json, dict)
return res_json
def get_metrics(self) -> str:
res = self.get(f"http://localhost:{self.port}/metrics")
self.verbose_error(res)
res.raise_for_status()
return res.text
@@ -793,189 +839,6 @@ class S3Storage:
RemoteStorage = Union[LocalFsStorage, S3Storage]
class ZenithCli:
"""
A typed wrapper around the `zenith` CLI tool.
Supports main commands via typed methods and a way to run arbitrary command directly via CLI.
"""
def __init__(self, env: ZenithEnv) -> None:
self.env = env
pass
def create_tenant(self, tenant_id: Optional[uuid.UUID] = None) -> uuid.UUID:
if tenant_id is None:
tenant_id = uuid.uuid4()
self.raw_cli(['tenant', 'create', tenant_id.hex])
return tenant_id
def list_tenants(self) -> 'subprocess.CompletedProcess[str]':
return self.raw_cli(['tenant', 'list'])
def create_branch(self,
branch_name: str,
starting_point: str,
tenant_id: Optional[uuid.UUID] = None) -> 'subprocess.CompletedProcess[str]':
args = ['branch']
if tenant_id is not None:
args.extend(['--tenantid', tenant_id.hex])
args.extend([branch_name, starting_point])
return self.raw_cli(args)
def list_branches(self,
tenant_id: Optional[uuid.UUID] = None) -> 'subprocess.CompletedProcess[str]':
args = ['branch']
if tenant_id is not None:
args.extend(['--tenantid', tenant_id.hex])
return self.raw_cli(args)
def init(self, config_toml: str) -> 'subprocess.CompletedProcess[str]':
with tempfile.NamedTemporaryFile(mode='w+') as tmp:
tmp.write(config_toml)
tmp.flush()
cmd = ['init', f'--config={tmp.name}']
append_pageserver_param_overrides(cmd, self.env.pageserver.remote_storage)
return self.raw_cli(cmd)
def pageserver_start(self) -> 'subprocess.CompletedProcess[str]':
start_args = ['pageserver', 'start']
append_pageserver_param_overrides(start_args, self.env.pageserver.remote_storage)
return self.raw_cli(start_args)
def pageserver_stop(self, immediate=False) -> 'subprocess.CompletedProcess[str]':
cmd = ['pageserver', 'stop']
if immediate:
cmd.extend(['-m', 'immediate'])
log.info(f"Stopping pageserver with {cmd}")
return self.raw_cli(cmd)
def safekeeper_start(self, name: str) -> 'subprocess.CompletedProcess[str]':
return self.raw_cli(['safekeeper', 'start', name])
def safekeeper_stop(self,
name: Optional[str] = None,
immediate=False) -> 'subprocess.CompletedProcess[str]':
args = ['safekeeper', 'stop']
if immediate:
args.extend(['-m', 'immediate'])
if name is not None:
args.append(name)
return self.raw_cli(args)
def pg_create(
self,
node_name: str,
tenant_id: Optional[uuid.UUID] = None,
timeline_spec: Optional[str] = None,
port: Optional[int] = None,
) -> 'subprocess.CompletedProcess[str]':
args = ['pg', 'create']
if tenant_id is not None:
args.extend(['--tenantid', tenant_id.hex])
if port is not None:
args.append(f'--port={port}')
args.append(node_name)
if timeline_spec is not None:
args.append(timeline_spec)
return self.raw_cli(args)
def pg_start(
self,
node_name: str,
tenant_id: Optional[uuid.UUID] = None,
timeline_spec: Optional[str] = None,
port: Optional[int] = None,
) -> 'subprocess.CompletedProcess[str]':
args = ['pg', 'start']
if tenant_id is not None:
args.extend(['--tenantid', tenant_id.hex])
if port is not None:
args.append(f'--port={port}')
args.append(node_name)
if timeline_spec is not None:
args.append(timeline_spec)
return self.raw_cli(args)
def pg_stop(
self,
node_name: str,
tenant_id: Optional[uuid.UUID] = None,
destroy=False,
) -> 'subprocess.CompletedProcess[str]':
args = ['pg', 'stop']
if tenant_id is not None:
args.extend(['--tenantid', tenant_id.hex])
if destroy:
args.append('--destroy')
args.append(node_name)
return self.raw_cli(args)
def raw_cli(self,
arguments: List[str],
check_return_code=True) -> 'subprocess.CompletedProcess[str]':
"""
Run "zenith" with the specified arguments.
Arguments must be in list form, e.g. ['pg', 'create']
Return both stdout and stderr, which can be accessed as
>>> result = env.zenith_cli.raw_cli(...)
>>> assert result.stderr == ""
>>> log.info(result.stdout)
"""
assert type(arguments) == list
bin_zenith = os.path.join(str(zenith_binpath), 'zenith')
args = [bin_zenith] + arguments
log.info('Running command "{}"'.format(' '.join(args)))
log.info(f'Running in "{self.env.repo_dir}"')
env_vars = os.environ.copy()
env_vars['ZENITH_REPO_DIR'] = str(self.env.repo_dir)
env_vars['POSTGRES_DISTRIB_DIR'] = str(pg_distrib_dir)
if self.env.rust_log_override is not None:
env_vars['RUST_LOG'] = self.env.rust_log_override
# Pass coverage settings
var = 'LLVM_PROFILE_FILE'
val = os.environ.get(var)
if val:
env_vars[var] = val
# Intercept CalledProcessError and print more info
try:
res = subprocess.run(args,
env=env_vars,
check=True,
universal_newlines=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
log.info(f"Run success: {res.stdout}")
except subprocess.CalledProcessError as exc:
# this way command output will be in recorded and shown in CI in failure message
msg = f"""\
Run failed: {exc}
stdout: {exc.stdout}
stderr: {exc.stderr}
"""
log.info(msg)
raise Exception(msg) from exc
if check_return_code:
res.check_returncode()
return res
class ZenithPageserver(PgProtocol):
"""
An object representing a running pageserver.
@@ -1000,7 +863,10 @@ class ZenithPageserver(PgProtocol):
"""
assert self.running == False
self.env.zenith_cli.pageserver_start()
start_args = ['pageserver', 'start']
append_pageserver_param_overrides(start_args, self.remote_storage)
self.env.zenith_cli(start_args)
self.running = True
return self
@@ -1009,8 +875,13 @@ class ZenithPageserver(PgProtocol):
Stop the page server.
Returns self.
"""
cmd = ['pageserver', 'stop']
if immediate:
cmd.extend(['-m', 'immediate'])
log.info(f"Stopping pageserver with {cmd}")
if self.running:
self.env.zenith_cli.pageserver_stop(immediate)
self.env.zenith_cli(cmd)
self.running = False
return self
@@ -1161,62 +1032,9 @@ def vanilla_pg(test_output_dir: str) -> Iterator[VanillaPostgres]:
yield vanilla_pg
class ZenithProxy(PgProtocol):
def __init__(self, port: int):
super().__init__(host="127.0.0.1", username="pytest", password="pytest", port=port)
self.running = False
def start_static(self, addr="127.0.0.1:5432") -> None:
assert not self.running
self.running = True
http_port = "7001"
args = [
# TODO is cargo run the right thing to do?
"cargo",
"run",
"--bin", "proxy",
"--",
"--http", http_port,
"--proxy", f"{self.host}:{self.port}",
"--auth-method", "password",
"--static-router", addr,
]
self.popen = subprocess.Popen(args)
# Readiness probe
requests.get(f"http://{self.host}:{http_port}/v1/status")
def stop(self) -> None:
assert self.running
self.running = False
# NOTE the process will die when we're done with tests anyway, because
# it's a child process. This is mostly to clean up in between different tests.
self.popen.kill()
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
if self.running:
self.stop()
@pytest.fixture(scope='function')
def static_proxy(vanilla_pg) -> Iterator[ZenithProxy]:
vanilla_pg.start()
vanilla_pg.safe_psql("create user pytest with password 'pytest';")
with ZenithProxy(4432) as proxy:
proxy.start_static()
yield proxy
class Postgres(PgProtocol):
""" An object representing a running postgres daemon. """
def __init__(self, env: ZenithEnv, tenant_id: uuid.UUID, port: int):
def __init__(self, env: ZenithEnv, tenant_id: str, port: int):
super().__init__(host='localhost', port=port, username='zenith_admin')
self.env = env
@@ -1243,12 +1061,16 @@ class Postgres(PgProtocol):
if branch is None:
branch = node_name
self.env.zenith_cli.pg_create(node_name,
tenant_id=self.tenant_id,
port=self.port,
timeline_spec=branch)
self.env.zenith_cli([
'pg',
'create',
f'--tenantid={self.tenant_id}',
f'--port={self.port}',
node_name,
branch
])
self.node_name = node_name
path = pathlib.Path('pgdatadirs') / 'tenants' / self.tenant_id.hex / self.node_name
path = pathlib.Path('pgdatadirs') / 'tenants' / self.tenant_id / self.node_name
self.pgdata_dir = os.path.join(self.env.repo_dir, path)
if config_lines is None:
@@ -1267,9 +1089,8 @@ class Postgres(PgProtocol):
log.info(f"Starting postgres node {self.node_name}")
run_result = self.env.zenith_cli.pg_start(self.node_name,
tenant_id=self.tenant_id,
port=self.port)
run_result = self.env.zenith_cli(
['pg', 'start', f'--tenantid={self.tenant_id}', f'--port={self.port}', self.node_name])
self.running = True
log.info(f"stdout: {run_result.stdout}")
@@ -1279,7 +1100,7 @@ class Postgres(PgProtocol):
def pg_data_dir_path(self) -> str:
""" Path to data directory """
assert self.node_name
path = pathlib.Path('pgdatadirs') / 'tenants' / self.tenant_id.hex / self.node_name
path = pathlib.Path('pgdatadirs') / 'tenants' / self.tenant_id / self.node_name
return os.path.join(self.env.repo_dir, path)
def pg_xact_dir_path(self) -> str:
@@ -1339,7 +1160,7 @@ class Postgres(PgProtocol):
if self.running:
assert self.node_name is not None
self.env.zenith_cli.pg_stop(self.node_name, tenant_id=self.tenant_id)
self.env.zenith_cli(['pg', 'stop', self.node_name, f'--tenantid={self.tenant_id}'])
self.running = False
return self
@@ -1351,7 +1172,8 @@ class Postgres(PgProtocol):
"""
assert self.node_name is not None
self.env.zenith_cli.pg_stop(self.node_name, self.tenant_id, destroy=True)
self.env.zenith_cli(
['pg', 'stop', '--destroy', self.node_name, f'--tenantid={self.tenant_id}'])
self.node_name = None
return self
@@ -1393,7 +1215,7 @@ class PostgresFactory:
def create_start(self,
node_name: str = "main",
branch: Optional[str] = None,
tenant_id: Optional[uuid.UUID] = None,
tenant_id: Optional[str] = None,
config_lines: Optional[List[str]] = None) -> Postgres:
pg = Postgres(
@@ -1413,7 +1235,7 @@ class PostgresFactory:
def create(self,
node_name: str = "main",
branch: Optional[str] = None,
tenant_id: Optional[uuid.UUID] = None,
tenant_id: Optional[str] = None,
config_lines: Optional[List[str]] = None) -> Postgres:
pg = Postgres(
@@ -1458,7 +1280,7 @@ class Safekeeper:
auth_token: Optional[str] = None
def start(self) -> 'Safekeeper':
self.env.zenith_cli.safekeeper_start(self.name)
self.env.zenith_cli(['safekeeper', 'start', self.name])
# wait for wal acceptor start by checking its status
started_at = time.time()
@@ -1477,13 +1299,16 @@ class Safekeeper:
return self
def stop(self, immediate=False) -> 'Safekeeper':
cmd = ['safekeeper', 'stop']
if immediate:
cmd.extend(['-m', 'immediate'])
cmd.append(self.name)
log.info('Stopping safekeeper {}'.format(self.name))
self.env.zenith_cli.safekeeper_stop(self.name, immediate)
self.env.zenith_cli(cmd)
return self
def append_logical_message(self,
tenant_id: uuid.UUID,
timeline_id: uuid.UUID,
def append_logical_message(self, tenant_id: str, timeline_id: str,
request: Dict[str, Any]) -> Dict[str, Any]:
"""
Send JSON_CTRL query to append LogicalMessage to WAL and modify
@@ -1493,7 +1318,7 @@ class Safekeeper:
# "replication=0" hacks psycopg not to send additional queries
# on startup, see https://github.com/psycopg/psycopg2/pull/482
connstr = f"host=localhost port={self.port.pg} replication=0 options='-c ztimelineid={timeline_id.hex} ztenantid={tenant_id.hex}'"
connstr = f"host=localhost port={self.port.pg} replication=0 options='-c ztimelineid={timeline_id} ztenantid={tenant_id}'"
with closing(psycopg2.connect(connstr)) as conn:
# server doesn't support transactions
@@ -1522,8 +1347,8 @@ class SafekeeperTimelineStatus:
class SafekeeperMetrics:
# These are metrics from Prometheus which uses float64 internally.
# As a consequence, values may differ from real original int64s.
flush_lsn_inexact: Dict[Tuple[str, str], int] = field(default_factory=dict)
commit_lsn_inexact: Dict[Tuple[str, str], int] = field(default_factory=dict)
flush_lsn_inexact: Dict[str, int] = field(default_factory=dict)
commit_lsn_inexact: Dict[str, int] = field(default_factory=dict)
class SafekeeperHttpClient(requests.Session):
@@ -1547,16 +1372,14 @@ class SafekeeperHttpClient(requests.Session):
all_metrics_text = request_result.text
metrics = SafekeeperMetrics()
for match in re.finditer(
r'^safekeeper_flush_lsn{tenant_id="([0-9a-f]+)",timeline_id="([0-9a-f]+)"} (\S+)$',
all_metrics_text,
re.MULTILINE):
metrics.flush_lsn_inexact[(match.group(1), match.group(2))] = int(match.group(3))
for match in re.finditer(
r'^safekeeper_commit_lsn{tenant_id="([0-9a-f]+)",timeline_id="([0-9a-f]+)"} (\S+)$',
all_metrics_text,
re.MULTILINE):
metrics.commit_lsn_inexact[(match.group(1), match.group(2))] = int(match.group(3))
for match in re.finditer(r'^safekeeper_flush_lsn{ztli="([0-9a-f]+)"} (\S+)$',
all_metrics_text,
re.MULTILINE):
metrics.flush_lsn_inexact[match.group(1)] = int(match.group(2))
for match in re.finditer(r'^safekeeper_commit_lsn{ztli="([0-9a-f]+)"} (\S+)$',
all_metrics_text,
re.MULTILINE):
metrics.commit_lsn_inexact[match.group(1)] = int(match.group(2))
return metrics
@@ -1665,7 +1488,7 @@ def check_restored_datadir_content(test_output_dir: str, env: ZenithEnv, pg: Pos
{psql_path} \
--no-psqlrc \
postgres://localhost:{env.pageserver.service_port.pg} \
-c 'basebackup {pg.tenant_id.hex} {timeline}' \
-c 'basebackup {pg.tenant_id} {timeline}' \
| tar -x -C {restored_dir_path}
"""

View File

@@ -4,6 +4,12 @@ from fixtures.log_helper import log
from fixtures.benchmark_fixture import MetricReport, ZenithBenchmarker
from fixtures.compare_fixtures import PgCompare, VanillaCompare, ZenithCompare
pytest_plugins = (
"fixtures.zenith_fixtures",
"fixtures.benchmark_fixture",
"fixtures.compare_fixtures",
)
#
# Run bulk INSERT test.

View File

@@ -4,6 +4,8 @@ import pytest
from fixtures.zenith_fixtures import ZenithEnvBuilder
pytest_plugins = ("fixtures.benchmark_fixture")
# Run bulk tenant creation test.
#
# Collects metrics:
@@ -31,10 +33,12 @@ def test_bulk_tenant_create(
start = timeit.default_timer()
tenant = env.create_tenant()
env.zenith_cli.create_branch(
env.zenith_cli([
"branch",
f"test_bulk_tenant_create_{tenants_count}_{i}_{use_wal_acceptors}",
"main",
tenant_id=tenant)
f"--tenantid={tenant}"
])
# FIXME: We used to start new safekeepers here. Did that make sense? Should we do it now?
#if use_wal_acceptors == 'with_wa':

View File

@@ -6,6 +6,12 @@ from fixtures.compare_fixtures import PgCompare, VanillaCompare, ZenithCompare
from io import BufferedReader, RawIOBase
from itertools import repeat
pytest_plugins = (
"fixtures.zenith_fixtures",
"fixtures.benchmark_fixture",
"fixtures.compare_fixtures",
)
class CopyTestData(RawIOBase):
def __init__(self, rows: int):

View File

@@ -5,6 +5,12 @@ from fixtures.zenith_fixtures import ZenithEnv
from fixtures.compare_fixtures import PgCompare, VanillaCompare, ZenithCompare
from fixtures.log_helper import log
pytest_plugins = (
"fixtures.zenith_fixtures",
"fixtures.benchmark_fixture",
"fixtures.compare_fixtures",
)
#
# Test buffering GisT build. It WAL-logs the whole relation, in 32-page chunks.

View File

@@ -6,6 +6,12 @@ from fixtures.log_helper import log
from fixtures.benchmark_fixture import MetricReport, ZenithBenchmarker
from fixtures.compare_fixtures import PgCompare, VanillaCompare, ZenithCompare
pytest_plugins = (
"fixtures.zenith_fixtures",
"fixtures.benchmark_fixture",
"fixtures.compare_fixtures",
)
async def repeat_bytes(buf, repetitions: int):
for i in range(repetitions):

View File

@@ -5,6 +5,12 @@ from fixtures.compare_fixtures import PgCompare, VanillaCompare, ZenithCompare
from fixtures.benchmark_fixture import MetricReport, ZenithBenchmarker
from fixtures.log_helper import log
pytest_plugins = (
"fixtures.zenith_fixtures",
"fixtures.benchmark_fixture",
"fixtures.compare_fixtures",
)
#
# Run a very short pgbench test.

View File

@@ -9,6 +9,8 @@ import calendar
import timeit
import os
pytest_plugins = ("fixtures.benchmark_fixture", )
def utc_now_timestamp() -> int:
return calendar.timegm(datetime.utcnow().utctimetuple())

View File

@@ -1,79 +0,0 @@
import os
from contextlib import closing
from fixtures.benchmark_fixture import MetricReport
from fixtures.zenith_fixtures import ZenithEnv
from fixtures.compare_fixtures import PgCompare, VanillaCompare, ZenithCompare
from fixtures.log_helper import log
import psycopg2.extras
import random
import time
from fixtures.utils import print_gc_result
# This is a clear-box test that demonstrates the worst case scenario for the
# "1 segment per layer" implementation of the pageserver. It writes to random
# rows, while almost never writing to the same segment twice before flushing.
# A naive pageserver implementation would create a full image layer for each
# dirty segment, leading to write_amplification = segment_size / page_size,
# when compared to vanilla postgres. With segment_size = 10MB, that's 1250.
def test_random_writes(zenith_with_baseline: PgCompare):
env = zenith_with_baseline
# Number of rows in the test database. 1M rows runs quickly, but implies
# a small effective_checkpoint_distance, which makes the test less realistic.
# Using a 300 TB database would imply a 250 MB effective_checkpoint_distance,
# but it will take a very long time to run. From what I've seen so far,
# increasing n_rows doesn't have impact on the (zenith_runtime / vanilla_runtime)
# performance ratio.
n_rows = 1 * 1000 * 1000 # around 36 MB table
# Number of writes per 3 segments. A value of 1 should produce a random
# workload where we almost never write to the same segment twice. Larger
# values of load_factor produce a larger effective_checkpoint_distance,
# making the test more realistic, but less effective. If you want a realistic
# worst case scenario and you have time to wait you should increase n_rows instead.
load_factor = 1
# Not sure why but this matters in a weird way (up to 2x difference in perf).
# TODO look into it
n_iterations = 1
with closing(env.pg.connect()) as conn:
with conn.cursor() as cur:
# Create the test table
with env.record_duration('init'):
cur.execute("""
CREATE TABLE Big(
pk integer primary key,
count integer default 0
);
""")
cur.execute(f"INSERT INTO Big (pk) values (generate_series(1,{n_rows}))")
# Get table size (can't be predicted because padding and alignment)
cur.execute("SELECT pg_relation_size('Big');")
row = cur.fetchone()
table_size = row[0]
env.zenbenchmark.record("table_size", table_size, 'bytes', MetricReport.TEST_PARAM)
# Decide how much to write, based on knowledge of pageserver implementation.
# Avoiding segment collisions maximizes (zenith_runtime / vanilla_runtime).
segment_size = 10 * 1024 * 1024
n_segments = table_size // segment_size
n_writes = load_factor * n_segments // 3
# The closer this is to 250 MB, the more realistic the test is.
effective_checkpoint_distance = table_size * n_writes // n_rows
env.zenbenchmark.record("effective_checkpoint_distance",
effective_checkpoint_distance,
'bytes',
MetricReport.TEST_PARAM)
# Update random keys
with env.record_duration('run'):
for it in range(n_iterations):
for i in range(n_writes):
key = random.randint(1, n_rows)
cur.execute(f"update Big set count=count+1 where pk={key}")
env.flush()

View File

@@ -11,6 +11,12 @@ from fixtures.benchmark_fixture import MetricReport, ZenithBenchmarker
from fixtures.compare_fixtures import PgCompare
import pytest
pytest_plugins = (
"fixtures.zenith_fixtures",
"fixtures.benchmark_fixture",
"fixtures.compare_fixtures",
)
@pytest.mark.parametrize('rows', [
pytest.param(100000),

View File

@@ -17,6 +17,12 @@ from fixtures.zenith_fixtures import ZenithEnv
from fixtures.compare_fixtures import PgCompare, VanillaCompare, ZenithCompare
from fixtures.log_helper import log
pytest_plugins = (
"fixtures.zenith_fixtures",
"fixtures.benchmark_fixture",
"fixtures.compare_fixtures",
)
def test_write_amplification(zenith_with_baseline: PgCompare):
env = zenith_with_baseline

View File

@@ -3,6 +3,8 @@ import os
from fixtures.zenith_fixtures import ZenithEnv
from fixtures.log_helper import log
pytest_plugins = ("fixtures.zenith_fixtures")
"""
Use this test to see what happens when tests fail.
@@ -21,7 +23,9 @@ run_broken = pytest.mark.skipif(os.environ.get('RUN_BROKEN') is None,
def test_broken(zenith_simple_env: ZenithEnv, pg_bin):
env = zenith_simple_env
env.zenith_cli.create_branch("test_broken", "empty")
# Create a branch for us
env.zenith_cli(["branch", "test_broken", "empty"])
env.postgres.create_start("test_broken")
log.info('postgres is running')

View File

@@ -10,7 +10,7 @@ use std::fs::File;
use std::path::{Path, PathBuf};
use std::thread;
use tracing::*;
use walkeeper::control_file::{self, CreateControlFile};
use walkeeper::timeline::{CreateControlFile, FileStorage};
use zenith_utils::http::endpoint;
use zenith_utils::{logging, tcp_listener, GIT_VERSION};
@@ -96,10 +96,7 @@ fn main() -> Result<()> {
.get_matches();
if let Some(addr) = arg_matches.value_of("dump-control-file") {
let state = control_file::FileStorage::load_control_file(
Path::new(addr),
CreateControlFile::False,
)?;
let state = FileStorage::load_control_file(Path::new(addr), CreateControlFile::False)?;
let json = serde_json::to_string(&state)?;
print!("{}", json);
return Ok(());

View File

@@ -1,297 +0,0 @@
//! Control file serialization, deserialization and persistence.
use anyhow::{bail, ensure, Context, Result};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use lazy_static::lazy_static;
use std::fs::{self, File, OpenOptions};
use std::io::{Read, Write};
use std::path::{Path, PathBuf};
use tracing::*;
use zenith_metrics::{register_histogram_vec, Histogram, HistogramVec, DISK_WRITE_SECONDS_BUCKETS};
use zenith_utils::bin_ser::LeSer;
use zenith_utils::zid::ZTenantTimelineId;
use crate::control_file_upgrade::upgrade_control_file;
use crate::safekeeper::{SafeKeeperState, SK_FORMAT_VERSION, SK_MAGIC};
use crate::SafeKeeperConf;
use std::convert::TryInto;
// contains persistent metadata for safekeeper
const CONTROL_FILE_NAME: &str = "safekeeper.control";
// needed to atomically update the state using `rename`
const CONTROL_FILE_NAME_PARTIAL: &str = "safekeeper.control.partial";
pub const CHECKSUM_SIZE: usize = std::mem::size_of::<u32>();
// A named boolean.
#[derive(Debug)]
pub enum CreateControlFile {
True,
False,
}
lazy_static! {
static ref PERSIST_CONTROL_FILE_SECONDS: HistogramVec = register_histogram_vec!(
"safekeeper_persist_control_file_seconds",
"Seconds to persist and sync control file, grouped by timeline",
&["tenant_id", "timeline_id"],
DISK_WRITE_SECONDS_BUCKETS.to_vec()
)
.expect("Failed to register safekeeper_persist_control_file_seconds histogram vec");
}
pub trait Storage {
/// Persist safekeeper state on disk.
fn persist(&mut self, s: &SafeKeeperState) -> Result<()>;
}
#[derive(Debug)]
pub struct FileStorage {
// save timeline dir to avoid reconstructing it every time
timeline_dir: PathBuf,
conf: SafeKeeperConf,
persist_control_file_seconds: Histogram,
}
impl FileStorage {
pub fn new(zttid: &ZTenantTimelineId, conf: &SafeKeeperConf) -> FileStorage {
let timeline_dir = conf.timeline_dir(zttid);
let tenant_id = zttid.tenant_id.to_string();
let timeline_id = zttid.timeline_id.to_string();
FileStorage {
timeline_dir,
conf: conf.clone(),
persist_control_file_seconds: PERSIST_CONTROL_FILE_SECONDS
.with_label_values(&[&tenant_id, &timeline_id]),
}
}
// Check the magic/version in the on-disk data and deserialize it, if possible.
fn deser_sk_state(buf: &mut &[u8]) -> Result<SafeKeeperState> {
// Read the version independent part
let magic = buf.read_u32::<LittleEndian>()?;
if magic != SK_MAGIC {
bail!(
"bad control file magic: {:X}, expected {:X}",
magic,
SK_MAGIC
);
}
let version = buf.read_u32::<LittleEndian>()?;
if version == SK_FORMAT_VERSION {
let res = SafeKeeperState::des(buf)?;
return Ok(res);
}
// try to upgrade
upgrade_control_file(buf, version)
}
// Load control file for given zttid at path specified by conf.
pub fn load_control_file_conf(
conf: &SafeKeeperConf,
zttid: &ZTenantTimelineId,
create: CreateControlFile,
) -> Result<SafeKeeperState> {
let path = conf.timeline_dir(zttid).join(CONTROL_FILE_NAME);
Self::load_control_file(path, create)
}
/// Read in the control file.
/// If create=false and file doesn't exist, bails out.
pub fn load_control_file<P: AsRef<Path>>(
control_file_path: P,
create: CreateControlFile,
) -> Result<SafeKeeperState> {
info!(
"loading control file {}, create={:?}",
control_file_path.as_ref().display(),
create,
);
let mut control_file = OpenOptions::new()
.read(true)
.write(true)
.create(matches!(create, CreateControlFile::True))
.open(&control_file_path)
.with_context(|| {
format!(
"failed to open control file at {}",
control_file_path.as_ref().display(),
)
})?;
// Empty file is legit on 'create', don't try to deser from it.
let state = if control_file.metadata().unwrap().len() == 0 {
if let CreateControlFile::False = create {
bail!("control file is empty");
}
SafeKeeperState::new()
} else {
let mut buf = Vec::new();
control_file
.read_to_end(&mut buf)
.context("failed to read control file")?;
let calculated_checksum = crc32c::crc32c(&buf[..buf.len() - CHECKSUM_SIZE]);
let expected_checksum_bytes: &[u8; CHECKSUM_SIZE] =
buf[buf.len() - CHECKSUM_SIZE..].try_into()?;
let expected_checksum = u32::from_le_bytes(*expected_checksum_bytes);
ensure!(
calculated_checksum == expected_checksum,
format!(
"safekeeper control file checksum mismatch: expected {} got {}",
expected_checksum, calculated_checksum
)
);
FileStorage::deser_sk_state(&mut &buf[..buf.len() - CHECKSUM_SIZE]).with_context(
|| {
format!(
"while reading control file {}",
control_file_path.as_ref().display(),
)
},
)?
};
Ok(state)
}
}
impl Storage for FileStorage {
// persists state durably to underlying storage
// for description see https://lwn.net/Articles/457667/
fn persist(&mut self, s: &SafeKeeperState) -> Result<()> {
let _timer = &self.persist_control_file_seconds.start_timer();
// write data to safekeeper.control.partial
let control_partial_path = self.timeline_dir.join(CONTROL_FILE_NAME_PARTIAL);
let mut control_partial = File::create(&control_partial_path).with_context(|| {
format!(
"failed to create partial control file at: {}",
&control_partial_path.display()
)
})?;
let mut buf: Vec<u8> = Vec::new();
buf.write_u32::<LittleEndian>(SK_MAGIC)?;
buf.write_u32::<LittleEndian>(SK_FORMAT_VERSION)?;
s.ser_into(&mut buf)?;
// calculate checksum before resize
let checksum = crc32c::crc32c(&buf);
buf.extend_from_slice(&checksum.to_le_bytes());
control_partial.write_all(&buf).with_context(|| {
format!(
"failed to write safekeeper state into control file at: {}",
control_partial_path.display()
)
})?;
// fsync the file
control_partial.sync_all().with_context(|| {
format!(
"failed to sync partial control file at {}",
control_partial_path.display()
)
})?;
let control_path = self.timeline_dir.join(CONTROL_FILE_NAME);
// rename should be atomic
fs::rename(&control_partial_path, &control_path)?;
// this sync is not required by any standard but postgres does this (see durable_rename)
File::open(&control_path)
.and_then(|f| f.sync_all())
.with_context(|| {
format!(
"failed to sync control file at: {}",
&control_path.display()
)
})?;
// fsync the directory (linux specific)
File::open(&self.timeline_dir)
.and_then(|f| f.sync_all())
.context("failed to sync control file directory")?;
Ok(())
}
}
#[cfg(test)]
mod test {
use super::FileStorage;
use super::*;
use crate::{safekeeper::SafeKeeperState, SafeKeeperConf, ZTenantTimelineId};
use anyhow::Result;
use std::fs;
use zenith_utils::lsn::Lsn;
fn stub_conf() -> SafeKeeperConf {
let workdir = tempfile::tempdir().unwrap().into_path();
SafeKeeperConf {
workdir,
..Default::default()
}
}
fn load_from_control_file(
conf: &SafeKeeperConf,
zttid: &ZTenantTimelineId,
create: CreateControlFile,
) -> Result<(FileStorage, SafeKeeperState)> {
fs::create_dir_all(&conf.timeline_dir(zttid)).expect("failed to create timeline dir");
Ok((
FileStorage::new(zttid, conf),
FileStorage::load_control_file_conf(conf, zttid, create)?,
))
}
#[test]
fn test_read_write_safekeeper_state() {
let conf = stub_conf();
let zttid = ZTenantTimelineId::generate();
{
let (mut storage, mut state) =
load_from_control_file(&conf, &zttid, CreateControlFile::True)
.expect("failed to read state");
// change something
state.wal_start_lsn = Lsn(42);
storage.persist(&state).expect("failed to persist state");
}
let (_, state) = load_from_control_file(&conf, &zttid, CreateControlFile::False)
.expect("failed to read state");
assert_eq!(state.wal_start_lsn, Lsn(42));
}
#[test]
fn test_safekeeper_state_checksum_mismatch() {
let conf = stub_conf();
let zttid = ZTenantTimelineId::generate();
{
let (mut storage, mut state) =
load_from_control_file(&conf, &zttid, CreateControlFile::True)
.expect("failed to read state");
// change something
state.wal_start_lsn = Lsn(42);
storage.persist(&state).expect("failed to persist state");
}
let control_path = conf.timeline_dir(&zttid).join(CONTROL_FILE_NAME);
let mut data = fs::read(&control_path).unwrap();
data[0] += 1; // change the first byte of the file to fail checksum validation
fs::write(&control_path, &data).expect("failed to write control file");
match load_from_control_file(&conf, &zttid, CreateControlFile::False) {
Err(err) => assert!(err
.to_string()
.contains("safekeeper control file checksum mismatch")),
Ok(_) => panic!("expected error"),
}
}
}

View File

@@ -19,7 +19,7 @@ use zenith_utils::pq_proto::{BeMessage, FeStartupPacket, RowDescriptor, INT4_OID
use zenith_utils::zid::{ZTenantId, ZTenantTimelineId, ZTimelineId};
use crate::callmemaybe::CallmeEvent;
use crate::control_file::CreateControlFile;
use crate::timeline::CreateControlFile;
use tokio::sync::mpsc::UnboundedSender;
/// Safekeeper handler of postgres commands

View File

@@ -7,9 +7,9 @@ use zenith_utils::http::{RequestExt, RouterBuilder};
use zenith_utils::lsn::Lsn;
use zenith_utils::zid::ZTenantTimelineId;
use crate::control_file::CreateControlFile;
use crate::safekeeper::Term;
use crate::safekeeper::TermHistory;
use crate::timeline::CreateControlFile;
use crate::timeline::GlobalTimelines;
use crate::SafeKeeperConf;
use zenith_utils::http::endpoint;

View File

@@ -5,8 +5,6 @@ use std::time::Duration;
use zenith_utils::zid::ZTenantTimelineId;
pub mod callmemaybe;
pub mod control_file;
pub mod control_file_upgrade;
pub mod handler;
pub mod http;
pub mod json_ctrl;
@@ -15,8 +13,8 @@ pub mod s3_offload;
pub mod safekeeper;
pub mod send_wal;
pub mod timeline;
pub mod upgrade;
pub mod wal_service;
pub mod wal_storage;
pub mod defaults {
use const_format::formatcp;

View File

@@ -3,7 +3,7 @@
use anyhow::{bail, Context, Result};
use byteorder::{LittleEndian, ReadBytesExt};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use postgres_ffi::waldecoder::WalStreamDecoder;
use postgres_ffi::xlog_utils::TimeLineID;
use serde::{Deserialize, Serialize};
use std::cmp::min;
@@ -13,11 +13,12 @@ use tracing::*;
use lazy_static::lazy_static;
use crate::control_file;
use crate::send_wal::HotStandbyFeedback;
use crate::wal_storage;
use postgres_ffi::xlog_utils::MAX_SEND_SIZE;
use zenith_metrics::{register_gauge_vec, Gauge, GaugeVec};
use zenith_metrics::{
register_gauge_vec, register_histogram_vec, Gauge, GaugeVec, Histogram, HistogramVec,
DISK_WRITE_SECONDS_BUCKETS,
};
use zenith_utils::bin_ser::LeSer;
use zenith_utils::lsn::Lsn;
use zenith_utils::pq_proto::SystemId;
@@ -406,87 +407,130 @@ impl AcceptorProposerMessage {
}
}
pub trait Storage {
/// Persist safekeeper state on disk.
fn persist(&mut self, s: &SafeKeeperState) -> Result<()>;
/// Write piece of wal in buf to disk and sync it.
fn write_wal(&mut self, server: &ServerInfo, startpos: Lsn, buf: &[u8]) -> Result<()>;
// Truncate WAL at specified LSN
fn truncate_wal(&mut self, s: &ServerInfo, endpos: Lsn) -> Result<()>;
}
lazy_static! {
// The prometheus crate does not support u64 yet, i64 only (see `IntGauge`).
// i64 is faster than f64, so update to u64 when available.
static ref FLUSH_LSN_GAUGE: GaugeVec = register_gauge_vec!(
"safekeeper_flush_lsn",
"Current flush_lsn, grouped by timeline",
&["ztli"]
)
.expect("Failed to register safekeeper_flush_lsn gauge vec");
static ref COMMIT_LSN_GAUGE: GaugeVec = register_gauge_vec!(
"safekeeper_commit_lsn",
"Current commit_lsn (not necessarily persisted to disk), grouped by timeline",
&["tenant_id", "timeline_id"]
&["ztli"]
)
.expect("Failed to register safekeeper_commit_lsn gauge vec");
static ref WRITE_WAL_BYTES: HistogramVec = register_histogram_vec!(
"safekeeper_write_wal_bytes",
"Bytes written to WAL in a single request, grouped by timeline",
&["timeline_id"],
vec![1.0, 10.0, 100.0, 1024.0, 8192.0, 128.0 * 1024.0, 1024.0 * 1024.0, 10.0 * 1024.0 * 1024.0]
)
.expect("Failed to register safekeeper_write_wal_bytes histogram vec");
static ref WRITE_WAL_SECONDS: HistogramVec = register_histogram_vec!(
"safekeeper_write_wal_seconds",
"Seconds spent writing and syncing WAL to a disk in a single request, grouped by timeline",
&["timeline_id"],
DISK_WRITE_SECONDS_BUCKETS.to_vec()
)
.expect("Failed to register safekeeper_write_wal_seconds histogram vec");
}
struct SafeKeeperMetrics {
flush_lsn: Gauge,
commit_lsn: Gauge,
write_wal_bytes: Histogram,
write_wal_seconds: Histogram,
}
impl SafeKeeperMetrics {
fn new(tenant_id: ZTenantId, timeline_id: ZTimelineId, commit_lsn: Lsn) -> Self {
let tenant_id = tenant_id.to_string();
let timeline_id = timeline_id.to_string();
let m = Self {
commit_lsn: COMMIT_LSN_GAUGE.with_label_values(&[&tenant_id, &timeline_id]),
struct SafeKeeperMetricsBuilder {
ztli: ZTimelineId,
flush_lsn: Lsn,
commit_lsn: Lsn,
}
impl SafeKeeperMetricsBuilder {
fn build(self) -> SafeKeeperMetrics {
let ztli_str = format!("{}", self.ztli);
let m = SafeKeeperMetrics {
flush_lsn: FLUSH_LSN_GAUGE.with_label_values(&[&ztli_str]),
commit_lsn: COMMIT_LSN_GAUGE.with_label_values(&[&ztli_str]),
write_wal_bytes: WRITE_WAL_BYTES.with_label_values(&[&ztli_str]),
write_wal_seconds: WRITE_WAL_SECONDS.with_label_values(&[&ztli_str]),
};
m.commit_lsn.set(u64::from(commit_lsn) as f64);
m.flush_lsn.set(u64::from(self.flush_lsn) as f64);
m.commit_lsn.set(u64::from(self.commit_lsn) as f64);
m
}
}
/// SafeKeeper which consumes events (messages from compute) and provides
/// replies.
pub struct SafeKeeper<CTRL: control_file::Storage, WAL: wal_storage::Storage> {
pub struct SafeKeeper<ST: Storage> {
/// Locally flushed part of WAL with full records (end_lsn of last record).
/// Established by reading wal.
pub flush_lsn: Lsn,
// Cached metrics so we don't have to recompute labels on each update.
metrics: SafeKeeperMetrics,
/// not-yet-flushed pairs of same named fields in s.*
pub commit_lsn: Lsn,
pub truncate_lsn: Lsn,
pub storage: ST,
pub s: SafeKeeperState, // persistent part
pub control_store: CTRL,
pub wal_store: WAL,
decoder: WalStreamDecoder,
}
impl<CTRL, WAL> SafeKeeper<CTRL, WAL>
impl<ST> SafeKeeper<ST>
where
CTRL: control_file::Storage,
WAL: wal_storage::Storage,
ST: Storage,
{
// constructor
pub fn new(
ztli: ZTimelineId,
control_store: CTRL,
wal_store: WAL,
flush_lsn: Lsn,
storage: ST,
state: SafeKeeperState,
) -> SafeKeeper<CTRL, WAL> {
) -> SafeKeeper<ST> {
if state.server.timeline_id != ZTimelineId::from([0u8; 16])
&& ztli != state.server.timeline_id
{
panic!("Calling SafeKeeper::new with inconsistent ztli ({}) and SafeKeeperState.server.timeline_id ({})", ztli, state.server.timeline_id);
}
SafeKeeper {
metrics: SafeKeeperMetrics::new(state.server.tenant_id, ztli, state.commit_lsn),
flush_lsn,
metrics: SafeKeeperMetricsBuilder {
ztli,
flush_lsn,
commit_lsn: state.commit_lsn,
}
.build(),
commit_lsn: state.commit_lsn,
truncate_lsn: state.truncate_lsn,
storage,
s: state,
control_store,
wal_store,
decoder: WalStreamDecoder::new(Lsn(0)),
}
}
/// Get history of term switches for the available WAL
fn get_term_history(&self) -> TermHistory {
self.s
.acceptor_state
.term_history
.up_to(self.wal_store.flush_lsn())
self.s.acceptor_state.term_history.up_to(self.flush_lsn)
}
#[cfg(test)]
fn get_epoch(&self) -> Term {
self.s.acceptor_state.get_epoch(self.wal_store.flush_lsn())
self.s.acceptor_state.get_epoch(self.flush_lsn)
}
/// Process message from proposer and possibly form reply. Concurrent
@@ -528,20 +572,20 @@ where
}
// set basic info about server, if not yet
// TODO: verify that is doesn't change after
self.s.server.system_id = msg.system_id;
self.s.server.tenant_id = msg.tenant_id;
self.s.server.timeline_id = msg.ztli;
self.s.server.wal_seg_size = msg.wal_seg_size;
self.control_store
self.storage
.persist(&self.s)
.context("failed to persist shared state")?;
// pass wal_seg_size to read WAL and find flush_lsn
self.wal_store.init_storage(&self.s)?;
// update tenant_id/timeline_id in metrics
self.metrics = SafeKeeperMetrics::new(msg.tenant_id, msg.ztli, self.commit_lsn);
self.metrics = SafeKeeperMetricsBuilder {
ztli: self.s.server.timeline_id,
flush_lsn: self.flush_lsn,
commit_lsn: self.commit_lsn,
}
.build();
info!(
"processed greeting from proposer {:?}, sending term {:?}",
@@ -561,14 +605,14 @@ where
let mut resp = VoteResponse {
term: self.s.acceptor_state.term,
vote_given: false as u64,
flush_lsn: self.wal_store.flush_lsn(),
flush_lsn: self.flush_lsn,
truncate_lsn: self.s.truncate_lsn,
term_history: self.get_term_history(),
};
if self.s.acceptor_state.term < msg.term {
self.s.acceptor_state.term = msg.term;
// persist vote before sending it out
self.control_store.persist(&self.s)?;
self.storage.persist(&self.s)?;
resp.term = self.s.acceptor_state.term;
resp.vote_given = true as u64;
}
@@ -580,7 +624,7 @@ where
fn bump_if_higher(&mut self, term: Term) -> Result<()> {
if self.s.acceptor_state.term < term {
self.s.acceptor_state.term = term;
self.control_store.persist(&self.s)?;
self.storage.persist(&self.s)?;
}
Ok(())
}
@@ -589,7 +633,7 @@ where
fn append_response(&self) -> AppendResponse {
AppendResponse {
term: self.s.acceptor_state.term,
flush_lsn: self.wal_store.flush_lsn(),
flush_lsn: self.flush_lsn,
commit_lsn: self.s.commit_lsn,
// will be filled by the upper code to avoid bothering safekeeper
hs_feedback: HotStandbyFeedback::empty(),
@@ -605,12 +649,22 @@ where
return Ok(None);
}
// truncate wal, update the lsns
self.wal_store.truncate_wal(msg.start_streaming_at)?;
// TODO: cross check divergence point
// streaming must not create a hole
assert!(self.flush_lsn == Lsn(0) || self.flush_lsn >= msg.start_streaming_at);
// truncate obsolete part of WAL
if self.flush_lsn != Lsn(0) {
self.storage
.truncate_wal(&self.s.server, msg.start_streaming_at)?;
}
// update our end of WAL pointer
self.flush_lsn = msg.start_streaming_at;
self.metrics.flush_lsn.set(u64::from(self.flush_lsn) as f64);
// and now adopt term history from proposer
self.s.acceptor_state.term_history = msg.term_history.clone();
self.control_store.persist(&self.s)?;
self.storage.persist(&self.s)?;
info!("start receiving WAL since {:?}", msg.start_streaming_at);
@@ -636,14 +690,42 @@ where
// After ProposerElected, which performs truncation, we should get only
// indeed append requests (but flush_lsn is advanced only on record
// boundary, so might be less).
assert!(self.wal_store.flush_lsn() <= msg.h.begin_lsn);
assert!(self.flush_lsn <= msg.h.begin_lsn);
self.s.proposer_uuid = msg.h.proposer_uuid;
let mut sync_control_file = false;
// do the job
let mut last_rec_lsn = Lsn(0);
if !msg.wal_data.is_empty() {
self.wal_store.write_wal(msg.h.begin_lsn, &msg.wal_data)?;
self.metrics
.write_wal_bytes
.observe(msg.wal_data.len() as f64);
{
let _timer = self.metrics.write_wal_seconds.start_timer();
self.storage
.write_wal(&self.s.server, msg.h.begin_lsn, &msg.wal_data)?;
}
// figure out last record's end lsn for reporting (if we got the
// whole record)
if self.decoder.available() != msg.h.begin_lsn {
info!(
"restart decoder from {} to {}",
self.decoder.available(),
msg.h.begin_lsn,
);
self.decoder = WalStreamDecoder::new(msg.h.begin_lsn);
}
self.decoder.feed_bytes(&msg.wal_data);
loop {
match self.decoder.poll_decode()? {
None => break, // no full record yet
Some((lsn, _rec)) => {
last_rec_lsn = lsn;
}
}
}
// If this was the first record we ever receieved, remember LSN to help
// find_end_of_wal skip the hole in the beginning.
@@ -653,11 +735,16 @@ where
}
}
if last_rec_lsn > self.flush_lsn {
self.flush_lsn = last_rec_lsn;
self.metrics.flush_lsn.set(u64::from(self.flush_lsn) as f64);
}
// Advance commit_lsn taking into account what we have locally.
// commit_lsn can be 0, being unknown to new walproposer while he hasn't
// collected majority of its epoch acks yet, ignore it in this case.
if msg.h.commit_lsn != Lsn(0) {
let commit_lsn = min(msg.h.commit_lsn, self.wal_store.flush_lsn());
let commit_lsn = min(msg.h.commit_lsn, self.flush_lsn);
// If new commit_lsn reached epoch switch, force sync of control
// file: walproposer in sync mode is very interested when this
// happens. Note: this is for sync-safekeepers mode only, as
@@ -683,7 +770,7 @@ where
}
if sync_control_file {
self.control_store.persist(&self.s)?;
self.storage.persist(&self.s)?;
}
let resp = self.append_response();
@@ -702,52 +789,34 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::wal_storage::Storage;
// fake storage for tests
struct InMemoryState {
struct InMemoryStorage {
persisted_state: SafeKeeperState,
}
impl control_file::Storage for InMemoryState {
impl Storage for InMemoryStorage {
fn persist(&mut self, s: &SafeKeeperState) -> Result<()> {
self.persisted_state = s.clone();
Ok(())
}
}
struct DummyWalStore {
lsn: Lsn,
}
impl wal_storage::Storage for DummyWalStore {
fn flush_lsn(&self) -> Lsn {
self.lsn
}
fn init_storage(&mut self, _state: &SafeKeeperState) -> Result<()> {
fn write_wal(&mut self, _server: &ServerInfo, _startpos: Lsn, _buf: &[u8]) -> Result<()> {
Ok(())
}
fn write_wal(&mut self, startpos: Lsn, buf: &[u8]) -> Result<()> {
self.lsn = startpos + buf.len() as u64;
Ok(())
}
fn truncate_wal(&mut self, end_pos: Lsn) -> Result<()> {
self.lsn = end_pos;
fn truncate_wal(&mut self, _server: &ServerInfo, _end_pos: Lsn) -> Result<()> {
Ok(())
}
}
#[test]
fn test_voting() {
let storage = InMemoryState {
let storage = InMemoryStorage {
persisted_state: SafeKeeperState::new(),
};
let wal_store = DummyWalStore { lsn: Lsn(0) };
let ztli = ZTimelineId::from([0u8; 16]);
let mut sk = SafeKeeper::new(ztli, storage, wal_store, SafeKeeperState::new());
let mut sk = SafeKeeper::new(ztli, Lsn(0), storage, SafeKeeperState::new());
// check voting for 1 is ok
let vote_request = ProposerAcceptorMessage::VoteRequest(VoteRequest { term: 1 });
@@ -758,11 +827,11 @@ mod tests {
}
// reboot...
let state = sk.control_store.persisted_state.clone();
let storage = InMemoryState {
let state = sk.storage.persisted_state.clone();
let storage = InMemoryStorage {
persisted_state: state.clone(),
};
sk = SafeKeeper::new(ztli, storage, sk.wal_store, state);
sk = SafeKeeper::new(ztli, Lsn(0), storage, state);
// and ensure voting second time for 1 is not ok
vote_resp = sk.process_msg(&vote_request);
@@ -774,12 +843,11 @@ mod tests {
#[test]
fn test_epoch_switch() {
let storage = InMemoryState {
let storage = InMemoryStorage {
persisted_state: SafeKeeperState::new(),
};
let wal_store = DummyWalStore { lsn: Lsn(0) };
let ztli = ZTimelineId::from([0u8; 16]);
let mut sk = SafeKeeper::new(ztli, storage, wal_store, SafeKeeperState::new());
let mut sk = SafeKeeper::new(ztli, Lsn(0), storage, SafeKeeperState::new());
let mut ar_hdr = AppendRequestHeader {
term: 1,
@@ -820,7 +888,7 @@ mod tests {
};
let resp = sk.process_msg(&ProposerAcceptorMessage::AppendRequest(append_request));
assert!(resp.is_ok());
sk.wal_store.truncate_wal(Lsn(3)).unwrap(); // imitate the complete record at 3 %)
sk.flush_lsn = Lsn(3); // imitate the complete record at 3 %)
assert_eq!(sk.get_epoch(), 1);
}
}

View File

@@ -3,16 +3,20 @@
use crate::handler::SafekeeperPostgresHandler;
use crate::timeline::{ReplicaState, Timeline, TimelineTools};
use crate::wal_storage::WalReader;
use anyhow::{bail, Context, Result};
use postgres_ffi::xlog_utils::{get_current_timestamp, TimestampTz, MAX_SEND_SIZE};
use postgres_ffi::xlog_utils::{
get_current_timestamp, TimestampTz, XLogFileName, MAX_SEND_SIZE, PG_TLI,
};
use crate::callmemaybe::{CallmeEvent, SubscriptionStateKey};
use bytes::Bytes;
use serde::{Deserialize, Serialize};
use std::cmp::min;
use std::fs::File;
use std::io::{Read, Seek, SeekFrom};
use std::net::Shutdown;
use std::path::Path;
use std::sync::Arc;
use std::thread::sleep;
use std::time::Duration;
@@ -190,6 +194,24 @@ impl ReplicationConn {
Ok(())
}
/// Helper function for opening a wal file.
fn open_wal_file(wal_file_path: &Path) -> Result<File> {
// First try to open the .partial file.
let mut partial_path = wal_file_path.to_owned();
partial_path.set_extension("partial");
if let Ok(opened_file) = File::open(&partial_path) {
return Ok(opened_file);
}
// If that failed, try it without the .partial extension.
File::open(&wal_file_path)
.with_context(|| format!("Failed to open WAL file {:?}", wal_file_path))
.map_err(|e| {
error!("{}", e);
e
})
}
///
/// Handle START_REPLICATION replication command
///
@@ -289,15 +311,7 @@ impl ReplicationConn {
pgb.write_message(&BeMessage::CopyBothResponse)?;
let mut end_pos = Lsn(0);
let mut wal_reader = WalReader::new(
spg.conf.timeline_dir(&spg.timeline.get().zttid),
wal_seg_size,
start_pos,
);
// buffer for wal sending, limited by MAX_SEND_SIZE
let mut send_buf = vec![0u8; MAX_SEND_SIZE];
let mut wal_file: Option<File> = None;
loop {
if let Some(stop_pos) = stop_pos {
@@ -331,26 +345,53 @@ impl ReplicationConn {
}
}
// Take the `File` from `wal_file`, or open a new file.
let mut file = match wal_file.take() {
Some(file) => file,
None => {
// Open a new file.
let segno = start_pos.segment_number(wal_seg_size);
let wal_file_name = XLogFileName(PG_TLI, segno, wal_seg_size);
let wal_file_path = spg
.conf
.timeline_dir(&spg.timeline.get().zttid)
.join(wal_file_name);
Self::open_wal_file(&wal_file_path)?
}
};
let xlogoff = start_pos.segment_offset(wal_seg_size) as usize;
// How much to read and send in message? We cannot cross the WAL file
// boundary, and we don't want send more than MAX_SEND_SIZE.
let send_size = end_pos.checked_sub(start_pos).unwrap().0 as usize;
let send_size = min(send_size, send_buf.len());
let send_size = min(send_size, wal_seg_size - xlogoff);
let send_size = min(send_size, MAX_SEND_SIZE);
let send_buf = &mut send_buf[..send_size];
// read wal into buffer
let send_size = wal_reader.read(send_buf)?;
let send_buf = &send_buf[..send_size];
// Read some data from the file.
let mut file_buf = vec![0u8; send_size];
file.seek(SeekFrom::Start(xlogoff as u64))
.and_then(|_| file.read_exact(&mut file_buf))
.context("Failed to read data from WAL file")?;
// Write some data to the network socket.
pgb.write_message(&BeMessage::XLogData(XLogDataBody {
wal_start: start_pos.0,
wal_end: end_pos.0,
timestamp: get_current_timestamp(),
data: send_buf,
data: &file_buf,
}))
.context("Failed to send XLogData")?;
start_pos += send_size as u64;
trace!("sent WAL up to {}", start_pos);
// Decide whether to reuse this file. If we don't set wal_file here
// a new file will be opened next time.
if start_pos.segment_offset(wal_seg_size) != 0 {
wal_file = Some(file);
}
}
Ok(())
}

View File

@@ -1,35 +1,42 @@
//! This module contains timeline id -> safekeeper state map with file-backed
//! persistence and support for interaction between sending and receiving wal.
use anyhow::{Context, Result};
use anyhow::{bail, ensure, Context, Result};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use lazy_static::lazy_static;
use postgres_ffi::xlog_utils::{find_end_of_wal, XLogSegNo, PG_TLI};
use std::cmp::{max, min};
use std::collections::HashMap;
use std::fs::{self};
use std::fs::{self, File, OpenOptions};
use std::io::{Read, Seek, SeekFrom, Write};
use std::path::{Path, PathBuf};
use std::sync::{Arc, Condvar, Mutex};
use std::time::Duration;
use tokio::sync::mpsc::UnboundedSender;
use tracing::*;
use zenith_metrics::{register_histogram_vec, Histogram, HistogramVec, DISK_WRITE_SECONDS_BUCKETS};
use zenith_utils::bin_ser::LeSer;
use zenith_utils::lsn::Lsn;
use zenith_utils::zid::ZTenantTimelineId;
use crate::callmemaybe::{CallmeEvent, SubscriptionStateKey};
use crate::control_file::{self, CreateControlFile};
use crate::safekeeper::{
AcceptorProposerMessage, ProposerAcceptorMessage, SafeKeeper, SafeKeeperState,
AcceptorProposerMessage, ProposerAcceptorMessage, SafeKeeper, SafeKeeperState, ServerInfo,
Storage, SK_FORMAT_VERSION, SK_MAGIC,
};
use crate::send_wal::HotStandbyFeedback;
use crate::wal_storage::{self, Storage};
use crate::upgrade::upgrade_control_file;
use crate::SafeKeeperConf;
use postgres_ffi::xlog_utils::{XLogFileName, XLOG_BLCKSZ};
use std::convert::TryInto;
use zenith_utils::pq_proto::ZenithFeedback;
// contains persistent metadata for safekeeper
const CONTROL_FILE_NAME: &str = "safekeeper.control";
// needed to atomically update the state using `rename`
const CONTROL_FILE_NAME_PARTIAL: &str = "safekeeper.control.partial";
const POLL_STATE_TIMEOUT: Duration = Duration::from_secs(1);
pub const CHECKSUM_SIZE: usize = std::mem::size_of::<u32>();
/// Replica status update + hot standby feedback
#[derive(Debug, Clone, Copy)]
@@ -68,7 +75,7 @@ impl ReplicaState {
/// Shared state associated with database instance
struct SharedState {
/// Safekeeper object
sk: SafeKeeper<control_file::FileStorage, wal_storage::PhysicalStorage>,
sk: SafeKeeper<FileStorage>,
/// For receiving-sending wal cooperation
/// quorum commit LSN we've notified walsenders about
notified_commit_lsn: Lsn,
@@ -86,6 +93,23 @@ struct SharedState {
pageserver_connstr: Option<String>,
}
// A named boolean.
#[derive(Debug)]
pub enum CreateControlFile {
True,
False,
}
lazy_static! {
static ref PERSIST_CONTROL_FILE_SECONDS: HistogramVec = register_histogram_vec!(
"safekeeper_persist_control_file_seconds",
"Seconds to persist and sync control file, grouped by timeline",
&["timeline_id"],
DISK_WRITE_SECONDS_BUCKETS.to_vec()
)
.expect("Failed to register safekeeper_persist_control_file_seconds histogram vec");
}
impl SharedState {
/// Restore SharedState from control file.
/// If create=false and file doesn't exist, bails out.
@@ -94,18 +118,32 @@ impl SharedState {
zttid: &ZTenantTimelineId,
create: CreateControlFile,
) -> Result<Self> {
let state = control_file::FileStorage::load_control_file_conf(conf, zttid, create)
let state = FileStorage::load_control_file_conf(conf, zttid, create)
.context("failed to load from control file")?;
let control_store = control_file::FileStorage::new(zttid, conf);
let wal_store = wal_storage::PhysicalStorage::new(zttid, conf);
info!("timeline {} created or restored", zttid.timeline_id);
let file_storage = FileStorage::new(zttid, conf);
let flush_lsn = if state.server.wal_seg_size != 0 {
let wal_dir = conf.timeline_dir(zttid);
Lsn(find_end_of_wal(
&wal_dir,
state.server.wal_seg_size as usize,
true,
state.wal_start_lsn,
)?
.0)
} else {
Lsn(0)
};
info!(
"timeline {} created or restored: flush_lsn={}, commit_lsn={}, truncate_lsn={}",
zttid.timeline_id, flush_lsn, state.commit_lsn, state.truncate_lsn,
);
if flush_lsn < state.commit_lsn || flush_lsn < state.truncate_lsn {
warn!("timeline {} potential data loss: flush_lsn by find_end_of_wal is less than either commit_lsn or truncate_lsn from control file", zttid.timeline_id);
}
Ok(Self {
notified_commit_lsn: Lsn(0),
sk: SafeKeeper::new(zttid.timeline_id, control_store, wal_store, state),
sk: SafeKeeper::new(zttid.timeline_id, flush_lsn, file_storage, state),
replicas: Vec::new(),
active: false,
num_computes: 0,
@@ -412,7 +450,7 @@ impl Timeline {
pub fn get_end_of_wal(&self) -> Lsn {
let shared_state = self.mutex.lock().unwrap();
shared_state.sk.wal_store.flush_lsn()
shared_state.sk.flush_lsn
}
}
@@ -488,3 +526,397 @@ impl GlobalTimelines {
}
}
}
#[derive(Debug)]
pub struct FileStorage {
// save timeline dir to avoid reconstructing it every time
timeline_dir: PathBuf,
conf: SafeKeeperConf,
persist_control_file_seconds: Histogram,
}
impl FileStorage {
fn new(zttid: &ZTenantTimelineId, conf: &SafeKeeperConf) -> FileStorage {
let timeline_dir = conf.timeline_dir(zttid);
let timelineid_str = format!("{}", zttid);
FileStorage {
timeline_dir,
conf: conf.clone(),
persist_control_file_seconds: PERSIST_CONTROL_FILE_SECONDS
.with_label_values(&[&timelineid_str]),
}
}
// Check the magic/version in the on-disk data and deserialize it, if possible.
fn deser_sk_state(buf: &mut &[u8]) -> Result<SafeKeeperState> {
// Read the version independent part
let magic = buf.read_u32::<LittleEndian>()?;
if magic != SK_MAGIC {
bail!(
"bad control file magic: {:X}, expected {:X}",
magic,
SK_MAGIC
);
}
let version = buf.read_u32::<LittleEndian>()?;
if version == SK_FORMAT_VERSION {
let res = SafeKeeperState::des(buf)?;
return Ok(res);
}
// try to upgrade
upgrade_control_file(buf, version)
}
// Load control file for given zttid at path specified by conf.
fn load_control_file_conf(
conf: &SafeKeeperConf,
zttid: &ZTenantTimelineId,
create: CreateControlFile,
) -> Result<SafeKeeperState> {
let path = conf.timeline_dir(zttid).join(CONTROL_FILE_NAME);
Self::load_control_file(path, create)
}
/// Read in the control file.
/// If create=false and file doesn't exist, bails out.
pub fn load_control_file<P: AsRef<Path>>(
control_file_path: P,
create: CreateControlFile,
) -> Result<SafeKeeperState> {
info!(
"loading control file {}, create={:?}",
control_file_path.as_ref().display(),
create,
);
let mut control_file = OpenOptions::new()
.read(true)
.write(true)
.create(matches!(create, CreateControlFile::True))
.open(&control_file_path)
.with_context(|| {
format!(
"failed to open control file at {}",
control_file_path.as_ref().display(),
)
})?;
// Empty file is legit on 'create', don't try to deser from it.
let state = if control_file.metadata().unwrap().len() == 0 {
if let CreateControlFile::False = create {
bail!("control file is empty");
}
SafeKeeperState::new()
} else {
let mut buf = Vec::new();
control_file
.read_to_end(&mut buf)
.context("failed to read control file")?;
let calculated_checksum = crc32c::crc32c(&buf[..buf.len() - CHECKSUM_SIZE]);
let expected_checksum_bytes: &[u8; CHECKSUM_SIZE] =
buf[buf.len() - CHECKSUM_SIZE..].try_into()?;
let expected_checksum = u32::from_le_bytes(*expected_checksum_bytes);
ensure!(
calculated_checksum == expected_checksum,
format!(
"safekeeper control file checksum mismatch: expected {} got {}",
expected_checksum, calculated_checksum
)
);
FileStorage::deser_sk_state(&mut &buf[..buf.len() - CHECKSUM_SIZE]).with_context(
|| {
format!(
"while reading control file {}",
control_file_path.as_ref().display(),
)
},
)?
};
Ok(state)
}
/// Helper returning full path to WAL segment file and its .partial brother.
fn wal_file_paths(&self, segno: XLogSegNo, wal_seg_size: usize) -> (PathBuf, PathBuf) {
let wal_file_name = XLogFileName(PG_TLI, segno, wal_seg_size);
let wal_file_path = self.timeline_dir.join(wal_file_name.clone());
let wal_file_partial_path = self.timeline_dir.join(wal_file_name + ".partial");
(wal_file_path, wal_file_partial_path)
}
}
impl Storage for FileStorage {
// persists state durably to underlying storage
// for description see https://lwn.net/Articles/457667/
fn persist(&mut self, s: &SafeKeeperState) -> Result<()> {
let _timer = &self.persist_control_file_seconds.start_timer();
// write data to safekeeper.control.partial
let control_partial_path = self.timeline_dir.join(CONTROL_FILE_NAME_PARTIAL);
let mut control_partial = File::create(&control_partial_path).with_context(|| {
format!(
"failed to create partial control file at: {}",
&control_partial_path.display()
)
})?;
let mut buf: Vec<u8> = Vec::new();
buf.write_u32::<LittleEndian>(SK_MAGIC)?;
buf.write_u32::<LittleEndian>(SK_FORMAT_VERSION)?;
s.ser_into(&mut buf)?;
// calculate checksum before resize
let checksum = crc32c::crc32c(&buf);
buf.extend_from_slice(&checksum.to_le_bytes());
control_partial.write_all(&buf).with_context(|| {
format!(
"failed to write safekeeper state into control file at: {}",
control_partial_path.display()
)
})?;
// fsync the file
control_partial.sync_all().with_context(|| {
format!(
"failed to sync partial control file at {}",
control_partial_path.display()
)
})?;
let control_path = self.timeline_dir.join(CONTROL_FILE_NAME);
// rename should be atomic
fs::rename(&control_partial_path, &control_path)?;
// this sync is not required by any standard but postgres does this (see durable_rename)
File::open(&control_path)
.and_then(|f| f.sync_all())
.with_context(|| {
format!(
"failed to sync control file at: {}",
&control_path.display()
)
})?;
// fsync the directory (linux specific)
File::open(&self.timeline_dir)
.and_then(|f| f.sync_all())
.context("failed to sync control file directory")?;
Ok(())
}
fn write_wal(&mut self, server: &ServerInfo, startpos: Lsn, buf: &[u8]) -> Result<()> {
let mut bytes_left: usize = buf.len();
let mut bytes_written: usize = 0;
let mut partial;
let mut start_pos = startpos;
const ZERO_BLOCK: &[u8] = &[0u8; XLOG_BLCKSZ];
let wal_seg_size = server.wal_seg_size as usize;
/* Extract WAL location for this block */
let mut xlogoff = start_pos.segment_offset(wal_seg_size) as usize;
while bytes_left != 0 {
let bytes_to_write;
/*
* If crossing a WAL boundary, only write up until we reach wal
* segment size.
*/
if xlogoff + bytes_left > wal_seg_size {
bytes_to_write = wal_seg_size - xlogoff;
} else {
bytes_to_write = bytes_left;
}
/* Open file */
let segno = start_pos.segment_number(wal_seg_size);
let (wal_file_path, wal_file_partial_path) = self.wal_file_paths(segno, wal_seg_size);
{
let mut wal_file: File;
/* Try to open already completed segment */
if let Ok(file) = OpenOptions::new().write(true).open(&wal_file_path) {
wal_file = file;
partial = false;
} else if let Ok(file) = OpenOptions::new().write(true).open(&wal_file_partial_path)
{
/* Try to open existed partial file */
wal_file = file;
partial = true;
} else {
/* Create and fill new partial file */
partial = true;
match OpenOptions::new()
.create(true)
.write(true)
.open(&wal_file_partial_path)
{
Ok(mut file) => {
for _ in 0..(wal_seg_size / XLOG_BLCKSZ) {
file.write_all(ZERO_BLOCK)?;
}
wal_file = file;
}
Err(e) => {
error!("Failed to open log file {:?}: {}", &wal_file_path, e);
return Err(e.into());
}
}
}
wal_file.seek(SeekFrom::Start(xlogoff as u64))?;
wal_file.write_all(&buf[bytes_written..(bytes_written + bytes_to_write)])?;
// Flush file, if not said otherwise
if !self.conf.no_sync {
wal_file.sync_all()?;
}
}
/* Write was successful, advance our position */
bytes_written += bytes_to_write;
bytes_left -= bytes_to_write;
start_pos += bytes_to_write as u64;
xlogoff += bytes_to_write;
/* Did we reach the end of a WAL segment? */
if start_pos.segment_offset(wal_seg_size) == 0 {
xlogoff = 0;
if partial {
fs::rename(&wal_file_partial_path, &wal_file_path)?;
}
}
}
Ok(())
}
fn truncate_wal(&mut self, server: &ServerInfo, end_pos: Lsn) -> Result<()> {
let partial;
const ZERO_BLOCK: &[u8] = &[0u8; XLOG_BLCKSZ];
let wal_seg_size = server.wal_seg_size as usize;
/* Extract WAL location for this block */
let mut xlogoff = end_pos.segment_offset(wal_seg_size) as usize;
/* Open file */
let mut segno = end_pos.segment_number(wal_seg_size);
let (wal_file_path, wal_file_partial_path) = self.wal_file_paths(segno, wal_seg_size);
{
let mut wal_file: File;
/* Try to open already completed segment */
if let Ok(file) = OpenOptions::new().write(true).open(&wal_file_path) {
wal_file = file;
partial = false;
} else {
wal_file = OpenOptions::new()
.write(true)
.open(&wal_file_partial_path)?;
partial = true;
}
wal_file.seek(SeekFrom::Start(xlogoff as u64))?;
while xlogoff < wal_seg_size {
let bytes_to_write = min(XLOG_BLCKSZ, wal_seg_size - xlogoff);
wal_file.write_all(&ZERO_BLOCK[0..bytes_to_write])?;
xlogoff += bytes_to_write;
}
// Flush file, if not said otherwise
if !self.conf.no_sync {
wal_file.sync_all()?;
}
}
if !partial {
// Make segment partial once again
fs::rename(&wal_file_path, &wal_file_partial_path)?;
}
// Remove all subsequent segments
loop {
segno += 1;
let (wal_file_path, wal_file_partial_path) = self.wal_file_paths(segno, wal_seg_size);
// TODO: better use fs::try_exists which is currenty avaialble only in nightly build
if wal_file_path.exists() {
fs::remove_file(&wal_file_path)?;
} else if wal_file_partial_path.exists() {
fs::remove_file(&wal_file_partial_path)?;
} else {
break;
}
}
Ok(())
}
}
#[cfg(test)]
mod test {
use super::FileStorage;
use crate::{
safekeeper::{SafeKeeperState, Storage},
timeline::{CreateControlFile, CONTROL_FILE_NAME},
SafeKeeperConf, ZTenantTimelineId,
};
use anyhow::Result;
use std::fs;
use zenith_utils::lsn::Lsn;
fn stub_conf() -> SafeKeeperConf {
let workdir = tempfile::tempdir().unwrap().into_path();
SafeKeeperConf {
workdir,
..Default::default()
}
}
fn load_from_control_file(
conf: &SafeKeeperConf,
zttid: &ZTenantTimelineId,
create: CreateControlFile,
) -> Result<(FileStorage, SafeKeeperState)> {
fs::create_dir_all(&conf.timeline_dir(zttid)).expect("failed to create timeline dir");
Ok((
FileStorage::new(zttid, conf),
FileStorage::load_control_file_conf(conf, zttid, create)?,
))
}
#[test]
fn test_read_write_safekeeper_state() {
let conf = stub_conf();
let zttid = ZTenantTimelineId::generate();
{
let (mut storage, mut state) =
load_from_control_file(&conf, &zttid, CreateControlFile::True)
.expect("failed to read state");
// change something
state.wal_start_lsn = Lsn(42);
storage.persist(&state).expect("failed to persist state");
}
let (_, state) = load_from_control_file(&conf, &zttid, CreateControlFile::False)
.expect("failed to read state");
assert_eq!(state.wal_start_lsn, Lsn(42));
}
#[test]
fn test_safekeeper_state_checksum_mismatch() {
let conf = stub_conf();
let zttid = ZTenantTimelineId::generate();
{
let (mut storage, mut state) =
load_from_control_file(&conf, &zttid, CreateControlFile::True)
.expect("failed to read state");
// change something
state.wal_start_lsn = Lsn(42);
storage.persist(&state).expect("failed to persist state");
}
let control_path = conf.timeline_dir(&zttid).join(CONTROL_FILE_NAME);
let mut data = fs::read(&control_path).unwrap();
data[0] += 1; // change the first byte of the file to fail checksum validation
fs::write(&control_path, &data).expect("failed to write control file");
match load_from_control_file(&conf, &zttid, CreateControlFile::False) {
Err(err) => assert!(err
.to_string()
.contains("safekeeper control file checksum mismatch")),
Ok(_) => panic!("expected error"),
}
}
}

View File

@@ -1,493 +0,0 @@
//! This module has everything to deal with WAL -- reading and writing to disk.
//!
//! Safekeeper WAL is stored in the timeline directory, in format similar to pg_wal.
//! PG timeline is always 1, so WAL segments are usually have names like this:
//! - 000000010000000000000001
//! - 000000010000000000000002.partial
//!
//! Note that last file has `.partial` suffix, that's different from postgres.
use anyhow::{anyhow, Context, Result};
use std::io::{Read, Seek, SeekFrom};
use lazy_static::lazy_static;
use postgres_ffi::xlog_utils::{find_end_of_wal, XLogSegNo, PG_TLI};
use std::cmp::min;
use std::fs::{self, File, OpenOptions};
use std::io::Write;
use std::path::{Path, PathBuf};
use tracing::*;
use zenith_utils::lsn::Lsn;
use zenith_utils::zid::ZTenantTimelineId;
use crate::safekeeper::SafeKeeperState;
use crate::SafeKeeperConf;
use postgres_ffi::xlog_utils::{XLogFileName, XLOG_BLCKSZ};
use postgres_ffi::waldecoder::WalStreamDecoder;
use zenith_metrics::{
register_gauge_vec, register_histogram_vec, Gauge, GaugeVec, Histogram, HistogramVec,
DISK_WRITE_SECONDS_BUCKETS,
};
lazy_static! {
// The prometheus crate does not support u64 yet, i64 only (see `IntGauge`).
// i64 is faster than f64, so update to u64 when available.
static ref FLUSH_LSN_GAUGE: GaugeVec = register_gauge_vec!(
"safekeeper_flush_lsn",
"Current flush_lsn, grouped by timeline",
&["tenant_id", "timeline_id"]
)
.expect("Failed to register safekeeper_flush_lsn gauge vec");
static ref WRITE_WAL_BYTES: HistogramVec = register_histogram_vec!(
"safekeeper_write_wal_bytes",
"Bytes written to WAL in a single request, grouped by timeline",
&["tenant_id", "timeline_id"],
vec![1.0, 10.0, 100.0, 1024.0, 8192.0, 128.0 * 1024.0, 1024.0 * 1024.0, 10.0 * 1024.0 * 1024.0]
)
.expect("Failed to register safekeeper_write_wal_bytes histogram vec");
static ref WRITE_WAL_SECONDS: HistogramVec = register_histogram_vec!(
"safekeeper_write_wal_seconds",
"Seconds spent writing and syncing WAL to a disk in a single request, grouped by timeline",
&["tenant_id", "timeline_id"],
DISK_WRITE_SECONDS_BUCKETS.to_vec()
)
.expect("Failed to register safekeeper_write_wal_seconds histogram vec");
}
struct WalStorageMetrics {
flush_lsn: Gauge,
write_wal_bytes: Histogram,
write_wal_seconds: Histogram,
}
impl WalStorageMetrics {
fn new(zttid: &ZTenantTimelineId) -> Self {
let tenant_id = zttid.tenant_id.to_string();
let timeline_id = zttid.timeline_id.to_string();
Self {
flush_lsn: FLUSH_LSN_GAUGE.with_label_values(&[&tenant_id, &timeline_id]),
write_wal_bytes: WRITE_WAL_BYTES.with_label_values(&[&tenant_id, &timeline_id]),
write_wal_seconds: WRITE_WAL_SECONDS.with_label_values(&[&tenant_id, &timeline_id]),
}
}
}
pub trait Storage {
/// lsn of last durably stored WAL record.
fn flush_lsn(&self) -> Lsn;
/// Init storage with wal_seg_size and read WAL from disk to get latest lsn.
fn init_storage(&mut self, state: &SafeKeeperState) -> Result<()>;
/// Write piece of wal in buf to disk and sync it.
fn write_wal(&mut self, startpos: Lsn, buf: &[u8]) -> Result<()>;
// Truncate WAL at specified LSN.
fn truncate_wal(&mut self, end_pos: Lsn) -> Result<()>;
}
pub struct PhysicalStorage {
metrics: WalStorageMetrics,
zttid: ZTenantTimelineId,
timeline_dir: PathBuf,
conf: SafeKeeperConf,
// fields below are filled upon initialization
// None if unitialized, Some(lsn) if storage is initialized
wal_seg_size: Option<usize>,
// Relationship of lsns:
// `write_lsn` >= `write_record_lsn` >= `flush_record_lsn`
//
// All lsns are zeroes, if storage is just created, and there are no segments on disk.
// Written to disk, but possibly still in the cache and not fully persisted.
// Also can be ahead of record_lsn, if happen to be in the middle of a WAL record.
write_lsn: Lsn,
// The LSN of the last WAL record written to disk. Still can be not fully flushed.
write_record_lsn: Lsn,
// The LSN of the last WAL record flushed to disk.
flush_record_lsn: Lsn,
// Decoder is required for detecting boundaries of WAL records.
decoder: WalStreamDecoder,
}
impl PhysicalStorage {
pub fn new(zttid: &ZTenantTimelineId, conf: &SafeKeeperConf) -> PhysicalStorage {
let timeline_dir = conf.timeline_dir(zttid);
PhysicalStorage {
metrics: WalStorageMetrics::new(zttid),
zttid: *zttid,
timeline_dir,
conf: conf.clone(),
wal_seg_size: None,
write_lsn: Lsn(0),
write_record_lsn: Lsn(0),
flush_record_lsn: Lsn(0),
decoder: WalStreamDecoder::new(Lsn(0)),
}
}
// wrapper for flush_lsn updates that also updates metrics
fn update_flush_lsn(&mut self) {
self.flush_record_lsn = self.write_record_lsn;
self.metrics.flush_lsn.set(self.flush_record_lsn.0 as f64);
}
/// Helper returning full path to WAL segment file and its .partial brother.
fn wal_file_paths(&self, segno: XLogSegNo) -> Result<(PathBuf, PathBuf)> {
let wal_seg_size = self
.wal_seg_size
.ok_or_else(|| anyhow!("wal_seg_size is not initialized"))?;
let wal_file_name = XLogFileName(PG_TLI, segno, wal_seg_size);
let wal_file_path = self.timeline_dir.join(wal_file_name.clone());
let wal_file_partial_path = self.timeline_dir.join(wal_file_name + ".partial");
Ok((wal_file_path, wal_file_partial_path))
}
// TODO: this function is going to be refactored soon, what will change:
// - flush will be called separately from write_wal, this function
// will only write bytes to disk
// - File will be cached in PhysicalStorage, to remove extra syscalls,
// such as open(), seek(), close()
fn write_and_flush(&mut self, startpos: Lsn, buf: &[u8]) -> Result<()> {
let wal_seg_size = self
.wal_seg_size
.ok_or_else(|| anyhow!("wal_seg_size is not initialized"))?;
let mut bytes_left: usize = buf.len();
let mut bytes_written: usize = 0;
let mut partial;
let mut start_pos = startpos;
const ZERO_BLOCK: &[u8] = &[0u8; XLOG_BLCKSZ];
/* Extract WAL location for this block */
let mut xlogoff = start_pos.segment_offset(wal_seg_size) as usize;
while bytes_left != 0 {
let bytes_to_write;
/*
* If crossing a WAL boundary, only write up until we reach wal
* segment size.
*/
if xlogoff + bytes_left > wal_seg_size {
bytes_to_write = wal_seg_size - xlogoff;
} else {
bytes_to_write = bytes_left;
}
/* Open file */
let segno = start_pos.segment_number(wal_seg_size);
let (wal_file_path, wal_file_partial_path) = self.wal_file_paths(segno)?;
{
let mut wal_file: File;
/* Try to open already completed segment */
if let Ok(file) = OpenOptions::new().write(true).open(&wal_file_path) {
wal_file = file;
partial = false;
} else if let Ok(file) = OpenOptions::new().write(true).open(&wal_file_partial_path)
{
/* Try to open existed partial file */
wal_file = file;
partial = true;
} else {
/* Create and fill new partial file */
partial = true;
match OpenOptions::new()
.create(true)
.write(true)
.open(&wal_file_partial_path)
{
Ok(mut file) => {
for _ in 0..(wal_seg_size / XLOG_BLCKSZ) {
file.write_all(ZERO_BLOCK)?;
}
wal_file = file;
}
Err(e) => {
error!("Failed to open log file {:?}: {}", &wal_file_path, e);
return Err(e.into());
}
}
}
wal_file.seek(SeekFrom::Start(xlogoff as u64))?;
wal_file.write_all(&buf[bytes_written..(bytes_written + bytes_to_write)])?;
// Flush file, if not said otherwise
if !self.conf.no_sync {
wal_file.sync_all()?;
}
}
/* Write was successful, advance our position */
bytes_written += bytes_to_write;
bytes_left -= bytes_to_write;
start_pos += bytes_to_write as u64;
xlogoff += bytes_to_write;
/* Did we reach the end of a WAL segment? */
if start_pos.segment_offset(wal_seg_size) == 0 {
xlogoff = 0;
if partial {
fs::rename(&wal_file_partial_path, &wal_file_path)?;
}
}
}
Ok(())
}
}
impl Storage for PhysicalStorage {
// flush_lsn returns lsn of last durably stored WAL record.
fn flush_lsn(&self) -> Lsn {
self.flush_record_lsn
}
// Storage needs to know wal_seg_size to know which segment to read/write, but
// wal_seg_size is not always known at the moment of storage creation. This method
// allows to postpone its initialization.
fn init_storage(&mut self, state: &SafeKeeperState) -> Result<()> {
if state.server.wal_seg_size == 0 {
// wal_seg_size is still unknown
return Ok(());
}
if let Some(wal_seg_size) = self.wal_seg_size {
// physical storage is already initialized
assert_eq!(wal_seg_size, state.server.wal_seg_size as usize);
return Ok(());
}
// initialize physical storage
let wal_seg_size = state.server.wal_seg_size as usize;
self.wal_seg_size = Some(wal_seg_size);
// we need to read WAL from disk to know which LSNs are stored on disk
self.write_lsn =
Lsn(find_end_of_wal(&self.timeline_dir, wal_seg_size, true, state.wal_start_lsn)?.0);
self.write_record_lsn = self.write_lsn;
// TODO: do we really know that write_lsn is fully flushed to disk?
// If not, maybe it's better to call fsync() here to be sure?
self.update_flush_lsn();
info!(
"initialized storage for timeline {}, flush_lsn={}, commit_lsn={}, truncate_lsn={}",
self.zttid.timeline_id, self.flush_record_lsn, state.commit_lsn, state.truncate_lsn,
);
if self.flush_record_lsn < state.commit_lsn || self.flush_record_lsn < state.truncate_lsn {
warn!("timeline {} potential data loss: flush_lsn by find_end_of_wal is less than either commit_lsn or truncate_lsn from control file", self.zttid.timeline_id);
}
Ok(())
}
// Write and flush WAL to disk.
fn write_wal(&mut self, startpos: Lsn, buf: &[u8]) -> Result<()> {
if self.write_lsn > startpos {
warn!(
"write_wal rewrites WAL written before, write_lsn={}, startpos={}",
self.write_lsn, startpos
);
}
if self.write_lsn < startpos {
warn!(
"write_wal creates gap in written WAL, write_lsn={}, startpos={}",
self.write_lsn, startpos
);
// TODO: return error if write_lsn is not zero
}
{
let _timer = self.metrics.write_wal_seconds.start_timer();
self.write_and_flush(startpos, buf)?;
}
// WAL is written and flushed, updating lsns
self.write_lsn = startpos + buf.len() as u64;
self.metrics.write_wal_bytes.observe(buf.len() as f64);
// figure out last record's end lsn for reporting (if we got the
// whole record)
if self.decoder.available() != startpos {
info!(
"restart decoder from {} to {}",
self.decoder.available(),
startpos,
);
self.decoder = WalStreamDecoder::new(startpos);
}
self.decoder.feed_bytes(buf);
loop {
match self.decoder.poll_decode()? {
None => break, // no full record yet
Some((lsn, _rec)) => {
self.write_record_lsn = lsn;
}
}
}
self.update_flush_lsn();
Ok(())
}
// Truncate written WAL by removing all WAL segments after the given LSN.
// end_pos must point to the end of the WAL record.
fn truncate_wal(&mut self, end_pos: Lsn) -> Result<()> {
let wal_seg_size = self
.wal_seg_size
.ok_or_else(|| anyhow!("wal_seg_size is not initialized"))?;
// TODO: cross check divergence point
// nothing to truncate
if self.write_lsn == Lsn(0) {
return Ok(());
}
// Streaming must not create a hole, so truncate cannot be called on non-written lsn
assert!(self.write_lsn >= end_pos);
// open segment files and delete or fill end with zeroes
let partial;
const ZERO_BLOCK: &[u8] = &[0u8; XLOG_BLCKSZ];
/* Extract WAL location for this block */
let mut xlogoff = end_pos.segment_offset(wal_seg_size) as usize;
/* Open file */
let mut segno = end_pos.segment_number(wal_seg_size);
let (wal_file_path, wal_file_partial_path) = self.wal_file_paths(segno)?;
{
let mut wal_file: File;
/* Try to open already completed segment */
if let Ok(file) = OpenOptions::new().write(true).open(&wal_file_path) {
wal_file = file;
partial = false;
} else {
wal_file = OpenOptions::new()
.write(true)
.open(&wal_file_partial_path)?;
partial = true;
}
wal_file.seek(SeekFrom::Start(xlogoff as u64))?;
while xlogoff < wal_seg_size {
let bytes_to_write = min(XLOG_BLCKSZ, wal_seg_size - xlogoff);
wal_file.write_all(&ZERO_BLOCK[0..bytes_to_write])?;
xlogoff += bytes_to_write;
}
// Flush file, if not said otherwise
if !self.conf.no_sync {
wal_file.sync_all()?;
}
}
if !partial {
// Make segment partial once again
fs::rename(&wal_file_path, &wal_file_partial_path)?;
}
// Remove all subsequent segments
loop {
segno += 1;
let (wal_file_path, wal_file_partial_path) = self.wal_file_paths(segno)?;
// TODO: better use fs::try_exists which is currenty avaialble only in nightly build
if wal_file_path.exists() {
fs::remove_file(&wal_file_path)?;
} else if wal_file_partial_path.exists() {
fs::remove_file(&wal_file_partial_path)?;
} else {
break;
}
}
// Update lsns
self.write_lsn = end_pos;
self.write_record_lsn = end_pos;
self.update_flush_lsn();
Ok(())
}
}
pub struct WalReader {
timeline_dir: PathBuf,
wal_seg_size: usize,
pos: Lsn,
file: Option<File>,
}
impl WalReader {
pub fn new(timeline_dir: PathBuf, wal_seg_size: usize, pos: Lsn) -> Self {
Self {
timeline_dir,
wal_seg_size,
pos,
file: None,
}
}
pub fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
// Take the `File` from `wal_file`, or open a new file.
let mut file = match self.file.take() {
Some(file) => file,
None => {
// Open a new file.
let segno = self.pos.segment_number(self.wal_seg_size);
let wal_file_name = XLogFileName(PG_TLI, segno, self.wal_seg_size);
let wal_file_path = self.timeline_dir.join(wal_file_name);
Self::open_wal_file(&wal_file_path)?
}
};
let xlogoff = self.pos.segment_offset(self.wal_seg_size) as usize;
// How much to read and send in message? We cannot cross the WAL file
// boundary, and we don't want send more than provided buffer.
let send_size = min(buf.len(), self.wal_seg_size - xlogoff);
// Read some data from the file.
let buf = &mut buf[0..send_size];
file.seek(SeekFrom::Start(xlogoff as u64))
.and_then(|_| file.read_exact(buf))
.context("Failed to read data from WAL file")?;
self.pos += send_size as u64;
// Decide whether to reuse this file. If we don't set wal_file here
// a new file will be opened next time.
if self.pos.segment_offset(self.wal_seg_size) != 0 {
self.file = Some(file);
}
Ok(send_size)
}
/// Helper function for opening a wal file.
fn open_wal_file(wal_file_path: &Path) -> Result<File> {
// First try to open the .partial file.
let mut partial_path = wal_file_path.to_owned();
partial_path.set_extension("partial");
if let Ok(opened_file) = File::open(&partial_path) {
return Ok(opened_file);
}
// If that failed, try it without the .partial extension.
File::open(&wal_file_path)
.with_context(|| format!("Failed to open WAL file {:?}", wal_file_path))
.map_err(|e| {
error!("{}", e);
e
})
}
}

View File

@@ -5,4 +5,4 @@ pub mod request;
/// Current fast way to apply simple http routing in various Zenith binaries.
/// Re-exported for sake of uniform approach, that could be later replaced with better alternatives, if needed.
pub use routerify::{ext::RequestExt, RouterBuilder, RouterService};
pub use routerify::{ext::RequestExt, RouterBuilder};

View File

@@ -430,11 +430,11 @@ impl PostgresBackend {
trace!("got query {:?}", query_string);
// xxx distinguish fatal and recoverable errors?
if let Err(e) = handler.process_query(self, query_string) {
// ":?" uses the alternate formatting style, which makes anyhow display the
// full cause of the error, not just the top-level context + its trace.
// We don't want to send that in the ErrorResponse though,
// because it's not relevant to the compute node logs.
error!("query handler for '{}' failed: {:?}", query_string, e);
// ":#" uses the alternate formatting style, which makes anyhow display the
// full cause of the error, not just the top-level context. We don't want to
// send that in the ErrorResponse though, because it's not relevant to the
// compute node logs.
warn!("query handler for {} failed: {:#}", query_string, e);
self.write_message_noflush(&BeMessage::ErrorResponse(&e.to_string()))?;
// TODO: untangle convoluted control flow
if e.to_string().contains("failed to run") {
@@ -467,7 +467,7 @@ impl PostgresBackend {
trace!("got execute {:?}", query_string);
// xxx distinguish fatal and recoverable errors?
if let Err(e) = handler.process_query(self, query_string) {
error!("query handler for '{}' failed: {:?}", query_string, e);
warn!("query handler for {:?} failed: {:#}", query_string, e);
self.write_message(&BeMessage::ErrorResponse(&e.to_string()))?;
}
// NOTE there is no ReadyForQuery message. This handler is used

View File

@@ -57,16 +57,6 @@ pub struct CancelKeyData {
pub cancel_key: i32,
}
use rand::distributions::{Distribution, Standard};
impl Distribution<CancelKeyData> for Standard {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> CancelKeyData {
CancelKeyData {
backend_pid: rng.gen(),
cancel_key: rng.gen(),
}
}
}
#[derive(Debug)]
pub struct FeQueryMessage {
pub body: Bytes,