diff --git a/libs/pq_proto/src/lib.rs b/libs/pq_proto/src/lib.rs index 522b65f5d1..f8e578c6f2 100644 --- a/libs/pq_proto/src/lib.rs +++ b/libs/pq_proto/src/lib.rs @@ -50,12 +50,17 @@ pub enum FeStartupPacket { }, } -#[derive(Debug)] +#[derive(Debug, Clone, Default)] pub struct StartupMessageParams { params: HashMap, } impl StartupMessageParams { + /// Set parameter's value by its name. + pub fn insert(&mut self, name: &str, value: &str) { + self.params.insert(name.to_owned(), value.to_owned()); + } + /// Get parameter's value by its name. pub fn get(&self, name: &str) -> Option<&str> { self.params.get(name).map(|s| s.as_str()) diff --git a/proxy/src/auth/backend/link.rs b/proxy/src/auth/backend/link.rs index 415a4b7d85..5932e1337c 100644 --- a/proxy/src/auth/backend/link.rs +++ b/proxy/src/auth/backend/link.rs @@ -100,6 +100,7 @@ pub(super) async fn authenticate( .dbname(&db_info.dbname) .user(&db_info.user); + ctx.set_dbname(db_info.dbname.into()); ctx.set_user(db_info.user.into()); ctx.set_project(db_info.aux.clone()); info!("woken up a compute node"); diff --git a/proxy/src/auth/credentials.rs b/proxy/src/auth/credentials.rs index 783a1a5a21..d06f5614f1 100644 --- a/proxy/src/auth/credentials.rs +++ b/proxy/src/auth/credentials.rs @@ -11,7 +11,6 @@ use crate::{ }; use itertools::Itertools; use pq_proto::StartupMessageParams; -use smol_str::SmolStr; use std::{collections::HashSet, net::IpAddr, str::FromStr}; use thiserror::Error; use tracing::{info, warn}; @@ -96,13 +95,6 @@ impl ComputeUserInfoMaybeEndpoint { let get_param = |key| params.get(key).ok_or(MissingKey(key)); let user: RoleName = get_param("user")?.into(); - // record the values if we have them - ctx.set_application(params.get("application_name").map(SmolStr::from)); - ctx.set_user(user.clone()); - if let Some(dbname) = params.get("database") { - ctx.set_dbname(dbname.into()); - } - // Project name might be passed via PG's command-line options. let endpoint_option = params .options_raw() diff --git a/proxy/src/context.rs b/proxy/src/context.rs index dfd3ef108e..ff79ba8275 100644 --- a/proxy/src/context.rs +++ b/proxy/src/context.rs @@ -2,6 +2,7 @@ use chrono::Utc; use once_cell::sync::OnceCell; +use pq_proto::StartupMessageParams; use smol_str::SmolStr; use std::net::IpAddr; use tokio::sync::mpsc; @@ -46,6 +47,7 @@ pub struct RequestMonitoring { pub(crate) auth_method: Option, success: bool, pub(crate) cold_start_info: ColdStartInfo, + pg_options: Option, // extra // This sender is here to keep the request monitoring channel open while requests are taking place. @@ -102,6 +104,7 @@ impl RequestMonitoring { success: false, rejected: None, cold_start_info: ColdStartInfo::Unknown, + pg_options: None, sender: LOG_CHAN.get().and_then(|tx| tx.upgrade()), disconnect_sender: LOG_CHAN_DISCONNECT.get().and_then(|tx| tx.upgrade()), @@ -132,6 +135,18 @@ impl RequestMonitoring { self.latency_timer.cold_start_info(info); } + pub fn set_db_options(&mut self, options: StartupMessageParams) { + self.set_application(options.get("application_name").map(SmolStr::from)); + if let Some(user) = options.get("user") { + self.set_user(user.into()); + } + if let Some(dbname) = options.get("database") { + self.set_dbname(dbname.into()); + } + + self.pg_options = Some(options); + } + pub fn set_project(&mut self, x: MetricsAuxInfo) { if self.endpoint_id.is_none() { self.set_endpoint_id(x.endpoint_id.as_str().into()) @@ -155,8 +170,10 @@ impl RequestMonitoring { } } - pub fn set_application(&mut self, app: Option) { - self.application = app.or_else(|| self.application.clone()); + fn set_application(&mut self, app: Option) { + if let Some(app) = app { + self.application = Some(app); + } } pub fn set_dbname(&mut self, dbname: DbName) { diff --git a/proxy/src/context/parquet.rs b/proxy/src/context/parquet.rs index a213a32ca4..1355b7e1d8 100644 --- a/proxy/src/context/parquet.rs +++ b/proxy/src/context/parquet.rs @@ -13,7 +13,9 @@ use parquet::{ }, record::RecordWriter, }; +use pq_proto::StartupMessageParams; use remote_storage::{GenericRemoteStorage, RemotePath, TimeoutOrCancel}; +use serde::ser::SerializeMap; use tokio::{sync::mpsc, time}; use tokio_util::sync::CancellationToken; use tracing::{debug, info, Span}; @@ -87,6 +89,7 @@ pub struct RequestData { database: Option, project: Option, branch: Option, + pg_options: Option, auth_method: Option<&'static str>, error: Option<&'static str>, /// Success is counted if we form a HTTP response with sql rows inside @@ -101,6 +104,23 @@ pub struct RequestData { disconnect_timestamp: Option, } +struct Options<'a> { + options: &'a StartupMessageParams, +} + +impl<'a> serde::Serialize for Options<'a> { + fn serialize(&self, s: S) -> Result + where + S: serde::Serializer, + { + let mut state = s.serialize_map(None)?; + for (k, v) in self.options.iter() { + state.serialize_entry(k, v)?; + } + state.end() + } +} + impl From<&RequestMonitoring> for RequestData { fn from(value: &RequestMonitoring) -> Self { Self { @@ -113,6 +133,10 @@ impl From<&RequestMonitoring> for RequestData { database: value.dbname.as_deref().map(String::from), project: value.project.as_deref().map(String::from), branch: value.branch.as_deref().map(String::from), + pg_options: value + .pg_options + .as_ref() + .and_then(|options| serde_json::to_string(&Options { options }).ok()), auth_method: value.auth_method.as_ref().map(|x| match x { super::AuthMethod::Web => "web", super::AuthMethod::ScramSha256 => "scram_sha_256", @@ -494,6 +518,7 @@ mod tests { database: Some(hex::encode(rng.gen::<[u8; 16]>())), project: Some(hex::encode(rng.gen::<[u8; 16]>())), branch: Some(hex::encode(rng.gen::<[u8; 16]>())), + pg_options: None, auth_method: None, protocol: ["tcp", "ws", "http"][rng.gen_range(0..3)], region: "us-east-1", @@ -570,15 +595,15 @@ mod tests { assert_eq!( file_stats, [ - (1315314, 3, 6000), - (1315307, 3, 6000), - (1315367, 3, 6000), - (1315324, 3, 6000), - (1315454, 3, 6000), - (1315296, 3, 6000), - (1315088, 3, 6000), - (1315324, 3, 6000), - (438713, 1, 2000) + (1315874, 3, 6000), + (1315867, 3, 6000), + (1315927, 3, 6000), + (1315884, 3, 6000), + (1316014, 3, 6000), + (1315856, 3, 6000), + (1315648, 3, 6000), + (1315884, 3, 6000), + (438913, 1, 2000) ] ); @@ -608,11 +633,11 @@ mod tests { assert_eq!( file_stats, [ - (1222212, 5, 10000), - (1228362, 5, 10000), - (1230156, 5, 10000), - (1229518, 5, 10000), - (1220796, 5, 10000) + (1223214, 5, 10000), + (1229364, 5, 10000), + (1231158, 5, 10000), + (1230520, 5, 10000), + (1221798, 5, 10000) ] ); @@ -644,11 +669,11 @@ mod tests { assert_eq!( file_stats, [ - (1207859, 5, 10000), - (1207590, 5, 10000), - (1207883, 5, 10000), - (1207871, 5, 10000), - (1208126, 5, 10000) + (1208861, 5, 10000), + (1208592, 5, 10000), + (1208885, 5, 10000), + (1208873, 5, 10000), + (1209128, 5, 10000) ] ); @@ -673,15 +698,15 @@ mod tests { assert_eq!( file_stats, [ - (1315314, 3, 6000), - (1315307, 3, 6000), - (1315367, 3, 6000), - (1315324, 3, 6000), - (1315454, 3, 6000), - (1315296, 3, 6000), - (1315088, 3, 6000), - (1315324, 3, 6000), - (438713, 1, 2000) + (1315874, 3, 6000), + (1315867, 3, 6000), + (1315927, 3, 6000), + (1315884, 3, 6000), + (1316014, 3, 6000), + (1315856, 3, 6000), + (1315648, 3, 6000), + (1315884, 3, 6000), + (438913, 1, 2000) ] ); @@ -718,7 +743,7 @@ mod tests { // files are smaller than the size threshold, but they took too long to fill so were flushed early assert_eq!( file_stats, - [(659462, 2, 3001), (659176, 2, 3000), (658972, 2, 2999)] + [(659836, 2, 3001), (659550, 2, 3000), (659346, 2, 2999)] ); tmpdir.close().unwrap(); diff --git a/proxy/src/proxy.rs b/proxy/src/proxy.rs index 5824b70df9..95b46ae002 100644 --- a/proxy/src/proxy.rs +++ b/proxy/src/proxy.rs @@ -267,6 +267,8 @@ pub async fn handle_client( }; drop(pause); + ctx.set_db_options(params.clone()); + let hostname = mode.hostname(stream.get_ref()); let common_names = tls.map(|tls| &tls.common_names); diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 5376bddfd3..9a7cdc8577 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -17,6 +17,7 @@ use hyper1::http::HeaderValue; use hyper1::Response; use hyper1::StatusCode; use hyper1::{HeaderMap, Request}; +use pq_proto::StartupMessageParams; use serde_json::json; use serde_json::Value; use tokio::time; @@ -192,13 +193,13 @@ fn get_conn_info( let mut options = Option::None; + let mut params = StartupMessageParams::default(); + params.insert("user", &username); + params.insert("database", &dbname); for (key, value) in pairs { - match &*key { - "options" => { - options = Some(NeonOptions::parse_options_raw(&value)); - } - "application_name" => ctx.set_application(Some(value.into())), - _ => {} + params.insert(&key, &value); + if key == "options" { + options = Some(NeonOptions::parse_options_raw(&value)); } }