Compare commits

..

1 Commits

Author SHA1 Message Date
Anastasia Lubennikova
0c8ed22939 Use saturating_sub for resident_physical_size_gauge
to avoid metric overflow, if the gauge is not initialized yet.
2023-02-28 16:00:41 +02:00
110 changed files with 3246 additions and 4798 deletions

View File

@@ -551,48 +551,6 @@ jobs:
- name: Cleanup ECR folder
run: rm -rf ~/.ecr
neon-image-depot:
# For testing this will run side-by-side for a few merges.
# This action is not really optimized yet, but gets the job done
runs-on: [ self-hosted, gen3, small ]
needs: [ tag ]
container: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/base:pinned
permissions:
contents: read
id-token: write
steps:
- name: Checkout
uses: actions/checkout@v3
with:
submodules: true
fetch-depth: 0
- name: Setup go
uses: actions/setup-go@v3
with:
go-version: '1.19'
- name: Set up Depot CLI
uses: depot/setup-action@v1
- name: Install Crane & ECR helper
run: go install github.com/awslabs/amazon-ecr-credential-helper/ecr-login/cli/docker-credential-ecr-login@69c85dc22db6511932bbf119e1a0cc5c90c69a7f # v0.6.0
- name: Configure ECR login
run: |
mkdir /github/home/.docker/
echo "{\"credsStore\":\"ecr-login\"}" > /github/home/.docker/config.json
- name: Build and push
uses: depot/build-push-action@v1
with:
# if no depot.json file is at the root of your repo, you must specify the project id
project: nrdv0s4kcs
push: true
tags: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/neon:depot-${{needs.tag.outputs.build-tag}}
compute-tools-image:
runs-on: [ self-hosted, gen3, large ]
needs: [ tag ]

53
Cargo.lock generated
View File

@@ -851,7 +851,6 @@ dependencies = [
"futures",
"hyper",
"notify",
"num_cpus",
"opentelemetry",
"postgres",
"regex",
@@ -914,7 +913,6 @@ dependencies = [
"once_cell",
"pageserver_api",
"postgres",
"postgres_backend",
"postgres_connection",
"regex",
"reqwest",
@@ -2456,7 +2454,6 @@ dependencies = [
"postgres",
"postgres-protocol",
"postgres-types",
"postgres_backend",
"postgres_connection",
"postgres_ffi",
"pq_proto",
@@ -2679,28 +2676,6 @@ dependencies = [
"postgres-protocol",
]
[[package]]
name = "postgres_backend"
version = "0.1.0"
dependencies = [
"anyhow",
"async-trait",
"bytes",
"futures",
"once_cell",
"pq_proto",
"rustls",
"rustls-pemfile",
"serde",
"thiserror",
"tokio",
"tokio-postgres",
"tokio-postgres-rustls",
"tokio-rustls",
"tracing",
"workspace_hack",
]
[[package]]
name = "postgres_connection"
version = "0.1.0"
@@ -2748,7 +2723,7 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
name = "pq_proto"
version = "0.1.0"
dependencies = [
"byteorder",
"anyhow",
"bytes",
"pin-project-lite",
"postgres-protocol",
@@ -2923,7 +2898,6 @@ dependencies = [
"opentelemetry",
"parking_lot",
"pin-project-lite",
"postgres_backend",
"pq_proto",
"prometheus",
"rand",
@@ -3303,6 +3277,15 @@ dependencies = [
"base64 0.21.0",
]
[[package]]
name = "rustls-split"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78802c9612b4689d207acff746f38132ca1b12dadb55d471aa5f10fd580f47d3"
dependencies = [
"rustls",
]
[[package]]
name = "rustversion"
version = "1.0.11"
@@ -3324,7 +3307,6 @@ dependencies = [
"async-trait",
"byteorder",
"bytes",
"chrono",
"clap 4.1.4",
"const_format",
"crc32c",
@@ -3334,11 +3316,11 @@ dependencies = [
"humantime",
"hyper",
"metrics",
"nix",
"once_cell",
"parking_lot",
"postgres",
"postgres-protocol",
"postgres_backend",
"postgres_ffi",
"pq_proto",
"regex",
@@ -4523,7 +4505,7 @@ dependencies = [
"byteorder",
"bytes",
"criterion",
"futures",
"git-version",
"heapless",
"hex",
"hex-literal",
@@ -4532,8 +4514,12 @@ dependencies = [
"metrics",
"nix",
"once_cell",
"pq_proto",
"rand",
"routerify",
"rustls",
"rustls-pemfile",
"rustls-split",
"sentry",
"serde",
"serde_json",
@@ -4544,10 +4530,10 @@ dependencies = [
"tempfile",
"thiserror",
"tokio",
"tokio-rustls",
"tracing",
"tracing-subscriber",
"url",
"uuid",
"workspace_hack",
]
@@ -4847,19 +4833,15 @@ name = "workspace_hack"
version = "0.1.0"
dependencies = [
"anyhow",
"byteorder",
"bytes",
"chrono",
"clap 4.1.4",
"crossbeam-utils",
"digest",
"either",
"fail",
"futures",
"futures-channel",
"futures-core",
"futures-executor",
"futures-sink",
"futures-util",
"hashbrown 0.12.3",
"indexmap",
@@ -4884,7 +4866,6 @@ dependencies = [
"socket2",
"syn",
"tokio",
"tokio-rustls",
"tokio-util",
"tonic",
"tower",

View File

@@ -64,7 +64,6 @@ md5 = "0.7.0"
memoffset = "0.8"
nix = "0.26"
notify = "5.0.0"
num_cpus = "1.15"
num-traits = "0.2.15"
once_cell = "1.13"
opentelemetry = "0.18.0"
@@ -134,7 +133,6 @@ heapless = { default-features=false, features=[], git = "https://github.com/japa
consumption_metrics = { version = "0.1", path = "./libs/consumption_metrics/" }
metrics = { version = "0.1", path = "./libs/metrics/" }
pageserver_api = { version = "0.1", path = "./libs/pageserver_api/" }
postgres_backend = { version = "0.1", path = "./libs/postgres_backend/" }
postgres_connection = { version = "0.1", path = "./libs/postgres_connection/" }
postgres_ffi = { version = "0.1", path = "./libs/postgres_ffi/" }
pq_proto = { version = "0.1", path = "./libs/pq_proto/" }

View File

@@ -39,7 +39,7 @@ ARG CACHEPOT_BUCKET=neon-github-dev
COPY --from=pg-build /home/nonroot/pg_install/v14/include/postgresql/server pg_install/v14/include/postgresql/server
COPY --from=pg-build /home/nonroot/pg_install/v15/include/postgresql/server pg_install/v15/include/postgresql/server
COPY --chown=nonroot . .
COPY . .
# Show build caching stats to check if it was used in the end.
# Has to be the part of the same RUN since cachepot daemon is killed in the end of this RUN, losing the compilation stats.

View File

@@ -225,81 +225,6 @@ RUN wget https://github.com/iCyberon/pg_hashids/archive/refs/tags/v1.2.1.tar.gz
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config USE_PGXS=1 && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/pg_hashids.control
#########################################################################################
#
# Layer "rum-pg-build"
# compile rum extension
#
#########################################################################################
FROM build-deps AS rum-pg-build
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
RUN wget https://github.com/postgrespro/rum/archive/refs/tags/1.3.13.tar.gz -O rum.tar.gz && \
mkdir rum-src && cd rum-src && tar xvzf ../rum.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config USE_PGXS=1 && \
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config USE_PGXS=1 && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/rum.control
#########################################################################################
#
# Layer "pgtap-pg-build"
# compile pgTAP extension
#
#########################################################################################
FROM build-deps AS pgtap-pg-build
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
RUN wget https://github.com/theory/pgtap/archive/refs/tags/v1.2.0.tar.gz -O pgtap.tar.gz && \
mkdir pgtap-src && cd pgtap-src && tar xvzf ../pgtap.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/pgtap.control
#########################################################################################
#
# Layer "prefix-pg-build"
# compile Prefix extension
#
#########################################################################################
FROM build-deps AS prefix-pg-build
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
RUN wget https://github.com/dimitri/prefix/archive/refs/tags/v1.2.9.tar.gz -O prefix.tar.gz && \
mkdir prefix-src && cd prefix-src && tar xvzf ../prefix.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/prefix.control
#########################################################################################
#
# Layer "hll-pg-build"
# compile hll extension
#
#########################################################################################
FROM build-deps AS hll-pg-build
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
RUN wget https://github.com/citusdata/postgresql-hll/archive/refs/tags/v2.17.tar.gz -O hll.tar.gz && \
mkdir hll-src && cd hll-src && tar xvzf ../hll.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/hll.control
#########################################################################################
#
# Layer "plpgsql-check-pg-build"
# compile plpgsql_check extension
#
#########################################################################################
FROM build-deps AS plpgsql-check-pg-build
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
RUN wget https://github.com/okbob/plpgsql_check/archive/refs/tags/v2.3.2.tar.gz -O plpgsql_check.tar.gz && \
mkdir plpgsql_check-src && cd plpgsql_check-src && tar xvzf ../plpgsql_check.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config USE_PGXS=1 && \
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config USE_PGXS=1 && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/plpgsql_check.control
#########################################################################################
#
# Layer "rust extensions"
@@ -323,7 +248,7 @@ RUN curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux
chmod +x rustup-init && \
./rustup-init -y --no-modify-path --profile minimal --default-toolchain stable && \
rm rustup-init && \
cargo install --locked --version 0.7.3 cargo-pgx && \
cargo install --git https://github.com/vadim2404/pgx --branch neon_abi_v0.6.1 --locked cargo-pgx && \
/bin/bash -c 'cargo pgx init --pg${PG_VERSION:1}=/usr/local/pgsql/bin/pg_config'
USER root
@@ -337,11 +262,11 @@ USER root
FROM rust-extensions-build AS pg-jsonschema-pg-build
# there is no release tag yet, but we need it due to the superuser fix in the control file
RUN wget https://github.com/supabase/pg_jsonschema/archive/caeab60d70b2fd3ae421ec66466a3abbb37b7ee6.tar.gz -O pg_jsonschema.tar.gz && \
mkdir pg_jsonschema-src && cd pg_jsonschema-src && tar xvzf ../pg_jsonschema.tar.gz --strip-components=1 -C . && \
sed -i 's/pgx = "0.7.1"/pgx = { version = "0.7.3", features = [ "unsafe-postgres" ] }/g' Cargo.toml && \
RUN git clone --depth=1 --single-branch --branch neon_abi_v0.1.4 https://github.com/vadim2404/pg_jsonschema/ && \
cd pg_jsonschema && \
cargo pgx install --release && \
# it's needed to enable extension because it uses untrusted C language
sed -i 's/superuser = false/superuser = true/g' /usr/local/pgsql/share/extension/pg_jsonschema.control && \
echo "trusted = true" >> /usr/local/pgsql/share/extension/pg_jsonschema.control
#########################################################################################
@@ -353,32 +278,13 @@ RUN wget https://github.com/supabase/pg_jsonschema/archive/caeab60d70b2fd3ae421e
FROM rust-extensions-build AS pg-graphql-pg-build
# Currently pgx version bump to >= 0.7.2 causes "call to unsafe function" compliation errors in
# pgx-contrib-spiext. There is a branch that removes that dependency, so use it. It is on the
# same 1.1 version we've used before.
RUN git clone -b remove-pgx-contrib-spiext --single-branch https://github.com/yrashk/pg_graphql && \
cd pg_graphql && \
sed -i 's/pgx = "~0.7.1"/pgx = { version = "0.7.3", features = [ "unsafe-postgres" ] }/g' Cargo.toml && \
sed -i 's/pgx-tests = "~0.7.1"/pgx-tests = "0.7.3"/g' Cargo.toml && \
RUN git clone --depth=1 --single-branch --branch neon_abi_v1.1.0 https://github.com/vadim2404/pg_graphql && \
cd pg_graphql && \
cargo pgx install --release && \
# it's needed to enable extension because it uses untrusted C language
sed -i 's/superuser = false/superuser = true/g' /usr/local/pgsql/share/extension/pg_graphql.control && \
echo "trusted = true" >> /usr/local/pgsql/share/extension/pg_graphql.control
#########################################################################################
#
# Layer "pg-tiktoken-build"
# Compile "pg_tiktoken" extension
#
#########################################################################################
FROM rust-extensions-build AS pg-tiktoken-pg-build
RUN git clone --depth=1 --single-branch https://github.com/kelvich/pg_tiktoken && \
cd pg_tiktoken && \
cargo pgx install --release && \
echo "trusted = true" >> /usr/local/pgsql/share/extension/pg_tiktoken.control
#########################################################################################
#
# Layer "neon-pg-ext-build"
@@ -396,23 +302,13 @@ COPY --from=vector-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pgjwt-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-jsonschema-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-graphql-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-tiktoken-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=hypopg-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-hashids-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=rum-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pgtap-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=prefix-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=hll-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=plpgsql-check-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY pgxn/ pgxn/
RUN make -j $(getconf _NPROCESSORS_ONLN) \
PG_CONFIG=/usr/local/pgsql/bin/pg_config \
-C pgxn/neon \
-s install && \
make -j $(getconf _NPROCESSORS_ONLN) \
PG_CONFIG=/usr/local/pgsql/bin/pg_config \
-C pgxn/neon_utils \
-s install
#########################################################################################
@@ -467,7 +363,7 @@ COPY --from=compute-tools --chown=postgres /home/nonroot/target/release-line-deb
# Install:
# libreadline8 for psql
# libicu67, locales for collations (including ICU and plpgsql_check)
# libicu67, locales for collations (including ICU)
# libossp-uuid16 for extension ossp-uuid
# libgeos, libgdal, libsfcgal1, libproj and libprotobuf-c1 for PostGIS
# libxml2, libxslt1.1 for xml2

View File

@@ -10,16 +10,23 @@ RUN set -e \
&& rm -f /etc/inittab \
&& touch /etc/inittab
ADD vm-cgconfig.conf /etc/cgconfig.conf
RUN set -e \
&& echo "::respawn:su vm-informant -c '/usr/local/bin/vm-informant --auto-restart'" >> /etc/inittab
&& echo "::sysinit:cgconfigparser -l /etc/cgconfig.conf -s 1664" >> /etc/inittab \
&& echo "::respawn:su vm-informant -c '/usr/local/bin/vm-informant --auto-restart --cgroup=neon-postgres'" >> /etc/inittab
# Combine, starting from non-VM compute node image.
FROM $SRC_IMAGE as base
# Temporarily set user back to root so we can run adduser
# Temporarily set user back to root so we can run apt update and adduser
USER root
RUN apt update && \
apt install --no-install-recommends -y \
cgroup-tools
RUN adduser vm-informant --disabled-password --no-create-home
USER postgres
COPY --from=informant /etc/inittab /etc/inittab
COPY --from=informant /usr/bin/vm-informant /usr/local/bin/vm-informant
ENTRYPOINT ["/usr/sbin/cgexec", "-g", "*:neon-postgres", "/usr/local/bin/compute_ctl"]

View File

@@ -133,11 +133,6 @@ neon-pg-ext-%: postgres-%
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config CFLAGS='$(PG_CFLAGS) $(COPT)' \
-C $(POSTGRES_INSTALL_DIR)/build/neon-test-utils-$* \
-f $(ROOT_PROJECT_DIR)/pgxn/neon_test_utils/Makefile install
+@echo "Compiling neon_utils $*"
mkdir -p $(POSTGRES_INSTALL_DIR)/build/neon-utils-$*
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config CFLAGS='$(PG_CFLAGS) $(COPT)' \
-C $(POSTGRES_INSTALL_DIR)/build/neon-utils-$* \
-f $(ROOT_PROJECT_DIR)/pgxn/neon_utils/Makefile install
.PHONY: neon-pg-ext-clean-%
neon-pg-ext-clean-%:
@@ -150,9 +145,6 @@ neon-pg-ext-clean-%:
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config \
-C $(POSTGRES_INSTALL_DIR)/build/neon-test-utils-$* \
-f $(ROOT_PROJECT_DIR)/pgxn/neon_test_utils/Makefile clean
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config \
-C $(POSTGRES_INSTALL_DIR)/build/neon-utils-$* \
-f $(ROOT_PROJECT_DIR)/pgxn/neon_utils/Makefile clean
.PHONY: neon-pg-ext
neon-pg-ext: \

View File

@@ -11,7 +11,6 @@ clap.workspace = true
futures.workspace = true
hyper = { workspace = true, features = ["full"] }
notify.workspace = true
num_cpus.workspace = true
opentelemetry.workspace = true
postgres.workspace = true
regex.workspace = true

View File

@@ -25,7 +25,6 @@ use anyhow::{Context, Result};
use chrono::{DateTime, Utc};
use postgres::{Client, NoTls};
use serde::{Serialize, Serializer};
use tokio_postgres;
use tracing::{info, instrument, warn};
use crate::checker::create_writability_check_data;
@@ -285,7 +284,6 @@ impl ComputeNode {
handle_role_deletions(self, &mut client)?;
handle_grants(self, &mut client)?;
create_writability_check_data(&mut client)?;
handle_extensions(&self.spec, &mut client)?;
// 'Close' connection
drop(client);
@@ -402,43 +400,4 @@ impl ComputeNode {
Ok(())
}
/// Select `pg_stat_statements` data and return it as a stringified JSON
pub async fn collect_insights(&self) -> String {
let mut result_rows: Vec<String> = Vec::new();
let connect_result = tokio_postgres::connect(self.connstr.as_str(), NoTls).await;
let (client, connection) = connect_result.unwrap();
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("connection error: {}", e);
}
});
let result = client
.simple_query(
"SELECT
row_to_json(pg_stat_statements)
FROM
pg_stat_statements
WHERE
userid != 'cloud_admin'::regrole::oid
ORDER BY
(mean_exec_time + mean_plan_time) DESC
LIMIT 100",
)
.await;
if let Ok(raw_rows) = result {
for message in raw_rows.iter() {
if let postgres::SimpleQueryMessage::Row(row) = message {
if let Some(json) = row.get(0) {
result_rows.push(json.to_string());
}
}
}
format!("{{\"pg_stat_statements\": [{}]}}", result_rows.join(","))
} else {
"{{\"pg_stat_statements\": []}}".to_string()
}
}
}

View File

@@ -7,7 +7,6 @@ use crate::compute::ComputeNode;
use anyhow::Result;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Method, Request, Response, Server, StatusCode};
use num_cpus;
use serde_json;
use tracing::{error, info};
use tracing_utils::http::OtelName;
@@ -34,13 +33,6 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
Response::new(Body::from(serde_json::to_string(&compute.metrics).unwrap()))
}
// Collect Postgres current usage insights
(&Method::GET, "/insights") => {
info!("serving /insights GET request");
let insights = compute.collect_insights().await;
Response::new(Body::from(insights))
}
(&Method::POST, "/check_writability") => {
info!("serving /check_writability POST request");
let res = crate::checker::check_writability(compute).await;
@@ -50,17 +42,6 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
}
}
(&Method::GET, "/info") => {
let num_cpus = num_cpus::get_physical();
info!("serving /info GET request. num_cpus: {}", num_cpus);
Response::new(Body::from(
serde_json::json!({
"num_cpus": num_cpus,
})
.to_string(),
))
}
// Return the `404 Not Found` for any other routes.
_ => {
let mut not_found = Response::new(Body::from("404 Not Found"));

View File

@@ -10,12 +10,12 @@ paths:
/status:
get:
tags:
- Info
- "info"
summary: Get compute node internal status
description: ""
operationId: getComputeStatus
responses:
200:
"200":
description: ComputeState
content:
application/json:
@@ -25,58 +25,27 @@ paths:
/metrics.json:
get:
tags:
- Info
- "info"
summary: Get compute node startup metrics in JSON format
description: ""
operationId: getComputeMetricsJSON
responses:
200:
"200":
description: ComputeMetrics
content:
application/json:
schema:
$ref: "#/components/schemas/ComputeMetrics"
/insights:
get:
tags:
- Info
summary: Get current compute insights in JSON format
description: |
Note, that this doesn't include any historical data
operationId: getComputeInsights
responses:
200:
description: Compute insights
content:
application/json:
schema:
$ref: "#/components/schemas/ComputeInsights"
/info:
get:
tags:
- "info"
summary: Get info about the compute Pod/VM
description: ""
operationId: getInfo
responses:
"200":
description: Info
content:
application/json:
schema:
$ref: "#/components/schemas/Info"
/check_writability:
post:
tags:
- Check
- "check"
summary: Check that we can write new data on this compute
description: ""
operationId: checkComputeWritability
responses:
200:
"200":
description: Check result
content:
text/plain:
@@ -111,15 +80,6 @@ components:
total_startup_ms:
type: integer
Info:
type: object
description: Information about VM/Pod
required:
- num_cpus
properties:
num_cpus:
type: integer
ComputeState:
type: object
required:
@@ -136,15 +96,6 @@ components:
type: string
description: Text of the error during compute startup, if any
ComputeInsights:
type: object
properties:
pg_stat_statements:
description: Contains raw output from pg_stat_statements in JSON format
type: array
items:
type: object
ComputeStatus:
type: string
enum:

View File

@@ -47,23 +47,12 @@ pub struct GenericOption {
/// declare a `trait` on it.
pub type GenericOptions = Option<Vec<GenericOption>>;
/// Escape a string for including it in a SQL literal
fn escape_literal(s: &str) -> String {
s.replace('\'', "''").replace('\\', "\\\\")
}
/// Escape a string so that it can be used in postgresql.conf.
/// Same as escape_literal, currently.
fn escape_conf_value(s: &str) -> String {
s.replace('\'', "''").replace('\\', "\\\\")
}
impl GenericOption {
/// Represent `GenericOption` as SQL statement parameter.
pub fn to_pg_option(&self) -> String {
if let Some(val) = &self.value {
match self.vartype.as_ref() {
"string" => format!("{} '{}'", self.name, escape_literal(val)),
"string" => format!("{} '{}'", self.name, val),
_ => format!("{} {}", self.name, val),
}
} else {
@@ -74,8 +63,6 @@ impl GenericOption {
/// Represent `GenericOption` as configuration option.
pub fn to_pg_setting(&self) -> String {
if let Some(val) = &self.value {
// TODO: check in the console DB that we don't have these settings
// set for any non-deleted project and drop this override.
let name = match self.name.as_str() {
"safekeepers" => "neon.safekeepers",
"wal_acceptor_reconnect" => "neon.safekeeper_reconnect_timeout",
@@ -84,7 +71,7 @@ impl GenericOption {
};
match self.vartype.as_ref() {
"string" => format!("{} = '{}'", name, escape_conf_value(val)),
"string" => format!("{} = '{}'", name, val),
_ => format!("{} = {}", name, val),
}
} else {
@@ -120,7 +107,6 @@ impl PgOptionsSerialize for GenericOptions {
.map(|op| op.to_pg_setting())
.collect::<Vec<String>>()
.join("\n")
+ "\n" // newline after last setting
} else {
"".to_string()
}

View File

@@ -515,18 +515,3 @@ pub fn handle_grants(node: &ComputeNode, client: &mut Client) -> Result<()> {
Ok(())
}
/// Create required system extensions
#[instrument(skip_all)]
pub fn handle_extensions(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
if let Some(libs) = spec.cluster.settings.find("shared_preload_libraries") {
if libs.contains("pg_stat_statements") {
// Create extension only if this compute really needs it
let query = "CREATE EXTENSION IF NOT EXISTS pg_stat_statements";
info!("creating system extensions with query: {}", query);
client.simple_query(query)?;
}
}
Ok(())
}

View File

@@ -178,11 +178,6 @@
"name": "neon.pageserver_connstring",
"value": "host=127.0.0.1 port=6400",
"vartype": "string"
},
{
"name": "test.escaping",
"value": "here's a backslash \\ and a quote ' and a double-quote \" hooray",
"vartype": "string"
}
]
},

View File

@@ -28,30 +28,7 @@ mod pg_helpers_tests {
assert_eq!(
spec.cluster.settings.as_pg_settings(),
r#"fsync = off
wal_level = replica
hot_standby = on
neon.safekeepers = '127.0.0.1:6502,127.0.0.1:6503,127.0.0.1:6501'
wal_log_hints = on
log_connections = on
shared_buffers = 32768
port = 55432
max_connections = 100
max_wal_senders = 10
listen_addresses = '0.0.0.0'
wal_sender_timeout = 0
password_encryption = md5
maintenance_work_mem = 65536
max_parallel_workers = 8
max_worker_processes = 8
neon.tenant_id = 'b0554b632bd4d547a63b86c3630317e8'
max_replication_slots = 10
neon.timeline_id = '2414a61ffc94e428f14b5758fe308e13'
shared_preload_libraries = 'neon'
synchronous_standby_names = 'walproposer'
neon.pageserver_connstring = 'host=127.0.0.1 port=6400'
test.escaping = 'here''s a backslash \\ and a quote '' and a double-quote " hooray'
"#
"fsync = off\nwal_level = replica\nhot_standby = on\nneon.safekeepers = '127.0.0.1:6502,127.0.0.1:6503,127.0.0.1:6501'\nwal_log_hints = on\nlog_connections = on\nshared_buffers = 32768\nport = 55432\nmax_connections = 100\nmax_wal_senders = 10\nlisten_addresses = '0.0.0.0'\nwal_sender_timeout = 0\npassword_encryption = md5\nmaintenance_work_mem = 65536\nmax_parallel_workers = 8\nmax_worker_processes = 8\nneon.tenant_id = 'b0554b632bd4d547a63b86c3630317e8'\nmax_replication_slots = 10\nneon.timeline_id = '2414a61ffc94e428f14b5758fe308e13'\nshared_preload_libraries = 'neon'\nsynchronous_standby_names = 'walproposer'\nneon.pageserver_connstring = 'host=127.0.0.1 port=6400'"
);
}

View File

@@ -24,7 +24,6 @@ url.workspace = true
# Note: Do not directly depend on pageserver or safekeeper; use pageserver_api or safekeeper_api
# instead, so that recompile times are better.
pageserver_api.workspace = true
postgres_backend.workspace = true
safekeeper_api.workspace = true
postgres_connection.workspace = true
storage_broker.workspace = true

View File

@@ -17,7 +17,6 @@ use pageserver_api::{
DEFAULT_HTTP_LISTEN_ADDR as DEFAULT_PAGESERVER_HTTP_ADDR,
DEFAULT_PG_LISTEN_ADDR as DEFAULT_PAGESERVER_PG_ADDR,
};
use postgres_backend::AuthType;
use safekeeper_api::{
DEFAULT_HTTP_LISTEN_PORT as DEFAULT_SAFEKEEPER_HTTP_PORT,
DEFAULT_PG_LISTEN_PORT as DEFAULT_SAFEKEEPER_PG_PORT,
@@ -31,6 +30,7 @@ use utils::{
auth::{Claims, Scope},
id::{NodeId, TenantId, TenantTimelineId, TimelineId},
lsn::Lsn,
postgres_backend::AuthType,
project_git_version,
};

View File

@@ -11,10 +11,10 @@ use std::sync::Arc;
use std::time::Duration;
use anyhow::{Context, Result};
use postgres_backend::AuthType;
use utils::{
id::{TenantId, TimelineId},
lsn::Lsn,
postgres_backend::AuthType,
};
use crate::local_env::{LocalEnv, DEFAULT_PG_VERSION};

View File

@@ -5,7 +5,6 @@
use anyhow::{bail, ensure, Context};
use postgres_backend::AuthType;
use reqwest::Url;
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, DisplayFromStr};
@@ -20,6 +19,7 @@ use std::process::{Command, Stdio};
use utils::{
auth::{encode_from_key_file, Claims, Scope},
id::{NodeId, TenantId, TenantTimelineId, TimelineId},
postgres_backend::AuthType,
};
use crate::safekeeper::SafekeeperNode;

View File

@@ -11,7 +11,6 @@ use anyhow::{bail, Context};
use pageserver_api::models::{
TenantConfigRequest, TenantCreateRequest, TenantInfo, TimelineCreateRequest, TimelineInfo,
};
use postgres_backend::AuthType;
use postgres_connection::{parse_host_port, PgConnectionConfig};
use reqwest::blocking::{Client, RequestBuilder, Response};
use reqwest::{IntoUrl, Method};
@@ -21,6 +20,7 @@ use utils::{
http::error::HttpErrorBody,
id::{TenantId, TimelineId},
lsn::Lsn,
postgres_backend::AuthType,
};
use crate::{background_process, local_env::LocalEnv};

View File

@@ -29,41 +29,6 @@ These components should not have access to the private key and may only get toke
The key pair is generated once for an installation of compute/pageserver/safekeeper, e.g. by `neon_local init`.
There is currently no way to rotate the key without bringing down all components.
### Token format
The JWT tokens in Neon use RSA as the algorithm. Example:
Header:
```
{
"alg": "RS512", # RS256, RS384, or RS512
"typ": "JWT"
}
```
Payload:
```
{
"scope": "tenant", # "tenant", "pageserverapi", or "safekeeperdata"
"tenant_id": "5204921ff44f09de8094a1390a6a50f6",
}
```
Meanings of scope:
"tenant": Provides access to all data for a specific tenant
"pageserverapi": Provides blanket access to all tenants on the pageserver plus pageserver-wide APIs.
Should only be used e.g. for status check/tenant creation/list.
"safekeeperdata": Provides blanket access to all data on the safekeeper plus safekeeper-wide APIs.
Should only be used e.g. for status check.
Currently also used for connection from any pageserver to any safekeeper.
### CLI
CLI generates a key pair during call to `neon_local init` with the following commands:

View File

@@ -1,26 +0,0 @@
[package]
name = "postgres_backend"
version = "0.1.0"
edition.workspace = true
license.workspace = true
[dependencies]
async-trait.workspace = true
anyhow.workspace = true
bytes.workspace = true
futures.workspace = true
rustls.workspace = true
serde.workspace = true
thiserror.workspace = true
tokio.workspace = true
tokio-rustls.workspace = true
tracing.workspace = true
pq_proto.workspace = true
workspace_hack.workspace = true
[dev-dependencies]
once_cell.workspace = true
rustls-pemfile.workspace = true
tokio-postgres.workspace = true
tokio-postgres-rustls.workspace = true

View File

@@ -1,910 +0,0 @@
//! Server-side asynchronous Postgres connection, as limited as we need.
//! To use, create PostgresBackend and run() it, passing the Handler
//! implementation determining how to process the queries. Currently its API
//! is rather narrow, but we can extend it once required.
use anyhow::Context;
use bytes::Bytes;
use futures::pin_mut;
use serde::{Deserialize, Serialize};
use std::io::ErrorKind;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{ready, Poll};
use std::{fmt, io};
use std::{future::Future, str::FromStr};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::TlsAcceptor;
use tracing::{debug, error, info, trace};
use pq_proto::framed::{ConnectionError, Framed, FramedReader, FramedWriter};
use pq_proto::{
BeMessage, FeMessage, FeStartupPacket, ProtocolError, SQLSTATE_INTERNAL_ERROR,
SQLSTATE_SUCCESSFUL_COMPLETION,
};
/// An error, occurred during query processing:
/// either during the connection ([`ConnectionError`]) or before/after it.
#[derive(thiserror::Error, Debug)]
pub enum QueryError {
/// The connection was lost while processing the query.
#[error(transparent)]
Disconnected(#[from] ConnectionError),
/// Some other error
#[error(transparent)]
Other(#[from] anyhow::Error),
}
impl From<io::Error> for QueryError {
fn from(e: io::Error) -> Self {
Self::Disconnected(ConnectionError::Io(e))
}
}
impl QueryError {
pub fn pg_error_code(&self) -> &'static [u8; 5] {
match self {
Self::Disconnected(_) => b"08006", // connection failure
Self::Other(_) => SQLSTATE_INTERNAL_ERROR, // internal error
}
}
}
pub fn is_expected_io_error(e: &io::Error) -> bool {
use io::ErrorKind::*;
matches!(
e.kind(),
ConnectionRefused | ConnectionAborted | ConnectionReset
)
}
#[async_trait::async_trait]
pub trait Handler {
/// Handle single query.
/// postgres_backend will issue ReadyForQuery after calling this (this
/// might be not what we want after CopyData streaming, but currently we don't
/// care). It will also flush out the output buffer.
async fn process_query(
&mut self,
pgb: &mut PostgresBackend,
query_string: &str,
) -> Result<(), QueryError>;
/// Called on startup packet receival, allows to process params.
///
/// If Ok(false) is returned postgres_backend will skip auth -- that is needed for new users
/// creation is the proxy code. That is quite hacky and ad-hoc solution, may be we could allow
/// to override whole init logic in implementations.
fn startup(
&mut self,
_pgb: &mut PostgresBackend,
_sm: &FeStartupPacket,
) -> Result<(), QueryError> {
Ok(())
}
/// Check auth jwt
fn check_auth_jwt(
&mut self,
_pgb: &mut PostgresBackend,
_jwt_response: &[u8],
) -> Result<(), QueryError> {
Err(QueryError::Other(anyhow::anyhow!("JWT auth failed")))
}
}
/// PostgresBackend protocol state.
/// XXX: The order of the constructors matters.
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd)]
pub enum ProtoState {
/// Nothing happened yet.
Initialization,
/// Encryption handshake is done; waiting for encrypted Startup message.
Encrypted,
/// Waiting for password (auth token).
Authentication,
/// Performed handshake and auth, ReadyForQuery is issued.
Established,
Closed,
}
#[derive(Clone, Copy)]
pub enum ProcessMsgResult {
Continue,
Break,
}
/// Either plain TCP stream or encrypted one, implementing AsyncRead + AsyncWrite.
pub enum MaybeTlsStream {
Unencrypted(tokio::net::TcpStream),
Tls(Box<tokio_rustls::server::TlsStream<tokio::net::TcpStream>>),
}
impl AsyncWrite for MaybeTlsStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_write(cx, buf),
Self::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_flush(cx),
Self::Tls(stream) => Pin::new(stream).poll_flush(cx),
}
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_shutdown(cx),
Self::Tls(stream) => Pin::new(stream).poll_shutdown(cx),
}
}
}
impl AsyncRead for MaybeTlsStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_read(cx, buf),
Self::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
}
}
}
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
pub enum AuthType {
Trust,
// This mimics postgres's AuthenticationCleartextPassword but instead of password expects JWT
NeonJWT,
}
impl FromStr for AuthType {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"Trust" => Ok(Self::Trust),
"NeonJWT" => Ok(Self::NeonJWT),
_ => anyhow::bail!("invalid value \"{s}\" for auth type"),
}
}
}
impl fmt::Display for AuthType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
AuthType::Trust => "Trust",
AuthType::NeonJWT => "NeonJWT",
})
}
}
/// Either full duplex Framed or write only half; the latter is left in
/// PostgresBackend after call to `split`. In principle we could always store a
/// pair of splitted handles, but that would force to to pay splitting price
/// (Arc and kinda mutex inside polling) for all uses (e.g. pageserver).
enum MaybeWriteOnly {
Full(Framed<MaybeTlsStream>),
WriteOnly(FramedWriter<MaybeTlsStream>),
Broken, // temporary value palmed off during the split
}
impl MaybeWriteOnly {
async fn read_startup_message(&mut self) -> Result<Option<FeStartupPacket>, ConnectionError> {
match self {
MaybeWriteOnly::Full(framed) => framed.read_startup_message().await,
MaybeWriteOnly::WriteOnly(_) => {
Err(io::Error::new(ErrorKind::Other, "reading from write only half").into())
}
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
}
}
async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
match self {
MaybeWriteOnly::Full(framed) => framed.read_message().await,
MaybeWriteOnly::WriteOnly(_) => {
Err(io::Error::new(ErrorKind::Other, "reading from write only half").into())
}
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
}
}
fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
match self {
MaybeWriteOnly::Full(framed) => framed.write_message(msg),
MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.write_message_noflush(msg),
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
}
}
async fn flush(&mut self) -> io::Result<()> {
match self {
MaybeWriteOnly::Full(framed) => framed.flush().await,
MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.flush().await,
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
}
}
async fn shutdown(&mut self) -> io::Result<()> {
match self {
MaybeWriteOnly::Full(framed) => framed.shutdown().await,
MaybeWriteOnly::WriteOnly(framed_writer) => framed_writer.shutdown().await,
MaybeWriteOnly::Broken => panic!("IO on invalid MaybeWriteOnly"),
}
}
}
pub struct PostgresBackend {
framed: MaybeWriteOnly,
pub state: ProtoState,
auth_type: AuthType,
peer_addr: SocketAddr,
pub tls_config: Option<Arc<rustls::ServerConfig>>,
}
pub fn query_from_cstring(query_string: Bytes) -> Vec<u8> {
let mut query_string = query_string.to_vec();
if let Some(ch) = query_string.last() {
if *ch == 0 {
query_string.pop();
}
}
query_string
}
/// Cast a byte slice to a string slice, dropping null terminator if there's one.
fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> {
let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes);
std::str::from_utf8(without_null).map_err(|e| e.into())
}
impl PostgresBackend {
pub fn new(
socket: tokio::net::TcpStream,
auth_type: AuthType,
tls_config: Option<Arc<rustls::ServerConfig>>,
) -> io::Result<Self> {
let peer_addr = socket.peer_addr()?;
let stream = MaybeTlsStream::Unencrypted(socket);
Ok(Self {
framed: MaybeWriteOnly::Full(Framed::new(stream)),
state: ProtoState::Initialization,
auth_type,
tls_config,
peer_addr,
})
}
pub fn get_peer_addr(&self) -> &SocketAddr {
&self.peer_addr
}
/// Read full message or return None if connection is cleanly closed with no
/// unprocessed data.
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
if let ProtoState::Closed = self.state {
Ok(None)
} else {
let m = self.framed.read_message().await?;
trace!("read msg {:?}", m);
Ok(m)
}
}
/// Write message into internal output buffer, doesn't flush it. Technically
/// error type can be only ProtocolError here (if, unlikely, serialization
/// fails), but callers typically wrap it anyway.
pub fn write_message_noflush(
&mut self,
message: &BeMessage<'_>,
) -> Result<&mut Self, ConnectionError> {
self.framed.write_message_noflush(message)?;
trace!("wrote msg {:?}", message);
Ok(self)
}
/// Flush output buffer into the socket.
pub async fn flush(&mut self) -> io::Result<()> {
self.framed.flush().await
}
/// Polling version of `flush()`, saves the caller need to pin.
pub fn poll_flush(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let flush_fut = self.flush();
pin_mut!(flush_fut);
flush_fut.poll(cx)
}
/// Write message into internal output buffer and flush it to the stream.
pub async fn write_message(
&mut self,
message: &BeMessage<'_>,
) -> Result<&mut Self, ConnectionError> {
self.write_message_noflush(message)?;
self.flush().await?;
Ok(self)
}
/// Returns an AsyncWrite implementation that wraps all the data written
/// to it in CopyData messages, and writes them to the connection
///
/// The caller is responsible for sending CopyOutResponse and CopyDone messages.
pub fn copyout_writer(&mut self) -> CopyDataWriter {
CopyDataWriter { pgb: self }
}
/// Wrapper for run_message_loop() that shuts down socket when we are done
pub async fn run<F, S>(
mut self,
handler: &mut impl Handler,
shutdown_watcher: F,
) -> Result<(), QueryError>
where
F: Fn() -> S,
S: Future,
{
let ret = self.run_message_loop(handler, shutdown_watcher).await;
// socket might be already closed, e.g. if previously received error,
// so ignore result.
self.framed.shutdown().await.ok();
ret
}
async fn run_message_loop<F, S>(
&mut self,
handler: &mut impl Handler,
shutdown_watcher: F,
) -> Result<(), QueryError>
where
F: Fn() -> S,
S: Future,
{
trace!("postgres backend to {:?} started", self.peer_addr);
tokio::select!(
biased;
_ = shutdown_watcher() => {
// We were requested to shut down.
tracing::info!("shutdown request received during handshake");
return Ok(())
},
result = self.handshake(handler) => {
// Handshake complete.
result?;
if self.state == ProtoState::Closed {
return Ok(()); // EOF during handshake
}
}
);
// Authentication completed
let mut query_string = Bytes::new();
while let Some(msg) = tokio::select!(
biased;
_ = shutdown_watcher() => {
// We were requested to shut down.
tracing::info!("shutdown request received in run_message_loop");
Ok(None)
},
msg = self.read_message() => { msg },
)? {
trace!("got message {:?}", msg);
let result = self.process_message(handler, msg, &mut query_string).await;
self.flush().await?;
match result? {
ProcessMsgResult::Continue => {
self.flush().await?;
continue;
}
ProcessMsgResult::Break => break,
}
}
trace!("postgres backend to {:?} exited", self.peer_addr);
Ok(())
}
/// Try to upgrade MaybeTlsStream into actual TLS one, performing handshake.
async fn tls_upgrade(
src: MaybeTlsStream,
tls_config: Arc<rustls::ServerConfig>,
) -> anyhow::Result<MaybeTlsStream> {
match src {
MaybeTlsStream::Unencrypted(s) => {
let acceptor = TlsAcceptor::from(tls_config);
let tls_stream = acceptor.accept(s).await?;
Ok(MaybeTlsStream::Tls(Box::new(tls_stream)))
}
MaybeTlsStream::Tls(_) => {
anyhow::bail!("TLS already started");
}
}
}
async fn start_tls(&mut self) -> anyhow::Result<()> {
// temporary replace stream with fake to cook TLS one, Indiana Jones style
match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
MaybeWriteOnly::Full(framed) => {
let tls_config = self
.tls_config
.as_ref()
.context("start_tls called without conf")?
.clone();
let tls_framed = framed
.map_stream(|s| PostgresBackend::tls_upgrade(s, tls_config))
.await?;
// push back ready TLS stream
self.framed = MaybeWriteOnly::Full(tls_framed);
Ok(())
}
MaybeWriteOnly::WriteOnly(_) => {
anyhow::bail!("TLS upgrade attempt in split state")
}
MaybeWriteOnly::Broken => panic!("TLS upgrade on framed in invalid state"),
}
}
/// Split off owned read part from which messages can be read in different
/// task/thread.
pub fn split(&mut self) -> anyhow::Result<PostgresBackendReader> {
// temporary replace stream with fake to cook split one, Indiana Jones style
match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
MaybeWriteOnly::Full(framed) => {
let (reader, writer) = framed.split();
self.framed = MaybeWriteOnly::WriteOnly(writer);
Ok(PostgresBackendReader(reader))
}
MaybeWriteOnly::WriteOnly(_) => {
anyhow::bail!("PostgresBackend is already split")
}
MaybeWriteOnly::Broken => panic!("split on framed in invalid state"),
}
}
/// Join read part back.
pub fn unsplit(&mut self, reader: PostgresBackendReader) -> anyhow::Result<()> {
// temporary replace stream with fake to cook joined one, Indiana Jones style
match std::mem::replace(&mut self.framed, MaybeWriteOnly::Broken) {
MaybeWriteOnly::Full(_) => {
anyhow::bail!("PostgresBackend is not split")
}
MaybeWriteOnly::WriteOnly(writer) => {
let joined = Framed::unsplit(reader.0, writer);
self.framed = MaybeWriteOnly::Full(joined);
Ok(())
}
MaybeWriteOnly::Broken => panic!("unsplit on framed in invalid state"),
}
}
/// Perform handshake with the client, transitioning to Established.
/// In case of EOF during handshake logs this, sets state to Closed and returns Ok(()).
async fn handshake(&mut self, handler: &mut impl Handler) -> Result<(), QueryError> {
while self.state < ProtoState::Authentication {
match self.framed.read_startup_message().await? {
Some(msg) => {
self.process_startup_message(handler, msg).await?;
}
None => {
trace!(
"postgres backend to {:?} received EOF during handshake",
self.peer_addr
);
self.state = ProtoState::Closed;
return Ok(());
}
}
}
// Perform auth, if needed.
if self.state == ProtoState::Authentication {
match self.framed.read_message().await? {
Some(FeMessage::PasswordMessage(m)) => {
assert!(self.auth_type == AuthType::NeonJWT);
let (_, jwt_response) = m.split_last().context("protocol violation")?;
if let Err(e) = handler.check_auth_jwt(self, jwt_response) {
self.write_message_noflush(&BeMessage::ErrorResponse(
&e.to_string(),
Some(e.pg_error_code()),
))?;
return Err(e);
}
self.write_message_noflush(&BeMessage::AuthenticationOk)?
.write_message_noflush(&BeMessage::CLIENT_ENCODING)?
.write_message(&BeMessage::ReadyForQuery)
.await?;
self.state = ProtoState::Established;
}
Some(m) => {
return Err(QueryError::Other(anyhow::anyhow!(
"Unexpected message {:?} while waiting for handshake",
m
)));
}
None => {
trace!(
"postgres backend to {:?} received EOF during auth",
self.peer_addr
);
self.state = ProtoState::Closed;
return Ok(());
}
}
}
Ok(())
}
/// Process startup packet:
/// - transition to Established if auth type is trust
/// - transition to Authentication if auth type is NeonJWT.
/// - or perform TLS handshake -- then need to call this again to receive
/// actual startup packet.
async fn process_startup_message(
&mut self,
handler: &mut impl Handler,
msg: FeStartupPacket,
) -> Result<(), QueryError> {
assert!(self.state < ProtoState::Authentication);
let have_tls = self.tls_config.is_some();
match msg {
FeStartupPacket::SslRequest => {
debug!("SSL requested");
self.write_message(&BeMessage::EncryptionResponse(have_tls))
.await?;
if have_tls {
self.start_tls().await?;
self.state = ProtoState::Encrypted;
}
}
FeStartupPacket::GssEncRequest => {
debug!("GSS requested");
self.write_message(&BeMessage::EncryptionResponse(false))
.await?;
}
FeStartupPacket::StartupMessage { .. } => {
if have_tls && !matches!(self.state, ProtoState::Encrypted) {
self.write_message(&BeMessage::ErrorResponse("must connect with TLS", None))
.await?;
return Err(QueryError::Other(anyhow::anyhow!(
"client did not connect with TLS"
)));
}
// NB: startup() may change self.auth_type -- we are using that in proxy code
// to bypass auth for new users.
handler.startup(self, &msg)?;
match self.auth_type {
AuthType::Trust => {
self.write_message_noflush(&BeMessage::AuthenticationOk)?
.write_message_noflush(&BeMessage::CLIENT_ENCODING)?
.write_message_noflush(&BeMessage::INTEGER_DATETIMES)?
// The async python driver requires a valid server_version
.write_message_noflush(&BeMessage::server_version("14.1"))?
.write_message(&BeMessage::ReadyForQuery)
.await?;
self.state = ProtoState::Established;
}
AuthType::NeonJWT => {
self.write_message(&BeMessage::AuthenticationCleartextPassword)
.await?;
self.state = ProtoState::Authentication;
}
}
}
FeStartupPacket::CancelRequest { .. } => {
return Err(QueryError::Other(anyhow::anyhow!(
"Unexpected CancelRequest message during handshake"
)));
}
}
Ok(())
}
async fn process_message(
&mut self,
handler: &mut impl Handler,
msg: FeMessage,
unnamed_query_string: &mut Bytes,
) -> Result<ProcessMsgResult, QueryError> {
// Allow only startup and password messages during auth. Otherwise client would be able to bypass auth
// TODO: change that to proper top-level match of protocol state with separate message handling for each state
assert!(self.state == ProtoState::Established);
match msg {
FeMessage::Query(body) => {
// remove null terminator
let query_string = cstr_to_str(&body)?;
trace!("got query {query_string:?}");
if let Err(e) = handler.process_query(self, query_string).await {
log_query_error(query_string, &e);
let short_error = short_error(&e);
self.write_message_noflush(&BeMessage::ErrorResponse(
&short_error,
Some(e.pg_error_code()),
))?;
}
self.write_message_noflush(&BeMessage::ReadyForQuery)?;
}
FeMessage::Parse(m) => {
*unnamed_query_string = m.query_string;
self.write_message_noflush(&BeMessage::ParseComplete)?;
}
FeMessage::Describe(_) => {
self.write_message_noflush(&BeMessage::ParameterDescription)?
.write_message_noflush(&BeMessage::NoData)?;
}
FeMessage::Bind(_) => {
self.write_message_noflush(&BeMessage::BindComplete)?;
}
FeMessage::Close(_) => {
self.write_message_noflush(&BeMessage::CloseComplete)?;
}
FeMessage::Execute(_) => {
let query_string = cstr_to_str(unnamed_query_string)?;
trace!("got execute {query_string:?}");
if let Err(e) = handler.process_query(self, query_string).await {
log_query_error(query_string, &e);
self.write_message_noflush(&BeMessage::ErrorResponse(
&e.to_string(),
Some(e.pg_error_code()),
))?;
}
// NOTE there is no ReadyForQuery message. This handler is used
// for basebackup and it uses CopyOut which doesn't require
// ReadyForQuery message and backend just switches back to
// processing mode after sending CopyDone or ErrorResponse.
}
FeMessage::Sync => {
self.write_message_noflush(&BeMessage::ReadyForQuery)?;
}
FeMessage::Terminate => {
return Ok(ProcessMsgResult::Break);
}
// We prefer explicit pattern matching to wildcards, because
// this helps us spot the places where new variants are missing
FeMessage::CopyData(_)
| FeMessage::CopyDone
| FeMessage::CopyFail
| FeMessage::PasswordMessage(_) => {
return Err(QueryError::Other(anyhow::anyhow!(
"unexpected message type: {msg:?}",
)));
}
}
Ok(ProcessMsgResult::Continue)
}
/// Log as info/error result of handling COPY stream and send back
/// ErrorResponse if that makes sense. Shutdown the stream if we got
/// Terminate. TODO: transition into waiting for Sync msg if we initiate the
/// close.
pub async fn handle_copy_stream_end(&mut self, end: CopyStreamHandlerEnd) {
use CopyStreamHandlerEnd::*;
let expected_end = match &end {
ServerInitiated(_) | CopyDone | CopyFail | Terminate | EOF => true,
CopyStreamHandlerEnd::Disconnected(ConnectionError::Io(io_error))
if is_expected_io_error(io_error) =>
{
true
}
_ => false,
};
if expected_end {
info!("terminated: {:#}", end);
} else {
error!("terminated: {:?}", end);
}
// Note: no current usages ever send this
if let CopyDone = &end {
if let Err(e) = self.write_message(&BeMessage::CopyDone).await {
error!("failed to send CopyDone: {}", e);
}
}
if let Terminate = &end {
self.state = ProtoState::Closed;
}
let err_to_send_and_errcode = match &end {
ServerInitiated(_) => Some((end.to_string(), SQLSTATE_SUCCESSFUL_COMPLETION)),
Other(_) => Some((end.to_string(), SQLSTATE_INTERNAL_ERROR)),
// Note: CopyFail in duplex copy is somewhat unexpected (at least to
// PG walsender; evidently and per my docs reading client should
// finish it with CopyDone). It is not a problem to recover from it
// finishing the stream in both directions like we do, but note that
// sync rust-postgres client (which we don't use anymore) hangs if
// socket is not closed here.
// https://github.com/sfackler/rust-postgres/issues/755
// https://github.com/neondatabase/neon/issues/935
//
// Currently, the version of tokio_postgres replication patch we use
// sends this when it closes the stream (e.g. pageserver decided to
// switch conn to another safekeeper and client gets dropped).
// Moreover, seems like 'connection' task errors with 'unexpected
// message from server' when it receives ErrorResponse (anything but
// CopyData/CopyDone) back.
CopyFail => Some((end.to_string(), SQLSTATE_SUCCESSFUL_COMPLETION)),
_ => None,
};
if let Some((err, errcode)) = err_to_send_and_errcode {
if let Err(ee) = self
.write_message(&BeMessage::ErrorResponse(&err, Some(errcode)))
.await
{
error!("failed to send ErrorResponse: {}", ee);
}
}
}
}
pub struct PostgresBackendReader(FramedReader<MaybeTlsStream>);
impl PostgresBackendReader {
/// Read full message or return None if connection is cleanly closed with no
/// unprocessed data.
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
let m = self.0.read_message().await?;
trace!("read msg {:?}", m);
Ok(m)
}
/// Get CopyData contents of the next message in COPY stream or error
/// closing it. The error type is wider than actual errors which can happen
/// here -- it includes 'Other' and 'ServerInitiated', but that's ok for
/// current callers.
pub async fn read_copy_message(&mut self) -> Result<Bytes, CopyStreamHandlerEnd> {
match self.read_message().await? {
Some(msg) => match msg {
FeMessage::CopyData(m) => Ok(m),
FeMessage::CopyDone => Err(CopyStreamHandlerEnd::CopyDone),
FeMessage::CopyFail => Err(CopyStreamHandlerEnd::CopyFail),
FeMessage::Terminate => Err(CopyStreamHandlerEnd::Terminate),
_ => Err(CopyStreamHandlerEnd::from(ConnectionError::Protocol(
ProtocolError::Protocol(format!("unexpected message in COPY stream {:?}", msg)),
))),
},
None => Err(CopyStreamHandlerEnd::EOF),
}
}
}
///
/// A futures::AsyncWrite implementation that wraps all data written to it in CopyData
/// messages.
///
pub struct CopyDataWriter<'a> {
pgb: &'a mut PostgresBackend,
}
impl<'a> AsyncWrite for CopyDataWriter<'a> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
let this = self.get_mut();
// It's not strictly required to flush between each message, but makes it easier
// to view in wireshark, and usually the messages that the callers write are
// decently-sized anyway.
if let Err(err) = ready!(this.pgb.poll_flush(cx)) {
return Poll::Ready(Err(err));
}
// CopyData
// XXX: if the input is large, we should split it into multiple messages.
// Not sure what the threshold should be, but the ultimate hard limit is that
// the length cannot exceed u32.
this.pgb
.write_message_noflush(&BeMessage::CopyData(buf))
// write_message only writes to the buffer, so it can fail iff the
// message is invaid, but CopyData can't be invalid.
.map_err(|_| io::Error::new(ErrorKind::Other, "failed to serialize CopyData"))?;
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let this = self.get_mut();
this.pgb.poll_flush(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let this = self.get_mut();
this.pgb.poll_flush(cx)
}
}
pub fn short_error(e: &QueryError) -> String {
match e {
QueryError::Disconnected(connection_error) => connection_error.to_string(),
QueryError::Other(e) => format!("{e:#}"),
}
}
fn log_query_error(query: &str, e: &QueryError) {
match e {
QueryError::Disconnected(ConnectionError::Io(io_error)) => {
if is_expected_io_error(io_error) {
info!("query handler for '{query}' failed with expected io error: {io_error}");
} else {
error!("query handler for '{query}' failed with io error: {io_error}");
}
}
QueryError::Disconnected(other_connection_error) => {
error!("query handler for '{query}' failed with connection error: {other_connection_error:?}")
}
QueryError::Other(e) => {
error!("query handler for '{query}' failed: {e:?}");
}
}
}
/// Something finishing handling of COPY stream, see handle_copy_stream_end.
/// This is not always a real error, but it allows to use ? and thiserror impls.
#[derive(thiserror::Error, Debug)]
pub enum CopyStreamHandlerEnd {
/// Handler initiates the end of streaming.
#[error("{0}")]
ServerInitiated(String),
#[error("received CopyDone")]
CopyDone,
#[error("received CopyFail")]
CopyFail,
#[error("received Terminate")]
Terminate,
#[error("EOF on COPY stream")]
EOF,
/// The connection was lost
#[error(transparent)]
Disconnected(#[from] ConnectionError),
/// Some other error
#[error(transparent)]
Other(#[from] anyhow::Error),
}

View File

@@ -1,139 +0,0 @@
/// Test postgres_backend_async with tokio_postgres
use once_cell::sync::Lazy;
use postgres_backend::{AuthType, Handler, PostgresBackend, QueryError};
use pq_proto::{BeMessage, RowDescriptor};
use std::io::Cursor;
use std::{future, sync::Arc};
use tokio::net::{TcpListener, TcpStream};
use tokio_postgres::config::SslMode;
use tokio_postgres::tls::MakeTlsConnect;
use tokio_postgres::{Config, NoTls, SimpleQueryMessage};
use tokio_postgres_rustls::MakeRustlsConnect;
// generate client, server test streams
async fn make_tcp_pair() -> (TcpStream, TcpStream) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let client_stream = TcpStream::connect(addr).await.unwrap();
let (server_stream, _) = listener.accept().await.unwrap();
(client_stream, server_stream)
}
struct TestHandler {}
#[async_trait::async_trait]
impl Handler for TestHandler {
// return single col 'hey' for any query
async fn process_query(
&mut self,
pgb: &mut PostgresBackend,
_query_string: &str,
) -> Result<(), QueryError> {
pgb.write_message_noflush(&BeMessage::RowDescription(&[RowDescriptor::text_col(
b"hey",
)]))?
.write_message_noflush(&BeMessage::DataRow(&[Some("hey".as_bytes())]))?
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
Ok(())
}
}
// test that basic select works
#[tokio::test]
async fn simple_select() {
let (client_sock, server_sock) = make_tcp_pair().await;
// create and run pgbackend
let pgbackend =
PostgresBackend::new(server_sock, AuthType::Trust, None).expect("pgbackend creation");
tokio::spawn(async move {
let mut handler = TestHandler {};
pgbackend.run(&mut handler, future::pending::<()>).await
});
let conf = Config::new();
let (client, connection) = conf.connect_raw(client_sock, NoTls).await.expect("connect");
// The connection object performs the actual communication with the database,
// so spawn it off to run on its own.
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("connection error: {}", e);
}
});
let first_val = &(client.simple_query("SELECT 42;").await.expect("select"))[0];
if let SimpleQueryMessage::Row(row) = first_val {
let first_col = row.get(0).expect("first column");
assert_eq!(first_col, "hey");
} else {
panic!("expected SimpleQueryMessage::Row");
}
}
static KEY: Lazy<rustls::PrivateKey> = Lazy::new(|| {
let mut cursor = Cursor::new(include_bytes!("key.pem"));
rustls::PrivateKey(rustls_pemfile::rsa_private_keys(&mut cursor).unwrap()[0].clone())
});
static CERT: Lazy<rustls::Certificate> = Lazy::new(|| {
let mut cursor = Cursor::new(include_bytes!("cert.pem"));
rustls::Certificate(rustls_pemfile::certs(&mut cursor).unwrap()[0].clone())
});
// test that basic select with ssl works
#[tokio::test]
async fn simple_select_ssl() {
let (client_sock, server_sock) = make_tcp_pair().await;
let server_cfg = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![CERT.clone()], KEY.clone())
.unwrap();
let tls_config = Some(Arc::new(server_cfg));
let pgbackend =
PostgresBackend::new(server_sock, AuthType::Trust, tls_config).expect("pgbackend creation");
tokio::spawn(async move {
let mut handler = TestHandler {};
pgbackend.run(&mut handler, future::pending::<()>).await
});
let client_cfg = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates({
let mut store = rustls::RootCertStore::empty();
store.add(&CERT).unwrap();
store
})
.with_no_client_auth();
let mut make_tls_connect = tokio_postgres_rustls::MakeRustlsConnect::new(client_cfg);
let tls_connect = <MakeRustlsConnect as MakeTlsConnect<TcpStream>>::make_tls_connect(
&mut make_tls_connect,
"localhost",
)
.expect("make_tls_connect");
let mut conf = Config::new();
conf.ssl_mode(SslMode::Require);
let (client, connection) = conf
.connect_raw(client_sock, tls_connect)
.await
.expect("connect");
// The connection object performs the actual communication with the database,
// so spawn it off to run on its own.
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("connection error: {}", e);
}
});
let first_val = &(client.simple_query("SELECT 42;").await.expect("select"))[0];
if let SimpleQueryMessage::Row(row) = first_val {
let first_col = row.get(0).expect("first column");
assert_eq!(first_col, "hey");
} else {
panic!("expected SimpleQueryMessage::Row");
}
}

View File

@@ -5,8 +5,8 @@ edition.workspace = true
license.workspace = true
[dependencies]
anyhow.workspace = true
bytes.workspace = true
byteorder.workspace = true
pin-project-lite.workspace = true
postgres-protocol.workspace = true
rand.workspace = true

View File

@@ -1,244 +0,0 @@
//! Provides `Framed` -- writing/flushing and reading Postgres messages to/from
//! the async stream based on (and buffered with) BytesMut. All functions are
//! cancellation safe.
//!
//! It is similar to what tokio_util::codec::Framed with appropriate codec
//! provides, but `FramedReader` and `FramedWriter` read/write parts can be used
//! separately without using split from futures::stream::StreamExt (which
//! allocates box[1] in polling internally). tokio::io::split is used for splitting
//! instead. Plus we customize error messages more than a single type for all io
//! calls.
//!
//! [1] https://docs.rs/futures-util/0.3.26/src/futures_util/lock/bilock.rs.html#107
use bytes::{Buf, BytesMut};
use std::{
future::Future,
io::{self, ErrorKind},
};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf};
use crate::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
const INITIAL_CAPACITY: usize = 8 * 1024;
/// Error on postgres connection: either IO (physical transport error) or
/// protocol violation.
#[derive(thiserror::Error, Debug)]
pub enum ConnectionError {
#[error(transparent)]
Io(#[from] io::Error),
#[error(transparent)]
Protocol(#[from] ProtocolError),
}
impl ConnectionError {
/// Proxy stream.rs uses only io::Error; provide it.
pub fn into_io_error(self) -> io::Error {
match self {
ConnectionError::Io(io) => io,
ConnectionError::Protocol(pe) => io::Error::new(io::ErrorKind::Other, pe.to_string()),
}
}
}
/// Wraps async io `stream`, providing messages to write/flush + read Postgres
/// messages.
pub struct Framed<S> {
stream: S,
read_buf: BytesMut,
write_buf: BytesMut,
}
impl<S> Framed<S> {
pub fn new(stream: S) -> Self {
Self {
stream,
read_buf: BytesMut::with_capacity(INITIAL_CAPACITY),
write_buf: BytesMut::with_capacity(INITIAL_CAPACITY),
}
}
/// Get a shared reference to the underlying stream.
pub fn get_ref(&self) -> &S {
&self.stream
}
/// Extract the underlying stream.
pub fn into_inner(self) -> S {
self.stream
}
/// Return new Framed with stream type transformed by async f, for TLS
/// upgrade.
pub async fn map_stream<S2, E, F, Fut>(self, f: F) -> Result<Framed<S2>, E>
where
F: FnOnce(S) -> Fut,
Fut: Future<Output = Result<S2, E>>,
{
let stream = f(self.stream).await?;
Ok(Framed {
stream,
read_buf: self.read_buf,
write_buf: self.write_buf,
})
}
}
impl<S: AsyncRead + Unpin> Framed<S> {
pub async fn read_startup_message(
&mut self,
) -> Result<Option<FeStartupPacket>, ConnectionError> {
read_message(&mut self.stream, &mut self.read_buf, FeStartupPacket::parse).await
}
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
read_message(&mut self.stream, &mut self.read_buf, FeMessage::parse).await
}
}
impl<S: AsyncWrite + Unpin> Framed<S> {
/// Write next message to the output buffer; doesn't flush.
pub fn write_message(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
BeMessage::write(&mut self.write_buf, msg)
}
/// Flush out the buffer. This function is cancellation safe: it can be
/// interrupted and flushing will be continued in the next call.
pub async fn flush(&mut self) -> Result<(), io::Error> {
flush(&mut self.stream, &mut self.write_buf).await
}
/// Flush out the buffer and shutdown the stream.
pub async fn shutdown(&mut self) -> Result<(), io::Error> {
shutdown(&mut self.stream, &mut self.write_buf).await
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> Framed<S> {
/// Split into owned read and write parts. Beware of potential issues with
/// using halves in different tasks on TLS stream:
/// https://github.com/tokio-rs/tls/issues/40
pub fn split(self) -> (FramedReader<S>, FramedWriter<S>) {
let (read_half, write_half) = tokio::io::split(self.stream);
let reader = FramedReader {
stream: read_half,
read_buf: self.read_buf,
};
let writer = FramedWriter {
stream: write_half,
write_buf: self.write_buf,
};
(reader, writer)
}
/// Join read and write parts back.
pub fn unsplit(reader: FramedReader<S>, writer: FramedWriter<S>) -> Self {
Self {
stream: reader.stream.unsplit(writer.stream),
read_buf: reader.read_buf,
write_buf: writer.write_buf,
}
}
}
/// Read-only version of `Framed`.
pub struct FramedReader<S> {
stream: ReadHalf<S>,
read_buf: BytesMut,
}
impl<S: AsyncRead + Unpin> FramedReader<S> {
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, ConnectionError> {
read_message(&mut self.stream, &mut self.read_buf, FeMessage::parse).await
}
}
/// Write-only version of `Framed`.
pub struct FramedWriter<S> {
stream: WriteHalf<S>,
write_buf: BytesMut,
}
impl<S: AsyncWrite + Unpin> FramedWriter<S> {
/// Write next message to the output buffer; doesn't flush.
pub fn write_message_noflush(&mut self, msg: &BeMessage<'_>) -> Result<(), ProtocolError> {
BeMessage::write(&mut self.write_buf, msg)
}
/// Flush out the buffer. This function is cancellation safe: it can be
/// interrupted and flushing will be continued in the next call.
pub async fn flush(&mut self) -> Result<(), io::Error> {
flush(&mut self.stream, &mut self.write_buf).await
}
/// Flush out the buffer and shutdown the stream.
pub async fn shutdown(&mut self) -> Result<(), io::Error> {
shutdown(&mut self.stream, &mut self.write_buf).await
}
}
/// Read next message from the stream. Returns Ok(None), if EOF happened and we
/// don't have remaining data in the buffer. This function is cancellation safe:
/// you can drop future which is not yet complete and finalize reading message
/// with the next call.
///
/// Parametrized to allow reading startup or usual message, having different
/// format.
async fn read_message<S: AsyncRead + Unpin, M, P>(
stream: &mut S,
read_buf: &mut BytesMut,
parse: P,
) -> Result<Option<M>, ConnectionError>
where
P: Fn(&mut BytesMut) -> Result<Option<M>, ProtocolError>,
{
loop {
if let Some(msg) = parse(read_buf)? {
return Ok(Some(msg));
}
// If we can't build a frame yet, try to read more data and try again.
// Make sure we've got room for at least one byte to read to ensure
// that we don't get a spurious 0 that looks like EOF.
read_buf.reserve(1);
if stream.read_buf(read_buf).await? == 0 {
if read_buf.has_remaining() {
return Err(io::Error::new(
ErrorKind::UnexpectedEof,
"EOF with unprocessed data in the buffer",
)
.into());
} else {
return Ok(None); // clean EOF
}
}
}
}
async fn flush<S: AsyncWrite + Unpin>(
stream: &mut S,
write_buf: &mut BytesMut,
) -> Result<(), io::Error> {
while write_buf.has_remaining() {
let bytes_written = stream.write(write_buf.chunk()).await?;
if bytes_written == 0 {
return Err(io::Error::new(
ErrorKind::WriteZero,
"failed to write message",
));
}
// The advanced part will be garbage collected, likely during shifting
// data left on next attempt to write to buffer when free space is not
// enough.
write_buf.advance(bytes_written);
}
write_buf.clear();
stream.flush().await
}
async fn shutdown<S: AsyncWrite + Unpin>(
stream: &mut S,
write_buf: &mut BytesMut,
) -> Result<(), io::Error> {
flush(stream, write_buf).await?;
stream.shutdown().await
}

View File

@@ -2,18 +2,24 @@
//! <https://www.postgresql.org/docs/devel/protocol-message-formats.html>
//! on message formats.
pub mod framed;
// Tools for calling certain async methods in sync contexts.
pub mod sync;
use byteorder::{BigEndian, ReadBytesExt};
use anyhow::{ensure, Context, Result};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use postgres_protocol::PG_EPOCH;
use serde::{Deserialize, Serialize};
use std::{
borrow::Cow,
collections::HashMap,
fmt, io, str,
fmt,
future::Future,
io::{self, Cursor},
str,
time::{Duration, SystemTime},
};
use sync::{AsyncishRead, SyncFuture};
use tokio::io::AsyncReadExt;
use tracing::{trace, warn};
pub type Oid = u32;
@@ -25,6 +31,7 @@ pub const TEXT_OID: Oid = 25;
#[derive(Debug)]
pub enum FeMessage {
StartupPacket(FeStartupPacket),
// Simple query.
Query(Bytes),
// Extended query protocol.
@@ -184,207 +191,260 @@ pub struct FeExecuteMessage {
#[derive(Debug)]
pub struct FeCloseMessage;
/// An error occured while parsing or serializing raw stream into Postgres
/// messages.
#[derive(thiserror::Error, Debug)]
pub enum ProtocolError {
/// Invalid packet was received from the client (e.g. unexpected message
/// type or broken len).
#[error("Protocol error: {0}")]
Protocol(String),
/// Failed to parse or, (unlikely), serialize a protocol message.
#[error("Message parse error: {0}")]
BadMessage(String),
/// Retry a read on EINTR
///
/// This runs the enclosed expression, and if it returns
/// Err(io::ErrorKind::Interrupted), retries it.
macro_rules! retry_read {
( $x:expr ) => {
loop {
match $x {
Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
res => break res,
}
}
};
}
impl ProtocolError {
/// Proxy stream.rs uses only io::Error; provide it.
/// An error occured during connection being open.
#[derive(thiserror::Error, Debug)]
pub enum ConnectionError {
/// IO error during writing to or reading from the connection socket.
#[error("Socket IO error: {0}")]
Socket(std::io::Error),
/// Invalid packet was received from client
#[error("Protocol error: {0}")]
Protocol(String),
/// Failed to parse a protocol mesage
#[error("Message parse error: {0}")]
MessageParse(anyhow::Error),
}
impl From<anyhow::Error> for ConnectionError {
fn from(e: anyhow::Error) -> Self {
Self::MessageParse(e)
}
}
impl ConnectionError {
pub fn into_io_error(self) -> io::Error {
io::Error::new(io::ErrorKind::Other, self.to_string())
match self {
ConnectionError::Socket(io) => io,
other => io::Error::new(io::ErrorKind::Other, other.to_string()),
}
}
}
impl FeMessage {
/// Read and parse one message from the `buf` input buffer. If there is at
/// least one valid message, returns it, advancing `buf`; redundant copies
/// are avoided, as thanks to `bytes` crate ptrs in parsed message point
/// directly into the `buf` (processed data is garbage collected after
/// parsed message is dropped).
/// Read one message from the stream.
/// This function returns `Ok(None)` in case of EOF.
/// One way to handle this properly:
///
/// Returns None if `buf` doesn't contain enough data for a single message.
/// For efficiency, tries to reserve large enough space in `buf` for the
/// next message in this case to save the repeated calls.
/// ```
/// # use std::io;
/// # use pq_proto::FeMessage;
/// #
/// # fn process_message(msg: FeMessage) -> anyhow::Result<()> {
/// # Ok(())
/// # };
/// #
/// fn do_the_job(stream: &mut (impl io::Read + Unpin)) -> anyhow::Result<()> {
/// while let Some(msg) = FeMessage::read(stream)? {
/// process_message(msg)?;
/// }
///
/// Returns Error if message is malformed, the only possible ErrorKind is
/// InvalidInput.
//
// Inspired by rust-postgres Message::parse.
pub fn parse(buf: &mut BytesMut) -> Result<Option<FeMessage>, ProtocolError> {
// Every message contains message type byte and 4 bytes len; can't do
// much without them.
if buf.len() < 5 {
let to_read = 5 - buf.len();
buf.reserve(to_read);
return Ok(None);
}
/// Ok(())
/// }
/// ```
#[inline(never)]
pub fn read(
stream: &mut (impl io::Read + Unpin),
) -> Result<Option<FeMessage>, ConnectionError> {
Self::read_fut(&mut AsyncishRead(stream)).wait()
}
// We shouldn't advance `buf` as probably full message is not there yet,
// so can't directly use Bytes::get_u32 etc.
let tag = buf[0];
let len = (&buf[1..5]).read_u32::<BigEndian>().unwrap();
if len < 4 {
return Err(ProtocolError::Protocol(format!(
"invalid message length {}",
len
)));
}
/// Read one message from the stream.
/// See documentation for `Self::read`.
pub fn read_fut<Reader>(
stream: &mut Reader,
) -> SyncFuture<Reader, impl Future<Output = Result<Option<FeMessage>, ConnectionError>> + '_>
where
Reader: tokio::io::AsyncRead + Unpin,
{
// We return a Future that's sync (has a `wait` method) if and only if the provided stream is SyncProof.
// SyncFuture contract: we are only allowed to await on sync-proof futures, the AsyncRead and
// AsyncReadExt methods of the stream.
SyncFuture::new(async move {
// Each libpq message begins with a message type byte, followed by message length
// If the client closes the connection, return None. But if the client closes the
// connection in the middle of a message, we will return an error.
let tag = match retry_read!(stream.read_u8().await) {
Ok(b) => b,
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(ConnectionError::Socket(e)),
};
// length field includes itself, but not message type.
let total_len = len as usize + 1;
if buf.len() < total_len {
// Don't have full message yet.
let to_read = total_len - buf.len();
buf.reserve(to_read);
return Ok(None);
}
// The message length includes itself, so it better be at least 4.
let len = retry_read!(stream.read_u32().await)
.map_err(ConnectionError::Socket)?
.checked_sub(4)
.ok_or_else(|| ConnectionError::Protocol("invalid message length".to_string()))?;
// got the message, advance buffer
let mut msg = buf.split_to(total_len).freeze();
msg.advance(5); // consume message type and len
let body = {
let mut buffer = vec![0u8; len as usize];
stream
.read_exact(&mut buffer)
.await
.map_err(ConnectionError::Socket)?;
Bytes::from(buffer)
};
match tag {
b'Q' => Ok(Some(FeMessage::Query(msg))),
b'P' => Ok(Some(FeParseMessage::parse(msg)?)),
b'D' => Ok(Some(FeDescribeMessage::parse(msg)?)),
b'E' => Ok(Some(FeExecuteMessage::parse(msg)?)),
b'B' => Ok(Some(FeBindMessage::parse(msg)?)),
b'C' => Ok(Some(FeCloseMessage::parse(msg)?)),
b'S' => Ok(Some(FeMessage::Sync)),
b'X' => Ok(Some(FeMessage::Terminate)),
b'd' => Ok(Some(FeMessage::CopyData(msg))),
b'c' => Ok(Some(FeMessage::CopyDone)),
b'f' => Ok(Some(FeMessage::CopyFail)),
b'p' => Ok(Some(FeMessage::PasswordMessage(msg))),
tag => {
return Err(ProtocolError::Protocol(format!(
"unknown message tag: {tag},'{msg:?}'"
)))
match tag {
b'Q' => Ok(Some(FeMessage::Query(body))),
b'P' => Ok(Some(FeParseMessage::parse(body)?)),
b'D' => Ok(Some(FeDescribeMessage::parse(body)?)),
b'E' => Ok(Some(FeExecuteMessage::parse(body)?)),
b'B' => Ok(Some(FeBindMessage::parse(body)?)),
b'C' => Ok(Some(FeCloseMessage::parse(body)?)),
b'S' => Ok(Some(FeMessage::Sync)),
b'X' => Ok(Some(FeMessage::Terminate)),
b'd' => Ok(Some(FeMessage::CopyData(body))),
b'c' => Ok(Some(FeMessage::CopyDone)),
b'f' => Ok(Some(FeMessage::CopyFail)),
b'p' => Ok(Some(FeMessage::PasswordMessage(body))),
tag => {
return Err(ConnectionError::Protocol(format!(
"unknown message tag: {tag},'{body:?}'"
)))
}
}
}
})
}
}
impl FeStartupPacket {
/// Read and parse startup message from the `buf` input buffer. It is
/// different from [`FeMessage::parse`] because startup messages don't have
/// message type byte; otherwise, its comments apply.
pub fn parse(buf: &mut BytesMut) -> Result<Option<FeStartupPacket>, ProtocolError> {
/// Read startup message from the stream.
// XXX: It's tempting yet undesirable to accept `stream` by value,
// since such a change will cause user-supplied &mut references to be consumed
pub fn read(
stream: &mut (impl io::Read + Unpin),
) -> Result<Option<FeMessage>, ConnectionError> {
Self::read_fut(&mut AsyncishRead(stream)).wait()
}
/// Read startup message from the stream.
// XXX: It's tempting yet undesirable to accept `stream` by value,
// since such a change will cause user-supplied &mut references to be consumed
pub fn read_fut<Reader>(
stream: &mut Reader,
) -> SyncFuture<Reader, impl Future<Output = Result<Option<FeMessage>, ConnectionError>> + '_>
where
Reader: tokio::io::AsyncRead + Unpin,
{
const MAX_STARTUP_PACKET_LENGTH: usize = 10000;
const RESERVED_INVALID_MAJOR_VERSION: u32 = 1234;
const CANCEL_REQUEST_CODE: u32 = 5678;
const NEGOTIATE_SSL_CODE: u32 = 5679;
const NEGOTIATE_GSS_CODE: u32 = 5680;
// need at least 4 bytes with packet len
if buf.len() < 4 {
let to_read = 4 - buf.len();
buf.reserve(to_read);
return Ok(None);
}
SyncFuture::new(async move {
// Read length. If the connection is closed before reading anything (or before
// reading 4 bytes, to be precise), return None to indicate that the connection
// was closed. This matches the PostgreSQL server's behavior, which avoids noise
// in the log if the client opens connection but closes it immediately.
let len = match retry_read!(stream.read_u32().await) {
Ok(len) => len as usize,
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(ConnectionError::Socket(e)),
};
// We shouldn't advance `buf` as probably full message is not there yet,
// so can't directly use Bytes::get_u32 etc.
let len = (&buf[0..4]).read_u32::<BigEndian>().unwrap() as usize;
if len < 4 || len > MAX_STARTUP_PACKET_LENGTH {
return Err(ProtocolError::Protocol(format!(
"invalid startup packet message length {}",
len
)));
}
if buf.len() < len {
// Don't have full message yet.
let to_read = len - buf.len();
buf.reserve(to_read);
return Ok(None);
}
// got the message, advance buffer
let mut msg = buf.split_to(len).freeze();
msg.advance(4); // consume len
let request_code = msg.get_u32();
let req_hi = request_code >> 16;
let req_lo = request_code & ((1 << 16) - 1);
// StartupMessage, CancelRequest, SSLRequest etc are differentiated by request code.
let message = match (req_hi, req_lo) {
(RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
if msg.remaining() != 8 {
return Err(ProtocolError::BadMessage(
"CancelRequest message is malformed, backend PID / secret key missing"
.to_owned(),
));
}
FeStartupPacket::CancelRequest(CancelKeyData {
backend_pid: msg.get_i32(),
cancel_key: msg.get_i32(),
})
}
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => {
// Requested upgrade to SSL (aka TLS)
FeStartupPacket::SslRequest
}
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => {
// Requested upgrade to GSSAPI
FeStartupPacket::GssEncRequest
}
(RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
return Err(ProtocolError::Protocol(format!(
"Unrecognized request code {unrecognized_code}"
#[allow(clippy::manual_range_contains)]
if len < 4 || len > MAX_STARTUP_PACKET_LENGTH {
return Err(ConnectionError::Protocol(format!(
"invalid message length {len}"
)));
}
// TODO bail if protocol major_version is not 3?
(major_version, minor_version) => {
// StartupMessage
// Parse pairs of null-terminated strings (key, value).
// See `postgres: ProcessStartupPacket, build_startup_packet`.
let mut tokens = str::from_utf8(&msg)
.map_err(|_e| {
ProtocolError::BadMessage("StartupMessage params: invalid utf-8".to_owned())
})?
.strip_suffix('\0') // drop packet's own null
.ok_or_else(|| {
ProtocolError::Protocol(
"StartupMessage params: missing null terminator".to_string(),
)
})?
.split_terminator('\0');
let request_code =
retry_read!(stream.read_u32().await).map_err(ConnectionError::Socket)?;
let mut params = HashMap::new();
while let Some(name) = tokens.next() {
let value = tokens.next().ok_or_else(|| {
ProtocolError::Protocol(
"StartupMessage params: key without value".to_string(),
)
})?;
// the rest of startup packet are params
let params_len = len - 8;
let mut params_bytes = vec![0u8; params_len];
stream
.read_exact(params_bytes.as_mut())
.await
.map_err(ConnectionError::Socket)?;
params.insert(name.to_owned(), value.to_owned());
// Parse params depending on request code
let req_hi = request_code >> 16;
let req_lo = request_code & ((1 << 16) - 1);
let message = match (req_hi, req_lo) {
(RESERVED_INVALID_MAJOR_VERSION, CANCEL_REQUEST_CODE) => {
if params_len != 8 {
return Err(ConnectionError::Protocol(
"expected 8 bytes for CancelRequest params".to_string(),
));
}
let mut cursor = Cursor::new(params_bytes);
FeStartupPacket::CancelRequest(CancelKeyData {
backend_pid: cursor.read_i32().await.map_err(ConnectionError::Socket)?,
cancel_key: cursor.read_i32().await.map_err(ConnectionError::Socket)?,
})
}
FeStartupPacket::StartupMessage {
major_version,
minor_version,
params: StartupMessageParams { params },
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_SSL_CODE) => {
// Requested upgrade to SSL (aka TLS)
FeStartupPacket::SslRequest
}
}
};
Ok(Some(message))
(RESERVED_INVALID_MAJOR_VERSION, NEGOTIATE_GSS_CODE) => {
// Requested upgrade to GSSAPI
FeStartupPacket::GssEncRequest
}
(RESERVED_INVALID_MAJOR_VERSION, unrecognized_code) => {
return Err(ConnectionError::Protocol(format!(
"Unrecognized request code {unrecognized_code}"
)));
}
// TODO bail if protocol major_version is not 3?
(major_version, minor_version) => {
// Parse pairs of null-terminated strings (key, value).
// See `postgres: ProcessStartupPacket, build_startup_packet`.
let mut tokens = str::from_utf8(&params_bytes)
.context("StartupMessage params: invalid utf-8")?
.strip_suffix('\0') // drop packet's own null
.ok_or_else(|| {
ConnectionError::Protocol(
"StartupMessage params: missing null terminator".to_string(),
)
})?
.split_terminator('\0');
let mut params = HashMap::new();
while let Some(name) = tokens.next() {
let value = tokens.next().ok_or_else(|| {
ConnectionError::Protocol(
"StartupMessage params: key without value".to_string(),
)
})?;
params.insert(name.to_owned(), value.to_owned());
}
FeStartupPacket::StartupMessage {
major_version,
minor_version,
params: StartupMessageParams { params },
}
}
};
Ok(Some(FeMessage::StartupPacket(message)))
})
}
}
impl FeParseMessage {
fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
// FIXME: the rust-postgres driver uses a named prepared statement
// for copy_out(). We're not prepared to handle that correctly. For
// now, just ignore the statement name, assuming that the client never
@@ -392,82 +452,55 @@ impl FeParseMessage {
let _pstmt_name = read_cstr(&mut buf)?;
let query_string = read_cstr(&mut buf)?;
if buf.remaining() < 2 {
return Err(ProtocolError::BadMessage(
"Parse message is malformed, nparams missing".to_string(),
));
}
let nparams = buf.get_i16();
if nparams != 0 {
return Err(ProtocolError::BadMessage(
"query params not implemented".to_string(),
));
}
ensure!(nparams == 0, "query params not implemented");
Ok(FeMessage::Parse(FeParseMessage { query_string }))
}
}
impl FeDescribeMessage {
fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
let kind = buf.get_u8();
let _pstmt_name = read_cstr(&mut buf)?;
// FIXME: see FeParseMessage::parse
if kind != b'S' {
return Err(ProtocolError::BadMessage(
"only prepared statemement Describe is implemented".to_string(),
));
}
ensure!(
kind == b'S',
"only prepared statemement Describe is implemented"
);
Ok(FeMessage::Describe(FeDescribeMessage { kind }))
}
}
impl FeExecuteMessage {
fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
let portal_name = read_cstr(&mut buf)?;
if buf.remaining() < 4 {
return Err(ProtocolError::BadMessage(
"FeExecuteMessage message is malformed, maxrows missing".to_string(),
));
}
let maxrows = buf.get_i32();
if !portal_name.is_empty() {
return Err(ProtocolError::BadMessage(
"named portals not implemented".to_string(),
));
}
if maxrows != 0 {
return Err(ProtocolError::BadMessage(
"row limit in Execute message not implemented".to_string(),
));
}
ensure!(portal_name.is_empty(), "named portals not implemented");
ensure!(maxrows == 0, "row limit in Execute message not implemented");
Ok(FeMessage::Execute(FeExecuteMessage { maxrows }))
}
}
impl FeBindMessage {
fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
let portal_name = read_cstr(&mut buf)?;
let _pstmt_name = read_cstr(&mut buf)?;
// FIXME: see FeParseMessage::parse
if !portal_name.is_empty() {
return Err(ProtocolError::BadMessage(
"named portals not implemented".to_string(),
));
}
ensure!(portal_name.is_empty(), "named portals not implemented");
Ok(FeMessage::Bind(FeBindMessage))
}
}
impl FeCloseMessage {
fn parse(mut buf: Bytes) -> Result<FeMessage, ProtocolError> {
fn parse(mut buf: Bytes) -> anyhow::Result<FeMessage> {
let _kind = buf.get_u8();
let _pstmt_or_portal_name = read_cstr(&mut buf)?;
@@ -496,7 +529,6 @@ pub enum BeMessage<'a> {
CloseComplete,
// None means column is NULL
DataRow(&'a [Option<&'a [u8]>]),
// None errcode means internal_error will be sent.
ErrorResponse(&'a str, Option<&'a [u8; 5]>),
/// Single byte - used in response to SSLRequest/GSSENCRequest.
EncryptionResponse(bool),
@@ -527,11 +559,6 @@ impl<'a> BeMessage<'a> {
value: b"UTF8",
};
pub const INTEGER_DATETIMES: Self = Self::ParameterStatus {
name: b"integer_datetimes",
value: b"on",
};
/// Build a [`BeMessage::ParameterStatus`] holding the server version.
pub fn server_version(version: &'a str) -> Self {
Self::ParameterStatus {
@@ -610,7 +637,7 @@ impl RowDescriptor<'_> {
#[derive(Debug)]
pub struct XLogDataBody<'a> {
pub wal_start: u64,
pub wal_end: u64, // current end of WAL on the server
pub wal_end: u64,
pub timestamp: i64,
pub data: &'a [u8],
}
@@ -650,11 +677,12 @@ fn write_body<R>(buf: &mut BytesMut, f: impl FnOnce(&mut BytesMut) -> R) -> R {
}
/// Safe write of s into buf as cstring (String in the protocol).
fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> Result<(), ProtocolError> {
fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> io::Result<()> {
let bytes = s.as_ref();
if bytes.contains(&0) {
return Err(ProtocolError::BadMessage(
"string contains embedded null".to_owned(),
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"string contains embedded null",
));
}
buf.put_slice(bytes);
@@ -662,27 +690,22 @@ fn write_cstr(s: impl AsRef<[u8]>, buf: &mut BytesMut) -> Result<(), ProtocolErr
Ok(())
}
/// Read cstring from buf, advancing it.
fn read_cstr(buf: &mut Bytes) -> Result<Bytes, ProtocolError> {
let pos = buf
.iter()
.position(|x| *x == 0)
.ok_or_else(|| ProtocolError::BadMessage("missing cstring terminator".to_owned()))?;
let result = buf.split_to(pos);
fn read_cstr(buf: &mut Bytes) -> anyhow::Result<Bytes> {
let pos = buf.iter().position(|x| *x == 0);
let result = buf.split_to(pos.context("missing terminator")?);
buf.advance(1); // drop the null terminator
Ok(result)
}
pub const SQLSTATE_INTERNAL_ERROR: &[u8; 5] = b"XX000";
pub const SQLSTATE_SUCCESSFUL_COMPLETION: &[u8; 5] = b"00000";
impl<'a> BeMessage<'a> {
/// Serialize `message` to the given `buf`.
/// Apart from smart memory managemet, BytesMut is good here as msg len
/// precedes its body and it is handy to write it down first and then fill
/// the length. With Write we would have to either calc it manually or have
/// one more buffer.
pub fn write(buf: &mut BytesMut, message: &BeMessage) -> Result<(), ProtocolError> {
/// Write message to the given buf.
// Unlike the reading side, we use BytesMut
// here as msg len precedes its body and it is handy to write it down first
// and then fill the length. With Write we would have to either calc it
// manually or have one more buffer.
pub fn write(buf: &mut BytesMut, message: &BeMessage) -> io::Result<()> {
match message {
BeMessage::AuthenticationOk => {
buf.put_u8(b'R');
@@ -727,7 +750,7 @@ impl<'a> BeMessage<'a> {
buf.put_slice(extra);
}
}
Ok(())
Ok::<_, io::Error>(())
})?;
}
@@ -831,7 +854,7 @@ impl<'a> BeMessage<'a> {
write_cstr(error_msg, buf)?;
buf.put_u8(0); // terminator
Ok(())
Ok::<_, io::Error>(())
})?;
}
@@ -854,7 +877,7 @@ impl<'a> BeMessage<'a> {
write_cstr(error_msg.as_bytes(), buf)?;
buf.put_u8(0); // terminator
Ok(())
Ok::<_, io::Error>(())
})?;
}
@@ -909,7 +932,7 @@ impl<'a> BeMessage<'a> {
buf.put_i32(-1); /* typmod */
buf.put_i16(0); /* format code */
}
Ok(())
Ok::<_, io::Error>(())
})?;
}
@@ -976,7 +999,7 @@ impl ReplicationFeedback {
// null-terminated string - key,
// uint32 - value length in bytes
// value itself
pub fn serialize(&self, buf: &mut BytesMut) {
pub fn serialize(&self, buf: &mut BytesMut) -> Result<()> {
buf.put_u8(REPLICATION_FEEDBACK_FIELDS_NUMBER); // # of keys
buf.put_slice(b"current_timeline_size\0");
buf.put_i32(8);
@@ -1001,6 +1024,7 @@ impl ReplicationFeedback {
buf.put_slice(b"ps_replytime\0");
buf.put_i32(8);
buf.put_i64(timestamp);
Ok(())
}
// Deserialize ReplicationFeedback message
@@ -1068,7 +1092,7 @@ mod tests {
// because it is rounded up to microseconds during serialization.
rf.ps_replytime = *PG_EPOCH + Duration::from_secs(100_000_000);
let mut data = BytesMut::new();
rf.serialize(&mut data);
rf.serialize(&mut data).unwrap();
let rf_parsed = ReplicationFeedback::parse(data.freeze());
assert_eq!(rf, rf_parsed);
@@ -1083,7 +1107,7 @@ mod tests {
// because it is rounded up to microseconds during serialization.
rf.ps_replytime = *PG_EPOCH + Duration::from_secs(100_000_000);
let mut data = BytesMut::new();
rf.serialize(&mut data);
rf.serialize(&mut data).unwrap();
// Add an extra field to the buffer and adjust number of keys
if let Some(first) = data.first_mut() {
@@ -1125,6 +1149,15 @@ mod tests {
let params = make_params("foo\\ bar \\ \\\\ baz\\ lol");
assert_eq!(split_options(&params), ["foo bar", " \\", "baz ", "lol"]);
}
// Make sure that `read` is sync/async callable
async fn _assert(stream: &mut (impl tokio::io::AsyncRead + Unpin)) {
let _ = FeMessage::read(&mut [].as_ref());
let _ = FeMessage::read_fut(stream).await;
let _ = FeStartupPacket::read(&mut [].as_ref());
let _ = FeStartupPacket::read_fut(stream).await;
}
}
fn terminate_code(code: &[u8; 5]) -> [u8; 6] {

179
libs/pq_proto/src/sync.rs Normal file
View File

@@ -0,0 +1,179 @@
use pin_project_lite::pin_project;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::{io, task};
pin_project! {
/// We use this future to mark certain methods
/// as callable in both sync and async modes.
#[repr(transparent)]
pub struct SyncFuture<S, T: Future> {
#[pin]
inner: T,
_marker: PhantomData<S>,
}
}
/// This wrapper lets us synchronously wait for inner future's completion
/// (see [`SyncFuture::wait`]) **provided that `S` implements [`SyncProof`]**.
/// For instance, `S` may be substituted with types implementing
/// [`tokio::io::AsyncRead`], but it's not the only viable option.
impl<S, T: Future> SyncFuture<S, T> {
/// NOTE: caller should carefully pick a type for `S`,
/// because we don't want to enable [`SyncFuture::wait`] when
/// it's in fact impossible to run the future synchronously.
/// Violation of this contract will not cause UB, but
/// panics and async event loop freezes won't please you.
///
/// Example:
///
/// ```
/// # use pq_proto::sync::SyncFuture;
/// # use std::future::Future;
/// # use tokio::io::AsyncReadExt;
/// #
/// // Parse a pair of numbers from a stream
/// pub fn parse_pair<Reader>(
/// stream: &mut Reader,
/// ) -> SyncFuture<Reader, impl Future<Output = anyhow::Result<(u32, u64)>> + '_>
/// where
/// Reader: tokio::io::AsyncRead + Unpin,
/// {
/// // If `Reader` is a `SyncProof`, this will give caller
/// // an opportunity to use `SyncFuture::wait`, because
/// // `.await` will always result in `Poll::Ready`.
/// SyncFuture::new(async move {
/// let x = stream.read_u32().await?;
/// let y = stream.read_u64().await?;
/// Ok((x, y))
/// })
/// }
/// ```
pub fn new(inner: T) -> Self {
Self {
inner,
_marker: PhantomData,
}
}
}
impl<S, T: Future> Future for SyncFuture<S, T> {
type Output = T::Output;
/// In async code, [`SyncFuture`] behaves like a regular wrapper.
#[inline(always)]
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
self.project().inner.poll(cx)
}
}
/// Postulates that we can call [`SyncFuture::wait`].
/// If implementer is also a [`Future`], it should always
/// return [`task::Poll::Ready`] from [`Future::poll`].
///
/// Each implementation should document which futures
/// specifically are being declared sync-proof.
pub trait SyncPostulate {}
impl<T: SyncPostulate> SyncPostulate for &T {}
impl<T: SyncPostulate> SyncPostulate for &mut T {}
impl<P: SyncPostulate, T: Future> SyncFuture<P, T> {
/// Synchronously wait for future completion.
pub fn wait(mut self) -> T::Output {
const RAW_WAKER: task::RawWaker = task::RawWaker::new(
std::ptr::null(),
&task::RawWakerVTable::new(
|_| RAW_WAKER,
|_| panic!("SyncFuture: failed to wake"),
|_| panic!("SyncFuture: failed to wake by ref"),
|_| { /* drop is no-op */ },
),
);
// SAFETY: We never move `self` during this call;
// furthermore, it will be dropped in the end regardless of panics
let this = unsafe { Pin::new_unchecked(&mut self) };
// SAFETY: This waker doesn't do anything apart from panicking
let waker = unsafe { task::Waker::from_raw(RAW_WAKER) };
let context = &mut task::Context::from_waker(&waker);
match this.poll(context) {
task::Poll::Ready(res) => res,
_ => panic!("SyncFuture: unexpected pending!"),
}
}
}
/// This wrapper turns any [`std::io::Read`] into a blocking [`tokio::io::AsyncRead`],
/// which lets us abstract over sync & async readers in methods returning [`SyncFuture`].
/// NOTE: you **should not** use this in async code.
#[repr(transparent)]
pub struct AsyncishRead<T: io::Read + Unpin>(pub T);
/// This lets us call [`SyncFuture<AsyncishRead<_>, _>::wait`],
/// and allows the future to await on any of the [`AsyncRead`]
/// and [`AsyncReadExt`] methods on `AsyncishRead`.
impl<T: io::Read + Unpin> SyncPostulate for AsyncishRead<T> {}
impl<T: io::Read + Unpin> tokio::io::AsyncRead for AsyncishRead<T> {
#[inline(always)]
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> task::Poll<io::Result<()>> {
task::Poll::Ready(
// `Read::read` will block, meaning we don't need a real event loop!
self.0
.read(buf.initialize_unfilled())
.map(|sz| buf.advance(sz)),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
// async helper(stream: &mut impl AsyncRead) -> io::Result<u32>
fn bytes_add<Reader>(
stream: &mut Reader,
) -> SyncFuture<Reader, impl Future<Output = io::Result<u32>> + '_>
where
Reader: tokio::io::AsyncRead + Unpin,
{
SyncFuture::new(async move {
let a = stream.read_u32().await?;
let b = stream.read_u32().await?;
Ok(a + b)
})
}
#[test]
fn test_sync() {
let bytes = [100u32.to_be_bytes(), 200u32.to_be_bytes()].concat();
let res = bytes_add(&mut AsyncishRead(&mut &bytes[..]))
.wait()
.unwrap();
assert_eq!(res, 300);
}
// We need a single-threaded executor for this test
#[tokio::test(flavor = "current_thread")]
async fn test_async() {
let (mut tx, mut rx) = tokio::net::UnixStream::pair().unwrap();
let write = async move {
tx.write_u32(100).await?;
tx.write_u32(200).await?;
Ok(())
};
let (res, ()) = tokio::try_join!(bytes_add(&mut rx), write).unwrap();
assert_eq!(res, 300);
}
}

View File

@@ -111,7 +111,7 @@ pub trait RemoteStorage: Send + Sync + 'static {
}
pub struct Download {
pub download_stream: Pin<Box<dyn io::AsyncRead + Unpin + Send + Sync>>,
pub download_stream: Pin<Box<dyn io::AsyncRead + Unpin + Send>>,
/// Extra key-value data, associated with the current remote file.
pub metadata: Option<StorageMetadata>,
}

View File

@@ -12,36 +12,41 @@ anyhow.workspace = true
bincode.workspace = true
bytes.workspace = true
heapless.workspace = true
hex = { workspace = true, features = ["serde"] }
hyper = { workspace = true, features = ["full"] }
futures = { workspace = true}
jsonwebtoken.workspace = true
nix.workspace = true
once_cell.workspace = true
routerify.workspace = true
serde.workspace = true
serde_json.workspace = true
signal-hook.workspace = true
thiserror.workspace = true
tokio.workspace = true
tokio-rustls.workspace = true
tracing.workspace = true
tracing-subscriber = { workspace = true, features = ["json"] }
nix.workspace = true
signal-hook.workspace = true
rand.workspace = true
jsonwebtoken.workspace = true
hex = { workspace = true, features = ["serde"] }
rustls.workspace = true
rustls-split.workspace = true
git-version.workspace = true
serde_with.workspace = true
once_cell.workspace = true
strum.workspace = true
strum_macros.workspace = true
url.workspace = true
uuid = { version = "1.2", features = ["v4", "serde"] }
metrics.workspace = true
pq_proto.workspace = true
workspace_hack.workspace = true
url.workspace = true
[dev-dependencies]
byteorder.workspace = true
bytes.workspace = true
criterion.workspace = true
hex-literal.workspace = true
tempfile.workspace = true
criterion.workspace = true
rustls-pemfile.workspace = true
[[bench]]
name = "benchmarks"

View File

@@ -9,28 +9,16 @@ use std::path::Path;
use anyhow::Result;
use jsonwebtoken::{
decode, encode, Algorithm, Algorithm::*, DecodingKey, EncodingKey, Header, TokenData,
Validation,
decode, encode, Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation,
};
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, DisplayFromStr};
use crate::id::TenantId;
/// Algorithms accepted during validation.
///
/// Accept all RSA-based algorithms. We pass this list to jsonwebtoken::decode,
/// which checks that the algorithm in the token is one of these.
///
/// XXX: It also fails the validation if there are any algorithms in this list that belong
/// to different family than the token's algorithm. In other words, we can *not* list any
/// non-RSA algorithms here, or the validation always fails with InvalidAlgorithm error.
const ACCEPTED_ALGORITHMS: &[Algorithm] = &[RS256, RS384, RS512];
const JWT_ALGORITHM: Algorithm = Algorithm::RS256;
/// Algorithm to use when generating a new token in [`encode_from_key_file`]
const ENCODE_ALGORITHM: Algorithm = Algorithm::RS256;
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(rename_all = "lowercase")]
pub enum Scope {
// Provides access to all data for a specific tenant (specified in `struct Claims` below)
@@ -45,9 +33,8 @@ pub enum Scope {
SafekeeperData,
}
/// JWT payload. See docs/authentication.md for the format
#[serde_as]
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Claims {
#[serde(default)]
#[serde_as(as = "Option<DisplayFromStr>")]
@@ -68,8 +55,7 @@ pub struct JwtAuth {
impl JwtAuth {
pub fn new(decoding_key: DecodingKey) -> Self {
let mut validation = Validation::default();
validation.algorithms = ACCEPTED_ALGORITHMS.into();
let mut validation = Validation::new(JWT_ALGORITHM);
// The default 'required_spec_claims' is 'exp'. But we don't want to require
// expiration.
validation.required_spec_claims = [].into();
@@ -100,113 +86,5 @@ impl std::fmt::Debug for JwtAuth {
// this function is used only for testing purposes in CLI e g generate tokens during init
pub fn encode_from_key_file(claims: &Claims, key_data: &[u8]) -> Result<String> {
let key = EncodingKey::from_rsa_pem(key_data)?;
Ok(encode(&Header::new(ENCODE_ALGORITHM), claims, &key)?)
}
#[cfg(test)]
mod tests {
use super::*;
use std::str::FromStr;
// generated with:
//
// openssl genpkey -algorithm rsa -out storage-auth-priv.pem
// openssl pkey -in storage-auth-priv.pem -pubout -out storage-auth-pub.pem
const TEST_PUB_KEY_RSA: &[u8] = br#"
-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAy6OZ+/kQXcueVJA/KTzO
v4ljxylc/Kcb0sXWuXg1GB8k3nDA1gK66LFYToH0aTnqrnqG32Vu6wrhwuvqsZA7
jQvP0ZePAbWhpEqho7EpNunDPcxZ/XDy5TQlB1P58F9I3lkJXDC+DsHYLuuzwhAv
vo2MtWRdYlVHblCVLyZtANHhUMp2HUhgjHnJh5UrLIKOl4doCBxkM3rK0wjKsNCt
M92PCR6S9rvYzldfeAYFNppBkEQrXt2CgUqZ4KaS4LXtjTRUJxljijA4HWffhxsr
euRu3ufq8kVqie7fum0rdZZSkONmce0V0LesQ4aE2jB+2Sn48h6jb4dLXGWdq8TV
wQIDAQAB
-----END PUBLIC KEY-----
"#;
const TEST_PRIV_KEY_RSA: &[u8] = br#"
-----BEGIN PRIVATE KEY-----
MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDLo5n7+RBdy55U
kD8pPM6/iWPHKVz8pxvSxda5eDUYHyTecMDWArrosVhOgfRpOequeobfZW7rCuHC
6+qxkDuNC8/Rl48BtaGkSqGjsSk26cM9zFn9cPLlNCUHU/nwX0jeWQlcML4Owdgu
67PCEC++jYy1ZF1iVUduUJUvJm0A0eFQynYdSGCMecmHlSssgo6Xh2gIHGQzesrT
CMqw0K0z3Y8JHpL2u9jOV194BgU2mkGQRCte3YKBSpngppLgte2NNFQnGWOKMDgd
Z9+HGyt65G7e5+ryRWqJ7t+6bSt1llKQ42Zx7RXQt6xDhoTaMH7ZKfjyHqNvh0tc
ZZ2rxNXBAgMBAAECggEAVz3u4Wlx3o02dsoZlSQs+xf0PEX3RXKeU+1YMbtTG9Nz
6yxpIQaoZrpbt76rJE2gwkFR+PEu1NmjoOuLb6j4KlQuI4AHz1auOoGSwFtM6e66
K4aZ4x95oEJ3vqz2fkmEIWYJwYpMUmwvnuJx76kZm0xvROMLsu4QHS2+zCVtO5Tr
hvS05IMVuZ2TdQBZw0+JaFdwXbgDjQnQGY5n9MoTWSx1a4s/FF4Eby65BbDutcpn
Vt3jQAOmO1X2kbPeWSGuPJRzyUs7Kg8qfeglBIR3ppGP3vPYAdWX+ho00bmsVkSp
Q8vjul6C3WiM+kjwDxotHSDgbl/xldAl7OqPh0bfAQKBgQDnycXuq14Vg8nZvyn9
rTnvucO8RBz5P6G+FZ+44cAS2x79+85onARmMnm+9MKYLSMo8fOvsK034NDI68XM
04QQ/vlfouvFklMTGJIurgEImTZbGCmlMYCvFyIxaEWixon8OpeI4rFe4Hmbiijh
PxhxWg221AwvBS2sco8J/ylEkQKBgQDg6Rh2QYb/j0Wou1rJPbuy3NhHofd5Rq35
4YV3f2lfVYcPrgRhwe3T9SVII7Dx8LfwzsX5TAlf48ESlI3Dzv40uOCDM+xdtBRI
r96SfSm+jup6gsXU3AsdNkrRK3HoOG9Z/TkrUp213QAIlVnvIx65l4ckFMlpnPJ0
lo1LDXZWMQKBgFArzjZ7N5OhfdO+9zszC3MLgdRAivT7OWqR+CjujIz5FYMr8Xzl
WfAvTUTrS9Nu6VZkObFvHrrRG+YjBsuN7YQjbQXTSFGSBwH34bgbn2fl9pMTjHQC
50uoaL9GHa/rlBaV/YvvPQJgCi/uXa1rMX0jdNLkDULGO8IF7cu7Yf7BAoGBAIUU
J29BkpmAst0GDs/ogTlyR18LTR0rXyHt+UUd1MGeH859TwZw80JpWWf4BmkB4DTS
hH3gKePdJY7S65ci0XNsuRupC4DeXuorde0DtkGU2tUmr9wlX0Ynq9lcdYfMbMa4
eK1TsxG69JwfkxlWlIWITWRiEFM3lJa7xlrUWmLhAoGAFpKWF/hn4zYg3seU9gai
EYHKSbhxA4mRb+F0/9IlCBPMCqFrL5yftUsYIh2XFKn8+QhO97Nmk8wJSK6TzQ5t
ZaSRmgySrUUhx4nZ/MgqWCFv8VUbLM5MBzwxPKhXkSTfR4z2vLYLJwVY7Tb4kZtp
8ismApXVGHpOCstzikV9W7k=
-----END PRIVATE KEY-----
"#;
#[test]
fn test_decode() -> Result<(), anyhow::Error> {
let expected_claims = Claims {
tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081")?),
scope: Scope::Tenant,
};
// Here are tokens containing the following payload, signed using TEST_PRIV_KEY_RSA
// using RS512, RS384 and RS256 algorithms:
//
// ```
// {
// "scope": "tenant",
// "tenant_id": "3d1f7595b468230304e0b73cecbcb081",
// "iss": "neon.controlplane",
// "exp": 1709200879,
// "iat": 1678442479
// }
// ```
//
// These were encoded with the online debugger at https://jwt.io
//
let encoded_rs512 = "eyJhbGciOiJSUzUxMiIsInR5cCI6IkpXVCJ9.eyJzY29wZSI6InRlbmFudCIsInRlbmFudF9pZCI6IjNkMWY3NTk1YjQ2ODIzMDMwNGUwYjczY2VjYmNiMDgxIiwiaXNzIjoibmVvbi5jb250cm9scGxhbmUiLCJleHAiOjE3MDkyMDA4NzksImlhdCI6MTY3ODQ0MjQ3OX0.QmqfteDQmDGoxQ5EFkasbt35Lx0W0Nh63muQnYZvFq93DSh4ZbOG9Mc4yaiXZoiS5HgeKtFKv3mbWkDqjz3En06aY17hWwguBtAsGASX48lYeCPADYGlGAuaWnOnVRwe3iiOC7tvPFvwX_45S84X73sNUXyUiXv6nLdcDqVXudtNrGST_DnZDnjuUJX11w7sebtKqQQ8l9-iGHiXOl5yevpMCoB1OcTWcT6DfDtffoNuMHDC3fyhmEGG5oKAt1qBybqAIiyC9-UBAowRZXhdfxrzUl-I9jzKWvk85c5ulhVRwbPeP6TTTlPKwFzBNHg1i2U-1GONew5osQ3aoptwsA";
let encoded_rs384 = "eyJhbGciOiJSUzM4NCIsInR5cCI6IkpXVCJ9.eyJzY29wZSI6InRlbmFudCIsInRlbmFudF9pZCI6IjNkMWY3NTk1YjQ2ODIzMDMwNGUwYjczY2VjYmNiMDgxIiwiaXNzIjoibmVvbi5jb250cm9scGxhbmUiLCJleHAiOjE3MDkyMDA4NzksImlhdCI6MTY3ODQ0MjQ3OX0.qqk4nkxKzOJP38c_g57_w_SfdQVmCsDT_bsLmdFj_N6LIB22gr6U6_P_5mvk3pIAsp0VCTDwPrCU908TxqjibEkwvQoJwbogHamSGHpD7eJBxGblSnA-Nr3MlEMxpFtec8QokSm6C5mH7DoBYjB2xzeOlxAmpR2GAzInKiMkU4kZ_OcqqrmVcMXY_6VnbxZWMekuw56zE1-PP_qNF1HvYOH-P08ONP8qdo5UPtBG7QBEFlCqZXJZCFihQaI4Vzil9rDuZGCm3I7xQJ8-yh1PX3BTbGo8EzqLdRyBeTpr08UTuRbp_MJDWevHpP3afvJetAItqZXIoZQrbJjcByHqKw";
let encoded_rs256 = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzY29wZSI6InRlbmFudCIsInRlbmFudF9pZCI6IjNkMWY3NTk1YjQ2ODIzMDMwNGUwYjczY2VjYmNiMDgxIiwiaXNzIjoibmVvbi5jb250cm9scGxhbmUiLCJleHAiOjE3MDkyMDA4NzksImlhdCI6MTY3ODQ0MjQ3OX0.dF2N9KXG8ftFKHYbd5jQtXMQqv0Ej8FISGp1b_dmqOCotXj5S1y2AWjwyB_EXHM77JXfbEoJPAPrFFBNfd8cWtkCSTvpxWoHaecGzegDFGv5ZSc5AECFV1Daahc3PI3jii9wEiGkFOiwiBNfZ5INomOAsV--XXxlqIwKbTcgSYI7lrOTfecXAbAHiMKQlQYiIBSGnytRCgafhRkyGzPAL8ismthFJ9RHfeejyskht-9GbVHURw02bUyijuHEulpf9eEY3ZiB28de6jnCdU7ftIYaUMaYWt0nZQGkzxKPSfSLZNy14DTOYLDS04DVstWQPqnCUW_ojg0wJETOOfo9Zw";
// Check that RS512, RS384 and RS256 tokens can all be validated
let auth = JwtAuth::new(DecodingKey::from_rsa_pem(TEST_PUB_KEY_RSA)?);
for encoded in [encoded_rs512, encoded_rs384, encoded_rs256] {
let claims_from_token = auth.decode(encoded)?.claims;
assert_eq!(claims_from_token, expected_claims);
}
Ok(())
}
#[test]
fn test_encode() -> Result<(), anyhow::Error> {
let claims = Claims {
tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081")?),
scope: Scope::Tenant,
};
let encoded = encode_from_key_file(&claims, TEST_PRIV_KEY_RSA)?;
// decode it back
let auth = JwtAuth::new(DecodingKey::from_rsa_pem(TEST_PUB_KEY_RSA)?);
let decoded = auth.decode(&encoded)?;
assert_eq!(decoded.claims, claims);
Ok(())
}
Ok(encode(&Header::new(JWT_ALGORITHM), claims, &key)?)
}

View File

@@ -3,14 +3,15 @@ use crate::http::error;
use anyhow::{anyhow, Context};
use hyper::header::{HeaderName, AUTHORIZATION};
use hyper::http::HeaderValue;
use hyper::Method;
use hyper::{header::CONTENT_TYPE, Body, Request, Response, Server};
use hyper::{Method, StatusCode};
use metrics::{register_int_counter, Encoder, IntCounter, TextEncoder};
use once_cell::sync::Lazy;
use routerify::ext::RequestExt;
use routerify::{Middleware, RequestInfo, Router, RouterBuilder, RouterService};
use routerify::RequestInfo;
use routerify::{Middleware, Router, RouterBuilder, RouterService};
use tokio::task::JoinError;
use tracing::{self, debug, info, info_span, warn, Instrument};
use tracing;
use std::future::Future;
use std::net::TcpListener;
@@ -26,83 +27,16 @@ static SERVE_METRICS_COUNT: Lazy<IntCounter> = Lazy::new(|| {
.expect("failed to define a metric")
});
static X_REQUEST_ID_HEADER_STR: &str = "x-request-id";
static X_REQUEST_ID_HEADER: HeaderName = HeaderName::from_static(X_REQUEST_ID_HEADER_STR);
#[derive(Debug, Default, Clone)]
struct RequestId(String);
/// Adds a tracing info_span! instrumentation around the handler events,
/// logs the request start and end events for non-GET requests and non-200 responses.
///
/// Use this to distinguish between logs of different HTTP requests: every request handler wrapped
/// in this type will get request info logged in the wrapping span, including the unique request ID.
///
/// There could be other ways to implement similar functionality:
///
/// * procmacros placed on top of all handler methods
/// With all the drawbacks of procmacros, brings no difference implementation-wise,
/// and little code reduction compared to the existing approach.
///
/// * Another `TraitExt` with e.g. the `get_with_span`, `post_with_span` methods to do similar logic,
/// implemented for [`RouterBuilder`].
/// Could be simpler, but we don't want to depend on [`routerify`] more, targeting to use other library later.
///
/// * In theory, a span guard could've been created in a pre-request middleware and placed into a global collection, to be dropped
/// later, in a post-response middleware.
/// Due to suspendable nature of the futures, would give contradictive results which is exactly the opposite of what `tracing-futures`
/// tries to achive with its `.instrument` used in the current approach.
///
/// If needed, a declarative macro to substitute the |r| ... closure boilerplate could be introduced.
pub struct RequestSpan<E, R, H>(pub H)
where
E: Into<Box<dyn std::error::Error + Send + Sync>> + 'static,
R: Future<Output = Result<Response<Body>, E>> + Send + 'static,
H: Fn(Request<Body>) -> R + Send + Sync + 'static;
impl<E, R, H> RequestSpan<E, R, H>
where
E: Into<Box<dyn std::error::Error + Send + Sync>> + 'static,
R: Future<Output = Result<Response<Body>, E>> + Send + 'static,
H: Fn(Request<Body>) -> R + Send + Sync + 'static,
{
/// Creates a tracing span around inner request handler and executes the request handler in the contex of that span.
/// Use as `|r| RequestSpan(my_handler).handle(r)` instead of `my_handler` as the request handler to get the span enabled.
pub async fn handle(self, request: Request<Body>) -> Result<Response<Body>, E> {
let request_id = request.context::<RequestId>().unwrap_or_default().0;
let method = request.method();
let path = request.uri().path();
let request_span = info_span!("request", %method, %path, %request_id);
let log_quietly = method == Method::GET;
async move {
if log_quietly {
debug!("Handling request");
} else {
info!("Handling request");
}
// Note that we reuse `error::handler` here and not returning and error at all,
// yet cannot use `!` directly in the method signature due to `routerify::RouterBuilder` limitation.
// Usage of the error handler also means that we expect only the `ApiError` errors to be raised in this call.
//
// Panics are not handled separately, there's a `tracing_panic_hook` from another module to do that globally.
match (self.0)(request).await {
Ok(response) => {
let response_status = response.status();
if log_quietly && response_status.is_success() {
debug!("Request handled, status: {response_status}");
} else {
info!("Request handled, status: {response_status}");
}
Ok(response)
}
Err(e) => Ok(error::handler(e.into()).await),
}
}
.instrument(request_span)
.await
async fn logger(res: Response<Body>, info: RequestInfo) -> Result<Response<Body>, ApiError> {
// cannot factor out the Level to avoid the repetition
// because tracing can only work with const Level
// which is not the case here
if info.method() == Method::GET && res.status() == StatusCode::OK {
tracing::debug!("{} {} {}", info.method(), info.uri().path(), res.status());
} else {
tracing::info!("{} {} {}", info.method(), info.uri().path(), res.status());
}
Ok(res)
}
async fn prometheus_metrics_handler(_req: Request<Body>) -> Result<Response<Body>, ApiError> {
@@ -129,48 +63,10 @@ async fn prometheus_metrics_handler(_req: Request<Body>) -> Result<Response<Body
Ok(response)
}
pub fn add_request_id_middleware<B: hyper::body::HttpBody + Send + Sync + 'static>(
) -> Middleware<B, ApiError> {
Middleware::pre(move |req| async move {
let request_id = match req.headers().get(&X_REQUEST_ID_HEADER) {
Some(request_id) => request_id
.to_str()
.expect("extract request id value")
.to_owned(),
None => {
let request_id = uuid::Uuid::new_v4();
request_id.to_string()
}
};
req.set_context(RequestId(request_id));
Ok(req)
})
}
async fn add_request_id_header_to_response(
mut res: Response<Body>,
req_info: RequestInfo,
) -> Result<Response<Body>, ApiError> {
if let Some(request_id) = req_info.context::<RequestId>() {
if let Ok(request_header_value) = HeaderValue::from_str(&request_id.0) {
res.headers_mut()
.insert(&X_REQUEST_ID_HEADER, request_header_value);
};
};
Ok(res)
}
pub fn make_router() -> RouterBuilder<hyper::Body, ApiError> {
Router::builder()
.middleware(add_request_id_middleware())
.middleware(Middleware::post_with_info(
add_request_id_header_to_response,
))
.get("/metrics", |r| {
RequestSpan(prometheus_metrics_handler).handle(r)
})
.middleware(Middleware::post_with_info(logger))
.get("/metrics", prometheus_metrics_handler)
.err_handler(error::handler)
}
@@ -180,43 +76,40 @@ pub fn attach_openapi_ui(
spec_mount_path: &'static str,
ui_mount_path: &'static str,
) -> RouterBuilder<hyper::Body, ApiError> {
router_builder
.get(spec_mount_path, move |r| {
RequestSpan(move |_| async move { Ok(Response::builder().body(Body::from(spec)).unwrap()) })
.handle(r)
})
.get(ui_mount_path, move |r| RequestSpan( move |_| async move {
Ok(Response::builder().body(Body::from(format!(r#"
<!DOCTYPE html>
<html lang="en">
<head>
<title>rweb</title>
<link href="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui.css" rel="stylesheet">
</head>
<body>
<div id="swagger-ui"></div>
<script src="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui-bundle.js" charset="UTF-8"> </script>
<script>
window.onload = function() {{
const ui = SwaggerUIBundle({{
"dom_id": "\#swagger-ui",
presets: [
SwaggerUIBundle.presets.apis,
SwaggerUIBundle.SwaggerUIStandalonePreset
],
layout: "BaseLayout",
deepLinking: true,
showExtensions: true,
showCommonExtensions: true,
url: "{}",
}})
window.ui = ui;
}};
</script>
</body>
</html>
"#, spec_mount_path))).unwrap())
}).handle(r))
router_builder.get(spec_mount_path, move |_| async move {
Ok(Response::builder().body(Body::from(spec)).unwrap())
}).get(ui_mount_path, move |_| async move {
Ok(Response::builder().body(Body::from(format!(r#"
<!DOCTYPE html>
<html lang="en">
<head>
<title>rweb</title>
<link href="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui.css" rel="stylesheet">
</head>
<body>
<div id="swagger-ui"></div>
<script src="https://cdn.jsdelivr.net/npm/swagger-ui-dist@3/swagger-ui-bundle.js" charset="UTF-8"> </script>
<script>
window.onload = function() {{
const ui = SwaggerUIBundle({{
"dom_id": "\#swagger-ui",
presets: [
SwaggerUIBundle.presets.apis,
SwaggerUIBundle.SwaggerUIStandalonePreset
],
layout: "BaseLayout",
deepLinking: true,
showExtensions: true,
showCommonExtensions: true,
url: "{}",
}})
window.ui = ui;
}};
</script>
</body>
</html>
"#, spec_mount_path))).unwrap())
})
}
fn parse_token(header_value: &str) -> Result<&str, ApiError> {
@@ -278,7 +171,7 @@ where
async move {
let headers = response.headers_mut();
if headers.contains_key(&name) {
warn!(
tracing::warn!(
"{} response already contains header {:?}",
request_info.uri(),
&name,
@@ -318,7 +211,7 @@ pub fn serve_thread_main<S>(
where
S: Future<Output = ()> + Send + Sync,
{
info!("Starting an HTTP endpoint at {}", listener.local_addr()?);
tracing::info!("Starting an HTTP endpoint at {}", listener.local_addr()?);
// Create a Service from the router above to handle incoming requests.
let service = RouterService::new(router_builder.build().map_err(|err| anyhow!(err))?).unwrap();
@@ -338,48 +231,3 @@ where
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use futures::future::poll_fn;
use hyper::service::Service;
use routerify::RequestServiceBuilder;
use std::net::{IpAddr, SocketAddr};
#[tokio::test]
async fn test_request_id_returned() {
let builder = RequestServiceBuilder::new(make_router().build().unwrap()).unwrap();
let remote_addr = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), 80);
let mut service = builder.build(remote_addr);
if let Err(e) = poll_fn(|ctx| service.poll_ready(ctx)).await {
panic!("request service is not ready: {:?}", e);
}
let mut req: Request<Body> = Request::default();
req.headers_mut()
.append(&X_REQUEST_ID_HEADER, HeaderValue::from_str("42").unwrap());
let resp: Response<hyper::body::Body> = service.call(req).await.unwrap();
let header_val = resp.headers().get(&X_REQUEST_ID_HEADER).unwrap();
assert!(header_val == "42", "response header mismatch");
}
#[tokio::test]
async fn test_request_id_empty() {
let builder = RequestServiceBuilder::new(make_router().build().unwrap()).unwrap();
let remote_addr = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), 80);
let mut service = builder.build(remote_addr);
if let Err(e) = poll_fn(|ctx| service.poll_ready(ctx)).await {
panic!("request service is not ready: {:?}", e);
}
let req: Request<Body> = Request::default();
let resp: Response<hyper::body::Body> = service.call(req).await.unwrap();
let header_val = resp.headers().get(&X_REQUEST_ID_HEADER);
assert_ne!(header_val, None, "response header should NOT be empty");
}
}

View File

@@ -1,9 +1,7 @@
use std::fmt::Display;
use anyhow::Context;
use bytes::Buf;
use hyper::{header, Body, Request, Response, StatusCode};
use serde::{Deserialize, Serialize, Serializer};
use serde::{Deserialize, Serialize};
use super::error::ApiError;
@@ -33,12 +31,3 @@ pub fn json_response<T: Serialize>(
.map_err(|e| ApiError::InternalServerError(e.into()))?;
Ok(response)
}
/// Serialize through Display trait.
pub fn display_serialize<S, F>(z: &F, s: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
F: Display,
{
s.serialize_str(&format!("{}", z))
}

View File

@@ -13,6 +13,8 @@ pub mod simple_rcu;
pub mod vec_map;
pub mod bin_ser;
pub mod postgres_backend;
pub mod postgres_backend_async;
// helper functions for creating and fsyncing
pub mod crashsafe;
@@ -25,6 +27,9 @@ pub mod id;
// http endpoint utils
pub mod http;
// socket splitting utils
pub mod sock_split;
// common log initialisation routine
pub mod logging;

View File

@@ -0,0 +1,485 @@
//! Server-side synchronous Postgres connection, as limited as we need.
//! To use, create PostgresBackend and run() it, passing the Handler
//! implementation determining how to process the queries. Currently its API
//! is rather narrow, but we can extend it once required.
use crate::postgres_backend_async::{log_query_error, short_error, QueryError};
use crate::sock_split::{BidiStream, ReadStream, WriteStream};
use anyhow::Context;
use bytes::{Bytes, BytesMut};
use pq_proto::{BeMessage, FeMessage, FeStartupPacket};
use serde::{Deserialize, Serialize};
use std::fmt;
use std::io::{self, Write};
use std::net::{Shutdown, SocketAddr, TcpStream};
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use tracing::*;
pub trait Handler {
/// Handle single query.
/// postgres_backend will issue ReadyForQuery after calling this (this
/// might be not what we want after CopyData streaming, but currently we don't
/// care).
fn process_query(
&mut self,
pgb: &mut PostgresBackend,
query_string: &str,
) -> Result<(), QueryError>;
/// Called on startup packet receival, allows to process params.
///
/// If Ok(false) is returned postgres_backend will skip auth -- that is needed for new users
/// creation is the proxy code. That is quite hacky and ad-hoc solution, may be we could allow
/// to override whole init logic in implementations.
fn startup(
&mut self,
_pgb: &mut PostgresBackend,
_sm: &FeStartupPacket,
) -> Result<(), QueryError> {
Ok(())
}
/// Check auth jwt
fn check_auth_jwt(
&mut self,
_pgb: &mut PostgresBackend,
_jwt_response: &[u8],
) -> Result<(), QueryError> {
Err(QueryError::Other(anyhow::anyhow!("JWT auth failed")))
}
fn is_shutdown_requested(&self) -> bool {
false
}
}
/// PostgresBackend protocol state.
/// XXX: The order of the constructors matters.
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd)]
pub enum ProtoState {
Initialization,
Encrypted,
Authentication,
Established,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
pub enum AuthType {
Trust,
// This mimics postgres's AuthenticationCleartextPassword but instead of password expects JWT
NeonJWT,
}
impl FromStr for AuthType {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"Trust" => Ok(Self::Trust),
"NeonJWT" => Ok(Self::NeonJWT),
_ => anyhow::bail!("invalid value \"{s}\" for auth type"),
}
}
}
impl fmt::Display for AuthType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
AuthType::Trust => "Trust",
AuthType::NeonJWT => "NeonJWT",
})
}
}
#[derive(Clone, Copy)]
pub enum ProcessMsgResult {
Continue,
Break,
}
/// Always-writeable sock_split stream.
/// May not be readable. See [`PostgresBackend::take_stream_in`]
pub enum Stream {
Bidirectional(BidiStream),
WriteOnly(WriteStream),
}
impl Stream {
fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
match self {
Self::Bidirectional(bidi_stream) => bidi_stream.shutdown(how),
Self::WriteOnly(write_stream) => write_stream.shutdown(how),
}
}
}
impl io::Write for Stream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self {
Self::Bidirectional(bidi_stream) => bidi_stream.write(buf),
Self::WriteOnly(write_stream) => write_stream.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match self {
Self::Bidirectional(bidi_stream) => bidi_stream.flush(),
Self::WriteOnly(write_stream) => write_stream.flush(),
}
}
}
pub struct PostgresBackend {
stream: Option<Stream>,
// Output buffer. c.f. BeMessage::write why we are using BytesMut here.
buf_out: BytesMut,
pub state: ProtoState,
auth_type: AuthType,
peer_addr: SocketAddr,
pub tls_config: Option<Arc<rustls::ServerConfig>>,
}
pub fn query_from_cstring(query_string: Bytes) -> Vec<u8> {
let mut query_string = query_string.to_vec();
if let Some(ch) = query_string.last() {
if *ch == 0 {
query_string.pop();
}
}
query_string
}
// Helper function for socket read loops
pub fn is_socket_read_timed_out(error: &anyhow::Error) -> bool {
for cause in error.chain() {
if let Some(io_error) = cause.downcast_ref::<io::Error>() {
if io_error.kind() == std::io::ErrorKind::WouldBlock {
return true;
}
}
}
false
}
// Cast a byte slice to a string slice, dropping null terminator if there's one.
fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> {
let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes);
std::str::from_utf8(without_null).map_err(|e| e.into())
}
impl PostgresBackend {
pub fn new(
socket: TcpStream,
auth_type: AuthType,
tls_config: Option<Arc<rustls::ServerConfig>>,
set_read_timeout: bool,
) -> io::Result<Self> {
let peer_addr = socket.peer_addr()?;
if set_read_timeout {
socket
.set_read_timeout(Some(Duration::from_secs(5)))
.unwrap();
}
Ok(Self {
stream: Some(Stream::Bidirectional(BidiStream::from_tcp(socket))),
buf_out: BytesMut::with_capacity(10 * 1024),
state: ProtoState::Initialization,
auth_type,
tls_config,
peer_addr,
})
}
pub fn into_stream(self) -> Stream {
self.stream.unwrap()
}
/// Get direct reference (into the Option) to the read stream.
fn get_stream_in(&mut self) -> anyhow::Result<&mut BidiStream> {
match &mut self.stream {
Some(Stream::Bidirectional(stream)) => Ok(stream),
_ => anyhow::bail!("reader taken"),
}
}
pub fn get_peer_addr(&self) -> &SocketAddr {
&self.peer_addr
}
pub fn take_stream_in(&mut self) -> Option<ReadStream> {
let stream = self.stream.take();
match stream {
Some(Stream::Bidirectional(bidi_stream)) => {
let (read, write) = bidi_stream.split();
self.stream = Some(Stream::WriteOnly(write));
Some(read)
}
stream => {
self.stream = stream;
None
}
}
}
/// Read full message or return None if connection is closed.
pub fn read_message(&mut self) -> Result<Option<FeMessage>, QueryError> {
let (state, stream) = (self.state, self.get_stream_in()?);
use ProtoState::*;
match state {
Initialization | Encrypted => FeStartupPacket::read(stream),
Authentication | Established => FeMessage::read(stream),
}
.map_err(QueryError::from)
}
/// Write message into internal output buffer.
pub fn write_message_noflush(&mut self, message: &BeMessage) -> io::Result<&mut Self> {
BeMessage::write(&mut self.buf_out, message)?;
Ok(self)
}
/// Flush output buffer into the socket.
pub fn flush(&mut self) -> io::Result<&mut Self> {
let stream = self.stream.as_mut().unwrap();
stream.write_all(&self.buf_out)?;
self.buf_out.clear();
Ok(self)
}
/// Write message into internal buffer and flush it.
pub fn write_message(&mut self, message: &BeMessage) -> io::Result<&mut Self> {
self.write_message_noflush(message)?;
self.flush()
}
// Wrapper for run_message_loop() that shuts down socket when we are done
pub fn run(mut self, handler: &mut impl Handler) -> Result<(), QueryError> {
let ret = self.run_message_loop(handler);
if let Some(stream) = self.stream.as_mut() {
let _ = stream.shutdown(Shutdown::Both);
}
ret
}
fn run_message_loop(&mut self, handler: &mut impl Handler) -> Result<(), QueryError> {
trace!("postgres backend to {:?} started", self.peer_addr);
let mut unnamed_query_string = Bytes::new();
while !handler.is_shutdown_requested() {
match self.read_message() {
Ok(message) => {
if let Some(msg) = message {
trace!("got message {msg:?}");
match self.process_message(handler, msg, &mut unnamed_query_string)? {
ProcessMsgResult::Continue => continue,
ProcessMsgResult::Break => break,
}
} else {
break;
}
}
Err(e) => {
if let QueryError::Other(e) = &e {
if is_socket_read_timed_out(e) {
continue;
}
}
return Err(e);
}
}
}
trace!("postgres backend to {:?} exited", self.peer_addr);
Ok(())
}
pub fn start_tls(&mut self) -> anyhow::Result<()> {
match self.stream.take() {
Some(Stream::Bidirectional(bidi_stream)) => {
let conn = rustls::ServerConnection::new(self.tls_config.clone().unwrap())?;
self.stream = Some(Stream::Bidirectional(bidi_stream.start_tls(conn)?));
Ok(())
}
stream => {
self.stream = stream;
anyhow::bail!("can't start TLs without bidi stream");
}
}
}
fn process_message(
&mut self,
handler: &mut impl Handler,
msg: FeMessage,
unnamed_query_string: &mut Bytes,
) -> Result<ProcessMsgResult, QueryError> {
// Allow only startup and password messages during auth. Otherwise client would be able to bypass auth
// TODO: change that to proper top-level match of protocol state with separate message handling for each state
if self.state < ProtoState::Established
&& !matches!(
msg,
FeMessage::PasswordMessage(_) | FeMessage::StartupPacket(_)
)
{
return Err(QueryError::Other(anyhow::anyhow!("protocol violation")));
}
let have_tls = self.tls_config.is_some();
match msg {
FeMessage::StartupPacket(m) => {
trace!("got startup message {m:?}");
match m {
FeStartupPacket::SslRequest => {
debug!("SSL requested");
self.write_message(&BeMessage::EncryptionResponse(have_tls))?;
if have_tls {
self.start_tls()?;
self.state = ProtoState::Encrypted;
}
}
FeStartupPacket::GssEncRequest => {
debug!("GSS requested");
self.write_message(&BeMessage::EncryptionResponse(false))?;
}
FeStartupPacket::StartupMessage { .. } => {
if have_tls && !matches!(self.state, ProtoState::Encrypted) {
self.write_message(&BeMessage::ErrorResponse(
"must connect with TLS",
None,
))?;
return Err(QueryError::Other(anyhow::anyhow!(
"client did not connect with TLS"
)));
}
// NB: startup() may change self.auth_type -- we are using that in proxy code
// to bypass auth for new users.
handler.startup(self, &m)?;
match self.auth_type {
AuthType::Trust => {
self.write_message_noflush(&BeMessage::AuthenticationOk)?
.write_message_noflush(&BeMessage::CLIENT_ENCODING)?
// The async python driver requires a valid server_version
.write_message_noflush(&BeMessage::server_version("14.1"))?
.write_message(&BeMessage::ReadyForQuery)?;
self.state = ProtoState::Established;
}
AuthType::NeonJWT => {
self.write_message(&BeMessage::AuthenticationCleartextPassword)?;
self.state = ProtoState::Authentication;
}
}
}
FeStartupPacket::CancelRequest { .. } => {
return Ok(ProcessMsgResult::Break);
}
}
}
FeMessage::PasswordMessage(m) => {
trace!("got password message '{:?}'", m);
assert!(self.state == ProtoState::Authentication);
match self.auth_type {
AuthType::Trust => unreachable!(),
AuthType::NeonJWT => {
let (_, jwt_response) = m.split_last().context("protocol violation")?;
if let Err(e) = handler.check_auth_jwt(self, jwt_response) {
self.write_message(&BeMessage::ErrorResponse(
&e.to_string(),
Some(e.pg_error_code()),
))?;
return Err(e);
}
}
}
self.write_message_noflush(&BeMessage::AuthenticationOk)?
.write_message_noflush(&BeMessage::CLIENT_ENCODING)?
.write_message(&BeMessage::ReadyForQuery)?;
self.state = ProtoState::Established;
}
FeMessage::Query(body) => {
// remove null terminator
let query_string = cstr_to_str(&body)?;
trace!("got query {query_string:?}");
if let Err(e) = handler.process_query(self, query_string) {
log_query_error(query_string, &e);
let short_error = short_error(&e);
self.write_message_noflush(&BeMessage::ErrorResponse(
&short_error,
Some(e.pg_error_code()),
))?;
}
self.write_message(&BeMessage::ReadyForQuery)?;
}
FeMessage::Parse(m) => {
*unnamed_query_string = m.query_string;
self.write_message(&BeMessage::ParseComplete)?;
}
FeMessage::Describe(_) => {
self.write_message_noflush(&BeMessage::ParameterDescription)?
.write_message(&BeMessage::NoData)?;
}
FeMessage::Bind(_) => {
self.write_message(&BeMessage::BindComplete)?;
}
FeMessage::Close(_) => {
self.write_message(&BeMessage::CloseComplete)?;
}
FeMessage::Execute(_) => {
let query_string = cstr_to_str(unnamed_query_string)?;
trace!("got execute {query_string:?}");
if let Err(e) = handler.process_query(self, query_string) {
log_query_error(query_string, &e);
self.write_message(&BeMessage::ErrorResponse(
&e.to_string(),
Some(e.pg_error_code()),
))?;
}
// NOTE there is no ReadyForQuery message. This handler is used
// for basebackup and it uses CopyOut which doesn't require
// ReadyForQuery message and backend just switches back to
// processing mode after sending CopyDone or ErrorResponse.
}
FeMessage::Sync => {
self.write_message(&BeMessage::ReadyForQuery)?;
}
FeMessage::Terminate => {
return Ok(ProcessMsgResult::Break);
}
// We prefer explicit pattern matching to wildcards, because
// this helps us spot the places where new variants are missing
FeMessage::CopyData(_) | FeMessage::CopyDone | FeMessage::CopyFail => {
return Err(QueryError::Other(anyhow::anyhow!(
"unexpected message type: {msg:?}"
)));
}
}
Ok(ProcessMsgResult::Continue)
}
}

View File

@@ -0,0 +1,634 @@
//! Server-side asynchronous Postgres connection, as limited as we need.
//! To use, create PostgresBackend and run() it, passing the Handler
//! implementation determining how to process the queries. Currently its API
//! is rather narrow, but we can extend it once required.
use crate::postgres_backend::AuthType;
use anyhow::Context;
use bytes::{Buf, Bytes, BytesMut};
use pq_proto::{BeMessage, ConnectionError, FeMessage, FeStartupPacket, SQLSTATE_INTERNAL_ERROR};
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::Poll;
use std::{future::Future, task::ready};
use tracing::{debug, error, info, trace};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
use tokio_rustls::TlsAcceptor;
pub fn is_expected_io_error(e: &io::Error) -> bool {
use io::ErrorKind::*;
matches!(
e.kind(),
ConnectionRefused | ConnectionAborted | ConnectionReset
)
}
/// An error, occurred during query processing:
/// either during the connection ([`ConnectionError`]) or before/after it.
#[derive(thiserror::Error, Debug)]
pub enum QueryError {
/// The connection was lost while processing the query.
#[error(transparent)]
Disconnected(#[from] ConnectionError),
/// Some other error
#[error(transparent)]
Other(#[from] anyhow::Error),
}
impl From<io::Error> for QueryError {
fn from(e: io::Error) -> Self {
Self::Disconnected(ConnectionError::Socket(e))
}
}
impl QueryError {
pub fn pg_error_code(&self) -> &'static [u8; 5] {
match self {
Self::Disconnected(_) => b"08006", // connection failure
Self::Other(_) => SQLSTATE_INTERNAL_ERROR, // internal error
}
}
}
#[async_trait::async_trait]
pub trait Handler {
/// Handle single query.
/// postgres_backend will issue ReadyForQuery after calling this (this
/// might be not what we want after CopyData streaming, but currently we don't
/// care).
async fn process_query(
&mut self,
pgb: &mut PostgresBackend,
query_string: &str,
) -> Result<(), QueryError>;
/// Called on startup packet receival, allows to process params.
///
/// If Ok(false) is returned postgres_backend will skip auth -- that is needed for new users
/// creation is the proxy code. That is quite hacky and ad-hoc solution, may be we could allow
/// to override whole init logic in implementations.
fn startup(
&mut self,
_pgb: &mut PostgresBackend,
_sm: &FeStartupPacket,
) -> Result<(), QueryError> {
Ok(())
}
/// Check auth jwt
fn check_auth_jwt(
&mut self,
_pgb: &mut PostgresBackend,
_jwt_response: &[u8],
) -> Result<(), QueryError> {
Err(QueryError::Other(anyhow::anyhow!("JWT auth failed")))
}
}
/// PostgresBackend protocol state.
/// XXX: The order of the constructors matters.
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd)]
pub enum ProtoState {
Initialization,
Encrypted,
Authentication,
Established,
Closed,
}
#[derive(Clone, Copy)]
pub enum ProcessMsgResult {
Continue,
Break,
}
/// Always-writeable sock_split stream.
/// May not be readable. See [`PostgresBackend::take_stream_in`]
pub enum Stream {
Unencrypted(BufReader<tokio::net::TcpStream>),
Tls(Box<tokio_rustls::server::TlsStream<BufReader<tokio::net::TcpStream>>>),
Broken,
}
impl AsyncWrite for Stream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_write(cx, buf),
Self::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
Self::Broken => unreachable!(),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_flush(cx),
Self::Tls(stream) => Pin::new(stream).poll_flush(cx),
Self::Broken => unreachable!(),
}
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_shutdown(cx),
Self::Tls(stream) => Pin::new(stream).poll_shutdown(cx),
Self::Broken => unreachable!(),
}
}
}
impl AsyncRead for Stream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match self.get_mut() {
Self::Unencrypted(stream) => Pin::new(stream).poll_read(cx, buf),
Self::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
Self::Broken => unreachable!(),
}
}
}
pub struct PostgresBackend {
stream: Stream,
// Output buffer. c.f. BeMessage::write why we are using BytesMut here.
// The data between 0 and "current position" as tracked by the bytes::Buf
// implementation of BytesMut, have already been written.
buf_out: BytesMut,
pub state: ProtoState,
auth_type: AuthType,
peer_addr: SocketAddr,
pub tls_config: Option<Arc<rustls::ServerConfig>>,
}
pub fn query_from_cstring(query_string: Bytes) -> Vec<u8> {
let mut query_string = query_string.to_vec();
if let Some(ch) = query_string.last() {
if *ch == 0 {
query_string.pop();
}
}
query_string
}
// Cast a byte slice to a string slice, dropping null terminator if there's one.
fn cstr_to_str(bytes: &[u8]) -> anyhow::Result<&str> {
let without_null = bytes.strip_suffix(&[0]).unwrap_or(bytes);
std::str::from_utf8(without_null).map_err(|e| e.into())
}
impl PostgresBackend {
pub fn new(
socket: tokio::net::TcpStream,
auth_type: AuthType,
tls_config: Option<Arc<rustls::ServerConfig>>,
) -> io::Result<Self> {
let peer_addr = socket.peer_addr()?;
Ok(Self {
stream: Stream::Unencrypted(BufReader::new(socket)),
buf_out: BytesMut::with_capacity(10 * 1024),
state: ProtoState::Initialization,
auth_type,
tls_config,
peer_addr,
})
}
pub fn get_peer_addr(&self) -> &SocketAddr {
&self.peer_addr
}
/// Read full message or return None if connection is closed.
pub async fn read_message(&mut self) -> Result<Option<FeMessage>, QueryError> {
use ProtoState::*;
match self.state {
Initialization | Encrypted => FeStartupPacket::read_fut(&mut self.stream).await,
Authentication | Established => FeMessage::read_fut(&mut self.stream).await,
Closed => Ok(None),
}
.map_err(QueryError::from)
}
/// Flush output buffer into the socket.
pub async fn flush(&mut self) -> io::Result<()> {
while self.buf_out.has_remaining() {
let bytes_written = self.stream.write(self.buf_out.chunk()).await?;
self.buf_out.advance(bytes_written);
}
self.buf_out.clear();
Ok(())
}
/// Write message into internal output buffer.
pub fn write_message(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
BeMessage::write(&mut self.buf_out, message)?;
Ok(self)
}
/// Returns an AsyncWrite implementation that wraps all the data written
/// to it in CopyData messages, and writes them to the connection
///
/// The caller is responsible for sending CopyOutResponse and CopyDone messages.
pub fn copyout_writer(&mut self) -> CopyDataWriter {
CopyDataWriter { pgb: self }
}
/// A polling function that tries to write all the data from 'buf_out' to the
/// underlying stream.
fn poll_write_buf(
&mut self,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
while self.buf_out.has_remaining() {
match ready!(Pin::new(&mut self.stream).poll_write(cx, self.buf_out.chunk())) {
Ok(bytes_written) => self.buf_out.advance(bytes_written),
Err(err) => return Poll::Ready(Err(err)),
}
}
Poll::Ready(Ok(()))
}
fn poll_flush(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), std::io::Error>> {
Pin::new(&mut self.stream).poll_flush(cx)
}
// Wrapper for run_message_loop() that shuts down socket when we are done
pub async fn run<F, S>(
mut self,
handler: &mut impl Handler,
shutdown_watcher: F,
) -> Result<(), QueryError>
where
F: Fn() -> S,
S: Future,
{
let ret = self.run_message_loop(handler, shutdown_watcher).await;
let _ = self.stream.shutdown();
ret
}
async fn run_message_loop<F, S>(
&mut self,
handler: &mut impl Handler,
shutdown_watcher: F,
) -> Result<(), QueryError>
where
F: Fn() -> S,
S: Future,
{
trace!("postgres backend to {:?} started", self.peer_addr);
tokio::select!(
biased;
_ = shutdown_watcher() => {
// We were requested to shut down.
tracing::info!("shutdown request received during handshake");
return Ok(())
},
result = async {
while self.state < ProtoState::Established {
if let Some(msg) = self.read_message().await? {
trace!("got message {msg:?} during handshake");
match self.process_handshake_message(handler, msg).await? {
ProcessMsgResult::Continue => {
self.flush().await?;
continue;
}
ProcessMsgResult::Break => {
trace!("postgres backend to {:?} exited during handshake", self.peer_addr);
return Ok(());
}
}
} else {
trace!("postgres backend to {:?} exited during handshake", self.peer_addr);
return Ok(());
}
}
Ok::<(), QueryError>(())
} => {
// Handshake complete.
result?;
}
);
// Authentication completed
let mut query_string = Bytes::new();
while let Some(msg) = tokio::select!(
biased;
_ = shutdown_watcher() => {
// We were requested to shut down.
tracing::info!("shutdown request received in run_message_loop");
Ok(None)
},
msg = self.read_message() => { msg },
)? {
trace!("got message {:?}", msg);
let result = self.process_message(handler, msg, &mut query_string).await;
self.flush().await?;
match result? {
ProcessMsgResult::Continue => {
self.flush().await?;
continue;
}
ProcessMsgResult::Break => break,
}
}
trace!("postgres backend to {:?} exited", self.peer_addr);
Ok(())
}
async fn start_tls(&mut self) -> anyhow::Result<()> {
if let Stream::Unencrypted(plain_stream) =
std::mem::replace(&mut self.stream, Stream::Broken)
{
let acceptor = TlsAcceptor::from(self.tls_config.clone().unwrap());
let tls_stream = acceptor.accept(plain_stream).await?;
self.stream = Stream::Tls(Box::new(tls_stream));
return Ok(());
};
anyhow::bail!("TLS already started");
}
async fn process_handshake_message(
&mut self,
handler: &mut impl Handler,
msg: FeMessage,
) -> Result<ProcessMsgResult, QueryError> {
assert!(self.state < ProtoState::Established);
let have_tls = self.tls_config.is_some();
match msg {
FeMessage::StartupPacket(m) => {
trace!("got startup message {m:?}");
match m {
FeStartupPacket::SslRequest => {
debug!("SSL requested");
self.write_message(&BeMessage::EncryptionResponse(have_tls))?;
if have_tls {
self.start_tls().await?;
self.state = ProtoState::Encrypted;
}
}
FeStartupPacket::GssEncRequest => {
debug!("GSS requested");
self.write_message(&BeMessage::EncryptionResponse(false))?;
}
FeStartupPacket::StartupMessage { .. } => {
if have_tls && !matches!(self.state, ProtoState::Encrypted) {
self.write_message(&BeMessage::ErrorResponse(
"must connect with TLS",
None,
))?;
return Err(QueryError::Other(anyhow::anyhow!(
"client did not connect with TLS"
)));
}
// NB: startup() may change self.auth_type -- we are using that in proxy code
// to bypass auth for new users.
handler.startup(self, &m)?;
match self.auth_type {
AuthType::Trust => {
self.write_message(&BeMessage::AuthenticationOk)?
.write_message(&BeMessage::CLIENT_ENCODING)?
// The async python driver requires a valid server_version
.write_message(&BeMessage::server_version("14.1"))?
.write_message(&BeMessage::ReadyForQuery)?;
self.state = ProtoState::Established;
}
AuthType::NeonJWT => {
self.write_message(&BeMessage::AuthenticationCleartextPassword)?;
self.state = ProtoState::Authentication;
}
}
}
FeStartupPacket::CancelRequest { .. } => {
self.state = ProtoState::Closed;
return Ok(ProcessMsgResult::Break);
}
}
}
FeMessage::PasswordMessage(m) => {
trace!("got password message '{:?}'", m);
assert!(self.state == ProtoState::Authentication);
match self.auth_type {
AuthType::Trust => unreachable!(),
AuthType::NeonJWT => {
let (_, jwt_response) = m.split_last().context("protocol violation")?;
if let Err(e) = handler.check_auth_jwt(self, jwt_response) {
self.write_message(&BeMessage::ErrorResponse(
&e.to_string(),
Some(e.pg_error_code()),
))?;
return Err(e);
}
}
}
self.write_message(&BeMessage::AuthenticationOk)?
.write_message(&BeMessage::CLIENT_ENCODING)?
.write_message(&BeMessage::ReadyForQuery)?;
self.state = ProtoState::Established;
}
_ => {
self.state = ProtoState::Closed;
return Ok(ProcessMsgResult::Break);
}
}
Ok(ProcessMsgResult::Continue)
}
async fn process_message(
&mut self,
handler: &mut impl Handler,
msg: FeMessage,
unnamed_query_string: &mut Bytes,
) -> Result<ProcessMsgResult, QueryError> {
// Allow only startup and password messages during auth. Otherwise client would be able to bypass auth
// TODO: change that to proper top-level match of protocol state with separate message handling for each state
assert!(self.state == ProtoState::Established);
match msg {
FeMessage::StartupPacket(_) | FeMessage::PasswordMessage(_) => {
return Err(QueryError::Other(anyhow::anyhow!("protocol violation")));
}
FeMessage::Query(body) => {
// remove null terminator
let query_string = cstr_to_str(&body)?;
trace!("got query {query_string:?}");
if let Err(e) = handler.process_query(self, query_string).await {
log_query_error(query_string, &e);
let short_error = short_error(&e);
self.write_message(&BeMessage::ErrorResponse(
&short_error,
Some(e.pg_error_code()),
))?;
}
self.write_message(&BeMessage::ReadyForQuery)?;
}
FeMessage::Parse(m) => {
*unnamed_query_string = m.query_string;
self.write_message(&BeMessage::ParseComplete)?;
}
FeMessage::Describe(_) => {
self.write_message(&BeMessage::ParameterDescription)?
.write_message(&BeMessage::NoData)?;
}
FeMessage::Bind(_) => {
self.write_message(&BeMessage::BindComplete)?;
}
FeMessage::Close(_) => {
self.write_message(&BeMessage::CloseComplete)?;
}
FeMessage::Execute(_) => {
let query_string = cstr_to_str(unnamed_query_string)?;
trace!("got execute {query_string:?}");
if let Err(e) = handler.process_query(self, query_string).await {
log_query_error(query_string, &e);
self.write_message(&BeMessage::ErrorResponse(
&e.to_string(),
Some(e.pg_error_code()),
))?;
}
// NOTE there is no ReadyForQuery message. This handler is used
// for basebackup and it uses CopyOut which doesn't require
// ReadyForQuery message and backend just switches back to
// processing mode after sending CopyDone or ErrorResponse.
}
FeMessage::Sync => {
self.write_message(&BeMessage::ReadyForQuery)?;
}
FeMessage::Terminate => {
return Ok(ProcessMsgResult::Break);
}
// We prefer explicit pattern matching to wildcards, because
// this helps us spot the places where new variants are missing
FeMessage::CopyData(_) | FeMessage::CopyDone | FeMessage::CopyFail => {
return Err(QueryError::Other(anyhow::anyhow!(
"unexpected message type: {:?}",
msg
)));
}
}
Ok(ProcessMsgResult::Continue)
}
}
///
/// A futures::AsyncWrite implementation that wraps all data written to it in CopyData
/// messages.
///
pub struct CopyDataWriter<'a> {
pgb: &'a mut PostgresBackend,
}
impl<'a> AsyncWrite for CopyDataWriter<'a> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
let this = self.get_mut();
// It's not strictly required to flush between each message, but makes it easier
// to view in wireshark, and usually the messages that the callers write are
// decently-sized anyway.
match ready!(this.pgb.poll_write_buf(cx)) {
Ok(()) => {}
Err(err) => return Poll::Ready(Err(err)),
}
// CopyData
// XXX: if the input is large, we should split it into multiple messages.
// Not sure what the threshold should be, but the ultimate hard limit is that
// the length cannot exceed u32.
this.pgb.write_message(&BeMessage::CopyData(buf))?;
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let this = self.get_mut();
match ready!(this.pgb.poll_write_buf(cx)) {
Ok(()) => {}
Err(err) => return Poll::Ready(Err(err)),
}
this.pgb.poll_flush(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let this = self.get_mut();
match ready!(this.pgb.poll_write_buf(cx)) {
Ok(()) => {}
Err(err) => return Poll::Ready(Err(err)),
}
this.pgb.poll_flush(cx)
}
}
pub fn short_error(e: &QueryError) -> String {
match e {
QueryError::Disconnected(connection_error) => connection_error.to_string(),
QueryError::Other(e) => format!("{e:#}"),
}
}
pub(super) fn log_query_error(query: &str, e: &QueryError) {
match e {
QueryError::Disconnected(ConnectionError::Socket(io_error)) => {
if is_expected_io_error(io_error) {
info!("query handler for '{query}' failed with expected io error: {io_error}");
} else {
error!("query handler for '{query}' failed with io error: {io_error}");
}
}
QueryError::Disconnected(other_connection_error) => {
error!("query handler for '{query}' failed with connection error: {other_connection_error:?}")
}
QueryError::Other(e) => {
error!("query handler for '{query}' failed: {e:?}");
}
}
}

View File

@@ -0,0 +1,206 @@
use std::{
io::{self, BufReader, Write},
net::{Shutdown, TcpStream},
sync::Arc,
};
use rustls::Connection;
/// Wrapper supporting reads of a shared TcpStream.
pub struct ArcTcpRead(Arc<TcpStream>);
impl io::Read for ArcTcpRead {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
(&*self.0).read(buf)
}
}
impl std::ops::Deref for ArcTcpRead {
type Target = TcpStream;
fn deref(&self) -> &Self::Target {
self.0.deref()
}
}
/// Wrapper around a TCP Stream supporting buffered reads.
pub struct BufStream(BufReader<ArcTcpRead>);
impl io::Read for BufStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.0.read(buf)
}
}
impl io::Write for BufStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.get_ref().write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.get_ref().flush()
}
}
impl BufStream {
/// Unwrap into the internal BufReader.
fn into_reader(self) -> BufReader<ArcTcpRead> {
self.0
}
/// Returns a reference to the underlying TcpStream.
fn get_ref(&self) -> &TcpStream {
&self.0.get_ref().0
}
}
pub enum ReadStream {
Tcp(BufReader<ArcTcpRead>),
Tls(rustls_split::ReadHalf),
}
impl io::Read for ReadStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self {
Self::Tcp(reader) => reader.read(buf),
Self::Tls(read_half) => read_half.read(buf),
}
}
}
impl ReadStream {
pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
match self {
Self::Tcp(stream) => stream.get_ref().shutdown(how),
Self::Tls(write_half) => write_half.shutdown(how),
}
}
}
pub enum WriteStream {
Tcp(Arc<TcpStream>),
Tls(rustls_split::WriteHalf),
}
impl WriteStream {
pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
match self {
Self::Tcp(stream) => stream.shutdown(how),
Self::Tls(write_half) => write_half.shutdown(how),
}
}
}
impl io::Write for WriteStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self {
Self::Tcp(stream) => stream.as_ref().write(buf),
Self::Tls(write_half) => write_half.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match self {
Self::Tcp(stream) => stream.as_ref().flush(),
Self::Tls(write_half) => write_half.flush(),
}
}
}
type TlsStream<T> = rustls::StreamOwned<rustls::ServerConnection, T>;
pub enum BidiStream {
Tcp(BufStream),
/// This variant is boxed, because [`rustls::ServerConnection`] is quite larger than [`BufStream`].
Tls(Box<TlsStream<BufStream>>),
}
impl BidiStream {
pub fn from_tcp(stream: TcpStream) -> Self {
Self::Tcp(BufStream(BufReader::new(ArcTcpRead(Arc::new(stream)))))
}
pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
match self {
Self::Tcp(stream) => stream.get_ref().shutdown(how),
Self::Tls(tls_boxed) => {
if how == Shutdown::Read {
tls_boxed.sock.get_ref().shutdown(how)
} else {
tls_boxed.conn.send_close_notify();
let res = tls_boxed.flush();
tls_boxed.sock.get_ref().shutdown(how)?;
res
}
}
}
}
/// Split the bi-directional stream into two owned read and write halves.
pub fn split(self) -> (ReadStream, WriteStream) {
match self {
Self::Tcp(stream) => {
let reader = stream.into_reader();
let stream: Arc<TcpStream> = reader.get_ref().0.clone();
(ReadStream::Tcp(reader), WriteStream::Tcp(stream))
}
Self::Tls(tls_boxed) => {
let reader = tls_boxed.sock.into_reader();
let buffer_data = reader.buffer().to_owned();
let read_buf_cfg = rustls_split::BufCfg::with_data(buffer_data, 8192);
let write_buf_cfg = rustls_split::BufCfg::with_capacity(8192);
// TODO would be nice to avoid the Arc here
let socket = Arc::try_unwrap(reader.into_inner().0).unwrap();
let (read_half, write_half) = rustls_split::split(
socket,
Connection::Server(tls_boxed.conn),
read_buf_cfg,
write_buf_cfg,
);
(ReadStream::Tls(read_half), WriteStream::Tls(write_half))
}
}
}
pub fn start_tls(self, mut conn: rustls::ServerConnection) -> io::Result<Self> {
match self {
Self::Tcp(mut stream) => {
conn.complete_io(&mut stream)?;
assert!(!conn.is_handshaking());
Ok(Self::Tls(Box::new(TlsStream::new(conn, stream))))
}
Self::Tls { .. } => Err(io::Error::new(
io::ErrorKind::InvalidInput,
"TLS is already started on this stream",
)),
}
}
}
impl io::Read for BidiStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self {
Self::Tcp(stream) => stream.read(buf),
Self::Tls(tls_boxed) => tls_boxed.read(buf),
}
}
}
impl io::Write for BidiStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self {
Self::Tcp(stream) => stream.write(buf),
Self::Tls(tls_boxed) => tls_boxed.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match self {
Self::Tcp(stream) => stream.flush(),
Self::Tls(tls_boxed) => tls_boxed.flush(),
}
}
}

View File

@@ -0,0 +1,238 @@
use std::{
collections::HashMap,
io::{Cursor, Read, Write},
net::{TcpListener, TcpStream},
sync::Arc,
};
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use once_cell::sync::Lazy;
use utils::{
postgres_backend::{AuthType, Handler, PostgresBackend},
postgres_backend_async::QueryError,
};
fn make_tcp_pair() -> (TcpStream, TcpStream) {
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let client_stream = TcpStream::connect(addr).unwrap();
let (server_stream, _) = listener.accept().unwrap();
(server_stream, client_stream)
}
static KEY: Lazy<rustls::PrivateKey> = Lazy::new(|| {
let mut cursor = Cursor::new(include_bytes!("key.pem"));
rustls::PrivateKey(rustls_pemfile::rsa_private_keys(&mut cursor).unwrap()[0].clone())
});
static CERT: Lazy<rustls::Certificate> = Lazy::new(|| {
let mut cursor = Cursor::new(include_bytes!("cert.pem"));
rustls::Certificate(rustls_pemfile::certs(&mut cursor).unwrap()[0].clone())
});
#[test]
// [false-positive](https://github.com/rust-lang/rust-clippy/issues/9274),
// we resize the vector so doing some modifications after all
#[allow(clippy::read_zero_byte_vec)]
fn ssl() {
let (mut client_sock, server_sock) = make_tcp_pair();
const QUERY: &str = "hello world";
let client_jh = std::thread::spawn(move || {
// SSLRequest
client_sock.write_u32::<BigEndian>(8).unwrap();
client_sock.write_u32::<BigEndian>(80877103).unwrap();
let ssl_response = client_sock.read_u8().unwrap();
assert_eq!(b'S', ssl_response);
let cfg = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates({
let mut store = rustls::RootCertStore::empty();
store.add(&CERT).unwrap();
store
})
.with_no_client_auth();
let client_config = Arc::new(cfg);
let dns_name = "localhost".try_into().unwrap();
let mut conn = rustls::ClientConnection::new(client_config, dns_name).unwrap();
conn.complete_io(&mut client_sock).unwrap();
assert!(!conn.is_handshaking());
let mut stream = rustls::Stream::new(&mut conn, &mut client_sock);
// StartupMessage
stream.write_u32::<BigEndian>(9).unwrap();
stream.write_u32::<BigEndian>(196608).unwrap();
stream.write_u8(0).unwrap();
stream.flush().unwrap();
// wait for ReadyForQuery
let mut msg_buf = Vec::new();
loop {
let msg = stream.read_u8().unwrap();
let size = stream.read_u32::<BigEndian>().unwrap() - 4;
msg_buf.resize(size as usize, 0);
stream.read_exact(&mut msg_buf).unwrap();
if msg == b'Z' {
// ReadyForQuery
break;
}
}
// Query
stream.write_u8(b'Q').unwrap();
stream
.write_u32::<BigEndian>(4u32 + QUERY.len() as u32)
.unwrap();
stream.write_all(QUERY.as_ref()).unwrap();
stream.flush().unwrap();
// ReadyForQuery
let msg = stream.read_u8().unwrap();
assert_eq!(msg, b'Z');
});
struct TestHandler {
got_query: bool,
}
impl Handler for TestHandler {
fn process_query(
&mut self,
_pgb: &mut PostgresBackend,
query_string: &str,
) -> Result<(), QueryError> {
self.got_query = query_string == QUERY;
Ok(())
}
}
let mut handler = TestHandler { got_query: false };
let cfg = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![CERT.clone()], KEY.clone())
.unwrap();
let tls_config = Some(Arc::new(cfg));
let pgb = PostgresBackend::new(server_sock, AuthType::Trust, tls_config, true).unwrap();
pgb.run(&mut handler).unwrap();
assert!(handler.got_query);
client_jh.join().unwrap();
// TODO consider shutdown behavior
}
#[test]
fn no_ssl() {
let (mut client_sock, server_sock) = make_tcp_pair();
let client_jh = std::thread::spawn(move || {
let mut buf = BytesMut::new();
// SSLRequest
buf.put_u32(8);
buf.put_u32(80877103);
client_sock.write_all(&buf).unwrap();
buf.clear();
let ssl_response = client_sock.read_u8().unwrap();
assert_eq!(b'N', ssl_response);
});
struct TestHandler;
impl Handler for TestHandler {
fn process_query(
&mut self,
_pgb: &mut PostgresBackend,
_query_string: &str,
) -> Result<(), QueryError> {
panic!()
}
}
let mut handler = TestHandler;
let pgb = PostgresBackend::new(server_sock, AuthType::Trust, None, true).unwrap();
pgb.run(&mut handler).unwrap();
client_jh.join().unwrap();
}
#[test]
fn server_forces_ssl() {
let (mut client_sock, server_sock) = make_tcp_pair();
let client_jh = std::thread::spawn(move || {
// StartupMessage
client_sock.write_u32::<BigEndian>(9).unwrap();
client_sock.write_u32::<BigEndian>(196608).unwrap();
client_sock.write_u8(0).unwrap();
client_sock.flush().unwrap();
// ErrorResponse
assert_eq!(client_sock.read_u8().unwrap(), b'E');
let len = client_sock.read_u32::<BigEndian>().unwrap() - 4;
let mut body = vec![0; len as usize];
client_sock.read_exact(&mut body).unwrap();
let mut body = Bytes::from(body);
let mut errors = HashMap::new();
loop {
let field_type = body.get_u8();
if field_type == 0u8 {
break;
}
let end_idx = body.iter().position(|&b| b == 0u8).unwrap();
let mut value = body.split_to(end_idx + 1);
assert_eq!(value[end_idx], 0u8);
value.truncate(end_idx);
let old = errors.insert(field_type, value);
assert!(old.is_none());
}
assert!(!body.has_remaining());
assert_eq!("must connect with TLS", errors.get(&b'M').unwrap());
// TODO read failure
});
struct TestHandler;
impl Handler for TestHandler {
fn process_query(
&mut self,
_pgb: &mut PostgresBackend,
_query_string: &str,
) -> Result<(), QueryError> {
panic!()
}
}
let mut handler = TestHandler;
let cfg = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(vec![CERT.clone()], KEY.clone())
.unwrap();
let tls_config = Some(Arc::new(cfg));
let pgb = PostgresBackend::new(server_sock, AuthType::Trust, tls_config, true).unwrap();
let res = pgb.run(&mut handler).unwrap_err();
assert_eq!("client did not connect with TLS", format!("{}", res));
client_jh.join().unwrap();
// TODO consider shutdown behavior
}

View File

@@ -37,7 +37,6 @@ num-traits.workspace = true
once_cell.workspace = true
pin-project-lite.workspace = true
postgres.workspace = true
postgres_backend.workspace = true
postgres-protocol.workspace = true
postgres-types.workspace = true
rand.workspace = true

View File

@@ -23,10 +23,11 @@ use pageserver::{
tenant::mgr,
virtual_file,
};
use postgres_backend::AuthType;
use utils::{
auth::JwtAuth,
logging, project_git_version,
logging,
postgres_backend::AuthType,
project_git_version,
sentry_init::init_sentry,
signals::{self, Signal},
tcp_listener,
@@ -280,17 +281,33 @@ fn start_pageserver(
};
info!("Using auth: {:#?}", conf.auth_type);
match var("NEON_AUTH_TOKEN") {
Ok(v) => {
// TODO: remove ZENITH_AUTH_TOKEN once it's not used anywhere in development/staging/prod configuration.
match (var("ZENITH_AUTH_TOKEN"), var("NEON_AUTH_TOKEN")) {
(old, Ok(v)) => {
info!("Loaded JWT token for authentication with Safekeeper");
if let Ok(v_old) = old {
warn!(
"JWT token for Safekeeper is specified twice, ZENITH_AUTH_TOKEN is deprecated"
);
if v_old != v {
warn!("JWT token for Safekeeper has two different values, choosing NEON_AUTH_TOKEN");
}
}
pageserver::config::SAFEKEEPER_AUTH_TOKEN
.set(Arc::new(v))
.map_err(|_| anyhow!("Could not initialize SAFEKEEPER_AUTH_TOKEN"))?;
}
Err(VarError::NotPresent) => {
(Ok(v), _) => {
info!("Loaded JWT token for authentication with Safekeeper");
warn!("Please update pageserver configuration: the JWT token should be NEON_AUTH_TOKEN, not ZENITH_AUTH_TOKEN");
pageserver::config::SAFEKEEPER_AUTH_TOKEN
.set(Arc::new(v))
.map_err(|_| anyhow!("Could not initialize SAFEKEEPER_AUTH_TOKEN"))?;
}
(_, Err(VarError::NotPresent)) => {
info!("No JWT token for authentication with Safekeeper detected");
}
Err(e) => {
(_, Err(e)) => {
return Err(e).with_context(|| {
"Failed to either load to detect non-present NEON_AUTH_TOKEN environment variable"
})

View File

@@ -21,10 +21,10 @@ use std::time::Duration;
use toml_edit;
use toml_edit::{Document, Item};
use postgres_backend::AuthType;
use utils::{
id::{NodeId, TenantId, TimelineId},
logging::LogFormat,
postgres_backend::AuthType,
};
use crate::tenant::config::TenantConf;
@@ -698,12 +698,6 @@ impl PageServerConf {
Some(parse_toml_u64("compaction_threshold", compaction_threshold)?.try_into()?);
}
if let Some(image_creation_threshold) = item.get("image_creation_threshold") {
t_conf.image_creation_threshold = Some(
parse_toml_u64("image_creation_threshold", image_creation_threshold)?.try_into()?,
);
}
if let Some(gc_horizon) = item.get("gc_horizon") {
t_conf.gc_horizon = Some(parse_toml_u64("gc_horizon", gc_horizon)?);
}

View File

@@ -245,53 +245,6 @@ paths:
application/json:
schema:
$ref: "#/components/schemas/Error"
/v1/tenant/{tenant_id}/timeline/{timeline_id}/do_gc:
parameters:
- name: tenant_id
in: path
required: true
schema:
type: string
format: hex
- name: timeline_id
in: path
required: true
schema:
type: string
format: hex
put:
description: Garbage collect given timeline
responses:
"200":
description: OK
content:
application/json:
schema:
type: string
"400":
description: Error when no tenant id found in path, no timeline id or invalid timestamp
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
"401":
description: Unauthorized Error
content:
application/json:
schema:
$ref: "#/components/schemas/UnauthorizedError"
"403":
description: Forbidden Error
content:
application/json:
schema:
$ref: "#/components/schemas/ForbiddenError"
"500":
description: Generic operation error
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
/v1/tenant/{tenant_id}/attach:
parameters:
- name: tenant_id

View File

@@ -10,7 +10,6 @@ use remote_storage::GenericRemoteStorage;
use tenant_size_model::{SizeResult, StorageModel};
use tokio_util::sync::CancellationToken;
use tracing::*;
use utils::http::endpoint::RequestSpan;
use utils::http::request::{get_request_param, must_get_query_param, parse_query_param};
use super::models::{
@@ -972,22 +971,19 @@ async fn timeline_checkpoint_handler(request: Request<Body>) -> Result<Response<
let tenant_id: TenantId = parse_request_param(&request, "tenant_id")?;
let timeline_id: TimelineId = parse_request_param(&request, "timeline_id")?;
check_permission(&request, Some(tenant_id))?;
async {
let ctx = RequestContext::new(TaskKind::MgmtRequest, DownloadBehavior::Download);
let timeline = active_timeline_of_active_tenant(tenant_id, timeline_id).await?;
timeline
.freeze_and_flush()
.await
.map_err(ApiError::InternalServerError)?;
timeline
.compact(&ctx)
.await
.map_err(ApiError::InternalServerError)?;
json_response(StatusCode::OK, ())
}
.instrument(info_span!("manual_checkpoint", tenant_id = %tenant_id, timeline_id = %timeline_id))
.await
let ctx = RequestContext::new(TaskKind::MgmtRequest, DownloadBehavior::Download);
let timeline = active_timeline_of_active_tenant(tenant_id, timeline_id).await?;
timeline
.freeze_and_flush()
.await
.map_err(ApiError::InternalServerError)?;
timeline
.compact(&ctx)
.await
.map_err(ApiError::InternalServerError)?;
json_response(StatusCode::OK, ())
}
async fn timeline_download_remote_layers_handler_post(
@@ -1092,8 +1088,7 @@ pub fn make_router(
let handler = $handler;
#[cfg(not(feature = "testing"))]
let handler = cfg_disabled;
move |r| RequestSpan(handler).handle(r)
handler
}};
}
@@ -1101,55 +1096,35 @@ pub fn make_router(
.data(Arc::new(
State::new(conf, auth, remote_storage).context("Failed to initialize router state")?,
))
.get("/v1/status", |r| RequestSpan(status_handler).handle(r))
.get("/v1/status", status_handler)
.put(
"/v1/failpoints",
testing_api!("manage failpoints", failpoints_handler),
)
.get("/v1/tenant", |r| RequestSpan(tenant_list_handler).handle(r))
.post("/v1/tenant", |r| {
RequestSpan(tenant_create_handler).handle(r)
})
.get("/v1/tenant/:tenant_id", |r| {
RequestSpan(tenant_status).handle(r)
})
.get("/v1/tenant/:tenant_id/synthetic_size", |r| {
RequestSpan(tenant_size_handler).handle(r)
})
.put("/v1/tenant/config", |r| {
RequestSpan(update_tenant_config_handler).handle(r)
})
.get("/v1/tenant/:tenant_id/config", |r| {
RequestSpan(get_tenant_config_handler).handle(r)
})
.get("/v1/tenant/:tenant_id/timeline", |r| {
RequestSpan(timeline_list_handler).handle(r)
})
.post("/v1/tenant/:tenant_id/timeline", |r| {
RequestSpan(timeline_create_handler).handle(r)
})
.post("/v1/tenant/:tenant_id/attach", |r| {
RequestSpan(tenant_attach_handler).handle(r)
})
.post("/v1/tenant/:tenant_id/detach", |r| {
RequestSpan(tenant_detach_handler).handle(r)
})
.post("/v1/tenant/:tenant_id/load", |r| {
RequestSpan(tenant_load_handler).handle(r)
})
.post("/v1/tenant/:tenant_id/ignore", |r| {
RequestSpan(tenant_ignore_handler).handle(r)
})
.get("/v1/tenant/:tenant_id/timeline/:timeline_id", |r| {
RequestSpan(timeline_detail_handler).handle(r)
})
.get("/v1/tenant", tenant_list_handler)
.post("/v1/tenant", tenant_create_handler)
.get("/v1/tenant/:tenant_id", tenant_status)
.get("/v1/tenant/:tenant_id/synthetic_size", tenant_size_handler)
.put("/v1/tenant/config", update_tenant_config_handler)
.get("/v1/tenant/:tenant_id/config", get_tenant_config_handler)
.get("/v1/tenant/:tenant_id/timeline", timeline_list_handler)
.post("/v1/tenant/:tenant_id/timeline", timeline_create_handler)
.post("/v1/tenant/:tenant_id/attach", tenant_attach_handler)
.post("/v1/tenant/:tenant_id/detach", tenant_detach_handler)
.post("/v1/tenant/:tenant_id/load", tenant_load_handler)
.post("/v1/tenant/:tenant_id/ignore", tenant_ignore_handler)
.get(
"/v1/tenant/:tenant_id/timeline/:timeline_id",
timeline_detail_handler,
)
.get(
"/v1/tenant/:tenant_id/timeline/:timeline_id/get_lsn_by_timestamp",
|r| RequestSpan(get_lsn_by_timestamp_handler).handle(r),
get_lsn_by_timestamp_handler,
)
.put(
"/v1/tenant/:tenant_id/timeline/:timeline_id/do_gc",
timeline_gc_handler,
)
.put("/v1/tenant/:tenant_id/timeline/:timeline_id/do_gc", |r| {
RequestSpan(timeline_gc_handler).handle(r)
})
.put(
"/v1/tenant/:tenant_id/timeline/:timeline_id/compact",
testing_api!("run timeline compaction", timeline_compact_handler),
@@ -1160,26 +1135,28 @@ pub fn make_router(
)
.post(
"/v1/tenant/:tenant_id/timeline/:timeline_id/download_remote_layers",
|r| RequestSpan(timeline_download_remote_layers_handler_post).handle(r),
timeline_download_remote_layers_handler_post,
)
.get(
"/v1/tenant/:tenant_id/timeline/:timeline_id/download_remote_layers",
|r| RequestSpan(timeline_download_remote_layers_handler_get).handle(r),
timeline_download_remote_layers_handler_get,
)
.delete(
"/v1/tenant/:tenant_id/timeline/:timeline_id",
timeline_delete_handler,
)
.get(
"/v1/tenant/:tenant_id/timeline/:timeline_id/layer",
layer_map_info_handler,
)
.delete("/v1/tenant/:tenant_id/timeline/:timeline_id", |r| {
RequestSpan(timeline_delete_handler).handle(r)
})
.get("/v1/tenant/:tenant_id/timeline/:timeline_id/layer", |r| {
RequestSpan(layer_map_info_handler).handle(r)
})
.get(
"/v1/tenant/:tenant_id/timeline/:timeline_id/layer/:layer_file_name",
|r| RequestSpan(layer_download_handler).handle(r),
layer_download_handler,
)
.delete(
"/v1/tenant/:tenant_id/timeline/:timeline_id/layer/:layer_file_name",
|r| RequestSpan(evict_timeline_layer_handler).handle(r),
evict_timeline_layer_handler,
)
.get("/v1/panic", |r| RequestSpan(always_panic_handler).handle(r))
.get("/v1/panic", always_panic_handler)
.any(handler_404))
}

View File

@@ -123,22 +123,6 @@ static REMOTE_PHYSICAL_SIZE: Lazy<UIntGaugeVec> = Lazy::new(|| {
.expect("failed to define a metric")
});
pub static REMOTE_ONDEMAND_DOWNLOADED_LAYERS: Lazy<IntCounter> = Lazy::new(|| {
register_int_counter!(
"pageserver_remote_ondemand_downloaded_layers_total",
"Total on-demand downloaded layers"
)
.unwrap()
});
pub static REMOTE_ONDEMAND_DOWNLOADED_BYTES: Lazy<IntCounter> = Lazy::new(|| {
register_int_counter!(
"pageserver_remote_ondemand_downloaded_bytes_total",
"Total bytes of layers on-demand downloaded",
)
.unwrap()
});
static CURRENT_LOGICAL_SIZE: Lazy<UIntGaugeVec> = Lazy::new(|| {
register_uint_gauge_vec!(
"pageserver_current_logical_size",

View File

@@ -20,8 +20,7 @@ use pageserver_api::models::{
PagestreamFeMessage, PagestreamGetPageRequest, PagestreamGetPageResponse,
PagestreamNblocksRequest, PagestreamNblocksResponse,
};
use postgres_backend::{self, is_expected_io_error, AuthType, PostgresBackend, QueryError};
use pq_proto::framed::ConnectionError;
use pq_proto::ConnectionError;
use pq_proto::FeStartupPacket;
use pq_proto::{BeMessage, FeMessage, RowDescriptor};
use std::io;
@@ -36,6 +35,8 @@ use utils::{
auth::{Claims, JwtAuth, Scope},
id::{TenantId, TimelineId},
lsn::Lsn,
postgres_backend::AuthType,
postgres_backend_async::{self, is_expected_io_error, PostgresBackend, QueryError},
simple_rcu::RcuReadGuard,
};
@@ -63,11 +64,11 @@ fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream<Item = io::Result<Byt
_ = task_mgr::shutdown_watcher() => {
// We were requested to shut down.
let msg = format!("pageserver is shutting down");
let _ = pgb.write_message_noflush(&BeMessage::ErrorResponse(&msg, None));
let _ = pgb.write_message(&BeMessage::ErrorResponse(&msg, None));
Err(QueryError::Other(anyhow::anyhow!(msg)))
}
msg = pgb.read_message() => { msg.map_err(QueryError::from)}
msg = pgb.read_message() => { msg }
};
match msg {
@@ -78,16 +79,14 @@ fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream<Item = io::Result<Byt
FeMessage::Sync => continue,
FeMessage::Terminate => {
let msg = "client terminated connection with Terminate message during COPY";
let query_error = QueryError::Disconnected(ConnectionError::Io(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
// error can't happen here, ErrorResponse serialization should be always ok
pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code()))).map_err(|e| e.into_io_error())?;
let query_error_error = QueryError::Disconnected(ConnectionError::Socket(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
pgb.write_message(&BeMessage::ErrorResponse(msg, Some(query_error_error.pg_error_code())))?;
Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?;
break;
}
m => {
let msg = format!("unexpected message {m:?}");
// error can't happen here, ErrorResponse serialization should be always ok
pgb.write_message_noflush(&BeMessage::ErrorResponse(&msg, None)).map_err(|e| e.into_io_error())?;
pgb.write_message(&BeMessage::ErrorResponse(&msg, None))?;
Err(io::Error::new(io::ErrorKind::Other, msg))?;
break;
}
@@ -97,17 +96,16 @@ fn copyin_stream(pgb: &mut PostgresBackend) -> impl Stream<Item = io::Result<Byt
}
Ok(None) => {
let msg = "client closed connection during COPY";
let query_error = QueryError::Disconnected(ConnectionError::Io(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
// error can't happen here, ErrorResponse serialization should be always ok
pgb.write_message_noflush(&BeMessage::ErrorResponse(msg, Some(query_error.pg_error_code()))).map_err(|e| e.into_io_error())?;
let query_error_error = QueryError::Disconnected(ConnectionError::Socket(io::Error::new(io::ErrorKind::ConnectionReset, msg)));
pgb.write_message(&BeMessage::ErrorResponse(msg, Some(query_error_error.pg_error_code())))?;
pgb.flush().await?;
Err(io::Error::new(io::ErrorKind::ConnectionReset, msg))?;
}
Err(QueryError::Disconnected(ConnectionError::Io(io_error))) => {
Err(QueryError::Disconnected(ConnectionError::Socket(io_error))) => {
Err(io_error)?;
}
Err(other) => {
Err(io::Error::new(io::ErrorKind::Other, other.to_string()))?;
Err(io::Error::new(io::ErrorKind::Other, other))?;
}
};
}
@@ -214,7 +212,7 @@ async fn page_service_conn_main(
// we've been requested to shut down
Ok(())
}
Err(QueryError::Disconnected(ConnectionError::Io(io_error))) => {
Err(QueryError::Disconnected(ConnectionError::Socket(io_error))) => {
if is_expected_io_error(&io_error) {
info!("Postgres client disconnected ({io_error})");
Ok(())
@@ -313,7 +311,7 @@ impl PageServerHandler {
let timeline = tenant.get_timeline(timeline_id, true)?;
// switch client to COPYBOTH
pgb.write_message_noflush(&BeMessage::CopyBothResponse)?;
pgb.write_message(&BeMessage::CopyBothResponse)?;
pgb.flush().await?;
let metrics = PageRequestMetrics::new(&tenant_id, &timeline_id);
@@ -382,7 +380,7 @@ impl PageServerHandler {
})
});
pgb.write_message_noflush(&BeMessage::CopyData(&response.serialize()))?;
pgb.write_message(&BeMessage::CopyData(&response.serialize()))?;
pgb.flush().await?;
}
Ok(())
@@ -418,7 +416,7 @@ impl PageServerHandler {
// Import basebackup provided via CopyData
info!("importing basebackup");
pgb.write_message_noflush(&BeMessage::CopyInResponse)?;
pgb.write_message(&BeMessage::CopyInResponse)?;
pgb.flush().await?;
let mut copyin_stream = Box::pin(copyin_stream(pgb));
@@ -470,7 +468,7 @@ impl PageServerHandler {
// Import wal provided via CopyData
info!("importing wal");
pgb.write_message_noflush(&BeMessage::CopyInResponse)?;
pgb.write_message(&BeMessage::CopyInResponse)?;
pgb.flush().await?;
let mut copyin_stream = Box::pin(copyin_stream(pgb));
let mut reader = tokio_util::io::StreamReader::new(&mut copyin_stream);
@@ -680,7 +678,7 @@ impl PageServerHandler {
}
// switch client to COPYOUT
pgb.write_message_noflush(&BeMessage::CopyOutResponse)?;
pgb.write_message(&BeMessage::CopyOutResponse)?;
pgb.flush().await?;
// Send a tarball of the latest layer on the timeline
@@ -697,7 +695,7 @@ impl PageServerHandler {
.await?;
}
pgb.write_message_noflush(&BeMessage::CopyDone)?;
pgb.write_message(&BeMessage::CopyDone)?;
pgb.flush().await?;
info!("basebackup complete");
@@ -723,7 +721,7 @@ impl PageServerHandler {
}
#[async_trait::async_trait]
impl postgres_backend::Handler for PageServerHandler {
impl postgres_backend_async::Handler for PageServerHandler {
fn check_auth_jwt(
&mut self,
_pgb: &mut PostgresBackend,
@@ -814,7 +812,7 @@ impl postgres_backend::Handler for PageServerHandler {
// Check that the timeline exists
self.handle_basebackup_request(pgb, tenant_id, timeline_id, lsn, None, false, ctx)
.await?;
pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
}
// return pair of prev_lsn and last_lsn
else if query_string.starts_with("get_last_record_rlsn ") {
@@ -837,15 +835,15 @@ impl postgres_backend::Handler for PageServerHandler {
let end_of_timeline = timeline.get_last_record_rlsn();
pgb.write_message_noflush(&BeMessage::RowDescription(&[
pgb.write_message(&BeMessage::RowDescription(&[
RowDescriptor::text_col(b"prev_lsn"),
RowDescriptor::text_col(b"last_lsn"),
]))?
.write_message_noflush(&BeMessage::DataRow(&[
.write_message(&BeMessage::DataRow(&[
Some(end_of_timeline.prev.to_string().as_bytes()),
Some(end_of_timeline.last.to_string().as_bytes()),
]))?
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
}
// same as basebackup, but result includes relational data as well
else if query_string.starts_with("fullbackup ") {
@@ -886,7 +884,7 @@ impl postgres_backend::Handler for PageServerHandler {
// Check that the timeline exists
self.handle_basebackup_request(pgb, tenant_id, timeline_id, lsn, prev_lsn, true, ctx)
.await?;
pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
} else if query_string.starts_with("import basebackup ") {
// Import the `base` section (everything but the wal) of a basebackup.
// Assumes the tenant already exists on this pageserver.
@@ -931,10 +929,10 @@ impl postgres_backend::Handler for PageServerHandler {
)
.await
{
Ok(()) => pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?,
Ok(()) => pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?,
Err(e) => {
error!("error importing base backup between {base_lsn} and {end_lsn}: {e:?}");
pgb.write_message_noflush(&BeMessage::ErrorResponse(
pgb.write_message(&BeMessage::ErrorResponse(
&e.to_string(),
Some(e.pg_error_code()),
))?
@@ -967,10 +965,10 @@ impl postgres_backend::Handler for PageServerHandler {
.handle_import_wal(pgb, tenant_id, timeline_id, start_lsn, end_lsn, ctx)
.await
{
Ok(()) => pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?,
Ok(()) => pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?,
Err(e) => {
error!("error importing WAL between {start_lsn} and {end_lsn}: {e:?}");
pgb.write_message_noflush(&BeMessage::ErrorResponse(
pgb.write_message(&BeMessage::ErrorResponse(
&e.to_string(),
Some(e.pg_error_code()),
))?
@@ -979,7 +977,7 @@ impl postgres_backend::Handler for PageServerHandler {
} else if query_string.to_ascii_lowercase().starts_with("set ") {
// important because psycopg2 executes "SET datestyle TO 'ISO'"
// on connect
pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
} else if query_string.starts_with("show ") {
// show <tenant_id>
let (_, params_raw) = query_string.split_at("show ".len());
@@ -995,7 +993,7 @@ impl postgres_backend::Handler for PageServerHandler {
self.check_permission(Some(tenant_id))?;
let tenant = get_active_tenant_with_timeout(tenant_id, &ctx).await?;
pgb.write_message_noflush(&BeMessage::RowDescription(&[
pgb.write_message(&BeMessage::RowDescription(&[
RowDescriptor::int8_col(b"checkpoint_distance"),
RowDescriptor::int8_col(b"checkpoint_timeout"),
RowDescriptor::int8_col(b"compaction_target_size"),
@@ -1006,7 +1004,7 @@ impl postgres_backend::Handler for PageServerHandler {
RowDescriptor::int8_col(b"image_creation_threshold"),
RowDescriptor::int8_col(b"pitr_interval"),
]))?
.write_message_noflush(&BeMessage::DataRow(&[
.write_message(&BeMessage::DataRow(&[
Some(tenant.get_checkpoint_distance().to_string().as_bytes()),
Some(
tenant
@@ -1029,7 +1027,7 @@ impl postgres_backend::Handler for PageServerHandler {
Some(tenant.get_image_creation_threshold().to_string().as_bytes()),
Some(tenant.get_pitr_interval().as_secs().to_string().as_bytes()),
]))?
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
} else {
return Err(QueryError::Other(anyhow::anyhow!(
"unknown command {query_string}"
@@ -1057,7 +1055,7 @@ impl From<GetActiveTenantError> for QueryError {
fn from(e: GetActiveTenantError) -> Self {
match e {
GetActiveTenantError::WaitForActiveTimeout { .. } => QueryError::Disconnected(
ConnectionError::Io(io::Error::new(io::ErrorKind::TimedOut, e.to_string())),
ConnectionError::Socket(io::Error::new(io::ErrorKind::TimedOut, e.to_string())),
),
GetActiveTenantError::Other(e) => QueryError::Other(e),
}

View File

@@ -103,7 +103,6 @@ pub struct TenantConfOpt {
pub checkpoint_distance: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(with = "humantime_serde")]
#[serde(default)]
pub checkpoint_timeout: Option<Duration>,

View File

@@ -218,10 +218,9 @@ use tracing::{debug, info, warn};
use tracing::{info_span, Instrument};
use utils::lsn::Lsn;
use crate::metrics::{
MeasureRemoteOp, RemoteOpFileKind, RemoteOpKind, RemoteTimelineClientMetrics,
REMOTE_ONDEMAND_DOWNLOADED_BYTES, REMOTE_ONDEMAND_DOWNLOADED_LAYERS,
};
use crate::metrics::RemoteOpFileKind;
use crate::metrics::RemoteOpKind;
use crate::metrics::{MeasureRemoteOp, RemoteTimelineClientMetrics};
use crate::tenant::remote_timeline_client::index::LayerFileMetadata;
use crate::{
config::PageServerConf,
@@ -447,10 +446,6 @@ impl RemoteTimelineClient {
);
}
}
REMOTE_ONDEMAND_DOWNLOADED_LAYERS.inc();
REMOTE_ONDEMAND_DOWNLOADED_BYTES.inc_by(downloaded_size);
Ok(downloaded_size)
}

View File

@@ -6,13 +6,11 @@
use std::collections::HashSet;
use std::future::Future;
use std::path::Path;
use std::time::Duration;
use anyhow::{anyhow, Context};
use tokio::fs;
use tokio::io::AsyncWriteExt;
use tracing::{info, warn};
use tracing::{error, info, warn};
use crate::config::PageServerConf;
use crate::tenant::storage_layer::LayerFileName;
@@ -28,8 +26,6 @@ async fn fsync_path(path: impl AsRef<std::path::Path>) -> Result<(), std::io::Er
fs::File::open(path).await?.sync_all().await
}
static MAX_DOWNLOAD_DURATION: Duration = Duration::from_secs(120);
///
/// If 'metadata' is given, we will validate that the downloaded file's size matches that
/// in the metadata. (In the future, we might do more cross-checks, like CRC validation)
@@ -68,28 +64,22 @@ pub async fn download_layer_file<'a>(
// TODO: this doesn't use the cached fd for some reason?
let mut destination_file = fs::File::create(&temp_file_path).await.with_context(|| {
format!(
"create a destination file for layer '{}'",
"Failed to create a destination file for layer '{}'",
temp_file_path.display()
)
})
.map_err(DownloadError::Other)?;
let mut download = storage.download(&remote_path).await.with_context(|| {
format!(
"open a download stream for layer with remote storage path '{remote_path:?}'"
"Failed to open a download stream for layer with remote storage path '{remote_path:?}'"
)
})
.map_err(DownloadError::Other)?;
let bytes_amount = tokio::time::timeout(MAX_DOWNLOAD_DURATION, tokio::io::copy(&mut download.download_stream, &mut destination_file))
.await
.map_err(|e| DownloadError::Other(anyhow::anyhow!("Timed out {:?}", e)))?
.with_context(|| {
format!("Failed to download layer with remote storage path '{remote_path:?}' into file {temp_file_path:?}")
})
.map_err(DownloadError::Other)?;
let bytes_amount = tokio::io::copy(&mut download.download_stream, &mut destination_file).await.with_context(|| {
format!("Failed to download layer with remote storage path '{remote_path:?}' into file {temp_file_path:?}")
})
.map_err(DownloadError::Other)?;
Ok((destination_file, bytes_amount))
},
&format!("download {remote_path:?}"),
).await?;
@@ -310,7 +300,7 @@ where
}
Err(DownloadError::Other(ref err)) => {
// Operation failed FAILED_DOWNLOAD_RETRIES times. Time to give up.
warn!("{description} still failed after {attempts} retries, giving up: {err:?}");
error!("{description} still failed after {attempts} retries, giving up: {err:?}");
return result;
}
}

View File

@@ -364,7 +364,7 @@ pub trait PersistentLayer: Layer {
}
/// Permanently remove this layer from disk.
fn delete_resident_layer_file(&self) -> Result<()>;
fn delete(&self) -> Result<()>;
fn downcast_remote_layer(self: Arc<Self>) -> Option<std::sync::Arc<RemoteLayer>> {
None

View File

@@ -438,7 +438,7 @@ impl PersistentLayer for DeltaLayer {
))
}
fn delete_resident_layer_file(&self) -> Result<()> {
fn delete(&self) -> Result<()> {
// delete underlying file
fs::remove_file(self.path())?;
Ok(())

View File

@@ -252,7 +252,7 @@ impl PersistentLayer for ImageLayer {
unimplemented!();
}
fn delete_resident_layer_file(&self) -> Result<()> {
fn delete(&self) -> Result<()> {
// delete underlying file
fs::remove_file(self.path())?;
Ok(())

View File

@@ -155,8 +155,8 @@ impl PersistentLayer for RemoteLayer {
bail!("cannot iterate a remote layer");
}
fn delete_resident_layer_file(&self) -> Result<()> {
bail!("remote layer has no layer file");
fn delete(&self) -> Result<()> {
Ok(())
}
fn downcast_remote_layer<'a>(self: Arc<Self>) -> Option<std::sync::Arc<RemoteLayer>> {

View File

@@ -662,8 +662,8 @@ impl Timeline {
// update the index file on next flush iteration too. But it
// could take a while until that happens.
//
// Additionally, only do this once before we return from this function.
if last_round || res.is_ok() {
// Additionally, only do this on the terminal round before sleeping.
if last_round {
if let Some(remote_client) = &self.remote_client {
remote_client.schedule_index_upload_for_file_changes()?;
}
@@ -1047,12 +1047,11 @@ impl Timeline {
return Ok(false);
}
let layer_file_size = local_layer
.file_size()
.expect("Local layer should have a file size");
let layer_metadata = LayerFileMetadata::new(layer_file_size);
let layer_metadata = LayerFileMetadata::new(
local_layer
.file_size()
.expect("Local layer should have a file size"),
);
let new_remote_layer = Arc::new(match local_layer.filename() {
LayerFileName::Image(image_name) => RemoteLayer::new_img(
self.tenant_id,
@@ -1076,22 +1075,18 @@ impl Timeline {
let replaced = match batch_updates.replace_historic(local_layer, new_remote_layer)? {
Replacement::Replaced { .. } => {
if let Err(e) = local_layer.delete_resident_layer_file() {
let layer_size = local_layer.file_size();
if let Err(e) = local_layer.delete() {
error!("failed to remove layer file on evict after replacement: {e:#?}");
}
// Always decrement the physical size gauge, even if we failed to delete the file.
// Rationale: we already replaced the layer with a remote layer in the layer map,
// and any subsequent download_remote_layer will
// 1. overwrite the file on disk and
// 2. add the downloaded size to the resident size gauge.
//
// If there is no re-download, and we restart the pageserver, then load_layer_map
// will treat the file as a local layer again, count it towards resident size,
// and it'll be like the layer removal never happened.
// The bump in resident size is perhaps unexpected but overall a robust behavior.
self.metrics
.resident_physical_size_gauge
.sub(layer_file_size);
if let Some(layer_size) = layer_size {
let current_size = self.metrics.resident_physical_size_gauge.get();
self.metrics
.resident_physical_size_gauge
.set(current_size.saturating_sub(layer_size));
}
true
}
@@ -1512,7 +1507,11 @@ impl Timeline {
assert!(local_layer_path.exists(), "we would leave the local_layer without a file if this does not hold: {}", local_layer_path.display());
anyhow::bail!("could not rename file {local_layer_path:?}: {err:?}");
} else {
self.metrics.resident_physical_size_gauge.sub(local_size);
let current_size = self.metrics.resident_physical_size_gauge.get();
self.metrics
.resident_physical_size_gauge
.set(current_size.saturating_sub(local_size));
updates.remove_historic(local_layer);
// fall-through to adding the remote layer
}
@@ -1950,14 +1949,14 @@ impl Timeline {
layer: Arc<dyn PersistentLayer>,
updates: &mut BatchedUpdates<'_, dyn PersistentLayer>,
) -> anyhow::Result<()> {
if !layer.is_remote_layer() {
layer.delete_resident_layer_file()?;
let layer_file_size = layer
.file_size()
.expect("Local layer should have a file size");
let layer_size = layer.file_size();
layer.delete()?;
if let Some(layer_size) = layer_size {
let current_size = self.metrics.resident_physical_size_gauge.get();
self.metrics
.resident_physical_size_gauge
.sub(layer_file_size);
.set(current_size.saturating_sub(layer_size));
}
// TODO Removing from the bottom of the layer map is expensive.
@@ -3819,7 +3818,7 @@ impl Timeline {
remote_layer.ongoing_download.close();
} else {
// Keep semaphore open. We'll drop the permit at the end of the function.
error!("on-demand download failed: {:?}", result.as_ref().unwrap_err());
info!("on-demand download failed: {:?}", result.as_ref().unwrap_err());
}
// Don't treat it as an error if the task that triggered the download

View File

@@ -33,11 +33,10 @@ use crate::{
walingest::WalIngest,
walrecord::DecodedWALRecord,
};
use postgres_backend::is_expected_io_error;
use postgres_connection::PgConnectionConfig;
use postgres_ffi::waldecoder::WalStreamDecoder;
use pq_proto::ReplicationFeedback;
use utils::lsn::Lsn;
use utils::{lsn::Lsn, postgres_backend_async::is_expected_io_error};
/// Status of the connection.
#[derive(Debug, Clone, Copy)]
@@ -354,7 +353,7 @@ pub async fn handle_walreceiver_connection(
debug!("neon_status_update {status_update:?}");
let mut data = BytesMut::new();
status_update.serialize(&mut data);
status_update.serialize(&mut data)?;
physical_stream
.as_mut()
.zenith_status_update(data.len() as u64, &data)
@@ -435,8 +434,8 @@ fn ignore_expected_errors(pg_error: postgres::Error) -> anyhow::Result<postgres:
{
return Ok(pg_error);
} else if let Some(db_error) = pg_error.as_db_error() {
if db_error.code() == &SqlState::SUCCESSFUL_COMPLETION
&& db_error.message().contains("ending streaming")
if db_error.code() == &SqlState::CONNECTION_FAILURE
&& db_error.message().contains("end streaming")
{
return Ok(pg_error);
}

View File

@@ -23,11 +23,13 @@ use bytes::{BufMut, Bytes, BytesMut};
use nix::poll::*;
use serde::Serialize;
use std::collections::VecDeque;
use std::fs::OpenOptions;
use std::io::prelude::*;
use std::io::{Error, ErrorKind};
use std::ops::{Deref, DerefMut};
use std::os::unix::io::{AsRawFd, RawFd};
use std::os::unix::prelude::CommandExt;
use std::path::PathBuf;
use std::process::Stdio;
use std::process::{Child, ChildStderr, ChildStdin, ChildStdout, Command};
use std::sync::{Mutex, MutexGuard};
@@ -254,53 +256,52 @@ impl PostgresRedoManager {
pg_version: u32,
) -> Result<Bytes, WalRedoError> {
let (rel, blknum) = key_to_rel_block(key).or(Err(WalRedoError::InvalidRecord))?;
const MAX_RETRY_ATTEMPTS: u32 = 1;
let start_time = Instant::now();
let mut n_attempts = 0u32;
loop {
let mut proc = self.stdin.lock().unwrap();
let lock_time = Instant::now();
// launch the WAL redo process on first use
if proc.is_none() {
self.launch(&mut proc, pg_version)?;
}
WAL_REDO_WAIT_TIME.observe(lock_time.duration_since(start_time).as_secs_f64());
let mut proc = self.stdin.lock().unwrap();
let lock_time = Instant::now();
// Relational WAL records are applied using wal-redo-postgres
let buf_tag = BufferTag { rel, blknum };
let result = self
.apply_wal_records(proc, buf_tag, &base_img, records, wal_redo_timeout)
.map_err(WalRedoError::IoError);
// launch the WAL redo process on first use
if proc.is_none() {
self.launch(&mut proc, pg_version)?;
}
WAL_REDO_WAIT_TIME.observe(lock_time.duration_since(start_time).as_secs_f64());
let end_time = Instant::now();
let duration = end_time.duration_since(lock_time);
// Relational WAL records are applied using wal-redo-postgres
let buf_tag = BufferTag { rel, blknum };
let result = self
.apply_wal_records(proc, buf_tag, base_img, records, wal_redo_timeout)
.map_err(WalRedoError::IoError);
let len = records.len();
let nbytes = records.iter().fold(0, |acumulator, record| {
acumulator
+ match &record.1 {
NeonWalRecord::Postgres { rec, .. } => rec.len(),
_ => unreachable!("Only PostgreSQL records are accepted in this batch"),
}
});
let end_time = Instant::now();
let duration = end_time.duration_since(lock_time);
WAL_REDO_TIME.observe(duration.as_secs_f64());
WAL_REDO_RECORDS_HISTOGRAM.observe(len as f64);
WAL_REDO_BYTES_HISTOGRAM.observe(nbytes as f64);
let len = records.len();
let nbytes = records.iter().fold(0, |acumulator, record| {
acumulator
+ match &record.1 {
NeonWalRecord::Postgres { rec, .. } => rec.len(),
_ => unreachable!("Only PostgreSQL records are accepted in this batch"),
}
});
debug!(
"postgres applied {} WAL records ({} bytes) in {} us to reconstruct page image at LSN {}",
len,
nbytes,
duration.as_micros(),
lsn
);
WAL_REDO_TIME.observe(duration.as_secs_f64());
WAL_REDO_RECORDS_HISTOGRAM.observe(len as f64);
WAL_REDO_BYTES_HISTOGRAM.observe(nbytes as f64);
// If something went wrong, don't try to reuse the process. Kill it, and
// next request will launch a new one.
if result.is_err() {
error!(
debug!(
"postgres applied {} WAL records ({} bytes) in {} us to reconstruct page image at LSN {}",
len,
nbytes,
duration.as_micros(),
lsn
);
// If something went wrong, don't try to reuse the process. Kill it, and
// next request will launch a new one.
if result.is_err() {
error!(
"error applying {} WAL records {}..{} ({} bytes) to base image with LSN {} to reconstruct page image at LSN {}",
records.len(),
records.first().map(|p| p.0).unwrap_or(Lsn(0)),
@@ -309,28 +310,24 @@ impl PostgresRedoManager {
base_img_lsn,
lsn
);
// self.stdin only holds stdin & stderr as_raw_fd().
// Dropping it as part of take() doesn't close them.
// The owning objects (ChildStdout and ChildStderr) are stored in
// self.stdout and self.stderr, respsectively.
// We intentionally keep them open here to avoid a race between
// currently running `apply_wal_records()` and a `launch()` call
// after we return here.
// The currently running `apply_wal_records()` must not read from
// the newly launched process.
// By keeping self.stdout and self.stderr open here, `launch()` will
// get other file descriptors for the new child's stdout and stderr,
// and hence the current `apply_wal_records()` calls will observe
// `output.stdout.as_raw_fd() != stdout_fd` .
if let Some(proc) = self.stdin.lock().unwrap().take() {
proc.child.kill_and_wait();
}
}
n_attempts += 1;
if n_attempts > MAX_RETRY_ATTEMPTS || result.is_ok() {
return result;
// self.stdin only holds stdin & stderr as_raw_fd().
// Dropping it as part of take() doesn't close them.
// The owning objects (ChildStdout and ChildStderr) are stored in
// self.stdout and self.stderr, respsectively.
// We intentionally keep them open here to avoid a race between
// currently running `apply_wal_records()` and a `launch()` call
// after we return here.
// The currently running `apply_wal_records()` must not read from
// the newly launched process.
// By keeping self.stdout and self.stderr open here, `launch()` will
// get other file descriptors for the new child's stdout and stderr,
// and hence the current `apply_wal_records()` calls will observe
// `output.stdout.as_raw_fd() != stdout_fd` .
if let Some(proc) = self.stdin.lock().unwrap().take() {
proc.child.kill_and_wait();
}
}
result
}
///
@@ -637,26 +634,26 @@ impl PostgresRedoManager {
input: &mut MutexGuard<Option<ProcessInput>>,
pg_version: u32,
) -> Result<(), Error> {
// Previous versions of wal-redo required data directory and that directories
// occupied some space on disk. Remove it if we face it.
//
// This code could be dropped after one release cycle.
let legacy_datadir = path_with_suffix_extension(
// FIXME: We need a dummy Postgres cluster to run the process in. Currently, we
// just create one with constant name. That fails if you try to launch more than
// one WAL redo manager concurrently.
let datadir = path_with_suffix_extension(
self.conf
.tenant_path(&self.tenant_id)
.join("wal-redo-datadir"),
TEMP_FILE_SUFFIX,
);
if legacy_datadir.exists() {
info!("legacy wal-redo datadir {legacy_datadir:?} exists, removing");
fs::remove_dir_all(&legacy_datadir).map_err(|e| {
// Create empty data directory for wal-redo postgres, deleting old one first.
if datadir.exists() {
info!("old temporary datadir {datadir:?} exists, removing");
fs::remove_dir_all(&datadir).map_err(|e| {
Error::new(
e.kind(),
format!("legacy wal-redo datadir {legacy_datadir:?} removal failure: {e}"),
format!("Old temporary dir {datadir:?} removal failure: {e}"),
)
})?;
}
let pg_bin_dir_path = self
.conf
.pg_bin_dir(pg_version)
@@ -666,6 +663,35 @@ impl PostgresRedoManager {
.pg_lib_dir(pg_version)
.map_err(|e| Error::new(ErrorKind::Other, format!("incorrect pg_lib_dir path: {e}")))?;
info!("running initdb in {}", datadir.display());
let initdb = Command::new(pg_bin_dir_path.join("initdb"))
.args(["-D", &datadir.to_string_lossy()])
.arg("-N")
.env_clear()
.env("LD_LIBRARY_PATH", &pg_lib_dir_path)
.env("DYLD_LIBRARY_PATH", &pg_lib_dir_path) // macOS
.close_fds()
.output()
.map_err(|e| Error::new(e.kind(), format!("failed to execute initdb: {e}")))?;
if !initdb.status.success() {
return Err(Error::new(
ErrorKind::Other,
format!(
"initdb failed\nstdout: {}\nstderr:\n{}",
String::from_utf8_lossy(&initdb.stdout),
String::from_utf8_lossy(&initdb.stderr)
),
));
} else {
// Limit shared cache for wal-redo-postgres
let mut config = OpenOptions::new()
.append(true)
.open(PathBuf::from(&datadir).join("postgresql.conf"))?;
config.write_all(b"shared_buffers=128kB\n")?;
config.write_all(b"fsync=off\n")?;
}
// Start postgres itself
let child = Command::new(pg_bin_dir_path.join("postgres"))
.arg("--wal-redo")
@@ -675,6 +701,7 @@ impl PostgresRedoManager {
.env_clear()
.env("LD_LIBRARY_PATH", &pg_lib_dir_path)
.env("DYLD_LIBRARY_PATH", &pg_lib_dir_path)
.env("PGDATA", &datadir)
// The redo process is not trusted, and runs in seccomp mode that
// doesn't allow it to open any files. We have to also make sure it
// doesn't inherit any file descriptors from the pageserver, that
@@ -744,7 +771,7 @@ impl PostgresRedoManager {
&self,
mut input: MutexGuard<Option<ProcessInput>>,
tag: BufferTag,
base_img: &Option<Bytes>,
base_img: Option<Bytes>,
records: &[(Lsn, NeonWalRecord)],
wal_redo_timeout: Duration,
) -> Result<Bytes, std::io::Error> {
@@ -760,7 +787,7 @@ impl PostgresRedoManager {
let mut writebuf: Vec<u8> = Vec::with_capacity((BLCKSZ as usize) * 3);
build_begin_redo_for_block_msg(tag, &mut writebuf);
if let Some(img) = base_img {
build_push_page_msg(tag, img, &mut writebuf);
build_push_page_msg(tag, &img, &mut writebuf);
}
for (lsn, rec) in records.iter() {
if let NeonWalRecord::Postgres {

View File

@@ -32,9 +32,6 @@
#define PageStoreTrace DEBUG5
#define MAX_RECONNECT_ATTEMPTS 5
#define RECONNECT_INTERVAL_USEC 1000000
bool connected = false;
PGconn *pageserver_conn = NULL;
@@ -55,8 +52,8 @@ int readahead_buffer_size = 128;
static void pageserver_flush(void);
static bool
pageserver_connect(int elevel)
static void
pageserver_connect()
{
char *query;
int ret;
@@ -72,11 +69,10 @@ pageserver_connect(int elevel)
PQfinish(pageserver_conn);
pageserver_conn = NULL;
ereport(elevel,
ereport(ERROR,
(errcode(ERRCODE_SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION),
errmsg(NEON_TAG "could not establish connection to pageserver"),
errdetail_internal("%s", msg)));
return false;
}
query = psprintf("pagestream %s %s", neon_tenant, neon_timeline);
@@ -85,8 +81,7 @@ pageserver_connect(int elevel)
{
PQfinish(pageserver_conn);
pageserver_conn = NULL;
neon_log(elevel, "could not send pagestream command to pageserver");
return false;
neon_log(ERROR, "could not send pagestream command to pageserver");
}
pageserver_conn_wes = CreateWaitEventSet(TopMemoryContext, 3);
@@ -118,9 +113,8 @@ pageserver_connect(int elevel)
FreeWaitEventSet(pageserver_conn_wes);
pageserver_conn_wes = NULL;
neon_log(elevel, "could not complete handshake with pageserver: %s",
neon_log(ERROR, "could not complete handshake with pageserver: %s",
msg);
return false;
}
}
}
@@ -128,7 +122,6 @@ pageserver_connect(int elevel)
neon_log(LOG, "libpagestore: connected to '%s'", page_server_connstring_raw);
connected = true;
return true;
}
/*
@@ -156,11 +149,8 @@ retry:
if (event.events & WL_SOCKET_READABLE)
{
if (!PQconsumeInput(pageserver_conn))
{
neon_log(LOG, "could not get response from pageserver: %s",
neon_log(ERROR, "could not get response from pageserver: %s",
PQerrorMessage(pageserver_conn));
return -1;
}
}
goto retry;
@@ -200,62 +190,31 @@ static void
pageserver_send(NeonRequest * request)
{
StringInfoData req_buff;
int n_reconnect_attempts = 0;
/* If the connection was lost for some reason, reconnect */
if (connected && PQstatus(pageserver_conn) == CONNECTION_BAD)
pageserver_disconnect();
if (!connected)
pageserver_connect();
req_buff = nm_pack_request(request);
/*
* If pageserver is stopped, the connections from compute node are broken.
* The compute node doesn't notice that immediately, but it will cause the next request to fail, usually on the next query.
* That causes user-visible errors if pageserver is restarted, or the tenant is moved from one pageserver to another.
* See https://github.com/neondatabase/neon/issues/1138
* So try to reestablish connection in case of failure.
* Send request.
*
* In principle, this could block if the output buffer is full, and we
* should use async mode and check for interrupts while waiting. In
* practice, our requests are small enough to always fit in the output and
* TCP buffer.
*/
while (true)
if (PQputCopyData(pageserver_conn, req_buff.data, req_buff.len) <= 0)
{
if (!connected)
{
if (!pageserver_connect(n_reconnect_attempts < MAX_RECONNECT_ATTEMPTS ? LOG : ERROR))
{
n_reconnect_attempts += 1;
pg_usleep(RECONNECT_INTERVAL_USEC);
continue;
}
}
char *msg = pchomp(PQerrorMessage(pageserver_conn));
/*
* Send request.
*
* In principle, this could block if the output buffer is full, and we
* should use async mode and check for interrupts while waiting. In
* practice, our requests are small enough to always fit in the output and
* TCP buffer.
*/
if (PQputCopyData(pageserver_conn, req_buff.data, req_buff.len) <= 0)
{
char *msg = pchomp(PQerrorMessage(pageserver_conn));
if (n_reconnect_attempts < MAX_RECONNECT_ATTEMPTS)
{
neon_log(LOG, "failed to send page request (try to reconnect): %s", msg);
if (n_reconnect_attempts != 0) /* do not sleep before first reconnect attempt, assuming that pageserver is already restarted */
pg_usleep(RECONNECT_INTERVAL_USEC);
n_reconnect_attempts += 1;
continue;
}
else
{
pageserver_disconnect();
neon_log(ERROR, "failed to send page request: %s", msg);
}
}
break;
pageserver_disconnect();
neon_log(ERROR, "failed to send page request: %s", msg);
}
pfree(req_buff.data);
n_unflushed_requests++;

View File

@@ -1,15 +0,0 @@
# pgxs/neon_utils/Makefile
MODULE_big = neon_utils
OBJS = \
$(WIN32RES) \
neon_utils.o
EXTENSION = neon_utils
DATA = neon_utils--1.0.sql
PGFILEDESC = "neon_utils - small useful functions"
PG_CONFIG = pg_config
PGXS := $(shell $(PG_CONFIG) --pgxs)
include $(PGXS)

View File

@@ -1,6 +0,0 @@
CREATE FUNCTION num_cpus()
RETURNS int
AS 'MODULE_PATHNAME', 'num_cpus'
LANGUAGE C STRICT
PARALLEL UNSAFE
VOLATILE;

View File

@@ -1,35 +0,0 @@
/*-------------------------------------------------------------------------
*
* neon_utils.c
* neon_utils - small useful functions
*
* IDENTIFICATION
* contrib/neon_utils/neon_utils.c
*
*-------------------------------------------------------------------------
*/
#ifdef _WIN32
#include <windows.h>
#else
#include <unistd.h>
#endif
#include "postgres.h"
#include "fmgr.h"
PG_MODULE_MAGIC;
PG_FUNCTION_INFO_V1(num_cpus);
Datum
num_cpus(PG_FUNCTION_ARGS)
{
#ifdef _WIN32
SYSTEM_INFO sysinfo;
GetSystemInfo(&sysinfo);
uint32 num_cpus = (uint32) sysinfo.dwNumberOfProcessors;
#else
uint32 num_cpus = (uint32) sysconf(_SC_NPROCESSORS_ONLN);
#endif
PG_RETURN_UINT32(num_cpus);
}

View File

@@ -1,6 +0,0 @@
# neon_utils extension
comment = 'neon_utils - small useful functions'
default_version = '1.0'
module_pathname = '$libdir/neon_utils'
relocatable = true
trusted = true

View File

@@ -65,14 +65,6 @@
#include "rusagestub.h"
#endif
#include "access/clog.h"
#include "access/commit_ts.h"
#include "access/heapam.h"
#include "access/multixact.h"
#include "access/nbtree.h"
#include "access/subtrans.h"
#include "access/syncscan.h"
#include "access/twophase.h"
#include "access/xlog.h"
#include "access/xlog_internal.h"
#if PG_VERSION_NUM >= 150000
@@ -80,36 +72,18 @@
#endif
#include "access/xlogutils.h"
#include "catalog/pg_class.h"
#include "commands/async.h"
#include "libpq/libpq.h"
#include "libpq/pqformat.h"
#include "miscadmin.h"
#include "pgstat.h"
#include "postmaster/autovacuum.h"
#include "postmaster/bgworker_internals.h"
#include "postmaster/bgwriter.h"
#include "postmaster/postmaster.h"
#include "replication/logicallauncher.h"
#include "replication/origin.h"
#include "replication/slot.h"
#include "replication/walreceiver.h"
#include "replication/walsender.h"
#include "storage/buf_internals.h"
#include "storage/bufmgr.h"
#include "storage/dsm.h"
#include "storage/ipc.h"
#include "storage/pg_shmem.h"
#include "storage/pmsignal.h"
#include "storage/predicate.h"
#include "storage/proc.h"
#include "storage/procarray.h"
#include "storage/procsignal.h"
#include "storage/sinvaladt.h"
#include "storage/smgr.h"
#include "storage/spin.h"
#include "tcop/tcopprot.h"
#include "utils/memutils.h"
#include "utils/ps_status.h"
#include "utils/snapmgr.h"
#include "inmem_smgr.h"
@@ -127,7 +101,6 @@ static void apply_error_callback(void *arg);
static bool redo_block_filter(XLogReaderState *record, uint8 block_id);
static void GetPage(StringInfo input_message);
static ssize_t buffered_read(void *buf, size_t count);
static void CreateFakeSharedMemoryAndSemaphores();
static BufferTag target_redo_tag;
@@ -168,7 +141,7 @@ enter_seccomp_mode(void)
PG_SCMP_ALLOW(shmctl),
PG_SCMP_ALLOW(shmdt),
PG_SCMP_ALLOW(unlink), // shm_unlink
*/
*/
};
#ifdef MALLOC_NO_MMAP
@@ -204,7 +177,6 @@ WalRedoMain(int argc, char *argv[])
* buffers. So let's keep it small (default value is 1024)
*/
num_temp_buffers = 4;
NBuffers = 4;
/*
* install the simple in-memory smgr
@@ -212,33 +184,49 @@ WalRedoMain(int argc, char *argv[])
smgr_hook = smgr_inmem;
smgr_init_hook = smgr_init_inmem;
/*
* Validate we have been given a reasonable-looking DataDir and change into it.
*/
checkDataDir();
ChangeToDataDir();
/*
* Create lockfile for data directory.
*/
CreateDataDirLockFile(false);
/* read control file (error checking and contains config ) */
LocalProcessControlFile(false);
/*
* process any libraries that should be preloaded at postmaster start
*/
process_shared_preload_libraries();
/* Initialize MaxBackends (if under postmaster, was done already) */
MaxConnections = 1;
max_worker_processes = 0;
max_parallel_workers = 0;
max_wal_senders = 0;
InitializeMaxBackends();
/* Disable lastWrittenLsnCache */
lastWrittenLsnCacheSize = 0;
#if PG_VERSION_NUM >= 150000
/*
* Give preloaded libraries a chance to request additional shared memory.
*/
process_shmem_requests();
/*
* Now that loadable modules have had their chance to request additional
* shared memory, determine the value of any runtime-computed GUCs that
* depend on the amount of shared memory required.
*/
InitializeShmemGUCs();
/*
* This will try to access data directory which we do not set.
* Seems to be pretty safe to disable.
* Now that modules have been loaded, we can process any custom resource
* managers specified in the wal_consistency_checking GUC.
*/
/* InitializeWalConsistencyChecking(); */
InitializeWalConsistencyChecking();
#endif
/*
* We have our own version of CreateSharedMemoryAndSemaphores() that
* sets up local memory instead of shared one.
*/
CreateFakeSharedMemoryAndSemaphores();
CreateSharedMemoryAndSemaphores();
/*
* Remember stand-alone backend startup time,roughly at the same point
@@ -366,172 +354,6 @@ WalRedoMain(int argc, char *argv[])
}
/*
* Initialize dummy shmem.
*
* This code follows CreateSharedMemoryAndSemaphores() but manually sets up
* the shmem header and skips few initialization steps that are not needed for
* WAL redo.
*
* I've also tried removing most of initialization functions that request some
* memory (like ApplyLauncherShmemInit and friends) but in reality it haven't had
* any sizeable effect on RSS, so probably such clean up not worth the risk of having
* half-initialized postgres.
*/
static void
CreateFakeSharedMemoryAndSemaphores()
{
PGShmemHeader *shim = NULL;
PGShmemHeader *hdr;
Size size;
int numSemas;
char cwd[MAXPGPATH];
#if PG_VERSION_NUM >= 150000
size = CalculateShmemSize(&numSemas);
#else
/*
* Postgres v14 doesn't have a separate CalculateShmemSize(). Use result of the
* corresponging calculation in CreateSharedMemoryAndSemaphores()
*/
size = 1409024;
numSemas = 10;
#endif
/* Dummy implementation of PGSharedMemoryCreate() */
{
hdr = (PGShmemHeader *) malloc(size);
if (!hdr)
ereport(FATAL,
(errcode(ERRCODE_OUT_OF_MEMORY),
errmsg("[neon-wal-redo] can not allocate (pseudo-) shared memory")));
hdr->creatorPID = getpid();
hdr->magic = PGShmemMagic;
hdr->dsm_control = 0;
hdr->device = 42; /* not relevant for non-shared memory */
hdr->inode = 43; /* not relevant for non-shared memory */
hdr->totalsize = size;
hdr->freeoffset = MAXALIGN(sizeof(PGShmemHeader));
shim = hdr;
UsedShmemSegAddr = hdr;
UsedShmemSegID = (unsigned long) 42; /* not relevant for non-shared memory */
}
InitShmemAccess(hdr);
/*
* Reserve semaphores uses dir name as a source of entropy. Set it to cwd(). Rest
* of the code does not need DataDir access so nullify DataDir after
* PGReserveSemaphores() to error out if something will try to access it.
*/
if (!getcwd(cwd, MAXPGPATH))
ereport(FATAL,
(errcode(ERRCODE_INTERNAL_ERROR),
errmsg("[neon-wal-redo] can not read current directory name")));
DataDir = cwd;
PGReserveSemaphores(numSemas);
DataDir = NULL;
/*
* The rest of function follows CreateSharedMemoryAndSemaphores() closely,
* skipped parts are marked with comments.
*/
InitShmemAllocation();
/*
* Now initialize LWLocks, which do shared memory allocation and are
* needed for InitShmemIndex.
*/
CreateLWLocks();
/*
* Set up shmem.c index hashtable
*/
InitShmemIndex();
dsm_shmem_init();
/*
* Set up xlog, clog, and buffers
*/
XLOGShmemInit();
CLOGShmemInit();
CommitTsShmemInit();
SUBTRANSShmemInit();
MultiXactShmemInit();
InitBufferPool();
/*
* Set up lock manager
*/
InitLocks();
/*
* Set up predicate lock manager
*/
InitPredicateLocks();
/*
* Set up process table
*/
if (!IsUnderPostmaster)
InitProcGlobal();
CreateSharedProcArray();
CreateSharedBackendStatus();
TwoPhaseShmemInit();
BackgroundWorkerShmemInit();
/*
* Set up shared-inval messaging
*/
CreateSharedInvalidationState();
/*
* Set up interprocess signaling mechanisms
*/
PMSignalShmemInit();
ProcSignalShmemInit();
CheckpointerShmemInit();
AutoVacuumShmemInit();
ReplicationSlotsShmemInit();
ReplicationOriginShmemInit();
WalSndShmemInit();
WalRcvShmemInit();
PgArchShmemInit();
ApplyLauncherShmemInit();
/*
* Set up other modules that need some shared memory space
*/
SnapMgrInit();
BTreeShmemInit();
SyncScanShmemInit();
/* Skip due to the 'pg_notify' directory check */
/* AsyncShmemInit(); */
#ifdef EXEC_BACKEND
/*
* Alloc the win32 shared backend array
*/
if (!IsUnderPostmaster)
ShmemBackendArrayAllocation();
#endif
/* Initialize dynamic shared memory facilities. */
if (!IsUnderPostmaster)
dsm_postmaster_startup(shim);
/*
* Now give loadable modules a chance to set up their shmem allocations
*/
if (shmem_startup_hook)
shmem_startup_hook();
}
/* Version compatility wrapper for ReadBufferWithoutRelcache */
static inline Buffer
NeonRedoReadBuffer(RelFileNode rnode,

View File

@@ -31,7 +31,6 @@ once_cell.workspace = true
opentelemetry.workspace = true
parking_lot.workspace = true
pin-project-lite.workspace = true
postgres_backend.workspace = true
pq_proto.workspace = true
prometheus.workspace = true
rand.workspace = true

View File

@@ -25,11 +25,12 @@ impl CancelMap {
cancel_closure.try_cancel_query().await
}
/// Create a new session, with a new client-facing random cancellation key.
///
/// Use `enable_query_cancellation` to register the Postgres backend's cancellation
/// key with it.
pub fn new_session<'a>(&'a self) -> anyhow::Result<Session<'a>> {
/// Run async action within an ephemeral session identified by [`CancelKeyData`].
pub async fn with_session<'a, F, R, V>(&'a self, f: F) -> anyhow::Result<V>
where
F: FnOnce(Session<'a>) -> R,
R: std::future::Future<Output = anyhow::Result<V>>,
{
// HACK: We'd rather get the real backend_pid but tokio_postgres doesn't
// expose it and we don't want to do another roundtrip to query
// for it. The client will be able to notice that this is not the
@@ -43,9 +44,17 @@ impl CancelMap {
.write()
.try_insert(key, None)
.map_err(|_| anyhow!("query cancellation key already exists: {key}"))?;
info!("registered new query cancellation key {key}");
Ok(Session::new(key, self))
// This will guarantee that the session gets dropped
// as soon as the future is finished.
scopeguard::defer! {
self.0.write().remove(&key);
info!("dropped query cancellation key {key}");
}
info!("registered new query cancellation key {key}");
let session = Session::new(key, self);
f(session).await
}
#[cfg(test)]
@@ -102,7 +111,7 @@ impl<'a> Session<'a> {
impl Session<'_> {
/// Store the cancel token for the given session.
/// This enables query cancellation in [`crate::proxy::handshake`].
pub fn enable_query_cancellation(&self, cancel_closure: CancelClosure) -> CancelKeyData {
pub fn enable_query_cancellation(self, cancel_closure: CancelClosure) -> CancelKeyData {
info!("enabling query cancellation for this session");
self.cancel_map
.0
@@ -113,14 +122,6 @@ impl Session<'_> {
}
}
impl<'a> Drop for Session<'a> {
fn drop(&mut self) {
let key = &self.key;
self.cancel_map.0.write().remove(key);
info!("dropped query cancellation key {key}");
}
}
#[cfg(test)]
mod tests {
use super::*;
@@ -131,14 +132,14 @@ mod tests {
static CANCEL_MAP: Lazy<CancelMap> = Lazy::new(Default::default);
let (tx, rx) = tokio::sync::oneshot::channel();
let session = CANCEL_MAP.new_session()?;
let task = tokio::spawn(async move {
let task = tokio::spawn(CANCEL_MAP.with_session(|session| async move {
assert!(CANCEL_MAP.contains(&session));
tx.send(()).expect("failed to send");
futures::future::pending::<()>().await; // sleep forever
});
Ok(())
}));
// Wait until the task has been spawned.
rx.await.context("failed to hear from the task")?;

View File

@@ -4,11 +4,13 @@ use crate::{
};
use anyhow::Context;
use once_cell::sync::Lazy;
use postgres_backend::{self, AuthType, PostgresBackend, QueryError};
use pq_proto::{BeMessage, SINGLE_COL_ROWDESC};
use std::future;
use tokio::net::{TcpListener, TcpStream};
use std::{net::TcpStream, thread};
use tracing::{error, info, info_span};
use utils::{
postgres_backend::{self, AuthType, PostgresBackend},
postgres_backend_async::QueryError,
};
static CPLANE_WAITERS: Lazy<Waiters<ComputeReady>> = Lazy::new(Default::default);
@@ -31,7 +33,7 @@ pub fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), waiters::N
/// Console management API listener task.
/// It spawns console response handlers needed for the link auth.
pub async fn task_main(listener: TcpListener) -> anyhow::Result<()> {
pub async fn task_main(listener: tokio::net::TcpListener) -> anyhow::Result<()> {
scopeguard::defer! {
info!("mgmt has shut down");
}
@@ -40,12 +42,18 @@ pub async fn task_main(listener: TcpListener) -> anyhow::Result<()> {
let (socket, peer_addr) = listener.accept().await?;
info!("accepted connection from {peer_addr}");
let socket = socket.into_std()?;
socket
.set_nodelay(true)
.context("failed to set client socket option")?;
socket
.set_nonblocking(false)
.context("failed to set client socket option")?;
tokio::task::spawn(async move {
let span = info_span!("mgmt", peer = %peer_addr);
// TODO: replace with async tasks.
thread::spawn(move || {
let tid = std::thread::current().id();
let span = info_span!("mgmt", thread = format_args!("{tid:?}"));
let _enter = span.enter();
info!("started a new console management API thread");
@@ -53,16 +61,16 @@ pub async fn task_main(listener: TcpListener) -> anyhow::Result<()> {
info!("console management API thread is about to finish");
}
if let Err(e) = handle_connection(socket).await {
if let Err(e) = handle_connection(socket) {
error!("thread failed with an error: {e}");
}
});
}
}
async fn handle_connection(socket: TcpStream) -> Result<(), QueryError> {
let pgbackend = PostgresBackend::new(socket, AuthType::Trust, None)?;
pgbackend.run(&mut MgmtHandler, future::pending::<()>).await
fn handle_connection(socket: TcpStream) -> Result<(), QueryError> {
let pgbackend = PostgresBackend::new(socket, AuthType::Trust, None, true)?;
pgbackend.run(&mut MgmtHandler)
}
/// A message received by `mgmt` when a compute node is ready.
@@ -70,21 +78,16 @@ pub type ComputeReady = Result<DatabaseInfo, String>;
// TODO: replace with an http-based protocol.
struct MgmtHandler;
#[async_trait::async_trait]
impl postgres_backend::Handler for MgmtHandler {
async fn process_query(
&mut self,
pgb: &mut PostgresBackend,
query: &str,
) -> Result<(), QueryError> {
try_process_query(pgb, query).await.map_err(|e| {
fn process_query(&mut self, pgb: &mut PostgresBackend, query: &str) -> Result<(), QueryError> {
try_process_query(pgb, query).map_err(|e| {
error!("failed to process response: {e:?}");
e
})
}
}
async fn try_process_query(pgb: &mut PostgresBackend, query: &str) -> Result<(), QueryError> {
fn try_process_query(pgb: &mut PostgresBackend, query: &str) -> Result<(), QueryError> {
let resp: KickSession = serde_json::from_str(query).context("Failed to parse query as json")?;
let span = info_span!("event", session_id = resp.session_id);
@@ -95,11 +98,11 @@ async fn try_process_query(pgb: &mut PostgresBackend, query: &str) -> Result<(),
Ok(()) => {
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
.write_message_noflush(&BeMessage::DataRow(&[Some(b"ok")]))?
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
}
Err(e) => {
error!("failed to deliver response to per-client task");
pgb.write_message_noflush(&BeMessage::ErrorResponse(&e.to_string(), None))?;
pgb.write_message(&BeMessage::ErrorResponse(&e.to_string(), None))?;
}
}

View File

@@ -21,7 +21,6 @@ const PROXY_IO_BYTES_PER_CLIENT: &str = "proxy_io_bytes_per_client";
#[derive(Eq, Hash, PartialEq, Serialize, Debug)]
pub struct Ids {
pub endpoint_id: String,
pub branch_id: String,
}
pub async fn task_main(config: &MetricCollectionConfig) -> anyhow::Result<()> {
@@ -75,23 +74,12 @@ fn gather_proxy_io_bytes_per_client() -> Vec<(Ids, (u64, DateTime<Utc>))> {
.find(|l| l.get_name() == "endpoint_id")
.unwrap()
.get_value();
let branch_id = ms
.get_label()
.iter()
.find(|l| l.get_name() == "branch_id")
.unwrap()
.get_value();
let value = ms.get_counter().get_value() as u64;
debug!(
"branch_id {} endpoint_id {} val: {}",
branch_id, endpoint_id, value
);
debug!("endpoint_id:val - {}: {}", endpoint_id, value);
current_metrics.push((
Ids {
endpoint_id: endpoint_id.to_string(),
branch_id: "".to_string(),
},
(value, Utc::now()),
));
@@ -143,7 +131,6 @@ async fn collect_metrics_iteration(
value,
extra: Ids {
endpoint_id: curr_key.endpoint_id.clone(),
branch_id: curr_key.branch_id.clone(),
},
})
})
@@ -185,7 +172,6 @@ async fn collect_metrics_iteration(
cached_metrics
.entry(Ids {
endpoint_id: send_metric.extra.endpoint_id.clone(),
branch_id: send_metric.extra.branch_id.clone(),
})
// update cached value (add delta) and time
.and_modify(|e| {

View File

@@ -133,14 +133,10 @@ pub async fn handle_ws_client(
async { result }.or_else(|e| stream.throw_error(e)).await?
};
let client = Client::new(
stream,
creds,
&params,
session_id,
cancel_map.new_session()?,
);
client.connect_to_db(true).await
let client = Client::new(stream, creds, &params, session_id);
cancel_map
.with_session(|session| client.connect_to_db(session, true))
.await
}
#[tracing::instrument(fields(session_id), skip_all)]
@@ -176,14 +172,10 @@ async fn handle_client(
async { result }.or_else(|e| stream.throw_error(e)).await?
};
let client = Client::new(
stream,
creds,
&params,
session_id,
cancel_map.new_session()?,
);
client.connect_to_db(false).await
let client = Client::new(stream, creds, &params, session_id);
cancel_map
.with_session(|session| client.connect_to_db(session, false))
.await
}
/// Establish a (most probably, secure) connection with the client.
@@ -389,8 +381,6 @@ struct Client<'a, S> {
params: &'a StartupMessageParams,
/// Unique connection ID.
session_id: uuid::Uuid,
session: cancellation::Session<'a>,
}
impl<'a, S> Client<'a, S> {
@@ -400,27 +390,28 @@ impl<'a, S> Client<'a, S> {
creds: auth::BackendType<'a, auth::ClientCredentials<'a>>,
params: &'a StartupMessageParams,
session_id: uuid::Uuid,
session: cancellation::Session<'a>,
) -> Self {
Self {
stream,
creds,
params,
session_id,
session,
}
}
}
impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
/// Let the client authenticate and connect to the designated compute node.
async fn connect_to_db(self, allow_cleartext: bool) -> anyhow::Result<()> {
async fn connect_to_db(
self,
session: cancellation::Session<'_>,
allow_cleartext: bool,
) -> anyhow::Result<()> {
let Self {
mut stream,
mut creds,
params,
session_id,
session,
} = self;
let extra = console::ConsoleReqExtra {

View File

@@ -1,40 +1,45 @@
use crate::error::UserFacingError;
use anyhow::bail;
use bytes::BytesMut;
use pin_project_lite::pin_project;
use pq_proto::framed::{ConnectionError, Framed};
use pq_proto::{BeMessage, FeMessage, FeStartupPacket, ProtocolError};
use pq_proto::{BeMessage, ConnectionError, FeMessage, FeStartupPacket};
use rustls::ServerConfig;
use std::pin::Pin;
use std::sync::Arc;
use std::{io, task};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
use tokio_rustls::server::TlsStream;
/// Stream wrapper which implements libpq's protocol.
/// NOTE: This object deliberately doesn't implement [`AsyncRead`]
/// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
/// to pass random malformed bytes through the connection).
pub struct PqStream<S> {
framed: Framed<S>,
pin_project! {
/// Stream wrapper which implements libpq's protocol.
/// NOTE: This object deliberately doesn't implement [`AsyncRead`]
/// or [`AsyncWrite`] to prevent subtle errors (e.g. trying
/// to pass random malformed bytes through the connection).
pub struct PqStream<S> {
#[pin]
stream: S,
buffer: BytesMut,
}
}
impl<S> PqStream<S> {
/// Construct a new libpq protocol wrapper.
pub fn new(stream: S) -> Self {
Self {
framed: Framed::new(stream),
stream,
buffer: Default::default(),
}
}
/// Extract the underlying stream.
pub fn into_inner(self) -> S {
self.framed.into_inner()
self.stream
}
/// Get a shared reference to the underlying stream.
pub fn get_ref(&self) -> &S {
self.framed.get_ref()
&self.stream
}
}
@@ -45,19 +50,16 @@ fn err_connection() -> io::Error {
impl<S: AsyncRead + Unpin> PqStream<S> {
/// Receive [`FeStartupPacket`], which is a first packet sent by a client.
pub async fn read_startup_packet(&mut self) -> io::Result<FeStartupPacket> {
self.framed
.read_startup_message()
// TODO: `FeStartupPacket::read_fut` should return `FeStartupPacket`
let msg = FeStartupPacket::read_fut(&mut self.stream)
.await
.map_err(ConnectionError::into_io_error)?
.ok_or_else(err_connection)
}
.ok_or_else(err_connection)?;
async fn read_message(&mut self) -> io::Result<FeMessage> {
self.framed
.read_message()
.await
.map_err(ConnectionError::into_io_error)?
.ok_or_else(err_connection)
match msg {
FeMessage::StartupPacket(packet) => Ok(packet),
_ => panic!("unreachable state"),
}
}
pub async fn read_password_message(&mut self) -> io::Result<bytes::Bytes> {
@@ -69,14 +71,19 @@ impl<S: AsyncRead + Unpin> PqStream<S> {
)),
}
}
async fn read_message(&mut self) -> io::Result<FeMessage> {
FeMessage::read_fut(&mut self.stream)
.await
.map_err(ConnectionError::into_io_error)?
.ok_or_else(err_connection)
}
}
impl<S: AsyncWrite + Unpin> PqStream<S> {
/// Write the message into an internal buffer, but don't flush the underlying stream.
pub fn write_message_noflush(&mut self, message: &BeMessage<'_>) -> io::Result<&mut Self> {
self.framed
.write_message(message)
.map_err(ProtocolError::into_io_error)?;
BeMessage::write(&mut self.buffer, message)?;
Ok(self)
}
@@ -89,7 +96,9 @@ impl<S: AsyncWrite + Unpin> PqStream<S> {
/// Flush the output buffer into the underlying stream.
pub async fn flush(&mut self) -> io::Result<&mut Self> {
self.framed.flush().await?;
self.stream.write_all(&self.buffer).await?;
self.buffer.clear();
self.stream.flush().await?;
Ok(self)
}

View File

@@ -11,18 +11,12 @@
# Not every feature is supported in macOS builds. Avoid running regular linting
# script that checks every feature.
#
# manual-range-contains wants
# !(4..=MAX_STARTUP_PACKET_LENGTH).contains(&len)
# instead of
# len < 4 || len > MAX_STARTUP_PACKET_LENGTH
# , let's disagree.
if [[ "$OSTYPE" == "darwin"* ]]; then
# no extra features to test currently, add more here when needed
cargo clippy --locked --all --all-targets --features testing -- -A unknown_lints -A clippy::manual-range-contains -D warnings
cargo clippy --locked --all --all-targets --features testing -- -A unknown_lints -D warnings
else
# * `-A unknown_lints` do not warn about unknown lint suppressions
# that people with newer toolchains might use
# * `-D warnings` - fail on any warnings (`cargo` returns non-zero exit status)
cargo clippy --locked --all --all-targets --all-features -- -A unknown_lints -A clippy::manual-range-contains -D warnings
cargo clippy --locked --all --all-targets --all-features -- -A unknown_lints -D warnings
fi

View File

@@ -10,7 +10,6 @@ anyhow.workspace = true
async-trait.workspace = true
byteorder.workspace = true
bytes.workspace = true
chrono.workspace = true
clap = { workspace = true, features = ["derive"] }
const_format.workspace = true
crc32c.workspace = true
@@ -19,6 +18,7 @@ git-version.workspace = true
hex.workspace = true
humantime.workspace = true
hyper.workspace = true
nix.workspace = true
once_cell.workspace = true
parking_lot.workspace = true
postgres.workspace = true
@@ -35,7 +35,6 @@ toml_edit.workspace = true
tracing.workspace = true
url.workspace = true
metrics.workspace = true
postgres_backend.workspace = true
postgres_ffi.workspace = true
pq_proto.workspace = true
remote_storage.workspace = true

View File

@@ -236,7 +236,7 @@ fn start_safekeeper(conf: SafeKeeperConf) -> Result<()> {
let conf_cloned = conf.clone();
let safekeeper_thread = thread::Builder::new()
.name("WAL service thread".into())
.name("safekeeper thread".into())
.spawn(|| wal_service::thread_main(conf_cloned, pg_listener))
.unwrap();

View File

@@ -1,264 +0,0 @@
//! Utils for dumping full state of the safekeeper.
use std::fs;
use std::fs::DirEntry;
use std::io::BufReader;
use std::io::Read;
use std::path::PathBuf;
use anyhow::Result;
use chrono::{DateTime, Utc};
use postgres_ffi::XLogSegNo;
use serde::Serialize;
use utils::http::json::display_serialize;
use utils::id::NodeId;
use utils::id::TenantTimelineId;
use utils::id::{TenantId, TimelineId};
use utils::lsn::Lsn;
use crate::safekeeper::SafeKeeperState;
use crate::safekeeper::SafekeeperMemState;
use crate::safekeeper::TermHistory;
use crate::SafeKeeperConf;
use crate::timeline::ReplicaState;
use crate::GlobalTimelines;
/// Various filters that influence the resulting JSON output.
#[derive(Debug, Serialize)]
pub struct Args {
/// Dump all available safekeeper state. False by default.
pub dump_all: bool,
/// Dump control_file content. Uses value of `dump_all` by default.
pub dump_control_file: bool,
/// Dump in-memory state. Uses value of `dump_all` by default.
pub dump_memory: bool,
/// Dump all disk files in a timeline directory. Uses value of `dump_all` by default.
pub dump_disk_content: bool,
/// Dump full term history. True by default.
pub dump_term_history: bool,
/// Filter timelines by tenant_id.
pub tenant_id: Option<TenantId>,
/// Filter timelines by timeline_id.
pub timeline_id: Option<TimelineId>,
}
/// Response for debug dump request.
#[derive(Debug, Serialize)]
pub struct Response {
pub start_time: DateTime<Utc>,
pub finish_time: DateTime<Utc>,
pub timelines: Vec<Timeline>,
pub timelines_count: usize,
pub config: Config,
}
/// Safekeeper configuration.
#[derive(Debug, Serialize)]
pub struct Config {
pub id: NodeId,
pub workdir: PathBuf,
pub listen_pg_addr: String,
pub listen_http_addr: String,
pub no_sync: bool,
pub max_offloader_lag_bytes: u64,
pub wal_backup_enabled: bool,
}
#[derive(Debug, Serialize)]
pub struct Timeline {
#[serde(serialize_with = "display_serialize")]
pub tenant_id: TenantId,
#[serde(serialize_with = "display_serialize")]
pub timeline_id: TimelineId,
pub control_file: Option<SafeKeeperState>,
pub memory: Option<Memory>,
pub disk_content: Option<DiskContent>,
}
#[derive(Debug, Serialize)]
pub struct Memory {
pub is_cancelled: bool,
pub peers_info_len: usize,
pub replicas: Vec<Option<ReplicaState>>,
pub wal_backup_active: bool,
pub active: bool,
pub num_computes: u32,
pub last_removed_segno: XLogSegNo,
pub epoch_start_lsn: Lsn,
pub mem_state: SafekeeperMemState,
// PhysicalStorage state.
pub write_lsn: Lsn,
pub write_record_lsn: Lsn,
pub flush_lsn: Lsn,
pub file_open: bool,
}
#[derive(Debug, Serialize)]
pub struct DiskContent {
pub files: Vec<FileInfo>,
}
#[derive(Debug, Serialize)]
pub struct FileInfo {
pub name: String,
pub size: u64,
pub created: DateTime<Utc>,
pub modified: DateTime<Utc>,
pub start_zeroes: u64,
pub end_zeroes: u64,
// TODO: add sha256 checksum
}
/// Build debug dump response, using the provided [`Args`] filters.
pub fn build(args: Args) -> Result<Response> {
let start_time = Utc::now();
let timelines_count = GlobalTimelines::timelines_count();
let ptrs_snapshot = if args.tenant_id.is_some() && args.timeline_id.is_some() {
// If both tenant_id and timeline_id are specified, we can just get the
// timeline directly, without taking a snapshot of the whole list.
let ttid = TenantTimelineId::new(args.tenant_id.unwrap(), args.timeline_id.unwrap());
if let Ok(tli) = GlobalTimelines::get(ttid) {
vec![tli]
} else {
vec![]
}
} else {
// Otherwise, take a snapshot of the whole list.
GlobalTimelines::get_all()
};
// TODO: return Stream instead of Vec
let mut timelines = Vec::new();
for tli in ptrs_snapshot {
let ttid = tli.ttid;
if let Some(tenant_id) = args.tenant_id {
if tenant_id != ttid.tenant_id {
continue;
}
}
if let Some(timeline_id) = args.timeline_id {
if timeline_id != ttid.timeline_id {
continue;
}
}
let control_file = if args.dump_control_file {
let mut state = tli.get_state().1;
if !args.dump_term_history {
state.acceptor_state.term_history = TermHistory(vec![]);
}
Some(state)
} else {
None
};
let memory = if args.dump_memory {
Some(tli.memory_dump())
} else {
None
};
let disk_content = if args.dump_disk_content {
// build_disk_content can fail, but we don't want to fail the whole
// request because of that.
build_disk_content(&tli.timeline_dir).ok()
} else {
None
};
let timeline = Timeline {
tenant_id: ttid.tenant_id,
timeline_id: ttid.timeline_id,
control_file,
memory,
disk_content,
};
timelines.push(timeline);
}
let config = GlobalTimelines::get_global_config();
Ok(Response {
start_time,
finish_time: Utc::now(),
timelines,
timelines_count,
config: build_config(config),
})
}
/// Builds DiskContent from a directory path. It can fail if the directory
/// is deleted between the time we get the path and the time we try to open it.
fn build_disk_content(path: &std::path::Path) -> Result<DiskContent> {
let mut files = Vec::new();
for entry in fs::read_dir(path)? {
if entry.is_err() {
continue;
}
let file = build_file_info(entry?);
if file.is_err() {
continue;
}
files.push(file?);
}
Ok(DiskContent { files })
}
/// Builds FileInfo from DirEntry. Sometimes it can return an error
/// if the file is deleted between the time we get the DirEntry
/// and the time we try to open it.
fn build_file_info(entry: DirEntry) -> Result<FileInfo> {
let metadata = entry.metadata()?;
let path = entry.path();
let name = path
.file_name()
.and_then(|x| x.to_str())
.unwrap_or("")
.to_owned();
let mut file = fs::File::open(path)?;
let mut reader = BufReader::new(&mut file).bytes().filter_map(|x| x.ok());
let start_zeroes = reader.by_ref().take_while(|&x| x == 0).count() as u64;
let mut end_zeroes = 0;
for b in reader {
if b == 0 {
end_zeroes += 1;
} else {
end_zeroes = 0;
}
}
Ok(FileInfo {
name,
size: metadata.len(),
created: DateTime::from(metadata.created()?),
modified: DateTime::from(metadata.modified()?),
start_zeroes,
end_zeroes,
})
}
/// Converts SafeKeeperConf to Config, filtering out the fields that are not
/// supposed to be exposed.
fn build_config(config: SafeKeeperConf) -> Config {
Config {
id: config.my_id,
workdir: config.workdir,
listen_pg_addr: config.listen_pg_addr,
listen_http_addr: config.listen_http_addr,
no_sync: config.no_sync,
max_offloader_lag_bytes: config.max_offloader_lag_bytes,
wal_backup_enabled: config.wal_backup_enabled,
}
}

View File

@@ -1,24 +1,27 @@
//! Part of Safekeeper pretending to be Postgres, i.e. handling Postgres
//! protocol commands.
use anyhow::Context;
use std::str;
use tracing::{info, info_span, Instrument};
use crate::auth::check_permission;
use crate::json_ctrl::{handle_json_ctrl, AppendLogicalMessage};
use crate::receive_wal::ReceiveWalConn;
use crate::send_wal::ReplicationConn;
use crate::wal_service::ConnectionId;
use crate::{GlobalTimelines, SafeKeeperConf};
use postgres_backend::QueryError;
use postgres_backend::{self, PostgresBackend};
use anyhow::Context;
use postgres_ffi::PG_TLI;
use pq_proto::{BeMessage, FeStartupPacket, RowDescriptor, INT4_OID, TEXT_OID};
use regex::Regex;
use pq_proto::{BeMessage, FeStartupPacket, RowDescriptor, INT4_OID, TEXT_OID};
use std::str;
use tracing::info;
use utils::auth::{Claims, Scope};
use utils::postgres_backend_async::QueryError;
use utils::{
id::{TenantId, TenantTimelineId, TimelineId},
lsn::Lsn,
postgres_backend::{self, PostgresBackend},
};
/// Safekeeper handler of postgres commands
@@ -29,8 +32,6 @@ pub struct SafekeeperPostgresHandler {
pub tenant_id: Option<TenantId>,
pub timeline_id: Option<TimelineId>,
pub ttid: TenantTimelineId,
/// Unique connection id is logged in spans for observability.
pub conn_id: ConnectionId,
claims: Option<Claims>,
}
@@ -52,7 +53,7 @@ fn parse_cmd(cmd: &str) -> anyhow::Result<SafekeeperPostgresCommand> {
let start_lsn = caps
.next()
.map(|cap| cap[1].parse::<Lsn>())
.context("parse start LSN from START_REPLICATION command")??;
.context("failed to parse start LSN from START_REPLICATION command")??;
Ok(SafekeeperPostgresCommand::StartReplication { start_lsn })
} else if cmd.starts_with("IDENTIFY_SYSTEM") {
Ok(SafekeeperPostgresCommand::IdentifySystem)
@@ -66,7 +67,6 @@ fn parse_cmd(cmd: &str) -> anyhow::Result<SafekeeperPostgresCommand> {
}
}
#[async_trait::async_trait]
impl postgres_backend::Handler for SafekeeperPostgresHandler {
// tenant_id and timeline_id are passed in connection string params
fn startup(
@@ -137,7 +137,7 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler {
Ok(())
}
async fn process_query(
fn process_query(
&mut self,
pgb: &mut PostgresBackend,
query_string: &str,
@@ -147,10 +147,9 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler {
.starts_with("set datestyle to ")
{
// important for debug because psycopg2 executes "SET datestyle TO 'ISO'" on connect
pgb.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
pgb.write_message(&BeMessage::CommandComplete(b"SELECT 1"))?;
return Ok(());
}
let cmd = parse_cmd(query_string)?;
info!(
@@ -162,36 +161,38 @@ impl postgres_backend::Handler for SafekeeperPostgresHandler {
let timeline_id = self.timeline_id.context("timelineid is required")?;
self.check_permission(Some(tenant_id))?;
self.ttid = TenantTimelineId::new(tenant_id, timeline_id);
let span_ttid = self.ttid; // satisfy borrow checker
match cmd {
SafekeeperPostgresCommand::StartWalPush => {
self.handle_start_wal_push(pgb)
.instrument(info_span!("WAL receiver", ttid = %span_ttid))
.await
}
let res = match cmd {
SafekeeperPostgresCommand::StartWalPush => ReceiveWalConn::new(pgb).run(self),
SafekeeperPostgresCommand::StartReplication { start_lsn } => {
self.handle_start_replication(pgb, start_lsn)
.instrument(info_span!("WAL sender", ttid = %span_ttid))
.await
ReplicationConn::new(pgb).run(self, pgb, start_lsn)
}
SafekeeperPostgresCommand::IdentifySystem => self.handle_identify_system(pgb).await,
SafekeeperPostgresCommand::JSONCtrl { ref cmd } => {
handle_json_ctrl(self, pgb, cmd).await
SafekeeperPostgresCommand::IdentifySystem => self.handle_identify_system(pgb),
SafekeeperPostgresCommand::JSONCtrl { ref cmd } => handle_json_ctrl(self, pgb, cmd),
};
match res {
Ok(()) => Ok(()),
Err(QueryError::Disconnected(connection_error)) => {
info!("Timeline {tenant_id}/{timeline_id} query failed with connection error: {connection_error}");
Err(QueryError::Disconnected(connection_error))
}
Err(QueryError::Other(e)) => Err(QueryError::Other(e.context(format!(
"Failed to process query for timeline {}",
self.ttid
)))),
}
}
}
impl SafekeeperPostgresHandler {
pub fn new(conf: SafeKeeperConf, conn_id: u32) -> Self {
pub fn new(conf: SafeKeeperConf) -> Self {
SafekeeperPostgresHandler {
conf,
appname: None,
tenant_id: None,
timeline_id: None,
ttid: TenantTimelineId::empty(),
conn_id,
claims: None,
}
}
@@ -216,11 +217,8 @@ impl SafekeeperPostgresHandler {
///
/// Handle IDENTIFY_SYSTEM replication command
///
async fn handle_identify_system(
&mut self,
pgb: &mut PostgresBackend,
) -> Result<(), QueryError> {
let tli = GlobalTimelines::get(self.ttid).map_err(|e| QueryError::Other(e.into()))?;
fn handle_identify_system(&mut self, pgb: &mut PostgresBackend) -> Result<(), QueryError> {
let tli = GlobalTimelines::get(self.ttid)?;
let lsn = if self.is_walproposer_recovery() {
// walproposer should get all local WAL until flush_lsn
@@ -269,7 +267,7 @@ impl SafekeeperPostgresHandler {
Some(lsn_bytes),
None,
]))?
.write_message_noflush(&BeMessage::CommandComplete(b"IDENTIFY_SYSTEM"))?;
.write_message(&BeMessage::CommandComplete(b"IDENTIFY_SYSTEM"))?;
Ok(())
}

View File

@@ -119,12 +119,6 @@ paths:
$ref: "#/components/responses/ForbiddenError"
default:
$ref: "#/components/responses/GenericError"
"404":
description: Timeline not found
content:
application/json:
schema:
$ref: "#/components/schemas/NotFoundError"
delete:
tags:

View File

@@ -1,19 +1,18 @@
use hyper::{Body, Request, Response, StatusCode, Uri};
use anyhow::Context;
use once_cell::sync::Lazy;
use postgres_ffi::WAL_SEGMENT_SIZE;
use safekeeper_api::models::SkTimelineInfo;
use serde::Serialize;
use serde::Serializer;
use std::collections::{HashMap, HashSet};
use std::fmt;
use std::str::FromStr;
use std::fmt::Display;
use std::sync::Arc;
use storage_broker::proto::SafekeeperTimelineInfo;
use storage_broker::proto::TenantTimelineId as ProtoTenantTimelineId;
use tokio::task::JoinError;
use utils::http::json::display_serialize;
use crate::debug_dump;
use crate::safekeeper::ServerInfo;
use crate::safekeeper::Term;
@@ -55,6 +54,15 @@ fn get_conf(request: &Request<Body>) -> &SafeKeeperConf {
.as_ref()
}
/// Serialize through Display trait.
fn display_serialize<S, F>(z: &F, s: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
F: Display,
{
s.serialize_str(&format!("{}", z))
}
/// Same as TermSwitchEntry, but serializes LSN using display serializer
/// in Postgres format, i.e. 0/FFFFFFFF. Used only for the API response.
#[derive(Debug, Serialize)]
@@ -111,7 +119,12 @@ async fn timeline_status_handler(request: Request<Body>) -> Result<Response<Body
);
check_permission(&request, Some(ttid.tenant_id))?;
let tli = GlobalTimelines::get(ttid).map_err(ApiError::from)?;
let tli = GlobalTimelines::get(ttid)
// FIXME: Currently, the only errors from `GlobalTimelines::get` will be client errors
// because the provided timeline isn't there. However, the method can in theory change and
// fail from internal errors later. Remove this comment once it the method returns
// something other than `anyhow::Result`.
.map_err(ApiError::InternalServerError)?;
let (inmem, state) = tli.get_state();
let flush_lsn = tli.get_flush_lsn();
@@ -168,9 +181,12 @@ async fn timeline_create_handler(mut request: Request<Body>) -> Result<Response<
.commit_lsn
.segment_lsn(server_info.wal_seg_size as usize)
});
GlobalTimelines::create(ttid, server_info, request_data.commit_lsn, local_start_lsn)
.await
.map_err(ApiError::InternalServerError)?;
tokio::task::spawn_blocking(move || {
GlobalTimelines::create(ttid, server_info, request_data.commit_lsn, local_start_lsn)
})
.await
.map_err(|e| ApiError::InternalServerError(e.into()))?
.map_err(ApiError::InternalServerError)?;
json_response(StatusCode::OK, ())
}
@@ -244,7 +260,15 @@ async fn record_safekeeper_info(mut request: Request<Body>) -> Result<Response<B
local_start_lsn: sk_info.local_start_lsn.0,
};
let tli = GlobalTimelines::get(ttid).map_err(ApiError::from)?;
let tli = GlobalTimelines::get(ttid)
// `GlobalTimelines::get` returns an error when it can't find the timeline.
.with_context(|| {
format!(
"Couldn't get timeline {} for tenant {}",
ttid.timeline_id, ttid.tenant_id
)
})
.map_err(ApiError::NotFound)?;
tli.record_safekeeper_info(&proto_sk_info)
.await
.map_err(ApiError::InternalServerError)?;
@@ -252,69 +276,6 @@ async fn record_safekeeper_info(mut request: Request<Body>) -> Result<Response<B
json_response(StatusCode::OK, ())
}
fn parse_kv_str<E: fmt::Display, T: FromStr<Err = E>>(k: &str, v: &str) -> Result<T, ApiError> {
v.parse()
.map_err(|e| ApiError::BadRequest(anyhow::anyhow!("cannot parse {k}: {e}")))
}
/// Dump debug info about all available safekeeper state.
async fn dump_debug_handler(mut request: Request<Body>) -> Result<Response<Body>, ApiError> {
check_permission(&request, None)?;
ensure_no_body(&mut request).await?;
let mut dump_all: Option<bool> = None;
let mut dump_control_file: Option<bool> = None;
let mut dump_memory: Option<bool> = None;
let mut dump_disk_content: Option<bool> = None;
let mut dump_term_history: Option<bool> = None;
let mut tenant_id: Option<TenantId> = None;
let mut timeline_id: Option<TimelineId> = None;
let query = request.uri().query().unwrap_or("");
let mut values = url::form_urlencoded::parse(query.as_bytes());
for (k, v) in &mut values {
match k.as_ref() {
"dump_all" => dump_all = Some(parse_kv_str(&k, &v)?),
"dump_control_file" => dump_control_file = Some(parse_kv_str(&k, &v)?),
"dump_memory" => dump_memory = Some(parse_kv_str(&k, &v)?),
"dump_disk_content" => dump_disk_content = Some(parse_kv_str(&k, &v)?),
"dump_term_history" => dump_term_history = Some(parse_kv_str(&k, &v)?),
"tenant_id" => tenant_id = Some(parse_kv_str(&k, &v)?),
"timeline_id" => timeline_id = Some(parse_kv_str(&k, &v)?),
_ => Err(ApiError::BadRequest(anyhow::anyhow!(
"Unknown query parameter: {}",
k
)))?,
}
}
let dump_all = dump_all.unwrap_or(false);
let dump_control_file = dump_control_file.unwrap_or(dump_all);
let dump_memory = dump_memory.unwrap_or(dump_all);
let dump_disk_content = dump_disk_content.unwrap_or(dump_all);
let dump_term_history = dump_term_history.unwrap_or(true);
let args = debug_dump::Args {
dump_all,
dump_control_file,
dump_memory,
dump_disk_content,
dump_term_history,
tenant_id,
timeline_id,
};
let resp = tokio::task::spawn_blocking(move || {
debug_dump::build(args).map_err(ApiError::InternalServerError)
})
.await
.map_err(|e: JoinError| ApiError::InternalServerError(e.into()))??;
// TODO: use streaming response
json_response(StatusCode::OK, resp)
}
/// Safekeeper http router.
pub fn make_router(conf: SafeKeeperConf) -> RouterBuilder<hyper::Body, ApiError> {
let mut router = endpoint::make_router();
@@ -355,7 +316,6 @@ pub fn make_router(conf: SafeKeeperConf) -> RouterBuilder<hyper::Body, ApiError>
"/v1/record_safekeeper_info/:tenant_id/:timeline_id",
record_safekeeper_info,
)
.get("/v1/debug_dump", dump_debug_handler)
}
#[cfg(test)]

View File

@@ -10,10 +10,10 @@ use std::sync::Arc;
use anyhow::Context;
use bytes::Bytes;
use postgres_backend::QueryError;
use serde::{Deserialize, Serialize};
use tracing::*;
use utils::id::TenantTimelineId;
use utils::postgres_backend_async::QueryError;
use crate::handler::SafekeeperPostgresHandler;
use crate::safekeeper::{AcceptorProposerMessage, AppendResponse, ServerInfo};
@@ -23,30 +23,29 @@ use crate::safekeeper::{
use crate::safekeeper::{SafeKeeperState, Term, TermHistory, TermSwitchEntry};
use crate::timeline::Timeline;
use crate::GlobalTimelines;
use postgres_backend::PostgresBackend;
use postgres_ffi::encode_logical_message;
use postgres_ffi::WAL_SEGMENT_SIZE;
use pq_proto::{BeMessage, RowDescriptor, TEXT_OID};
use utils::lsn::Lsn;
use utils::{lsn::Lsn, postgres_backend::PostgresBackend};
#[derive(Serialize, Deserialize, Debug)]
pub struct AppendLogicalMessage {
// prefix and message to build LogicalMessage
pub lm_prefix: String,
pub lm_message: String,
lm_prefix: String,
lm_message: String,
// if true, commit_lsn will match flush_lsn after append
pub set_commit_lsn: bool,
set_commit_lsn: bool,
// if true, ProposerElected will be sent before append
pub send_proposer_elected: bool,
send_proposer_elected: bool,
// fields from AppendRequestHeader
pub term: Term,
pub epoch_start_lsn: Lsn,
pub begin_lsn: Lsn,
pub truncate_lsn: Lsn,
pub pg_version: u32,
term: Term,
epoch_start_lsn: Lsn,
begin_lsn: Lsn,
truncate_lsn: Lsn,
pg_version: u32,
}
#[derive(Debug, Serialize, Deserialize)]
@@ -60,7 +59,7 @@ struct AppendResult {
/// Handles command to craft logical message WAL record with given
/// content, and then append it with specified term and lsn. This
/// function is used to test safekeepers in different scenarios.
pub async fn handle_json_ctrl(
pub fn handle_json_ctrl(
spg: &SafekeeperPostgresHandler,
pgb: &mut PostgresBackend,
append_request: &AppendLogicalMessage,
@@ -68,7 +67,7 @@ pub async fn handle_json_ctrl(
info!("JSON_CTRL request: {append_request:?}");
// need to init safekeeper state before AppendRequest
let tli = prepare_safekeeper(spg.ttid, append_request.pg_version).await?;
let tli = prepare_safekeeper(spg.ttid, append_request.pg_version)?;
// if send_proposer_elected is true, we need to update local history
if append_request.send_proposer_elected {
@@ -90,16 +89,13 @@ pub async fn handle_json_ctrl(
..Default::default()
}]))?
.write_message_noflush(&BeMessage::DataRow(&[Some(&response_data)]))?
.write_message_noflush(&BeMessage::CommandComplete(b"JSON_CTRL"))?;
.write_message(&BeMessage::CommandComplete(b"JSON_CTRL"))?;
Ok(())
}
/// Prepare safekeeper to process append requests without crashes,
/// by sending ProposerGreeting with default server.wal_seg_size.
async fn prepare_safekeeper(
ttid: TenantTimelineId,
pg_version: u32,
) -> anyhow::Result<Arc<Timeline>> {
fn prepare_safekeeper(ttid: TenantTimelineId, pg_version: u32) -> anyhow::Result<Arc<Timeline>> {
GlobalTimelines::create(
ttid,
ServerInfo {
@@ -110,7 +106,6 @@ async fn prepare_safekeeper(
Lsn::INVALID,
Lsn::INVALID,
)
.await
}
fn send_proposer_elected(tli: &Arc<Timeline>, term: Term, lsn: Lsn) -> anyhow::Result<()> {
@@ -133,15 +128,15 @@ fn send_proposer_elected(tli: &Arc<Timeline>, term: Term, lsn: Lsn) -> anyhow::R
}
#[derive(Debug, Serialize, Deserialize)]
pub struct InsertedWAL {
struct InsertedWAL {
begin_lsn: Lsn,
pub end_lsn: Lsn,
end_lsn: Lsn,
append_response: AppendResponse,
}
/// Extend local WAL with new LogicalMessage record. To do that,
/// create AppendRequest with new WAL and pass it to safekeeper.
pub fn append_logical_message(
fn append_logical_message(
tli: &Arc<Timeline>,
msg: &AppendLogicalMessage,
) -> anyhow::Result<InsertedWAL> {

View File

@@ -1,8 +1,8 @@
use storage_broker::Uri;
//
use remote_storage::RemoteStorageConfig;
use std::path::PathBuf;
use std::time::Duration;
use storage_broker::Uri;
use utils::id::{NodeId, TenantId, TenantTimelineId};
@@ -10,7 +10,6 @@ mod auth;
pub mod broker;
pub mod control_file;
pub mod control_file_upgrade;
pub mod debug_dump;
pub mod handler;
pub mod http;
pub mod json_ctrl;

View File

@@ -2,134 +2,72 @@
//! Gets messages from the network, passes them down to consensus module and
//! sends replies back.
use crate::handler::SafekeeperPostgresHandler;
use crate::safekeeper::AcceptorProposerMessage;
use crate::safekeeper::ProposerAcceptorMessage;
use anyhow::anyhow;
use anyhow::Context;
use bytes::BytesMut;
use tracing::*;
use utils::lsn::Lsn;
use utils::postgres_backend_async::QueryError;
use crate::safekeeper::ServerInfo;
use crate::timeline::Timeline;
use crate::wal_service::ConnectionId;
use crate::GlobalTimelines;
use anyhow::{anyhow, Context};
use bytes::BytesMut;
use postgres_backend::CopyStreamHandlerEnd;
use postgres_backend::PostgresBackend;
use postgres_backend::PostgresBackendReader;
use postgres_backend::QueryError;
use pq_proto::BeMessage;
use std::net::SocketAddr;
use std::sync::mpsc::channel;
use std::sync::mpsc::Receiver;
use std::sync::Arc;
use std::thread;
use std::thread::JoinHandle;
use tokio::sync::mpsc::channel;
use tokio::sync::mpsc::error::TryRecvError;
use tokio::sync::mpsc::Receiver;
use tokio::sync::mpsc::Sender;
use tokio::task::spawn_blocking;
use tracing::*;
use utils::id::TenantTimelineId;
use utils::lsn::Lsn;
const MSG_QUEUE_SIZE: usize = 256;
const REPLY_QUEUE_SIZE: usize = 16;
use crate::safekeeper::AcceptorProposerMessage;
use crate::safekeeper::ProposerAcceptorMessage;
impl SafekeeperPostgresHandler {
/// Wrapper around handle_start_wal_push_guts handling result. Error is
/// handled here while we're still in walreceiver ttid span; with API
/// extension, this can probably be moved into postgres_backend.
pub async fn handle_start_wal_push(
&mut self,
pgb: &mut PostgresBackend,
) -> Result<(), QueryError> {
if let Err(end) = self.handle_start_wal_push_guts(pgb).await {
// Log the result and probably send it to the client, closing the stream.
pgb.handle_copy_stream_end(end).await;
use crate::handler::SafekeeperPostgresHandler;
use pq_proto::{BeMessage, FeMessage};
use utils::{postgres_backend::PostgresBackend, sock_split::ReadStream};
pub struct ReceiveWalConn<'pg> {
/// Postgres connection
pg_backend: &'pg mut PostgresBackend,
/// The cached result of `pg_backend.socket().peer_addr()` (roughly)
peer_addr: SocketAddr,
}
impl<'pg> ReceiveWalConn<'pg> {
pub fn new(pg: &'pg mut PostgresBackend) -> ReceiveWalConn<'pg> {
let peer_addr = *pg.get_peer_addr();
ReceiveWalConn {
pg_backend: pg,
peer_addr,
}
}
// Send message to the postgres
fn write_msg(&mut self, msg: &AcceptorProposerMessage) -> anyhow::Result<()> {
let mut buf = BytesMut::with_capacity(128);
msg.serialize(&mut buf)?;
self.pg_backend.write_message(&BeMessage::CopyData(&buf))?;
Ok(())
}
pub async fn handle_start_wal_push_guts(
&mut self,
pgb: &mut PostgresBackend,
) -> Result<(), CopyStreamHandlerEnd> {
/// Receive WAL from wal_proposer
pub fn run(&mut self, spg: &mut SafekeeperPostgresHandler) -> Result<(), QueryError> {
let _enter = info_span!("WAL acceptor", ttid = %spg.ttid).entered();
// Notify the libpq client that it's allowed to send `CopyData` messages
pgb.write_message(&BeMessage::CopyBothResponse).await?;
self.pg_backend
.write_message(&BeMessage::CopyBothResponse)?;
// Experiments [1] confirm that doing network IO in one (this) thread and
// processing with disc IO in another significantly improves
// performance; we spawn off WalAcceptor thread for message processing
// to this end.
//
// [1] https://github.com/neondatabase/neon/pull/1318
let (msg_tx, msg_rx) = channel(MSG_QUEUE_SIZE);
let (reply_tx, reply_rx) = channel(REPLY_QUEUE_SIZE);
let mut acceptor_handle: Option<JoinHandle<anyhow::Result<()>>> = None;
let r = self
.pg_backend
.take_stream_in()
.ok_or_else(|| anyhow!("failed to take read stream from pgbackend"))?;
let mut poll_reader = ProposerPollStream::new(r)?;
// Concurrently receive and send data; replies are not synchronized with
// sends, so this avoids deadlocks.
let mut pgb_reader = pgb.split().context("START_WAL_PUSH split")?;
let peer_addr = *pgb.get_peer_addr();
let network_reader = NetworkReader {
ttid: self.ttid,
conn_id: self.conn_id,
pgb_reader: &mut pgb_reader,
peer_addr,
acceptor_handle: &mut acceptor_handle,
};
let res = tokio::select! {
// todo: add read|write .context to these errors
r = network_reader.run(msg_tx, msg_rx, reply_tx) => r,
r = network_write(pgb, reply_rx) => r,
};
// Join pg backend back.
pgb.unsplit(pgb_reader)?;
// Join the spawned WalAcceptor. At this point chans to/from it passed
// to network routines are dropped, so it will exit as soon as it
// touches them.
match acceptor_handle {
None => {
// failed even before spawning; read_network should have error
Err(res.expect_err("no error with WalAcceptor not spawn"))
}
Some(handle) => {
let wal_acceptor_res = handle.join();
// If there was any network error, return it.
res?;
// Otherwise, WalAcceptor thread must have errored.
match wal_acceptor_res {
Ok(Ok(_)) => Ok(()), // can't happen currently; would be if we add graceful termination
Ok(Err(e)) => Err(CopyStreamHandlerEnd::Other(e.context("WAL acceptor"))),
Err(_) => Err(CopyStreamHandlerEnd::Other(anyhow!(
"WalAcceptor thread panicked",
))),
}
}
}
}
}
struct NetworkReader<'a> {
ttid: TenantTimelineId,
conn_id: ConnectionId,
pgb_reader: &'a mut PostgresBackendReader,
peer_addr: SocketAddr,
// WalAcceptor is spawned when we learn server info from walproposer and
// create timeline; handle is put here.
acceptor_handle: &'a mut Option<JoinHandle<anyhow::Result<()>>>,
}
impl<'a> NetworkReader<'a> {
async fn run(
self,
msg_tx: Sender<ProposerAcceptorMessage>,
msg_rx: Receiver<ProposerAcceptorMessage>,
reply_tx: Sender<AcceptorProposerMessage>,
) -> Result<(), CopyStreamHandlerEnd> {
// Receive information about server to create timeline, if not yet.
let next_msg = read_message(self.pgb_reader).await?;
// Receive information about server
let next_msg = poll_reader.recv_msg()?;
let tli = match next_msg {
ProposerAcceptorMessage::Greeting(ref greeting) => {
info!(
@@ -141,158 +79,127 @@ impl<'a> NetworkReader<'a> {
system_id: greeting.system_id,
wal_seg_size: greeting.wal_seg_size,
};
GlobalTimelines::create(self.ttid, server_info, Lsn::INVALID, Lsn::INVALID).await?
GlobalTimelines::create(spg.ttid, server_info, Lsn::INVALID, Lsn::INVALID)?
}
_ => {
return Err(CopyStreamHandlerEnd::Other(anyhow::anyhow!(
return Err(QueryError::Other(anyhow::anyhow!(
"unexpected message {next_msg:?} instead of greeting"
)))
}
};
*self.acceptor_handle = Some(
WalAcceptor::spawn(tli.clone(), msg_rx, reply_tx, self.conn_id)
.context("spawn WalAcceptor thread")?,
);
// Forward all messages to WalAcceptor
read_network_loop(self.pgb_reader, msg_tx, next_msg).await
}
}
/// Read next message from walproposer.
/// TODO: Return Ok(None) on graceful termination.
async fn read_message(
pgb_reader: &mut PostgresBackendReader,
) -> Result<ProposerAcceptorMessage, CopyStreamHandlerEnd> {
let copy_data = pgb_reader.read_copy_message().await?;
let msg = ProposerAcceptorMessage::parse(copy_data)?;
Ok(msg)
}
async fn read_network_loop(
pgb_reader: &mut PostgresBackendReader,
msg_tx: Sender<ProposerAcceptorMessage>,
mut next_msg: ProposerAcceptorMessage,
) -> Result<(), CopyStreamHandlerEnd> {
loop {
if msg_tx.send(next_msg).await.is_err() {
return Ok(()); // chan closed, WalAcceptor terminated
}
next_msg = read_message(pgb_reader).await?;
}
}
/// Read replies from WalAcceptor and pass them back to socket. Returns Ok(())
/// if reply_rx closed; it must mean WalAcceptor terminated, joining it should
/// tell the error.
async fn network_write(
pgb_writer: &mut PostgresBackend,
mut reply_rx: Receiver<AcceptorProposerMessage>,
) -> Result<(), CopyStreamHandlerEnd> {
let mut buf = BytesMut::with_capacity(128);
loop {
match reply_rx.recv().await {
Some(msg) => {
buf.clear();
msg.serialize(&mut buf)?;
pgb_writer.write_message(&BeMessage::CopyData(&buf)).await?;
}
None => return Ok(()), // chan closed, WalAcceptor terminated
}
}
}
/// Takes messages from msg_rx, processes and pushes replies to reply_tx.
struct WalAcceptor {
tli: Arc<Timeline>,
msg_rx: Receiver<ProposerAcceptorMessage>,
reply_tx: Sender<AcceptorProposerMessage>,
}
impl WalAcceptor {
/// Spawn thread with WalAcceptor running, return handle to it.
fn spawn(
tli: Arc<Timeline>,
msg_rx: Receiver<ProposerAcceptorMessage>,
reply_tx: Sender<AcceptorProposerMessage>,
conn_id: ConnectionId,
) -> anyhow::Result<JoinHandle<anyhow::Result<()>>> {
let thread_name = format!("WAL acceptor {}", tli.ttid);
thread::Builder::new()
.name(thread_name)
.spawn(move || -> anyhow::Result<()> {
let mut wa = WalAcceptor {
tli,
msg_rx,
reply_tx,
};
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
let span_ttid = wa.tli.ttid; // satisfy borrow checker
runtime.block_on(
wa.run()
.instrument(info_span!("WAL acceptor", cid = %conn_id, ttid = %span_ttid)),
)
})
.map_err(anyhow::Error::from)
}
/// The main loop. Returns Ok(()) if either msg_rx or reply_tx got closed;
/// it must mean that network thread terminated.
async fn run(&mut self) -> anyhow::Result<()> {
// Register the connection and defer unregister.
self.tli.on_compute_connect().await?;
let _guard = ComputeConnectionGuard {
timeline: Arc::clone(&self.tli),
};
let mut next_msg: ProposerAcceptorMessage;
let mut next_msg = Some(next_msg);
let mut first_time_through = true;
let mut _guard: Option<ComputeConnectionGuard> = None;
loop {
let opt_msg = self.msg_rx.recv().await;
if opt_msg.is_none() {
return Ok(()); // chan closed, streaming terminated
}
next_msg = opt_msg.unwrap();
if matches!(next_msg, Some(ProposerAcceptorMessage::AppendRequest(_))) {
// poll AppendRequest's without blocking and write WAL to disk without flushing,
// while it's readily available
while let Some(ProposerAcceptorMessage::AppendRequest(append_request)) = next_msg {
let msg = ProposerAcceptorMessage::NoFlushAppendRequest(append_request);
if matches!(next_msg, ProposerAcceptorMessage::AppendRequest(_)) {
// loop through AppendRequest's while it's readily available to
// write as many WAL as possible without fsyncing
while let ProposerAcceptorMessage::AppendRequest(append_request) = next_msg {
let noflush_msg = ProposerAcceptorMessage::NoFlushAppendRequest(append_request);
if let Some(reply) = self.tli.process_msg(&noflush_msg)? {
if self.reply_tx.send(reply).await.is_err() {
return Ok(()); // chan closed, streaming terminated
}
let reply = tli.process_msg(&msg)?;
if let Some(reply) = reply {
self.write_msg(&reply)?;
}
match self.msg_rx.try_recv() {
Ok(msg) => next_msg = msg,
Err(TryRecvError::Empty) => break,
Err(TryRecvError::Disconnected) => return Ok(()), // chan closed, streaming terminated
}
next_msg = poll_reader.poll_msg();
}
// flush all written WAL to the disk
if let Some(reply) = self.tli.process_msg(&ProposerAcceptorMessage::FlushWAL)? {
if self.reply_tx.send(reply).await.is_err() {
return Ok(()); // chan closed, streaming terminated
}
let reply = tli.process_msg(&ProposerAcceptorMessage::FlushWAL)?;
if let Some(reply) = reply {
self.write_msg(&reply)?;
}
} else {
// process message other than AppendRequest
if let Some(reply) = self.tli.process_msg(&next_msg)? {
if self.reply_tx.send(reply).await.is_err() {
return Ok(()); // chan closed, streaming terminated
}
} else if let Some(msg) = next_msg.take() {
// process other message
let reply = tli.process_msg(&msg)?;
if let Some(reply) = reply {
self.write_msg(&reply)?;
}
}
if first_time_through {
// Register the connection and defer unregister. Do that only
// after processing first message, as it sets wal_seg_size,
// wanted by many.
tli.on_compute_connect()?;
_guard = Some(ComputeConnectionGuard {
timeline: Arc::clone(&tli),
});
first_time_through = false;
}
// blocking wait for the next message
if next_msg.is_none() {
next_msg = Some(poll_reader.recv_msg()?);
}
}
}
}
struct ProposerPollStream {
msg_rx: Receiver<ProposerAcceptorMessage>,
read_thread: Option<thread::JoinHandle<Result<(), QueryError>>>,
}
impl ProposerPollStream {
fn new(mut r: ReadStream) -> anyhow::Result<Self> {
let (msg_tx, msg_rx) = channel();
let read_thread = thread::Builder::new()
.name("Read WAL thread".into())
.spawn(move || -> Result<(), QueryError> {
loop {
let copy_data = match FeMessage::read(&mut r)? {
Some(FeMessage::CopyData(bytes)) => Ok(bytes),
Some(msg) => Err(QueryError::Other(anyhow::anyhow!(
"expected `CopyData` message, found {msg:?}"
))),
None => Err(QueryError::from(std::io::Error::new(
std::io::ErrorKind::ConnectionAborted,
"walproposer closed the connection",
))),
}?;
let msg = ProposerAcceptorMessage::parse(copy_data)?;
msg_tx
.send(msg)
.context("Failed to send the proposer message")?;
}
// msg_tx will be dropped here, this will also close msg_rx
})?;
Ok(Self {
msg_rx,
read_thread: Some(read_thread),
})
}
fn recv_msg(&mut self) -> Result<ProposerAcceptorMessage, QueryError> {
self.msg_rx.recv().map_err(|_| {
// return error from the read thread
let res = match self.read_thread.take() {
Some(thread) => thread.join(),
None => return QueryError::Other(anyhow::anyhow!("read thread is gone")),
};
match res {
Ok(Ok(())) => {
QueryError::Other(anyhow::anyhow!("unexpected result from read thread"))
}
Err(err) => QueryError::Other(anyhow::anyhow!("read thread panicked: {err:?}")),
Ok(Err(err)) => err,
}
})
}
fn poll_msg(&mut self) -> Option<ProposerAcceptorMessage> {
let res = self.msg_rx.try_recv();
match res {
Err(_) => None,
Ok(msg) => Some(msg),
}
}
}
@@ -303,13 +210,8 @@ struct ComputeConnectionGuard {
impl Drop for ComputeConnectionGuard {
fn drop(&mut self) {
let tli = self.timeline.clone();
// tokio forbids to call blocking_send inside the runtime, and see
// comments in on_compute_disconnect why we call blocking_send.
spawn_blocking(move || {
if let Err(e) = tli.on_compute_disconnect() {
error!("failed to unregister compute connection: {}", e);
}
});
if let Err(e) = self.timeline.on_compute_disconnect() {
error!("failed to unregister compute connection: {}", e);
}
}
}

View File

@@ -191,8 +191,7 @@ pub struct SafeKeeperState {
/// Minimal LSN which may be needed for recovery of some safekeeper (end_lsn
/// of last record streamed to everyone). Persisting it helps skipping
/// recovery in walproposer, generally we compute it from peers. In
/// walproposer proto called 'truncate_lsn'. Updates are currently drived
/// only by walproposer.
/// walproposer proto called 'truncate_lsn'.
pub peer_horizon_lsn: Lsn,
/// LSN of the oldest known checkpoint made by pageserver and successfully
/// pushed to s3. We don't remove WAL beyond it. Persisted only for
@@ -205,7 +204,7 @@ pub struct SafeKeeperState {
pub peers: PersistedPeers,
}
#[derive(Debug, Clone, Serialize)]
#[derive(Debug, Clone)]
// In memory safekeeper state. Fields mirror ones in `SafeKeeperState`; values
// are not flushed yet.
pub struct SafekeeperMemState {
@@ -213,7 +212,6 @@ pub struct SafekeeperMemState {
pub backup_lsn: Lsn,
pub peer_horizon_lsn: Lsn,
pub remote_consistent_lsn: Lsn,
#[serde(with = "hex")]
pub proposer_uuid: PgUuid,
}
@@ -488,7 +486,7 @@ impl AcceptorProposerMessage {
buf.put_u64_le(msg.hs_feedback.xmin);
buf.put_u64_le(msg.hs_feedback.catalog_xmin);
msg.pageserver_feedback.serialize(buf);
msg.pageserver_feedback.serialize(buf)?
}
}
@@ -683,7 +681,7 @@ where
term: self.state.acceptor_state.term,
vote_given: false as u64,
flush_lsn: self.flush_lsn(),
truncate_lsn: self.inmem.peer_horizon_lsn,
truncate_lsn: self.state.peer_horizon_lsn,
term_history: self.get_term_history(),
timeline_start_lsn: self.state.timeline_start_lsn,
};
@@ -879,13 +877,7 @@ where
if msg.h.commit_lsn != Lsn(0) {
self.update_commit_lsn(msg.h.commit_lsn)?;
}
// Value calculated by walproposer can always lag:
// - safekeepers can forget inmem value and send to proposer lower
// persisted one on restart;
// - if we make safekeepers always send persistent value,
// any compute restart would pull it down.
// Thus, take max before adopting.
self.inmem.peer_horizon_lsn = max(self.inmem.peer_horizon_lsn, msg.h.truncate_lsn);
self.inmem.peer_horizon_lsn = msg.h.truncate_lsn;
// Update truncate and commit LSN in control file.
// To avoid negative impact on performance of extra fsync, do it only

View File

@@ -5,22 +5,24 @@ use crate::handler::SafekeeperPostgresHandler;
use crate::timeline::{ReplicaState, Timeline};
use crate::wal_storage::WalReader;
use crate::GlobalTimelines;
use anyhow::Context as AnyhowContext;
use anyhow::Context;
use bytes::Bytes;
use postgres_backend::PostgresBackend;
use postgres_backend::{CopyStreamHandlerEnd, PostgresBackendReader, QueryError};
use postgres_ffi::get_current_timestamp;
use postgres_ffi::{TimestampTz, MAX_SEND_SIZE};
use pq_proto::{BeMessage, ReplicationFeedback, WalSndKeepAlive, XLogDataBody};
use serde::{Deserialize, Serialize};
use std::cmp::min;
use std::str;
use std::net::Shutdown;
use std::sync::Arc;
use std::time::Duration;
use std::{io, str, thread};
use utils::postgres_backend_async::QueryError;
use pq_proto::{BeMessage, FeMessage, ReplicationFeedback, WalSndKeepAlive, XLogDataBody};
use tokio::sync::watch::Receiver;
use tokio::time::timeout;
use tracing::*;
use utils::{bin_ser::BeSer, lsn::Lsn};
use utils::{bin_ser::BeSer, lsn::Lsn, postgres_backend::PostgresBackend, sock_split::ReadStream};
// See: https://www.postgresql.org/docs/13/protocol-replication.html
const HOT_STANDBY_FEEDBACK_TAG_BYTE: u8 = b'h';
@@ -58,6 +60,13 @@ pub struct StandbyReply {
pub reply_requested: bool,
}
/// A network connection that's speaking the replication protocol.
pub struct ReplicationConn {
/// This is an `Option` because we will spawn a background thread that will
/// `take` it from us.
stream_in: Option<ReadStream>,
}
/// Scope guard to unregister replication connection from timeline
struct ReplicationConnGuard {
replica: usize, // replica internal ID assigned by timeline
@@ -70,275 +79,230 @@ impl Drop for ReplicationConnGuard {
}
}
impl SafekeeperPostgresHandler {
/// Wrapper around handle_start_replication_guts handling result. Error is
/// handled here while we're still in walsender ttid span; with API
/// extension, this can probably be moved into postgres_backend.
pub async fn handle_start_replication(
&mut self,
pgb: &mut PostgresBackend,
start_pos: Lsn,
) -> Result<(), QueryError> {
if let Err(end) = self.handle_start_replication_guts(pgb, start_pos).await {
// Log the result and probably send it to the client, closing the stream.
pgb.handle_copy_stream_end(end).await;
impl ReplicationConn {
/// Create a new `ReplicationConn`
pub fn new(pgb: &mut PostgresBackend) -> Self {
Self {
stream_in: pgb.take_stream_in(),
}
}
/// Handle incoming messages from the network.
/// This is spawned into the background by `handle_start_replication`.
fn background_thread(
mut stream_in: ReadStream,
replica_guard: Arc<ReplicationConnGuard>,
) -> anyhow::Result<()> {
let replica_id = replica_guard.replica;
let timeline = &replica_guard.timeline;
let mut state = ReplicaState::new();
// Wait for replica's feedback.
while let Some(msg) = FeMessage::read(&mut stream_in)? {
match &msg {
FeMessage::CopyData(m) => {
// There's three possible data messages that the client is supposed to send here:
// `HotStandbyFeedback` and `StandbyStatusUpdate` and `NeonStandbyFeedback`.
match m.first().cloned() {
Some(HOT_STANDBY_FEEDBACK_TAG_BYTE) => {
// Note: deserializing is on m[1..] because we skip the tag byte.
state.hs_feedback = HotStandbyFeedback::des(&m[1..])
.context("failed to deserialize HotStandbyFeedback")?;
timeline.update_replica_state(replica_id, state);
}
Some(STANDBY_STATUS_UPDATE_TAG_BYTE) => {
let _reply = StandbyReply::des(&m[1..])
.context("failed to deserialize StandbyReply")?;
// This must be a regular postgres replica,
// because pageserver doesn't send this type of messages to safekeeper.
// Currently this is not implemented, so this message is ignored.
warn!("unexpected StandbyReply. Read-only postgres replicas are not supported in safekeepers yet.");
// timeline.update_replica_state(replica_id, Some(state));
}
Some(NEON_STATUS_UPDATE_TAG_BYTE) => {
// Note: deserializing is on m[9..] because we skip the tag byte and len bytes.
let buf = Bytes::copy_from_slice(&m[9..]);
let reply = ReplicationFeedback::parse(buf);
trace!("ReplicationFeedback is {:?}", reply);
// Only pageserver sends ReplicationFeedback, so set the flag.
// This replica is the source of information to resend to compute.
state.pageserver_feedback = Some(reply);
timeline.update_replica_state(replica_id, state);
}
_ => warn!("unexpected message {:?}", msg),
}
}
FeMessage::Sync => {}
FeMessage::CopyFail => {
// Shutdown the connection, because rust-postgres client cannot be dropped
// when connection is alive.
let _ = stream_in.shutdown(Shutdown::Both);
anyhow::bail!("Copy failed");
}
_ => {
// We only handle `CopyData`, 'Sync', 'CopyFail' messages. Anything else is ignored.
info!("unexpected message {:?}", msg);
}
}
}
Ok(())
}
pub async fn handle_start_replication_guts(
///
/// Handle START_REPLICATION replication command
///
pub fn run(
&mut self,
spg: &mut SafekeeperPostgresHandler,
pgb: &mut PostgresBackend,
start_pos: Lsn,
) -> Result<(), CopyStreamHandlerEnd> {
let appname = self.appname.clone();
let tli =
GlobalTimelines::get(self.ttid).map_err(|e| CopyStreamHandlerEnd::Other(e.into()))?;
mut start_pos: Lsn,
) -> Result<(), QueryError> {
let _enter = info_span!("WAL sender", ttid = %spg.ttid).entered();
let tli = GlobalTimelines::get(spg.ttid)?;
// spawn the background thread which receives HotStandbyFeedback messages.
let bg_timeline = Arc::clone(&tli);
let bg_stream_in = self.stream_in.take().unwrap();
let bg_timeline_id = spg.timeline_id.unwrap();
let state = ReplicaState::new();
// This replica_id is used below to check if it's time to stop replication.
let replica_id = tli.add_replica(state);
let replica_id = bg_timeline.add_replica(state);
// Use a guard object to remove our entry from the timeline, when the background
// thread and us have both finished using it.
let _guard = Arc::new(ReplicationConnGuard {
let replica_guard = Arc::new(ReplicationConnGuard {
replica: replica_id,
timeline: tli.clone(),
timeline: bg_timeline,
});
let bg_replica_guard = Arc::clone(&replica_guard);
// Walproposer gets special handling: safekeeper must give proposer all
// local WAL till the end, whether committed or not (walproposer will
// hang otherwise). That's because walproposer runs the consensus and
// synchronizes safekeepers on the most advanced one.
//
// There is a small risk of this WAL getting concurrently garbaged if
// another compute rises which collects majority and starts fixing log
// on this safekeeper itself. That's ok as (old) proposer will never be
// able to commit such WAL.
let stop_pos: Option<Lsn> = if self.is_walproposer_recovery() {
let wal_end = tli.get_flush_lsn();
Some(wal_end)
} else {
None
};
let end_pos = stop_pos.unwrap_or(Lsn::INVALID);
info!(
"starting streaming from {:?} till {:?}",
start_pos, stop_pos
);
// switch to copy
pgb.write_message(&BeMessage::CopyBothResponse).await?;
let (_, persisted_state) = tli.get_state();
let wal_reader = WalReader::new(
self.conf.workdir.clone(),
self.conf.timeline_dir(&tli.ttid),
&persisted_state,
start_pos,
self.conf.wal_backup_enabled,
)?;
// Split to concurrently receive and send data; replies are generally
// not synchronized with sends, so this avoids deadlocks.
let reader = pgb.split().context("START_REPLICATION split")?;
let mut sender = WalSender {
pgb,
tli: tli.clone(),
appname,
start_pos,
end_pos,
stop_pos,
commit_lsn_watch_rx: tli.get_commit_lsn_watch_rx(),
replica_id,
wal_reader,
send_buf: [0; MAX_SEND_SIZE],
};
let mut reply_reader = ReplyReader {
reader,
tli,
replica_id,
feedback: ReplicaState::new(),
};
let res = tokio::select! {
// todo: add read|write .context to these errors
r = sender.run() => r,
r = reply_reader.run() => r,
};
// Join pg backend back.
pgb.unsplit(reply_reader.reader)?;
res
}
}
/// A half driving sending WAL.
struct WalSender<'a> {
pgb: &'a mut PostgresBackend,
tli: Arc<Timeline>,
appname: Option<String>,
// Position since which we are sending next chunk.
start_pos: Lsn,
// WAL up to this position is known to be locally available.
end_pos: Lsn,
// If present, terminate after reaching this position; used by walproposer
// in recovery.
stop_pos: Option<Lsn>,
commit_lsn_watch_rx: Receiver<Lsn>,
replica_id: usize,
wal_reader: WalReader,
// buffer for readling WAL into to send it
send_buf: [u8; MAX_SEND_SIZE],
}
impl WalSender<'_> {
/// Send WAL until
/// - an error occurs
/// - if we are streaming to walproposer, we've streamed until stop_pos
/// (recovery finished)
/// - receiver is caughtup and there is no computes
///
/// Err(CopyStreamHandlerEnd) is always returned; Result is used only for ?
/// convenience.
async fn run(&mut self) -> Result<(), CopyStreamHandlerEnd> {
loop {
// If we are streaming to walproposer, check it is time to stop.
if let Some(stop_pos) = self.stop_pos {
if self.start_pos >= stop_pos {
// recovery finished
return Err(CopyStreamHandlerEnd::ServerInitiated(format!(
"ending streaming to walproposer at {}, recovery finished",
self.start_pos
)));
// TODO: here we got two threads, one for writing WAL and one for receiving
// feedback. If one of them fails, we should shutdown the other one too.
let _ = thread::Builder::new()
.name("HotStandbyFeedback thread".into())
.spawn(move || {
let _enter =
info_span!("HotStandbyFeedback thread", timeline = %bg_timeline_id).entered();
if let Err(err) = Self::background_thread(bg_stream_in, bg_replica_guard) {
error!("Replication background thread failed: {}", err);
}
})?;
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
runtime.block_on(async move {
let (inmem_state, persisted_state) = tli.get_state();
// add persisted_state.timeline_start_lsn == Lsn(0) check
// Walproposer gets special handling: safekeeper must give proposer all
// local WAL till the end, whether committed or not (walproposer will
// hang otherwise). That's because walproposer runs the consensus and
// synchronizes safekeepers on the most advanced one.
//
// There is a small risk of this WAL getting concurrently garbaged if
// another compute rises which collects majority and starts fixing log
// on this safekeeper itself. That's ok as (old) proposer will never be
// able to commit such WAL.
let stop_pos: Option<Lsn> = if spg.is_walproposer_recovery() {
let wal_end = tli.get_flush_lsn();
Some(wal_end)
} else {
// Wait for the next portion if it is not there yet, or just
// update our end of WAL available for sending value, we
// communicate it to the receiver.
self.wait_wal().await?;
}
None
};
// try to send as much as available, capped by MAX_SEND_SIZE
let mut send_size = self
.end_pos
.checked_sub(self.start_pos)
.context("reading wal without waiting for it first")?
.0 as usize;
send_size = min(send_size, self.send_buf.len());
let send_buf = &mut self.send_buf[..send_size];
// read wal into buffer
send_size = self.wal_reader.read(send_buf).await?;
let send_buf = &send_buf[..send_size];
info!("Start replication from {:?} till {:?}", start_pos, stop_pos);
// and send it
self.pgb
.write_message(&BeMessage::XLogData(XLogDataBody {
wal_start: self.start_pos.0,
wal_end: self.end_pos.0,
// switch to copy
pgb.write_message(&BeMessage::CopyBothResponse)?;
let mut end_pos = stop_pos.unwrap_or(inmem_state.commit_lsn);
let mut wal_reader = WalReader::new(
spg.conf.workdir.clone(),
spg.conf.timeline_dir(&tli.ttid),
&persisted_state,
start_pos,
spg.conf.wal_backup_enabled,
)?;
// buffer for wal sending, limited by MAX_SEND_SIZE
let mut send_buf = vec![0u8; MAX_SEND_SIZE];
// watcher for commit_lsn updates
let mut commit_lsn_watch_rx = tli.get_commit_lsn_watch_rx();
loop {
if let Some(stop_pos) = stop_pos {
if start_pos >= stop_pos {
break; /* recovery finished */
}
end_pos = stop_pos;
} else {
/* Wait until we have some data to stream */
let lsn = wait_for_lsn(&mut commit_lsn_watch_rx, start_pos).await?;
if let Some(lsn) = lsn {
end_pos = lsn;
} else {
// TODO: also check once in a while whether we are walsender
// to right pageserver.
if tli.should_walsender_stop(replica_id) {
// Shut down, timeline is suspended.
return Err(QueryError::from(io::Error::new(
io::ErrorKind::ConnectionAborted,
format!("end streaming to {:?}", spg.appname),
)));
}
// timeout expired: request pageserver status
pgb.write_message(&BeMessage::KeepAlive(WalSndKeepAlive {
sent_ptr: end_pos.0,
timestamp: get_current_timestamp(),
request_reply: true,
}))?;
continue;
}
}
let send_size = end_pos.checked_sub(start_pos).unwrap().0 as usize;
let send_size = min(send_size, send_buf.len());
let send_buf = &mut send_buf[..send_size];
// read wal into buffer
let send_size = wal_reader.read(send_buf).await?;
let send_buf = &send_buf[..send_size];
// Write some data to the network socket.
pgb.write_message(&BeMessage::XLogData(XLogDataBody {
wal_start: start_pos.0,
wal_end: end_pos.0,
timestamp: get_current_timestamp(),
data: send_buf,
}))
.await?;
.context("Failed to send XLogData")?;
trace!(
"sent {} bytes of WAL {}-{}",
send_size,
self.start_pos,
self.start_pos + send_size as u64
);
self.start_pos += send_size as u64;
}
}
/// wait until we have WAL to stream, sending keepalives and checking for
/// exit in the meanwhile
async fn wait_wal(&mut self) -> Result<(), CopyStreamHandlerEnd> {
loop {
if let Some(lsn) = wait_for_lsn(&mut self.commit_lsn_watch_rx, self.start_pos).await? {
self.end_pos = lsn;
return Ok(());
start_pos += send_size as u64;
trace!("sent WAL up to {}", start_pos);
}
// Timed out waiting for WAL, check for termination and send KA
if self.tli.should_walsender_stop(self.replica_id) {
// Terminate if there is nothing more to send.
// TODO close the stream properly
return Err(CopyStreamHandlerEnd::ServerInitiated(format!(
"ending streaming to {:?} at {}, receiver is caughtup and there is no computes",
self.appname, self.start_pos,
)));
}
self.pgb
.write_message(&BeMessage::KeepAlive(WalSndKeepAlive {
sent_ptr: self.end_pos.0,
timestamp: get_current_timestamp(),
request_reply: true,
}))
.await?;
}
}
}
/// A half driving receiving replies.
struct ReplyReader {
reader: PostgresBackendReader,
tli: Arc<Timeline>,
replica_id: usize,
feedback: ReplicaState,
}
impl ReplyReader {
async fn run(&mut self) -> Result<(), CopyStreamHandlerEnd> {
loop {
let msg = self.reader.read_copy_message().await?;
self.handle_feedback(&msg)?
}
}
fn handle_feedback(&mut self, msg: &Bytes) -> anyhow::Result<()> {
match msg.first().cloned() {
Some(HOT_STANDBY_FEEDBACK_TAG_BYTE) => {
// Note: deserializing is on m[1..] because we skip the tag byte.
self.feedback.hs_feedback = HotStandbyFeedback::des(&msg[1..])
.context("failed to deserialize HotStandbyFeedback")?;
self.tli
.update_replica_state(self.replica_id, self.feedback);
}
Some(STANDBY_STATUS_UPDATE_TAG_BYTE) => {
let _reply =
StandbyReply::des(&msg[1..]).context("failed to deserialize StandbyReply")?;
// This must be a regular postgres replica,
// because pageserver doesn't send this type of messages to safekeeper.
// Currently we just ignore this, tracking progress for them is not supported.
}
Some(NEON_STATUS_UPDATE_TAG_BYTE) => {
// pageserver sends this.
// Note: deserializing is on m[9..] because we skip the tag byte and len bytes.
let buf = Bytes::copy_from_slice(&msg[9..]);
let reply = ReplicationFeedback::parse(buf);
trace!("ReplicationFeedback is {:?}", reply);
// Only pageserver sends ReplicationFeedback, so set the flag.
// This replica is the source of information to resend to compute.
self.feedback.pageserver_feedback = Some(reply);
self.tli
.update_replica_state(self.replica_id, self.feedback);
}
_ => warn!("unexpected message {:?}", msg),
}
Ok(())
Ok(())
})
}
}
const POLL_STATE_TIMEOUT: Duration = Duration::from_secs(1);
/// Wait until we have commit_lsn > lsn or timeout expires. Returns
/// - Ok(Some(commit_lsn)) if needed lsn is successfully observed;
/// - Ok(None) if timeout expired;
/// - Err in case of error (if watch channel is in trouble, shouldn't happen).
// Wait until we have commit_lsn > lsn or timeout expires. Returns latest commit_lsn.
async fn wait_for_lsn(rx: &mut Receiver<Lsn>, lsn: Lsn) -> anyhow::Result<Option<Lsn>> {
let commit_lsn: Lsn = *rx.borrow();
if commit_lsn > lsn {

View File

@@ -1,11 +1,10 @@
//! This module implements Timeline lifecycle management and has all necessary code
//! This module implements Timeline lifecycle management and has all neccessary code
//! to glue together SafeKeeper and all other background services.
use anyhow::{anyhow, bail, Result};
use anyhow::{bail, Result};
use parking_lot::{Mutex, MutexGuard};
use postgres_ffi::XLogSegNo;
use pq_proto::ReplicationFeedback;
use serde::Serialize;
use std::cmp::{max, min};
use std::path::PathBuf;
use tokio::{
@@ -13,7 +12,6 @@ use tokio::{
time::Instant,
};
use tracing::*;
use utils::http::error::ApiError;
use utils::{
id::{NodeId, TenantTimelineId},
lsn::Lsn,
@@ -30,9 +28,9 @@ use crate::send_wal::HotStandbyFeedback;
use crate::{control_file, safekeeper::UNKNOWN_SERVER_VERSION};
use crate::metrics::FullTimelineInfo;
use crate::wal_storage;
use crate::wal_storage::Storage as wal_storage_iface;
use crate::SafeKeeperConf;
use crate::{debug_dump, wal_storage};
/// Things safekeeper should know about timeline state on peers.
#[derive(Debug, Clone)]
@@ -82,7 +80,7 @@ impl PeersInfo {
}
/// Replica status update + hot standby feedback
#[derive(Debug, Clone, Copy, Serialize)]
#[derive(Debug, Clone, Copy)]
pub struct ReplicaState {
/// last known lsn received by replica
pub last_received_lsn: Lsn, // None means we don't know
@@ -357,18 +355,6 @@ pub enum TimelineError {
UninitialinzedPgVersion(TenantTimelineId),
}
// Convert to HTTP API error.
impl From<TimelineError> for ApiError {
fn from(te: TimelineError) -> ApiError {
match te {
TimelineError::NotFound(ttid) => {
ApiError::NotFound(anyhow!("timeline {} not found", ttid))
}
_ => ApiError::InternalServerError(anyhow!("{}", te)),
}
}
}
/// Timeline struct manages lifecycle (creation, deletion, restore) of a safekeeper timeline.
/// It also holds SharedState and provides mutually exclusive access to it.
pub struct Timeline {
@@ -395,7 +381,7 @@ pub struct Timeline {
cancellation_rx: watch::Receiver<bool>,
/// Directory where timeline state is stored.
pub timeline_dir: PathBuf,
timeline_dir: PathBuf,
}
impl Timeline {
@@ -532,7 +518,7 @@ impl Timeline {
/// Register compute connection, starting timeline-related activity if it is
/// not running yet.
pub async fn on_compute_connect(&self) -> Result<()> {
pub fn on_compute_connect(&self) -> Result<()> {
if self.is_cancelled() {
bail!(TimelineError::Cancelled(self.ttid));
}
@@ -546,7 +532,7 @@ impl Timeline {
// Wake up wal backup launcher, if offloading not started yet.
if is_wal_backup_action_pending {
// Can fail only if channel to a static thread got closed, which is not normal at all.
self.wal_backup_launcher_tx.send(self.ttid).await?;
self.wal_backup_launcher_tx.blocking_send(self.ttid)?;
}
Ok(())
}
@@ -563,11 +549,6 @@ impl Timeline {
// Wake up wal backup launcher, if it is time to stop the offloading.
if is_wal_backup_action_pending {
// Can fail only if channel to a static thread got closed, which is not normal at all.
//
// Note: this is blocking_send because on_compute_disconnect is called in Drop, there is
// no async Drop and we use current thread runtimes. With current thread rt spawning
// task in drop impl is racy, as thread along with runtime might finish before the task.
// This should be switched send.await when/if we go to full async.
self.wal_backup_launcher_tx.blocking_send(self.ttid)?;
}
Ok(())
@@ -607,6 +588,38 @@ impl Timeline {
self.write_shared_state().wal_backup_attend()
}
/// Returns full timeline info, required for the metrics. If the timeline is
/// not active, returns None instead.
pub fn info_for_metrics(&self) -> Option<FullTimelineInfo> {
if self.is_cancelled() {
return None;
}
let state = self.write_shared_state();
if state.active {
Some(FullTimelineInfo {
ttid: self.ttid,
replicas: state
.replicas
.iter()
.filter_map(|r| r.as_ref())
.copied()
.collect(),
wal_backup_active: state.wal_backup_active,
timeline_is_active: state.active,
num_computes: state.num_computes,
last_removed_segno: state.last_removed_segno,
epoch_start_lsn: state.sk.epoch_start_lsn,
mem_state: state.sk.inmem.clone(),
persisted_state: state.sk.state.clone(),
flush_lsn: state.sk.wal_store.flush_lsn(),
wal_storage: state.sk.wal_store.get_metrics(),
})
} else {
None
}
}
/// Returns commit_lsn watch channel.
pub fn get_commit_lsn_watch_rx(&self) -> watch::Receiver<Lsn> {
self.commit_lsn_watch_rx.clone()
@@ -771,62 +784,6 @@ impl Timeline {
shared_state.last_removed_segno = horizon_segno;
Ok(())
}
/// Returns full timeline info, required for the metrics. If the timeline is
/// not active, returns None instead.
pub fn info_for_metrics(&self) -> Option<FullTimelineInfo> {
if self.is_cancelled() {
return None;
}
let state = self.write_shared_state();
if state.active {
Some(FullTimelineInfo {
ttid: self.ttid,
replicas: state
.replicas
.iter()
.filter_map(|r| r.as_ref())
.copied()
.collect(),
wal_backup_active: state.wal_backup_active,
timeline_is_active: state.active,
num_computes: state.num_computes,
last_removed_segno: state.last_removed_segno,
epoch_start_lsn: state.sk.epoch_start_lsn,
mem_state: state.sk.inmem.clone(),
persisted_state: state.sk.state.clone(),
flush_lsn: state.sk.wal_store.flush_lsn(),
wal_storage: state.sk.wal_store.get_metrics(),
})
} else {
None
}
}
/// Returns in-memory timeline state to build a full debug dump.
pub fn memory_dump(&self) -> debug_dump::Memory {
let state = self.write_shared_state();
let (write_lsn, write_record_lsn, flush_lsn, file_open) =
state.sk.wal_store.internal_state();
debug_dump::Memory {
is_cancelled: self.is_cancelled(),
peers_info_len: state.peers_info.0.len(),
replicas: state.replicas.clone(),
wal_backup_active: state.wal_backup_active,
active: state.active,
num_computes: state.num_computes,
last_removed_segno: state.last_removed_segno,
epoch_start_lsn: state.sk.epoch_start_lsn,
mem_state: state.sk.inmem.clone(),
write_lsn,
write_record_lsn,
flush_lsn,
file_open,
}
}
}
/// Deletes directory and it's contents. Returns false if directory does not exist.

View File

@@ -5,7 +5,7 @@
use crate::safekeeper::ServerInfo;
use crate::timeline::{Timeline, TimelineError};
use crate::SafeKeeperConf;
use anyhow::{bail, Context, Result};
use anyhow::{anyhow, bail, Context, Result};
use once_cell::sync::Lazy;
use serde::Serialize;
use std::collections::HashMap;
@@ -50,11 +50,11 @@ impl GlobalTimelinesState {
}
/// Get timeline from the map. Returns error if timeline doesn't exist.
fn get(&self, ttid: &TenantTimelineId) -> Result<Arc<Timeline>, TimelineError> {
fn get(&self, ttid: &TenantTimelineId) -> Result<Arc<Timeline>> {
self.timelines
.get(ttid)
.cloned()
.ok_or(TimelineError::NotFound(*ttid))
.ok_or_else(|| anyhow!(TimelineError::NotFound(*ttid)))
}
}
@@ -159,19 +159,9 @@ impl GlobalTimelines {
Ok(())
}
/// Get the number of timelines in the map.
pub fn timelines_count() -> usize {
TIMELINES_STATE.lock().unwrap().timelines.len()
}
/// Get the global safekeeper config.
pub fn get_global_config() -> SafeKeeperConf {
TIMELINES_STATE.lock().unwrap().get_conf().clone()
}
/// Create a new timeline with the given id. If the timeline already exists, returns
/// an existing timeline.
pub async fn create(
pub fn create(
ttid: TenantTimelineId,
server_info: ServerInfo,
commit_lsn: Lsn,
@@ -199,20 +189,28 @@ impl GlobalTimelines {
// Take a lock and finish the initialization holding this mutex. No other threads
// can interfere with creation after we will insert timeline into the map.
{
let mut shared_state = timeline.write_shared_state();
let mut shared_state = timeline.write_shared_state();
// We can get a race condition here in case of concurrent create calls, but only
// in theory. create() will return valid timeline on the next try.
TIMELINES_STATE
.lock()
.unwrap()
.try_insert(timeline.clone())?;
// We can get a race condition here in case of concurrent create calls, but only
// in theory. create() will return valid timeline on the next try.
TIMELINES_STATE
.lock()
.unwrap()
.try_insert(timeline.clone())?;
// Write the new timeline to the disk and start background workers.
// Bootstrap is transactional, so if it fails, the timeline will be deleted,
// and the state on disk should remain unchanged.
if let Err(e) = timeline.bootstrap(&mut shared_state) {
// Write the new timeline to the disk and start background workers.
// Bootstrap is transactional, so if it fails, the timeline will be deleted,
// and the state on disk should remain unchanged.
match timeline.bootstrap(&mut shared_state) {
Ok(_) => {
// We are done with bootstrap, release the lock, return the timeline.
drop(shared_state);
timeline
.wal_backup_launcher_tx
.blocking_send(timeline.ttid)?;
Ok(timeline)
}
Err(e) => {
// Note: the most likely reason for bootstrap failure is that the timeline
// directory already exists on disk. This happens when timeline is corrupted
// and wasn't loaded from disk on startup because of that. We want to preserve
@@ -224,33 +222,29 @@ impl GlobalTimelines {
// Timeline failed to bootstrap, it cannot be used. Remove it from the map.
TIMELINES_STATE.lock().unwrap().timelines.remove(&ttid);
return Err(e);
Err(e)
}
// We are done with bootstrap, release the lock, return the timeline.
// {} block forces release before .await
}
timeline.wal_backup_launcher_tx.send(timeline.ttid).await?;
Ok(timeline)
}
/// Get a timeline from the global map. If it's not present, it doesn't exist on disk,
/// or was corrupted and couldn't be loaded on startup. Returned timeline is always valid,
/// i.e. loaded in memory and not cancelled.
pub fn get(ttid: TenantTimelineId) -> Result<Arc<Timeline>, TimelineError> {
pub fn get(ttid: TenantTimelineId) -> Result<Arc<Timeline>> {
let res = TIMELINES_STATE.lock().unwrap().get(&ttid);
match res {
Ok(tli) => {
if tli.is_cancelled() {
return Err(TimelineError::Cancelled(ttid));
anyhow::bail!(TimelineError::Cancelled(ttid));
}
Ok(tli)
}
_ => res,
Err(e) => Err(e),
}
}
/// Returns all timelines. This is used for background timeline processes.
/// Returns all timelines. This is used for background timeline proccesses.
pub fn get_all() -> Vec<Arc<Timeline>> {
let global_lock = TIMELINES_STATE.lock().unwrap();
global_lock

View File

@@ -191,7 +191,7 @@ async fn wal_backup_launcher_main_loop(
.map(|c| GenericRemoteStorage::from_config(c).expect("failed to create remote storage"))
});
// Presence in this map means launcher is aware s3 offloading is needed for
// Presense in this map means launcher is aware s3 offloading is needed for
// the timeline, but task is started only if it makes sense for to offload
// from this safekeeper.
let mut tasks: HashMap<TenantTimelineId, WalBackupTimelineEntry> = HashMap::new();
@@ -467,7 +467,7 @@ async fn backup_object(source_file: &Path, target_file: &RemotePath, size: usize
pub async fn read_object(
file_path: &RemotePath,
offset: u64,
) -> anyhow::Result<Pin<Box<dyn tokio::io::AsyncRead + Send + Sync>>> {
) -> anyhow::Result<Pin<Box<dyn tokio::io::AsyncRead>>> {
let storage = REMOTE_STORAGE
.get()
.context("Failed to get remote storage")?

View File

@@ -2,70 +2,50 @@
//! WAL service listens for client connections and
//! receive WAL from wal_proposer and send it to WAL receivers
//!
use anyhow::{Context, Result};
use postgres_backend::QueryError;
use std::{future, thread};
use tokio::net::TcpStream;
use regex::Regex;
use std::net::{TcpListener, TcpStream};
use std::thread;
use tracing::*;
use utils::postgres_backend_async::QueryError;
use crate::handler::SafekeeperPostgresHandler;
use crate::SafeKeeperConf;
use postgres_backend::{AuthType, PostgresBackend};
use utils::postgres_backend::{AuthType, PostgresBackend};
/// Accept incoming TCP connections and spawn them into a background thread.
pub fn thread_main(conf: SafeKeeperConf, pg_listener: std::net::TcpListener) {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.context("create runtime")
// todo catch error in main thread
.expect("failed to create runtime");
pub fn thread_main(conf: SafeKeeperConf, listener: TcpListener) -> ! {
loop {
match listener.accept() {
Ok((socket, peer_addr)) => {
debug!("accepted connection from {}", peer_addr);
let conf = conf.clone();
runtime
.block_on(async move {
// Tokio's from_std won't do this for us, per its comment.
pg_listener.set_nonblocking(true)?;
let listener = tokio::net::TcpListener::from_std(pg_listener)?;
let mut connection_count: ConnectionCount = 0;
loop {
match listener.accept().await {
Ok((socket, peer_addr)) => {
debug!("accepted connection from {}", peer_addr);
let conf = conf.clone();
let conn_id = issue_connection_id(&mut connection_count);
let _ = thread::Builder::new()
.name("WAL service thread".into())
.spawn(move || {
if let Err(err) = handle_socket(socket, conf, conn_id) {
error!("connection handler exited: {}", err);
}
})
.unwrap();
}
Err(e) => error!("Failed to accept connection: {}", e),
}
let _ = thread::Builder::new()
.name("WAL service thread".into())
.spawn(move || {
if let Err(err) = handle_socket(socket, conf) {
error!("connection handler exited: {}", err);
}
})
.unwrap();
}
#[allow(unreachable_code)] // hint compiler the closure return type
Ok::<(), anyhow::Error>(())
})
.expect("listener failed")
Err(e) => error!("Failed to accept connection: {}", e),
}
}
}
// Get unique thread id (Rust internal), with ThreadId removed for shorter printing
fn get_tid() -> u64 {
let tids = format!("{:?}", thread::current().id());
let r = Regex::new(r"ThreadId\((\d+)\)").unwrap();
let caps = r.captures(&tids).unwrap();
caps.get(1).unwrap().as_str().parse().unwrap()
}
/// This is run by `thread_main` above, inside a background thread.
///
fn handle_socket(
socket: TcpStream,
conf: SafeKeeperConf,
conn_id: ConnectionId,
) -> Result<(), QueryError> {
let _enter = info_span!("", cid = %conn_id).entered();
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
let local = tokio::task::LocalSet::new();
fn handle_socket(socket: TcpStream, conf: SafeKeeperConf) -> Result<(), QueryError> {
let _enter = info_span!("", tid = ?get_tid()).entered();
socket.set_nodelay(true)?;
@@ -73,23 +53,10 @@ fn handle_socket(
None => AuthType::Trust,
Some(_) => AuthType::NeonJWT,
};
let mut conn_handler = SafekeeperPostgresHandler::new(conf, conn_id);
let pgbackend = PostgresBackend::new(socket, auth_type, None)?;
// libpq protocol between safekeeper and walproposer / pageserver
// We don't use shutdown.
local.block_on(
&runtime,
pgbackend.run(&mut conn_handler, future::pending::<()>),
)?;
let mut conn_handler = SafekeeperPostgresHandler::new(conf);
let pgbackend = PostgresBackend::new(socket, auth_type, None, false)?;
// libpq replication protocol between safekeeper and replicas/pagers
pgbackend.run(&mut conn_handler)?;
Ok(())
}
/// Unique WAL service connection ids are logged in spans for observability.
pub type ConnectionId = u32;
pub type ConnectionCount = u32;
pub fn issue_connection_id(count: &mut ConnectionCount) -> ConnectionId {
*count = count.wrapping_add(1);
*count
}

View File

@@ -165,16 +165,6 @@ impl PhysicalStorage {
})
}
/// Get all known state of the storage.
pub fn internal_state(&self) -> (Lsn, Lsn, Lsn, bool) {
(
self.write_lsn,
self.write_record_lsn,
self.flush_record_lsn,
self.file.is_some(),
)
}
/// Call fdatasync if config requires so.
fn fdatasync_file(&mut self, file: &mut File) -> Result<()> {
if !self.conf.no_sync {
@@ -471,7 +461,7 @@ pub struct WalReader {
timeline_dir: PathBuf,
wal_seg_size: usize,
pos: Lsn,
wal_segment: Option<Pin<Box<dyn AsyncRead + Send + Sync>>>,
wal_segment: Option<Pin<Box<dyn AsyncRead>>>,
// S3 will be used to read WAL if LSN is not available locally
enable_remote_read: bool,
@@ -538,7 +528,7 @@ impl WalReader {
}
/// Open WAL segment at the current position of the reader.
async fn open_segment(&self) -> Result<Pin<Box<dyn AsyncRead + Send + Sync>>> {
async fn open_segment(&self) -> Result<Pin<Box<dyn AsyncRead>>> {
let xlogoff = self.pos.segment_offset(self.wal_seg_size);
let segno = self.pos.segment_number(self.wal_seg_size);
let wal_file_name = XLogFileName(PG_TLI, segno, self.wal_seg_size);

View File

@@ -354,26 +354,29 @@ class NeonBenchmarker:
"""
Fetch the "cumulative # of bytes written" metric from the pageserver
"""
return self.get_int_counter_value(
pageserver, "libmetrics_disk_io_bytes_total", {"io_operation": "write"}
)
metric_name = r'libmetrics_disk_io_bytes_total{io_operation="write"}'
return self.get_int_counter_value(pageserver, metric_name)
def get_peak_mem(self, pageserver: NeonPageserver) -> int:
"""
Fetch the "maxrss" metric from the pageserver
"""
return self.get_int_counter_value(pageserver, "libmetrics_maxrss_kb")
metric_name = r"libmetrics_maxrss_kb"
return self.get_int_counter_value(pageserver, metric_name)
def get_int_counter_value(
self,
pageserver: NeonPageserver,
metric_name: str,
label_filters: Optional[Dict[str, str]] = None,
) -> int:
def get_int_counter_value(self, pageserver: NeonPageserver, metric_name: str) -> int:
"""Fetch the value of given int counter from pageserver metrics."""
# TODO: If we start to collect more of the prometheus metrics in the
# performance test suite like this, we should refactor this to load and
# parse all the metrics into a more convenient structure in one go.
#
# The metric should be an integer, as it's a number of bytes. But in general
# all prometheus metrics are floats. So to be pedantic, read it as a float
# and round to integer.
all_metrics = pageserver.http_client().get_metrics()
sample = all_metrics.query_one(metric_name, label_filters)
return int(round(sample.value))
matches = re.search(rf"^{metric_name} (\S+)$", all_metrics, re.MULTILINE)
assert matches, f"metric {metric_name} not found"
return int(round(float(matches.group(1))))
def get_timeline_size(
self, repo_dir: Path, tenant_id: TenantId, timeline_id: TimelineId

View File

@@ -144,12 +144,12 @@ class NeonCompare(PgCompare):
"size", timeline_size / (1024 * 1024), "MB", report=MetricReport.LOWER_IS_BETTER
)
metric_filters = {"tenant_id": str(self.tenant), "timeline_id": str(self.timeline)}
params = f'{{tenant_id="{self.tenant}",timeline_id="{self.timeline}"}}'
total_files = self.zenbenchmark.get_int_counter_value(
self.env.pageserver, "pageserver_created_persistent_files_total", metric_filters
self.env.pageserver, "pageserver_created_persistent_files_total" + params
)
total_bytes = self.zenbenchmark.get_int_counter_value(
self.env.pageserver, "pageserver_written_persistent_bytes_total", metric_filters
self.env.pageserver, "pageserver_written_persistent_bytes_total" + params
)
self.zenbenchmark.record(
"data_uploaded", total_bytes / (1024 * 1024), "MB", report=MetricReport.LOWER_IS_BETTER

View File

@@ -13,8 +13,7 @@ class Metrics:
self.metrics = defaultdict(list)
self.name = name
def query_all(self, name: str, filter: Optional[Dict[str, str]] = None) -> List[Sample]:
filter = filter or {}
def query_all(self, name: str, filter: Dict[str, str]) -> List[Sample]:
res = []
for sample in self.metrics[name]:
try:

View File

@@ -14,7 +14,6 @@ import tempfile
import textwrap
import time
import uuid
from collections import defaultdict
from contextlib import closing, contextmanager
from dataclasses import dataclass, field
from enum import Flag, auto
@@ -29,6 +28,7 @@ import asyncpg
import backoff # type: ignore
import boto3
import jwt
import prometheus_client
import psycopg2
import pytest
import requests
@@ -36,7 +36,7 @@ from _pytest.config import Config
from _pytest.config.argparsing import Parser
from _pytest.fixtures import FixtureRequest
from fixtures.log_helper import log
from fixtures.metrics import Metrics, parse_metrics
from fixtures.metrics import parse_metrics
from fixtures.types import Lsn, TenantId, TimelineId
from fixtures.utils import (
ATTACHMENT_NAME_REGEX,
@@ -45,6 +45,7 @@ from fixtures.utils import (
get_self_dir,
subprocess_capture,
)
from prometheus_client.parser import text_string_to_metric_families
# Type-related stuff
from psycopg2.extensions import connection as PgConnection
@@ -1435,27 +1436,22 @@ class PageserverHttpClient(requests.Session):
assert completed["successful_download_count"] > 0
return completed
def get_metrics_str(self) -> str:
"""You probably want to use get_metrics() instead."""
def get_metrics(self) -> str:
res = self.get(f"http://localhost:{self.port}/metrics")
self.verbose_error(res)
return res.text
def get_metrics(self) -> Metrics:
res = self.get_metrics_str()
return parse_metrics(res)
def get_timeline_metric(
self, tenant_id: TenantId, timeline_id: TimelineId, metric_name: str
) -> float:
metrics = self.get_metrics()
return metrics.query_one(
metric_name,
filter={
"tenant_id": str(tenant_id),
"timeline_id": str(timeline_id),
},
).value
def get_timeline_metric(self, tenant_id: TenantId, timeline_id: TimelineId, metric_name: str):
raw = self.get_metrics()
family: List[prometheus_client.Metric] = list(text_string_to_metric_families(raw))
[metric] = [m for m in family if m.name == metric_name]
[sample] = [
s
for s in metric.samples
if s.labels["tenant_id"] == str(tenant_id)
and s.labels["timeline_id"] == str(timeline_id)
]
return sample.value
def get_remote_timeline_client_metric(
self,
@@ -1465,7 +1461,7 @@ class PageserverHttpClient(requests.Session):
file_kind: str,
op_kind: str,
) -> Optional[float]:
metrics = self.get_metrics()
metrics = parse_metrics(self.get_metrics(), "pageserver")
matches = metrics.query_all(
name=metric_name,
filter={
@@ -1484,16 +1480,14 @@ class PageserverHttpClient(requests.Session):
assert len(matches) < 2, "above filter should uniquely identify metric"
return value
def get_metric_value(
self, name: str, filter: Optional[Dict[str, str]] = None
) -> Optional[float]:
def get_metric_value(self, name: str) -> Optional[str]:
metrics = self.get_metrics()
results = metrics.query_all(name, filter=filter)
if not results:
relevant = [line for line in metrics.splitlines() if line.startswith(name)]
if len(relevant) == 0:
log.info(f'could not find metric "{name}"')
return None
assert len(results) == 1, f"metric {name} with given filters is not unique, got: {results}"
return results[0].value
assert len(relevant) == 1
return relevant[0].lstrip(name).strip()
def layer_map_info(
self,
@@ -1522,11 +1516,6 @@ class PageserverHttpClient(requests.Session):
assert res.status_code == 200
def evict_all_layers(self, tenant_id: TenantId, timeline_id: TimelineId):
info = self.layer_map_info(tenant_id, timeline_id)
for layer in info.historic_layers:
self.evict_layer(tenant_id, timeline_id, layer.layer_file_name)
@dataclass
class TenantConfig:
@@ -1562,14 +1551,6 @@ class LayerMapInfo:
return info
def kind_count(self) -> Dict[str, int]:
counts: Dict[str, int] = defaultdict(int)
for inmem_layer in self.in_memory_layers:
counts[inmem_layer.kind] += 1
for hist_layer in self.historic_layers:
counts[hist_layer.kind] += 1
return counts
@dataclass
class InMemoryLayerInfo:
@@ -1586,7 +1567,7 @@ class InMemoryLayerInfo:
)
@dataclass(frozen=True)
@dataclass
class HistoricLayerInfo:
kind: str
layer_file_name: str
@@ -2068,10 +2049,8 @@ class NeonPageserver(PgProtocol):
".*Connection aborted: connection error: error communicating with the server: Broken pipe.*",
".*Connection aborted: connection error: error communicating with the server: Transport endpoint is not connected.*",
".*Connection aborted: connection error: error communicating with the server: Connection reset by peer.*",
# FIXME: replication patch for tokio_postgres regards any but CopyDone/CopyData message in CopyBoth stream as unexpected
".*Connection aborted: connection error: unexpected message from server*",
".*kill_and_wait_impl.*: wait successful.*",
".*Replication stream finished: db error:.*ending streaming to Some*",
".*Replication stream finished: db error: ERROR: Socket IO error: end streaming to Some.*",
".*query handler for 'pagestream.*failed: Broken pipe.*", # pageserver notices compute shut down
".*query handler for 'pagestream.*failed: Connection reset by peer.*", # pageserver notices compute shut down
# safekeeper connection can fail with this, in the window between timeline creation
@@ -3009,13 +2988,6 @@ class SafekeeperHttpClient(requests.Session):
def check_status(self):
self.get(f"http://localhost:{self.port}/v1/status").raise_for_status()
def debug_dump(self, params: Dict[str, str] = {}) -> Dict[str, Any]:
res = self.get(f"http://localhost:{self.port}/v1/debug_dump", params=params)
res.raise_for_status()
res_json = res.json()
assert isinstance(res_json, dict)
return res_json
def timeline_create(
self, tenant_id: TenantId, timeline_id: TimelineId, pg_version: int, commit_lsn: Lsn
):
@@ -3544,23 +3516,3 @@ def wait_for_sk_commit_lsn_to_reach_remote_storage(
ps_http.timeline_checkpoint(tenant_id, timeline_id)
wait_for_upload(ps_http, tenant_id, timeline_id, lsn)
return lsn
def wait_for_upload_queue_empty(
pageserver: NeonPageserver, tenant_id: TenantId, timeline_id: TimelineId
):
ps_http = pageserver.http_client()
while True:
all_metrics = ps_http.get_metrics()
tl = all_metrics.query_all(
"pageserver_remote_timeline_client_calls_unfinished",
{
"tenant_id": str(tenant_id),
"timeline_id": str(timeline_id),
},
)
assert len(tl) > 0
log.info(f"upload queue for {tenant_id}/{timeline_id}: {tl}")
if all(m.value == 0 for m in tl):
return
time.sleep(0.2)

View File

@@ -1,58 +0,0 @@
from contextlib import closing
import pytest
from fixtures.compare_fixtures import NeonCompare
from fixtures.neon_fixtures import wait_for_last_flush_lsn
#
# Test compaction and image layer creation performance.
#
# This creates a few tables and runs some simple INSERTs and UPDATEs on them to generate
# some delta layers. Then it runs manual compaction, measuring how long it takes.
#
@pytest.mark.timeout(1000)
def test_compaction(neon_compare: NeonCompare):
env = neon_compare.env
pageserver_http = env.pageserver.http_client()
tenant_id, timeline_id = env.neon_cli.create_tenant(
conf={
# Disable background GC and compaction, we'll run compaction manually.
"gc_period": "0s",
"compaction_period": "0s",
# Make checkpoint distance somewhat smaller than default, to create
# more delta layers quicker, to trigger compaction.
"checkpoint_distance": "25000000", # 25 MB
# Force image layer creation when we run compaction.
"image_creation_threshold": "1",
}
)
neon_compare.tenant = tenant_id
neon_compare.timeline = timeline_id
# Create some tables, and run a bunch of INSERTs and UPDATes on them,
# to generate WAL and layers
pg = env.postgres.create_start(
"main", tenant_id=tenant_id, config_lines=["shared_buffers=512MB"]
)
with closing(pg.connect()) as conn:
with conn.cursor() as cur:
for i in range(100):
cur.execute(f"create table tbl{i} (i int, j int);")
cur.execute(f"insert into tbl{i} values (generate_series(1, 1000), 0);")
for j in range(100):
cur.execute(f"update tbl{i} set j = {j};")
wait_for_last_flush_lsn(env, pg, tenant_id, timeline_id)
# First compaction generates L1 layers
with neon_compare.zenbenchmark.record_duration("compaction"):
pageserver_http.timeline_compact(tenant_id, timeline_id)
# And second compaction triggers image layer creation
with neon_compare.zenbenchmark.record_duration("image_creation"):
pageserver_http.timeline_compact(tenant_id, timeline_id)
neon_compare.report_size()

View File

@@ -8,7 +8,7 @@ def test_build_info_metric(neon_env_builder: NeonEnvBuilder, link_proxy: NeonPro
parsed_metrics = {}
parsed_metrics["pageserver"] = parse_metrics(env.pageserver.http_client().get_metrics_str())
parsed_metrics["pageserver"] = parse_metrics(env.pageserver.http_client().get_metrics())
parsed_metrics["safekeeper"] = parse_metrics(env.safekeepers[0].http_client().get_metrics_str())
parsed_metrics["proxy"] = parse_metrics(link_proxy.get_metrics())

View File

@@ -220,12 +220,9 @@ def prepare_snapshot(
for tenant in (repo_dir / "pgdatadirs" / "tenants").glob("*"):
shutil.rmtree(tenant)
# Remove wal-redo temp directory if it exists. Newer pageserver versions don't create
# them anymore, but old versions did.
# Remove wal-redo temp directory
for tenant in (repo_dir / "tenants").glob("*"):
wal_redo_dir = tenant / "wal-redo-datadir.___temp"
if wal_redo_dir.exists() and wal_redo_dir.is_dir():
shutil.rmtree(wal_redo_dir)
shutil.rmtree(tenant / "wal-redo-datadir.___temp")
# Update paths and ports in config files
pageserver_toml = repo_dir / "pageserver.toml"

View File

@@ -4,6 +4,7 @@ import random
import pytest
from fixtures.log_helper import log
from fixtures.metrics import parse_metrics
from fixtures.neon_fixtures import (
NeonEnv,
NeonEnvBuilder,
@@ -133,7 +134,7 @@ def test_gc_index_upload(neon_env_builder: NeonEnvBuilder, remote_storage_kind:
# Helper function that gets the number of given kind of remote ops from the metrics
def get_num_remote_ops(file_kind: str, op_kind: str) -> int:
ps_metrics = env.pageserver.http_client().get_metrics()
ps_metrics = parse_metrics(env.pageserver.http_client().get_metrics(), "pageserver")
total = 0.0
for sample in ps_metrics.query_all(
name="pageserver_remote_operation_seconds_count",

View File

@@ -1,13 +1,8 @@
import time
import pytest
from fixtures.log_helper import log
from fixtures.neon_fixtures import (
NeonEnvBuilder,
RemoteStorageKind,
wait_for_last_flush_lsn,
wait_for_last_record_lsn,
wait_for_sk_commit_lsn_to_reach_remote_storage,
wait_for_upload,
)
from fixtures.types import Lsn, TenantId, TimelineId
@@ -143,160 +138,3 @@ def test_basic_eviction(
assert (
redownloaded_layer_map_info == initial_layer_map_info
), "Should have the same layer map after redownloading the evicted layers"
def test_gc_of_remote_layers(neon_env_builder: NeonEnvBuilder):
neon_env_builder.enable_remote_storage(
remote_storage_kind=RemoteStorageKind.LOCAL_FS,
test_name="test_gc_of_remote_layers",
)
env = neon_env_builder.init_start()
tenant_config = {
"pitr_interval": "1s", # set to non-zero, so GC actually does something
"gc_period": "0s", # we want to control when GC runs
"compaction_period": "0s", # we want to control when compaction runs
"checkpoint_timeout": "24h", # something we won't reach
"checkpoint_distance": f"{50 * (1024**2)}", # something we won't reach, we checkpoint manually
"compaction_threshold": "3",
# "image_creation_threshold": set at runtime
"compaction_target_size": f"{128 * (1024**2)}", # make it so that we only have 1 partition => image coverage for delta layers => enables gc of delta layers
}
def tenant_update_config(changes):
tenant_config.update(changes)
env.neon_cli.config_tenant(tenant_id, tenant_config)
tenant_id, timeline_id = env.neon_cli.create_tenant(conf=tenant_config)
log.info("tenant id is %s", tenant_id)
env.initial_tenant = tenant_id # update_and_gc relies on this
ps_http = env.pageserver.http_client()
pg = env.postgres.create_start("main")
log.info("fill with data, creating delta & image layers, some of which are GC'able after")
# no particular reason to create the layers like this, but we are sure
# not to hit the image_creation_threshold here.
with pg.cursor() as cur:
cur.execute("create table a (id bigserial primary key, some_value bigint not null)")
cur.execute("insert into a(some_value) select i from generate_series(1, 10000) s(i)")
wait_for_last_flush_lsn(env, pg, tenant_id, timeline_id)
ps_http.timeline_checkpoint(tenant_id, timeline_id)
# Create delta layers, then turn them into image layers.
# Do it multiple times so that there's something to GC.
for k in range(0, 2):
# produce delta layers => disable image layer creation by setting high threshold
tenant_update_config({"image_creation_threshold": "100"})
for i in range(0, 2):
for j in range(0, 3):
# create a minimal amount of "delta difficulty" for this table
with pg.cursor() as cur:
cur.execute("update a set some_value = -some_value + %s", (j,))
with pg.cursor() as cur:
# vacuuming should aid to reuse keys, though it's not really important
# with image_creation_threshold=1 which we will use on the last compaction
cur.execute("vacuum")
wait_for_last_flush_lsn(env, pg, tenant_id, timeline_id)
if i == 1 and j == 2 and k == 1:
# last iteration; stop before checkpoint to avoid leaving an inmemory layer
pg.stop_and_destroy()
ps_http.timeline_checkpoint(tenant_id, timeline_id)
# images should not yet be created, because threshold is too high,
# but these will be reshuffled to L1 layers
ps_http.timeline_compact(tenant_id, timeline_id)
for _ in range(0, 20):
# loop in case flushing is still in progress
layers = ps_http.layer_map_info(tenant_id, timeline_id)
if not layers.in_memory_layers:
break
time.sleep(0.2)
# now that we've grown some delta layers, turn them into image layers
tenant_update_config({"image_creation_threshold": "1"})
ps_http.timeline_compact(tenant_id, timeline_id)
# wait for all uploads to finish
wait_for_sk_commit_lsn_to_reach_remote_storage(
tenant_id, timeline_id, env.safekeepers, env.pageserver
)
# shutdown safekeepers to avoid on-demand downloads from walreceiver
for sk in env.safekeepers:
sk.stop()
ps_http.timeline_checkpoint(tenant_id, timeline_id)
log.info("ensure the code above produced image and delta layers")
pre_evict_info = ps_http.layer_map_info(tenant_id, timeline_id)
log.info("layer map dump: %s", pre_evict_info)
by_kind = pre_evict_info.kind_count()
log.info("by kind: %s", by_kind)
assert by_kind["Image"] > 0
assert by_kind["Delta"] > 0
assert by_kind["InMemory"] == 0
resident_layers = list(env.timeline_dir(tenant_id, timeline_id).glob("*-*_*"))
log.info("resident layers count before eviction: %s", len(resident_layers))
log.info("evict all layers")
ps_http.evict_all_layers(tenant_id, timeline_id)
def ensure_resident_and_remote_size_metrics():
log.info("ensure that all the layers are gone")
resident_layers = list(env.timeline_dir(tenant_id, timeline_id).glob("*-*_*"))
# we have disabled all background loops, so, this should hold
assert len(resident_layers) == 0
info = ps_http.layer_map_info(tenant_id, timeline_id)
log.info("layer map dump: %s", info)
log.info("ensure that resident_physical_size metric is zero")
resident_physical_size_metric = ps_http.get_timeline_metric(
tenant_id, timeline_id, "pageserver_resident_physical_size"
)
assert resident_physical_size_metric == 0
log.info("ensure that resident_physical_size metric corresponds to layer map dump")
assert resident_physical_size_metric == sum(
[layer.layer_file_size or 0 for layer in info.historic_layers if not layer.remote]
)
log.info("ensure that remote_physical_size metric matches layer map")
remote_physical_size_metric = ps_http.get_timeline_metric(
tenant_id, timeline_id, "pageserver_remote_physical_size"
)
log.info("ensure that remote_physical_size metric corresponds to layer map dump")
assert remote_physical_size_metric == sum(
layer.layer_file_size or 0 for layer in info.historic_layers if layer.remote
)
log.info("before runnning GC, ensure that remote_physical size is zero")
ensure_resident_and_remote_size_metrics()
log.info("run GC")
time.sleep(2) # let pitr_interval + 1 second pass
ps_http.timeline_gc(tenant_id, timeline_id, 0)
time.sleep(1)
assert not env.pageserver.log_contains("Nothing to GC")
log.info("ensure GC deleted some layers, otherwise this test is pointless")
post_gc_info = ps_http.layer_map_info(tenant_id, timeline_id)
log.info("layer map dump: %s", post_gc_info)
log.info("by kind: %s", post_gc_info.kind_count())
pre_evict_layers = set([layer.layer_file_name for layer in pre_evict_info.historic_layers])
post_gc_layers = set([layer.layer_file_name for layer in post_gc_info.historic_layers])
assert post_gc_layers.issubset(pre_evict_layers)
assert len(post_gc_layers) < len(pre_evict_layers)
log.info("update_gc_info might download some layers. Evict them again.")
ps_http.evict_all_layers(tenant_id, timeline_id)
log.info("after running GC, ensure that resident size is still zero")
ensure_resident_and_remote_size_metrics()

View File

@@ -9,6 +9,7 @@ from typing import Iterator
import pytest
from fixtures.log_helper import log
from fixtures.metrics import parse_metrics
from fixtures.neon_fixtures import (
PSQL,
NeonEnvBuilder,
@@ -142,7 +143,7 @@ def test_metric_collection(
# Helper function that gets the number of given kind of remote ops from the metrics
def get_num_remote_ops(file_kind: str, op_kind: str) -> int:
ps_metrics = env.pageserver.http_client().get_metrics()
ps_metrics = parse_metrics(env.pageserver.http_client().get_metrics(), "pageserver")
total = 0.0
for sample in ps_metrics.query_all(
name="pageserver_remote_operation_seconds_count",

View File

@@ -4,14 +4,13 @@
import time
from collections import defaultdict
from pathlib import Path
from typing import Any, DefaultDict, Dict, Tuple
from typing import Any, DefaultDict, Dict
import pytest
from fixtures.log_helper import log
from fixtures.neon_fixtures import (
NeonEnvBuilder,
PageserverApiException,
PageserverHttpClient,
RemoteStorageKind,
assert_tenant_status,
available_remote_storages,
@@ -26,16 +25,9 @@ from fixtures.types import Lsn
from fixtures.utils import query_scalar
def get_num_downloaded_layers(client: PageserverHttpClient, tenant_id, timeline_id):
def get_num_downloaded_layers(client, tenant_id, timeline_id):
value = client.get_metric_value(
"pageserver_remote_operation_seconds_count",
{
"file_kind": "layer",
"op_kind": "download",
"status": "success",
"tenant_id": tenant_id,
"timeline_id": timeline_id,
},
f'pageserver_remote_operation_seconds_count{{file_kind="layer",op_kind="download",status="success",tenant_id="{tenant_id}",timeline_id="{timeline_id}"}}'
)
if value is None:
return 0
@@ -497,17 +489,6 @@ def test_compaction_downloads_on_demand_without_image_creation(
# pitr_interval and gc_horizon are not interesting because we dont run gc
}
def downloaded_bytes_and_count(pageserver_http: PageserverHttpClient) -> Tuple[int, int]:
m = pageserver_http.get_metrics()
# these are global counters
total_bytes = m.query_one("pageserver_remote_ondemand_downloaded_bytes_total").value
assert (
total_bytes < 2**53 and total_bytes.is_integer()
), "bytes should still be safe integer-in-f64"
count = m.query_one("pageserver_remote_ondemand_downloaded_layers_total").value
assert count < 2**53 and count.is_integer(), "count should still be safe integer-in-f64"
return (int(total_bytes), int(count))
# Override defaults, to create more layers
tenant_id, timeline_id = env.neon_cli.create_tenant(conf=stringify(conf))
env.initial_tenant = tenant_id
@@ -528,14 +509,10 @@ def test_compaction_downloads_on_demand_without_image_creation(
layers = pageserver_http.layer_map_info(tenant_id, timeline_id)
assert not layers.in_memory_layers, "no inmemory layers expected after post-commit checkpoint"
assert len(layers.historic_layers) == 1 + 2, "should have initdb layer and 2 deltas"
layer_sizes = 0
assert len(layers.historic_layers) == 1 + 2, "should have inidb layer and 2 deltas"
for layer in layers.historic_layers:
log.info(f"pre-compact: {layer}")
assert layer.layer_file_size is not None, "we must know layer file sizes"
layer_sizes += layer.layer_file_size
pageserver_http.evict_layer(tenant_id, timeline_id, layer.layer_file_name)
env.neon_cli.config_tenant(tenant_id, {"compaction_threshold": "3"})
@@ -546,12 +523,6 @@ def test_compaction_downloads_on_demand_without_image_creation(
log.info(f"post compact: {layer}")
assert len(layers.historic_layers) == 1, "should have compacted to single layer"
post_compact = downloaded_bytes_and_count(pageserver_http)
# use gte to allow pageserver to do other random stuff; this test could be run on a shared pageserver
assert post_compact[0] >= layer_sizes
assert post_compact[1] >= 3, "should had downloaded the three layers"
@pytest.mark.parametrize("remote_storage_kind", [RemoteStorageKind.MOCK_S3])
def test_compaction_downloads_on_demand_with_image_creation(

View File

@@ -45,6 +45,14 @@ def test_pageserver_restart(neon_env_builder: NeonEnvBuilder):
env.pageserver.stop()
env.pageserver.start()
# Stopping the pageserver breaks the connection from the postgres backend to
# the page server, and causes the next query on the connection to fail. Start a new
# postgres connection too, to avoid that error. (Ideally, the compute node would
# handle that and retry internally, without propagating the error to the user, but
# currently it doesn't...)
pg_conn = pg.connect()
cur = pg_conn.cursor()
cur.execute("SELECT count(*) FROM foo")
assert cur.fetchone() == (100000,)
@@ -62,6 +70,8 @@ def test_pageserver_restart(neon_env_builder: NeonEnvBuilder):
assert tenant_status["state"] == "Loading"
# Try to read. This waits until the loading finishes, and then return normally.
pg_conn = pg.connect()
cur = pg_conn.cursor()
cur.execute("SELECT count(*) FROM foo")
assert cur.fetchone() == (100000,)
@@ -122,6 +132,14 @@ def test_pageserver_chaos(neon_env_builder: NeonEnvBuilder):
env.pageserver.stop(immediate=True)
env.pageserver.start()
# Stopping the pageserver breaks the connection from the postgres backend to
# the page server, and causes the next query on the connection to fail. Start a new
# postgres connection too, to avoid that error. (Ideally, the compute node would
# handle that and retry internally, without propagating the error to the user, but
# currently it doesn't...)
pg_conn = pg.connect()
cur = pg_conn.cursor()
# Check that all the updates are visible
num_updates = pg.safe_psql("SELECT sum(updates) FROM foo")[0][0]
assert num_updates == i * 100000

View File

@@ -1,35 +0,0 @@
# This test spawns pgbench in a thread in the background and concurrently restarts pageserver,
# checking how client is able to transparently restore connection to pageserver
#
import threading
import time
from fixtures.log_helper import log
from fixtures.neon_fixtures import NeonEnv, PgBin, Postgres
# Test restarting page server, while safekeeper and compute node keep
# running.
def test_pageserver_restarts_under_worload(neon_simple_env: NeonEnv, pg_bin: PgBin):
env = neon_simple_env
env.neon_cli.create_branch("test_pageserver_restarts")
pg = env.postgres.create_start("test_pageserver_restarts")
n_restarts = 10
scale = 10
def run_pgbench(pg: Postgres):
connstr = pg.connstr()
log.info(f"Start a pgbench workload on pg {connstr}")
pg_bin.run_capture(["pgbench", "-i", f"-s{scale}", connstr])
pg_bin.run_capture(["pgbench", f"-T{n_restarts}", connstr])
thread = threading.Thread(target=run_pgbench, args=(pg,), daemon=True)
thread.start()
for i in range(n_restarts):
# Stop the pageserver gracefully and restart it.
time.sleep(1)
env.pageserver.stop()
env.pageserver.start()
thread.join()

View File

@@ -233,8 +233,8 @@ def test_remote_storage_upload_queue_retries(
# disable background compaction and GC. We invoke it manually when we want it to happen.
"gc_period": "0s",
"compaction_period": "0s",
# create image layers eagerly, so that GC can remove some layers
"image_creation_threshold": "1",
# don't create image layers, that causes just noise
"image_creation_threshold": "10000",
}
)
@@ -301,7 +301,7 @@ def test_remote_storage_upload_queue_retries(
# Create more churn to generate all upload ops.
# The checkpoint / compact / gc ops will block because they call remote_client.wait_completion().
# So, run this in a different thread.
# So, run this in a differen thread.
churn_thread_result = [False]
def churn_while_failpoints_active(result):
@@ -395,8 +395,8 @@ def test_remote_timeline_client_calls_started_metric(
# disable background compaction and GC. We invoke it manually when we want it to happen.
"gc_period": "0s",
"compaction_period": "0s",
# create image layers eagerly, so that GC can remove some layers
"image_creation_threshold": "1",
# don't create image layers, that causes just noise
"image_creation_threshold": "10000",
}
)
@@ -618,9 +618,6 @@ def test_timeline_deletion_with_files_stuck_in_upload_queue(
# checkpoint operations. Hence, checkpoint is allowed to fail now.
log.info("sending delete request")
checkpoint_allowed_to_fail.set()
env.pageserver.allowed_errors.append(
".* ERROR .*Error processing HTTP request: InternalServerError\\(timeline is Stopping"
)
client.timeline_delete(tenant_id, timeline_id)
assert not timeline_path.exists()

Some files were not shown because too many files have changed in this diff Show More