Skip to main content

auth/
common.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use common_base::secrets::SecretString;
18use digest::Digest;
19use pbkdf2::pbkdf2_hmac;
20use sha1::Sha1;
21use sha2::Sha256;
22use snafu::{OptionExt, ensure};
23
24use crate::error::{IllegalParamSnafu, InvalidConfigSnafu, Result, UserPasswordMismatchSnafu};
25use crate::user_info::DefaultUserInfo;
26use crate::user_provider::static_user_provider::{STATIC_USER_PROVIDER, StaticUserProvider};
27use crate::user_provider::watch_file_user_provider::{
28    WATCH_FILE_USER_PROVIDER, WatchFileUserProvider,
29};
30use crate::{UserInfoRef, UserProviderRef};
31
32pub(crate) const DEFAULT_USERNAME: &str = "greptime";
33pub const DEFAULT_PBKDF2_SHA256_ITERATIONS: u32 = 4096;
34pub const PBKDF2_SHA256_HASH_LEN: usize = 32;
35pub const MAX_PBKDF2_SHA256_ITERATIONS: u32 = 1_000_000;
36pub const MAX_PBKDF2_SHA256_SALT_LEN: usize = 1024;
37
38/// construct a [`UserInfo`](crate::user_info::UserInfo) impl with name
39/// use default username `greptime` if None is provided
40pub fn userinfo_by_name(username: Option<String>) -> UserInfoRef {
41    DefaultUserInfo::with_name(username.unwrap_or_else(|| DEFAULT_USERNAME.to_string()))
42}
43
44pub fn user_provider_from_option(opt: &str) -> Result<UserProviderRef> {
45    let (name, content) = opt.split_once(':').with_context(|| InvalidConfigSnafu {
46        value: opt.to_string(),
47        msg: "UserProviderOption must be in format `<option>:<value>`",
48    })?;
49    match name {
50        STATIC_USER_PROVIDER => {
51            let provider =
52                StaticUserProvider::new(content).map(|p| Arc::new(p) as UserProviderRef)?;
53            Ok(provider)
54        }
55        WATCH_FILE_USER_PROVIDER => {
56            WatchFileUserProvider::new(content).map(|p| Arc::new(p) as UserProviderRef)
57        }
58        _ => InvalidConfigSnafu {
59            value: name.to_string(),
60            msg: "Invalid UserProviderOption",
61        }
62        .fail(),
63    }
64}
65
66pub fn static_user_provider_from_option(opt: &str) -> Result<StaticUserProvider> {
67    let (name, content) = opt.split_once(':').with_context(|| InvalidConfigSnafu {
68        value: opt.to_string(),
69        msg: "UserProviderOption must be in format `<option>:<value>`",
70    })?;
71    match name {
72        STATIC_USER_PROVIDER => {
73            let provider = StaticUserProvider::new(content)?;
74            Ok(provider)
75        }
76        _ => InvalidConfigSnafu {
77            value: name.to_string(),
78            msg: format!("Invalid UserProviderOption, expect only {STATIC_USER_PROVIDER}"),
79        }
80        .fail(),
81    }
82}
83
84type Username<'a> = &'a str;
85type HostOrIp<'a> = &'a str;
86
87#[derive(Debug, Clone)]
88pub enum Identity<'a> {
89    UserId(Username<'a>, Option<HostOrIp<'a>>),
90}
91
92pub type HashedPassword<'a> = &'a [u8];
93pub type Salt<'a> = &'a [u8];
94
95/// Authentication information sent by the client.
96pub enum Password<'a> {
97    PlainText(SecretString),
98    MysqlNativePassword(HashedPassword<'a>, Salt<'a>),
99    PgMD5(HashedPassword<'a>, Salt<'a>),
100}
101
102impl Password<'_> {
103    pub fn r#type(&self) -> &str {
104        match self {
105            Password::PlainText(_) => "plain_text",
106            Password::MysqlNativePassword(_, _) => "mysql_native_password",
107            Password::PgMD5(_, _) => "pg_md5",
108        }
109    }
110}
111
112pub fn auth_mysql(
113    auth_data: HashedPassword,
114    salt: Salt,
115    username: &str,
116    save_pwd: &[u8],
117) -> Result<()> {
118    let hash_stage_2 = mysql_native_password_hash(save_pwd);
119    auth_mysql_with_hash_stage_2(auth_data, salt, username, &hash_stage_2)
120}
121
122pub(crate) fn auth_mysql_with_hash_stage_2(
123    auth_data: HashedPassword,
124    salt: Salt,
125    username: &str,
126    hash_stage_2: &[u8],
127) -> Result<()> {
128    ensure!(
129        auth_data.len() == 20,
130        IllegalParamSnafu {
131            msg: "Illegal mysql password length"
132        }
133    );
134    ensure!(
135        hash_stage_2.len() == 20,
136        InvalidConfigSnafu {
137            value: hash_stage_2.len().to_string(),
138            msg: "Illegal mysql native password verifier length",
139        }
140    );
141    // ref: https://github.com/mysql/mysql-server/blob/a246bad76b9271cb4333634e954040a970222e0a/sql/auth/password.cc#L62
142    let tmp = sha1_two(salt, hash_stage_2);
143    // xor auth_data and tmp
144    let mut xor_result = [0u8; 20];
145    for i in 0..20 {
146        xor_result[i] = auth_data[i] ^ tmp[i];
147    }
148    let candidate_stage_2 = sha1_one(&xor_result);
149    if candidate_stage_2 == hash_stage_2 {
150        Ok(())
151    } else {
152        UserPasswordMismatchSnafu {
153            username: username.to_string(),
154        }
155        .fail()
156    }
157}
158
159pub fn mysql_native_password_hash(save_pwd: &[u8]) -> Vec<u8> {
160    double_sha1(save_pwd)
161}
162
163pub fn format_mysql_native_password_verifier(password: &[u8]) -> String {
164    format!(
165        "mysql_native_password:{}",
166        hex::encode(mysql_native_password_hash(password))
167    )
168}
169
170pub fn format_pbkdf2_sha256_password_verifier(
171    password: &[u8],
172    salt: &[u8],
173    iterations: u32,
174) -> Result<String> {
175    ensure!(
176        iterations > 0 && iterations <= MAX_PBKDF2_SHA256_ITERATIONS,
177        IllegalParamSnafu {
178            msg: format!(
179                "pbkdf2_sha256 iterations must be in 1..={}",
180                MAX_PBKDF2_SHA256_ITERATIONS
181            )
182        }
183    );
184    ensure!(
185        !salt.is_empty() && salt.len() <= MAX_PBKDF2_SHA256_SALT_LEN,
186        IllegalParamSnafu {
187            msg: format!(
188                "pbkdf2_sha256 salt length must be in 1..={}",
189                MAX_PBKDF2_SHA256_SALT_LEN
190            )
191        }
192    );
193
194    let mut hash = [0u8; PBKDF2_SHA256_HASH_LEN];
195    pbkdf2_hmac::<Sha256>(password, salt, iterations, &mut hash);
196    Ok(format!(
197        "pbkdf2_sha256:{iterations}:{}:{}",
198        hex::encode(salt),
199        hex::encode(hash)
200    ))
201}
202
203fn sha1_two(input_1: &[u8], input_2: &[u8]) -> Vec<u8> {
204    let mut hasher = Sha1::new();
205    hasher.update(input_1);
206    hasher.update(input_2);
207    hasher.finalize().to_vec()
208}
209
210fn sha1_one(data: &[u8]) -> Vec<u8> {
211    let mut hasher = Sha1::new();
212    hasher.update(data);
213    hasher.finalize().to_vec()
214}
215
216fn double_sha1(data: &[u8]) -> Vec<u8> {
217    sha1_one(&sha1_one(data))
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223
224    #[test]
225    fn test_sha() {
226        let sha_1_answer: Vec<u8> = vec![
227            124, 74, 141, 9, 202, 55, 98, 175, 97, 229, 149, 32, 148, 61, 194, 100, 148, 248, 148,
228            27,
229        ];
230        let sha_1 = sha1_one("123456".as_bytes());
231        assert_eq!(sha_1, sha_1_answer);
232
233        let double_sha1_answer: Vec<u8> = vec![
234            107, 180, 131, 126, 183, 67, 41, 16, 94, 228, 86, 141, 218, 125, 198, 126, 210, 202,
235            42, 217,
236        ];
237        let double_sha1 = double_sha1("123456".as_bytes());
238        assert_eq!(double_sha1, double_sha1_answer);
239
240        let sha1_2_answer: Vec<u8> = vec![
241            132, 115, 215, 211, 99, 186, 164, 206, 168, 152, 217, 192, 117, 47, 240, 252, 142, 244,
242            37, 204,
243        ];
244        let sha1_2 = sha1_two("123456".as_bytes(), "654321".as_bytes());
245        assert_eq!(sha1_2, sha1_2_answer);
246    }
247
248    #[test]
249    fn test_format_mysql_native_password_verifier() {
250        let verifier = format_mysql_native_password_verifier("123456".as_bytes());
251        assert_eq!(
252            "mysql_native_password:6bb4837eb74329105ee4568dda7dc67ed2ca2ad9",
253            verifier
254        );
255    }
256
257    #[test]
258    fn test_format_pbkdf2_sha256_password_verifier() {
259        let verifier =
260            format_pbkdf2_sha256_password_verifier("password".as_bytes(), b"salt", 4096).unwrap();
261        assert_eq!(
262            "pbkdf2_sha256:4096:73616c74:c5e478d59288c841aa530db6845c4c8d962893a001ce4e11a4963873aa98134a",
263            verifier
264        );
265
266        assert!(format_pbkdf2_sha256_password_verifier(b"password", b"", 4096).is_err());
267        assert!(format_pbkdf2_sha256_password_verifier(b"password", b"salt", 0).is_err());
268        assert!(
269            format_pbkdf2_sha256_password_verifier(
270                b"password",
271                b"salt",
272                MAX_PBKDF2_SHA256_ITERATIONS + 1,
273            )
274            .is_err()
275        );
276    }
277}