proxy: chore: replace strings with SmolStr (#5786)

## Problem

no problem

## Summary of changes

replaces boxstr with arcstr as it's cheaper to clone. mild perf
improvement.

probably should look into other smallstring optimsations tbh, they will
likely be even better. The longest endpoint name I was able to construct
is something like `ep-weathered-wildflower-12345678` which is 32 bytes.
Most string optimisations top out at 23 bytes
This commit is contained in:
Conrad Ludgate
2023-11-30 20:52:30 +00:00
committed by GitHub
parent b451e75dc6
commit f39fca0049
12 changed files with 48 additions and 33 deletions

10
Cargo.lock generated
View File

@@ -3643,6 +3643,7 @@ dependencies = [
"serde",
"serde_json",
"sha2",
"smol_str",
"socket2 0.5.3",
"sync_wrapper",
"task-local-extensions",
@@ -4709,6 +4710,15 @@ version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62bb4feee49fdd9f707ef802e22365a35de4b7b299de4763d44bfea899442ff9"
[[package]]
name = "smol_str"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "74212e6bbe9a4352329b2f68ba3130c15a3f26fe88ff22dbdc6cdd58fa85e99c"
dependencies = [
"serde",
]
[[package]]
name = "socket2"
version = "0.4.9"

View File

@@ -132,6 +132,7 @@ serde_assert = "0.5.0"
sha2 = "0.10.2"
signal-hook = "0.3"
smallvec = "1.11"
smol_str = { version = "0.2.0", features = ["serde"] }
socket2 = "0.5"
strum = "0.24"
strum_macros = "0.24"

View File

@@ -69,6 +69,7 @@ webpki-roots.workspace = true
x509-parser.workspace = true
native-tls.workspace = true
postgres-native-tls.workspace = true
smol_str.workspace = true
workspace_hack.workspace = true
tokio-util.workspace = true

View File

@@ -106,7 +106,7 @@ pub(super) async fn authenticate(
reported_auth_ok: true,
value: NodeInfo {
config,
aux: db_info.aux.into(),
aux: db_info.aux,
allow_self_signed_compute: false, // caller may override
},
})

View File

@@ -284,5 +284,5 @@ async fn handle_client(
let client = tokio::net::TcpStream::connect(destination).await?;
let metrics_aux: MetricsAuxInfo = Default::default();
proxy::proxy::proxy_pass(tls_stream, client, &metrics_aux).await
proxy::proxy::proxy_pass(tls_stream, client, metrics_aux).await
}

View File

@@ -1,4 +1,5 @@
use serde::Deserialize;
use smol_str::SmolStr;
use std::fmt;
/// Generic error response with human-readable description.
@@ -88,11 +89,11 @@ impl fmt::Debug for DatabaseInfo {
/// Various labels for prometheus metrics.
/// Also known as `ProxyMetricsAuxInfo` in the console.
#[derive(Debug, Deserialize, Default)]
#[derive(Debug, Deserialize, Clone, Default)]
pub struct MetricsAuxInfo {
pub endpoint_id: Box<str>,
pub project_id: Box<str>,
pub branch_id: Box<str>,
pub endpoint_id: SmolStr,
pub project_id: SmolStr,
pub branch_id: SmolStr,
}
impl MetricsAuxInfo {

View File

@@ -229,7 +229,7 @@ pub struct NodeInfo {
pub config: compute::ConnCfg,
/// Labels for proxy's metrics.
pub aux: Arc<MetricsAuxInfo>,
pub aux: MetricsAuxInfo,
/// Whether we should accept self-signed certificates (for testing)
pub allow_self_signed_compute: bool,

View File

@@ -144,7 +144,7 @@ impl Api {
let node = NodeInfo {
config,
aux: body.aux.into(),
aux: body.aux,
allow_self_signed_compute: false,
};

View File

@@ -877,11 +877,11 @@ async fn prepare_client_connection(
pub async fn proxy_pass(
client: impl AsyncRead + AsyncWrite + Unpin,
compute: impl AsyncRead + AsyncWrite + Unpin,
aux: &MetricsAuxInfo,
aux: MetricsAuxInfo,
) -> anyhow::Result<()> {
let usage = USAGE_METRICS.register(Ids {
endpoint_id: aux.endpoint_id.to_string(),
branch_id: aux.branch_id.to_string(),
endpoint_id: aux.endpoint_id.clone(),
branch_id: aux.branch_id.clone(),
});
let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["tx"]);
@@ -1032,7 +1032,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
// immediately after opening the connection.
let (stream, read_buf) = stream.into_inner();
node.stream.write_all(&read_buf).await?;
proxy_pass(stream, node.stream, &aux).await
proxy_pass(stream, node.stream, aux).await
}
}

View File

@@ -8,6 +8,7 @@ use pbkdf2::{
Params, Pbkdf2,
};
use pq_proto::StartupMessageParams;
use smol_str::SmolStr;
use std::{collections::HashMap, net::SocketAddr, sync::Arc};
use std::{
fmt,
@@ -41,16 +42,16 @@ const MAX_CONNS_PER_ENDPOINT: usize = 20;
#[derive(Debug, Clone)]
pub struct ConnInfo {
pub username: String,
pub dbname: String,
pub hostname: String,
pub password: String,
pub options: Option<String>,
pub username: SmolStr,
pub dbname: SmolStr,
pub hostname: SmolStr,
pub password: SmolStr,
pub options: Option<SmolStr>,
}
impl ConnInfo {
// hm, change to hasher to avoid cloning?
pub fn db_and_user(&self) -> (String, String) {
pub fn db_and_user(&self) -> (SmolStr, SmolStr) {
(self.dbname.clone(), self.username.clone())
}
}
@@ -70,7 +71,7 @@ struct ConnPoolEntry {
// Per-endpoint connection pool, (dbname, username) -> DbUserConnPool
// Number of open connections is limited by the `max_conns_per_endpoint`.
pub struct EndpointConnPool {
pools: HashMap<(String, String), DbUserConnPool>,
pools: HashMap<(SmolStr, SmolStr), DbUserConnPool>,
total_conns: usize,
}
@@ -95,7 +96,7 @@ pub struct GlobalConnPool {
//
// That should be a fairly conteded map, so return reference to the per-endpoint
// pool as early as possible and release the lock.
global_pool: DashMap<String, Arc<RwLock<EndpointConnPool>>>,
global_pool: DashMap<SmolStr, Arc<RwLock<EndpointConnPool>>>,
/// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each.
/// That seems like far too much effort, so we're using a relaxed increment counter instead.
@@ -327,7 +328,7 @@ impl GlobalConnPool {
Ok(())
}
fn get_or_create_endpoint_pool(&self, endpoint: &String) -> Arc<RwLock<EndpointConnPool>> {
fn get_or_create_endpoint_pool(&self, endpoint: &SmolStr) -> Arc<RwLock<EndpointConnPool>> {
// fast path
if let Some(pool) = self.global_pool.get(endpoint) {
return pool.clone();
@@ -468,7 +469,7 @@ async fn connect_to_compute_once(
let (client, mut connection) = config
.user(&conn_info.username)
.password(&conn_info.password)
.password(&*conn_info.password)
.dbname(&conn_info.dbname)
.connect_timeout(timeout)
.connect(tokio_postgres::NoTls)
@@ -482,8 +483,8 @@ async fn connect_to_compute_once(
info!(%conn_info, %session, "new connection");
});
let ids = Ids {
endpoint_id: node_info.aux.endpoint_id.to_string(),
branch_id: node_info.aux.branch_id.to_string(),
endpoint_id: node_info.aux.endpoint_id.clone(),
branch_id: node_info.aux.branch_id.clone(),
};
tokio::spawn(

View File

@@ -182,16 +182,16 @@ fn get_conn_info(
for (key, value) in pairs {
if key == "options" {
options = Some(value.to_string());
options = Some(value.into());
break;
}
}
Ok(ConnInfo {
username: username.to_owned(),
dbname: dbname.to_owned(),
hostname: hostname.to_owned(),
password: password.to_owned(),
username: username.into(),
dbname: dbname.into(),
hostname: hostname.into(),
password: password.into(),
options,
})
}

View File

@@ -6,6 +6,7 @@ use consumption_metrics::{idempotency_key, Event, EventChunk, EventType, CHUNK_S
use dashmap::{mapref::entry::Entry, DashMap};
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use smol_str::SmolStr;
use std::{
convert::Infallible,
sync::{
@@ -29,8 +30,8 @@ const DEFAULT_HTTP_REPORTING_TIMEOUT: Duration = Duration::from_secs(60);
/// because we enrich the event with project_id in the control-plane endpoint.
#[derive(Eq, Hash, PartialEq, Serialize, Deserialize, Debug, Clone)]
pub struct Ids {
pub endpoint_id: String,
pub branch_id: String,
pub endpoint_id: SmolStr,
pub branch_id: SmolStr,
}
#[derive(Debug)]
@@ -290,8 +291,8 @@ mod tests {
// register a new counter
let counter = metrics.register(Ids {
endpoint_id: "e1".to_string(),
branch_id: "b1".to_string(),
endpoint_id: "e1".into(),
branch_id: "b1".into(),
});
// the counter should be observed despite 0 egress