From f16f58266ec9861306f680bfb92f3ab97b48a722 Mon Sep 17 00:00:00 2001 From: shuiyisong <113876041+shuiyisong@users.noreply.github.com> Date: Fri, 25 Aug 2023 17:36:33 +0800 Subject: [PATCH] refactor: query_ctx from http middleware (#2253) * chore: change userinfo to query_ctx in http handler * chore: minor change * chore: move prometheus http to http mod * chore: fix uni test: * chore: add back schema check * chore: minor change * chore: remove clone --- src/frontend/src/instance.rs | 2 +- src/servers/src/grpc.rs | 2 +- src/servers/src/grpc/prom_query_gateway.rs | 5 +- src/servers/src/http.rs | 43 ++----------- src/servers/src/http/admin.rs | 13 ++-- src/servers/src/http/authorize.rs | 39 ++++-------- src/servers/src/http/handler.rs | 70 ++++++++++++--------- src/servers/src/http/influxdb.rs | 19 ++---- src/servers/src/http/opentsdb.rs | 8 +-- src/servers/src/http/otlp.rs | 15 ++--- src/servers/src/http/prom_store.rs | 16 ++--- src/servers/src/{ => http}/prometheus.rs | 31 +++------ src/servers/src/lib.rs | 2 +- src/servers/src/prometheus_handler.rs | 36 +++++++++++ src/servers/tests/http/authorize.rs | 14 +++-- src/servers/tests/http/http_handler_test.rs | 14 ++++- tests-integration/tests/grpc.rs | 2 +- tests-integration/tests/http.rs | 2 +- 18 files changed, 153 insertions(+), 180 deletions(-) rename src/servers/src/{ => http}/prometheus.rs (97%) create mode 100644 src/servers/src/prometheus_handler.rs diff --git a/src/frontend/src/instance.rs b/src/frontend/src/instance.rs index 44cbf6e98b..12ee2c01f9 100644 --- a/src/frontend/src/instance.rs +++ b/src/frontend/src/instance.rs @@ -64,7 +64,7 @@ use servers::error::{AuthSnafu, ExecuteQuerySnafu, ParsePromQLSnafu}; use servers::interceptor::{ PromQueryInterceptor, PromQueryInterceptorRef, SqlQueryInterceptor, SqlQueryInterceptorRef, }; -use servers::prometheus::PrometheusHandler; +use servers::prometheus_handler::PrometheusHandler; use servers::query_handler::grpc::{GrpcQueryHandler, GrpcQueryHandlerRef}; use servers::query_handler::sql::SqlQueryHandler; use servers::query_handler::{ diff --git a/src/servers/src/grpc.rs b/src/servers/src/grpc.rs index 4c6fa68306..c3beda0f70 100644 --- a/src/servers/src/grpc.rs +++ b/src/servers/src/grpc.rs @@ -51,7 +51,7 @@ use self::region_server::{RegionServerHandlerRef, RegionServerRequestHandler}; use crate::error::{AlreadyStartedSnafu, InternalSnafu, Result, StartGrpcSnafu, TcpBindSnafu}; use crate::grpc::database::DatabaseService; use crate::grpc::greptime_handler::GreptimeRequestHandler; -use crate::prometheus::PrometheusHandlerRef; +use crate::prometheus_handler::PrometheusHandlerRef; use crate::query_handler::grpc::ServerGrpcQueryHandlerRef; use crate::server::Server; diff --git a/src/servers/src/grpc/prom_query_gateway.rs b/src/servers/src/grpc/prom_query_gateway.rs index 02d74839c4..e41960d68b 100644 --- a/src/servers/src/grpc/prom_query_gateway.rs +++ b/src/servers/src/grpc/prom_query_gateway.rs @@ -35,9 +35,8 @@ use tonic::{Request, Response}; use crate::error::InvalidQuerySnafu; use crate::grpc::greptime_handler::{auth, create_query_context}; use crate::grpc::TonicResult; -use crate::prometheus::{ - retrieve_metric_name_and_result_type, PrometheusHandlerRef, PrometheusJsonResponse, -}; +use crate::http::prometheus::{retrieve_metric_name_and_result_type, PrometheusJsonResponse}; +use crate::prometheus_handler::PrometheusHandlerRef; pub struct PrometheusGatewayService { handler: PrometheusHandlerRef, diff --git a/src/servers/src/http.rs b/src/servers/src/http.rs index ece3dda6c6..0b0f11db15 100644 --- a/src/servers/src/http.rs +++ b/src/servers/src/http.rs @@ -22,13 +22,13 @@ pub mod opentsdb; pub mod otlp; mod pprof; pub mod prom_store; +pub mod prometheus; pub mod script; #[cfg(feature = "dashboard")] mod dashboard; use std::net::SocketAddr; -use std::sync::Arc; use std::time::{Duration, Instant}; use aide::axum::{routing as apirouting, ApiRouter, IntoApiResponse}; @@ -43,8 +43,6 @@ use axum::middleware::{self, Next}; use axum::response::{Html, IntoResponse, Json}; use axum::{routing, BoxError, Extension, Router}; use common_base::readable_size::ReadableSize; -use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; -use common_catalog::parse_catalog_and_schema_from_db_string; use common_error::ext::ErrorExt; use common_error::status_code::StatusCode; use common_query::Output; @@ -55,7 +53,6 @@ use futures::FutureExt; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use serde_json::Value; -use session::context::QueryContext; use snafu::{ensure, ResultExt}; use tokio::sync::oneshot::{self, Sender}; use tokio::sync::Mutex; @@ -69,15 +66,15 @@ use self::influxdb::{influxdb_health, influxdb_ping, influxdb_write_v1, influxdb use crate::configurator::ConfiguratorRef; use crate::error::{AlreadyStartedSnafu, Result, StartHttpSnafu}; use crate::http::admin::{compact, flush}; +use crate::http::prometheus::{ + instant_query, label_values_query, labels_query, range_query, series_query, +}; use crate::metrics::{ METRIC_CODE_LABEL, METRIC_HTTP_REQUESTS_ELAPSED, METRIC_HTTP_REQUESTS_TOTAL, METRIC_METHOD_LABEL, METRIC_PATH_LABEL, }; use crate::metrics_handler::MetricsHandler; -use crate::prometheus::{ - instant_query, label_values_query, labels_query, range_query, series_query, - PrometheusHandlerRef, -}; +use crate::prometheus_handler::PrometheusHandlerRef; use crate::query_handler::grpc::ServerGrpcQueryHandlerRef; use crate::query_handler::sql::ServerSqlQueryHandlerRef; use crate::query_handler::{ @@ -86,36 +83,6 @@ use crate::query_handler::{ }; use crate::server::Server; -/// create query context from database name information, catalog and schema are -/// resolved from the name -pub(crate) async fn query_context_from_db( - query_handler: ServerSqlQueryHandlerRef, - db: Option, -) -> std::result::Result, JsonResponse> { - let (catalog, schema) = if let Some(db) = &db { - let (catalog, schema) = parse_catalog_and_schema_from_db_string(db); - - match query_handler.is_valid_schema(catalog, schema).await { - Ok(true) => (catalog, schema), - Ok(false) => { - return Err(JsonResponse::with_error( - format!("Database not found: {db}"), - StatusCode::DatabaseNotFound, - )) - } - Err(e) => { - return Err(JsonResponse::with_error( - format!("Error checking database: {db}, {e}"), - StatusCode::Internal, - )) - } - } - } else { - (DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME) - }; - Ok(QueryContext::with(catalog, schema)) -} - pub const HTTP_API_VERSION: &str = "v1"; pub const HTTP_API_PREFIX: &str = "/v1/"; /// Default http body limit (64M). diff --git a/src/servers/src/http/admin.rs b/src/servers/src/http/admin.rs index 76adf8271b..258ed5778c 100644 --- a/src/servers/src/http/admin.rs +++ b/src/servers/src/http/admin.rs @@ -19,8 +19,9 @@ use api::v1::greptime_request::Request; use api::v1::{CompactTableExpr, DdlRequest, FlushTableExpr}; use axum::extract::{Query, RawBody, State}; use axum::http::StatusCode; +use axum::Extension; use common_catalog::consts::DEFAULT_CATALOG_NAME; -use session::context::QueryContext; +use session::context::QueryContextRef; use snafu::OptionExt; use crate::error; @@ -31,6 +32,7 @@ use crate::query_handler::grpc::ServerGrpcQueryHandlerRef; pub async fn flush( State(grpc_handler): State, Query(params): Query>, + Extension(query_ctx): Extension, RawBody(_): RawBody, ) -> Result<(StatusCode, ())> { let catalog_name = params @@ -64,9 +66,7 @@ pub async fn flush( })), }); - grpc_handler - .do_query(request, QueryContext::with(&catalog_name, &schema_name)) - .await?; + grpc_handler.do_query(request, query_ctx).await?; Ok((StatusCode::NO_CONTENT, ())) } @@ -74,6 +74,7 @@ pub async fn flush( pub async fn compact( State(grpc_handler): State, Query(params): Query>, + Extension(query_ctx): Extension, RawBody(_): RawBody, ) -> Result<(StatusCode, ())> { let catalog_name = params @@ -106,8 +107,6 @@ pub async fn compact( })), }); - grpc_handler - .do_query(request, QueryContext::with(&catalog_name, &schema_name)) - .await?; + grpc_handler.do_query(request, query_ctx).await?; Ok((StatusCode::NO_CONTENT, ())) } diff --git a/src/servers/src/http/authorize.rs b/src/servers/src/http/authorize.rs index 54cda788d3..03c7bb735f 100644 --- a/src/servers/src/http/authorize.rs +++ b/src/servers/src/http/authorize.rs @@ -17,6 +17,7 @@ use std::marker::PhantomData; use ::auth::UserProviderRef; use axum::http::{self, Request, StatusCode}; use axum::response::Response; +use common_catalog::consts::DEFAULT_SCHEMA_NAME; use common_catalog::parse_catalog_and_schema_from_db_string; use common_error::ext::ErrorExt; use common_telemetry::warn; @@ -25,6 +26,7 @@ use headers::Header; use http_body::Body; use metrics::increment_counter; use secrecy::SecretString; +use session::context::QueryContext; use snafu::{ensure, OptionExt, ResultExt}; use tower_http::auth::AsyncAuthorizeRequest; @@ -71,14 +73,15 @@ where fn authorize(&mut self, mut request: Request) -> Self::Future { let user_provider = self.user_provider.clone(); Box::pin(async move { + let (catalog, schema) = extract_catalog_and_schema(&request); + let query_ctx = QueryContext::with(catalog, schema); let need_auth = need_auth(&request); let user_provider = if let Some(user_provider) = user_provider.filter(|_| need_auth) { user_provider } else { - let _ = request - .extensions_mut() - .insert(auth::userinfo_by_name(None)); + query_ctx.set_current_user(Some(auth::userinfo_by_name(None))); + let _ = request.extensions_mut().insert(query_ctx); return Ok(request); }; @@ -97,21 +100,6 @@ where } }; - let (catalog, schema) = match extract_catalog_and_schema(&request) { - Ok((catalog, schema)) => (catalog, schema), - Err(e) => { - warn!("extract catalog and schema failed: {}", e); - increment_counter!( - crate::metrics::METRIC_AUTH_FAILURE, - &[( - crate::metrics::METRIC_CODE_LABEL, - format!("{}", e.status_code()) - )] - ); - return Err(unauthorized_resp()); - } - }; - match user_provider .auth( ::auth::Identity::UserId(username.as_str(), None), @@ -122,7 +110,8 @@ where .await { Ok(userinfo) => { - let _ = request.extensions_mut().insert(userinfo); + query_ctx.set_current_user(Some(userinfo)); + let _ = request.extensions_mut().insert(query_ctx); Ok(request) } Err(e) => { @@ -141,9 +130,7 @@ where } } -fn extract_catalog_and_schema( - request: &Request, -) -> Result<(&str, &str)> { +fn extract_catalog_and_schema(request: &Request) -> (&str, &str) { // parse database from header let dbname = request .headers() @@ -154,11 +141,9 @@ fn extract_catalog_and_schema( let query = request.uri().query().unwrap_or_default(); extract_db_from_query(query) }) - .context(InvalidParameterSnafu { - reason: "`db` must be provided in query string", - })?; + .unwrap_or(DEFAULT_SCHEMA_NAME); - Ok(parse_catalog_and_schema_from_db_string(dbname)) + parse_catalog_and_schema_from_db_string(dbname) } fn get_influxdb_credentials( @@ -413,7 +398,7 @@ mod tests { .body(()) .unwrap(); - let db = extract_catalog_and_schema(&req).unwrap(); + let db = extract_catalog_and_schema(&req); assert_eq!(db, ("greptime", "tomcat")); } diff --git a/src/servers/src/http/handler.rs b/src/servers/src/http/handler.rs index 66b6e97804..90fff97635 100644 --- a/src/servers/src/http/handler.rs +++ b/src/servers/src/http/handler.rs @@ -17,7 +17,6 @@ use std::env; use std::time::Instant; use aide::transform::TransformOperation; -use auth::UserInfoRef; use axum::extract::{Json, Query, State}; use axum::response::{IntoResponse, Response}; use axum::{Extension, Form}; @@ -26,9 +25,11 @@ use common_telemetry::timer; use query::parser::PromQuery; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; +use session::context::QueryContextRef; use crate::http::{ApiState, GreptimeOptionsConfigState, JsonResponse}; use crate::metrics_handler::MetricsHandler; +use crate::query_handler::sql::ServerSqlQueryHandlerRef; #[derive(Debug, Default, Serialize, Deserialize, JsonSchema)] pub struct SqlQuery { @@ -41,31 +42,25 @@ pub struct SqlQuery { pub async fn sql( State(state): State, Query(query_params): Query, - // TODO(fys): pass _user_info into query context - user_info: Extension, + Extension(query_ctx): Extension, Form(form_params): Form, ) -> Json { let sql_handler = &state.sql_handler; let start = Instant::now(); let sql = query_params.sql.or(form_params.sql); - let db = query_params.db.or(form_params.db); + let db = query_ctx.get_db_string(); let _timer = timer!( crate::metrics::METRIC_HTTP_SQL_ELAPSED, - &[( - crate::metrics::METRIC_DB_LABEL, - db.clone().unwrap_or_default() - )] + &[(crate::metrics::METRIC_DB_LABEL, db)] ); let resp = if let Some(sql) = &sql { - match crate::http::query_context_from_db(sql_handler.clone(), db).await { - Ok(query_ctx) => { - query_ctx.set_current_user(Some(user_info.0)); - JsonResponse::from_output(sql_handler.do_query(sql, query_ctx).await).await - } - Err(resp) => resp, + if let Some(resp) = validate_schema(sql_handler.clone(), query_ctx.clone()).await { + return Json(resp); } + + JsonResponse::from_output(sql_handler.do_query(sql, query_ctx).await).await } else { JsonResponse::with_error( "sql parameter is required.".to_string(), @@ -101,29 +96,23 @@ impl From for PromQuery { pub async fn promql( State(state): State, Query(params): Query, - // TODO(fys): pass _user_info into query context - user_info: Extension, + Extension(query_ctx): Extension, ) -> Json { let sql_handler = &state.sql_handler; let exec_start = Instant::now(); - let db = params.db.clone(); + let db = query_ctx.get_db_string(); let _timer = timer!( crate::metrics::METRIC_HTTP_PROMQL_ELAPSED, - &[( - crate::metrics::METRIC_DB_LABEL, - db.clone().unwrap_or_default() - )] + &[(crate::metrics::METRIC_DB_LABEL, db)] ); + if let Some(resp) = validate_schema(sql_handler.clone(), query_ctx.clone()).await { + return Json(resp); + } + let prom_query = params.into(); - let resp = match super::query_context_from_db(sql_handler.clone(), db).await { - Ok(query_ctx) => { - query_ctx.set_current_user(Some(user_info.0)); - JsonResponse::from_output(sql_handler.do_promql_query(&prom_query, query_ctx).await) - .await - } - Err(resp) => resp, - }; + let resp = + JsonResponse::from_output(sql_handler.do_promql_query(&prom_query, query_ctx).await).await; Json(resp.with_execution_time(exec_start.elapsed().as_millis())) } @@ -196,3 +185,26 @@ pub async fn status() -> Json> { pub async fn config(State(state): State) -> Response { (axum::http::StatusCode::OK, state.greptime_config_options).into_response() } + +async fn validate_schema( + sql_handler: ServerSqlQueryHandlerRef, + query_ctx: QueryContextRef, +) -> Option { + match sql_handler + .is_valid_schema(query_ctx.current_catalog(), query_ctx.current_schema()) + .await + { + Ok(false) => Some(JsonResponse::with_error( + format!("Database not found: {}", query_ctx.get_db_string()), + StatusCode::DatabaseNotFound, + )), + Err(e) => Some(JsonResponse::with_error( + format!( + "Error checking database: {}, {e}", + query_ctx.get_db_string() + ), + StatusCode::Internal, + )), + _ => None, + } +} diff --git a/src/servers/src/http/influxdb.rs b/src/servers/src/http/influxdb.rs index 9e3bec8558..4dab6d00ef 100644 --- a/src/servers/src/http/influxdb.rs +++ b/src/servers/src/http/influxdb.rs @@ -14,16 +14,14 @@ use std::collections::HashMap; -use auth::UserInfoRef; use axum::extract::{Query, State}; use axum::http::StatusCode; use axum::response::IntoResponse; use axum::Extension; use common_catalog::consts::DEFAULT_SCHEMA_NAME; -use common_catalog::parse_catalog_and_schema_from_db_string; use common_grpc::writer::Precision; use common_telemetry::timer; -use session::context::QueryContext; +use session::context::QueryContextRef; use crate::error::{Result, TimePrecisionSnafu}; use crate::influxdb::InfluxdbRequest; @@ -45,7 +43,7 @@ pub async fn influxdb_health() -> Result { pub async fn influxdb_write_v1( State(handler): State, Query(mut params): Query>, - user_info: Extension, + Extension(query_ctx): Extension, lines: String, ) -> Result { let db = params @@ -57,14 +55,14 @@ pub async fn influxdb_write_v1( .map(|val| parse_time_precision(val)) .transpose()?; - influxdb_write(&db, precision, lines, handler, user_info.0).await + influxdb_write(&db, precision, lines, handler, query_ctx).await } #[axum_macros::debug_handler] pub async fn influxdb_write_v2( State(handler): State, Query(mut params): Query>, - user_info: Extension, + Extension(query_ctx): Extension, lines: String, ) -> Result { let db = params @@ -76,7 +74,7 @@ pub async fn influxdb_write_v2( .map(|val| parse_time_precision(val)) .transpose()?; - influxdb_write(&db, precision, lines, handler, user_info.0).await + influxdb_write(&db, precision, lines, handler, query_ctx).await } pub async fn influxdb_write( @@ -84,19 +82,14 @@ pub async fn influxdb_write( precision: Option, lines: String, handler: InfluxdbLineProtocolHandlerRef, - user_info: UserInfoRef, + ctx: QueryContextRef, ) -> Result { let _timer = timer!( crate::metrics::METRIC_HTTP_INFLUXDB_WRITE_ELAPSED, &[(crate::metrics::METRIC_DB_LABEL, db.to_string())] ); - let (catalog, schema) = parse_catalog_and_schema_from_db_string(db); - let ctx = QueryContext::with(catalog, schema); - ctx.set_current_user(Some(user_info)); - let request = InfluxdbRequest { precision, lines }; - handler.exec(request, ctx).await?; Ok((StatusCode::NO_CONTENT, ())) diff --git a/src/servers/src/http/opentsdb.rs b/src/servers/src/http/opentsdb.rs index 3c2a82cd81..182c8db8d7 100644 --- a/src/servers/src/http/opentsdb.rs +++ b/src/servers/src/http/opentsdb.rs @@ -14,13 +14,12 @@ use std::collections::HashMap; -use auth::UserInfoRef; use axum::extract::{Query, RawBody, State}; use axum::http::StatusCode as HttpStatusCode; use axum::{Extension, Json}; use hyper::Body; use serde::{Deserialize, Serialize}; -use session::context::QueryContext; +use session::context::QueryContextRef; use snafu::ResultExt; use crate::error::{self, Error, Result}; @@ -78,15 +77,12 @@ pub enum OpentsdbPutResponse { pub async fn put( State(opentsdb_handler): State, Query(params): Query>, - user_info: Extension, + Extension(ctx): Extension, RawBody(body): RawBody, ) -> Result<(HttpStatusCode, Json)> { let summary = params.contains_key("summary"); let details = params.contains_key("details"); - let ctx = QueryContext::with_db_name(params.get("db")); - ctx.set_current_user(Some(user_info.0)); - let data_points = parse_data_points(body).await?; let response = if !summary && !details { diff --git a/src/servers/src/http/otlp.rs b/src/servers/src/http/otlp.rs index fc2a3fa6d4..7d797d440f 100644 --- a/src/servers/src/http/otlp.rs +++ b/src/servers/src/http/otlp.rs @@ -12,39 +12,34 @@ // See the License for the specific language governing permissions and // limitations under the License. -use auth::UserInfoRef; use axum::extract::{RawBody, State}; use axum::http::header; use axum::response::IntoResponse; -use axum::{Extension, TypedHeader}; +use axum::Extension; use common_telemetry::timer; use hyper::Body; use opentelemetry_proto::tonic::collector::metrics::v1::{ ExportMetricsServiceRequest, ExportMetricsServiceResponse, }; use prost::Message; -use session::context::QueryContext; +use session::context::QueryContextRef; use snafu::prelude::*; use crate::error::{self, Result}; -use crate::http::header::GreptimeDbName; use crate::query_handler::OpenTelemetryProtocolHandlerRef; #[axum_macros::debug_handler] pub async fn metrics( State(handler): State, - TypedHeader(db): TypedHeader, - user_info: Extension, + Extension(query_ctx): Extension, RawBody(body): RawBody, ) -> Result { - let ctx = QueryContext::with_db_name(db.value()); - ctx.set_current_user(Some(user_info.0)); let _timer = timer!( crate::metrics::METRIC_HTTP_OPENTELEMETRY_ELAPSED, - &[(crate::metrics::METRIC_DB_LABEL, ctx.get_db_string())] + &[(crate::metrics::METRIC_DB_LABEL, query_ctx.get_db_string())] ); let request = parse_body(body).await?; - handler.metrics(request, ctx).await.map(OtlpResponse) + handler.metrics(request, query_ctx).await.map(OtlpResponse) } async fn parse_body(body: Body) -> Result { diff --git a/src/servers/src/http/prom_store.rs b/src/servers/src/http/prom_store.rs index 0eb56f575a..897e9e703e 100644 --- a/src/servers/src/http/prom_store.rs +++ b/src/servers/src/http/prom_store.rs @@ -13,7 +13,6 @@ // limitations under the License. use api::prom_store::remote::{ReadRequest, WriteRequest}; -use auth::UserInfoRef; use axum::extract::{Query, RawBody, State}; use axum::http::{header, StatusCode}; use axum::response::IntoResponse; @@ -24,7 +23,7 @@ use hyper::Body; use prost::Message; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use session::context::QueryContext; +use session::context::QueryContextRef; use snafu::prelude::*; use crate::error::{self, Result}; @@ -48,7 +47,7 @@ impl Default for DatabaseQuery { pub async fn remote_write( State(handler): State, Query(params): Query, - user_info: Extension, + Extension(query_ctx): Extension, RawBody(body): RawBody, ) -> Result<(StatusCode, ())> { let request = decode_remote_write_request(body).await?; @@ -61,9 +60,7 @@ pub async fn remote_write( )] ); - let ctx = QueryContext::with_db_name(params.db.as_ref()); - ctx.set_current_user(Some(user_info.0)); - handler.write(request, ctx).await?; + handler.write(request, query_ctx).await?; Ok((StatusCode::NO_CONTENT, ())) } @@ -84,7 +81,7 @@ impl IntoResponse for PromStoreResponse { pub async fn remote_read( State(handler): State, Query(params): Query, - user_info: Extension, + Extension(query_ctx): Extension, RawBody(body): RawBody, ) -> Result { let request = decode_remote_read_request(body).await?; @@ -96,10 +93,7 @@ pub async fn remote_read( params.db.clone().unwrap_or_default() )] ); - - let ctx = QueryContext::with_db_name(params.db.as_ref()); - ctx.set_current_user(Some(user_info.0)); - handler.read(request, ctx).await + handler.read(request, query_ctx).await } async fn decode_remote_write_request(body: Body) -> Result { diff --git a/src/servers/src/prometheus.rs b/src/servers/src/http/prometheus.rs similarity index 97% rename from src/servers/src/prometheus.rs rename to src/servers/src/http/prometheus.rs index e7261a5615..9f8e652cc5 100644 --- a/src/servers/src/prometheus.rs +++ b/src/servers/src/http/prometheus.rs @@ -14,11 +14,9 @@ //! prom supply the prometheus HTTP API Server compliance use std::collections::{BTreeMap, HashMap, HashSet}; -use std::sync::Arc; -use async_trait::async_trait; use axum::extract::{Path, Query, State}; -use axum::{Form, Json}; +use axum::{Extension, Form, Json}; use catalog::CatalogManagerRef; use common_catalog::consts::DEFAULT_SCHEMA_NAME; use common_catalog::parse_catalog_and_schema_from_db_string; @@ -40,24 +38,14 @@ use query::parser::{PromQuery, DEFAULT_LOOKBACK_STRING}; use schemars::JsonSchema; use serde::de::{self, MapAccess, Visitor}; use serde::{Deserialize, Serialize}; -use session::context::{QueryContext, QueryContextRef}; +use session::context::QueryContextRef; use snafu::{Location, OptionExt, ResultExt}; use crate::error::{ CollectRecordbatchSnafu, Error, InternalSnafu, InvalidQuerySnafu, Result, UnexpectedResultSnafu, }; use crate::prom_store::{FIELD_COLUMN_NAME, METRIC_NAME_LABEL, TIMESTAMP_COLUMN_NAME}; - -pub const PROMETHEUS_API_VERSION: &str = "v1"; - -pub type PrometheusHandlerRef = Arc; - -#[async_trait] -pub trait PrometheusHandler { - async fn do_query(&self, query: &PromQuery, query_ctx: QueryContextRef) -> Result; - - fn catalog_manager(&self) -> CatalogManagerRef; -} +use crate::prometheus_handler::PrometheusHandlerRef; #[derive(Debug, Default, Serialize, Deserialize, JsonSchema, PartialEq)] pub struct PromSeries { @@ -315,6 +303,7 @@ pub struct InstantQuery { pub async fn instant_query( State(handler): State, Query(params): Query, + Extension(query_ctx): Extension, Form(form_params): Form, ) -> Json { let _timer = timer!(crate::metrics::METRIC_HTTP_PROMQL_INSTANT_QUERY_ELAPSED); @@ -330,8 +319,6 @@ pub async fn instant_query( step: "1s".to_string(), }; - let query_ctx = QueryContext::with_db_name(params.db.as_ref()); - let result = handler.do_query(&prom_query, query_ctx).await; let (metric_name, result_type) = match retrieve_metric_name_and_result_type(&prom_query.query) { Ok((metric_name, result_type)) => (metric_name.unwrap_or_default(), result_type), @@ -356,6 +343,7 @@ pub struct RangeQuery { pub async fn range_query( State(handler): State, Query(params): Query, + Extension(query_ctx): Extension, Form(form_params): Form, ) -> Json { let _timer = timer!(crate::metrics::METRIC_HTTP_PROMQL_RANGE_QUERY_ELAPSED); @@ -366,8 +354,6 @@ pub async fn range_query( step: params.step.or(form_params.step).unwrap_or_default(), }; - let query_ctx = QueryContext::with_db_name(params.db.as_ref()); - let result = handler.do_query(&prom_query, query_ctx).await; let metric_name = match retrieve_metric_name_and_result_type(&prom_query.query) { Err(err) => { @@ -426,13 +412,13 @@ impl<'de> Deserialize<'de> for Matches { pub async fn labels_query( State(handler): State, Query(params): Query, + Extension(query_ctx): Extension, Form(form_params): Form, ) -> Json { let _timer = timer!(crate::metrics::METRIC_HTTP_PROMQL_LABEL_QUERY_ELAPSED); let db = ¶ms.db.unwrap_or(DEFAULT_SCHEMA_NAME.to_string()); let (catalog, schema) = parse_catalog_and_schema_from_db_string(db); - let query_ctx = QueryContext::with(catalog, schema); let mut queries = params.matches.0; if queries.is_empty() { @@ -692,6 +678,7 @@ pub struct LabelValueQuery { pub async fn label_values_query( State(handler): State, Path(label_name): Path, + Extension(query_ctx): Extension, Query(params): Query, ) -> Json { let _timer = timer!(crate::metrics::METRIC_HTTP_PROMQL_LABEL_VALUE_QUERY_ELAPSED); @@ -717,7 +704,6 @@ pub async fn label_values_query( let start = params.start.unwrap_or_else(yesterday_rfc3339); let end = params.end.unwrap_or_else(current_time_rfc3339); - let query_ctx = QueryContext::with(catalog, schema); let mut label_values = HashSet::new(); @@ -818,6 +804,7 @@ pub struct SeriesQuery { pub async fn series_query( State(handler): State, Query(params): Query, + Extension(query_ctx): Extension, Form(form_params): Form, ) -> Json { let _timer = timer!(crate::metrics::METRIC_HTTP_PROMQL_SERIES_QUERY_ELAPSED); @@ -837,8 +824,6 @@ pub async fn series_query( .or(form_params.end) .unwrap_or_else(current_time_rfc3339); - let query_ctx = QueryContext::with_db_name(params.db.as_ref()); - let mut series = Vec::new(); for query in queries { let table_name = query.clone(); diff --git a/src/servers/src/lib.rs b/src/servers/src/lib.rs index 2e62224084..acfd12c6cc 100644 --- a/src/servers/src/lib.rs +++ b/src/servers/src/lib.rs @@ -35,7 +35,7 @@ pub mod opentsdb; pub mod otlp; pub mod postgres; pub mod prom_store; -pub mod prometheus; +pub mod prometheus_handler; pub mod query_handler; pub mod server; mod shutdown; diff --git a/src/servers/src/prometheus_handler.rs b/src/servers/src/prometheus_handler.rs new file mode 100644 index 0000000000..e6d1359edd --- /dev/null +++ b/src/servers/src/prometheus_handler.rs @@ -0,0 +1,36 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! prom supply the prometheus HTTP API Server compliance + +use std::sync::Arc; + +use async_trait::async_trait; +use catalog::CatalogManagerRef; +use common_query::Output; +use query::parser::PromQuery; +use session::context::QueryContextRef; + +use crate::error::Result; + +pub const PROMETHEUS_API_VERSION: &str = "v1"; + +pub type PrometheusHandlerRef = Arc; + +#[async_trait] +pub trait PrometheusHandler { + async fn do_query(&self, query: &PromQuery, query_ctx: QueryContextRef) -> Result; + + fn catalog_manager(&self) -> CatalogManagerRef; +} diff --git a/src/servers/tests/http/authorize.rs b/src/servers/tests/http/authorize.rs index 9ca41a7096..a575ffb6df 100644 --- a/src/servers/tests/http/authorize.rs +++ b/src/servers/tests/http/authorize.rs @@ -15,11 +15,12 @@ use std::sync::Arc; use auth::tests::MockUserProvider; -use auth::{UserInfoRef, UserProvider}; +use auth::UserProvider; use axum::body::BoxBody; use axum::http; use hyper::Request; use servers::http::authorize::HttpAuth; +use session::context::QueryContextRef; use tower_http::auth::AsyncAuthorizeRequest; #[tokio::test] @@ -28,8 +29,9 @@ async fn test_http_auth() { // base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ=" let req = mock_http_request(Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap(); - let auth_res = http_auth.authorize(req).await.unwrap(); - let user_info: &UserInfoRef = auth_res.extensions().get().unwrap(); + let req = http_auth.authorize(req).await.unwrap(); + let ctx: &QueryContextRef = req.extensions().get().unwrap(); + let user_info = ctx.current_user().unwrap(); let default = auth::userinfo_by_name(None); assert_eq!(default.username(), user_info.username()); @@ -40,7 +42,8 @@ async fn test_http_auth() { // base64encode("greptime:greptime") == "Z3JlcHRpbWU6Z3JlcHRpbWU=" let req = mock_http_request(Some("Basic Z3JlcHRpbWU6Z3JlcHRpbWU="), None).unwrap(); let req = http_auth.authorize(req).await.unwrap(); - let user_info: &UserInfoRef = req.extensions().get().unwrap(); + let ctx: &QueryContextRef = req.extensions().get().unwrap(); + let user_info = ctx.current_user().unwrap(); let default = auth::userinfo_by_name(None); assert_eq!(default.username(), user_info.username()); @@ -70,7 +73,8 @@ async fn test_schema_validating() { ) .unwrap(); let req = http_auth.authorize(req).await.unwrap(); - let user_info: &UserInfoRef = req.extensions().get().unwrap(); + let ctx: &QueryContextRef = req.extensions().get().unwrap(); + let user_info = ctx.current_user().unwrap(); let default = auth::userinfo_by_name(None); assert_eq!(default.username(), user_info.username()); diff --git a/src/servers/tests/http/http_handler_test.rs b/src/servers/tests/http/http_handler_test.rs index e682e87ad0..f8d5a567f0 100644 --- a/src/servers/tests/http/http_handler_test.rs +++ b/src/servers/tests/http/http_handler_test.rs @@ -26,6 +26,7 @@ use servers::http::{ JsonOutput, }; use servers::metrics_handler::MetricsHandler; +use session::context::QueryContext; use table::test_util::MemTable; use crate::{ @@ -36,13 +37,15 @@ use crate::{ #[tokio::test] async fn test_sql_not_provided() { let sql_handler = create_testing_sql_query_handler(MemTable::default_numbers_table()); + let ctx = QueryContext::arc(); + ctx.set_current_user(Some(auth::userinfo_by_name(None))); let Json(json) = http_handler::sql( State(ApiState { sql_handler, script_handler: None, }), Query(http_handler::SqlQuery::default()), - axum::Extension(auth::userinfo_by_name(None)), + axum::Extension(ctx), Form(http_handler::SqlQuery::default()), ) .await; @@ -61,13 +64,15 @@ async fn test_sql_output_rows() { let query = create_query(); let sql_handler = create_testing_sql_query_handler(MemTable::default_numbers_table()); + let ctx = QueryContext::arc(); + ctx.set_current_user(Some(auth::userinfo_by_name(None))); let Json(json) = http_handler::sql( State(ApiState { sql_handler, script_handler: None, }), query, - axum::Extension(auth::userinfo_by_name(None)), + axum::Extension(ctx), Form(http_handler::SqlQuery::default()), ) .await; @@ -107,13 +112,16 @@ async fn test_sql_form() { let form = create_form(); let sql_handler = create_testing_sql_query_handler(MemTable::default_numbers_table()); + let ctx = QueryContext::arc(); + ctx.set_current_user(Some(auth::userinfo_by_name(None))); + let Json(json) = http_handler::sql( State(ApiState { sql_handler, script_handler: None, }), Query(http_handler::SqlQuery::default()), - axum::Extension(auth::userinfo_by_name(None)), + axum::Extension(ctx), form, ) .await; diff --git a/tests-integration/tests/grpc.rs b/tests-integration/tests/grpc.rs index 5abe02f086..bc3f8abf12 100644 --- a/tests-integration/tests/grpc.rs +++ b/tests-integration/tests/grpc.rs @@ -24,7 +24,7 @@ use client::{Client, Database, DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_catalog::consts::{MIN_USER_TABLE_ID, MITO_ENGINE}; use common_query::Output; use common_recordbatch::RecordBatches; -use servers::prometheus::{PromData, PromSeries, PrometheusJsonResponse, PrometheusResponse}; +use servers::http::prometheus::{PromData, PromSeries, PrometheusJsonResponse, PrometheusResponse}; use servers::server::Server; use tests_integration::test_util::{ setup_grpc_server, setup_grpc_server_with_user_provider, StorageType, diff --git a/tests-integration/tests/http.rs b/tests-integration/tests/http.rs index 72c77a2429..8079f41e4d 100644 --- a/tests-integration/tests/http.rs +++ b/tests-integration/tests/http.rs @@ -20,8 +20,8 @@ use axum_test_helper::TestClient; use common_error::status_code::StatusCode as ErrorCode; use serde_json::json; use servers::http::handler::HealthResponse; +use servers::http::prometheus::{PrometheusJsonResponse, PrometheusResponse}; use servers::http::{JsonOutput, JsonResponse}; -use servers::prometheus::{PrometheusJsonResponse, PrometheusResponse}; use tests_integration::test_util::{ setup_test_http_app, setup_test_http_app_with_frontend, setup_test_http_app_with_frontend_and_user_provider, setup_test_prom_app_with_frontend,