Compare commits

...

2 Commits

Author SHA1 Message Date
Conrad Ludgate
db7b244fdb custom params fmt 2024-02-02 17:02:33 +00:00
Conrad Ludgate
e00127e84b less small allocs for startup params 2024-02-02 16:45:33 +00:00
6 changed files with 129 additions and 31 deletions

1
Cargo.lock generated
View File

@@ -3903,6 +3903,7 @@ dependencies = [
"pin-project-lite",
"postgres-protocol",
"rand 0.8.5",
"smallvec",
"thiserror",
"tokio",
"tracing",

View File

@@ -10,6 +10,7 @@ byteorder.workspace = true
pin-project-lite.workspace = true
postgres-protocol.workspace = true
rand.workspace = true
smallvec.workspace = true
tokio.workspace = true
tracing.workspace = true
thiserror.workspace = true

View File

@@ -7,7 +7,8 @@ pub mod framed;
use byteorder::{BigEndian, ReadBytesExt};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use std::{borrow::Cow, collections::HashMap, fmt, io, str};
use smallvec::SmallVec;
use std::{borrow::Cow, fmt, io, ops::Range, str};
// re-export for use in utils pageserver_feedback.rs
pub use postgres_protocol::PG_EPOCH;
@@ -49,29 +50,67 @@ pub enum FeStartupPacket {
},
}
#[derive(Debug)]
pub struct StartupMessageParams {
params: HashMap<String, String>,
data: String,
pairs: SmallVec<[Range<u32>; 4]>,
// for easy access
user: Option<Range<u32>>,
database: Option<Range<u32>>,
options: Option<Range<u32>>,
replication: Option<Range<u32>>,
}
impl fmt::Debug for StartupMessageParams {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_map().entries(self.iter()).finish()
}
}
impl StartupMessageParams {
/// Get parameter's value by its name.
pub fn get(&self, name: &str) -> Option<&str> {
self.params.get(name).map(|s| s.as_str())
self.pairs
.iter()
.map(|r| &self.data[r.start as usize..r.end as usize])
.find_map(|pair| pair.strip_prefix(name).and_then(|x| x.strip_prefix('\0')))
}
pub fn user(&self) -> Option<&str> {
self.user
.clone()
.and_then(|r| self.data.get(r.start as usize..r.end as usize))
}
pub fn database(&self) -> Option<&str> {
self.database
.clone()
.and_then(|r| self.data.get(r.start as usize..r.end as usize))
}
pub(crate) fn options_str(&self) -> Option<&str> {
self.options
.clone()
.and_then(|r| self.data.get(r.start as usize..r.end as usize))
}
pub fn replication(&self) -> Option<&str> {
self.replication
.clone()
.and_then(|r| self.data.get(r.start as usize..r.end as usize))
}
/// Split command-line options according to PostgreSQL's logic,
/// taking into account all escape sequences but leaving them as-is.
/// [`None`] means that there's no `options` in [`Self`].
pub fn options_raw(&self) -> Option<impl Iterator<Item = &str>> {
self.get("options").map(Self::parse_options_raw)
self.options_str().map(Self::parse_options_raw)
}
/// Split command-line options according to PostgreSQL's logic,
/// applying all escape sequences (using owned strings as needed).
/// [`None`] means that there's no `options` in [`Self`].
pub fn options_escaped(&self) -> Option<impl Iterator<Item = Cow<'_, str>>> {
self.get("options").map(Self::parse_options_escaped)
self.options_str().map(Self::parse_options_escaped)
}
/// Split command-line options according to PostgreSQL's logic,
@@ -111,15 +150,44 @@ impl StartupMessageParams {
/// Iterate through key-value pairs in an arbitrary order.
pub fn iter(&self) -> impl Iterator<Item = (&str, &str)> {
self.params.iter().map(|(k, v)| (k.as_str(), v.as_str()))
self.pairs
.iter()
.map(|r| &self.data[r.start as usize..r.end as usize])
.flat_map(|pair| pair.split_once('\0'))
}
// This function is mostly useful in tests.
#[doc(hidden)]
pub fn new<'a, const N: usize>(pairs: [(&'a str, &'a str); N]) -> Self {
Self {
params: pairs.map(|(k, v)| (k.to_owned(), v.to_owned())).into(),
let mut this = Self {
data: Default::default(),
pairs: Default::default(),
user: Default::default(),
database: Default::default(),
options: Default::default(),
replication: Default::default(),
};
for (k, v) in pairs {
let start = this.data.len();
this.data.push_str(k);
this.data.push('\0');
let value_offset = this.data.len();
this.data.push_str(v);
let end = this.data.len();
this.data.push('\0');
let range = start as u32..end as u32;
this.pairs.push(range);
let value_range = value_offset as u32..end as u32;
match k {
"user" => this.user = Some(value_range),
"database" => this.database = Some(value_range),
"options" => this.options = Some(value_range),
"replication" => this.replication = Some(value_range),
_ => {}
}
}
this.data.push('\0');
this
}
}
@@ -346,33 +414,62 @@ impl FeStartupPacket {
// Parse pairs of null-terminated strings (key, value).
// See `postgres: ProcessStartupPacket, build_startup_packet`.
let mut tokens = str::from_utf8(&msg)
let data = str::from_utf8(&msg)
.map_err(|_e| {
ProtocolError::BadMessage("StartupMessage params: invalid utf-8".to_owned())
})?
.strip_suffix('\0') // drop packet's own null
.ok_or_else(|| {
ProtocolError::Protocol(
.to_owned();
let mut params = StartupMessageParams {
data,
pairs: Default::default(),
user: Default::default(),
database: Default::default(),
options: Default::default(),
replication: Default::default(),
};
let mut offset = 0;
let mut rest = params.data.as_str();
loop {
let Some((key, rest1)) = rest.split_once('\0') else {
return Err(ProtocolError::Protocol(
"StartupMessage params: missing null terminator".to_string(),
)
})?
.split_terminator('\0');
));
};
// pairs terminated
if key.is_empty() {
params.data.truncate(offset + 1);
params.data.shrink_to_fit();
break;
}
let Some((value, rest2)) = rest1.split_once('\0') else {
return Err(ProtocolError::Protocol(
"StartupMessage params: missing null terminator".to_string(),
));
};
rest = rest2;
let mut params = HashMap::new();
while let Some(name) = tokens.next() {
let value = tokens.next().ok_or_else(|| {
ProtocolError::Protocol(
"StartupMessage params: key without value".to_string(),
)
})?;
let start = offset;
let value_offset = offset + key.len() + 1;
let end = value_offset + value.len();
offset = end + 1;
params.insert(name.to_owned(), value.to_owned());
params.pairs.push(start as u32..end as u32);
let value_range = value_offset as u32..end as u32;
match key {
"user" => params.user = Some(value_range),
"database" => params.database = Some(value_range),
"options" => params.options = Some(value_range),
"replication" => params.replication = Some(value_range),
_ => {}
}
}
FeStartupPacket::StartupMessage {
major_version,
minor_version,
params: StartupMessageParams { params },
params,
}
}
};

View File

@@ -83,8 +83,7 @@ impl ComputeUserInfoMaybeEndpoint {
use ComputeUserInfoParseError::*;
// Some parameters are stored in the startup message.
let get_param = |key| params.get(key).ok_or(MissingKey(key));
let user: RoleName = get_param("user")?.into();
let user: RoleName = params.user().ok_or(MissingKey("user"))?.into();
// record the values if we have them
ctx.set_application(params.get("application_name").map(SmolStr::from));

View File

@@ -89,13 +89,13 @@ impl ConnCfg {
pub fn set_startup_params(&mut self, params: &StartupMessageParams) {
// Only set `user` if it's not present in the config.
// Link auth flow takes username from the console's response.
if let (None, Some(user)) = (self.get_user(), params.get("user")) {
if let (None, Some(user)) = (self.get_user(), params.user()) {
self.user(user);
}
// Only set `dbname` if it's not present in the config.
// Link auth flow takes dbname from the console's response.
if let (None, Some(dbname)) = (self.get_dbname(), params.get("database")) {
if let (None, Some(dbname)) = (self.get_dbname(), params.database()) {
self.dbname(dbname);
}
@@ -110,7 +110,7 @@ impl ConnCfg {
}
// TODO: This is especially ugly...
if let Some(replication) = params.get("replication") {
if let Some(replication) = params.replication() {
use tokio_postgres::config::ReplicationMode;
match replication {
"true" | "on" | "yes" | "1" => {

View File

@@ -237,7 +237,7 @@ pub async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
{
Ok(auth_result) => auth_result,
Err(e) => {
let db = params.get("database");
let db = params.database();
let app = params.get("application_name");
let params_span = tracing::info_span!("", ?user, ?db, ?app);