From fb77f28326492b1dff44d2623a81ce822a45ef9e Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Mon, 24 Feb 2025 11:49:11 +0000 Subject: [PATCH] feat(proxy): add direction and private link id to billing export (#10925) ref: https://github.com/neondatabase/cloud/issues/23385 Adds a direction flag as well as private-link ID to the traffic reporting pipeline. We do not yet actually count ingress, but we include the flag anyway. I have additionally moved vpce_id string parsing earlier, since we expect it to be utf8 (ascii). --- proxy/src/auth/backend/mod.rs | 7 ++--- proxy/src/cancellation.rs | 5 +--- proxy/src/console_redirect_proxy.rs | 1 + proxy/src/protocol2.rs | 16 ++++++---- proxy/src/proxy/mod.rs | 41 +++++++++++++++++++++----- proxy/src/proxy/passthrough.rs | 20 ++++++++++--- proxy/src/serverless/backend.rs | 5 +--- proxy/src/serverless/conn_pool_lib.rs | 27 ++++++++++++++--- proxy/src/serverless/http_conn_pool.rs | 19 ++++++++++-- proxy/src/serverless/sql_over_http.rs | 14 ++++----- proxy/src/usage_metrics.rs | 34 +++++++++++++++++++++ 11 files changed, 146 insertions(+), 43 deletions(-) diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index dc595844c5..8f1625278f 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -308,10 +308,7 @@ async fn auth_quirks( let incoming_vpc_endpoint_id = match ctx.extra() { None => return Err(AuthError::MissingEndpointName), - Some(ConnectionInfoExtra::Aws { vpce_id }) => { - // Convert the vcpe_id to a string - String::from_utf8(vpce_id.to_vec()).unwrap_or_default() - } + Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(), Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(), }; let allowed_vpc_endpoint_ids = api.get_allowed_vpc_endpoint_ids(ctx, &info).await?; @@ -451,7 +448,7 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> { Ok((Backend::ControlPlane(api, credentials), ip_allowlist)) } Self::Local(_) => { - return Err(auth::AuthError::bad_auth_method("invalid for local proxy")) + return Err(auth::AuthError::bad_auth_method("invalid for local proxy")); } }; diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index 1f9c8a48b7..422e6f741d 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -358,10 +358,7 @@ impl CancellationHandler { let incoming_vpc_endpoint_id = match ctx.extra() { None => return Err(CancelError::AuthError(AuthError::MissingVPCEndpointId)), - Some(ConnectionInfoExtra::Aws { vpce_id }) => { - // Convert the vcpe_id to a string - String::from_utf8(vpce_id.to_vec()).unwrap_or_default() - } + Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(), Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(), }; diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index 1044f5f8e2..a2e7299d39 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -241,6 +241,7 @@ pub(crate) async fn handle_client( Ok(Some(ProxyPassthrough { client: stream, aux: node.aux.clone(), + private_link_id: None, compute: node, session_id: ctx.session_id(), cancel: session, diff --git a/proxy/src/protocol2.rs b/proxy/src/protocol2.rs index 74a15d9bf4..99d645878f 100644 --- a/proxy/src/protocol2.rs +++ b/proxy/src/protocol2.rs @@ -9,6 +9,7 @@ use std::task::{Context, Poll}; use bytes::{Buf, Bytes, BytesMut}; use pin_project_lite::pin_project; +use smol_str::SmolStr; use strum_macros::FromRepr; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf}; use zerocopy::{FromBytes, FromZeroes}; @@ -99,7 +100,7 @@ impl fmt::Display for ConnectionInfo { #[derive(PartialEq, Eq, Clone, Debug)] pub enum ConnectionInfoExtra { - Aws { vpce_id: Bytes }, + Aws { vpce_id: SmolStr }, Azure { link_id: u32 }, } @@ -193,7 +194,7 @@ fn process_proxy_payload( return Err(io::Error::new( io::ErrorKind::Other, "invalid proxy protocol address family/transport protocol.", - )) + )); } }; @@ -207,9 +208,14 @@ fn process_proxy_payload( } let subtype = tlv.value.get_u8(); match Pp2AwsType::from_repr(subtype) { - Some(Pp2AwsType::VpceId) => { - extra = Some(ConnectionInfoExtra::Aws { vpce_id: tlv.value }); - } + Some(Pp2AwsType::VpceId) => match std::str::from_utf8(&tlv.value) { + Ok(s) => { + extra = Some(ConnectionInfoExtra::Aws { vpce_id: s.into() }); + } + Err(e) => { + tracing::warn!("invalid aws vpce id: {e}"); + } + }, None => { tracing::warn!("unknown aws tlv: subtype={subtype}"); } diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 2a406fcb34..49566e5172 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -16,7 +16,7 @@ use once_cell::sync::OnceCell; use pq_proto::{BeMessage as Be, CancelKeyData, StartupMessageParams}; use regex::Regex; use serde::{Deserialize, Serialize}; -use smol_str::{format_smolstr, SmolStr}; +use smol_str::{format_smolstr, SmolStr, ToSmolStr}; use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio_util::sync::CancellationToken; @@ -29,7 +29,7 @@ use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig}; use crate::context::RequestContext; use crate::error::ReportableError; use crate::metrics::{Metrics, NumClientConnectionsGuard}; -use crate::protocol2::{read_proxy_protocol, ConnectHeader, ConnectionInfo}; +use crate::protocol2::{read_proxy_protocol, ConnectHeader, ConnectionInfo, ConnectionInfoExtra}; use crate::proxy::handshake::{handshake, HandshakeData}; use crate::rate_limiter::EndpointRateLimiter; use crate::stream::{PqStream, Stream}; @@ -100,22 +100,34 @@ pub async fn task_main( debug!("healthcheck received"); return; } - Ok((_socket, ConnectHeader::Missing)) if config.proxy_protocol_v2 == ProxyProtocolV2::Required => { + Ok((_socket, ConnectHeader::Missing)) + if config.proxy_protocol_v2 == ProxyProtocolV2::Required => + { warn!("missing required proxy protocol header"); return; } - Ok((_socket, ConnectHeader::Proxy(_))) if config.proxy_protocol_v2 == ProxyProtocolV2::Rejected => { + Ok((_socket, ConnectHeader::Proxy(_))) + if config.proxy_protocol_v2 == ProxyProtocolV2::Rejected => + { warn!("proxy protocol header not supported"); return; } Ok((socket, ConnectHeader::Proxy(info))) => (socket, info), - Ok((socket, ConnectHeader::Missing)) => (socket, ConnectionInfo { addr: peer_addr, extra: None }), + Ok((socket, ConnectHeader::Missing)) => ( + socket, + ConnectionInfo { + addr: peer_addr, + extra: None, + }, + ), }; match socket.inner.set_nodelay(true) { Ok(()) => {} Err(e) => { - error!("per-client task finished with an error: failed to set socket option: {e:#}"); + error!( + "per-client task finished with an error: failed to set socket option: {e:#}" + ); return; } } @@ -156,10 +168,16 @@ pub async fn task_main( match p.proxy_pass(&config.connect_to_compute).await { Ok(()) => {} Err(ErrorSource::Client(e)) => { - warn!(?session_id, "per-client task finished with an IO error from the client: {e:#}"); + warn!( + ?session_id, + "per-client task finished with an IO error from the client: {e:#}" + ); } Err(ErrorSource::Compute(e)) => { - error!(?session_id, "per-client task finished with an IO error from the compute: {e:#}"); + error!( + ?session_id, + "per-client task finished with an IO error from the compute: {e:#}" + ); } } } @@ -374,9 +392,16 @@ pub(crate) async fn handle_client( let (stream, read_buf) = stream.into_inner(); node.stream.write_all(&read_buf).await?; + let private_link_id = match ctx.extra() { + Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()), + Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()), + None => None, + }; + Ok(Some(ProxyPassthrough { client: stream, aux: node.aux.clone(), + private_link_id, compute: node, session_id: ctx.session_id(), cancel: session, diff --git a/proxy/src/proxy/passthrough.rs b/proxy/src/proxy/passthrough.rs index 08871380d6..23b9897155 100644 --- a/proxy/src/proxy/passthrough.rs +++ b/proxy/src/proxy/passthrough.rs @@ -1,3 +1,4 @@ +use smol_str::SmolStr; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::debug; use utils::measured_stream::MeasuredStream; @@ -9,7 +10,7 @@ use crate::config::ComputeConfig; use crate::control_plane::messages::MetricsAuxInfo; use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard}; use crate::stream::Stream; -use crate::usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS}; +use crate::usage_metrics::{Ids, MetricCounterRecorder, TrafficDirection, USAGE_METRICS}; /// Forward bytes in both directions (client <-> compute). #[tracing::instrument(skip_all)] @@ -17,10 +18,14 @@ pub(crate) async fn proxy_pass( client: impl AsyncRead + AsyncWrite + Unpin, compute: impl AsyncRead + AsyncWrite + Unpin, aux: MetricsAuxInfo, + private_link_id: Option, ) -> Result<(), ErrorSource> { - let usage = USAGE_METRICS.register(Ids { + // we will report ingress at a later date + let usage_tx = USAGE_METRICS.register(Ids { endpoint_id: aux.endpoint_id, branch_id: aux.branch_id, + direction: TrafficDirection::Egress, + private_link_id, }); let metrics = &Metrics::get().proxy.io_bytes; @@ -31,7 +36,7 @@ pub(crate) async fn proxy_pass( |cnt| { // Number of bytes we sent to the client (outbound). metrics.get_metric(m_sent).inc_by(cnt as u64); - usage.record_egress(cnt as u64); + usage_tx.record_egress(cnt as u64); }, ); @@ -61,6 +66,7 @@ pub(crate) struct ProxyPassthrough { pub(crate) compute: PostgresConnection, pub(crate) aux: MetricsAuxInfo, pub(crate) session_id: uuid::Uuid, + pub(crate) private_link_id: Option, pub(crate) cancel: cancellation::Session, pub(crate) _req: NumConnectionRequestsGuard<'static>, @@ -72,7 +78,13 @@ impl ProxyPassthrough { self, compute_config: &ComputeConfig, ) -> Result<(), ErrorSource> { - let res = proxy_pass(self.client, self.compute.stream, self.aux).await; + let res = proxy_pass( + self.client, + self.compute.stream, + self.aux, + self.private_link_id, + ) + .await; if let Err(err) = self .compute .cancel_closure diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index f35c375ba2..70dd7bc0e7 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -75,10 +75,7 @@ impl PoolingBackend { let extra = ctx.extra(); let incoming_endpoint_id = match extra { None => String::new(), - Some(ConnectionInfoExtra::Aws { vpce_id }) => { - // Convert the vcpe_id to a string - String::from_utf8(vpce_id.to_vec()).unwrap_or_default() - } + Some(ConnectionInfoExtra::Aws { vpce_id }) => vpce_id.to_string(), Some(ConnectionInfoExtra::Azure { link_id }) => link_id.to_string(), }; diff --git a/proxy/src/serverless/conn_pool_lib.rs b/proxy/src/serverless/conn_pool_lib.rs index a300198de4..9e21491655 100644 --- a/proxy/src/serverless/conn_pool_lib.rs +++ b/proxy/src/serverless/conn_pool_lib.rs @@ -9,6 +9,7 @@ use clashmap::ClashMap; use parking_lot::RwLock; use postgres_client::ReadyForQueryStatus; use rand::Rng; +use smol_str::ToSmolStr; use tracing::{debug, info, Span}; use super::backend::HttpConnError; @@ -19,8 +20,9 @@ use crate::auth::backend::ComputeUserInfo; use crate::context::RequestContext; use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo}; use crate::metrics::{HttpEndpointPoolsGuard, Metrics}; +use crate::protocol2::ConnectionInfoExtra; use crate::types::{DbName, EndpointCacheKey, RoleName}; -use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS}; +use crate::usage_metrics::{Ids, MetricCounter, TrafficDirection, USAGE_METRICS}; #[derive(Debug, Clone)] pub(crate) struct ConnInfo { @@ -473,7 +475,9 @@ where .http_pool_opened_connections .get_metric() .dec_by(clients_removed as i64); - info!("pool: performed global pool gc. removed {clients_removed} clients, total number of clients in pool is {size}"); + info!( + "pool: performed global pool gc. removed {clients_removed} clients, total number of clients in pool is {size}" + ); } let removed = current_len - new_len; @@ -635,15 +639,28 @@ impl Client { (&mut inner.inner, Discard { conn_info, pool }) } - pub(crate) fn metrics(&self) -> Arc { + pub(crate) fn metrics( + &self, + direction: TrafficDirection, + ctx: &RequestContext, + ) -> Arc { let aux = &self .inner .as_ref() .expect("client inner should not be removed") .aux; + + let private_link_id = match ctx.extra() { + None => None, + Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()), + Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()), + }; + USAGE_METRICS.register(Ids { endpoint_id: aux.endpoint_id, branch_id: aux.branch_id, + direction, + private_link_id, }) } } @@ -700,7 +717,9 @@ impl Discard<'_, C> { pub(crate) fn discard(&mut self) { let conn_info = &self.conn_info; if std::mem::take(self.pool).strong_count() > 0 { - info!("pool: throwing away connection '{conn_info}' because connection is potentially in a broken state"); + info!( + "pool: throwing away connection '{conn_info}' because connection is potentially in a broken state" + ); } } } diff --git a/proxy/src/serverless/http_conn_pool.rs b/proxy/src/serverless/http_conn_pool.rs index fde38d0de3..fa21f24a1c 100644 --- a/proxy/src/serverless/http_conn_pool.rs +++ b/proxy/src/serverless/http_conn_pool.rs @@ -5,6 +5,7 @@ use std::sync::{Arc, Weak}; use hyper::client::conn::http2; use hyper_util::rt::{TokioExecutor, TokioIo}; use parking_lot::RwLock; +use smol_str::ToSmolStr; use tokio::net::TcpStream; use tracing::{debug, error, info, info_span, Instrument}; @@ -16,8 +17,9 @@ use super::conn_pool_lib::{ use crate::context::RequestContext; use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo}; use crate::metrics::{HttpEndpointPoolsGuard, Metrics}; +use crate::protocol2::ConnectionInfoExtra; use crate::types::EndpointCacheKey; -use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS}; +use crate::usage_metrics::{Ids, MetricCounter, TrafficDirection, USAGE_METRICS}; pub(crate) type Send = http2::SendRequest; pub(crate) type Connect = @@ -264,11 +266,24 @@ impl Client { Self { inner } } - pub(crate) fn metrics(&self) -> Arc { + pub(crate) fn metrics( + &self, + direction: TrafficDirection, + ctx: &RequestContext, + ) -> Arc { let aux = &self.inner.aux; + + let private_link_id = match ctx.extra() { + None => None, + Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()), + Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()), + }; + USAGE_METRICS.register(Ids { endpoint_id: aux.endpoint_id, branch_id: aux.branch_id, + direction, + private_link_id, }) } } diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 5982fe225d..7c21d90ed8 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -42,7 +42,7 @@ use crate::metrics::{HttpDirection, Metrics}; use crate::proxy::{run_until_cancelled, NeonOptions}; use crate::serverless::backend::HttpConnError; use crate::types::{DbName, RoleName}; -use crate::usage_metrics::{MetricCounter, MetricCounterRecorder}; +use crate::usage_metrics::{MetricCounter, MetricCounterRecorder, TrafficDirection}; #[derive(serde::Deserialize)] #[serde(rename_all = "camelCase")] @@ -209,7 +209,7 @@ fn get_conn_info( } } Some(url::Host::Ipv4(_) | url::Host::Ipv6(_)) | None => { - return Err(ConnInfoError::MissingHostname) + return Err(ConnInfoError::MissingHostname); } }; ctx.set_endpoint_id(endpoint.clone()); @@ -745,7 +745,7 @@ async fn handle_db_inner( } }; - let metrics = client.metrics(); + let metrics = client.metrics(TrafficDirection::Egress, ctx); let len = json_output.len(); let response = response @@ -818,7 +818,7 @@ async fn handle_auth_broker_inner( .expect("all headers and params received via hyper should be valid for request"); // todo: map body to count egress - let _metrics = client.metrics(); + let _metrics = client.metrics(TrafficDirection::Egress, ctx); Ok(client .inner @@ -1118,10 +1118,10 @@ enum Discard<'a> { } impl Client { - fn metrics(&self) -> Arc { + fn metrics(&self, direction: TrafficDirection, ctx: &RequestContext) -> Arc { match self { - Client::Remote(client) => client.metrics(), - Client::Local(local_client) => local_client.metrics(), + Client::Remote(client) => client.metrics(direction, ctx), + Client::Local(local_client) => local_client.metrics(direction, ctx), } } diff --git a/proxy/src/usage_metrics.rs b/proxy/src/usage_metrics.rs index d369e3742f..6a23f0e129 100644 --- a/proxy/src/usage_metrics.rs +++ b/proxy/src/usage_metrics.rs @@ -16,6 +16,7 @@ use consumption_metrics::{idempotency_key, Event, EventChunk, EventType, CHUNK_S use once_cell::sync::Lazy; use remote_storage::{GenericRemoteStorage, RemotePath, TimeoutOrCancel}; use serde::{Deserialize, Serialize}; +use smol_str::SmolStr; use tokio::io::AsyncWriteExt; use tokio_util::sync::CancellationToken; use tracing::{error, info, instrument, trace, warn}; @@ -43,6 +44,37 @@ const HTTP_REPORTING_RETRY_DURATION: Duration = Duration::from_secs(60); pub(crate) struct Ids { pub(crate) endpoint_id: EndpointIdInt, pub(crate) branch_id: BranchIdInt, + pub(crate) direction: TrafficDirection, + #[serde(with = "none_as_empty_string")] + pub(crate) private_link_id: Option, +} + +mod none_as_empty_string { + use serde::Deserialize; + use smol_str::SmolStr; + + #[allow(clippy::ref_option)] + pub fn serialize(t: &Option, s: S) -> Result { + s.serialize_str(t.as_deref().unwrap_or("")) + } + + pub fn deserialize<'de, D: serde::Deserializer<'de>>( + d: D, + ) -> Result, D::Error> { + let s = SmolStr::deserialize(d)?; + if s.is_empty() { + Ok(None) + } else { + Ok(Some(s)) + } + } +} + +#[derive(Eq, Hash, PartialEq, Serialize, Deserialize, Debug, Clone)] +#[serde(rename_all = "lowercase")] +pub(crate) enum TrafficDirection { + Ingress, + Egress, } pub(crate) trait MetricCounterRecorder { @@ -505,6 +537,8 @@ mod tests { let counter = metrics.register(Ids { endpoint_id: (&EndpointId::from("e1")).into(), branch_id: (&BranchId::from("b1")).into(), + direction: TrafficDirection::Egress, + private_link_id: None, }); // the counter should be observed despite 0 egress