mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-20 14:40:37 +00:00
Compare commits
21 Commits
lfc_free_m
...
proxy/remo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
101e770632 | ||
|
|
747ffa50d6 | ||
|
|
e8400d9d93 | ||
|
|
17627e8023 | ||
|
|
4fb5cdbdb8 | ||
|
|
70b503f83b | ||
|
|
21c15c4285 | ||
|
|
0a524e09a5 | ||
|
|
cbe24f7c35 | ||
|
|
64add503c8 | ||
|
|
a98a80abc2 | ||
|
|
7b6c849456 | ||
|
|
326189d950 | ||
|
|
ddbe170454 | ||
|
|
39e458f049 | ||
|
|
e1424647a0 | ||
|
|
705ae2dce9 | ||
|
|
eb78603121 | ||
|
|
f0ad603693 | ||
|
|
e5183f85dc | ||
|
|
89ee8f2028 |
14
Cargo.lock
generated
14
Cargo.lock
generated
@@ -2654,16 +2654,6 @@ dependencies = [
|
||||
"windows-sys 0.45.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pbkdf2"
|
||||
version = "0.12.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f0ca0b5a68607598bf3bad68f32227a8164f6254833f84eafaac409cd6746c31"
|
||||
dependencies = [
|
||||
"digest",
|
||||
"hmac",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "peeking_take_while"
|
||||
version = "0.1.2"
|
||||
@@ -3040,6 +3030,7 @@ dependencies = [
|
||||
"chrono",
|
||||
"clap",
|
||||
"consumption_metrics",
|
||||
"fallible-iterator",
|
||||
"futures",
|
||||
"git-version",
|
||||
"hashbrown 0.13.2",
|
||||
@@ -3057,9 +3048,9 @@ dependencies = [
|
||||
"once_cell",
|
||||
"opentelemetry",
|
||||
"parking_lot 0.12.1",
|
||||
"pbkdf2",
|
||||
"pin-project-lite",
|
||||
"postgres-native-tls",
|
||||
"postgres-protocol",
|
||||
"postgres_backend",
|
||||
"pq_proto",
|
||||
"prometheus",
|
||||
@@ -3083,6 +3074,7 @@ dependencies = [
|
||||
"thiserror",
|
||||
"tls-listener",
|
||||
"tokio",
|
||||
"tokio-native-tls",
|
||||
"tokio-postgres",
|
||||
"tokio-postgres-rustls",
|
||||
"tokio-rustls 0.23.4",
|
||||
|
||||
@@ -88,7 +88,6 @@ opentelemetry = "0.19.0"
|
||||
opentelemetry-otlp = { version = "0.12.0", default_features=false, features = ["http-proto", "trace", "http", "reqwest-client"] }
|
||||
opentelemetry-semantic-conventions = "0.11.0"
|
||||
parking_lot = "0.12"
|
||||
pbkdf2 = "0.12.1"
|
||||
pin-project-lite = "0.2"
|
||||
prometheus = {version = "0.13", default_features=false, features = ["process"]} # removes protobuf dependency
|
||||
prost = "0.11"
|
||||
|
||||
@@ -551,10 +551,8 @@ FROM build-deps AS pg-embedding-pg-build
|
||||
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
|
||||
ENV PATH "/usr/local/pgsql/bin/:$PATH"
|
||||
# eeb3ba7c3a60c95b2604dd543c64b2f1bb4a3703 made on 15/07/2023
|
||||
# There is no release tag yet
|
||||
RUN wget https://github.com/neondatabase/pg_embedding/archive/eeb3ba7c3a60c95b2604dd543c64b2f1bb4a3703.tar.gz -O pg_embedding.tar.gz && \
|
||||
echo "030846df723652f99a8689ce63b66fa0c23477a7fd723533ab8a6b28ab70730f pg_embedding.tar.gz" | sha256sum --check && \
|
||||
RUN wget https://github.com/neondatabase/pg_embedding/archive/refs/tags/0.3.1.tar.gz -O pg_embedding.tar.gz && \
|
||||
echo "c4ae84eef36fa8ec5868f6e061f39812f19ee5ba3604d428d40935685c7be512 pg_embedding.tar.gz" | sha256sum --check && \
|
||||
mkdir pg_embedding-src && cd pg_embedding-src && tar xvzf ../pg_embedding.tar.gz --strip-components=1 -C . && \
|
||||
make -j $(getconf _NPROCESSORS_ONLN) && \
|
||||
make -j $(getconf _NPROCESSORS_ONLN) install && \
|
||||
|
||||
@@ -193,6 +193,13 @@ fn main() -> Result<()> {
|
||||
if !spec_set {
|
||||
// No spec provided, hang waiting for it.
|
||||
info!("no compute spec provided, waiting");
|
||||
|
||||
// TODO this can stall startups in the unlikely event that we bind
|
||||
// this compute node while it's busy prewarming. It's not too
|
||||
// bad because it's just 100ms and unlikely, but it's an
|
||||
// avoidable problem.
|
||||
compute.prewarm_postgres()?;
|
||||
|
||||
let mut state = compute.state.lock().unwrap();
|
||||
while state.status != ComputeStatus::ConfigurationPending {
|
||||
state = compute.state_changed.wait(state).unwrap();
|
||||
|
||||
@@ -532,6 +532,50 @@ impl ComputeNode {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Start and stop a postgres process to warm up the VM for startup.
|
||||
pub fn prewarm_postgres(&self) -> Result<()> {
|
||||
info!("prewarming");
|
||||
|
||||
// Create pgdata
|
||||
let pgdata = &format!("{}.warmup", self.pgdata);
|
||||
create_pgdata(pgdata)?;
|
||||
|
||||
// Run initdb to completion
|
||||
info!("running initdb");
|
||||
let initdb_bin = Path::new(&self.pgbin).parent().unwrap().join("initdb");
|
||||
Command::new(initdb_bin)
|
||||
.args(["-D", pgdata])
|
||||
.output()
|
||||
.expect("cannot start initdb process");
|
||||
|
||||
// Write conf
|
||||
use std::io::Write;
|
||||
let conf_path = Path::new(pgdata).join("postgresql.conf");
|
||||
let mut file = std::fs::File::create(conf_path)?;
|
||||
writeln!(file, "shared_buffers=65536")?;
|
||||
writeln!(file, "port=51055")?; // Nobody should be connecting
|
||||
writeln!(file, "shared_preload_libraries = 'neon'")?;
|
||||
|
||||
// Start postgres
|
||||
info!("starting postgres");
|
||||
let mut pg = Command::new(&self.pgbin)
|
||||
.args(["-D", pgdata])
|
||||
.spawn()
|
||||
.expect("cannot start postgres process");
|
||||
|
||||
// Stop it when it's ready
|
||||
info!("waiting for postgres");
|
||||
wait_for_postgres(&mut pg, Path::new(pgdata))?;
|
||||
pg.kill()?;
|
||||
info!("sent kill signal");
|
||||
pg.wait()?;
|
||||
info!("done prewarming");
|
||||
|
||||
// clean up
|
||||
let _ok = fs::remove_dir_all(pgdata);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Start Postgres as a child process and manage DBs/roles.
|
||||
/// After that this will hang waiting on the postmaster process to exit.
|
||||
#[instrument(skip_all)]
|
||||
|
||||
@@ -5,7 +5,7 @@ use chrono::{DateTime, Utc};
|
||||
use rand::Rng;
|
||||
use serde::Serialize;
|
||||
|
||||
#[derive(Serialize, Debug, Clone, Eq, PartialEq, Ord, PartialOrd)]
|
||||
#[derive(Serialize, Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum EventType {
|
||||
#[serde(rename = "absolute")]
|
||||
@@ -17,6 +17,32 @@ pub enum EventType {
|
||||
},
|
||||
}
|
||||
|
||||
impl EventType {
|
||||
pub fn absolute_time(&self) -> Option<&DateTime<Utc>> {
|
||||
use EventType::*;
|
||||
match self {
|
||||
Absolute { time } => Some(time),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn incremental_timerange(&self) -> Option<std::ops::Range<&DateTime<Utc>>> {
|
||||
// these can most likely be thought of as Range or RangeFull
|
||||
use EventType::*;
|
||||
match self {
|
||||
Incremental {
|
||||
start_time,
|
||||
stop_time,
|
||||
} => Some(start_time..stop_time),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_incremental(&self) -> bool {
|
||||
matches!(self, EventType::Incremental { .. })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug, Clone, Eq, PartialEq, Ord, PartialOrd)]
|
||||
pub struct Event<Extra> {
|
||||
#[serde(flatten)]
|
||||
|
||||
@@ -7,7 +7,7 @@ use crate::context::{DownloadBehavior, RequestContext};
|
||||
use crate::task_mgr::{self, TaskKind, BACKGROUND_RUNTIME};
|
||||
use crate::tenant::{mgr, LogicalSizeCalculationCause};
|
||||
use anyhow;
|
||||
use chrono::Utc;
|
||||
use chrono::{DateTime, Utc};
|
||||
use consumption_metrics::{idempotency_key, Event, EventChunk, EventType, CHUNK_SIZE};
|
||||
use pageserver_api::models::TenantState;
|
||||
use reqwest::Url;
|
||||
@@ -18,12 +18,6 @@ use std::time::Duration;
|
||||
use tracing::*;
|
||||
use utils::id::{NodeId, TenantId, TimelineId};
|
||||
|
||||
const WRITTEN_SIZE: &str = "written_size";
|
||||
const SYNTHETIC_STORAGE_SIZE: &str = "synthetic_storage_size";
|
||||
const RESIDENT_SIZE: &str = "resident_size";
|
||||
const REMOTE_STORAGE_SIZE: &str = "remote_storage_size";
|
||||
const TIMELINE_LOGICAL_SIZE: &str = "timeline_logical_size";
|
||||
|
||||
const DEFAULT_HTTP_REPORTING_TIMEOUT: Duration = Duration::from_secs(60);
|
||||
|
||||
#[serde_as]
|
||||
@@ -44,6 +38,121 @@ pub struct PageserverConsumptionMetricsKey {
|
||||
pub metric: &'static str,
|
||||
}
|
||||
|
||||
impl PageserverConsumptionMetricsKey {
|
||||
const fn absolute_values(self) -> AbsoluteValueFactory {
|
||||
AbsoluteValueFactory(self)
|
||||
}
|
||||
const fn incremental_values(self) -> IncrementalValueFactory {
|
||||
IncrementalValueFactory(self)
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper type which each individual metric kind can return to produce only absolute values.
|
||||
struct AbsoluteValueFactory(PageserverConsumptionMetricsKey);
|
||||
|
||||
impl AbsoluteValueFactory {
|
||||
fn now(self, val: u64) -> (PageserverConsumptionMetricsKey, (EventType, u64)) {
|
||||
let key = self.0;
|
||||
let time = Utc::now();
|
||||
(key, (EventType::Absolute { time }, val))
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper type which each individual metric kind can return to produce only incremental values.
|
||||
struct IncrementalValueFactory(PageserverConsumptionMetricsKey);
|
||||
|
||||
impl IncrementalValueFactory {
|
||||
#[allow(clippy::wrong_self_convention)]
|
||||
fn from_previous_up_to(
|
||||
self,
|
||||
prev_end: DateTime<Utc>,
|
||||
up_to: DateTime<Utc>,
|
||||
val: u64,
|
||||
) -> (PageserverConsumptionMetricsKey, (EventType, u64)) {
|
||||
let key = self.0;
|
||||
// cannot assert prev_end < up_to because these are realtime clock based
|
||||
(
|
||||
key,
|
||||
(
|
||||
EventType::Incremental {
|
||||
start_time: prev_end,
|
||||
stop_time: up_to,
|
||||
},
|
||||
val,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
fn key(&self) -> &PageserverConsumptionMetricsKey {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
// the static part of a PageserverConsumptionMetricsKey
|
||||
impl PageserverConsumptionMetricsKey {
|
||||
const fn written_size(tenant_id: TenantId, timeline_id: TimelineId) -> AbsoluteValueFactory {
|
||||
PageserverConsumptionMetricsKey {
|
||||
tenant_id,
|
||||
timeline_id: Some(timeline_id),
|
||||
metric: "written_size",
|
||||
}
|
||||
.absolute_values()
|
||||
}
|
||||
|
||||
/// Values will be the difference of the latest written_size (last_record_lsn) to what we
|
||||
/// previously sent.
|
||||
const fn written_size_delta(
|
||||
tenant_id: TenantId,
|
||||
timeline_id: TimelineId,
|
||||
) -> IncrementalValueFactory {
|
||||
PageserverConsumptionMetricsKey {
|
||||
tenant_id,
|
||||
timeline_id: Some(timeline_id),
|
||||
metric: "written_size_bytes_delta",
|
||||
}
|
||||
.incremental_values()
|
||||
}
|
||||
|
||||
const fn timeline_logical_size(
|
||||
tenant_id: TenantId,
|
||||
timeline_id: TimelineId,
|
||||
) -> AbsoluteValueFactory {
|
||||
PageserverConsumptionMetricsKey {
|
||||
tenant_id,
|
||||
timeline_id: Some(timeline_id),
|
||||
metric: "timeline_logical_size",
|
||||
}
|
||||
.absolute_values()
|
||||
}
|
||||
|
||||
const fn remote_storage_size(tenant_id: TenantId) -> AbsoluteValueFactory {
|
||||
PageserverConsumptionMetricsKey {
|
||||
tenant_id,
|
||||
timeline_id: None,
|
||||
metric: "remote_storage_size",
|
||||
}
|
||||
.absolute_values()
|
||||
}
|
||||
|
||||
const fn resident_size(tenant_id: TenantId) -> AbsoluteValueFactory {
|
||||
PageserverConsumptionMetricsKey {
|
||||
tenant_id,
|
||||
timeline_id: None,
|
||||
metric: "resident_size",
|
||||
}
|
||||
.absolute_values()
|
||||
}
|
||||
|
||||
const fn synthetic_size(tenant_id: TenantId) -> AbsoluteValueFactory {
|
||||
PageserverConsumptionMetricsKey {
|
||||
tenant_id,
|
||||
timeline_id: None,
|
||||
metric: "synthetic_storage_size",
|
||||
}
|
||||
.absolute_values()
|
||||
}
|
||||
}
|
||||
|
||||
/// Main thread that serves metrics collection
|
||||
pub async fn collect_metrics(
|
||||
metric_collection_endpoint: &Url,
|
||||
@@ -79,7 +188,7 @@ pub async fn collect_metrics(
|
||||
.timeout(DEFAULT_HTTP_REPORTING_TIMEOUT)
|
||||
.build()
|
||||
.expect("Failed to create http client with timeout");
|
||||
let mut cached_metrics: HashMap<PageserverConsumptionMetricsKey, u64> = HashMap::new();
|
||||
let mut cached_metrics = HashMap::new();
|
||||
let mut prev_iteration_time: std::time::Instant = std::time::Instant::now();
|
||||
|
||||
loop {
|
||||
@@ -121,13 +230,13 @@ pub async fn collect_metrics(
|
||||
/// - refactor this function (chunking+sending part) to reuse it in proxy module;
|
||||
pub async fn collect_metrics_iteration(
|
||||
client: &reqwest::Client,
|
||||
cached_metrics: &mut HashMap<PageserverConsumptionMetricsKey, u64>,
|
||||
cached_metrics: &mut HashMap<PageserverConsumptionMetricsKey, (EventType, u64)>,
|
||||
metric_collection_endpoint: &reqwest::Url,
|
||||
node_id: NodeId,
|
||||
ctx: &RequestContext,
|
||||
send_cached: bool,
|
||||
) {
|
||||
let mut current_metrics: Vec<(PageserverConsumptionMetricsKey, u64)> = Vec::new();
|
||||
let mut current_metrics: Vec<(PageserverConsumptionMetricsKey, (EventType, u64))> = Vec::new();
|
||||
trace!(
|
||||
"starting collect_metrics_iteration. metric_collection_endpoint: {}",
|
||||
metric_collection_endpoint
|
||||
@@ -166,27 +275,80 @@ pub async fn collect_metrics_iteration(
|
||||
if timeline.is_active() {
|
||||
let timeline_written_size = u64::from(timeline.get_last_record_lsn());
|
||||
|
||||
current_metrics.push((
|
||||
PageserverConsumptionMetricsKey {
|
||||
tenant_id,
|
||||
timeline_id: Some(timeline.timeline_id),
|
||||
metric: WRITTEN_SIZE,
|
||||
let (key, written_size_now) =
|
||||
PageserverConsumptionMetricsKey::written_size(tenant_id, timeline.timeline_id)
|
||||
.now(timeline_written_size);
|
||||
|
||||
// last_record_lsn can only go up, right now at least, TODO: #2592 or related
|
||||
// features might change this.
|
||||
|
||||
let written_size_delta_key = PageserverConsumptionMetricsKey::written_size_delta(
|
||||
tenant_id,
|
||||
timeline.timeline_id,
|
||||
);
|
||||
|
||||
// use this when available, because in a stream of incremental values, it will be
|
||||
// accurate where as when last_record_lsn stops moving, we will only cache the last
|
||||
// one of those.
|
||||
let last_stop_time =
|
||||
cached_metrics
|
||||
.get(written_size_delta_key.key())
|
||||
.map(|(until, _val)| {
|
||||
until
|
||||
.incremental_timerange()
|
||||
.expect("never create EventType::Absolute for written_size_delta")
|
||||
.end
|
||||
});
|
||||
|
||||
// by default, use the last sent written_size as the basis for
|
||||
// calculating the delta. if we don't yet have one, use the load time value.
|
||||
let prev = cached_metrics
|
||||
.get(&key)
|
||||
.map(|(prev_at, prev)| {
|
||||
// use the prev time from our last incremental update, or default to latest
|
||||
// absolute update on the first round.
|
||||
let prev_at = prev_at
|
||||
.absolute_time()
|
||||
.expect("never create EventType::Incremental for written_size");
|
||||
let prev_at = last_stop_time.unwrap_or(prev_at);
|
||||
(*prev_at, *prev)
|
||||
})
|
||||
.unwrap_or_else(|| {
|
||||
// if we don't have a previous point of comparison, compare to the load time
|
||||
// lsn.
|
||||
let (disk_consistent_lsn, loaded_at) = &timeline.loaded_at;
|
||||
(DateTime::from(*loaded_at), disk_consistent_lsn.0)
|
||||
});
|
||||
|
||||
// written_size_delta_bytes
|
||||
current_metrics.extend(
|
||||
if let Some(delta) = written_size_now.1.checked_sub(prev.1) {
|
||||
let up_to = written_size_now
|
||||
.0
|
||||
.absolute_time()
|
||||
.expect("never create EventType::Incremental for written_size");
|
||||
let key_value =
|
||||
written_size_delta_key.from_previous_up_to(prev.0, *up_to, delta);
|
||||
Some(key_value)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
timeline_written_size,
|
||||
));
|
||||
);
|
||||
|
||||
// written_size
|
||||
current_metrics.push((key, written_size_now));
|
||||
|
||||
let span = info_span!("collect_metrics_iteration", tenant_id = %timeline.tenant_id, timeline_id = %timeline.timeline_id);
|
||||
match span.in_scope(|| timeline.get_current_logical_size(ctx)) {
|
||||
// Only send timeline logical size when it is fully calculated.
|
||||
Ok((size, is_exact)) if is_exact => {
|
||||
current_metrics.push((
|
||||
PageserverConsumptionMetricsKey {
|
||||
current_metrics.push(
|
||||
PageserverConsumptionMetricsKey::timeline_logical_size(
|
||||
tenant_id,
|
||||
timeline_id: Some(timeline.timeline_id),
|
||||
metric: TIMELINE_LOGICAL_SIZE,
|
||||
},
|
||||
size,
|
||||
));
|
||||
timeline.timeline_id,
|
||||
)
|
||||
.now(size),
|
||||
);
|
||||
}
|
||||
Ok((_, _)) => {}
|
||||
Err(err) => {
|
||||
@@ -205,14 +367,10 @@ pub async fn collect_metrics_iteration(
|
||||
|
||||
match tenant.get_remote_size().await {
|
||||
Ok(tenant_remote_size) => {
|
||||
current_metrics.push((
|
||||
PageserverConsumptionMetricsKey {
|
||||
tenant_id,
|
||||
timeline_id: None,
|
||||
metric: REMOTE_STORAGE_SIZE,
|
||||
},
|
||||
tenant_remote_size,
|
||||
));
|
||||
current_metrics.push(
|
||||
PageserverConsumptionMetricsKey::remote_storage_size(tenant_id)
|
||||
.now(tenant_remote_size),
|
||||
);
|
||||
}
|
||||
Err(err) => {
|
||||
error!(
|
||||
@@ -222,14 +380,9 @@ pub async fn collect_metrics_iteration(
|
||||
}
|
||||
}
|
||||
|
||||
current_metrics.push((
|
||||
PageserverConsumptionMetricsKey {
|
||||
tenant_id,
|
||||
timeline_id: None,
|
||||
metric: RESIDENT_SIZE,
|
||||
},
|
||||
tenant_resident_size,
|
||||
));
|
||||
current_metrics.push(
|
||||
PageserverConsumptionMetricsKey::resident_size(tenant_id).now(tenant_resident_size),
|
||||
);
|
||||
|
||||
// Note that this metric is calculated in a separate bgworker
|
||||
// Here we only use cached value, which may lag behind the real latest one
|
||||
@@ -237,23 +390,27 @@ pub async fn collect_metrics_iteration(
|
||||
|
||||
if tenant_synthetic_size != 0 {
|
||||
// only send non-zeroes because otherwise these show up as errors in logs
|
||||
current_metrics.push((
|
||||
PageserverConsumptionMetricsKey {
|
||||
tenant_id,
|
||||
timeline_id: None,
|
||||
metric: SYNTHETIC_STORAGE_SIZE,
|
||||
},
|
||||
tenant_synthetic_size,
|
||||
));
|
||||
current_metrics.push(
|
||||
PageserverConsumptionMetricsKey::synthetic_size(tenant_id)
|
||||
.now(tenant_synthetic_size),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Filter metrics, unless we want to send all metrics, including cached ones.
|
||||
// See: https://github.com/neondatabase/neon/issues/3485
|
||||
if !send_cached {
|
||||
current_metrics.retain(|(curr_key, curr_val)| match cached_metrics.get(curr_key) {
|
||||
Some(val) => val != curr_val,
|
||||
None => true,
|
||||
current_metrics.retain(|(curr_key, (kind, curr_val))| {
|
||||
if kind.is_incremental() {
|
||||
// incremental values (currently only written_size_delta) should not get any cache
|
||||
// deduplication because they will be used by upstream for "is still alive."
|
||||
true
|
||||
} else {
|
||||
match cached_metrics.get(curr_key) {
|
||||
Some((_, val)) => val != curr_val,
|
||||
None => true,
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -272,8 +429,8 @@ pub async fn collect_metrics_iteration(
|
||||
chunk_to_send.clear();
|
||||
|
||||
// enrich metrics with type,timestamp and idempotency key before sending
|
||||
chunk_to_send.extend(chunk.iter().map(|(curr_key, curr_val)| Event {
|
||||
kind: EventType::Absolute { time: Utc::now() },
|
||||
chunk_to_send.extend(chunk.iter().map(|(curr_key, (when, curr_val))| Event {
|
||||
kind: *when,
|
||||
metric: curr_key.metric,
|
||||
idempotency_key: idempotency_key(node_id.to_string()),
|
||||
value: *curr_val,
|
||||
|
||||
@@ -390,39 +390,42 @@ where
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn dump(&self) -> Result<()> {
|
||||
self.dump_recurse(self.root_blk, &[], 0)
|
||||
}
|
||||
pub async fn dump(&self) -> Result<()> {
|
||||
let mut stack = Vec::new();
|
||||
|
||||
fn dump_recurse(&self, blknum: u32, path: &[u8], depth: usize) -> Result<()> {
|
||||
let blk = self.reader.read_blk(self.start_blk + blknum)?;
|
||||
let buf: &[u8] = blk.as_ref();
|
||||
stack.push((self.root_blk, String::new(), 0, 0, 0));
|
||||
|
||||
let node = OnDiskNode::<L>::deparse(buf)?;
|
||||
while let Some((blknum, path, depth, child_idx, key_off)) = stack.pop() {
|
||||
let blk = self.reader.read_blk(self.start_blk + blknum)?;
|
||||
let buf: &[u8] = blk.as_ref();
|
||||
let node = OnDiskNode::<L>::deparse(buf)?;
|
||||
|
||||
print!("{:indent$}", "", indent = depth * 2);
|
||||
println!(
|
||||
"blk #{}: path {}: prefix {}, suffix_len {}",
|
||||
blknum,
|
||||
hex::encode(path),
|
||||
hex::encode(node.prefix),
|
||||
node.suffix_len
|
||||
);
|
||||
if child_idx == 0 {
|
||||
print!("{:indent$}", "", indent = depth * 2);
|
||||
let path_prefix = stack
|
||||
.iter()
|
||||
.map(|(_blknum, path, ..)| path.as_str())
|
||||
.collect::<String>();
|
||||
println!(
|
||||
"blk #{blknum}: path {path_prefix}{path}: prefix {}, suffix_len {}",
|
||||
hex::encode(node.prefix),
|
||||
node.suffix_len
|
||||
);
|
||||
}
|
||||
|
||||
let mut idx = 0;
|
||||
let mut key_off = 0;
|
||||
while idx < node.num_children {
|
||||
if child_idx + 1 < node.num_children {
|
||||
let key_off = key_off + node.suffix_len as usize;
|
||||
stack.push((blknum, path.clone(), depth, child_idx + 1, key_off));
|
||||
}
|
||||
let key = &node.keys[key_off..key_off + node.suffix_len as usize];
|
||||
let val = node.value(idx as usize);
|
||||
let val = node.value(child_idx as usize);
|
||||
|
||||
print!("{:indent$}", "", indent = depth * 2 + 2);
|
||||
println!("{}: {}", hex::encode(key), hex::encode(val.0));
|
||||
|
||||
if node.level > 0 {
|
||||
let child_path = [path, node.prefix].concat();
|
||||
self.dump_recurse(val.to_blknum(), &child_path, depth + 1)?;
|
||||
stack.push((val.to_blknum(), hex::encode(node.prefix), depth + 1, 0, 0));
|
||||
}
|
||||
idx += 1;
|
||||
key_off += node.suffix_len as usize;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -754,8 +757,8 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn basic() -> Result<()> {
|
||||
#[tokio::test]
|
||||
async fn basic() -> Result<()> {
|
||||
let mut disk = TestDisk::new();
|
||||
let mut writer = DiskBtreeBuilder::<_, 6>::new(&mut disk);
|
||||
|
||||
@@ -775,7 +778,7 @@ mod tests {
|
||||
|
||||
let reader = DiskBtreeReader::new(0, root_offset, disk);
|
||||
|
||||
reader.dump()?;
|
||||
reader.dump().await?;
|
||||
|
||||
// Test the `get` function on all the keys.
|
||||
for (key, val) in all_data.iter() {
|
||||
@@ -835,8 +838,8 @@ mod tests {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lots_of_keys() -> Result<()> {
|
||||
#[tokio::test]
|
||||
async fn lots_of_keys() -> Result<()> {
|
||||
let mut disk = TestDisk::new();
|
||||
let mut writer = DiskBtreeBuilder::<_, 8>::new(&mut disk);
|
||||
|
||||
@@ -856,7 +859,7 @@ mod tests {
|
||||
|
||||
let reader = DiskBtreeReader::new(0, root_offset, disk);
|
||||
|
||||
reader.dump()?;
|
||||
reader.dump().await?;
|
||||
|
||||
use std::sync::Mutex;
|
||||
|
||||
@@ -994,8 +997,8 @@ mod tests {
|
||||
///
|
||||
/// This test contains a particular data set, see disk_btree_test_data.rs
|
||||
///
|
||||
#[test]
|
||||
fn particular_data() -> Result<()> {
|
||||
#[tokio::test]
|
||||
async fn particular_data() -> Result<()> {
|
||||
// Build a tree from it
|
||||
let mut disk = TestDisk::new();
|
||||
let mut writer = DiskBtreeBuilder::<_, 26>::new(&mut disk);
|
||||
@@ -1022,7 +1025,7 @@ mod tests {
|
||||
})?;
|
||||
assert_eq!(count, disk_btree_test_data::TEST_DATA.len());
|
||||
|
||||
reader.dump()?;
|
||||
reader.dump().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -223,6 +223,45 @@ mod tests {
|
||||
assert_eq!(part, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn v2_indexpart_is_parsed_with_deleted_at() {
|
||||
let example = r#"{
|
||||
"version":2,
|
||||
"timeline_layers":["000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__0000000001696070-00000000016960E9"],
|
||||
"missing_layers":["This shouldn't fail deserialization"],
|
||||
"layer_metadata":{
|
||||
"000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__0000000001696070-00000000016960E9": { "file_size": 25600000 },
|
||||
"000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__00000000016B59D8-00000000016B5A51": { "file_size": 9007199254741001 }
|
||||
},
|
||||
"disk_consistent_lsn":"0/16960E8",
|
||||
"metadata_bytes":[112,11,159,210,0,54,0,4,0,0,0,0,1,105,96,232,1,0,0,0,0,1,105,96,112,0,0,0,0,0,0,0,0,0,0,0,0,0,1,105,96,112,0,0,0,0,1,105,96,112,0,0,0,14,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
|
||||
"deleted_at": "2023-07-31T09:00:00.123"
|
||||
}"#;
|
||||
|
||||
let expected = IndexPart {
|
||||
// note this is not verified, could be anything, but exists for humans debugging.. could be the git version instead?
|
||||
version: 2,
|
||||
timeline_layers: HashSet::from(["000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__0000000001696070-00000000016960E9".parse().unwrap()]),
|
||||
layer_metadata: HashMap::from([
|
||||
("000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__0000000001696070-00000000016960E9".parse().unwrap(), IndexLayerMetadata {
|
||||
file_size: 25600000,
|
||||
}),
|
||||
("000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__00000000016B59D8-00000000016B5A51".parse().unwrap(), IndexLayerMetadata {
|
||||
// serde_json should always parse this but this might be a double with jq for
|
||||
// example.
|
||||
file_size: 9007199254741001,
|
||||
})
|
||||
]),
|
||||
disk_consistent_lsn: "0/16960E8".parse::<Lsn>().unwrap(),
|
||||
metadata_bytes: [112,11,159,210,0,54,0,4,0,0,0,0,1,105,96,232,1,0,0,0,0,1,105,96,112,0,0,0,0,0,0,0,0,0,0,0,0,0,1,105,96,112,0,0,0,0,1,105,96,112,0,0,0,14,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0].to_vec(),
|
||||
deleted_at: Some(chrono::NaiveDateTime::parse_from_str(
|
||||
"2023-07-31T09:00:00.123000000", "%Y-%m-%dT%H:%M:%S.%f").unwrap())
|
||||
};
|
||||
|
||||
let part = serde_json::from_str::<IndexPart>(example).unwrap();
|
||||
assert_eq!(part, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_layers_are_parsed() {
|
||||
let empty_layers_json = r#"{
|
||||
|
||||
@@ -256,7 +256,7 @@ impl Layer for DeltaLayer {
|
||||
file,
|
||||
);
|
||||
|
||||
tree_reader.dump()?;
|
||||
tree_reader.dump().await?;
|
||||
|
||||
let mut cursor = file.block_cursor();
|
||||
|
||||
|
||||
@@ -175,7 +175,7 @@ impl Layer for ImageLayer {
|
||||
let tree_reader =
|
||||
DiskBtreeReader::<_, KEY_SIZE>::new(inner.index_start_blk, inner.index_root_blk, file);
|
||||
|
||||
tree_reader.dump()?;
|
||||
tree_reader.dump().await?;
|
||||
|
||||
tree_reader.visit(&[0u8; KEY_SIZE], VisitDirection::Forwards, |key, value| {
|
||||
println!("key: {} offset {}", hex::encode(key), value);
|
||||
|
||||
@@ -294,6 +294,10 @@ pub struct Timeline {
|
||||
/// Completion shared between all timelines loaded during startup; used to delay heavier
|
||||
/// background tasks until some logical sizes have been calculated.
|
||||
initial_logical_size_attempt: Mutex<Option<completion::Completion>>,
|
||||
|
||||
/// Load or creation time information about the disk_consistent_lsn and when the loading
|
||||
/// happened. Used for consumption metrics.
|
||||
pub(crate) loaded_at: (Lsn, SystemTime),
|
||||
}
|
||||
|
||||
pub struct WalReceiverInfo {
|
||||
@@ -1404,6 +1408,8 @@ impl Timeline {
|
||||
last_freeze_at: AtomicLsn::new(disk_consistent_lsn.0),
|
||||
last_freeze_ts: RwLock::new(Instant::now()),
|
||||
|
||||
loaded_at: (disk_consistent_lsn, SystemTime::now()),
|
||||
|
||||
ancestor_timeline: ancestor,
|
||||
ancestor_lsn: metadata.ancestor_lsn(),
|
||||
|
||||
@@ -1600,7 +1606,7 @@ impl Timeline {
|
||||
if let Some(imgfilename) = ImageFileName::parse_str(&fname) {
|
||||
// create an ImageLayer struct for each image file.
|
||||
if imgfilename.lsn > disk_consistent_lsn {
|
||||
warn!(
|
||||
info!(
|
||||
"found future image layer {} on timeline {} disk_consistent_lsn is {}",
|
||||
imgfilename, self.timeline_id, disk_consistent_lsn
|
||||
);
|
||||
@@ -1632,7 +1638,7 @@ impl Timeline {
|
||||
// is 102, then it might not have been fully flushed to disk
|
||||
// before crash.
|
||||
if deltafilename.lsn_range.end > disk_consistent_lsn + 1 {
|
||||
warn!(
|
||||
info!(
|
||||
"found future delta layer {} on timeline {} disk_consistent_lsn is {}",
|
||||
deltafilename, self.timeline_id, disk_consistent_lsn
|
||||
);
|
||||
@@ -1774,7 +1780,7 @@ impl Timeline {
|
||||
match remote_layer_name {
|
||||
LayerFileName::Image(imgfilename) => {
|
||||
if imgfilename.lsn > up_to_date_disk_consistent_lsn {
|
||||
warn!(
|
||||
info!(
|
||||
"found future image layer {} on timeline {} remote_consistent_lsn is {}",
|
||||
imgfilename, self.timeline_id, up_to_date_disk_consistent_lsn
|
||||
);
|
||||
@@ -1799,7 +1805,7 @@ impl Timeline {
|
||||
// is 102, then it might not have been fully flushed to disk
|
||||
// before crash.
|
||||
if deltafilename.lsn_range.end > up_to_date_disk_consistent_lsn + 1 {
|
||||
warn!(
|
||||
info!(
|
||||
"found future delta layer {} on timeline {} remote_consistent_lsn is {}",
|
||||
deltafilename, self.timeline_id, up_to_date_disk_consistent_lsn
|
||||
);
|
||||
|
||||
@@ -88,50 +88,16 @@ static LWLockId lfc_lock;
|
||||
static int lfc_max_size;
|
||||
static int lfc_size_limit;
|
||||
static int lfc_free_space_watermark;
|
||||
static int lfc_free_memory_watermark;
|
||||
static char* lfc_path;
|
||||
static FileCacheControl* lfc_ctl;
|
||||
static shmem_startup_hook_type prev_shmem_startup_hook;
|
||||
#if PG_VERSION_NUM>=150000
|
||||
static shmem_request_hook_type prev_shmem_request_hook;
|
||||
#endif
|
||||
static int lfc_shrinking_factor; /* power of two by which local cache size will be shrinked when lfc_free_space_watermark or lfc_free_memory_watermak are reached */
|
||||
static int lfc_shrinking_factor; /* power of two by which local cache size will be shrinked when lfc_free_space_watermark is reached */
|
||||
|
||||
void FileCacheMonitorMain(Datum main_arg);
|
||||
|
||||
#ifdef __APPLE__
|
||||
|
||||
#include <sys/types.h>
|
||||
#include <sys/sysctl.h>
|
||||
|
||||
static size_t
|
||||
get_available_memory(void)
|
||||
{
|
||||
size_t total;
|
||||
size_t sizeof_total = sizeof(total);
|
||||
if (sysctlbyname("hw.memsize", &total, &sizeof_total, NULL, 0) < 0)
|
||||
elog(ERROR, "Failed to get amount of RAM: %m");
|
||||
|
||||
return total;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
#include <sys/sysinfo.h>
|
||||
|
||||
static size_t
|
||||
get_available_memory(void)
|
||||
{
|
||||
struct sysinfo si;
|
||||
if (sysinfo(&si) < 0)
|
||||
elog(ERROR, "Failed to get amount of RAM: %m");
|
||||
|
||||
return si.totalram*si.mem_unit;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
static void
|
||||
lfc_shmem_startup(void)
|
||||
{
|
||||
@@ -229,11 +195,10 @@ lfc_change_limit_hook(int newval, void *extra)
|
||||
}
|
||||
|
||||
/*
|
||||
* Local file system state monitor check available free space and memory.
|
||||
* If available disk space is lower than lfc_free_space_watermark or
|
||||
* available memory is lower than lfc_free_memory_watermark then we shrink size of local cache
|
||||
* Local file system state monitor check available free space.
|
||||
* If it is lower than lfc_free_space_watermark then we shrink size of local cache
|
||||
* but throwing away least recently accessed chunks.
|
||||
* First time the watermark is reached cache size is divided by two,
|
||||
* First time low space watermark is reached cache size is divided by two,
|
||||
* second time by four,... Finally we remove all chunks from local cache.
|
||||
*
|
||||
* Please notice that we are not changing lfc_cache_size: it is used to be adjusted by autoscaler.
|
||||
@@ -263,27 +228,23 @@ FileCacheMonitorMain(Datum main_arg)
|
||||
{
|
||||
if (lfc_size_limit != 0)
|
||||
{
|
||||
bool shrink_cache = false;
|
||||
if (lfc_free_space_watermark != 0)
|
||||
struct statvfs sfs;
|
||||
if (statvfs(lfc_path, &sfs) < 0)
|
||||
{
|
||||
struct statvfs sfs;
|
||||
if (statvfs(lfc_path, &sfs) < 0)
|
||||
elog(WARNING, "Failed to obtain status of %s: %m", lfc_path);
|
||||
else
|
||||
shrink_cache |= sfs.f_bavail*sfs.f_bsize < lfc_free_space_watermark*MB;
|
||||
}
|
||||
if (lfc_free_memory_watermark != 0)
|
||||
shrink_cache |= get_available_memory() < lfc_free_memory_watermark*MB;
|
||||
|
||||
if (shrink_cache)
|
||||
{
|
||||
if (lfc_shrinking_factor < 31) {
|
||||
lfc_shrinking_factor += 1;
|
||||
}
|
||||
lfc_change_limit_hook(lfc_size_limit >> lfc_shrinking_factor, NULL);
|
||||
elog(WARNING, "Failed to obtain status of %s: %m", lfc_path);
|
||||
}
|
||||
else
|
||||
lfc_shrinking_factor = 0; /* reset to initial value */
|
||||
{
|
||||
if (sfs.f_bavail*sfs.f_bsize < lfc_free_space_watermark*MB)
|
||||
{
|
||||
if (lfc_shrinking_factor < 31) {
|
||||
lfc_shrinking_factor += 1;
|
||||
}
|
||||
lfc_change_limit_hook(lfc_size_limit >> lfc_shrinking_factor, NULL);
|
||||
}
|
||||
else
|
||||
lfc_shrinking_factor = 0; /* reset to initial value */
|
||||
}
|
||||
}
|
||||
pg_usleep(monitor_interval);
|
||||
}
|
||||
@@ -356,19 +317,6 @@ lfc_init(void)
|
||||
NULL,
|
||||
NULL);
|
||||
|
||||
DefineCustomIntVariable("neon.free_memory_watermark",
|
||||
"Minimal free memory in system after reaching which local file cache will be truncated",
|
||||
NULL,
|
||||
&lfc_free_memory_watermark,
|
||||
0, /* disabled by default, because iurt makes sense only when local file cache is located i tmpfs */
|
||||
0,
|
||||
INT_MAX,
|
||||
PGC_SIGHUP,
|
||||
GUC_UNIT_MB,
|
||||
NULL,
|
||||
NULL,
|
||||
NULL);
|
||||
|
||||
DefineCustomStringVariable("neon.file_cache_path",
|
||||
"Path to local file cache (can be raw device)",
|
||||
NULL,
|
||||
|
||||
@@ -29,9 +29,9 @@ metrics.workspace = true
|
||||
once_cell.workspace = true
|
||||
opentelemetry.workspace = true
|
||||
parking_lot.workspace = true
|
||||
pbkdf2.workspace = true
|
||||
pin-project-lite.workspace = true
|
||||
postgres_backend.workspace = true
|
||||
postgres-protocol.workspace = true
|
||||
pq_proto.workspace = true
|
||||
prometheus.workspace = true
|
||||
rand.workspace = true
|
||||
@@ -65,10 +65,13 @@ webpki-roots.workspace = true
|
||||
x509-parser.workspace = true
|
||||
native-tls.workspace = true
|
||||
postgres-native-tls.workspace = true
|
||||
tokio-native-tls = "0.3.1"
|
||||
|
||||
workspace_hack.workspace = true
|
||||
tokio-util.workspace = true
|
||||
|
||||
fallible-iterator = "0.2.0"
|
||||
|
||||
[dev-dependencies]
|
||||
rcgen.workspace = true
|
||||
rstest.workspace = true
|
||||
|
||||
@@ -5,7 +5,7 @@ use crate::{
|
||||
auth::{self, AuthFlow, ClientCredentials},
|
||||
compute,
|
||||
console::{self, AuthInfo, CachedNodeInfo, ConsoleReqExtra},
|
||||
proxy::{try_wake, NUM_RETRIES_CONNECT},
|
||||
proxy::handle_try_wake,
|
||||
sasl, scram,
|
||||
stream::PqStream,
|
||||
};
|
||||
@@ -51,14 +51,15 @@ pub(super) async fn authenticate(
|
||||
}
|
||||
};
|
||||
|
||||
info!("compute node's state has likely changed; requesting a wake-up");
|
||||
let mut num_retries = 0;
|
||||
let mut node = loop {
|
||||
num_retries += 1;
|
||||
match try_wake(api, extra, creds).await? {
|
||||
let wake_res = api.wake_compute(extra, creds).await;
|
||||
match handle_try_wake(wake_res, num_retries)? {
|
||||
ControlFlow::Continue(_) => num_retries += 1,
|
||||
ControlFlow::Break(n) => break n,
|
||||
ControlFlow::Continue(_) if num_retries < NUM_RETRIES_CONNECT => continue,
|
||||
ControlFlow::Continue(e) => return Err(e.into()),
|
||||
}
|
||||
info!(num_retries, "retrying wake compute");
|
||||
};
|
||||
if let Some(keys) = scram_keys {
|
||||
use tokio_postgres::config::AuthKeys;
|
||||
|
||||
@@ -6,7 +6,7 @@ use std::fmt;
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use tokio::time;
|
||||
|
||||
use crate::{auth, console};
|
||||
use crate::{auth, console, pg_client};
|
||||
use crate::{compute, config};
|
||||
|
||||
use super::sql_over_http::MAX_RESPONSE_SIZE;
|
||||
@@ -41,8 +41,10 @@ impl fmt::Display for ConnInfo {
|
||||
}
|
||||
}
|
||||
|
||||
type PgConn =
|
||||
pg_client::connection::Connection<tokio_postgres::Socket, tokio_postgres::tls::NoTlsStream>;
|
||||
struct ConnPoolEntry {
|
||||
conn: tokio_postgres::Client,
|
||||
conn: PgConn,
|
||||
_last_access: std::time::Instant,
|
||||
}
|
||||
|
||||
@@ -78,12 +80,8 @@ impl GlobalConnPool {
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn get(
|
||||
&self,
|
||||
conn_info: &ConnInfo,
|
||||
force_new: bool,
|
||||
) -> anyhow::Result<tokio_postgres::Client> {
|
||||
let mut client: Option<tokio_postgres::Client> = None;
|
||||
pub async fn get(&self, conn_info: &ConnInfo, force_new: bool) -> anyhow::Result<PgConn> {
|
||||
let mut client: Option<PgConn> = None;
|
||||
|
||||
if !force_new {
|
||||
let pool = self.get_endpoint_pool(&conn_info.hostname).await;
|
||||
@@ -114,11 +112,7 @@ impl GlobalConnPool {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn put(
|
||||
&self,
|
||||
conn_info: &ConnInfo,
|
||||
client: tokio_postgres::Client,
|
||||
) -> anyhow::Result<()> {
|
||||
pub async fn put(&self, conn_info: &ConnInfo, client: PgConn) -> anyhow::Result<()> {
|
||||
let pool = self.get_endpoint_pool(&conn_info.hostname).await;
|
||||
|
||||
// return connection to the pool
|
||||
@@ -191,7 +185,7 @@ struct TokioMechanism<'a> {
|
||||
|
||||
#[async_trait]
|
||||
impl ConnectMechanism for TokioMechanism<'_> {
|
||||
type Connection = tokio_postgres::Client;
|
||||
type Connection = PgConn;
|
||||
type ConnectError = tokio_postgres::Error;
|
||||
type Error = anyhow::Error;
|
||||
|
||||
@@ -213,7 +207,7 @@ impl ConnectMechanism for TokioMechanism<'_> {
|
||||
async fn connect_to_compute(
|
||||
config: &config::ProxyConfig,
|
||||
conn_info: &ConnInfo,
|
||||
) -> anyhow::Result<tokio_postgres::Client> {
|
||||
) -> anyhow::Result<PgConn> {
|
||||
let tls = config.tls_config.as_ref();
|
||||
let common_names = tls.and_then(|tls| tls.common_names.clone());
|
||||
|
||||
@@ -251,7 +245,7 @@ async fn connect_to_compute_once(
|
||||
node_info: &console::CachedNodeInfo,
|
||||
conn_info: &ConnInfo,
|
||||
timeout: time::Duration,
|
||||
) -> Result<tokio_postgres::Client, tokio_postgres::Error> {
|
||||
) -> Result<PgConn, tokio_postgres::Error> {
|
||||
let mut config = (*node_info.config).clone();
|
||||
|
||||
let (client, connection) = config
|
||||
@@ -263,11 +257,13 @@ async fn connect_to_compute_once(
|
||||
.connect(tokio_postgres::NoTls)
|
||||
.await?;
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = connection.await {
|
||||
error!("connection error: {}", e);
|
||||
}
|
||||
});
|
||||
let stream = connection.stream.into_inner();
|
||||
|
||||
Ok(client)
|
||||
// tokio::spawn(async move {
|
||||
// if let Err(e) = connection.await {
|
||||
// error!("connection error: {}", e);
|
||||
// }
|
||||
// });
|
||||
|
||||
Ok(PgConn::new(stream))
|
||||
}
|
||||
|
||||
@@ -1,20 +1,38 @@
|
||||
use std::io::ErrorKind;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::bail;
|
||||
use bytes::BufMut;
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use futures::pin_mut;
|
||||
use futures::StreamExt;
|
||||
use hashbrown::HashMap;
|
||||
use hyper::body::HttpBody;
|
||||
use hyper::http::HeaderName;
|
||||
use hyper::http::HeaderValue;
|
||||
use hyper::{Body, HeaderMap, Request};
|
||||
use postgres_protocol::message::backend::DataRowBody;
|
||||
use postgres_protocol::message::backend::ReadyForQueryBody;
|
||||
use serde_json::json;
|
||||
use serde_json::Map;
|
||||
use serde_json::Value;
|
||||
use tokio::io::AsyncRead;
|
||||
use tokio::io::AsyncWrite;
|
||||
use tokio_postgres::types::Kind;
|
||||
use tokio_postgres::types::Type;
|
||||
use tokio_postgres::GenericClient;
|
||||
use tokio_postgres::IsolationLevel;
|
||||
use tokio_postgres::Row;
|
||||
use tokio_postgres::RowStream;
|
||||
use tokio_postgres::Statement;
|
||||
use url::Url;
|
||||
|
||||
use crate::pg_client;
|
||||
use crate::pg_client::codec::FrontendMessage;
|
||||
use crate::pg_client::connection;
|
||||
use crate::pg_client::connection::RequestMessages;
|
||||
use crate::pg_client::prepare::TypeinfoPreparedQueries;
|
||||
|
||||
use super::conn_pool::ConnInfo;
|
||||
use super::conn_pool::GlobalConnPool;
|
||||
|
||||
@@ -37,6 +55,8 @@ const MAX_REQUEST_SIZE: u64 = 1024 * 1024; // 1 MB
|
||||
static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
|
||||
static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
|
||||
static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in");
|
||||
static TXN_ISOLATION_LEVEL: HeaderName = HeaderName::from_static("neon-batch-isolation-level");
|
||||
static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only");
|
||||
|
||||
static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
|
||||
|
||||
@@ -170,7 +190,7 @@ pub async fn handle(
|
||||
request: Request<Body>,
|
||||
sni_hostname: Option<String>,
|
||||
conn_pool: Arc<GlobalConnPool>,
|
||||
) -> anyhow::Result<Value> {
|
||||
) -> anyhow::Result<(Value, HashMap<HeaderName, HeaderValue>)> {
|
||||
//
|
||||
// Determine the destination and connection params
|
||||
//
|
||||
@@ -185,6 +205,23 @@ pub async fn handle(
|
||||
// Allow connection pooling only if explicitly requested
|
||||
let allow_pool = headers.get(&ALLOW_POOL) == Some(&HEADER_VALUE_TRUE);
|
||||
|
||||
// isolation level and read only
|
||||
|
||||
let txn_isolation_level_raw = headers.get(&TXN_ISOLATION_LEVEL).cloned();
|
||||
let txn_isolation_level = match txn_isolation_level_raw {
|
||||
Some(ref x) => Some(match x.as_bytes() {
|
||||
b"Serializable" => IsolationLevel::Serializable,
|
||||
b"ReadUncommitted" => IsolationLevel::ReadUncommitted,
|
||||
b"ReadCommitted" => IsolationLevel::ReadCommitted,
|
||||
b"RepeatableRead" => IsolationLevel::RepeatableRead,
|
||||
_ => bail!("invalid isolation level"),
|
||||
}),
|
||||
None => None,
|
||||
};
|
||||
|
||||
let txn_read_only_raw = headers.get(&TXN_READ_ONLY).cloned();
|
||||
let txn_read_only = txn_read_only_raw.as_ref() == Some(&HEADER_VALUE_TRUE);
|
||||
|
||||
let request_content_length = match request.body().size_hint().upper() {
|
||||
Some(v) => v,
|
||||
None => MAX_REQUEST_SIZE + 1,
|
||||
@@ -208,26 +245,48 @@ pub async fn handle(
|
||||
// Now execute the query and return the result
|
||||
//
|
||||
let result = match payload {
|
||||
Payload::Single(query) => query_to_json(&client, query, raw_output, array_mode).await,
|
||||
Payload::Single(query) => query_raw_txt_as_json(&mut client, query, raw_output, array_mode)
|
||||
.await
|
||||
.map(|x| (x, HashMap::default())),
|
||||
Payload::Batch(queries) => {
|
||||
let mut results = Vec::new();
|
||||
let transaction = client.transaction().await?;
|
||||
|
||||
client
|
||||
.start_tx(txn_isolation_level, Some(txn_read_only))
|
||||
.await?;
|
||||
|
||||
for query in queries {
|
||||
let result = query_to_json(&transaction, query, raw_output, array_mode).await;
|
||||
let result =
|
||||
query_raw_txt_as_json(&mut client, query, raw_output, array_mode).await;
|
||||
match result {
|
||||
Ok(r) => results.push(r),
|
||||
// TODO: check this tag to see if the client has executed a commit during the non-interactive transactions...
|
||||
Ok((r, _ready_tag)) => results.push(r),
|
||||
Err(e) => {
|
||||
transaction.rollback().await?;
|
||||
let tag = client.rollback().await?;
|
||||
if allow_pool && tag.status() == b'I' {
|
||||
// return connection to the pool
|
||||
tokio::task::spawn(async move {
|
||||
let _ = conn_pool.put(&conn_info, client).await;
|
||||
});
|
||||
}
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
transaction.commit().await?;
|
||||
Ok(json!({ "results": results }))
|
||||
let ready_tag = client.commit().await?;
|
||||
let mut headers = HashMap::default();
|
||||
headers.insert(
|
||||
TXN_READ_ONLY.clone(),
|
||||
HeaderValue::try_from(txn_read_only.to_string())?,
|
||||
);
|
||||
if let Some(txn_isolation_level_raw) = txn_isolation_level_raw {
|
||||
headers.insert(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level_raw);
|
||||
}
|
||||
Ok(((json!({ "results": results }), ready_tag), headers))
|
||||
}
|
||||
};
|
||||
|
||||
if allow_pool {
|
||||
if allow_pool && ready_tag.status() == b'I' {
|
||||
// return connection to the pool
|
||||
tokio::task::spawn(async move {
|
||||
let _ = conn_pool.put(&conn_info, client).await;
|
||||
@@ -312,6 +371,99 @@ async fn query_to_json<T: GenericClient>(
|
||||
}))
|
||||
}
|
||||
|
||||
async fn query_raw_txt_as_json<'a, St, T>(
|
||||
conn: &mut connection::Connection<St, T>,
|
||||
data: QueryData,
|
||||
raw_output: bool,
|
||||
array_mode: bool,
|
||||
) -> anyhow::Result<(Value, ReadyForQueryBody)>
|
||||
where
|
||||
St: AsyncRead + AsyncWrite + Unpin + Send,
|
||||
T: AsyncRead + AsyncWrite + Unpin + Send,
|
||||
{
|
||||
let params = json_to_pg_text(data.params)?;
|
||||
let params = params.into_iter();
|
||||
|
||||
let stmt_name = conn.statement_name();
|
||||
let row_description = conn.prepare(&stmt_name, &data.query).await?;
|
||||
|
||||
let mut fields = vec![];
|
||||
let mut columns = vec![];
|
||||
let mut it = row_description.fields();
|
||||
while let Some(field) = it.next().map_err(pg_client::error::Error::parse)? {
|
||||
fields.push(json!({
|
||||
"name": Value::String(field.name().to_owned()),
|
||||
"dataTypeID": Value::Number(field.type_oid().into()),
|
||||
"tableID": field.table_oid(),
|
||||
"columnID": field.column_id(),
|
||||
"dataTypeSize": field.type_size(),
|
||||
"dataTypeModifier": field.type_modifier(),
|
||||
"format": "text",
|
||||
}));
|
||||
|
||||
let type_ = match Type::from_oid(field.type_oid()) {
|
||||
Some(t) => t,
|
||||
None => TypeinfoPreparedQueries::get_type(conn, field.type_oid()).await?,
|
||||
};
|
||||
|
||||
columns.push(Column {
|
||||
name: field.name().to_string(),
|
||||
type_,
|
||||
});
|
||||
}
|
||||
|
||||
conn.execute("", &stmt_name, params)?;
|
||||
conn.sync().await?;
|
||||
|
||||
let mut rows = vec![];
|
||||
|
||||
let mut row_stream = conn.stream_query_results().await?;
|
||||
|
||||
let mut curret_size = 0;
|
||||
while let Some(row) = row_stream.next().await.transpose()? {
|
||||
// let row = row.map_err(Error::db)?;
|
||||
|
||||
curret_size += row.buffer().len();
|
||||
if curret_size > MAX_RESPONSE_SIZE {
|
||||
return Err(anyhow::anyhow!("response too large"));
|
||||
}
|
||||
|
||||
rows.push(pg_text_row_to_json2(&row, &columns, raw_output, array_mode).unwrap());
|
||||
}
|
||||
|
||||
let command_tag = row_stream.tag();
|
||||
let command_tag = command_tag.tag()?;
|
||||
let mut command_tag_split = command_tag.split(' ');
|
||||
let command_tag_name = command_tag_split.next().unwrap_or_default();
|
||||
let command_tag_count = if command_tag_name == "INSERT" {
|
||||
// INSERT returns OID first and then number of rows
|
||||
command_tag_split.nth(1)
|
||||
} else {
|
||||
// other commands return number of rows (if any)
|
||||
command_tag_split.next()
|
||||
}
|
||||
.and_then(|s| s.parse::<i64>().ok());
|
||||
|
||||
let ready_tag = conn.wait_for_ready().await?;
|
||||
|
||||
// resulting JSON format is based on the format of node-postgres result
|
||||
Ok((
|
||||
json!({
|
||||
"command": command_tag_name,
|
||||
"rowCount": command_tag_count,
|
||||
"rows": rows,
|
||||
"fields": fields,
|
||||
"rowAsArray": array_mode,
|
||||
}),
|
||||
ready_tag,
|
||||
))
|
||||
}
|
||||
|
||||
struct Column {
|
||||
name: String,
|
||||
type_: Type,
|
||||
}
|
||||
|
||||
//
|
||||
// Convert postgres row with text-encoded values to JSON object
|
||||
//
|
||||
@@ -331,7 +483,7 @@ pub fn pg_text_row_to_json(
|
||||
} else {
|
||||
pg_text_to_json(pg_value, column.type_())?
|
||||
};
|
||||
Ok((name.to_string(), json_value))
|
||||
Ok((name, json_value))
|
||||
});
|
||||
|
||||
if array_mode {
|
||||
@@ -341,7 +493,55 @@ pub fn pg_text_row_to_json(
|
||||
.collect::<Result<Vec<Value>, anyhow::Error>>()?;
|
||||
Ok(Value::Array(arr))
|
||||
} else {
|
||||
let obj = iter.collect::<Result<Map<String, Value>, anyhow::Error>>()?;
|
||||
let obj = iter
|
||||
.map(|r| r.map(|(key, val)| (key.to_owned(), val)))
|
||||
.collect::<Result<Map<String, Value>, anyhow::Error>>()?;
|
||||
Ok(Value::Object(obj))
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Convert postgres row with text-encoded values to JSON object
|
||||
//
|
||||
fn pg_text_row_to_json2(
|
||||
row: &DataRowBody,
|
||||
columns: &[Column],
|
||||
raw_output: bool,
|
||||
array_mode: bool,
|
||||
) -> Result<Value, anyhow::Error> {
|
||||
let ranges: Vec<Option<std::ops::Range<usize>>> = row.ranges().collect()?;
|
||||
let iter = std::iter::zip(ranges, columns)
|
||||
.enumerate()
|
||||
.map(|(i, (range, column))| {
|
||||
let name = &column.name;
|
||||
let pg_value = range
|
||||
.map(|r| {
|
||||
std::str::from_utf8(&row.buffer()[r])
|
||||
.map_err(|e| pg_client::error::Error::from_sql(e.into(), i))
|
||||
})
|
||||
.transpose()?;
|
||||
// let pg_value = row.as_text(i)?;
|
||||
let json_value = if raw_output {
|
||||
match pg_value {
|
||||
Some(v) => Value::String(v.to_string()),
|
||||
None => Value::Null,
|
||||
}
|
||||
} else {
|
||||
pg_text_to_json(pg_value, &column.type_)?
|
||||
};
|
||||
Ok((name, json_value))
|
||||
});
|
||||
|
||||
if array_mode {
|
||||
// drop keys and aggregate into array
|
||||
let arr = iter
|
||||
.map(|r| r.map(|(_key, val)| val))
|
||||
.collect::<Result<Vec<Value>, anyhow::Error>>()?;
|
||||
Ok(Value::Array(arr))
|
||||
} else {
|
||||
let obj = iter
|
||||
.map(|r| r.map(|(key, val)| (key.to_owned(), val)))
|
||||
.collect::<Result<Map<String, Value>, anyhow::Error>>()?;
|
||||
Ok(Value::Object(obj))
|
||||
}
|
||||
}
|
||||
@@ -352,16 +552,16 @@ pub fn pg_text_row_to_json(
|
||||
pub fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result<Value, anyhow::Error> {
|
||||
if let Some(val) = pg_value {
|
||||
if let Kind::Array(elem_type) = pg_type.kind() {
|
||||
return pg_array_parse(val, elem_type);
|
||||
return pg_array_parse(val, &elem_type);
|
||||
}
|
||||
|
||||
match *pg_type {
|
||||
Type::BOOL => Ok(Value::Bool(val == "t")),
|
||||
Type::INT2 | Type::INT4 => {
|
||||
match pg_type {
|
||||
&Type::BOOL => Ok(Value::Bool(val == "t")),
|
||||
&Type::INT2 | &Type::INT4 => {
|
||||
let val = val.parse::<i32>()?;
|
||||
Ok(Value::Number(serde_json::Number::from(val)))
|
||||
}
|
||||
Type::FLOAT4 | Type::FLOAT8 => {
|
||||
&Type::FLOAT4 | &Type::FLOAT8 => {
|
||||
let fval = val.parse::<f64>()?;
|
||||
let num = serde_json::Number::from_f64(fval);
|
||||
if let Some(num) = num {
|
||||
@@ -373,7 +573,7 @@ pub fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result<Value,
|
||||
Ok(Value::String(val.to_string()))
|
||||
}
|
||||
}
|
||||
Type::JSON | Type::JSONB => Ok(serde_json::from_str(val)?),
|
||||
&Type::JSON | &Type::JSONB => Ok(serde_json::from_str(val)?),
|
||||
_ => Ok(Value::String(val.to_string())),
|
||||
}
|
||||
} else {
|
||||
|
||||
@@ -6,6 +6,7 @@ use crate::{
|
||||
};
|
||||
use bytes::{Buf, Bytes};
|
||||
use futures::{Sink, Stream, StreamExt};
|
||||
use hashbrown::HashMap;
|
||||
use hyper::{
|
||||
server::{
|
||||
accept,
|
||||
@@ -205,7 +206,7 @@ async fn ws_handler(
|
||||
Ok(_) => StatusCode::OK,
|
||||
Err(_) => StatusCode::BAD_REQUEST,
|
||||
};
|
||||
let json = match result {
|
||||
let (json, headers) = match result {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
let message = format!("{:?}", e);
|
||||
@@ -216,7 +217,10 @@ async fn ws_handler(
|
||||
},
|
||||
None => Value::Null,
|
||||
};
|
||||
json!({ "message": message, "code": code })
|
||||
(
|
||||
json!({ "message": message, "code": code }),
|
||||
HashMap::default(),
|
||||
)
|
||||
}
|
||||
};
|
||||
json_response(status_code, json).map(|mut r| {
|
||||
@@ -224,6 +228,9 @@ async fn ws_handler(
|
||||
"Access-Control-Allow-Origin",
|
||||
hyper::http::HeaderValue::from_static("*"),
|
||||
);
|
||||
for (k, v) in headers {
|
||||
r.headers_mut().insert(k, v);
|
||||
}
|
||||
r
|
||||
})
|
||||
} else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS {
|
||||
|
||||
@@ -22,6 +22,7 @@ pub mod scram;
|
||||
pub mod stream;
|
||||
pub mod url;
|
||||
pub mod waiters;
|
||||
pub mod pg_client;
|
||||
|
||||
/// Handle unix signals appropriately.
|
||||
pub async fn handle_signals(token: CancellationToken) -> anyhow::Result<Infallible> {
|
||||
|
||||
43
proxy/src/pg_client/codec.rs
Normal file
43
proxy/src/pg_client/codec.rs
Normal file
@@ -0,0 +1,43 @@
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use postgres_protocol::message::backend::{self, Message};
|
||||
use std::io;
|
||||
use tokio_util::codec::{Decoder, Encoder};
|
||||
|
||||
pub struct FrontendMessage(pub Bytes);
|
||||
pub struct BackendMessages(pub BytesMut);
|
||||
|
||||
impl BackendMessages {
|
||||
pub fn empty() -> BackendMessages {
|
||||
BackendMessages(BytesMut::new())
|
||||
}
|
||||
}
|
||||
|
||||
impl FallibleIterator for BackendMessages {
|
||||
type Item = backend::Message;
|
||||
type Error = io::Error;
|
||||
|
||||
fn next(&mut self) -> io::Result<Option<backend::Message>> {
|
||||
backend::Message::parse(&mut self.0)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PostgresCodec;
|
||||
|
||||
impl Encoder<FrontendMessage> for PostgresCodec {
|
||||
type Error = io::Error;
|
||||
|
||||
fn encode(&mut self, item: FrontendMessage, dst: &mut BytesMut) -> io::Result<()> {
|
||||
dst.extend_from_slice(&item.0);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for PostgresCodec {
|
||||
type Item = Message;
|
||||
type Error = io::Error;
|
||||
|
||||
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Message>, io::Error> {
|
||||
Message::parse(src)
|
||||
}
|
||||
}
|
||||
369
proxy/src/pg_client/connection.rs
Normal file
369
proxy/src/pg_client/connection.rs
Normal file
@@ -0,0 +1,369 @@
|
||||
use super::codec::{BackendMessages, FrontendMessage, PostgresCodec};
|
||||
use super::error::Error;
|
||||
use super::prepare::TypeinfoPreparedQueries;
|
||||
use bytes::{BufMut, BytesMut};
|
||||
use futures::channel::mpsc;
|
||||
use futures::{Sink, StreamExt};
|
||||
use futures::{SinkExt, Stream};
|
||||
use hashbrown::HashMap;
|
||||
use postgres_protocol::message::backend::{
|
||||
BackendKeyDataBody, CommandCompleteBody, DataRowBody, ErrorResponseBody, Message,
|
||||
ReadyForQueryBody, RowDescriptionBody,
|
||||
};
|
||||
use postgres_protocol::message::frontend;
|
||||
use postgres_protocol::Oid;
|
||||
use std::collections::VecDeque;
|
||||
use std::future::poll_fn;
|
||||
use std::pin::Pin;
|
||||
use std::task::{ready, Context, Poll};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_postgres::maybe_tls_stream::MaybeTlsStream;
|
||||
use tokio_postgres::types::Type;
|
||||
use tokio_postgres::IsolationLevel;
|
||||
use tokio_util::codec::Framed;
|
||||
|
||||
pub enum RequestMessages {
|
||||
Single(FrontendMessage),
|
||||
}
|
||||
|
||||
pub struct Request {
|
||||
pub messages: RequestMessages,
|
||||
pub sender: mpsc::Sender<BackendMessages>,
|
||||
}
|
||||
|
||||
pub struct Response {
|
||||
sender: mpsc::Sender<BackendMessages>,
|
||||
}
|
||||
|
||||
/// A connection to a PostgreSQL database.
|
||||
pub struct RawConnection<S, T> {
|
||||
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
|
||||
pending_responses: VecDeque<Message>,
|
||||
pub buf: BytesMut,
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin> RawConnection<S, T> {
|
||||
pub fn new(
|
||||
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
|
||||
buf: BytesMut,
|
||||
) -> RawConnection<S, T> {
|
||||
RawConnection {
|
||||
stream,
|
||||
pending_responses: VecDeque::new(),
|
||||
buf,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send(&mut self) -> Result<(), Error> {
|
||||
poll_fn(|cx| self.poll_send(cx)).await?;
|
||||
let request = FrontendMessage(self.buf.split().freeze());
|
||||
self.stream.start_send_unpin(request).map_err(Error::io)?;
|
||||
poll_fn(|cx| self.poll_flush(cx)).await
|
||||
}
|
||||
|
||||
pub async fn next_message(&mut self) -> Result<Message, Error> {
|
||||
match self.pending_responses.pop_front() {
|
||||
Some(message) => Ok(message),
|
||||
None => poll_fn(|cx| self.poll_read(cx)).await,
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll<Result<Message, Error>> {
|
||||
let message = match ready!(self.stream.poll_next_unpin(cx)?) {
|
||||
Some(message) => message,
|
||||
None => return Poll::Ready(Err(Error::closed())),
|
||||
};
|
||||
Poll::Ready(Ok(message))
|
||||
}
|
||||
|
||||
fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||
Pin::new(&mut self.stream).poll_close(cx).map_err(Error::io)
|
||||
}
|
||||
|
||||
fn poll_send(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||
if let Poll::Ready(msg) = self.poll_read(cx)? {
|
||||
self.pending_responses.push_back(msg);
|
||||
};
|
||||
self.stream.poll_ready_unpin(cx).map_err(Error::io)
|
||||
}
|
||||
|
||||
fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
|
||||
if let Poll::Ready(msg) = self.poll_read(cx)? {
|
||||
self.pending_responses.push_back(msg);
|
||||
};
|
||||
self.stream.poll_flush_unpin(cx).map_err(Error::io)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Connection<S, T> {
|
||||
stmt_counter: usize,
|
||||
pub typeinfo: Option<TypeinfoPreparedQueries>,
|
||||
pub typecache: HashMap<Oid, Type>,
|
||||
pub raw: RawConnection<S, T>,
|
||||
// key: BackendKeyDataBody,
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin> Connection<S, T> {
|
||||
pub fn new(stream: MaybeTlsStream<S, T>) -> Connection<S, T> {
|
||||
Connection {
|
||||
stmt_counter: 0,
|
||||
typeinfo: None,
|
||||
typecache: HashMap::new(),
|
||||
raw: RawConnection::new(Framed::new(stream, PostgresCodec), BytesMut::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn start_tx(
|
||||
&mut self,
|
||||
isolation_level: Option<IsolationLevel>,
|
||||
read_only: Option<bool>,
|
||||
) -> Result<ReadyForQueryBody, Error> {
|
||||
let mut query = "START TRANSACTION".to_string();
|
||||
let mut first = true;
|
||||
|
||||
if let Some(level) = isolation_level {
|
||||
first = false;
|
||||
|
||||
query.push_str(" ISOLATION LEVEL ");
|
||||
let level = match level {
|
||||
IsolationLevel::ReadUncommitted => "READ UNCOMMITTED",
|
||||
IsolationLevel::ReadCommitted => "READ COMMITTED",
|
||||
IsolationLevel::RepeatableRead => "REPEATABLE READ",
|
||||
IsolationLevel::Serializable => "SERIALIZABLE",
|
||||
_ => return Err(Error::unexpected_message()),
|
||||
};
|
||||
query.push_str(level);
|
||||
}
|
||||
|
||||
if let Some(read_only) = read_only {
|
||||
if !first {
|
||||
query.push(',');
|
||||
}
|
||||
first = false;
|
||||
|
||||
let s = if read_only {
|
||||
" READ ONLY"
|
||||
} else {
|
||||
" READ WRITE"
|
||||
};
|
||||
query.push_str(s);
|
||||
}
|
||||
|
||||
self.execute_simple(&query).await
|
||||
}
|
||||
|
||||
pub async fn rollback(&mut self) -> Result<ReadyForQueryBody, Error> {
|
||||
self.execute_simple("ROLLBACK").await
|
||||
}
|
||||
|
||||
pub async fn commit(&mut self) -> Result<ReadyForQueryBody, Error> {
|
||||
self.execute_simple("COMMIT").await
|
||||
}
|
||||
|
||||
// pub async fn auth_sasl_scram<'a, I>(
|
||||
// mut raw: RawConnection<S, T>,
|
||||
// params: I,
|
||||
// password: &[u8],
|
||||
// ) -> Result<Self, Error>
|
||||
// where
|
||||
// I: IntoIterator<Item = (&'a str, &'a str)>,
|
||||
// {
|
||||
// // send a startup message
|
||||
// frontend::startup_message(params, &mut raw.buf).unwrap();
|
||||
// raw.send().await?;
|
||||
|
||||
// // expect sasl authentication message
|
||||
// let Message::AuthenticationSasl(body) = raw.next_message().await? else { return Err(Error::expecting("sasl authentication")) };
|
||||
// // expect support for SCRAM_SHA_256
|
||||
// if body
|
||||
// .mechanisms()
|
||||
// .find(|&x| Ok(x == authentication::sasl::SCRAM_SHA_256))?
|
||||
// .is_none()
|
||||
// {
|
||||
// return Err(Error::expecting("SCRAM-SHA-256 auth"));
|
||||
// }
|
||||
|
||||
// // initiate SCRAM_SHA_256 authentication without channel binding
|
||||
// let auth = authentication::sasl::ChannelBinding::unrequested();
|
||||
// let mut scram = authentication::sasl::ScramSha256::new(password, auth);
|
||||
|
||||
// frontend::sasl_initial_response(
|
||||
// authentication::sasl::SCRAM_SHA_256,
|
||||
// scram.message(),
|
||||
// &mut raw.buf,
|
||||
// )
|
||||
// .unwrap();
|
||||
// raw.send().await?;
|
||||
|
||||
// // expect sasl continue
|
||||
// let Message::AuthenticationSaslContinue(b) = raw.next_message().await? else { return Err(Error::expecting("auth continue")) };
|
||||
// scram.update(b.data()).unwrap();
|
||||
|
||||
// // continue sasl
|
||||
// frontend::sasl_response(scram.message(), &mut raw.buf).unwrap();
|
||||
// raw.send().await?;
|
||||
|
||||
// // expect sasl final
|
||||
// let Message::AuthenticationSaslFinal(b) = raw.next_message().await? else { return Err(Error::expecting("auth final")) };
|
||||
// scram.finish(b.data()).unwrap();
|
||||
|
||||
// // expect auth ok
|
||||
// let Message::AuthenticationOk = raw.next_message().await? else { return Err(Error::expecting("auth ok")) };
|
||||
|
||||
// // expect connection accepted
|
||||
// let key = loop {
|
||||
// match raw.next_message().await? {
|
||||
// Message::BackendKeyData(key) => break key,
|
||||
// Message::ParameterStatus(_) => {}
|
||||
// _ => return Err(Error::expecting("backend ready")),
|
||||
// }
|
||||
// };
|
||||
|
||||
// let Message::ReadyForQuery(b) = raw.next_message().await? else { return Err(Error::expecting("ready for query")) };
|
||||
// // assert_eq!(b.status(), b'I');
|
||||
|
||||
// Ok(Self { raw, key })
|
||||
// }
|
||||
|
||||
// pub fn prepare_and_execute(
|
||||
// &mut self,
|
||||
// portal: &str,
|
||||
// name: &str,
|
||||
// query: &str,
|
||||
// params: impl IntoIterator<Item = Option<impl AsRef<str>>>,
|
||||
// ) -> std::io::Result<()> {
|
||||
// self.prepare(name, query)?;
|
||||
// self.execute(portal, name, params)
|
||||
// }
|
||||
|
||||
pub fn statement_name(&mut self) -> String {
|
||||
self.stmt_counter += 1;
|
||||
format!("s{}", self.stmt_counter)
|
||||
}
|
||||
|
||||
async fn execute_simple(&mut self, query: &str) -> Result<ReadyForQueryBody, Error> {
|
||||
frontend::query(query, &mut self.raw.buf)?;
|
||||
self.raw.send().await?;
|
||||
|
||||
loop {
|
||||
match self.raw.next_message().await? {
|
||||
Message::ReadyForQuery(q) => return Ok(q),
|
||||
Message::CommandComplete(_)
|
||||
| Message::EmptyQueryResponse
|
||||
| Message::RowDescription(_)
|
||||
| Message::DataRow(_) => {}
|
||||
_ => return Err(Error::unexpected_message()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn prepare(&mut self, name: &str, query: &str) -> Result<RowDescriptionBody, Error> {
|
||||
frontend::parse(name, query, std::iter::empty(), &mut self.raw.buf)?;
|
||||
frontend::describe(b'S', name, &mut self.raw.buf)?;
|
||||
self.sync().await?;
|
||||
self.wait_for_prepare().await
|
||||
}
|
||||
|
||||
pub fn execute(
|
||||
&mut self,
|
||||
portal: &str,
|
||||
name: &str,
|
||||
params: impl IntoIterator<Item = Option<impl AsRef<str>>>,
|
||||
) -> std::io::Result<()> {
|
||||
frontend::bind(
|
||||
portal,
|
||||
name,
|
||||
std::iter::empty(), // all parameters use the default format (text)
|
||||
params,
|
||||
|param, buf| match param {
|
||||
Some(param) => {
|
||||
buf.put_slice(param.as_ref().as_bytes());
|
||||
Ok(postgres_protocol::IsNull::No)
|
||||
}
|
||||
None => Ok(postgres_protocol::IsNull::Yes),
|
||||
},
|
||||
Some(0), // all text
|
||||
&mut self.raw.buf,
|
||||
)
|
||||
.map_err(|e| match e {
|
||||
frontend::BindError::Conversion(e) => std::io::Error::new(std::io::ErrorKind::Other, e),
|
||||
frontend::BindError::Serialization(io) => io,
|
||||
})?;
|
||||
frontend::execute(portal, 0, &mut self.raw.buf)
|
||||
}
|
||||
|
||||
pub async fn sync(&mut self) -> Result<(), Error> {
|
||||
frontend::sync(&mut self.raw.buf);
|
||||
self.raw.send().await
|
||||
}
|
||||
|
||||
pub async fn wait_for_prepare(&mut self) -> Result<RowDescriptionBody, Error> {
|
||||
let Message::ParseComplete = self.raw.next_message().await? else { return Err(Error::expecting("parse")) };
|
||||
let Message::ParameterDescription(_) = self.raw.next_message().await? else { return Err(Error::expecting("param description")) };
|
||||
let Message::RowDescription(desc) = self.raw.next_message().await? else { return Err(Error::expecting("row description")) };
|
||||
|
||||
self.wait_for_ready().await?;
|
||||
|
||||
Ok(desc)
|
||||
}
|
||||
|
||||
pub async fn stream_query_results(&mut self) -> Result<RowStream<'_, S, T>, Error> {
|
||||
// let Message::ParseComplete = self.raw.next_message().await? else { return Err(Error::expecting("parse")) };
|
||||
let Message::BindComplete = self.raw.next_message().await? else { return Err(Error::expecting("bind")) };
|
||||
Ok(RowStream::Stream(&mut self.raw))
|
||||
}
|
||||
|
||||
pub async fn wait_for_ready(&mut self) -> Result<ReadyForQueryBody, Error> {
|
||||
loop {
|
||||
match self.raw.next_message().await.unwrap() {
|
||||
Message::ReadyForQuery(b) => break Ok(b),
|
||||
_ => continue,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum RowStream<'a, S, T> {
|
||||
Stream(&'a mut RawConnection<S, T>),
|
||||
Complete(Option<CommandCompleteBody>),
|
||||
}
|
||||
impl<S, T> Unpin for RowStream<'_, S, T> {}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin> Stream
|
||||
for RowStream<'_, S, T>
|
||||
{
|
||||
// this is horrible - first result is for transport/protocol errors errors
|
||||
// second result is for sql errors.
|
||||
type Item = Result<Result<DataRowBody, ErrorResponseBody>, Error>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
match &mut *self {
|
||||
RowStream::Stream(raw) => match ready!(raw.poll_read(cx)?) {
|
||||
Message::DataRow(row) => Poll::Ready(Some(Ok(Ok(row)))),
|
||||
Message::CommandComplete(tag) => {
|
||||
*self = Self::Complete(Some(tag));
|
||||
Poll::Ready(None)
|
||||
}
|
||||
Message::EmptyQueryResponse | Message::PortalSuspended => {
|
||||
*self = Self::Complete(None);
|
||||
Poll::Ready(None)
|
||||
}
|
||||
Message::ErrorResponse(error) => {
|
||||
*self = Self::Complete(None);
|
||||
Poll::Ready(Some(Ok(Err(error))))
|
||||
}
|
||||
_ => Poll::Ready(Some(Err(Error::expecting("command completion")))),
|
||||
},
|
||||
RowStream::Complete(_) => Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S, T> RowStream<'_, S, T> {
|
||||
pub fn tag(self) -> Option<CommandCompleteBody> {
|
||||
match self {
|
||||
RowStream::Stream(_) => panic!("should not get tag unless row stream is exhausted"),
|
||||
RowStream::Complete(tag) => tag,
|
||||
}
|
||||
}
|
||||
}
|
||||
447
proxy/src/pg_client/error.rs
Normal file
447
proxy/src/pg_client/error.rs
Normal file
@@ -0,0 +1,447 @@
|
||||
use std::{error, fmt, io};
|
||||
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use postgres_protocol::message::backend::{ErrorFields, ErrorResponseBody};
|
||||
use tokio_native_tls::native_tls;
|
||||
use tokio_postgres::error::{ErrorPosition, SqlState};
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
enum Kind {
|
||||
Io,
|
||||
Tls,
|
||||
UnexpectedMessage,
|
||||
FromSql(usize),
|
||||
Closed,
|
||||
Db,
|
||||
Parse,
|
||||
Encode,
|
||||
}
|
||||
|
||||
struct ErrorInner {
|
||||
kind: Kind,
|
||||
cause: Option<Box<dyn error::Error + Sync + Send>>,
|
||||
}
|
||||
|
||||
/// An error communicating with the Postgres server.
|
||||
pub struct Error(ErrorInner);
|
||||
|
||||
impl fmt::Debug for Error {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
fmt.debug_struct("Error")
|
||||
.field("kind", &self.0.kind)
|
||||
.field("cause", &self.0.cause)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Error {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match &self.0.kind {
|
||||
Kind::Io => fmt.write_str("error communicating with the server")?,
|
||||
Kind::Tls => fmt.write_str("error establishing tls")?,
|
||||
Kind::UnexpectedMessage => fmt.write_str("unexpected message from server")?,
|
||||
Kind::FromSql(idx) => write!(fmt, "error deserializing column {}", idx)?,
|
||||
Kind::Closed => fmt.write_str("connection closed")?,
|
||||
Kind::Db => fmt.write_str("db error")?,
|
||||
Kind::Parse => fmt.write_str("error parsing response from server")?,
|
||||
Kind::Encode => fmt.write_str("error encoding message to server")?,
|
||||
};
|
||||
if let Some(ref cause) = self.0.cause {
|
||||
write!(fmt, ": {}", cause)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl error::Error for Error {
|
||||
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
|
||||
self.0.cause.as_ref().map(|e| &**e as _)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<io::Error> for Error {
|
||||
fn from(value: io::Error) -> Self {
|
||||
Self::io(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl Error {
|
||||
/// Consumes the error, returning its cause.
|
||||
pub fn into_source(self) -> Option<Box<dyn error::Error + Sync + Send>> {
|
||||
self.0.cause
|
||||
}
|
||||
|
||||
/// Returns the source of this error if it was a `DbError`.
|
||||
///
|
||||
/// This is a simple convenience method.
|
||||
pub fn as_db_error(&self) -> Option<&DbError> {
|
||||
error::Error::source(self).and_then(|e| e.downcast_ref::<DbError>())
|
||||
}
|
||||
|
||||
/// Determines if the error was associated with closed connection.
|
||||
pub fn is_closed(&self) -> bool {
|
||||
self.0.kind == Kind::Closed
|
||||
}
|
||||
|
||||
/// Returns the SQLSTATE error code associated with the error.
|
||||
///
|
||||
/// This is a convenience method that downcasts the cause to a `DbError` and returns its code.
|
||||
pub fn code(&self) -> Option<&SqlState> {
|
||||
self.as_db_error().map(DbError::code)
|
||||
}
|
||||
|
||||
fn new(kind: Kind, cause: Option<Box<dyn error::Error + Sync + Send>>) -> Error {
|
||||
Error(ErrorInner { kind, cause })
|
||||
}
|
||||
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
pub(crate) fn db(error: ErrorResponseBody) -> Error {
|
||||
match DbError::parse(&mut error.fields()) {
|
||||
Ok(e) => Error::new(Kind::Db, Some(Box::new(e))),
|
||||
Err(e) => Error::new(Kind::Parse, Some(Box::new(e))),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn from_sql(e: Box<dyn error::Error + Sync + Send>, idx: usize) -> Error {
|
||||
Error::new(Kind::FromSql(idx), Some(e))
|
||||
}
|
||||
|
||||
pub(crate) fn closed() -> Error {
|
||||
Error::new(Kind::Closed, None)
|
||||
}
|
||||
|
||||
pub(crate) fn unexpected_message() -> Error {
|
||||
Error::new(Kind::UnexpectedMessage, None)
|
||||
}
|
||||
|
||||
pub(crate) fn expecting(expected: &str) -> Error {
|
||||
Error::new(Kind::UnexpectedMessage, Some(expected.into()))
|
||||
}
|
||||
|
||||
pub(crate) fn parse(e: io::Error) -> Error {
|
||||
Error::new(Kind::Parse, Some(Box::new(e)))
|
||||
}
|
||||
|
||||
pub(crate) fn encode(e: io::Error) -> Error {
|
||||
Error::new(Kind::Encode, Some(Box::new(e)))
|
||||
}
|
||||
|
||||
pub(crate) fn io(e: io::Error) -> Error {
|
||||
Error::new(Kind::Io, Some(Box::new(e)))
|
||||
}
|
||||
|
||||
pub(crate) fn tls(e: native_tls::Error) -> Error {
|
||||
Error::new(Kind::Tls, Some(Box::new(e)))
|
||||
}
|
||||
}
|
||||
|
||||
/// The severity of a Postgres error or notice.
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
pub enum Severity {
|
||||
/// PANIC
|
||||
Panic,
|
||||
/// FATAL
|
||||
Fatal,
|
||||
/// ERROR
|
||||
Error,
|
||||
/// WARNING
|
||||
Warning,
|
||||
/// NOTICE
|
||||
Notice,
|
||||
/// DEBUG
|
||||
Debug,
|
||||
/// INFO
|
||||
Info,
|
||||
/// LOG
|
||||
Log,
|
||||
}
|
||||
|
||||
impl fmt::Display for Severity {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let s = match *self {
|
||||
Severity::Panic => "PANIC",
|
||||
Severity::Fatal => "FATAL",
|
||||
Severity::Error => "ERROR",
|
||||
Severity::Warning => "WARNING",
|
||||
Severity::Notice => "NOTICE",
|
||||
Severity::Debug => "DEBUG",
|
||||
Severity::Info => "INFO",
|
||||
Severity::Log => "LOG",
|
||||
};
|
||||
fmt.write_str(s)
|
||||
}
|
||||
}
|
||||
|
||||
impl Severity {
|
||||
fn from_str(s: &str) -> Option<Severity> {
|
||||
match s {
|
||||
"PANIC" => Some(Severity::Panic),
|
||||
"FATAL" => Some(Severity::Fatal),
|
||||
"ERROR" => Some(Severity::Error),
|
||||
"WARNING" => Some(Severity::Warning),
|
||||
"NOTICE" => Some(Severity::Notice),
|
||||
"DEBUG" => Some(Severity::Debug),
|
||||
"INFO" => Some(Severity::Info),
|
||||
"LOG" => Some(Severity::Log),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A Postgres error or notice.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct DbError {
|
||||
severity: String,
|
||||
parsed_severity: Option<Severity>,
|
||||
code: SqlState,
|
||||
message: String,
|
||||
detail: Option<String>,
|
||||
hint: Option<String>,
|
||||
position: Option<ErrorPosition>,
|
||||
where_: Option<String>,
|
||||
schema: Option<String>,
|
||||
table: Option<String>,
|
||||
column: Option<String>,
|
||||
datatype: Option<String>,
|
||||
constraint: Option<String>,
|
||||
file: Option<String>,
|
||||
line: Option<u32>,
|
||||
routine: Option<String>,
|
||||
}
|
||||
|
||||
impl DbError {
|
||||
pub(crate) fn parse(fields: &mut ErrorFields<'_>) -> io::Result<DbError> {
|
||||
let mut severity = None;
|
||||
let mut parsed_severity = None;
|
||||
let mut code = None;
|
||||
let mut message = None;
|
||||
let mut detail = None;
|
||||
let mut hint = None;
|
||||
let mut normal_position = None;
|
||||
let mut internal_position = None;
|
||||
let mut internal_query = None;
|
||||
let mut where_ = None;
|
||||
let mut schema = None;
|
||||
let mut table = None;
|
||||
let mut column = None;
|
||||
let mut datatype = None;
|
||||
let mut constraint = None;
|
||||
let mut file = None;
|
||||
let mut line = None;
|
||||
let mut routine = None;
|
||||
|
||||
while let Some(field) = fields.next()? {
|
||||
match field.type_() {
|
||||
b'S' => severity = Some(field.value().to_owned()),
|
||||
b'C' => code = Some(SqlState::from_code(field.value())),
|
||||
b'M' => message = Some(field.value().to_owned()),
|
||||
b'D' => detail = Some(field.value().to_owned()),
|
||||
b'H' => hint = Some(field.value().to_owned()),
|
||||
b'P' => {
|
||||
normal_position = Some(field.value().parse::<u32>().map_err(|_| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"`P` field did not contain an integer",
|
||||
)
|
||||
})?);
|
||||
}
|
||||
b'p' => {
|
||||
internal_position = Some(field.value().parse::<u32>().map_err(|_| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"`p` field did not contain an integer",
|
||||
)
|
||||
})?);
|
||||
}
|
||||
b'q' => internal_query = Some(field.value().to_owned()),
|
||||
b'W' => where_ = Some(field.value().to_owned()),
|
||||
b's' => schema = Some(field.value().to_owned()),
|
||||
b't' => table = Some(field.value().to_owned()),
|
||||
b'c' => column = Some(field.value().to_owned()),
|
||||
b'd' => datatype = Some(field.value().to_owned()),
|
||||
b'n' => constraint = Some(field.value().to_owned()),
|
||||
b'F' => file = Some(field.value().to_owned()),
|
||||
b'L' => {
|
||||
line = Some(field.value().parse::<u32>().map_err(|_| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"`L` field did not contain an integer",
|
||||
)
|
||||
})?);
|
||||
}
|
||||
b'R' => routine = Some(field.value().to_owned()),
|
||||
b'V' => {
|
||||
parsed_severity = Some(Severity::from_str(field.value()).ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"`V` field contained an invalid value",
|
||||
)
|
||||
})?);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(DbError {
|
||||
severity: severity
|
||||
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "`S` field missing"))?,
|
||||
parsed_severity,
|
||||
code: code
|
||||
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "`C` field missing"))?,
|
||||
message: message
|
||||
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "`M` field missing"))?,
|
||||
detail,
|
||||
hint,
|
||||
position: match normal_position {
|
||||
Some(position) => Some(ErrorPosition::Original(position)),
|
||||
None => match internal_position {
|
||||
Some(position) => Some(ErrorPosition::Internal {
|
||||
position,
|
||||
query: internal_query.ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"`q` field missing but `p` field present",
|
||||
)
|
||||
})?,
|
||||
}),
|
||||
None => None,
|
||||
},
|
||||
},
|
||||
where_,
|
||||
schema,
|
||||
table,
|
||||
column,
|
||||
datatype,
|
||||
constraint,
|
||||
file,
|
||||
line,
|
||||
routine,
|
||||
})
|
||||
}
|
||||
|
||||
/// The field contents are ERROR, FATAL, or PANIC (in an error message),
|
||||
/// or WARNING, NOTICE, DEBUG, INFO, or LOG (in a notice message), or a
|
||||
/// localized translation of one of these.
|
||||
pub fn severity(&self) -> &str {
|
||||
&self.severity
|
||||
}
|
||||
|
||||
/// A parsed, nonlocalized version of `severity`. (PostgreSQL 9.6+)
|
||||
pub fn parsed_severity(&self) -> Option<Severity> {
|
||||
self.parsed_severity
|
||||
}
|
||||
|
||||
/// The SQLSTATE code for the error.
|
||||
pub fn code(&self) -> &SqlState {
|
||||
&self.code
|
||||
}
|
||||
|
||||
/// The primary human-readable error message.
|
||||
///
|
||||
/// This should be accurate but terse (typically one line).
|
||||
pub fn message(&self) -> &str {
|
||||
&self.message
|
||||
}
|
||||
|
||||
/// An optional secondary error message carrying more detail about the
|
||||
/// problem.
|
||||
///
|
||||
/// Might run to multiple lines.
|
||||
pub fn detail(&self) -> Option<&str> {
|
||||
self.detail.as_deref()
|
||||
}
|
||||
|
||||
/// An optional suggestion what to do about the problem.
|
||||
///
|
||||
/// This is intended to differ from `detail` in that it offers advice
|
||||
/// (potentially inappropriate) rather than hard facts. Might run to
|
||||
/// multiple lines.
|
||||
pub fn hint(&self) -> Option<&str> {
|
||||
self.hint.as_deref()
|
||||
}
|
||||
|
||||
/// An optional error cursor position into either the original query string
|
||||
/// or an internally generated query.
|
||||
pub fn position(&self) -> Option<&ErrorPosition> {
|
||||
self.position.as_ref()
|
||||
}
|
||||
|
||||
/// An indication of the context in which the error occurred.
|
||||
///
|
||||
/// Presently this includes a call stack traceback of active procedural
|
||||
/// language functions and internally-generated queries. The trace is one
|
||||
/// entry per line, most recent first.
|
||||
pub fn where_(&self) -> Option<&str> {
|
||||
self.where_.as_deref()
|
||||
}
|
||||
|
||||
/// If the error was associated with a specific database object, the name
|
||||
/// of the schema containing that object, if any. (PostgreSQL 9.3+)
|
||||
pub fn schema(&self) -> Option<&str> {
|
||||
self.schema.as_deref()
|
||||
}
|
||||
|
||||
/// If the error was associated with a specific table, the name of the
|
||||
/// table. (Refer to the schema name field for the name of the table's
|
||||
/// schema.) (PostgreSQL 9.3+)
|
||||
pub fn table(&self) -> Option<&str> {
|
||||
self.table.as_deref()
|
||||
}
|
||||
|
||||
/// If the error was associated with a specific table column, the name of
|
||||
/// the column.
|
||||
///
|
||||
/// (Refer to the schema and table name fields to identify the table.)
|
||||
/// (PostgreSQL 9.3+)
|
||||
pub fn column(&self) -> Option<&str> {
|
||||
self.column.as_deref()
|
||||
}
|
||||
|
||||
/// If the error was associated with a specific data type, the name of the
|
||||
/// data type. (Refer to the schema name field for the name of the data
|
||||
/// type's schema.) (PostgreSQL 9.3+)
|
||||
pub fn datatype(&self) -> Option<&str> {
|
||||
self.datatype.as_deref()
|
||||
}
|
||||
|
||||
/// If the error was associated with a specific constraint, the name of the
|
||||
/// constraint.
|
||||
///
|
||||
/// Refer to fields listed above for the associated table or domain.
|
||||
/// (For this purpose, indexes are treated as constraints, even if they
|
||||
/// weren't created with constraint syntax.) (PostgreSQL 9.3+)
|
||||
pub fn constraint(&self) -> Option<&str> {
|
||||
self.constraint.as_deref()
|
||||
}
|
||||
|
||||
/// The file name of the source-code location where the error was reported.
|
||||
pub fn file(&self) -> Option<&str> {
|
||||
self.file.as_deref()
|
||||
}
|
||||
|
||||
/// The line number of the source-code location where the error was
|
||||
/// reported.
|
||||
pub fn line(&self) -> Option<u32> {
|
||||
self.line
|
||||
}
|
||||
|
||||
/// The name of the source-code routine reporting the error.
|
||||
pub fn routine(&self) -> Option<&str> {
|
||||
self.routine.as_deref()
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for DbError {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(fmt, "{}: {}", self.severity, self.message)?;
|
||||
if let Some(detail) = &self.detail {
|
||||
write!(fmt, "\nDETAIL: {}", detail)?;
|
||||
}
|
||||
if let Some(hint) = &self.hint {
|
||||
write!(fmt, "\nHINT: {}", hint)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl error::Error for DbError {}
|
||||
5
proxy/src/pg_client/mod.rs
Normal file
5
proxy/src/pg_client/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
|
||||
pub mod codec;
|
||||
pub mod connection;
|
||||
pub mod error;
|
||||
pub mod prepare;
|
||||
293
proxy/src/pg_client/prepare.rs
Normal file
293
proxy/src/pg_client/prepare.rs
Normal file
@@ -0,0 +1,293 @@
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use futures::StreamExt;
|
||||
use postgres_protocol::message::backend::{DataRowRanges, Message};
|
||||
use postgres_protocol::message::frontend;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_postgres::types::{Field, Kind, Oid, ToSql, Type};
|
||||
|
||||
use super::connection::Connection;
|
||||
use super::error::Error;
|
||||
|
||||
const TYPEINFO_QUERY: &str = "\
|
||||
SELECT t.typname, t.typtype, t.typelem, r.rngsubtype, t.typbasetype, n.nspname, t.typrelid
|
||||
FROM pg_catalog.pg_type t
|
||||
LEFT OUTER JOIN pg_catalog.pg_range r ON r.rngtypid = t.oid
|
||||
INNER JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid
|
||||
WHERE t.oid = $1
|
||||
";
|
||||
|
||||
const TYPEINFO_ENUM_QUERY: &str = "\
|
||||
SELECT enumlabel
|
||||
FROM pg_catalog.pg_enum
|
||||
WHERE enumtypid = $1
|
||||
ORDER BY enumsortorder
|
||||
";
|
||||
|
||||
const TYPEINFO_COMPOSITE_QUERY: &str = "\
|
||||
SELECT attname, atttypid
|
||||
FROM pg_catalog.pg_attribute
|
||||
WHERE attrelid = $1
|
||||
AND NOT attisdropped
|
||||
AND attnum > 0
|
||||
ORDER BY attnum
|
||||
";
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct TypeinfoPreparedQueries {
|
||||
query: String,
|
||||
enum_query: String,
|
||||
composite_query: String,
|
||||
}
|
||||
|
||||
fn map_is_null(x: tokio_postgres::types::IsNull) -> postgres_protocol::IsNull {
|
||||
match x {
|
||||
tokio_postgres::types::IsNull::Yes => postgres_protocol::IsNull::Yes,
|
||||
tokio_postgres::types::IsNull::No => postgres_protocol::IsNull::No,
|
||||
}
|
||||
}
|
||||
|
||||
fn read_column<'a, T: tokio_postgres::types::FromSql<'a>>(
|
||||
buffer: &'a [u8],
|
||||
type_: &Type,
|
||||
ranges: &mut DataRowRanges<'a>,
|
||||
) -> Result<T, Error> {
|
||||
let range = ranges.next()?;
|
||||
match range {
|
||||
Some(range) => T::from_sql_nullable(type_, range.map(|r| &buffer[r])),
|
||||
None => T::from_sql_null(type_),
|
||||
}
|
||||
.map_err(|e| Error::from_sql(e, 0))
|
||||
}
|
||||
|
||||
impl TypeinfoPreparedQueries {
|
||||
pub async fn new<
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send,
|
||||
T: AsyncRead + AsyncWrite + Unpin + Send,
|
||||
>(
|
||||
c: &mut Connection<S, T>,
|
||||
) -> Result<Self, Error> {
|
||||
if let Some(ti) = &c.typeinfo {
|
||||
return Ok(ti.clone());
|
||||
}
|
||||
|
||||
let query = c.statement_name();
|
||||
let enum_query = c.statement_name();
|
||||
let composite_query = c.statement_name();
|
||||
|
||||
frontend::parse(&query, TYPEINFO_QUERY, [Type::OID.oid()], &mut c.raw.buf)?;
|
||||
frontend::parse(
|
||||
&enum_query,
|
||||
TYPEINFO_ENUM_QUERY,
|
||||
[Type::OID.oid()],
|
||||
&mut c.raw.buf,
|
||||
)?;
|
||||
c.sync().await?;
|
||||
frontend::parse(
|
||||
&composite_query,
|
||||
TYPEINFO_COMPOSITE_QUERY,
|
||||
[Type::OID.oid()],
|
||||
&mut c.raw.buf,
|
||||
)?;
|
||||
c.sync().await?;
|
||||
|
||||
let Message::ParseComplete = c.raw.next_message().await? else { return Err(Error::expecting("parse")) };
|
||||
let Message::ParseComplete = c.raw.next_message().await? else { return Err(Error::expecting("parse")) };
|
||||
let Message::ParseComplete = c.raw.next_message().await? else { return Err(Error::expecting("parse")) };
|
||||
c.wait_for_ready().await?;
|
||||
|
||||
Ok(c.typeinfo
|
||||
.insert(TypeinfoPreparedQueries {
|
||||
query,
|
||||
enum_query,
|
||||
composite_query,
|
||||
})
|
||||
.clone())
|
||||
}
|
||||
|
||||
fn get_type_rec<
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send,
|
||||
T: AsyncRead + AsyncWrite + Unpin + Send,
|
||||
>(
|
||||
c: &mut Connection<S, T>,
|
||||
oid: Oid,
|
||||
) -> Pin<Box<dyn Future<Output = Result<Type, Error>> + Send + '_>> {
|
||||
Box::pin(Self::get_type(c, oid))
|
||||
}
|
||||
|
||||
pub async fn get_type<
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send,
|
||||
T: AsyncRead + AsyncWrite + Unpin + Send,
|
||||
>(
|
||||
c: &mut Connection<S, T>,
|
||||
oid: Oid,
|
||||
) -> Result<Type, Error> {
|
||||
if let Some(type_) = Type::from_oid(oid) {
|
||||
return Ok(type_);
|
||||
}
|
||||
|
||||
if let Some(type_) = c.typecache.get(&oid) {
|
||||
return Ok(type_.clone());
|
||||
}
|
||||
|
||||
let queries = Self::new(c).await?;
|
||||
|
||||
frontend::bind(
|
||||
"",
|
||||
&queries.query,
|
||||
[1], // the only parameter is in binary format
|
||||
[oid],
|
||||
|param, buf| param.to_sql(&Type::OID, buf).map(map_is_null),
|
||||
Some(1), // binary return type
|
||||
&mut c.raw.buf,
|
||||
)
|
||||
.map_err(|e| match e {
|
||||
frontend::BindError::Conversion(e) => std::io::Error::new(std::io::ErrorKind::Other, e),
|
||||
frontend::BindError::Serialization(io) => io,
|
||||
})?;
|
||||
frontend::execute("", 0, &mut c.raw.buf)?;
|
||||
|
||||
c.sync().await?;
|
||||
|
||||
let mut stream = c.stream_query_results().await?;
|
||||
|
||||
let Some(row) = stream.next().await.transpose()? else {
|
||||
todo!()
|
||||
};
|
||||
|
||||
let row = row.map_err(Error::db)?;
|
||||
let b = row.buffer();
|
||||
let mut ranges = row.ranges();
|
||||
|
||||
let name: String = read_column(b, &Type::NAME, &mut ranges)?;
|
||||
let type_: i8 = read_column(b, &Type::CHAR, &mut ranges)?;
|
||||
let elem_oid: Oid = read_column(b, &Type::OID, &mut ranges)?;
|
||||
let rngsubtype: Option<Oid> = read_column(b, &Type::OID, &mut ranges)?;
|
||||
let basetype: Oid = read_column(b, &Type::OID, &mut ranges)?;
|
||||
let schema: String = read_column(b, &Type::NAME, &mut ranges)?;
|
||||
let relid: Oid = read_column(b, &Type::OID, &mut ranges)?;
|
||||
|
||||
{
|
||||
// should be none
|
||||
let None = stream.next().await.transpose()? else {
|
||||
todo!()
|
||||
};
|
||||
drop(stream);
|
||||
}
|
||||
|
||||
let kind = if type_ == b'e' as i8 {
|
||||
let variants = Self::get_enum_variants(c, oid).await?;
|
||||
Kind::Enum(variants)
|
||||
} else if type_ == b'p' as i8 {
|
||||
Kind::Pseudo
|
||||
} else if basetype != 0 {
|
||||
let type_ = Self::get_type_rec(c, basetype).await?;
|
||||
Kind::Domain(type_)
|
||||
} else if elem_oid != 0 {
|
||||
let type_ = Self::get_type_rec(c, elem_oid).await?;
|
||||
Kind::Array(type_)
|
||||
} else if relid != 0 {
|
||||
let fields = Self::get_composite_fields(c, relid).await?;
|
||||
Kind::Composite(fields)
|
||||
} else if let Some(rngsubtype) = rngsubtype {
|
||||
let type_ = Self::get_type_rec(c, rngsubtype).await?;
|
||||
Kind::Range(type_)
|
||||
} else {
|
||||
Kind::Simple
|
||||
};
|
||||
|
||||
let type_ = Type::new(name, oid, kind, schema);
|
||||
c.typecache.insert(oid, type_.clone());
|
||||
|
||||
Ok(type_)
|
||||
}
|
||||
|
||||
async fn get_enum_variants<
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send,
|
||||
T: AsyncRead + AsyncWrite + Unpin + Send,
|
||||
>(
|
||||
c: &mut Connection<S, T>,
|
||||
oid: Oid,
|
||||
) -> Result<Vec<String>, Error> {
|
||||
let queries = Self::new(c).await?;
|
||||
|
||||
frontend::bind(
|
||||
"",
|
||||
&queries.enum_query,
|
||||
[1], // the only parameter is in binary format
|
||||
[oid],
|
||||
|param, buf| param.to_sql(&Type::OID, buf).map(map_is_null),
|
||||
Some(1), // binary return type
|
||||
&mut c.raw.buf,
|
||||
)
|
||||
.map_err(|e| match e {
|
||||
frontend::BindError::Conversion(e) => std::io::Error::new(std::io::ErrorKind::Other, e),
|
||||
frontend::BindError::Serialization(io) => io,
|
||||
})?;
|
||||
frontend::execute("", 0, &mut c.raw.buf)?;
|
||||
|
||||
c.sync().await?;
|
||||
|
||||
let mut stream = c.stream_query_results().await?;
|
||||
let mut variants = Vec::new();
|
||||
while let Some(row) = stream.next().await.transpose()? {
|
||||
let row = row.map_err(Error::db)?;
|
||||
|
||||
let variant: String = read_column(row.buffer(), &Type::NAME, &mut row.ranges())?;
|
||||
variants.push(variant);
|
||||
}
|
||||
|
||||
c.wait_for_ready().await?;
|
||||
|
||||
Ok(variants)
|
||||
}
|
||||
|
||||
async fn get_composite_fields<
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send,
|
||||
T: AsyncRead + AsyncWrite + Unpin + Send,
|
||||
>(
|
||||
c: &mut Connection<S, T>,
|
||||
oid: Oid,
|
||||
) -> Result<Vec<Field>, Error> {
|
||||
let queries = Self::new(c).await?;
|
||||
|
||||
frontend::bind(
|
||||
"",
|
||||
&queries.composite_query,
|
||||
[1], // the only parameter is in binary format
|
||||
[oid],
|
||||
|param, buf| param.to_sql(&Type::OID, buf).map(map_is_null),
|
||||
Some(1), // binary return type
|
||||
&mut c.raw.buf,
|
||||
)
|
||||
.map_err(|e| match e {
|
||||
frontend::BindError::Conversion(e) => std::io::Error::new(std::io::ErrorKind::Other, e),
|
||||
frontend::BindError::Serialization(io) => io,
|
||||
})?;
|
||||
frontend::execute("", 0, &mut c.raw.buf)?;
|
||||
|
||||
c.sync().await?;
|
||||
|
||||
let mut stream = c.stream_query_results().await?;
|
||||
let mut fields = Vec::new();
|
||||
while let Some(row) = stream.next().await.transpose()? {
|
||||
let row = row.map_err(Error::db)?;
|
||||
|
||||
let mut ranges = row.ranges();
|
||||
let name: String = read_column(row.buffer(), &Type::NAME, &mut ranges)?;
|
||||
let oid: Oid = read_column(row.buffer(), &Type::OID, &mut ranges)?;
|
||||
fields.push((name, oid));
|
||||
}
|
||||
|
||||
c.wait_for_ready().await?;
|
||||
|
||||
let mut output_fields = Vec::with_capacity(fields.len());
|
||||
for (name, oid) in fields {
|
||||
let type_ = Self::get_type_rec(c, oid).await?;
|
||||
output_fields.push(Field::new(name, type_))
|
||||
}
|
||||
|
||||
Ok(output_fields)
|
||||
}
|
||||
}
|
||||
@@ -347,11 +347,6 @@ async fn connect_to_compute_once(
|
||||
.await
|
||||
}
|
||||
|
||||
enum ConnectionState<E> {
|
||||
Cached(console::CachedNodeInfo),
|
||||
Invalid(compute::ConnCfg, E),
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait ConnectMechanism {
|
||||
type Connection;
|
||||
@@ -407,70 +402,67 @@ where
|
||||
|
||||
mechanism.update_connect_config(&mut node_info.config);
|
||||
|
||||
let mut num_retries = 0;
|
||||
let mut state = ConnectionState::<M::ConnectError>::Cached(node_info);
|
||||
// try once
|
||||
let (config, err) = match mechanism.connect_once(&node_info, CONNECT_TIMEOUT).await {
|
||||
Ok(res) => return Ok(res),
|
||||
Err(e) => {
|
||||
error!(error = ?e, "could not connect to compute node");
|
||||
(invalidate_cache(node_info), e)
|
||||
}
|
||||
};
|
||||
|
||||
loop {
|
||||
match state {
|
||||
ConnectionState::Invalid(config, err) => {
|
||||
info!("compute node's state has likely changed; requesting a wake-up");
|
||||
let mut num_retries = 1;
|
||||
|
||||
let wake_res = match creds {
|
||||
auth::BackendType::Console(api, creds) => api.wake_compute(extra, creds).await,
|
||||
auth::BackendType::Postgres(api, creds) => api.wake_compute(extra, creds).await,
|
||||
// nothing to do?
|
||||
auth::BackendType::Link(_) => return Err(err.into()),
|
||||
// test backend
|
||||
auth::BackendType::Test(x) => x.wake_compute(),
|
||||
};
|
||||
// if we failed to connect, it's likely that the compute node was suspended, wake a new compute node
|
||||
info!("compute node's state has likely changed; requesting a wake-up");
|
||||
let node_info = loop {
|
||||
let wake_res = match creds {
|
||||
auth::BackendType::Console(api, creds) => api.wake_compute(extra, creds).await,
|
||||
auth::BackendType::Postgres(api, creds) => api.wake_compute(extra, creds).await,
|
||||
// nothing to do?
|
||||
auth::BackendType::Link(_) => return Err(err.into()),
|
||||
// test backend
|
||||
auth::BackendType::Test(x) => x.wake_compute(),
|
||||
};
|
||||
|
||||
match handle_try_wake(wake_res) {
|
||||
// there was an error communicating with the control plane
|
||||
Err(e) => return Err(e.into()),
|
||||
// failed to wake up but we can continue to retry
|
||||
Ok(ControlFlow::Continue(_)) => {
|
||||
state = ConnectionState::Invalid(config, err);
|
||||
let wait_duration = retry_after(num_retries);
|
||||
num_retries += 1;
|
||||
|
||||
info!(num_retries, "retrying wake compute");
|
||||
time::sleep(wait_duration).await;
|
||||
continue;
|
||||
}
|
||||
// successfully woke up a compute node and can break the wakeup loop
|
||||
Ok(ControlFlow::Break(mut node_info)) => {
|
||||
node_info.config.reuse_password(&config);
|
||||
mechanism.update_connect_config(&mut node_info.config);
|
||||
state = ConnectionState::Cached(node_info)
|
||||
}
|
||||
}
|
||||
match handle_try_wake(wake_res, num_retries)? {
|
||||
// failed to wake up but we can continue to retry
|
||||
ControlFlow::Continue(_) => {}
|
||||
// successfully woke up a compute node and can break the wakeup loop
|
||||
ControlFlow::Break(mut node_info) => {
|
||||
node_info.config.reuse_password(&config);
|
||||
mechanism.update_connect_config(&mut node_info.config);
|
||||
break node_info;
|
||||
}
|
||||
ConnectionState::Cached(node_info) => {
|
||||
match mechanism.connect_once(&node_info, CONNECT_TIMEOUT).await {
|
||||
Ok(res) => return Ok(res),
|
||||
Err(e) => {
|
||||
error!(error = ?e, "could not connect to compute node");
|
||||
if !e.should_retry(num_retries) {
|
||||
return Err(e.into());
|
||||
}
|
||||
}
|
||||
|
||||
// after the first connect failure,
|
||||
// we should invalidate the cache and wake up a new compute node
|
||||
if num_retries == 0 {
|
||||
state = ConnectionState::Invalid(invalidate_cache(node_info), e);
|
||||
} else {
|
||||
state = ConnectionState::Cached(node_info);
|
||||
}
|
||||
let wait_duration = retry_after(num_retries);
|
||||
num_retries += 1;
|
||||
|
||||
let wait_duration = retry_after(num_retries);
|
||||
num_retries += 1;
|
||||
time::sleep(wait_duration).await;
|
||||
info!(num_retries, "retrying wake compute");
|
||||
};
|
||||
|
||||
info!(num_retries, "retrying wake compute");
|
||||
time::sleep(wait_duration).await;
|
||||
}
|
||||
// now that we have a new node, try connect to it repeatedly.
|
||||
// this can error for a few reasons, for instance:
|
||||
// * DNS connection settings haven't quite propagated yet
|
||||
info!("wake_compute success. attempting to connect");
|
||||
loop {
|
||||
match mechanism.connect_once(&node_info, CONNECT_TIMEOUT).await {
|
||||
Ok(res) => return Ok(res),
|
||||
Err(e) => {
|
||||
error!(error = ?e, "could not connect to compute node");
|
||||
if !e.should_retry(num_retries) {
|
||||
return Err(e.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let wait_duration = retry_after(num_retries);
|
||||
num_retries += 1;
|
||||
|
||||
time::sleep(wait_duration).await;
|
||||
info!(num_retries, "retrying connect_once");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -478,12 +470,15 @@ where
|
||||
/// * Returns Ok(Continue(e)) if there was an error waking but retries are acceptable
|
||||
/// * Returns Ok(Break(node)) if the wakeup succeeded
|
||||
/// * Returns Err(e) if there was an error
|
||||
fn handle_try_wake(
|
||||
pub fn handle_try_wake(
|
||||
result: Result<console::CachedNodeInfo, WakeComputeError>,
|
||||
num_retries: u32,
|
||||
) -> Result<ControlFlow<console::CachedNodeInfo, WakeComputeError>, WakeComputeError> {
|
||||
match result {
|
||||
Err(err) => match &err {
|
||||
WakeComputeError::ApiError(api) if api.could_retry() => Ok(ControlFlow::Continue(err)),
|
||||
WakeComputeError::ApiError(api) if api.should_retry(num_retries) => {
|
||||
Ok(ControlFlow::Continue(err))
|
||||
}
|
||||
_ => Err(err),
|
||||
},
|
||||
// Ready to try again.
|
||||
@@ -491,22 +486,10 @@ fn handle_try_wake(
|
||||
}
|
||||
}
|
||||
|
||||
/// Attempts to wake up the compute node.
|
||||
pub async fn try_wake(
|
||||
api: &impl console::Api,
|
||||
extra: &console::ConsoleReqExtra<'_>,
|
||||
creds: &auth::ClientCredentials<'_>,
|
||||
) -> Result<ControlFlow<console::CachedNodeInfo, WakeComputeError>, WakeComputeError> {
|
||||
info!("compute node's state has likely changed; requesting a wake-up");
|
||||
handle_try_wake(api.wake_compute(extra, creds).await)
|
||||
}
|
||||
|
||||
pub trait ShouldRetry {
|
||||
fn could_retry(&self) -> bool;
|
||||
fn should_retry(&self, num_retries: u32) -> bool {
|
||||
match self {
|
||||
// retry all errors at least once
|
||||
_ if num_retries == 0 => true,
|
||||
_ if num_retries >= NUM_RETRIES_CONNECT => false,
|
||||
err => err.could_retry(),
|
||||
}
|
||||
@@ -558,14 +541,9 @@ impl ShouldRetry for compute::ConnectionError {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn retry_after(num_retries: u32) -> time::Duration {
|
||||
match num_retries {
|
||||
0 => time::Duration::ZERO,
|
||||
_ => {
|
||||
// 3/2 = 1.5 which seems to be an ok growth factor heuristic
|
||||
BASE_RETRY_WAIT_DURATION * 3_u32.pow(num_retries) / 2_u32.pow(num_retries)
|
||||
}
|
||||
}
|
||||
fn retry_after(num_retries: u32) -> time::Duration {
|
||||
// 1.5 seems to be an ok growth factor heuristic
|
||||
BASE_RETRY_WAIT_DURATION.mul_f64(1.5_f64.powi(num_retries as i32))
|
||||
}
|
||||
|
||||
/// Finish client connection initialization: confirm auth success, send params, etc.
|
||||
|
||||
@@ -99,9 +99,8 @@ struct Scram(scram::ServerSecret);
|
||||
|
||||
impl Scram {
|
||||
fn new(password: &str) -> anyhow::Result<Self> {
|
||||
let salt = rand::random::<[u8; 16]>();
|
||||
let secret = scram::ServerSecret::build(password, &salt, 256)
|
||||
.context("failed to generate scram secret")?;
|
||||
let secret =
|
||||
scram::ServerSecret::build(password).context("failed to generate scram secret")?;
|
||||
Ok(Scram(secret))
|
||||
}
|
||||
|
||||
@@ -302,7 +301,7 @@ async fn scram_auth_mock() -> anyhow::Result<()> {
|
||||
#[test]
|
||||
fn connect_compute_total_wait() {
|
||||
let mut total_wait = tokio::time::Duration::ZERO;
|
||||
for num_retries in 0..10 {
|
||||
for num_retries in 1..10 {
|
||||
total_wait += retry_after(num_retries);
|
||||
}
|
||||
assert!(total_wait < tokio::time::Duration::from_secs(12));
|
||||
|
||||
@@ -12,9 +12,6 @@ mod messages;
|
||||
mod secret;
|
||||
mod signature;
|
||||
|
||||
#[cfg(any(test, doc))]
|
||||
mod password;
|
||||
|
||||
pub use exchange::Exchange;
|
||||
pub use key::ScramKey;
|
||||
pub use secret::ServerSecret;
|
||||
@@ -57,27 +54,21 @@ fn sha256<'a>(parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use postgres_protocol::authentication::sasl::{ChannelBinding, ScramSha256};
|
||||
|
||||
use crate::sasl::{Mechanism, Step};
|
||||
|
||||
use super::{password::SaltedPassword, Exchange, ServerSecret};
|
||||
use super::{Exchange, ServerSecret};
|
||||
|
||||
#[test]
|
||||
fn happy_path() {
|
||||
fn snapshot() {
|
||||
let iterations = 4096;
|
||||
let salt_base64 = "QSXCR+Q6sek8bf92";
|
||||
let pw = SaltedPassword::new(
|
||||
b"pencil",
|
||||
base64::decode(salt_base64).unwrap().as_slice(),
|
||||
iterations,
|
||||
);
|
||||
let salt = "QSXCR+Q6sek8bf92";
|
||||
let stored_key = "FO+9jBb3MUukt6jJnzjPZOWc5ow/Pu6JtPyju0aqaE8=";
|
||||
let server_key = "qxJ1SbmSAi5EcS0J5Ck/cKAm/+Ixa+Kwp63f4OHDgzo=";
|
||||
let secret = format!("SCRAM-SHA-256${iterations}:{salt}${stored_key}:{server_key}",);
|
||||
let secret = ServerSecret::parse(&secret).unwrap();
|
||||
|
||||
let secret = ServerSecret {
|
||||
iterations,
|
||||
salt_base64: salt_base64.to_owned(),
|
||||
stored_key: pw.client_key().sha256(),
|
||||
server_key: pw.server_key(),
|
||||
doomed: false,
|
||||
};
|
||||
const NONCE: [u8; 18] = [
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
|
||||
];
|
||||
@@ -115,4 +106,40 @@ mod tests {
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
fn run_round_trip_test(client_password: &str) {
|
||||
let secret = ServerSecret::build("pencil").unwrap();
|
||||
let mut exchange = Exchange::new(&secret, rand::random, None);
|
||||
|
||||
let mut client =
|
||||
ScramSha256::new(client_password.as_bytes(), ChannelBinding::unsupported());
|
||||
|
||||
let client_first = std::str::from_utf8(client.message()).unwrap();
|
||||
exchange = match exchange.exchange(client_first).unwrap() {
|
||||
Step::Continue(exchange, message) => {
|
||||
client.update(message.as_bytes()).unwrap();
|
||||
exchange
|
||||
}
|
||||
Step::Success(_, _) => panic!("expected continue, got success"),
|
||||
Step::Failure(f) => panic!("{f}"),
|
||||
};
|
||||
|
||||
let client_final = std::str::from_utf8(client.message()).unwrap();
|
||||
match exchange.exchange(client_final).unwrap() {
|
||||
Step::Success(_, message) => client.finish(message.as_bytes()).unwrap(),
|
||||
Step::Continue(_, _) => panic!("expected success, got continue"),
|
||||
Step::Failure(f) => panic!("{f}"),
|
||||
};
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn round_trip() {
|
||||
run_round_trip_test("pencil")
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "password doesn't match")]
|
||||
fn failure() {
|
||||
run_round_trip_test("eraser")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
/// Faithfully taken from PostgreSQL.
|
||||
pub const SCRAM_KEY_LEN: usize = 32;
|
||||
|
||||
/// One of the keys derived from the [password](super::password::SaltedPassword).
|
||||
/// One of the keys derived from the user's password.
|
||||
/// We use the same structure for all keys, i.e.
|
||||
/// `ClientKey`, `StoredKey`, and `ServerKey`.
|
||||
#[derive(Default, PartialEq, Eq)]
|
||||
|
||||
@@ -1,74 +0,0 @@
|
||||
//! Password hashing routines.
|
||||
|
||||
use super::key::ScramKey;
|
||||
|
||||
pub const SALTED_PASSWORD_LEN: usize = 32;
|
||||
|
||||
/// Salted hashed password is essential for [key](super::key) derivation.
|
||||
#[repr(transparent)]
|
||||
pub struct SaltedPassword {
|
||||
bytes: [u8; SALTED_PASSWORD_LEN],
|
||||
}
|
||||
|
||||
impl SaltedPassword {
|
||||
/// See `scram-common.c : scram_SaltedPassword` for details.
|
||||
/// Further reading: <https://datatracker.ietf.org/doc/html/rfc2898> (see `PBKDF2`).
|
||||
pub fn new(password: &[u8], salt: &[u8], iterations: u32) -> SaltedPassword {
|
||||
pbkdf2::pbkdf2_hmac_array::<sha2::Sha256, 32>(password, salt, iterations).into()
|
||||
}
|
||||
|
||||
/// Derive `ClientKey` from a salted hashed password.
|
||||
pub fn client_key(&self) -> ScramKey {
|
||||
super::hmac_sha256(&self.bytes, [b"Client Key".as_ref()]).into()
|
||||
}
|
||||
|
||||
/// Derive `ServerKey` from a salted hashed password.
|
||||
pub fn server_key(&self) -> ScramKey {
|
||||
super::hmac_sha256(&self.bytes, [b"Server Key".as_ref()]).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[u8; SALTED_PASSWORD_LEN]> for SaltedPassword {
|
||||
#[inline(always)]
|
||||
fn from(bytes: [u8; SALTED_PASSWORD_LEN]) -> Self {
|
||||
Self { bytes }
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::SaltedPassword;
|
||||
|
||||
fn legacy_pbkdf2_impl(password: &[u8], salt: &[u8], iterations: u32) -> SaltedPassword {
|
||||
let one = 1_u32.to_be_bytes(); // magic
|
||||
|
||||
let mut current = super::super::hmac_sha256(password, [salt, &one]);
|
||||
let mut result = current;
|
||||
for _ in 1..iterations {
|
||||
current = super::super::hmac_sha256(password, [current.as_ref()]);
|
||||
// TODO: result = current.zip(result).map(|(x, y)| x ^ y), issue #80094
|
||||
for (i, x) in current.iter().enumerate() {
|
||||
result[i] ^= x;
|
||||
}
|
||||
}
|
||||
|
||||
result.into()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pbkdf2() {
|
||||
let password = "a-very-secure-password";
|
||||
let salt = "such-a-random-salt";
|
||||
let iterations = 4096;
|
||||
let output = [
|
||||
203, 18, 206, 81, 4, 154, 193, 100, 147, 41, 211, 217, 177, 203, 69, 210, 194, 211,
|
||||
101, 1, 248, 156, 96, 0, 8, 223, 30, 87, 158, 41, 20, 42,
|
||||
];
|
||||
|
||||
let actual = SaltedPassword::new(password.as_bytes(), salt.as_bytes(), iterations);
|
||||
let expected = legacy_pbkdf2_impl(password.as_bytes(), salt.as_bytes(), iterations);
|
||||
|
||||
assert_eq!(actual.bytes, output);
|
||||
assert_eq!(actual.bytes, expected.bytes);
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,7 @@
|
||||
use super::base64_decode_array;
|
||||
use super::key::ScramKey;
|
||||
|
||||
/// Server secret is produced from [password](super::password::SaltedPassword)
|
||||
/// Server secret is produced from user's password,
|
||||
/// and is used throughout the authentication process.
|
||||
pub struct ServerSecret {
|
||||
/// Number of iterations for `PBKDF2` function.
|
||||
@@ -58,21 +58,10 @@ impl ServerSecret {
|
||||
/// Build a new server secret from the prerequisites.
|
||||
/// XXX: We only use this function in tests.
|
||||
#[cfg(test)]
|
||||
pub fn build(password: &str, salt: &[u8], iterations: u32) -> Option<Self> {
|
||||
// TODO: implement proper password normalization required by the RFC
|
||||
if !password.is_ascii() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let password = super::password::SaltedPassword::new(password.as_bytes(), salt, iterations);
|
||||
|
||||
Some(Self {
|
||||
iterations,
|
||||
salt_base64: base64::encode(salt),
|
||||
stored_key: password.client_key().sha256(),
|
||||
server_key: password.server_key(),
|
||||
doomed: false,
|
||||
})
|
||||
pub fn build(password: &str) -> Option<Self> {
|
||||
Self::parse(&postgres_protocol::password::scram_sha_256(
|
||||
password.as_bytes(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,20 +91,4 @@ mod tests {
|
||||
assert_eq!(base64::encode(parsed.stored_key), stored_key);
|
||||
assert_eq!(base64::encode(parsed.server_key), server_key);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_scram_secret() {
|
||||
let salt = b"salt";
|
||||
let secret = ServerSecret::build("password", salt, 4096).unwrap();
|
||||
assert_eq!(secret.iterations, 4096);
|
||||
assert_eq!(secret.salt_base64, base64::encode(salt));
|
||||
assert_eq!(
|
||||
base64::encode(secret.stored_key.as_ref()),
|
||||
"lF4cRm/Jky763CN4HtxdHnjV4Q8AWTNlKvGmEFFU8IQ="
|
||||
);
|
||||
assert_eq!(
|
||||
base64::encode(secret.server_key.as_ref()),
|
||||
"ub8OgRsftnk2ccDMOt7ffHXNcikRkQkq1lh4xaAqrSw="
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -234,7 +234,10 @@ async fn start_safekeeper(conf: SafeKeeperConf) -> Result<()> {
|
||||
listen_pg_addr_tenant_only
|
||||
);
|
||||
let listener = tcp_listener::bind(listen_pg_addr_tenant_only.clone()).map_err(|e| {
|
||||
error!("failed to bind to address {}: {}", conf.listen_pg_addr, e);
|
||||
error!(
|
||||
"failed to bind to address {}: {}",
|
||||
listen_pg_addr_tenant_only, e
|
||||
);
|
||||
e
|
||||
})?;
|
||||
Some(listener)
|
||||
|
||||
@@ -257,28 +257,15 @@ def prepare_snapshot(
|
||||
shutil.rmtree(repo_dir / "pgdatadirs")
|
||||
os.mkdir(repo_dir / "endpoints")
|
||||
|
||||
# Remove wal-redo temp directory if it exists. Newer pageserver versions don't create
|
||||
# them anymore, but old versions did.
|
||||
for tenant in (repo_dir / "tenants").glob("*"):
|
||||
wal_redo_dir = tenant / "wal-redo-datadir.___temp"
|
||||
if wal_redo_dir.exists() and wal_redo_dir.is_dir():
|
||||
shutil.rmtree(wal_redo_dir)
|
||||
|
||||
# Update paths and ports in config files
|
||||
pageserver_toml = repo_dir / "pageserver.toml"
|
||||
pageserver_config = toml.load(pageserver_toml)
|
||||
pageserver_config["remote_storage"]["local_path"] = str(repo_dir / "local_fs_remote_storage")
|
||||
pageserver_config["listen_http_addr"] = port_distributor.replace_with_new_port(
|
||||
pageserver_config["listen_http_addr"]
|
||||
)
|
||||
pageserver_config["listen_pg_addr"] = port_distributor.replace_with_new_port(
|
||||
pageserver_config["listen_pg_addr"]
|
||||
)
|
||||
for param in ("listen_http_addr", "listen_pg_addr", "broker_endpoint"):
|
||||
pageserver_config[param] = port_distributor.replace_with_new_port(pageserver_config[param])
|
||||
|
||||
# Older pageserver versions had just one `auth_type` setting. Now there
|
||||
# are separate settings for pg and http ports. We don't use authentication
|
||||
# in compatibility tests so just remove authentication related settings.
|
||||
pageserver_config.pop("auth_type", None)
|
||||
# We don't use authentication in compatibility tests
|
||||
# so just remove authentication related settings.
|
||||
pageserver_config.pop("pg_auth_type", None)
|
||||
pageserver_config.pop("http_auth_type", None)
|
||||
|
||||
@@ -290,19 +277,16 @@ def prepare_snapshot(
|
||||
|
||||
snapshot_config_toml = repo_dir / "config"
|
||||
snapshot_config = toml.load(snapshot_config_toml)
|
||||
|
||||
broker_listen_addr = f"127.0.0.1:{port_distributor.get_port()}"
|
||||
snapshot_config["broker"] = {"listen_addr": broker_listen_addr}
|
||||
|
||||
snapshot_config["pageserver"]["listen_http_addr"] = port_distributor.replace_with_new_port(
|
||||
snapshot_config["pageserver"]["listen_http_addr"]
|
||||
)
|
||||
snapshot_config["pageserver"]["listen_pg_addr"] = port_distributor.replace_with_new_port(
|
||||
snapshot_config["pageserver"]["listen_pg_addr"]
|
||||
for param in ("listen_http_addr", "listen_pg_addr"):
|
||||
snapshot_config["pageserver"][param] = port_distributor.replace_with_new_port(
|
||||
snapshot_config["pageserver"][param]
|
||||
)
|
||||
snapshot_config["broker"]["listen_addr"] = port_distributor.replace_with_new_port(
|
||||
snapshot_config["broker"]["listen_addr"]
|
||||
)
|
||||
for sk in snapshot_config["safekeepers"]:
|
||||
sk["http_port"] = port_distributor.replace_with_new_port(sk["http_port"])
|
||||
sk["pg_port"] = port_distributor.replace_with_new_port(sk["pg_port"])
|
||||
for param in ("http_port", "pg_port", "pg_tenant_only_port"):
|
||||
sk[param] = port_distributor.replace_with_new_port(sk[param])
|
||||
|
||||
if pg_distrib_dir:
|
||||
snapshot_config["pg_distrib_dir"] = str(pg_distrib_dir)
|
||||
|
||||
@@ -14,10 +14,6 @@ from fixtures.neon_fixtures import NeonEnvBuilder, PgBin
|
||||
def test_gc_cutoff(neon_env_builder: NeonEnvBuilder, pg_bin: PgBin):
|
||||
env = neon_env_builder.init_start()
|
||||
|
||||
# These warnings are expected, when the pageserver is restarted abruptly
|
||||
env.pageserver.allowed_errors.append(".*found future image layer.*")
|
||||
env.pageserver.allowed_errors.append(".*found future delta layer.*")
|
||||
|
||||
pageserver_http = env.pageserver.http_client()
|
||||
|
||||
# Use aggressive GC and checkpoint settings, so that we also exercise GC during the test
|
||||
|
||||
@@ -72,10 +72,6 @@ def test_pageserver_restart(neon_env_builder: NeonEnvBuilder):
|
||||
def test_pageserver_chaos(neon_env_builder: NeonEnvBuilder):
|
||||
env = neon_env_builder.init_start()
|
||||
|
||||
# These warnings are expected, when the pageserver is restarted abruptly
|
||||
env.pageserver.allowed_errors.append(".*found future image layer.*")
|
||||
env.pageserver.allowed_errors.append(".*found future delta layer.*")
|
||||
|
||||
# Use a tiny checkpoint distance, to create a lot of layers quickly.
|
||||
# That allows us to stress the compaction and layer flushing logic more.
|
||||
tenant, _ = env.neon_cli.create_tenant(
|
||||
|
||||
@@ -265,18 +265,23 @@ def test_sql_over_http_output_options(static_proxy: NeonProxy):
|
||||
def test_sql_over_http_batch(static_proxy: NeonProxy):
|
||||
static_proxy.safe_psql("create role http with login password 'http' superuser")
|
||||
|
||||
def qq(queries: List[Tuple[str, Optional[List[Any]]]]) -> Any:
|
||||
def qq(queries: List[Tuple[str, Optional[List[Any]]]], read_only: bool = False) -> Any:
|
||||
connstr = f"postgresql://http:http@{static_proxy.domain}:{static_proxy.proxy_port}/postgres"
|
||||
response = requests.post(
|
||||
f"https://{static_proxy.domain}:{static_proxy.external_http_port}/sql",
|
||||
data=json.dumps(list(map(lambda x: {"query": x[0], "params": x[1] or []}, queries))),
|
||||
headers={"Content-Type": "application/sql", "Neon-Connection-String": connstr},
|
||||
headers={
|
||||
"Content-Type": "application/sql",
|
||||
"Neon-Connection-String": connstr,
|
||||
"Neon-Batch-Isolation-Level": "Serializable",
|
||||
"Neon-Batch-Read-Only": "true" if read_only else "false",
|
||||
},
|
||||
verify=str(static_proxy.test_output_dir / "proxy.crt"),
|
||||
)
|
||||
assert response.status_code == 200
|
||||
return response.json()["results"]
|
||||
return response.json()["results"], response.headers
|
||||
|
||||
result = qq(
|
||||
result, headers = qq(
|
||||
[
|
||||
("select 42 as answer", None),
|
||||
("select $1 as answer", [42]),
|
||||
@@ -291,6 +296,9 @@ def test_sql_over_http_batch(static_proxy: NeonProxy):
|
||||
]
|
||||
)
|
||||
|
||||
assert headers["Neon-Batch-Isolation-Level"] == "Serializable"
|
||||
assert headers["Neon-Batch-Read-Only"] == "false"
|
||||
|
||||
assert result[0]["rows"] == [{"answer": 42}]
|
||||
assert result[1]["rows"] == [{"answer": "42"}]
|
||||
assert result[2]["rows"] == [{"answer": 42}]
|
||||
@@ -311,3 +319,14 @@ def test_sql_over_http_batch(static_proxy: NeonProxy):
|
||||
assert res["command"] == "DROP"
|
||||
assert res["rowCount"] is None
|
||||
assert len(result) == 10
|
||||
|
||||
result, headers = qq(
|
||||
[
|
||||
("select 42 as answer", None),
|
||||
],
|
||||
True,
|
||||
)
|
||||
assert headers["Neon-Batch-Isolation-Level"] == "Serializable"
|
||||
assert headers["Neon-Batch-Read-Only"] == "true"
|
||||
|
||||
assert result[0]["rows"] == [{"answer": 42}]
|
||||
|
||||
@@ -15,10 +15,6 @@ def test_pageserver_recovery(neon_env_builder: NeonEnvBuilder):
|
||||
env = neon_env_builder.init_start()
|
||||
env.pageserver.is_testing_enabled_or_skip()
|
||||
|
||||
# These warnings are expected, when the pageserver is restarted abruptly
|
||||
env.pageserver.allowed_errors.append(".*found future delta layer.*")
|
||||
env.pageserver.allowed_errors.append(".*found future image layer.*")
|
||||
|
||||
# Create a branch for us
|
||||
env.neon_cli.create_branch("test_pageserver_recovery", "main")
|
||||
|
||||
|
||||
@@ -348,9 +348,6 @@ def test_remote_storage_upload_queue_retries(
|
||||
# XXX: should vary this test to selectively fail just layer uploads, index uploads, deletions
|
||||
# but how do we validate the result after restore?
|
||||
|
||||
# these are always possible when we do an immediate stop. perhaps something with compacting has changed since.
|
||||
env.pageserver.allowed_errors.append(r".*found future (delta|image) layer.*")
|
||||
|
||||
env.pageserver.stop(immediate=True)
|
||||
env.endpoints.stop_all()
|
||||
|
||||
|
||||
2
vendor/postgres-v14
vendored
2
vendor/postgres-v14
vendored
Submodule vendor/postgres-v14 updated: ebedb34d01...da3885c34d
2
vendor/postgres-v15
vendored
2
vendor/postgres-v15
vendored
Submodule vendor/postgres-v15 updated: 1220c8a63f...770c6dffc5
4
vendor/revisions.json
vendored
4
vendor/revisions.json
vendored
@@ -1,4 +1,4 @@
|
||||
{
|
||||
"postgres-v15": "1220c8a63f00101829f9222a5821fc084b4384c7",
|
||||
"postgres-v14": "ebedb34d01c8ac9c31e8ea4628b9854103a1dc8f"
|
||||
"postgres-v15": "770c6dffc5ef6aac05bf049693877fb377eea6fc",
|
||||
"postgres-v14": "da3885c34db312afd555802be2ce985fafd1d8ad"
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user