1use std::sync::Arc;
16
17use common_base::secrets::SecretString;
18use digest::Digest;
19use pbkdf2::pbkdf2_hmac;
20use sha1::Sha1;
21use sha2::Sha256;
22use snafu::{OptionExt, ensure};
23
24use crate::error::{IllegalParamSnafu, InvalidConfigSnafu, Result, UserPasswordMismatchSnafu};
25use crate::user_info::DefaultUserInfo;
26use crate::user_provider::static_user_provider::{STATIC_USER_PROVIDER, StaticUserProvider};
27use crate::user_provider::watch_file_user_provider::{
28 WATCH_FILE_USER_PROVIDER, WatchFileUserProvider,
29};
30use crate::{UserInfoRef, UserProviderRef};
31
32pub(crate) const DEFAULT_USERNAME: &str = "greptime";
33pub const DEFAULT_PBKDF2_SHA256_ITERATIONS: u32 = 4096;
34pub const PBKDF2_SHA256_HASH_LEN: usize = 32;
35pub const MAX_PBKDF2_SHA256_ITERATIONS: u32 = 1_000_000;
36pub const MAX_PBKDF2_SHA256_SALT_LEN: usize = 1024;
37
38pub fn userinfo_by_name(username: Option<String>) -> UserInfoRef {
41 DefaultUserInfo::with_name(username.unwrap_or_else(|| DEFAULT_USERNAME.to_string()))
42}
43
44pub fn user_provider_from_option(opt: &str) -> Result<UserProviderRef> {
45 let (name, content) = opt.split_once(':').with_context(|| InvalidConfigSnafu {
46 value: opt.to_string(),
47 msg: "UserProviderOption must be in format `<option>:<value>`",
48 })?;
49 match name {
50 STATIC_USER_PROVIDER => {
51 let provider =
52 StaticUserProvider::new(content).map(|p| Arc::new(p) as UserProviderRef)?;
53 Ok(provider)
54 }
55 WATCH_FILE_USER_PROVIDER => {
56 WatchFileUserProvider::new(content).map(|p| Arc::new(p) as UserProviderRef)
57 }
58 _ => InvalidConfigSnafu {
59 value: name.to_string(),
60 msg: "Invalid UserProviderOption",
61 }
62 .fail(),
63 }
64}
65
66pub fn static_user_provider_from_option(opt: &str) -> Result<StaticUserProvider> {
67 let (name, content) = opt.split_once(':').with_context(|| InvalidConfigSnafu {
68 value: opt.to_string(),
69 msg: "UserProviderOption must be in format `<option>:<value>`",
70 })?;
71 match name {
72 STATIC_USER_PROVIDER => {
73 let provider = StaticUserProvider::new(content)?;
74 Ok(provider)
75 }
76 _ => InvalidConfigSnafu {
77 value: name.to_string(),
78 msg: format!("Invalid UserProviderOption, expect only {STATIC_USER_PROVIDER}"),
79 }
80 .fail(),
81 }
82}
83
84type Username<'a> = &'a str;
85type HostOrIp<'a> = &'a str;
86
87#[derive(Debug, Clone)]
88pub enum Identity<'a> {
89 UserId(Username<'a>, Option<HostOrIp<'a>>),
90}
91
92pub type HashedPassword<'a> = &'a [u8];
93pub type Salt<'a> = &'a [u8];
94
95pub enum Password<'a> {
97 PlainText(SecretString),
98 MysqlNativePassword(HashedPassword<'a>, Salt<'a>),
99 PgMD5(HashedPassword<'a>, Salt<'a>),
100}
101
102impl Password<'_> {
103 pub fn r#type(&self) -> &str {
104 match self {
105 Password::PlainText(_) => "plain_text",
106 Password::MysqlNativePassword(_, _) => "mysql_native_password",
107 Password::PgMD5(_, _) => "pg_md5",
108 }
109 }
110}
111
112pub fn auth_mysql(
113 auth_data: HashedPassword,
114 salt: Salt,
115 username: &str,
116 save_pwd: &[u8],
117) -> Result<()> {
118 let hash_stage_2 = mysql_native_password_hash(save_pwd);
119 auth_mysql_with_hash_stage_2(auth_data, salt, username, &hash_stage_2)
120}
121
122pub(crate) fn auth_mysql_with_hash_stage_2(
123 auth_data: HashedPassword,
124 salt: Salt,
125 username: &str,
126 hash_stage_2: &[u8],
127) -> Result<()> {
128 ensure!(
129 auth_data.len() == 20,
130 IllegalParamSnafu {
131 msg: "Illegal mysql password length"
132 }
133 );
134 ensure!(
135 hash_stage_2.len() == 20,
136 InvalidConfigSnafu {
137 value: hash_stage_2.len().to_string(),
138 msg: "Illegal mysql native password verifier length",
139 }
140 );
141 let tmp = sha1_two(salt, hash_stage_2);
143 let mut xor_result = [0u8; 20];
145 for i in 0..20 {
146 xor_result[i] = auth_data[i] ^ tmp[i];
147 }
148 let candidate_stage_2 = sha1_one(&xor_result);
149 if candidate_stage_2 == hash_stage_2 {
150 Ok(())
151 } else {
152 UserPasswordMismatchSnafu {
153 username: username.to_string(),
154 }
155 .fail()
156 }
157}
158
159pub fn mysql_native_password_hash(save_pwd: &[u8]) -> Vec<u8> {
160 double_sha1(save_pwd)
161}
162
163pub fn format_mysql_native_password_verifier(password: &[u8]) -> String {
164 format!(
165 "mysql_native_password:{}",
166 hex::encode(mysql_native_password_hash(password))
167 )
168}
169
170pub fn format_pbkdf2_sha256_password_verifier(
171 password: &[u8],
172 salt: &[u8],
173 iterations: u32,
174) -> Result<String> {
175 ensure!(
176 iterations > 0 && iterations <= MAX_PBKDF2_SHA256_ITERATIONS,
177 IllegalParamSnafu {
178 msg: format!(
179 "pbkdf2_sha256 iterations must be in 1..={}",
180 MAX_PBKDF2_SHA256_ITERATIONS
181 )
182 }
183 );
184 ensure!(
185 !salt.is_empty() && salt.len() <= MAX_PBKDF2_SHA256_SALT_LEN,
186 IllegalParamSnafu {
187 msg: format!(
188 "pbkdf2_sha256 salt length must be in 1..={}",
189 MAX_PBKDF2_SHA256_SALT_LEN
190 )
191 }
192 );
193
194 let mut hash = [0u8; PBKDF2_SHA256_HASH_LEN];
195 pbkdf2_hmac::<Sha256>(password, salt, iterations, &mut hash);
196 Ok(format!(
197 "pbkdf2_sha256:{iterations}:{}:{}",
198 hex::encode(salt),
199 hex::encode(hash)
200 ))
201}
202
203fn sha1_two(input_1: &[u8], input_2: &[u8]) -> Vec<u8> {
204 let mut hasher = Sha1::new();
205 hasher.update(input_1);
206 hasher.update(input_2);
207 hasher.finalize().to_vec()
208}
209
210fn sha1_one(data: &[u8]) -> Vec<u8> {
211 let mut hasher = Sha1::new();
212 hasher.update(data);
213 hasher.finalize().to_vec()
214}
215
216fn double_sha1(data: &[u8]) -> Vec<u8> {
217 sha1_one(&sha1_one(data))
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223
224 #[test]
225 fn test_sha() {
226 let sha_1_answer: Vec<u8> = vec![
227 124, 74, 141, 9, 202, 55, 98, 175, 97, 229, 149, 32, 148, 61, 194, 100, 148, 248, 148,
228 27,
229 ];
230 let sha_1 = sha1_one("123456".as_bytes());
231 assert_eq!(sha_1, sha_1_answer);
232
233 let double_sha1_answer: Vec<u8> = vec![
234 107, 180, 131, 126, 183, 67, 41, 16, 94, 228, 86, 141, 218, 125, 198, 126, 210, 202,
235 42, 217,
236 ];
237 let double_sha1 = double_sha1("123456".as_bytes());
238 assert_eq!(double_sha1, double_sha1_answer);
239
240 let sha1_2_answer: Vec<u8> = vec![
241 132, 115, 215, 211, 99, 186, 164, 206, 168, 152, 217, 192, 117, 47, 240, 252, 142, 244,
242 37, 204,
243 ];
244 let sha1_2 = sha1_two("123456".as_bytes(), "654321".as_bytes());
245 assert_eq!(sha1_2, sha1_2_answer);
246 }
247
248 #[test]
249 fn test_format_mysql_native_password_verifier() {
250 let verifier = format_mysql_native_password_verifier("123456".as_bytes());
251 assert_eq!(
252 "mysql_native_password:6bb4837eb74329105ee4568dda7dc67ed2ca2ad9",
253 verifier
254 );
255 }
256
257 #[test]
258 fn test_format_pbkdf2_sha256_password_verifier() {
259 let verifier =
260 format_pbkdf2_sha256_password_verifier("password".as_bytes(), b"salt", 4096).unwrap();
261 assert_eq!(
262 "pbkdf2_sha256:4096:73616c74:c5e478d59288c841aa530db6845c4c8d962893a001ce4e11a4963873aa98134a",
263 verifier
264 );
265
266 assert!(format_pbkdf2_sha256_password_verifier(b"password", b"", 4096).is_err());
267 assert!(format_pbkdf2_sha256_password_verifier(b"password", b"salt", 0).is_err());
268 assert!(
269 format_pbkdf2_sha256_password_verifier(
270 b"password",
271 b"salt",
272 MAX_PBKDF2_SHA256_ITERATIONS + 1,
273 )
274 .is_err()
275 );
276 }
277}