Compare commits

...

2 Commits

Author SHA1 Message Date
Conrad Ludgate
db7b244fdb custom params fmt 2024-02-02 17:02:33 +00:00
Conrad Ludgate
e00127e84b less small allocs for startup params 2024-02-02 16:45:33 +00:00
6 changed files with 129 additions and 31 deletions

1
Cargo.lock generated
View File

@@ -3903,6 +3903,7 @@ dependencies = [
"pin-project-lite", "pin-project-lite",
"postgres-protocol", "postgres-protocol",
"rand 0.8.5", "rand 0.8.5",
"smallvec",
"thiserror", "thiserror",
"tokio", "tokio",
"tracing", "tracing",

View File

@@ -10,6 +10,7 @@ byteorder.workspace = true
pin-project-lite.workspace = true pin-project-lite.workspace = true
postgres-protocol.workspace = true postgres-protocol.workspace = true
rand.workspace = true rand.workspace = true
smallvec.workspace = true
tokio.workspace = true tokio.workspace = true
tracing.workspace = true tracing.workspace = true
thiserror.workspace = true thiserror.workspace = true

View File

@@ -7,7 +7,8 @@ pub mod framed;
use byteorder::{BigEndian, ReadBytesExt}; use byteorder::{BigEndian, ReadBytesExt};
use bytes::{Buf, BufMut, Bytes, BytesMut}; use bytes::{Buf, BufMut, Bytes, BytesMut};
use std::{borrow::Cow, collections::HashMap, fmt, io, str}; use smallvec::SmallVec;
use std::{borrow::Cow, fmt, io, ops::Range, str};
// re-export for use in utils pageserver_feedback.rs // re-export for use in utils pageserver_feedback.rs
pub use postgres_protocol::PG_EPOCH; pub use postgres_protocol::PG_EPOCH;
@@ -49,29 +50,67 @@ pub enum FeStartupPacket {
}, },
} }
#[derive(Debug)]
pub struct StartupMessageParams { pub struct StartupMessageParams {
params: HashMap<String, String>, data: String,
pairs: SmallVec<[Range<u32>; 4]>,
// for easy access
user: Option<Range<u32>>,
database: Option<Range<u32>>,
options: Option<Range<u32>>,
replication: Option<Range<u32>>,
}
impl fmt::Debug for StartupMessageParams {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_map().entries(self.iter()).finish()
}
} }
impl StartupMessageParams { impl StartupMessageParams {
/// Get parameter's value by its name. /// Get parameter's value by its name.
pub fn get(&self, name: &str) -> Option<&str> { pub fn get(&self, name: &str) -> Option<&str> {
self.params.get(name).map(|s| s.as_str()) self.pairs
.iter()
.map(|r| &self.data[r.start as usize..r.end as usize])
.find_map(|pair| pair.strip_prefix(name).and_then(|x| x.strip_prefix('\0')))
}
pub fn user(&self) -> Option<&str> {
self.user
.clone()
.and_then(|r| self.data.get(r.start as usize..r.end as usize))
}
pub fn database(&self) -> Option<&str> {
self.database
.clone()
.and_then(|r| self.data.get(r.start as usize..r.end as usize))
}
pub(crate) fn options_str(&self) -> Option<&str> {
self.options
.clone()
.and_then(|r| self.data.get(r.start as usize..r.end as usize))
}
pub fn replication(&self) -> Option<&str> {
self.replication
.clone()
.and_then(|r| self.data.get(r.start as usize..r.end as usize))
} }
/// Split command-line options according to PostgreSQL's logic, /// Split command-line options according to PostgreSQL's logic,
/// taking into account all escape sequences but leaving them as-is. /// taking into account all escape sequences but leaving them as-is.
/// [`None`] means that there's no `options` in [`Self`]. /// [`None`] means that there's no `options` in [`Self`].
pub fn options_raw(&self) -> Option<impl Iterator<Item = &str>> { pub fn options_raw(&self) -> Option<impl Iterator<Item = &str>> {
self.get("options").map(Self::parse_options_raw) self.options_str().map(Self::parse_options_raw)
} }
/// Split command-line options according to PostgreSQL's logic, /// Split command-line options according to PostgreSQL's logic,
/// applying all escape sequences (using owned strings as needed). /// applying all escape sequences (using owned strings as needed).
/// [`None`] means that there's no `options` in [`Self`]. /// [`None`] means that there's no `options` in [`Self`].
pub fn options_escaped(&self) -> Option<impl Iterator<Item = Cow<'_, str>>> { pub fn options_escaped(&self) -> Option<impl Iterator<Item = Cow<'_, str>>> {
self.get("options").map(Self::parse_options_escaped) self.options_str().map(Self::parse_options_escaped)
} }
/// Split command-line options according to PostgreSQL's logic, /// Split command-line options according to PostgreSQL's logic,
@@ -111,15 +150,44 @@ impl StartupMessageParams {
/// Iterate through key-value pairs in an arbitrary order. /// Iterate through key-value pairs in an arbitrary order.
pub fn iter(&self) -> impl Iterator<Item = (&str, &str)> { pub fn iter(&self) -> impl Iterator<Item = (&str, &str)> {
self.params.iter().map(|(k, v)| (k.as_str(), v.as_str())) self.pairs
.iter()
.map(|r| &self.data[r.start as usize..r.end as usize])
.flat_map(|pair| pair.split_once('\0'))
} }
// This function is mostly useful in tests. // This function is mostly useful in tests.
#[doc(hidden)] #[doc(hidden)]
pub fn new<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> Self { pub fn new<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> Self {
Self { let mut this = Self {
params: pairs.map(|(k, v)| (k.to_owned(), v.to_owned())).into(), data: Default::default(),
pairs: Default::default(),
user: Default::default(),
database: Default::default(),
options: Default::default(),
replication: Default::default(),
};
for (k, v) in pairs {
let start = this.data.len();
this.data.push_str(k);
this.data.push('\0');
let value_offset = this.data.len();
this.data.push_str(v);
let end = this.data.len();
this.data.push('\0');
let range = start as u32..end as u32;
this.pairs.push(range);
let value_range = value_offset as u32..end as u32;
match k {
"user" => this.user = Some(value_range),
"database" => this.database = Some(value_range),
"options" => this.options = Some(value_range),
"replication" => this.replication = Some(value_range),
_ => {}
}
} }
this.data.push('\0');
this
} }
} }
@@ -346,33 +414,62 @@ impl FeStartupPacket {
// Parse pairs of null-terminated strings (key, value). // Parse pairs of null-terminated strings (key, value).
// See `postgres: ProcessStartupPacket, build_startup_packet`. // See `postgres: ProcessStartupPacket, build_startup_packet`.
let mut tokens = str::from_utf8(&msg) let data = str::from_utf8(&msg)
.map_err(|_e| { .map_err(|_e| {
ProtocolError::BadMessage("StartupMessage params: invalid utf-8".to_owned()) ProtocolError::BadMessage("StartupMessage params: invalid utf-8".to_owned())
})? })?
.strip_suffix('\0') // drop packet's own null .to_owned();
.ok_or_else(|| {
ProtocolError::Protocol( let mut params = StartupMessageParams {
data,
pairs: Default::default(),
user: Default::default(),
database: Default::default(),
options: Default::default(),
replication: Default::default(),
};
let mut offset = 0;
let mut rest = params.data.as_str();
loop {
let Some((key, rest1)) = rest.split_once('\0') else {
return Err(ProtocolError::Protocol(
"StartupMessage params: missing null terminator".to_string(), "StartupMessage params: missing null terminator".to_string(),
) ));
})? };
.split_terminator('\0'); // pairs terminated
if key.is_empty() {
params.data.truncate(offset + 1);
params.data.shrink_to_fit();
break;
}
let Some((value, rest2)) = rest1.split_once('\0') else {
return Err(ProtocolError::Protocol(
"StartupMessage params: missing null terminator".to_string(),
));
};
rest = rest2;
let mut params = HashMap::new(); let start = offset;
while let Some(name) = tokens.next() { let value_offset = offset + key.len() + 1;
let value = tokens.next().ok_or_else(|| { let end = value_offset + value.len();
ProtocolError::Protocol( offset = end + 1;
"StartupMessage params: key without value".to_string(),
)
})?;
params.insert(name.to_owned(), value.to_owned()); params.pairs.push(start as u32..end as u32);
let value_range = value_offset as u32..end as u32;
match key {
"user" => params.user = Some(value_range),
"database" => params.database = Some(value_range),
"options" => params.options = Some(value_range),
"replication" => params.replication = Some(value_range),
_ => {}
}
} }
FeStartupPacket::StartupMessage { FeStartupPacket::StartupMessage {
major_version, major_version,
minor_version, minor_version,
params: StartupMessageParams { params }, params,
} }
} }
}; };

View File

@@ -83,8 +83,7 @@ impl ComputeUserInfoMaybeEndpoint {
use ComputeUserInfoParseError::*; use ComputeUserInfoParseError::*;
// Some parameters are stored in the startup message. // Some parameters are stored in the startup message.
let get_param = |key| params.get(key).ok_or(MissingKey(key)); let user: RoleName = params.user().ok_or(MissingKey("user"))?.into();
let user: RoleName = get_param("user")?.into();
// record the values if we have them // record the values if we have them
ctx.set_application(params.get("application_name").map(SmolStr::from)); ctx.set_application(params.get("application_name").map(SmolStr::from));

View File

@@ -89,13 +89,13 @@ impl ConnCfg {
pub fn set_startup_params(&mut self, params: &StartupMessageParams) { pub fn set_startup_params(&mut self, params: &StartupMessageParams) {
// Only set `user` if it's not present in the config. // Only set `user` if it's not present in the config.
// Link auth flow takes username from the console's response. // Link auth flow takes username from the console's response.
if let (None, Some(user)) = (self.get_user(), params.get("user")) { if let (None, Some(user)) = (self.get_user(), params.user()) {
self.user(user); self.user(user);
} }
// Only set `dbname` if it's not present in the config. // Only set `dbname` if it's not present in the config.
// Link auth flow takes dbname from the console's response. // Link auth flow takes dbname from the console's response.
if let (None, Some(dbname)) = (self.get_dbname(), params.get("database")) { if let (None, Some(dbname)) = (self.get_dbname(), params.database()) {
self.dbname(dbname); self.dbname(dbname);
} }
@@ -110,7 +110,7 @@ impl ConnCfg {
} }
// TODO: This is especially ugly... // TODO: This is especially ugly...
if let Some(replication) = params.get("replication") { if let Some(replication) = params.replication() {
use tokio_postgres::config::ReplicationMode; use tokio_postgres::config::ReplicationMode;
match replication { match replication {
"true" | "on" | "yes" | "1" => { "true" | "on" | "yes" | "1" => {

View File

@@ -237,7 +237,7 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
{ {
Ok(auth_result) => auth_result, Ok(auth_result) => auth_result,
Err(e) => { Err(e) => {
let db = params.get("database"); let db = params.database();
let app = params.get("application_name"); let app = params.get("application_name");
let params_span = tracing::info_span!("", ?user, ?db, ?app); let params_span = tracing::info_span!("", ?user, ?db, ?app);