mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-21 23:20:40 +00:00
Compare commits
11 Commits
myrrc/past
...
28934-pg-d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4f4a96ea25 | ||
|
|
5bdba70f7d | ||
|
|
25fffd3a55 | ||
|
|
e00fd45bba | ||
|
|
3b8be98b67 | ||
|
|
3e72edede5 | ||
|
|
a650f7f5af | ||
|
|
fc3994eb71 | ||
|
|
781bf4945d | ||
|
|
a21c1174ed | ||
|
|
8d7ed2a4ee |
8
.gitmodules
vendored
8
.gitmodules
vendored
@@ -1,16 +1,16 @@
|
||||
[submodule "vendor/postgres-v14"]
|
||||
path = vendor/postgres-v14
|
||||
url = https://github.com/neondatabase/postgres.git
|
||||
branch = REL_14_STABLE_neon
|
||||
branch = 28934-pg-dump-schema-no-create-v14
|
||||
[submodule "vendor/postgres-v15"]
|
||||
path = vendor/postgres-v15
|
||||
url = https://github.com/neondatabase/postgres.git
|
||||
branch = REL_15_STABLE_neon
|
||||
branch = 28934-pg-dump-schema-no-create-v15
|
||||
[submodule "vendor/postgres-v16"]
|
||||
path = vendor/postgres-v16
|
||||
url = https://github.com/neondatabase/postgres.git
|
||||
branch = REL_16_STABLE_neon
|
||||
branch = 28934-pg-dump-schema-no-create-v16
|
||||
[submodule "vendor/postgres-v17"]
|
||||
path = vendor/postgres-v17
|
||||
url = https://github.com/neondatabase/postgres.git
|
||||
branch = REL_17_STABLE_neon
|
||||
branch = 28934-pg-dump-schema-no-create-v17
|
||||
|
||||
187
Cargo.lock
generated
187
Cargo.lock
generated
@@ -29,41 +29,6 @@ version = "2.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627"
|
||||
|
||||
[[package]]
|
||||
name = "aead"
|
||||
version = "0.5.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0"
|
||||
dependencies = [
|
||||
"crypto-common",
|
||||
"generic-array",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aes"
|
||||
version = "0.8.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cipher",
|
||||
"cpufeatures",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aes-gcm"
|
||||
version = "0.10.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "831010a0f742e1209b3bcea8fab6a8e149051ba6099432c8cb2cc117dec3ead1"
|
||||
dependencies = [
|
||||
"aead",
|
||||
"aes",
|
||||
"cipher",
|
||||
"ctr",
|
||||
"ghash",
|
||||
"subtle",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ahash"
|
||||
version = "0.8.11"
|
||||
@@ -788,7 +753,6 @@ dependencies = [
|
||||
"axum",
|
||||
"axum-core",
|
||||
"bytes",
|
||||
"cookie",
|
||||
"futures-util",
|
||||
"headers",
|
||||
"http 1.1.0",
|
||||
@@ -1209,16 +1173,6 @@ dependencies = [
|
||||
"half",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cipher"
|
||||
version = "0.4.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad"
|
||||
dependencies = [
|
||||
"crypto-common",
|
||||
"inout",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clang-sys"
|
||||
version = "1.6.1"
|
||||
@@ -1510,21 +1464,6 @@ dependencies = [
|
||||
"workspace_hack",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cookie"
|
||||
version = "0.18.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4ddef33a339a91ea89fb53151bd0a4689cfce27055c291dfa69945475d22c747"
|
||||
dependencies = [
|
||||
"aes-gcm",
|
||||
"base64 0.22.1",
|
||||
"percent-encoding",
|
||||
"rand 0.8.5",
|
||||
"subtle",
|
||||
"time",
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "core-foundation"
|
||||
version = "0.9.3"
|
||||
@@ -1718,19 +1657,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3"
|
||||
dependencies = [
|
||||
"generic-array",
|
||||
"rand_core 0.6.4",
|
||||
"typenum",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ctr"
|
||||
version = "0.9.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0369ee1ad671834580515889b80f2ea915f23b8be8d0daa4bbaf2ac5c7590835"
|
||||
dependencies = [
|
||||
"cipher",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "curve25519-dalek"
|
||||
version = "4.1.3"
|
||||
@@ -2581,16 +2510,6 @@ dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ghash"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f0d8a4362ccb29cb0b265253fb0a2728f592895ee6854fd9bc13f2ffda266ff1"
|
||||
dependencies = [
|
||||
"opaque-debug",
|
||||
"polyval",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gimli"
|
||||
version = "0.31.1"
|
||||
@@ -3362,15 +3281,6 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "inout"
|
||||
version = "0.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "879f10e63c20629ecabbb64a8010319738c66a5cd0c29b02d63d272b03751d01"
|
||||
dependencies = [
|
||||
"generic-array",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "instant"
|
||||
version = "0.1.12"
|
||||
@@ -3884,15 +3794,6 @@ version = "0.8.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a"
|
||||
|
||||
[[package]]
|
||||
name = "nanoid"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3ffa00dec017b5b1a8b7cf5e2c008bfda1aa7e0697ac1508b491fdf2622fb4d8"
|
||||
dependencies = [
|
||||
"rand 0.8.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "neon-shmem"
|
||||
version = "0.1.0"
|
||||
@@ -4165,12 +4066,6 @@ version = "11.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575"
|
||||
|
||||
[[package]]
|
||||
name = "opaque-debug"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381"
|
||||
|
||||
[[package]]
|
||||
name = "openssl-probe"
|
||||
version = "0.1.5"
|
||||
@@ -4341,6 +4236,7 @@ name = "pagebench"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"camino",
|
||||
"clap",
|
||||
"futures",
|
||||
@@ -4349,12 +4245,15 @@ dependencies = [
|
||||
"humantime-serde",
|
||||
"pageserver_api",
|
||||
"pageserver_client",
|
||||
"pageserver_page_api",
|
||||
"rand 0.8.5",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tokio-util",
|
||||
"tonic 0.13.1",
|
||||
"tracing",
|
||||
"utils",
|
||||
"workspace_hack",
|
||||
@@ -4410,6 +4309,7 @@ dependencies = [
|
||||
"hashlink",
|
||||
"hex",
|
||||
"hex-literal",
|
||||
"http 1.1.0",
|
||||
"http-utils",
|
||||
"humantime",
|
||||
"humantime-serde",
|
||||
@@ -4472,6 +4372,7 @@ dependencies = [
|
||||
"toml_edit",
|
||||
"tonic 0.13.1",
|
||||
"tonic-reflection",
|
||||
"tower 0.5.2",
|
||||
"tracing",
|
||||
"tracing-utils",
|
||||
"twox-hash",
|
||||
@@ -4568,7 +4469,6 @@ dependencies = [
|
||||
"pageserver_api",
|
||||
"postgres_ffi",
|
||||
"prost 0.13.5",
|
||||
"smallvec",
|
||||
"thiserror 1.0.69",
|
||||
"tonic 0.13.1",
|
||||
"tonic-build",
|
||||
@@ -4690,31 +4590,6 @@ version = "1.0.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c"
|
||||
|
||||
[[package]]
|
||||
name = "paster"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"axum",
|
||||
"axum-extra",
|
||||
"base64 0.13.1",
|
||||
"chrono",
|
||||
"nanoid",
|
||||
"rand 0.8.5",
|
||||
"reqwest",
|
||||
"rustls 0.23.27",
|
||||
"rustls-native-certs 0.8.0",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"time",
|
||||
"tokio",
|
||||
"tokio-postgres",
|
||||
"tokio-postgres-rustls",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"workspace_hack",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pbkdf2"
|
||||
version = "0.12.2"
|
||||
@@ -4887,18 +4762,6 @@ dependencies = [
|
||||
"never-say-never",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "polyval"
|
||||
version = "0.6.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9d1fe60d06143b2430aa532c94cfe9e29783047f06c0d7fd359a9a51b729fa25"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"cpufeatures",
|
||||
"opaque-debug",
|
||||
"universal-hash",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "portable-atomic"
|
||||
version = "1.10.0"
|
||||
@@ -6701,32 +6564,6 @@ version = "1.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
|
||||
|
||||
[[package]]
|
||||
name = "shortener"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"axum",
|
||||
"axum-extra",
|
||||
"base64 0.13.1",
|
||||
"chrono",
|
||||
"cookie",
|
||||
"nanoid",
|
||||
"rand 0.8.5",
|
||||
"reqwest",
|
||||
"rustls 0.23.27",
|
||||
"rustls-native-certs 0.8.0",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"time",
|
||||
"tokio",
|
||||
"tokio-postgres",
|
||||
"tokio-postgres-rustls",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"workspace_hack",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "signal-hook"
|
||||
version = "0.3.15"
|
||||
@@ -8095,16 +7932,6 @@ version = "0.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c"
|
||||
|
||||
[[package]]
|
||||
name = "universal-hash"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fc1de2c688dc15305988b563c3854064043356019f97a4b46276fe734c4f07ea"
|
||||
dependencies = [
|
||||
"crypto-common",
|
||||
"subtle",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "untrusted"
|
||||
version = "0.9.0"
|
||||
@@ -8737,7 +8564,6 @@ dependencies = [
|
||||
"anyhow",
|
||||
"axum",
|
||||
"axum-core",
|
||||
"axum-extra",
|
||||
"base64 0.13.1",
|
||||
"base64 0.21.7",
|
||||
"base64ct",
|
||||
@@ -8749,7 +8575,6 @@ dependencies = [
|
||||
"clap_builder",
|
||||
"const-oid",
|
||||
"crypto-bigint 0.5.5",
|
||||
"crypto-common",
|
||||
"der 0.7.8",
|
||||
"deranged",
|
||||
"digest",
|
||||
|
||||
@@ -13,8 +13,6 @@ members = [
|
||||
"proxy",
|
||||
"safekeeper",
|
||||
"safekeeper/client",
|
||||
"shortener",
|
||||
"paster",
|
||||
"storage_broker",
|
||||
"storage_controller",
|
||||
"storage_controller/client",
|
||||
|
||||
@@ -1180,14 +1180,14 @@ RUN cd exts/rag && \
|
||||
RUN cd exts/rag_bge_small_en_v15 && \
|
||||
sed -i 's/pgrx = "0.14.1"/pgrx = { version = "0.14.1", features = [ "unsafe-postgres" ] }/g' Cargo.toml && \
|
||||
ORT_LIB_LOCATION=/ext-src/onnxruntime-src/build/Linux \
|
||||
REMOTE_ONNX_URL=http://pg-ext-s3-gateway/pgrag-data/bge_small_en_v15.onnx \
|
||||
REMOTE_ONNX_URL=http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local/pgrag-data/bge_small_en_v15.onnx \
|
||||
cargo pgrx install --release --features remote_onnx && \
|
||||
echo "trusted = true" >> /usr/local/pgsql/share/extension/rag_bge_small_en_v15.control
|
||||
|
||||
RUN cd exts/rag_jina_reranker_v1_tiny_en && \
|
||||
sed -i 's/pgrx = "0.14.1"/pgrx = { version = "0.14.1", features = [ "unsafe-postgres" ] }/g' Cargo.toml && \
|
||||
ORT_LIB_LOCATION=/ext-src/onnxruntime-src/build/Linux \
|
||||
REMOTE_ONNX_URL=http://pg-ext-s3-gateway/pgrag-data/jina_reranker_v1_tiny_en.onnx \
|
||||
REMOTE_ONNX_URL=http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local/pgrag-data/jina_reranker_v1_tiny_en.onnx \
|
||||
cargo pgrx install --release --features remote_onnx && \
|
||||
echo "trusted = true" >> /usr/local/pgsql/share/extension/rag_jina_reranker_v1_tiny_en.control
|
||||
|
||||
|
||||
@@ -70,6 +70,14 @@ enum Command {
|
||||
/// and maintenance_work_mem.
|
||||
#[clap(long, env = "NEON_IMPORTER_MEMORY_MB")]
|
||||
memory_mb: Option<usize>,
|
||||
|
||||
/// List of schemas to dump.
|
||||
#[clap(long)]
|
||||
schema: Vec<String>,
|
||||
|
||||
/// List of extensions to dump.
|
||||
#[clap(long)]
|
||||
extension: Vec<String>,
|
||||
},
|
||||
|
||||
/// Runs pg_dump-pg_restore from source to destination without running local postgres.
|
||||
@@ -82,6 +90,12 @@ enum Command {
|
||||
/// real scenario uses encrypted connection string in spec.json from s3.
|
||||
#[clap(long)]
|
||||
destination_connection_string: Option<String>,
|
||||
/// List of schemas to dump.
|
||||
#[clap(long)]
|
||||
schema: Vec<String>,
|
||||
/// List of extensions to dump.
|
||||
#[clap(long)]
|
||||
extension: Vec<String>,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -117,6 +131,8 @@ struct Spec {
|
||||
source_connstring_ciphertext_base64: Vec<u8>,
|
||||
#[serde_as(as = "Option<serde_with::base64::Base64>")]
|
||||
destination_connstring_ciphertext_base64: Option<Vec<u8>>,
|
||||
schemas: Option<Vec<String>>,
|
||||
extensions: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
@@ -337,6 +353,8 @@ async fn run_dump_restore(
|
||||
pg_lib_dir: Utf8PathBuf,
|
||||
source_connstring: String,
|
||||
destination_connstring: String,
|
||||
schemas: Vec<String>,
|
||||
extensions: Vec<String>,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let dumpdir = workdir.join("dumpdir");
|
||||
let num_jobs = num_cpus::get().to_string();
|
||||
@@ -351,6 +369,7 @@ async fn run_dump_restore(
|
||||
"--no-subscriptions".to_string(),
|
||||
"--no-tablespaces".to_string(),
|
||||
"--no-event-triggers".to_string(),
|
||||
"--enable-row-security".to_string(),
|
||||
// format
|
||||
"--format".to_string(),
|
||||
"directory".to_string(),
|
||||
@@ -361,10 +380,36 @@ async fn run_dump_restore(
|
||||
"--verbose".to_string(),
|
||||
];
|
||||
|
||||
let mut pg_dump_args = vec![
|
||||
// this makes sure any unsupported extensions are not included in the dump
|
||||
// even if we don't specify supported extensions explicitly
|
||||
"--extension".to_string(),
|
||||
"plpgsql".to_string(),
|
||||
];
|
||||
|
||||
// if no schemas are specified, try to import all schemas
|
||||
if !schemas.is_empty() {
|
||||
// always include public schema objects
|
||||
// but never create the schema itself
|
||||
// it already exists in any pg cluster by default
|
||||
pg_dump_args.push("--schema-no-create".to_string());
|
||||
pg_dump_args.push("public".to_string());
|
||||
for schema in &schemas {
|
||||
pg_dump_args.push("--schema".to_string());
|
||||
pg_dump_args.push(schema.clone());
|
||||
}
|
||||
}
|
||||
|
||||
for extension in &extensions {
|
||||
pg_dump_args.push("--extension".to_string());
|
||||
pg_dump_args.push(extension.clone());
|
||||
}
|
||||
|
||||
info!("dump into the working directory");
|
||||
{
|
||||
let mut pg_dump = tokio::process::Command::new(pg_bin_dir.join("pg_dump"))
|
||||
.args(&common_args)
|
||||
.args(&pg_dump_args)
|
||||
.arg("-f")
|
||||
.arg(&dumpdir)
|
||||
.arg("--no-sync")
|
||||
@@ -455,6 +500,8 @@ async fn cmd_pgdata(
|
||||
maybe_s3_prefix: Option<s3_uri::S3Uri>,
|
||||
maybe_spec: Option<Spec>,
|
||||
source_connection_string: Option<String>,
|
||||
schemas: Vec<String>,
|
||||
extensions: Vec<String>,
|
||||
interactive: bool,
|
||||
pg_port: u16,
|
||||
workdir: Utf8PathBuf,
|
||||
@@ -470,19 +517,25 @@ async fn cmd_pgdata(
|
||||
bail!("only one of spec or source_connection_string can be provided");
|
||||
}
|
||||
|
||||
let source_connection_string = if let Some(spec) = maybe_spec {
|
||||
let (source_connection_string, schemas, extensions) = if let Some(spec) = maybe_spec {
|
||||
match spec.encryption_secret {
|
||||
EncryptionSecret::KMS { key_id } => {
|
||||
decode_connstring(
|
||||
let schemas = spec.schemas.unwrap_or(vec![]);
|
||||
let extensions = spec.extensions.unwrap_or(vec![]);
|
||||
|
||||
let source = decode_connstring(
|
||||
kms_client.as_ref().unwrap(),
|
||||
&key_id,
|
||||
spec.source_connstring_ciphertext_base64,
|
||||
)
|
||||
.await?
|
||||
.await
|
||||
.context("decrypt source connection string")?;
|
||||
|
||||
(source, schemas, extensions)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
source_connection_string.unwrap()
|
||||
(source_connection_string.unwrap(), schemas, extensions)
|
||||
};
|
||||
|
||||
let superuser = "cloud_admin";
|
||||
@@ -504,6 +557,8 @@ async fn cmd_pgdata(
|
||||
pg_lib_dir,
|
||||
source_connection_string,
|
||||
destination_connstring,
|
||||
schemas,
|
||||
extensions,
|
||||
)
|
||||
.await?;
|
||||
|
||||
@@ -546,18 +601,26 @@ async fn cmd_pgdata(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn cmd_dumprestore(
|
||||
kms_client: Option<aws_sdk_kms::Client>,
|
||||
maybe_spec: Option<Spec>,
|
||||
source_connection_string: Option<String>,
|
||||
destination_connection_string: Option<String>,
|
||||
schemas: Vec<String>,
|
||||
extensions: Vec<String>,
|
||||
workdir: Utf8PathBuf,
|
||||
pg_bin_dir: Utf8PathBuf,
|
||||
pg_lib_dir: Utf8PathBuf,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let (source_connstring, destination_connstring) = if let Some(spec) = maybe_spec {
|
||||
let (source_connstring, destination_connstring, schemas, extensions) = if let Some(spec) =
|
||||
maybe_spec
|
||||
{
|
||||
match spec.encryption_secret {
|
||||
EncryptionSecret::KMS { key_id } => {
|
||||
let schemas = spec.schemas.unwrap_or(vec![]);
|
||||
let extensions = spec.extensions.unwrap_or(vec![]);
|
||||
|
||||
let source = decode_connstring(
|
||||
kms_client.as_ref().unwrap(),
|
||||
&key_id,
|
||||
@@ -578,18 +641,17 @@ async fn cmd_dumprestore(
|
||||
);
|
||||
};
|
||||
|
||||
(source, dest)
|
||||
(source, dest, schemas, extensions)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
(
|
||||
source_connection_string.unwrap(),
|
||||
if let Some(val) = destination_connection_string {
|
||||
val
|
||||
} else {
|
||||
bail!("destination connection string must be provided for dump_restore command");
|
||||
},
|
||||
)
|
||||
let dest = if let Some(val) = destination_connection_string {
|
||||
val
|
||||
} else {
|
||||
bail!("destination connection string must be provided for dump_restore command");
|
||||
};
|
||||
|
||||
(source_connection_string.unwrap(), dest, schemas, extensions)
|
||||
};
|
||||
|
||||
run_dump_restore(
|
||||
@@ -598,6 +660,8 @@ async fn cmd_dumprestore(
|
||||
pg_lib_dir,
|
||||
source_connstring,
|
||||
destination_connstring,
|
||||
schemas,
|
||||
extensions,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -679,6 +743,8 @@ pub(crate) async fn main() -> anyhow::Result<()> {
|
||||
pg_port,
|
||||
num_cpus,
|
||||
memory_mb,
|
||||
schema,
|
||||
extension,
|
||||
} => {
|
||||
cmd_pgdata(
|
||||
s3_client.as_ref(),
|
||||
@@ -686,6 +752,8 @@ pub(crate) async fn main() -> anyhow::Result<()> {
|
||||
args.s3_prefix.clone(),
|
||||
spec,
|
||||
source_connection_string,
|
||||
schema,
|
||||
extension,
|
||||
interactive,
|
||||
pg_port,
|
||||
args.working_directory.clone(),
|
||||
@@ -699,12 +767,16 @@ pub(crate) async fn main() -> anyhow::Result<()> {
|
||||
Command::DumpRestore {
|
||||
source_connection_string,
|
||||
destination_connection_string,
|
||||
schema,
|
||||
extension,
|
||||
} => {
|
||||
cmd_dumprestore(
|
||||
kms_client,
|
||||
spec,
|
||||
source_connection_string,
|
||||
destination_connection_string,
|
||||
schema,
|
||||
extension,
|
||||
args.working_directory.clone(),
|
||||
args.pg_bin_dir,
|
||||
args.pg_lib_dir,
|
||||
|
||||
@@ -181,6 +181,7 @@ pub struct ConfigToml {
|
||||
pub virtual_file_io_engine: Option<crate::models::virtual_file::IoEngineKind>,
|
||||
pub ingest_batch_size: u64,
|
||||
pub max_vectored_read_bytes: MaxVectoredReadBytes,
|
||||
pub max_get_vectored_keys: MaxGetVectoredKeys,
|
||||
pub image_compression: ImageCompressionAlgorithm,
|
||||
pub timeline_offloading: bool,
|
||||
pub ephemeral_bytes_per_memory_kb: usize,
|
||||
@@ -229,7 +230,7 @@ pub enum PageServicePipeliningConfig {
|
||||
}
|
||||
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
|
||||
pub struct PageServicePipeliningConfigPipelined {
|
||||
/// Causes runtime errors if larger than max get_vectored batch size.
|
||||
/// Failed config parsing and validation if larger than `max_get_vectored_keys`.
|
||||
pub max_batch_size: NonZeroUsize,
|
||||
pub execution: PageServiceProtocolPipelinedExecutionStrategy,
|
||||
// The default below is such that new versions of the software can start
|
||||
@@ -403,6 +404,16 @@ impl Default for EvictionOrder {
|
||||
#[serde(transparent)]
|
||||
pub struct MaxVectoredReadBytes(pub NonZeroUsize);
|
||||
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
|
||||
#[serde(transparent)]
|
||||
pub struct MaxGetVectoredKeys(NonZeroUsize);
|
||||
|
||||
impl MaxGetVectoredKeys {
|
||||
pub fn get(&self) -> usize {
|
||||
self.0.get()
|
||||
}
|
||||
}
|
||||
|
||||
/// Tenant-level configuration values, used for various purposes.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
|
||||
#[serde(default)]
|
||||
@@ -587,6 +598,8 @@ pub mod defaults {
|
||||
/// That is, slightly above 128 kB.
|
||||
pub const DEFAULT_MAX_VECTORED_READ_BYTES: usize = 130 * 1024; // 130 KiB
|
||||
|
||||
pub const DEFAULT_MAX_GET_VECTORED_KEYS: usize = 32;
|
||||
|
||||
pub const DEFAULT_IMAGE_COMPRESSION: ImageCompressionAlgorithm =
|
||||
ImageCompressionAlgorithm::Zstd { level: Some(1) };
|
||||
|
||||
@@ -685,6 +698,9 @@ impl Default for ConfigToml {
|
||||
max_vectored_read_bytes: (MaxVectoredReadBytes(
|
||||
NonZeroUsize::new(DEFAULT_MAX_VECTORED_READ_BYTES).unwrap(),
|
||||
)),
|
||||
max_get_vectored_keys: (MaxGetVectoredKeys(
|
||||
NonZeroUsize::new(DEFAULT_MAX_GET_VECTORED_KEYS).unwrap(),
|
||||
)),
|
||||
image_compression: (DEFAULT_IMAGE_COMPRESSION),
|
||||
timeline_offloading: true,
|
||||
ephemeral_bytes_per_memory_kb: (DEFAULT_EPHEMERAL_BYTES_PER_MEMORY_KB),
|
||||
|
||||
@@ -1934,7 +1934,7 @@ pub enum PagestreamFeMessage {
|
||||
}
|
||||
|
||||
// Wrapped in libpq CopyData
|
||||
#[derive(strum_macros::EnumProperty)]
|
||||
#[derive(Debug, strum_macros::EnumProperty)]
|
||||
pub enum PagestreamBeMessage {
|
||||
Exists(PagestreamExistsResponse),
|
||||
Nblocks(PagestreamNblocksResponse),
|
||||
@@ -2045,7 +2045,7 @@ pub enum PagestreamProtocolVersion {
|
||||
|
||||
pub type RequestId = u64;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)]
|
||||
pub struct PagestreamRequest {
|
||||
pub reqid: RequestId,
|
||||
pub request_lsn: Lsn,
|
||||
@@ -2064,7 +2064,7 @@ pub struct PagestreamNblocksRequest {
|
||||
pub rel: RelTag,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)]
|
||||
pub struct PagestreamGetPageRequest {
|
||||
pub hdr: PagestreamRequest,
|
||||
pub rel: RelTag,
|
||||
|
||||
@@ -24,7 +24,7 @@ use serde::{Deserialize, Serialize};
|
||||
// FIXME: should move 'forknum' as last field to keep this consistent with Postgres.
|
||||
// Then we could replace the custom Ord and PartialOrd implementations below with
|
||||
// deriving them. This will require changes in walredoproc.c.
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize)]
|
||||
#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize)]
|
||||
pub struct RelTag {
|
||||
pub forknum: u8,
|
||||
pub spcnode: Oid,
|
||||
@@ -184,12 +184,12 @@ pub enum SlruKind {
|
||||
MultiXactOffsets,
|
||||
}
|
||||
|
||||
impl SlruKind {
|
||||
pub fn to_str(&self) -> &'static str {
|
||||
impl fmt::Display for SlruKind {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::Clog => "pg_xact",
|
||||
Self::MultiXactMembers => "pg_multixact/members",
|
||||
Self::MultiXactOffsets => "pg_multixact/offsets",
|
||||
Self::Clog => write!(f, "pg_xact"),
|
||||
Self::MultiXactMembers => write!(f, "pg_multixact/members"),
|
||||
Self::MultiXactOffsets => write!(f, "pg_multixact/offsets"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -73,6 +73,7 @@ pub mod error;
|
||||
/// async timeout helper
|
||||
pub mod timeout;
|
||||
|
||||
pub mod span;
|
||||
pub mod sync;
|
||||
|
||||
pub mod failpoint_support;
|
||||
|
||||
19
libs/utils/src/span.rs
Normal file
19
libs/utils/src/span.rs
Normal file
@@ -0,0 +1,19 @@
|
||||
//! Tracing span helpers.
|
||||
|
||||
/// Records the given fields in the current span, as a single call. The fields must already have
|
||||
/// been declared for the span (typically with empty values).
|
||||
#[macro_export]
|
||||
macro_rules! span_record {
|
||||
($($tokens:tt)*) => {$crate::span_record_in!(::tracing::Span::current(), $($tokens)*)};
|
||||
}
|
||||
|
||||
/// Records the given fields in the given span, as a single call. The fields must already have been
|
||||
/// declared for the span (typically with empty values).
|
||||
#[macro_export]
|
||||
macro_rules! span_record_in {
|
||||
($span:expr, $($tokens:tt)*) => {
|
||||
if let Some(meta) = $span.metadata() {
|
||||
$span.record_all(&tracing::valueset!(meta.fields(), $($tokens)*));
|
||||
}
|
||||
};
|
||||
}
|
||||
@@ -34,6 +34,7 @@ fail.workspace = true
|
||||
futures.workspace = true
|
||||
hashlink.workspace = true
|
||||
hex.workspace = true
|
||||
http.workspace = true
|
||||
http-utils.workspace = true
|
||||
humantime-serde.workspace = true
|
||||
humantime.workspace = true
|
||||
@@ -93,6 +94,7 @@ tokio-util.workspace = true
|
||||
toml_edit = { workspace = true, features = [ "serde" ] }
|
||||
tonic.workspace = true
|
||||
tonic-reflection.workspace = true
|
||||
tower.workspace = true
|
||||
tracing.workspace = true
|
||||
tracing-utils.workspace = true
|
||||
url.workspace = true
|
||||
|
||||
@@ -9,7 +9,6 @@ bytes.workspace = true
|
||||
pageserver_api.workspace = true
|
||||
postgres_ffi.workspace = true
|
||||
prost.workspace = true
|
||||
smallvec.workspace = true
|
||||
thiserror.workspace = true
|
||||
tonic.workspace = true
|
||||
utils.workspace = true
|
||||
|
||||
@@ -9,10 +9,16 @@
|
||||
//! - Use more precise datatypes, e.g. Lsn and uints shorter than 32 bits.
|
||||
//!
|
||||
//! - Validate protocol invariants, via try_from() and try_into().
|
||||
//!
|
||||
//! Validation only happens on the receiver side, i.e. when converting from Protobuf to domain
|
||||
//! types. This is where it matters -- the Protobuf types are less strict than the domain types, and
|
||||
//! receivers should expect all sorts of junk from senders. This also allows the sender to use e.g.
|
||||
//! stream combinators without dealing with errors, and avoids validating the same message twice.
|
||||
|
||||
use std::fmt::Display;
|
||||
|
||||
use bytes::Bytes;
|
||||
use postgres_ffi::Oid;
|
||||
use smallvec::SmallVec;
|
||||
// TODO: split out Lsn, RelTag, SlruKind, Oid and other basic types to a separate crate, to avoid
|
||||
// pulling in all of their other crate dependencies when building the client.
|
||||
use utils::lsn::Lsn;
|
||||
@@ -48,7 +54,8 @@ pub struct ReadLsn {
|
||||
pub request_lsn: Lsn,
|
||||
/// If given, the caller guarantees that the page has not been modified since this LSN. Must be
|
||||
/// smaller than or equal to request_lsn. This allows the Pageserver to serve an old page
|
||||
/// without waiting for the request LSN to arrive. Valid for all request types.
|
||||
/// without waiting for the request LSN to arrive. If not given, the request will read at the
|
||||
/// request_lsn and wait for it to arrive if necessary. Valid for all request types.
|
||||
///
|
||||
/// It is undefined behaviour to make a request such that the page was, in fact, modified
|
||||
/// between request_lsn and not_modified_since_lsn. The Pageserver might detect it and return an
|
||||
@@ -58,19 +65,14 @@ pub struct ReadLsn {
|
||||
pub not_modified_since_lsn: Option<Lsn>,
|
||||
}
|
||||
|
||||
impl ReadLsn {
|
||||
/// Validates the ReadLsn.
|
||||
pub fn validate(&self) -> Result<(), ProtocolError> {
|
||||
if self.request_lsn == Lsn::INVALID {
|
||||
return Err(ProtocolError::invalid("request_lsn", self.request_lsn));
|
||||
impl Display for ReadLsn {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let req_lsn = self.request_lsn;
|
||||
if let Some(mod_lsn) = self.not_modified_since_lsn {
|
||||
write!(f, "{req_lsn}>={mod_lsn}")
|
||||
} else {
|
||||
req_lsn.fmt(f)
|
||||
}
|
||||
if self.not_modified_since_lsn > Some(self.request_lsn) {
|
||||
return Err(ProtocolError::invalid(
|
||||
"not_modified_since_lsn",
|
||||
self.not_modified_since_lsn,
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -78,27 +80,31 @@ impl TryFrom<proto::ReadLsn> for ReadLsn {
|
||||
type Error = ProtocolError;
|
||||
|
||||
fn try_from(pb: proto::ReadLsn) -> Result<Self, Self::Error> {
|
||||
let read_lsn = Self {
|
||||
if pb.request_lsn == 0 {
|
||||
return Err(ProtocolError::invalid("request_lsn", pb.request_lsn));
|
||||
}
|
||||
if pb.not_modified_since_lsn > pb.request_lsn {
|
||||
return Err(ProtocolError::invalid(
|
||||
"not_modified_since_lsn",
|
||||
pb.not_modified_since_lsn,
|
||||
));
|
||||
}
|
||||
Ok(Self {
|
||||
request_lsn: Lsn(pb.request_lsn),
|
||||
not_modified_since_lsn: match pb.not_modified_since_lsn {
|
||||
0 => None,
|
||||
lsn => Some(Lsn(lsn)),
|
||||
},
|
||||
};
|
||||
read_lsn.validate()?;
|
||||
Ok(read_lsn)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<ReadLsn> for proto::ReadLsn {
|
||||
type Error = ProtocolError;
|
||||
|
||||
fn try_from(read_lsn: ReadLsn) -> Result<Self, Self::Error> {
|
||||
read_lsn.validate()?;
|
||||
Ok(Self {
|
||||
impl From<ReadLsn> for proto::ReadLsn {
|
||||
fn from(read_lsn: ReadLsn) -> Self {
|
||||
Self {
|
||||
request_lsn: read_lsn.request_lsn.0,
|
||||
not_modified_since_lsn: read_lsn.not_modified_since_lsn.unwrap_or_default().0,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -153,6 +159,15 @@ impl TryFrom<proto::CheckRelExistsRequest> for CheckRelExistsRequest {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<CheckRelExistsRequest> for proto::CheckRelExistsRequest {
|
||||
fn from(request: CheckRelExistsRequest) -> Self {
|
||||
Self {
|
||||
read_lsn: Some(request.read_lsn.into()),
|
||||
rel: Some(request.rel.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type CheckRelExistsResponse = bool;
|
||||
|
||||
impl From<proto::CheckRelExistsResponse> for CheckRelExistsResponse {
|
||||
@@ -190,14 +205,12 @@ impl TryFrom<proto::GetBaseBackupRequest> for GetBaseBackupRequest {
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<GetBaseBackupRequest> for proto::GetBaseBackupRequest {
|
||||
type Error = ProtocolError;
|
||||
|
||||
fn try_from(request: GetBaseBackupRequest) -> Result<Self, Self::Error> {
|
||||
Ok(Self {
|
||||
read_lsn: Some(request.read_lsn.try_into()?),
|
||||
impl From<GetBaseBackupRequest> for proto::GetBaseBackupRequest {
|
||||
fn from(request: GetBaseBackupRequest) -> Self {
|
||||
Self {
|
||||
read_lsn: Some(request.read_lsn.into()),
|
||||
replica: request.replica,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -214,14 +227,9 @@ impl TryFrom<proto::GetBaseBackupResponseChunk> for GetBaseBackupResponseChunk {
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<GetBaseBackupResponseChunk> for proto::GetBaseBackupResponseChunk {
|
||||
type Error = ProtocolError;
|
||||
|
||||
fn try_from(chunk: GetBaseBackupResponseChunk) -> Result<Self, Self::Error> {
|
||||
if chunk.is_empty() {
|
||||
return Err(ProtocolError::Missing("chunk"));
|
||||
}
|
||||
Ok(Self { chunk })
|
||||
impl From<GetBaseBackupResponseChunk> for proto::GetBaseBackupResponseChunk {
|
||||
fn from(chunk: GetBaseBackupResponseChunk) -> Self {
|
||||
Self { chunk }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -246,14 +254,12 @@ impl TryFrom<proto::GetDbSizeRequest> for GetDbSizeRequest {
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<GetDbSizeRequest> for proto::GetDbSizeRequest {
|
||||
type Error = ProtocolError;
|
||||
|
||||
fn try_from(request: GetDbSizeRequest) -> Result<Self, Self::Error> {
|
||||
Ok(Self {
|
||||
read_lsn: Some(request.read_lsn.try_into()?),
|
||||
impl From<GetDbSizeRequest> for proto::GetDbSizeRequest {
|
||||
fn from(request: GetDbSizeRequest) -> Self {
|
||||
Self {
|
||||
read_lsn: Some(request.read_lsn.into()),
|
||||
db_oid: request.db_oid,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -288,7 +294,7 @@ pub struct GetPageRequest {
|
||||
/// Multiple pages will be executed as a single batch by the Pageserver, amortizing layer access
|
||||
/// costs and parallelizing them. This may increase the latency of any individual request, but
|
||||
/// improves the overall latency and throughput of the batch as a whole.
|
||||
pub block_numbers: SmallVec<[u32; 1]>,
|
||||
pub block_numbers: Vec<u32>,
|
||||
}
|
||||
|
||||
impl TryFrom<proto::GetPageRequest> for GetPageRequest {
|
||||
@@ -306,25 +312,20 @@ impl TryFrom<proto::GetPageRequest> for GetPageRequest {
|
||||
.ok_or(ProtocolError::Missing("read_lsn"))?
|
||||
.try_into()?,
|
||||
rel: pb.rel.ok_or(ProtocolError::Missing("rel"))?.try_into()?,
|
||||
block_numbers: pb.block_number.into(),
|
||||
block_numbers: pb.block_number,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<GetPageRequest> for proto::GetPageRequest {
|
||||
type Error = ProtocolError;
|
||||
|
||||
fn try_from(request: GetPageRequest) -> Result<Self, Self::Error> {
|
||||
if request.block_numbers.is_empty() {
|
||||
return Err(ProtocolError::Missing("block_number"));
|
||||
}
|
||||
Ok(Self {
|
||||
impl From<GetPageRequest> for proto::GetPageRequest {
|
||||
fn from(request: GetPageRequest) -> Self {
|
||||
Self {
|
||||
request_id: request.request_id,
|
||||
request_class: request.request_class.into(),
|
||||
read_lsn: Some(request.read_lsn.try_into()?),
|
||||
read_lsn: Some(request.read_lsn.into()),
|
||||
rel: Some(request.rel.into()),
|
||||
block_number: request.block_numbers.into_vec(),
|
||||
})
|
||||
block_number: request.block_numbers,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -396,7 +397,7 @@ pub struct GetPageResponse {
|
||||
/// A string describing the status, if any.
|
||||
pub reason: Option<String>,
|
||||
/// The 8KB page images, in the same order as the request. Empty if status != OK.
|
||||
pub page_images: SmallVec<[Bytes; 1]>,
|
||||
pub page_images: Vec<Bytes>,
|
||||
}
|
||||
|
||||
impl From<proto::GetPageResponse> for GetPageResponse {
|
||||
@@ -405,7 +406,7 @@ impl From<proto::GetPageResponse> for GetPageResponse {
|
||||
request_id: pb.request_id,
|
||||
status_code: pb.status_code.into(),
|
||||
reason: Some(pb.reason).filter(|r| !r.is_empty()),
|
||||
page_images: pb.page_image.into(),
|
||||
page_images: pb.page_image,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -416,7 +417,7 @@ impl From<GetPageResponse> for proto::GetPageResponse {
|
||||
request_id: response.request_id,
|
||||
status_code: response.status_code.into(),
|
||||
reason: response.reason.unwrap_or_default(),
|
||||
page_image: response.page_images.into_vec(),
|
||||
page_image: response.page_images,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -505,14 +506,12 @@ impl TryFrom<proto::GetRelSizeRequest> for GetRelSizeRequest {
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<GetRelSizeRequest> for proto::GetRelSizeRequest {
|
||||
type Error = ProtocolError;
|
||||
|
||||
fn try_from(request: GetRelSizeRequest) -> Result<Self, Self::Error> {
|
||||
Ok(Self {
|
||||
read_lsn: Some(request.read_lsn.try_into()?),
|
||||
impl From<GetRelSizeRequest> for proto::GetRelSizeRequest {
|
||||
fn from(request: GetRelSizeRequest) -> Self {
|
||||
Self {
|
||||
read_lsn: Some(request.read_lsn.into()),
|
||||
rel: Some(request.rel.into()),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -555,15 +554,13 @@ impl TryFrom<proto::GetSlruSegmentRequest> for GetSlruSegmentRequest {
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<GetSlruSegmentRequest> for proto::GetSlruSegmentRequest {
|
||||
type Error = ProtocolError;
|
||||
|
||||
fn try_from(request: GetSlruSegmentRequest) -> Result<Self, Self::Error> {
|
||||
Ok(Self {
|
||||
read_lsn: Some(request.read_lsn.try_into()?),
|
||||
impl From<GetSlruSegmentRequest> for proto::GetSlruSegmentRequest {
|
||||
fn from(request: GetSlruSegmentRequest) -> Self {
|
||||
Self {
|
||||
read_lsn: Some(request.read_lsn.into()),
|
||||
kind: request.kind as u32,
|
||||
segno: request.segno,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -580,14 +577,9 @@ impl TryFrom<proto::GetSlruSegmentResponse> for GetSlruSegmentResponse {
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<GetSlruSegmentResponse> for proto::GetSlruSegmentResponse {
|
||||
type Error = ProtocolError;
|
||||
|
||||
fn try_from(segment: GetSlruSegmentResponse) -> Result<Self, Self::Error> {
|
||||
if segment.is_empty() {
|
||||
return Err(ProtocolError::Missing("segment"));
|
||||
}
|
||||
Ok(Self { segment })
|
||||
impl From<GetSlruSegmentResponse> for proto::GetSlruSegmentResponse {
|
||||
fn from(segment: GetSlruSegmentResponse) -> Self {
|
||||
Self { segment }
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
async-trait.workspace = true
|
||||
camino.workspace = true
|
||||
clap.workspace = true
|
||||
futures.workspace = true
|
||||
@@ -15,14 +16,17 @@ hdrhistogram.workspace = true
|
||||
humantime.workspace = true
|
||||
humantime-serde.workspace = true
|
||||
rand.workspace = true
|
||||
reqwest.workspace=true
|
||||
reqwest.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
tracing.workspace = true
|
||||
tokio.workspace = true
|
||||
tokio-stream.workspace = true
|
||||
tokio-util.workspace = true
|
||||
tonic.workspace = true
|
||||
|
||||
pageserver_client.workspace = true
|
||||
pageserver_api.workspace = true
|
||||
pageserver_page_api.workspace = true
|
||||
utils = { path = "../../libs/utils/" }
|
||||
workspace_hack = { version = "0.1", path = "../../workspace_hack" }
|
||||
|
||||
@@ -7,11 +7,15 @@ use std::sync::{Arc, Mutex};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use anyhow::Context;
|
||||
use async_trait::async_trait;
|
||||
use camino::Utf8PathBuf;
|
||||
use pageserver_api::key::Key;
|
||||
use pageserver_api::keyspace::KeySpaceAccum;
|
||||
use pageserver_api::models::{PagestreamGetPageRequest, PagestreamRequest};
|
||||
use pageserver_api::models::{
|
||||
PagestreamGetPageRequest, PagestreamGetPageResponse, PagestreamRequest,
|
||||
};
|
||||
use pageserver_api::shard::TenantShardId;
|
||||
use pageserver_page_api::proto;
|
||||
use rand::prelude::*;
|
||||
use tokio::task::JoinSet;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
@@ -22,6 +26,12 @@ use utils::lsn::Lsn;
|
||||
use crate::util::tokio_thread_local_stats::AllThreadLocalStats;
|
||||
use crate::util::{request_stats, tokio_thread_local_stats};
|
||||
|
||||
#[derive(clap::ValueEnum, Clone, Debug)]
|
||||
enum Protocol {
|
||||
Libpq,
|
||||
Grpc,
|
||||
}
|
||||
|
||||
/// GetPage@LatestLSN, uniformly distributed across the compute-accessible keyspace.
|
||||
#[derive(clap::Parser)]
|
||||
pub(crate) struct Args {
|
||||
@@ -35,6 +45,8 @@ pub(crate) struct Args {
|
||||
num_clients: NonZeroUsize,
|
||||
#[clap(long)]
|
||||
runtime: Option<humantime::Duration>,
|
||||
#[clap(long, value_enum, default_value = "libpq")]
|
||||
protocol: Protocol,
|
||||
/// Each client sends requests at the given rate.
|
||||
///
|
||||
/// If a request takes too long and we should be issuing a new request already,
|
||||
@@ -303,7 +315,20 @@ async fn main_impl(
|
||||
.unwrap();
|
||||
|
||||
Box::pin(async move {
|
||||
client_libpq(args, worker_id, ss, cancel, rps_period, ranges, weights).await
|
||||
let client: Box<dyn Client> = match args.protocol {
|
||||
Protocol::Libpq => Box::new(
|
||||
LibpqClient::new(args.page_service_connstring.clone(), worker_id.timeline)
|
||||
.await
|
||||
.unwrap(),
|
||||
),
|
||||
|
||||
Protocol::Grpc => Box::new(
|
||||
GrpcClient::new(args.page_service_connstring.clone(), worker_id.timeline)
|
||||
.await
|
||||
.unwrap(),
|
||||
),
|
||||
};
|
||||
run_worker(args, client, ss, cancel, rps_period, ranges, weights).await
|
||||
})
|
||||
};
|
||||
|
||||
@@ -355,23 +380,15 @@ async fn main_impl(
|
||||
anyhow::Ok(())
|
||||
}
|
||||
|
||||
async fn client_libpq(
|
||||
async fn run_worker(
|
||||
args: &Args,
|
||||
worker_id: WorkerId,
|
||||
mut client: Box<dyn Client>,
|
||||
shared_state: Arc<SharedState>,
|
||||
cancel: CancellationToken,
|
||||
rps_period: Option<Duration>,
|
||||
ranges: Vec<KeyRange>,
|
||||
weights: rand::distributions::weighted::WeightedIndex<i128>,
|
||||
) {
|
||||
let client = pageserver_client::page_service::Client::new(args.page_service_connstring.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
let mut client = client
|
||||
.pagestream(worker_id.timeline.tenant_id, worker_id.timeline.timeline_id)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
shared_state.start_work_barrier.wait().await;
|
||||
let client_start = Instant::now();
|
||||
let mut ticks_processed = 0;
|
||||
@@ -415,12 +432,12 @@ async fn client_libpq(
|
||||
blkno: block_no,
|
||||
}
|
||||
};
|
||||
client.getpage_send(req).await.unwrap();
|
||||
client.send_get_page(req).await.unwrap();
|
||||
inflight.push_back(start);
|
||||
}
|
||||
|
||||
let start = inflight.pop_front().unwrap();
|
||||
client.getpage_recv().await.unwrap();
|
||||
client.recv_get_page().await.unwrap();
|
||||
let end = Instant::now();
|
||||
shared_state.live_stats.request_done();
|
||||
ticks_processed += 1;
|
||||
@@ -442,3 +459,104 @@ async fn client_libpq(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A benchmark client, to allow switching out the transport protocol.
|
||||
///
|
||||
/// For simplicity, this just uses separate asynchronous send/recv methods. The send method could
|
||||
/// return a future that resolves when the response is received, but we don't really need it.
|
||||
#[async_trait]
|
||||
trait Client: Send {
|
||||
/// Sends an asynchronous GetPage request to the pageserver.
|
||||
async fn send_get_page(&mut self, req: PagestreamGetPageRequest) -> anyhow::Result<()>;
|
||||
|
||||
/// Receives the next GetPage response from the pageserver.
|
||||
async fn recv_get_page(&mut self) -> anyhow::Result<PagestreamGetPageResponse>;
|
||||
}
|
||||
|
||||
/// A libpq-based Pageserver client.
|
||||
struct LibpqClient {
|
||||
inner: pageserver_client::page_service::PagestreamClient,
|
||||
}
|
||||
|
||||
impl LibpqClient {
|
||||
async fn new(connstring: String, ttid: TenantTimelineId) -> anyhow::Result<Self> {
|
||||
let inner = pageserver_client::page_service::Client::new(connstring)
|
||||
.await?
|
||||
.pagestream(ttid.tenant_id, ttid.timeline_id)
|
||||
.await?;
|
||||
Ok(Self { inner })
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Client for LibpqClient {
|
||||
async fn send_get_page(&mut self, req: PagestreamGetPageRequest) -> anyhow::Result<()> {
|
||||
self.inner.getpage_send(req).await
|
||||
}
|
||||
|
||||
async fn recv_get_page(&mut self) -> anyhow::Result<PagestreamGetPageResponse> {
|
||||
self.inner.getpage_recv().await
|
||||
}
|
||||
}
|
||||
|
||||
/// A gRPC client using the raw, no-frills gRPC client.
|
||||
struct GrpcClient {
|
||||
req_tx: tokio::sync::mpsc::Sender<proto::GetPageRequest>,
|
||||
resp_rx: tonic::Streaming<proto::GetPageResponse>,
|
||||
}
|
||||
|
||||
impl GrpcClient {
|
||||
async fn new(connstring: String, ttid: TenantTimelineId) -> anyhow::Result<Self> {
|
||||
let mut client = pageserver_page_api::proto::PageServiceClient::connect(connstring).await?;
|
||||
|
||||
// The channel has a buffer size of 1, since 0 is not allowed. It does not matter, since the
|
||||
// benchmark will control the queue depth (i.e. in-flight requests) anyway, and requests are
|
||||
// buffered by Tonic and the OS too.
|
||||
let (req_tx, req_rx) = tokio::sync::mpsc::channel(1);
|
||||
let req_stream = tokio_stream::wrappers::ReceiverStream::new(req_rx);
|
||||
let mut req = tonic::Request::new(req_stream);
|
||||
let metadata = req.metadata_mut();
|
||||
metadata.insert("neon-tenant-id", ttid.tenant_id.to_string().try_into()?);
|
||||
metadata.insert("neon-timeline-id", ttid.timeline_id.to_string().try_into()?);
|
||||
metadata.insert("neon-shard-id", "0000".try_into()?);
|
||||
|
||||
let resp = client.get_pages(req).await?;
|
||||
let resp_stream = resp.into_inner();
|
||||
|
||||
Ok(Self {
|
||||
req_tx,
|
||||
resp_rx: resp_stream,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Client for GrpcClient {
|
||||
async fn send_get_page(&mut self, req: PagestreamGetPageRequest) -> anyhow::Result<()> {
|
||||
let req = proto::GetPageRequest {
|
||||
request_id: 0,
|
||||
request_class: proto::GetPageClass::Normal as i32,
|
||||
read_lsn: Some(proto::ReadLsn {
|
||||
request_lsn: req.hdr.request_lsn.0,
|
||||
not_modified_since_lsn: req.hdr.not_modified_since.0,
|
||||
}),
|
||||
rel: Some(req.rel.into()),
|
||||
block_number: vec![req.blkno],
|
||||
};
|
||||
self.req_tx.send(req).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recv_get_page(&mut self) -> anyhow::Result<PagestreamGetPageResponse> {
|
||||
let resp = self.resp_rx.message().await?.unwrap();
|
||||
anyhow::ensure!(
|
||||
resp.status_code == proto::GetPageStatusCode::Ok as i32,
|
||||
"unexpected status code: {}",
|
||||
resp.status_code
|
||||
);
|
||||
Ok(PagestreamGetPageResponse {
|
||||
page: resp.page_image[0].clone(),
|
||||
req: PagestreamGetPageRequest::default(), // dummy
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,6 +65,30 @@ impl From<GetVectoredError> for BasebackupError {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<BasebackupError> for postgres_backend::QueryError {
|
||||
fn from(err: BasebackupError) -> Self {
|
||||
use postgres_backend::QueryError;
|
||||
use pq_proto::framed::ConnectionError;
|
||||
match err {
|
||||
BasebackupError::Client(err, _) => QueryError::Disconnected(ConnectionError::Io(err)),
|
||||
BasebackupError::Server(err) => QueryError::Other(err),
|
||||
BasebackupError::Shutdown => QueryError::Shutdown,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<BasebackupError> for tonic::Status {
|
||||
fn from(err: BasebackupError) -> Self {
|
||||
use tonic::Code;
|
||||
let code = match &err {
|
||||
BasebackupError::Client(_, _) => Code::Cancelled,
|
||||
BasebackupError::Server(_) => Code::Internal,
|
||||
BasebackupError::Shutdown => Code::Unavailable,
|
||||
};
|
||||
tonic::Status::new(code, err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
/// Create basebackup with non-rel data in it.
|
||||
/// Only include relational data if 'full_backup' is true.
|
||||
///
|
||||
@@ -248,7 +272,7 @@ where
|
||||
async fn flush(&mut self) -> Result<(), BasebackupError> {
|
||||
let nblocks = self.buf.len() / BLCKSZ as usize;
|
||||
let (kind, segno) = self.current_segment.take().unwrap();
|
||||
let segname = format!("{}/{:>04X}", kind.to_str(), segno);
|
||||
let segname = format!("{kind}/{segno:>04X}");
|
||||
let header = new_tar_header(&segname, self.buf.len() as u64)?;
|
||||
self.ar
|
||||
.append(&header, self.buf.as_slice())
|
||||
@@ -347,7 +371,7 @@ where
|
||||
.await?
|
||||
.partition(
|
||||
self.timeline.get_shard_identity(),
|
||||
Timeline::MAX_GET_VECTORED_KEYS * BLCKSZ as u64,
|
||||
self.timeline.conf.max_get_vectored_keys.get() as u64 * BLCKSZ as u64,
|
||||
);
|
||||
|
||||
let mut slru_builder = SlruSegmentsBuilder::new(&mut self.ar);
|
||||
|
||||
@@ -804,7 +804,7 @@ fn start_pageserver(
|
||||
} else {
|
||||
None
|
||||
},
|
||||
basebackup_cache.clone(),
|
||||
basebackup_cache,
|
||||
);
|
||||
|
||||
// Spawn a Pageserver gRPC server task. It will spawn separate tasks for
|
||||
@@ -816,12 +816,10 @@ fn start_pageserver(
|
||||
let mut page_service_grpc = None;
|
||||
if let Some(grpc_listener) = grpc_listener {
|
||||
page_service_grpc = Some(page_service::spawn_grpc(
|
||||
conf,
|
||||
tenant_manager.clone(),
|
||||
grpc_auth,
|
||||
otel_guard.as_ref().map(|g| g.dispatch.clone()),
|
||||
grpc_listener,
|
||||
basebackup_cache,
|
||||
)?);
|
||||
}
|
||||
|
||||
|
||||
@@ -14,7 +14,10 @@ use std::time::Duration;
|
||||
use anyhow::{Context, bail, ensure};
|
||||
use camino::{Utf8Path, Utf8PathBuf};
|
||||
use once_cell::sync::OnceCell;
|
||||
use pageserver_api::config::{DiskUsageEvictionTaskConfig, MaxVectoredReadBytes, PostHogConfig};
|
||||
use pageserver_api::config::{
|
||||
DiskUsageEvictionTaskConfig, MaxGetVectoredKeys, MaxVectoredReadBytes,
|
||||
PageServicePipeliningConfig, PageServicePipeliningConfigPipelined, PostHogConfig,
|
||||
};
|
||||
use pageserver_api::models::ImageCompressionAlgorithm;
|
||||
use pageserver_api::shard::TenantShardId;
|
||||
use pem::Pem;
|
||||
@@ -185,6 +188,9 @@ pub struct PageServerConf {
|
||||
|
||||
pub max_vectored_read_bytes: MaxVectoredReadBytes,
|
||||
|
||||
/// Maximum number of keys to be read in a single get_vectored call.
|
||||
pub max_get_vectored_keys: MaxGetVectoredKeys,
|
||||
|
||||
pub image_compression: ImageCompressionAlgorithm,
|
||||
|
||||
/// Whether to offload archived timelines automatically
|
||||
@@ -404,6 +410,7 @@ impl PageServerConf {
|
||||
secondary_download_concurrency,
|
||||
ingest_batch_size,
|
||||
max_vectored_read_bytes,
|
||||
max_get_vectored_keys,
|
||||
image_compression,
|
||||
timeline_offloading,
|
||||
ephemeral_bytes_per_memory_kb,
|
||||
@@ -470,6 +477,7 @@ impl PageServerConf {
|
||||
secondary_download_concurrency,
|
||||
ingest_batch_size,
|
||||
max_vectored_read_bytes,
|
||||
max_get_vectored_keys,
|
||||
image_compression,
|
||||
timeline_offloading,
|
||||
ephemeral_bytes_per_memory_kb,
|
||||
@@ -598,6 +606,19 @@ impl PageServerConf {
|
||||
)
|
||||
})?;
|
||||
|
||||
if let PageServicePipeliningConfig::Pipelined(PageServicePipeliningConfigPipelined {
|
||||
max_batch_size,
|
||||
..
|
||||
}) = conf.page_service_pipelining
|
||||
{
|
||||
if max_batch_size.get() > conf.max_get_vectored_keys.get() {
|
||||
return Err(anyhow::anyhow!(
|
||||
"`max_batch_size` ({max_batch_size}) must be less than or equal to `max_get_vectored_keys` ({})",
|
||||
conf.max_get_vectored_keys.get()
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
Ok(conf)
|
||||
}
|
||||
|
||||
@@ -685,6 +706,7 @@ impl ConfigurableSemaphore {
|
||||
mod tests {
|
||||
|
||||
use camino::Utf8PathBuf;
|
||||
use rstest::rstest;
|
||||
use utils::id::NodeId;
|
||||
|
||||
use super::PageServerConf;
|
||||
@@ -724,4 +746,28 @@ mod tests {
|
||||
PageServerConf::parse_and_validate(NodeId(0), config_toml, &workdir)
|
||||
.expect_err("parse_and_validate should fail for endpoint without scheme");
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
#[case(32, 32, true)]
|
||||
#[case(64, 32, false)]
|
||||
#[case(64, 64, true)]
|
||||
#[case(128, 128, true)]
|
||||
fn test_config_max_batch_size_is_valid(
|
||||
#[case] max_batch_size: usize,
|
||||
#[case] max_get_vectored_keys: usize,
|
||||
#[case] is_valid: bool,
|
||||
) {
|
||||
let input = format!(
|
||||
r#"
|
||||
control_plane_api = "http://localhost:6666"
|
||||
max_get_vectored_keys = {max_get_vectored_keys}
|
||||
page_service_pipelining = {{ mode="pipelined", execution="concurrent-futures", max_batch_size={max_batch_size}, batching="uniform-lsn" }}
|
||||
"#,
|
||||
);
|
||||
let config_toml = toml_edit::de::from_str::<pageserver_api::config::ConfigToml>(&input)
|
||||
.expect("config has valid fields");
|
||||
let workdir = Utf8PathBuf::from("/nonexistent");
|
||||
let result = PageServerConf::parse_and_validate(NodeId(0), config_toml, &workdir);
|
||||
assert_eq!(result.is_ok(), is_valid);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ use metrics::{
|
||||
register_int_gauge, register_int_gauge_vec, register_uint_gauge, register_uint_gauge_vec,
|
||||
};
|
||||
use once_cell::sync::Lazy;
|
||||
use pageserver_api::config::defaults::DEFAULT_MAX_GET_VECTORED_KEYS;
|
||||
use pageserver_api::config::{
|
||||
PageServicePipeliningConfig, PageServicePipeliningConfigPipelined,
|
||||
PageServiceProtocolPipelinedBatchingStrategy, PageServiceProtocolPipelinedExecutionStrategy,
|
||||
@@ -32,7 +33,6 @@ use crate::config::PageServerConf;
|
||||
use crate::context::{PageContentKind, RequestContext};
|
||||
use crate::pgdatadir_mapping::DatadirModificationStats;
|
||||
use crate::task_mgr::TaskKind;
|
||||
use crate::tenant::Timeline;
|
||||
use crate::tenant::layer_map::LayerMap;
|
||||
use crate::tenant::mgr::TenantSlot;
|
||||
use crate::tenant::storage_layer::{InMemoryLayer, PersistentLayerDesc};
|
||||
@@ -1939,7 +1939,7 @@ static SMGR_QUERY_TIME_GLOBAL: Lazy<HistogramVec> = Lazy::new(|| {
|
||||
});
|
||||
|
||||
static PAGE_SERVICE_BATCH_SIZE_BUCKETS_GLOBAL: Lazy<Vec<f64>> = Lazy::new(|| {
|
||||
(1..=u32::try_from(Timeline::MAX_GET_VECTORED_KEYS).unwrap())
|
||||
(1..=u32::try_from(DEFAULT_MAX_GET_VECTORED_KEYS).unwrap())
|
||||
.map(|v| v.into())
|
||||
.collect()
|
||||
});
|
||||
@@ -1957,7 +1957,7 @@ static PAGE_SERVICE_BATCH_SIZE_BUCKETS_PER_TIMELINE: Lazy<Vec<f64>> = Lazy::new(
|
||||
let mut buckets = Vec::new();
|
||||
for i in 0.. {
|
||||
let bucket = 1 << i;
|
||||
if bucket > u32::try_from(Timeline::MAX_GET_VECTORED_KEYS).unwrap() {
|
||||
if bucket > u32::try_from(DEFAULT_MAX_GET_VECTORED_KEYS).unwrap() {
|
||||
break;
|
||||
}
|
||||
buckets.push(bucket.into());
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -431,10 +431,10 @@ impl Timeline {
|
||||
GetVectoredError::InvalidLsn(e) => {
|
||||
Err(anyhow::anyhow!("invalid LSN: {e:?}").into())
|
||||
}
|
||||
// NB: this should never happen in practice because we limit MAX_GET_VECTORED_KEYS
|
||||
// NB: this should never happen in practice because we limit batch size to be smaller than max_get_vectored_keys
|
||||
// TODO: we can prevent this error class by moving this check into the type system
|
||||
GetVectoredError::Oversized(err) => {
|
||||
Err(anyhow::anyhow!("batching oversized: {err:?}").into())
|
||||
GetVectoredError::Oversized(err, max) => {
|
||||
Err(anyhow::anyhow!("batching oversized: {err} > {max}").into())
|
||||
}
|
||||
};
|
||||
|
||||
@@ -471,8 +471,19 @@ impl Timeline {
|
||||
|
||||
let rels = self.list_rels(spcnode, dbnode, version, ctx).await?;
|
||||
|
||||
if rels.is_empty() {
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
// Pre-deserialize the rel directory to avoid duplicated work in `get_relsize_cached`.
|
||||
let reldir_key = rel_dir_to_key(spcnode, dbnode);
|
||||
let buf = version.get(self, reldir_key, ctx).await?;
|
||||
let reldir = RelDirectory::des(&buf)?;
|
||||
|
||||
for rel in rels {
|
||||
let n_blocks = self.get_rel_size(rel, version, ctx).await?;
|
||||
let n_blocks = self
|
||||
.get_rel_size_in_reldir(rel, version, Some((reldir_key, &reldir)), ctx)
|
||||
.await?;
|
||||
total_blocks += n_blocks as usize;
|
||||
}
|
||||
Ok(total_blocks)
|
||||
@@ -487,6 +498,19 @@ impl Timeline {
|
||||
tag: RelTag,
|
||||
version: Version<'_>,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<BlockNumber, PageReconstructError> {
|
||||
self.get_rel_size_in_reldir(tag, version, None, ctx).await
|
||||
}
|
||||
|
||||
/// Get size of a relation file. The relation must exist, otherwise an error is returned.
|
||||
///
|
||||
/// See [`Self::get_rel_exists_in_reldir`] on why we need `deserialized_reldir_v1`.
|
||||
pub(crate) async fn get_rel_size_in_reldir(
|
||||
&self,
|
||||
tag: RelTag,
|
||||
version: Version<'_>,
|
||||
deserialized_reldir_v1: Option<(Key, &RelDirectory)>,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<BlockNumber, PageReconstructError> {
|
||||
if tag.relnode == 0 {
|
||||
return Err(PageReconstructError::Other(
|
||||
@@ -499,7 +523,9 @@ impl Timeline {
|
||||
}
|
||||
|
||||
if (tag.forknum == FSM_FORKNUM || tag.forknum == VISIBILITYMAP_FORKNUM)
|
||||
&& !self.get_rel_exists(tag, version, ctx).await?
|
||||
&& !self
|
||||
.get_rel_exists_in_reldir(tag, version, deserialized_reldir_v1, ctx)
|
||||
.await?
|
||||
{
|
||||
// FIXME: Postgres sometimes calls smgrcreate() to create
|
||||
// FSM, and smgrnblocks() on it immediately afterwards,
|
||||
@@ -521,11 +547,28 @@ impl Timeline {
|
||||
///
|
||||
/// Only shard 0 has a full view of the relations. Other shards only know about relations that
|
||||
/// the shard stores pages for.
|
||||
///
|
||||
pub(crate) async fn get_rel_exists(
|
||||
&self,
|
||||
tag: RelTag,
|
||||
version: Version<'_>,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<bool, PageReconstructError> {
|
||||
self.get_rel_exists_in_reldir(tag, version, None, ctx).await
|
||||
}
|
||||
|
||||
/// Does the relation exist? With a cached deserialized `RelDirectory`.
|
||||
///
|
||||
/// There are some cases where the caller loops across all relations. In that specific case,
|
||||
/// the caller should obtain the deserialized `RelDirectory` first and then call this function
|
||||
/// to avoid duplicated work of deserliazation. This is a hack and should be removed by introducing
|
||||
/// a new API (e.g., `get_rel_exists_batched`).
|
||||
pub(crate) async fn get_rel_exists_in_reldir(
|
||||
&self,
|
||||
tag: RelTag,
|
||||
version: Version<'_>,
|
||||
deserialized_reldir_v1: Option<(Key, &RelDirectory)>,
|
||||
ctx: &RequestContext,
|
||||
) -> Result<bool, PageReconstructError> {
|
||||
if tag.relnode == 0 {
|
||||
return Err(PageReconstructError::Other(
|
||||
@@ -568,6 +611,17 @@ impl Timeline {
|
||||
// fetch directory listing (old)
|
||||
|
||||
let key = rel_dir_to_key(tag.spcnode, tag.dbnode);
|
||||
|
||||
if let Some((cached_key, dir)) = deserialized_reldir_v1 {
|
||||
if cached_key == key {
|
||||
return Ok(dir.rels.contains(&(tag.relnode, tag.forknum)));
|
||||
} else if cfg!(test) || cfg!(feature = "testing") {
|
||||
panic!("cached reldir key mismatch: {cached_key} != {key}");
|
||||
} else {
|
||||
warn!("cached reldir key mismatch: {cached_key} != {key}");
|
||||
}
|
||||
// Fallback to reading the directory from the datadir.
|
||||
}
|
||||
let buf = version.get(self, key, ctx).await?;
|
||||
|
||||
let dir = RelDirectory::des(&buf)?;
|
||||
@@ -665,7 +719,7 @@ impl Timeline {
|
||||
|
||||
let batches = keyspace.partition(
|
||||
self.get_shard_identity(),
|
||||
Timeline::MAX_GET_VECTORED_KEYS * BLCKSZ as u64,
|
||||
self.conf.max_get_vectored_keys.get() as u64 * BLCKSZ as u64,
|
||||
);
|
||||
|
||||
let io_concurrency = IoConcurrency::spawn_from_conf(
|
||||
@@ -905,7 +959,7 @@ impl Timeline {
|
||||
|
||||
let batches = keyspace.partition(
|
||||
self.get_shard_identity(),
|
||||
Timeline::MAX_GET_VECTORED_KEYS * BLCKSZ as u64,
|
||||
self.conf.max_get_vectored_keys.get() as u64 * BLCKSZ as u64,
|
||||
);
|
||||
|
||||
let io_concurrency = IoConcurrency::spawn_from_conf(
|
||||
|
||||
@@ -7197,7 +7197,7 @@ mod tests {
|
||||
let end = desc
|
||||
.key_range
|
||||
.start
|
||||
.add(Timeline::MAX_GET_VECTORED_KEYS.try_into().unwrap());
|
||||
.add(tenant.conf.max_get_vectored_keys.get() as u32);
|
||||
reads.push(KeySpace {
|
||||
ranges: vec![start..end],
|
||||
});
|
||||
@@ -11260,11 +11260,11 @@ mod tests {
|
||||
let mut keyspaces_at_lsn: HashMap<Lsn, KeySpaceRandomAccum> = HashMap::default();
|
||||
let mut used_keys: HashSet<Key> = HashSet::default();
|
||||
|
||||
while used_keys.len() < Timeline::MAX_GET_VECTORED_KEYS as usize {
|
||||
while used_keys.len() < tenant.conf.max_get_vectored_keys.get() {
|
||||
let selected_lsn = interesting_lsns.choose(&mut random).expect("not empty");
|
||||
let mut selected_key = start_key.add(random.gen_range(0..KEY_DIMENSION_SIZE));
|
||||
|
||||
while used_keys.len() < Timeline::MAX_GET_VECTORED_KEYS as usize {
|
||||
while used_keys.len() < tenant.conf.max_get_vectored_keys.get() {
|
||||
if used_keys.contains(&selected_key)
|
||||
|| selected_key >= start_key.add(KEY_DIMENSION_SIZE)
|
||||
{
|
||||
|
||||
@@ -817,8 +817,8 @@ pub(crate) enum GetVectoredError {
|
||||
#[error("timeline shutting down")]
|
||||
Cancelled,
|
||||
|
||||
#[error("requested too many keys: {0} > {}", Timeline::MAX_GET_VECTORED_KEYS)]
|
||||
Oversized(u64),
|
||||
#[error("requested too many keys: {0} > {1}")]
|
||||
Oversized(u64, u64),
|
||||
|
||||
#[error("requested at invalid LSN: {0}")]
|
||||
InvalidLsn(Lsn),
|
||||
@@ -950,6 +950,18 @@ pub(crate) enum WaitLsnError {
|
||||
Timeout(String),
|
||||
}
|
||||
|
||||
impl From<WaitLsnError> for tonic::Status {
|
||||
fn from(err: WaitLsnError) -> Self {
|
||||
use tonic::Code;
|
||||
let code = match &err {
|
||||
WaitLsnError::Timeout(_) => Code::Internal,
|
||||
WaitLsnError::BadState(_) => Code::Internal,
|
||||
WaitLsnError::Shutdown => Code::Unavailable,
|
||||
};
|
||||
tonic::Status::new(code, err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
// The impls below achieve cancellation mapping for errors.
|
||||
// Perhaps there's a way of achieving this with less cruft.
|
||||
|
||||
@@ -1007,7 +1019,7 @@ impl From<GetVectoredError> for PageReconstructError {
|
||||
match e {
|
||||
GetVectoredError::Cancelled => PageReconstructError::Cancelled,
|
||||
GetVectoredError::InvalidLsn(_) => PageReconstructError::Other(anyhow!("Invalid LSN")),
|
||||
err @ GetVectoredError::Oversized(_) => PageReconstructError::Other(err.into()),
|
||||
err @ GetVectoredError::Oversized(_, _) => PageReconstructError::Other(err.into()),
|
||||
GetVectoredError::MissingKey(err) => PageReconstructError::MissingKey(err),
|
||||
GetVectoredError::GetReadyAncestorError(err) => PageReconstructError::from(err),
|
||||
GetVectoredError::Other(err) => PageReconstructError::Other(err),
|
||||
@@ -1187,7 +1199,6 @@ impl Timeline {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) const MAX_GET_VECTORED_KEYS: u64 = 32;
|
||||
pub(crate) const LAYERS_VISITED_WARN_THRESHOLD: u32 = 100;
|
||||
|
||||
/// Look up multiple page versions at a given LSN
|
||||
@@ -1202,9 +1213,12 @@ impl Timeline {
|
||||
) -> Result<BTreeMap<Key, Result<Bytes, PageReconstructError>>, GetVectoredError> {
|
||||
let total_keyspace = query.total_keyspace();
|
||||
|
||||
let key_count = total_keyspace.total_raw_size().try_into().unwrap();
|
||||
if key_count > Timeline::MAX_GET_VECTORED_KEYS {
|
||||
return Err(GetVectoredError::Oversized(key_count));
|
||||
let key_count = total_keyspace.total_raw_size();
|
||||
if key_count > self.conf.max_get_vectored_keys.get() {
|
||||
return Err(GetVectoredError::Oversized(
|
||||
key_count as u64,
|
||||
self.conf.max_get_vectored_keys.get() as u64,
|
||||
));
|
||||
}
|
||||
|
||||
for range in &total_keyspace.ranges {
|
||||
@@ -5258,7 +5272,7 @@ impl Timeline {
|
||||
key = key.next();
|
||||
|
||||
// Maybe flush `key_rest_accum`
|
||||
if key_request_accum.raw_size() >= Timeline::MAX_GET_VECTORED_KEYS
|
||||
if key_request_accum.raw_size() >= self.conf.max_get_vectored_keys.get() as u64
|
||||
|| (last_key_in_range && key_request_accum.raw_size() > 0)
|
||||
{
|
||||
let query =
|
||||
|
||||
@@ -201,8 +201,8 @@ async fn prepare_import(
|
||||
.await;
|
||||
match res {
|
||||
Ok(_) => break,
|
||||
Err(err) => {
|
||||
info!(?err, "indefinitely waiting for pgdata to finish");
|
||||
Err(_err) => {
|
||||
info!("indefinitely waiting for pgdata to finish");
|
||||
if tokio::time::timeout(std::time::Duration::from_secs(10), cancel.cancelled())
|
||||
.await
|
||||
.is_ok()
|
||||
|
||||
@@ -471,6 +471,8 @@ impl Plan {
|
||||
last_completed_job_idx = job_idx;
|
||||
|
||||
if last_completed_job_idx % checkpoint_every == 0 {
|
||||
tracing::info!(last_completed_job_idx, jobs=%jobs_in_plan, "Checkpointing import status");
|
||||
|
||||
let progress = ShardImportProgressV1 {
|
||||
jobs: jobs_in_plan,
|
||||
completed: last_completed_job_idx,
|
||||
@@ -492,8 +494,6 @@ impl Plan {
|
||||
anyhow::anyhow!("Shut down while putting timeline import status")
|
||||
})?;
|
||||
}
|
||||
|
||||
tracing::info!(last_completed_job_idx, jobs=%jobs_in_plan, "Checkpointing import status");
|
||||
},
|
||||
Some(Err(_)) => {
|
||||
anyhow::bail!(
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
[package]
|
||||
name = "paster"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
axum-extra = { workspace = true, features = ["cookie", "cookie-private"] }
|
||||
axum.workspace = true
|
||||
base64.workspace = true
|
||||
chrono.workspace = true
|
||||
nanoid = { version = "0.4.0", default-features = false }
|
||||
rand.workspace = true
|
||||
reqwest.workspace = true
|
||||
rustls-native-certs.workspace = true
|
||||
rustls.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
time = { version = "0.3.36", default-features = false }
|
||||
tokio-postgres-rustls.workspace = true
|
||||
tokio-postgres.workspace = true
|
||||
tokio.workspace = true
|
||||
tracing-subscriber.workspace = true
|
||||
tracing.workspace = true
|
||||
workspace_hack.workspace = true
|
||||
@@ -1,18 +0,0 @@
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id SERIAL PRIMARY KEY,
|
||||
sub VARCHAR(100) NOT NULL UNIQUE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
id SERIAL PRIMARY KEY,
|
||||
user_id INT NOT NULL UNIQUE REFERENCES users(id),
|
||||
session_id VARCHAR NOT NULL,
|
||||
expires_at TIMESTAMP WITH TIME ZONE NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS pastes (
|
||||
id SERIAL PRIMARY KEY,
|
||||
user_id INT NOT NULL REFERENCES users(id),
|
||||
paste text NOT NULL,
|
||||
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
@@ -1,353 +0,0 @@
|
||||
//! Paster is a service to share logs or code snippets outside of
|
||||
//! Slack, not relying on public services
|
||||
use anyhow::Result;
|
||||
use shortener::google_oauth_gate::{AuthRequest, State, UserId};
|
||||
use axum::Form;
|
||||
use axum::extract::{FromRef, FromRequestParts, Path, Query, State as AxumStateT};
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::{Html, IntoResponse};
|
||||
use axum::response::{Redirect, Response};
|
||||
use axum::routing::get;
|
||||
use axum_extra::extract::PrivateCookieJar;
|
||||
use axum_extra::extract::cookie::{Cookie, Key};
|
||||
use chrono::{Duration, Local, TimeZone, Utc};
|
||||
use core::num::NonZeroI32;
|
||||
use serde::Deserialize;
|
||||
use std::env;
|
||||
use std::sync::Arc;
|
||||
use tracing::{error, info};
|
||||
use tracing_subscriber::layer::SubscriberExt;
|
||||
use tracing_subscriber::util::SubscriberInitExt;
|
||||
|
||||
const SOCKET: &str = "127.0.0.1:12344";
|
||||
const HOST: &str = "http://127.0.0.1:12344";
|
||||
const ALLOWED_OAUTH_DOMAIN: &str = "neon.tech";
|
||||
|
||||
fn oauth_redirect_url() -> String {
|
||||
format!("{HOST}{AUTHORIZED_ROUTE}")
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
tracing_subscriber::registry()
|
||||
.with(
|
||||
tracing_subscriber::EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| format!("{}=info", env!("CARGO_CRATE_NAME")).into()),
|
||||
)
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.init();
|
||||
|
||||
let oauth_client_id = env::var("OAUTH_CLIENT_ID").expect("Missing OAUTH_CLIENT_ID");
|
||||
let oauth_client_secret = env::var("OAUTH_CLIENT_SECRET").expect("Missing OAUTH_CLIENT_SECRET");
|
||||
|
||||
let db_connstr = env::var("DB_CONNSTR").expect("Missing DB_CONNSTR");
|
||||
let mut roots = rustls::RootCertStore::empty();
|
||||
for cert in rustls_native_certs::load_native_certs().expect("could not load platform certs") {
|
||||
roots.add(cert).unwrap();
|
||||
}
|
||||
let config = rustls::ClientConfig::builder()
|
||||
.with_root_certificates(roots)
|
||||
.with_no_client_auth();
|
||||
let tls = tokio_postgres_rustls::MakeRustlsConnect::new(config);
|
||||
info!("initialized TLS");
|
||||
|
||||
let (db_client, db_conn) = tokio_postgres::connect(&db_connstr, tls).await?;
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) = db_conn.await {
|
||||
error!(%err, "connecting to database");
|
||||
std::process::exit(1);
|
||||
}
|
||||
});
|
||||
info!("connected to database");
|
||||
|
||||
let state = InnerState {
|
||||
db_client,
|
||||
cookie_jar_key: Key::generate(),
|
||||
oauth_client_id,
|
||||
oauth_client_secret,
|
||||
};
|
||||
let router = axum::Router::new()
|
||||
.route("/", get(index).post(paste))
|
||||
.route("/authorize", get(authorize))
|
||||
.route(AUTHORIZED_ROUTE, get(authorized))
|
||||
.route("/{id}", get(view_paste))
|
||||
.with_state(State { 0: Arc::new(state) });
|
||||
let listener = tokio::net::TcpListener::bind(SOCKET)
|
||||
.await
|
||||
.expect("failed to bind TcpListener");
|
||||
info!("listening on {SOCKET}");
|
||||
axum::serve(listener, router).await.unwrap();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct UserId {
|
||||
id: NonZeroI32,
|
||||
}
|
||||
|
||||
impl axum::extract::OptionalFromRequestParts<State> for UserId {
|
||||
type Rejection = Response;
|
||||
async fn from_request_parts(
|
||||
parts: &mut axum::http::request::Parts,
|
||||
state: &State,
|
||||
) -> Result<Option<Self>, Self::Rejection> {
|
||||
let jar: PrivateCookieJar = PrivateCookieJar::from_request_parts(parts, state)
|
||||
.await
|
||||
.unwrap(); // infallible
|
||||
let Some(session_id) = jar.get(COOKIE_SID).map(|cookie| cookie.value().to_owned()) else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let client = &state.db_client;
|
||||
let query = client
|
||||
.query_opt(
|
||||
"SELECT user_id FROM sessions WHERE session_id = $1",
|
||||
&[&session_id],
|
||||
)
|
||||
.await;
|
||||
let id = match query {
|
||||
Ok(Some(row)) => row.get::<usize, i32>(0),
|
||||
Ok(None) => return Ok(None),
|
||||
Err(err) => {
|
||||
error!(%err, "querying user session");
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
let id = NonZeroI32::new(id).unwrap(); // postgres id guaranteed not to be zero
|
||||
Ok(Some(Self { id }))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Paste {
|
||||
paste: String,
|
||||
}
|
||||
|
||||
fn paste_form() -> Html<String> {
|
||||
Html(
|
||||
r#"
|
||||
<form method="post">
|
||||
<textarea name="paste" style="width:100%;height:80%"></textarea>
|
||||
<input type="submit" value="Paste" style="margin-top:10px">
|
||||
</form>"#
|
||||
.to_string(),
|
||||
)
|
||||
}
|
||||
|
||||
fn authorize_link(paste_id: i32) -> String {
|
||||
format!("<a href=\"/authorize?paste_id={paste_id}\">Authorize</a>")
|
||||
}
|
||||
|
||||
async fn index(user: Option<UserId>) -> Html<String> {
|
||||
if user.is_some() {
|
||||
return paste_form();
|
||||
}
|
||||
Html(authorize_link(0))
|
||||
}
|
||||
|
||||
async fn paste(
|
||||
state: AxumState,
|
||||
user: Option<UserId>,
|
||||
Form(Paste { paste }): Form<Paste>,
|
||||
) -> Response {
|
||||
let user_id = match user {
|
||||
None => return StatusCode::FORBIDDEN.into_response(),
|
||||
Some(user) => user.id,
|
||||
};
|
||||
if paste.is_empty() {
|
||||
return paste_form().into_response();
|
||||
}
|
||||
|
||||
let query = state
|
||||
.db_client
|
||||
.query_one(
|
||||
"INSERT INTO pastes (user_id, paste) VALUES ($1, $2) RETURNING id",
|
||||
&[&user_id.get(), &paste],
|
||||
)
|
||||
.await;
|
||||
let id = match query {
|
||||
Ok(row) => row.get::<usize, i32>(0),
|
||||
Err(err) => {
|
||||
error!(%err, "inserting paste");
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
};
|
||||
Redirect::to(&format!("/{id}")).into_response()
|
||||
}
|
||||
|
||||
async fn view_paste(state: AxumState, user: Option<UserId>, Path(paste_id): Path<i32>) -> Response {
|
||||
let user_id = match user {
|
||||
None => return Html(authorize_link(paste_id)).into_response(),
|
||||
Some(user) => user.id,
|
||||
};
|
||||
|
||||
let query = state
|
||||
.db_client
|
||||
.query_opt("SELECT paste FROM pastes WHERE id = $1", &[&paste_id])
|
||||
.await;
|
||||
let row = match query {
|
||||
Ok(None) => return StatusCode::NOT_FOUND.into_response(),
|
||||
Ok(Some(row)) => row,
|
||||
Err(err) => {
|
||||
error!(%err, %paste_id, %user_id, "querying paste");
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
};
|
||||
row.get::<usize, String>(0).into_response()
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct AuthRequest {
|
||||
code: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct AuthResponse {
|
||||
access_token: String,
|
||||
id_token: String,
|
||||
expires_in: u64,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct UserInfo {
|
||||
hd: String,
|
||||
sub: String,
|
||||
}
|
||||
|
||||
fn decode_id_token(token: String) -> Option<UserInfo> {
|
||||
let payload = token.split(".").skip(1).take(1).collect::<Vec<&str>>();
|
||||
let decoded = base64::decode_config(payload.get(0)?, base64::STANDARD_NO_PAD).ok()?;
|
||||
serde_json::from_slice::<UserInfo>(&decoded).ok()
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct AuthorizeQuery {
|
||||
paste_id: i32,
|
||||
}
|
||||
|
||||
fn generate_csrf_token(num_bytes: u32) -> String {
|
||||
use rand::{Rng, thread_rng};
|
||||
let random_bytes: Vec<u8> = (0..num_bytes).map(|_| thread_rng().r#gen::<u8>()).collect();
|
||||
base64::encode_config(&random_bytes, base64::URL_SAFE_NO_PAD)
|
||||
}
|
||||
|
||||
async fn authorize(
|
||||
state: AxumState,
|
||||
jar: PrivateCookieJar,
|
||||
Query(AuthorizeQuery { paste_id }): Query<AuthorizeQuery>,
|
||||
) -> (PrivateCookieJar, Redirect) {
|
||||
let csrf_token = generate_csrf_token(16);
|
||||
let client_id = &state.oauth_client_id;
|
||||
let redirect_uri = oauth_redirect_url();
|
||||
let auth_url = format!(
|
||||
"{OAUTH_BASE_URL}?response_type=code\
|
||||
&client_id={client_id}\
|
||||
&state={csrf_token}\
|
||||
&redirect_uri={redirect_uri}\
|
||||
&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email"
|
||||
);
|
||||
|
||||
let redirect_cookie = Cookie::build((COOKIE_REDIRECT, paste_id.to_string()))
|
||||
.path("/")
|
||||
//.TODO secure(true) not true for localhost
|
||||
//.domain(COOKIE_DOMAIN)
|
||||
.secure(false)
|
||||
.same_site(axum_extra::extract::cookie::SameSite::Lax)
|
||||
.http_only(true)
|
||||
.build();
|
||||
let csrf_cookie = Cookie::build((COOKIE_CSRF, csrf_token))
|
||||
.path("/")
|
||||
//.TODO secure(true) not true for localhost
|
||||
//.domain(COOKIE_DOMAIN)
|
||||
.secure(false)
|
||||
.same_site(axum_extra::extract::cookie::SameSite::Lax)
|
||||
.http_only(true)
|
||||
.build();
|
||||
let jar = jar.add(redirect_cookie).add(csrf_cookie);
|
||||
let url = Into::<String>::into(auth_url);
|
||||
(jar, Redirect::to(&url))
|
||||
}
|
||||
|
||||
async fn authorized(
|
||||
state: AxumState,
|
||||
jar: PrivateCookieJar,
|
||||
Query(auth_request): Query<AuthRequest>,
|
||||
) -> Result<(PrivateCookieJar, Redirect), Response> {
|
||||
let params = [
|
||||
("grant_type", "authorization_code"),
|
||||
("redirect_uri", &oauth_redirect_url()),
|
||||
("code", &auth_request.code),
|
||||
("client_id", &state.oauth_client_id),
|
||||
("client_secret", &state.oauth_client_secret),
|
||||
];
|
||||
let auth_response = reqwest::Client::new()
|
||||
.post(OAUTH_TOKEN_URL)
|
||||
.form(¶ms)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|err| {
|
||||
error!(%err, "exchanging oauth code for token");
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
})?
|
||||
.json::<AuthResponse>()
|
||||
.await
|
||||
.map_err(|err| {
|
||||
error!(%err, "deserializing access token response");
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
})?;
|
||||
let Some(UserInfo { hd, sub }) = decode_id_token(auth_response.id_token) else {
|
||||
error!("Failed to decode response id token");
|
||||
return Err(StatusCode::UNAUTHORIZED.into_response());
|
||||
};
|
||||
if hd != ALLOWED_OAUTH_DOMAIN {
|
||||
error!(hd, "Domain doesn't match {ALLOWED_OAUTH_DOMAIN}");
|
||||
return Err(StatusCode::UNAUTHORIZED.into_response());
|
||||
}
|
||||
|
||||
let token_duration = Duration::try_seconds(auth_response.expires_in as i64).unwrap();
|
||||
let expires_at = Utc.from_utc_datetime(&(Local::now().naive_local() + token_duration));
|
||||
let cookie_max_age = time::Duration::new(token_duration.num_seconds(), 0);
|
||||
|
||||
let session_cookie = Cookie::build((COOKIE_SID, auth_response.access_token.clone()))
|
||||
.path("/")
|
||||
//.TODO secure(true) not true for localhost
|
||||
//.domain(COOKIE_DOMAIN)
|
||||
.secure(false)
|
||||
.same_site(axum_extra::extract::cookie::SameSite::Lax)
|
||||
.http_only(true)
|
||||
.max_age(cookie_max_age)
|
||||
.build();
|
||||
|
||||
state
|
||||
.db_client
|
||||
.query(
|
||||
"WITH user_insert AS (\
|
||||
INSERT INTO users (sub) VALUES ($1) \
|
||||
ON CONFLICT (sub) DO UPDATE SET sub = excluded.sub RETURNING id)\
|
||||
INSERT INTO sessions (user_id, session_id, expires_at) \
|
||||
SELECT id, $2, $3 FROM user_insert \
|
||||
ON CONFLICT (user_id) DO UPDATE SET \
|
||||
session_id = excluded.session_id, \
|
||||
expires_at = excluded.expires_at",
|
||||
&[&sub, &auth_response.access_token, &expires_at],
|
||||
)
|
||||
.await
|
||||
.map_err(|err| {
|
||||
error!(%err, %sub, "updating session");
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
})?;
|
||||
|
||||
let csrf_cookie = jar.get(COOKIE_CSRF).unwrap(); // set in authorize()
|
||||
let jar = jar.remove(csrf_cookie).add(session_cookie);
|
||||
match jar.get(COOKIE_REDIRECT) {
|
||||
Some(redirect_cookie) => {
|
||||
let mut value = redirect_cookie.value_trimmed();
|
||||
if value == "0" {
|
||||
value = "";
|
||||
}
|
||||
let redirect_url = format!("/{value}");
|
||||
Ok((jar.remove(redirect_cookie), Redirect::to(&redirect_url)))
|
||||
}
|
||||
None => Ok((jar, Redirect::to("/"))),
|
||||
}
|
||||
}
|
||||
@@ -25,19 +25,15 @@ pub(super) async fn authenticate(
|
||||
}
|
||||
AuthSecret::Scram(secret) => {
|
||||
debug!("auth endpoint chooses SCRAM");
|
||||
let scram = auth::Scram(&secret, ctx);
|
||||
|
||||
let auth_outcome = tokio::time::timeout(config.scram_protocol_timeout, async {
|
||||
AuthFlow::new(client, scram)
|
||||
.authenticate()
|
||||
.await
|
||||
.inspect_err(|error| {
|
||||
warn!(?error, "error processing scram messages");
|
||||
})
|
||||
})
|
||||
let auth_outcome = tokio::time::timeout(
|
||||
config.scram_protocol_timeout,
|
||||
AuthFlow::new(client, auth::Scram(&secret, ctx)).authenticate(),
|
||||
)
|
||||
.await
|
||||
.inspect_err(|_| warn!("error processing scram messages error = authentication timed out, execution time exceeded {} seconds", config.scram_protocol_timeout.as_secs()))
|
||||
.map_err(auth::AuthError::user_timeout)??;
|
||||
.map_err(auth::AuthError::user_timeout)?
|
||||
.inspect_err(|error| warn!(?error, "error processing scram messages"))?;
|
||||
|
||||
let client_key = match auth_outcome {
|
||||
sasl::Outcome::Success(key) => key,
|
||||
|
||||
@@ -159,7 +159,7 @@ pub async fn task_main(
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
config: &'static ProxyConfig,
|
||||
backend: &'static ConsoleRedirectBackend,
|
||||
ctx: &RequestContext,
|
||||
|
||||
@@ -7,7 +7,9 @@ use std::time::Duration;
|
||||
|
||||
use ::http::HeaderName;
|
||||
use ::http::header::AUTHORIZATION;
|
||||
use bytes::Bytes;
|
||||
use futures::TryFutureExt;
|
||||
use hyper::StatusCode;
|
||||
use postgres_client::config::SslMode;
|
||||
use tokio::time::Instant;
|
||||
use tracing::{Instrument, debug, info, info_span, warn};
|
||||
@@ -72,28 +74,34 @@ impl NeonControlPlaneClient {
|
||||
role: &RoleName,
|
||||
) -> Result<AuthInfo, GetAuthInfoError> {
|
||||
async {
|
||||
let request = self
|
||||
.endpoint
|
||||
.get_path("get_endpoint_access_control")
|
||||
.header(X_REQUEST_ID, ctx.session_id().to_string())
|
||||
.header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
|
||||
.query(&[("session_id", ctx.session_id())])
|
||||
.query(&[
|
||||
("application_name", ctx.console_application_name().as_str()),
|
||||
("endpointish", endpoint.as_str()),
|
||||
("role", role.as_str()),
|
||||
])
|
||||
.build()?;
|
||||
|
||||
debug!(url = request.url().as_str(), "sending http request");
|
||||
let start = Instant::now();
|
||||
let response = {
|
||||
let _pause = ctx.latency_timer_pause_at(start, crate::metrics::Waiting::Cplane);
|
||||
self.endpoint.execute(request).await?
|
||||
};
|
||||
info!(duration = ?start.elapsed(), "received http response");
|
||||
let request = self
|
||||
.endpoint
|
||||
.get_path("get_endpoint_access_control")
|
||||
.header(X_REQUEST_ID, ctx.session_id().to_string())
|
||||
.header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
|
||||
.query(&[("session_id", ctx.session_id())])
|
||||
.query(&[
|
||||
("application_name", ctx.console_application_name().as_str()),
|
||||
("endpointish", endpoint.as_str()),
|
||||
("role", role.as_str()),
|
||||
])
|
||||
.build()?;
|
||||
|
||||
let body = match parse_body::<GetEndpointAccessControl>(response).await {
|
||||
debug!(url = request.url().as_str(), "sending http request");
|
||||
let start = Instant::now();
|
||||
let _pause = ctx.latency_timer_pause_at(start, crate::metrics::Waiting::Cplane);
|
||||
let response = self.endpoint.execute(request).await?;
|
||||
|
||||
info!(duration = ?start.elapsed(), "received http response");
|
||||
|
||||
response
|
||||
};
|
||||
|
||||
let body = match parse_body::<GetEndpointAccessControl>(
|
||||
response.status(),
|
||||
response.bytes().await?,
|
||||
) {
|
||||
Ok(body) => body,
|
||||
// Error 404 is special: it's ok not to have a secret.
|
||||
// TODO(anna): retry
|
||||
@@ -184,7 +192,10 @@ impl NeonControlPlaneClient {
|
||||
drop(pause);
|
||||
info!(duration = ?start.elapsed(), "received http response");
|
||||
|
||||
let body = parse_body::<EndpointJwksResponse>(response).await?;
|
||||
let body = parse_body::<EndpointJwksResponse>(
|
||||
response.status(),
|
||||
response.bytes().await.map_err(ControlPlaneError::from)?,
|
||||
)?;
|
||||
|
||||
let rules = body
|
||||
.jwks
|
||||
@@ -236,7 +247,7 @@ impl NeonControlPlaneClient {
|
||||
let response = self.endpoint.execute(request).await?;
|
||||
drop(pause);
|
||||
info!(duration = ?start.elapsed(), "received http response");
|
||||
let body = parse_body::<WakeCompute>(response).await?;
|
||||
let body = parse_body::<WakeCompute>(response.status(), response.bytes().await?)?;
|
||||
|
||||
// Unfortunately, ownership won't let us use `Option::ok_or` here.
|
||||
let (host, port) = match parse_host_port(&body.address) {
|
||||
@@ -487,33 +498,33 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
|
||||
}
|
||||
|
||||
/// Parse http response body, taking status code into account.
|
||||
async fn parse_body<T: for<'a> serde::Deserialize<'a>>(
|
||||
response: http::Response,
|
||||
fn parse_body<T: for<'a> serde::Deserialize<'a>>(
|
||||
status: StatusCode,
|
||||
body: Bytes,
|
||||
) -> Result<T, ControlPlaneError> {
|
||||
let status = response.status();
|
||||
if status.is_success() {
|
||||
// We shouldn't log raw body because it may contain secrets.
|
||||
info!("request succeeded, processing the body");
|
||||
return Ok(response.json().await?);
|
||||
return Ok(serde_json::from_slice(&body).map_err(std::io::Error::other)?);
|
||||
}
|
||||
let s = response.bytes().await?;
|
||||
|
||||
// Log plaintext to be able to detect, whether there are some cases not covered by the error struct.
|
||||
info!("response_error plaintext: {:?}", s);
|
||||
info!("response_error plaintext: {:?}", body);
|
||||
|
||||
// Don't throw an error here because it's not as important
|
||||
// as the fact that the request itself has failed.
|
||||
let mut body = serde_json::from_slice(&s).unwrap_or_else(|e| {
|
||||
let mut body = serde_json::from_slice(&body).unwrap_or_else(|e| {
|
||||
warn!("failed to parse error body: {e}");
|
||||
ControlPlaneErrorMessage {
|
||||
Box::new(ControlPlaneErrorMessage {
|
||||
error: "reason unclear (malformed error message)".into(),
|
||||
http_status_code: status,
|
||||
status: None,
|
||||
}
|
||||
})
|
||||
});
|
||||
body.http_status_code = status;
|
||||
|
||||
warn!("console responded with an error ({status}): {body:?}");
|
||||
Err(ControlPlaneError::Message(Box::new(body)))
|
||||
Err(ControlPlaneError::Message(body))
|
||||
}
|
||||
|
||||
fn parse_host_port(input: &str) -> Option<(&str, u16)> {
|
||||
|
||||
@@ -4,9 +4,10 @@
|
||||
|
||||
pub mod health_server;
|
||||
|
||||
use std::time::Duration;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use bytes::Bytes;
|
||||
use futures::FutureExt;
|
||||
use http::Method;
|
||||
use http_body_util::BodyExt;
|
||||
use hyper::body::Body;
|
||||
@@ -109,15 +110,31 @@ impl Endpoint {
|
||||
}
|
||||
|
||||
/// Execute a [request](reqwest::Request).
|
||||
pub(crate) async fn execute(&self, request: Request) -> Result<Response, Error> {
|
||||
let _timer = Metrics::get()
|
||||
pub(crate) fn execute(
|
||||
&self,
|
||||
request: Request,
|
||||
) -> impl Future<Output = Result<Response, Error>> {
|
||||
let metric = Metrics::get()
|
||||
.proxy
|
||||
.console_request_latency
|
||||
.start_timer(ConsoleRequest {
|
||||
.with_labels(ConsoleRequest {
|
||||
request: request.url().path(),
|
||||
});
|
||||
|
||||
self.client.execute(request).await
|
||||
let req = self.client.execute(request).boxed();
|
||||
|
||||
async move {
|
||||
let start = Instant::now();
|
||||
scopeguard::defer!({
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.console_request_latency
|
||||
.get_metric(metric)
|
||||
.observe_duration_since(start);
|
||||
});
|
||||
|
||||
req.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -186,7 +186,7 @@ where
|
||||
pub async fn read_message<'a, S>(
|
||||
stream: &mut S,
|
||||
buf: &'a mut Vec<u8>,
|
||||
max: usize,
|
||||
max: u32,
|
||||
) -> io::Result<(u8, &'a mut [u8])>
|
||||
where
|
||||
S: AsyncRead + Unpin,
|
||||
@@ -206,7 +206,7 @@ where
|
||||
let header = read!(stream => Header);
|
||||
|
||||
// as described above, the length must be at least 4.
|
||||
let Some(len) = (header.len.get() as usize).checked_sub(4) else {
|
||||
let Some(len) = header.len.get().checked_sub(4) else {
|
||||
return Err(io::Error::other(format!(
|
||||
"invalid startup message length {}, must be at least 4.",
|
||||
header.len,
|
||||
@@ -222,7 +222,7 @@ where
|
||||
}
|
||||
|
||||
// read in our entire message.
|
||||
buf.resize(len, 0);
|
||||
buf.resize(len as usize, 0);
|
||||
stream.read_exact(buf).await?;
|
||||
|
||||
Ok((header.tag, buf))
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::{debug, info, warn};
|
||||
@@ -57,7 +58,7 @@ pub(crate) enum HandshakeData<S> {
|
||||
/// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
|
||||
/// we also take an extra care of propagating only the select handshake errors to client.
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
ctx: &RequestContext,
|
||||
stream: S,
|
||||
mut tls: Option<&TlsConfig>,
|
||||
@@ -108,7 +109,9 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
})
|
||||
.map_ok(Box::new)
|
||||
.boxed();
|
||||
|
||||
res?;
|
||||
|
||||
@@ -146,7 +149,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
tls.cert_resolver.resolve(conn_info.server_name());
|
||||
|
||||
let tls = Stream::Tls {
|
||||
tls: Box::new(tls_stream),
|
||||
tls: tls_stream,
|
||||
tls_server_end_point,
|
||||
};
|
||||
(stream, msg) = PqStream::parse_startup(tls).await?;
|
||||
|
||||
@@ -270,7 +270,7 @@ impl ReportableError for ClientRequestError {
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
|
||||
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
config: &'static ProxyConfig,
|
||||
auth_backend: &'static auth::Backend<'static, ()>,
|
||||
ctx: &RequestContext,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use futures::FutureExt;
|
||||
use smol_str::SmolStr;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tracing::debug;
|
||||
@@ -89,6 +90,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {
|
||||
.compute
|
||||
.cancel_closure
|
||||
.try_cancel_query(compute_config)
|
||||
.boxed()
|
||||
.await
|
||||
{
|
||||
tracing::warn!(session_id = ?self.session_id, ?err, "could not cancel the query in the database");
|
||||
|
||||
@@ -30,52 +30,53 @@ where
|
||||
F: FnOnce(&str) -> super::Result<M>,
|
||||
M: Mechanism,
|
||||
{
|
||||
let sasl = {
|
||||
let (mut mechanism, mut input) = {
|
||||
// pause the timer while we communicate with the client
|
||||
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
|
||||
// Initial client message contains the chosen auth method's name.
|
||||
let msg = stream.read_password_message().await?;
|
||||
super::FirstMessage::parse(msg).ok_or(super::Error::BadClientMessage("bad sasl message"))?
|
||||
|
||||
let sasl = super::FirstMessage::parse(msg)
|
||||
.ok_or(super::Error::BadClientMessage("bad sasl message"))?;
|
||||
|
||||
(mechanism(sasl.method)?, sasl.message)
|
||||
};
|
||||
|
||||
let mut mechanism = mechanism(sasl.method)?;
|
||||
let mut input = sasl.message;
|
||||
loop {
|
||||
let step = mechanism
|
||||
.exchange(input)
|
||||
.inspect_err(|error| tracing::info!(?error, "error during SASL exchange"))?;
|
||||
|
||||
match step {
|
||||
Step::Continue(moved_mechanism, reply) => {
|
||||
match mechanism.exchange(input) {
|
||||
Ok(Step::Continue(moved_mechanism, reply)) => {
|
||||
mechanism = moved_mechanism;
|
||||
|
||||
// pause the timer while we communicate with the client
|
||||
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
|
||||
// write reply
|
||||
let sasl_msg = BeAuthenticationSaslMessage::Continue(reply.as_bytes());
|
||||
stream.write_message(BeMessage::AuthenticationSasl(sasl_msg));
|
||||
|
||||
// get next input
|
||||
stream.flush().await?;
|
||||
let msg = stream.read_password_message().await?;
|
||||
input = std::str::from_utf8(msg)
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?;
|
||||
drop(reply);
|
||||
}
|
||||
Step::Success(result, reply) => {
|
||||
// pause the timer while we communicate with the client
|
||||
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
|
||||
Ok(Step::Success(result, reply)) => {
|
||||
// write reply
|
||||
let sasl_msg = BeAuthenticationSaslMessage::Final(reply.as_bytes());
|
||||
stream.write_message(BeMessage::AuthenticationSasl(sasl_msg));
|
||||
stream.write_message(BeMessage::AuthenticationOk);
|
||||
|
||||
// exit with success
|
||||
break Ok(Outcome::Success(result));
|
||||
}
|
||||
// exit with failure
|
||||
Step::Failure(reason) => break Ok(Outcome::Failure(reason)),
|
||||
Ok(Step::Failure(reason)) => break Ok(Outcome::Failure(reason)),
|
||||
Err(error) => {
|
||||
tracing::info!(?error, "error during SASL exchange");
|
||||
return Err(error);
|
||||
}
|
||||
}
|
||||
|
||||
// pause the timer while we communicate with the client
|
||||
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
|
||||
|
||||
// get next input
|
||||
stream.flush().await?;
|
||||
let msg = stream.read_password_message().await?;
|
||||
input = std::str::from_utf8(msg)
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,7 +72,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> PqStream<S> {
|
||||
impl<S: AsyncRead + Unpin> PqStream<S> {
|
||||
/// Read a raw postgres packet, which will respect the max length requested.
|
||||
/// This is not cancel safe.
|
||||
async fn read_raw_expect(&mut self, tag: u8, max: usize) -> io::Result<&mut [u8]> {
|
||||
async fn read_raw_expect(&mut self, tag: u8, max: u32) -> io::Result<&mut [u8]> {
|
||||
let (actual_tag, msg) = read_message(&mut self.stream, &mut self.read, max).await?;
|
||||
if actual_tag != tag {
|
||||
return Err(io::Error::other(format!(
|
||||
@@ -89,7 +89,7 @@ impl<S: AsyncRead + Unpin> PqStream<S> {
|
||||
// passwords are usually pretty short
|
||||
// and SASL SCRAM messages are no longer than 256 bytes in my testing
|
||||
// (a few hashes and random bytes, encoded into base64).
|
||||
const MAX_PASSWORD_LENGTH: usize = 512;
|
||||
const MAX_PASSWORD_LENGTH: u32 = 512;
|
||||
self.read_raw_expect(FE_PASSWORD_MESSAGE, MAX_PASSWORD_LENGTH)
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -31,7 +31,9 @@ mod private {
|
||||
type Output = io::Result<RustlsStream<S>>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
Pin::new(&mut self.inner).poll(cx).map_ok(RustlsStream)
|
||||
Pin::new(&mut self.inner)
|
||||
.poll(cx)
|
||||
.map_ok(|s| RustlsStream(Box::new(s)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,7 +59,7 @@ mod private {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RustlsStream<S>(TlsStream<S>);
|
||||
pub struct RustlsStream<S>(Box<TlsStream<S>>);
|
||||
|
||||
impl<S> postgres_client::tls::TlsStream for RustlsStream<S>
|
||||
where
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
[package]
|
||||
name = "shortener"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
axum-extra = { workspace = true, features = ["cookie", "cookie-private"] }
|
||||
axum.workspace = true
|
||||
base64.workspace = true
|
||||
chrono.workspace = true
|
||||
cookie = "0.18.1"
|
||||
nanoid = { version = "0.4.0", default-features = false }
|
||||
rand.workspace = true
|
||||
reqwest.workspace = true
|
||||
rustls-native-certs.workspace = true
|
||||
rustls.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
time = { version = "0.3.36", default-features = false }
|
||||
tokio-postgres-rustls.workspace = true
|
||||
tokio-postgres.workspace = true
|
||||
tokio.workspace = true
|
||||
tracing-subscriber.workspace = true
|
||||
tracing.workspace = true
|
||||
workspace_hack.workspace = true
|
||||
@@ -1,19 +0,0 @@
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id SERIAL PRIMARY KEY,
|
||||
sub VARCHAR(100) NOT NULL UNIQUE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
id SERIAL PRIMARY KEY,
|
||||
user_id INT NOT NULL UNIQUE REFERENCES users(id),
|
||||
session_id VARCHAR NOT NULL,
|
||||
expires_at TIMESTAMP WITH TIME ZONE NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS urls (
|
||||
id SERIAL PRIMARY KEY,
|
||||
user_id INT NOT NULL REFERENCES users(id),
|
||||
short_url VARCHAR(6) NOT NULL UNIQUE,
|
||||
long_url VARCHAR NOT NULL,
|
||||
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
@@ -1,222 +0,0 @@
|
||||
//! Library to gate infrastructure behind Google Oauth for domain.
|
||||
//!
|
||||
//! Why not oauth-rs? Oauth .exchange_code() doesn't work with "request failed". Also, we can't get
|
||||
//! id token from it, and I don't want to pull in whole openid library just for that.
|
||||
//! Id token saves us a request to openid endpoint and one Oauth scope we don't use
|
||||
use anyhow::{Context, Result, bail};
|
||||
use axum::extract::{FromRef, FromRequestParts, Query, State as AxumState};
|
||||
use axum::response::Redirect;
|
||||
use axum_extra::extract::PrivateCookieJar;
|
||||
use axum_extra::extract::cookie::{Cookie, Key};
|
||||
use chrono::{Duration, Local, TimeZone, Utc};
|
||||
use cookie::CookieBuilder;
|
||||
use core::num::NonZeroI32;
|
||||
use reqwest::StatusCode;
|
||||
use serde::Deserialize;
|
||||
use std::sync::Arc;
|
||||
use tokio_postgres::Socket;
|
||||
|
||||
const OAUTH_BASE_URL: &str = "https://accounts.google.com/o/oauth2/v2/auth";
|
||||
const OAUTH_TOKEN_URL: &str = "https://oauth2.googleapis.com/token";
|
||||
const COOKIE_SID: &str = "sid";
|
||||
const COOKIE_CSRF: &str = "csrf";
|
||||
|
||||
pub struct Config {
|
||||
pub oauth_client_id: String,
|
||||
pub oauth_client_secret: String,
|
||||
pub oauth_redirect_url: String,
|
||||
pub oauth_allowed_domain: String,
|
||||
pub cookie_settings: fn(CookieBuilder) -> CookieBuilder,
|
||||
}
|
||||
|
||||
pub struct InnerState {
|
||||
config: Config,
|
||||
cookie_jar_key: Key,
|
||||
pub db_client: tokio_postgres::Client,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct State(Arc<InnerState>);
|
||||
type DbConn = tokio_postgres::Connection<Socket, tokio_postgres_rustls::RustlsStream<Socket>>;
|
||||
|
||||
impl State {
|
||||
pub async fn new(config: Config, db_connstr: &str) -> Result<(Self, DbConn)> {
|
||||
let mut roots = rustls::RootCertStore::empty();
|
||||
for cert in rustls_native_certs::load_native_certs().expect("could not load platform certs")
|
||||
{
|
||||
roots.add(cert).unwrap();
|
||||
}
|
||||
let tls_config = rustls::ClientConfig::builder()
|
||||
.with_root_certificates(roots)
|
||||
.with_no_client_auth();
|
||||
let tls = tokio_postgres_rustls::MakeRustlsConnect::new(tls_config);
|
||||
|
||||
let (db_client, db_conn) = tokio_postgres::connect(&db_connstr, tls).await?;
|
||||
let inner = InnerState {
|
||||
config,
|
||||
cookie_jar_key: Key::generate(),
|
||||
db_client,
|
||||
};
|
||||
Ok((Self { 0: Arc::new(inner) }, db_conn))
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for State {
|
||||
type Target = InnerState;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&*self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl FromRef<State> for Key {
|
||||
fn from_ref(state: &State) -> Self {
|
||||
state.cookie_jar_key.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct UserId {
|
||||
pub id: NonZeroI32,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct AuthRequest {
|
||||
code: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct AuthResponse {
|
||||
access_token: String,
|
||||
id_token: String,
|
||||
expires_in: u64,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct UserInfo {
|
||||
hd: String,
|
||||
sub: String,
|
||||
}
|
||||
|
||||
impl axum::extract::OptionalFromRequestParts<State> for UserId {
|
||||
type Rejection = StatusCode;
|
||||
async fn from_request_parts(
|
||||
parts: &mut axum::http::request::Parts,
|
||||
state: &State,
|
||||
) -> Result<Option<Self>, Self::Rejection> {
|
||||
let jar: PrivateCookieJar = PrivateCookieJar::from_request_parts(parts, state)
|
||||
.await
|
||||
.unwrap(); // infallible
|
||||
let Some(session_id) = jar.get(COOKIE_SID).map(|cookie| cookie.value().to_owned()) else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let client = &state.db_client;
|
||||
let query = client
|
||||
.query_opt(
|
||||
"SELECT user_id FROM sessions WHERE session_id = $1",
|
||||
&[&session_id],
|
||||
)
|
||||
.await;
|
||||
let id = match query {
|
||||
Ok(Some(row)) => row.get::<usize, i32>(0),
|
||||
Ok(None) => return Ok(None),
|
||||
Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
|
||||
};
|
||||
let id = NonZeroI32::new(id).unwrap(); // postgres id guaranteed not to be zero
|
||||
Ok(Some(Self { id }))
|
||||
}
|
||||
}
|
||||
|
||||
fn decode_id_token(token: String) -> Option<UserInfo> {
|
||||
let payload = token.split(".").skip(1).take(1).collect::<Vec<&str>>();
|
||||
let decoded = base64::decode_config(payload.get(0)?, base64::STANDARD_NO_PAD).ok()?;
|
||||
serde_json::from_slice::<UserInfo>(&decoded).ok()
|
||||
}
|
||||
|
||||
fn generate_csrf_token(num_bytes: u32) -> String {
|
||||
use rand::{Rng, thread_rng};
|
||||
let random_bytes: Vec<u8> = (0..num_bytes).map(|_| thread_rng().r#gen::<u8>()).collect();
|
||||
base64::encode_config(&random_bytes, base64::URL_SAFE_NO_PAD)
|
||||
}
|
||||
|
||||
pub async fn authorize(
|
||||
state: AxumState<State>,
|
||||
jar: PrivateCookieJar,
|
||||
) -> (PrivateCookieJar, Redirect) {
|
||||
let csrf_token = generate_csrf_token(16);
|
||||
let client_id = &state.config.oauth_client_id;
|
||||
let redirect_uri = &state.config.oauth_redirect_url;
|
||||
let auth_url = format!(
|
||||
"{OAUTH_BASE_URL}?response_type=code\
|
||||
&client_id={client_id}\
|
||||
&state={csrf_token}\
|
||||
&redirect_uri={redirect_uri}\
|
||||
&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email"
|
||||
);
|
||||
|
||||
let csrf_cookie =
|
||||
(state.config.cookie_settings)(Cookie::build((COOKIE_CSRF, csrf_token))).build();
|
||||
let url = Into::<String>::into(auth_url);
|
||||
(jar.add(csrf_cookie), Redirect::to(&url))
|
||||
}
|
||||
|
||||
pub async fn authorized(
|
||||
state: AxumState<State>,
|
||||
jar: PrivateCookieJar,
|
||||
Query(auth_request): Query<AuthRequest>,
|
||||
) -> Result<PrivateCookieJar> {
|
||||
let params = [
|
||||
("grant_type", "authorization_code"),
|
||||
("redirect_uri", &state.config.oauth_redirect_url),
|
||||
("code", &auth_request.code),
|
||||
("client_id", &state.config.oauth_client_id),
|
||||
("client_secret", &state.config.oauth_client_secret),
|
||||
];
|
||||
let auth_response = reqwest::Client::new()
|
||||
.post(OAUTH_TOKEN_URL)
|
||||
.form(¶ms)
|
||||
.send()
|
||||
.await
|
||||
.context("exchanging oauth code for token")?
|
||||
.json::<AuthResponse>()
|
||||
.await
|
||||
.context("deserializing access_token response")?;
|
||||
let Some(UserInfo { hd, sub }) = decode_id_token(auth_response.id_token) else {
|
||||
bail!("failed to decode id token")
|
||||
};
|
||||
|
||||
let allowed_domain = &state.config.oauth_allowed_domain;
|
||||
if hd != *allowed_domain {
|
||||
bail!("{hd} doesn't match {allowed_domain}")
|
||||
}
|
||||
|
||||
let token_duration = Duration::try_seconds(auth_response.expires_in as i64).unwrap();
|
||||
let expires_at = Utc.from_utc_datetime(&(Local::now().naive_local() + token_duration));
|
||||
let cookie_max_age = time::Duration::new(token_duration.num_seconds(), 0);
|
||||
|
||||
let session_cookie = (state.config.cookie_settings)(Cookie::build((
|
||||
COOKIE_SID,
|
||||
auth_response.access_token.clone(),
|
||||
)))
|
||||
.max_age(cookie_max_age)
|
||||
.build();
|
||||
|
||||
state
|
||||
.db_client
|
||||
.query(
|
||||
"WITH user_insert AS (\
|
||||
INSERT INTO users (sub) VALUES ($1) \
|
||||
ON CONFLICT (sub) DO UPDATE SET sub = excluded.sub RETURNING id)\
|
||||
INSERT INTO sessions (user_id, session_id, expires_at) \
|
||||
SELECT id, $2, $3 FROM user_insert \
|
||||
ON CONFLICT (user_id) DO UPDATE SET \
|
||||
session_id = excluded.session_id, \
|
||||
expires_at = excluded.expires_at",
|
||||
&[&sub, &auth_response.access_token, &expires_at],
|
||||
)
|
||||
.await
|
||||
.with_context(|| format!("updating session for {sub}"))?;
|
||||
|
||||
let csrf_cookie = jar.get(COOKIE_CSRF).unwrap(); // set in authorize()
|
||||
Ok(jar.remove(csrf_cookie).add(session_cookie))
|
||||
}
|
||||
@@ -1,240 +0,0 @@
|
||||
//! Shortener is a service to gate access to internal infrastructure
|
||||
//! URLs behind team authorisation to expose less private information.
|
||||
pub mod google_oauth_gate;
|
||||
use crate::google_oauth_gate::{AuthRequest, State, UserId};
|
||||
use anyhow::Result;
|
||||
use axum::Form;
|
||||
use axum::extract::State as AxumState;
|
||||
use axum::extract::{Path, Query};
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::{Html, IntoResponse};
|
||||
use axum::response::{Redirect, Response};
|
||||
use axum::routing::get;
|
||||
use axum_extra::extract::PrivateCookieJar;
|
||||
use axum_extra::extract::cookie::Cookie;
|
||||
use cookie::CookieBuilder;
|
||||
use google_oauth_gate::Config;
|
||||
use serde::Deserialize;
|
||||
use std::env;
|
||||
use tracing::{error, info};
|
||||
use tracing_subscriber::layer::SubscriberExt;
|
||||
use tracing_subscriber::util::SubscriberInitExt;
|
||||
|
||||
const SOCKET: &str = "127.0.0.1:12344";
|
||||
const HOST: &str = "http://127.0.0.1:12344";
|
||||
const COOKIE_REDIRECT: &str = "redirect";
|
||||
const ALLOWED_OAUTH_DOMAIN: &str = "neon.tech";
|
||||
const AUTHORIZED_ROUTE: &str = "/authorized";
|
||||
const SHORT_URL_LEN: usize = 6;
|
||||
|
||||
fn cookie_settings(b: CookieBuilder) -> CookieBuilder {
|
||||
if HOST.contains("127.0.0.1") {
|
||||
b.path("/")
|
||||
.secure(false)
|
||||
.same_site(axum_extra::extract::cookie::SameSite::Lax)
|
||||
.http_only(true)
|
||||
} else {
|
||||
b.path("/")
|
||||
.domain(ALLOWED_OAUTH_DOMAIN)
|
||||
.secure(true)
|
||||
.http_only(false)
|
||||
}
|
||||
}
|
||||
|
||||
fn oauth_redirect_url() -> String {
|
||||
format!("{HOST}{AUTHORIZED_ROUTE}")
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
tracing_subscriber::registry()
|
||||
.with(
|
||||
tracing_subscriber::EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| format!("{}=info", env!("CARGO_CRATE_NAME")).into()),
|
||||
)
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.init();
|
||||
|
||||
let oauth_client_id = env::var("OAUTH_CLIENT_ID").expect("Missing OAUTH_CLIENT_ID");
|
||||
let oauth_client_secret = env::var("OAUTH_CLIENT_SECRET").expect("Missing OAUTH_CLIENT_SECRET");
|
||||
let db_connstr = env::var("DB_CONNSTR").expect("Missing DB_CONNSTR");
|
||||
|
||||
let config = Config {
|
||||
oauth_client_id,
|
||||
oauth_client_secret,
|
||||
oauth_redirect_url: oauth_redirect_url(),
|
||||
oauth_allowed_domain: ALLOWED_OAUTH_DOMAIN.to_string(),
|
||||
cookie_settings,
|
||||
};
|
||||
let (state, db_conn) = State::new(config, &db_connstr).await?;
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) = db_conn.await {
|
||||
error!(%err, "connecting to database");
|
||||
std::process::exit(1);
|
||||
}
|
||||
});
|
||||
|
||||
let router = axum::Router::new()
|
||||
.route("/", get(index).post(shorten))
|
||||
.route("/authorize", get(authorize))
|
||||
.route(AUTHORIZED_ROUTE, get(authorized))
|
||||
.route("/{short_url}", get(redirect))
|
||||
.with_state(state);
|
||||
let listener = tokio::net::TcpListener::bind(SOCKET)
|
||||
.await
|
||||
.expect("failed to bind TcpListener");
|
||||
info!("listening on {SOCKET}");
|
||||
axum::serve(listener, router).await.unwrap();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct LongUrl {
|
||||
url: String,
|
||||
}
|
||||
|
||||
fn shorten_form(short_url: &str) -> Html<String> {
|
||||
let mut form = r#"
|
||||
<div style="margin:auto;width:50%;padding:10px">
|
||||
<form method="post">
|
||||
<input type="text" name="url" style="width:100%">
|
||||
<input type="submit" value="Shorten" style="margin-top:10px">
|
||||
</form>"#
|
||||
.to_string();
|
||||
if !short_url.is_empty() {
|
||||
form += &format!(
|
||||
r#"
|
||||
<p>
|
||||
<a id="short" href="{0}">{0}</a>
|
||||
<button onclick="copy()">Copy</button>
|
||||
</p>
|
||||
<script>
|
||||
function copy() {{
|
||||
navigator.clipboard.writeText(document.querySelector("\#short").textContent);
|
||||
}}
|
||||
</script>"#,
|
||||
short_url
|
||||
);
|
||||
}
|
||||
form += "</div>";
|
||||
Html(form)
|
||||
}
|
||||
|
||||
fn authorize_link(short_url: &str) -> Html<String> {
|
||||
Html(format!(
|
||||
"<a href=\"/authorize?short_url={short_url}\">Authorize</a>"
|
||||
))
|
||||
}
|
||||
|
||||
async fn index(user: Option<UserId>) -> Html<String> {
|
||||
if user.is_some() {
|
||||
return shorten_form("");
|
||||
}
|
||||
authorize_link("")
|
||||
}
|
||||
|
||||
async fn shorten(
|
||||
state: AxumState<State>,
|
||||
user: Option<UserId>,
|
||||
Form(LongUrl { url }): Form<LongUrl>,
|
||||
) -> Response {
|
||||
let user_id = match user {
|
||||
None => return StatusCode::FORBIDDEN.into_response(),
|
||||
Some(user) => user.id.get(),
|
||||
};
|
||||
if url.is_empty() {
|
||||
return shorten_form("").into_response();
|
||||
}
|
||||
|
||||
let mut short_url = "".to_string();
|
||||
for i in 0..20 {
|
||||
short_url = nanoid::nanoid!(SHORT_URL_LEN);
|
||||
let query = state
|
||||
.db_client
|
||||
.query_opt(
|
||||
"INSERT INTO urls (user_id, short_url, long_url) VALUES ($1, $2, $3) \
|
||||
ON CONFLICT (short_url) DO NOTHING \
|
||||
RETURNING short_url",
|
||||
&[&user_id, &short_url, &url],
|
||||
)
|
||||
.await;
|
||||
match query {
|
||||
Ok(Some(_)) => break,
|
||||
Ok(None) => {
|
||||
info!(short_url, "url clash, retry {i}");
|
||||
continue;
|
||||
}
|
||||
Err(err) => {
|
||||
error!(%err, "inserting shortened url");
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
};
|
||||
}
|
||||
shorten_form(&format!("{HOST}/{short_url}")).into_response()
|
||||
}
|
||||
|
||||
async fn redirect(
|
||||
state: AxumState<State>,
|
||||
user: Option<UserId>,
|
||||
Path(short_url): Path<String>,
|
||||
) -> Response {
|
||||
let user_id = match user {
|
||||
None => return authorize_link(&short_url).into_response(),
|
||||
Some(user) => user.id,
|
||||
};
|
||||
|
||||
let query = state
|
||||
.db_client
|
||||
.query_opt(
|
||||
"SELECT long_url FROM urls WHERE short_url = $1",
|
||||
&[&short_url],
|
||||
)
|
||||
.await;
|
||||
match query {
|
||||
Ok(Some(row)) => Redirect::permanent(row.get(0)).into_response(),
|
||||
Ok(None) => StatusCode::NOT_FOUND.into_response(),
|
||||
Err(err) => {
|
||||
error!(%err, %short_url, %user_id, "querying long url");
|
||||
StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct AuthorizeQuery {
|
||||
short_url: String,
|
||||
}
|
||||
|
||||
async fn authorize(
|
||||
state: AxumState<State>,
|
||||
jar: PrivateCookieJar,
|
||||
Query(AuthorizeQuery { short_url }): Query<AuthorizeQuery>,
|
||||
) -> (PrivateCookieJar, Redirect) {
|
||||
let (jar, auth_redirect) = google_oauth_gate::authorize(state, jar).await;
|
||||
let redirect_cookie = Cookie::build((COOKIE_REDIRECT, short_url))
|
||||
.path("/")
|
||||
//.TODO secure(true) not true for localhost
|
||||
//.domain(COOKIE_DOMAIN)
|
||||
.secure(false)
|
||||
.same_site(axum_extra::extract::cookie::SameSite::Lax)
|
||||
.http_only(true)
|
||||
.build();
|
||||
(jar.add(redirect_cookie), auth_redirect)
|
||||
}
|
||||
|
||||
async fn authorized(
|
||||
state: AxumState<State>,
|
||||
jar: PrivateCookieJar,
|
||||
query: Query<AuthRequest>,
|
||||
) -> Result<(PrivateCookieJar, Redirect), Response> {
|
||||
use google_oauth_gate::authorized;
|
||||
let jar = authorized(state, jar, query).await.map_err(|err| {
|
||||
error!(%err, "authorizing");
|
||||
return StatusCode::UNAUTHORIZED.into_response();
|
||||
})?;
|
||||
let Some(redirect_cookie) = jar.get(COOKIE_REDIRECT) else {
|
||||
return Ok((jar, Redirect::to("/")));
|
||||
};
|
||||
let redirect_url = Redirect::to(&format!("/{}", redirect_cookie.value_trimmed()));
|
||||
Ok((jar.remove(redirect_cookie), redirect_url))
|
||||
}
|
||||
@@ -9,7 +9,6 @@ from pathlib import Path
|
||||
from threading import Event
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.errors
|
||||
import pytest
|
||||
from fixtures.common_types import Lsn, TenantId, TenantShardId, TimelineId
|
||||
from fixtures.fast_import import (
|
||||
@@ -1070,15 +1069,41 @@ def test_fast_import_restore_to_connstring_from_s3_spec(
|
||||
return mock_kms.encrypt(KeyId=key_id, Plaintext=x)
|
||||
|
||||
# Start source postgres and ingest data
|
||||
vanilla_pg.configure(["shared_preload_libraries='neon,neon_utils,neon_rmgr'"])
|
||||
vanilla_pg.start()
|
||||
vanilla_pg.safe_psql("CREATE TABLE foo (a int); INSERT INTO foo SELECT generate_series(1, 10);")
|
||||
res = vanilla_pg.safe_psql("SHOW shared_preload_libraries;")
|
||||
log.info(f"shared_preload_libraries: {res}")
|
||||
res = vanilla_pg.safe_psql("SELECT name FROM pg_available_extensions;")
|
||||
log.info(f"pg_available_extensions: {res}")
|
||||
res = vanilla_pg.safe_psql("SELECT extname FROM pg_extension;")
|
||||
log.info(f"pg_extension: {res}")
|
||||
|
||||
# Create a number of extensions, we only will dump selected ones
|
||||
vanilla_pg.safe_psql("CREATE EXTENSION neon;")
|
||||
vanilla_pg.safe_psql("CREATE EXTENSION neon_utils;")
|
||||
vanilla_pg.safe_psql("CREATE EXTENSION pg_visibility;")
|
||||
|
||||
# Default schema is always dumped
|
||||
vanilla_pg.safe_psql(
|
||||
"CREATE TABLE public.foo (a int); INSERT INTO public.foo SELECT generate_series(1, 7);"
|
||||
)
|
||||
|
||||
# Create a number of schemas, we only will dump selected ones
|
||||
vanilla_pg.safe_psql("CREATE SCHEMA custom;")
|
||||
vanilla_pg.safe_psql(
|
||||
"CREATE TABLE custom.foo (a int); INSERT INTO custom.foo SELECT generate_series(1, 13);"
|
||||
)
|
||||
vanilla_pg.safe_psql("CREATE SCHEMA other;")
|
||||
vanilla_pg.safe_psql(
|
||||
"CREATE TABLE other.foo (a int); INSERT INTO other.foo SELECT generate_series(1, 42);"
|
||||
)
|
||||
|
||||
# Start target postgres
|
||||
pgdatadir = test_output_dir / "destination-pgdata"
|
||||
pg_bin = PgBin(test_output_dir, pg_distrib_dir, pg_version)
|
||||
port = port_distributor.get_port()
|
||||
with VanillaPostgres(pgdatadir, pg_bin, port) as destination_vanilla_pg:
|
||||
destination_vanilla_pg.configure(["shared_preload_libraries='neon_rmgr'"])
|
||||
destination_vanilla_pg.configure(["shared_preload_libraries='neon,neon_utils,neon_rmgr'"])
|
||||
destination_vanilla_pg.start()
|
||||
|
||||
# Encrypt connstrings and put spec into S3
|
||||
@@ -1092,6 +1117,8 @@ def test_fast_import_restore_to_connstring_from_s3_spec(
|
||||
"destination_connstring_ciphertext_base64": base64.b64encode(
|
||||
destination_connstring_encrypted["CiphertextBlob"]
|
||||
).decode("utf-8"),
|
||||
"schemas": ["custom"],
|
||||
"extensions": ["plpgsql", "neon"],
|
||||
}
|
||||
|
||||
bucket = "test-bucket"
|
||||
@@ -1117,9 +1144,31 @@ def test_fast_import_restore_to_connstring_from_s3_spec(
|
||||
}, f"got status: {job_status}"
|
||||
vanilla_pg.stop()
|
||||
|
||||
res = destination_vanilla_pg.safe_psql("SELECT count(*) FROM foo;")
|
||||
res = destination_vanilla_pg.safe_psql("SELECT count(*) FROM public.foo;")
|
||||
log.info(f"Result: {res}")
|
||||
assert res[0][0] == 10
|
||||
assert res[0][0] == 7
|
||||
|
||||
res = destination_vanilla_pg.safe_psql("SELECT count(*) FROM custom.foo;")
|
||||
log.info(f"Result: {res}")
|
||||
assert res[0][0] == 13
|
||||
|
||||
# Check that other schema is not restored
|
||||
with pytest.raises(psycopg2.errors.UndefinedTable):
|
||||
destination_vanilla_pg.safe_psql("SELECT count(*) FROM other.foo;")
|
||||
|
||||
# Check that all schemas are listed correctly
|
||||
res = destination_vanilla_pg.safe_psql("SELECT nspname FROM pg_namespace;")
|
||||
log.info(f"Result: {res}")
|
||||
schemas = [row[0] for row in res]
|
||||
assert "other" not in schemas
|
||||
|
||||
# Check that only selected extensions are restored
|
||||
res = destination_vanilla_pg.safe_psql("SELECT extname FROM pg_extension;")
|
||||
log.info(f"Result: {res}")
|
||||
assert len(res) == 2
|
||||
extensions = set([str(row[0]) for row in res])
|
||||
assert "plpgsql" in extensions
|
||||
assert "neon" in extensions
|
||||
|
||||
|
||||
def test_fast_import_restore_to_connstring_error_to_s3_bad_destination(
|
||||
|
||||
2
vendor/postgres-v14
vendored
2
vendor/postgres-v14
vendored
Submodule vendor/postgres-v14 updated: b6eece3f52...b1e9959858
2
vendor/postgres-v15
vendored
2
vendor/postgres-v15
vendored
Submodule vendor/postgres-v15 updated: 20f8491225...cd0b534a76
2
vendor/postgres-v16
vendored
2
vendor/postgres-v16
vendored
Submodule vendor/postgres-v16 updated: 77c63bfebf...94ad7e11cd
2
vendor/postgres-v17
vendored
2
vendor/postgres-v17
vendored
Submodule vendor/postgres-v17 updated: 32d704d965...10c0029104
8
vendor/revisions.json
vendored
8
vendor/revisions.json
vendored
@@ -1,18 +1,18 @@
|
||||
{
|
||||
"v17": [
|
||||
"17.5",
|
||||
"8be779fd3ab9e87206da96a7e4842ef1abf04f44"
|
||||
"10c002910447b3138e13213befca662df7cbe1d0"
|
||||
],
|
||||
"v16": [
|
||||
"16.9",
|
||||
"0bf96bd6d70301a0b43b0b3457bb3cf8fb43c198"
|
||||
"94ad7e11cd43cce32f5af5674af29b3f334551a7"
|
||||
],
|
||||
"v15": [
|
||||
"15.13",
|
||||
"de7640f55da07512834d5cc40c4b3fb376b5f04f"
|
||||
"cd0b534a761c18d8ef4654d4f749c63c5663215f"
|
||||
],
|
||||
"v14": [
|
||||
"14.18",
|
||||
"55c0d45abe6467c02084c2192bca117eda6ce1e7"
|
||||
"b1e9959858f0529ea33d3cc5e833c0acc43f583a"
|
||||
]
|
||||
}
|
||||
|
||||
@@ -20,7 +20,6 @@ anstream = { version = "0.6" }
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
axum = { version = "0.8", features = ["ws"] }
|
||||
axum-core = { version = "0.5", default-features = false, features = ["tracing"] }
|
||||
axum-extra = { version = "0.10", features = ["cookie-private", "typed-header"] }
|
||||
base64-594e8ee84c453af0 = { package = "base64", version = "0.13", features = ["alloc"] }
|
||||
base64-647d43efb71741da = { package = "base64", version = "0.21" }
|
||||
base64ct = { version = "1", default-features = false, features = ["std"] }
|
||||
@@ -31,7 +30,6 @@ clap = { version = "4", features = ["derive", "env", "string"] }
|
||||
clap_builder = { version = "4", default-features = false, features = ["color", "env", "help", "std", "string", "suggestions", "usage"] }
|
||||
const-oid = { version = "0.9", default-features = false, features = ["db", "std"] }
|
||||
crypto-bigint = { version = "0.5", features = ["generic-array", "zeroize"] }
|
||||
crypto-common = { version = "0.1", default-features = false, features = ["getrandom", "std"] }
|
||||
der = { version = "0.7", default-features = false, features = ["derive", "flagset", "oid", "pem", "std"] }
|
||||
deranged = { version = "0.3", default-features = false, features = ["powerfmt", "serde", "std"] }
|
||||
digest = { version = "0.10", features = ["mac", "oid", "std"] }
|
||||
|
||||
Reference in New Issue
Block a user