[proxy] Add more labels to the pricing metrics

This commit is contained in:
Dmitry Ivanov
2022-12-26 22:10:28 +03:00
parent 7c7d225d98
commit c700c7db2e
12 changed files with 249 additions and 177 deletions

View File

@@ -8,7 +8,9 @@ pub use console::{GetAuthInfoError, WakeComputeError};
use crate::{
auth::{self, AuthFlow, ClientCredentials},
compute, http, mgmt, stream, url,
compute,
console::messages::MetricsAuxInfo,
http, mgmt, stream, url,
waiters::{self, Waiter, Waiters},
};
use once_cell::sync::Lazy;
@@ -126,25 +128,13 @@ pub struct AuthSuccess<T> {
pub value: T,
}
impl<T> AuthSuccess<T> {
/// Very similar to [`std::option::Option::map`].
/// Maps [`AuthSuccess<T>`] to [`AuthSuccess<R>`] by applying
/// a function to a contained value.
pub fn map<R>(self, f: impl FnOnce(T) -> R) -> AuthSuccess<R> {
AuthSuccess {
reported_auth_ok: self.reported_auth_ok,
value: f(self.value),
}
}
}
/// Info for establishing a connection to a compute node.
/// This is what we get after auth succeeded, but not before!
pub struct NodeInfo {
/// Project from [`auth::ClientCredentials`].
pub project: String,
/// Compute node connection params.
pub config: compute::ConnCfg,
/// Labels for proxy's metrics.
pub aux: MetricsAuxInfo,
}
impl BackendType<'_, ClientCredentials<'_>> {
@@ -172,37 +162,34 @@ impl BackendType<'_, ClientCredentials<'_>> {
};
// TODO: find a proper way to merge those very similar blocks.
let (mut config, payload) = match self {
let (mut node, payload) = match self {
Console(endpoint, creds) if creds.project.is_none() => {
let payload = fetch_magic_payload.await?;
let mut creds = creds.as_ref();
creds.project = Some(payload.project.as_str().into());
let config = console::Api::new(endpoint, extra, &creds)
let node = console::Api::new(endpoint, extra, &creds)
.wake_compute()
.await?;
(config, payload)
(node, payload)
}
Postgres(endpoint, creds) if creds.project.is_none() => {
let payload = fetch_magic_payload.await?;
let mut creds = creds.as_ref();
creds.project = Some(payload.project.as_str().into());
let config = postgres::Api::new(endpoint, &creds).wake_compute().await?;
let node = postgres::Api::new(endpoint, &creds).wake_compute().await?;
(config, payload)
(node, payload)
}
_ => return Ok(None),
};
config.password(payload.password);
node.config.password(payload.password);
Ok(Some(AuthSuccess {
reported_auth_ok: false,
value: NodeInfo {
project: payload.project,
config,
},
value: node,
}))
}
@@ -233,10 +220,6 @@ impl BackendType<'_, ClientCredentials<'_>> {
console::Api::new(&endpoint, extra, &creds)
.handle_user(client)
.await?
.map(|config| NodeInfo {
project: creds.project.unwrap().into_owned(),
config,
})
}
Postgres(endpoint, creds) => {
info!("performing mock authentication using a local postgres instance");
@@ -245,10 +228,6 @@ impl BackendType<'_, ClientCredentials<'_>> {
postgres::Api::new(&endpoint, &creds)
.handle_user(client)
.await?
.map(|config| NodeInfo {
project: creds.project.unwrap().into_owned(),
config,
})
}
// NOTE: this auth backend doesn't use client credentials.
Link(url) => {

View File

@@ -1,16 +1,16 @@
//! Cloud API V2.
use super::{AuthSuccess, ConsoleReqExtra};
use super::{AuthSuccess, ConsoleReqExtra, NodeInfo};
use crate::{
auth::{self, AuthFlow, ClientCredentials},
compute,
console::messages::{ConsoleError, GetRoleSecret, WakeCompute},
error::{io_error, UserFacingError},
http, sasl, scram,
stream::PqStream,
};
use futures::TryFutureExt;
use reqwest::StatusCode as HttpStatusCode;
use serde::Deserialize;
use std::future::Future;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite};
@@ -136,24 +136,6 @@ impl UserFacingError for WakeComputeError {
}
}
/// Console's response which holds client's auth secret.
#[derive(Deserialize, Debug)]
struct GetRoleSecret {
role_secret: Box<str>,
}
/// Console's response which holds compute node's `host:port` pair.
#[derive(Deserialize, Debug)]
struct WakeCompute {
address: Box<str>,
}
/// Console's error response with human-readable description.
#[derive(Deserialize, Debug)]
struct ConsoleError {
error: Box<str>,
}
/// Auth secret which is managed by the cloud.
pub enum AuthInfo {
/// Md5 hash of user's password.
@@ -194,7 +176,7 @@ impl<'a> Api<'a> {
pub(super) async fn handle_user(
&'a self,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
) -> auth::Result<AuthSuccess<compute::ConnCfg>> {
) -> auth::Result<AuthSuccess<NodeInfo>> {
handle_user(client, self, Self::get_auth_info, Self::wake_compute).await
}
}
@@ -238,7 +220,7 @@ impl Api<'_> {
}
/// Wake up the compute node and return the corresponding connection info.
pub async fn wake_compute(&self) -> Result<compute::ConnCfg, WakeComputeError> {
pub async fn wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
let request_id = uuid::Uuid::new_v4().to_string();
async {
let request = self
@@ -269,7 +251,10 @@ impl Api<'_> {
.dbname(self.creds.dbname)
.user(self.creds.user);
Ok(config)
Ok(NodeInfo {
config,
aux: body.aux,
})
}
.map_err(crate::error::log_error)
.instrument(info_span!("wake_compute", id = request_id))
@@ -284,11 +269,11 @@ pub(super) async fn handle_user<'a, Endpoint, GetAuthInfo, WakeCompute>(
endpoint: &'a Endpoint,
get_auth_info: impl FnOnce(&'a Endpoint) -> GetAuthInfo,
wake_compute: impl FnOnce(&'a Endpoint) -> WakeCompute,
) -> auth::Result<AuthSuccess<compute::ConnCfg>>
) -> auth::Result<AuthSuccess<NodeInfo>>
where
Endpoint: AsRef<ClientCredentials<'a>>,
GetAuthInfo: Future<Output = Result<Option<AuthInfo>, GetAuthInfoError>>,
WakeCompute: Future<Output = Result<compute::ConnCfg, WakeComputeError>>,
WakeCompute: Future<Output = Result<NodeInfo, WakeComputeError>>,
{
let creds = endpoint.as_ref();
@@ -325,19 +310,20 @@ where
}
};
let mut config = wake_compute(endpoint).await?;
let mut node = wake_compute(endpoint).await?;
if let Some(keys) = scram_keys {
config.auth_keys(tokio_postgres::config::AuthKeys::ScramSha256(keys));
use tokio_postgres::config::AuthKeys;
node.config.auth_keys(AuthKeys::ScramSha256(keys));
}
Ok(AuthSuccess {
reported_auth_ok: false,
value: config,
value: node,
})
}
/// Parse http response body, taking status code into account.
async fn parse_body<T: for<'a> Deserialize<'a>>(
async fn parse_body<T: for<'a> serde::Deserialize<'a>>(
response: reqwest::Response,
) -> Result<T, ApiError> {
let status = response.status();

View File

@@ -86,8 +86,8 @@ pub async fn handle_user(
Ok(AuthSuccess {
reported_auth_ok: true,
value: NodeInfo {
project: db_info.project,
config,
aux: db_info.aux,
},
})
}

View File

@@ -2,7 +2,7 @@
use super::{
console::{self, AuthInfo, GetAuthInfoError, WakeComputeError},
AuthSuccess,
AuthSuccess, NodeInfo,
};
use crate::{
auth::{self, ClientCredentials},
@@ -57,7 +57,7 @@ impl<'a> Api<'a> {
pub(super) async fn handle_user(
&'a self,
client: &mut PqStream<impl AsyncRead + AsyncWrite + Unpin + Send>,
) -> auth::Result<AuthSuccess<compute::ConnCfg>> {
) -> auth::Result<AuthSuccess<NodeInfo>> {
// We reuse user handling logic from a production module.
console::handle_user(client, self, Self::get_auth_info, Self::wake_compute).await
}
@@ -103,7 +103,7 @@ impl Api<'_> {
}
/// We don't need to wake anything locally, so we just return the connection info.
pub async fn wake_compute(&self) -> Result<compute::ConnCfg, WakeComputeError> {
pub async fn wake_compute(&self) -> Result<NodeInfo, WakeComputeError> {
let mut config = compute::ConnCfg::new();
config
.host(self.endpoint.host_str().unwrap_or("localhost"))
@@ -111,7 +111,10 @@ impl Api<'_> {
.dbname(self.creds.dbname)
.user(self.creds.user);
Ok(config)
Ok(NodeInfo {
config,
aux: Default::default(),
})
}
}

View File

@@ -43,7 +43,7 @@ pub type ScramKeys = tokio_postgres::config::ScramKeys<32>;
/// Eventually, `tokio_postgres` will be replaced with something better.
/// Newtype allows us to implement methods on top of it.
#[repr(transparent)]
pub struct ConnCfg(pub tokio_postgres::Config);
pub struct ConnCfg(Box<tokio_postgres::Config>);
impl ConnCfg {
/// Construct a new connection config.

5
proxy/src/console.rs Normal file
View File

@@ -0,0 +1,5 @@
///! Various stuff for dealing with the Neon Console.
///! Later we might move some API wrappers here.
/// Payloads used in the console's APIs.
pub mod messages;

View File

@@ -0,0 +1,190 @@
use serde::Deserialize;
use std::fmt;
/// Generic error response with human-readable description.
/// Note that we can't always present it to user as is.
#[derive(Debug, Deserialize)]
pub struct ConsoleError {
pub error: Box<str>,
}
/// Response which holds client's auth secret, e.g. [`crate::scram::ServerSecret`].
/// Returned by the `/proxy_get_role_secret` API method.
#[derive(Deserialize)]
pub struct GetRoleSecret {
pub role_secret: Box<str>,
}
// Manually implement debug to omit sensitive info.
impl fmt::Debug for GetRoleSecret {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("GetRoleSecret").finish_non_exhaustive()
}
}
/// Response which holds compute node's `host:port` pair.
/// Returned by the `/proxy_wake_compute` API method.
#[derive(Debug, Deserialize)]
pub struct WakeCompute {
pub address: Box<str>,
pub aux: MetricsAuxInfo,
}
/// Async response which concludes the link auth flow.
/// Also known as `kickResponse` in the console.
#[derive(Debug, Deserialize)]
pub struct KickSession<'a> {
/// Session ID is assigned by the proxy.
pub session_id: &'a str,
/// Compute node connection params.
#[serde(deserialize_with = "KickSession::parse_db_info")]
pub result: DatabaseInfo,
}
impl KickSession<'_> {
fn parse_db_info<'de, D>(des: D) -> Result<DatabaseInfo, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(Deserialize)]
enum Wrapper {
// Currently, console only reports `Success`.
// `Failure(String)` used to be here... RIP.
Success(DatabaseInfo),
}
Wrapper::deserialize(des).map(|x| match x {
Wrapper::Success(info) => info,
})
}
}
/// Compute node connection params.
#[derive(Deserialize)]
pub struct DatabaseInfo {
pub host: String,
pub port: u16,
pub dbname: String,
pub user: String,
/// Console always provides a password, but it might
/// be inconvenient for debug with local PG instance.
pub password: Option<String>,
pub aux: MetricsAuxInfo,
}
// Manually implement debug to omit sensitive info.
impl fmt::Debug for DatabaseInfo {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("DatabaseInfo")
.field("host", &self.host)
.field("port", &self.port)
.field("dbname", &self.dbname)
.field("user", &self.user)
.finish_non_exhaustive()
}
}
/// Various labels for prometheus metrics.
/// Also known as `ProxyMetricsAuxInfo` in the console.
#[derive(Debug, Deserialize, Default)]
pub struct MetricsAuxInfo {
pub endpoint_id: Box<str>,
pub project_id: Box<str>,
pub branch_id: Box<str>,
}
impl MetricsAuxInfo {
/// Definitions of labels for traffic metric.
pub const TRAFFIC_LABELS: &'static [&'static str] = &[
// Received (rx) / sent (tx).
"direction",
// ID of a project.
"project_id",
// ID of an endpoint within a project.
"endpoint_id",
// ID of a branch within a project (snapshot).
"branch_id",
];
/// Values of labels for traffic metric.
// TODO: add more type safety (validate arity & positions).
pub fn traffic_labels(&self, direction: &'static str) -> [&str; 4] {
[
direction,
&self.project_id,
&self.endpoint_id,
&self.branch_id,
]
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn dummy_aux() -> serde_json::Value {
json!({
"endpoint_id": "endpoint",
"project_id": "project",
"branch_id": "branch",
})
}
#[test]
fn parse_kick_session() -> anyhow::Result<()> {
// This is what the console's kickResponse looks like.
let json = json!({
"session_id": "deadbeef",
"result": {
"Success": {
"host": "localhost",
"port": 5432,
"dbname": "postgres",
"user": "john_doe",
"password": "password",
"aux": dummy_aux(),
}
}
});
let _: KickSession = serde_json::from_str(&json.to_string())?;
Ok(())
}
#[test]
fn parse_db_info() -> anyhow::Result<()> {
// with password
let _: DatabaseInfo = serde_json::from_value(json!({
"host": "localhost",
"port": 5432,
"dbname": "postgres",
"user": "john_doe",
"password": "password",
"aux": dummy_aux(),
}))?;
// without password
let _: DatabaseInfo = serde_json::from_value(json!({
"host": "localhost",
"port": 5432,
"dbname": "postgres",
"user": "john_doe",
"aux": dummy_aux(),
}))?;
// new field (forward compatibility)
let _: DatabaseInfo = serde_json::from_value(json!({
"host": "localhost",
"port": 5432,
"dbname": "postgres",
"user": "john_doe",
"project": "hello_world",
"N.E.W": "forward compatibility check",
"aux": dummy_aux(),
}))?;
Ok(())
}
}

View File

@@ -8,6 +8,7 @@ mod auth;
mod cancellation;
mod compute;
mod config;
mod console;
mod error;
mod http;
mod mgmt;

View File

@@ -1,7 +1,9 @@
use crate::auth;
use crate::{
auth,
console::messages::{DatabaseInfo, KickSession},
};
use anyhow::Context;
use pq_proto::{BeMessage, SINGLE_COL_ROWDESC};
use serde::Deserialize;
use std::{
net::{TcpListener, TcpStream},
thread,
@@ -50,59 +52,9 @@ fn handle_connection(socket: TcpStream) -> anyhow::Result<()> {
pgbackend.run(&mut MgmtHandler)
}
/// Known as `kickResponse` in the console.
#[derive(Debug, Deserialize)]
struct PsqlSessionResponse {
session_id: String,
result: PsqlSessionResult,
}
#[derive(Debug, Deserialize)]
enum PsqlSessionResult {
Success(DatabaseInfo),
Failure(String),
}
/// A message received by `mgmt` when a compute node is ready.
pub type ComputeReady = Result<DatabaseInfo, String>;
impl PsqlSessionResult {
fn into_compute_ready(self) -> ComputeReady {
match self {
Self::Success(db_info) => Ok(db_info),
Self::Failure(message) => Err(message),
}
}
}
/// Compute node connection params provided by the console.
/// This struct and its parents are mgmt API implementation
/// detail and thus should remain in this module.
// TODO: restore deserialization tests from git history.
#[derive(Deserialize)]
pub struct DatabaseInfo {
pub host: String,
pub port: u16,
pub dbname: String,
pub user: String,
/// Console always provides a password, but it might
/// be inconvenient for debug with local PG instance.
pub password: Option<String>,
pub project: String,
}
// Manually implement debug to omit sensitive info.
impl std::fmt::Debug for DatabaseInfo {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
fmt.debug_struct("DatabaseInfo")
.field("host", &self.host)
.field("port", &self.port)
.field("dbname", &self.dbname)
.field("user", &self.user)
.finish_non_exhaustive()
}
}
// TODO: replace with an http-based protocol.
struct MgmtHandler;
impl postgres_backend::Handler for MgmtHandler {
@@ -115,13 +67,13 @@ impl postgres_backend::Handler for MgmtHandler {
}
fn try_process_query(pgb: &mut PostgresBackend, query: &str) -> anyhow::Result<()> {
let resp: PsqlSessionResponse = serde_json::from_str(query)?;
let resp: KickSession = serde_json::from_str(query)?;
let span = info_span!("event", session_id = resp.session_id);
let _enter = span.enter();
info!("got response: {:?}", resp.result);
match auth::backend::notify(&resp.session_id, resp.result.into_compute_ready()) {
match auth::backend::notify(resp.session_id, Ok(resp.result)) {
Ok(()) => {
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
.write_message_noflush(&BeMessage::DataRow(&[Some(b"ok")]))?
@@ -135,43 +87,3 @@ fn try_process_query(pgb: &mut PostgresBackend, query: &str) -> anyhow::Result<(
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn parse_db_info() -> anyhow::Result<()> {
// with password
let _: DatabaseInfo = serde_json::from_value(json!({
"host": "localhost",
"port": 5432,
"dbname": "postgres",
"user": "john_doe",
"password": "password",
"project": "hello_world",
}))?;
// without password
let _: DatabaseInfo = serde_json::from_value(json!({
"host": "localhost",
"port": 5432,
"dbname": "postgres",
"user": "john_doe",
"project": "hello_world",
}))?;
// new field (forward compatibility)
let _: DatabaseInfo = serde_json::from_value(json!({
"host": "localhost",
"port": 5432,
"dbname": "postgres",
"user": "john_doe",
"project": "hello_world",
"N.E.W": "forward compatibility check",
}))?;
Ok(())
}
}

View File

@@ -11,7 +11,7 @@ use anyhow::{bail, Context};
use futures::TryFutureExt;
use metrics::{register_int_counter, register_int_counter_vec, IntCounter, IntCounterVec};
use once_cell::sync::Lazy;
use pq_proto::{BeMessage as Be, *};
use pq_proto::{BeMessage as Be, FeStartupPacket, StartupMessageParams};
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{error, info, info_span, Instrument};
@@ -39,12 +39,7 @@ static NUM_BYTES_PROXIED_COUNTER: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"proxy_io_bytes_per_client",
"Number of bytes sent/received between client and backend.",
&[
// Received (rx) / sent (tx).
"direction",
// Proxy can keep calling it `project` internally.
"endpoint_id"
]
crate::console::messages::MetricsAuxInfo::TRAFFIC_LABELS,
)
.unwrap()
});
@@ -271,19 +266,16 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> Client<'_, S> {
stream
.write_message_noflush(&Be::BackendKeyData(cancel_key_data))?
.write_message(&BeMessage::ReadyForQuery)
.write_message(&Be::ReadyForQuery)
.await?;
// TODO: add more identifiers.
let metric_id = node.project;
let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["tx", &metric_id]);
let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&node.aux.traffic_labels("tx"));
let mut client = MeasuredStream::new(stream.into_inner(), |cnt| {
// Number of bytes we sent to the client (outbound).
m_sent.inc_by(cnt as u64);
});
let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["rx", &metric_id]);
let m_recv = NUM_BYTES_PROXIED_COUNTER.with_label_values(&node.aux.traffic_labels("rx"));
let mut db = MeasuredStream::new(db.stream, |cnt| {
// Number of bytes the client sent to the compute node (inbound).
m_recv.inc_by(cnt as u64);

View File

@@ -140,7 +140,7 @@ async fn dummy_proxy(
stream
.write_message_noflush(&Be::AuthenticationOk)?
.write_message_noflush(&Be::CLIENT_ENCODING)?
.write_message(&BeMessage::ReadyForQuery)
.write_message(&Be::ReadyForQuery)
.await?;
Ok(())