mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-17 10:22:56 +00:00
proxy: chore: replace strings with SmolStr (#5786)
## Problem no problem ## Summary of changes replaces boxstr with arcstr as it's cheaper to clone. mild perf improvement. probably should look into other smallstring optimsations tbh, they will likely be even better. The longest endpoint name I was able to construct is something like `ep-weathered-wildflower-12345678` which is 32 bytes. Most string optimisations top out at 23 bytes
This commit is contained in:
10
Cargo.lock
generated
10
Cargo.lock
generated
@@ -3643,6 +3643,7 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sha2",
|
||||
"smol_str",
|
||||
"socket2 0.5.3",
|
||||
"sync_wrapper",
|
||||
"task-local-extensions",
|
||||
@@ -4709,6 +4710,15 @@ version = "1.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "62bb4feee49fdd9f707ef802e22365a35de4b7b299de4763d44bfea899442ff9"
|
||||
|
||||
[[package]]
|
||||
name = "smol_str"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "74212e6bbe9a4352329b2f68ba3130c15a3f26fe88ff22dbdc6cdd58fa85e99c"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "socket2"
|
||||
version = "0.4.9"
|
||||
|
||||
@@ -132,6 +132,7 @@ serde_assert = "0.5.0"
|
||||
sha2 = "0.10.2"
|
||||
signal-hook = "0.3"
|
||||
smallvec = "1.11"
|
||||
smol_str = { version = "0.2.0", features = ["serde"] }
|
||||
socket2 = "0.5"
|
||||
strum = "0.24"
|
||||
strum_macros = "0.24"
|
||||
|
||||
@@ -69,6 +69,7 @@ webpki-roots.workspace = true
|
||||
x509-parser.workspace = true
|
||||
native-tls.workspace = true
|
||||
postgres-native-tls.workspace = true
|
||||
smol_str.workspace = true
|
||||
|
||||
workspace_hack.workspace = true
|
||||
tokio-util.workspace = true
|
||||
|
||||
@@ -106,7 +106,7 @@ pub(super) async fn authenticate(
|
||||
reported_auth_ok: true,
|
||||
value: NodeInfo {
|
||||
config,
|
||||
aux: db_info.aux.into(),
|
||||
aux: db_info.aux,
|
||||
allow_self_signed_compute: false, // caller may override
|
||||
},
|
||||
})
|
||||
|
||||
@@ -284,5 +284,5 @@ async fn handle_client(
|
||||
let client = tokio::net::TcpStream::connect(destination).await?;
|
||||
|
||||
let metrics_aux: MetricsAuxInfo = Default::default();
|
||||
proxy::proxy::proxy_pass(tls_stream, client, &metrics_aux).await
|
||||
proxy::proxy::proxy_pass(tls_stream, client, metrics_aux).await
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use serde::Deserialize;
|
||||
use smol_str::SmolStr;
|
||||
use std::fmt;
|
||||
|
||||
/// Generic error response with human-readable description.
|
||||
@@ -88,11 +89,11 @@ impl fmt::Debug for DatabaseInfo {
|
||||
|
||||
/// Various labels for prometheus metrics.
|
||||
/// Also known as `ProxyMetricsAuxInfo` in the console.
|
||||
#[derive(Debug, Deserialize, Default)]
|
||||
#[derive(Debug, Deserialize, Clone, Default)]
|
||||
pub struct MetricsAuxInfo {
|
||||
pub endpoint_id: Box<str>,
|
||||
pub project_id: Box<str>,
|
||||
pub branch_id: Box<str>,
|
||||
pub endpoint_id: SmolStr,
|
||||
pub project_id: SmolStr,
|
||||
pub branch_id: SmolStr,
|
||||
}
|
||||
|
||||
impl MetricsAuxInfo {
|
||||
|
||||
@@ -229,7 +229,7 @@ pub struct NodeInfo {
|
||||
pub config: compute::ConnCfg,
|
||||
|
||||
/// Labels for proxy's metrics.
|
||||
pub aux: Arc<MetricsAuxInfo>,
|
||||
pub aux: MetricsAuxInfo,
|
||||
|
||||
/// Whether we should accept self-signed certificates (for testing)
|
||||
pub allow_self_signed_compute: bool,
|
||||
|
||||
@@ -144,7 +144,7 @@ impl Api {
|
||||
|
||||
let node = NodeInfo {
|
||||
config,
|
||||
aux: body.aux.into(),
|
||||
aux: body.aux,
|
||||
allow_self_signed_compute: false,
|
||||
};
|
||||
|
||||
|
||||
@@ -877,11 +877,11 @@ async fn prepare_client_connection(
|
||||
pub async fn proxy_pass(
|
||||
client: impl AsyncRead + AsyncWrite + Unpin,
|
||||
compute: impl AsyncRead + AsyncWrite + Unpin,
|
||||
aux: &MetricsAuxInfo,
|
||||
aux: MetricsAuxInfo,
|
||||
) -> anyhow::Result<()> {
|
||||
let usage = USAGE_METRICS.register(Ids {
|
||||
endpoint_id: aux.endpoint_id.to_string(),
|
||||
branch_id: aux.branch_id.to_string(),
|
||||
endpoint_id: aux.endpoint_id.clone(),
|
||||
branch_id: aux.branch_id.clone(),
|
||||
});
|
||||
|
||||
let m_sent = NUM_BYTES_PROXIED_COUNTER.with_label_values(&["tx"]);
|
||||
@@ -1032,7 +1032,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> Client<'_, S> {
|
||||
// immediately after opening the connection.
|
||||
let (stream, read_buf) = stream.into_inner();
|
||||
node.stream.write_all(&read_buf).await?;
|
||||
proxy_pass(stream, node.stream, &aux).await
|
||||
proxy_pass(stream, node.stream, aux).await
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ use pbkdf2::{
|
||||
Params, Pbkdf2,
|
||||
};
|
||||
use pq_proto::StartupMessageParams;
|
||||
use smol_str::SmolStr;
|
||||
use std::{collections::HashMap, net::SocketAddr, sync::Arc};
|
||||
use std::{
|
||||
fmt,
|
||||
@@ -41,16 +42,16 @@ const MAX_CONNS_PER_ENDPOINT: usize = 20;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ConnInfo {
|
||||
pub username: String,
|
||||
pub dbname: String,
|
||||
pub hostname: String,
|
||||
pub password: String,
|
||||
pub options: Option<String>,
|
||||
pub username: SmolStr,
|
||||
pub dbname: SmolStr,
|
||||
pub hostname: SmolStr,
|
||||
pub password: SmolStr,
|
||||
pub options: Option<SmolStr>,
|
||||
}
|
||||
|
||||
impl ConnInfo {
|
||||
// hm, change to hasher to avoid cloning?
|
||||
pub fn db_and_user(&self) -> (String, String) {
|
||||
pub fn db_and_user(&self) -> (SmolStr, SmolStr) {
|
||||
(self.dbname.clone(), self.username.clone())
|
||||
}
|
||||
}
|
||||
@@ -70,7 +71,7 @@ struct ConnPoolEntry {
|
||||
// Per-endpoint connection pool, (dbname, username) -> DbUserConnPool
|
||||
// Number of open connections is limited by the `max_conns_per_endpoint`.
|
||||
pub struct EndpointConnPool {
|
||||
pools: HashMap<(String, String), DbUserConnPool>,
|
||||
pools: HashMap<(SmolStr, SmolStr), DbUserConnPool>,
|
||||
total_conns: usize,
|
||||
}
|
||||
|
||||
@@ -95,7 +96,7 @@ pub struct GlobalConnPool {
|
||||
//
|
||||
// That should be a fairly conteded map, so return reference to the per-endpoint
|
||||
// pool as early as possible and release the lock.
|
||||
global_pool: DashMap<String, Arc<RwLock<EndpointConnPool>>>,
|
||||
global_pool: DashMap<SmolStr, Arc<RwLock<EndpointConnPool>>>,
|
||||
|
||||
/// [`DashMap::len`] iterates over all inner pools and acquires a read lock on each.
|
||||
/// That seems like far too much effort, so we're using a relaxed increment counter instead.
|
||||
@@ -327,7 +328,7 @@ impl GlobalConnPool {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_or_create_endpoint_pool(&self, endpoint: &String) -> Arc<RwLock<EndpointConnPool>> {
|
||||
fn get_or_create_endpoint_pool(&self, endpoint: &SmolStr) -> Arc<RwLock<EndpointConnPool>> {
|
||||
// fast path
|
||||
if let Some(pool) = self.global_pool.get(endpoint) {
|
||||
return pool.clone();
|
||||
@@ -468,7 +469,7 @@ async fn connect_to_compute_once(
|
||||
|
||||
let (client, mut connection) = config
|
||||
.user(&conn_info.username)
|
||||
.password(&conn_info.password)
|
||||
.password(&*conn_info.password)
|
||||
.dbname(&conn_info.dbname)
|
||||
.connect_timeout(timeout)
|
||||
.connect(tokio_postgres::NoTls)
|
||||
@@ -482,8 +483,8 @@ async fn connect_to_compute_once(
|
||||
info!(%conn_info, %session, "new connection");
|
||||
});
|
||||
let ids = Ids {
|
||||
endpoint_id: node_info.aux.endpoint_id.to_string(),
|
||||
branch_id: node_info.aux.branch_id.to_string(),
|
||||
endpoint_id: node_info.aux.endpoint_id.clone(),
|
||||
branch_id: node_info.aux.branch_id.clone(),
|
||||
};
|
||||
|
||||
tokio::spawn(
|
||||
|
||||
@@ -182,16 +182,16 @@ fn get_conn_info(
|
||||
|
||||
for (key, value) in pairs {
|
||||
if key == "options" {
|
||||
options = Some(value.to_string());
|
||||
options = Some(value.into());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ConnInfo {
|
||||
username: username.to_owned(),
|
||||
dbname: dbname.to_owned(),
|
||||
hostname: hostname.to_owned(),
|
||||
password: password.to_owned(),
|
||||
username: username.into(),
|
||||
dbname: dbname.into(),
|
||||
hostname: hostname.into(),
|
||||
password: password.into(),
|
||||
options,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ use consumption_metrics::{idempotency_key, Event, EventChunk, EventType, CHUNK_S
|
||||
use dashmap::{mapref::entry::Entry, DashMap};
|
||||
use once_cell::sync::Lazy;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use smol_str::SmolStr;
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
sync::{
|
||||
@@ -29,8 +30,8 @@ const DEFAULT_HTTP_REPORTING_TIMEOUT: Duration = Duration::from_secs(60);
|
||||
/// because we enrich the event with project_id in the control-plane endpoint.
|
||||
#[derive(Eq, Hash, PartialEq, Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct Ids {
|
||||
pub endpoint_id: String,
|
||||
pub branch_id: String,
|
||||
pub endpoint_id: SmolStr,
|
||||
pub branch_id: SmolStr,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -290,8 +291,8 @@ mod tests {
|
||||
|
||||
// register a new counter
|
||||
let counter = metrics.register(Ids {
|
||||
endpoint_id: "e1".to_string(),
|
||||
branch_id: "b1".to_string(),
|
||||
endpoint_id: "e1".into(),
|
||||
branch_id: "b1".into(),
|
||||
});
|
||||
|
||||
// the counter should be observed despite 0 egress
|
||||
|
||||
Reference in New Issue
Block a user