transform the deserializer manually

This commit is contained in:
Christian Schwarz
2024-05-13 18:20:46 +02:00
parent fd0b22f5cd
commit eef40ce9e2
3 changed files with 200 additions and 130 deletions

3
Cargo.lock generated
View File

@@ -1014,6 +1014,9 @@ name = "camino"
version = "1.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c59e92b5a388f549b863a7bea62612c09f24c8393560709a54558a9abdfb3b9c"
dependencies = [
"serde",
]
[[package]]
name = "camino-tempfile"

View File

@@ -14,7 +14,7 @@ aws-config.workspace = true
aws-sdk-s3.workspace = true
aws-credential-types.workspace = true
bytes.workspace = true
camino.workspace = true
camino = { workspace = true, features = [ "serde1" ] }
humantime.workspace = true
hyper = { workspace = true, features = ["stream"] }
futures.workspace = true

View File

@@ -27,7 +27,7 @@ use std::{
time::{Duration, SystemTime},
};
use anyhow::{bail, Context};
use anyhow::Context;
use aws_sdk_s3::types::StorageClass;
use camino::{Utf8Path, Utf8PathBuf};
@@ -36,7 +36,6 @@ use futures::stream::Stream;
use serde::{Deserialize, Serialize};
use tokio::sync::Semaphore;
use tokio_util::sync::CancellationToken;
use toml_edit::Item;
use tracing::info;
pub use self::{
@@ -621,58 +620,158 @@ impl Debug for AzureConfig {
}
}
impl RemoteStorageConfig {
pub const DEFAULT_TIMEOUT: Duration = std::time::Duration::from_secs(120);
struct RemoteStorageConfigDeserializeVisitor;
pub fn from_toml(toml: &toml_edit::Item) -> anyhow::Result<Option<RemoteStorageConfig>> {
let local_path = toml.get("local_path");
let bucket_name = toml.get("bucket_name");
let bucket_region = toml.get("bucket_region");
let container_name = toml.get("container_name");
let container_region = toml.get("container_region");
impl<'de> serde::de::Visitor<'de> for RemoteStorageConfigDeserializeVisitor {
type Value = RemoteStorageConfig;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a RemoteStorageConfig")
}
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: serde::de::MapAccess<'de>,
{
let mut local_path: Option<Utf8PathBuf> = None;
let mut bucket_name = None;
let mut bucket_region = None;
let mut prefix_in_bucket = None;
let mut container_name = None;
let mut storage_account = None;
let mut container_region = None;
let mut prefix_in_container = None;
let mut concurrency_limit = None;
let mut max_keys_per_list_response = None;
let mut upload_storage_class: Option<StorageClass> = None;
let mut endpoint = None;
let mut timeout: Option<Duration> = None;
while let Some(key) = map.next_key::<String>()? {
match key.as_str() {
"local_path" => {
if local_path.is_some() {
return Err(serde::de::Error::duplicate_field("local_path"));
}
local_path = Some(map.next_value()?);
}
"bucket_name" => {
if bucket_name.is_some() {
return Err(serde::de::Error::duplicate_field("bucket_name"));
}
bucket_name = Some(map.next_value()?);
}
"bucket_region" => {
if bucket_region.is_some() {
return Err(serde::de::Error::duplicate_field("bucket_region"));
}
bucket_region = Some(map.next_value()?);
}
"prefix_in_bucket" => {
if prefix_in_bucket.is_some() {
return Err(serde::de::Error::duplicate_field("prefix_in_bucket"));
}
prefix_in_bucket = Some(map.next_value()?);
}
"container_name" => {
if container_name.is_some() {
return Err(serde::de::Error::duplicate_field("container_name"));
}
container_name = Some(map.next_value()?);
}
"storage_account" => {
if storage_account.is_some() {
return Err(serde::de::Error::duplicate_field("storage_account"));
}
storage_account = map.next_value()?;
}
"container_region" => {
if container_region.is_some() {
return Err(serde::de::Error::duplicate_field("container_region"));
}
container_region = Some(map.next_value()?);
}
"prefix_in_container" => {
if prefix_in_container.is_some() {
return Err(serde::de::Error::duplicate_field("prefix_in_container"));
}
prefix_in_container = Some(map.next_value()?);
}
"concurrency_limit" => {
if concurrency_limit.is_some() {
return Err(serde::de::Error::duplicate_field("concurrency_limit"));
}
concurrency_limit = Some(map.next_value()?);
}
"max_keys_per_list_response" => {
if max_keys_per_list_response.is_some() {
return Err(serde::de::Error::duplicate_field(
"max_keys_per_list_response",
));
}
max_keys_per_list_response = Some(map.next_value()?);
}
"upload_storage_class" => {
if upload_storage_class.is_some() {
return Err(serde::de::Error::duplicate_field("upload_storage_class"));
}
let s = map.next_value::<String>()?;
let v = StorageClass::from_str(&s).expect("infallible");
#[allow(deprecated)]
if matches!(v, StorageClass::Unknown(_)) {
let values = format!("{:?}", StorageClass::values());
return Err(serde::de::Error::invalid_value(
serde::de::Unexpected::Str(&s),
&values.as_str(),
));
}
upload_storage_class = Some(v);
}
"endpoint" => {
if endpoint.is_some() {
return Err(serde::de::Error::duplicate_field("endpoint"));
}
endpoint = Some(map.next_value()?);
}
"timeout" => {
if timeout.is_some() {
return Err(serde::de::Error::duplicate_field("timeout"));
}
let s = map.next_value::<String>()?;
let d = humantime::parse_duration(&s)
.map_err(|e| format!("invalid `timeout`: {e}"))
.map_err(serde::de::Error::custom)?;
timeout = Some(d);
}
field => {
return Err(serde::de::Error::custom(format!("unknown field {field:?}")));
}
}
}
let use_azure = container_name.is_some() && container_region.is_some();
let default_concurrency_limit = if use_azure {
DEFAULT_REMOTE_STORAGE_AZURE_CONCURRENCY_LIMIT
} else {
DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT
let concurrency_limit = {
let default = if use_azure {
DEFAULT_REMOTE_STORAGE_AZURE_CONCURRENCY_LIMIT
} else {
DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT
};
concurrency_limit
.unwrap_or(NonZeroUsize::new(default).expect("defaults should be valid"))
};
let concurrency_limit = NonZeroUsize::new(
parse_optional_integer("concurrency_limit", toml)?.unwrap_or(default_concurrency_limit),
)
.context("Failed to parse 'concurrency_limit' as a positive integer")?;
let max_keys_per_list_response =
parse_optional_integer::<i32, _>("max_keys_per_list_response", toml)
.context("Failed to parse 'max_keys_per_list_response' as a positive integer")?
.or(DEFAULT_MAX_KEYS_PER_LIST_RESPONSE);
max_keys_per_list_response.unwrap_or(DEFAULT_MAX_KEYS_PER_LIST_RESPONSE);
let endpoint = toml
.get("endpoint")
.map(|endpoint| parse_toml_string("endpoint", endpoint))
.transpose()?;
let timeout = toml
.get("timeout")
.map(|timeout| {
timeout
.as_str()
.ok_or_else(|| anyhow::Error::msg("timeout was not a string"))
})
.transpose()
.and_then(|timeout| {
timeout
.map(humantime::parse_duration)
.transpose()
.map_err(anyhow::Error::new)
})
.context("parse timeout")?
.unwrap_or(Self::DEFAULT_TIMEOUT);
if timeout < Duration::from_secs(1) {
bail!("timeout was specified as {timeout:?} which is too low");
}
let timeout = {
let timeout = timeout.unwrap_or(RemoteStorageConfig::DEFAULT_TIMEOUT);
if timeout < Duration::from_secs(1) {
return Err(serde::de::Error::custom(format!(
"timeout was specified as {timeout:?} which is too low"
)));
}
timeout
};
let storage = match (
local_path,
@@ -681,105 +780,73 @@ impl RemoteStorageConfig {
container_name,
container_region,
) {
// no 'local_path' nor 'bucket_name' options are provided, consider this remote storage disabled
(None, None, None, None, None) => return Ok(None),
(_, Some(_), None, ..) => {
bail!("'bucket_region' option is mandatory if 'bucket_name' is given ")
}
(_, None, Some(_), ..) => {
bail!("'bucket_name' option is mandatory if 'bucket_region' is given ")
}
(None, None, None, None, None) => Err(serde::de::Error::custom(
"one or more mandatory fields not specified",
)),
(_, Some(_), None, ..) => Err(serde::de::Error::custom(
"'bucket_region' option is mandatory if 'bucket_name' is given ",
)),
(_, None, Some(_), ..) => Err(serde::de::Error::custom(
"'bucket_name' option is mandatory if 'bucket_region' is given ",
)),
(None, Some(bucket_name), Some(bucket_region), ..) => {
RemoteStorageKind::AwsS3(S3Config {
bucket_name: parse_toml_string("bucket_name", bucket_name)?,
bucket_region: parse_toml_string("bucket_region", bucket_region)?,
prefix_in_bucket: toml
.get("prefix_in_bucket")
.map(|prefix_in_bucket| {
parse_toml_string("prefix_in_bucket", prefix_in_bucket)
})
.transpose()?,
Ok(RemoteStorageKind::AwsS3(S3Config {
bucket_name,
bucket_region,
prefix_in_bucket,
endpoint,
concurrency_limit,
max_keys_per_list_response,
upload_storage_class: toml
.get("upload_storage_class")
.map(|prefix_in_bucket| -> anyhow::Result<_> {
let s = parse_toml_string("upload_storage_class", prefix_in_bucket)?;
let storage_class = StorageClass::from_str(&s).expect("infallible");
#[allow(deprecated)]
if matches!(storage_class, StorageClass::Unknown(_)) {
bail!("Specified storage class unknown to SDK: '{s}'. Allowed values: {:?}", StorageClass::values());
}
Ok(storage_class)
})
.transpose()?,
})
}
(_, _, _, Some(_), None) => {
bail!("'container_name' option is mandatory if 'container_region' is given ")
}
(_, _, _, None, Some(_)) => {
bail!("'container_name' option is mandatory if 'container_region' is given ")
upload_storage_class,
}))
}
(_, _, _, Some(_), None) => Err(serde::de::Error::custom(
"'container_name' option is mandatory if 'container_region' is given ",
)),
(_, _, _, None, Some(_)) => Err(serde::de::Error::custom(
"'container_name' option is mandatory if 'container_region' is given ",
)),
(None, None, None, Some(container_name), Some(container_region)) => {
RemoteStorageKind::AzureContainer(AzureConfig {
container_name: parse_toml_string("container_name", container_name)?,
storage_account: toml
.get("storage_account")
.map(|storage_account| {
parse_toml_string("storage_account", storage_account)
})
.transpose()?,
container_region: parse_toml_string("container_region", container_region)?,
prefix_in_container: toml
.get("prefix_in_container")
.map(|prefix_in_container| {
parse_toml_string("prefix_in_container", prefix_in_container)
})
.transpose()?,
Ok(RemoteStorageKind::AzureContainer(AzureConfig {
container_name,
storage_account,
container_region,
prefix_in_container,
concurrency_limit,
max_keys_per_list_response,
})
}))
}
(Some(local_path), None, None, None, None) => RemoteStorageKind::LocalFs(
Utf8PathBuf::from(parse_toml_string("local_path", local_path)?),
),
(Some(_), Some(_), ..) => {
bail!("'local_path' and 'bucket_name' are mutually exclusive")
(Some(local_path), None, None, None, None) => {
Ok(RemoteStorageKind::LocalFs(local_path))
}
(Some(_), _, _, Some(_), Some(_)) => {
bail!("local_path and 'container_name' are mutually exclusive")
}
};
(Some(_), Some(_), ..) => Err(serde::de::Error::custom(
"'local_path' and 'bucket_name' are mutually exclusive",
)),
(Some(_), _, _, Some(_), Some(_)) => Err(serde::de::Error::custom(
"local_path and 'container_name' are mutually exclusive",
)),
}?;
Ok(Some(RemoteStorageConfig { storage, timeout }))
Ok(RemoteStorageConfig { storage, timeout })
}
}
// Helper functions to parse a toml Item
fn parse_optional_integer<I, E>(name: &str, item: &toml_edit::Item) -> anyhow::Result<Option<I>>
where
I: TryFrom<i64, Error = E>,
E: std::error::Error + Send + Sync + 'static,
{
let toml_integer = match item.get(name) {
Some(item) => item
.as_integer()
.with_context(|| format!("configure option {name} is not an integer"))?,
None => return Ok(None),
};
I::try_from(toml_integer)
.map(Some)
.with_context(|| format!("configure option {name} is too large"))
impl<'de> serde::Deserialize<'de> for RemoteStorageConfig {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
deserializer.deserialize_map(RemoteStorageConfigDeserializeVisitor)
}
}
fn parse_toml_string(name: &str, item: &Item) -> anyhow::Result<String> {
let s = item
.as_str()
.with_context(|| format!("configure option {name} is not a string"))?;
Ok(s.to_string())
impl RemoteStorageConfig {
pub const DEFAULT_TIMEOUT: Duration = std::time::Duration::from_secs(120);
pub fn from_toml(toml: &toml_edit::Item) -> anyhow::Result<Option<RemoteStorageConfig>> {
let toml = toml.to_string();
Ok(toml_edit::de::from_str(&toml)?)
}
}
struct ConcurrencyLimiter {