mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-07 05:22:56 +00:00
Merge remote-tracking branch 'origin/main' into communicator-rewrite
This commit is contained in:
@@ -108,11 +108,10 @@ pub enum PromoteState {
|
||||
Failed { error: String },
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Default, Debug, Clone)]
|
||||
#[derive(Deserialize, Default, Debug)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
/// Result of /safekeepers_lsn
|
||||
pub struct SafekeepersLsn {
|
||||
pub safekeepers: String,
|
||||
pub struct PromoteConfig {
|
||||
pub spec: ComputeSpec,
|
||||
pub wal_flush_lsn: utils::lsn::Lsn,
|
||||
}
|
||||
|
||||
@@ -173,6 +172,11 @@ pub enum ComputeStatus {
|
||||
TerminationPendingImmediate,
|
||||
// Terminated Postgres
|
||||
Terminated,
|
||||
// A spec refresh is being requested
|
||||
RefreshConfigurationPending,
|
||||
// A spec refresh is being applied. We cannot refresh configuration again until the current
|
||||
// refresh is done, i.e., signal_refresh_configuration() will return 500 error.
|
||||
RefreshConfiguration,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
@@ -185,6 +189,10 @@ impl Display for ComputeStatus {
|
||||
match self {
|
||||
ComputeStatus::Empty => f.write_str("empty"),
|
||||
ComputeStatus::ConfigurationPending => f.write_str("configuration-pending"),
|
||||
ComputeStatus::RefreshConfiguration => f.write_str("refresh-configuration"),
|
||||
ComputeStatus::RefreshConfigurationPending => {
|
||||
f.write_str("refresh-configuration-pending")
|
||||
}
|
||||
ComputeStatus::Init => f.write_str("init"),
|
||||
ComputeStatus::Running => f.write_str("running"),
|
||||
ComputeStatus::Configuration => f.write_str("configuration"),
|
||||
|
||||
@@ -12,9 +12,9 @@ use regex::Regex;
|
||||
use remote_storage::RemotePath;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use url::Url;
|
||||
use utils::id::{TenantId, TimelineId};
|
||||
use utils::id::{NodeId, TenantId, TimelineId};
|
||||
use utils::lsn::Lsn;
|
||||
use utils::shard::{ShardCount, ShardIndex};
|
||||
use utils::shard::{ShardCount, ShardIndex, ShardNumber, ShardStripeSize};
|
||||
|
||||
use crate::responses::TlsConfig;
|
||||
|
||||
@@ -115,10 +115,18 @@ pub struct ComputeSpec {
|
||||
/// The goal is to use method 1. everywhere. But for backwards-compatibility with old
|
||||
/// versions of the control plane, `compute_ctl` will check 2. and 3. if the
|
||||
/// `pageserver_connection_info` field is missing.
|
||||
///
|
||||
/// If both `pageserver_connection_info` and `pageserver_connstring`+`shard_stripe_size` are
|
||||
/// given, they must contain the same information.
|
||||
pub pageserver_connection_info: Option<PageserverConnectionInfo>,
|
||||
|
||||
pub pageserver_connstring: Option<String>,
|
||||
|
||||
/// Stripe size for pageserver sharding, in pages. This is set together with the legacy
|
||||
/// `pageserver_connstring` field. When the modern `pageserver_connection_info` field is used,
|
||||
/// the stripe size is stored in `pageserver_connection_info.stripe_size` instead.
|
||||
pub shard_stripe_size: Option<ShardStripeSize>,
|
||||
|
||||
// More neon ids that we expose to the compute_ctl
|
||||
// and to postgres as neon extension GUCs.
|
||||
pub project_id: Option<String>,
|
||||
@@ -151,10 +159,6 @@ pub struct ComputeSpec {
|
||||
|
||||
pub pgbouncer_settings: Option<IndexMap<String, String>>,
|
||||
|
||||
// Stripe size for pageserver sharding, in pages
|
||||
#[serde(default)]
|
||||
pub shard_stripe_size: Option<u32>,
|
||||
|
||||
/// Local Proxy configuration used for JWT authentication
|
||||
#[serde(default)]
|
||||
pub local_proxy_config: Option<LocalProxySpec>,
|
||||
@@ -205,6 +209,9 @@ pub struct ComputeSpec {
|
||||
///
|
||||
/// We use this value to derive other values, such as the installed extensions metric.
|
||||
pub suspend_timeout_seconds: i64,
|
||||
|
||||
// Databricks specific options for compute instance.
|
||||
pub databricks_settings: Option<DatabricksSettings>,
|
||||
}
|
||||
|
||||
/// Feature flag to signal `compute_ctl` to enable certain experimental functionality.
|
||||
@@ -232,14 +239,122 @@ pub struct PageserverConnectionInfo {
|
||||
pub shard_count: ShardCount,
|
||||
|
||||
/// INVARIANT: null if shard_count is 0, otherwise non-null and immutable
|
||||
pub stripe_size: Option<u32>,
|
||||
pub stripe_size: Option<ShardStripeSize>,
|
||||
|
||||
pub shards: HashMap<ShardIndex, PageserverShardInfo>,
|
||||
|
||||
/// If the compute supports both protocols, this indicates which one it should use. The compute
|
||||
/// may use other available protocols too, if it doesn't support the preferred one. The URL's
|
||||
/// for the protocol specified here must be present for all shards, i.e. do not mark a protocol
|
||||
/// as preferred if it cannot actually be used with all the pageservers.
|
||||
#[serde(default)]
|
||||
pub prefer_protocol: PageserverProtocol,
|
||||
}
|
||||
|
||||
/// Extract PageserverConnectionInfo from a comma-separated list of libpq connection strings.
|
||||
///
|
||||
/// This is used for backwards-compatibility, to parse the legacy
|
||||
/// [ComputeSpec::pageserver_connstring] field, or the 'neon.pageserver_connstring' GUC. Nowadays,
|
||||
/// the 'pageserver_connection_info' field should be used instead.
|
||||
impl PageserverConnectionInfo {
|
||||
pub fn from_connstr(
|
||||
connstr: &str,
|
||||
stripe_size: Option<ShardStripeSize>,
|
||||
) -> Result<PageserverConnectionInfo, anyhow::Error> {
|
||||
let shard_infos: Vec<_> = connstr
|
||||
.split(',')
|
||||
.map(|connstr| PageserverShardInfo {
|
||||
pageservers: vec![PageserverShardConnectionInfo {
|
||||
id: None,
|
||||
libpq_url: Some(connstr.to_string()),
|
||||
grpc_url: None,
|
||||
}],
|
||||
})
|
||||
.collect();
|
||||
|
||||
match shard_infos.len() {
|
||||
0 => anyhow::bail!("empty connection string"),
|
||||
1 => {
|
||||
// We assume that if there's only connection string, it means "unsharded",
|
||||
// rather than a sharded system with just a single shard. The latter is
|
||||
// possible in principle, but we never do it.
|
||||
let shard_count = ShardCount::unsharded();
|
||||
let only_shard = shard_infos.first().unwrap().clone();
|
||||
let shards = vec![(ShardIndex::unsharded(), only_shard)];
|
||||
Ok(PageserverConnectionInfo {
|
||||
shard_count,
|
||||
stripe_size: None,
|
||||
shards: shards.into_iter().collect(),
|
||||
prefer_protocol: PageserverProtocol::Libpq,
|
||||
})
|
||||
}
|
||||
n => {
|
||||
if stripe_size.is_none() {
|
||||
anyhow::bail!("{n} shards but no stripe_size");
|
||||
}
|
||||
let shard_count = ShardCount(n.try_into()?);
|
||||
let shards = shard_infos
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(idx, shard_info)| {
|
||||
(
|
||||
ShardIndex {
|
||||
shard_count,
|
||||
shard_number: ShardNumber(
|
||||
idx.try_into().expect("shard number fits in u8"),
|
||||
),
|
||||
},
|
||||
shard_info,
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
Ok(PageserverConnectionInfo {
|
||||
shard_count,
|
||||
stripe_size,
|
||||
shards,
|
||||
prefer_protocol: PageserverProtocol::Libpq,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convenience routine to get the connection string for a shard.
|
||||
pub fn shard_url(
|
||||
&self,
|
||||
shard_number: ShardNumber,
|
||||
protocol: PageserverProtocol,
|
||||
) -> anyhow::Result<&str> {
|
||||
let shard_index = ShardIndex {
|
||||
shard_number,
|
||||
shard_count: self.shard_count,
|
||||
};
|
||||
let shard = self.shards.get(&shard_index).ok_or_else(|| {
|
||||
anyhow::anyhow!("shard connection info missing for shard {}", shard_index)
|
||||
})?;
|
||||
|
||||
// Just use the first pageserver in the list. That's good enough for this
|
||||
// convenience routine; if you need more control, like round robin policy or
|
||||
// failover support, roll your own. (As of this writing, we never have more than
|
||||
// one pageserver per shard anyway, but that will change in the future.)
|
||||
let pageserver = shard
|
||||
.pageservers
|
||||
.first()
|
||||
.ok_or(anyhow::anyhow!("must have at least one pageserver"))?;
|
||||
|
||||
let result = match protocol {
|
||||
PageserverProtocol::Grpc => pageserver
|
||||
.grpc_url
|
||||
.as_ref()
|
||||
.ok_or(anyhow::anyhow!("no grpc_url for shard {shard_index}"))?,
|
||||
PageserverProtocol::Libpq => pageserver
|
||||
.libpq_url
|
||||
.as_ref()
|
||||
.ok_or(anyhow::anyhow!("no libpq_url for shard {shard_index}"))?,
|
||||
};
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, Eq, PartialEq)]
|
||||
pub struct PageserverShardInfo {
|
||||
pub pageservers: Vec<PageserverShardConnectionInfo>,
|
||||
@@ -247,7 +362,7 @@ pub struct PageserverShardInfo {
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, Eq, PartialEq)]
|
||||
pub struct PageserverShardConnectionInfo {
|
||||
pub id: Option<String>,
|
||||
pub id: Option<NodeId>,
|
||||
pub libpq_url: Option<String>,
|
||||
pub grpc_url: Option<String>,
|
||||
}
|
||||
|
||||
@@ -558,11 +558,11 @@ async fn add_request_id_header_to_response(
|
||||
mut res: Response<Body>,
|
||||
req_info: RequestInfo,
|
||||
) -> Result<Response<Body>, ApiError> {
|
||||
if let Some(request_id) = req_info.context::<RequestId>() {
|
||||
if let Ok(request_header_value) = HeaderValue::from_str(&request_id.0) {
|
||||
res.headers_mut()
|
||||
.insert(&X_REQUEST_ID_HEADER, request_header_value);
|
||||
};
|
||||
if let Some(request_id) = req_info.context::<RequestId>()
|
||||
&& let Ok(request_header_value) = HeaderValue::from_str(&request_id.0)
|
||||
{
|
||||
res.headers_mut()
|
||||
.insert(&X_REQUEST_ID_HEADER, request_header_value);
|
||||
};
|
||||
|
||||
Ok(res)
|
||||
|
||||
@@ -72,10 +72,10 @@ impl Server {
|
||||
if err.is_incomplete_message() || err.is_closed() || err.is_timeout() {
|
||||
return true;
|
||||
}
|
||||
if let Some(inner) = err.source() {
|
||||
if let Some(io) = inner.downcast_ref::<std::io::Error>() {
|
||||
return suppress_io_error(io);
|
||||
}
|
||||
if let Some(inner) = err.source()
|
||||
&& let Some(io) = inner.downcast_ref::<std::io::Error>()
|
||||
{
|
||||
return suppress_io_error(io);
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
@@ -129,6 +129,12 @@ impl<L: LabelGroup> InfoMetric<L> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<L: LabelGroup + Default> Default for InfoMetric<L, GaugeState> {
|
||||
fn default() -> Self {
|
||||
InfoMetric::new(L::default())
|
||||
}
|
||||
}
|
||||
|
||||
impl<L: LabelGroup, M: MetricType<Metadata = ()>> InfoMetric<L, M> {
|
||||
pub fn with_metric(label: L, metric: M) -> Self {
|
||||
Self {
|
||||
|
||||
@@ -394,7 +394,7 @@ where
|
||||
// TODO: An Iterator might be nicer. The communicator's clock algorithm needs to
|
||||
// _slowly_ iterate through all buckets with its clock hand, without holding a lock.
|
||||
// If we switch to an Iterator, it must not hold the lock.
|
||||
pub fn get_at_bucket(&self, pos: usize) -> Option<ValueReadGuard<(K, V)>> {
|
||||
pub fn get_at_bucket(&self, pos: usize) -> Option<ValueReadGuard<'_, (K, V)>> {
|
||||
let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
|
||||
if pos >= map.buckets.len() {
|
||||
return None;
|
||||
|
||||
@@ -1500,6 +1500,7 @@ pub struct TimelineArchivalConfigRequest {
|
||||
#[derive(Serialize, Deserialize, PartialEq, Eq, Clone)]
|
||||
pub struct TimelinePatchIndexPartRequest {
|
||||
pub rel_size_migration: Option<RelSizeMigration>,
|
||||
pub rel_size_migrated_at: Option<Lsn>,
|
||||
pub gc_compaction_last_completed_lsn: Option<Lsn>,
|
||||
pub applied_gc_cutoff_lsn: Option<Lsn>,
|
||||
#[serde(default)]
|
||||
@@ -1533,10 +1534,10 @@ pub enum RelSizeMigration {
|
||||
/// `None` is the same as `Some(RelSizeMigration::Legacy)`.
|
||||
Legacy,
|
||||
/// The tenant is migrating to the new rel_size format. Both old and new rel_size format are
|
||||
/// persisted in the index part. The read path will read both formats and merge them.
|
||||
/// persisted in the storage. The read path will read both formats and validate them.
|
||||
Migrating,
|
||||
/// The tenant has migrated to the new rel_size format. Only the new rel_size format is persisted
|
||||
/// in the index part, and the read path will not read the old format.
|
||||
/// in the storage, and the read path will not read the old format.
|
||||
Migrated,
|
||||
}
|
||||
|
||||
@@ -1619,6 +1620,7 @@ pub struct TimelineInfo {
|
||||
|
||||
/// The status of the rel_size migration.
|
||||
pub rel_size_migration: Option<RelSizeMigration>,
|
||||
pub rel_size_migrated_at: Option<Lsn>,
|
||||
|
||||
/// Whether the timeline is invisible in synthetic size calculations.
|
||||
pub is_invisible: Option<bool>,
|
||||
|
||||
@@ -11,8 +11,6 @@ anyhow.workspace = true
|
||||
crc32c.workspace = true
|
||||
criterion.workspace = true
|
||||
once_cell.workspace = true
|
||||
log.workspace = true
|
||||
memoffset.workspace = true
|
||||
pprof.workspace = true
|
||||
thiserror.workspace = true
|
||||
serde.workspace = true
|
||||
|
||||
@@ -34,9 +34,8 @@ const SIZEOF_CONTROLDATA: usize = size_of::<ControlFileData>();
|
||||
impl ControlFileData {
|
||||
/// Compute the offset of the `crc` field within the `ControlFileData` struct.
|
||||
/// Equivalent to offsetof(ControlFileData, crc) in C.
|
||||
// Someday this can be const when the right compiler features land.
|
||||
fn pg_control_crc_offset() -> usize {
|
||||
memoffset::offset_of!(ControlFileData, crc)
|
||||
const fn pg_control_crc_offset() -> usize {
|
||||
std::mem::offset_of!(ControlFileData, crc)
|
||||
}
|
||||
|
||||
///
|
||||
|
||||
@@ -4,12 +4,11 @@
|
||||
use crate::pg_constants;
|
||||
use crate::transaction_id_precedes;
|
||||
use bytes::BytesMut;
|
||||
use log::*;
|
||||
|
||||
use super::bindings::MultiXactId;
|
||||
|
||||
pub fn transaction_id_set_status(xid: u32, status: u8, page: &mut BytesMut) {
|
||||
trace!(
|
||||
tracing::trace!(
|
||||
"handle_apply_request for RM_XACT_ID-{} (1-commit, 2-abort, 3-sub_commit)",
|
||||
status
|
||||
);
|
||||
|
||||
@@ -14,7 +14,6 @@ use super::xlog_utils::*;
|
||||
use crate::WAL_SEGMENT_SIZE;
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
use crc32c::*;
|
||||
use log::*;
|
||||
use std::cmp::min;
|
||||
use std::num::NonZeroU32;
|
||||
use utils::lsn::Lsn;
|
||||
@@ -236,7 +235,7 @@ impl WalStreamDecoderHandler for WalStreamDecoder {
|
||||
// XLOG_SWITCH records are special. If we see one, we need to skip
|
||||
// to the next WAL segment.
|
||||
let next_lsn = if xlogrec.is_xlog_switch_record() {
|
||||
trace!("saw xlog switch record at {}", self.lsn);
|
||||
tracing::trace!("saw xlog switch record at {}", self.lsn);
|
||||
self.lsn + self.lsn.calc_padding(WAL_SEGMENT_SIZE as u64)
|
||||
} else {
|
||||
// Pad to an 8-byte boundary
|
||||
|
||||
@@ -23,8 +23,6 @@ use crate::{WAL_SEGMENT_SIZE, XLOG_BLCKSZ};
|
||||
use bytes::BytesMut;
|
||||
use bytes::{Buf, Bytes};
|
||||
|
||||
use log::*;
|
||||
|
||||
use serde::Serialize;
|
||||
use std::ffi::{CString, OsStr};
|
||||
use std::fs::File;
|
||||
@@ -235,7 +233,7 @@ pub fn find_end_of_wal(
|
||||
let mut curr_lsn = start_lsn;
|
||||
let mut buf = [0u8; XLOG_BLCKSZ];
|
||||
let pg_version = MY_PGVERSION;
|
||||
debug!("find_end_of_wal PG_VERSION: {}", pg_version);
|
||||
tracing::debug!("find_end_of_wal PG_VERSION: {}", pg_version);
|
||||
|
||||
let mut decoder = WalStreamDecoder::new(start_lsn, pg_version);
|
||||
|
||||
@@ -247,7 +245,7 @@ pub fn find_end_of_wal(
|
||||
match open_wal_segment(&seg_file_path)? {
|
||||
None => {
|
||||
// no more segments
|
||||
debug!(
|
||||
tracing::debug!(
|
||||
"find_end_of_wal reached end at {:?}, segment {:?} doesn't exist",
|
||||
result, seg_file_path
|
||||
);
|
||||
@@ -260,7 +258,7 @@ pub fn find_end_of_wal(
|
||||
while curr_lsn.segment_number(wal_seg_size) == segno {
|
||||
let bytes_read = segment.read(&mut buf)?;
|
||||
if bytes_read == 0 {
|
||||
debug!(
|
||||
tracing::debug!(
|
||||
"find_end_of_wal reached end at {:?}, EOF in segment {:?} at offset {}",
|
||||
result,
|
||||
seg_file_path,
|
||||
@@ -276,7 +274,7 @@ pub fn find_end_of_wal(
|
||||
match decoder.poll_decode() {
|
||||
Ok(Some(record)) => result = record.0,
|
||||
Err(e) => {
|
||||
debug!(
|
||||
tracing::debug!(
|
||||
"find_end_of_wal reached end at {:?}, decode error: {:?}",
|
||||
result, e
|
||||
);
|
||||
|
||||
2
libs/proxy/subzero_core/.gitignore
vendored
Normal file
2
libs/proxy/subzero_core/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
target
|
||||
Cargo.lock
|
||||
12
libs/proxy/subzero_core/Cargo.toml
Normal file
12
libs/proxy/subzero_core/Cargo.toml
Normal file
@@ -0,0 +1,12 @@
|
||||
# This is a stub for the subzero-core crate.
|
||||
[package]
|
||||
name = "subzero-core"
|
||||
version = "3.0.1"
|
||||
edition = "2024"
|
||||
publish = false # "private"!
|
||||
|
||||
[features]
|
||||
default = []
|
||||
postgresql = []
|
||||
|
||||
[dependencies]
|
||||
1
libs/proxy/subzero_core/src/lib.rs
Normal file
1
libs/proxy/subzero_core/src/lib.rs
Normal file
@@ -0,0 +1 @@
|
||||
// This is a stub for the subzero-core crate.
|
||||
@@ -15,6 +15,7 @@ use tokio::sync::mpsc;
|
||||
use crate::cancel_token::RawCancelToken;
|
||||
use crate::codec::{BackendMessages, FrontendMessage, RecordNotices};
|
||||
use crate::config::{Host, SslMode};
|
||||
use crate::connection::gc_bytesmut;
|
||||
use crate::query::RowStream;
|
||||
use crate::simple_query::SimpleQueryStream;
|
||||
use crate::types::{Oid, Type};
|
||||
@@ -95,20 +96,13 @@ impl InnerClient {
|
||||
Ok(PartialQuery(Some(self)))
|
||||
}
|
||||
|
||||
// pub fn send_with_sync<F>(&mut self, f: F) -> Result<&mut Responses, Error>
|
||||
// where
|
||||
// F: FnOnce(&mut BytesMut) -> Result<(), Error>,
|
||||
// {
|
||||
// self.start()?.send_with_sync(f)
|
||||
// }
|
||||
|
||||
pub fn send_simple_query(&mut self, query: &str) -> Result<&mut Responses, Error> {
|
||||
self.responses.waiting += 1;
|
||||
|
||||
self.buffer.clear();
|
||||
// simple queries do not need sync.
|
||||
frontend::query(query, &mut self.buffer).map_err(Error::encode)?;
|
||||
let buf = self.buffer.split().freeze();
|
||||
let buf = self.buffer.split();
|
||||
self.send_message(FrontendMessage::Raw(buf))
|
||||
}
|
||||
|
||||
@@ -125,7 +119,7 @@ impl Drop for PartialQuery<'_> {
|
||||
if let Some(client) = self.0.take() {
|
||||
client.buffer.clear();
|
||||
frontend::sync(&mut client.buffer);
|
||||
let buf = client.buffer.split().freeze();
|
||||
let buf = client.buffer.split();
|
||||
let _ = client.send_message(FrontendMessage::Raw(buf));
|
||||
}
|
||||
}
|
||||
@@ -141,7 +135,7 @@ impl<'a> PartialQuery<'a> {
|
||||
client.buffer.clear();
|
||||
f(&mut client.buffer)?;
|
||||
frontend::flush(&mut client.buffer);
|
||||
let buf = client.buffer.split().freeze();
|
||||
let buf = client.buffer.split();
|
||||
client.send_message(FrontendMessage::Raw(buf))
|
||||
}
|
||||
|
||||
@@ -154,7 +148,7 @@ impl<'a> PartialQuery<'a> {
|
||||
client.buffer.clear();
|
||||
f(&mut client.buffer)?;
|
||||
frontend::sync(&mut client.buffer);
|
||||
let buf = client.buffer.split().freeze();
|
||||
let buf = client.buffer.split();
|
||||
let _ = client.send_message(FrontendMessage::Raw(buf));
|
||||
|
||||
Ok(&mut self.0.take().unwrap().responses)
|
||||
@@ -191,6 +185,7 @@ impl Client {
|
||||
ssl_mode: SslMode,
|
||||
process_id: i32,
|
||||
secret_key: i32,
|
||||
write_buf: BytesMut,
|
||||
) -> Client {
|
||||
Client {
|
||||
inner: InnerClient {
|
||||
@@ -201,7 +196,7 @@ impl Client {
|
||||
waiting: 0,
|
||||
received: 0,
|
||||
},
|
||||
buffer: Default::default(),
|
||||
buffer: write_buf,
|
||||
},
|
||||
cached_typeinfo: Default::default(),
|
||||
|
||||
@@ -292,8 +287,35 @@ impl Client {
|
||||
simple_query::batch_execute(self.inner_mut(), query).await
|
||||
}
|
||||
|
||||
pub async fn discard_all(&mut self) -> Result<ReadyForQueryStatus, Error> {
|
||||
self.batch_execute("discard all").await
|
||||
/// Similar to `discard_all`, but it does not clear any query plans
|
||||
///
|
||||
/// This runs in the background, so it can be executed without `await`ing.
|
||||
pub fn reset_session_background(&mut self) -> Result<(), Error> {
|
||||
// "CLOSE ALL": closes any cursors
|
||||
// "SET SESSION AUTHORIZATION DEFAULT": resets the current_user back to the session_user
|
||||
// "RESET ALL": resets any GUCs back to their session defaults.
|
||||
// "DEALLOCATE ALL": deallocates any prepared statements
|
||||
// "UNLISTEN *": stops listening on all channels
|
||||
// "SELECT pg_advisory_unlock_all();": unlocks all advisory locks
|
||||
// "DISCARD TEMP;": drops all temporary tables
|
||||
// "DISCARD SEQUENCES;": deallocates all cached sequence state
|
||||
|
||||
let _responses = self.inner_mut().send_simple_query(
|
||||
"ROLLBACK;
|
||||
CLOSE ALL;
|
||||
SET SESSION AUTHORIZATION DEFAULT;
|
||||
RESET ALL;
|
||||
DEALLOCATE ALL;
|
||||
UNLISTEN *;
|
||||
SELECT pg_advisory_unlock_all();
|
||||
DISCARD TEMP;
|
||||
DISCARD SEQUENCES;",
|
||||
)?;
|
||||
|
||||
// Clean up memory usage.
|
||||
gc_bytesmut(&mut self.inner_mut().buffer);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Begins a new database transaction.
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
use std::io;
|
||||
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use bytes::BytesMut;
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use postgres_protocol2::message::backend;
|
||||
use tokio::sync::mpsc::UnboundedSender;
|
||||
use tokio_util::codec::{Decoder, Encoder};
|
||||
|
||||
pub enum FrontendMessage {
|
||||
Raw(Bytes),
|
||||
Raw(BytesMut),
|
||||
RecordNotices(RecordNotices),
|
||||
}
|
||||
|
||||
@@ -17,7 +17,10 @@ pub struct RecordNotices {
|
||||
}
|
||||
|
||||
pub enum BackendMessage {
|
||||
Normal { messages: BackendMessages },
|
||||
Normal {
|
||||
messages: BackendMessages,
|
||||
ready: bool,
|
||||
},
|
||||
Async(backend::Message),
|
||||
}
|
||||
|
||||
@@ -40,11 +43,11 @@ impl FallibleIterator for BackendMessages {
|
||||
|
||||
pub struct PostgresCodec;
|
||||
|
||||
impl Encoder<Bytes> for PostgresCodec {
|
||||
impl Encoder<BytesMut> for PostgresCodec {
|
||||
type Error = io::Error;
|
||||
|
||||
fn encode(&mut self, item: Bytes, dst: &mut BytesMut) -> io::Result<()> {
|
||||
dst.extend_from_slice(&item);
|
||||
fn encode(&mut self, item: BytesMut, dst: &mut BytesMut) -> io::Result<()> {
|
||||
dst.unsplit(item);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -56,6 +59,7 @@ impl Decoder for PostgresCodec {
|
||||
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<BackendMessage>, io::Error> {
|
||||
let mut idx = 0;
|
||||
|
||||
let mut ready = false;
|
||||
while let Some(header) = backend::Header::parse(&src[idx..])? {
|
||||
let len = header.len() as usize + 1;
|
||||
if src[idx..].len() < len {
|
||||
@@ -79,6 +83,7 @@ impl Decoder for PostgresCodec {
|
||||
idx += len;
|
||||
|
||||
if header.tag() == backend::READY_FOR_QUERY_TAG {
|
||||
ready = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -88,6 +93,7 @@ impl Decoder for PostgresCodec {
|
||||
} else {
|
||||
Ok(Some(BackendMessage::Normal {
|
||||
messages: BackendMessages(src.split_to(idx)),
|
||||
ready,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,9 +11,8 @@ use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
use crate::connect::connect;
|
||||
use crate::connect_raw::{RawConnection, connect_raw};
|
||||
use crate::connect_raw::{self, StartupStream};
|
||||
use crate::connect_tls::connect_tls;
|
||||
use crate::maybe_tls_stream::MaybeTlsStream;
|
||||
use crate::tls::{MakeTlsConnect, TlsConnect, TlsStream};
|
||||
use crate::{Client, Connection, Error};
|
||||
|
||||
@@ -244,24 +243,27 @@ impl Config {
|
||||
&self,
|
||||
stream: S,
|
||||
tls: T,
|
||||
) -> Result<RawConnection<S, T::Stream>, Error>
|
||||
) -> Result<StartupStream<S, T::Stream>, Error>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: TlsConnect<S>,
|
||||
{
|
||||
let stream = connect_tls(stream, self.ssl_mode, tls).await?;
|
||||
connect_raw(stream, self).await
|
||||
let mut stream = StartupStream::new(stream);
|
||||
connect_raw::authenticate(&mut stream, self).await?;
|
||||
|
||||
Ok(stream)
|
||||
}
|
||||
|
||||
pub async fn authenticate<S, T>(
|
||||
pub fn authenticate<S, T>(
|
||||
&self,
|
||||
stream: MaybeTlsStream<S, T>,
|
||||
) -> Result<RawConnection<S, T>, Error>
|
||||
stream: &mut StartupStream<S, T>,
|
||||
) -> impl Future<Output = Result<(), Error>>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: TlsStream + Unpin,
|
||||
{
|
||||
connect_raw(stream, self).await
|
||||
connect_raw::authenticate(stream, self)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
use std::net::IpAddr;
|
||||
|
||||
use futures_util::TryStreamExt;
|
||||
use postgres_protocol2::message::backend::Message;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::client::SocketConfig;
|
||||
use crate::config::Host;
|
||||
use crate::connect_raw::connect_raw;
|
||||
use crate::config::{Host, SslMode};
|
||||
use crate::connect_raw::StartupStream;
|
||||
use crate::connect_socket::connect_socket;
|
||||
use crate::connect_tls::connect_tls;
|
||||
use crate::tls::{MakeTlsConnect, TlsConnect};
|
||||
use crate::{Client, Config, Connection, Error, RawConnection};
|
||||
use crate::{Client, Config, Connection, Error};
|
||||
|
||||
pub async fn connect<T>(
|
||||
tls: &T,
|
||||
@@ -43,34 +45,78 @@ where
|
||||
T: TlsConnect<TcpStream>,
|
||||
{
|
||||
let socket = connect_socket(host_addr, host, port, config.connect_timeout).await?;
|
||||
let stream = connect_tls(socket, config.ssl_mode, tls).await?;
|
||||
let RawConnection {
|
||||
let stream = config.tls_and_authenticate(socket, tls).await?;
|
||||
managed(
|
||||
stream,
|
||||
parameters: _,
|
||||
delayed_notice: _,
|
||||
process_id,
|
||||
secret_key,
|
||||
} = connect_raw(stream, config).await?;
|
||||
host_addr,
|
||||
host.clone(),
|
||||
port,
|
||||
config.ssl_mode,
|
||||
config.connect_timeout,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn managed<TlsStream>(
|
||||
mut stream: StartupStream<TcpStream, TlsStream>,
|
||||
host_addr: Option<IpAddr>,
|
||||
host: Host,
|
||||
port: u16,
|
||||
ssl_mode: SslMode,
|
||||
connect_timeout: Option<std::time::Duration>,
|
||||
) -> Result<(Client, Connection<TcpStream, TlsStream>), Error>
|
||||
where
|
||||
TlsStream: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
let (process_id, secret_key) = wait_until_ready(&mut stream).await?;
|
||||
|
||||
let socket_config = SocketConfig {
|
||||
host_addr,
|
||||
host: host.clone(),
|
||||
host,
|
||||
port,
|
||||
connect_timeout: config.connect_timeout,
|
||||
connect_timeout,
|
||||
};
|
||||
|
||||
let mut stream = stream.into_framed();
|
||||
let write_buf = std::mem::take(stream.write_buffer_mut());
|
||||
|
||||
let (client_tx, conn_rx) = mpsc::unbounded_channel();
|
||||
let (conn_tx, client_rx) = mpsc::channel(4);
|
||||
let client = Client::new(
|
||||
client_tx,
|
||||
client_rx,
|
||||
socket_config,
|
||||
config.ssl_mode,
|
||||
ssl_mode,
|
||||
process_id,
|
||||
secret_key,
|
||||
write_buf,
|
||||
);
|
||||
|
||||
let connection = Connection::new(stream, conn_tx, conn_rx);
|
||||
|
||||
Ok((client, connection))
|
||||
}
|
||||
|
||||
async fn wait_until_ready<S, T>(stream: &mut StartupStream<S, T>) -> Result<(i32, i32), Error>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
let mut process_id = 0;
|
||||
let mut secret_key = 0;
|
||||
|
||||
loop {
|
||||
match stream.try_next().await.map_err(Error::io)? {
|
||||
Some(Message::BackendKeyData(body)) => {
|
||||
process_id = body.process_id();
|
||||
secret_key = body.secret_key();
|
||||
}
|
||||
// These values are currently not used by `Client`/`Connection`. Ignore them.
|
||||
Some(Message::ParameterStatus(_)) | Some(Message::NoticeResponse(_)) => {}
|
||||
Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key)),
|
||||
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
|
||||
Some(_) => return Err(Error::unexpected_message()),
|
||||
None => return Err(Error::closed()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,52 +1,27 @@
|
||||
use std::collections::HashMap;
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use std::task::{Context, Poll, ready};
|
||||
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use bytes::BytesMut;
|
||||
use fallible_iterator::FallibleIterator;
|
||||
use futures_util::{Sink, SinkExt, Stream, TryStreamExt, ready};
|
||||
use futures_util::{SinkExt, Stream, TryStreamExt};
|
||||
use postgres_protocol2::authentication::sasl;
|
||||
use postgres_protocol2::authentication::sasl::ScramSha256;
|
||||
use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message, NoticeResponseBody};
|
||||
use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message};
|
||||
use postgres_protocol2::message::frontend;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_util::codec::Framed;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio_util::codec::{Framed, FramedParts};
|
||||
|
||||
use crate::Error;
|
||||
use crate::codec::{BackendMessage, BackendMessages, PostgresCodec};
|
||||
use crate::codec::PostgresCodec;
|
||||
use crate::config::{self, AuthKeys, Config};
|
||||
use crate::connection::{GC_THRESHOLD, INITIAL_CAPACITY};
|
||||
use crate::maybe_tls_stream::MaybeTlsStream;
|
||||
use crate::tls::TlsStream;
|
||||
|
||||
pub struct StartupStream<S, T> {
|
||||
inner: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
|
||||
buf: BackendMessages,
|
||||
delayed_notice: Vec<NoticeResponseBody>,
|
||||
}
|
||||
|
||||
impl<S, T> Sink<Bytes> for StartupStream<S, T>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_ready(cx)
|
||||
}
|
||||
|
||||
fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> io::Result<()> {
|
||||
Pin::new(&mut self.inner).start_send(item)
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_close(cx)
|
||||
}
|
||||
read_buf: BytesMut,
|
||||
}
|
||||
|
||||
impl<S, T> Stream for StartupStream<S, T>
|
||||
@@ -56,78 +31,109 @@ where
|
||||
{
|
||||
type Item = io::Result<Message>;
|
||||
|
||||
fn poll_next(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<io::Result<Message>>> {
|
||||
loop {
|
||||
match self.buf.next() {
|
||||
Ok(Some(message)) => return Poll::Ready(Some(Ok(message))),
|
||||
Ok(None) => {}
|
||||
Err(e) => return Poll::Ready(Some(Err(e))),
|
||||
}
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
// We don't use `self.inner.poll_next()` as that might over-read into the read buffer.
|
||||
|
||||
match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
|
||||
Some(Ok(BackendMessage::Normal { messages, .. })) => self.buf = messages,
|
||||
Some(Ok(BackendMessage::Async(message))) => return Poll::Ready(Some(Ok(message))),
|
||||
Some(Err(e)) => return Poll::Ready(Some(Err(e))),
|
||||
None => return Poll::Ready(None),
|
||||
}
|
||||
// read 1 byte tag, 4 bytes length.
|
||||
let header = ready!(self.as_mut().poll_fill_buf_exact(cx, 5)?);
|
||||
|
||||
let len = u32::from_be_bytes(header[1..5].try_into().unwrap());
|
||||
if len < 4 {
|
||||
return Poll::Ready(Some(Err(std::io::Error::other(
|
||||
"postgres message too small",
|
||||
))));
|
||||
}
|
||||
if len >= 65536 {
|
||||
return Poll::Ready(Some(Err(std::io::Error::other(
|
||||
"postgres message too large",
|
||||
))));
|
||||
}
|
||||
|
||||
// the tag is an additional byte.
|
||||
let _message = ready!(self.as_mut().poll_fill_buf_exact(cx, len as usize + 1)?);
|
||||
|
||||
// Message::parse will remove the all the bytes from the buffer.
|
||||
Poll::Ready(Message::parse(&mut self.read_buf).transpose())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RawConnection<S, T> {
|
||||
pub stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
|
||||
pub parameters: HashMap<String, String>,
|
||||
pub delayed_notice: Vec<NoticeResponseBody>,
|
||||
pub process_id: i32,
|
||||
pub secret_key: i32,
|
||||
}
|
||||
|
||||
pub async fn connect_raw<S, T>(
|
||||
stream: MaybeTlsStream<S, T>,
|
||||
config: &Config,
|
||||
) -> Result<RawConnection<S, T>, Error>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: TlsStream + Unpin,
|
||||
{
|
||||
let mut stream = StartupStream {
|
||||
inner: Framed::new(stream, PostgresCodec),
|
||||
buf: BackendMessages::empty(),
|
||||
delayed_notice: Vec::new(),
|
||||
};
|
||||
|
||||
startup(&mut stream, config).await?;
|
||||
authenticate(&mut stream, config).await?;
|
||||
let (process_id, secret_key, parameters) = read_info(&mut stream).await?;
|
||||
|
||||
Ok(RawConnection {
|
||||
stream: stream.inner,
|
||||
parameters,
|
||||
delayed_notice: stream.delayed_notice,
|
||||
process_id,
|
||||
secret_key,
|
||||
})
|
||||
}
|
||||
|
||||
async fn startup<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
|
||||
impl<S, T> StartupStream<S, T>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
let mut buf = BytesMut::new();
|
||||
frontend::startup_message(&config.server_params, &mut buf).map_err(Error::encode)?;
|
||||
/// Fill the buffer until it's the exact length provided. No additional data will be read from the socket.
|
||||
///
|
||||
/// If the current buffer length is greater, nothing happens.
|
||||
fn poll_fill_buf_exact(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
len: usize,
|
||||
) -> Poll<Result<&[u8], std::io::Error>> {
|
||||
let this = self.get_mut();
|
||||
let mut stream = Pin::new(this.inner.get_mut());
|
||||
|
||||
stream.send(buf.freeze()).await.map_err(Error::io)
|
||||
let mut n = this.read_buf.len();
|
||||
while n < len {
|
||||
this.read_buf.resize(len, 0);
|
||||
|
||||
let mut buf = ReadBuf::new(&mut this.read_buf[..]);
|
||||
buf.set_filled(n);
|
||||
|
||||
if stream.as_mut().poll_read(cx, &mut buf)?.is_pending() {
|
||||
this.read_buf.truncate(n);
|
||||
return Poll::Pending;
|
||||
}
|
||||
|
||||
if buf.filled().len() == n {
|
||||
return Poll::Ready(Err(std::io::Error::new(
|
||||
std::io::ErrorKind::UnexpectedEof,
|
||||
"early eof",
|
||||
)));
|
||||
}
|
||||
n = buf.filled().len();
|
||||
|
||||
this.read_buf.truncate(n);
|
||||
}
|
||||
|
||||
Poll::Ready(Ok(&this.read_buf[..len]))
|
||||
}
|
||||
|
||||
pub fn into_framed(mut self) -> Framed<MaybeTlsStream<S, T>, PostgresCodec> {
|
||||
*self.inner.read_buffer_mut() = self.read_buf;
|
||||
self.inner
|
||||
}
|
||||
|
||||
pub fn new(io: MaybeTlsStream<S, T>) -> Self {
|
||||
let mut parts = FramedParts::new(io, PostgresCodec);
|
||||
parts.write_buf = BytesMut::with_capacity(INITIAL_CAPACITY);
|
||||
|
||||
let mut inner = Framed::from_parts(parts);
|
||||
|
||||
// This is the default already, but nice to be explicit.
|
||||
// We divide by two because writes will overshoot the boundary.
|
||||
// We don't want constant overshoots to cause us to constantly re-shrink the buffer.
|
||||
inner.set_backpressure_boundary(GC_THRESHOLD / 2);
|
||||
|
||||
Self {
|
||||
inner,
|
||||
read_buf: BytesMut::with_capacity(INITIAL_CAPACITY),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn authenticate<S, T>(stream: &mut StartupStream<S, T>, config: &Config) -> Result<(), Error>
|
||||
pub(crate) async fn authenticate<S, T>(
|
||||
stream: &mut StartupStream<S, T>,
|
||||
config: &Config,
|
||||
) -> Result<(), Error>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: TlsStream + Unpin,
|
||||
{
|
||||
frontend::startup_message(&config.server_params, stream.inner.write_buffer_mut())
|
||||
.map_err(Error::encode)?;
|
||||
|
||||
stream.inner.flush().await.map_err(Error::io)?;
|
||||
match stream.try_next().await.map_err(Error::io)? {
|
||||
Some(Message::AuthenticationOk) => {
|
||||
can_skip_channel_binding(config)?;
|
||||
@@ -141,7 +147,8 @@ where
|
||||
.as_ref()
|
||||
.ok_or_else(|| Error::config("password missing".into()))?;
|
||||
|
||||
authenticate_password(stream, pass).await?;
|
||||
frontend::password_message(pass, stream.inner.write_buffer_mut())
|
||||
.map_err(Error::encode)?;
|
||||
}
|
||||
Some(Message::AuthenticationSasl(body)) => {
|
||||
authenticate_sasl(stream, body, config).await?;
|
||||
@@ -160,6 +167,7 @@ where
|
||||
None => return Err(Error::closed()),
|
||||
}
|
||||
|
||||
stream.inner.flush().await.map_err(Error::io)?;
|
||||
match stream.try_next().await.map_err(Error::io)? {
|
||||
Some(Message::AuthenticationOk) => Ok(()),
|
||||
Some(Message::ErrorResponse(body)) => Err(Error::db(body)),
|
||||
@@ -177,20 +185,6 @@ fn can_skip_channel_binding(config: &Config) -> Result<(), Error> {
|
||||
}
|
||||
}
|
||||
|
||||
async fn authenticate_password<S, T>(
|
||||
stream: &mut StartupStream<S, T>,
|
||||
password: &[u8],
|
||||
) -> Result<(), Error>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
let mut buf = BytesMut::new();
|
||||
frontend::password_message(password, &mut buf).map_err(Error::encode)?;
|
||||
|
||||
stream.send(buf.freeze()).await.map_err(Error::io)
|
||||
}
|
||||
|
||||
async fn authenticate_sasl<S, T>(
|
||||
stream: &mut StartupStream<S, T>,
|
||||
body: AuthenticationSaslBody,
|
||||
@@ -245,10 +239,10 @@ where
|
||||
return Err(Error::config("password or auth keys missing".into()));
|
||||
};
|
||||
|
||||
let mut buf = BytesMut::new();
|
||||
frontend::sasl_initial_response(mechanism, scram.message(), &mut buf).map_err(Error::encode)?;
|
||||
stream.send(buf.freeze()).await.map_err(Error::io)?;
|
||||
frontend::sasl_initial_response(mechanism, scram.message(), stream.inner.write_buffer_mut())
|
||||
.map_err(Error::encode)?;
|
||||
|
||||
stream.inner.flush().await.map_err(Error::io)?;
|
||||
let body = match stream.try_next().await.map_err(Error::io)? {
|
||||
Some(Message::AuthenticationSaslContinue(body)) => body,
|
||||
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
|
||||
@@ -261,10 +255,10 @@ where
|
||||
.await
|
||||
.map_err(|e| Error::authentication(e.into()))?;
|
||||
|
||||
let mut buf = BytesMut::new();
|
||||
frontend::sasl_response(scram.message(), &mut buf).map_err(Error::encode)?;
|
||||
stream.send(buf.freeze()).await.map_err(Error::io)?;
|
||||
frontend::sasl_response(scram.message(), stream.inner.write_buffer_mut())
|
||||
.map_err(Error::encode)?;
|
||||
|
||||
stream.inner.flush().await.map_err(Error::io)?;
|
||||
let body = match stream.try_next().await.map_err(Error::io)? {
|
||||
Some(Message::AuthenticationSaslFinal(body)) => body,
|
||||
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
|
||||
@@ -278,35 +272,3 @@ where
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn read_info<S, T>(
|
||||
stream: &mut StartupStream<S, T>,
|
||||
) -> Result<(i32, i32, HashMap<String, String>), Error>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
let mut process_id = 0;
|
||||
let mut secret_key = 0;
|
||||
let mut parameters = HashMap::new();
|
||||
|
||||
loop {
|
||||
match stream.try_next().await.map_err(Error::io)? {
|
||||
Some(Message::BackendKeyData(body)) => {
|
||||
process_id = body.process_id();
|
||||
secret_key = body.secret_key();
|
||||
}
|
||||
Some(Message::ParameterStatus(body)) => {
|
||||
parameters.insert(
|
||||
body.name().map_err(Error::parse)?.to_string(),
|
||||
body.value().map_err(Error::parse)?.to_string(),
|
||||
);
|
||||
}
|
||||
Some(Message::NoticeResponse(body)) => stream.delayed_notice.push(body),
|
||||
Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key, parameters)),
|
||||
Some(Message::ErrorResponse(body)) => return Err(Error::db(body)),
|
||||
Some(_) => return Err(Error::unexpected_message()),
|
||||
None => return Err(Error::closed()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,6 +44,27 @@ pub struct Connection<S, T> {
|
||||
state: State,
|
||||
}
|
||||
|
||||
pub const INITIAL_CAPACITY: usize = 2 * 1024;
|
||||
pub const GC_THRESHOLD: usize = 16 * 1024;
|
||||
|
||||
/// Gargabe collect the [`BytesMut`] if it has too much spare capacity.
|
||||
pub fn gc_bytesmut(buf: &mut BytesMut) {
|
||||
// We use a different mode to shrink the buf when above the threshold.
|
||||
// When above the threshold, we only re-allocate when the buf has 2x spare capacity.
|
||||
let reclaim = GC_THRESHOLD.checked_sub(buf.len()).unwrap_or(buf.len());
|
||||
|
||||
// `try_reclaim` tries to get the capacity from any shared `BytesMut`s,
|
||||
// before then comparing the length against the capacity.
|
||||
if buf.try_reclaim(reclaim) {
|
||||
let capacity = usize::max(buf.len(), INITIAL_CAPACITY);
|
||||
|
||||
// Allocate a new `BytesMut` so that we deallocate the old version.
|
||||
let mut new = BytesMut::with_capacity(capacity);
|
||||
new.extend_from_slice(buf);
|
||||
*buf = new;
|
||||
}
|
||||
}
|
||||
|
||||
pub enum Never {}
|
||||
|
||||
impl<S, T> Connection<S, T>
|
||||
@@ -86,7 +107,14 @@ where
|
||||
continue;
|
||||
}
|
||||
BackendMessage::Async(_) => continue,
|
||||
BackendMessage::Normal { messages } => messages,
|
||||
BackendMessage::Normal { messages, ready } => {
|
||||
// if we read a ReadyForQuery from postgres, let's try GC the read buffer.
|
||||
if ready {
|
||||
gc_bytesmut(self.stream.read_buffer_mut());
|
||||
}
|
||||
|
||||
messages
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -177,12 +205,7 @@ where
|
||||
// Send a terminate message to postgres
|
||||
Poll::Ready(None) => {
|
||||
trace!("poll_write: at eof, terminating");
|
||||
let mut request = BytesMut::new();
|
||||
frontend::terminate(&mut request);
|
||||
|
||||
Pin::new(&mut self.stream)
|
||||
.start_send(request.freeze())
|
||||
.map_err(Error::io)?;
|
||||
frontend::terminate(self.stream.write_buffer_mut());
|
||||
|
||||
trace!("poll_write: sent eof, closing");
|
||||
trace!("poll_write: done");
|
||||
@@ -205,6 +228,13 @@ where
|
||||
{
|
||||
Poll::Ready(()) => {
|
||||
trace!("poll_flush: flushed");
|
||||
|
||||
// Since our codec prefers to share the buffer with the `Client`,
|
||||
// if we don't release our share, then the `Client` would have to re-alloc
|
||||
// the buffer when they next use it.
|
||||
debug_assert!(self.stream.write_buffer().is_empty());
|
||||
*self.stream.write_buffer_mut() = BytesMut::new();
|
||||
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
Poll::Pending => {
|
||||
|
||||
@@ -452,16 +452,16 @@ impl Error {
|
||||
Error(Box::new(ErrorInner { kind, cause }))
|
||||
}
|
||||
|
||||
pub(crate) fn closed() -> Error {
|
||||
pub fn closed() -> Error {
|
||||
Error::new(Kind::Closed, None)
|
||||
}
|
||||
|
||||
pub(crate) fn unexpected_message() -> Error {
|
||||
pub fn unexpected_message() -> Error {
|
||||
Error::new(Kind::UnexpectedMessage, None)
|
||||
}
|
||||
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
pub(crate) fn db(error: ErrorResponseBody) -> Error {
|
||||
pub fn db(error: ErrorResponseBody) -> Error {
|
||||
match DbError::parse(&mut error.fields()) {
|
||||
Ok(e) => Error::new(Kind::Db, Some(Box::new(e))),
|
||||
Err(e) => Error::new(Kind::Parse, Some(Box::new(e))),
|
||||
@@ -493,7 +493,7 @@ impl Error {
|
||||
Error::new(Kind::Tls, Some(e))
|
||||
}
|
||||
|
||||
pub(crate) fn io(e: io::Error) -> Error {
|
||||
pub fn io(e: io::Error) -> Error {
|
||||
Error::new(Kind::Io, Some(Box::new(e)))
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ use postgres_protocol2::message::backend::ReadyForQueryBody;
|
||||
pub use crate::cancel_token::{CancelToken, RawCancelToken};
|
||||
pub use crate::client::{Client, SocketConfig};
|
||||
pub use crate::config::Config;
|
||||
pub use crate::connect_raw::RawConnection;
|
||||
pub use crate::connection::Connection;
|
||||
pub use crate::error::Error;
|
||||
pub use crate::generic_client::GenericClient;
|
||||
@@ -49,8 +48,8 @@ mod cancel_token;
|
||||
mod client;
|
||||
mod codec;
|
||||
pub mod config;
|
||||
mod connect;
|
||||
mod connect_raw;
|
||||
pub mod connect;
|
||||
pub mod connect_raw;
|
||||
mod connect_socket;
|
||||
mod connect_tls;
|
||||
mod connection;
|
||||
|
||||
@@ -301,7 +301,12 @@ pub struct PullTimelineRequest {
|
||||
pub tenant_id: TenantId,
|
||||
pub timeline_id: TimelineId,
|
||||
pub http_hosts: Vec<String>,
|
||||
pub ignore_tombstone: Option<bool>,
|
||||
/// Membership configuration to switch to after pull.
|
||||
/// It guarantees that if pull_timeline returns successfully, the timeline will
|
||||
/// not be deleted by request with an older generation.
|
||||
/// Storage controller always sets this field.
|
||||
/// None is only allowed for manual pull_timeline requests.
|
||||
pub mconf: Option<Configuration>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
|
||||
@@ -49,7 +49,7 @@ impl PerfSpan {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn enter(&self) -> PerfSpanEntered {
|
||||
pub fn enter(&self) -> PerfSpanEntered<'_> {
|
||||
if let Some(ref id) = self.inner.id() {
|
||||
self.dispatch.enter(id);
|
||||
}
|
||||
|
||||
@@ -429,9 +429,11 @@ pub fn empty_shmem() -> crate::bindings::WalproposerShmemState {
|
||||
};
|
||||
|
||||
let empty_wal_rate_limiter = crate::bindings::WalRateLimiter {
|
||||
effective_max_wal_bytes_per_second: crate::bindings::pg_atomic_uint32 { value: 0 },
|
||||
should_limit: crate::bindings::pg_atomic_uint32 { value: 0 },
|
||||
sent_bytes: 0,
|
||||
last_recorded_time_us: crate::bindings::pg_atomic_uint64 { value: 0 },
|
||||
batch_start_time_us: crate::bindings::pg_atomic_uint64 { value: 0 },
|
||||
batch_end_time_us: crate::bindings::pg_atomic_uint64 { value: 0 },
|
||||
};
|
||||
|
||||
crate::bindings::WalproposerShmemState {
|
||||
|
||||
Reference in New Issue
Block a user