From 528fb1bd81ea4a71e903e1b1be3c3b08ced0ce73 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Thu, 28 Sep 2023 11:38:26 +0100 Subject: [PATCH] proxy: metrics2 (#5179) ## Problem We need to count metrics always when a connection is open. Not only when the transfer is 0. We also need to count bytes usage for HTTP. ## Summary of changes New structure for usage metrics. A `DashMap>`. If the arc has 1 owner (the map) then I can conclude that no connections are open. If the counters has "open_connections" non zero, then I can conclude a new connection was opened in the last interval and should be reported on. Also, keep count of how many bytes processed for HTTP and report it here. --- Cargo.lock | 1 + Cargo.toml | 1 + libs/consumption_metrics/src/lib.rs | 2 +- proxy/Cargo.toml | 1 + proxy/src/http/conn_pool.rs | 20 +- proxy/src/http/sql_over_http.rs | 102 ++++++-- proxy/src/http/websocket.rs | 43 +--- proxy/src/metrics.rs | 363 ++++++++++++++++++---------- proxy/src/proxy.rs | 7 + 9 files changed, 347 insertions(+), 193 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 55c80e30a7..b22f081bdc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3246,6 +3246,7 @@ dependencies = [ "reqwest-tracing", "routerify", "rstest", + "rustc-hash", "rustls", "rustls-pemfile", "scopeguard", diff --git a/Cargo.toml b/Cargo.toml index 4fe3069822..b0bcf69039 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -107,6 +107,7 @@ reqwest-middleware = "0.2.0" reqwest-retry = "0.2.2" routerify = "3" rpds = "0.13" +rustc-hash = "1.1.0" rustls = "0.21" rustls-pemfile = "1" rustls-split = "0.3" diff --git a/libs/consumption_metrics/src/lib.rs b/libs/consumption_metrics/src/lib.rs index 7b133c61af..9e89327e84 100644 --- a/libs/consumption_metrics/src/lib.rs +++ b/libs/consumption_metrics/src/lib.rs @@ -107,7 +107,7 @@ pub const CHUNK_SIZE: usize = 1000; // Just a wrapper around a slice of events // to serialize it as `{"events" : [ ] } -#[derive(serde::Serialize)] +#[derive(serde::Serialize, serde::Deserialize)] pub struct EventChunk<'a, T: Clone> { pub events: std::borrow::Cow<'a, [T]>, } diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index cbab0c6f07..92498d3ecd 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -42,6 +42,7 @@ reqwest-middleware.workspace = true reqwest-retry.workspace = true reqwest-tracing.workspace = true routerify.workspace = true +rustc-hash.workspace = true rustls-pemfile.workspace = true rustls.workspace = true scopeguard.workspace = true diff --git a/proxy/src/http/conn_pool.rs b/proxy/src/http/conn_pool.rs index e771e5d7ed..a7ef15d342 100644 --- a/proxy/src/http/conn_pool.rs +++ b/proxy/src/http/conn_pool.rs @@ -17,11 +17,12 @@ use std::{ use tokio::time; use tokio_postgres::AsyncMessage; -use crate::{auth, console}; +use crate::{ + auth, console, + metrics::{Ids, MetricCounter, USAGE_METRICS}, +}; use crate::{compute, config}; -use super::sql_over_http::MAX_RESPONSE_SIZE; - use crate::proxy::ConnectMechanism; use tracing::{error, warn}; @@ -400,7 +401,6 @@ async fn connect_to_compute_once( .user(&conn_info.username) .password(&conn_info.password) .dbname(&conn_info.dbname) - .max_backend_message_size(MAX_RESPONSE_SIZE) .connect_timeout(timeout) .connect(tokio_postgres::NoTls) .await?; @@ -412,6 +412,10 @@ async fn connect_to_compute_once( span.in_scope(|| { info!(%conn_info, %session, "new connection"); }); + let ids = Ids { + endpoint_id: node_info.aux.endpoint_id.to_string(), + branch_id: node_info.aux.branch_id.to_string(), + }; tokio::spawn( poll_fn(move |cx| { @@ -450,10 +454,18 @@ async fn connect_to_compute_once( Ok(Client { inner: client, session: tx, + ids, }) } pub struct Client { pub inner: tokio_postgres::Client, session: tokio::sync::watch::Sender, + ids: Ids, +} + +impl Client { + pub fn metrics(&self) -> Arc { + USAGE_METRICS.register(self.ids.clone()) + } } diff --git a/proxy/src/http/sql_over_http.rs b/proxy/src/http/sql_over_http.rs index fe57096105..b74b3e9646 100644 --- a/proxy/src/http/sql_over_http.rs +++ b/proxy/src/http/sql_over_http.rs @@ -3,10 +3,12 @@ use std::sync::Arc; use anyhow::bail; use futures::pin_mut; use futures::StreamExt; -use hashbrown::HashMap; use hyper::body::HttpBody; +use hyper::header; use hyper::http::HeaderName; use hyper::http::HeaderValue; +use hyper::Response; +use hyper::StatusCode; use hyper::{Body, HeaderMap, Request}; use serde_json::json; use serde_json::Map; @@ -16,7 +18,11 @@ use tokio_postgres::types::Type; use tokio_postgres::GenericClient; use tokio_postgres::IsolationLevel; use tokio_postgres::Row; +use tracing::error; +use tracing::instrument; use url::Url; +use utils::http::error::ApiError; +use utils::http::json::json_response; use super::conn_pool::ConnInfo; use super::conn_pool::GlobalConnPool; @@ -39,7 +45,6 @@ enum Payload { Batch(BatchQueryData), } -pub const MAX_RESPONSE_SIZE: usize = 10 * 1024 * 1024; // 10 MB const MAX_REQUEST_SIZE: u64 = 1024 * 1024; // 1 MB static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output"); @@ -182,7 +187,45 @@ pub async fn handle( sni_hostname: Option, conn_pool: Arc, session_id: uuid::Uuid, -) -> anyhow::Result<(Value, HashMap)> { +) -> Result, ApiError> { + let result = handle_inner(request, sni_hostname, conn_pool, session_id).await; + + let mut response = match result { + Ok(r) => r, + Err(e) => { + let message = format!("{:?}", e); + let code = match e.downcast_ref::() { + Some(e) => match e.code() { + Some(e) => serde_json::to_value(e.code()).unwrap(), + None => Value::Null, + }, + None => Value::Null, + }; + error!( + ?code, + "sql-over-http per-client task finished with an error: {e:#}" + ); + // TODO: this shouldn't always be bad request. + json_response( + StatusCode::BAD_REQUEST, + json!({ "message": message, "code": code }), + )? + } + }; + response.headers_mut().insert( + "Access-Control-Allow-Origin", + hyper::http::HeaderValue::from_static("*"), + ); + Ok(response) +} + +#[instrument(name = "sql-over-http", skip_all)] +async fn handle_inner( + request: Request, + sni_hostname: Option, + conn_pool: Arc, + session_id: uuid::Uuid, +) -> anyhow::Result> { // // Determine the destination and connection params // @@ -233,13 +276,18 @@ pub async fn handle( let mut client = conn_pool.get(&conn_info, !allow_pool, session_id).await?; + let mut response = Response::builder() + .status(StatusCode::OK) + .header(header::CONTENT_TYPE, "application/json"); + // // Now execute the query and return the result // + let mut size = 0; let result = match payload { - Payload::Single(query) => query_to_json(&client.inner, query, raw_output, array_mode) - .await - .map(|x| (x, HashMap::default())), + Payload::Single(query) => { + query_to_json(&client.inner, query, &mut size, raw_output, array_mode).await + } Payload::Batch(batch_query) => { let mut results = Vec::new(); let mut builder = client.inner.build_transaction(); @@ -254,7 +302,8 @@ pub async fn handle( } let transaction = builder.start().await?; for query in batch_query.queries { - let result = query_to_json(&transaction, query, raw_output, array_mode).await; + let result = + query_to_json(&transaction, query, &mut size, raw_output, array_mode).await; match result { Ok(r) => results.push(r), Err(e) => { @@ -264,26 +313,27 @@ pub async fn handle( } } transaction.commit().await?; - let mut headers = HashMap::default(); if txn_read_only { - headers.insert( + response = response.header( TXN_READ_ONLY.clone(), HeaderValue::try_from(txn_read_only.to_string())?, ); } if txn_deferrable { - headers.insert( + response = response.header( TXN_DEFERRABLE.clone(), HeaderValue::try_from(txn_deferrable.to_string())?, ); } if let Some(txn_isolation_level) = txn_isolation_level_raw { - headers.insert(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level); + response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level); } - Ok((json!({ "results": results }), headers)) + Ok(json!({ "results": results })) } }; + let metrics = client.metrics(); + if allow_pool { let current_span = tracing::Span::current(); // return connection to the pool @@ -293,12 +343,30 @@ pub async fn handle( }); } - result + match result { + Ok(value) => { + // how could this possibly fail + let body = serde_json::to_string(&value).expect("json serialization should not fail"); + let len = body.len(); + let response = response + .body(Body::from(body)) + // only fails if invalid status code or invalid header/values are given. + // these are not user configurable so it cannot fail dynamically + .expect("building response payload should not fail"); + + // count the egress bytes - we miss the TLS and header overhead but oh well... + // moving this later in the stack is going to be a lot of effort and ehhhh + metrics.record_egress(len as u64); + Ok(response) + } + Err(e) => Err(e), + } } async fn query_to_json( client: &T, data: QueryData, + current_size: &mut usize, raw_output: bool, array_mode: bool, ) -> anyhow::Result { @@ -312,16 +380,10 @@ async fn query_to_json( // big. pin_mut!(row_stream); let mut rows: Vec = Vec::new(); - let mut current_size = 0; while let Some(row) = row_stream.next().await { let row = row?; - current_size += row.body_len(); + *current_size += row.body_len(); rows.push(row); - if current_size > MAX_RESPONSE_SIZE { - return Err(anyhow::anyhow!( - "response is too large (max is {MAX_RESPONSE_SIZE} bytes)" - )); - } } // grab the command tag and number of rows affected diff --git a/proxy/src/http/websocket.rs b/proxy/src/http/websocket.rs index fa66df0469..994a7de764 100644 --- a/proxy/src/http/websocket.rs +++ b/proxy/src/http/websocket.rs @@ -7,7 +7,6 @@ use crate::{ }; use bytes::{Buf, Bytes}; use futures::{Sink, Stream, StreamExt}; -use hashbrown::HashMap; use hyper::{ server::{ accept, @@ -18,7 +17,6 @@ use hyper::{ }; use hyper_tungstenite::{tungstenite::Message, HyperWebsocket, WebSocketStream}; use pin_project_lite::pin_project; -use serde_json::{json, Value}; use std::{ convert::Infallible, @@ -204,44 +202,7 @@ async fn ws_handler( // TODO: that deserves a refactor as now this function also handles http json client besides websockets. // Right now I don't want to blow up sql-over-http patch with file renames and do that as a follow up instead. } else if request.uri().path() == "/sql" && request.method() == Method::POST { - let result = sql_over_http::handle(request, sni_hostname, conn_pool, session_id) - .instrument(info_span!("sql-over-http")) - .await; - let status_code = match result { - Ok(_) => StatusCode::OK, - Err(_) => StatusCode::BAD_REQUEST, - }; - let (json, headers) = match result { - Ok(r) => r, - Err(e) => { - let message = format!("{:?}", e); - let code = match e.downcast_ref::() { - Some(e) => match e.code() { - Some(e) => serde_json::to_value(e.code()).unwrap(), - None => Value::Null, - }, - None => Value::Null, - }; - error!( - ?code, - "sql-over-http per-client task finished with an error: {e:#}" - ); - ( - json!({ "message": message, "code": code }), - HashMap::default(), - ) - } - }; - json_response(status_code, json).map(|mut r| { - r.headers_mut().insert( - "Access-Control-Allow-Origin", - hyper::http::HeaderValue::from_static("*"), - ); - for (k, v) in headers { - r.headers_mut().insert(k, v); - } - r - }) + sql_over_http::handle(request, sni_hostname, conn_pool, session_id).await } else if request.uri().path() == "/sql" && request.method() == Method::OPTIONS { Response::builder() .header("Allow", "OPTIONS, POST") @@ -253,7 +214,7 @@ async fn ws_handler( .header("Access-Control-Max-Age", "86400" /* 24 hours */) .status(StatusCode::OK) // 204 is also valid, but see: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods/OPTIONS#status_code .body(Body::empty()) - .map_err(|e| ApiError::BadRequest(e.into())) + .map_err(|e| ApiError::InternalServerError(e.into())) } else { json_response(StatusCode::BAD_REQUEST, "query is not supported") } diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index 9279002eb3..cfeec5622b 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -3,9 +3,18 @@ use crate::{config::MetricCollectionConfig, http}; use chrono::{DateTime, Utc}; use consumption_metrics::{idempotency_key, Event, EventChunk, EventType, CHUNK_SIZE}; -use serde::Serialize; -use std::{collections::HashMap, convert::Infallible, time::Duration}; -use tracing::{error, info, instrument, trace, warn}; +use dashmap::{mapref::entry::Entry, DashMap}; +use once_cell::sync::Lazy; +use serde::{Deserialize, Serialize}; +use std::{ + convert::Infallible, + sync::{ + atomic::{AtomicU64, AtomicUsize, Ordering}, + Arc, + }, + time::Duration, +}; +use tracing::{error, info, instrument, trace}; const PROXY_IO_BYTES_PER_CLIENT: &str = "proxy_io_bytes_per_client"; @@ -18,12 +27,95 @@ const DEFAULT_HTTP_REPORTING_TIMEOUT: Duration = Duration::from_secs(60); /// Both the proxy and the ingestion endpoint will live in the same region (or cell) /// so while the project-id is unique across regions the whole pipeline will work correctly /// because we enrich the event with project_id in the control-plane endpoint. -#[derive(Eq, Hash, PartialEq, Serialize, Debug, Clone)] +#[derive(Eq, Hash, PartialEq, Serialize, Deserialize, Debug, Clone)] pub struct Ids { pub endpoint_id: String, pub branch_id: String, } +#[derive(Debug)] +pub struct MetricCounter { + transmitted: AtomicU64, + opened_connections: AtomicUsize, +} + +impl MetricCounter { + /// Record that some bytes were sent from the proxy to the client + pub fn record_egress(&self, bytes: u64) { + self.transmitted.fetch_add(bytes, Ordering::AcqRel); + } + + /// extract the value that should be reported + fn should_report(self: &Arc) -> Option { + // heuristic to see if the branch is still open + // if a clone happens while we are observing, the heuristic will be incorrect. + // + // Worst case is that we won't report an event for this endpoint. + // However, for the strong count to be 1 it must have occured that at one instant + // all the endpoints were closed, so missing a report because the endpoints are closed is valid. + let is_open = Arc::strong_count(self) > 1; + let opened = self.opened_connections.swap(0, Ordering::AcqRel); + + // update cached metrics eagerly, even if they can't get sent + // (to avoid sending the same metrics twice) + // see the relevant discussion on why to do so even if the status is not success: + // https://github.com/neondatabase/neon/pull/4563#discussion_r1246710956 + let value = self.transmitted.swap(0, Ordering::AcqRel); + + // Our only requirement is that we report in every interval if there was an open connection + // if there were no opened connections since, then we don't need to report + if value == 0 && !is_open && opened == 0 { + None + } else { + Some(value) + } + } + + /// Determine whether the counter should be cleared from the global map. + fn should_clear(self: &mut Arc) -> bool { + // we can't clear this entry if it's acquired elsewhere + let Some(counter) = Arc::get_mut(self) else { + return false; + }; + let opened = *counter.opened_connections.get_mut(); + let value = *counter.transmitted.get_mut(); + // clear if there's no data to report + value == 0 && opened == 0 + } +} + +// endpoint and branch IDs are not user generated so we don't run the risk of hash-dos +type FastHasher = std::hash::BuildHasherDefault; + +#[derive(Default)] +pub struct Metrics { + endpoints: DashMap, FastHasher>, +} + +impl Metrics { + /// Register a new byte metrics counter for this endpoint + pub fn register(&self, ids: Ids) -> Arc { + let entry = if let Some(entry) = self.endpoints.get(&ids) { + entry.clone() + } else { + self.endpoints + .entry(ids) + .or_insert_with(|| { + Arc::new(MetricCounter { + transmitted: AtomicU64::new(0), + opened_connections: AtomicUsize::new(0), + }) + }) + .clone() + }; + + entry.opened_connections.fetch_add(1, Ordering::AcqRel); + entry + } +} + +pub static USAGE_METRICS: Lazy = Lazy::new(Metrics::default); + pub async fn task_main(config: &MetricCollectionConfig) -> anyhow::Result { info!("metrics collector config: {config:?}"); scopeguard::defer! { @@ -31,145 +123,83 @@ pub async fn task_main(config: &MetricCollectionConfig) -> anyhow::Result)> = HashMap::new(); let hostname = hostname::get()?.as_os_str().to_string_lossy().into_owned(); + let mut prev = Utc::now(); let mut ticker = tokio::time::interval(config.interval); loop { ticker.tick().await; - let res = collect_metrics_iteration( + let now = Utc::now(); + collect_metrics_iteration( + &USAGE_METRICS, &http_client, - &mut cached_metrics, &config.endpoint, &hostname, + prev, + now, ) .await; - - match res { - Err(e) => error!("failed to send consumption metrics: {e} "), - Ok(_) => trace!("periodic metrics collection completed successfully"), - } + prev = now; } } -fn gather_proxy_io_bytes_per_client() -> Vec<(Ids, (u64, DateTime))> { - let mut current_metrics: Vec<(Ids, (u64, DateTime))> = Vec::new(); - let metrics = prometheus::default_registry().gather(); - - for m in metrics { - if m.get_name() == "proxy_io_bytes_per_client" { - for ms in m.get_metric() { - let direction = ms - .get_label() - .iter() - .find(|l| l.get_name() == "direction") - .unwrap() - .get_value(); - - // Only collect metric for outbound traffic - if direction == "tx" { - let endpoint_id = ms - .get_label() - .iter() - .find(|l| l.get_name() == "endpoint_id") - .unwrap() - .get_value(); - let branch_id = ms - .get_label() - .iter() - .find(|l| l.get_name() == "branch_id") - .unwrap() - .get_value(); - - let value = ms.get_counter().get_value() as u64; - - // Report if the metric value is suspiciously large - if value > (1u64 << 40) { - warn!( - "potentially abnormal counter value: branch_id {} endpoint_id {} val: {}", - branch_id, endpoint_id, value - ); - } - - current_metrics.push(( - Ids { - endpoint_id: endpoint_id.to_string(), - branch_id: branch_id.to_string(), - }, - (value, Utc::now()), - )); - } - } - } - } - - current_metrics -} - #[instrument(skip_all)] async fn collect_metrics_iteration( + metrics: &Metrics, client: &http::ClientWithMiddleware, - cached_metrics: &mut HashMap)>, metric_collection_endpoint: &reqwest::Url, hostname: &str, -) -> anyhow::Result<()> { + prev: DateTime, + now: DateTime, +) { info!( "starting collect_metrics_iteration. metric_collection_endpoint: {}", metric_collection_endpoint ); - let current_metrics = gather_proxy_io_bytes_per_client(); + let mut metrics_to_clear = Vec::new(); - let metrics_to_send: Vec> = current_metrics + let metrics_to_send: Vec<(Ids, u64)> = metrics + .endpoints .iter() - .filter_map(|(curr_key, (curr_val, curr_time))| { - let mut start_time = *curr_time; - let mut value = *curr_val; - - if let Some((prev_val, prev_time)) = cached_metrics.get(curr_key) { - // Only send metrics updates if the metric has increased - if curr_val > prev_val { - value = curr_val - prev_val; - start_time = *prev_time; - } else { - if curr_val < prev_val { - error!("proxy_io_bytes_per_client metric value decreased from {} to {} for key {:?}", - prev_val, curr_val, curr_key); - } - return None; - } + .filter_map(|counter| { + let key = counter.key().clone(); + let Some(value) = counter.should_report() else { + metrics_to_clear.push(key); + return None; }; - - Some(Event { - kind: EventType::Incremental { - start_time, - stop_time: *curr_time, - }, - metric: PROXY_IO_BYTES_PER_CLIENT, - idempotency_key: idempotency_key(hostname), - value, - extra: Ids { - endpoint_id: curr_key.endpoint_id.clone(), - branch_id: curr_key.branch_id.clone(), - }, - }) + Some((key, value)) }) .collect(); if metrics_to_send.is_empty() { trace!("no new metrics to send"); - return Ok(()); } // Send metrics. // Split into chunks of 1000 metrics to avoid exceeding the max request size for chunk in metrics_to_send.chunks(CHUNK_SIZE) { + let events = chunk + .iter() + .map(|(ids, value)| Event { + kind: EventType::Incremental { + start_time: prev, + stop_time: now, + }, + metric: PROXY_IO_BYTES_PER_CLIENT, + idempotency_key: idempotency_key(hostname), + value: *value, + extra: Ids { + endpoint_id: ids.endpoint_id.clone(), + branch_id: ids.branch_id.clone(), + }, + }) + .collect(); + let res = client .post(metric_collection_endpoint.clone()) - .json(&EventChunk { - events: chunk.into(), - }) + .json(&EventChunk { events }) .send() .await; @@ -183,34 +213,113 @@ async fn collect_metrics_iteration( if !res.status().is_success() { error!("metrics endpoint refused the sent metrics: {:?}", res); - for metric in chunk.iter().filter(|metric| metric.value > (1u64 << 40)) { + for metric in chunk.iter().filter(|(_, value)| *value > (1u64 << 40)) { // Report if the metric value is suspiciously large error!("potentially abnormal metric value: {:?}", metric); } } - // update cached metrics after they were sent - // (to avoid sending the same metrics twice) - // see the relevant discussion on why to do so even if the status is not success: - // https://github.com/neondatabase/neon/pull/4563#discussion_r1246710956 - for send_metric in chunk { - let stop_time = match send_metric.kind { - EventType::Incremental { stop_time, .. } => stop_time, - _ => unreachable!(), - }; + } - cached_metrics - .entry(Ids { - endpoint_id: send_metric.extra.endpoint_id.clone(), - branch_id: send_metric.extra.branch_id.clone(), - }) - // update cached value (add delta) and time - .and_modify(|e| { - e.0 = e.0.saturating_add(send_metric.value); - e.1 = stop_time - }) - // cache new metric - .or_insert((send_metric.value, stop_time)); + for metric in metrics_to_clear { + match metrics.endpoints.entry(metric) { + Entry::Occupied(mut counter) => { + if counter.get_mut().should_clear() { + counter.remove_entry(); + } + } + Entry::Vacant(_) => {} } } - Ok(()) +} + +#[cfg(test)] +mod tests { + use std::{ + net::TcpListener, + sync::{Arc, Mutex}, + }; + + use anyhow::Error; + use chrono::Utc; + use consumption_metrics::{Event, EventChunk}; + use hyper::{ + service::{make_service_fn, service_fn}, + Body, Response, + }; + use url::Url; + + use super::{collect_metrics_iteration, Ids, Metrics}; + use crate::http; + + #[tokio::test] + async fn metrics() { + let listener = TcpListener::bind("0.0.0.0:0").unwrap(); + + let reports = Arc::new(Mutex::new(vec![])); + let reports2 = reports.clone(); + + let server = hyper::server::Server::from_tcp(listener) + .unwrap() + .serve(make_service_fn(move |_| { + let reports = reports.clone(); + async move { + Ok::<_, Error>(service_fn(move |req| { + let reports = reports.clone(); + async move { + let bytes = hyper::body::to_bytes(req.into_body()).await?; + let events: EventChunk<'static, Event> = + serde_json::from_slice(&bytes)?; + reports.lock().unwrap().push(events); + Ok::<_, Error>(Response::new(Body::from(vec![]))) + } + })) + } + })); + let addr = server.local_addr(); + tokio::spawn(server); + + let metrics = Metrics::default(); + let client = http::new_client(); + let endpoint = Url::parse(&format!("http://{addr}")).unwrap(); + let now = Utc::now(); + + // no counters have been registered + collect_metrics_iteration(&metrics, &client, &endpoint, "foo", now, now).await; + let r = std::mem::take(&mut *reports2.lock().unwrap()); + assert!(r.is_empty()); + + // register a new counter + let counter = metrics.register(Ids { + endpoint_id: "e1".to_string(), + branch_id: "b1".to_string(), + }); + + // the counter should be observed despite 0 egress + collect_metrics_iteration(&metrics, &client, &endpoint, "foo", now, now).await; + let r = std::mem::take(&mut *reports2.lock().unwrap()); + assert_eq!(r.len(), 1); + assert_eq!(r[0].events.len(), 1); + assert_eq!(r[0].events[0].value, 0); + + // record egress + counter.record_egress(1); + + // egress should be observered + collect_metrics_iteration(&metrics, &client, &endpoint, "foo", now, now).await; + let r = std::mem::take(&mut *reports2.lock().unwrap()); + assert_eq!(r.len(), 1); + assert_eq!(r[0].events.len(), 1); + assert_eq!(r[0].events[0].value, 1); + + // release counter + drop(counter); + + // we do not observe the counter + collect_metrics_iteration(&metrics, &client, &endpoint, "foo", now, now).await; + let r = std::mem::take(&mut *reports2.lock().unwrap()); + assert!(r.is_empty()); + + // counter is unregistered + assert!(metrics.endpoints.is_empty()); + } } diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index f9da145859..c8f534b2b7 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -7,6 +7,7 @@ use crate::{ compute::{self, PostgresConnection}, config::{ProxyConfig, TlsConfig}, console::{self, errors::WakeComputeError, messages::MetricsAuxInfo, Api}, + metrics::{Ids, USAGE_METRICS}, protocol2::WithClientIp, stream::{PqStream, Stream}, }; @@ -602,6 +603,11 @@ pub async fn proxy_pass( compute: impl AsyncRead + AsyncWrite + Unpin, aux: &MetricsAuxInfo, ) -> anyhow::Result<()> { + let usage = USAGE_METRICS.register(Ids { + endpoint_id: aux.endpoint_id.to_string(), + branch_id: aux.branch_id.to_string(), + }); + let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&aux.traffic_labels("tx")); let mut client = MeasuredStream::new( client, @@ -609,6 +615,7 @@ pub async fn proxy_pass( |cnt| { // Number of bytes we sent to the client (outbound). m_sent.inc_by(cnt as u64); + usage.record_egress(cnt as u64); }, );