mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-08 22:12:56 +00:00
proxy: rename console -> control_plane, rename web -> console_redirect (#9266)
rename console -> control_plane rename web -> console_redirect I think these names are a little more representative.
This commit is contained in:
480
proxy/src/control_plane/messages.rs
Normal file
480
proxy/src/control_plane/messages.rs
Normal file
@@ -0,0 +1,480 @@
|
||||
use measured::FixedCardinalityLabel;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt::{self, Display};
|
||||
|
||||
use crate::auth::IpPattern;
|
||||
|
||||
use crate::intern::{BranchIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
|
||||
use crate::proxy::retry::CouldRetry;
|
||||
|
||||
/// Generic error response with human-readable description.
|
||||
/// Note that we can't always present it to user as is.
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub(crate) struct ControlPlaneError {
|
||||
pub(crate) error: Box<str>,
|
||||
#[serde(skip)]
|
||||
pub(crate) http_status_code: http::StatusCode,
|
||||
pub(crate) status: Option<Status>,
|
||||
}
|
||||
|
||||
impl ControlPlaneError {
|
||||
pub(crate) fn get_reason(&self) -> Reason {
|
||||
self.status
|
||||
.as_ref()
|
||||
.and_then(|s| s.details.error_info.as_ref())
|
||||
.map_or(Reason::Unknown, |e| e.reason)
|
||||
}
|
||||
|
||||
pub(crate) fn get_user_facing_message(&self) -> String {
|
||||
use super::provider::errors::REQUEST_FAILED;
|
||||
self.status
|
||||
.as_ref()
|
||||
.and_then(|s| s.details.user_facing_message.as_ref())
|
||||
.map_or_else(|| {
|
||||
// Ask @neondatabase/control-plane for review before adding more.
|
||||
match self.http_status_code {
|
||||
http::StatusCode::NOT_FOUND => {
|
||||
// Status 404: failed to get a project-related resource.
|
||||
format!("{REQUEST_FAILED}: endpoint cannot be found")
|
||||
}
|
||||
http::StatusCode::NOT_ACCEPTABLE => {
|
||||
// Status 406: endpoint is disabled (we don't allow connections).
|
||||
format!("{REQUEST_FAILED}: endpoint is disabled")
|
||||
}
|
||||
http::StatusCode::LOCKED | http::StatusCode::UNPROCESSABLE_ENTITY => {
|
||||
// Status 423: project might be in maintenance mode (or bad state), or quotas exceeded.
|
||||
format!("{REQUEST_FAILED}: endpoint is temporarily unavailable. Check your quotas and/or contact our support.")
|
||||
}
|
||||
_ => REQUEST_FAILED.to_owned(),
|
||||
}
|
||||
}, |m| m.message.clone().into())
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for ControlPlaneError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let msg: &str = self
|
||||
.status
|
||||
.as_ref()
|
||||
.and_then(|s| s.details.user_facing_message.as_ref())
|
||||
.map_or_else(|| self.error.as_ref(), |m| m.message.as_ref());
|
||||
write!(f, "{msg}")
|
||||
}
|
||||
}
|
||||
|
||||
impl CouldRetry for ControlPlaneError {
|
||||
fn could_retry(&self) -> bool {
|
||||
// If the error message does not have a status,
|
||||
// the error is unknown and probably should not retry automatically
|
||||
let Some(status) = &self.status else {
|
||||
return false;
|
||||
};
|
||||
|
||||
// retry if the retry info is set.
|
||||
if status.details.retry_info.is_some() {
|
||||
return true;
|
||||
}
|
||||
|
||||
// if no retry info set, attempt to use the error code to guess the retry state.
|
||||
let reason = status
|
||||
.details
|
||||
.error_info
|
||||
.map_or(Reason::Unknown, |e| e.reason);
|
||||
|
||||
reason.can_retry()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
#[allow(dead_code)]
|
||||
pub(crate) struct Status {
|
||||
pub(crate) code: Box<str>,
|
||||
pub(crate) message: Box<str>,
|
||||
pub(crate) details: Details,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub(crate) struct Details {
|
||||
pub(crate) error_info: Option<ErrorInfo>,
|
||||
pub(crate) retry_info: Option<RetryInfo>,
|
||||
pub(crate) user_facing_message: Option<UserFacingMessage>,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, Deserialize)]
|
||||
pub(crate) struct ErrorInfo {
|
||||
pub(crate) reason: Reason,
|
||||
// Schema could also have `metadata` field, but it's not structured. Skip it for now.
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Deserialize, Default)]
|
||||
pub(crate) enum Reason {
|
||||
/// RoleProtected indicates that the role is protected and the attempted operation is not permitted on protected roles.
|
||||
#[serde(rename = "ROLE_PROTECTED")]
|
||||
RoleProtected,
|
||||
/// ResourceNotFound indicates that a resource (project, endpoint, branch, etc.) wasn't found,
|
||||
/// usually due to the provided ID not being correct or because the subject doesn't have enough permissions to
|
||||
/// access the requested resource.
|
||||
/// Prefer a more specific reason if possible, e.g., ProjectNotFound, EndpointNotFound, etc.
|
||||
#[serde(rename = "RESOURCE_NOT_FOUND")]
|
||||
ResourceNotFound,
|
||||
/// ProjectNotFound indicates that the project wasn't found, usually due to the provided ID not being correct,
|
||||
/// or that the subject doesn't have enough permissions to access the requested project.
|
||||
#[serde(rename = "PROJECT_NOT_FOUND")]
|
||||
ProjectNotFound,
|
||||
/// EndpointNotFound indicates that the endpoint wasn't found, usually due to the provided ID not being correct,
|
||||
/// or that the subject doesn't have enough permissions to access the requested endpoint.
|
||||
#[serde(rename = "ENDPOINT_NOT_FOUND")]
|
||||
EndpointNotFound,
|
||||
/// BranchNotFound indicates that the branch wasn't found, usually due to the provided ID not being correct,
|
||||
/// or that the subject doesn't have enough permissions to access the requested branch.
|
||||
#[serde(rename = "BRANCH_NOT_FOUND")]
|
||||
BranchNotFound,
|
||||
/// RateLimitExceeded indicates that the rate limit for the operation has been exceeded.
|
||||
#[serde(rename = "RATE_LIMIT_EXCEEDED")]
|
||||
RateLimitExceeded,
|
||||
/// NonDefaultBranchComputeTimeExceeded indicates that the compute time quota of non-default branches has been
|
||||
/// exceeded.
|
||||
#[serde(rename = "NON_PRIMARY_BRANCH_COMPUTE_TIME_EXCEEDED")]
|
||||
NonDefaultBranchComputeTimeExceeded,
|
||||
/// ActiveTimeQuotaExceeded indicates that the active time quota was exceeded.
|
||||
#[serde(rename = "ACTIVE_TIME_QUOTA_EXCEEDED")]
|
||||
ActiveTimeQuotaExceeded,
|
||||
/// ComputeTimeQuotaExceeded indicates that the compute time quota was exceeded.
|
||||
#[serde(rename = "COMPUTE_TIME_QUOTA_EXCEEDED")]
|
||||
ComputeTimeQuotaExceeded,
|
||||
/// WrittenDataQuotaExceeded indicates that the written data quota was exceeded.
|
||||
#[serde(rename = "WRITTEN_DATA_QUOTA_EXCEEDED")]
|
||||
WrittenDataQuotaExceeded,
|
||||
/// DataTransferQuotaExceeded indicates that the data transfer quota was exceeded.
|
||||
#[serde(rename = "DATA_TRANSFER_QUOTA_EXCEEDED")]
|
||||
DataTransferQuotaExceeded,
|
||||
/// LogicalSizeQuotaExceeded indicates that the logical size quota was exceeded.
|
||||
#[serde(rename = "LOGICAL_SIZE_QUOTA_EXCEEDED")]
|
||||
LogicalSizeQuotaExceeded,
|
||||
/// RunningOperations indicates that the project already has some running operations
|
||||
/// and scheduling of new ones is prohibited.
|
||||
#[serde(rename = "RUNNING_OPERATIONS")]
|
||||
RunningOperations,
|
||||
/// ConcurrencyLimitReached indicates that the concurrency limit for an action was reached.
|
||||
#[serde(rename = "CONCURRENCY_LIMIT_REACHED")]
|
||||
ConcurrencyLimitReached,
|
||||
/// LockAlreadyTaken indicates that the we attempted to take a lock that was already taken.
|
||||
#[serde(rename = "LOCK_ALREADY_TAKEN")]
|
||||
LockAlreadyTaken,
|
||||
#[default]
|
||||
#[serde(other)]
|
||||
Unknown,
|
||||
}
|
||||
|
||||
impl Reason {
|
||||
pub(crate) fn is_not_found(self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
Reason::ResourceNotFound
|
||||
| Reason::ProjectNotFound
|
||||
| Reason::EndpointNotFound
|
||||
| Reason::BranchNotFound
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn can_retry(self) -> bool {
|
||||
match self {
|
||||
// do not retry role protected errors
|
||||
// not a transitive error
|
||||
Reason::RoleProtected => false,
|
||||
// on retry, it will still not be found
|
||||
Reason::ResourceNotFound
|
||||
| Reason::ProjectNotFound
|
||||
| Reason::EndpointNotFound
|
||||
| Reason::BranchNotFound => false,
|
||||
// we were asked to go away
|
||||
Reason::RateLimitExceeded
|
||||
| Reason::NonDefaultBranchComputeTimeExceeded
|
||||
| Reason::ActiveTimeQuotaExceeded
|
||||
| Reason::ComputeTimeQuotaExceeded
|
||||
| Reason::WrittenDataQuotaExceeded
|
||||
| Reason::DataTransferQuotaExceeded
|
||||
| Reason::LogicalSizeQuotaExceeded => false,
|
||||
// transitive error. control plane is currently busy
|
||||
// but might be ready soon
|
||||
Reason::RunningOperations
|
||||
| Reason::ConcurrencyLimitReached
|
||||
| Reason::LockAlreadyTaken => true,
|
||||
// unknown error. better not retry it.
|
||||
Reason::Unknown => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, Deserialize)]
|
||||
#[allow(dead_code)]
|
||||
pub(crate) struct RetryInfo {
|
||||
pub(crate) retry_delay_ms: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub(crate) struct UserFacingMessage {
|
||||
pub(crate) message: Box<str>,
|
||||
}
|
||||
|
||||
/// Response which holds client's auth secret, e.g. [`crate::scram::ServerSecret`].
|
||||
/// Returned by the `/proxy_get_role_secret` API method.
|
||||
#[derive(Deserialize)]
|
||||
pub(crate) struct GetRoleSecret {
|
||||
pub(crate) role_secret: Box<str>,
|
||||
pub(crate) allowed_ips: Option<Vec<IpPattern>>,
|
||||
pub(crate) project_id: Option<ProjectIdInt>,
|
||||
}
|
||||
|
||||
// Manually implement debug to omit sensitive info.
|
||||
impl fmt::Debug for GetRoleSecret {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("GetRoleSecret").finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
|
||||
/// Response which holds compute node's `host:port` pair.
|
||||
/// Returned by the `/proxy_wake_compute` API method.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub(crate) struct WakeCompute {
|
||||
pub(crate) address: Box<str>,
|
||||
pub(crate) aux: MetricsAuxInfo,
|
||||
}
|
||||
|
||||
/// Async response which concludes the web auth flow.
|
||||
/// Also known as `kickResponse` in the console.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub(crate) struct KickSession<'a> {
|
||||
/// Session ID is assigned by the proxy.
|
||||
pub(crate) session_id: &'a str,
|
||||
|
||||
/// Compute node connection params.
|
||||
#[serde(deserialize_with = "KickSession::parse_db_info")]
|
||||
pub(crate) result: DatabaseInfo,
|
||||
}
|
||||
|
||||
impl KickSession<'_> {
|
||||
fn parse_db_info<'de, D>(des: D) -> Result<DatabaseInfo, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
#[derive(Deserialize)]
|
||||
enum Wrapper {
|
||||
// Currently, console only reports `Success`.
|
||||
// `Failure(String)` used to be here... RIP.
|
||||
Success(DatabaseInfo),
|
||||
}
|
||||
|
||||
Wrapper::deserialize(des).map(|x| match x {
|
||||
Wrapper::Success(info) => info,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute node connection params.
|
||||
#[derive(Deserialize)]
|
||||
pub(crate) struct DatabaseInfo {
|
||||
pub(crate) host: Box<str>,
|
||||
pub(crate) port: u16,
|
||||
pub(crate) dbname: Box<str>,
|
||||
pub(crate) user: Box<str>,
|
||||
/// Console always provides a password, but it might
|
||||
/// be inconvenient for debug with local PG instance.
|
||||
pub(crate) password: Option<Box<str>>,
|
||||
pub(crate) aux: MetricsAuxInfo,
|
||||
#[serde(default)]
|
||||
pub(crate) allowed_ips: Option<Vec<IpPattern>>,
|
||||
}
|
||||
|
||||
// Manually implement debug to omit sensitive info.
|
||||
impl fmt::Debug for DatabaseInfo {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("DatabaseInfo")
|
||||
.field("host", &self.host)
|
||||
.field("port", &self.port)
|
||||
.field("dbname", &self.dbname)
|
||||
.field("user", &self.user)
|
||||
.field("allowed_ips", &self.allowed_ips)
|
||||
.finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
|
||||
/// Various labels for prometheus metrics.
|
||||
/// Also known as `ProxyMetricsAuxInfo` in the console.
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub(crate) struct MetricsAuxInfo {
|
||||
pub(crate) endpoint_id: EndpointIdInt,
|
||||
pub(crate) project_id: ProjectIdInt,
|
||||
pub(crate) branch_id: BranchIdInt,
|
||||
#[serde(default)]
|
||||
pub(crate) cold_start_info: ColdStartInfo,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Serialize, Deserialize, Clone, Copy, FixedCardinalityLabel)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ColdStartInfo {
|
||||
#[default]
|
||||
Unknown,
|
||||
/// Compute was already running
|
||||
Warm,
|
||||
#[serde(rename = "pool_hit")]
|
||||
#[label(rename = "pool_hit")]
|
||||
/// Compute was not running but there was an available VM
|
||||
VmPoolHit,
|
||||
#[serde(rename = "pool_miss")]
|
||||
#[label(rename = "pool_miss")]
|
||||
/// Compute was not running and there were no VMs available
|
||||
VmPoolMiss,
|
||||
|
||||
// not provided by control plane
|
||||
/// Connection available from HTTP pool
|
||||
HttpPoolHit,
|
||||
/// Cached connection info
|
||||
WarmCached,
|
||||
}
|
||||
|
||||
impl ColdStartInfo {
|
||||
pub(crate) fn as_str(self) -> &'static str {
|
||||
match self {
|
||||
ColdStartInfo::Unknown => "unknown",
|
||||
ColdStartInfo::Warm => "warm",
|
||||
ColdStartInfo::VmPoolHit => "pool_hit",
|
||||
ColdStartInfo::VmPoolMiss => "pool_miss",
|
||||
ColdStartInfo::HttpPoolHit => "http_pool_hit",
|
||||
ColdStartInfo::WarmCached => "warm_cached",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct EndpointJwksResponse {
|
||||
pub jwks: Vec<JwksSettings>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct JwksSettings {
|
||||
pub id: String,
|
||||
pub jwks_url: url::Url,
|
||||
pub provider_name: String,
|
||||
pub jwt_audience: Option<String>,
|
||||
pub role_names: Vec<RoleNameInt>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
fn dummy_aux() -> serde_json::Value {
|
||||
json!({
|
||||
"endpoint_id": "endpoint",
|
||||
"project_id": "project",
|
||||
"branch_id": "branch",
|
||||
"cold_start_info": "unknown",
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_kick_session() -> anyhow::Result<()> {
|
||||
// This is what the console's kickResponse looks like.
|
||||
let json = json!({
|
||||
"session_id": "deadbeef",
|
||||
"result": {
|
||||
"Success": {
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"dbname": "postgres",
|
||||
"user": "john_doe",
|
||||
"password": "password",
|
||||
"aux": dummy_aux(),
|
||||
}
|
||||
}
|
||||
});
|
||||
serde_json::from_str::<KickSession<'_>>(&json.to_string())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_db_info() -> anyhow::Result<()> {
|
||||
// with password
|
||||
serde_json::from_value::<DatabaseInfo>(json!({
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"dbname": "postgres",
|
||||
"user": "john_doe",
|
||||
"password": "password",
|
||||
"aux": dummy_aux(),
|
||||
}))?;
|
||||
|
||||
// without password
|
||||
serde_json::from_value::<DatabaseInfo>(json!({
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"dbname": "postgres",
|
||||
"user": "john_doe",
|
||||
"aux": dummy_aux(),
|
||||
}))?;
|
||||
|
||||
// new field (forward compatibility)
|
||||
serde_json::from_value::<DatabaseInfo>(json!({
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"dbname": "postgres",
|
||||
"user": "john_doe",
|
||||
"project": "hello_world",
|
||||
"N.E.W": "forward compatibility check",
|
||||
"aux": dummy_aux(),
|
||||
}))?;
|
||||
|
||||
// with allowed_ips
|
||||
let dbinfo = serde_json::from_value::<DatabaseInfo>(json!({
|
||||
"host": "localhost",
|
||||
"port": 5432,
|
||||
"dbname": "postgres",
|
||||
"user": "john_doe",
|
||||
"password": "password",
|
||||
"aux": dummy_aux(),
|
||||
"allowed_ips": ["127.0.0.1"],
|
||||
}))?;
|
||||
|
||||
assert_eq!(
|
||||
dbinfo.allowed_ips,
|
||||
Some(vec![IpPattern::Single("127.0.0.1".parse()?)])
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_wake_compute() -> anyhow::Result<()> {
|
||||
let json = json!({
|
||||
"address": "0.0.0.0",
|
||||
"aux": dummy_aux(),
|
||||
});
|
||||
serde_json::from_str::<WakeCompute>(&json.to_string())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_get_role_secret() -> anyhow::Result<()> {
|
||||
// Empty `allowed_ips` field.
|
||||
let json = json!({
|
||||
"role_secret": "secret",
|
||||
});
|
||||
serde_json::from_str::<GetRoleSecret>(&json.to_string())?;
|
||||
let json = json!({
|
||||
"role_secret": "secret",
|
||||
"allowed_ips": ["8.8.8.8"],
|
||||
});
|
||||
serde_json::from_str::<GetRoleSecret>(&json.to_string())?;
|
||||
let json = json!({
|
||||
"role_secret": "secret",
|
||||
"allowed_ips": ["8.8.8.8"],
|
||||
"project_id": "project",
|
||||
});
|
||||
serde_json::from_str::<GetRoleSecret>(&json.to_string())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
116
proxy/src/control_plane/mgmt.rs
Normal file
116
proxy/src/control_plane/mgmt.rs
Normal file
@@ -0,0 +1,116 @@
|
||||
use crate::{
|
||||
control_plane::messages::{DatabaseInfo, KickSession},
|
||||
waiters::{self, Waiter, Waiters},
|
||||
};
|
||||
use anyhow::Context;
|
||||
use once_cell::sync::Lazy;
|
||||
use postgres_backend::{AuthType, PostgresBackend, PostgresBackendTCP, QueryError};
|
||||
use pq_proto::{BeMessage, SINGLE_COL_ROWDESC};
|
||||
use std::convert::Infallible;
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{error, info, info_span, Instrument};
|
||||
|
||||
static CPLANE_WAITERS: Lazy<Waiters<ComputeReady>> = Lazy::new(Default::default);
|
||||
|
||||
/// Give caller an opportunity to wait for the cloud's reply.
|
||||
pub(crate) fn get_waiter(
|
||||
psql_session_id: impl Into<String>,
|
||||
) -> Result<Waiter<'static, ComputeReady>, waiters::RegisterError> {
|
||||
CPLANE_WAITERS.register(psql_session_id.into())
|
||||
}
|
||||
|
||||
pub(crate) fn notify(psql_session_id: &str, msg: ComputeReady) -> Result<(), waiters::NotifyError> {
|
||||
CPLANE_WAITERS.notify(psql_session_id, msg)
|
||||
}
|
||||
|
||||
/// Console management API listener task.
|
||||
/// It spawns console response handlers needed for the web auth.
|
||||
pub async fn task_main(listener: TcpListener) -> anyhow::Result<Infallible> {
|
||||
scopeguard::defer! {
|
||||
info!("mgmt has shut down");
|
||||
}
|
||||
|
||||
loop {
|
||||
let (socket, peer_addr) = listener.accept().await?;
|
||||
info!("accepted connection from {peer_addr}");
|
||||
|
||||
socket
|
||||
.set_nodelay(true)
|
||||
.context("failed to set client socket option")?;
|
||||
|
||||
let span = info_span!("mgmt", peer = %peer_addr);
|
||||
|
||||
tokio::task::spawn(
|
||||
async move {
|
||||
info!("serving a new console management API connection");
|
||||
|
||||
// these might be long running connections, have a separate logging for cancelling
|
||||
// on shutdown and other ways of stopping.
|
||||
let cancelled = scopeguard::guard(tracing::Span::current(), |span| {
|
||||
let _e = span.entered();
|
||||
info!("console management API task cancelled");
|
||||
});
|
||||
|
||||
if let Err(e) = handle_connection(socket).await {
|
||||
error!("serving failed with an error: {e}");
|
||||
} else {
|
||||
info!("serving completed");
|
||||
}
|
||||
|
||||
// we can no longer get dropped
|
||||
scopeguard::ScopeGuard::into_inner(cancelled);
|
||||
}
|
||||
.instrument(span),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_connection(socket: TcpStream) -> Result<(), QueryError> {
|
||||
let pgbackend = PostgresBackend::new(socket, AuthType::Trust, None)?;
|
||||
pgbackend
|
||||
.run(&mut MgmtHandler, &CancellationToken::new())
|
||||
.await
|
||||
}
|
||||
|
||||
/// A message received by `mgmt` when a compute node is ready.
|
||||
pub(crate) type ComputeReady = DatabaseInfo;
|
||||
|
||||
// TODO: replace with an http-based protocol.
|
||||
struct MgmtHandler;
|
||||
|
||||
impl postgres_backend::Handler<tokio::net::TcpStream> for MgmtHandler {
|
||||
async fn process_query(
|
||||
&mut self,
|
||||
pgb: &mut PostgresBackendTCP,
|
||||
query: &str,
|
||||
) -> Result<(), QueryError> {
|
||||
try_process_query(pgb, query).map_err(|e| {
|
||||
error!("failed to process response: {e:?}");
|
||||
e
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn try_process_query(pgb: &mut PostgresBackendTCP, query: &str) -> Result<(), QueryError> {
|
||||
let resp: KickSession<'_> =
|
||||
serde_json::from_str(query).context("Failed to parse query as json")?;
|
||||
|
||||
let span = info_span!("event", session_id = resp.session_id);
|
||||
let _enter = span.enter();
|
||||
info!("got response: {:?}", resp.result);
|
||||
|
||||
match notify(resp.session_id, resp.result) {
|
||||
Ok(()) => {
|
||||
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
|
||||
.write_message_noflush(&BeMessage::DataRow(&[Some(b"ok")]))?
|
||||
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
}
|
||||
Err(e) => {
|
||||
error!("failed to deliver response to per-client task");
|
||||
pgb.write_message_noflush(&BeMessage::ErrorResponse(&e.to_string(), None))?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
589
proxy/src/control_plane/provider.rs
Normal file
589
proxy/src/control_plane/provider.rs
Normal file
@@ -0,0 +1,589 @@
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
pub mod mock;
|
||||
pub mod neon;
|
||||
|
||||
use super::messages::{ControlPlaneError, MetricsAuxInfo};
|
||||
use crate::{
|
||||
auth::{
|
||||
backend::{
|
||||
jwt::{AuthRule, FetchAuthRules},
|
||||
ComputeCredentialKeys, ComputeUserInfo,
|
||||
},
|
||||
IpPattern,
|
||||
},
|
||||
cache::{endpoints::EndpointsCache, project_info::ProjectInfoCacheImpl, Cached, TimedLru},
|
||||
compute,
|
||||
config::{CacheOptions, EndpointCacheConfig, ProjectInfoCacheOptions},
|
||||
context::RequestMonitoring,
|
||||
error::ReportableError,
|
||||
intern::ProjectIdInt,
|
||||
metrics::ApiLockMetrics,
|
||||
rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token},
|
||||
scram, EndpointCacheKey, EndpointId,
|
||||
};
|
||||
use dashmap::DashMap;
|
||||
use std::{hash::Hash, sync::Arc, time::Duration};
|
||||
use tokio::time::Instant;
|
||||
use tracing::info;
|
||||
|
||||
pub(crate) mod errors {
|
||||
use crate::{
|
||||
control_plane::messages::{self, ControlPlaneError, Reason},
|
||||
error::{io_error, ErrorKind, ReportableError, UserFacingError},
|
||||
proxy::retry::CouldRetry,
|
||||
};
|
||||
use thiserror::Error;
|
||||
|
||||
use super::ApiLockError;
|
||||
|
||||
/// A go-to error message which doesn't leak any detail.
|
||||
pub(crate) const REQUEST_FAILED: &str = "Console request failed";
|
||||
|
||||
/// Common console API error.
|
||||
#[derive(Debug, Error)]
|
||||
pub(crate) enum ApiError {
|
||||
/// Error returned by the console itself.
|
||||
#[error("{REQUEST_FAILED} with {0}")]
|
||||
ControlPlane(ControlPlaneError),
|
||||
|
||||
/// Various IO errors like broken pipe or malformed payload.
|
||||
#[error("{REQUEST_FAILED}: {0}")]
|
||||
Transport(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
impl ApiError {
|
||||
/// Returns HTTP status code if it's the reason for failure.
|
||||
pub(crate) fn get_reason(&self) -> messages::Reason {
|
||||
match self {
|
||||
ApiError::ControlPlane(e) => e.get_reason(),
|
||||
ApiError::Transport(_) => messages::Reason::Unknown,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for ApiError {
|
||||
fn to_string_client(&self) -> String {
|
||||
match self {
|
||||
// To minimize risks, only select errors are forwarded to users.
|
||||
ApiError::ControlPlane(c) => c.get_user_facing_message(),
|
||||
ApiError::Transport(_) => REQUEST_FAILED.to_owned(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for ApiError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
ApiError::ControlPlane(e) => match e.get_reason() {
|
||||
Reason::RoleProtected => ErrorKind::User,
|
||||
Reason::ResourceNotFound => ErrorKind::User,
|
||||
Reason::ProjectNotFound => ErrorKind::User,
|
||||
Reason::EndpointNotFound => ErrorKind::User,
|
||||
Reason::BranchNotFound => ErrorKind::User,
|
||||
Reason::RateLimitExceeded => ErrorKind::ServiceRateLimit,
|
||||
Reason::NonDefaultBranchComputeTimeExceeded => ErrorKind::User,
|
||||
Reason::ActiveTimeQuotaExceeded => ErrorKind::User,
|
||||
Reason::ComputeTimeQuotaExceeded => ErrorKind::User,
|
||||
Reason::WrittenDataQuotaExceeded => ErrorKind::User,
|
||||
Reason::DataTransferQuotaExceeded => ErrorKind::User,
|
||||
Reason::LogicalSizeQuotaExceeded => ErrorKind::User,
|
||||
Reason::ConcurrencyLimitReached => ErrorKind::ControlPlane,
|
||||
Reason::LockAlreadyTaken => ErrorKind::ControlPlane,
|
||||
Reason::RunningOperations => ErrorKind::ControlPlane,
|
||||
Reason::Unknown => match &e {
|
||||
ControlPlaneError {
|
||||
http_status_code:
|
||||
http::StatusCode::NOT_FOUND | http::StatusCode::NOT_ACCEPTABLE,
|
||||
..
|
||||
} => crate::error::ErrorKind::User,
|
||||
ControlPlaneError {
|
||||
http_status_code: http::StatusCode::UNPROCESSABLE_ENTITY,
|
||||
error,
|
||||
..
|
||||
} if error
|
||||
.contains("compute time quota of non-primary branches is exceeded") =>
|
||||
{
|
||||
crate::error::ErrorKind::User
|
||||
}
|
||||
ControlPlaneError {
|
||||
http_status_code: http::StatusCode::LOCKED,
|
||||
error,
|
||||
..
|
||||
} if error.contains("quota exceeded")
|
||||
|| error.contains("the limit for current plan reached") =>
|
||||
{
|
||||
crate::error::ErrorKind::User
|
||||
}
|
||||
ControlPlaneError {
|
||||
http_status_code: http::StatusCode::TOO_MANY_REQUESTS,
|
||||
..
|
||||
} => crate::error::ErrorKind::ServiceRateLimit,
|
||||
ControlPlaneError { .. } => crate::error::ErrorKind::ControlPlane,
|
||||
},
|
||||
},
|
||||
ApiError::Transport(_) => crate::error::ErrorKind::ControlPlane,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CouldRetry for ApiError {
|
||||
fn could_retry(&self) -> bool {
|
||||
match self {
|
||||
// retry some transport errors
|
||||
Self::Transport(io) => io.could_retry(),
|
||||
Self::ControlPlane(e) => e.could_retry(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<reqwest::Error> for ApiError {
|
||||
fn from(e: reqwest::Error) -> Self {
|
||||
io_error(e).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<reqwest_middleware::Error> for ApiError {
|
||||
fn from(e: reqwest_middleware::Error) -> Self {
|
||||
io_error(e).into()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub(crate) enum GetAuthInfoError {
|
||||
// We shouldn't include the actual secret here.
|
||||
#[error("Console responded with a malformed auth secret")]
|
||||
BadSecret,
|
||||
|
||||
#[error(transparent)]
|
||||
ApiError(ApiError),
|
||||
}
|
||||
|
||||
// This allows more useful interactions than `#[from]`.
|
||||
impl<E: Into<ApiError>> From<E> for GetAuthInfoError {
|
||||
fn from(e: E) -> Self {
|
||||
Self::ApiError(e.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for GetAuthInfoError {
|
||||
fn to_string_client(&self) -> String {
|
||||
match self {
|
||||
// We absolutely should not leak any secrets!
|
||||
Self::BadSecret => REQUEST_FAILED.to_owned(),
|
||||
// However, API might return a meaningful error.
|
||||
Self::ApiError(e) => e.to_string_client(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for GetAuthInfoError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
Self::BadSecret => crate::error::ErrorKind::ControlPlane,
|
||||
Self::ApiError(_) => crate::error::ErrorKind::ControlPlane,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub(crate) enum WakeComputeError {
|
||||
#[error("Console responded with a malformed compute address: {0}")]
|
||||
BadComputeAddress(Box<str>),
|
||||
|
||||
#[error(transparent)]
|
||||
ApiError(ApiError),
|
||||
|
||||
#[error("Too many connections attempts")]
|
||||
TooManyConnections,
|
||||
|
||||
#[error("error acquiring resource permit: {0}")]
|
||||
TooManyConnectionAttempts(#[from] ApiLockError),
|
||||
}
|
||||
|
||||
// This allows more useful interactions than `#[from]`.
|
||||
impl<E: Into<ApiError>> From<E> for WakeComputeError {
|
||||
fn from(e: E) -> Self {
|
||||
Self::ApiError(e.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl UserFacingError for WakeComputeError {
|
||||
fn to_string_client(&self) -> String {
|
||||
match self {
|
||||
// We shouldn't show user the address even if it's broken.
|
||||
// Besides, user is unlikely to care about this detail.
|
||||
Self::BadComputeAddress(_) => REQUEST_FAILED.to_owned(),
|
||||
// However, API might return a meaningful error.
|
||||
Self::ApiError(e) => e.to_string_client(),
|
||||
|
||||
Self::TooManyConnections => self.to_string(),
|
||||
|
||||
Self::TooManyConnectionAttempts(_) => {
|
||||
"Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReportableError for WakeComputeError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
Self::BadComputeAddress(_) => crate::error::ErrorKind::ControlPlane,
|
||||
Self::ApiError(e) => e.get_error_kind(),
|
||||
Self::TooManyConnections => crate::error::ErrorKind::RateLimit,
|
||||
Self::TooManyConnectionAttempts(e) => e.get_error_kind(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CouldRetry for WakeComputeError {
|
||||
fn could_retry(&self) -> bool {
|
||||
match self {
|
||||
Self::BadComputeAddress(_) => false,
|
||||
Self::ApiError(e) => e.could_retry(),
|
||||
Self::TooManyConnections => false,
|
||||
Self::TooManyConnectionAttempts(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Auth secret which is managed by the cloud.
|
||||
#[derive(Clone, Eq, PartialEq, Debug)]
|
||||
pub(crate) enum AuthSecret {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
/// Md5 hash of user's password.
|
||||
Md5([u8; 16]),
|
||||
|
||||
/// [SCRAM](crate::scram) authentication info.
|
||||
Scram(scram::ServerSecret),
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub(crate) struct AuthInfo {
|
||||
pub(crate) secret: Option<AuthSecret>,
|
||||
/// List of IP addresses allowed for the autorization.
|
||||
pub(crate) allowed_ips: Vec<IpPattern>,
|
||||
/// Project ID. This is used for cache invalidation.
|
||||
pub(crate) project_id: Option<ProjectIdInt>,
|
||||
}
|
||||
|
||||
/// Info for establishing a connection to a compute node.
|
||||
/// This is what we get after auth succeeded, but not before!
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct NodeInfo {
|
||||
/// Compute node connection params.
|
||||
/// It's sad that we have to clone this, but this will improve
|
||||
/// once we migrate to a bespoke connection logic.
|
||||
pub(crate) config: compute::ConnCfg,
|
||||
|
||||
/// Labels for proxy's metrics.
|
||||
pub(crate) aux: MetricsAuxInfo,
|
||||
|
||||
/// Whether we should accept self-signed certificates (for testing)
|
||||
pub(crate) allow_self_signed_compute: bool,
|
||||
}
|
||||
|
||||
impl NodeInfo {
|
||||
pub(crate) async fn connect(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
timeout: Duration,
|
||||
) -> Result<compute::PostgresConnection, compute::ConnectionError> {
|
||||
self.config
|
||||
.connect(
|
||||
ctx,
|
||||
self.allow_self_signed_compute,
|
||||
self.aux.clone(),
|
||||
timeout,
|
||||
)
|
||||
.await
|
||||
}
|
||||
pub(crate) fn reuse_settings(&mut self, other: Self) {
|
||||
self.allow_self_signed_compute = other.allow_self_signed_compute;
|
||||
self.config.reuse_password(other.config);
|
||||
}
|
||||
|
||||
pub(crate) fn set_keys(&mut self, keys: &ComputeCredentialKeys) {
|
||||
match keys {
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
ComputeCredentialKeys::Password(password) => self.config.password(password),
|
||||
ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys),
|
||||
ComputeCredentialKeys::None => &mut self.config,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) type NodeInfoCache =
|
||||
TimedLru<EndpointCacheKey, Result<NodeInfo, Box<ControlPlaneError>>>;
|
||||
pub(crate) type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>;
|
||||
pub(crate) type CachedRoleSecret = Cached<&'static ProjectInfoCacheImpl, Option<AuthSecret>>;
|
||||
pub(crate) type CachedAllowedIps = Cached<&'static ProjectInfoCacheImpl, Arc<Vec<IpPattern>>>;
|
||||
|
||||
/// This will allocate per each call, but the http requests alone
|
||||
/// already require a few allocations, so it should be fine.
|
||||
pub(crate) trait Api {
|
||||
/// Get the client's auth secret for authentication.
|
||||
/// Returns option because user not found situation is special.
|
||||
/// We still have to mock the scram to avoid leaking information that user doesn't exist.
|
||||
async fn get_role_secret(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedRoleSecret, errors::GetAuthInfoError>;
|
||||
|
||||
async fn get_allowed_ips_and_secret(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError>;
|
||||
|
||||
async fn get_endpoint_jwks(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
endpoint: EndpointId,
|
||||
) -> anyhow::Result<Vec<AuthRule>>;
|
||||
|
||||
/// Wake up the compute node and return the corresponding connection info.
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedNodeInfo, errors::WakeComputeError>;
|
||||
}
|
||||
|
||||
#[non_exhaustive]
|
||||
#[derive(Clone)]
|
||||
pub enum ControlPlaneBackend {
|
||||
/// Current Management API (V2).
|
||||
Management(neon::Api),
|
||||
/// Local mock control plane.
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
PostgresMock(mock::Api),
|
||||
/// Internal testing
|
||||
#[cfg(test)]
|
||||
#[allow(private_interfaces)]
|
||||
Test(Box<dyn crate::auth::backend::TestBackend>),
|
||||
}
|
||||
|
||||
impl Api for ControlPlaneBackend {
|
||||
async fn get_role_secret(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedRoleSecret, errors::GetAuthInfoError> {
|
||||
match self {
|
||||
Self::Management(api) => api.get_role_secret(ctx, user_info).await,
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Self::PostgresMock(api) => api.get_role_secret(ctx, user_info).await,
|
||||
#[cfg(test)]
|
||||
Self::Test(_) => {
|
||||
unreachable!("this function should never be called in the test backend")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_allowed_ips_and_secret(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError> {
|
||||
match self {
|
||||
Self::Management(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Self::PostgresMock(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
|
||||
#[cfg(test)]
|
||||
Self::Test(api) => api.get_allowed_ips_and_secret(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_endpoint_jwks(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
endpoint: EndpointId,
|
||||
) -> anyhow::Result<Vec<AuthRule>> {
|
||||
match self {
|
||||
Self::Management(api) => api.get_endpoint_jwks(ctx, endpoint).await,
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Self::PostgresMock(api) => api.get_endpoint_jwks(ctx, endpoint).await,
|
||||
#[cfg(test)]
|
||||
Self::Test(_api) => Ok(vec![]),
|
||||
}
|
||||
}
|
||||
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedNodeInfo, errors::WakeComputeError> {
|
||||
match self {
|
||||
Self::Management(api) => api.wake_compute(ctx, user_info).await,
|
||||
#[cfg(any(test, feature = "testing"))]
|
||||
Self::PostgresMock(api) => api.wake_compute(ctx, user_info).await,
|
||||
#[cfg(test)]
|
||||
Self::Test(api) => api.wake_compute(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Various caches for [`control_plane`](super).
|
||||
pub struct ApiCaches {
|
||||
/// Cache for the `wake_compute` API method.
|
||||
pub(crate) node_info: NodeInfoCache,
|
||||
/// Cache which stores project_id -> endpoint_ids mapping.
|
||||
pub project_info: Arc<ProjectInfoCacheImpl>,
|
||||
/// List of all valid endpoints.
|
||||
pub endpoints_cache: Arc<EndpointsCache>,
|
||||
}
|
||||
|
||||
impl ApiCaches {
|
||||
pub fn new(
|
||||
wake_compute_cache_config: CacheOptions,
|
||||
project_info_cache_config: ProjectInfoCacheOptions,
|
||||
endpoint_cache_config: EndpointCacheConfig,
|
||||
) -> Self {
|
||||
Self {
|
||||
node_info: NodeInfoCache::new(
|
||||
"node_info_cache",
|
||||
wake_compute_cache_config.size,
|
||||
wake_compute_cache_config.ttl,
|
||||
true,
|
||||
),
|
||||
project_info: Arc::new(ProjectInfoCacheImpl::new(project_info_cache_config)),
|
||||
endpoints_cache: Arc::new(EndpointsCache::new(endpoint_cache_config)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Various caches for [`control_plane`](super).
|
||||
pub struct ApiLocks<K> {
|
||||
name: &'static str,
|
||||
node_locks: DashMap<K, Arc<DynamicLimiter>>,
|
||||
config: RateLimiterConfig,
|
||||
timeout: Duration,
|
||||
epoch: std::time::Duration,
|
||||
metrics: &'static ApiLockMetrics,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) enum ApiLockError {
|
||||
#[error("timeout acquiring resource permit")]
|
||||
TimeoutError(#[from] tokio::time::error::Elapsed),
|
||||
}
|
||||
|
||||
impl ReportableError for ApiLockError {
|
||||
fn get_error_kind(&self) -> crate::error::ErrorKind {
|
||||
match self {
|
||||
ApiLockError::TimeoutError(_) => crate::error::ErrorKind::RateLimit,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Hash + Eq + Clone> ApiLocks<K> {
|
||||
pub fn new(
|
||||
name: &'static str,
|
||||
config: RateLimiterConfig,
|
||||
shards: usize,
|
||||
timeout: Duration,
|
||||
epoch: std::time::Duration,
|
||||
metrics: &'static ApiLockMetrics,
|
||||
) -> prometheus::Result<Self> {
|
||||
Ok(Self {
|
||||
name,
|
||||
node_locks: DashMap::with_shard_amount(shards),
|
||||
config,
|
||||
timeout,
|
||||
epoch,
|
||||
metrics,
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) async fn get_permit(&self, key: &K) -> Result<WakeComputePermit, ApiLockError> {
|
||||
if self.config.initial_limit == 0 {
|
||||
return Ok(WakeComputePermit {
|
||||
permit: Token::disabled(),
|
||||
});
|
||||
}
|
||||
let now = Instant::now();
|
||||
let semaphore = {
|
||||
// get fast path
|
||||
if let Some(semaphore) = self.node_locks.get(key) {
|
||||
semaphore.clone()
|
||||
} else {
|
||||
self.node_locks
|
||||
.entry(key.clone())
|
||||
.or_insert_with(|| {
|
||||
self.metrics.semaphores_registered.inc();
|
||||
DynamicLimiter::new(self.config)
|
||||
})
|
||||
.clone()
|
||||
}
|
||||
};
|
||||
let permit = semaphore.acquire_timeout(self.timeout).await;
|
||||
|
||||
self.metrics
|
||||
.semaphore_acquire_seconds
|
||||
.observe(now.elapsed().as_secs_f64());
|
||||
info!("acquired permit {:?}", now.elapsed().as_secs_f64());
|
||||
Ok(WakeComputePermit { permit: permit? })
|
||||
}
|
||||
|
||||
pub async fn garbage_collect_worker(&self) {
|
||||
if self.config.initial_limit == 0 {
|
||||
return;
|
||||
}
|
||||
let mut interval =
|
||||
tokio::time::interval(self.epoch / (self.node_locks.shards().len()) as u32);
|
||||
loop {
|
||||
for (i, shard) in self.node_locks.shards().iter().enumerate() {
|
||||
interval.tick().await;
|
||||
// temporary lock a single shard and then clear any semaphores that aren't currently checked out
|
||||
// race conditions: if strong_count == 1, there's no way that it can increase while the shard is locked
|
||||
// therefore releasing it is safe from race conditions
|
||||
info!(
|
||||
name = self.name,
|
||||
shard = i,
|
||||
"performing epoch reclamation on api lock"
|
||||
);
|
||||
let mut lock = shard.write();
|
||||
let timer = self.metrics.reclamation_lag_seconds.start_timer();
|
||||
let count = lock
|
||||
.extract_if(|_, semaphore| Arc::strong_count(semaphore.get_mut()) == 1)
|
||||
.count();
|
||||
drop(lock);
|
||||
self.metrics.semaphores_unregistered.inc_by(count as u64);
|
||||
timer.observe();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct WakeComputePermit {
|
||||
permit: Token,
|
||||
}
|
||||
|
||||
impl WakeComputePermit {
|
||||
pub(crate) fn should_check_cache(&self) -> bool {
|
||||
!self.permit.is_disabled()
|
||||
}
|
||||
pub(crate) fn release(self, outcome: Outcome) {
|
||||
self.permit.release(outcome);
|
||||
}
|
||||
pub(crate) fn release_result<T, E>(self, res: Result<T, E>) -> Result<T, E> {
|
||||
match res {
|
||||
Ok(_) => self.release(Outcome::Success),
|
||||
Err(_) => self.release(Outcome::Overload),
|
||||
}
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
impl FetchAuthRules for ControlPlaneBackend {
|
||||
async fn fetch_auth_rules(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
endpoint: EndpointId,
|
||||
) -> anyhow::Result<Vec<AuthRule>> {
|
||||
self.get_endpoint_jwks(ctx, endpoint).await
|
||||
}
|
||||
}
|
||||
248
proxy/src/control_plane/provider/mock.rs
Normal file
248
proxy/src/control_plane/provider/mock.rs
Normal file
@@ -0,0 +1,248 @@
|
||||
//! Mock console backend which relies on a user-provided postgres instance.
|
||||
|
||||
use super::{
|
||||
errors::{ApiError, GetAuthInfoError, WakeComputeError},
|
||||
AuthInfo, AuthSecret, CachedNodeInfo, NodeInfo,
|
||||
};
|
||||
use crate::{
|
||||
auth::backend::jwt::AuthRule, context::RequestMonitoring, intern::RoleNameInt, RoleName,
|
||||
};
|
||||
use crate::{auth::backend::ComputeUserInfo, compute, error::io_error, scram, url::ApiUrl};
|
||||
use crate::{auth::IpPattern, cache::Cached};
|
||||
use crate::{
|
||||
control_plane::{
|
||||
messages::MetricsAuxInfo,
|
||||
provider::{CachedAllowedIps, CachedRoleSecret},
|
||||
},
|
||||
BranchId, EndpointId, ProjectId,
|
||||
};
|
||||
use futures::TryFutureExt;
|
||||
use std::{str::FromStr, sync::Arc};
|
||||
use thiserror::Error;
|
||||
use tokio_postgres::{config::SslMode, Client};
|
||||
use tracing::{error, info, info_span, warn, Instrument};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
enum MockApiError {
|
||||
#[error("Failed to read password: {0}")]
|
||||
PasswordNotSet(tokio_postgres::Error),
|
||||
}
|
||||
|
||||
impl From<MockApiError> for ApiError {
|
||||
fn from(e: MockApiError) -> Self {
|
||||
io_error(e).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<tokio_postgres::Error> for ApiError {
|
||||
fn from(e: tokio_postgres::Error) -> Self {
|
||||
io_error(e).into()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Api {
|
||||
endpoint: ApiUrl,
|
||||
ip_allowlist_check_enabled: bool,
|
||||
}
|
||||
|
||||
impl Api {
|
||||
pub fn new(endpoint: ApiUrl, ip_allowlist_check_enabled: bool) -> Self {
|
||||
Self {
|
||||
endpoint,
|
||||
ip_allowlist_check_enabled,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn url(&self) -> &str {
|
||||
self.endpoint.as_str()
|
||||
}
|
||||
|
||||
async fn do_get_auth_info(
|
||||
&self,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<AuthInfo, GetAuthInfoError> {
|
||||
let (secret, allowed_ips) = async {
|
||||
// Perhaps we could persist this connection, but then we'd have to
|
||||
// write more code for reopening it if it got closed, which doesn't
|
||||
// seem worth it.
|
||||
let (client, connection) =
|
||||
tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
|
||||
|
||||
tokio::spawn(connection);
|
||||
|
||||
let secret = if let Some(entry) = get_execute_postgres_query(
|
||||
&client,
|
||||
"select rolpassword from pg_catalog.pg_authid where rolname = $1",
|
||||
&[&&*user_info.user],
|
||||
"rolpassword",
|
||||
)
|
||||
.await?
|
||||
{
|
||||
info!("got a secret: {entry}"); // safe since it's not a prod scenario
|
||||
let secret = scram::ServerSecret::parse(&entry).map(AuthSecret::Scram);
|
||||
secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5))
|
||||
} else {
|
||||
warn!("user '{}' does not exist", user_info.user);
|
||||
None
|
||||
};
|
||||
|
||||
let allowed_ips = if self.ip_allowlist_check_enabled {
|
||||
match get_execute_postgres_query(
|
||||
&client,
|
||||
"select allowed_ips from neon_control_plane.endpoints where endpoint_id = $1",
|
||||
&[&user_info.endpoint.as_str()],
|
||||
"allowed_ips",
|
||||
)
|
||||
.await?
|
||||
{
|
||||
Some(s) => {
|
||||
info!("got allowed_ips: {s}");
|
||||
s.split(',')
|
||||
.map(|s| IpPattern::from_str(s).unwrap())
|
||||
.collect()
|
||||
}
|
||||
None => vec![],
|
||||
}
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
Ok((secret, allowed_ips))
|
||||
}
|
||||
.map_err(crate::error::log_error::<GetAuthInfoError>)
|
||||
.instrument(info_span!("postgres", url = self.endpoint.as_str()))
|
||||
.await?;
|
||||
Ok(AuthInfo {
|
||||
secret,
|
||||
allowed_ips,
|
||||
project_id: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn do_get_endpoint_jwks(&self, endpoint: EndpointId) -> anyhow::Result<Vec<AuthRule>> {
|
||||
let (client, connection) =
|
||||
tokio_postgres::connect(self.endpoint.as_str(), tokio_postgres::NoTls).await?;
|
||||
|
||||
let connection = tokio::spawn(connection);
|
||||
|
||||
let res = client.query(
|
||||
"select id, jwks_url, audience, role_names from neon_control_plane.endpoint_jwks where endpoint_id = $1",
|
||||
&[&endpoint.as_str()],
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut rows = vec![];
|
||||
for row in res {
|
||||
rows.push(AuthRule {
|
||||
id: row.get("id"),
|
||||
jwks_url: url::Url::parse(row.get("jwks_url"))?,
|
||||
audience: row.get("audience"),
|
||||
role_names: row
|
||||
.get::<_, Vec<String>>("role_names")
|
||||
.into_iter()
|
||||
.map(RoleName::from)
|
||||
.map(|s| RoleNameInt::from(&s))
|
||||
.collect(),
|
||||
});
|
||||
}
|
||||
|
||||
drop(client);
|
||||
connection.await??;
|
||||
|
||||
Ok(rows)
|
||||
}
|
||||
|
||||
async fn do_wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
|
||||
let mut config = compute::ConnCfg::new();
|
||||
config
|
||||
.host(self.endpoint.host_str().unwrap_or("localhost"))
|
||||
.port(self.endpoint.port().unwrap_or(5432))
|
||||
.ssl_mode(SslMode::Disable);
|
||||
|
||||
let node = NodeInfo {
|
||||
config,
|
||||
aux: MetricsAuxInfo {
|
||||
endpoint_id: (&EndpointId::from("endpoint")).into(),
|
||||
project_id: (&ProjectId::from("project")).into(),
|
||||
branch_id: (&BranchId::from("branch")).into(),
|
||||
cold_start_info: crate::control_plane::messages::ColdStartInfo::Warm,
|
||||
},
|
||||
allow_self_signed_compute: false,
|
||||
};
|
||||
|
||||
Ok(node)
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_execute_postgres_query(
|
||||
client: &Client,
|
||||
query: &str,
|
||||
params: &[&(dyn tokio_postgres::types::ToSql + Sync)],
|
||||
idx: &str,
|
||||
) -> Result<Option<String>, GetAuthInfoError> {
|
||||
let rows = client.query(query, params).await?;
|
||||
|
||||
// We can get at most one row, because `rolname` is unique.
|
||||
let Some(row) = rows.first() else {
|
||||
// This means that the user doesn't exist, so there can be no secret.
|
||||
// However, this is still a *valid* outcome which is very similar
|
||||
// to getting `404 Not found` from the Neon console.
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let entry = row.try_get(idx).map_err(MockApiError::PasswordNotSet)?;
|
||||
Ok(Some(entry))
|
||||
}
|
||||
|
||||
impl super::Api for Api {
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn get_role_secret(
|
||||
&self,
|
||||
_ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedRoleSecret, GetAuthInfoError> {
|
||||
Ok(CachedRoleSecret::new_uncached(
|
||||
self.do_get_auth_info(user_info).await?.secret,
|
||||
))
|
||||
}
|
||||
|
||||
async fn get_allowed_ips_and_secret(
|
||||
&self,
|
||||
_ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
|
||||
Ok((
|
||||
Cached::new_uncached(Arc::new(
|
||||
self.do_get_auth_info(user_info).await?.allowed_ips,
|
||||
)),
|
||||
None,
|
||||
))
|
||||
}
|
||||
|
||||
async fn get_endpoint_jwks(
|
||||
&self,
|
||||
_ctx: &RequestMonitoring,
|
||||
endpoint: EndpointId,
|
||||
) -> anyhow::Result<Vec<AuthRule>> {
|
||||
self.do_get_endpoint_jwks(endpoint).await
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
_ctx: &RequestMonitoring,
|
||||
_user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedNodeInfo, WakeComputeError> {
|
||||
self.do_wake_compute().map_ok(Cached::new_uncached).await
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_md5(input: &str) -> Option<[u8; 16]> {
|
||||
let text = input.strip_prefix("md5")?;
|
||||
|
||||
let mut bytes = [0u8; 16];
|
||||
hex::decode_to_slice(text, &mut bytes).ok()?;
|
||||
|
||||
Some(bytes)
|
||||
}
|
||||
494
proxy/src/control_plane/provider/neon.rs
Normal file
494
proxy/src/control_plane/provider/neon.rs
Normal file
@@ -0,0 +1,494 @@
|
||||
//! Production console backend.
|
||||
|
||||
use super::{
|
||||
super::messages::{ControlPlaneError, GetRoleSecret, WakeCompute},
|
||||
errors::{ApiError, GetAuthInfoError, WakeComputeError},
|
||||
ApiCaches, ApiLocks, AuthInfo, AuthSecret, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret,
|
||||
NodeInfo,
|
||||
};
|
||||
use crate::{
|
||||
auth::backend::{jwt::AuthRule, ComputeUserInfo},
|
||||
compute,
|
||||
control_plane::messages::{ColdStartInfo, EndpointJwksResponse, Reason},
|
||||
http,
|
||||
metrics::{CacheOutcome, Metrics},
|
||||
rate_limiter::WakeComputeRateLimiter,
|
||||
scram, EndpointCacheKey, EndpointId,
|
||||
};
|
||||
use crate::{cache::Cached, context::RequestMonitoring};
|
||||
use ::http::{header::AUTHORIZATION, HeaderName};
|
||||
use anyhow::bail;
|
||||
use futures::TryFutureExt;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use tokio::time::Instant;
|
||||
use tokio_postgres::config::SslMode;
|
||||
use tracing::{debug, error, info, info_span, warn, Instrument};
|
||||
|
||||
const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Api {
|
||||
endpoint: http::Endpoint,
|
||||
pub caches: &'static ApiCaches,
|
||||
pub(crate) locks: &'static ApiLocks<EndpointCacheKey>,
|
||||
pub(crate) wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
|
||||
// put in a shared ref so we don't copy secrets all over in memory
|
||||
jwt: Arc<str>,
|
||||
}
|
||||
|
||||
impl Api {
|
||||
/// Construct an API object containing the auth parameters.
|
||||
pub fn new(
|
||||
endpoint: http::Endpoint,
|
||||
caches: &'static ApiCaches,
|
||||
locks: &'static ApiLocks<EndpointCacheKey>,
|
||||
wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
|
||||
) -> Self {
|
||||
let jwt = std::env::var("NEON_PROXY_TO_CONTROLPLANE_TOKEN")
|
||||
.unwrap_or_default()
|
||||
.into();
|
||||
Self {
|
||||
endpoint,
|
||||
caches,
|
||||
locks,
|
||||
wake_compute_endpoint_rate_limiter,
|
||||
jwt,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn url(&self) -> &str {
|
||||
self.endpoint.url().as_str()
|
||||
}
|
||||
|
||||
async fn do_get_auth_info(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<AuthInfo, GetAuthInfoError> {
|
||||
if !self
|
||||
.caches
|
||||
.endpoints_cache
|
||||
.is_valid(ctx, &user_info.endpoint.normalize())
|
||||
.await
|
||||
{
|
||||
info!("endpoint is not valid, skipping the request");
|
||||
return Ok(AuthInfo::default());
|
||||
}
|
||||
let request_id = ctx.session_id().to_string();
|
||||
let application_name = ctx.console_application_name();
|
||||
async {
|
||||
let request = self
|
||||
.endpoint
|
||||
.get_path("proxy_get_role_secret")
|
||||
.header(X_REQUEST_ID, &request_id)
|
||||
.header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
|
||||
.query(&[("session_id", ctx.session_id())])
|
||||
.query(&[
|
||||
("application_name", application_name.as_str()),
|
||||
("project", user_info.endpoint.as_str()),
|
||||
("role", user_info.user.as_str()),
|
||||
])
|
||||
.build()?;
|
||||
|
||||
info!(url = request.url().as_str(), "sending http request");
|
||||
let start = Instant::now();
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
|
||||
let response = self.endpoint.execute(request).await?;
|
||||
drop(pause);
|
||||
info!(duration = ?start.elapsed(), "received http response");
|
||||
let body = match parse_body::<GetRoleSecret>(response).await {
|
||||
Ok(body) => body,
|
||||
// Error 404 is special: it's ok not to have a secret.
|
||||
// TODO(anna): retry
|
||||
Err(e) => {
|
||||
return if e.get_reason().is_not_found() {
|
||||
Ok(AuthInfo::default())
|
||||
} else {
|
||||
Err(e.into())
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let secret = if body.role_secret.is_empty() {
|
||||
None
|
||||
} else {
|
||||
let secret = scram::ServerSecret::parse(&body.role_secret)
|
||||
.map(AuthSecret::Scram)
|
||||
.ok_or(GetAuthInfoError::BadSecret)?;
|
||||
Some(secret)
|
||||
};
|
||||
let allowed_ips = body.allowed_ips.unwrap_or_default();
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.allowed_ips_number
|
||||
.observe(allowed_ips.len() as f64);
|
||||
Ok(AuthInfo {
|
||||
secret,
|
||||
allowed_ips,
|
||||
project_id: body.project_id,
|
||||
})
|
||||
}
|
||||
.map_err(crate::error::log_error)
|
||||
.instrument(info_span!("http", id = request_id))
|
||||
.await
|
||||
}
|
||||
|
||||
async fn do_get_endpoint_jwks(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
endpoint: EndpointId,
|
||||
) -> anyhow::Result<Vec<AuthRule>> {
|
||||
if !self
|
||||
.caches
|
||||
.endpoints_cache
|
||||
.is_valid(ctx, &endpoint.normalize())
|
||||
.await
|
||||
{
|
||||
bail!("endpoint not found");
|
||||
}
|
||||
let request_id = ctx.session_id().to_string();
|
||||
async {
|
||||
let request = self
|
||||
.endpoint
|
||||
.get_with_url(|url| {
|
||||
url.path_segments_mut()
|
||||
.push("endpoints")
|
||||
.push(endpoint.as_str())
|
||||
.push("jwks");
|
||||
})
|
||||
.header(X_REQUEST_ID, &request_id)
|
||||
.header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
|
||||
.query(&[("session_id", ctx.session_id())])
|
||||
.build()?;
|
||||
|
||||
info!(url = request.url().as_str(), "sending http request");
|
||||
let start = Instant::now();
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
|
||||
let response = self.endpoint.execute(request).await?;
|
||||
drop(pause);
|
||||
info!(duration = ?start.elapsed(), "received http response");
|
||||
|
||||
let body = parse_body::<EndpointJwksResponse>(response).await?;
|
||||
|
||||
let rules = body
|
||||
.jwks
|
||||
.into_iter()
|
||||
.map(|jwks| AuthRule {
|
||||
id: jwks.id,
|
||||
jwks_url: jwks.jwks_url,
|
||||
audience: jwks.jwt_audience,
|
||||
role_names: jwks.role_names,
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(rules)
|
||||
}
|
||||
.map_err(crate::error::log_error)
|
||||
.instrument(info_span!("http", id = request_id))
|
||||
.await
|
||||
}
|
||||
|
||||
async fn do_wake_compute(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<NodeInfo, WakeComputeError> {
|
||||
let request_id = ctx.session_id().to_string();
|
||||
let application_name = ctx.console_application_name();
|
||||
async {
|
||||
let mut request_builder = self
|
||||
.endpoint
|
||||
.get_path("proxy_wake_compute")
|
||||
.header("X-Request-ID", &request_id)
|
||||
.header("Authorization", format!("Bearer {}", &self.jwt))
|
||||
.query(&[("session_id", ctx.session_id())])
|
||||
.query(&[
|
||||
("application_name", application_name.as_str()),
|
||||
("project", user_info.endpoint.as_str()),
|
||||
]);
|
||||
|
||||
let options = user_info.options.to_deep_object();
|
||||
if !options.is_empty() {
|
||||
request_builder = request_builder.query(&options);
|
||||
}
|
||||
|
||||
let request = request_builder.build()?;
|
||||
|
||||
info!(url = request.url().as_str(), "sending http request");
|
||||
let start = Instant::now();
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
|
||||
let response = self.endpoint.execute(request).await?;
|
||||
drop(pause);
|
||||
info!(duration = ?start.elapsed(), "received http response");
|
||||
let body = parse_body::<WakeCompute>(response).await?;
|
||||
|
||||
// Unfortunately, ownership won't let us use `Option::ok_or` here.
|
||||
let (host, port) = match parse_host_port(&body.address) {
|
||||
None => return Err(WakeComputeError::BadComputeAddress(body.address)),
|
||||
Some(x) => x,
|
||||
};
|
||||
|
||||
// Don't set anything but host and port! This config will be cached.
|
||||
// We'll set username and such later using the startup message.
|
||||
// TODO: add more type safety (in progress).
|
||||
let mut config = compute::ConnCfg::new();
|
||||
config.host(host).port(port).ssl_mode(SslMode::Disable); // TLS is not configured on compute nodes.
|
||||
|
||||
let node = NodeInfo {
|
||||
config,
|
||||
aux: body.aux,
|
||||
allow_self_signed_compute: false,
|
||||
};
|
||||
|
||||
Ok(node)
|
||||
}
|
||||
.map_err(crate::error::log_error)
|
||||
.instrument(info_span!("http", id = request_id))
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
impl super::Api for Api {
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn get_role_secret(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedRoleSecret, GetAuthInfoError> {
|
||||
let normalized_ep = &user_info.endpoint.normalize();
|
||||
let user = &user_info.user;
|
||||
if let Some(role_secret) = self
|
||||
.caches
|
||||
.project_info
|
||||
.get_role_secret(normalized_ep, user)
|
||||
{
|
||||
return Ok(role_secret);
|
||||
}
|
||||
let auth_info = self.do_get_auth_info(ctx, user_info).await?;
|
||||
if let Some(project_id) = auth_info.project_id {
|
||||
let normalized_ep_int = normalized_ep.into();
|
||||
self.caches.project_info.insert_role_secret(
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
user.into(),
|
||||
auth_info.secret.clone(),
|
||||
);
|
||||
self.caches.project_info.insert_allowed_ips(
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
Arc::new(auth_info.allowed_ips),
|
||||
);
|
||||
ctx.set_project_id(project_id);
|
||||
}
|
||||
// When we just got a secret, we don't need to invalidate it.
|
||||
Ok(Cached::new_uncached(auth_info.secret))
|
||||
}
|
||||
|
||||
async fn get_allowed_ips_and_secret(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
|
||||
let normalized_ep = &user_info.endpoint.normalize();
|
||||
if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(normalized_ep) {
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.allowed_ips_cache_misses
|
||||
.inc(CacheOutcome::Hit);
|
||||
return Ok((allowed_ips, None));
|
||||
}
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.allowed_ips_cache_misses
|
||||
.inc(CacheOutcome::Miss);
|
||||
let auth_info = self.do_get_auth_info(ctx, user_info).await?;
|
||||
let allowed_ips = Arc::new(auth_info.allowed_ips);
|
||||
let user = &user_info.user;
|
||||
if let Some(project_id) = auth_info.project_id {
|
||||
let normalized_ep_int = normalized_ep.into();
|
||||
self.caches.project_info.insert_role_secret(
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
user.into(),
|
||||
auth_info.secret.clone(),
|
||||
);
|
||||
self.caches.project_info.insert_allowed_ips(
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
allowed_ips.clone(),
|
||||
);
|
||||
ctx.set_project_id(project_id);
|
||||
}
|
||||
Ok((
|
||||
Cached::new_uncached(allowed_ips),
|
||||
Some(Cached::new_uncached(auth_info.secret)),
|
||||
))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn get_endpoint_jwks(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
endpoint: EndpointId,
|
||||
) -> anyhow::Result<Vec<AuthRule>> {
|
||||
self.do_get_endpoint_jwks(ctx, endpoint).await
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn wake_compute(
|
||||
&self,
|
||||
ctx: &RequestMonitoring,
|
||||
user_info: &ComputeUserInfo,
|
||||
) -> Result<CachedNodeInfo, WakeComputeError> {
|
||||
let key = user_info.endpoint_cache_key();
|
||||
|
||||
macro_rules! check_cache {
|
||||
() => {
|
||||
if let Some(cached) = self.caches.node_info.get(&key) {
|
||||
let (cached, info) = cached.take_value();
|
||||
let info = info.map_err(|c| {
|
||||
info!(key = &*key, "found cached wake_compute error");
|
||||
WakeComputeError::ApiError(ApiError::ControlPlane(*c))
|
||||
})?;
|
||||
|
||||
debug!(key = &*key, "found cached compute node info");
|
||||
ctx.set_project(info.aux.clone());
|
||||
return Ok(cached.map(|()| info));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// Every time we do a wakeup http request, the compute node will stay up
|
||||
// for some time (highly depends on the console's scale-to-zero policy);
|
||||
// The connection info remains the same during that period of time,
|
||||
// which means that we might cache it to reduce the load and latency.
|
||||
check_cache!();
|
||||
|
||||
let permit = self.locks.get_permit(&key).await?;
|
||||
|
||||
// after getting back a permit - it's possible the cache was filled
|
||||
// double check
|
||||
if permit.should_check_cache() {
|
||||
check_cache!();
|
||||
}
|
||||
|
||||
// check rate limit
|
||||
if !self
|
||||
.wake_compute_endpoint_rate_limiter
|
||||
.check(user_info.endpoint.normalize_intern(), 1)
|
||||
{
|
||||
return Err(WakeComputeError::TooManyConnections);
|
||||
}
|
||||
|
||||
let node = permit.release_result(self.do_wake_compute(ctx, user_info).await);
|
||||
match node {
|
||||
Ok(node) => {
|
||||
ctx.set_project(node.aux.clone());
|
||||
debug!(key = &*key, "created a cache entry for woken compute node");
|
||||
|
||||
let mut stored_node = node.clone();
|
||||
// store the cached node as 'warm_cached'
|
||||
stored_node.aux.cold_start_info = ColdStartInfo::WarmCached;
|
||||
|
||||
let (_, cached) = self.caches.node_info.insert_unit(key, Ok(stored_node));
|
||||
|
||||
Ok(cached.map(|()| node))
|
||||
}
|
||||
Err(err) => match err {
|
||||
WakeComputeError::ApiError(ApiError::ControlPlane(err)) => {
|
||||
let Some(status) = &err.status else {
|
||||
return Err(WakeComputeError::ApiError(ApiError::ControlPlane(err)));
|
||||
};
|
||||
|
||||
let reason = status
|
||||
.details
|
||||
.error_info
|
||||
.map_or(Reason::Unknown, |x| x.reason);
|
||||
|
||||
// if we can retry this error, do not cache it.
|
||||
if reason.can_retry() {
|
||||
return Err(WakeComputeError::ApiError(ApiError::ControlPlane(err)));
|
||||
}
|
||||
|
||||
// at this point, we should only have quota errors.
|
||||
debug!(
|
||||
key = &*key,
|
||||
"created a cache entry for the wake compute error"
|
||||
);
|
||||
|
||||
self.caches.node_info.insert_ttl(
|
||||
key,
|
||||
Err(Box::new(err.clone())),
|
||||
Duration::from_secs(30),
|
||||
);
|
||||
|
||||
Err(WakeComputeError::ApiError(ApiError::ControlPlane(err)))
|
||||
}
|
||||
err => return Err(err),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse http response body, taking status code into account.
|
||||
async fn parse_body<T: for<'a> serde::Deserialize<'a>>(
|
||||
response: http::Response,
|
||||
) -> Result<T, ApiError> {
|
||||
let status = response.status();
|
||||
if status.is_success() {
|
||||
// We shouldn't log raw body because it may contain secrets.
|
||||
info!("request succeeded, processing the body");
|
||||
return Ok(response.json().await?);
|
||||
}
|
||||
let s = response.bytes().await?;
|
||||
// Log plaintext to be able to detect, whether there are some cases not covered by the error struct.
|
||||
info!("response_error plaintext: {:?}", s);
|
||||
|
||||
// Don't throw an error here because it's not as important
|
||||
// as the fact that the request itself has failed.
|
||||
let mut body = serde_json::from_slice(&s).unwrap_or_else(|e| {
|
||||
warn!("failed to parse error body: {e}");
|
||||
ControlPlaneError {
|
||||
error: "reason unclear (malformed error message)".into(),
|
||||
http_status_code: status,
|
||||
status: None,
|
||||
}
|
||||
});
|
||||
body.http_status_code = status;
|
||||
|
||||
error!("console responded with an error ({status}): {body:?}");
|
||||
Err(ApiError::ControlPlane(body))
|
||||
}
|
||||
|
||||
fn parse_host_port(input: &str) -> Option<(&str, u16)> {
|
||||
let (host, port) = input.rsplit_once(':')?;
|
||||
let ipv6_brackets: &[_] = &['[', ']'];
|
||||
Some((host.trim_matches(ipv6_brackets), port.parse().ok()?))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_host_port_v4() {
|
||||
let (host, port) = parse_host_port("127.0.0.1:5432").expect("failed to parse");
|
||||
assert_eq!(host, "127.0.0.1");
|
||||
assert_eq!(port, 5432);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_host_port_v6() {
|
||||
let (host, port) = parse_host_port("[2001:db8::1]:5432").expect("failed to parse");
|
||||
assert_eq!(host, "2001:db8::1");
|
||||
assert_eq!(port, 5432);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_host_port_url() {
|
||||
let (host, port) = parse_host_port("compute-foo-bar-1234.default.svc.cluster.local:5432")
|
||||
.expect("failed to parse");
|
||||
assert_eq!(host, "compute-foo-bar-1234.default.svc.cluster.local");
|
||||
assert_eq!(port, 5432);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user