From f39fca0049c36cff5ed2c4b890b08f86fa56c15b Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Thu, 30 Nov 2023 20:52:30 +0000 Subject: [PATCH] proxy: chore: replace strings with SmolStr (#5786) ## Problem no problem ## Summary of changes replaces boxstr with arcstr as it's cheaper to clone. mild perf improvement. probably should look into other smallstring optimsations tbh, they will likely be even better. The longest endpoint name I was able to construct is something like `ep-weathered-wildflower-12345678` which is 32 bytes. Most string optimisations top out at 23 bytes --- Cargo.lock | 10 ++++++++++ Cargo.toml | 1 + proxy/Cargo.toml | 1 + proxy/src/auth/backend/link.rs | 2 +- proxy/src/bin/pg_sni_router.rs | 2 +- proxy/src/console/messages.rs | 9 +++++---- proxy/src/console/provider.rs | 2 +- proxy/src/console/provider/neon.rs | 2 +- proxy/src/proxy.rs | 8 ++++---- proxy/src/serverless/conn_pool.rs | 25 +++++++++++++------------ proxy/src/serverless/sql_over_http.rs | 10 +++++----- proxy/src/usage_metrics.rs | 9 +++++---- 12 files changed, 48 insertions(+), 33 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6546590f6c..5639665758 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3643,6 +3643,7 @@ dependencies = [ "serde", "serde_json", "sha2", + "smol_str", "socket2 0.5.3", "sync_wrapper", "task-local-extensions", @@ -4709,6 +4710,15 @@ version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "62bb4feee49fdd9f707ef802e22365a35de4b7b299de4763d44bfea899442ff9" +[[package]] +name = "smol_str" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74212e6bbe9a4352329b2f68ba3130c15a3f26fe88ff22dbdc6cdd58fa85e99c" +dependencies = [ + "serde", +] + [[package]] name = "socket2" version = "0.4.9" diff --git a/Cargo.toml b/Cargo.toml index cbcb25359d..ba8b49c0e0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -132,6 +132,7 @@ serde_assert = "0.5.0" sha2 = "0.10.2" signal-hook = "0.3" smallvec = "1.11" +smol_str = { version = "0.2.0", features = ["serde"] } socket2 = "0.5" strum = "0.24" strum_macros = "0.24" diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 0822718bae..48c8604d86 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -69,6 +69,7 @@ webpki-roots.workspace = true x509-parser.workspace = true native-tls.workspace = true postgres-native-tls.workspace = true +smol_str.workspace = true workspace_hack.workspace = true tokio-util.workspace = true diff --git a/proxy/src/auth/backend/link.rs b/proxy/src/auth/backend/link.rs index da43cf11c4..3a77d7e5ca 100644 --- a/proxy/src/auth/backend/link.rs +++ b/proxy/src/auth/backend/link.rs @@ -106,7 +106,7 @@ pub(super) async fn authenticate( reported_auth_ok: true, value: NodeInfo { config, - aux: db_info.aux.into(), + aux: db_info.aux, allow_self_signed_compute: false, // caller may override }, }) diff --git a/proxy/src/bin/pg_sni_router.rs b/proxy/src/bin/pg_sni_router.rs index 2b859fc2db..bedbdbcc83 100644 --- a/proxy/src/bin/pg_sni_router.rs +++ b/proxy/src/bin/pg_sni_router.rs @@ -284,5 +284,5 @@ async fn handle_client( let client = tokio::net::TcpStream::connect(destination).await?; let metrics_aux: MetricsAuxInfo = Default::default(); - proxy::proxy::proxy_pass(tls_stream, client, &metrics_aux).await + proxy::proxy::proxy_pass(tls_stream, client, metrics_aux).await } diff --git a/proxy/src/console/messages.rs b/proxy/src/console/messages.rs index e5f1615b14..837379b21f 100644 --- a/proxy/src/console/messages.rs +++ b/proxy/src/console/messages.rs @@ -1,4 +1,5 @@ use serde::Deserialize; +use smol_str::SmolStr; use std::fmt; /// Generic error response with human-readable description. @@ -88,11 +89,11 @@ impl fmt::Debug for DatabaseInfo { /// Various labels for prometheus metrics. /// Also known as `ProxyMetricsAuxInfo` in the console. -#[derive(Debug, Deserialize, Default)] +#[derive(Debug, Deserialize, Clone, Default)] pub struct MetricsAuxInfo { - pub endpoint_id: Box, - pub project_id: Box, - pub branch_id: Box, + pub endpoint_id: SmolStr, + pub project_id: SmolStr, + pub branch_id: SmolStr, } impl MetricsAuxInfo { diff --git a/proxy/src/console/provider.rs b/proxy/src/console/provider.rs index a525de8e53..e735b9f66c 100644 --- a/proxy/src/console/provider.rs +++ b/proxy/src/console/provider.rs @@ -229,7 +229,7 @@ pub struct NodeInfo { pub config: compute::ConnCfg, /// Labels for proxy's metrics. - pub aux: Arc, + pub aux: MetricsAuxInfo, /// Whether we should accept self-signed certificates (for testing) pub allow_self_signed_compute: bool, diff --git a/proxy/src/console/provider/neon.rs b/proxy/src/console/provider/neon.rs index 117d0ec190..7828a7d7e4 100644 --- a/proxy/src/console/provider/neon.rs +++ b/proxy/src/console/provider/neon.rs @@ -144,7 +144,7 @@ impl Api { let node = NodeInfo { config, - aux: body.aux.into(), + aux: body.aux, allow_self_signed_compute: false, }; diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index c4bea13f7f..36d01f9acc 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -877,11 +877,11 @@ async fn prepare_client_connection( pub async fn proxy_pass( client: impl AsyncRead + AsyncWrite + Unpin, compute: impl AsyncRead + AsyncWrite + Unpin, - aux: &MetricsAuxInfo, + aux: MetricsAuxInfo, ) -> anyhow::Result<()> { let usage = USAGE_METRICS.register(Ids { - endpoint_id: aux.endpoint_id.to_string(), - branch_id: aux.branch_id.to_string(), + endpoint_id: aux.endpoint_id.clone(), + branch_id: aux.branch_id.clone(), }); let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["tx"]); @@ -1032,7 +1032,7 @@ impl Client<'_, S> { // immediately after opening the connection. let (stream, read_buf) = stream.into_inner(); node.stream.write_all(&read_buf).await?; - proxy_pass(stream, node.stream, &aux).await + proxy_pass(stream, node.stream, aux).await } } diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index 2072cadc3a..ca7a9ad0a0 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -8,6 +8,7 @@ use pbkdf2::{ Params, Pbkdf2, }; use pq_proto::StartupMessageParams; +use smol_str::SmolStr; use std::{collections::HashMap, net::SocketAddr, sync::Arc}; use std::{ fmt, @@ -41,16 +42,16 @@ const MAX_CONNS_PER_ENDPOINT: usize = 20; #[derive(Debug, Clone)] pub struct ConnInfo { - pub username: String, - pub dbname: String, - pub hostname: String, - pub password: String, - pub options: Option, + pub username: SmolStr, + pub dbname: SmolStr, + pub hostname: SmolStr, + pub password: SmolStr, + pub options: Option, } impl ConnInfo { // hm, change to hasher to avoid cloning? - pub fn db_and_user(&self) -> (String, String) { + pub fn db_and_user(&self) -> (SmolStr, SmolStr) { (self.dbname.clone(), self.username.clone()) } } @@ -70,7 +71,7 @@ struct ConnPoolEntry { // Per-endpoint connection pool, (dbname, username) -> DbUserConnPool // Number of open connections is limited by the `max_conns_per_endpoint`. pub struct EndpointConnPool { - pools: HashMap<(String, String), DbUserConnPool>, + pools: HashMap<(SmolStr, SmolStr), DbUserConnPool>, total_conns: usize, } @@ -95,7 +96,7 @@ pub struct GlobalConnPool { // // That should be a fairly conteded map, so return reference to the per-endpoint // pool as early as possible and release the lock. - global_pool: DashMap>>, + global_pool: DashMap>>, /// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each. /// That seems like far too much effort, so we're using a relaxed increment counter instead. @@ -327,7 +328,7 @@ impl GlobalConnPool { Ok(()) } - fn get_or_create_endpoint_pool(&self, endpoint: &String) -> Arc> { + fn get_or_create_endpoint_pool(&self, endpoint: &SmolStr) -> Arc> { // fast path if let Some(pool) = self.global_pool.get(endpoint) { return pool.clone(); @@ -468,7 +469,7 @@ async fn connect_to_compute_once( let (client, mut connection) = config .user(&conn_info.username) - .password(&conn_info.password) + .password(&*conn_info.password) .dbname(&conn_info.dbname) .connect_timeout(timeout) .connect(tokio_postgres::NoTls) @@ -482,8 +483,8 @@ async fn connect_to_compute_once( info!(%conn_info, %session, "new connection"); }); let ids = Ids { - endpoint_id: node_info.aux.endpoint_id.to_string(), - branch_id: node_info.aux.branch_id.to_string(), + endpoint_id: node_info.aux.endpoint_id.clone(), + branch_id: node_info.aux.branch_id.clone(), }; tokio::spawn( diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 25b96668de..6c337a837c 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -182,16 +182,16 @@ fn get_conn_info( for (key, value) in pairs { if key == "options" { - options = Some(value.to_string()); + options = Some(value.into()); break; } } Ok(ConnInfo { - username: username.to_owned(), - dbname: dbname.to_owned(), - hostname: hostname.to_owned(), - password: password.to_owned(), + username: username.into(), + dbname: dbname.into(), + hostname: hostname.into(), + password: password.into(), options, }) } diff --git a/proxy/src/usage_metrics.rs b/proxy/src/usage_metrics.rs index 180b5f7199..789a4c680c 100644 --- a/proxy/src/usage_metrics.rs +++ b/proxy/src/usage_metrics.rs @@ -6,6 +6,7 @@ use consumption_metrics::{idempotency_key, Event, EventChunk, EventType, CHUNK_S use dashmap::{mapref::entry::Entry, DashMap}; use once_cell::sync::Lazy; use serde::{Deserialize, Serialize}; +use smol_str::SmolStr; use std::{ convert::Infallible, sync::{ @@ -29,8 +30,8 @@ const DEFAULT_HTTP_REPORTING_TIMEOUT: Duration = Duration::from_secs(60); /// because we enrich the event with project_id in the control-plane endpoint. #[derive(Eq, Hash, PartialEq, Serialize, Deserialize, Debug, Clone)] pub struct Ids { - pub endpoint_id: String, - pub branch_id: String, + pub endpoint_id: SmolStr, + pub branch_id: SmolStr, } #[derive(Debug)] @@ -290,8 +291,8 @@ mod tests { // register a new counter let counter = metrics.register(Ids { - endpoint_id: "e1".to_string(), - branch_id: "b1".to_string(), + endpoint_id: "e1".into(), + branch_id: "b1".into(), }); // the counter should be observed despite 0 egress