Compare commits

..

21 Commits

Author SHA1 Message Date
Conrad Ludgate
101e770632 stash 2023-08-02 10:15:03 +01:00
Conrad Ludgate
747ffa50d6 recursive type queries 2023-08-02 09:27:27 +01:00
Conrad Ludgate
e8400d9d93 stash 2023-08-01 11:43:25 +01:00
Conrad Ludgate
17627e8023 stash2 2023-08-01 11:43:25 +01:00
Conrad Ludgate
4fb5cdbdb8 stash 2023-08-01 11:43:25 +01:00
Conrad Ludgate
70b503f83b refactor connections 2023-08-01 11:43:25 +01:00
Conrad Ludgate
21c15c4285 copy in tokio-postgres connection code 2023-08-01 11:43:10 +01:00
Conrad Ludgate
0a524e09a5 fmt 2023-08-01 11:42:24 +01:00
Conrad Ludgate
cbe24f7c35 more scram tests 2023-08-01 11:42:24 +01:00
Conrad Ludgate
64add503c8 do less 2023-08-01 11:42:24 +01:00
Konstantin Knizhnik
a98a80abc2 Deffine NEON_SMGR to make it possible for extensions to use Neon SMG API (#4840)
## Problem

See https://neondb.slack.com/archives/C036U0GRMRB/p1689148023067319

## Summary of changes

Define NEON_SMGR in smgr.h

## Checklist before requesting a review

- [ ] I have performed a self-review of my code.
- [ ] If it is a core feature, I have added thorough tests.
- [ ] Do we need to implement analytics? if so did you add the relevant
metrics to the dashboard?
- [ ] If this PR requires public announcement, mark it with
/release-notes label and add several sentences in this section.

## Checklist before merging

- [ ] Do not forget to reformat commit message to not include the above
checklist

---------

Co-authored-by: Konstantin Knizhnik <knizhnik@neon.tech>
2023-08-01 10:04:45 +03:00
Alex Chi Z
7b6c849456 support isolation level + read only for http batch sql (#4830)
We will retrieve `neon-batch-isolation-level` and `neon-batch-read-only`
from the http header, which sets the txn properties.
https://github.com/neondatabase/serverless/pull/38#issuecomment-1653130981

---------

Signed-off-by: Alex Chi Z <chi@neon.tech>
2023-08-01 02:59:11 +03:00
Joonas Koivunen
326189d950 consumption_metrics: send timeline_written_size_delta (#4822)
We want to have timeline_written_size_delta which is defined as
difference to the previously sent `timeline_written_size` from the
current `timeline_written_size`.

Solution is to send it. On the first round `disk_consistent_lsn` is used
which is captured during `load` time. After that an incremental "event"
is sent on every collection. Incremental "events" are not part of
deduplication.

I've added some infrastructure to allow somewhat typesafe
`EventType::Absolute` and `EventType::Incremental` factories per
metrics, now that we have our first `EventType::Incremental` usage.
2023-07-31 22:10:19 +03:00
bojanserafimov
ddbe170454 Prewarm compute nodes (#4828) 2023-07-31 14:13:32 -04:00
Alexander Bayandin
39e458f049 test_compatibility: fix pg_tenant_only_port port collision (#4850)
## Problem

Compatibility tests fail from time to time due to `pg_tenant_only_port`
port collision (added in https://github.com/neondatabase/neon/pull/4731)

## Summary of changes
- replace `pg_tenant_only_port` value in config with new port
- remove old logic, than we don't need anymore
- unify config overrides
2023-07-31 20:49:46 +03:00
Vadim Kharitonov
e1424647a0 Update pg_embedding to 0.3.1 version (#4811) 2023-07-31 20:23:18 +03:00
Yinnan Yao
705ae2dce9 Fix error message for listen_pg_addr_tenant_only binding (#4787)
## Problem

Wrong use of `conf.listen_pg_addr` in `error!()`.

## Summary of changes

Use `listen_pg_addr_tenant_only` instead of `conf.listen_pg_addr`.

Signed-off-by: yaoyinnan <35447132+yaoyinnan@users.noreply.github.com>
2023-07-31 14:40:52 +01:00
Conrad Ludgate
eb78603121 proxy: div by zero (#4845)
## Problem

1. In the CacheInvalid state loop, we weren't checking the
`num_retries`. If this managed to get up to `32`, the retry_after
procedure would compute 2^32 which would overflow to 0 and trigger a div
by zero
2. When fixing the above, I started working on a flow diagram for the
state machine logic and realised it was more complex than it had to be:
  a. We start in a `Cached` state
b. `Cached`: call `connect_once`. After the first connect_once error, we
always move to the `CacheInvalid` state, otherwise, we return the
connection.
c. `CacheInvalid`: we attempt to `wake_compute` and we either switch to
Cached or we retry this step (or we error).
d. `Cached`: call `connect_once`. We either retry this step or we have a
connection (or we error) - After num_retries > 1 we never switch back to
`CacheInvalid`.

## Summary of changes

1. Insert a `num_retries` check in the `handle_try_wake` procedure. Also
using floats in the retry_after procedure to prevent the overflow
entirely
2. Refactor connect_to_compute to be more linear in design.
2023-07-31 09:30:24 -04:00
John Spray
f0ad603693 pageserver: add unit test for deleted_at in IndexPart (#4844)
## Problem

Existing IndexPart unit tests only exercised the version 1 format (i.e.
without deleted_at set).

## Summary of changes

Add a test that sets version to 2, and sets a value for deleted_at.

Closes https://github.com/neondatabase/neon/issues/4162
2023-07-31 12:51:18 +01:00
Arpad Müller
e5183f85dc Make DiskBtreeReader::dump async (#4838)
## Problem

`DiskBtreeReader::dump` calls `read_blk` internally, which we want to
make async in the future. As it is currently relying on recursion, and
async doesn't like recursion, we want to find an alternative to that and
instead traverse the tree using a loop and a manual stack.

## Summary of changes

* Make `DiskBtreeReader::dump` and all the places calling it async
* Make `DiskBtreeReader::dump` non-recursive internally and use a stack
instead. It now deparses the node in each iteration, which isn't
optimal, but on the other hand it's hard to store the node as it is
referencing the buffer. Self referential data are hard in Rust. For a
dumping function, speed isn't a priority so we deparse the node multiple
times now (up to branching factor many times).

Part of https://github.com/neondatabase/neon/issues/4743

I have verified that output is unchanged by comparing the output of this
command both before and after this patch:
```
cargo test -p pageserver -- particular_data --nocapture 
```
2023-07-31 12:52:29 +02:00
Joonas Koivunen
89ee8f2028 fix: demote warnings, fix flakyness (#4837)
`WARN ... found future (image|delta) layer` are not actionable log
lines. They don't need to be warnings. `info!` is enough.

This also fixes some known but not tracked flakyness in
[`test_remote_timeline_client_calls_started_metric`][evidence].

[evidence]:
https://neon-github-public-dev.s3.amazonaws.com/reports/pr-4829/5683495367/index.html#/testresult/34fe79e24729618b

Closes #3369.
Closes #4473.
2023-07-31 07:43:12 +00:00
40 changed files with 1965 additions and 487 deletions

14
Cargo.lock generated
View File

@@ -2654,16 +2654,6 @@ dependencies = [
"windows-sys 0.45.0",
]
[[package]]
name = "pbkdf2"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0ca0b5a68607598bf3bad68f32227a8164f6254833f84eafaac409cd6746c31"
dependencies = [
"digest",
"hmac",
]
[[package]]
name = "peeking_take_while"
version = "0.1.2"
@@ -3040,6 +3030,7 @@ dependencies = [
"chrono",
"clap",
"consumption_metrics",
"fallible-iterator",
"futures",
"git-version",
"hashbrown 0.13.2",
@@ -3057,9 +3048,9 @@ dependencies = [
"once_cell",
"opentelemetry",
"parking_lot 0.12.1",
"pbkdf2",
"pin-project-lite",
"postgres-native-tls",
"postgres-protocol",
"postgres_backend",
"pq_proto",
"prometheus",
@@ -3083,6 +3074,7 @@ dependencies = [
"thiserror",
"tls-listener",
"tokio",
"tokio-native-tls",
"tokio-postgres",
"tokio-postgres-rustls",
"tokio-rustls 0.23.4",

View File

@@ -88,7 +88,6 @@ opentelemetry = "0.19.0"
opentelemetry-otlp = { version = "0.12.0", default_features=false, features = ["http-proto", "trace", "http", "reqwest-client"] }
opentelemetry-semantic-conventions = "0.11.0"
parking_lot = "0.12"
pbkdf2 = "0.12.1"
pin-project-lite = "0.2"
prometheus = {version = "0.13", default_features=false, features = ["process"]} # removes protobuf dependency
prost = "0.11"

View File

@@ -551,10 +551,8 @@ FROM build-deps AS pg-embedding-pg-build
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
ENV PATH "/usr/local/pgsql/bin/:$PATH"
# eeb3ba7c3a60c95b2604dd543c64b2f1bb4a3703 made on 15/07/2023
# There is no release tag yet
RUN wget https://github.com/neondatabase/pg_embedding/archive/eeb3ba7c3a60c95b2604dd543c64b2f1bb4a3703.tar.gz -O pg_embedding.tar.gz && \
echo "030846df723652f99a8689ce63b66fa0c23477a7fd723533ab8a6b28ab70730f pg_embedding.tar.gz" | sha256sum --check && \
RUN wget https://github.com/neondatabase/pg_embedding/archive/refs/tags/0.3.1.tar.gz -O pg_embedding.tar.gz && \
echo "c4ae84eef36fa8ec5868f6e061f39812f19ee5ba3604d428d40935685c7be512 pg_embedding.tar.gz" | sha256sum --check && \
mkdir pg_embedding-src && cd pg_embedding-src && tar xvzf ../pg_embedding.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) && \
make -j $(getconf _NPROCESSORS_ONLN) install && \

View File

@@ -193,6 +193,13 @@ fn main() -> Result<()> {
if !spec_set {
// No spec provided, hang waiting for it.
info!("no compute spec provided, waiting");
// TODO this can stall startups in the unlikely event that we bind
// this compute node while it's busy prewarming. It's not too
// bad because it's just 100ms and unlikely, but it's an
// avoidable problem.
compute.prewarm_postgres()?;
let mut state = compute.state.lock().unwrap();
while state.status != ComputeStatus::ConfigurationPending {
state = compute.state_changed.wait(state).unwrap();

View File

@@ -532,6 +532,50 @@ impl ComputeNode {
Ok(())
}
/// Start and stop a postgres process to warm up the VM for startup.
pub fn prewarm_postgres(&self) -> Result<()> {
info!("prewarming");
// Create pgdata
let pgdata = &format!("{}.warmup", self.pgdata);
create_pgdata(pgdata)?;
// Run initdb to completion
info!("running initdb");
let initdb_bin = Path::new(&self.pgbin).parent().unwrap().join("initdb");
Command::new(initdb_bin)
.args(["-D", pgdata])
.output()
.expect("cannot start initdb process");
// Write conf
use std::io::Write;
let conf_path = Path::new(pgdata).join("postgresql.conf");
let mut file = std::fs::File::create(conf_path)?;
writeln!(file, "shared_buffers=65536")?;
writeln!(file, "port=51055")?; // Nobody should be connecting
writeln!(file, "shared_preload_libraries = 'neon'")?;
// Start postgres
info!("starting postgres");
let mut pg = Command::new(&self.pgbin)
.args(["-D", pgdata])
.spawn()
.expect("cannot start postgres process");
// Stop it when it's ready
info!("waiting for postgres");
wait_for_postgres(&mut pg, Path::new(pgdata))?;
pg.kill()?;
info!("sent kill signal");
pg.wait()?;
info!("done prewarming");
// clean up
let _ok = fs::remove_dir_all(pgdata);
Ok(())
}
/// Start Postgres as a child process and manage DBs/roles.
/// After that this will hang waiting on the postmaster process to exit.
#[instrument(skip_all)]

View File

@@ -5,7 +5,7 @@ use chrono::{DateTime, Utc};
use rand::Rng;
use serde::Serialize;
#[derive(Serialize, Debug, Clone, Eq, PartialEq, Ord, PartialOrd)]
#[derive(Serialize, Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)]
#[serde(tag = "type")]
pub enum EventType {
#[serde(rename = "absolute")]
@@ -17,6 +17,32 @@ pub enum EventType {
},
}
impl EventType {
pub fn absolute_time(&self) -> Option<&DateTime<Utc>> {
use EventType::*;
match self {
Absolute { time } => Some(time),
_ => None,
}
}
pub fn incremental_timerange(&self) -> Option<std::ops::Range<&DateTime<Utc>>> {
// these can most likely be thought of as Range or RangeFull
use EventType::*;
match self {
Incremental {
start_time,
stop_time,
} => Some(start_time..stop_time),
_ => None,
}
}
pub fn is_incremental(&self) -> bool {
matches!(self, EventType::Incremental { .. })
}
}
#[derive(Serialize, Debug, Clone, Eq, PartialEq, Ord, PartialOrd)]
pub struct Event<Extra> {
#[serde(flatten)]

View File

@@ -7,7 +7,7 @@ use crate::context::{DownloadBehavior, RequestContext};
use crate::task_mgr::{self, TaskKind, BACKGROUND_RUNTIME};
use crate::tenant::{mgr, LogicalSizeCalculationCause};
use anyhow;
use chrono::Utc;
use chrono::{DateTime, Utc};
use consumption_metrics::{idempotency_key, Event, EventChunk, EventType, CHUNK_SIZE};
use pageserver_api::models::TenantState;
use reqwest::Url;
@@ -18,12 +18,6 @@ use std::time::Duration;
use tracing::*;
use utils::id::{NodeId, TenantId, TimelineId};
const WRITTEN_SIZE: &str = "written_size";
const SYNTHETIC_STORAGE_SIZE: &str = "synthetic_storage_size";
const RESIDENT_SIZE: &str = "resident_size";
const REMOTE_STORAGE_SIZE: &str = "remote_storage_size";
const TIMELINE_LOGICAL_SIZE: &str = "timeline_logical_size";
const DEFAULT_HTTP_REPORTING_TIMEOUT: Duration = Duration::from_secs(60);
#[serde_as]
@@ -44,6 +38,121 @@ pub struct PageserverConsumptionMetricsKey {
pub metric: &'static str,
}
impl PageserverConsumptionMetricsKey {
const fn absolute_values(self) -> AbsoluteValueFactory {
AbsoluteValueFactory(self)
}
const fn incremental_values(self) -> IncrementalValueFactory {
IncrementalValueFactory(self)
}
}
/// Helper type which each individual metric kind can return to produce only absolute values.
struct AbsoluteValueFactory(PageserverConsumptionMetricsKey);
impl AbsoluteValueFactory {
fn now(self, val: u64) -> (PageserverConsumptionMetricsKey, (EventType, u64)) {
let key = self.0;
let time = Utc::now();
(key, (EventType::Absolute { time }, val))
}
}
/// Helper type which each individual metric kind can return to produce only incremental values.
struct IncrementalValueFactory(PageserverConsumptionMetricsKey);
impl IncrementalValueFactory {
#[allow(clippy::wrong_self_convention)]
fn from_previous_up_to(
self,
prev_end: DateTime<Utc>,
up_to: DateTime<Utc>,
val: u64,
) -> (PageserverConsumptionMetricsKey, (EventType, u64)) {
let key = self.0;
// cannot assert prev_end < up_to because these are realtime clock based
(
key,
(
EventType::Incremental {
start_time: prev_end,
stop_time: up_to,
},
val,
),
)
}
fn key(&self) -> &PageserverConsumptionMetricsKey {
&self.0
}
}
// the static part of a PageserverConsumptionMetricsKey
impl PageserverConsumptionMetricsKey {
const fn written_size(tenant_id: TenantId, timeline_id: TimelineId) -> AbsoluteValueFactory {
PageserverConsumptionMetricsKey {
tenant_id,
timeline_id: Some(timeline_id),
metric: "written_size",
}
.absolute_values()
}
/// Values will be the difference of the latest written_size (last_record_lsn) to what we
/// previously sent.
const fn written_size_delta(
tenant_id: TenantId,
timeline_id: TimelineId,
) -> IncrementalValueFactory {
PageserverConsumptionMetricsKey {
tenant_id,
timeline_id: Some(timeline_id),
metric: "written_size_bytes_delta",
}
.incremental_values()
}
const fn timeline_logical_size(
tenant_id: TenantId,
timeline_id: TimelineId,
) -> AbsoluteValueFactory {
PageserverConsumptionMetricsKey {
tenant_id,
timeline_id: Some(timeline_id),
metric: "timeline_logical_size",
}
.absolute_values()
}
const fn remote_storage_size(tenant_id: TenantId) -> AbsoluteValueFactory {
PageserverConsumptionMetricsKey {
tenant_id,
timeline_id: None,
metric: "remote_storage_size",
}
.absolute_values()
}
const fn resident_size(tenant_id: TenantId) -> AbsoluteValueFactory {
PageserverConsumptionMetricsKey {
tenant_id,
timeline_id: None,
metric: "resident_size",
}
.absolute_values()
}
const fn synthetic_size(tenant_id: TenantId) -> AbsoluteValueFactory {
PageserverConsumptionMetricsKey {
tenant_id,
timeline_id: None,
metric: "synthetic_storage_size",
}
.absolute_values()
}
}
/// Main thread that serves metrics collection
pub async fn collect_metrics(
metric_collection_endpoint: &Url,
@@ -79,7 +188,7 @@ pub async fn collect_metrics(
.timeout(DEFAULT_HTTP_REPORTING_TIMEOUT)
.build()
.expect("Failed to create http client with timeout");
let mut cached_metrics: HashMap<PageserverConsumptionMetricsKey, u64> = HashMap::new();
let mut cached_metrics = HashMap::new();
let mut prev_iteration_time: std::time::Instant = std::time::Instant::now();
loop {
@@ -121,13 +230,13 @@ pub async fn collect_metrics(
/// - refactor this function (chunking+sending part) to reuse it in proxy module;
pub async fn collect_metrics_iteration(
client: &reqwest::Client,
cached_metrics: &mut HashMap<PageserverConsumptionMetricsKey, u64>,
cached_metrics: &mut HashMap<PageserverConsumptionMetricsKey, (EventType, u64)>,
metric_collection_endpoint: &reqwest::Url,
node_id: NodeId,
ctx: &RequestContext,
send_cached: bool,
) {
let mut current_metrics: Vec<(PageserverConsumptionMetricsKey, u64)> = Vec::new();
let mut current_metrics: Vec<(PageserverConsumptionMetricsKey, (EventType, u64))> = Vec::new();
trace!(
"starting collect_metrics_iteration. metric_collection_endpoint: {}",
metric_collection_endpoint
@@ -166,27 +275,80 @@ pub async fn collect_metrics_iteration(
if timeline.is_active() {
let timeline_written_size = u64::from(timeline.get_last_record_lsn());
current_metrics.push((
PageserverConsumptionMetricsKey {
tenant_id,
timeline_id: Some(timeline.timeline_id),
metric: WRITTEN_SIZE,
let (key, written_size_now) =
PageserverConsumptionMetricsKey::written_size(tenant_id, timeline.timeline_id)
.now(timeline_written_size);
// last_record_lsn can only go up, right now at least, TODO: #2592 or related
// features might change this.
let written_size_delta_key = PageserverConsumptionMetricsKey::written_size_delta(
tenant_id,
timeline.timeline_id,
);
// use this when available, because in a stream of incremental values, it will be
// accurate where as when last_record_lsn stops moving, we will only cache the last
// one of those.
let last_stop_time =
cached_metrics
.get(written_size_delta_key.key())
.map(|(until, _val)| {
until
.incremental_timerange()
.expect("never create EventType::Absolute for written_size_delta")
.end
});
// by default, use the last sent written_size as the basis for
// calculating the delta. if we don't yet have one, use the load time value.
let prev = cached_metrics
.get(&key)
.map(|(prev_at, prev)| {
// use the prev time from our last incremental update, or default to latest
// absolute update on the first round.
let prev_at = prev_at
.absolute_time()
.expect("never create EventType::Incremental for written_size");
let prev_at = last_stop_time.unwrap_or(prev_at);
(*prev_at, *prev)
})
.unwrap_or_else(|| {
// if we don't have a previous point of comparison, compare to the load time
// lsn.
let (disk_consistent_lsn, loaded_at) = &timeline.loaded_at;
(DateTime::from(*loaded_at), disk_consistent_lsn.0)
});
// written_size_delta_bytes
current_metrics.extend(
if let Some(delta) = written_size_now.1.checked_sub(prev.1) {
let up_to = written_size_now
.0
.absolute_time()
.expect("never create EventType::Incremental for written_size");
let key_value =
written_size_delta_key.from_previous_up_to(prev.0, *up_to, delta);
Some(key_value)
} else {
None
},
timeline_written_size,
));
);
// written_size
current_metrics.push((key, written_size_now));
let span = info_span!("collect_metrics_iteration", tenant_id = %timeline.tenant_id, timeline_id = %timeline.timeline_id);
match span.in_scope(|| timeline.get_current_logical_size(ctx)) {
// Only send timeline logical size when it is fully calculated.
Ok((size, is_exact)) if is_exact => {
current_metrics.push((
PageserverConsumptionMetricsKey {
current_metrics.push(
PageserverConsumptionMetricsKey::timeline_logical_size(
tenant_id,
timeline_id: Some(timeline.timeline_id),
metric: TIMELINE_LOGICAL_SIZE,
},
size,
));
timeline.timeline_id,
)
.now(size),
);
}
Ok((_, _)) => {}
Err(err) => {
@@ -205,14 +367,10 @@ pub async fn collect_metrics_iteration(
match tenant.get_remote_size().await {
Ok(tenant_remote_size) => {
current_metrics.push((
PageserverConsumptionMetricsKey {
tenant_id,
timeline_id: None,
metric: REMOTE_STORAGE_SIZE,
},
tenant_remote_size,
));
current_metrics.push(
PageserverConsumptionMetricsKey::remote_storage_size(tenant_id)
.now(tenant_remote_size),
);
}
Err(err) => {
error!(
@@ -222,14 +380,9 @@ pub async fn collect_metrics_iteration(
}
}
current_metrics.push((
PageserverConsumptionMetricsKey {
tenant_id,
timeline_id: None,
metric: RESIDENT_SIZE,
},
tenant_resident_size,
));
current_metrics.push(
PageserverConsumptionMetricsKey::resident_size(tenant_id).now(tenant_resident_size),
);
// Note that this metric is calculated in a separate bgworker
// Here we only use cached value, which may lag behind the real latest one
@@ -237,23 +390,27 @@ pub async fn collect_metrics_iteration(
if tenant_synthetic_size != 0 {
// only send non-zeroes because otherwise these show up as errors in logs
current_metrics.push((
PageserverConsumptionMetricsKey {
tenant_id,
timeline_id: None,
metric: SYNTHETIC_STORAGE_SIZE,
},
tenant_synthetic_size,
));
current_metrics.push(
PageserverConsumptionMetricsKey::synthetic_size(tenant_id)
.now(tenant_synthetic_size),
);
}
}
// Filter metrics, unless we want to send all metrics, including cached ones.
// See: https://github.com/neondatabase/neon/issues/3485
if !send_cached {
current_metrics.retain(|(curr_key, curr_val)| match cached_metrics.get(curr_key) {
Some(val) => val != curr_val,
None => true,
current_metrics.retain(|(curr_key, (kind, curr_val))| {
if kind.is_incremental() {
// incremental values (currently only written_size_delta) should not get any cache
// deduplication because they will be used by upstream for "is still alive."
true
} else {
match cached_metrics.get(curr_key) {
Some((_, val)) => val != curr_val,
None => true,
}
}
});
}
@@ -272,8 +429,8 @@ pub async fn collect_metrics_iteration(
chunk_to_send.clear();
// enrich metrics with type,timestamp and idempotency key before sending
chunk_to_send.extend(chunk.iter().map(|(curr_key, curr_val)| Event {
kind: EventType::Absolute { time: Utc::now() },
chunk_to_send.extend(chunk.iter().map(|(curr_key, (when, curr_val))| Event {
kind: *when,
metric: curr_key.metric,
idempotency_key: idempotency_key(node_id.to_string()),
value: *curr_val,

View File

@@ -390,39 +390,42 @@ where
}
#[allow(dead_code)]
pub fn dump(&self) -> Result<()> {
self.dump_recurse(self.root_blk, &[], 0)
}
pub async fn dump(&self) -> Result<()> {
let mut stack = Vec::new();
fn dump_recurse(&self, blknum: u32, path: &[u8], depth: usize) -> Result<()> {
let blk = self.reader.read_blk(self.start_blk + blknum)?;
let buf: &[u8] = blk.as_ref();
stack.push((self.root_blk, String::new(), 0, 0, 0));
let node = OnDiskNode::<L>::deparse(buf)?;
while let Some((blknum, path, depth, child_idx, key_off)) = stack.pop() {
let blk = self.reader.read_blk(self.start_blk + blknum)?;
let buf: &[u8] = blk.as_ref();
let node = OnDiskNode::<L>::deparse(buf)?;
print!("{:indent$}", "", indent = depth * 2);
println!(
"blk #{}: path {}: prefix {}, suffix_len {}",
blknum,
hex::encode(path),
hex::encode(node.prefix),
node.suffix_len
);
if child_idx == 0 {
print!("{:indent$}", "", indent = depth * 2);
let path_prefix = stack
.iter()
.map(|(_blknum, path, ..)| path.as_str())
.collect::<String>();
println!(
"blk #{blknum}: path {path_prefix}{path}: prefix {}, suffix_len {}",
hex::encode(node.prefix),
node.suffix_len
);
}
let mut idx = 0;
let mut key_off = 0;
while idx < node.num_children {
if child_idx + 1 < node.num_children {
let key_off = key_off + node.suffix_len as usize;
stack.push((blknum, path.clone(), depth, child_idx + 1, key_off));
}
let key = &node.keys[key_off..key_off + node.suffix_len as usize];
let val = node.value(idx as usize);
let val = node.value(child_idx as usize);
print!("{:indent$}", "", indent = depth * 2 + 2);
println!("{}: {}", hex::encode(key), hex::encode(val.0));
if node.level > 0 {
let child_path = [path, node.prefix].concat();
self.dump_recurse(val.to_blknum(), &child_path, depth + 1)?;
stack.push((val.to_blknum(), hex::encode(node.prefix), depth + 1, 0, 0));
}
idx += 1;
key_off += node.suffix_len as usize;
}
Ok(())
}
@@ -754,8 +757,8 @@ mod tests {
}
}
#[test]
fn basic() -> Result<()> {
#[tokio::test]
async fn basic() -> Result<()> {
let mut disk = TestDisk::new();
let mut writer = DiskBtreeBuilder::<_, 6>::new(&mut disk);
@@ -775,7 +778,7 @@ mod tests {
let reader = DiskBtreeReader::new(0, root_offset, disk);
reader.dump()?;
reader.dump().await?;
// Test the `get` function on all the keys.
for (key, val) in all_data.iter() {
@@ -835,8 +838,8 @@ mod tests {
Ok(())
}
#[test]
fn lots_of_keys() -> Result<()> {
#[tokio::test]
async fn lots_of_keys() -> Result<()> {
let mut disk = TestDisk::new();
let mut writer = DiskBtreeBuilder::<_, 8>::new(&mut disk);
@@ -856,7 +859,7 @@ mod tests {
let reader = DiskBtreeReader::new(0, root_offset, disk);
reader.dump()?;
reader.dump().await?;
use std::sync::Mutex;
@@ -994,8 +997,8 @@ mod tests {
///
/// This test contains a particular data set, see disk_btree_test_data.rs
///
#[test]
fn particular_data() -> Result<()> {
#[tokio::test]
async fn particular_data() -> Result<()> {
// Build a tree from it
let mut disk = TestDisk::new();
let mut writer = DiskBtreeBuilder::<_, 26>::new(&mut disk);
@@ -1022,7 +1025,7 @@ mod tests {
})?;
assert_eq!(count, disk_btree_test_data::TEST_DATA.len());
reader.dump()?;
reader.dump().await?;
Ok(())
}

View File

@@ -223,6 +223,45 @@ mod tests {
assert_eq!(part, expected);
}
#[test]
fn v2_indexpart_is_parsed_with_deleted_at() {
let example = r#"{
"version":2,
"timeline_layers":["000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__0000000001696070-00000000016960E9"],
"missing_layers":["This shouldn't fail deserialization"],
"layer_metadata":{
"000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__0000000001696070-00000000016960E9": { "file_size": 25600000 },
"000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__00000000016B59D8-00000000016B5A51": { "file_size": 9007199254741001 }
},
"disk_consistent_lsn":"0/16960E8",
"metadata_bytes":[112,11,159,210,0,54,0,4,0,0,0,0,1,105,96,232,1,0,0,0,0,1,105,96,112,0,0,0,0,0,0,0,0,0,0,0,0,0,1,105,96,112,0,0,0,0,1,105,96,112,0,0,0,14,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
"deleted_at": "2023-07-31T09:00:00.123"
}"#;
let expected = IndexPart {
// note this is not verified, could be anything, but exists for humans debugging.. could be the git version instead?
version: 2,
timeline_layers: HashSet::from(["000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__0000000001696070-00000000016960E9".parse().unwrap()]),
layer_metadata: HashMap::from([
("000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__0000000001696070-00000000016960E9".parse().unwrap(), IndexLayerMetadata {
file_size: 25600000,
}),
("000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__00000000016B59D8-00000000016B5A51".parse().unwrap(), IndexLayerMetadata {
// serde_json should always parse this but this might be a double with jq for
// example.
file_size: 9007199254741001,
})
]),
disk_consistent_lsn: "0/16960E8".parse::<Lsn>().unwrap(),
metadata_bytes: [112,11,159,210,0,54,0,4,0,0,0,0,1,105,96,232,1,0,0,0,0,1,105,96,112,0,0,0,0,0,0,0,0,0,0,0,0,0,1,105,96,112,0,0,0,0,1,105,96,112,0,0,0,14,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0].to_vec(),
deleted_at: Some(chrono::NaiveDateTime::parse_from_str(
"2023-07-31T09:00:00.123000000", "%Y-%m-%dT%H:%M:%S.%f").unwrap())
};
let part = serde_json::from_str::<IndexPart>(example).unwrap();
assert_eq!(part, expected);
}
#[test]
fn empty_layers_are_parsed() {
let empty_layers_json = r#"{

View File

@@ -256,7 +256,7 @@ impl Layer for DeltaLayer {
file,
);
tree_reader.dump()?;
tree_reader.dump().await?;
let mut cursor = file.block_cursor();

View File

@@ -175,7 +175,7 @@ impl Layer for ImageLayer {
let tree_reader =
DiskBtreeReader::<_, KEY_SIZE>::new(inner.index_start_blk, inner.index_root_blk, file);
tree_reader.dump()?;
tree_reader.dump().await?;
tree_reader.visit(&[0u8; KEY_SIZE], VisitDirection::Forwards, |key, value| {
println!("key: {} offset {}", hex::encode(key), value);

View File

@@ -294,6 +294,10 @@ pub struct Timeline {
/// Completion shared between all timelines loaded during startup; used to delay heavier
/// background tasks until some logical sizes have been calculated.
initial_logical_size_attempt: Mutex<Option<completion::Completion>>,
/// Load or creation time information about the disk_consistent_lsn and when the loading
/// happened. Used for consumption metrics.
pub(crate) loaded_at: (Lsn, SystemTime),
}
pub struct WalReceiverInfo {
@@ -1404,6 +1408,8 @@ impl Timeline {
last_freeze_at: AtomicLsn::new(disk_consistent_lsn.0),
last_freeze_ts: RwLock::new(Instant::now()),
loaded_at: (disk_consistent_lsn, SystemTime::now()),
ancestor_timeline: ancestor,
ancestor_lsn: metadata.ancestor_lsn(),
@@ -1600,7 +1606,7 @@ impl Timeline {
if let Some(imgfilename) = ImageFileName::parse_str(&fname) {
// create an ImageLayer struct for each image file.
if imgfilename.lsn > disk_consistent_lsn {
warn!(
info!(
"found future image layer {} on timeline {} disk_consistent_lsn is {}",
imgfilename, self.timeline_id, disk_consistent_lsn
);
@@ -1632,7 +1638,7 @@ impl Timeline {
// is 102, then it might not have been fully flushed to disk
// before crash.
if deltafilename.lsn_range.end > disk_consistent_lsn + 1 {
warn!(
info!(
"found future delta layer {} on timeline {} disk_consistent_lsn is {}",
deltafilename, self.timeline_id, disk_consistent_lsn
);
@@ -1774,7 +1780,7 @@ impl Timeline {
match remote_layer_name {
LayerFileName::Image(imgfilename) => {
if imgfilename.lsn > up_to_date_disk_consistent_lsn {
warn!(
info!(
"found future image layer {} on timeline {} remote_consistent_lsn is {}",
imgfilename, self.timeline_id, up_to_date_disk_consistent_lsn
);
@@ -1799,7 +1805,7 @@ impl Timeline {
// is 102, then it might not have been fully flushed to disk
// before crash.
if deltafilename.lsn_range.end > up_to_date_disk_consistent_lsn + 1 {
warn!(
info!(
"found future delta layer {} on timeline {} remote_consistent_lsn is {}",
deltafilename, self.timeline_id, up_to_date_disk_consistent_lsn
);

View File

@@ -88,50 +88,16 @@ static LWLockId lfc_lock;
static int lfc_max_size;
static int lfc_size_limit;
static int lfc_free_space_watermark;
static int lfc_free_memory_watermark;
static char* lfc_path;
static FileCacheControl* lfc_ctl;
static shmem_startup_hook_type prev_shmem_startup_hook;
#if PG_VERSION_NUM>=150000
static shmem_request_hook_type prev_shmem_request_hook;
#endif
static int lfc_shrinking_factor; /* power of two by which local cache size will be shrinked when lfc_free_space_watermark or lfc_free_memory_watermak are reached */
static int lfc_shrinking_factor; /* power of two by which local cache size will be shrinked when lfc_free_space_watermark is reached */
void FileCacheMonitorMain(Datum main_arg);
#ifdef __APPLE__
#include <sys/types.h>
#include <sys/sysctl.h>
static size_t
get_available_memory(void)
{
size_t total;
size_t sizeof_total = sizeof(total);
if (sysctlbyname("hw.memsize", &total, &sizeof_total, NULL, 0) < 0)
elog(ERROR, "Failed to get amount of RAM: %m");
return total;
}
#else
#include <sys/sysinfo.h>
static size_t
get_available_memory(void)
{
struct sysinfo si;
if (sysinfo(&si) < 0)
elog(ERROR, "Failed to get amount of RAM: %m");
return si.totalram*si.mem_unit;
}
#endif
static void
lfc_shmem_startup(void)
{
@@ -229,11 +195,10 @@ lfc_change_limit_hook(int newval, void *extra)
}
/*
* Local file system state monitor check available free space and memory.
* If available disk space is lower than lfc_free_space_watermark or
* available memory is lower than lfc_free_memory_watermark then we shrink size of local cache
* Local file system state monitor check available free space.
* If it is lower than lfc_free_space_watermark then we shrink size of local cache
* but throwing away least recently accessed chunks.
* First time the watermark is reached cache size is divided by two,
* First time low space watermark is reached cache size is divided by two,
* second time by four,... Finally we remove all chunks from local cache.
*
* Please notice that we are not changing lfc_cache_size: it is used to be adjusted by autoscaler.
@@ -263,27 +228,23 @@ FileCacheMonitorMain(Datum main_arg)
{
if (lfc_size_limit != 0)
{
bool shrink_cache = false;
if (lfc_free_space_watermark != 0)
struct statvfs sfs;
if (statvfs(lfc_path, &sfs) < 0)
{
struct statvfs sfs;
if (statvfs(lfc_path, &sfs) < 0)
elog(WARNING, "Failed to obtain status of %s: %m", lfc_path);
else
shrink_cache |= sfs.f_bavail*sfs.f_bsize < lfc_free_space_watermark*MB;
}
if (lfc_free_memory_watermark != 0)
shrink_cache |= get_available_memory() < lfc_free_memory_watermark*MB;
if (shrink_cache)
{
if (lfc_shrinking_factor < 31) {
lfc_shrinking_factor += 1;
}
lfc_change_limit_hook(lfc_size_limit >> lfc_shrinking_factor, NULL);
elog(WARNING, "Failed to obtain status of %s: %m", lfc_path);
}
else
lfc_shrinking_factor = 0; /* reset to initial value */
{
if (sfs.f_bavail*sfs.f_bsize < lfc_free_space_watermark*MB)
{
if (lfc_shrinking_factor < 31) {
lfc_shrinking_factor += 1;
}
lfc_change_limit_hook(lfc_size_limit >> lfc_shrinking_factor, NULL);
}
else
lfc_shrinking_factor = 0; /* reset to initial value */
}
}
pg_usleep(monitor_interval);
}
@@ -356,19 +317,6 @@ lfc_init(void)
NULL,
NULL);
DefineCustomIntVariable("neon.free_memory_watermark",
"Minimal free memory in system after reaching which local file cache will be truncated",
NULL,
&lfc_free_memory_watermark,
0, /* disabled by default, because iurt makes sense only when local file cache is located i tmpfs */
0,
INT_MAX,
PGC_SIGHUP,
GUC_UNIT_MB,
NULL,
NULL,
NULL);
DefineCustomStringVariable("neon.file_cache_path",
"Path to local file cache (can be raw device)",
NULL,

View File

@@ -29,9 +29,9 @@ metrics.workspace = true
once_cell.workspace = true
opentelemetry.workspace = true
parking_lot.workspace = true
pbkdf2.workspace = true
pin-project-lite.workspace = true
postgres_backend.workspace = true
postgres-protocol.workspace = true
pq_proto.workspace = true
prometheus.workspace = true
rand.workspace = true
@@ -65,10 +65,13 @@ webpki-roots.workspace = true
x509-parser.workspace = true
native-tls.workspace = true
postgres-native-tls.workspace = true
tokio-native-tls = "0.3.1"
workspace_hack.workspace = true
tokio-util.workspace = true
fallible-iterator = "0.2.0"
[dev-dependencies]
rcgen.workspace = true
rstest.workspace = true

View File

@@ -5,7 +5,7 @@ use crate::{
auth::{self, AuthFlow, ClientCredentials},
compute,
console::{self, AuthInfo, CachedNodeInfo, ConsoleReqExtra},
proxy::{try_wake, NUM_RETRIES_CONNECT},
proxy::handle_try_wake,
sasl, scram,
stream::PqStream,
};
@@ -51,14 +51,15 @@ pub(super) async fn authenticate(
}
};
info!("compute node's state has likely changed; requesting a wake-up");
let mut num_retries = 0;
let mut node = loop {
num_retries += 1;
match try_wake(api, extra, creds).await? {
let wake_res = api.wake_compute(extra, creds).await;
match handle_try_wake(wake_res, num_retries)? {
ControlFlow::Continue(_) => num_retries += 1,
ControlFlow::Break(n) => break n,
ControlFlow::Continue(_) if num_retries < NUM_RETRIES_CONNECT => continue,
ControlFlow::Continue(e) => return Err(e.into()),
}
info!(num_retries, "retrying wake compute");
};
if let Some(keys) = scram_keys {
use tokio_postgres::config::AuthKeys;

View File

@@ -6,7 +6,7 @@ use std::fmt;
use std::{collections::HashMap, sync::Arc};
use tokio::time;
use crate::{auth, console};
use crate::{auth, console, pg_client};
use crate::{compute, config};
use super::sql_over_http::MAX_RESPONSE_SIZE;
@@ -41,8 +41,10 @@ impl fmt::Display for ConnInfo {
}
}
type PgConn =
pg_client::connection::Connection<tokio_postgres::Socket, tokio_postgres::tls::NoTlsStream>;
struct ConnPoolEntry {
conn: tokio_postgres::Client,
conn: PgConn,
_last_access: std::time::Instant,
}
@@ -78,12 +80,8 @@ impl GlobalConnPool {
})
}
pub async fn get(
&self,
conn_info: &ConnInfo,
force_new: bool,
) -> anyhow::Result<tokio_postgres::Client> {
let mut client: Option<tokio_postgres::Client> = None;
pub async fn get(&self, conn_info: &ConnInfo, force_new: bool) -> anyhow::Result<PgConn> {
let mut client: Option<PgConn> = None;
if !force_new {
let pool = self.get_endpoint_pool(&conn_info.hostname).await;
@@ -114,11 +112,7 @@ impl GlobalConnPool {
}
}
pub async fn put(
&self,
conn_info: &ConnInfo,
client: tokio_postgres::Client,
) -> anyhow::Result<()> {
pub async fn put(&self, conn_info: &ConnInfo, client: PgConn) -> anyhow::Result<()> {
let pool = self.get_endpoint_pool(&conn_info.hostname).await;
// return connection to the pool
@@ -191,7 +185,7 @@ struct TokioMechanism<'a> {
#[async_trait]
impl ConnectMechanism for TokioMechanism<'_> {
type Connection = tokio_postgres::Client;
type Connection = PgConn;
type ConnectError = tokio_postgres::Error;
type Error = anyhow::Error;
@@ -213,7 +207,7 @@ impl ConnectMechanism for TokioMechanism<'_> {
async fn connect_to_compute(
config: &config::ProxyConfig,
conn_info: &ConnInfo,
) -> anyhow::Result<tokio_postgres::Client> {
) -> anyhow::Result<PgConn> {
let tls = config.tls_config.as_ref();
let common_names = tls.and_then(|tls| tls.common_names.clone());
@@ -251,7 +245,7 @@ async fn connect_to_compute_once(
node_info: &console::CachedNodeInfo,
conn_info: &ConnInfo,
timeout: time::Duration,
) -> Result<tokio_postgres::Client, tokio_postgres::Error> {
) -> Result<PgConn, tokio_postgres::Error> {
let mut config = (*node_info.config).clone();
let (client, connection) = config
@@ -263,11 +257,13 @@ async fn connect_to_compute_once(
.connect(tokio_postgres::NoTls)
.await?;
tokio::spawn(async move {
if let Err(e) = connection.await {
error!("connection error: {}", e);
}
});
let stream = connection.stream.into_inner();
Ok(client)
// tokio::spawn(async move {
// if let Err(e) = connection.await {
// error!("connection error: {}", e);
// }
// });
Ok(PgConn::new(stream))
}

View File

@@ -1,20 +1,38 @@
use std::io::ErrorKind;
use std::sync::Arc;
use anyhow::bail;
use bytes::BufMut;
use fallible_iterator::FallibleIterator;
use futures::pin_mut;
use futures::StreamExt;
use hashbrown::HashMap;
use hyper::body::HttpBody;
use hyper::http::HeaderName;
use hyper::http::HeaderValue;
use hyper::{Body, HeaderMap, Request};
use postgres_protocol::message::backend::DataRowBody;
use postgres_protocol::message::backend::ReadyForQueryBody;
use serde_json::json;
use serde_json::Map;
use serde_json::Value;
use tokio::io::AsyncRead;
use tokio::io::AsyncWrite;
use tokio_postgres::types::Kind;
use tokio_postgres::types::Type;
use tokio_postgres::GenericClient;
use tokio_postgres::IsolationLevel;
use tokio_postgres::Row;
use tokio_postgres::RowStream;
use tokio_postgres::Statement;
use url::Url;
use crate::pg_client;
use crate::pg_client::codec::FrontendMessage;
use crate::pg_client::connection;
use crate::pg_client::connection::RequestMessages;
use crate::pg_client::prepare::TypeinfoPreparedQueries;
use super::conn_pool::ConnInfo;
use super::conn_pool::GlobalConnPool;
@@ -37,6 +55,8 @@ const MAX_REQUEST_SIZE: u64 = 1024 * 1024; // 1 MB
static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in");
static TXN_ISOLATION_LEVEL: HeaderName = HeaderName::from_static("neon-batch-isolation-level");
static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only");
static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
@@ -170,7 +190,7 @@ pub async fn handle(
request: Request<Body>,
sni_hostname: Option<String>,
conn_pool: Arc<GlobalConnPool>,
) -> anyhow::Result<Value> {
) -> anyhow::Result<(Value, HashMap<HeaderName, HeaderValue>)> {
//
// Determine the destination and connection params
//
@@ -185,6 +205,23 @@ pub async fn handle(
// Allow connection pooling only if explicitly requested
let allow_pool = headers.get(&ALLOW_POOL) == Some(&HEADER_VALUE_TRUE);
// isolation level and read only
let txn_isolation_level_raw = headers.get(&TXN_ISOLATION_LEVEL).cloned();
let txn_isolation_level = match txn_isolation_level_raw {
Some(ref x) => Some(match x.as_bytes() {
b"Serializable" => IsolationLevel::Serializable,
b"ReadUncommitted" => IsolationLevel::ReadUncommitted,
b"ReadCommitted" => IsolationLevel::ReadCommitted,
b"RepeatableRead" => IsolationLevel::RepeatableRead,
_ => bail!("invalid isolation level"),
}),
None => None,
};
let txn_read_only_raw = headers.get(&TXN_READ_ONLY).cloned();
let txn_read_only = txn_read_only_raw.as_ref() == Some(&HEADER_VALUE_TRUE);
let request_content_length = match request.body().size_hint().upper() {
Some(v) => v,
None => MAX_REQUEST_SIZE + 1,
@@ -208,26 +245,48 @@ pub async fn handle(
// Now execute the query and return the result
//
let result = match payload {
Payload::Single(query) => query_to_json(&client, query, raw_output, array_mode).await,
Payload::Single(query) => query_raw_txt_as_json(&mut client, query, raw_output, array_mode)
.await
.map(|x| (x, HashMap::default())),
Payload::Batch(queries) => {
let mut results = Vec::new();
let transaction = client.transaction().await?;
client
.start_tx(txn_isolation_level, Some(txn_read_only))
.await?;
for query in queries {
let result = query_to_json(&transaction, query, raw_output, array_mode).await;
let result =
query_raw_txt_as_json(&mut client, query, raw_output, array_mode).await;
match result {
Ok(r) => results.push(r),
// TODO: check this tag to see if the client has executed a commit during the non-interactive transactions...
Ok((r, _ready_tag)) => results.push(r),
Err(e) => {
transaction.rollback().await?;
let tag = client.rollback().await?;
if allow_pool && tag.status() == b'I' {
// return connection to the pool
tokio::task::spawn(async move {
let _ = conn_pool.put(&conn_info, client).await;
});
}
return Err(e);
}
}
}
transaction.commit().await?;
Ok(json!({ "results": results }))
let ready_tag = client.commit().await?;
let mut headers = HashMap::default();
headers.insert(
TXN_READ_ONLY.clone(),
HeaderValue::try_from(txn_read_only.to_string())?,
);
if let Some(txn_isolation_level_raw) = txn_isolation_level_raw {
headers.insert(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level_raw);
}
Ok(((json!({ "results": results }), ready_tag), headers))
}
};
if allow_pool {
if allow_pool && ready_tag.status() == b'I' {
// return connection to the pool
tokio::task::spawn(async move {
let _ = conn_pool.put(&conn_info, client).await;
@@ -312,6 +371,99 @@ async fn query_to_json<T: GenericClient>(
}))
}
async fn query_raw_txt_as_json<'a, St, T>(
conn: &mut connection::Connection<St, T>,
data: QueryData,
raw_output: bool,
array_mode: bool,
) -> anyhow::Result<(Value, ReadyForQueryBody)>
where
St: AsyncRead + AsyncWrite + Unpin + Send,
T: AsyncRead + AsyncWrite + Unpin + Send,
{
let params = json_to_pg_text(data.params)?;
let params = params.into_iter();
let stmt_name = conn.statement_name();
let row_description = conn.prepare(&stmt_name, &data.query).await?;
let mut fields = vec![];
let mut columns = vec![];
let mut it = row_description.fields();
while let Some(field) = it.next().map_err(pg_client::error::Error::parse)? {
fields.push(json!({
"name": Value::String(field.name().to_owned()),
"dataTypeID": Value::Number(field.type_oid().into()),
"tableID": field.table_oid(),
"columnID": field.column_id(),
"dataTypeSize": field.type_size(),
"dataTypeModifier": field.type_modifier(),
"format": "text",
}));
let type_ = match Type::from_oid(field.type_oid()) {
Some(t) => t,
None => TypeinfoPreparedQueries::get_type(conn, field.type_oid()).await?,
};
columns.push(Column {
name: field.name().to_string(),
type_,
});
}
conn.execute("", &stmt_name, params)?;
conn.sync().await?;
let mut rows = vec![];
let mut row_stream = conn.stream_query_results().await?;
let mut curret_size = 0;
while let Some(row) = row_stream.next().await.transpose()? {
// let row = row.map_err(Error::db)?;
curret_size += row.buffer().len();
if curret_size > MAX_RESPONSE_SIZE {
return Err(anyhow::anyhow!("response too large"));
}
rows.push(pg_text_row_to_json2(&row, &columns, raw_output, array_mode).unwrap());
}
let command_tag = row_stream.tag();
let command_tag = command_tag.tag()?;
let mut command_tag_split = command_tag.split(' ');
let command_tag_name = command_tag_split.next().unwrap_or_default();
let command_tag_count = if command_tag_name == "INSERT" {
// INSERT returns OID first and then number of rows
command_tag_split.nth(1)
} else {
// other commands return number of rows (if any)
command_tag_split.next()
}
.and_then(|s| s.parse::<i64>().ok());
let ready_tag = conn.wait_for_ready().await?;
// resulting JSON format is based on the format of node-postgres result
Ok((
json!({
"command": command_tag_name,
"rowCount": command_tag_count,
"rows": rows,
"fields": fields,
"rowAsArray": array_mode,
}),
ready_tag,
))
}
struct Column {
name: String,
type_: Type,
}
//
// Convert postgres row with text-encoded values to JSON object
//
@@ -331,7 +483,7 @@ pub fn pg_text_row_to_json(
} else {
pg_text_to_json(pg_value, column.type_())?
};
Ok((name.to_string(), json_value))
Ok((name, json_value))
});
if array_mode {
@@ -341,7 +493,55 @@ pub fn pg_text_row_to_json(
.collect::<Result<Vec<Value>, anyhow::Error>>()?;
Ok(Value::Array(arr))
} else {
let obj = iter.collect::<Result<Map<String, Value>, anyhow::Error>>()?;
let obj = iter
.map(|r| r.map(|(key, val)| (key.to_owned(), val)))
.collect::<Result<Map<String, Value>, anyhow::Error>>()?;
Ok(Value::Object(obj))
}
}
//
// Convert postgres row with text-encoded values to JSON object
//
fn pg_text_row_to_json2(
row: &DataRowBody,
columns: &[Column],
raw_output: bool,
array_mode: bool,
) -> Result<Value, anyhow::Error> {
let ranges: Vec<Option<std::ops::Range<usize>>> = row.ranges().collect()?;
let iter = std::iter::zip(ranges, columns)
.enumerate()
.map(|(i, (range, column))| {
let name = &column.name;
let pg_value = range
.map(|r| {
std::str::from_utf8(&row.buffer()[r])
.map_err(|e| pg_client::error::Error::from_sql(e.into(), i))
})
.transpose()?;
// let pg_value = row.as_text(i)?;
let json_value = if raw_output {
match pg_value {
Some(v) => Value::String(v.to_string()),
None => Value::Null,
}
} else {
pg_text_to_json(pg_value, &column.type_)?
};
Ok((name, json_value))
});
if array_mode {
// drop keys and aggregate into array
let arr = iter
.map(|r| r.map(|(_key, val)| val))
.collect::<Result<Vec<Value>, anyhow::Error>>()?;
Ok(Value::Array(arr))
} else {
let obj = iter
.map(|r| r.map(|(key, val)| (key.to_owned(), val)))
.collect::<Result<Map<String, Value>, anyhow::Error>>()?;
Ok(Value::Object(obj))
}
}
@@ -352,16 +552,16 @@ pub fn pg_text_row_to_json(
pub fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result<Value, anyhow::Error> {
if let Some(val) = pg_value {
if let Kind::Array(elem_type) = pg_type.kind() {
return pg_array_parse(val, elem_type);
return pg_array_parse(val, &elem_type);
}
match *pg_type {
Type::BOOL => Ok(Value::Bool(val == "t")),
Type::INT2 | Type::INT4 => {
match pg_type {
&Type::BOOL => Ok(Value::Bool(val == "t")),
&Type::INT2 | &Type::INT4 => {
let val = val.parse::<i32>()?;
Ok(Value::Number(serde_json::Number::from(val)))
}
Type::FLOAT4 | Type::FLOAT8 => {
&Type::FLOAT4 | &Type::FLOAT8 => {
let fval = val.parse::<f64>()?;
let num = serde_json::Number::from_f64(fval);
if let Some(num) = num {
@@ -373,7 +573,7 @@ pub fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result<Value,
Ok(Value::String(val.to_string()))
}
}
Type::JSON | Type::JSONB => Ok(serde_json::from_str(val)?),
&Type::JSON | &Type::JSONB => Ok(serde_json::from_str(val)?),
_ => Ok(Value::String(val.to_string())),
}
} else {

View File

@@ -6,6 +6,7 @@ use crate::{
};
use bytes::{Buf, Bytes};
use futures::{Sink, Stream, StreamExt};
use hashbrown::HashMap;
use hyper::{
server::{
accept,
@@ -205,7 +206,7 @@ async fn ws_handler(
Ok(_) => StatusCode::OK,
Err(_) => StatusCode::BAD_REQUEST,
};
let json = match result {
let (json, headers) = match result {
Ok(r) => r,
Err(e) => {
let message = format!("{:?}", e);
@@ -216,7 +217,10 @@ async fn ws_handler(
},
None => Value::Null,
};
json!({ "message": message, "code": code })
(
json!({ "message": message, "code": code }),
HashMap::default(),
)
}
};
json_response(status_code, json).map(|mut r| {
@@ -224,6 +228,9 @@ async fn ws_handler(
"Access-Control-Allow-Origin",
hyper::http::HeaderValue::from_static("*"),
);
for (k, v) in headers {
r.headers_mut().insert(k, v);
}
r
})
} else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS {

View File

@@ -22,6 +22,7 @@ pub mod scram;
pub mod stream;
pub mod url;
pub mod waiters;
pub mod pg_client;
/// Handle unix signals appropriately.
pub async fn handle_signals(token: CancellationToken) -> anyhow::Result<Infallible> {

View File

@@ -0,0 +1,43 @@
use bytes::{Bytes, BytesMut};
use fallible_iterator::FallibleIterator;
use postgres_protocol::message::backend::{self, Message};
use std::io;
use tokio_util::codec::{Decoder, Encoder};
pub struct FrontendMessage(pub Bytes);
pub struct BackendMessages(pub BytesMut);
impl BackendMessages {
pub fn empty() -> BackendMessages {
BackendMessages(BytesMut::new())
}
}
impl FallibleIterator for BackendMessages {
type Item = backend::Message;
type Error = io::Error;
fn next(&mut self) -> io::Result<Option<backend::Message>> {
backend::Message::parse(&mut self.0)
}
}
pub struct PostgresCodec;
impl Encoder<FrontendMessage> for PostgresCodec {
type Error = io::Error;
fn encode(&mut self, item: FrontendMessage, dst: &mut BytesMut) -> io::Result<()> {
dst.extend_from_slice(&item.0);
Ok(())
}
}
impl Decoder for PostgresCodec {
type Item = Message;
type Error = io::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Message>, io::Error> {
Message::parse(src)
}
}

View File

@@ -0,0 +1,369 @@
use super::codec::{BackendMessages, FrontendMessage, PostgresCodec};
use super::error::Error;
use super::prepare::TypeinfoPreparedQueries;
use bytes::{BufMut, BytesMut};
use futures::channel::mpsc;
use futures::{Sink, StreamExt};
use futures::{SinkExt, Stream};
use hashbrown::HashMap;
use postgres_protocol::message::backend::{
BackendKeyDataBody, CommandCompleteBody, DataRowBody, ErrorResponseBody, Message,
ReadyForQueryBody, RowDescriptionBody,
};
use postgres_protocol::message::frontend;
use postgres_protocol::Oid;
use std::collections::VecDeque;
use std::future::poll_fn;
use std::pin::Pin;
use std::task::{ready, Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_postgres::maybe_tls_stream::MaybeTlsStream;
use tokio_postgres::types::Type;
use tokio_postgres::IsolationLevel;
use tokio_util::codec::Framed;
pub enum RequestMessages {
Single(FrontendMessage),
}
pub struct Request {
pub messages: RequestMessages,
pub sender: mpsc::Sender<BackendMessages>,
}
pub struct Response {
sender: mpsc::Sender<BackendMessages>,
}
/// A connection to a PostgreSQL database.
pub struct RawConnection<S, T> {
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
pending_responses: VecDeque<Message>,
pub buf: BytesMut,
}
impl<S: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin> RawConnection<S, T> {
pub fn new(
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
buf: BytesMut,
) -> RawConnection<S, T> {
RawConnection {
stream,
pending_responses: VecDeque::new(),
buf,
}
}
pub async fn send(&mut self) -> Result<(), Error> {
poll_fn(|cx| self.poll_send(cx)).await?;
let request = FrontendMessage(self.buf.split().freeze());
self.stream.start_send_unpin(request).map_err(Error::io)?;
poll_fn(|cx| self.poll_flush(cx)).await
}
pub async fn next_message(&mut self) -> Result<Message, Error> {
match self.pending_responses.pop_front() {
Some(message) => Ok(message),
None => poll_fn(|cx| self.poll_read(cx)).await,
}
}
fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll<Result<Message, Error>> {
let message = match ready!(self.stream.poll_next_unpin(cx)?) {
Some(message) => message,
None => return Poll::Ready(Err(Error::closed())),
};
Poll::Ready(Ok(message))
}
fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
Pin::new(&mut self.stream).poll_close(cx).map_err(Error::io)
}
fn poll_send(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
if let Poll::Ready(msg) = self.poll_read(cx)? {
self.pending_responses.push_back(msg);
};
self.stream.poll_ready_unpin(cx).map_err(Error::io)
}
fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
if let Poll::Ready(msg) = self.poll_read(cx)? {
self.pending_responses.push_back(msg);
};
self.stream.poll_flush_unpin(cx).map_err(Error::io)
}
}
pub struct Connection<S, T> {
stmt_counter: usize,
pub typeinfo: Option<TypeinfoPreparedQueries>,
pub typecache: HashMap<Oid, Type>,
pub raw: RawConnection<S, T>,
// key: BackendKeyDataBody,
}
impl<S: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin> Connection<S, T> {
pub fn new(stream: MaybeTlsStream<S, T>) -> Connection<S, T> {
Connection {
stmt_counter: 0,
typeinfo: None,
typecache: HashMap::new(),
raw: RawConnection::new(Framed::new(stream, PostgresCodec), BytesMut::new()),
}
}
pub async fn start_tx(
&mut self,
isolation_level: Option<IsolationLevel>,
read_only: Option<bool>,
) -> Result<ReadyForQueryBody, Error> {
let mut query = "START TRANSACTION".to_string();
let mut first = true;
if let Some(level) = isolation_level {
first = false;
query.push_str(" ISOLATION LEVEL ");
let level = match level {
IsolationLevel::ReadUncommitted => "READ UNCOMMITTED",
IsolationLevel::ReadCommitted => "READ COMMITTED",
IsolationLevel::RepeatableRead => "REPEATABLE READ",
IsolationLevel::Serializable => "SERIALIZABLE",
_ => return Err(Error::unexpected_message()),
};
query.push_str(level);
}
if let Some(read_only) = read_only {
if !first {
query.push(',');
}
first = false;
let s = if read_only {
" READ ONLY"
} else {
" READ WRITE"
};
query.push_str(s);
}
self.execute_simple(&query).await
}
pub async fn rollback(&mut self) -> Result<ReadyForQueryBody, Error> {
self.execute_simple("ROLLBACK").await
}
pub async fn commit(&mut self) -> Result<ReadyForQueryBody, Error> {
self.execute_simple("COMMIT").await
}
// pub async fn auth_sasl_scram<'a, I>(
// mut raw: RawConnection<S, T>,
// params: I,
// password: &[u8],
// ) -> Result<Self, Error>
// where
// I: IntoIterator<Item = (&'a str, &'a str)>,
// {
// // send a startup message
// frontend::startup_message(params, &mut raw.buf).unwrap();
// raw.send().await?;
// // expect sasl authentication message
// let Message::AuthenticationSasl(body) = raw.next_message().await? else { return Err(Error::expecting("sasl authentication")) };
// // expect support for SCRAM_SHA_256
// if body
// .mechanisms()
// .find(|&x| Ok(x == authentication::sasl::SCRAM_SHA_256))?
// .is_none()
// {
// return Err(Error::expecting("SCRAM-SHA-256 auth"));
// }
// // initiate SCRAM_SHA_256 authentication without channel binding
// let auth = authentication::sasl::ChannelBinding::unrequested();
// let mut scram = authentication::sasl::ScramSha256::new(password, auth);
// frontend::sasl_initial_response(
// authentication::sasl::SCRAM_SHA_256,
// scram.message(),
// &mut raw.buf,
// )
// .unwrap();
// raw.send().await?;
// // expect sasl continue
// let Message::AuthenticationSaslContinue(b) = raw.next_message().await? else { return Err(Error::expecting("auth continue")) };
// scram.update(b.data()).unwrap();
// // continue sasl
// frontend::sasl_response(scram.message(), &mut raw.buf).unwrap();
// raw.send().await?;
// // expect sasl final
// let Message::AuthenticationSaslFinal(b) = raw.next_message().await? else { return Err(Error::expecting("auth final")) };
// scram.finish(b.data()).unwrap();
// // expect auth ok
// let Message::AuthenticationOk = raw.next_message().await? else { return Err(Error::expecting("auth ok")) };
// // expect connection accepted
// let key = loop {
// match raw.next_message().await? {
// Message::BackendKeyData(key) => break key,
// Message::ParameterStatus(_) => {}
// _ => return Err(Error::expecting("backend ready")),
// }
// };
// let Message::ReadyForQuery(b) = raw.next_message().await? else { return Err(Error::expecting("ready for query")) };
// // assert_eq!(b.status(), b'I');
// Ok(Self { raw, key })
// }
// pub fn prepare_and_execute(
// &mut self,
// portal: &str,
// name: &str,
// query: &str,
// params: impl IntoIterator<Item = Option<impl AsRef<str>>>,
// ) -> std::io::Result<()> {
// self.prepare(name, query)?;
// self.execute(portal, name, params)
// }
pub fn statement_name(&mut self) -> String {
self.stmt_counter += 1;
format!("s{}", self.stmt_counter)
}
async fn execute_simple(&mut self, query: &str) -> Result<ReadyForQueryBody, Error> {
frontend::query(query, &mut self.raw.buf)?;
self.raw.send().await?;
loop {
match self.raw.next_message().await? {
Message::ReadyForQuery(q) => return Ok(q),
Message::CommandComplete(_)
| Message::EmptyQueryResponse
| Message::RowDescription(_)
| Message::DataRow(_) => {}
_ => return Err(Error::unexpected_message()),
}
}
}
pub async fn prepare(&mut self, name: &str, query: &str) -> Result<RowDescriptionBody, Error> {
frontend::parse(name, query, std::iter::empty(), &mut self.raw.buf)?;
frontend::describe(b'S', name, &mut self.raw.buf)?;
self.sync().await?;
self.wait_for_prepare().await
}
pub fn execute(
&mut self,
portal: &str,
name: &str,
params: impl IntoIterator<Item = Option<impl AsRef<str>>>,
) -> std::io::Result<()> {
frontend::bind(
portal,
name,
std::iter::empty(), // all parameters use the default format (text)
params,
|param, buf| match param {
Some(param) => {
buf.put_slice(param.as_ref().as_bytes());
Ok(postgres_protocol::IsNull::No)
}
None => Ok(postgres_protocol::IsNull::Yes),
},
Some(0), // all text
&mut self.raw.buf,
)
.map_err(|e| match e {
frontend::BindError::Conversion(e) => std::io::Error::new(std::io::ErrorKind::Other, e),
frontend::BindError::Serialization(io) => io,
})?;
frontend::execute(portal, 0, &mut self.raw.buf)
}
pub async fn sync(&mut self) -> Result<(), Error> {
frontend::sync(&mut self.raw.buf);
self.raw.send().await
}
pub async fn wait_for_prepare(&mut self) -> Result<RowDescriptionBody, Error> {
let Message::ParseComplete = self.raw.next_message().await? else { return Err(Error::expecting("parse")) };
let Message::ParameterDescription(_) = self.raw.next_message().await? else { return Err(Error::expecting("param description")) };
let Message::RowDescription(desc) = self.raw.next_message().await? else { return Err(Error::expecting("row description")) };
self.wait_for_ready().await?;
Ok(desc)
}
pub async fn stream_query_results(&mut self) -> Result<RowStream<'_, S, T>, Error> {
// let Message::ParseComplete = self.raw.next_message().await? else { return Err(Error::expecting("parse")) };
let Message::BindComplete = self.raw.next_message().await? else { return Err(Error::expecting("bind")) };
Ok(RowStream::Stream(&mut self.raw))
}
pub async fn wait_for_ready(&mut self) -> Result<ReadyForQueryBody, Error> {
loop {
match self.raw.next_message().await.unwrap() {
Message::ReadyForQuery(b) => break Ok(b),
_ => continue,
}
}
}
}
pub enum RowStream<'a, S, T> {
Stream(&'a mut RawConnection<S, T>),
Complete(Option<CommandCompleteBody>),
}
impl<S, T> Unpin for RowStream<'_, S, T> {}
impl<S: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin> Stream
for RowStream<'_, S, T>
{
// this is horrible - first result is for transport/protocol errors errors
// second result is for sql errors.
type Item = Result<Result<DataRowBody, ErrorResponseBody>, Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match &mut *self {
RowStream::Stream(raw) => match ready!(raw.poll_read(cx)?) {
Message::DataRow(row) => Poll::Ready(Some(Ok(Ok(row)))),
Message::CommandComplete(tag) => {
*self = Self::Complete(Some(tag));
Poll::Ready(None)
}
Message::EmptyQueryResponse | Message::PortalSuspended => {
*self = Self::Complete(None);
Poll::Ready(None)
}
Message::ErrorResponse(error) => {
*self = Self::Complete(None);
Poll::Ready(Some(Ok(Err(error))))
}
_ => Poll::Ready(Some(Err(Error::expecting("command completion")))),
},
RowStream::Complete(_) => Poll::Ready(None),
}
}
}
impl<S, T> RowStream<'_, S, T> {
pub fn tag(self) -> Option<CommandCompleteBody> {
match self {
RowStream::Stream(_) => panic!("should not get tag unless row stream is exhausted"),
RowStream::Complete(tag) => tag,
}
}
}

View File

@@ -0,0 +1,447 @@
use std::{error, fmt, io};
use fallible_iterator::FallibleIterator;
use postgres_protocol::message::backend::{ErrorFields, ErrorResponseBody};
use tokio_native_tls::native_tls;
use tokio_postgres::error::{ErrorPosition, SqlState};
#[derive(Debug, PartialEq)]
enum Kind {
Io,
Tls,
UnexpectedMessage,
FromSql(usize),
Closed,
Db,
Parse,
Encode,
}
struct ErrorInner {
kind: Kind,
cause: Option<Box<dyn error::Error + Sync + Send>>,
}
/// An error communicating with the Postgres server.
pub struct Error(ErrorInner);
impl fmt::Debug for Error {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("Error")
.field("kind", &self.0.kind)
.field("cause", &self.0.cause)
.finish()
}
}
impl fmt::Display for Error {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.0.kind {
Kind::Io => fmt.write_str("error communicating with the server")?,
Kind::Tls => fmt.write_str("error establishing tls")?,
Kind::UnexpectedMessage => fmt.write_str("unexpected message from server")?,
Kind::FromSql(idx) => write!(fmt, "error deserializing column {}", idx)?,
Kind::Closed => fmt.write_str("connection closed")?,
Kind::Db => fmt.write_str("db error")?,
Kind::Parse => fmt.write_str("error parsing response from server")?,
Kind::Encode => fmt.write_str("error encoding message to server")?,
};
if let Some(ref cause) = self.0.cause {
write!(fmt, ": {}", cause)?;
}
Ok(())
}
}
impl error::Error for Error {
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
self.0.cause.as_ref().map(|e| &**e as _)
}
}
impl From<io::Error> for Error {
fn from(value: io::Error) -> Self {
Self::io(value)
}
}
impl Error {
/// Consumes the error, returning its cause.
pub fn into_source(self) -> Option<Box<dyn error::Error + Sync + Send>> {
self.0.cause
}
/// Returns the source of this error if it was a `DbError`.
///
/// This is a simple convenience method.
pub fn as_db_error(&self) -> Option<&DbError> {
error::Error::source(self).and_then(|e| e.downcast_ref::<DbError>())
}
/// Determines if the error was associated with closed connection.
pub fn is_closed(&self) -> bool {
self.0.kind == Kind::Closed
}
/// Returns the SQLSTATE error code associated with the error.
///
/// This is a convenience method that downcasts the cause to a `DbError` and returns its code.
pub fn code(&self) -> Option<&SqlState> {
self.as_db_error().map(DbError::code)
}
fn new(kind: Kind, cause: Option<Box<dyn error::Error + Sync + Send>>) -> Error {
Error(ErrorInner { kind, cause })
}
#[allow(clippy::needless_pass_by_value)]
pub(crate) fn db(error: ErrorResponseBody) -> Error {
match DbError::parse(&mut error.fields()) {
Ok(e) => Error::new(Kind::Db, Some(Box::new(e))),
Err(e) => Error::new(Kind::Parse, Some(Box::new(e))),
}
}
pub(crate) fn from_sql(e: Box<dyn error::Error + Sync + Send>, idx: usize) -> Error {
Error::new(Kind::FromSql(idx), Some(e))
}
pub(crate) fn closed() -> Error {
Error::new(Kind::Closed, None)
}
pub(crate) fn unexpected_message() -> Error {
Error::new(Kind::UnexpectedMessage, None)
}
pub(crate) fn expecting(expected: &str) -> Error {
Error::new(Kind::UnexpectedMessage, Some(expected.into()))
}
pub(crate) fn parse(e: io::Error) -> Error {
Error::new(Kind::Parse, Some(Box::new(e)))
}
pub(crate) fn encode(e: io::Error) -> Error {
Error::new(Kind::Encode, Some(Box::new(e)))
}
pub(crate) fn io(e: io::Error) -> Error {
Error::new(Kind::Io, Some(Box::new(e)))
}
pub(crate) fn tls(e: native_tls::Error) -> Error {
Error::new(Kind::Tls, Some(Box::new(e)))
}
}
/// The severity of a Postgres error or notice.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum Severity {
/// PANIC
Panic,
/// FATAL
Fatal,
/// ERROR
Error,
/// WARNING
Warning,
/// NOTICE
Notice,
/// DEBUG
Debug,
/// INFO
Info,
/// LOG
Log,
}
impl fmt::Display for Severity {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match *self {
Severity::Panic => "PANIC",
Severity::Fatal => "FATAL",
Severity::Error => "ERROR",
Severity::Warning => "WARNING",
Severity::Notice => "NOTICE",
Severity::Debug => "DEBUG",
Severity::Info => "INFO",
Severity::Log => "LOG",
};
fmt.write_str(s)
}
}
impl Severity {
fn from_str(s: &str) -> Option<Severity> {
match s {
"PANIC" => Some(Severity::Panic),
"FATAL" => Some(Severity::Fatal),
"ERROR" => Some(Severity::Error),
"WARNING" => Some(Severity::Warning),
"NOTICE" => Some(Severity::Notice),
"DEBUG" => Some(Severity::Debug),
"INFO" => Some(Severity::Info),
"LOG" => Some(Severity::Log),
_ => None,
}
}
}
/// A Postgres error or notice.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DbError {
severity: String,
parsed_severity: Option<Severity>,
code: SqlState,
message: String,
detail: Option<String>,
hint: Option<String>,
position: Option<ErrorPosition>,
where_: Option<String>,
schema: Option<String>,
table: Option<String>,
column: Option<String>,
datatype: Option<String>,
constraint: Option<String>,
file: Option<String>,
line: Option<u32>,
routine: Option<String>,
}
impl DbError {
pub(crate) fn parse(fields: &mut ErrorFields<'_>) -> io::Result<DbError> {
let mut severity = None;
let mut parsed_severity = None;
let mut code = None;
let mut message = None;
let mut detail = None;
let mut hint = None;
let mut normal_position = None;
let mut internal_position = None;
let mut internal_query = None;
let mut where_ = None;
let mut schema = None;
let mut table = None;
let mut column = None;
let mut datatype = None;
let mut constraint = None;
let mut file = None;
let mut line = None;
let mut routine = None;
while let Some(field) = fields.next()? {
match field.type_() {
b'S' => severity = Some(field.value().to_owned()),
b'C' => code = Some(SqlState::from_code(field.value())),
b'M' => message = Some(field.value().to_owned()),
b'D' => detail = Some(field.value().to_owned()),
b'H' => hint = Some(field.value().to_owned()),
b'P' => {
normal_position = Some(field.value().parse::<u32>().map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
"`P` field did not contain an integer",
)
})?);
}
b'p' => {
internal_position = Some(field.value().parse::<u32>().map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
"`p` field did not contain an integer",
)
})?);
}
b'q' => internal_query = Some(field.value().to_owned()),
b'W' => where_ = Some(field.value().to_owned()),
b's' => schema = Some(field.value().to_owned()),
b't' => table = Some(field.value().to_owned()),
b'c' => column = Some(field.value().to_owned()),
b'd' => datatype = Some(field.value().to_owned()),
b'n' => constraint = Some(field.value().to_owned()),
b'F' => file = Some(field.value().to_owned()),
b'L' => {
line = Some(field.value().parse::<u32>().map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
"`L` field did not contain an integer",
)
})?);
}
b'R' => routine = Some(field.value().to_owned()),
b'V' => {
parsed_severity = Some(Severity::from_str(field.value()).ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"`V` field contained an invalid value",
)
})?);
}
_ => {}
}
}
Ok(DbError {
severity: severity
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "`S` field missing"))?,
parsed_severity,
code: code
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "`C` field missing"))?,
message: message
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "`M` field missing"))?,
detail,
hint,
position: match normal_position {
Some(position) => Some(ErrorPosition::Original(position)),
None => match internal_position {
Some(position) => Some(ErrorPosition::Internal {
position,
query: internal_query.ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"`q` field missing but `p` field present",
)
})?,
}),
None => None,
},
},
where_,
schema,
table,
column,
datatype,
constraint,
file,
line,
routine,
})
}
/// The field contents are ERROR, FATAL, or PANIC (in an error message),
/// or WARNING, NOTICE, DEBUG, INFO, or LOG (in a notice message), or a
/// localized translation of one of these.
pub fn severity(&self) -> &str {
&self.severity
}
/// A parsed, nonlocalized version of `severity`. (PostgreSQL 9.6+)
pub fn parsed_severity(&self) -> Option<Severity> {
self.parsed_severity
}
/// The SQLSTATE code for the error.
pub fn code(&self) -> &SqlState {
&self.code
}
/// The primary human-readable error message.
///
/// This should be accurate but terse (typically one line).
pub fn message(&self) -> &str {
&self.message
}
/// An optional secondary error message carrying more detail about the
/// problem.
///
/// Might run to multiple lines.
pub fn detail(&self) -> Option<&str> {
self.detail.as_deref()
}
/// An optional suggestion what to do about the problem.
///
/// This is intended to differ from `detail` in that it offers advice
/// (potentially inappropriate) rather than hard facts. Might run to
/// multiple lines.
pub fn hint(&self) -> Option<&str> {
self.hint.as_deref()
}
/// An optional error cursor position into either the original query string
/// or an internally generated query.
pub fn position(&self) -> Option<&ErrorPosition> {
self.position.as_ref()
}
/// An indication of the context in which the error occurred.
///
/// Presently this includes a call stack traceback of active procedural
/// language functions and internally-generated queries. The trace is one
/// entry per line, most recent first.
pub fn where_(&self) -> Option<&str> {
self.where_.as_deref()
}
/// If the error was associated with a specific database object, the name
/// of the schema containing that object, if any. (PostgreSQL 9.3+)
pub fn schema(&self) -> Option<&str> {
self.schema.as_deref()
}
/// If the error was associated with a specific table, the name of the
/// table. (Refer to the schema name field for the name of the table's
/// schema.) (PostgreSQL 9.3+)
pub fn table(&self) -> Option<&str> {
self.table.as_deref()
}
/// If the error was associated with a specific table column, the name of
/// the column.
///
/// (Refer to the schema and table name fields to identify the table.)
/// (PostgreSQL 9.3+)
pub fn column(&self) -> Option<&str> {
self.column.as_deref()
}
/// If the error was associated with a specific data type, the name of the
/// data type. (Refer to the schema name field for the name of the data
/// type's schema.) (PostgreSQL 9.3+)
pub fn datatype(&self) -> Option<&str> {
self.datatype.as_deref()
}
/// If the error was associated with a specific constraint, the name of the
/// constraint.
///
/// Refer to fields listed above for the associated table or domain.
/// (For this purpose, indexes are treated as constraints, even if they
/// weren't created with constraint syntax.) (PostgreSQL 9.3+)
pub fn constraint(&self) -> Option<&str> {
self.constraint.as_deref()
}
/// The file name of the source-code location where the error was reported.
pub fn file(&self) -> Option<&str> {
self.file.as_deref()
}
/// The line number of the source-code location where the error was
/// reported.
pub fn line(&self) -> Option<u32> {
self.line
}
/// The name of the source-code routine reporting the error.
pub fn routine(&self) -> Option<&str> {
self.routine.as_deref()
}
}
impl fmt::Display for DbError {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(fmt, "{}: {}", self.severity, self.message)?;
if let Some(detail) = &self.detail {
write!(fmt, "\nDETAIL: {}", detail)?;
}
if let Some(hint) = &self.hint {
write!(fmt, "\nHINT: {}", hint)?;
}
Ok(())
}
}
impl error::Error for DbError {}

View File

@@ -0,0 +1,5 @@
pub mod codec;
pub mod connection;
pub mod error;
pub mod prepare;

View File

@@ -0,0 +1,293 @@
use fallible_iterator::FallibleIterator;
use futures::StreamExt;
use postgres_protocol::message::backend::{DataRowRanges, Message};
use postgres_protocol::message::frontend;
use std::future::Future;
use std::pin::Pin;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_postgres::types::{Field, Kind, Oid, ToSql, Type};
use super::connection::Connection;
use super::error::Error;
const TYPEINFO_QUERY: &str = "\
SELECT t.typname, t.typtype, t.typelem, r.rngsubtype, t.typbasetype, n.nspname, t.typrelid
FROM pg_catalog.pg_type t
LEFT OUTER JOIN pg_catalog.pg_range r ON r.rngtypid = t.oid
INNER JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid
WHERE t.oid = $1
";
const TYPEINFO_ENUM_QUERY: &str = "\
SELECT enumlabel
FROM pg_catalog.pg_enum
WHERE enumtypid = $1
ORDER BY enumsortorder
";
const TYPEINFO_COMPOSITE_QUERY: &str = "\
SELECT attname, atttypid
FROM pg_catalog.pg_attribute
WHERE attrelid = $1
AND NOT attisdropped
AND attnum > 0
ORDER BY attnum
";
#[derive(Clone)]
pub struct TypeinfoPreparedQueries {
query: String,
enum_query: String,
composite_query: String,
}
fn map_is_null(x: tokio_postgres::types::IsNull) -> postgres_protocol::IsNull {
match x {
tokio_postgres::types::IsNull::Yes => postgres_protocol::IsNull::Yes,
tokio_postgres::types::IsNull::No => postgres_protocol::IsNull::No,
}
}
fn read_column<'a, T: tokio_postgres::types::FromSql<'a>>(
buffer: &'a [u8],
type_: &Type,
ranges: &mut DataRowRanges<'a>,
) -> Result<T, Error> {
let range = ranges.next()?;
match range {
Some(range) => T::from_sql_nullable(type_, range.map(|r| &buffer[r])),
None => T::from_sql_null(type_),
}
.map_err(|e| Error::from_sql(e, 0))
}
impl TypeinfoPreparedQueries {
pub async fn new<
S: AsyncRead + AsyncWrite + Unpin + Send,
T: AsyncRead + AsyncWrite + Unpin + Send,
>(
c: &mut Connection<S, T>,
) -> Result<Self, Error> {
if let Some(ti) = &c.typeinfo {
return Ok(ti.clone());
}
let query = c.statement_name();
let enum_query = c.statement_name();
let composite_query = c.statement_name();
frontend::parse(&query, TYPEINFO_QUERY, [Type::OID.oid()], &mut c.raw.buf)?;
frontend::parse(
&enum_query,
TYPEINFO_ENUM_QUERY,
[Type::OID.oid()],
&mut c.raw.buf,
)?;
c.sync().await?;
frontend::parse(
&composite_query,
TYPEINFO_COMPOSITE_QUERY,
[Type::OID.oid()],
&mut c.raw.buf,
)?;
c.sync().await?;
let Message::ParseComplete = c.raw.next_message().await? else { return Err(Error::expecting("parse")) };
let Message::ParseComplete = c.raw.next_message().await? else { return Err(Error::expecting("parse")) };
let Message::ParseComplete = c.raw.next_message().await? else { return Err(Error::expecting("parse")) };
c.wait_for_ready().await?;
Ok(c.typeinfo
.insert(TypeinfoPreparedQueries {
query,
enum_query,
composite_query,
})
.clone())
}
fn get_type_rec<
S: AsyncRead + AsyncWrite + Unpin + Send,
T: AsyncRead + AsyncWrite + Unpin + Send,
>(
c: &mut Connection<S, T>,
oid: Oid,
) -> Pin<Box<dyn Future<Output = Result<Type, Error>> + Send + '_>> {
Box::pin(Self::get_type(c, oid))
}
pub async fn get_type<
S: AsyncRead + AsyncWrite + Unpin + Send,
T: AsyncRead + AsyncWrite + Unpin + Send,
>(
c: &mut Connection<S, T>,
oid: Oid,
) -> Result<Type, Error> {
if let Some(type_) = Type::from_oid(oid) {
return Ok(type_);
}
if let Some(type_) = c.typecache.get(&oid) {
return Ok(type_.clone());
}
let queries = Self::new(c).await?;
frontend::bind(
"",
&queries.query,
[1], // the only parameter is in binary format
[oid],
|param, buf| param.to_sql(&Type::OID, buf).map(map_is_null),
Some(1), // binary return type
&mut c.raw.buf,
)
.map_err(|e| match e {
frontend::BindError::Conversion(e) => std::io::Error::new(std::io::ErrorKind::Other, e),
frontend::BindError::Serialization(io) => io,
})?;
frontend::execute("", 0, &mut c.raw.buf)?;
c.sync().await?;
let mut stream = c.stream_query_results().await?;
let Some(row) = stream.next().await.transpose()? else {
todo!()
};
let row = row.map_err(Error::db)?;
let b = row.buffer();
let mut ranges = row.ranges();
let name: String = read_column(b, &Type::NAME, &mut ranges)?;
let type_: i8 = read_column(b, &Type::CHAR, &mut ranges)?;
let elem_oid: Oid = read_column(b, &Type::OID, &mut ranges)?;
let rngsubtype: Option<Oid> = read_column(b, &Type::OID, &mut ranges)?;
let basetype: Oid = read_column(b, &Type::OID, &mut ranges)?;
let schema: String = read_column(b, &Type::NAME, &mut ranges)?;
let relid: Oid = read_column(b, &Type::OID, &mut ranges)?;
{
// should be none
let None = stream.next().await.transpose()? else {
todo!()
};
drop(stream);
}
let kind = if type_ == b'e' as i8 {
let variants = Self::get_enum_variants(c, oid).await?;
Kind::Enum(variants)
} else if type_ == b'p' as i8 {
Kind::Pseudo
} else if basetype != 0 {
let type_ = Self::get_type_rec(c, basetype).await?;
Kind::Domain(type_)
} else if elem_oid != 0 {
let type_ = Self::get_type_rec(c, elem_oid).await?;
Kind::Array(type_)
} else if relid != 0 {
let fields = Self::get_composite_fields(c, relid).await?;
Kind::Composite(fields)
} else if let Some(rngsubtype) = rngsubtype {
let type_ = Self::get_type_rec(c, rngsubtype).await?;
Kind::Range(type_)
} else {
Kind::Simple
};
let type_ = Type::new(name, oid, kind, schema);
c.typecache.insert(oid, type_.clone());
Ok(type_)
}
async fn get_enum_variants<
S: AsyncRead + AsyncWrite + Unpin + Send,
T: AsyncRead + AsyncWrite + Unpin + Send,
>(
c: &mut Connection<S, T>,
oid: Oid,
) -> Result<Vec<String>, Error> {
let queries = Self::new(c).await?;
frontend::bind(
"",
&queries.enum_query,
[1], // the only parameter is in binary format
[oid],
|param, buf| param.to_sql(&Type::OID, buf).map(map_is_null),
Some(1), // binary return type
&mut c.raw.buf,
)
.map_err(|e| match e {
frontend::BindError::Conversion(e) => std::io::Error::new(std::io::ErrorKind::Other, e),
frontend::BindError::Serialization(io) => io,
})?;
frontend::execute("", 0, &mut c.raw.buf)?;
c.sync().await?;
let mut stream = c.stream_query_results().await?;
let mut variants = Vec::new();
while let Some(row) = stream.next().await.transpose()? {
let row = row.map_err(Error::db)?;
let variant: String = read_column(row.buffer(), &Type::NAME, &mut row.ranges())?;
variants.push(variant);
}
c.wait_for_ready().await?;
Ok(variants)
}
async fn get_composite_fields<
S: AsyncRead + AsyncWrite + Unpin + Send,
T: AsyncRead + AsyncWrite + Unpin + Send,
>(
c: &mut Connection<S, T>,
oid: Oid,
) -> Result<Vec<Field>, Error> {
let queries = Self::new(c).await?;
frontend::bind(
"",
&queries.composite_query,
[1], // the only parameter is in binary format
[oid],
|param, buf| param.to_sql(&Type::OID, buf).map(map_is_null),
Some(1), // binary return type
&mut c.raw.buf,
)
.map_err(|e| match e {
frontend::BindError::Conversion(e) => std::io::Error::new(std::io::ErrorKind::Other, e),
frontend::BindError::Serialization(io) => io,
})?;
frontend::execute("", 0, &mut c.raw.buf)?;
c.sync().await?;
let mut stream = c.stream_query_results().await?;
let mut fields = Vec::new();
while let Some(row) = stream.next().await.transpose()? {
let row = row.map_err(Error::db)?;
let mut ranges = row.ranges();
let name: String = read_column(row.buffer(), &Type::NAME, &mut ranges)?;
let oid: Oid = read_column(row.buffer(), &Type::OID, &mut ranges)?;
fields.push((name, oid));
}
c.wait_for_ready().await?;
let mut output_fields = Vec::with_capacity(fields.len());
for (name, oid) in fields {
let type_ = Self::get_type_rec(c, oid).await?;
output_fields.push(Field::new(name, type_))
}
Ok(output_fields)
}
}

View File

@@ -347,11 +347,6 @@ async fn connect_to_compute_once(
.await
}
enum ConnectionState<E> {
Cached(console::CachedNodeInfo),
Invalid(compute::ConnCfg, E),
}
#[async_trait]
pub trait ConnectMechanism {
type Connection;
@@ -407,70 +402,67 @@ where
mechanism.update_connect_config(&mut node_info.config);
let mut num_retries = 0;
let mut state = ConnectionState::<M::ConnectError>::Cached(node_info);
// try once
let (config, err) = match mechanism.connect_once(&node_info, CONNECT_TIMEOUT).await {
Ok(res) => return Ok(res),
Err(e) => {
error!(error = ?e, "could not connect to compute node");
(invalidate_cache(node_info), e)
}
};
loop {
match state {
ConnectionState::Invalid(config, err) => {
info!("compute node's state has likely changed; requesting a wake-up");
let mut num_retries = 1;
let wake_res = match creds {
auth::BackendType::Console(api, creds) => api.wake_compute(extra, creds).await,
auth::BackendType::Postgres(api, creds) => api.wake_compute(extra, creds).await,
// nothing to do?
auth::BackendType::Link(_) => return Err(err.into()),
// test backend
auth::BackendType::Test(x) => x.wake_compute(),
};
// if we failed to connect, it's likely that the compute node was suspended, wake a new compute node
info!("compute node's state has likely changed; requesting a wake-up");
let node_info = loop {
let wake_res = match creds {
auth::BackendType::Console(api, creds) => api.wake_compute(extra, creds).await,
auth::BackendType::Postgres(api, creds) => api.wake_compute(extra, creds).await,
// nothing to do?
auth::BackendType::Link(_) => return Err(err.into()),
// test backend
auth::BackendType::Test(x) => x.wake_compute(),
};
match handle_try_wake(wake_res) {
// there was an error communicating with the control plane
Err(e) => return Err(e.into()),
// failed to wake up but we can continue to retry
Ok(ControlFlow::Continue(_)) => {
state = ConnectionState::Invalid(config, err);
let wait_duration = retry_after(num_retries);
num_retries += 1;
info!(num_retries, "retrying wake compute");
time::sleep(wait_duration).await;
continue;
}
// successfully woke up a compute node and can break the wakeup loop
Ok(ControlFlow::Break(mut node_info)) => {
node_info.config.reuse_password(&config);
mechanism.update_connect_config(&mut node_info.config);
state = ConnectionState::Cached(node_info)
}
}
match handle_try_wake(wake_res, num_retries)? {
// failed to wake up but we can continue to retry
ControlFlow::Continue(_) => {}
// successfully woke up a compute node and can break the wakeup loop
ControlFlow::Break(mut node_info) => {
node_info.config.reuse_password(&config);
mechanism.update_connect_config(&mut node_info.config);
break node_info;
}
ConnectionState::Cached(node_info) => {
match mechanism.connect_once(&node_info, CONNECT_TIMEOUT).await {
Ok(res) => return Ok(res),
Err(e) => {
error!(error = ?e, "could not connect to compute node");
if !e.should_retry(num_retries) {
return Err(e.into());
}
}
// after the first connect failure,
// we should invalidate the cache and wake up a new compute node
if num_retries == 0 {
state = ConnectionState::Invalid(invalidate_cache(node_info), e);
} else {
state = ConnectionState::Cached(node_info);
}
let wait_duration = retry_after(num_retries);
num_retries += 1;
let wait_duration = retry_after(num_retries);
num_retries += 1;
time::sleep(wait_duration).await;
info!(num_retries, "retrying wake compute");
};
info!(num_retries, "retrying wake compute");
time::sleep(wait_duration).await;
}
// now that we have a new node, try connect to it repeatedly.
// this can error for a few reasons, for instance:
// * DNS connection settings haven't quite propagated yet
info!("wake_compute success. attempting to connect");
loop {
match mechanism.connect_once(&node_info, CONNECT_TIMEOUT).await {
Ok(res) => return Ok(res),
Err(e) => {
error!(error = ?e, "could not connect to compute node");
if !e.should_retry(num_retries) {
return Err(e.into());
}
}
}
let wait_duration = retry_after(num_retries);
num_retries += 1;
time::sleep(wait_duration).await;
info!(num_retries, "retrying connect_once");
}
}
@@ -478,12 +470,15 @@ where
/// * Returns Ok(Continue(e)) if there was an error waking but retries are acceptable
/// * Returns Ok(Break(node)) if the wakeup succeeded
/// * Returns Err(e) if there was an error
fn handle_try_wake(
pub fn handle_try_wake(
result: Result<console::CachedNodeInfo, WakeComputeError>,
num_retries: u32,
) -> Result<ControlFlow<console::CachedNodeInfo, WakeComputeError>, WakeComputeError> {
match result {
Err(err) => match &err {
WakeComputeError::ApiError(api) if api.could_retry() => Ok(ControlFlow::Continue(err)),
WakeComputeError::ApiError(api) if api.should_retry(num_retries) => {
Ok(ControlFlow::Continue(err))
}
_ => Err(err),
},
// Ready to try again.
@@ -491,22 +486,10 @@ fn handle_try_wake(
}
}
/// Attempts to wake up the compute node.
pub async fn try_wake(
api: &impl console::Api,
extra: &console::ConsoleReqExtra<'_>,
creds: &auth::ClientCredentials<'_>,
) -> Result<ControlFlow<console::CachedNodeInfo, WakeComputeError>, WakeComputeError> {
info!("compute node's state has likely changed; requesting a wake-up");
handle_try_wake(api.wake_compute(extra, creds).await)
}
pub trait ShouldRetry {
fn could_retry(&self) -> bool;
fn should_retry(&self, num_retries: u32) -> bool {
match self {
// retry all errors at least once
_ if num_retries == 0 => true,
_ if num_retries >= NUM_RETRIES_CONNECT => false,
err => err.could_retry(),
}
@@ -558,14 +541,9 @@ impl ShouldRetry for compute::ConnectionError {
}
}
pub fn retry_after(num_retries: u32) -> time::Duration {
match num_retries {
0 => time::Duration::ZERO,
_ => {
// 3/2 = 1.5 which seems to be an ok growth factor heuristic
BASE_RETRY_WAIT_DURATION * 3_u32.pow(num_retries) / 2_u32.pow(num_retries)
}
}
fn retry_after(num_retries: u32) -> time::Duration {
// 1.5 seems to be an ok growth factor heuristic
BASE_RETRY_WAIT_DURATION.mul_f64(1.5_f64.powi(num_retries as i32))
}
/// Finish client connection initialization: confirm auth success, send params, etc.

View File

@@ -99,9 +99,8 @@ struct Scram(scram::ServerSecret);
impl Scram {
fn new(password: &str) -> anyhow::Result<Self> {
let salt = rand::random::<[u8; 16]>();
let secret = scram::ServerSecret::build(password, &salt, 256)
.context("failed to generate scram secret")?;
let secret =
scram::ServerSecret::build(password).context("failed to generate scram secret")?;
Ok(Scram(secret))
}
@@ -302,7 +301,7 @@ async fn scram_auth_mock() -> anyhow::Result<()> {
#[test]
fn connect_compute_total_wait() {
let mut total_wait = tokio::time::Duration::ZERO;
for num_retries in 0..10 {
for num_retries in 1..10 {
total_wait += retry_after(num_retries);
}
assert!(total_wait < tokio::time::Duration::from_secs(12));

View File

@@ -12,9 +12,6 @@ mod messages;
mod secret;
mod signature;
#[cfg(any(test, doc))]
mod password;
pub use exchange::Exchange;
pub use key::ScramKey;
pub use secret::ServerSecret;
@@ -57,27 +54,21 @@ fn sha256<'a>(parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
#[cfg(test)]
mod tests {
use postgres_protocol::authentication::sasl::{ChannelBinding, ScramSha256};
use crate::sasl::{Mechanism, Step};
use super::{password::SaltedPassword, Exchange, ServerSecret};
use super::{Exchange, ServerSecret};
#[test]
fn happy_path() {
fn snapshot() {
let iterations = 4096;
let salt_base64 = "QSXCR+Q6sek8bf92";
let pw = SaltedPassword::new(
b"pencil",
base64::decode(salt_base64).unwrap().as_slice(),
iterations,
);
let salt = "QSXCR+Q6sek8bf92";
let stored_key = "FO+9jBb3MUukt6jJnzjPZOWc5ow/Pu6JtPyju0aqaE8=";
let server_key = "qxJ1SbmSAi5EcS0J5Ck/cKAm/+Ixa+Kwp63f4OHDgzo=";
let secret = format!("SCRAM-SHA-256${iterations}:{salt}${stored_key}:{server_key}",);
let secret = ServerSecret::parse(&secret).unwrap();
let secret = ServerSecret {
iterations,
salt_base64: salt_base64.to_owned(),
stored_key: pw.client_key().sha256(),
server_key: pw.server_key(),
doomed: false,
};
const NONCE: [u8; 18] = [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
];
@@ -115,4 +106,40 @@ mod tests {
]
);
}
fn run_round_trip_test(client_password: &str) {
let secret = ServerSecret::build("pencil").unwrap();
let mut exchange = Exchange::new(&secret, rand::random, None);
let mut client =
ScramSha256::new(client_password.as_bytes(), ChannelBinding::unsupported());
let client_first = std::str::from_utf8(client.message()).unwrap();
exchange = match exchange.exchange(client_first).unwrap() {
Step::Continue(exchange, message) => {
client.update(message.as_bytes()).unwrap();
exchange
}
Step::Success(_, _) => panic!("expected continue, got success"),
Step::Failure(f) => panic!("{f}"),
};
let client_final = std::str::from_utf8(client.message()).unwrap();
match exchange.exchange(client_final).unwrap() {
Step::Success(_, message) => client.finish(message.as_bytes()).unwrap(),
Step::Continue(_, _) => panic!("expected success, got continue"),
Step::Failure(f) => panic!("{f}"),
};
}
#[test]
fn round_trip() {
run_round_trip_test("pencil")
}
#[test]
#[should_panic(expected = "password doesn't match")]
fn failure() {
run_round_trip_test("eraser")
}
}

View File

@@ -3,7 +3,7 @@
/// Faithfully taken from PostgreSQL.
pub const SCRAM_KEY_LEN: usize = 32;
/// One of the keys derived from the [password](super::password::SaltedPassword).
/// One of the keys derived from the user's password.
/// We use the same structure for all keys, i.e.
/// `ClientKey`, `StoredKey`, and `ServerKey`.
#[derive(Default, PartialEq, Eq)]

View File

@@ -1,74 +0,0 @@
//! Password hashing routines.
use super::key::ScramKey;
pub const SALTED_PASSWORD_LEN: usize = 32;
/// Salted hashed password is essential for [key](super::key) derivation.
#[repr(transparent)]
pub struct SaltedPassword {
bytes: [u8; SALTED_PASSWORD_LEN],
}
impl SaltedPassword {
/// See `scram-common.c : scram_SaltedPassword` for details.
/// Further reading: <https://datatracker.ietf.org/doc/html/rfc2898> (see `PBKDF2`).
pub fn new(password: &[u8], salt: &[u8], iterations: u32) -> SaltedPassword {
pbkdf2::pbkdf2_hmac_array::<sha2::Sha256, 32>(password, salt, iterations).into()
}
/// Derive `ClientKey` from a salted hashed password.
pub fn client_key(&self) -> ScramKey {
super::hmac_sha256(&self.bytes, [b"Client Key".as_ref()]).into()
}
/// Derive `ServerKey` from a salted hashed password.
pub fn server_key(&self) -> ScramKey {
super::hmac_sha256(&self.bytes, [b"Server Key".as_ref()]).into()
}
}
impl From<[u8; SALTED_PASSWORD_LEN]> for SaltedPassword {
#[inline(always)]
fn from(bytes: [u8; SALTED_PASSWORD_LEN]) -> Self {
Self { bytes }
}
}
#[cfg(test)]
mod tests {
use super::SaltedPassword;
fn legacy_pbkdf2_impl(password: &[u8], salt: &[u8], iterations: u32) -> SaltedPassword {
let one = 1_u32.to_be_bytes(); // magic
let mut current = super::super::hmac_sha256(password, [salt, &one]);
let mut result = current;
for _ in 1..iterations {
current = super::super::hmac_sha256(password, [current.as_ref()]);
// TODO: result = current.zip(result).map(|(x, y)| x ^ y), issue #80094
for (i, x) in current.iter().enumerate() {
result[i] ^= x;
}
}
result.into()
}
#[test]
fn pbkdf2() {
let password = "a-very-secure-password";
let salt = "such-a-random-salt";
let iterations = 4096;
let output = [
203, 18, 206, 81, 4, 154, 193, 100, 147, 41, 211, 217, 177, 203, 69, 210, 194, 211,
101, 1, 248, 156, 96, 0, 8, 223, 30, 87, 158, 41, 20, 42,
];
let actual = SaltedPassword::new(password.as_bytes(), salt.as_bytes(), iterations);
let expected = legacy_pbkdf2_impl(password.as_bytes(), salt.as_bytes(), iterations);
assert_eq!(actual.bytes, output);
assert_eq!(actual.bytes, expected.bytes);
}
}

View File

@@ -3,7 +3,7 @@
use super::base64_decode_array;
use super::key::ScramKey;
/// Server secret is produced from [password](super::password::SaltedPassword)
/// Server secret is produced from user's password,
/// and is used throughout the authentication process.
pub struct ServerSecret {
/// Number of iterations for `PBKDF2` function.
@@ -58,21 +58,10 @@ impl ServerSecret {
/// Build a new server secret from the prerequisites.
/// XXX: We only use this function in tests.
#[cfg(test)]
pub fn build(password: &str, salt: &[u8], iterations: u32) -> Option<Self> {
// TODO: implement proper password normalization required by the RFC
if !password.is_ascii() {
return None;
}
let password = super::password::SaltedPassword::new(password.as_bytes(), salt, iterations);
Some(Self {
iterations,
salt_base64: base64::encode(salt),
stored_key: password.client_key().sha256(),
server_key: password.server_key(),
doomed: false,
})
pub fn build(password: &str) -> Option<Self> {
Self::parse(&postgres_protocol::password::scram_sha_256(
password.as_bytes(),
))
}
}
@@ -102,20 +91,4 @@ mod tests {
assert_eq!(base64::encode(parsed.stored_key), stored_key);
assert_eq!(base64::encode(parsed.server_key), server_key);
}
#[test]
fn build_scram_secret() {
let salt = b"salt";
let secret = ServerSecret::build("password", salt, 4096).unwrap();
assert_eq!(secret.iterations, 4096);
assert_eq!(secret.salt_base64, base64::encode(salt));
assert_eq!(
base64::encode(secret.stored_key.as_ref()),
"lF4cRm/Jky763CN4HtxdHnjV4Q8AWTNlKvGmEFFU8IQ="
);
assert_eq!(
base64::encode(secret.server_key.as_ref()),
"ub8OgRsftnk2ccDMOt7ffHXNcikRkQkq1lh4xaAqrSw="
);
}
}

View File

@@ -234,7 +234,10 @@ async fn start_safekeeper(conf: SafeKeeperConf) -> Result<()> {
listen_pg_addr_tenant_only
);
let listener = tcp_listener::bind(listen_pg_addr_tenant_only.clone()).map_err(|e| {
error!("failed to bind to address {}: {}", conf.listen_pg_addr, e);
error!(
"failed to bind to address {}: {}",
listen_pg_addr_tenant_only, e
);
e
})?;
Some(listener)

View File

@@ -257,28 +257,15 @@ def prepare_snapshot(
shutil.rmtree(repo_dir / "pgdatadirs")
os.mkdir(repo_dir / "endpoints")
# Remove wal-redo temp directory if it exists. Newer pageserver versions don't create
# them anymore, but old versions did.
for tenant in (repo_dir / "tenants").glob("*"):
wal_redo_dir = tenant / "wal-redo-datadir.___temp"
if wal_redo_dir.exists() and wal_redo_dir.is_dir():
shutil.rmtree(wal_redo_dir)
# Update paths and ports in config files
pageserver_toml = repo_dir / "pageserver.toml"
pageserver_config = toml.load(pageserver_toml)
pageserver_config["remote_storage"]["local_path"] = str(repo_dir / "local_fs_remote_storage")
pageserver_config["listen_http_addr"] = port_distributor.replace_with_new_port(
pageserver_config["listen_http_addr"]
)
pageserver_config["listen_pg_addr"] = port_distributor.replace_with_new_port(
pageserver_config["listen_pg_addr"]
)
for param in ("listen_http_addr", "listen_pg_addr", "broker_endpoint"):
pageserver_config[param] = port_distributor.replace_with_new_port(pageserver_config[param])
# Older pageserver versions had just one `auth_type` setting. Now there
# are separate settings for pg and http ports. We don't use authentication
# in compatibility tests so just remove authentication related settings.
pageserver_config.pop("auth_type", None)
# We don't use authentication in compatibility tests
# so just remove authentication related settings.
pageserver_config.pop("pg_auth_type", None)
pageserver_config.pop("http_auth_type", None)
@@ -290,19 +277,16 @@ def prepare_snapshot(
snapshot_config_toml = repo_dir / "config"
snapshot_config = toml.load(snapshot_config_toml)
broker_listen_addr = f"127.0.0.1:{port_distributor.get_port()}"
snapshot_config["broker"] = {"listen_addr": broker_listen_addr}
snapshot_config["pageserver"]["listen_http_addr"] = port_distributor.replace_with_new_port(
snapshot_config["pageserver"]["listen_http_addr"]
)
snapshot_config["pageserver"]["listen_pg_addr"] = port_distributor.replace_with_new_port(
snapshot_config["pageserver"]["listen_pg_addr"]
for param in ("listen_http_addr", "listen_pg_addr"):
snapshot_config["pageserver"][param] = port_distributor.replace_with_new_port(
snapshot_config["pageserver"][param]
)
snapshot_config["broker"]["listen_addr"] = port_distributor.replace_with_new_port(
snapshot_config["broker"]["listen_addr"]
)
for sk in snapshot_config["safekeepers"]:
sk["http_port"] = port_distributor.replace_with_new_port(sk["http_port"])
sk["pg_port"] = port_distributor.replace_with_new_port(sk["pg_port"])
for param in ("http_port", "pg_port", "pg_tenant_only_port"):
sk[param] = port_distributor.replace_with_new_port(sk[param])
if pg_distrib_dir:
snapshot_config["pg_distrib_dir"] = str(pg_distrib_dir)

View File

@@ -14,10 +14,6 @@ from fixtures.neon_fixtures import NeonEnvBuilder, PgBin
def test_gc_cutoff(neon_env_builder: NeonEnvBuilder, pg_bin: PgBin):
env = neon_env_builder.init_start()
# These warnings are expected, when the pageserver is restarted abruptly
env.pageserver.allowed_errors.append(".*found future image layer.*")
env.pageserver.allowed_errors.append(".*found future delta layer.*")
pageserver_http = env.pageserver.http_client()
# Use aggressive GC and checkpoint settings, so that we also exercise GC during the test

View File

@@ -72,10 +72,6 @@ def test_pageserver_restart(neon_env_builder: NeonEnvBuilder):
def test_pageserver_chaos(neon_env_builder: NeonEnvBuilder):
env = neon_env_builder.init_start()
# These warnings are expected, when the pageserver is restarted abruptly
env.pageserver.allowed_errors.append(".*found future image layer.*")
env.pageserver.allowed_errors.append(".*found future delta layer.*")
# Use a tiny checkpoint distance, to create a lot of layers quickly.
# That allows us to stress the compaction and layer flushing logic more.
tenant, _ = env.neon_cli.create_tenant(

View File

@@ -265,18 +265,23 @@ def test_sql_over_http_output_options(static_proxy: NeonProxy):
def test_sql_over_http_batch(static_proxy: NeonProxy):
static_proxy.safe_psql("create role http with login password 'http' superuser")
def qq(queries: List[Tuple[str, Optional[List[Any]]]]) -> Any:
def qq(queries: List[Tuple[str, Optional[List[Any]]]], read_only: bool = False) -> Any:
connstr = f"postgresql://http:http@{static_proxy.domain}:{static_proxy.proxy_port}/postgres"
response = requests.post(
f"https://{static_proxy.domain}:{static_proxy.external_http_port}/sql",
data=json.dumps(list(map(lambda x: {"query": x[0], "params": x[1] or []}, queries))),
headers={"Content-Type": "application/sql", "Neon-Connection-String": connstr},
headers={
"Content-Type": "application/sql",
"Neon-Connection-String": connstr,
"Neon-Batch-Isolation-Level": "Serializable",
"Neon-Batch-Read-Only": "true" if read_only else "false",
},
verify=str(static_proxy.test_output_dir / "proxy.crt"),
)
assert response.status_code == 200
return response.json()["results"]
return response.json()["results"], response.headers
result = qq(
result, headers = qq(
[
("select 42 as answer", None),
("select $1 as answer", [42]),
@@ -291,6 +296,9 @@ def test_sql_over_http_batch(static_proxy: NeonProxy):
]
)
assert headers["Neon-Batch-Isolation-Level"] == "Serializable"
assert headers["Neon-Batch-Read-Only"] == "false"
assert result[0]["rows"] == [{"answer": 42}]
assert result[1]["rows"] == [{"answer": "42"}]
assert result[2]["rows"] == [{"answer": 42}]
@@ -311,3 +319,14 @@ def test_sql_over_http_batch(static_proxy: NeonProxy):
assert res["command"] == "DROP"
assert res["rowCount"] is None
assert len(result) == 10
result, headers = qq(
[
("select 42 as answer", None),
],
True,
)
assert headers["Neon-Batch-Isolation-Level"] == "Serializable"
assert headers["Neon-Batch-Read-Only"] == "true"
assert result[0]["rows"] == [{"answer": 42}]

View File

@@ -15,10 +15,6 @@ def test_pageserver_recovery(neon_env_builder: NeonEnvBuilder):
env = neon_env_builder.init_start()
env.pageserver.is_testing_enabled_or_skip()
# These warnings are expected, when the pageserver is restarted abruptly
env.pageserver.allowed_errors.append(".*found future delta layer.*")
env.pageserver.allowed_errors.append(".*found future image layer.*")
# Create a branch for us
env.neon_cli.create_branch("test_pageserver_recovery", "main")

View File

@@ -348,9 +348,6 @@ def test_remote_storage_upload_queue_retries(
# XXX: should vary this test to selectively fail just layer uploads, index uploads, deletions
# but how do we validate the result after restore?
# these are always possible when we do an immediate stop. perhaps something with compacting has changed since.
env.pageserver.allowed_errors.append(r".*found future (delta|image) layer.*")
env.pageserver.stop(immediate=True)
env.endpoints.stop_all()

View File

@@ -1,4 +1,4 @@
{
"postgres-v15": "1220c8a63f00101829f9222a5821fc084b4384c7",
"postgres-v14": "ebedb34d01c8ac9c31e8ea4628b9854103a1dc8f"
"postgres-v15": "770c6dffc5ef6aac05bf049693877fb377eea6fc",
"postgres-v14": "da3885c34db312afd555802be2ce985fafd1d8ad"
}