Compare commits

..

1 Commits

Author SHA1 Message Date
Mikhail Kot
8de4636c2e initial 2025-06-02 12:36:31 +01:00
47 changed files with 1554 additions and 1549 deletions

187
Cargo.lock generated
View File

@@ -29,6 +29,41 @@ version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627"
[[package]]
name = "aead"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0"
dependencies = [
"crypto-common",
"generic-array",
]
[[package]]
name = "aes"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0"
dependencies = [
"cfg-if",
"cipher",
"cpufeatures",
]
[[package]]
name = "aes-gcm"
version = "0.10.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "831010a0f742e1209b3bcea8fab6a8e149051ba6099432c8cb2cc117dec3ead1"
dependencies = [
"aead",
"aes",
"cipher",
"ctr",
"ghash",
"subtle",
]
[[package]]
name = "ahash"
version = "0.8.11"
@@ -753,6 +788,7 @@ dependencies = [
"axum",
"axum-core",
"bytes",
"cookie",
"futures-util",
"headers",
"http 1.1.0",
@@ -1173,6 +1209,16 @@ dependencies = [
"half",
]
[[package]]
name = "cipher"
version = "0.4.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad"
dependencies = [
"crypto-common",
"inout",
]
[[package]]
name = "clang-sys"
version = "1.6.1"
@@ -1464,6 +1510,21 @@ dependencies = [
"workspace_hack",
]
[[package]]
name = "cookie"
version = "0.18.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ddef33a339a91ea89fb53151bd0a4689cfce27055c291dfa69945475d22c747"
dependencies = [
"aes-gcm",
"base64 0.22.1",
"percent-encoding",
"rand 0.8.5",
"subtle",
"time",
"version_check",
]
[[package]]
name = "core-foundation"
version = "0.9.3"
@@ -1657,9 +1718,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3"
dependencies = [
"generic-array",
"rand_core 0.6.4",
"typenum",
]
[[package]]
name = "ctr"
version = "0.9.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0369ee1ad671834580515889b80f2ea915f23b8be8d0daa4bbaf2ac5c7590835"
dependencies = [
"cipher",
]
[[package]]
name = "curve25519-dalek"
version = "4.1.3"
@@ -2510,6 +2581,16 @@ dependencies = [
"winapi",
]
[[package]]
name = "ghash"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0d8a4362ccb29cb0b265253fb0a2728f592895ee6854fd9bc13f2ffda266ff1"
dependencies = [
"opaque-debug",
"polyval",
]
[[package]]
name = "gimli"
version = "0.31.1"
@@ -3281,6 +3362,15 @@ dependencies = [
"libc",
]
[[package]]
name = "inout"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "879f10e63c20629ecabbb64a8010319738c66a5cd0c29b02d63d272b03751d01"
dependencies = [
"generic-array",
]
[[package]]
name = "instant"
version = "0.1.12"
@@ -3794,6 +3884,15 @@ version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a"
[[package]]
name = "nanoid"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3ffa00dec017b5b1a8b7cf5e2c008bfda1aa7e0697ac1508b491fdf2622fb4d8"
dependencies = [
"rand 0.8.5",
]
[[package]]
name = "neon-shmem"
version = "0.1.0"
@@ -4066,6 +4165,12 @@ version = "11.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575"
[[package]]
name = "opaque-debug"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381"
[[package]]
name = "openssl-probe"
version = "0.1.5"
@@ -4236,7 +4341,6 @@ name = "pagebench"
version = "0.1.0"
dependencies = [
"anyhow",
"async-trait",
"camino",
"clap",
"futures",
@@ -4245,15 +4349,12 @@ dependencies = [
"humantime-serde",
"pageserver_api",
"pageserver_client",
"pageserver_page_api",
"rand 0.8.5",
"reqwest",
"serde",
"serde_json",
"tokio",
"tokio-stream",
"tokio-util",
"tonic 0.13.1",
"tracing",
"utils",
"workspace_hack",
@@ -4309,7 +4410,6 @@ dependencies = [
"hashlink",
"hex",
"hex-literal",
"http 1.1.0",
"http-utils",
"humantime",
"humantime-serde",
@@ -4372,7 +4472,6 @@ dependencies = [
"toml_edit",
"tonic 0.13.1",
"tonic-reflection",
"tower 0.5.2",
"tracing",
"tracing-utils",
"twox-hash",
@@ -4469,6 +4568,7 @@ dependencies = [
"pageserver_api",
"postgres_ffi",
"prost 0.13.5",
"smallvec",
"thiserror 1.0.69",
"tonic 0.13.1",
"tonic-build",
@@ -4590,6 +4690,31 @@ version = "1.0.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c"
[[package]]
name = "paster"
version = "0.1.0"
dependencies = [
"anyhow",
"axum",
"axum-extra",
"base64 0.13.1",
"chrono",
"nanoid",
"rand 0.8.5",
"reqwest",
"rustls 0.23.27",
"rustls-native-certs 0.8.0",
"serde",
"serde_json",
"time",
"tokio",
"tokio-postgres",
"tokio-postgres-rustls",
"tracing",
"tracing-subscriber",
"workspace_hack",
]
[[package]]
name = "pbkdf2"
version = "0.12.2"
@@ -4762,6 +4887,18 @@ dependencies = [
"never-say-never",
]
[[package]]
name = "polyval"
version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d1fe60d06143b2430aa532c94cfe9e29783047f06c0d7fd359a9a51b729fa25"
dependencies = [
"cfg-if",
"cpufeatures",
"opaque-debug",
"universal-hash",
]
[[package]]
name = "portable-atomic"
version = "1.10.0"
@@ -6564,6 +6701,32 @@ version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
[[package]]
name = "shortener"
version = "0.1.0"
dependencies = [
"anyhow",
"axum",
"axum-extra",
"base64 0.13.1",
"chrono",
"cookie",
"nanoid",
"rand 0.8.5",
"reqwest",
"rustls 0.23.27",
"rustls-native-certs 0.8.0",
"serde",
"serde_json",
"time",
"tokio",
"tokio-postgres",
"tokio-postgres-rustls",
"tracing",
"tracing-subscriber",
"workspace_hack",
]
[[package]]
name = "signal-hook"
version = "0.3.15"
@@ -7932,6 +8095,16 @@ version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c"
[[package]]
name = "universal-hash"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc1de2c688dc15305988b563c3854064043356019f97a4b46276fe734c4f07ea"
dependencies = [
"crypto-common",
"subtle",
]
[[package]]
name = "untrusted"
version = "0.9.0"
@@ -8564,6 +8737,7 @@ dependencies = [
"anyhow",
"axum",
"axum-core",
"axum-extra",
"base64 0.13.1",
"base64 0.21.7",
"base64ct",
@@ -8575,6 +8749,7 @@ dependencies = [
"clap_builder",
"const-oid",
"crypto-bigint 0.5.5",
"crypto-common",
"der 0.7.8",
"deranged",
"digest",

View File

@@ -13,6 +13,8 @@ members = [
"proxy",
"safekeeper",
"safekeeper/client",
"shortener",
"paster",
"storage_broker",
"storage_controller",
"storage_controller/client",

View File

@@ -1180,14 +1180,14 @@ RUN cd exts/rag && \
RUN cd exts/rag_bge_small_en_v15 && \
sed -i 's/pgrx = "0.14.1"/pgrx = { version = "0.14.1", features = [ "unsafe-postgres" ] }/g' Cargo.toml && \
ORT_LIB_LOCATION=/ext-src/onnxruntime-src/build/Linux \
REMOTE_ONNX_URL=http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local/pgrag-data/bge_small_en_v15.onnx \
REMOTE_ONNX_URL=http://pg-ext-s3-gateway/pgrag-data/bge_small_en_v15.onnx \
cargo pgrx install --release --features remote_onnx && \
echo "trusted = true" >> /usr/local/pgsql/share/extension/rag_bge_small_en_v15.control
RUN cd exts/rag_jina_reranker_v1_tiny_en && \
sed -i 's/pgrx = "0.14.1"/pgrx = { version = "0.14.1", features = [ "unsafe-postgres" ] }/g' Cargo.toml && \
ORT_LIB_LOCATION=/ext-src/onnxruntime-src/build/Linux \
REMOTE_ONNX_URL=http://pg-ext-s3-gateway.pg-ext-s3-gateway.svc.cluster.local/pgrag-data/jina_reranker_v1_tiny_en.onnx \
REMOTE_ONNX_URL=http://pg-ext-s3-gateway/pgrag-data/jina_reranker_v1_tiny_en.onnx \
cargo pgrx install --release --features remote_onnx && \
echo "trusted = true" >> /usr/local/pgsql/share/extension/rag_jina_reranker_v1_tiny_en.control

View File

@@ -181,7 +181,6 @@ pub struct ConfigToml {
pub virtual_file_io_engine: Option<crate::models::virtual_file::IoEngineKind>,
pub ingest_batch_size: u64,
pub max_vectored_read_bytes: MaxVectoredReadBytes,
pub max_get_vectored_keys: MaxGetVectoredKeys,
pub image_compression: ImageCompressionAlgorithm,
pub timeline_offloading: bool,
pub ephemeral_bytes_per_memory_kb: usize,
@@ -230,7 +229,7 @@ pub enum PageServicePipeliningConfig {
}
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct PageServicePipeliningConfigPipelined {
/// Failed config parsing and validation if larger than `max_get_vectored_keys`.
/// Causes runtime errors if larger than max get_vectored batch size.
pub max_batch_size: NonZeroUsize,
pub execution: PageServiceProtocolPipelinedExecutionStrategy,
// The default below is such that new versions of the software can start
@@ -404,16 +403,6 @@ impl Default for EvictionOrder {
#[serde(transparent)]
pub struct MaxVectoredReadBytes(pub NonZeroUsize);
#[derive(Copy, Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(transparent)]
pub struct MaxGetVectoredKeys(NonZeroUsize);
impl MaxGetVectoredKeys {
pub fn get(&self) -> usize {
self.0.get()
}
}
/// Tenant-level configuration values, used for various purposes.
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(default)]
@@ -598,8 +587,6 @@ pub mod defaults {
/// That is, slightly above 128 kB.
pub const DEFAULT_MAX_VECTORED_READ_BYTES: usize = 130 * 1024; // 130 KiB
pub const DEFAULT_MAX_GET_VECTORED_KEYS: usize = 32;
pub const DEFAULT_IMAGE_COMPRESSION: ImageCompressionAlgorithm =
ImageCompressionAlgorithm::Zstd { level: Some(1) };
@@ -698,9 +685,6 @@ impl Default for ConfigToml {
max_vectored_read_bytes: (MaxVectoredReadBytes(
NonZeroUsize::new(DEFAULT_MAX_VECTORED_READ_BYTES).unwrap(),
)),
max_get_vectored_keys: (MaxGetVectoredKeys(
NonZeroUsize::new(DEFAULT_MAX_GET_VECTORED_KEYS).unwrap(),
)),
image_compression: (DEFAULT_IMAGE_COMPRESSION),
timeline_offloading: true,
ephemeral_bytes_per_memory_kb: (DEFAULT_EPHEMERAL_BYTES_PER_MEMORY_KB),

View File

@@ -1934,7 +1934,7 @@ pub enum PagestreamFeMessage {
}
// Wrapped in libpq CopyData
#[derive(Debug, strum_macros::EnumProperty)]
#[derive(strum_macros::EnumProperty)]
pub enum PagestreamBeMessage {
Exists(PagestreamExistsResponse),
Nblocks(PagestreamNblocksResponse),
@@ -2045,7 +2045,7 @@ pub enum PagestreamProtocolVersion {
pub type RequestId = u64;
#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)]
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct PagestreamRequest {
pub reqid: RequestId,
pub request_lsn: Lsn,
@@ -2064,7 +2064,7 @@ pub struct PagestreamNblocksRequest {
pub rel: RelTag,
}
#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)]
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct PagestreamGetPageRequest {
pub hdr: PagestreamRequest,
pub rel: RelTag,

View File

@@ -24,7 +24,7 @@ use serde::{Deserialize, Serialize};
// FIXME: should move 'forknum' as last field to keep this consistent with Postgres.
// Then we could replace the custom Ord and PartialOrd implementations below with
// deriving them. This will require changes in walredoproc.c.
#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize)]
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize)]
pub struct RelTag {
pub forknum: u8,
pub spcnode: Oid,
@@ -184,12 +184,12 @@ pub enum SlruKind {
MultiXactOffsets,
}
impl fmt::Display for SlruKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
impl SlruKind {
pub fn to_str(&self) -> &'static str {
match self {
Self::Clog => write!(f, "pg_xact"),
Self::MultiXactMembers => write!(f, "pg_multixact/members"),
Self::MultiXactOffsets => write!(f, "pg_multixact/offsets"),
Self::Clog => "pg_xact",
Self::MultiXactMembers => "pg_multixact/members",
Self::MultiXactOffsets => "pg_multixact/offsets",
}
}
}

View File

@@ -73,7 +73,6 @@ pub mod error;
/// async timeout helper
pub mod timeout;
pub mod span;
pub mod sync;
pub mod failpoint_support;

View File

@@ -1,19 +0,0 @@
//! Tracing span helpers.
/// Records the given fields in the current span, as a single call. The fields must already have
/// been declared for the span (typically with empty values).
#[macro_export]
macro_rules! span_record {
($($tokens:tt)*) => {$crate::span_record_in!(::tracing::Span::current(), $($tokens)*)};
}
/// Records the given fields in the given span, as a single call. The fields must already have been
/// declared for the span (typically with empty values).
#[macro_export]
macro_rules! span_record_in {
($span:expr, $($tokens:tt)*) => {
if let Some(meta) = $span.metadata() {
$span.record_all(&tracing::valueset!(meta.fields(), $($tokens)*));
}
};
}

View File

@@ -34,7 +34,6 @@ fail.workspace = true
futures.workspace = true
hashlink.workspace = true
hex.workspace = true
http.workspace = true
http-utils.workspace = true
humantime-serde.workspace = true
humantime.workspace = true
@@ -94,7 +93,6 @@ tokio-util.workspace = true
toml_edit = { workspace = true, features = [ "serde" ] }
tonic.workspace = true
tonic-reflection.workspace = true
tower.workspace = true
tracing.workspace = true
tracing-utils.workspace = true
url.workspace = true

View File

@@ -9,6 +9,7 @@ bytes.workspace = true
pageserver_api.workspace = true
postgres_ffi.workspace = true
prost.workspace = true
smallvec.workspace = true
thiserror.workspace = true
tonic.workspace = true
utils.workspace = true

View File

@@ -9,16 +9,10 @@
//! - Use more precise datatypes, e.g. Lsn and uints shorter than 32 bits.
//!
//! - Validate protocol invariants, via try_from() and try_into().
//!
//! Validation only happens on the receiver side, i.e. when converting from Protobuf to domain
//! types. This is where it matters -- the Protobuf types are less strict than the domain types, and
//! receivers should expect all sorts of junk from senders. This also allows the sender to use e.g.
//! stream combinators without dealing with errors, and avoids validating the same message twice.
use std::fmt::Display;
use bytes::Bytes;
use postgres_ffi::Oid;
use smallvec::SmallVec;
// TODO: split out Lsn, RelTag, SlruKind, Oid and other basic types to a separate crate, to avoid
// pulling in all of their other crate dependencies when building the client.
use utils::lsn::Lsn;
@@ -54,8 +48,7 @@ pub struct ReadLsn {
pub request_lsn: Lsn,
/// If given, the caller guarantees that the page has not been modified since this LSN. Must be
/// smaller than or equal to request_lsn. This allows the Pageserver to serve an old page
/// without waiting for the request LSN to arrive. If not given, the request will read at the
/// request_lsn and wait for it to arrive if necessary. Valid for all request types.
/// without waiting for the request LSN to arrive. Valid for all request types.
///
/// It is undefined behaviour to make a request such that the page was, in fact, modified
/// between request_lsn and not_modified_since_lsn. The Pageserver might detect it and return an
@@ -65,14 +58,19 @@ pub struct ReadLsn {
pub not_modified_since_lsn: Option<Lsn>,
}
impl Display for ReadLsn {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let req_lsn = self.request_lsn;
if let Some(mod_lsn) = self.not_modified_since_lsn {
write!(f, "{req_lsn}>={mod_lsn}")
} else {
req_lsn.fmt(f)
impl ReadLsn {
/// Validates the ReadLsn.
pub fn validate(&self) -> Result<(), ProtocolError> {
if self.request_lsn == Lsn::INVALID {
return Err(ProtocolError::invalid("request_lsn", self.request_lsn));
}
if self.not_modified_since_lsn > Some(self.request_lsn) {
return Err(ProtocolError::invalid(
"not_modified_since_lsn",
self.not_modified_since_lsn,
));
}
Ok(())
}
}
@@ -80,31 +78,27 @@ impl TryFrom<proto::ReadLsn> for ReadLsn {
type Error = ProtocolError;
fn try_from(pb: proto::ReadLsn) -> Result<Self, Self::Error> {
if pb.request_lsn == 0 {
return Err(ProtocolError::invalid("request_lsn", pb.request_lsn));
}
if pb.not_modified_since_lsn > pb.request_lsn {
return Err(ProtocolError::invalid(
"not_modified_since_lsn",
pb.not_modified_since_lsn,
));
}
Ok(Self {
let read_lsn = Self {
request_lsn: Lsn(pb.request_lsn),
not_modified_since_lsn: match pb.not_modified_since_lsn {
0 => None,
lsn => Some(Lsn(lsn)),
},
})
};
read_lsn.validate()?;
Ok(read_lsn)
}
}
impl From<ReadLsn> for proto::ReadLsn {
fn from(read_lsn: ReadLsn) -> Self {
Self {
impl TryFrom<ReadLsn> for proto::ReadLsn {
type Error = ProtocolError;
fn try_from(read_lsn: ReadLsn) -> Result<Self, Self::Error> {
read_lsn.validate()?;
Ok(Self {
request_lsn: read_lsn.request_lsn.0,
not_modified_since_lsn: read_lsn.not_modified_since_lsn.unwrap_or_default().0,
}
})
}
}
@@ -159,15 +153,6 @@ impl TryFrom<proto::CheckRelExistsRequest> for CheckRelExistsRequest {
}
}
impl From<CheckRelExistsRequest> for proto::CheckRelExistsRequest {
fn from(request: CheckRelExistsRequest) -> Self {
Self {
read_lsn: Some(request.read_lsn.into()),
rel: Some(request.rel.into()),
}
}
}
pub type CheckRelExistsResponse = bool;
impl From<proto::CheckRelExistsResponse> for CheckRelExistsResponse {
@@ -205,12 +190,14 @@ impl TryFrom<proto::GetBaseBackupRequest> for GetBaseBackupRequest {
}
}
impl From<GetBaseBackupRequest> for proto::GetBaseBackupRequest {
fn from(request: GetBaseBackupRequest) -> Self {
Self {
read_lsn: Some(request.read_lsn.into()),
impl TryFrom<GetBaseBackupRequest> for proto::GetBaseBackupRequest {
type Error = ProtocolError;
fn try_from(request: GetBaseBackupRequest) -> Result<Self, Self::Error> {
Ok(Self {
read_lsn: Some(request.read_lsn.try_into()?),
replica: request.replica,
}
})
}
}
@@ -227,9 +214,14 @@ impl TryFrom<proto::GetBaseBackupResponseChunk> for GetBaseBackupResponseChunk {
}
}
impl From<GetBaseBackupResponseChunk> for proto::GetBaseBackupResponseChunk {
fn from(chunk: GetBaseBackupResponseChunk) -> Self {
Self { chunk }
impl TryFrom<GetBaseBackupResponseChunk> for proto::GetBaseBackupResponseChunk {
type Error = ProtocolError;
fn try_from(chunk: GetBaseBackupResponseChunk) -> Result<Self, Self::Error> {
if chunk.is_empty() {
return Err(ProtocolError::Missing("chunk"));
}
Ok(Self { chunk })
}
}
@@ -254,12 +246,14 @@ impl TryFrom<proto::GetDbSizeRequest> for GetDbSizeRequest {
}
}
impl From<GetDbSizeRequest> for proto::GetDbSizeRequest {
fn from(request: GetDbSizeRequest) -> Self {
Self {
read_lsn: Some(request.read_lsn.into()),
impl TryFrom<GetDbSizeRequest> for proto::GetDbSizeRequest {
type Error = ProtocolError;
fn try_from(request: GetDbSizeRequest) -> Result<Self, Self::Error> {
Ok(Self {
read_lsn: Some(request.read_lsn.try_into()?),
db_oid: request.db_oid,
}
})
}
}
@@ -294,7 +288,7 @@ pub struct GetPageRequest {
/// Multiple pages will be executed as a single batch by the Pageserver, amortizing layer access
/// costs and parallelizing them. This may increase the latency of any individual request, but
/// improves the overall latency and throughput of the batch as a whole.
pub block_numbers: Vec<u32>,
pub block_numbers: SmallVec<[u32; 1]>,
}
impl TryFrom<proto::GetPageRequest> for GetPageRequest {
@@ -312,20 +306,25 @@ impl TryFrom<proto::GetPageRequest> for GetPageRequest {
.ok_or(ProtocolError::Missing("read_lsn"))?
.try_into()?,
rel: pb.rel.ok_or(ProtocolError::Missing("rel"))?.try_into()?,
block_numbers: pb.block_number,
block_numbers: pb.block_number.into(),
})
}
}
impl From<GetPageRequest> for proto::GetPageRequest {
fn from(request: GetPageRequest) -> Self {
Self {
impl TryFrom<GetPageRequest> for proto::GetPageRequest {
type Error = ProtocolError;
fn try_from(request: GetPageRequest) -> Result<Self, Self::Error> {
if request.block_numbers.is_empty() {
return Err(ProtocolError::Missing("block_number"));
}
Ok(Self {
request_id: request.request_id,
request_class: request.request_class.into(),
read_lsn: Some(request.read_lsn.into()),
read_lsn: Some(request.read_lsn.try_into()?),
rel: Some(request.rel.into()),
block_number: request.block_numbers,
}
block_number: request.block_numbers.into_vec(),
})
}
}
@@ -397,7 +396,7 @@ pub struct GetPageResponse {
/// A string describing the status, if any.
pub reason: Option<String>,
/// The 8KB page images, in the same order as the request. Empty if status != OK.
pub page_images: Vec<Bytes>,
pub page_images: SmallVec<[Bytes; 1]>,
}
impl From<proto::GetPageResponse> for GetPageResponse {
@@ -406,7 +405,7 @@ impl From<proto::GetPageResponse> for GetPageResponse {
request_id: pb.request_id,
status_code: pb.status_code.into(),
reason: Some(pb.reason).filter(|r| !r.is_empty()),
page_images: pb.page_image,
page_images: pb.page_image.into(),
}
}
}
@@ -417,7 +416,7 @@ impl From<GetPageResponse> for proto::GetPageResponse {
request_id: response.request_id,
status_code: response.status_code.into(),
reason: response.reason.unwrap_or_default(),
page_image: response.page_images,
page_image: response.page_images.into_vec(),
}
}
}
@@ -506,12 +505,14 @@ impl TryFrom<proto::GetRelSizeRequest> for GetRelSizeRequest {
}
}
impl From<GetRelSizeRequest> for proto::GetRelSizeRequest {
fn from(request: GetRelSizeRequest) -> Self {
Self {
read_lsn: Some(request.read_lsn.into()),
impl TryFrom<GetRelSizeRequest> for proto::GetRelSizeRequest {
type Error = ProtocolError;
fn try_from(request: GetRelSizeRequest) -> Result<Self, Self::Error> {
Ok(Self {
read_lsn: Some(request.read_lsn.try_into()?),
rel: Some(request.rel.into()),
}
})
}
}
@@ -554,13 +555,15 @@ impl TryFrom<proto::GetSlruSegmentRequest> for GetSlruSegmentRequest {
}
}
impl From<GetSlruSegmentRequest> for proto::GetSlruSegmentRequest {
fn from(request: GetSlruSegmentRequest) -> Self {
Self {
read_lsn: Some(request.read_lsn.into()),
impl TryFrom<GetSlruSegmentRequest> for proto::GetSlruSegmentRequest {
type Error = ProtocolError;
fn try_from(request: GetSlruSegmentRequest) -> Result<Self, Self::Error> {
Ok(Self {
read_lsn: Some(request.read_lsn.try_into()?),
kind: request.kind as u32,
segno: request.segno,
}
})
}
}
@@ -577,9 +580,14 @@ impl TryFrom<proto::GetSlruSegmentResponse> for GetSlruSegmentResponse {
}
}
impl From<GetSlruSegmentResponse> for proto::GetSlruSegmentResponse {
fn from(segment: GetSlruSegmentResponse) -> Self {
Self { segment }
impl TryFrom<GetSlruSegmentResponse> for proto::GetSlruSegmentResponse {
type Error = ProtocolError;
fn try_from(segment: GetSlruSegmentResponse) -> Result<Self, Self::Error> {
if segment.is_empty() {
return Err(ProtocolError::Missing("segment"));
}
Ok(Self { segment })
}
}

View File

@@ -8,7 +8,6 @@ license.workspace = true
[dependencies]
anyhow.workspace = true
async-trait.workspace = true
camino.workspace = true
clap.workspace = true
futures.workspace = true
@@ -16,17 +15,14 @@ hdrhistogram.workspace = true
humantime.workspace = true
humantime-serde.workspace = true
rand.workspace = true
reqwest.workspace = true
reqwest.workspace=true
serde.workspace = true
serde_json.workspace = true
tracing.workspace = true
tokio.workspace = true
tokio-stream.workspace = true
tokio-util.workspace = true
tonic.workspace = true
pageserver_client.workspace = true
pageserver_api.workspace = true
pageserver_page_api.workspace = true
utils = { path = "../../libs/utils/" }
workspace_hack = { version = "0.1", path = "../../workspace_hack" }

View File

@@ -7,15 +7,11 @@ use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use anyhow::Context;
use async_trait::async_trait;
use camino::Utf8PathBuf;
use pageserver_api::key::Key;
use pageserver_api::keyspace::KeySpaceAccum;
use pageserver_api::models::{
PagestreamGetPageRequest, PagestreamGetPageResponse, PagestreamRequest,
};
use pageserver_api::models::{PagestreamGetPageRequest, PagestreamRequest};
use pageserver_api::shard::TenantShardId;
use pageserver_page_api::proto;
use rand::prelude::*;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
@@ -26,12 +22,6 @@ use utils::lsn::Lsn;
use crate::util::tokio_thread_local_stats::AllThreadLocalStats;
use crate::util::{request_stats, tokio_thread_local_stats};
#[derive(clap::ValueEnum, Clone, Debug)]
enum Protocol {
Libpq,
Grpc,
}
/// GetPage@LatestLSN, uniformly distributed across the compute-accessible keyspace.
#[derive(clap::Parser)]
pub(crate) struct Args {
@@ -45,8 +35,6 @@ pub(crate) struct Args {
num_clients: NonZeroUsize,
#[clap(long)]
runtime: Option<humantime::Duration>,
#[clap(long, value_enum, default_value = "libpq")]
protocol: Protocol,
/// Each client sends requests at the given rate.
///
/// If a request takes too long and we should be issuing a new request already,
@@ -315,20 +303,7 @@ async fn main_impl(
.unwrap();
Box::pin(async move {
let client: Box<dyn Client> = match args.protocol {
Protocol::Libpq => Box::new(
LibpqClient::new(args.page_service_connstring.clone(), worker_id.timeline)
.await
.unwrap(),
),
Protocol::Grpc => Box::new(
GrpcClient::new(args.page_service_connstring.clone(), worker_id.timeline)
.await
.unwrap(),
),
};
run_worker(args, client, ss, cancel, rps_period, ranges, weights).await
client_libpq(args, worker_id, ss, cancel, rps_period, ranges, weights).await
})
};
@@ -380,15 +355,23 @@ async fn main_impl(
anyhow::Ok(())
}
async fn run_worker(
async fn client_libpq(
args: &Args,
mut client: Box<dyn Client>,
worker_id: WorkerId,
shared_state: Arc<SharedState>,
cancel: CancellationToken,
rps_period: Option<Duration>,
ranges: Vec<KeyRange>,
weights: rand::distributions::weighted::WeightedIndex<i128>,
) {
let client = pageserver_client::page_service::Client::new(args.page_service_connstring.clone())
.await
.unwrap();
let mut client = client
.pagestream(worker_id.timeline.tenant_id, worker_id.timeline.timeline_id)
.await
.unwrap();
shared_state.start_work_barrier.wait().await;
let client_start = Instant::now();
let mut ticks_processed = 0;
@@ -432,12 +415,12 @@ async fn run_worker(
blkno: block_no,
}
};
client.send_get_page(req).await.unwrap();
client.getpage_send(req).await.unwrap();
inflight.push_back(start);
}
let start = inflight.pop_front().unwrap();
client.recv_get_page().await.unwrap();
client.getpage_recv().await.unwrap();
let end = Instant::now();
shared_state.live_stats.request_done();
ticks_processed += 1;
@@ -459,104 +442,3 @@ async fn run_worker(
}
}
}
/// A benchmark client, to allow switching out the transport protocol.
///
/// For simplicity, this just uses separate asynchronous send/recv methods. The send method could
/// return a future that resolves when the response is received, but we don't really need it.
#[async_trait]
trait Client: Send {
/// Sends an asynchronous GetPage request to the pageserver.
async fn send_get_page(&mut self, req: PagestreamGetPageRequest) -> anyhow::Result<()>;
/// Receives the next GetPage response from the pageserver.
async fn recv_get_page(&mut self) -> anyhow::Result<PagestreamGetPageResponse>;
}
/// A libpq-based Pageserver client.
struct LibpqClient {
inner: pageserver_client::page_service::PagestreamClient,
}
impl LibpqClient {
async fn new(connstring: String, ttid: TenantTimelineId) -> anyhow::Result<Self> {
let inner = pageserver_client::page_service::Client::new(connstring)
.await?
.pagestream(ttid.tenant_id, ttid.timeline_id)
.await?;
Ok(Self { inner })
}
}
#[async_trait]
impl Client for LibpqClient {
async fn send_get_page(&mut self, req: PagestreamGetPageRequest) -> anyhow::Result<()> {
self.inner.getpage_send(req).await
}
async fn recv_get_page(&mut self) -> anyhow::Result<PagestreamGetPageResponse> {
self.inner.getpage_recv().await
}
}
/// A gRPC client using the raw, no-frills gRPC client.
struct GrpcClient {
req_tx: tokio::sync::mpsc::Sender<proto::GetPageRequest>,
resp_rx: tonic::Streaming<proto::GetPageResponse>,
}
impl GrpcClient {
async fn new(connstring: String, ttid: TenantTimelineId) -> anyhow::Result<Self> {
let mut client = pageserver_page_api::proto::PageServiceClient::connect(connstring).await?;
// The channel has a buffer size of 1, since 0 is not allowed. It does not matter, since the
// benchmark will control the queue depth (i.e. in-flight requests) anyway, and requests are
// buffered by Tonic and the OS too.
let (req_tx, req_rx) = tokio::sync::mpsc::channel(1);
let req_stream = tokio_stream::wrappers::ReceiverStream::new(req_rx);
let mut req = tonic::Request::new(req_stream);
let metadata = req.metadata_mut();
metadata.insert("neon-tenant-id", ttid.tenant_id.to_string().try_into()?);
metadata.insert("neon-timeline-id", ttid.timeline_id.to_string().try_into()?);
metadata.insert("neon-shard-id", "0000".try_into()?);
let resp = client.get_pages(req).await?;
let resp_stream = resp.into_inner();
Ok(Self {
req_tx,
resp_rx: resp_stream,
})
}
}
#[async_trait]
impl Client for GrpcClient {
async fn send_get_page(&mut self, req: PagestreamGetPageRequest) -> anyhow::Result<()> {
let req = proto::GetPageRequest {
request_id: 0,
request_class: proto::GetPageClass::Normal as i32,
read_lsn: Some(proto::ReadLsn {
request_lsn: req.hdr.request_lsn.0,
not_modified_since_lsn: req.hdr.not_modified_since.0,
}),
rel: Some(req.rel.into()),
block_number: vec![req.blkno],
};
self.req_tx.send(req).await?;
Ok(())
}
async fn recv_get_page(&mut self) -> anyhow::Result<PagestreamGetPageResponse> {
let resp = self.resp_rx.message().await?.unwrap();
anyhow::ensure!(
resp.status_code == proto::GetPageStatusCode::Ok as i32,
"unexpected status code: {}",
resp.status_code
);
Ok(PagestreamGetPageResponse {
page: resp.page_image[0].clone(),
req: PagestreamGetPageRequest::default(), // dummy
})
}
}

View File

@@ -65,30 +65,6 @@ impl From<GetVectoredError> for BasebackupError {
}
}
impl From<BasebackupError> for postgres_backend::QueryError {
fn from(err: BasebackupError) -> Self {
use postgres_backend::QueryError;
use pq_proto::framed::ConnectionError;
match err {
BasebackupError::Client(err, _) => QueryError::Disconnected(ConnectionError::Io(err)),
BasebackupError::Server(err) => QueryError::Other(err),
BasebackupError::Shutdown => QueryError::Shutdown,
}
}
}
impl From<BasebackupError> for tonic::Status {
fn from(err: BasebackupError) -> Self {
use tonic::Code;
let code = match &err {
BasebackupError::Client(_, _) => Code::Cancelled,
BasebackupError::Server(_) => Code::Internal,
BasebackupError::Shutdown => Code::Unavailable,
};
tonic::Status::new(code, err.to_string())
}
}
/// Create basebackup with non-rel data in it.
/// Only include relational data if 'full_backup' is true.
///
@@ -272,7 +248,7 @@ where
async fn flush(&mut self) -> Result<(), BasebackupError> {
let nblocks = self.buf.len() / BLCKSZ as usize;
let (kind, segno) = self.current_segment.take().unwrap();
let segname = format!("{kind}/{segno:>04X}");
let segname = format!("{}/{:>04X}", kind.to_str(), segno);
let header = new_tar_header(&segname, self.buf.len() as u64)?;
self.ar
.append(&header, self.buf.as_slice())
@@ -371,7 +347,7 @@ where
.await?
.partition(
self.timeline.get_shard_identity(),
self.timeline.conf.max_get_vectored_keys.get() as u64 * BLCKSZ as u64,
Timeline::MAX_GET_VECTORED_KEYS * BLCKSZ as u64,
);
let mut slru_builder = SlruSegmentsBuilder::new(&mut self.ar);

View File

@@ -804,7 +804,7 @@ fn start_pageserver(
} else {
None
},
basebackup_cache,
basebackup_cache.clone(),
);
// Spawn a Pageserver gRPC server task. It will spawn separate tasks for
@@ -816,10 +816,12 @@ fn start_pageserver(
let mut page_service_grpc = None;
if let Some(grpc_listener) = grpc_listener {
page_service_grpc = Some(page_service::spawn_grpc(
conf,
tenant_manager.clone(),
grpc_auth,
otel_guard.as_ref().map(|g| g.dispatch.clone()),
grpc_listener,
basebackup_cache,
)?);
}

View File

@@ -14,10 +14,7 @@ use std::time::Duration;
use anyhow::{Context, bail, ensure};
use camino::{Utf8Path, Utf8PathBuf};
use once_cell::sync::OnceCell;
use pageserver_api::config::{
DiskUsageEvictionTaskConfig, MaxGetVectoredKeys, MaxVectoredReadBytes,
PageServicePipeliningConfig, PageServicePipeliningConfigPipelined, PostHogConfig,
};
use pageserver_api::config::{DiskUsageEvictionTaskConfig, MaxVectoredReadBytes, PostHogConfig};
use pageserver_api::models::ImageCompressionAlgorithm;
use pageserver_api::shard::TenantShardId;
use pem::Pem;
@@ -188,9 +185,6 @@ pub struct PageServerConf {
pub max_vectored_read_bytes: MaxVectoredReadBytes,
/// Maximum number of keys to be read in a single get_vectored call.
pub max_get_vectored_keys: MaxGetVectoredKeys,
pub image_compression: ImageCompressionAlgorithm,
/// Whether to offload archived timelines automatically
@@ -410,7 +404,6 @@ impl PageServerConf {
secondary_download_concurrency,
ingest_batch_size,
max_vectored_read_bytes,
max_get_vectored_keys,
image_compression,
timeline_offloading,
ephemeral_bytes_per_memory_kb,
@@ -477,7 +470,6 @@ impl PageServerConf {
secondary_download_concurrency,
ingest_batch_size,
max_vectored_read_bytes,
max_get_vectored_keys,
image_compression,
timeline_offloading,
ephemeral_bytes_per_memory_kb,
@@ -606,19 +598,6 @@ impl PageServerConf {
)
})?;
if let PageServicePipeliningConfig::Pipelined(PageServicePipeliningConfigPipelined {
max_batch_size,
..
}) = conf.page_service_pipelining
{
if max_batch_size.get() > conf.max_get_vectored_keys.get() {
return Err(anyhow::anyhow!(
"`max_batch_size` ({max_batch_size}) must be less than or equal to `max_get_vectored_keys` ({})",
conf.max_get_vectored_keys.get()
));
}
};
Ok(conf)
}
@@ -706,7 +685,6 @@ impl ConfigurableSemaphore {
mod tests {
use camino::Utf8PathBuf;
use rstest::rstest;
use utils::id::NodeId;
use super::PageServerConf;
@@ -746,28 +724,4 @@ mod tests {
PageServerConf::parse_and_validate(NodeId(0), config_toml, &workdir)
.expect_err("parse_and_validate should fail for endpoint without scheme");
}
#[rstest]
#[case(32, 32, true)]
#[case(64, 32, false)]
#[case(64, 64, true)]
#[case(128, 128, true)]
fn test_config_max_batch_size_is_valid(
#[case] max_batch_size: usize,
#[case] max_get_vectored_keys: usize,
#[case] is_valid: bool,
) {
let input = format!(
r#"
control_plane_api = "http://localhost:6666"
max_get_vectored_keys = {max_get_vectored_keys}
page_service_pipelining = {{ mode="pipelined", execution="concurrent-futures", max_batch_size={max_batch_size}, batching="uniform-lsn" }}
"#,
);
let config_toml = toml_edit::de::from_str::<pageserver_api::config::ConfigToml>(&input)
.expect("config has valid fields");
let workdir = Utf8PathBuf::from("/nonexistent");
let result = PageServerConf::parse_and_validate(NodeId(0), config_toml, &workdir);
assert_eq!(result.is_ok(), is_valid);
}
}

View File

@@ -15,7 +15,6 @@ use metrics::{
register_int_gauge, register_int_gauge_vec, register_uint_gauge, register_uint_gauge_vec,
};
use once_cell::sync::Lazy;
use pageserver_api::config::defaults::DEFAULT_MAX_GET_VECTORED_KEYS;
use pageserver_api::config::{
PageServicePipeliningConfig, PageServicePipeliningConfigPipelined,
PageServiceProtocolPipelinedBatchingStrategy, PageServiceProtocolPipelinedExecutionStrategy,
@@ -33,6 +32,7 @@ use crate::config::PageServerConf;
use crate::context::{PageContentKind, RequestContext};
use crate::pgdatadir_mapping::DatadirModificationStats;
use crate::task_mgr::TaskKind;
use crate::tenant::Timeline;
use crate::tenant::layer_map::LayerMap;
use crate::tenant::mgr::TenantSlot;
use crate::tenant::storage_layer::{InMemoryLayer, PersistentLayerDesc};
@@ -1939,7 +1939,7 @@ static SMGR_QUERY_TIME_GLOBAL: Lazy<HistogramVec> = Lazy::new(|| {
});
static PAGE_SERVICE_BATCH_SIZE_BUCKETS_GLOBAL: Lazy<Vec<f64>> = Lazy::new(|| {
(1..=u32::try_from(DEFAULT_MAX_GET_VECTORED_KEYS).unwrap())
(1..=u32::try_from(Timeline::MAX_GET_VECTORED_KEYS).unwrap())
.map(|v| v.into())
.collect()
});
@@ -1957,7 +1957,7 @@ static PAGE_SERVICE_BATCH_SIZE_BUCKETS_PER_TIMELINE: Lazy<Vec<f64>> = Lazy::new(
let mut buckets = Vec::new();
for i in 0.. {
let bucket = 1 << i;
if bucket > u32::try_from(DEFAULT_MAX_GET_VECTORED_KEYS).unwrap() {
if bucket > u32::try_from(Timeline::MAX_GET_VECTORED_KEYS).unwrap() {
break;
}
buckets.push(bucket.into());

File diff suppressed because it is too large Load Diff

View File

@@ -431,10 +431,10 @@ impl Timeline {
GetVectoredError::InvalidLsn(e) => {
Err(anyhow::anyhow!("invalid LSN: {e:?}").into())
}
// NB: this should never happen in practice because we limit batch size to be smaller than max_get_vectored_keys
// NB: this should never happen in practice because we limit MAX_GET_VECTORED_KEYS
// TODO: we can prevent this error class by moving this check into the type system
GetVectoredError::Oversized(err, max) => {
Err(anyhow::anyhow!("batching oversized: {err} > {max}").into())
GetVectoredError::Oversized(err) => {
Err(anyhow::anyhow!("batching oversized: {err:?}").into())
}
};
@@ -471,19 +471,8 @@ impl Timeline {
let rels = self.list_rels(spcnode, dbnode, version, ctx).await?;
if rels.is_empty() {
return Ok(0);
}
// Pre-deserialize the rel directory to avoid duplicated work in `get_relsize_cached`.
let reldir_key = rel_dir_to_key(spcnode, dbnode);
let buf = version.get(self, reldir_key, ctx).await?;
let reldir = RelDirectory::des(&buf)?;
for rel in rels {
let n_blocks = self
.get_rel_size_in_reldir(rel, version, Some((reldir_key, &reldir)), ctx)
.await?;
let n_blocks = self.get_rel_size(rel, version, ctx).await?;
total_blocks += n_blocks as usize;
}
Ok(total_blocks)
@@ -498,19 +487,6 @@ impl Timeline {
tag: RelTag,
version: Version<'_>,
ctx: &RequestContext,
) -> Result<BlockNumber, PageReconstructError> {
self.get_rel_size_in_reldir(tag, version, None, ctx).await
}
/// Get size of a relation file. The relation must exist, otherwise an error is returned.
///
/// See [`Self::get_rel_exists_in_reldir`] on why we need `deserialized_reldir_v1`.
pub(crate) async fn get_rel_size_in_reldir(
&self,
tag: RelTag,
version: Version<'_>,
deserialized_reldir_v1: Option<(Key, &RelDirectory)>,
ctx: &RequestContext,
) -> Result<BlockNumber, PageReconstructError> {
if tag.relnode == 0 {
return Err(PageReconstructError::Other(
@@ -523,9 +499,7 @@ impl Timeline {
}
if (tag.forknum == FSM_FORKNUM || tag.forknum == VISIBILITYMAP_FORKNUM)
&& !self
.get_rel_exists_in_reldir(tag, version, deserialized_reldir_v1, ctx)
.await?
&& !self.get_rel_exists(tag, version, ctx).await?
{
// FIXME: Postgres sometimes calls smgrcreate() to create
// FSM, and smgrnblocks() on it immediately afterwards,
@@ -547,28 +521,11 @@ impl Timeline {
///
/// Only shard 0 has a full view of the relations. Other shards only know about relations that
/// the shard stores pages for.
///
pub(crate) async fn get_rel_exists(
&self,
tag: RelTag,
version: Version<'_>,
ctx: &RequestContext,
) -> Result<bool, PageReconstructError> {
self.get_rel_exists_in_reldir(tag, version, None, ctx).await
}
/// Does the relation exist? With a cached deserialized `RelDirectory`.
///
/// There are some cases where the caller loops across all relations. In that specific case,
/// the caller should obtain the deserialized `RelDirectory` first and then call this function
/// to avoid duplicated work of deserliazation. This is a hack and should be removed by introducing
/// a new API (e.g., `get_rel_exists_batched`).
pub(crate) async fn get_rel_exists_in_reldir(
&self,
tag: RelTag,
version: Version<'_>,
deserialized_reldir_v1: Option<(Key, &RelDirectory)>,
ctx: &RequestContext,
) -> Result<bool, PageReconstructError> {
if tag.relnode == 0 {
return Err(PageReconstructError::Other(
@@ -611,17 +568,6 @@ impl Timeline {
// fetch directory listing (old)
let key = rel_dir_to_key(tag.spcnode, tag.dbnode);
if let Some((cached_key, dir)) = deserialized_reldir_v1 {
if cached_key == key {
return Ok(dir.rels.contains(&(tag.relnode, tag.forknum)));
} else if cfg!(test) || cfg!(feature = "testing") {
panic!("cached reldir key mismatch: {cached_key} != {key}");
} else {
warn!("cached reldir key mismatch: {cached_key} != {key}");
}
// Fallback to reading the directory from the datadir.
}
let buf = version.get(self, key, ctx).await?;
let dir = RelDirectory::des(&buf)?;
@@ -719,7 +665,7 @@ impl Timeline {
let batches = keyspace.partition(
self.get_shard_identity(),
self.conf.max_get_vectored_keys.get() as u64 * BLCKSZ as u64,
Timeline::MAX_GET_VECTORED_KEYS * BLCKSZ as u64,
);
let io_concurrency = IoConcurrency::spawn_from_conf(
@@ -959,7 +905,7 @@ impl Timeline {
let batches = keyspace.partition(
self.get_shard_identity(),
self.conf.max_get_vectored_keys.get() as u64 * BLCKSZ as u64,
Timeline::MAX_GET_VECTORED_KEYS * BLCKSZ as u64,
);
let io_concurrency = IoConcurrency::spawn_from_conf(

View File

@@ -7197,7 +7197,7 @@ mod tests {
let end = desc
.key_range
.start
.add(tenant.conf.max_get_vectored_keys.get() as u32);
.add(Timeline::MAX_GET_VECTORED_KEYS.try_into().unwrap());
reads.push(KeySpace {
ranges: vec![start..end],
});
@@ -11260,11 +11260,11 @@ mod tests {
let mut keyspaces_at_lsn: HashMap<Lsn, KeySpaceRandomAccum> = HashMap::default();
let mut used_keys: HashSet<Key> = HashSet::default();
while used_keys.len() < tenant.conf.max_get_vectored_keys.get() {
while used_keys.len() < Timeline::MAX_GET_VECTORED_KEYS as usize {
let selected_lsn = interesting_lsns.choose(&mut random).expect("not empty");
let mut selected_key = start_key.add(random.gen_range(0..KEY_DIMENSION_SIZE));
while used_keys.len() < tenant.conf.max_get_vectored_keys.get() {
while used_keys.len() < Timeline::MAX_GET_VECTORED_KEYS as usize {
if used_keys.contains(&selected_key)
|| selected_key >= start_key.add(KEY_DIMENSION_SIZE)
{

View File

@@ -817,8 +817,8 @@ pub(crate) enum GetVectoredError {
#[error("timeline shutting down")]
Cancelled,
#[error("requested too many keys: {0} > {1}")]
Oversized(u64, u64),
#[error("requested too many keys: {0} > {}", Timeline::MAX_GET_VECTORED_KEYS)]
Oversized(u64),
#[error("requested at invalid LSN: {0}")]
InvalidLsn(Lsn),
@@ -950,18 +950,6 @@ pub(crate) enum WaitLsnError {
Timeout(String),
}
impl From<WaitLsnError> for tonic::Status {
fn from(err: WaitLsnError) -> Self {
use tonic::Code;
let code = match &err {
WaitLsnError::Timeout(_) => Code::Internal,
WaitLsnError::BadState(_) => Code::Internal,
WaitLsnError::Shutdown => Code::Unavailable,
};
tonic::Status::new(code, err.to_string())
}
}
// The impls below achieve cancellation mapping for errors.
// Perhaps there's a way of achieving this with less cruft.
@@ -1019,7 +1007,7 @@ impl From<GetVectoredError> for PageReconstructError {
match e {
GetVectoredError::Cancelled => PageReconstructError::Cancelled,
GetVectoredError::InvalidLsn(_) => PageReconstructError::Other(anyhow!("Invalid LSN")),
err @ GetVectoredError::Oversized(_, _) => PageReconstructError::Other(err.into()),
err @ GetVectoredError::Oversized(_) => PageReconstructError::Other(err.into()),
GetVectoredError::MissingKey(err) => PageReconstructError::MissingKey(err),
GetVectoredError::GetReadyAncestorError(err) => PageReconstructError::from(err),
GetVectoredError::Other(err) => PageReconstructError::Other(err),
@@ -1199,6 +1187,7 @@ impl Timeline {
}
}
pub(crate) const MAX_GET_VECTORED_KEYS: u64 = 32;
pub(crate) const LAYERS_VISITED_WARN_THRESHOLD: u32 = 100;
/// Look up multiple page versions at a given LSN
@@ -1213,12 +1202,9 @@ impl Timeline {
) -> Result<BTreeMap<Key, Result<Bytes, PageReconstructError>>, GetVectoredError> {
let total_keyspace = query.total_keyspace();
let key_count = total_keyspace.total_raw_size();
if key_count > self.conf.max_get_vectored_keys.get() {
return Err(GetVectoredError::Oversized(
key_count as u64,
self.conf.max_get_vectored_keys.get() as u64,
));
let key_count = total_keyspace.total_raw_size().try_into().unwrap();
if key_count > Timeline::MAX_GET_VECTORED_KEYS {
return Err(GetVectoredError::Oversized(key_count));
}
for range in &total_keyspace.ranges {
@@ -5272,7 +5258,7 @@ impl Timeline {
key = key.next();
// Maybe flush `key_rest_accum`
if key_request_accum.raw_size() >= self.conf.max_get_vectored_keys.get() as u64
if key_request_accum.raw_size() >= Timeline::MAX_GET_VECTORED_KEYS
|| (last_key_in_range && key_request_accum.raw_size() > 0)
{
let query =

View File

@@ -201,8 +201,8 @@ async fn prepare_import(
.await;
match res {
Ok(_) => break,
Err(_err) => {
info!("indefinitely waiting for pgdata to finish");
Err(err) => {
info!(?err, "indefinitely waiting for pgdata to finish");
if tokio::time::timeout(std::time::Duration::from_secs(10), cancel.cancelled())
.await
.is_ok()

View File

@@ -471,8 +471,6 @@ impl Plan {
last_completed_job_idx = job_idx;
if last_completed_job_idx % checkpoint_every == 0 {
tracing::info!(last_completed_job_idx, jobs=%jobs_in_plan, "Checkpointing import status");
let progress = ShardImportProgressV1 {
jobs: jobs_in_plan,
completed: last_completed_job_idx,
@@ -494,6 +492,8 @@ impl Plan {
anyhow::anyhow!("Shut down while putting timeline import status")
})?;
}
tracing::info!(last_completed_job_idx, jobs=%jobs_in_plan, "Checkpointing import status");
},
Some(Err(_)) => {
anyhow::bail!(

25
paster/Cargo.toml Normal file
View File

@@ -0,0 +1,25 @@
[package]
name = "paster"
version = "0.1.0"
edition.workspace = true
license.workspace = true
[dependencies]
anyhow.workspace = true
axum-extra = { workspace = true, features = ["cookie", "cookie-private"] }
axum.workspace = true
base64.workspace = true
chrono.workspace = true
nanoid = { version = "0.4.0", default-features = false }
rand.workspace = true
reqwest.workspace = true
rustls-native-certs.workspace = true
rustls.workspace = true
serde.workspace = true
serde_json.workspace = true
time = { version = "0.3.36", default-features = false }
tokio-postgres-rustls.workspace = true
tokio-postgres.workspace = true
tokio.workspace = true
tracing-subscriber.workspace = true
tracing.workspace = true
workspace_hack.workspace = true

View File

@@ -0,0 +1,18 @@
CREATE TABLE IF NOT EXISTS users (
id SERIAL PRIMARY KEY,
sub VARCHAR(100) NOT NULL UNIQUE
);
CREATE TABLE IF NOT EXISTS sessions (
id SERIAL PRIMARY KEY,
user_id INT NOT NULL UNIQUE REFERENCES users(id),
session_id VARCHAR NOT NULL,
expires_at TIMESTAMP WITH TIME ZONE NOT NULL
);
CREATE TABLE IF NOT EXISTS pastes (
id SERIAL PRIMARY KEY,
user_id INT NOT NULL REFERENCES users(id),
paste text NOT NULL,
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
)

353
paster/src/main.rs Normal file
View File

@@ -0,0 +1,353 @@
//! Paster is a service to share logs or code snippets outside of
//! Slack, not relying on public services
use anyhow::Result;
use shortener::google_oauth_gate::{AuthRequest, State, UserId};
use axum::Form;
use axum::extract::{FromRef, FromRequestParts, Path, Query, State as AxumStateT};
use axum::http::StatusCode;
use axum::response::{Html, IntoResponse};
use axum::response::{Redirect, Response};
use axum::routing::get;
use axum_extra::extract::PrivateCookieJar;
use axum_extra::extract::cookie::{Cookie, Key};
use chrono::{Duration, Local, TimeZone, Utc};
use core::num::NonZeroI32;
use serde::Deserialize;
use std::env;
use std::sync::Arc;
use tracing::{error, info};
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
const SOCKET: &str = "127.0.0.1:12344";
const HOST: &str = "http://127.0.0.1:12344";
const ALLOWED_OAUTH_DOMAIN: &str = "neon.tech";
fn oauth_redirect_url() -> String {
format!("{HOST}{AUTHORIZED_ROUTE}")
}
#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| format!("{}=info", env!("CARGO_CRATE_NAME")).into()),
)
.with(tracing_subscriber::fmt::layer())
.init();
let oauth_client_id = env::var("OAUTH_CLIENT_ID").expect("Missing OAUTH_CLIENT_ID");
let oauth_client_secret = env::var("OAUTH_CLIENT_SECRET").expect("Missing OAUTH_CLIENT_SECRET");
let db_connstr = env::var("DB_CONNSTR").expect("Missing DB_CONNSTR");
let mut roots = rustls::RootCertStore::empty();
for cert in rustls_native_certs::load_native_certs().expect("could not load platform certs") {
roots.add(cert).unwrap();
}
let config = rustls::ClientConfig::builder()
.with_root_certificates(roots)
.with_no_client_auth();
let tls = tokio_postgres_rustls::MakeRustlsConnect::new(config);
info!("initialized TLS");
let (db_client, db_conn) = tokio_postgres::connect(&db_connstr, tls).await?;
tokio::spawn(async move {
if let Err(err) = db_conn.await {
error!(%err, "connecting to database");
std::process::exit(1);
}
});
info!("connected to database");
let state = InnerState {
db_client,
cookie_jar_key: Key::generate(),
oauth_client_id,
oauth_client_secret,
};
let router = axum::Router::new()
.route("/", get(index).post(paste))
.route("/authorize", get(authorize))
.route(AUTHORIZED_ROUTE, get(authorized))
.route("/{id}", get(view_paste))
.with_state(State { 0: Arc::new(state) });
let listener = tokio::net::TcpListener::bind(SOCKET)
.await
.expect("failed to bind TcpListener");
info!("listening on {SOCKET}");
axum::serve(listener, router).await.unwrap();
Ok(())
}
#[derive(Deserialize)]
pub struct UserId {
id: NonZeroI32,
}
impl axum::extract::OptionalFromRequestParts<State> for UserId {
type Rejection = Response;
async fn from_request_parts(
parts: &mut axum::http::request::Parts,
state: &State,
) -> Result<Option<Self>, Self::Rejection> {
let jar: PrivateCookieJar = PrivateCookieJar::from_request_parts(parts, state)
.await
.unwrap(); // infallible
let Some(session_id) = jar.get(COOKIE_SID).map(|cookie| cookie.value().to_owned()) else {
return Ok(None);
};
let client = &state.db_client;
let query = client
.query_opt(
"SELECT user_id FROM sessions WHERE session_id = $1",
&[&session_id],
)
.await;
let id = match query {
Ok(Some(row)) => row.get::<usize, i32>(0),
Ok(None) => return Ok(None),
Err(err) => {
error!(%err, "querying user session");
return Ok(None);
}
};
let id = NonZeroI32::new(id).unwrap(); // postgres id guaranteed not to be zero
Ok(Some(Self { id }))
}
}
#[derive(Deserialize)]
struct Paste {
paste: String,
}
fn paste_form() -> Html<String> {
Html(
r#"
<form method="post">
<textarea name="paste" style="width:100%;height:80%"></textarea>
<input type="submit" value="Paste" style="margin-top:10px">
</form>"#
.to_string(),
)
}
fn authorize_link(paste_id: i32) -> String {
format!("<a href=\"/authorize?paste_id={paste_id}\">Authorize</a>")
}
async fn index(user: Option<UserId>) -> Html<String> {
if user.is_some() {
return paste_form();
}
Html(authorize_link(0))
}
async fn paste(
state: AxumState,
user: Option<UserId>,
Form(Paste { paste }): Form<Paste>,
) -> Response {
let user_id = match user {
None => return StatusCode::FORBIDDEN.into_response(),
Some(user) => user.id,
};
if paste.is_empty() {
return paste_form().into_response();
}
let query = state
.db_client
.query_one(
"INSERT INTO pastes (user_id, paste) VALUES ($1, $2) RETURNING id",
&[&user_id.get(), &paste],
)
.await;
let id = match query {
Ok(row) => row.get::<usize, i32>(0),
Err(err) => {
error!(%err, "inserting paste");
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
};
Redirect::to(&format!("/{id}")).into_response()
}
async fn view_paste(state: AxumState, user: Option<UserId>, Path(paste_id): Path<i32>) -> Response {
let user_id = match user {
None => return Html(authorize_link(paste_id)).into_response(),
Some(user) => user.id,
};
let query = state
.db_client
.query_opt("SELECT paste FROM pastes WHERE id = $1", &[&paste_id])
.await;
let row = match query {
Ok(None) => return StatusCode::NOT_FOUND.into_response(),
Ok(Some(row)) => row,
Err(err) => {
error!(%err, %paste_id, %user_id, "querying paste");
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
};
row.get::<usize, String>(0).into_response()
}
#[derive(Deserialize)]
struct AuthRequest {
code: String,
}
#[derive(Deserialize)]
struct AuthResponse {
access_token: String,
id_token: String,
expires_in: u64,
}
#[derive(Deserialize)]
struct UserInfo {
hd: String,
sub: String,
}
fn decode_id_token(token: String) -> Option<UserInfo> {
let payload = token.split(".").skip(1).take(1).collect::<Vec<&str>>();
let decoded = base64::decode_config(payload.get(0)?, base64::STANDARD_NO_PAD).ok()?;
serde_json::from_slice::<UserInfo>(&decoded).ok()
}
#[derive(Deserialize)]
struct AuthorizeQuery {
paste_id: i32,
}
fn generate_csrf_token(num_bytes: u32) -> String {
use rand::{Rng, thread_rng};
let random_bytes: Vec<u8> = (0..num_bytes).map(|_| thread_rng().r#gen::<u8>()).collect();
base64::encode_config(&random_bytes, base64::URL_SAFE_NO_PAD)
}
async fn authorize(
state: AxumState,
jar: PrivateCookieJar,
Query(AuthorizeQuery { paste_id }): Query<AuthorizeQuery>,
) -> (PrivateCookieJar, Redirect) {
let csrf_token = generate_csrf_token(16);
let client_id = &state.oauth_client_id;
let redirect_uri = oauth_redirect_url();
let auth_url = format!(
"{OAUTH_BASE_URL}?response_type=code\
&client_id={client_id}\
&state={csrf_token}\
&redirect_uri={redirect_uri}\
&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email"
);
let redirect_cookie = Cookie::build((COOKIE_REDIRECT, paste_id.to_string()))
.path("/")
//.TODO secure(true) not true for localhost
//.domain(COOKIE_DOMAIN)
.secure(false)
.same_site(axum_extra::extract::cookie::SameSite::Lax)
.http_only(true)
.build();
let csrf_cookie = Cookie::build((COOKIE_CSRF, csrf_token))
.path("/")
//.TODO secure(true) not true for localhost
//.domain(COOKIE_DOMAIN)
.secure(false)
.same_site(axum_extra::extract::cookie::SameSite::Lax)
.http_only(true)
.build();
let jar = jar.add(redirect_cookie).add(csrf_cookie);
let url = Into::<String>::into(auth_url);
(jar, Redirect::to(&url))
}
async fn authorized(
state: AxumState,
jar: PrivateCookieJar,
Query(auth_request): Query<AuthRequest>,
) -> Result<(PrivateCookieJar, Redirect), Response> {
let params = [
("grant_type", "authorization_code"),
("redirect_uri", &oauth_redirect_url()),
("code", &auth_request.code),
("client_id", &state.oauth_client_id),
("client_secret", &state.oauth_client_secret),
];
let auth_response = reqwest::Client::new()
.post(OAUTH_TOKEN_URL)
.form(&params)
.send()
.await
.map_err(|err| {
error!(%err, "exchanging oauth code for token");
StatusCode::INTERNAL_SERVER_ERROR.into_response()
})?
.json::<AuthResponse>()
.await
.map_err(|err| {
error!(%err, "deserializing access token response");
StatusCode::INTERNAL_SERVER_ERROR.into_response()
})?;
let Some(UserInfo { hd, sub }) = decode_id_token(auth_response.id_token) else {
error!("Failed to decode response id token");
return Err(StatusCode::UNAUTHORIZED.into_response());
};
if hd != ALLOWED_OAUTH_DOMAIN {
error!(hd, "Domain doesn't match {ALLOWED_OAUTH_DOMAIN}");
return Err(StatusCode::UNAUTHORIZED.into_response());
}
let token_duration = Duration::try_seconds(auth_response.expires_in as i64).unwrap();
let expires_at = Utc.from_utc_datetime(&(Local::now().naive_local() + token_duration));
let cookie_max_age = time::Duration::new(token_duration.num_seconds(), 0);
let session_cookie = Cookie::build((COOKIE_SID, auth_response.access_token.clone()))
.path("/")
//.TODO secure(true) not true for localhost
//.domain(COOKIE_DOMAIN)
.secure(false)
.same_site(axum_extra::extract::cookie::SameSite::Lax)
.http_only(true)
.max_age(cookie_max_age)
.build();
state
.db_client
.query(
"WITH user_insert AS (\
INSERT INTO users (sub) VALUES ($1) \
ON CONFLICT (sub) DO UPDATE SET sub = excluded.sub RETURNING id)\
INSERT INTO sessions (user_id, session_id, expires_at) \
SELECT id, $2, $3 FROM user_insert \
ON CONFLICT (user_id) DO UPDATE SET \
session_id = excluded.session_id, \
expires_at = excluded.expires_at",
&[&sub, &auth_response.access_token, &expires_at],
)
.await
.map_err(|err| {
error!(%err, %sub, "updating session");
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
})?;
let csrf_cookie = jar.get(COOKIE_CSRF).unwrap(); // set in authorize()
let jar = jar.remove(csrf_cookie).add(session_cookie);
match jar.get(COOKIE_REDIRECT) {
Some(redirect_cookie) => {
let mut value = redirect_cookie.value_trimmed();
if value == "0" {
value = "";
}
let redirect_url = format!("/{value}");
Ok((jar.remove(redirect_cookie), Redirect::to(&redirect_url)))
}
None => Ok((jar, Redirect::to("/"))),
}
}

View File

@@ -21,7 +21,6 @@
#include "access/xlog.h"
#include "funcapi.h"
#include "miscadmin.h"
#include "common/file_utils.h"
#include "common/hashfn.h"
#include "pgstat.h"
#include "port/pg_iovec.h"
@@ -65,7 +64,7 @@
*
* Cache is always reconstructed at node startup, so we do not need to save mapping somewhere and worry about
* its consistency.
*
*
* ## Holes
*
@@ -77,15 +76,13 @@
* fallocate(FALLOC_FL_PUNCH_HOLE) call. The nominal size of the file doesn't
* shrink, but the disk space it uses does.
*
* Each hole is tracked in a freelist. The freelist consists of two parts: a
* fixed-size array in shared memory, and a linked chain of on-disk
* blocks. When the in-memory array fills up, it's flushed to a new on-disk
* chunk. If the soft limit is raised again, we reuse the holes before
* extending the nominal size of the file.
*
* The in-memory freelist array is protected by 'lfc_lock', while the on-disk
* chain is protected by a separate 'lfc_freelist_lock'. Locking rule to
* avoid deadlocks: always acquire lfc_freelist_lock first, then lfc_lock.
* Each hole is tracked by a dummy FileCacheEntry, which are kept in the
* 'holes' linked list. They are entered into the chunk hash table, with a
* special key where the blockNumber is used to store the 'offset' of the
* hole, and all other fields are zero. Holes are never looked up in the hash
* table, we only enter them there to have a FileCacheEntry that we can keep
* in the linked list. If the soft limit is raised again, we reuse the holes
* before extending the nominal size of the file.
*/
/* Local file storage allocation chunk.
@@ -103,8 +100,6 @@
#define SIZE_MB_TO_CHUNKS(size) ((uint32)((size) * MB / BLCKSZ >> lfc_chunk_size_log))
#define BLOCK_TO_CHUNK_OFF(blkno) ((blkno) & (lfc_blocks_per_chunk-1))
#define INVALID_OFFSET (0xffffffff)
/*
* Blocks are read or written to LFC file outside LFC critical section.
* To synchronize access to such block, writer set state of such block to PENDING.
@@ -128,11 +123,11 @@ typedef struct FileCacheEntry
uint32 hash;
uint32 offset;
uint32 access_count;
dlist_node list_node; /* LRU list node */
dlist_node list_node; /* LRU/holes list node */
uint32 state[FLEXIBLE_ARRAY_MEMBER]; /* two bits per block */
} FileCacheEntry;
#define FILE_CACHE_ENTRY_SIZE MAXALIGN(offsetof(FileCacheEntry, state) + (lfc_blocks_per_chunk*2+31)/32*4)
#define FILE_CACHE_ENRTY_SIZE MAXALIGN(offsetof(FileCacheEntry, state) + (lfc_blocks_per_chunk*2+31)/32*4)
#define GET_STATE(entry, i) (((entry)->state[(i) / 16] >> ((i) % 16 * 2)) & 3)
#define SET_STATE(entry, i, new_state) (entry)->state[(i) / 16] = ((entry)->state[(i) / 16] & ~(3 << ((i) % 16 * 2))) | ((new_state) << ((i) % 16 * 2))
@@ -166,6 +161,7 @@ typedef struct FileCacheControl
uint64 evicted_pages; /* number of evicted pages */
dlist_head lru; /* double linked list for LRU replacement
* algorithm */
dlist_head holes; /* double linked list of punched holes */
HyperLogLogState wss_estimation; /* estimation of working set size */
ConditionVariable cv[N_COND_VARS]; /* turnstile of condition variables */
PrewarmWorkerState prewarm_workers[MAX_PREWARM_WORKERS];
@@ -176,35 +172,17 @@ typedef struct FileCacheControl
bool prewarm_active;
bool prewarm_canceled;
dsm_handle prewarm_lfc_state_handle;
/*
* Free list. This is large enough to hold one chunks worth of entries.
*/
uint32 freelist_size;
uint32 freelist_head;
uint32 num_free_pages;
uint32 free_pages[FLEXIBLE_ARRAY_MEMBER];
} FileCacheControl;
typedef struct FreeListChunk
{
uint32 next;
uint32 num_free_pages;
uint32 free_pages[FLEXIBLE_ARRAY_MEMBER];
} FreeListChunk;
#define FILE_CACHE_STATE_MAGIC 0xfcfcfcfc
#define FILE_CACHE_STATE_BITMAP(fcs) ((uint8*)&(fcs)->chunks[(fcs)->n_chunks])
#define FILE_CACHE_STATE_SIZE_FOR_CHUNKS(n_chunks) (sizeof(FileCacheState) + (n_chunks)*sizeof(BufferTag) + (((n_chunks) * lfc_blocks_per_chunk)+7)/8)
#define FILE_CACHE_STATE_SIZE(fcs) (sizeof(FileCacheState) + (fcs->n_chunks)*sizeof(BufferTag) + (((fcs->n_chunks) << fcs->chunk_size_log)+7)/8)
#define FREELIST_ENTRIES_PER_CHUNK(c) ((c) * BLCKSZ / sizeof(uint32) - 2)
static HTAB *lfc_hash;
static int lfc_desc = -1;
static LWLockId lfc_lock;
static LWLockId lfc_freelist_lock;
static int lfc_max_size;
static int lfc_size_limit;
static int lfc_prewarm_limit;
@@ -227,11 +205,6 @@ bool AmPrewarmWorker;
#define LFC_ENABLED() (lfc_ctl->limit != 0)
static bool freelist_push(uint32 offset);
static bool freelist_prepare_pop(void);
static uint32 freelist_pop(void);
static bool freelist_is_empty(void);
/*
* Close LFC file if opened.
* All backends should close their LFC files once LFC is disabled.
@@ -275,9 +248,7 @@ lfc_switch_off(void)
lfc_ctl->used_pages = 0;
lfc_ctl->limit = 0;
dlist_init(&lfc_ctl->lru);
lfc_ctl->freelist_head = INVALID_OFFSET;
lfc_ctl->num_free_pages = 0;
dlist_init(&lfc_ctl->holes);
/*
* We need to use unlink to to avoid races in LFC write, because it is not
@@ -346,7 +317,6 @@ lfc_ensure_opened(void)
static void
lfc_shmem_startup(void)
{
size_t size;
bool found;
static HASHCTL info;
@@ -357,19 +327,15 @@ lfc_shmem_startup(void)
LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE);
size = offsetof(FileCacheControl, free_pages);
size += FREELIST_ENTRIES_PER_CHUNK(lfc_blocks_per_chunk) * sizeof(uint32);
lfc_ctl = (FileCacheControl *) ShmemInitStruct("lfc", size, &found);
lfc_ctl = (FileCacheControl *) ShmemInitStruct("lfc", sizeof(FileCacheControl), &found);
if (!found)
{
int fd;
uint32 n_chunks = SIZE_MB_TO_CHUNKS(lfc_max_size);
lfc_lock = (LWLockId) GetNamedLWLockTranche("lfc_lock");
lfc_freelist_lock = (LWLockId) GetNamedLWLockTranche("lfc_freelist_lock");
info.keysize = sizeof(BufferTag);
info.entrysize = FILE_CACHE_ENTRY_SIZE;
info.entrysize = FILE_CACHE_ENRTY_SIZE;
/*
* n_chunks+1 because we add new element to hash table before eviction
@@ -379,12 +345,9 @@ lfc_shmem_startup(void)
n_chunks + 1, n_chunks + 1,
&info,
HASH_ELEM | HASH_BLOBS);
memset(lfc_ctl, 0, offsetof(FileCacheControl, free_pages));
memset(lfc_ctl, 0, sizeof(FileCacheControl));
dlist_init(&lfc_ctl->lru);
lfc_ctl->freelist_size = FREELIST_ENTRIES_PER_CHUNK(lfc_blocks_per_chunk);
lfc_ctl->freelist_head = INVALID_OFFSET;
lfc_ctl->num_free_pages = 0;
dlist_init(&lfc_ctl->holes);
/* Initialize hyper-log-log structure for estimating working set size */
initSHLL(&lfc_ctl->wss_estimation);
@@ -413,20 +376,13 @@ lfc_shmem_startup(void)
static void
lfc_shmem_request(void)
{
size_t size;
#if PG_VERSION_NUM>=150000
if (prev_shmem_request_hook)
prev_shmem_request_hook();
#endif
size = offsetof(FileCacheControl, free_pages);
size += FREELIST_ENTRIES_PER_CHUNK(lfc_blocks_per_chunk) * sizeof(uint32);
size += hash_estimate_size(SIZE_MB_TO_CHUNKS(lfc_max_size) + 1, FILE_CACHE_ENTRY_SIZE);
RequestAddinShmemSpace(size);
RequestAddinShmemSpace(sizeof(FileCacheControl) + hash_estimate_size(SIZE_MB_TO_CHUNKS(lfc_max_size) + 1, FILE_CACHE_ENRTY_SIZE));
RequestNamedLWLockTranche("lfc_lock", 1);
RequestNamedLWLockTranche("lfc_freelist_lock", 2);
}
static bool
@@ -479,14 +435,12 @@ lfc_change_limit_hook(int newval, void *extra)
if (!lfc_ctl || !is_normal_backend())
return;
LWLockAcquire(lfc_freelist_lock, LW_EXCLUSIVE);
LWLockAcquire(lfc_lock, LW_EXCLUSIVE);
/* Open LFC file only if LFC was enabled or we are going to reenable it */
if (newval == 0 && !LFC_ENABLED())
{
LWLockRelease(lfc_lock);
LWLockRelease(lfc_freelist_lock);
/* File should be reopened if LFC is reenabled */
lfc_close_file();
return;
@@ -495,7 +449,6 @@ lfc_change_limit_hook(int newval, void *extra)
if (!lfc_ensure_opened())
{
LWLockRelease(lfc_lock);
LWLockRelease(lfc_freelist_lock);
return;
}
@@ -511,14 +464,18 @@ lfc_change_limit_hook(int newval, void *extra)
* returning their space to file system
*/
FileCacheEntry *victim = dlist_container(FileCacheEntry, list_node, dlist_pop_head_node(&lfc_ctl->lru));
FileCacheEntry *hole;
uint32 offset = victim->offset;
uint32 hash;
bool found;
BufferTag holetag;
CriticalAssert(victim->access_count == 0);
#ifdef FALLOC_FL_PUNCH_HOLE
if (fallocate(lfc_desc, FALLOC_FL_PUNCH_HOLE | FALLOC_FL_KEEP_SIZE, (off_t) victim->offset * lfc_blocks_per_chunk * BLCKSZ, lfc_blocks_per_chunk * BLCKSZ) < 0)
neon_log(LOG, "Failed to punch hole in file: %m");
#endif
/* We remove the entry, and enter a hole to the freelist */
/* We remove the old entry, and re-enter a hole to the hash table */
for (int i = 0; i < lfc_blocks_per_chunk; i++)
{
bool is_page_cached = GET_STATE(victim, i) == AVAILABLE;
@@ -527,14 +484,15 @@ lfc_change_limit_hook(int newval, void *extra)
}
hash_search_with_hash_value(lfc_hash, &victim->key, victim->hash, HASH_REMOVE, NULL);
if (!freelist_push(offset))
{
/* freelist_push already logged the error */
lfc_switch_off();
LWLockRelease(lfc_lock);
LWLockRelease(lfc_freelist_lock);
return;
}
memset(&holetag, 0, sizeof(holetag));
holetag.blockNum = offset;
hash = get_hash_value(lfc_hash, &holetag);
hole = hash_search_with_hash_value(lfc_hash, &holetag, hash, HASH_ENTER, &found);
hole->hash = hash;
hole->offset = offset;
hole->access_count = 0;
CriticalAssert(!found);
dlist_push_tail(&lfc_ctl->holes, &hole->list_node);
lfc_ctl->used -= 1;
}
@@ -546,7 +504,6 @@ lfc_change_limit_hook(int newval, void *extra)
neon_log(DEBUG1, "set local file cache limit to %d", new_size);
LWLockRelease(lfc_lock);
LWLockRelease(lfc_freelist_lock);
}
void
@@ -1423,7 +1380,7 @@ lfc_init_new_entry(FileCacheEntry* entry, uint32 hash)
* options, in order of preference:
*
* Unless there is no space available, we can:
* 1. Use an entry from the freelist, and
* 1. Use an entry from the `holes` list, and
* 2. Create a new entry.
* We can always, regardless of space in the LFC:
* 3. evict an entry from LRU, and
@@ -1431,10 +1388,17 @@ lfc_init_new_entry(FileCacheEntry* entry, uint32 hash)
*/
if (lfc_ctl->used < lfc_ctl->limit)
{
if (!freelist_is_empty())
if (!dlist_is_empty(&lfc_ctl->holes))
{
/* We can reuse a hole that was left behind when the LFC was shrunk previously */
uint32 offset = freelist_pop();
FileCacheEntry *hole = dlist_container(FileCacheEntry, list_node,
dlist_pop_head_node(&lfc_ctl->holes));
uint32 offset = hole->offset;
bool hole_found;
hash_search_with_hash_value(lfc_hash, &hole->key,
hole->hash, HASH_REMOVE, &hole_found);
CriticalAssert(hole_found);
lfc_ctl->used += 1;
entry->offset = offset; /* reuse the hole */
@@ -1548,7 +1512,6 @@ lfc_prefetch(NRelFileInfo rinfo, ForkNumber forknum, BlockNumber blkno,
hash = get_hash_value(lfc_hash, &tag);
cv = &lfc_ctl->cv[hash % N_COND_VARS];
retry:
LWLockAcquire(lfc_lock, LW_EXCLUSIVE);
if (!LFC_ENABLED() || !lfc_ensure_opened())
@@ -1557,9 +1520,6 @@ lfc_prefetch(NRelFileInfo rinfo, ForkNumber forknum, BlockNumber blkno,
return false;
}
if (!freelist_prepare_pop())
goto retry;
lwlsn = neon_get_lwlsn(rinfo, forknum, blkno);
if (lwlsn > lsn)
@@ -1693,7 +1653,6 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
CriticalAssert(BufTagGetRelNumber(&tag) != InvalidRelFileNumber);
retry:
LWLockAcquire(lfc_lock, LW_EXCLUSIVE);
if (!LFC_ENABLED() || !lfc_ensure_opened())
@@ -1703,9 +1662,6 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
}
generation = lfc_ctl->generation;
if (!freelist_prepare_pop())
goto retry;
/*
* For every chunk that has blocks we're interested in, we
* 1. get the chunk header
@@ -1867,140 +1823,6 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
LWLockRelease(lfc_lock);
}
/**** freelist management ****/
/*
* Prerequisites:
* - The caller is holding 'lfc_lock'. XXX
*/
static bool
freelist_prepare_pop(void)
{
/*
* If the in-memory freelist is empty, but there are more blocks available, load them.
*
* TODO: if there
*/
if (lfc_ctl->num_free_pages == 0 && lfc_ctl->freelist_head != INVALID_OFFSET)
{
uint32 freelist_head;
FreeListChunk *freelist_chunk;
size_t bytes_read;
LWLockRelease(lfc_lock);
LWLockAcquire(lfc_freelist_lock, LW_EXCLUSIVE);
if (!(lfc_ctl->num_free_pages == 0 && lfc_ctl->freelist_head != INVALID_OFFSET))
{
/* someone else did the work for us while we were not holding the lock */
LWLockRelease(lfc_freelist_lock);
return false;
}
freelist_head = lfc_ctl->freelist_head;
freelist_chunk = palloc(lfc_blocks_per_chunk * BLCKSZ);
bytes_read = 0;
while (bytes_read < lfc_blocks_per_chunk * BLCKSZ)
{
ssize_t rc;
rc = pread(lfc_desc, freelist_chunk, lfc_blocks_per_chunk * BLCKSZ - bytes_read, (off_t) freelist_head * lfc_blocks_per_chunk * BLCKSZ + bytes_read);
if (rc < 0)
{
lfc_disable("read freelist page");
return false;
}
bytes_read += rc;
}
LWLockAcquire(lfc_lock, LW_EXCLUSIVE);
if (lfc_generation != lfc_ctl->generation)
{
LWLockRelease(lfc_lock);
return false;
}
Assert(lfc_ctl->freelist_head == freelist_head);
Assert(lfc_ctl->num_free_pages == 0);
lfc_ctl->freelist_head = freelist_chunk->next;
lfc_ctl->num_free_pages = freelist_chunk->num_free_pages;
memcpy(lfc_ctl->free_pages, freelist_chunk->free_pages, lfc_ctl->num_free_pages * sizeof(uint32));
pfree(freelist_chunk);
LWLockRelease(lfc_lock);
LWLockRelease(lfc_freelist_lock);
return false;
}
return true;
}
/*
* Prerequisites:
* - The caller is holding 'lfc_lock' and 'lfc_freelist_lock'.
*
* Returns 'false' on error.
*/
static bool
freelist_push(uint32 offset)
{
Assert(lfc_ctl->freelist_size == FREELIST_ENTRIES_PER_CHUNK(lfc_blocks_per_chunk));
if (lfc_ctl->num_free_pages == lfc_ctl->freelist_size)
{
FreeListChunk *freelist_chunk;
struct iovec iov;
ssize_t rc;
freelist_chunk = palloc(lfc_blocks_per_chunk * BLCKSZ);
/* write the existing entries to the chunk on disk */
freelist_chunk->next = lfc_ctl->freelist_head;
freelist_chunk->num_free_pages = lfc_ctl->num_free_pages;
memcpy(freelist_chunk->free_pages, lfc_ctl->free_pages, lfc_ctl->num_free_pages * sizeof(uint32));
/* Use the passed-in offset to hold the freelist chunk itself */
iov.iov_base = freelist_chunk;
iov.iov_len = lfc_blocks_per_chunk * BLCKSZ;
rc = pg_pwritev_with_retry(lfc_desc, &iov, 1, (off_t) offset * lfc_blocks_per_chunk * BLCKSZ);
pfree(freelist_chunk);
if (rc < 0)
return false;
lfc_ctl->freelist_head = offset;
lfc_ctl->num_free_pages = 0;
}
else
{
lfc_ctl->free_pages[lfc_ctl->num_free_pages] = offset;
lfc_ctl->num_free_pages++;
}
return true;
}
static uint32
freelist_pop(void)
{
uint32 result;
/* The caller should've checked that the list is not empty */
Assert(lfc_ctl->num_free_pages > 0);
result = lfc_ctl->free_pages[lfc_ctl->num_free_pages - 1];
lfc_ctl->num_free_pages--;
return result;
}
static bool
freelist_is_empty(void)
{
return lfc_ctl->num_free_pages == 0;
}
typedef struct
{
TupleDesc tupdesc;
@@ -2227,8 +2049,12 @@ local_cache_pages(PG_FUNCTION_ARGS)
hash_seq_init(&status, lfc_hash);
while ((entry = hash_seq_search(&status)) != NULL)
{
for (int i = 0; i < lfc_blocks_per_chunk; i++)
n_pages += GET_STATE(entry, i) == AVAILABLE;
/* Skip hole tags */
if (NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key)) != 0)
{
for (int i = 0; i < lfc_blocks_per_chunk; i++)
n_pages += GET_STATE(entry, i) == AVAILABLE;
}
}
}
}
@@ -2256,16 +2082,19 @@ local_cache_pages(PG_FUNCTION_ARGS)
{
for (int i = 0; i < lfc_blocks_per_chunk; i++)
{
if (GET_STATE(entry, i) == AVAILABLE)
if (NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key)) != 0)
{
fctx->record[n].pageoffs = entry->offset * lfc_blocks_per_chunk + i;
fctx->record[n].relfilenode = NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key));
fctx->record[n].reltablespace = NInfoGetSpcOid(BufTagGetNRelFileInfo(entry->key));
fctx->record[n].reldatabase = NInfoGetDbOid(BufTagGetNRelFileInfo(entry->key));
fctx->record[n].forknum = entry->key.forkNum;
fctx->record[n].blocknum = entry->key.blockNum + i;
fctx->record[n].accesscount = entry->access_count;
n += 1;
if (GET_STATE(entry, i) == AVAILABLE)
{
fctx->record[n].pageoffs = entry->offset * lfc_blocks_per_chunk + i;
fctx->record[n].relfilenode = NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key));
fctx->record[n].reltablespace = NInfoGetSpcOid(BufTagGetNRelFileInfo(entry->key));
fctx->record[n].reldatabase = NInfoGetDbOid(BufTagGetNRelFileInfo(entry->key));
fctx->record[n].forknum = entry->key.forkNum;
fctx->record[n].blocknum = entry->key.blockNum + i;
fctx->record[n].accesscount = entry->access_count;
n += 1;
}
}
}
}

View File

@@ -25,15 +25,19 @@ pub(super) async fn authenticate(
}
AuthSecret::Scram(secret) => {
debug!("auth endpoint chooses SCRAM");
let scram = auth::Scram(&secret, ctx);
let auth_outcome = tokio::time::timeout(
config.scram_protocol_timeout,
AuthFlow::new(client, auth::Scram(&secret, ctx)).authenticate(),
)
let auth_outcome = tokio::time::timeout(config.scram_protocol_timeout, async {
AuthFlow::new(client, scram)
.authenticate()
.await
.inspect_err(|error| {
warn!(?error, "error processing scram messages");
})
})
.await
.inspect_err(|_| warn!("error processing scram messages error = authentication timed out, execution time exceeded {} seconds", config.scram_protocol_timeout.as_secs()))
.map_err(auth::AuthError::user_timeout)?
.inspect_err(|error| warn!(?error, "error processing scram messages"))?;
.map_err(auth::AuthError::user_timeout)??;
let client_key = match auth_outcome {
sasl::Outcome::Success(key) => key,

View File

@@ -159,7 +159,7 @@ pub async fn task_main(
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
config: &'static ProxyConfig,
backend: &'static ConsoleRedirectBackend,
ctx: &RequestContext,

View File

@@ -7,9 +7,7 @@ use std::time::Duration;
use ::http::HeaderName;
use ::http::header::AUTHORIZATION;
use bytes::Bytes;
use futures::TryFutureExt;
use hyper::StatusCode;
use postgres_client::config::SslMode;
use tokio::time::Instant;
use tracing::{Instrument, debug, info, info_span, warn};
@@ -74,34 +72,28 @@ impl NeonControlPlaneClient {
role: &RoleName,
) -> Result<AuthInfo, GetAuthInfoError> {
async {
let request = self
.endpoint
.get_path("get_endpoint_access_control")
.header(X_REQUEST_ID, ctx.session_id().to_string())
.header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
.query(&[("session_id", ctx.session_id())])
.query(&[
("application_name", ctx.console_application_name().as_str()),
("endpointish", endpoint.as_str()),
("role", role.as_str()),
])
.build()?;
debug!(url = request.url().as_str(), "sending http request");
let start = Instant::now();
let response = {
let request = self
.endpoint
.get_path("get_endpoint_access_control")
.header(X_REQUEST_ID, ctx.session_id().to_string())
.header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
.query(&[("session_id", ctx.session_id())])
.query(&[
("application_name", ctx.console_application_name().as_str()),
("endpointish", endpoint.as_str()),
("role", role.as_str()),
])
.build()?;
debug!(url = request.url().as_str(), "sending http request");
let start = Instant::now();
let _pause = ctx.latency_timer_pause_at(start, crate::metrics::Waiting::Cplane);
let response = self.endpoint.execute(request).await?;
info!(duration = ?start.elapsed(), "received http response");
response
self.endpoint.execute(request).await?
};
info!(duration = ?start.elapsed(), "received http response");
let body = match parse_body::<GetEndpointAccessControl>(
response.status(),
response.bytes().await?,
) {
let body = match parse_body::<GetEndpointAccessControl>(response).await {
Ok(body) => body,
// Error 404 is special: it's ok not to have a secret.
// TODO(anna): retry
@@ -192,10 +184,7 @@ impl NeonControlPlaneClient {
drop(pause);
info!(duration = ?start.elapsed(), "received http response");
let body = parse_body::<EndpointJwksResponse>(
response.status(),
response.bytes().await.map_err(ControlPlaneError::from)?,
)?;
let body = parse_body::<EndpointJwksResponse>(response).await?;
let rules = body
.jwks
@@ -247,7 +236,7 @@ impl NeonControlPlaneClient {
let response = self.endpoint.execute(request).await?;
drop(pause);
info!(duration = ?start.elapsed(), "received http response");
let body = parse_body::<WakeCompute>(response.status(), response.bytes().await?)?;
let body = parse_body::<WakeCompute>(response).await?;
// Unfortunately, ownership won't let us use `Option::ok_or` here.
let (host, port) = match parse_host_port(&body.address) {
@@ -498,33 +487,33 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
}
/// Parse http response body, taking status code into account.
fn parse_body<T: for<'a> serde::Deserialize<'a>>(
status: StatusCode,
body: Bytes,
async fn parse_body<T: for<'a> serde::Deserialize<'a>>(
response: http::Response,
) -> Result<T, ControlPlaneError> {
let status = response.status();
if status.is_success() {
// We shouldn't log raw body because it may contain secrets.
info!("request succeeded, processing the body");
return Ok(serde_json::from_slice(&body).map_err(std::io::Error::other)?);
return Ok(response.json().await?);
}
let s = response.bytes().await?;
// Log plaintext to be able to detect, whether there are some cases not covered by the error struct.
info!("response_error plaintext: {:?}", body);
info!("response_error plaintext: {:?}", s);
// Don't throw an error here because it's not as important
// as the fact that the request itself has failed.
let mut body = serde_json::from_slice(&body).unwrap_or_else(|e| {
let mut body = serde_json::from_slice(&s).unwrap_or_else(|e| {
warn!("failed to parse error body: {e}");
Box::new(ControlPlaneErrorMessage {
ControlPlaneErrorMessage {
error: "reason unclear (malformed error message)".into(),
http_status_code: status,
status: None,
})
}
});
body.http_status_code = status;
warn!("console responded with an error ({status}): {body:?}");
Err(ControlPlaneError::Message(body))
Err(ControlPlaneError::Message(Box::new(body)))
}
fn parse_host_port(input: &str) -> Option<(&str, u16)> {

View File

@@ -4,10 +4,9 @@
pub mod health_server;
use std::time::{Duration, Instant};
use std::time::Duration;
use bytes::Bytes;
use futures::FutureExt;
use http::Method;
use http_body_util::BodyExt;
use hyper::body::Body;
@@ -110,31 +109,15 @@ impl Endpoint {
}
/// Execute a [request](reqwest::Request).
pub(crate) fn execute(
&self,
request: Request,
) -> impl Future<Output = Result<Response, Error>> {
let metric = Metrics::get()
pub(crate) async fn execute(&self, request: Request) -> Result<Response, Error> {
let _timer = Metrics::get()
.proxy
.console_request_latency
.with_labels(ConsoleRequest {
.start_timer(ConsoleRequest {
request: request.url().path(),
});
let req = self.client.execute(request).boxed();
async move {
let start = Instant::now();
scopeguard::defer!({
Metrics::get()
.proxy
.console_request_latency
.get_metric(metric)
.observe_duration_since(start);
});
req.await
}
self.client.execute(request).await
}
}

View File

@@ -186,7 +186,7 @@ where
pub async fn read_message<'a, S>(
stream: &mut S,
buf: &'a mut Vec<u8>,
max: u32,
max: usize,
) -> io::Result<(u8, &'a mut [u8])>
where
S: AsyncRead + Unpin,
@@ -206,7 +206,7 @@ where
let header = read!(stream => Header);
// as described above, the length must be at least 4.
let Some(len) = header.len.get().checked_sub(4) else {
let Some(len) = (header.len.get() as usize).checked_sub(4) else {
return Err(io::Error::other(format!(
"invalid startup message length {}, must be at least 4.",
header.len,
@@ -222,7 +222,7 @@ where
}
// read in our entire message.
buf.resize(len as usize, 0);
buf.resize(len, 0);
stream.read_exact(buf).await?;
Ok((header.tag, buf))

View File

@@ -1,4 +1,3 @@
use futures::{FutureExt, TryFutureExt};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{debug, info, warn};
@@ -58,7 +57,7 @@ pub(crate) enum HandshakeData<S> {
/// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
/// we also take an extra care of propagating only the select handshake errors to client.
#[tracing::instrument(skip_all)]
pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin + Send>(
pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
ctx: &RequestContext,
stream: S,
mut tls: Option<&TlsConfig>,
@@ -109,9 +108,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin + Send>(
}
}
}
})
.map_ok(Box::new)
.boxed();
});
res?;
@@ -149,7 +146,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin + Send>(
tls.cert_resolver.resolve(conn_info.server_name());
let tls = Stream::Tls {
tls: tls_stream,
tls: Box::new(tls_stream),
tls_server_end_point,
};
(stream, msg) = PqStream::parse_startup(tls).await?;

View File

@@ -270,7 +270,7 @@ impl ReportableError for ClientRequestError {
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
config: &'static ProxyConfig,
auth_backend: &'static auth::Backend<'static, ()>,
ctx: &RequestContext,

View File

@@ -1,4 +1,3 @@
use futures::FutureExt;
use smol_str::SmolStr;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::debug;
@@ -90,7 +89,6 @@ impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {
.compute
.cancel_closure
.try_cancel_query(compute_config)
.boxed()
.await
{
tracing::warn!(session_id = ?self.session_id, ?err, "could not cancel the query in the database");

View File

@@ -30,53 +30,52 @@ where
F: FnOnce(&str) -> super::Result<M>,
M: Mechanism,
{
let (mut mechanism, mut input) = {
let sasl = {
// pause the timer while we communicate with the client
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
// Initial client message contains the chosen auth method's name.
let msg = stream.read_password_message().await?;
let sasl = super::FirstMessage::parse(msg)
.ok_or(super::Error::BadClientMessage("bad sasl message"))?;
(mechanism(sasl.method)?, sasl.message)
super::FirstMessage::parse(msg).ok_or(super::Error::BadClientMessage("bad sasl message"))?
};
let mut mechanism = mechanism(sasl.method)?;
let mut input = sasl.message;
loop {
match mechanism.exchange(input) {
Ok(Step::Continue(moved_mechanism, reply)) => {
let step = mechanism
.exchange(input)
.inspect_err(|error| tracing::info!(?error, "error during SASL exchange"))?;
match step {
Step::Continue(moved_mechanism, reply) => {
mechanism = moved_mechanism;
// pause the timer while we communicate with the client
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
// write reply
let sasl_msg = BeAuthenticationSaslMessage::Continue(reply.as_bytes());
stream.write_message(BeMessage::AuthenticationSasl(sasl_msg));
drop(reply);
// get next input
stream.flush().await?;
let msg = stream.read_password_message().await?;
input = std::str::from_utf8(msg)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?;
}
Ok(Step::Success(result, reply)) => {
Step::Success(result, reply) => {
// pause the timer while we communicate with the client
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
// write reply
let sasl_msg = BeAuthenticationSaslMessage::Final(reply.as_bytes());
stream.write_message(BeMessage::AuthenticationSasl(sasl_msg));
stream.write_message(BeMessage::AuthenticationOk);
// exit with success
break Ok(Outcome::Success(result));
}
// exit with failure
Ok(Step::Failure(reason)) => break Ok(Outcome::Failure(reason)),
Err(error) => {
tracing::info!(?error, "error during SASL exchange");
return Err(error);
}
Step::Failure(reason) => break Ok(Outcome::Failure(reason)),
}
// pause the timer while we communicate with the client
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
// get next input
stream.flush().await?;
let msg = stream.read_password_message().await?;
input = std::str::from_utf8(msg)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?;
}
}

View File

@@ -72,7 +72,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> PqStream<S> {
impl<S: AsyncRead + Unpin> PqStream<S> {
/// Read a raw postgres packet, which will respect the max length requested.
/// This is not cancel safe.
async fn read_raw_expect(&mut self, tag: u8, max: u32) -> io::Result<&mut [u8]> {
async fn read_raw_expect(&mut self, tag: u8, max: usize) -> io::Result<&mut [u8]> {
let (actual_tag, msg) = read_message(&mut self.stream, &mut self.read, max).await?;
if actual_tag != tag {
return Err(io::Error::other(format!(
@@ -89,7 +89,7 @@ impl<S: AsyncRead + Unpin> PqStream<S> {
// passwords are usually pretty short
// and SASL SCRAM messages are no longer than 256 bytes in my testing
// (a few hashes and random bytes, encoded into base64).
const MAX_PASSWORD_LENGTH: u32 = 512;
const MAX_PASSWORD_LENGTH: usize = 512;
self.read_raw_expect(FE_PASSWORD_MESSAGE, MAX_PASSWORD_LENGTH)
.await
}

View File

@@ -31,9 +31,7 @@ mod private {
type Output = io::Result<RustlsStream<S>>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.inner)
.poll(cx)
.map_ok(|s| RustlsStream(Box::new(s)))
Pin::new(&mut self.inner).poll(cx).map_ok(RustlsStream)
}
}
@@ -59,7 +57,7 @@ mod private {
}
}
pub struct RustlsStream<S>(Box<TlsStream<S>>);
pub struct RustlsStream<S>(TlsStream<S>);
impl<S> postgres_client::tls::TlsStream for RustlsStream<S>
where

26
shortener/Cargo.toml Normal file
View File

@@ -0,0 +1,26 @@
[package]
name = "shortener"
version = "0.1.0"
edition.workspace = true
license.workspace = true
[dependencies]
anyhow.workspace = true
axum-extra = { workspace = true, features = ["cookie", "cookie-private"] }
axum.workspace = true
base64.workspace = true
chrono.workspace = true
cookie = "0.18.1"
nanoid = { version = "0.4.0", default-features = false }
rand.workspace = true
reqwest.workspace = true
rustls-native-certs.workspace = true
rustls.workspace = true
serde.workspace = true
serde_json.workspace = true
time = { version = "0.3.36", default-features = false }
tokio-postgres-rustls.workspace = true
tokio-postgres.workspace = true
tokio.workspace = true
tracing-subscriber.workspace = true
tracing.workspace = true
workspace_hack.workspace = true

View File

@@ -0,0 +1,19 @@
CREATE TABLE IF NOT EXISTS users (
id SERIAL PRIMARY KEY,
sub VARCHAR(100) NOT NULL UNIQUE
);
CREATE TABLE IF NOT EXISTS sessions (
id SERIAL PRIMARY KEY,
user_id INT NOT NULL UNIQUE REFERENCES users(id),
session_id VARCHAR NOT NULL,
expires_at TIMESTAMP WITH TIME ZONE NOT NULL
);
CREATE TABLE IF NOT EXISTS urls (
id SERIAL PRIMARY KEY,
user_id INT NOT NULL REFERENCES users(id),
short_url VARCHAR(6) NOT NULL UNIQUE,
long_url VARCHAR NOT NULL,
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
)

View File

@@ -0,0 +1,222 @@
//! Library to gate infrastructure behind Google Oauth for domain.
//!
//! Why not oauth-rs? Oauth .exchange_code() doesn't work with "request failed". Also, we can't get
//! id token from it, and I don't want to pull in whole openid library just for that.
//! Id token saves us a request to openid endpoint and one Oauth scope we don't use
use anyhow::{Context, Result, bail};
use axum::extract::{FromRef, FromRequestParts, Query, State as AxumState};
use axum::response::Redirect;
use axum_extra::extract::PrivateCookieJar;
use axum_extra::extract::cookie::{Cookie, Key};
use chrono::{Duration, Local, TimeZone, Utc};
use cookie::CookieBuilder;
use core::num::NonZeroI32;
use reqwest::StatusCode;
use serde::Deserialize;
use std::sync::Arc;
use tokio_postgres::Socket;
const OAUTH_BASE_URL: &str = "https://accounts.google.com/o/oauth2/v2/auth";
const OAUTH_TOKEN_URL: &str = "https://oauth2.googleapis.com/token";
const COOKIE_SID: &str = "sid";
const COOKIE_CSRF: &str = "csrf";
pub struct Config {
pub oauth_client_id: String,
pub oauth_client_secret: String,
pub oauth_redirect_url: String,
pub oauth_allowed_domain: String,
pub cookie_settings: fn(CookieBuilder) -> CookieBuilder,
}
pub struct InnerState {
config: Config,
cookie_jar_key: Key,
pub db_client: tokio_postgres::Client,
}
#[derive(Clone)]
pub struct State(Arc<InnerState>);
type DbConn = tokio_postgres::Connection<Socket, tokio_postgres_rustls::RustlsStream<Socket>>;
impl State {
pub async fn new(config: Config, db_connstr: &str) -> Result<(Self, DbConn)> {
let mut roots = rustls::RootCertStore::empty();
for cert in rustls_native_certs::load_native_certs().expect("could not load platform certs")
{
roots.add(cert).unwrap();
}
let tls_config = rustls::ClientConfig::builder()
.with_root_certificates(roots)
.with_no_client_auth();
let tls = tokio_postgres_rustls::MakeRustlsConnect::new(tls_config);
let (db_client, db_conn) = tokio_postgres::connect(&db_connstr, tls).await?;
let inner = InnerState {
config,
cookie_jar_key: Key::generate(),
db_client,
};
Ok((Self { 0: Arc::new(inner) }, db_conn))
}
}
impl std::ops::Deref for State {
type Target = InnerState;
fn deref(&self) -> &Self::Target {
&*self.0
}
}
impl FromRef<State> for Key {
fn from_ref(state: &State) -> Self {
state.cookie_jar_key.clone()
}
}
#[derive(Deserialize)]
pub struct UserId {
pub id: NonZeroI32,
}
#[derive(Deserialize)]
pub struct AuthRequest {
code: String,
}
#[derive(Deserialize)]
struct AuthResponse {
access_token: String,
id_token: String,
expires_in: u64,
}
#[derive(Deserialize)]
struct UserInfo {
hd: String,
sub: String,
}
impl axum::extract::OptionalFromRequestParts<State> for UserId {
type Rejection = StatusCode;
async fn from_request_parts(
parts: &mut axum::http::request::Parts,
state: &State,
) -> Result<Option<Self>, Self::Rejection> {
let jar: PrivateCookieJar = PrivateCookieJar::from_request_parts(parts, state)
.await
.unwrap(); // infallible
let Some(session_id) = jar.get(COOKIE_SID).map(|cookie| cookie.value().to_owned()) else {
return Ok(None);
};
let client = &state.db_client;
let query = client
.query_opt(
"SELECT user_id FROM sessions WHERE session_id = $1",
&[&session_id],
)
.await;
let id = match query {
Ok(Some(row)) => row.get::<usize, i32>(0),
Ok(None) => return Ok(None),
Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
};
let id = NonZeroI32::new(id).unwrap(); // postgres id guaranteed not to be zero
Ok(Some(Self { id }))
}
}
fn decode_id_token(token: String) -> Option<UserInfo> {
let payload = token.split(".").skip(1).take(1).collect::<Vec<&str>>();
let decoded = base64::decode_config(payload.get(0)?, base64::STANDARD_NO_PAD).ok()?;
serde_json::from_slice::<UserInfo>(&decoded).ok()
}
fn generate_csrf_token(num_bytes: u32) -> String {
use rand::{Rng, thread_rng};
let random_bytes: Vec<u8> = (0..num_bytes).map(|_| thread_rng().r#gen::<u8>()).collect();
base64::encode_config(&random_bytes, base64::URL_SAFE_NO_PAD)
}
pub async fn authorize(
state: AxumState<State>,
jar: PrivateCookieJar,
) -> (PrivateCookieJar, Redirect) {
let csrf_token = generate_csrf_token(16);
let client_id = &state.config.oauth_client_id;
let redirect_uri = &state.config.oauth_redirect_url;
let auth_url = format!(
"{OAUTH_BASE_URL}?response_type=code\
&client_id={client_id}\
&state={csrf_token}\
&redirect_uri={redirect_uri}\
&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email"
);
let csrf_cookie =
(state.config.cookie_settings)(Cookie::build((COOKIE_CSRF, csrf_token))).build();
let url = Into::<String>::into(auth_url);
(jar.add(csrf_cookie), Redirect::to(&url))
}
pub async fn authorized(
state: AxumState<State>,
jar: PrivateCookieJar,
Query(auth_request): Query<AuthRequest>,
) -> Result<PrivateCookieJar> {
let params = [
("grant_type", "authorization_code"),
("redirect_uri", &state.config.oauth_redirect_url),
("code", &auth_request.code),
("client_id", &state.config.oauth_client_id),
("client_secret", &state.config.oauth_client_secret),
];
let auth_response = reqwest::Client::new()
.post(OAUTH_TOKEN_URL)
.form(&params)
.send()
.await
.context("exchanging oauth code for token")?
.json::<AuthResponse>()
.await
.context("deserializing access_token response")?;
let Some(UserInfo { hd, sub }) = decode_id_token(auth_response.id_token) else {
bail!("failed to decode id token")
};
let allowed_domain = &state.config.oauth_allowed_domain;
if hd != *allowed_domain {
bail!("{hd} doesn't match {allowed_domain}")
}
let token_duration = Duration::try_seconds(auth_response.expires_in as i64).unwrap();
let expires_at = Utc.from_utc_datetime(&(Local::now().naive_local() + token_duration));
let cookie_max_age = time::Duration::new(token_duration.num_seconds(), 0);
let session_cookie = (state.config.cookie_settings)(Cookie::build((
COOKIE_SID,
auth_response.access_token.clone(),
)))
.max_age(cookie_max_age)
.build();
state
.db_client
.query(
"WITH user_insert AS (\
INSERT INTO users (sub) VALUES ($1) \
ON CONFLICT (sub) DO UPDATE SET sub = excluded.sub RETURNING id)\
INSERT INTO sessions (user_id, session_id, expires_at) \
SELECT id, $2, $3 FROM user_insert \
ON CONFLICT (user_id) DO UPDATE SET \
session_id = excluded.session_id, \
expires_at = excluded.expires_at",
&[&sub, &auth_response.access_token, &expires_at],
)
.await
.with_context(|| format!("updating session for {sub}"))?;
let csrf_cookie = jar.get(COOKIE_CSRF).unwrap(); // set in authorize()
Ok(jar.remove(csrf_cookie).add(session_cookie))
}

240
shortener/src/main.rs Normal file
View File

@@ -0,0 +1,240 @@
//! Shortener is a service to gate access to internal infrastructure
//! URLs behind team authorisation to expose less private information.
pub mod google_oauth_gate;
use crate::google_oauth_gate::{AuthRequest, State, UserId};
use anyhow::Result;
use axum::Form;
use axum::extract::State as AxumState;
use axum::extract::{Path, Query};
use axum::http::StatusCode;
use axum::response::{Html, IntoResponse};
use axum::response::{Redirect, Response};
use axum::routing::get;
use axum_extra::extract::PrivateCookieJar;
use axum_extra::extract::cookie::Cookie;
use cookie::CookieBuilder;
use google_oauth_gate::Config;
use serde::Deserialize;
use std::env;
use tracing::{error, info};
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
const SOCKET: &str = "127.0.0.1:12344";
const HOST: &str = "http://127.0.0.1:12344";
const COOKIE_REDIRECT: &str = "redirect";
const ALLOWED_OAUTH_DOMAIN: &str = "neon.tech";
const AUTHORIZED_ROUTE: &str = "/authorized";
const SHORT_URL_LEN: usize = 6;
fn cookie_settings(b: CookieBuilder) -> CookieBuilder {
if HOST.contains("127.0.0.1") {
b.path("/")
.secure(false)
.same_site(axum_extra::extract::cookie::SameSite::Lax)
.http_only(true)
} else {
b.path("/")
.domain(ALLOWED_OAUTH_DOMAIN)
.secure(true)
.http_only(false)
}
}
fn oauth_redirect_url() -> String {
format!("{HOST}{AUTHORIZED_ROUTE}")
}
#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| format!("{}=info", env!("CARGO_CRATE_NAME")).into()),
)
.with(tracing_subscriber::fmt::layer())
.init();
let oauth_client_id = env::var("OAUTH_CLIENT_ID").expect("Missing OAUTH_CLIENT_ID");
let oauth_client_secret = env::var("OAUTH_CLIENT_SECRET").expect("Missing OAUTH_CLIENT_SECRET");
let db_connstr = env::var("DB_CONNSTR").expect("Missing DB_CONNSTR");
let config = Config {
oauth_client_id,
oauth_client_secret,
oauth_redirect_url: oauth_redirect_url(),
oauth_allowed_domain: ALLOWED_OAUTH_DOMAIN.to_string(),
cookie_settings,
};
let (state, db_conn) = State::new(config, &db_connstr).await?;
tokio::spawn(async move {
if let Err(err) = db_conn.await {
error!(%err, "connecting to database");
std::process::exit(1);
}
});
let router = axum::Router::new()
.route("/", get(index).post(shorten))
.route("/authorize", get(authorize))
.route(AUTHORIZED_ROUTE, get(authorized))
.route("/{short_url}", get(redirect))
.with_state(state);
let listener = tokio::net::TcpListener::bind(SOCKET)
.await
.expect("failed to bind TcpListener");
info!("listening on {SOCKET}");
axum::serve(listener, router).await.unwrap();
Ok(())
}
#[derive(Deserialize)]
struct LongUrl {
url: String,
}
fn shorten_form(short_url: &str) -> Html<String> {
let mut form = r#"
<div style="margin:auto;width:50%;padding:10px">
<form method="post">
<input type="text" name="url" style="width:100%">
<input type="submit" value="Shorten" style="margin-top:10px">
</form>"#
.to_string();
if !short_url.is_empty() {
form += &format!(
r#"
<p>
<a id="short" href="{0}">{0}</a>
<button onclick="copy()">Copy</button>
</p>
<script>
function copy() {{
navigator.clipboard.writeText(document.querySelector("\#short").textContent);
}}
</script>"#,
short_url
);
}
form += "</div>";
Html(form)
}
fn authorize_link(short_url: &str) -> Html<String> {
Html(format!(
"<a href=\"/authorize?short_url={short_url}\">Authorize</a>"
))
}
async fn index(user: Option<UserId>) -> Html<String> {
if user.is_some() {
return shorten_form("");
}
authorize_link("")
}
async fn shorten(
state: AxumState<State>,
user: Option<UserId>,
Form(LongUrl { url }): Form<LongUrl>,
) -> Response {
let user_id = match user {
None => return StatusCode::FORBIDDEN.into_response(),
Some(user) => user.id.get(),
};
if url.is_empty() {
return shorten_form("").into_response();
}
let mut short_url = "".to_string();
for i in 0..20 {
short_url = nanoid::nanoid!(SHORT_URL_LEN);
let query = state
.db_client
.query_opt(
"INSERT INTO urls (user_id, short_url, long_url) VALUES ($1, $2, $3) \
ON CONFLICT (short_url) DO NOTHING \
RETURNING short_url",
&[&user_id, &short_url, &url],
)
.await;
match query {
Ok(Some(_)) => break,
Ok(None) => {
info!(short_url, "url clash, retry {i}");
continue;
}
Err(err) => {
error!(%err, "inserting shortened url");
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
};
}
shorten_form(&format!("{HOST}/{short_url}")).into_response()
}
async fn redirect(
state: AxumState<State>,
user: Option<UserId>,
Path(short_url): Path<String>,
) -> Response {
let user_id = match user {
None => return authorize_link(&short_url).into_response(),
Some(user) => user.id,
};
let query = state
.db_client
.query_opt(
"SELECT long_url FROM urls WHERE short_url = $1",
&[&short_url],
)
.await;
match query {
Ok(Some(row)) => Redirect::permanent(row.get(0)).into_response(),
Ok(None) => StatusCode::NOT_FOUND.into_response(),
Err(err) => {
error!(%err, %short_url, %user_id, "querying long url");
StatusCode::INTERNAL_SERVER_ERROR.into_response()
}
}
}
#[derive(Deserialize)]
struct AuthorizeQuery {
short_url: String,
}
async fn authorize(
state: AxumState<State>,
jar: PrivateCookieJar,
Query(AuthorizeQuery { short_url }): Query<AuthorizeQuery>,
) -> (PrivateCookieJar, Redirect) {
let (jar, auth_redirect) = google_oauth_gate::authorize(state, jar).await;
let redirect_cookie = Cookie::build((COOKIE_REDIRECT, short_url))
.path("/")
//.TODO secure(true) not true for localhost
//.domain(COOKIE_DOMAIN)
.secure(false)
.same_site(axum_extra::extract::cookie::SameSite::Lax)
.http_only(true)
.build();
(jar.add(redirect_cookie), auth_redirect)
}
async fn authorized(
state: AxumState<State>,
jar: PrivateCookieJar,
query: Query<AuthRequest>,
) -> Result<(PrivateCookieJar, Redirect), Response> {
use google_oauth_gate::authorized;
let jar = authorized(state, jar, query).await.map_err(|err| {
error!(%err, "authorizing");
return StatusCode::UNAUTHORIZED.into_response();
})?;
let Some(redirect_cookie) = jar.get(COOKIE_REDIRECT) else {
return Ok((jar, Redirect::to("/")));
};
let redirect_url = Redirect::to(&format!("/{}", redirect_cookie.value_trimmed()));
Ok((jar.remove(redirect_cookie), redirect_url))
}

View File

@@ -20,6 +20,7 @@ anstream = { version = "0.6" }
anyhow = { version = "1", features = ["backtrace"] }
axum = { version = "0.8", features = ["ws"] }
axum-core = { version = "0.5", default-features = false, features = ["tracing"] }
axum-extra = { version = "0.10", features = ["cookie-private", "typed-header"] }
base64-594e8ee84c453af0 = { package = "base64", version = "0.13", features = ["alloc"] }
base64-647d43efb71741da = { package = "base64", version = "0.21" }
base64ct = { version = "1", default-features = false, features = ["std"] }
@@ -30,6 +31,7 @@ clap = { version = "4", features = ["derive", "env", "string"] }
clap_builder = { version = "4", default-features = false, features = ["color", "env", "help", "std", "string", "suggestions", "usage"] }
const-oid = { version = "0.9", default-features = false, features = ["db", "std"] }
crypto-bigint = { version = "0.5", features = ["generic-array", "zeroize"] }
crypto-common = { version = "0.1", default-features = false, features = ["getrandom", "std"] }
der = { version = "0.7", default-features = false, features = ["derive", "flagset", "oid", "pem", "std"] }
deranged = { version = "0.3", default-features = false, features = ["powerfmt", "serde", "std"] }
digest = { version = "0.10", features = ["mac", "oid", "std"] }