diff --git a/Cargo.lock b/Cargo.lock index 6546590f6c..5639665758 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3643,6 +3643,7 @@ dependencies = [ "serde", "serde_json", "sha2", + "smol_str", "socket2 0.5.3", "sync_wrapper", "task-local-extensions", @@ -4709,6 +4710,15 @@ version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "62bb4feee49fdd9f707ef802e22365a35de4b7b299de4763d44bfea899442ff9" +[[package]] +name = "smol_str" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74212e6bbe9a4352329b2f68ba3130c15a3f26fe88ff22dbdc6cdd58fa85e99c" +dependencies = [ + "serde", +] + [[package]] name = "socket2" version = "0.4.9" diff --git a/Cargo.toml b/Cargo.toml index cbcb25359d..ba8b49c0e0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -132,6 +132,7 @@ serde_assert = "0.5.0" sha2 = "0.10.2" signal-hook = "0.3" smallvec = "1.11" +smol_str = { version = "0.2.0", features = ["serde"] } socket2 = "0.5" strum = "0.24" strum_macros = "0.24" diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 0822718bae..48c8604d86 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -69,6 +69,7 @@ webpki-roots.workspace = true x509-parser.workspace = true native-tls.workspace = true postgres-native-tls.workspace = true +smol_str.workspace = true workspace_hack.workspace = true tokio-util.workspace = true diff --git a/proxy/src/auth/backend/link.rs b/proxy/src/auth/backend/link.rs index da43cf11c4..3a77d7e5ca 100644 --- a/proxy/src/auth/backend/link.rs +++ b/proxy/src/auth/backend/link.rs @@ -106,7 +106,7 @@ pub(super) async fn authenticate( reported_auth_ok: true, value: NodeInfo { config, - aux: db_info.aux.into(), + aux: db_info.aux, allow_self_signed_compute: false, // caller may override }, }) diff --git a/proxy/src/bin/pg_sni_router.rs b/proxy/src/bin/pg_sni_router.rs index 2b859fc2db..bedbdbcc83 100644 --- a/proxy/src/bin/pg_sni_router.rs +++ b/proxy/src/bin/pg_sni_router.rs @@ -284,5 +284,5 @@ async fn handle_client( let client = tokio::net::TcpStream::connect(destination).await?; let metrics_aux: MetricsAuxInfo = Default::default(); - proxy::proxy::proxy_pass(tls_stream, client, &metrics_aux).await + proxy::proxy::proxy_pass(tls_stream, client, metrics_aux).await } diff --git a/proxy/src/console/messages.rs b/proxy/src/console/messages.rs index e5f1615b14..837379b21f 100644 --- a/proxy/src/console/messages.rs +++ b/proxy/src/console/messages.rs @@ -1,4 +1,5 @@ use serde::Deserialize; +use smol_str::SmolStr; use std::fmt; /// Generic error response with human-readable description. @@ -88,11 +89,11 @@ impl fmt::Debug for DatabaseInfo { /// Various labels for prometheus metrics. /// Also known as `ProxyMetricsAuxInfo` in the console. -#[derive(Debug, Deserialize, Default)] +#[derive(Debug, Deserialize, Clone, Default)] pub struct MetricsAuxInfo { - pub endpoint_id: Box, - pub project_id: Box, - pub branch_id: Box, + pub endpoint_id: SmolStr, + pub project_id: SmolStr, + pub branch_id: SmolStr, } impl MetricsAuxInfo { diff --git a/proxy/src/console/provider.rs b/proxy/src/console/provider.rs index a525de8e53..e735b9f66c 100644 --- a/proxy/src/console/provider.rs +++ b/proxy/src/console/provider.rs @@ -229,7 +229,7 @@ pub struct NodeInfo { pub config: compute::ConnCfg, /// Labels for proxy's metrics. - pub aux: Arc, + pub aux: MetricsAuxInfo, /// Whether we should accept self-signed certificates (for testing) pub allow_self_signed_compute: bool, diff --git a/proxy/src/console/provider/neon.rs b/proxy/src/console/provider/neon.rs index 117d0ec190..7828a7d7e4 100644 --- a/proxy/src/console/provider/neon.rs +++ b/proxy/src/console/provider/neon.rs @@ -144,7 +144,7 @@ impl Api { let node = NodeInfo { config, - aux: body.aux.into(), + aux: body.aux, allow_self_signed_compute: false, }; diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index c4bea13f7f..36d01f9acc 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -877,11 +877,11 @@ async fn prepare_client_connection( pub async fn proxy_pass( client: impl AsyncRead + AsyncWrite + Unpin, compute: impl AsyncRead + AsyncWrite + Unpin, - aux: &MetricsAuxInfo, + aux: MetricsAuxInfo, ) -> anyhow::Result<()> { let usage = USAGE_METRICS.register(Ids { - endpoint_id: aux.endpoint_id.to_string(), - branch_id: aux.branch_id.to_string(), + endpoint_id: aux.endpoint_id.clone(), + branch_id: aux.branch_id.clone(), }); let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["tx"]); @@ -1032,7 +1032,7 @@ impl Client<'_, S> { // immediately after opening the connection. let (stream, read_buf) = stream.into_inner(); node.stream.write_all(&read_buf).await?; - proxy_pass(stream, node.stream, &aux).await + proxy_pass(stream, node.stream, aux).await } } diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index 2072cadc3a..ca7a9ad0a0 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -8,6 +8,7 @@ use pbkdf2::{ Params, Pbkdf2, }; use pq_proto::StartupMessageParams; +use smol_str::SmolStr; use std::{collections::HashMap, net::SocketAddr, sync::Arc}; use std::{ fmt, @@ -41,16 +42,16 @@ const MAX_CONNS_PER_ENDPOINT: usize = 20; #[derive(Debug, Clone)] pub struct ConnInfo { - pub username: String, - pub dbname: String, - pub hostname: String, - pub password: String, - pub options: Option, + pub username: SmolStr, + pub dbname: SmolStr, + pub hostname: SmolStr, + pub password: SmolStr, + pub options: Option, } impl ConnInfo { // hm, change to hasher to avoid cloning? - pub fn db_and_user(&self) -> (String, String) { + pub fn db_and_user(&self) -> (SmolStr, SmolStr) { (self.dbname.clone(), self.username.clone()) } } @@ -70,7 +71,7 @@ struct ConnPoolEntry { // Per-endpoint connection pool, (dbname, username) -> DbUserConnPool // Number of open connections is limited by the `max_conns_per_endpoint`. pub struct EndpointConnPool { - pools: HashMap<(String, String), DbUserConnPool>, + pools: HashMap<(SmolStr, SmolStr), DbUserConnPool>, total_conns: usize, } @@ -95,7 +96,7 @@ pub struct GlobalConnPool { // // That should be a fairly conteded map, so return reference to the per-endpoint // pool as early as possible and release the lock. - global_pool: DashMap>>, + global_pool: DashMap>>, /// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each. /// That seems like far too much effort, so we're using a relaxed increment counter instead. @@ -327,7 +328,7 @@ impl GlobalConnPool { Ok(()) } - fn get_or_create_endpoint_pool(&self, endpoint: &String) -> Arc> { + fn get_or_create_endpoint_pool(&self, endpoint: &SmolStr) -> Arc> { // fast path if let Some(pool) = self.global_pool.get(endpoint) { return pool.clone(); @@ -468,7 +469,7 @@ async fn connect_to_compute_once( let (client, mut connection) = config .user(&conn_info.username) - .password(&conn_info.password) + .password(&*conn_info.password) .dbname(&conn_info.dbname) .connect_timeout(timeout) .connect(tokio_postgres::NoTls) @@ -482,8 +483,8 @@ async fn connect_to_compute_once( info!(%conn_info, %session, "new connection"); }); let ids = Ids { - endpoint_id: node_info.aux.endpoint_id.to_string(), - branch_id: node_info.aux.branch_id.to_string(), + endpoint_id: node_info.aux.endpoint_id.clone(), + branch_id: node_info.aux.branch_id.clone(), }; tokio::spawn( diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 25b96668de..6c337a837c 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -182,16 +182,16 @@ fn get_conn_info( for (key, value) in pairs { if key == "options" { - options = Some(value.to_string()); + options = Some(value.into()); break; } } Ok(ConnInfo { - username: username.to_owned(), - dbname: dbname.to_owned(), - hostname: hostname.to_owned(), - password: password.to_owned(), + username: username.into(), + dbname: dbname.into(), + hostname: hostname.into(), + password: password.into(), options, }) } diff --git a/proxy/src/usage_metrics.rs b/proxy/src/usage_metrics.rs index 180b5f7199..789a4c680c 100644 --- a/proxy/src/usage_metrics.rs +++ b/proxy/src/usage_metrics.rs @@ -6,6 +6,7 @@ use consumption_metrics::{idempotency_key, Event, EventChunk, EventType, CHUNK_S use dashmap::{mapref::entry::Entry, DashMap}; use once_cell::sync::Lazy; use serde::{Deserialize, Serialize}; +use smol_str::SmolStr; use std::{ convert::Infallible, sync::{ @@ -29,8 +30,8 @@ const DEFAULT_HTTP_REPORTING_TIMEOUT: Duration = Duration::from_secs(60); /// because we enrich the event with project_id in the control-plane endpoint. #[derive(Eq, Hash, PartialEq, Serialize, Deserialize, Debug, Clone)] pub struct Ids { - pub endpoint_id: String, - pub branch_id: String, + pub endpoint_id: SmolStr, + pub branch_id: SmolStr, } #[derive(Debug)] @@ -290,8 +291,8 @@ mod tests { // register a new counter let counter = metrics.register(Ids { - endpoint_id: "e1".to_string(), - branch_id: "b1".to_string(), + endpoint_id: "e1".into(), + branch_id: "b1".into(), }); // the counter should be observed despite 0 egress