Compare commits

..

9 Commits

Author SHA1 Message Date
github-actions[bot]
abb7d6c2d8 Proxy release 2025-06-03 06:01 UTC 2025-06-03 06:01:25 +00:00
Alex Chi Z.
a650f7f5af fix(pageserver): only deserialize reldir key once during get_db_size (#12102)
## Problem

fix https://github.com/neondatabase/neon/issues/12101; this is a quick
hack and we need better API in the future.

In `get_db_size`, we call `get_reldir_size` for every relation. However,
we do the same deserializing the reldir directory thing for every
relation. This creates huge CPU overhead.

## Summary of changes

Get and deserialize the reldir v1 key once and use it across all
get_rel_size requests.

---------

Signed-off-by: Alex Chi Z <chi@neon.tech>
2025-06-03 05:00:34 +00:00
Erik Grinaker
fc3994eb71 pageserver: initial gRPC page service implementation (#12094)
## Problem

We should expose the page service over gRPC.

Requires #12093.
Touches #11728.

## Summary of changes

This patch adds an initial page service implementation over gRPC. It
ties in with the existing `PageServerHandler` request logic, to avoid
the implementations drifting apart for the core read path.

This is just a bare-bones functional implementation. Several important
aspects have been omitted, and will be addressed in follow-up PRs:

* Limited observability: minimal tracing, no logging, limited metrics
and timing, etc.
* Rate limiting will currently block.
* No performance optimization.
* No cancellation handling.
* No tests.

I've only done rudimentary testing of this, but Pagebench passes at
least.
2025-06-02 17:15:18 +00:00
Conrad Ludgate
781bf4945d proxy: optimise future layout allocations (#12104)
A smaller version of #12066 that is somewhat easier to review.

Now that I've been using https://crates.io/crates/top-type-sizes I've
found a lot more of the low hanging fruit that can be tweaks to reduce
the memory usage.

Some context for the optimisations:

Rust's stack allocation in futures is quite naive. Stack variables, even
if moved, often still end up taking space in the future. Rearranging the
order in which variables are defined, and properly scoping them can go a
long way.

`async fn` and `async move {}` have a consequence that they always
duplicate the "upvars" (aka captures). All captures are permanently
allocated in the future, even if moved. We can be mindful when writing
futures to only capture as little as possible.

TlsStream is massive. Needs boxing so it doesn't contribute to the above
issue.

## Measurements from `top-type-sizes`:

### Before

```
10328 {async block@proxy::proxy::task_main::{closure#0}::{closure#0}} align=8
6120 {async fn body of proxy::proxy::handle_client<proxy::protocol2::ChainRW<tokio::net::TcpStream>>()} align=8
```

### After

```
4040 {async block@proxy::proxy::task_main::{closure#0}::{closure#0}}
4704 {async fn body of proxy::proxy::handle_client<proxy::protocol2::ChainRW<tokio::net::TcpStream>>()} align=8
```
2025-06-02 16:13:30 +00:00
Erik Grinaker
a21c1174ed pagebench: add gRPC support for get-page-latest-lsn (#12077)
## Problem

We need gRPC support in Pagebench to benchmark the new gRPC Pageserver
implementation.

Touches #11728.

## Summary of changes

Adds a `Client` trait to make the client transport swappable, and a gRPC
client via a `--protocol grpc` parameter. This must also specify the
connstring with the gRPC port:

```
pagebench get-page-latest-lsn --protocol grpc --page-service-connstring grpc://localhost:51051
```

The client is implemented using the raw Tonic-generated gRPC client, to
minimize client overhead.
2025-06-02 14:50:49 +00:00
Erik Grinaker
8d7ed2a4ee pageserver: add gRPC observability middleware (#12093)
## Problem

The page service logic asserts that a tracing span is present with
tenant/timeline/shard IDs. An initial gRPC page service implementation
thus requires a tracing span.

Touches https://github.com/neondatabase/neon/issues/11728.

## Summary of changes

Adds an `ObservabilityLayer` middleware that generates a tracing span
and decorates it with IDs from the gRPC metadata.

This is a minimal implementation to address the tracing span assertion.
It will be extended with additional observability in later PRs.
2025-06-02 11:46:50 +00:00
Vlad Lazar
5b62749c42 pageserver: reduce import memory utilization (#12086)
## Problem

Imports can end up allocating too much.

## Summary of Changes

Nerf them a bunch and add some logs.
2025-06-02 10:29:15 +00:00
Vlad Lazar
af5bb67f08 pageserver: more reactive wal receiver cancellation (#12076)
## Problem

If the wal receiver is cancelled, there's a 50% chance that it will
ingest yet more WAL.

## Summary of Changes

Always check cancellation first.
2025-06-02 08:59:21 +00:00
Conrad Ludgate
589bfdfd02 proxy: Changes to rate limits and GetEndpointAccessControl caches. (#12048)
Precursor to https://github.com/neondatabase/cloud/issues/28333.

We want per-endpoint configuration for rate limits, which will be
distributed via the `GetEndpointAccessControl` API. This lays some of
the ground work.

1. Allow the endpoint rate limiter to accept a custom leaky bucket
config on check.
2. Remove the unused auth rate limiter, as I don't want to think about
how it fits into this.
3. Refactor the caching of `GetEndpointAccessControl`, as it adds
friction for adding new cached data to the API.

That third one was rather large. I couldn't find any way to split it up.
The core idea is that there's now only 2 cache APIs.
`get_endpoint_access_controls` and `get_role_access_controls`.

I'm pretty sure the behaviour is unchanged, except I did a drive by
change to fix #8989 because it felt harmless. The change in question is
that when a password validation fails, we eagerly expire the role cache
if the role was cached for 5 minutes. This is to allow for edge cases
where a user tries to connect with a reset password, but the cache never
expires the entry due to some redis related quirk (lag, or
misconfiguration, or cplane error)
2025-06-02 08:38:35 +00:00
49 changed files with 1748 additions and 1660 deletions

6
Cargo.lock generated
View File

@@ -4236,6 +4236,7 @@ name = "pagebench"
version = "0.1.0"
dependencies = [
"anyhow",
"async-trait",
"camino",
"clap",
"futures",
@@ -4244,12 +4245,15 @@ dependencies = [
"humantime-serde",
"pageserver_api",
"pageserver_client",
"pageserver_page_api",
"rand 0.8.5",
"reqwest",
"serde",
"serde_json",
"tokio",
"tokio-stream",
"tokio-util",
"tonic 0.13.1",
"tracing",
"utils",
"workspace_hack",
@@ -4305,6 +4309,7 @@ dependencies = [
"hashlink",
"hex",
"hex-literal",
"http 1.1.0",
"http-utils",
"humantime",
"humantime-serde",
@@ -4367,6 +4372,7 @@ dependencies = [
"toml_edit",
"tonic 0.13.1",
"tonic-reflection",
"tower 0.5.2",
"tracing",
"tracing-utils",
"twox-hash",

View File

@@ -107,7 +107,7 @@ impl<const N: usize> MetricType for HyperLogLogState<N> {
}
impl<const N: usize> HyperLogLogState<N> {
pub fn measure(&self, item: &impl Hash) {
pub fn measure(&self, item: &(impl Hash + ?Sized)) {
// changing the hasher will break compatibility with previous measurements.
self.record(BuildHasherDefault::<xxh3::Hash64>::default().hash_one(item));
}

View File

@@ -713,9 +713,9 @@ impl Default for ConfigToml {
enable_tls_page_service_api: false,
dev_mode: false,
timeline_import_config: TimelineImportConfig {
import_job_concurrency: NonZeroUsize::new(128).unwrap(),
import_job_soft_size_limit: NonZeroUsize::new(1024 * 1024 * 1024).unwrap(),
import_job_checkpoint_threshold: NonZeroUsize::new(128).unwrap(),
import_job_concurrency: NonZeroUsize::new(32).unwrap(),
import_job_soft_size_limit: NonZeroUsize::new(256 * 1024 * 1024).unwrap(),
import_job_checkpoint_threshold: NonZeroUsize::new(32).unwrap(),
},
basebackup_cache_config: None,
posthog_config: None,

View File

@@ -1934,7 +1934,7 @@ pub enum PagestreamFeMessage {
}
// Wrapped in libpq CopyData
#[derive(strum_macros::EnumProperty)]
#[derive(Debug, strum_macros::EnumProperty)]
pub enum PagestreamBeMessage {
Exists(PagestreamExistsResponse),
Nblocks(PagestreamNblocksResponse),
@@ -2045,7 +2045,7 @@ pub enum PagestreamProtocolVersion {
pub type RequestId = u64;
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)]
pub struct PagestreamRequest {
pub reqid: RequestId,
pub request_lsn: Lsn,
@@ -2064,7 +2064,7 @@ pub struct PagestreamNblocksRequest {
pub rel: RelTag,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)]
pub struct PagestreamGetPageRequest {
pub hdr: PagestreamRequest,
pub rel: RelTag,

View File

@@ -24,7 +24,7 @@ use serde::{Deserialize, Serialize};
// FIXME: should move 'forknum' as last field to keep this consistent with Postgres.
// Then we could replace the custom Ord and PartialOrd implementations below with
// deriving them. This will require changes in walredoproc.c.
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize)]
#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize)]
pub struct RelTag {
pub forknum: u8,
pub spcnode: Oid,
@@ -184,12 +184,12 @@ pub enum SlruKind {
MultiXactOffsets,
}
impl SlruKind {
pub fn to_str(&self) -> &'static str {
impl fmt::Display for SlruKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Clog => "pg_xact",
Self::MultiXactMembers => "pg_multixact/members",
Self::MultiXactOffsets => "pg_multixact/offsets",
Self::Clog => write!(f, "pg_xact"),
Self::MultiXactMembers => write!(f, "pg_multixact/members"),
Self::MultiXactOffsets => write!(f, "pg_multixact/offsets"),
}
}
}

View File

@@ -28,6 +28,7 @@ use std::time::Duration;
use tokio::sync::Notify;
use tokio::time::Instant;
#[derive(Clone, Copy)]
pub struct LeakyBucketConfig {
/// This is the "time cost" of a single request unit.
/// Should loosely represent how long it takes to handle a request unit in active resource time.

View File

@@ -73,6 +73,7 @@ pub mod error;
/// async timeout helper
pub mod timeout;
pub mod span;
pub mod sync;
pub mod failpoint_support;

19
libs/utils/src/span.rs Normal file
View File

@@ -0,0 +1,19 @@
//! Tracing span helpers.
/// Records the given fields in the current span, as a single call. The fields must already have
/// been declared for the span (typically with empty values).
#[macro_export]
macro_rules! span_record {
($($tokens:tt)*) => {$crate::span_record_in!(::tracing::Span::current(), $($tokens)*)};
}
/// Records the given fields in the given span, as a single call. The fields must already have been
/// declared for the span (typically with empty values).
#[macro_export]
macro_rules! span_record_in {
($span:expr, $($tokens:tt)*) => {
if let Some(meta) = $span.metadata() {
$span.record_all(&tracing::valueset!(meta.fields(), $($tokens)*));
}
};
}

View File

@@ -34,6 +34,7 @@ fail.workspace = true
futures.workspace = true
hashlink.workspace = true
hex.workspace = true
http.workspace = true
http-utils.workspace = true
humantime-serde.workspace = true
humantime.workspace = true
@@ -93,6 +94,7 @@ tokio-util.workspace = true
toml_edit = { workspace = true, features = [ "serde" ] }
tonic.workspace = true
tonic-reflection.workspace = true
tower.workspace = true
tracing.workspace = true
tracing-utils.workspace = true
url.workspace = true

View File

@@ -10,6 +10,8 @@
//!
//! - Validate protocol invariants, via try_from() and try_into().
use std::fmt::Display;
use bytes::Bytes;
use postgres_ffi::Oid;
use smallvec::SmallVec;
@@ -48,7 +50,8 @@ pub struct ReadLsn {
pub request_lsn: Lsn,
/// If given, the caller guarantees that the page has not been modified since this LSN. Must be
/// smaller than or equal to request_lsn. This allows the Pageserver to serve an old page
/// without waiting for the request LSN to arrive. Valid for all request types.
/// without waiting for the request LSN to arrive. If not given, the request will read at the
/// request_lsn and wait for it to arrive if necessary. Valid for all request types.
///
/// It is undefined behaviour to make a request such that the page was, in fact, modified
/// between request_lsn and not_modified_since_lsn. The Pageserver might detect it and return an
@@ -58,6 +61,17 @@ pub struct ReadLsn {
pub not_modified_since_lsn: Option<Lsn>,
}
impl Display for ReadLsn {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let req_lsn = self.request_lsn;
if let Some(mod_lsn) = self.not_modified_since_lsn {
write!(f, "{req_lsn}>={mod_lsn}")
} else {
req_lsn.fmt(f)
}
}
}
impl ReadLsn {
/// Validates the ReadLsn.
pub fn validate(&self) -> Result<(), ProtocolError> {
@@ -584,6 +598,7 @@ impl TryFrom<GetSlruSegmentResponse> for proto::GetSlruSegmentResponse {
type Error = ProtocolError;
fn try_from(segment: GetSlruSegmentResponse) -> Result<Self, Self::Error> {
// TODO: can a segment legitimately be empty?
if segment.is_empty() {
return Err(ProtocolError::Missing("segment"));
}

View File

@@ -8,6 +8,7 @@ license.workspace = true
[dependencies]
anyhow.workspace = true
async-trait.workspace = true
camino.workspace = true
clap.workspace = true
futures.workspace = true
@@ -15,14 +16,17 @@ hdrhistogram.workspace = true
humantime.workspace = true
humantime-serde.workspace = true
rand.workspace = true
reqwest.workspace=true
reqwest.workspace = true
serde.workspace = true
serde_json.workspace = true
tracing.workspace = true
tokio.workspace = true
tokio-stream.workspace = true
tokio-util.workspace = true
tonic.workspace = true
pageserver_client.workspace = true
pageserver_api.workspace = true
pageserver_page_api.workspace = true
utils = { path = "../../libs/utils/" }
workspace_hack = { version = "0.1", path = "../../workspace_hack" }

View File

@@ -7,11 +7,15 @@ use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use anyhow::Context;
use async_trait::async_trait;
use camino::Utf8PathBuf;
use pageserver_api::key::Key;
use pageserver_api::keyspace::KeySpaceAccum;
use pageserver_api::models::{PagestreamGetPageRequest, PagestreamRequest};
use pageserver_api::models::{
PagestreamGetPageRequest, PagestreamGetPageResponse, PagestreamRequest,
};
use pageserver_api::shard::TenantShardId;
use pageserver_page_api::proto;
use rand::prelude::*;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
@@ -22,6 +26,12 @@ use utils::lsn::Lsn;
use crate::util::tokio_thread_local_stats::AllThreadLocalStats;
use crate::util::{request_stats, tokio_thread_local_stats};
#[derive(clap::ValueEnum, Clone, Debug)]
enum Protocol {
Libpq,
Grpc,
}
/// GetPage@LatestLSN, uniformly distributed across the compute-accessible keyspace.
#[derive(clap::Parser)]
pub(crate) struct Args {
@@ -35,6 +45,8 @@ pub(crate) struct Args {
num_clients: NonZeroUsize,
#[clap(long)]
runtime: Option<humantime::Duration>,
#[clap(long, value_enum, default_value = "libpq")]
protocol: Protocol,
/// Each client sends requests at the given rate.
///
/// If a request takes too long and we should be issuing a new request already,
@@ -303,7 +315,20 @@ async fn main_impl(
.unwrap();
Box::pin(async move {
client_libpq(args, worker_id, ss, cancel, rps_period, ranges, weights).await
let client: Box<dyn Client> = match args.protocol {
Protocol::Libpq => Box::new(
LibpqClient::new(args.page_service_connstring.clone(), worker_id.timeline)
.await
.unwrap(),
),
Protocol::Grpc => Box::new(
GrpcClient::new(args.page_service_connstring.clone(), worker_id.timeline)
.await
.unwrap(),
),
};
run_worker(args, client, ss, cancel, rps_period, ranges, weights).await
})
};
@@ -355,23 +380,15 @@ async fn main_impl(
anyhow::Ok(())
}
async fn client_libpq(
async fn run_worker(
args: &Args,
worker_id: WorkerId,
mut client: Box<dyn Client>,
shared_state: Arc<SharedState>,
cancel: CancellationToken,
rps_period: Option<Duration>,
ranges: Vec<KeyRange>,
weights: rand::distributions::weighted::WeightedIndex<i128>,
) {
let client = pageserver_client::page_service::Client::new(args.page_service_connstring.clone())
.await
.unwrap();
let mut client = client
.pagestream(worker_id.timeline.tenant_id, worker_id.timeline.timeline_id)
.await
.unwrap();
shared_state.start_work_barrier.wait().await;
let client_start = Instant::now();
let mut ticks_processed = 0;
@@ -415,12 +432,12 @@ async fn client_libpq(
blkno: block_no,
}
};
client.getpage_send(req).await.unwrap();
client.send_get_page(req).await.unwrap();
inflight.push_back(start);
}
let start = inflight.pop_front().unwrap();
client.getpage_recv().await.unwrap();
client.recv_get_page().await.unwrap();
let end = Instant::now();
shared_state.live_stats.request_done();
ticks_processed += 1;
@@ -442,3 +459,104 @@ async fn client_libpq(
}
}
}
/// A benchmark client, to allow switching out the transport protocol.
///
/// For simplicity, this just uses separate asynchronous send/recv methods. The send method could
/// return a future that resolves when the response is received, but we don't really need it.
#[async_trait]
trait Client: Send {
/// Sends an asynchronous GetPage request to the pageserver.
async fn send_get_page(&mut self, req: PagestreamGetPageRequest) -> anyhow::Result<()>;
/// Receives the next GetPage response from the pageserver.
async fn recv_get_page(&mut self) -> anyhow::Result<PagestreamGetPageResponse>;
}
/// A libpq-based Pageserver client.
struct LibpqClient {
inner: pageserver_client::page_service::PagestreamClient,
}
impl LibpqClient {
async fn new(connstring: String, ttid: TenantTimelineId) -> anyhow::Result<Self> {
let inner = pageserver_client::page_service::Client::new(connstring)
.await?
.pagestream(ttid.tenant_id, ttid.timeline_id)
.await?;
Ok(Self { inner })
}
}
#[async_trait]
impl Client for LibpqClient {
async fn send_get_page(&mut self, req: PagestreamGetPageRequest) -> anyhow::Result<()> {
self.inner.getpage_send(req).await
}
async fn recv_get_page(&mut self) -> anyhow::Result<PagestreamGetPageResponse> {
self.inner.getpage_recv().await
}
}
/// A gRPC client using the raw, no-frills gRPC client.
struct GrpcClient {
req_tx: tokio::sync::mpsc::Sender<proto::GetPageRequest>,
resp_rx: tonic::Streaming<proto::GetPageResponse>,
}
impl GrpcClient {
async fn new(connstring: String, ttid: TenantTimelineId) -> anyhow::Result<Self> {
let mut client = pageserver_page_api::proto::PageServiceClient::connect(connstring).await?;
// The channel has a buffer size of 1, since 0 is not allowed. It does not matter, since the
// benchmark will control the queue depth (i.e. in-flight requests) anyway, and requests are
// buffered by Tonic and the OS too.
let (req_tx, req_rx) = tokio::sync::mpsc::channel(1);
let req_stream = tokio_stream::wrappers::ReceiverStream::new(req_rx);
let mut req = tonic::Request::new(req_stream);
let metadata = req.metadata_mut();
metadata.insert("neon-tenant-id", ttid.tenant_id.to_string().try_into()?);
metadata.insert("neon-timeline-id", ttid.timeline_id.to_string().try_into()?);
metadata.insert("neon-shard-id", "0000".try_into()?);
let resp = client.get_pages(req).await?;
let resp_stream = resp.into_inner();
Ok(Self {
req_tx,
resp_rx: resp_stream,
})
}
}
#[async_trait]
impl Client for GrpcClient {
async fn send_get_page(&mut self, req: PagestreamGetPageRequest) -> anyhow::Result<()> {
let req = proto::GetPageRequest {
request_id: 0,
request_class: proto::GetPageClass::Normal as i32,
read_lsn: Some(proto::ReadLsn {
request_lsn: req.hdr.request_lsn.0,
not_modified_since_lsn: req.hdr.not_modified_since.0,
}),
rel: Some(req.rel.into()),
block_number: vec![req.blkno],
};
self.req_tx.send(req).await?;
Ok(())
}
async fn recv_get_page(&mut self) -> anyhow::Result<PagestreamGetPageResponse> {
let resp = self.resp_rx.message().await?.unwrap();
anyhow::ensure!(
resp.status_code == proto::GetPageStatusCode::Ok as i32,
"unexpected status code: {}",
resp.status_code
);
Ok(PagestreamGetPageResponse {
page: resp.page_image[0].clone(),
req: PagestreamGetPageRequest::default(), // dummy
})
}
}

View File

@@ -65,6 +65,30 @@ impl From<GetVectoredError> for BasebackupError {
}
}
impl From<BasebackupError> for postgres_backend::QueryError {
fn from(err: BasebackupError) -> Self {
use postgres_backend::QueryError;
use pq_proto::framed::ConnectionError;
match err {
BasebackupError::Client(err, _) => QueryError::Disconnected(ConnectionError::Io(err)),
BasebackupError::Server(err) => QueryError::Other(err),
BasebackupError::Shutdown => QueryError::Shutdown,
}
}
}
impl From<BasebackupError> for tonic::Status {
fn from(err: BasebackupError) -> Self {
use tonic::Code;
let code = match &err {
BasebackupError::Client(_, _) => Code::Cancelled,
BasebackupError::Server(_) => Code::Internal,
BasebackupError::Shutdown => Code::Unavailable,
};
tonic::Status::new(code, err.to_string())
}
}
/// Create basebackup with non-rel data in it.
/// Only include relational data if 'full_backup' is true.
///
@@ -248,7 +272,7 @@ where
async fn flush(&mut self) -> Result<(), BasebackupError> {
let nblocks = self.buf.len() / BLCKSZ as usize;
let (kind, segno) = self.current_segment.take().unwrap();
let segname = format!("{}/{:>04X}", kind.to_str(), segno);
let segname = format!("{kind}/{segno:>04X}");
let header = new_tar_header(&segname, self.buf.len() as u64)?;
self.ar
.append(&header, self.buf.as_slice())

View File

@@ -804,7 +804,7 @@ fn start_pageserver(
} else {
None
},
basebackup_cache.clone(),
basebackup_cache,
);
// Spawn a Pageserver gRPC server task. It will spawn separate tasks for
@@ -816,12 +816,10 @@ fn start_pageserver(
let mut page_service_grpc = None;
if let Some(grpc_listener) = grpc_listener {
page_service_grpc = Some(page_service::spawn_grpc(
conf,
tenant_manager.clone(),
grpc_auth,
otel_guard.as_ref().map(|g| g.dispatch.clone()),
grpc_listener,
basebackup_cache,
)?);
}

File diff suppressed because it is too large Load Diff

View File

@@ -471,8 +471,19 @@ impl Timeline {
let rels = self.list_rels(spcnode, dbnode, version, ctx).await?;
if rels.is_empty() {
return Ok(0);
}
// Pre-deserialize the rel directory to avoid duplicated work in `get_relsize_cached`.
let reldir_key = rel_dir_to_key(spcnode, dbnode);
let buf = version.get(self, reldir_key, ctx).await?;
let reldir = RelDirectory::des(&buf)?;
for rel in rels {
let n_blocks = self.get_rel_size(rel, version, ctx).await?;
let n_blocks = self
.get_rel_size_in_reldir(rel, version, Some((reldir_key, &reldir)), ctx)
.await?;
total_blocks += n_blocks as usize;
}
Ok(total_blocks)
@@ -487,6 +498,19 @@ impl Timeline {
tag: RelTag,
version: Version<'_>,
ctx: &RequestContext,
) -> Result<BlockNumber, PageReconstructError> {
self.get_rel_size_in_reldir(tag, version, None, ctx).await
}
/// Get size of a relation file. The relation must exist, otherwise an error is returned.
///
/// See [`Self::get_rel_exists_in_reldir`] on why we need `deserialized_reldir_v1`.
pub(crate) async fn get_rel_size_in_reldir(
&self,
tag: RelTag,
version: Version<'_>,
deserialized_reldir_v1: Option<(Key, &RelDirectory)>,
ctx: &RequestContext,
) -> Result<BlockNumber, PageReconstructError> {
if tag.relnode == 0 {
return Err(PageReconstructError::Other(
@@ -499,7 +523,9 @@ impl Timeline {
}
if (tag.forknum == FSM_FORKNUM || tag.forknum == VISIBILITYMAP_FORKNUM)
&& !self.get_rel_exists(tag, version, ctx).await?
&& !self
.get_rel_exists_in_reldir(tag, version, deserialized_reldir_v1, ctx)
.await?
{
// FIXME: Postgres sometimes calls smgrcreate() to create
// FSM, and smgrnblocks() on it immediately afterwards,
@@ -521,11 +547,28 @@ impl Timeline {
///
/// Only shard 0 has a full view of the relations. Other shards only know about relations that
/// the shard stores pages for.
///
pub(crate) async fn get_rel_exists(
&self,
tag: RelTag,
version: Version<'_>,
ctx: &RequestContext,
) -> Result<bool, PageReconstructError> {
self.get_rel_exists_in_reldir(tag, version, None, ctx).await
}
/// Does the relation exist? With a cached deserialized `RelDirectory`.
///
/// There are some cases where the caller loops across all relations. In that specific case,
/// the caller should obtain the deserialized `RelDirectory` first and then call this function
/// to avoid duplicated work of deserliazation. This is a hack and should be removed by introducing
/// a new API (e.g., `get_rel_exists_batched`).
pub(crate) async fn get_rel_exists_in_reldir(
&self,
tag: RelTag,
version: Version<'_>,
deserialized_reldir_v1: Option<(Key, &RelDirectory)>,
ctx: &RequestContext,
) -> Result<bool, PageReconstructError> {
if tag.relnode == 0 {
return Err(PageReconstructError::Other(
@@ -568,6 +611,17 @@ impl Timeline {
// fetch directory listing (old)
let key = rel_dir_to_key(tag.spcnode, tag.dbnode);
if let Some((cached_key, dir)) = deserialized_reldir_v1 {
if cached_key == key {
return Ok(dir.rels.contains(&(tag.relnode, tag.forknum)));
} else if cfg!(test) || cfg!(feature = "testing") {
panic!("cached reldir key mismatch: {cached_key} != {key}");
} else {
warn!("cached reldir key mismatch: {cached_key} != {key}");
}
// Fallback to reading the directory from the datadir.
}
let buf = version.get(self, key, ctx).await?;
let dir = RelDirectory::des(&buf)?;

View File

@@ -950,6 +950,18 @@ pub(crate) enum WaitLsnError {
Timeout(String),
}
impl From<WaitLsnError> for tonic::Status {
fn from(err: WaitLsnError) -> Self {
use tonic::Code;
let code = match &err {
WaitLsnError::Timeout(_) => Code::Internal,
WaitLsnError::BadState(_) => Code::Internal,
WaitLsnError::Shutdown => Code::Unavailable,
};
tonic::Status::new(code, err.to_string())
}
}
// The impls below achieve cancellation mapping for errors.
// Perhaps there's a way of achieving this with less cruft.

View File

@@ -106,6 +106,8 @@ pub async fn doit(
);
}
tracing::info!("Import plan executed. Flushing remote changes and notifying storcon");
timeline
.remote_client
.schedule_index_upload_for_file_changes()?;

View File

@@ -130,7 +130,15 @@ async fn run_v1(
pausable_failpoint!("import-timeline-pre-execute-pausable");
let jobs_count = import_progress.as_ref().map(|p| p.jobs);
let start_from_job_idx = import_progress.map(|progress| progress.completed);
tracing::info!(
start_from_job_idx=?start_from_job_idx,
jobs=?jobs_count,
"Executing import plan"
);
plan.execute(timeline, start_from_job_idx, plan_hash, &import_config, ctx)
.await
}
@@ -484,6 +492,8 @@ impl Plan {
anyhow::anyhow!("Shut down while putting timeline import status")
})?;
}
tracing::info!(last_completed_job_idx, jobs=%jobs_in_plan, "Checkpointing import status");
},
Some(Err(_)) => {
anyhow::bail!(
@@ -760,7 +770,7 @@ impl ImportTask for ImportRelBlocksTask {
layer_writer: &mut ImageLayerWriter,
ctx: &RequestContext,
) -> anyhow::Result<usize> {
const MAX_BYTE_RANGE_SIZE: usize = 128 * 1024 * 1024;
const MAX_BYTE_RANGE_SIZE: usize = 4 * 1024 * 1024;
debug!("Importing relation file");

View File

@@ -113,7 +113,7 @@ impl WalReceiver {
}
connection_manager_state.shutdown().await;
*loop_status.write().unwrap() = None;
debug!("task exits");
info!("task exits");
}
.instrument(info_span!(parent: None, "wal_connection_manager", tenant_id = %tenant_shard_id.tenant_id, shard_id = %tenant_shard_id.shard_slug(), timeline_id = %timeline_id))
});

View File

@@ -297,6 +297,7 @@ pub(super) async fn handle_walreceiver_connection(
let mut expected_wal_start = startpoint;
while let Some(replication_message) = {
select! {
biased;
_ = cancellation.cancelled() => {
debug!("walreceiver interrupted");
None

View File

@@ -25,19 +25,15 @@ pub(super) async fn authenticate(
}
AuthSecret::Scram(secret) => {
debug!("auth endpoint chooses SCRAM");
let scram = auth::Scram(&secret, ctx);
let auth_outcome = tokio::time::timeout(config.scram_protocol_timeout, async {
AuthFlow::new(client, scram)
.authenticate()
.await
.inspect_err(|error| {
warn!(?error, "error processing scram messages");
})
})
let auth_outcome = tokio::time::timeout(
config.scram_protocol_timeout,
AuthFlow::new(client, auth::Scram(&secret, ctx)).authenticate(),
)
.await
.inspect_err(|_| warn!("error processing scram messages error = authentication timed out, execution time exceeded {} seconds", config.scram_protocol_timeout.as_secs()))
.map_err(auth::AuthError::user_timeout)??;
.map_err(auth::AuthError::user_timeout)?
.inspect_err(|error| warn!(?error, "error processing scram messages"))?;
let client_key = match auth_outcome {
sasl::Outcome::Success(key) => key,

View File

@@ -4,38 +4,31 @@ mod hacks;
pub mod jwt;
pub mod local;
use std::net::IpAddr;
use std::sync::Arc;
pub use console_redirect::ConsoleRedirectBackend;
pub(crate) use console_redirect::ConsoleRedirectError;
use ipnet::{Ipv4Net, Ipv6Net};
use local::LocalBackend;
use postgres_client::config::AuthKeys;
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{debug, info, warn};
use tracing::{debug, info};
use crate::auth::credentials::check_peer_addr_is_in_list;
use crate::auth::{
self, AuthError, ComputeUserInfoMaybeEndpoint, IpPattern, validate_password_and_exchange,
};
use crate::auth::{self, AuthError, ComputeUserInfoMaybeEndpoint, validate_password_and_exchange};
use crate::cache::Cached;
use crate::config::AuthenticationConfig;
use crate::context::RequestContext;
use crate::control_plane::client::ControlPlaneClient;
use crate::control_plane::errors::GetAuthInfoError;
use crate::control_plane::{
self, AccessBlockerFlags, AuthSecret, CachedAccessBlockerFlags, CachedAllowedIps,
CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, ControlPlaneApi,
self, AccessBlockerFlags, AuthSecret, CachedNodeInfo, ControlPlaneApi, EndpointAccessControl,
RoleAccessControl,
};
use crate::intern::EndpointIdInt;
use crate::metrics::Metrics;
use crate::pqproto::BeMessage;
use crate::protocol2::ConnectionInfoExtra;
use crate::proxy::NeonOptions;
use crate::proxy::connect_compute::ComputeConnectBackend;
use crate::rate_limiter::{BucketRateLimiter, EndpointRateLimiter};
use crate::rate_limiter::EndpointRateLimiter;
use crate::stream::Stream;
use crate::types::{EndpointCacheKey, EndpointId, RoleName};
use crate::{scram, stream};
@@ -201,78 +194,6 @@ impl TryFrom<ComputeUserInfoMaybeEndpoint> for ComputeUserInfo {
}
}
#[derive(PartialEq, PartialOrd, Hash, Eq, Ord, Debug, Copy, Clone)]
pub struct MaskedIp(IpAddr);
impl MaskedIp {
fn new(value: IpAddr, prefix: u8) -> Self {
match value {
IpAddr::V4(v4) => Self(IpAddr::V4(
Ipv4Net::new(v4, prefix).map_or(v4, |x| x.trunc().addr()),
)),
IpAddr::V6(v6) => Self(IpAddr::V6(
Ipv6Net::new(v6, prefix).map_or(v6, |x| x.trunc().addr()),
)),
}
}
}
// This can't be just per IP because that would limit some PaaS that share IP addresses
pub type AuthRateLimiter = BucketRateLimiter<(EndpointIdInt, MaskedIp)>;
impl AuthenticationConfig {
pub(crate) fn check_rate_limit(
&self,
ctx: &RequestContext,
secret: AuthSecret,
endpoint: &EndpointId,
is_cleartext: bool,
) -> auth::Result<AuthSecret> {
// we have validated the endpoint exists, so let's intern it.
let endpoint_int = EndpointIdInt::from(endpoint.normalize());
// only count the full hash count if password hack or websocket flow.
// in other words, if proxy needs to run the hashing
let password_weight = if is_cleartext {
match &secret {
#[cfg(any(test, feature = "testing"))]
AuthSecret::Md5(_) => 1,
AuthSecret::Scram(s) => s.iterations + 1,
}
} else {
// validating scram takes just 1 hmac_sha_256 operation.
1
};
let limit_not_exceeded = self.rate_limiter.check(
(
endpoint_int,
MaskedIp::new(ctx.peer_addr(), self.rate_limit_ip_subnet),
),
password_weight,
);
if !limit_not_exceeded {
warn!(
enabled = self.rate_limiter_enabled,
"rate limiting authentication"
);
Metrics::get().proxy.requests_auth_rate_limits_total.inc();
Metrics::get()
.proxy
.endpoints_auth_rate_limits
.get_metric()
.measure(endpoint);
if self.rate_limiter_enabled {
return Err(auth::AuthError::too_many_connections());
}
}
Ok(secret)
}
}
/// True to its name, this function encapsulates our current auth trade-offs.
/// Here, we choose the appropriate auth flow based on circumstances.
///
@@ -285,7 +206,7 @@ async fn auth_quirks(
allow_cleartext: bool,
config: &'static AuthenticationConfig,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> auth::Result<(ComputeCredentials, Option<Vec<IpPattern>>)> {
) -> auth::Result<ComputeCredentials> {
// If there's no project so far, that entails that client doesn't
// support SNI or other means of passing the endpoint (project) name.
// We now expect to see a very specific payload in the place of password.
@@ -301,55 +222,27 @@ async fn auth_quirks(
debug!("fetching authentication info and allowlists");
// check allowed list
let allowed_ips = if config.ip_allowlist_check_enabled {
let allowed_ips = api.get_allowed_ips(ctx, &info).await?;
if !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips) {
return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
}
allowed_ips
} else {
Cached::new_uncached(Arc::new(vec![]))
};
let access_controls = api
.get_endpoint_access_control(ctx, &info.endpoint, &info.user)
.await?;
// check if a VPC endpoint ID is coming in and if yes, if it's allowed
let access_blocks = api.get_block_public_or_vpc_access(ctx, &info).await?;
if config.is_vpc_acccess_proxy {
if access_blocks.vpc_access_blocked {
return Err(AuthError::NetworkNotAllowed);
}
access_controls.check(
ctx,
config.ip_allowlist_check_enabled,
config.is_vpc_acccess_proxy,
)?;
let incoming_vpc_endpoint_id = match ctx.extra() {
None => return Err(AuthError::MissingEndpointName),
Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(),
Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(),
};
let allowed_vpc_endpoint_ids = api.get_allowed_vpc_endpoint_ids(ctx, &info).await?;
// TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that.
if !allowed_vpc_endpoint_ids.is_empty()
&& !allowed_vpc_endpoint_ids.contains(&incoming_vpc_endpoint_id)
{
return Err(AuthError::vpc_endpoint_id_not_allowed(
incoming_vpc_endpoint_id,
));
}
} else if access_blocks.public_access_blocked {
return Err(AuthError::NetworkNotAllowed);
}
if !endpoint_rate_limiter.check(info.endpoint.clone().into(), 1) {
let endpoint = EndpointIdInt::from(&info.endpoint);
let rate_limit_config = None;
if !endpoint_rate_limiter.check(endpoint, rate_limit_config, 1) {
return Err(AuthError::too_many_connections());
}
let cached_secret = api.get_role_secret(ctx, &info).await?;
let (cached_entry, secret) = cached_secret.take_value();
let role_access = api
.get_role_access_control(ctx, &info.endpoint, &info.user)
.await?;
let secret = if let Some(secret) = secret {
config.check_rate_limit(
ctx,
secret,
&info.endpoint,
unauthenticated_password.is_some() || allow_cleartext,
)?
let secret = if let Some(secret) = role_access.secret {
secret
} else {
// If we don't have an authentication secret, we mock one to
// prevent malicious probing (possible due to missing protocol steps).
@@ -369,14 +262,8 @@ async fn auth_quirks(
)
.await
{
Ok(keys) => Ok((keys, Some(allowed_ips.as_ref().clone()))),
Err(e) => {
if e.is_password_failed() {
// The password could have been changed, so we invalidate the cache.
cached_entry.invalidate();
}
Err(e)
}
Ok(keys) => Ok(keys),
Err(e) => Err(e),
}
}
@@ -439,7 +326,7 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
allow_cleartext: bool,
config: &'static AuthenticationConfig,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> auth::Result<(Backend<'a, ComputeCredentials>, Option<Vec<IpPattern>>)> {
) -> auth::Result<Backend<'a, ComputeCredentials>> {
let res = match self {
Self::ControlPlane(api, user_info) => {
debug!(
@@ -448,17 +335,35 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
"performing authentication using the console"
);
let (credentials, ip_allowlist) = auth_quirks(
let auth_res = auth_quirks(
ctx,
&*api,
user_info,
user_info.clone(),
client,
allow_cleartext,
config,
endpoint_rate_limiter,
)
.await?;
Ok((Backend::ControlPlane(api, credentials), ip_allowlist))
.await;
match auth_res {
Ok(credentials) => Ok(Backend::ControlPlane(api, credentials)),
Err(e) => {
// The password could have been changed, so we invalidate the cache.
// We should only invalidate the cache if the TTL might have expired.
if e.is_password_failed() {
#[allow(irrefutable_let_patterns)]
if let ControlPlaneClient::ProxyV1(api) = &*api {
if let Some(ep) = &user_info.endpoint_id {
api.caches
.project_info
.maybe_invalidate_role_secret(ep, &user_info.user);
}
}
}
Err(e)
}
}
}
Self::Local(_) => {
return Err(auth::AuthError::bad_auth_method("invalid for local proxy"));
@@ -475,44 +380,30 @@ impl Backend<'_, ComputeUserInfo> {
pub(crate) async fn get_role_secret(
&self,
ctx: &RequestContext,
) -> Result<CachedRoleSecret, GetAuthInfoError> {
match self {
Self::ControlPlane(api, user_info) => api.get_role_secret(ctx, user_info).await,
Self::Local(_) => Ok(Cached::new_uncached(None)),
}
}
pub(crate) async fn get_allowed_ips(
&self,
ctx: &RequestContext,
) -> Result<CachedAllowedIps, GetAuthInfoError> {
match self {
Self::ControlPlane(api, user_info) => api.get_allowed_ips(ctx, user_info).await,
Self::Local(_) => Ok(Cached::new_uncached(Arc::new(vec![]))),
}
}
pub(crate) async fn get_allowed_vpc_endpoint_ids(
&self,
ctx: &RequestContext,
) -> Result<CachedAllowedVpcEndpointIds, GetAuthInfoError> {
) -> Result<RoleAccessControl, GetAuthInfoError> {
match self {
Self::ControlPlane(api, user_info) => {
api.get_allowed_vpc_endpoint_ids(ctx, user_info).await
api.get_role_access_control(ctx, &user_info.endpoint, &user_info.user)
.await
}
Self::Local(_) => Ok(Cached::new_uncached(Arc::new(vec![]))),
Self::Local(_) => Ok(RoleAccessControl { secret: None }),
}
}
pub(crate) async fn get_block_public_or_vpc_access(
pub(crate) async fn get_endpoint_access_control(
&self,
ctx: &RequestContext,
) -> Result<CachedAccessBlockerFlags, GetAuthInfoError> {
) -> Result<EndpointAccessControl, GetAuthInfoError> {
match self {
Self::ControlPlane(api, user_info) => {
api.get_block_public_or_vpc_access(ctx, user_info).await
api.get_endpoint_access_control(ctx, &user_info.endpoint, &user_info.user)
.await
}
Self::Local(_) => Ok(Cached::new_uncached(AccessBlockerFlags::default())),
Self::Local(_) => Ok(EndpointAccessControl {
allowed_ips: Arc::new(vec![]),
allowed_vpce: Arc::new(vec![]),
flags: AccessBlockerFlags::default(),
}),
}
}
}
@@ -541,9 +432,7 @@ impl ComputeConnectBackend for Backend<'_, ComputeCredentials> {
mod tests {
#![allow(clippy::unimplemented, clippy::unwrap_used)]
use std::net::IpAddr;
use std::sync::Arc;
use std::time::Duration;
use bytes::BytesMut;
use control_plane::AuthSecret;
@@ -554,18 +443,16 @@ mod tests {
use postgres_protocol::message::frontend;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
use super::auth_quirks;
use super::jwt::JwkCache;
use super::{AuthRateLimiter, auth_quirks};
use crate::auth::backend::MaskedIp;
use crate::auth::{ComputeUserInfoMaybeEndpoint, IpPattern};
use crate::config::AuthenticationConfig;
use crate::context::RequestContext;
use crate::control_plane::{
self, AccessBlockerFlags, CachedAccessBlockerFlags, CachedAllowedIps,
CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret,
self, AccessBlockerFlags, CachedNodeInfo, EndpointAccessControl, RoleAccessControl,
};
use crate::proxy::NeonOptions;
use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo};
use crate::rate_limiter::EndpointRateLimiter;
use crate::scram::ServerSecret;
use crate::scram::threadpool::ThreadPool;
use crate::stream::{PqStream, Stream};
@@ -578,46 +465,34 @@ mod tests {
}
impl control_plane::ControlPlaneApi for Auth {
async fn get_role_secret(
async fn get_role_access_control(
&self,
_ctx: &RequestContext,
_user_info: &super::ComputeUserInfo,
) -> Result<CachedRoleSecret, control_plane::errors::GetAuthInfoError> {
Ok(CachedRoleSecret::new_uncached(Some(self.secret.clone())))
_endpoint: &crate::types::EndpointId,
_role: &crate::types::RoleName,
) -> Result<RoleAccessControl, control_plane::errors::GetAuthInfoError> {
Ok(RoleAccessControl {
secret: Some(self.secret.clone()),
})
}
async fn get_allowed_ips(
async fn get_endpoint_access_control(
&self,
_ctx: &RequestContext,
_user_info: &super::ComputeUserInfo,
) -> Result<CachedAllowedIps, control_plane::errors::GetAuthInfoError> {
Ok(CachedAllowedIps::new_uncached(Arc::new(self.ips.clone())))
}
async fn get_allowed_vpc_endpoint_ids(
&self,
_ctx: &RequestContext,
_user_info: &super::ComputeUserInfo,
) -> Result<CachedAllowedVpcEndpointIds, control_plane::errors::GetAuthInfoError> {
Ok(CachedAllowedVpcEndpointIds::new_uncached(Arc::new(
self.vpc_endpoint_ids.clone(),
)))
}
async fn get_block_public_or_vpc_access(
&self,
_ctx: &RequestContext,
_user_info: &super::ComputeUserInfo,
) -> Result<CachedAccessBlockerFlags, control_plane::errors::GetAuthInfoError> {
Ok(CachedAccessBlockerFlags::new_uncached(
self.access_blocker_flags.clone(),
))
_endpoint: &crate::types::EndpointId,
_role: &crate::types::RoleName,
) -> Result<EndpointAccessControl, control_plane::errors::GetAuthInfoError> {
Ok(EndpointAccessControl {
allowed_ips: Arc::new(self.ips.clone()),
allowed_vpce: Arc::new(self.vpc_endpoint_ids.clone()),
flags: self.access_blocker_flags,
})
}
async fn get_endpoint_jwks(
&self,
_ctx: &RequestContext,
_endpoint: crate::types::EndpointId,
_endpoint: &crate::types::EndpointId,
) -> Result<Vec<super::jwt::AuthRule>, control_plane::errors::GetEndpointJwksError>
{
unimplemented!()
@@ -636,9 +511,6 @@ mod tests {
jwks_cache: JwkCache::default(),
thread_pool: ThreadPool::new(1),
scram_protocol_timeout: std::time::Duration::from_secs(5),
rate_limiter_enabled: true,
rate_limiter: AuthRateLimiter::new(&RateBucketInfo::DEFAULT_AUTH_SET),
rate_limit_ip_subnet: 64,
ip_allowlist_check_enabled: true,
is_vpc_acccess_proxy: false,
is_auth_broker: false,
@@ -655,51 +527,6 @@ mod tests {
}
}
#[test]
fn masked_ip() {
let ip_a = IpAddr::V4([127, 0, 0, 1].into());
let ip_b = IpAddr::V4([127, 0, 0, 2].into());
let ip_c = IpAddr::V4([192, 168, 1, 101].into());
let ip_d = IpAddr::V4([192, 168, 1, 102].into());
let ip_e = IpAddr::V6("abcd:abcd:abcd:abcd:abcd:abcd:abcd:abcd".parse().unwrap());
let ip_f = IpAddr::V6("abcd:abcd:abcd:abcd:1234:abcd:abcd:abcd".parse().unwrap());
assert_ne!(MaskedIp::new(ip_a, 64), MaskedIp::new(ip_b, 64));
assert_ne!(MaskedIp::new(ip_a, 32), MaskedIp::new(ip_b, 32));
assert_eq!(MaskedIp::new(ip_a, 30), MaskedIp::new(ip_b, 30));
assert_eq!(MaskedIp::new(ip_c, 30), MaskedIp::new(ip_d, 30));
assert_ne!(MaskedIp::new(ip_e, 128), MaskedIp::new(ip_f, 128));
assert_eq!(MaskedIp::new(ip_e, 64), MaskedIp::new(ip_f, 64));
}
#[test]
fn test_default_auth_rate_limit_set() {
// these values used to exceed u32::MAX
assert_eq!(
RateBucketInfo::DEFAULT_AUTH_SET,
[
RateBucketInfo {
interval: Duration::from_secs(1),
max_rpi: 1000 * 4096,
},
RateBucketInfo {
interval: Duration::from_secs(60),
max_rpi: 600 * 4096 * 60,
},
RateBucketInfo {
interval: Duration::from_secs(600),
max_rpi: 300 * 4096 * 600,
}
]
);
for x in RateBucketInfo::DEFAULT_AUTH_SET {
let y = x.to_string().parse().unwrap();
assert_eq!(x, y);
}
}
#[tokio::test]
async fn auth_quirks_scram() {
let (mut client, server) = tokio::io::duplex(1024);
@@ -888,7 +715,7 @@ mod tests {
.await
.unwrap();
assert_eq!(creds.0.info.endpoint, "my-endpoint");
assert_eq!(creds.info.endpoint, "my-endpoint");
handle.await.unwrap();
}

View File

@@ -32,9 +32,7 @@ use crate::ext::TaskExt;
use crate::http::health_server::AppMetrics;
use crate::intern::RoleNameInt;
use crate::metrics::{Metrics, ThreadPoolMetrics};
use crate::rate_limiter::{
BucketRateLimiter, EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo,
};
use crate::rate_limiter::{EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo};
use crate::scram::threadpool::ThreadPool;
use crate::serverless::cancel_set::CancelSet;
use crate::serverless::{self, GlobalConnPoolOptions};
@@ -69,15 +67,6 @@ struct LocalProxyCliArgs {
/// Can be given multiple times for different bucket sizes.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_ENDPOINT_SET)]
user_rps_limit: Vec<RateBucketInfo>,
/// Whether the auth rate limiter actually takes effect (for testing)
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
auth_rate_limit_enabled: bool,
/// Authentication rate limiter max number of hashes per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_AUTH_SET)]
auth_rate_limit: Vec<RateBucketInfo>,
/// The IP subnet to use when considering whether two IP addresses are considered the same.
#[clap(long, default_value_t = 64)]
auth_rate_limit_ip_subnet: u8,
/// Whether to retry the connection to the compute node
#[clap(long, default_value = config::RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)]
connect_to_compute_retry: String,
@@ -282,9 +271,6 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
jwks_cache: JwkCache::default(),
thread_pool: ThreadPool::new(0),
scram_protocol_timeout: Duration::from_secs(10),
rate_limiter_enabled: false,
rate_limiter: BucketRateLimiter::new(vec![]),
rate_limit_ip_subnet: 64,
ip_allowlist_check_enabled: true,
is_vpc_acccess_proxy: false,
is_auth_broker: false,

View File

@@ -20,7 +20,7 @@ use utils::sentry_init::init_sentry;
use utils::{project_build_tag, project_git_version};
use crate::auth::backend::jwt::JwkCache;
use crate::auth::backend::{AuthRateLimiter, ConsoleRedirectBackend, MaybeOwned};
use crate::auth::backend::{ConsoleRedirectBackend, MaybeOwned};
use crate::cancellation::{CancellationHandler, handle_cancel_messages};
use crate::config::{
self, AuthenticationConfig, CacheOptions, ComputeConfig, HttpConfig, ProjectInfoCacheOptions,
@@ -29,9 +29,7 @@ use crate::config::{
use crate::context::parquet::ParquetUploadArgs;
use crate::http::health_server::AppMetrics;
use crate::metrics::Metrics;
use crate::rate_limiter::{
EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo, WakeComputeRateLimiter,
};
use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo, WakeComputeRateLimiter};
use crate::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use crate::redis::kv_ops::RedisKVClient;
use crate::redis::{elasticache, notifications};
@@ -154,15 +152,6 @@ struct ProxyCliArgs {
/// Wake compute rate limiter max number of requests per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)]
wake_compute_limit: Vec<RateBucketInfo>,
/// Whether the auth rate limiter actually takes effect (for testing)
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
auth_rate_limit_enabled: bool,
/// Authentication rate limiter max number of hashes per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_AUTH_SET)]
auth_rate_limit: Vec<RateBucketInfo>,
/// The IP subnet to use when considering whether two IP addresses are considered the same.
#[clap(long, default_value_t = 64)]
auth_rate_limit_ip_subnet: u8,
/// Redis rate limiter max number of requests per second.
#[clap(long, default_values_t = RateBucketInfo::DEFAULT_REDIS_SET)]
redis_rps_limit: Vec<RateBucketInfo>,
@@ -410,22 +399,9 @@ pub async fn run() -> anyhow::Result<()> {
Some(tx_cancel),
));
// bit of a hack - find the min rps and max rps supported and turn it into
// leaky bucket config instead
let max = args
.endpoint_rps_limit
.iter()
.map(|x| x.rps())
.max_by(f64::total_cmp)
.unwrap_or(EndpointRateLimiter::DEFAULT.max);
let rps = args
.endpoint_rps_limit
.iter()
.map(|x| x.rps())
.min_by(f64::total_cmp)
.unwrap_or(EndpointRateLimiter::DEFAULT.rps);
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
LeakyBucketConfig { rps, max },
RateBucketInfo::to_leaky_bucket(&args.endpoint_rps_limit)
.unwrap_or(EndpointRateLimiter::DEFAULT),
64,
));
@@ -678,9 +654,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
jwks_cache: JwkCache::default(),
thread_pool,
scram_protocol_timeout: args.scram_protocol_timeout,
rate_limiter_enabled: args.auth_rate_limit_enabled,
rate_limiter: AuthRateLimiter::new(args.auth_rate_limit.clone()),
rate_limit_ip_subnet: args.auth_rate_limit_ip_subnet,
ip_allowlist_check_enabled: !args.is_private_access_proxy,
is_vpc_acccess_proxy: args.is_private_access_proxy,
is_auth_broker: args.is_auth_broker,

View File

@@ -1,30 +1,25 @@
use std::collections::HashSet;
use std::collections::{HashMap, HashSet, hash_map};
use std::convert::Infallible;
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use std::time::Duration;
use async_trait::async_trait;
use clashmap::ClashMap;
use clashmap::mapref::one::Ref;
use rand::{Rng, thread_rng};
use smol_str::SmolStr;
use tokio::sync::Mutex;
use tokio::time::Instant;
use tracing::{debug, info};
use super::{Cache, Cached};
use crate::auth::IpPattern;
use crate::config::ProjectInfoCacheOptions;
use crate::control_plane::{AccessBlockerFlags, AuthSecret};
use crate::control_plane::{EndpointAccessControl, RoleAccessControl};
use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
use crate::types::{EndpointId, RoleName};
#[async_trait]
pub(crate) trait ProjectInfoCache {
fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt);
fn invalidate_allowed_vpc_endpoint_ids_for_projects(&self, project_ids: Vec<ProjectIdInt>);
fn invalidate_allowed_vpc_endpoint_ids_for_org(&self, account_id: AccountIdInt);
fn invalidate_block_public_or_vpc_access_for_project(&self, project_id: ProjectIdInt);
fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt);
fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt);
fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt);
async fn decrement_active_listeners(&self);
async fn increment_active_listeners(&self);
@@ -42,6 +37,10 @@ impl<T> Entry<T> {
value,
}
}
pub(crate) fn get(&self, valid_since: Instant) -> Option<&T> {
(valid_since < self.created_at).then_some(&self.value)
}
}
impl<T> From<T> for Entry<T> {
@@ -50,101 +49,32 @@ impl<T> From<T> for Entry<T> {
}
}
#[derive(Default)]
struct EndpointInfo {
secret: std::collections::HashMap<RoleNameInt, Entry<Option<AuthSecret>>>,
allowed_ips: Option<Entry<Arc<Vec<IpPattern>>>>,
block_public_or_vpc_access: Option<Entry<AccessBlockerFlags>>,
allowed_vpc_endpoint_ids: Option<Entry<Arc<Vec<String>>>>,
role_controls: HashMap<RoleNameInt, Entry<RoleAccessControl>>,
controls: Option<Entry<EndpointAccessControl>>,
}
impl EndpointInfo {
fn check_ignore_cache(ignore_cache_since: Option<Instant>, created_at: Instant) -> bool {
match ignore_cache_since {
None => false,
Some(t) => t < created_at,
}
}
pub(crate) fn get_role_secret(
&self,
role_name: RoleNameInt,
valid_since: Instant,
ignore_cache_since: Option<Instant>,
) -> Option<(Option<AuthSecret>, bool)> {
if let Some(secret) = self.secret.get(&role_name) {
if valid_since < secret.created_at {
return Some((
secret.value.clone(),
Self::check_ignore_cache(ignore_cache_since, secret.created_at),
));
}
}
None
) -> Option<RoleAccessControl> {
let controls = self.role_controls.get(&role_name)?;
controls.get(valid_since).cloned()
}
pub(crate) fn get_allowed_ips(
&self,
valid_since: Instant,
ignore_cache_since: Option<Instant>,
) -> Option<(Arc<Vec<IpPattern>>, bool)> {
if let Some(allowed_ips) = &self.allowed_ips {
if valid_since < allowed_ips.created_at {
return Some((
allowed_ips.value.clone(),
Self::check_ignore_cache(ignore_cache_since, allowed_ips.created_at),
));
}
}
None
}
pub(crate) fn get_allowed_vpc_endpoint_ids(
&self,
valid_since: Instant,
ignore_cache_since: Option<Instant>,
) -> Option<(Arc<Vec<String>>, bool)> {
if let Some(allowed_vpc_endpoint_ids) = &self.allowed_vpc_endpoint_ids {
if valid_since < allowed_vpc_endpoint_ids.created_at {
return Some((
allowed_vpc_endpoint_ids.value.clone(),
Self::check_ignore_cache(
ignore_cache_since,
allowed_vpc_endpoint_ids.created_at,
),
));
}
}
None
}
pub(crate) fn get_block_public_or_vpc_access(
&self,
valid_since: Instant,
ignore_cache_since: Option<Instant>,
) -> Option<(AccessBlockerFlags, bool)> {
if let Some(block_public_or_vpc_access) = &self.block_public_or_vpc_access {
if valid_since < block_public_or_vpc_access.created_at {
return Some((
block_public_or_vpc_access.value.clone(),
Self::check_ignore_cache(
ignore_cache_since,
block_public_or_vpc_access.created_at,
),
));
}
}
None
pub(crate) fn get_controls(&self, valid_since: Instant) -> Option<EndpointAccessControl> {
let controls = self.controls.as_ref()?;
controls.get(valid_since).cloned()
}
pub(crate) fn invalidate_allowed_ips(&mut self) {
self.allowed_ips = None;
}
pub(crate) fn invalidate_allowed_vpc_endpoint_ids(&mut self) {
self.allowed_vpc_endpoint_ids = None;
}
pub(crate) fn invalidate_block_public_or_vpc_access(&mut self) {
self.block_public_or_vpc_access = None;
pub(crate) fn invalidate_endpoint(&mut self) {
self.controls = None;
}
pub(crate) fn invalidate_role_secret(&mut self, role_name: RoleNameInt) {
self.secret.remove(&role_name);
self.role_controls.remove(&role_name);
}
}
@@ -170,34 +100,22 @@ pub struct ProjectInfoCacheImpl {
#[async_trait]
impl ProjectInfoCache for ProjectInfoCacheImpl {
fn invalidate_allowed_vpc_endpoint_ids_for_projects(&self, project_ids: Vec<ProjectIdInt>) {
info!(
"invalidating allowed vpc endpoint ids for projects `{}`",
project_ids
.iter()
.map(|id| id.to_string())
.collect::<Vec<_>>()
.join(", ")
);
for project_id in project_ids {
let endpoints = self
.project2ep
.get(&project_id)
.map(|kv| kv.value().clone())
.unwrap_or_default();
for endpoint_id in endpoints {
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
endpoint_info.invalidate_allowed_vpc_endpoint_ids();
}
fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt) {
info!("invalidating endpoint access for project `{project_id}`");
let endpoints = self
.project2ep
.get(&project_id)
.map(|kv| kv.value().clone())
.unwrap_or_default();
for endpoint_id in endpoints {
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
endpoint_info.invalidate_endpoint();
}
}
}
fn invalidate_allowed_vpc_endpoint_ids_for_org(&self, account_id: AccountIdInt) {
info!(
"invalidating allowed vpc endpoint ids for org `{}`",
account_id
);
fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt) {
info!("invalidating endpoint access for org `{account_id}`");
let endpoints = self
.account2ep
.get(&account_id)
@@ -205,41 +123,11 @@ impl ProjectInfoCache for ProjectInfoCacheImpl {
.unwrap_or_default();
for endpoint_id in endpoints {
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
endpoint_info.invalidate_allowed_vpc_endpoint_ids();
endpoint_info.invalidate_endpoint();
}
}
}
fn invalidate_block_public_or_vpc_access_for_project(&self, project_id: ProjectIdInt) {
info!(
"invalidating block public or vpc access for project `{}`",
project_id
);
let endpoints = self
.project2ep
.get(&project_id)
.map(|kv| kv.value().clone())
.unwrap_or_default();
for endpoint_id in endpoints {
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
endpoint_info.invalidate_block_public_or_vpc_access();
}
}
}
fn invalidate_allowed_ips_for_project(&self, project_id: ProjectIdInt) {
info!("invalidating allowed ips for project `{}`", project_id);
let endpoints = self
.project2ep
.get(&project_id)
.map(|kv| kv.value().clone())
.unwrap_or_default();
for endpoint_id in endpoints {
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
endpoint_info.invalidate_allowed_ips();
}
}
}
fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt) {
info!(
"invalidating role secret for project_id `{}` and role_name `{}`",
@@ -256,6 +144,7 @@ impl ProjectInfoCache for ProjectInfoCacheImpl {
}
}
}
async fn decrement_active_listeners(&self) {
let mut listeners_guard = self.active_listeners_lock.lock().await;
if *listeners_guard == 0 {
@@ -293,155 +182,71 @@ impl ProjectInfoCacheImpl {
}
}
fn get_endpoint_cache(
&self,
endpoint_id: &EndpointId,
) -> Option<Ref<'_, EndpointIdInt, EndpointInfo>> {
let endpoint_id = EndpointIdInt::get(endpoint_id)?;
self.cache.get(&endpoint_id)
}
pub(crate) fn get_role_secret(
&self,
endpoint_id: &EndpointId,
role_name: &RoleName,
) -> Option<Cached<&Self, Option<AuthSecret>>> {
let endpoint_id = EndpointIdInt::get(endpoint_id)?;
) -> Option<RoleAccessControl> {
let valid_since = self.get_cache_times();
let role_name = RoleNameInt::get(role_name)?;
let (valid_since, ignore_cache_since) = self.get_cache_times();
let endpoint_info = self.cache.get(&endpoint_id)?;
let (value, ignore_cache) =
endpoint_info.get_role_secret(role_name, valid_since, ignore_cache_since)?;
if !ignore_cache {
let cached = Cached {
token: Some((
self,
CachedLookupInfo::new_role_secret(endpoint_id, role_name),
)),
value,
};
return Some(cached);
}
Some(Cached::new_uncached(value))
}
pub(crate) fn get_allowed_ips(
&self,
endpoint_id: &EndpointId,
) -> Option<Cached<&Self, Arc<Vec<IpPattern>>>> {
let endpoint_id = EndpointIdInt::get(endpoint_id)?;
let (valid_since, ignore_cache_since) = self.get_cache_times();
let endpoint_info = self.cache.get(&endpoint_id)?;
let value = endpoint_info.get_allowed_ips(valid_since, ignore_cache_since);
let (value, ignore_cache) = value?;
if !ignore_cache {
let cached = Cached {
token: Some((self, CachedLookupInfo::new_allowed_ips(endpoint_id))),
value,
};
return Some(cached);
}
Some(Cached::new_uncached(value))
}
pub(crate) fn get_allowed_vpc_endpoint_ids(
&self,
endpoint_id: &EndpointId,
) -> Option<Cached<&Self, Arc<Vec<String>>>> {
let endpoint_id = EndpointIdInt::get(endpoint_id)?;
let (valid_since, ignore_cache_since) = self.get_cache_times();
let endpoint_info = self.cache.get(&endpoint_id)?;
let value = endpoint_info.get_allowed_vpc_endpoint_ids(valid_since, ignore_cache_since);
let (value, ignore_cache) = value?;
if !ignore_cache {
let cached = Cached {
token: Some((
self,
CachedLookupInfo::new_allowed_vpc_endpoint_ids(endpoint_id),
)),
value,
};
return Some(cached);
}
Some(Cached::new_uncached(value))
}
pub(crate) fn get_block_public_or_vpc_access(
&self,
endpoint_id: &EndpointId,
) -> Option<Cached<&Self, AccessBlockerFlags>> {
let endpoint_id = EndpointIdInt::get(endpoint_id)?;
let (valid_since, ignore_cache_since) = self.get_cache_times();
let endpoint_info = self.cache.get(&endpoint_id)?;
let value = endpoint_info.get_block_public_or_vpc_access(valid_since, ignore_cache_since);
let (value, ignore_cache) = value?;
if !ignore_cache {
let cached = Cached {
token: Some((
self,
CachedLookupInfo::new_block_public_or_vpc_access(endpoint_id),
)),
value,
};
return Some(cached);
}
Some(Cached::new_uncached(value))
let endpoint_info = self.get_endpoint_cache(endpoint_id)?;
endpoint_info.get_role_secret(role_name, valid_since)
}
pub(crate) fn insert_role_secret(
pub(crate) fn get_endpoint_access(
&self,
project_id: ProjectIdInt,
endpoint_id: EndpointIdInt,
role_name: RoleNameInt,
secret: Option<AuthSecret>,
) {
if self.cache.len() >= self.config.size {
// If there are too many entries, wait until the next gc cycle.
return;
}
self.insert_project2endpoint(project_id, endpoint_id);
let mut entry = self.cache.entry(endpoint_id).or_default();
if entry.secret.len() < self.config.max_roles {
entry.secret.insert(role_name, secret.into());
}
endpoint_id: &EndpointId,
) -> Option<EndpointAccessControl> {
let valid_since = self.get_cache_times();
let endpoint_info = self.get_endpoint_cache(endpoint_id)?;
endpoint_info.get_controls(valid_since)
}
pub(crate) fn insert_allowed_ips(
&self,
project_id: ProjectIdInt,
endpoint_id: EndpointIdInt,
allowed_ips: Arc<Vec<IpPattern>>,
) {
if self.cache.len() >= self.config.size {
// If there are too many entries, wait until the next gc cycle.
return;
}
self.insert_project2endpoint(project_id, endpoint_id);
self.cache.entry(endpoint_id).or_default().allowed_ips = Some(allowed_ips.into());
}
pub(crate) fn insert_allowed_vpc_endpoint_ids(
pub(crate) fn insert_endpoint_access(
&self,
account_id: Option<AccountIdInt>,
project_id: ProjectIdInt,
endpoint_id: EndpointIdInt,
allowed_vpc_endpoint_ids: Arc<Vec<String>>,
role_name: RoleNameInt,
controls: EndpointAccessControl,
role_controls: RoleAccessControl,
) {
if self.cache.len() >= self.config.size {
// If there are too many entries, wait until the next gc cycle.
return;
}
if let Some(account_id) = account_id {
self.insert_account2endpoint(account_id, endpoint_id);
}
self.insert_project2endpoint(project_id, endpoint_id);
self.cache
.entry(endpoint_id)
.or_default()
.allowed_vpc_endpoint_ids = Some(allowed_vpc_endpoint_ids.into());
}
pub(crate) fn insert_block_public_or_vpc_access(
&self,
project_id: ProjectIdInt,
endpoint_id: EndpointIdInt,
access_blockers: AccessBlockerFlags,
) {
if self.cache.len() >= self.config.size {
// If there are too many entries, wait until the next gc cycle.
return;
}
self.insert_project2endpoint(project_id, endpoint_id);
self.cache
.entry(endpoint_id)
.or_default()
.block_public_or_vpc_access = Some(access_blockers.into());
let controls = Entry::from(controls);
let role_controls = Entry::from(role_controls);
match self.cache.entry(endpoint_id) {
clashmap::Entry::Vacant(e) => {
e.insert(EndpointInfo {
role_controls: HashMap::from_iter([(role_name, role_controls)]),
controls: Some(controls),
});
}
clashmap::Entry::Occupied(mut e) => {
let ep = e.get_mut();
ep.controls = Some(controls);
if ep.role_controls.len() < self.config.max_roles {
ep.role_controls.insert(role_name, role_controls);
}
}
}
}
fn insert_project2endpoint(&self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt) {
@@ -452,6 +257,7 @@ impl ProjectInfoCacheImpl {
.insert(project_id, HashSet::from([endpoint_id]));
}
}
fn insert_account2endpoint(&self, account_id: AccountIdInt, endpoint_id: EndpointIdInt) {
if let Some(mut endpoints) = self.account2ep.get_mut(&account_id) {
endpoints.insert(endpoint_id);
@@ -460,21 +266,57 @@ impl ProjectInfoCacheImpl {
.insert(account_id, HashSet::from([endpoint_id]));
}
}
fn get_cache_times(&self) -> (Instant, Option<Instant>) {
let mut valid_since = Instant::now() - self.config.ttl;
// Only ignore cache if ttl is disabled.
fn ignore_ttl_since(&self) -> Option<Instant> {
let ttl_disabled_since_us = self
.ttl_disabled_since_us
.load(std::sync::atomic::Ordering::Relaxed);
let ignore_cache_since = if ttl_disabled_since_us == u64::MAX {
None
} else {
let ignore_cache_since = self.start_time + Duration::from_micros(ttl_disabled_since_us);
if ttl_disabled_since_us == u64::MAX {
return None;
}
Some(self.start_time + Duration::from_micros(ttl_disabled_since_us))
}
fn get_cache_times(&self) -> Instant {
let mut valid_since = Instant::now() - self.config.ttl;
if let Some(ignore_ttl_since) = self.ignore_ttl_since() {
// We are fine if entry is not older than ttl or was added before we are getting notifications.
valid_since = valid_since.min(ignore_cache_since);
Some(ignore_cache_since)
valid_since = valid_since.min(ignore_ttl_since);
}
valid_since
}
pub fn maybe_invalidate_role_secret(&self, endpoint_id: &EndpointId, role_name: &RoleName) {
let Some(endpoint_id) = EndpointIdInt::get(endpoint_id) else {
return;
};
(valid_since, ignore_cache_since)
let Some(role_name) = RoleNameInt::get(role_name) else {
return;
};
let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) else {
return;
};
let entry = endpoint_info.role_controls.entry(role_name);
let hash_map::Entry::Occupied(role_controls) = entry else {
return;
};
let created_at = role_controls.get().created_at;
let expire = match self.ignore_ttl_since() {
// if ignoring TTL, we should still try and roll the password if it's old
// and we the client gave an incorrect password. There could be some lag on the redis channel.
Some(_) => created_at + self.config.ttl < Instant::now(),
// edge case: redis is down, let's be generous and invalidate the cache immediately.
None => true,
};
if expire {
role_controls.remove();
}
}
pub async fn gc_worker(&self) -> anyhow::Result<Infallible> {
@@ -509,84 +351,12 @@ impl ProjectInfoCacheImpl {
}
}
/// Lookup info for project info cache.
/// This is used to invalidate cache entries.
pub(crate) struct CachedLookupInfo {
/// Search by this key.
endpoint_id: EndpointIdInt,
lookup_type: LookupType,
}
impl CachedLookupInfo {
pub(self) fn new_role_secret(endpoint_id: EndpointIdInt, role_name: RoleNameInt) -> Self {
Self {
endpoint_id,
lookup_type: LookupType::RoleSecret(role_name),
}
}
pub(self) fn new_allowed_ips(endpoint_id: EndpointIdInt) -> Self {
Self {
endpoint_id,
lookup_type: LookupType::AllowedIps,
}
}
pub(self) fn new_allowed_vpc_endpoint_ids(endpoint_id: EndpointIdInt) -> Self {
Self {
endpoint_id,
lookup_type: LookupType::AllowedVpcEndpointIds,
}
}
pub(self) fn new_block_public_or_vpc_access(endpoint_id: EndpointIdInt) -> Self {
Self {
endpoint_id,
lookup_type: LookupType::BlockPublicOrVpcAccess,
}
}
}
enum LookupType {
RoleSecret(RoleNameInt),
AllowedIps,
AllowedVpcEndpointIds,
BlockPublicOrVpcAccess,
}
impl Cache for ProjectInfoCacheImpl {
type Key = SmolStr;
// Value is not really used here, but we need to specify it.
type Value = SmolStr;
type LookupInfo<Key> = CachedLookupInfo;
fn invalidate(&self, key: &Self::LookupInfo<SmolStr>) {
match &key.lookup_type {
LookupType::RoleSecret(role_name) => {
if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) {
endpoint_info.invalidate_role_secret(*role_name);
}
}
LookupType::AllowedIps => {
if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) {
endpoint_info.invalidate_allowed_ips();
}
}
LookupType::AllowedVpcEndpointIds => {
if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) {
endpoint_info.invalidate_allowed_vpc_endpoint_ids();
}
}
LookupType::BlockPublicOrVpcAccess => {
if let Some(mut endpoint_info) = self.cache.get_mut(&key.endpoint_id) {
endpoint_info.invalidate_block_public_or_vpc_access();
}
}
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
use crate::control_plane::{AccessBlockerFlags, AuthSecret};
use crate::scram::ServerSecret;
use crate::types::ProjectId;
@@ -601,6 +371,8 @@ mod tests {
});
let project_id: ProjectId = "project".into();
let endpoint_id: EndpointId = "endpoint".into();
let account_id: Option<AccountIdInt> = None;
let user1: RoleName = "user1".into();
let user2: RoleName = "user2".into();
let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
@@ -609,183 +381,73 @@ mod tests {
"127.0.0.1".parse().unwrap(),
"127.0.0.2".parse().unwrap(),
]);
cache.insert_role_secret(
cache.insert_endpoint_access(
account_id,
(&project_id).into(),
(&endpoint_id).into(),
(&user1).into(),
secret1.clone(),
EndpointAccessControl {
allowed_ips: allowed_ips.clone(),
allowed_vpce: Arc::new(vec![]),
flags: AccessBlockerFlags::default(),
},
RoleAccessControl {
secret: secret1.clone(),
},
);
cache.insert_role_secret(
cache.insert_endpoint_access(
account_id,
(&project_id).into(),
(&endpoint_id).into(),
(&user2).into(),
secret2.clone(),
);
cache.insert_allowed_ips(
(&project_id).into(),
(&endpoint_id).into(),
allowed_ips.clone(),
EndpointAccessControl {
allowed_ips: allowed_ips.clone(),
allowed_vpce: Arc::new(vec![]),
flags: AccessBlockerFlags::default(),
},
RoleAccessControl {
secret: secret2.clone(),
},
);
let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
assert!(cached.cached());
assert_eq!(cached.value, secret1);
assert_eq!(cached.secret, secret1);
let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
assert!(cached.cached());
assert_eq!(cached.value, secret2);
assert_eq!(cached.secret, secret2);
// Shouldn't add more than 2 roles.
let user3: RoleName = "user3".into();
let secret3 = Some(AuthSecret::Scram(ServerSecret::mock([3; 32])));
cache.insert_role_secret(
cache.insert_endpoint_access(
account_id,
(&project_id).into(),
(&endpoint_id).into(),
(&user3).into(),
secret3.clone(),
EndpointAccessControl {
allowed_ips: allowed_ips.clone(),
allowed_vpce: Arc::new(vec![]),
flags: AccessBlockerFlags::default(),
},
RoleAccessControl {
secret: secret3.clone(),
},
);
assert!(cache.get_role_secret(&endpoint_id, &user3).is_none());
let cached = cache.get_allowed_ips(&endpoint_id).unwrap();
assert!(cached.cached());
assert_eq!(cached.value, allowed_ips);
let cached = cache.get_endpoint_access(&endpoint_id).unwrap();
assert_eq!(cached.allowed_ips, allowed_ips);
tokio::time::advance(Duration::from_secs(2)).await;
let cached = cache.get_role_secret(&endpoint_id, &user1);
assert!(cached.is_none());
let cached = cache.get_role_secret(&endpoint_id, &user2);
assert!(cached.is_none());
let cached = cache.get_allowed_ips(&endpoint_id);
let cached = cache.get_endpoint_access(&endpoint_id);
assert!(cached.is_none());
}
#[tokio::test]
async fn test_project_info_cache_invalidations() {
tokio::time::pause();
let cache = Arc::new(ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
size: 2,
max_roles: 2,
ttl: Duration::from_secs(1),
gc_interval: Duration::from_secs(600),
}));
cache.clone().increment_active_listeners().await;
tokio::time::advance(Duration::from_secs(2)).await;
let project_id: ProjectId = "project".into();
let endpoint_id: EndpointId = "endpoint".into();
let user1: RoleName = "user1".into();
let user2: RoleName = "user2".into();
let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
let secret2 = Some(AuthSecret::Scram(ServerSecret::mock([2; 32])));
let allowed_ips = Arc::new(vec![
"127.0.0.1".parse().unwrap(),
"127.0.0.2".parse().unwrap(),
]);
cache.insert_role_secret(
(&project_id).into(),
(&endpoint_id).into(),
(&user1).into(),
secret1.clone(),
);
cache.insert_role_secret(
(&project_id).into(),
(&endpoint_id).into(),
(&user2).into(),
secret2.clone(),
);
cache.insert_allowed_ips(
(&project_id).into(),
(&endpoint_id).into(),
allowed_ips.clone(),
);
tokio::time::advance(Duration::from_secs(2)).await;
// Nothing should be invalidated.
let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
// TTL is disabled, so it should be impossible to invalidate this value.
assert!(!cached.cached());
assert_eq!(cached.value, secret1);
cached.invalidate(); // Shouldn't do anything.
let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
assert_eq!(cached.value, secret1);
let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
assert!(!cached.cached());
assert_eq!(cached.value, secret2);
// The only way to invalidate this value is to invalidate via the api.
cache.invalidate_role_secret_for_project((&project_id).into(), (&user2).into());
assert!(cache.get_role_secret(&endpoint_id, &user2).is_none());
let cached = cache.get_allowed_ips(&endpoint_id).unwrap();
assert!(!cached.cached());
assert_eq!(cached.value, allowed_ips);
}
#[tokio::test]
async fn test_increment_active_listeners_invalidate_added_before() {
tokio::time::pause();
let cache = Arc::new(ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
size: 2,
max_roles: 2,
ttl: Duration::from_secs(1),
gc_interval: Duration::from_secs(600),
}));
let project_id: ProjectId = "project".into();
let endpoint_id: EndpointId = "endpoint".into();
let user1: RoleName = "user1".into();
let user2: RoleName = "user2".into();
let secret1 = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
let secret2 = Some(AuthSecret::Scram(ServerSecret::mock([2; 32])));
let allowed_ips = Arc::new(vec![
"127.0.0.1".parse().unwrap(),
"127.0.0.2".parse().unwrap(),
]);
cache.insert_role_secret(
(&project_id).into(),
(&endpoint_id).into(),
(&user1).into(),
secret1.clone(),
);
cache.clone().increment_active_listeners().await;
tokio::time::advance(Duration::from_millis(100)).await;
cache.insert_role_secret(
(&project_id).into(),
(&endpoint_id).into(),
(&user2).into(),
secret2.clone(),
);
// Added before ttl was disabled + ttl should be still cached.
let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
assert!(cached.cached());
let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
assert!(cached.cached());
tokio::time::advance(Duration::from_secs(1)).await;
// Added before ttl was disabled + ttl should expire.
assert!(cache.get_role_secret(&endpoint_id, &user1).is_none());
assert!(cache.get_role_secret(&endpoint_id, &user2).is_none());
// Added after ttl was disabled + ttl should not be cached.
cache.insert_allowed_ips(
(&project_id).into(),
(&endpoint_id).into(),
allowed_ips.clone(),
);
let cached = cache.get_allowed_ips(&endpoint_id).unwrap();
assert!(!cached.cached());
tokio::time::advance(Duration::from_secs(1)).await;
// Added before ttl was disabled + ttl still should expire.
assert!(cache.get_role_secret(&endpoint_id, &user1).is_none());
assert!(cache.get_role_secret(&endpoint_id, &user2).is_none());
// Shouldn't be invalidated.
let cached = cache.get_allowed_ips(&endpoint_id).unwrap();
assert!(!cached.cached());
assert_eq!(cached.value, allowed_ips);
}
}

View File

@@ -12,8 +12,8 @@ use tokio::net::TcpStream;
use tokio::sync::{mpsc, oneshot};
use tracing::{debug, error, info, warn};
use crate::auth::AuthError;
use crate::auth::backend::ComputeUserInfo;
use crate::auth::{AuthError, check_peer_addr_is_in_list};
use crate::config::ComputeConfig;
use crate::context::RequestContext;
use crate::control_plane::ControlPlaneApi;
@@ -21,7 +21,6 @@ use crate::error::ReportableError;
use crate::ext::LockExt;
use crate::metrics::{CancelChannelSizeGuard, CancellationRequest, Metrics, RedisMsgKind};
use crate::pqproto::CancelKeyData;
use crate::protocol2::ConnectionInfoExtra;
use crate::rate_limiter::LeakyBucketRateLimiter;
use crate::redis::keys::KeyPrefix;
use crate::redis::kv_ops::RedisKVClient;
@@ -272,13 +271,7 @@ pub(crate) enum CancelError {
#[error("rate limit exceeded")]
RateLimit,
#[error("IP is not allowed")]
IpNotAllowed,
#[error("VPC endpoint id is not allowed to connect")]
VpcEndpointIdNotAllowed,
#[error("Authentication backend error")]
#[error("Authentication error")]
AuthError(#[from] AuthError),
#[error("key not found")]
@@ -297,10 +290,7 @@ impl ReportableError for CancelError {
}
CancelError::Postgres(_) => crate::error::ErrorKind::Compute,
CancelError::RateLimit => crate::error::ErrorKind::RateLimit,
CancelError::IpNotAllowed
| CancelError::VpcEndpointIdNotAllowed
| CancelError::NotFound => crate::error::ErrorKind::User,
CancelError::AuthError(_) => crate::error::ErrorKind::ControlPlane,
CancelError::NotFound | CancelError::AuthError(_) => crate::error::ErrorKind::User,
CancelError::InternalError => crate::error::ErrorKind::Service,
}
}
@@ -422,7 +412,13 @@ impl CancellationHandler {
IpAddr::V4(ip) => IpNet::V4(Ipv4Net::new_assert(ip, 24).trunc()), // use defaut mask here
IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()),
};
if !self.limiter.lock_propagate_poison().check(subnet_key, 1) {
let allowed = {
let rate_limit_config = None;
let limiter = self.limiter.lock_propagate_poison();
limiter.check(subnet_key, rate_limit_config, 1)
};
if !allowed {
// log only the subnet part of the IP address to know which subnet is rate limited
tracing::warn!("Rate limit exceeded. Skipping cancellation message, {subnet_key}");
Metrics::get()
@@ -450,52 +446,13 @@ impl CancellationHandler {
return Err(CancelError::NotFound);
};
if check_ip_allowed {
let ip_allowlist = auth_backend
.get_allowed_ips(&ctx, &cancel_closure.user_info)
.await
.map_err(|e| CancelError::AuthError(e.into()))?;
if !check_peer_addr_is_in_list(&ctx.peer_addr(), &ip_allowlist) {
// log it here since cancel_session could be spawned in a task
tracing::warn!(
"IP is not allowed to cancel the query: {key}, address: {}",
ctx.peer_addr()
);
return Err(CancelError::IpNotAllowed);
}
}
// check if a VPC endpoint ID is coming in and if yes, if it's allowed
let access_blocks = auth_backend
.get_block_public_or_vpc_access(&ctx, &cancel_closure.user_info)
let info = &cancel_closure.user_info;
let access_controls = auth_backend
.get_endpoint_access_control(&ctx, &info.endpoint, &info.user)
.await
.map_err(|e| CancelError::AuthError(e.into()))?;
if check_vpc_allowed {
if access_blocks.vpc_access_blocked {
return Err(CancelError::AuthError(AuthError::NetworkNotAllowed));
}
let incoming_vpc_endpoint_id = match ctx.extra() {
None => return Err(CancelError::AuthError(AuthError::MissingVPCEndpointId)),
Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(),
Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(),
};
let allowed_vpc_endpoint_ids = auth_backend
.get_allowed_vpc_endpoint_ids(&ctx, &cancel_closure.user_info)
.await
.map_err(|e| CancelError::AuthError(e.into()))?;
// TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that.
if !allowed_vpc_endpoint_ids.is_empty()
&& !allowed_vpc_endpoint_ids.contains(&incoming_vpc_endpoint_id)
{
return Err(CancelError::VpcEndpointIdNotAllowed);
}
} else if access_blocks.public_access_blocked {
return Err(CancelError::VpcEndpointIdNotAllowed);
}
access_controls.check(&ctx, check_ip_allowed, check_vpc_allowed)?;
Metrics::get()
.proxy

View File

@@ -7,7 +7,6 @@ use arc_swap::ArcSwapOption;
use clap::ValueEnum;
use remote_storage::RemoteStorageConfig;
use crate::auth::backend::AuthRateLimiter;
use crate::auth::backend::jwt::JwkCache;
use crate::control_plane::locks::ApiLocks;
use crate::rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig};
@@ -65,9 +64,6 @@ pub struct HttpConfig {
pub struct AuthenticationConfig {
pub thread_pool: Arc<ThreadPool>,
pub scram_protocol_timeout: tokio::time::Duration,
pub rate_limiter_enabled: bool,
pub rate_limiter: AuthRateLimiter,
pub rate_limit_ip_subnet: u8,
pub ip_allowlist_check_enabled: bool,
pub is_vpc_acccess_proxy: bool,
pub jwks_cache: JwkCache,

View File

@@ -159,7 +159,7 @@ pub async fn task_main(
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
config: &'static ProxyConfig,
backend: &'static ConsoleRedirectBackend,
ctx: &RequestContext,

View File

@@ -370,6 +370,18 @@ impl RequestContext {
}
}
pub(crate) fn latency_timer_pause_at(
&self,
at: tokio::time::Instant,
waiting_for: Waiting,
) -> LatencyTimerPause<'_> {
LatencyTimerPause {
ctx: self,
start: at,
waiting_for,
}
}
pub(crate) fn get_proxy_latency(&self) -> LatencyAccumulated {
self.0
.try_lock()

View File

@@ -7,7 +7,9 @@ use std::time::Duration;
use ::http::HeaderName;
use ::http::header::AUTHORIZATION;
use bytes::Bytes;
use futures::TryFutureExt;
use hyper::StatusCode;
use postgres_client::config::SslMode;
use tokio::time::Instant;
use tracing::{Instrument, debug, info, info_span, warn};
@@ -15,7 +17,6 @@ use tracing::{Instrument, debug, info, info_span, warn};
use super::super::messages::{ControlPlaneErrorMessage, GetEndpointAccessControl, WakeCompute};
use crate::auth::backend::ComputeUserInfo;
use crate::auth::backend::jwt::AuthRule;
use crate::cache::Cached;
use crate::context::RequestContext;
use crate::control_plane::caches::ApiCaches;
use crate::control_plane::errors::{
@@ -24,12 +25,12 @@ use crate::control_plane::errors::{
use crate::control_plane::locks::ApiLocks;
use crate::control_plane::messages::{ColdStartInfo, EndpointJwksResponse, Reason};
use crate::control_plane::{
AccessBlockerFlags, AuthInfo, AuthSecret, CachedAccessBlockerFlags, CachedAllowedIps,
CachedAllowedVpcEndpointIds, CachedNodeInfo, CachedRoleSecret, NodeInfo,
AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, EndpointAccessControl, NodeInfo,
RoleAccessControl,
};
use crate::metrics::{CacheOutcome, Metrics};
use crate::metrics::Metrics;
use crate::rate_limiter::WakeComputeRateLimiter;
use crate::types::{EndpointCacheKey, EndpointId};
use crate::types::{EndpointCacheKey, EndpointId, RoleName};
use crate::{compute, http, scram};
pub(crate) const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
@@ -66,66 +67,41 @@ impl NeonControlPlaneClient {
self.endpoint.url().as_str()
}
async fn do_get_auth_info(
&self,
ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> Result<AuthInfo, GetAuthInfoError> {
if !self
.caches
.endpoints_cache
.is_valid(ctx, &user_info.endpoint.normalize())
{
// TODO: refactor this because it's weird
// this is a failure to authenticate but we return Ok.
info!("endpoint is not valid, skipping the request");
return Ok(AuthInfo::default());
}
self.do_get_auth_req(user_info, &ctx.session_id(), Some(ctx))
.await
}
async fn do_get_auth_req(
&self,
user_info: &ComputeUserInfo,
session_id: &uuid::Uuid,
ctx: Option<&RequestContext>,
ctx: &RequestContext,
endpoint: &EndpointId,
role: &RoleName,
) -> Result<AuthInfo, GetAuthInfoError> {
let request_id: String = session_id.to_string();
let application_name = if let Some(ctx) = ctx {
ctx.console_application_name()
} else {
"auth_cancellation".to_string()
};
async {
let request = self
.endpoint
.get_path("get_endpoint_access_control")
.header(X_REQUEST_ID, &request_id)
.header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
.query(&[("session_id", session_id)])
.query(&[
("application_name", application_name.as_str()),
("endpointish", user_info.endpoint.as_str()),
("role", user_info.user.as_str()),
])
.build()?;
let response = {
let request = self
.endpoint
.get_path("get_endpoint_access_control")
.header(X_REQUEST_ID, ctx.session_id().to_string())
.header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
.query(&[("session_id", ctx.session_id())])
.query(&[
("application_name", ctx.console_application_name().as_str()),
("endpointish", endpoint.as_str()),
("role", role.as_str()),
])
.build()?;
debug!(url = request.url().as_str(), "sending http request");
let start = Instant::now();
let response = match ctx {
Some(ctx) => {
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
let rsp = self.endpoint.execute(request).await;
drop(pause);
rsp?
}
None => self.endpoint.execute(request).await?,
debug!(url = request.url().as_str(), "sending http request");
let start = Instant::now();
let _pause = ctx.latency_timer_pause_at(start, crate::metrics::Waiting::Cplane);
let response = self.endpoint.execute(request).await?;
info!(duration = ?start.elapsed(), "received http response");
response
};
info!(duration = ?start.elapsed(), "received http response");
let body = match parse_body::<GetEndpointAccessControl>(response).await {
let body = match parse_body::<GetEndpointAccessControl>(
response.status(),
response.bytes().await?,
) {
Ok(body) => body,
// Error 404 is special: it's ok not to have a secret.
// TODO(anna): retry
@@ -180,7 +156,7 @@ impl NeonControlPlaneClient {
async fn do_get_endpoint_jwks(
&self,
ctx: &RequestContext,
endpoint: EndpointId,
endpoint: &EndpointId,
) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
if !self
.caches
@@ -216,7 +192,10 @@ impl NeonControlPlaneClient {
drop(pause);
info!(duration = ?start.elapsed(), "received http response");
let body = parse_body::<EndpointJwksResponse>(response).await?;
let body = parse_body::<EndpointJwksResponse>(
response.status(),
response.bytes().await.map_err(ControlPlaneError::from)?,
)?;
let rules = body
.jwks
@@ -268,7 +247,7 @@ impl NeonControlPlaneClient {
let response = self.endpoint.execute(request).await?;
drop(pause);
info!(duration = ?start.elapsed(), "received http response");
let body = parse_body::<WakeCompute>(response).await?;
let body = parse_body::<WakeCompute>(response.status(), response.bytes().await?)?;
// Unfortunately, ownership won't let us use `Option::ok_or` here.
let (host, port) = match parse_host_port(&body.address) {
@@ -313,225 +292,104 @@ impl NeonControlPlaneClient {
impl super::ControlPlaneApi for NeonControlPlaneClient {
#[tracing::instrument(skip_all)]
async fn get_role_secret(
async fn get_role_access_control(
&self,
ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> Result<CachedRoleSecret, GetAuthInfoError> {
let normalized_ep = &user_info.endpoint.normalize();
let user = &user_info.user;
if let Some(role_secret) = self
endpoint: &EndpointId,
role: &RoleName,
) -> Result<RoleAccessControl, crate::control_plane::errors::GetAuthInfoError> {
let normalized_ep = &endpoint.normalize();
if let Some(secret) = self
.caches
.project_info
.get_role_secret(normalized_ep, user)
.get_role_secret(normalized_ep, role)
{
return Ok(role_secret);
return Ok(secret);
}
let auth_info = self.do_get_auth_info(ctx, user_info).await?;
let account_id = auth_info.account_id;
if !self.caches.endpoints_cache.is_valid(ctx, normalized_ep) {
info!("endpoint is not valid, skipping the request");
return Err(GetAuthInfoError::UnknownEndpoint);
}
let auth_info = self.do_get_auth_req(ctx, endpoint, role).await?;
let control = EndpointAccessControl {
allowed_ips: Arc::new(auth_info.allowed_ips),
allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids),
flags: auth_info.access_blocker_flags,
};
let role_control = RoleAccessControl {
secret: auth_info.secret,
};
if let Some(project_id) = auth_info.project_id {
let normalized_ep_int = normalized_ep.into();
self.caches.project_info.insert_role_secret(
self.caches.project_info.insert_endpoint_access(
auth_info.account_id,
project_id,
normalized_ep_int,
user.into(),
auth_info.secret.clone(),
);
self.caches.project_info.insert_allowed_ips(
project_id,
normalized_ep_int,
Arc::new(auth_info.allowed_ips),
);
self.caches.project_info.insert_allowed_vpc_endpoint_ids(
account_id,
project_id,
normalized_ep_int,
Arc::new(auth_info.allowed_vpc_endpoint_ids),
);
self.caches.project_info.insert_block_public_or_vpc_access(
project_id,
normalized_ep_int,
auth_info.access_blocker_flags,
role.into(),
control,
role_control.clone(),
);
ctx.set_project_id(project_id);
}
// When we just got a secret, we don't need to invalidate it.
Ok(Cached::new_uncached(auth_info.secret))
Ok(role_control)
}
async fn get_allowed_ips(
#[tracing::instrument(skip_all)]
async fn get_endpoint_access_control(
&self,
ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> Result<CachedAllowedIps, GetAuthInfoError> {
let normalized_ep = &user_info.endpoint.normalize();
if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(normalized_ep) {
Metrics::get()
.proxy
.allowed_ips_cache_misses // TODO SR: Should we rename this variable to something like allowed_ip_cache_stats?
.inc(CacheOutcome::Hit);
return Ok(allowed_ips);
endpoint: &EndpointId,
role: &RoleName,
) -> Result<EndpointAccessControl, GetAuthInfoError> {
let normalized_ep = &endpoint.normalize();
if let Some(control) = self.caches.project_info.get_endpoint_access(normalized_ep) {
return Ok(control);
}
Metrics::get()
.proxy
.allowed_ips_cache_misses
.inc(CacheOutcome::Miss);
let auth_info = self.do_get_auth_info(ctx, user_info).await?;
let allowed_ips = Arc::new(auth_info.allowed_ips);
let allowed_vpc_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids);
let access_blocker_flags = auth_info.access_blocker_flags;
let user = &user_info.user;
let account_id = auth_info.account_id;
if !self.caches.endpoints_cache.is_valid(ctx, normalized_ep) {
info!("endpoint is not valid, skipping the request");
return Err(GetAuthInfoError::UnknownEndpoint);
}
let auth_info = self.do_get_auth_req(ctx, endpoint, role).await?;
let control = EndpointAccessControl {
allowed_ips: Arc::new(auth_info.allowed_ips),
allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids),
flags: auth_info.access_blocker_flags,
};
let role_control = RoleAccessControl {
secret: auth_info.secret,
};
if let Some(project_id) = auth_info.project_id {
let normalized_ep_int = normalized_ep.into();
self.caches.project_info.insert_role_secret(
self.caches.project_info.insert_endpoint_access(
auth_info.account_id,
project_id,
normalized_ep_int,
user.into(),
auth_info.secret.clone(),
);
self.caches.project_info.insert_allowed_ips(
project_id,
normalized_ep_int,
allowed_ips.clone(),
);
self.caches.project_info.insert_allowed_vpc_endpoint_ids(
account_id,
project_id,
normalized_ep_int,
allowed_vpc_endpoint_ids.clone(),
);
self.caches.project_info.insert_block_public_or_vpc_access(
project_id,
normalized_ep_int,
access_blocker_flags,
role.into(),
control.clone(),
role_control,
);
ctx.set_project_id(project_id);
}
Ok(Cached::new_uncached(allowed_ips))
}
async fn get_allowed_vpc_endpoint_ids(
&self,
ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> Result<CachedAllowedVpcEndpointIds, GetAuthInfoError> {
let normalized_ep = &user_info.endpoint.normalize();
if let Some(allowed_vpc_endpoint_ids) = self
.caches
.project_info
.get_allowed_vpc_endpoint_ids(normalized_ep)
{
Metrics::get()
.proxy
.vpc_endpoint_id_cache_stats
.inc(CacheOutcome::Hit);
return Ok(allowed_vpc_endpoint_ids);
}
Metrics::get()
.proxy
.vpc_endpoint_id_cache_stats
.inc(CacheOutcome::Miss);
let auth_info = self.do_get_auth_info(ctx, user_info).await?;
let allowed_ips = Arc::new(auth_info.allowed_ips);
let allowed_vpc_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids);
let access_blocker_flags = auth_info.access_blocker_flags;
let user = &user_info.user;
let account_id = auth_info.account_id;
if let Some(project_id) = auth_info.project_id {
let normalized_ep_int = normalized_ep.into();
self.caches.project_info.insert_role_secret(
project_id,
normalized_ep_int,
user.into(),
auth_info.secret.clone(),
);
self.caches.project_info.insert_allowed_ips(
project_id,
normalized_ep_int,
allowed_ips.clone(),
);
self.caches.project_info.insert_allowed_vpc_endpoint_ids(
account_id,
project_id,
normalized_ep_int,
allowed_vpc_endpoint_ids.clone(),
);
self.caches.project_info.insert_block_public_or_vpc_access(
project_id,
normalized_ep_int,
access_blocker_flags,
);
ctx.set_project_id(project_id);
}
Ok(Cached::new_uncached(allowed_vpc_endpoint_ids))
}
async fn get_block_public_or_vpc_access(
&self,
ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> Result<CachedAccessBlockerFlags, GetAuthInfoError> {
let normalized_ep = &user_info.endpoint.normalize();
if let Some(access_blocker_flags) = self
.caches
.project_info
.get_block_public_or_vpc_access(normalized_ep)
{
Metrics::get()
.proxy
.access_blocker_flags_cache_stats
.inc(CacheOutcome::Hit);
return Ok(access_blocker_flags);
}
Metrics::get()
.proxy
.access_blocker_flags_cache_stats
.inc(CacheOutcome::Miss);
let auth_info = self.do_get_auth_info(ctx, user_info).await?;
let allowed_ips = Arc::new(auth_info.allowed_ips);
let allowed_vpc_endpoint_ids = Arc::new(auth_info.allowed_vpc_endpoint_ids);
let access_blocker_flags = auth_info.access_blocker_flags;
let user = &user_info.user;
let account_id = auth_info.account_id;
if let Some(project_id) = auth_info.project_id {
let normalized_ep_int = normalized_ep.into();
self.caches.project_info.insert_role_secret(
project_id,
normalized_ep_int,
user.into(),
auth_info.secret.clone(),
);
self.caches.project_info.insert_allowed_ips(
project_id,
normalized_ep_int,
allowed_ips.clone(),
);
self.caches.project_info.insert_allowed_vpc_endpoint_ids(
account_id,
project_id,
normalized_ep_int,
allowed_vpc_endpoint_ids.clone(),
);
self.caches.project_info.insert_block_public_or_vpc_access(
project_id,
normalized_ep_int,
access_blocker_flags.clone(),
);
ctx.set_project_id(project_id);
}
Ok(Cached::new_uncached(access_blocker_flags))
Ok(control)
}
#[tracing::instrument(skip_all)]
async fn get_endpoint_jwks(
&self,
ctx: &RequestContext,
endpoint: EndpointId,
endpoint: &EndpointId,
) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
self.do_get_endpoint_jwks(ctx, endpoint).await
}
@@ -640,33 +498,33 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
}
/// Parse http response body, taking status code into account.
async fn parse_body<T: for<'a> serde::Deserialize<'a>>(
response: http::Response,
fn parse_body<T: for<'a> serde::Deserialize<'a>>(
status: StatusCode,
body: Bytes,
) -> Result<T, ControlPlaneError> {
let status = response.status();
if status.is_success() {
// We shouldn't log raw body because it may contain secrets.
info!("request succeeded, processing the body");
return Ok(response.json().await?);
return Ok(serde_json::from_slice(&body).map_err(std::io::Error::other)?);
}
let s = response.bytes().await?;
// Log plaintext to be able to detect, whether there are some cases not covered by the error struct.
info!("response_error plaintext: {:?}", s);
info!("response_error plaintext: {:?}", body);
// Don't throw an error here because it's not as important
// as the fact that the request itself has failed.
let mut body = serde_json::from_slice(&s).unwrap_or_else(|e| {
let mut body = serde_json::from_slice(&body).unwrap_or_else(|e| {
warn!("failed to parse error body: {e}");
ControlPlaneErrorMessage {
Box::new(ControlPlaneErrorMessage {
error: "reason unclear (malformed error message)".into(),
http_status_code: status,
status: None,
}
})
});
body.http_status_code = status;
warn!("console responded with an error ({status}): {body:?}");
Err(ControlPlaneError::Message(Box::new(body)))
Err(ControlPlaneError::Message(body))
}
fn parse_host_port(input: &str) -> Option<(&str, u16)> {

View File

@@ -15,14 +15,14 @@ use crate::auth::backend::ComputeUserInfo;
use crate::auth::backend::jwt::AuthRule;
use crate::cache::Cached;
use crate::context::RequestContext;
use crate::control_plane::client::{
CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedRoleSecret,
};
use crate::control_plane::errors::{
ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError,
};
use crate::control_plane::messages::MetricsAuxInfo;
use crate::control_plane::{AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo};
use crate::control_plane::{
AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, EndpointAccessControl, NodeInfo,
RoleAccessControl,
};
use crate::intern::RoleNameInt;
use crate::types::{BranchId, EndpointId, ProjectId, RoleName};
use crate::url::ApiUrl;
@@ -66,7 +66,8 @@ impl MockControlPlane {
async fn do_get_auth_info(
&self,
user_info: &ComputeUserInfo,
endpoint: &EndpointId,
role: &RoleName,
) -> Result<AuthInfo, GetAuthInfoError> {
let (secret, allowed_ips) = async {
// Perhaps we could persist this connection, but then we'd have to
@@ -80,7 +81,7 @@ impl MockControlPlane {
let secret = if let Some(entry) = get_execute_postgres_query(
&client,
"select rolpassword from pg_catalog.pg_authid where rolname = $1",
&[&&*user_info.user],
&[&role.as_str()],
"rolpassword",
)
.await?
@@ -89,7 +90,7 @@ impl MockControlPlane {
let secret = scram::ServerSecret::parse(&entry).map(AuthSecret::Scram);
secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5))
} else {
warn!("user '{}' does not exist", user_info.user);
warn!("user '{role}' does not exist");
None
};
@@ -97,7 +98,7 @@ impl MockControlPlane {
match get_execute_postgres_query(
&client,
"select allowed_ips from neon_control_plane.endpoints where endpoint_id = $1",
&[&user_info.endpoint.as_str()],
&[&endpoint.as_str()],
"allowed_ips",
)
.await?
@@ -133,7 +134,7 @@ impl MockControlPlane {
async fn do_get_endpoint_jwks(
&self,
endpoint: EndpointId,
endpoint: &EndpointId,
) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
let (client, connection) =
tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
@@ -222,53 +223,36 @@ async fn get_execute_postgres_query(
}
impl super::ControlPlaneApi for MockControlPlane {
#[tracing::instrument(skip_all)]
async fn get_role_secret(
async fn get_endpoint_access_control(
&self,
_ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> Result<CachedRoleSecret, GetAuthInfoError> {
Ok(CachedRoleSecret::new_uncached(
self.do_get_auth_info(user_info).await?.secret,
))
endpoint: &EndpointId,
role: &RoleName,
) -> Result<EndpointAccessControl, GetAuthInfoError> {
let info = self.do_get_auth_info(endpoint, role).await?;
Ok(EndpointAccessControl {
allowed_ips: Arc::new(info.allowed_ips),
allowed_vpce: Arc::new(info.allowed_vpc_endpoint_ids),
flags: info.access_blocker_flags,
})
}
async fn get_allowed_ips(
async fn get_role_access_control(
&self,
_ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> Result<CachedAllowedIps, GetAuthInfoError> {
Ok(Cached::new_uncached(Arc::new(
self.do_get_auth_info(user_info).await?.allowed_ips,
)))
}
async fn get_allowed_vpc_endpoint_ids(
&self,
_ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> Result<CachedAllowedVpcEndpointIds, super::errors::GetAuthInfoError> {
Ok(Cached::new_uncached(Arc::new(
self.do_get_auth_info(user_info)
.await?
.allowed_vpc_endpoint_ids,
)))
}
async fn get_block_public_or_vpc_access(
&self,
_ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> Result<super::CachedAccessBlockerFlags, super::errors::GetAuthInfoError> {
Ok(Cached::new_uncached(
self.do_get_auth_info(user_info).await?.access_blocker_flags,
))
endpoint: &EndpointId,
role: &RoleName,
) -> Result<RoleAccessControl, GetAuthInfoError> {
let info = self.do_get_auth_info(endpoint, role).await?;
Ok(RoleAccessControl {
secret: info.secret,
})
}
async fn get_endpoint_jwks(
&self,
_ctx: &RequestContext,
endpoint: EndpointId,
endpoint: &EndpointId,
) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
self.do_get_endpoint_jwks(endpoint).await
}

View File

@@ -16,15 +16,14 @@ use crate::cache::endpoints::EndpointsCache;
use crate::cache::project_info::ProjectInfoCacheImpl;
use crate::config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions};
use crate::context::RequestContext;
use crate::control_plane::{
CachedAccessBlockerFlags, CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedNodeInfo,
CachedRoleSecret, ControlPlaneApi, NodeInfoCache, errors,
};
use crate::control_plane::{CachedNodeInfo, ControlPlaneApi, NodeInfoCache, errors};
use crate::error::ReportableError;
use crate::metrics::ApiLockMetrics;
use crate::rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token};
use crate::types::EndpointId;
use super::{EndpointAccessControl, RoleAccessControl};
#[non_exhaustive]
#[derive(Clone)]
pub enum ControlPlaneClient {
@@ -40,68 +39,42 @@ pub enum ControlPlaneClient {
}
impl ControlPlaneApi for ControlPlaneClient {
async fn get_role_secret(
async fn get_role_access_control(
&self,
ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> Result<CachedRoleSecret, errors::GetAuthInfoError> {
endpoint: &EndpointId,
role: &crate::types::RoleName,
) -> Result<RoleAccessControl, errors::GetAuthInfoError> {
match self {
Self::ProxyV1(api) => api.get_role_secret(ctx, user_info).await,
Self::ProxyV1(api) => api.get_role_access_control(ctx, endpoint, role).await,
#[cfg(any(test, feature = "testing"))]
Self::PostgresMock(api) => api.get_role_secret(ctx, user_info).await,
Self::PostgresMock(api) => api.get_role_access_control(ctx, endpoint, role).await,
#[cfg(test)]
Self::Test(_) => {
Self::Test(_api) => {
unreachable!("this function should never be called in the test backend")
}
}
}
async fn get_allowed_ips(
async fn get_endpoint_access_control(
&self,
ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> Result<CachedAllowedIps, errors::GetAuthInfoError> {
endpoint: &EndpointId,
role: &crate::types::RoleName,
) -> Result<EndpointAccessControl, errors::GetAuthInfoError> {
match self {
Self::ProxyV1(api) => api.get_allowed_ips(ctx, user_info).await,
Self::ProxyV1(api) => api.get_endpoint_access_control(ctx, endpoint, role).await,
#[cfg(any(test, feature = "testing"))]
Self::PostgresMock(api) => api.get_allowed_ips(ctx, user_info).await,
Self::PostgresMock(api) => api.get_endpoint_access_control(ctx, endpoint, role).await,
#[cfg(test)]
Self::Test(api) => api.get_allowed_ips(),
}
}
async fn get_allowed_vpc_endpoint_ids(
&self,
ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> Result<CachedAllowedVpcEndpointIds, errors::GetAuthInfoError> {
match self {
Self::ProxyV1(api) => api.get_allowed_vpc_endpoint_ids(ctx, user_info).await,
#[cfg(any(test, feature = "testing"))]
Self::PostgresMock(api) => api.get_allowed_vpc_endpoint_ids(ctx, user_info).await,
#[cfg(test)]
Self::Test(api) => api.get_allowed_vpc_endpoint_ids(),
}
}
async fn get_block_public_or_vpc_access(
&self,
ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> Result<CachedAccessBlockerFlags, errors::GetAuthInfoError> {
match self {
Self::ProxyV1(api) => api.get_block_public_or_vpc_access(ctx, user_info).await,
#[cfg(any(test, feature = "testing"))]
Self::PostgresMock(api) => api.get_block_public_or_vpc_access(ctx, user_info).await,
#[cfg(test)]
Self::Test(api) => api.get_block_public_or_vpc_access(),
Self::Test(api) => api.get_access_control(),
}
}
async fn get_endpoint_jwks(
&self,
ctx: &RequestContext,
endpoint: EndpointId,
endpoint: &EndpointId,
) -> Result<Vec<AuthRule>, errors::GetEndpointJwksError> {
match self {
Self::ProxyV1(api) => api.get_endpoint_jwks(ctx, endpoint).await,
@@ -131,15 +104,7 @@ impl ControlPlaneApi for ControlPlaneClient {
pub(crate) trait TestControlPlaneClient: Send + Sync + 'static {
fn wake_compute(&self) -> Result<CachedNodeInfo, errors::WakeComputeError>;
fn get_allowed_ips(&self) -> Result<CachedAllowedIps, errors::GetAuthInfoError>;
fn get_allowed_vpc_endpoint_ids(
&self,
) -> Result<CachedAllowedVpcEndpointIds, errors::GetAuthInfoError>;
fn get_block_public_or_vpc_access(
&self,
) -> Result<CachedAccessBlockerFlags, errors::GetAuthInfoError>;
fn get_access_control(&self) -> Result<EndpointAccessControl, errors::GetAuthInfoError>;
fn dyn_clone(&self) -> Box<dyn TestControlPlaneClient>;
}
@@ -309,7 +274,7 @@ impl FetchAuthRules for ControlPlaneClient {
ctx: &RequestContext,
endpoint: EndpointId,
) -> Result<Vec<AuthRule>, FetchAuthRulesError> {
self.get_endpoint_jwks(ctx, endpoint)
self.get_endpoint_jwks(ctx, &endpoint)
.await
.map_err(FetchAuthRulesError::GetEndpointJwks)
}

View File

@@ -99,6 +99,10 @@ pub(crate) enum GetAuthInfoError {
#[error(transparent)]
ApiError(ControlPlaneError),
/// Proxy does not know about the endpoint in advanced
#[error("endpoint not found in endpoint cache")]
UnknownEndpoint,
}
// This allows more useful interactions than `#[from]`.
@@ -115,6 +119,8 @@ impl UserFacingError for GetAuthInfoError {
Self::BadSecret => REQUEST_FAILED.to_owned(),
// However, API might return a meaningful error.
Self::ApiError(e) => e.to_string_client(),
// pretend like control plane returned an error.
Self::UnknownEndpoint => REQUEST_FAILED.to_owned(),
}
}
}
@@ -124,6 +130,8 @@ impl ReportableError for GetAuthInfoError {
match self {
Self::BadSecret => crate::error::ErrorKind::ControlPlane,
Self::ApiError(_) => crate::error::ErrorKind::ControlPlane,
// we only apply endpoint filtering if control plane is under high load.
Self::UnknownEndpoint => crate::error::ErrorKind::ServiceRateLimit,
}
}
}

View File

@@ -11,16 +11,16 @@ pub(crate) mod errors;
use std::sync::Arc;
use crate::auth::IpPattern;
use crate::auth::backend::jwt::AuthRule;
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
use crate::cache::project_info::ProjectInfoCacheImpl;
use crate::auth::{AuthError, IpPattern, check_peer_addr_is_in_list};
use crate::cache::{Cached, TimedLru};
use crate::config::ComputeConfig;
use crate::context::RequestContext;
use crate::control_plane::messages::{ControlPlaneErrorMessage, MetricsAuxInfo};
use crate::intern::{AccountIdInt, ProjectIdInt};
use crate::types::{EndpointCacheKey, EndpointId};
use crate::protocol2::ConnectionInfoExtra;
use crate::types::{EndpointCacheKey, EndpointId, RoleName};
use crate::{compute, scram};
/// Various cache-related types.
@@ -101,7 +101,7 @@ impl NodeInfo {
}
}
#[derive(Clone, Default, Eq, PartialEq, Debug)]
#[derive(Copy, Clone, Default)]
pub(crate) struct AccessBlockerFlags {
pub public_access_blocked: bool,
pub vpc_access_blocked: bool,
@@ -110,47 +110,78 @@ pub(crate) struct AccessBlockerFlags {
pub(crate) type NodeInfoCache =
TimedLru<EndpointCacheKey, Result<NodeInfo, Box<ControlPlaneErrorMessage>>>;
pub(crate) type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>;
pub(crate) type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option<AuthSecret>>;
pub(crate) type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc<Vec<IpPattern>>>;
pub(crate) type CachedAllowedVpcEndpointIds =
Cached<&'static ProjectInfoCacheImpl, Arc<Vec<String>>>;
pub(crate) type CachedAccessBlockerFlags =
Cached<&'static ProjectInfoCacheImpl, AccessBlockerFlags>;
#[derive(Clone)]
pub struct RoleAccessControl {
pub secret: Option<AuthSecret>,
}
#[derive(Clone)]
pub struct EndpointAccessControl {
pub allowed_ips: Arc<Vec<IpPattern>>,
pub allowed_vpce: Arc<Vec<String>>,
pub flags: AccessBlockerFlags,
}
impl EndpointAccessControl {
pub fn check(
&self,
ctx: &RequestContext,
check_ip_allowed: bool,
check_vpc_allowed: bool,
) -> Result<(), AuthError> {
if check_ip_allowed && !check_peer_addr_is_in_list(&ctx.peer_addr(), &self.allowed_ips) {
return Err(AuthError::IpAddressNotAllowed(ctx.peer_addr()));
}
// check if a VPC endpoint ID is coming in and if yes, if it's allowed
if check_vpc_allowed {
if self.flags.vpc_access_blocked {
return Err(AuthError::NetworkNotAllowed);
}
let incoming_vpc_endpoint_id = match ctx.extra() {
None => return Err(AuthError::MissingVPCEndpointId),
Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(),
Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(),
};
let vpce = &self.allowed_vpce;
// TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that.
if !vpce.is_empty() && !vpce.contains(&incoming_vpc_endpoint_id) {
return Err(AuthError::vpc_endpoint_id_not_allowed(
incoming_vpc_endpoint_id,
));
}
} else if self.flags.public_access_blocked {
return Err(AuthError::NetworkNotAllowed);
}
Ok(())
}
}
/// This will allocate per each call, but the http requests alone
/// already require a few allocations, so it should be fine.
pub(crate) trait ControlPlaneApi {
/// Get the client's auth secret for authentication.
/// Returns option because user not found situation is special.
/// We still have to mock the scram to avoid leaking information that user doesn't exist.
async fn get_role_secret(
async fn get_role_access_control(
&self,
ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> Result<CachedRoleSecret, errors::GetAuthInfoError>;
endpoint: &EndpointId,
role: &RoleName,
) -> Result<RoleAccessControl, errors::GetAuthInfoError>;
async fn get_allowed_ips(
async fn get_endpoint_access_control(
&self,
ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> Result<CachedAllowedIps, errors::GetAuthInfoError>;
async fn get_allowed_vpc_endpoint_ids(
&self,
ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> Result<CachedAllowedVpcEndpointIds, errors::GetAuthInfoError>;
async fn get_block_public_or_vpc_access(
&self,
ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> Result<CachedAccessBlockerFlags, errors::GetAuthInfoError>;
endpoint: &EndpointId,
role: &RoleName,
) -> Result<EndpointAccessControl, errors::GetAuthInfoError>;
async fn get_endpoint_jwks(
&self,
ctx: &RequestContext,
endpoint: EndpointId,
endpoint: &EndpointId,
) -> Result<Vec<AuthRule>, errors::GetEndpointJwksError>;
/// Wake up the compute node and return the corresponding connection info.

View File

@@ -4,9 +4,10 @@
pub mod health_server;
use std::time::Duration;
use std::time::{Duration, Instant};
use bytes::Bytes;
use futures::FutureExt;
use http::Method;
use http_body_util::BodyExt;
use hyper::body::Body;
@@ -109,15 +110,31 @@ impl Endpoint {
}
/// Execute a [request](reqwest::Request).
pub(crate) async fn execute(&self, request: Request) -> Result<Response, Error> {
let _timer = Metrics::get()
pub(crate) fn execute(
&self,
request: Request,
) -> impl Future<Output = Result<Response, Error>> {
let metric = Metrics::get()
.proxy
.console_request_latency
.start_timer(ConsoleRequest {
.with_labels(ConsoleRequest {
request: request.url().path(),
});
self.client.execute(request).await
let req = self.client.execute(request).boxed();
async move {
let start = Instant::now();
scopeguard::defer!({
Metrics::get()
.proxy
.console_request_latency
.get_metric(metric)
.observe_duration_since(start);
});
req.await
}
}
}

View File

@@ -186,7 +186,7 @@ where
pub async fn read_message<'a, S>(
stream: &mut S,
buf: &'a mut Vec<u8>,
max: usize,
max: u32,
) -> io::Result<(u8, &'a mut [u8])>
where
S: AsyncRead + Unpin,
@@ -206,7 +206,7 @@ where
let header = read!(stream => Header);
// as described above, the length must be at least 4.
let Some(len) = (header.len.get() as usize).checked_sub(4) else {
let Some(len) = header.len.get().checked_sub(4) else {
return Err(io::Error::other(format!(
"invalid startup message length {}, must be at least 4.",
header.len,
@@ -222,7 +222,7 @@ where
}
// read in our entire message.
buf.resize(len, 0);
buf.resize(len as usize, 0);
stream.read_exact(buf).await?;
Ok((header.tag, buf))

View File

@@ -1,3 +1,4 @@
use futures::{FutureExt, TryFutureExt};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{debug, info, warn};
@@ -57,7 +58,7 @@ pub(crate) enum HandshakeData<S> {
/// It's easier to work with owned `stream` here as we need to upgrade it to TLS;
/// we also take an extra care of propagating only the select handshake errors to client.
#[tracing::instrument(skip_all)]
pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin + Send>(
ctx: &RequestContext,
stream: S,
mut tls: Option<&TlsConfig>,
@@ -108,7 +109,9 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
}
}
}
});
})
.map_ok(Box::new)
.boxed();
res?;
@@ -146,7 +149,7 @@ pub(crate) async fn handshake<S: AsyncRead + AsyncWrite + Unpin>(
tls.cert_resolver.resolve(conn_info.server_name());
let tls = Stream::Tls {
tls: Box::new(tls_stream),
tls: tls_stream,
tls_server_end_point,
};
(stream, msg) = PqStream::parse_startup(tls).await?;

View File

@@ -270,7 +270,7 @@ impl ReportableError for ClientRequestError {
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
config: &'static ProxyConfig,
auth_backend: &'static auth::Backend<'static, ()>,
ctx: &RequestContext,
@@ -345,7 +345,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
};
let user = user_info.get_user().to_owned();
let (user_info, _ip_allowlist) = match user_info
let user_info = match user_info
.authenticate(
ctx,
&mut stream,

View File

@@ -1,3 +1,4 @@
use futures::FutureExt;
use smol_str::SmolStr;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::debug;
@@ -89,6 +90,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {
.compute
.cancel_closure
.try_cancel_query(compute_config)
.boxed()
.await
{
tracing::warn!(session_id = ?self.session_id, ?err, "could not cancel the query in the database");

View File

@@ -26,9 +26,7 @@ use crate::auth::backend::{
use crate::config::{ComputeConfig, RetryConfig};
use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient};
use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status};
use crate::control_plane::{
self, CachedAllowedIps, CachedAllowedVpcEndpointIds, CachedNodeInfo, NodeInfo, NodeInfoCache,
};
use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache};
use crate::error::ErrorKind;
use crate::tls::client_config::compute_client_config_with_certs;
use crate::tls::postgres_rustls::MakeRustlsConnect;
@@ -547,20 +545,9 @@ impl TestControlPlaneClient for TestConnectMechanism {
}
}
fn get_allowed_ips(&self) -> Result<CachedAllowedIps, control_plane::errors::GetAuthInfoError> {
unimplemented!("not used in tests")
}
fn get_allowed_vpc_endpoint_ids(
fn get_access_control(
&self,
) -> Result<CachedAllowedVpcEndpointIds, control_plane::errors::GetAuthInfoError> {
unimplemented!("not used in tests")
}
fn get_block_public_or_vpc_access(
&self,
) -> Result<control_plane::CachedAccessBlockerFlags, control_plane::errors::GetAuthInfoError>
{
) -> Result<control_plane::EndpointAccessControl, control_plane::errors::GetAuthInfoError> {
unimplemented!("not used in tests")
}

View File

@@ -15,7 +15,7 @@ pub type EndpointRateLimiter = LeakyBucketRateLimiter<EndpointIdInt>;
pub struct LeakyBucketRateLimiter<Key> {
map: ClashMap<Key, LeakyBucketState, RandomState>,
config: utils::leaky_bucket::LeakyBucketConfig,
default_config: utils::leaky_bucket::LeakyBucketConfig,
access_count: AtomicUsize,
}
@@ -28,15 +28,17 @@ impl<K: Hash + Eq> LeakyBucketRateLimiter<K> {
pub fn new_with_shards(config: LeakyBucketConfig, shards: usize) -> Self {
Self {
map: ClashMap::with_hasher_and_shard_amount(RandomState::new(), shards),
config: config.into(),
default_config: config.into(),
access_count: AtomicUsize::new(0),
}
}
/// Check that number of connections to the endpoint is below `max_rps` rps.
pub(crate) fn check(&self, key: K, n: u32) -> bool {
pub(crate) fn check(&self, key: K, config: Option<LeakyBucketConfig>, n: u32) -> bool {
let now = Instant::now();
let config = config.map_or(self.default_config, Into::into);
if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 {
self.do_gc(now);
}
@@ -46,7 +48,7 @@ impl<K: Hash + Eq> LeakyBucketRateLimiter<K> {
.entry(key)
.or_insert_with(|| LeakyBucketState { empty_at: now });
entry.add_tokens(&self.config, now, n as f64).is_ok()
entry.add_tokens(&config, now, n as f64).is_ok()
}
fn do_gc(&self, now: Instant) {

View File

@@ -15,6 +15,8 @@ use tracing::info;
use crate::ext::LockExt;
use crate::intern::EndpointIdInt;
use super::LeakyBucketConfig;
pub struct GlobalRateLimiter {
data: Vec<RateBucket>,
info: Vec<RateBucketInfo>,
@@ -144,19 +146,6 @@ impl RateBucketInfo {
Self::new(50_000, Duration::from_secs(10)),
];
/// All of these are per endpoint-maskedip pair.
/// Context: 4096 rounds of pbkdf2 take about 1ms of cpu time to execute (1 milli-cpu-second or 1mcpus).
///
/// First bucket: 1000mcpus total per endpoint-ip pair
/// * 4096000 requests per second with 1 hash rounds.
/// * 1000 requests per second with 4096 hash rounds.
/// * 6.8 requests per second with 600000 hash rounds.
pub const DEFAULT_AUTH_SET: [Self; 3] = [
Self::new(1000 * 4096, Duration::from_secs(1)),
Self::new(600 * 4096, Duration::from_secs(60)),
Self::new(300 * 4096, Duration::from_secs(600)),
];
pub fn rps(&self) -> f64 {
(self.max_rpi as f64) / self.interval.as_secs_f64()
}
@@ -184,6 +173,21 @@ impl RateBucketInfo {
max_rpi: ((max_rps as u64) * (interval.as_millis() as u64) / 1000) as u32,
}
}
pub fn to_leaky_bucket(this: &[Self]) -> Option<LeakyBucketConfig> {
// bit of a hack - find the min rps and max rps supported and turn it into
// leaky bucket config instead
let mut iter = this.iter().map(|info| info.rps());
let first = iter.next()?;
let (min, max) = (first, first);
let (min, max) = iter.fold((min, max), |(min, max), rps| {
(f64::min(min, rps), f64::max(max, rps))
});
Some(LeakyBucketConfig { rps: min, max })
}
}
impl<K: Hash + Eq> BucketRateLimiter<K> {

View File

@@ -8,4 +8,4 @@ pub(crate) use limit_algorithm::aimd::Aimd;
pub(crate) use limit_algorithm::{
DynamicLimiter, Outcome, RateLimitAlgorithm, RateLimiterConfig, Token,
};
pub use limiter::{BucketRateLimiter, GlobalRateLimiter, RateBucketInfo, WakeComputeRateLimiter};
pub use limiter::{GlobalRateLimiter, RateBucketInfo, WakeComputeRateLimiter};

View File

@@ -233,29 +233,30 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
match msg {
Notification::AllowedIpsUpdate { allowed_ips_update } => {
cache.invalidate_allowed_ips_for_project(allowed_ips_update.project_id);
Notification::AllowedIpsUpdate {
allowed_ips_update: AllowedIpsUpdate { project_id },
}
Notification::BlockPublicOrVpcAccessUpdated {
block_public_or_vpc_access_updated,
} => cache.invalidate_block_public_or_vpc_access_for_project(
block_public_or_vpc_access_updated.project_id,
),
| Notification::BlockPublicOrVpcAccessUpdated {
block_public_or_vpc_access_updated: BlockPublicOrVpcAccessUpdated { project_id },
} => cache.invalidate_endpoint_access_for_project(project_id),
Notification::AllowedVpcEndpointsUpdatedForOrg {
allowed_vpc_endpoints_updated_for_org,
} => cache.invalidate_allowed_vpc_endpoint_ids_for_org(
allowed_vpc_endpoints_updated_for_org.account_id,
),
allowed_vpc_endpoints_updated_for_org: AllowedVpcEndpointsUpdatedForOrg { account_id },
} => cache.invalidate_endpoint_access_for_org(account_id),
Notification::AllowedVpcEndpointsUpdatedForProjects {
allowed_vpc_endpoints_updated_for_projects,
} => cache.invalidate_allowed_vpc_endpoint_ids_for_projects(
allowed_vpc_endpoints_updated_for_projects.project_ids,
),
Notification::PasswordUpdate { password_update } => cache
.invalidate_role_secret_for_project(
password_update.project_id,
password_update.role_name,
),
allowed_vpc_endpoints_updated_for_projects:
AllowedVpcEndpointsUpdatedForProjects { project_ids },
} => {
for project in project_ids {
cache.invalidate_endpoint_access_for_project(project);
}
}
Notification::PasswordUpdate {
password_update:
PasswordUpdate {
project_id,
role_name,
},
} => cache.invalidate_role_secret_for_project(project_id, role_name),
Notification::UnknownTopic => unreachable!(),
}
}

View File

@@ -30,52 +30,53 @@ where
F: FnOnce(&str) -> super::Result<M>,
M: Mechanism,
{
let sasl = {
let (mut mechanism, mut input) = {
// pause the timer while we communicate with the client
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
// Initial client message contains the chosen auth method's name.
let msg = stream.read_password_message().await?;
super::FirstMessage::parse(msg).ok_or(super::Error::BadClientMessage("bad sasl message"))?
let sasl = super::FirstMessage::parse(msg)
.ok_or(super::Error::BadClientMessage("bad sasl message"))?;
(mechanism(sasl.method)?, sasl.message)
};
let mut mechanism = mechanism(sasl.method)?;
let mut input = sasl.message;
loop {
let step = mechanism
.exchange(input)
.inspect_err(|error| tracing::info!(?error, "error during SASL exchange"))?;
match step {
Step::Continue(moved_mechanism, reply) => {
match mechanism.exchange(input) {
Ok(Step::Continue(moved_mechanism, reply)) => {
mechanism = moved_mechanism;
// pause the timer while we communicate with the client
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
// write reply
let sasl_msg = BeAuthenticationSaslMessage::Continue(reply.as_bytes());
stream.write_message(BeMessage::AuthenticationSasl(sasl_msg));
// get next input
stream.flush().await?;
let msg = stream.read_password_message().await?;
input = std::str::from_utf8(msg)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?;
drop(reply);
}
Step::Success(result, reply) => {
// pause the timer while we communicate with the client
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
Ok(Step::Success(result, reply)) => {
// write reply
let sasl_msg = BeAuthenticationSaslMessage::Final(reply.as_bytes());
stream.write_message(BeMessage::AuthenticationSasl(sasl_msg));
stream.write_message(BeMessage::AuthenticationOk);
// exit with success
break Ok(Outcome::Success(result));
}
// exit with failure
Step::Failure(reason) => break Ok(Outcome::Failure(reason)),
Ok(Step::Failure(reason)) => break Ok(Outcome::Failure(reason)),
Err(error) => {
tracing::info!(?error, "error during SASL exchange");
return Err(error);
}
}
// pause the timer while we communicate with the client
let _paused = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
// get next input
stream.flush().await?;
let msg = stream.read_password_message().await?;
input = std::str::from_utf8(msg)
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "bad encoding"))?;
}
}

View File

@@ -22,7 +22,7 @@ use super::http_conn_pool::{self, HttpConnPool, Send, poll_http2_client};
use super::local_conn_pool::{self, EXT_NAME, EXT_SCHEMA, EXT_VERSION, LocalConnPool};
use crate::auth::backend::local::StaticAuthRules;
use crate::auth::backend::{ComputeCredentials, ComputeUserInfo};
use crate::auth::{self, AuthError, check_peer_addr_is_in_list};
use crate::auth::{self, AuthError};
use crate::compute;
use crate::compute_ctl::{
ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest,
@@ -35,7 +35,6 @@ use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
use crate::control_plane::locks::ApiLocks;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::intern::EndpointIdInt;
use crate::protocol2::ConnectionInfoExtra;
use crate::proxy::connect_compute::ConnectMechanism;
use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute};
use crate::rate_limiter::EndpointRateLimiter;
@@ -63,63 +62,24 @@ impl PoolingBackend {
let user_info = user_info.clone();
let backend = self.auth_backend.as_ref().map(|()| user_info.clone());
let allowed_ips = backend.get_allowed_ips(ctx).await?;
let access_control = backend.get_endpoint_access_control(ctx).await?;
access_control.check(
ctx,
self.config.authentication_config.ip_allowlist_check_enabled,
self.config.authentication_config.is_vpc_acccess_proxy,
)?;
if self.config.authentication_config.ip_allowlist_check_enabled
&& !check_peer_addr_is_in_list(&ctx.peer_addr(), &allowed_ips)
{
return Err(AuthError::ip_address_not_allowed(ctx.peer_addr()));
}
let access_blocker_flags = backend.get_block_public_or_vpc_access(ctx).await?;
if self.config.authentication_config.is_vpc_acccess_proxy {
if access_blocker_flags.vpc_access_blocked {
return Err(AuthError::NetworkNotAllowed);
}
let extra = ctx.extra();
let incoming_endpoint_id = match extra {
None => String::new(),
Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(),
Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(),
};
if incoming_endpoint_id.is_empty() {
return Err(AuthError::MissingVPCEndpointId);
}
let allowed_vpc_endpoint_ids = backend.get_allowed_vpc_endpoint_ids(ctx).await?;
// TODO: For now an empty VPC endpoint ID list means all are allowed. We should replace that.
if !allowed_vpc_endpoint_ids.is_empty()
&& !allowed_vpc_endpoint_ids.contains(&incoming_endpoint_id)
{
return Err(AuthError::vpc_endpoint_id_not_allowed(incoming_endpoint_id));
}
} else if access_blocker_flags.public_access_blocked {
return Err(AuthError::NetworkNotAllowed);
}
if !self
.endpoint_rate_limiter
.check(user_info.endpoint.clone().into(), 1)
{
let ep = EndpointIdInt::from(&user_info.endpoint);
let rate_limit_config = None;
if !self.endpoint_rate_limiter.check(ep, rate_limit_config, 1) {
return Err(AuthError::too_many_connections());
}
let cached_secret = backend.get_role_secret(ctx).await?;
let secret = match cached_secret.value.clone() {
Some(secret) => self.config.authentication_config.check_rate_limit(
ctx,
secret,
&user_info.endpoint,
true,
)?,
None => {
// If we don't have an authentication secret, for the http flow we can just return an error.
info!("authentication info not found");
return Err(AuthError::password_failed(&*user_info.user));
}
let role_access = backend.get_role_secret(ctx).await?;
let Some(secret) = role_access.secret else {
// If we don't have an authentication secret, for the http flow we can just return an error.
info!("authentication info not found");
return Err(AuthError::password_failed(&*user_info.user));
};
let ep = EndpointIdInt::from(&user_info.endpoint);
let auth_outcome = crate::auth::validate_password_and_exchange(
&self.config.authentication_config.thread_pool,
ep,

View File

@@ -72,7 +72,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> PqStream<S> {
impl<S: AsyncRead + Unpin> PqStream<S> {
/// Read a raw postgres packet, which will respect the max length requested.
/// This is not cancel safe.
async fn read_raw_expect(&mut self, tag: u8, max: usize) -> io::Result<&mut [u8]> {
async fn read_raw_expect(&mut self, tag: u8, max: u32) -> io::Result<&mut [u8]> {
let (actual_tag, msg) = read_message(&mut self.stream, &mut self.read, max).await?;
if actual_tag != tag {
return Err(io::Error::other(format!(
@@ -89,7 +89,7 @@ impl<S: AsyncRead + Unpin> PqStream<S> {
// passwords are usually pretty short
// and SASL SCRAM messages are no longer than 256 bytes in my testing
// (a few hashes and random bytes, encoded into base64).
const MAX_PASSWORD_LENGTH: usize = 512;
const MAX_PASSWORD_LENGTH: u32 = 512;
self.read_raw_expect(FE_PASSWORD_MESSAGE, MAX_PASSWORD_LENGTH)
.await
}

View File

@@ -31,7 +31,9 @@ mod private {
type Output = io::Result<RustlsStream<S>>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.inner).poll(cx).map_ok(RustlsStream)
Pin::new(&mut self.inner)
.poll(cx)
.map_ok(|s| RustlsStream(Box::new(s)))
}
}
@@ -57,7 +59,7 @@ mod private {
}
}
pub struct RustlsStream<S>(TlsStream<S>);
pub struct RustlsStream<S>(Box<TlsStream<S>>);
impl<S> postgres_client::tls::TlsStream for RustlsStream<S>
where