Merge remote-tracking branch 'origin/main' into problame/batching-sidecar-task

This commit is contained in:
Christian Schwarz
2024-11-29 14:26:37 +01:00
175 changed files with 14270 additions and 851 deletions

View File

@@ -46,6 +46,9 @@ workspace-members = [
"utils",
"wal_craft",
"walproposer",
"postgres-protocol2",
"postgres-types2",
"tokio-postgres2",
]
# Write out exact versions rather than a semver range. (Defaults to false.)

View File

@@ -36,8 +36,8 @@ inputs:
description: 'Region name for real s3 tests'
required: false
default: ''
rerun_flaky:
description: 'Whether to rerun flaky tests'
rerun_failed:
description: 'Whether to rerun failed tests'
required: false
default: 'false'
pg_version:
@@ -108,7 +108,7 @@ runs:
COMPATIBILITY_SNAPSHOT_DIR: /tmp/compatibility_snapshot_pg${{ inputs.pg_version }}
ALLOW_BACKWARD_COMPATIBILITY_BREAKAGE: contains(github.event.pull_request.labels.*.name, 'backward compatibility breakage')
ALLOW_FORWARD_COMPATIBILITY_BREAKAGE: contains(github.event.pull_request.labels.*.name, 'forward compatibility breakage')
RERUN_FLAKY: ${{ inputs.rerun_flaky }}
RERUN_FAILED: ${{ inputs.rerun_failed }}
PG_VERSION: ${{ inputs.pg_version }}
shell: bash -euxo pipefail {0}
run: |
@@ -154,15 +154,8 @@ runs:
EXTRA_PARAMS="--out-dir $PERF_REPORT_DIR $EXTRA_PARAMS"
fi
if [ "${RERUN_FLAKY}" == "true" ]; then
mkdir -p $TEST_OUTPUT
poetry run ./scripts/flaky_tests.py "${TEST_RESULT_CONNSTR}" \
--days 7 \
--output "$TEST_OUTPUT/flaky.json" \
--pg-version "${DEFAULT_PG_VERSION}" \
--build-type "${BUILD_TYPE}"
EXTRA_PARAMS="--flaky-tests-json $TEST_OUTPUT/flaky.json $EXTRA_PARAMS"
if [ "${RERUN_FAILED}" == "true" ]; then
EXTRA_PARAMS="--reruns 2 $EXTRA_PARAMS"
fi
# We use pytest-split plugin to run benchmarks in parallel on different CI runners

View File

@@ -19,8 +19,8 @@ on:
description: 'debug or release'
required: true
type: string
pg-versions:
description: 'a json array of postgres versions to run regression tests on'
test-cfg:
description: 'a json object of postgres versions and lfc states to run regression tests on'
required: true
type: string
@@ -276,14 +276,14 @@ jobs:
options: --init --shm-size=512mb --ulimit memlock=67108864:67108864
strategy:
fail-fast: false
matrix:
pg_version: ${{ fromJson(inputs.pg-versions) }}
matrix: ${{ fromJSON(format('{{"include":{0}}}', inputs.test-cfg)) }}
steps:
- uses: actions/checkout@v4
with:
submodules: true
- name: Pytest regression tests
continue-on-error: ${{ matrix.lfc_state == 'with-lfc' }}
uses: ./.github/actions/run-python-test-set
timeout-minutes: 60
with:
@@ -293,13 +293,14 @@ jobs:
run_with_real_s3: true
real_s3_bucket: neon-github-ci-tests
real_s3_region: eu-central-1
rerun_flaky: true
rerun_failed: true
pg_version: ${{ matrix.pg_version }}
env:
TEST_RESULT_CONNSTR: ${{ secrets.REGRESS_TEST_RESULT_CONNSTR_NEW }}
CHECK_ONDISK_DATA_COMPATIBILITY: nonempty
BUILD_TAG: ${{ inputs.build-tag }}
PAGESERVER_VIRTUAL_FILE_IO_ENGINE: tokio-epoll-uring
USE_LFC: ${{ matrix.lfc_state == 'with-lfc' && 'true' || 'false' }}
# Temporary disable this step until we figure out why it's so flaky
# Ref https://github.com/neondatabase/neon/issues/4540

View File

@@ -541,7 +541,7 @@ jobs:
runs-on: ${{ matrix.RUNNER }}
container:
image: neondatabase/build-tools:pinned
image: neondatabase/build-tools:pinned-bookworm
credentials:
username: ${{ secrets.NEON_DOCKERHUB_USERNAME }}
password: ${{ secrets.NEON_DOCKERHUB_PASSWORD }}
@@ -558,12 +558,12 @@ jobs:
arch=$(uname -m | sed 's/x86_64/amd64/g' | sed 's/aarch64/arm64/g')
cd /home/nonroot
wget -q "https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-17/libpq5_17.2-1.pgdg110+1_${arch}.deb"
wget -q "https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-16/postgresql-client-16_16.6-1.pgdg110+1_${arch}.deb"
wget -q "https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-16/postgresql-16_16.6-1.pgdg110+1_${arch}.deb"
dpkg -x libpq5_17.2-1.pgdg110+1_${arch}.deb pg
dpkg -x postgresql-16_16.6-1.pgdg110+1_${arch}.deb pg
dpkg -x postgresql-client-16_16.6-1.pgdg110+1_${arch}.deb pg
wget -q "https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-17/libpq5_17.2-1.pgdg120+1_${arch}.deb"
wget -q "https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-16/postgresql-client-16_16.6-1.pgdg120+1_${arch}.deb"
wget -q "https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-16/postgresql-16_16.6-1.pgdg120+1_${arch}.deb"
dpkg -x libpq5_17.2-1.pgdg120+1_${arch}.deb pg
dpkg -x postgresql-16_16.6-1.pgdg120+1_${arch}.deb pg
dpkg -x postgresql-client-16_16.6-1.pgdg120+1_${arch}.deb pg
mkdir -p /tmp/neon/pg_install/v16/bin
ln -s /home/nonroot/pg/usr/lib/postgresql/16/bin/pgbench /tmp/neon/pg_install/v16/bin/pgbench

View File

@@ -2,6 +2,17 @@ name: Build build-tools image
on:
workflow_call:
inputs:
archs:
description: "Json array of architectures to build"
# Default values are set in `check-image` job, `set-variables` step
type: string
required: false
debians:
description: "Json array of Debian versions to build"
# Default values are set in `check-image` job, `set-variables` step
type: string
required: false
outputs:
image-tag:
description: "build-tools tag"
@@ -32,25 +43,37 @@ jobs:
check-image:
runs-on: ubuntu-22.04
outputs:
tag: ${{ steps.get-build-tools-tag.outputs.image-tag }}
found: ${{ steps.check-image.outputs.found }}
archs: ${{ steps.set-variables.outputs.archs }}
debians: ${{ steps.set-variables.outputs.debians }}
tag: ${{ steps.set-variables.outputs.image-tag }}
everything: ${{ steps.set-more-variables.outputs.everything }}
found: ${{ steps.set-more-variables.outputs.found }}
steps:
- uses: actions/checkout@v4
- name: Get build-tools image tag for the current commit
id: get-build-tools-tag
- name: Set variables
id: set-variables
env:
ARCHS: ${{ inputs.archs || '["x64","arm64"]' }}
DEBIANS: ${{ inputs.debians || '["bullseye","bookworm"]' }}
IMAGE_TAG: |
${{ hashFiles('build-tools.Dockerfile',
'.github/workflows/build-build-tools-image.yml') }}
run: |
echo "image-tag=${IMAGE_TAG}" | tee -a $GITHUB_OUTPUT
echo "archs=${ARCHS}" | tee -a ${GITHUB_OUTPUT}
echo "debians=${DEBIANS}" | tee -a ${GITHUB_OUTPUT}
echo "image-tag=${IMAGE_TAG}" | tee -a ${GITHUB_OUTPUT}
- name: Check if such tag found in the registry
id: check-image
- name: Set more variables
id: set-more-variables
env:
IMAGE_TAG: ${{ steps.get-build-tools-tag.outputs.image-tag }}
IMAGE_TAG: ${{ steps.set-variables.outputs.image-tag }}
EVERYTHING: |
${{ contains(fromJson(steps.set-variables.outputs.archs), 'x64') &&
contains(fromJson(steps.set-variables.outputs.archs), 'arm64') &&
contains(fromJson(steps.set-variables.outputs.debians), 'bullseye') &&
contains(fromJson(steps.set-variables.outputs.debians), 'bookworm') }}
run: |
if docker manifest inspect neondatabase/build-tools:${IMAGE_TAG}; then
found=true
@@ -58,8 +81,8 @@ jobs:
found=false
fi
echo "found=${found}" | tee -a $GITHUB_OUTPUT
echo "everything=${EVERYTHING}" | tee -a ${GITHUB_OUTPUT}
echo "found=${found}" | tee -a ${GITHUB_OUTPUT}
build-image:
needs: [ check-image ]
@@ -67,8 +90,8 @@ jobs:
strategy:
matrix:
debian-version: [ bullseye, bookworm ]
arch: [ x64, arm64 ]
arch: ${{ fromJson(needs.check-image.outputs.archs) }}
debian: ${{ fromJson(needs.check-image.outputs.debians) }}
runs-on: ${{ fromJson(format('["self-hosted", "{0}"]', matrix.arch == 'arm64' && 'large-arm64' || 'large')) }}
@@ -99,11 +122,11 @@ jobs:
push: true
pull: true
build-args: |
DEBIAN_VERSION=${{ matrix.debian-version }}
cache-from: type=registry,ref=cache.neon.build/build-tools:cache-${{ matrix.debian-version }}-${{ matrix.arch }}
cache-to: ${{ github.ref_name == 'main' && format('type=registry,ref=cache.neon.build/build-tools:cache-{0}-{1},mode=max', matrix.debian-version, matrix.arch) || '' }}
DEBIAN_VERSION=${{ matrix.debian }}
cache-from: type=registry,ref=cache.neon.build/build-tools:cache-${{ matrix.debian }}-${{ matrix.arch }}
cache-to: ${{ github.ref_name == 'main' && format('type=registry,ref=cache.neon.build/build-tools:cache-{0}-{1},mode=max', matrix.debian, matrix.arch) || '' }}
tags: |
neondatabase/build-tools:${{ needs.check-image.outputs.tag }}-${{ matrix.debian-version }}-${{ matrix.arch }}
neondatabase/build-tools:${{ needs.check-image.outputs.tag }}-${{ matrix.debian }}-${{ matrix.arch }}
merge-images:
needs: [ check-image, build-image ]
@@ -117,16 +140,22 @@ jobs:
- name: Create multi-arch image
env:
DEFAULT_DEBIAN_VERSION: bullseye
DEFAULT_DEBIAN_VERSION: bookworm
ARCHS: ${{ join(fromJson(needs.check-image.outputs.archs), ' ') }}
DEBIANS: ${{ join(fromJson(needs.check-image.outputs.debians), ' ') }}
EVERYTHING: ${{ needs.check-image.outputs.everything }}
IMAGE_TAG: ${{ needs.check-image.outputs.tag }}
run: |
for debian_version in bullseye bookworm; do
tags=("-t" "neondatabase/build-tools:${IMAGE_TAG}-${debian_version}")
if [ "${debian_version}" == "${DEFAULT_DEBIAN_VERSION}" ]; then
for debian in ${DEBIANS}; do
tags=("-t" "neondatabase/build-tools:${IMAGE_TAG}-${debian}")
if [ "${EVERYTHING}" == "true" ] && [ "${debian}" == "${DEFAULT_DEBIAN_VERSION}" ]; then
tags+=("-t" "neondatabase/build-tools:${IMAGE_TAG}")
fi
docker buildx imagetools create "${tags[@]}" \
neondatabase/build-tools:${IMAGE_TAG}-${debian_version}-x64 \
neondatabase/build-tools:${IMAGE_TAG}-${debian_version}-arm64
for arch in ${ARCHS}; do
tags+=("neondatabase/build-tools:${IMAGE_TAG}-${debian}-${arch}")
done
docker buildx imagetools create "${tags[@]}"
done

View File

@@ -253,7 +253,14 @@ jobs:
build-tag: ${{ needs.tag.outputs.build-tag }}
build-type: ${{ matrix.build-type }}
# Run tests on all Postgres versions in release builds and only on the latest version in debug builds
pg-versions: ${{ matrix.build-type == 'release' && '["v14", "v15", "v16", "v17"]' || '["v17"]' }}
# run without LFC on v17 release only
test-cfg: |
${{ matrix.build-type == 'release' && '[{"pg_version":"v14", "lfc_state": "without-lfc"},
{"pg_version":"v15", "lfc_state": "without-lfc"},
{"pg_version":"v16", "lfc_state": "without-lfc"},
{"pg_version":"v17", "lfc_state": "without-lfc"},
{"pg_version":"v17", "lfc_state": "with-lfc"}]'
|| '[{"pg_version":"v17", "lfc_state": "without-lfc"}]' }}
secrets: inherit
# Keep `benchmarks` job outside of `build-and-test-locally` workflow to make job failures non-blocking

View File

@@ -29,7 +29,7 @@ jobs:
trigger_bench_on_ec2_machine_in_eu_central_1:
runs-on: [ self-hosted, small ]
container:
image: neondatabase/build-tools:pinned
image: neondatabase/build-tools:pinned-bookworm
credentials:
username: ${{ secrets.NEON_DOCKERHUB_USERNAME }}
password: ${{ secrets.NEON_DOCKERHUB_PASSWORD }}

View File

@@ -94,7 +94,7 @@ jobs:
- name: Tag build-tools with `${{ env.TO_TAG }}` in Docker Hub, ECR, and ACR
env:
DEFAULT_DEBIAN_VERSION: bullseye
DEFAULT_DEBIAN_VERSION: bookworm
run: |
for debian_version in bullseye bookworm; do
tags=()

View File

@@ -23,6 +23,8 @@ jobs:
id: python-src
with:
files: |
.github/workflows/_check-codestyle-python.yml
.github/workflows/build-build-tools-image.yml
.github/workflows/pre-merge-checks.yml
**/**.py
poetry.lock
@@ -38,6 +40,10 @@ jobs:
if: needs.get-changed-files.outputs.python-changed == 'true'
needs: [ get-changed-files ]
uses: ./.github/workflows/build-build-tools-image.yml
with:
# Build only one combination to save time
archs: '["x64"]'
debians: '["bookworm"]'
secrets: inherit
check-codestyle-python:
@@ -45,7 +51,8 @@ jobs:
needs: [ get-changed-files, build-build-tools-image ]
uses: ./.github/workflows/_check-codestyle-python.yml
with:
build-tools-image: ${{ needs.build-build-tools-image.outputs.image }}-bookworm
# `-bookworm-x64` suffix should match the combination in `build-build-tools-image`
build-tools-image: ${{ needs.build-build-tools-image.outputs.image }}-bookworm-x64
secrets: inherit
# To get items from the merge queue merged into main we need to satisfy "Status checks that are required".

121
Cargo.lock generated
View File

@@ -2180,9 +2180,9 @@ dependencies = [
[[package]]
name = "futures-channel"
version = "0.3.30"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78"
checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10"
dependencies = [
"futures-core",
"futures-sink",
@@ -2190,9 +2190,9 @@ dependencies = [
[[package]]
name = "futures-core"
version = "0.3.30"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d"
checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e"
[[package]]
name = "futures-executor"
@@ -2207,9 +2207,9 @@ dependencies = [
[[package]]
name = "futures-io"
version = "0.3.30"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1"
checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6"
[[package]]
name = "futures-lite"
@@ -2228,9 +2228,9 @@ dependencies = [
[[package]]
name = "futures-macro"
version = "0.3.30"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac"
checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650"
dependencies = [
"proc-macro2",
"quote",
@@ -2239,15 +2239,15 @@ dependencies = [
[[package]]
name = "futures-sink"
version = "0.3.30"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5"
checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7"
[[package]]
name = "futures-task"
version = "0.3.30"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004"
checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988"
[[package]]
name = "futures-timer"
@@ -2257,9 +2257,9 @@ checksum = "e64b03909df88034c26dc1547e8970b91f98bdb65165d6a4e9110d94263dbb2c"
[[package]]
name = "futures-util"
version = "0.3.30"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48"
checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81"
dependencies = [
"futures-channel",
"futures-core",
@@ -4139,7 +4139,7 @@ dependencies = [
[[package]]
name = "postgres"
version = "0.19.4"
source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#a130197713830a0ea0004b539b1f51a66b4c3e18"
source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#00940fcdb57a8e99e805297b75839e7c4c7b1796"
dependencies = [
"bytes",
"fallible-iterator",
@@ -4152,7 +4152,7 @@ dependencies = [
[[package]]
name = "postgres-protocol"
version = "0.6.4"
source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#2a2a7c56930dd5ad60676ce6da92e1cbe6fb3ef5"
source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#00940fcdb57a8e99e805297b75839e7c4c7b1796"
dependencies = [
"base64 0.20.0",
"byteorder",
@@ -4168,16 +4168,40 @@ dependencies = [
"tokio",
]
[[package]]
name = "postgres-protocol2"
version = "0.1.0"
dependencies = [
"base64 0.20.0",
"byteorder",
"bytes",
"fallible-iterator",
"hmac",
"md-5",
"memchr",
"rand 0.8.5",
"sha2",
"stringprep",
"tokio",
]
[[package]]
name = "postgres-types"
version = "0.2.4"
source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#2a2a7c56930dd5ad60676ce6da92e1cbe6fb3ef5"
source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#00940fcdb57a8e99e805297b75839e7c4c7b1796"
dependencies = [
"bytes",
"fallible-iterator",
"postgres-protocol",
"serde",
"serde_json",
]
[[package]]
name = "postgres-types2"
version = "0.1.0"
dependencies = [
"bytes",
"fallible-iterator",
"postgres-protocol2",
]
[[package]]
@@ -4188,7 +4212,7 @@ dependencies = [
"bytes",
"once_cell",
"pq_proto",
"rustls 0.23.16",
"rustls 0.23.18",
"rustls-pemfile 2.1.1",
"serde",
"thiserror",
@@ -4507,7 +4531,7 @@ dependencies = [
"parquet_derive",
"pbkdf2",
"pin-project-lite",
"postgres-protocol",
"postgres-protocol2",
"postgres_backend",
"pq_proto",
"prometheus",
@@ -4524,7 +4548,7 @@ dependencies = [
"rsa",
"rstest",
"rustc-hash",
"rustls 0.23.16",
"rustls 0.23.18",
"rustls-native-certs 0.8.0",
"rustls-pemfile 2.1.1",
"scopeguard",
@@ -4542,8 +4566,7 @@ dependencies = [
"tikv-jemalloc-ctl",
"tikv-jemallocator",
"tokio",
"tokio-postgres",
"tokio-postgres-rustls",
"tokio-postgres2",
"tokio-rustls 0.26.0",
"tokio-tungstenite",
"tokio-util",
@@ -5237,9 +5260,9 @@ dependencies = [
[[package]]
name = "rustls"
version = "0.23.16"
version = "0.23.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eee87ff5d9b36712a58574e12e9f0ea80f915a5b0ac518d322b24a465617925e"
checksum = "9c9cc1d47e243d655ace55ed38201c19ae02c148ae56412ab8750e8f0166ab7f"
dependencies = [
"log",
"once_cell",
@@ -5370,6 +5393,7 @@ dependencies = [
"itertools 0.10.5",
"metrics",
"once_cell",
"pageserver_api",
"parking_lot 0.12.1",
"postgres",
"postgres-protocol",
@@ -5401,6 +5425,7 @@ dependencies = [
"tracing-subscriber",
"url",
"utils",
"wal_decoder",
"walproposer",
"workspace_hack",
]
@@ -5954,7 +5979,7 @@ dependencies = [
"once_cell",
"parking_lot 0.12.1",
"prost",
"rustls 0.23.16",
"rustls 0.23.18",
"tokio",
"tonic",
"tonic-build",
@@ -6037,7 +6062,7 @@ dependencies = [
"postgres_ffi",
"remote_storage",
"reqwest 0.12.4",
"rustls 0.23.16",
"rustls 0.23.18",
"rustls-native-certs 0.8.0",
"serde",
"serde_json",
@@ -6425,6 +6450,7 @@ dependencies = [
"libc",
"mio",
"num_cpus",
"parking_lot 0.12.1",
"pin-project-lite",
"signal-hook-registry",
"socket2",
@@ -6472,7 +6498,7 @@ dependencies = [
[[package]]
name = "tokio-postgres"
version = "0.7.7"
source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#2a2a7c56930dd5ad60676ce6da92e1cbe6fb3ef5"
source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#00940fcdb57a8e99e805297b75839e7c4c7b1796"
dependencies = [
"async-trait",
"byteorder",
@@ -6499,13 +6525,33 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04fb792ccd6bbcd4bba408eb8a292f70fc4a3589e5d793626f45190e6454b6ab"
dependencies = [
"ring",
"rustls 0.23.16",
"rustls 0.23.18",
"tokio",
"tokio-postgres",
"tokio-rustls 0.26.0",
"x509-certificate",
]
[[package]]
name = "tokio-postgres2"
version = "0.1.0"
dependencies = [
"async-trait",
"byteorder",
"bytes",
"fallible-iterator",
"futures-util",
"log",
"parking_lot 0.12.1",
"percent-encoding",
"phf",
"pin-project-lite",
"postgres-protocol2",
"postgres-types2",
"tokio",
"tokio-util",
]
[[package]]
name = "tokio-rustls"
version = "0.24.0"
@@ -6533,7 +6579,7 @@ version = "0.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4"
dependencies = [
"rustls 0.23.16",
"rustls 0.23.18",
"rustls-pki-types",
"tokio",
]
@@ -6942,7 +6988,7 @@ dependencies = [
"base64 0.22.1",
"log",
"once_cell",
"rustls 0.23.16",
"rustls 0.23.18",
"rustls-pki-types",
"url",
"webpki-roots 0.26.1",
@@ -7028,6 +7074,7 @@ dependencies = [
"serde_assert",
"serde_json",
"serde_path_to_error",
"serde_with",
"signal-hook",
"strum",
"strum_macros",
@@ -7124,10 +7171,16 @@ name = "wal_decoder"
version = "0.1.0"
dependencies = [
"anyhow",
"async-compression",
"bytes",
"pageserver_api",
"postgres_ffi",
"prost",
"serde",
"thiserror",
"tokio",
"tonic",
"tonic-build",
"tracing",
"utils",
"workspace_hack",
@@ -7595,7 +7648,6 @@ dependencies = [
"num-traits",
"once_cell",
"parquet",
"postgres-types",
"prettyplease",
"proc-macro2",
"prost",
@@ -7605,7 +7657,7 @@ dependencies = [
"regex-automata 0.4.3",
"regex-syntax 0.8.2",
"reqwest 0.12.4",
"rustls 0.23.16",
"rustls 0.23.18",
"scopeguard",
"serde",
"serde_json",
@@ -7620,7 +7672,6 @@ dependencies = [
"time",
"time-macros",
"tokio",
"tokio-postgres",
"tokio-rustls 0.26.0",
"tokio-stream",
"tokio-util",

View File

@@ -35,6 +35,9 @@ members = [
"libs/walproposer",
"libs/wal_decoder",
"libs/postgres_initdb",
"libs/proxy/postgres-protocol2",
"libs/proxy/postgres-types2",
"libs/proxy/tokio-postgres2",
]
[workspace.package]

View File

@@ -7,7 +7,7 @@ ARG IMAGE=build-tools
ARG TAG=pinned
ARG DEFAULT_PG_VERSION=17
ARG STABLE_PG_VERSION=16
ARG DEBIAN_VERSION=bullseye
ARG DEBIAN_VERSION=bookworm
ARG DEBIAN_FLAVOR=${DEBIAN_VERSION}-slim
# Build Postgres

View File

@@ -38,6 +38,7 @@ ifeq ($(UNAME_S),Linux)
# Seccomp BPF is only available for Linux
PG_CONFIGURE_OPTS += --with-libseccomp
else ifeq ($(UNAME_S),Darwin)
PG_CFLAGS += -DUSE_PREFETCH
ifndef DISABLE_HOMEBREW
# macOS with brew-installed openssl requires explicit paths
# It can be configured with OPENSSL_PREFIX variable
@@ -146,6 +147,8 @@ postgres-%: postgres-configure-% \
$(MAKE) -C $(POSTGRES_INSTALL_DIR)/build/$*/contrib/pg_prewarm install
+@echo "Compiling pg_buffercache $*"
$(MAKE) -C $(POSTGRES_INSTALL_DIR)/build/$*/contrib/pg_buffercache install
+@echo "Compiling pg_visibility $*"
$(MAKE) -C $(POSTGRES_INSTALL_DIR)/build/$*/contrib/pg_visibility install
+@echo "Compiling pageinspect $*"
$(MAKE) -C $(POSTGRES_INSTALL_DIR)/build/$*/contrib/pageinspect install
+@echo "Compiling amcheck $*"

View File

@@ -1,4 +1,4 @@
ARG DEBIAN_VERSION=bullseye
ARG DEBIAN_VERSION=bookworm
FROM debian:bookworm-slim AS pgcopydb_builder
ARG DEBIAN_VERSION
@@ -57,9 +57,9 @@ RUN mkdir -p /pgcopydb/bin && \
mkdir -p /pgcopydb/lib && \
chmod -R 755 /pgcopydb && \
chown -R nonroot:nonroot /pgcopydb
COPY --from=pgcopydb_builder /usr/lib/postgresql/16/bin/pgcopydb /pgcopydb/bin/pgcopydb
COPY --from=pgcopydb_builder /pgcopydb/lib/libpq.so.5 /pgcopydb/lib/libpq.so.5
COPY --from=pgcopydb_builder /usr/lib/postgresql/16/bin/pgcopydb /pgcopydb/bin/pgcopydb
COPY --from=pgcopydb_builder /pgcopydb/lib/libpq.so.5 /pgcopydb/lib/libpq.so.5
# System deps
#
@@ -258,14 +258,14 @@ WORKDIR /home/nonroot
# Rust
# Please keep the version of llvm (installed above) in sync with rust llvm (`rustc --version --verbose | grep LLVM`)
ENV RUSTC_VERSION=1.82.0
ENV RUSTC_VERSION=1.83.0
ENV RUSTUP_HOME="/home/nonroot/.rustup"
ENV PATH="/home/nonroot/.cargo/bin:${PATH}"
ARG RUSTFILT_VERSION=0.2.1
ARG CARGO_HAKARI_VERSION=0.9.30
ARG CARGO_DENY_VERSION=0.16.1
ARG CARGO_HACK_VERSION=0.6.31
ARG CARGO_NEXTEST_VERSION=0.9.72
ARG CARGO_HAKARI_VERSION=0.9.33
ARG CARGO_DENY_VERSION=0.16.2
ARG CARGO_HACK_VERSION=0.6.33
ARG CARGO_NEXTEST_VERSION=0.9.85
RUN curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux-gnu/rustup-init && whoami && \
chmod +x rustup-init && \
./rustup-init -y --default-toolchain ${RUSTC_VERSION} && \
@@ -289,7 +289,7 @@ RUN whoami \
&& cargo --version --verbose \
&& rustup --version --verbose \
&& rustc --version --verbose \
&& clang --version
&& clang --version
RUN if [ "${DEBIAN_VERSION}" = "bookworm" ]; then \
LD_LIBRARY_PATH=/pgcopydb/lib /pgcopydb/bin/pgcopydb --version; \

View File

@@ -3,7 +3,7 @@ ARG REPOSITORY=neondatabase
ARG IMAGE=build-tools
ARG TAG=pinned
ARG BUILD_TAG
ARG DEBIAN_VERSION=bullseye
ARG DEBIAN_VERSION=bookworm
ARG DEBIAN_FLAVOR=${DEBIAN_VERSION}-slim
#########################################################################################

View File

@@ -58,7 +58,7 @@ use compute_tools::compute::{
forward_termination_signal, ComputeNode, ComputeState, ParsedSpec, PG_PID,
};
use compute_tools::configurator::launch_configurator;
use compute_tools::extension_server::get_pg_version;
use compute_tools::extension_server::get_pg_version_string;
use compute_tools::http::api::launch_http_server;
use compute_tools::logger::*;
use compute_tools::monitor::launch_monitor;
@@ -326,7 +326,7 @@ fn wait_spec(
connstr: Url::parse(connstr).context("cannot parse connstr as a URL")?,
pgdata: pgdata.to_string(),
pgbin: pgbin.to_string(),
pgversion: get_pg_version(pgbin),
pgversion: get_pg_version_string(pgbin),
live_config_allowed,
state: Mutex::new(new_state),
state_changed: Condvar::new(),

View File

@@ -29,6 +29,7 @@ use anyhow::Context;
use aws_config::BehaviorVersion;
use camino::{Utf8Path, Utf8PathBuf};
use clap::Parser;
use compute_tools::extension_server::{get_pg_version, PostgresMajorVersion};
use nix::unistd::Pid;
use tracing::{info, info_span, warn, Instrument};
use utils::fs_ext::is_directory_empty;
@@ -131,11 +132,17 @@ pub(crate) async fn main() -> anyhow::Result<()> {
//
// Initialize pgdata
//
let pg_version = match get_pg_version(pg_bin_dir.as_str()) {
PostgresMajorVersion::V14 => 14,
PostgresMajorVersion::V15 => 15,
PostgresMajorVersion::V16 => 16,
PostgresMajorVersion::V17 => 17,
};
let superuser = "cloud_admin"; // XXX: this shouldn't be hard-coded
postgres_initdb::do_run_initdb(postgres_initdb::RunInitdbArgs {
superuser,
locale: "en_US.UTF-8", // XXX: this shouldn't be hard-coded,
pg_version: 140000, // XXX: this shouldn't be hard-coded but derived from which compute image we're running in
pg_version,
initdb_bin: pg_bin_dir.join("initdb").as_ref(),
library_search_path: &pg_lib_dir, // TODO: is this right? Prob works in compute image, not sure about neon_local.
pgdata: &pgdata_dir,

View File

@@ -1,4 +1,3 @@
use compute_api::responses::CatalogObjects;
use futures::Stream;
use postgres::NoTls;
use std::{path::Path, process::Stdio, result::Result, sync::Arc};
@@ -13,7 +12,8 @@ use tokio_util::codec::{BytesCodec, FramedRead};
use tracing::warn;
use crate::compute::ComputeNode;
use crate::pg_helpers::{get_existing_dbs_async, get_existing_roles_async};
use crate::pg_helpers::{get_existing_dbs_async, get_existing_roles_async, postgres_conf_for_db};
use compute_api::responses::CatalogObjects;
pub async fn get_dbs_and_roles(compute: &Arc<ComputeNode>) -> anyhow::Result<CatalogObjects> {
let connstr = compute.connstr.clone();
@@ -43,6 +43,8 @@ pub enum SchemaDumpError {
DatabaseDoesNotExist,
#[error("Failed to execute pg_dump.")]
IO(#[from] std::io::Error),
#[error("Unexpected error.")]
Unexpected,
}
// It uses the pg_dump utility to dump the schema of the specified database.
@@ -60,11 +62,38 @@ pub async fn get_database_schema(
let pgbin = &compute.pgbin;
let basepath = Path::new(pgbin).parent().unwrap();
let pgdump = basepath.join("pg_dump");
let mut connstr = compute.connstr.clone();
connstr.set_path(dbname);
// Replace the DB in the connection string and disable it to parts.
// This is the only option to handle DBs with special characters.
let conf =
postgres_conf_for_db(&compute.connstr, dbname).map_err(|_| SchemaDumpError::Unexpected)?;
let host = conf
.get_hosts()
.first()
.ok_or(SchemaDumpError::Unexpected)?;
let host = match host {
tokio_postgres::config::Host::Tcp(ip) => ip.to_string(),
#[cfg(unix)]
tokio_postgres::config::Host::Unix(path) => path.to_string_lossy().to_string(),
};
let port = conf
.get_ports()
.first()
.ok_or(SchemaDumpError::Unexpected)?;
let user = conf.get_user().ok_or(SchemaDumpError::Unexpected)?;
let dbname = conf.get_dbname().ok_or(SchemaDumpError::Unexpected)?;
let mut cmd = Command::new(pgdump)
// XXX: this seems to be the only option to deal with DBs with `=` in the name
// See <https://www.postgresql.org/message-id/flat/20151023003445.931.91267%40wrigleys.postgresql.org>
.env("PGDATABASE", dbname)
.arg("--host")
.arg(host)
.arg("--port")
.arg(port.to_string())
.arg("--username")
.arg(user)
.arg("--schema-only")
.arg(connstr.as_str())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.kill_on_drop(true)

View File

@@ -34,9 +34,8 @@ use utils::measured_stream::MeasuredReader;
use nix::sys::signal::{kill, Signal};
use remote_storage::{DownloadError, RemotePath};
use tokio::spawn;
use url::Url;
use crate::installed_extensions::get_installed_extensions_sync;
use crate::installed_extensions::get_installed_extensions;
use crate::local_proxy;
use crate::pg_helpers::*;
use crate::spec::*;
@@ -816,30 +815,32 @@ impl ComputeNode {
Ok(())
}
async fn get_maintenance_client(url: &Url) -> Result<tokio_postgres::Client> {
let mut connstr = url.clone();
async fn get_maintenance_client(
conf: &tokio_postgres::Config,
) -> Result<tokio_postgres::Client> {
let mut conf = conf.clone();
connstr
.query_pairs_mut()
.append_pair("application_name", "apply_config");
conf.application_name("apply_config");
let (client, conn) = match tokio_postgres::connect(connstr.as_str(), NoTls).await {
let (client, conn) = match conf.connect(NoTls).await {
// If connection fails, it may be the old node with `zenith_admin` superuser.
//
// In this case we need to connect with old `zenith_admin` name
// and create new user. We cannot simply rename connected user,
// but we can create a new one and grant it all privileges.
Err(e) => match e.code() {
Some(&SqlState::INVALID_PASSWORD)
| Some(&SqlState::INVALID_AUTHORIZATION_SPECIFICATION) => {
// connect with zenith_admin if cloud_admin could not authenticate
// Connect with zenith_admin if cloud_admin could not authenticate
info!(
"cannot connect to postgres: {}, retrying with `zenith_admin` username",
e
);
let mut zenith_admin_connstr = connstr.clone();
zenith_admin_connstr
.set_username("zenith_admin")
.map_err(|_| anyhow::anyhow!("invalid connstr"))?;
let mut zenith_admin_conf = postgres::config::Config::from(conf.clone());
zenith_admin_conf.user("zenith_admin");
let mut client =
Client::connect(zenith_admin_connstr.as_str(), NoTls)
zenith_admin_conf.connect(NoTls)
.context("broken cloud_admin credential: tried connecting with cloud_admin but could not authenticate, and zenith_admin does not work either")?;
// Disable forwarding so that users don't get a cloud_admin role
@@ -853,8 +854,8 @@ impl ComputeNode {
drop(client);
// reconnect with connstring with expected name
tokio_postgres::connect(connstr.as_str(), NoTls).await?
// Reconnect with connstring with expected name
conf.connect(NoTls).await?
}
_ => return Err(e.into()),
},
@@ -885,7 +886,7 @@ impl ComputeNode {
pub fn apply_spec_sql(
&self,
spec: Arc<ComputeSpec>,
url: Arc<Url>,
conf: Arc<tokio_postgres::Config>,
concurrency: usize,
) -> Result<()> {
let rt = tokio::runtime::Builder::new_multi_thread()
@@ -897,7 +898,7 @@ impl ComputeNode {
rt.block_on(async {
// Proceed with post-startup configuration. Note, that order of operations is important.
let client = Self::get_maintenance_client(&url).await?;
let client = Self::get_maintenance_client(&conf).await?;
let spec = spec.clone();
let databases = get_existing_dbs_async(&client).await?;
@@ -931,7 +932,7 @@ impl ComputeNode {
RenameAndDeleteDatabases,
CreateAndAlterDatabases,
] {
debug!("Applying phase {:?}", &phase);
info!("Applying phase {:?}", &phase);
apply_operations(
spec.clone(),
ctx.clone(),
@@ -942,6 +943,7 @@ impl ComputeNode {
.await?;
}
info!("Applying RunInEachDatabase phase");
let concurrency_token = Arc::new(tokio::sync::Semaphore::new(concurrency));
let db_processes = spec
@@ -955,7 +957,7 @@ impl ComputeNode {
let spec = spec.clone();
let ctx = ctx.clone();
let jwks_roles = jwks_roles.clone();
let mut url = url.as_ref().clone();
let mut conf = conf.as_ref().clone();
let concurrency_token = concurrency_token.clone();
let db = db.clone();
@@ -964,14 +966,14 @@ impl ComputeNode {
match &db {
DB::SystemDB => {}
DB::UserDB(db) => {
url.set_path(db.name.as_str());
conf.dbname(db.name.as_str());
}
}
let url = Arc::new(url);
let conf = Arc::new(conf);
let fut = Self::apply_spec_sql_db(
spec.clone(),
url,
conf,
ctx.clone(),
jwks_roles.clone(),
concurrency_token.clone(),
@@ -1017,7 +1019,7 @@ impl ComputeNode {
/// semaphore. The caller has to make sure the semaphore isn't exhausted.
async fn apply_spec_sql_db(
spec: Arc<ComputeSpec>,
url: Arc<Url>,
conf: Arc<tokio_postgres::Config>,
ctx: Arc<tokio::sync::RwLock<MutableApplyContext>>,
jwks_roles: Arc<HashSet<String>>,
concurrency_token: Arc<tokio::sync::Semaphore>,
@@ -1046,7 +1048,7 @@ impl ComputeNode {
// that database.
|| async {
if client_conn.is_none() {
let db_client = Self::get_maintenance_client(&url).await?;
let db_client = Self::get_maintenance_client(&conf).await?;
client_conn.replace(db_client);
}
let client = client_conn.as_ref().unwrap();
@@ -1061,34 +1063,16 @@ impl ComputeNode {
Ok::<(), anyhow::Error>(())
}
/// Do initial configuration of the already started Postgres.
#[instrument(skip_all)]
pub fn apply_config(&self, compute_state: &ComputeState) -> Result<()> {
// If connection fails,
// it may be the old node with `zenith_admin` superuser.
//
// In this case we need to connect with old `zenith_admin` name
// and create new user. We cannot simply rename connected user,
// but we can create a new one and grant it all privileges.
let mut url = self.connstr.clone();
url.query_pairs_mut()
.append_pair("application_name", "apply_config");
let url = Arc::new(url);
let spec = Arc::new(
compute_state
.pspec
.as_ref()
.expect("spec must be set")
.spec
.clone(),
);
// Choose how many concurrent connections to use for applying the spec changes.
// If the cluster is not currently Running we don't have to deal with user connections,
/// Choose how many concurrent connections to use for applying the spec changes.
pub fn max_service_connections(
&self,
compute_state: &ComputeState,
spec: &ComputeSpec,
) -> usize {
// If the cluster is in Init state we don't have to deal with user connections,
// and can thus use all `max_connections` connection slots. However, that's generally not
// very efficient, so we generally still limit it to a smaller number.
let max_concurrent_connections = if compute_state.status != ComputeStatus::Running {
if compute_state.status == ComputeStatus::Init {
// If the settings contain 'max_connections', use that as template
if let Some(config) = spec.cluster.settings.find("max_connections") {
config.parse::<usize>().ok()
@@ -1144,10 +1128,29 @@ impl ComputeNode {
.map(|val| if val > 1 { val - 1 } else { 1 })
.last()
.unwrap_or(3)
};
}
}
/// Do initial configuration of the already started Postgres.
#[instrument(skip_all)]
pub fn apply_config(&self, compute_state: &ComputeState) -> Result<()> {
let mut conf = tokio_postgres::Config::from_str(self.connstr.as_str()).unwrap();
conf.application_name("apply_config");
let conf = Arc::new(conf);
let spec = Arc::new(
compute_state
.pspec
.as_ref()
.expect("spec must be set")
.spec
.clone(),
);
let max_concurrent_connections = self.max_service_connections(compute_state, &spec);
// Merge-apply spec & changes to PostgreSQL state.
self.apply_spec_sql(spec.clone(), url.clone(), max_concurrent_connections)?;
self.apply_spec_sql(spec.clone(), conf.clone(), max_concurrent_connections)?;
if let Some(ref local_proxy) = &spec.clone().local_proxy_config {
info!("configuring local_proxy");
@@ -1156,12 +1159,11 @@ impl ComputeNode {
// Run migrations separately to not hold up cold starts
thread::spawn(move || {
let mut connstr = url.as_ref().clone();
connstr
.query_pairs_mut()
.append_pair("application_name", "migrations");
let conf = conf.as_ref().clone();
let mut conf = postgres::config::Config::from(conf);
conf.application_name("migrations");
let mut client = Client::connect(connstr.as_str(), NoTls)?;
let mut client = conf.connect(NoTls)?;
handle_migrations(&mut client).context("apply_config handle_migrations")
});
@@ -1222,21 +1224,28 @@ impl ComputeNode {
let pgdata_path = Path::new(&self.pgdata);
let postgresql_conf_path = pgdata_path.join("postgresql.conf");
config::write_postgres_conf(&postgresql_conf_path, &spec, None)?;
// temporarily reset max_cluster_size in config
// TODO(ololobus): We need a concurrency during reconfiguration as well,
// but DB is already running and used by user. We can easily get out of
// `max_connections` limit, and the current code won't handle that.
// let compute_state = self.state.lock().unwrap().clone();
// let max_concurrent_connections = self.max_service_connections(&compute_state, &spec);
let max_concurrent_connections = 1;
// Temporarily reset max_cluster_size in config
// to avoid the possibility of hitting the limit, while we are reconfiguring:
// creating new extensions, roles, etc...
// creating new extensions, roles, etc.
config::with_compute_ctl_tmp_override(pgdata_path, "neon.max_cluster_size=-1", || {
self.pg_reload_conf()?;
if spec.mode == ComputeMode::Primary {
let mut url = self.connstr.clone();
url.query_pairs_mut()
.append_pair("application_name", "apply_config");
let url = Arc::new(url);
let mut conf = tokio_postgres::Config::from_str(self.connstr.as_str()).unwrap();
conf.application_name("apply_config");
let conf = Arc::new(conf);
let spec = Arc::new(spec.clone());
self.apply_spec_sql(spec, url, 1)?;
self.apply_spec_sql(spec, conf, max_concurrent_connections)?;
}
Ok(())
@@ -1362,7 +1371,17 @@ impl ComputeNode {
let connstr = self.connstr.clone();
thread::spawn(move || {
get_installed_extensions_sync(connstr).context("get_installed_extensions")
let res = get_installed_extensions(&connstr);
match res {
Ok(extensions) => {
info!(
"[NEON_EXT_STAT] {}",
serde_json::to_string(&extensions)
.expect("failed to serialize extensions list")
);
}
Err(err) => error!("could not get installed extensions: {err:?}"),
}
});
}

View File

@@ -103,14 +103,33 @@ fn get_pg_config(argument: &str, pgbin: &str) -> String {
.to_string()
}
pub fn get_pg_version(pgbin: &str) -> String {
pub fn get_pg_version(pgbin: &str) -> PostgresMajorVersion {
// pg_config --version returns a (platform specific) human readable string
// such as "PostgreSQL 15.4". We parse this to v14/v15/v16 etc.
let human_version = get_pg_config("--version", pgbin);
parse_pg_version(&human_version).to_string()
parse_pg_version(&human_version)
}
fn parse_pg_version(human_version: &str) -> &str {
pub fn get_pg_version_string(pgbin: &str) -> String {
match get_pg_version(pgbin) {
PostgresMajorVersion::V14 => "v14",
PostgresMajorVersion::V15 => "v15",
PostgresMajorVersion::V16 => "v16",
PostgresMajorVersion::V17 => "v17",
}
.to_owned()
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum PostgresMajorVersion {
V14,
V15,
V16,
V17,
}
fn parse_pg_version(human_version: &str) -> PostgresMajorVersion {
use PostgresMajorVersion::*;
// Normal releases have version strings like "PostgreSQL 15.4". But there
// are also pre-release versions like "PostgreSQL 17devel" or "PostgreSQL
// 16beta2" or "PostgreSQL 17rc1". And with the --with-extra-version
@@ -121,10 +140,10 @@ fn parse_pg_version(human_version: &str) -> &str {
.captures(human_version)
{
Some(captures) if captures.len() == 2 => match &captures["major"] {
"14" => return "v14",
"15" => return "v15",
"16" => return "v16",
"17" => return "v17",
"14" => return V14,
"15" => return V15,
"16" => return V16,
"17" => return V17,
_ => {}
},
_ => {}
@@ -263,24 +282,25 @@ mod tests {
#[test]
fn test_parse_pg_version() {
assert_eq!(parse_pg_version("PostgreSQL 15.4"), "v15");
assert_eq!(parse_pg_version("PostgreSQL 15.14"), "v15");
use super::PostgresMajorVersion::*;
assert_eq!(parse_pg_version("PostgreSQL 15.4"), V15);
assert_eq!(parse_pg_version("PostgreSQL 15.14"), V15);
assert_eq!(
parse_pg_version("PostgreSQL 15.4 (Ubuntu 15.4-0ubuntu0.23.04.1)"),
"v15"
V15
);
assert_eq!(parse_pg_version("PostgreSQL 14.15"), "v14");
assert_eq!(parse_pg_version("PostgreSQL 14.0"), "v14");
assert_eq!(parse_pg_version("PostgreSQL 14.15"), V14);
assert_eq!(parse_pg_version("PostgreSQL 14.0"), V14);
assert_eq!(
parse_pg_version("PostgreSQL 14.9 (Debian 14.9-1.pgdg120+1"),
"v14"
V14
);
assert_eq!(parse_pg_version("PostgreSQL 16devel"), "v16");
assert_eq!(parse_pg_version("PostgreSQL 16beta1"), "v16");
assert_eq!(parse_pg_version("PostgreSQL 16rc2"), "v16");
assert_eq!(parse_pg_version("PostgreSQL 16extra"), "v16");
assert_eq!(parse_pg_version("PostgreSQL 16devel"), V16);
assert_eq!(parse_pg_version("PostgreSQL 16beta1"), V16);
assert_eq!(parse_pg_version("PostgreSQL 16rc2"), V16);
assert_eq!(parse_pg_version("PostgreSQL 16extra"), V16);
}
#[test]

View File

@@ -296,7 +296,12 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
}
let connstr = compute.connstr.clone();
let res = crate::installed_extensions::get_installed_extensions(connstr).await;
let res = task::spawn_blocking(move || {
installed_extensions::get_installed_extensions(&connstr)
})
.await
.unwrap();
match res {
Ok(res) => render_json(Body::from(serde_json::to_string(&res).unwrap())),
Err(e) => render_json_error(

View File

@@ -2,17 +2,16 @@ use compute_api::responses::{InstalledExtension, InstalledExtensions};
use metrics::proto::MetricFamily;
use std::collections::HashMap;
use std::collections::HashSet;
use tracing::info;
use url::Url;
use anyhow::Result;
use postgres::{Client, NoTls};
use tokio::task;
use metrics::core::Collector;
use metrics::{register_uint_gauge_vec, UIntGaugeVec};
use once_cell::sync::Lazy;
use crate::pg_helpers::postgres_conf_for_db;
/// We don't reuse get_existing_dbs() just for code clarity
/// and to make database listing query here more explicit.
///
@@ -42,75 +41,51 @@ fn list_dbs(client: &mut Client) -> Result<Vec<String>> {
///
/// Same extension can be installed in multiple databases with different versions,
/// we only keep the highest and lowest version across all databases.
pub async fn get_installed_extensions(connstr: Url) -> Result<InstalledExtensions> {
let mut connstr = connstr.clone();
pub fn get_installed_extensions(connstr: &url::Url) -> Result<InstalledExtensions> {
let mut client = Client::connect(connstr.as_str(), NoTls)?;
let databases: Vec<String> = list_dbs(&mut client)?;
task::spawn_blocking(move || {
let mut client = Client::connect(connstr.as_str(), NoTls)?;
let databases: Vec<String> = list_dbs(&mut client)?;
let mut extensions_map: HashMap<String, InstalledExtension> = HashMap::new();
for db in databases.iter() {
let config = postgres_conf_for_db(connstr, db)?;
let mut db_client = config.connect(NoTls)?;
let extensions: Vec<(String, String)> = db_client
.query(
"SELECT extname, extversion FROM pg_catalog.pg_extension;",
&[],
)?
.iter()
.map(|row| (row.get("extname"), row.get("extversion")))
.collect();
let mut extensions_map: HashMap<String, InstalledExtension> = HashMap::new();
for db in databases.iter() {
connstr.set_path(db);
let mut db_client = Client::connect(connstr.as_str(), NoTls)?;
let extensions: Vec<(String, String)> = db_client
.query(
"SELECT extname, extversion FROM pg_catalog.pg_extension;",
&[],
)?
.iter()
.map(|row| (row.get("extname"), row.get("extversion")))
.collect();
for (extname, v) in extensions.iter() {
let version = v.to_string();
for (extname, v) in extensions.iter() {
let version = v.to_string();
// increment the number of databases where the version of extension is installed
INSTALLED_EXTENSIONS
.with_label_values(&[extname, &version])
.inc();
// increment the number of databases where the version of extension is installed
INSTALLED_EXTENSIONS
.with_label_values(&[extname, &version])
.inc();
extensions_map
.entry(extname.to_string())
.and_modify(|e| {
e.versions.insert(version.clone());
// count the number of databases where the extension is installed
e.n_databases += 1;
})
.or_insert(InstalledExtension {
extname: extname.to_string(),
versions: HashSet::from([version.clone()]),
n_databases: 1,
});
}
extensions_map
.entry(extname.to_string())
.and_modify(|e| {
e.versions.insert(version.clone());
// count the number of databases where the extension is installed
e.n_databases += 1;
})
.or_insert(InstalledExtension {
extname: extname.to_string(),
versions: HashSet::from([version.clone()]),
n_databases: 1,
});
}
}
let res = InstalledExtensions {
extensions: extensions_map.values().cloned().collect(),
};
let res = InstalledExtensions {
extensions: extensions_map.values().cloned().collect(),
};
Ok(res)
})
.await?
}
// Gather info about installed extensions
pub fn get_installed_extensions_sync(connstr: Url) -> Result<()> {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("failed to create runtime");
let result = rt
.block_on(crate::installed_extensions::get_installed_extensions(
connstr,
))
.expect("failed to get installed extensions");
info!(
"[NEON_EXT_STAT] {}",
serde_json::to_string(&result).expect("failed to serialize extensions list")
);
Ok(())
Ok(res)
}
static INSTALLED_EXTENSIONS: Lazy<UIntGaugeVec> = Lazy::new(|| {

View File

@@ -6,6 +6,7 @@ use std::io::{BufRead, BufReader};
use std::os::unix::fs::PermissionsExt;
use std::path::Path;
use std::process::Child;
use std::str::FromStr;
use std::thread::JoinHandle;
use std::time::{Duration, Instant};
@@ -13,8 +14,10 @@ use anyhow::{bail, Result};
use futures::StreamExt;
use ini::Ini;
use notify::{RecursiveMode, Watcher};
use postgres::config::Config;
use tokio::io::AsyncBufReadExt;
use tokio::time::timeout;
use tokio_postgres;
use tokio_postgres::NoTls;
use tracing::{debug, error, info, instrument};
@@ -542,3 +545,11 @@ async fn handle_postgres_logs_async(stderr: tokio::process::ChildStderr) -> Resu
Ok(())
}
/// `Postgres::config::Config` handles database names with whitespaces
/// and special characters properly.
pub fn postgres_conf_for_db(connstr: &url::Url, dbname: &str) -> Result<Config> {
let mut conf = Config::from_str(connstr.as_str())?;
conf.dbname(dbname);
Ok(conf)
}

View File

@@ -415,6 +415,11 @@ impl PageServerNode {
.map(|x| x.parse::<bool>())
.transpose()
.context("Failed to parse 'timeline_offloading' as bool")?,
wal_receiver_protocol_override: settings
.remove("wal_receiver_protocol_override")
.map(serde_json::from_str)
.transpose()
.context("parse `wal_receiver_protocol_override` from json")?,
};
if !settings.is_empty() {
bail!("Unrecognized tenant settings: {settings:?}")

View File

@@ -33,7 +33,6 @@ reason = "the marvin attack only affects private key decryption, not public key
[licenses]
allow = [
"Apache-2.0",
"Artistic-2.0",
"BSD-2-Clause",
"BSD-3-Clause",
"CC0-1.0",
@@ -67,7 +66,7 @@ registries = []
# More documentation about the 'bans' section can be found here:
# https://embarkstudios.github.io/cargo-deny/checks/bans/cfg.html
[bans]
multiple-versions = "warn"
multiple-versions = "allow"
wildcards = "allow"
highlight = "all"
workspace-default-features = "allow"

View File

@@ -18,7 +18,7 @@ use std::{
str::FromStr,
time::Duration,
};
use utils::logging::LogFormat;
use utils::{logging::LogFormat, postgres_client::PostgresClientProtocol};
use crate::models::ImageCompressionAlgorithm;
use crate::models::LsnLease;
@@ -118,6 +118,7 @@ pub struct ConfigToml {
pub virtual_file_io_mode: Option<crate::models::virtual_file::IoMode>,
#[serde(skip_serializing_if = "Option::is_none")]
pub no_sync: Option<bool>,
pub wal_receiver_protocol: PostgresClientProtocol,
pub page_service_pipelining: PageServicePipeliningConfig,
}
@@ -298,6 +299,8 @@ pub struct TenantConfigToml {
/// Enable auto-offloading of timelines.
/// (either this flag or the pageserver-global one need to be set)
pub timeline_offloading: bool,
pub wal_receiver_protocol_override: Option<PostgresClientProtocol>,
}
pub mod defaults {
@@ -349,6 +352,9 @@ pub mod defaults {
pub const DEFAULT_EPHEMERAL_BYTES_PER_MEMORY_KB: usize = 0;
pub const DEFAULT_IO_BUFFER_ALIGNMENT: usize = 512;
pub const DEFAULT_WAL_RECEIVER_PROTOCOL: utils::postgres_client::PostgresClientProtocol =
utils::postgres_client::PostgresClientProtocol::Vanilla;
}
impl Default for ConfigToml {
@@ -435,6 +441,7 @@ impl Default for ConfigToml {
virtual_file_io_mode: None,
tenant_config: TenantConfigToml::default(),
no_sync: None,
wal_receiver_protocol: DEFAULT_WAL_RECEIVER_PROTOCOL,
page_service_pipelining: PageServicePipeliningConfig::Pipelined(
PageServicePipeliningConfigPipelined {
max_batch_size: NonZeroUsize::new(32).unwrap(),
@@ -528,6 +535,7 @@ impl Default for TenantConfigToml {
lsn_lease_length: LsnLease::DEFAULT_LENGTH,
lsn_lease_length_for_ts: LsnLease::DEFAULT_LENGTH_FOR_TS,
timeline_offloading: false,
wal_receiver_protocol_override: None,
}
}
}

View File

@@ -229,6 +229,18 @@ impl Key {
}
}
impl CompactKey {
pub fn raw(&self) -> i128 {
self.0
}
}
impl From<i128> for CompactKey {
fn from(value: i128) -> Self {
Self(value)
}
}
impl fmt::Display for Key {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(

View File

@@ -23,6 +23,7 @@ use utils::{
completion,
id::{NodeId, TenantId, TimelineId},
lsn::Lsn,
postgres_client::PostgresClientProtocol,
serde_system_time,
};
@@ -352,6 +353,7 @@ pub struct TenantConfig {
pub lsn_lease_length: Option<String>,
pub lsn_lease_length_for_ts: Option<String>,
pub timeline_offloading: Option<bool>,
pub wal_receiver_protocol_override: Option<PostgresClientProtocol>,
}
/// The policy for the aux file storage.

View File

@@ -562,6 +562,9 @@ pub enum BeMessage<'a> {
options: &'a [&'a str],
},
KeepAlive(WalSndKeepAlive),
/// Batch of interpreted, shard filtered WAL records,
/// ready for the pageserver to ingest
InterpretedWalRecords(InterpretedWalRecordsBody<'a>),
}
/// Common shorthands.
@@ -672,6 +675,22 @@ pub struct WalSndKeepAlive {
pub request_reply: bool,
}
/// Batch of interpreted WAL records used in the interpreted
/// safekeeper to pageserver protocol.
///
/// Note that the pageserver uses the RawInterpretedWalRecordsBody
/// counterpart of this from the neondatabase/rust-postgres repo.
/// If you're changing this struct, you likely need to change its
/// twin as well.
#[derive(Debug)]
pub struct InterpretedWalRecordsBody<'a> {
/// End of raw WAL in [`Self::data`]
pub streaming_lsn: u64,
/// Current end of WAL on the server
pub commit_lsn: u64,
pub data: &'a [u8],
}
pub static HELLO_WORLD_ROW: BeMessage = BeMessage::DataRow(&[Some(b"hello world")]);
// single text column
@@ -996,6 +1015,19 @@ impl BeMessage<'_> {
Ok(())
})?
}
BeMessage::InterpretedWalRecords(rec) => {
// We use the COPY_DATA_TAG for our custom message
// since this tag is interpreted as raw bytes.
buf.put_u8(b'd');
write_body(buf, |buf| {
buf.put_u8(b'0'); // matches INTERPRETED_WAL_RECORD_TAG in postgres-protocol
// dependency
buf.put_u64(rec.streaming_lsn);
buf.put_u64(rec.commit_lsn);
buf.put_slice(rec.data);
});
}
}
Ok(())
}

6
libs/proxy/README.md Normal file
View File

@@ -0,0 +1,6 @@
This directory contains libraries that are specific for proxy.
Currently, it contains a signficant fork/refactoring of rust-postgres that no longer reflects the API
of the original library. Since it was so significant, it made sense to upgrade it to it's own set of libraries.
Proxy needs unique access to the protocol, which explains why such heavy modifications were necessary.

View File

@@ -0,0 +1,21 @@
[package]
name = "postgres-protocol2"
version = "0.1.0"
edition = "2018"
license = "MIT/Apache-2.0"
[dependencies]
base64 = "0.20"
byteorder.workspace = true
bytes.workspace = true
fallible-iterator.workspace = true
hmac.workspace = true
md-5 = "0.10"
memchr = "2.0"
rand.workspace = true
sha2.workspace = true
stringprep = "0.1"
tokio = { workspace = true, features = ["rt"] }
[dev-dependencies]
tokio = { workspace = true, features = ["full"] }

View File

@@ -0,0 +1,37 @@
//! Authentication protocol support.
use md5::{Digest, Md5};
pub mod sasl;
/// Hashes authentication information in a way suitable for use in response
/// to an `AuthenticationMd5Password` message.
///
/// The resulting string should be sent back to the database in a
/// `PasswordMessage` message.
#[inline]
pub fn md5_hash(username: &[u8], password: &[u8], salt: [u8; 4]) -> String {
let mut md5 = Md5::new();
md5.update(password);
md5.update(username);
let output = md5.finalize_reset();
md5.update(format!("{:x}", output));
md5.update(salt);
format!("md5{:x}", md5.finalize())
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn md5() {
let username = b"md5_user";
let password = b"password";
let salt = [0x2a, 0x3d, 0x8f, 0xe0];
assert_eq!(
md5_hash(username, password, salt),
"md562af4dd09bbb41884907a838a3233294"
);
}
}

View File

@@ -0,0 +1,516 @@
//! SASL-based authentication support.
use hmac::{Hmac, Mac};
use rand::{self, Rng};
use sha2::digest::FixedOutput;
use sha2::{Digest, Sha256};
use std::fmt::Write;
use std::io;
use std::iter;
use std::mem;
use std::str;
use tokio::task::yield_now;
const NONCE_LENGTH: usize = 24;
/// The identifier of the SCRAM-SHA-256 SASL authentication mechanism.
pub const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
/// The identifier of the SCRAM-SHA-256-PLUS SASL authentication mechanism.
pub const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
// since postgres passwords are not required to exclude saslprep-prohibited
// characters or even be valid UTF8, we run saslprep if possible and otherwise
// return the raw password.
fn normalize(pass: &[u8]) -> Vec<u8> {
let pass = match str::from_utf8(pass) {
Ok(pass) => pass,
Err(_) => return pass.to_vec(),
};
match stringprep::saslprep(pass) {
Ok(pass) => pass.into_owned().into_bytes(),
Err(_) => pass.as_bytes().to_vec(),
}
}
pub(crate) async fn hi(str: &[u8], salt: &[u8], iterations: u32) -> [u8; 32] {
let mut hmac =
Hmac::<Sha256>::new_from_slice(str).expect("HMAC is able to accept all key sizes");
hmac.update(salt);
hmac.update(&[0, 0, 0, 1]);
let mut prev = hmac.finalize().into_bytes();
let mut hi = prev;
for i in 1..iterations {
let mut hmac = Hmac::<Sha256>::new_from_slice(str).expect("already checked above");
hmac.update(&prev);
prev = hmac.finalize().into_bytes();
for (hi, prev) in hi.iter_mut().zip(prev) {
*hi ^= prev;
}
// yield every ~250us
// hopefully reduces tail latencies
if i % 1024 == 0 {
yield_now().await
}
}
hi.into()
}
enum ChannelBindingInner {
Unrequested,
Unsupported,
TlsServerEndPoint(Vec<u8>),
}
/// The channel binding configuration for a SCRAM authentication exchange.
pub struct ChannelBinding(ChannelBindingInner);
impl ChannelBinding {
/// The server did not request channel binding.
pub fn unrequested() -> ChannelBinding {
ChannelBinding(ChannelBindingInner::Unrequested)
}
/// The server requested channel binding but the client is unable to provide it.
pub fn unsupported() -> ChannelBinding {
ChannelBinding(ChannelBindingInner::Unsupported)
}
/// The server requested channel binding and the client will use the `tls-server-end-point`
/// method.
pub fn tls_server_end_point(signature: Vec<u8>) -> ChannelBinding {
ChannelBinding(ChannelBindingInner::TlsServerEndPoint(signature))
}
fn gs2_header(&self) -> &'static str {
match self.0 {
ChannelBindingInner::Unrequested => "y,,",
ChannelBindingInner::Unsupported => "n,,",
ChannelBindingInner::TlsServerEndPoint(_) => "p=tls-server-end-point,,",
}
}
fn cbind_data(&self) -> &[u8] {
match self.0 {
ChannelBindingInner::Unrequested | ChannelBindingInner::Unsupported => &[],
ChannelBindingInner::TlsServerEndPoint(ref buf) => buf,
}
}
}
/// A pair of keys for the SCRAM-SHA-256 mechanism.
/// See <https://datatracker.ietf.org/doc/html/rfc5802#section-3> for details.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ScramKeys<const N: usize> {
/// Used by server to authenticate client.
pub client_key: [u8; N],
/// Used by client to verify server's signature.
pub server_key: [u8; N],
}
/// Password or keys which were derived from it.
enum Credentials<const N: usize> {
/// A regular password as a vector of bytes.
Password(Vec<u8>),
/// A precomputed pair of keys.
Keys(Box<ScramKeys<N>>),
}
enum State {
Update {
nonce: String,
password: Credentials<32>,
channel_binding: ChannelBinding,
},
Finish {
server_key: [u8; 32],
auth_message: String,
},
Done,
}
/// A type which handles the client side of the SCRAM-SHA-256/SCRAM-SHA-256-PLUS authentication
/// process.
///
/// During the authentication process, if the backend sends an `AuthenticationSASL` message which
/// includes `SCRAM-SHA-256` as an authentication mechanism, this type can be used.
///
/// After a `ScramSha256` is constructed, the buffer returned by the `message()` method should be
/// sent to the backend in a `SASLInitialResponse` message along with the mechanism name.
///
/// The server will reply with an `AuthenticationSASLContinue` message. Its contents should be
/// passed to the `update()` method, after which the buffer returned by the `message()` method
/// should be sent to the backend in a `SASLResponse` message.
///
/// The server will reply with an `AuthenticationSASLFinal` message. Its contents should be passed
/// to the `finish()` method, after which the authentication process is complete.
pub struct ScramSha256 {
message: String,
state: State,
}
fn nonce() -> String {
// rand 0.5's ThreadRng is cryptographically secure
let mut rng = rand::thread_rng();
(0..NONCE_LENGTH)
.map(|_| {
let mut v = rng.gen_range(0x21u8..0x7e);
if v == 0x2c {
v = 0x7e
}
v as char
})
.collect()
}
impl ScramSha256 {
/// Constructs a new instance which will use the provided password for authentication.
pub fn new(password: &[u8], channel_binding: ChannelBinding) -> ScramSha256 {
let password = Credentials::Password(normalize(password));
ScramSha256::new_inner(password, channel_binding, nonce())
}
/// Constructs a new instance which will use the provided key pair for authentication.
pub fn new_with_keys(keys: ScramKeys<32>, channel_binding: ChannelBinding) -> ScramSha256 {
let password = Credentials::Keys(keys.into());
ScramSha256::new_inner(password, channel_binding, nonce())
}
fn new_inner(
password: Credentials<32>,
channel_binding: ChannelBinding,
nonce: String,
) -> ScramSha256 {
ScramSha256 {
message: format!("{}n=,r={}", channel_binding.gs2_header(), nonce),
state: State::Update {
nonce,
password,
channel_binding,
},
}
}
/// Returns the message which should be sent to the backend in an `SASLResponse` message.
pub fn message(&self) -> &[u8] {
if let State::Done = self.state {
panic!("invalid SCRAM state");
}
self.message.as_bytes()
}
/// Updates the state machine with the response from the backend.
///
/// This should be called when an `AuthenticationSASLContinue` message is received.
pub async fn update(&mut self, message: &[u8]) -> io::Result<()> {
let (client_nonce, password, channel_binding) =
match mem::replace(&mut self.state, State::Done) {
State::Update {
nonce,
password,
channel_binding,
} => (nonce, password, channel_binding),
_ => return Err(io::Error::new(io::ErrorKind::Other, "invalid SCRAM state")),
};
let message =
str::from_utf8(message).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
let parsed = Parser::new(message).server_first_message()?;
if !parsed.nonce.starts_with(&client_nonce) {
return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid nonce"));
}
let (client_key, server_key) = match password {
Credentials::Password(password) => {
let salt = match base64::decode(parsed.salt) {
Ok(salt) => salt,
Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
};
let salted_password = hi(&password, &salt, parsed.iteration_count).await;
let make_key = |name| {
let mut hmac = Hmac::<Sha256>::new_from_slice(&salted_password)
.expect("HMAC is able to accept all key sizes");
hmac.update(name);
let mut key = [0u8; 32];
key.copy_from_slice(hmac.finalize().into_bytes().as_slice());
key
};
(make_key(b"Client Key"), make_key(b"Server Key"))
}
Credentials::Keys(keys) => (keys.client_key, keys.server_key),
};
let mut hash = Sha256::default();
hash.update(client_key);
let stored_key = hash.finalize_fixed();
let mut cbind_input = vec![];
cbind_input.extend(channel_binding.gs2_header().as_bytes());
cbind_input.extend(channel_binding.cbind_data());
let cbind_input = base64::encode(&cbind_input);
self.message.clear();
write!(&mut self.message, "c={},r={}", cbind_input, parsed.nonce).unwrap();
let auth_message = format!("n=,r={},{},{}", client_nonce, message, self.message);
let mut hmac = Hmac::<Sha256>::new_from_slice(&stored_key)
.expect("HMAC is able to accept all key sizes");
hmac.update(auth_message.as_bytes());
let client_signature = hmac.finalize().into_bytes();
let mut client_proof = client_key;
for (proof, signature) in client_proof.iter_mut().zip(client_signature) {
*proof ^= signature;
}
write!(&mut self.message, ",p={}", base64::encode(client_proof)).unwrap();
self.state = State::Finish {
server_key,
auth_message,
};
Ok(())
}
/// Finalizes the authentication process.
///
/// This should be called when the backend sends an `AuthenticationSASLFinal` message.
/// Authentication has only succeeded if this method returns `Ok(())`.
pub fn finish(&mut self, message: &[u8]) -> io::Result<()> {
let (server_key, auth_message) = match mem::replace(&mut self.state, State::Done) {
State::Finish {
server_key,
auth_message,
} => (server_key, auth_message),
_ => return Err(io::Error::new(io::ErrorKind::Other, "invalid SCRAM state")),
};
let message =
str::from_utf8(message).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
let parsed = Parser::new(message).server_final_message()?;
let verifier = match parsed {
ServerFinalMessage::Error(e) => {
return Err(io::Error::new(
io::ErrorKind::Other,
format!("SCRAM error: {}", e),
));
}
ServerFinalMessage::Verifier(verifier) => verifier,
};
let verifier = match base64::decode(verifier) {
Ok(verifier) => verifier,
Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
};
let mut hmac = Hmac::<Sha256>::new_from_slice(&server_key)
.expect("HMAC is able to accept all key sizes");
hmac.update(auth_message.as_bytes());
hmac.verify_slice(&verifier)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "SCRAM verification error"))
}
}
struct Parser<'a> {
s: &'a str,
it: iter::Peekable<str::CharIndices<'a>>,
}
impl<'a> Parser<'a> {
fn new(s: &'a str) -> Parser<'a> {
Parser {
s,
it: s.char_indices().peekable(),
}
}
fn eat(&mut self, target: char) -> io::Result<()> {
match self.it.next() {
Some((_, c)) if c == target => Ok(()),
Some((i, c)) => {
let m = format!(
"unexpected character at byte {}: expected `{}` but got `{}",
i, target, c
);
Err(io::Error::new(io::ErrorKind::InvalidInput, m))
}
None => Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"unexpected EOF",
)),
}
}
fn take_while<F>(&mut self, f: F) -> io::Result<&'a str>
where
F: Fn(char) -> bool,
{
let start = match self.it.peek() {
Some(&(i, _)) => i,
None => return Ok(""),
};
loop {
match self.it.peek() {
Some(&(_, c)) if f(c) => {
self.it.next();
}
Some(&(i, _)) => return Ok(&self.s[start..i]),
None => return Ok(&self.s[start..]),
}
}
}
fn printable(&mut self) -> io::Result<&'a str> {
self.take_while(|c| matches!(c, '\x21'..='\x2b' | '\x2d'..='\x7e'))
}
fn nonce(&mut self) -> io::Result<&'a str> {
self.eat('r')?;
self.eat('=')?;
self.printable()
}
fn base64(&mut self) -> io::Result<&'a str> {
self.take_while(|c| matches!(c, 'a'..='z' | 'A'..='Z' | '0'..='9' | '/' | '+' | '='))
}
fn salt(&mut self) -> io::Result<&'a str> {
self.eat('s')?;
self.eat('=')?;
self.base64()
}
fn posit_number(&mut self) -> io::Result<u32> {
let n = self.take_while(|c| c.is_ascii_digit())?;
n.parse()
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))
}
fn iteration_count(&mut self) -> io::Result<u32> {
self.eat('i')?;
self.eat('=')?;
self.posit_number()
}
fn eof(&mut self) -> io::Result<()> {
match self.it.peek() {
Some(&(i, _)) => Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("unexpected trailing data at byte {}", i),
)),
None => Ok(()),
}
}
fn server_first_message(&mut self) -> io::Result<ServerFirstMessage<'a>> {
let nonce = self.nonce()?;
self.eat(',')?;
let salt = self.salt()?;
self.eat(',')?;
let iteration_count = self.iteration_count()?;
self.eof()?;
Ok(ServerFirstMessage {
nonce,
salt,
iteration_count,
})
}
fn value(&mut self) -> io::Result<&'a str> {
self.take_while(|c| matches!(c, '\0' | '=' | ','))
}
fn server_error(&mut self) -> io::Result<Option<&'a str>> {
match self.it.peek() {
Some(&(_, 'e')) => {}
_ => return Ok(None),
}
self.eat('e')?;
self.eat('=')?;
self.value().map(Some)
}
fn verifier(&mut self) -> io::Result<&'a str> {
self.eat('v')?;
self.eat('=')?;
self.base64()
}
fn server_final_message(&mut self) -> io::Result<ServerFinalMessage<'a>> {
let message = match self.server_error()? {
Some(error) => ServerFinalMessage::Error(error),
None => ServerFinalMessage::Verifier(self.verifier()?),
};
self.eof()?;
Ok(message)
}
}
struct ServerFirstMessage<'a> {
nonce: &'a str,
salt: &'a str,
iteration_count: u32,
}
enum ServerFinalMessage<'a> {
Error(&'a str),
Verifier(&'a str),
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn parse_server_first_message() {
let message = "r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096";
let message = Parser::new(message).server_first_message().unwrap();
assert_eq!(message.nonce, "fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j");
assert_eq!(message.salt, "QSXCR+Q6sek8bf92");
assert_eq!(message.iteration_count, 4096);
}
// recorded auth exchange from psql
#[tokio::test]
async fn exchange() {
let password = "foobar";
let nonce = "9IZ2O01zb9IgiIZ1WJ/zgpJB";
let client_first = "n,,n=,r=9IZ2O01zb9IgiIZ1WJ/zgpJB";
let server_first =
"r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,s=fs3IXBy7U7+IvVjZ,i\
=4096";
let client_final =
"c=biws,r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,p=AmNKosjJzS3\
1NTlQYNs5BTeQjdHdk7lOflDo5re2an8=";
let server_final = "v=U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw=";
let mut scram = ScramSha256::new_inner(
Credentials::Password(normalize(password.as_bytes())),
ChannelBinding::unsupported(),
nonce.to_string(),
);
assert_eq!(str::from_utf8(scram.message()).unwrap(), client_first);
scram.update(server_first.as_bytes()).await.unwrap();
assert_eq!(str::from_utf8(scram.message()).unwrap(), client_final);
scram.finish(server_final.as_bytes()).unwrap();
}
}

View File

@@ -0,0 +1,93 @@
//! Provides functions for escaping literals and identifiers for use
//! in SQL queries.
//!
//! Prefer parameterized queries where possible. Do not escape
//! parameters in a parameterized query.
#[cfg(test)]
mod test;
/// Escape a literal and surround result with single quotes. Not
/// recommended in most cases.
///
/// If input contains backslashes, result will be of the form `
/// E'...'` so it is safe to use regardless of the setting of
/// standard_conforming_strings.
pub fn escape_literal(input: &str) -> String {
escape_internal(input, false)
}
/// Escape an identifier and surround result with double quotes.
pub fn escape_identifier(input: &str) -> String {
escape_internal(input, true)
}
// Translation of PostgreSQL libpq's PQescapeInternal(). Does not
// require a connection because input string is known to be valid
// UTF-8.
//
// Escape arbitrary strings. If as_ident is true, we escape the
// result as an identifier; if false, as a literal. The result is
// returned in a newly allocated buffer. If we fail due to an
// encoding violation or out of memory condition, we return NULL,
// storing an error message into conn.
fn escape_internal(input: &str, as_ident: bool) -> String {
let mut num_backslashes = 0;
let mut num_quotes = 0;
let quote_char = if as_ident { '"' } else { '\'' };
// Scan the string for characters that must be escaped.
for ch in input.chars() {
if ch == quote_char {
num_quotes += 1;
} else if ch == '\\' {
num_backslashes += 1;
}
}
// Allocate output String.
let mut result_size = input.len() + num_quotes + 3; // two quotes, plus a NUL
if !as_ident && num_backslashes > 0 {
result_size += num_backslashes + 2;
}
let mut output = String::with_capacity(result_size);
// If we are escaping a literal that contains backslashes, we use
// the escape string syntax so that the result is correct under
// either value of standard_conforming_strings. We also emit a
// leading space in this case, to guard against the possibility
// that the result might be interpolated immediately following an
// identifier.
if !as_ident && num_backslashes > 0 {
output.push(' ');
output.push('E');
}
// Opening quote.
output.push(quote_char);
// Use fast path if possible.
//
// We've already verified that the input string is well-formed in
// the current encoding. If it contains no quotes and, in the
// case of literal-escaping, no backslashes, then we can just copy
// it directly to the output buffer, adding the necessary quotes.
//
// If not, we must rescan the input and process each character
// individually.
if num_quotes == 0 && (num_backslashes == 0 || as_ident) {
output.push_str(input);
} else {
for ch in input.chars() {
if ch == quote_char || (!as_ident && ch == '\\') {
output.push(ch);
}
output.push(ch);
}
}
output.push(quote_char);
output
}

View File

@@ -0,0 +1,17 @@
use crate::escape::{escape_identifier, escape_literal};
#[test]
fn test_escape_idenifier() {
assert_eq!(escape_identifier("foo"), String::from("\"foo\""));
assert_eq!(escape_identifier("f\\oo"), String::from("\"f\\oo\""));
assert_eq!(escape_identifier("f'oo"), String::from("\"f'oo\""));
assert_eq!(escape_identifier("f\"oo"), String::from("\"f\"\"oo\""));
}
#[test]
fn test_escape_literal() {
assert_eq!(escape_literal("foo"), String::from("'foo'"));
assert_eq!(escape_literal("f\\oo"), String::from(" E'f\\\\oo'"));
assert_eq!(escape_literal("f'oo"), String::from("'f''oo'"));
assert_eq!(escape_literal("f\"oo"), String::from("'f\"oo'"));
}

View File

@@ -0,0 +1,78 @@
//! Low level Postgres protocol APIs.
//!
//! This crate implements the low level components of Postgres's communication
//! protocol, including message and value serialization and deserialization.
//! It is designed to be used as a building block by higher level APIs such as
//! `rust-postgres`, and should not typically be used directly.
//!
//! # Note
//!
//! This library assumes that the `client_encoding` backend parameter has been
//! set to `UTF8`. It will most likely not behave properly if that is not the case.
#![doc(html_root_url = "https://docs.rs/postgres-protocol/0.6")]
#![warn(missing_docs, rust_2018_idioms, clippy::all)]
use byteorder::{BigEndian, ByteOrder};
use bytes::{BufMut, BytesMut};
use std::io;
pub mod authentication;
pub mod escape;
pub mod message;
pub mod password;
pub mod types;
/// A Postgres OID.
pub type Oid = u32;
/// A Postgres Log Sequence Number (LSN).
pub type Lsn = u64;
/// An enum indicating if a value is `NULL` or not.
pub enum IsNull {
/// The value is `NULL`.
Yes,
/// The value is not `NULL`.
No,
}
fn write_nullable<F, E>(serializer: F, buf: &mut BytesMut) -> Result<(), E>
where
F: FnOnce(&mut BytesMut) -> Result<IsNull, E>,
E: From<io::Error>,
{
let base = buf.len();
buf.put_i32(0);
let size = match serializer(buf)? {
IsNull::No => i32::from_usize(buf.len() - base - 4)?,
IsNull::Yes => -1,
};
BigEndian::write_i32(&mut buf[base..], size);
Ok(())
}
trait FromUsize: Sized {
fn from_usize(x: usize) -> Result<Self, io::Error>;
}
macro_rules! from_usize {
($t:ty) => {
impl FromUsize for $t {
#[inline]
fn from_usize(x: usize) -> io::Result<$t> {
if x > <$t>::MAX as usize {
Err(io::Error::new(
io::ErrorKind::InvalidInput,
"value too large to transmit",
))
} else {
Ok(x as $t)
}
}
}
};
}
from_usize!(i16);
from_usize!(i32);

View File

@@ -0,0 +1,766 @@
#![allow(missing_docs)]
use byteorder::{BigEndian, ByteOrder, ReadBytesExt};
use bytes::{Bytes, BytesMut};
use fallible_iterator::FallibleIterator;
use memchr::memchr;
use std::cmp;
use std::io::{self, Read};
use std::ops::Range;
use std::str;
use crate::Oid;
// top-level message tags
const PARSE_COMPLETE_TAG: u8 = b'1';
const BIND_COMPLETE_TAG: u8 = b'2';
const CLOSE_COMPLETE_TAG: u8 = b'3';
pub const NOTIFICATION_RESPONSE_TAG: u8 = b'A';
const COPY_DONE_TAG: u8 = b'c';
const COMMAND_COMPLETE_TAG: u8 = b'C';
const COPY_DATA_TAG: u8 = b'd';
const DATA_ROW_TAG: u8 = b'D';
const ERROR_RESPONSE_TAG: u8 = b'E';
const COPY_IN_RESPONSE_TAG: u8 = b'G';
const COPY_OUT_RESPONSE_TAG: u8 = b'H';
const COPY_BOTH_RESPONSE_TAG: u8 = b'W';
const EMPTY_QUERY_RESPONSE_TAG: u8 = b'I';
const BACKEND_KEY_DATA_TAG: u8 = b'K';
pub const NO_DATA_TAG: u8 = b'n';
pub const NOTICE_RESPONSE_TAG: u8 = b'N';
const AUTHENTICATION_TAG: u8 = b'R';
const PORTAL_SUSPENDED_TAG: u8 = b's';
pub const PARAMETER_STATUS_TAG: u8 = b'S';
const PARAMETER_DESCRIPTION_TAG: u8 = b't';
const ROW_DESCRIPTION_TAG: u8 = b'T';
pub const READY_FOR_QUERY_TAG: u8 = b'Z';
#[derive(Debug, Copy, Clone)]
pub struct Header {
tag: u8,
len: i32,
}
#[allow(clippy::len_without_is_empty)]
impl Header {
#[inline]
pub fn parse(buf: &[u8]) -> io::Result<Option<Header>> {
if buf.len() < 5 {
return Ok(None);
}
let tag = buf[0];
let len = BigEndian::read_i32(&buf[1..]);
if len < 4 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"invalid message length: header length < 4",
));
}
Ok(Some(Header { tag, len }))
}
#[inline]
pub fn tag(self) -> u8 {
self.tag
}
#[inline]
pub fn len(self) -> i32 {
self.len
}
}
/// An enum representing Postgres backend messages.
#[non_exhaustive]
pub enum Message {
AuthenticationCleartextPassword,
AuthenticationGss,
AuthenticationKerberosV5,
AuthenticationMd5Password(AuthenticationMd5PasswordBody),
AuthenticationOk,
AuthenticationScmCredential,
AuthenticationSspi,
AuthenticationGssContinue,
AuthenticationSasl(AuthenticationSaslBody),
AuthenticationSaslContinue(AuthenticationSaslContinueBody),
AuthenticationSaslFinal(AuthenticationSaslFinalBody),
BackendKeyData(BackendKeyDataBody),
BindComplete,
CloseComplete,
CommandComplete(CommandCompleteBody),
CopyData,
CopyDone,
CopyInResponse,
CopyOutResponse,
CopyBothResponse,
DataRow(DataRowBody),
EmptyQueryResponse,
ErrorResponse(ErrorResponseBody),
NoData,
NoticeResponse(NoticeResponseBody),
NotificationResponse(NotificationResponseBody),
ParameterDescription(ParameterDescriptionBody),
ParameterStatus(ParameterStatusBody),
ParseComplete,
PortalSuspended,
ReadyForQuery(ReadyForQueryBody),
RowDescription(RowDescriptionBody),
}
impl Message {
#[inline]
pub fn parse(buf: &mut BytesMut) -> io::Result<Option<Message>> {
if buf.len() < 5 {
let to_read = 5 - buf.len();
buf.reserve(to_read);
return Ok(None);
}
let tag = buf[0];
let len = (&buf[1..5]).read_u32::<BigEndian>().unwrap();
if len < 4 {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid message length: parsing u32",
));
}
let total_len = len as usize + 1;
if buf.len() < total_len {
let to_read = total_len - buf.len();
buf.reserve(to_read);
return Ok(None);
}
let mut buf = Buffer {
bytes: buf.split_to(total_len).freeze(),
idx: 5,
};
let message = match tag {
PARSE_COMPLETE_TAG => Message::ParseComplete,
BIND_COMPLETE_TAG => Message::BindComplete,
CLOSE_COMPLETE_TAG => Message::CloseComplete,
NOTIFICATION_RESPONSE_TAG => {
let process_id = buf.read_i32::<BigEndian>()?;
let channel = buf.read_cstr()?;
let message = buf.read_cstr()?;
Message::NotificationResponse(NotificationResponseBody {
process_id,
channel,
message,
})
}
COPY_DONE_TAG => Message::CopyDone,
COMMAND_COMPLETE_TAG => {
let tag = buf.read_cstr()?;
Message::CommandComplete(CommandCompleteBody { tag })
}
COPY_DATA_TAG => Message::CopyData,
DATA_ROW_TAG => {
let len = buf.read_u16::<BigEndian>()?;
let storage = buf.read_all();
Message::DataRow(DataRowBody { storage, len })
}
ERROR_RESPONSE_TAG => {
let storage = buf.read_all();
Message::ErrorResponse(ErrorResponseBody { storage })
}
COPY_IN_RESPONSE_TAG => Message::CopyInResponse,
COPY_OUT_RESPONSE_TAG => Message::CopyOutResponse,
COPY_BOTH_RESPONSE_TAG => Message::CopyBothResponse,
EMPTY_QUERY_RESPONSE_TAG => Message::EmptyQueryResponse,
BACKEND_KEY_DATA_TAG => {
let process_id = buf.read_i32::<BigEndian>()?;
let secret_key = buf.read_i32::<BigEndian>()?;
Message::BackendKeyData(BackendKeyDataBody {
process_id,
secret_key,
})
}
NO_DATA_TAG => Message::NoData,
NOTICE_RESPONSE_TAG => {
let storage = buf.read_all();
Message::NoticeResponse(NoticeResponseBody { storage })
}
AUTHENTICATION_TAG => match buf.read_i32::<BigEndian>()? {
0 => Message::AuthenticationOk,
2 => Message::AuthenticationKerberosV5,
3 => Message::AuthenticationCleartextPassword,
5 => {
let mut salt = [0; 4];
buf.read_exact(&mut salt)?;
Message::AuthenticationMd5Password(AuthenticationMd5PasswordBody { salt })
}
6 => Message::AuthenticationScmCredential,
7 => Message::AuthenticationGss,
8 => Message::AuthenticationGssContinue,
9 => Message::AuthenticationSspi,
10 => {
let storage = buf.read_all();
Message::AuthenticationSasl(AuthenticationSaslBody(storage))
}
11 => {
let storage = buf.read_all();
Message::AuthenticationSaslContinue(AuthenticationSaslContinueBody(storage))
}
12 => {
let storage = buf.read_all();
Message::AuthenticationSaslFinal(AuthenticationSaslFinalBody(storage))
}
tag => {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("unknown authentication tag `{}`", tag),
));
}
},
PORTAL_SUSPENDED_TAG => Message::PortalSuspended,
PARAMETER_STATUS_TAG => {
let name = buf.read_cstr()?;
let value = buf.read_cstr()?;
Message::ParameterStatus(ParameterStatusBody { name, value })
}
PARAMETER_DESCRIPTION_TAG => {
let len = buf.read_u16::<BigEndian>()?;
let storage = buf.read_all();
Message::ParameterDescription(ParameterDescriptionBody { storage, len })
}
ROW_DESCRIPTION_TAG => {
let len = buf.read_u16::<BigEndian>()?;
let storage = buf.read_all();
Message::RowDescription(RowDescriptionBody { storage, len })
}
READY_FOR_QUERY_TAG => {
let status = buf.read_u8()?;
Message::ReadyForQuery(ReadyForQueryBody { status })
}
tag => {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("unknown message tag `{}`", tag),
));
}
};
if !buf.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid message length: expected buffer to be empty",
));
}
Ok(Some(message))
}
}
struct Buffer {
bytes: Bytes,
idx: usize,
}
impl Buffer {
#[inline]
fn slice(&self) -> &[u8] {
&self.bytes[self.idx..]
}
#[inline]
fn is_empty(&self) -> bool {
self.slice().is_empty()
}
#[inline]
fn read_cstr(&mut self) -> io::Result<Bytes> {
match memchr(0, self.slice()) {
Some(pos) => {
let start = self.idx;
let end = start + pos;
let cstr = self.bytes.slice(start..end);
self.idx = end + 1;
Ok(cstr)
}
None => Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"unexpected EOF",
)),
}
}
#[inline]
fn read_all(&mut self) -> Bytes {
let buf = self.bytes.slice(self.idx..);
self.idx = self.bytes.len();
buf
}
}
impl Read for Buffer {
#[inline]
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let len = {
let slice = self.slice();
let len = cmp::min(slice.len(), buf.len());
buf[..len].copy_from_slice(&slice[..len]);
len
};
self.idx += len;
Ok(len)
}
}
pub struct AuthenticationMd5PasswordBody {
salt: [u8; 4],
}
impl AuthenticationMd5PasswordBody {
#[inline]
pub fn salt(&self) -> [u8; 4] {
self.salt
}
}
pub struct AuthenticationSaslBody(Bytes);
impl AuthenticationSaslBody {
#[inline]
pub fn mechanisms(&self) -> SaslMechanisms<'_> {
SaslMechanisms(&self.0)
}
}
pub struct SaslMechanisms<'a>(&'a [u8]);
impl<'a> FallibleIterator for SaslMechanisms<'a> {
type Item = &'a str;
type Error = io::Error;
#[inline]
fn next(&mut self) -> io::Result<Option<&'a str>> {
let value_end = find_null(self.0, 0)?;
if value_end == 0 {
if self.0.len() != 1 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"invalid message length: expected to be at end of iterator for sasl",
));
}
Ok(None)
} else {
let value = get_str(&self.0[..value_end])?;
self.0 = &self.0[value_end + 1..];
Ok(Some(value))
}
}
}
pub struct AuthenticationSaslContinueBody(Bytes);
impl AuthenticationSaslContinueBody {
#[inline]
pub fn data(&self) -> &[u8] {
&self.0
}
}
pub struct AuthenticationSaslFinalBody(Bytes);
impl AuthenticationSaslFinalBody {
#[inline]
pub fn data(&self) -> &[u8] {
&self.0
}
}
pub struct BackendKeyDataBody {
process_id: i32,
secret_key: i32,
}
impl BackendKeyDataBody {
#[inline]
pub fn process_id(&self) -> i32 {
self.process_id
}
#[inline]
pub fn secret_key(&self) -> i32 {
self.secret_key
}
}
pub struct CommandCompleteBody {
tag: Bytes,
}
impl CommandCompleteBody {
#[inline]
pub fn tag(&self) -> io::Result<&str> {
get_str(&self.tag)
}
}
#[derive(Debug)]
pub struct DataRowBody {
storage: Bytes,
len: u16,
}
impl DataRowBody {
#[inline]
pub fn ranges(&self) -> DataRowRanges<'_> {
DataRowRanges {
buf: &self.storage,
len: self.storage.len(),
remaining: self.len,
}
}
#[inline]
pub fn buffer(&self) -> &[u8] {
&self.storage
}
}
pub struct DataRowRanges<'a> {
buf: &'a [u8],
len: usize,
remaining: u16,
}
impl FallibleIterator for DataRowRanges<'_> {
type Item = Option<Range<usize>>;
type Error = io::Error;
#[inline]
fn next(&mut self) -> io::Result<Option<Option<Range<usize>>>> {
if self.remaining == 0 {
if self.buf.is_empty() {
return Ok(None);
} else {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid message length: datarowrange is not empty",
));
}
}
self.remaining -= 1;
let len = self.buf.read_i32::<BigEndian>()?;
if len < 0 {
Ok(Some(None))
} else {
let len = len as usize;
if self.buf.len() < len {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"unexpected EOF",
));
}
let base = self.len - self.buf.len();
self.buf = &self.buf[len..];
Ok(Some(Some(base..base + len)))
}
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.remaining as usize;
(len, Some(len))
}
}
pub struct ErrorResponseBody {
storage: Bytes,
}
impl ErrorResponseBody {
#[inline]
pub fn fields(&self) -> ErrorFields<'_> {
ErrorFields { buf: &self.storage }
}
}
pub struct ErrorFields<'a> {
buf: &'a [u8],
}
impl<'a> FallibleIterator for ErrorFields<'a> {
type Item = ErrorField<'a>;
type Error = io::Error;
#[inline]
fn next(&mut self) -> io::Result<Option<ErrorField<'a>>> {
let type_ = self.buf.read_u8()?;
if type_ == 0 {
if self.buf.is_empty() {
return Ok(None);
} else {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid message length: error fields is not drained",
));
}
}
let value_end = find_null(self.buf, 0)?;
let value = get_str(&self.buf[..value_end])?;
self.buf = &self.buf[value_end + 1..];
Ok(Some(ErrorField { type_, value }))
}
}
pub struct ErrorField<'a> {
type_: u8,
value: &'a str,
}
impl ErrorField<'_> {
#[inline]
pub fn type_(&self) -> u8 {
self.type_
}
#[inline]
pub fn value(&self) -> &str {
self.value
}
}
pub struct NoticeResponseBody {
storage: Bytes,
}
impl NoticeResponseBody {
#[inline]
pub fn fields(&self) -> ErrorFields<'_> {
ErrorFields { buf: &self.storage }
}
}
pub struct NotificationResponseBody {
process_id: i32,
channel: Bytes,
message: Bytes,
}
impl NotificationResponseBody {
#[inline]
pub fn process_id(&self) -> i32 {
self.process_id
}
#[inline]
pub fn channel(&self) -> io::Result<&str> {
get_str(&self.channel)
}
#[inline]
pub fn message(&self) -> io::Result<&str> {
get_str(&self.message)
}
}
pub struct ParameterDescriptionBody {
storage: Bytes,
len: u16,
}
impl ParameterDescriptionBody {
#[inline]
pub fn parameters(&self) -> Parameters<'_> {
Parameters {
buf: &self.storage,
remaining: self.len,
}
}
}
pub struct Parameters<'a> {
buf: &'a [u8],
remaining: u16,
}
impl FallibleIterator for Parameters<'_> {
type Item = Oid;
type Error = io::Error;
#[inline]
fn next(&mut self) -> io::Result<Option<Oid>> {
if self.remaining == 0 {
if self.buf.is_empty() {
return Ok(None);
} else {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid message length: parameters is not drained",
));
}
}
self.remaining -= 1;
self.buf.read_u32::<BigEndian>().map(Some)
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.remaining as usize;
(len, Some(len))
}
}
pub struct ParameterStatusBody {
name: Bytes,
value: Bytes,
}
impl ParameterStatusBody {
#[inline]
pub fn name(&self) -> io::Result<&str> {
get_str(&self.name)
}
#[inline]
pub fn value(&self) -> io::Result<&str> {
get_str(&self.value)
}
}
pub struct ReadyForQueryBody {
status: u8,
}
impl ReadyForQueryBody {
#[inline]
pub fn status(&self) -> u8 {
self.status
}
}
pub struct RowDescriptionBody {
storage: Bytes,
len: u16,
}
impl RowDescriptionBody {
#[inline]
pub fn fields(&self) -> Fields<'_> {
Fields {
buf: &self.storage,
remaining: self.len,
}
}
}
pub struct Fields<'a> {
buf: &'a [u8],
remaining: u16,
}
impl<'a> FallibleIterator for Fields<'a> {
type Item = Field<'a>;
type Error = io::Error;
#[inline]
fn next(&mut self) -> io::Result<Option<Field<'a>>> {
if self.remaining == 0 {
if self.buf.is_empty() {
return Ok(None);
} else {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid message length: field is not drained",
));
}
}
self.remaining -= 1;
let name_end = find_null(self.buf, 0)?;
let name = get_str(&self.buf[..name_end])?;
self.buf = &self.buf[name_end + 1..];
let table_oid = self.buf.read_u32::<BigEndian>()?;
let column_id = self.buf.read_i16::<BigEndian>()?;
let type_oid = self.buf.read_u32::<BigEndian>()?;
let type_size = self.buf.read_i16::<BigEndian>()?;
let type_modifier = self.buf.read_i32::<BigEndian>()?;
let format = self.buf.read_i16::<BigEndian>()?;
Ok(Some(Field {
name,
table_oid,
column_id,
type_oid,
type_size,
type_modifier,
format,
}))
}
}
pub struct Field<'a> {
name: &'a str,
table_oid: Oid,
column_id: i16,
type_oid: Oid,
type_size: i16,
type_modifier: i32,
format: i16,
}
impl<'a> Field<'a> {
#[inline]
pub fn name(&self) -> &'a str {
self.name
}
#[inline]
pub fn table_oid(&self) -> Oid {
self.table_oid
}
#[inline]
pub fn column_id(&self) -> i16 {
self.column_id
}
#[inline]
pub fn type_oid(&self) -> Oid {
self.type_oid
}
#[inline]
pub fn type_size(&self) -> i16 {
self.type_size
}
#[inline]
pub fn type_modifier(&self) -> i32 {
self.type_modifier
}
#[inline]
pub fn format(&self) -> i16 {
self.format
}
}
#[inline]
fn find_null(buf: &[u8], start: usize) -> io::Result<usize> {
match memchr(0, &buf[start..]) {
Some(pos) => Ok(pos + start),
None => Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"unexpected EOF",
)),
}
}
#[inline]
fn get_str(buf: &[u8]) -> io::Result<&str> {
str::from_utf8(buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))
}

View File

@@ -0,0 +1,297 @@
//! Frontend message serialization.
#![allow(missing_docs)]
use byteorder::{BigEndian, ByteOrder};
use bytes::{Buf, BufMut, BytesMut};
use std::convert::TryFrom;
use std::error::Error;
use std::io;
use std::marker;
use crate::{write_nullable, FromUsize, IsNull, Oid};
#[inline]
fn write_body<F, E>(buf: &mut BytesMut, f: F) -> Result<(), E>
where
F: FnOnce(&mut BytesMut) -> Result<(), E>,
E: From<io::Error>,
{
let base = buf.len();
buf.extend_from_slice(&[0; 4]);
f(buf)?;
let size = i32::from_usize(buf.len() - base)?;
BigEndian::write_i32(&mut buf[base..], size);
Ok(())
}
pub enum BindError {
Conversion(Box<dyn Error + marker::Sync + Send>),
Serialization(io::Error),
}
impl From<Box<dyn Error + marker::Sync + Send>> for BindError {
#[inline]
fn from(e: Box<dyn Error + marker::Sync + Send>) -> BindError {
BindError::Conversion(e)
}
}
impl From<io::Error> for BindError {
#[inline]
fn from(e: io::Error) -> BindError {
BindError::Serialization(e)
}
}
#[inline]
pub fn bind<I, J, F, T, K>(
portal: &str,
statement: &str,
formats: I,
values: J,
mut serializer: F,
result_formats: K,
buf: &mut BytesMut,
) -> Result<(), BindError>
where
I: IntoIterator<Item = i16>,
J: IntoIterator<Item = T>,
F: FnMut(T, &mut BytesMut) -> Result<IsNull, Box<dyn Error + marker::Sync + Send>>,
K: IntoIterator<Item = i16>,
{
buf.put_u8(b'B');
write_body(buf, |buf| {
write_cstr(portal.as_bytes(), buf)?;
write_cstr(statement.as_bytes(), buf)?;
write_counted(
formats,
|f, buf| {
buf.put_i16(f);
Ok::<_, io::Error>(())
},
buf,
)?;
write_counted(
values,
|v, buf| write_nullable(|buf| serializer(v, buf), buf),
buf,
)?;
write_counted(
result_formats,
|f, buf| {
buf.put_i16(f);
Ok::<_, io::Error>(())
},
buf,
)?;
Ok(())
})
}
#[inline]
fn write_counted<I, T, F, E>(items: I, mut serializer: F, buf: &mut BytesMut) -> Result<(), E>
where
I: IntoIterator<Item = T>,
F: FnMut(T, &mut BytesMut) -> Result<(), E>,
E: From<io::Error>,
{
let base = buf.len();
buf.extend_from_slice(&[0; 2]);
let mut count = 0;
for item in items {
serializer(item, buf)?;
count += 1;
}
let count = i16::from_usize(count)?;
BigEndian::write_i16(&mut buf[base..], count);
Ok(())
}
#[inline]
pub fn cancel_request(process_id: i32, secret_key: i32, buf: &mut BytesMut) {
write_body(buf, |buf| {
buf.put_i32(80_877_102);
buf.put_i32(process_id);
buf.put_i32(secret_key);
Ok::<_, io::Error>(())
})
.unwrap();
}
#[inline]
pub fn close(variant: u8, name: &str, buf: &mut BytesMut) -> io::Result<()> {
buf.put_u8(b'C');
write_body(buf, |buf| {
buf.put_u8(variant);
write_cstr(name.as_bytes(), buf)
})
}
pub struct CopyData<T> {
buf: T,
len: i32,
}
impl<T> CopyData<T>
where
T: Buf,
{
pub fn new(buf: T) -> io::Result<CopyData<T>> {
let len = buf
.remaining()
.checked_add(4)
.and_then(|l| i32::try_from(l).ok())
.ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "message length overflow")
})?;
Ok(CopyData { buf, len })
}
pub fn write(self, out: &mut BytesMut) {
out.put_u8(b'd');
out.put_i32(self.len);
out.put(self.buf);
}
}
#[inline]
pub fn copy_done(buf: &mut BytesMut) {
buf.put_u8(b'c');
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
}
#[inline]
pub fn copy_fail(message: &str, buf: &mut BytesMut) -> io::Result<()> {
buf.put_u8(b'f');
write_body(buf, |buf| write_cstr(message.as_bytes(), buf))
}
#[inline]
pub fn describe(variant: u8, name: &str, buf: &mut BytesMut) -> io::Result<()> {
buf.put_u8(b'D');
write_body(buf, |buf| {
buf.put_u8(variant);
write_cstr(name.as_bytes(), buf)
})
}
#[inline]
pub fn execute(portal: &str, max_rows: i32, buf: &mut BytesMut) -> io::Result<()> {
buf.put_u8(b'E');
write_body(buf, |buf| {
write_cstr(portal.as_bytes(), buf)?;
buf.put_i32(max_rows);
Ok(())
})
}
#[inline]
pub fn parse<I>(name: &str, query: &str, param_types: I, buf: &mut BytesMut) -> io::Result<()>
where
I: IntoIterator<Item = Oid>,
{
buf.put_u8(b'P');
write_body(buf, |buf| {
write_cstr(name.as_bytes(), buf)?;
write_cstr(query.as_bytes(), buf)?;
write_counted(
param_types,
|t, buf| {
buf.put_u32(t);
Ok::<_, io::Error>(())
},
buf,
)?;
Ok(())
})
}
#[inline]
pub fn password_message(password: &[u8], buf: &mut BytesMut) -> io::Result<()> {
buf.put_u8(b'p');
write_body(buf, |buf| write_cstr(password, buf))
}
#[inline]
pub fn query(query: &str, buf: &mut BytesMut) -> io::Result<()> {
buf.put_u8(b'Q');
write_body(buf, |buf| write_cstr(query.as_bytes(), buf))
}
#[inline]
pub fn sasl_initial_response(mechanism: &str, data: &[u8], buf: &mut BytesMut) -> io::Result<()> {
buf.put_u8(b'p');
write_body(buf, |buf| {
write_cstr(mechanism.as_bytes(), buf)?;
let len = i32::from_usize(data.len())?;
buf.put_i32(len);
buf.put_slice(data);
Ok(())
})
}
#[inline]
pub fn sasl_response(data: &[u8], buf: &mut BytesMut) -> io::Result<()> {
buf.put_u8(b'p');
write_body(buf, |buf| {
buf.put_slice(data);
Ok(())
})
}
#[inline]
pub fn ssl_request(buf: &mut BytesMut) {
write_body(buf, |buf| {
buf.put_i32(80_877_103);
Ok::<_, io::Error>(())
})
.unwrap();
}
#[inline]
pub fn startup_message<'a, I>(parameters: I, buf: &mut BytesMut) -> io::Result<()>
where
I: IntoIterator<Item = (&'a str, &'a str)>,
{
write_body(buf, |buf| {
// postgres protocol version 3.0(196608) in bigger-endian
buf.put_i32(0x00_03_00_00);
for (key, value) in parameters {
write_cstr(key.as_bytes(), buf)?;
write_cstr(value.as_bytes(), buf)?;
}
buf.put_u8(0);
Ok(())
})
}
#[inline]
pub fn sync(buf: &mut BytesMut) {
buf.put_u8(b'S');
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
}
#[inline]
pub fn terminate(buf: &mut BytesMut) {
buf.put_u8(b'X');
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
}
#[inline]
fn write_cstr(s: &[u8], buf: &mut BytesMut) -> Result<(), io::Error> {
if s.contains(&0) {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"string contains embedded null",
));
}
buf.put_slice(s);
buf.put_u8(0);
Ok(())
}

View File

@@ -0,0 +1,8 @@
//! Postgres message protocol support.
//!
//! See [Postgres's documentation][docs] for more information on message flow.
//!
//! [docs]: https://www.postgresql.org/docs/9.5/static/protocol-flow.html
pub mod backend;
pub mod frontend;

View File

@@ -0,0 +1,107 @@
//! Functions to encrypt a password in the client.
//!
//! This is intended to be used by client applications that wish to
//! send commands like `ALTER USER joe PASSWORD 'pwd'`. The password
//! need not be sent in cleartext if it is encrypted on the client
//! side. This is good because it ensures the cleartext password won't
//! end up in logs pg_stat displays, etc.
use crate::authentication::sasl;
use hmac::{Hmac, Mac};
use md5::Md5;
use rand::RngCore;
use sha2::digest::FixedOutput;
use sha2::{Digest, Sha256};
#[cfg(test)]
mod test;
const SCRAM_DEFAULT_ITERATIONS: u32 = 4096;
const SCRAM_DEFAULT_SALT_LEN: usize = 16;
/// Hash password using SCRAM-SHA-256 with a randomly-generated
/// salt.
///
/// The client may assume the returned string doesn't contain any
/// special characters that would require escaping in an SQL command.
pub async fn scram_sha_256(password: &[u8]) -> String {
let mut salt: [u8; SCRAM_DEFAULT_SALT_LEN] = [0; SCRAM_DEFAULT_SALT_LEN];
let mut rng = rand::thread_rng();
rng.fill_bytes(&mut salt);
scram_sha_256_salt(password, salt).await
}
// Internal implementation of scram_sha_256 with a caller-provided
// salt. This is useful for testing.
pub(crate) async fn scram_sha_256_salt(
password: &[u8],
salt: [u8; SCRAM_DEFAULT_SALT_LEN],
) -> String {
// Prepare the password, per [RFC
// 4013](https://tools.ietf.org/html/rfc4013), if possible.
//
// Postgres treats passwords as byte strings (without embedded NUL
// bytes), but SASL expects passwords to be valid UTF-8.
//
// Follow the behavior of libpq's PQencryptPasswordConn(), and
// also the backend. If the password is not valid UTF-8, or if it
// contains prohibited characters (such as non-ASCII whitespace),
// just skip the SASLprep step and use the original byte
// sequence.
let prepared: Vec<u8> = match std::str::from_utf8(password) {
Ok(password_str) => {
match stringprep::saslprep(password_str) {
Ok(p) => p.into_owned().into_bytes(),
// contains invalid characters; skip saslprep
Err(_) => Vec::from(password),
}
}
// not valid UTF-8; skip saslprep
Err(_) => Vec::from(password),
};
// salt password
let salted_password = sasl::hi(&prepared, &salt, SCRAM_DEFAULT_ITERATIONS).await;
// client key
let mut hmac = Hmac::<Sha256>::new_from_slice(&salted_password)
.expect("HMAC is able to accept all key sizes");
hmac.update(b"Client Key");
let client_key = hmac.finalize().into_bytes();
// stored key
let mut hash = Sha256::default();
hash.update(client_key.as_slice());
let stored_key = hash.finalize_fixed();
// server key
let mut hmac = Hmac::<Sha256>::new_from_slice(&salted_password)
.expect("HMAC is able to accept all key sizes");
hmac.update(b"Server Key");
let server_key = hmac.finalize().into_bytes();
format!(
"SCRAM-SHA-256${}:{}${}:{}",
SCRAM_DEFAULT_ITERATIONS,
base64::encode(salt),
base64::encode(stored_key),
base64::encode(server_key)
)
}
/// **Not recommended, as MD5 is not considered to be secure.**
///
/// Hash password using MD5 with the username as the salt.
///
/// The client may assume the returned string doesn't contain any
/// special characters that would require escaping.
pub fn md5(password: &[u8], username: &str) -> String {
// salt password with username
let mut salted_password = Vec::from(password);
salted_password.extend_from_slice(username.as_bytes());
let mut hash = Md5::new();
hash.update(&salted_password);
let digest = hash.finalize();
format!("md5{:x}", digest)
}

View File

@@ -0,0 +1,19 @@
use crate::password;
#[tokio::test]
async fn test_encrypt_scram_sha_256() {
// Specify the salt to make the test deterministic. Any bytes will do.
let salt: [u8; 16] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
assert_eq!(
password::scram_sha_256_salt(b"secret", salt).await,
"SCRAM-SHA-256$4096:AQIDBAUGBwgJCgsMDQ4PEA==$8rrDg00OqaiWXJ7p+sCgHEIaBSHY89ZJl3mfIsf32oY=:05L1f+yZbiN8O0AnO40Og85NNRhvzTS57naKRWCcsIA="
);
}
#[test]
fn test_encrypt_md5() {
assert_eq!(
password::md5(b"secret", "foo"),
"md54ab2c5d00339c4b2a4e921d2dc4edec7"
);
}

View File

@@ -0,0 +1,294 @@
//! Conversions to and from Postgres's binary format for various types.
use byteorder::{BigEndian, ReadBytesExt};
use bytes::{BufMut, BytesMut};
use fallible_iterator::FallibleIterator;
use std::boxed::Box as StdBox;
use std::error::Error;
use std::str;
use crate::Oid;
#[cfg(test)]
mod test;
/// Serializes a `TEXT`, `VARCHAR`, `CHAR(n)`, `NAME`, or `CITEXT` value.
#[inline]
pub fn text_to_sql(v: &str, buf: &mut BytesMut) {
buf.put_slice(v.as_bytes());
}
/// Deserializes a `TEXT`, `VARCHAR`, `CHAR(n)`, `NAME`, or `CITEXT` value.
#[inline]
pub fn text_from_sql(buf: &[u8]) -> Result<&str, StdBox<dyn Error + Sync + Send>> {
Ok(str::from_utf8(buf)?)
}
/// Deserializes a `"char"` value.
#[inline]
pub fn char_from_sql(mut buf: &[u8]) -> Result<i8, StdBox<dyn Error + Sync + Send>> {
let v = buf.read_i8()?;
if !buf.is_empty() {
return Err("invalid buffer size".into());
}
Ok(v)
}
/// Serializes an `OID` value.
#[inline]
pub fn oid_to_sql(v: Oid, buf: &mut BytesMut) {
buf.put_u32(v);
}
/// Deserializes an `OID` value.
#[inline]
pub fn oid_from_sql(mut buf: &[u8]) -> Result<Oid, StdBox<dyn Error + Sync + Send>> {
let v = buf.read_u32::<BigEndian>()?;
if !buf.is_empty() {
return Err("invalid buffer size".into());
}
Ok(v)
}
/// A fallible iterator over `HSTORE` entries.
pub struct HstoreEntries<'a> {
remaining: i32,
buf: &'a [u8],
}
impl<'a> FallibleIterator for HstoreEntries<'a> {
type Item = (&'a str, Option<&'a str>);
type Error = StdBox<dyn Error + Sync + Send>;
#[inline]
#[allow(clippy::type_complexity)]
fn next(
&mut self,
) -> Result<Option<(&'a str, Option<&'a str>)>, StdBox<dyn Error + Sync + Send>> {
if self.remaining == 0 {
if !self.buf.is_empty() {
return Err("invalid buffer size".into());
}
return Ok(None);
}
self.remaining -= 1;
let key_len = self.buf.read_i32::<BigEndian>()?;
if key_len < 0 {
return Err("invalid key length".into());
}
let (key, buf) = self.buf.split_at(key_len as usize);
let key = str::from_utf8(key)?;
self.buf = buf;
let value_len = self.buf.read_i32::<BigEndian>()?;
let value = if value_len < 0 {
None
} else {
let (value, buf) = self.buf.split_at(value_len as usize);
let value = str::from_utf8(value)?;
self.buf = buf;
Some(value)
};
Ok(Some((key, value)))
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.remaining as usize;
(len, Some(len))
}
}
/// Deserializes an array value.
#[inline]
pub fn array_from_sql(mut buf: &[u8]) -> Result<Array<'_>, StdBox<dyn Error + Sync + Send>> {
let dimensions = buf.read_i32::<BigEndian>()?;
if dimensions < 0 {
return Err("invalid dimension count".into());
}
let mut r = buf;
let mut elements = 1i32;
for _ in 0..dimensions {
let len = r.read_i32::<BigEndian>()?;
if len < 0 {
return Err("invalid dimension size".into());
}
let _lower_bound = r.read_i32::<BigEndian>()?;
elements = match elements.checked_mul(len) {
Some(elements) => elements,
None => return Err("too many array elements".into()),
};
}
if dimensions == 0 {
elements = 0;
}
Ok(Array {
dimensions,
elements,
buf,
})
}
/// A Postgres array.
pub struct Array<'a> {
dimensions: i32,
elements: i32,
buf: &'a [u8],
}
impl<'a> Array<'a> {
/// Returns an iterator over the dimensions of the array.
#[inline]
pub fn dimensions(&self) -> ArrayDimensions<'a> {
ArrayDimensions(&self.buf[..self.dimensions as usize * 8])
}
/// Returns an iterator over the values of the array.
#[inline]
pub fn values(&self) -> ArrayValues<'a> {
ArrayValues {
remaining: self.elements,
buf: &self.buf[self.dimensions as usize * 8..],
}
}
}
/// An iterator over the dimensions of an array.
pub struct ArrayDimensions<'a>(&'a [u8]);
impl FallibleIterator for ArrayDimensions<'_> {
type Item = ArrayDimension;
type Error = StdBox<dyn Error + Sync + Send>;
#[inline]
fn next(&mut self) -> Result<Option<ArrayDimension>, StdBox<dyn Error + Sync + Send>> {
if self.0.is_empty() {
return Ok(None);
}
let len = self.0.read_i32::<BigEndian>()?;
let lower_bound = self.0.read_i32::<BigEndian>()?;
Ok(Some(ArrayDimension { len, lower_bound }))
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.0.len() / 8;
(len, Some(len))
}
}
/// Information about a dimension of an array.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct ArrayDimension {
/// The length of this dimension.
pub len: i32,
/// The base value used to index into this dimension.
pub lower_bound: i32,
}
/// An iterator over the values of an array, in row-major order.
pub struct ArrayValues<'a> {
remaining: i32,
buf: &'a [u8],
}
impl<'a> FallibleIterator for ArrayValues<'a> {
type Item = Option<&'a [u8]>;
type Error = StdBox<dyn Error + Sync + Send>;
#[inline]
fn next(&mut self) -> Result<Option<Option<&'a [u8]>>, StdBox<dyn Error + Sync + Send>> {
if self.remaining == 0 {
if !self.buf.is_empty() {
return Err("invalid message length: arrayvalue not drained".into());
}
return Ok(None);
}
self.remaining -= 1;
let len = self.buf.read_i32::<BigEndian>()?;
let val = if len < 0 {
None
} else {
if self.buf.len() < len as usize {
return Err("invalid value length".into());
}
let (val, buf) = self.buf.split_at(len as usize);
self.buf = buf;
Some(val)
};
Ok(Some(val))
}
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.remaining as usize;
(len, Some(len))
}
}
/// Serializes a Postgres ltree string
#[inline]
pub fn ltree_to_sql(v: &str, buf: &mut BytesMut) {
// A version number is prepended to an ltree string per spec
buf.put_u8(1);
// Append the rest of the query
buf.put_slice(v.as_bytes());
}
/// Deserialize a Postgres ltree string
#[inline]
pub fn ltree_from_sql(buf: &[u8]) -> Result<&str, StdBox<dyn Error + Sync + Send>> {
match buf {
// Remove the version number from the front of the ltree per spec
[1u8, rest @ ..] => Ok(str::from_utf8(rest)?),
_ => Err("ltree version 1 only supported".into()),
}
}
/// Serializes a Postgres lquery string
#[inline]
pub fn lquery_to_sql(v: &str, buf: &mut BytesMut) {
// A version number is prepended to an lquery string per spec
buf.put_u8(1);
// Append the rest of the query
buf.put_slice(v.as_bytes());
}
/// Deserialize a Postgres lquery string
#[inline]
pub fn lquery_from_sql(buf: &[u8]) -> Result<&str, StdBox<dyn Error + Sync + Send>> {
match buf {
// Remove the version number from the front of the lquery per spec
[1u8, rest @ ..] => Ok(str::from_utf8(rest)?),
_ => Err("lquery version 1 only supported".into()),
}
}
/// Serializes a Postgres ltxtquery string
#[inline]
pub fn ltxtquery_to_sql(v: &str, buf: &mut BytesMut) {
// A version number is prepended to an ltxtquery string per spec
buf.put_u8(1);
// Append the rest of the query
buf.put_slice(v.as_bytes());
}
/// Deserialize a Postgres ltxtquery string
#[inline]
pub fn ltxtquery_from_sql(buf: &[u8]) -> Result<&str, StdBox<dyn Error + Sync + Send>> {
match buf {
// Remove the version number from the front of the ltxtquery per spec
[1u8, rest @ ..] => Ok(str::from_utf8(rest)?),
_ => Err("ltxtquery version 1 only supported".into()),
}
}

View File

@@ -0,0 +1,87 @@
use bytes::{Buf, BytesMut};
use super::*;
#[test]
fn ltree_sql() {
let mut query = vec![1u8];
query.extend_from_slice("A.B.C".as_bytes());
let mut buf = BytesMut::new();
ltree_to_sql("A.B.C", &mut buf);
assert_eq!(query.as_slice(), buf.chunk());
}
#[test]
fn ltree_str() {
let mut query = vec![1u8];
query.extend_from_slice("A.B.C".as_bytes());
assert!(ltree_from_sql(query.as_slice()).is_ok())
}
#[test]
fn ltree_wrong_version() {
let mut query = vec![2u8];
query.extend_from_slice("A.B.C".as_bytes());
assert!(ltree_from_sql(query.as_slice()).is_err())
}
#[test]
fn lquery_sql() {
let mut query = vec![1u8];
query.extend_from_slice("A.B.C".as_bytes());
let mut buf = BytesMut::new();
lquery_to_sql("A.B.C", &mut buf);
assert_eq!(query.as_slice(), buf.chunk());
}
#[test]
fn lquery_str() {
let mut query = vec![1u8];
query.extend_from_slice("A.B.C".as_bytes());
assert!(lquery_from_sql(query.as_slice()).is_ok())
}
#[test]
fn lquery_wrong_version() {
let mut query = vec![2u8];
query.extend_from_slice("A.B.C".as_bytes());
assert!(lquery_from_sql(query.as_slice()).is_err())
}
#[test]
fn ltxtquery_sql() {
let mut query = vec![1u8];
query.extend_from_slice("a & b*".as_bytes());
let mut buf = BytesMut::new();
ltree_to_sql("a & b*", &mut buf);
assert_eq!(query.as_slice(), buf.chunk());
}
#[test]
fn ltxtquery_str() {
let mut query = vec![1u8];
query.extend_from_slice("a & b*".as_bytes());
assert!(ltree_from_sql(query.as_slice()).is_ok())
}
#[test]
fn ltxtquery_wrong_version() {
let mut query = vec![2u8];
query.extend_from_slice("a & b*".as_bytes());
assert!(ltree_from_sql(query.as_slice()).is_err())
}

View File

@@ -0,0 +1,10 @@
[package]
name = "postgres-types2"
version = "0.1.0"
edition = "2018"
license = "MIT/Apache-2.0"
[dependencies]
bytes.workspace = true
fallible-iterator.workspace = true
postgres-protocol2 = { path = "../postgres-protocol2" }

View File

@@ -0,0 +1,477 @@
//! Conversions to and from Postgres types.
//!
//! This crate is used by the `tokio-postgres` and `postgres` crates. You normally don't need to depend directly on it
//! unless you want to define your own `ToSql` or `FromSql` definitions.
#![doc(html_root_url = "https://docs.rs/postgres-types/0.2")]
#![warn(clippy::all, rust_2018_idioms, missing_docs)]
use fallible_iterator::FallibleIterator;
use postgres_protocol2::types;
use std::any::type_name;
use std::error::Error;
use std::fmt;
use std::sync::Arc;
use crate::type_gen::{Inner, Other};
#[doc(inline)]
pub use postgres_protocol2::Oid;
use bytes::BytesMut;
/// Generates a simple implementation of `ToSql::accepts` which accepts the
/// types passed to it.
macro_rules! accepts {
($($expected:ident),+) => (
fn accepts(ty: &$crate::Type) -> bool {
matches!(*ty, $($crate::Type::$expected)|+)
}
)
}
/// Generates an implementation of `ToSql::to_sql_checked`.
///
/// All `ToSql` implementations should use this macro.
macro_rules! to_sql_checked {
() => {
fn to_sql_checked(
&self,
ty: &$crate::Type,
out: &mut $crate::private::BytesMut,
) -> ::std::result::Result<
$crate::IsNull,
Box<dyn ::std::error::Error + ::std::marker::Sync + ::std::marker::Send>,
> {
$crate::__to_sql_checked(self, ty, out)
}
};
}
// WARNING: this function is not considered part of this crate's public API.
// It is subject to change at any time.
#[doc(hidden)]
pub fn __to_sql_checked<T>(
v: &T,
ty: &Type,
out: &mut BytesMut,
) -> Result<IsNull, Box<dyn Error + Sync + Send>>
where
T: ToSql,
{
if !T::accepts(ty) {
return Err(Box::new(WrongType::new::<T>(ty.clone())));
}
v.to_sql(ty, out)
}
// mod pg_lsn;
#[doc(hidden)]
pub mod private;
// mod special;
mod type_gen;
/// A Postgres type.
#[derive(PartialEq, Eq, Clone, Hash)]
pub struct Type(Inner);
impl fmt::Debug for Type {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(&self.0, fmt)
}
}
impl fmt::Display for Type {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.schema() {
"public" | "pg_catalog" => {}
schema => write!(fmt, "{}.", schema)?,
}
fmt.write_str(self.name())
}
}
impl Type {
/// Creates a new `Type`.
pub fn new(name: String, oid: Oid, kind: Kind, schema: String) -> Type {
Type(Inner::Other(Arc::new(Other {
name,
oid,
kind,
schema,
})))
}
/// Returns the `Type` corresponding to the provided `Oid` if it
/// corresponds to a built-in type.
pub fn from_oid(oid: Oid) -> Option<Type> {
Inner::from_oid(oid).map(Type)
}
/// Returns the OID of the `Type`.
pub fn oid(&self) -> Oid {
self.0.oid()
}
/// Returns the kind of this type.
pub fn kind(&self) -> &Kind {
self.0.kind()
}
/// Returns the schema of this type.
pub fn schema(&self) -> &str {
match self.0 {
Inner::Other(ref u) => &u.schema,
_ => "pg_catalog",
}
}
/// Returns the name of this type.
pub fn name(&self) -> &str {
self.0.name()
}
}
/// Represents the kind of a Postgres type.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum Kind {
/// A simple type like `VARCHAR` or `INTEGER`.
Simple,
/// An enumerated type along with its variants.
Enum(Vec<String>),
/// A pseudo-type.
Pseudo,
/// An array type along with the type of its elements.
Array(Type),
/// A range type along with the type of its elements.
Range(Type),
/// A multirange type along with the type of its elements.
Multirange(Type),
/// A domain type along with its underlying type.
Domain(Type),
/// A composite type along with information about its fields.
Composite(Vec<Field>),
}
/// Information about a field of a composite type.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Field {
name: String,
type_: Type,
}
impl Field {
/// Creates a new `Field`.
pub fn new(name: String, type_: Type) -> Field {
Field { name, type_ }
}
/// Returns the name of the field.
pub fn name(&self) -> &str {
&self.name
}
/// Returns the type of the field.
pub fn type_(&self) -> &Type {
&self.type_
}
}
/// An error indicating that a `NULL` Postgres value was passed to a `FromSql`
/// implementation that does not support `NULL` values.
#[derive(Debug, Clone, Copy)]
pub struct WasNull;
impl fmt::Display for WasNull {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.write_str("a Postgres value was `NULL`")
}
}
impl Error for WasNull {}
/// An error indicating that a conversion was attempted between incompatible
/// Rust and Postgres types.
#[derive(Debug)]
pub struct WrongType {
postgres: Type,
rust: &'static str,
}
impl fmt::Display for WrongType {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
fmt,
"cannot convert between the Rust type `{}` and the Postgres type `{}`",
self.rust, self.postgres,
)
}
}
impl Error for WrongType {}
impl WrongType {
/// Creates a new `WrongType` error.
pub fn new<T>(ty: Type) -> WrongType {
WrongType {
postgres: ty,
rust: type_name::<T>(),
}
}
}
/// An error indicating that a as_text conversion was attempted on a binary
/// result.
#[derive(Debug)]
pub struct WrongFormat {}
impl Error for WrongFormat {}
impl fmt::Display for WrongFormat {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
fmt,
"cannot read column as text while it is in binary format"
)
}
}
/// A trait for types that can be created from a Postgres value.
pub trait FromSql<'a>: Sized {
/// Creates a new value of this type from a buffer of data of the specified
/// Postgres `Type` in its binary format.
///
/// The caller of this method is responsible for ensuring that this type
/// is compatible with the Postgres `Type`.
fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<Self, Box<dyn Error + Sync + Send>>;
/// Creates a new value of this type from a `NULL` SQL value.
///
/// The caller of this method is responsible for ensuring that this type
/// is compatible with the Postgres `Type`.
///
/// The default implementation returns `Err(Box::new(WasNull))`.
#[allow(unused_variables)]
fn from_sql_null(ty: &Type) -> Result<Self, Box<dyn Error + Sync + Send>> {
Err(Box::new(WasNull))
}
/// A convenience function that delegates to `from_sql` and `from_sql_null` depending on the
/// value of `raw`.
fn from_sql_nullable(
ty: &Type,
raw: Option<&'a [u8]>,
) -> Result<Self, Box<dyn Error + Sync + Send>> {
match raw {
Some(raw) => Self::from_sql(ty, raw),
None => Self::from_sql_null(ty),
}
}
/// Determines if a value of this type can be created from the specified
/// Postgres `Type`.
fn accepts(ty: &Type) -> bool;
}
/// A trait for types which can be created from a Postgres value without borrowing any data.
///
/// This is primarily useful for trait bounds on functions.
pub trait FromSqlOwned: for<'a> FromSql<'a> {}
impl<T> FromSqlOwned for T where T: for<'a> FromSql<'a> {}
impl<'a, T: FromSql<'a>> FromSql<'a> for Option<T> {
fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<Option<T>, Box<dyn Error + Sync + Send>> {
<T as FromSql>::from_sql(ty, raw).map(Some)
}
fn from_sql_null(_: &Type) -> Result<Option<T>, Box<dyn Error + Sync + Send>> {
Ok(None)
}
fn accepts(ty: &Type) -> bool {
<T as FromSql>::accepts(ty)
}
}
impl<'a, T: FromSql<'a>> FromSql<'a> for Vec<T> {
fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<Vec<T>, Box<dyn Error + Sync + Send>> {
let member_type = match *ty.kind() {
Kind::Array(ref member) => member,
_ => panic!("expected array type"),
};
let array = types::array_from_sql(raw)?;
if array.dimensions().count()? > 1 {
return Err("array contains too many dimensions".into());
}
array
.values()
.map(|v| T::from_sql_nullable(member_type, v))
.collect()
}
fn accepts(ty: &Type) -> bool {
match *ty.kind() {
Kind::Array(ref inner) => T::accepts(inner),
_ => false,
}
}
}
impl<'a> FromSql<'a> for String {
fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<String, Box<dyn Error + Sync + Send>> {
<&str as FromSql>::from_sql(ty, raw).map(ToString::to_string)
}
fn accepts(ty: &Type) -> bool {
<&str as FromSql>::accepts(ty)
}
}
impl<'a> FromSql<'a> for &'a str {
fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<&'a str, Box<dyn Error + Sync + Send>> {
match *ty {
ref ty if ty.name() == "ltree" => types::ltree_from_sql(raw),
ref ty if ty.name() == "lquery" => types::lquery_from_sql(raw),
ref ty if ty.name() == "ltxtquery" => types::ltxtquery_from_sql(raw),
_ => types::text_from_sql(raw),
}
}
fn accepts(ty: &Type) -> bool {
match *ty {
Type::VARCHAR | Type::TEXT | Type::BPCHAR | Type::NAME | Type::UNKNOWN => true,
ref ty
if (ty.name() == "citext"
|| ty.name() == "ltree"
|| ty.name() == "lquery"
|| ty.name() == "ltxtquery") =>
{
true
}
_ => false,
}
}
}
macro_rules! simple_from {
($t:ty, $f:ident, $($expected:ident),+) => {
impl<'a> FromSql<'a> for $t {
fn from_sql(_: &Type, raw: &'a [u8]) -> Result<$t, Box<dyn Error + Sync + Send>> {
types::$f(raw)
}
accepts!($($expected),+);
}
}
}
simple_from!(i8, char_from_sql, CHAR);
simple_from!(u32, oid_from_sql, OID);
/// An enum representing the nullability of a Postgres value.
pub enum IsNull {
/// The value is NULL.
Yes,
/// The value is not NULL.
No,
}
/// A trait for types that can be converted into Postgres values.
pub trait ToSql: fmt::Debug {
/// Converts the value of `self` into the binary format of the specified
/// Postgres `Type`, appending it to `out`.
///
/// The caller of this method is responsible for ensuring that this type
/// is compatible with the Postgres `Type`.
///
/// The return value indicates if this value should be represented as
/// `NULL`. If this is the case, implementations **must not** write
/// anything to `out`.
fn to_sql(&self, ty: &Type, out: &mut BytesMut) -> Result<IsNull, Box<dyn Error + Sync + Send>>
where
Self: Sized;
/// Determines if a value of this type can be converted to the specified
/// Postgres `Type`.
fn accepts(ty: &Type) -> bool
where
Self: Sized;
/// An adaptor method used internally by Rust-Postgres.
///
/// *All* implementations of this method should be generated by the
/// `to_sql_checked!()` macro.
fn to_sql_checked(
&self,
ty: &Type,
out: &mut BytesMut,
) -> Result<IsNull, Box<dyn Error + Sync + Send>>;
/// Specify the encode format
fn encode_format(&self, _ty: &Type) -> Format {
Format::Binary
}
}
/// Supported Postgres message format types
///
/// Using Text format in a message assumes a Postgres `SERVER_ENCODING` of `UTF8`
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum Format {
/// Text format (UTF-8)
Text,
/// Compact, typed binary format
Binary,
}
impl ToSql for &str {
fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result<IsNull, Box<dyn Error + Sync + Send>> {
match *ty {
ref ty if ty.name() == "ltree" => types::ltree_to_sql(self, w),
ref ty if ty.name() == "lquery" => types::lquery_to_sql(self, w),
ref ty if ty.name() == "ltxtquery" => types::ltxtquery_to_sql(self, w),
_ => types::text_to_sql(self, w),
}
Ok(IsNull::No)
}
fn accepts(ty: &Type) -> bool {
match *ty {
Type::VARCHAR | Type::TEXT | Type::BPCHAR | Type::NAME | Type::UNKNOWN => true,
ref ty
if (ty.name() == "citext"
|| ty.name() == "ltree"
|| ty.name() == "lquery"
|| ty.name() == "ltxtquery") =>
{
true
}
_ => false,
}
}
to_sql_checked!();
}
macro_rules! simple_to {
($t:ty, $f:ident, $($expected:ident),+) => {
impl ToSql for $t {
fn to_sql(&self,
_: &Type,
w: &mut BytesMut)
-> Result<IsNull, Box<dyn Error + Sync + Send>> {
types::$f(*self, w);
Ok(IsNull::No)
}
accepts!($($expected),+);
to_sql_checked!();
}
}
}
simple_to!(u32, oid_to_sql, OID);

View File

@@ -0,0 +1,34 @@
use crate::{FromSql, Type};
pub use bytes::BytesMut;
use std::error::Error;
pub fn read_be_i32(buf: &mut &[u8]) -> Result<i32, Box<dyn Error + Sync + Send>> {
if buf.len() < 4 {
return Err("invalid buffer size".into());
}
let mut bytes = [0; 4];
bytes.copy_from_slice(&buf[..4]);
*buf = &buf[4..];
Ok(i32::from_be_bytes(bytes))
}
pub fn read_value<'a, T>(
type_: &Type,
buf: &mut &'a [u8],
) -> Result<T, Box<dyn Error + Sync + Send>>
where
T: FromSql<'a>,
{
let len = read_be_i32(buf)?;
let value = if len < 0 {
None
} else {
if len as usize > buf.len() {
return Err("invalid buffer size".into());
}
let (head, tail) = buf.split_at(len as usize);
*buf = tail;
Some(head)
};
T::from_sql_nullable(type_, value)
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,21 @@
[package]
name = "tokio-postgres2"
version = "0.1.0"
edition = "2018"
license = "MIT/Apache-2.0"
[dependencies]
async-trait.workspace = true
bytes.workspace = true
byteorder.workspace = true
fallible-iterator.workspace = true
futures-util = { workspace = true, features = ["sink"] }
log = "0.4"
parking_lot.workspace = true
percent-encoding = "2.0"
pin-project-lite.workspace = true
phf = "0.11"
postgres-protocol2 = { path = "../postgres-protocol2" }
postgres-types2 = { path = "../postgres-types2" }
tokio = { workspace = true, features = ["io-util", "time", "net"] }
tokio-util = { workspace = true, features = ["codec"] }

View File

@@ -0,0 +1,40 @@
use tokio::net::TcpStream;
use crate::client::SocketConfig;
use crate::config::{Host, SslMode};
use crate::tls::MakeTlsConnect;
use crate::{cancel_query_raw, connect_socket, Error};
use std::io;
pub(crate) async fn cancel_query<T>(
config: Option<SocketConfig>,
ssl_mode: SslMode,
mut tls: T,
process_id: i32,
secret_key: i32,
) -> Result<(), Error>
where
T: MakeTlsConnect<TcpStream>,
{
let config = match config {
Some(config) => config,
None => {
return Err(Error::connect(io::Error::new(
io::ErrorKind::InvalidInput,
"unknown host",
)))
}
};
let hostname = match &config.host {
Host::Tcp(host) => &**host,
};
let tls = tls
.make_tls_connect(hostname)
.map_err(|e| Error::tls(e.into()))?;
let socket =
connect_socket::connect_socket(&config.host, config.port, config.connect_timeout).await?;
cancel_query_raw::cancel_query_raw(socket, ssl_mode, tls, process_id, secret_key).await
}

View File

@@ -0,0 +1,29 @@
use crate::config::SslMode;
use crate::tls::TlsConnect;
use crate::{connect_tls, Error};
use bytes::BytesMut;
use postgres_protocol2::message::frontend;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
pub async fn cancel_query_raw<S, T>(
stream: S,
mode: SslMode,
tls: T,
process_id: i32,
secret_key: i32,
) -> Result<(), Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: TlsConnect<S>,
{
let mut stream = connect_tls::connect_tls(stream, mode, tls).await?;
let mut buf = BytesMut::new();
frontend::cancel_request(process_id, secret_key, &mut buf);
stream.write_all(&buf).await.map_err(Error::io)?;
stream.flush().await.map_err(Error::io)?;
stream.shutdown().await.map_err(Error::io)?;
Ok(())
}

View File

@@ -0,0 +1,62 @@
use crate::config::SslMode;
use crate::tls::TlsConnect;
use crate::{cancel_query, client::SocketConfig, tls::MakeTlsConnect};
use crate::{cancel_query_raw, Error};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
/// The capability to request cancellation of in-progress queries on a
/// connection.
#[derive(Clone)]
pub struct CancelToken {
pub(crate) socket_config: Option<SocketConfig>,
pub(crate) ssl_mode: SslMode,
pub(crate) process_id: i32,
pub(crate) secret_key: i32,
}
impl CancelToken {
/// Attempts to cancel the in-progress query on the connection associated
/// with this `CancelToken`.
///
/// The server provides no information about whether a cancellation attempt was successful or not. An error will
/// only be returned if the client was unable to connect to the database.
///
/// Cancellation is inherently racy. There is no guarantee that the
/// cancellation request will reach the server before the query terminates
/// normally, or that the connection associated with this token is still
/// active.
///
/// Requires the `runtime` Cargo feature (enabled by default).
pub async fn cancel_query<T>(&self, tls: T) -> Result<(), Error>
where
T: MakeTlsConnect<TcpStream>,
{
cancel_query::cancel_query(
self.socket_config.clone(),
self.ssl_mode,
tls,
self.process_id,
self.secret_key,
)
.await
}
/// Like `cancel_query`, but uses a stream which is already connected to the server rather than opening a new
/// connection itself.
pub async fn cancel_query_raw<S, T>(&self, stream: S, tls: T) -> Result<(), Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: TlsConnect<S>,
{
cancel_query_raw::cancel_query_raw(
stream,
self.ssl_mode,
tls,
self.process_id,
self.secret_key,
)
.await
}
}

View File

@@ -0,0 +1,439 @@
use crate::codec::{BackendMessages, FrontendMessage};
use crate::config::Host;
use crate::config::SslMode;
use crate::connection::{Request, RequestMessages};
use crate::query::RowStream;
use crate::simple_query::SimpleQueryStream;
use crate::types::{Oid, ToSql, Type};
use crate::{
prepare, query, simple_query, slice_iter, CancelToken, Error, ReadyForQueryStatus, Row,
SimpleQueryMessage, Statement, ToStatement, Transaction, TransactionBuilder,
};
use bytes::BytesMut;
use fallible_iterator::FallibleIterator;
use futures_util::{future, ready, TryStreamExt};
use parking_lot::Mutex;
use postgres_protocol2::message::{backend::Message, frontend};
use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::sync::mpsc;
use std::time::Duration;
pub struct Responses {
receiver: mpsc::Receiver<BackendMessages>,
cur: BackendMessages,
}
impl Responses {
pub fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Result<Message, Error>> {
loop {
match self.cur.next().map_err(Error::parse)? {
Some(Message::ErrorResponse(body)) => return Poll::Ready(Err(Error::db(body))),
Some(message) => return Poll::Ready(Ok(message)),
None => {}
}
match ready!(self.receiver.poll_recv(cx)) {
Some(messages) => self.cur = messages,
None => return Poll::Ready(Err(Error::closed())),
}
}
}
pub async fn next(&mut self) -> Result<Message, Error> {
future::poll_fn(|cx| self.poll_next(cx)).await
}
}
/// A cache of type info and prepared statements for fetching type info
/// (corresponding to the queries in the [prepare] module).
#[derive(Default)]
struct CachedTypeInfo {
/// A statement for basic information for a type from its
/// OID. Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_QUERY) (or its
/// fallback).
typeinfo: Option<Statement>,
/// A statement for getting information for a composite type from its OID.
/// Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_COMPOSITE_QUERY).
typeinfo_composite: Option<Statement>,
/// A statement for getting information for a composite type from its OID.
/// Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_COMPOSITE_QUERY) (or
/// its fallback).
typeinfo_enum: Option<Statement>,
/// Cache of types already looked up.
types: HashMap<Oid, Type>,
}
pub struct InnerClient {
sender: mpsc::UnboundedSender<Request>,
cached_typeinfo: Mutex<CachedTypeInfo>,
/// A buffer to use when writing out postgres commands.
buffer: Mutex<BytesMut>,
}
impl InnerClient {
pub fn send(&self, messages: RequestMessages) -> Result<Responses, Error> {
let (sender, receiver) = mpsc::channel(1);
let request = Request { messages, sender };
self.sender.send(request).map_err(|_| Error::closed())?;
Ok(Responses {
receiver,
cur: BackendMessages::empty(),
})
}
pub fn typeinfo(&self) -> Option<Statement> {
self.cached_typeinfo.lock().typeinfo.clone()
}
pub fn set_typeinfo(&self, statement: &Statement) {
self.cached_typeinfo.lock().typeinfo = Some(statement.clone());
}
pub fn typeinfo_composite(&self) -> Option<Statement> {
self.cached_typeinfo.lock().typeinfo_composite.clone()
}
pub fn set_typeinfo_composite(&self, statement: &Statement) {
self.cached_typeinfo.lock().typeinfo_composite = Some(statement.clone());
}
pub fn typeinfo_enum(&self) -> Option<Statement> {
self.cached_typeinfo.lock().typeinfo_enum.clone()
}
pub fn set_typeinfo_enum(&self, statement: &Statement) {
self.cached_typeinfo.lock().typeinfo_enum = Some(statement.clone());
}
pub fn type_(&self, oid: Oid) -> Option<Type> {
self.cached_typeinfo.lock().types.get(&oid).cloned()
}
pub fn set_type(&self, oid: Oid, type_: &Type) {
self.cached_typeinfo.lock().types.insert(oid, type_.clone());
}
/// Call the given function with a buffer to be used when writing out
/// postgres commands.
pub fn with_buf<F, R>(&self, f: F) -> R
where
F: FnOnce(&mut BytesMut) -> R,
{
let mut buffer = self.buffer.lock();
let r = f(&mut buffer);
buffer.clear();
r
}
}
#[derive(Clone)]
pub(crate) struct SocketConfig {
pub host: Host,
pub port: u16,
pub connect_timeout: Option<Duration>,
// pub keepalive: Option<KeepaliveConfig>,
}
/// An asynchronous PostgreSQL client.
///
/// The client is one half of what is returned when a connection is established. Users interact with the database
/// through this client object.
pub struct Client {
inner: Arc<InnerClient>,
socket_config: Option<SocketConfig>,
ssl_mode: SslMode,
process_id: i32,
secret_key: i32,
}
impl Client {
pub(crate) fn new(
sender: mpsc::UnboundedSender<Request>,
ssl_mode: SslMode,
process_id: i32,
secret_key: i32,
) -> Client {
Client {
inner: Arc::new(InnerClient {
sender,
cached_typeinfo: Default::default(),
buffer: Default::default(),
}),
socket_config: None,
ssl_mode,
process_id,
secret_key,
}
}
/// Returns process_id.
pub fn get_process_id(&self) -> i32 {
self.process_id
}
pub(crate) fn inner(&self) -> &Arc<InnerClient> {
&self.inner
}
pub(crate) fn set_socket_config(&mut self, socket_config: SocketConfig) {
self.socket_config = Some(socket_config);
}
/// Creates a new prepared statement.
///
/// Prepared statements can be executed repeatedly, and may contain query parameters (indicated by `$1`, `$2`, etc),
/// which are set when executed. Prepared statements can only be used with the connection that created them.
pub async fn prepare(&self, query: &str) -> Result<Statement, Error> {
self.prepare_typed(query, &[]).await
}
/// Like `prepare`, but allows the types of query parameters to be explicitly specified.
///
/// The list of types may be smaller than the number of parameters - the types of the remaining parameters will be
/// inferred. For example, `client.prepare_typed(query, &[])` is equivalent to `client.prepare(query)`.
pub async fn prepare_typed(
&self,
query: &str,
parameter_types: &[Type],
) -> Result<Statement, Error> {
prepare::prepare(&self.inner, query, parameter_types).await
}
/// Executes a statement, returning a vector of the resulting rows.
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
/// provided, 1-indexed.
///
/// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be
/// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
/// with the `prepare` method.
///
/// # Panics
///
/// Panics if the number of parameters provided does not match the number expected.
pub async fn query<T>(
&self,
statement: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<Vec<Row>, Error>
where
T: ?Sized + ToStatement,
{
self.query_raw(statement, slice_iter(params))
.await?
.try_collect()
.await
}
/// The maximally flexible version of [`query`].
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
/// provided, 1-indexed.
///
/// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be
/// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
/// with the `prepare` method.
///
/// # Panics
///
/// Panics if the number of parameters provided does not match the number expected.
///
/// [`query`]: #method.query
pub async fn query_raw<'a, T, I>(&self, statement: &T, params: I) -> Result<RowStream, Error>
where
T: ?Sized + ToStatement,
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
I::IntoIter: ExactSizeIterator,
{
let statement = statement.__convert().into_statement(self).await?;
query::query(&self.inner, statement, params).await
}
/// Pass text directly to the Postgres backend to allow it to sort out typing itself and
/// to save a roundtrip
pub async fn query_raw_txt<S, I>(&self, statement: &str, params: I) -> Result<RowStream, Error>
where
S: AsRef<str>,
I: IntoIterator<Item = Option<S>>,
I::IntoIter: ExactSizeIterator,
{
query::query_txt(&self.inner, statement, params).await
}
/// Executes a statement, returning the number of rows modified.
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
/// provided, 1-indexed.
///
/// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be
/// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
/// with the `prepare` method.
///
/// If the statement does not modify any rows (e.g. `SELECT`), 0 is returned.
///
/// # Panics
///
/// Panics if the number of parameters provided does not match the number expected.
pub async fn execute<T>(
&self,
statement: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<u64, Error>
where
T: ?Sized + ToStatement,
{
self.execute_raw(statement, slice_iter(params)).await
}
/// The maximally flexible version of [`execute`].
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
/// provided, 1-indexed.
///
/// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be
/// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
/// with the `prepare` method.
///
/// # Panics
///
/// Panics if the number of parameters provided does not match the number expected.
///
/// [`execute`]: #method.execute
pub async fn execute_raw<'a, T, I>(&self, statement: &T, params: I) -> Result<u64, Error>
where
T: ?Sized + ToStatement,
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
I::IntoIter: ExactSizeIterator,
{
let statement = statement.__convert().into_statement(self).await?;
query::execute(self.inner(), statement, params).await
}
/// Executes a sequence of SQL statements using the simple query protocol, returning the resulting rows.
///
/// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that
/// point. The simple query protocol returns the values in rows as strings rather than in their binary encodings,
/// so the associated row type doesn't work with the `FromSql` trait. Rather than simply returning a list of the
/// rows, this method returns a list of an enum which indicates either the completion of one of the commands,
/// or a row of data. This preserves the framing between the separate statements in the request.
///
/// # Warning
///
/// Prepared statements should be use for any query which contains user-specified data, as they provided the
/// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass
/// them to this method!
pub async fn simple_query(&self, query: &str) -> Result<Vec<SimpleQueryMessage>, Error> {
self.simple_query_raw(query).await?.try_collect().await
}
pub(crate) async fn simple_query_raw(&self, query: &str) -> Result<SimpleQueryStream, Error> {
simple_query::simple_query(self.inner(), query).await
}
/// Executes a sequence of SQL statements using the simple query protocol.
///
/// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that
/// point. This is intended for use when, for example, initializing a database schema.
///
/// # Warning
///
/// Prepared statements should be use for any query which contains user-specified data, as they provided the
/// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass
/// them to this method!
pub async fn batch_execute(&self, query: &str) -> Result<ReadyForQueryStatus, Error> {
simple_query::batch_execute(self.inner(), query).await
}
/// Begins a new database transaction.
///
/// The transaction will roll back by default - use the `commit` method to commit it.
pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
struct RollbackIfNotDone<'me> {
client: &'me Client,
done: bool,
}
impl Drop for RollbackIfNotDone<'_> {
fn drop(&mut self) {
if self.done {
return;
}
let buf = self.client.inner().with_buf(|buf| {
frontend::query("ROLLBACK", buf).unwrap();
buf.split().freeze()
});
let _ = self
.client
.inner()
.send(RequestMessages::Single(FrontendMessage::Raw(buf)));
}
}
// This is done, as `Future` created by this method can be dropped after
// `RequestMessages` is synchronously send to the `Connection` by
// `batch_execute()`, but before `Responses` is asynchronously polled to
// completion. In that case `Transaction` won't be created and thus
// won't be rolled back.
{
let mut cleaner = RollbackIfNotDone {
client: self,
done: false,
};
self.batch_execute("BEGIN").await?;
cleaner.done = true;
}
Ok(Transaction::new(self))
}
/// Returns a builder for a transaction with custom settings.
///
/// Unlike the `transaction` method, the builder can be used to control the transaction's isolation level and other
/// attributes.
pub fn build_transaction(&mut self) -> TransactionBuilder<'_> {
TransactionBuilder::new(self)
}
/// Constructs a cancellation token that can later be used to request cancellation of a query running on the
/// connection associated with this client.
pub fn cancel_token(&self) -> CancelToken {
CancelToken {
socket_config: self.socket_config.clone(),
ssl_mode: self.ssl_mode,
process_id: self.process_id,
secret_key: self.secret_key,
}
}
/// Query for type information
pub async fn get_type(&self, oid: Oid) -> Result<Type, Error> {
crate::prepare::get_type(&self.inner, oid).await
}
/// Determines if the connection to the server has already closed.
///
/// In that case, all future queries will fail.
pub fn is_closed(&self) -> bool {
self.inner.sender.is_closed()
}
}
impl fmt::Debug for Client {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Client").finish()
}
}

View File

@@ -0,0 +1,109 @@
use bytes::{Buf, Bytes, BytesMut};
use fallible_iterator::FallibleIterator;
use postgres_protocol2::message::backend;
use postgres_protocol2::message::frontend::CopyData;
use std::io;
use tokio_util::codec::{Decoder, Encoder};
pub enum FrontendMessage {
Raw(Bytes),
CopyData(CopyData<Box<dyn Buf + Send>>),
}
pub enum BackendMessage {
Normal {
messages: BackendMessages,
request_complete: bool,
},
Async(backend::Message),
}
pub struct BackendMessages(BytesMut);
impl BackendMessages {
pub fn empty() -> BackendMessages {
BackendMessages(BytesMut::new())
}
}
impl FallibleIterator for BackendMessages {
type Item = backend::Message;
type Error = io::Error;
fn next(&mut self) -> io::Result<Option<backend::Message>> {
backend::Message::parse(&mut self.0)
}
}
pub struct PostgresCodec {
pub max_message_size: Option<usize>,
}
impl Encoder<FrontendMessage> for PostgresCodec {
type Error = io::Error;
fn encode(&mut self, item: FrontendMessage, dst: &mut BytesMut) -> io::Result<()> {
match item {
FrontendMessage::Raw(buf) => dst.extend_from_slice(&buf),
FrontendMessage::CopyData(data) => data.write(dst),
}
Ok(())
}
}
impl Decoder for PostgresCodec {
type Item = BackendMessage;
type Error = io::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<BackendMessage>, io::Error> {
let mut idx = 0;
let mut request_complete = false;
while let Some(header) = backend::Header::parse(&src[idx..])? {
let len = header.len() as usize + 1;
if src[idx..].len() < len {
break;
}
if let Some(max) = self.max_message_size {
if len > max {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"message too large",
));
}
}
match header.tag() {
backend::NOTICE_RESPONSE_TAG
| backend::NOTIFICATION_RESPONSE_TAG
| backend::PARAMETER_STATUS_TAG => {
if idx == 0 {
let message = backend::Message::parse(src)?.unwrap();
return Ok(Some(BackendMessage::Async(message)));
} else {
break;
}
}
_ => {}
}
idx += len;
if header.tag() == backend::READY_FOR_QUERY_TAG {
request_complete = true;
break;
}
}
if idx == 0 {
Ok(None)
} else {
Ok(Some(BackendMessage::Normal {
messages: BackendMessages(src.split_to(idx)),
request_complete,
}))
}
}
}

View File

@@ -0,0 +1,897 @@
//! Connection configuration.
use crate::connect::connect;
use crate::connect_raw::connect_raw;
use crate::tls::MakeTlsConnect;
use crate::tls::TlsConnect;
use crate::{Client, Connection, Error};
use std::borrow::Cow;
use std::str;
use std::str::FromStr;
use std::time::Duration;
use std::{error, fmt, iter, mem};
use tokio::io::{AsyncRead, AsyncWrite};
pub use postgres_protocol2::authentication::sasl::ScramKeys;
use tokio::net::TcpStream;
/// Properties required of a session.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum TargetSessionAttrs {
/// No special properties are required.
Any,
/// The session must allow writes.
ReadWrite,
}
/// TLS configuration.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum SslMode {
/// Do not use TLS.
Disable,
/// Attempt to connect with TLS but allow sessions without.
Prefer,
/// Require the use of TLS.
Require,
}
/// Channel binding configuration.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum ChannelBinding {
/// Do not use channel binding.
Disable,
/// Attempt to use channel binding but allow sessions without.
Prefer,
/// Require the use of channel binding.
Require,
}
/// Replication mode configuration.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum ReplicationMode {
/// Physical replication.
Physical,
/// Logical replication.
Logical,
}
/// A host specification.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Host {
/// A TCP hostname.
Tcp(String),
}
/// Precomputed keys which may override password during auth.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AuthKeys {
/// A `ClientKey` & `ServerKey` pair for `SCRAM-SHA-256`.
ScramSha256(ScramKeys<32>),
}
/// Connection configuration.
///
/// Configuration can be parsed from libpq-style connection strings. These strings come in two formats:
///
/// # Key-Value
///
/// This format consists of space-separated key-value pairs. Values which are either the empty string or contain
/// whitespace should be wrapped in `'`. `'` and `\` characters should be backslash-escaped.
///
/// ## Keys
///
/// * `user` - The username to authenticate with. Required.
/// * `password` - The password to authenticate with.
/// * `dbname` - The name of the database to connect to. Defaults to the username.
/// * `options` - Command line options used to configure the server.
/// * `application_name` - Sets the `application_name` parameter on the server.
/// * `sslmode` - Controls usage of TLS. If set to `disable`, TLS will not be used. If set to `prefer`, TLS will be used
/// if available, but not used otherwise. If set to `require`, TLS will be forced to be used. Defaults to `prefer`.
/// * `host` - The host to connect to. On Unix platforms, if the host starts with a `/` character it is treated as the
/// path to the directory containing Unix domain sockets. Otherwise, it is treated as a hostname. Multiple hosts
/// can be specified, separated by commas. Each host will be tried in turn when connecting. Required if connecting
/// with the `connect` method.
/// * `port` - The port to connect to. Multiple ports can be specified, separated by commas. The number of ports must be
/// either 1, in which case it will be used for all hosts, or the same as the number of hosts. Defaults to 5432 if
/// omitted or the empty string.
/// * `connect_timeout` - The time limit in seconds applied to each socket-level connection attempt. Note that hostnames
/// can resolve to multiple IP addresses, and this limit is applied to each address. Defaults to no timeout.
/// * `target_session_attrs` - Specifies requirements of the session. If set to `read-write`, the client will check that
/// the `transaction_read_write` session parameter is set to `on`. This can be used to connect to the primary server
/// in a database cluster as opposed to the secondary read-only mirrors. Defaults to `all`.
/// * `channel_binding` - Controls usage of channel binding in the authentication process. If set to `disable`, channel
/// binding will not be used. If set to `prefer`, channel binding will be used if available, but not used otherwise.
/// If set to `require`, the authentication process will fail if channel binding is not used. Defaults to `prefer`.
///
/// ## Examples
///
/// ```not_rust
/// host=localhost user=postgres connect_timeout=10 keepalives=0
/// ```
///
/// ```not_rust
/// host=/var/lib/postgresql,localhost port=1234 user=postgres password='password with spaces'
/// ```
///
/// ```not_rust
/// host=host1,host2,host3 port=1234,,5678 user=postgres target_session_attrs=read-write
/// ```
///
/// # Url
///
/// This format resembles a URL with a scheme of either `postgres://` or `postgresql://`. All components are optional,
/// and the format accepts query parameters for all of the key-value pairs described in the section above. Multiple
/// host/port pairs can be comma-separated. Unix socket paths in the host section of the URL should be percent-encoded,
/// as the path component of the URL specifies the database name.
///
/// ## Examples
///
/// ```not_rust
/// postgresql://user@localhost
/// ```
///
/// ```not_rust
/// postgresql://user:password@%2Fvar%2Flib%2Fpostgresql/mydb?connect_timeout=10
/// ```
///
/// ```not_rust
/// postgresql://user@host1:1234,host2,host3:5678?target_session_attrs=read-write
/// ```
///
/// ```not_rust
/// postgresql:///mydb?user=user&host=/var/lib/postgresql
/// ```
#[derive(Clone, PartialEq, Eq)]
pub struct Config {
pub(crate) user: Option<String>,
pub(crate) password: Option<Vec<u8>>,
pub(crate) auth_keys: Option<Box<AuthKeys>>,
pub(crate) dbname: Option<String>,
pub(crate) options: Option<String>,
pub(crate) application_name: Option<String>,
pub(crate) ssl_mode: SslMode,
pub(crate) host: Vec<Host>,
pub(crate) port: Vec<u16>,
pub(crate) connect_timeout: Option<Duration>,
pub(crate) target_session_attrs: TargetSessionAttrs,
pub(crate) channel_binding: ChannelBinding,
pub(crate) replication_mode: Option<ReplicationMode>,
pub(crate) max_backend_message_size: Option<usize>,
}
impl Default for Config {
fn default() -> Config {
Config::new()
}
}
impl Config {
/// Creates a new configuration.
pub fn new() -> Config {
Config {
user: None,
password: None,
auth_keys: None,
dbname: None,
options: None,
application_name: None,
ssl_mode: SslMode::Prefer,
host: vec![],
port: vec![],
connect_timeout: None,
target_session_attrs: TargetSessionAttrs::Any,
channel_binding: ChannelBinding::Prefer,
replication_mode: None,
max_backend_message_size: None,
}
}
/// Sets the user to authenticate with.
///
/// Required.
pub fn user(&mut self, user: &str) -> &mut Config {
self.user = Some(user.to_string());
self
}
/// Gets the user to authenticate with, if one has been configured with
/// the `user` method.
pub fn get_user(&self) -> Option<&str> {
self.user.as_deref()
}
/// Sets the password to authenticate with.
pub fn password<T>(&mut self, password: T) -> &mut Config
where
T: AsRef<[u8]>,
{
self.password = Some(password.as_ref().to_vec());
self
}
/// Gets the password to authenticate with, if one has been configured with
/// the `password` method.
pub fn get_password(&self) -> Option<&[u8]> {
self.password.as_deref()
}
/// Sets precomputed protocol-specific keys to authenticate with.
/// When set, this option will override `password`.
/// See [`AuthKeys`] for more information.
pub fn auth_keys(&mut self, keys: AuthKeys) -> &mut Config {
self.auth_keys = Some(Box::new(keys));
self
}
/// Gets precomputed protocol-specific keys to authenticate with.
/// if one has been configured with the `auth_keys` method.
pub fn get_auth_keys(&self) -> Option<AuthKeys> {
self.auth_keys.as_deref().copied()
}
/// Sets the name of the database to connect to.
///
/// Defaults to the user.
pub fn dbname(&mut self, dbname: &str) -> &mut Config {
self.dbname = Some(dbname.to_string());
self
}
/// Gets the name of the database to connect to, if one has been configured
/// with the `dbname` method.
pub fn get_dbname(&self) -> Option<&str> {
self.dbname.as_deref()
}
/// Sets command line options used to configure the server.
pub fn options(&mut self, options: &str) -> &mut Config {
self.options = Some(options.to_string());
self
}
/// Gets the command line options used to configure the server, if the
/// options have been set with the `options` method.
pub fn get_options(&self) -> Option<&str> {
self.options.as_deref()
}
/// Sets the value of the `application_name` runtime parameter.
pub fn application_name(&mut self, application_name: &str) -> &mut Config {
self.application_name = Some(application_name.to_string());
self
}
/// Gets the value of the `application_name` runtime parameter, if it has
/// been set with the `application_name` method.
pub fn get_application_name(&self) -> Option<&str> {
self.application_name.as_deref()
}
/// Sets the SSL configuration.
///
/// Defaults to `prefer`.
pub fn ssl_mode(&mut self, ssl_mode: SslMode) -> &mut Config {
self.ssl_mode = ssl_mode;
self
}
/// Gets the SSL configuration.
pub fn get_ssl_mode(&self) -> SslMode {
self.ssl_mode
}
/// Adds a host to the configuration.
///
/// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order.
pub fn host(&mut self, host: &str) -> &mut Config {
self.host.push(Host::Tcp(host.to_string()));
self
}
/// Gets the hosts that have been added to the configuration with `host`.
pub fn get_hosts(&self) -> &[Host] {
&self.host
}
/// Adds a port to the configuration.
///
/// Multiple ports can be specified by calling this method multiple times. There must either be no ports, in which
/// case the default of 5432 is used, a single port, in which it is used for all hosts, or the same number of ports
/// as hosts.
pub fn port(&mut self, port: u16) -> &mut Config {
self.port.push(port);
self
}
/// Gets the ports that have been added to the configuration with `port`.
pub fn get_ports(&self) -> &[u16] {
&self.port
}
/// Sets the timeout applied to socket-level connection attempts.
///
/// Note that hostnames can resolve to multiple IP addresses, and this timeout will apply to each address of each
/// host separately. Defaults to no limit.
pub fn connect_timeout(&mut self, connect_timeout: Duration) -> &mut Config {
self.connect_timeout = Some(connect_timeout);
self
}
/// Gets the connection timeout, if one has been set with the
/// `connect_timeout` method.
pub fn get_connect_timeout(&self) -> Option<&Duration> {
self.connect_timeout.as_ref()
}
/// Sets the requirements of the session.
///
/// This can be used to connect to the primary server in a clustered database rather than one of the read-only
/// secondary servers. Defaults to `Any`.
pub fn target_session_attrs(
&mut self,
target_session_attrs: TargetSessionAttrs,
) -> &mut Config {
self.target_session_attrs = target_session_attrs;
self
}
/// Gets the requirements of the session.
pub fn get_target_session_attrs(&self) -> TargetSessionAttrs {
self.target_session_attrs
}
/// Sets the channel binding behavior.
///
/// Defaults to `prefer`.
pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config {
self.channel_binding = channel_binding;
self
}
/// Gets the channel binding behavior.
pub fn get_channel_binding(&self) -> ChannelBinding {
self.channel_binding
}
/// Set replication mode.
pub fn replication_mode(&mut self, replication_mode: ReplicationMode) -> &mut Config {
self.replication_mode = Some(replication_mode);
self
}
/// Get replication mode.
pub fn get_replication_mode(&self) -> Option<ReplicationMode> {
self.replication_mode
}
/// Set limit for backend messages size.
pub fn max_backend_message_size(&mut self, max_backend_message_size: usize) -> &mut Config {
self.max_backend_message_size = Some(max_backend_message_size);
self
}
/// Get limit for backend messages size.
pub fn get_max_backend_message_size(&self) -> Option<usize> {
self.max_backend_message_size
}
fn param(&mut self, key: &str, value: &str) -> Result<(), Error> {
match key {
"user" => {
self.user(value);
}
"password" => {
self.password(value);
}
"dbname" => {
self.dbname(value);
}
"options" => {
self.options(value);
}
"application_name" => {
self.application_name(value);
}
"sslmode" => {
let mode = match value {
"disable" => SslMode::Disable,
"prefer" => SslMode::Prefer,
"require" => SslMode::Require,
_ => return Err(Error::config_parse(Box::new(InvalidValue("sslmode")))),
};
self.ssl_mode(mode);
}
"host" => {
for host in value.split(',') {
self.host(host);
}
}
"port" => {
for port in value.split(',') {
let port = if port.is_empty() {
5432
} else {
port.parse()
.map_err(|_| Error::config_parse(Box::new(InvalidValue("port"))))?
};
self.port(port);
}
}
"connect_timeout" => {
let timeout = value
.parse::<i64>()
.map_err(|_| Error::config_parse(Box::new(InvalidValue("connect_timeout"))))?;
if timeout > 0 {
self.connect_timeout(Duration::from_secs(timeout as u64));
}
}
"target_session_attrs" => {
let target_session_attrs = match value {
"any" => TargetSessionAttrs::Any,
"read-write" => TargetSessionAttrs::ReadWrite,
_ => {
return Err(Error::config_parse(Box::new(InvalidValue(
"target_session_attrs",
))));
}
};
self.target_session_attrs(target_session_attrs);
}
"channel_binding" => {
let channel_binding = match value {
"disable" => ChannelBinding::Disable,
"prefer" => ChannelBinding::Prefer,
"require" => ChannelBinding::Require,
_ => {
return Err(Error::config_parse(Box::new(InvalidValue(
"channel_binding",
))))
}
};
self.channel_binding(channel_binding);
}
"max_backend_message_size" => {
let limit = value.parse::<usize>().map_err(|_| {
Error::config_parse(Box::new(InvalidValue("max_backend_message_size")))
})?;
if limit > 0 {
self.max_backend_message_size(limit);
}
}
key => {
return Err(Error::config_parse(Box::new(UnknownOption(
key.to_string(),
))));
}
}
Ok(())
}
/// Opens a connection to a PostgreSQL database.
///
/// Requires the `runtime` Cargo feature (enabled by default).
pub async fn connect<T>(
&self,
tls: T,
) -> Result<(Client, Connection<TcpStream, T::Stream>), Error>
where
T: MakeTlsConnect<TcpStream>,
{
connect(tls, self).await
}
/// Connects to a PostgreSQL database over an arbitrary stream.
///
/// All of the settings other than `user`, `password`, `dbname`, `options`, and `application_name` name are ignored.
pub async fn connect_raw<S, T>(
&self,
stream: S,
tls: T,
) -> Result<(Client, Connection<S, T::Stream>), Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: TlsConnect<S>,
{
connect_raw(stream, tls, self).await
}
}
impl FromStr for Config {
type Err = Error;
fn from_str(s: &str) -> Result<Config, Error> {
match UrlParser::parse(s)? {
Some(config) => Ok(config),
None => Parser::parse(s),
}
}
}
// Omit password from debug output
impl fmt::Debug for Config {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
struct Redaction {}
impl fmt::Debug for Redaction {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "_")
}
}
f.debug_struct("Config")
.field("user", &self.user)
.field("password", &self.password.as_ref().map(|_| Redaction {}))
.field("dbname", &self.dbname)
.field("options", &self.options)
.field("application_name", &self.application_name)
.field("ssl_mode", &self.ssl_mode)
.field("host", &self.host)
.field("port", &self.port)
.field("connect_timeout", &self.connect_timeout)
.field("target_session_attrs", &self.target_session_attrs)
.field("channel_binding", &self.channel_binding)
.field("replication", &self.replication_mode)
.finish()
}
}
#[derive(Debug)]
struct UnknownOption(String);
impl fmt::Display for UnknownOption {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(fmt, "unknown option `{}`", self.0)
}
}
impl error::Error for UnknownOption {}
#[derive(Debug)]
struct InvalidValue(&'static str);
impl fmt::Display for InvalidValue {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(fmt, "invalid value for option `{}`", self.0)
}
}
impl error::Error for InvalidValue {}
struct Parser<'a> {
s: &'a str,
it: iter::Peekable<str::CharIndices<'a>>,
}
impl<'a> Parser<'a> {
fn parse(s: &'a str) -> Result<Config, Error> {
let mut parser = Parser {
s,
it: s.char_indices().peekable(),
};
let mut config = Config::new();
while let Some((key, value)) = parser.parameter()? {
config.param(key, &value)?;
}
Ok(config)
}
fn skip_ws(&mut self) {
self.take_while(char::is_whitespace);
}
fn take_while<F>(&mut self, f: F) -> &'a str
where
F: Fn(char) -> bool,
{
let start = match self.it.peek() {
Some(&(i, _)) => i,
None => return "",
};
loop {
match self.it.peek() {
Some(&(_, c)) if f(c) => {
self.it.next();
}
Some(&(i, _)) => return &self.s[start..i],
None => return &self.s[start..],
}
}
}
fn eat(&mut self, target: char) -> Result<(), Error> {
match self.it.next() {
Some((_, c)) if c == target => Ok(()),
Some((i, c)) => {
let m = format!(
"unexpected character at byte {}: expected `{}` but got `{}`",
i, target, c
);
Err(Error::config_parse(m.into()))
}
None => Err(Error::config_parse("unexpected EOF".into())),
}
}
fn eat_if(&mut self, target: char) -> bool {
match self.it.peek() {
Some(&(_, c)) if c == target => {
self.it.next();
true
}
_ => false,
}
}
fn keyword(&mut self) -> Option<&'a str> {
let s = self.take_while(|c| match c {
c if c.is_whitespace() => false,
'=' => false,
_ => true,
});
if s.is_empty() {
None
} else {
Some(s)
}
}
fn value(&mut self) -> Result<String, Error> {
let value = if self.eat_if('\'') {
let value = self.quoted_value()?;
self.eat('\'')?;
value
} else {
self.simple_value()?
};
Ok(value)
}
fn simple_value(&mut self) -> Result<String, Error> {
let mut value = String::new();
while let Some(&(_, c)) = self.it.peek() {
if c.is_whitespace() {
break;
}
self.it.next();
if c == '\\' {
if let Some((_, c2)) = self.it.next() {
value.push(c2);
}
} else {
value.push(c);
}
}
if value.is_empty() {
return Err(Error::config_parse("unexpected EOF".into()));
}
Ok(value)
}
fn quoted_value(&mut self) -> Result<String, Error> {
let mut value = String::new();
while let Some(&(_, c)) = self.it.peek() {
if c == '\'' {
return Ok(value);
}
self.it.next();
if c == '\\' {
if let Some((_, c2)) = self.it.next() {
value.push(c2);
}
} else {
value.push(c);
}
}
Err(Error::config_parse(
"unterminated quoted connection parameter value".into(),
))
}
fn parameter(&mut self) -> Result<Option<(&'a str, String)>, Error> {
self.skip_ws();
let keyword = match self.keyword() {
Some(keyword) => keyword,
None => return Ok(None),
};
self.skip_ws();
self.eat('=')?;
self.skip_ws();
let value = self.value()?;
Ok(Some((keyword, value)))
}
}
// This is a pretty sloppy "URL" parser, but it matches the behavior of libpq, where things really aren't very strict
struct UrlParser<'a> {
s: &'a str,
config: Config,
}
impl<'a> UrlParser<'a> {
fn parse(s: &'a str) -> Result<Option<Config>, Error> {
let s = match Self::remove_url_prefix(s) {
Some(s) => s,
None => return Ok(None),
};
let mut parser = UrlParser {
s,
config: Config::new(),
};
parser.parse_credentials()?;
parser.parse_host()?;
parser.parse_path()?;
parser.parse_params()?;
Ok(Some(parser.config))
}
fn remove_url_prefix(s: &str) -> Option<&str> {
for prefix in &["postgres://", "postgresql://"] {
if let Some(stripped) = s.strip_prefix(prefix) {
return Some(stripped);
}
}
None
}
fn take_until(&mut self, end: &[char]) -> Option<&'a str> {
match self.s.find(end) {
Some(pos) => {
let (head, tail) = self.s.split_at(pos);
self.s = tail;
Some(head)
}
None => None,
}
}
fn take_all(&mut self) -> &'a str {
mem::take(&mut self.s)
}
fn eat_byte(&mut self) {
self.s = &self.s[1..];
}
fn parse_credentials(&mut self) -> Result<(), Error> {
let creds = match self.take_until(&['@']) {
Some(creds) => creds,
None => return Ok(()),
};
self.eat_byte();
let mut it = creds.splitn(2, ':');
let user = self.decode(it.next().unwrap())?;
self.config.user(&user);
if let Some(password) = it.next() {
let password = Cow::from(percent_encoding::percent_decode(password.as_bytes()));
self.config.password(password);
}
Ok(())
}
fn parse_host(&mut self) -> Result<(), Error> {
let host = match self.take_until(&['/', '?']) {
Some(host) => host,
None => self.take_all(),
};
if host.is_empty() {
return Ok(());
}
for chunk in host.split(',') {
let (host, port) = if chunk.starts_with('[') {
let idx = match chunk.find(']') {
Some(idx) => idx,
None => return Err(Error::config_parse(InvalidValue("host").into())),
};
let host = &chunk[1..idx];
let remaining = &chunk[idx + 1..];
let port = if let Some(port) = remaining.strip_prefix(':') {
Some(port)
} else if remaining.is_empty() {
None
} else {
return Err(Error::config_parse(InvalidValue("host").into()));
};
(host, port)
} else {
let mut it = chunk.splitn(2, ':');
(it.next().unwrap(), it.next())
};
self.host_param(host)?;
let port = self.decode(port.unwrap_or("5432"))?;
self.config.param("port", &port)?;
}
Ok(())
}
fn parse_path(&mut self) -> Result<(), Error> {
if !self.s.starts_with('/') {
return Ok(());
}
self.eat_byte();
let dbname = match self.take_until(&['?']) {
Some(dbname) => dbname,
None => self.take_all(),
};
if !dbname.is_empty() {
self.config.dbname(&self.decode(dbname)?);
}
Ok(())
}
fn parse_params(&mut self) -> Result<(), Error> {
if !self.s.starts_with('?') {
return Ok(());
}
self.eat_byte();
while !self.s.is_empty() {
let key = match self.take_until(&['=']) {
Some(key) => self.decode(key)?,
None => return Err(Error::config_parse("unterminated parameter".into())),
};
self.eat_byte();
let value = match self.take_until(&['&']) {
Some(value) => {
self.eat_byte();
value
}
None => self.take_all(),
};
if key == "host" {
self.host_param(value)?;
} else {
let value = self.decode(value)?;
self.config.param(&key, &value)?;
}
}
Ok(())
}
fn host_param(&mut self, s: &str) -> Result<(), Error> {
let s = self.decode(s)?;
self.config.param("host", &s)
}
fn decode(&self, s: &'a str) -> Result<Cow<'a, str>, Error> {
percent_encoding::percent_decode(s.as_bytes())
.decode_utf8()
.map_err(|e| Error::config_parse(e.into()))
}
}

View File

@@ -0,0 +1,112 @@
use crate::client::SocketConfig;
use crate::config::{Host, TargetSessionAttrs};
use crate::connect_raw::connect_raw;
use crate::connect_socket::connect_socket;
use crate::tls::{MakeTlsConnect, TlsConnect};
use crate::{Client, Config, Connection, Error, SimpleQueryMessage};
use futures_util::{future, pin_mut, Future, FutureExt, Stream};
use std::io;
use std::task::Poll;
use tokio::net::TcpStream;
pub async fn connect<T>(
mut tls: T,
config: &Config,
) -> Result<(Client, Connection<TcpStream, T::Stream>), Error>
where
T: MakeTlsConnect<TcpStream>,
{
if config.host.is_empty() {
return Err(Error::config("host missing".into()));
}
if config.port.len() > 1 && config.port.len() != config.host.len() {
return Err(Error::config("invalid number of ports".into()));
}
let mut error = None;
for (i, host) in config.host.iter().enumerate() {
let port = config
.port
.get(i)
.or_else(|| config.port.first())
.copied()
.unwrap_or(5432);
let hostname = match host {
Host::Tcp(host) => host.as_str(),
};
let tls = tls
.make_tls_connect(hostname)
.map_err(|e| Error::tls(e.into()))?;
match connect_once(host, port, tls, config).await {
Ok((client, connection)) => return Ok((client, connection)),
Err(e) => error = Some(e),
}
}
Err(error.unwrap())
}
async fn connect_once<T>(
host: &Host,
port: u16,
tls: T,
config: &Config,
) -> Result<(Client, Connection<TcpStream, T::Stream>), Error>
where
T: TlsConnect<TcpStream>,
{
let socket = connect_socket(host, port, config.connect_timeout).await?;
let (mut client, mut connection) = connect_raw(socket, tls, config).await?;
if let TargetSessionAttrs::ReadWrite = config.target_session_attrs {
let rows = client.simple_query_raw("SHOW transaction_read_only");
pin_mut!(rows);
let rows = future::poll_fn(|cx| {
if connection.poll_unpin(cx)?.is_ready() {
return Poll::Ready(Err(Error::closed()));
}
rows.as_mut().poll(cx)
})
.await?;
pin_mut!(rows);
loop {
let next = future::poll_fn(|cx| {
if connection.poll_unpin(cx)?.is_ready() {
return Poll::Ready(Some(Err(Error::closed())));
}
rows.as_mut().poll_next(cx)
});
match next.await.transpose()? {
Some(SimpleQueryMessage::Row(row)) => {
if row.try_get(0)? == Some("on") {
return Err(Error::connect(io::Error::new(
io::ErrorKind::PermissionDenied,
"database does not allow writes",
)));
} else {
break;
}
}
Some(_) => {}
None => return Err(Error::unexpected_message()),
}
}
}
client.set_socket_config(SocketConfig {
host: host.clone(),
port,
connect_timeout: config.connect_timeout,
});
Ok((client, connection))
}

View File

@@ -0,0 +1,359 @@
use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
use crate::config::{self, AuthKeys, Config, ReplicationMode};
use crate::connect_tls::connect_tls;
use crate::maybe_tls_stream::MaybeTlsStream;
use crate::tls::{TlsConnect, TlsStream};
use crate::{Client, Connection, Error};
use bytes::BytesMut;
use fallible_iterator::FallibleIterator;
use futures_util::{ready, Sink, SinkExt, Stream, TryStreamExt};
use postgres_protocol2::authentication;
use postgres_protocol2::authentication::sasl;
use postgres_protocol2::authentication::sasl::ScramSha256;
use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message};
use postgres_protocol2::message::frontend;
use std::collections::{HashMap, VecDeque};
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::sync::mpsc;
use tokio_util::codec::Framed;
pub struct StartupStream<S, T> {
inner: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
buf: BackendMessages,
delayed: VecDeque<BackendMessage>,
}
impl<S, T> Sink<FrontendMessage> for StartupStream<S, T>
where
S: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin,
{
type Error = io::Error;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_ready(cx)
}
fn start_send(mut self: Pin<&mut Self>, item: FrontendMessage) -> io::Result<()> {
Pin::new(&mut self.inner).start_send(item)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_close(cx)
}
}
impl<S, T> Stream for StartupStream<S, T>
where
S: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin,
{
type Item = io::Result<Message>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<io::Result<Message>>> {
loop {
match self.buf.next() {
Ok(Some(message)) => return Poll::Ready(Some(Ok(message))),
Ok(None) => {}
Err(e) => return Poll::Ready(Some(Err(e))),
}
match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
Some(Ok(BackendMessage::Normal { messages, .. })) => self.buf = messages,
Some(Ok(BackendMessage::Async(message))) => return Poll::Ready(Some(Ok(message))),
Some(Err(e)) => return Poll::Ready(Some(Err(e))),
None => return Poll::Ready(None),
}
}
}
}
pub async fn connect_raw<S, T>(
stream: S,
tls: T,
config: &Config,
) -> Result<(Client, Connection<S, T::Stream>), Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: TlsConnect<S>,
{
let stream = connect_tls(stream, config.ssl_mode, tls).await?;
let mut stream = StartupStream {
inner: Framed::new(
stream,
PostgresCodec {
max_message_size: config.max_backend_message_size,
},
),
buf: BackendMessages::empty(),
delayed: VecDeque::new(),
};
startup(&mut stream, config).await?;
authenticate(&mut stream, config).await?;
let (process_id, secret_key, parameters) = read_info(&mut stream).await?;
let (sender, receiver) = mpsc::unbounded_channel();
let client = Client::new(sender, config.ssl_mode, process_id, secret_key);
let connection = Connection::new(stream.inner, stream.delayed, parameters, receiver);
Ok((client, connection))
}
async fn startup<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin,
{
let mut params = vec![("client_encoding", "UTF8")];
if let Some(user) = &config.user {
params.push(("user", &**user));
}
if let Some(dbname) = &config.dbname {
params.push(("database", &**dbname));
}
if let Some(options) = &config.options {
params.push(("options", &**options));
}
if let Some(application_name) = &config.application_name {
params.push(("application_name", &**application_name));
}
if let Some(replication_mode) = &config.replication_mode {
match replication_mode {
ReplicationMode::Physical => params.push(("replication", "true")),
ReplicationMode::Logical => params.push(("replication", "database")),
}
}
let mut buf = BytesMut::new();
frontend::startup_message(params, &mut buf).map_err(Error::encode)?;
stream
.send(FrontendMessage::Raw(buf.freeze()))
.await
.map_err(Error::io)
}
async fn authenticate<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: TlsStream + Unpin,
{
match stream.try_next().await.map_err(Error::io)? {
Some(Message::AuthenticationOk) => {
can_skip_channel_binding(config)?;
return Ok(());
}
Some(Message::AuthenticationCleartextPassword) => {
can_skip_channel_binding(config)?;
let pass = config
.password
.as_ref()
.ok_or_else(|| Error::config("password missing".into()))?;
authenticate_password(stream, pass).await?;
}
Some(Message::AuthenticationMd5Password(body)) => {
can_skip_channel_binding(config)?;
let user = config
.user
.as_ref()
.ok_or_else(|| Error::config("user missing".into()))?;
let pass = config
.password
.as_ref()
.ok_or_else(|| Error::config("password missing".into()))?;
let output = authentication::md5_hash(user.as_bytes(), pass, body.salt());
authenticate_password(stream, output.as_bytes()).await?;
}
Some(Message::AuthenticationSasl(body)) => {
authenticate_sasl(stream, body, config).await?;
}
Some(Message::AuthenticationKerberosV5)
| Some(Message::AuthenticationScmCredential)
| Some(Message::AuthenticationGss)
| Some(Message::AuthenticationSspi) => {
return Err(Error::authentication(
"unsupported authentication method".into(),
))
}
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
Some(_) => return Err(Error::unexpected_message()),
None => return Err(Error::closed()),
}
match stream.try_next().await.map_err(Error::io)? {
Some(Message::AuthenticationOk) => Ok(()),
Some(Message::ErrorResponse(body)) => Err(Error::db(body)),
Some(_) => Err(Error::unexpected_message()),
None => Err(Error::closed()),
}
}
fn can_skip_channel_binding(config: &Config) -> Result<(), Error> {
match config.channel_binding {
config::ChannelBinding::Disable | config::ChannelBinding::Prefer => Ok(()),
config::ChannelBinding::Require => Err(Error::authentication(
"server did not use channel binding".into(),
)),
}
}
async fn authenticate_password<S, T>(
stream: &mut StartupStream<S, T>,
password: &[u8],
) -> Result<(), Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin,
{
let mut buf = BytesMut::new();
frontend::password_message(password, &mut buf).map_err(Error::encode)?;
stream
.send(FrontendMessage::Raw(buf.freeze()))
.await
.map_err(Error::io)
}
async fn authenticate_sasl<S, T>(
stream: &mut StartupStream<S, T>,
body: AuthenticationSaslBody,
config: &Config,
) -> Result<(), Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: TlsStream + Unpin,
{
let mut has_scram = false;
let mut has_scram_plus = false;
let mut mechanisms = body.mechanisms();
while let Some(mechanism) = mechanisms.next().map_err(Error::parse)? {
match mechanism {
sasl::SCRAM_SHA_256 => has_scram = true,
sasl::SCRAM_SHA_256_PLUS => has_scram_plus = true,
_ => {}
}
}
let channel_binding = stream
.inner
.get_ref()
.channel_binding()
.tls_server_end_point
.filter(|_| config.channel_binding != config::ChannelBinding::Disable)
.map(sasl::ChannelBinding::tls_server_end_point);
let (channel_binding, mechanism) = if has_scram_plus {
match channel_binding {
Some(channel_binding) => (channel_binding, sasl::SCRAM_SHA_256_PLUS),
None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
}
} else if has_scram {
match channel_binding {
Some(_) => (sasl::ChannelBinding::unrequested(), sasl::SCRAM_SHA_256),
None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
}
} else {
return Err(Error::authentication("unsupported SASL mechanism".into()));
};
if mechanism != sasl::SCRAM_SHA_256_PLUS {
can_skip_channel_binding(config)?;
}
let mut scram = if let Some(AuthKeys::ScramSha256(keys)) = config.get_auth_keys() {
ScramSha256::new_with_keys(keys, channel_binding)
} else if let Some(password) = config.get_password() {
ScramSha256::new(password, channel_binding)
} else {
return Err(Error::config("password or auth keys missing".into()));
};
let mut buf = BytesMut::new();
frontend::sasl_initial_response(mechanism, scram.message(), &mut buf).map_err(Error::encode)?;
stream
.send(FrontendMessage::Raw(buf.freeze()))
.await
.map_err(Error::io)?;
let body = match stream.try_next().await.map_err(Error::io)? {
Some(Message::AuthenticationSaslContinue(body)) => body,
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
Some(_) => return Err(Error::unexpected_message()),
None => return Err(Error::closed()),
};
scram
.update(body.data())
.await
.map_err(|e| Error::authentication(e.into()))?;
let mut buf = BytesMut::new();
frontend::sasl_response(scram.message(), &mut buf).map_err(Error::encode)?;
stream
.send(FrontendMessage::Raw(buf.freeze()))
.await
.map_err(Error::io)?;
let body = match stream.try_next().await.map_err(Error::io)? {
Some(Message::AuthenticationSaslFinal(body)) => body,
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
Some(_) => return Err(Error::unexpected_message()),
None => return Err(Error::closed()),
};
scram
.finish(body.data())
.map_err(|e| Error::authentication(e.into()))?;
Ok(())
}
async fn read_info<S, T>(
stream: &mut StartupStream<S, T>,
) -> Result<(i32, i32, HashMap<String, String>), Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin,
{
let mut process_id = 0;
let mut secret_key = 0;
let mut parameters = HashMap::new();
loop {
match stream.try_next().await.map_err(Error::io)? {
Some(Message::BackendKeyData(body)) => {
process_id = body.process_id();
secret_key = body.secret_key();
}
Some(Message::ParameterStatus(body)) => {
parameters.insert(
body.name().map_err(Error::parse)?.to_string(),
body.value().map_err(Error::parse)?.to_string(),
);
}
Some(msg @ Message::NoticeResponse(_)) => {
stream.delayed.push_back(BackendMessage::Async(msg))
}
Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key, parameters)),
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
Some(_) => return Err(Error::unexpected_message()),
None => return Err(Error::closed()),
}
}
}

View File

@@ -0,0 +1,65 @@
use crate::config::Host;
use crate::Error;
use std::future::Future;
use std::io;
use std::time::Duration;
use tokio::net::{self, TcpStream};
use tokio::time;
pub(crate) async fn connect_socket(
host: &Host,
port: u16,
connect_timeout: Option<Duration>,
) -> Result<TcpStream, Error> {
match host {
Host::Tcp(host) => {
let addrs = net::lookup_host((&**host, port))
.await
.map_err(Error::connect)?;
let mut last_err = None;
for addr in addrs {
let stream =
match connect_with_timeout(TcpStream::connect(addr), connect_timeout).await {
Ok(stream) => stream,
Err(e) => {
last_err = Some(e);
continue;
}
};
stream.set_nodelay(true).map_err(Error::connect)?;
return Ok(stream);
}
Err(last_err.unwrap_or_else(|| {
Error::connect(io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve any addresses",
))
}))
}
}
}
async fn connect_with_timeout<F, T>(connect: F, timeout: Option<Duration>) -> Result<T, Error>
where
F: Future<Output = io::Result<T>>,
{
match timeout {
Some(timeout) => match time::timeout(timeout, connect).await {
Ok(Ok(socket)) => Ok(socket),
Ok(Err(e)) => Err(Error::connect(e)),
Err(_) => Err(Error::connect(io::Error::new(
io::ErrorKind::TimedOut,
"connection timed out",
))),
},
None => match connect.await {
Ok(socket) => Ok(socket),
Err(e) => Err(Error::connect(e)),
},
}
}

View File

@@ -0,0 +1,48 @@
use crate::config::SslMode;
use crate::maybe_tls_stream::MaybeTlsStream;
use crate::tls::private::ForcePrivateApi;
use crate::tls::TlsConnect;
use crate::Error;
use bytes::BytesMut;
use postgres_protocol2::message::frontend;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
pub async fn connect_tls<S, T>(
mut stream: S,
mode: SslMode,
tls: T,
) -> Result<MaybeTlsStream<S, T::Stream>, Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: TlsConnect<S>,
{
match mode {
SslMode::Disable => return Ok(MaybeTlsStream::Raw(stream)),
SslMode::Prefer if !tls.can_connect(ForcePrivateApi) => {
return Ok(MaybeTlsStream::Raw(stream))
}
SslMode::Prefer | SslMode::Require => {}
}
let mut buf = BytesMut::new();
frontend::ssl_request(&mut buf);
stream.write_all(&buf).await.map_err(Error::io)?;
let mut buf = [0];
stream.read_exact(&mut buf).await.map_err(Error::io)?;
if buf[0] != b'S' {
if SslMode::Require == mode {
return Err(Error::tls("server does not support TLS".into()));
} else {
return Ok(MaybeTlsStream::Raw(stream));
}
}
let stream = tls
.connect(stream)
.await
.map_err(|e| Error::tls(e.into()))?;
Ok(MaybeTlsStream::Tls(stream))
}

View File

@@ -0,0 +1,323 @@
use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
use crate::error::DbError;
use crate::maybe_tls_stream::MaybeTlsStream;
use crate::{AsyncMessage, Error, Notification};
use bytes::BytesMut;
use fallible_iterator::FallibleIterator;
use futures_util::{ready, Sink, Stream};
use log::{info, trace};
use postgres_protocol2::message::backend::Message;
use postgres_protocol2::message::frontend;
use std::collections::{HashMap, VecDeque};
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::sync::mpsc;
use tokio_util::codec::Framed;
use tokio_util::sync::PollSender;
pub enum RequestMessages {
Single(FrontendMessage),
}
pub struct Request {
pub messages: RequestMessages,
pub sender: mpsc::Sender<BackendMessages>,
}
pub struct Response {
sender: PollSender<BackendMessages>,
}
#[derive(PartialEq, Debug)]
enum State {
Active,
Terminating,
Closing,
}
/// A connection to a PostgreSQL database.
///
/// This is one half of what is returned when a new connection is established. It performs the actual IO with the
/// server, and should generally be spawned off onto an executor to run in the background.
///
/// `Connection` implements `Future`, and only resolves when the connection is closed, either because a fatal error has
/// occurred, or because its associated `Client` has dropped and all outstanding work has completed.
#[must_use = "futures do nothing unless polled"]
pub struct Connection<S, T> {
/// HACK: we need this in the Neon Proxy.
pub stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
/// HACK: we need this in the Neon Proxy to forward params.
pub parameters: HashMap<String, String>,
receiver: mpsc::UnboundedReceiver<Request>,
pending_request: Option<RequestMessages>,
pending_responses: VecDeque<BackendMessage>,
responses: VecDeque<Response>,
state: State,
}
impl<S, T> Connection<S, T>
where
S: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin,
{
pub(crate) fn new(
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
pending_responses: VecDeque<BackendMessage>,
parameters: HashMap<String, String>,
receiver: mpsc::UnboundedReceiver<Request>,
) -> Connection<S, T> {
Connection {
stream,
parameters,
receiver,
pending_request: None,
pending_responses,
responses: VecDeque::new(),
state: State::Active,
}
}
fn poll_response(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<BackendMessage, Error>>> {
if let Some(message) = self.pending_responses.pop_front() {
trace!("retrying pending response");
return Poll::Ready(Some(Ok(message)));
}
Pin::new(&mut self.stream)
.poll_next(cx)
.map(|o| o.map(|r| r.map_err(Error::io)))
}
fn poll_read(&mut self, cx: &mut Context<'_>) -> Result<Option<AsyncMessage>, Error> {
if self.state != State::Active {
trace!("poll_read: done");
return Ok(None);
}
loop {
let message = match self.poll_response(cx)? {
Poll::Ready(Some(message)) => message,
Poll::Ready(None) => return Err(Error::closed()),
Poll::Pending => {
trace!("poll_read: waiting on response");
return Ok(None);
}
};
let (mut messages, request_complete) = match message {
BackendMessage::Async(Message::NoticeResponse(body)) => {
let error = DbError::parse(&mut body.fields()).map_err(Error::parse)?;
return Ok(Some(AsyncMessage::Notice(error)));
}
BackendMessage::Async(Message::NotificationResponse(body)) => {
let notification = Notification {
process_id: body.process_id(),
channel: body.channel().map_err(Error::parse)?.to_string(),
payload: body.message().map_err(Error::parse)?.to_string(),
};
return Ok(Some(AsyncMessage::Notification(notification)));
}
BackendMessage::Async(Message::ParameterStatus(body)) => {
self.parameters.insert(
body.name().map_err(Error::parse)?.to_string(),
body.value().map_err(Error::parse)?.to_string(),
);
continue;
}
BackendMessage::Async(_) => unreachable!(),
BackendMessage::Normal {
messages,
request_complete,
} => (messages, request_complete),
};
let mut response = match self.responses.pop_front() {
Some(response) => response,
None => match messages.next().map_err(Error::parse)? {
Some(Message::ErrorResponse(error)) => return Err(Error::db(error)),
_ => return Err(Error::unexpected_message()),
},
};
match response.sender.poll_reserve(cx) {
Poll::Ready(Ok(())) => {
let _ = response.sender.send_item(messages);
if !request_complete {
self.responses.push_front(response);
}
}
Poll::Ready(Err(_)) => {
// we need to keep paging through the rest of the messages even if the receiver's hung up
if !request_complete {
self.responses.push_front(response);
}
}
Poll::Pending => {
self.responses.push_front(response);
self.pending_responses.push_back(BackendMessage::Normal {
messages,
request_complete,
});
trace!("poll_read: waiting on sender");
return Ok(None);
}
}
}
}
fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll<Option<RequestMessages>> {
if let Some(messages) = self.pending_request.take() {
trace!("retrying pending request");
return Poll::Ready(Some(messages));
}
if self.receiver.is_closed() {
return Poll::Ready(None);
}
match self.receiver.poll_recv(cx) {
Poll::Ready(Some(request)) => {
trace!("polled new request");
self.responses.push_back(Response {
sender: PollSender::new(request.sender),
});
Poll::Ready(Some(request.messages))
}
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
fn poll_write(&mut self, cx: &mut Context<'_>) -> Result<bool, Error> {
loop {
if self.state == State::Closing {
trace!("poll_write: done");
return Ok(false);
}
if Pin::new(&mut self.stream)
.poll_ready(cx)
.map_err(Error::io)?
.is_pending()
{
trace!("poll_write: waiting on socket");
return Ok(false);
}
let request = match self.poll_request(cx) {
Poll::Ready(Some(request)) => request,
Poll::Ready(None) if self.responses.is_empty() && self.state == State::Active => {
trace!("poll_write: at eof, terminating");
self.state = State::Terminating;
let mut request = BytesMut::new();
frontend::terminate(&mut request);
RequestMessages::Single(FrontendMessage::Raw(request.freeze()))
}
Poll::Ready(None) => {
trace!(
"poll_write: at eof, pending responses {}",
self.responses.len()
);
return Ok(true);
}
Poll::Pending => {
trace!("poll_write: waiting on request");
return Ok(true);
}
};
match request {
RequestMessages::Single(request) => {
Pin::new(&mut self.stream)
.start_send(request)
.map_err(Error::io)?;
if self.state == State::Terminating {
trace!("poll_write: sent eof, closing");
self.state = State::Closing;
}
}
}
}
}
fn poll_flush(&mut self, cx: &mut Context<'_>) -> Result<(), Error> {
match Pin::new(&mut self.stream)
.poll_flush(cx)
.map_err(Error::io)?
{
Poll::Ready(()) => trace!("poll_flush: flushed"),
Poll::Pending => trace!("poll_flush: waiting on socket"),
}
Ok(())
}
fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
if self.state != State::Closing {
return Poll::Pending;
}
match Pin::new(&mut self.stream)
.poll_close(cx)
.map_err(Error::io)?
{
Poll::Ready(()) => {
trace!("poll_shutdown: complete");
Poll::Ready(Ok(()))
}
Poll::Pending => {
trace!("poll_shutdown: waiting on socket");
Poll::Pending
}
}
}
/// Returns the value of a runtime parameter for this connection.
pub fn parameter(&self, name: &str) -> Option<&str> {
self.parameters.get(name).map(|s| &**s)
}
/// Polls for asynchronous messages from the server.
///
/// The server can send notices as well as notifications asynchronously to the client. Applications that wish to
/// examine those messages should use this method to drive the connection rather than its `Future` implementation.
pub fn poll_message(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<AsyncMessage, Error>>> {
let message = self.poll_read(cx)?;
let want_flush = self.poll_write(cx)?;
if want_flush {
self.poll_flush(cx)?;
}
match message {
Some(message) => Poll::Ready(Some(Ok(message))),
None => match self.poll_shutdown(cx) {
Poll::Ready(Ok(())) => Poll::Ready(None),
Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
Poll::Pending => Poll::Pending,
},
}
}
}
impl<S, T> Future for Connection<S, T>
where
S: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin,
{
type Output = Result<(), Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
while let Some(message) = ready!(self.poll_message(cx)?) {
if let AsyncMessage::Notice(notice) = message {
info!("{}: {}", notice.severity(), notice.message());
}
}
Poll::Ready(Ok(()))
}
}

View File

@@ -0,0 +1,501 @@
//! Errors.
use fallible_iterator::FallibleIterator;
use postgres_protocol2::message::backend::{ErrorFields, ErrorResponseBody};
use std::error::{self, Error as _Error};
use std::fmt;
use std::io;
pub use self::sqlstate::*;
#[allow(clippy::unreadable_literal)]
mod sqlstate;
/// The severity of a Postgres error or notice.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum Severity {
/// PANIC
Panic,
/// FATAL
Fatal,
/// ERROR
Error,
/// WARNING
Warning,
/// NOTICE
Notice,
/// DEBUG
Debug,
/// INFO
Info,
/// LOG
Log,
}
impl fmt::Display for Severity {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match *self {
Severity::Panic => "PANIC",
Severity::Fatal => "FATAL",
Severity::Error => "ERROR",
Severity::Warning => "WARNING",
Severity::Notice => "NOTICE",
Severity::Debug => "DEBUG",
Severity::Info => "INFO",
Severity::Log => "LOG",
};
fmt.write_str(s)
}
}
impl Severity {
fn from_str(s: &str) -> Option<Severity> {
match s {
"PANIC" => Some(Severity::Panic),
"FATAL" => Some(Severity::Fatal),
"ERROR" => Some(Severity::Error),
"WARNING" => Some(Severity::Warning),
"NOTICE" => Some(Severity::Notice),
"DEBUG" => Some(Severity::Debug),
"INFO" => Some(Severity::Info),
"LOG" => Some(Severity::Log),
_ => None,
}
}
}
/// A Postgres error or notice.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DbError {
severity: String,
parsed_severity: Option<Severity>,
code: SqlState,
message: String,
detail: Option<String>,
hint: Option<String>,
position: Option<ErrorPosition>,
where_: Option<String>,
schema: Option<String>,
table: Option<String>,
column: Option<String>,
datatype: Option<String>,
constraint: Option<String>,
file: Option<String>,
line: Option<u32>,
routine: Option<String>,
}
impl DbError {
pub(crate) fn parse(fields: &mut ErrorFields<'_>) -> io::Result<DbError> {
let mut severity = None;
let mut parsed_severity = None;
let mut code = None;
let mut message = None;
let mut detail = None;
let mut hint = None;
let mut normal_position = None;
let mut internal_position = None;
let mut internal_query = None;
let mut where_ = None;
let mut schema = None;
let mut table = None;
let mut column = None;
let mut datatype = None;
let mut constraint = None;
let mut file = None;
let mut line = None;
let mut routine = None;
while let Some(field) = fields.next()? {
match field.type_() {
b'S' => severity = Some(field.value().to_owned()),
b'C' => code = Some(SqlState::from_code(field.value())),
b'M' => message = Some(field.value().to_owned()),
b'D' => detail = Some(field.value().to_owned()),
b'H' => hint = Some(field.value().to_owned()),
b'P' => {
normal_position = Some(field.value().parse::<u32>().map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
"`P` field did not contain an integer",
)
})?);
}
b'p' => {
internal_position = Some(field.value().parse::<u32>().map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
"`p` field did not contain an integer",
)
})?);
}
b'q' => internal_query = Some(field.value().to_owned()),
b'W' => where_ = Some(field.value().to_owned()),
b's' => schema = Some(field.value().to_owned()),
b't' => table = Some(field.value().to_owned()),
b'c' => column = Some(field.value().to_owned()),
b'd' => datatype = Some(field.value().to_owned()),
b'n' => constraint = Some(field.value().to_owned()),
b'F' => file = Some(field.value().to_owned()),
b'L' => {
line = Some(field.value().parse::<u32>().map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
"`L` field did not contain an integer",
)
})?);
}
b'R' => routine = Some(field.value().to_owned()),
b'V' => {
parsed_severity = Some(Severity::from_str(field.value()).ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"`V` field contained an invalid value",
)
})?);
}
_ => {}
}
}
Ok(DbError {
severity: severity
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "`S` field missing"))?,
parsed_severity,
code: code
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "`C` field missing"))?,
message: message
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "`M` field missing"))?,
detail,
hint,
position: match normal_position {
Some(position) => Some(ErrorPosition::Original(position)),
None => match internal_position {
Some(position) => Some(ErrorPosition::Internal {
position,
query: internal_query.ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"`q` field missing but `p` field present",
)
})?,
}),
None => None,
},
},
where_,
schema,
table,
column,
datatype,
constraint,
file,
line,
routine,
})
}
/// The field contents are ERROR, FATAL, or PANIC (in an error message),
/// or WARNING, NOTICE, DEBUG, INFO, or LOG (in a notice message), or a
/// localized translation of one of these.
pub fn severity(&self) -> &str {
&self.severity
}
/// A parsed, nonlocalized version of `severity`. (PostgreSQL 9.6+)
pub fn parsed_severity(&self) -> Option<Severity> {
self.parsed_severity
}
/// The SQLSTATE code for the error.
pub fn code(&self) -> &SqlState {
&self.code
}
/// The primary human-readable error message.
///
/// This should be accurate but terse (typically one line).
pub fn message(&self) -> &str {
&self.message
}
/// An optional secondary error message carrying more detail about the
/// problem.
///
/// Might run to multiple lines.
pub fn detail(&self) -> Option<&str> {
self.detail.as_deref()
}
/// An optional suggestion what to do about the problem.
///
/// This is intended to differ from `detail` in that it offers advice
/// (potentially inappropriate) rather than hard facts. Might run to
/// multiple lines.
pub fn hint(&self) -> Option<&str> {
self.hint.as_deref()
}
/// An optional error cursor position into either the original query string
/// or an internally generated query.
pub fn position(&self) -> Option<&ErrorPosition> {
self.position.as_ref()
}
/// An indication of the context in which the error occurred.
///
/// Presently this includes a call stack traceback of active procedural
/// language functions and internally-generated queries. The trace is one
/// entry per line, most recent first.
pub fn where_(&self) -> Option<&str> {
self.where_.as_deref()
}
/// If the error was associated with a specific database object, the name
/// of the schema containing that object, if any. (PostgreSQL 9.3+)
pub fn schema(&self) -> Option<&str> {
self.schema.as_deref()
}
/// If the error was associated with a specific table, the name of the
/// table. (Refer to the schema name field for the name of the table's
/// schema.) (PostgreSQL 9.3+)
pub fn table(&self) -> Option<&str> {
self.table.as_deref()
}
/// If the error was associated with a specific table column, the name of
/// the column.
///
/// (Refer to the schema and table name fields to identify the table.)
/// (PostgreSQL 9.3+)
pub fn column(&self) -> Option<&str> {
self.column.as_deref()
}
/// If the error was associated with a specific data type, the name of the
/// data type. (Refer to the schema name field for the name of the data
/// type's schema.) (PostgreSQL 9.3+)
pub fn datatype(&self) -> Option<&str> {
self.datatype.as_deref()
}
/// If the error was associated with a specific constraint, the name of the
/// constraint.
///
/// Refer to fields listed above for the associated table or domain.
/// (For this purpose, indexes are treated as constraints, even if they
/// weren't created with constraint syntax.) (PostgreSQL 9.3+)
pub fn constraint(&self) -> Option<&str> {
self.constraint.as_deref()
}
/// The file name of the source-code location where the error was reported.
pub fn file(&self) -> Option<&str> {
self.file.as_deref()
}
/// The line number of the source-code location where the error was
/// reported.
pub fn line(&self) -> Option<u32> {
self.line
}
/// The name of the source-code routine reporting the error.
pub fn routine(&self) -> Option<&str> {
self.routine.as_deref()
}
}
impl fmt::Display for DbError {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(fmt, "{}: {}", self.severity, self.message)?;
if let Some(detail) = &self.detail {
write!(fmt, "\nDETAIL: {}", detail)?;
}
if let Some(hint) = &self.hint {
write!(fmt, "\nHINT: {}", hint)?;
}
Ok(())
}
}
impl error::Error for DbError {}
/// Represents the position of an error in a query.
#[derive(Clone, PartialEq, Eq, Debug)]
pub enum ErrorPosition {
/// A position in the original query.
Original(u32),
/// A position in an internally generated query.
Internal {
/// The byte position.
position: u32,
/// A query generated by the Postgres server.
query: String,
},
}
#[derive(Debug, PartialEq)]
enum Kind {
Io,
UnexpectedMessage,
Tls,
ToSql(usize),
FromSql(usize),
Column(String),
Closed,
Db,
Parse,
Encode,
Authentication,
ConfigParse,
Config,
Connect,
Timeout,
}
struct ErrorInner {
kind: Kind,
cause: Option<Box<dyn error::Error + Sync + Send>>,
}
/// An error communicating with the Postgres server.
pub struct Error(Box<ErrorInner>);
impl fmt::Debug for Error {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("Error")
.field("kind", &self.0.kind)
.field("cause", &self.0.cause)
.finish()
}
}
impl fmt::Display for Error {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.0.kind {
Kind::Io => fmt.write_str("error communicating with the server")?,
Kind::UnexpectedMessage => fmt.write_str("unexpected message from server")?,
Kind::Tls => fmt.write_str("error performing TLS handshake")?,
Kind::ToSql(idx) => write!(fmt, "error serializing parameter {}", idx)?,
Kind::FromSql(idx) => write!(fmt, "error deserializing column {}", idx)?,
Kind::Column(column) => write!(fmt, "invalid column `{}`", column)?,
Kind::Closed => fmt.write_str("connection closed")?,
Kind::Db => fmt.write_str("db error")?,
Kind::Parse => fmt.write_str("error parsing response from server")?,
Kind::Encode => fmt.write_str("error encoding message to server")?,
Kind::Authentication => fmt.write_str("authentication error")?,
Kind::ConfigParse => fmt.write_str("invalid connection string")?,
Kind::Config => fmt.write_str("invalid configuration")?,
Kind::Connect => fmt.write_str("error connecting to server")?,
Kind::Timeout => fmt.write_str("timeout waiting for server")?,
};
if let Some(ref cause) = self.0.cause {
write!(fmt, ": {}", cause)?;
}
Ok(())
}
}
impl error::Error for Error {
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
self.0.cause.as_ref().map(|e| &**e as _)
}
}
impl Error {
/// Consumes the error, returning its cause.
pub fn into_source(self) -> Option<Box<dyn error::Error + Sync + Send>> {
self.0.cause
}
/// Returns the source of this error if it was a `DbError`.
///
/// This is a simple convenience method.
pub fn as_db_error(&self) -> Option<&DbError> {
self.source().and_then(|e| e.downcast_ref::<DbError>())
}
/// Determines if the error was associated with closed connection.
pub fn is_closed(&self) -> bool {
self.0.kind == Kind::Closed
}
/// Returns the SQLSTATE error code associated with the error.
///
/// This is a convenience method that downcasts the cause to a `DbError` and returns its code.
pub fn code(&self) -> Option<&SqlState> {
self.as_db_error().map(DbError::code)
}
fn new(kind: Kind, cause: Option<Box<dyn error::Error + Sync + Send>>) -> Error {
Error(Box::new(ErrorInner { kind, cause }))
}
pub(crate) fn closed() -> Error {
Error::new(Kind::Closed, None)
}
pub(crate) fn unexpected_message() -> Error {
Error::new(Kind::UnexpectedMessage, None)
}
#[allow(clippy::needless_pass_by_value)]
pub(crate) fn db(error: ErrorResponseBody) -> Error {
match DbError::parse(&mut error.fields()) {
Ok(e) => Error::new(Kind::Db, Some(Box::new(e))),
Err(e) => Error::new(Kind::Parse, Some(Box::new(e))),
}
}
pub(crate) fn parse(e: io::Error) -> Error {
Error::new(Kind::Parse, Some(Box::new(e)))
}
pub(crate) fn encode(e: io::Error) -> Error {
Error::new(Kind::Encode, Some(Box::new(e)))
}
#[allow(clippy::wrong_self_convention)]
pub(crate) fn to_sql(e: Box<dyn error::Error + Sync + Send>, idx: usize) -> Error {
Error::new(Kind::ToSql(idx), Some(e))
}
pub(crate) fn from_sql(e: Box<dyn error::Error + Sync + Send>, idx: usize) -> Error {
Error::new(Kind::FromSql(idx), Some(e))
}
pub(crate) fn column(column: String) -> Error {
Error::new(Kind::Column(column), None)
}
pub(crate) fn tls(e: Box<dyn error::Error + Sync + Send>) -> Error {
Error::new(Kind::Tls, Some(e))
}
pub(crate) fn io(e: io::Error) -> Error {
Error::new(Kind::Io, Some(Box::new(e)))
}
pub(crate) fn authentication(e: Box<dyn error::Error + Sync + Send>) -> Error {
Error::new(Kind::Authentication, Some(e))
}
pub(crate) fn config_parse(e: Box<dyn error::Error + Sync + Send>) -> Error {
Error::new(Kind::ConfigParse, Some(e))
}
pub(crate) fn config(e: Box<dyn error::Error + Sync + Send>) -> Error {
Error::new(Kind::Config, Some(e))
}
pub(crate) fn connect(e: io::Error) -> Error {
Error::new(Kind::Connect, Some(Box::new(e)))
}
#[doc(hidden)]
pub fn __private_api_timeout() -> Error {
Error::new(Kind::Timeout, None)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,64 @@
use crate::query::RowStream;
use crate::types::Type;
use crate::{Client, Error, Transaction};
use async_trait::async_trait;
use postgres_protocol2::Oid;
mod private {
pub trait Sealed {}
}
/// A trait allowing abstraction over connections and transactions.
///
/// This trait is "sealed", and cannot be implemented outside of this crate.
#[async_trait]
pub trait GenericClient: private::Sealed {
/// Like `Client::query_raw_txt`.
async fn query_raw_txt<S, I>(&self, statement: &str, params: I) -> Result<RowStream, Error>
where
S: AsRef<str> + Sync + Send,
I: IntoIterator<Item = Option<S>> + Sync + Send,
I::IntoIter: ExactSizeIterator + Sync + Send;
/// Query for type information
async fn get_type(&self, oid: Oid) -> Result<Type, Error>;
}
impl private::Sealed for Client {}
#[async_trait]
impl GenericClient for Client {
async fn query_raw_txt<S, I>(&self, statement: &str, params: I) -> Result<RowStream, Error>
where
S: AsRef<str> + Sync + Send,
I: IntoIterator<Item = Option<S>> + Sync + Send,
I::IntoIter: ExactSizeIterator + Sync + Send,
{
self.query_raw_txt(statement, params).await
}
/// Query for type information
async fn get_type(&self, oid: Oid) -> Result<Type, Error> {
self.get_type(oid).await
}
}
impl private::Sealed for Transaction<'_> {}
#[async_trait]
#[allow(clippy::needless_lifetimes)]
impl GenericClient for Transaction<'_> {
async fn query_raw_txt<S, I>(&self, statement: &str, params: I) -> Result<RowStream, Error>
where
S: AsRef<str> + Sync + Send,
I: IntoIterator<Item = Option<S>> + Sync + Send,
I::IntoIter: ExactSizeIterator + Sync + Send,
{
self.query_raw_txt(statement, params).await
}
/// Query for type information
async fn get_type(&self, oid: Oid) -> Result<Type, Error> {
self.client().get_type(oid).await
}
}

View File

@@ -0,0 +1,148 @@
//! An asynchronous, pipelined, PostgreSQL client.
#![warn(rust_2018_idioms, clippy::all, missing_docs)]
pub use crate::cancel_token::CancelToken;
pub use crate::client::Client;
pub use crate::config::Config;
pub use crate::connection::Connection;
use crate::error::DbError;
pub use crate::error::Error;
pub use crate::generic_client::GenericClient;
pub use crate::query::RowStream;
pub use crate::row::{Row, SimpleQueryRow};
pub use crate::simple_query::SimpleQueryStream;
pub use crate::statement::{Column, Statement};
use crate::tls::MakeTlsConnect;
pub use crate::tls::NoTls;
pub use crate::to_statement::ToStatement;
pub use crate::transaction::Transaction;
pub use crate::transaction_builder::{IsolationLevel, TransactionBuilder};
use crate::types::ToSql;
use postgres_protocol2::message::backend::ReadyForQueryBody;
use tokio::net::TcpStream;
/// After executing a query, the connection will be in one of these states
#[derive(Clone, Copy, Debug, PartialEq)]
#[repr(u8)]
pub enum ReadyForQueryStatus {
/// Connection state is unknown
Unknown,
/// Connection is idle (no transactions)
Idle = b'I',
/// Connection is in a transaction block
Transaction = b'T',
/// Connection is in a failed transaction block
FailedTransaction = b'E',
}
impl From<ReadyForQueryBody> for ReadyForQueryStatus {
fn from(value: ReadyForQueryBody) -> Self {
match value.status() {
b'I' => Self::Idle,
b'T' => Self::Transaction,
b'E' => Self::FailedTransaction,
_ => Self::Unknown,
}
}
}
mod cancel_query;
mod cancel_query_raw;
mod cancel_token;
mod client;
mod codec;
pub mod config;
mod connect;
mod connect_raw;
mod connect_socket;
mod connect_tls;
mod connection;
pub mod error;
mod generic_client;
pub mod maybe_tls_stream;
mod prepare;
mod query;
pub mod row;
mod simple_query;
mod statement;
pub mod tls;
mod to_statement;
mod transaction;
mod transaction_builder;
pub mod types;
/// A convenience function which parses a connection string and connects to the database.
///
/// See the documentation for [`Config`] for details on the connection string format.
///
/// Requires the `runtime` Cargo feature (enabled by default).
///
/// [`Config`]: config/struct.Config.html
pub async fn connect<T>(
config: &str,
tls: T,
) -> Result<(Client, Connection<TcpStream, T::Stream>), Error>
where
T: MakeTlsConnect<TcpStream>,
{
let config = config.parse::<Config>()?;
config.connect(tls).await
}
/// An asynchronous notification.
#[derive(Clone, Debug)]
pub struct Notification {
process_id: i32,
channel: String,
payload: String,
}
impl Notification {
/// The process ID of the notifying backend process.
pub fn process_id(&self) -> i32 {
self.process_id
}
/// The name of the channel that the notify has been raised on.
pub fn channel(&self) -> &str {
&self.channel
}
/// The "payload" string passed from the notifying process.
pub fn payload(&self) -> &str {
&self.payload
}
}
/// An asynchronous message from the server.
#[allow(clippy::large_enum_variant)]
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum AsyncMessage {
/// A notice.
///
/// Notices use the same format as errors, but aren't "errors" per-se.
Notice(DbError),
/// A notification.
///
/// Connections can subscribe to notifications with the `LISTEN` command.
Notification(Notification),
}
/// Message returned by the `SimpleQuery` stream.
#[derive(Debug)]
#[non_exhaustive]
pub enum SimpleQueryMessage {
/// A row of data.
Row(SimpleQueryRow),
/// A statement in the query has completed.
///
/// The number of rows modified or selected is returned.
CommandComplete(u64),
}
fn slice_iter<'a>(
s: &'a [&'a (dyn ToSql + Sync)],
) -> impl ExactSizeIterator<Item = &'a (dyn ToSql + Sync)> + 'a {
s.iter().map(|s| *s as _)
}

View File

@@ -0,0 +1,77 @@
//! MaybeTlsStream.
//!
//! Represents a stream that may or may not be encrypted with TLS.
use crate::tls::{ChannelBinding, TlsStream};
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
/// A stream that may or may not be encrypted with TLS.
pub enum MaybeTlsStream<S, T> {
/// An unencrypted stream.
Raw(S),
/// An encrypted stream.
Tls(T),
}
impl<S, T> AsyncRead for MaybeTlsStream<S, T>
where
S: AsyncRead + Unpin,
T: AsyncRead + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match &mut *self {
MaybeTlsStream::Raw(s) => Pin::new(s).poll_read(cx, buf),
MaybeTlsStream::Tls(s) => Pin::new(s).poll_read(cx, buf),
}
}
}
impl<S, T> AsyncWrite for MaybeTlsStream<S, T>
where
S: AsyncWrite + Unpin,
T: AsyncWrite + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match &mut *self {
MaybeTlsStream::Raw(s) => Pin::new(s).poll_write(cx, buf),
MaybeTlsStream::Tls(s) => Pin::new(s).poll_write(cx, buf),
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match &mut *self {
MaybeTlsStream::Raw(s) => Pin::new(s).poll_flush(cx),
MaybeTlsStream::Tls(s) => Pin::new(s).poll_flush(cx),
}
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match &mut *self {
MaybeTlsStream::Raw(s) => Pin::new(s).poll_shutdown(cx),
MaybeTlsStream::Tls(s) => Pin::new(s).poll_shutdown(cx),
}
}
}
impl<S, T> TlsStream for MaybeTlsStream<S, T>
where
S: AsyncRead + AsyncWrite + Unpin,
T: TlsStream + Unpin,
{
fn channel_binding(&self) -> ChannelBinding {
match self {
MaybeTlsStream::Raw(_) => ChannelBinding::none(),
MaybeTlsStream::Tls(s) => s.channel_binding(),
}
}
}

View File

@@ -0,0 +1,262 @@
use crate::client::InnerClient;
use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::error::SqlState;
use crate::types::{Field, Kind, Oid, Type};
use crate::{query, slice_iter};
use crate::{Column, Error, Statement};
use bytes::Bytes;
use fallible_iterator::FallibleIterator;
use futures_util::{pin_mut, TryStreamExt};
use log::debug;
use postgres_protocol2::message::backend::Message;
use postgres_protocol2::message::frontend;
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
pub(crate) const TYPEINFO_QUERY: &str = "\
SELECT t.typname, t.typtype, t.typelem, r.rngsubtype, t.typbasetype, n.nspname, t.typrelid
FROM pg_catalog.pg_type t
LEFT OUTER JOIN pg_catalog.pg_range r ON r.rngtypid = t.oid
INNER JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid
WHERE t.oid = $1
";
// Range types weren't added until Postgres 9.2, so pg_range may not exist
const TYPEINFO_FALLBACK_QUERY: &str = "\
SELECT t.typname, t.typtype, t.typelem, NULL::OID, t.typbasetype, n.nspname, t.typrelid
FROM pg_catalog.pg_type t
INNER JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid
WHERE t.oid = $1
";
const TYPEINFO_ENUM_QUERY: &str = "\
SELECT enumlabel
FROM pg_catalog.pg_enum
WHERE enumtypid = $1
ORDER BY enumsortorder
";
// Postgres 9.0 didn't have enumsortorder
const TYPEINFO_ENUM_FALLBACK_QUERY: &str = "\
SELECT enumlabel
FROM pg_catalog.pg_enum
WHERE enumtypid = $1
ORDER BY oid
";
pub(crate) const TYPEINFO_COMPOSITE_QUERY: &str = "\
SELECT attname, atttypid
FROM pg_catalog.pg_attribute
WHERE attrelid = $1
AND NOT attisdropped
AND attnum > 0
ORDER BY attnum
";
static NEXT_ID: AtomicUsize = AtomicUsize::new(0);
pub async fn prepare(
client: &Arc<InnerClient>,
query: &str,
types: &[Type],
) -> Result<Statement, Error> {
let name = format!("s{}", NEXT_ID.fetch_add(1, Ordering::SeqCst));
let buf = encode(client, &name, query, types)?;
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
match responses.next().await? {
Message::ParseComplete => {}
_ => return Err(Error::unexpected_message()),
}
let parameter_description = match responses.next().await? {
Message::ParameterDescription(body) => body,
_ => return Err(Error::unexpected_message()),
};
let row_description = match responses.next().await? {
Message::RowDescription(body) => Some(body),
Message::NoData => None,
_ => return Err(Error::unexpected_message()),
};
let mut parameters = vec![];
let mut it = parameter_description.parameters();
while let Some(oid) = it.next().map_err(Error::parse)? {
let type_ = get_type(client, oid).await?;
parameters.push(type_);
}
let mut columns = vec![];
if let Some(row_description) = row_description {
let mut it = row_description.fields();
while let Some(field) = it.next().map_err(Error::parse)? {
let type_ = get_type(client, field.type_oid()).await?;
let column = Column::new(field.name().to_string(), type_, field);
columns.push(column);
}
}
Ok(Statement::new(client, name, parameters, columns))
}
fn prepare_rec<'a>(
client: &'a Arc<InnerClient>,
query: &'a str,
types: &'a [Type],
) -> Pin<Box<dyn Future<Output = Result<Statement, Error>> + 'a + Send>> {
Box::pin(prepare(client, query, types))
}
fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Result<Bytes, Error> {
if types.is_empty() {
debug!("preparing query {}: {}", name, query);
} else {
debug!("preparing query {} with types {:?}: {}", name, types, query);
}
client.with_buf(|buf| {
frontend::parse(name, query, types.iter().map(Type::oid), buf).map_err(Error::encode)?;
frontend::describe(b'S', name, buf).map_err(Error::encode)?;
frontend::sync(buf);
Ok(buf.split().freeze())
})
}
pub async fn get_type(client: &Arc<InnerClient>, oid: Oid) -> Result<Type, Error> {
if let Some(type_) = Type::from_oid(oid) {
return Ok(type_);
}
if let Some(type_) = client.type_(oid) {
return Ok(type_);
}
let stmt = typeinfo_statement(client).await?;
let rows = query::query(client, stmt, slice_iter(&[&oid])).await?;
pin_mut!(rows);
let row = match rows.try_next().await? {
Some(row) => row,
None => return Err(Error::unexpected_message()),
};
let name: String = row.try_get(0)?;
let type_: i8 = row.try_get(1)?;
let elem_oid: Oid = row.try_get(2)?;
let rngsubtype: Option<Oid> = row.try_get(3)?;
let basetype: Oid = row.try_get(4)?;
let schema: String = row.try_get(5)?;
let relid: Oid = row.try_get(6)?;
let kind = if type_ == b'e' as i8 {
let variants = get_enum_variants(client, oid).await?;
Kind::Enum(variants)
} else if type_ == b'p' as i8 {
Kind::Pseudo
} else if basetype != 0 {
let type_ = get_type_rec(client, basetype).await?;
Kind::Domain(type_)
} else if elem_oid != 0 {
let type_ = get_type_rec(client, elem_oid).await?;
Kind::Array(type_)
} else if relid != 0 {
let fields = get_composite_fields(client, relid).await?;
Kind::Composite(fields)
} else if let Some(rngsubtype) = rngsubtype {
let type_ = get_type_rec(client, rngsubtype).await?;
Kind::Range(type_)
} else {
Kind::Simple
};
let type_ = Type::new(name, oid, kind, schema);
client.set_type(oid, &type_);
Ok(type_)
}
fn get_type_rec<'a>(
client: &'a Arc<InnerClient>,
oid: Oid,
) -> Pin<Box<dyn Future<Output = Result<Type, Error>> + Send + 'a>> {
Box::pin(get_type(client, oid))
}
async fn typeinfo_statement(client: &Arc<InnerClient>) -> Result<Statement, Error> {
if let Some(stmt) = client.typeinfo() {
return Ok(stmt);
}
let stmt = match prepare_rec(client, TYPEINFO_QUERY, &[]).await {
Ok(stmt) => stmt,
Err(ref e) if e.code() == Some(&SqlState::UNDEFINED_TABLE) => {
prepare_rec(client, TYPEINFO_FALLBACK_QUERY, &[]).await?
}
Err(e) => return Err(e),
};
client.set_typeinfo(&stmt);
Ok(stmt)
}
async fn get_enum_variants(client: &Arc<InnerClient>, oid: Oid) -> Result<Vec<String>, Error> {
let stmt = typeinfo_enum_statement(client).await?;
query::query(client, stmt, slice_iter(&[&oid]))
.await?
.and_then(|row| async move { row.try_get(0) })
.try_collect()
.await
}
async fn typeinfo_enum_statement(client: &Arc<InnerClient>) -> Result<Statement, Error> {
if let Some(stmt) = client.typeinfo_enum() {
return Ok(stmt);
}
let stmt = match prepare_rec(client, TYPEINFO_ENUM_QUERY, &[]).await {
Ok(stmt) => stmt,
Err(ref e) if e.code() == Some(&SqlState::UNDEFINED_COLUMN) => {
prepare_rec(client, TYPEINFO_ENUM_FALLBACK_QUERY, &[]).await?
}
Err(e) => return Err(e),
};
client.set_typeinfo_enum(&stmt);
Ok(stmt)
}
async fn get_composite_fields(client: &Arc<InnerClient>, oid: Oid) -> Result<Vec<Field>, Error> {
let stmt = typeinfo_composite_statement(client).await?;
let rows = query::query(client, stmt, slice_iter(&[&oid]))
.await?
.try_collect::<Vec<_>>()
.await?;
let mut fields = vec![];
for row in rows {
let name = row.try_get(0)?;
let oid = row.try_get(1)?;
let type_ = get_type_rec(client, oid).await?;
fields.push(Field::new(name, type_));
}
Ok(fields)
}
async fn typeinfo_composite_statement(client: &Arc<InnerClient>) -> Result<Statement, Error> {
if let Some(stmt) = client.typeinfo_composite() {
return Ok(stmt);
}
let stmt = prepare_rec(client, TYPEINFO_COMPOSITE_QUERY, &[]).await?;
client.set_typeinfo_composite(&stmt);
Ok(stmt)
}

View File

@@ -0,0 +1,340 @@
use crate::client::{InnerClient, Responses};
use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::types::IsNull;
use crate::{Column, Error, ReadyForQueryStatus, Row, Statement};
use bytes::{BufMut, Bytes, BytesMut};
use fallible_iterator::FallibleIterator;
use futures_util::{ready, Stream};
use log::{debug, log_enabled, Level};
use pin_project_lite::pin_project;
use postgres_protocol2::message::backend::Message;
use postgres_protocol2::message::frontend;
use postgres_types2::{Format, ToSql, Type};
use std::fmt;
use std::marker::PhantomPinned;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
struct BorrowToSqlParamsDebug<'a>(&'a [&'a (dyn ToSql + Sync)]);
impl fmt::Debug for BorrowToSqlParamsDebug<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_list().entries(self.0.iter()).finish()
}
}
pub async fn query<'a, I>(
client: &InnerClient,
statement: Statement,
params: I,
) -> Result<RowStream, Error>
where
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
I::IntoIter: ExactSizeIterator,
{
let buf = if log_enabled!(Level::Debug) {
let params = params.into_iter().collect::<Vec<_>>();
debug!(
"executing statement {} with parameters: {:?}",
statement.name(),
BorrowToSqlParamsDebug(params.as_slice()),
);
encode(client, &statement, params)?
} else {
encode(client, &statement, params)?
};
let responses = start(client, buf).await?;
Ok(RowStream {
statement,
responses,
command_tag: None,
status: ReadyForQueryStatus::Unknown,
output_format: Format::Binary,
_p: PhantomPinned,
})
}
pub async fn query_txt<S, I>(
client: &Arc<InnerClient>,
query: &str,
params: I,
) -> Result<RowStream, Error>
where
S: AsRef<str>,
I: IntoIterator<Item = Option<S>>,
I::IntoIter: ExactSizeIterator,
{
let params = params.into_iter();
let buf = client.with_buf(|buf| {
frontend::parse(
"", // unnamed prepared statement
query, // query to parse
std::iter::empty(), // give no type info
buf,
)
.map_err(Error::encode)?;
frontend::describe(b'S', "", buf).map_err(Error::encode)?;
// Bind, pass params as text, retrieve as binary
match frontend::bind(
"", // empty string selects the unnamed portal
"", // unnamed prepared statement
std::iter::empty(), // all parameters use the default format (text)
params,
|param, buf| match param {
Some(param) => {
buf.put_slice(param.as_ref().as_bytes());
Ok(postgres_protocol2::IsNull::No)
}
None => Ok(postgres_protocol2::IsNull::Yes),
},
Some(0), // all text
buf,
) {
Ok(()) => Ok(()),
Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, 0)),
Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)),
}?;
// Execute
frontend::execute("", 0, buf).map_err(Error::encode)?;
// Sync
frontend::sync(buf);
Ok(buf.split().freeze())
})?;
// now read the responses
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
match responses.next().await? {
Message::ParseComplete => {}
_ => return Err(Error::unexpected_message()),
}
let parameter_description = match responses.next().await? {
Message::ParameterDescription(body) => body,
_ => return Err(Error::unexpected_message()),
};
let row_description = match responses.next().await? {
Message::RowDescription(body) => Some(body),
Message::NoData => None,
_ => return Err(Error::unexpected_message()),
};
match responses.next().await? {
Message::BindComplete => {}
_ => return Err(Error::unexpected_message()),
}
let mut parameters = vec![];
let mut it = parameter_description.parameters();
while let Some(oid) = it.next().map_err(Error::parse)? {
let type_ = Type::from_oid(oid).unwrap_or(Type::UNKNOWN);
parameters.push(type_);
}
let mut columns = vec![];
if let Some(row_description) = row_description {
let mut it = row_description.fields();
while let Some(field) = it.next().map_err(Error::parse)? {
let type_ = Type::from_oid(field.type_oid()).unwrap_or(Type::UNKNOWN);
let column = Column::new(field.name().to_string(), type_, field);
columns.push(column);
}
}
Ok(RowStream {
statement: Statement::new_anonymous(parameters, columns),
responses,
command_tag: None,
status: ReadyForQueryStatus::Unknown,
output_format: Format::Text,
_p: PhantomPinned,
})
}
pub async fn execute<'a, I>(
client: &InnerClient,
statement: Statement,
params: I,
) -> Result<u64, Error>
where
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
I::IntoIter: ExactSizeIterator,
{
let buf = if log_enabled!(Level::Debug) {
let params = params.into_iter().collect::<Vec<_>>();
debug!(
"executing statement {} with parameters: {:?}",
statement.name(),
BorrowToSqlParamsDebug(params.as_slice()),
);
encode(client, &statement, params)?
} else {
encode(client, &statement, params)?
};
let mut responses = start(client, buf).await?;
let mut rows = 0;
loop {
match responses.next().await? {
Message::DataRow(_) => {}
Message::CommandComplete(body) => {
rows = body
.tag()
.map_err(Error::parse)?
.rsplit(' ')
.next()
.unwrap()
.parse()
.unwrap_or(0);
}
Message::EmptyQueryResponse => rows = 0,
Message::ReadyForQuery(_) => return Ok(rows),
_ => return Err(Error::unexpected_message()),
}
}
}
async fn start(client: &InnerClient, buf: Bytes) -> Result<Responses, Error> {
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
match responses.next().await? {
Message::BindComplete => {}
_ => return Err(Error::unexpected_message()),
}
Ok(responses)
}
pub fn encode<'a, I>(client: &InnerClient, statement: &Statement, params: I) -> Result<Bytes, Error>
where
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
I::IntoIter: ExactSizeIterator,
{
client.with_buf(|buf| {
encode_bind(statement, params, "", buf)?;
frontend::execute("", 0, buf).map_err(Error::encode)?;
frontend::sync(buf);
Ok(buf.split().freeze())
})
}
pub fn encode_bind<'a, I>(
statement: &Statement,
params: I,
portal: &str,
buf: &mut BytesMut,
) -> Result<(), Error>
where
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
I::IntoIter: ExactSizeIterator,
{
let param_types = statement.params();
let params = params.into_iter();
assert!(
param_types.len() == params.len(),
"expected {} parameters but got {}",
param_types.len(),
params.len()
);
let (param_formats, params): (Vec<_>, Vec<_>) = params
.zip(param_types.iter())
.map(|(p, ty)| (p.encode_format(ty) as i16, p))
.unzip();
let params = params.into_iter();
let mut error_idx = 0;
let r = frontend::bind(
portal,
statement.name(),
param_formats,
params.zip(param_types).enumerate(),
|(idx, (param, ty)), buf| match param.to_sql_checked(ty, buf) {
Ok(IsNull::No) => Ok(postgres_protocol2::IsNull::No),
Ok(IsNull::Yes) => Ok(postgres_protocol2::IsNull::Yes),
Err(e) => {
error_idx = idx;
Err(e)
}
},
Some(1),
buf,
);
match r {
Ok(()) => Ok(()),
Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, error_idx)),
Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)),
}
}
pin_project! {
/// A stream of table rows.
pub struct RowStream {
statement: Statement,
responses: Responses,
command_tag: Option<String>,
output_format: Format,
status: ReadyForQueryStatus,
#[pin]
_p: PhantomPinned,
}
}
impl Stream for RowStream {
type Item = Result<Row, Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
loop {
match ready!(this.responses.poll_next(cx)?) {
Message::DataRow(body) => {
return Poll::Ready(Some(Ok(Row::new(
this.statement.clone(),
body,
*this.output_format,
)?)))
}
Message::EmptyQueryResponse | Message::PortalSuspended => {}
Message::CommandComplete(body) => {
if let Ok(tag) = body.tag() {
*this.command_tag = Some(tag.to_string());
}
}
Message::ReadyForQuery(status) => {
*this.status = status.into();
return Poll::Ready(None);
}
_ => return Poll::Ready(Some(Err(Error::unexpected_message()))),
}
}
}
}
impl RowStream {
/// Returns information about the columns of data in the row.
pub fn columns(&self) -> &[Column] {
self.statement.columns()
}
/// Returns the command tag of this query.
///
/// This is only available after the stream has been exhausted.
pub fn command_tag(&self) -> Option<String> {
self.command_tag.clone()
}
/// Returns if the connection is ready for querying, with the status of the connection.
///
/// This might be available only after the stream has been exhausted.
pub fn ready_status(&self) -> ReadyForQueryStatus {
self.status
}
}

View File

@@ -0,0 +1,300 @@
//! Rows.
use crate::row::sealed::{AsName, Sealed};
use crate::simple_query::SimpleColumn;
use crate::statement::Column;
use crate::types::{FromSql, Type, WrongType};
use crate::{Error, Statement};
use fallible_iterator::FallibleIterator;
use postgres_protocol2::message::backend::DataRowBody;
use postgres_types2::{Format, WrongFormat};
use std::fmt;
use std::ops::Range;
use std::str;
use std::sync::Arc;
mod sealed {
pub trait Sealed {}
pub trait AsName {
fn as_name(&self) -> &str;
}
}
impl AsName for Column {
fn as_name(&self) -> &str {
self.name()
}
}
impl AsName for String {
fn as_name(&self) -> &str {
self
}
}
/// A trait implemented by types that can index into columns of a row.
///
/// This cannot be implemented outside of this crate.
pub trait RowIndex: Sealed {
#[doc(hidden)]
fn __idx<T>(&self, columns: &[T]) -> Option<usize>
where
T: AsName;
}
impl Sealed for usize {}
impl RowIndex for usize {
#[inline]
fn __idx<T>(&self, columns: &[T]) -> Option<usize>
where
T: AsName,
{
if *self >= columns.len() {
None
} else {
Some(*self)
}
}
}
impl Sealed for str {}
impl RowIndex for str {
#[inline]
fn __idx<T>(&self, columns: &[T]) -> Option<usize>
where
T: AsName,
{
if let Some(idx) = columns.iter().position(|d| d.as_name() == self) {
return Some(idx);
};
// FIXME ASCII-only case insensitivity isn't really the right thing to
// do. Postgres itself uses a dubious wrapper around tolower and JDBC
// uses the US locale.
columns
.iter()
.position(|d| d.as_name().eq_ignore_ascii_case(self))
}
}
impl<T> Sealed for &T where T: ?Sized + Sealed {}
impl<T> RowIndex for &T
where
T: ?Sized + RowIndex,
{
#[inline]
fn __idx<U>(&self, columns: &[U]) -> Option<usize>
where
U: AsName,
{
T::__idx(*self, columns)
}
}
/// A row of data returned from the database by a query.
pub struct Row {
statement: Statement,
output_format: Format,
body: DataRowBody,
ranges: Vec<Option<Range<usize>>>,
}
impl fmt::Debug for Row {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Row")
.field("columns", &self.columns())
.finish()
}
}
impl Row {
pub(crate) fn new(
statement: Statement,
body: DataRowBody,
output_format: Format,
) -> Result<Row, Error> {
let ranges = body.ranges().collect().map_err(Error::parse)?;
Ok(Row {
statement,
body,
ranges,
output_format,
})
}
/// Returns information about the columns of data in the row.
pub fn columns(&self) -> &[Column] {
self.statement.columns()
}
/// Determines if the row contains no values.
pub fn is_empty(&self) -> bool {
self.len() == 0
}
/// Returns the number of values in the row.
pub fn len(&self) -> usize {
self.columns().len()
}
/// Deserializes a value from the row.
///
/// The value can be specified either by its numeric index in the row, or by its column name.
///
/// # Panics
///
/// Panics if the index is out of bounds or if the value cannot be converted to the specified type.
pub fn get<'a, I, T>(&'a self, idx: I) -> T
where
I: RowIndex + fmt::Display,
T: FromSql<'a>,
{
match self.get_inner(&idx) {
Ok(ok) => ok,
Err(err) => panic!("error retrieving column {}: {}", idx, err),
}
}
/// Like `Row::get`, but returns a `Result` rather than panicking.
pub fn try_get<'a, I, T>(&'a self, idx: I) -> Result<T, Error>
where
I: RowIndex + fmt::Display,
T: FromSql<'a>,
{
self.get_inner(&idx)
}
fn get_inner<'a, I, T>(&'a self, idx: &I) -> Result<T, Error>
where
I: RowIndex + fmt::Display,
T: FromSql<'a>,
{
let idx = match idx.__idx(self.columns()) {
Some(idx) => idx,
None => return Err(Error::column(idx.to_string())),
};
let ty = self.columns()[idx].type_();
if !T::accepts(ty) {
return Err(Error::from_sql(
Box::new(WrongType::new::<T>(ty.clone())),
idx,
));
}
FromSql::from_sql_nullable(ty, self.col_buffer(idx)).map_err(|e| Error::from_sql(e, idx))
}
/// Get the raw bytes for the column at the given index.
fn col_buffer(&self, idx: usize) -> Option<&[u8]> {
let range = self.ranges.get(idx)?.to_owned()?;
Some(&self.body.buffer()[range])
}
/// Interpret the column at the given index as text
///
/// Useful when using query_raw_txt() which sets text transfer mode
pub fn as_text(&self, idx: usize) -> Result<Option<&str>, Error> {
if self.output_format == Format::Text {
match self.col_buffer(idx) {
Some(raw) => {
FromSql::from_sql(&Type::TEXT, raw).map_err(|e| Error::from_sql(e, idx))
}
None => Ok(None),
}
} else {
Err(Error::from_sql(Box::new(WrongFormat {}), idx))
}
}
/// Row byte size
pub fn body_len(&self) -> usize {
self.body.buffer().len()
}
}
impl AsName for SimpleColumn {
fn as_name(&self) -> &str {
self.name()
}
}
/// A row of data returned from the database by a simple query.
#[derive(Debug)]
pub struct SimpleQueryRow {
columns: Arc<[SimpleColumn]>,
body: DataRowBody,
ranges: Vec<Option<Range<usize>>>,
}
impl SimpleQueryRow {
#[allow(clippy::new_ret_no_self)]
pub(crate) fn new(
columns: Arc<[SimpleColumn]>,
body: DataRowBody,
) -> Result<SimpleQueryRow, Error> {
let ranges = body.ranges().collect().map_err(Error::parse)?;
Ok(SimpleQueryRow {
columns,
body,
ranges,
})
}
/// Returns information about the columns of data in the row.
pub fn columns(&self) -> &[SimpleColumn] {
&self.columns
}
/// Determines if the row contains no values.
pub fn is_empty(&self) -> bool {
self.len() == 0
}
/// Returns the number of values in the row.
pub fn len(&self) -> usize {
self.columns.len()
}
/// Returns a value from the row.
///
/// The value can be specified either by its numeric index in the row, or by its column name.
///
/// # Panics
///
/// Panics if the index is out of bounds or if the value cannot be converted to the specified type.
pub fn get<I>(&self, idx: I) -> Option<&str>
where
I: RowIndex + fmt::Display,
{
match self.get_inner(&idx) {
Ok(ok) => ok,
Err(err) => panic!("error retrieving column {}: {}", idx, err),
}
}
/// Like `SimpleQueryRow::get`, but returns a `Result` rather than panicking.
pub fn try_get<I>(&self, idx: I) -> Result<Option<&str>, Error>
where
I: RowIndex + fmt::Display,
{
self.get_inner(&idx)
}
fn get_inner<I>(&self, idx: &I) -> Result<Option<&str>, Error>
where
I: RowIndex + fmt::Display,
{
let idx = match idx.__idx(&self.columns) {
Some(idx) => idx,
None => return Err(Error::column(idx.to_string())),
};
let buf = self.ranges[idx].clone().map(|r| &self.body.buffer()[r]);
FromSql::from_sql_nullable(&Type::TEXT, buf).map_err(|e| Error::from_sql(e, idx))
}
}

View File

@@ -0,0 +1,142 @@
use crate::client::{InnerClient, Responses};
use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::{Error, ReadyForQueryStatus, SimpleQueryMessage, SimpleQueryRow};
use bytes::Bytes;
use fallible_iterator::FallibleIterator;
use futures_util::{ready, Stream};
use log::debug;
use pin_project_lite::pin_project;
use postgres_protocol2::message::backend::Message;
use postgres_protocol2::message::frontend;
use std::marker::PhantomPinned;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
/// Information about a column of a single query row.
#[derive(Debug)]
pub struct SimpleColumn {
name: String,
}
impl SimpleColumn {
pub(crate) fn new(name: String) -> SimpleColumn {
SimpleColumn { name }
}
/// Returns the name of the column.
pub fn name(&self) -> &str {
&self.name
}
}
pub async fn simple_query(client: &InnerClient, query: &str) -> Result<SimpleQueryStream, Error> {
debug!("executing simple query: {}", query);
let buf = encode(client, query)?;
let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
Ok(SimpleQueryStream {
responses,
columns: None,
status: ReadyForQueryStatus::Unknown,
_p: PhantomPinned,
})
}
pub async fn batch_execute(
client: &InnerClient,
query: &str,
) -> Result<ReadyForQueryStatus, Error> {
debug!("executing statement batch: {}", query);
let buf = encode(client, query)?;
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
loop {
match responses.next().await? {
Message::ReadyForQuery(status) => return Ok(status.into()),
Message::CommandComplete(_)
| Message::EmptyQueryResponse
| Message::RowDescription(_)
| Message::DataRow(_) => {}
_ => return Err(Error::unexpected_message()),
}
}
}
pub(crate) fn encode(client: &InnerClient, query: &str) -> Result<Bytes, Error> {
client.with_buf(|buf| {
frontend::query(query, buf).map_err(Error::encode)?;
Ok(buf.split().freeze())
})
}
pin_project! {
/// A stream of simple query results.
pub struct SimpleQueryStream {
responses: Responses,
columns: Option<Arc<[SimpleColumn]>>,
status: ReadyForQueryStatus,
#[pin]
_p: PhantomPinned,
}
}
impl SimpleQueryStream {
/// Returns if the connection is ready for querying, with the status of the connection.
///
/// This might be available only after the stream has been exhausted.
pub fn ready_status(&self) -> ReadyForQueryStatus {
self.status
}
}
impl Stream for SimpleQueryStream {
type Item = Result<SimpleQueryMessage, Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
loop {
match ready!(this.responses.poll_next(cx)?) {
Message::CommandComplete(body) => {
let rows = body
.tag()
.map_err(Error::parse)?
.rsplit(' ')
.next()
.unwrap()
.parse()
.unwrap_or(0);
return Poll::Ready(Some(Ok(SimpleQueryMessage::CommandComplete(rows))));
}
Message::EmptyQueryResponse => {
return Poll::Ready(Some(Ok(SimpleQueryMessage::CommandComplete(0))));
}
Message::RowDescription(body) => {
let columns = body
.fields()
.map(|f| Ok(SimpleColumn::new(f.name().to_string())))
.collect::<Vec<_>>()
.map_err(Error::parse)?
.into();
*this.columns = Some(columns);
}
Message::DataRow(body) => {
let row = match &this.columns {
Some(columns) => SimpleQueryRow::new(columns.clone(), body)?,
None => return Poll::Ready(Some(Err(Error::unexpected_message()))),
};
return Poll::Ready(Some(Ok(SimpleQueryMessage::Row(row))));
}
Message::ReadyForQuery(s) => {
*this.status = s.into();
return Poll::Ready(None);
}
_ => return Poll::Ready(Some(Err(Error::unexpected_message()))),
}
}
}
}

View File

@@ -0,0 +1,157 @@
use crate::client::InnerClient;
use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::types::Type;
use postgres_protocol2::{
message::{backend::Field, frontend},
Oid,
};
use std::{
fmt,
sync::{Arc, Weak},
};
struct StatementInner {
client: Weak<InnerClient>,
name: String,
params: Vec<Type>,
columns: Vec<Column>,
}
impl Drop for StatementInner {
fn drop(&mut self) {
if let Some(client) = self.client.upgrade() {
let buf = client.with_buf(|buf| {
frontend::close(b'S', &self.name, buf).unwrap();
frontend::sync(buf);
buf.split().freeze()
});
let _ = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)));
}
}
}
/// A prepared statement.
///
/// Prepared statements can only be used with the connection that created them.
#[derive(Clone)]
pub struct Statement(Arc<StatementInner>);
impl Statement {
pub(crate) fn new(
inner: &Arc<InnerClient>,
name: String,
params: Vec<Type>,
columns: Vec<Column>,
) -> Statement {
Statement(Arc::new(StatementInner {
client: Arc::downgrade(inner),
name,
params,
columns,
}))
}
pub(crate) fn new_anonymous(params: Vec<Type>, columns: Vec<Column>) -> Statement {
Statement(Arc::new(StatementInner {
client: Weak::new(),
name: String::new(),
params,
columns,
}))
}
pub(crate) fn name(&self) -> &str {
&self.0.name
}
/// Returns the expected types of the statement's parameters.
pub fn params(&self) -> &[Type] {
&self.0.params
}
/// Returns information about the columns returned when the statement is queried.
pub fn columns(&self) -> &[Column] {
&self.0.columns
}
}
/// Information about a column of a query.
pub struct Column {
name: String,
type_: Type,
// raw fields from RowDescription
table_oid: Oid,
column_id: i16,
format: i16,
// that better be stored in self.type_, but that is more radical refactoring
type_oid: Oid,
type_size: i16,
type_modifier: i32,
}
impl Column {
pub(crate) fn new(name: String, type_: Type, raw_field: Field<'_>) -> Column {
Column {
name,
type_,
table_oid: raw_field.table_oid(),
column_id: raw_field.column_id(),
format: raw_field.format(),
type_oid: raw_field.type_oid(),
type_size: raw_field.type_size(),
type_modifier: raw_field.type_modifier(),
}
}
/// Returns the name of the column.
pub fn name(&self) -> &str {
&self.name
}
/// Returns the type of the column.
pub fn type_(&self) -> &Type {
&self.type_
}
/// Returns the table OID of the column.
pub fn table_oid(&self) -> Oid {
self.table_oid
}
/// Returns the column ID of the column.
pub fn column_id(&self) -> i16 {
self.column_id
}
/// Returns the format of the column.
pub fn format(&self) -> i16 {
self.format
}
/// Returns the type OID of the column.
pub fn type_oid(&self) -> Oid {
self.type_oid
}
/// Returns the type size of the column.
pub fn type_size(&self) -> i16 {
self.type_size
}
/// Returns the type modifier of the column.
pub fn type_modifier(&self) -> i32 {
self.type_modifier
}
}
impl fmt::Debug for Column {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("Column")
.field("name", &self.name)
.field("type", &self.type_)
.finish()
}
}

View File

@@ -0,0 +1,162 @@
//! TLS support.
use std::error::Error;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{fmt, io};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
pub(crate) mod private {
pub struct ForcePrivateApi;
}
/// Channel binding information returned from a TLS handshake.
pub struct ChannelBinding {
pub(crate) tls_server_end_point: Option<Vec<u8>>,
}
impl ChannelBinding {
/// Creates a `ChannelBinding` containing no information.
pub fn none() -> ChannelBinding {
ChannelBinding {
tls_server_end_point: None,
}
}
/// Creates a `ChannelBinding` containing `tls-server-end-point` channel binding information.
pub fn tls_server_end_point(tls_server_end_point: Vec<u8>) -> ChannelBinding {
ChannelBinding {
tls_server_end_point: Some(tls_server_end_point),
}
}
}
/// A constructor of `TlsConnect`ors.
///
/// Requires the `runtime` Cargo feature (enabled by default).
pub trait MakeTlsConnect<S> {
/// The stream type created by the `TlsConnect` implementation.
type Stream: TlsStream + Unpin;
/// The `TlsConnect` implementation created by this type.
type TlsConnect: TlsConnect<S, Stream = Self::Stream>;
/// The error type returned by the `TlsConnect` implementation.
type Error: Into<Box<dyn Error + Sync + Send>>;
/// Creates a new `TlsConnect`or.
///
/// The domain name is provided for certificate verification and SNI.
fn make_tls_connect(&mut self, domain: &str) -> Result<Self::TlsConnect, Self::Error>;
}
/// An asynchronous function wrapping a stream in a TLS session.
pub trait TlsConnect<S> {
/// The stream returned by the future.
type Stream: TlsStream + Unpin;
/// The error returned by the future.
type Error: Into<Box<dyn Error + Sync + Send>>;
/// The future returned by the connector.
type Future: Future<Output = Result<Self::Stream, Self::Error>>;
/// Returns a future performing a TLS handshake over the stream.
fn connect(self, stream: S) -> Self::Future;
#[doc(hidden)]
fn can_connect(&self, _: private::ForcePrivateApi) -> bool {
true
}
}
/// A TLS-wrapped connection to a PostgreSQL database.
pub trait TlsStream: AsyncRead + AsyncWrite {
/// Returns channel binding information for the session.
fn channel_binding(&self) -> ChannelBinding;
}
/// A `MakeTlsConnect` and `TlsConnect` implementation which simply returns an error.
///
/// This can be used when `sslmode` is `none` or `prefer`.
#[derive(Debug, Copy, Clone)]
pub struct NoTls;
impl<S> MakeTlsConnect<S> for NoTls {
type Stream = NoTlsStream;
type TlsConnect = NoTls;
type Error = NoTlsError;
fn make_tls_connect(&mut self, _: &str) -> Result<NoTls, NoTlsError> {
Ok(NoTls)
}
}
impl<S> TlsConnect<S> for NoTls {
type Stream = NoTlsStream;
type Error = NoTlsError;
type Future = NoTlsFuture;
fn connect(self, _: S) -> NoTlsFuture {
NoTlsFuture(())
}
fn can_connect(&self, _: private::ForcePrivateApi) -> bool {
false
}
}
/// The future returned by `NoTls`.
pub struct NoTlsFuture(());
impl Future for NoTlsFuture {
type Output = Result<NoTlsStream, NoTlsError>;
fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output> {
Poll::Ready(Err(NoTlsError(())))
}
}
/// The TLS "stream" type produced by the `NoTls` connector.
///
/// Since `NoTls` doesn't support TLS, this type is uninhabited.
pub enum NoTlsStream {}
impl AsyncRead for NoTlsStream {
fn poll_read(
self: Pin<&mut Self>,
_: &mut Context<'_>,
_: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match *self {}
}
}
impl AsyncWrite for NoTlsStream {
fn poll_write(self: Pin<&mut Self>, _: &mut Context<'_>, _: &[u8]) -> Poll<io::Result<usize>> {
match *self {}
}
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
match *self {}
}
fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
match *self {}
}
}
impl TlsStream for NoTlsStream {
fn channel_binding(&self) -> ChannelBinding {
match *self {}
}
}
/// The error returned by `NoTls`.
#[derive(Debug)]
pub struct NoTlsError(());
impl fmt::Display for NoTlsError {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.write_str("no TLS implementation configured")
}
}
impl Error for NoTlsError {}

View File

@@ -0,0 +1,57 @@
use crate::to_statement::private::{Sealed, ToStatementType};
use crate::Statement;
mod private {
use crate::{Client, Error, Statement};
pub trait Sealed {}
pub enum ToStatementType<'a> {
Statement(&'a Statement),
Query(&'a str),
}
impl<'a> ToStatementType<'a> {
pub async fn into_statement(self, client: &Client) -> Result<Statement, Error> {
match self {
ToStatementType::Statement(s) => Ok(s.clone()),
ToStatementType::Query(s) => client.prepare(s).await,
}
}
}
}
/// A trait abstracting over prepared and unprepared statements.
///
/// Many methods are generic over this bound, so that they support both a raw query string as well as a statement which
/// was prepared previously.
///
/// This trait is "sealed" and cannot be implemented by anything outside this crate.
pub trait ToStatement: Sealed {
#[doc(hidden)]
fn __convert(&self) -> ToStatementType<'_>;
}
impl ToStatement for Statement {
fn __convert(&self) -> ToStatementType<'_> {
ToStatementType::Statement(self)
}
}
impl Sealed for Statement {}
impl ToStatement for str {
fn __convert(&self) -> ToStatementType<'_> {
ToStatementType::Query(self)
}
}
impl Sealed for str {}
impl ToStatement for String {
fn __convert(&self) -> ToStatementType<'_> {
ToStatementType::Query(self)
}
}
impl Sealed for String {}

View File

@@ -0,0 +1,74 @@
use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::query::RowStream;
use crate::{CancelToken, Client, Error, ReadyForQueryStatus};
use postgres_protocol2::message::frontend;
/// A representation of a PostgreSQL database transaction.
///
/// Transactions will implicitly roll back when dropped. Use the `commit` method to commit the changes made in the
/// transaction. Transactions can be nested, with inner transactions implemented via safepoints.
pub struct Transaction<'a> {
client: &'a mut Client,
done: bool,
}
impl Drop for Transaction<'_> {
fn drop(&mut self) {
if self.done {
return;
}
let buf = self.client.inner().with_buf(|buf| {
frontend::query("ROLLBACK", buf).unwrap();
buf.split().freeze()
});
let _ = self
.client
.inner()
.send(RequestMessages::Single(FrontendMessage::Raw(buf)));
}
}
impl<'a> Transaction<'a> {
pub(crate) fn new(client: &'a mut Client) -> Transaction<'a> {
Transaction {
client,
done: false,
}
}
/// Consumes the transaction, committing all changes made within it.
pub async fn commit(mut self) -> Result<ReadyForQueryStatus, Error> {
self.done = true;
self.client.batch_execute("COMMIT").await
}
/// Rolls the transaction back, discarding all changes made within it.
///
/// This is equivalent to `Transaction`'s `Drop` implementation, but provides any error encountered to the caller.
pub async fn rollback(mut self) -> Result<ReadyForQueryStatus, Error> {
self.done = true;
self.client.batch_execute("ROLLBACK").await
}
/// Like `Client::query_raw_txt`.
pub async fn query_raw_txt<S, I>(&self, statement: &str, params: I) -> Result<RowStream, Error>
where
S: AsRef<str>,
I: IntoIterator<Item = Option<S>>,
I::IntoIter: ExactSizeIterator,
{
self.client.query_raw_txt(statement, params).await
}
/// Like `Client::cancel_token`.
pub fn cancel_token(&self) -> CancelToken {
self.client.cancel_token()
}
/// Returns a reference to the underlying `Client`.
pub fn client(&self) -> &Client {
self.client
}
}

View File

@@ -0,0 +1,113 @@
use crate::{Client, Error, Transaction};
/// The isolation level of a database transaction.
#[derive(Debug, Copy, Clone)]
#[non_exhaustive]
pub enum IsolationLevel {
/// Equivalent to `ReadCommitted`.
ReadUncommitted,
/// An individual statement in the transaction will see rows committed before it began.
ReadCommitted,
/// All statements in the transaction will see the same view of rows committed before the first query in the
/// transaction.
RepeatableRead,
/// The reads and writes in this transaction must be able to be committed as an atomic "unit" with respect to reads
/// and writes of all other concurrent serializable transactions without interleaving.
Serializable,
}
/// A builder for database transactions.
pub struct TransactionBuilder<'a> {
client: &'a mut Client,
isolation_level: Option<IsolationLevel>,
read_only: Option<bool>,
deferrable: Option<bool>,
}
impl<'a> TransactionBuilder<'a> {
pub(crate) fn new(client: &'a mut Client) -> TransactionBuilder<'a> {
TransactionBuilder {
client,
isolation_level: None,
read_only: None,
deferrable: None,
}
}
/// Sets the isolation level of the transaction.
pub fn isolation_level(mut self, isolation_level: IsolationLevel) -> Self {
self.isolation_level = Some(isolation_level);
self
}
/// Sets the access mode of the transaction.
pub fn read_only(mut self, read_only: bool) -> Self {
self.read_only = Some(read_only);
self
}
/// Sets the deferrability of the transaction.
///
/// If the transaction is also serializable and read only, creation of the transaction may block, but when it
/// completes the transaction is able to run with less overhead and a guarantee that it will not be aborted due to
/// serialization failure.
pub fn deferrable(mut self, deferrable: bool) -> Self {
self.deferrable = Some(deferrable);
self
}
/// Begins the transaction.
///
/// The transaction will roll back by default - use the `commit` method to commit it.
pub async fn start(self) -> Result<Transaction<'a>, Error> {
let mut query = "START TRANSACTION".to_string();
let mut first = true;
if let Some(level) = self.isolation_level {
first = false;
query.push_str(" ISOLATION LEVEL ");
let level = match level {
IsolationLevel::ReadUncommitted => "READ UNCOMMITTED",
IsolationLevel::ReadCommitted => "READ COMMITTED",
IsolationLevel::RepeatableRead => "REPEATABLE READ",
IsolationLevel::Serializable => "SERIALIZABLE",
};
query.push_str(level);
}
if let Some(read_only) = self.read_only {
if !first {
query.push(',');
}
first = false;
let s = if read_only {
" READ ONLY"
} else {
" READ WRITE"
};
query.push_str(s);
}
if let Some(deferrable) = self.deferrable {
if !first {
query.push(',');
}
let s = if deferrable {
" DEFERRABLE"
} else {
" NOT DEFERRABLE"
};
query.push_str(s);
}
self.client.batch_execute(&query).await?;
Ok(Transaction::new(self.client))
}
}

View File

@@ -0,0 +1,6 @@
//! Types.
//!
//! This module is a reexport of the `postgres_types` crate.
#[doc(inline)]
pub use postgres_types2::*;

View File

@@ -220,6 +220,11 @@ impl AzureBlobStorage {
let started_at = ScopeGuard::into_inner(started_at);
let outcome = match &download {
Ok(_) => AttemptOutcome::Ok,
// At this level in the stack 404 and 304 responses do not indicate an error.
// There's expected cases when a blob may not exist or hasn't been modified since
// the last get (e.g. probing for timeline indices and heatmap downloads).
// Callers should handle errors if they are unexpected.
Err(DownloadError::NotFound | DownloadError::Unmodified) => AttemptOutcome::Ok,
Err(_) => AttemptOutcome::Err,
};
crate::metrics::BUCKET_METRICS

View File

@@ -176,7 +176,9 @@ pub(crate) struct BucketMetrics {
impl Default for BucketMetrics {
fn default() -> Self {
let buckets = [0.01, 0.10, 0.5, 1.0, 5.0, 10.0, 50.0, 100.0];
// first bucket 100 microseconds to count requests that do not need to wait at all
// and get a permit immediately
let buckets = [0.0001, 0.01, 0.10, 0.5, 1.0, 5.0, 10.0, 50.0, 100.0];
let req_seconds = register_histogram_vec!(
"remote_storage_s3_request_seconds",

View File

@@ -34,6 +34,7 @@ pprof.workspace = true
regex.workspace = true
routerify.workspace = true
serde.workspace = true
serde_with.workspace = true
serde_json.workspace = true
signal-hook.workspace = true
thiserror.workspace = true

View File

@@ -7,29 +7,88 @@ use postgres_connection::{parse_host_port, PgConnectionConfig};
use crate::id::TenantTimelineId;
#[derive(Copy, Clone, PartialEq, Eq, Debug, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum InterpretedFormat {
Bincode,
Protobuf,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum Compression {
Zstd { level: i8 },
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(tag = "type", content = "args")]
#[serde(rename_all = "kebab-case")]
pub enum PostgresClientProtocol {
/// Usual Postgres replication protocol
Vanilla,
/// Custom shard-aware protocol that replicates interpreted records.
/// Used to send wal from safekeeper to pageserver.
Interpreted {
format: InterpretedFormat,
compression: Option<Compression>,
},
}
pub struct ConnectionConfigArgs<'a> {
pub protocol: PostgresClientProtocol,
pub ttid: TenantTimelineId,
pub shard_number: Option<u8>,
pub shard_count: Option<u8>,
pub shard_stripe_size: Option<u32>,
pub listen_pg_addr_str: &'a str,
pub auth_token: Option<&'a str>,
pub availability_zone: Option<&'a str>,
}
impl<'a> ConnectionConfigArgs<'a> {
fn options(&'a self) -> Vec<String> {
let mut options = vec![
"-c".to_owned(),
format!("timeline_id={}", self.ttid.timeline_id),
format!("tenant_id={}", self.ttid.tenant_id),
format!(
"protocol={}",
serde_json::to_string(&self.protocol).unwrap()
),
];
if self.shard_number.is_some() {
assert!(self.shard_count.is_some());
assert!(self.shard_stripe_size.is_some());
options.push(format!("shard_count={}", self.shard_count.unwrap()));
options.push(format!("shard_number={}", self.shard_number.unwrap()));
options.push(format!(
"shard_stripe_size={}",
self.shard_stripe_size.unwrap()
));
}
options
}
}
/// Create client config for fetching WAL from safekeeper on particular timeline.
/// listen_pg_addr_str is in form host:\[port\].
pub fn wal_stream_connection_config(
TenantTimelineId {
tenant_id,
timeline_id,
}: TenantTimelineId,
listen_pg_addr_str: &str,
auth_token: Option<&str>,
availability_zone: Option<&str>,
args: ConnectionConfigArgs,
) -> anyhow::Result<PgConnectionConfig> {
let (host, port) =
parse_host_port(listen_pg_addr_str).context("Unable to parse listen_pg_addr_str")?;
parse_host_port(args.listen_pg_addr_str).context("Unable to parse listen_pg_addr_str")?;
let port = port.unwrap_or(5432);
let mut connstr = PgConnectionConfig::new_host_port(host, port)
.extend_options([
"-c".to_owned(),
format!("timeline_id={}", timeline_id),
format!("tenant_id={}", tenant_id),
])
.set_password(auth_token.map(|s| s.to_owned()));
.extend_options(args.options())
.set_password(args.auth_token.map(|s| s.to_owned()));
if let Some(availability_zone) = availability_zone {
if let Some(availability_zone) = args.availability_zone {
connstr = connstr.extend_options([format!("availability_zone={}", availability_zone)]);
}

View File

@@ -218,7 +218,7 @@ impl MemoryStatus {
fn debug_slice(slice: &[Self]) -> impl '_ + Debug {
struct DS<'a>(&'a [MemoryStatus]);
impl<'a> Debug for DS<'a> {
impl Debug for DS<'_> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.debug_struct("[MemoryStatus]")
.field(
@@ -233,7 +233,7 @@ impl MemoryStatus {
struct Fields<'a, F>(&'a [MemoryStatus], F);
impl<'a, F: Fn(&MemoryStatus) -> T, T: Debug> Debug for Fields<'a, F> {
impl<F: Fn(&MemoryStatus) -> T, T: Debug> Debug for Fields<'_, F> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.debug_list().entries(self.0.iter().map(&self.1)).finish()
}

View File

@@ -8,11 +8,19 @@ license.workspace = true
testing = ["pageserver_api/testing"]
[dependencies]
async-compression.workspace = true
anyhow.workspace = true
bytes.workspace = true
pageserver_api.workspace = true
prost.workspace = true
postgres_ffi.workspace = true
serde.workspace = true
thiserror.workspace = true
tokio = { workspace = true, features = ["io-util"] }
tonic.workspace = true
tracing.workspace = true
utils.workspace = true
workspace_hack = { version = "0.1", path = "../../workspace_hack" }
[build-dependencies]
tonic-build.workspace = true

11
libs/wal_decoder/build.rs Normal file
View File

@@ -0,0 +1,11 @@
fn main() -> Result<(), Box<dyn std::error::Error>> {
// Generate rust code from .proto protobuf.
//
// Note: we previously tried to use deterministic location at proto/ for
// easy location, but apparently interference with cachepot sometimes fails
// the build then. Anyway, per cargo docs build script shouldn't output to
// anywhere but $OUT_DIR.
tonic_build::compile_protos("proto/interpreted_wal.proto")
.unwrap_or_else(|e| panic!("failed to compile protos {:?}", e));
Ok(())
}

View File

@@ -0,0 +1,43 @@
syntax = "proto3";
package interpreted_wal;
message InterpretedWalRecords {
repeated InterpretedWalRecord records = 1;
optional uint64 next_record_lsn = 2;
}
message InterpretedWalRecord {
optional bytes metadata_record = 1;
SerializedValueBatch batch = 2;
uint64 next_record_lsn = 3;
bool flush_uncommitted = 4;
uint32 xid = 5;
}
message SerializedValueBatch {
bytes raw = 1;
repeated ValueMeta metadata = 2;
uint64 max_lsn = 3;
uint64 len = 4;
}
enum ValueMetaType {
Serialized = 0;
Observed = 1;
}
message ValueMeta {
ValueMetaType type = 1;
CompactKey key = 2;
uint64 lsn = 3;
optional uint64 batch_offset = 4;
optional uint64 len = 5;
optional bool will_init = 6;
}
message CompactKey {
int64 high = 1;
int64 low = 2;
}

View File

@@ -4,6 +4,7 @@
use crate::models::*;
use crate::serialized_batch::SerializedValueBatch;
use bytes::{Buf, Bytes};
use pageserver_api::key::rel_block_to_key;
use pageserver_api::reltag::{RelTag, SlruKind};
use pageserver_api::shard::ShardIdentity;
use postgres_ffi::pg_constants;
@@ -32,7 +33,8 @@ impl InterpretedWalRecord {
FlushUncommittedRecords::No
};
let metadata_record = MetadataRecord::from_decoded(&decoded, next_record_lsn, pg_version)?;
let metadata_record =
MetadataRecord::from_decoded_filtered(&decoded, shard, next_record_lsn, pg_version)?;
let batch = SerializedValueBatch::from_decoded_filtered(
decoded,
shard,
@@ -51,8 +53,13 @@ impl InterpretedWalRecord {
}
impl MetadataRecord {
fn from_decoded(
/// Builds a metadata record for this WAL record, if any.
///
/// Only metadata records relevant for the given shard are emitted. Currently, most metadata
/// records are broadcast to all shards for simplicity, but this should be improved.
fn from_decoded_filtered(
decoded: &DecodedWALRecord,
shard: &ShardIdentity,
next_record_lsn: Lsn,
pg_version: u32,
) -> anyhow::Result<Option<MetadataRecord>> {
@@ -61,26 +68,27 @@ impl MetadataRecord {
let mut buf = decoded.record.clone();
buf.advance(decoded.main_data_offset);
match decoded.xl_rmid {
// First, generate metadata records from the decoded WAL record.
let mut metadata_record = match decoded.xl_rmid {
pg_constants::RM_HEAP_ID | pg_constants::RM_HEAP2_ID => {
Self::decode_heapam_record(&mut buf, decoded, pg_version)
Self::decode_heapam_record(&mut buf, decoded, pg_version)?
}
pg_constants::RM_NEON_ID => Self::decode_neonmgr_record(&mut buf, decoded, pg_version),
pg_constants::RM_NEON_ID => Self::decode_neonmgr_record(&mut buf, decoded, pg_version)?,
// Handle other special record types
pg_constants::RM_SMGR_ID => Self::decode_smgr_record(&mut buf, decoded),
pg_constants::RM_DBASE_ID => Self::decode_dbase_record(&mut buf, decoded, pg_version),
pg_constants::RM_SMGR_ID => Self::decode_smgr_record(&mut buf, decoded)?,
pg_constants::RM_DBASE_ID => Self::decode_dbase_record(&mut buf, decoded, pg_version)?,
pg_constants::RM_TBLSPC_ID => {
tracing::trace!("XLOG_TBLSPC_CREATE/DROP is not handled yet");
Ok(None)
None
}
pg_constants::RM_CLOG_ID => Self::decode_clog_record(&mut buf, decoded, pg_version),
pg_constants::RM_CLOG_ID => Self::decode_clog_record(&mut buf, decoded, pg_version)?,
pg_constants::RM_XACT_ID => {
Self::decode_xact_record(&mut buf, decoded, next_record_lsn)
Self::decode_xact_record(&mut buf, decoded, next_record_lsn)?
}
pg_constants::RM_MULTIXACT_ID => {
Self::decode_multixact_record(&mut buf, decoded, pg_version)
Self::decode_multixact_record(&mut buf, decoded, pg_version)?
}
pg_constants::RM_RELMAP_ID => Self::decode_relmap_record(&mut buf, decoded),
pg_constants::RM_RELMAP_ID => Self::decode_relmap_record(&mut buf, decoded)?,
// This is an odd duck. It needs to go to all shards.
// Since it uses the checkpoint image (that's initialized from CHECKPOINT_KEY
// in WalIngest::new), we have to send the whole DecodedWalRecord::record to
@@ -89,19 +97,48 @@ impl MetadataRecord {
// Alternatively, one can make the checkpoint part of the subscription protocol
// to the pageserver. This should work fine, but can be done at a later point.
pg_constants::RM_XLOG_ID => {
Self::decode_xlog_record(&mut buf, decoded, next_record_lsn)
Self::decode_xlog_record(&mut buf, decoded, next_record_lsn)?
}
pg_constants::RM_LOGICALMSG_ID => {
Self::decode_logical_message_record(&mut buf, decoded)
Self::decode_logical_message_record(&mut buf, decoded)?
}
pg_constants::RM_STANDBY_ID => Self::decode_standby_record(&mut buf, decoded),
pg_constants::RM_REPLORIGIN_ID => Self::decode_replorigin_record(&mut buf, decoded),
pg_constants::RM_STANDBY_ID => Self::decode_standby_record(&mut buf, decoded)?,
pg_constants::RM_REPLORIGIN_ID => Self::decode_replorigin_record(&mut buf, decoded)?,
_unexpected => {
// TODO: consider failing here instead of blindly doing something without
// understanding the protocol
Ok(None)
None
}
};
// Next, filter the metadata record by shard.
// Route VM page updates to the shards that own them. VM pages are stored in the VM fork
// of the main relation. These are sharded and managed just like regular relation pages.
// See: https://github.com/neondatabase/neon/issues/9855
if let Some(
MetadataRecord::Heapam(HeapamRecord::ClearVmBits(ref mut clear_vm_bits))
| MetadataRecord::Neonrmgr(NeonrmgrRecord::ClearVmBits(ref mut clear_vm_bits)),
) = metadata_record
{
let is_local_vm_page = |heap_blk| {
let vm_blk = pg_constants::HEAPBLK_TO_MAPBLOCK(heap_blk);
shard.is_key_local(&rel_block_to_key(clear_vm_bits.vm_rel, vm_blk))
};
// Send the old and new VM page updates to their respective shards.
clear_vm_bits.old_heap_blkno = clear_vm_bits
.old_heap_blkno
.filter(|&blkno| is_local_vm_page(blkno));
clear_vm_bits.new_heap_blkno = clear_vm_bits
.new_heap_blkno
.filter(|&blkno| is_local_vm_page(blkno));
// If neither VM page belongs to this shard, discard the record.
if clear_vm_bits.old_heap_blkno.is_none() && clear_vm_bits.new_heap_blkno.is_none() {
metadata_record = None
}
}
Ok(metadata_record)
}
fn decode_heapam_record(

View File

@@ -1,3 +1,4 @@
pub mod decoder;
pub mod models;
pub mod serialized_batch;
pub mod wire_format;

View File

@@ -37,12 +37,32 @@ use utils::lsn::Lsn;
use crate::serialized_batch::SerializedValueBatch;
// Code generated by protobuf.
pub mod proto {
// Tonic does derives as `#[derive(Clone, PartialEq, ::prost::Message)]`
// we don't use these types for anything but broker data transmission,
// so it's ok to ignore this one.
#![allow(clippy::derive_partial_eq_without_eq)]
// The generated ValueMeta has a `len` method generate for its `len` field.
#![allow(clippy::len_without_is_empty)]
tonic::include_proto!("interpreted_wal");
}
#[derive(Serialize, Deserialize)]
pub enum FlushUncommittedRecords {
Yes,
No,
}
/// A batch of interpreted WAL records
#[derive(Serialize, Deserialize)]
pub struct InterpretedWalRecords {
pub records: Vec<InterpretedWalRecord>,
// Start LSN of the next record after the batch.
// Note that said record may not belong to the current shard.
pub next_record_lsn: Option<Lsn>,
}
/// An interpreted Postgres WAL record, ready to be handled by the pageserver
#[derive(Serialize, Deserialize)]
pub struct InterpretedWalRecord {
@@ -65,6 +85,18 @@ pub struct InterpretedWalRecord {
pub xid: TransactionId,
}
impl InterpretedWalRecord {
/// Checks if the WAL record is empty
///
/// An empty interpreted WAL record has no data or metadata and does not have to be sent to the
/// pageserver.
pub fn is_empty(&self) -> bool {
self.batch.is_empty()
&& self.metadata_record.is_none()
&& matches!(self.flush_uncommitted, FlushUncommittedRecords::No)
}
}
/// The interpreted part of the Postgres WAL record which requires metadata
/// writes to the underlying storage engine.
#[derive(Serialize, Deserialize)]

View File

@@ -496,11 +496,16 @@ impl SerializedValueBatch {
}
}
/// Checks if the batch is empty
///
/// A batch is empty when it contains no serialized values.
/// Note that it may still contain observed values.
/// Checks if the batch contains any serialized or observed values
pub fn is_empty(&self) -> bool {
!self.has_data() && self.metadata.is_empty()
}
/// Checks if the batch contains data
///
/// Note that if this returns false, it may still contain observed values or
/// a metadata record.
pub fn has_data(&self) -> bool {
let empty = self.raw.is_empty();
if cfg!(debug_assertions) && empty {
@@ -510,7 +515,7 @@ impl SerializedValueBatch {
.all(|meta| matches!(meta, ValueMeta::Observed(_))));
}
empty
!empty
}
/// Returns the number of values serialized in the batch

View File

@@ -0,0 +1,356 @@
use bytes::{BufMut, Bytes, BytesMut};
use pageserver_api::key::CompactKey;
use prost::{DecodeError, EncodeError, Message};
use tokio::io::AsyncWriteExt;
use utils::bin_ser::{BeSer, DeserializeError, SerializeError};
use utils::lsn::Lsn;
use utils::postgres_client::{Compression, InterpretedFormat};
use crate::models::{
FlushUncommittedRecords, InterpretedWalRecord, InterpretedWalRecords, MetadataRecord,
};
use crate::serialized_batch::{
ObservedValueMeta, SerializedValueBatch, SerializedValueMeta, ValueMeta,
};
use crate::models::proto;
#[derive(Debug, thiserror::Error)]
pub enum ToWireFormatError {
#[error("{0}")]
Bincode(#[from] SerializeError),
#[error("{0}")]
Protobuf(#[from] ProtobufSerializeError),
#[error("{0}")]
Compression(#[from] std::io::Error),
}
#[derive(Debug, thiserror::Error)]
pub enum ProtobufSerializeError {
#[error("{0}")]
MetadataRecord(#[from] SerializeError),
#[error("{0}")]
Encode(#[from] EncodeError),
}
#[derive(Debug, thiserror::Error)]
pub enum FromWireFormatError {
#[error("{0}")]
Bincode(#[from] DeserializeError),
#[error("{0}")]
Protobuf(#[from] ProtobufDeserializeError),
#[error("{0}")]
Decompress(#[from] std::io::Error),
}
#[derive(Debug, thiserror::Error)]
pub enum ProtobufDeserializeError {
#[error("{0}")]
Transcode(#[from] TranscodeError),
#[error("{0}")]
Decode(#[from] DecodeError),
}
#[derive(Debug, thiserror::Error)]
pub enum TranscodeError {
#[error("{0}")]
BadInput(String),
#[error("{0}")]
MetadataRecord(#[from] DeserializeError),
}
pub trait ToWireFormat {
fn to_wire(
self,
format: InterpretedFormat,
compression: Option<Compression>,
) -> impl std::future::Future<Output = Result<Bytes, ToWireFormatError>> + Send;
}
pub trait FromWireFormat {
type T;
fn from_wire(
buf: &Bytes,
format: InterpretedFormat,
compression: Option<Compression>,
) -> impl std::future::Future<Output = Result<Self::T, FromWireFormatError>> + Send;
}
impl ToWireFormat for InterpretedWalRecords {
async fn to_wire(
self,
format: InterpretedFormat,
compression: Option<Compression>,
) -> Result<Bytes, ToWireFormatError> {
use async_compression::tokio::write::ZstdEncoder;
use async_compression::Level;
let encode_res: Result<Bytes, ToWireFormatError> = match format {
InterpretedFormat::Bincode => {
let buf = BytesMut::new();
let mut buf = buf.writer();
self.ser_into(&mut buf)?;
Ok(buf.into_inner().freeze())
}
InterpretedFormat::Protobuf => {
let proto: proto::InterpretedWalRecords = self.try_into()?;
let mut buf = BytesMut::new();
proto
.encode(&mut buf)
.map_err(|e| ToWireFormatError::Protobuf(e.into()))?;
Ok(buf.freeze())
}
};
let buf = encode_res?;
let compressed_buf = match compression {
Some(Compression::Zstd { level }) => {
let mut encoder = ZstdEncoder::with_quality(
Vec::with_capacity(buf.len() / 4),
Level::Precise(level as i32),
);
encoder.write_all(&buf).await?;
encoder.shutdown().await?;
Bytes::from(encoder.into_inner())
}
None => buf,
};
Ok(compressed_buf)
}
}
impl FromWireFormat for InterpretedWalRecords {
type T = Self;
async fn from_wire(
buf: &Bytes,
format: InterpretedFormat,
compression: Option<Compression>,
) -> Result<Self, FromWireFormatError> {
let decompressed_buf = match compression {
Some(Compression::Zstd { .. }) => {
use async_compression::tokio::write::ZstdDecoder;
let mut decoded_buf = Vec::with_capacity(buf.len());
let mut decoder = ZstdDecoder::new(&mut decoded_buf);
decoder.write_all(buf).await?;
decoder.flush().await?;
Bytes::from(decoded_buf)
}
None => buf.clone(),
};
match format {
InterpretedFormat::Bincode => {
InterpretedWalRecords::des(&decompressed_buf).map_err(FromWireFormatError::Bincode)
}
InterpretedFormat::Protobuf => {
let proto = proto::InterpretedWalRecords::decode(decompressed_buf)
.map_err(|e| FromWireFormatError::Protobuf(e.into()))?;
InterpretedWalRecords::try_from(proto)
.map_err(|e| FromWireFormatError::Protobuf(e.into()))
}
}
}
}
impl TryFrom<InterpretedWalRecords> for proto::InterpretedWalRecords {
type Error = SerializeError;
fn try_from(value: InterpretedWalRecords) -> Result<Self, Self::Error> {
let records = value
.records
.into_iter()
.map(proto::InterpretedWalRecord::try_from)
.collect::<Result<Vec<_>, _>>()?;
Ok(proto::InterpretedWalRecords {
records,
next_record_lsn: value.next_record_lsn.map(|l| l.0),
})
}
}
impl TryFrom<InterpretedWalRecord> for proto::InterpretedWalRecord {
type Error = SerializeError;
fn try_from(value: InterpretedWalRecord) -> Result<Self, Self::Error> {
let metadata_record = value
.metadata_record
.map(|meta_rec| -> Result<Vec<u8>, Self::Error> {
let mut buf = Vec::new();
meta_rec.ser_into(&mut buf)?;
Ok(buf)
})
.transpose()?;
Ok(proto::InterpretedWalRecord {
metadata_record,
batch: Some(proto::SerializedValueBatch::from(value.batch)),
next_record_lsn: value.next_record_lsn.0,
flush_uncommitted: matches!(value.flush_uncommitted, FlushUncommittedRecords::Yes),
xid: value.xid,
})
}
}
impl From<SerializedValueBatch> for proto::SerializedValueBatch {
fn from(value: SerializedValueBatch) -> Self {
proto::SerializedValueBatch {
raw: value.raw,
metadata: value
.metadata
.into_iter()
.map(proto::ValueMeta::from)
.collect(),
max_lsn: value.max_lsn.0,
len: value.len as u64,
}
}
}
impl From<ValueMeta> for proto::ValueMeta {
fn from(value: ValueMeta) -> Self {
match value {
ValueMeta::Observed(obs) => proto::ValueMeta {
r#type: proto::ValueMetaType::Observed.into(),
key: Some(proto::CompactKey::from(obs.key)),
lsn: obs.lsn.0,
batch_offset: None,
len: None,
will_init: None,
},
ValueMeta::Serialized(ser) => proto::ValueMeta {
r#type: proto::ValueMetaType::Serialized.into(),
key: Some(proto::CompactKey::from(ser.key)),
lsn: ser.lsn.0,
batch_offset: Some(ser.batch_offset),
len: Some(ser.len as u64),
will_init: Some(ser.will_init),
},
}
}
}
impl From<CompactKey> for proto::CompactKey {
fn from(value: CompactKey) -> Self {
proto::CompactKey {
high: (value.raw() >> 64) as i64,
low: value.raw() as i64,
}
}
}
impl TryFrom<proto::InterpretedWalRecords> for InterpretedWalRecords {
type Error = TranscodeError;
fn try_from(value: proto::InterpretedWalRecords) -> Result<Self, Self::Error> {
let records = value
.records
.into_iter()
.map(InterpretedWalRecord::try_from)
.collect::<Result<_, _>>()?;
Ok(InterpretedWalRecords {
records,
next_record_lsn: value.next_record_lsn.map(Lsn::from),
})
}
}
impl TryFrom<proto::InterpretedWalRecord> for InterpretedWalRecord {
type Error = TranscodeError;
fn try_from(value: proto::InterpretedWalRecord) -> Result<Self, Self::Error> {
let metadata_record = value
.metadata_record
.map(|mrec| -> Result<_, DeserializeError> { MetadataRecord::des(&mrec) })
.transpose()?;
let batch = {
let batch = value.batch.ok_or_else(|| {
TranscodeError::BadInput("InterpretedWalRecord::batch missing".to_string())
})?;
SerializedValueBatch::try_from(batch)?
};
Ok(InterpretedWalRecord {
metadata_record,
batch,
next_record_lsn: Lsn(value.next_record_lsn),
flush_uncommitted: if value.flush_uncommitted {
FlushUncommittedRecords::Yes
} else {
FlushUncommittedRecords::No
},
xid: value.xid,
})
}
}
impl TryFrom<proto::SerializedValueBatch> for SerializedValueBatch {
type Error = TranscodeError;
fn try_from(value: proto::SerializedValueBatch) -> Result<Self, Self::Error> {
let metadata = value
.metadata
.into_iter()
.map(ValueMeta::try_from)
.collect::<Result<Vec<_>, _>>()?;
Ok(SerializedValueBatch {
raw: value.raw,
metadata,
max_lsn: Lsn(value.max_lsn),
len: value.len as usize,
})
}
}
impl TryFrom<proto::ValueMeta> for ValueMeta {
type Error = TranscodeError;
fn try_from(value: proto::ValueMeta) -> Result<Self, Self::Error> {
match proto::ValueMetaType::try_from(value.r#type) {
Ok(proto::ValueMetaType::Serialized) => {
Ok(ValueMeta::Serialized(SerializedValueMeta {
key: value
.key
.ok_or_else(|| {
TranscodeError::BadInput("ValueMeta::key missing".to_string())
})?
.into(),
lsn: Lsn(value.lsn),
batch_offset: value.batch_offset.ok_or_else(|| {
TranscodeError::BadInput("ValueMeta::batch_offset missing".to_string())
})?,
len: value.len.ok_or_else(|| {
TranscodeError::BadInput("ValueMeta::len missing".to_string())
})? as usize,
will_init: value.will_init.ok_or_else(|| {
TranscodeError::BadInput("ValueMeta::will_init missing".to_string())
})?,
}))
}
Ok(proto::ValueMetaType::Observed) => Ok(ValueMeta::Observed(ObservedValueMeta {
key: value
.key
.ok_or_else(|| TranscodeError::BadInput("ValueMeta::key missing".to_string()))?
.into(),
lsn: Lsn(value.lsn),
})),
Err(_) => Err(TranscodeError::BadInput(format!(
"Unexpected ValueMeta::type {}",
value.r#type
))),
}
}
}
impl From<proto::CompactKey> for CompactKey {
fn from(value: proto::CompactKey) -> Self {
(((value.high as i128) << 64) | (value.low as i128)).into()
}
}

View File

@@ -126,6 +126,7 @@ fn main() -> anyhow::Result<()> {
// after setting up logging, log the effective IO engine choice and read path implementations
info!(?conf.virtual_file_io_engine, "starting with virtual_file IO engine");
info!(?conf.virtual_file_io_mode, "starting with virtual_file IO mode");
info!(?conf.wal_receiver_protocol, "starting with WAL receiver protocol");
// The tenants directory contains all the pageserver local disk state.
// Create if not exists and make sure all the contents are durable before proceeding.

View File

@@ -14,6 +14,7 @@ use remote_storage::{RemotePath, RemoteStorageConfig};
use std::env;
use storage_broker::Uri;
use utils::logging::SecretString;
use utils::postgres_client::PostgresClientProtocol;
use once_cell::sync::OnceCell;
use reqwest::Url;
@@ -187,6 +188,8 @@ pub struct PageServerConf {
/// Optionally disable disk syncs (unsafe!)
pub no_sync: bool,
pub wal_receiver_protocol: PostgresClientProtocol,
pub page_service_pipelining: pageserver_api::config::PageServicePipeliningConfig,
}
@@ -347,6 +350,7 @@ impl PageServerConf {
virtual_file_io_engine,
tenant_config,
no_sync,
wal_receiver_protocol,
page_service_pipelining,
} = config_toml;
@@ -390,6 +394,7 @@ impl PageServerConf {
import_pgdata_upcall_api,
import_pgdata_upcall_api_token: import_pgdata_upcall_api_token.map(SecretString::from),
import_pgdata_aws_endpoint_url,
wal_receiver_protocol,
page_service_pipelining,
// ------------------------------------------------------------

View File

@@ -1144,18 +1144,24 @@ pub(crate) mod mock {
rx: tokio::sync::mpsc::UnboundedReceiver<ListWriterQueueMessage>,
executor_rx: tokio::sync::mpsc::Receiver<DeleterMessage>,
cancel: CancellationToken,
executed: Arc<AtomicUsize>,
}
impl ConsumerState {
async fn consume(&mut self, remote_storage: &GenericRemoteStorage) -> usize {
let mut executed = 0;
async fn consume(&mut self, remote_storage: &GenericRemoteStorage) {
info!("Executing all pending deletions");
// Transform all executor messages to generic frontend messages
while let Ok(msg) = self.executor_rx.try_recv() {
loop {
use either::Either;
let msg = tokio::select! {
left = self.executor_rx.recv() => Either::Left(left),
right = self.rx.recv() => Either::Right(right),
};
match msg {
DeleterMessage::Delete(objects) => {
Either::Left(None) => break,
Either::Right(None) => break,
Either::Left(Some(DeleterMessage::Delete(objects))) => {
for path in objects {
match remote_storage.delete(&path, &self.cancel).await {
Ok(_) => {
@@ -1165,18 +1171,13 @@ pub(crate) mod mock {
error!("Failed to delete {path}, leaking object! ({e})");
}
}
executed += 1;
self.executed.fetch_add(1, Ordering::Relaxed);
}
}
DeleterMessage::Flush(flush_op) => {
Either::Left(Some(DeleterMessage::Flush(flush_op))) => {
flush_op.notify();
}
}
}
while let Ok(msg) = self.rx.try_recv() {
match msg {
ListWriterQueueMessage::Delete(op) => {
Either::Right(Some(ListWriterQueueMessage::Delete(op))) => {
let mut objects = op.objects;
for (layer, meta) in op.layers {
objects.push(remote_layer_path(
@@ -1198,33 +1199,27 @@ pub(crate) mod mock {
error!("Failed to delete {path}, leaking object! ({e})");
}
}
executed += 1;
self.executed.fetch_add(1, Ordering::Relaxed);
}
}
ListWriterQueueMessage::Flush(op) => {
Either::Right(Some(ListWriterQueueMessage::Flush(op))) => {
op.notify();
}
ListWriterQueueMessage::FlushExecute(op) => {
Either::Right(Some(ListWriterQueueMessage::FlushExecute(op))) => {
// We have already executed all prior deletions because mock does them inline
op.notify();
}
ListWriterQueueMessage::Recover(_) => {
Either::Right(Some(ListWriterQueueMessage::Recover(_))) => {
// no-op in mock
}
}
info!("All pending deletions have been executed");
}
executed
}
}
pub struct MockDeletionQueue {
tx: tokio::sync::mpsc::UnboundedSender<ListWriterQueueMessage>,
executor_tx: tokio::sync::mpsc::Sender<DeleterMessage>,
executed: Arc<AtomicUsize>,
remote_storage: Option<GenericRemoteStorage>,
consumer: std::sync::Mutex<ConsumerState>,
lsn_table: Arc<std::sync::RwLock<VisibleLsnUpdates>>,
}
@@ -1235,29 +1230,34 @@ pub(crate) mod mock {
let executed = Arc::new(AtomicUsize::new(0));
let mut consumer = ConsumerState {
rx,
executor_rx,
cancel: CancellationToken::new(),
executed: executed.clone(),
};
tokio::spawn(async move {
if let Some(remote_storage) = &remote_storage {
consumer.consume(remote_storage).await;
}
});
Self {
tx,
executor_tx,
executed,
remote_storage,
consumer: std::sync::Mutex::new(ConsumerState {
rx,
executor_rx,
cancel: CancellationToken::new(),
}),
lsn_table: Arc::new(std::sync::RwLock::new(VisibleLsnUpdates::new())),
}
}
#[allow(clippy::await_holding_lock)]
pub async fn pump(&self) {
if let Some(remote_storage) = &self.remote_storage {
// Permit holding mutex across await, because this is only ever
// called once at a time in tests.
let mut locked = self.consumer.lock().unwrap();
let count = locked.consume(remote_storage).await;
self.executed.fetch_add(count, Ordering::Relaxed);
}
let (tx, rx) = tokio::sync::oneshot::channel();
self.executor_tx
.send(DeleterMessage::Flush(FlushOp { tx }))
.await
.expect("Failed to send flush message");
rx.await.ok();
}
pub(crate) fn new_client(&self) -> DeletionQueueClient {

View File

@@ -3,7 +3,7 @@ use metrics::{
register_counter_vec, register_gauge_vec, register_histogram, register_histogram_vec,
register_int_counter, register_int_counter_pair_vec, register_int_counter_vec,
register_int_gauge, register_int_gauge_vec, register_uint_gauge, register_uint_gauge_vec,
Counter, CounterVec, GaugeVec, Histogram, HistogramVec, IntCounter, IntCounterPair,
Counter, CounterVec, Gauge, GaugeVec, Histogram, HistogramVec, IntCounter, IntCounterPair,
IntCounterPairVec, IntCounterVec, IntGauge, IntGaugeVec, UIntGauge, UIntGaugeVec,
};
use once_cell::sync::Lazy;
@@ -457,6 +457,15 @@ pub(crate) static WAIT_LSN_TIME: Lazy<Histogram> = Lazy::new(|| {
.expect("failed to define a metric")
});
static FLUSH_WAIT_UPLOAD_TIME: Lazy<GaugeVec> = Lazy::new(|| {
register_gauge_vec!(
"pageserver_flush_wait_upload_seconds",
"Time spent waiting for preceding uploads during layer flush",
&["tenant_id", "shard_id", "timeline_id"]
)
.expect("failed to define a metric")
});
static LAST_RECORD_LSN: Lazy<IntGaugeVec> = Lazy::new(|| {
register_int_gauge_vec!(
"pageserver_last_record_lsn",
@@ -653,6 +662,35 @@ pub(crate) static COMPRESSION_IMAGE_OUTPUT_BYTES: Lazy<IntCounter> = Lazy::new(|
.expect("failed to define a metric")
});
pub(crate) static RELSIZE_CACHE_ENTRIES: Lazy<UIntGauge> = Lazy::new(|| {
register_uint_gauge!(
"pageserver_relsize_cache_entries",
"Number of entries in the relation size cache",
)
.expect("failed to define a metric")
});
pub(crate) static RELSIZE_CACHE_HITS: Lazy<IntCounter> = Lazy::new(|| {
register_int_counter!("pageserver_relsize_cache_hits", "Relation size cache hits",)
.expect("failed to define a metric")
});
pub(crate) static RELSIZE_CACHE_MISSES: Lazy<IntCounter> = Lazy::new(|| {
register_int_counter!(
"pageserver_relsize_cache_misses",
"Relation size cache misses",
)
.expect("failed to define a metric")
});
pub(crate) static RELSIZE_CACHE_MISSES_OLD: Lazy<IntCounter> = Lazy::new(|| {
register_int_counter!(
"pageserver_relsize_cache_misses_old",
"Relation size cache misses where the lookup LSN is older than the last relation update"
)
.expect("failed to define a metric")
});
pub(crate) mod initial_logical_size {
use metrics::{register_int_counter, register_int_counter_vec, IntCounter, IntCounterVec};
use once_cell::sync::Lazy;
@@ -2106,6 +2144,7 @@ pub(crate) struct WalIngestMetrics {
pub(crate) records_committed: IntCounter,
pub(crate) records_filtered: IntCounter,
pub(crate) gap_blocks_zeroed_on_rel_extend: IntCounter,
pub(crate) clear_vm_bits_unknown: IntCounterVec,
}
pub(crate) static WAL_INGEST: Lazy<WalIngestMetrics> = Lazy::new(|| WalIngestMetrics {
@@ -2134,6 +2173,12 @@ pub(crate) static WAL_INGEST: Lazy<WalIngestMetrics> = Lazy::new(|| WalIngestMet
"Total number of zero gap blocks written on relation extends"
)
.expect("failed to define a metric"),
clear_vm_bits_unknown: register_int_counter_vec!(
"pageserver_wal_ingest_clear_vm_bits_unknown",
"Number of ignored ClearVmBits operations due to unknown pages/relations",
&["entity"],
)
.expect("failed to define a metric"),
});
pub(crate) static WAL_REDO_TIME: Lazy<Histogram> = Lazy::new(|| {
@@ -2336,6 +2381,7 @@ pub(crate) struct TimelineMetrics {
shard_id: String,
timeline_id: String,
pub flush_time_histo: StorageTimeMetrics,
pub flush_wait_upload_time_gauge: Gauge,
pub compact_time_histo: StorageTimeMetrics,
pub create_images_time_histo: StorageTimeMetrics,
pub logical_size_histo: StorageTimeMetrics,
@@ -2379,6 +2425,9 @@ impl TimelineMetrics {
&shard_id,
&timeline_id,
);
let flush_wait_upload_time_gauge = FLUSH_WAIT_UPLOAD_TIME
.get_metric_with_label_values(&[&tenant_id, &shard_id, &timeline_id])
.unwrap();
let compact_time_histo = StorageTimeMetrics::new(
StorageTimeOperation::Compact,
&tenant_id,
@@ -2516,6 +2565,7 @@ impl TimelineMetrics {
shard_id,
timeline_id,
flush_time_histo,
flush_wait_upload_time_gauge,
compact_time_histo,
create_images_time_histo,
logical_size_histo,
@@ -2563,6 +2613,14 @@ impl TimelineMetrics {
self.resident_physical_size_gauge.get()
}
pub(crate) fn flush_wait_upload_time_gauge_add(&self, duration: f64) {
self.flush_wait_upload_time_gauge.add(duration);
crate::metrics::FLUSH_WAIT_UPLOAD_TIME
.get_metric_with_label_values(&[&self.tenant_id, &self.shard_id, &self.timeline_id])
.unwrap()
.add(duration);
}
pub(crate) fn shutdown(&self) {
let was_shutdown = self
.shutdown
@@ -2579,6 +2637,7 @@ impl TimelineMetrics {
let timeline_id = &self.timeline_id;
let shard_id = &self.shard_id;
let _ = LAST_RECORD_LSN.remove_label_values(&[tenant_id, shard_id, timeline_id]);
let _ = FLUSH_WAIT_UPLOAD_TIME.remove_label_values(&[tenant_id, shard_id, timeline_id]);
let _ = STANDBY_HORIZON.remove_label_values(&[tenant_id, shard_id, timeline_id]);
{
RESIDENT_PHYSICAL_SIZE_GLOBAL.sub(self.resident_physical_size_get());

View File

@@ -10,6 +10,9 @@ use super::tenant::{PageReconstructError, Timeline};
use crate::aux_file;
use crate::context::RequestContext;
use crate::keyspace::{KeySpace, KeySpaceAccum};
use crate::metrics::{
RELSIZE_CACHE_ENTRIES, RELSIZE_CACHE_HITS, RELSIZE_CACHE_MISSES, RELSIZE_CACHE_MISSES_OLD,
};
use crate::span::{
debug_assert_current_span_has_tenant_and_timeline_id,
debug_assert_current_span_has_tenant_and_timeline_id_no_shard_id,
@@ -389,7 +392,9 @@ impl Timeline {
result
}
// Get size of a database in blocks
/// Get size of a database in blocks. This is only accurate on shard 0. It will undercount on
/// other shards, by only accounting for relations the shard has pages for, and only accounting
/// for pages up to the highest page number it has stored.
pub(crate) async fn get_db_size(
&self,
spcnode: Oid,
@@ -408,7 +413,10 @@ impl Timeline {
Ok(total_blocks)
}
/// Get size of a relation file
/// Get size of a relation file. The relation must exist, otherwise an error is returned.
///
/// This is only accurate on shard 0. On other shards, it will return the size up to the highest
/// page number stored in the shard.
pub(crate) async fn get_rel_size(
&self,
tag: RelTag,
@@ -444,7 +452,10 @@ impl Timeline {
Ok(nblocks)
}
/// Does relation exist?
/// Does the relation exist?
///
/// Only shard 0 has a full view of the relations. Other shards only know about relations that
/// the shard stores pages for.
pub(crate) async fn get_rel_exists(
&self,
tag: RelTag,
@@ -478,6 +489,9 @@ impl Timeline {
/// Get a list of all existing relations in given tablespace and database.
///
/// Only shard 0 has a full view of the relations. Other shards only know about relations that
/// the shard stores pages for.
///
/// # Cancel-Safety
///
/// This method is cancellation-safe.
@@ -1129,9 +1143,12 @@ impl Timeline {
let rel_size_cache = self.rel_size_cache.read().unwrap();
if let Some((cached_lsn, nblocks)) = rel_size_cache.map.get(tag) {
if lsn >= *cached_lsn {
RELSIZE_CACHE_HITS.inc();
return Some(*nblocks);
}
RELSIZE_CACHE_MISSES_OLD.inc();
}
RELSIZE_CACHE_MISSES.inc();
None
}
@@ -1156,6 +1173,7 @@ impl Timeline {
}
hash_map::Entry::Vacant(entry) => {
entry.insert((lsn, nblocks));
RELSIZE_CACHE_ENTRIES.inc();
}
}
}
@@ -1163,13 +1181,17 @@ impl Timeline {
/// Store cached relation size
pub fn set_cached_rel_size(&self, tag: RelTag, lsn: Lsn, nblocks: BlockNumber) {
let mut rel_size_cache = self.rel_size_cache.write().unwrap();
rel_size_cache.map.insert(tag, (lsn, nblocks));
if rel_size_cache.map.insert(tag, (lsn, nblocks)).is_none() {
RELSIZE_CACHE_ENTRIES.inc();
}
}
/// Remove cached relation size
pub fn remove_cached_rel_size(&self, tag: &RelTag) {
let mut rel_size_cache = self.rel_size_cache.write().unwrap();
rel_size_cache.map.remove(tag);
if rel_size_cache.map.remove(tag).is_some() {
RELSIZE_CACHE_ENTRIES.dec();
}
}
}
@@ -1229,10 +1251,9 @@ impl<'a> DatadirModification<'a> {
}
pub(crate) fn has_dirty_data(&self) -> bool {
!self
.pending_data_batch
self.pending_data_batch
.as_ref()
.map_or(true, |b| b.is_empty())
.map_or(false, |b| b.has_data())
}
/// Set the current lsn
@@ -1408,7 +1429,7 @@ impl<'a> DatadirModification<'a> {
Some(pending_batch) => {
pending_batch.extend(batch);
}
None if !batch.is_empty() => {
None if batch.has_data() => {
self.pending_data_batch = Some(batch);
}
None => {

View File

@@ -3215,6 +3215,18 @@ impl Tenant {
}
}
if let ShutdownMode::Reload = shutdown_mode {
tracing::info!("Flushing deletion queue");
if let Err(e) = self.deletion_queue_client.flush().await {
match e {
DeletionQueueError::ShuttingDown => {
// This is the only error we expect for now. In the future, if more error
// variants are added, we should handle them here.
}
}
}
}
// We cancel the Tenant's cancellation token _after_ the timelines have all shut down. This permits
// them to continue to do work during their shutdown methods, e.g. flushing data.
tracing::debug!("Cancelling CancellationToken");
@@ -5344,6 +5356,7 @@ pub(crate) mod harness {
lsn_lease_length: Some(tenant_conf.lsn_lease_length),
lsn_lease_length_for_ts: Some(tenant_conf.lsn_lease_length_for_ts),
timeline_offloading: Some(tenant_conf.timeline_offloading),
wal_receiver_protocol_override: tenant_conf.wal_receiver_protocol_override,
}
}
}

View File

@@ -19,6 +19,7 @@ use serde_json::Value;
use std::num::NonZeroU64;
use std::time::Duration;
use utils::generation::Generation;
use utils::postgres_client::PostgresClientProtocol;
#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub(crate) enum AttachmentMode {
@@ -353,6 +354,9 @@ pub struct TenantConfOpt {
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(default)]
pub timeline_offloading: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub wal_receiver_protocol_override: Option<PostgresClientProtocol>,
}
impl TenantConfOpt {
@@ -418,6 +422,9 @@ impl TenantConfOpt {
timeline_offloading: self
.lazy_slru_download
.unwrap_or(global_conf.timeline_offloading),
wal_receiver_protocol_override: self
.wal_receiver_protocol_override
.or(global_conf.wal_receiver_protocol_override),
}
}
}
@@ -472,6 +479,7 @@ impl From<TenantConfOpt> for models::TenantConfig {
lsn_lease_length: value.lsn_lease_length.map(humantime),
lsn_lease_length_for_ts: value.lsn_lease_length_for_ts.map(humantime),
timeline_offloading: value.timeline_offloading,
wal_receiver_protocol_override: value.wal_receiver_protocol_override,
}
}
}

View File

@@ -1960,7 +1960,7 @@ impl TenantManager {
attempt.before_reset_tenant();
let (_guard, progress) = utils::completion::channel();
match tenant.shutdown(progress, ShutdownMode::Flush).await {
match tenant.shutdown(progress, ShutdownMode::Reload).await {
Ok(()) => {
slot_guard.drop_old_value().expect("it was just shutdown");
}

View File

@@ -50,6 +50,7 @@ use tokio_util::sync::CancellationToken;
use tracing::*;
use utils::{
fs_ext, pausable_failpoint,
postgres_client::PostgresClientProtocol,
sync::gate::{Gate, GateGuard},
};
use wal_decoder::serialized_batch::SerializedValueBatch;
@@ -893,10 +894,11 @@ pub(crate) enum ShutdownMode {
/// While we are flushing, we continue to accept read I/O for LSNs ingested before
/// the call to [`Timeline::shutdown`].
FreezeAndFlush,
/// Only flush the layers to the remote storage without freezing any open layers. This is the
/// mode used by ancestor detach and any other operations that reloads a tenant but not increasing
/// the generation number.
Flush,
/// Only flush the layers to the remote storage without freezing any open layers. Flush the deletion
/// queue. This is the mode used by ancestor detach and any other operations that reloads a tenant
/// but not increasing the generation number. Note that this mode cannot be used at tenant shutdown,
/// as flushing the deletion queue at that time will cause shutdown-in-progress errors.
Reload,
/// Shut down immediately, without waiting for any open layers to flush.
Hard,
}
@@ -1817,7 +1819,7 @@ impl Timeline {
}
}
if let ShutdownMode::Flush = mode {
if let ShutdownMode::Reload = mode {
// drain the upload queue
self.remote_client.shutdown().await;
if !self.remote_client.no_pending_work() {
@@ -2178,6 +2180,21 @@ impl Timeline {
)
}
/// Resolve the effective WAL receiver protocol to use for this tenant.
///
/// Priority order is:
/// 1. Tenant config override
/// 2. Default value for tenant config override
/// 3. Pageserver config override
/// 4. Pageserver config default
pub fn resolve_wal_receiver_protocol(&self) -> PostgresClientProtocol {
let tenant_conf = self.tenant_conf.load().tenant_conf.clone();
tenant_conf
.wal_receiver_protocol_override
.or(self.conf.default_tenant_conf.wal_receiver_protocol_override)
.unwrap_or(self.conf.wal_receiver_protocol)
}
pub(super) fn tenant_conf_updated(&self, new_conf: &AttachedTenantConf) {
// NB: Most tenant conf options are read by background loops, so,
// changes will automatically be picked up.
@@ -2470,6 +2487,7 @@ impl Timeline {
*guard = Some(WalReceiver::start(
Arc::clone(self),
WalReceiverConf {
protocol: self.resolve_wal_receiver_protocol(),
wal_connect_timeout,
lagging_wal_timeout,
max_lsn_wal_lag,
@@ -3829,7 +3847,8 @@ impl Timeline {
};
// Backpressure mechanism: wait with continuation of the flush loop until we have uploaded all layer files.
// This makes us refuse ingest until the new layers have been persisted to the remote.
// This makes us refuse ingest until the new layers have been persisted to the remote
let start = Instant::now();
self.remote_client
.wait_completion()
.await
@@ -3842,6 +3861,8 @@ impl Timeline {
FlushLayerError::Other(anyhow!(e).into())
}
})?;
let duration = start.elapsed().as_secs_f64();
self.metrics.flush_wait_upload_time_gauge_add(duration);
// FIXME: between create_delta_layer and the scheduling of the upload in `update_metadata_file`,
// a compaction can delete the file and then it won't be available for uploads any more.
@@ -5896,7 +5917,7 @@ impl<'a> TimelineWriter<'a> {
batch: SerializedValueBatch,
ctx: &RequestContext,
) -> anyhow::Result<()> {
if batch.is_empty() {
if !batch.has_data() {
return Ok(());
}

View File

@@ -58,7 +58,7 @@ pub(crate) async fn offload_timeline(
}
// Now that the Timeline is in Stopping state, request all the related tasks to shut down.
timeline.shutdown(super::ShutdownMode::Flush).await;
timeline.shutdown(super::ShutdownMode::Reload).await;
// TODO extend guard mechanism above with method
// to make deletions possible while offloading is in progress

View File

@@ -38,6 +38,7 @@ use storage_broker::BrokerClientChannel;
use tokio::sync::watch;
use tokio_util::sync::CancellationToken;
use tracing::*;
use utils::postgres_client::PostgresClientProtocol;
use self::connection_manager::ConnectionManagerStatus;
@@ -45,6 +46,7 @@ use super::Timeline;
#[derive(Clone)]
pub struct WalReceiverConf {
pub protocol: PostgresClientProtocol,
/// The timeout on the connection to safekeeper for WAL streaming.
pub wal_connect_timeout: Duration,
/// The timeout to use to determine when the current connection is "stale" and reconnect to the other one.

View File

@@ -36,7 +36,9 @@ use postgres_connection::PgConnectionConfig;
use utils::backoff::{
exponential_backoff, DEFAULT_BASE_BACKOFF_SECONDS, DEFAULT_MAX_BACKOFF_SECONDS,
};
use utils::postgres_client::wal_stream_connection_config;
use utils::postgres_client::{
wal_stream_connection_config, ConnectionConfigArgs, PostgresClientProtocol,
};
use utils::{
id::{NodeId, TenantTimelineId},
lsn::Lsn,
@@ -533,6 +535,7 @@ impl ConnectionManagerState {
let node_id = new_sk.safekeeper_id;
let connect_timeout = self.conf.wal_connect_timeout;
let ingest_batch_size = self.conf.ingest_batch_size;
let protocol = self.conf.protocol;
let timeline = Arc::clone(&self.timeline);
let ctx = ctx.detached_child(
TaskKind::WalReceiverConnectionHandler,
@@ -546,6 +549,7 @@ impl ConnectionManagerState {
let res = super::walreceiver_connection::handle_walreceiver_connection(
timeline,
protocol,
new_sk.wal_source_connconf,
events_sender,
cancellation.clone(),
@@ -984,15 +988,33 @@ impl ConnectionManagerState {
if info.safekeeper_connstr.is_empty() {
return None; // no connection string, ignore sk
}
match wal_stream_connection_config(
self.id,
info.safekeeper_connstr.as_ref(),
match &self.conf.auth_token {
None => None,
Some(x) => Some(x),
let (shard_number, shard_count, shard_stripe_size) = match self.conf.protocol {
PostgresClientProtocol::Vanilla => {
(None, None, None)
},
self.conf.availability_zone.as_deref(),
) {
PostgresClientProtocol::Interpreted { .. } => {
let shard_identity = self.timeline.get_shard_identity();
(
Some(shard_identity.number.0),
Some(shard_identity.count.0),
Some(shard_identity.stripe_size.0),
)
}
};
let connection_conf_args = ConnectionConfigArgs {
protocol: self.conf.protocol,
ttid: self.id,
shard_number,
shard_count,
shard_stripe_size,
listen_pg_addr_str: info.safekeeper_connstr.as_ref(),
auth_token: self.conf.auth_token.as_ref().map(|t| t.as_str()),
availability_zone: self.conf.availability_zone.as_deref()
};
match wal_stream_connection_config(connection_conf_args) {
Ok(connstr) => Some((*sk_id, info, connstr)),
Err(e) => {
error!("Failed to create wal receiver connection string from broker data of safekeeper node {}: {e:#}", sk_id);
@@ -1096,6 +1118,7 @@ impl ReconnectReason {
mod tests {
use super::*;
use crate::tenant::harness::{TenantHarness, TIMELINE_ID};
use pageserver_api::config::defaults::DEFAULT_WAL_RECEIVER_PROTOCOL;
use url::Host;
fn dummy_broker_sk_timeline(
@@ -1532,6 +1555,7 @@ mod tests {
timeline,
cancel: CancellationToken::new(),
conf: WalReceiverConf {
protocol: DEFAULT_WAL_RECEIVER_PROTOCOL,
wal_connect_timeout: Duration::from_secs(1),
lagging_wal_timeout: Duration::from_secs(1),
max_lsn_wal_lag: NonZeroU64::new(1024 * 1024).unwrap(),

View File

@@ -22,7 +22,10 @@ use tokio::{select, sync::watch, time};
use tokio_postgres::{replication::ReplicationStream, Client};
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, trace, warn, Instrument};
use wal_decoder::models::{FlushUncommittedRecords, InterpretedWalRecord};
use wal_decoder::{
models::{FlushUncommittedRecords, InterpretedWalRecord, InterpretedWalRecords},
wire_format::FromWireFormat,
};
use super::TaskStateUpdate;
use crate::{
@@ -36,7 +39,7 @@ use crate::{
use postgres_backend::is_expected_io_error;
use postgres_connection::PgConnectionConfig;
use postgres_ffi::waldecoder::WalStreamDecoder;
use utils::{id::NodeId, lsn::Lsn};
use utils::{id::NodeId, lsn::Lsn, postgres_client::PostgresClientProtocol};
use utils::{pageserver_feedback::PageserverFeedback, sync::gate::GateError};
/// Status of the connection.
@@ -109,6 +112,7 @@ impl From<WalDecodeError> for WalReceiverError {
#[allow(clippy::too_many_arguments)]
pub(super) async fn handle_walreceiver_connection(
timeline: Arc<Timeline>,
protocol: PostgresClientProtocol,
wal_source_connconf: PgConnectionConfig,
events_sender: watch::Sender<TaskStateUpdate<WalConnectionStatus>>,
cancellation: CancellationToken,
@@ -260,6 +264,14 @@ pub(super) async fn handle_walreceiver_connection(
let mut walingest = WalIngest::new(timeline.as_ref(), startpoint, &ctx).await?;
let interpreted_proto_config = match protocol {
PostgresClientProtocol::Vanilla => None,
PostgresClientProtocol::Interpreted {
format,
compression,
} => Some((format, compression)),
};
while let Some(replication_message) = {
select! {
_ = cancellation.cancelled() => {
@@ -291,6 +303,15 @@ pub(super) async fn handle_walreceiver_connection(
connection_status.latest_connection_update = now;
connection_status.commit_lsn = Some(Lsn::from(keepalive.wal_end()));
}
ReplicationMessage::RawInterpretedWalRecords(raw) => {
connection_status.latest_connection_update = now;
if !raw.data().is_empty() {
connection_status.latest_wal_update = now;
}
connection_status.commit_lsn = Some(Lsn::from(raw.commit_lsn()));
connection_status.streaming_lsn = Some(Lsn::from(raw.streaming_lsn()));
}
&_ => {}
};
if let Err(e) = events_sender.send(TaskStateUpdate::Progress(connection_status)) {
@@ -298,7 +319,148 @@ pub(super) async fn handle_walreceiver_connection(
return Ok(());
}
async fn commit(
modification: &mut DatadirModification<'_>,
uncommitted: &mut u64,
filtered: &mut u64,
ctx: &RequestContext,
) -> anyhow::Result<()> {
WAL_INGEST
.records_committed
.inc_by(*uncommitted - *filtered);
modification.commit(ctx).await?;
*uncommitted = 0;
*filtered = 0;
Ok(())
}
let status_update = match replication_message {
ReplicationMessage::RawInterpretedWalRecords(raw) => {
WAL_INGEST.bytes_received.inc_by(raw.data().len() as u64);
let mut uncommitted_records = 0;
let mut filtered_records = 0;
// This is the end LSN of the raw WAL from which the records
// were interpreted.
let streaming_lsn = Lsn::from(raw.streaming_lsn());
let (format, compression) = interpreted_proto_config.unwrap();
let batch = InterpretedWalRecords::from_wire(raw.data(), format, compression)
.await
.with_context(|| {
anyhow::anyhow!(
"Failed to deserialize interpreted records ending at LSN {streaming_lsn}"
)
})?;
let InterpretedWalRecords {
records,
next_record_lsn,
} = batch;
tracing::debug!(
"Received WAL up to {} with next_record_lsn={:?}",
streaming_lsn,
next_record_lsn
);
// We start the modification at 0 because each interpreted record
// advances it to its end LSN. 0 is just an initialization placeholder.
let mut modification = timeline.begin_modification(Lsn(0));
for interpreted in records {
if matches!(interpreted.flush_uncommitted, FlushUncommittedRecords::Yes)
&& uncommitted_records > 0
{
commit(
&mut modification,
&mut uncommitted_records,
&mut filtered_records,
&ctx,
)
.await?;
}
let local_next_record_lsn = interpreted.next_record_lsn;
let ingested = walingest
.ingest_record(interpreted, &mut modification, &ctx)
.await
.with_context(|| {
format!("could not ingest record at {local_next_record_lsn}")
})?;
if !ingested {
tracing::debug!(
"ingest: filtered out record @ LSN {local_next_record_lsn}"
);
WAL_INGEST.records_filtered.inc();
filtered_records += 1;
}
uncommitted_records += 1;
// FIXME: this cannot be made pausable_failpoint without fixing the
// failpoint library; in tests, the added amount of debugging will cause us
// to timeout the tests.
fail_point!("walreceiver-after-ingest");
// Commit every ingest_batch_size records. Even if we filtered out
// all records, we still need to call commit to advance the LSN.
if uncommitted_records >= ingest_batch_size
|| modification.approx_pending_bytes()
> DatadirModification::MAX_PENDING_BYTES
{
commit(
&mut modification,
&mut uncommitted_records,
&mut filtered_records,
&ctx,
)
.await?;
}
}
// Records might have been filtered out on the safekeeper side, but we still
// need to advance last record LSN on all shards. If we've not ingested the latest
// record, then set the LSN of the modification past it. This way all shards
// advance their last record LSN at the same time.
let needs_last_record_lsn_advance = match next_record_lsn.map(Lsn::from) {
Some(lsn) if lsn > modification.get_lsn() => {
modification.set_lsn(lsn).unwrap();
true
}
_ => false,
};
if uncommitted_records > 0 || needs_last_record_lsn_advance {
// Commit any uncommitted records
commit(
&mut modification,
&mut uncommitted_records,
&mut filtered_records,
&ctx,
)
.await?;
}
if !caught_up && streaming_lsn >= end_of_wal {
info!("caught up at LSN {streaming_lsn}");
caught_up = true;
}
tracing::debug!(
"Ingested WAL up to {streaming_lsn}. Last record LSN is {}",
timeline.get_last_record_lsn()
);
if let Some(lsn) = next_record_lsn {
last_rec_lsn = lsn;
}
Some(streaming_lsn)
}
ReplicationMessage::XLogData(xlog_data) => {
// Pass the WAL data to the decoder, and see if we can decode
// more records as a result.
@@ -316,21 +478,6 @@ pub(super) async fn handle_walreceiver_connection(
let mut uncommitted_records = 0;
let mut filtered_records = 0;
async fn commit(
modification: &mut DatadirModification<'_>,
uncommitted: &mut u64,
filtered: &mut u64,
ctx: &RequestContext,
) -> anyhow::Result<()> {
WAL_INGEST
.records_committed
.inc_by(*uncommitted - *filtered);
modification.commit(ctx).await?;
*uncommitted = 0;
*filtered = 0;
Ok(())
}
while let Some((next_record_lsn, recdata)) = waldecoder.poll_decode()? {
// It is important to deal with the aligned records as lsn in getPage@LSN is
// aligned and can be several bytes bigger. Without this alignment we are

Some files were not shown because too many files have changed in this diff Show More