common_memory_manager/
policy.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::time::Duration;
16
17use humantime::{format_duration, parse_duration};
18use serde::{Deserialize, Serialize};
19
20/// Default wait timeout for memory acquisition.
21pub const DEFAULT_MEMORY_WAIT_TIMEOUT: Duration = Duration::from_secs(10);
22
23/// Defines how to react when memory cannot be acquired immediately.
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum OnExhaustedPolicy {
26    /// Wait until enough memory is released, bounded by timeout.
27    Wait { timeout: Duration },
28
29    /// Fail immediately if memory is not available.
30    Fail,
31}
32
33impl Default for OnExhaustedPolicy {
34    fn default() -> Self {
35        OnExhaustedPolicy::Wait {
36            timeout: DEFAULT_MEMORY_WAIT_TIMEOUT,
37        }
38    }
39}
40
41impl Serialize for OnExhaustedPolicy {
42    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
43    where
44        S: serde::Serializer,
45    {
46        let text = match self {
47            OnExhaustedPolicy::Fail => "fail".to_string(),
48            OnExhaustedPolicy::Wait { timeout } if *timeout == DEFAULT_MEMORY_WAIT_TIMEOUT => {
49                "wait".to_string()
50            }
51            OnExhaustedPolicy::Wait { timeout } => format!("wait({})", format_duration(*timeout)),
52        };
53        serializer.serialize_str(&text)
54    }
55}
56
57impl<'de> Deserialize<'de> for OnExhaustedPolicy {
58    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
59    where
60        D: serde::Deserializer<'de>,
61    {
62        let raw = String::deserialize(deserializer)?;
63        let lower = raw.to_ascii_lowercase();
64
65        // Accept both "skip" (legacy) and "fail".
66        if lower == "skip" || lower == "fail" {
67            return Ok(OnExhaustedPolicy::Fail);
68        }
69        if lower == "wait" {
70            return Ok(OnExhaustedPolicy::default());
71        }
72        if lower.starts_with("wait(") && lower.ends_with(')') {
73            let inner = &raw[5..raw.len() - 1];
74            let timeout = parse_duration(inner).map_err(serde::de::Error::custom)?;
75            return Ok(OnExhaustedPolicy::Wait { timeout });
76        }
77
78        Err(serde::de::Error::custom(format!(
79            "invalid memory policy: {}, expected wait, wait(<duration>), fail",
80            raw
81        )))
82    }
83}