Compare commits

..

12 Commits

Author SHA1 Message Date
Konstantin Knizhnik
de649f856c Fix documentation format issues 2024-04-13 22:37:39 +03:00
Konstantin Knizhnik
de3fdf9860 Add more comments 2024-04-13 21:47:01 +03:00
Konstantin Knizhnik
1b2cfc0259 Proivide comment for NeonRequest struct 2024-04-11 17:24:39 +03:00
Konstantin Knizhnik
165a1d7bf1 Make ruff happy 2024-04-11 09:15:35 +03:00
Konstantin Knizhnik
f07c33186a Add neon.protocol_version GUC 2024-04-11 09:15:35 +03:00
Konstantin Knizhnik
15c0e1351a Fix messages tags in PS serialize 2024-04-11 09:15:35 +03:00
Konstantin Knizhnik
ccbf95e9dc Use tags starting from 10 for command of new protocol 2024-04-11 09:15:34 +03:00
Konstantin Knizhnik
93e6046005 Send LSN range in getpage request 2024-04-11 09:15:31 +03:00
Konstantin Knizhnik
d47e4a2a41 Remember last written LSN when it is first requested (#7343)
## Problem

See https://neondb.slack.com/archives/C03QLRH7PPD/p1712529369520409

In case of statements CREATE TABLE AS SELECT... or INSERT FROM SELECT...
we are fetching data from source table and storing it in destination
table. It cause problems with prefetch last-written-lsn is known for the
pages of source table
(which for example happens after compute restart). In this case we get
get global value of last-written-lsn which is changed frequently as far
as we are writing pages of destination table. As a result request-isn
for the prefetch and request-let when this page is actually needed are
different and we got exported prefetch request. So it actually disarms
prefetch.


## Summary of changes

Proposed simple patch stores last-written LSN for the page when it is
not found. So next time we will request last-written LSN for this page,
we will get the same value (certainly if the page was not changed).

## Checklist before requesting a review

- [ ] I have performed a self-review of my code.
- [ ] If it is a core feature, I have added thorough tests.
- [ ] Do we need to implement analytics? if so did you add the relevant
metrics to the dashboard?
- [ ] If this PR requires public announcement, mark it with
/release-notes label and add several sentences in this section.

## Checklist before merging

- [ ] Do not forget to reformat commit message to not include the above
checklist

---------

Co-authored-by: Konstantin Knizhnik <knizhnik@neon.tech>
2024-04-11 07:47:45 +03:00
Em Sharnoff
f86845f64b compute_ctl: Auto-set dynamic_shared_memory_type (#7348)
Part of neondatabase/cloud#12047.

The basic idea is that for our VMs, we want to enable swap and disable
Linux memory overcommit. Alongside these, we should set postgres'
dynamic_shared_memory_type to mmap, but we want to avoid setting it to
mmap if swap is not enabled.

Implementing this in the control plane would be fiddly, but it's
relatively straightforward to add to compute_ctl.
2024-04-10 13:13:48 +00:00
Anna Khanova
0bb04ebe19 Revert "Proxy read ids from redis (#7205)" (#7350)
This reverts commit dbac2d2c47.

## Problem

Proxy pods fails to install in k8s clusters, cplane release blocking.

## Summary of changes

Revert
2024-04-10 10:12:55 +00:00
Anna Khanova
5efe95a008 proxy: fix credentials cache lookup (#7349)
## Problem

Incorrect processing of `-pooler` connections.

## Summary of changes

Fix

TODO: add e2e tests for caching
2024-04-10 08:30:09 +00:00
33 changed files with 428 additions and 605 deletions

View File

@@ -6,8 +6,8 @@ use std::path::Path;
use anyhow::Result;
use crate::pg_helpers::escape_conf_value;
use crate::pg_helpers::PgOptionsSerialize;
use compute_api::spec::{ComputeMode, ComputeSpec};
use crate::pg_helpers::{GenericOptionExt, PgOptionsSerialize};
use compute_api::spec::{ComputeMode, ComputeSpec, GenericOption};
/// Check that `line` is inside a text file and put it there if it is not.
/// Create file if it doesn't exist.
@@ -92,6 +92,27 @@ pub fn write_postgres_conf(
}
}
if cfg!(target_os = "linux") {
// Check /proc/sys/vm/overcommit_memory -- if it equals 2 (i.e. linux memory overcommit is
// disabled), then the control plane has enabled swap and we should set
// dynamic_shared_memory_type = 'mmap'.
//
// This is (maybe?) temporary - for more, see https://github.com/neondatabase/cloud/issues/12047.
let overcommit_memory_contents = std::fs::read_to_string("/proc/sys/vm/overcommit_memory")
// ignore any errors - they may be expected to occur under certain situations (e.g. when
// not running in Linux).
.unwrap_or_else(|_| String::new());
if overcommit_memory_contents.trim() == "2" {
let opt = GenericOption {
name: "dynamic_shared_memory_type".to_owned(),
value: Some("mmap".to_owned()),
vartype: "enum".to_owned(),
};
write!(file, "{}", opt.to_pg_setting())?;
}
}
// If there are any extra options in the 'settings' field, append those
if spec.cluster.settings.is_some() {
writeln!(file, "# Managed by compute_ctl: begin")?;

View File

@@ -44,7 +44,7 @@ pub fn escape_conf_value(s: &str) -> String {
format!("'{}'", res)
}
trait GenericOptionExt {
pub trait GenericOptionExt {
fn to_pg_option(&self) -> String;
fn to_pg_setting(&self) -> String;
}

View File

@@ -841,21 +841,21 @@ impl TryFrom<u8> for PagestreamBeMessageTag {
#[derive(Debug, PartialEq, Eq)]
pub struct PagestreamExistsRequest {
pub latest: bool,
pub horizon: Lsn,
pub lsn: Lsn,
pub rel: RelTag,
}
#[derive(Debug, PartialEq, Eq)]
pub struct PagestreamNblocksRequest {
pub latest: bool,
pub horizon: Lsn,
pub lsn: Lsn,
pub rel: RelTag,
}
#[derive(Debug, PartialEq, Eq)]
pub struct PagestreamGetPageRequest {
pub latest: bool,
pub horizon: Lsn,
pub lsn: Lsn,
pub rel: RelTag,
pub blkno: u32,
@@ -863,14 +863,14 @@ pub struct PagestreamGetPageRequest {
#[derive(Debug, PartialEq, Eq)]
pub struct PagestreamDbSizeRequest {
pub latest: bool,
pub horizon: Lsn,
pub lsn: Lsn,
pub dbnode: u32,
}
#[derive(Debug, PartialEq, Eq)]
pub struct PagestreamGetSlruSegmentRequest {
pub latest: bool,
pub horizon: Lsn,
pub lsn: Lsn,
pub kind: u8,
pub segno: u32,
@@ -923,8 +923,8 @@ impl PagestreamFeMessage {
match self {
Self::Exists(req) => {
bytes.put_u8(0);
bytes.put_u8(u8::from(req.latest));
bytes.put_u8(10);
bytes.put_u64(req.horizon.0);
bytes.put_u64(req.lsn.0);
bytes.put_u32(req.rel.spcnode);
bytes.put_u32(req.rel.dbnode);
@@ -933,8 +933,8 @@ impl PagestreamFeMessage {
}
Self::Nblocks(req) => {
bytes.put_u8(1);
bytes.put_u8(u8::from(req.latest));
bytes.put_u8(11);
bytes.put_u64(req.horizon.0);
bytes.put_u64(req.lsn.0);
bytes.put_u32(req.rel.spcnode);
bytes.put_u32(req.rel.dbnode);
@@ -943,8 +943,8 @@ impl PagestreamFeMessage {
}
Self::GetPage(req) => {
bytes.put_u8(2);
bytes.put_u8(u8::from(req.latest));
bytes.put_u8(12);
bytes.put_u64(req.horizon.0);
bytes.put_u64(req.lsn.0);
bytes.put_u32(req.rel.spcnode);
bytes.put_u32(req.rel.dbnode);
@@ -954,15 +954,15 @@ impl PagestreamFeMessage {
}
Self::DbSize(req) => {
bytes.put_u8(3);
bytes.put_u8(u8::from(req.latest));
bytes.put_u8(13);
bytes.put_u64(req.horizon.0);
bytes.put_u64(req.lsn.0);
bytes.put_u32(req.dbnode);
}
Self::GetSlruSegment(req) => {
bytes.put_u8(4);
bytes.put_u8(u8::from(req.latest));
bytes.put_u8(14);
bytes.put_u64(req.horizon.0);
bytes.put_u64(req.lsn.0);
bytes.put_u8(req.kind);
bytes.put_u32(req.segno);
@@ -979,11 +979,32 @@ impl PagestreamFeMessage {
//
// TODO: consider using protobuf or serde bincode for less error prone
// serialization.
let msg_tag = body.read_u8()?;
let mut msg_tag = body.read_u8()?;
//
// Old version of protocol use commands with tags started with 0 and containing `latest` flag.
// New version of protocol shift command tags by 10 and pass LSN range instead of `latest` flag.
// Server should be able to handle both protocol version. As far as we are not passing no=w,
// protocol version from client to server, we make a decision based on tag range.
// So this code actually provides backward compatibility.
//
let horizon = if msg_tag >= 10 {
// new protocol
msg_tag -= 10; // commands tags in new protocol starts with 10
Lsn::from(body.read_u64::<BigEndian>()?)
} else {
// old_protocol
let latest = body.read_u8()? != 0;
if latest {
Lsn::MAX // get latest version
} else {
Lsn::INVALID // get version on specified LSN
}
};
let lsn = Lsn::from(body.read_u64::<BigEndian>()?);
match msg_tag {
0 => Ok(PagestreamFeMessage::Exists(PagestreamExistsRequest {
latest: body.read_u8()? != 0,
lsn: Lsn::from(body.read_u64::<BigEndian>()?),
horizon,
lsn,
rel: RelTag {
spcnode: body.read_u32::<BigEndian>()?,
dbnode: body.read_u32::<BigEndian>()?,
@@ -992,8 +1013,8 @@ impl PagestreamFeMessage {
},
})),
1 => Ok(PagestreamFeMessage::Nblocks(PagestreamNblocksRequest {
latest: body.read_u8()? != 0,
lsn: Lsn::from(body.read_u64::<BigEndian>()?),
horizon,
lsn,
rel: RelTag {
spcnode: body.read_u32::<BigEndian>()?,
dbnode: body.read_u32::<BigEndian>()?,
@@ -1002,8 +1023,8 @@ impl PagestreamFeMessage {
},
})),
2 => Ok(PagestreamFeMessage::GetPage(PagestreamGetPageRequest {
latest: body.read_u8()? != 0,
lsn: Lsn::from(body.read_u64::<BigEndian>()?),
horizon,
lsn,
rel: RelTag {
spcnode: body.read_u32::<BigEndian>()?,
dbnode: body.read_u32::<BigEndian>()?,
@@ -1013,14 +1034,14 @@ impl PagestreamFeMessage {
blkno: body.read_u32::<BigEndian>()?,
})),
3 => Ok(PagestreamFeMessage::DbSize(PagestreamDbSizeRequest {
latest: body.read_u8()? != 0,
lsn: Lsn::from(body.read_u64::<BigEndian>()?),
horizon,
lsn,
dbnode: body.read_u32::<BigEndian>()?,
})),
4 => Ok(PagestreamFeMessage::GetSlruSegment(
PagestreamGetSlruSegmentRequest {
latest: body.read_u8()? != 0,
lsn: Lsn::from(body.read_u64::<BigEndian>()?),
horizon,
lsn,
kind: body.read_u8()?,
segno: body.read_u32::<BigEndian>()?,
},
@@ -1148,7 +1169,7 @@ mod tests {
// Test serialization/deserialization of PagestreamFeMessage
let messages = vec![
PagestreamFeMessage::Exists(PagestreamExistsRequest {
latest: true,
horizon: Lsn::MAX,
lsn: Lsn(4),
rel: RelTag {
forknum: 1,
@@ -1158,7 +1179,7 @@ mod tests {
},
}),
PagestreamFeMessage::Nblocks(PagestreamNblocksRequest {
latest: false,
horizon: Lsn::INVALID,
lsn: Lsn(4),
rel: RelTag {
forknum: 1,
@@ -1168,8 +1189,8 @@ mod tests {
},
}),
PagestreamFeMessage::GetPage(PagestreamGetPageRequest {
latest: true,
lsn: Lsn(4),
horizon: Lsn::MAX,
lsn: Lsn::INVALID,
rel: RelTag {
forknum: 1,
spcnode: 2,
@@ -1179,7 +1200,7 @@ mod tests {
blkno: 7,
}),
PagestreamFeMessage::DbSize(PagestreamDbSizeRequest {
latest: true,
horizon: Lsn::MAX,
lsn: Lsn(4),
dbnode: 7,
}),

View File

@@ -312,7 +312,11 @@ async fn main_impl(
let (rel_tag, block_no) =
key_to_rel_block(key).expect("we filter non-rel-block keys out above");
PagestreamGetPageRequest {
latest: rng.gen_bool(args.req_latest_probability),
horizon: if rng.gen_bool(args.req_latest_probability) {
Lsn::MAX
} else {
r.timeline_lsn
},
lsn: r.timeline_lsn,
rel: rel_tag,
blkno: block_no,

View File

@@ -361,9 +361,10 @@ where
/// Add contents of relfilenode `src`, naming it as `dst`.
async fn add_rel(&mut self, src: RelTag, dst: RelTag) -> anyhow::Result<()> {
let horizon = self.lsn; // we do not need latest version
let nblocks = self
.timeline
.get_rel_size(src, Version::Lsn(self.lsn), false, self.ctx)
.get_rel_size(src, Version::Lsn(self.lsn), horizon, self.ctx)
.await?;
// If the relation is empty, create an empty file
@@ -384,7 +385,7 @@ where
for blknum in startblk..endblk {
let img = self
.timeline
.get_rel_page_at_lsn(src, blknum, Version::Lsn(self.lsn), false, self.ctx)
.get_rel_page_at_lsn(src, blknum, Version::Lsn(self.lsn), horizon, self.ctx)
.await?;
segment_data.extend_from_slice(&img[..]);
}

View File

@@ -847,69 +847,66 @@ impl PageServerHandler {
/// In either case, if the page server hasn't received the WAL up to the
/// requested LSN yet, we will wait for it to arrive. The return value is
/// the LSN that should be used to look up the page versions.
///
/// Compute needs to specify:
/// 1. "desired" LSN - which LSN compute expects to be acceptable
/// 2. Upper boundary LSN - PS should not send page with greater LSN to preserver consistency
///
/// In case of primary node then upper boundary is always +inf: nobody except this node can produce more recent version of the page.
/// In case of replica it is not true: replica can lag from primary node and PS and should not receive pages newer than its last_replay_lsn.
/// But it is not good always to request pages at `last_replay_lsn` because replica can be ahead PS and so it has to wait
/// until PS caught up (while for this particular page it is not needed).
///
/// We actually need to handle just three cases:
/// \[page_last_written_lsn, +inf\] - primary node
/// \[page_last_written_lsn, last_replay_lsn\] - hot-standby replica (receiving WAL from primary)
/// \[snapshot_lsn, snapshot_lsn\] - static RO replica (not receiving WAL fro primary)
///
/// Case \[0, lsn\] is not actually needed and added mostly for convenience as alias for \[lsn,lsn\]
async fn wait_or_get_last_lsn(
timeline: &Timeline,
mut lsn: Lsn,
latest: bool,
lsn: Lsn,
horizon: Lsn,
latest_gc_cutoff_lsn: &RcuReadGuard<Lsn>,
ctx: &RequestContext,
) -> Result<Lsn, PageStreamError> {
if latest {
// Latest page version was requested. If LSN is given, it is a hint
// to the page server that there have been no modifications to the
// page after that LSN. If we haven't received WAL up to that point,
// wait until it arrives.
let last_record_lsn = timeline.get_last_record_lsn();
// Note: this covers the special case that lsn == Lsn(0). That
// special case means "return the latest version whatever it is",
// and it's used for bootstrapping purposes, when the page server is
// connected directly to the compute node. That is needed because
// when you connect to the compute node, to receive the WAL, the
// walsender process will do a look up in the pg_authid catalog
// table for authentication. That poses a deadlock problem: the
// catalog table lookup will send a GetPage request, but the GetPage
// request will block in the page server because the recent WAL
// hasn't been received yet, and it cannot be received until the
// walsender completes the authentication and starts streaming the
// WAL.
if lsn <= last_record_lsn {
lsn = last_record_lsn;
} else {
timeline
.wait_lsn(
lsn,
crate::tenant::timeline::WaitLsnWaiter::PageService,
ctx,
)
.await?;
// Since we waited for 'lsn' to arrive, that is now the last
// record LSN. (Or close enough for our purposes; the
// last-record LSN can advance immediately after we return
// anyway)
}
let last_record_lsn = timeline.get_last_record_lsn();
// Horizon = 0 (INVALID) is treated as LSN interval degenerated to point [lsn,lsn].
// It as done mostly for convenience (because such get_page commands are widely used in tests) and
// also seems to be logical: Lsn::MAX moves upper boundary of LSN interval till last_record_lsn and
// Lsn(0) moves upper boundary to lower boundary.
let request_horizon = if horizon == Lsn::INVALID {
lsn
} else {
if lsn == Lsn(0) {
return Err(PageStreamError::BadRequest(
"invalid LSN(0) in request".into(),
));
}
horizon
};
let effective_lsn = Lsn::max(lsn, Lsn::min(request_horizon, last_record_lsn));
if effective_lsn > last_record_lsn {
timeline
.wait_lsn(
lsn,
effective_lsn,
crate::tenant::timeline::WaitLsnWaiter::PageService,
ctx,
)
.await?;
// Since we waited for 'lsn' to arrive, that is now the last
// record LSN. (Or close enough for our purposes; the
// last-record LSN can advance immediately after we return
// anyway)
} else if effective_lsn == Lsn(0) {
return Err(PageStreamError::BadRequest(
"invalid LSN(0) in request".into(),
));
}
if lsn < **latest_gc_cutoff_lsn {
if effective_lsn < **latest_gc_cutoff_lsn {
return Err(PageStreamError::BadRequest(format!(
"tried to request a page version that was garbage collected. requested at {} gc cutoff {}",
lsn, **latest_gc_cutoff_lsn
effective_lsn, **latest_gc_cutoff_lsn
).into()));
}
Ok(lsn)
Ok(effective_lsn)
}
#[instrument(skip_all, fields(shard_id))]
@@ -927,11 +924,11 @@ impl PageServerHandler {
let latest_gc_cutoff_lsn = timeline.get_latest_gc_cutoff_lsn();
let lsn =
Self::wait_or_get_last_lsn(timeline, req.lsn, req.latest, &latest_gc_cutoff_lsn, ctx)
Self::wait_or_get_last_lsn(timeline, req.lsn, req.horizon, &latest_gc_cutoff_lsn, ctx)
.await?;
let exists = timeline
.get_rel_exists(req.rel, Version::Lsn(lsn), req.latest, ctx)
.get_rel_exists(req.rel, Version::Lsn(lsn), req.horizon, ctx)
.await?;
Ok(PagestreamBeMessage::Exists(PagestreamExistsResponse {
@@ -955,11 +952,11 @@ impl PageServerHandler {
let latest_gc_cutoff_lsn = timeline.get_latest_gc_cutoff_lsn();
let lsn =
Self::wait_or_get_last_lsn(timeline, req.lsn, req.latest, &latest_gc_cutoff_lsn, ctx)
Self::wait_or_get_last_lsn(timeline, req.lsn, req.horizon, &latest_gc_cutoff_lsn, ctx)
.await?;
let n_blocks = timeline
.get_rel_size(req.rel, Version::Lsn(lsn), req.latest, ctx)
.get_rel_size(req.rel, Version::Lsn(lsn), req.horizon, ctx)
.await?;
Ok(PagestreamBeMessage::Nblocks(PagestreamNblocksResponse {
@@ -983,7 +980,7 @@ impl PageServerHandler {
let latest_gc_cutoff_lsn = timeline.get_latest_gc_cutoff_lsn();
let lsn =
Self::wait_or_get_last_lsn(timeline, req.lsn, req.latest, &latest_gc_cutoff_lsn, ctx)
Self::wait_or_get_last_lsn(timeline, req.lsn, req.horizon, &latest_gc_cutoff_lsn, ctx)
.await?;
let total_blocks = timeline
@@ -991,7 +988,7 @@ impl PageServerHandler {
DEFAULTTABLESPACE_OID,
req.dbnode,
Version::Lsn(lsn),
req.latest,
req.horizon,
ctx,
)
.await?;
@@ -1161,11 +1158,11 @@ impl PageServerHandler {
let latest_gc_cutoff_lsn = timeline.get_latest_gc_cutoff_lsn();
let lsn =
Self::wait_or_get_last_lsn(timeline, req.lsn, req.latest, &latest_gc_cutoff_lsn, ctx)
Self::wait_or_get_last_lsn(timeline, req.lsn, req.horizon, &latest_gc_cutoff_lsn, ctx)
.await?;
let page = timeline
.get_rel_page_at_lsn(req.rel, req.blkno, Version::Lsn(lsn), req.latest, ctx)
.get_rel_page_at_lsn(req.rel, req.blkno, Version::Lsn(lsn), req.horizon, ctx)
.await?;
Ok(PagestreamBeMessage::GetPage(PagestreamGetPageResponse {
@@ -1189,7 +1186,7 @@ impl PageServerHandler {
let latest_gc_cutoff_lsn = timeline.get_latest_gc_cutoff_lsn();
let lsn =
Self::wait_or_get_last_lsn(timeline, req.lsn, req.latest, &latest_gc_cutoff_lsn, ctx)
Self::wait_or_get_last_lsn(timeline, req.lsn, req.horizon, &latest_gc_cutoff_lsn, ctx)
.await?;
let kind = SlruKind::from_repr(req.kind)

View File

@@ -175,7 +175,7 @@ impl Timeline {
tag: RelTag,
blknum: BlockNumber,
version: Version<'_>,
latest: bool,
horizon: Lsn,
ctx: &RequestContext,
) -> Result<Bytes, PageReconstructError> {
if tag.relnode == 0 {
@@ -184,7 +184,7 @@ impl Timeline {
));
}
let nblocks = self.get_rel_size(tag, version, latest, ctx).await?;
let nblocks = self.get_rel_size(tag, version, horizon, ctx).await?;
if blknum >= nblocks {
debug!(
"read beyond EOF at {} blk {} at {}, size is {}: returning all-zeros page",
@@ -206,7 +206,7 @@ impl Timeline {
spcnode: Oid,
dbnode: Oid,
version: Version<'_>,
latest: bool,
horizon: Lsn,
ctx: &RequestContext,
) -> Result<usize, PageReconstructError> {
let mut total_blocks = 0;
@@ -214,7 +214,7 @@ impl Timeline {
let rels = self.list_rels(spcnode, dbnode, version, ctx).await?;
for rel in rels {
let n_blocks = self.get_rel_size(rel, version, latest, ctx).await?;
let n_blocks = self.get_rel_size(rel, version, horizon, ctx).await?;
total_blocks += n_blocks as usize;
}
Ok(total_blocks)
@@ -225,7 +225,7 @@ impl Timeline {
&self,
tag: RelTag,
version: Version<'_>,
latest: bool,
horizon: Lsn,
ctx: &RequestContext,
) -> Result<BlockNumber, PageReconstructError> {
if tag.relnode == 0 {
@@ -239,7 +239,7 @@ impl Timeline {
}
if (tag.forknum == FSM_FORKNUM || tag.forknum == VISIBILITYMAP_FORKNUM)
&& !self.get_rel_exists(tag, version, latest, ctx).await?
&& !self.get_rel_exists(tag, version, horizon, ctx).await?
{
// FIXME: Postgres sometimes calls smgrcreate() to create
// FSM, and smgrnblocks() on it immediately afterwards,
@@ -252,14 +252,8 @@ impl Timeline {
let mut buf = version.get(self, key, ctx).await?;
let nblocks = buf.get_u32_le();
if latest {
// Update relation size cache only if "latest" flag is set.
// This flag is set by compute when it is working with most recent version of relation.
// Typically master compute node always set latest=true.
// Please notice, that even if compute node "by mistake" specifies old LSN but set
// latest=true, then it can not cause cache corruption, because with latest=true
// pageserver choose max(request_lsn, last_written_lsn) and so cached value will be
// associated with most recent value of LSN.
if horizon == Lsn::MAX {
// Update relation size cache only if latest version is requested.
self.update_cached_rel_size(tag, version.get_lsn(), nblocks);
}
Ok(nblocks)
@@ -270,7 +264,7 @@ impl Timeline {
&self,
tag: RelTag,
version: Version<'_>,
_latest: bool,
_horizon: Lsn,
ctx: &RequestContext,
) -> Result<bool, PageReconstructError> {
if tag.relnode == 0 {
@@ -1088,7 +1082,7 @@ impl<'a> DatadirModification<'a> {
) -> anyhow::Result<()> {
let total_blocks = self
.tline
.get_db_size(spcnode, dbnode, Version::Modified(self), true, ctx)
.get_db_size(spcnode, dbnode, Version::Modified(self), Lsn::MAX, ctx)
.await?;
// Remove entry from dbdir
@@ -1187,7 +1181,7 @@ impl<'a> DatadirModification<'a> {
anyhow::ensure!(rel.relnode != 0, RelationError::InvalidRelnode);
if self
.tline
.get_rel_exists(rel, Version::Modified(self), true, ctx)
.get_rel_exists(rel, Version::Modified(self), Lsn::MAX, ctx)
.await?
{
let size_key = rel_size_to_key(rel);

View File

@@ -1034,7 +1034,7 @@ impl WalIngest {
let nblocks = modification
.tline
.get_rel_size(src_rel, Version::Modified(modification), true, ctx)
.get_rel_size(src_rel, Version::Modified(modification), Lsn::MAX, ctx)
.await?;
let dst_rel = RelTag {
spcnode: tablespace_id,
@@ -1072,7 +1072,7 @@ impl WalIngest {
src_rel,
blknum,
Version::Modified(modification),
true,
Lsn::MAX,
ctx,
)
.await?;
@@ -1242,7 +1242,7 @@ impl WalIngest {
};
if modification
.tline
.get_rel_exists(rel, Version::Modified(modification), true, ctx)
.get_rel_exists(rel, Version::Modified(modification), Lsn::MAX, ctx)
.await?
{
self.put_rel_drop(modification, rel, ctx).await?;
@@ -1541,7 +1541,7 @@ impl WalIngest {
nblocks
} else if !modification
.tline
.get_rel_exists(rel, Version::Modified(modification), true, ctx)
.get_rel_exists(rel, Version::Modified(modification), Lsn::MAX, ctx)
.await?
{
// create it with 0 size initially, the logic below will extend it
@@ -1553,7 +1553,7 @@ impl WalIngest {
} else {
modification
.tline
.get_rel_size(rel, Version::Modified(modification), true, ctx)
.get_rel_size(rel, Version::Modified(modification), Lsn::MAX, ctx)
.await?
};
@@ -1650,14 +1650,14 @@ async fn get_relsize(
) -> anyhow::Result<BlockNumber> {
let nblocks = if !modification
.tline
.get_rel_exists(rel, Version::Modified(modification), true, ctx)
.get_rel_exists(rel, Version::Modified(modification), Lsn::MAX, ctx)
.await?
{
0
} else {
modification
.tline
.get_rel_size(rel, Version::Modified(modification), true, ctx)
.get_rel_size(rel, Version::Modified(modification), Lsn::MAX, ctx)
.await?
};
Ok(nblocks)
@@ -1732,29 +1732,29 @@ mod tests {
// The relation was created at LSN 2, not visible at LSN 1 yet.
assert_eq!(
tline
.get_rel_exists(TESTREL_A, Version::Lsn(Lsn(0x10)), false, &ctx)
.get_rel_exists(TESTREL_A, Version::Lsn(Lsn(0x10)), Lsn::INVALID, &ctx)
.await?,
false
);
assert!(tline
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x10)), false, &ctx)
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x10)), Lsn::INVALID, &ctx)
.await
.is_err());
assert_eq!(
tline
.get_rel_exists(TESTREL_A, Version::Lsn(Lsn(0x20)), false, &ctx)
.get_rel_exists(TESTREL_A, Version::Lsn(Lsn(0x20)), Lsn::INVALID, &ctx)
.await?,
true
);
assert_eq!(
tline
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x20)), false, &ctx)
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x20)), Lsn::INVALID, &ctx)
.await?,
1
);
assert_eq!(
tline
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x50)), false, &ctx)
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x50)), Lsn::INVALID, &ctx)
.await?,
3
);
@@ -1762,46 +1762,46 @@ mod tests {
// Check page contents at each LSN
assert_eq!(
tline
.get_rel_page_at_lsn(TESTREL_A, 0, Version::Lsn(Lsn(0x20)), false, &ctx)
.get_rel_page_at_lsn(TESTREL_A, 0, Version::Lsn(Lsn(0x20)), Lsn::INVALID, &ctx)
.await?,
test_img("foo blk 0 at 2")
);
assert_eq!(
tline
.get_rel_page_at_lsn(TESTREL_A, 0, Version::Lsn(Lsn(0x30)), false, &ctx)
.get_rel_page_at_lsn(TESTREL_A, 0, Version::Lsn(Lsn(0x30)), Lsn::INVALID, &ctx)
.await?,
test_img("foo blk 0 at 3")
);
assert_eq!(
tline
.get_rel_page_at_lsn(TESTREL_A, 0, Version::Lsn(Lsn(0x40)), false, &ctx)
.get_rel_page_at_lsn(TESTREL_A, 0, Version::Lsn(Lsn(0x40)), Lsn::INVALID, &ctx)
.await?,
test_img("foo blk 0 at 3")
);
assert_eq!(
tline
.get_rel_page_at_lsn(TESTREL_A, 1, Version::Lsn(Lsn(0x40)), false, &ctx)
.get_rel_page_at_lsn(TESTREL_A, 1, Version::Lsn(Lsn(0x40)), Lsn::INVALID, &ctx)
.await?,
test_img("foo blk 1 at 4")
);
assert_eq!(
tline
.get_rel_page_at_lsn(TESTREL_A, 0, Version::Lsn(Lsn(0x50)), false, &ctx)
.get_rel_page_at_lsn(TESTREL_A, 0, Version::Lsn(Lsn(0x50)), Lsn::INVALID, &ctx)
.await?,
test_img("foo blk 0 at 3")
);
assert_eq!(
tline
.get_rel_page_at_lsn(TESTREL_A, 1, Version::Lsn(Lsn(0x50)), false, &ctx)
.get_rel_page_at_lsn(TESTREL_A, 1, Version::Lsn(Lsn(0x50)), Lsn::INVALID, &ctx)
.await?,
test_img("foo blk 1 at 4")
);
assert_eq!(
tline
.get_rel_page_at_lsn(TESTREL_A, 2, Version::Lsn(Lsn(0x50)), false, &ctx)
.get_rel_page_at_lsn(TESTREL_A, 2, Version::Lsn(Lsn(0x50)), Lsn::INVALID, &ctx)
.await?,
test_img("foo blk 2 at 5")
);
@@ -1817,19 +1817,19 @@ mod tests {
// Check reported size and contents after truncation
assert_eq!(
tline
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x60)), false, &ctx)
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x60)), Lsn::INVALID, &ctx)
.await?,
2
);
assert_eq!(
tline
.get_rel_page_at_lsn(TESTREL_A, 0, Version::Lsn(Lsn(0x60)), false, &ctx)
.get_rel_page_at_lsn(TESTREL_A, 0, Version::Lsn(Lsn(0x60)), Lsn::INVALID, &ctx)
.await?,
test_img("foo blk 0 at 3")
);
assert_eq!(
tline
.get_rel_page_at_lsn(TESTREL_A, 1, Version::Lsn(Lsn(0x60)), false, &ctx)
.get_rel_page_at_lsn(TESTREL_A, 1, Version::Lsn(Lsn(0x60)), Lsn::INVALID, &ctx)
.await?,
test_img("foo blk 1 at 4")
);
@@ -1837,13 +1837,13 @@ mod tests {
// should still see the truncated block with older LSN
assert_eq!(
tline
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x50)), false, &ctx)
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x50)), Lsn::INVALID, &ctx)
.await?,
3
);
assert_eq!(
tline
.get_rel_page_at_lsn(TESTREL_A, 2, Version::Lsn(Lsn(0x50)), false, &ctx)
.get_rel_page_at_lsn(TESTREL_A, 2, Version::Lsn(Lsn(0x50)), Lsn::INVALID, &ctx)
.await?,
test_img("foo blk 2 at 5")
);
@@ -1856,7 +1856,7 @@ mod tests {
m.commit(&ctx).await?;
assert_eq!(
tline
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x68)), false, &ctx)
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x68)), Lsn::INVALID, &ctx)
.await?,
0
);
@@ -1869,19 +1869,19 @@ mod tests {
m.commit(&ctx).await?;
assert_eq!(
tline
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x70)), false, &ctx)
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x70)), Lsn::INVALID, &ctx)
.await?,
2
);
assert_eq!(
tline
.get_rel_page_at_lsn(TESTREL_A, 0, Version::Lsn(Lsn(0x70)), false, &ctx)
.get_rel_page_at_lsn(TESTREL_A, 0, Version::Lsn(Lsn(0x70)), Lsn::INVALID, &ctx)
.await?,
ZERO_PAGE
);
assert_eq!(
tline
.get_rel_page_at_lsn(TESTREL_A, 1, Version::Lsn(Lsn(0x70)), false, &ctx)
.get_rel_page_at_lsn(TESTREL_A, 1, Version::Lsn(Lsn(0x70)), Lsn::INVALID, &ctx)
.await?,
test_img("foo blk 1")
);
@@ -1894,21 +1894,27 @@ mod tests {
m.commit(&ctx).await?;
assert_eq!(
tline
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x80)), false, &ctx)
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x80)), Lsn::INVALID, &ctx)
.await?,
1501
);
for blk in 2..1500 {
assert_eq!(
tline
.get_rel_page_at_lsn(TESTREL_A, blk, Version::Lsn(Lsn(0x80)), false, &ctx)
.get_rel_page_at_lsn(
TESTREL_A,
blk,
Version::Lsn(Lsn(0x80)),
Lsn::INVALID,
&ctx
)
.await?,
ZERO_PAGE
);
}
assert_eq!(
tline
.get_rel_page_at_lsn(TESTREL_A, 1500, Version::Lsn(Lsn(0x80)), false, &ctx)
.get_rel_page_at_lsn(TESTREL_A, 1500, Version::Lsn(Lsn(0x80)), Lsn::INVALID, &ctx)
.await?,
test_img("foo blk 1500")
);
@@ -1935,13 +1941,13 @@ mod tests {
// Check that rel exists and size is correct
assert_eq!(
tline
.get_rel_exists(TESTREL_A, Version::Lsn(Lsn(0x20)), false, &ctx)
.get_rel_exists(TESTREL_A, Version::Lsn(Lsn(0x20)), Lsn::INVALID, &ctx)
.await?,
true
);
assert_eq!(
tline
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x20)), false, &ctx)
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x20)), Lsn::INVALID, &ctx)
.await?,
1
);
@@ -1954,7 +1960,7 @@ mod tests {
// Check that rel is not visible anymore
assert_eq!(
tline
.get_rel_exists(TESTREL_A, Version::Lsn(Lsn(0x30)), false, &ctx)
.get_rel_exists(TESTREL_A, Version::Lsn(Lsn(0x30)), Lsn::INVALID, &ctx)
.await?,
false
);
@@ -1972,13 +1978,13 @@ mod tests {
// Check that rel exists and size is correct
assert_eq!(
tline
.get_rel_exists(TESTREL_A, Version::Lsn(Lsn(0x40)), false, &ctx)
.get_rel_exists(TESTREL_A, Version::Lsn(Lsn(0x40)), Lsn::INVALID, &ctx)
.await?,
true
);
assert_eq!(
tline
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x40)), false, &ctx)
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x40)), Lsn::INVALID, &ctx)
.await?,
1
);
@@ -2011,24 +2017,24 @@ mod tests {
// The relation was created at LSN 20, not visible at LSN 1 yet.
assert_eq!(
tline
.get_rel_exists(TESTREL_A, Version::Lsn(Lsn(0x10)), false, &ctx)
.get_rel_exists(TESTREL_A, Version::Lsn(Lsn(0x10)), Lsn::INVALID, &ctx)
.await?,
false
);
assert!(tline
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x10)), false, &ctx)
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x10)), Lsn::INVALID, &ctx)
.await
.is_err());
assert_eq!(
tline
.get_rel_exists(TESTREL_A, Version::Lsn(Lsn(0x20)), false, &ctx)
.get_rel_exists(TESTREL_A, Version::Lsn(Lsn(0x20)), Lsn::INVALID, &ctx)
.await?,
true
);
assert_eq!(
tline
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x20)), false, &ctx)
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x20)), Lsn::INVALID, &ctx)
.await?,
relsize
);
@@ -2039,7 +2045,7 @@ mod tests {
let data = format!("foo blk {} at {}", blkno, lsn);
assert_eq!(
tline
.get_rel_page_at_lsn(TESTREL_A, blkno, Version::Lsn(lsn), false, &ctx)
.get_rel_page_at_lsn(TESTREL_A, blkno, Version::Lsn(lsn), Lsn::INVALID, &ctx)
.await?,
test_img(&data)
);
@@ -2056,7 +2062,7 @@ mod tests {
// Check reported size and contents after truncation
assert_eq!(
tline
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x60)), false, &ctx)
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x60)), Lsn::INVALID, &ctx)
.await?,
1
);
@@ -2066,7 +2072,13 @@ mod tests {
let data = format!("foo blk {} at {}", blkno, lsn);
assert_eq!(
tline
.get_rel_page_at_lsn(TESTREL_A, blkno, Version::Lsn(Lsn(0x60)), false, &ctx)
.get_rel_page_at_lsn(
TESTREL_A,
blkno,
Version::Lsn(Lsn(0x60)),
Lsn::INVALID,
&ctx
)
.await?,
test_img(&data)
);
@@ -2075,7 +2087,7 @@ mod tests {
// should still see all blocks with older LSN
assert_eq!(
tline
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x50)), false, &ctx)
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x50)), Lsn::INVALID, &ctx)
.await?,
relsize
);
@@ -2084,7 +2096,13 @@ mod tests {
let data = format!("foo blk {} at {}", blkno, lsn);
assert_eq!(
tline
.get_rel_page_at_lsn(TESTREL_A, blkno, Version::Lsn(Lsn(0x50)), false, &ctx)
.get_rel_page_at_lsn(
TESTREL_A,
blkno,
Version::Lsn(Lsn(0x50)),
Lsn::INVALID,
&ctx
)
.await?,
test_img(&data)
);
@@ -2104,13 +2122,13 @@ mod tests {
assert_eq!(
tline
.get_rel_exists(TESTREL_A, Version::Lsn(Lsn(0x80)), false, &ctx)
.get_rel_exists(TESTREL_A, Version::Lsn(Lsn(0x80)), Lsn::INVALID, &ctx)
.await?,
true
);
assert_eq!(
tline
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x80)), false, &ctx)
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(0x80)), Lsn::INVALID, &ctx)
.await?,
relsize
);
@@ -2120,7 +2138,13 @@ mod tests {
let data = format!("foo blk {} at {}", blkno, lsn);
assert_eq!(
tline
.get_rel_page_at_lsn(TESTREL_A, blkno, Version::Lsn(Lsn(0x80)), false, &ctx)
.get_rel_page_at_lsn(
TESTREL_A,
blkno,
Version::Lsn(Lsn(0x80)),
Lsn::INVALID,
&ctx
)
.await?,
test_img(&data)
);
@@ -2154,7 +2178,7 @@ mod tests {
assert_eq!(
tline
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(lsn)), false, &ctx)
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(lsn)), Lsn::INVALID, &ctx)
.await?,
RELSEG_SIZE + 1
);
@@ -2168,7 +2192,7 @@ mod tests {
m.commit(&ctx).await?;
assert_eq!(
tline
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(lsn)), false, &ctx)
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(lsn)), Lsn::INVALID, &ctx)
.await?,
RELSEG_SIZE
);
@@ -2183,7 +2207,7 @@ mod tests {
m.commit(&ctx).await?;
assert_eq!(
tline
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(lsn)), false, &ctx)
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(lsn)), Lsn::INVALID, &ctx)
.await?,
RELSEG_SIZE - 1
);
@@ -2201,7 +2225,7 @@ mod tests {
m.commit(&ctx).await?;
assert_eq!(
tline
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(lsn)), false, &ctx)
.get_rel_size(TESTREL_A, Version::Lsn(Lsn(lsn)), Lsn::INVALID, &ctx)
.await?,
size as BlockNumber
);

View File

@@ -49,6 +49,8 @@ char *neon_auth_token;
int readahead_buffer_size = 128;
int flush_every_n_requests = 8;
int neon_protocol_version;
static int n_reconnect_attempts = 0;
static int max_reconnect_attempts = 60;
static int stripe_size;
@@ -844,6 +846,14 @@ pg_init_libpagestore(void)
PGC_USERSET,
0, /* no flags required */
NULL, (GucIntAssignHook) &readahead_buffer_resize, NULL);
DefineCustomIntVariable("neon.protocol_version",
"Version of compute<->page server protocol",
NULL,
&neon_protocol_version,
NEON_PROTOCOL_VERSION, 1, INT_MAX,
PGC_USERSET,
0, /* no flags required */
NULL, NULL, NULL);
relsize_hash_init();

View File

@@ -28,10 +28,17 @@
#define MAX_SHARDS 128
#define MAX_PAGESERVER_CONNSTRING_SIZE 256
/*
* Right now protocal version is not set to the server.
* So it is ciritical that format of existed commands is not changed.
* New protocl versions can just add new commands.
*/
#define NEON_PROTOCOL_VERSION 2
typedef enum
{
/* pagestore_client -> pagestore */
T_NeonExistsRequest = 0,
T_NeonExistsRequest = 10, /* new protocol message tags start from 10 */
T_NeonNblocksRequest,
T_NeonGetPageRequest,
T_NeonDbSizeRequest,
@@ -72,14 +79,20 @@ typedef enum {
/*
* supertype of all the Neon*Request structs below
*
* If 'latest' is true, we are requesting the latest page version, and 'lsn'
* In old version of Neon we have 'latest' flag indicating that we are requesting the latest page version, and 'lsn'
* is just a hint to the server that we know there are no versions of the page
* (or relation size, for exists/nblocks requests) later than the 'lsn'.
*
* But it doesn't work for hot-standby replica because it may be not at the latest LSN position.
* So we need to be able to specify upper boundary for LSN which page server can send to us.
* This is why 'latest' flag is replaced with 'horizon'. MAX_LSN=~0 value of 'horizon' means that we are requesting latest version.
* If we need version on exact LSN (for static RO replicas), 'horizon' should be set to 0: in this case range [lsn,lsn] is used by page server.
* Otherwise for hot-standby replica we specify in 'horizon' current replay position.
*/
typedef struct
{
NeonMessageTag tag;
bool latest; /* if true, request latest page version */
XLogRecPtr horizon; /* upper boundary for page LSN */
XLogRecPtr lsn; /* request page version @ this LSN */
} NeonRequest;
@@ -193,6 +206,7 @@ extern int readahead_buffer_size;
extern char *neon_timeline;
extern char *neon_tenant;
extern int32 max_cluster_size;
extern int neon_protocol_version;
extern shardno_t get_shard_number(BufferTag* tag);

View File

@@ -110,6 +110,20 @@ static UnloggedBuildPhase unlogged_build_phase = UNLOGGED_BUILD_NOT_IN_PROGRESS;
static bool neon_redo_read_buffer_filter(XLogReaderState *record, uint8 block_id);
static bool (*old_redo_read_buffer_filter) (XLogReaderState *record, uint8 block_id) = NULL;
#define MAX_LSN ((XLogRecPtr)~0)
/*
* There are three kinds of get_page :
* 1. Master compute: get the latest page not older than specified LSN (horizon=Lsn::MAX)
* 2. RO replica: get the latest page not newer than current WAL position replica already applied (horizon=GetXLogReplayRecPtr(NULL))
* 3. Snapshot: get latest page not new than specified LSN (horizon=request_lsn)
*/
static XLogRecPtr
neon_get_horizon(bool latest)
{
return latest ? MAX_LSN : RecoveryInProgress() ? GetXLogReplayRecPtr(NULL) : InvalidXLogRecPtr; /* horizon=InvalidXlogRecPtr is replaced with request_lsn at PS */
}
/*
* Prefetch implementation:
*
@@ -687,9 +701,10 @@ static void
prefetch_do_request(PrefetchRequest *slot, bool *force_latest, XLogRecPtr *force_lsn)
{
bool found;
bool latest;
NeonGetPageRequest request = {
.req.tag = T_NeonGetPageRequest,
.req.latest = false,
.req.horizon = 0,
.req.lsn = 0,
.rinfo = BufTagGetNRelFileInfo(slot->buftag),
.forknum = slot->buftag.forkNum,
@@ -699,13 +714,13 @@ prefetch_do_request(PrefetchRequest *slot, bool *force_latest, XLogRecPtr *force
if (force_lsn && force_latest)
{
request.req.lsn = *force_lsn;
request.req.latest = *force_latest;
latest = *force_latest;
slot->actual_request_lsn = slot->effective_request_lsn = *force_lsn;
}
else
{
XLogRecPtr lsn = neon_get_request_lsn(
&request.req.latest,
&latest,
BufTagGetNRelFileInfo(slot->buftag),
slot->buftag.forkNum,
slot->buftag.blockNum
@@ -733,6 +748,7 @@ prefetch_do_request(PrefetchRequest *slot, bool *force_latest, XLogRecPtr *force
prefetch_lsn = Max(prefetch_lsn, lsn);
slot->effective_request_lsn = prefetch_lsn;
}
request.req.horizon = neon_get_horizon(latest);
Assert(slot->response == NULL);
Assert(slot->my_ring_index == MyPState->ring_unused);
@@ -997,7 +1013,19 @@ nm_pack_request(NeonRequest *msg)
StringInfoData s;
initStringInfo(&s);
pq_sendbyte(&s, msg->tag);
if (neon_protocol_version >= 2)
{
pq_sendbyte(&s, msg->tag);
pq_sendint64(&s, msg->horizon);
}
else
{
/* Old protocol with latest flag */
pq_sendbyte(&s, msg->tag - T_NeonExistsRequest); /* old protocol command tags start from zero */
pq_sendbyte(&s, msg->horizon == MAX_LSN);
}
pq_sendint64(&s, msg->lsn);
switch (messageTag(msg))
{
@@ -1006,8 +1034,6 @@ nm_pack_request(NeonRequest *msg)
{
NeonExistsRequest *msg_req = (NeonExistsRequest *) msg;
pq_sendbyte(&s, msg_req->req.latest);
pq_sendint64(&s, msg_req->req.lsn);
pq_sendint32(&s, NInfoGetSpcOid(msg_req->rinfo));
pq_sendint32(&s, NInfoGetDbOid(msg_req->rinfo));
pq_sendint32(&s, NInfoGetRelNumber(msg_req->rinfo));
@@ -1019,8 +1045,6 @@ nm_pack_request(NeonRequest *msg)
{
NeonNblocksRequest *msg_req = (NeonNblocksRequest *) msg;
pq_sendbyte(&s, msg_req->req.latest);
pq_sendint64(&s, msg_req->req.lsn);
pq_sendint32(&s, NInfoGetSpcOid(msg_req->rinfo));
pq_sendint32(&s, NInfoGetDbOid(msg_req->rinfo));
pq_sendint32(&s, NInfoGetRelNumber(msg_req->rinfo));
@@ -1032,8 +1056,6 @@ nm_pack_request(NeonRequest *msg)
{
NeonDbSizeRequest *msg_req = (NeonDbSizeRequest *) msg;
pq_sendbyte(&s, msg_req->req.latest);
pq_sendint64(&s, msg_req->req.lsn);
pq_sendint32(&s, msg_req->dbNode);
break;
@@ -1042,8 +1064,6 @@ nm_pack_request(NeonRequest *msg)
{
NeonGetPageRequest *msg_req = (NeonGetPageRequest *) msg;
pq_sendbyte(&s, msg_req->req.latest);
pq_sendint64(&s, msg_req->req.lsn);
pq_sendint32(&s, NInfoGetSpcOid(msg_req->rinfo));
pq_sendint32(&s, NInfoGetDbOid(msg_req->rinfo));
pq_sendint32(&s, NInfoGetRelNumber(msg_req->rinfo));
@@ -1057,8 +1077,6 @@ nm_pack_request(NeonRequest *msg)
{
NeonGetSlruSegmentRequest *msg_req = (NeonGetSlruSegmentRequest *) msg;
pq_sendbyte(&s, msg_req->req.latest);
pq_sendint64(&s, msg_req->req.lsn);
pq_sendbyte(&s, msg_req->kind);
pq_sendint32(&s, msg_req->segno);
@@ -1209,7 +1227,7 @@ nm_to_string(NeonMessage *msg)
appendStringInfo(&s, ", \"rinfo\": \"%u/%u/%u\"", RelFileInfoFmt(msg_req->rinfo));
appendStringInfo(&s, ", \"forknum\": %d", msg_req->forknum);
appendStringInfo(&s, ", \"lsn\": \"%X/%X\"", LSN_FORMAT_ARGS(msg_req->req.lsn));
appendStringInfo(&s, ", \"latest\": %d", msg_req->req.latest);
appendStringInfo(&s, ", \"horizon\": \"%X/%X\"", LSN_FORMAT_ARGS(msg_req->req.horizon));
appendStringInfoChar(&s, '}');
break;
}
@@ -1222,7 +1240,7 @@ nm_to_string(NeonMessage *msg)
appendStringInfo(&s, ", \"rinfo\": \"%u/%u/%u\"", RelFileInfoFmt(msg_req->rinfo));
appendStringInfo(&s, ", \"forknum\": %d", msg_req->forknum);
appendStringInfo(&s, ", \"lsn\": \"%X/%X\"", LSN_FORMAT_ARGS(msg_req->req.lsn));
appendStringInfo(&s, ", \"latest\": %d", msg_req->req.latest);
appendStringInfo(&s, ", \"horizon\": \"%X/%X\"", LSN_FORMAT_ARGS(msg_req->req.horizon));
appendStringInfoChar(&s, '}');
break;
}
@@ -1236,7 +1254,7 @@ nm_to_string(NeonMessage *msg)
appendStringInfo(&s, ", \"forknum\": %d", msg_req->forknum);
appendStringInfo(&s, ", \"blkno\": %u", msg_req->blkno);
appendStringInfo(&s, ", \"lsn\": \"%X/%X\"", LSN_FORMAT_ARGS(msg_req->req.lsn));
appendStringInfo(&s, ", \"latest\": %d", msg_req->req.latest);
appendStringInfo(&s, ", \"horizon\": \"%X/%X\"", LSN_FORMAT_ARGS(msg_req->req.horizon));
appendStringInfoChar(&s, '}');
break;
}
@@ -1247,7 +1265,7 @@ nm_to_string(NeonMessage *msg)
appendStringInfoString(&s, "{\"type\": \"NeonDbSizeRequest\"");
appendStringInfo(&s, ", \"dbnode\": \"%u\"", msg_req->dbNode);
appendStringInfo(&s, ", \"lsn\": \"%X/%X\"", LSN_FORMAT_ARGS(msg_req->req.lsn));
appendStringInfo(&s, ", \"latest\": %d", msg_req->req.latest);
appendStringInfo(&s, ", \"horizon\": \"%X/%X\"", LSN_FORMAT_ARGS(msg_req->req.horizon));
appendStringInfoChar(&s, '}');
break;
}
@@ -1259,7 +1277,7 @@ nm_to_string(NeonMessage *msg)
appendStringInfo(&s, ", \"kind\": %u", msg_req->kind);
appendStringInfo(&s, ", \"segno\": %u", msg_req->segno);
appendStringInfo(&s, ", \"lsn\": \"%X/%X\"", LSN_FORMAT_ARGS(msg_req->req.lsn));
appendStringInfo(&s, ", \"latest\": %d", msg_req->req.latest);
appendStringInfo(&s, ", \"horizon\": \"%X/%X\"", LSN_FORMAT_ARGS(msg_req->req.horizon));
appendStringInfoChar(&s, '}');
break;
}
@@ -1664,7 +1682,7 @@ neon_exists(SMgrRelation reln, ForkNumber forkNum)
{
NeonExistsRequest request = {
.req.tag = T_NeonExistsRequest,
.req.latest = latest,
.req.horizon = neon_get_horizon(latest),
.req.lsn = request_lsn,
.rinfo = InfoFromSMgrRel(reln),
.forknum = forkNum};
@@ -2474,7 +2492,7 @@ neon_nblocks(SMgrRelation reln, ForkNumber forknum)
{
NeonNblocksRequest request = {
.req.tag = T_NeonNblocksRequest,
.req.latest = latest,
.req.horizon = neon_get_horizon(latest),
.req.lsn = request_lsn,
.rinfo = InfoFromSMgrRel(reln),
.forknum = forknum,
@@ -2531,7 +2549,7 @@ neon_dbsize(Oid dbNode)
{
NeonDbSizeRequest request = {
.req.tag = T_NeonDbSizeRequest,
.req.latest = latest,
.req.horizon = neon_get_horizon(latest),
.req.lsn = request_lsn,
.dbNode = dbNode,
};
@@ -2827,7 +2845,7 @@ neon_read_slru_segment(SMgrRelation reln, const char* path, int segno, void* buf
NeonResponse *resp;
NeonGetSlruSegmentRequest request = {
.req.tag = T_NeonGetSlruSegmentRequest,
.req.latest = false,
.req.horizon = InvalidXLogRecPtr,
.req.lsn = request_lsn,
.kind = kind,
@@ -2980,7 +2998,7 @@ neon_extend_rel_size(NRelFileInfo rinfo, ForkNumber forknum, BlockNumber blkno,
NeonNblocksRequest request = {
.req = (NeonRequest) {
.lsn = end_recptr,
.latest = false,
.horizon = neon_get_horizon(false),
.tag = T_NeonNblocksRequest,
},
.rinfo = rinfo,

View File

@@ -27,7 +27,7 @@ use crate::{
},
stream, url,
};
use crate::{scram, EndpointCacheKey, EndpointId, Normalize, RoleName};
use crate::{scram, EndpointCacheKey, EndpointId, RoleName};
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{info, warn};
@@ -186,7 +186,7 @@ impl AuthenticationConfig {
is_cleartext: bool,
) -> auth::Result<AuthSecret> {
// we have validated the endpoint exists, so let's intern it.
let endpoint_int = EndpointIdInt::from(endpoint.normalize());
let endpoint_int = EndpointIdInt::from(endpoint);
// only count the full hash count if password hack or websocket flow.
// in other words, if proxy needs to run the hashing

View File

@@ -189,9 +189,7 @@ struct ProxyCliArgs {
/// cache for `project_info` (use `size=0` to disable)
#[clap(long, default_value = config::ProjectInfoCacheOptions::CACHE_DEFAULT_OPTIONS)]
project_info_cache: String,
/// cache for all valid endpoints
#[clap(long, default_value = config::EndpointCacheConfig::CACHE_DEFAULT_OPTIONS)]
endpoint_cache_config: String,
#[clap(flatten)]
parquet_upload: ParquetUploadArgs,
@@ -403,7 +401,6 @@ async fn main() -> anyhow::Result<()> {
if let auth::BackendType::Console(api, _) = &config.auth_backend {
if let proxy::console::provider::ConsoleBackend::Console(api) = &**api {
maintenance_tasks.spawn(api.locks.garbage_collect_worker());
if let Some(redis_notifications_client) = redis_notifications_client {
let cache = api.caches.project_info.clone();
maintenance_tasks.spawn(notifications::task_main(
@@ -413,9 +410,6 @@ async fn main() -> anyhow::Result<()> {
args.region.clone(),
));
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
let cache = api.caches.endpoints_cache.clone();
let con = redis_notifications_client.clone();
maintenance_tasks.spawn(async move { cache.do_read(con).await });
}
}
}
@@ -495,18 +489,14 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?;
let project_info_cache_config: ProjectInfoCacheOptions =
args.project_info_cache.parse()?;
let endpoint_cache_config: config::EndpointCacheConfig =
args.endpoint_cache_config.parse()?;
info!("Using NodeInfoCache (wake_compute) with options={wake_compute_cache_config:?}");
info!(
"Using AllowedIpsCache (wake_compute) with options={project_info_cache_config:?}"
);
info!("Using EndpointCacheConfig with options={endpoint_cache_config:?}");
let caches = Box::leak(Box::new(console::caches::ApiCaches::new(
wake_compute_cache_config,
project_info_cache_config,
endpoint_cache_config,
)));
let config::WakeComputeLockOptions {
@@ -517,9 +507,10 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
} = args.wake_compute_lock.parse()?;
info!(permits, shards, ?epoch, "Using NodeLocks (wake_compute)");
let locks = Box::leak(Box::new(
console::locks::ApiLocks::new("wake_compute_lock", permits, shards, timeout, epoch)
console::locks::ApiLocks::new("wake_compute_lock", permits, shards, timeout)
.unwrap(),
));
tokio::spawn(locks.garbage_collect_worker(epoch));
let url = args.auth_endpoint.parse()?;
let endpoint = http::Endpoint::new(url, http::new_client(rate_limiter_config));

View File

@@ -1,5 +1,4 @@
pub mod common;
pub mod endpoints;
pub mod project_info;
mod timed_lru;

View File

@@ -1,191 +0,0 @@
use std::{
convert::Infallible,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use dashmap::DashSet;
use redis::{
streams::{StreamReadOptions, StreamReadReply},
AsyncCommands, FromRedisValue, Value,
};
use serde::Deserialize;
use tokio::sync::Mutex;
use crate::{
config::EndpointCacheConfig,
context::RequestMonitoring,
intern::{BranchIdInt, EndpointIdInt, ProjectIdInt},
metrics::REDIS_BROKEN_MESSAGES,
rate_limiter::GlobalRateLimiter,
redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider,
EndpointId, Normalize,
};
#[derive(Deserialize, Debug, Clone)]
#[serde(rename_all(deserialize = "snake_case"))]
pub enum ControlPlaneEventKey {
EndpointCreated,
BranchCreated,
ProjectCreated,
}
pub struct EndpointsCache {
config: EndpointCacheConfig,
endpoints: DashSet<EndpointIdInt>,
branches: DashSet<BranchIdInt>,
projects: DashSet<ProjectIdInt>,
ready: AtomicBool,
limiter: Arc<Mutex<GlobalRateLimiter>>,
}
impl EndpointsCache {
pub fn new(config: EndpointCacheConfig) -> Self {
Self {
limiter: Arc::new(Mutex::new(GlobalRateLimiter::new(
config.limiter_info.clone(),
))),
config,
endpoints: DashSet::new(),
branches: DashSet::new(),
projects: DashSet::new(),
ready: AtomicBool::new(false),
}
}
pub async fn is_valid(&self, ctx: &mut RequestMonitoring, endpoint: &EndpointId) -> bool {
if !self.ready.load(Ordering::Acquire) {
return true;
}
// If cache is disabled, just collect the metrics and return.
if self.config.disable_cache {
ctx.set_rejected(self.should_reject(endpoint));
return true;
}
// If the limiter allows, we don't need to check the cache.
if self.limiter.lock().await.check() {
return true;
}
let rejected = self.should_reject(endpoint);
ctx.set_rejected(rejected);
!rejected
}
fn should_reject(&self, endpoint: &EndpointId) -> bool {
let endpoint = endpoint.normalize();
if endpoint.is_endpoint() {
!self.endpoints.contains(&EndpointIdInt::from(&endpoint))
} else if endpoint.is_branch() {
!self
.branches
.contains(&BranchIdInt::from(&endpoint.as_branch()))
} else {
!self
.projects
.contains(&ProjectIdInt::from(&endpoint.as_project()))
}
}
fn insert_event(&self, key: ControlPlaneEventKey, value: String) {
// Do not do normalization here, we expect the events to be normalized.
match key {
ControlPlaneEventKey::EndpointCreated => {
self.endpoints.insert(EndpointIdInt::from(&value.into()));
}
ControlPlaneEventKey::BranchCreated => {
self.branches.insert(BranchIdInt::from(&value.into()));
}
ControlPlaneEventKey::ProjectCreated => {
self.projects.insert(ProjectIdInt::from(&value.into()));
}
}
}
pub async fn do_read(
&self,
mut con: ConnectionWithCredentialsProvider,
) -> anyhow::Result<Infallible> {
let mut last_id = "0-0".to_string();
loop {
self.ready.store(false, Ordering::Release);
if let Err(e) = con.connect().await {
tracing::error!("error connecting to redis: {:?}", e);
continue;
}
if let Err(e) = self.read_from_stream(&mut con, &mut last_id).await {
tracing::error!("error reading from redis: {:?}", e);
}
}
}
async fn read_from_stream(
&self,
con: &mut ConnectionWithCredentialsProvider,
last_id: &mut String,
) -> anyhow::Result<()> {
tracing::info!("reading endpoints/branches/projects from redis");
self.batch_read(
con,
StreamReadOptions::default().count(self.config.initial_batch_size),
last_id,
true,
)
.await?;
tracing::info!("ready to filter user requests");
self.ready.store(true, Ordering::Release);
self.batch_read(
con,
StreamReadOptions::default()
.count(self.config.initial_batch_size)
.block(self.config.xread_timeout.as_millis() as usize),
last_id,
false,
)
.await
}
fn parse_key_value(key: &str, value: &Value) -> anyhow::Result<(ControlPlaneEventKey, String)> {
Ok((serde_json::from_str(key)?, String::from_redis_value(value)?))
}
async fn batch_read(
&self,
conn: &mut ConnectionWithCredentialsProvider,
opts: StreamReadOptions,
last_id: &mut String,
return_when_finish: bool,
) -> anyhow::Result<()> {
let mut total: usize = 0;
loop {
let mut res: StreamReadReply = conn
.xread_options(&[&self.config.stream_name], &[last_id.as_str()], &opts)
.await?;
if res.keys.len() != 1 {
anyhow::bail!("Cannot read from redis stream {}", self.config.stream_name);
}
let res = res.keys.pop().expect("Checked length above");
if return_when_finish && res.ids.len() <= self.config.default_batch_size {
break;
}
for x in res.ids {
total += 1;
for (k, v) in x.map {
let (key, value) = match Self::parse_key_value(&k, &v) {
Ok(x) => x,
Err(e) => {
REDIS_BROKEN_MESSAGES
.with_label_values(&[&self.config.stream_name])
.inc();
tracing::error!("error parsing key-value {k}-{v:?}: {e:?}");
continue;
}
};
self.insert_event(key, value);
}
if total.is_power_of_two() {
tracing::debug!("endpoints read {}", total);
}
*last_id = x.id;
}
}
tracing::info!("read {} endpoints/branches/projects from redis", total);
Ok(())
}
}

View File

@@ -16,7 +16,7 @@ use crate::{
config::ProjectInfoCacheOptions,
console::AuthSecret,
intern::{EndpointIdInt, ProjectIdInt, RoleNameInt},
EndpointCacheKey, EndpointId, RoleName,
EndpointId, RoleName,
};
use super::{Cache, Cached};
@@ -196,7 +196,7 @@ impl ProjectInfoCacheImpl {
}
pub fn get_allowed_ips(
&self,
endpoint_id: &EndpointCacheKey,
endpoint_id: &EndpointId,
) -> Option<Cached<&Self, Arc<Vec<IpPattern>>>> {
let endpoint_id = EndpointIdInt::get(endpoint_id)?;
let (valid_since, ignore_cache_since) = self.get_cache_times();

View File

@@ -313,75 +313,6 @@ impl CertResolver {
}
}
#[derive(Debug)]
pub struct EndpointCacheConfig {
/// Batch size to receive all endpoints on the startup.
pub initial_batch_size: usize,
/// Batch size to receive endpoints.
pub default_batch_size: usize,
/// Timeouts for the stream read operation.
pub xread_timeout: Duration,
/// Stream name to read from.
pub stream_name: String,
/// Limiter info (to distinguish when to enable cache).
pub limiter_info: Vec<RateBucketInfo>,
/// Disable cache.
/// If true, cache is ignored, but reports all statistics.
pub disable_cache: bool,
}
impl EndpointCacheConfig {
/// Default options for [`crate::console::provider::NodeInfoCache`].
/// Notice that by default the limiter is empty, which means that cache is disabled.
pub const CACHE_DEFAULT_OPTIONS: &'static str =
"initial_batch_size=1000,default_batch_size=10,xread_timeout=5m,stream_name=controlPlane,disable_cache=true,limiter_info=1000@1s";
/// Parse cache options passed via cmdline.
/// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
fn parse(options: &str) -> anyhow::Result<Self> {
let mut initial_batch_size = None;
let mut default_batch_size = None;
let mut xread_timeout = None;
let mut stream_name = None;
let mut limiter_info = vec![];
let mut disable_cache = false;
for option in options.split(',') {
let (key, value) = option
.split_once('=')
.with_context(|| format!("bad key-value pair: {option}"))?;
match key {
"initial_batch_size" => initial_batch_size = Some(value.parse()?),
"default_batch_size" => default_batch_size = Some(value.parse()?),
"xread_timeout" => xread_timeout = Some(humantime::parse_duration(value)?),
"stream_name" => stream_name = Some(value.to_string()),
"limiter_info" => limiter_info.push(RateBucketInfo::from_str(value)?),
"disable_cache" => disable_cache = value.parse()?,
unknown => bail!("unknown key: {unknown}"),
}
}
RateBucketInfo::validate(&mut limiter_info)?;
Ok(Self {
initial_batch_size: initial_batch_size.context("missing `initial_batch_size`")?,
default_batch_size: default_batch_size.context("missing `default_batch_size`")?,
xread_timeout: xread_timeout.context("missing `xread_timeout`")?,
stream_name: stream_name.context("missing `stream_name`")?,
disable_cache,
limiter_info,
})
}
}
impl FromStr for EndpointCacheConfig {
type Err = anyhow::Error;
fn from_str(options: &str) -> Result<Self, Self::Err> {
let error = || format!("failed to parse endpoint cache options '{options}'");
Self::parse(options).with_context(error)
}
}
#[derive(Debug)]
pub struct MetricBackupCollectionConfig {
pub interval: Duration,

View File

@@ -8,15 +8,15 @@ use crate::{
backend::{ComputeCredentialKeys, ComputeUserInfo},
IpPattern,
},
cache::{endpoints::EndpointsCache, project_info::ProjectInfoCacheImpl, Cached, TimedLru},
cache::{project_info::ProjectInfoCacheImpl, Cached, TimedLru},
compute,
config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions},
config::{CacheOptions, ProjectInfoCacheOptions},
context::RequestMonitoring,
intern::ProjectIdInt,
scram, EndpointCacheKey,
};
use dashmap::DashMap;
use std::{convert::Infallible, sync::Arc, time::Duration};
use std::{sync::Arc, time::Duration};
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
use tokio::time::Instant;
use tracing::info;
@@ -416,15 +416,12 @@ pub struct ApiCaches {
pub node_info: NodeInfoCache,
/// Cache which stores project_id -> endpoint_ids mapping.
pub project_info: Arc<ProjectInfoCacheImpl>,
/// List of all valid endpoints.
pub endpoints_cache: Arc<EndpointsCache>,
}
impl ApiCaches {
pub fn new(
wake_compute_cache_config: CacheOptions,
project_info_cache_config: ProjectInfoCacheOptions,
endpoint_cache_config: EndpointCacheConfig,
) -> Self {
Self {
node_info: NodeInfoCache::new(
@@ -434,7 +431,6 @@ impl ApiCaches {
true,
),
project_info: Arc::new(ProjectInfoCacheImpl::new(project_info_cache_config)),
endpoints_cache: Arc::new(EndpointsCache::new(endpoint_cache_config)),
}
}
}
@@ -445,7 +441,6 @@ pub struct ApiLocks {
node_locks: DashMap<EndpointCacheKey, Arc<Semaphore>>,
permits: usize,
timeout: Duration,
epoch: std::time::Duration,
registered: prometheus::IntCounter,
unregistered: prometheus::IntCounter,
reclamation_lag: prometheus::Histogram,
@@ -458,7 +453,6 @@ impl ApiLocks {
permits: usize,
shards: usize,
timeout: Duration,
epoch: std::time::Duration,
) -> prometheus::Result<Self> {
let registered = prometheus::IntCounter::with_opts(
prometheus::Opts::new(
@@ -503,7 +497,6 @@ impl ApiLocks {
node_locks: DashMap::with_shard_amount(shards),
permits,
timeout,
epoch,
lock_acquire_lag,
registered,
unregistered,
@@ -543,9 +536,12 @@ impl ApiLocks {
})
}
pub async fn garbage_collect_worker(&self) -> anyhow::Result<Infallible> {
let mut interval =
tokio::time::interval(self.epoch / (self.node_locks.shards().len()) as u32);
pub async fn garbage_collect_worker(&self, epoch: std::time::Duration) {
if self.permits == 0 {
return;
}
let mut interval = tokio::time::interval(epoch / (self.node_locks.shards().len()) as u32);
loop {
for (i, shard) in self.node_locks.shards().iter().enumerate() {
interval.tick().await;

View File

@@ -8,7 +8,6 @@ use super::{
};
use crate::{
auth::backend::ComputeUserInfo, compute, console::messages::ColdStartInfo, http, scram,
Normalize,
};
use crate::{
cache::Cached,
@@ -24,7 +23,7 @@ use tracing::{error, info, info_span, warn, Instrument};
pub struct Api {
endpoint: http::Endpoint,
pub caches: &'static ApiCaches,
pub locks: &'static ApiLocks,
locks: &'static ApiLocks,
jwt: String,
}
@@ -56,19 +55,10 @@ impl Api {
ctx: &mut RequestMonitoring,
user_info: &ComputeUserInfo,
) -> Result<AuthInfo, GetAuthInfoError> {
if !self
.caches
.endpoints_cache
.is_valid(ctx, &user_info.endpoint)
.await
{
info!("endpoint is not valid, skipping the request");
return Ok(AuthInfo::default());
}
let request_id = ctx.session_id.to_string();
let application_name = ctx.console_application_name();
async {
let mut request_builder = self
let request = self
.endpoint
.get("proxy_get_role_secret")
.header("X-Request-ID", &request_id)
@@ -78,14 +68,8 @@ impl Api {
("application_name", application_name.as_str()),
("project", user_info.endpoint.as_str()),
("role", user_info.user.as_str()),
]);
let options = user_info.options.to_deep_object();
if !options.is_empty() {
request_builder = request_builder.query(&options);
}
let request = request_builder.build()?;
])
.build()?;
info!(url = request.url().as_str(), "sending http request");
let start = Instant::now();
@@ -97,9 +81,7 @@ impl Api {
Ok(body) => body,
// Error 404 is special: it's ok not to have a secret.
Err(e) => match e.http_status_code() {
Some(http::StatusCode::NOT_FOUND) => {
return Ok(AuthInfo::default());
}
Some(http::StatusCode::NOT_FOUND) => return Ok(AuthInfo::default()),
_otherwise => return Err(e.into()),
},
};
@@ -199,7 +181,7 @@ impl super::Api for Api {
}
let auth_info = self.do_get_auth_info(ctx, user_info).await?;
if let Some(project_id) = auth_info.project_id {
let ep_int = ep.normalize().into();
let ep_int = ep.into();
self.caches.project_info.insert_role_secret(
project_id,
ep_int,
@@ -222,8 +204,8 @@ impl super::Api for Api {
ctx: &mut RequestMonitoring,
user_info: &ComputeUserInfo,
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
let cache_key = user_info.endpoint_cache_key();
if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(&cache_key) {
let ep = &user_info.endpoint;
if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(ep) {
ALLOWED_IPS_BY_CACHE_OUTCOME
.with_label_values(&["hit"])
.inc();
@@ -236,7 +218,7 @@ impl super::Api for Api {
let allowed_ips = Arc::new(auth_info.allowed_ips);
let user = &user_info.user;
if let Some(project_id) = auth_info.project_id {
let ep_int = cache_key.normalize().into();
let ep_int = ep.into();
self.caches.project_info.insert_role_secret(
project_id,
ep_int,

View File

@@ -12,9 +12,7 @@ use crate::{
console::messages::{ColdStartInfo, MetricsAuxInfo},
error::ErrorKind,
intern::{BranchIdInt, ProjectIdInt},
metrics::{
bool_to_str, LatencyTimer, ENDPOINT_ERRORS_BY_KIND, ERROR_BY_KIND, NUM_INVALID_ENDPOINTS,
},
metrics::{LatencyTimer, ENDPOINT_ERRORS_BY_KIND, ERROR_BY_KIND},
DbName, EndpointId, RoleName,
};
@@ -52,8 +50,6 @@ pub struct RequestMonitoring {
// This sender is here to keep the request monitoring channel open while requests are taking place.
sender: Option<mpsc::UnboundedSender<RequestData>>,
pub latency_timer: LatencyTimer,
// Whether proxy decided that it's not a valid endpoint end rejected it before going to cplane.
rejected: bool,
}
#[derive(Clone, Debug)]
@@ -97,7 +93,6 @@ impl RequestMonitoring {
error_kind: None,
auth_method: None,
success: false,
rejected: false,
cold_start_info: ColdStartInfo::Unknown,
sender: LOG_CHAN.get().and_then(|tx| tx.upgrade()),
@@ -118,10 +113,6 @@ impl RequestMonitoring {
)
}
pub fn set_rejected(&mut self, rejected: bool) {
self.rejected = rejected;
}
pub fn set_cold_start_info(&mut self, info: ColdStartInfo) {
self.cold_start_info = info;
self.latency_timer.cold_start_info(info);
@@ -187,10 +178,6 @@ impl RequestMonitoring {
impl Drop for RequestMonitoring {
fn drop(&mut self) {
let outcome = if self.success { "success" } else { "failure" };
NUM_INVALID_ENDPOINTS
.with_label_values(&[self.protocol, bool_to_str(self.rejected), outcome])
.inc();
if let Some(tx) = self.sender.take() {
let _: Result<(), _> = tx.send(RequestData::from(&*self));
}

View File

@@ -5,7 +5,7 @@ use std::{
use lasso::{Capacity, MemoryLimits, Spur, ThreadedRodeo};
use rustc_hash::FxHasher;
use crate::{BranchId, EndpointCacheKey, EndpointId, ProjectId, RoleName};
use crate::{BranchId, EndpointId, ProjectId, RoleName};
pub trait InternId: Sized + 'static {
fn get_interner() -> &'static StringInterner<Self>;
@@ -160,16 +160,6 @@ impl From<&EndpointId> for EndpointIdInt {
EndpointIdTag::get_interner().get_or_intern(value)
}
}
impl From<EndpointId> for EndpointIdInt {
fn from(value: EndpointId) -> Self {
EndpointIdTag::get_interner().get_or_intern(&value)
}
}
impl From<EndpointCacheKey> for EndpointIdInt {
fn from(value: EndpointCacheKey) -> Self {
EndpointIdTag::get_interner().get_or_intern(&value)
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub struct BranchIdTag;
@@ -185,11 +175,6 @@ impl From<&BranchId> for BranchIdInt {
BranchIdTag::get_interner().get_or_intern(value)
}
}
impl From<BranchId> for BranchIdInt {
fn from(value: BranchId) -> Self {
BranchIdTag::get_interner().get_or_intern(&value)
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub struct ProjectIdTag;
@@ -205,11 +190,6 @@ impl From<&ProjectId> for ProjectIdInt {
ProjectIdTag::get_interner().get_or_intern(value)
}
}
impl From<ProjectId> for ProjectIdInt {
fn from(value: ProjectId) -> Self {
ProjectIdTag::get_interner().get_or_intern(&value)
}
}
#[cfg(test)]
mod tests {

View File

@@ -127,24 +127,6 @@ macro_rules! smol_str_wrapper {
};
}
const POOLER_SUFFIX: &str = "-pooler";
pub trait Normalize {
fn normalize(&self) -> Self;
}
impl<S: Clone + AsRef<str> + From<String>> Normalize for S {
fn normalize(&self) -> Self {
if self.as_ref().ends_with(POOLER_SUFFIX) {
let mut s = self.as_ref().to_string();
s.truncate(s.len() - POOLER_SUFFIX.len());
s.into()
} else {
self.clone()
}
}
}
// 90% of role name strings are 20 characters or less.
smol_str_wrapper!(RoleName);
// 50% of endpoint strings are 23 characters or less.
@@ -158,22 +140,3 @@ smol_str_wrapper!(ProjectId);
smol_str_wrapper!(EndpointCacheKey);
smol_str_wrapper!(DbName);
// Endpoints are a bit tricky. Rare they might be branches or projects.
impl EndpointId {
pub fn is_endpoint(&self) -> bool {
self.0.starts_with("ep-")
}
pub fn is_branch(&self) -> bool {
self.0.starts_with("br-")
}
pub fn is_project(&self) -> bool {
!self.is_endpoint() && !self.is_branch()
}
pub fn as_branch(&self) -> BranchId {
BranchId(self.0.clone())
}
pub fn as_project(&self) -> ProjectId {
ProjectId(self.0.clone())
}
}

View File

@@ -169,18 +169,6 @@ pub static NUM_CANCELLATION_REQUESTS: Lazy<IntCounterVec> = Lazy::new(|| {
.unwrap()
});
pub static NUM_INVALID_ENDPOINTS: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"proxy_invalid_endpoints_total",
"Number of invalid endpoints (per protocol, per rejected).",
// http/ws/tcp, true/false, success/failure
// TODO(anna): the last dimension is just a proxy to what we actually want to measure.
// We need to measure whether the endpoint was found by cplane or not.
&["protocol", "rejected", "outcome"],
)
.unwrap()
});
pub const NUM_CANCELLATION_REQUESTS_SOURCE_FROM_CLIENT: &str = "from_client";
pub const NUM_CANCELLATION_REQUESTS_SOURCE_FROM_REDIS: &str = "from_redis";

View File

@@ -20,7 +20,7 @@ use crate::{
proxy::handshake::{handshake, HandshakeData},
rate_limiter::EndpointRateLimiter,
stream::{PqStream, Stream},
EndpointCacheKey, Normalize,
EndpointCacheKey,
};
use futures::TryFutureExt;
use itertools::Itertools;
@@ -280,7 +280,7 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
// check rate limit
if let Some(ep) = user_info.get_endpoint() {
if !endpoint_rate_limiter.check(ep.normalize(), 1) {
if !endpoint_rate_limiter.check(ep, 1) {
return stream
.throw_error(auth::AuthError::too_many_connections())
.await?;

View File

@@ -4,4 +4,4 @@ mod limiter;
pub use aimd::Aimd;
pub use limit_algorithm::{AimdConfig, Fixed, RateLimitAlgorithm, RateLimiterConfig};
pub use limiter::Limiter;
pub use limiter::{AuthRateLimiter, EndpointRateLimiter, GlobalRateLimiter, RateBucketInfo};
pub use limiter::{AuthRateLimiter, EndpointRateLimiter, RateBucketInfo, RedisRateLimiter};

View File

@@ -24,13 +24,13 @@ use super::{
RateLimiterConfig,
};
pub struct GlobalRateLimiter {
pub struct RedisRateLimiter {
data: Vec<RateBucket>,
info: Vec<RateBucketInfo>,
info: &'static [RateBucketInfo],
}
impl GlobalRateLimiter {
pub fn new(info: Vec<RateBucketInfo>) -> Self {
impl RedisRateLimiter {
pub fn new(info: &'static [RateBucketInfo]) -> Self {
Self {
data: vec![
RateBucket {
@@ -50,7 +50,7 @@ impl GlobalRateLimiter {
let should_allow_request = self
.data
.iter_mut()
.zip(&self.info)
.zip(self.info)
.all(|(bucket, info)| bucket.should_allow_request(info, now, 1));
if should_allow_request {

View File

@@ -5,7 +5,7 @@ use redis::AsyncCommands;
use tokio::sync::Mutex;
use uuid::Uuid;
use crate::rate_limiter::{GlobalRateLimiter, RateBucketInfo};
use crate::rate_limiter::{RateBucketInfo, RedisRateLimiter};
use super::{
connection_with_credentials_provider::ConnectionWithCredentialsProvider,
@@ -80,7 +80,7 @@ impl<P: CancellationPublisherMut> CancellationPublisher for Arc<Mutex<P>> {
pub struct RedisPublisherClient {
client: ConnectionWithCredentialsProvider,
region_id: String,
limiter: GlobalRateLimiter,
limiter: RedisRateLimiter,
}
impl RedisPublisherClient {
@@ -92,7 +92,7 @@ impl RedisPublisherClient {
Ok(Self {
client,
region_id,
limiter: GlobalRateLimiter::new(info.into()),
limiter: RedisRateLimiter::new(info),
})
}

View File

@@ -0,0 +1,9 @@
from fixtures.neon_fixtures import NeonEnv
def test_protocol_version(neon_simple_env: NeonEnv):
env = neon_simple_env
endpoint = env.endpoints.create_start("main", config_lines=["neon.protocol_version=1"])
cur = endpoint.connect().cursor()
cur.execute("show neon.protocol_version")
assert cur.fetchone() == ("1",)

View File

@@ -0,0 +1,84 @@
import asyncio
import time
from pathlib import Path
from typing import Iterator
import pytest
from fixtures.neon_fixtures import (
PSQL,
NeonProxy,
)
from fixtures.port_distributor import PortDistributor
from pytest_httpserver import HTTPServer
from werkzeug.wrappers.response import Response
def waiting_handler(status_code: int) -> Response:
# wait more than timeout to make sure that both (two) connections are open.
# It would be better to use a barrier here, but I don't know how to do that together with pytest-httpserver.
time.sleep(2)
return Response(status=status_code)
@pytest.fixture(scope="function")
def proxy_with_rate_limit(
port_distributor: PortDistributor,
neon_binpath: Path,
httpserver_listen_address,
test_output_dir: Path,
) -> Iterator[NeonProxy]:
"""Neon proxy that routes directly to vanilla postgres."""
proxy_port = port_distributor.get_port()
mgmt_port = port_distributor.get_port()
http_port = port_distributor.get_port()
external_http_port = port_distributor.get_port()
(host, port) = httpserver_listen_address
endpoint = f"http://{host}:{port}/billing/api/v1/usage_events"
with NeonProxy(
neon_binpath=neon_binpath,
test_output_dir=test_output_dir,
proxy_port=proxy_port,
http_port=http_port,
mgmt_port=mgmt_port,
external_http_port=external_http_port,
auth_backend=NeonProxy.Console(endpoint, fixed_rate_limit=5),
) as proxy:
proxy.start()
yield proxy
@pytest.mark.asyncio
async def test_proxy_rate_limit(
httpserver: HTTPServer,
proxy_with_rate_limit: NeonProxy,
):
uri = "/billing/api/v1/usage_events/proxy_get_role_secret"
# mock control plane service
httpserver.expect_ordered_request(uri, method="GET").respond_with_handler(
lambda _: Response(status=200)
)
httpserver.expect_ordered_request(uri, method="GET").respond_with_handler(
lambda _: waiting_handler(429)
)
httpserver.expect_ordered_request(uri, method="GET").respond_with_handler(
lambda _: waiting_handler(500)
)
psql = PSQL(host=proxy_with_rate_limit.host, port=proxy_with_rate_limit.proxy_port)
f = await psql.run("select 42;")
await proxy_with_rate_limit.find_auth_link(uri, f)
# Limit should be 2.
# Run two queries in parallel.
f1, f2 = await asyncio.gather(psql.run("select 42;"), psql.run("select 42;"))
await proxy_with_rate_limit.find_auth_link(uri, f1)
await proxy_with_rate_limit.find_auth_link(uri, f2)
# Now limit should be 0.
f = await psql.run("select 42;")
await proxy_with_rate_limit.find_auth_link(uri, f)
# There last query shouldn't reach the http-server.
assert httpserver.assertions == []

View File

@@ -1,5 +1,5 @@
{
"postgres-v16": "3946b2e2ea71d07af092099cb5bcae76a69b90d6",
"postgres-v15": "64b8c7bccc6b77e04795e2d4cf6ad82dc8d987ed",
"postgres-v14": "a7b4c66156bce00afa60e5592d4284ba9e40b4cf"
"postgres-v16": "261497dd63ace434045058b1453bcbaaa83f23e5",
"postgres-v15": "85d809c124a898847a97d66a211f7d5ef4f8e0cb",
"postgres-v14": "d9149dc59abcbeeb26293707509aef51752db28f"
}