From eef40ce9e2fa4ec6aaf5a08b1f16221aa8c497ac Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Mon, 13 May 2024 18:20:46 +0200 Subject: [PATCH] transform the deserializer manually --- Cargo.lock | 3 + libs/remote_storage/Cargo.toml | 2 +- libs/remote_storage/src/lib.rs | 325 ++++++++++++++++++++------------- 3 files changed, 200 insertions(+), 130 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index cf8a0b3286..3696d6ceb5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1014,6 +1014,9 @@ name = "camino" version = "1.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c59e92b5a388f549b863a7bea62612c09f24c8393560709a54558a9abdfb3b9c" +dependencies = [ + "serde", +] [[package]] name = "camino-tempfile" diff --git a/libs/remote_storage/Cargo.toml b/libs/remote_storage/Cargo.toml index 78da01c9a0..5a5a82e066 100644 --- a/libs/remote_storage/Cargo.toml +++ b/libs/remote_storage/Cargo.toml @@ -14,7 +14,7 @@ aws-config.workspace = true aws-sdk-s3.workspace = true aws-credential-types.workspace = true bytes.workspace = true -camino.workspace = true +camino = { workspace = true, features = [ "serde1" ] } humantime.workspace = true hyper = { workspace = true, features = ["stream"] } futures.workspace = true diff --git a/libs/remote_storage/src/lib.rs b/libs/remote_storage/src/lib.rs index 72748e156c..25e7b6a8d1 100644 --- a/libs/remote_storage/src/lib.rs +++ b/libs/remote_storage/src/lib.rs @@ -27,7 +27,7 @@ use std::{ time::{Duration, SystemTime}, }; -use anyhow::{bail, Context}; +use anyhow::Context; use aws_sdk_s3::types::StorageClass; use camino::{Utf8Path, Utf8PathBuf}; @@ -36,7 +36,6 @@ use futures::stream::Stream; use serde::{Deserialize, Serialize}; use tokio::sync::Semaphore; use tokio_util::sync::CancellationToken; -use toml_edit::Item; use tracing::info; pub use self::{ @@ -621,58 +620,158 @@ impl Debug for AzureConfig { } } -impl RemoteStorageConfig { - pub const DEFAULT_TIMEOUT: Duration = std::time::Duration::from_secs(120); +struct RemoteStorageConfigDeserializeVisitor; - pub fn from_toml(toml: &toml_edit::Item) -> anyhow::Result> { - let local_path = toml.get("local_path"); - let bucket_name = toml.get("bucket_name"); - let bucket_region = toml.get("bucket_region"); - let container_name = toml.get("container_name"); - let container_region = toml.get("container_region"); +impl<'de> serde::de::Visitor<'de> for RemoteStorageConfigDeserializeVisitor { + type Value = RemoteStorageConfig; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a RemoteStorageConfig") + } + + fn visit_map(self, mut map: A) -> Result + where + A: serde::de::MapAccess<'de>, + { + let mut local_path: Option = None; + let mut bucket_name = None; + let mut bucket_region = None; + let mut prefix_in_bucket = None; + let mut container_name = None; + let mut storage_account = None; + let mut container_region = None; + let mut prefix_in_container = None; + let mut concurrency_limit = None; + let mut max_keys_per_list_response = None; + let mut upload_storage_class: Option = None; + let mut endpoint = None; + let mut timeout: Option = None; + while let Some(key) = map.next_key::()? { + match key.as_str() { + "local_path" => { + if local_path.is_some() { + return Err(serde::de::Error::duplicate_field("local_path")); + } + local_path = Some(map.next_value()?); + } + "bucket_name" => { + if bucket_name.is_some() { + return Err(serde::de::Error::duplicate_field("bucket_name")); + } + bucket_name = Some(map.next_value()?); + } + "bucket_region" => { + if bucket_region.is_some() { + return Err(serde::de::Error::duplicate_field("bucket_region")); + } + bucket_region = Some(map.next_value()?); + } + "prefix_in_bucket" => { + if prefix_in_bucket.is_some() { + return Err(serde::de::Error::duplicate_field("prefix_in_bucket")); + } + prefix_in_bucket = Some(map.next_value()?); + } + "container_name" => { + if container_name.is_some() { + return Err(serde::de::Error::duplicate_field("container_name")); + } + container_name = Some(map.next_value()?); + } + "storage_account" => { + if storage_account.is_some() { + return Err(serde::de::Error::duplicate_field("storage_account")); + } + storage_account = map.next_value()?; + } + "container_region" => { + if container_region.is_some() { + return Err(serde::de::Error::duplicate_field("container_region")); + } + container_region = Some(map.next_value()?); + } + "prefix_in_container" => { + if prefix_in_container.is_some() { + return Err(serde::de::Error::duplicate_field("prefix_in_container")); + } + prefix_in_container = Some(map.next_value()?); + } + "concurrency_limit" => { + if concurrency_limit.is_some() { + return Err(serde::de::Error::duplicate_field("concurrency_limit")); + } + concurrency_limit = Some(map.next_value()?); + } + "max_keys_per_list_response" => { + if max_keys_per_list_response.is_some() { + return Err(serde::de::Error::duplicate_field( + "max_keys_per_list_response", + )); + } + max_keys_per_list_response = Some(map.next_value()?); + } + "upload_storage_class" => { + if upload_storage_class.is_some() { + return Err(serde::de::Error::duplicate_field("upload_storage_class")); + } + let s = map.next_value::()?; + let v = StorageClass::from_str(&s).expect("infallible"); + #[allow(deprecated)] + if matches!(v, StorageClass::Unknown(_)) { + let values = format!("{:?}", StorageClass::values()); + return Err(serde::de::Error::invalid_value( + serde::de::Unexpected::Str(&s), + &values.as_str(), + )); + } + upload_storage_class = Some(v); + } + "endpoint" => { + if endpoint.is_some() { + return Err(serde::de::Error::duplicate_field("endpoint")); + } + endpoint = Some(map.next_value()?); + } + "timeout" => { + if timeout.is_some() { + return Err(serde::de::Error::duplicate_field("timeout")); + } + let s = map.next_value::()?; + let d = humantime::parse_duration(&s) + .map_err(|e| format!("invalid `timeout`: {e}")) + .map_err(serde::de::Error::custom)?; + timeout = Some(d); + } + field => { + return Err(serde::de::Error::custom(format!("unknown field {field:?}"))); + } + } + } let use_azure = container_name.is_some() && container_region.is_some(); - let default_concurrency_limit = if use_azure { - DEFAULT_REMOTE_STORAGE_AZURE_CONCURRENCY_LIMIT - } else { - DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT + let concurrency_limit = { + let default = if use_azure { + DEFAULT_REMOTE_STORAGE_AZURE_CONCURRENCY_LIMIT + } else { + DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT + }; + concurrency_limit + .unwrap_or(NonZeroUsize::new(default).expect("defaults should be valid")) }; - let concurrency_limit = NonZeroUsize::new( - parse_optional_integer("concurrency_limit", toml)?.unwrap_or(default_concurrency_limit), - ) - .context("Failed to parse 'concurrency_limit' as a positive integer")?; let max_keys_per_list_response = - parse_optional_integer::("max_keys_per_list_response", toml) - .context("Failed to parse 'max_keys_per_list_response' as a positive integer")? - .or(DEFAULT_MAX_KEYS_PER_LIST_RESPONSE); + max_keys_per_list_response.unwrap_or(DEFAULT_MAX_KEYS_PER_LIST_RESPONSE); - let endpoint = toml - .get("endpoint") - .map(|endpoint| parse_toml_string("endpoint", endpoint)) - .transpose()?; - - let timeout = toml - .get("timeout") - .map(|timeout| { - timeout - .as_str() - .ok_or_else(|| anyhow::Error::msg("timeout was not a string")) - }) - .transpose() - .and_then(|timeout| { - timeout - .map(humantime::parse_duration) - .transpose() - .map_err(anyhow::Error::new) - }) - .context("parse timeout")? - .unwrap_or(Self::DEFAULT_TIMEOUT); - - if timeout < Duration::from_secs(1) { - bail!("timeout was specified as {timeout:?} which is too low"); - } + let timeout = { + let timeout = timeout.unwrap_or(RemoteStorageConfig::DEFAULT_TIMEOUT); + if timeout < Duration::from_secs(1) { + return Err(serde::de::Error::custom(format!( + "timeout was specified as {timeout:?} which is too low" + ))); + } + timeout + }; let storage = match ( local_path, @@ -681,105 +780,73 @@ impl RemoteStorageConfig { container_name, container_region, ) { - // no 'local_path' nor 'bucket_name' options are provided, consider this remote storage disabled - (None, None, None, None, None) => return Ok(None), - (_, Some(_), None, ..) => { - bail!("'bucket_region' option is mandatory if 'bucket_name' is given ") - } - (_, None, Some(_), ..) => { - bail!("'bucket_name' option is mandatory if 'bucket_region' is given ") - } + (None, None, None, None, None) => Err(serde::de::Error::custom( + "one or more mandatory fields not specified", + )), + (_, Some(_), None, ..) => Err(serde::de::Error::custom( + "'bucket_region' option is mandatory if 'bucket_name' is given ", + )), + (_, None, Some(_), ..) => Err(serde::de::Error::custom( + "'bucket_name' option is mandatory if 'bucket_region' is given ", + )), (None, Some(bucket_name), Some(bucket_region), ..) => { - RemoteStorageKind::AwsS3(S3Config { - bucket_name: parse_toml_string("bucket_name", bucket_name)?, - bucket_region: parse_toml_string("bucket_region", bucket_region)?, - prefix_in_bucket: toml - .get("prefix_in_bucket") - .map(|prefix_in_bucket| { - parse_toml_string("prefix_in_bucket", prefix_in_bucket) - }) - .transpose()?, + Ok(RemoteStorageKind::AwsS3(S3Config { + bucket_name, + bucket_region, + prefix_in_bucket, endpoint, concurrency_limit, max_keys_per_list_response, - upload_storage_class: toml - .get("upload_storage_class") - .map(|prefix_in_bucket| -> anyhow::Result<_> { - let s = parse_toml_string("upload_storage_class", prefix_in_bucket)?; - let storage_class = StorageClass::from_str(&s).expect("infallible"); - #[allow(deprecated)] - if matches!(storage_class, StorageClass::Unknown(_)) { - bail!("Specified storage class unknown to SDK: '{s}'. Allowed values: {:?}", StorageClass::values()); - } - Ok(storage_class) - }) - .transpose()?, - }) - } - (_, _, _, Some(_), None) => { - bail!("'container_name' option is mandatory if 'container_region' is given ") - } - (_, _, _, None, Some(_)) => { - bail!("'container_name' option is mandatory if 'container_region' is given ") + upload_storage_class, + })) } + (_, _, _, Some(_), None) => Err(serde::de::Error::custom( + "'container_name' option is mandatory if 'container_region' is given ", + )), + (_, _, _, None, Some(_)) => Err(serde::de::Error::custom( + "'container_name' option is mandatory if 'container_region' is given ", + )), (None, None, None, Some(container_name), Some(container_region)) => { - RemoteStorageKind::AzureContainer(AzureConfig { - container_name: parse_toml_string("container_name", container_name)?, - storage_account: toml - .get("storage_account") - .map(|storage_account| { - parse_toml_string("storage_account", storage_account) - }) - .transpose()?, - container_region: parse_toml_string("container_region", container_region)?, - prefix_in_container: toml - .get("prefix_in_container") - .map(|prefix_in_container| { - parse_toml_string("prefix_in_container", prefix_in_container) - }) - .transpose()?, + Ok(RemoteStorageKind::AzureContainer(AzureConfig { + container_name, + storage_account, + container_region, + prefix_in_container, concurrency_limit, max_keys_per_list_response, - }) + })) } - (Some(local_path), None, None, None, None) => RemoteStorageKind::LocalFs( - Utf8PathBuf::from(parse_toml_string("local_path", local_path)?), - ), - (Some(_), Some(_), ..) => { - bail!("'local_path' and 'bucket_name' are mutually exclusive") + (Some(local_path), None, None, None, None) => { + Ok(RemoteStorageKind::LocalFs(local_path)) } - (Some(_), _, _, Some(_), Some(_)) => { - bail!("local_path and 'container_name' are mutually exclusive") - } - }; + (Some(_), Some(_), ..) => Err(serde::de::Error::custom( + "'local_path' and 'bucket_name' are mutually exclusive", + )), + (Some(_), _, _, Some(_), Some(_)) => Err(serde::de::Error::custom( + "local_path and 'container_name' are mutually exclusive", + )), + }?; - Ok(Some(RemoteStorageConfig { storage, timeout })) + Ok(RemoteStorageConfig { storage, timeout }) } } -// Helper functions to parse a toml Item -fn parse_optional_integer(name: &str, item: &toml_edit::Item) -> anyhow::Result> -where - I: TryFrom, - E: std::error::Error + Send + Sync + 'static, -{ - let toml_integer = match item.get(name) { - Some(item) => item - .as_integer() - .with_context(|| format!("configure option {name} is not an integer"))?, - None => return Ok(None), - }; - - I::try_from(toml_integer) - .map(Some) - .with_context(|| format!("configure option {name} is too large")) +impl<'de> serde::Deserialize<'de> for RemoteStorageConfig { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_map(RemoteStorageConfigDeserializeVisitor) + } } -fn parse_toml_string(name: &str, item: &Item) -> anyhow::Result { - let s = item - .as_str() - .with_context(|| format!("configure option {name} is not a string"))?; - Ok(s.to_string()) +impl RemoteStorageConfig { + pub const DEFAULT_TIMEOUT: Duration = std::time::Duration::from_secs(120); + + pub fn from_toml(toml: &toml_edit::Item) -> anyhow::Result> { + let toml = toml.to_string(); + Ok(toml_edit::de::from_str(&toml)?) + } } struct ConcurrencyLimiter {