feat(proxy): add direction and private link id to billing export (#10925)

ref: https://github.com/neondatabase/cloud/issues/23385

Adds a direction flag as well as private-link ID to the traffic
reporting pipeline. We do not yet actually count ingress, but we include
the flag anyway.

I have additionally moved vpce_id string parsing earlier, since we
expect it to be utf8 (ascii).
This commit is contained in:
Conrad Ludgate
2025-02-24 11:49:11 +00:00
committed by GitHub
parent a6f315c9c9
commit fb77f28326
11 changed files with 146 additions and 43 deletions

View File

@@ -308,10 +308,7 @@ async fn auth_quirks(
let incoming_vpc_endpoint_id = match ctx.extra() {
None => return Err(AuthError::MissingEndpointName),
Some(ConnectionInfoExtra::Aws { vpce_id }) => {
// Convert the vcpe_id to a string
String::from_utf8(vpce_id.to_vec()).unwrap_or_default()
}
Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(),
Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(),
};
let allowed_vpc_endpoint_ids = api.get_allowed_vpc_endpoint_ids(ctx, &info).await?;
@@ -451,7 +448,7 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
Ok((Backend::ControlPlane(api, credentials), ip_allowlist))
}
Self::Local(_) => {
return Err(auth::AuthError::bad_auth_method("invalid for local proxy"))
return Err(auth::AuthError::bad_auth_method("invalid for local proxy"));
}
};

View File

@@ -358,10 +358,7 @@ impl CancellationHandler {
let incoming_vpc_endpoint_id = match ctx.extra() {
None => return Err(CancelError::AuthError(AuthError::MissingVPCEndpointId)),
Some(ConnectionInfoExtra::Aws { vpce_id }) => {
// Convert the vcpe_id to a string
String::from_utf8(vpce_id.to_vec()).unwrap_or_default()
}
Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(),
Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(),
};

View File

@@ -241,6 +241,7 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
Ok(Some(ProxyPassthrough {
client: stream,
aux: node.aux.clone(),
private_link_id: None,
compute: node,
session_id: ctx.session_id(),
cancel: session,

View File

@@ -9,6 +9,7 @@ use std::task::{Context, Poll};
use bytes::{Buf, Bytes, BytesMut};
use pin_project_lite::pin_project;
use smol_str::SmolStr;
use strum_macros::FromRepr;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
use zerocopy::{FromBytes, FromZeroes};
@@ -99,7 +100,7 @@ impl fmt::Display for ConnectionInfo {
#[derive(PartialEq, Eq, Clone, Debug)]
pub enum ConnectionInfoExtra {
Aws { vpce_id: Bytes },
Aws { vpce_id: SmolStr },
Azure { link_id: u32 },
}
@@ -193,7 +194,7 @@ fn process_proxy_payload(
return Err(io::Error::new(
io::ErrorKind::Other,
"invalid proxy protocol address family/transport protocol.",
))
));
}
};
@@ -207,9 +208,14 @@ fn process_proxy_payload(
}
let subtype = tlv.value.get_u8();
match Pp2AwsType::from_repr(subtype) {
Some(Pp2AwsType::VpceId) => {
extra = Some(ConnectionInfoExtra::Aws { vpce_id: tlv.value });
}
Some(Pp2AwsType::VpceId) => match std::str::from_utf8(&tlv.value) {
Ok(s) => {
extra = Some(ConnectionInfoExtra::Aws { vpce_id: s.into() });
}
Err(e) => {
tracing::warn!("invalid aws vpce id: {e}");
}
},
None => {
tracing::warn!("unknown aws tlv: subtype={subtype}");
}

View File

@@ -16,7 +16,7 @@ use once_cell::sync::OnceCell;
use pq_proto::{BeMessage as Be, CancelKeyData, StartupMessageParams};
use regex::Regex;
use serde::{Deserialize, Serialize};
use smol_str::{format_smolstr, SmolStr};
use smol_str::{format_smolstr, SmolStr, ToSmolStr};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio_util::sync::CancellationToken;
@@ -29,7 +29,7 @@ use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::metrics::{Metrics, NumClientConnectionsGuard};
use crate::protocol2::{read_proxy_protocol, ConnectHeader, ConnectionInfo};
use crate::protocol2::{read_proxy_protocol, ConnectHeader, ConnectionInfo, ConnectionInfoExtra};
use crate::proxy::handshake::{handshake, HandshakeData};
use crate::rate_limiter::EndpointRateLimiter;
use crate::stream::{PqStream, Stream};
@@ -100,22 +100,34 @@ pub async fn task_main(
debug!("healthcheck received");
return;
}
Ok((_socket, ConnectHeader::Missing)) if config.proxy_protocol_v2 == ProxyProtocolV2::Required => {
Ok((_socket, ConnectHeader::Missing))
if config.proxy_protocol_v2 == ProxyProtocolV2::Required =>
{
warn!("missing required proxy protocol header");
return;
}
Ok((_socket, ConnectHeader::Proxy(_))) if config.proxy_protocol_v2 == ProxyProtocolV2::Rejected => {
Ok((_socket, ConnectHeader::Proxy(_)))
if config.proxy_protocol_v2 == ProxyProtocolV2::Rejected =>
{
warn!("proxy protocol header not supported");
return;
}
Ok((socket, ConnectHeader::Proxy(info))) => (socket, info),
Ok((socket, ConnectHeader::Missing)) => (socket, ConnectionInfo { addr: peer_addr, extra: None }),
Ok((socket, ConnectHeader::Missing)) => (
socket,
ConnectionInfo {
addr: peer_addr,
extra: None,
},
),
};
match socket.inner.set_nodelay(true) {
Ok(()) => {}
Err(e) => {
error!("per-client task finished with an error: failed to set socket option: {e:#}");
error!(
"per-client task finished with an error: failed to set socket option: {e:#}"
);
return;
}
}
@@ -156,10 +168,16 @@ pub async fn task_main(
match p.proxy_pass(&config.connect_to_compute).await {
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
warn!(?session_id, "per-client task finished with an IO error from the client: {e:#}");
warn!(
?session_id,
"per-client task finished with an IO error from the client: {e:#}"
);
}
Err(ErrorSource::Compute(e)) => {
error!(?session_id, "per-client task finished with an IO error from the compute: {e:#}");
error!(
?session_id,
"per-client task finished with an IO error from the compute: {e:#}"
);
}
}
}
@@ -374,9 +392,16 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
let (stream, read_buf) = stream.into_inner();
node.stream.write_all(&read_buf).await?;
let private_link_id = match ctx.extra() {
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
None => None,
};
Ok(Some(ProxyPassthrough {
client: stream,
aux: node.aux.clone(),
private_link_id,
compute: node,
session_id: ctx.session_id(),
cancel: session,

View File

@@ -1,3 +1,4 @@
use smol_str::SmolStr;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::debug;
use utils::measured_stream::MeasuredStream;
@@ -9,7 +10,7 @@ use crate::config::ComputeConfig;
use crate::control_plane::messages::MetricsAuxInfo;
use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard};
use crate::stream::Stream;
use crate::usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS};
use crate::usage_metrics::{Ids, MetricCounterRecorder, TrafficDirection, USAGE_METRICS};
/// Forward bytes in both directions (client <-> compute).
#[tracing::instrument(skip_all)]
@@ -17,10 +18,14 @@ pub(crate) async fn proxy_pass(
client: impl AsyncRead + AsyncWrite + Unpin,
compute: impl AsyncRead + AsyncWrite + Unpin,
aux: MetricsAuxInfo,
private_link_id: Option<SmolStr>,
) -> Result<(), ErrorSource> {
let usage = USAGE_METRICS.register(Ids {
// we will report ingress at a later date
let usage_tx = USAGE_METRICS.register(Ids {
endpoint_id: aux.endpoint_id,
branch_id: aux.branch_id,
direction: TrafficDirection::Egress,
private_link_id,
});
let metrics = &Metrics::get().proxy.io_bytes;
@@ -31,7 +36,7 @@ pub(crate) async fn proxy_pass(
|cnt| {
// Number of bytes we sent to the client (outbound).
metrics.get_metric(m_sent).inc_by(cnt as u64);
usage.record_egress(cnt as u64);
usage_tx.record_egress(cnt as u64);
},
);
@@ -61,6 +66,7 @@ pub(crate) struct ProxyPassthrough<S> {
pub(crate) compute: PostgresConnection,
pub(crate) aux: MetricsAuxInfo,
pub(crate) session_id: uuid::Uuid,
pub(crate) private_link_id: Option<SmolStr>,
pub(crate) cancel: cancellation::Session,
pub(crate) _req: NumConnectionRequestsGuard<'static>,
@@ -72,7 +78,13 @@ impl<S: AsyncRead + AsyncWrite + Unpin> ProxyPassthrough<S> {
self,
compute_config: &ComputeConfig,
) -> Result<(), ErrorSource> {
let res = proxy_pass(self.client, self.compute.stream, self.aux).await;
let res = proxy_pass(
self.client,
self.compute.stream,
self.aux,
self.private_link_id,
)
.await;
if let Err(err) = self
.compute
.cancel_closure

View File

@@ -75,10 +75,7 @@ impl PoolingBackend {
let extra = ctx.extra();
let incoming_endpoint_id = match extra {
None => String::new(),
Some(ConnectionInfoExtra::Aws { vpce_id }) => {
// Convert the vcpe_id to a string
String::from_utf8(vpce_id.to_vec()).unwrap_or_default()
}
Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(),
Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(),
};

View File

@@ -9,6 +9,7 @@ use clashmap::ClashMap;
use parking_lot::RwLock;
use postgres_client::ReadyForQueryStatus;
use rand::Rng;
use smol_str::ToSmolStr;
use tracing::{debug, info, Span};
use super::backend::HttpConnError;
@@ -19,8 +20,9 @@ use crate::auth::backend::ComputeUserInfo;
use crate::context::RequestContext;
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
use crate::protocol2::ConnectionInfoExtra;
use crate::types::{DbName, EndpointCacheKey, RoleName};
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
use crate::usage_metrics::{Ids, MetricCounter, TrafficDirection, USAGE_METRICS};
#[derive(Debug, Clone)]
pub(crate) struct ConnInfo {
@@ -473,7 +475,9 @@ where
.http_pool_opened_connections
.get_metric()
.dec_by(clients_removed as i64);
info!("pool: performed global pool gc. removed {clients_removed} clients, total number of clients in pool is {size}");
info!(
"pool: performed global pool gc. removed {clients_removed} clients, total number of clients in pool is {size}"
);
}
let removed = current_len - new_len;
@@ -635,15 +639,28 @@ impl<C: ClientInnerExt> Client<C> {
(&mut inner.inner, Discard { conn_info, pool })
}
pub(crate) fn metrics(&self) -> Arc<MetricCounter> {
pub(crate) fn metrics(
&self,
direction: TrafficDirection,
ctx: &RequestContext,
) -> Arc<MetricCounter> {
let aux = &self
.inner
.as_ref()
.expect("client inner should not be removed")
.aux;
let private_link_id = match ctx.extra() {
None => None,
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
};
USAGE_METRICS.register(Ids {
endpoint_id: aux.endpoint_id,
branch_id: aux.branch_id,
direction,
private_link_id,
})
}
}
@@ -700,7 +717,9 @@ impl<C: ClientInnerExt> Discard<'_, C> {
pub(crate) fn discard(&mut self) {
let conn_info = &self.conn_info;
if std::mem::take(self.pool).strong_count() > 0 {
info!("pool: throwing away connection '{conn_info}' because connection is potentially in a broken state");
info!(
"pool: throwing away connection '{conn_info}' because connection is potentially in a broken state"
);
}
}
}

View File

@@ -5,6 +5,7 @@ use std::sync::{Arc, Weak};
use hyper::client::conn::http2;
use hyper_util::rt::{TokioExecutor, TokioIo};
use parking_lot::RwLock;
use smol_str::ToSmolStr;
use tokio::net::TcpStream;
use tracing::{debug, error, info, info_span, Instrument};
@@ -16,8 +17,9 @@ use super::conn_pool_lib::{
use crate::context::RequestContext;
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
use crate::protocol2::ConnectionInfoExtra;
use crate::types::EndpointCacheKey;
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
use crate::usage_metrics::{Ids, MetricCounter, TrafficDirection, USAGE_METRICS};
pub(crate) type Send = http2::SendRequest<hyper::body::Incoming>;
pub(crate) type Connect =
@@ -264,11 +266,24 @@ impl<C: ClientInnerExt + Clone> Client<C> {
Self { inner }
}
pub(crate) fn metrics(&self) -> Arc<MetricCounter> {
pub(crate) fn metrics(
&self,
direction: TrafficDirection,
ctx: &RequestContext,
) -> Arc<MetricCounter> {
let aux = &self.inner.aux;
let private_link_id = match ctx.extra() {
None => None,
Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()),
Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()),
};
USAGE_METRICS.register(Ids {
endpoint_id: aux.endpoint_id,
branch_id: aux.branch_id,
direction,
private_link_id,
})
}
}

View File

@@ -42,7 +42,7 @@ use crate::metrics::{HttpDirection, Metrics};
use crate::proxy::{run_until_cancelled, NeonOptions};
use crate::serverless::backend::HttpConnError;
use crate::types::{DbName, RoleName};
use crate::usage_metrics::{MetricCounter, MetricCounterRecorder};
use crate::usage_metrics::{MetricCounter, MetricCounterRecorder, TrafficDirection};
#[derive(serde::Deserialize)]
#[serde(rename_all = "camelCase")]
@@ -209,7 +209,7 @@ fn get_conn_info(
}
}
Some(url::Host::Ipv4(_) | url::Host::Ipv6(_)) | None => {
return Err(ConnInfoError::MissingHostname)
return Err(ConnInfoError::MissingHostname);
}
};
ctx.set_endpoint_id(endpoint.clone());
@@ -745,7 +745,7 @@ async fn handle_db_inner(
}
};
let metrics = client.metrics();
let metrics = client.metrics(TrafficDirection::Egress, ctx);
let len = json_output.len();
let response = response
@@ -818,7 +818,7 @@ async fn handle_auth_broker_inner(
.expect("all headers and params received via hyper should be valid for request");
// todo: map body to count egress
let _metrics = client.metrics();
let _metrics = client.metrics(TrafficDirection::Egress, ctx);
Ok(client
.inner
@@ -1118,10 +1118,10 @@ enum Discard<'a> {
}
impl Client {
fn metrics(&self) -> Arc<MetricCounter> {
fn metrics(&self, direction: TrafficDirection, ctx: &RequestContext) -> Arc<MetricCounter> {
match self {
Client::Remote(client) => client.metrics(),
Client::Local(local_client) => local_client.metrics(),
Client::Remote(client) => client.metrics(direction, ctx),
Client::Local(local_client) => local_client.metrics(direction, ctx),
}
}

View File

@@ -16,6 +16,7 @@ use consumption_metrics::{idempotency_key, Event, EventChunk, EventType, CHUNK_S
use once_cell::sync::Lazy;
use remote_storage::{GenericRemoteStorage, RemotePath, TimeoutOrCancel};
use serde::{Deserialize, Serialize};
use smol_str::SmolStr;
use tokio::io::AsyncWriteExt;
use tokio_util::sync::CancellationToken;
use tracing::{error, info, instrument, trace, warn};
@@ -43,6 +44,37 @@ const HTTP_REPORTING_RETRY_DURATION: Duration = Duration::from_secs(60);
pub(crate) struct Ids {
pub(crate) endpoint_id: EndpointIdInt,
pub(crate) branch_id: BranchIdInt,
pub(crate) direction: TrafficDirection,
#[serde(with = "none_as_empty_string")]
pub(crate) private_link_id: Option<SmolStr>,
}
mod none_as_empty_string {
use serde::Deserialize;
use smol_str::SmolStr;
#[allow(clippy::ref_option)]
pub fn serialize<S: serde::Serializer>(t: &Option<SmolStr>, s: S) -> Result<S::Ok, S::Error> {
s.serialize_str(t.as_deref().unwrap_or(""))
}
pub fn deserialize<'de, D: serde::Deserializer<'de>>(
d: D,
) -> Result<Option<SmolStr>, D::Error> {
let s = SmolStr::deserialize(d)?;
if s.is_empty() {
Ok(None)
} else {
Ok(Some(s))
}
}
}
#[derive(Eq, Hash, PartialEq, Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "lowercase")]
pub(crate) enum TrafficDirection {
Ingress,
Egress,
}
pub(crate) trait MetricCounterRecorder {
@@ -505,6 +537,8 @@ mod tests {
let counter = metrics.register(Ids {
endpoint_id: (&EndpointId::from("e1")).into(),
branch_id: (&BranchId::from("b1")).into(),
direction: TrafficDirection::Egress,
private_link_id: None,
});
// the counter should be observed despite 0 egress