1use ::auth::UserProviderRef;
16use api::v1::Basic;
17use axum::extract::{Request, State};
18use axum::http::{self, StatusCode};
19use axum::middleware::Next;
20use axum::response::{IntoResponse, Response};
21use base64::Engine;
22use base64::prelude::BASE64_STANDARD;
23use common_base::secrets::{ExposeSecret, SecretString};
24use common_catalog::consts::DEFAULT_SCHEMA_NAME;
25use common_catalog::parse_catalog_and_schema_from_db_string;
26use common_error::ext::ErrorExt;
27use common_telemetry::warn;
28use common_time::Timezone;
29use common_time::timezone::parse_timezone;
30use headers::Header;
31use session::context::QueryContextBuilder;
32use snafu::{OptionExt, ResultExt, ensure};
33
34use crate::error::{
35 self, InvalidAuthHeaderInvisibleASCIISnafu, InvalidAuthHeaderSnafu, InvalidParameterSnafu,
36 NotFoundAuthHeaderSnafu, NotFoundInfluxAuthSnafu, Result, UnsupportedAuthSchemeSnafu,
37 UrlDecodeSnafu,
38};
39use crate::http::header::{GREPTIME_TIMEZONE_HEADER_NAME, GreptimeDbName};
40use crate::http::result::error_result::ErrorResponse;
41use crate::http::splunk::is_splunk_request;
42use crate::http::{AUTHORIZATION_HEADER, HTTP_API_PREFIX, PUBLIC_API_PREFIX};
43use crate::influxdb::{is_influxdb_request, is_influxdb_v2_request};
44
45#[derive(Clone)]
48pub struct AuthState {
49 user_provider: Option<UserProviderRef>,
50}
51
52impl AuthState {
53 pub fn new(user_provider: Option<UserProviderRef>) -> Self {
54 Self { user_provider }
55 }
56}
57
58pub async fn inner_auth<B>(
59 user_provider: Option<UserProviderRef>,
60 mut req: Request<B>,
61) -> std::result::Result<Request<B>, Response> {
62 let (catalog, schema) = extract_catalog_and_schema(&req);
64 let timezone = extract_timezone(&req);
66 let query_ctx_builder = QueryContextBuilder::default()
67 .current_catalog(catalog.clone())
68 .current_schema(schema.clone())
69 .timezone(timezone);
70
71 let query_ctx = query_ctx_builder.build();
72 let need_auth = need_auth(&req);
73
74 let user_provider = if let Some(user_provider) = user_provider.filter(|_| need_auth) {
76 user_provider
77 } else {
78 query_ctx.set_current_user(auth::userinfo_by_name(None));
79 let _ = req.extensions_mut().insert(query_ctx);
80 return Ok(req);
81 };
82
83 let (username, password) = match extract_username_and_password(&req) {
85 Ok((username, password)) => (username, password),
86 Err(e) => {
87 warn!(e; "extract username and password failed");
88 crate::metrics::METRIC_AUTH_FAILURE
89 .with_label_values(&[e.status_code().as_ref()])
90 .inc();
91 if is_splunk_request(&req) {
92 let (status, code) = match &e {
94 error::Error::NotFoundAuthHeader { .. } => (StatusCode::UNAUTHORIZED, 2),
95 _ => (StatusCode::FORBIDDEN, 4),
96 };
97 return Err(splunk_hec_err(status, code));
98 }
99 return Err(err_response(e));
100 }
101 };
102
103 match user_provider
105 .auth(
106 auth::Identity::UserId(&username, None),
107 auth::Password::PlainText(password),
108 &catalog,
109 &schema,
110 )
111 .await
112 {
113 Ok(userinfo) => {
114 query_ctx.set_current_user(userinfo);
115 let _ = req.extensions_mut().insert(query_ctx);
116 Ok(req)
117 }
118 Err(e) => {
119 warn!(e; "authenticate failed");
120 crate::metrics::METRIC_AUTH_FAILURE
121 .with_label_values(&[e.status_code().as_ref()])
122 .inc();
123 if is_splunk_request(&req) {
125 return Err(splunk_hec_err(StatusCode::FORBIDDEN, 4));
126 }
127 Err(err_response(e))
128 }
129 }
130}
131
132pub async fn check_http_auth(
133 State(auth_state): State<AuthState>,
134 req: Request,
135 next: Next,
136) -> Response {
137 match inner_auth(auth_state.user_provider, req).await {
138 Ok(req) => next.run(req).await,
139 Err(resp) => resp,
140 }
141}
142
143fn splunk_hec_err(status: StatusCode, code: u32) -> Response {
145 let text = match code {
146 2 => "Token is required",
147 4 => "Invalid token",
148 _ => "Unauthorized",
149 };
150 (
151 status,
152 axum::Json(serde_json::json!({ "text": text, "code": code })),
153 )
154 .into_response()
155}
156
157fn err_response(err: impl ErrorExt) -> Response {
158 (StatusCode::UNAUTHORIZED, ErrorResponse::from_error(err)).into_response()
159}
160
161pub fn extract_catalog_and_schema<B>(request: &Request<B>) -> (String, String) {
162 let dbname = request
164 .headers()
165 .get(GreptimeDbName::name())
166 .and_then(|header| header.to_str().ok())
168 .or_else(|| {
169 let query = request.uri().query().unwrap_or_default();
170 if is_influxdb_v2_request(request) {
171 extract_db_from_query(query).or_else(|| extract_bucket_from_query(query))
172 } else {
173 extract_db_from_query(query)
174 }
175 })
176 .unwrap_or(DEFAULT_SCHEMA_NAME);
177
178 parse_catalog_and_schema_from_db_string(dbname)
179}
180
181fn extract_timezone<B>(request: &Request<B>) -> Timezone {
182 let timezone = request
184 .headers()
185 .get(&GREPTIME_TIMEZONE_HEADER_NAME)
186 .and_then(|header| header.to_str().ok())
188 .unwrap_or("");
189 parse_timezone(Some(timezone))
190}
191
192fn get_influxdb_credentials<B>(request: &Request<B>) -> Result<Option<(Username, Password)>> {
193 if let Some(header) = request.headers().get(http::header::AUTHORIZATION) {
195 let (auth_scheme, credential) = header
197 .to_str()
198 .context(InvalidAuthHeaderInvisibleASCIISnafu)?
199 .split_once(' ')
200 .context(InvalidAuthHeaderSnafu)?;
201
202 let (username, password) = match auth_scheme.to_lowercase().as_str() {
203 "token" => {
204 let (u, p) = credential.split_once(':').context(InvalidAuthHeaderSnafu)?;
205 (u.to_string(), p.to_string().into())
206 }
207 "basic" => decode_basic(credential)?,
208 _ => UnsupportedAuthSchemeSnafu { name: auth_scheme }.fail()?,
209 };
210
211 Ok(Some((username, password)))
212 } else {
213 let Some(query_str) = request.uri().query() else {
215 return Ok(None);
216 };
217
218 let query_str = urlencoding::decode(query_str).context(UrlDecodeSnafu)?;
219
220 match extract_influxdb_user_from_query(&query_str) {
221 (None, None) => Ok(None),
222 (Some(username), Some(password)) => {
223 Ok(Some((username.to_string(), password.to_string().into())))
224 }
225 _ => InvalidParameterSnafu {
226 reason: "influxdb auth: username and password must be provided together"
227 .to_string(),
228 }
229 .fail(),
230 }
231 }
232}
233
234fn get_splunk_credentials<B>(request: &Request<B>) -> Result<Option<(Username, Password)>> {
235 let Some(header) = request.headers().get(http::header::AUTHORIZATION) else {
236 return Ok(None);
237 };
238 let (auth_scheme, credential) = header
239 .to_str()
240 .context(InvalidAuthHeaderInvisibleASCIISnafu)?
241 .split_once(' ')
242 .context(InvalidAuthHeaderSnafu)?;
243
244 let (username, password) = match auth_scheme.to_lowercase().as_str() {
245 "splunk" => {
246 let (u, p) = credential.split_once(':').context(InvalidAuthHeaderSnafu)?;
247 (u.to_string(), p.to_string().into())
248 }
249 "basic" => decode_basic(credential)?,
250 _ => UnsupportedAuthSchemeSnafu { name: auth_scheme }.fail()?,
251 };
252 Ok(Some((username, password)))
253}
254
255pub fn extract_username_and_password<B>(request: &Request<B>) -> Result<(Username, Password)> {
256 Ok(if is_influxdb_request(request) {
257 get_influxdb_credentials(request)?.context(NotFoundInfluxAuthSnafu)?
259 } else if is_splunk_request(request) {
260 get_splunk_credentials(request)?.context(NotFoundAuthHeaderSnafu)?
261 } else {
262 let scheme = auth_header(request)?;
264 match scheme {
265 AuthScheme::Basic(username, password) => (username, password),
266 }
267 })
268}
269
270#[derive(Debug)]
271pub enum AuthScheme {
272 Basic(Username, Password),
273}
274
275type Username = String;
276type Password = SecretString;
277
278impl TryFrom<&str> for AuthScheme {
279 type Error = error::Error;
280
281 fn try_from(value: &str) -> Result<Self> {
282 let (scheme, encoded_credentials) =
283 value.split_once(' ').context(InvalidAuthHeaderSnafu)?;
284
285 ensure!(!encoded_credentials.contains(' '), InvalidAuthHeaderSnafu);
286
287 match scheme.to_lowercase().as_str() {
288 "basic" => decode_basic(encoded_credentials)
289 .map(|(username, password)| AuthScheme::Basic(username, password)),
290 other => UnsupportedAuthSchemeSnafu { name: other }.fail(),
291 }
292 }
293}
294
295impl From<AuthScheme> for api::v1::auth_header::AuthScheme {
296 fn from(value: AuthScheme) -> Self {
297 match value {
298 AuthScheme::Basic(username, password) => {
299 api::v1::auth_header::AuthScheme::Basic(Basic {
300 username,
301 password: password.expose_secret().clone(),
302 })
303 }
304 }
305 }
306}
307
308type Credential<'a> = &'a str;
309
310fn auth_header<B>(req: &Request<B>) -> Result<AuthScheme> {
311 let auth_header = req
312 .headers()
313 .get(AUTHORIZATION_HEADER)
314 .or_else(|| req.headers().get(http::header::AUTHORIZATION))
315 .context(error::NotFoundAuthHeaderSnafu)?
316 .to_str()
317 .context(InvalidAuthHeaderInvisibleASCIISnafu)?;
318
319 auth_header.try_into()
320}
321
322fn decode_basic(credential: Credential) -> Result<(Username, Password)> {
323 let decoded = BASE64_STANDARD
324 .decode(credential)
325 .context(error::InvalidBase64ValueSnafu)?;
326 let as_utf8 =
327 String::from_utf8(decoded).context(error::InvalidAuthHeaderInvalidUtf8ValueSnafu)?;
328
329 if let Some((user_id, password)) = as_utf8.split_once(':') {
330 return Ok((user_id.to_string(), password.to_string().into()));
331 }
332
333 InvalidAuthHeaderSnafu {}.fail()
334}
335
336fn need_auth<B>(req: &Request<B>) -> bool {
337 let path = req.uri().path();
338
339 for api in PUBLIC_API_PREFIX {
340 if path.starts_with(api) {
341 return false;
342 }
343 }
344
345 path.starts_with(HTTP_API_PREFIX)
346}
347
348fn extract_param_from_query<'a>(query: &'a str, param: &'a str) -> Option<&'a str> {
349 let prefix = format!("{}=", param);
350 for pair in query.split('&') {
351 if let Some(param) = pair.strip_prefix(&prefix) {
352 return if param.is_empty() { None } else { Some(param) };
353 }
354 }
355 None
356}
357
358fn extract_db_from_query(query: &str) -> Option<&str> {
359 extract_param_from_query(query, "db")
360}
361
362fn extract_bucket_from_query(query: &str) -> Option<&str> {
365 extract_param_from_query(query, "bucket")
366}
367
368fn extract_influxdb_user_from_query(query: &str) -> (Option<&str>, Option<&str>) {
369 let mut username = None;
370 let mut password = None;
371
372 for pair in query.split('&') {
373 if pair.starts_with("u=") && pair.len() > 2 {
374 username = Some(&pair[2..]);
375 } else if pair.starts_with("p=") && pair.len() > 2 {
376 password = Some(&pair[2..]);
377 }
378 }
379 (username, password)
380}
381
382#[cfg(test)]
383mod tests {
384 use std::assert_matches;
385
386 use common_base::secrets::ExposeSecret;
387
388 use super::*;
389
390 #[test]
391 fn test_need_auth() {
392 let req = Request::builder()
393 .uri("http://127.0.0.1/v1/influxdb/ping")
394 .body(())
395 .unwrap();
396
397 assert!(!need_auth(&req));
398
399 let req = Request::builder()
400 .uri("http://127.0.0.1/v1/influxdb/health")
401 .body(())
402 .unwrap();
403
404 assert!(!need_auth(&req));
405
406 let req = Request::builder()
407 .uri("http://127.0.0.1/v1/influxdb/write")
408 .body(())
409 .unwrap();
410
411 assert!(need_auth(&req));
412 }
413
414 #[test]
415 fn test_splunk_auth() {
416 let splunk_uri = "http://127.0.0.1/v1/splunk/services/collector/event";
417 let splunk_req = |auth: Option<&str>| {
418 let mut req = Request::builder().uri(splunk_uri);
419 if let Some(auth) = auth {
420 req = req.header(http::header::AUTHORIZATION, auth);
421 }
422 req.body(()).unwrap()
423 };
424
425 assert!(is_splunk_request(&splunk_req(None)));
427 assert!(!is_splunk_request(
428 &Request::builder()
429 .uri("http://127.0.0.1/v1/influxdb/write")
430 .body(())
431 .unwrap()
432 ));
433 assert!(!is_splunk_request(
434 &Request::builder()
435 .uri("http://127.0.0.1/v1/sql")
436 .body(())
437 .unwrap()
438 ));
439
440 let (username, password) =
442 get_splunk_credentials(&splunk_req(Some("Splunk teamA:secretA")))
443 .unwrap()
444 .unwrap();
445 assert_eq!(username, "teamA");
446 assert_eq!(password.expose_secret(), "secretA");
447
448 let basic = basic_auth("u", "p");
450 let (username, password) = get_splunk_credentials(&splunk_req(Some(&basic)))
451 .unwrap()
452 .unwrap();
453 assert_eq!(username, "u");
454 assert_eq!(password.expose_secret(), "p");
455
456 assert!(get_splunk_credentials(&splunk_req(None)).unwrap().is_none());
458 assert!(get_splunk_credentials(&splunk_req(Some("Splunk no_colon_token"))).is_err());
459
460 let (username, password) =
462 extract_username_and_password(&splunk_req(Some("Splunk teamA:secretA"))).unwrap();
463 assert_eq!(username, "teamA");
464 assert_eq!(password.expose_secret(), "secretA");
465 }
466
467 #[test]
468 fn test_decode_basic() {
469 let credential = basic_auth_credentials("username", "password");
470 let (username, pwd) = decode_basic(&credential).unwrap();
471 assert_eq!("username", username);
472 assert_eq!("password", pwd.expose_secret());
473
474 let wrong_credential = credential.replacen('c', "c ", 1);
475 let result = decode_basic(&wrong_credential);
476 assert_matches!(result.err(), Some(error::Error::InvalidBase64Value { .. }));
477 }
478
479 #[test]
480 fn test_try_into_auth_scheme() {
481 let auth_scheme_str = "basic";
482 let re: Result<AuthScheme> = auth_scheme_str.try_into();
483 assert!(re.is_err());
484
485 let auth_scheme_str = basic_auth("test", "test");
486 let scheme: AuthScheme = auth_scheme_str.as_str().try_into().unwrap();
487 assert_matches!(scheme, AuthScheme::Basic(username, pwd) if username == "test" && pwd.expose_secret() == "test");
488
489 let unsupported = "digest";
490 let auth_scheme: Result<AuthScheme> = unsupported.try_into();
491 assert!(auth_scheme.is_err());
492 }
493
494 #[test]
495 fn test_inner_auth_assigns_remote_query_id() {
496 let req =
497 mock_http_request(None, Some("http://127.0.0.1/v1/sql?db=greptime-public")).unwrap();
498 let req = futures::executor::block_on(inner_auth::<()>(None, req)).unwrap();
499 let query_ctx = req
500 .extensions()
501 .get::<session::context::QueryContext>()
502 .unwrap();
503
504 assert!(query_ctx.remote_query_id().is_some());
505 }
506
507 #[test]
508 fn test_auth_header() {
509 let header_value = basic_auth("username", "password");
510 let req = mock_http_request(Some(&header_value), None).unwrap();
511
512 let auth_scheme = auth_header(&req).unwrap();
513 assert_matches!(auth_scheme, AuthScheme::Basic(username, pwd) if username == "username" && pwd.expose_secret() == "password");
514
515 let wrong_auth_header = header_value.replacen('c', "c ", 1);
516 let wrong_req = mock_http_request(Some(&wrong_auth_header), None).unwrap();
517 let res = auth_header(&wrong_req);
518 assert_matches!(res.err(), Some(error::Error::InvalidAuthHeader { .. }));
519
520 let wrong_req = mock_http_request(
521 Some(&format!(
522 "Digest {}",
523 basic_auth_credentials("username", "password")
524 )),
525 None,
526 )
527 .unwrap();
528 let res = auth_header(&wrong_req);
529 assert_matches!(res.err(), Some(error::Error::UnsupportedAuthScheme { .. }));
530 }
531
532 fn basic_auth(username: &str, password: &str) -> String {
533 format!("Basic {}", basic_auth_credentials(username, password))
534 }
535
536 fn basic_auth_credentials(username: &str, password: &str) -> String {
537 BASE64_STANDARD.encode(format!("{username}:{password}"))
538 }
539
540 fn mock_http_request(auth_header: Option<&str>, uri: Option<&str>) -> Result<Request<()>> {
541 let http_api_version = crate::http::HTTP_API_VERSION;
542 let mut req = Request::builder()
543 .uri(uri.unwrap_or(format!("http://localhost/{http_api_version}/sql").as_str()));
544 if let Some(auth_header) = auth_header {
545 req = req.header(http::header::AUTHORIZATION, auth_header);
546 }
547
548 Ok(req.body(()).unwrap())
549 }
550
551 #[test]
552 fn test_db_name_header() {
553 let http_api_version = crate::http::HTTP_API_VERSION;
554 let req = Request::builder()
555 .uri(format!("http://localhost/{http_api_version}/sql").as_str())
556 .header(GreptimeDbName::name(), "greptime-tomcat")
557 .body(())
558 .unwrap();
559
560 let db = extract_catalog_and_schema(&req);
561 assert_eq!(db, ("greptime".to_string(), "tomcat".to_string()));
562 }
563
564 #[test]
565 fn test_extract_db() {
566 assert_matches!(extract_db_from_query(""), None);
567 assert_matches!(extract_db_from_query("&"), None);
568 assert_matches!(extract_db_from_query("db="), None);
569 assert_matches!(extract_bucket_from_query("bucket="), None);
570 assert_matches!(extract_bucket_from_query("db=foo"), None);
571 assert_matches!(extract_db_from_query("db=foo"), Some("foo"));
572 assert_matches!(extract_bucket_from_query("bucket=foo"), Some("foo"));
573 assert_matches!(extract_db_from_query("name=bar"), None);
574 assert_matches!(extract_db_from_query("db=&name=bar"), None);
575 assert_matches!(extract_db_from_query("db=foo&name=bar"), Some("foo"));
576 assert_matches!(extract_bucket_from_query("db=foo&bucket=bar"), Some("bar"));
577 assert_matches!(extract_db_from_query("name=bar&db="), None);
578 assert_matches!(extract_db_from_query("name=bar&db=foo"), Some("foo"));
579 assert_matches!(extract_db_from_query("name=bar&db=&name=bar"), None);
580 assert_matches!(
581 extract_db_from_query("name=bar&db=foo&name=bar"),
582 Some("foo")
583 );
584 }
585
586 #[test]
587 fn test_extract_user() {
588 assert_matches!(extract_influxdb_user_from_query(""), (None, None));
589 assert_matches!(extract_influxdb_user_from_query("u="), (None, None));
590 assert_matches!(
591 extract_influxdb_user_from_query("u=123"),
592 (Some("123"), None)
593 );
594 assert_matches!(
595 extract_influxdb_user_from_query("u=123&p="),
596 (Some("123"), None)
597 );
598 assert_matches!(
599 extract_influxdb_user_from_query("u=123&p=4"),
600 (Some("123"), Some("4"))
601 );
602 assert_matches!(extract_influxdb_user_from_query("p="), (None, None));
603 assert_matches!(extract_influxdb_user_from_query("p=4"), (None, Some("4")));
604 assert_matches!(
605 extract_influxdb_user_from_query("p=4&u="),
606 (None, Some("4"))
607 );
608 assert_matches!(
609 extract_influxdb_user_from_query("p=4&u=123"),
610 (Some("123"), Some("4"))
611 );
612 }
613}