Compare commits

..

5 Commits

Author SHA1 Message Date
Konstantin Knizhnik
5ecc4b249f Bump Postgres version 2024-11-25 17:06:01 +02:00
Konstantin Knizhnik
48025c988f Bump Postgres version 2024-11-25 10:12:06 +02:00
Konstantin Knizhnik
424ba47c58 Bupm postgres version 2024-11-24 21:52:06 +02:00
Konstantin Knizhnik
c424fa60ca Bupm postgres version 2024-11-24 21:31:51 +02:00
Konstantin Knizhnik
74d5129a0d Fix seqscan prefetch in pg17 2024-11-24 13:52:48 +02:00
351 changed files with 4114 additions and 23483 deletions

View File

@@ -46,9 +46,6 @@ workspace-members = [
"utils",
"wal_craft",
"walproposer",
"postgres-protocol2",
"postgres-types2",
"tokio-postgres2",
]
# Write out exact versions rather than a semver range. (Defaults to false.)

View File

@@ -43,8 +43,7 @@ runs:
PR_NUMBER=$(jq --raw-output .pull_request.number "$GITHUB_EVENT_PATH" || true)
if [ "${PR_NUMBER}" != "null" ]; then
BRANCH_OR_PR=pr-${PR_NUMBER}
elif [ "${GITHUB_REF_NAME}" = "main" ] || [ "${GITHUB_REF_NAME}" = "release" ] || \
[ "${GITHUB_REF_NAME}" = "release-proxy" ] || [ "${GITHUB_REF_NAME}" = "release-compute" ]; then
elif [ "${GITHUB_REF_NAME}" = "main" ] || [ "${GITHUB_REF_NAME}" = "release" ] || [ "${GITHUB_REF_NAME}" = "release-proxy" ]; then
# Shortcut for special branches
BRANCH_OR_PR=${GITHUB_REF_NAME}
else

View File

@@ -23,8 +23,7 @@ runs:
PR_NUMBER=$(jq --raw-output .pull_request.number "$GITHUB_EVENT_PATH" || true)
if [ "${PR_NUMBER}" != "null" ]; then
BRANCH_OR_PR=pr-${PR_NUMBER}
elif [ "${GITHUB_REF_NAME}" = "main" ] || [ "${GITHUB_REF_NAME}" = "release" ] || \
[ "${GITHUB_REF_NAME}" = "release-proxy" ] || [ "${GITHUB_REF_NAME}" = "release-compute" ]; then
elif [ "${GITHUB_REF_NAME}" = "main" ] || [ "${GITHUB_REF_NAME}" = "release" ] || [ "${GITHUB_REF_NAME}" = "release-proxy" ]; then
# Shortcut for special branches
BRANCH_OR_PR=${GITHUB_REF_NAME}
else

View File

@@ -36,8 +36,8 @@ inputs:
description: 'Region name for real s3 tests'
required: false
default: ''
rerun_failed:
description: 'Whether to rerun failed tests'
rerun_flaky:
description: 'Whether to rerun flaky tests'
required: false
default: 'false'
pg_version:
@@ -108,7 +108,7 @@ runs:
COMPATIBILITY_SNAPSHOT_DIR: /tmp/compatibility_snapshot_pg${{ inputs.pg_version }}
ALLOW_BACKWARD_COMPATIBILITY_BREAKAGE: contains(github.event.pull_request.labels.*.name, 'backward compatibility breakage')
ALLOW_FORWARD_COMPATIBILITY_BREAKAGE: contains(github.event.pull_request.labels.*.name, 'forward compatibility breakage')
RERUN_FAILED: ${{ inputs.rerun_failed }}
RERUN_FLAKY: ${{ inputs.rerun_flaky }}
PG_VERSION: ${{ inputs.pg_version }}
shell: bash -euxo pipefail {0}
run: |
@@ -154,8 +154,15 @@ runs:
EXTRA_PARAMS="--out-dir $PERF_REPORT_DIR $EXTRA_PARAMS"
fi
if [ "${RERUN_FAILED}" == "true" ]; then
EXTRA_PARAMS="--reruns 2 $EXTRA_PARAMS"
if [ "${RERUN_FLAKY}" == "true" ]; then
mkdir -p $TEST_OUTPUT
poetry run ./scripts/flaky_tests.py "${TEST_RESULT_CONNSTR}" \
--days 7 \
--output "$TEST_OUTPUT/flaky.json" \
--pg-version "${DEFAULT_PG_VERSION}" \
--build-type "${BUILD_TYPE}"
EXTRA_PARAMS="--flaky-tests-json $TEST_OUTPUT/flaky.json $EXTRA_PARAMS"
fi
# We use pytest-split plugin to run benchmarks in parallel on different CI runners

View File

@@ -19,8 +19,8 @@ on:
description: 'debug or release'
required: true
type: string
test-cfg:
description: 'a json object of postgres versions and lfc states to run regression tests on'
pg-versions:
description: 'a json array of postgres versions to run regression tests on'
required: true
type: string
@@ -276,14 +276,14 @@ jobs:
options: --init --shm-size=512mb --ulimit memlock=67108864:67108864
strategy:
fail-fast: false
matrix: ${{ fromJSON(format('{{"include":{0}}}', inputs.test-cfg)) }}
matrix:
pg_version: ${{ fromJson(inputs.pg-versions) }}
steps:
- uses: actions/checkout@v4
with:
submodules: true
- name: Pytest regression tests
continue-on-error: ${{ matrix.lfc_state == 'with-lfc' }}
uses: ./.github/actions/run-python-test-set
timeout-minutes: 60
with:
@@ -293,14 +293,13 @@ jobs:
run_with_real_s3: true
real_s3_bucket: neon-github-ci-tests
real_s3_region: eu-central-1
rerun_failed: true
rerun_flaky: true
pg_version: ${{ matrix.pg_version }}
env:
TEST_RESULT_CONNSTR: ${{ secrets.REGRESS_TEST_RESULT_CONNSTR_NEW }}
CHECK_ONDISK_DATA_COMPATIBILITY: nonempty
BUILD_TAG: ${{ inputs.build-tag }}
PAGESERVER_VIRTUAL_FILE_IO_ENGINE: tokio-epoll-uring
USE_LFC: ${{ matrix.lfc_state == 'with-lfc' && 'true' || 'false' }}
# Temporary disable this step until we figure out why it's so flaky
# Ref https://github.com/neondatabase/neon/issues/4540

View File

@@ -21,7 +21,7 @@ defaults:
shell: bash -euo pipefail {0}
jobs:
create-release-branch:
create-storage-release-branch:
runs-on: ubuntu-22.04
permissions:

View File

@@ -249,7 +249,7 @@ jobs:
# Post both success and failure to the Slack channel
- name: Post to a Slack channel
if: ${{ github.event.schedule && !cancelled() }}
if: ${{ github.event.schedule }}
uses: slackapi/slack-github-action@v1
with:
channel-id: "C06T9AMNDQQ" # on-call-compute-staging-stream
@@ -541,7 +541,7 @@ jobs:
runs-on: ${{ matrix.RUNNER }}
container:
image: neondatabase/build-tools:pinned-bookworm
image: neondatabase/build-tools:pinned
credentials:
username: ${{ secrets.NEON_DOCKERHUB_USERNAME }}
password: ${{ secrets.NEON_DOCKERHUB_PASSWORD }}
@@ -558,12 +558,12 @@ jobs:
arch=$(uname -m | sed 's/x86_64/amd64/g' | sed 's/aarch64/arm64/g')
cd /home/nonroot
wget -q "https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-17/libpq5_17.2-1.pgdg120+1_${arch}.deb"
wget -q "https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-16/postgresql-client-16_16.6-1.pgdg120+1_${arch}.deb"
wget -q "https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-16/postgresql-16_16.6-1.pgdg120+1_${arch}.deb"
dpkg -x libpq5_17.2-1.pgdg120+1_${arch}.deb pg
dpkg -x postgresql-16_16.6-1.pgdg120+1_${arch}.deb pg
dpkg -x postgresql-client-16_16.6-1.pgdg120+1_${arch}.deb pg
wget -q "https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-17/libpq5_17.2-1.pgdg110+1_${arch}.deb"
wget -q "https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-16/postgresql-client-16_16.6-1.pgdg110+1_${arch}.deb"
wget -q "https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-16/postgresql-16_16.6-1.pgdg110+1_${arch}.deb"
dpkg -x libpq5_17.2-1.pgdg110+1_${arch}.deb pg
dpkg -x postgresql-16_16.6-1.pgdg110+1_${arch}.deb pg
dpkg -x postgresql-client-16_16.6-1.pgdg110+1_${arch}.deb pg
mkdir -p /tmp/neon/pg_install/v16/bin
ln -s /home/nonroot/pg/usr/lib/postgresql/16/bin/pgbench /tmp/neon/pg_install/v16/bin/pgbench

View File

@@ -2,17 +2,6 @@ name: Build build-tools image
on:
workflow_call:
inputs:
archs:
description: "Json array of architectures to build"
# Default values are set in `check-image` job, `set-variables` step
type: string
required: false
debians:
description: "Json array of Debian versions to build"
# Default values are set in `check-image` job, `set-variables` step
type: string
required: false
outputs:
image-tag:
description: "build-tools tag"
@@ -43,37 +32,25 @@ jobs:
check-image:
runs-on: ubuntu-22.04
outputs:
archs: ${{ steps.set-variables.outputs.archs }}
debians: ${{ steps.set-variables.outputs.debians }}
tag: ${{ steps.set-variables.outputs.image-tag }}
everything: ${{ steps.set-more-variables.outputs.everything }}
found: ${{ steps.set-more-variables.outputs.found }}
tag: ${{ steps.get-build-tools-tag.outputs.image-tag }}
found: ${{ steps.check-image.outputs.found }}
steps:
- uses: actions/checkout@v4
- name: Set variables
id: set-variables
- name: Get build-tools image tag for the current commit
id: get-build-tools-tag
env:
ARCHS: ${{ inputs.archs || '["x64","arm64"]' }}
DEBIANS: ${{ inputs.debians || '["bullseye","bookworm"]' }}
IMAGE_TAG: |
${{ hashFiles('build-tools.Dockerfile',
'.github/workflows/build-build-tools-image.yml') }}
run: |
echo "archs=${ARCHS}" | tee -a ${GITHUB_OUTPUT}
echo "debians=${DEBIANS}" | tee -a ${GITHUB_OUTPUT}
echo "image-tag=${IMAGE_TAG}" | tee -a ${GITHUB_OUTPUT}
echo "image-tag=${IMAGE_TAG}" | tee -a $GITHUB_OUTPUT
- name: Set more variables
id: set-more-variables
- name: Check if such tag found in the registry
id: check-image
env:
IMAGE_TAG: ${{ steps.set-variables.outputs.image-tag }}
EVERYTHING: |
${{ contains(fromJson(steps.set-variables.outputs.archs), 'x64') &&
contains(fromJson(steps.set-variables.outputs.archs), 'arm64') &&
contains(fromJson(steps.set-variables.outputs.debians), 'bullseye') &&
contains(fromJson(steps.set-variables.outputs.debians), 'bookworm') }}
IMAGE_TAG: ${{ steps.get-build-tools-tag.outputs.image-tag }}
run: |
if docker manifest inspect neondatabase/build-tools:${IMAGE_TAG}; then
found=true
@@ -81,8 +58,8 @@ jobs:
found=false
fi
echo "everything=${EVERYTHING}" | tee -a ${GITHUB_OUTPUT}
echo "found=${found}" | tee -a ${GITHUB_OUTPUT}
echo "found=${found}" | tee -a $GITHUB_OUTPUT
build-image:
needs: [ check-image ]
@@ -90,8 +67,8 @@ jobs:
strategy:
matrix:
arch: ${{ fromJson(needs.check-image.outputs.archs) }}
debian: ${{ fromJson(needs.check-image.outputs.debians) }}
debian-version: [ bullseye, bookworm ]
arch: [ x64, arm64 ]
runs-on: ${{ fromJson(format('["self-hosted", "{0}"]', matrix.arch == 'arm64' && 'large-arm64' || 'large')) }}
@@ -122,11 +99,11 @@ jobs:
push: true
pull: true
build-args: |
DEBIAN_VERSION=${{ matrix.debian }}
cache-from: type=registry,ref=cache.neon.build/build-tools:cache-${{ matrix.debian }}-${{ matrix.arch }}
cache-to: ${{ github.ref_name == 'main' && format('type=registry,ref=cache.neon.build/build-tools:cache-{0}-{1},mode=max', matrix.debian, matrix.arch) || '' }}
DEBIAN_VERSION=${{ matrix.debian-version }}
cache-from: type=registry,ref=cache.neon.build/build-tools:cache-${{ matrix.debian-version }}-${{ matrix.arch }}
cache-to: ${{ github.ref_name == 'main' && format('type=registry,ref=cache.neon.build/build-tools:cache-{0}-{1},mode=max', matrix.debian-version, matrix.arch) || '' }}
tags: |
neondatabase/build-tools:${{ needs.check-image.outputs.tag }}-${{ matrix.debian }}-${{ matrix.arch }}
neondatabase/build-tools:${{ needs.check-image.outputs.tag }}-${{ matrix.debian-version }}-${{ matrix.arch }}
merge-images:
needs: [ check-image, build-image ]
@@ -140,22 +117,16 @@ jobs:
- name: Create multi-arch image
env:
DEFAULT_DEBIAN_VERSION: bookworm
ARCHS: ${{ join(fromJson(needs.check-image.outputs.archs), ' ') }}
DEBIANS: ${{ join(fromJson(needs.check-image.outputs.debians), ' ') }}
EVERYTHING: ${{ needs.check-image.outputs.everything }}
DEFAULT_DEBIAN_VERSION: bullseye
IMAGE_TAG: ${{ needs.check-image.outputs.tag }}
run: |
for debian in ${DEBIANS}; do
tags=("-t" "neondatabase/build-tools:${IMAGE_TAG}-${debian}")
if [ "${EVERYTHING}" == "true" ] && [ "${debian}" == "${DEFAULT_DEBIAN_VERSION}" ]; then
for debian_version in bullseye bookworm; do
tags=("-t" "neondatabase/build-tools:${IMAGE_TAG}-${debian_version}")
if [ "${debian_version}" == "${DEFAULT_DEBIAN_VERSION}" ]; then
tags+=("-t" "neondatabase/build-tools:${IMAGE_TAG}")
fi
for arch in ${ARCHS}; do
tags+=("neondatabase/build-tools:${IMAGE_TAG}-${debian}-${arch}")
done
docker buildx imagetools create "${tags[@]}"
docker buildx imagetools create "${tags[@]}" \
neondatabase/build-tools:${IMAGE_TAG}-${debian_version}-x64 \
neondatabase/build-tools:${IMAGE_TAG}-${debian_version}-arm64
done

View File

@@ -6,7 +6,6 @@ on:
- main
- release
- release-proxy
- release-compute
pull_request:
defaults:
@@ -71,10 +70,8 @@ jobs:
echo "tag=release-$(git rev-list --count HEAD)" >> $GITHUB_OUTPUT
elif [[ "$GITHUB_REF_NAME" == "release-proxy" ]]; then
echo "tag=release-proxy-$(git rev-list --count HEAD)" >> $GITHUB_OUTPUT
elif [[ "$GITHUB_REF_NAME" == "release-compute" ]]; then
echo "tag=release-compute-$(git rev-list --count HEAD)" >> $GITHUB_OUTPUT
else
echo "GITHUB_REF_NAME (value '$GITHUB_REF_NAME') is not set to either 'main' or 'release', 'release-proxy', 'release-compute'"
echo "GITHUB_REF_NAME (value '$GITHUB_REF_NAME') is not set to either 'main' or 'release'"
echo "tag=$GITHUB_RUN_ID" >> $GITHUB_OUTPUT
fi
shell: bash
@@ -256,14 +253,7 @@ jobs:
build-tag: ${{ needs.tag.outputs.build-tag }}
build-type: ${{ matrix.build-type }}
# Run tests on all Postgres versions in release builds and only on the latest version in debug builds
# run without LFC on v17 release only
test-cfg: |
${{ matrix.build-type == 'release' && '[{"pg_version":"v14", "lfc_state": "without-lfc"},
{"pg_version":"v15", "lfc_state": "without-lfc"},
{"pg_version":"v16", "lfc_state": "without-lfc"},
{"pg_version":"v17", "lfc_state": "without-lfc"},
{"pg_version":"v17", "lfc_state": "with-lfc"}]'
|| '[{"pg_version":"v17", "lfc_state": "without-lfc"}]' }}
pg-versions: ${{ matrix.build-type == 'release' && '["v14", "v15", "v16", "v17"]' || '["v17"]' }}
secrets: inherit
# Keep `benchmarks` job outside of `build-and-test-locally` workflow to make job failures non-blocking
@@ -516,7 +506,7 @@ jobs:
})
trigger-e2e-tests:
if: ${{ !github.event.pull_request.draft || contains( github.event.pull_request.labels.*.name, 'run-e2e-tests-in-draft') || github.ref_name == 'main' || github.ref_name == 'release' || github.ref_name == 'release-proxy' || github.ref_name == 'release-compute' }}
if: ${{ !github.event.pull_request.draft || contains( github.event.pull_request.labels.*.name, 'run-e2e-tests-in-draft') || github.ref_name == 'main' || github.ref_name == 'release' || github.ref_name == 'release-proxy' }}
needs: [ check-permissions, promote-images, tag ]
uses: ./.github/workflows/trigger-e2e-tests.yml
secrets: inherit
@@ -672,7 +662,7 @@ jobs:
neondatabase/compute-node-${{ matrix.version.pg }}:${{ needs.tag.outputs.build-tag }}-${{ matrix.version.debian }}-${{ matrix.arch }}
- name: Build neon extensions test image
if: matrix.version.pg >= 'v16'
if: matrix.version.pg == 'v16'
uses: docker/build-push-action@v6
with:
context: .
@@ -687,7 +677,8 @@ jobs:
pull: true
file: compute/compute-node.Dockerfile
target: neon-pg-ext-test
cache-from: type=registry,ref=cache.neon.build/compute-node-${{ matrix.version.pg }}:cache-${{ matrix.version.debian }}-${{ matrix.arch }}
cache-from: type=registry,ref=cache.neon.build/neon-test-extensions-${{ matrix.version.pg }}:cache-${{ matrix.version.debian }}-${{ matrix.arch }}
cache-to: ${{ github.ref_name == 'main' && format('type=registry,ref=cache.neon.build/neon-test-extensions-{0}:cache-{1}-{2},mode=max', matrix.version.pg, matrix.version.debian, matrix.arch) || '' }}
tags: |
neondatabase/neon-test-extensions-${{ matrix.version.pg }}:${{needs.tag.outputs.build-tag}}-${{ matrix.version.debian }}-${{ matrix.arch }}
@@ -710,7 +701,7 @@ jobs:
push: true
pull: true
file: compute/compute-node.Dockerfile
cache-from: type=registry,ref=cache.neon.build/compute-node-${{ matrix.version.pg }}:cache-${{ matrix.version.debian }}-${{ matrix.arch }}
cache-from: type=registry,ref=cache.neon.build/neon-test-extensions-${{ matrix.version.pg }}:cache-${{ matrix.version.debian }}-${{ matrix.arch }}
cache-to: ${{ github.ref_name == 'main' && format('type=registry,ref=cache.neon.build/compute-tools-{0}:cache-{1}-{2},mode=max', matrix.version.pg, matrix.version.debian, matrix.arch) || '' }}
tags: |
neondatabase/compute-tools:${{ needs.tag.outputs.build-tag }}-${{ matrix.version.debian }}-${{ matrix.arch }}
@@ -746,7 +737,7 @@ jobs:
neondatabase/compute-node-${{ matrix.version.pg }}:${{ needs.tag.outputs.build-tag }}-${{ matrix.version.debian }}-arm64
- name: Create multi-arch neon-test-extensions image
if: matrix.version.pg >= 'v16'
if: matrix.version.pg == 'v16'
run: |
docker buildx imagetools create -t neondatabase/neon-test-extensions-${{ matrix.version.pg }}:${{ needs.tag.outputs.build-tag }} \
-t neondatabase/neon-test-extensions-${{ matrix.version.pg }}:${{ needs.tag.outputs.build-tag }}-${{ matrix.version.debian }} \
@@ -835,7 +826,6 @@ jobs:
fail-fast: false
matrix:
arch: [ x64, arm64 ]
pg_version: [v16, v17]
runs-on: ${{ fromJson(format('["self-hosted", "{0}"]', matrix.arch == 'arm64' && 'small-arm64' || 'small')) }}
@@ -874,10 +864,7 @@ jobs:
- name: Verify docker-compose example and test extensions
timeout-minutes: 20
env:
TAG: ${{needs.tag.outputs.build-tag}}
TEST_VERSION_ONLY: ${{ matrix.pg_version }}
run: ./docker-compose/docker_compose_test.sh
run: env TAG=${{needs.tag.outputs.build-tag}} ./docker-compose/docker_compose_test.sh
- name: Print logs and clean up
if: always()
@@ -937,7 +924,7 @@ jobs:
neondatabase/neon-test-extensions-v16:${{ needs.tag.outputs.build-tag }}
- name: Configure AWS-prod credentials
if: github.ref_name == 'release'|| github.ref_name == 'release-proxy' || github.ref_name == 'release-compute'
if: github.ref_name == 'release'|| github.ref_name == 'release-proxy'
uses: aws-actions/configure-aws-credentials@v4
with:
aws-region: eu-central-1
@@ -946,12 +933,12 @@ jobs:
- name: Login to prod ECR
uses: docker/login-action@v3
if: github.ref_name == 'release'|| github.ref_name == 'release-proxy' || github.ref_name == 'release-compute'
if: github.ref_name == 'release'|| github.ref_name == 'release-proxy'
with:
registry: 093970136003.dkr.ecr.eu-central-1.amazonaws.com
- name: Copy all images to prod ECR
if: github.ref_name == 'release' || github.ref_name == 'release-proxy' || github.ref_name == 'release-compute'
if: github.ref_name == 'release'|| github.ref_name == 'release-proxy'
run: |
for image in neon compute-tools {vm-,}compute-node-{v14,v15,v16,v17}; do
docker buildx imagetools create -t 093970136003.dkr.ecr.eu-central-1.amazonaws.com/${image}:${{ needs.tag.outputs.build-tag }} \
@@ -971,7 +958,7 @@ jobs:
tenant_id: ${{ vars.AZURE_TENANT_ID }}
push-to-acr-prod:
if: github.ref_name == 'release' || github.ref_name == 'release-proxy' || github.ref_name == 'release-compute'
if: github.ref_name == 'release'|| github.ref_name == 'release-proxy'
needs: [ tag, promote-images ]
uses: ./.github/workflows/_push-to-acr.yml
with:
@@ -1059,7 +1046,7 @@ jobs:
deploy:
needs: [ check-permissions, promote-images, tag, build-and-test-locally, trigger-custom-extensions-build-and-wait, push-to-acr-dev, push-to-acr-prod ]
# `!failure() && !cancelled()` is required because the workflow depends on the job that can be skipped: `push-to-acr-dev` and `push-to-acr-prod`
if: (github.ref_name == 'main' || github.ref_name == 'release' || github.ref_name == 'release-proxy' || github.ref_name == 'release-compute') && !failure() && !cancelled()
if: (github.ref_name == 'main' || github.ref_name == 'release' || github.ref_name == 'release-proxy') && !failure() && !cancelled()
runs-on: [ self-hosted, small ]
container: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/ansible:latest
@@ -1108,15 +1095,13 @@ jobs:
-f deployProxyAuthBroker=true \
-f branch=main \
-f dockerTag=${{needs.tag.outputs.build-tag}}
elif [[ "$GITHUB_REF_NAME" == "release-compute" ]]; then
gh workflow --repo neondatabase/infra run deploy-compute-dev.yml --ref main -f dockerTag=${{needs.tag.outputs.build-tag}}
else
echo "GITHUB_REF_NAME (value '$GITHUB_REF_NAME') is not set to either 'main', 'release', 'release-proxy' or 'release-compute'"
echo "GITHUB_REF_NAME (value '$GITHUB_REF_NAME') is not set to either 'main' or 'release'"
exit 1
fi
- name: Create git tag
if: github.ref_name == 'release' || github.ref_name == 'release-proxy' || github.ref_name == 'release-compute'
if: github.ref_name == 'release' || github.ref_name == 'release-proxy'
uses: actions/github-script@v7
with:
# Retry script for 5XX server errors: https://github.com/actions/github-script#retries

View File

@@ -26,7 +26,6 @@ concurrency:
jobs:
ingest:
strategy:
fail-fast: false # allow other variants to continue even if one fails
matrix:
target_project: [new_empty_project, large_existing_project]
permissions:

View File

@@ -29,7 +29,7 @@ jobs:
trigger_bench_on_ec2_machine_in_eu_central_1:
runs-on: [ self-hosted, small ]
container:
image: neondatabase/build-tools:pinned-bookworm
image: neondatabase/build-tools:pinned
credentials:
username: ${{ secrets.NEON_DOCKERHUB_USERNAME }}
password: ${{ secrets.NEON_DOCKERHUB_PASSWORD }}

View File

@@ -94,7 +94,7 @@ jobs:
- name: Tag build-tools with `${{ env.TO_TAG }}` in Docker Hub, ECR, and ACR
env:
DEFAULT_DEBIAN_VERSION: bookworm
DEFAULT_DEBIAN_VERSION: bullseye
run: |
for debian_version in bullseye bookworm; do
tags=()

View File

@@ -23,8 +23,6 @@ jobs:
id: python-src
with:
files: |
.github/workflows/_check-codestyle-python.yml
.github/workflows/build-build-tools-image.yml
.github/workflows/pre-merge-checks.yml
**/**.py
poetry.lock
@@ -40,10 +38,6 @@ jobs:
if: needs.get-changed-files.outputs.python-changed == 'true'
needs: [ get-changed-files ]
uses: ./.github/workflows/build-build-tools-image.yml
with:
# Build only one combination to save time
archs: '["x64"]'
debians: '["bookworm"]'
secrets: inherit
check-codestyle-python:
@@ -51,8 +45,7 @@ jobs:
needs: [ get-changed-files, build-build-tools-image ]
uses: ./.github/workflows/_check-codestyle-python.yml
with:
# `-bookworm-x64` suffix should match the combination in `build-build-tools-image`
build-tools-image: ${{ needs.build-build-tools-image.outputs.image }}-bookworm-x64
build-tools-image: ${{ needs.build-build-tools-image.outputs.image }}-bookworm
secrets: inherit
# To get items from the merge queue merged into main we need to satisfy "Status checks that are required".

View File

@@ -15,10 +15,6 @@ on:
type: boolean
description: 'Create Proxy release PR'
required: false
create-compute-release-branch:
type: boolean
description: 'Create Compute release PR'
required: false
# No permission for GITHUB_TOKEN by default; the **minimal required** set of permissions should be granted in each job.
permissions: {}
@@ -29,20 +25,20 @@ defaults:
jobs:
create-storage-release-branch:
if: ${{ github.event.schedule == '0 6 * * MON' || inputs.create-storage-release-branch }}
if: ${{ github.event.schedule == '0 6 * * MON' || format('{0}', inputs.create-storage-release-branch) == 'true' }}
permissions:
contents: write
uses: ./.github/workflows/_create-release-pr.yml
with:
component-name: 'Storage'
component-name: 'Storage & Compute'
release-branch: 'release'
secrets:
ci-access-token: ${{ secrets.CI_ACCESS_TOKEN }}
create-proxy-release-branch:
if: ${{ github.event.schedule == '0 6 * * THU' || inputs.create-proxy-release-branch }}
if: ${{ github.event.schedule == '0 6 * * THU' || format('{0}', inputs.create-proxy-release-branch) == 'true' }}
permissions:
contents: write
@@ -53,16 +49,3 @@ jobs:
release-branch: 'release-proxy'
secrets:
ci-access-token: ${{ secrets.CI_ACCESS_TOKEN }}
create-compute-release-branch:
if: inputs.create-compute-release-branch
permissions:
contents: write
uses: ./.github/workflows/_create-release-pr.yml
with:
component-name: 'Compute'
release-branch: 'release-compute'
secrets:
ci-access-token: ${{ secrets.CI_ACCESS_TOKEN }}

View File

@@ -51,8 +51,6 @@ jobs:
echo "tag=release-$(git rev-list --count HEAD)" | tee -a $GITHUB_OUTPUT
elif [[ "$GITHUB_REF_NAME" == "release-proxy" ]]; then
echo "tag=release-proxy-$(git rev-list --count HEAD)" >> $GITHUB_OUTPUT
elif [[ "$GITHUB_REF_NAME" == "release-compute" ]]; then
echo "tag=release-compute-$(git rev-list --count HEAD)" >> $GITHUB_OUTPUT
else
echo "GITHUB_REF_NAME (value '$GITHUB_REF_NAME') is not set to either 'main' or 'release'"
BUILD_AND_TEST_RUN_ID=$(gh run list -b $CURRENT_BRANCH -c $CURRENT_SHA -w 'Build and Test' -L 1 --json databaseId --jq '.[].databaseId')

View File

@@ -2,7 +2,6 @@
/compute_tools/ @neondatabase/control-plane @neondatabase/compute
/libs/pageserver_api/ @neondatabase/storage
/libs/postgres_ffi/ @neondatabase/compute @neondatabase/storage
/libs/proxy/ @neondatabase/proxy
/libs/remote_storage/ @neondatabase/storage
/libs/safekeeper_api/ @neondatabase/storage
/libs/vm_monitor/ @neondatabase/autoscaling

447
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -35,9 +35,6 @@ members = [
"libs/walproposer",
"libs/wal_decoder",
"libs/postgres_initdb",
"libs/proxy/postgres-protocol2",
"libs/proxy/postgres-types2",
"libs/proxy/tokio-postgres2",
]
[workspace.package]
@@ -74,7 +71,7 @@ bindgen = "0.70"
bit_field = "0.10.2"
bstr = "1.0"
byteorder = "1.4"
bytes = "1.9"
bytes = "1.0"
camino = "1.1.6"
cfg-if = "1.0.0"
chrono = { version = "0.4", default-features = false, features = ["clock"] }
@@ -83,7 +80,6 @@ comfy-table = "7.1"
const_format = "0.2"
crc32c = "0.6"
dashmap = { version = "5.5.0", features = ["raw-api"] }
diatomic-waker = { version = "0.2.3" }
either = "1.8"
enum-map = "2.4.2"
enumset = "1.0.12"
@@ -115,7 +111,6 @@ indoc = "2"
ipnet = "2.10.0"
itertools = "0.10"
itoa = "1.0.11"
jemalloc_pprof = "0.6"
jsonwebtoken = "9"
lasso = "0.7"
libc = "0.2"
@@ -128,10 +123,10 @@ notify = "6.0.0"
num_cpus = "1.15"
num-traits = "0.2.15"
once_cell = "1.13"
opentelemetry = "0.26"
opentelemetry_sdk = "0.26"
opentelemetry-otlp = { version = "0.26", default-features=false, features = ["http-proto", "trace", "http", "reqwest-client"] }
opentelemetry-semantic-conventions = "0.26"
opentelemetry = "0.24"
opentelemetry_sdk = "0.24"
opentelemetry-otlp = { version = "0.17", default-features=false, features = ["http-proto", "trace", "http", "reqwest-client"] }
opentelemetry-semantic-conventions = "0.16"
parking_lot = "0.12"
parquet = { version = "53", default-features = false, features = ["zstd"] }
parquet_derive = "53"
@@ -145,9 +140,9 @@ rand = "0.8"
redis = { version = "0.25.2", features = ["tokio-rustls-comp", "keep-alive"] }
regex = "1.10.2"
reqwest = { version = "0.12", default-features = false, features = ["rustls-tls"] }
reqwest-tracing = { version = "0.5", features = ["opentelemetry_0_26"] }
reqwest-middleware = "0.4"
reqwest-retry = "0.7"
reqwest-tracing = { version = "0.5", features = ["opentelemetry_0_24"] }
reqwest-middleware = "0.3.0"
reqwest-retry = "0.5"
routerify = "3"
rpds = "0.13"
rustc-hash = "1.1.0"
@@ -176,7 +171,7 @@ sync_wrapper = "0.1.2"
tar = "0.4"
test-context = "0.3"
thiserror = "1.0"
tikv-jemallocator = { version = "0.6", features = ["profiling", "stats", "unprefixed_malloc_on_supported_platforms"] }
tikv-jemallocator = { version = "0.6", features = ["stats"] }
tikv-jemalloc-ctl = { version = "0.6", features = ["stats"] }
tokio = { version = "1.17", features = ["macros"] }
tokio-epoll-uring = { git = "https://github.com/neondatabase/tokio-epoll-uring.git" , branch = "main" }
@@ -192,7 +187,7 @@ tonic = {version = "0.12.3", features = ["tls", "tls-roots"]}
tower-service = "0.3.2"
tracing = "0.1"
tracing-error = "0.2"
tracing-opentelemetry = "0.27"
tracing-opentelemetry = "0.25"
tracing-subscriber = { version = "0.3", default-features = false, features = ["smallvec", "fmt", "tracing-log", "std", "env-filter", "json"] }
try-lock = "0.2.5"
twox-hash = { version = "1.6.3", default-features = false }

View File

@@ -7,7 +7,7 @@ ARG IMAGE=build-tools
ARG TAG=pinned
ARG DEFAULT_PG_VERSION=17
ARG STABLE_PG_VERSION=16
ARG DEBIAN_VERSION=bookworm
ARG DEBIAN_VERSION=bullseye
ARG DEBIAN_FLAVOR=${DEBIAN_VERSION}-slim
# Build Postgres

View File

@@ -38,7 +38,6 @@ ifeq ($(UNAME_S),Linux)
# Seccomp BPF is only available for Linux
PG_CONFIGURE_OPTS += --with-libseccomp
else ifeq ($(UNAME_S),Darwin)
PG_CFLAGS += -DUSE_PREFETCH
ifndef DISABLE_HOMEBREW
# macOS with brew-installed openssl requires explicit paths
# It can be configured with OPENSSL_PREFIX variable
@@ -147,8 +146,6 @@ postgres-%: postgres-configure-% \
$(MAKE) -C $(POSTGRES_INSTALL_DIR)/build/$*/contrib/pg_prewarm install
+@echo "Compiling pg_buffercache $*"
$(MAKE) -C $(POSTGRES_INSTALL_DIR)/build/$*/contrib/pg_buffercache install
+@echo "Compiling pg_visibility $*"
$(MAKE) -C $(POSTGRES_INSTALL_DIR)/build/$*/contrib/pg_visibility install
+@echo "Compiling pageinspect $*"
$(MAKE) -C $(POSTGRES_INSTALL_DIR)/build/$*/contrib/pageinspect install
+@echo "Compiling amcheck $*"

View File

@@ -1,4 +1,4 @@
ARG DEBIAN_VERSION=bookworm
ARG DEBIAN_VERSION=bullseye
FROM debian:bookworm-slim AS pgcopydb_builder
ARG DEBIAN_VERSION
@@ -57,9 +57,9 @@ RUN mkdir -p /pgcopydb/bin && \
mkdir -p /pgcopydb/lib && \
chmod -R 755 /pgcopydb && \
chown -R nonroot:nonroot /pgcopydb
COPY --from=pgcopydb_builder /usr/lib/postgresql/16/bin/pgcopydb /pgcopydb/bin/pgcopydb
COPY --from=pgcopydb_builder /pgcopydb/lib/libpq.so.5 /pgcopydb/lib/libpq.so.5
COPY --from=pgcopydb_builder /usr/lib/postgresql/16/bin/pgcopydb /pgcopydb/bin/pgcopydb
COPY --from=pgcopydb_builder /pgcopydb/lib/libpq.so.5 /pgcopydb/lib/libpq.so.5
# System deps
#
@@ -258,14 +258,14 @@ WORKDIR /home/nonroot
# Rust
# Please keep the version of llvm (installed above) in sync with rust llvm (`rustc --version --verbose | grep LLVM`)
ENV RUSTC_VERSION=1.83.0
ENV RUSTC_VERSION=1.82.0
ENV RUSTUP_HOME="/home/nonroot/.rustup"
ENV PATH="/home/nonroot/.cargo/bin:${PATH}"
ARG RUSTFILT_VERSION=0.2.1
ARG CARGO_HAKARI_VERSION=0.9.33
ARG CARGO_DENY_VERSION=0.16.2
ARG CARGO_HACK_VERSION=0.6.33
ARG CARGO_NEXTEST_VERSION=0.9.85
ARG CARGO_HAKARI_VERSION=0.9.30
ARG CARGO_DENY_VERSION=0.16.1
ARG CARGO_HACK_VERSION=0.6.31
ARG CARGO_NEXTEST_VERSION=0.9.72
RUN curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux-gnu/rustup-init && whoami && \
chmod +x rustup-init && \
./rustup-init -y --default-toolchain ${RUSTC_VERSION} && \
@@ -289,7 +289,7 @@ RUN whoami \
&& cargo --version --verbose \
&& rustup --version --verbose \
&& rustc --version --verbose \
&& clang --version
&& clang --version
RUN if [ "${DEBIAN_VERSION}" = "bookworm" ]; then \
LD_LIBRARY_PATH=/pgcopydb/lib /pgcopydb/bin/pgcopydb --version; \

View File

@@ -3,7 +3,7 @@ ARG REPOSITORY=neondatabase
ARG IMAGE=build-tools
ARG TAG=pinned
ARG BUILD_TAG
ARG DEBIAN_VERSION=bookworm
ARG DEBIAN_VERSION=bullseye
ARG DEBIAN_FLAVOR=${DEBIAN_VERSION}-slim
#########################################################################################
@@ -14,9 +14,6 @@ ARG DEBIAN_FLAVOR=${DEBIAN_VERSION}-slim
FROM debian:$DEBIAN_FLAVOR AS build-deps
ARG DEBIAN_VERSION
# Use strict mode for bash to catch errors early
SHELL ["/bin/bash", "-euo", "pipefail", "-c"]
RUN case $DEBIAN_VERSION in \
# Version-specific installs for Bullseye (PG14-PG16):
# The h3_pg extension needs a cmake 3.20+, but Debian bullseye has 3.18.
@@ -109,7 +106,6 @@ RUN cd postgres && \
#
#########################################################################################
FROM build-deps AS postgis-build
ARG DEBIAN_VERSION
ARG PG_VERSION
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
RUN apt update && \
@@ -126,12 +122,12 @@ RUN apt update && \
# and also we must check backward compatibility with older versions of PostGIS.
#
# Use new version only for v17
RUN case "${DEBIAN_VERSION}" in \
"bookworm") \
RUN case "${PG_VERSION}" in \
"v17") \
export SFCGAL_VERSION=1.4.1 \
export SFCGAL_CHECKSUM=1800c8a26241588f11cddcf433049e9b9aea902e923414d2ecef33a3295626c3 \
;; \
"bullseye") \
"v14" | "v15" | "v16") \
export SFCGAL_VERSION=1.3.10 \
export SFCGAL_CHECKSUM=4e39b3b2adada6254a7bdba6d297bb28e1a9835a9f879b74f37e2dab70203232 \
;; \
@@ -232,8 +228,6 @@ FROM build-deps AS plv8-build
ARG PG_VERSION
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY compute/patches/plv8-3.1.10.patch /plv8-3.1.10.patch
RUN apt update && \
apt install --no-install-recommends -y ninja-build python3-dev libncurses5 binutils clang
@@ -245,6 +239,8 @@ RUN apt update && \
#
# Use new version only for v17
# because since v3.2, plv8 doesn't include plcoffee and plls extensions
ENV PLV8_TAG=v3.2.3
RUN case "${PG_VERSION}" in \
"v17") \
export PLV8_TAG=v3.2.3 \
@@ -259,9 +255,8 @@ RUN case "${PG_VERSION}" in \
git clone --recurse-submodules --depth 1 --branch ${PLV8_TAG} https://github.com/plv8/plv8.git plv8-src && \
tar -czf plv8.tar.gz --exclude .git plv8-src && \
cd plv8-src && \
if [[ "${PG_VERSION}" < "v17" ]]; then patch -p1 < /plv8-3.1.10.patch; fi && \
# generate and copy upgrade scripts
mkdir -p upgrade && ./generate_upgrade.sh ${PLV8_TAG#v} && \
mkdir -p upgrade && ./generate_upgrade.sh 3.1.10 && \
cp upgrade/* /usr/local/pgsql/share/extension/ && \
export PATH="/usr/local/pgsql/bin:$PATH" && \
make DOCKER=1 -j $(getconf _NPROCESSORS_ONLN) install && \
@@ -358,10 +353,10 @@ COPY compute/patches/pgvector.patch /pgvector.patch
# because we build the images on different machines than where we run them.
# Pass OPTFLAGS="" to remove it.
#
# vector >0.7.4 supports v17
# last release v0.8.0 - Oct 30, 2024
RUN wget https://github.com/pgvector/pgvector/archive/refs/tags/v0.8.0.tar.gz -O pgvector.tar.gz && \
echo "867a2c328d4928a5a9d6f052cd3bc78c7d60228a9b914ad32aa3db88e9de27b0 pgvector.tar.gz" | sha256sum --check && \
# vector 0.7.4 supports v17
# last release v0.7.4 - Aug 5, 2024
RUN wget https://github.com/pgvector/pgvector/archive/refs/tags/v0.7.4.tar.gz -O pgvector.tar.gz && \
echo "0341edf89b1924ae0d552f617e14fb7f8867c0194ed775bcc44fa40288642583 pgvector.tar.gz" | sha256sum --check && \
mkdir pgvector-src && cd pgvector-src && tar xzf ../pgvector.tar.gz --strip-components=1 -C . && \
patch -p1 < /pgvector.patch && \
make -j $(getconf _NPROCESSORS_ONLN) OPTFLAGS="" PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
@@ -1367,12 +1362,15 @@ RUN make PG_VERSION="${PG_VERSION}" -C compute
FROM neon-pg-ext-build AS neon-pg-ext-test
ARG PG_VERSION
RUN mkdir /ext-src
RUN case "${PG_VERSION}" in "v17") \
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
esac && \
mkdir /ext-src
#COPY --from=postgis-build /postgis.tar.gz /ext-src/
#COPY --from=postgis-build /sfcgal/* /usr
COPY --from=plv8-build /plv8.tar.gz /ext-src/
#COPY --from=h3-pg-build /h3-pg.tar.gz /ext-src/
COPY --from=h3-pg-build /h3-pg.tar.gz /ext-src/
COPY --from=unit-pg-build /postgresql-unit.tar.gz /ext-src/
COPY --from=vector-pg-build /pgvector.tar.gz /ext-src/
COPY --from=vector-pg-build /pgvector.patch /ext-src/
@@ -1392,7 +1390,7 @@ COPY --from=hll-pg-build /hll.tar.gz /ext-src
COPY --from=plpgsql-check-pg-build /plpgsql_check.tar.gz /ext-src
#COPY --from=timescaledb-pg-build /timescaledb.tar.gz /ext-src
COPY --from=pg-hint-plan-pg-build /pg_hint_plan.tar.gz /ext-src
COPY compute/patches/pg_hint_plan_${PG_VERSION}.patch /ext-src
COPY compute/patches/pg_hint_plan.patch /ext-src
COPY --from=pg-cron-pg-build /pg_cron.tar.gz /ext-src
COPY compute/patches/pg_cron.patch /ext-src
#COPY --from=pg-pgx-ulid-build /home/nonroot/pgx_ulid.tar.gz /ext-src
@@ -1402,23 +1400,38 @@ COPY --from=pg-roaringbitmap-pg-build /pg_roaringbitmap.tar.gz /ext-src
COPY --from=pg-semver-pg-build /pg_semver.tar.gz /ext-src
#COPY --from=pg-embedding-pg-build /home/nonroot/pg_embedding-src/ /ext-src
#COPY --from=wal2json-pg-build /wal2json_2_5.tar.gz /ext-src
#pg_anon is not supported yet for pg v17 so, don't fail if nothing found
COPY --from=pg-anon-pg-build /pg_anon.tar.g? /ext-src
COPY --from=pg-anon-pg-build /pg_anon.tar.gz /ext-src
COPY compute/patches/pg_anon.patch /ext-src
COPY --from=pg-ivm-build /pg_ivm.tar.gz /ext-src
COPY --from=pg-partman-build /pg_partman.tar.gz /ext-src
RUN cd /ext-src/ && for f in *.tar.gz; \
RUN case "${PG_VERSION}" in "v17") \
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
esac && \
cd /ext-src/ && for f in *.tar.gz; \
do echo $f; dname=$(echo $f | sed 's/\.tar.*//')-src; \
rm -rf $dname; mkdir $dname; tar xzf $f --strip-components=1 -C $dname \
|| exit 1; rm -f $f; done
RUN cd /ext-src/rum-src && patch -p1 <../rum.patch
RUN cd /ext-src/pgvector-src && patch -p1 <../pgvector.patch
RUN cd /ext-src/pg_hint_plan-src && patch -p1 < /ext-src/pg_hint_plan_${PG_VERSION}.patch
RUN case "${PG_VERSION}" in "v17") \
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
esac && \
cd /ext-src/rum-src && patch -p1 <../rum.patch
RUN case "${PG_VERSION}" in "v17") \
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
esac && \
cd /ext-src/pgvector-src && patch -p1 <../pgvector.patch
RUN case "${PG_VERSION}" in "v17") \
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
esac && \
cd /ext-src/pg_hint_plan-src && patch -p1 < /ext-src/pg_hint_plan.patch
COPY --chmod=755 docker-compose/run-tests.sh /run-tests.sh
RUN case "${PG_VERSION}" in "v17") \
echo "postgresql_anonymizer does not yet support PG17" && exit 0;; \
esac && patch -p1 </ext-src/pg_anon.patch
RUN patch -p1 </ext-src/pg_cron.patch
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
esac && \
patch -p1 </ext-src/pg_anon.patch
RUN case "${PG_VERSION}" in "v17") \
echo "v17 extensions are not supported yet. Quit" && exit 0;; \
esac && \
patch -p1 </ext-src/pg_cron.patch
ENV PATH=/usr/local/pgsql/bin:$PATH
ENV PGHOST=compute
ENV PGPORT=55433

View File

@@ -6,7 +6,6 @@
import 'sql_exporter/compute_backpressure_throttling_seconds.libsonnet',
import 'sql_exporter/compute_current_lsn.libsonnet',
import 'sql_exporter/compute_logical_snapshot_files.libsonnet',
import 'sql_exporter/compute_logical_snapshots_bytes.libsonnet',
import 'sql_exporter/compute_max_connections.libsonnet',
import 'sql_exporter/compute_receive_lsn.libsonnet',
import 'sql_exporter/compute_subscriptions_count.libsonnet',

View File

@@ -1,9 +1,5 @@
[databases]
;; pgbouncer propagates application_name (if it's specified) to the server, but some
;; clients don't set it. We set default application_name=pgbouncer to make it
;; easier to identify pgbouncer connections in Postgres. If client sets
;; application_name, it will be used instead.
*=host=localhost port=5432 auth_user=cloud_admin application_name=pgbouncer
*=host=localhost port=5432 auth_user=cloud_admin
[pgbouncer]
listen_port=6432
listen_addr=0.0.0.0

View File

@@ -1,7 +0,0 @@
SELECT
(SELECT current_setting('neon.timeline_id')) AS timeline_id,
-- Postgres creates temporary snapshot files of the form %X-%X.snap.%d.tmp.
-- These temporary snapshot files are renamed to the actual snapshot files
-- after they are completely built. We only WAL-log the completely built
-- snapshot files
(SELECT COALESCE(sum(size), 0) FROM pg_ls_logicalsnapdir() WHERE name LIKE '%.snap') AS logical_snapshots_bytes;

View File

@@ -1,17 +0,0 @@
local neon = import 'neon.libsonnet';
local pg_ls_logicalsnapdir = importstr 'sql_exporter/compute_logical_snapshots_bytes.15.sql';
local pg_ls_dir = importstr 'sql_exporter/compute_logical_snapshots_bytes.sql';
{
metric_name: 'compute_logical_snapshots_bytes',
type: 'gauge',
help: 'Size of the pg_logical/snapshots directory, not including temporary files',
key_labels: [
'timeline_id',
],
values: [
'logical_snapshots_bytes',
],
query: if neon.PG_MAJORVERSION_NUM < 15 then pg_ls_dir else pg_ls_logicalsnapdir,
}

View File

@@ -1,9 +0,0 @@
SELECT
(SELECT setting FROM pg_settings WHERE name = 'neon.timeline_id') AS timeline_id,
-- Postgres creates temporary snapshot files of the form %X-%X.snap.%d.tmp.
-- These temporary snapshot files are renamed to the actual snapshot files
-- after they are completely built. We only WAL-log the completely built
-- snapshot files
(SELECT COALESCE(sum((pg_stat_file('pg_logical/snapshots/' || name, missing_ok => true)).size), 0)
FROM (SELECT * FROM pg_ls_dir('pg_logical/snapshots') WHERE pg_ls_dir LIKE '%.snap') AS name
) AS logical_snapshots_bytes;

View File

@@ -1,174 +0,0 @@
diff --git a/expected/ut-A.out b/expected/ut-A.out
index e7d68a1..65a056c 100644
--- a/expected/ut-A.out
+++ b/expected/ut-A.out
@@ -9,13 +9,16 @@ SET search_path TO public;
----
-- No.A-1-1-3
CREATE EXTENSION pg_hint_plan;
+LOG: Sending request to compute_ctl: http://localhost:3080/extension_server/pg_hint_plan
-- No.A-1-2-3
DROP EXTENSION pg_hint_plan;
-- No.A-1-1-4
CREATE SCHEMA other_schema;
CREATE EXTENSION pg_hint_plan SCHEMA other_schema;
+LOG: Sending request to compute_ctl: http://localhost:3080/extension_server/pg_hint_plan
ERROR: extension "pg_hint_plan" must be installed in schema "hint_plan"
CREATE EXTENSION pg_hint_plan;
+LOG: Sending request to compute_ctl: http://localhost:3080/extension_server/pg_hint_plan
DROP SCHEMA other_schema;
----
---- No. A-5-1 comment pattern
diff --git a/expected/ut-J.out b/expected/ut-J.out
index 2fa3c70..314e929 100644
--- a/expected/ut-J.out
+++ b/expected/ut-J.out
@@ -789,38 +789,6 @@ NestLoop(st1 st2)
MergeJoin(t1 t2)
not used hint:
duplication hint:
-error hint:
-
-LOG: pg_hint_plan:
-used hint:
-not used hint:
-NestLoop(st1 st2)
-MergeJoin(t1 t2)
-duplication hint:
-error hint:
-
-LOG: pg_hint_plan:
-used hint:
-not used hint:
-NestLoop(st1 st2)
-MergeJoin(t1 t2)
-duplication hint:
-error hint:
-
-LOG: pg_hint_plan:
-used hint:
-not used hint:
-NestLoop(st1 st2)
-MergeJoin(t1 t2)
-duplication hint:
-error hint:
-
-LOG: pg_hint_plan:
-used hint:
-not used hint:
-NestLoop(st1 st2)
-MergeJoin(t1 t2)
-duplication hint:
error hint:
explain_filter
diff --git a/expected/ut-S.out b/expected/ut-S.out
index 0bfcfb8..e75f581 100644
--- a/expected/ut-S.out
+++ b/expected/ut-S.out
@@ -4415,34 +4415,6 @@ used hint:
IndexScan(ti1 ti1_pred)
not used hint:
duplication hint:
-error hint:
-
-LOG: pg_hint_plan:
-used hint:
-not used hint:
-IndexScan(ti1 ti1_pred)
-duplication hint:
-error hint:
-
-LOG: pg_hint_plan:
-used hint:
-not used hint:
-IndexScan(ti1 ti1_pred)
-duplication hint:
-error hint:
-
-LOG: pg_hint_plan:
-used hint:
-not used hint:
-IndexScan(ti1 ti1_pred)
-duplication hint:
-error hint:
-
-LOG: pg_hint_plan:
-used hint:
-not used hint:
-IndexScan(ti1 ti1_pred)
-duplication hint:
error hint:
explain_filter
diff --git a/expected/ut-W.out b/expected/ut-W.out
index a09bd34..0ad227c 100644
--- a/expected/ut-W.out
+++ b/expected/ut-W.out
@@ -1341,54 +1341,6 @@ IndexScan(ft1)
IndexScan(t)
Parallel(s1 3 hard)
duplication hint:
-error hint:
-
-LOG: pg_hint_plan:
-used hint:
-not used hint:
-IndexScan(*VALUES*)
-SeqScan(cte1)
-IndexScan(ft1)
-IndexScan(t)
-Parallel(p1 5 hard)
-Parallel(s1 3 hard)
-duplication hint:
-error hint:
-
-LOG: pg_hint_plan:
-used hint:
-not used hint:
-IndexScan(*VALUES*)
-SeqScan(cte1)
-IndexScan(ft1)
-IndexScan(t)
-Parallel(p1 5 hard)
-Parallel(s1 3 hard)
-duplication hint:
-error hint:
-
-LOG: pg_hint_plan:
-used hint:
-not used hint:
-IndexScan(*VALUES*)
-SeqScan(cte1)
-IndexScan(ft1)
-IndexScan(t)
-Parallel(p1 5 hard)
-Parallel(s1 3 hard)
-duplication hint:
-error hint:
-
-LOG: pg_hint_plan:
-used hint:
-not used hint:
-IndexScan(*VALUES*)
-SeqScan(cte1)
-IndexScan(ft1)
-IndexScan(t)
-Parallel(p1 5 hard)
-Parallel(s1 3 hard)
-duplication hint:
error hint:
explain_filter
diff --git a/expected/ut-fdw.out b/expected/ut-fdw.out
index 017fa4b..98d989b 100644
--- a/expected/ut-fdw.out
+++ b/expected/ut-fdw.out
@@ -7,6 +7,7 @@ SET pg_hint_plan.debug_print TO on;
SET client_min_messages TO LOG;
SET pg_hint_plan.enable_hint TO on;
CREATE EXTENSION file_fdw;
+LOG: Sending request to compute_ctl: http://localhost:3080/extension_server/file_fdw
CREATE SERVER file_server FOREIGN DATA WRAPPER file_fdw;
CREATE USER MAPPING FOR PUBLIC SERVER file_server;
CREATE FOREIGN TABLE ft1 (id int, val int) SERVER file_server OPTIONS (format 'csv', filename :'filename');

View File

@@ -1,42 +0,0 @@
commit 46b38d3e46f9cd6c70d9b189dd6ff4abaa17cf5e
Author: Alexander Bayandin <alexander@neon.tech>
Date: Sat Nov 30 18:29:32 2024 +0000
Fix v8 9.7.37 compilation on Debian 12
diff --git a/patches/code/84cf3230a9680aac3b73c410c2b758760b6d3066.patch b/patches/code/84cf3230a9680aac3b73c410c2b758760b6d3066.patch
new file mode 100644
index 0000000..f0a5dc7
--- /dev/null
+++ b/patches/code/84cf3230a9680aac3b73c410c2b758760b6d3066.patch
@@ -0,0 +1,30 @@
+From 84cf3230a9680aac3b73c410c2b758760b6d3066 Mon Sep 17 00:00:00 2001
+From: Michael Lippautz <mlippautz@chromium.org>
+Date: Thu, 27 Jan 2022 14:14:11 +0100
+Subject: [PATCH] cppgc: Fix include
+
+Add <utility> to cover for std::exchange.
+
+Bug: v8:12585
+Change-Id: Ida65144e93e466be8914527d0e646f348c136bcb
+Reviewed-on: https://chromium-review.googlesource.com/c/v8/v8/+/3420309
+Auto-Submit: Michael Lippautz <mlippautz@chromium.org>
+Reviewed-by: Omer Katz <omerkatz@chromium.org>
+Commit-Queue: Michael Lippautz <mlippautz@chromium.org>
+Cr-Commit-Position: refs/heads/main@{#78820}
+---
+ src/heap/cppgc/prefinalizer-handler.h | 1 +
+ 1 file changed, 1 insertion(+)
+
+diff --git a/src/heap/cppgc/prefinalizer-handler.h b/src/heap/cppgc/prefinalizer-handler.h
+index bc17c99b1838..c82c91ff5a45 100644
+--- a/src/heap/cppgc/prefinalizer-handler.h
++++ b/src/heap/cppgc/prefinalizer-handler.h
+@@ -5,6 +5,7 @@
+ #ifndef V8_HEAP_CPPGC_PREFINALIZER_HANDLER_H_
+ #define V8_HEAP_CPPGC_PREFINALIZER_HANDLER_H_
+
++#include <utility>
+ #include <vector>
+
+ #include "include/cppgc/prefinalizer.h"

View File

@@ -37,7 +37,6 @@ use std::collections::HashMap;
use std::fs::File;
use std::path::Path;
use std::process::exit;
use std::str::FromStr;
use std::sync::atomic::Ordering;
use std::sync::{mpsc, Arc, Condvar, Mutex, RwLock};
use std::{thread, time::Duration};
@@ -59,7 +58,7 @@ use compute_tools::compute::{
forward_termination_signal, ComputeNode, ComputeState, ParsedSpec, PG_PID,
};
use compute_tools::configurator::launch_configurator;
use compute_tools::extension_server::get_pg_version_string;
use compute_tools::extension_server::get_pg_version;
use compute_tools::http::api::launch_http_server;
use compute_tools::logger::*;
use compute_tools::monitor::launch_monitor;
@@ -323,19 +322,11 @@ fn wait_spec(
} else {
spec_set = false;
}
let connstr = Url::parse(connstr).context("cannot parse connstr as a URL")?;
let conn_conf = postgres::config::Config::from_str(connstr.as_str())
.context("cannot build postgres config from connstr")?;
let tokio_conn_conf = tokio_postgres::config::Config::from_str(connstr.as_str())
.context("cannot build tokio postgres config from connstr")?;
let compute_node = ComputeNode {
connstr,
conn_conf,
tokio_conn_conf,
connstr: Url::parse(connstr).context("cannot parse connstr as a URL")?,
pgdata: pgdata.to_string(),
pgbin: pgbin.to_string(),
pgversion: get_pg_version_string(pgbin),
http_port,
pgversion: get_pg_version(pgbin),
live_config_allowed,
state: Mutex::new(new_state),
state_changed: Condvar::new(),
@@ -390,6 +381,7 @@ fn wait_spec(
Ok(WaitSpecResult {
compute,
http_port,
resize_swap_on_bind,
set_disk_quota_for_fs: set_disk_quota_for_fs.cloned(),
})
@@ -397,6 +389,8 @@ fn wait_spec(
struct WaitSpecResult {
compute: Arc<ComputeNode>,
// passed through from ProcessCliResult
http_port: u16,
resize_swap_on_bind: bool,
set_disk_quota_for_fs: Option<String>,
}
@@ -406,6 +400,7 @@ fn start_postgres(
#[allow(unused_variables)] matches: &clap::ArgMatches,
WaitSpecResult {
compute,
http_port,
resize_swap_on_bind,
set_disk_quota_for_fs,
}: WaitSpecResult,
@@ -478,10 +473,12 @@ fn start_postgres(
}
}
let extension_server_port: u16 = http_port;
// Start Postgres
let mut pg = None;
if !prestartup_failed {
pg = match compute.start_compute() {
pg = match compute.start_compute(extension_server_port) {
Ok(pg) => Some(pg),
Err(err) => {
error!("could not start the compute node: {:#}", err);

View File

@@ -21,7 +21,7 @@
//! - Build the image with the following command:
//!
//! ```bash
//! docker buildx build --platform linux/amd64 --build-arg DEBIAN_VERSION=bullseye --build-arg GIT_VERSION=local --build-arg PG_VERSION=v14 --build-arg BUILD_TAG="$(date --iso-8601=s -u)" -t localhost:3030/localregistry/compute-node-v14:latest -f compute/compute-node.Dockerfile .
//! docker buildx build --build-arg DEBIAN_FLAVOR=bullseye-slim --build-arg GIT_VERSION=local --build-arg PG_VERSION=v14 --build-arg BUILD_TAG="$(date --iso-8601=s -u)" -t localhost:3030/localregistry/compute-node-v14:latest -f compute/Dockerfile.com
//! docker push localhost:3030/localregistry/compute-node-v14:latest
//! ```
@@ -29,7 +29,6 @@ use anyhow::Context;
use aws_config::BehaviorVersion;
use camino::{Utf8Path, Utf8PathBuf};
use clap::Parser;
use compute_tools::extension_server::{get_pg_version, PostgresMajorVersion};
use nix::unistd::Pid;
use tracing::{info, info_span, warn, Instrument};
use utils::fs_ext::is_directory_empty;
@@ -132,18 +131,11 @@ pub(crate) async fn main() -> anyhow::Result<()> {
//
// Initialize pgdata
//
let pgbin = pg_bin_dir.join("postgres");
let pg_version = match get_pg_version(pgbin.as_ref()) {
PostgresMajorVersion::V14 => 14,
PostgresMajorVersion::V15 => 15,
PostgresMajorVersion::V16 => 16,
PostgresMajorVersion::V17 => 17,
};
let superuser = "cloud_admin"; // XXX: this shouldn't be hard-coded
postgres_initdb::do_run_initdb(postgres_initdb::RunInitdbArgs {
superuser,
locale: "en_US.UTF-8", // XXX: this shouldn't be hard-coded,
pg_version,
pg_version: 140000, // XXX: this shouldn't be hard-coded but derived from which compute image we're running in
initdb_bin: pg_bin_dir.join("initdb").as_ref(),
library_search_path: &pg_lib_dir, // TODO: is this right? Prob works in compute image, not sure about neon_local.
pgdata: &pgdata_dir,
@@ -156,7 +148,7 @@ pub(crate) async fn main() -> anyhow::Result<()> {
//
// Launch postgres process
//
let mut postgres_proc = tokio::process::Command::new(pgbin)
let mut postgres_proc = tokio::process::Command::new(pg_bin_dir.join("postgres"))
.arg("-D")
.arg(&pgdata_dir)
.args(["-c", "wal_level=minimal"])

View File

@@ -1,3 +1,4 @@
use compute_api::responses::CatalogObjects;
use futures::Stream;
use postgres::NoTls;
use std::{path::Path, process::Stdio, result::Result, sync::Arc};
@@ -6,17 +7,19 @@ use tokio::{
process::Command,
spawn,
};
use tokio_postgres::connect;
use tokio_stream::{self as stream, StreamExt};
use tokio_util::codec::{BytesCodec, FramedRead};
use tracing::warn;
use crate::compute::ComputeNode;
use crate::pg_helpers::{get_existing_dbs_async, get_existing_roles_async, postgres_conf_for_db};
use compute_api::responses::CatalogObjects;
use crate::pg_helpers::{get_existing_dbs_async, get_existing_roles_async};
pub async fn get_dbs_and_roles(compute: &Arc<ComputeNode>) -> anyhow::Result<CatalogObjects> {
let conf = compute.get_tokio_conn_conf(Some("compute_ctl:get_dbs_and_roles"));
let (client, connection): (tokio_postgres::Client, _) = conf.connect(NoTls).await?;
let connstr = compute.connstr.clone();
let (client, connection): (tokio_postgres::Client, _) =
connect(connstr.as_str(), NoTls).await?;
spawn(async move {
if let Err(e) = connection.await {
@@ -40,8 +43,6 @@ pub enum SchemaDumpError {
DatabaseDoesNotExist,
#[error("Failed to execute pg_dump.")]
IO(#[from] std::io::Error),
#[error("Unexpected error.")]
Unexpected,
}
// It uses the pg_dump utility to dump the schema of the specified database.
@@ -59,38 +60,11 @@ pub async fn get_database_schema(
let pgbin = &compute.pgbin;
let basepath = Path::new(pgbin).parent().unwrap();
let pgdump = basepath.join("pg_dump");
// Replace the DB in the connection string and disable it to parts.
// This is the only option to handle DBs with special characters.
let conf =
postgres_conf_for_db(&compute.connstr, dbname).map_err(|_| SchemaDumpError::Unexpected)?;
let host = conf
.get_hosts()
.first()
.ok_or(SchemaDumpError::Unexpected)?;
let host = match host {
tokio_postgres::config::Host::Tcp(ip) => ip.to_string(),
#[cfg(unix)]
tokio_postgres::config::Host::Unix(path) => path.to_string_lossy().to_string(),
};
let port = conf
.get_ports()
.first()
.ok_or(SchemaDumpError::Unexpected)?;
let user = conf.get_user().ok_or(SchemaDumpError::Unexpected)?;
let dbname = conf.get_dbname().ok_or(SchemaDumpError::Unexpected)?;
let mut connstr = compute.connstr.clone();
connstr.set_path(dbname);
let mut cmd = Command::new(pgdump)
// XXX: this seems to be the only option to deal with DBs with `=` in the name
// See <https://www.postgresql.org/message-id/flat/20151023003445.931.91267%40wrigleys.postgresql.org>
.env("PGDATABASE", dbname)
.arg("--host")
.arg(host)
.arg("--port")
.arg(port.to_string())
.arg("--username")
.arg(user)
.arg("--schema-only")
.arg(connstr.as_str())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.kill_on_drop(true)

View File

@@ -9,8 +9,7 @@ use crate::compute::ComputeNode;
#[instrument(skip_all)]
pub async fn check_writability(compute: &ComputeNode) -> Result<()> {
// Connect to the database.
let conf = compute.get_tokio_conn_conf(Some("compute_ctl:availability_checker"));
let (client, connection) = conf.connect(NoTls).await?;
let (client, connection) = tokio_postgres::connect(compute.connstr.as_str(), NoTls).await?;
if client.is_closed() {
return Err(anyhow!("connection to postgres closed"));
}

View File

@@ -20,9 +20,8 @@ use futures::future::join_all;
use futures::stream::FuturesUnordered;
use futures::StreamExt;
use nix::unistd::Pid;
use postgres;
use postgres::error::SqlState;
use postgres::NoTls;
use postgres::{Client, NoTls};
use tracing::{debug, error, info, instrument, warn};
use utils::id::{TenantId, TimelineId};
use utils::lsn::Lsn;
@@ -35,8 +34,9 @@ use utils::measured_stream::MeasuredReader;
use nix::sys::signal::{kill, Signal};
use remote_storage::{DownloadError, RemotePath};
use tokio::spawn;
use url::Url;
use crate::installed_extensions::get_installed_extensions;
use crate::installed_extensions::get_installed_extensions_sync;
use crate::local_proxy;
use crate::pg_helpers::*;
use crate::spec::*;
@@ -59,10 +59,6 @@ pub static PG_PID: AtomicU32 = AtomicU32::new(0);
pub struct ComputeNode {
// Url type maintains proper escaping
pub connstr: url::Url,
// We connect to Postgres from many different places, so build configs once
// and reuse them where needed.
pub conn_conf: postgres::config::Config,
pub tokio_conn_conf: tokio_postgres::config::Config,
pub pgdata: String,
pub pgbin: String,
pub pgversion: String,
@@ -79,8 +75,6 @@ pub struct ComputeNode {
/// - we push spec and it does configuration
/// - but then it is restarted without any spec again
pub live_config_allowed: bool,
/// The port that the compute's HTTP server listens on
pub http_port: u16,
/// Volatile part of the `ComputeNode`, which should be used under `Mutex`.
/// To allow HTTP API server to serving status requests, while configuration
/// is in progress, lock should be held only for short periods of time to do
@@ -613,7 +607,11 @@ impl ComputeNode {
/// Do all the preparations like PGDATA directory creation, configuration,
/// safekeepers sync, basebackup, etc.
#[instrument(skip_all)]
pub fn prepare_pgdata(&self, compute_state: &ComputeState) -> Result<()> {
pub fn prepare_pgdata(
&self,
compute_state: &ComputeState,
extension_server_port: u16,
) -> Result<()> {
let pspec = compute_state.pspec.as_ref().expect("spec must be set");
let spec = &pspec.spec;
let pgdata_path = Path::new(&self.pgdata);
@@ -623,7 +621,7 @@ impl ComputeNode {
config::write_postgres_conf(
&pgdata_path.join("postgresql.conf"),
&pspec.spec,
self.http_port,
Some(extension_server_port),
)?;
// Syncing safekeepers is only safe with primary nodes: if a primary
@@ -803,10 +801,10 @@ impl ComputeNode {
/// version. In the future, it may upgrade all 3rd-party extensions.
#[instrument(skip_all)]
pub fn post_apply_config(&self) -> Result<()> {
let conf = self.get_conn_conf(Some("compute_ctl:post_apply_config"));
let connstr = self.connstr.clone();
thread::spawn(move || {
let func = || {
let mut client = conf.connect(NoTls)?;
let mut client = Client::connect(connstr.as_str(), NoTls)?;
handle_neon_extension_upgrade(&mut client)
.context("handle_neon_extension_upgrade")?;
Ok::<_, anyhow::Error>(())
@@ -818,48 +816,30 @@ impl ComputeNode {
Ok(())
}
pub fn get_conn_conf(&self, application_name: Option<&str>) -> postgres::Config {
let mut conf = self.conn_conf.clone();
if let Some(application_name) = application_name {
conf.application_name(application_name);
}
conf
}
async fn get_maintenance_client(url: &Url) -> Result<tokio_postgres::Client> {
let mut connstr = url.clone();
pub fn get_tokio_conn_conf(&self, application_name: Option<&str>) -> tokio_postgres::Config {
let mut conf = self.tokio_conn_conf.clone();
if let Some(application_name) = application_name {
conf.application_name(application_name);
}
conf
}
connstr
.query_pairs_mut()
.append_pair("application_name", "apply_config");
async fn get_maintenance_client(
conf: &tokio_postgres::Config,
) -> Result<tokio_postgres::Client> {
let mut conf = conf.clone();
conf.application_name("compute_ctl:apply_config");
let (client, conn) = match conf.connect(NoTls).await {
// If connection fails, it may be the old node with `zenith_admin` superuser.
//
// In this case we need to connect with old `zenith_admin` name
// and create new user. We cannot simply rename connected user,
// but we can create a new one and grant it all privileges.
let (client, conn) = match tokio_postgres::connect(connstr.as_str(), NoTls).await {
Err(e) => match e.code() {
Some(&SqlState::INVALID_PASSWORD)
| Some(&SqlState::INVALID_AUTHORIZATION_SPECIFICATION) => {
// Connect with zenith_admin if cloud_admin could not authenticate
// connect with zenith_admin if cloud_admin could not authenticate
info!(
"cannot connect to postgres: {}, retrying with `zenith_admin` username",
e
);
let mut zenith_admin_conf = postgres::config::Config::from(conf.clone());
zenith_admin_conf.application_name("compute_ctl:apply_config");
zenith_admin_conf.user("zenith_admin");
let mut zenith_admin_connstr = connstr.clone();
zenith_admin_connstr
.set_username("zenith_admin")
.map_err(|_| anyhow::anyhow!("invalid connstr"))?;
let mut client =
zenith_admin_conf.connect(NoTls)
Client::connect(zenith_admin_connstr.as_str(), NoTls)
.context("broken cloud_admin credential: tried connecting with cloud_admin but could not authenticate, and zenith_admin does not work either")?;
// Disable forwarding so that users don't get a cloud_admin role
@@ -873,8 +853,8 @@ impl ComputeNode {
drop(client);
// Reconnect with connstring with expected name
conf.connect(NoTls).await?
// reconnect with connstring with expected name
tokio_postgres::connect(connstr.as_str(), NoTls).await?
}
_ => return Err(e.into()),
},
@@ -905,7 +885,7 @@ impl ComputeNode {
pub fn apply_spec_sql(
&self,
spec: Arc<ComputeSpec>,
conf: Arc<tokio_postgres::Config>,
url: Arc<Url>,
concurrency: usize,
) -> Result<()> {
let rt = tokio::runtime::Builder::new_multi_thread()
@@ -917,7 +897,7 @@ impl ComputeNode {
rt.block_on(async {
// Proceed with post-startup configuration. Note, that order of operations is important.
let client = Self::get_maintenance_client(&conf).await?;
let client = Self::get_maintenance_client(&url).await?;
let spec = spec.clone();
let databases = get_existing_dbs_async(&client).await?;
@@ -951,7 +931,7 @@ impl ComputeNode {
RenameAndDeleteDatabases,
CreateAndAlterDatabases,
] {
info!("Applying phase {:?}", &phase);
debug!("Applying phase {:?}", &phase);
apply_operations(
spec.clone(),
ctx.clone(),
@@ -962,7 +942,6 @@ impl ComputeNode {
.await?;
}
info!("Applying RunInEachDatabase phase");
let concurrency_token = Arc::new(tokio::sync::Semaphore::new(concurrency));
let db_processes = spec
@@ -976,7 +955,7 @@ impl ComputeNode {
let spec = spec.clone();
let ctx = ctx.clone();
let jwks_roles = jwks_roles.clone();
let mut conf = conf.as_ref().clone();
let mut url = url.as_ref().clone();
let concurrency_token = concurrency_token.clone();
let db = db.clone();
@@ -985,14 +964,14 @@ impl ComputeNode {
match &db {
DB::SystemDB => {}
DB::UserDB(db) => {
conf.dbname(db.name.as_str());
url.set_path(db.name.as_str());
}
}
let conf = Arc::new(conf);
let url = Arc::new(url);
let fut = Self::apply_spec_sql_db(
spec.clone(),
conf,
url,
ctx.clone(),
jwks_roles.clone(),
concurrency_token.clone(),
@@ -1038,7 +1017,7 @@ impl ComputeNode {
/// semaphore. The caller has to make sure the semaphore isn't exhausted.
async fn apply_spec_sql_db(
spec: Arc<ComputeSpec>,
conf: Arc<tokio_postgres::Config>,
url: Arc<Url>,
ctx: Arc<tokio::sync::RwLock<MutableApplyContext>>,
jwks_roles: Arc<HashSet<String>>,
concurrency_token: Arc<tokio::sync::Semaphore>,
@@ -1067,7 +1046,7 @@ impl ComputeNode {
// that database.
|| async {
if client_conn.is_none() {
let db_client = Self::get_maintenance_client(&conf).await?;
let db_client = Self::get_maintenance_client(&url).await?;
client_conn.replace(db_client);
}
let client = client_conn.as_ref().unwrap();
@@ -1082,16 +1061,34 @@ impl ComputeNode {
Ok::<(), anyhow::Error>(())
}
/// Choose how many concurrent connections to use for applying the spec changes.
pub fn max_service_connections(
&self,
compute_state: &ComputeState,
spec: &ComputeSpec,
) -> usize {
// If the cluster is in Init state we don't have to deal with user connections,
/// Do initial configuration of the already started Postgres.
#[instrument(skip_all)]
pub fn apply_config(&self, compute_state: &ComputeState) -> Result<()> {
// If connection fails,
// it may be the old node with `zenith_admin` superuser.
//
// In this case we need to connect with old `zenith_admin` name
// and create new user. We cannot simply rename connected user,
// but we can create a new one and grant it all privileges.
let mut url = self.connstr.clone();
url.query_pairs_mut()
.append_pair("application_name", "apply_config");
let url = Arc::new(url);
let spec = Arc::new(
compute_state
.pspec
.as_ref()
.expect("spec must be set")
.spec
.clone(),
);
// Choose how many concurrent connections to use for applying the spec changes.
// If the cluster is not currently Running we don't have to deal with user connections,
// and can thus use all `max_connections` connection slots. However, that's generally not
// very efficient, so we generally still limit it to a smaller number.
if compute_state.status == ComputeStatus::Init {
let max_concurrent_connections = if compute_state.status != ComputeStatus::Running {
// If the settings contain 'max_connections', use that as template
if let Some(config) = spec.cluster.settings.find("max_connections") {
config.parse::<usize>().ok()
@@ -1147,28 +1144,10 @@ impl ComputeNode {
.map(|val| if val > 1 { val - 1 } else { 1 })
.last()
.unwrap_or(3)
}
}
/// Do initial configuration of the already started Postgres.
#[instrument(skip_all)]
pub fn apply_config(&self, compute_state: &ComputeState) -> Result<()> {
let conf = self.get_tokio_conn_conf(Some("compute_ctl:apply_config"));
let conf = Arc::new(conf);
let spec = Arc::new(
compute_state
.pspec
.as_ref()
.expect("spec must be set")
.spec
.clone(),
);
let max_concurrent_connections = self.max_service_connections(compute_state, &spec);
};
// Merge-apply spec & changes to PostgreSQL state.
self.apply_spec_sql(spec.clone(), conf.clone(), max_concurrent_connections)?;
self.apply_spec_sql(spec.clone(), url.clone(), max_concurrent_connections)?;
if let Some(ref local_proxy) = &spec.clone().local_proxy_config {
info!("configuring local_proxy");
@@ -1177,11 +1156,12 @@ impl ComputeNode {
// Run migrations separately to not hold up cold starts
thread::spawn(move || {
let conf = conf.as_ref().clone();
let mut conf = postgres::config::Config::from(conf);
conf.application_name("compute_ctl:migrations");
let mut connstr = url.as_ref().clone();
connstr
.query_pairs_mut()
.append_pair("application_name", "migrations");
let mut client = conf.connect(NoTls)?;
let mut client = Client::connect(connstr.as_str(), NoTls)?;
handle_migrations(&mut client).context("apply_config handle_migrations")
});
@@ -1241,24 +1221,22 @@ impl ComputeNode {
// Write new config
let pgdata_path = Path::new(&self.pgdata);
let postgresql_conf_path = pgdata_path.join("postgresql.conf");
config::write_postgres_conf(&postgresql_conf_path, &spec, self.http_port)?;
let max_concurrent_connections = spec.reconfigure_concurrency;
// Temporarily reset max_cluster_size in config
config::write_postgres_conf(&postgresql_conf_path, &spec, None)?;
// temporarily reset max_cluster_size in config
// to avoid the possibility of hitting the limit, while we are reconfiguring:
// creating new extensions, roles, etc.
// creating new extensions, roles, etc...
config::with_compute_ctl_tmp_override(pgdata_path, "neon.max_cluster_size=-1", || {
self.pg_reload_conf()?;
if spec.mode == ComputeMode::Primary {
let mut conf = tokio_postgres::Config::from_str(self.connstr.as_str()).unwrap();
conf.application_name("apply_config");
let conf = Arc::new(conf);
let mut url = self.connstr.clone();
url.query_pairs_mut()
.append_pair("application_name", "apply_config");
let url = Arc::new(url);
let spec = Arc::new(spec.clone());
self.apply_spec_sql(spec, conf, max_concurrent_connections)?;
self.apply_spec_sql(spec, url, 1)?;
}
Ok(())
@@ -1277,7 +1255,10 @@ impl ComputeNode {
}
#[instrument(skip_all)]
pub fn start_compute(&self) -> Result<(std::process::Child, std::thread::JoinHandle<()>)> {
pub fn start_compute(
&self,
extension_server_port: u16,
) -> Result<(std::process::Child, std::thread::JoinHandle<()>)> {
let compute_state = self.state.lock().unwrap().clone();
let pspec = compute_state.pspec.as_ref().expect("spec must be set");
info!(
@@ -1352,7 +1333,7 @@ impl ComputeNode {
info!("{:?}", remote_ext_metrics);
}
self.prepare_pgdata(&compute_state)?;
self.prepare_pgdata(&compute_state, extension_server_port)?;
let start_time = Utc::now();
let pg_process = self.start_postgres(pspec.storage_auth_token.clone())?;
@@ -1379,19 +1360,9 @@ impl ComputeNode {
}
self.post_apply_config()?;
let conf = self.get_conn_conf(None);
let connstr = self.connstr.clone();
thread::spawn(move || {
let res = get_installed_extensions(conf);
match res {
Ok(extensions) => {
info!(
"[NEON_EXT_STAT] {}",
serde_json::to_string(&extensions)
.expect("failed to serialize extensions list")
);
}
Err(err) => error!("could not get installed extensions: {err:?}"),
}
get_installed_extensions_sync(connstr).context("get_installed_extensions")
});
}
@@ -1520,8 +1491,7 @@ impl ComputeNode {
/// Select `pg_stat_statements` data and return it as a stringified JSON
pub async fn collect_insights(&self) -> String {
let mut result_rows: Vec<String> = Vec::new();
let conf = self.get_tokio_conn_conf(Some("compute_ctl:collect_insights"));
let connect_result = conf.connect(NoTls).await;
let connect_result = tokio_postgres::connect(self.connstr.as_str(), NoTls).await;
let (client, connection) = connect_result.unwrap();
tokio::spawn(async move {
if let Err(e) = connection.await {
@@ -1647,9 +1617,10 @@ LIMIT 100",
privileges: &[Privilege],
role_name: &PgIdent,
) -> Result<()> {
use tokio_postgres::config::Config;
use tokio_postgres::NoTls;
let mut conf = self.get_tokio_conn_conf(Some("compute_ctl:set_role_grants"));
let mut conf = Config::from_str(self.connstr.as_str()).unwrap();
conf.dbname(db_name);
let (db_client, conn) = conf
@@ -1686,9 +1657,10 @@ LIMIT 100",
db_name: &PgIdent,
ext_version: ExtVersion,
) -> Result<ExtVersion> {
use tokio_postgres::config::Config;
use tokio_postgres::NoTls;
let mut conf = self.get_tokio_conn_conf(Some("compute_ctl:install_extension"));
let mut conf = Config::from_str(self.connstr.as_str()).unwrap();
conf.dbname(db_name);
let (db_client, conn) = conf

View File

@@ -37,7 +37,7 @@ pub fn line_in_file(path: &Path, line: &str) -> Result<bool> {
pub fn write_postgres_conf(
path: &Path,
spec: &ComputeSpec,
extension_server_port: u16,
extension_server_port: Option<u16>,
) -> Result<()> {
// File::create() destroys the file content if it exists.
let mut file = File::create(path)?;
@@ -127,7 +127,9 @@ pub fn write_postgres_conf(
writeln!(file, "# Managed by compute_ctl: end")?;
}
writeln!(file, "neon.extension_server_port={}", extension_server_port)?;
if let Some(port) = extension_server_port {
writeln!(file, "neon.extension_server_port={}", port)?;
}
// This is essential to keep this line at the end of the file,
// because it is intended to override any settings above.

View File

@@ -103,33 +103,14 @@ fn get_pg_config(argument: &str, pgbin: &str) -> String {
.to_string()
}
pub fn get_pg_version(pgbin: &str) -> PostgresMajorVersion {
pub fn get_pg_version(pgbin: &str) -> String {
// pg_config --version returns a (platform specific) human readable string
// such as "PostgreSQL 15.4". We parse this to v14/v15/v16 etc.
let human_version = get_pg_config("--version", pgbin);
parse_pg_version(&human_version)
parse_pg_version(&human_version).to_string()
}
pub fn get_pg_version_string(pgbin: &str) -> String {
match get_pg_version(pgbin) {
PostgresMajorVersion::V14 => "v14",
PostgresMajorVersion::V15 => "v15",
PostgresMajorVersion::V16 => "v16",
PostgresMajorVersion::V17 => "v17",
}
.to_owned()
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum PostgresMajorVersion {
V14,
V15,
V16,
V17,
}
fn parse_pg_version(human_version: &str) -> PostgresMajorVersion {
use PostgresMajorVersion::*;
fn parse_pg_version(human_version: &str) -> &str {
// Normal releases have version strings like "PostgreSQL 15.4". But there
// are also pre-release versions like "PostgreSQL 17devel" or "PostgreSQL
// 16beta2" or "PostgreSQL 17rc1". And with the --with-extra-version
@@ -140,10 +121,10 @@ fn parse_pg_version(human_version: &str) -> PostgresMajorVersion {
.captures(human_version)
{
Some(captures) if captures.len() == 2 => match &captures["major"] {
"14" => return V14,
"15" => return V15,
"16" => return V16,
"17" => return V17,
"14" => return "v14",
"15" => return "v15",
"16" => return "v16",
"17" => return "v17",
_ => {}
},
_ => {}
@@ -282,25 +263,24 @@ mod tests {
#[test]
fn test_parse_pg_version() {
use super::PostgresMajorVersion::*;
assert_eq!(parse_pg_version("PostgreSQL 15.4"), V15);
assert_eq!(parse_pg_version("PostgreSQL 15.14"), V15);
assert_eq!(parse_pg_version("PostgreSQL 15.4"), "v15");
assert_eq!(parse_pg_version("PostgreSQL 15.14"), "v15");
assert_eq!(
parse_pg_version("PostgreSQL 15.4 (Ubuntu 15.4-0ubuntu0.23.04.1)"),
V15
"v15"
);
assert_eq!(parse_pg_version("PostgreSQL 14.15"), V14);
assert_eq!(parse_pg_version("PostgreSQL 14.0"), V14);
assert_eq!(parse_pg_version("PostgreSQL 14.15"), "v14");
assert_eq!(parse_pg_version("PostgreSQL 14.0"), "v14");
assert_eq!(
parse_pg_version("PostgreSQL 14.9 (Debian 14.9-1.pgdg120+1"),
V14
"v14"
);
assert_eq!(parse_pg_version("PostgreSQL 16devel"), V16);
assert_eq!(parse_pg_version("PostgreSQL 16beta1"), V16);
assert_eq!(parse_pg_version("PostgreSQL 16rc2"), V16);
assert_eq!(parse_pg_version("PostgreSQL 16extra"), V16);
assert_eq!(parse_pg_version("PostgreSQL 16devel"), "v16");
assert_eq!(parse_pg_version("PostgreSQL 16beta1"), "v16");
assert_eq!(parse_pg_version("PostgreSQL 16rc2"), "v16");
assert_eq!(parse_pg_version("PostgreSQL 16extra"), "v16");
}
#[test]

View File

@@ -295,12 +295,8 @@ async fn routes(req: Request<Body>, compute: &Arc<ComputeNode>) -> Response<Body
return Response::new(Body::from(msg));
}
let conf = compute.get_conn_conf(None);
let res =
task::spawn_blocking(move || installed_extensions::get_installed_extensions(conf))
.await
.unwrap();
let connstr = compute.connstr.clone();
let res = crate::installed_extensions::get_installed_extensions(connstr).await;
match res {
Ok(res) => render_json(Body::from(serde_json::to_string(&res).unwrap())),
Err(e) => render_json_error(

View File

@@ -2,9 +2,12 @@ use compute_api::responses::{InstalledExtension, InstalledExtensions};
use metrics::proto::MetricFamily;
use std::collections::HashMap;
use std::collections::HashSet;
use tracing::info;
use url::Url;
use anyhow::Result;
use postgres::{Client, NoTls};
use tokio::task;
use metrics::core::Collector;
use metrics::{register_uint_gauge_vec, UIntGaugeVec};
@@ -39,53 +42,75 @@ fn list_dbs(client: &mut Client) -> Result<Vec<String>> {
///
/// Same extension can be installed in multiple databases with different versions,
/// we only keep the highest and lowest version across all databases.
pub fn get_installed_extensions(mut conf: postgres::config::Config) -> Result<InstalledExtensions> {
conf.application_name("compute_ctl:get_installed_extensions");
let mut client = conf.connect(NoTls)?;
pub async fn get_installed_extensions(connstr: Url) -> Result<InstalledExtensions> {
let mut connstr = connstr.clone();
let databases: Vec<String> = list_dbs(&mut client)?;
task::spawn_blocking(move || {
let mut client = Client::connect(connstr.as_str(), NoTls)?;
let databases: Vec<String> = list_dbs(&mut client)?;
let mut extensions_map: HashMap<String, InstalledExtension> = HashMap::new();
for db in databases.iter() {
conf.dbname(db);
let mut db_client = conf.connect(NoTls)?;
let extensions: Vec<(String, String)> = db_client
.query(
"SELECT extname, extversion FROM pg_catalog.pg_extension;",
&[],
)?
.iter()
.map(|row| (row.get("extname"), row.get("extversion")))
.collect();
let mut extensions_map: HashMap<String, InstalledExtension> = HashMap::new();
for db in databases.iter() {
connstr.set_path(db);
let mut db_client = Client::connect(connstr.as_str(), NoTls)?;
let extensions: Vec<(String, String)> = db_client
.query(
"SELECT extname, extversion FROM pg_catalog.pg_extension;",
&[],
)?
.iter()
.map(|row| (row.get("extname"), row.get("extversion")))
.collect();
for (extname, v) in extensions.iter() {
let version = v.to_string();
for (extname, v) in extensions.iter() {
let version = v.to_string();
// increment the number of databases where the version of extension is installed
INSTALLED_EXTENSIONS
.with_label_values(&[extname, &version])
.inc();
// increment the number of databases where the version of extension is installed
INSTALLED_EXTENSIONS
.with_label_values(&[extname, &version])
.inc();
extensions_map
.entry(extname.to_string())
.and_modify(|e| {
e.versions.insert(version.clone());
// count the number of databases where the extension is installed
e.n_databases += 1;
})
.or_insert(InstalledExtension {
extname: extname.to_string(),
versions: HashSet::from([version.clone()]),
n_databases: 1,
});
extensions_map
.entry(extname.to_string())
.and_modify(|e| {
e.versions.insert(version.clone());
// count the number of databases where the extension is installed
e.n_databases += 1;
})
.or_insert(InstalledExtension {
extname: extname.to_string(),
versions: HashSet::from([version.clone()]),
n_databases: 1,
});
}
}
}
let res = InstalledExtensions {
extensions: extensions_map.into_values().collect(),
};
let res = InstalledExtensions {
extensions: extensions_map.values().cloned().collect(),
};
Ok(res)
Ok(res)
})
.await?
}
// Gather info about installed extensions
pub fn get_installed_extensions_sync(connstr: Url) -> Result<()> {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("failed to create runtime");
let result = rt
.block_on(crate::installed_extensions::get_installed_extensions(
connstr,
))
.expect("failed to get installed extensions");
info!(
"[NEON_EXT_STAT] {}",
serde_json::to_string(&result).expect("failed to serialize extensions list")
);
Ok(())
}
static INSTALLED_EXTENSIONS: Lazy<UIntGaugeVec> = Lazy::new(|| {

View File

@@ -17,8 +17,11 @@ const MONITOR_CHECK_INTERVAL: Duration = Duration::from_millis(500);
// should be handled gracefully.
fn watch_compute_activity(compute: &ComputeNode) {
// Suppose that `connstr` doesn't change
let connstr = compute.connstr.clone();
let conf = compute.get_conn_conf(Some("compute_ctl:activity_monitor"));
let mut connstr = compute.connstr.clone();
connstr
.query_pairs_mut()
.append_pair("application_name", "compute_activity_monitor");
let connstr = connstr.as_str();
// During startup and configuration we connect to every Postgres database,
// but we don't want to count this as some user activity. So wait until
@@ -26,7 +29,7 @@ fn watch_compute_activity(compute: &ComputeNode) {
wait_for_postgres_start(compute);
// Define `client` outside of the loop to reuse existing connection if it's active.
let mut client = conf.connect(NoTls);
let mut client = Client::connect(connstr, NoTls);
let mut sleep = false;
let mut prev_active_time: Option<f64> = None;
@@ -54,7 +57,7 @@ fn watch_compute_activity(compute: &ComputeNode) {
info!("connection to Postgres is closed, trying to reconnect");
// Connection is closed, reconnect and try again.
client = conf.connect(NoTls);
client = Client::connect(connstr, NoTls);
continue;
}
@@ -193,7 +196,7 @@ fn watch_compute_activity(compute: &ComputeNode) {
debug!("could not connect to Postgres: {}, retrying", e);
// Establish a new connection and try again.
client = conf.connect(NoTls);
client = Client::connect(connstr, NoTls);
}
}
}

View File

@@ -6,7 +6,6 @@ use std::io::{BufRead, BufReader};
use std::os::unix::fs::PermissionsExt;
use std::path::Path;
use std::process::Child;
use std::str::FromStr;
use std::thread::JoinHandle;
use std::time::{Duration, Instant};
@@ -14,10 +13,8 @@ use anyhow::{bail, Result};
use futures::StreamExt;
use ini::Ini;
use notify::{RecursiveMode, Watcher};
use postgres::config::Config;
use tokio::io::AsyncBufReadExt;
use tokio::time::timeout;
use tokio_postgres;
use tokio_postgres::NoTls;
use tracing::{debug, error, info, instrument};
@@ -545,11 +542,3 @@ async fn handle_postgres_logs_async(stderr: tokio::process::ChildStderr) -> Resu
Ok(())
}
/// `Postgres::config::Config` handles database names with whitespaces
/// and special characters properly.
pub fn postgres_conf_for_db(connstr: &url::Url, dbname: &str) -> Result<Config> {
let mut conf = Config::from_str(connstr.as_str())?;
conf.dbname(dbname);
Ok(conf)
}

View File

@@ -53,7 +53,6 @@ use compute_api::spec::Role;
use nix::sys::signal::kill;
use nix::sys::signal::Signal;
use pageserver_api::shard::ShardStripeSize;
use reqwest::header::CONTENT_TYPE;
use serde::{Deserialize, Serialize};
use url::Host;
use utils::id::{NodeId, TenantId, TimelineId};
@@ -311,10 +310,6 @@ impl Endpoint {
conf.append("wal_log_hints", "off");
conf.append("max_replication_slots", "10");
conf.append("hot_standby", "on");
// Set to 1MB to both exercise getPage requests/LFC, and still have enough room for
// Postgres to operate. Everything smaller might be not enough for Postgres under load,
// and can cause errors like 'no unpinned buffers available', see
// <https://github.com/neondatabase/neon/issues/9956>
conf.append("shared_buffers", "1MB");
conf.append("fsync", "off");
conf.append("max_connections", "100");
@@ -619,7 +614,6 @@ impl Endpoint {
pgbouncer_settings: None,
shard_stripe_size: Some(shard_stripe_size),
local_proxy_config: None,
reconfigure_concurrency: 1,
};
let spec_path = self.endpoint_path().join("spec.json");
std::fs::write(spec_path, serde_json::to_string_pretty(&spec)?)?;
@@ -819,7 +813,6 @@ impl Endpoint {
self.http_address.ip(),
self.http_address.port()
))
.header(CONTENT_TYPE.as_str(), "application/json")
.body(format!(
"{{\"spec\":{}}}",
serde_json::to_string_pretty(&spec)?

View File

@@ -415,11 +415,6 @@ impl PageServerNode {
.map(|x| x.parse::<bool>())
.transpose()
.context("Failed to parse 'timeline_offloading' as bool")?,
wal_receiver_protocol_override: settings
.remove("wal_receiver_protocol_override")
.map(serde_json::from_str)
.transpose()
.context("parse `wal_receiver_protocol_override` from json")?,
};
if !settings.is_empty() {
bail!("Unrecognized tenant settings: {settings:?}")

View File

@@ -5,7 +5,6 @@
//! ```text
//! .neon/safekeepers/<safekeeper id>
//! ```
use std::error::Error as _;
use std::future::Future;
use std::io::Write;
use std::path::PathBuf;
@@ -27,7 +26,7 @@ use crate::{
#[derive(Error, Debug)]
pub enum SafekeeperHttpError {
#[error("request error: {0}{}", .0.source().map(|e| format!(": {e}")).unwrap_or_default())]
#[error("Reqwest error: {0}")]
Transport(#[from] reqwest::Error),
#[error("Error: {0}")]

View File

@@ -560,26 +560,14 @@ async fn main() -> anyhow::Result<()> {
.await?;
}
Command::TenantDescribe { tenant_id } => {
let TenantDescribeResponse {
tenant_id,
shards,
stripe_size,
policy,
config,
} = storcon_client
let describe_response = storcon_client
.dispatch::<(), TenantDescribeResponse>(
Method::GET,
format!("control/v1/tenant/{tenant_id}"),
None,
)
.await?;
println!("Tenant {tenant_id}");
let mut table = comfy_table::Table::new();
table.add_row(["Policy", &format!("{:?}", policy)]);
table.add_row(["Stripe size", &format!("{:?}", stripe_size)]);
table.add_row(["Config", &serde_json::to_string_pretty(&config).unwrap()]);
println!("{table}");
println!("Shards:");
let shards = describe_response.shards;
let mut table = comfy_table::Table::new();
table.set_header(["Shard", "Attached", "Secondary", "Last error", "status"]);
for shard in shards {

View File

@@ -33,6 +33,7 @@ reason = "the marvin attack only affects private key decryption, not public key
[licenses]
allow = [
"Apache-2.0",
"Artistic-2.0",
"BSD-2-Clause",
"BSD-3-Clause",
"CC0-1.0",
@@ -66,7 +67,7 @@ registries = []
# More documentation about the 'bans' section can be found here:
# https://embarkstudios.github.io/cargo-deny/checks/bans/cfg.html
[bans]
multiple-versions = "allow"
multiple-versions = "warn"
wildcards = "allow"
highlight = "all"
workspace-default-features = "allow"

View File

@@ -4,16 +4,14 @@ ARG TAG=latest
FROM $REPOSITORY/${COMPUTE_IMAGE}:$TAG
ARG COMPUTE_IMAGE
USER root
RUN apt-get update && \
apt-get install -y curl \
jq \
python3-pip \
netcat-openbsd
netcat
#Faker is required for the pg_anon test
RUN case $COMPUTE_IMAGE in compute-node-v17) OPT="--break-system-packages";; *) OPT= ;; esac && pip3 install $OPT Faker
RUN pip3 install Faker
#This is required for the pg_hintplan test
RUN mkdir -p /ext-src/pg_hint_plan-src && chown postgres /ext-src/pg_hint_plan-src

View File

@@ -30,17 +30,10 @@ cleanup() {
docker compose --profile test-extensions -f $COMPOSE_FILE down
}
for pg_version in ${TEST_VERSION_ONLY-14 15 16 17}; do
pg_version=${pg_version/v/}
for pg_version in 14 15 16; do
echo "clean up containers if exists"
cleanup
PG_TEST_VERSION=$((pg_version < 16 ? 16 : pg_version))
# The support of pg_anon not yet added to PG17, so we have to remove the corresponding option
if [ $pg_version -eq 17 ]; then
SPEC_PATH="compute_wrapper/var/db/postgres/specs"
mv $SPEC_PATH/spec.json $SPEC_PATH/spec.bak
jq 'del(.cluster.settings[] | select (.name == "session_preload_libraries"))' $SPEC_PATH/spec.bak > $SPEC_PATH/spec.json
fi
PG_TEST_VERSION=$(($pg_version < 16 ? 16 : $pg_version))
PG_VERSION=$pg_version PG_TEST_VERSION=$PG_TEST_VERSION docker compose --profile test-extensions -f $COMPOSE_FILE up --build -d
echo "wait until the compute is ready. timeout after 60s. "
@@ -61,7 +54,8 @@ for pg_version in ${TEST_VERSION_ONLY-14 15 16 17}; do
fi
done
if [ $pg_version -ge 16 ]; then
if [ $pg_version -ge 16 ]
then
echo Enabling trust connection
docker exec $COMPUTE_CONTAINER_NAME bash -c "sed -i '\$d' /var/db/postgres/compute/pg_hba.conf && echo -e 'host\t all\t all\t all\t trust' >> /var/db/postgres/compute/pg_hba.conf && psql $PSQL_OPTION -c 'select pg_reload_conf()' "
echo Adding postgres role
@@ -74,13 +68,10 @@ for pg_version in ${TEST_VERSION_ONLY-14 15 16 17}; do
# The test assumes that it is running on the same host with the postgres engine.
# In our case it's not true, that's why we are copying files to the compute node
TMPDIR=$(mktemp -d)
# Add support for pg_anon for pg_v16
if [ $pg_version -ne 17 ]; then
docker cp $TEST_CONTAINER_NAME:/ext-src/pg_anon-src/data $TMPDIR/data
echo -e '1\t too \t many \t tabs' > $TMPDIR/data/bad.csv
docker cp $TMPDIR/data $COMPUTE_CONTAINER_NAME:/tmp/tmp_anon_alternate_data
docker cp $TEST_CONTAINER_NAME:/ext-src/pg_anon-src/data $TMPDIR/data
echo -e '1\t too \t many \t tabs' > $TMPDIR/data/bad.csv
docker cp $TMPDIR/data $COMPUTE_CONTAINER_NAME:/tmp/tmp_anon_alternate_data
rm -rf $TMPDIR
fi
TMPDIR=$(mktemp -d)
# The following block does the same for the pg_hintplan test
docker cp $TEST_CONTAINER_NAME:/ext-src/pg_hint_plan-src/data $TMPDIR/data
@@ -106,8 +97,4 @@ for pg_version in ${TEST_VERSION_ONLY-14 15 16 17}; do
fi
fi
cleanup
# The support of pg_anon not yet added to PG17, so we have to remove the corresponding option
if [ $pg_version -eq 17 ]; then
mv $SPEC_PATH/spec.bak $SPEC_PATH/spec.json
fi
done

View File

@@ -19,10 +19,6 @@ pub type PgIdent = String;
/// String type alias representing Postgres extension version
pub type ExtVersion = String;
fn default_reconfigure_concurrency() -> usize {
1
}
/// Cluster spec or configuration represented as an optional number of
/// delta operations + final cluster state description.
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
@@ -71,7 +67,7 @@ pub struct ComputeSpec {
pub cluster: Cluster,
pub delta_operations: Option<Vec<DeltaOp>>,
/// An optional hint that can be passed to speed up startup time if we know
/// An optinal hint that can be passed to speed up startup time if we know
/// that no pg catalog mutations (like role creation, database creation,
/// extension creation) need to be done on the actual database to start.
#[serde(default)] // Default false
@@ -90,7 +86,9 @@ pub struct ComputeSpec {
// etc. GUCs in cluster.settings. TODO: Once the control plane has been
// updated to fill these fields, we can make these non optional.
pub tenant_id: Option<TenantId>,
pub timeline_id: Option<TimelineId>,
pub pageserver_connstring: Option<String>,
#[serde(default)]
@@ -115,20 +113,6 @@ pub struct ComputeSpec {
/// Local Proxy configuration used for JWT authentication
#[serde(default)]
pub local_proxy_config: Option<LocalProxySpec>,
/// Number of concurrent connections during the parallel RunInEachDatabase
/// phase of the apply config process.
///
/// We need a higher concurrency during reconfiguration in case of many DBs,
/// but instance is already running and used by client. We can easily get out of
/// `max_connections` limit, and the current code won't handle that.
///
/// Default is 1, but also allow control plane to override this value for specific
/// projects. It's also recommended to bump `superuser_reserved_connections` +=
/// `reconfigure_concurrency` for such projects to ensure that we always have
/// enough spare connections for reconfiguration process to succeed.
#[serde(default = "default_reconfigure_concurrency")]
pub reconfigure_concurrency: usize,
}
/// Feature flag to signal `compute_ctl` to enable certain experimental functionality.
@@ -331,9 +315,6 @@ mod tests {
// Features list defaults to empty vector.
assert!(spec.features.is_empty());
// Reconfigure concurrency defaults to 1.
assert_eq!(spec.reconfigure_concurrency, 1);
}
#[test]

View File

@@ -103,12 +103,11 @@ impl<'a> IdempotencyKey<'a> {
}
}
/// Split into chunks of 1000 metrics to avoid exceeding the max request size.
pub const CHUNK_SIZE: usize = 1000;
// Just a wrapper around a slice of events
// to serialize it as `{"events" : [ ] }
#[derive(Debug, serde::Serialize, serde::Deserialize, PartialEq)]
pub struct EventChunk<'a, T: Clone + PartialEq> {
#[derive(serde::Serialize, Deserialize)]
pub struct EventChunk<'a, T: Clone> {
pub events: std::borrow::Cow<'a, [T]>,
}

View File

@@ -2,28 +2,14 @@
// This module has heavy inspiration from the prometheus crate's `process_collector.rs`.
use once_cell::sync::Lazy;
use prometheus::Gauge;
use crate::UIntGauge;
pub struct Collector {
descs: Vec<prometheus::core::Desc>,
vmlck: crate::UIntGauge,
cpu_seconds_highres: Gauge,
}
const NMETRICS: usize = 2;
static CLK_TCK_F64: Lazy<f64> = Lazy::new(|| {
let long = unsafe { libc::sysconf(libc::_SC_CLK_TCK) };
if long == -1 {
panic!("sysconf(_SC_CLK_TCK) failed");
}
let convertible_to_f64: i32 =
i32::try_from(long).expect("sysconf(_SC_CLK_TCK) is larger than i32");
convertible_to_f64 as f64
});
const NMETRICS: usize = 1;
impl prometheus::core::Collector for Collector {
fn desc(&self) -> Vec<&prometheus::core::Desc> {
@@ -41,12 +27,6 @@ impl prometheus::core::Collector for Collector {
mfs.extend(self.vmlck.collect())
}
}
if let Ok(stat) = myself.stat() {
let cpu_seconds = stat.utime + stat.stime;
self.cpu_seconds_highres
.set(cpu_seconds as f64 / *CLK_TCK_F64);
mfs.extend(self.cpu_seconds_highres.collect());
}
mfs
}
}
@@ -63,23 +43,7 @@ impl Collector {
.cloned(),
);
let cpu_seconds_highres = Gauge::new(
"libmetrics_process_cpu_seconds_highres",
"Total user and system CPU time spent in seconds.\
Sub-second resolution, hence better than `process_cpu_seconds_total`.",
)
.unwrap();
descs.extend(
prometheus::core::Collector::desc(&cpu_seconds_highres)
.into_iter()
.cloned(),
);
Self {
descs,
vmlck,
cpu_seconds_highres,
}
Self { descs, vmlck }
}
}

View File

@@ -18,7 +18,7 @@ use std::{
str::FromStr,
time::Duration,
};
use utils::{logging::LogFormat, postgres_client::PostgresClientProtocol};
use utils::logging::LogFormat;
use crate::models::ImageCompressionAlgorithm;
use crate::models::LsnLease;
@@ -118,8 +118,8 @@ pub struct ConfigToml {
pub virtual_file_io_mode: Option<crate::models::virtual_file::IoMode>,
#[serde(skip_serializing_if = "Option::is_none")]
pub no_sync: Option<bool>,
pub wal_receiver_protocol: PostgresClientProtocol,
pub page_service_pipelining: PageServicePipeliningConfig,
#[serde(with = "humantime_serde")]
pub server_side_batch_timeout: Option<Duration>,
}
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
@@ -136,28 +136,6 @@ pub struct DiskUsageEvictionTaskConfig {
pub eviction_order: EvictionOrder,
}
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(tag = "mode", rename_all = "kebab-case")]
#[serde(deny_unknown_fields)]
pub enum PageServicePipeliningConfig {
Serial,
Pipelined(PageServicePipeliningConfigPipelined),
}
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(deny_unknown_fields)]
pub struct PageServicePipeliningConfigPipelined {
/// Causes runtime errors if larger than max get_vectored batch size.
pub max_batch_size: NonZeroUsize,
pub execution: PageServiceProtocolPipelinedExecutionStrategy,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum PageServiceProtocolPipelinedExecutionStrategy {
ConcurrentFutures,
Tasks,
}
pub mod statvfs {
pub mod mock {
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
@@ -299,8 +277,6 @@ pub struct TenantConfigToml {
/// Enable auto-offloading of timelines.
/// (either this flag or the pageserver-global one need to be set)
pub timeline_offloading: bool,
pub wal_receiver_protocol_override: Option<PostgresClientProtocol>,
}
pub mod defaults {
@@ -353,8 +329,7 @@ pub mod defaults {
pub const DEFAULT_IO_BUFFER_ALIGNMENT: usize = 512;
pub const DEFAULT_WAL_RECEIVER_PROTOCOL: utils::postgres_client::PostgresClientProtocol =
utils::postgres_client::PostgresClientProtocol::Vanilla;
pub const DEFAULT_SERVER_SIDE_BATCH_TIMEOUT: Option<&str> = None;
}
impl Default for ConfigToml {
@@ -439,17 +414,10 @@ impl Default for ConfigToml {
ephemeral_bytes_per_memory_kb: (DEFAULT_EPHEMERAL_BYTES_PER_MEMORY_KB),
l0_flush: None,
virtual_file_io_mode: None,
server_side_batch_timeout: DEFAULT_SERVER_SIDE_BATCH_TIMEOUT
.map(|duration| humantime::parse_duration(duration).unwrap()),
tenant_config: TenantConfigToml::default(),
no_sync: None,
wal_receiver_protocol: DEFAULT_WAL_RECEIVER_PROTOCOL,
page_service_pipelining: if !cfg!(test) {
PageServicePipeliningConfig::Serial
} else {
PageServicePipeliningConfig::Pipelined(PageServicePipeliningConfigPipelined {
max_batch_size: NonZeroUsize::new(32).unwrap(),
execution: PageServiceProtocolPipelinedExecutionStrategy::ConcurrentFutures,
})
},
}
}
}
@@ -537,7 +505,6 @@ impl Default for TenantConfigToml {
lsn_lease_length: LsnLease::DEFAULT_LENGTH,
lsn_lease_length_for_ts: LsnLease::DEFAULT_LENGTH_FOR_TS,
timeline_offloading: false,
wal_receiver_protocol_override: None,
}
}
}

View File

@@ -48,7 +48,7 @@ pub struct TenantCreateResponse {
pub shards: Vec<TenantCreateResponseShard>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[derive(Serialize, Deserialize)]
pub struct NodeRegisterRequest {
pub node_id: NodeId,
@@ -75,7 +75,7 @@ pub struct TenantPolicyRequest {
pub scheduling: Option<ShardSchedulingPolicy>,
}
#[derive(Clone, Serialize, Deserialize, PartialEq, Eq, Hash, Debug)]
#[derive(Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct AvailabilityZone(pub String);
impl Display for AvailabilityZone {

View File

@@ -229,18 +229,6 @@ impl Key {
}
}
impl CompactKey {
pub fn raw(&self) -> i128 {
self.0
}
}
impl From<i128> for CompactKey {
fn from(value: i128) -> Self {
Self(value)
}
}
impl fmt::Display for Key {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
@@ -770,11 +758,6 @@ impl Key {
&& self.field6 == 1
}
#[inline(always)]
pub fn is_aux_file_key(&self) -> bool {
self.field1 == AUX_KEY_PREFIX
}
/// Guaranteed to return `Ok()` if [`Self::is_rel_block_key`] returns `true` for `key`.
#[inline(always)]
pub fn to_rel_block(self) -> anyhow::Result<(RelTag, BlockNumber)> {

View File

@@ -23,7 +23,6 @@ use utils::{
completion,
id::{NodeId, TenantId, TimelineId},
lsn::Lsn,
postgres_client::PostgresClientProtocol,
serde_system_time,
};
@@ -353,7 +352,6 @@ pub struct TenantConfig {
pub lsn_lease_length: Option<String>,
pub lsn_lease_length_for_ts: Option<String>,
pub timeline_offloading: Option<bool>,
pub wal_receiver_protocol_override: Option<PostgresClientProtocol>,
}
/// The policy for the aux file storage.
@@ -501,9 +499,7 @@ pub struct EvictionPolicyLayerAccessThreshold {
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
pub struct ThrottleConfig {
/// See [`ThrottleConfigTaskKinds`] for why we do the serde `rename`.
#[serde(rename = "task_kinds")]
pub enabled: ThrottleConfigTaskKinds,
pub task_kinds: Vec<String>, // TaskKind
pub initial: u32,
#[serde(with = "humantime_serde")]
pub refill_interval: Duration,
@@ -511,38 +507,10 @@ pub struct ThrottleConfig {
pub max: u32,
}
/// Before <https://github.com/neondatabase/neon/pull/9962>
/// the throttle was a per `Timeline::get`/`Timeline::get_vectored` call.
/// The `task_kinds` field controlled which Pageserver "Task Kind"s
/// were subject to the throttle.
///
/// After that PR, the throttle is applied at pagestream request level
/// and the `task_kinds` field does not apply since the only task kind
/// that us subject to the throttle is that of the page service.
///
/// However, we don't want to make a breaking config change right now
/// because it means we have to migrate all the tenant configs.
/// This will be done in a future PR.
///
/// In the meantime, we use emptiness / non-emptsiness of the `task_kinds`
/// field to determine if the throttle is enabled or not.
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(transparent)]
pub struct ThrottleConfigTaskKinds(Vec<String>);
impl ThrottleConfigTaskKinds {
pub fn disabled() -> Self {
Self(vec![])
}
pub fn is_enabled(&self) -> bool {
!self.0.is_empty()
}
}
impl ThrottleConfig {
pub fn disabled() -> Self {
Self {
enabled: ThrottleConfigTaskKinds::disabled(),
task_kinds: vec![], // effectively disables the throttle
// other values don't matter with emtpy `task_kinds`.
initial: 0,
refill_interval: Duration::from_millis(1),
@@ -556,30 +524,6 @@ impl ThrottleConfig {
}
}
#[cfg(test)]
mod throttle_config_tests {
use super::*;
#[test]
fn test_disabled_is_disabled() {
let config = ThrottleConfig::disabled();
assert!(!config.enabled.is_enabled());
}
#[test]
fn test_enabled_backwards_compat() {
let input = serde_json::json!({
"task_kinds": ["PageRequestHandler"],
"initial": 40000,
"refill_interval": "50ms",
"refill_amount": 1000,
"max": 40000,
"fair": true
});
let config: ThrottleConfig = serde_json::from_value(input).unwrap();
assert!(config.enabled.is_enabled());
}
}
/// A flattened analog of a `pagesever::tenant::LocationMode`, which
/// lists out all possible states (and the virtual "Detached" state)
/// in a flat form rather than using rust-style enums.

View File

@@ -158,8 +158,7 @@ impl ShardIdentity {
key_to_shard_number(self.count, self.stripe_size, key)
}
/// Return true if the key is stored only on this shard. This does not include
/// global keys, see is_key_global().
/// Return true if the key should be ingested by this shard
///
/// Shards must ingest _at least_ keys which return true from this check.
pub fn is_key_local(&self, key: &Key) -> bool {
@@ -171,37 +170,19 @@ impl ShardIdentity {
}
}
/// Return true if the key should be stored on all shards, not just one.
pub fn is_key_global(&self, key: &Key) -> bool {
if key.is_slru_block_key() || key.is_slru_segment_size_key() || key.is_aux_file_key() {
// Special keys that are only stored on shard 0
false
} else if key.is_rel_block_key() {
// Ordinary relation blocks are distributed across shards
false
} else if key.is_rel_size_key() {
// All shards maintain rel size keys (although only shard 0 is responsible for
// keeping it strictly accurate, other shards just reflect the highest block they've ingested)
true
} else {
// For everything else, we assume it must be kept everywhere, because ingest code
// might assume this -- this covers functionality where the ingest code has
// not (yet) been made fully shard aware.
true
}
}
/// Return true if the key should be discarded if found in this shard's
/// data store, e.g. during compaction after a split.
///
/// Shards _may_ drop keys which return false here, but are not obliged to.
pub fn is_key_disposable(&self, key: &Key) -> bool {
if self.count < ShardCount(2) {
// Fast path: unsharded tenant doesn't dispose of anything
return false;
}
if self.is_key_global(key) {
if key_is_shard0(key) {
// Q: Why can't we dispose of shard0 content if we're not shard 0?
// A1: because the WAL ingestion logic currently ingests some shard 0
// content on all shards, even though it's only read on shard 0. If we
// dropped it, then subsequent WAL ingest to these keys would encounter
// an error.
// A2: because key_is_shard0 also covers relation size keys, which are written
// on all shards even though they're only maintained accurately on shard 0.
false
} else {
!self.is_key_local(key)

View File

@@ -100,7 +100,7 @@ impl StartupMessageParamsBuilder {
#[derive(Debug, Clone, Default)]
pub struct StartupMessageParams {
pub params: Bytes,
params: Bytes,
}
impl StartupMessageParams {
@@ -562,11 +562,6 @@ pub enum BeMessage<'a> {
options: &'a [&'a str],
},
KeepAlive(WalSndKeepAlive),
/// Batch of interpreted, shard filtered WAL records,
/// ready for the pageserver to ingest
InterpretedWalRecords(InterpretedWalRecordsBody<'a>),
Raw(u8, &'a [u8]),
}
/// Common shorthands.
@@ -677,22 +672,6 @@ pub struct WalSndKeepAlive {
pub request_reply: bool,
}
/// Batch of interpreted WAL records used in the interpreted
/// safekeeper to pageserver protocol.
///
/// Note that the pageserver uses the RawInterpretedWalRecordsBody
/// counterpart of this from the neondatabase/rust-postgres repo.
/// If you're changing this struct, you likely need to change its
/// twin as well.
#[derive(Debug)]
pub struct InterpretedWalRecordsBody<'a> {
/// End of raw WAL in [`Self::data`]
pub streaming_lsn: u64,
/// Current end of WAL on the server
pub commit_lsn: u64,
pub data: &'a [u8],
}
pub static HELLO_WORLD_ROW: BeMessage = BeMessage::DataRow(&[Some(b"hello world")]);
// single text column
@@ -756,10 +735,6 @@ impl BeMessage<'_> {
/// one more buffer.
pub fn write(buf: &mut BytesMut, message: &BeMessage) -> Result<(), ProtocolError> {
match message {
BeMessage::Raw(code, data) => {
buf.put_u8(*code);
write_body(buf, |b| b.put_slice(data))
}
BeMessage::AuthenticationOk => {
buf.put_u8(b'R');
write_body(buf, |buf| {
@@ -1021,19 +996,6 @@ impl BeMessage<'_> {
Ok(())
})?
}
BeMessage::InterpretedWalRecords(rec) => {
// We use the COPY_DATA_TAG for our custom message
// since this tag is interpreted as raw bytes.
buf.put_u8(b'd');
write_body(buf, |buf| {
buf.put_u8(b'0'); // matches INTERPRETED_WAL_RECORD_TAG in postgres-protocol
// dependency
buf.put_u64(rec.streaming_lsn);
buf.put_u64(rec.commit_lsn);
buf.put_slice(rec.data);
});
}
}
Ok(())
}

View File

@@ -1,6 +0,0 @@
This directory contains libraries that are specific for proxy.
Currently, it contains a signficant fork/refactoring of rust-postgres that no longer reflects the API
of the original library. Since it was so significant, it made sense to upgrade it to it's own set of libraries.
Proxy needs unique access to the protocol, which explains why such heavy modifications were necessary.

View File

@@ -1,20 +0,0 @@
[package]
name = "postgres-protocol2"
version = "0.1.0"
edition = "2018"
license = "MIT/Apache-2.0"
[dependencies]
base64 = "0.20"
byteorder.workspace = true
bytes.workspace = true
fallible-iterator.workspace = true
hmac.workspace = true
memchr = "2.0"
rand.workspace = true
sha2.workspace = true
stringprep = "0.1"
tokio = { workspace = true, features = ["rt"] }
[dev-dependencies]
tokio = { workspace = true, features = ["full"] }

View File

@@ -1,2 +0,0 @@
//! Authentication protocol support.
pub mod sasl;

View File

@@ -1,516 +0,0 @@
//! SASL-based authentication support.
use hmac::{Hmac, Mac};
use rand::{self, Rng};
use sha2::digest::FixedOutput;
use sha2::{Digest, Sha256};
use std::fmt::Write;
use std::io;
use std::iter;
use std::mem;
use std::str;
use tokio::task::yield_now;
const NONCE_LENGTH: usize = 24;
/// The identifier of the SCRAM-SHA-256 SASL authentication mechanism.
pub const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
/// The identifier of the SCRAM-SHA-256-PLUS SASL authentication mechanism.
pub const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
// since postgres passwords are not required to exclude saslprep-prohibited
// characters or even be valid UTF8, we run saslprep if possible and otherwise
// return the raw password.
fn normalize(pass: &[u8]) -> Vec<u8> {
let pass = match str::from_utf8(pass) {
Ok(pass) => pass,
Err(_) => return pass.to_vec(),
};
match stringprep::saslprep(pass) {
Ok(pass) => pass.into_owned().into_bytes(),
Err(_) => pass.as_bytes().to_vec(),
}
}
pub(crate) async fn hi(str: &[u8], salt: &[u8], iterations: u32) -> [u8; 32] {
let mut hmac =
Hmac::<Sha256>::new_from_slice(str).expect("HMAC is able to accept all key sizes");
hmac.update(salt);
hmac.update(&[0, 0, 0, 1]);
let mut prev = hmac.finalize().into_bytes();
let mut hi = prev;
for i in 1..iterations {
let mut hmac = Hmac::<Sha256>::new_from_slice(str).expect("already checked above");
hmac.update(&prev);
prev = hmac.finalize().into_bytes();
for (hi, prev) in hi.iter_mut().zip(prev) {
*hi ^= prev;
}
// yield every ~250us
// hopefully reduces tail latencies
if i % 1024 == 0 {
yield_now().await
}
}
hi.into()
}
enum ChannelBindingInner {
Unrequested,
Unsupported,
TlsServerEndPoint(Vec<u8>),
}
/// The channel binding configuration for a SCRAM authentication exchange.
pub struct ChannelBinding(ChannelBindingInner);
impl ChannelBinding {
/// The server did not request channel binding.
pub fn unrequested() -> ChannelBinding {
ChannelBinding(ChannelBindingInner::Unrequested)
}
/// The server requested channel binding but the client is unable to provide it.
pub fn unsupported() -> ChannelBinding {
ChannelBinding(ChannelBindingInner::Unsupported)
}
/// The server requested channel binding and the client will use the `tls-server-end-point`
/// method.
pub fn tls_server_end_point(signature: Vec<u8>) -> ChannelBinding {
ChannelBinding(ChannelBindingInner::TlsServerEndPoint(signature))
}
fn gs2_header(&self) -> &'static str {
match self.0 {
ChannelBindingInner::Unrequested => "y,,",
ChannelBindingInner::Unsupported => "n,,",
ChannelBindingInner::TlsServerEndPoint(_) => "p=tls-server-end-point,,",
}
}
fn cbind_data(&self) -> &[u8] {
match self.0 {
ChannelBindingInner::Unrequested | ChannelBindingInner::Unsupported => &[],
ChannelBindingInner::TlsServerEndPoint(ref buf) => buf,
}
}
}
/// A pair of keys for the SCRAM-SHA-256 mechanism.
/// See <https://datatracker.ietf.org/doc/html/rfc5802#section-3> for details.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ScramKeys<const N: usize> {
/// Used by server to authenticate client.
pub client_key: [u8; N],
/// Used by client to verify server's signature.
pub server_key: [u8; N],
}
/// Password or keys which were derived from it.
enum Credentials<const N: usize> {
/// A regular password as a vector of bytes.
Password(Vec<u8>),
/// A precomputed pair of keys.
Keys(ScramKeys<N>),
}
enum State {
Update {
nonce: String,
password: Credentials<32>,
channel_binding: ChannelBinding,
},
Finish {
server_key: [u8; 32],
auth_message: String,
},
Done,
}
/// A type which handles the client side of the SCRAM-SHA-256/SCRAM-SHA-256-PLUS authentication
/// process.
///
/// During the authentication process, if the backend sends an `AuthenticationSASL` message which
/// includes `SCRAM-SHA-256` as an authentication mechanism, this type can be used.
///
/// After a `ScramSha256` is constructed, the buffer returned by the `message()` method should be
/// sent to the backend in a `SASLInitialResponse` message along with the mechanism name.
///
/// The server will reply with an `AuthenticationSASLContinue` message. Its contents should be
/// passed to the `update()` method, after which the buffer returned by the `message()` method
/// should be sent to the backend in a `SASLResponse` message.
///
/// The server will reply with an `AuthenticationSASLFinal` message. Its contents should be passed
/// to the `finish()` method, after which the authentication process is complete.
pub struct ScramSha256 {
message: String,
state: State,
}
fn nonce() -> String {
// rand 0.5's ThreadRng is cryptographically secure
let mut rng = rand::thread_rng();
(0..NONCE_LENGTH)
.map(|_| {
let mut v = rng.gen_range(0x21u8..0x7e);
if v == 0x2c {
v = 0x7e
}
v as char
})
.collect()
}
impl ScramSha256 {
/// Constructs a new instance which will use the provided password for authentication.
pub fn new(password: &[u8], channel_binding: ChannelBinding) -> ScramSha256 {
let password = Credentials::Password(normalize(password));
ScramSha256::new_inner(password, channel_binding, nonce())
}
/// Constructs a new instance which will use the provided key pair for authentication.
pub fn new_with_keys(keys: ScramKeys<32>, channel_binding: ChannelBinding) -> ScramSha256 {
let password = Credentials::Keys(keys);
ScramSha256::new_inner(password, channel_binding, nonce())
}
fn new_inner(
password: Credentials<32>,
channel_binding: ChannelBinding,
nonce: String,
) -> ScramSha256 {
ScramSha256 {
message: format!("{}n=,r={}", channel_binding.gs2_header(), nonce),
state: State::Update {
nonce,
password,
channel_binding,
},
}
}
/// Returns the message which should be sent to the backend in an `SASLResponse` message.
pub fn message(&self) -> &[u8] {
if let State::Done = self.state {
panic!("invalid SCRAM state");
}
self.message.as_bytes()
}
/// Updates the state machine with the response from the backend.
///
/// This should be called when an `AuthenticationSASLContinue` message is received.
pub async fn update(&mut self, message: &[u8]) -> io::Result<()> {
let (client_nonce, password, channel_binding) =
match mem::replace(&mut self.state, State::Done) {
State::Update {
nonce,
password,
channel_binding,
} => (nonce, password, channel_binding),
_ => return Err(io::Error::new(io::ErrorKind::Other, "invalid SCRAM state")),
};
let message =
str::from_utf8(message).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
let parsed = Parser::new(message).server_first_message()?;
if !parsed.nonce.starts_with(&client_nonce) {
return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid nonce"));
}
let (client_key, server_key) = match password {
Credentials::Password(password) => {
let salt = match base64::decode(parsed.salt) {
Ok(salt) => salt,
Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
};
let salted_password = hi(&password, &salt, parsed.iteration_count).await;
let make_key = |name| {
let mut hmac = Hmac::<Sha256>::new_from_slice(&salted_password)
.expect("HMAC is able to accept all key sizes");
hmac.update(name);
let mut key = [0u8; 32];
key.copy_from_slice(hmac.finalize().into_bytes().as_slice());
key
};
(make_key(b"Client Key"), make_key(b"Server Key"))
}
Credentials::Keys(keys) => (keys.client_key, keys.server_key),
};
let mut hash = Sha256::default();
hash.update(client_key);
let stored_key = hash.finalize_fixed();
let mut cbind_input = vec![];
cbind_input.extend(channel_binding.gs2_header().as_bytes());
cbind_input.extend(channel_binding.cbind_data());
let cbind_input = base64::encode(&cbind_input);
self.message.clear();
write!(&mut self.message, "c={},r={}", cbind_input, parsed.nonce).unwrap();
let auth_message = format!("n=,r={},{},{}", client_nonce, message, self.message);
let mut hmac = Hmac::<Sha256>::new_from_slice(&stored_key)
.expect("HMAC is able to accept all key sizes");
hmac.update(auth_message.as_bytes());
let client_signature = hmac.finalize().into_bytes();
let mut client_proof = client_key;
for (proof, signature) in client_proof.iter_mut().zip(client_signature) {
*proof ^= signature;
}
write!(&mut self.message, ",p={}", base64::encode(client_proof)).unwrap();
self.state = State::Finish {
server_key,
auth_message,
};
Ok(())
}
/// Finalizes the authentication process.
///
/// This should be called when the backend sends an `AuthenticationSASLFinal` message.
/// Authentication has only succeeded if this method returns `Ok(())`.
pub fn finish(&mut self, message: &[u8]) -> io::Result<()> {
let (server_key, auth_message) = match mem::replace(&mut self.state, State::Done) {
State::Finish {
server_key,
auth_message,
} => (server_key, auth_message),
_ => return Err(io::Error::new(io::ErrorKind::Other, "invalid SCRAM state")),
};
let message =
str::from_utf8(message).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
let parsed = Parser::new(message).server_final_message()?;
let verifier = match parsed {
ServerFinalMessage::Error(e) => {
return Err(io::Error::new(
io::ErrorKind::Other,
format!("SCRAM error: {}", e),
));
}
ServerFinalMessage::Verifier(verifier) => verifier,
};
let verifier = match base64::decode(verifier) {
Ok(verifier) => verifier,
Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
};
let mut hmac = Hmac::<Sha256>::new_from_slice(&server_key)
.expect("HMAC is able to accept all key sizes");
hmac.update(auth_message.as_bytes());
hmac.verify_slice(&verifier)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "SCRAM verification error"))
}
}
struct Parser<'a> {
s: &'a str,
it: iter::Peekable<str::CharIndices<'a>>,
}
impl<'a> Parser<'a> {
fn new(s: &'a str) -> Parser<'a> {
Parser {
s,
it: s.char_indices().peekable(),
}
}
fn eat(&mut self, target: char) -> io::Result<()> {
match self.it.next() {
Some((_, c)) if c == target => Ok(()),
Some((i, c)) => {
let m = format!(
"unexpected character at byte {}: expected `{}` but got `{}",
i, target, c
);
Err(io::Error::new(io::ErrorKind::InvalidInput, m))
}
None => Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"unexpected EOF",
)),
}
}
fn take_while<F>(&mut self, f: F) -> io::Result<&'a str>
where
F: Fn(char) -> bool,
{
let start = match self.it.peek() {
Some(&(i, _)) => i,
None => return Ok(""),
};
loop {
match self.it.peek() {
Some(&(_, c)) if f(c) => {
self.it.next();
}
Some(&(i, _)) => return Ok(&self.s[start..i]),
None => return Ok(&self.s[start..]),
}
}
}
fn printable(&mut self) -> io::Result<&'a str> {
self.take_while(|c| matches!(c, '\x21'..='\x2b' | '\x2d'..='\x7e'))
}
fn nonce(&mut self) -> io::Result<&'a str> {
self.eat('r')?;
self.eat('=')?;
self.printable()
}
fn base64(&mut self) -> io::Result<&'a str> {
self.take_while(|c| matches!(c, 'a'..='z' | 'A'..='Z' | '0'..='9' | '/' | '+' | '='))
}
fn salt(&mut self) -> io::Result<&'a str> {
self.eat('s')?;
self.eat('=')?;
self.base64()
}
fn posit_number(&mut self) -> io::Result<u32> {
let n = self.take_while(|c| c.is_ascii_digit())?;
n.parse()
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))
}
fn iteration_count(&mut self) -> io::Result<u32> {
self.eat('i')?;
self.eat('=')?;
self.posit_number()
}
fn eof(&mut self) -> io::Result<()> {
match self.it.peek() {
Some(&(i, _)) => Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("unexpected trailing data at byte {}", i),
)),
None => Ok(()),
}
}
fn server_first_message(&mut self) -> io::Result<ServerFirstMessage<'a>> {
let nonce = self.nonce()?;
self.eat(',')?;
let salt = self.salt()?;
self.eat(',')?;
let iteration_count = self.iteration_count()?;
self.eof()?;
Ok(ServerFirstMessage {
nonce,
salt,
iteration_count,
})
}
fn value(&mut self) -> io::Result<&'a str> {
self.take_while(|c| matches!(c, '\0' | '=' | ','))
}
fn server_error(&mut self) -> io::Result<Option<&'a str>> {
match self.it.peek() {
Some(&(_, 'e')) => {}
_ => return Ok(None),
}
self.eat('e')?;
self.eat('=')?;
self.value().map(Some)
}
fn verifier(&mut self) -> io::Result<&'a str> {
self.eat('v')?;
self.eat('=')?;
self.base64()
}
fn server_final_message(&mut self) -> io::Result<ServerFinalMessage<'a>> {
let message = match self.server_error()? {
Some(error) => ServerFinalMessage::Error(error),
None => ServerFinalMessage::Verifier(self.verifier()?),
};
self.eof()?;
Ok(message)
}
}
struct ServerFirstMessage<'a> {
nonce: &'a str,
salt: &'a str,
iteration_count: u32,
}
enum ServerFinalMessage<'a> {
Error(&'a str),
Verifier(&'a str),
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn parse_server_first_message() {
let message = "r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096";
let message = Parser::new(message).server_first_message().unwrap();
assert_eq!(message.nonce, "fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j");
assert_eq!(message.salt, "QSXCR+Q6sek8bf92");
assert_eq!(message.iteration_count, 4096);
}
// recorded auth exchange from psql
#[tokio::test]
async fn exchange() {
let password = "foobar";
let nonce = "9IZ2O01zb9IgiIZ1WJ/zgpJB";
let client_first = "n,,n=,r=9IZ2O01zb9IgiIZ1WJ/zgpJB";
let server_first =
"r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,s=fs3IXBy7U7+IvVjZ,i\
=4096";
let client_final =
"c=biws,r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,p=AmNKosjJzS3\
1NTlQYNs5BTeQjdHdk7lOflDo5re2an8=";
let server_final = "v=U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw=";
let mut scram = ScramSha256::new_inner(
Credentials::Password(normalize(password.as_bytes())),
ChannelBinding::unsupported(),
nonce.to_string(),
);
assert_eq!(str::from_utf8(scram.message()).unwrap(), client_first);
scram.update(server_first.as_bytes()).await.unwrap();
assert_eq!(str::from_utf8(scram.message()).unwrap(), client_final);
scram.finish(server_final.as_bytes()).unwrap();
}
}

View File

@@ -1,93 +0,0 @@
//! Provides functions for escaping literals and identifiers for use
//! in SQL queries.
//!
//! Prefer parameterized queries where possible. Do not escape
//! parameters in a parameterized query.
#[cfg(test)]
mod test;
/// Escape a literal and surround result with single quotes. Not
/// recommended in most cases.
///
/// If input contains backslashes, result will be of the form `
/// E'...'` so it is safe to use regardless of the setting of
/// standard_conforming_strings.
pub fn escape_literal(input: &str) -> String {
escape_internal(input, false)
}
/// Escape an identifier and surround result with double quotes.
pub fn escape_identifier(input: &str) -> String {
escape_internal(input, true)
}
// Translation of PostgreSQL libpq's PQescapeInternal(). Does not
// require a connection because input string is known to be valid
// UTF-8.
//
// Escape arbitrary strings. If as_ident is true, we escape the
// result as an identifier; if false, as a literal. The result is
// returned in a newly allocated buffer. If we fail due to an
// encoding violation or out of memory condition, we return NULL,
// storing an error message into conn.
fn escape_internal(input: &str, as_ident: bool) -> String {
let mut num_backslashes = 0;
let mut num_quotes = 0;
let quote_char = if as_ident { '"' } else { '\'' };
// Scan the string for characters that must be escaped.
for ch in input.chars() {
if ch == quote_char {
num_quotes += 1;
} else if ch == '\\' {
num_backslashes += 1;
}
}
// Allocate output String.
let mut result_size = input.len() + num_quotes + 3; // two quotes, plus a NUL
if !as_ident && num_backslashes > 0 {
result_size += num_backslashes + 2;
}
let mut output = String::with_capacity(result_size);
// If we are escaping a literal that contains backslashes, we use
// the escape string syntax so that the result is correct under
// either value of standard_conforming_strings. We also emit a
// leading space in this case, to guard against the possibility
// that the result might be interpolated immediately following an
// identifier.
if !as_ident && num_backslashes > 0 {
output.push(' ');
output.push('E');
}
// Opening quote.
output.push(quote_char);
// Use fast path if possible.
//
// We've already verified that the input string is well-formed in
// the current encoding. If it contains no quotes and, in the
// case of literal-escaping, no backslashes, then we can just copy
// it directly to the output buffer, adding the necessary quotes.
//
// If not, we must rescan the input and process each character
// individually.
if num_quotes == 0 && (num_backslashes == 0 || as_ident) {
output.push_str(input);
} else {
for ch in input.chars() {
if ch == quote_char || (!as_ident && ch == '\\') {
output.push(ch);
}
output.push(ch);
}
}
output.push(quote_char);
output
}

View File

@@ -1,17 +0,0 @@
use crate::escape::{escape_identifier, escape_literal};
#[test]
fn test_escape_idenifier() {
assert_eq!(escape_identifier("foo"), String::from("\"foo\""));
assert_eq!(escape_identifier("f\\oo"), String::from("\"f\\oo\""));
assert_eq!(escape_identifier("f'oo"), String::from("\"f'oo\""));
assert_eq!(escape_identifier("f\"oo"), String::from("\"f\"\"oo\""));
}
#[test]
fn test_escape_literal() {
assert_eq!(escape_literal("foo"), String::from("'foo'"));
assert_eq!(escape_literal("f\\oo"), String::from(" E'f\\\\oo'"));
assert_eq!(escape_literal("f'oo"), String::from("'f''oo'"));
assert_eq!(escape_literal("f\"oo"), String::from("'f\"oo'"));
}

View File

@@ -1,78 +0,0 @@
//! Low level Postgres protocol APIs.
//!
//! This crate implements the low level components of Postgres's communication
//! protocol, including message and value serialization and deserialization.
//! It is designed to be used as a building block by higher level APIs such as
//! `rust-postgres`, and should not typically be used directly.
//!
//! # Note
//!
//! This library assumes that the `client_encoding` backend parameter has been
//! set to `UTF8`. It will most likely not behave properly if that is not the case.
#![doc(html_root_url = "https://docs.rs/postgres-protocol/0.6")]
#![warn(missing_docs, rust_2018_idioms, clippy::all)]
use byteorder::{BigEndian, ByteOrder};
use bytes::{BufMut, BytesMut};
use std::io;
pub mod authentication;
pub mod escape;
pub mod message;
pub mod password;
pub mod types;
/// A Postgres OID.
pub type Oid = u32;
/// A Postgres Log Sequence Number (LSN).
pub type Lsn = u64;
/// An enum indicating if a value is `NULL` or not.
pub enum IsNull {
/// The value is `NULL`.
Yes,
/// The value is not `NULL`.
No,
}
fn write_nullable<F, E>(serializer: F, buf: &mut BytesMut) -> Result<(), E>
where
F: FnOnce(&mut BytesMut) -> Result<IsNull, E>,
E: From<io::Error>,
{
let base = buf.len();
buf.put_i32(0);
let size = match serializer(buf)? {
IsNull::No => i32::from_usize(buf.len() - base - 4)?,
IsNull::Yes => -1,
};
BigEndian::write_i32(&mut buf[base..], size);
Ok(())
}
trait FromUsize: Sized {
fn from_usize(x: usize) -> Result<Self, io::Error>;
}
macro_rules! from_usize {
($t:ty) => {
impl FromUsize for $t {
#[inline]
fn from_usize(x: usize) -> io::Result<$t> {
if x > <$t>::MAX as usize {
Err(io::Error::new(
io::ErrorKind::InvalidInput,
"value too large to transmit",
))
} else {
Ok(x as $t)
}
}
}
};
}
from_usize!(i16);
from_usize!(i32);

View File

@@ -1,766 +0,0 @@
#![allow(missing_docs)]
use byteorder::{BigEndian, ByteOrder, ReadBytesExt};
use bytes::{Bytes, BytesMut};
use fallible_iterator::FallibleIterator;
use memchr::memchr;
use std::cmp;
use std::io::{self, Read};
use std::ops::Range;
use std::str;
use crate::Oid;
// top-level message tags
const PARSE_COMPLETE_TAG: u8 = b'1';
const BIND_COMPLETE_TAG: u8 = b'2';
const CLOSE_COMPLETE_TAG: u8 = b'3';
pub const NOTIFICATION_RESPONSE_TAG: u8 = b'A';
const COPY_DONE_TAG: u8 = b'c';
const COMMAND_COMPLETE_TAG: u8 = b'C';
const COPY_DATA_TAG: u8 = b'd';
const DATA_ROW_TAG: u8 = b'D';
const ERROR_RESPONSE_TAG: u8 = b'E';
const COPY_IN_RESPONSE_TAG: u8 = b'G';
const COPY_OUT_RESPONSE_TAG: u8 = b'H';
const COPY_BOTH_RESPONSE_TAG: u8 = b'W';
const EMPTY_QUERY_RESPONSE_TAG: u8 = b'I';
const BACKEND_KEY_DATA_TAG: u8 = b'K';
pub const NO_DATA_TAG: u8 = b'n';
pub const NOTICE_RESPONSE_TAG: u8 = b'N';
const AUTHENTICATION_TAG: u8 = b'R';
const PORTAL_SUSPENDED_TAG: u8 = b's';
pub const PARAMETER_STATUS_TAG: u8 = b'S';
const PARAMETER_DESCRIPTION_TAG: u8 = b't';
const ROW_DESCRIPTION_TAG: u8 = b'T';
pub const READY_FOR_QUERY_TAG: u8 = b'Z';
#[derive(Debug, Copy, Clone)]
pub struct Header {
tag: u8,
len: i32,
}
#[allow(clippy::len_without_is_empty)]
impl Header {
#[inline]
pub fn parse(buf: &[u8]) -> io::Result<Option<Header>> {
if buf.len() < 5 {
return Ok(None);
}
let tag = buf[0];
let len = BigEndian::read_i32(&buf[1..]);
if len < 4 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"invalid message length: header length < 4",
));
}
Ok(Some(Header { tag, len }))
}
#[inline]
pub fn tag(self) -> u8 {
self.tag
}
#[inline]
pub fn len(self) -> i32 {
self.len
}
}
/// An enum representing Postgres backend messages.
#[non_exhaustive]
pub enum Message {
AuthenticationCleartextPassword,
AuthenticationGss,
AuthenticationKerberosV5,
AuthenticationMd5Password,
AuthenticationOk,
AuthenticationScmCredential,
AuthenticationSspi,
AuthenticationGssContinue,
AuthenticationSasl(AuthenticationSaslBody),
AuthenticationSaslContinue(AuthenticationSaslContinueBody),
AuthenticationSaslFinal(AuthenticationSaslFinalBody),
BackendKeyData(BackendKeyDataBody),
BindComplete,
CloseComplete,
CommandComplete(CommandCompleteBody),
CopyData,
CopyDone,
CopyInResponse,
CopyOutResponse,
CopyBothResponse,
DataRow(DataRowBody),
EmptyQueryResponse,
ErrorResponse(ErrorResponseBody),
NoData,
NoticeResponse(NoticeResponseBody),
NotificationResponse(NotificationResponseBody),
ParameterDescription(ParameterDescriptionBody),
ParameterStatus(ParameterStatusBody),
ParseComplete,
PortalSuspended,
ReadyForQuery(ReadyForQueryBody),
RowDescription(RowDescriptionBody),
}
impl Message {
#[inline]
pub fn parse(buf: &mut BytesMut) -> io::Result<Option<Message>> {
if buf.len() < 5 {
let to_read = 5 - buf.len();
buf.reserve(to_read);
return Ok(None);
}
let tag = buf[0];
let len = (&buf[1..5]).read_u32::<BigEndian>().unwrap();
if len < 4 {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid message length: parsing u32",
));
}
let total_len = len as usize + 1;
if buf.len() < total_len {
let to_read = total_len - buf.len();
buf.reserve(to_read);
return Ok(None);
}
let mut buf = Buffer {
bytes: buf.split_to(total_len).freeze(),
idx: 5,
};
let message = match tag {
PARSE_COMPLETE_TAG => Message::ParseComplete,
BIND_COMPLETE_TAG => Message::BindComplete,
CLOSE_COMPLETE_TAG => Message::CloseComplete,
NOTIFICATION_RESPONSE_TAG => {
let process_id = buf.read_i32::<BigEndian>()?;
let channel = buf.read_cstr()?;
let message = buf.read_cstr()?;
Message::NotificationResponse(NotificationResponseBody {
process_id,
channel,
message,
})
}
COPY_DONE_TAG => Message::CopyDone,
COMMAND_COMPLETE_TAG => {
let tag = buf.read_cstr()?;
Message::CommandComplete(CommandCompleteBody { tag })
}
COPY_DATA_TAG => Message::CopyData,
DATA_ROW_TAG => {
let len = buf.read_u16::<BigEndian>()?;
let storage = buf.read_all();
Message::DataRow(DataRowBody { storage, len })
}
ERROR_RESPONSE_TAG => {
let storage = buf.read_all();
Message::ErrorResponse(ErrorResponseBody { storage })
}
COPY_IN_RESPONSE_TAG => Message::CopyInResponse,
COPY_OUT_RESPONSE_TAG => Message::CopyOutResponse,
COPY_BOTH_RESPONSE_TAG => Message::CopyBothResponse,
EMPTY_QUERY_RESPONSE_TAG => Message::EmptyQueryResponse,
BACKEND_KEY_DATA_TAG => {
let process_id = buf.read_i32::<BigEndian>()?;
let secret_key = buf.read_i32::<BigEndian>()?;
Message::BackendKeyData(BackendKeyDataBody {
process_id,
secret_key,
})
}
NO_DATA_TAG => Message::NoData,
NOTICE_RESPONSE_TAG => {
let storage = buf.read_all();
Message::NoticeResponse(NoticeResponseBody { storage })
}
AUTHENTICATION_TAG => match buf.read_i32::<BigEndian>()? {
0 => Message::AuthenticationOk,
2 => Message::AuthenticationKerberosV5,
3 => Message::AuthenticationCleartextPassword,
5 => Message::AuthenticationMd5Password,
6 => Message::AuthenticationScmCredential,
7 => Message::AuthenticationGss,
8 => Message::AuthenticationGssContinue,
9 => Message::AuthenticationSspi,
10 => {
let storage = buf.read_all();
Message::AuthenticationSasl(AuthenticationSaslBody(storage))
}
11 => {
let storage = buf.read_all();
Message::AuthenticationSaslContinue(AuthenticationSaslContinueBody(storage))
}
12 => {
let storage = buf.read_all();
Message::AuthenticationSaslFinal(AuthenticationSaslFinalBody(storage))
}
tag => {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("unknown authentication tag `{}`", tag),
));
}
},
PORTAL_SUSPENDED_TAG => Message::PortalSuspended,
PARAMETER_STATUS_TAG => {
let name = buf.read_cstr()?;
let value = buf.read_cstr()?;
Message::ParameterStatus(ParameterStatusBody { name, value })
}
PARAMETER_DESCRIPTION_TAG => {
let len = buf.read_u16::<BigEndian>()?;
let storage = buf.read_all();
Message::ParameterDescription(ParameterDescriptionBody { storage, len })
}
ROW_DESCRIPTION_TAG => {
let len = buf.read_u16::<BigEndian>()?;
let storage = buf.read_all();
Message::RowDescription(RowDescriptionBody { storage, len })
}
READY_FOR_QUERY_TAG => {
let status = buf.read_u8()?;
Message::ReadyForQuery(ReadyForQueryBody { status })
}
tag => {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("unknown message tag `{}`", tag),
));
}
};
if !buf.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid message length: expected buffer to be empty",
));
}
Ok(Some(message))
}
}
struct Buffer {
bytes: Bytes,
idx: usize,
}
impl Buffer {
#[inline]
fn slice(&self) -> &[u8] {
&self.bytes[self.idx..]
}
#[inline]
fn is_empty(&self) -> bool {
self.slice().is_empty()
}
#[inline]
fn read_cstr(&mut self) -> io::Result<Bytes> {
match memchr(0, self.slice()) {
Some(pos) => {
let start = self.idx;
let end = start + pos;
let cstr = self.bytes.slice(start..end);
self.idx = end + 1;
Ok(cstr)
}
None => Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"unexpected EOF",
)),
}
}
#[inline]
fn read_all(&mut self) -> Bytes {
let buf = self.bytes.slice(self.idx..);
self.idx = self.bytes.len();
buf
}
}
impl Read for Buffer {
#[inline]
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let len = {
let slice = self.slice();
let len = cmp::min(slice.len(), buf.len());
buf[..len].copy_from_slice(&slice[..len]);
len
};
self.idx += len;
Ok(len)
}
}
pub struct AuthenticationMd5PasswordBody {
salt: [u8; 4],
}
impl AuthenticationMd5PasswordBody {
#[inline]
pub fn salt(&self) -> [u8; 4] {
self.salt
}
}
pub struct AuthenticationSaslBody(Bytes);
impl AuthenticationSaslBody {
#[inline]
pub fn mechanisms(&self) -> SaslMechanisms<'_> {
SaslMechanisms(&self.0)
}
}
pub struct SaslMechanisms<'a>(&'a [u8]);
impl<'a> FallibleIterator for SaslMechanisms<'a> {
type Item = &'a str;
type Error = io::Error;
#[inline]
fn next(&mut self) -> io::Result<Option<&'a str>> {
let value_end = find_null(self.0, 0)?;
if value_end == 0 {
if self.0.len() != 1 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"invalid message length: expected to be at end of iterator for sasl",
));
}
Ok(None)
} else {
let value = get_str(&self.0[..value_end])?;
self.0 = &self.0[value_end + 1..];
Ok(Some(value))
}
}
}
pub struct AuthenticationSaslContinueBody(Bytes);
impl AuthenticationSaslContinueBody {
#[inline]
pub fn data(&self) -> &[u8] {
&self.0
}
}
pub struct AuthenticationSaslFinalBody(Bytes);
impl AuthenticationSaslFinalBody {
#[inline]
pub fn data(&self) -> &[u8] {
&self.0
}
}
pub struct BackendKeyDataBody {
process_id: i32,
secret_key: i32,
}
impl BackendKeyDataBody {
#[inline]
pub fn process_id(&self) -> i32 {
self.process_id
}
#[inline]
pub fn secret_key(&self) -> i32 {
self.secret_key
}
}
pub struct CommandCompleteBody {
tag: Bytes,
}
impl CommandCompleteBody {
#[inline]
pub fn tag(&self) -> io::Result<&str> {
get_str(&self.tag)
}
}
#[derive(Debug)]
pub struct DataRowBody {
storage: Bytes,
len: u16,
}
impl DataRowBody {
#[inline]
pub fn ranges(&self) -> DataRowRanges<'_> {
DataRowRanges {
buf: &self.storage,
len: self.storage.len(),
remaining: self.len,
}
}
#[inline]
pub fn buffer(&self) -> &[u8] {
&self.storage
}
}
pub struct DataRowRanges<'a> {
buf: &'a [u8],
len: usize,
remaining: u16,
}
impl FallibleIterator for DataRowRanges<'_> {
type Item = Option<Range<usize>>;
type Error = io::Error;
#[inline]
fn next(&mut self) -> io::Result<Option<Option<Range<usize>>>> {
if self.remaining == 0 {
if self.buf.is_empty() {
return Ok(None);
} else {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid message length: datarowrange is not empty",
));
}
}
self.remaining -= 1;
let len = self.buf.read_i32::<BigEndian>()?;
if len < 0 {
Ok(Some(None))
} else {
let len = len as usize;
if self.buf.len() < len {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"unexpected EOF",
));
}
let base = self.len - self.buf.len();
self.buf = &self.buf[len..];
Ok(Some(Some(base..base + len)))
}
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.remaining as usize;
(len, Some(len))
}
}
pub struct ErrorResponseBody {
storage: Bytes,
}
impl ErrorResponseBody {
#[inline]
pub fn fields(&self) -> ErrorFields<'_> {
ErrorFields { buf: &self.storage }
}
}
pub struct ErrorFields<'a> {
buf: &'a [u8],
}
impl<'a> FallibleIterator for ErrorFields<'a> {
type Item = ErrorField<'a>;
type Error = io::Error;
#[inline]
fn next(&mut self) -> io::Result<Option<ErrorField<'a>>> {
let type_ = self.buf.read_u8()?;
if type_ == 0 {
if self.buf.is_empty() {
return Ok(None);
} else {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid message length: error fields is not drained",
));
}
}
let value_end = find_null(self.buf, 0)?;
let value = get_str(&self.buf[..value_end])?;
self.buf = &self.buf[value_end + 1..];
Ok(Some(ErrorField { type_, value }))
}
}
pub struct ErrorField<'a> {
type_: u8,
value: &'a str,
}
impl ErrorField<'_> {
#[inline]
pub fn type_(&self) -> u8 {
self.type_
}
#[inline]
pub fn value(&self) -> &str {
self.value
}
}
pub struct NoticeResponseBody {
storage: Bytes,
}
impl NoticeResponseBody {
#[inline]
pub fn fields(&self) -> ErrorFields<'_> {
ErrorFields { buf: &self.storage }
}
pub fn as_bytes(&self) -> &[u8] {
&self.storage
}
}
pub struct NotificationResponseBody {
process_id: i32,
channel: Bytes,
message: Bytes,
}
impl NotificationResponseBody {
#[inline]
pub fn process_id(&self) -> i32 {
self.process_id
}
#[inline]
pub fn channel(&self) -> io::Result<&str> {
get_str(&self.channel)
}
#[inline]
pub fn message(&self) -> io::Result<&str> {
get_str(&self.message)
}
}
pub struct ParameterDescriptionBody {
storage: Bytes,
len: u16,
}
impl ParameterDescriptionBody {
#[inline]
pub fn parameters(&self) -> Parameters<'_> {
Parameters {
buf: &self.storage,
remaining: self.len,
}
}
}
pub struct Parameters<'a> {
buf: &'a [u8],
remaining: u16,
}
impl FallibleIterator for Parameters<'_> {
type Item = Oid;
type Error = io::Error;
#[inline]
fn next(&mut self) -> io::Result<Option<Oid>> {
if self.remaining == 0 {
if self.buf.is_empty() {
return Ok(None);
} else {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid message length: parameters is not drained",
));
}
}
self.remaining -= 1;
self.buf.read_u32::<BigEndian>().map(Some)
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.remaining as usize;
(len, Some(len))
}
}
pub struct ParameterStatusBody {
name: Bytes,
value: Bytes,
}
impl ParameterStatusBody {
#[inline]
pub fn name(&self) -> io::Result<&str> {
get_str(&self.name)
}
#[inline]
pub fn value(&self) -> io::Result<&str> {
get_str(&self.value)
}
}
pub struct ReadyForQueryBody {
status: u8,
}
impl ReadyForQueryBody {
#[inline]
pub fn status(&self) -> u8 {
self.status
}
}
pub struct RowDescriptionBody {
storage: Bytes,
len: u16,
}
impl RowDescriptionBody {
#[inline]
pub fn fields(&self) -> Fields<'_> {
Fields {
buf: &self.storage,
remaining: self.len,
}
}
}
pub struct Fields<'a> {
buf: &'a [u8],
remaining: u16,
}
impl<'a> FallibleIterator for Fields<'a> {
type Item = Field<'a>;
type Error = io::Error;
#[inline]
fn next(&mut self) -> io::Result<Option<Field<'a>>> {
if self.remaining == 0 {
if self.buf.is_empty() {
return Ok(None);
} else {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid message length: field is not drained",
));
}
}
self.remaining -= 1;
let name_end = find_null(self.buf, 0)?;
let name = get_str(&self.buf[..name_end])?;
self.buf = &self.buf[name_end + 1..];
let table_oid = self.buf.read_u32::<BigEndian>()?;
let column_id = self.buf.read_i16::<BigEndian>()?;
let type_oid = self.buf.read_u32::<BigEndian>()?;
let type_size = self.buf.read_i16::<BigEndian>()?;
let type_modifier = self.buf.read_i32::<BigEndian>()?;
let format = self.buf.read_i16::<BigEndian>()?;
Ok(Some(Field {
name,
table_oid,
column_id,
type_oid,
type_size,
type_modifier,
format,
}))
}
}
pub struct Field<'a> {
name: &'a str,
table_oid: Oid,
column_id: i16,
type_oid: Oid,
type_size: i16,
type_modifier: i32,
format: i16,
}
impl<'a> Field<'a> {
#[inline]
pub fn name(&self) -> &'a str {
self.name
}
#[inline]
pub fn table_oid(&self) -> Oid {
self.table_oid
}
#[inline]
pub fn column_id(&self) -> i16 {
self.column_id
}
#[inline]
pub fn type_oid(&self) -> Oid {
self.type_oid
}
#[inline]
pub fn type_size(&self) -> i16 {
self.type_size
}
#[inline]
pub fn type_modifier(&self) -> i32 {
self.type_modifier
}
#[inline]
pub fn format(&self) -> i16 {
self.format
}
}
#[inline]
fn find_null(buf: &[u8], start: usize) -> io::Result<usize> {
match memchr(0, &buf[start..]) {
Some(pos) => Ok(pos + start),
None => Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"unexpected EOF",
)),
}
}
#[inline]
fn get_str(buf: &[u8]) -> io::Result<&str> {
str::from_utf8(buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))
}

View File

@@ -1,309 +0,0 @@
//! Frontend message serialization.
#![allow(missing_docs)]
use byteorder::{BigEndian, ByteOrder};
use bytes::{Buf, BufMut, BytesMut};
use std::convert::TryFrom;
use std::error::Error;
use std::io;
use std::marker;
use crate::{write_nullable, FromUsize, IsNull, Oid};
#[inline]
fn write_body<F, E>(buf: &mut BytesMut, f: F) -> Result<(), E>
where
F: FnOnce(&mut BytesMut) -> Result<(), E>,
E: From<io::Error>,
{
let base = buf.len();
buf.extend_from_slice(&[0; 4]);
f(buf)?;
let size = i32::from_usize(buf.len() - base)?;
BigEndian::write_i32(&mut buf[base..], size);
Ok(())
}
pub enum BindError {
Conversion(Box<dyn Error + marker::Sync + Send>),
Serialization(io::Error),
}
impl From<Box<dyn Error + marker::Sync + Send>> for BindError {
#[inline]
fn from(e: Box<dyn Error + marker::Sync + Send>) -> BindError {
BindError::Conversion(e)
}
}
impl From<io::Error> for BindError {
#[inline]
fn from(e: io::Error) -> BindError {
BindError::Serialization(e)
}
}
#[inline]
pub fn bind<I, J, F, T, K>(
portal: &str,
statement: &str,
formats: I,
values: J,
mut serializer: F,
result_formats: K,
buf: &mut BytesMut,
) -> Result<(), BindError>
where
I: IntoIterator<Item = i16>,
J: IntoIterator<Item = T>,
F: FnMut(T, &mut BytesMut) -> Result<IsNull, Box<dyn Error + marker::Sync + Send>>,
K: IntoIterator<Item = i16>,
{
buf.put_u8(b'B');
write_body(buf, |buf| {
write_cstr(portal.as_bytes(), buf)?;
write_cstr(statement.as_bytes(), buf)?;
write_counted(
formats,
|f, buf| {
buf.put_i16(f);
Ok::<_, io::Error>(())
},
buf,
)?;
write_counted(
values,
|v, buf| write_nullable(|buf| serializer(v, buf), buf),
buf,
)?;
write_counted(
result_formats,
|f, buf| {
buf.put_i16(f);
Ok::<_, io::Error>(())
},
buf,
)?;
Ok(())
})
}
#[inline]
fn write_counted<I, T, F, E>(items: I, mut serializer: F, buf: &mut BytesMut) -> Result<(), E>
where
I: IntoIterator<Item = T>,
F: FnMut(T, &mut BytesMut) -> Result<(), E>,
E: From<io::Error>,
{
let base = buf.len();
buf.extend_from_slice(&[0; 2]);
let mut count = 0;
for item in items {
serializer(item, buf)?;
count += 1;
}
let count = i16::from_usize(count)?;
BigEndian::write_i16(&mut buf[base..], count);
Ok(())
}
#[inline]
pub fn cancel_request(process_id: i32, secret_key: i32, buf: &mut BytesMut) {
write_body(buf, |buf| {
buf.put_i32(80_877_102);
buf.put_i32(process_id);
buf.put_i32(secret_key);
Ok::<_, io::Error>(())
})
.unwrap();
}
#[inline]
pub fn close(variant: u8, name: &str, buf: &mut BytesMut) -> io::Result<()> {
buf.put_u8(b'C');
write_body(buf, |buf| {
buf.put_u8(variant);
write_cstr(name.as_bytes(), buf)
})
}
pub struct CopyData<T> {
buf: T,
len: i32,
}
impl<T> CopyData<T>
where
T: Buf,
{
pub fn new(buf: T) -> io::Result<CopyData<T>> {
let len = buf
.remaining()
.checked_add(4)
.and_then(|l| i32::try_from(l).ok())
.ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "message length overflow")
})?;
Ok(CopyData { buf, len })
}
pub fn write(self, out: &mut BytesMut) {
out.put_u8(b'd');
out.put_i32(self.len);
out.put(self.buf);
}
}
#[inline]
pub fn copy_done(buf: &mut BytesMut) {
buf.put_u8(b'c');
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
}
#[inline]
pub fn copy_fail(message: &str, buf: &mut BytesMut) -> io::Result<()> {
buf.put_u8(b'f');
write_body(buf, |buf| write_cstr(message.as_bytes(), buf))
}
#[inline]
pub fn describe(variant: u8, name: &str, buf: &mut BytesMut) -> io::Result<()> {
buf.put_u8(b'D');
write_body(buf, |buf| {
buf.put_u8(variant);
write_cstr(name.as_bytes(), buf)
})
}
#[inline]
pub fn execute(portal: &str, max_rows: i32, buf: &mut BytesMut) -> io::Result<()> {
buf.put_u8(b'E');
write_body(buf, |buf| {
write_cstr(portal.as_bytes(), buf)?;
buf.put_i32(max_rows);
Ok(())
})
}
#[inline]
pub fn parse<I>(name: &str, query: &str, param_types: I, buf: &mut BytesMut) -> io::Result<()>
where
I: IntoIterator<Item = Oid>,
{
buf.put_u8(b'P');
write_body(buf, |buf| {
write_cstr(name.as_bytes(), buf)?;
write_cstr(query.as_bytes(), buf)?;
write_counted(
param_types,
|t, buf| {
buf.put_u32(t);
Ok::<_, io::Error>(())
},
buf,
)?;
Ok(())
})
}
#[inline]
pub fn password_message(password: &[u8], buf: &mut BytesMut) -> io::Result<()> {
buf.put_u8(b'p');
write_body(buf, |buf| write_cstr(password, buf))
}
#[inline]
pub fn query(query: &str, buf: &mut BytesMut) -> io::Result<()> {
buf.put_u8(b'Q');
write_body(buf, |buf| write_cstr(query.as_bytes(), buf))
}
#[inline]
pub fn sasl_initial_response(mechanism: &str, data: &[u8], buf: &mut BytesMut) -> io::Result<()> {
buf.put_u8(b'p');
write_body(buf, |buf| {
write_cstr(mechanism.as_bytes(), buf)?;
let len = i32::from_usize(data.len())?;
buf.put_i32(len);
buf.put_slice(data);
Ok(())
})
}
#[inline]
pub fn sasl_response(data: &[u8], buf: &mut BytesMut) -> io::Result<()> {
buf.put_u8(b'p');
write_body(buf, |buf| {
buf.put_slice(data);
Ok(())
})
}
#[inline]
pub fn ssl_request(buf: &mut BytesMut) {
write_body(buf, |buf| {
buf.put_i32(80_877_103);
Ok::<_, io::Error>(())
})
.unwrap();
}
#[inline]
pub fn startup_message(parameters: &StartupMessageParams, buf: &mut BytesMut) -> io::Result<()> {
write_body(buf, |buf| {
// postgres protocol version 3.0(196608) in bigger-endian
buf.put_i32(0x00_03_00_00);
buf.put_slice(&parameters.params);
buf.put_u8(0);
Ok(())
})
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct StartupMessageParams {
pub params: BytesMut,
}
impl StartupMessageParams {
/// Set parameter's value by its name.
pub fn insert(&mut self, name: &str, value: &str) {
if name.contains('\0') || value.contains('\0') {
panic!("startup parameter name or value contained a null")
}
self.params.put_slice(name.as_bytes());
self.params.put_u8(0);
self.params.put_slice(value.as_bytes());
self.params.put_u8(0);
}
}
#[inline]
pub fn sync(buf: &mut BytesMut) {
buf.put_u8(b'S');
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
}
#[inline]
pub fn terminate(buf: &mut BytesMut) {
buf.put_u8(b'X');
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
}
#[inline]
fn write_cstr(s: &[u8], buf: &mut BytesMut) -> Result<(), io::Error> {
if s.contains(&0) {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"string contains embedded null",
));
}
buf.put_slice(s);
buf.put_u8(0);
Ok(())
}

View File

@@ -1,8 +0,0 @@
//! Postgres message protocol support.
//!
//! See [Postgres's documentation][docs] for more information on message flow.
//!
//! [docs]: https://www.postgresql.org/docs/9.5/static/protocol-flow.html
pub mod backend;
pub mod frontend;

View File

@@ -1,89 +0,0 @@
//! Functions to encrypt a password in the client.
//!
//! This is intended to be used by client applications that wish to
//! send commands like `ALTER USER joe PASSWORD 'pwd'`. The password
//! need not be sent in cleartext if it is encrypted on the client
//! side. This is good because it ensures the cleartext password won't
//! end up in logs pg_stat displays, etc.
use crate::authentication::sasl;
use hmac::{Hmac, Mac};
use rand::RngCore;
use sha2::digest::FixedOutput;
use sha2::{Digest, Sha256};
#[cfg(test)]
mod test;
const SCRAM_DEFAULT_ITERATIONS: u32 = 4096;
const SCRAM_DEFAULT_SALT_LEN: usize = 16;
/// Hash password using SCRAM-SHA-256 with a randomly-generated
/// salt.
///
/// The client may assume the returned string doesn't contain any
/// special characters that would require escaping in an SQL command.
pub async fn scram_sha_256(password: &[u8]) -> String {
let mut salt: [u8; SCRAM_DEFAULT_SALT_LEN] = [0; SCRAM_DEFAULT_SALT_LEN];
let mut rng = rand::thread_rng();
rng.fill_bytes(&mut salt);
scram_sha_256_salt(password, salt).await
}
// Internal implementation of scram_sha_256 with a caller-provided
// salt. This is useful for testing.
pub(crate) async fn scram_sha_256_salt(
password: &[u8],
salt: [u8; SCRAM_DEFAULT_SALT_LEN],
) -> String {
// Prepare the password, per [RFC
// 4013](https://tools.ietf.org/html/rfc4013), if possible.
//
// Postgres treats passwords as byte strings (without embedded NUL
// bytes), but SASL expects passwords to be valid UTF-8.
//
// Follow the behavior of libpq's PQencryptPasswordConn(), and
// also the backend. If the password is not valid UTF-8, or if it
// contains prohibited characters (such as non-ASCII whitespace),
// just skip the SASLprep step and use the original byte
// sequence.
let prepared: Vec<u8> = match std::str::from_utf8(password) {
Ok(password_str) => {
match stringprep::saslprep(password_str) {
Ok(p) => p.into_owned().into_bytes(),
// contains invalid characters; skip saslprep
Err(_) => Vec::from(password),
}
}
// not valid UTF-8; skip saslprep
Err(_) => Vec::from(password),
};
// salt password
let salted_password = sasl::hi(&prepared, &salt, SCRAM_DEFAULT_ITERATIONS).await;
// client key
let mut hmac = Hmac::<Sha256>::new_from_slice(&salted_password)
.expect("HMAC is able to accept all key sizes");
hmac.update(b"Client Key");
let client_key = hmac.finalize().into_bytes();
// stored key
let mut hash = Sha256::default();
hash.update(client_key.as_slice());
let stored_key = hash.finalize_fixed();
// server key
let mut hmac = Hmac::<Sha256>::new_from_slice(&salted_password)
.expect("HMAC is able to accept all key sizes");
hmac.update(b"Server Key");
let server_key = hmac.finalize().into_bytes();
format!(
"SCRAM-SHA-256${}:{}${}:{}",
SCRAM_DEFAULT_ITERATIONS,
base64::encode(salt),
base64::encode(stored_key),
base64::encode(server_key)
)
}

View File

@@ -1,11 +0,0 @@
use crate::password;
#[tokio::test]
async fn test_encrypt_scram_sha_256() {
// Specify the salt to make the test deterministic. Any bytes will do.
let salt: [u8; 16] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
assert_eq!(
password::scram_sha_256_salt(b"secret", salt).await,
"SCRAM-SHA-256$4096:AQIDBAUGBwgJCgsMDQ4PEA==$8rrDg00OqaiWXJ7p+sCgHEIaBSHY89ZJl3mfIsf32oY=:05L1f+yZbiN8O0AnO40Og85NNRhvzTS57naKRWCcsIA="
);
}

View File

@@ -1,294 +0,0 @@
//! Conversions to and from Postgres's binary format for various types.
use byteorder::{BigEndian, ReadBytesExt};
use bytes::{BufMut, BytesMut};
use fallible_iterator::FallibleIterator;
use std::boxed::Box as StdBox;
use std::error::Error;
use std::str;
use crate::Oid;
#[cfg(test)]
mod test;
/// Serializes a `TEXT`, `VARCHAR`, `CHAR(n)`, `NAME`, or `CITEXT` value.
#[inline]
pub fn text_to_sql(v: &str, buf: &mut BytesMut) {
buf.put_slice(v.as_bytes());
}
/// Deserializes a `TEXT`, `VARCHAR`, `CHAR(n)`, `NAME`, or `CITEXT` value.
#[inline]
pub fn text_from_sql(buf: &[u8]) -> Result<&str, StdBox<dyn Error + Sync + Send>> {
Ok(str::from_utf8(buf)?)
}
/// Deserializes a `"char"` value.
#[inline]
pub fn char_from_sql(mut buf: &[u8]) -> Result<i8, StdBox<dyn Error + Sync + Send>> {
let v = buf.read_i8()?;
if !buf.is_empty() {
return Err("invalid buffer size".into());
}
Ok(v)
}
/// Serializes an `OID` value.
#[inline]
pub fn oid_to_sql(v: Oid, buf: &mut BytesMut) {
buf.put_u32(v);
}
/// Deserializes an `OID` value.
#[inline]
pub fn oid_from_sql(mut buf: &[u8]) -> Result<Oid, StdBox<dyn Error + Sync + Send>> {
let v = buf.read_u32::<BigEndian>()?;
if !buf.is_empty() {
return Err("invalid buffer size".into());
}
Ok(v)
}
/// A fallible iterator over `HSTORE` entries.
pub struct HstoreEntries<'a> {
remaining: i32,
buf: &'a [u8],
}
impl<'a> FallibleIterator for HstoreEntries<'a> {
type Item = (&'a str, Option<&'a str>);
type Error = StdBox<dyn Error + Sync + Send>;
#[inline]
#[allow(clippy::type_complexity)]
fn next(
&mut self,
) -> Result<Option<(&'a str, Option<&'a str>)>, StdBox<dyn Error + Sync + Send>> {
if self.remaining == 0 {
if !self.buf.is_empty() {
return Err("invalid buffer size".into());
}
return Ok(None);
}
self.remaining -= 1;
let key_len = self.buf.read_i32::<BigEndian>()?;
if key_len < 0 {
return Err("invalid key length".into());
}
let (key, buf) = self.buf.split_at(key_len as usize);
let key = str::from_utf8(key)?;
self.buf = buf;
let value_len = self.buf.read_i32::<BigEndian>()?;
let value = if value_len < 0 {
None
} else {
let (value, buf) = self.buf.split_at(value_len as usize);
let value = str::from_utf8(value)?;
self.buf = buf;
Some(value)
};
Ok(Some((key, value)))
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.remaining as usize;
(len, Some(len))
}
}
/// Deserializes an array value.
#[inline]
pub fn array_from_sql(mut buf: &[u8]) -> Result<Array<'_>, StdBox<dyn Error + Sync + Send>> {
let dimensions = buf.read_i32::<BigEndian>()?;
if dimensions < 0 {
return Err("invalid dimension count".into());
}
let mut r = buf;
let mut elements = 1i32;
for _ in 0..dimensions {
let len = r.read_i32::<BigEndian>()?;
if len < 0 {
return Err("invalid dimension size".into());
}
let _lower_bound = r.read_i32::<BigEndian>()?;
elements = match elements.checked_mul(len) {
Some(elements) => elements,
None => return Err("too many array elements".into()),
};
}
if dimensions == 0 {
elements = 0;
}
Ok(Array {
dimensions,
elements,
buf,
})
}
/// A Postgres array.
pub struct Array<'a> {
dimensions: i32,
elements: i32,
buf: &'a [u8],
}
impl<'a> Array<'a> {
/// Returns an iterator over the dimensions of the array.
#[inline]
pub fn dimensions(&self) -> ArrayDimensions<'a> {
ArrayDimensions(&self.buf[..self.dimensions as usize * 8])
}
/// Returns an iterator over the values of the array.
#[inline]
pub fn values(&self) -> ArrayValues<'a> {
ArrayValues {
remaining: self.elements,
buf: &self.buf[self.dimensions as usize * 8..],
}
}
}
/// An iterator over the dimensions of an array.
pub struct ArrayDimensions<'a>(&'a [u8]);
impl FallibleIterator for ArrayDimensions<'_> {
type Item = ArrayDimension;
type Error = StdBox<dyn Error + Sync + Send>;
#[inline]
fn next(&mut self) -> Result<Option<ArrayDimension>, StdBox<dyn Error + Sync + Send>> {
if self.0.is_empty() {
return Ok(None);
}
let len = self.0.read_i32::<BigEndian>()?;
let lower_bound = self.0.read_i32::<BigEndian>()?;
Ok(Some(ArrayDimension { len, lower_bound }))
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.0.len() / 8;
(len, Some(len))
}
}
/// Information about a dimension of an array.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct ArrayDimension {
/// The length of this dimension.
pub len: i32,
/// The base value used to index into this dimension.
pub lower_bound: i32,
}
/// An iterator over the values of an array, in row-major order.
pub struct ArrayValues<'a> {
remaining: i32,
buf: &'a [u8],
}
impl<'a> FallibleIterator for ArrayValues<'a> {
type Item = Option<&'a [u8]>;
type Error = StdBox<dyn Error + Sync + Send>;
#[inline]
fn next(&mut self) -> Result<Option<Option<&'a [u8]>>, StdBox<dyn Error + Sync + Send>> {
if self.remaining == 0 {
if !self.buf.is_empty() {
return Err("invalid message length: arrayvalue not drained".into());
}
return Ok(None);
}
self.remaining -= 1;
let len = self.buf.read_i32::<BigEndian>()?;
let val = if len < 0 {
None
} else {
if self.buf.len() < len as usize {
return Err("invalid value length".into());
}
let (val, buf) = self.buf.split_at(len as usize);
self.buf = buf;
Some(val)
};
Ok(Some(val))
}
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.remaining as usize;
(len, Some(len))
}
}
/// Serializes a Postgres ltree string
#[inline]
pub fn ltree_to_sql(v: &str, buf: &mut BytesMut) {
// A version number is prepended to an ltree string per spec
buf.put_u8(1);
// Append the rest of the query
buf.put_slice(v.as_bytes());
}
/// Deserialize a Postgres ltree string
#[inline]
pub fn ltree_from_sql(buf: &[u8]) -> Result<&str, StdBox<dyn Error + Sync + Send>> {
match buf {
// Remove the version number from the front of the ltree per spec
[1u8, rest @ ..] => Ok(str::from_utf8(rest)?),
_ => Err("ltree version 1 only supported".into()),
}
}
/// Serializes a Postgres lquery string
#[inline]
pub fn lquery_to_sql(v: &str, buf: &mut BytesMut) {
// A version number is prepended to an lquery string per spec
buf.put_u8(1);
// Append the rest of the query
buf.put_slice(v.as_bytes());
}
/// Deserialize a Postgres lquery string
#[inline]
pub fn lquery_from_sql(buf: &[u8]) -> Result<&str, StdBox<dyn Error + Sync + Send>> {
match buf {
// Remove the version number from the front of the lquery per spec
[1u8, rest @ ..] => Ok(str::from_utf8(rest)?),
_ => Err("lquery version 1 only supported".into()),
}
}
/// Serializes a Postgres ltxtquery string
#[inline]
pub fn ltxtquery_to_sql(v: &str, buf: &mut BytesMut) {
// A version number is prepended to an ltxtquery string per spec
buf.put_u8(1);
// Append the rest of the query
buf.put_slice(v.as_bytes());
}
/// Deserialize a Postgres ltxtquery string
#[inline]
pub fn ltxtquery_from_sql(buf: &[u8]) -> Result<&str, StdBox<dyn Error + Sync + Send>> {
match buf {
// Remove the version number from the front of the ltxtquery per spec
[1u8, rest @ ..] => Ok(str::from_utf8(rest)?),
_ => Err("ltxtquery version 1 only supported".into()),
}
}

View File

@@ -1,87 +0,0 @@
use bytes::{Buf, BytesMut};
use super::*;
#[test]
fn ltree_sql() {
let mut query = vec![1u8];
query.extend_from_slice("A.B.C".as_bytes());
let mut buf = BytesMut::new();
ltree_to_sql("A.B.C", &mut buf);
assert_eq!(query.as_slice(), buf.chunk());
}
#[test]
fn ltree_str() {
let mut query = vec![1u8];
query.extend_from_slice("A.B.C".as_bytes());
assert!(ltree_from_sql(query.as_slice()).is_ok())
}
#[test]
fn ltree_wrong_version() {
let mut query = vec![2u8];
query.extend_from_slice("A.B.C".as_bytes());
assert!(ltree_from_sql(query.as_slice()).is_err())
}
#[test]
fn lquery_sql() {
let mut query = vec![1u8];
query.extend_from_slice("A.B.C".as_bytes());
let mut buf = BytesMut::new();
lquery_to_sql("A.B.C", &mut buf);
assert_eq!(query.as_slice(), buf.chunk());
}
#[test]
fn lquery_str() {
let mut query = vec![1u8];
query.extend_from_slice("A.B.C".as_bytes());
assert!(lquery_from_sql(query.as_slice()).is_ok())
}
#[test]
fn lquery_wrong_version() {
let mut query = vec![2u8];
query.extend_from_slice("A.B.C".as_bytes());
assert!(lquery_from_sql(query.as_slice()).is_err())
}
#[test]
fn ltxtquery_sql() {
let mut query = vec![1u8];
query.extend_from_slice("a & b*".as_bytes());
let mut buf = BytesMut::new();
ltree_to_sql("a & b*", &mut buf);
assert_eq!(query.as_slice(), buf.chunk());
}
#[test]
fn ltxtquery_str() {
let mut query = vec![1u8];
query.extend_from_slice("a & b*".as_bytes());
assert!(ltree_from_sql(query.as_slice()).is_ok())
}
#[test]
fn ltxtquery_wrong_version() {
let mut query = vec![2u8];
query.extend_from_slice("a & b*".as_bytes());
assert!(ltree_from_sql(query.as_slice()).is_err())
}

View File

@@ -1,10 +0,0 @@
[package]
name = "postgres-types2"
version = "0.1.0"
edition = "2018"
license = "MIT/Apache-2.0"
[dependencies]
bytes.workspace = true
fallible-iterator.workspace = true
postgres-protocol2 = { path = "../postgres-protocol2" }

View File

@@ -1,477 +0,0 @@
//! Conversions to and from Postgres types.
//!
//! This crate is used by the `tokio-postgres` and `postgres` crates. You normally don't need to depend directly on it
//! unless you want to define your own `ToSql` or `FromSql` definitions.
#![doc(html_root_url = "https://docs.rs/postgres-types/0.2")]
#![warn(clippy::all, rust_2018_idioms, missing_docs)]
use fallible_iterator::FallibleIterator;
use postgres_protocol2::types;
use std::any::type_name;
use std::error::Error;
use std::fmt;
use std::sync::Arc;
use crate::type_gen::{Inner, Other};
#[doc(inline)]
pub use postgres_protocol2::Oid;
use bytes::BytesMut;
/// Generates a simple implementation of `ToSql::accepts` which accepts the
/// types passed to it.
macro_rules! accepts {
($($expected:ident),+) => (
fn accepts(ty: &$crate::Type) -> bool {
matches!(*ty, $($crate::Type::$expected)|+)
}
)
}
/// Generates an implementation of `ToSql::to_sql_checked`.
///
/// All `ToSql` implementations should use this macro.
macro_rules! to_sql_checked {
() => {
fn to_sql_checked(
&self,
ty: &$crate::Type,
out: &mut $crate::private::BytesMut,
) -> ::std::result::Result<
$crate::IsNull,
Box<dyn ::std::error::Error + ::std::marker::Sync + ::std::marker::Send>,
> {
$crate::__to_sql_checked(self, ty, out)
}
};
}
// WARNING: this function is not considered part of this crate's public API.
// It is subject to change at any time.
#[doc(hidden)]
pub fn __to_sql_checked<T>(
v: &T,
ty: &Type,
out: &mut BytesMut,
) -> Result<IsNull, Box<dyn Error + Sync + Send>>
where
T: ToSql,
{
if !T::accepts(ty) {
return Err(Box::new(WrongType::new::<T>(ty.clone())));
}
v.to_sql(ty, out)
}
// mod pg_lsn;
#[doc(hidden)]
pub mod private;
// mod special;
mod type_gen;
/// A Postgres type.
#[derive(PartialEq, Eq, Clone, Hash)]
pub struct Type(Inner);
impl fmt::Debug for Type {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(&self.0, fmt)
}
}
impl fmt::Display for Type {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.schema() {
"public" | "pg_catalog" => {}
schema => write!(fmt, "{}.", schema)?,
}
fmt.write_str(self.name())
}
}
impl Type {
/// Creates a new `Type`.
pub fn new(name: String, oid: Oid, kind: Kind, schema: String) -> Type {
Type(Inner::Other(Arc::new(Other {
name,
oid,
kind,
schema,
})))
}
/// Returns the `Type` corresponding to the provided `Oid` if it
/// corresponds to a built-in type.
pub fn from_oid(oid: Oid) -> Option<Type> {
Inner::from_oid(oid).map(Type)
}
/// Returns the OID of the `Type`.
pub fn oid(&self) -> Oid {
self.0.oid()
}
/// Returns the kind of this type.
pub fn kind(&self) -> &Kind {
self.0.kind()
}
/// Returns the schema of this type.
pub fn schema(&self) -> &str {
match self.0 {
Inner::Other(ref u) => &u.schema,
_ => "pg_catalog",
}
}
/// Returns the name of this type.
pub fn name(&self) -> &str {
self.0.name()
}
}
/// Represents the kind of a Postgres type.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum Kind {
/// A simple type like `VARCHAR` or `INTEGER`.
Simple,
/// An enumerated type along with its variants.
Enum(Vec<String>),
/// A pseudo-type.
Pseudo,
/// An array type along with the type of its elements.
Array(Type),
/// A range type along with the type of its elements.
Range(Type),
/// A multirange type along with the type of its elements.
Multirange(Type),
/// A domain type along with its underlying type.
Domain(Type),
/// A composite type along with information about its fields.
Composite(Vec<Field>),
}
/// Information about a field of a composite type.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Field {
name: String,
type_: Type,
}
impl Field {
/// Creates a new `Field`.
pub fn new(name: String, type_: Type) -> Field {
Field { name, type_ }
}
/// Returns the name of the field.
pub fn name(&self) -> &str {
&self.name
}
/// Returns the type of the field.
pub fn type_(&self) -> &Type {
&self.type_
}
}
/// An error indicating that a `NULL` Postgres value was passed to a `FromSql`
/// implementation that does not support `NULL` values.
#[derive(Debug, Clone, Copy)]
pub struct WasNull;
impl fmt::Display for WasNull {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.write_str("a Postgres value was `NULL`")
}
}
impl Error for WasNull {}
/// An error indicating that a conversion was attempted between incompatible
/// Rust and Postgres types.
#[derive(Debug)]
pub struct WrongType {
postgres: Type,
rust: &'static str,
}
impl fmt::Display for WrongType {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
fmt,
"cannot convert between the Rust type `{}` and the Postgres type `{}`",
self.rust, self.postgres,
)
}
}
impl Error for WrongType {}
impl WrongType {
/// Creates a new `WrongType` error.
pub fn new<T>(ty: Type) -> WrongType {
WrongType {
postgres: ty,
rust: type_name::<T>(),
}
}
}
/// An error indicating that a as_text conversion was attempted on a binary
/// result.
#[derive(Debug)]
pub struct WrongFormat {}
impl Error for WrongFormat {}
impl fmt::Display for WrongFormat {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
fmt,
"cannot read column as text while it is in binary format"
)
}
}
/// A trait for types that can be created from a Postgres value.
pub trait FromSql<'a>: Sized {
/// Creates a new value of this type from a buffer of data of the specified
/// Postgres `Type` in its binary format.
///
/// The caller of this method is responsible for ensuring that this type
/// is compatible with the Postgres `Type`.
fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<Self, Box<dyn Error + Sync + Send>>;
/// Creates a new value of this type from a `NULL` SQL value.
///
/// The caller of this method is responsible for ensuring that this type
/// is compatible with the Postgres `Type`.
///
/// The default implementation returns `Err(Box::new(WasNull))`.
#[allow(unused_variables)]
fn from_sql_null(ty: &Type) -> Result<Self, Box<dyn Error + Sync + Send>> {
Err(Box::new(WasNull))
}
/// A convenience function that delegates to `from_sql` and `from_sql_null` depending on the
/// value of `raw`.
fn from_sql_nullable(
ty: &Type,
raw: Option<&'a [u8]>,
) -> Result<Self, Box<dyn Error + Sync + Send>> {
match raw {
Some(raw) => Self::from_sql(ty, raw),
None => Self::from_sql_null(ty),
}
}
/// Determines if a value of this type can be created from the specified
/// Postgres `Type`.
fn accepts(ty: &Type) -> bool;
}
/// A trait for types which can be created from a Postgres value without borrowing any data.
///
/// This is primarily useful for trait bounds on functions.
pub trait FromSqlOwned: for<'a> FromSql<'a> {}
impl<T> FromSqlOwned for T where T: for<'a> FromSql<'a> {}
impl<'a, T: FromSql<'a>> FromSql<'a> for Option<T> {
fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<Option<T>, Box<dyn Error + Sync + Send>> {
<T as FromSql>::from_sql(ty, raw).map(Some)
}
fn from_sql_null(_: &Type) -> Result<Option<T>, Box<dyn Error + Sync + Send>> {
Ok(None)
}
fn accepts(ty: &Type) -> bool {
<T as FromSql>::accepts(ty)
}
}
impl<'a, T: FromSql<'a>> FromSql<'a> for Vec<T> {
fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<Vec<T>, Box<dyn Error + Sync + Send>> {
let member_type = match *ty.kind() {
Kind::Array(ref member) => member,
_ => panic!("expected array type"),
};
let array = types::array_from_sql(raw)?;
if array.dimensions().count()? > 1 {
return Err("array contains too many dimensions".into());
}
array
.values()
.map(|v| T::from_sql_nullable(member_type, v))
.collect()
}
fn accepts(ty: &Type) -> bool {
match *ty.kind() {
Kind::Array(ref inner) => T::accepts(inner),
_ => false,
}
}
}
impl<'a> FromSql<'a> for String {
fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<String, Box<dyn Error + Sync + Send>> {
<&str as FromSql>::from_sql(ty, raw).map(ToString::to_string)
}
fn accepts(ty: &Type) -> bool {
<&str as FromSql>::accepts(ty)
}
}
impl<'a> FromSql<'a> for &'a str {
fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<&'a str, Box<dyn Error + Sync + Send>> {
match *ty {
ref ty if ty.name() == "ltree" => types::ltree_from_sql(raw),
ref ty if ty.name() == "lquery" => types::lquery_from_sql(raw),
ref ty if ty.name() == "ltxtquery" => types::ltxtquery_from_sql(raw),
_ => types::text_from_sql(raw),
}
}
fn accepts(ty: &Type) -> bool {
match *ty {
Type::VARCHAR | Type::TEXT | Type::BPCHAR | Type::NAME | Type::UNKNOWN => true,
ref ty
if (ty.name() == "citext"
|| ty.name() == "ltree"
|| ty.name() == "lquery"
|| ty.name() == "ltxtquery") =>
{
true
}
_ => false,
}
}
}
macro_rules! simple_from {
($t:ty, $f:ident, $($expected:ident),+) => {
impl<'a> FromSql<'a> for $t {
fn from_sql(_: &Type, raw: &'a [u8]) -> Result<$t, Box<dyn Error + Sync + Send>> {
types::$f(raw)
}
accepts!($($expected),+);
}
}
}
simple_from!(i8, char_from_sql, CHAR);
simple_from!(u32, oid_from_sql, OID);
/// An enum representing the nullability of a Postgres value.
pub enum IsNull {
/// The value is NULL.
Yes,
/// The value is not NULL.
No,
}
/// A trait for types that can be converted into Postgres values.
pub trait ToSql: fmt::Debug {
/// Converts the value of `self` into the binary format of the specified
/// Postgres `Type`, appending it to `out`.
///
/// The caller of this method is responsible for ensuring that this type
/// is compatible with the Postgres `Type`.
///
/// The return value indicates if this value should be represented as
/// `NULL`. If this is the case, implementations **must not** write
/// anything to `out`.
fn to_sql(&self, ty: &Type, out: &mut BytesMut) -> Result<IsNull, Box<dyn Error + Sync + Send>>
where
Self: Sized;
/// Determines if a value of this type can be converted to the specified
/// Postgres `Type`.
fn accepts(ty: &Type) -> bool
where
Self: Sized;
/// An adaptor method used internally by Rust-Postgres.
///
/// *All* implementations of this method should be generated by the
/// `to_sql_checked!()` macro.
fn to_sql_checked(
&self,
ty: &Type,
out: &mut BytesMut,
) -> Result<IsNull, Box<dyn Error + Sync + Send>>;
/// Specify the encode format
fn encode_format(&self, _ty: &Type) -> Format {
Format::Binary
}
}
/// Supported Postgres message format types
///
/// Using Text format in a message assumes a Postgres `SERVER_ENCODING` of `UTF8`
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum Format {
/// Text format (UTF-8)
Text,
/// Compact, typed binary format
Binary,
}
impl ToSql for &str {
fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result<IsNull, Box<dyn Error + Sync + Send>> {
match *ty {
ref ty if ty.name() == "ltree" => types::ltree_to_sql(self, w),
ref ty if ty.name() == "lquery" => types::lquery_to_sql(self, w),
ref ty if ty.name() == "ltxtquery" => types::ltxtquery_to_sql(self, w),
_ => types::text_to_sql(self, w),
}
Ok(IsNull::No)
}
fn accepts(ty: &Type) -> bool {
match *ty {
Type::VARCHAR | Type::TEXT | Type::BPCHAR | Type::NAME | Type::UNKNOWN => true,
ref ty
if (ty.name() == "citext"
|| ty.name() == "ltree"
|| ty.name() == "lquery"
|| ty.name() == "ltxtquery") =>
{
true
}
_ => false,
}
}
to_sql_checked!();
}
macro_rules! simple_to {
($t:ty, $f:ident, $($expected:ident),+) => {
impl ToSql for $t {
fn to_sql(&self,
_: &Type,
w: &mut BytesMut)
-> Result<IsNull, Box<dyn Error + Sync + Send>> {
types::$f(*self, w);
Ok(IsNull::No)
}
accepts!($($expected),+);
to_sql_checked!();
}
}
}
simple_to!(u32, oid_to_sql, OID);

View File

@@ -1,34 +0,0 @@
use crate::{FromSql, Type};
pub use bytes::BytesMut;
use std::error::Error;
pub fn read_be_i32(buf: &mut &[u8]) -> Result<i32, Box<dyn Error + Sync + Send>> {
if buf.len() < 4 {
return Err("invalid buffer size".into());
}
let mut bytes = [0; 4];
bytes.copy_from_slice(&buf[..4]);
*buf = &buf[4..];
Ok(i32::from_be_bytes(bytes))
}
pub fn read_value<'a, T>(
type_: &Type,
buf: &mut &'a [u8],
) -> Result<T, Box<dyn Error + Sync + Send>>
where
T: FromSql<'a>,
{
let len = read_be_i32(buf)?;
let value = if len < 0 {
None
} else {
if len as usize > buf.len() {
return Err("invalid buffer size".into());
}
let (head, tail) = buf.split_at(len as usize);
*buf = tail;
Some(head)
};
T::from_sql_nullable(type_, value)
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,21 +0,0 @@
[package]
name = "tokio-postgres2"
version = "0.1.0"
edition = "2018"
license = "MIT/Apache-2.0"
[dependencies]
async-trait.workspace = true
bytes.workspace = true
byteorder.workspace = true
fallible-iterator.workspace = true
futures-util = { workspace = true, features = ["sink"] }
log = "0.4"
parking_lot.workspace = true
percent-encoding = "2.0"
pin-project-lite.workspace = true
phf = "0.11"
postgres-protocol2 = { path = "../postgres-protocol2" }
postgres-types2 = { path = "../postgres-types2" }
tokio = { workspace = true, features = ["io-util", "time", "net"] }
tokio-util = { workspace = true, features = ["codec"] }

View File

@@ -1,40 +0,0 @@
use tokio::net::TcpStream;
use crate::client::SocketConfig;
use crate::config::{Host, SslMode};
use crate::tls::MakeTlsConnect;
use crate::{cancel_query_raw, connect_socket, Error};
use std::io;
pub(crate) async fn cancel_query<T>(
config: Option<SocketConfig>,
ssl_mode: SslMode,
mut tls: T,
process_id: i32,
secret_key: i32,
) -> Result<(), Error>
where
T: MakeTlsConnect<TcpStream>,
{
let config = match config {
Some(config) => config,
None => {
return Err(Error::connect(io::Error::new(
io::ErrorKind::InvalidInput,
"unknown host",
)))
}
};
let hostname = match &config.host {
Host::Tcp(host) => &**host,
};
let tls = tls
.make_tls_connect(hostname)
.map_err(|e| Error::tls(e.into()))?;
let socket =
connect_socket::connect_socket(&config.host, config.port, config.connect_timeout).await?;
cancel_query_raw::cancel_query_raw(socket, ssl_mode, tls, process_id, secret_key).await
}

View File

@@ -1,29 +0,0 @@
use crate::config::SslMode;
use crate::tls::TlsConnect;
use crate::{connect_tls, Error};
use bytes::BytesMut;
use postgres_protocol2::message::frontend;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
pub async fn cancel_query_raw<S, T>(
stream: S,
mode: SslMode,
tls: T,
process_id: i32,
secret_key: i32,
) -> Result<(), Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: TlsConnect<S>,
{
let mut stream = connect_tls::connect_tls(stream, mode, tls).await?;
let mut buf = BytesMut::new();
frontend::cancel_request(process_id, secret_key, &mut buf);
stream.write_all(&buf).await.map_err(Error::io)?;
stream.flush().await.map_err(Error::io)?;
stream.shutdown().await.map_err(Error::io)?;
Ok(())
}

View File

@@ -1,62 +0,0 @@
use crate::config::SslMode;
use crate::tls::TlsConnect;
use crate::{cancel_query, client::SocketConfig, tls::MakeTlsConnect};
use crate::{cancel_query_raw, Error};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
/// The capability to request cancellation of in-progress queries on a
/// connection.
#[derive(Clone)]
pub struct CancelToken {
pub socket_config: Option<SocketConfig>,
pub ssl_mode: SslMode,
pub process_id: i32,
pub secret_key: i32,
}
impl CancelToken {
/// Attempts to cancel the in-progress query on the connection associated
/// with this `CancelToken`.
///
/// The server provides no information about whether a cancellation attempt was successful or not. An error will
/// only be returned if the client was unable to connect to the database.
///
/// Cancellation is inherently racy. There is no guarantee that the
/// cancellation request will reach the server before the query terminates
/// normally, or that the connection associated with this token is still
/// active.
///
/// Requires the `runtime` Cargo feature (enabled by default).
pub async fn cancel_query<T>(&self, tls: T) -> Result<(), Error>
where
T: MakeTlsConnect<TcpStream>,
{
cancel_query::cancel_query(
self.socket_config.clone(),
self.ssl_mode,
tls,
self.process_id,
self.secret_key,
)
.await
}
/// Like `cancel_query`, but uses a stream which is already connected to the server rather than opening a new
/// connection itself.
pub async fn cancel_query_raw<S, T>(&self, stream: S, tls: T) -> Result<(), Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: TlsConnect<S>,
{
cancel_query_raw::cancel_query_raw(
stream,
self.ssl_mode,
tls,
self.process_id,
self.secret_key,
)
.await
}
}

View File

@@ -1,271 +0,0 @@
use crate::codec::{BackendMessages, FrontendMessage};
use crate::config::Host;
use crate::config::SslMode;
use crate::connection::{Request, RequestMessages};
use crate::types::{Oid, Type};
use crate::{
simple_query, CancelToken, Error, ReadyForQueryStatus, Statement, Transaction,
TransactionBuilder,
};
use bytes::BytesMut;
use fallible_iterator::FallibleIterator;
use futures_util::{future, ready};
use postgres_protocol2::message::{backend::Message, frontend};
use std::collections::HashMap;
use std::fmt;
use std::task::{Context, Poll};
use tokio::sync::mpsc;
use std::time::Duration;
pub struct Responses {
receiver: mpsc::Receiver<BackendMessages>,
cur: BackendMessages,
}
impl Responses {
pub fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Result<Message, Error>> {
loop {
match self.cur.next().map_err(Error::parse)? {
Some(Message::ErrorResponse(body)) => return Poll::Ready(Err(Error::db(body))),
Some(message) => return Poll::Ready(Ok(message)),
None => {}
}
match ready!(self.receiver.poll_recv(cx)) {
Some(messages) => self.cur = messages,
None => return Poll::Ready(Err(Error::closed())),
}
}
}
pub async fn next(&mut self) -> Result<Message, Error> {
future::poll_fn(|cx| self.poll_next(cx)).await
}
}
/// A cache of type info and prepared statements for fetching type info
/// (corresponding to the queries in the [prepare] module).
#[derive(Default)]
pub(crate) struct CachedTypeInfo {
/// A statement for basic information for a type from its
/// OID. Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_QUERY) (or its
/// fallback).
typeinfo: Option<Statement>,
/// A statement for getting information for a composite type from its OID.
/// Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_COMPOSITE_QUERY).
typeinfo_composite: Option<Statement>,
/// A statement for getting information for a composite type from its OID.
/// Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_COMPOSITE_QUERY) (or
/// its fallback).
typeinfo_enum: Option<Statement>,
/// Cache of types already looked up.
types: HashMap<Oid, Type>,
}
impl CachedTypeInfo {
pub(crate) fn typeinfo(&mut self) -> Option<&Statement> {
self.typeinfo.as_ref()
}
pub(crate) fn set_typeinfo(&mut self, statement: Statement) -> &Statement {
self.typeinfo.insert(statement)
}
pub(crate) fn typeinfo_composite(&mut self) -> Option<&Statement> {
self.typeinfo_composite.as_ref()
}
pub(crate) fn set_typeinfo_composite(&mut self, statement: Statement) -> &Statement {
self.typeinfo_composite.insert(statement)
}
pub(crate) fn typeinfo_enum(&mut self) -> Option<&Statement> {
self.typeinfo_enum.as_ref()
}
pub(crate) fn set_typeinfo_enum(&mut self, statement: Statement) -> &Statement {
self.typeinfo_enum.insert(statement)
}
pub(crate) fn type_(&mut self, oid: Oid) -> Option<Type> {
self.types.get(&oid).cloned()
}
pub(crate) fn set_type(&mut self, oid: Oid, type_: &Type) {
self.types.insert(oid, type_.clone());
}
}
pub struct InnerClient {
sender: mpsc::UnboundedSender<Request>,
/// A buffer to use when writing out postgres commands.
buffer: BytesMut,
}
impl InnerClient {
pub fn send(&self, messages: RequestMessages) -> Result<Responses, Error> {
let (sender, receiver) = mpsc::channel(1);
let request = Request { messages, sender };
self.sender.send(request).map_err(|_| Error::closed())?;
Ok(Responses {
receiver,
cur: BackendMessages::empty(),
})
}
/// Call the given function with a buffer to be used when writing out
/// postgres commands.
pub fn with_buf<F, R>(&mut self, f: F) -> R
where
F: FnOnce(&mut BytesMut) -> R,
{
let r = f(&mut self.buffer);
self.buffer.clear();
r
}
}
#[derive(Clone)]
pub struct SocketConfig {
pub host: Host,
pub port: u16,
pub connect_timeout: Option<Duration>,
// pub keepalive: Option<KeepaliveConfig>,
}
/// An asynchronous PostgreSQL client.
///
/// The client is one half of what is returned when a connection is established. Users interact with the database
/// through this client object.
pub struct Client {
pub(crate) inner: InnerClient,
pub(crate) cached_typeinfo: CachedTypeInfo,
socket_config: SocketConfig,
ssl_mode: SslMode,
process_id: i32,
secret_key: i32,
}
impl Client {
pub(crate) fn new(
sender: mpsc::UnboundedSender<Request>,
socket_config: SocketConfig,
ssl_mode: SslMode,
process_id: i32,
secret_key: i32,
) -> Client {
Client {
inner: InnerClient {
sender,
buffer: Default::default(),
},
cached_typeinfo: Default::default(),
socket_config,
ssl_mode,
process_id,
secret_key,
}
}
/// Returns process_id.
pub fn get_process_id(&self) -> i32 {
self.process_id
}
/// Executes a sequence of SQL statements using the simple query protocol.
///
/// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that
/// point. This is intended for use when, for example, initializing a database schema.
///
/// # Warning
///
/// Prepared statements should be use for any query which contains user-specified data, as they provided the
/// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass
/// them to this method!
pub async fn batch_execute(&mut self, query: &str) -> Result<ReadyForQueryStatus, Error> {
simple_query::batch_execute(&mut self.inner, query).await
}
/// Begins a new database transaction.
///
/// The transaction will roll back by default - use the `commit` method to commit it.
pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
struct RollbackIfNotDone<'me> {
client: &'me mut Client,
done: bool,
}
impl Drop for RollbackIfNotDone<'_> {
fn drop(&mut self) {
if self.done {
return;
}
let buf = self.client.inner.with_buf(|buf| {
frontend::query("ROLLBACK", buf).unwrap();
buf.split().freeze()
});
let _ = self
.client
.inner
.send(RequestMessages::Single(FrontendMessage::Raw(buf)));
}
}
// This is done, as `Future` created by this method can be dropped after
// `RequestMessages` is synchronously send to the `Connection` by
// `batch_execute()`, but before `Responses` is asynchronously polled to
// completion. In that case `Transaction` won't be created and thus
// won't be rolled back.
{
let mut cleaner = RollbackIfNotDone {
client: self,
done: false,
};
cleaner.client.batch_execute("BEGIN").await?;
cleaner.done = true;
}
Ok(Transaction::new(self))
}
/// Returns a builder for a transaction with custom settings.
///
/// Unlike the `transaction` method, the builder can be used to control the transaction's isolation level and other
/// attributes.
pub fn build_transaction(&mut self) -> TransactionBuilder<'_> {
TransactionBuilder::new(self)
}
/// Constructs a cancellation token that can later be used to request cancellation of a query running on the
/// connection associated with this client.
pub fn cancel_token(&self) -> CancelToken {
CancelToken {
socket_config: Some(self.socket_config.clone()),
ssl_mode: self.ssl_mode,
process_id: self.process_id,
secret_key: self.secret_key,
}
}
/// Determines if the connection to the server has already closed.
///
/// In that case, all future queries will fail.
pub fn is_closed(&self) -> bool {
self.inner.sender.is_closed()
}
}
impl fmt::Debug for Client {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Client").finish()
}
}

View File

@@ -1,98 +0,0 @@
use bytes::{Buf, Bytes, BytesMut};
use fallible_iterator::FallibleIterator;
use postgres_protocol2::message::backend;
use postgres_protocol2::message::frontend::CopyData;
use std::io;
use tokio_util::codec::{Decoder, Encoder};
pub enum FrontendMessage {
Raw(Bytes),
CopyData(CopyData<Box<dyn Buf + Send>>),
}
pub enum BackendMessage {
Normal {
messages: BackendMessages,
request_complete: bool,
},
Async(backend::Message),
}
pub struct BackendMessages(BytesMut);
impl BackendMessages {
pub fn empty() -> BackendMessages {
BackendMessages(BytesMut::new())
}
}
impl FallibleIterator for BackendMessages {
type Item = backend::Message;
type Error = io::Error;
fn next(&mut self) -> io::Result<Option<backend::Message>> {
backend::Message::parse(&mut self.0)
}
}
pub struct PostgresCodec;
impl Encoder<FrontendMessage> for PostgresCodec {
type Error = io::Error;
fn encode(&mut self, item: FrontendMessage, dst: &mut BytesMut) -> io::Result<()> {
match item {
FrontendMessage::Raw(buf) => dst.extend_from_slice(&buf),
FrontendMessage::CopyData(data) => data.write(dst),
}
Ok(())
}
}
impl Decoder for PostgresCodec {
type Item = BackendMessage;
type Error = io::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<BackendMessage>, io::Error> {
let mut idx = 0;
let mut request_complete = false;
while let Some(header) = backend::Header::parse(&src[idx..])? {
let len = header.len() as usize + 1;
if src[idx..].len() < len {
break;
}
match header.tag() {
backend::NOTICE_RESPONSE_TAG
| backend::NOTIFICATION_RESPONSE_TAG
| backend::PARAMETER_STATUS_TAG => {
if idx == 0 {
let message = backend::Message::parse(src)?.unwrap();
return Ok(Some(BackendMessage::Async(message)));
} else {
break;
}
}
_ => {}
}
idx += len;
if header.tag() == backend::READY_FOR_QUERY_TAG {
request_complete = true;
break;
}
}
if idx == 0 {
Ok(None)
} else {
Ok(Some(BackendMessage::Normal {
messages: BackendMessages(src.split_to(idx)),
request_complete,
}))
}
}
}

View File

@@ -1,264 +0,0 @@
//! Connection configuration.
use crate::connect::connect;
use crate::connect_raw::connect_raw;
use crate::connect_raw::RawConnection;
use crate::tls::MakeTlsConnect;
use crate::tls::TlsConnect;
use crate::{Client, Connection, Error};
use postgres_protocol2::message::frontend::StartupMessageParams;
use std::fmt;
use std::str;
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite};
pub use postgres_protocol2::authentication::sasl::ScramKeys;
use tokio::net::TcpStream;
/// TLS configuration.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum SslMode {
/// Do not use TLS.
Disable,
/// Attempt to connect with TLS but allow sessions without.
Prefer,
/// Require the use of TLS.
Require,
}
/// Channel binding configuration.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum ChannelBinding {
/// Do not use channel binding.
Disable,
/// Attempt to use channel binding but allow sessions without.
Prefer,
/// Require the use of channel binding.
Require,
}
/// Replication mode configuration.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum ReplicationMode {
/// Physical replication.
Physical,
/// Logical replication.
Logical,
}
/// A host specification.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Host {
/// A TCP hostname.
Tcp(String),
}
/// Precomputed keys which may override password during auth.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AuthKeys {
/// A `ClientKey` & `ServerKey` pair for `SCRAM-SHA-256`.
ScramSha256(ScramKeys<32>),
}
/// Connection configuration.
#[derive(Clone, PartialEq, Eq)]
pub struct Config {
pub(crate) host: Host,
pub(crate) port: u16,
pub(crate) password: Option<Vec<u8>>,
pub(crate) auth_keys: Option<Box<AuthKeys>>,
pub(crate) ssl_mode: SslMode,
pub(crate) connect_timeout: Option<Duration>,
pub(crate) channel_binding: ChannelBinding,
pub(crate) server_params: StartupMessageParams,
database: bool,
username: bool,
}
impl Config {
/// Creates a new configuration.
pub fn new(host: String, port: u16) -> Config {
Config {
host: Host::Tcp(host),
port,
password: None,
auth_keys: None,
ssl_mode: SslMode::Prefer,
connect_timeout: None,
channel_binding: ChannelBinding::Prefer,
server_params: StartupMessageParams::default(),
database: false,
username: false,
}
}
/// Sets the user to authenticate with.
///
/// Required.
pub fn user(&mut self, user: &str) -> &mut Config {
self.set_param("user", user)
}
/// Gets the user to authenticate with, if one has been configured with
/// the `user` method.
pub fn user_is_set(&self) -> bool {
self.username
}
/// Sets the password to authenticate with.
pub fn password<T>(&mut self, password: T) -> &mut Config
where
T: AsRef<[u8]>,
{
self.password = Some(password.as_ref().to_vec());
self
}
/// Gets the password to authenticate with, if one has been configured with
/// the `password` method.
pub fn get_password(&self) -> Option<&[u8]> {
self.password.as_deref()
}
/// Sets precomputed protocol-specific keys to authenticate with.
/// When set, this option will override `password`.
/// See [`AuthKeys`] for more information.
pub fn auth_keys(&mut self, keys: AuthKeys) -> &mut Config {
self.auth_keys = Some(Box::new(keys));
self
}
/// Gets precomputed protocol-specific keys to authenticate with.
/// if one has been configured with the `auth_keys` method.
pub fn get_auth_keys(&self) -> Option<AuthKeys> {
self.auth_keys.as_deref().copied()
}
/// Sets the name of the database to connect to.
///
/// Defaults to the user.
pub fn dbname(&mut self, dbname: &str) -> &mut Config {
self.set_param("database", dbname)
}
/// Gets the name of the database to connect to, if one has been configured
/// with the `dbname` method.
pub fn db_is_set(&self) -> bool {
self.database
}
pub fn set_param(&mut self, name: &str, value: &str) -> &mut Config {
if name == "database" {
self.database = true;
} else if name == "user" {
self.username = true;
}
self.server_params.insert(name, value);
self
}
/// Sets the SSL configuration.
///
/// Defaults to `prefer`.
pub fn ssl_mode(&mut self, ssl_mode: SslMode) -> &mut Config {
self.ssl_mode = ssl_mode;
self
}
/// Gets the SSL configuration.
pub fn get_ssl_mode(&self) -> SslMode {
self.ssl_mode
}
/// Gets the hosts that have been added to the configuration with `host`.
pub fn get_host(&self) -> &Host {
&self.host
}
/// Gets the ports that have been added to the configuration with `port`.
pub fn get_port(&self) -> u16 {
self.port
}
/// Sets the timeout applied to socket-level connection attempts.
///
/// Note that hostnames can resolve to multiple IP addresses, and this timeout will apply to each address of each
/// host separately. Defaults to no limit.
pub fn connect_timeout(&mut self, connect_timeout: Duration) -> &mut Config {
self.connect_timeout = Some(connect_timeout);
self
}
/// Gets the connection timeout, if one has been set with the
/// `connect_timeout` method.
pub fn get_connect_timeout(&self) -> Option<&Duration> {
self.connect_timeout.as_ref()
}
/// Sets the channel binding behavior.
///
/// Defaults to `prefer`.
pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config {
self.channel_binding = channel_binding;
self
}
/// Gets the channel binding behavior.
pub fn get_channel_binding(&self) -> ChannelBinding {
self.channel_binding
}
/// Opens a connection to a PostgreSQL database.
///
/// Requires the `runtime` Cargo feature (enabled by default).
pub async fn connect<T>(
&self,
tls: T,
) -> Result<(Client, Connection<TcpStream, T::Stream>), Error>
where
T: MakeTlsConnect<TcpStream>,
{
connect(tls, self).await
}
pub async fn connect_raw<S, T>(
&self,
stream: S,
tls: T,
) -> Result<RawConnection<S, T::Stream>, Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: TlsConnect<S>,
{
connect_raw(stream, tls, self).await
}
}
// Omit password from debug output
impl fmt::Debug for Config {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
struct Redaction {}
impl fmt::Debug for Redaction {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "_")
}
}
f.debug_struct("Config")
.field("password", &self.password.as_ref().map(|_| Redaction {}))
.field("ssl_mode", &self.ssl_mode)
.field("host", &self.host)
.field("port", &self.port)
.field("connect_timeout", &self.connect_timeout)
.field("channel_binding", &self.channel_binding)
.field("server_params", &self.server_params)
.finish()
}
}

View File

@@ -1,75 +0,0 @@
use crate::client::SocketConfig;
use crate::codec::BackendMessage;
use crate::config::Host;
use crate::connect_raw::connect_raw;
use crate::connect_socket::connect_socket;
use crate::tls::{MakeTlsConnect, TlsConnect};
use crate::{Client, Config, Connection, Error, RawConnection};
use postgres_protocol2::message::backend::Message;
use tokio::net::TcpStream;
use tokio::sync::mpsc;
pub async fn connect<T>(
mut tls: T,
config: &Config,
) -> Result<(Client, Connection<TcpStream, T::Stream>), Error>
where
T: MakeTlsConnect<TcpStream>,
{
let hostname = match &config.host {
Host::Tcp(host) => host.as_str(),
};
let tls = tls
.make_tls_connect(hostname)
.map_err(|e| Error::tls(e.into()))?;
match connect_once(&config.host, config.port, tls, config).await {
Ok((client, connection)) => Ok((client, connection)),
Err(e) => Err(e),
}
}
async fn connect_once<T>(
host: &Host,
port: u16,
tls: T,
config: &Config,
) -> Result<(Client, Connection<TcpStream, T::Stream>), Error>
where
T: TlsConnect<TcpStream>,
{
let socket = connect_socket(host, port, config.connect_timeout).await?;
let RawConnection {
stream,
parameters,
delayed_notice,
process_id,
secret_key,
} = connect_raw(socket, tls, config).await?;
let socket_config = SocketConfig {
host: host.clone(),
port,
connect_timeout: config.connect_timeout,
};
let (sender, receiver) = mpsc::unbounded_channel();
let client = Client::new(
sender,
socket_config,
config.ssl_mode,
process_id,
secret_key,
);
// delayed notices are always sent as "Async" messages.
let delayed = delayed_notice
.into_iter()
.map(|m| BackendMessage::Async(Message::NoticeResponse(m)))
.collect();
let connection = Connection::new(stream, delayed, parameters, receiver);
Ok((client, connection))
}

View File

@@ -1,326 +0,0 @@
use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
use crate::config::{self, AuthKeys, Config};
use crate::connect_tls::connect_tls;
use crate::maybe_tls_stream::MaybeTlsStream;
use crate::tls::{TlsConnect, TlsStream};
use crate::Error;
use bytes::BytesMut;
use fallible_iterator::FallibleIterator;
use futures_util::{ready, Sink, SinkExt, Stream, TryStreamExt};
use postgres_protocol2::authentication::sasl;
use postgres_protocol2::authentication::sasl::ScramSha256;
use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message, NoticeResponseBody};
use postgres_protocol2::message::frontend;
use std::collections::HashMap;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_util::codec::Framed;
pub struct StartupStream<S, T> {
inner: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
buf: BackendMessages,
delayed_notice: Vec<NoticeResponseBody>,
}
impl<S, T> Sink<FrontendMessage> for StartupStream<S, T>
where
S: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin,
{
type Error = io::Error;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_ready(cx)
}
fn start_send(mut self: Pin<&mut Self>, item: FrontendMessage) -> io::Result<()> {
Pin::new(&mut self.inner).start_send(item)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_close(cx)
}
}
impl<S, T> Stream for StartupStream<S, T>
where
S: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin,
{
type Item = io::Result<Message>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<io::Result<Message>>> {
loop {
match self.buf.next() {
Ok(Some(message)) => return Poll::Ready(Some(Ok(message))),
Ok(None) => {}
Err(e) => return Poll::Ready(Some(Err(e))),
}
match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
Some(Ok(BackendMessage::Normal { messages, .. })) => self.buf = messages,
Some(Ok(BackendMessage::Async(message))) => return Poll::Ready(Some(Ok(message))),
Some(Err(e)) => return Poll::Ready(Some(Err(e))),
None => return Poll::Ready(None),
}
}
}
}
pub struct RawConnection<S, T> {
pub stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
pub parameters: HashMap<String, String>,
pub delayed_notice: Vec<NoticeResponseBody>,
pub process_id: i32,
pub secret_key: i32,
}
pub async fn connect_raw<S, T>(
stream: S,
tls: T,
config: &Config,
) -> Result<RawConnection<S, T::Stream>, Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: TlsConnect<S>,
{
let stream = connect_tls(stream, config.ssl_mode, tls).await?;
let mut stream = StartupStream {
inner: Framed::new(stream, PostgresCodec),
buf: BackendMessages::empty(),
delayed_notice: Vec::new(),
};
startup(&mut stream, config).await?;
authenticate(&mut stream, config).await?;
let (process_id, secret_key, parameters) = read_info(&mut stream).await?;
Ok(RawConnection {
stream: stream.inner,
parameters,
delayed_notice: stream.delayed_notice,
process_id,
secret_key,
})
}
async fn startup<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin,
{
let mut buf = BytesMut::new();
frontend::startup_message(&config.server_params, &mut buf).map_err(Error::encode)?;
stream
.send(FrontendMessage::Raw(buf.freeze()))
.await
.map_err(Error::io)
}
async fn authenticate<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: TlsStream + Unpin,
{
match stream.try_next().await.map_err(Error::io)? {
Some(Message::AuthenticationOk) => {
can_skip_channel_binding(config)?;
return Ok(());
}
Some(Message::AuthenticationCleartextPassword) => {
can_skip_channel_binding(config)?;
let pass = config
.password
.as_ref()
.ok_or_else(|| Error::config("password missing".into()))?;
authenticate_password(stream, pass).await?;
}
Some(Message::AuthenticationSasl(body)) => {
authenticate_sasl(stream, body, config).await?;
}
Some(Message::AuthenticationMd5Password)
| Some(Message::AuthenticationKerberosV5)
| Some(Message::AuthenticationScmCredential)
| Some(Message::AuthenticationGss)
| Some(Message::AuthenticationSspi) => {
return Err(Error::authentication(
"unsupported authentication method".into(),
))
}
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
Some(_) => return Err(Error::unexpected_message()),
None => return Err(Error::closed()),
}
match stream.try_next().await.map_err(Error::io)? {
Some(Message::AuthenticationOk) => Ok(()),
Some(Message::ErrorResponse(body)) => Err(Error::db(body)),
Some(_) => Err(Error::unexpected_message()),
None => Err(Error::closed()),
}
}
fn can_skip_channel_binding(config: &Config) -> Result<(), Error> {
match config.channel_binding {
config::ChannelBinding::Disable | config::ChannelBinding::Prefer => Ok(()),
config::ChannelBinding::Require => Err(Error::authentication(
"server did not use channel binding".into(),
)),
}
}
async fn authenticate_password<S, T>(
stream: &mut StartupStream<S, T>,
password: &[u8],
) -> Result<(), Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin,
{
let mut buf = BytesMut::new();
frontend::password_message(password, &mut buf).map_err(Error::encode)?;
stream
.send(FrontendMessage::Raw(buf.freeze()))
.await
.map_err(Error::io)
}
async fn authenticate_sasl<S, T>(
stream: &mut StartupStream<S, T>,
body: AuthenticationSaslBody,
config: &Config,
) -> Result<(), Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: TlsStream + Unpin,
{
let mut has_scram = false;
let mut has_scram_plus = false;
let mut mechanisms = body.mechanisms();
while let Some(mechanism) = mechanisms.next().map_err(Error::parse)? {
match mechanism {
sasl::SCRAM_SHA_256 => has_scram = true,
sasl::SCRAM_SHA_256_PLUS => has_scram_plus = true,
_ => {}
}
}
let channel_binding = stream
.inner
.get_ref()
.channel_binding()
.tls_server_end_point
.filter(|_| config.channel_binding != config::ChannelBinding::Disable)
.map(sasl::ChannelBinding::tls_server_end_point);
let (channel_binding, mechanism) = if has_scram_plus {
match channel_binding {
Some(channel_binding) => (channel_binding, sasl::SCRAM_SHA_256_PLUS),
None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
}
} else if has_scram {
match channel_binding {
Some(_) => (sasl::ChannelBinding::unrequested(), sasl::SCRAM_SHA_256),
None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256),
}
} else {
return Err(Error::authentication("unsupported SASL mechanism".into()));
};
if mechanism != sasl::SCRAM_SHA_256_PLUS {
can_skip_channel_binding(config)?;
}
let mut scram = if let Some(AuthKeys::ScramSha256(keys)) = config.get_auth_keys() {
ScramSha256::new_with_keys(keys, channel_binding)
} else if let Some(password) = config.get_password() {
ScramSha256::new(password, channel_binding)
} else {
return Err(Error::config("password or auth keys missing".into()));
};
let mut buf = BytesMut::new();
frontend::sasl_initial_response(mechanism, scram.message(), &mut buf).map_err(Error::encode)?;
stream
.send(FrontendMessage::Raw(buf.freeze()))
.await
.map_err(Error::io)?;
let body = match stream.try_next().await.map_err(Error::io)? {
Some(Message::AuthenticationSaslContinue(body)) => body,
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
Some(_) => return Err(Error::unexpected_message()),
None => return Err(Error::closed()),
};
scram
.update(body.data())
.await
.map_err(|e| Error::authentication(e.into()))?;
let mut buf = BytesMut::new();
frontend::sasl_response(scram.message(), &mut buf).map_err(Error::encode)?;
stream
.send(FrontendMessage::Raw(buf.freeze()))
.await
.map_err(Error::io)?;
let body = match stream.try_next().await.map_err(Error::io)? {
Some(Message::AuthenticationSaslFinal(body)) => body,
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
Some(_) => return Err(Error::unexpected_message()),
None => return Err(Error::closed()),
};
scram
.finish(body.data())
.map_err(|e| Error::authentication(e.into()))?;
Ok(())
}
async fn read_info<S, T>(
stream: &mut StartupStream<S, T>,
) -> Result<(i32, i32, HashMap<String, String>), Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin,
{
let mut process_id = 0;
let mut secret_key = 0;
let mut parameters = HashMap::new();
loop {
match stream.try_next().await.map_err(Error::io)? {
Some(Message::BackendKeyData(body)) => {
process_id = body.process_id();
secret_key = body.secret_key();
}
Some(Message::ParameterStatus(body)) => {
parameters.insert(
body.name().map_err(Error::parse)?.to_string(),
body.value().map_err(Error::parse)?.to_string(),
);
}
Some(Message::NoticeResponse(body)) => stream.delayed_notice.push(body),
Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key, parameters)),
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
Some(_) => return Err(Error::unexpected_message()),
None => return Err(Error::closed()),
}
}
}

View File

@@ -1,65 +0,0 @@
use crate::config::Host;
use crate::Error;
use std::future::Future;
use std::io;
use std::time::Duration;
use tokio::net::{self, TcpStream};
use tokio::time;
pub(crate) async fn connect_socket(
host: &Host,
port: u16,
connect_timeout: Option<Duration>,
) -> Result<TcpStream, Error> {
match host {
Host::Tcp(host) => {
let addrs = net::lookup_host((&**host, port))
.await
.map_err(Error::connect)?;
let mut last_err = None;
for addr in addrs {
let stream =
match connect_with_timeout(TcpStream::connect(addr), connect_timeout).await {
Ok(stream) => stream,
Err(e) => {
last_err = Some(e);
continue;
}
};
stream.set_nodelay(true).map_err(Error::connect)?;
return Ok(stream);
}
Err(last_err.unwrap_or_else(|| {
Error::connect(io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve any addresses",
))
}))
}
}
}
async fn connect_with_timeout<F, T>(connect: F, timeout: Option<Duration>) -> Result<T, Error>
where
F: Future<Output = io::Result<T>>,
{
match timeout {
Some(timeout) => match time::timeout(timeout, connect).await {
Ok(Ok(socket)) => Ok(socket),
Ok(Err(e)) => Err(Error::connect(e)),
Err(_) => Err(Error::connect(io::Error::new(
io::ErrorKind::TimedOut,
"connection timed out",
))),
},
None => match connect.await {
Ok(socket) => Ok(socket),
Err(e) => Err(Error::connect(e)),
},
}
}

View File

@@ -1,48 +0,0 @@
use crate::config::SslMode;
use crate::maybe_tls_stream::MaybeTlsStream;
use crate::tls::private::ForcePrivateApi;
use crate::tls::TlsConnect;
use crate::Error;
use bytes::BytesMut;
use postgres_protocol2::message::frontend;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
pub async fn connect_tls<S, T>(
mut stream: S,
mode: SslMode,
tls: T,
) -> Result<MaybeTlsStream<S, T::Stream>, Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: TlsConnect<S>,
{
match mode {
SslMode::Disable => return Ok(MaybeTlsStream::Raw(stream)),
SslMode::Prefer if !tls.can_connect(ForcePrivateApi) => {
return Ok(MaybeTlsStream::Raw(stream))
}
SslMode::Prefer | SslMode::Require => {}
}
let mut buf = BytesMut::new();
frontend::ssl_request(&mut buf);
stream.write_all(&buf).await.map_err(Error::io)?;
let mut buf = [0];
stream.read_exact(&mut buf).await.map_err(Error::io)?;
if buf[0] != b'S' {
if SslMode::Require == mode {
return Err(Error::tls("server does not support TLS".into()));
} else {
return Ok(MaybeTlsStream::Raw(stream));
}
}
let stream = tls
.connect(stream)
.await
.map_err(|e| Error::tls(e.into()))?;
Ok(MaybeTlsStream::Tls(stream))
}

View File

@@ -1,323 +0,0 @@
use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec};
use crate::error::DbError;
use crate::maybe_tls_stream::MaybeTlsStream;
use crate::{AsyncMessage, Error, Notification};
use bytes::BytesMut;
use fallible_iterator::FallibleIterator;
use futures_util::{ready, Sink, Stream};
use log::{info, trace};
use postgres_protocol2::message::backend::Message;
use postgres_protocol2::message::frontend;
use std::collections::{HashMap, VecDeque};
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::sync::mpsc;
use tokio_util::codec::Framed;
use tokio_util::sync::PollSender;
pub enum RequestMessages {
Single(FrontendMessage),
}
pub struct Request {
pub messages: RequestMessages,
pub sender: mpsc::Sender<BackendMessages>,
}
pub struct Response {
sender: PollSender<BackendMessages>,
}
#[derive(PartialEq, Debug)]
enum State {
Active,
Terminating,
Closing,
}
/// A connection to a PostgreSQL database.
///
/// This is one half of what is returned when a new connection is established. It performs the actual IO with the
/// server, and should generally be spawned off onto an executor to run in the background.
///
/// `Connection` implements `Future`, and only resolves when the connection is closed, either because a fatal error has
/// occurred, or because its associated `Client` has dropped and all outstanding work has completed.
#[must_use = "futures do nothing unless polled"]
pub struct Connection<S, T> {
/// HACK: we need this in the Neon Proxy.
pub stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
/// HACK: we need this in the Neon Proxy to forward params.
pub parameters: HashMap<String, String>,
receiver: mpsc::UnboundedReceiver<Request>,
pending_request: Option<RequestMessages>,
pending_responses: VecDeque<BackendMessage>,
responses: VecDeque<Response>,
state: State,
}
impl<S, T> Connection<S, T>
where
S: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin,
{
pub(crate) fn new(
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
pending_responses: VecDeque<BackendMessage>,
parameters: HashMap<String, String>,
receiver: mpsc::UnboundedReceiver<Request>,
) -> Connection<S, T> {
Connection {
stream,
parameters,
receiver,
pending_request: None,
pending_responses,
responses: VecDeque::new(),
state: State::Active,
}
}
fn poll_response(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<BackendMessage, Error>>> {
if let Some(message) = self.pending_responses.pop_front() {
trace!("retrying pending response");
return Poll::Ready(Some(Ok(message)));
}
Pin::new(&mut self.stream)
.poll_next(cx)
.map(|o| o.map(|r| r.map_err(Error::io)))
}
fn poll_read(&mut self, cx: &mut Context<'_>) -> Result<Option<AsyncMessage>, Error> {
if self.state != State::Active {
trace!("poll_read: done");
return Ok(None);
}
loop {
let message = match self.poll_response(cx)? {
Poll::Ready(Some(message)) => message,
Poll::Ready(None) => return Err(Error::closed()),
Poll::Pending => {
trace!("poll_read: waiting on response");
return Ok(None);
}
};
let (mut messages, request_complete) = match message {
BackendMessage::Async(Message::NoticeResponse(body)) => {
let error = DbError::parse(&mut body.fields()).map_err(Error::parse)?;
return Ok(Some(AsyncMessage::Notice(error)));
}
BackendMessage::Async(Message::NotificationResponse(body)) => {
let notification = Notification {
process_id: body.process_id(),
channel: body.channel().map_err(Error::parse)?.to_string(),
payload: body.message().map_err(Error::parse)?.to_string(),
};
return Ok(Some(AsyncMessage::Notification(notification)));
}
BackendMessage::Async(Message::ParameterStatus(body)) => {
self.parameters.insert(
body.name().map_err(Error::parse)?.to_string(),
body.value().map_err(Error::parse)?.to_string(),
);
continue;
}
BackendMessage::Async(_) => unreachable!(),
BackendMessage::Normal {
messages,
request_complete,
} => (messages, request_complete),
};
let mut response = match self.responses.pop_front() {
Some(response) => response,
None => match messages.next().map_err(Error::parse)? {
Some(Message::ErrorResponse(error)) => return Err(Error::db(error)),
_ => return Err(Error::unexpected_message()),
},
};
match response.sender.poll_reserve(cx) {
Poll::Ready(Ok(())) => {
let _ = response.sender.send_item(messages);
if !request_complete {
self.responses.push_front(response);
}
}
Poll::Ready(Err(_)) => {
// we need to keep paging through the rest of the messages even if the receiver's hung up
if !request_complete {
self.responses.push_front(response);
}
}
Poll::Pending => {
self.responses.push_front(response);
self.pending_responses.push_back(BackendMessage::Normal {
messages,
request_complete,
});
trace!("poll_read: waiting on sender");
return Ok(None);
}
}
}
}
fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll<Option<RequestMessages>> {
if let Some(messages) = self.pending_request.take() {
trace!("retrying pending request");
return Poll::Ready(Some(messages));
}
if self.receiver.is_closed() {
return Poll::Ready(None);
}
match self.receiver.poll_recv(cx) {
Poll::Ready(Some(request)) => {
trace!("polled new request");
self.responses.push_back(Response {
sender: PollSender::new(request.sender),
});
Poll::Ready(Some(request.messages))
}
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
fn poll_write(&mut self, cx: &mut Context<'_>) -> Result<bool, Error> {
loop {
if self.state == State::Closing {
trace!("poll_write: done");
return Ok(false);
}
if Pin::new(&mut self.stream)
.poll_ready(cx)
.map_err(Error::io)?
.is_pending()
{
trace!("poll_write: waiting on socket");
return Ok(false);
}
let request = match self.poll_request(cx) {
Poll::Ready(Some(request)) => request,
Poll::Ready(None) if self.responses.is_empty() && self.state == State::Active => {
trace!("poll_write: at eof, terminating");
self.state = State::Terminating;
let mut request = BytesMut::new();
frontend::terminate(&mut request);
RequestMessages::Single(FrontendMessage::Raw(request.freeze()))
}
Poll::Ready(None) => {
trace!(
"poll_write: at eof, pending responses {}",
self.responses.len()
);
return Ok(true);
}
Poll::Pending => {
trace!("poll_write: waiting on request");
return Ok(true);
}
};
match request {
RequestMessages::Single(request) => {
Pin::new(&mut self.stream)
.start_send(request)
.map_err(Error::io)?;
if self.state == State::Terminating {
trace!("poll_write: sent eof, closing");
self.state = State::Closing;
}
}
}
}
}
fn poll_flush(&mut self, cx: &mut Context<'_>) -> Result<(), Error> {
match Pin::new(&mut self.stream)
.poll_flush(cx)
.map_err(Error::io)?
{
Poll::Ready(()) => trace!("poll_flush: flushed"),
Poll::Pending => trace!("poll_flush: waiting on socket"),
}
Ok(())
}
fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
if self.state != State::Closing {
return Poll::Pending;
}
match Pin::new(&mut self.stream)
.poll_close(cx)
.map_err(Error::io)?
{
Poll::Ready(()) => {
trace!("poll_shutdown: complete");
Poll::Ready(Ok(()))
}
Poll::Pending => {
trace!("poll_shutdown: waiting on socket");
Poll::Pending
}
}
}
/// Returns the value of a runtime parameter for this connection.
pub fn parameter(&self, name: &str) -> Option<&str> {
self.parameters.get(name).map(|s| &**s)
}
/// Polls for asynchronous messages from the server.
///
/// The server can send notices as well as notifications asynchronously to the client. Applications that wish to
/// examine those messages should use this method to drive the connection rather than its `Future` implementation.
pub fn poll_message(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<AsyncMessage, Error>>> {
let message = self.poll_read(cx)?;
let want_flush = self.poll_write(cx)?;
if want_flush {
self.poll_flush(cx)?;
}
match message {
Some(message) => Poll::Ready(Some(Ok(message))),
None => match self.poll_shutdown(cx) {
Poll::Ready(Ok(())) => Poll::Ready(None),
Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))),
Poll::Pending => Poll::Pending,
},
}
}
}
impl<S, T> Future for Connection<S, T>
where
S: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin,
{
type Output = Result<(), Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
while let Some(message) = ready!(self.poll_message(cx)?) {
if let AsyncMessage::Notice(notice) = message {
info!("{}: {}", notice.severity(), notice.message());
}
}
Poll::Ready(Ok(()))
}
}

View File

@@ -1,495 +0,0 @@
//! Errors.
use fallible_iterator::FallibleIterator;
use postgres_protocol2::message::backend::{ErrorFields, ErrorResponseBody};
use std::error::{self, Error as _Error};
use std::fmt;
use std::io;
pub use self::sqlstate::*;
#[allow(clippy::unreadable_literal)]
mod sqlstate;
/// The severity of a Postgres error or notice.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum Severity {
/// PANIC
Panic,
/// FATAL
Fatal,
/// ERROR
Error,
/// WARNING
Warning,
/// NOTICE
Notice,
/// DEBUG
Debug,
/// INFO
Info,
/// LOG
Log,
}
impl fmt::Display for Severity {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match *self {
Severity::Panic => "PANIC",
Severity::Fatal => "FATAL",
Severity::Error => "ERROR",
Severity::Warning => "WARNING",
Severity::Notice => "NOTICE",
Severity::Debug => "DEBUG",
Severity::Info => "INFO",
Severity::Log => "LOG",
};
fmt.write_str(s)
}
}
impl Severity {
fn from_str(s: &str) -> Option<Severity> {
match s {
"PANIC" => Some(Severity::Panic),
"FATAL" => Some(Severity::Fatal),
"ERROR" => Some(Severity::Error),
"WARNING" => Some(Severity::Warning),
"NOTICE" => Some(Severity::Notice),
"DEBUG" => Some(Severity::Debug),
"INFO" => Some(Severity::Info),
"LOG" => Some(Severity::Log),
_ => None,
}
}
}
/// A Postgres error or notice.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DbError {
severity: String,
parsed_severity: Option<Severity>,
code: SqlState,
message: String,
detail: Option<String>,
hint: Option<String>,
position: Option<ErrorPosition>,
where_: Option<String>,
schema: Option<String>,
table: Option<String>,
column: Option<String>,
datatype: Option<String>,
constraint: Option<String>,
file: Option<String>,
line: Option<u32>,
routine: Option<String>,
}
impl DbError {
pub(crate) fn parse(fields: &mut ErrorFields<'_>) -> io::Result<DbError> {
let mut severity = None;
let mut parsed_severity = None;
let mut code = None;
let mut message = None;
let mut detail = None;
let mut hint = None;
let mut normal_position = None;
let mut internal_position = None;
let mut internal_query = None;
let mut where_ = None;
let mut schema = None;
let mut table = None;
let mut column = None;
let mut datatype = None;
let mut constraint = None;
let mut file = None;
let mut line = None;
let mut routine = None;
while let Some(field) = fields.next()? {
match field.type_() {
b'S' => severity = Some(field.value().to_owned()),
b'C' => code = Some(SqlState::from_code(field.value())),
b'M' => message = Some(field.value().to_owned()),
b'D' => detail = Some(field.value().to_owned()),
b'H' => hint = Some(field.value().to_owned()),
b'P' => {
normal_position = Some(field.value().parse::<u32>().map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
"`P` field did not contain an integer",
)
})?);
}
b'p' => {
internal_position = Some(field.value().parse::<u32>().map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
"`p` field did not contain an integer",
)
})?);
}
b'q' => internal_query = Some(field.value().to_owned()),
b'W' => where_ = Some(field.value().to_owned()),
b's' => schema = Some(field.value().to_owned()),
b't' => table = Some(field.value().to_owned()),
b'c' => column = Some(field.value().to_owned()),
b'd' => datatype = Some(field.value().to_owned()),
b'n' => constraint = Some(field.value().to_owned()),
b'F' => file = Some(field.value().to_owned()),
b'L' => {
line = Some(field.value().parse::<u32>().map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
"`L` field did not contain an integer",
)
})?);
}
b'R' => routine = Some(field.value().to_owned()),
b'V' => {
parsed_severity = Some(Severity::from_str(field.value()).ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"`V` field contained an invalid value",
)
})?);
}
_ => {}
}
}
Ok(DbError {
severity: severity
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "`S` field missing"))?,
parsed_severity,
code: code
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "`C` field missing"))?,
message: message
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "`M` field missing"))?,
detail,
hint,
position: match normal_position {
Some(position) => Some(ErrorPosition::Original(position)),
None => match internal_position {
Some(position) => Some(ErrorPosition::Internal {
position,
query: internal_query.ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"`q` field missing but `p` field present",
)
})?,
}),
None => None,
},
},
where_,
schema,
table,
column,
datatype,
constraint,
file,
line,
routine,
})
}
/// The field contents are ERROR, FATAL, or PANIC (in an error message),
/// or WARNING, NOTICE, DEBUG, INFO, or LOG (in a notice message), or a
/// localized translation of one of these.
pub fn severity(&self) -> &str {
&self.severity
}
/// A parsed, nonlocalized version of `severity`. (PostgreSQL 9.6+)
pub fn parsed_severity(&self) -> Option<Severity> {
self.parsed_severity
}
/// The SQLSTATE code for the error.
pub fn code(&self) -> &SqlState {
&self.code
}
/// The primary human-readable error message.
///
/// This should be accurate but terse (typically one line).
pub fn message(&self) -> &str {
&self.message
}
/// An optional secondary error message carrying more detail about the
/// problem.
///
/// Might run to multiple lines.
pub fn detail(&self) -> Option<&str> {
self.detail.as_deref()
}
/// An optional suggestion what to do about the problem.
///
/// This is intended to differ from `detail` in that it offers advice
/// (potentially inappropriate) rather than hard facts. Might run to
/// multiple lines.
pub fn hint(&self) -> Option<&str> {
self.hint.as_deref()
}
/// An optional error cursor position into either the original query string
/// or an internally generated query.
pub fn position(&self) -> Option<&ErrorPosition> {
self.position.as_ref()
}
/// An indication of the context in which the error occurred.
///
/// Presently this includes a call stack traceback of active procedural
/// language functions and internally-generated queries. The trace is one
/// entry per line, most recent first.
pub fn where_(&self) -> Option<&str> {
self.where_.as_deref()
}
/// If the error was associated with a specific database object, the name
/// of the schema containing that object, if any. (PostgreSQL 9.3+)
pub fn schema(&self) -> Option<&str> {
self.schema.as_deref()
}
/// If the error was associated with a specific table, the name of the
/// table. (Refer to the schema name field for the name of the table's
/// schema.) (PostgreSQL 9.3+)
pub fn table(&self) -> Option<&str> {
self.table.as_deref()
}
/// If the error was associated with a specific table column, the name of
/// the column.
///
/// (Refer to the schema and table name fields to identify the table.)
/// (PostgreSQL 9.3+)
pub fn column(&self) -> Option<&str> {
self.column.as_deref()
}
/// If the error was associated with a specific data type, the name of the
/// data type. (Refer to the schema name field for the name of the data
/// type's schema.) (PostgreSQL 9.3+)
pub fn datatype(&self) -> Option<&str> {
self.datatype.as_deref()
}
/// If the error was associated with a specific constraint, the name of the
/// constraint.
///
/// Refer to fields listed above for the associated table or domain.
/// (For this purpose, indexes are treated as constraints, even if they
/// weren't created with constraint syntax.) (PostgreSQL 9.3+)
pub fn constraint(&self) -> Option<&str> {
self.constraint.as_deref()
}
/// The file name of the source-code location where the error was reported.
pub fn file(&self) -> Option<&str> {
self.file.as_deref()
}
/// The line number of the source-code location where the error was
/// reported.
pub fn line(&self) -> Option<u32> {
self.line
}
/// The name of the source-code routine reporting the error.
pub fn routine(&self) -> Option<&str> {
self.routine.as_deref()
}
}
impl fmt::Display for DbError {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(fmt, "{}: {}", self.severity, self.message)?;
if let Some(detail) = &self.detail {
write!(fmt, "\nDETAIL: {}", detail)?;
}
if let Some(hint) = &self.hint {
write!(fmt, "\nHINT: {}", hint)?;
}
Ok(())
}
}
impl error::Error for DbError {}
/// Represents the position of an error in a query.
#[derive(Clone, PartialEq, Eq, Debug)]
pub enum ErrorPosition {
/// A position in the original query.
Original(u32),
/// A position in an internally generated query.
Internal {
/// The byte position.
position: u32,
/// A query generated by the Postgres server.
query: String,
},
}
#[derive(Debug, PartialEq)]
enum Kind {
Io,
UnexpectedMessage,
Tls,
ToSql(usize),
FromSql(usize),
Column(String),
Closed,
Db,
Parse,
Encode,
Authentication,
Config,
Connect,
Timeout,
}
struct ErrorInner {
kind: Kind,
cause: Option<Box<dyn error::Error + Sync + Send>>,
}
/// An error communicating with the Postgres server.
pub struct Error(Box<ErrorInner>);
impl fmt::Debug for Error {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("Error")
.field("kind", &self.0.kind)
.field("cause", &self.0.cause)
.finish()
}
}
impl fmt::Display for Error {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.0.kind {
Kind::Io => fmt.write_str("error communicating with the server")?,
Kind::UnexpectedMessage => fmt.write_str("unexpected message from server")?,
Kind::Tls => fmt.write_str("error performing TLS handshake")?,
Kind::ToSql(idx) => write!(fmt, "error serializing parameter {}", idx)?,
Kind::FromSql(idx) => write!(fmt, "error deserializing column {}", idx)?,
Kind::Column(column) => write!(fmt, "invalid column `{}`", column)?,
Kind::Closed => fmt.write_str("connection closed")?,
Kind::Db => fmt.write_str("db error")?,
Kind::Parse => fmt.write_str("error parsing response from server")?,
Kind::Encode => fmt.write_str("error encoding message to server")?,
Kind::Authentication => fmt.write_str("authentication error")?,
Kind::Config => fmt.write_str("invalid configuration")?,
Kind::Connect => fmt.write_str("error connecting to server")?,
Kind::Timeout => fmt.write_str("timeout waiting for server")?,
};
if let Some(ref cause) = self.0.cause {
write!(fmt, ": {}", cause)?;
}
Ok(())
}
}
impl error::Error for Error {
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
self.0.cause.as_ref().map(|e| &**e as _)
}
}
impl Error {
/// Consumes the error, returning its cause.
pub fn into_source(self) -> Option<Box<dyn error::Error + Sync + Send>> {
self.0.cause
}
/// Returns the source of this error if it was a `DbError`.
///
/// This is a simple convenience method.
pub fn as_db_error(&self) -> Option<&DbError> {
self.source().and_then(|e| e.downcast_ref::<DbError>())
}
/// Determines if the error was associated with closed connection.
pub fn is_closed(&self) -> bool {
self.0.kind == Kind::Closed
}
/// Returns the SQLSTATE error code associated with the error.
///
/// This is a convenience method that downcasts the cause to a `DbError` and returns its code.
pub fn code(&self) -> Option<&SqlState> {
self.as_db_error().map(DbError::code)
}
fn new(kind: Kind, cause: Option<Box<dyn error::Error + Sync + Send>>) -> Error {
Error(Box::new(ErrorInner { kind, cause }))
}
pub(crate) fn closed() -> Error {
Error::new(Kind::Closed, None)
}
pub(crate) fn unexpected_message() -> Error {
Error::new(Kind::UnexpectedMessage, None)
}
#[allow(clippy::needless_pass_by_value)]
pub(crate) fn db(error: ErrorResponseBody) -> Error {
match DbError::parse(&mut error.fields()) {
Ok(e) => Error::new(Kind::Db, Some(Box::new(e))),
Err(e) => Error::new(Kind::Parse, Some(Box::new(e))),
}
}
pub(crate) fn parse(e: io::Error) -> Error {
Error::new(Kind::Parse, Some(Box::new(e)))
}
pub(crate) fn encode(e: io::Error) -> Error {
Error::new(Kind::Encode, Some(Box::new(e)))
}
#[allow(clippy::wrong_self_convention)]
pub(crate) fn to_sql(e: Box<dyn error::Error + Sync + Send>, idx: usize) -> Error {
Error::new(Kind::ToSql(idx), Some(e))
}
pub(crate) fn from_sql(e: Box<dyn error::Error + Sync + Send>, idx: usize) -> Error {
Error::new(Kind::FromSql(idx), Some(e))
}
pub(crate) fn column(column: String) -> Error {
Error::new(Kind::Column(column), None)
}
pub(crate) fn tls(e: Box<dyn error::Error + Sync + Send>) -> Error {
Error::new(Kind::Tls, Some(e))
}
pub(crate) fn io(e: io::Error) -> Error {
Error::new(Kind::Io, Some(Box::new(e)))
}
pub(crate) fn authentication(e: Box<dyn error::Error + Sync + Send>) -> Error {
Error::new(Kind::Authentication, Some(e))
}
pub(crate) fn config(e: Box<dyn error::Error + Sync + Send>) -> Error {
Error::new(Kind::Config, Some(e))
}
pub(crate) fn connect(e: io::Error) -> Error {
Error::new(Kind::Connect, Some(Box::new(e)))
}
#[doc(hidden)]
pub fn __private_api_timeout() -> Error {
Error::new(Kind::Timeout, None)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,64 +0,0 @@
use crate::query::{self, RowStream};
use crate::types::Type;
use crate::{Client, Error, Transaction};
use async_trait::async_trait;
use postgres_protocol2::Oid;
mod private {
pub trait Sealed {}
}
/// A trait allowing abstraction over connections and transactions.
///
/// This trait is "sealed", and cannot be implemented outside of this crate.
#[async_trait]
pub trait GenericClient: private::Sealed {
async fn query_raw_txt<S, I>(&mut self, statement: &str, params: I) -> Result<RowStream, Error>
where
S: AsRef<str> + Sync + Send,
I: IntoIterator<Item = Option<S>> + Sync + Send,
I::IntoIter: ExactSizeIterator + Sync + Send;
/// Query for type information
async fn get_type(&mut self, oid: Oid) -> Result<Type, Error>;
}
impl private::Sealed for Client {}
#[async_trait]
impl GenericClient for Client {
async fn query_raw_txt<S, I>(&mut self, statement: &str, params: I) -> Result<RowStream, Error>
where
S: AsRef<str> + Sync + Send,
I: IntoIterator<Item = Option<S>> + Sync + Send,
I::IntoIter: ExactSizeIterator + Sync + Send,
{
query::query_txt(&mut self.inner, statement, params).await
}
/// Query for type information
async fn get_type(&mut self, oid: Oid) -> Result<Type, Error> {
crate::prepare::get_type(&mut self.inner, &mut self.cached_typeinfo, oid).await
}
}
impl private::Sealed for Transaction<'_> {}
#[async_trait]
#[allow(clippy::needless_lifetimes)]
impl GenericClient for Transaction<'_> {
async fn query_raw_txt<S, I>(&mut self, statement: &str, params: I) -> Result<RowStream, Error>
where
S: AsRef<str> + Sync + Send,
I: IntoIterator<Item = Option<S>> + Sync + Send,
I::IntoIter: ExactSizeIterator + Sync + Send,
{
query::query_txt(&mut self.client().inner, statement, params).await
}
/// Query for type information
async fn get_type(&mut self, oid: Oid) -> Result<Type, Error> {
let client = self.client();
crate::prepare::get_type(&mut client.inner, &mut client.cached_typeinfo, oid).await
}
}

View File

@@ -1,115 +0,0 @@
//! An asynchronous, pipelined, PostgreSQL client.
#![warn(rust_2018_idioms, clippy::all)]
pub use crate::cancel_token::CancelToken;
pub use crate::client::{Client, SocketConfig};
pub use crate::config::Config;
pub use crate::connect_raw::RawConnection;
pub use crate::connection::Connection;
use crate::error::DbError;
pub use crate::error::Error;
pub use crate::generic_client::GenericClient;
pub use crate::query::RowStream;
pub use crate::row::Row;
pub use crate::statement::{Column, Statement};
pub use crate::tls::NoTls;
// pub use crate::to_statement::ToStatement;
pub use crate::transaction::Transaction;
pub use crate::transaction_builder::{IsolationLevel, TransactionBuilder};
use crate::types::ToSql;
use postgres_protocol2::message::backend::ReadyForQueryBody;
/// After executing a query, the connection will be in one of these states
#[derive(Clone, Copy, Debug, PartialEq)]
#[repr(u8)]
pub enum ReadyForQueryStatus {
/// Connection state is unknown
Unknown,
/// Connection is idle (no transactions)
Idle = b'I',
/// Connection is in a transaction block
Transaction = b'T',
/// Connection is in a failed transaction block
FailedTransaction = b'E',
}
impl From<ReadyForQueryBody> for ReadyForQueryStatus {
fn from(value: ReadyForQueryBody) -> Self {
match value.status() {
b'I' => Self::Idle,
b'T' => Self::Transaction,
b'E' => Self::FailedTransaction,
_ => Self::Unknown,
}
}
}
mod cancel_query;
mod cancel_query_raw;
mod cancel_token;
mod client;
mod codec;
pub mod config;
mod connect;
mod connect_raw;
mod connect_socket;
mod connect_tls;
mod connection;
pub mod error;
mod generic_client;
pub mod maybe_tls_stream;
mod prepare;
mod query;
pub mod row;
mod simple_query;
mod statement;
pub mod tls;
// mod to_statement;
mod transaction;
mod transaction_builder;
pub mod types;
/// An asynchronous notification.
#[derive(Clone, Debug)]
pub struct Notification {
process_id: i32,
channel: String,
payload: String,
}
impl Notification {
/// The process ID of the notifying backend process.
pub fn process_id(&self) -> i32 {
self.process_id
}
/// The name of the channel that the notify has been raised on.
pub fn channel(&self) -> &str {
&self.channel
}
/// The "payload" string passed from the notifying process.
pub fn payload(&self) -> &str {
&self.payload
}
}
/// An asynchronous message from the server.
#[allow(clippy::large_enum_variant)]
#[derive(Debug, Clone)]
pub enum AsyncMessage {
/// A notice.
///
/// Notices use the same format as errors, but aren't "errors" per-se.
Notice(DbError),
/// A notification.
///
/// Connections can subscribe to notifications with the `LISTEN` command.
Notification(Notification),
}
fn slice_iter<'a>(
s: &'a [&'a (dyn ToSql + Sync)],
) -> impl ExactSizeIterator<Item = &'a (dyn ToSql + Sync)> + 'a {
s.iter().map(|s| *s as _)
}

View File

@@ -1,77 +0,0 @@
//! MaybeTlsStream.
//!
//! Represents a stream that may or may not be encrypted with TLS.
use crate::tls::{ChannelBinding, TlsStream};
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
/// A stream that may or may not be encrypted with TLS.
pub enum MaybeTlsStream<S, T> {
/// An unencrypted stream.
Raw(S),
/// An encrypted stream.
Tls(T),
}
impl<S, T> AsyncRead for MaybeTlsStream<S, T>
where
S: AsyncRead + Unpin,
T: AsyncRead + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match &mut *self {
MaybeTlsStream::Raw(s) => Pin::new(s).poll_read(cx, buf),
MaybeTlsStream::Tls(s) => Pin::new(s).poll_read(cx, buf),
}
}
}
impl<S, T> AsyncWrite for MaybeTlsStream<S, T>
where
S: AsyncWrite + Unpin,
T: AsyncWrite + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match &mut *self {
MaybeTlsStream::Raw(s) => Pin::new(s).poll_write(cx, buf),
MaybeTlsStream::Tls(s) => Pin::new(s).poll_write(cx, buf),
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match &mut *self {
MaybeTlsStream::Raw(s) => Pin::new(s).poll_flush(cx),
MaybeTlsStream::Tls(s) => Pin::new(s).poll_flush(cx),
}
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match &mut *self {
MaybeTlsStream::Raw(s) => Pin::new(s).poll_shutdown(cx),
MaybeTlsStream::Tls(s) => Pin::new(s).poll_shutdown(cx),
}
}
}
impl<S, T> TlsStream for MaybeTlsStream<S, T>
where
S: AsyncRead + AsyncWrite + Unpin,
T: TlsStream + Unpin,
{
fn channel_binding(&self) -> ChannelBinding {
match self {
MaybeTlsStream::Raw(_) => ChannelBinding::none(),
MaybeTlsStream::Tls(s) => s.channel_binding(),
}
}
}

View File

@@ -1,295 +0,0 @@
use crate::client::{CachedTypeInfo, InnerClient};
use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::error::SqlState;
use crate::types::{Field, Kind, Oid, Type};
use crate::{query, slice_iter};
use crate::{Column, Error, Statement};
use bytes::Bytes;
use fallible_iterator::FallibleIterator;
use futures_util::{pin_mut, StreamExt, TryStreamExt};
use log::debug;
use postgres_protocol2::message::backend::Message;
use postgres_protocol2::message::frontend;
use std::future::Future;
use std::pin::{pin, Pin};
use std::sync::atomic::{AtomicUsize, Ordering};
pub(crate) const TYPEINFO_QUERY: &str = "\
SELECT t.typname, t.typtype, t.typelem, r.rngsubtype, t.typbasetype, n.nspname, t.typrelid
FROM pg_catalog.pg_type t
LEFT OUTER JOIN pg_catalog.pg_range r ON r.rngtypid = t.oid
INNER JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid
WHERE t.oid = $1
";
// Range types weren't added until Postgres 9.2, so pg_range may not exist
const TYPEINFO_FALLBACK_QUERY: &str = "\
SELECT t.typname, t.typtype, t.typelem, NULL::OID, t.typbasetype, n.nspname, t.typrelid
FROM pg_catalog.pg_type t
INNER JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid
WHERE t.oid = $1
";
const TYPEINFO_ENUM_QUERY: &str = "\
SELECT enumlabel
FROM pg_catalog.pg_enum
WHERE enumtypid = $1
ORDER BY enumsortorder
";
// Postgres 9.0 didn't have enumsortorder
const TYPEINFO_ENUM_FALLBACK_QUERY: &str = "\
SELECT enumlabel
FROM pg_catalog.pg_enum
WHERE enumtypid = $1
ORDER BY oid
";
pub(crate) const TYPEINFO_COMPOSITE_QUERY: &str = "\
SELECT attname, atttypid
FROM pg_catalog.pg_attribute
WHERE attrelid = $1
AND NOT attisdropped
AND attnum > 0
ORDER BY attnum
";
static NEXT_ID: AtomicUsize = AtomicUsize::new(0);
pub async fn prepare(
client: &mut InnerClient,
cache: &mut CachedTypeInfo,
query: &str,
types: &[Type],
) -> Result<Statement, Error> {
let name = format!("s{}", NEXT_ID.fetch_add(1, Ordering::SeqCst));
let buf = encode(client, &name, query, types)?;
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
match responses.next().await? {
Message::ParseComplete => {}
_ => return Err(Error::unexpected_message()),
}
let parameter_description = match responses.next().await? {
Message::ParameterDescription(body) => body,
_ => return Err(Error::unexpected_message()),
};
let row_description = match responses.next().await? {
Message::RowDescription(body) => Some(body),
Message::NoData => None,
_ => return Err(Error::unexpected_message()),
};
let mut parameters = vec![];
let mut it = parameter_description.parameters();
while let Some(oid) = it.next().map_err(Error::parse)? {
let type_ = get_type(client, cache, oid).await?;
parameters.push(type_);
}
let mut columns = vec![];
if let Some(row_description) = row_description {
let mut it = row_description.fields();
while let Some(field) = it.next().map_err(Error::parse)? {
let type_ = get_type(client, cache, field.type_oid()).await?;
let column = Column::new(field.name().to_string(), type_, field);
columns.push(column);
}
}
Ok(Statement::new(name, parameters, columns))
}
fn prepare_rec<'a>(
client: &'a mut InnerClient,
cache: &'a mut CachedTypeInfo,
query: &'a str,
types: &'a [Type],
) -> Pin<Box<dyn Future<Output = Result<Statement, Error>> + 'a + Send>> {
Box::pin(prepare(client, cache, query, types))
}
fn encode(
client: &mut InnerClient,
name: &str,
query: &str,
types: &[Type],
) -> Result<Bytes, Error> {
if types.is_empty() {
debug!("preparing query {}: {}", name, query);
} else {
debug!("preparing query {} with types {:?}: {}", name, types, query);
}
client.with_buf(|buf| {
frontend::parse(name, query, types.iter().map(Type::oid), buf).map_err(Error::encode)?;
frontend::describe(b'S', name, buf).map_err(Error::encode)?;
frontend::sync(buf);
Ok(buf.split().freeze())
})
}
pub async fn get_type(
client: &mut InnerClient,
cache: &mut CachedTypeInfo,
oid: Oid,
) -> Result<Type, Error> {
if let Some(type_) = Type::from_oid(oid) {
return Ok(type_);
}
if let Some(type_) = cache.type_(oid) {
return Ok(type_);
}
let stmt = typeinfo_statement(client, cache).await?;
let rows = query::query(client, stmt, slice_iter(&[&oid])).await?;
pin_mut!(rows);
let row = match rows.try_next().await? {
Some(row) => row,
None => return Err(Error::unexpected_message()),
};
let name: String = row.try_get(stmt.columns(), 0)?;
let type_: i8 = row.try_get(stmt.columns(), 1)?;
let elem_oid: Oid = row.try_get(stmt.columns(), 2)?;
let rngsubtype: Option<Oid> = row.try_get(stmt.columns(), 3)?;
let basetype: Oid = row.try_get(stmt.columns(), 4)?;
let schema: String = row.try_get(stmt.columns(), 5)?;
let relid: Oid = row.try_get(stmt.columns(), 6)?;
let kind = if type_ == b'e' as i8 {
let variants = get_enum_variants(client, cache, oid).await?;
Kind::Enum(variants)
} else if type_ == b'p' as i8 {
Kind::Pseudo
} else if basetype != 0 {
let type_ = get_type_rec(client, cache, basetype).await?;
Kind::Domain(type_)
} else if elem_oid != 0 {
let type_ = get_type_rec(client, cache, elem_oid).await?;
Kind::Array(type_)
} else if relid != 0 {
let fields = get_composite_fields(client, cache, relid).await?;
Kind::Composite(fields)
} else if let Some(rngsubtype) = rngsubtype {
let type_ = get_type_rec(client, cache, rngsubtype).await?;
Kind::Range(type_)
} else {
Kind::Simple
};
let type_ = Type::new(name, oid, kind, schema);
cache.set_type(oid, &type_);
Ok(type_)
}
fn get_type_rec<'a>(
client: &'a mut InnerClient,
cache: &'a mut CachedTypeInfo,
oid: Oid,
) -> Pin<Box<dyn Future<Output = Result<Type, Error>> + Send + 'a>> {
Box::pin(get_type(client, cache, oid))
}
async fn typeinfo_statement<'c>(
client: &mut InnerClient,
cache: &'c mut CachedTypeInfo,
) -> Result<&'c Statement, Error> {
if cache.typeinfo().is_some() {
// needed to get around a borrow checker limitation
return Ok(cache.typeinfo().unwrap());
}
let stmt = match prepare_rec(client, cache, TYPEINFO_QUERY, &[]).await {
Ok(stmt) => stmt,
Err(ref e) if e.code() == Some(&SqlState::UNDEFINED_TABLE) => {
prepare_rec(client, cache, TYPEINFO_FALLBACK_QUERY, &[]).await?
}
Err(e) => return Err(e),
};
Ok(cache.set_typeinfo(stmt))
}
async fn get_enum_variants(
client: &mut InnerClient,
cache: &mut CachedTypeInfo,
oid: Oid,
) -> Result<Vec<String>, Error> {
let stmt = typeinfo_enum_statement(client, cache).await?;
let mut out = vec![];
let mut rows = pin!(query::query(client, stmt, slice_iter(&[&oid])).await?);
while let Some(row) = rows.next().await {
out.push(row?.try_get(stmt.columns(), 0)?)
}
Ok(out)
}
async fn typeinfo_enum_statement<'c>(
client: &mut InnerClient,
cache: &'c mut CachedTypeInfo,
) -> Result<&'c Statement, Error> {
if cache.typeinfo_enum().is_some() {
// needed to get around a borrow checker limitation
return Ok(cache.typeinfo_enum().unwrap());
}
let stmt = match prepare_rec(client, cache, TYPEINFO_ENUM_QUERY, &[]).await {
Ok(stmt) => stmt,
Err(ref e) if e.code() == Some(&SqlState::UNDEFINED_COLUMN) => {
prepare_rec(client, cache, TYPEINFO_ENUM_FALLBACK_QUERY, &[]).await?
}
Err(e) => return Err(e),
};
Ok(cache.set_typeinfo_enum(stmt))
}
async fn get_composite_fields(
client: &mut InnerClient,
cache: &mut CachedTypeInfo,
oid: Oid,
) -> Result<Vec<Field>, Error> {
let stmt = typeinfo_composite_statement(client, cache).await?;
let mut rows = pin!(query::query(client, stmt, slice_iter(&[&oid])).await?);
let mut oids = vec![];
while let Some(row) = rows.next().await {
let row = row?;
let name = row.try_get(stmt.columns(), 0)?;
let oid = row.try_get(stmt.columns(), 1)?;
oids.push((name, oid));
}
let mut fields = vec![];
for (name, oid) in oids {
let type_ = get_type_rec(client, cache, oid).await?;
fields.push(Field::new(name, type_));
}
Ok(fields)
}
async fn typeinfo_composite_statement<'c>(
client: &mut InnerClient,
cache: &'c mut CachedTypeInfo,
) -> Result<&'c Statement, Error> {
if cache.typeinfo_composite().is_some() {
// needed to get around a borrow checker limitation
return Ok(cache.typeinfo_composite().unwrap());
}
let stmt = prepare_rec(client, cache, TYPEINFO_COMPOSITE_QUERY, &[]).await?;
Ok(cache.set_typeinfo_composite(stmt))
}

View File

@@ -1,333 +0,0 @@
use crate::client::{InnerClient, Responses};
use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::types::IsNull;
use crate::{Column, Error, ReadyForQueryStatus, Row, Statement};
use bytes::{BufMut, Bytes, BytesMut};
use fallible_iterator::FallibleIterator;
use futures_util::{ready, Stream};
use log::{debug, log_enabled, Level};
use pin_project_lite::pin_project;
use postgres_protocol2::message::backend::Message;
use postgres_protocol2::message::frontend;
use postgres_types2::{Format, ToSql, Type};
use std::fmt;
use std::marker::PhantomPinned;
use std::pin::Pin;
use std::task::{Context, Poll};
struct BorrowToSqlParamsDebug<'a>(&'a [&'a (dyn ToSql + Sync)]);
impl fmt::Debug for BorrowToSqlParamsDebug<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_list().entries(self.0.iter()).finish()
}
}
pub async fn query<'a, I>(
client: &mut InnerClient,
statement: &Statement,
params: I,
) -> Result<RawRowStream, Error>
where
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
I::IntoIter: ExactSizeIterator,
{
let buf = if log_enabled!(Level::Debug) {
let params = params.into_iter().collect::<Vec<_>>();
debug!(
"executing statement {} with parameters: {:?}",
statement.name(),
BorrowToSqlParamsDebug(params.as_slice()),
);
encode(client, statement, params)?
} else {
encode(client, statement, params)?
};
let responses = start(client, buf).await?;
Ok(RawRowStream {
responses,
command_tag: None,
status: ReadyForQueryStatus::Unknown,
output_format: Format::Binary,
_p: PhantomPinned,
})
}
pub async fn query_txt<S, I>(
client: &mut InnerClient,
query: &str,
params: I,
) -> Result<RowStream, Error>
where
S: AsRef<str>,
I: IntoIterator<Item = Option<S>>,
I::IntoIter: ExactSizeIterator,
{
let params = params.into_iter();
let buf = client.with_buf(|buf| {
frontend::parse(
"", // unnamed prepared statement
query, // query to parse
std::iter::empty(), // give no type info
buf,
)
.map_err(Error::encode)?;
frontend::describe(b'S', "", buf).map_err(Error::encode)?;
// Bind, pass params as text, retrieve as binary
match frontend::bind(
"", // empty string selects the unnamed portal
"", // unnamed prepared statement
std::iter::empty(), // all parameters use the default format (text)
params,
|param, buf| match param {
Some(param) => {
buf.put_slice(param.as_ref().as_bytes());
Ok(postgres_protocol2::IsNull::No)
}
None => Ok(postgres_protocol2::IsNull::Yes),
},
Some(0), // all text
buf,
) {
Ok(()) => Ok(()),
Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, 0)),
Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)),
}?;
// Execute
frontend::execute("", 0, buf).map_err(Error::encode)?;
// Sync
frontend::sync(buf);
Ok(buf.split().freeze())
})?;
// now read the responses
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
match responses.next().await? {
Message::ParseComplete => {}
_ => return Err(Error::unexpected_message()),
}
let parameter_description = match responses.next().await? {
Message::ParameterDescription(body) => body,
_ => return Err(Error::unexpected_message()),
};
let row_description = match responses.next().await? {
Message::RowDescription(body) => Some(body),
Message::NoData => None,
_ => return Err(Error::unexpected_message()),
};
match responses.next().await? {
Message::BindComplete => {}
_ => return Err(Error::unexpected_message()),
}
let mut parameters = vec![];
let mut it = parameter_description.parameters();
while let Some(oid) = it.next().map_err(Error::parse)? {
let type_ = Type::from_oid(oid).unwrap_or(Type::UNKNOWN);
parameters.push(type_);
}
let mut columns = vec![];
if let Some(row_description) = row_description {
let mut it = row_description.fields();
while let Some(field) = it.next().map_err(Error::parse)? {
let type_ = Type::from_oid(field.type_oid()).unwrap_or(Type::UNKNOWN);
let column = Column::new(field.name().to_string(), type_, field);
columns.push(column);
}
}
Ok(RowStream {
statement: Statement::new_anonymous(parameters, columns),
responses,
command_tag: None,
status: ReadyForQueryStatus::Unknown,
output_format: Format::Text,
_p: PhantomPinned,
})
}
async fn start(client: &InnerClient, buf: Bytes) -> Result<Responses, Error> {
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
match responses.next().await? {
Message::BindComplete => {}
_ => return Err(Error::unexpected_message()),
}
Ok(responses)
}
pub fn encode<'a, I>(
client: &mut InnerClient,
statement: &Statement,
params: I,
) -> Result<Bytes, Error>
where
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
I::IntoIter: ExactSizeIterator,
{
client.with_buf(|buf| {
encode_bind(statement, params, "", buf)?;
frontend::execute("", 0, buf).map_err(Error::encode)?;
frontend::sync(buf);
Ok(buf.split().freeze())
})
}
pub fn encode_bind<'a, I>(
statement: &Statement,
params: I,
portal: &str,
buf: &mut BytesMut,
) -> Result<(), Error>
where
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
I::IntoIter: ExactSizeIterator,
{
let param_types = statement.params();
let params = params.into_iter();
assert!(
param_types.len() == params.len(),
"expected {} parameters but got {}",
param_types.len(),
params.len()
);
let (param_formats, params): (Vec<_>, Vec<_>) = params
.zip(param_types.iter())
.map(|(p, ty)| (p.encode_format(ty) as i16, p))
.unzip();
let params = params.into_iter();
let mut error_idx = 0;
let r = frontend::bind(
portal,
statement.name(),
param_formats,
params.zip(param_types).enumerate(),
|(idx, (param, ty)), buf| match param.to_sql_checked(ty, buf) {
Ok(IsNull::No) => Ok(postgres_protocol2::IsNull::No),
Ok(IsNull::Yes) => Ok(postgres_protocol2::IsNull::Yes),
Err(e) => {
error_idx = idx;
Err(e)
}
},
Some(1),
buf,
);
match r {
Ok(()) => Ok(()),
Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, error_idx)),
Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)),
}
}
pin_project! {
/// A stream of table rows.
pub struct RowStream {
statement: Statement,
responses: Responses,
command_tag: Option<String>,
output_format: Format,
status: ReadyForQueryStatus,
#[pin]
_p: PhantomPinned,
}
}
impl Stream for RowStream {
type Item = Result<Row, Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
loop {
match ready!(this.responses.poll_next(cx)?) {
Message::DataRow(body) => {
return Poll::Ready(Some(Ok(Row::new(body, *this.output_format)?)))
}
Message::EmptyQueryResponse | Message::PortalSuspended => {}
Message::CommandComplete(body) => {
if let Ok(tag) = body.tag() {
*this.command_tag = Some(tag.to_string());
}
}
Message::ReadyForQuery(status) => {
*this.status = status.into();
return Poll::Ready(None);
}
_ => return Poll::Ready(Some(Err(Error::unexpected_message()))),
}
}
}
}
impl RowStream {
/// Returns information about the columns of data in the row.
pub fn columns(&self) -> &[Column] {
self.statement.columns()
}
/// Returns the command tag of this query.
///
/// This is only available after the stream has been exhausted.
pub fn command_tag(&self) -> Option<String> {
self.command_tag.clone()
}
/// Returns if the connection is ready for querying, with the status of the connection.
///
/// This might be available only after the stream has been exhausted.
pub fn ready_status(&self) -> ReadyForQueryStatus {
self.status
}
}
pin_project! {
/// A stream of table rows.
pub struct RawRowStream {
responses: Responses,
command_tag: Option<String>,
output_format: Format,
status: ReadyForQueryStatus,
#[pin]
_p: PhantomPinned,
}
}
impl Stream for RawRowStream {
type Item = Result<Row, Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
loop {
match ready!(this.responses.poll_next(cx)?) {
Message::DataRow(body) => {
return Poll::Ready(Some(Ok(Row::new(body, *this.output_format)?)))
}
Message::EmptyQueryResponse | Message::PortalSuspended => {}
Message::CommandComplete(body) => {
if let Ok(tag) = body.tag() {
*this.command_tag = Some(tag.to_string());
}
}
Message::ReadyForQuery(status) => {
*this.status = status.into();
return Poll::Ready(None);
}
_ => return Poll::Ready(Some(Err(Error::unexpected_message()))),
}
}
}
}

View File

@@ -1,84 +0,0 @@
//! Rows.
use crate::statement::Column;
use crate::types::{FromSql, Type, WrongType};
use crate::Error;
use fallible_iterator::FallibleIterator;
use postgres_protocol2::message::backend::DataRowBody;
use postgres_types2::{Format, WrongFormat};
use std::fmt;
use std::ops::Range;
use std::str;
/// A row of data returned from the database by a query.
pub struct Row {
output_format: Format,
body: DataRowBody,
ranges: Vec<Option<Range<usize>>>,
}
impl fmt::Debug for Row {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Row").finish()
}
}
impl Row {
pub(crate) fn new(
// statement: Statement,
body: DataRowBody,
output_format: Format,
) -> Result<Row, Error> {
let ranges = body.ranges().collect().map_err(Error::parse)?;
Ok(Row {
body,
ranges,
output_format,
})
}
pub(crate) fn try_get<'a, T>(&'a self, columns: &[Column], idx: usize) -> Result<T, Error>
where
T: FromSql<'a>,
{
let Some(column) = columns.get(idx) else {
return Err(Error::column(idx.to_string()));
};
let ty = column.type_();
if !T::accepts(ty) {
return Err(Error::from_sql(
Box::new(WrongType::new::<T>(ty.clone())),
idx,
));
}
FromSql::from_sql_nullable(ty, self.col_buffer(idx)).map_err(|e| Error::from_sql(e, idx))
}
/// Get the raw bytes for the column at the given index.
fn col_buffer(&self, idx: usize) -> Option<&[u8]> {
let range = self.ranges.get(idx)?.to_owned()?;
Some(&self.body.buffer()[range])
}
/// Interpret the column at the given index as text
///
/// Useful when using query_raw_txt() which sets text transfer mode
pub fn as_text(&self, idx: usize) -> Result<Option<&str>, Error> {
if self.output_format == Format::Text {
match self.col_buffer(idx) {
Some(raw) => {
FromSql::from_sql(&Type::TEXT, raw).map_err(|e| Error::from_sql(e, idx))
}
None => Ok(None),
}
} else {
Err(Error::from_sql(Box::new(WrongFormat {}), idx))
}
}
/// Row byte size
pub fn body_len(&self) -> usize {
self.body.buffer().len()
}
}

View File

@@ -1,36 +0,0 @@
use crate::client::InnerClient;
use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::{Error, ReadyForQueryStatus};
use bytes::Bytes;
use log::debug;
use postgres_protocol2::message::backend::Message;
use postgres_protocol2::message::frontend;
pub async fn batch_execute(
client: &mut InnerClient,
query: &str,
) -> Result<ReadyForQueryStatus, Error> {
debug!("executing statement batch: {}", query);
let buf = encode(client, query)?;
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
loop {
match responses.next().await? {
Message::ReadyForQuery(status) => return Ok(status.into()),
Message::CommandComplete(_)
| Message::EmptyQueryResponse
| Message::RowDescription(_)
| Message::DataRow(_) => {}
_ => return Err(Error::unexpected_message()),
}
}
}
pub(crate) fn encode(client: &mut InnerClient, query: &str) -> Result<Bytes, Error> {
client.with_buf(|buf| {
frontend::query(query, buf).map_err(Error::encode)?;
Ok(buf.split().freeze())
})
}

View File

@@ -1,126 +0,0 @@
use crate::types::Type;
use postgres_protocol2::{message::backend::Field, Oid};
use std::fmt;
struct StatementInner {
name: String,
params: Vec<Type>,
columns: Vec<Column>,
}
/// A prepared statement.
///
/// Prepared statements can only be used with the connection that created them.
pub struct Statement(StatementInner);
impl Statement {
pub(crate) fn new(name: String, params: Vec<Type>, columns: Vec<Column>) -> Statement {
Statement(StatementInner {
name,
params,
columns,
})
}
pub(crate) fn new_anonymous(params: Vec<Type>, columns: Vec<Column>) -> Statement {
Statement(StatementInner {
name: String::new(),
params,
columns,
})
}
pub(crate) fn name(&self) -> &str {
&self.0.name
}
/// Returns the expected types of the statement's parameters.
pub fn params(&self) -> &[Type] {
&self.0.params
}
/// Returns information about the columns returned when the statement is queried.
pub fn columns(&self) -> &[Column] {
&self.0.columns
}
}
/// Information about a column of a query.
pub struct Column {
name: String,
type_: Type,
// raw fields from RowDescription
table_oid: Oid,
column_id: i16,
format: i16,
// that better be stored in self.type_, but that is more radical refactoring
type_oid: Oid,
type_size: i16,
type_modifier: i32,
}
impl Column {
pub(crate) fn new(name: String, type_: Type, raw_field: Field<'_>) -> Column {
Column {
name,
type_,
table_oid: raw_field.table_oid(),
column_id: raw_field.column_id(),
format: raw_field.format(),
type_oid: raw_field.type_oid(),
type_size: raw_field.type_size(),
type_modifier: raw_field.type_modifier(),
}
}
/// Returns the name of the column.
pub fn name(&self) -> &str {
&self.name
}
/// Returns the type of the column.
pub fn type_(&self) -> &Type {
&self.type_
}
/// Returns the table OID of the column.
pub fn table_oid(&self) -> Oid {
self.table_oid
}
/// Returns the column ID of the column.
pub fn column_id(&self) -> i16 {
self.column_id
}
/// Returns the format of the column.
pub fn format(&self) -> i16 {
self.format
}
/// Returns the type OID of the column.
pub fn type_oid(&self) -> Oid {
self.type_oid
}
/// Returns the type size of the column.
pub fn type_size(&self) -> i16 {
self.type_size
}
/// Returns the type modifier of the column.
pub fn type_modifier(&self) -> i32 {
self.type_modifier
}
}
impl fmt::Debug for Column {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("Column")
.field("name", &self.name)
.field("type", &self.type_)
.finish()
}
}

View File

@@ -1,162 +0,0 @@
//! TLS support.
use std::error::Error;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{fmt, io};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
pub(crate) mod private {
pub struct ForcePrivateApi;
}
/// Channel binding information returned from a TLS handshake.
pub struct ChannelBinding {
pub(crate) tls_server_end_point: Option<Vec<u8>>,
}
impl ChannelBinding {
/// Creates a `ChannelBinding` containing no information.
pub fn none() -> ChannelBinding {
ChannelBinding {
tls_server_end_point: None,
}
}
/// Creates a `ChannelBinding` containing `tls-server-end-point` channel binding information.
pub fn tls_server_end_point(tls_server_end_point: Vec<u8>) -> ChannelBinding {
ChannelBinding {
tls_server_end_point: Some(tls_server_end_point),
}
}
}
/// A constructor of `TlsConnect`ors.
///
/// Requires the `runtime` Cargo feature (enabled by default).
pub trait MakeTlsConnect<S> {
/// The stream type created by the `TlsConnect` implementation.
type Stream: TlsStream + Unpin;
/// The `TlsConnect` implementation created by this type.
type TlsConnect: TlsConnect<S, Stream = Self::Stream>;
/// The error type returned by the `TlsConnect` implementation.
type Error: Into<Box<dyn Error + Sync + Send>>;
/// Creates a new `TlsConnect`or.
///
/// The domain name is provided for certificate verification and SNI.
fn make_tls_connect(&mut self, domain: &str) -> Result<Self::TlsConnect, Self::Error>;
}
/// An asynchronous function wrapping a stream in a TLS session.
pub trait TlsConnect<S> {
/// The stream returned by the future.
type Stream: TlsStream + Unpin;
/// The error returned by the future.
type Error: Into<Box<dyn Error + Sync + Send>>;
/// The future returned by the connector.
type Future: Future<Output = Result<Self::Stream, Self::Error>>;
/// Returns a future performing a TLS handshake over the stream.
fn connect(self, stream: S) -> Self::Future;
#[doc(hidden)]
fn can_connect(&self, _: private::ForcePrivateApi) -> bool {
true
}
}
/// A TLS-wrapped connection to a PostgreSQL database.
pub trait TlsStream: AsyncRead + AsyncWrite {
/// Returns channel binding information for the session.
fn channel_binding(&self) -> ChannelBinding;
}
/// A `MakeTlsConnect` and `TlsConnect` implementation which simply returns an error.
///
/// This can be used when `sslmode` is `none` or `prefer`.
#[derive(Debug, Copy, Clone)]
pub struct NoTls;
impl<S> MakeTlsConnect<S> for NoTls {
type Stream = NoTlsStream;
type TlsConnect = NoTls;
type Error = NoTlsError;
fn make_tls_connect(&mut self, _: &str) -> Result<NoTls, NoTlsError> {
Ok(NoTls)
}
}
impl<S> TlsConnect<S> for NoTls {
type Stream = NoTlsStream;
type Error = NoTlsError;
type Future = NoTlsFuture;
fn connect(self, _: S) -> NoTlsFuture {
NoTlsFuture(())
}
fn can_connect(&self, _: private::ForcePrivateApi) -> bool {
false
}
}
/// The future returned by `NoTls`.
pub struct NoTlsFuture(());
impl Future for NoTlsFuture {
type Output = Result<NoTlsStream, NoTlsError>;
fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output> {
Poll::Ready(Err(NoTlsError(())))
}
}
/// The TLS "stream" type produced by the `NoTls` connector.
///
/// Since `NoTls` doesn't support TLS, this type is uninhabited.
pub enum NoTlsStream {}
impl AsyncRead for NoTlsStream {
fn poll_read(
self: Pin<&mut Self>,
_: &mut Context<'_>,
_: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match *self {}
}
}
impl AsyncWrite for NoTlsStream {
fn poll_write(self: Pin<&mut Self>, _: &mut Context<'_>, _: &[u8]) -> Poll<io::Result<usize>> {
match *self {}
}
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
match *self {}
}
fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
match *self {}
}
}
impl TlsStream for NoTlsStream {
fn channel_binding(&self) -> ChannelBinding {
match *self {}
}
}
/// The error returned by `NoTls`.
#[derive(Debug)]
pub struct NoTlsError(());
impl fmt::Display for NoTlsError {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.write_str("no TLS implementation configured")
}
}
impl Error for NoTlsError {}

View File

@@ -1,63 +0,0 @@
use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::{CancelToken, Client, Error, ReadyForQueryStatus};
use postgres_protocol2::message::frontend;
/// A representation of a PostgreSQL database transaction.
///
/// Transactions will implicitly roll back when dropped. Use the `commit` method to commit the changes made in the
/// transaction. Transactions can be nested, with inner transactions implemented via safepoints.
pub struct Transaction<'a> {
client: &'a mut Client,
done: bool,
}
impl Drop for Transaction<'_> {
fn drop(&mut self) {
if self.done {
return;
}
let buf = self.client.inner.with_buf(|buf| {
frontend::query("ROLLBACK", buf).unwrap();
buf.split().freeze()
});
let _ = self
.client
.inner
.send(RequestMessages::Single(FrontendMessage::Raw(buf)));
}
}
impl<'a> Transaction<'a> {
pub(crate) fn new(client: &'a mut Client) -> Transaction<'a> {
Transaction {
client,
done: false,
}
}
/// Consumes the transaction, committing all changes made within it.
pub async fn commit(mut self) -> Result<ReadyForQueryStatus, Error> {
self.done = true;
self.client.batch_execute("COMMIT").await
}
/// Rolls the transaction back, discarding all changes made within it.
///
/// This is equivalent to `Transaction`'s `Drop` implementation, but provides any error encountered to the caller.
pub async fn rollback(mut self) -> Result<ReadyForQueryStatus, Error> {
self.done = true;
self.client.batch_execute("ROLLBACK").await
}
/// Like `Client::cancel_token`.
pub fn cancel_token(&self) -> CancelToken {
self.client.cancel_token()
}
/// Returns a reference to the underlying `Client`.
pub fn client(&mut self) -> &mut Client {
self.client
}
}

View File

@@ -1,113 +0,0 @@
use crate::{Client, Error, Transaction};
/// The isolation level of a database transaction.
#[derive(Debug, Copy, Clone)]
#[non_exhaustive]
pub enum IsolationLevel {
/// Equivalent to `ReadCommitted`.
ReadUncommitted,
/// An individual statement in the transaction will see rows committed before it began.
ReadCommitted,
/// All statements in the transaction will see the same view of rows committed before the first query in the
/// transaction.
RepeatableRead,
/// The reads and writes in this transaction must be able to be committed as an atomic "unit" with respect to reads
/// and writes of all other concurrent serializable transactions without interleaving.
Serializable,
}
/// A builder for database transactions.
pub struct TransactionBuilder<'a> {
client: &'a mut Client,
isolation_level: Option<IsolationLevel>,
read_only: Option<bool>,
deferrable: Option<bool>,
}
impl<'a> TransactionBuilder<'a> {
pub(crate) fn new(client: &'a mut Client) -> TransactionBuilder<'a> {
TransactionBuilder {
client,
isolation_level: None,
read_only: None,
deferrable: None,
}
}
/// Sets the isolation level of the transaction.
pub fn isolation_level(mut self, isolation_level: IsolationLevel) -> Self {
self.isolation_level = Some(isolation_level);
self
}
/// Sets the access mode of the transaction.
pub fn read_only(mut self, read_only: bool) -> Self {
self.read_only = Some(read_only);
self
}
/// Sets the deferrability of the transaction.
///
/// If the transaction is also serializable and read only, creation of the transaction may block, but when it
/// completes the transaction is able to run with less overhead and a guarantee that it will not be aborted due to
/// serialization failure.
pub fn deferrable(mut self, deferrable: bool) -> Self {
self.deferrable = Some(deferrable);
self
}
/// Begins the transaction.
///
/// The transaction will roll back by default - use the `commit` method to commit it.
pub async fn start(self) -> Result<Transaction<'a>, Error> {
let mut query = "START TRANSACTION".to_string();
let mut first = true;
if let Some(level) = self.isolation_level {
first = false;
query.push_str(" ISOLATION LEVEL ");
let level = match level {
IsolationLevel::ReadUncommitted => "READ UNCOMMITTED",
IsolationLevel::ReadCommitted => "READ COMMITTED",
IsolationLevel::RepeatableRead => "REPEATABLE READ",
IsolationLevel::Serializable => "SERIALIZABLE",
};
query.push_str(level);
}
if let Some(read_only) = self.read_only {
if !first {
query.push(',');
}
first = false;
let s = if read_only {
" READ ONLY"
} else {
" READ WRITE"
};
query.push_str(s);
}
if let Some(deferrable) = self.deferrable {
if !first {
query.push(',');
}
let s = if deferrable {
" DEFERRABLE"
} else {
" NOT DEFERRABLE"
};
query.push_str(s);
}
self.client.batch_execute(&query).await?;
Ok(Transaction::new(self.client))
}
}

Some files were not shown because too many files have changed in this diff Show More