mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-15 09:22:55 +00:00
Merge remote-tracking branch 'origin/main' into problame/batching-sidecar-task
This commit is contained in:
@@ -46,6 +46,9 @@ workspace-members = [
|
||||
"utils",
|
||||
"wal_craft",
|
||||
"walproposer",
|
||||
"postgres-protocol2",
|
||||
"postgres-types2",
|
||||
"tokio-postgres2",
|
||||
]
|
||||
|
||||
# Write out exact versions rather than a semver range. (Defaults to false.)
|
||||
|
||||
17
.github/actions/run-python-test-set/action.yml
vendored
17
.github/actions/run-python-test-set/action.yml
vendored
@@ -36,8 +36,8 @@ inputs:
|
||||
description: 'Region name for real s3 tests'
|
||||
required: false
|
||||
default: ''
|
||||
rerun_flaky:
|
||||
description: 'Whether to rerun flaky tests'
|
||||
rerun_failed:
|
||||
description: 'Whether to rerun failed tests'
|
||||
required: false
|
||||
default: 'false'
|
||||
pg_version:
|
||||
@@ -108,7 +108,7 @@ runs:
|
||||
COMPATIBILITY_SNAPSHOT_DIR: /tmp/compatibility_snapshot_pg${{ inputs.pg_version }}
|
||||
ALLOW_BACKWARD_COMPATIBILITY_BREAKAGE: contains(github.event.pull_request.labels.*.name, 'backward compatibility breakage')
|
||||
ALLOW_FORWARD_COMPATIBILITY_BREAKAGE: contains(github.event.pull_request.labels.*.name, 'forward compatibility breakage')
|
||||
RERUN_FLAKY: ${{ inputs.rerun_flaky }}
|
||||
RERUN_FAILED: ${{ inputs.rerun_failed }}
|
||||
PG_VERSION: ${{ inputs.pg_version }}
|
||||
shell: bash -euxo pipefail {0}
|
||||
run: |
|
||||
@@ -154,15 +154,8 @@ runs:
|
||||
EXTRA_PARAMS="--out-dir $PERF_REPORT_DIR $EXTRA_PARAMS"
|
||||
fi
|
||||
|
||||
if [ "${RERUN_FLAKY}" == "true" ]; then
|
||||
mkdir -p $TEST_OUTPUT
|
||||
poetry run ./scripts/flaky_tests.py "${TEST_RESULT_CONNSTR}" \
|
||||
--days 7 \
|
||||
--output "$TEST_OUTPUT/flaky.json" \
|
||||
--pg-version "${DEFAULT_PG_VERSION}" \
|
||||
--build-type "${BUILD_TYPE}"
|
||||
|
||||
EXTRA_PARAMS="--flaky-tests-json $TEST_OUTPUT/flaky.json $EXTRA_PARAMS"
|
||||
if [ "${RERUN_FAILED}" == "true" ]; then
|
||||
EXTRA_PARAMS="--reruns 2 $EXTRA_PARAMS"
|
||||
fi
|
||||
|
||||
# We use pytest-split plugin to run benchmarks in parallel on different CI runners
|
||||
|
||||
11
.github/workflows/_build-and-test-locally.yml
vendored
11
.github/workflows/_build-and-test-locally.yml
vendored
@@ -19,8 +19,8 @@ on:
|
||||
description: 'debug or release'
|
||||
required: true
|
||||
type: string
|
||||
pg-versions:
|
||||
description: 'a json array of postgres versions to run regression tests on'
|
||||
test-cfg:
|
||||
description: 'a json object of postgres versions and lfc states to run regression tests on'
|
||||
required: true
|
||||
type: string
|
||||
|
||||
@@ -276,14 +276,14 @@ jobs:
|
||||
options: --init --shm-size=512mb --ulimit memlock=67108864:67108864
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
pg_version: ${{ fromJson(inputs.pg-versions) }}
|
||||
matrix: ${{ fromJSON(format('{{"include":{0}}}', inputs.test-cfg)) }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: true
|
||||
|
||||
- name: Pytest regression tests
|
||||
continue-on-error: ${{ matrix.lfc_state == 'with-lfc' }}
|
||||
uses: ./.github/actions/run-python-test-set
|
||||
timeout-minutes: 60
|
||||
with:
|
||||
@@ -293,13 +293,14 @@ jobs:
|
||||
run_with_real_s3: true
|
||||
real_s3_bucket: neon-github-ci-tests
|
||||
real_s3_region: eu-central-1
|
||||
rerun_flaky: true
|
||||
rerun_failed: true
|
||||
pg_version: ${{ matrix.pg_version }}
|
||||
env:
|
||||
TEST_RESULT_CONNSTR: ${{ secrets.REGRESS_TEST_RESULT_CONNSTR_NEW }}
|
||||
CHECK_ONDISK_DATA_COMPATIBILITY: nonempty
|
||||
BUILD_TAG: ${{ inputs.build-tag }}
|
||||
PAGESERVER_VIRTUAL_FILE_IO_ENGINE: tokio-epoll-uring
|
||||
USE_LFC: ${{ matrix.lfc_state == 'with-lfc' && 'true' || 'false' }}
|
||||
|
||||
# Temporary disable this step until we figure out why it's so flaky
|
||||
# Ref https://github.com/neondatabase/neon/issues/4540
|
||||
|
||||
14
.github/workflows/benchmarking.yml
vendored
14
.github/workflows/benchmarking.yml
vendored
@@ -541,7 +541,7 @@ jobs:
|
||||
|
||||
runs-on: ${{ matrix.RUNNER }}
|
||||
container:
|
||||
image: neondatabase/build-tools:pinned
|
||||
image: neondatabase/build-tools:pinned-bookworm
|
||||
credentials:
|
||||
username: ${{ secrets.NEON_DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.NEON_DOCKERHUB_PASSWORD }}
|
||||
@@ -558,12 +558,12 @@ jobs:
|
||||
arch=$(uname -m | sed 's/x86_64/amd64/g' | sed 's/aarch64/arm64/g')
|
||||
|
||||
cd /home/nonroot
|
||||
wget -q "https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-17/libpq5_17.2-1.pgdg110+1_${arch}.deb"
|
||||
wget -q "https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-16/postgresql-client-16_16.6-1.pgdg110+1_${arch}.deb"
|
||||
wget -q "https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-16/postgresql-16_16.6-1.pgdg110+1_${arch}.deb"
|
||||
dpkg -x libpq5_17.2-1.pgdg110+1_${arch}.deb pg
|
||||
dpkg -x postgresql-16_16.6-1.pgdg110+1_${arch}.deb pg
|
||||
dpkg -x postgresql-client-16_16.6-1.pgdg110+1_${arch}.deb pg
|
||||
wget -q "https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-17/libpq5_17.2-1.pgdg120+1_${arch}.deb"
|
||||
wget -q "https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-16/postgresql-client-16_16.6-1.pgdg120+1_${arch}.deb"
|
||||
wget -q "https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-16/postgresql-16_16.6-1.pgdg120+1_${arch}.deb"
|
||||
dpkg -x libpq5_17.2-1.pgdg120+1_${arch}.deb pg
|
||||
dpkg -x postgresql-16_16.6-1.pgdg120+1_${arch}.deb pg
|
||||
dpkg -x postgresql-client-16_16.6-1.pgdg120+1_${arch}.deb pg
|
||||
|
||||
mkdir -p /tmp/neon/pg_install/v16/bin
|
||||
ln -s /home/nonroot/pg/usr/lib/postgresql/16/bin/pgbench /tmp/neon/pg_install/v16/bin/pgbench
|
||||
|
||||
75
.github/workflows/build-build-tools-image.yml
vendored
75
.github/workflows/build-build-tools-image.yml
vendored
@@ -2,6 +2,17 @@ name: Build build-tools image
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
archs:
|
||||
description: "Json array of architectures to build"
|
||||
# Default values are set in `check-image` job, `set-variables` step
|
||||
type: string
|
||||
required: false
|
||||
debians:
|
||||
description: "Json array of Debian versions to build"
|
||||
# Default values are set in `check-image` job, `set-variables` step
|
||||
type: string
|
||||
required: false
|
||||
outputs:
|
||||
image-tag:
|
||||
description: "build-tools tag"
|
||||
@@ -32,25 +43,37 @@ jobs:
|
||||
check-image:
|
||||
runs-on: ubuntu-22.04
|
||||
outputs:
|
||||
tag: ${{ steps.get-build-tools-tag.outputs.image-tag }}
|
||||
found: ${{ steps.check-image.outputs.found }}
|
||||
archs: ${{ steps.set-variables.outputs.archs }}
|
||||
debians: ${{ steps.set-variables.outputs.debians }}
|
||||
tag: ${{ steps.set-variables.outputs.image-tag }}
|
||||
everything: ${{ steps.set-more-variables.outputs.everything }}
|
||||
found: ${{ steps.set-more-variables.outputs.found }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Get build-tools image tag for the current commit
|
||||
id: get-build-tools-tag
|
||||
- name: Set variables
|
||||
id: set-variables
|
||||
env:
|
||||
ARCHS: ${{ inputs.archs || '["x64","arm64"]' }}
|
||||
DEBIANS: ${{ inputs.debians || '["bullseye","bookworm"]' }}
|
||||
IMAGE_TAG: |
|
||||
${{ hashFiles('build-tools.Dockerfile',
|
||||
'.github/workflows/build-build-tools-image.yml') }}
|
||||
run: |
|
||||
echo "image-tag=${IMAGE_TAG}" | tee -a $GITHUB_OUTPUT
|
||||
echo "archs=${ARCHS}" | tee -a ${GITHUB_OUTPUT}
|
||||
echo "debians=${DEBIANS}" | tee -a ${GITHUB_OUTPUT}
|
||||
echo "image-tag=${IMAGE_TAG}" | tee -a ${GITHUB_OUTPUT}
|
||||
|
||||
- name: Check if such tag found in the registry
|
||||
id: check-image
|
||||
- name: Set more variables
|
||||
id: set-more-variables
|
||||
env:
|
||||
IMAGE_TAG: ${{ steps.get-build-tools-tag.outputs.image-tag }}
|
||||
IMAGE_TAG: ${{ steps.set-variables.outputs.image-tag }}
|
||||
EVERYTHING: |
|
||||
${{ contains(fromJson(steps.set-variables.outputs.archs), 'x64') &&
|
||||
contains(fromJson(steps.set-variables.outputs.archs), 'arm64') &&
|
||||
contains(fromJson(steps.set-variables.outputs.debians), 'bullseye') &&
|
||||
contains(fromJson(steps.set-variables.outputs.debians), 'bookworm') }}
|
||||
run: |
|
||||
if docker manifest inspect neondatabase/build-tools:${IMAGE_TAG}; then
|
||||
found=true
|
||||
@@ -58,8 +81,8 @@ jobs:
|
||||
found=false
|
||||
fi
|
||||
|
||||
echo "found=${found}" | tee -a $GITHUB_OUTPUT
|
||||
|
||||
echo "everything=${EVERYTHING}" | tee -a ${GITHUB_OUTPUT}
|
||||
echo "found=${found}" | tee -a ${GITHUB_OUTPUT}
|
||||
|
||||
build-image:
|
||||
needs: [ check-image ]
|
||||
@@ -67,8 +90,8 @@ jobs:
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
debian-version: [ bullseye, bookworm ]
|
||||
arch: [ x64, arm64 ]
|
||||
arch: ${{ fromJson(needs.check-image.outputs.archs) }}
|
||||
debian: ${{ fromJson(needs.check-image.outputs.debians) }}
|
||||
|
||||
runs-on: ${{ fromJson(format('["self-hosted", "{0}"]', matrix.arch == 'arm64' && 'large-arm64' || 'large')) }}
|
||||
|
||||
@@ -99,11 +122,11 @@ jobs:
|
||||
push: true
|
||||
pull: true
|
||||
build-args: |
|
||||
DEBIAN_VERSION=${{ matrix.debian-version }}
|
||||
cache-from: type=registry,ref=cache.neon.build/build-tools:cache-${{ matrix.debian-version }}-${{ matrix.arch }}
|
||||
cache-to: ${{ github.ref_name == 'main' && format('type=registry,ref=cache.neon.build/build-tools:cache-{0}-{1},mode=max', matrix.debian-version, matrix.arch) || '' }}
|
||||
DEBIAN_VERSION=${{ matrix.debian }}
|
||||
cache-from: type=registry,ref=cache.neon.build/build-tools:cache-${{ matrix.debian }}-${{ matrix.arch }}
|
||||
cache-to: ${{ github.ref_name == 'main' && format('type=registry,ref=cache.neon.build/build-tools:cache-{0}-{1},mode=max', matrix.debian, matrix.arch) || '' }}
|
||||
tags: |
|
||||
neondatabase/build-tools:${{ needs.check-image.outputs.tag }}-${{ matrix.debian-version }}-${{ matrix.arch }}
|
||||
neondatabase/build-tools:${{ needs.check-image.outputs.tag }}-${{ matrix.debian }}-${{ matrix.arch }}
|
||||
|
||||
merge-images:
|
||||
needs: [ check-image, build-image ]
|
||||
@@ -117,16 +140,22 @@ jobs:
|
||||
|
||||
- name: Create multi-arch image
|
||||
env:
|
||||
DEFAULT_DEBIAN_VERSION: bullseye
|
||||
DEFAULT_DEBIAN_VERSION: bookworm
|
||||
ARCHS: ${{ join(fromJson(needs.check-image.outputs.archs), ' ') }}
|
||||
DEBIANS: ${{ join(fromJson(needs.check-image.outputs.debians), ' ') }}
|
||||
EVERYTHING: ${{ needs.check-image.outputs.everything }}
|
||||
IMAGE_TAG: ${{ needs.check-image.outputs.tag }}
|
||||
run: |
|
||||
for debian_version in bullseye bookworm; do
|
||||
tags=("-t" "neondatabase/build-tools:${IMAGE_TAG}-${debian_version}")
|
||||
if [ "${debian_version}" == "${DEFAULT_DEBIAN_VERSION}" ]; then
|
||||
for debian in ${DEBIANS}; do
|
||||
tags=("-t" "neondatabase/build-tools:${IMAGE_TAG}-${debian}")
|
||||
|
||||
if [ "${EVERYTHING}" == "true" ] && [ "${debian}" == "${DEFAULT_DEBIAN_VERSION}" ]; then
|
||||
tags+=("-t" "neondatabase/build-tools:${IMAGE_TAG}")
|
||||
fi
|
||||
|
||||
docker buildx imagetools create "${tags[@]}" \
|
||||
neondatabase/build-tools:${IMAGE_TAG}-${debian_version}-x64 \
|
||||
neondatabase/build-tools:${IMAGE_TAG}-${debian_version}-arm64
|
||||
for arch in ${ARCHS}; do
|
||||
tags+=("neondatabase/build-tools:${IMAGE_TAG}-${debian}-${arch}")
|
||||
done
|
||||
|
||||
docker buildx imagetools create "${tags[@]}"
|
||||
done
|
||||
|
||||
9
.github/workflows/build_and_test.yml
vendored
9
.github/workflows/build_and_test.yml
vendored
@@ -253,7 +253,14 @@ jobs:
|
||||
build-tag: ${{ needs.tag.outputs.build-tag }}
|
||||
build-type: ${{ matrix.build-type }}
|
||||
# Run tests on all Postgres versions in release builds and only on the latest version in debug builds
|
||||
pg-versions: ${{ matrix.build-type == 'release' && '["v14", "v15", "v16", "v17"]' || '["v17"]' }}
|
||||
# run without LFC on v17 release only
|
||||
test-cfg: |
|
||||
${{ matrix.build-type == 'release' && '[{"pg_version":"v14", "lfc_state": "without-lfc"},
|
||||
{"pg_version":"v15", "lfc_state": "without-lfc"},
|
||||
{"pg_version":"v16", "lfc_state": "without-lfc"},
|
||||
{"pg_version":"v17", "lfc_state": "without-lfc"},
|
||||
{"pg_version":"v17", "lfc_state": "with-lfc"}]'
|
||||
|| '[{"pg_version":"v17", "lfc_state": "without-lfc"}]' }}
|
||||
secrets: inherit
|
||||
|
||||
# Keep `benchmarks` job outside of `build-and-test-locally` workflow to make job failures non-blocking
|
||||
|
||||
2
.github/workflows/periodic_pagebench.yml
vendored
2
.github/workflows/periodic_pagebench.yml
vendored
@@ -29,7 +29,7 @@ jobs:
|
||||
trigger_bench_on_ec2_machine_in_eu_central_1:
|
||||
runs-on: [ self-hosted, small ]
|
||||
container:
|
||||
image: neondatabase/build-tools:pinned
|
||||
image: neondatabase/build-tools:pinned-bookworm
|
||||
credentials:
|
||||
username: ${{ secrets.NEON_DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.NEON_DOCKERHUB_PASSWORD }}
|
||||
|
||||
2
.github/workflows/pin-build-tools-image.yml
vendored
2
.github/workflows/pin-build-tools-image.yml
vendored
@@ -94,7 +94,7 @@ jobs:
|
||||
|
||||
- name: Tag build-tools with `${{ env.TO_TAG }}` in Docker Hub, ECR, and ACR
|
||||
env:
|
||||
DEFAULT_DEBIAN_VERSION: bullseye
|
||||
DEFAULT_DEBIAN_VERSION: bookworm
|
||||
run: |
|
||||
for debian_version in bullseye bookworm; do
|
||||
tags=()
|
||||
|
||||
9
.github/workflows/pre-merge-checks.yml
vendored
9
.github/workflows/pre-merge-checks.yml
vendored
@@ -23,6 +23,8 @@ jobs:
|
||||
id: python-src
|
||||
with:
|
||||
files: |
|
||||
.github/workflows/_check-codestyle-python.yml
|
||||
.github/workflows/build-build-tools-image.yml
|
||||
.github/workflows/pre-merge-checks.yml
|
||||
**/**.py
|
||||
poetry.lock
|
||||
@@ -38,6 +40,10 @@ jobs:
|
||||
if: needs.get-changed-files.outputs.python-changed == 'true'
|
||||
needs: [ get-changed-files ]
|
||||
uses: ./.github/workflows/build-build-tools-image.yml
|
||||
with:
|
||||
# Build only one combination to save time
|
||||
archs: '["x64"]'
|
||||
debians: '["bookworm"]'
|
||||
secrets: inherit
|
||||
|
||||
check-codestyle-python:
|
||||
@@ -45,7 +51,8 @@ jobs:
|
||||
needs: [ get-changed-files, build-build-tools-image ]
|
||||
uses: ./.github/workflows/_check-codestyle-python.yml
|
||||
with:
|
||||
build-tools-image: ${{ needs.build-build-tools-image.outputs.image }}-bookworm
|
||||
# `-bookworm-x64` suffix should match the combination in `build-build-tools-image`
|
||||
build-tools-image: ${{ needs.build-build-tools-image.outputs.image }}-bookworm-x64
|
||||
secrets: inherit
|
||||
|
||||
# To get items from the merge queue merged into main we need to satisfy "Status checks that are required".
|
||||
|
||||
121
Cargo.lock
generated
121
Cargo.lock
generated
@@ -2180,9 +2180,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "futures-channel"
|
||||
version = "0.3.30"
|
||||
version = "0.3.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78"
|
||||
checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-sink",
|
||||
@@ -2190,9 +2190,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "futures-core"
|
||||
version = "0.3.30"
|
||||
version = "0.3.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d"
|
||||
checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e"
|
||||
|
||||
[[package]]
|
||||
name = "futures-executor"
|
||||
@@ -2207,9 +2207,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "futures-io"
|
||||
version = "0.3.30"
|
||||
version = "0.3.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1"
|
||||
checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6"
|
||||
|
||||
[[package]]
|
||||
name = "futures-lite"
|
||||
@@ -2228,9 +2228,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "futures-macro"
|
||||
version = "0.3.30"
|
||||
version = "0.3.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac"
|
||||
checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@@ -2239,15 +2239,15 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "futures-sink"
|
||||
version = "0.3.30"
|
||||
version = "0.3.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5"
|
||||
checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7"
|
||||
|
||||
[[package]]
|
||||
name = "futures-task"
|
||||
version = "0.3.30"
|
||||
version = "0.3.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004"
|
||||
checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988"
|
||||
|
||||
[[package]]
|
||||
name = "futures-timer"
|
||||
@@ -2257,9 +2257,9 @@ checksum = "e64b03909df88034c26dc1547e8970b91f98bdb65165d6a4e9110d94263dbb2c"
|
||||
|
||||
[[package]]
|
||||
name = "futures-util"
|
||||
version = "0.3.30"
|
||||
version = "0.3.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48"
|
||||
checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81"
|
||||
dependencies = [
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
@@ -4139,7 +4139,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "postgres"
|
||||
version = "0.19.4"
|
||||
source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#a130197713830a0ea0004b539b1f51a66b4c3e18"
|
||||
source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#00940fcdb57a8e99e805297b75839e7c4c7b1796"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"fallible-iterator",
|
||||
@@ -4152,7 +4152,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "postgres-protocol"
|
||||
version = "0.6.4"
|
||||
source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#2a2a7c56930dd5ad60676ce6da92e1cbe6fb3ef5"
|
||||
source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#00940fcdb57a8e99e805297b75839e7c4c7b1796"
|
||||
dependencies = [
|
||||
"base64 0.20.0",
|
||||
"byteorder",
|
||||
@@ -4168,16 +4168,40 @@ dependencies = [
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "postgres-protocol2"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"base64 0.20.0",
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"fallible-iterator",
|
||||
"hmac",
|
||||
"md-5",
|
||||
"memchr",
|
||||
"rand 0.8.5",
|
||||
"sha2",
|
||||
"stringprep",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "postgres-types"
|
||||
version = "0.2.4"
|
||||
source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#2a2a7c56930dd5ad60676ce6da92e1cbe6fb3ef5"
|
||||
source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#00940fcdb57a8e99e805297b75839e7c4c7b1796"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"fallible-iterator",
|
||||
"postgres-protocol",
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "postgres-types2"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"fallible-iterator",
|
||||
"postgres-protocol2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4188,7 +4212,7 @@ dependencies = [
|
||||
"bytes",
|
||||
"once_cell",
|
||||
"pq_proto",
|
||||
"rustls 0.23.16",
|
||||
"rustls 0.23.18",
|
||||
"rustls-pemfile 2.1.1",
|
||||
"serde",
|
||||
"thiserror",
|
||||
@@ -4507,7 +4531,7 @@ dependencies = [
|
||||
"parquet_derive",
|
||||
"pbkdf2",
|
||||
"pin-project-lite",
|
||||
"postgres-protocol",
|
||||
"postgres-protocol2",
|
||||
"postgres_backend",
|
||||
"pq_proto",
|
||||
"prometheus",
|
||||
@@ -4524,7 +4548,7 @@ dependencies = [
|
||||
"rsa",
|
||||
"rstest",
|
||||
"rustc-hash",
|
||||
"rustls 0.23.16",
|
||||
"rustls 0.23.18",
|
||||
"rustls-native-certs 0.8.0",
|
||||
"rustls-pemfile 2.1.1",
|
||||
"scopeguard",
|
||||
@@ -4542,8 +4566,7 @@ dependencies = [
|
||||
"tikv-jemalloc-ctl",
|
||||
"tikv-jemallocator",
|
||||
"tokio",
|
||||
"tokio-postgres",
|
||||
"tokio-postgres-rustls",
|
||||
"tokio-postgres2",
|
||||
"tokio-rustls 0.26.0",
|
||||
"tokio-tungstenite",
|
||||
"tokio-util",
|
||||
@@ -5237,9 +5260,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "rustls"
|
||||
version = "0.23.16"
|
||||
version = "0.23.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "eee87ff5d9b36712a58574e12e9f0ea80f915a5b0ac518d322b24a465617925e"
|
||||
checksum = "9c9cc1d47e243d655ace55ed38201c19ae02c148ae56412ab8750e8f0166ab7f"
|
||||
dependencies = [
|
||||
"log",
|
||||
"once_cell",
|
||||
@@ -5370,6 +5393,7 @@ dependencies = [
|
||||
"itertools 0.10.5",
|
||||
"metrics",
|
||||
"once_cell",
|
||||
"pageserver_api",
|
||||
"parking_lot 0.12.1",
|
||||
"postgres",
|
||||
"postgres-protocol",
|
||||
@@ -5401,6 +5425,7 @@ dependencies = [
|
||||
"tracing-subscriber",
|
||||
"url",
|
||||
"utils",
|
||||
"wal_decoder",
|
||||
"walproposer",
|
||||
"workspace_hack",
|
||||
]
|
||||
@@ -5954,7 +5979,7 @@ dependencies = [
|
||||
"once_cell",
|
||||
"parking_lot 0.12.1",
|
||||
"prost",
|
||||
"rustls 0.23.16",
|
||||
"rustls 0.23.18",
|
||||
"tokio",
|
||||
"tonic",
|
||||
"tonic-build",
|
||||
@@ -6037,7 +6062,7 @@ dependencies = [
|
||||
"postgres_ffi",
|
||||
"remote_storage",
|
||||
"reqwest 0.12.4",
|
||||
"rustls 0.23.16",
|
||||
"rustls 0.23.18",
|
||||
"rustls-native-certs 0.8.0",
|
||||
"serde",
|
||||
"serde_json",
|
||||
@@ -6425,6 +6450,7 @@ dependencies = [
|
||||
"libc",
|
||||
"mio",
|
||||
"num_cpus",
|
||||
"parking_lot 0.12.1",
|
||||
"pin-project-lite",
|
||||
"signal-hook-registry",
|
||||
"socket2",
|
||||
@@ -6472,7 +6498,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "tokio-postgres"
|
||||
version = "0.7.7"
|
||||
source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#2a2a7c56930dd5ad60676ce6da92e1cbe6fb3ef5"
|
||||
source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#00940fcdb57a8e99e805297b75839e7c4c7b1796"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"byteorder",
|
||||
@@ -6499,13 +6525,33 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "04fb792ccd6bbcd4bba408eb8a292f70fc4a3589e5d793626f45190e6454b6ab"
|
||||
dependencies = [
|
||||
"ring",
|
||||
"rustls 0.23.16",
|
||||
"rustls 0.23.18",
|
||||
"tokio",
|
||||
"tokio-postgres",
|
||||
"tokio-rustls 0.26.0",
|
||||
"x509-certificate",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-postgres2"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"fallible-iterator",
|
||||
"futures-util",
|
||||
"log",
|
||||
"parking_lot 0.12.1",
|
||||
"percent-encoding",
|
||||
"phf",
|
||||
"pin-project-lite",
|
||||
"postgres-protocol2",
|
||||
"postgres-types2",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-rustls"
|
||||
version = "0.24.0"
|
||||
@@ -6533,7 +6579,7 @@ version = "0.26.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4"
|
||||
dependencies = [
|
||||
"rustls 0.23.16",
|
||||
"rustls 0.23.18",
|
||||
"rustls-pki-types",
|
||||
"tokio",
|
||||
]
|
||||
@@ -6942,7 +6988,7 @@ dependencies = [
|
||||
"base64 0.22.1",
|
||||
"log",
|
||||
"once_cell",
|
||||
"rustls 0.23.16",
|
||||
"rustls 0.23.18",
|
||||
"rustls-pki-types",
|
||||
"url",
|
||||
"webpki-roots 0.26.1",
|
||||
@@ -7028,6 +7074,7 @@ dependencies = [
|
||||
"serde_assert",
|
||||
"serde_json",
|
||||
"serde_path_to_error",
|
||||
"serde_with",
|
||||
"signal-hook",
|
||||
"strum",
|
||||
"strum_macros",
|
||||
@@ -7124,10 +7171,16 @@ name = "wal_decoder"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-compression",
|
||||
"bytes",
|
||||
"pageserver_api",
|
||||
"postgres_ffi",
|
||||
"prost",
|
||||
"serde",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tonic",
|
||||
"tonic-build",
|
||||
"tracing",
|
||||
"utils",
|
||||
"workspace_hack",
|
||||
@@ -7595,7 +7648,6 @@ dependencies = [
|
||||
"num-traits",
|
||||
"once_cell",
|
||||
"parquet",
|
||||
"postgres-types",
|
||||
"prettyplease",
|
||||
"proc-macro2",
|
||||
"prost",
|
||||
@@ -7605,7 +7657,7 @@ dependencies = [
|
||||
"regex-automata 0.4.3",
|
||||
"regex-syntax 0.8.2",
|
||||
"reqwest 0.12.4",
|
||||
"rustls 0.23.16",
|
||||
"rustls 0.23.18",
|
||||
"scopeguard",
|
||||
"serde",
|
||||
"serde_json",
|
||||
@@ -7620,7 +7672,6 @@ dependencies = [
|
||||
"time",
|
||||
"time-macros",
|
||||
"tokio",
|
||||
"tokio-postgres",
|
||||
"tokio-rustls 0.26.0",
|
||||
"tokio-stream",
|
||||
"tokio-util",
|
||||
|
||||
@@ -35,6 +35,9 @@ members = [
|
||||
"libs/walproposer",
|
||||
"libs/wal_decoder",
|
||||
"libs/postgres_initdb",
|
||||
"libs/proxy/postgres-protocol2",
|
||||
"libs/proxy/postgres-types2",
|
||||
"libs/proxy/tokio-postgres2",
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
|
||||
@@ -7,7 +7,7 @@ ARG IMAGE=build-tools
|
||||
ARG TAG=pinned
|
||||
ARG DEFAULT_PG_VERSION=17
|
||||
ARG STABLE_PG_VERSION=16
|
||||
ARG DEBIAN_VERSION=bullseye
|
||||
ARG DEBIAN_VERSION=bookworm
|
||||
ARG DEBIAN_FLAVOR=${DEBIAN_VERSION}-slim
|
||||
|
||||
# Build Postgres
|
||||
|
||||
3
Makefile
3
Makefile
@@ -38,6 +38,7 @@ ifeq ($(UNAME_S),Linux)
|
||||
# Seccomp BPF is only available for Linux
|
||||
PG_CONFIGURE_OPTS += --with-libseccomp
|
||||
else ifeq ($(UNAME_S),Darwin)
|
||||
PG_CFLAGS += -DUSE_PREFETCH
|
||||
ifndef DISABLE_HOMEBREW
|
||||
# macOS with brew-installed openssl requires explicit paths
|
||||
# It can be configured with OPENSSL_PREFIX variable
|
||||
@@ -146,6 +147,8 @@ postgres-%: postgres-configure-% \
|
||||
$(MAKE) -C $(POSTGRES_INSTALL_DIR)/build/$*/contrib/pg_prewarm install
|
||||
+@echo "Compiling pg_buffercache $*"
|
||||
$(MAKE) -C $(POSTGRES_INSTALL_DIR)/build/$*/contrib/pg_buffercache install
|
||||
+@echo "Compiling pg_visibility $*"
|
||||
$(MAKE) -C $(POSTGRES_INSTALL_DIR)/build/$*/contrib/pg_visibility install
|
||||
+@echo "Compiling pageinspect $*"
|
||||
$(MAKE) -C $(POSTGRES_INSTALL_DIR)/build/$*/contrib/pageinspect install
|
||||
+@echo "Compiling amcheck $*"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
ARG DEBIAN_VERSION=bullseye
|
||||
ARG DEBIAN_VERSION=bookworm
|
||||
|
||||
FROM debian:bookworm-slim AS pgcopydb_builder
|
||||
ARG DEBIAN_VERSION
|
||||
@@ -57,9 +57,9 @@ RUN mkdir -p /pgcopydb/bin && \
|
||||
mkdir -p /pgcopydb/lib && \
|
||||
chmod -R 755 /pgcopydb && \
|
||||
chown -R nonroot:nonroot /pgcopydb
|
||||
|
||||
COPY --from=pgcopydb_builder /usr/lib/postgresql/16/bin/pgcopydb /pgcopydb/bin/pgcopydb
|
||||
COPY --from=pgcopydb_builder /pgcopydb/lib/libpq.so.5 /pgcopydb/lib/libpq.so.5
|
||||
|
||||
COPY --from=pgcopydb_builder /usr/lib/postgresql/16/bin/pgcopydb /pgcopydb/bin/pgcopydb
|
||||
COPY --from=pgcopydb_builder /pgcopydb/lib/libpq.so.5 /pgcopydb/lib/libpq.so.5
|
||||
|
||||
# System deps
|
||||
#
|
||||
@@ -258,14 +258,14 @@ WORKDIR /home/nonroot
|
||||
|
||||
# Rust
|
||||
# Please keep the version of llvm (installed above) in sync with rust llvm (`rustc --version --verbose | grep LLVM`)
|
||||
ENV RUSTC_VERSION=1.82.0
|
||||
ENV RUSTC_VERSION=1.83.0
|
||||
ENV RUSTUP_HOME="/home/nonroot/.rustup"
|
||||
ENV PATH="/home/nonroot/.cargo/bin:${PATH}"
|
||||
ARG RUSTFILT_VERSION=0.2.1
|
||||
ARG CARGO_HAKARI_VERSION=0.9.30
|
||||
ARG CARGO_DENY_VERSION=0.16.1
|
||||
ARG CARGO_HACK_VERSION=0.6.31
|
||||
ARG CARGO_NEXTEST_VERSION=0.9.72
|
||||
ARG CARGO_HAKARI_VERSION=0.9.33
|
||||
ARG CARGO_DENY_VERSION=0.16.2
|
||||
ARG CARGO_HACK_VERSION=0.6.33
|
||||
ARG CARGO_NEXTEST_VERSION=0.9.85
|
||||
RUN curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux-gnu/rustup-init && whoami && \
|
||||
chmod +x rustup-init && \
|
||||
./rustup-init -y --default-toolchain ${RUSTC_VERSION} && \
|
||||
@@ -289,7 +289,7 @@ RUN whoami \
|
||||
&& cargo --version --verbose \
|
||||
&& rustup --version --verbose \
|
||||
&& rustc --version --verbose \
|
||||
&& clang --version
|
||||
&& clang --version
|
||||
|
||||
RUN if [ "${DEBIAN_VERSION}" = "bookworm" ]; then \
|
||||
LD_LIBRARY_PATH=/pgcopydb/lib /pgcopydb/bin/pgcopydb --version; \
|
||||
|
||||
@@ -3,7 +3,7 @@ ARG REPOSITORY=neondatabase
|
||||
ARG IMAGE=build-tools
|
||||
ARG TAG=pinned
|
||||
ARG BUILD_TAG
|
||||
ARG DEBIAN_VERSION=bullseye
|
||||
ARG DEBIAN_VERSION=bookworm
|
||||
ARG DEBIAN_FLAVOR=${DEBIAN_VERSION}-slim
|
||||
|
||||
#########################################################################################
|
||||
|
||||
@@ -58,7 +58,7 @@ use compute_tools::compute::{
|
||||
forward_termination_signal, ComputeNode, ComputeState, ParsedSpec, PG_PID,
|
||||
};
|
||||
use compute_tools::configurator::launch_configurator;
|
||||
use compute_tools::extension_server::get_pg_version;
|
||||
use compute_tools::extension_server::get_pg_version_string;
|
||||
use compute_tools::http::api::launch_http_server;
|
||||
use compute_tools::logger::*;
|
||||
use compute_tools::monitor::launch_monitor;
|
||||
@@ -326,7 +326,7 @@ fn wait_spec(
|
||||
connstr: Url::parse(connstr).context("cannot parse connstr as a URL")?,
|
||||
pgdata: pgdata.to_string(),
|
||||
pgbin: pgbin.to_string(),
|
||||
pgversion: get_pg_version(pgbin),
|
||||
pgversion: get_pg_version_string(pgbin),
|
||||
live_config_allowed,
|
||||
state: Mutex::new(new_state),
|
||||
state_changed: Condvar::new(),
|
||||
|
||||
@@ -29,6 +29,7 @@ use anyhow::Context;
|
||||
use aws_config::BehaviorVersion;
|
||||
use camino::{Utf8Path, Utf8PathBuf};
|
||||
use clap::Parser;
|
||||
use compute_tools::extension_server::{get_pg_version, PostgresMajorVersion};
|
||||
use nix::unistd::Pid;
|
||||
use tracing::{info, info_span, warn, Instrument};
|
||||
use utils::fs_ext::is_directory_empty;
|
||||
@@ -131,11 +132,17 @@ pub(crate) async fn main() -> anyhow::Result<()> {
|
||||
//
|
||||
// Initialize pgdata
|
||||
//
|
||||
let pg_version = match get_pg_version(pg_bin_dir.as_str()) {
|
||||
PostgresMajorVersion::V14 => 14,
|
||||
PostgresMajorVersion::V15 => 15,
|
||||
PostgresMajorVersion::V16 => 16,
|
||||
PostgresMajorVersion::V17 => 17,
|
||||
};
|
||||
let superuser = "cloud_admin"; // XXX: this shouldn't be hard-coded
|
||||
postgres_initdb::do_run_initdb(postgres_initdb::RunInitdbArgs {
|
||||
superuser,
|
||||
locale: "en_US.UTF-8", // XXX: this shouldn't be hard-coded,
|
||||
pg_version: 140000, // XXX: this shouldn't be hard-coded but derived from which compute image we're running in
|
||||
pg_version,
|
||||
initdb_bin: pg_bin_dir.join("initdb").as_ref(),
|
||||
library_search_path: &pg_lib_dir, // TODO: is this right? Prob works in compute image, not sure about neon_local.
|
||||
pgdata: &pgdata_dir,
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use compute_api::responses::CatalogObjects;
|
||||
use futures::Stream;
|
||||
use postgres::NoTls;
|
||||
use std::{path::Path, process::Stdio, result::Result, sync::Arc};
|
||||
@@ -13,7 +12,8 @@ use tokio_util::codec::{BytesCodec, FramedRead};
|
||||
use tracing::warn;
|
||||
|
||||
use crate::compute::ComputeNode;
|
||||
use crate::pg_helpers::{get_existing_dbs_async, get_existing_roles_async};
|
||||
use crate::pg_helpers::{get_existing_dbs_async, get_existing_roles_async, postgres_conf_for_db};
|
||||
use compute_api::responses::CatalogObjects;
|
||||
|
||||
pub async fn get_dbs_and_roles(compute: &Arc<ComputeNode>) -> anyhow::Result<CatalogObjects> {
|
||||
let connstr = compute.connstr.clone();
|
||||
@@ -43,6 +43,8 @@ pub enum SchemaDumpError {
|
||||
DatabaseDoesNotExist,
|
||||
#[error("Failed to execute pg_dump.")]
|
||||
IO(#[from] std::io::Error),
|
||||
#[error("Unexpected error.")]
|
||||
Unexpected,
|
||||
}
|
||||
|
||||
// It uses the pg_dump utility to dump the schema of the specified database.
|
||||
@@ -60,11 +62,38 @@ pub async fn get_database_schema(
|
||||
let pgbin = &compute.pgbin;
|
||||
let basepath = Path::new(pgbin).parent().unwrap();
|
||||
let pgdump = basepath.join("pg_dump");
|
||||
let mut connstr = compute.connstr.clone();
|
||||
connstr.set_path(dbname);
|
||||
|
||||
// Replace the DB in the connection string and disable it to parts.
|
||||
// This is the only option to handle DBs with special characters.
|
||||
let conf =
|
||||
postgres_conf_for_db(&compute.connstr, dbname).map_err(|_| SchemaDumpError::Unexpected)?;
|
||||
let host = conf
|
||||
.get_hosts()
|
||||
.first()
|
||||
.ok_or(SchemaDumpError::Unexpected)?;
|
||||
let host = match host {
|
||||
tokio_postgres::config::Host::Tcp(ip) => ip.to_string(),
|
||||
#[cfg(unix)]
|
||||
tokio_postgres::config::Host::Unix(path) => path.to_string_lossy().to_string(),
|
||||
};
|
||||
let port = conf
|
||||
.get_ports()
|
||||
.first()
|
||||
.ok_or(SchemaDumpError::Unexpected)?;
|
||||
let user = conf.get_user().ok_or(SchemaDumpError::Unexpected)?;
|
||||
let dbname = conf.get_dbname().ok_or(SchemaDumpError::Unexpected)?;
|
||||
|
||||
let mut cmd = Command::new(pgdump)
|
||||
// XXX: this seems to be the only option to deal with DBs with `=` in the name
|
||||
// See <https://www.postgresql.org/message-id/flat/20151023003445.931.91267%40wrigleys.postgresql.org>
|
||||
.env("PGDATABASE", dbname)
|
||||
.arg("--host")
|
||||
.arg(host)
|
||||
.arg("--port")
|
||||
.arg(port.to_string())
|
||||
.arg("--username")
|
||||
.arg(user)
|
||||
.arg("--schema-only")
|
||||
.arg(connstr.as_str())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.kill_on_drop(true)
|
||||
|
||||
@@ -34,9 +34,8 @@ use utils::measured_stream::MeasuredReader;
|
||||
use nix::sys::signal::{kill, Signal};
|
||||
use remote_storage::{DownloadError, RemotePath};
|
||||
use tokio::spawn;
|
||||
use url::Url;
|
||||
|
||||
use crate::installed_extensions::get_installed_extensions_sync;
|
||||
use crate::installed_extensions::get_installed_extensions;
|
||||
use crate::local_proxy;
|
||||
use crate::pg_helpers::*;
|
||||
use crate::spec::*;
|
||||
@@ -816,30 +815,32 @@ impl ComputeNode {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_maintenance_client(url: &Url) -> Result<tokio_postgres::Client> {
|
||||
let mut connstr = url.clone();
|
||||
async fn get_maintenance_client(
|
||||
conf: &tokio_postgres::Config,
|
||||
) -> Result<tokio_postgres::Client> {
|
||||
let mut conf = conf.clone();
|
||||
|
||||
connstr
|
||||
.query_pairs_mut()
|
||||
.append_pair("application_name", "apply_config");
|
||||
conf.application_name("apply_config");
|
||||
|
||||
let (client, conn) = match tokio_postgres::connect(connstr.as_str(), NoTls).await {
|
||||
let (client, conn) = match conf.connect(NoTls).await {
|
||||
// If connection fails, it may be the old node with `zenith_admin` superuser.
|
||||
//
|
||||
// In this case we need to connect with old `zenith_admin` name
|
||||
// and create new user. We cannot simply rename connected user,
|
||||
// but we can create a new one and grant it all privileges.
|
||||
Err(e) => match e.code() {
|
||||
Some(&SqlState::INVALID_PASSWORD)
|
||||
| Some(&SqlState::INVALID_AUTHORIZATION_SPECIFICATION) => {
|
||||
// connect with zenith_admin if cloud_admin could not authenticate
|
||||
// Connect with zenith_admin if cloud_admin could not authenticate
|
||||
info!(
|
||||
"cannot connect to postgres: {}, retrying with `zenith_admin` username",
|
||||
e
|
||||
);
|
||||
let mut zenith_admin_connstr = connstr.clone();
|
||||
|
||||
zenith_admin_connstr
|
||||
.set_username("zenith_admin")
|
||||
.map_err(|_| anyhow::anyhow!("invalid connstr"))?;
|
||||
let mut zenith_admin_conf = postgres::config::Config::from(conf.clone());
|
||||
zenith_admin_conf.user("zenith_admin");
|
||||
|
||||
let mut client =
|
||||
Client::connect(zenith_admin_connstr.as_str(), NoTls)
|
||||
zenith_admin_conf.connect(NoTls)
|
||||
.context("broken cloud_admin credential: tried connecting with cloud_admin but could not authenticate, and zenith_admin does not work either")?;
|
||||
|
||||
// Disable forwarding so that users don't get a cloud_admin role
|
||||
@@ -853,8 +854,8 @@ impl ComputeNode {
|
||||
|
||||
drop(client);
|
||||
|
||||
// reconnect with connstring with expected name
|
||||
tokio_postgres::connect(connstr.as_str(), NoTls).await?
|
||||
// Reconnect with connstring with expected name
|
||||
conf.connect(NoTls).await?
|
||||
}
|
||||
_ => return Err(e.into()),
|
||||
},
|
||||
@@ -885,7 +886,7 @@ impl ComputeNode {
|
||||
pub fn apply_spec_sql(
|
||||
&self,
|
||||
spec: Arc<ComputeSpec>,
|
||||
url: Arc<Url>,
|
||||
conf: Arc<tokio_postgres::Config>,
|
||||
concurrency: usize,
|
||||
) -> Result<()> {
|
||||
let rt = tokio::runtime::Builder::new_multi_thread()
|
||||
@@ -897,7 +898,7 @@ impl ComputeNode {
|
||||
|
||||
rt.block_on(async {
|
||||
// Proceed with post-startup configuration. Note, that order of operations is important.
|
||||
let client = Self::get_maintenance_client(&url).await?;
|
||||
let client = Self::get_maintenance_client(&conf).await?;
|
||||
let spec = spec.clone();
|
||||
|
||||
let databases = get_existing_dbs_async(&client).await?;
|
||||
@@ -931,7 +932,7 @@ impl ComputeNode {
|
||||
RenameAndDeleteDatabases,
|
||||
CreateAndAlterDatabases,
|
||||
] {
|
||||
debug!("Applying phase {:?}", &phase);
|
||||
info!("Applying phase {:?}", &phase);
|
||||
apply_operations(
|
||||
spec.clone(),
|
||||
ctx.clone(),
|
||||
@@ -942,6 +943,7 @@ impl ComputeNode {
|
||||
.await?;
|
||||
}
|
||||
|
||||
info!("Applying RunInEachDatabase phase");
|
||||
let concurrency_token = Arc::new(tokio::sync::Semaphore::new(concurrency));
|
||||
|
||||
let db_processes = spec
|
||||
@@ -955,7 +957,7 @@ impl ComputeNode {
|
||||
let spec = spec.clone();
|
||||
let ctx = ctx.clone();
|
||||
let jwks_roles = jwks_roles.clone();
|
||||
let mut url = url.as_ref().clone();
|
||||
let mut conf = conf.as_ref().clone();
|
||||
let concurrency_token = concurrency_token.clone();
|
||||
let db = db.clone();
|
||||
|
||||
@@ -964,14 +966,14 @@ impl ComputeNode {
|
||||
match &db {
|
||||
DB::SystemDB => {}
|
||||
DB::UserDB(db) => {
|
||||
url.set_path(db.name.as_str());
|
||||
conf.dbname(db.name.as_str());
|
||||
}
|
||||
}
|
||||
|
||||
let url = Arc::new(url);
|
||||
let conf = Arc::new(conf);
|
||||
let fut = Self::apply_spec_sql_db(
|
||||
spec.clone(),
|
||||
url,
|
||||
conf,
|
||||
ctx.clone(),
|
||||
jwks_roles.clone(),
|
||||
concurrency_token.clone(),
|
||||
@@ -1017,7 +1019,7 @@ impl ComputeNode {
|
||||
/// semaphore. The caller has to make sure the semaphore isn't exhausted.
|
||||
async fn apply_spec_sql_db(
|
||||
spec: Arc<ComputeSpec>,
|
||||
url: Arc<Url>,
|
||||
conf: Arc<tokio_postgres::Config>,
|
||||
ctx: Arc<tokio::sync::RwLock<MutableApplyContext>>,
|
||||
jwks_roles: Arc<HashSet<String>>,
|
||||
concurrency_token: Arc<tokio::sync::Semaphore>,
|
||||
@@ -1046,7 +1048,7 @@ impl ComputeNode {
|
||||
// that database.
|
||||
|| async {
|
||||
if client_conn.is_none() {
|
||||
let db_client = Self::get_maintenance_client(&url).await?;
|
||||
let db_client = Self::get_maintenance_client(&conf).await?;
|
||||
client_conn.replace(db_client);
|
||||
}
|
||||
let client = client_conn.as_ref().unwrap();
|
||||
@@ -1061,34 +1063,16 @@ impl ComputeNode {
|
||||
Ok::<(), anyhow::Error>(())
|
||||
}
|
||||
|
||||
/// Do initial configuration of the already started Postgres.
|
||||
#[instrument(skip_all)]
|
||||
pub fn apply_config(&self, compute_state: &ComputeState) -> Result<()> {
|
||||
// If connection fails,
|
||||
// it may be the old node with `zenith_admin` superuser.
|
||||
//
|
||||
// In this case we need to connect with old `zenith_admin` name
|
||||
// and create new user. We cannot simply rename connected user,
|
||||
// but we can create a new one and grant it all privileges.
|
||||
let mut url = self.connstr.clone();
|
||||
url.query_pairs_mut()
|
||||
.append_pair("application_name", "apply_config");
|
||||
|
||||
let url = Arc::new(url);
|
||||
let spec = Arc::new(
|
||||
compute_state
|
||||
.pspec
|
||||
.as_ref()
|
||||
.expect("spec must be set")
|
||||
.spec
|
||||
.clone(),
|
||||
);
|
||||
|
||||
// Choose how many concurrent connections to use for applying the spec changes.
|
||||
// If the cluster is not currently Running we don't have to deal with user connections,
|
||||
/// Choose how many concurrent connections to use for applying the spec changes.
|
||||
pub fn max_service_connections(
|
||||
&self,
|
||||
compute_state: &ComputeState,
|
||||
spec: &ComputeSpec,
|
||||
) -> usize {
|
||||
// If the cluster is in Init state we don't have to deal with user connections,
|
||||
// and can thus use all `max_connections` connection slots. However, that's generally not
|
||||
// very efficient, so we generally still limit it to a smaller number.
|
||||
let max_concurrent_connections = if compute_state.status != ComputeStatus::Running {
|
||||
if compute_state.status == ComputeStatus::Init {
|
||||
// If the settings contain 'max_connections', use that as template
|
||||
if let Some(config) = spec.cluster.settings.find("max_connections") {
|
||||
config.parse::<usize>().ok()
|
||||
@@ -1144,10 +1128,29 @@ impl ComputeNode {
|
||||
.map(|val| if val > 1 { val - 1 } else { 1 })
|
||||
.last()
|
||||
.unwrap_or(3)
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/// Do initial configuration of the already started Postgres.
|
||||
#[instrument(skip_all)]
|
||||
pub fn apply_config(&self, compute_state: &ComputeState) -> Result<()> {
|
||||
let mut conf = tokio_postgres::Config::from_str(self.connstr.as_str()).unwrap();
|
||||
conf.application_name("apply_config");
|
||||
|
||||
let conf = Arc::new(conf);
|
||||
let spec = Arc::new(
|
||||
compute_state
|
||||
.pspec
|
||||
.as_ref()
|
||||
.expect("spec must be set")
|
||||
.spec
|
||||
.clone(),
|
||||
);
|
||||
|
||||
let max_concurrent_connections = self.max_service_connections(compute_state, &spec);
|
||||
|
||||
// Merge-apply spec & changes to PostgreSQL state.
|
||||
self.apply_spec_sql(spec.clone(), url.clone(), max_concurrent_connections)?;
|
||||
self.apply_spec_sql(spec.clone(), conf.clone(), max_concurrent_connections)?;
|
||||
|
||||
if let Some(ref local_proxy) = &spec.clone().local_proxy_config {
|
||||
info!("configuring local_proxy");
|
||||
@@ -1156,12 +1159,11 @@ impl ComputeNode {
|
||||
|
||||
// Run migrations separately to not hold up cold starts
|
||||
thread::spawn(move || {
|
||||
let mut connstr = url.as_ref().clone();
|
||||
connstr
|
||||
.query_pairs_mut()
|
||||
.append_pair("application_name", "migrations");
|
||||
let conf = conf.as_ref().clone();
|
||||
let mut conf = postgres::config::Config::from(conf);
|
||||
conf.application_name("migrations");
|
||||
|
||||
let mut client = Client::connect(connstr.as_str(), NoTls)?;
|
||||
let mut client = conf.connect(NoTls)?;
|
||||
handle_migrations(&mut client).context("apply_config handle_migrations")
|
||||
});
|
||||
|
||||
@@ -1222,21 +1224,28 @@ impl ComputeNode {
|
||||
let pgdata_path = Path::new(&self.pgdata);
|
||||
let postgresql_conf_path = pgdata_path.join("postgresql.conf");
|
||||
config::write_postgres_conf(&postgresql_conf_path, &spec, None)?;
|
||||
// temporarily reset max_cluster_size in config
|
||||
|
||||
// TODO(ololobus): We need a concurrency during reconfiguration as well,
|
||||
// but DB is already running and used by user. We can easily get out of
|
||||
// `max_connections` limit, and the current code won't handle that.
|
||||
// let compute_state = self.state.lock().unwrap().clone();
|
||||
// let max_concurrent_connections = self.max_service_connections(&compute_state, &spec);
|
||||
let max_concurrent_connections = 1;
|
||||
|
||||
// Temporarily reset max_cluster_size in config
|
||||
// to avoid the possibility of hitting the limit, while we are reconfiguring:
|
||||
// creating new extensions, roles, etc...
|
||||
// creating new extensions, roles, etc.
|
||||
config::with_compute_ctl_tmp_override(pgdata_path, "neon.max_cluster_size=-1", || {
|
||||
self.pg_reload_conf()?;
|
||||
|
||||
if spec.mode == ComputeMode::Primary {
|
||||
let mut url = self.connstr.clone();
|
||||
url.query_pairs_mut()
|
||||
.append_pair("application_name", "apply_config");
|
||||
let url = Arc::new(url);
|
||||
let mut conf = tokio_postgres::Config::from_str(self.connstr.as_str()).unwrap();
|
||||
conf.application_name("apply_config");
|
||||
let conf = Arc::new(conf);
|
||||
|
||||
let spec = Arc::new(spec.clone());
|
||||
|
||||
self.apply_spec_sql(spec, url, 1)?;
|
||||
self.apply_spec_sql(spec, conf, max_concurrent_connections)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -1362,7 +1371,17 @@ impl ComputeNode {
|
||||
|
||||
let connstr = self.connstr.clone();
|
||||
thread::spawn(move || {
|
||||
get_installed_extensions_sync(connstr).context("get_installed_extensions")
|
||||
let res = get_installed_extensions(&connstr);
|
||||
match res {
|
||||
Ok(extensions) => {
|
||||
info!(
|
||||
"[NEON_EXT_STAT] {}",
|
||||
serde_json::to_string(&extensions)
|
||||
.expect("failed to serialize extensions list")
|
||||
);
|
||||
}
|
||||
Err(err) => error!("could not get installed extensions: {err:?}"),
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -103,14 +103,33 @@ fn get_pg_config(argument: &str, pgbin: &str) -> String {
|
||||
.to_string()
|
||||
}
|
||||
|
||||
pub fn get_pg_version(pgbin: &str) -> String {
|
||||
pub fn get_pg_version(pgbin: &str) -> PostgresMajorVersion {
|
||||
// pg_config --version returns a (platform specific) human readable string
|
||||
// such as "PostgreSQL 15.4". We parse this to v14/v15/v16 etc.
|
||||
let human_version = get_pg_config("--version", pgbin);
|
||||
parse_pg_version(&human_version).to_string()
|
||||
parse_pg_version(&human_version)
|
||||
}
|
||||
|
||||
fn parse_pg_version(human_version: &str) -> &str {
|
||||
pub fn get_pg_version_string(pgbin: &str) -> String {
|
||||
match get_pg_version(pgbin) {
|
||||
PostgresMajorVersion::V14 => "v14",
|
||||
PostgresMajorVersion::V15 => "v15",
|
||||
PostgresMajorVersion::V16 => "v16",
|
||||
PostgresMajorVersion::V17 => "v17",
|
||||
}
|
||||
.to_owned()
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
|
||||
pub enum PostgresMajorVersion {
|
||||
V14,
|
||||
V15,
|
||||
V16,
|
||||
V17,
|
||||
}
|
||||
|
||||
fn parse_pg_version(human_version: &str) -> PostgresMajorVersion {
|
||||
use PostgresMajorVersion::*;
|
||||
// Normal releases have version strings like "PostgreSQL 15.4". But there
|
||||
// are also pre-release versions like "PostgreSQL 17devel" or "PostgreSQL
|
||||
// 16beta2" or "PostgreSQL 17rc1". And with the --with-extra-version
|
||||
@@ -121,10 +140,10 @@ fn parse_pg_version(human_version: &str) -> &str {
|
||||
.captures(human_version)
|
||||
{
|
||||
Some(captures) if captures.len() == 2 => match &captures["major"] {
|
||||
"14" => return "v14",
|
||||
"15" => return "v15",
|
||||
"16" => return "v16",
|
||||
"17" => return "v17",
|
||||
"14" => return V14,
|
||||
"15" => return V15,
|
||||
"16" => return V16,
|
||||
"17" => return V17,
|
||||
_ => {}
|
||||
},
|
||||
_ => {}
|
||||
@@ -263,24 +282,25 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_parse_pg_version() {
|
||||
assert_eq!(parse_pg_version("PostgreSQL 15.4"), "v15");
|
||||
assert_eq!(parse_pg_version("PostgreSQL 15.14"), "v15");
|
||||
use super::PostgresMajorVersion::*;
|
||||
assert_eq!(parse_pg_version("PostgreSQL 15.4"), V15);
|
||||
assert_eq!(parse_pg_version("PostgreSQL 15.14"), V15);
|
||||
assert_eq!(
|
||||
parse_pg_version("PostgreSQL 15.4 (Ubuntu 15.4-0ubuntu0.23.04.1)"),
|
||||
"v15"
|
||||
V15
|
||||
);
|
||||
|
||||
assert_eq!(parse_pg_version("PostgreSQL 14.15"), "v14");
|
||||
assert_eq!(parse_pg_version("PostgreSQL 14.0"), "v14");
|
||||
assert_eq!(parse_pg_version("PostgreSQL 14.15"), V14);
|
||||
assert_eq!(parse_pg_version("PostgreSQL 14.0"), V14);
|
||||
assert_eq!(
|
||||
parse_pg_version("PostgreSQL 14.9 (Debian 14.9-1.pgdg120+1"),
|
||||
"v14"
|
||||
V14
|
||||
);
|
||||
|
||||
assert_eq!(parse_pg_version("PostgreSQL 16devel"), "v16");
|
||||
assert_eq!(parse_pg_version("PostgreSQL 16beta1"), "v16");
|
||||
assert_eq!(parse_pg_version("PostgreSQL 16rc2"), "v16");
|
||||
assert_eq!(parse_pg_version("PostgreSQL 16extra"), "v16");
|
||||
assert_eq!(parse_pg_version("PostgreSQL 16devel"), V16);
|
||||
assert_eq!(parse_pg_version("PostgreSQL 16beta1"), V16);
|
||||
assert_eq!(parse_pg_version("PostgreSQL 16rc2"), V16);
|
||||
assert_eq!(parse_pg_version("PostgreSQL 16extra"), V16);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -296,7 +296,12 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
|
||||
}
|
||||
|
||||
let connstr = compute.connstr.clone();
|
||||
let res = crate::installed_extensions::get_installed_extensions(connstr).await;
|
||||
let res = task::spawn_blocking(move || {
|
||||
installed_extensions::get_installed_extensions(&connstr)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
match res {
|
||||
Ok(res) => render_json(Body::from(serde_json::to_string(&res).unwrap())),
|
||||
Err(e) => render_json_error(
|
||||
|
||||
@@ -2,17 +2,16 @@ use compute_api::responses::{InstalledExtension, InstalledExtensions};
|
||||
use metrics::proto::MetricFamily;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use tracing::info;
|
||||
use url::Url;
|
||||
|
||||
use anyhow::Result;
|
||||
use postgres::{Client, NoTls};
|
||||
use tokio::task;
|
||||
|
||||
use metrics::core::Collector;
|
||||
use metrics::{register_uint_gauge_vec, UIntGaugeVec};
|
||||
use once_cell::sync::Lazy;
|
||||
|
||||
use crate::pg_helpers::postgres_conf_for_db;
|
||||
|
||||
/// We don't reuse get_existing_dbs() just for code clarity
|
||||
/// and to make database listing query here more explicit.
|
||||
///
|
||||
@@ -42,75 +41,51 @@ fn list_dbs(client: &mut Client) -> Result<Vec<String>> {
|
||||
///
|
||||
/// Same extension can be installed in multiple databases with different versions,
|
||||
/// we only keep the highest and lowest version across all databases.
|
||||
pub async fn get_installed_extensions(connstr: Url) -> Result<InstalledExtensions> {
|
||||
let mut connstr = connstr.clone();
|
||||
pub fn get_installed_extensions(connstr: &url::Url) -> Result<InstalledExtensions> {
|
||||
let mut client = Client::connect(connstr.as_str(), NoTls)?;
|
||||
let databases: Vec<String> = list_dbs(&mut client)?;
|
||||
|
||||
task::spawn_blocking(move || {
|
||||
let mut client = Client::connect(connstr.as_str(), NoTls)?;
|
||||
let databases: Vec<String> = list_dbs(&mut client)?;
|
||||
let mut extensions_map: HashMap<String, InstalledExtension> = HashMap::new();
|
||||
for db in databases.iter() {
|
||||
let config = postgres_conf_for_db(connstr, db)?;
|
||||
let mut db_client = config.connect(NoTls)?;
|
||||
let extensions: Vec<(String, String)> = db_client
|
||||
.query(
|
||||
"SELECT extname, extversion FROM pg_catalog.pg_extension;",
|
||||
&[],
|
||||
)?
|
||||
.iter()
|
||||
.map(|row| (row.get("extname"), row.get("extversion")))
|
||||
.collect();
|
||||
|
||||
let mut extensions_map: HashMap<String, InstalledExtension> = HashMap::new();
|
||||
for db in databases.iter() {
|
||||
connstr.set_path(db);
|
||||
let mut db_client = Client::connect(connstr.as_str(), NoTls)?;
|
||||
let extensions: Vec<(String, String)> = db_client
|
||||
.query(
|
||||
"SELECT extname, extversion FROM pg_catalog.pg_extension;",
|
||||
&[],
|
||||
)?
|
||||
.iter()
|
||||
.map(|row| (row.get("extname"), row.get("extversion")))
|
||||
.collect();
|
||||
for (extname, v) in extensions.iter() {
|
||||
let version = v.to_string();
|
||||
|
||||
for (extname, v) in extensions.iter() {
|
||||
let version = v.to_string();
|
||||
// increment the number of databases where the version of extension is installed
|
||||
INSTALLED_EXTENSIONS
|
||||
.with_label_values(&[extname, &version])
|
||||
.inc();
|
||||
|
||||
// increment the number of databases where the version of extension is installed
|
||||
INSTALLED_EXTENSIONS
|
||||
.with_label_values(&[extname, &version])
|
||||
.inc();
|
||||
|
||||
extensions_map
|
||||
.entry(extname.to_string())
|
||||
.and_modify(|e| {
|
||||
e.versions.insert(version.clone());
|
||||
// count the number of databases where the extension is installed
|
||||
e.n_databases += 1;
|
||||
})
|
||||
.or_insert(InstalledExtension {
|
||||
extname: extname.to_string(),
|
||||
versions: HashSet::from([version.clone()]),
|
||||
n_databases: 1,
|
||||
});
|
||||
}
|
||||
extensions_map
|
||||
.entry(extname.to_string())
|
||||
.and_modify(|e| {
|
||||
e.versions.insert(version.clone());
|
||||
// count the number of databases where the extension is installed
|
||||
e.n_databases += 1;
|
||||
})
|
||||
.or_insert(InstalledExtension {
|
||||
extname: extname.to_string(),
|
||||
versions: HashSet::from([version.clone()]),
|
||||
n_databases: 1,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let res = InstalledExtensions {
|
||||
extensions: extensions_map.values().cloned().collect(),
|
||||
};
|
||||
let res = InstalledExtensions {
|
||||
extensions: extensions_map.values().cloned().collect(),
|
||||
};
|
||||
|
||||
Ok(res)
|
||||
})
|
||||
.await?
|
||||
}
|
||||
|
||||
// Gather info about installed extensions
|
||||
pub fn get_installed_extensions_sync(connstr: Url) -> Result<()> {
|
||||
let rt = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.expect("failed to create runtime");
|
||||
let result = rt
|
||||
.block_on(crate::installed_extensions::get_installed_extensions(
|
||||
connstr,
|
||||
))
|
||||
.expect("failed to get installed extensions");
|
||||
|
||||
info!(
|
||||
"[NEON_EXT_STAT] {}",
|
||||
serde_json::to_string(&result).expect("failed to serialize extensions list")
|
||||
);
|
||||
Ok(())
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
static INSTALLED_EXTENSIONS: Lazy<UIntGaugeVec> = Lazy::new(|| {
|
||||
|
||||
@@ -6,6 +6,7 @@ use std::io::{BufRead, BufReader};
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
use std::path::Path;
|
||||
use std::process::Child;
|
||||
use std::str::FromStr;
|
||||
use std::thread::JoinHandle;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
@@ -13,8 +14,10 @@ use anyhow::{bail, Result};
|
||||
use futures::StreamExt;
|
||||
use ini::Ini;
|
||||
use notify::{RecursiveMode, Watcher};
|
||||
use postgres::config::Config;
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::time::timeout;
|
||||
use tokio_postgres;
|
||||
use tokio_postgres::NoTls;
|
||||
use tracing::{debug, error, info, instrument};
|
||||
|
||||
@@ -542,3 +545,11 @@ async fn handle_postgres_logs_async(stderr: tokio::process::ChildStderr) -> Resu
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// `Postgres::config::Config` handles database names with whitespaces
|
||||
/// and special characters properly.
|
||||
pub fn postgres_conf_for_db(connstr: &url::Url, dbname: &str) -> Result<Config> {
|
||||
let mut conf = Config::from_str(connstr.as_str())?;
|
||||
conf.dbname(dbname);
|
||||
Ok(conf)
|
||||
}
|
||||
|
||||
@@ -415,6 +415,11 @@ impl PageServerNode {
|
||||
.map(|x| x.parse::<bool>())
|
||||
.transpose()
|
||||
.context("Failed to parse 'timeline_offloading' as bool")?,
|
||||
wal_receiver_protocol_override: settings
|
||||
.remove("wal_receiver_protocol_override")
|
||||
.map(serde_json::from_str)
|
||||
.transpose()
|
||||
.context("parse `wal_receiver_protocol_override` from json")?,
|
||||
};
|
||||
if !settings.is_empty() {
|
||||
bail!("Unrecognized tenant settings: {settings:?}")
|
||||
|
||||
@@ -33,7 +33,6 @@ reason = "the marvin attack only affects private key decryption, not public key
|
||||
[licenses]
|
||||
allow = [
|
||||
"Apache-2.0",
|
||||
"Artistic-2.0",
|
||||
"BSD-2-Clause",
|
||||
"BSD-3-Clause",
|
||||
"CC0-1.0",
|
||||
@@ -67,7 +66,7 @@ registries = []
|
||||
# More documentation about the 'bans' section can be found here:
|
||||
# https://embarkstudios.github.io/cargo-deny/checks/bans/cfg.html
|
||||
[bans]
|
||||
multiple-versions = "warn"
|
||||
multiple-versions = "allow"
|
||||
wildcards = "allow"
|
||||
highlight = "all"
|
||||
workspace-default-features = "allow"
|
||||
|
||||
@@ -18,7 +18,7 @@ use std::{
|
||||
str::FromStr,
|
||||
time::Duration,
|
||||
};
|
||||
use utils::logging::LogFormat;
|
||||
use utils::{logging::LogFormat, postgres_client::PostgresClientProtocol};
|
||||
|
||||
use crate::models::ImageCompressionAlgorithm;
|
||||
use crate::models::LsnLease;
|
||||
@@ -118,6 +118,7 @@ pub struct ConfigToml {
|
||||
pub virtual_file_io_mode: Option<crate::models::virtual_file::IoMode>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub no_sync: Option<bool>,
|
||||
pub wal_receiver_protocol: PostgresClientProtocol,
|
||||
pub page_service_pipelining: PageServicePipeliningConfig,
|
||||
}
|
||||
|
||||
@@ -298,6 +299,8 @@ pub struct TenantConfigToml {
|
||||
/// Enable auto-offloading of timelines.
|
||||
/// (either this flag or the pageserver-global one need to be set)
|
||||
pub timeline_offloading: bool,
|
||||
|
||||
pub wal_receiver_protocol_override: Option<PostgresClientProtocol>,
|
||||
}
|
||||
|
||||
pub mod defaults {
|
||||
@@ -349,6 +352,9 @@ pub mod defaults {
|
||||
pub const DEFAULT_EPHEMERAL_BYTES_PER_MEMORY_KB: usize = 0;
|
||||
|
||||
pub const DEFAULT_IO_BUFFER_ALIGNMENT: usize = 512;
|
||||
|
||||
pub const DEFAULT_WAL_RECEIVER_PROTOCOL: utils::postgres_client::PostgresClientProtocol =
|
||||
utils::postgres_client::PostgresClientProtocol::Vanilla;
|
||||
}
|
||||
|
||||
impl Default for ConfigToml {
|
||||
@@ -435,6 +441,7 @@ impl Default for ConfigToml {
|
||||
virtual_file_io_mode: None,
|
||||
tenant_config: TenantConfigToml::default(),
|
||||
no_sync: None,
|
||||
wal_receiver_protocol: DEFAULT_WAL_RECEIVER_PROTOCOL,
|
||||
page_service_pipelining: PageServicePipeliningConfig::Pipelined(
|
||||
PageServicePipeliningConfigPipelined {
|
||||
max_batch_size: NonZeroUsize::new(32).unwrap(),
|
||||
@@ -528,6 +535,7 @@ impl Default for TenantConfigToml {
|
||||
lsn_lease_length: LsnLease::DEFAULT_LENGTH,
|
||||
lsn_lease_length_for_ts: LsnLease::DEFAULT_LENGTH_FOR_TS,
|
||||
timeline_offloading: false,
|
||||
wal_receiver_protocol_override: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -229,6 +229,18 @@ impl Key {
|
||||
}
|
||||
}
|
||||
|
||||
impl CompactKey {
|
||||
pub fn raw(&self) -> i128 {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<i128> for CompactKey {
|
||||
fn from(value: i128) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Key {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(
|
||||
|
||||
@@ -23,6 +23,7 @@ use utils::{
|
||||
completion,
|
||||
id::{NodeId, TenantId, TimelineId},
|
||||
lsn::Lsn,
|
||||
postgres_client::PostgresClientProtocol,
|
||||
serde_system_time,
|
||||
};
|
||||
|
||||
@@ -352,6 +353,7 @@ pub struct TenantConfig {
|
||||
pub lsn_lease_length: Option<String>,
|
||||
pub lsn_lease_length_for_ts: Option<String>,
|
||||
pub timeline_offloading: Option<bool>,
|
||||
pub wal_receiver_protocol_override: Option<PostgresClientProtocol>,
|
||||
}
|
||||
|
||||
/// The policy for the aux file storage.
|
||||
|
||||
@@ -562,6 +562,9 @@ pub enum BeMessage<'a> {
|
||||
options: &'a [&'a str],
|
||||
},
|
||||
KeepAlive(WalSndKeepAlive),
|
||||
/// Batch of interpreted, shard filtered WAL records,
|
||||
/// ready for the pageserver to ingest
|
||||
InterpretedWalRecords(InterpretedWalRecordsBody<'a>),
|
||||
}
|
||||
|
||||
/// Common shorthands.
|
||||
@@ -672,6 +675,22 @@ pub struct WalSndKeepAlive {
|
||||
pub request_reply: bool,
|
||||
}
|
||||
|
||||
/// Batch of interpreted WAL records used in the interpreted
|
||||
/// safekeeper to pageserver protocol.
|
||||
///
|
||||
/// Note that the pageserver uses the RawInterpretedWalRecordsBody
|
||||
/// counterpart of this from the neondatabase/rust-postgres repo.
|
||||
/// If you're changing this struct, you likely need to change its
|
||||
/// twin as well.
|
||||
#[derive(Debug)]
|
||||
pub struct InterpretedWalRecordsBody<'a> {
|
||||
/// End of raw WAL in [`Self::data`]
|
||||
pub streaming_lsn: u64,
|
||||
/// Current end of WAL on the server
|
||||
pub commit_lsn: u64,
|
||||
pub data: &'a [u8],
|
||||
}
|
||||
|
||||
pub static HELLO_WORLD_ROW: BeMessage = BeMessage::DataRow(&[Some(b"hello world")]);
|
||||
|
||||
// single text column
|
||||
@@ -996,6 +1015,19 @@ impl BeMessage<'_> {
|
||||
Ok(())
|
||||
})?
|
||||
}
|
||||
|
||||
BeMessage::InterpretedWalRecords(rec) => {
|
||||
// We use the COPY_DATA_TAG for our custom message
|
||||
// since this tag is interpreted as raw bytes.
|
||||
buf.put_u8(b'd');
|
||||
write_body(buf, |buf| {
|
||||
buf.put_u8(b'0'); // matches INTERPRETED_WAL_RECORD_TAG in postgres-protocol
|
||||
// dependency
|
||||
buf.put_u64(rec.streaming_lsn);
|
||||
buf.put_u64(rec.commit_lsn);
|
||||
buf.put_slice(rec.data);
|
||||
});
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
6
libs/proxy/README.md
Normal file
6
libs/proxy/README.md
Normal file
@@ -0,0 +1,6 @@
|
||||
This directory contains libraries that are specific for proxy.
|
||||
|
||||
Currently, it contains a signficant fork/refactoring of rust-postgres that no longer reflects the API
|
||||
of the original library. Since it was so significant, it made sense to upgrade it to it's own set of libraries.
|
||||
|
||||
Proxy needs unique access to the protocol, which explains why such heavy modifications were necessary.
|
||||
21
libs/proxy/postgres-protocol2/Cargo.toml
Normal file
21
libs/proxy/postgres-protocol2/Cargo.toml
Normal file
@@ -0,0 +1,21 @@
|
||||
[package]
|
||||
name = "postgres-protocol2"
|
||||
version = "0.1.0"
|
||||
edition = "2018"
|
||||
license = "MIT/Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
base64 = "0.20"
|
||||
byteorder.workspace = true
|
||||
bytes.workspace = true
|
||||
fallible-iterator.workspace = true
|
||||
hmac.workspace = true
|
||||
md-5 = "0.10"
|
||||
memchr = "2.0"
|
||||
rand.workspace = true
|
||||
sha2.workspace = true
|
||||
stringprep = "0.1"
|
||||
tokio = { workspace = true, features = ["rt"] }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
37
libs/proxy/postgres-protocol2/src/authentication/mod.rs
Normal file
37
libs/proxy/postgres-protocol2/src/authentication/mod.rs
Normal file
@@ -0,0 +1,37 @@
|
||||
//! Authentication protocol support.
|
||||
use md5::{Digest, Md5};
|
||||
|
||||
pub mod sasl;
|
||||
|
||||
/// Hashes authentication information in a way suitable for use in response
|
||||
/// to an `AuthenticationMd5Password` message.
|
||||
///
|
||||
/// The resulting string should be sent back to the database in a
|
||||
/// `PasswordMessage` message.
|
||||
#[inline]
|
||||
pub fn md5_hash(username: &[u8], password: &[u8], salt: [u8; 4]) -> String {
|
||||
let mut md5 = Md5::new();
|
||||
md5.update(password);
|
||||
md5.update(username);
|
||||
let output = md5.finalize_reset();
|
||||
md5.update(format!("{:x}", output));
|
||||
md5.update(salt);
|
||||
format!("md5{:x}", md5.finalize())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn md5() {
|
||||
let username = b"md5_user";
|
||||
let password = b"password";
|
||||
let salt = [0x2a, 0x3d, 0x8f, 0xe0];
|
||||
|
||||
assert_eq!(
|
||||
md5_hash(username, password, salt),
|
||||
"md562af4dd09bbb41884907a838a3233294"
|
||||
);
|
||||
}
|
||||
}
|
||||
516
libs/proxy/postgres-protocol2/src/authentication/sasl.rs
Normal file
516
libs/proxy/postgres-protocol2/src/authentication/sasl.rs
Normal file
@@ -0,0 +1,516 @@
|
||||
//! SASL-based authentication support.
|
||||
|
||||
use hmac::{Hmac, Mac};
|
||||
use rand::{self, Rng};
|
||||
use sha2::digest::FixedOutput;
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::fmt::Write;
|
||||
use std::io;
|
||||
use std::iter;
|
||||
use std::mem;
|
||||
use std::str;
|
||||
use tokio::task::yield_now;
|
||||
|
||||
const NONCE_LENGTH: usize = 24;
|
||||
|
||||
/// The identifier of the SCRAM-SHA-256 SASL authentication mechanism.
|
||||
pub const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
|
||||
/// The identifier of the SCRAM-SHA-256-PLUS SASL authentication mechanism.
|
||||
pub const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
|
||||
|
||||
// since postgres passwords are not required to exclude saslprep-prohibited
|
||||
// characters or even be valid UTF8, we run saslprep if possible and otherwise
|
||||
// return the raw password.
|
||||
fn normalize(pass: &[u8]) -> Vec<u8> {
|
||||
let pass = match str::from_utf8(pass) {
|
||||
Ok(pass) => pass,
|
||||
Err(_) => return pass.to_vec(),
|
||||
};
|
||||
|
||||
match stringprep::saslprep(pass) {
|
||||
Ok(pass) => pass.into_owned().into_bytes(),
|
||||
Err(_) => pass.as_bytes().to_vec(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn hi(str: &[u8], salt: &[u8], iterations: u32) -> [u8; 32] {
|
||||
let mut hmac =
|
||||
Hmac::<Sha256>::new_from_slice(str).expect("HMAC is able to accept all key sizes");
|
||||
hmac.update(salt);
|
||||
hmac.update(&[0, 0, 0, 1]);
|
||||
let mut prev = hmac.finalize().into_bytes();
|
||||
|
||||
let mut hi = prev;
|
||||
|
||||
for i in 1..iterations {
|
||||
let mut hmac = Hmac::<Sha256>::new_from_slice(str).expect("already checked above");
|
||||
hmac.update(&prev);
|
||||
prev = hmac.finalize().into_bytes();
|
||||
|
||||
for (hi, prev) in hi.iter_mut().zip(prev) {
|
||||
*hi ^= prev;
|
||||
}
|
||||
// yield every ~250us
|
||||
// hopefully reduces tail latencies
|
||||
if i % 1024 == 0 {
|
||||
yield_now().await
|
||||
}
|
||||
}
|
||||
|
||||
hi.into()
|
||||
}
|
||||
|
||||
enum ChannelBindingInner {
|
||||
Unrequested,
|
||||
Unsupported,
|
||||
TlsServerEndPoint(Vec<u8>),
|
||||
}
|
||||
|
||||
/// The channel binding configuration for a SCRAM authentication exchange.
|
||||
pub struct ChannelBinding(ChannelBindingInner);
|
||||
|
||||
impl ChannelBinding {
|
||||
/// The server did not request channel binding.
|
||||
pub fn unrequested() -> ChannelBinding {
|
||||
ChannelBinding(ChannelBindingInner::Unrequested)
|
||||
}
|
||||
|
||||
/// The server requested channel binding but the client is unable to provide it.
|
||||
pub fn unsupported() -> ChannelBinding {
|
||||
ChannelBinding(ChannelBindingInner::Unsupported)
|
||||
}
|
||||
|
||||
/// The server requested channel binding and the client will use the `tls-server-end-point`
|
||||
/// method.
|
||||
pub fn tls_server_end_point(signature: Vec<u8>) -> ChannelBinding {
|
||||
ChannelBinding(ChannelBindingInner::TlsServerEndPoint(signature))
|
||||
}
|
||||
|
||||
fn gs2_header(&self) -> &'static str {
|
||||
match self.0 {
|
||||
ChannelBindingInner::Unrequested => "y,,",
|
||||
ChannelBindingInner::Unsupported => "n,,",
|
||||
ChannelBindingInner::TlsServerEndPoint(_) => "p=tls-server-end-point,,",
|
||||
}
|
||||
}
|
||||
|
||||
fn cbind_data(&self) -> &[u8] {
|
||||
match self.0 {
|
||||
ChannelBindingInner::Unrequested | ChannelBindingInner::Unsupported => &[],
|
||||
ChannelBindingInner::TlsServerEndPoint(ref buf) => buf,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A pair of keys for the SCRAM-SHA-256 mechanism.
|
||||
/// See <https://datatracker.ietf.org/doc/html/rfc5802#section-3> for details.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct ScramKeys<const N: usize> {
|
||||
/// Used by server to authenticate client.
|
||||
pub client_key: [u8; N],
|
||||
/// Used by client to verify server's signature.
|
||||
pub server_key: [u8; N],
|
||||
}
|
||||
|
||||
/// Password or keys which were derived from it.
|
||||
enum Credentials<const N: usize> {
|
||||
/// A regular password as a vector of bytes.
|
||||
Password(Vec<u8>),
|
||||
/// A precomputed pair of keys.
|
||||
Keys(Box<ScramKeys<N>>),
|
||||
}
|
||||
|
||||
enum State {
|
||||
Update {
|
||||
nonce: String,
|
||||
password: Credentials<32>,
|
||||
channel_binding: ChannelBinding,
|
||||
},
|
||||
Finish {
|
||||
server_key: [u8; 32],
|
||||
auth_message: String,
|
||||
},
|
||||
Done,
|
||||
}
|
||||
|
||||
/// A type which handles the client side of the SCRAM-SHA-256/SCRAM-SHA-256-PLUS authentication
|
||||
/// process.
|
||||
///
|
||||
/// During the authentication process, if the backend sends an `AuthenticationSASL` message which
|
||||
/// includes `SCRAM-SHA-256` as an authentication mechanism, this type can be used.
|
||||
///
|
||||
/// After a `ScramSha256` is constructed, the buffer returned by the `message()` method should be
|
||||
/// sent to the backend in a `SASLInitialResponse` message along with the mechanism name.
|
||||
///
|
||||
/// The server will reply with an `AuthenticationSASLContinue` message. Its contents should be
|
||||
/// passed to the `update()` method, after which the buffer returned by the `message()` method
|
||||
/// should be sent to the backend in a `SASLResponse` message.
|
||||
///
|
||||
/// The server will reply with an `AuthenticationSASLFinal` message. Its contents should be passed
|
||||
/// to the `finish()` method, after which the authentication process is complete.
|
||||
pub struct ScramSha256 {
|
||||
message: String,
|
||||
state: State,
|
||||
}
|
||||
|
||||
fn nonce() -> String {
|
||||
// rand 0.5's ThreadRng is cryptographically secure
|
||||
let mut rng = rand::thread_rng();
|
||||
(0..NONCE_LENGTH)
|
||||
.map(|_| {
|
||||
let mut v = rng.gen_range(0x21u8..0x7e);
|
||||
if v == 0x2c {
|
||||
v = 0x7e
|
||||
}
|
||||
v as char
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
impl ScramSha256 {
|
||||
/// Constructs a new instance which will use the provided password for authentication.
|
||||
pub fn new(password: &[u8], channel_binding: ChannelBinding) -> ScramSha256 {
|
||||
let password = Credentials::Password(normalize(password));
|
||||
ScramSha256::new_inner(password, channel_binding, nonce())
|
||||
}
|
||||
|
||||
/// Constructs a new instance which will use the provided key pair for authentication.
|
||||
pub fn new_with_keys(keys: ScramKeys<32>, channel_binding: ChannelBinding) -> ScramSha256 {
|
||||
let password = Credentials::Keys(keys.into());
|
||||
ScramSha256::new_inner(password, channel_binding, nonce())
|
||||
}
|
||||
|
||||
fn new_inner(
|
||||
password: Credentials<32>,
|
||||
channel_binding: ChannelBinding,
|
||||
nonce: String,
|
||||
) -> ScramSha256 {
|
||||
ScramSha256 {
|
||||
message: format!("{}n=,r={}", channel_binding.gs2_header(), nonce),
|
||||
state: State::Update {
|
||||
nonce,
|
||||
password,
|
||||
channel_binding,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the message which should be sent to the backend in an `SASLResponse` message.
|
||||
pub fn message(&self) -> &[u8] {
|
||||
if let State::Done = self.state {
|
||||
panic!("invalid SCRAM state");
|
||||
}
|
||||
self.message.as_bytes()
|
||||
}
|
||||
|
||||
/// Updates the state machine with the response from the backend.
|
||||
///
|
||||
/// This should be called when an `AuthenticationSASLContinue` message is received.
|
||||
pub async fn update(&mut self, message: &[u8]) -> io::Result<()> {
|
||||
let (client_nonce, password, channel_binding) =
|
||||
match mem::replace(&mut self.state, State::Done) {
|
||||
State::Update {
|
||||
nonce,
|
||||
password,
|
||||
channel_binding,
|
||||
} => (nonce, password, channel_binding),
|
||||
_ => return Err(io::Error::new(io::ErrorKind::Other, "invalid SCRAM state")),
|
||||
};
|
||||
|
||||
let message =
|
||||
str::from_utf8(message).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
|
||||
|
||||
let parsed = Parser::new(message).server_first_message()?;
|
||||
|
||||
if !parsed.nonce.starts_with(&client_nonce) {
|
||||
return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid nonce"));
|
||||
}
|
||||
|
||||
let (client_key, server_key) = match password {
|
||||
Credentials::Password(password) => {
|
||||
let salt = match base64::decode(parsed.salt) {
|
||||
Ok(salt) => salt,
|
||||
Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
|
||||
};
|
||||
|
||||
let salted_password = hi(&password, &salt, parsed.iteration_count).await;
|
||||
|
||||
let make_key = |name| {
|
||||
let mut hmac = Hmac::<Sha256>::new_from_slice(&salted_password)
|
||||
.expect("HMAC is able to accept all key sizes");
|
||||
hmac.update(name);
|
||||
|
||||
let mut key = [0u8; 32];
|
||||
key.copy_from_slice(hmac.finalize().into_bytes().as_slice());
|
||||
key
|
||||
};
|
||||
|
||||
(make_key(b"Client Key"), make_key(b"Server Key"))
|
||||
}
|
||||
Credentials::Keys(keys) => (keys.client_key, keys.server_key),
|
||||
};
|
||||
|
||||
let mut hash = Sha256::default();
|
||||
hash.update(client_key);
|
||||
let stored_key = hash.finalize_fixed();
|
||||
|
||||
let mut cbind_input = vec![];
|
||||
cbind_input.extend(channel_binding.gs2_header().as_bytes());
|
||||
cbind_input.extend(channel_binding.cbind_data());
|
||||
let cbind_input = base64::encode(&cbind_input);
|
||||
|
||||
self.message.clear();
|
||||
write!(&mut self.message, "c={},r={}", cbind_input, parsed.nonce).unwrap();
|
||||
|
||||
let auth_message = format!("n=,r={},{},{}", client_nonce, message, self.message);
|
||||
|
||||
let mut hmac = Hmac::<Sha256>::new_from_slice(&stored_key)
|
||||
.expect("HMAC is able to accept all key sizes");
|
||||
hmac.update(auth_message.as_bytes());
|
||||
let client_signature = hmac.finalize().into_bytes();
|
||||
|
||||
let mut client_proof = client_key;
|
||||
for (proof, signature) in client_proof.iter_mut().zip(client_signature) {
|
||||
*proof ^= signature;
|
||||
}
|
||||
|
||||
write!(&mut self.message, ",p={}", base64::encode(client_proof)).unwrap();
|
||||
|
||||
self.state = State::Finish {
|
||||
server_key,
|
||||
auth_message,
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Finalizes the authentication process.
|
||||
///
|
||||
/// This should be called when the backend sends an `AuthenticationSASLFinal` message.
|
||||
/// Authentication has only succeeded if this method returns `Ok(())`.
|
||||
pub fn finish(&mut self, message: &[u8]) -> io::Result<()> {
|
||||
let (server_key, auth_message) = match mem::replace(&mut self.state, State::Done) {
|
||||
State::Finish {
|
||||
server_key,
|
||||
auth_message,
|
||||
} => (server_key, auth_message),
|
||||
_ => return Err(io::Error::new(io::ErrorKind::Other, "invalid SCRAM state")),
|
||||
};
|
||||
|
||||
let message =
|
||||
str::from_utf8(message).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
|
||||
|
||||
let parsed = Parser::new(message).server_final_message()?;
|
||||
|
||||
let verifier = match parsed {
|
||||
ServerFinalMessage::Error(e) => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
format!("SCRAM error: {}", e),
|
||||
));
|
||||
}
|
||||
ServerFinalMessage::Verifier(verifier) => verifier,
|
||||
};
|
||||
|
||||
let verifier = match base64::decode(verifier) {
|
||||
Ok(verifier) => verifier,
|
||||
Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
|
||||
};
|
||||
|
||||
let mut hmac = Hmac::<Sha256>::new_from_slice(&server_key)
|
||||
.expect("HMAC is able to accept all key sizes");
|
||||
hmac.update(auth_message.as_bytes());
|
||||
hmac.verify_slice(&verifier)
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "SCRAM verification error"))
|
||||
}
|
||||
}
|
||||
|
||||
struct Parser<'a> {
|
||||
s: &'a str,
|
||||
it: iter::Peekable<str::CharIndices<'a>>,
|
||||
}
|
||||
|
||||
impl<'a> Parser<'a> {
|
||||
fn new(s: &'a str) -> Parser<'a> {
|
||||
Parser {
|
||||
s,
|
||||
it: s.char_indices().peekable(),
|
||||
}
|
||||
}
|
||||
|
||||
fn eat(&mut self, target: char) -> io::Result<()> {
|
||||
match self.it.next() {
|
||||
Some((_, c)) if c == target => Ok(()),
|
||||
Some((i, c)) => {
|
||||
let m = format!(
|
||||
"unexpected character at byte {}: expected `{}` but got `{}",
|
||||
i, target, c
|
||||
);
|
||||
Err(io::Error::new(io::ErrorKind::InvalidInput, m))
|
||||
}
|
||||
None => Err(io::Error::new(
|
||||
io::ErrorKind::UnexpectedEof,
|
||||
"unexpected EOF",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn take_while<F>(&mut self, f: F) -> io::Result<&'a str>
|
||||
where
|
||||
F: Fn(char) -> bool,
|
||||
{
|
||||
let start = match self.it.peek() {
|
||||
Some(&(i, _)) => i,
|
||||
None => return Ok(""),
|
||||
};
|
||||
|
||||
loop {
|
||||
match self.it.peek() {
|
||||
Some(&(_, c)) if f(c) => {
|
||||
self.it.next();
|
||||
}
|
||||
Some(&(i, _)) => return Ok(&self.s[start..i]),
|
||||
None => return Ok(&self.s[start..]),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn printable(&mut self) -> io::Result<&'a str> {
|
||||
self.take_while(|c| matches!(c, '\x21'..='\x2b' | '\x2d'..='\x7e'))
|
||||
}
|
||||
|
||||
fn nonce(&mut self) -> io::Result<&'a str> {
|
||||
self.eat('r')?;
|
||||
self.eat('=')?;
|
||||
self.printable()
|
||||
}
|
||||
|
||||
fn base64(&mut self) -> io::Result<&'a str> {
|
||||
self.take_while(|c| matches!(c, 'a'..='z' | 'A'..='Z' | '0'..='9' | '/' | '+' | '='))
|
||||
}
|
||||
|
||||
fn salt(&mut self) -> io::Result<&'a str> {
|
||||
self.eat('s')?;
|
||||
self.eat('=')?;
|
||||
self.base64()
|
||||
}
|
||||
|
||||
fn posit_number(&mut self) -> io::Result<u32> {
|
||||
let n = self.take_while(|c| c.is_ascii_digit())?;
|
||||
n.parse()
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))
|
||||
}
|
||||
|
||||
fn iteration_count(&mut self) -> io::Result<u32> {
|
||||
self.eat('i')?;
|
||||
self.eat('=')?;
|
||||
self.posit_number()
|
||||
}
|
||||
|
||||
fn eof(&mut self) -> io::Result<()> {
|
||||
match self.it.peek() {
|
||||
Some(&(i, _)) => Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
format!("unexpected trailing data at byte {}", i),
|
||||
)),
|
||||
None => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
fn server_first_message(&mut self) -> io::Result<ServerFirstMessage<'a>> {
|
||||
let nonce = self.nonce()?;
|
||||
self.eat(',')?;
|
||||
let salt = self.salt()?;
|
||||
self.eat(',')?;
|
||||
let iteration_count = self.iteration_count()?;
|
||||
self.eof()?;
|
||||
|
||||
Ok(ServerFirstMessage {
|
||||
nonce,
|
||||
salt,
|
||||
iteration_count,
|
||||
})
|
||||
}
|
||||
|
||||
fn value(&mut self) -> io::Result<&'a str> {
|
||||
self.take_while(|c| matches!(c, '\0' | '=' | ','))
|
||||
}
|
||||
|
||||
fn server_error(&mut self) -> io::Result<Option<&'a str>> {
|
||||
match self.it.peek() {
|
||||
Some(&(_, 'e')) => {}
|
||||
_ => return Ok(None),
|
||||
}
|
||||
|
||||
self.eat('e')?;
|
||||
self.eat('=')?;
|
||||
self.value().map(Some)
|
||||
}
|
||||
|
||||
fn verifier(&mut self) -> io::Result<&'a str> {
|
||||
self.eat('v')?;
|
||||
self.eat('=')?;
|
||||
self.base64()
|
||||
}
|
||||
|
||||
fn server_final_message(&mut self) -> io::Result<ServerFinalMessage<'a>> {
|
||||
let message = match self.server_error()? {
|
||||
Some(error) => ServerFinalMessage::Error(error),
|
||||
None => ServerFinalMessage::Verifier(self.verifier()?),
|
||||
};
|
||||
self.eof()?;
|
||||
Ok(message)
|
||||
}
|
||||
}
|
||||
|
||||
struct ServerFirstMessage<'a> {
|
||||
nonce: &'a str,
|
||||
salt: &'a str,
|
||||
iteration_count: u32,
|
||||
}
|
||||
|
||||
enum ServerFinalMessage<'a> {
|
||||
Error(&'a str),
|
||||
Verifier(&'a str),
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parse_server_first_message() {
|
||||
let message = "r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096";
|
||||
let message = Parser::new(message).server_first_message().unwrap();
|
||||
assert_eq!(message.nonce, "fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j");
|
||||
assert_eq!(message.salt, "QSXCR+Q6sek8bf92");
|
||||
assert_eq!(message.iteration_count, 4096);
|
||||
}
|
||||
|
||||
// recorded auth exchange from psql
|
||||
#[tokio::test]
|
||||
async fn exchange() {
|
||||
let password = "foobar";
|
||||
let nonce = "9IZ2O01zb9IgiIZ1WJ/zgpJB";
|
||||
|
||||
let client_first = "n,,n=,r=9IZ2O01zb9IgiIZ1WJ/zgpJB";
|
||||
let server_first =
|
||||
"r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,s=fs3IXBy7U7+IvVjZ,i\
|
||||
=4096";
|
||||
let client_final =
|
||||
"c=biws,r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,p=AmNKosjJzS3\
|
||||
1NTlQYNs5BTeQjdHdk7lOflDo5re2an8=";
|
||||
let server_final = "v=U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw=";
|
||||
|
||||
let mut scram = ScramSha256::new_inner(
|
||||
Credentials::Password(normalize(password.as_bytes())),
|
||||
ChannelBinding::unsupported(),
|
||||
nonce.to_string(),
|
||||
);
|
||||
assert_eq!(str::from_utf8(scram.message()).unwrap(), client_first);
|
||||
|
||||
scram.update(server_first.as_bytes()).await.unwrap();
|
||||
assert_eq!(str::from_utf8(scram.message()).unwrap(), client_final);
|
||||
|
||||
scram.finish(server_final.as_bytes()).unwrap();
|
||||
}
|
||||
}
|
||||
93
libs/proxy/postgres-protocol2/src/escape/mod.rs
Normal file
93
libs/proxy/postgres-protocol2/src/escape/mod.rs
Normal file
@@ -0,0 +1,93 @@
|
||||
//! Provides functions for escaping literals and identifiers for use
|
||||
//! in SQL queries.
|
||||
//!
|
||||
//! Prefer parameterized queries where possible. Do not escape
|
||||
//! parameters in a parameterized query.
|
||||
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
|
||||
/// Escape a literal and surround result with single quotes. Not
|
||||
/// recommended in most cases.
|
||||
///
|
||||
/// If input contains backslashes, result will be of the form `
|
||||
/// E'...'` so it is safe to use regardless of the setting of
|
||||
/// standard_conforming_strings.
|
||||
pub fn escape_literal(input: &str) -> String {
|
||||
escape_internal(input, false)
|
||||
}
|
||||
|
||||
/// Escape an identifier and surround result with double quotes.
|
||||
pub fn escape_identifier(input: &str) -> String {
|
||||
escape_internal(input, true)
|
||||
}
|
||||
|
||||
// Translation of PostgreSQL libpq's PQescapeInternal(). Does not
|
||||
// require a connection because input string is known to be valid
|
||||
// UTF-8.
|
||||
//
|
||||
// Escape arbitrary strings. If as_ident is true, we escape the
|
||||
// result as an identifier; if false, as a literal. The result is
|
||||
// returned in a newly allocated buffer. If we fail due to an
|
||||
// encoding violation or out of memory condition, we return NULL,
|
||||
// storing an error message into conn.
|
||||
fn escape_internal(input: &str, as_ident: bool) -> String {
|
||||
let mut num_backslashes = 0;
|
||||
let mut num_quotes = 0;
|
||||
let quote_char = if as_ident { '"' } else { '\'' };
|
||||
|
||||
// Scan the string for characters that must be escaped.
|
||||
for ch in input.chars() {
|
||||
if ch == quote_char {
|
||||
num_quotes += 1;
|
||||
} else if ch == '\\' {
|
||||
num_backslashes += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Allocate output String.
|
||||
let mut result_size = input.len() + num_quotes + 3; // two quotes, plus a NUL
|
||||
if !as_ident && num_backslashes > 0 {
|
||||
result_size += num_backslashes + 2;
|
||||
}
|
||||
|
||||
let mut output = String::with_capacity(result_size);
|
||||
|
||||
// If we are escaping a literal that contains backslashes, we use
|
||||
// the escape string syntax so that the result is correct under
|
||||
// either value of standard_conforming_strings. We also emit a
|
||||
// leading space in this case, to guard against the possibility
|
||||
// that the result might be interpolated immediately following an
|
||||
// identifier.
|
||||
if !as_ident && num_backslashes > 0 {
|
||||
output.push(' ');
|
||||
output.push('E');
|
||||
}
|
||||
|
||||
// Opening quote.
|
||||
output.push(quote_char);
|
||||
|
||||
// Use fast path if possible.
|
||||
//
|
||||
// We've already verified that the input string is well-formed in
|
||||
// the current encoding. If it contains no quotes and, in the
|
||||
// case of literal-escaping, no backslashes, then we can just copy
|
||||
// it directly to the output buffer, adding the necessary quotes.
|
||||
//
|
||||
// If not, we must rescan the input and process each character
|
||||
// individually.
|
||||
if num_quotes == 0 && (num_backslashes == 0 || as_ident) {
|
||||
output.push_str(input);
|
||||
} else {
|
||||
for ch in input.chars() {
|
||||
if ch == quote_char || (!as_ident && ch == '\\') {
|
||||
output.push(ch);
|
||||
}
|
||||
output.push(ch);
|
||||
}
|
||||
}
|
||||
|
||||
output.push(quote_char);
|
||||
|
||||
output
|
||||
}
|
||||
17
libs/proxy/postgres-protocol2/src/escape/test.rs
Normal file
17
libs/proxy/postgres-protocol2/src/escape/test.rs
Normal file
@@ -0,0 +1,17 @@
|
||||
use crate::escape::{escape_identifier, escape_literal};
|
||||
|
||||
#[test]
|
||||
fn test_escape_idenifier() {
|
||||
assert_eq!(escape_identifier("foo"), String::from("\"foo\""));
|
||||
assert_eq!(escape_identifier("f\\oo"), String::from("\"f\\oo\""));
|
||||
assert_eq!(escape_identifier("f'oo"), String::from("\"f'oo\""));
|
||||
assert_eq!(escape_identifier("f\"oo"), String::from("\"f\"\"oo\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_escape_literal() {
|
||||
assert_eq!(escape_literal("foo"), String::from("'foo'"));
|
||||
assert_eq!(escape_literal("f\\oo"), String::from(" E'f\\\\oo'"));
|
||||
assert_eq!(escape_literal("f'oo"), String::from("'f''oo'"));
|
||||
assert_eq!(escape_literal("f\"oo"), String::from("'f\"oo'"));
|
||||
}
|
||||
78
libs/proxy/postgres-protocol2/src/lib.rs
Normal file
78
libs/proxy/postgres-protocol2/src/lib.rs
Normal file
@@ -0,0 +1,78 @@
|
||||
//! Low level Postgres protocol APIs.
|
||||
//!
|
||||
//! This crate implements the low level components of Postgres's communication
|
||||
//! protocol, including message and value serialization and deserialization.
|
||||
//! It is designed to be used as a building block by higher level APIs such as
|
||||
//! `rust-postgres`, and should not typically be used directly.
|
||||
//!
|
||||
//! # Note
|
||||
//!
|
||||
//! This library assumes that the `client_encoding` backend parameter has been
|
||||
//! set to `UTF8`. It will most likely not behave properly if that is not the case.
|
||||
#![doc(html_root_url = "https://docs.rs/postgres-protocol/0.6")]
|
||||
#![warn(missing_docs, rust_2018_idioms, clippy::all)]
|
||||
|
||||
use byteorder::{BigEndian, ByteOrder};
|
||||
use bytes::{BufMut, BytesMut};
|
||||
use std::io;
|
||||
|
||||
pub mod authentication;
|
||||
pub mod escape;
|
||||
pub mod message;
|
||||
pub mod password;
|
||||
pub mod types;
|
||||
|
||||
/// A Postgres OID.
|
||||
pub type Oid = u32;
|
||||
|
||||
/// A Postgres Log Sequence Number (LSN).
|
||||
pub type Lsn = u64;
|
||||
|
||||
/// An enum indicating if a value is `NULL` or not.
|
||||
pub enum IsNull {
|
||||
/// The value is `NULL`.
|
||||
Yes,
|
||||
/// The value is not `NULL`.
|
||||
No,
|
||||
}
|
||||
|
||||
fn write_nullable<F, E>(serializer: F, buf: &mut BytesMut) -> Result<(), E>
|
||||
where
|
||||
F: FnOnce(&mut BytesMut) -> Result<IsNull, E>,
|
||||
E: From<io::Error>,
|
||||
{
|
||||
let base = buf.len();
|
||||
buf.put_i32(0);
|
||||
let size = match serializer(buf)? {
|
||||
IsNull::No => i32::from_usize(buf.len() - base - 4)?,
|
||||
IsNull::Yes => -1,
|
||||
};
|
||||
BigEndian::write_i32(&mut buf[base..], size);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
trait FromUsize: Sized {
|
||||
fn from_usize(x: usize) -> Result<Self, io::Error>;
|
||||
}
|
||||
|
||||
macro_rules! from_usize {
|
||||
($t:ty) => {
|
||||
impl FromUsize for $t {
|
||||
#[inline]
|
||||
fn from_usize(x: usize) -> io::Result<$t> {
|
||||
if x > <$t>::MAX as usize {
|
||||
Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"value too large to transmit",
|
||||
))
|
||||
} else {
|
||||
Ok(x as $t)
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
from_usize!(i16);
|
||||
from_usize!(i32);
|
||||
766
libs/proxy/postgres-protocol2/src/message/backend.rs
Normal file
766
libs/proxy/postgres-protocol2/src/message/backend.rs
Normal file
@@ -0,0 +1,766 @@
|
||||
#![allow(missing_docs)]
|
||||
|
||||
use byteorder::{BigEndian, ByteOrder, ReadBytesExt};
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use memchr::memchr;
|
||||
use std::cmp;
|
||||
use std::io::{self, Read};
|
||||
use std::ops::Range;
|
||||
use std::str;
|
||||
|
||||
use crate::Oid;
|
||||
|
||||
// top-level message tags
|
||||
const PARSE_COMPLETE_TAG: u8 = b'1';
|
||||
const BIND_COMPLETE_TAG: u8 = b'2';
|
||||
const CLOSE_COMPLETE_TAG: u8 = b'3';
|
||||
pub const NOTIFICATION_RESPONSE_TAG: u8 = b'A';
|
||||
const COPY_DONE_TAG: u8 = b'c';
|
||||
const COMMAND_COMPLETE_TAG: u8 = b'C';
|
||||
const COPY_DATA_TAG: u8 = b'd';
|
||||
const DATA_ROW_TAG: u8 = b'D';
|
||||
const ERROR_RESPONSE_TAG: u8 = b'E';
|
||||
const COPY_IN_RESPONSE_TAG: u8 = b'G';
|
||||
const COPY_OUT_RESPONSE_TAG: u8 = b'H';
|
||||
const COPY_BOTH_RESPONSE_TAG: u8 = b'W';
|
||||
const EMPTY_QUERY_RESPONSE_TAG: u8 = b'I';
|
||||
const BACKEND_KEY_DATA_TAG: u8 = b'K';
|
||||
pub const NO_DATA_TAG: u8 = b'n';
|
||||
pub const NOTICE_RESPONSE_TAG: u8 = b'N';
|
||||
const AUTHENTICATION_TAG: u8 = b'R';
|
||||
const PORTAL_SUSPENDED_TAG: u8 = b's';
|
||||
pub const PARAMETER_STATUS_TAG: u8 = b'S';
|
||||
const PARAMETER_DESCRIPTION_TAG: u8 = b't';
|
||||
const ROW_DESCRIPTION_TAG: u8 = b'T';
|
||||
pub const READY_FOR_QUERY_TAG: u8 = b'Z';
|
||||
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct Header {
|
||||
tag: u8,
|
||||
len: i32,
|
||||
}
|
||||
|
||||
#[allow(clippy::len_without_is_empty)]
|
||||
impl Header {
|
||||
#[inline]
|
||||
pub fn parse(buf: &[u8]) -> io::Result<Option<Header>> {
|
||||
if buf.len() < 5 {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let tag = buf[0];
|
||||
let len = BigEndian::read_i32(&buf[1..]);
|
||||
|
||||
if len < 4 {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"invalid message length: header length < 4",
|
||||
));
|
||||
}
|
||||
|
||||
Ok(Some(Header { tag, len }))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn tag(self) -> u8 {
|
||||
self.tag
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn len(self) -> i32 {
|
||||
self.len
|
||||
}
|
||||
}
|
||||
|
||||
/// An enum representing Postgres backend messages.
|
||||
#[non_exhaustive]
|
||||
pub enum Message {
|
||||
AuthenticationCleartextPassword,
|
||||
AuthenticationGss,
|
||||
AuthenticationKerberosV5,
|
||||
AuthenticationMd5Password(AuthenticationMd5PasswordBody),
|
||||
AuthenticationOk,
|
||||
AuthenticationScmCredential,
|
||||
AuthenticationSspi,
|
||||
AuthenticationGssContinue,
|
||||
AuthenticationSasl(AuthenticationSaslBody),
|
||||
AuthenticationSaslContinue(AuthenticationSaslContinueBody),
|
||||
AuthenticationSaslFinal(AuthenticationSaslFinalBody),
|
||||
BackendKeyData(BackendKeyDataBody),
|
||||
BindComplete,
|
||||
CloseComplete,
|
||||
CommandComplete(CommandCompleteBody),
|
||||
CopyData,
|
||||
CopyDone,
|
||||
CopyInResponse,
|
||||
CopyOutResponse,
|
||||
CopyBothResponse,
|
||||
DataRow(DataRowBody),
|
||||
EmptyQueryResponse,
|
||||
ErrorResponse(ErrorResponseBody),
|
||||
NoData,
|
||||
NoticeResponse(NoticeResponseBody),
|
||||
NotificationResponse(NotificationResponseBody),
|
||||
ParameterDescription(ParameterDescriptionBody),
|
||||
ParameterStatus(ParameterStatusBody),
|
||||
ParseComplete,
|
||||
PortalSuspended,
|
||||
ReadyForQuery(ReadyForQueryBody),
|
||||
RowDescription(RowDescriptionBody),
|
||||
}
|
||||
|
||||
impl Message {
|
||||
#[inline]
|
||||
pub fn parse(buf: &mut BytesMut) -> io::Result<Option<Message>> {
|
||||
if buf.len() < 5 {
|
||||
let to_read = 5 - buf.len();
|
||||
buf.reserve(to_read);
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let tag = buf[0];
|
||||
let len = (&buf[1..5]).read_u32::<BigEndian>().unwrap();
|
||||
|
||||
if len < 4 {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"invalid message length: parsing u32",
|
||||
));
|
||||
}
|
||||
|
||||
let total_len = len as usize + 1;
|
||||
if buf.len() < total_len {
|
||||
let to_read = total_len - buf.len();
|
||||
buf.reserve(to_read);
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let mut buf = Buffer {
|
||||
bytes: buf.split_to(total_len).freeze(),
|
||||
idx: 5,
|
||||
};
|
||||
|
||||
let message = match tag {
|
||||
PARSE_COMPLETE_TAG => Message::ParseComplete,
|
||||
BIND_COMPLETE_TAG => Message::BindComplete,
|
||||
CLOSE_COMPLETE_TAG => Message::CloseComplete,
|
||||
NOTIFICATION_RESPONSE_TAG => {
|
||||
let process_id = buf.read_i32::<BigEndian>()?;
|
||||
let channel = buf.read_cstr()?;
|
||||
let message = buf.read_cstr()?;
|
||||
Message::NotificationResponse(NotificationResponseBody {
|
||||
process_id,
|
||||
channel,
|
||||
message,
|
||||
})
|
||||
}
|
||||
COPY_DONE_TAG => Message::CopyDone,
|
||||
COMMAND_COMPLETE_TAG => {
|
||||
let tag = buf.read_cstr()?;
|
||||
Message::CommandComplete(CommandCompleteBody { tag })
|
||||
}
|
||||
COPY_DATA_TAG => Message::CopyData,
|
||||
DATA_ROW_TAG => {
|
||||
let len = buf.read_u16::<BigEndian>()?;
|
||||
let storage = buf.read_all();
|
||||
Message::DataRow(DataRowBody { storage, len })
|
||||
}
|
||||
ERROR_RESPONSE_TAG => {
|
||||
let storage = buf.read_all();
|
||||
Message::ErrorResponse(ErrorResponseBody { storage })
|
||||
}
|
||||
COPY_IN_RESPONSE_TAG => Message::CopyInResponse,
|
||||
COPY_OUT_RESPONSE_TAG => Message::CopyOutResponse,
|
||||
COPY_BOTH_RESPONSE_TAG => Message::CopyBothResponse,
|
||||
EMPTY_QUERY_RESPONSE_TAG => Message::EmptyQueryResponse,
|
||||
BACKEND_KEY_DATA_TAG => {
|
||||
let process_id = buf.read_i32::<BigEndian>()?;
|
||||
let secret_key = buf.read_i32::<BigEndian>()?;
|
||||
Message::BackendKeyData(BackendKeyDataBody {
|
||||
process_id,
|
||||
secret_key,
|
||||
})
|
||||
}
|
||||
NO_DATA_TAG => Message::NoData,
|
||||
NOTICE_RESPONSE_TAG => {
|
||||
let storage = buf.read_all();
|
||||
Message::NoticeResponse(NoticeResponseBody { storage })
|
||||
}
|
||||
AUTHENTICATION_TAG => match buf.read_i32::<BigEndian>()? {
|
||||
0 => Message::AuthenticationOk,
|
||||
2 => Message::AuthenticationKerberosV5,
|
||||
3 => Message::AuthenticationCleartextPassword,
|
||||
5 => {
|
||||
let mut salt = [0; 4];
|
||||
buf.read_exact(&mut salt)?;
|
||||
Message::AuthenticationMd5Password(AuthenticationMd5PasswordBody { salt })
|
||||
}
|
||||
6 => Message::AuthenticationScmCredential,
|
||||
7 => Message::AuthenticationGss,
|
||||
8 => Message::AuthenticationGssContinue,
|
||||
9 => Message::AuthenticationSspi,
|
||||
10 => {
|
||||
let storage = buf.read_all();
|
||||
Message::AuthenticationSasl(AuthenticationSaslBody(storage))
|
||||
}
|
||||
11 => {
|
||||
let storage = buf.read_all();
|
||||
Message::AuthenticationSaslContinue(AuthenticationSaslContinueBody(storage))
|
||||
}
|
||||
12 => {
|
||||
let storage = buf.read_all();
|
||||
Message::AuthenticationSaslFinal(AuthenticationSaslFinalBody(storage))
|
||||
}
|
||||
tag => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
format!("unknown authentication tag `{}`", tag),
|
||||
));
|
||||
}
|
||||
},
|
||||
PORTAL_SUSPENDED_TAG => Message::PortalSuspended,
|
||||
PARAMETER_STATUS_TAG => {
|
||||
let name = buf.read_cstr()?;
|
||||
let value = buf.read_cstr()?;
|
||||
Message::ParameterStatus(ParameterStatusBody { name, value })
|
||||
}
|
||||
PARAMETER_DESCRIPTION_TAG => {
|
||||
let len = buf.read_u16::<BigEndian>()?;
|
||||
let storage = buf.read_all();
|
||||
Message::ParameterDescription(ParameterDescriptionBody { storage, len })
|
||||
}
|
||||
ROW_DESCRIPTION_TAG => {
|
||||
let len = buf.read_u16::<BigEndian>()?;
|
||||
let storage = buf.read_all();
|
||||
Message::RowDescription(RowDescriptionBody { storage, len })
|
||||
}
|
||||
READY_FOR_QUERY_TAG => {
|
||||
let status = buf.read_u8()?;
|
||||
Message::ReadyForQuery(ReadyForQueryBody { status })
|
||||
}
|
||||
tag => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
format!("unknown message tag `{}`", tag),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
if !buf.is_empty() {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"invalid message length: expected buffer to be empty",
|
||||
));
|
||||
}
|
||||
|
||||
Ok(Some(message))
|
||||
}
|
||||
}
|
||||
|
||||
struct Buffer {
|
||||
bytes: Bytes,
|
||||
idx: usize,
|
||||
}
|
||||
|
||||
impl Buffer {
|
||||
#[inline]
|
||||
fn slice(&self) -> &[u8] {
|
||||
&self.bytes[self.idx..]
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn is_empty(&self) -> bool {
|
||||
self.slice().is_empty()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn read_cstr(&mut self) -> io::Result<Bytes> {
|
||||
match memchr(0, self.slice()) {
|
||||
Some(pos) => {
|
||||
let start = self.idx;
|
||||
let end = start + pos;
|
||||
let cstr = self.bytes.slice(start..end);
|
||||
self.idx = end + 1;
|
||||
Ok(cstr)
|
||||
}
|
||||
None => Err(io::Error::new(
|
||||
io::ErrorKind::UnexpectedEof,
|
||||
"unexpected EOF",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn read_all(&mut self) -> Bytes {
|
||||
let buf = self.bytes.slice(self.idx..);
|
||||
self.idx = self.bytes.len();
|
||||
buf
|
||||
}
|
||||
}
|
||||
|
||||
impl Read for Buffer {
|
||||
#[inline]
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
let len = {
|
||||
let slice = self.slice();
|
||||
let len = cmp::min(slice.len(), buf.len());
|
||||
buf[..len].copy_from_slice(&slice[..len]);
|
||||
len
|
||||
};
|
||||
self.idx += len;
|
||||
Ok(len)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AuthenticationMd5PasswordBody {
|
||||
salt: [u8; 4],
|
||||
}
|
||||
|
||||
impl AuthenticationMd5PasswordBody {
|
||||
#[inline]
|
||||
pub fn salt(&self) -> [u8; 4] {
|
||||
self.salt
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AuthenticationSaslBody(Bytes);
|
||||
|
||||
impl AuthenticationSaslBody {
|
||||
#[inline]
|
||||
pub fn mechanisms(&self) -> SaslMechanisms<'_> {
|
||||
SaslMechanisms(&self.0)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SaslMechanisms<'a>(&'a [u8]);
|
||||
|
||||
impl<'a> FallibleIterator for SaslMechanisms<'a> {
|
||||
type Item = &'a str;
|
||||
type Error = io::Error;
|
||||
|
||||
#[inline]
|
||||
fn next(&mut self) -> io::Result<Option<&'a str>> {
|
||||
let value_end = find_null(self.0, 0)?;
|
||||
if value_end == 0 {
|
||||
if self.0.len() != 1 {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"invalid message length: expected to be at end of iterator for sasl",
|
||||
));
|
||||
}
|
||||
Ok(None)
|
||||
} else {
|
||||
let value = get_str(&self.0[..value_end])?;
|
||||
self.0 = &self.0[value_end + 1..];
|
||||
Ok(Some(value))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AuthenticationSaslContinueBody(Bytes);
|
||||
|
||||
impl AuthenticationSaslContinueBody {
|
||||
#[inline]
|
||||
pub fn data(&self) -> &[u8] {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AuthenticationSaslFinalBody(Bytes);
|
||||
|
||||
impl AuthenticationSaslFinalBody {
|
||||
#[inline]
|
||||
pub fn data(&self) -> &[u8] {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
pub struct BackendKeyDataBody {
|
||||
process_id: i32,
|
||||
secret_key: i32,
|
||||
}
|
||||
|
||||
impl BackendKeyDataBody {
|
||||
#[inline]
|
||||
pub fn process_id(&self) -> i32 {
|
||||
self.process_id
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn secret_key(&self) -> i32 {
|
||||
self.secret_key
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CommandCompleteBody {
|
||||
tag: Bytes,
|
||||
}
|
||||
|
||||
impl CommandCompleteBody {
|
||||
#[inline]
|
||||
pub fn tag(&self) -> io::Result<&str> {
|
||||
get_str(&self.tag)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct DataRowBody {
|
||||
storage: Bytes,
|
||||
len: u16,
|
||||
}
|
||||
|
||||
impl DataRowBody {
|
||||
#[inline]
|
||||
pub fn ranges(&self) -> DataRowRanges<'_> {
|
||||
DataRowRanges {
|
||||
buf: &self.storage,
|
||||
len: self.storage.len(),
|
||||
remaining: self.len,
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn buffer(&self) -> &[u8] {
|
||||
&self.storage
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DataRowRanges<'a> {
|
||||
buf: &'a [u8],
|
||||
len: usize,
|
||||
remaining: u16,
|
||||
}
|
||||
|
||||
impl FallibleIterator for DataRowRanges<'_> {
|
||||
type Item = Option<Range<usize>>;
|
||||
type Error = io::Error;
|
||||
|
||||
#[inline]
|
||||
fn next(&mut self) -> io::Result<Option<Option<Range<usize>>>> {
|
||||
if self.remaining == 0 {
|
||||
if self.buf.is_empty() {
|
||||
return Ok(None);
|
||||
} else {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"invalid message length: datarowrange is not empty",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
self.remaining -= 1;
|
||||
let len = self.buf.read_i32::<BigEndian>()?;
|
||||
if len < 0 {
|
||||
Ok(Some(None))
|
||||
} else {
|
||||
let len = len as usize;
|
||||
if self.buf.len() < len {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::UnexpectedEof,
|
||||
"unexpected EOF",
|
||||
));
|
||||
}
|
||||
let base = self.len - self.buf.len();
|
||||
self.buf = &self.buf[len..];
|
||||
Ok(Some(Some(base..base + len)))
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn size_hint(&self) -> (usize, Option<usize>) {
|
||||
let len = self.remaining as usize;
|
||||
(len, Some(len))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ErrorResponseBody {
|
||||
storage: Bytes,
|
||||
}
|
||||
|
||||
impl ErrorResponseBody {
|
||||
#[inline]
|
||||
pub fn fields(&self) -> ErrorFields<'_> {
|
||||
ErrorFields { buf: &self.storage }
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ErrorFields<'a> {
|
||||
buf: &'a [u8],
|
||||
}
|
||||
|
||||
impl<'a> FallibleIterator for ErrorFields<'a> {
|
||||
type Item = ErrorField<'a>;
|
||||
type Error = io::Error;
|
||||
|
||||
#[inline]
|
||||
fn next(&mut self) -> io::Result<Option<ErrorField<'a>>> {
|
||||
let type_ = self.buf.read_u8()?;
|
||||
if type_ == 0 {
|
||||
if self.buf.is_empty() {
|
||||
return Ok(None);
|
||||
} else {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"invalid message length: error fields is not drained",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let value_end = find_null(self.buf, 0)?;
|
||||
let value = get_str(&self.buf[..value_end])?;
|
||||
self.buf = &self.buf[value_end + 1..];
|
||||
|
||||
Ok(Some(ErrorField { type_, value }))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ErrorField<'a> {
|
||||
type_: u8,
|
||||
value: &'a str,
|
||||
}
|
||||
|
||||
impl ErrorField<'_> {
|
||||
#[inline]
|
||||
pub fn type_(&self) -> u8 {
|
||||
self.type_
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn value(&self) -> &str {
|
||||
self.value
|
||||
}
|
||||
}
|
||||
|
||||
pub struct NoticeResponseBody {
|
||||
storage: Bytes,
|
||||
}
|
||||
|
||||
impl NoticeResponseBody {
|
||||
#[inline]
|
||||
pub fn fields(&self) -> ErrorFields<'_> {
|
||||
ErrorFields { buf: &self.storage }
|
||||
}
|
||||
}
|
||||
|
||||
pub struct NotificationResponseBody {
|
||||
process_id: i32,
|
||||
channel: Bytes,
|
||||
message: Bytes,
|
||||
}
|
||||
|
||||
impl NotificationResponseBody {
|
||||
#[inline]
|
||||
pub fn process_id(&self) -> i32 {
|
||||
self.process_id
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn channel(&self) -> io::Result<&str> {
|
||||
get_str(&self.channel)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn message(&self) -> io::Result<&str> {
|
||||
get_str(&self.message)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ParameterDescriptionBody {
|
||||
storage: Bytes,
|
||||
len: u16,
|
||||
}
|
||||
|
||||
impl ParameterDescriptionBody {
|
||||
#[inline]
|
||||
pub fn parameters(&self) -> Parameters<'_> {
|
||||
Parameters {
|
||||
buf: &self.storage,
|
||||
remaining: self.len,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Parameters<'a> {
|
||||
buf: &'a [u8],
|
||||
remaining: u16,
|
||||
}
|
||||
|
||||
impl FallibleIterator for Parameters<'_> {
|
||||
type Item = Oid;
|
||||
type Error = io::Error;
|
||||
|
||||
#[inline]
|
||||
fn next(&mut self) -> io::Result<Option<Oid>> {
|
||||
if self.remaining == 0 {
|
||||
if self.buf.is_empty() {
|
||||
return Ok(None);
|
||||
} else {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"invalid message length: parameters is not drained",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
self.remaining -= 1;
|
||||
self.buf.read_u32::<BigEndian>().map(Some)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn size_hint(&self) -> (usize, Option<usize>) {
|
||||
let len = self.remaining as usize;
|
||||
(len, Some(len))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ParameterStatusBody {
|
||||
name: Bytes,
|
||||
value: Bytes,
|
||||
}
|
||||
|
||||
impl ParameterStatusBody {
|
||||
#[inline]
|
||||
pub fn name(&self) -> io::Result<&str> {
|
||||
get_str(&self.name)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn value(&self) -> io::Result<&str> {
|
||||
get_str(&self.value)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ReadyForQueryBody {
|
||||
status: u8,
|
||||
}
|
||||
|
||||
impl ReadyForQueryBody {
|
||||
#[inline]
|
||||
pub fn status(&self) -> u8 {
|
||||
self.status
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RowDescriptionBody {
|
||||
storage: Bytes,
|
||||
len: u16,
|
||||
}
|
||||
|
||||
impl RowDescriptionBody {
|
||||
#[inline]
|
||||
pub fn fields(&self) -> Fields<'_> {
|
||||
Fields {
|
||||
buf: &self.storage,
|
||||
remaining: self.len,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Fields<'a> {
|
||||
buf: &'a [u8],
|
||||
remaining: u16,
|
||||
}
|
||||
|
||||
impl<'a> FallibleIterator for Fields<'a> {
|
||||
type Item = Field<'a>;
|
||||
type Error = io::Error;
|
||||
|
||||
#[inline]
|
||||
fn next(&mut self) -> io::Result<Option<Field<'a>>> {
|
||||
if self.remaining == 0 {
|
||||
if self.buf.is_empty() {
|
||||
return Ok(None);
|
||||
} else {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"invalid message length: field is not drained",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
self.remaining -= 1;
|
||||
let name_end = find_null(self.buf, 0)?;
|
||||
let name = get_str(&self.buf[..name_end])?;
|
||||
self.buf = &self.buf[name_end + 1..];
|
||||
let table_oid = self.buf.read_u32::<BigEndian>()?;
|
||||
let column_id = self.buf.read_i16::<BigEndian>()?;
|
||||
let type_oid = self.buf.read_u32::<BigEndian>()?;
|
||||
let type_size = self.buf.read_i16::<BigEndian>()?;
|
||||
let type_modifier = self.buf.read_i32::<BigEndian>()?;
|
||||
let format = self.buf.read_i16::<BigEndian>()?;
|
||||
|
||||
Ok(Some(Field {
|
||||
name,
|
||||
table_oid,
|
||||
column_id,
|
||||
type_oid,
|
||||
type_size,
|
||||
type_modifier,
|
||||
format,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Field<'a> {
|
||||
name: &'a str,
|
||||
table_oid: Oid,
|
||||
column_id: i16,
|
||||
type_oid: Oid,
|
||||
type_size: i16,
|
||||
type_modifier: i32,
|
||||
format: i16,
|
||||
}
|
||||
|
||||
impl<'a> Field<'a> {
|
||||
#[inline]
|
||||
pub fn name(&self) -> &'a str {
|
||||
self.name
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn table_oid(&self) -> Oid {
|
||||
self.table_oid
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn column_id(&self) -> i16 {
|
||||
self.column_id
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn type_oid(&self) -> Oid {
|
||||
self.type_oid
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn type_size(&self) -> i16 {
|
||||
self.type_size
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn type_modifier(&self) -> i32 {
|
||||
self.type_modifier
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn format(&self) -> i16 {
|
||||
self.format
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn find_null(buf: &[u8], start: usize) -> io::Result<usize> {
|
||||
match memchr(0, &buf[start..]) {
|
||||
Some(pos) => Ok(pos + start),
|
||||
None => Err(io::Error::new(
|
||||
io::ErrorKind::UnexpectedEof,
|
||||
"unexpected EOF",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn get_str(buf: &[u8]) -> io::Result<&str> {
|
||||
str::from_utf8(buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))
|
||||
}
|
||||
297
libs/proxy/postgres-protocol2/src/message/frontend.rs
Normal file
297
libs/proxy/postgres-protocol2/src/message/frontend.rs
Normal file
@@ -0,0 +1,297 @@
|
||||
//! Frontend message serialization.
|
||||
#![allow(missing_docs)]
|
||||
|
||||
use byteorder::{BigEndian, ByteOrder};
|
||||
use bytes::{Buf, BufMut, BytesMut};
|
||||
use std::convert::TryFrom;
|
||||
use std::error::Error;
|
||||
use std::io;
|
||||
use std::marker;
|
||||
|
||||
use crate::{write_nullable, FromUsize, IsNull, Oid};
|
||||
|
||||
#[inline]
|
||||
fn write_body<F, E>(buf: &mut BytesMut, f: F) -> Result<(), E>
|
||||
where
|
||||
F: FnOnce(&mut BytesMut) -> Result<(), E>,
|
||||
E: From<io::Error>,
|
||||
{
|
||||
let base = buf.len();
|
||||
buf.extend_from_slice(&[0; 4]);
|
||||
|
||||
f(buf)?;
|
||||
|
||||
let size = i32::from_usize(buf.len() - base)?;
|
||||
BigEndian::write_i32(&mut buf[base..], size);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub enum BindError {
|
||||
Conversion(Box<dyn Error + marker::Sync + Send>),
|
||||
Serialization(io::Error),
|
||||
}
|
||||
|
||||
impl From<Box<dyn Error + marker::Sync + Send>> for BindError {
|
||||
#[inline]
|
||||
fn from(e: Box<dyn Error + marker::Sync + Send>) -> BindError {
|
||||
BindError::Conversion(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<io::Error> for BindError {
|
||||
#[inline]
|
||||
fn from(e: io::Error) -> BindError {
|
||||
BindError::Serialization(e)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn bind<I, J, F, T, K>(
|
||||
portal: &str,
|
||||
statement: &str,
|
||||
formats: I,
|
||||
values: J,
|
||||
mut serializer: F,
|
||||
result_formats: K,
|
||||
buf: &mut BytesMut,
|
||||
) -> Result<(), BindError>
|
||||
where
|
||||
I: IntoIterator<Item = i16>,
|
||||
J: IntoIterator<Item = T>,
|
||||
F: FnMut(T, &mut BytesMut) -> Result<IsNull, Box<dyn Error + marker::Sync + Send>>,
|
||||
K: IntoIterator<Item = i16>,
|
||||
{
|
||||
buf.put_u8(b'B');
|
||||
|
||||
write_body(buf, |buf| {
|
||||
write_cstr(portal.as_bytes(), buf)?;
|
||||
write_cstr(statement.as_bytes(), buf)?;
|
||||
write_counted(
|
||||
formats,
|
||||
|f, buf| {
|
||||
buf.put_i16(f);
|
||||
Ok::<_, io::Error>(())
|
||||
},
|
||||
buf,
|
||||
)?;
|
||||
write_counted(
|
||||
values,
|
||||
|v, buf| write_nullable(|buf| serializer(v, buf), buf),
|
||||
buf,
|
||||
)?;
|
||||
write_counted(
|
||||
result_formats,
|
||||
|f, buf| {
|
||||
buf.put_i16(f);
|
||||
Ok::<_, io::Error>(())
|
||||
},
|
||||
buf,
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn write_counted<I, T, F, E>(items: I, mut serializer: F, buf: &mut BytesMut) -> Result<(), E>
|
||||
where
|
||||
I: IntoIterator<Item = T>,
|
||||
F: FnMut(T, &mut BytesMut) -> Result<(), E>,
|
||||
E: From<io::Error>,
|
||||
{
|
||||
let base = buf.len();
|
||||
buf.extend_from_slice(&[0; 2]);
|
||||
let mut count = 0;
|
||||
for item in items {
|
||||
serializer(item, buf)?;
|
||||
count += 1;
|
||||
}
|
||||
let count = i16::from_usize(count)?;
|
||||
BigEndian::write_i16(&mut buf[base..], count);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn cancel_request(process_id: i32, secret_key: i32, buf: &mut BytesMut) {
|
||||
write_body(buf, |buf| {
|
||||
buf.put_i32(80_877_102);
|
||||
buf.put_i32(process_id);
|
||||
buf.put_i32(secret_key);
|
||||
Ok::<_, io::Error>(())
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn close(variant: u8, name: &str, buf: &mut BytesMut) -> io::Result<()> {
|
||||
buf.put_u8(b'C');
|
||||
write_body(buf, |buf| {
|
||||
buf.put_u8(variant);
|
||||
write_cstr(name.as_bytes(), buf)
|
||||
})
|
||||
}
|
||||
|
||||
pub struct CopyData<T> {
|
||||
buf: T,
|
||||
len: i32,
|
||||
}
|
||||
|
||||
impl<T> CopyData<T>
|
||||
where
|
||||
T: Buf,
|
||||
{
|
||||
pub fn new(buf: T) -> io::Result<CopyData<T>> {
|
||||
let len = buf
|
||||
.remaining()
|
||||
.checked_add(4)
|
||||
.and_then(|l| i32::try_from(l).ok())
|
||||
.ok_or_else(|| {
|
||||
io::Error::new(io::ErrorKind::InvalidInput, "message length overflow")
|
||||
})?;
|
||||
|
||||
Ok(CopyData { buf, len })
|
||||
}
|
||||
|
||||
pub fn write(self, out: &mut BytesMut) {
|
||||
out.put_u8(b'd');
|
||||
out.put_i32(self.len);
|
||||
out.put(self.buf);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn copy_done(buf: &mut BytesMut) {
|
||||
buf.put_u8(b'c');
|
||||
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn copy_fail(message: &str, buf: &mut BytesMut) -> io::Result<()> {
|
||||
buf.put_u8(b'f');
|
||||
write_body(buf, |buf| write_cstr(message.as_bytes(), buf))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn describe(variant: u8, name: &str, buf: &mut BytesMut) -> io::Result<()> {
|
||||
buf.put_u8(b'D');
|
||||
write_body(buf, |buf| {
|
||||
buf.put_u8(variant);
|
||||
write_cstr(name.as_bytes(), buf)
|
||||
})
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn execute(portal: &str, max_rows: i32, buf: &mut BytesMut) -> io::Result<()> {
|
||||
buf.put_u8(b'E');
|
||||
write_body(buf, |buf| {
|
||||
write_cstr(portal.as_bytes(), buf)?;
|
||||
buf.put_i32(max_rows);
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn parse<I>(name: &str, query: &str, param_types: I, buf: &mut BytesMut) -> io::Result<()>
|
||||
where
|
||||
I: IntoIterator<Item = Oid>,
|
||||
{
|
||||
buf.put_u8(b'P');
|
||||
write_body(buf, |buf| {
|
||||
write_cstr(name.as_bytes(), buf)?;
|
||||
write_cstr(query.as_bytes(), buf)?;
|
||||
write_counted(
|
||||
param_types,
|
||||
|t, buf| {
|
||||
buf.put_u32(t);
|
||||
Ok::<_, io::Error>(())
|
||||
},
|
||||
buf,
|
||||
)?;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn password_message(password: &[u8], buf: &mut BytesMut) -> io::Result<()> {
|
||||
buf.put_u8(b'p');
|
||||
write_body(buf, |buf| write_cstr(password, buf))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn query(query: &str, buf: &mut BytesMut) -> io::Result<()> {
|
||||
buf.put_u8(b'Q');
|
||||
write_body(buf, |buf| write_cstr(query.as_bytes(), buf))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn sasl_initial_response(mechanism: &str, data: &[u8], buf: &mut BytesMut) -> io::Result<()> {
|
||||
buf.put_u8(b'p');
|
||||
write_body(buf, |buf| {
|
||||
write_cstr(mechanism.as_bytes(), buf)?;
|
||||
let len = i32::from_usize(data.len())?;
|
||||
buf.put_i32(len);
|
||||
buf.put_slice(data);
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn sasl_response(data: &[u8], buf: &mut BytesMut) -> io::Result<()> {
|
||||
buf.put_u8(b'p');
|
||||
write_body(buf, |buf| {
|
||||
buf.put_slice(data);
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn ssl_request(buf: &mut BytesMut) {
|
||||
write_body(buf, |buf| {
|
||||
buf.put_i32(80_877_103);
|
||||
Ok::<_, io::Error>(())
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn startup_message<'a, I>(parameters: I, buf: &mut BytesMut) -> io::Result<()>
|
||||
where
|
||||
I: IntoIterator<Item = (&'a str, &'a str)>,
|
||||
{
|
||||
write_body(buf, |buf| {
|
||||
// postgres protocol version 3.0(196608) in bigger-endian
|
||||
buf.put_i32(0x00_03_00_00);
|
||||
for (key, value) in parameters {
|
||||
write_cstr(key.as_bytes(), buf)?;
|
||||
write_cstr(value.as_bytes(), buf)?;
|
||||
}
|
||||
buf.put_u8(0);
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn sync(buf: &mut BytesMut) {
|
||||
buf.put_u8(b'S');
|
||||
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn terminate(buf: &mut BytesMut) {
|
||||
buf.put_u8(b'X');
|
||||
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn write_cstr(s: &[u8], buf: &mut BytesMut) -> Result<(), io::Error> {
|
||||
if s.contains(&0) {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"string contains embedded null",
|
||||
));
|
||||
}
|
||||
buf.put_slice(s);
|
||||
buf.put_u8(0);
|
||||
Ok(())
|
||||
}
|
||||
8
libs/proxy/postgres-protocol2/src/message/mod.rs
Normal file
8
libs/proxy/postgres-protocol2/src/message/mod.rs
Normal file
@@ -0,0 +1,8 @@
|
||||
//! Postgres message protocol support.
|
||||
//!
|
||||
//! See [Postgres's documentation][docs] for more information on message flow.
|
||||
//!
|
||||
//! [docs]: https://www.postgresql.org/docs/9.5/static/protocol-flow.html
|
||||
|
||||
pub mod backend;
|
||||
pub mod frontend;
|
||||
107
libs/proxy/postgres-protocol2/src/password/mod.rs
Normal file
107
libs/proxy/postgres-protocol2/src/password/mod.rs
Normal file
@@ -0,0 +1,107 @@
|
||||
//! Functions to encrypt a password in the client.
|
||||
//!
|
||||
//! This is intended to be used by client applications that wish to
|
||||
//! send commands like `ALTER USER joe PASSWORD 'pwd'`. The password
|
||||
//! need not be sent in cleartext if it is encrypted on the client
|
||||
//! side. This is good because it ensures the cleartext password won't
|
||||
//! end up in logs pg_stat displays, etc.
|
||||
|
||||
use crate::authentication::sasl;
|
||||
use hmac::{Hmac, Mac};
|
||||
use md5::Md5;
|
||||
use rand::RngCore;
|
||||
use sha2::digest::FixedOutput;
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
|
||||
const SCRAM_DEFAULT_ITERATIONS: u32 = 4096;
|
||||
const SCRAM_DEFAULT_SALT_LEN: usize = 16;
|
||||
|
||||
/// Hash password using SCRAM-SHA-256 with a randomly-generated
|
||||
/// salt.
|
||||
///
|
||||
/// The client may assume the returned string doesn't contain any
|
||||
/// special characters that would require escaping in an SQL command.
|
||||
pub async fn scram_sha_256(password: &[u8]) -> String {
|
||||
let mut salt: [u8; SCRAM_DEFAULT_SALT_LEN] = [0; SCRAM_DEFAULT_SALT_LEN];
|
||||
let mut rng = rand::thread_rng();
|
||||
rng.fill_bytes(&mut salt);
|
||||
scram_sha_256_salt(password, salt).await
|
||||
}
|
||||
|
||||
// Internal implementation of scram_sha_256 with a caller-provided
|
||||
// salt. This is useful for testing.
|
||||
pub(crate) async fn scram_sha_256_salt(
|
||||
password: &[u8],
|
||||
salt: [u8; SCRAM_DEFAULT_SALT_LEN],
|
||||
) -> String {
|
||||
// Prepare the password, per [RFC
|
||||
// 4013](https://tools.ietf.org/html/rfc4013), if possible.
|
||||
//
|
||||
// Postgres treats passwords as byte strings (without embedded NUL
|
||||
// bytes), but SASL expects passwords to be valid UTF-8.
|
||||
//
|
||||
// Follow the behavior of libpq's PQencryptPasswordConn(), and
|
||||
// also the backend. If the password is not valid UTF-8, or if it
|
||||
// contains prohibited characters (such as non-ASCII whitespace),
|
||||
// just skip the SASLprep step and use the original byte
|
||||
// sequence.
|
||||
let prepared: Vec<u8> = match std::str::from_utf8(password) {
|
||||
Ok(password_str) => {
|
||||
match stringprep::saslprep(password_str) {
|
||||
Ok(p) => p.into_owned().into_bytes(),
|
||||
// contains invalid characters; skip saslprep
|
||||
Err(_) => Vec::from(password),
|
||||
}
|
||||
}
|
||||
// not valid UTF-8; skip saslprep
|
||||
Err(_) => Vec::from(password),
|
||||
};
|
||||
|
||||
// salt password
|
||||
let salted_password = sasl::hi(&prepared, &salt, SCRAM_DEFAULT_ITERATIONS).await;
|
||||
|
||||
// client key
|
||||
let mut hmac = Hmac::<Sha256>::new_from_slice(&salted_password)
|
||||
.expect("HMAC is able to accept all key sizes");
|
||||
hmac.update(b"Client Key");
|
||||
let client_key = hmac.finalize().into_bytes();
|
||||
|
||||
// stored key
|
||||
let mut hash = Sha256::default();
|
||||
hash.update(client_key.as_slice());
|
||||
let stored_key = hash.finalize_fixed();
|
||||
|
||||
// server key
|
||||
let mut hmac = Hmac::<Sha256>::new_from_slice(&salted_password)
|
||||
.expect("HMAC is able to accept all key sizes");
|
||||
hmac.update(b"Server Key");
|
||||
let server_key = hmac.finalize().into_bytes();
|
||||
|
||||
format!(
|
||||
"SCRAM-SHA-256${}:{}${}:{}",
|
||||
SCRAM_DEFAULT_ITERATIONS,
|
||||
base64::encode(salt),
|
||||
base64::encode(stored_key),
|
||||
base64::encode(server_key)
|
||||
)
|
||||
}
|
||||
|
||||
/// **Not recommended, as MD5 is not considered to be secure.**
|
||||
///
|
||||
/// Hash password using MD5 with the username as the salt.
|
||||
///
|
||||
/// The client may assume the returned string doesn't contain any
|
||||
/// special characters that would require escaping.
|
||||
pub fn md5(password: &[u8], username: &str) -> String {
|
||||
// salt password with username
|
||||
let mut salted_password = Vec::from(password);
|
||||
salted_password.extend_from_slice(username.as_bytes());
|
||||
|
||||
let mut hash = Md5::new();
|
||||
hash.update(&salted_password);
|
||||
let digest = hash.finalize();
|
||||
format!("md5{:x}", digest)
|
||||
}
|
||||
19
libs/proxy/postgres-protocol2/src/password/test.rs
Normal file
19
libs/proxy/postgres-protocol2/src/password/test.rs
Normal file
@@ -0,0 +1,19 @@
|
||||
use crate::password;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_encrypt_scram_sha_256() {
|
||||
// Specify the salt to make the test deterministic. Any bytes will do.
|
||||
let salt: [u8; 16] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
|
||||
assert_eq!(
|
||||
password::scram_sha_256_salt(b"secret", salt).await,
|
||||
"SCRAM-SHA-256$4096:AQIDBAUGBwgJCgsMDQ4PEA==$8rrDg00OqaiWXJ7p+sCgHEIaBSHY89ZJl3mfIsf32oY=:05L1f+yZbiN8O0AnO40Og85NNRhvzTS57naKRWCcsIA="
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encrypt_md5() {
|
||||
assert_eq!(
|
||||
password::md5(b"secret", "foo"),
|
||||
"md54ab2c5d00339c4b2a4e921d2dc4edec7"
|
||||
);
|
||||
}
|
||||
294
libs/proxy/postgres-protocol2/src/types/mod.rs
Normal file
294
libs/proxy/postgres-protocol2/src/types/mod.rs
Normal file
@@ -0,0 +1,294 @@
|
||||
//! Conversions to and from Postgres's binary format for various types.
|
||||
use byteorder::{BigEndian, ReadBytesExt};
|
||||
use bytes::{BufMut, BytesMut};
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use std::boxed::Box as StdBox;
|
||||
use std::error::Error;
|
||||
use std::str;
|
||||
|
||||
use crate::Oid;
|
||||
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
|
||||
/// Serializes a `TEXT`, `VARCHAR`, `CHAR(n)`, `NAME`, or `CITEXT` value.
|
||||
#[inline]
|
||||
pub fn text_to_sql(v: &str, buf: &mut BytesMut) {
|
||||
buf.put_slice(v.as_bytes());
|
||||
}
|
||||
|
||||
/// Deserializes a `TEXT`, `VARCHAR`, `CHAR(n)`, `NAME`, or `CITEXT` value.
|
||||
#[inline]
|
||||
pub fn text_from_sql(buf: &[u8]) -> Result<&str, StdBox<dyn Error + Sync + Send>> {
|
||||
Ok(str::from_utf8(buf)?)
|
||||
}
|
||||
|
||||
/// Deserializes a `"char"` value.
|
||||
#[inline]
|
||||
pub fn char_from_sql(mut buf: &[u8]) -> Result<i8, StdBox<dyn Error + Sync + Send>> {
|
||||
let v = buf.read_i8()?;
|
||||
if !buf.is_empty() {
|
||||
return Err("invalid buffer size".into());
|
||||
}
|
||||
Ok(v)
|
||||
}
|
||||
|
||||
/// Serializes an `OID` value.
|
||||
#[inline]
|
||||
pub fn oid_to_sql(v: Oid, buf: &mut BytesMut) {
|
||||
buf.put_u32(v);
|
||||
}
|
||||
|
||||
/// Deserializes an `OID` value.
|
||||
#[inline]
|
||||
pub fn oid_from_sql(mut buf: &[u8]) -> Result<Oid, StdBox<dyn Error + Sync + Send>> {
|
||||
let v = buf.read_u32::<BigEndian>()?;
|
||||
if !buf.is_empty() {
|
||||
return Err("invalid buffer size".into());
|
||||
}
|
||||
Ok(v)
|
||||
}
|
||||
|
||||
/// A fallible iterator over `HSTORE` entries.
|
||||
pub struct HstoreEntries<'a> {
|
||||
remaining: i32,
|
||||
buf: &'a [u8],
|
||||
}
|
||||
|
||||
impl<'a> FallibleIterator for HstoreEntries<'a> {
|
||||
type Item = (&'a str, Option<&'a str>);
|
||||
type Error = StdBox<dyn Error + Sync + Send>;
|
||||
|
||||
#[inline]
|
||||
#[allow(clippy::type_complexity)]
|
||||
fn next(
|
||||
&mut self,
|
||||
) -> Result<Option<(&'a str, Option<&'a str>)>, StdBox<dyn Error + Sync + Send>> {
|
||||
if self.remaining == 0 {
|
||||
if !self.buf.is_empty() {
|
||||
return Err("invalid buffer size".into());
|
||||
}
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
self.remaining -= 1;
|
||||
|
||||
let key_len = self.buf.read_i32::<BigEndian>()?;
|
||||
if key_len < 0 {
|
||||
return Err("invalid key length".into());
|
||||
}
|
||||
let (key, buf) = self.buf.split_at(key_len as usize);
|
||||
let key = str::from_utf8(key)?;
|
||||
self.buf = buf;
|
||||
|
||||
let value_len = self.buf.read_i32::<BigEndian>()?;
|
||||
let value = if value_len < 0 {
|
||||
None
|
||||
} else {
|
||||
let (value, buf) = self.buf.split_at(value_len as usize);
|
||||
let value = str::from_utf8(value)?;
|
||||
self.buf = buf;
|
||||
Some(value)
|
||||
};
|
||||
|
||||
Ok(Some((key, value)))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn size_hint(&self) -> (usize, Option<usize>) {
|
||||
let len = self.remaining as usize;
|
||||
(len, Some(len))
|
||||
}
|
||||
}
|
||||
|
||||
/// Deserializes an array value.
|
||||
#[inline]
|
||||
pub fn array_from_sql(mut buf: &[u8]) -> Result<Array<'_>, StdBox<dyn Error + Sync + Send>> {
|
||||
let dimensions = buf.read_i32::<BigEndian>()?;
|
||||
if dimensions < 0 {
|
||||
return Err("invalid dimension count".into());
|
||||
}
|
||||
|
||||
let mut r = buf;
|
||||
let mut elements = 1i32;
|
||||
for _ in 0..dimensions {
|
||||
let len = r.read_i32::<BigEndian>()?;
|
||||
if len < 0 {
|
||||
return Err("invalid dimension size".into());
|
||||
}
|
||||
let _lower_bound = r.read_i32::<BigEndian>()?;
|
||||
elements = match elements.checked_mul(len) {
|
||||
Some(elements) => elements,
|
||||
None => return Err("too many array elements".into()),
|
||||
};
|
||||
}
|
||||
|
||||
if dimensions == 0 {
|
||||
elements = 0;
|
||||
}
|
||||
|
||||
Ok(Array {
|
||||
dimensions,
|
||||
elements,
|
||||
buf,
|
||||
})
|
||||
}
|
||||
|
||||
/// A Postgres array.
|
||||
pub struct Array<'a> {
|
||||
dimensions: i32,
|
||||
elements: i32,
|
||||
buf: &'a [u8],
|
||||
}
|
||||
|
||||
impl<'a> Array<'a> {
|
||||
/// Returns an iterator over the dimensions of the array.
|
||||
#[inline]
|
||||
pub fn dimensions(&self) -> ArrayDimensions<'a> {
|
||||
ArrayDimensions(&self.buf[..self.dimensions as usize * 8])
|
||||
}
|
||||
|
||||
/// Returns an iterator over the values of the array.
|
||||
#[inline]
|
||||
pub fn values(&self) -> ArrayValues<'a> {
|
||||
ArrayValues {
|
||||
remaining: self.elements,
|
||||
buf: &self.buf[self.dimensions as usize * 8..],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// An iterator over the dimensions of an array.
|
||||
pub struct ArrayDimensions<'a>(&'a [u8]);
|
||||
|
||||
impl FallibleIterator for ArrayDimensions<'_> {
|
||||
type Item = ArrayDimension;
|
||||
type Error = StdBox<dyn Error + Sync + Send>;
|
||||
|
||||
#[inline]
|
||||
fn next(&mut self) -> Result<Option<ArrayDimension>, StdBox<dyn Error + Sync + Send>> {
|
||||
if self.0.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let len = self.0.read_i32::<BigEndian>()?;
|
||||
let lower_bound = self.0.read_i32::<BigEndian>()?;
|
||||
|
||||
Ok(Some(ArrayDimension { len, lower_bound }))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn size_hint(&self) -> (usize, Option<usize>) {
|
||||
let len = self.0.len() / 8;
|
||||
(len, Some(len))
|
||||
}
|
||||
}
|
||||
|
||||
/// Information about a dimension of an array.
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
pub struct ArrayDimension {
|
||||
/// The length of this dimension.
|
||||
pub len: i32,
|
||||
|
||||
/// The base value used to index into this dimension.
|
||||
pub lower_bound: i32,
|
||||
}
|
||||
|
||||
/// An iterator over the values of an array, in row-major order.
|
||||
pub struct ArrayValues<'a> {
|
||||
remaining: i32,
|
||||
buf: &'a [u8],
|
||||
}
|
||||
|
||||
impl<'a> FallibleIterator for ArrayValues<'a> {
|
||||
type Item = Option<&'a [u8]>;
|
||||
type Error = StdBox<dyn Error + Sync + Send>;
|
||||
|
||||
#[inline]
|
||||
fn next(&mut self) -> Result<Option<Option<&'a [u8]>>, StdBox<dyn Error + Sync + Send>> {
|
||||
if self.remaining == 0 {
|
||||
if !self.buf.is_empty() {
|
||||
return Err("invalid message length: arrayvalue not drained".into());
|
||||
}
|
||||
return Ok(None);
|
||||
}
|
||||
self.remaining -= 1;
|
||||
|
||||
let len = self.buf.read_i32::<BigEndian>()?;
|
||||
let val = if len < 0 {
|
||||
None
|
||||
} else {
|
||||
if self.buf.len() < len as usize {
|
||||
return Err("invalid value length".into());
|
||||
}
|
||||
|
||||
let (val, buf) = self.buf.split_at(len as usize);
|
||||
self.buf = buf;
|
||||
Some(val)
|
||||
};
|
||||
|
||||
Ok(Some(val))
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> (usize, Option<usize>) {
|
||||
let len = self.remaining as usize;
|
||||
(len, Some(len))
|
||||
}
|
||||
}
|
||||
|
||||
/// Serializes a Postgres ltree string
|
||||
#[inline]
|
||||
pub fn ltree_to_sql(v: &str, buf: &mut BytesMut) {
|
||||
// A version number is prepended to an ltree string per spec
|
||||
buf.put_u8(1);
|
||||
// Append the rest of the query
|
||||
buf.put_slice(v.as_bytes());
|
||||
}
|
||||
|
||||
/// Deserialize a Postgres ltree string
|
||||
#[inline]
|
||||
pub fn ltree_from_sql(buf: &[u8]) -> Result<&str, StdBox<dyn Error + Sync + Send>> {
|
||||
match buf {
|
||||
// Remove the version number from the front of the ltree per spec
|
||||
[1u8, rest @ ..] => Ok(str::from_utf8(rest)?),
|
||||
_ => Err("ltree version 1 only supported".into()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Serializes a Postgres lquery string
|
||||
#[inline]
|
||||
pub fn lquery_to_sql(v: &str, buf: &mut BytesMut) {
|
||||
// A version number is prepended to an lquery string per spec
|
||||
buf.put_u8(1);
|
||||
// Append the rest of the query
|
||||
buf.put_slice(v.as_bytes());
|
||||
}
|
||||
|
||||
/// Deserialize a Postgres lquery string
|
||||
#[inline]
|
||||
pub fn lquery_from_sql(buf: &[u8]) -> Result<&str, StdBox<dyn Error + Sync + Send>> {
|
||||
match buf {
|
||||
// Remove the version number from the front of the lquery per spec
|
||||
[1u8, rest @ ..] => Ok(str::from_utf8(rest)?),
|
||||
_ => Err("lquery version 1 only supported".into()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Serializes a Postgres ltxtquery string
|
||||
#[inline]
|
||||
pub fn ltxtquery_to_sql(v: &str, buf: &mut BytesMut) {
|
||||
// A version number is prepended to an ltxtquery string per spec
|
||||
buf.put_u8(1);
|
||||
// Append the rest of the query
|
||||
buf.put_slice(v.as_bytes());
|
||||
}
|
||||
|
||||
/// Deserialize a Postgres ltxtquery string
|
||||
#[inline]
|
||||
pub fn ltxtquery_from_sql(buf: &[u8]) -> Result<&str, StdBox<dyn Error + Sync + Send>> {
|
||||
match buf {
|
||||
// Remove the version number from the front of the ltxtquery per spec
|
||||
[1u8, rest @ ..] => Ok(str::from_utf8(rest)?),
|
||||
_ => Err("ltxtquery version 1 only supported".into()),
|
||||
}
|
||||
}
|
||||
87
libs/proxy/postgres-protocol2/src/types/test.rs
Normal file
87
libs/proxy/postgres-protocol2/src/types/test.rs
Normal file
@@ -0,0 +1,87 @@
|
||||
use bytes::{Buf, BytesMut};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn ltree_sql() {
|
||||
let mut query = vec![1u8];
|
||||
query.extend_from_slice("A.B.C".as_bytes());
|
||||
|
||||
let mut buf = BytesMut::new();
|
||||
|
||||
ltree_to_sql("A.B.C", &mut buf);
|
||||
|
||||
assert_eq!(query.as_slice(), buf.chunk());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ltree_str() {
|
||||
let mut query = vec![1u8];
|
||||
query.extend_from_slice("A.B.C".as_bytes());
|
||||
|
||||
assert!(ltree_from_sql(query.as_slice()).is_ok())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ltree_wrong_version() {
|
||||
let mut query = vec![2u8];
|
||||
query.extend_from_slice("A.B.C".as_bytes());
|
||||
|
||||
assert!(ltree_from_sql(query.as_slice()).is_err())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lquery_sql() {
|
||||
let mut query = vec![1u8];
|
||||
query.extend_from_slice("A.B.C".as_bytes());
|
||||
|
||||
let mut buf = BytesMut::new();
|
||||
|
||||
lquery_to_sql("A.B.C", &mut buf);
|
||||
|
||||
assert_eq!(query.as_slice(), buf.chunk());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lquery_str() {
|
||||
let mut query = vec![1u8];
|
||||
query.extend_from_slice("A.B.C".as_bytes());
|
||||
|
||||
assert!(lquery_from_sql(query.as_slice()).is_ok())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lquery_wrong_version() {
|
||||
let mut query = vec![2u8];
|
||||
query.extend_from_slice("A.B.C".as_bytes());
|
||||
|
||||
assert!(lquery_from_sql(query.as_slice()).is_err())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ltxtquery_sql() {
|
||||
let mut query = vec![1u8];
|
||||
query.extend_from_slice("a & b*".as_bytes());
|
||||
|
||||
let mut buf = BytesMut::new();
|
||||
|
||||
ltree_to_sql("a & b*", &mut buf);
|
||||
|
||||
assert_eq!(query.as_slice(), buf.chunk());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ltxtquery_str() {
|
||||
let mut query = vec![1u8];
|
||||
query.extend_from_slice("a & b*".as_bytes());
|
||||
|
||||
assert!(ltree_from_sql(query.as_slice()).is_ok())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ltxtquery_wrong_version() {
|
||||
let mut query = vec![2u8];
|
||||
query.extend_from_slice("a & b*".as_bytes());
|
||||
|
||||
assert!(ltree_from_sql(query.as_slice()).is_err())
|
||||
}
|
||||
10
libs/proxy/postgres-types2/Cargo.toml
Normal file
10
libs/proxy/postgres-types2/Cargo.toml
Normal file
@@ -0,0 +1,10 @@
|
||||
[package]
|
||||
name = "postgres-types2"
|
||||
version = "0.1.0"
|
||||
edition = "2018"
|
||||
license = "MIT/Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
bytes.workspace = true
|
||||
fallible-iterator.workspace = true
|
||||
postgres-protocol2 = { path = "../postgres-protocol2" }
|
||||
477
libs/proxy/postgres-types2/src/lib.rs
Normal file
477
libs/proxy/postgres-types2/src/lib.rs
Normal file
@@ -0,0 +1,477 @@
|
||||
//! Conversions to and from Postgres types.
|
||||
//!
|
||||
//! This crate is used by the `tokio-postgres` and `postgres` crates. You normally don't need to depend directly on it
|
||||
//! unless you want to define your own `ToSql` or `FromSql` definitions.
|
||||
#![doc(html_root_url = "https://docs.rs/postgres-types/0.2")]
|
||||
#![warn(clippy::all, rust_2018_idioms, missing_docs)]
|
||||
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use postgres_protocol2::types;
|
||||
use std::any::type_name;
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::type_gen::{Inner, Other};
|
||||
|
||||
#[doc(inline)]
|
||||
pub use postgres_protocol2::Oid;
|
||||
|
||||
use bytes::BytesMut;
|
||||
|
||||
/// Generates a simple implementation of `ToSql::accepts` which accepts the
|
||||
/// types passed to it.
|
||||
macro_rules! accepts {
|
||||
($($expected:ident),+) => (
|
||||
fn accepts(ty: &$crate::Type) -> bool {
|
||||
matches!(*ty, $($crate::Type::$expected)|+)
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
/// Generates an implementation of `ToSql::to_sql_checked`.
|
||||
///
|
||||
/// All `ToSql` implementations should use this macro.
|
||||
macro_rules! to_sql_checked {
|
||||
() => {
|
||||
fn to_sql_checked(
|
||||
&self,
|
||||
ty: &$crate::Type,
|
||||
out: &mut $crate::private::BytesMut,
|
||||
) -> ::std::result::Result<
|
||||
$crate::IsNull,
|
||||
Box<dyn ::std::error::Error + ::std::marker::Sync + ::std::marker::Send>,
|
||||
> {
|
||||
$crate::__to_sql_checked(self, ty, out)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// WARNING: this function is not considered part of this crate's public API.
|
||||
// It is subject to change at any time.
|
||||
#[doc(hidden)]
|
||||
pub fn __to_sql_checked<T>(
|
||||
v: &T,
|
||||
ty: &Type,
|
||||
out: &mut BytesMut,
|
||||
) -> Result<IsNull, Box<dyn Error + Sync + Send>>
|
||||
where
|
||||
T: ToSql,
|
||||
{
|
||||
if !T::accepts(ty) {
|
||||
return Err(Box::new(WrongType::new::<T>(ty.clone())));
|
||||
}
|
||||
v.to_sql(ty, out)
|
||||
}
|
||||
|
||||
// mod pg_lsn;
|
||||
#[doc(hidden)]
|
||||
pub mod private;
|
||||
// mod special;
|
||||
mod type_gen;
|
||||
|
||||
/// A Postgres type.
|
||||
#[derive(PartialEq, Eq, Clone, Hash)]
|
||||
pub struct Type(Inner);
|
||||
|
||||
impl fmt::Debug for Type {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
fmt::Debug::fmt(&self.0, fmt)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Type {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self.schema() {
|
||||
"public" | "pg_catalog" => {}
|
||||
schema => write!(fmt, "{}.", schema)?,
|
||||
}
|
||||
fmt.write_str(self.name())
|
||||
}
|
||||
}
|
||||
|
||||
impl Type {
|
||||
/// Creates a new `Type`.
|
||||
pub fn new(name: String, oid: Oid, kind: Kind, schema: String) -> Type {
|
||||
Type(Inner::Other(Arc::new(Other {
|
||||
name,
|
||||
oid,
|
||||
kind,
|
||||
schema,
|
||||
})))
|
||||
}
|
||||
|
||||
/// Returns the `Type` corresponding to the provided `Oid` if it
|
||||
/// corresponds to a built-in type.
|
||||
pub fn from_oid(oid: Oid) -> Option<Type> {
|
||||
Inner::from_oid(oid).map(Type)
|
||||
}
|
||||
|
||||
/// Returns the OID of the `Type`.
|
||||
pub fn oid(&self) -> Oid {
|
||||
self.0.oid()
|
||||
}
|
||||
|
||||
/// Returns the kind of this type.
|
||||
pub fn kind(&self) -> &Kind {
|
||||
self.0.kind()
|
||||
}
|
||||
|
||||
/// Returns the schema of this type.
|
||||
pub fn schema(&self) -> &str {
|
||||
match self.0 {
|
||||
Inner::Other(ref u) => &u.schema,
|
||||
_ => "pg_catalog",
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the name of this type.
|
||||
pub fn name(&self) -> &str {
|
||||
self.0.name()
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents the kind of a Postgres type.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
#[non_exhaustive]
|
||||
pub enum Kind {
|
||||
/// A simple type like `VARCHAR` or `INTEGER`.
|
||||
Simple,
|
||||
/// An enumerated type along with its variants.
|
||||
Enum(Vec<String>),
|
||||
/// A pseudo-type.
|
||||
Pseudo,
|
||||
/// An array type along with the type of its elements.
|
||||
Array(Type),
|
||||
/// A range type along with the type of its elements.
|
||||
Range(Type),
|
||||
/// A multirange type along with the type of its elements.
|
||||
Multirange(Type),
|
||||
/// A domain type along with its underlying type.
|
||||
Domain(Type),
|
||||
/// A composite type along with information about its fields.
|
||||
Composite(Vec<Field>),
|
||||
}
|
||||
|
||||
/// Information about a field of a composite type.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub struct Field {
|
||||
name: String,
|
||||
type_: Type,
|
||||
}
|
||||
|
||||
impl Field {
|
||||
/// Creates a new `Field`.
|
||||
pub fn new(name: String, type_: Type) -> Field {
|
||||
Field { name, type_ }
|
||||
}
|
||||
|
||||
/// Returns the name of the field.
|
||||
pub fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
/// Returns the type of the field.
|
||||
pub fn type_(&self) -> &Type {
|
||||
&self.type_
|
||||
}
|
||||
}
|
||||
|
||||
/// An error indicating that a `NULL` Postgres value was passed to a `FromSql`
|
||||
/// implementation that does not support `NULL` values.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct WasNull;
|
||||
|
||||
impl fmt::Display for WasNull {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
fmt.write_str("a Postgres value was `NULL`")
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for WasNull {}
|
||||
|
||||
/// An error indicating that a conversion was attempted between incompatible
|
||||
/// Rust and Postgres types.
|
||||
#[derive(Debug)]
|
||||
pub struct WrongType {
|
||||
postgres: Type,
|
||||
rust: &'static str,
|
||||
}
|
||||
|
||||
impl fmt::Display for WrongType {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(
|
||||
fmt,
|
||||
"cannot convert between the Rust type `{}` and the Postgres type `{}`",
|
||||
self.rust, self.postgres,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for WrongType {}
|
||||
|
||||
impl WrongType {
|
||||
/// Creates a new `WrongType` error.
|
||||
pub fn new<T>(ty: Type) -> WrongType {
|
||||
WrongType {
|
||||
postgres: ty,
|
||||
rust: type_name::<T>(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// An error indicating that a as_text conversion was attempted on a binary
|
||||
/// result.
|
||||
#[derive(Debug)]
|
||||
pub struct WrongFormat {}
|
||||
|
||||
impl Error for WrongFormat {}
|
||||
|
||||
impl fmt::Display for WrongFormat {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(
|
||||
fmt,
|
||||
"cannot read column as text while it is in binary format"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// A trait for types that can be created from a Postgres value.
|
||||
pub trait FromSql<'a>: Sized {
|
||||
/// Creates a new value of this type from a buffer of data of the specified
|
||||
/// Postgres `Type` in its binary format.
|
||||
///
|
||||
/// The caller of this method is responsible for ensuring that this type
|
||||
/// is compatible with the Postgres `Type`.
|
||||
fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<Self, Box<dyn Error + Sync + Send>>;
|
||||
|
||||
/// Creates a new value of this type from a `NULL` SQL value.
|
||||
///
|
||||
/// The caller of this method is responsible for ensuring that this type
|
||||
/// is compatible with the Postgres `Type`.
|
||||
///
|
||||
/// The default implementation returns `Err(Box::new(WasNull))`.
|
||||
#[allow(unused_variables)]
|
||||
fn from_sql_null(ty: &Type) -> Result<Self, Box<dyn Error + Sync + Send>> {
|
||||
Err(Box::new(WasNull))
|
||||
}
|
||||
|
||||
/// A convenience function that delegates to `from_sql` and `from_sql_null` depending on the
|
||||
/// value of `raw`.
|
||||
fn from_sql_nullable(
|
||||
ty: &Type,
|
||||
raw: Option<&'a [u8]>,
|
||||
) -> Result<Self, Box<dyn Error + Sync + Send>> {
|
||||
match raw {
|
||||
Some(raw) => Self::from_sql(ty, raw),
|
||||
None => Self::from_sql_null(ty),
|
||||
}
|
||||
}
|
||||
|
||||
/// Determines if a value of this type can be created from the specified
|
||||
/// Postgres `Type`.
|
||||
fn accepts(ty: &Type) -> bool;
|
||||
}
|
||||
|
||||
/// A trait for types which can be created from a Postgres value without borrowing any data.
|
||||
///
|
||||
/// This is primarily useful for trait bounds on functions.
|
||||
pub trait FromSqlOwned: for<'a> FromSql<'a> {}
|
||||
|
||||
impl<T> FromSqlOwned for T where T: for<'a> FromSql<'a> {}
|
||||
|
||||
impl<'a, T: FromSql<'a>> FromSql<'a> for Option<T> {
|
||||
fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<Option<T>, Box<dyn Error + Sync + Send>> {
|
||||
<T as FromSql>::from_sql(ty, raw).map(Some)
|
||||
}
|
||||
|
||||
fn from_sql_null(_: &Type) -> Result<Option<T>, Box<dyn Error + Sync + Send>> {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn accepts(ty: &Type) -> bool {
|
||||
<T as FromSql>::accepts(ty)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: FromSql<'a>> FromSql<'a> for Vec<T> {
|
||||
fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<Vec<T>, Box<dyn Error + Sync + Send>> {
|
||||
let member_type = match *ty.kind() {
|
||||
Kind::Array(ref member) => member,
|
||||
_ => panic!("expected array type"),
|
||||
};
|
||||
|
||||
let array = types::array_from_sql(raw)?;
|
||||
if array.dimensions().count()? > 1 {
|
||||
return Err("array contains too many dimensions".into());
|
||||
}
|
||||
|
||||
array
|
||||
.values()
|
||||
.map(|v| T::from_sql_nullable(member_type, v))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn accepts(ty: &Type) -> bool {
|
||||
match *ty.kind() {
|
||||
Kind::Array(ref inner) => T::accepts(inner),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> FromSql<'a> for String {
|
||||
fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<String, Box<dyn Error + Sync + Send>> {
|
||||
<&str as FromSql>::from_sql(ty, raw).map(ToString::to_string)
|
||||
}
|
||||
|
||||
fn accepts(ty: &Type) -> bool {
|
||||
<&str as FromSql>::accepts(ty)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> FromSql<'a> for &'a str {
|
||||
fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<&'a str, Box<dyn Error + Sync + Send>> {
|
||||
match *ty {
|
||||
ref ty if ty.name() == "ltree" => types::ltree_from_sql(raw),
|
||||
ref ty if ty.name() == "lquery" => types::lquery_from_sql(raw),
|
||||
ref ty if ty.name() == "ltxtquery" => types::ltxtquery_from_sql(raw),
|
||||
_ => types::text_from_sql(raw),
|
||||
}
|
||||
}
|
||||
|
||||
fn accepts(ty: &Type) -> bool {
|
||||
match *ty {
|
||||
Type::VARCHAR | Type::TEXT | Type::BPCHAR | Type::NAME | Type::UNKNOWN => true,
|
||||
ref ty
|
||||
if (ty.name() == "citext"
|
||||
|| ty.name() == "ltree"
|
||||
|| ty.name() == "lquery"
|
||||
|| ty.name() == "ltxtquery") =>
|
||||
{
|
||||
true
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! simple_from {
|
||||
($t:ty, $f:ident, $($expected:ident),+) => {
|
||||
impl<'a> FromSql<'a> for $t {
|
||||
fn from_sql(_: &Type, raw: &'a [u8]) -> Result<$t, Box<dyn Error + Sync + Send>> {
|
||||
types::$f(raw)
|
||||
}
|
||||
|
||||
accepts!($($expected),+);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
simple_from!(i8, char_from_sql, CHAR);
|
||||
simple_from!(u32, oid_from_sql, OID);
|
||||
|
||||
/// An enum representing the nullability of a Postgres value.
|
||||
pub enum IsNull {
|
||||
/// The value is NULL.
|
||||
Yes,
|
||||
/// The value is not NULL.
|
||||
No,
|
||||
}
|
||||
|
||||
/// A trait for types that can be converted into Postgres values.
|
||||
pub trait ToSql: fmt::Debug {
|
||||
/// Converts the value of `self` into the binary format of the specified
|
||||
/// Postgres `Type`, appending it to `out`.
|
||||
///
|
||||
/// The caller of this method is responsible for ensuring that this type
|
||||
/// is compatible with the Postgres `Type`.
|
||||
///
|
||||
/// The return value indicates if this value should be represented as
|
||||
/// `NULL`. If this is the case, implementations **must not** write
|
||||
/// anything to `out`.
|
||||
fn to_sql(&self, ty: &Type, out: &mut BytesMut) -> Result<IsNull, Box<dyn Error + Sync + Send>>
|
||||
where
|
||||
Self: Sized;
|
||||
|
||||
/// Determines if a value of this type can be converted to the specified
|
||||
/// Postgres `Type`.
|
||||
fn accepts(ty: &Type) -> bool
|
||||
where
|
||||
Self: Sized;
|
||||
|
||||
/// An adaptor method used internally by Rust-Postgres.
|
||||
///
|
||||
/// *All* implementations of this method should be generated by the
|
||||
/// `to_sql_checked!()` macro.
|
||||
fn to_sql_checked(
|
||||
&self,
|
||||
ty: &Type,
|
||||
out: &mut BytesMut,
|
||||
) -> Result<IsNull, Box<dyn Error + Sync + Send>>;
|
||||
|
||||
/// Specify the encode format
|
||||
fn encode_format(&self, _ty: &Type) -> Format {
|
||||
Format::Binary
|
||||
}
|
||||
}
|
||||
|
||||
/// Supported Postgres message format types
|
||||
///
|
||||
/// Using Text format in a message assumes a Postgres `SERVER_ENCODING` of `UTF8`
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
pub enum Format {
|
||||
/// Text format (UTF-8)
|
||||
Text,
|
||||
/// Compact, typed binary format
|
||||
Binary,
|
||||
}
|
||||
|
||||
impl ToSql for &str {
|
||||
fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result<IsNull, Box<dyn Error + Sync + Send>> {
|
||||
match *ty {
|
||||
ref ty if ty.name() == "ltree" => types::ltree_to_sql(self, w),
|
||||
ref ty if ty.name() == "lquery" => types::lquery_to_sql(self, w),
|
||||
ref ty if ty.name() == "ltxtquery" => types::ltxtquery_to_sql(self, w),
|
||||
_ => types::text_to_sql(self, w),
|
||||
}
|
||||
Ok(IsNull::No)
|
||||
}
|
||||
|
||||
fn accepts(ty: &Type) -> bool {
|
||||
match *ty {
|
||||
Type::VARCHAR | Type::TEXT | Type::BPCHAR | Type::NAME | Type::UNKNOWN => true,
|
||||
ref ty
|
||||
if (ty.name() == "citext"
|
||||
|| ty.name() == "ltree"
|
||||
|| ty.name() == "lquery"
|
||||
|| ty.name() == "ltxtquery") =>
|
||||
{
|
||||
true
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
to_sql_checked!();
|
||||
}
|
||||
|
||||
macro_rules! simple_to {
|
||||
($t:ty, $f:ident, $($expected:ident),+) => {
|
||||
impl ToSql for $t {
|
||||
fn to_sql(&self,
|
||||
_: &Type,
|
||||
w: &mut BytesMut)
|
||||
-> Result<IsNull, Box<dyn Error + Sync + Send>> {
|
||||
types::$f(*self, w);
|
||||
Ok(IsNull::No)
|
||||
}
|
||||
|
||||
accepts!($($expected),+);
|
||||
|
||||
to_sql_checked!();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
simple_to!(u32, oid_to_sql, OID);
|
||||
34
libs/proxy/postgres-types2/src/private.rs
Normal file
34
libs/proxy/postgres-types2/src/private.rs
Normal file
@@ -0,0 +1,34 @@
|
||||
use crate::{FromSql, Type};
|
||||
pub use bytes::BytesMut;
|
||||
use std::error::Error;
|
||||
|
||||
pub fn read_be_i32(buf: &mut &[u8]) -> Result<i32, Box<dyn Error + Sync + Send>> {
|
||||
if buf.len() < 4 {
|
||||
return Err("invalid buffer size".into());
|
||||
}
|
||||
let mut bytes = [0; 4];
|
||||
bytes.copy_from_slice(&buf[..4]);
|
||||
*buf = &buf[4..];
|
||||
Ok(i32::from_be_bytes(bytes))
|
||||
}
|
||||
|
||||
pub fn read_value<'a, T>(
|
||||
type_: &Type,
|
||||
buf: &mut &'a [u8],
|
||||
) -> Result<T, Box<dyn Error + Sync + Send>>
|
||||
where
|
||||
T: FromSql<'a>,
|
||||
{
|
||||
let len = read_be_i32(buf)?;
|
||||
let value = if len < 0 {
|
||||
None
|
||||
} else {
|
||||
if len as usize > buf.len() {
|
||||
return Err("invalid buffer size".into());
|
||||
}
|
||||
let (head, tail) = buf.split_at(len as usize);
|
||||
*buf = tail;
|
||||
Some(head)
|
||||
};
|
||||
T::from_sql_nullable(type_, value)
|
||||
}
|
||||
1524
libs/proxy/postgres-types2/src/type_gen.rs
Normal file
1524
libs/proxy/postgres-types2/src/type_gen.rs
Normal file
File diff suppressed because it is too large
Load Diff
21
libs/proxy/tokio-postgres2/Cargo.toml
Normal file
21
libs/proxy/tokio-postgres2/Cargo.toml
Normal file
@@ -0,0 +1,21 @@
|
||||
[package]
|
||||
name = "tokio-postgres2"
|
||||
version = "0.1.0"
|
||||
edition = "2018"
|
||||
license = "MIT/Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
async-trait.workspace = true
|
||||
bytes.workspace = true
|
||||
byteorder.workspace = true
|
||||
fallible-iterator.workspace = true
|
||||
futures-util = { workspace = true, features = ["sink"] }
|
||||
log = "0.4"
|
||||
parking_lot.workspace = true
|
||||
percent-encoding = "2.0"
|
||||
pin-project-lite.workspace = true
|
||||
phf = "0.11"
|
||||
postgres-protocol2 = { path = "../postgres-protocol2" }
|
||||
postgres-types2 = { path = "../postgres-types2" }
|
||||
tokio = { workspace = true, features = ["io-util", "time", "net"] }
|
||||
tokio-util = { workspace = true, features = ["codec"] }
|
||||
40
libs/proxy/tokio-postgres2/src/cancel_query.rs
Normal file
40
libs/proxy/tokio-postgres2/src/cancel_query.rs
Normal file
@@ -0,0 +1,40 @@
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
use crate::client::SocketConfig;
|
||||
use crate::config::{Host, SslMode};
|
||||
use crate::tls::MakeTlsConnect;
|
||||
use crate::{cancel_query_raw, connect_socket, Error};
|
||||
use std::io;
|
||||
|
||||
pub(crate) async fn cancel_query<T>(
|
||||
config: Option<SocketConfig>,
|
||||
ssl_mode: SslMode,
|
||||
mut tls: T,
|
||||
process_id: i32,
|
||||
secret_key: i32,
|
||||
) -> Result<(), Error>
|
||||
where
|
||||
T: MakeTlsConnect<TcpStream>,
|
||||
{
|
||||
let config = match config {
|
||||
Some(config) => config,
|
||||
None => {
|
||||
return Err(Error::connect(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"unknown host",
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
let hostname = match &config.host {
|
||||
Host::Tcp(host) => &**host,
|
||||
};
|
||||
let tls = tls
|
||||
.make_tls_connect(hostname)
|
||||
.map_err(|e| Error::tls(e.into()))?;
|
||||
|
||||
let socket =
|
||||
connect_socket::connect_socket(&config.host, config.port, config.connect_timeout).await?;
|
||||
|
||||
cancel_query_raw::cancel_query_raw(socket, ssl_mode, tls, process_id, secret_key).await
|
||||
}
|
||||
29
libs/proxy/tokio-postgres2/src/cancel_query_raw.rs
Normal file
29
libs/proxy/tokio-postgres2/src/cancel_query_raw.rs
Normal file
@@ -0,0 +1,29 @@
|
||||
use crate::config::SslMode;
|
||||
use crate::tls::TlsConnect;
|
||||
use crate::{connect_tls, Error};
|
||||
use bytes::BytesMut;
|
||||
use postgres_protocol2::message::frontend;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
|
||||
pub async fn cancel_query_raw<S, T>(
|
||||
stream: S,
|
||||
mode: SslMode,
|
||||
tls: T,
|
||||
process_id: i32,
|
||||
secret_key: i32,
|
||||
) -> Result<(), Error>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: TlsConnect<S>,
|
||||
{
|
||||
let mut stream = connect_tls::connect_tls(stream, mode, tls).await?;
|
||||
|
||||
let mut buf = BytesMut::new();
|
||||
frontend::cancel_request(process_id, secret_key, &mut buf);
|
||||
|
||||
stream.write_all(&buf).await.map_err(Error::io)?;
|
||||
stream.flush().await.map_err(Error::io)?;
|
||||
stream.shutdown().await.map_err(Error::io)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
62
libs/proxy/tokio-postgres2/src/cancel_token.rs
Normal file
62
libs/proxy/tokio-postgres2/src/cancel_token.rs
Normal file
@@ -0,0 +1,62 @@
|
||||
use crate::config::SslMode;
|
||||
use crate::tls::TlsConnect;
|
||||
|
||||
use crate::{cancel_query, client::SocketConfig, tls::MakeTlsConnect};
|
||||
use crate::{cancel_query_raw, Error};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
/// The capability to request cancellation of in-progress queries on a
|
||||
/// connection.
|
||||
#[derive(Clone)]
|
||||
pub struct CancelToken {
|
||||
pub(crate) socket_config: Option<SocketConfig>,
|
||||
pub(crate) ssl_mode: SslMode,
|
||||
pub(crate) process_id: i32,
|
||||
pub(crate) secret_key: i32,
|
||||
}
|
||||
|
||||
impl CancelToken {
|
||||
/// Attempts to cancel the in-progress query on the connection associated
|
||||
/// with this `CancelToken`.
|
||||
///
|
||||
/// The server provides no information about whether a cancellation attempt was successful or not. An error will
|
||||
/// only be returned if the client was unable to connect to the database.
|
||||
///
|
||||
/// Cancellation is inherently racy. There is no guarantee that the
|
||||
/// cancellation request will reach the server before the query terminates
|
||||
/// normally, or that the connection associated with this token is still
|
||||
/// active.
|
||||
///
|
||||
/// Requires the `runtime` Cargo feature (enabled by default).
|
||||
pub async fn cancel_query<T>(&self, tls: T) -> Result<(), Error>
|
||||
where
|
||||
T: MakeTlsConnect<TcpStream>,
|
||||
{
|
||||
cancel_query::cancel_query(
|
||||
self.socket_config.clone(),
|
||||
self.ssl_mode,
|
||||
tls,
|
||||
self.process_id,
|
||||
self.secret_key,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Like `cancel_query`, but uses a stream which is already connected to the server rather than opening a new
|
||||
/// connection itself.
|
||||
pub async fn cancel_query_raw<S, T>(&self, stream: S, tls: T) -> Result<(), Error>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: TlsConnect<S>,
|
||||
{
|
||||
cancel_query_raw::cancel_query_raw(
|
||||
stream,
|
||||
self.ssl_mode,
|
||||
tls,
|
||||
self.process_id,
|
||||
self.secret_key,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
439
libs/proxy/tokio-postgres2/src/client.rs
Normal file
439
libs/proxy/tokio-postgres2/src/client.rs
Normal file
@@ -0,0 +1,439 @@
|
||||
use crate::codec::{BackendMessages, FrontendMessage};
|
||||
|
||||
use crate::config::Host;
|
||||
use crate::config::SslMode;
|
||||
use crate::connection::{Request, RequestMessages};
|
||||
|
||||
use crate::query::RowStream;
|
||||
use crate::simple_query::SimpleQueryStream;
|
||||
|
||||
use crate::types::{Oid, ToSql, Type};
|
||||
|
||||
use crate::{
|
||||
prepare, query, simple_query, slice_iter, CancelToken, Error, ReadyForQueryStatus, Row,
|
||||
SimpleQueryMessage, Statement, ToStatement, Transaction, TransactionBuilder,
|
||||
};
|
||||
use bytes::BytesMut;
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use futures_util::{future, ready, TryStreamExt};
|
||||
use parking_lot::Mutex;
|
||||
use postgres_protocol2::message::{backend::Message, frontend};
|
||||
use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
pub struct Responses {
|
||||
receiver: mpsc::Receiver<BackendMessages>,
|
||||
cur: BackendMessages,
|
||||
}
|
||||
|
||||
impl Responses {
|
||||
pub fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Result<Message, Error>> {
|
||||
loop {
|
||||
match self.cur.next().map_err(Error::parse)? {
|
||||
Some(Message::ErrorResponse(body)) => return Poll::Ready(Err(Error::db(body))),
|
||||
Some(message) => return Poll::Ready(Ok(message)),
|
||||
None => {}
|
||||
}
|
||||
|
||||
match ready!(self.receiver.poll_recv(cx)) {
|
||||
Some(messages) => self.cur = messages,
|
||||
None => return Poll::Ready(Err(Error::closed())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn next(&mut self) -> Result<Message, Error> {
|
||||
future::poll_fn(|cx| self.poll_next(cx)).await
|
||||
}
|
||||
}
|
||||
|
||||
/// A cache of type info and prepared statements for fetching type info
|
||||
/// (corresponding to the queries in the [prepare] module).
|
||||
#[derive(Default)]
|
||||
struct CachedTypeInfo {
|
||||
/// A statement for basic information for a type from its
|
||||
/// OID. Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_QUERY) (or its
|
||||
/// fallback).
|
||||
typeinfo: Option<Statement>,
|
||||
/// A statement for getting information for a composite type from its OID.
|
||||
/// Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_COMPOSITE_QUERY).
|
||||
typeinfo_composite: Option<Statement>,
|
||||
/// A statement for getting information for a composite type from its OID.
|
||||
/// Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_COMPOSITE_QUERY) (or
|
||||
/// its fallback).
|
||||
typeinfo_enum: Option<Statement>,
|
||||
|
||||
/// Cache of types already looked up.
|
||||
types: HashMap<Oid, Type>,
|
||||
}
|
||||
|
||||
pub struct InnerClient {
|
||||
sender: mpsc::UnboundedSender<Request>,
|
||||
cached_typeinfo: Mutex<CachedTypeInfo>,
|
||||
|
||||
/// A buffer to use when writing out postgres commands.
|
||||
buffer: Mutex<BytesMut>,
|
||||
}
|
||||
|
||||
impl InnerClient {
|
||||
pub fn send(&self, messages: RequestMessages) -> Result<Responses, Error> {
|
||||
let (sender, receiver) = mpsc::channel(1);
|
||||
let request = Request { messages, sender };
|
||||
self.sender.send(request).map_err(|_| Error::closed())?;
|
||||
|
||||
Ok(Responses {
|
||||
receiver,
|
||||
cur: BackendMessages::empty(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn typeinfo(&self) -> Option<Statement> {
|
||||
self.cached_typeinfo.lock().typeinfo.clone()
|
||||
}
|
||||
|
||||
pub fn set_typeinfo(&self, statement: &Statement) {
|
||||
self.cached_typeinfo.lock().typeinfo = Some(statement.clone());
|
||||
}
|
||||
|
||||
pub fn typeinfo_composite(&self) -> Option<Statement> {
|
||||
self.cached_typeinfo.lock().typeinfo_composite.clone()
|
||||
}
|
||||
|
||||
pub fn set_typeinfo_composite(&self, statement: &Statement) {
|
||||
self.cached_typeinfo.lock().typeinfo_composite = Some(statement.clone());
|
||||
}
|
||||
|
||||
pub fn typeinfo_enum(&self) -> Option<Statement> {
|
||||
self.cached_typeinfo.lock().typeinfo_enum.clone()
|
||||
}
|
||||
|
||||
pub fn set_typeinfo_enum(&self, statement: &Statement) {
|
||||
self.cached_typeinfo.lock().typeinfo_enum = Some(statement.clone());
|
||||
}
|
||||
|
||||
pub fn type_(&self, oid: Oid) -> Option<Type> {
|
||||
self.cached_typeinfo.lock().types.get(&oid).cloned()
|
||||
}
|
||||
|
||||
pub fn set_type(&self, oid: Oid, type_: &Type) {
|
||||
self.cached_typeinfo.lock().types.insert(oid, type_.clone());
|
||||
}
|
||||
|
||||
/// Call the given function with a buffer to be used when writing out
|
||||
/// postgres commands.
|
||||
pub fn with_buf<F, R>(&self, f: F) -> R
|
||||
where
|
||||
F: FnOnce(&mut BytesMut) -> R,
|
||||
{
|
||||
let mut buffer = self.buffer.lock();
|
||||
let r = f(&mut buffer);
|
||||
buffer.clear();
|
||||
r
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct SocketConfig {
|
||||
pub host: Host,
|
||||
pub port: u16,
|
||||
pub connect_timeout: Option<Duration>,
|
||||
// pub keepalive: Option<KeepaliveConfig>,
|
||||
}
|
||||
|
||||
/// An asynchronous PostgreSQL client.
|
||||
///
|
||||
/// The client is one half of what is returned when a connection is established. Users interact with the database
|
||||
/// through this client object.
|
||||
pub struct Client {
|
||||
inner: Arc<InnerClient>,
|
||||
|
||||
socket_config: Option<SocketConfig>,
|
||||
ssl_mode: SslMode,
|
||||
process_id: i32,
|
||||
secret_key: i32,
|
||||
}
|
||||
|
||||
impl Client {
|
||||
pub(crate) fn new(
|
||||
sender: mpsc::UnboundedSender<Request>,
|
||||
ssl_mode: SslMode,
|
||||
process_id: i32,
|
||||
secret_key: i32,
|
||||
) -> Client {
|
||||
Client {
|
||||
inner: Arc::new(InnerClient {
|
||||
sender,
|
||||
cached_typeinfo: Default::default(),
|
||||
buffer: Default::default(),
|
||||
}),
|
||||
|
||||
socket_config: None,
|
||||
ssl_mode,
|
||||
process_id,
|
||||
secret_key,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns process_id.
|
||||
pub fn get_process_id(&self) -> i32 {
|
||||
self.process_id
|
||||
}
|
||||
|
||||
pub(crate) fn inner(&self) -> &Arc<InnerClient> {
|
||||
&self.inner
|
||||
}
|
||||
|
||||
pub(crate) fn set_socket_config(&mut self, socket_config: SocketConfig) {
|
||||
self.socket_config = Some(socket_config);
|
||||
}
|
||||
|
||||
/// Creates a new prepared statement.
|
||||
///
|
||||
/// Prepared statements can be executed repeatedly, and may contain query parameters (indicated by `$1`, `$2`, etc),
|
||||
/// which are set when executed. Prepared statements can only be used with the connection that created them.
|
||||
pub async fn prepare(&self, query: &str) -> Result<Statement, Error> {
|
||||
self.prepare_typed(query, &[]).await
|
||||
}
|
||||
|
||||
/// Like `prepare`, but allows the types of query parameters to be explicitly specified.
|
||||
///
|
||||
/// The list of types may be smaller than the number of parameters - the types of the remaining parameters will be
|
||||
/// inferred. For example, `client.prepare_typed(query, &[])` is equivalent to `client.prepare(query)`.
|
||||
pub async fn prepare_typed(
|
||||
&self,
|
||||
query: &str,
|
||||
parameter_types: &[Type],
|
||||
) -> Result<Statement, Error> {
|
||||
prepare::prepare(&self.inner, query, parameter_types).await
|
||||
}
|
||||
|
||||
/// Executes a statement, returning a vector of the resulting rows.
|
||||
///
|
||||
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
|
||||
/// provided, 1-indexed.
|
||||
///
|
||||
/// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be
|
||||
/// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
|
||||
/// with the `prepare` method.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the number of parameters provided does not match the number expected.
|
||||
pub async fn query<T>(
|
||||
&self,
|
||||
statement: &T,
|
||||
params: &[&(dyn ToSql + Sync)],
|
||||
) -> Result<Vec<Row>, Error>
|
||||
where
|
||||
T: ?Sized + ToStatement,
|
||||
{
|
||||
self.query_raw(statement, slice_iter(params))
|
||||
.await?
|
||||
.try_collect()
|
||||
.await
|
||||
}
|
||||
|
||||
/// The maximally flexible version of [`query`].
|
||||
///
|
||||
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
|
||||
/// provided, 1-indexed.
|
||||
///
|
||||
/// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be
|
||||
/// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
|
||||
/// with the `prepare` method.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the number of parameters provided does not match the number expected.
|
||||
///
|
||||
/// [`query`]: #method.query
|
||||
pub async fn query_raw<'a, T, I>(&self, statement: &T, params: I) -> Result<RowStream, Error>
|
||||
where
|
||||
T: ?Sized + ToStatement,
|
||||
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
|
||||
I::IntoIter: ExactSizeIterator,
|
||||
{
|
||||
let statement = statement.__convert().into_statement(self).await?;
|
||||
query::query(&self.inner, statement, params).await
|
||||
}
|
||||
|
||||
/// Pass text directly to the Postgres backend to allow it to sort out typing itself and
|
||||
/// to save a roundtrip
|
||||
pub async fn query_raw_txt<S, I>(&self, statement: &str, params: I) -> Result<RowStream, Error>
|
||||
where
|
||||
S: AsRef<str>,
|
||||
I: IntoIterator<Item = Option<S>>,
|
||||
I::IntoIter: ExactSizeIterator,
|
||||
{
|
||||
query::query_txt(&self.inner, statement, params).await
|
||||
}
|
||||
|
||||
/// Executes a statement, returning the number of rows modified.
|
||||
///
|
||||
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
|
||||
/// provided, 1-indexed.
|
||||
///
|
||||
/// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be
|
||||
/// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
|
||||
/// with the `prepare` method.
|
||||
///
|
||||
/// If the statement does not modify any rows (e.g. `SELECT`), 0 is returned.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the number of parameters provided does not match the number expected.
|
||||
pub async fn execute<T>(
|
||||
&self,
|
||||
statement: &T,
|
||||
params: &[&(dyn ToSql + Sync)],
|
||||
) -> Result<u64, Error>
|
||||
where
|
||||
T: ?Sized + ToStatement,
|
||||
{
|
||||
self.execute_raw(statement, slice_iter(params)).await
|
||||
}
|
||||
|
||||
/// The maximally flexible version of [`execute`].
|
||||
///
|
||||
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
|
||||
/// provided, 1-indexed.
|
||||
///
|
||||
/// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be
|
||||
/// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
|
||||
/// with the `prepare` method.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the number of parameters provided does not match the number expected.
|
||||
///
|
||||
/// [`execute`]: #method.execute
|
||||
pub async fn execute_raw<'a, T, I>(&self, statement: &T, params: I) -> Result<u64, Error>
|
||||
where
|
||||
T: ?Sized + ToStatement,
|
||||
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
|
||||
I::IntoIter: ExactSizeIterator,
|
||||
{
|
||||
let statement = statement.__convert().into_statement(self).await?;
|
||||
query::execute(self.inner(), statement, params).await
|
||||
}
|
||||
|
||||
/// Executes a sequence of SQL statements using the simple query protocol, returning the resulting rows.
|
||||
///
|
||||
/// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that
|
||||
/// point. The simple query protocol returns the values in rows as strings rather than in their binary encodings,
|
||||
/// so the associated row type doesn't work with the `FromSql` trait. Rather than simply returning a list of the
|
||||
/// rows, this method returns a list of an enum which indicates either the completion of one of the commands,
|
||||
/// or a row of data. This preserves the framing between the separate statements in the request.
|
||||
///
|
||||
/// # Warning
|
||||
///
|
||||
/// Prepared statements should be use for any query which contains user-specified data, as they provided the
|
||||
/// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass
|
||||
/// them to this method!
|
||||
pub async fn simple_query(&self, query: &str) -> Result<Vec<SimpleQueryMessage>, Error> {
|
||||
self.simple_query_raw(query).await?.try_collect().await
|
||||
}
|
||||
|
||||
pub(crate) async fn simple_query_raw(&self, query: &str) -> Result<SimpleQueryStream, Error> {
|
||||
simple_query::simple_query(self.inner(), query).await
|
||||
}
|
||||
|
||||
/// Executes a sequence of SQL statements using the simple query protocol.
|
||||
///
|
||||
/// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that
|
||||
/// point. This is intended for use when, for example, initializing a database schema.
|
||||
///
|
||||
/// # Warning
|
||||
///
|
||||
/// Prepared statements should be use for any query which contains user-specified data, as they provided the
|
||||
/// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass
|
||||
/// them to this method!
|
||||
pub async fn batch_execute(&self, query: &str) -> Result<ReadyForQueryStatus, Error> {
|
||||
simple_query::batch_execute(self.inner(), query).await
|
||||
}
|
||||
|
||||
/// Begins a new database transaction.
|
||||
///
|
||||
/// The transaction will roll back by default - use the `commit` method to commit it.
|
||||
pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
|
||||
struct RollbackIfNotDone<'me> {
|
||||
client: &'me Client,
|
||||
done: bool,
|
||||
}
|
||||
|
||||
impl Drop for RollbackIfNotDone<'_> {
|
||||
fn drop(&mut self) {
|
||||
if self.done {
|
||||
return;
|
||||
}
|
||||
|
||||
let buf = self.client.inner().with_buf(|buf| {
|
||||
frontend::query("ROLLBACK", buf).unwrap();
|
||||
buf.split().freeze()
|
||||
});
|
||||
let _ = self
|
||||
.client
|
||||
.inner()
|
||||
.send(RequestMessages::Single(FrontendMessage::Raw(buf)));
|
||||
}
|
||||
}
|
||||
|
||||
// This is done, as `Future` created by this method can be dropped after
|
||||
// `RequestMessages` is synchronously send to the `Connection` by
|
||||
// `batch_execute()`, but before `Responses` is asynchronously polled to
|
||||
// completion. In that case `Transaction` won't be created and thus
|
||||
// won't be rolled back.
|
||||
{
|
||||
let mut cleaner = RollbackIfNotDone {
|
||||
client: self,
|
||||
done: false,
|
||||
};
|
||||
self.batch_execute("BEGIN").await?;
|
||||
cleaner.done = true;
|
||||
}
|
||||
|
||||
Ok(Transaction::new(self))
|
||||
}
|
||||
|
||||
/// Returns a builder for a transaction with custom settings.
|
||||
///
|
||||
/// Unlike the `transaction` method, the builder can be used to control the transaction's isolation level and other
|
||||
/// attributes.
|
||||
pub fn build_transaction(&mut self) -> TransactionBuilder<'_> {
|
||||
TransactionBuilder::new(self)
|
||||
}
|
||||
|
||||
/// Constructs a cancellation token that can later be used to request cancellation of a query running on the
|
||||
/// connection associated with this client.
|
||||
pub fn cancel_token(&self) -> CancelToken {
|
||||
CancelToken {
|
||||
socket_config: self.socket_config.clone(),
|
||||
ssl_mode: self.ssl_mode,
|
||||
process_id: self.process_id,
|
||||
secret_key: self.secret_key,
|
||||
}
|
||||
}
|
||||
|
||||
/// Query for type information
|
||||
pub async fn get_type(&self, oid: Oid) -> Result<Type, Error> {
|
||||
crate::prepare::get_type(&self.inner, oid).await
|
||||
}
|
||||
|
||||
/// Determines if the connection to the server has already closed.
|
||||
///
|
||||
/// In that case, all future queries will fail.
|
||||
pub fn is_closed(&self) -> bool {
|
||||
self.inner.sender.is_closed()
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for Client {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("Client").finish()
|
||||
}
|
||||
}
|
||||
109
libs/proxy/tokio-postgres2/src/codec.rs
Normal file
109
libs/proxy/tokio-postgres2/src/codec.rs
Normal file
@@ -0,0 +1,109 @@
|
||||
use bytes::{Buf, Bytes, BytesMut};
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use postgres_protocol2::message::backend;
|
||||
use postgres_protocol2::message::frontend::CopyData;
|
||||
use std::io;
|
||||
use tokio_util::codec::{Decoder, Encoder};
|
||||
|
||||
pub enum FrontendMessage {
|
||||
Raw(Bytes),
|
||||
CopyData(CopyData<Box<dyn Buf + Send>>),
|
||||
}
|
||||
|
||||
pub enum BackendMessage {
|
||||
Normal {
|
||||
messages: BackendMessages,
|
||||
request_complete: bool,
|
||||
},
|
||||
Async(backend::Message),
|
||||
}
|
||||
|
||||
pub struct BackendMessages(BytesMut);
|
||||
|
||||
impl BackendMessages {
|
||||
pub fn empty() -> BackendMessages {
|
||||
BackendMessages(BytesMut::new())
|
||||
}
|
||||
}
|
||||
|
||||
impl FallibleIterator for BackendMessages {
|
||||
type Item = backend::Message;
|
||||
type Error = io::Error;
|
||||
|
||||
fn next(&mut self) -> io::Result<Option<backend::Message>> {
|
||||
backend::Message::parse(&mut self.0)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PostgresCodec {
|
||||
pub max_message_size: Option<usize>,
|
||||
}
|
||||
|
||||
impl Encoder<FrontendMessage> for PostgresCodec {
|
||||
type Error = io::Error;
|
||||
|
||||
fn encode(&mut self, item: FrontendMessage, dst: &mut BytesMut) -> io::Result<()> {
|
||||
match item {
|
||||
FrontendMessage::Raw(buf) => dst.extend_from_slice(&buf),
|
||||
FrontendMessage::CopyData(data) => data.write(dst),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for PostgresCodec {
|
||||
type Item = BackendMessage;
|
||||
type Error = io::Error;
|
||||
|
||||
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<BackendMessage>, io::Error> {
|
||||
let mut idx = 0;
|
||||
let mut request_complete = false;
|
||||
|
||||
while let Some(header) = backend::Header::parse(&src[idx..])? {
|
||||
let len = header.len() as usize + 1;
|
||||
if src[idx..].len() < len {
|
||||
break;
|
||||
}
|
||||
|
||||
if let Some(max) = self.max_message_size {
|
||||
if len > max {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"message too large",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
match header.tag() {
|
||||
backend::NOTICE_RESPONSE_TAG
|
||||
| backend::NOTIFICATION_RESPONSE_TAG
|
||||
| backend::PARAMETER_STATUS_TAG => {
|
||||
if idx == 0 {
|
||||
let message = backend::Message::parse(src)?.unwrap();
|
||||
return Ok(Some(BackendMessage::Async(message)));
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
idx += len;
|
||||
|
||||
if header.tag() == backend::READY_FOR_QUERY_TAG {
|
||||
request_complete = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if idx == 0 {
|
||||
Ok(None)
|
||||
} else {
|
||||
Ok(Some(BackendMessage::Normal {
|
||||
messages: BackendMessages(src.split_to(idx)),
|
||||
request_complete,
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
897
libs/proxy/tokio-postgres2/src/config.rs
Normal file
897
libs/proxy/tokio-postgres2/src/config.rs
Normal file
@@ -0,0 +1,897 @@
|
||||
//! Connection configuration.
|
||||
|
||||
use crate::connect::connect;
|
||||
use crate::connect_raw::connect_raw;
|
||||
use crate::tls::MakeTlsConnect;
|
||||
use crate::tls::TlsConnect;
|
||||
use crate::{Client, Connection, Error};
|
||||
use std::borrow::Cow;
|
||||
use std::str;
|
||||
use std::str::FromStr;
|
||||
use std::time::Duration;
|
||||
use std::{error, fmt, iter, mem};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
|
||||
pub use postgres_protocol2::authentication::sasl::ScramKeys;
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
/// Properties required of a session.
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
#[non_exhaustive]
|
||||
pub enum TargetSessionAttrs {
|
||||
/// No special properties are required.
|
||||
Any,
|
||||
/// The session must allow writes.
|
||||
ReadWrite,
|
||||
}
|
||||
|
||||
/// TLS configuration.
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
#[non_exhaustive]
|
||||
pub enum SslMode {
|
||||
/// Do not use TLS.
|
||||
Disable,
|
||||
/// Attempt to connect with TLS but allow sessions without.
|
||||
Prefer,
|
||||
/// Require the use of TLS.
|
||||
Require,
|
||||
}
|
||||
|
||||
/// Channel binding configuration.
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
#[non_exhaustive]
|
||||
pub enum ChannelBinding {
|
||||
/// Do not use channel binding.
|
||||
Disable,
|
||||
/// Attempt to use channel binding but allow sessions without.
|
||||
Prefer,
|
||||
/// Require the use of channel binding.
|
||||
Require,
|
||||
}
|
||||
|
||||
/// Replication mode configuration.
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
#[non_exhaustive]
|
||||
pub enum ReplicationMode {
|
||||
/// Physical replication.
|
||||
Physical,
|
||||
/// Logical replication.
|
||||
Logical,
|
||||
}
|
||||
|
||||
/// A host specification.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum Host {
|
||||
/// A TCP hostname.
|
||||
Tcp(String),
|
||||
}
|
||||
|
||||
/// Precomputed keys which may override password during auth.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum AuthKeys {
|
||||
/// A `ClientKey` & `ServerKey` pair for `SCRAM-SHA-256`.
|
||||
ScramSha256(ScramKeys<32>),
|
||||
}
|
||||
|
||||
/// Connection configuration.
|
||||
///
|
||||
/// Configuration can be parsed from libpq-style connection strings. These strings come in two formats:
|
||||
///
|
||||
/// # Key-Value
|
||||
///
|
||||
/// This format consists of space-separated key-value pairs. Values which are either the empty string or contain
|
||||
/// whitespace should be wrapped in `'`. `'` and `\` characters should be backslash-escaped.
|
||||
///
|
||||
/// ## Keys
|
||||
///
|
||||
/// * `user` - The username to authenticate with. Required.
|
||||
/// * `password` - The password to authenticate with.
|
||||
/// * `dbname` - The name of the database to connect to. Defaults to the username.
|
||||
/// * `options` - Command line options used to configure the server.
|
||||
/// * `application_name` - Sets the `application_name` parameter on the server.
|
||||
/// * `sslmode` - Controls usage of TLS. If set to `disable`, TLS will not be used. If set to `prefer`, TLS will be used
|
||||
/// if available, but not used otherwise. If set to `require`, TLS will be forced to be used. Defaults to `prefer`.
|
||||
/// * `host` - The host to connect to. On Unix platforms, if the host starts with a `/` character it is treated as the
|
||||
/// path to the directory containing Unix domain sockets. Otherwise, it is treated as a hostname. Multiple hosts
|
||||
/// can be specified, separated by commas. Each host will be tried in turn when connecting. Required if connecting
|
||||
/// with the `connect` method.
|
||||
/// * `port` - The port to connect to. Multiple ports can be specified, separated by commas. The number of ports must be
|
||||
/// either 1, in which case it will be used for all hosts, or the same as the number of hosts. Defaults to 5432 if
|
||||
/// omitted or the empty string.
|
||||
/// * `connect_timeout` - The time limit in seconds applied to each socket-level connection attempt. Note that hostnames
|
||||
/// can resolve to multiple IP addresses, and this limit is applied to each address. Defaults to no timeout.
|
||||
/// * `target_session_attrs` - Specifies requirements of the session. If set to `read-write`, the client will check that
|
||||
/// the `transaction_read_write` session parameter is set to `on`. This can be used to connect to the primary server
|
||||
/// in a database cluster as opposed to the secondary read-only mirrors. Defaults to `all`.
|
||||
/// * `channel_binding` - Controls usage of channel binding in the authentication process. If set to `disable`, channel
|
||||
/// binding will not be used. If set to `prefer`, channel binding will be used if available, but not used otherwise.
|
||||
/// If set to `require`, the authentication process will fail if channel binding is not used. Defaults to `prefer`.
|
||||
///
|
||||
/// ## Examples
|
||||
///
|
||||
/// ```not_rust
|
||||
/// host=localhost user=postgres connect_timeout=10 keepalives=0
|
||||
/// ```
|
||||
///
|
||||
/// ```not_rust
|
||||
/// host=/var/lib/postgresql,localhost port=1234 user=postgres password='password with spaces'
|
||||
/// ```
|
||||
///
|
||||
/// ```not_rust
|
||||
/// host=host1,host2,host3 port=1234,,5678 user=postgres target_session_attrs=read-write
|
||||
/// ```
|
||||
///
|
||||
/// # Url
|
||||
///
|
||||
/// This format resembles a URL with a scheme of either `postgres://` or `postgresql://`. All components are optional,
|
||||
/// and the format accepts query parameters for all of the key-value pairs described in the section above. Multiple
|
||||
/// host/port pairs can be comma-separated. Unix socket paths in the host section of the URL should be percent-encoded,
|
||||
/// as the path component of the URL specifies the database name.
|
||||
///
|
||||
/// ## Examples
|
||||
///
|
||||
/// ```not_rust
|
||||
/// postgresql://user@localhost
|
||||
/// ```
|
||||
///
|
||||
/// ```not_rust
|
||||
/// postgresql://user:password@%2Fvar%2Flib%2Fpostgresql/mydb?connect_timeout=10
|
||||
/// ```
|
||||
///
|
||||
/// ```not_rust
|
||||
/// postgresql://user@host1:1234,host2,host3:5678?target_session_attrs=read-write
|
||||
/// ```
|
||||
///
|
||||
/// ```not_rust
|
||||
/// postgresql:///mydb?user=user&host=/var/lib/postgresql
|
||||
/// ```
|
||||
#[derive(Clone, PartialEq, Eq)]
|
||||
pub struct Config {
|
||||
pub(crate) user: Option<String>,
|
||||
pub(crate) password: Option<Vec<u8>>,
|
||||
pub(crate) auth_keys: Option<Box<AuthKeys>>,
|
||||
pub(crate) dbname: Option<String>,
|
||||
pub(crate) options: Option<String>,
|
||||
pub(crate) application_name: Option<String>,
|
||||
pub(crate) ssl_mode: SslMode,
|
||||
pub(crate) host: Vec<Host>,
|
||||
pub(crate) port: Vec<u16>,
|
||||
pub(crate) connect_timeout: Option<Duration>,
|
||||
pub(crate) target_session_attrs: TargetSessionAttrs,
|
||||
pub(crate) channel_binding: ChannelBinding,
|
||||
pub(crate) replication_mode: Option<ReplicationMode>,
|
||||
pub(crate) max_backend_message_size: Option<usize>,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Config {
|
||||
Config::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Config {
|
||||
/// Creates a new configuration.
|
||||
pub fn new() -> Config {
|
||||
Config {
|
||||
user: None,
|
||||
password: None,
|
||||
auth_keys: None,
|
||||
dbname: None,
|
||||
options: None,
|
||||
application_name: None,
|
||||
ssl_mode: SslMode::Prefer,
|
||||
host: vec![],
|
||||
port: vec![],
|
||||
connect_timeout: None,
|
||||
target_session_attrs: TargetSessionAttrs::Any,
|
||||
channel_binding: ChannelBinding::Prefer,
|
||||
replication_mode: None,
|
||||
max_backend_message_size: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets the user to authenticate with.
|
||||
///
|
||||
/// Required.
|
||||
pub fn user(&mut self, user: &str) -> &mut Config {
|
||||
self.user = Some(user.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the user to authenticate with, if one has been configured with
|
||||
/// the `user` method.
|
||||
pub fn get_user(&self) -> Option<&str> {
|
||||
self.user.as_deref()
|
||||
}
|
||||
|
||||
/// Sets the password to authenticate with.
|
||||
pub fn password<T>(&mut self, password: T) -> &mut Config
|
||||
where
|
||||
T: AsRef<[u8]>,
|
||||
{
|
||||
self.password = Some(password.as_ref().to_vec());
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the password to authenticate with, if one has been configured with
|
||||
/// the `password` method.
|
||||
pub fn get_password(&self) -> Option<&[u8]> {
|
||||
self.password.as_deref()
|
||||
}
|
||||
|
||||
/// Sets precomputed protocol-specific keys to authenticate with.
|
||||
/// When set, this option will override `password`.
|
||||
/// See [`AuthKeys`] for more information.
|
||||
pub fn auth_keys(&mut self, keys: AuthKeys) -> &mut Config {
|
||||
self.auth_keys = Some(Box::new(keys));
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets precomputed protocol-specific keys to authenticate with.
|
||||
/// if one has been configured with the `auth_keys` method.
|
||||
pub fn get_auth_keys(&self) -> Option<AuthKeys> {
|
||||
self.auth_keys.as_deref().copied()
|
||||
}
|
||||
|
||||
/// Sets the name of the database to connect to.
|
||||
///
|
||||
/// Defaults to the user.
|
||||
pub fn dbname(&mut self, dbname: &str) -> &mut Config {
|
||||
self.dbname = Some(dbname.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the name of the database to connect to, if one has been configured
|
||||
/// with the `dbname` method.
|
||||
pub fn get_dbname(&self) -> Option<&str> {
|
||||
self.dbname.as_deref()
|
||||
}
|
||||
|
||||
/// Sets command line options used to configure the server.
|
||||
pub fn options(&mut self, options: &str) -> &mut Config {
|
||||
self.options = Some(options.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the command line options used to configure the server, if the
|
||||
/// options have been set with the `options` method.
|
||||
pub fn get_options(&self) -> Option<&str> {
|
||||
self.options.as_deref()
|
||||
}
|
||||
|
||||
/// Sets the value of the `application_name` runtime parameter.
|
||||
pub fn application_name(&mut self, application_name: &str) -> &mut Config {
|
||||
self.application_name = Some(application_name.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the value of the `application_name` runtime parameter, if it has
|
||||
/// been set with the `application_name` method.
|
||||
pub fn get_application_name(&self) -> Option<&str> {
|
||||
self.application_name.as_deref()
|
||||
}
|
||||
|
||||
/// Sets the SSL configuration.
|
||||
///
|
||||
/// Defaults to `prefer`.
|
||||
pub fn ssl_mode(&mut self, ssl_mode: SslMode) -> &mut Config {
|
||||
self.ssl_mode = ssl_mode;
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the SSL configuration.
|
||||
pub fn get_ssl_mode(&self) -> SslMode {
|
||||
self.ssl_mode
|
||||
}
|
||||
|
||||
/// Adds a host to the configuration.
|
||||
///
|
||||
/// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order.
|
||||
pub fn host(&mut self, host: &str) -> &mut Config {
|
||||
self.host.push(Host::Tcp(host.to_string()));
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the hosts that have been added to the configuration with `host`.
|
||||
pub fn get_hosts(&self) -> &[Host] {
|
||||
&self.host
|
||||
}
|
||||
|
||||
/// Adds a port to the configuration.
|
||||
///
|
||||
/// Multiple ports can be specified by calling this method multiple times. There must either be no ports, in which
|
||||
/// case the default of 5432 is used, a single port, in which it is used for all hosts, or the same number of ports
|
||||
/// as hosts.
|
||||
pub fn port(&mut self, port: u16) -> &mut Config {
|
||||
self.port.push(port);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the ports that have been added to the configuration with `port`.
|
||||
pub fn get_ports(&self) -> &[u16] {
|
||||
&self.port
|
||||
}
|
||||
|
||||
/// Sets the timeout applied to socket-level connection attempts.
|
||||
///
|
||||
/// Note that hostnames can resolve to multiple IP addresses, and this timeout will apply to each address of each
|
||||
/// host separately. Defaults to no limit.
|
||||
pub fn connect_timeout(&mut self, connect_timeout: Duration) -> &mut Config {
|
||||
self.connect_timeout = Some(connect_timeout);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the connection timeout, if one has been set with the
|
||||
/// `connect_timeout` method.
|
||||
pub fn get_connect_timeout(&self) -> Option<&Duration> {
|
||||
self.connect_timeout.as_ref()
|
||||
}
|
||||
|
||||
/// Sets the requirements of the session.
|
||||
///
|
||||
/// This can be used to connect to the primary server in a clustered database rather than one of the read-only
|
||||
/// secondary servers. Defaults to `Any`.
|
||||
pub fn target_session_attrs(
|
||||
&mut self,
|
||||
target_session_attrs: TargetSessionAttrs,
|
||||
) -> &mut Config {
|
||||
self.target_session_attrs = target_session_attrs;
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the requirements of the session.
|
||||
pub fn get_target_session_attrs(&self) -> TargetSessionAttrs {
|
||||
self.target_session_attrs
|
||||
}
|
||||
|
||||
/// Sets the channel binding behavior.
|
||||
///
|
||||
/// Defaults to `prefer`.
|
||||
pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config {
|
||||
self.channel_binding = channel_binding;
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the channel binding behavior.
|
||||
pub fn get_channel_binding(&self) -> ChannelBinding {
|
||||
self.channel_binding
|
||||
}
|
||||
|
||||
/// Set replication mode.
|
||||
pub fn replication_mode(&mut self, replication_mode: ReplicationMode) -> &mut Config {
|
||||
self.replication_mode = Some(replication_mode);
|
||||
self
|
||||
}
|
||||
|
||||
/// Get replication mode.
|
||||
pub fn get_replication_mode(&self) -> Option<ReplicationMode> {
|
||||
self.replication_mode
|
||||
}
|
||||
|
||||
/// Set limit for backend messages size.
|
||||
pub fn max_backend_message_size(&mut self, max_backend_message_size: usize) -> &mut Config {
|
||||
self.max_backend_message_size = Some(max_backend_message_size);
|
||||
self
|
||||
}
|
||||
|
||||
/// Get limit for backend messages size.
|
||||
pub fn get_max_backend_message_size(&self) -> Option<usize> {
|
||||
self.max_backend_message_size
|
||||
}
|
||||
|
||||
fn param(&mut self, key: &str, value: &str) -> Result<(), Error> {
|
||||
match key {
|
||||
"user" => {
|
||||
self.user(value);
|
||||
}
|
||||
"password" => {
|
||||
self.password(value);
|
||||
}
|
||||
"dbname" => {
|
||||
self.dbname(value);
|
||||
}
|
||||
"options" => {
|
||||
self.options(value);
|
||||
}
|
||||
"application_name" => {
|
||||
self.application_name(value);
|
||||
}
|
||||
"sslmode" => {
|
||||
let mode = match value {
|
||||
"disable" => SslMode::Disable,
|
||||
"prefer" => SslMode::Prefer,
|
||||
"require" => SslMode::Require,
|
||||
_ => return Err(Error::config_parse(Box::new(InvalidValue("sslmode")))),
|
||||
};
|
||||
self.ssl_mode(mode);
|
||||
}
|
||||
"host" => {
|
||||
for host in value.split(',') {
|
||||
self.host(host);
|
||||
}
|
||||
}
|
||||
"port" => {
|
||||
for port in value.split(',') {
|
||||
let port = if port.is_empty() {
|
||||
5432
|
||||
} else {
|
||||
port.parse()
|
||||
.map_err(|_| Error::config_parse(Box::new(InvalidValue("port"))))?
|
||||
};
|
||||
self.port(port);
|
||||
}
|
||||
}
|
||||
"connect_timeout" => {
|
||||
let timeout = value
|
||||
.parse::<i64>()
|
||||
.map_err(|_| Error::config_parse(Box::new(InvalidValue("connect_timeout"))))?;
|
||||
if timeout > 0 {
|
||||
self.connect_timeout(Duration::from_secs(timeout as u64));
|
||||
}
|
||||
}
|
||||
"target_session_attrs" => {
|
||||
let target_session_attrs = match value {
|
||||
"any" => TargetSessionAttrs::Any,
|
||||
"read-write" => TargetSessionAttrs::ReadWrite,
|
||||
_ => {
|
||||
return Err(Error::config_parse(Box::new(InvalidValue(
|
||||
"target_session_attrs",
|
||||
))));
|
||||
}
|
||||
};
|
||||
self.target_session_attrs(target_session_attrs);
|
||||
}
|
||||
"channel_binding" => {
|
||||
let channel_binding = match value {
|
||||
"disable" => ChannelBinding::Disable,
|
||||
"prefer" => ChannelBinding::Prefer,
|
||||
"require" => ChannelBinding::Require,
|
||||
_ => {
|
||||
return Err(Error::config_parse(Box::new(InvalidValue(
|
||||
"channel_binding",
|
||||
))))
|
||||
}
|
||||
};
|
||||
self.channel_binding(channel_binding);
|
||||
}
|
||||
"max_backend_message_size" => {
|
||||
let limit = value.parse::<usize>().map_err(|_| {
|
||||
Error::config_parse(Box::new(InvalidValue("max_backend_message_size")))
|
||||
})?;
|
||||
if limit > 0 {
|
||||
self.max_backend_message_size(limit);
|
||||
}
|
||||
}
|
||||
key => {
|
||||
return Err(Error::config_parse(Box::new(UnknownOption(
|
||||
key.to_string(),
|
||||
))));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Opens a connection to a PostgreSQL database.
|
||||
///
|
||||
/// Requires the `runtime` Cargo feature (enabled by default).
|
||||
pub async fn connect<T>(
|
||||
&self,
|
||||
tls: T,
|
||||
) -> Result<(Client, Connection<TcpStream, T::Stream>), Error>
|
||||
where
|
||||
T: MakeTlsConnect<TcpStream>,
|
||||
{
|
||||
connect(tls, self).await
|
||||
}
|
||||
|
||||
/// Connects to a PostgreSQL database over an arbitrary stream.
|
||||
///
|
||||
/// All of the settings other than `user`, `password`, `dbname`, `options`, and `application_name` name are ignored.
|
||||
pub async fn connect_raw<S, T>(
|
||||
&self,
|
||||
stream: S,
|
||||
tls: T,
|
||||
) -> Result<(Client, Connection<S, T::Stream>), Error>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: TlsConnect<S>,
|
||||
{
|
||||
connect_raw(stream, tls, self).await
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for Config {
|
||||
type Err = Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Config, Error> {
|
||||
match UrlParser::parse(s)? {
|
||||
Some(config) => Ok(config),
|
||||
None => Parser::parse(s),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Omit password from debug output
|
||||
impl fmt::Debug for Config {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
struct Redaction {}
|
||||
impl fmt::Debug for Redaction {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "_")
|
||||
}
|
||||
}
|
||||
|
||||
f.debug_struct("Config")
|
||||
.field("user", &self.user)
|
||||
.field("password", &self.password.as_ref().map(|_| Redaction {}))
|
||||
.field("dbname", &self.dbname)
|
||||
.field("options", &self.options)
|
||||
.field("application_name", &self.application_name)
|
||||
.field("ssl_mode", &self.ssl_mode)
|
||||
.field("host", &self.host)
|
||||
.field("port", &self.port)
|
||||
.field("connect_timeout", &self.connect_timeout)
|
||||
.field("target_session_attrs", &self.target_session_attrs)
|
||||
.field("channel_binding", &self.channel_binding)
|
||||
.field("replication", &self.replication_mode)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct UnknownOption(String);
|
||||
|
||||
impl fmt::Display for UnknownOption {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(fmt, "unknown option `{}`", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl error::Error for UnknownOption {}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct InvalidValue(&'static str);
|
||||
|
||||
impl fmt::Display for InvalidValue {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(fmt, "invalid value for option `{}`", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl error::Error for InvalidValue {}
|
||||
|
||||
struct Parser<'a> {
|
||||
s: &'a str,
|
||||
it: iter::Peekable<str::CharIndices<'a>>,
|
||||
}
|
||||
|
||||
impl<'a> Parser<'a> {
|
||||
fn parse(s: &'a str) -> Result<Config, Error> {
|
||||
let mut parser = Parser {
|
||||
s,
|
||||
it: s.char_indices().peekable(),
|
||||
};
|
||||
|
||||
let mut config = Config::new();
|
||||
|
||||
while let Some((key, value)) = parser.parameter()? {
|
||||
config.param(key, &value)?;
|
||||
}
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
fn skip_ws(&mut self) {
|
||||
self.take_while(char::is_whitespace);
|
||||
}
|
||||
|
||||
fn take_while<F>(&mut self, f: F) -> &'a str
|
||||
where
|
||||
F: Fn(char) -> bool,
|
||||
{
|
||||
let start = match self.it.peek() {
|
||||
Some(&(i, _)) => i,
|
||||
None => return "",
|
||||
};
|
||||
|
||||
loop {
|
||||
match self.it.peek() {
|
||||
Some(&(_, c)) if f(c) => {
|
||||
self.it.next();
|
||||
}
|
||||
Some(&(i, _)) => return &self.s[start..i],
|
||||
None => return &self.s[start..],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn eat(&mut self, target: char) -> Result<(), Error> {
|
||||
match self.it.next() {
|
||||
Some((_, c)) if c == target => Ok(()),
|
||||
Some((i, c)) => {
|
||||
let m = format!(
|
||||
"unexpected character at byte {}: expected `{}` but got `{}`",
|
||||
i, target, c
|
||||
);
|
||||
Err(Error::config_parse(m.into()))
|
||||
}
|
||||
None => Err(Error::config_parse("unexpected EOF".into())),
|
||||
}
|
||||
}
|
||||
|
||||
fn eat_if(&mut self, target: char) -> bool {
|
||||
match self.it.peek() {
|
||||
Some(&(_, c)) if c == target => {
|
||||
self.it.next();
|
||||
true
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn keyword(&mut self) -> Option<&'a str> {
|
||||
let s = self.take_while(|c| match c {
|
||||
c if c.is_whitespace() => false,
|
||||
'=' => false,
|
||||
_ => true,
|
||||
});
|
||||
|
||||
if s.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(s)
|
||||
}
|
||||
}
|
||||
|
||||
fn value(&mut self) -> Result<String, Error> {
|
||||
let value = if self.eat_if('\'') {
|
||||
let value = self.quoted_value()?;
|
||||
self.eat('\'')?;
|
||||
value
|
||||
} else {
|
||||
self.simple_value()?
|
||||
};
|
||||
|
||||
Ok(value)
|
||||
}
|
||||
|
||||
fn simple_value(&mut self) -> Result<String, Error> {
|
||||
let mut value = String::new();
|
||||
|
||||
while let Some(&(_, c)) = self.it.peek() {
|
||||
if c.is_whitespace() {
|
||||
break;
|
||||
}
|
||||
|
||||
self.it.next();
|
||||
if c == '\\' {
|
||||
if let Some((_, c2)) = self.it.next() {
|
||||
value.push(c2);
|
||||
}
|
||||
} else {
|
||||
value.push(c);
|
||||
}
|
||||
}
|
||||
|
||||
if value.is_empty() {
|
||||
return Err(Error::config_parse("unexpected EOF".into()));
|
||||
}
|
||||
|
||||
Ok(value)
|
||||
}
|
||||
|
||||
fn quoted_value(&mut self) -> Result<String, Error> {
|
||||
let mut value = String::new();
|
||||
|
||||
while let Some(&(_, c)) = self.it.peek() {
|
||||
if c == '\'' {
|
||||
return Ok(value);
|
||||
}
|
||||
|
||||
self.it.next();
|
||||
if c == '\\' {
|
||||
if let Some((_, c2)) = self.it.next() {
|
||||
value.push(c2);
|
||||
}
|
||||
} else {
|
||||
value.push(c);
|
||||
}
|
||||
}
|
||||
|
||||
Err(Error::config_parse(
|
||||
"unterminated quoted connection parameter value".into(),
|
||||
))
|
||||
}
|
||||
|
||||
fn parameter(&mut self) -> Result<Option<(&'a str, String)>, Error> {
|
||||
self.skip_ws();
|
||||
let keyword = match self.keyword() {
|
||||
Some(keyword) => keyword,
|
||||
None => return Ok(None),
|
||||
};
|
||||
self.skip_ws();
|
||||
self.eat('=')?;
|
||||
self.skip_ws();
|
||||
let value = self.value()?;
|
||||
|
||||
Ok(Some((keyword, value)))
|
||||
}
|
||||
}
|
||||
|
||||
// This is a pretty sloppy "URL" parser, but it matches the behavior of libpq, where things really aren't very strict
|
||||
struct UrlParser<'a> {
|
||||
s: &'a str,
|
||||
config: Config,
|
||||
}
|
||||
|
||||
impl<'a> UrlParser<'a> {
|
||||
fn parse(s: &'a str) -> Result<Option<Config>, Error> {
|
||||
let s = match Self::remove_url_prefix(s) {
|
||||
Some(s) => s,
|
||||
None => return Ok(None),
|
||||
};
|
||||
|
||||
let mut parser = UrlParser {
|
||||
s,
|
||||
config: Config::new(),
|
||||
};
|
||||
|
||||
parser.parse_credentials()?;
|
||||
parser.parse_host()?;
|
||||
parser.parse_path()?;
|
||||
parser.parse_params()?;
|
||||
|
||||
Ok(Some(parser.config))
|
||||
}
|
||||
|
||||
fn remove_url_prefix(s: &str) -> Option<&str> {
|
||||
for prefix in &["postgres://", "postgresql://"] {
|
||||
if let Some(stripped) = s.strip_prefix(prefix) {
|
||||
return Some(stripped);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn take_until(&mut self, end: &[char]) -> Option<&'a str> {
|
||||
match self.s.find(end) {
|
||||
Some(pos) => {
|
||||
let (head, tail) = self.s.split_at(pos);
|
||||
self.s = tail;
|
||||
Some(head)
|
||||
}
|
||||
None => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn take_all(&mut self) -> &'a str {
|
||||
mem::take(&mut self.s)
|
||||
}
|
||||
|
||||
fn eat_byte(&mut self) {
|
||||
self.s = &self.s[1..];
|
||||
}
|
||||
|
||||
fn parse_credentials(&mut self) -> Result<(), Error> {
|
||||
let creds = match self.take_until(&['@']) {
|
||||
Some(creds) => creds,
|
||||
None => return Ok(()),
|
||||
};
|
||||
self.eat_byte();
|
||||
|
||||
let mut it = creds.splitn(2, ':');
|
||||
let user = self.decode(it.next().unwrap())?;
|
||||
self.config.user(&user);
|
||||
|
||||
if let Some(password) = it.next() {
|
||||
let password = Cow::from(percent_encoding::percent_decode(password.as_bytes()));
|
||||
self.config.password(password);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn parse_host(&mut self) -> Result<(), Error> {
|
||||
let host = match self.take_until(&['/', '?']) {
|
||||
Some(host) => host,
|
||||
None => self.take_all(),
|
||||
};
|
||||
|
||||
if host.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
for chunk in host.split(',') {
|
||||
let (host, port) = if chunk.starts_with('[') {
|
||||
let idx = match chunk.find(']') {
|
||||
Some(idx) => idx,
|
||||
None => return Err(Error::config_parse(InvalidValue("host").into())),
|
||||
};
|
||||
|
||||
let host = &chunk[1..idx];
|
||||
let remaining = &chunk[idx + 1..];
|
||||
let port = if let Some(port) = remaining.strip_prefix(':') {
|
||||
Some(port)
|
||||
} else if remaining.is_empty() {
|
||||
None
|
||||
} else {
|
||||
return Err(Error::config_parse(InvalidValue("host").into()));
|
||||
};
|
||||
|
||||
(host, port)
|
||||
} else {
|
||||
let mut it = chunk.splitn(2, ':');
|
||||
(it.next().unwrap(), it.next())
|
||||
};
|
||||
|
||||
self.host_param(host)?;
|
||||
let port = self.decode(port.unwrap_or("5432"))?;
|
||||
self.config.param("port", &port)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn parse_path(&mut self) -> Result<(), Error> {
|
||||
if !self.s.starts_with('/') {
|
||||
return Ok(());
|
||||
}
|
||||
self.eat_byte();
|
||||
|
||||
let dbname = match self.take_until(&['?']) {
|
||||
Some(dbname) => dbname,
|
||||
None => self.take_all(),
|
||||
};
|
||||
|
||||
if !dbname.is_empty() {
|
||||
self.config.dbname(&self.decode(dbname)?);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn parse_params(&mut self) -> Result<(), Error> {
|
||||
if !self.s.starts_with('?') {
|
||||
return Ok(());
|
||||
}
|
||||
self.eat_byte();
|
||||
|
||||
while !self.s.is_empty() {
|
||||
let key = match self.take_until(&['=']) {
|
||||
Some(key) => self.decode(key)?,
|
||||
None => return Err(Error::config_parse("unterminated parameter".into())),
|
||||
};
|
||||
self.eat_byte();
|
||||
|
||||
let value = match self.take_until(&['&']) {
|
||||
Some(value) => {
|
||||
self.eat_byte();
|
||||
value
|
||||
}
|
||||
None => self.take_all(),
|
||||
};
|
||||
|
||||
if key == "host" {
|
||||
self.host_param(value)?;
|
||||
} else {
|
||||
let value = self.decode(value)?;
|
||||
self.config.param(&key, &value)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn host_param(&mut self, s: &str) -> Result<(), Error> {
|
||||
let s = self.decode(s)?;
|
||||
self.config.param("host", &s)
|
||||
}
|
||||
|
||||
fn decode(&self, s: &'a str) -> Result<Cow<'a, str>, Error> {
|
||||
percent_encoding::percent_decode(s.as_bytes())
|
||||
.decode_utf8()
|
||||
.map_err(|e| Error::config_parse(e.into()))
|
||||
}
|
||||
}
|
||||
112
libs/proxy/tokio-postgres2/src/connect.rs
Normal file
112
libs/proxy/tokio-postgres2/src/connect.rs
Normal file
@@ -0,0 +1,112 @@
|
||||
use crate::client::SocketConfig;
|
||||
use crate::config::{Host, TargetSessionAttrs};
|
||||
use crate::connect_raw::connect_raw;
|
||||
use crate::connect_socket::connect_socket;
|
||||
use crate::tls::{MakeTlsConnect, TlsConnect};
|
||||
use crate::{Client, Config, Connection, Error, SimpleQueryMessage};
|
||||
use futures_util::{future, pin_mut, Future, FutureExt, Stream};
|
||||
use std::io;
|
||||
use std::task::Poll;
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
pub async fn connect<T>(
|
||||
mut tls: T,
|
||||
config: &Config,
|
||||
) -> Result<(Client, Connection<TcpStream, T::Stream>), Error>
|
||||
where
|
||||
T: MakeTlsConnect<TcpStream>,
|
||||
{
|
||||
if config.host.is_empty() {
|
||||
return Err(Error::config("host missing".into()));
|
||||
}
|
||||
|
||||
if config.port.len() > 1 && config.port.len() != config.host.len() {
|
||||
return Err(Error::config("invalid number of ports".into()));
|
||||
}
|
||||
|
||||
let mut error = None;
|
||||
for (i, host) in config.host.iter().enumerate() {
|
||||
let port = config
|
||||
.port
|
||||
.get(i)
|
||||
.or_else(|| config.port.first())
|
||||
.copied()
|
||||
.unwrap_or(5432);
|
||||
|
||||
let hostname = match host {
|
||||
Host::Tcp(host) => host.as_str(),
|
||||
};
|
||||
|
||||
let tls = tls
|
||||
.make_tls_connect(hostname)
|
||||
.map_err(|e| Error::tls(e.into()))?;
|
||||
|
||||
match connect_once(host, port, tls, config).await {
|
||||
Ok((client, connection)) => return Ok((client, connection)),
|
||||
Err(e) => error = Some(e),
|
||||
}
|
||||
}
|
||||
|
||||
Err(error.unwrap())
|
||||
}
|
||||
|
||||
async fn connect_once<T>(
|
||||
host: &Host,
|
||||
port: u16,
|
||||
tls: T,
|
||||
config: &Config,
|
||||
) -> Result<(Client, Connection<TcpStream, T::Stream>), Error>
|
||||
where
|
||||
T: TlsConnect<TcpStream>,
|
||||
{
|
||||
let socket = connect_socket(host, port, config.connect_timeout).await?;
|
||||
let (mut client, mut connection) = connect_raw(socket, tls, config).await?;
|
||||
|
||||
if let TargetSessionAttrs::ReadWrite = config.target_session_attrs {
|
||||
let rows = client.simple_query_raw("SHOW transaction_read_only");
|
||||
pin_mut!(rows);
|
||||
|
||||
let rows = future::poll_fn(|cx| {
|
||||
if connection.poll_unpin(cx)?.is_ready() {
|
||||
return Poll::Ready(Err(Error::closed()));
|
||||
}
|
||||
|
||||
rows.as_mut().poll(cx)
|
||||
})
|
||||
.await?;
|
||||
pin_mut!(rows);
|
||||
|
||||
loop {
|
||||
let next = future::poll_fn(|cx| {
|
||||
if connection.poll_unpin(cx)?.is_ready() {
|
||||
return Poll::Ready(Some(Err(Error::closed())));
|
||||
}
|
||||
|
||||
rows.as_mut().poll_next(cx)
|
||||
});
|
||||
|
||||
match next.await.transpose()? {
|
||||
Some(SimpleQueryMessage::Row(row)) => {
|
||||
if row.try_get(0)? == Some("on") {
|
||||
return Err(Error::connect(io::Error::new(
|
||||
io::ErrorKind::PermissionDenied,
|
||||
"database does not allow writes",
|
||||
)));
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Some(_) => {}
|
||||
None => return Err(Error::unexpected_message()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
client.set_socket_config(SocketConfig {
|
||||
host: host.clone(),
|
||||
port,
|
||||
connect_timeout: config.connect_timeout,
|
||||
});
|
||||
|
||||
Ok((client, connection))
|
||||
}
|
||||
359
libs/proxy/tokio-postgres2/src/connect_raw.rs
Normal file
359
libs/proxy/tokio-postgres2/src/connect_raw.rs
Normal file
@@ -0,0 +1,359 @@
|
||||
use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
|
||||
use crate::config::{self, AuthKeys, Config, ReplicationMode};
|
||||
use crate::connect_tls::connect_tls;
|
||||
use crate::maybe_tls_stream::MaybeTlsStream;
|
||||
use crate::tls::{TlsConnect, TlsStream};
|
||||
use crate::{Client, Connection, Error};
|
||||
use bytes::BytesMut;
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use futures_util::{ready, Sink, SinkExt, Stream, TryStreamExt};
|
||||
use postgres_protocol2::authentication;
|
||||
use postgres_protocol2::authentication::sasl;
|
||||
use postgres_protocol2::authentication::sasl::ScramSha256;
|
||||
use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message};
|
||||
use postgres_protocol2::message::frontend;
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_util::codec::Framed;
|
||||
|
||||
pub struct StartupStream<S, T> {
|
||||
inner: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
|
||||
buf: BackendMessages,
|
||||
delayed: VecDeque<BackendMessage>,
|
||||
}
|
||||
|
||||
impl<S, T> Sink<FrontendMessage> for StartupStream<S, T>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_ready(cx)
|
||||
}
|
||||
|
||||
fn start_send(mut self: Pin<&mut Self>, item: FrontendMessage) -> io::Result<()> {
|
||||
Pin::new(&mut self.inner).start_send(item)
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_close(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, T> Stream for StartupStream<S, T>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
type Item = io::Result<Message>;
|
||||
|
||||
fn poll_next(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<io::Result<Message>>> {
|
||||
loop {
|
||||
match self.buf.next() {
|
||||
Ok(Some(message)) => return Poll::Ready(Some(Ok(message))),
|
||||
Ok(None) => {}
|
||||
Err(e) => return Poll::Ready(Some(Err(e))),
|
||||
}
|
||||
|
||||
match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
|
||||
Some(Ok(BackendMessage::Normal { messages, .. })) => self.buf = messages,
|
||||
Some(Ok(BackendMessage::Async(message))) => return Poll::Ready(Some(Ok(message))),
|
||||
Some(Err(e)) => return Poll::Ready(Some(Err(e))),
|
||||
None => return Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn connect_raw<S, T>(
|
||||
stream: S,
|
||||
tls: T,
|
||||
config: &Config,
|
||||
) -> Result<(Client, Connection<S, T::Stream>), Error>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: TlsConnect<S>,
|
||||
{
|
||||
let stream = connect_tls(stream, config.ssl_mode, tls).await?;
|
||||
|
||||
let mut stream = StartupStream {
|
||||
inner: Framed::new(
|
||||
stream,
|
||||
PostgresCodec {
|
||||
max_message_size: config.max_backend_message_size,
|
||||
},
|
||||
),
|
||||
buf: BackendMessages::empty(),
|
||||
delayed: VecDeque::new(),
|
||||
};
|
||||
|
||||
startup(&mut stream, config).await?;
|
||||
authenticate(&mut stream, config).await?;
|
||||
let (process_id, secret_key, parameters) = read_info(&mut stream).await?;
|
||||
|
||||
let (sender, receiver) = mpsc::unbounded_channel();
|
||||
let client = Client::new(sender, config.ssl_mode, process_id, secret_key);
|
||||
let connection = Connection::new(stream.inner, stream.delayed, parameters, receiver);
|
||||
|
||||
Ok((client, connection))
|
||||
}
|
||||
|
||||
async fn startup<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
let mut params = vec![("client_encoding", "UTF8")];
|
||||
if let Some(user) = &config.user {
|
||||
params.push(("user", &**user));
|
||||
}
|
||||
if let Some(dbname) = &config.dbname {
|
||||
params.push(("database", &**dbname));
|
||||
}
|
||||
if let Some(options) = &config.options {
|
||||
params.push(("options", &**options));
|
||||
}
|
||||
if let Some(application_name) = &config.application_name {
|
||||
params.push(("application_name", &**application_name));
|
||||
}
|
||||
if let Some(replication_mode) = &config.replication_mode {
|
||||
match replication_mode {
|
||||
ReplicationMode::Physical => params.push(("replication", "true")),
|
||||
ReplicationMode::Logical => params.push(("replication", "database")),
|
||||
}
|
||||
}
|
||||
|
||||
let mut buf = BytesMut::new();
|
||||
frontend::startup_message(params, &mut buf).map_err(Error::encode)?;
|
||||
|
||||
stream
|
||||
.send(FrontendMessage::Raw(buf.freeze()))
|
||||
.await
|
||||
.map_err(Error::io)
|
||||
}
|
||||
|
||||
async fn authenticate<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: TlsStream + Unpin,
|
||||
{
|
||||
match stream.try_next().await.map_err(Error::io)? {
|
||||
Some(Message::AuthenticationOk) => {
|
||||
can_skip_channel_binding(config)?;
|
||||
return Ok(());
|
||||
}
|
||||
Some(Message::AuthenticationCleartextPassword) => {
|
||||
can_skip_channel_binding(config)?;
|
||||
|
||||
let pass = config
|
||||
.password
|
||||
.as_ref()
|
||||
.ok_or_else(|| Error::config("password missing".into()))?;
|
||||
|
||||
authenticate_password(stream, pass).await?;
|
||||
}
|
||||
Some(Message::AuthenticationMd5Password(body)) => {
|
||||
can_skip_channel_binding(config)?;
|
||||
|
||||
let user = config
|
||||
.user
|
||||
.as_ref()
|
||||
.ok_or_else(|| Error::config("user missing".into()))?;
|
||||
let pass = config
|
||||
.password
|
||||
.as_ref()
|
||||
.ok_or_else(|| Error::config("password missing".into()))?;
|
||||
|
||||
let output = authentication::md5_hash(user.as_bytes(), pass, body.salt());
|
||||
authenticate_password(stream, output.as_bytes()).await?;
|
||||
}
|
||||
Some(Message::AuthenticationSasl(body)) => {
|
||||
authenticate_sasl(stream, body, config).await?;
|
||||
}
|
||||
Some(Message::AuthenticationKerberosV5)
|
||||
| Some(Message::AuthenticationScmCredential)
|
||||
| Some(Message::AuthenticationGss)
|
||||
| Some(Message::AuthenticationSspi) => {
|
||||
return Err(Error::authentication(
|
||||
"unsupported authentication method".into(),
|
||||
))
|
||||
}
|
||||
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
|
||||
Some(_) => return Err(Error::unexpected_message()),
|
||||
None => return Err(Error::closed()),
|
||||
}
|
||||
|
||||
match stream.try_next().await.map_err(Error::io)? {
|
||||
Some(Message::AuthenticationOk) => Ok(()),
|
||||
Some(Message::ErrorResponse(body)) => Err(Error::db(body)),
|
||||
Some(_) => Err(Error::unexpected_message()),
|
||||
None => Err(Error::closed()),
|
||||
}
|
||||
}
|
||||
|
||||
fn can_skip_channel_binding(config: &Config) -> Result<(), Error> {
|
||||
match config.channel_binding {
|
||||
config::ChannelBinding::Disable | config::ChannelBinding::Prefer => Ok(()),
|
||||
config::ChannelBinding::Require => Err(Error::authentication(
|
||||
"server did not use channel binding".into(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn authenticate_password<S, T>(
|
||||
stream: &mut StartupStream<S, T>,
|
||||
password: &[u8],
|
||||
) -> Result<(), Error>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
let mut buf = BytesMut::new();
|
||||
frontend::password_message(password, &mut buf).map_err(Error::encode)?;
|
||||
|
||||
stream
|
||||
.send(FrontendMessage::Raw(buf.freeze()))
|
||||
.await
|
||||
.map_err(Error::io)
|
||||
}
|
||||
|
||||
async fn authenticate_sasl<S, T>(
|
||||
stream: &mut StartupStream<S, T>,
|
||||
body: AuthenticationSaslBody,
|
||||
config: &Config,
|
||||
) -> Result<(), Error>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: TlsStream + Unpin,
|
||||
{
|
||||
let mut has_scram = false;
|
||||
let mut has_scram_plus = false;
|
||||
let mut mechanisms = body.mechanisms();
|
||||
while let Some(mechanism) = mechanisms.next().map_err(Error::parse)? {
|
||||
match mechanism {
|
||||
sasl::SCRAM_SHA_256 => has_scram = true,
|
||||
sasl::SCRAM_SHA_256_PLUS => has_scram_plus = true,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
let channel_binding = stream
|
||||
.inner
|
||||
.get_ref()
|
||||
.channel_binding()
|
||||
.tls_server_end_point
|
||||
.filter(|_| config.channel_binding != config::ChannelBinding::Disable)
|
||||
.map(sasl::ChannelBinding::tls_server_end_point);
|
||||
|
||||
let (channel_binding, mechanism) = if has_scram_plus {
|
||||
match channel_binding {
|
||||
Some(channel_binding) => (channel_binding, sasl::SCRAM_SHA_256_PLUS),
|
||||
None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
|
||||
}
|
||||
} else if has_scram {
|
||||
match channel_binding {
|
||||
Some(_) => (sasl::ChannelBinding::unrequested(), sasl::SCRAM_SHA_256),
|
||||
None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
|
||||
}
|
||||
} else {
|
||||
return Err(Error::authentication("unsupported SASL mechanism".into()));
|
||||
};
|
||||
|
||||
if mechanism != sasl::SCRAM_SHA_256_PLUS {
|
||||
can_skip_channel_binding(config)?;
|
||||
}
|
||||
|
||||
let mut scram = if let Some(AuthKeys::ScramSha256(keys)) = config.get_auth_keys() {
|
||||
ScramSha256::new_with_keys(keys, channel_binding)
|
||||
} else if let Some(password) = config.get_password() {
|
||||
ScramSha256::new(password, channel_binding)
|
||||
} else {
|
||||
return Err(Error::config("password or auth keys missing".into()));
|
||||
};
|
||||
|
||||
let mut buf = BytesMut::new();
|
||||
frontend::sasl_initial_response(mechanism, scram.message(), &mut buf).map_err(Error::encode)?;
|
||||
stream
|
||||
.send(FrontendMessage::Raw(buf.freeze()))
|
||||
.await
|
||||
.map_err(Error::io)?;
|
||||
|
||||
let body = match stream.try_next().await.map_err(Error::io)? {
|
||||
Some(Message::AuthenticationSaslContinue(body)) => body,
|
||||
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
|
||||
Some(_) => return Err(Error::unexpected_message()),
|
||||
None => return Err(Error::closed()),
|
||||
};
|
||||
|
||||
scram
|
||||
.update(body.data())
|
||||
.await
|
||||
.map_err(|e| Error::authentication(e.into()))?;
|
||||
|
||||
let mut buf = BytesMut::new();
|
||||
frontend::sasl_response(scram.message(), &mut buf).map_err(Error::encode)?;
|
||||
stream
|
||||
.send(FrontendMessage::Raw(buf.freeze()))
|
||||
.await
|
||||
.map_err(Error::io)?;
|
||||
|
||||
let body = match stream.try_next().await.map_err(Error::io)? {
|
||||
Some(Message::AuthenticationSaslFinal(body)) => body,
|
||||
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
|
||||
Some(_) => return Err(Error::unexpected_message()),
|
||||
None => return Err(Error::closed()),
|
||||
};
|
||||
|
||||
scram
|
||||
.finish(body.data())
|
||||
.map_err(|e| Error::authentication(e.into()))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn read_info<S, T>(
|
||||
stream: &mut StartupStream<S, T>,
|
||||
) -> Result<(i32, i32, HashMap<String, String>), Error>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
let mut process_id = 0;
|
||||
let mut secret_key = 0;
|
||||
let mut parameters = HashMap::new();
|
||||
|
||||
loop {
|
||||
match stream.try_next().await.map_err(Error::io)? {
|
||||
Some(Message::BackendKeyData(body)) => {
|
||||
process_id = body.process_id();
|
||||
secret_key = body.secret_key();
|
||||
}
|
||||
Some(Message::ParameterStatus(body)) => {
|
||||
parameters.insert(
|
||||
body.name().map_err(Error::parse)?.to_string(),
|
||||
body.value().map_err(Error::parse)?.to_string(),
|
||||
);
|
||||
}
|
||||
Some(msg @ Message::NoticeResponse(_)) => {
|
||||
stream.delayed.push_back(BackendMessage::Async(msg))
|
||||
}
|
||||
Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key, parameters)),
|
||||
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
|
||||
Some(_) => return Err(Error::unexpected_message()),
|
||||
None => return Err(Error::closed()),
|
||||
}
|
||||
}
|
||||
}
|
||||
65
libs/proxy/tokio-postgres2/src/connect_socket.rs
Normal file
65
libs/proxy/tokio-postgres2/src/connect_socket.rs
Normal file
@@ -0,0 +1,65 @@
|
||||
use crate::config::Host;
|
||||
use crate::Error;
|
||||
use std::future::Future;
|
||||
use std::io;
|
||||
use std::time::Duration;
|
||||
use tokio::net::{self, TcpStream};
|
||||
use tokio::time;
|
||||
|
||||
pub(crate) async fn connect_socket(
|
||||
host: &Host,
|
||||
port: u16,
|
||||
connect_timeout: Option<Duration>,
|
||||
) -> Result<TcpStream, Error> {
|
||||
match host {
|
||||
Host::Tcp(host) => {
|
||||
let addrs = net::lookup_host((&**host, port))
|
||||
.await
|
||||
.map_err(Error::connect)?;
|
||||
|
||||
let mut last_err = None;
|
||||
|
||||
for addr in addrs {
|
||||
let stream =
|
||||
match connect_with_timeout(TcpStream::connect(addr), connect_timeout).await {
|
||||
Ok(stream) => stream,
|
||||
Err(e) => {
|
||||
last_err = Some(e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
stream.set_nodelay(true).map_err(Error::connect)?;
|
||||
|
||||
return Ok(stream);
|
||||
}
|
||||
|
||||
Err(last_err.unwrap_or_else(|| {
|
||||
Error::connect(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"could not resolve any addresses",
|
||||
))
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn connect_with_timeout<F, T>(connect: F, timeout: Option<Duration>) -> Result<T, Error>
|
||||
where
|
||||
F: Future<Output = io::Result<T>>,
|
||||
{
|
||||
match timeout {
|
||||
Some(timeout) => match time::timeout(timeout, connect).await {
|
||||
Ok(Ok(socket)) => Ok(socket),
|
||||
Ok(Err(e)) => Err(Error::connect(e)),
|
||||
Err(_) => Err(Error::connect(io::Error::new(
|
||||
io::ErrorKind::TimedOut,
|
||||
"connection timed out",
|
||||
))),
|
||||
},
|
||||
None => match connect.await {
|
||||
Ok(socket) => Ok(socket),
|
||||
Err(e) => Err(Error::connect(e)),
|
||||
},
|
||||
}
|
||||
}
|
||||
48
libs/proxy/tokio-postgres2/src/connect_tls.rs
Normal file
48
libs/proxy/tokio-postgres2/src/connect_tls.rs
Normal file
@@ -0,0 +1,48 @@
|
||||
use crate::config::SslMode;
|
||||
use crate::maybe_tls_stream::MaybeTlsStream;
|
||||
use crate::tls::private::ForcePrivateApi;
|
||||
use crate::tls::TlsConnect;
|
||||
use crate::Error;
|
||||
use bytes::BytesMut;
|
||||
use postgres_protocol2::message::frontend;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
|
||||
pub async fn connect_tls<S, T>(
|
||||
mut stream: S,
|
||||
mode: SslMode,
|
||||
tls: T,
|
||||
) -> Result<MaybeTlsStream<S, T::Stream>, Error>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: TlsConnect<S>,
|
||||
{
|
||||
match mode {
|
||||
SslMode::Disable => return Ok(MaybeTlsStream::Raw(stream)),
|
||||
SslMode::Prefer if !tls.can_connect(ForcePrivateApi) => {
|
||||
return Ok(MaybeTlsStream::Raw(stream))
|
||||
}
|
||||
SslMode::Prefer | SslMode::Require => {}
|
||||
}
|
||||
|
||||
let mut buf = BytesMut::new();
|
||||
frontend::ssl_request(&mut buf);
|
||||
stream.write_all(&buf).await.map_err(Error::io)?;
|
||||
|
||||
let mut buf = [0];
|
||||
stream.read_exact(&mut buf).await.map_err(Error::io)?;
|
||||
|
||||
if buf[0] != b'S' {
|
||||
if SslMode::Require == mode {
|
||||
return Err(Error::tls("server does not support TLS".into()));
|
||||
} else {
|
||||
return Ok(MaybeTlsStream::Raw(stream));
|
||||
}
|
||||
}
|
||||
|
||||
let stream = tls
|
||||
.connect(stream)
|
||||
.await
|
||||
.map_err(|e| Error::tls(e.into()))?;
|
||||
|
||||
Ok(MaybeTlsStream::Tls(stream))
|
||||
}
|
||||
323
libs/proxy/tokio-postgres2/src/connection.rs
Normal file
323
libs/proxy/tokio-postgres2/src/connection.rs
Normal file
@@ -0,0 +1,323 @@
|
||||
use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
|
||||
use crate::error::DbError;
|
||||
use crate::maybe_tls_stream::MaybeTlsStream;
|
||||
use crate::{AsyncMessage, Error, Notification};
|
||||
use bytes::BytesMut;
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use futures_util::{ready, Sink, Stream};
|
||||
use log::{info, trace};
|
||||
use postgres_protocol2::message::backend::Message;
|
||||
use postgres_protocol2::message::frontend;
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_util::codec::Framed;
|
||||
use tokio_util::sync::PollSender;
|
||||
|
||||
pub enum RequestMessages {
|
||||
Single(FrontendMessage),
|
||||
}
|
||||
|
||||
pub struct Request {
|
||||
pub messages: RequestMessages,
|
||||
pub sender: mpsc::Sender<BackendMessages>,
|
||||
}
|
||||
|
||||
pub struct Response {
|
||||
sender: PollSender<BackendMessages>,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Debug)]
|
||||
enum State {
|
||||
Active,
|
||||
Terminating,
|
||||
Closing,
|
||||
}
|
||||
|
||||
/// A connection to a PostgreSQL database.
|
||||
///
|
||||
/// This is one half of what is returned when a new connection is established. It performs the actual IO with the
|
||||
/// server, and should generally be spawned off onto an executor to run in the background.
|
||||
///
|
||||
/// `Connection` implements `Future`, and only resolves when the connection is closed, either because a fatal error has
|
||||
/// occurred, or because its associated `Client` has dropped and all outstanding work has completed.
|
||||
#[must_use = "futures do nothing unless polled"]
|
||||
pub struct Connection<S, T> {
|
||||
/// HACK: we need this in the Neon Proxy.
|
||||
pub stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
|
||||
/// HACK: we need this in the Neon Proxy to forward params.
|
||||
pub parameters: HashMap<String, String>,
|
||||
receiver: mpsc::UnboundedReceiver<Request>,
|
||||
pending_request: Option<RequestMessages>,
|
||||
pending_responses: VecDeque<BackendMessage>,
|
||||
responses: VecDeque<Response>,
|
||||
state: State,
|
||||
}
|
||||
|
||||
impl<S, T> Connection<S, T>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
pub(crate) fn new(
|
||||
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
|
||||
pending_responses: VecDeque<BackendMessage>,
|
||||
parameters: HashMap<String, String>,
|
||||
receiver: mpsc::UnboundedReceiver<Request>,
|
||||
) -> Connection<S, T> {
|
||||
Connection {
|
||||
stream,
|
||||
parameters,
|
||||
receiver,
|
||||
pending_request: None,
|
||||
pending_responses,
|
||||
responses: VecDeque::new(),
|
||||
state: State::Active,
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_response(
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<BackendMessage, Error>>> {
|
||||
if let Some(message) = self.pending_responses.pop_front() {
|
||||
trace!("retrying pending response");
|
||||
return Poll::Ready(Some(Ok(message)));
|
||||
}
|
||||
|
||||
Pin::new(&mut self.stream)
|
||||
.poll_next(cx)
|
||||
.map(|o| o.map(|r| r.map_err(Error::io)))
|
||||
}
|
||||
|
||||
fn poll_read(&mut self, cx: &mut Context<'_>) -> Result<Option<AsyncMessage>, Error> {
|
||||
if self.state != State::Active {
|
||||
trace!("poll_read: done");
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
loop {
|
||||
let message = match self.poll_response(cx)? {
|
||||
Poll::Ready(Some(message)) => message,
|
||||
Poll::Ready(None) => return Err(Error::closed()),
|
||||
Poll::Pending => {
|
||||
trace!("poll_read: waiting on response");
|
||||
return Ok(None);
|
||||
}
|
||||
};
|
||||
|
||||
let (mut messages, request_complete) = match message {
|
||||
BackendMessage::Async(Message::NoticeResponse(body)) => {
|
||||
let error = DbError::parse(&mut body.fields()).map_err(Error::parse)?;
|
||||
return Ok(Some(AsyncMessage::Notice(error)));
|
||||
}
|
||||
BackendMessage::Async(Message::NotificationResponse(body)) => {
|
||||
let notification = Notification {
|
||||
process_id: body.process_id(),
|
||||
channel: body.channel().map_err(Error::parse)?.to_string(),
|
||||
payload: body.message().map_err(Error::parse)?.to_string(),
|
||||
};
|
||||
return Ok(Some(AsyncMessage::Notification(notification)));
|
||||
}
|
||||
BackendMessage::Async(Message::ParameterStatus(body)) => {
|
||||
self.parameters.insert(
|
||||
body.name().map_err(Error::parse)?.to_string(),
|
||||
body.value().map_err(Error::parse)?.to_string(),
|
||||
);
|
||||
continue;
|
||||
}
|
||||
BackendMessage::Async(_) => unreachable!(),
|
||||
BackendMessage::Normal {
|
||||
messages,
|
||||
request_complete,
|
||||
} => (messages, request_complete),
|
||||
};
|
||||
|
||||
let mut response = match self.responses.pop_front() {
|
||||
Some(response) => response,
|
||||
None => match messages.next().map_err(Error::parse)? {
|
||||
Some(Message::ErrorResponse(error)) => return Err(Error::db(error)),
|
||||
_ => return Err(Error::unexpected_message()),
|
||||
},
|
||||
};
|
||||
|
||||
match response.sender.poll_reserve(cx) {
|
||||
Poll::Ready(Ok(())) => {
|
||||
let _ = response.sender.send_item(messages);
|
||||
if !request_complete {
|
||||
self.responses.push_front(response);
|
||||
}
|
||||
}
|
||||
Poll::Ready(Err(_)) => {
|
||||
// we need to keep paging through the rest of the messages even if the receiver's hung up
|
||||
if !request_complete {
|
||||
self.responses.push_front(response);
|
||||
}
|
||||
}
|
||||
Poll::Pending => {
|
||||
self.responses.push_front(response);
|
||||
self.pending_responses.push_back(BackendMessage::Normal {
|
||||
messages,
|
||||
request_complete,
|
||||
});
|
||||
trace!("poll_read: waiting on sender");
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll<Option<RequestMessages>> {
|
||||
if let Some(messages) = self.pending_request.take() {
|
||||
trace!("retrying pending request");
|
||||
return Poll::Ready(Some(messages));
|
||||
}
|
||||
|
||||
if self.receiver.is_closed() {
|
||||
return Poll::Ready(None);
|
||||
}
|
||||
|
||||
match self.receiver.poll_recv(cx) {
|
||||
Poll::Ready(Some(request)) => {
|
||||
trace!("polled new request");
|
||||
self.responses.push_back(Response {
|
||||
sender: PollSender::new(request.sender),
|
||||
});
|
||||
Poll::Ready(Some(request.messages))
|
||||
}
|
||||
Poll::Ready(None) => Poll::Ready(None),
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_write(&mut self, cx: &mut Context<'_>) -> Result<bool, Error> {
|
||||
loop {
|
||||
if self.state == State::Closing {
|
||||
trace!("poll_write: done");
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
if Pin::new(&mut self.stream)
|
||||
.poll_ready(cx)
|
||||
.map_err(Error::io)?
|
||||
.is_pending()
|
||||
{
|
||||
trace!("poll_write: waiting on socket");
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let request = match self.poll_request(cx) {
|
||||
Poll::Ready(Some(request)) => request,
|
||||
Poll::Ready(None) if self.responses.is_empty() && self.state == State::Active => {
|
||||
trace!("poll_write: at eof, terminating");
|
||||
self.state = State::Terminating;
|
||||
let mut request = BytesMut::new();
|
||||
frontend::terminate(&mut request);
|
||||
RequestMessages::Single(FrontendMessage::Raw(request.freeze()))
|
||||
}
|
||||
Poll::Ready(None) => {
|
||||
trace!(
|
||||
"poll_write: at eof, pending responses {}",
|
||||
self.responses.len()
|
||||
);
|
||||
return Ok(true);
|
||||
}
|
||||
Poll::Pending => {
|
||||
trace!("poll_write: waiting on request");
|
||||
return Ok(true);
|
||||
}
|
||||
};
|
||||
|
||||
match request {
|
||||
RequestMessages::Single(request) => {
|
||||
Pin::new(&mut self.stream)
|
||||
.start_send(request)
|
||||
.map_err(Error::io)?;
|
||||
if self.state == State::Terminating {
|
||||
trace!("poll_write: sent eof, closing");
|
||||
self.state = State::Closing;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(&mut self, cx: &mut Context<'_>) -> Result<(), Error> {
|
||||
match Pin::new(&mut self.stream)
|
||||
.poll_flush(cx)
|
||||
.map_err(Error::io)?
|
||||
{
|
||||
Poll::Ready(()) => trace!("poll_flush: flushed"),
|
||||
Poll::Pending => trace!("poll_flush: waiting on socket"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||
if self.state != State::Closing {
|
||||
return Poll::Pending;
|
||||
}
|
||||
|
||||
match Pin::new(&mut self.stream)
|
||||
.poll_close(cx)
|
||||
.map_err(Error::io)?
|
||||
{
|
||||
Poll::Ready(()) => {
|
||||
trace!("poll_shutdown: complete");
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
Poll::Pending => {
|
||||
trace!("poll_shutdown: waiting on socket");
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the value of a runtime parameter for this connection.
|
||||
pub fn parameter(&self, name: &str) -> Option<&str> {
|
||||
self.parameters.get(name).map(|s| &**s)
|
||||
}
|
||||
|
||||
/// Polls for asynchronous messages from the server.
|
||||
///
|
||||
/// The server can send notices as well as notifications asynchronously to the client. Applications that wish to
|
||||
/// examine those messages should use this method to drive the connection rather than its `Future` implementation.
|
||||
pub fn poll_message(
|
||||
&mut self,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<AsyncMessage, Error>>> {
|
||||
let message = self.poll_read(cx)?;
|
||||
let want_flush = self.poll_write(cx)?;
|
||||
if want_flush {
|
||||
self.poll_flush(cx)?;
|
||||
}
|
||||
match message {
|
||||
Some(message) => Poll::Ready(Some(Ok(message))),
|
||||
None => match self.poll_shutdown(cx) {
|
||||
Poll::Ready(Ok(())) => Poll::Ready(None),
|
||||
Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
|
||||
Poll::Pending => Poll::Pending,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, T> Future for Connection<S, T>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
type Output = Result<(), Error>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||
while let Some(message) = ready!(self.poll_message(cx)?) {
|
||||
if let AsyncMessage::Notice(notice) = message {
|
||||
info!("{}: {}", notice.severity(), notice.message());
|
||||
}
|
||||
}
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
501
libs/proxy/tokio-postgres2/src/error/mod.rs
Normal file
501
libs/proxy/tokio-postgres2/src/error/mod.rs
Normal file
@@ -0,0 +1,501 @@
|
||||
//! Errors.
|
||||
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use postgres_protocol2::message::backend::{ErrorFields, ErrorResponseBody};
|
||||
use std::error::{self, Error as _Error};
|
||||
use std::fmt;
|
||||
use std::io;
|
||||
|
||||
pub use self::sqlstate::*;
|
||||
|
||||
#[allow(clippy::unreadable_literal)]
|
||||
mod sqlstate;
|
||||
|
||||
/// The severity of a Postgres error or notice.
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
pub enum Severity {
|
||||
/// PANIC
|
||||
Panic,
|
||||
/// FATAL
|
||||
Fatal,
|
||||
/// ERROR
|
||||
Error,
|
||||
/// WARNING
|
||||
Warning,
|
||||
/// NOTICE
|
||||
Notice,
|
||||
/// DEBUG
|
||||
Debug,
|
||||
/// INFO
|
||||
Info,
|
||||
/// LOG
|
||||
Log,
|
||||
}
|
||||
|
||||
impl fmt::Display for Severity {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let s = match *self {
|
||||
Severity::Panic => "PANIC",
|
||||
Severity::Fatal => "FATAL",
|
||||
Severity::Error => "ERROR",
|
||||
Severity::Warning => "WARNING",
|
||||
Severity::Notice => "NOTICE",
|
||||
Severity::Debug => "DEBUG",
|
||||
Severity::Info => "INFO",
|
||||
Severity::Log => "LOG",
|
||||
};
|
||||
fmt.write_str(s)
|
||||
}
|
||||
}
|
||||
|
||||
impl Severity {
|
||||
fn from_str(s: &str) -> Option<Severity> {
|
||||
match s {
|
||||
"PANIC" => Some(Severity::Panic),
|
||||
"FATAL" => Some(Severity::Fatal),
|
||||
"ERROR" => Some(Severity::Error),
|
||||
"WARNING" => Some(Severity::Warning),
|
||||
"NOTICE" => Some(Severity::Notice),
|
||||
"DEBUG" => Some(Severity::Debug),
|
||||
"INFO" => Some(Severity::Info),
|
||||
"LOG" => Some(Severity::Log),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A Postgres error or notice.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct DbError {
|
||||
severity: String,
|
||||
parsed_severity: Option<Severity>,
|
||||
code: SqlState,
|
||||
message: String,
|
||||
detail: Option<String>,
|
||||
hint: Option<String>,
|
||||
position: Option<ErrorPosition>,
|
||||
where_: Option<String>,
|
||||
schema: Option<String>,
|
||||
table: Option<String>,
|
||||
column: Option<String>,
|
||||
datatype: Option<String>,
|
||||
constraint: Option<String>,
|
||||
file: Option<String>,
|
||||
line: Option<u32>,
|
||||
routine: Option<String>,
|
||||
}
|
||||
|
||||
impl DbError {
|
||||
pub(crate) fn parse(fields: &mut ErrorFields<'_>) -> io::Result<DbError> {
|
||||
let mut severity = None;
|
||||
let mut parsed_severity = None;
|
||||
let mut code = None;
|
||||
let mut message = None;
|
||||
let mut detail = None;
|
||||
let mut hint = None;
|
||||
let mut normal_position = None;
|
||||
let mut internal_position = None;
|
||||
let mut internal_query = None;
|
||||
let mut where_ = None;
|
||||
let mut schema = None;
|
||||
let mut table = None;
|
||||
let mut column = None;
|
||||
let mut datatype = None;
|
||||
let mut constraint = None;
|
||||
let mut file = None;
|
||||
let mut line = None;
|
||||
let mut routine = None;
|
||||
|
||||
while let Some(field) = fields.next()? {
|
||||
match field.type_() {
|
||||
b'S' => severity = Some(field.value().to_owned()),
|
||||
b'C' => code = Some(SqlState::from_code(field.value())),
|
||||
b'M' => message = Some(field.value().to_owned()),
|
||||
b'D' => detail = Some(field.value().to_owned()),
|
||||
b'H' => hint = Some(field.value().to_owned()),
|
||||
b'P' => {
|
||||
normal_position = Some(field.value().parse::<u32>().map_err(|_| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"`P` field did not contain an integer",
|
||||
)
|
||||
})?);
|
||||
}
|
||||
b'p' => {
|
||||
internal_position = Some(field.value().parse::<u32>().map_err(|_| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"`p` field did not contain an integer",
|
||||
)
|
||||
})?);
|
||||
}
|
||||
b'q' => internal_query = Some(field.value().to_owned()),
|
||||
b'W' => where_ = Some(field.value().to_owned()),
|
||||
b's' => schema = Some(field.value().to_owned()),
|
||||
b't' => table = Some(field.value().to_owned()),
|
||||
b'c' => column = Some(field.value().to_owned()),
|
||||
b'd' => datatype = Some(field.value().to_owned()),
|
||||
b'n' => constraint = Some(field.value().to_owned()),
|
||||
b'F' => file = Some(field.value().to_owned()),
|
||||
b'L' => {
|
||||
line = Some(field.value().parse::<u32>().map_err(|_| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"`L` field did not contain an integer",
|
||||
)
|
||||
})?);
|
||||
}
|
||||
b'R' => routine = Some(field.value().to_owned()),
|
||||
b'V' => {
|
||||
parsed_severity = Some(Severity::from_str(field.value()).ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"`V` field contained an invalid value",
|
||||
)
|
||||
})?);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(DbError {
|
||||
severity: severity
|
||||
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "`S` field missing"))?,
|
||||
parsed_severity,
|
||||
code: code
|
||||
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "`C` field missing"))?,
|
||||
message: message
|
||||
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "`M` field missing"))?,
|
||||
detail,
|
||||
hint,
|
||||
position: match normal_position {
|
||||
Some(position) => Some(ErrorPosition::Original(position)),
|
||||
None => match internal_position {
|
||||
Some(position) => Some(ErrorPosition::Internal {
|
||||
position,
|
||||
query: internal_query.ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"`q` field missing but `p` field present",
|
||||
)
|
||||
})?,
|
||||
}),
|
||||
None => None,
|
||||
},
|
||||
},
|
||||
where_,
|
||||
schema,
|
||||
table,
|
||||
column,
|
||||
datatype,
|
||||
constraint,
|
||||
file,
|
||||
line,
|
||||
routine,
|
||||
})
|
||||
}
|
||||
|
||||
/// The field contents are ERROR, FATAL, or PANIC (in an error message),
|
||||
/// or WARNING, NOTICE, DEBUG, INFO, or LOG (in a notice message), or a
|
||||
/// localized translation of one of these.
|
||||
pub fn severity(&self) -> &str {
|
||||
&self.severity
|
||||
}
|
||||
|
||||
/// A parsed, nonlocalized version of `severity`. (PostgreSQL 9.6+)
|
||||
pub fn parsed_severity(&self) -> Option<Severity> {
|
||||
self.parsed_severity
|
||||
}
|
||||
|
||||
/// The SQLSTATE code for the error.
|
||||
pub fn code(&self) -> &SqlState {
|
||||
&self.code
|
||||
}
|
||||
|
||||
/// The primary human-readable error message.
|
||||
///
|
||||
/// This should be accurate but terse (typically one line).
|
||||
pub fn message(&self) -> &str {
|
||||
&self.message
|
||||
}
|
||||
|
||||
/// An optional secondary error message carrying more detail about the
|
||||
/// problem.
|
||||
///
|
||||
/// Might run to multiple lines.
|
||||
pub fn detail(&self) -> Option<&str> {
|
||||
self.detail.as_deref()
|
||||
}
|
||||
|
||||
/// An optional suggestion what to do about the problem.
|
||||
///
|
||||
/// This is intended to differ from `detail` in that it offers advice
|
||||
/// (potentially inappropriate) rather than hard facts. Might run to
|
||||
/// multiple lines.
|
||||
pub fn hint(&self) -> Option<&str> {
|
||||
self.hint.as_deref()
|
||||
}
|
||||
|
||||
/// An optional error cursor position into either the original query string
|
||||
/// or an internally generated query.
|
||||
pub fn position(&self) -> Option<&ErrorPosition> {
|
||||
self.position.as_ref()
|
||||
}
|
||||
|
||||
/// An indication of the context in which the error occurred.
|
||||
///
|
||||
/// Presently this includes a call stack traceback of active procedural
|
||||
/// language functions and internally-generated queries. The trace is one
|
||||
/// entry per line, most recent first.
|
||||
pub fn where_(&self) -> Option<&str> {
|
||||
self.where_.as_deref()
|
||||
}
|
||||
|
||||
/// If the error was associated with a specific database object, the name
|
||||
/// of the schema containing that object, if any. (PostgreSQL 9.3+)
|
||||
pub fn schema(&self) -> Option<&str> {
|
||||
self.schema.as_deref()
|
||||
}
|
||||
|
||||
/// If the error was associated with a specific table, the name of the
|
||||
/// table. (Refer to the schema name field for the name of the table's
|
||||
/// schema.) (PostgreSQL 9.3+)
|
||||
pub fn table(&self) -> Option<&str> {
|
||||
self.table.as_deref()
|
||||
}
|
||||
|
||||
/// If the error was associated with a specific table column, the name of
|
||||
/// the column.
|
||||
///
|
||||
/// (Refer to the schema and table name fields to identify the table.)
|
||||
/// (PostgreSQL 9.3+)
|
||||
pub fn column(&self) -> Option<&str> {
|
||||
self.column.as_deref()
|
||||
}
|
||||
|
||||
/// If the error was associated with a specific data type, the name of the
|
||||
/// data type. (Refer to the schema name field for the name of the data
|
||||
/// type's schema.) (PostgreSQL 9.3+)
|
||||
pub fn datatype(&self) -> Option<&str> {
|
||||
self.datatype.as_deref()
|
||||
}
|
||||
|
||||
/// If the error was associated with a specific constraint, the name of the
|
||||
/// constraint.
|
||||
///
|
||||
/// Refer to fields listed above for the associated table or domain.
|
||||
/// (For this purpose, indexes are treated as constraints, even if they
|
||||
/// weren't created with constraint syntax.) (PostgreSQL 9.3+)
|
||||
pub fn constraint(&self) -> Option<&str> {
|
||||
self.constraint.as_deref()
|
||||
}
|
||||
|
||||
/// The file name of the source-code location where the error was reported.
|
||||
pub fn file(&self) -> Option<&str> {
|
||||
self.file.as_deref()
|
||||
}
|
||||
|
||||
/// The line number of the source-code location where the error was
|
||||
/// reported.
|
||||
pub fn line(&self) -> Option<u32> {
|
||||
self.line
|
||||
}
|
||||
|
||||
/// The name of the source-code routine reporting the error.
|
||||
pub fn routine(&self) -> Option<&str> {
|
||||
self.routine.as_deref()
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for DbError {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(fmt, "{}: {}", self.severity, self.message)?;
|
||||
if let Some(detail) = &self.detail {
|
||||
write!(fmt, "\nDETAIL: {}", detail)?;
|
||||
}
|
||||
if let Some(hint) = &self.hint {
|
||||
write!(fmt, "\nHINT: {}", hint)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl error::Error for DbError {}
|
||||
|
||||
/// Represents the position of an error in a query.
|
||||
#[derive(Clone, PartialEq, Eq, Debug)]
|
||||
pub enum ErrorPosition {
|
||||
/// A position in the original query.
|
||||
Original(u32),
|
||||
/// A position in an internally generated query.
|
||||
Internal {
|
||||
/// The byte position.
|
||||
position: u32,
|
||||
/// A query generated by the Postgres server.
|
||||
query: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
enum Kind {
|
||||
Io,
|
||||
UnexpectedMessage,
|
||||
Tls,
|
||||
ToSql(usize),
|
||||
FromSql(usize),
|
||||
Column(String),
|
||||
Closed,
|
||||
Db,
|
||||
Parse,
|
||||
Encode,
|
||||
Authentication,
|
||||
ConfigParse,
|
||||
Config,
|
||||
Connect,
|
||||
Timeout,
|
||||
}
|
||||
|
||||
struct ErrorInner {
|
||||
kind: Kind,
|
||||
cause: Option<Box<dyn error::Error + Sync + Send>>,
|
||||
}
|
||||
|
||||
/// An error communicating with the Postgres server.
|
||||
pub struct Error(Box<ErrorInner>);
|
||||
|
||||
impl fmt::Debug for Error {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
fmt.debug_struct("Error")
|
||||
.field("kind", &self.0.kind)
|
||||
.field("cause", &self.0.cause)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Error {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match &self.0.kind {
|
||||
Kind::Io => fmt.write_str("error communicating with the server")?,
|
||||
Kind::UnexpectedMessage => fmt.write_str("unexpected message from server")?,
|
||||
Kind::Tls => fmt.write_str("error performing TLS handshake")?,
|
||||
Kind::ToSql(idx) => write!(fmt, "error serializing parameter {}", idx)?,
|
||||
Kind::FromSql(idx) => write!(fmt, "error deserializing column {}", idx)?,
|
||||
Kind::Column(column) => write!(fmt, "invalid column `{}`", column)?,
|
||||
Kind::Closed => fmt.write_str("connection closed")?,
|
||||
Kind::Db => fmt.write_str("db error")?,
|
||||
Kind::Parse => fmt.write_str("error parsing response from server")?,
|
||||
Kind::Encode => fmt.write_str("error encoding message to server")?,
|
||||
Kind::Authentication => fmt.write_str("authentication error")?,
|
||||
Kind::ConfigParse => fmt.write_str("invalid connection string")?,
|
||||
Kind::Config => fmt.write_str("invalid configuration")?,
|
||||
Kind::Connect => fmt.write_str("error connecting to server")?,
|
||||
Kind::Timeout => fmt.write_str("timeout waiting for server")?,
|
||||
};
|
||||
if let Some(ref cause) = self.0.cause {
|
||||
write!(fmt, ": {}", cause)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl error::Error for Error {
|
||||
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
|
||||
self.0.cause.as_ref().map(|e| &**e as _)
|
||||
}
|
||||
}
|
||||
|
||||
impl Error {
|
||||
/// Consumes the error, returning its cause.
|
||||
pub fn into_source(self) -> Option<Box<dyn error::Error + Sync + Send>> {
|
||||
self.0.cause
|
||||
}
|
||||
|
||||
/// Returns the source of this error if it was a `DbError`.
|
||||
///
|
||||
/// This is a simple convenience method.
|
||||
pub fn as_db_error(&self) -> Option<&DbError> {
|
||||
self.source().and_then(|e| e.downcast_ref::<DbError>())
|
||||
}
|
||||
|
||||
/// Determines if the error was associated with closed connection.
|
||||
pub fn is_closed(&self) -> bool {
|
||||
self.0.kind == Kind::Closed
|
||||
}
|
||||
|
||||
/// Returns the SQLSTATE error code associated with the error.
|
||||
///
|
||||
/// This is a convenience method that downcasts the cause to a `DbError` and returns its code.
|
||||
pub fn code(&self) -> Option<&SqlState> {
|
||||
self.as_db_error().map(DbError::code)
|
||||
}
|
||||
|
||||
fn new(kind: Kind, cause: Option<Box<dyn error::Error + Sync + Send>>) -> Error {
|
||||
Error(Box::new(ErrorInner { kind, cause }))
|
||||
}
|
||||
|
||||
pub(crate) fn closed() -> Error {
|
||||
Error::new(Kind::Closed, None)
|
||||
}
|
||||
|
||||
pub(crate) fn unexpected_message() -> Error {
|
||||
Error::new(Kind::UnexpectedMessage, None)
|
||||
}
|
||||
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
pub(crate) fn db(error: ErrorResponseBody) -> Error {
|
||||
match DbError::parse(&mut error.fields()) {
|
||||
Ok(e) => Error::new(Kind::Db, Some(Box::new(e))),
|
||||
Err(e) => Error::new(Kind::Parse, Some(Box::new(e))),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn parse(e: io::Error) -> Error {
|
||||
Error::new(Kind::Parse, Some(Box::new(e)))
|
||||
}
|
||||
|
||||
pub(crate) fn encode(e: io::Error) -> Error {
|
||||
Error::new(Kind::Encode, Some(Box::new(e)))
|
||||
}
|
||||
|
||||
#[allow(clippy::wrong_self_convention)]
|
||||
pub(crate) fn to_sql(e: Box<dyn error::Error + Sync + Send>, idx: usize) -> Error {
|
||||
Error::new(Kind::ToSql(idx), Some(e))
|
||||
}
|
||||
|
||||
pub(crate) fn from_sql(e: Box<dyn error::Error + Sync + Send>, idx: usize) -> Error {
|
||||
Error::new(Kind::FromSql(idx), Some(e))
|
||||
}
|
||||
|
||||
pub(crate) fn column(column: String) -> Error {
|
||||
Error::new(Kind::Column(column), None)
|
||||
}
|
||||
|
||||
pub(crate) fn tls(e: Box<dyn error::Error + Sync + Send>) -> Error {
|
||||
Error::new(Kind::Tls, Some(e))
|
||||
}
|
||||
|
||||
pub(crate) fn io(e: io::Error) -> Error {
|
||||
Error::new(Kind::Io, Some(Box::new(e)))
|
||||
}
|
||||
|
||||
pub(crate) fn authentication(e: Box<dyn error::Error + Sync + Send>) -> Error {
|
||||
Error::new(Kind::Authentication, Some(e))
|
||||
}
|
||||
|
||||
pub(crate) fn config_parse(e: Box<dyn error::Error + Sync + Send>) -> Error {
|
||||
Error::new(Kind::ConfigParse, Some(e))
|
||||
}
|
||||
|
||||
pub(crate) fn config(e: Box<dyn error::Error + Sync + Send>) -> Error {
|
||||
Error::new(Kind::Config, Some(e))
|
||||
}
|
||||
|
||||
pub(crate) fn connect(e: io::Error) -> Error {
|
||||
Error::new(Kind::Connect, Some(Box::new(e)))
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
pub fn __private_api_timeout() -> Error {
|
||||
Error::new(Kind::Timeout, None)
|
||||
}
|
||||
}
|
||||
1670
libs/proxy/tokio-postgres2/src/error/sqlstate.rs
Normal file
1670
libs/proxy/tokio-postgres2/src/error/sqlstate.rs
Normal file
File diff suppressed because it is too large
Load Diff
64
libs/proxy/tokio-postgres2/src/generic_client.rs
Normal file
64
libs/proxy/tokio-postgres2/src/generic_client.rs
Normal file
@@ -0,0 +1,64 @@
|
||||
use crate::query::RowStream;
|
||||
use crate::types::Type;
|
||||
use crate::{Client, Error, Transaction};
|
||||
use async_trait::async_trait;
|
||||
use postgres_protocol2::Oid;
|
||||
|
||||
mod private {
|
||||
pub trait Sealed {}
|
||||
}
|
||||
|
||||
/// A trait allowing abstraction over connections and transactions.
|
||||
///
|
||||
/// This trait is "sealed", and cannot be implemented outside of this crate.
|
||||
#[async_trait]
|
||||
pub trait GenericClient: private::Sealed {
|
||||
/// Like `Client::query_raw_txt`.
|
||||
async fn query_raw_txt<S, I>(&self, statement: &str, params: I) -> Result<RowStream, Error>
|
||||
where
|
||||
S: AsRef<str> + Sync + Send,
|
||||
I: IntoIterator<Item = Option<S>> + Sync + Send,
|
||||
I::IntoIter: ExactSizeIterator + Sync + Send;
|
||||
|
||||
/// Query for type information
|
||||
async fn get_type(&self, oid: Oid) -> Result<Type, Error>;
|
||||
}
|
||||
|
||||
impl private::Sealed for Client {}
|
||||
|
||||
#[async_trait]
|
||||
impl GenericClient for Client {
|
||||
async fn query_raw_txt<S, I>(&self, statement: &str, params: I) -> Result<RowStream, Error>
|
||||
where
|
||||
S: AsRef<str> + Sync + Send,
|
||||
I: IntoIterator<Item = Option<S>> + Sync + Send,
|
||||
I::IntoIter: ExactSizeIterator + Sync + Send,
|
||||
{
|
||||
self.query_raw_txt(statement, params).await
|
||||
}
|
||||
|
||||
/// Query for type information
|
||||
async fn get_type(&self, oid: Oid) -> Result<Type, Error> {
|
||||
self.get_type(oid).await
|
||||
}
|
||||
}
|
||||
|
||||
impl private::Sealed for Transaction<'_> {}
|
||||
|
||||
#[async_trait]
|
||||
#[allow(clippy::needless_lifetimes)]
|
||||
impl GenericClient for Transaction<'_> {
|
||||
async fn query_raw_txt<S, I>(&self, statement: &str, params: I) -> Result<RowStream, Error>
|
||||
where
|
||||
S: AsRef<str> + Sync + Send,
|
||||
I: IntoIterator<Item = Option<S>> + Sync + Send,
|
||||
I::IntoIter: ExactSizeIterator + Sync + Send,
|
||||
{
|
||||
self.query_raw_txt(statement, params).await
|
||||
}
|
||||
|
||||
/// Query for type information
|
||||
async fn get_type(&self, oid: Oid) -> Result<Type, Error> {
|
||||
self.client().get_type(oid).await
|
||||
}
|
||||
}
|
||||
148
libs/proxy/tokio-postgres2/src/lib.rs
Normal file
148
libs/proxy/tokio-postgres2/src/lib.rs
Normal file
@@ -0,0 +1,148 @@
|
||||
//! An asynchronous, pipelined, PostgreSQL client.
|
||||
#![warn(rust_2018_idioms, clippy::all, missing_docs)]
|
||||
|
||||
pub use crate::cancel_token::CancelToken;
|
||||
pub use crate::client::Client;
|
||||
pub use crate::config::Config;
|
||||
pub use crate::connection::Connection;
|
||||
use crate::error::DbError;
|
||||
pub use crate::error::Error;
|
||||
pub use crate::generic_client::GenericClient;
|
||||
pub use crate::query::RowStream;
|
||||
pub use crate::row::{Row, SimpleQueryRow};
|
||||
pub use crate::simple_query::SimpleQueryStream;
|
||||
pub use crate::statement::{Column, Statement};
|
||||
use crate::tls::MakeTlsConnect;
|
||||
pub use crate::tls::NoTls;
|
||||
pub use crate::to_statement::ToStatement;
|
||||
pub use crate::transaction::Transaction;
|
||||
pub use crate::transaction_builder::{IsolationLevel, TransactionBuilder};
|
||||
use crate::types::ToSql;
|
||||
use postgres_protocol2::message::backend::ReadyForQueryBody;
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
/// After executing a query, the connection will be in one of these states
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
#[repr(u8)]
|
||||
pub enum ReadyForQueryStatus {
|
||||
/// Connection state is unknown
|
||||
Unknown,
|
||||
/// Connection is idle (no transactions)
|
||||
Idle = b'I',
|
||||
/// Connection is in a transaction block
|
||||
Transaction = b'T',
|
||||
/// Connection is in a failed transaction block
|
||||
FailedTransaction = b'E',
|
||||
}
|
||||
|
||||
impl From<ReadyForQueryBody> for ReadyForQueryStatus {
|
||||
fn from(value: ReadyForQueryBody) -> Self {
|
||||
match value.status() {
|
||||
b'I' => Self::Idle,
|
||||
b'T' => Self::Transaction,
|
||||
b'E' => Self::FailedTransaction,
|
||||
_ => Self::Unknown,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mod cancel_query;
|
||||
mod cancel_query_raw;
|
||||
mod cancel_token;
|
||||
mod client;
|
||||
mod codec;
|
||||
pub mod config;
|
||||
mod connect;
|
||||
mod connect_raw;
|
||||
mod connect_socket;
|
||||
mod connect_tls;
|
||||
mod connection;
|
||||
pub mod error;
|
||||
mod generic_client;
|
||||
pub mod maybe_tls_stream;
|
||||
mod prepare;
|
||||
mod query;
|
||||
pub mod row;
|
||||
mod simple_query;
|
||||
mod statement;
|
||||
pub mod tls;
|
||||
mod to_statement;
|
||||
mod transaction;
|
||||
mod transaction_builder;
|
||||
pub mod types;
|
||||
|
||||
/// A convenience function which parses a connection string and connects to the database.
|
||||
///
|
||||
/// See the documentation for [`Config`] for details on the connection string format.
|
||||
///
|
||||
/// Requires the `runtime` Cargo feature (enabled by default).
|
||||
///
|
||||
/// [`Config`]: config/struct.Config.html
|
||||
pub async fn connect<T>(
|
||||
config: &str,
|
||||
tls: T,
|
||||
) -> Result<(Client, Connection<TcpStream, T::Stream>), Error>
|
||||
where
|
||||
T: MakeTlsConnect<TcpStream>,
|
||||
{
|
||||
let config = config.parse::<Config>()?;
|
||||
config.connect(tls).await
|
||||
}
|
||||
|
||||
/// An asynchronous notification.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Notification {
|
||||
process_id: i32,
|
||||
channel: String,
|
||||
payload: String,
|
||||
}
|
||||
|
||||
impl Notification {
|
||||
/// The process ID of the notifying backend process.
|
||||
pub fn process_id(&self) -> i32 {
|
||||
self.process_id
|
||||
}
|
||||
|
||||
/// The name of the channel that the notify has been raised on.
|
||||
pub fn channel(&self) -> &str {
|
||||
&self.channel
|
||||
}
|
||||
|
||||
/// The "payload" string passed from the notifying process.
|
||||
pub fn payload(&self) -> &str {
|
||||
&self.payload
|
||||
}
|
||||
}
|
||||
|
||||
/// An asynchronous message from the server.
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
#[derive(Debug, Clone)]
|
||||
#[non_exhaustive]
|
||||
pub enum AsyncMessage {
|
||||
/// A notice.
|
||||
///
|
||||
/// Notices use the same format as errors, but aren't "errors" per-se.
|
||||
Notice(DbError),
|
||||
/// A notification.
|
||||
///
|
||||
/// Connections can subscribe to notifications with the `LISTEN` command.
|
||||
Notification(Notification),
|
||||
}
|
||||
|
||||
/// Message returned by the `SimpleQuery` stream.
|
||||
#[derive(Debug)]
|
||||
#[non_exhaustive]
|
||||
pub enum SimpleQueryMessage {
|
||||
/// A row of data.
|
||||
Row(SimpleQueryRow),
|
||||
/// A statement in the query has completed.
|
||||
///
|
||||
/// The number of rows modified or selected is returned.
|
||||
CommandComplete(u64),
|
||||
}
|
||||
|
||||
fn slice_iter<'a>(
|
||||
s: &'a [&'a (dyn ToSql + Sync)],
|
||||
) -> impl ExactSizeIterator<Item = &'a (dyn ToSql + Sync)> + 'a {
|
||||
s.iter().map(|s| *s as _)
|
||||
}
|
||||
77
libs/proxy/tokio-postgres2/src/maybe_tls_stream.rs
Normal file
77
libs/proxy/tokio-postgres2/src/maybe_tls_stream.rs
Normal file
@@ -0,0 +1,77 @@
|
||||
//! MaybeTlsStream.
|
||||
//!
|
||||
//! Represents a stream that may or may not be encrypted with TLS.
|
||||
use crate::tls::{ChannelBinding, TlsStream};
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
|
||||
/// A stream that may or may not be encrypted with TLS.
|
||||
pub enum MaybeTlsStream<S, T> {
|
||||
/// An unencrypted stream.
|
||||
Raw(S),
|
||||
/// An encrypted stream.
|
||||
Tls(T),
|
||||
}
|
||||
|
||||
impl<S, T> AsyncRead for MaybeTlsStream<S, T>
|
||||
where
|
||||
S: AsyncRead + Unpin,
|
||||
T: AsyncRead + Unpin,
|
||||
{
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
match &mut *self {
|
||||
MaybeTlsStream::Raw(s) => Pin::new(s).poll_read(cx, buf),
|
||||
MaybeTlsStream::Tls(s) => Pin::new(s).poll_read(cx, buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, T> AsyncWrite for MaybeTlsStream<S, T>
|
||||
where
|
||||
S: AsyncWrite + Unpin,
|
||||
T: AsyncWrite + Unpin,
|
||||
{
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
match &mut *self {
|
||||
MaybeTlsStream::Raw(s) => Pin::new(s).poll_write(cx, buf),
|
||||
MaybeTlsStream::Tls(s) => Pin::new(s).poll_write(cx, buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
match &mut *self {
|
||||
MaybeTlsStream::Raw(s) => Pin::new(s).poll_flush(cx),
|
||||
MaybeTlsStream::Tls(s) => Pin::new(s).poll_flush(cx),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
match &mut *self {
|
||||
MaybeTlsStream::Raw(s) => Pin::new(s).poll_shutdown(cx),
|
||||
MaybeTlsStream::Tls(s) => Pin::new(s).poll_shutdown(cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, T> TlsStream for MaybeTlsStream<S, T>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: TlsStream + Unpin,
|
||||
{
|
||||
fn channel_binding(&self) -> ChannelBinding {
|
||||
match self {
|
||||
MaybeTlsStream::Raw(_) => ChannelBinding::none(),
|
||||
MaybeTlsStream::Tls(s) => s.channel_binding(),
|
||||
}
|
||||
}
|
||||
}
|
||||
262
libs/proxy/tokio-postgres2/src/prepare.rs
Normal file
262
libs/proxy/tokio-postgres2/src/prepare.rs
Normal file
@@ -0,0 +1,262 @@
|
||||
use crate::client::InnerClient;
|
||||
use crate::codec::FrontendMessage;
|
||||
use crate::connection::RequestMessages;
|
||||
use crate::error::SqlState;
|
||||
use crate::types::{Field, Kind, Oid, Type};
|
||||
use crate::{query, slice_iter};
|
||||
use crate::{Column, Error, Statement};
|
||||
use bytes::Bytes;
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use futures_util::{pin_mut, TryStreamExt};
|
||||
use log::debug;
|
||||
use postgres_protocol2::message::backend::Message;
|
||||
use postgres_protocol2::message::frontend;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
pub(crate) const TYPEINFO_QUERY: &str = "\
|
||||
SELECT t.typname, t.typtype, t.typelem, r.rngsubtype, t.typbasetype, n.nspname, t.typrelid
|
||||
FROM pg_catalog.pg_type t
|
||||
LEFT OUTER JOIN pg_catalog.pg_range r ON r.rngtypid = t.oid
|
||||
INNER JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid
|
||||
WHERE t.oid = $1
|
||||
";
|
||||
|
||||
// Range types weren't added until Postgres 9.2, so pg_range may not exist
|
||||
const TYPEINFO_FALLBACK_QUERY: &str = "\
|
||||
SELECT t.typname, t.typtype, t.typelem, NULL::OID, t.typbasetype, n.nspname, t.typrelid
|
||||
FROM pg_catalog.pg_type t
|
||||
INNER JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid
|
||||
WHERE t.oid = $1
|
||||
";
|
||||
|
||||
const TYPEINFO_ENUM_QUERY: &str = "\
|
||||
SELECT enumlabel
|
||||
FROM pg_catalog.pg_enum
|
||||
WHERE enumtypid = $1
|
||||
ORDER BY enumsortorder
|
||||
";
|
||||
|
||||
// Postgres 9.0 didn't have enumsortorder
|
||||
const TYPEINFO_ENUM_FALLBACK_QUERY: &str = "\
|
||||
SELECT enumlabel
|
||||
FROM pg_catalog.pg_enum
|
||||
WHERE enumtypid = $1
|
||||
ORDER BY oid
|
||||
";
|
||||
|
||||
pub(crate) const TYPEINFO_COMPOSITE_QUERY: &str = "\
|
||||
SELECT attname, atttypid
|
||||
FROM pg_catalog.pg_attribute
|
||||
WHERE attrelid = $1
|
||||
AND NOT attisdropped
|
||||
AND attnum > 0
|
||||
ORDER BY attnum
|
||||
";
|
||||
|
||||
static NEXT_ID: AtomicUsize = AtomicUsize::new(0);
|
||||
|
||||
pub async fn prepare(
|
||||
client: &Arc<InnerClient>,
|
||||
query: &str,
|
||||
types: &[Type],
|
||||
) -> Result<Statement, Error> {
|
||||
let name = format!("s{}", NEXT_ID.fetch_add(1, Ordering::SeqCst));
|
||||
let buf = encode(client, &name, query, types)?;
|
||||
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
|
||||
|
||||
match responses.next().await? {
|
||||
Message::ParseComplete => {}
|
||||
_ => return Err(Error::unexpected_message()),
|
||||
}
|
||||
|
||||
let parameter_description = match responses.next().await? {
|
||||
Message::ParameterDescription(body) => body,
|
||||
_ => return Err(Error::unexpected_message()),
|
||||
};
|
||||
|
||||
let row_description = match responses.next().await? {
|
||||
Message::RowDescription(body) => Some(body),
|
||||
Message::NoData => None,
|
||||
_ => return Err(Error::unexpected_message()),
|
||||
};
|
||||
|
||||
let mut parameters = vec![];
|
||||
let mut it = parameter_description.parameters();
|
||||
while let Some(oid) = it.next().map_err(Error::parse)? {
|
||||
let type_ = get_type(client, oid).await?;
|
||||
parameters.push(type_);
|
||||
}
|
||||
|
||||
let mut columns = vec![];
|
||||
if let Some(row_description) = row_description {
|
||||
let mut it = row_description.fields();
|
||||
while let Some(field) = it.next().map_err(Error::parse)? {
|
||||
let type_ = get_type(client, field.type_oid()).await?;
|
||||
let column = Column::new(field.name().to_string(), type_, field);
|
||||
columns.push(column);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Statement::new(client, name, parameters, columns))
|
||||
}
|
||||
|
||||
fn prepare_rec<'a>(
|
||||
client: &'a Arc<InnerClient>,
|
||||
query: &'a str,
|
||||
types: &'a [Type],
|
||||
) -> Pin<Box<dyn Future<Output = Result<Statement, Error>> + 'a + Send>> {
|
||||
Box::pin(prepare(client, query, types))
|
||||
}
|
||||
|
||||
fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Result<Bytes, Error> {
|
||||
if types.is_empty() {
|
||||
debug!("preparing query {}: {}", name, query);
|
||||
} else {
|
||||
debug!("preparing query {} with types {:?}: {}", name, types, query);
|
||||
}
|
||||
|
||||
client.with_buf(|buf| {
|
||||
frontend::parse(name, query, types.iter().map(Type::oid), buf).map_err(Error::encode)?;
|
||||
frontend::describe(b'S', name, buf).map_err(Error::encode)?;
|
||||
frontend::sync(buf);
|
||||
Ok(buf.split().freeze())
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn get_type(client: &Arc<InnerClient>, oid: Oid) -> Result<Type, Error> {
|
||||
if let Some(type_) = Type::from_oid(oid) {
|
||||
return Ok(type_);
|
||||
}
|
||||
|
||||
if let Some(type_) = client.type_(oid) {
|
||||
return Ok(type_);
|
||||
}
|
||||
|
||||
let stmt = typeinfo_statement(client).await?;
|
||||
|
||||
let rows = query::query(client, stmt, slice_iter(&[&oid])).await?;
|
||||
pin_mut!(rows);
|
||||
|
||||
let row = match rows.try_next().await? {
|
||||
Some(row) => row,
|
||||
None => return Err(Error::unexpected_message()),
|
||||
};
|
||||
|
||||
let name: String = row.try_get(0)?;
|
||||
let type_: i8 = row.try_get(1)?;
|
||||
let elem_oid: Oid = row.try_get(2)?;
|
||||
let rngsubtype: Option<Oid> = row.try_get(3)?;
|
||||
let basetype: Oid = row.try_get(4)?;
|
||||
let schema: String = row.try_get(5)?;
|
||||
let relid: Oid = row.try_get(6)?;
|
||||
|
||||
let kind = if type_ == b'e' as i8 {
|
||||
let variants = get_enum_variants(client, oid).await?;
|
||||
Kind::Enum(variants)
|
||||
} else if type_ == b'p' as i8 {
|
||||
Kind::Pseudo
|
||||
} else if basetype != 0 {
|
||||
let type_ = get_type_rec(client, basetype).await?;
|
||||
Kind::Domain(type_)
|
||||
} else if elem_oid != 0 {
|
||||
let type_ = get_type_rec(client, elem_oid).await?;
|
||||
Kind::Array(type_)
|
||||
} else if relid != 0 {
|
||||
let fields = get_composite_fields(client, relid).await?;
|
||||
Kind::Composite(fields)
|
||||
} else if let Some(rngsubtype) = rngsubtype {
|
||||
let type_ = get_type_rec(client, rngsubtype).await?;
|
||||
Kind::Range(type_)
|
||||
} else {
|
||||
Kind::Simple
|
||||
};
|
||||
|
||||
let type_ = Type::new(name, oid, kind, schema);
|
||||
client.set_type(oid, &type_);
|
||||
|
||||
Ok(type_)
|
||||
}
|
||||
|
||||
fn get_type_rec<'a>(
|
||||
client: &'a Arc<InnerClient>,
|
||||
oid: Oid,
|
||||
) -> Pin<Box<dyn Future<Output = Result<Type, Error>> + Send + 'a>> {
|
||||
Box::pin(get_type(client, oid))
|
||||
}
|
||||
|
||||
async fn typeinfo_statement(client: &Arc<InnerClient>) -> Result<Statement, Error> {
|
||||
if let Some(stmt) = client.typeinfo() {
|
||||
return Ok(stmt);
|
||||
}
|
||||
|
||||
let stmt = match prepare_rec(client, TYPEINFO_QUERY, &[]).await {
|
||||
Ok(stmt) => stmt,
|
||||
Err(ref e) if e.code() == Some(&SqlState::UNDEFINED_TABLE) => {
|
||||
prepare_rec(client, TYPEINFO_FALLBACK_QUERY, &[]).await?
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
|
||||
client.set_typeinfo(&stmt);
|
||||
Ok(stmt)
|
||||
}
|
||||
|
||||
async fn get_enum_variants(client: &Arc<InnerClient>, oid: Oid) -> Result<Vec<String>, Error> {
|
||||
let stmt = typeinfo_enum_statement(client).await?;
|
||||
|
||||
query::query(client, stmt, slice_iter(&[&oid]))
|
||||
.await?
|
||||
.and_then(|row| async move { row.try_get(0) })
|
||||
.try_collect()
|
||||
.await
|
||||
}
|
||||
|
||||
async fn typeinfo_enum_statement(client: &Arc<InnerClient>) -> Result<Statement, Error> {
|
||||
if let Some(stmt) = client.typeinfo_enum() {
|
||||
return Ok(stmt);
|
||||
}
|
||||
|
||||
let stmt = match prepare_rec(client, TYPEINFO_ENUM_QUERY, &[]).await {
|
||||
Ok(stmt) => stmt,
|
||||
Err(ref e) if e.code() == Some(&SqlState::UNDEFINED_COLUMN) => {
|
||||
prepare_rec(client, TYPEINFO_ENUM_FALLBACK_QUERY, &[]).await?
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
|
||||
client.set_typeinfo_enum(&stmt);
|
||||
Ok(stmt)
|
||||
}
|
||||
|
||||
async fn get_composite_fields(client: &Arc<InnerClient>, oid: Oid) -> Result<Vec<Field>, Error> {
|
||||
let stmt = typeinfo_composite_statement(client).await?;
|
||||
|
||||
let rows = query::query(client, stmt, slice_iter(&[&oid]))
|
||||
.await?
|
||||
.try_collect::<Vec<_>>()
|
||||
.await?;
|
||||
|
||||
let mut fields = vec![];
|
||||
for row in rows {
|
||||
let name = row.try_get(0)?;
|
||||
let oid = row.try_get(1)?;
|
||||
let type_ = get_type_rec(client, oid).await?;
|
||||
fields.push(Field::new(name, type_));
|
||||
}
|
||||
|
||||
Ok(fields)
|
||||
}
|
||||
|
||||
async fn typeinfo_composite_statement(client: &Arc<InnerClient>) -> Result<Statement, Error> {
|
||||
if let Some(stmt) = client.typeinfo_composite() {
|
||||
return Ok(stmt);
|
||||
}
|
||||
|
||||
let stmt = prepare_rec(client, TYPEINFO_COMPOSITE_QUERY, &[]).await?;
|
||||
|
||||
client.set_typeinfo_composite(&stmt);
|
||||
Ok(stmt)
|
||||
}
|
||||
340
libs/proxy/tokio-postgres2/src/query.rs
Normal file
340
libs/proxy/tokio-postgres2/src/query.rs
Normal file
@@ -0,0 +1,340 @@
|
||||
use crate::client::{InnerClient, Responses};
|
||||
use crate::codec::FrontendMessage;
|
||||
use crate::connection::RequestMessages;
|
||||
use crate::types::IsNull;
|
||||
use crate::{Column, Error, ReadyForQueryStatus, Row, Statement};
|
||||
use bytes::{BufMut, Bytes, BytesMut};
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use futures_util::{ready, Stream};
|
||||
use log::{debug, log_enabled, Level};
|
||||
use pin_project_lite::pin_project;
|
||||
use postgres_protocol2::message::backend::Message;
|
||||
use postgres_protocol2::message::frontend;
|
||||
use postgres_types2::{Format, ToSql, Type};
|
||||
use std::fmt;
|
||||
use std::marker::PhantomPinned;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
struct BorrowToSqlParamsDebug<'a>(&'a [&'a (dyn ToSql + Sync)]);
|
||||
|
||||
impl fmt::Debug for BorrowToSqlParamsDebug<'_> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_list().entries(self.0.iter()).finish()
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn query<'a, I>(
|
||||
client: &InnerClient,
|
||||
statement: Statement,
|
||||
params: I,
|
||||
) -> Result<RowStream, Error>
|
||||
where
|
||||
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
|
||||
I::IntoIter: ExactSizeIterator,
|
||||
{
|
||||
let buf = if log_enabled!(Level::Debug) {
|
||||
let params = params.into_iter().collect::<Vec<_>>();
|
||||
debug!(
|
||||
"executing statement {} with parameters: {:?}",
|
||||
statement.name(),
|
||||
BorrowToSqlParamsDebug(params.as_slice()),
|
||||
);
|
||||
encode(client, &statement, params)?
|
||||
} else {
|
||||
encode(client, &statement, params)?
|
||||
};
|
||||
let responses = start(client, buf).await?;
|
||||
Ok(RowStream {
|
||||
statement,
|
||||
responses,
|
||||
command_tag: None,
|
||||
status: ReadyForQueryStatus::Unknown,
|
||||
output_format: Format::Binary,
|
||||
_p: PhantomPinned,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn query_txt<S, I>(
|
||||
client: &Arc<InnerClient>,
|
||||
query: &str,
|
||||
params: I,
|
||||
) -> Result<RowStream, Error>
|
||||
where
|
||||
S: AsRef<str>,
|
||||
I: IntoIterator<Item = Option<S>>,
|
||||
I::IntoIter: ExactSizeIterator,
|
||||
{
|
||||
let params = params.into_iter();
|
||||
|
||||
let buf = client.with_buf(|buf| {
|
||||
frontend::parse(
|
||||
"", // unnamed prepared statement
|
||||
query, // query to parse
|
||||
std::iter::empty(), // give no type info
|
||||
buf,
|
||||
)
|
||||
.map_err(Error::encode)?;
|
||||
frontend::describe(b'S', "", buf).map_err(Error::encode)?;
|
||||
// Bind, pass params as text, retrieve as binary
|
||||
match frontend::bind(
|
||||
"", // empty string selects the unnamed portal
|
||||
"", // unnamed prepared statement
|
||||
std::iter::empty(), // all parameters use the default format (text)
|
||||
params,
|
||||
|param, buf| match param {
|
||||
Some(param) => {
|
||||
buf.put_slice(param.as_ref().as_bytes());
|
||||
Ok(postgres_protocol2::IsNull::No)
|
||||
}
|
||||
None => Ok(postgres_protocol2::IsNull::Yes),
|
||||
},
|
||||
Some(0), // all text
|
||||
buf,
|
||||
) {
|
||||
Ok(()) => Ok(()),
|
||||
Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, 0)),
|
||||
Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)),
|
||||
}?;
|
||||
|
||||
// Execute
|
||||
frontend::execute("", 0, buf).map_err(Error::encode)?;
|
||||
// Sync
|
||||
frontend::sync(buf);
|
||||
|
||||
Ok(buf.split().freeze())
|
||||
})?;
|
||||
|
||||
// now read the responses
|
||||
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
|
||||
|
||||
match responses.next().await? {
|
||||
Message::ParseComplete => {}
|
||||
_ => return Err(Error::unexpected_message()),
|
||||
}
|
||||
|
||||
let parameter_description = match responses.next().await? {
|
||||
Message::ParameterDescription(body) => body,
|
||||
_ => return Err(Error::unexpected_message()),
|
||||
};
|
||||
|
||||
let row_description = match responses.next().await? {
|
||||
Message::RowDescription(body) => Some(body),
|
||||
Message::NoData => None,
|
||||
_ => return Err(Error::unexpected_message()),
|
||||
};
|
||||
|
||||
match responses.next().await? {
|
||||
Message::BindComplete => {}
|
||||
_ => return Err(Error::unexpected_message()),
|
||||
}
|
||||
|
||||
let mut parameters = vec![];
|
||||
let mut it = parameter_description.parameters();
|
||||
while let Some(oid) = it.next().map_err(Error::parse)? {
|
||||
let type_ = Type::from_oid(oid).unwrap_or(Type::UNKNOWN);
|
||||
parameters.push(type_);
|
||||
}
|
||||
|
||||
let mut columns = vec![];
|
||||
if let Some(row_description) = row_description {
|
||||
let mut it = row_description.fields();
|
||||
while let Some(field) = it.next().map_err(Error::parse)? {
|
||||
let type_ = Type::from_oid(field.type_oid()).unwrap_or(Type::UNKNOWN);
|
||||
let column = Column::new(field.name().to_string(), type_, field);
|
||||
columns.push(column);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(RowStream {
|
||||
statement: Statement::new_anonymous(parameters, columns),
|
||||
responses,
|
||||
command_tag: None,
|
||||
status: ReadyForQueryStatus::Unknown,
|
||||
output_format: Format::Text,
|
||||
_p: PhantomPinned,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn execute<'a, I>(
|
||||
client: &InnerClient,
|
||||
statement: Statement,
|
||||
params: I,
|
||||
) -> Result<u64, Error>
|
||||
where
|
||||
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
|
||||
I::IntoIter: ExactSizeIterator,
|
||||
{
|
||||
let buf = if log_enabled!(Level::Debug) {
|
||||
let params = params.into_iter().collect::<Vec<_>>();
|
||||
debug!(
|
||||
"executing statement {} with parameters: {:?}",
|
||||
statement.name(),
|
||||
BorrowToSqlParamsDebug(params.as_slice()),
|
||||
);
|
||||
encode(client, &statement, params)?
|
||||
} else {
|
||||
encode(client, &statement, params)?
|
||||
};
|
||||
let mut responses = start(client, buf).await?;
|
||||
|
||||
let mut rows = 0;
|
||||
loop {
|
||||
match responses.next().await? {
|
||||
Message::DataRow(_) => {}
|
||||
Message::CommandComplete(body) => {
|
||||
rows = body
|
||||
.tag()
|
||||
.map_err(Error::parse)?
|
||||
.rsplit(' ')
|
||||
.next()
|
||||
.unwrap()
|
||||
.parse()
|
||||
.unwrap_or(0);
|
||||
}
|
||||
Message::EmptyQueryResponse => rows = 0,
|
||||
Message::ReadyForQuery(_) => return Ok(rows),
|
||||
_ => return Err(Error::unexpected_message()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn start(client: &InnerClient, buf: Bytes) -> Result<Responses, Error> {
|
||||
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
|
||||
|
||||
match responses.next().await? {
|
||||
Message::BindComplete => {}
|
||||
_ => return Err(Error::unexpected_message()),
|
||||
}
|
||||
|
||||
Ok(responses)
|
||||
}
|
||||
|
||||
pub fn encode<'a, I>(client: &InnerClient, statement: &Statement, params: I) -> Result<Bytes, Error>
|
||||
where
|
||||
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
|
||||
I::IntoIter: ExactSizeIterator,
|
||||
{
|
||||
client.with_buf(|buf| {
|
||||
encode_bind(statement, params, "", buf)?;
|
||||
frontend::execute("", 0, buf).map_err(Error::encode)?;
|
||||
frontend::sync(buf);
|
||||
Ok(buf.split().freeze())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn encode_bind<'a, I>(
|
||||
statement: &Statement,
|
||||
params: I,
|
||||
portal: &str,
|
||||
buf: &mut BytesMut,
|
||||
) -> Result<(), Error>
|
||||
where
|
||||
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
|
||||
I::IntoIter: ExactSizeIterator,
|
||||
{
|
||||
let param_types = statement.params();
|
||||
let params = params.into_iter();
|
||||
|
||||
assert!(
|
||||
param_types.len() == params.len(),
|
||||
"expected {} parameters but got {}",
|
||||
param_types.len(),
|
||||
params.len()
|
||||
);
|
||||
|
||||
let (param_formats, params): (Vec<_>, Vec<_>) = params
|
||||
.zip(param_types.iter())
|
||||
.map(|(p, ty)| (p.encode_format(ty) as i16, p))
|
||||
.unzip();
|
||||
|
||||
let params = params.into_iter();
|
||||
|
||||
let mut error_idx = 0;
|
||||
let r = frontend::bind(
|
||||
portal,
|
||||
statement.name(),
|
||||
param_formats,
|
||||
params.zip(param_types).enumerate(),
|
||||
|(idx, (param, ty)), buf| match param.to_sql_checked(ty, buf) {
|
||||
Ok(IsNull::No) => Ok(postgres_protocol2::IsNull::No),
|
||||
Ok(IsNull::Yes) => Ok(postgres_protocol2::IsNull::Yes),
|
||||
Err(e) => {
|
||||
error_idx = idx;
|
||||
Err(e)
|
||||
}
|
||||
},
|
||||
Some(1),
|
||||
buf,
|
||||
);
|
||||
match r {
|
||||
Ok(()) => Ok(()),
|
||||
Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, error_idx)),
|
||||
Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)),
|
||||
}
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
/// A stream of table rows.
|
||||
pub struct RowStream {
|
||||
statement: Statement,
|
||||
responses: Responses,
|
||||
command_tag: Option<String>,
|
||||
output_format: Format,
|
||||
status: ReadyForQueryStatus,
|
||||
#[pin]
|
||||
_p: PhantomPinned,
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for RowStream {
|
||||
type Item = Result<Row, Error>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let this = self.project();
|
||||
loop {
|
||||
match ready!(this.responses.poll_next(cx)?) {
|
||||
Message::DataRow(body) => {
|
||||
return Poll::Ready(Some(Ok(Row::new(
|
||||
this.statement.clone(),
|
||||
body,
|
||||
*this.output_format,
|
||||
)?)))
|
||||
}
|
||||
Message::EmptyQueryResponse | Message::PortalSuspended => {}
|
||||
Message::CommandComplete(body) => {
|
||||
if let Ok(tag) = body.tag() {
|
||||
*this.command_tag = Some(tag.to_string());
|
||||
}
|
||||
}
|
||||
Message::ReadyForQuery(status) => {
|
||||
*this.status = status.into();
|
||||
return Poll::Ready(None);
|
||||
}
|
||||
_ => return Poll::Ready(Some(Err(Error::unexpected_message()))),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RowStream {
|
||||
/// Returns information about the columns of data in the row.
|
||||
pub fn columns(&self) -> &[Column] {
|
||||
self.statement.columns()
|
||||
}
|
||||
|
||||
/// Returns the command tag of this query.
|
||||
///
|
||||
/// This is only available after the stream has been exhausted.
|
||||
pub fn command_tag(&self) -> Option<String> {
|
||||
self.command_tag.clone()
|
||||
}
|
||||
|
||||
/// Returns if the connection is ready for querying, with the status of the connection.
|
||||
///
|
||||
/// This might be available only after the stream has been exhausted.
|
||||
pub fn ready_status(&self) -> ReadyForQueryStatus {
|
||||
self.status
|
||||
}
|
||||
}
|
||||
300
libs/proxy/tokio-postgres2/src/row.rs
Normal file
300
libs/proxy/tokio-postgres2/src/row.rs
Normal file
@@ -0,0 +1,300 @@
|
||||
//! Rows.
|
||||
|
||||
use crate::row::sealed::{AsName, Sealed};
|
||||
use crate::simple_query::SimpleColumn;
|
||||
use crate::statement::Column;
|
||||
use crate::types::{FromSql, Type, WrongType};
|
||||
use crate::{Error, Statement};
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use postgres_protocol2::message::backend::DataRowBody;
|
||||
use postgres_types2::{Format, WrongFormat};
|
||||
use std::fmt;
|
||||
use std::ops::Range;
|
||||
use std::str;
|
||||
use std::sync::Arc;
|
||||
|
||||
mod sealed {
|
||||
pub trait Sealed {}
|
||||
|
||||
pub trait AsName {
|
||||
fn as_name(&self) -> &str;
|
||||
}
|
||||
}
|
||||
|
||||
impl AsName for Column {
|
||||
fn as_name(&self) -> &str {
|
||||
self.name()
|
||||
}
|
||||
}
|
||||
|
||||
impl AsName for String {
|
||||
fn as_name(&self) -> &str {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// A trait implemented by types that can index into columns of a row.
|
||||
///
|
||||
/// This cannot be implemented outside of this crate.
|
||||
pub trait RowIndex: Sealed {
|
||||
#[doc(hidden)]
|
||||
fn __idx<T>(&self, columns: &[T]) -> Option<usize>
|
||||
where
|
||||
T: AsName;
|
||||
}
|
||||
|
||||
impl Sealed for usize {}
|
||||
|
||||
impl RowIndex for usize {
|
||||
#[inline]
|
||||
fn __idx<T>(&self, columns: &[T]) -> Option<usize>
|
||||
where
|
||||
T: AsName,
|
||||
{
|
||||
if *self >= columns.len() {
|
||||
None
|
||||
} else {
|
||||
Some(*self)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Sealed for str {}
|
||||
|
||||
impl RowIndex for str {
|
||||
#[inline]
|
||||
fn __idx<T>(&self, columns: &[T]) -> Option<usize>
|
||||
where
|
||||
T: AsName,
|
||||
{
|
||||
if let Some(idx) = columns.iter().position(|d| d.as_name() == self) {
|
||||
return Some(idx);
|
||||
};
|
||||
|
||||
// FIXME ASCII-only case insensitivity isn't really the right thing to
|
||||
// do. Postgres itself uses a dubious wrapper around tolower and JDBC
|
||||
// uses the US locale.
|
||||
columns
|
||||
.iter()
|
||||
.position(|d| d.as_name().eq_ignore_ascii_case(self))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Sealed for &T where T: ?Sized + Sealed {}
|
||||
|
||||
impl<T> RowIndex for &T
|
||||
where
|
||||
T: ?Sized + RowIndex,
|
||||
{
|
||||
#[inline]
|
||||
fn __idx<U>(&self, columns: &[U]) -> Option<usize>
|
||||
where
|
||||
U: AsName,
|
||||
{
|
||||
T::__idx(*self, columns)
|
||||
}
|
||||
}
|
||||
|
||||
/// A row of data returned from the database by a query.
|
||||
pub struct Row {
|
||||
statement: Statement,
|
||||
output_format: Format,
|
||||
body: DataRowBody,
|
||||
ranges: Vec<Option<Range<usize>>>,
|
||||
}
|
||||
|
||||
impl fmt::Debug for Row {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("Row")
|
||||
.field("columns", &self.columns())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Row {
|
||||
pub(crate) fn new(
|
||||
statement: Statement,
|
||||
body: DataRowBody,
|
||||
output_format: Format,
|
||||
) -> Result<Row, Error> {
|
||||
let ranges = body.ranges().collect().map_err(Error::parse)?;
|
||||
Ok(Row {
|
||||
statement,
|
||||
body,
|
||||
ranges,
|
||||
output_format,
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns information about the columns of data in the row.
|
||||
pub fn columns(&self) -> &[Column] {
|
||||
self.statement.columns()
|
||||
}
|
||||
|
||||
/// Determines if the row contains no values.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.len() == 0
|
||||
}
|
||||
|
||||
/// Returns the number of values in the row.
|
||||
pub fn len(&self) -> usize {
|
||||
self.columns().len()
|
||||
}
|
||||
|
||||
/// Deserializes a value from the row.
|
||||
///
|
||||
/// The value can be specified either by its numeric index in the row, or by its column name.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the index is out of bounds or if the value cannot be converted to the specified type.
|
||||
pub fn get<'a, I, T>(&'a self, idx: I) -> T
|
||||
where
|
||||
I: RowIndex + fmt::Display,
|
||||
T: FromSql<'a>,
|
||||
{
|
||||
match self.get_inner(&idx) {
|
||||
Ok(ok) => ok,
|
||||
Err(err) => panic!("error retrieving column {}: {}", idx, err),
|
||||
}
|
||||
}
|
||||
|
||||
/// Like `Row::get`, but returns a `Result` rather than panicking.
|
||||
pub fn try_get<'a, I, T>(&'a self, idx: I) -> Result<T, Error>
|
||||
where
|
||||
I: RowIndex + fmt::Display,
|
||||
T: FromSql<'a>,
|
||||
{
|
||||
self.get_inner(&idx)
|
||||
}
|
||||
|
||||
fn get_inner<'a, I, T>(&'a self, idx: &I) -> Result<T, Error>
|
||||
where
|
||||
I: RowIndex + fmt::Display,
|
||||
T: FromSql<'a>,
|
||||
{
|
||||
let idx = match idx.__idx(self.columns()) {
|
||||
Some(idx) => idx,
|
||||
None => return Err(Error::column(idx.to_string())),
|
||||
};
|
||||
|
||||
let ty = self.columns()[idx].type_();
|
||||
if !T::accepts(ty) {
|
||||
return Err(Error::from_sql(
|
||||
Box::new(WrongType::new::<T>(ty.clone())),
|
||||
idx,
|
||||
));
|
||||
}
|
||||
|
||||
FromSql::from_sql_nullable(ty, self.col_buffer(idx)).map_err(|e| Error::from_sql(e, idx))
|
||||
}
|
||||
|
||||
/// Get the raw bytes for the column at the given index.
|
||||
fn col_buffer(&self, idx: usize) -> Option<&[u8]> {
|
||||
let range = self.ranges.get(idx)?.to_owned()?;
|
||||
Some(&self.body.buffer()[range])
|
||||
}
|
||||
|
||||
/// Interpret the column at the given index as text
|
||||
///
|
||||
/// Useful when using query_raw_txt() which sets text transfer mode
|
||||
pub fn as_text(&self, idx: usize) -> Result<Option<&str>, Error> {
|
||||
if self.output_format == Format::Text {
|
||||
match self.col_buffer(idx) {
|
||||
Some(raw) => {
|
||||
FromSql::from_sql(&Type::TEXT, raw).map_err(|e| Error::from_sql(e, idx))
|
||||
}
|
||||
None => Ok(None),
|
||||
}
|
||||
} else {
|
||||
Err(Error::from_sql(Box::new(WrongFormat {}), idx))
|
||||
}
|
||||
}
|
||||
|
||||
/// Row byte size
|
||||
pub fn body_len(&self) -> usize {
|
||||
self.body.buffer().len()
|
||||
}
|
||||
}
|
||||
|
||||
impl AsName for SimpleColumn {
|
||||
fn as_name(&self) -> &str {
|
||||
self.name()
|
||||
}
|
||||
}
|
||||
|
||||
/// A row of data returned from the database by a simple query.
|
||||
#[derive(Debug)]
|
||||
pub struct SimpleQueryRow {
|
||||
columns: Arc<[SimpleColumn]>,
|
||||
body: DataRowBody,
|
||||
ranges: Vec<Option<Range<usize>>>,
|
||||
}
|
||||
|
||||
impl SimpleQueryRow {
|
||||
#[allow(clippy::new_ret_no_self)]
|
||||
pub(crate) fn new(
|
||||
columns: Arc<[SimpleColumn]>,
|
||||
body: DataRowBody,
|
||||
) -> Result<SimpleQueryRow, Error> {
|
||||
let ranges = body.ranges().collect().map_err(Error::parse)?;
|
||||
Ok(SimpleQueryRow {
|
||||
columns,
|
||||
body,
|
||||
ranges,
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns information about the columns of data in the row.
|
||||
pub fn columns(&self) -> &[SimpleColumn] {
|
||||
&self.columns
|
||||
}
|
||||
|
||||
/// Determines if the row contains no values.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.len() == 0
|
||||
}
|
||||
|
||||
/// Returns the number of values in the row.
|
||||
pub fn len(&self) -> usize {
|
||||
self.columns.len()
|
||||
}
|
||||
|
||||
/// Returns a value from the row.
|
||||
///
|
||||
/// The value can be specified either by its numeric index in the row, or by its column name.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the index is out of bounds or if the value cannot be converted to the specified type.
|
||||
pub fn get<I>(&self, idx: I) -> Option<&str>
|
||||
where
|
||||
I: RowIndex + fmt::Display,
|
||||
{
|
||||
match self.get_inner(&idx) {
|
||||
Ok(ok) => ok,
|
||||
Err(err) => panic!("error retrieving column {}: {}", idx, err),
|
||||
}
|
||||
}
|
||||
|
||||
/// Like `SimpleQueryRow::get`, but returns a `Result` rather than panicking.
|
||||
pub fn try_get<I>(&self, idx: I) -> Result<Option<&str>, Error>
|
||||
where
|
||||
I: RowIndex + fmt::Display,
|
||||
{
|
||||
self.get_inner(&idx)
|
||||
}
|
||||
|
||||
fn get_inner<I>(&self, idx: &I) -> Result<Option<&str>, Error>
|
||||
where
|
||||
I: RowIndex + fmt::Display,
|
||||
{
|
||||
let idx = match idx.__idx(&self.columns) {
|
||||
Some(idx) => idx,
|
||||
None => return Err(Error::column(idx.to_string())),
|
||||
};
|
||||
|
||||
let buf = self.ranges[idx].clone().map(|r| &self.body.buffer()[r]);
|
||||
FromSql::from_sql_nullable(&Type::TEXT, buf).map_err(|e| Error::from_sql(e, idx))
|
||||
}
|
||||
}
|
||||
142
libs/proxy/tokio-postgres2/src/simple_query.rs
Normal file
142
libs/proxy/tokio-postgres2/src/simple_query.rs
Normal file
@@ -0,0 +1,142 @@
|
||||
use crate::client::{InnerClient, Responses};
|
||||
use crate::codec::FrontendMessage;
|
||||
use crate::connection::RequestMessages;
|
||||
use crate::{Error, ReadyForQueryStatus, SimpleQueryMessage, SimpleQueryRow};
|
||||
use bytes::Bytes;
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use futures_util::{ready, Stream};
|
||||
use log::debug;
|
||||
use pin_project_lite::pin_project;
|
||||
use postgres_protocol2::message::backend::Message;
|
||||
use postgres_protocol2::message::frontend;
|
||||
use std::marker::PhantomPinned;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
/// Information about a column of a single query row.
|
||||
#[derive(Debug)]
|
||||
pub struct SimpleColumn {
|
||||
name: String,
|
||||
}
|
||||
|
||||
impl SimpleColumn {
|
||||
pub(crate) fn new(name: String) -> SimpleColumn {
|
||||
SimpleColumn { name }
|
||||
}
|
||||
|
||||
/// Returns the name of the column.
|
||||
pub fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn simple_query(client: &InnerClient, query: &str) -> Result<SimpleQueryStream, Error> {
|
||||
debug!("executing simple query: {}", query);
|
||||
|
||||
let buf = encode(client, query)?;
|
||||
let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
|
||||
|
||||
Ok(SimpleQueryStream {
|
||||
responses,
|
||||
columns: None,
|
||||
status: ReadyForQueryStatus::Unknown,
|
||||
_p: PhantomPinned,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn batch_execute(
|
||||
client: &InnerClient,
|
||||
query: &str,
|
||||
) -> Result<ReadyForQueryStatus, Error> {
|
||||
debug!("executing statement batch: {}", query);
|
||||
|
||||
let buf = encode(client, query)?;
|
||||
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
|
||||
|
||||
loop {
|
||||
match responses.next().await? {
|
||||
Message::ReadyForQuery(status) => return Ok(status.into()),
|
||||
Message::CommandComplete(_)
|
||||
| Message::EmptyQueryResponse
|
||||
| Message::RowDescription(_)
|
||||
| Message::DataRow(_) => {}
|
||||
_ => return Err(Error::unexpected_message()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn encode(client: &InnerClient, query: &str) -> Result<Bytes, Error> {
|
||||
client.with_buf(|buf| {
|
||||
frontend::query(query, buf).map_err(Error::encode)?;
|
||||
Ok(buf.split().freeze())
|
||||
})
|
||||
}
|
||||
|
||||
pin_project! {
|
||||
/// A stream of simple query results.
|
||||
pub struct SimpleQueryStream {
|
||||
responses: Responses,
|
||||
columns: Option<Arc<[SimpleColumn]>>,
|
||||
status: ReadyForQueryStatus,
|
||||
#[pin]
|
||||
_p: PhantomPinned,
|
||||
}
|
||||
}
|
||||
|
||||
impl SimpleQueryStream {
|
||||
/// Returns if the connection is ready for querying, with the status of the connection.
|
||||
///
|
||||
/// This might be available only after the stream has been exhausted.
|
||||
pub fn ready_status(&self) -> ReadyForQueryStatus {
|
||||
self.status
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for SimpleQueryStream {
|
||||
type Item = Result<SimpleQueryMessage, Error>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let this = self.project();
|
||||
loop {
|
||||
match ready!(this.responses.poll_next(cx)?) {
|
||||
Message::CommandComplete(body) => {
|
||||
let rows = body
|
||||
.tag()
|
||||
.map_err(Error::parse)?
|
||||
.rsplit(' ')
|
||||
.next()
|
||||
.unwrap()
|
||||
.parse()
|
||||
.unwrap_or(0);
|
||||
return Poll::Ready(Some(Ok(SimpleQueryMessage::CommandComplete(rows))));
|
||||
}
|
||||
Message::EmptyQueryResponse => {
|
||||
return Poll::Ready(Some(Ok(SimpleQueryMessage::CommandComplete(0))));
|
||||
}
|
||||
Message::RowDescription(body) => {
|
||||
let columns = body
|
||||
.fields()
|
||||
.map(|f| Ok(SimpleColumn::new(f.name().to_string())))
|
||||
.collect::<Vec<_>>()
|
||||
.map_err(Error::parse)?
|
||||
.into();
|
||||
|
||||
*this.columns = Some(columns);
|
||||
}
|
||||
Message::DataRow(body) => {
|
||||
let row = match &this.columns {
|
||||
Some(columns) => SimpleQueryRow::new(columns.clone(), body)?,
|
||||
None => return Poll::Ready(Some(Err(Error::unexpected_message()))),
|
||||
};
|
||||
return Poll::Ready(Some(Ok(SimpleQueryMessage::Row(row))));
|
||||
}
|
||||
Message::ReadyForQuery(s) => {
|
||||
*this.status = s.into();
|
||||
return Poll::Ready(None);
|
||||
}
|
||||
_ => return Poll::Ready(Some(Err(Error::unexpected_message()))),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
157
libs/proxy/tokio-postgres2/src/statement.rs
Normal file
157
libs/proxy/tokio-postgres2/src/statement.rs
Normal file
@@ -0,0 +1,157 @@
|
||||
use crate::client::InnerClient;
|
||||
use crate::codec::FrontendMessage;
|
||||
use crate::connection::RequestMessages;
|
||||
use crate::types::Type;
|
||||
use postgres_protocol2::{
|
||||
message::{backend::Field, frontend},
|
||||
Oid,
|
||||
};
|
||||
use std::{
|
||||
fmt,
|
||||
sync::{Arc, Weak},
|
||||
};
|
||||
|
||||
struct StatementInner {
|
||||
client: Weak<InnerClient>,
|
||||
name: String,
|
||||
params: Vec<Type>,
|
||||
columns: Vec<Column>,
|
||||
}
|
||||
|
||||
impl Drop for StatementInner {
|
||||
fn drop(&mut self) {
|
||||
if let Some(client) = self.client.upgrade() {
|
||||
let buf = client.with_buf(|buf| {
|
||||
frontend::close(b'S', &self.name, buf).unwrap();
|
||||
frontend::sync(buf);
|
||||
buf.split().freeze()
|
||||
});
|
||||
let _ = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A prepared statement.
|
||||
///
|
||||
/// Prepared statements can only be used with the connection that created them.
|
||||
#[derive(Clone)]
|
||||
pub struct Statement(Arc<StatementInner>);
|
||||
|
||||
impl Statement {
|
||||
pub(crate) fn new(
|
||||
inner: &Arc<InnerClient>,
|
||||
name: String,
|
||||
params: Vec<Type>,
|
||||
columns: Vec<Column>,
|
||||
) -> Statement {
|
||||
Statement(Arc::new(StatementInner {
|
||||
client: Arc::downgrade(inner),
|
||||
name,
|
||||
params,
|
||||
columns,
|
||||
}))
|
||||
}
|
||||
|
||||
pub(crate) fn new_anonymous(params: Vec<Type>, columns: Vec<Column>) -> Statement {
|
||||
Statement(Arc::new(StatementInner {
|
||||
client: Weak::new(),
|
||||
name: String::new(),
|
||||
params,
|
||||
columns,
|
||||
}))
|
||||
}
|
||||
|
||||
pub(crate) fn name(&self) -> &str {
|
||||
&self.0.name
|
||||
}
|
||||
|
||||
/// Returns the expected types of the statement's parameters.
|
||||
pub fn params(&self) -> &[Type] {
|
||||
&self.0.params
|
||||
}
|
||||
|
||||
/// Returns information about the columns returned when the statement is queried.
|
||||
pub fn columns(&self) -> &[Column] {
|
||||
&self.0.columns
|
||||
}
|
||||
}
|
||||
|
||||
/// Information about a column of a query.
|
||||
pub struct Column {
|
||||
name: String,
|
||||
type_: Type,
|
||||
|
||||
// raw fields from RowDescription
|
||||
table_oid: Oid,
|
||||
column_id: i16,
|
||||
format: i16,
|
||||
|
||||
// that better be stored in self.type_, but that is more radical refactoring
|
||||
type_oid: Oid,
|
||||
type_size: i16,
|
||||
type_modifier: i32,
|
||||
}
|
||||
|
||||
impl Column {
|
||||
pub(crate) fn new(name: String, type_: Type, raw_field: Field<'_>) -> Column {
|
||||
Column {
|
||||
name,
|
||||
type_,
|
||||
table_oid: raw_field.table_oid(),
|
||||
column_id: raw_field.column_id(),
|
||||
format: raw_field.format(),
|
||||
type_oid: raw_field.type_oid(),
|
||||
type_size: raw_field.type_size(),
|
||||
type_modifier: raw_field.type_modifier(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the name of the column.
|
||||
pub fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
/// Returns the type of the column.
|
||||
pub fn type_(&self) -> &Type {
|
||||
&self.type_
|
||||
}
|
||||
|
||||
/// Returns the table OID of the column.
|
||||
pub fn table_oid(&self) -> Oid {
|
||||
self.table_oid
|
||||
}
|
||||
|
||||
/// Returns the column ID of the column.
|
||||
pub fn column_id(&self) -> i16 {
|
||||
self.column_id
|
||||
}
|
||||
|
||||
/// Returns the format of the column.
|
||||
pub fn format(&self) -> i16 {
|
||||
self.format
|
||||
}
|
||||
|
||||
/// Returns the type OID of the column.
|
||||
pub fn type_oid(&self) -> Oid {
|
||||
self.type_oid
|
||||
}
|
||||
|
||||
/// Returns the type size of the column.
|
||||
pub fn type_size(&self) -> i16 {
|
||||
self.type_size
|
||||
}
|
||||
|
||||
/// Returns the type modifier of the column.
|
||||
pub fn type_modifier(&self) -> i32 {
|
||||
self.type_modifier
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for Column {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
fmt.debug_struct("Column")
|
||||
.field("name", &self.name)
|
||||
.field("type", &self.type_)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
162
libs/proxy/tokio-postgres2/src/tls.rs
Normal file
162
libs/proxy/tokio-postgres2/src/tls.rs
Normal file
@@ -0,0 +1,162 @@
|
||||
//! TLS support.
|
||||
|
||||
use std::error::Error;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use std::{fmt, io};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
|
||||
pub(crate) mod private {
|
||||
pub struct ForcePrivateApi;
|
||||
}
|
||||
|
||||
/// Channel binding information returned from a TLS handshake.
|
||||
pub struct ChannelBinding {
|
||||
pub(crate) tls_server_end_point: Option<Vec<u8>>,
|
||||
}
|
||||
|
||||
impl ChannelBinding {
|
||||
/// Creates a `ChannelBinding` containing no information.
|
||||
pub fn none() -> ChannelBinding {
|
||||
ChannelBinding {
|
||||
tls_server_end_point: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a `ChannelBinding` containing `tls-server-end-point` channel binding information.
|
||||
pub fn tls_server_end_point(tls_server_end_point: Vec<u8>) -> ChannelBinding {
|
||||
ChannelBinding {
|
||||
tls_server_end_point: Some(tls_server_end_point),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A constructor of `TlsConnect`ors.
|
||||
///
|
||||
/// Requires the `runtime` Cargo feature (enabled by default).
|
||||
pub trait MakeTlsConnect<S> {
|
||||
/// The stream type created by the `TlsConnect` implementation.
|
||||
type Stream: TlsStream + Unpin;
|
||||
/// The `TlsConnect` implementation created by this type.
|
||||
type TlsConnect: TlsConnect<S, Stream = Self::Stream>;
|
||||
/// The error type returned by the `TlsConnect` implementation.
|
||||
type Error: Into<Box<dyn Error + Sync + Send>>;
|
||||
|
||||
/// Creates a new `TlsConnect`or.
|
||||
///
|
||||
/// The domain name is provided for certificate verification and SNI.
|
||||
fn make_tls_connect(&mut self, domain: &str) -> Result<Self::TlsConnect, Self::Error>;
|
||||
}
|
||||
|
||||
/// An asynchronous function wrapping a stream in a TLS session.
|
||||
pub trait TlsConnect<S> {
|
||||
/// The stream returned by the future.
|
||||
type Stream: TlsStream + Unpin;
|
||||
/// The error returned by the future.
|
||||
type Error: Into<Box<dyn Error + Sync + Send>>;
|
||||
/// The future returned by the connector.
|
||||
type Future: Future<Output = Result<Self::Stream, Self::Error>>;
|
||||
|
||||
/// Returns a future performing a TLS handshake over the stream.
|
||||
fn connect(self, stream: S) -> Self::Future;
|
||||
|
||||
#[doc(hidden)]
|
||||
fn can_connect(&self, _: private::ForcePrivateApi) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
/// A TLS-wrapped connection to a PostgreSQL database.
|
||||
pub trait TlsStream: AsyncRead + AsyncWrite {
|
||||
/// Returns channel binding information for the session.
|
||||
fn channel_binding(&self) -> ChannelBinding;
|
||||
}
|
||||
|
||||
/// A `MakeTlsConnect` and `TlsConnect` implementation which simply returns an error.
|
||||
///
|
||||
/// This can be used when `sslmode` is `none` or `prefer`.
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct NoTls;
|
||||
|
||||
impl<S> MakeTlsConnect<S> for NoTls {
|
||||
type Stream = NoTlsStream;
|
||||
type TlsConnect = NoTls;
|
||||
type Error = NoTlsError;
|
||||
|
||||
fn make_tls_connect(&mut self, _: &str) -> Result<NoTls, NoTlsError> {
|
||||
Ok(NoTls)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> TlsConnect<S> for NoTls {
|
||||
type Stream = NoTlsStream;
|
||||
type Error = NoTlsError;
|
||||
type Future = NoTlsFuture;
|
||||
|
||||
fn connect(self, _: S) -> NoTlsFuture {
|
||||
NoTlsFuture(())
|
||||
}
|
||||
|
||||
fn can_connect(&self, _: private::ForcePrivateApi) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// The future returned by `NoTls`.
|
||||
pub struct NoTlsFuture(());
|
||||
|
||||
impl Future for NoTlsFuture {
|
||||
type Output = Result<NoTlsStream, NoTlsError>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
Poll::Ready(Err(NoTlsError(())))
|
||||
}
|
||||
}
|
||||
|
||||
/// The TLS "stream" type produced by the `NoTls` connector.
|
||||
///
|
||||
/// Since `NoTls` doesn't support TLS, this type is uninhabited.
|
||||
pub enum NoTlsStream {}
|
||||
|
||||
impl AsyncRead for NoTlsStream {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
_: &mut Context<'_>,
|
||||
_: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
match *self {}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for NoTlsStream {
|
||||
fn poll_write(self: Pin<&mut Self>, _: &mut Context<'_>, _: &[u8]) -> Poll<io::Result<usize>> {
|
||||
match *self {}
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
match *self {}
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
match *self {}
|
||||
}
|
||||
}
|
||||
|
||||
impl TlsStream for NoTlsStream {
|
||||
fn channel_binding(&self) -> ChannelBinding {
|
||||
match *self {}
|
||||
}
|
||||
}
|
||||
|
||||
/// The error returned by `NoTls`.
|
||||
#[derive(Debug)]
|
||||
pub struct NoTlsError(());
|
||||
|
||||
impl fmt::Display for NoTlsError {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
fmt.write_str("no TLS implementation configured")
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for NoTlsError {}
|
||||
57
libs/proxy/tokio-postgres2/src/to_statement.rs
Normal file
57
libs/proxy/tokio-postgres2/src/to_statement.rs
Normal file
@@ -0,0 +1,57 @@
|
||||
use crate::to_statement::private::{Sealed, ToStatementType};
|
||||
use crate::Statement;
|
||||
|
||||
mod private {
|
||||
use crate::{Client, Error, Statement};
|
||||
|
||||
pub trait Sealed {}
|
||||
|
||||
pub enum ToStatementType<'a> {
|
||||
Statement(&'a Statement),
|
||||
Query(&'a str),
|
||||
}
|
||||
|
||||
impl<'a> ToStatementType<'a> {
|
||||
pub async fn into_statement(self, client: &Client) -> Result<Statement, Error> {
|
||||
match self {
|
||||
ToStatementType::Statement(s) => Ok(s.clone()),
|
||||
ToStatementType::Query(s) => client.prepare(s).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A trait abstracting over prepared and unprepared statements.
|
||||
///
|
||||
/// Many methods are generic over this bound, so that they support both a raw query string as well as a statement which
|
||||
/// was prepared previously.
|
||||
///
|
||||
/// This trait is "sealed" and cannot be implemented by anything outside this crate.
|
||||
pub trait ToStatement: Sealed {
|
||||
#[doc(hidden)]
|
||||
fn __convert(&self) -> ToStatementType<'_>;
|
||||
}
|
||||
|
||||
impl ToStatement for Statement {
|
||||
fn __convert(&self) -> ToStatementType<'_> {
|
||||
ToStatementType::Statement(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl Sealed for Statement {}
|
||||
|
||||
impl ToStatement for str {
|
||||
fn __convert(&self) -> ToStatementType<'_> {
|
||||
ToStatementType::Query(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl Sealed for str {}
|
||||
|
||||
impl ToStatement for String {
|
||||
fn __convert(&self) -> ToStatementType<'_> {
|
||||
ToStatementType::Query(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl Sealed for String {}
|
||||
74
libs/proxy/tokio-postgres2/src/transaction.rs
Normal file
74
libs/proxy/tokio-postgres2/src/transaction.rs
Normal file
@@ -0,0 +1,74 @@
|
||||
use crate::codec::FrontendMessage;
|
||||
use crate::connection::RequestMessages;
|
||||
use crate::query::RowStream;
|
||||
use crate::{CancelToken, Client, Error, ReadyForQueryStatus};
|
||||
use postgres_protocol2::message::frontend;
|
||||
|
||||
/// A representation of a PostgreSQL database transaction.
|
||||
///
|
||||
/// Transactions will implicitly roll back when dropped. Use the `commit` method to commit the changes made in the
|
||||
/// transaction. Transactions can be nested, with inner transactions implemented via safepoints.
|
||||
pub struct Transaction<'a> {
|
||||
client: &'a mut Client,
|
||||
done: bool,
|
||||
}
|
||||
|
||||
impl Drop for Transaction<'_> {
|
||||
fn drop(&mut self) {
|
||||
if self.done {
|
||||
return;
|
||||
}
|
||||
|
||||
let buf = self.client.inner().with_buf(|buf| {
|
||||
frontend::query("ROLLBACK", buf).unwrap();
|
||||
buf.split().freeze()
|
||||
});
|
||||
let _ = self
|
||||
.client
|
||||
.inner()
|
||||
.send(RequestMessages::Single(FrontendMessage::Raw(buf)));
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Transaction<'a> {
|
||||
pub(crate) fn new(client: &'a mut Client) -> Transaction<'a> {
|
||||
Transaction {
|
||||
client,
|
||||
done: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Consumes the transaction, committing all changes made within it.
|
||||
pub async fn commit(mut self) -> Result<ReadyForQueryStatus, Error> {
|
||||
self.done = true;
|
||||
self.client.batch_execute("COMMIT").await
|
||||
}
|
||||
|
||||
/// Rolls the transaction back, discarding all changes made within it.
|
||||
///
|
||||
/// This is equivalent to `Transaction`'s `Drop` implementation, but provides any error encountered to the caller.
|
||||
pub async fn rollback(mut self) -> Result<ReadyForQueryStatus, Error> {
|
||||
self.done = true;
|
||||
self.client.batch_execute("ROLLBACK").await
|
||||
}
|
||||
|
||||
/// Like `Client::query_raw_txt`.
|
||||
pub async fn query_raw_txt<S, I>(&self, statement: &str, params: I) -> Result<RowStream, Error>
|
||||
where
|
||||
S: AsRef<str>,
|
||||
I: IntoIterator<Item = Option<S>>,
|
||||
I::IntoIter: ExactSizeIterator,
|
||||
{
|
||||
self.client.query_raw_txt(statement, params).await
|
||||
}
|
||||
|
||||
/// Like `Client::cancel_token`.
|
||||
pub fn cancel_token(&self) -> CancelToken {
|
||||
self.client.cancel_token()
|
||||
}
|
||||
|
||||
/// Returns a reference to the underlying `Client`.
|
||||
pub fn client(&self) -> &Client {
|
||||
self.client
|
||||
}
|
||||
}
|
||||
113
libs/proxy/tokio-postgres2/src/transaction_builder.rs
Normal file
113
libs/proxy/tokio-postgres2/src/transaction_builder.rs
Normal file
@@ -0,0 +1,113 @@
|
||||
use crate::{Client, Error, Transaction};
|
||||
|
||||
/// The isolation level of a database transaction.
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
#[non_exhaustive]
|
||||
pub enum IsolationLevel {
|
||||
/// Equivalent to `ReadCommitted`.
|
||||
ReadUncommitted,
|
||||
|
||||
/// An individual statement in the transaction will see rows committed before it began.
|
||||
ReadCommitted,
|
||||
|
||||
/// All statements in the transaction will see the same view of rows committed before the first query in the
|
||||
/// transaction.
|
||||
RepeatableRead,
|
||||
|
||||
/// The reads and writes in this transaction must be able to be committed as an atomic "unit" with respect to reads
|
||||
/// and writes of all other concurrent serializable transactions without interleaving.
|
||||
Serializable,
|
||||
}
|
||||
|
||||
/// A builder for database transactions.
|
||||
pub struct TransactionBuilder<'a> {
|
||||
client: &'a mut Client,
|
||||
isolation_level: Option<IsolationLevel>,
|
||||
read_only: Option<bool>,
|
||||
deferrable: Option<bool>,
|
||||
}
|
||||
|
||||
impl<'a> TransactionBuilder<'a> {
|
||||
pub(crate) fn new(client: &'a mut Client) -> TransactionBuilder<'a> {
|
||||
TransactionBuilder {
|
||||
client,
|
||||
isolation_level: None,
|
||||
read_only: None,
|
||||
deferrable: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets the isolation level of the transaction.
|
||||
pub fn isolation_level(mut self, isolation_level: IsolationLevel) -> Self {
|
||||
self.isolation_level = Some(isolation_level);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the access mode of the transaction.
|
||||
pub fn read_only(mut self, read_only: bool) -> Self {
|
||||
self.read_only = Some(read_only);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the deferrability of the transaction.
|
||||
///
|
||||
/// If the transaction is also serializable and read only, creation of the transaction may block, but when it
|
||||
/// completes the transaction is able to run with less overhead and a guarantee that it will not be aborted due to
|
||||
/// serialization failure.
|
||||
pub fn deferrable(mut self, deferrable: bool) -> Self {
|
||||
self.deferrable = Some(deferrable);
|
||||
self
|
||||
}
|
||||
|
||||
/// Begins the transaction.
|
||||
///
|
||||
/// The transaction will roll back by default - use the `commit` method to commit it.
|
||||
pub async fn start(self) -> Result<Transaction<'a>, Error> {
|
||||
let mut query = "START TRANSACTION".to_string();
|
||||
let mut first = true;
|
||||
|
||||
if let Some(level) = self.isolation_level {
|
||||
first = false;
|
||||
|
||||
query.push_str(" ISOLATION LEVEL ");
|
||||
let level = match level {
|
||||
IsolationLevel::ReadUncommitted => "READ UNCOMMITTED",
|
||||
IsolationLevel::ReadCommitted => "READ COMMITTED",
|
||||
IsolationLevel::RepeatableRead => "REPEATABLE READ",
|
||||
IsolationLevel::Serializable => "SERIALIZABLE",
|
||||
};
|
||||
query.push_str(level);
|
||||
}
|
||||
|
||||
if let Some(read_only) = self.read_only {
|
||||
if !first {
|
||||
query.push(',');
|
||||
}
|
||||
first = false;
|
||||
|
||||
let s = if read_only {
|
||||
" READ ONLY"
|
||||
} else {
|
||||
" READ WRITE"
|
||||
};
|
||||
query.push_str(s);
|
||||
}
|
||||
|
||||
if let Some(deferrable) = self.deferrable {
|
||||
if !first {
|
||||
query.push(',');
|
||||
}
|
||||
|
||||
let s = if deferrable {
|
||||
" DEFERRABLE"
|
||||
} else {
|
||||
" NOT DEFERRABLE"
|
||||
};
|
||||
query.push_str(s);
|
||||
}
|
||||
|
||||
self.client.batch_execute(&query).await?;
|
||||
|
||||
Ok(Transaction::new(self.client))
|
||||
}
|
||||
}
|
||||
6
libs/proxy/tokio-postgres2/src/types.rs
Normal file
6
libs/proxy/tokio-postgres2/src/types.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
//! Types.
|
||||
//!
|
||||
//! This module is a reexport of the `postgres_types` crate.
|
||||
|
||||
#[doc(inline)]
|
||||
pub use postgres_types2::*;
|
||||
@@ -220,6 +220,11 @@ impl AzureBlobStorage {
|
||||
let started_at = ScopeGuard::into_inner(started_at);
|
||||
let outcome = match &download {
|
||||
Ok(_) => AttemptOutcome::Ok,
|
||||
// At this level in the stack 404 and 304 responses do not indicate an error.
|
||||
// There's expected cases when a blob may not exist or hasn't been modified since
|
||||
// the last get (e.g. probing for timeline indices and heatmap downloads).
|
||||
// Callers should handle errors if they are unexpected.
|
||||
Err(DownloadError::NotFound | DownloadError::Unmodified) => AttemptOutcome::Ok,
|
||||
Err(_) => AttemptOutcome::Err,
|
||||
};
|
||||
crate::metrics::BUCKET_METRICS
|
||||
|
||||
@@ -176,7 +176,9 @@ pub(crate) struct BucketMetrics {
|
||||
|
||||
impl Default for BucketMetrics {
|
||||
fn default() -> Self {
|
||||
let buckets = [0.01, 0.10, 0.5, 1.0, 5.0, 10.0, 50.0, 100.0];
|
||||
// first bucket 100 microseconds to count requests that do not need to wait at all
|
||||
// and get a permit immediately
|
||||
let buckets = [0.0001, 0.01, 0.10, 0.5, 1.0, 5.0, 10.0, 50.0, 100.0];
|
||||
|
||||
let req_seconds = register_histogram_vec!(
|
||||
"remote_storage_s3_request_seconds",
|
||||
|
||||
@@ -34,6 +34,7 @@ pprof.workspace = true
|
||||
regex.workspace = true
|
||||
routerify.workspace = true
|
||||
serde.workspace = true
|
||||
serde_with.workspace = true
|
||||
serde_json.workspace = true
|
||||
signal-hook.workspace = true
|
||||
thiserror.workspace = true
|
||||
|
||||
@@ -7,29 +7,88 @@ use postgres_connection::{parse_host_port, PgConnectionConfig};
|
||||
|
||||
use crate::id::TenantTimelineId;
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Debug, serde::Serialize, serde::Deserialize)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum InterpretedFormat {
|
||||
Bincode,
|
||||
Protobuf,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum Compression {
|
||||
Zstd { level: i8 },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
|
||||
#[serde(tag = "type", content = "args")]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum PostgresClientProtocol {
|
||||
/// Usual Postgres replication protocol
|
||||
Vanilla,
|
||||
/// Custom shard-aware protocol that replicates interpreted records.
|
||||
/// Used to send wal from safekeeper to pageserver.
|
||||
Interpreted {
|
||||
format: InterpretedFormat,
|
||||
compression: Option<Compression>,
|
||||
},
|
||||
}
|
||||
|
||||
pub struct ConnectionConfigArgs<'a> {
|
||||
pub protocol: PostgresClientProtocol,
|
||||
|
||||
pub ttid: TenantTimelineId,
|
||||
pub shard_number: Option<u8>,
|
||||
pub shard_count: Option<u8>,
|
||||
pub shard_stripe_size: Option<u32>,
|
||||
|
||||
pub listen_pg_addr_str: &'a str,
|
||||
|
||||
pub auth_token: Option<&'a str>,
|
||||
pub availability_zone: Option<&'a str>,
|
||||
}
|
||||
|
||||
impl<'a> ConnectionConfigArgs<'a> {
|
||||
fn options(&'a self) -> Vec<String> {
|
||||
let mut options = vec![
|
||||
"-c".to_owned(),
|
||||
format!("timeline_id={}", self.ttid.timeline_id),
|
||||
format!("tenant_id={}", self.ttid.tenant_id),
|
||||
format!(
|
||||
"protocol={}",
|
||||
serde_json::to_string(&self.protocol).unwrap()
|
||||
),
|
||||
];
|
||||
|
||||
if self.shard_number.is_some() {
|
||||
assert!(self.shard_count.is_some());
|
||||
assert!(self.shard_stripe_size.is_some());
|
||||
|
||||
options.push(format!("shard_count={}", self.shard_count.unwrap()));
|
||||
options.push(format!("shard_number={}", self.shard_number.unwrap()));
|
||||
options.push(format!(
|
||||
"shard_stripe_size={}",
|
||||
self.shard_stripe_size.unwrap()
|
||||
));
|
||||
}
|
||||
|
||||
options
|
||||
}
|
||||
}
|
||||
|
||||
/// Create client config for fetching WAL from safekeeper on particular timeline.
|
||||
/// listen_pg_addr_str is in form host:\[port\].
|
||||
pub fn wal_stream_connection_config(
|
||||
TenantTimelineId {
|
||||
tenant_id,
|
||||
timeline_id,
|
||||
}: TenantTimelineId,
|
||||
listen_pg_addr_str: &str,
|
||||
auth_token: Option<&str>,
|
||||
availability_zone: Option<&str>,
|
||||
args: ConnectionConfigArgs,
|
||||
) -> anyhow::Result<PgConnectionConfig> {
|
||||
let (host, port) =
|
||||
parse_host_port(listen_pg_addr_str).context("Unable to parse listen_pg_addr_str")?;
|
||||
parse_host_port(args.listen_pg_addr_str).context("Unable to parse listen_pg_addr_str")?;
|
||||
let port = port.unwrap_or(5432);
|
||||
let mut connstr = PgConnectionConfig::new_host_port(host, port)
|
||||
.extend_options([
|
||||
"-c".to_owned(),
|
||||
format!("timeline_id={}", timeline_id),
|
||||
format!("tenant_id={}", tenant_id),
|
||||
])
|
||||
.set_password(auth_token.map(|s| s.to_owned()));
|
||||
.extend_options(args.options())
|
||||
.set_password(args.auth_token.map(|s| s.to_owned()));
|
||||
|
||||
if let Some(availability_zone) = availability_zone {
|
||||
if let Some(availability_zone) = args.availability_zone {
|
||||
connstr = connstr.extend_options([format!("availability_zone={}", availability_zone)]);
|
||||
}
|
||||
|
||||
|
||||
@@ -218,7 +218,7 @@ impl MemoryStatus {
|
||||
fn debug_slice(slice: &[Self]) -> impl '_ + Debug {
|
||||
struct DS<'a>(&'a [MemoryStatus]);
|
||||
|
||||
impl<'a> Debug for DS<'a> {
|
||||
impl Debug for DS<'_> {
|
||||
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
|
||||
f.debug_struct("[MemoryStatus]")
|
||||
.field(
|
||||
@@ -233,7 +233,7 @@ impl MemoryStatus {
|
||||
|
||||
struct Fields<'a, F>(&'a [MemoryStatus], F);
|
||||
|
||||
impl<'a, F: Fn(&MemoryStatus) -> T, T: Debug> Debug for Fields<'a, F> {
|
||||
impl<F: Fn(&MemoryStatus) -> T, T: Debug> Debug for Fields<'_, F> {
|
||||
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
|
||||
f.debug_list().entries(self.0.iter().map(&self.1)).finish()
|
||||
}
|
||||
|
||||
@@ -8,11 +8,19 @@ license.workspace = true
|
||||
testing = ["pageserver_api/testing"]
|
||||
|
||||
[dependencies]
|
||||
async-compression.workspace = true
|
||||
anyhow.workspace = true
|
||||
bytes.workspace = true
|
||||
pageserver_api.workspace = true
|
||||
prost.workspace = true
|
||||
postgres_ffi.workspace = true
|
||||
serde.workspace = true
|
||||
thiserror.workspace = true
|
||||
tokio = { workspace = true, features = ["io-util"] }
|
||||
tonic.workspace = true
|
||||
tracing.workspace = true
|
||||
utils.workspace = true
|
||||
workspace_hack = { version = "0.1", path = "../../workspace_hack" }
|
||||
|
||||
[build-dependencies]
|
||||
tonic-build.workspace = true
|
||||
|
||||
11
libs/wal_decoder/build.rs
Normal file
11
libs/wal_decoder/build.rs
Normal file
@@ -0,0 +1,11 @@
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Generate rust code from .proto protobuf.
|
||||
//
|
||||
// Note: we previously tried to use deterministic location at proto/ for
|
||||
// easy location, but apparently interference with cachepot sometimes fails
|
||||
// the build then. Anyway, per cargo docs build script shouldn't output to
|
||||
// anywhere but $OUT_DIR.
|
||||
tonic_build::compile_protos("proto/interpreted_wal.proto")
|
||||
.unwrap_or_else(|e| panic!("failed to compile protos {:?}", e));
|
||||
Ok(())
|
||||
}
|
||||
43
libs/wal_decoder/proto/interpreted_wal.proto
Normal file
43
libs/wal_decoder/proto/interpreted_wal.proto
Normal file
@@ -0,0 +1,43 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package interpreted_wal;
|
||||
|
||||
message InterpretedWalRecords {
|
||||
repeated InterpretedWalRecord records = 1;
|
||||
optional uint64 next_record_lsn = 2;
|
||||
}
|
||||
|
||||
message InterpretedWalRecord {
|
||||
optional bytes metadata_record = 1;
|
||||
SerializedValueBatch batch = 2;
|
||||
uint64 next_record_lsn = 3;
|
||||
bool flush_uncommitted = 4;
|
||||
uint32 xid = 5;
|
||||
}
|
||||
|
||||
message SerializedValueBatch {
|
||||
bytes raw = 1;
|
||||
repeated ValueMeta metadata = 2;
|
||||
uint64 max_lsn = 3;
|
||||
uint64 len = 4;
|
||||
}
|
||||
|
||||
enum ValueMetaType {
|
||||
Serialized = 0;
|
||||
Observed = 1;
|
||||
}
|
||||
|
||||
message ValueMeta {
|
||||
ValueMetaType type = 1;
|
||||
CompactKey key = 2;
|
||||
uint64 lsn = 3;
|
||||
optional uint64 batch_offset = 4;
|
||||
optional uint64 len = 5;
|
||||
optional bool will_init = 6;
|
||||
}
|
||||
|
||||
message CompactKey {
|
||||
int64 high = 1;
|
||||
int64 low = 2;
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
use crate::models::*;
|
||||
use crate::serialized_batch::SerializedValueBatch;
|
||||
use bytes::{Buf, Bytes};
|
||||
use pageserver_api::key::rel_block_to_key;
|
||||
use pageserver_api::reltag::{RelTag, SlruKind};
|
||||
use pageserver_api::shard::ShardIdentity;
|
||||
use postgres_ffi::pg_constants;
|
||||
@@ -32,7 +33,8 @@ impl InterpretedWalRecord {
|
||||
FlushUncommittedRecords::No
|
||||
};
|
||||
|
||||
let metadata_record = MetadataRecord::from_decoded(&decoded, next_record_lsn, pg_version)?;
|
||||
let metadata_record =
|
||||
MetadataRecord::from_decoded_filtered(&decoded, shard, next_record_lsn, pg_version)?;
|
||||
let batch = SerializedValueBatch::from_decoded_filtered(
|
||||
decoded,
|
||||
shard,
|
||||
@@ -51,8 +53,13 @@ impl InterpretedWalRecord {
|
||||
}
|
||||
|
||||
impl MetadataRecord {
|
||||
fn from_decoded(
|
||||
/// Builds a metadata record for this WAL record, if any.
|
||||
///
|
||||
/// Only metadata records relevant for the given shard are emitted. Currently, most metadata
|
||||
/// records are broadcast to all shards for simplicity, but this should be improved.
|
||||
fn from_decoded_filtered(
|
||||
decoded: &DecodedWALRecord,
|
||||
shard: &ShardIdentity,
|
||||
next_record_lsn: Lsn,
|
||||
pg_version: u32,
|
||||
) -> anyhow::Result<Option<MetadataRecord>> {
|
||||
@@ -61,26 +68,27 @@ impl MetadataRecord {
|
||||
let mut buf = decoded.record.clone();
|
||||
buf.advance(decoded.main_data_offset);
|
||||
|
||||
match decoded.xl_rmid {
|
||||
// First, generate metadata records from the decoded WAL record.
|
||||
let mut metadata_record = match decoded.xl_rmid {
|
||||
pg_constants::RM_HEAP_ID | pg_constants::RM_HEAP2_ID => {
|
||||
Self::decode_heapam_record(&mut buf, decoded, pg_version)
|
||||
Self::decode_heapam_record(&mut buf, decoded, pg_version)?
|
||||
}
|
||||
pg_constants::RM_NEON_ID => Self::decode_neonmgr_record(&mut buf, decoded, pg_version),
|
||||
pg_constants::RM_NEON_ID => Self::decode_neonmgr_record(&mut buf, decoded, pg_version)?,
|
||||
// Handle other special record types
|
||||
pg_constants::RM_SMGR_ID => Self::decode_smgr_record(&mut buf, decoded),
|
||||
pg_constants::RM_DBASE_ID => Self::decode_dbase_record(&mut buf, decoded, pg_version),
|
||||
pg_constants::RM_SMGR_ID => Self::decode_smgr_record(&mut buf, decoded)?,
|
||||
pg_constants::RM_DBASE_ID => Self::decode_dbase_record(&mut buf, decoded, pg_version)?,
|
||||
pg_constants::RM_TBLSPC_ID => {
|
||||
tracing::trace!("XLOG_TBLSPC_CREATE/DROP is not handled yet");
|
||||
Ok(None)
|
||||
None
|
||||
}
|
||||
pg_constants::RM_CLOG_ID => Self::decode_clog_record(&mut buf, decoded, pg_version),
|
||||
pg_constants::RM_CLOG_ID => Self::decode_clog_record(&mut buf, decoded, pg_version)?,
|
||||
pg_constants::RM_XACT_ID => {
|
||||
Self::decode_xact_record(&mut buf, decoded, next_record_lsn)
|
||||
Self::decode_xact_record(&mut buf, decoded, next_record_lsn)?
|
||||
}
|
||||
pg_constants::RM_MULTIXACT_ID => {
|
||||
Self::decode_multixact_record(&mut buf, decoded, pg_version)
|
||||
Self::decode_multixact_record(&mut buf, decoded, pg_version)?
|
||||
}
|
||||
pg_constants::RM_RELMAP_ID => Self::decode_relmap_record(&mut buf, decoded),
|
||||
pg_constants::RM_RELMAP_ID => Self::decode_relmap_record(&mut buf, decoded)?,
|
||||
// This is an odd duck. It needs to go to all shards.
|
||||
// Since it uses the checkpoint image (that's initialized from CHECKPOINT_KEY
|
||||
// in WalIngest::new), we have to send the whole DecodedWalRecord::record to
|
||||
@@ -89,19 +97,48 @@ impl MetadataRecord {
|
||||
// Alternatively, one can make the checkpoint part of the subscription protocol
|
||||
// to the pageserver. This should work fine, but can be done at a later point.
|
||||
pg_constants::RM_XLOG_ID => {
|
||||
Self::decode_xlog_record(&mut buf, decoded, next_record_lsn)
|
||||
Self::decode_xlog_record(&mut buf, decoded, next_record_lsn)?
|
||||
}
|
||||
pg_constants::RM_LOGICALMSG_ID => {
|
||||
Self::decode_logical_message_record(&mut buf, decoded)
|
||||
Self::decode_logical_message_record(&mut buf, decoded)?
|
||||
}
|
||||
pg_constants::RM_STANDBY_ID => Self::decode_standby_record(&mut buf, decoded),
|
||||
pg_constants::RM_REPLORIGIN_ID => Self::decode_replorigin_record(&mut buf, decoded),
|
||||
pg_constants::RM_STANDBY_ID => Self::decode_standby_record(&mut buf, decoded)?,
|
||||
pg_constants::RM_REPLORIGIN_ID => Self::decode_replorigin_record(&mut buf, decoded)?,
|
||||
_unexpected => {
|
||||
// TODO: consider failing here instead of blindly doing something without
|
||||
// understanding the protocol
|
||||
Ok(None)
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
// Next, filter the metadata record by shard.
|
||||
|
||||
// Route VM page updates to the shards that own them. VM pages are stored in the VM fork
|
||||
// of the main relation. These are sharded and managed just like regular relation pages.
|
||||
// See: https://github.com/neondatabase/neon/issues/9855
|
||||
if let Some(
|
||||
MetadataRecord::Heapam(HeapamRecord::ClearVmBits(ref mut clear_vm_bits))
|
||||
| MetadataRecord::Neonrmgr(NeonrmgrRecord::ClearVmBits(ref mut clear_vm_bits)),
|
||||
) = metadata_record
|
||||
{
|
||||
let is_local_vm_page = |heap_blk| {
|
||||
let vm_blk = pg_constants::HEAPBLK_TO_MAPBLOCK(heap_blk);
|
||||
shard.is_key_local(&rel_block_to_key(clear_vm_bits.vm_rel, vm_blk))
|
||||
};
|
||||
// Send the old and new VM page updates to their respective shards.
|
||||
clear_vm_bits.old_heap_blkno = clear_vm_bits
|
||||
.old_heap_blkno
|
||||
.filter(|&blkno| is_local_vm_page(blkno));
|
||||
clear_vm_bits.new_heap_blkno = clear_vm_bits
|
||||
.new_heap_blkno
|
||||
.filter(|&blkno| is_local_vm_page(blkno));
|
||||
// If neither VM page belongs to this shard, discard the record.
|
||||
if clear_vm_bits.old_heap_blkno.is_none() && clear_vm_bits.new_heap_blkno.is_none() {
|
||||
metadata_record = None
|
||||
}
|
||||
}
|
||||
|
||||
Ok(metadata_record)
|
||||
}
|
||||
|
||||
fn decode_heapam_record(
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod decoder;
|
||||
pub mod models;
|
||||
pub mod serialized_batch;
|
||||
pub mod wire_format;
|
||||
|
||||
@@ -37,12 +37,32 @@ use utils::lsn::Lsn;
|
||||
|
||||
use crate::serialized_batch::SerializedValueBatch;
|
||||
|
||||
// Code generated by protobuf.
|
||||
pub mod proto {
|
||||
// Tonic does derives as `#[derive(Clone, PartialEq, ::prost::Message)]`
|
||||
// we don't use these types for anything but broker data transmission,
|
||||
// so it's ok to ignore this one.
|
||||
#![allow(clippy::derive_partial_eq_without_eq)]
|
||||
// The generated ValueMeta has a `len` method generate for its `len` field.
|
||||
#![allow(clippy::len_without_is_empty)]
|
||||
tonic::include_proto!("interpreted_wal");
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub enum FlushUncommittedRecords {
|
||||
Yes,
|
||||
No,
|
||||
}
|
||||
|
||||
/// A batch of interpreted WAL records
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct InterpretedWalRecords {
|
||||
pub records: Vec<InterpretedWalRecord>,
|
||||
// Start LSN of the next record after the batch.
|
||||
// Note that said record may not belong to the current shard.
|
||||
pub next_record_lsn: Option<Lsn>,
|
||||
}
|
||||
|
||||
/// An interpreted Postgres WAL record, ready to be handled by the pageserver
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct InterpretedWalRecord {
|
||||
@@ -65,6 +85,18 @@ pub struct InterpretedWalRecord {
|
||||
pub xid: TransactionId,
|
||||
}
|
||||
|
||||
impl InterpretedWalRecord {
|
||||
/// Checks if the WAL record is empty
|
||||
///
|
||||
/// An empty interpreted WAL record has no data or metadata and does not have to be sent to the
|
||||
/// pageserver.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.batch.is_empty()
|
||||
&& self.metadata_record.is_none()
|
||||
&& matches!(self.flush_uncommitted, FlushUncommittedRecords::No)
|
||||
}
|
||||
}
|
||||
|
||||
/// The interpreted part of the Postgres WAL record which requires metadata
|
||||
/// writes to the underlying storage engine.
|
||||
#[derive(Serialize, Deserialize)]
|
||||
|
||||
@@ -496,11 +496,16 @@ impl SerializedValueBatch {
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks if the batch is empty
|
||||
///
|
||||
/// A batch is empty when it contains no serialized values.
|
||||
/// Note that it may still contain observed values.
|
||||
/// Checks if the batch contains any serialized or observed values
|
||||
pub fn is_empty(&self) -> bool {
|
||||
!self.has_data() && self.metadata.is_empty()
|
||||
}
|
||||
|
||||
/// Checks if the batch contains data
|
||||
///
|
||||
/// Note that if this returns false, it may still contain observed values or
|
||||
/// a metadata record.
|
||||
pub fn has_data(&self) -> bool {
|
||||
let empty = self.raw.is_empty();
|
||||
|
||||
if cfg!(debug_assertions) && empty {
|
||||
@@ -510,7 +515,7 @@ impl SerializedValueBatch {
|
||||
.all(|meta| matches!(meta, ValueMeta::Observed(_))));
|
||||
}
|
||||
|
||||
empty
|
||||
!empty
|
||||
}
|
||||
|
||||
/// Returns the number of values serialized in the batch
|
||||
|
||||
356
libs/wal_decoder/src/wire_format.rs
Normal file
356
libs/wal_decoder/src/wire_format.rs
Normal file
@@ -0,0 +1,356 @@
|
||||
use bytes::{BufMut, Bytes, BytesMut};
|
||||
use pageserver_api::key::CompactKey;
|
||||
use prost::{DecodeError, EncodeError, Message};
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use utils::bin_ser::{BeSer, DeserializeError, SerializeError};
|
||||
use utils::lsn::Lsn;
|
||||
use utils::postgres_client::{Compression, InterpretedFormat};
|
||||
|
||||
use crate::models::{
|
||||
FlushUncommittedRecords, InterpretedWalRecord, InterpretedWalRecords, MetadataRecord,
|
||||
};
|
||||
|
||||
use crate::serialized_batch::{
|
||||
ObservedValueMeta, SerializedValueBatch, SerializedValueMeta, ValueMeta,
|
||||
};
|
||||
|
||||
use crate::models::proto;
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ToWireFormatError {
|
||||
#[error("{0}")]
|
||||
Bincode(#[from] SerializeError),
|
||||
#[error("{0}")]
|
||||
Protobuf(#[from] ProtobufSerializeError),
|
||||
#[error("{0}")]
|
||||
Compression(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ProtobufSerializeError {
|
||||
#[error("{0}")]
|
||||
MetadataRecord(#[from] SerializeError),
|
||||
#[error("{0}")]
|
||||
Encode(#[from] EncodeError),
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum FromWireFormatError {
|
||||
#[error("{0}")]
|
||||
Bincode(#[from] DeserializeError),
|
||||
#[error("{0}")]
|
||||
Protobuf(#[from] ProtobufDeserializeError),
|
||||
#[error("{0}")]
|
||||
Decompress(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ProtobufDeserializeError {
|
||||
#[error("{0}")]
|
||||
Transcode(#[from] TranscodeError),
|
||||
#[error("{0}")]
|
||||
Decode(#[from] DecodeError),
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum TranscodeError {
|
||||
#[error("{0}")]
|
||||
BadInput(String),
|
||||
#[error("{0}")]
|
||||
MetadataRecord(#[from] DeserializeError),
|
||||
}
|
||||
|
||||
pub trait ToWireFormat {
|
||||
fn to_wire(
|
||||
self,
|
||||
format: InterpretedFormat,
|
||||
compression: Option<Compression>,
|
||||
) -> impl std::future::Future<Output = Result<Bytes, ToWireFormatError>> + Send;
|
||||
}
|
||||
|
||||
pub trait FromWireFormat {
|
||||
type T;
|
||||
fn from_wire(
|
||||
buf: &Bytes,
|
||||
format: InterpretedFormat,
|
||||
compression: Option<Compression>,
|
||||
) -> impl std::future::Future<Output = Result<Self::T, FromWireFormatError>> + Send;
|
||||
}
|
||||
|
||||
impl ToWireFormat for InterpretedWalRecords {
|
||||
async fn to_wire(
|
||||
self,
|
||||
format: InterpretedFormat,
|
||||
compression: Option<Compression>,
|
||||
) -> Result<Bytes, ToWireFormatError> {
|
||||
use async_compression::tokio::write::ZstdEncoder;
|
||||
use async_compression::Level;
|
||||
|
||||
let encode_res: Result<Bytes, ToWireFormatError> = match format {
|
||||
InterpretedFormat::Bincode => {
|
||||
let buf = BytesMut::new();
|
||||
let mut buf = buf.writer();
|
||||
self.ser_into(&mut buf)?;
|
||||
Ok(buf.into_inner().freeze())
|
||||
}
|
||||
InterpretedFormat::Protobuf => {
|
||||
let proto: proto::InterpretedWalRecords = self.try_into()?;
|
||||
let mut buf = BytesMut::new();
|
||||
proto
|
||||
.encode(&mut buf)
|
||||
.map_err(|e| ToWireFormatError::Protobuf(e.into()))?;
|
||||
|
||||
Ok(buf.freeze())
|
||||
}
|
||||
};
|
||||
|
||||
let buf = encode_res?;
|
||||
let compressed_buf = match compression {
|
||||
Some(Compression::Zstd { level }) => {
|
||||
let mut encoder = ZstdEncoder::with_quality(
|
||||
Vec::with_capacity(buf.len() / 4),
|
||||
Level::Precise(level as i32),
|
||||
);
|
||||
encoder.write_all(&buf).await?;
|
||||
encoder.shutdown().await?;
|
||||
Bytes::from(encoder.into_inner())
|
||||
}
|
||||
None => buf,
|
||||
};
|
||||
|
||||
Ok(compressed_buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl FromWireFormat for InterpretedWalRecords {
|
||||
type T = Self;
|
||||
|
||||
async fn from_wire(
|
||||
buf: &Bytes,
|
||||
format: InterpretedFormat,
|
||||
compression: Option<Compression>,
|
||||
) -> Result<Self, FromWireFormatError> {
|
||||
let decompressed_buf = match compression {
|
||||
Some(Compression::Zstd { .. }) => {
|
||||
use async_compression::tokio::write::ZstdDecoder;
|
||||
let mut decoded_buf = Vec::with_capacity(buf.len());
|
||||
let mut decoder = ZstdDecoder::new(&mut decoded_buf);
|
||||
decoder.write_all(buf).await?;
|
||||
decoder.flush().await?;
|
||||
Bytes::from(decoded_buf)
|
||||
}
|
||||
None => buf.clone(),
|
||||
};
|
||||
|
||||
match format {
|
||||
InterpretedFormat::Bincode => {
|
||||
InterpretedWalRecords::des(&decompressed_buf).map_err(FromWireFormatError::Bincode)
|
||||
}
|
||||
InterpretedFormat::Protobuf => {
|
||||
let proto = proto::InterpretedWalRecords::decode(decompressed_buf)
|
||||
.map_err(|e| FromWireFormatError::Protobuf(e.into()))?;
|
||||
InterpretedWalRecords::try_from(proto)
|
||||
.map_err(|e| FromWireFormatError::Protobuf(e.into()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<InterpretedWalRecords> for proto::InterpretedWalRecords {
|
||||
type Error = SerializeError;
|
||||
|
||||
fn try_from(value: InterpretedWalRecords) -> Result<Self, Self::Error> {
|
||||
let records = value
|
||||
.records
|
||||
.into_iter()
|
||||
.map(proto::InterpretedWalRecord::try_from)
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
Ok(proto::InterpretedWalRecords {
|
||||
records,
|
||||
next_record_lsn: value.next_record_lsn.map(|l| l.0),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<InterpretedWalRecord> for proto::InterpretedWalRecord {
|
||||
type Error = SerializeError;
|
||||
|
||||
fn try_from(value: InterpretedWalRecord) -> Result<Self, Self::Error> {
|
||||
let metadata_record = value
|
||||
.metadata_record
|
||||
.map(|meta_rec| -> Result<Vec<u8>, Self::Error> {
|
||||
let mut buf = Vec::new();
|
||||
meta_rec.ser_into(&mut buf)?;
|
||||
Ok(buf)
|
||||
})
|
||||
.transpose()?;
|
||||
|
||||
Ok(proto::InterpretedWalRecord {
|
||||
metadata_record,
|
||||
batch: Some(proto::SerializedValueBatch::from(value.batch)),
|
||||
next_record_lsn: value.next_record_lsn.0,
|
||||
flush_uncommitted: matches!(value.flush_uncommitted, FlushUncommittedRecords::Yes),
|
||||
xid: value.xid,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SerializedValueBatch> for proto::SerializedValueBatch {
|
||||
fn from(value: SerializedValueBatch) -> Self {
|
||||
proto::SerializedValueBatch {
|
||||
raw: value.raw,
|
||||
metadata: value
|
||||
.metadata
|
||||
.into_iter()
|
||||
.map(proto::ValueMeta::from)
|
||||
.collect(),
|
||||
max_lsn: value.max_lsn.0,
|
||||
len: value.len as u64,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ValueMeta> for proto::ValueMeta {
|
||||
fn from(value: ValueMeta) -> Self {
|
||||
match value {
|
||||
ValueMeta::Observed(obs) => proto::ValueMeta {
|
||||
r#type: proto::ValueMetaType::Observed.into(),
|
||||
key: Some(proto::CompactKey::from(obs.key)),
|
||||
lsn: obs.lsn.0,
|
||||
batch_offset: None,
|
||||
len: None,
|
||||
will_init: None,
|
||||
},
|
||||
ValueMeta::Serialized(ser) => proto::ValueMeta {
|
||||
r#type: proto::ValueMetaType::Serialized.into(),
|
||||
key: Some(proto::CompactKey::from(ser.key)),
|
||||
lsn: ser.lsn.0,
|
||||
batch_offset: Some(ser.batch_offset),
|
||||
len: Some(ser.len as u64),
|
||||
will_init: Some(ser.will_init),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<CompactKey> for proto::CompactKey {
|
||||
fn from(value: CompactKey) -> Self {
|
||||
proto::CompactKey {
|
||||
high: (value.raw() >> 64) as i64,
|
||||
low: value.raw() as i64,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<proto::InterpretedWalRecords> for InterpretedWalRecords {
|
||||
type Error = TranscodeError;
|
||||
|
||||
fn try_from(value: proto::InterpretedWalRecords) -> Result<Self, Self::Error> {
|
||||
let records = value
|
||||
.records
|
||||
.into_iter()
|
||||
.map(InterpretedWalRecord::try_from)
|
||||
.collect::<Result<_, _>>()?;
|
||||
|
||||
Ok(InterpretedWalRecords {
|
||||
records,
|
||||
next_record_lsn: value.next_record_lsn.map(Lsn::from),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<proto::InterpretedWalRecord> for InterpretedWalRecord {
|
||||
type Error = TranscodeError;
|
||||
|
||||
fn try_from(value: proto::InterpretedWalRecord) -> Result<Self, Self::Error> {
|
||||
let metadata_record = value
|
||||
.metadata_record
|
||||
.map(|mrec| -> Result<_, DeserializeError> { MetadataRecord::des(&mrec) })
|
||||
.transpose()?;
|
||||
|
||||
let batch = {
|
||||
let batch = value.batch.ok_or_else(|| {
|
||||
TranscodeError::BadInput("InterpretedWalRecord::batch missing".to_string())
|
||||
})?;
|
||||
|
||||
SerializedValueBatch::try_from(batch)?
|
||||
};
|
||||
|
||||
Ok(InterpretedWalRecord {
|
||||
metadata_record,
|
||||
batch,
|
||||
next_record_lsn: Lsn(value.next_record_lsn),
|
||||
flush_uncommitted: if value.flush_uncommitted {
|
||||
FlushUncommittedRecords::Yes
|
||||
} else {
|
||||
FlushUncommittedRecords::No
|
||||
},
|
||||
xid: value.xid,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<proto::SerializedValueBatch> for SerializedValueBatch {
|
||||
type Error = TranscodeError;
|
||||
|
||||
fn try_from(value: proto::SerializedValueBatch) -> Result<Self, Self::Error> {
|
||||
let metadata = value
|
||||
.metadata
|
||||
.into_iter()
|
||||
.map(ValueMeta::try_from)
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
Ok(SerializedValueBatch {
|
||||
raw: value.raw,
|
||||
metadata,
|
||||
max_lsn: Lsn(value.max_lsn),
|
||||
len: value.len as usize,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<proto::ValueMeta> for ValueMeta {
|
||||
type Error = TranscodeError;
|
||||
|
||||
fn try_from(value: proto::ValueMeta) -> Result<Self, Self::Error> {
|
||||
match proto::ValueMetaType::try_from(value.r#type) {
|
||||
Ok(proto::ValueMetaType::Serialized) => {
|
||||
Ok(ValueMeta::Serialized(SerializedValueMeta {
|
||||
key: value
|
||||
.key
|
||||
.ok_or_else(|| {
|
||||
TranscodeError::BadInput("ValueMeta::key missing".to_string())
|
||||
})?
|
||||
.into(),
|
||||
lsn: Lsn(value.lsn),
|
||||
batch_offset: value.batch_offset.ok_or_else(|| {
|
||||
TranscodeError::BadInput("ValueMeta::batch_offset missing".to_string())
|
||||
})?,
|
||||
len: value.len.ok_or_else(|| {
|
||||
TranscodeError::BadInput("ValueMeta::len missing".to_string())
|
||||
})? as usize,
|
||||
will_init: value.will_init.ok_or_else(|| {
|
||||
TranscodeError::BadInput("ValueMeta::will_init missing".to_string())
|
||||
})?,
|
||||
}))
|
||||
}
|
||||
Ok(proto::ValueMetaType::Observed) => Ok(ValueMeta::Observed(ObservedValueMeta {
|
||||
key: value
|
||||
.key
|
||||
.ok_or_else(|| TranscodeError::BadInput("ValueMeta::key missing".to_string()))?
|
||||
.into(),
|
||||
lsn: Lsn(value.lsn),
|
||||
})),
|
||||
Err(_) => Err(TranscodeError::BadInput(format!(
|
||||
"Unexpected ValueMeta::type {}",
|
||||
value.r#type
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<proto::CompactKey> for CompactKey {
|
||||
fn from(value: proto::CompactKey) -> Self {
|
||||
(((value.high as i128) << 64) | (value.low as i128)).into()
|
||||
}
|
||||
}
|
||||
@@ -126,6 +126,7 @@ fn main() -> anyhow::Result<()> {
|
||||
// after setting up logging, log the effective IO engine choice and read path implementations
|
||||
info!(?conf.virtual_file_io_engine, "starting with virtual_file IO engine");
|
||||
info!(?conf.virtual_file_io_mode, "starting with virtual_file IO mode");
|
||||
info!(?conf.wal_receiver_protocol, "starting with WAL receiver protocol");
|
||||
|
||||
// The tenants directory contains all the pageserver local disk state.
|
||||
// Create if not exists and make sure all the contents are durable before proceeding.
|
||||
|
||||
@@ -14,6 +14,7 @@ use remote_storage::{RemotePath, RemoteStorageConfig};
|
||||
use std::env;
|
||||
use storage_broker::Uri;
|
||||
use utils::logging::SecretString;
|
||||
use utils::postgres_client::PostgresClientProtocol;
|
||||
|
||||
use once_cell::sync::OnceCell;
|
||||
use reqwest::Url;
|
||||
@@ -187,6 +188,8 @@ pub struct PageServerConf {
|
||||
/// Optionally disable disk syncs (unsafe!)
|
||||
pub no_sync: bool,
|
||||
|
||||
pub wal_receiver_protocol: PostgresClientProtocol,
|
||||
|
||||
pub page_service_pipelining: pageserver_api::config::PageServicePipeliningConfig,
|
||||
}
|
||||
|
||||
@@ -347,6 +350,7 @@ impl PageServerConf {
|
||||
virtual_file_io_engine,
|
||||
tenant_config,
|
||||
no_sync,
|
||||
wal_receiver_protocol,
|
||||
page_service_pipelining,
|
||||
} = config_toml;
|
||||
|
||||
@@ -390,6 +394,7 @@ impl PageServerConf {
|
||||
import_pgdata_upcall_api,
|
||||
import_pgdata_upcall_api_token: import_pgdata_upcall_api_token.map(SecretString::from),
|
||||
import_pgdata_aws_endpoint_url,
|
||||
wal_receiver_protocol,
|
||||
page_service_pipelining,
|
||||
|
||||
// ------------------------------------------------------------
|
||||
|
||||
@@ -1144,18 +1144,24 @@ pub(crate) mod mock {
|
||||
rx: tokio::sync::mpsc::UnboundedReceiver<ListWriterQueueMessage>,
|
||||
executor_rx: tokio::sync::mpsc::Receiver<DeleterMessage>,
|
||||
cancel: CancellationToken,
|
||||
executed: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
impl ConsumerState {
|
||||
async fn consume(&mut self, remote_storage: &GenericRemoteStorage) -> usize {
|
||||
let mut executed = 0;
|
||||
|
||||
async fn consume(&mut self, remote_storage: &GenericRemoteStorage) {
|
||||
info!("Executing all pending deletions");
|
||||
|
||||
// Transform all executor messages to generic frontend messages
|
||||
while let Ok(msg) = self.executor_rx.try_recv() {
|
||||
loop {
|
||||
use either::Either;
|
||||
let msg = tokio::select! {
|
||||
left = self.executor_rx.recv() => Either::Left(left),
|
||||
right = self.rx.recv() => Either::Right(right),
|
||||
};
|
||||
match msg {
|
||||
DeleterMessage::Delete(objects) => {
|
||||
Either::Left(None) => break,
|
||||
Either::Right(None) => break,
|
||||
Either::Left(Some(DeleterMessage::Delete(objects))) => {
|
||||
for path in objects {
|
||||
match remote_storage.delete(&path, &self.cancel).await {
|
||||
Ok(_) => {
|
||||
@@ -1165,18 +1171,13 @@ pub(crate) mod mock {
|
||||
error!("Failed to delete {path}, leaking object! ({e})");
|
||||
}
|
||||
}
|
||||
executed += 1;
|
||||
self.executed.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
DeleterMessage::Flush(flush_op) => {
|
||||
Either::Left(Some(DeleterMessage::Flush(flush_op))) => {
|
||||
flush_op.notify();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
while let Ok(msg) = self.rx.try_recv() {
|
||||
match msg {
|
||||
ListWriterQueueMessage::Delete(op) => {
|
||||
Either::Right(Some(ListWriterQueueMessage::Delete(op))) => {
|
||||
let mut objects = op.objects;
|
||||
for (layer, meta) in op.layers {
|
||||
objects.push(remote_layer_path(
|
||||
@@ -1198,33 +1199,27 @@ pub(crate) mod mock {
|
||||
error!("Failed to delete {path}, leaking object! ({e})");
|
||||
}
|
||||
}
|
||||
executed += 1;
|
||||
self.executed.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
ListWriterQueueMessage::Flush(op) => {
|
||||
Either::Right(Some(ListWriterQueueMessage::Flush(op))) => {
|
||||
op.notify();
|
||||
}
|
||||
ListWriterQueueMessage::FlushExecute(op) => {
|
||||
Either::Right(Some(ListWriterQueueMessage::FlushExecute(op))) => {
|
||||
// We have already executed all prior deletions because mock does them inline
|
||||
op.notify();
|
||||
}
|
||||
ListWriterQueueMessage::Recover(_) => {
|
||||
Either::Right(Some(ListWriterQueueMessage::Recover(_))) => {
|
||||
// no-op in mock
|
||||
}
|
||||
}
|
||||
info!("All pending deletions have been executed");
|
||||
}
|
||||
|
||||
executed
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MockDeletionQueue {
|
||||
tx: tokio::sync::mpsc::UnboundedSender<ListWriterQueueMessage>,
|
||||
executor_tx: tokio::sync::mpsc::Sender<DeleterMessage>,
|
||||
executed: Arc<AtomicUsize>,
|
||||
remote_storage: Option<GenericRemoteStorage>,
|
||||
consumer: std::sync::Mutex<ConsumerState>,
|
||||
lsn_table: Arc<std::sync::RwLock<VisibleLsnUpdates>>,
|
||||
}
|
||||
|
||||
@@ -1235,29 +1230,34 @@ pub(crate) mod mock {
|
||||
|
||||
let executed = Arc::new(AtomicUsize::new(0));
|
||||
|
||||
let mut consumer = ConsumerState {
|
||||
rx,
|
||||
executor_rx,
|
||||
cancel: CancellationToken::new(),
|
||||
executed: executed.clone(),
|
||||
};
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Some(remote_storage) = &remote_storage {
|
||||
consumer.consume(remote_storage).await;
|
||||
}
|
||||
});
|
||||
|
||||
Self {
|
||||
tx,
|
||||
executor_tx,
|
||||
executed,
|
||||
remote_storage,
|
||||
consumer: std::sync::Mutex::new(ConsumerState {
|
||||
rx,
|
||||
executor_rx,
|
||||
cancel: CancellationToken::new(),
|
||||
}),
|
||||
lsn_table: Arc::new(std::sync::RwLock::new(VisibleLsnUpdates::new())),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::await_holding_lock)]
|
||||
pub async fn pump(&self) {
|
||||
if let Some(remote_storage) = &self.remote_storage {
|
||||
// Permit holding mutex across await, because this is only ever
|
||||
// called once at a time in tests.
|
||||
let mut locked = self.consumer.lock().unwrap();
|
||||
let count = locked.consume(remote_storage).await;
|
||||
self.executed.fetch_add(count, Ordering::Relaxed);
|
||||
}
|
||||
let (tx, rx) = tokio::sync::oneshot::channel();
|
||||
self.executor_tx
|
||||
.send(DeleterMessage::Flush(FlushOp { tx }))
|
||||
.await
|
||||
.expect("Failed to send flush message");
|
||||
rx.await.ok();
|
||||
}
|
||||
|
||||
pub(crate) fn new_client(&self) -> DeletionQueueClient {
|
||||
|
||||
@@ -3,7 +3,7 @@ use metrics::{
|
||||
register_counter_vec, register_gauge_vec, register_histogram, register_histogram_vec,
|
||||
register_int_counter, register_int_counter_pair_vec, register_int_counter_vec,
|
||||
register_int_gauge, register_int_gauge_vec, register_uint_gauge, register_uint_gauge_vec,
|
||||
Counter, CounterVec, GaugeVec, Histogram, HistogramVec, IntCounter, IntCounterPair,
|
||||
Counter, CounterVec, Gauge, GaugeVec, Histogram, HistogramVec, IntCounter, IntCounterPair,
|
||||
IntCounterPairVec, IntCounterVec, IntGauge, IntGaugeVec, UIntGauge, UIntGaugeVec,
|
||||
};
|
||||
use once_cell::sync::Lazy;
|
||||
@@ -457,6 +457,15 @@ pub(crate) static WAIT_LSN_TIME: Lazy<Histogram> = Lazy::new(|| {
|
||||
.expect("failed to define a metric")
|
||||
});
|
||||
|
||||
static FLUSH_WAIT_UPLOAD_TIME: Lazy<GaugeVec> = Lazy::new(|| {
|
||||
register_gauge_vec!(
|
||||
"pageserver_flush_wait_upload_seconds",
|
||||
"Time spent waiting for preceding uploads during layer flush",
|
||||
&["tenant_id", "shard_id", "timeline_id"]
|
||||
)
|
||||
.expect("failed to define a metric")
|
||||
});
|
||||
|
||||
static LAST_RECORD_LSN: Lazy<IntGaugeVec> = Lazy::new(|| {
|
||||
register_int_gauge_vec!(
|
||||
"pageserver_last_record_lsn",
|
||||
@@ -653,6 +662,35 @@ pub(crate) static COMPRESSION_IMAGE_OUTPUT_BYTES: Lazy<IntCounter> = Lazy::new(|
|
||||
.expect("failed to define a metric")
|
||||
});
|
||||
|
||||
pub(crate) static RELSIZE_CACHE_ENTRIES: Lazy<UIntGauge> = Lazy::new(|| {
|
||||
register_uint_gauge!(
|
||||
"pageserver_relsize_cache_entries",
|
||||
"Number of entries in the relation size cache",
|
||||
)
|
||||
.expect("failed to define a metric")
|
||||
});
|
||||
|
||||
pub(crate) static RELSIZE_CACHE_HITS: Lazy<IntCounter> = Lazy::new(|| {
|
||||
register_int_counter!("pageserver_relsize_cache_hits", "Relation size cache hits",)
|
||||
.expect("failed to define a metric")
|
||||
});
|
||||
|
||||
pub(crate) static RELSIZE_CACHE_MISSES: Lazy<IntCounter> = Lazy::new(|| {
|
||||
register_int_counter!(
|
||||
"pageserver_relsize_cache_misses",
|
||||
"Relation size cache misses",
|
||||
)
|
||||
.expect("failed to define a metric")
|
||||
});
|
||||
|
||||
pub(crate) static RELSIZE_CACHE_MISSES_OLD: Lazy<IntCounter> = Lazy::new(|| {
|
||||
register_int_counter!(
|
||||
"pageserver_relsize_cache_misses_old",
|
||||
"Relation size cache misses where the lookup LSN is older than the last relation update"
|
||||
)
|
||||
.expect("failed to define a metric")
|
||||
});
|
||||
|
||||
pub(crate) mod initial_logical_size {
|
||||
use metrics::{register_int_counter, register_int_counter_vec, IntCounter, IntCounterVec};
|
||||
use once_cell::sync::Lazy;
|
||||
@@ -2106,6 +2144,7 @@ pub(crate) struct WalIngestMetrics {
|
||||
pub(crate) records_committed: IntCounter,
|
||||
pub(crate) records_filtered: IntCounter,
|
||||
pub(crate) gap_blocks_zeroed_on_rel_extend: IntCounter,
|
||||
pub(crate) clear_vm_bits_unknown: IntCounterVec,
|
||||
}
|
||||
|
||||
pub(crate) static WAL_INGEST: Lazy<WalIngestMetrics> = Lazy::new(|| WalIngestMetrics {
|
||||
@@ -2134,6 +2173,12 @@ pub(crate) static WAL_INGEST: Lazy<WalIngestMetrics> = Lazy::new(|| WalIngestMet
|
||||
"Total number of zero gap blocks written on relation extends"
|
||||
)
|
||||
.expect("failed to define a metric"),
|
||||
clear_vm_bits_unknown: register_int_counter_vec!(
|
||||
"pageserver_wal_ingest_clear_vm_bits_unknown",
|
||||
"Number of ignored ClearVmBits operations due to unknown pages/relations",
|
||||
&["entity"],
|
||||
)
|
||||
.expect("failed to define a metric"),
|
||||
});
|
||||
|
||||
pub(crate) static WAL_REDO_TIME: Lazy<Histogram> = Lazy::new(|| {
|
||||
@@ -2336,6 +2381,7 @@ pub(crate) struct TimelineMetrics {
|
||||
shard_id: String,
|
||||
timeline_id: String,
|
||||
pub flush_time_histo: StorageTimeMetrics,
|
||||
pub flush_wait_upload_time_gauge: Gauge,
|
||||
pub compact_time_histo: StorageTimeMetrics,
|
||||
pub create_images_time_histo: StorageTimeMetrics,
|
||||
pub logical_size_histo: StorageTimeMetrics,
|
||||
@@ -2379,6 +2425,9 @@ impl TimelineMetrics {
|
||||
&shard_id,
|
||||
&timeline_id,
|
||||
);
|
||||
let flush_wait_upload_time_gauge = FLUSH_WAIT_UPLOAD_TIME
|
||||
.get_metric_with_label_values(&[&tenant_id, &shard_id, &timeline_id])
|
||||
.unwrap();
|
||||
let compact_time_histo = StorageTimeMetrics::new(
|
||||
StorageTimeOperation::Compact,
|
||||
&tenant_id,
|
||||
@@ -2516,6 +2565,7 @@ impl TimelineMetrics {
|
||||
shard_id,
|
||||
timeline_id,
|
||||
flush_time_histo,
|
||||
flush_wait_upload_time_gauge,
|
||||
compact_time_histo,
|
||||
create_images_time_histo,
|
||||
logical_size_histo,
|
||||
@@ -2563,6 +2613,14 @@ impl TimelineMetrics {
|
||||
self.resident_physical_size_gauge.get()
|
||||
}
|
||||
|
||||
pub(crate) fn flush_wait_upload_time_gauge_add(&self, duration: f64) {
|
||||
self.flush_wait_upload_time_gauge.add(duration);
|
||||
crate::metrics::FLUSH_WAIT_UPLOAD_TIME
|
||||
.get_metric_with_label_values(&[&self.tenant_id, &self.shard_id, &self.timeline_id])
|
||||
.unwrap()
|
||||
.add(duration);
|
||||
}
|
||||
|
||||
pub(crate) fn shutdown(&self) {
|
||||
let was_shutdown = self
|
||||
.shutdown
|
||||
@@ -2579,6 +2637,7 @@ impl TimelineMetrics {
|
||||
let timeline_id = &self.timeline_id;
|
||||
let shard_id = &self.shard_id;
|
||||
let _ = LAST_RECORD_LSN.remove_label_values(&[tenant_id, shard_id, timeline_id]);
|
||||
let _ = FLUSH_WAIT_UPLOAD_TIME.remove_label_values(&[tenant_id, shard_id, timeline_id]);
|
||||
let _ = STANDBY_HORIZON.remove_label_values(&[tenant_id, shard_id, timeline_id]);
|
||||
{
|
||||
RESIDENT_PHYSICAL_SIZE_GLOBAL.sub(self.resident_physical_size_get());
|
||||
|
||||
@@ -10,6 +10,9 @@ use super::tenant::{PageReconstructError, Timeline};
|
||||
use crate::aux_file;
|
||||
use crate::context::RequestContext;
|
||||
use crate::keyspace::{KeySpace, KeySpaceAccum};
|
||||
use crate::metrics::{
|
||||
RELSIZE_CACHE_ENTRIES, RELSIZE_CACHE_HITS, RELSIZE_CACHE_MISSES, RELSIZE_CACHE_MISSES_OLD,
|
||||
};
|
||||
use crate::span::{
|
||||
debug_assert_current_span_has_tenant_and_timeline_id,
|
||||
debug_assert_current_span_has_tenant_and_timeline_id_no_shard_id,
|
||||
@@ -389,7 +392,9 @@ impl Timeline {
|
||||
result
|
||||
}
|
||||
|
||||
// Get size of a database in blocks
|
||||
/// Get size of a database in blocks. This is only accurate on shard 0. It will undercount on
|
||||
/// other shards, by only accounting for relations the shard has pages for, and only accounting
|
||||
/// for pages up to the highest page number it has stored.
|
||||
pub(crate) async fn get_db_size(
|
||||
&self,
|
||||
spcnode: Oid,
|
||||
@@ -408,7 +413,10 @@ impl Timeline {
|
||||
Ok(total_blocks)
|
||||
}
|
||||
|
||||
/// Get size of a relation file
|
||||
/// Get size of a relation file. The relation must exist, otherwise an error is returned.
|
||||
///
|
||||
/// This is only accurate on shard 0. On other shards, it will return the size up to the highest
|
||||
/// page number stored in the shard.
|
||||
pub(crate) async fn get_rel_size(
|
||||
&self,
|
||||
tag: RelTag,
|
||||
@@ -444,7 +452,10 @@ impl Timeline {
|
||||
Ok(nblocks)
|
||||
}
|
||||
|
||||
/// Does relation exist?
|
||||
/// Does the relation exist?
|
||||
///
|
||||
/// Only shard 0 has a full view of the relations. Other shards only know about relations that
|
||||
/// the shard stores pages for.
|
||||
pub(crate) async fn get_rel_exists(
|
||||
&self,
|
||||
tag: RelTag,
|
||||
@@ -478,6 +489,9 @@ impl Timeline {
|
||||
|
||||
/// Get a list of all existing relations in given tablespace and database.
|
||||
///
|
||||
/// Only shard 0 has a full view of the relations. Other shards only know about relations that
|
||||
/// the shard stores pages for.
|
||||
///
|
||||
/// # Cancel-Safety
|
||||
///
|
||||
/// This method is cancellation-safe.
|
||||
@@ -1129,9 +1143,12 @@ impl Timeline {
|
||||
let rel_size_cache = self.rel_size_cache.read().unwrap();
|
||||
if let Some((cached_lsn, nblocks)) = rel_size_cache.map.get(tag) {
|
||||
if lsn >= *cached_lsn {
|
||||
RELSIZE_CACHE_HITS.inc();
|
||||
return Some(*nblocks);
|
||||
}
|
||||
RELSIZE_CACHE_MISSES_OLD.inc();
|
||||
}
|
||||
RELSIZE_CACHE_MISSES.inc();
|
||||
None
|
||||
}
|
||||
|
||||
@@ -1156,6 +1173,7 @@ impl Timeline {
|
||||
}
|
||||
hash_map::Entry::Vacant(entry) => {
|
||||
entry.insert((lsn, nblocks));
|
||||
RELSIZE_CACHE_ENTRIES.inc();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1163,13 +1181,17 @@ impl Timeline {
|
||||
/// Store cached relation size
|
||||
pub fn set_cached_rel_size(&self, tag: RelTag, lsn: Lsn, nblocks: BlockNumber) {
|
||||
let mut rel_size_cache = self.rel_size_cache.write().unwrap();
|
||||
rel_size_cache.map.insert(tag, (lsn, nblocks));
|
||||
if rel_size_cache.map.insert(tag, (lsn, nblocks)).is_none() {
|
||||
RELSIZE_CACHE_ENTRIES.inc();
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove cached relation size
|
||||
pub fn remove_cached_rel_size(&self, tag: &RelTag) {
|
||||
let mut rel_size_cache = self.rel_size_cache.write().unwrap();
|
||||
rel_size_cache.map.remove(tag);
|
||||
if rel_size_cache.map.remove(tag).is_some() {
|
||||
RELSIZE_CACHE_ENTRIES.dec();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1229,10 +1251,9 @@ impl<'a> DatadirModification<'a> {
|
||||
}
|
||||
|
||||
pub(crate) fn has_dirty_data(&self) -> bool {
|
||||
!self
|
||||
.pending_data_batch
|
||||
self.pending_data_batch
|
||||
.as_ref()
|
||||
.map_or(true, |b| b.is_empty())
|
||||
.map_or(false, |b| b.has_data())
|
||||
}
|
||||
|
||||
/// Set the current lsn
|
||||
@@ -1408,7 +1429,7 @@ impl<'a> DatadirModification<'a> {
|
||||
Some(pending_batch) => {
|
||||
pending_batch.extend(batch);
|
||||
}
|
||||
None if !batch.is_empty() => {
|
||||
None if batch.has_data() => {
|
||||
self.pending_data_batch = Some(batch);
|
||||
}
|
||||
None => {
|
||||
|
||||
@@ -3215,6 +3215,18 @@ impl Tenant {
|
||||
}
|
||||
}
|
||||
|
||||
if let ShutdownMode::Reload = shutdown_mode {
|
||||
tracing::info!("Flushing deletion queue");
|
||||
if let Err(e) = self.deletion_queue_client.flush().await {
|
||||
match e {
|
||||
DeletionQueueError::ShuttingDown => {
|
||||
// This is the only error we expect for now. In the future, if more error
|
||||
// variants are added, we should handle them here.
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We cancel the Tenant's cancellation token _after_ the timelines have all shut down. This permits
|
||||
// them to continue to do work during their shutdown methods, e.g. flushing data.
|
||||
tracing::debug!("Cancelling CancellationToken");
|
||||
@@ -5344,6 +5356,7 @@ pub(crate) mod harness {
|
||||
lsn_lease_length: Some(tenant_conf.lsn_lease_length),
|
||||
lsn_lease_length_for_ts: Some(tenant_conf.lsn_lease_length_for_ts),
|
||||
timeline_offloading: Some(tenant_conf.timeline_offloading),
|
||||
wal_receiver_protocol_override: tenant_conf.wal_receiver_protocol_override,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ use serde_json::Value;
|
||||
use std::num::NonZeroU64;
|
||||
use std::time::Duration;
|
||||
use utils::generation::Generation;
|
||||
use utils::postgres_client::PostgresClientProtocol;
|
||||
|
||||
#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub(crate) enum AttachmentMode {
|
||||
@@ -353,6 +354,9 @@ pub struct TenantConfOpt {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[serde(default)]
|
||||
pub timeline_offloading: Option<bool>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub wal_receiver_protocol_override: Option<PostgresClientProtocol>,
|
||||
}
|
||||
|
||||
impl TenantConfOpt {
|
||||
@@ -418,6 +422,9 @@ impl TenantConfOpt {
|
||||
timeline_offloading: self
|
||||
.lazy_slru_download
|
||||
.unwrap_or(global_conf.timeline_offloading),
|
||||
wal_receiver_protocol_override: self
|
||||
.wal_receiver_protocol_override
|
||||
.or(global_conf.wal_receiver_protocol_override),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -472,6 +479,7 @@ impl From<TenantConfOpt> for models::TenantConfig {
|
||||
lsn_lease_length: value.lsn_lease_length.map(humantime),
|
||||
lsn_lease_length_for_ts: value.lsn_lease_length_for_ts.map(humantime),
|
||||
timeline_offloading: value.timeline_offloading,
|
||||
wal_receiver_protocol_override: value.wal_receiver_protocol_override,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1960,7 +1960,7 @@ impl TenantManager {
|
||||
attempt.before_reset_tenant();
|
||||
|
||||
let (_guard, progress) = utils::completion::channel();
|
||||
match tenant.shutdown(progress, ShutdownMode::Flush).await {
|
||||
match tenant.shutdown(progress, ShutdownMode::Reload).await {
|
||||
Ok(()) => {
|
||||
slot_guard.drop_old_value().expect("it was just shutdown");
|
||||
}
|
||||
|
||||
@@ -50,6 +50,7 @@ use tokio_util::sync::CancellationToken;
|
||||
use tracing::*;
|
||||
use utils::{
|
||||
fs_ext, pausable_failpoint,
|
||||
postgres_client::PostgresClientProtocol,
|
||||
sync::gate::{Gate, GateGuard},
|
||||
};
|
||||
use wal_decoder::serialized_batch::SerializedValueBatch;
|
||||
@@ -893,10 +894,11 @@ pub(crate) enum ShutdownMode {
|
||||
/// While we are flushing, we continue to accept read I/O for LSNs ingested before
|
||||
/// the call to [`Timeline::shutdown`].
|
||||
FreezeAndFlush,
|
||||
/// Only flush the layers to the remote storage without freezing any open layers. This is the
|
||||
/// mode used by ancestor detach and any other operations that reloads a tenant but not increasing
|
||||
/// the generation number.
|
||||
Flush,
|
||||
/// Only flush the layers to the remote storage without freezing any open layers. Flush the deletion
|
||||
/// queue. This is the mode used by ancestor detach and any other operations that reloads a tenant
|
||||
/// but not increasing the generation number. Note that this mode cannot be used at tenant shutdown,
|
||||
/// as flushing the deletion queue at that time will cause shutdown-in-progress errors.
|
||||
Reload,
|
||||
/// Shut down immediately, without waiting for any open layers to flush.
|
||||
Hard,
|
||||
}
|
||||
@@ -1817,7 +1819,7 @@ impl Timeline {
|
||||
}
|
||||
}
|
||||
|
||||
if let ShutdownMode::Flush = mode {
|
||||
if let ShutdownMode::Reload = mode {
|
||||
// drain the upload queue
|
||||
self.remote_client.shutdown().await;
|
||||
if !self.remote_client.no_pending_work() {
|
||||
@@ -2178,6 +2180,21 @@ impl Timeline {
|
||||
)
|
||||
}
|
||||
|
||||
/// Resolve the effective WAL receiver protocol to use for this tenant.
|
||||
///
|
||||
/// Priority order is:
|
||||
/// 1. Tenant config override
|
||||
/// 2. Default value for tenant config override
|
||||
/// 3. Pageserver config override
|
||||
/// 4. Pageserver config default
|
||||
pub fn resolve_wal_receiver_protocol(&self) -> PostgresClientProtocol {
|
||||
let tenant_conf = self.tenant_conf.load().tenant_conf.clone();
|
||||
tenant_conf
|
||||
.wal_receiver_protocol_override
|
||||
.or(self.conf.default_tenant_conf.wal_receiver_protocol_override)
|
||||
.unwrap_or(self.conf.wal_receiver_protocol)
|
||||
}
|
||||
|
||||
pub(super) fn tenant_conf_updated(&self, new_conf: &AttachedTenantConf) {
|
||||
// NB: Most tenant conf options are read by background loops, so,
|
||||
// changes will automatically be picked up.
|
||||
@@ -2470,6 +2487,7 @@ impl Timeline {
|
||||
*guard = Some(WalReceiver::start(
|
||||
Arc::clone(self),
|
||||
WalReceiverConf {
|
||||
protocol: self.resolve_wal_receiver_protocol(),
|
||||
wal_connect_timeout,
|
||||
lagging_wal_timeout,
|
||||
max_lsn_wal_lag,
|
||||
@@ -3829,7 +3847,8 @@ impl Timeline {
|
||||
};
|
||||
|
||||
// Backpressure mechanism: wait with continuation of the flush loop until we have uploaded all layer files.
|
||||
// This makes us refuse ingest until the new layers have been persisted to the remote.
|
||||
// This makes us refuse ingest until the new layers have been persisted to the remote
|
||||
let start = Instant::now();
|
||||
self.remote_client
|
||||
.wait_completion()
|
||||
.await
|
||||
@@ -3842,6 +3861,8 @@ impl Timeline {
|
||||
FlushLayerError::Other(anyhow!(e).into())
|
||||
}
|
||||
})?;
|
||||
let duration = start.elapsed().as_secs_f64();
|
||||
self.metrics.flush_wait_upload_time_gauge_add(duration);
|
||||
|
||||
// FIXME: between create_delta_layer and the scheduling of the upload in `update_metadata_file`,
|
||||
// a compaction can delete the file and then it won't be available for uploads any more.
|
||||
@@ -5896,7 +5917,7 @@ impl<'a> TimelineWriter<'a> {
|
||||
batch: SerializedValueBatch,
|
||||
ctx: &RequestContext,
|
||||
) -> anyhow::Result<()> {
|
||||
if batch.is_empty() {
|
||||
if !batch.has_data() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ pub(crate) async fn offload_timeline(
|
||||
}
|
||||
|
||||
// Now that the Timeline is in Stopping state, request all the related tasks to shut down.
|
||||
timeline.shutdown(super::ShutdownMode::Flush).await;
|
||||
timeline.shutdown(super::ShutdownMode::Reload).await;
|
||||
|
||||
// TODO extend guard mechanism above with method
|
||||
// to make deletions possible while offloading is in progress
|
||||
|
||||
@@ -38,6 +38,7 @@ use storage_broker::BrokerClientChannel;
|
||||
use tokio::sync::watch;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::*;
|
||||
use utils::postgres_client::PostgresClientProtocol;
|
||||
|
||||
use self::connection_manager::ConnectionManagerStatus;
|
||||
|
||||
@@ -45,6 +46,7 @@ use super::Timeline;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct WalReceiverConf {
|
||||
pub protocol: PostgresClientProtocol,
|
||||
/// The timeout on the connection to safekeeper for WAL streaming.
|
||||
pub wal_connect_timeout: Duration,
|
||||
/// The timeout to use to determine when the current connection is "stale" and reconnect to the other one.
|
||||
|
||||
@@ -36,7 +36,9 @@ use postgres_connection::PgConnectionConfig;
|
||||
use utils::backoff::{
|
||||
exponential_backoff, DEFAULT_BASE_BACKOFF_SECONDS, DEFAULT_MAX_BACKOFF_SECONDS,
|
||||
};
|
||||
use utils::postgres_client::wal_stream_connection_config;
|
||||
use utils::postgres_client::{
|
||||
wal_stream_connection_config, ConnectionConfigArgs, PostgresClientProtocol,
|
||||
};
|
||||
use utils::{
|
||||
id::{NodeId, TenantTimelineId},
|
||||
lsn::Lsn,
|
||||
@@ -533,6 +535,7 @@ impl ConnectionManagerState {
|
||||
let node_id = new_sk.safekeeper_id;
|
||||
let connect_timeout = self.conf.wal_connect_timeout;
|
||||
let ingest_batch_size = self.conf.ingest_batch_size;
|
||||
let protocol = self.conf.protocol;
|
||||
let timeline = Arc::clone(&self.timeline);
|
||||
let ctx = ctx.detached_child(
|
||||
TaskKind::WalReceiverConnectionHandler,
|
||||
@@ -546,6 +549,7 @@ impl ConnectionManagerState {
|
||||
|
||||
let res = super::walreceiver_connection::handle_walreceiver_connection(
|
||||
timeline,
|
||||
protocol,
|
||||
new_sk.wal_source_connconf,
|
||||
events_sender,
|
||||
cancellation.clone(),
|
||||
@@ -984,15 +988,33 @@ impl ConnectionManagerState {
|
||||
if info.safekeeper_connstr.is_empty() {
|
||||
return None; // no connection string, ignore sk
|
||||
}
|
||||
match wal_stream_connection_config(
|
||||
self.id,
|
||||
info.safekeeper_connstr.as_ref(),
|
||||
match &self.conf.auth_token {
|
||||
None => None,
|
||||
Some(x) => Some(x),
|
||||
|
||||
let (shard_number, shard_count, shard_stripe_size) = match self.conf.protocol {
|
||||
PostgresClientProtocol::Vanilla => {
|
||||
(None, None, None)
|
||||
},
|
||||
self.conf.availability_zone.as_deref(),
|
||||
) {
|
||||
PostgresClientProtocol::Interpreted { .. } => {
|
||||
let shard_identity = self.timeline.get_shard_identity();
|
||||
(
|
||||
Some(shard_identity.number.0),
|
||||
Some(shard_identity.count.0),
|
||||
Some(shard_identity.stripe_size.0),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
let connection_conf_args = ConnectionConfigArgs {
|
||||
protocol: self.conf.protocol,
|
||||
ttid: self.id,
|
||||
shard_number,
|
||||
shard_count,
|
||||
shard_stripe_size,
|
||||
listen_pg_addr_str: info.safekeeper_connstr.as_ref(),
|
||||
auth_token: self.conf.auth_token.as_ref().map(|t| t.as_str()),
|
||||
availability_zone: self.conf.availability_zone.as_deref()
|
||||
};
|
||||
|
||||
match wal_stream_connection_config(connection_conf_args) {
|
||||
Ok(connstr) => Some((*sk_id, info, connstr)),
|
||||
Err(e) => {
|
||||
error!("Failed to create wal receiver connection string from broker data of safekeeper node {}: {e:#}", sk_id);
|
||||
@@ -1096,6 +1118,7 @@ impl ReconnectReason {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tenant::harness::{TenantHarness, TIMELINE_ID};
|
||||
use pageserver_api::config::defaults::DEFAULT_WAL_RECEIVER_PROTOCOL;
|
||||
use url::Host;
|
||||
|
||||
fn dummy_broker_sk_timeline(
|
||||
@@ -1532,6 +1555,7 @@ mod tests {
|
||||
timeline,
|
||||
cancel: CancellationToken::new(),
|
||||
conf: WalReceiverConf {
|
||||
protocol: DEFAULT_WAL_RECEIVER_PROTOCOL,
|
||||
wal_connect_timeout: Duration::from_secs(1),
|
||||
lagging_wal_timeout: Duration::from_secs(1),
|
||||
max_lsn_wal_lag: NonZeroU64::new(1024 * 1024).unwrap(),
|
||||
|
||||
@@ -22,7 +22,10 @@ use tokio::{select, sync::watch, time};
|
||||
use tokio_postgres::{replication::ReplicationStream, Client};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, error, info, trace, warn, Instrument};
|
||||
use wal_decoder::models::{FlushUncommittedRecords, InterpretedWalRecord};
|
||||
use wal_decoder::{
|
||||
models::{FlushUncommittedRecords, InterpretedWalRecord, InterpretedWalRecords},
|
||||
wire_format::FromWireFormat,
|
||||
};
|
||||
|
||||
use super::TaskStateUpdate;
|
||||
use crate::{
|
||||
@@ -36,7 +39,7 @@ use crate::{
|
||||
use postgres_backend::is_expected_io_error;
|
||||
use postgres_connection::PgConnectionConfig;
|
||||
use postgres_ffi::waldecoder::WalStreamDecoder;
|
||||
use utils::{id::NodeId, lsn::Lsn};
|
||||
use utils::{id::NodeId, lsn::Lsn, postgres_client::PostgresClientProtocol};
|
||||
use utils::{pageserver_feedback::PageserverFeedback, sync::gate::GateError};
|
||||
|
||||
/// Status of the connection.
|
||||
@@ -109,6 +112,7 @@ impl From<WalDecodeError> for WalReceiverError {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(super) async fn handle_walreceiver_connection(
|
||||
timeline: Arc<Timeline>,
|
||||
protocol: PostgresClientProtocol,
|
||||
wal_source_connconf: PgConnectionConfig,
|
||||
events_sender: watch::Sender<TaskStateUpdate<WalConnectionStatus>>,
|
||||
cancellation: CancellationToken,
|
||||
@@ -260,6 +264,14 @@ pub(super) async fn handle_walreceiver_connection(
|
||||
|
||||
let mut walingest = WalIngest::new(timeline.as_ref(), startpoint, &ctx).await?;
|
||||
|
||||
let interpreted_proto_config = match protocol {
|
||||
PostgresClientProtocol::Vanilla => None,
|
||||
PostgresClientProtocol::Interpreted {
|
||||
format,
|
||||
compression,
|
||||
} => Some((format, compression)),
|
||||
};
|
||||
|
||||
while let Some(replication_message) = {
|
||||
select! {
|
||||
_ = cancellation.cancelled() => {
|
||||
@@ -291,6 +303,15 @@ pub(super) async fn handle_walreceiver_connection(
|
||||
connection_status.latest_connection_update = now;
|
||||
connection_status.commit_lsn = Some(Lsn::from(keepalive.wal_end()));
|
||||
}
|
||||
ReplicationMessage::RawInterpretedWalRecords(raw) => {
|
||||
connection_status.latest_connection_update = now;
|
||||
if !raw.data().is_empty() {
|
||||
connection_status.latest_wal_update = now;
|
||||
}
|
||||
|
||||
connection_status.commit_lsn = Some(Lsn::from(raw.commit_lsn()));
|
||||
connection_status.streaming_lsn = Some(Lsn::from(raw.streaming_lsn()));
|
||||
}
|
||||
&_ => {}
|
||||
};
|
||||
if let Err(e) = events_sender.send(TaskStateUpdate::Progress(connection_status)) {
|
||||
@@ -298,7 +319,148 @@ pub(super) async fn handle_walreceiver_connection(
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
async fn commit(
|
||||
modification: &mut DatadirModification<'_>,
|
||||
uncommitted: &mut u64,
|
||||
filtered: &mut u64,
|
||||
ctx: &RequestContext,
|
||||
) -> anyhow::Result<()> {
|
||||
WAL_INGEST
|
||||
.records_committed
|
||||
.inc_by(*uncommitted - *filtered);
|
||||
modification.commit(ctx).await?;
|
||||
*uncommitted = 0;
|
||||
*filtered = 0;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
let status_update = match replication_message {
|
||||
ReplicationMessage::RawInterpretedWalRecords(raw) => {
|
||||
WAL_INGEST.bytes_received.inc_by(raw.data().len() as u64);
|
||||
|
||||
let mut uncommitted_records = 0;
|
||||
let mut filtered_records = 0;
|
||||
|
||||
// This is the end LSN of the raw WAL from which the records
|
||||
// were interpreted.
|
||||
let streaming_lsn = Lsn::from(raw.streaming_lsn());
|
||||
|
||||
let (format, compression) = interpreted_proto_config.unwrap();
|
||||
let batch = InterpretedWalRecords::from_wire(raw.data(), format, compression)
|
||||
.await
|
||||
.with_context(|| {
|
||||
anyhow::anyhow!(
|
||||
"Failed to deserialize interpreted records ending at LSN {streaming_lsn}"
|
||||
)
|
||||
})?;
|
||||
|
||||
let InterpretedWalRecords {
|
||||
records,
|
||||
next_record_lsn,
|
||||
} = batch;
|
||||
|
||||
tracing::debug!(
|
||||
"Received WAL up to {} with next_record_lsn={:?}",
|
||||
streaming_lsn,
|
||||
next_record_lsn
|
||||
);
|
||||
|
||||
// We start the modification at 0 because each interpreted record
|
||||
// advances it to its end LSN. 0 is just an initialization placeholder.
|
||||
let mut modification = timeline.begin_modification(Lsn(0));
|
||||
|
||||
for interpreted in records {
|
||||
if matches!(interpreted.flush_uncommitted, FlushUncommittedRecords::Yes)
|
||||
&& uncommitted_records > 0
|
||||
{
|
||||
commit(
|
||||
&mut modification,
|
||||
&mut uncommitted_records,
|
||||
&mut filtered_records,
|
||||
&ctx,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let local_next_record_lsn = interpreted.next_record_lsn;
|
||||
let ingested = walingest
|
||||
.ingest_record(interpreted, &mut modification, &ctx)
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!("could not ingest record at {local_next_record_lsn}")
|
||||
})?;
|
||||
|
||||
if !ingested {
|
||||
tracing::debug!(
|
||||
"ingest: filtered out record @ LSN {local_next_record_lsn}"
|
||||
);
|
||||
WAL_INGEST.records_filtered.inc();
|
||||
filtered_records += 1;
|
||||
}
|
||||
|
||||
uncommitted_records += 1;
|
||||
|
||||
// FIXME: this cannot be made pausable_failpoint without fixing the
|
||||
// failpoint library; in tests, the added amount of debugging will cause us
|
||||
// to timeout the tests.
|
||||
fail_point!("walreceiver-after-ingest");
|
||||
|
||||
// Commit every ingest_batch_size records. Even if we filtered out
|
||||
// all records, we still need to call commit to advance the LSN.
|
||||
if uncommitted_records >= ingest_batch_size
|
||||
|| modification.approx_pending_bytes()
|
||||
> DatadirModification::MAX_PENDING_BYTES
|
||||
{
|
||||
commit(
|
||||
&mut modification,
|
||||
&mut uncommitted_records,
|
||||
&mut filtered_records,
|
||||
&ctx,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
// Records might have been filtered out on the safekeeper side, but we still
|
||||
// need to advance last record LSN on all shards. If we've not ingested the latest
|
||||
// record, then set the LSN of the modification past it. This way all shards
|
||||
// advance their last record LSN at the same time.
|
||||
let needs_last_record_lsn_advance = match next_record_lsn.map(Lsn::from) {
|
||||
Some(lsn) if lsn > modification.get_lsn() => {
|
||||
modification.set_lsn(lsn).unwrap();
|
||||
true
|
||||
}
|
||||
_ => false,
|
||||
};
|
||||
|
||||
if uncommitted_records > 0 || needs_last_record_lsn_advance {
|
||||
// Commit any uncommitted records
|
||||
commit(
|
||||
&mut modification,
|
||||
&mut uncommitted_records,
|
||||
&mut filtered_records,
|
||||
&ctx,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
if !caught_up && streaming_lsn >= end_of_wal {
|
||||
info!("caught up at LSN {streaming_lsn}");
|
||||
caught_up = true;
|
||||
}
|
||||
|
||||
tracing::debug!(
|
||||
"Ingested WAL up to {streaming_lsn}. Last record LSN is {}",
|
||||
timeline.get_last_record_lsn()
|
||||
);
|
||||
|
||||
if let Some(lsn) = next_record_lsn {
|
||||
last_rec_lsn = lsn;
|
||||
}
|
||||
|
||||
Some(streaming_lsn)
|
||||
}
|
||||
|
||||
ReplicationMessage::XLogData(xlog_data) => {
|
||||
// Pass the WAL data to the decoder, and see if we can decode
|
||||
// more records as a result.
|
||||
@@ -316,21 +478,6 @@ pub(super) async fn handle_walreceiver_connection(
|
||||
let mut uncommitted_records = 0;
|
||||
let mut filtered_records = 0;
|
||||
|
||||
async fn commit(
|
||||
modification: &mut DatadirModification<'_>,
|
||||
uncommitted: &mut u64,
|
||||
filtered: &mut u64,
|
||||
ctx: &RequestContext,
|
||||
) -> anyhow::Result<()> {
|
||||
WAL_INGEST
|
||||
.records_committed
|
||||
.inc_by(*uncommitted - *filtered);
|
||||
modification.commit(ctx).await?;
|
||||
*uncommitted = 0;
|
||||
*filtered = 0;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
while let Some((next_record_lsn, recdata)) = waldecoder.poll_decode()? {
|
||||
// It is important to deal with the aligned records as lsn in getPage@LSN is
|
||||
// aligned and can be several bytes bigger. Without this alignment we are
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user