Compare commits

..

11 Commits

Author SHA1 Message Date
piercypixel
4f4a96ea25 Skip custom extensions in fast import 2025-06-03 14:18:53 +00:00
Erik Grinaker
5bdba70f7d page_api: only validate Protobuf → domain type conversion (#12115)
## Problem

Currently, `page_api` domain types validate message invariants both when
converting Protobuf → domain and domain → Protobuf. This is annoying for
clients, because they can't use stream combinators to convert streamed
requests (needed for hot path performance), and also performs the
validation twice in the common case.

Blocks #12099.

## Summary of changes

Only validate the Protobuf → domain type conversion, i.e. on the
receiver side, and make domain → Protobuf infallible. This is where it
matters -- the Protobuf types are less strict than the domain types, and
receivers should expect all sorts of junk from senders (they're not
required to validate anyway, and can just construct an invalid message
manually).

Also adds a missing `impl From<CheckRelExistsRequest> for
proto::CheckRelExistsRequest`.
2025-06-03 13:50:41 +00:00
Trung Dinh
25fffd3a55 Validate max_batch_size against max_get_vectored_keys (#12052)
## Problem
Setting `max_batch_size` to anything higher than
`Timeline::MAX_GET_VECTORED_KEYS` will cause runtime error. We should
rather fail fast at startup if this is the case.

## Summary of changes
* Create `max_get_vectored_keys` as a new configuration (default to 32);
* Validate `max_batch_size` against `max_get_vectored_keys` right at
config parsing and validation.

Closes https://github.com/neondatabase/neon/issues/11994
2025-06-03 13:37:11 +00:00
Erik Grinaker
e00fd45bba page_api: remove smallvec (#12095)
## Problem

The gRPC `page_api` domain types used smallvecs to avoid heap
allocations in the common case where a single page is requested.

However, this is pointless: the Protobuf types use a normal vec, and
converting a smallvec into a vec always causes a heap allocation anyway.

## Summary of changes

Use a normal `Vec` instead of a `SmallVec` in `page_api` domain types.
2025-06-03 12:20:34 +00:00
Vlad Lazar
3b8be98b67 pageserver: remove backtrace in info level log (#12108)
## Problem

We print a backtrace in an info level log every 10 seconds while waiting
for the import data to land in the bucket.

## Summary of changes

The backtrace is not useful. Remove it.
2025-06-03 09:07:07 +00:00
a-masterov
3e72edede5 Use full hostname for ONNX URL (#12064)
## Problem
We should use the full host name for computes, according to
https://github.com/neondatabase/cloud/issues/26005 , but now a truncated
host name is used.
## Summary of changes
The URL for REMOTE_ONNX is rewritten using the FQDN.
2025-06-03 07:23:17 +00:00
Alex Chi Z.
a650f7f5af fix(pageserver): only deserialize reldir key once during get_db_size (#12102)
## Problem

fix https://github.com/neondatabase/neon/issues/12101; this is a quick
hack and we need better API in the future.

In `get_db_size`, we call `get_reldir_size` for every relation. However,
we do the same deserializing the reldir directory thing for every
relation. This creates huge CPU overhead.

## Summary of changes

Get and deserialize the reldir v1 key once and use it across all
get_rel_size requests.

---------

Signed-off-by: Alex Chi Z <chi@neon.tech>
2025-06-03 05:00:34 +00:00
Erik Grinaker
fc3994eb71 pageserver: initial gRPC page service implementation (#12094)
## Problem

We should expose the page service over gRPC.

Requires #12093.
Touches #11728.

## Summary of changes

This patch adds an initial page service implementation over gRPC. It
ties in with the existing `PageServerHandler` request logic, to avoid
the implementations drifting apart for the core read path.

This is just a bare-bones functional implementation. Several important
aspects have been omitted, and will be addressed in follow-up PRs:

* Limited observability: minimal tracing, no logging, limited metrics
and timing, etc.
* Rate limiting will currently block.
* No performance optimization.
* No cancellation handling.
* No tests.

I've only done rudimentary testing of this, but Pagebench passes at
least.
2025-06-02 17:15:18 +00:00
Conrad Ludgate
781bf4945d proxy: optimise future layout allocations (#12104)
A smaller version of #12066 that is somewhat easier to review.

Now that I've been using https://crates.io/crates/top-type-sizes I've
found a lot more of the low hanging fruit that can be tweaks to reduce
the memory usage.

Some context for the optimisations:

Rust's stack allocation in futures is quite naive. Stack variables, even
if moved, often still end up taking space in the future. Rearranging the
order in which variables are defined, and properly scoping them can go a
long way.

`async fn` and `async move {}` have a consequence that they always
duplicate the "upvars" (aka captures). All captures are permanently
allocated in the future, even if moved. We can be mindful when writing
futures to only capture as little as possible.

TlsStream is massive. Needs boxing so it doesn't contribute to the above
issue.

## Measurements from `top-type-sizes`:

### Before

```
10328 {async block@proxy::proxy::task_main::{closure#0}::{closure#0}} align=8
6120 {async fn body of proxy::proxy::handle_client<proxy::protocol2::ChainRW<tokio::net::TcpStream>>()} align=8
```

### After

```
4040 {async block@proxy::proxy::task_main::{closure#0}::{closure#0}}
4704 {async fn body of proxy::proxy::handle_client<proxy::protocol2::ChainRW<tokio::net::TcpStream>>()} align=8
```
2025-06-02 16:13:30 +00:00
Erik Grinaker
a21c1174ed pagebench: add gRPC support for get-page-latest-lsn (#12077)
## Problem

We need gRPC support in Pagebench to benchmark the new gRPC Pageserver
implementation.

Touches #11728.

## Summary of changes

Adds a `Client` trait to make the client transport swappable, and a gRPC
client via a `--protocol grpc` parameter. This must also specify the
connstring with the gRPC port:

```
pagebench get-page-latest-lsn --protocol grpc --page-service-connstring grpc://localhost:51051
```

The client is implemented using the raw Tonic-generated gRPC client, to
minimize client overhead.
2025-06-02 14:50:49 +00:00
Erik Grinaker
8d7ed2a4ee pageserver: add gRPC observability middleware (#12093)
## Problem

The page service logic asserts that a tracing span is present with
tenant/timeline/shard IDs. An initial gRPC page service implementation
thus requires a tracing span.

Touches https://github.com/neondatabase/neon/issues/11728.

## Summary of changes

Adds an `ObservabilityLayer` middleware that generates a tracing span
and decorates it with IDs from the gRPC metadata.

This is a minimal implementation to address the tracing span assertion.
It will be extended with additional observability in later PRs.
2025-06-02 11:46:50 +00:00
50 changed files with 1494 additions and 1549 deletions

8
.gitmodules vendored
View File

@@ -1,16 +1,16 @@
[submodule "vendor/postgres-v14"]
path = vendor/postgres-v14
url = https://github.com/neondatabase/postgres.git
branch = REL_14_STABLE_neon
branch = 28934-pg-dump-schema-no-create-v14
[submodule "vendor/postgres-v15"]
path = vendor/postgres-v15
url = https://github.com/neondatabase/postgres.git
branch = REL_15_STABLE_neon
branch = 28934-pg-dump-schema-no-create-v15
[submodule "vendor/postgres-v16"]
path = vendor/postgres-v16
url = https://github.com/neondatabase/postgres.git
branch = REL_16_STABLE_neon
branch = 28934-pg-dump-schema-no-create-v16
[submodule "vendor/postgres-v17"]
path = vendor/postgres-v17
url = https://github.com/neondatabase/postgres.git
branch = REL_17_STABLE_neon
branch = 28934-pg-dump-schema-no-create-v17

187
Cargo.lock generated
View File

@@ -29,41 +29,6 @@ version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627"
[[package]]
name = "aead"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0"
dependencies = [
"crypto-common",
"generic-array",
]
[[package]]
name = "aes"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0"
dependencies = [
"cfg-if",
"cipher",
"cpufeatures",
]
[[package]]
name = "aes-gcm"
version = "0.10.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "831010a0f742e1209b3bcea8fab6a8e149051ba6099432c8cb2cc117dec3ead1"
dependencies = [
"aead",
"aes",
"cipher",
"ctr",
"ghash",
"subtle",
]
[[package]]
name = "ahash"
version = "0.8.11"
@@ -788,7 +753,6 @@ dependencies = [
"axum",
"axum-core",
"bytes",
"cookie",
"futures-util",
"headers",
"http 1.1.0",
@@ -1209,16 +1173,6 @@ dependencies = [
"half",
]
[[package]]
name = "cipher"
version = "0.4.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad"
dependencies = [
"crypto-common",
"inout",
]
[[package]]
name = "clang-sys"
version = "1.6.1"
@@ -1510,21 +1464,6 @@ dependencies = [
"workspace_hack",
]
[[package]]
name = "cookie"
version = "0.18.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ddef33a339a91ea89fb53151bd0a4689cfce27055c291dfa69945475d22c747"
dependencies = [
"aes-gcm",
"base64 0.22.1",
"percent-encoding",
"rand 0.8.5",
"subtle",
"time",
"version_check",
]
[[package]]
name = "core-foundation"
version = "0.9.3"
@@ -1718,19 +1657,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3"
dependencies = [
"generic-array",
"rand_core 0.6.4",
"typenum",
]
[[package]]
name = "ctr"
version = "0.9.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0369ee1ad671834580515889b80f2ea915f23b8be8d0daa4bbaf2ac5c7590835"
dependencies = [
"cipher",
]
[[package]]
name = "curve25519-dalek"
version = "4.1.3"
@@ -2581,16 +2510,6 @@ dependencies = [
"winapi",
]
[[package]]
name = "ghash"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0d8a4362ccb29cb0b265253fb0a2728f592895ee6854fd9bc13f2ffda266ff1"
dependencies = [
"opaque-debug",
"polyval",
]
[[package]]
name = "gimli"
version = "0.31.1"
@@ -3362,15 +3281,6 @@ dependencies = [
"libc",
]
[[package]]
name = "inout"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "879f10e63c20629ecabbb64a8010319738c66a5cd0c29b02d63d272b03751d01"
dependencies = [
"generic-array",
]
[[package]]
name = "instant"
version = "0.1.12"
@@ -3884,15 +3794,6 @@ version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a"
[[package]]
name = "nanoid"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3ffa00dec017b5b1a8b7cf5e2c008bfda1aa7e0697ac1508b491fdf2622fb4d8"
dependencies = [
"rand 0.8.5",
]
[[package]]
name = "neon-shmem"
version = "0.1.0"
@@ -4165,12 +4066,6 @@ version = "11.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575"
[[package]]
name = "opaque-debug"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381"
[[package]]
name = "openssl-probe"
version = "0.1.5"
@@ -4341,6 +4236,7 @@ name = "pagebench"
version = "0.1.0"
dependencies = [
"anyhow",
"async-trait",
"camino",
"clap",
"futures",
@@ -4349,12 +4245,15 @@ dependencies = [
"humantime-serde",
"pageserver_api",
"pageserver_client",
"pageserver_page_api",
"rand 0.8.5",
"reqwest",
"serde",
"serde_json",
"tokio",
"tokio-stream",
"tokio-util",
"tonic 0.13.1",
"tracing",
"utils",
"workspace_hack",
@@ -4410,6 +4309,7 @@ dependencies = [
"hashlink",
"hex",
"hex-literal",
"http 1.1.0",
"http-utils",
"humantime",
"humantime-serde",
@@ -4472,6 +4372,7 @@ dependencies = [
"toml_edit",
"tonic 0.13.1",
"tonic-reflection",
"tower 0.5.2",
"tracing",
"tracing-utils",
"twox-hash",
@@ -4568,7 +4469,6 @@ dependencies = [
"pageserver_api",
"postgres_ffi",
"prost 0.13.5",
"smallvec",
"thiserror 1.0.69",
"tonic 0.13.1",
"tonic-build",
@@ -4690,31 +4590,6 @@ version = "1.0.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c"
[[package]]
name = "paster"
version = "0.1.0"
dependencies = [
"anyhow",
"axum",
"axum-extra",
"base64 0.13.1",
"chrono",
"nanoid",
"rand 0.8.5",
"reqwest",
"rustls 0.23.27",
"rustls-native-certs 0.8.0",
"serde",
"serde_json",
"time",
"tokio",
"tokio-postgres",
"tokio-postgres-rustls",
"tracing",
"tracing-subscriber",
"workspace_hack",
]
[[package]]
name = "pbkdf2"
version = "0.12.2"
@@ -4887,18 +4762,6 @@ dependencies = [
"never-say-never",
]
[[package]]
name = "polyval"
version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d1fe60d06143b2430aa532c94cfe9e29783047f06c0d7fd359a9a51b729fa25"
dependencies = [
"cfg-if",
"cpufeatures",
"opaque-debug",
"universal-hash",
]
[[package]]
name = "portable-atomic"
version = "1.10.0"
@@ -6701,32 +6564,6 @@ version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
[[package]]
name = "shortener"
version = "0.1.0"
dependencies = [
"anyhow",
"axum",
"axum-extra",
"base64 0.13.1",
"chrono",
"cookie",
"nanoid",
"rand 0.8.5",
"reqwest",
"rustls 0.23.27",
"rustls-native-certs 0.8.0",
"serde",
"serde_json",
"time",
"tokio",
"tokio-postgres",
"tokio-postgres-rustls",
"tracing",
"tracing-subscriber",
"workspace_hack",
]
[[package]]
name = "signal-hook"
version = "0.3.15"
@@ -8095,16 +7932,6 @@ version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c"
[[package]]
name = "universal-hash"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc1de2c688dc15305988b563c3854064043356019f97a4b46276fe734c4f07ea"
dependencies = [
"crypto-common",
"subtle",
]
[[package]]
name = "untrusted"
version = "0.9.0"
@@ -8737,7 +8564,6 @@ dependencies = [
"anyhow",
"axum",
"axum-core",
"axum-extra",
"base64 0.13.1",
"base64 0.21.7",
"base64ct",
@@ -8749,7 +8575,6 @@ dependencies = [
"clap_builder",
"const-oid",
"crypto-bigint 0.5.5",
"crypto-common",
"der 0.7.8",
"deranged",
"digest",

View File

@@ -13,8 +13,6 @@ members = [
"proxy",
"safekeeper",
"safekeeper/client",
"shortener",
"paster",
"storage_broker",
"storage_controller",
"storage_controller/client",

View File

@@ -1180,14 +1180,14 @@ RUN cd exts/rag && \
RUN cd exts/rag_bge_small_en_v15 && \
sed -i 's/pgrx = "0.14.1"/pgrx = { version = "0.14.1", features = [ "unsafe-postgres" ] }/g' Cargo.toml && \
ORT_LIB_LOCATION=/ext-src/onnxruntime-src/build/Linux \
REMOTE_ONNX_URL=http://pg-ext-s3-gateway/pgrag-data/bge_small_en_v15.onnx \
REMOTE_ONNX_URL=http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local/pgrag-data/bge_small_en_v15.onnx \
cargo pgrx install --release --features remote_onnx && \
echo "trusted = true" >> /usr/local/pgsql/share/extension/rag_bge_small_en_v15.control
RUN cd exts/rag_jina_reranker_v1_tiny_en && \
sed -i 's/pgrx = "0.14.1"/pgrx = { version = "0.14.1", features = [ "unsafe-postgres" ] }/g' Cargo.toml && \
ORT_LIB_LOCATION=/ext-src/onnxruntime-src/build/Linux \
REMOTE_ONNX_URL=http://pg-ext-s3-gateway/pgrag-data/jina_reranker_v1_tiny_en.onnx \
REMOTE_ONNX_URL=http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local/pgrag-data/jina_reranker_v1_tiny_en.onnx \
cargo pgrx install --release --features remote_onnx && \
echo "trusted = true" >> /usr/local/pgsql/share/extension/rag_jina_reranker_v1_tiny_en.control

View File

@@ -70,6 +70,14 @@ enum Command {
/// and maintenance_work_mem.
#[clap(long, env = "NEON_IMPORTER_MEMORY_MB")]
memory_mb: Option<usize>,
/// List of schemas to dump.
#[clap(long)]
schema: Vec<String>,
/// List of extensions to dump.
#[clap(long)]
extension: Vec<String>,
},
/// Runs pg_dump-pg_restore from source to destination without running local postgres.
@@ -82,6 +90,12 @@ enum Command {
/// real scenario uses encrypted connection string in spec.json from s3.
#[clap(long)]
destination_connection_string: Option<String>,
/// List of schemas to dump.
#[clap(long)]
schema: Vec<String>,
/// List of extensions to dump.
#[clap(long)]
extension: Vec<String>,
},
}
@@ -117,6 +131,8 @@ struct Spec {
source_connstring_ciphertext_base64: Vec<u8>,
#[serde_as(as = "Option<serde_with::base64::Base64>")]
destination_connstring_ciphertext_base64: Option<Vec<u8>>,
schemas: Option<Vec<String>>,
extensions: Option<Vec<String>>,
}
#[derive(serde::Deserialize)]
@@ -337,6 +353,8 @@ async fn run_dump_restore(
pg_lib_dir: Utf8PathBuf,
source_connstring: String,
destination_connstring: String,
schemas: Vec<String>,
extensions: Vec<String>,
) -> Result<(), anyhow::Error> {
let dumpdir = workdir.join("dumpdir");
let num_jobs = num_cpus::get().to_string();
@@ -351,6 +369,7 @@ async fn run_dump_restore(
"--no-subscriptions".to_string(),
"--no-tablespaces".to_string(),
"--no-event-triggers".to_string(),
"--enable-row-security".to_string(),
// format
"--format".to_string(),
"directory".to_string(),
@@ -361,10 +380,36 @@ async fn run_dump_restore(
"--verbose".to_string(),
];
let mut pg_dump_args = vec![
// this makes sure any unsupported extensions are not included in the dump
// even if we don't specify supported extensions explicitly
"--extension".to_string(),
"plpgsql".to_string(),
];
// if no schemas are specified, try to import all schemas
if !schemas.is_empty() {
// always include public schema objects
// but never create the schema itself
// it already exists in any pg cluster by default
pg_dump_args.push("--schema-no-create".to_string());
pg_dump_args.push("public".to_string());
for schema in &schemas {
pg_dump_args.push("--schema".to_string());
pg_dump_args.push(schema.clone());
}
}
for extension in &extensions {
pg_dump_args.push("--extension".to_string());
pg_dump_args.push(extension.clone());
}
info!("dump into the working directory");
{
let mut pg_dump = tokio::process::Command::new(pg_bin_dir.join("pg_dump"))
.args(&common_args)
.args(&pg_dump_args)
.arg("-f")
.arg(&dumpdir)
.arg("--no-sync")
@@ -455,6 +500,8 @@ async fn cmd_pgdata(
maybe_s3_prefix: Option<s3_uri::S3Uri>,
maybe_spec: Option<Spec>,
source_connection_string: Option<String>,
schemas: Vec<String>,
extensions: Vec<String>,
interactive: bool,
pg_port: u16,
workdir: Utf8PathBuf,
@@ -470,19 +517,25 @@ async fn cmd_pgdata(
bail!("only one of spec or source_connection_string can be provided");
}
let source_connection_string = if let Some(spec) = maybe_spec {
let (source_connection_string, schemas, extensions) = if let Some(spec) = maybe_spec {
match spec.encryption_secret {
EncryptionSecret::KMS { key_id } => {
decode_connstring(
let schemas = spec.schemas.unwrap_or(vec![]);
let extensions = spec.extensions.unwrap_or(vec![]);
let source = decode_connstring(
kms_client.as_ref().unwrap(),
&key_id,
spec.source_connstring_ciphertext_base64,
)
.await?
.await
.context("decrypt source connection string")?;
(source, schemas, extensions)
}
}
} else {
source_connection_string.unwrap()
(source_connection_string.unwrap(), schemas, extensions)
};
let superuser = "cloud_admin";
@@ -504,6 +557,8 @@ async fn cmd_pgdata(
pg_lib_dir,
source_connection_string,
destination_connstring,
schemas,
extensions,
)
.await?;
@@ -546,18 +601,26 @@ async fn cmd_pgdata(
Ok(())
}
#[allow(clippy::too_many_arguments)]
async fn cmd_dumprestore(
kms_client: Option<aws_sdk_kms::Client>,
maybe_spec: Option<Spec>,
source_connection_string: Option<String>,
destination_connection_string: Option<String>,
schemas: Vec<String>,
extensions: Vec<String>,
workdir: Utf8PathBuf,
pg_bin_dir: Utf8PathBuf,
pg_lib_dir: Utf8PathBuf,
) -> Result<(), anyhow::Error> {
let (source_connstring, destination_connstring) = if let Some(spec) = maybe_spec {
let (source_connstring, destination_connstring, schemas, extensions) = if let Some(spec) =
maybe_spec
{
match spec.encryption_secret {
EncryptionSecret::KMS { key_id } => {
let schemas = spec.schemas.unwrap_or(vec![]);
let extensions = spec.extensions.unwrap_or(vec![]);
let source = decode_connstring(
kms_client.as_ref().unwrap(),
&key_id,
@@ -578,18 +641,17 @@ async fn cmd_dumprestore(
);
};
(source, dest)
(source, dest, schemas, extensions)
}
}
} else {
(
source_connection_string.unwrap(),
if let Some(val) = destination_connection_string {
val
} else {
bail!("destination connection string must be provided for dump_restore command");
},
)
let dest = if let Some(val) = destination_connection_string {
val
} else {
bail!("destination connection string must be provided for dump_restore command");
};
(source_connection_string.unwrap(), dest, schemas, extensions)
};
run_dump_restore(
@@ -598,6 +660,8 @@ async fn cmd_dumprestore(
pg_lib_dir,
source_connstring,
destination_connstring,
schemas,
extensions,
)
.await
}
@@ -679,6 +743,8 @@ pub(crate) async fn main() -> anyhow::Result<()> {
pg_port,
num_cpus,
memory_mb,
schema,
extension,
} => {
cmd_pgdata(
s3_client.as_ref(),
@@ -686,6 +752,8 @@ pub(crate) async fn main() -> anyhow::Result<()> {
args.s3_prefix.clone(),
spec,
source_connection_string,
schema,
extension,
interactive,
pg_port,
args.working_directory.clone(),
@@ -699,12 +767,16 @@ pub(crate) async fn main() -> anyhow::Result<()> {
Command::DumpRestore {
source_connection_string,
destination_connection_string,
schema,
extension,
} => {
cmd_dumprestore(
kms_client,
spec,
source_connection_string,
destination_connection_string,
schema,
extension,
args.working_directory.clone(),
args.pg_bin_dir,
args.pg_lib_dir,

View File

@@ -181,6 +181,7 @@ pub struct ConfigToml {
pub virtual_file_io_engine: Option<crate::models::virtual_file::IoEngineKind>,
pub ingest_batch_size: u64,
pub max_vectored_read_bytes: MaxVectoredReadBytes,
pub max_get_vectored_keys: MaxGetVectoredKeys,
pub image_compression: ImageCompressionAlgorithm,
pub timeline_offloading: bool,
pub ephemeral_bytes_per_memory_kb: usize,
@@ -229,7 +230,7 @@ pub enum PageServicePipeliningConfig {
}
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct PageServicePipeliningConfigPipelined {
/// Causes runtime errors if larger than max get_vectored batch size.
/// Failed config parsing and validation if larger than `max_get_vectored_keys`.
pub max_batch_size: NonZeroUsize,
pub execution: PageServiceProtocolPipelinedExecutionStrategy,
// The default below is such that new versions of the software can start
@@ -403,6 +404,16 @@ impl Default for EvictionOrder {
#[serde(transparent)]
pub struct MaxVectoredReadBytes(pub NonZeroUsize);
#[derive(Copy, Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(transparent)]
pub struct MaxGetVectoredKeys(NonZeroUsize);
impl MaxGetVectoredKeys {
pub fn get(&self) -> usize {
self.0.get()
}
}
/// Tenant-level configuration values, used for various purposes.
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(default)]
@@ -587,6 +598,8 @@ pub mod defaults {
/// That is, slightly above 128 kB.
pub const DEFAULT_MAX_VECTORED_READ_BYTES: usize = 130 * 1024; // 130 KiB
pub const DEFAULT_MAX_GET_VECTORED_KEYS: usize = 32;
pub const DEFAULT_IMAGE_COMPRESSION: ImageCompressionAlgorithm =
ImageCompressionAlgorithm::Zstd { level: Some(1) };
@@ -685,6 +698,9 @@ impl Default for ConfigToml {
max_vectored_read_bytes: (MaxVectoredReadBytes(
NonZeroUsize::new(DEFAULT_MAX_VECTORED_READ_BYTES).unwrap(),
)),
max_get_vectored_keys: (MaxGetVectoredKeys(
NonZeroUsize::new(DEFAULT_MAX_GET_VECTORED_KEYS).unwrap(),
)),
image_compression: (DEFAULT_IMAGE_COMPRESSION),
timeline_offloading: true,
ephemeral_bytes_per_memory_kb: (DEFAULT_EPHEMERAL_BYTES_PER_MEMORY_KB),

View File

@@ -1934,7 +1934,7 @@ pub enum PagestreamFeMessage {
}
// Wrapped in libpq CopyData
#[derive(strum_macros::EnumProperty)]
#[derive(Debug, strum_macros::EnumProperty)]
pub enum PagestreamBeMessage {
Exists(PagestreamExistsResponse),
Nblocks(PagestreamNblocksResponse),
@@ -2045,7 +2045,7 @@ pub enum PagestreamProtocolVersion {
pub type RequestId = u64;
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)]
pub struct PagestreamRequest {
pub reqid: RequestId,
pub request_lsn: Lsn,
@@ -2064,7 +2064,7 @@ pub struct PagestreamNblocksRequest {
pub rel: RelTag,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)]
pub struct PagestreamGetPageRequest {
pub hdr: PagestreamRequest,
pub rel: RelTag,

View File

@@ -24,7 +24,7 @@ use serde::{Deserialize, Serialize};
// FIXME: should move 'forknum' as last field to keep this consistent with Postgres.
// Then we could replace the custom Ord and PartialOrd implementations below with
// deriving them. This will require changes in walredoproc.c.
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize)]
#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize)]
pub struct RelTag {
pub forknum: u8,
pub spcnode: Oid,
@@ -184,12 +184,12 @@ pub enum SlruKind {
MultiXactOffsets,
}
impl SlruKind {
pub fn to_str(&self) -> &'static str {
impl fmt::Display for SlruKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Clog => "pg_xact",
Self::MultiXactMembers => "pg_multixact/members",
Self::MultiXactOffsets => "pg_multixact/offsets",
Self::Clog => write!(f, "pg_xact"),
Self::MultiXactMembers => write!(f, "pg_multixact/members"),
Self::MultiXactOffsets => write!(f, "pg_multixact/offsets"),
}
}
}

View File

@@ -73,6 +73,7 @@ pub mod error;
/// async timeout helper
pub mod timeout;
pub mod span;
pub mod sync;
pub mod failpoint_support;

19
libs/utils/src/span.rs Normal file
View File

@@ -0,0 +1,19 @@
//! Tracing span helpers.
/// Records the given fields in the current span, as a single call. The fields must already have
/// been declared for the span (typically with empty values).
#[macro_export]
macro_rules! span_record {
($($tokens:tt)*) => {$crate::span_record_in!(::tracing::Span::current(), $($tokens)*)};
}
/// Records the given fields in the given span, as a single call. The fields must already have been
/// declared for the span (typically with empty values).
#[macro_export]
macro_rules! span_record_in {
($span:expr, $($tokens:tt)*) => {
if let Some(meta) = $span.metadata() {
$span.record_all(&tracing::valueset!(meta.fields(), $($tokens)*));
}
};
}

View File

@@ -34,6 +34,7 @@ fail.workspace = true
futures.workspace = true
hashlink.workspace = true
hex.workspace = true
http.workspace = true
http-utils.workspace = true
humantime-serde.workspace = true
humantime.workspace = true
@@ -93,6 +94,7 @@ tokio-util.workspace = true
toml_edit = { workspace = true, features = [ "serde" ] }
tonic.workspace = true
tonic-reflection.workspace = true
tower.workspace = true
tracing.workspace = true
tracing-utils.workspace = true
url.workspace = true

View File

@@ -9,7 +9,6 @@ bytes.workspace = true
pageserver_api.workspace = true
postgres_ffi.workspace = true
prost.workspace = true
smallvec.workspace = true
thiserror.workspace = true
tonic.workspace = true
utils.workspace = true

View File

@@ -9,10 +9,16 @@
//! - Use more precise datatypes, e.g. Lsn and uints shorter than 32 bits.
//!
//! - Validate protocol invariants, via try_from() and try_into().
//!
//! Validation only happens on the receiver side, i.e. when converting from Protobuf to domain
//! types. This is where it matters -- the Protobuf types are less strict than the domain types, and
//! receivers should expect all sorts of junk from senders. This also allows the sender to use e.g.
//! stream combinators without dealing with errors, and avoids validating the same message twice.
use std::fmt::Display;
use bytes::Bytes;
use postgres_ffi::Oid;
use smallvec::SmallVec;
// TODO: split out Lsn, RelTag, SlruKind, Oid and other basic types to a separate crate, to avoid
// pulling in all of their other crate dependencies when building the client.
use utils::lsn::Lsn;
@@ -48,7 +54,8 @@ pub struct ReadLsn {
pub request_lsn: Lsn,
/// If given, the caller guarantees that the page has not been modified since this LSN. Must be
/// smaller than or equal to request_lsn. This allows the Pageserver to serve an old page
/// without waiting for the request LSN to arrive. Valid for all request types.
/// without waiting for the request LSN to arrive. If not given, the request will read at the
/// request_lsn and wait for it to arrive if necessary. Valid for all request types.
///
/// It is undefined behaviour to make a request such that the page was, in fact, modified
/// between request_lsn and not_modified_since_lsn. The Pageserver might detect it and return an
@@ -58,19 +65,14 @@ pub struct ReadLsn {
pub not_modified_since_lsn: Option<Lsn>,
}
impl ReadLsn {
/// Validates the ReadLsn.
pub fn validate(&self) -> Result<(), ProtocolError> {
if self.request_lsn == Lsn::INVALID {
return Err(ProtocolError::invalid("request_lsn", self.request_lsn));
impl Display for ReadLsn {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let req_lsn = self.request_lsn;
if let Some(mod_lsn) = self.not_modified_since_lsn {
write!(f, "{req_lsn}>={mod_lsn}")
} else {
req_lsn.fmt(f)
}
if self.not_modified_since_lsn > Some(self.request_lsn) {
return Err(ProtocolError::invalid(
"not_modified_since_lsn",
self.not_modified_since_lsn,
));
}
Ok(())
}
}
@@ -78,27 +80,31 @@ impl TryFrom<proto::ReadLsn> for ReadLsn {
type Error = ProtocolError;
fn try_from(pb: proto::ReadLsn) -> Result<Self, Self::Error> {
let read_lsn = Self {
if pb.request_lsn == 0 {
return Err(ProtocolError::invalid("request_lsn", pb.request_lsn));
}
if pb.not_modified_since_lsn > pb.request_lsn {
return Err(ProtocolError::invalid(
"not_modified_since_lsn",
pb.not_modified_since_lsn,
));
}
Ok(Self {
request_lsn: Lsn(pb.request_lsn),
not_modified_since_lsn: match pb.not_modified_since_lsn {
0 => None,
lsn => Some(Lsn(lsn)),
},
};
read_lsn.validate()?;
Ok(read_lsn)
})
}
}
impl TryFrom<ReadLsn> for proto::ReadLsn {
type Error = ProtocolError;
fn try_from(read_lsn: ReadLsn) -> Result<Self, Self::Error> {
read_lsn.validate()?;
Ok(Self {
impl From<ReadLsn> for proto::ReadLsn {
fn from(read_lsn: ReadLsn) -> Self {
Self {
request_lsn: read_lsn.request_lsn.0,
not_modified_since_lsn: read_lsn.not_modified_since_lsn.unwrap_or_default().0,
})
}
}
}
@@ -153,6 +159,15 @@ impl TryFrom<proto::CheckRelExistsRequest> for CheckRelExistsRequest {
}
}
impl From<CheckRelExistsRequest> for proto::CheckRelExistsRequest {
fn from(request: CheckRelExistsRequest) -> Self {
Self {
read_lsn: Some(request.read_lsn.into()),
rel: Some(request.rel.into()),
}
}
}
pub type CheckRelExistsResponse = bool;
impl From<proto::CheckRelExistsResponse> for CheckRelExistsResponse {
@@ -190,14 +205,12 @@ impl TryFrom<proto::GetBaseBackupRequest> for GetBaseBackupRequest {
}
}
impl TryFrom<GetBaseBackupRequest> for proto::GetBaseBackupRequest {
type Error = ProtocolError;
fn try_from(request: GetBaseBackupRequest) -> Result<Self, Self::Error> {
Ok(Self {
read_lsn: Some(request.read_lsn.try_into()?),
impl From<GetBaseBackupRequest> for proto::GetBaseBackupRequest {
fn from(request: GetBaseBackupRequest) -> Self {
Self {
read_lsn: Some(request.read_lsn.into()),
replica: request.replica,
})
}
}
}
@@ -214,14 +227,9 @@ impl TryFrom<proto::GetBaseBackupResponseChunk> for GetBaseBackupResponseChunk {
}
}
impl TryFrom<GetBaseBackupResponseChunk> for proto::GetBaseBackupResponseChunk {
type Error = ProtocolError;
fn try_from(chunk: GetBaseBackupResponseChunk) -> Result<Self, Self::Error> {
if chunk.is_empty() {
return Err(ProtocolError::Missing("chunk"));
}
Ok(Self { chunk })
impl From<GetBaseBackupResponseChunk> for proto::GetBaseBackupResponseChunk {
fn from(chunk: GetBaseBackupResponseChunk) -> Self {
Self { chunk }
}
}
@@ -246,14 +254,12 @@ impl TryFrom<proto::GetDbSizeRequest> for GetDbSizeRequest {
}
}
impl TryFrom<GetDbSizeRequest> for proto::GetDbSizeRequest {
type Error = ProtocolError;
fn try_from(request: GetDbSizeRequest) -> Result<Self, Self::Error> {
Ok(Self {
read_lsn: Some(request.read_lsn.try_into()?),
impl From<GetDbSizeRequest> for proto::GetDbSizeRequest {
fn from(request: GetDbSizeRequest) -> Self {
Self {
read_lsn: Some(request.read_lsn.into()),
db_oid: request.db_oid,
})
}
}
}
@@ -288,7 +294,7 @@ pub struct GetPageRequest {
/// Multiple pages will be executed as a single batch by the Pageserver, amortizing layer access
/// costs and parallelizing them. This may increase the latency of any individual request, but
/// improves the overall latency and throughput of the batch as a whole.
pub block_numbers: SmallVec<[u32; 1]>,
pub block_numbers: Vec<u32>,
}
impl TryFrom<proto::GetPageRequest> for GetPageRequest {
@@ -306,25 +312,20 @@ impl TryFrom<proto::GetPageRequest> for GetPageRequest {
.ok_or(ProtocolError::Missing("read_lsn"))?
.try_into()?,
rel: pb.rel.ok_or(ProtocolError::Missing("rel"))?.try_into()?,
block_numbers: pb.block_number.into(),
block_numbers: pb.block_number,
})
}
}
impl TryFrom<GetPageRequest> for proto::GetPageRequest {
type Error = ProtocolError;
fn try_from(request: GetPageRequest) -> Result<Self, Self::Error> {
if request.block_numbers.is_empty() {
return Err(ProtocolError::Missing("block_number"));
}
Ok(Self {
impl From<GetPageRequest> for proto::GetPageRequest {
fn from(request: GetPageRequest) -> Self {
Self {
request_id: request.request_id,
request_class: request.request_class.into(),
read_lsn: Some(request.read_lsn.try_into()?),
read_lsn: Some(request.read_lsn.into()),
rel: Some(request.rel.into()),
block_number: request.block_numbers.into_vec(),
})
block_number: request.block_numbers,
}
}
}
@@ -396,7 +397,7 @@ pub struct GetPageResponse {
/// A string describing the status, if any.
pub reason: Option<String>,
/// The 8KB page images, in the same order as the request. Empty if status != OK.
pub page_images: SmallVec<[Bytes; 1]>,
pub page_images: Vec<Bytes>,
}
impl From<proto::GetPageResponse> for GetPageResponse {
@@ -405,7 +406,7 @@ impl From<proto::GetPageResponse> for GetPageResponse {
request_id: pb.request_id,
status_code: pb.status_code.into(),
reason: Some(pb.reason).filter(|r| !r.is_empty()),
page_images: pb.page_image.into(),
page_images: pb.page_image,
}
}
}
@@ -416,7 +417,7 @@ impl From<GetPageResponse> for proto::GetPageResponse {
request_id: response.request_id,
status_code: response.status_code.into(),
reason: response.reason.unwrap_or_default(),
page_image: response.page_images.into_vec(),
page_image: response.page_images,
}
}
}
@@ -505,14 +506,12 @@ impl TryFrom<proto::GetRelSizeRequest> for GetRelSizeRequest {
}
}
impl TryFrom<GetRelSizeRequest> for proto::GetRelSizeRequest {
type Error = ProtocolError;
fn try_from(request: GetRelSizeRequest) -> Result<Self, Self::Error> {
Ok(Self {
read_lsn: Some(request.read_lsn.try_into()?),
impl From<GetRelSizeRequest> for proto::GetRelSizeRequest {
fn from(request: GetRelSizeRequest) -> Self {
Self {
read_lsn: Some(request.read_lsn.into()),
rel: Some(request.rel.into()),
})
}
}
}
@@ -555,15 +554,13 @@ impl TryFrom<proto::GetSlruSegmentRequest> for GetSlruSegmentRequest {
}
}
impl TryFrom<GetSlruSegmentRequest> for proto::GetSlruSegmentRequest {
type Error = ProtocolError;
fn try_from(request: GetSlruSegmentRequest) -> Result<Self, Self::Error> {
Ok(Self {
read_lsn: Some(request.read_lsn.try_into()?),
impl From<GetSlruSegmentRequest> for proto::GetSlruSegmentRequest {
fn from(request: GetSlruSegmentRequest) -> Self {
Self {
read_lsn: Some(request.read_lsn.into()),
kind: request.kind as u32,
segno: request.segno,
})
}
}
}
@@ -580,14 +577,9 @@ impl TryFrom<proto::GetSlruSegmentResponse> for GetSlruSegmentResponse {
}
}
impl TryFrom<GetSlruSegmentResponse> for proto::GetSlruSegmentResponse {
type Error = ProtocolError;
fn try_from(segment: GetSlruSegmentResponse) -> Result<Self, Self::Error> {
if segment.is_empty() {
return Err(ProtocolError::Missing("segment"));
}
Ok(Self { segment })
impl From<GetSlruSegmentResponse> for proto::GetSlruSegmentResponse {
fn from(segment: GetSlruSegmentResponse) -> Self {
Self { segment }
}
}

View File

@@ -8,6 +8,7 @@ license.workspace = true
[dependencies]
anyhow.workspace = true
async-trait.workspace = true
camino.workspace = true
clap.workspace = true
futures.workspace = true
@@ -15,14 +16,17 @@ hdrhistogram.workspace = true
humantime.workspace = true
humantime-serde.workspace = true
rand.workspace = true
reqwest.workspace=true
reqwest.workspace = true
serde.workspace = true
serde_json.workspace = true
tracing.workspace = true
tokio.workspace = true
tokio-stream.workspace = true
tokio-util.workspace = true
tonic.workspace = true
pageserver_client.workspace = true
pageserver_api.workspace = true
pageserver_page_api.workspace = true
utils = { path = "../../libs/utils/" }
workspace_hack = { version = "0.1", path = "../../workspace_hack" }

View File

@@ -7,11 +7,15 @@ use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use anyhow::Context;
use async_trait::async_trait;
use camino::Utf8PathBuf;
use pageserver_api::key::Key;
use pageserver_api::keyspace::KeySpaceAccum;
use pageserver_api::models::{PagestreamGetPageRequest, PagestreamRequest};
use pageserver_api::models::{
PagestreamGetPageRequest, PagestreamGetPageResponse, PagestreamRequest,
};
use pageserver_api::shard::TenantShardId;
use pageserver_page_api::proto;
use rand::prelude::*;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
@@ -22,6 +26,12 @@ use utils::lsn::Lsn;
use crate::util::tokio_thread_local_stats::AllThreadLocalStats;
use crate::util::{request_stats, tokio_thread_local_stats};
#[derive(clap::ValueEnum, Clone, Debug)]
enum Protocol {
Libpq,
Grpc,
}
/// GetPage@LatestLSN, uniformly distributed across the compute-accessible keyspace.
#[derive(clap::Parser)]
pub(crate) struct Args {
@@ -35,6 +45,8 @@ pub(crate) struct Args {
num_clients: NonZeroUsize,
#[clap(long)]
runtime: Option<humantime::Duration>,
#[clap(long, value_enum, default_value = "libpq")]
protocol: Protocol,
/// Each client sends requests at the given rate.
///
/// If a request takes too long and we should be issuing a new request already,
@@ -303,7 +315,20 @@ async fn main_impl(
.unwrap();
Box::pin(async move {
client_libpq(args, worker_id, ss, cancel, rps_period, ranges, weights).await
let client: Box<dyn Client> = match args.protocol {
Protocol::Libpq => Box::new(
LibpqClient::new(args.page_service_connstring.clone(), worker_id.timeline)
.await
.unwrap(),
),
Protocol::Grpc => Box::new(
GrpcClient::new(args.page_service_connstring.clone(), worker_id.timeline)
.await
.unwrap(),
),
};
run_worker(args, client, ss, cancel, rps_period, ranges, weights).await
})
};
@@ -355,23 +380,15 @@ async fn main_impl(
anyhow::Ok(())
}
async fn client_libpq(
async fn run_worker(
args: &Args,
worker_id: WorkerId,
mut client: Box<dyn Client>,
shared_state: Arc<SharedState>,
cancel: CancellationToken,
rps_period: Option<Duration>,
ranges: Vec<KeyRange>,
weights: rand::distributions::weighted::WeightedIndex<i128>,
) {
let client = pageserver_client::page_service::Client::new(args.page_service_connstring.clone())
.await
.unwrap();
let mut client = client
.pagestream(worker_id.timeline.tenant_id, worker_id.timeline.timeline_id)
.await
.unwrap();
shared_state.start_work_barrier.wait().await;
let client_start = Instant::now();
let mut ticks_processed = 0;
@@ -415,12 +432,12 @@ async fn client_libpq(
blkno: block_no,
}
};
client.getpage_send(req).await.unwrap();
client.send_get_page(req).await.unwrap();
inflight.push_back(start);
}
let start = inflight.pop_front().unwrap();
client.getpage_recv().await.unwrap();
client.recv_get_page().await.unwrap();
let end = Instant::now();
shared_state.live_stats.request_done();
ticks_processed += 1;
@@ -442,3 +459,104 @@ async fn client_libpq(
}
}
}
/// A benchmark client, to allow switching out the transport protocol.
///
/// For simplicity, this just uses separate asynchronous send/recv methods. The send method could
/// return a future that resolves when the response is received, but we don't really need it.
#[async_trait]
trait Client: Send {
/// Sends an asynchronous GetPage request to the pageserver.
async fn send_get_page(&mut self, req: PagestreamGetPageRequest) -> anyhow::Result<()>;
/// Receives the next GetPage response from the pageserver.
async fn recv_get_page(&mut self) -> anyhow::Result<PagestreamGetPageResponse>;
}
/// A libpq-based Pageserver client.
struct LibpqClient {
inner: pageserver_client::page_service::PagestreamClient,
}
impl LibpqClient {
async fn new(connstring: String, ttid: TenantTimelineId) -> anyhow::Result<Self> {
let inner = pageserver_client::page_service::Client::new(connstring)
.await?
.pagestream(ttid.tenant_id, ttid.timeline_id)
.await?;
Ok(Self { inner })
}
}
#[async_trait]
impl Client for LibpqClient {
async fn send_get_page(&mut self, req: PagestreamGetPageRequest) -> anyhow::Result<()> {
self.inner.getpage_send(req).await
}
async fn recv_get_page(&mut self) -> anyhow::Result<PagestreamGetPageResponse> {
self.inner.getpage_recv().await
}
}
/// A gRPC client using the raw, no-frills gRPC client.
struct GrpcClient {
req_tx: tokio::sync::mpsc::Sender<proto::GetPageRequest>,
resp_rx: tonic::Streaming<proto::GetPageResponse>,
}
impl GrpcClient {
async fn new(connstring: String, ttid: TenantTimelineId) -> anyhow::Result<Self> {
let mut client = pageserver_page_api::proto::PageServiceClient::connect(connstring).await?;
// The channel has a buffer size of 1, since 0 is not allowed. It does not matter, since the
// benchmark will control the queue depth (i.e. in-flight requests) anyway, and requests are
// buffered by Tonic and the OS too.
let (req_tx, req_rx) = tokio::sync::mpsc::channel(1);
let req_stream = tokio_stream::wrappers::ReceiverStream::new(req_rx);
let mut req = tonic::Request::new(req_stream);
let metadata = req.metadata_mut();
metadata.insert("neon-tenant-id", ttid.tenant_id.to_string().try_into()?);
metadata.insert("neon-timeline-id", ttid.timeline_id.to_string().try_into()?);
metadata.insert("neon-shard-id", "0000".try_into()?);
let resp = client.get_pages(req).await?;
let resp_stream = resp.into_inner();
Ok(Self {
req_tx,
resp_rx: resp_stream,
})
}
}
#[async_trait]
impl Client for GrpcClient {
async fn send_get_page(&mut self, req: PagestreamGetPageRequest) -> anyhow::Result<()> {
let req = proto::GetPageRequest {
request_id: 0,
request_class: proto::GetPageClass::Normal as i32,
read_lsn: Some(proto::ReadLsn {
request_lsn: req.hdr.request_lsn.0,
not_modified_since_lsn: req.hdr.not_modified_since.0,
}),
rel: Some(req.rel.into()),
block_number: vec![req.blkno],
};
self.req_tx.send(req).await?;
Ok(())
}
async fn recv_get_page(&mut self) -> anyhow::Result<PagestreamGetPageResponse> {
let resp = self.resp_rx.message().await?.unwrap();
anyhow::ensure!(
resp.status_code == proto::GetPageStatusCode::Ok as i32,
"unexpected status code: {}",
resp.status_code
);
Ok(PagestreamGetPageResponse {
page: resp.page_image[0].clone(),
req: PagestreamGetPageRequest::default(), // dummy
})
}
}

View File

@@ -65,6 +65,30 @@ impl From<GetVectoredError> for BasebackupError {
}
}
impl From<BasebackupError> for postgres_backend::QueryError {
fn from(err: BasebackupError) -> Self {
use postgres_backend::QueryError;
use pq_proto::framed::ConnectionError;
match err {
BasebackupError::Client(err, _) => QueryError::Disconnected(ConnectionError::Io(err)),
BasebackupError::Server(err) => QueryError::Other(err),
BasebackupError::Shutdown => QueryError::Shutdown,
}
}
}
impl From<BasebackupError> for tonic::Status {
fn from(err: BasebackupError) -> Self {
use tonic::Code;
let code = match &err {
BasebackupError::Client(_, _) => Code::Cancelled,
BasebackupError::Server(_) => Code::Internal,
BasebackupError::Shutdown => Code::Unavailable,
};
tonic::Status::new(code, err.to_string())
}
}
/// Create basebackup with non-rel data in it.
/// Only include relational data if 'full_backup' is true.
///
@@ -248,7 +272,7 @@ where
async fn flush(&mut self) -> Result<(), BasebackupError> {
let nblocks = self.buf.len() / BLCKSZ as usize;
let (kind, segno) = self.current_segment.take().unwrap();
let segname = format!("{}/{:>04X}", kind.to_str(), segno);
let segname = format!("{kind}/{segno:>04X}");
let header = new_tar_header(&segname, self.buf.len() as u64)?;
self.ar
.append(&header, self.buf.as_slice())
@@ -347,7 +371,7 @@ where
.await?
.partition(
self.timeline.get_shard_identity(),
Timeline::MAX_GET_VECTORED_KEYS * BLCKSZ as u64,
self.timeline.conf.max_get_vectored_keys.get() as u64 * BLCKSZ as u64,
);
let mut slru_builder = SlruSegmentsBuilder::new(&mut self.ar);

View File

@@ -804,7 +804,7 @@ fn start_pageserver(
} else {
None
},
basebackup_cache.clone(),
basebackup_cache,
);
// Spawn a Pageserver gRPC server task. It will spawn separate tasks for
@@ -816,12 +816,10 @@ fn start_pageserver(
let mut page_service_grpc = None;
if let Some(grpc_listener) = grpc_listener {
page_service_grpc = Some(page_service::spawn_grpc(
conf,
tenant_manager.clone(),
grpc_auth,
otel_guard.as_ref().map(|g| g.dispatch.clone()),
grpc_listener,
basebackup_cache,
)?);
}

View File

@@ -14,7 +14,10 @@ use std::time::Duration;
use anyhow::{Context, bail, ensure};
use camino::{Utf8Path, Utf8PathBuf};
use once_cell::sync::OnceCell;
use pageserver_api::config::{DiskUsageEvictionTaskConfig, MaxVectoredReadBytes, PostHogConfig};
use pageserver_api::config::{
DiskUsageEvictionTaskConfig, MaxGetVectoredKeys, MaxVectoredReadBytes,
PageServicePipeliningConfig, PageServicePipeliningConfigPipelined, PostHogConfig,
};
use pageserver_api::models::ImageCompressionAlgorithm;
use pageserver_api::shard::TenantShardId;
use pem::Pem;
@@ -185,6 +188,9 @@ pub struct PageServerConf {
pub max_vectored_read_bytes: MaxVectoredReadBytes,
/// Maximum number of keys to be read in a single get_vectored call.
pub max_get_vectored_keys: MaxGetVectoredKeys,
pub image_compression: ImageCompressionAlgorithm,
/// Whether to offload archived timelines automatically
@@ -404,6 +410,7 @@ impl PageServerConf {
secondary_download_concurrency,
ingest_batch_size,
max_vectored_read_bytes,
max_get_vectored_keys,
image_compression,
timeline_offloading,
ephemeral_bytes_per_memory_kb,
@@ -470,6 +477,7 @@ impl PageServerConf {
secondary_download_concurrency,
ingest_batch_size,
max_vectored_read_bytes,
max_get_vectored_keys,
image_compression,
timeline_offloading,
ephemeral_bytes_per_memory_kb,
@@ -598,6 +606,19 @@ impl PageServerConf {
)
})?;
if let PageServicePipeliningConfig::Pipelined(PageServicePipeliningConfigPipelined {
max_batch_size,
..
}) = conf.page_service_pipelining
{
if max_batch_size.get() > conf.max_get_vectored_keys.get() {
return Err(anyhow::anyhow!(
"`max_batch_size` ({max_batch_size}) must be less than or equal to `max_get_vectored_keys` ({})",
conf.max_get_vectored_keys.get()
));
}
};
Ok(conf)
}
@@ -685,6 +706,7 @@ impl ConfigurableSemaphore {
mod tests {
use camino::Utf8PathBuf;
use rstest::rstest;
use utils::id::NodeId;
use super::PageServerConf;
@@ -724,4 +746,28 @@ mod tests {
PageServerConf::parse_and_validate(NodeId(0), config_toml, &workdir)
.expect_err("parse_and_validate should fail for endpoint without scheme");
}
#[rstest]
#[case(32, 32, true)]
#[case(64, 32, false)]
#[case(64, 64, true)]
#[case(128, 128, true)]
fn test_config_max_batch_size_is_valid(
#[case] max_batch_size: usize,
#[case] max_get_vectored_keys: usize,
#[case] is_valid: bool,
) {
let input = format!(
r#"
control_plane_api = "http://localhost:6666"
max_get_vectored_keys = {max_get_vectored_keys}
page_service_pipelining = {{ mode="pipelined", execution="concurrent-futures", max_batch_size={max_batch_size}, batching="uniform-lsn" }}
"#,
);
let config_toml = toml_edit::de::from_str::<pageserver_api::config::ConfigToml>(&input)
.expect("config has valid fields");
let workdir = Utf8PathBuf::from("/nonexistent");
let result = PageServerConf::parse_and_validate(NodeId(0), config_toml, &workdir);
assert_eq!(result.is_ok(), is_valid);
}
}

View File

@@ -15,6 +15,7 @@ use metrics::{
register_int_gauge, register_int_gauge_vec, register_uint_gauge, register_uint_gauge_vec,
};
use once_cell::sync::Lazy;
use pageserver_api::config::defaults::DEFAULT_MAX_GET_VECTORED_KEYS;
use pageserver_api::config::{
PageServicePipeliningConfig, PageServicePipeliningConfigPipelined,
PageServiceProtocolPipelinedBatchingStrategy, PageServiceProtocolPipelinedExecutionStrategy,
@@ -32,7 +33,6 @@ use crate::config::PageServerConf;
use crate::context::{PageContentKind, RequestContext};
use crate::pgdatadir_mapping::DatadirModificationStats;
use crate::task_mgr::TaskKind;
use crate::tenant::Timeline;
use crate::tenant::layer_map::LayerMap;
use crate::tenant::mgr::TenantSlot;
use crate::tenant::storage_layer::{InMemoryLayer, PersistentLayerDesc};
@@ -1939,7 +1939,7 @@ static SMGR_QUERY_TIME_GLOBAL: Lazy<HistogramVec> = Lazy::new(|| {
});
static PAGE_SERVICE_BATCH_SIZE_BUCKETS_GLOBAL: Lazy<Vec<f64>> = Lazy::new(|| {
(1..=u32::try_from(Timeline::MAX_GET_VECTORED_KEYS).unwrap())
(1..=u32::try_from(DEFAULT_MAX_GET_VECTORED_KEYS).unwrap())
.map(|v| v.into())
.collect()
});
@@ -1957,7 +1957,7 @@ static PAGE_SERVICE_BATCH_SIZE_BUCKETS_PER_TIMELINE: Lazy<Vec<f64>> = Lazy::new(
let mut buckets = Vec::new();
for i in 0.. {
let bucket = 1 << i;
if bucket > u32::try_from(Timeline::MAX_GET_VECTORED_KEYS).unwrap() {
if bucket > u32::try_from(DEFAULT_MAX_GET_VECTORED_KEYS).unwrap() {
break;
}
buckets.push(bucket.into());

File diff suppressed because it is too large Load Diff

View File

@@ -431,10 +431,10 @@ impl Timeline {
GetVectoredError::InvalidLsn(e) => {
Err(anyhow::anyhow!("invalid LSN: {e:?}").into())
}
// NB: this should never happen in practice because we limit MAX_GET_VECTORED_KEYS
// NB: this should never happen in practice because we limit batch size to be smaller than max_get_vectored_keys
// TODO: we can prevent this error class by moving this check into the type system
GetVectoredError::Oversized(err) => {
Err(anyhow::anyhow!("batching oversized: {err:?}").into())
GetVectoredError::Oversized(err, max) => {
Err(anyhow::anyhow!("batching oversized: {err} > {max}").into())
}
};
@@ -471,8 +471,19 @@ impl Timeline {
let rels = self.list_rels(spcnode, dbnode, version, ctx).await?;
if rels.is_empty() {
return Ok(0);
}
// Pre-deserialize the rel directory to avoid duplicated work in `get_relsize_cached`.
let reldir_key = rel_dir_to_key(spcnode, dbnode);
let buf = version.get(self, reldir_key, ctx).await?;
let reldir = RelDirectory::des(&buf)?;
for rel in rels {
let n_blocks = self.get_rel_size(rel, version, ctx).await?;
let n_blocks = self
.get_rel_size_in_reldir(rel, version, Some((reldir_key, &reldir)), ctx)
.await?;
total_blocks += n_blocks as usize;
}
Ok(total_blocks)
@@ -487,6 +498,19 @@ impl Timeline {
tag: RelTag,
version: Version<'_>,
ctx: &RequestContext,
) -> Result<BlockNumber, PageReconstructError> {
self.get_rel_size_in_reldir(tag, version, None, ctx).await
}
/// Get size of a relation file. The relation must exist, otherwise an error is returned.
///
/// See [`Self::get_rel_exists_in_reldir`] on why we need `deserialized_reldir_v1`.
pub(crate) async fn get_rel_size_in_reldir(
&self,
tag: RelTag,
version: Version<'_>,
deserialized_reldir_v1: Option<(Key, &RelDirectory)>,
ctx: &RequestContext,
) -> Result<BlockNumber, PageReconstructError> {
if tag.relnode == 0 {
return Err(PageReconstructError::Other(
@@ -499,7 +523,9 @@ impl Timeline {
}
if (tag.forknum == FSM_FORKNUM || tag.forknum == VISIBILITYMAP_FORKNUM)
&& !self.get_rel_exists(tag, version, ctx).await?
&& !self
.get_rel_exists_in_reldir(tag, version, deserialized_reldir_v1, ctx)
.await?
{
// FIXME: Postgres sometimes calls smgrcreate() to create
// FSM, and smgrnblocks() on it immediately afterwards,
@@ -521,11 +547,28 @@ impl Timeline {
///
/// Only shard 0 has a full view of the relations. Other shards only know about relations that
/// the shard stores pages for.
///
pub(crate) async fn get_rel_exists(
&self,
tag: RelTag,
version: Version<'_>,
ctx: &RequestContext,
) -> Result<bool, PageReconstructError> {
self.get_rel_exists_in_reldir(tag, version, None, ctx).await
}
/// Does the relation exist? With a cached deserialized `RelDirectory`.
///
/// There are some cases where the caller loops across all relations. In that specific case,
/// the caller should obtain the deserialized `RelDirectory` first and then call this function
/// to avoid duplicated work of deserliazation. This is a hack and should be removed by introducing
/// a new API (e.g., `get_rel_exists_batched`).
pub(crate) async fn get_rel_exists_in_reldir(
&self,
tag: RelTag,
version: Version<'_>,
deserialized_reldir_v1: Option<(Key, &RelDirectory)>,
ctx: &RequestContext,
) -> Result<bool, PageReconstructError> {
if tag.relnode == 0 {
return Err(PageReconstructError::Other(
@@ -568,6 +611,17 @@ impl Timeline {
// fetch directory listing (old)
let key = rel_dir_to_key(tag.spcnode, tag.dbnode);
if let Some((cached_key, dir)) = deserialized_reldir_v1 {
if cached_key == key {
return Ok(dir.rels.contains(&(tag.relnode, tag.forknum)));
} else if cfg!(test) || cfg!(feature = "testing") {
panic!("cached reldir key mismatch: {cached_key} != {key}");
} else {
warn!("cached reldir key mismatch: {cached_key} != {key}");
}
// Fallback to reading the directory from the datadir.
}
let buf = version.get(self, key, ctx).await?;
let dir = RelDirectory::des(&buf)?;
@@ -665,7 +719,7 @@ impl Timeline {
let batches = keyspace.partition(
self.get_shard_identity(),
Timeline::MAX_GET_VECTORED_KEYS * BLCKSZ as u64,
self.conf.max_get_vectored_keys.get() as u64 * BLCKSZ as u64,
);
let io_concurrency = IoConcurrency::spawn_from_conf(
@@ -905,7 +959,7 @@ impl Timeline {
let batches = keyspace.partition(
self.get_shard_identity(),
Timeline::MAX_GET_VECTORED_KEYS * BLCKSZ as u64,
self.conf.max_get_vectored_keys.get() as u64 * BLCKSZ as u64,
);
let io_concurrency = IoConcurrency::spawn_from_conf(

View File

@@ -7197,7 +7197,7 @@ mod tests {
let end = desc
.key_range
.start
.add(Timeline::MAX_GET_VECTORED_KEYS.try_into().unwrap());
.add(tenant.conf.max_get_vectored_keys.get() as u32);
reads.push(KeySpace {
ranges: vec![start..end],
});
@@ -11260,11 +11260,11 @@ mod tests {
let mut keyspaces_at_lsn: HashMap<Lsn, KeySpaceRandomAccum> = HashMap::default();
let mut used_keys: HashSet<Key> = HashSet::default();
while used_keys.len() < Timeline::MAX_GET_VECTORED_KEYS as usize {
while used_keys.len() < tenant.conf.max_get_vectored_keys.get() {
let selected_lsn = interesting_lsns.choose(&mut random).expect("not empty");
let mut selected_key = start_key.add(random.gen_range(0..KEY_DIMENSION_SIZE));
while used_keys.len() < Timeline::MAX_GET_VECTORED_KEYS as usize {
while used_keys.len() < tenant.conf.max_get_vectored_keys.get() {
if used_keys.contains(&selected_key)
|| selected_key >= start_key.add(KEY_DIMENSION_SIZE)
{

View File

@@ -817,8 +817,8 @@ pub(crate) enum GetVectoredError {
#[error("timeline shutting down")]
Cancelled,
#[error("requested too many keys: {0} > {}", Timeline::MAX_GET_VECTORED_KEYS)]
Oversized(u64),
#[error("requested too many keys: {0} > {1}")]
Oversized(u64, u64),
#[error("requested at invalid LSN: {0}")]
InvalidLsn(Lsn),
@@ -950,6 +950,18 @@ pub(crate) enum WaitLsnError {
Timeout(String),
}
impl From<WaitLsnError> for tonic::Status {
fn from(err: WaitLsnError) -> Self {
use tonic::Code;
let code = match &err {
WaitLsnError::Timeout(_) => Code::Internal,
WaitLsnError::BadState(_) => Code::Internal,
WaitLsnError::Shutdown => Code::Unavailable,
};
tonic::Status::new(code, err.to_string())
}
}
// The impls below achieve cancellation mapping for errors.
// Perhaps there's a way of achieving this with less cruft.
@@ -1007,7 +1019,7 @@ impl From<GetVectoredError> for PageReconstructError {
match e {
GetVectoredError::Cancelled => PageReconstructError::Cancelled,
GetVectoredError::InvalidLsn(_) => PageReconstructError::Other(anyhow!("Invalid LSN")),
err @ GetVectoredError::Oversized(_) => PageReconstructError::Other(err.into()),
err @ GetVectoredError::Oversized(_, _) => PageReconstructError::Other(err.into()),
GetVectoredError::MissingKey(err) => PageReconstructError::MissingKey(err),
GetVectoredError::GetReadyAncestorError(err) => PageReconstructError::from(err),
GetVectoredError::Other(err) => PageReconstructError::Other(err),
@@ -1187,7 +1199,6 @@ impl Timeline {
}
}
pub(crate) const MAX_GET_VECTORED_KEYS: u64 = 32;
pub(crate) const LAYERS_VISITED_WARN_THRESHOLD: u32 = 100;
/// Look up multiple page versions at a given LSN
@@ -1202,9 +1213,12 @@ impl Timeline {
) -> Result<BTreeMap<Key, Result<Bytes, PageReconstructError>>, GetVectoredError> {
let total_keyspace = query.total_keyspace();
let key_count = total_keyspace.total_raw_size().try_into().unwrap();
if key_count > Timeline::MAX_GET_VECTORED_KEYS {
return Err(GetVectoredError::Oversized(key_count));
let key_count = total_keyspace.total_raw_size();
if key_count > self.conf.max_get_vectored_keys.get() {
return Err(GetVectoredError::Oversized(
key_count as u64,
self.conf.max_get_vectored_keys.get() as u64,
));
}
for range in &total_keyspace.ranges {
@@ -5258,7 +5272,7 @@ impl Timeline {
key = key.next();
// Maybe flush `key_rest_accum`
if key_request_accum.raw_size() >= Timeline::MAX_GET_VECTORED_KEYS
if key_request_accum.raw_size() >= self.conf.max_get_vectored_keys.get() as u64
|| (last_key_in_range && key_request_accum.raw_size() > 0)
{
let query =

View File

@@ -201,8 +201,8 @@ async fn prepare_import(
.await;
match res {
Ok(_) => break,
Err(err) => {
info!(?err, "indefinitely waiting for pgdata to finish");
Err(_err) => {
info!("indefinitely waiting for pgdata to finish");
if tokio::time::timeout(std::time::Duration::from_secs(10), cancel.cancelled())
.await
.is_ok()

View File

@@ -471,6 +471,8 @@ impl Plan {
last_completed_job_idx = job_idx;
if last_completed_job_idx % checkpoint_every == 0 {
tracing::info!(last_completed_job_idx, jobs=%jobs_in_plan, "Checkpointing import status");
let progress = ShardImportProgressV1 {
jobs: jobs_in_plan,
completed: last_completed_job_idx,
@@ -492,8 +494,6 @@ impl Plan {
anyhow::anyhow!("Shut down while putting timeline import status")
})?;
}
tracing::info!(last_completed_job_idx, jobs=%jobs_in_plan, "Checkpointing import status");
},
Some(Err(_)) => {
anyhow::bail!(

View File

@@ -1,25 +0,0 @@
[package]
name = "paster"
version = "0.1.0"
edition.workspace = true
license.workspace = true
[dependencies]
anyhow.workspace = true
axum-extra = { workspace = true, features = ["cookie", "cookie-private"] }
axum.workspace = true
base64.workspace = true
chrono.workspace = true
nanoid = { version = "0.4.0", default-features = false }
rand.workspace = true
reqwest.workspace = true
rustls-native-certs.workspace = true
rustls.workspace = true
serde.workspace = true
serde_json.workspace = true
time = { version = "0.3.36", default-features = false }
tokio-postgres-rustls.workspace = true
tokio-postgres.workspace = true
tokio.workspace = true
tracing-subscriber.workspace = true
tracing.workspace = true
workspace_hack.workspace = true

View File

@@ -1,18 +0,0 @@
CREATE TABLE IF NOT EXISTS users (
id SERIAL PRIMARY KEY,
sub VARCHAR(100) NOT NULL UNIQUE
);
CREATE TABLE IF NOT EXISTS sessions (
id SERIAL PRIMARY KEY,
user_id INT NOT NULL UNIQUE REFERENCES users(id),
session_id VARCHAR NOT NULL,
expires_at TIMESTAMP WITH TIME ZONE NOT NULL
);
CREATE TABLE IF NOT EXISTS pastes (
id SERIAL PRIMARY KEY,
user_id INT NOT NULL REFERENCES users(id),
paste text NOT NULL,
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
)

View File

@@ -1,353 +0,0 @@
//! Paster is a service to share logs or code snippets outside of
//! Slack, not relying on public services
use anyhow::Result;
use shortener::google_oauth_gate::{AuthRequest, State, UserId};
use axum::Form;
use axum::extract::{FromRef, FromRequestParts, Path, Query, State as AxumStateT};
use axum::http::StatusCode;
use axum::response::{Html, IntoResponse};
use axum::response::{Redirect, Response};
use axum::routing::get;
use axum_extra::extract::PrivateCookieJar;
use axum_extra::extract::cookie::{Cookie, Key};
use chrono::{Duration, Local, TimeZone, Utc};
use core::num::NonZeroI32;
use serde::Deserialize;
use std::env;
use std::sync::Arc;
use tracing::{error, info};
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
const SOCKET: &str = "127.0.0.1:12344";
const HOST: &str = "http://127.0.0.1:12344";
const ALLOWED_OAUTH_DOMAIN: &str = "neon.tech";
fn oauth_redirect_url() -> String {
format!("{HOST}{AUTHORIZED_ROUTE}")
}
#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| format!("{}=info", env!("CARGO_CRATE_NAME")).into()),
)
.with(tracing_subscriber::fmt::layer())
.init();
let oauth_client_id = env::var("OAUTH_CLIENT_ID").expect("Missing OAUTH_CLIENT_ID");
let oauth_client_secret = env::var("OAUTH_CLIENT_SECRET").expect("Missing OAUTH_CLIENT_SECRET");
let db_connstr = env::var("DB_CONNSTR").expect("Missing DB_CONNSTR");
let mut roots = rustls::RootCertStore::empty();
for cert in rustls_native_certs::load_native_certs().expect("could not load platform certs") {
roots.add(cert).unwrap();
}
let config = rustls::ClientConfig::builder()
.with_root_certificates(roots)
.with_no_client_auth();
let tls = tokio_postgres_rustls::MakeRustlsConnect::new(config);
info!("initialized TLS");
let (db_client, db_conn) = tokio_postgres::connect(&db_connstr, tls).await?;
tokio::spawn(async move {
if let Err(err) = db_conn.await {
error!(%err, "connecting to database");
std::process::exit(1);
}
});
info!("connected to database");
let state = InnerState {
db_client,
cookie_jar_key: Key::generate(),
oauth_client_id,
oauth_client_secret,
};
let router = axum::Router::new()
.route("/", get(index).post(paste))
.route("/authorize", get(authorize))
.route(AUTHORIZED_ROUTE, get(authorized))
.route("/{id}", get(view_paste))
.with_state(State { 0: Arc::new(state) });
let listener = tokio::net::TcpListener::bind(SOCKET)
.await
.expect("failed to bind TcpListener");
info!("listening on {SOCKET}");
axum::serve(listener, router).await.unwrap();
Ok(())
}
#[derive(Deserialize)]
pub struct UserId {
id: NonZeroI32,
}
impl axum::extract::OptionalFromRequestParts<State> for UserId {
type Rejection = Response;
async fn from_request_parts(
parts: &mut axum::http::request::Parts,
state: &State,
) -> Result<Option<Self>, Self::Rejection> {
let jar: PrivateCookieJar = PrivateCookieJar::from_request_parts(parts, state)
.await
.unwrap(); // infallible
let Some(session_id) = jar.get(COOKIE_SID).map(|cookie| cookie.value().to_owned()) else {
return Ok(None);
};
let client = &state.db_client;
let query = client
.query_opt(
"SELECT user_id FROM sessions WHERE session_id = $1",
&[&session_id],
)
.await;
let id = match query {
Ok(Some(row)) => row.get::<usize, i32>(0),
Ok(None) => return Ok(None),
Err(err) => {
error!(%err, "querying user session");
return Ok(None);
}
};
let id = NonZeroI32::new(id).unwrap(); // postgres id guaranteed not to be zero
Ok(Some(Self { id }))
}
}
#[derive(Deserialize)]
struct Paste {
paste: String,
}
fn paste_form() -> Html<String> {
Html(
r#"
<form method="post">
<textarea name="paste" style="width:100%;height:80%"></textarea>
<input type="submit" value="Paste" style="margin-top:10px">
</form>"#
.to_string(),
)
}
fn authorize_link(paste_id: i32) -> String {
format!("<a href=\"/authorize?paste_id={paste_id}\">Authorize</a>")
}
async fn index(user: Option<UserId>) -> Html<String> {
if user.is_some() {
return paste_form();
}
Html(authorize_link(0))
}
async fn paste(
state: AxumState,
user: Option<UserId>,
Form(Paste { paste }): Form<Paste>,
) -> Response {
let user_id = match user {
None => return StatusCode::FORBIDDEN.into_response(),
Some(user) => user.id,
};
if paste.is_empty() {
return paste_form().into_response();
}
let query = state
.db_client
.query_one(
"INSERT INTO pastes (user_id, paste) VALUES ($1, $2) RETURNING id",
&[&user_id.get(), &paste],
)
.await;
let id = match query {
Ok(row) => row.get::<usize, i32>(0),
Err(err) => {
error!(%err, "inserting paste");
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
};
Redirect::to(&format!("/{id}")).into_response()
}
async fn view_paste(state: AxumState, user: Option<UserId>, Path(paste_id): Path<i32>) -> Response {
let user_id = match user {
None => return Html(authorize_link(paste_id)).into_response(),
Some(user) => user.id,
};
let query = state
.db_client
.query_opt("SELECT paste FROM pastes WHERE id = $1", &[&paste_id])
.await;
let row = match query {
Ok(None) => return StatusCode::NOT_FOUND.into_response(),
Ok(Some(row)) => row,
Err(err) => {
error!(%err, %paste_id, %user_id, "querying paste");
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
};
row.get::<usize, String>(0).into_response()
}
#[derive(Deserialize)]
struct AuthRequest {
code: String,
}
#[derive(Deserialize)]
struct AuthResponse {
access_token: String,
id_token: String,
expires_in: u64,
}
#[derive(Deserialize)]
struct UserInfo {
hd: String,
sub: String,
}
fn decode_id_token(token: String) -> Option<UserInfo> {
let payload = token.split(".").skip(1).take(1).collect::<Vec<&str>>();
let decoded = base64::decode_config(payload.get(0)?, base64::STANDARD_NO_PAD).ok()?;
serde_json::from_slice::<UserInfo>(&decoded).ok()
}
#[derive(Deserialize)]
struct AuthorizeQuery {
paste_id: i32,
}
fn generate_csrf_token(num_bytes: u32) -> String {
use rand::{Rng, thread_rng};
let random_bytes: Vec<u8> = (0..num_bytes).map(|_| thread_rng().r#gen::<u8>()).collect();
base64::encode_config(&random_bytes, base64::URL_SAFE_NO_PAD)
}
async fn authorize(
state: AxumState,
jar: PrivateCookieJar,
Query(AuthorizeQuery { paste_id }): Query<AuthorizeQuery>,
) -> (PrivateCookieJar, Redirect) {
let csrf_token = generate_csrf_token(16);
let client_id = &state.oauth_client_id;
let redirect_uri = oauth_redirect_url();
let auth_url = format!(
"{OAUTH_BASE_URL}?response_type=code\
&client_id={client_id}\
&state={csrf_token}\
&redirect_uri={redirect_uri}\
&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email"
);
let redirect_cookie = Cookie::build((COOKIE_REDIRECT, paste_id.to_string()))
.path("/")
//.TODO secure(true) not true for localhost
//.domain(COOKIE_DOMAIN)
.secure(false)
.same_site(axum_extra::extract::cookie::SameSite::Lax)
.http_only(true)
.build();
let csrf_cookie = Cookie::build((COOKIE_CSRF, csrf_token))
.path("/")
//.TODO secure(true) not true for localhost
//.domain(COOKIE_DOMAIN)
.secure(false)
.same_site(axum_extra::extract::cookie::SameSite::Lax)
.http_only(true)
.build();
let jar = jar.add(redirect_cookie).add(csrf_cookie);
let url = Into::<String>::into(auth_url);
(jar, Redirect::to(&url))
}
async fn authorized(
state: AxumState,
jar: PrivateCookieJar,
Query(auth_request): Query<AuthRequest>,
) -> Result<(PrivateCookieJar, Redirect), Response> {
let params = [
("grant_type", "authorization_code"),
("redirect_uri", &oauth_redirect_url()),
("code", &auth_request.code),
("client_id", &state.oauth_client_id),
("client_secret", &state.oauth_client_secret),
];
let auth_response = reqwest::Client::new()
.post(OAUTH_TOKEN_URL)
.form(&params)
.send()
.await
.map_err(|err| {
error!(%err, "exchanging oauth code for token");
StatusCode::INTERNAL_SERVER_ERROR.into_response()
})?
.json::<AuthResponse>()
.await
.map_err(|err| {
error!(%err, "deserializing access token response");
StatusCode::INTERNAL_SERVER_ERROR.into_response()
})?;
let Some(UserInfo { hd, sub }) = decode_id_token(auth_response.id_token) else {
error!("Failed to decode response id token");
return Err(StatusCode::UNAUTHORIZED.into_response());
};
if hd != ALLOWED_OAUTH_DOMAIN {
error!(hd, "Domain doesn't match {ALLOWED_OAUTH_DOMAIN}");
return Err(StatusCode::UNAUTHORIZED.into_response());
}
let token_duration = Duration::try_seconds(auth_response.expires_in as i64).unwrap();
let expires_at = Utc.from_utc_datetime(&(Local::now().naive_local() + token_duration));
let cookie_max_age = time::Duration::new(token_duration.num_seconds(), 0);
let session_cookie = Cookie::build((COOKIE_SID, auth_response.access_token.clone()))
.path("/")
//.TODO secure(true) not true for localhost
//.domain(COOKIE_DOMAIN)
.secure(false)
.same_site(axum_extra::extract::cookie::SameSite::Lax)
.http_only(true)
.max_age(cookie_max_age)
.build();
state
.db_client
.query(
"WITH user_insert AS (\
INSERT INTO users (sub) VALUES ($1) \
ON CONFLICT (sub) DO UPDATE SET sub = excluded.sub RETURNING id)\
INSERT INTO sessions (user_id, session_id, expires_at) \
SELECT id, $2, $3 FROM user_insert \
ON CONFLICT (user_id) DO UPDATE SET \
session_id = excluded.session_id, \
expires_at = excluded.expires_at",
&[&sub, &auth_response.access_token, &expires_at],
)
.await
.map_err(|err| {
error!(%err, %sub, "updating session");
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
})?;
let csrf_cookie = jar.get(COOKIE_CSRF).unwrap(); // set in authorize()
let jar = jar.remove(csrf_cookie).add(session_cookie);
match jar.get(COOKIE_REDIRECT) {
Some(redirect_cookie) => {
let mut value = redirect_cookie.value_trimmed();
if value == "0" {
value = "";
}
let redirect_url = format!("/{value}");
Ok((jar.remove(redirect_cookie), Redirect::to(&redirect_url)))
}
None => Ok((jar, Redirect::to("/"))),
}
}

View File

@@ -25,19 +25,15 @@ pub(super) async fn authenticate(
}
AuthSecret::Scram(secret) => {
debug!("auth endpoint chooses SCRAM");
let scram = auth::Scram(&secret, ctx);
let auth_outcome = tokio::time::timeout(config.scram_protocol_timeout, async {
AuthFlow::new(client, scram)
.authenticate()
.await
.inspect_err(|error| {
warn!(?error, "error processing scram messages");
})
})
let auth_outcome = tokio::time::timeout(
config.scram_protocol_timeout,
AuthFlow::new(client, auth::Scram(&secret, ctx)).authenticate(),
)
.await
.inspect_err(|_| warn!("error processing scram messages error = authentication timed out, execution time exceeded {} seconds", config.scram_protocol_timeout.as_secs()))
.map_err(auth::AuthError::user_timeout)??;
.map_err(auth::AuthError::user_timeout)?
.inspect_err(|error| warn!(?error, "error processing scram messages"))?;
let client_key = match auth_outcome {
sasl::Outcome::Success(key) => key,

View File

@@ -159,7 +159,7 @@ pub async fn task_main(
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
config: &'static ProxyConfig,
backend: &'static ConsoleRedirectBackend,
ctx: &RequestContext,

View File

@@ -7,7 +7,9 @@ use std::time::Duration;
use ::http::HeaderName;
use ::http::header::AUTHORIZATION;
use bytes::Bytes;
use futures::TryFutureExt;
use hyper::StatusCode;
use postgres_client::config::SslMode;
use tokio::time::Instant;
use tracing::{Instrument, debug, info, info_span, warn};
@@ -72,28 +74,34 @@ impl NeonControlPlaneClient {
role: &RoleName,
) -> Result<AuthInfo, GetAuthInfoError> {
async {
let request = self
.endpoint
.get_path("get_endpoint_access_control")
.header(X_REQUEST_ID, ctx.session_id().to_string())
.header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
.query(&[("session_id", ctx.session_id())])
.query(&[
("application_name", ctx.console_application_name().as_str()),
("endpointish", endpoint.as_str()),
("role", role.as_str()),
])
.build()?;
debug!(url = request.url().as_str(), "sending http request");
let start = Instant::now();
let response = {
let _pause = ctx.latency_timer_pause_at(start, crate::metrics::Waiting::Cplane);
self.endpoint.execute(request).await?
};
info!(duration = ?start.elapsed(), "received http response");
let request = self
.endpoint
.get_path("get_endpoint_access_control")
.header(X_REQUEST_ID, ctx.session_id().to_string())
.header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
.query(&[("session_id", ctx.session_id())])
.query(&[
("application_name", ctx.console_application_name().as_str()),
("endpointish", endpoint.as_str()),
("role", role.as_str()),
])
.build()?;
let body = match parse_body::<GetEndpointAccessControl>(response).await {
debug!(url = request.url().as_str(), "sending http request");
let start = Instant::now();
let _pause = ctx.latency_timer_pause_at(start, crate::metrics::Waiting::Cplane);
let response = self.endpoint.execute(request).await?;
info!(duration = ?start.elapsed(), "received http response");
response
};
let body = match parse_body::<GetEndpointAccessControl>(
response.status(),
response.bytes().await?,
) {
Ok(body) => body,
// Error 404 is special: it's ok not to have a secret.
// TODO(anna): retry
@@ -184,7 +192,10 @@ impl NeonControlPlaneClient {
drop(pause);
info!(duration = ?start.elapsed(), "received http response");
let body = parse_body::<EndpointJwksResponse>(response).await?;
let body = parse_body::<EndpointJwksResponse>(
response.status(),
response.bytes().await.map_err(ControlPlaneError::from)?,
)?;
let rules = body
.jwks
@@ -236,7 +247,7 @@ impl NeonControlPlaneClient {
let response = self.endpoint.execute(request).await?;
drop(pause);
info!(duration = ?start.elapsed(), "received http response");
let body = parse_body::<WakeCompute>(response).await?;
let body = parse_body::<WakeCompute>(response.status(), response.bytes().await?)?;
// Unfortunately, ownership won't let us use `Option::ok_or` here.
let (host, port) = match parse_host_port(&body.address) {
@@ -487,33 +498,33 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
}
/// Parse http response body, taking status code into account.
async fn parse_body<T: for<'a> serde::Deserialize<'a>>(
response: http::Response,
fn parse_body<T: for<'a> serde::Deserialize<'a>>(
status: StatusCode,
body: Bytes,
) -> Result<T, ControlPlaneError> {
let status = response.status();
if status.is_success() {
// We shouldn't log raw body because it may contain secrets.
info!("request succeeded, processing the body");
return Ok(response.json().await?);
return Ok(serde_json::from_slice(&body).map_err(std::io::Error::other)?);
}
let s = response.bytes().await?;
// Log plaintext to be able to detect, whether there are some cases not covered by the error struct.
info!("response_error plaintext: {:?}", s);
info!("response_error plaintext: {:?}", body);
// Don't throw an error here because it's not as important
// as the fact that the request itself has failed.
let mut body = serde_json::from_slice(&s).unwrap_or_else(|e| {
let mut body = serde_json::from_slice(&body).unwrap_or_else(|e| {
warn!("failed to parse error body: {e}");
ControlPlaneErrorMessage {
Box::new(ControlPlaneErrorMessage {
error: "reason unclear (malformed error message)".into(),
http_status_code: status,
status: None,
}
})
});
body.http_status_code = status;
warn!("console responded with an error ({status}): {body:?}");
Err(ControlPlaneError::Message(Box::new(body)))
Err(ControlPlaneError::Message(body))
}
fn parse_host_port(input: &str) -> Option<(&str, u16)> {

View File

@@ -4,9 +4,10 @@
pub mod health_server;
use std::time::Duration;
use std::time::{Duration, Instant};
use bytes::Bytes;
use futures::FutureExt;
use http::Method;
use http_body_util::BodyExt;
use hyper::body::Body;
@@ -109,15 +110,31 @@ impl Endpoint {
}
/// Execute a [request](reqwest::Request).
pub(crate) async fn execute(&self, request: Request) -> Result<Response, Error> {
let _timer = Metrics::get()
pub(crate) fn execute(
&self,
request: Request,
) -> impl Future<Output = Result<Response, Error>> {
let metric = Metrics::get()
.proxy
.console_request_latency
.start_timer(ConsoleRequest {
.with_labels(ConsoleRequest {
request: request.url().path(),
});
self.client.execute(request).await
let req = self.client.execute(request).boxed();
async move {
let start = Instant::now();
scopeguard::defer!({
Metrics::get()
.proxy
.console_request_latency
.get_metric(metric)
.observe_duration_since(start);
});
req.await
}
}
}

View File

@@ -186,7 +186,7 @@ where
pub async fn read_message<'a, S>(
stream: &mut S,
buf: &'a mut Vec<u8>,
max: usize,
max: u32,
) -> io::Result<(u8, &'a mut [u8])>
where
S: AsyncRead + Unpin,
@@ -206,7 +206,7 @@ where
let header = read!(stream => Header);
// as described above, the length must be at least 4.
let Some(len) = (header.len.get() as usize).checked_sub(4) else {
let Some(len) = header.len.get().checked_sub(4) else {
return Err(io::Error::other(format!(
"invalid startup message length {}, must be at least 4.",
header.len,
@@ -222,7 +222,7 @@ where
}
// read in our entire message.
buf.resize(len, 0);
buf.resize(len as usize, 0);
stream.read_exact(buf).await?;
Ok((header.tag, buf))

View File

@@ -1,3 +1,4 @@
use futures::{FutureExt, TryFutureExt};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{debug, info, warn};
@@ -57,7 +58,7 @@ pub(crate) enum HandshakeData<S> {
/// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
/// we also take an extra care of propagating only the select handshake errors to client.
#[tracing::instrument(skip_all)]
pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin + Send>(
ctx: &RequestContext,
stream: S,
mut tls: Option<&TlsConfig>,
@@ -108,7 +109,9 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
}
}
}
});
})
.map_ok(Box::new)
.boxed();
res?;
@@ -146,7 +149,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
tls.cert_resolver.resolve(conn_info.server_name());
let tls = Stream::Tls {
tls: Box::new(tls_stream),
tls: tls_stream,
tls_server_end_point,
};
(stream, msg) = PqStream::parse_startup(tls).await?;

View File

@@ -270,7 +270,7 @@ impl ReportableError for ClientRequestError {
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
config: &'static ProxyConfig,
auth_backend: &'static auth::Backend<'static, ()>,
ctx: &RequestContext,

View File

@@ -1,3 +1,4 @@
use futures::FutureExt;
use smol_str::SmolStr;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::debug;
@@ -89,6 +90,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {
.compute
.cancel_closure
.try_cancel_query(compute_config)
.boxed()
.await
{
tracing::warn!(session_id = ?self.session_id, ?err, "could not cancel the query in the database");

View File

@@ -30,52 +30,53 @@ where
F: FnOnce(&str) -> super::Result<M>,
M: Mechanism,
{
let sasl = {
let (mut mechanism, mut input) = {
// pause the timer while we communicate with the client
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
// Initial client message contains the chosen auth method's name.
let msg = stream.read_password_message().await?;
super::FirstMessage::parse(msg).ok_or(super::Error::BadClientMessage("bad sasl message"))?
let sasl = super::FirstMessage::parse(msg)
.ok_or(super::Error::BadClientMessage("bad sasl message"))?;
(mechanism(sasl.method)?, sasl.message)
};
let mut mechanism = mechanism(sasl.method)?;
let mut input = sasl.message;
loop {
let step = mechanism
.exchange(input)
.inspect_err(|error| tracing::info!(?error, "error during SASL exchange"))?;
match step {
Step::Continue(moved_mechanism, reply) => {
match mechanism.exchange(input) {
Ok(Step::Continue(moved_mechanism, reply)) => {
mechanism = moved_mechanism;
// pause the timer while we communicate with the client
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
// write reply
let sasl_msg = BeAuthenticationSaslMessage::Continue(reply.as_bytes());
stream.write_message(BeMessage::AuthenticationSasl(sasl_msg));
// get next input
stream.flush().await?;
let msg = stream.read_password_message().await?;
input = std::str::from_utf8(msg)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?;
drop(reply);
}
Step::Success(result, reply) => {
// pause the timer while we communicate with the client
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
Ok(Step::Success(result, reply)) => {
// write reply
let sasl_msg = BeAuthenticationSaslMessage::Final(reply.as_bytes());
stream.write_message(BeMessage::AuthenticationSasl(sasl_msg));
stream.write_message(BeMessage::AuthenticationOk);
// exit with success
break Ok(Outcome::Success(result));
}
// exit with failure
Step::Failure(reason) => break Ok(Outcome::Failure(reason)),
Ok(Step::Failure(reason)) => break Ok(Outcome::Failure(reason)),
Err(error) => {
tracing::info!(?error, "error during SASL exchange");
return Err(error);
}
}
// pause the timer while we communicate with the client
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
// get next input
stream.flush().await?;
let msg = stream.read_password_message().await?;
input = std::str::from_utf8(msg)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?;
}
}

View File

@@ -72,7 +72,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> PqStream<S> {
impl<S: AsyncRead + Unpin> PqStream<S> {
/// Read a raw postgres packet, which will respect the max length requested.
/// This is not cancel safe.
async fn read_raw_expect(&mut self, tag: u8, max: usize) -> io::Result<&mut [u8]> {
async fn read_raw_expect(&mut self, tag: u8, max: u32) -> io::Result<&mut [u8]> {
let (actual_tag, msg) = read_message(&mut self.stream, &mut self.read, max).await?;
if actual_tag != tag {
return Err(io::Error::other(format!(
@@ -89,7 +89,7 @@ impl<S: AsyncRead + Unpin> PqStream<S> {
// passwords are usually pretty short
// and SASL SCRAM messages are no longer than 256 bytes in my testing
// (a few hashes and random bytes, encoded into base64).
const MAX_PASSWORD_LENGTH: usize = 512;
const MAX_PASSWORD_LENGTH: u32 = 512;
self.read_raw_expect(FE_PASSWORD_MESSAGE, MAX_PASSWORD_LENGTH)
.await
}

View File

@@ -31,7 +31,9 @@ mod private {
type Output = io::Result<RustlsStream<S>>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.inner).poll(cx).map_ok(RustlsStream)
Pin::new(&mut self.inner)
.poll(cx)
.map_ok(|s| RustlsStream(Box::new(s)))
}
}
@@ -57,7 +59,7 @@ mod private {
}
}
pub struct RustlsStream<S>(TlsStream<S>);
pub struct RustlsStream<S>(Box<TlsStream<S>>);
impl<S> postgres_client::tls::TlsStream for RustlsStream<S>
where

View File

@@ -1,26 +0,0 @@
[package]
name = "shortener"
version = "0.1.0"
edition.workspace = true
license.workspace = true
[dependencies]
anyhow.workspace = true
axum-extra = { workspace = true, features = ["cookie", "cookie-private"] }
axum.workspace = true
base64.workspace = true
chrono.workspace = true
cookie = "0.18.1"
nanoid = { version = "0.4.0", default-features = false }
rand.workspace = true
reqwest.workspace = true
rustls-native-certs.workspace = true
rustls.workspace = true
serde.workspace = true
serde_json.workspace = true
time = { version = "0.3.36", default-features = false }
tokio-postgres-rustls.workspace = true
tokio-postgres.workspace = true
tokio.workspace = true
tracing-subscriber.workspace = true
tracing.workspace = true
workspace_hack.workspace = true

View File

@@ -1,19 +0,0 @@
CREATE TABLE IF NOT EXISTS users (
id SERIAL PRIMARY KEY,
sub VARCHAR(100) NOT NULL UNIQUE
);
CREATE TABLE IF NOT EXISTS sessions (
id SERIAL PRIMARY KEY,
user_id INT NOT NULL UNIQUE REFERENCES users(id),
session_id VARCHAR NOT NULL,
expires_at TIMESTAMP WITH TIME ZONE NOT NULL
);
CREATE TABLE IF NOT EXISTS urls (
id SERIAL PRIMARY KEY,
user_id INT NOT NULL REFERENCES users(id),
short_url VARCHAR(6) NOT NULL UNIQUE,
long_url VARCHAR NOT NULL,
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
)

View File

@@ -1,222 +0,0 @@
//! Library to gate infrastructure behind Google Oauth for domain.
//!
//! Why not oauth-rs? Oauth .exchange_code() doesn't work with "request failed". Also, we can't get
//! id token from it, and I don't want to pull in whole openid library just for that.
//! Id token saves us a request to openid endpoint and one Oauth scope we don't use
use anyhow::{Context, Result, bail};
use axum::extract::{FromRef, FromRequestParts, Query, State as AxumState};
use axum::response::Redirect;
use axum_extra::extract::PrivateCookieJar;
use axum_extra::extract::cookie::{Cookie, Key};
use chrono::{Duration, Local, TimeZone, Utc};
use cookie::CookieBuilder;
use core::num::NonZeroI32;
use reqwest::StatusCode;
use serde::Deserialize;
use std::sync::Arc;
use tokio_postgres::Socket;
const OAUTH_BASE_URL: &str = "https://accounts.google.com/o/oauth2/v2/auth";
const OAUTH_TOKEN_URL: &str = "https://oauth2.googleapis.com/token";
const COOKIE_SID: &str = "sid";
const COOKIE_CSRF: &str = "csrf";
pub struct Config {
pub oauth_client_id: String,
pub oauth_client_secret: String,
pub oauth_redirect_url: String,
pub oauth_allowed_domain: String,
pub cookie_settings: fn(CookieBuilder) -> CookieBuilder,
}
pub struct InnerState {
config: Config,
cookie_jar_key: Key,
pub db_client: tokio_postgres::Client,
}
#[derive(Clone)]
pub struct State(Arc<InnerState>);
type DbConn = tokio_postgres::Connection<Socket, tokio_postgres_rustls::RustlsStream<Socket>>;
impl State {
pub async fn new(config: Config, db_connstr: &str) -> Result<(Self, DbConn)> {
let mut roots = rustls::RootCertStore::empty();
for cert in rustls_native_certs::load_native_certs().expect("could not load platform certs")
{
roots.add(cert).unwrap();
}
let tls_config = rustls::ClientConfig::builder()
.with_root_certificates(roots)
.with_no_client_auth();
let tls = tokio_postgres_rustls::MakeRustlsConnect::new(tls_config);
let (db_client, db_conn) = tokio_postgres::connect(&db_connstr, tls).await?;
let inner = InnerState {
config,
cookie_jar_key: Key::generate(),
db_client,
};
Ok((Self { 0: Arc::new(inner) }, db_conn))
}
}
impl std::ops::Deref for State {
type Target = InnerState;
fn deref(&self) -> &Self::Target {
&*self.0
}
}
impl FromRef<State> for Key {
fn from_ref(state: &State) -> Self {
state.cookie_jar_key.clone()
}
}
#[derive(Deserialize)]
pub struct UserId {
pub id: NonZeroI32,
}
#[derive(Deserialize)]
pub struct AuthRequest {
code: String,
}
#[derive(Deserialize)]
struct AuthResponse {
access_token: String,
id_token: String,
expires_in: u64,
}
#[derive(Deserialize)]
struct UserInfo {
hd: String,
sub: String,
}
impl axum::extract::OptionalFromRequestParts<State> for UserId {
type Rejection = StatusCode;
async fn from_request_parts(
parts: &mut axum::http::request::Parts,
state: &State,
) -> Result<Option<Self>, Self::Rejection> {
let jar: PrivateCookieJar = PrivateCookieJar::from_request_parts(parts, state)
.await
.unwrap(); // infallible
let Some(session_id) = jar.get(COOKIE_SID).map(|cookie| cookie.value().to_owned()) else {
return Ok(None);
};
let client = &state.db_client;
let query = client
.query_opt(
"SELECT user_id FROM sessions WHERE session_id = $1",
&[&session_id],
)
.await;
let id = match query {
Ok(Some(row)) => row.get::<usize, i32>(0),
Ok(None) => return Ok(None),
Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
};
let id = NonZeroI32::new(id).unwrap(); // postgres id guaranteed not to be zero
Ok(Some(Self { id }))
}
}
fn decode_id_token(token: String) -> Option<UserInfo> {
let payload = token.split(".").skip(1).take(1).collect::<Vec<&str>>();
let decoded = base64::decode_config(payload.get(0)?, base64::STANDARD_NO_PAD).ok()?;
serde_json::from_slice::<UserInfo>(&decoded).ok()
}
fn generate_csrf_token(num_bytes: u32) -> String {
use rand::{Rng, thread_rng};
let random_bytes: Vec<u8> = (0..num_bytes).map(|_| thread_rng().r#gen::<u8>()).collect();
base64::encode_config(&random_bytes, base64::URL_SAFE_NO_PAD)
}
pub async fn authorize(
state: AxumState<State>,
jar: PrivateCookieJar,
) -> (PrivateCookieJar, Redirect) {
let csrf_token = generate_csrf_token(16);
let client_id = &state.config.oauth_client_id;
let redirect_uri = &state.config.oauth_redirect_url;
let auth_url = format!(
"{OAUTH_BASE_URL}?response_type=code\
&client_id={client_id}\
&state={csrf_token}\
&redirect_uri={redirect_uri}\
&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email"
);
let csrf_cookie =
(state.config.cookie_settings)(Cookie::build((COOKIE_CSRF, csrf_token))).build();
let url = Into::<String>::into(auth_url);
(jar.add(csrf_cookie), Redirect::to(&url))
}
pub async fn authorized(
state: AxumState<State>,
jar: PrivateCookieJar,
Query(auth_request): Query<AuthRequest>,
) -> Result<PrivateCookieJar> {
let params = [
("grant_type", "authorization_code"),
("redirect_uri", &state.config.oauth_redirect_url),
("code", &auth_request.code),
("client_id", &state.config.oauth_client_id),
("client_secret", &state.config.oauth_client_secret),
];
let auth_response = reqwest::Client::new()
.post(OAUTH_TOKEN_URL)
.form(&params)
.send()
.await
.context("exchanging oauth code for token")?
.json::<AuthResponse>()
.await
.context("deserializing access_token response")?;
let Some(UserInfo { hd, sub }) = decode_id_token(auth_response.id_token) else {
bail!("failed to decode id token")
};
let allowed_domain = &state.config.oauth_allowed_domain;
if hd != *allowed_domain {
bail!("{hd} doesn't match {allowed_domain}")
}
let token_duration = Duration::try_seconds(auth_response.expires_in as i64).unwrap();
let expires_at = Utc.from_utc_datetime(&(Local::now().naive_local() + token_duration));
let cookie_max_age = time::Duration::new(token_duration.num_seconds(), 0);
let session_cookie = (state.config.cookie_settings)(Cookie::build((
COOKIE_SID,
auth_response.access_token.clone(),
)))
.max_age(cookie_max_age)
.build();
state
.db_client
.query(
"WITH user_insert AS (\
INSERT INTO users (sub) VALUES ($1) \
ON CONFLICT (sub) DO UPDATE SET sub = excluded.sub RETURNING id)\
INSERT INTO sessions (user_id, session_id, expires_at) \
SELECT id, $2, $3 FROM user_insert \
ON CONFLICT (user_id) DO UPDATE SET \
session_id = excluded.session_id, \
expires_at = excluded.expires_at",
&[&sub, &auth_response.access_token, &expires_at],
)
.await
.with_context(|| format!("updating session for {sub}"))?;
let csrf_cookie = jar.get(COOKIE_CSRF).unwrap(); // set in authorize()
Ok(jar.remove(csrf_cookie).add(session_cookie))
}

View File

@@ -1,240 +0,0 @@
//! Shortener is a service to gate access to internal infrastructure
//! URLs behind team authorisation to expose less private information.
pub mod google_oauth_gate;
use crate::google_oauth_gate::{AuthRequest, State, UserId};
use anyhow::Result;
use axum::Form;
use axum::extract::State as AxumState;
use axum::extract::{Path, Query};
use axum::http::StatusCode;
use axum::response::{Html, IntoResponse};
use axum::response::{Redirect, Response};
use axum::routing::get;
use axum_extra::extract::PrivateCookieJar;
use axum_extra::extract::cookie::Cookie;
use cookie::CookieBuilder;
use google_oauth_gate::Config;
use serde::Deserialize;
use std::env;
use tracing::{error, info};
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
const SOCKET: &str = "127.0.0.1:12344";
const HOST: &str = "http://127.0.0.1:12344";
const COOKIE_REDIRECT: &str = "redirect";
const ALLOWED_OAUTH_DOMAIN: &str = "neon.tech";
const AUTHORIZED_ROUTE: &str = "/authorized";
const SHORT_URL_LEN: usize = 6;
fn cookie_settings(b: CookieBuilder) -> CookieBuilder {
if HOST.contains("127.0.0.1") {
b.path("/")
.secure(false)
.same_site(axum_extra::extract::cookie::SameSite::Lax)
.http_only(true)
} else {
b.path("/")
.domain(ALLOWED_OAUTH_DOMAIN)
.secure(true)
.http_only(false)
}
}
fn oauth_redirect_url() -> String {
format!("{HOST}{AUTHORIZED_ROUTE}")
}
#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| format!("{}=info", env!("CARGO_CRATE_NAME")).into()),
)
.with(tracing_subscriber::fmt::layer())
.init();
let oauth_client_id = env::var("OAUTH_CLIENT_ID").expect("Missing OAUTH_CLIENT_ID");
let oauth_client_secret = env::var("OAUTH_CLIENT_SECRET").expect("Missing OAUTH_CLIENT_SECRET");
let db_connstr = env::var("DB_CONNSTR").expect("Missing DB_CONNSTR");
let config = Config {
oauth_client_id,
oauth_client_secret,
oauth_redirect_url: oauth_redirect_url(),
oauth_allowed_domain: ALLOWED_OAUTH_DOMAIN.to_string(),
cookie_settings,
};
let (state, db_conn) = State::new(config, &db_connstr).await?;
tokio::spawn(async move {
if let Err(err) = db_conn.await {
error!(%err, "connecting to database");
std::process::exit(1);
}
});
let router = axum::Router::new()
.route("/", get(index).post(shorten))
.route("/authorize", get(authorize))
.route(AUTHORIZED_ROUTE, get(authorized))
.route("/{short_url}", get(redirect))
.with_state(state);
let listener = tokio::net::TcpListener::bind(SOCKET)
.await
.expect("failed to bind TcpListener");
info!("listening on {SOCKET}");
axum::serve(listener, router).await.unwrap();
Ok(())
}
#[derive(Deserialize)]
struct LongUrl {
url: String,
}
fn shorten_form(short_url: &str) -> Html<String> {
let mut form = r#"
<div style="margin:auto;width:50%;padding:10px">
<form method="post">
<input type="text" name="url" style="width:100%">
<input type="submit" value="Shorten" style="margin-top:10px">
</form>"#
.to_string();
if !short_url.is_empty() {
form += &format!(
r#"
<p>
<a id="short" href="{0}">{0}</a>
<button onclick="copy()">Copy</button>
</p>
<script>
function copy() {{
navigator.clipboard.writeText(document.querySelector("\#short").textContent);
}}
</script>"#,
short_url
);
}
form += "</div>";
Html(form)
}
fn authorize_link(short_url: &str) -> Html<String> {
Html(format!(
"<a href=\"/authorize?short_url={short_url}\">Authorize</a>"
))
}
async fn index(user: Option<UserId>) -> Html<String> {
if user.is_some() {
return shorten_form("");
}
authorize_link("")
}
async fn shorten(
state: AxumState<State>,
user: Option<UserId>,
Form(LongUrl { url }): Form<LongUrl>,
) -> Response {
let user_id = match user {
None => return StatusCode::FORBIDDEN.into_response(),
Some(user) => user.id.get(),
};
if url.is_empty() {
return shorten_form("").into_response();
}
let mut short_url = "".to_string();
for i in 0..20 {
short_url = nanoid::nanoid!(SHORT_URL_LEN);
let query = state
.db_client
.query_opt(
"INSERT INTO urls (user_id, short_url, long_url) VALUES ($1, $2, $3) \
ON CONFLICT (short_url) DO NOTHING \
RETURNING short_url",
&[&user_id, &short_url, &url],
)
.await;
match query {
Ok(Some(_)) => break,
Ok(None) => {
info!(short_url, "url clash, retry {i}");
continue;
}
Err(err) => {
error!(%err, "inserting shortened url");
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
};
}
shorten_form(&format!("{HOST}/{short_url}")).into_response()
}
async fn redirect(
state: AxumState<State>,
user: Option<UserId>,
Path(short_url): Path<String>,
) -> Response {
let user_id = match user {
None => return authorize_link(&short_url).into_response(),
Some(user) => user.id,
};
let query = state
.db_client
.query_opt(
"SELECT long_url FROM urls WHERE short_url = $1",
&[&short_url],
)
.await;
match query {
Ok(Some(row)) => Redirect::permanent(row.get(0)).into_response(),
Ok(None) => StatusCode::NOT_FOUND.into_response(),
Err(err) => {
error!(%err, %short_url, %user_id, "querying long url");
StatusCode::INTERNAL_SERVER_ERROR.into_response()
}
}
}
#[derive(Deserialize)]
struct AuthorizeQuery {
short_url: String,
}
async fn authorize(
state: AxumState<State>,
jar: PrivateCookieJar,
Query(AuthorizeQuery { short_url }): Query<AuthorizeQuery>,
) -> (PrivateCookieJar, Redirect) {
let (jar, auth_redirect) = google_oauth_gate::authorize(state, jar).await;
let redirect_cookie = Cookie::build((COOKIE_REDIRECT, short_url))
.path("/")
//.TODO secure(true) not true for localhost
//.domain(COOKIE_DOMAIN)
.secure(false)
.same_site(axum_extra::extract::cookie::SameSite::Lax)
.http_only(true)
.build();
(jar.add(redirect_cookie), auth_redirect)
}
async fn authorized(
state: AxumState<State>,
jar: PrivateCookieJar,
query: Query<AuthRequest>,
) -> Result<(PrivateCookieJar, Redirect), Response> {
use google_oauth_gate::authorized;
let jar = authorized(state, jar, query).await.map_err(|err| {
error!(%err, "authorizing");
return StatusCode::UNAUTHORIZED.into_response();
})?;
let Some(redirect_cookie) = jar.get(COOKIE_REDIRECT) else {
return Ok((jar, Redirect::to("/")));
};
let redirect_url = Redirect::to(&format!("/{}", redirect_cookie.value_trimmed()));
Ok((jar.remove(redirect_cookie), redirect_url))
}

View File

@@ -9,7 +9,6 @@ from pathlib import Path
from threading import Event
import psycopg2
import psycopg2.errors
import pytest
from fixtures.common_types import Lsn, TenantId, TenantShardId, TimelineId
from fixtures.fast_import import (
@@ -1070,15 +1069,41 @@ def test_fast_import_restore_to_connstring_from_s3_spec(
return mock_kms.encrypt(KeyId=key_id, Plaintext=x)
# Start source postgres and ingest data
vanilla_pg.configure(["shared_preload_libraries='neon,neon_utils,neon_rmgr'"])
vanilla_pg.start()
vanilla_pg.safe_psql("CREATE TABLE foo (a int); INSERT INTO foo SELECT generate_series(1, 10);")
res = vanilla_pg.safe_psql("SHOW shared_preload_libraries;")
log.info(f"shared_preload_libraries: {res}")
res = vanilla_pg.safe_psql("SELECT name FROM pg_available_extensions;")
log.info(f"pg_available_extensions: {res}")
res = vanilla_pg.safe_psql("SELECT extname FROM pg_extension;")
log.info(f"pg_extension: {res}")
# Create a number of extensions, we only will dump selected ones
vanilla_pg.safe_psql("CREATE EXTENSION neon;")
vanilla_pg.safe_psql("CREATE EXTENSION neon_utils;")
vanilla_pg.safe_psql("CREATE EXTENSION pg_visibility;")
# Default schema is always dumped
vanilla_pg.safe_psql(
"CREATE TABLE public.foo (a int); INSERT INTO public.foo SELECT generate_series(1, 7);"
)
# Create a number of schemas, we only will dump selected ones
vanilla_pg.safe_psql("CREATE SCHEMA custom;")
vanilla_pg.safe_psql(
"CREATE TABLE custom.foo (a int); INSERT INTO custom.foo SELECT generate_series(1, 13);"
)
vanilla_pg.safe_psql("CREATE SCHEMA other;")
vanilla_pg.safe_psql(
"CREATE TABLE other.foo (a int); INSERT INTO other.foo SELECT generate_series(1, 42);"
)
# Start target postgres
pgdatadir = test_output_dir / "destination-pgdata"
pg_bin = PgBin(test_output_dir, pg_distrib_dir, pg_version)
port = port_distributor.get_port()
with VanillaPostgres(pgdatadir, pg_bin, port) as destination_vanilla_pg:
destination_vanilla_pg.configure(["shared_preload_libraries='neon_rmgr'"])
destination_vanilla_pg.configure(["shared_preload_libraries='neon,neon_utils,neon_rmgr'"])
destination_vanilla_pg.start()
# Encrypt connstrings and put spec into S3
@@ -1092,6 +1117,8 @@ def test_fast_import_restore_to_connstring_from_s3_spec(
"destination_connstring_ciphertext_base64": base64.b64encode(
destination_connstring_encrypted["CiphertextBlob"]
).decode("utf-8"),
"schemas": ["custom"],
"extensions": ["plpgsql", "neon"],
}
bucket = "test-bucket"
@@ -1117,9 +1144,31 @@ def test_fast_import_restore_to_connstring_from_s3_spec(
}, f"got status: {job_status}"
vanilla_pg.stop()
res = destination_vanilla_pg.safe_psql("SELECT count(*) FROM foo;")
res = destination_vanilla_pg.safe_psql("SELECT count(*) FROM public.foo;")
log.info(f"Result: {res}")
assert res[0][0] == 10
assert res[0][0] == 7
res = destination_vanilla_pg.safe_psql("SELECT count(*) FROM custom.foo;")
log.info(f"Result: {res}")
assert res[0][0] == 13
# Check that other schema is not restored
with pytest.raises(psycopg2.errors.UndefinedTable):
destination_vanilla_pg.safe_psql("SELECT count(*) FROM other.foo;")
# Check that all schemas are listed correctly
res = destination_vanilla_pg.safe_psql("SELECT nspname FROM pg_namespace;")
log.info(f"Result: {res}")
schemas = [row[0] for row in res]
assert "other" not in schemas
# Check that only selected extensions are restored
res = destination_vanilla_pg.safe_psql("SELECT extname FROM pg_extension;")
log.info(f"Result: {res}")
assert len(res) == 2
extensions = set([str(row[0]) for row in res])
assert "plpgsql" in extensions
assert "neon" in extensions
def test_fast_import_restore_to_connstring_error_to_s3_bad_destination(

View File

@@ -1,18 +1,18 @@
{
"v17": [
"17.5",
"8be779fd3ab9e87206da96a7e4842ef1abf04f44"
"10c002910447b3138e13213befca662df7cbe1d0"
],
"v16": [
"16.9",
"0bf96bd6d70301a0b43b0b3457bb3cf8fb43c198"
"94ad7e11cd43cce32f5af5674af29b3f334551a7"
],
"v15": [
"15.13",
"de7640f55da07512834d5cc40c4b3fb376b5f04f"
"cd0b534a761c18d8ef4654d4f749c63c5663215f"
],
"v14": [
"14.18",
"55c0d45abe6467c02084c2192bca117eda6ce1e7"
"b1e9959858f0529ea33d3cc5e833c0acc43f583a"
]
}

View File

@@ -20,7 +20,6 @@ anstream = { version = "0.6" }
anyhow = { version = "1", features = ["backtrace"] }
axum = { version = "0.8", features = ["ws"] }
axum-core = { version = "0.5", default-features = false, features = ["tracing"] }
axum-extra = { version = "0.10", features = ["cookie-private", "typed-header"] }
base64-594e8ee84c453af0 = { package = "base64", version = "0.13", features = ["alloc"] }
base64-647d43efb71741da = { package = "base64", version = "0.21" }
base64ct = { version = "1", default-features = false, features = ["std"] }
@@ -31,7 +30,6 @@ clap = { version = "4", features = ["derive", "env", "string"] }
clap_builder = { version = "4", default-features = false, features = ["color", "env", "help", "std", "string", "suggestions", "usage"] }
const-oid = { version = "0.9", default-features = false, features = ["db", "std"] }
crypto-bigint = { version = "0.5", features = ["generic-array", "zeroize"] }
crypto-common = { version = "0.1", default-features = false, features = ["getrandom", "std"] }
der = { version = "0.7", default-features = false, features = ["derive", "flagset", "oid", "pem", "std"] }
deranged = { version = "0.3", default-features = false, features = ["powerfmt", "serde", "std"] }
digest = { version = "0.10", features = ["mac", "oid", "std"] }