Compare commits

...

3 Commits

Author SHA1 Message Date
Conrad Ludgate
462713802d one more 2024-04-20 15:06:36 +01:00
Conrad Ludgate
0b2e0d8af5 more serde 2024-04-20 12:16:47 +01:00
Conrad Ludgate
7ac2179aeb use serde for deser config 2024-04-20 12:10:54 +01:00
4 changed files with 245 additions and 144 deletions

2
Cargo.lock generated
View File

@@ -4314,6 +4314,7 @@ dependencies = [
"http 1.1.0", "http 1.1.0",
"http-body-util", "http-body-util",
"humantime", "humantime",
"humantime-serde",
"hyper 0.14.26", "hyper 0.14.26",
"hyper 1.2.0", "hyper 1.2.0",
"hyper-tungstenite", "hyper-tungstenite",
@@ -4355,6 +4356,7 @@ dependencies = [
"scopeguard", "scopeguard",
"serde", "serde",
"serde_json", "serde_json",
"serde_with",
"sha2", "sha2",
"smallvec", "smallvec",
"smol_str", "smol_str",

View File

@@ -35,6 +35,7 @@ hmac.workspace = true
hostname.workspace = true hostname.workspace = true
http.workspace = true http.workspace = true
humantime.workspace = true humantime.workspace = true
humantime-serde.workspace = true
hyper-tungstenite.workspace = true hyper-tungstenite.workspace = true
hyper.workspace = true hyper.workspace = true
hyper1 = { package = "hyper", version = "1.2", features = ["server"] } hyper1 = { package = "hyper", version = "1.2", features = ["server"] }
@@ -70,6 +71,7 @@ rustls.workspace = true
scopeguard.workspace = true scopeguard.workspace = true
serde.workspace = true serde.workspace = true
serde_json.workspace = true serde_json.workspace = true
serde_with.workspace = true
sha2 = { workspace = true, features = ["asm"] } sha2 = { workspace = true, features = ["asm"] }
smol_str.workspace = true smol_str.workspace = true
smallvec.workspace = true smallvec.workspace = true

View File

@@ -3,13 +3,19 @@ use crate::{
rate_limiter::RateBucketInfo, rate_limiter::RateBucketInfo,
serverless::GlobalConnPoolOptions, serverless::GlobalConnPoolOptions,
}; };
use anyhow::{bail, ensure, Context, Ok}; use anyhow::{ensure, Context};
use humantime::parse_duration;
use itertools::Itertools; use itertools::Itertools;
use remote_storage::RemoteStorageConfig; use remote_storage::RemoteStorageConfig;
use rustls::{ use rustls::{
crypto::ring::sign, crypto::ring::sign,
pki_types::{CertificateDer, PrivateKeyDer}, pki_types::{CertificateDer, PrivateKeyDer},
}; };
use serde::{
de::{value::BorrowedStrDeserializer, MapAccess},
forward_to_deserialize_any, Deserialize, Deserializer,
};
use serde_with::serde_as;
use sha2::{Digest, Sha256}; use sha2::{Digest, Sha256};
use std::{ use std::{
collections::{HashMap, HashSet}, collections::{HashMap, HashSet},
@@ -337,54 +343,97 @@ impl EndpointCacheConfig {
/// Notice that by default the limiter is empty, which means that cache is disabled. /// Notice that by default the limiter is empty, which means that cache is disabled.
pub const CACHE_DEFAULT_OPTIONS: &'static str = pub const CACHE_DEFAULT_OPTIONS: &'static str =
"initial_batch_size=1000,default_batch_size=10,xread_timeout=5m,stream_name=controlPlane,disable_cache=true,limiter_info=1000@1s,retry_interval=1s"; "initial_batch_size=1000,default_batch_size=10,xread_timeout=5m,stream_name=controlPlane,disable_cache=true,limiter_info=1000@1s,retry_interval=1s";
}
/// Parse cache options passed via cmdline. impl<'de> serde::Deserialize<'de> for EndpointCacheConfig {
/// Example: [`Self::CACHE_DEFAULT_OPTIONS`]. fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
fn parse(options: &str) -> anyhow::Result<Self> { where
let mut initial_batch_size = None; D: serde::Deserializer<'de>,
let mut default_batch_size = None; {
let mut xread_timeout = None; struct Visitor;
let mut stream_name = None; impl<'de> serde::de::Visitor<'de> for Visitor {
let mut limiter_info = vec![]; type Value = EndpointCacheConfig;
let mut disable_cache = false; fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
let mut retry_interval = None; f.write_str("struct EndpointCacheConfig")
}
for option in options.split(',') { fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
let (key, value) = option where
.split_once('=') A: serde::de::MapAccess<'de>,
.with_context(|| format!("bad key-value pair: {option}"))?; {
fn e<E: serde::de::Error, T: std::fmt::Display>(t: T) -> E {
E::custom(t)
}
match key { let mut initial_batch_size: Option<usize> = None;
"initial_batch_size" => initial_batch_size = Some(value.parse()?), let mut default_batch_size: Option<usize> = None;
"default_batch_size" => default_batch_size = Some(value.parse()?), let mut xread_timeout: Option<Duration> = None;
"xread_timeout" => xread_timeout = Some(humantime::parse_duration(value)?), let mut stream_name: Option<String> = None;
"stream_name" => stream_name = Some(value.to_string()), let mut limiter_info: Vec<RateBucketInfo> = vec![];
"limiter_info" => limiter_info.push(RateBucketInfo::from_str(value)?), let mut disable_cache: bool = false;
"disable_cache" => disable_cache = value.parse()?, let mut retry_interval: Option<Duration> = None;
"retry_interval" => retry_interval = Some(humantime::parse_duration(value)?), while let Some((k, v)) = map.next_entry::<&str, &str>()? {
unknown => bail!("unknown key: {unknown}"), match k {
"initial_batch_size" => initial_batch_size = Some(v.parse().map_err(e)?),
"default_batch_size" => default_batch_size = Some(v.parse().map_err(e)?),
"xread_timeout" => {
xread_timeout = Some(parse_duration(v).map_err(e)?);
}
"stream_name" => stream_name = Some(v.to_owned()),
"limiter_info" => limiter_info.push(v.parse().map_err(e)?),
"disable_cache" => disable_cache = v.parse().map_err(e)?,
"retry_interval" => retry_interval = Some(parse_duration(v).map_err(e)?),
x => {
return Err(serde::de::Error::unknown_field(
x,
&[
"initial_batch_size",
"default_batch_size",
"xread_timeout",
"stream_name",
"limiter_info",
"disable_cache",
"retry_interval",
],
));
}
} }
} }
RateBucketInfo::validate(&mut limiter_info)?;
Ok(Self { let initial_batch_size = initial_batch_size
initial_batch_size: initial_batch_size.context("missing `initial_batch_size`")?, .ok_or_else(|| serde::de::Error::missing_field("initial_batch_size"))?;
default_batch_size: default_batch_size.context("missing `default_batch_size`")?, let default_batch_size = default_batch_size
xread_timeout: xread_timeout.context("missing `xread_timeout`")?, .ok_or_else(|| serde::de::Error::missing_field("default_batch_size"))?;
stream_name: stream_name.context("missing `stream_name`")?, let xread_timeout = xread_timeout
disable_cache, .ok_or_else(|| serde::de::Error::missing_field("xread_timeout"))?;
let stream_name =
stream_name.ok_or_else(|| serde::de::Error::missing_field("stream_name"))?;
let retry_interval = retry_interval
.ok_or_else(|| serde::de::Error::missing_field("retry_interval"))?;
RateBucketInfo::validate(&mut limiter_info).map_err(e)?;
Ok(EndpointCacheConfig {
initial_batch_size,
default_batch_size,
xread_timeout,
stream_name,
limiter_info, limiter_info,
retry_interval: retry_interval.context("missing `retry_interval`")?, disable_cache,
retry_interval,
}) })
} }
} }
serde::Deserializer::deserialize_map(deserializer, Visitor)
}
}
impl FromStr for EndpointCacheConfig { impl FromStr for EndpointCacheConfig {
type Err = anyhow::Error; type Err = anyhow::Error;
fn from_str(options: &str) -> Result<Self, Self::Err> { fn from_str(options: &str) -> Result<Self, Self::Err> {
let error = || format!("failed to parse endpoint cache options '{options}'"); let error = || format!("failed to parse endpoint cache options '{options}'");
Self::parse(options).with_context(error) Self::deserialize(SimpleKVConfig(options)).with_context(error)
} }
} }
#[derive(Debug)] #[derive(Debug)]
@@ -403,11 +452,15 @@ pub fn remote_storage_from_toml(s: &str) -> anyhow::Result<OptRemoteStorageConfi
} }
/// Helper for cmdline cache options parsing. /// Helper for cmdline cache options parsing.
#[derive(Debug)] #[serde_as]
#[derive(Debug, Deserialize)]
pub struct CacheOptions { pub struct CacheOptions {
/// Max number of entries. /// Max number of entries.
#[serde_as(as = "serde_with::DisplayFromStr")]
pub size: usize, pub size: usize,
/// Entry's time-to-live. /// Entry's time-to-live.
#[serde(with = "humantime_serde")]
#[serde(default)]
pub ttl: Duration, pub ttl: Duration,
} }
@@ -418,30 +471,7 @@ impl CacheOptions {
/// Parse cache options passed via cmdline. /// Parse cache options passed via cmdline.
/// Example: [`Self::CACHE_DEFAULT_OPTIONS`]. /// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
fn parse(options: &str) -> anyhow::Result<Self> { fn parse(options: &str) -> anyhow::Result<Self> {
let mut size = None; Ok(Self::deserialize(SimpleKVConfig(options))?)
let mut ttl = None;
for option in options.split(',') {
let (key, value) = option
.split_once('=')
.with_context(|| format!("bad key-value pair: {option}"))?;
match key {
"size" => size = Some(value.parse()?),
"ttl" => ttl = Some(humantime::parse_duration(value)?),
unknown => bail!("unknown key: {unknown}"),
}
}
// TTL doesn't matter if cache is always empty.
if let Some(0) = size {
ttl.get_or_insert(Duration::default());
}
Ok(Self {
size: size.context("missing `size`")?,
ttl: ttl.context("missing `ttl`")?,
})
} }
} }
@@ -455,15 +485,21 @@ impl FromStr for CacheOptions {
} }
/// Helper for cmdline cache options parsing. /// Helper for cmdline cache options parsing.
#[derive(Debug)] #[serde_as]
#[derive(Debug, Deserialize)]
pub struct ProjectInfoCacheOptions { pub struct ProjectInfoCacheOptions {
/// Max number of entries. /// Max number of entries.
#[serde_as(as = "serde_with::DisplayFromStr")]
pub size: usize, pub size: usize,
/// Entry's time-to-live. /// Entry's time-to-live.
#[serde(with = "humantime_serde")]
#[serde(default)]
pub ttl: Duration, pub ttl: Duration,
/// Max number of roles per endpoint. /// Max number of roles per endpoint.
#[serde_as(as = "serde_with::DisplayFromStr")]
pub max_roles: usize, pub max_roles: usize,
/// Gc interval. /// Gc interval.
#[serde(with = "humantime_serde")]
pub gc_interval: Duration, pub gc_interval: Duration,
} }
@@ -475,36 +511,7 @@ impl ProjectInfoCacheOptions {
/// Parse cache options passed via cmdline. /// Parse cache options passed via cmdline.
/// Example: [`Self::CACHE_DEFAULT_OPTIONS`]. /// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
fn parse(options: &str) -> anyhow::Result<Self> { fn parse(options: &str) -> anyhow::Result<Self> {
let mut size = None; Ok(Self::deserialize(SimpleKVConfig(options))?)
let mut ttl = None;
let mut max_roles = None;
let mut gc_interval = None;
for option in options.split(',') {
let (key, value) = option
.split_once('=')
.with_context(|| format!("bad key-value pair: {option}"))?;
match key {
"size" => size = Some(value.parse()?),
"ttl" => ttl = Some(humantime::parse_duration(value)?),
"max_roles" => max_roles = Some(value.parse()?),
"gc_interval" => gc_interval = Some(humantime::parse_duration(value)?),
unknown => bail!("unknown key: {unknown}"),
}
}
// TTL doesn't matter if cache is always empty.
if let Some(0) = size {
ttl.get_or_insert(Duration::default());
}
Ok(Self {
size: size.context("missing `size`")?,
ttl: ttl.context("missing `ttl`")?,
max_roles: max_roles.context("missing `max_roles`")?,
gc_interval: gc_interval.context("missing `gc_interval`")?,
})
} }
} }
@@ -518,14 +525,23 @@ impl FromStr for ProjectInfoCacheOptions {
} }
/// Helper for cmdline cache options parsing. /// Helper for cmdline cache options parsing.
#[serde_as]
#[derive(Deserialize)]
pub struct WakeComputeLockOptions { pub struct WakeComputeLockOptions {
/// The number of shards the lock map should have /// The number of shards the lock map should have
#[serde_as(as = "serde_with::DisplayFromStr")]
#[serde(default)]
pub shards: usize, pub shards: usize,
/// The number of allowed concurrent requests for each endpoitn /// The number of allowed concurrent requests for each endpoitn
#[serde_as(as = "serde_with::DisplayFromStr")]
pub permits: usize, pub permits: usize,
/// Garbage collection epoch /// Garbage collection epoch
#[serde(with = "humantime_serde")]
#[serde(default)]
pub epoch: Duration, pub epoch: Duration,
/// Lock timeout /// Lock timeout
#[serde(with = "humantime_serde")]
#[serde(default)]
pub timeout: Duration, pub timeout: Duration,
} }
@@ -538,44 +554,22 @@ impl WakeComputeLockOptions {
/// Parse lock options passed via cmdline. /// Parse lock options passed via cmdline.
/// Example: [`Self::DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK`]. /// Example: [`Self::DEFAULT_OPTIONS_WAKE_COMPUTE_LOCK`].
fn parse(options: &str) -> anyhow::Result<Self> { fn parse(options: &str) -> anyhow::Result<Self> {
let mut shards = None; let out = Self::deserialize(SimpleKVConfig(options))?;
let mut permits = None; if out.permits != 0 {
let mut epoch = None; ensure!(
let mut timeout = None; out.timeout > Duration::ZERO,
"wake compute lock timeout should be non-zero"
for option in options.split(',') { );
let (key, value) = option ensure!(
.split_once('=') out.epoch > Duration::ZERO,
.with_context(|| format!("bad key-value pair: {option}"))?; "wake compute lock gc epoch should be non-zero"
);
match key {
"shards" => shards = Some(value.parse()?),
"permits" => permits = Some(value.parse()?),
"epoch" => epoch = Some(humantime::parse_duration(value)?),
"timeout" => timeout = Some(humantime::parse_duration(value)?),
unknown => bail!("unknown key: {unknown}"),
}
}
// these dont matter if lock is disabled
if let Some(0) = permits {
timeout = Some(Duration::default());
epoch = Some(Duration::default());
shards = Some(2);
}
let out = Self {
shards: shards.context("missing `shards`")?,
permits: permits.context("missing `permits`")?,
epoch: epoch.context("missing `epoch`")?,
timeout: timeout.context("missing `timeout`")?,
};
ensure!(out.shards > 1, "shard count must be > 1"); ensure!(out.shards > 1, "shard count must be > 1");
ensure!( ensure!(
out.shards.is_power_of_two(), out.shards.is_power_of_two(),
"shard count must be a power of two" "shard count must be a power of two"
); );
}
Ok(out) Ok(out)
} }
@@ -590,6 +584,100 @@ impl FromStr for WakeComputeLockOptions {
} }
} }
struct SimpleKVConfig<'a>(&'a str);
struct SimpleKVConfigMapAccess<'a> {
kv: std::str::Split<'a, char>,
val: Option<&'a str>,
}
#[derive(Debug)]
struct SimpleKVConfigErr(String);
impl std::fmt::Display for SimpleKVConfigErr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.0)
}
}
impl std::error::Error for SimpleKVConfigErr {}
impl serde::de::Error for SimpleKVConfigErr {
fn custom<T>(msg: T) -> Self
where
T: std::fmt::Display,
{
Self(msg.to_string())
}
}
impl<'de> MapAccess<'de> for SimpleKVConfigMapAccess<'de> {
type Error = SimpleKVConfigErr;
fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
where
K: serde::de::DeserializeSeed<'de>,
{
let Some(kv) = self.kv.next() else {
return Ok(None);
};
let (key, value) = kv
.split_once('=')
.ok_or_else(|| SimpleKVConfigErr("invalid kv pair".to_string()))?;
self.val = Some(value);
seed.deserialize(BorrowedStrDeserializer::new(key))
.map(Some)
}
fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
where
V: serde::de::DeserializeSeed<'de>,
{
seed.deserialize(BorrowedStrDeserializer::new(self.val.take().unwrap()))
}
fn next_entry_seed<K, V>(
&mut self,
kseed: K,
vseed: V,
) -> Result<Option<(K::Value, V::Value)>, Self::Error>
where
K: serde::de::DeserializeSeed<'de>,
V: serde::de::DeserializeSeed<'de>,
{
let Some(kv) = self.kv.next() else {
return Ok(None);
};
let (key, value) = kv
.split_once('=')
.ok_or_else(|| SimpleKVConfigErr("invalid kv pair".to_string()))?;
let key = kseed.deserialize(BorrowedStrDeserializer::new(key))?;
let value = vseed.deserialize(BorrowedStrDeserializer::new(value))?;
Ok(Some((key, value)))
}
}
impl<'de> Deserializer<'de> for SimpleKVConfig<'de> {
type Error = SimpleKVConfigErr;
fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: serde::de::Visitor<'de>,
{
visitor.visit_map(SimpleKVConfigMapAccess {
kv: self.0.split(','),
val: None,
})
}
forward_to_deserialize_any! {
bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
bytes byte_buf option unit struct unit_struct newtype_struct seq tuple
tuple_struct map enum identifier ignored_any
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@@ -647,7 +735,7 @@ mod tests {
} = "permits=0".parse()?; } = "permits=0".parse()?;
assert_eq!(epoch, Duration::ZERO); assert_eq!(epoch, Duration::ZERO);
assert_eq!(timeout, Duration::ZERO); assert_eq!(timeout, Duration::ZERO);
assert_eq!(shards, 2); assert_eq!(shards, 0);
assert_eq!(permits, 0); assert_eq!(permits, 0);
Ok(()) Ok(())

View File

@@ -17,7 +17,7 @@ use crate::{
scram, EndpointCacheKey, scram, EndpointCacheKey,
}; };
use dashmap::DashMap; use dashmap::DashMap;
use std::{sync::Arc, time::Duration}; use std::{num::NonZeroUsize, sync::Arc, time::Duration};
use tokio::sync::{OwnedSemaphorePermit, Semaphore}; use tokio::sync::{OwnedSemaphorePermit, Semaphore};
use tokio::time::Instant; use tokio::time::Instant;
use tracing::info; use tracing::info;
@@ -449,13 +449,17 @@ impl ApiCaches {
/// Various caches for [`console`](super). /// Various caches for [`console`](super).
pub struct ApiLocks { pub struct ApiLocks {
name: &'static str, name: &'static str,
node_locks: DashMap<EndpointCacheKey, Arc<Semaphore>>, inner: Option<ApiLocksInner>,
permits: usize,
timeout: Duration, timeout: Duration,
epoch: std::time::Duration, epoch: std::time::Duration,
metrics: &'static ApiLockMetrics, metrics: &'static ApiLockMetrics,
} }
struct ApiLocksInner {
permits: NonZeroUsize,
node_locks: DashMap<EndpointCacheKey, Arc<Semaphore>>,
}
impl ApiLocks { impl ApiLocks {
pub fn new( pub fn new(
name: &'static str, name: &'static str,
@@ -465,10 +469,14 @@ impl ApiLocks {
epoch: std::time::Duration, epoch: std::time::Duration,
metrics: &'static ApiLockMetrics, metrics: &'static ApiLockMetrics,
) -> prometheus::Result<Self> { ) -> prometheus::Result<Self> {
let inner = NonZeroUsize::new(permits).map(|permits| ApiLocksInner {
permits,
node_locks: DashMap::with_shard_amount(shards),
});
Ok(Self { Ok(Self {
name, name,
node_locks: DashMap::with_shard_amount(shards), inner,
permits,
timeout, timeout,
epoch, epoch,
metrics, metrics,
@@ -479,20 +487,21 @@ impl ApiLocks {
&self, &self,
key: &EndpointCacheKey, key: &EndpointCacheKey,
) -> Result<WakeComputePermit, errors::WakeComputeError> { ) -> Result<WakeComputePermit, errors::WakeComputeError> {
if self.permits == 0 { let Some(inner) = &self.inner else {
return Ok(WakeComputePermit { permit: None }); return Ok(WakeComputePermit { permit: None });
} };
let now = Instant::now(); let now = Instant::now();
let semaphore = { let semaphore = {
// get fast path // get fast path
if let Some(semaphore) = self.node_locks.get(key) { if let Some(semaphore) = inner.node_locks.get(key) {
semaphore.clone() semaphore.clone()
} else { } else {
self.node_locks inner
.node_locks
.entry(key.clone()) .entry(key.clone())
.or_insert_with(|| { .or_insert_with(|| {
self.metrics.semaphores_registered.inc(); self.metrics.semaphores_registered.inc();
Arc::new(Semaphore::new(self.permits)) Arc::new(Semaphore::new(inner.permits.get()))
}) })
.clone() .clone()
} }
@@ -509,13 +518,13 @@ impl ApiLocks {
} }
pub async fn garbage_collect_worker(&self) { pub async fn garbage_collect_worker(&self) {
if self.permits == 0 { let Some(inner) = &self.inner else {
return; return;
} };
let mut interval = let mut interval =
tokio::time::interval(self.epoch / (self.node_locks.shards().len()) as u32); tokio::time::interval(self.epoch / (inner.node_locks.shards().len()) as u32);
loop { loop {
for (i, shard) in self.node_locks.shards().iter().enumerate() { for (i, shard) in inner.node_locks.shards().iter().enumerate() {
interval.tick().await; interval.tick().await;
// temporary lock a single shard and then clear any semaphores that aren't currently checked out // temporary lock a single shard and then clear any semaphores that aren't currently checked out
// race conditions: if strong_count == 1, there's no way that it can increase while the shard is locked // race conditions: if strong_count == 1, there's no way that it can increase while the shard is locked