feat: track channels with query context and w/rcu (#4448)

* feat: add source channel to meter recorders

* feat: provide channel for query context

* fix: testing and extension get for query context

* chore: revert cargo toml structure changes

* fix: querycontext modification for prometheus and pipeline

* chore: switch git dependency to main branches

* chore: remove TODO

* refactor: rename other to unknown

---------

Co-authored-by: shuiyisong <113876041+shuiyisong@users.noreply.github.com>
This commit is contained in:
Ning Sun
2024-07-31 15:30:50 +08:00
committed by GitHub
parent dd23d47743
commit b741a7181b
21 changed files with 193 additions and 96 deletions

7
Cargo.lock generated
View File

@@ -4234,7 +4234,7 @@ checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b"
[[package]]
name = "greptime-proto"
version = "0.1.0"
source = "git+https://github.com/GreptimeTeam/greptime-proto.git?rev=7ca323090b3ae8faf2c15036b7f41b7c5225cf5f#7ca323090b3ae8faf2c15036b7f41b7c5225cf5f"
source = "git+https://github.com/GreptimeTeam/greptime-proto.git?rev=c437b55725b7f5224fe9d46db21072b4a682ee4b#c437b55725b7f5224fe9d46db21072b4a682ee4b"
dependencies = [
"prost 0.12.6",
"serde",
@@ -6155,7 +6155,7 @@ dependencies = [
[[package]]
name = "meter-core"
version = "0.1.0"
source = "git+https://github.com/GreptimeTeam/greptime-meter.git?rev=80b72716dcde47ec4161478416a5c6c21343364d#80b72716dcde47ec4161478416a5c6c21343364d"
source = "git+https://github.com/GreptimeTeam/greptime-meter.git?rev=049171eb16cb4249d8099751a0c46750d1fe88e7#049171eb16cb4249d8099751a0c46750d1fe88e7"
dependencies = [
"anymap",
"once_cell",
@@ -6165,7 +6165,7 @@ dependencies = [
[[package]]
name = "meter-macros"
version = "0.1.0"
source = "git+https://github.com/GreptimeTeam/greptime-meter.git?rev=80b72716dcde47ec4161478416a5c6c21343364d#80b72716dcde47ec4161478416a5c6c21343364d"
source = "git+https://github.com/GreptimeTeam/greptime-meter.git?rev=049171eb16cb4249d8099751a0c46750d1fe88e7#049171eb16cb4249d8099751a0c46750d1fe88e7"
dependencies = [
"meter-core",
]
@@ -10326,6 +10326,7 @@ dependencies = [
"common-telemetry",
"common-time",
"derive_builder 0.12.0",
"meter-core",
"snafu 0.8.3",
"sql",
]

View File

@@ -119,12 +119,12 @@ etcd-client = { version = "0.13" }
fst = "0.4.7"
futures = "0.3"
futures-util = "0.3"
greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "7ca323090b3ae8faf2c15036b7f41b7c5225cf5f" }
greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "c437b55725b7f5224fe9d46db21072b4a682ee4b" }
humantime = "2.1"
humantime-serde = "1.1"
itertools = "0.10"
lazy_static = "1.4"
meter-core = { git = "https://github.com/GreptimeTeam/greptime-meter.git", rev = "80b72716dcde47ec4161478416a5c6c21343364d" }
meter-core = { git = "https://github.com/GreptimeTeam/greptime-meter.git", rev = "049171eb16cb4249d8099751a0c46750d1fe88e7" }
mockall = "0.11.4"
moka = "0.12"
notify = "6.1"
@@ -238,7 +238,7 @@ table = { path = "src/table" }
[workspace.dependencies.meter-macros]
git = "https://github.com/GreptimeTeam/greptime-meter.git"
rev = "80b72716dcde47ec4161478416a5c6c21343364d"
rev = "049171eb16cb4249d8099751a0c46750d1fe88e7"
[profile.release]
debug = 1

View File

@@ -1079,6 +1079,7 @@ pub struct QueryContext {
current_schema: String,
timezone: String,
extensions: HashMap<String, String>,
channel: u8,
}
impl From<QueryContextRef> for QueryContext {
@@ -1088,6 +1089,7 @@ impl From<QueryContextRef> for QueryContext {
current_schema: query_context.current_schema().to_string(),
timezone: query_context.timezone().to_string(),
extensions: query_context.extensions(),
channel: query_context.channel() as u8,
}
}
}
@@ -1099,6 +1101,7 @@ impl From<QueryContext> for PbQueryContext {
current_schema,
timezone,
extensions,
channel,
}: QueryContext,
) -> Self {
PbQueryContext {
@@ -1106,6 +1109,7 @@ impl From<QueryContext> for PbQueryContext {
current_schema,
timezone,
extensions,
channel: channel as u32,
}
}
}

View File

@@ -270,7 +270,12 @@ impl Inserter {
requests: RegionInsertRequests,
ctx: &QueryContextRef,
) -> Result<Output> {
let write_cost = write_meter!(ctx.current_catalog(), ctx.current_schema(), requests);
let write_cost = write_meter!(
ctx.current_catalog(),
ctx.current_schema(),
requests,
ctx.channel() as u8
);
let request_factory = RegionRequestFactory::new(RegionRequestHeader {
tracing_context: TracingContext::from_current_span().to_w3c(),
dbname: ctx.get_db_string(),

View File

@@ -194,6 +194,7 @@ impl MergeScanExec {
let tracing_context = TracingContext::from_json(context.session_id().as_str());
let current_catalog = self.query_ctx.current_catalog().to_string();
let current_schema = self.query_ctx.current_schema().to_string();
let current_channel = self.query_ctx.channel();
let timezone = self.query_ctx.timezone().to_string();
let extensions = self.query_ctx.extensions();
let target_partition = self.target_partition;
@@ -221,6 +222,7 @@ impl MergeScanExec {
current_schema: current_schema.clone(),
timezone: timezone.clone(),
extensions: extensions.clone(),
channel: current_channel as u32,
}),
}),
region_id,
@@ -271,7 +273,8 @@ impl MergeScanExec {
ReadItem {
cpu_time: metrics.elapsed_compute as u64,
table_scan: metrics.memory_usage as u64
}
},
current_channel as u8
);
metric.record_greptime_exec_cost(value as usize);

View File

@@ -14,12 +14,11 @@
use std::pin::Pin;
use std::result::Result as StdResult;
use std::sync::Arc;
use std::task::{Context, Poll};
use auth::UserProviderRef;
use hyper::Body;
use session::context::QueryContext;
use session::context::{Channel, QueryContext};
use tonic::body::BoxBody;
use tonic::server::NamedService;
use tower::{Layer, Service};
@@ -105,7 +104,7 @@ async fn do_auth<T>(
) -> Result<(), tonic::Status> {
let (catalog, schema) = extract_catalog_and_schema(req);
let query_ctx = Arc::new(QueryContext::with(&catalog, &schema));
let query_ctx = QueryContext::with_channel(&catalog, &schema, Channel::Grpc);
let Some(user_provider) = user_provider else {
query_ctx.set_current_user(auth::userinfo_by_name(None));
@@ -139,7 +138,7 @@ mod tests {
use base64::Engine;
use headers::Header;
use hyper::{Body, Request};
use session::context::QueryContextRef;
use session::context::QueryContext;
use crate::grpc::authorize::do_auth;
use crate::http::header::GreptimeDbName;
@@ -197,7 +196,7 @@ mod tests {
expected_schema: &str,
expected_user_name: &str,
) {
let ctx = req.extensions().get::<QueryContextRef>().unwrap();
let ctx = req.extensions().get::<QueryContext>().unwrap();
assert_eq!(expected_catalog, ctx.current_catalog());
assert_eq!(expected_schema, ctx.current_schema());

View File

@@ -13,6 +13,7 @@
// limitations under the License.
use std::result::Result as StdResult;
use std::sync::Arc;
use opentelemetry_proto::tonic::collector::metrics::v1::metrics_service_server::MetricsService;
use opentelemetry_proto::tonic::collector::metrics::v1::{
@@ -22,7 +23,7 @@ use opentelemetry_proto::tonic::collector::trace::v1::trace_service_server::Trac
use opentelemetry_proto::tonic::collector::trace::v1::{
ExportTraceServiceRequest, ExportTraceServiceResponse,
};
use session::context::QueryContextRef;
use session::context::{Channel, QueryContext};
use snafu::OptionExt;
use tonic::{Request, Response, Status};
@@ -47,10 +48,12 @@ impl TraceService for OtlpService {
) -> StdResult<Response<ExportTraceServiceResponse>, Status> {
let (_headers, extensions, req) = request.into_parts();
let ctx = extensions
.get::<QueryContextRef>()
let mut ctx = extensions
.get::<QueryContext>()
.cloned()
.context(error::MissingQueryContextSnafu)?;
ctx.set_channel(Channel::Otlp);
let ctx = Arc::new(ctx);
let _ = self.handler.traces(req, ctx).await?;
@@ -68,10 +71,12 @@ impl MetricsService for OtlpService {
) -> StdResult<Response<ExportMetricsServiceResponse>, Status> {
let (_headers, extensions, req) = request.into_parts();
let ctx = extensions
.get::<QueryContextRef>()
let mut ctx = extensions
.get::<QueryContext>()
.cloned()
.context(error::MissingQueryContextSnafu)?;
ctx.set_channel(Channel::Otlp);
let ctx = Arc::new(ctx);
let _ = self.handler.metrics(req, ctx).await?;

View File

@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use ::auth::UserProviderRef;
use axum::extract::State;
use axum::http::{self, Request, StatusCode};
@@ -68,7 +66,7 @@ pub async fn inner_auth<B>(
.current_schema(schema.clone())
.timezone(timezone);
let query_ctx = Arc::new(query_ctx_builder.build());
let query_ctx = query_ctx_builder.build();
let need_auth = need_auth(&req);
// 2. check if auth is needed

View File

@@ -32,7 +32,7 @@ use pipeline::{PipelineVersion, Value as PipelineValue};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::{Deserializer, Value};
use session::context::QueryContextRef;
use session::context::{Channel, QueryContext, QueryContextRef};
use snafu::{ensure, OptionExt, ResultExt};
use crate::error::{
@@ -107,7 +107,7 @@ where
pub async fn add_pipeline(
State(state): State<LogState>,
Path(pipeline_name): Path<String>,
Extension(query_ctx): Extension<QueryContextRef>,
Extension(mut query_ctx): Extension<QueryContext>,
PipelineContent(payload): PipelineContent,
) -> Result<GreptimedbManageResponse> {
let start = Instant::now();
@@ -126,6 +126,9 @@ pub async fn add_pipeline(
.build());
}
query_ctx.set_channel(Channel::Http);
let query_ctx = Arc::new(query_ctx);
let content_type = "yaml";
let result = handler
.insert_pipeline(&pipeline_name, content_type, &payload, query_ctx)
@@ -148,7 +151,7 @@ pub async fn add_pipeline(
#[axum_macros::debug_handler]
pub async fn delete_pipeline(
State(state): State<LogState>,
Extension(query_ctx): Extension<QueryContextRef>,
Extension(mut query_ctx): Extension<QueryContext>,
Query(query_params): Query<LogIngesterQueryParams>,
Path(pipeline_name): Path<String>,
) -> Result<GreptimedbManageResponse> {
@@ -167,6 +170,9 @@ pub async fn delete_pipeline(
let version = to_pipeline_version(Some(version_str.clone())).context(PipelineSnafu)?;
query_ctx.set_channel(Channel::Http);
let query_ctx = Arc::new(query_ctx);
handler
.delete_pipeline(&pipeline_name, version, query_ctx)
.await
@@ -231,7 +237,7 @@ fn transform_ndjson_array_factory(
pub async fn log_ingester(
State(log_state): State<LogState>,
Query(query_params): Query<LogIngesterQueryParams>,
Extension(query_ctx): Extension<QueryContextRef>,
Extension(mut query_ctx): Extension<QueryContext>,
TypedHeader(content_type): TypedHeader<ContentType>,
payload: String,
) -> Result<HttpResponse> {
@@ -256,6 +262,9 @@ pub async fn log_ingester(
let value = extract_pipeline_value_by_content_type(content_type, payload, ignore_errors)?;
query_ctx.set_channel(Channel::Http);
let query_ctx = Arc::new(query_ctx);
ingest_logs_inner(
handler,
pipeline_name,

View File

@@ -13,6 +13,7 @@
// limitations under the License.
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use aide::transform::TransformOperation;
@@ -29,7 +30,7 @@ use query::parser::{PromQuery, DEFAULT_LOOKBACK_STRING};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use session::context::QueryContextRef;
use session::context::{Channel, QueryContext, QueryContextRef};
use super::header::collect_plan_metrics;
use crate::http::arrow_result::ArrowResponse;
@@ -70,13 +71,16 @@ pub struct SqlQuery {
pub async fn sql(
State(state): State<ApiState>,
Query(query_params): Query<SqlQuery>,
Extension(query_ctx): Extension<QueryContextRef>,
Extension(mut query_ctx): Extension<QueryContext>,
Form(form_params): Form<SqlQuery>,
) -> HttpResponse {
let start = Instant::now();
let sql_handler = &state.sql_handler;
let db = query_ctx.get_db_string();
query_ctx.set_channel(Channel::Http);
let query_ctx = Arc::new(query_ctx);
let _timer = crate::metrics::METRIC_HTTP_SQL_ELAPSED
.with_label_values(&[db.as_str()])
.start_timer();
@@ -232,12 +236,15 @@ impl From<PromqlQuery> for PromQuery {
pub async fn promql(
State(state): State<ApiState>,
Query(params): Query<PromqlQuery>,
Extension(query_ctx): Extension<QueryContextRef>,
Extension(mut query_ctx): Extension<QueryContext>,
) -> Response {
let sql_handler = &state.sql_handler;
let exec_start = Instant::now();
let db = query_ctx.get_db_string();
query_ctx.set_channel(Channel::Http);
let query_ctx = Arc::new(query_ctx);
let _timer = crate::metrics::METRIC_HTTP_PROMQL_ELAPSED
.with_label_values(&[db.as_str()])
.start_timer();

View File

@@ -13,6 +13,7 @@
// limitations under the License.
use std::collections::HashMap;
use std::sync::Arc;
use axum::extract::{Query, State};
use axum::http::StatusCode;
@@ -21,7 +22,7 @@ use axum::Extension;
use common_catalog::consts::DEFAULT_SCHEMA_NAME;
use common_grpc::precision::Precision;
use common_telemetry::tracing;
use session::context::QueryContextRef;
use session::context::{Channel, QueryContext, QueryContextRef};
use super::header::write_cost_header_map;
use crate::error::{Result, TimePrecisionSnafu};
@@ -45,12 +46,14 @@ pub async fn influxdb_health() -> Result<impl IntoResponse> {
pub async fn influxdb_write_v1(
State(handler): State<InfluxdbLineProtocolHandlerRef>,
Query(mut params): Query<HashMap<String, String>>,
Extension(query_ctx): Extension<QueryContextRef>,
Extension(mut query_ctx): Extension<QueryContext>,
lines: String,
) -> Result<impl IntoResponse> {
let db = params
.remove("db")
.unwrap_or_else(|| DEFAULT_SCHEMA_NAME.to_string());
query_ctx.set_channel(Channel::Influx);
let query_ctx = Arc::new(query_ctx);
let precision = params
.get("precision")
@@ -65,7 +68,7 @@ pub async fn influxdb_write_v1(
pub async fn influxdb_write_v2(
State(handler): State<InfluxdbLineProtocolHandlerRef>,
Query(mut params): Query<HashMap<String, String>>,
Extension(query_ctx): Extension<QueryContextRef>,
Extension(mut query_ctx): Extension<QueryContext>,
lines: String,
) -> Result<impl IntoResponse> {
let db = match (params.remove("db"), params.remove("bucket")) {
@@ -73,6 +76,8 @@ pub async fn influxdb_write_v2(
(Some(db), None) => db.clone(),
_ => DEFAULT_SCHEMA_NAME.to_string(),
};
query_ctx.set_channel(Channel::Influx);
let query_ctx = Arc::new(query_ctx);
let precision = params
.get("precision")

View File

@@ -13,6 +13,7 @@
// limitations under the License.
use std::collections::HashMap;
use std::sync::Arc;
use axum::extract::{Query, RawBody, State};
use axum::http::StatusCode as HttpStatusCode;
@@ -20,7 +21,7 @@ use axum::{Extension, Json};
use common_error::ext::ErrorExt;
use hyper::Body;
use serde::{Deserialize, Serialize};
use session::context::QueryContextRef;
use session::context::{Channel, QueryContext};
use snafu::ResultExt;
use crate::error::{self, Result};
@@ -74,7 +75,7 @@ pub enum OpentsdbPutResponse {
pub async fn put(
State(opentsdb_handler): State<OpentsdbProtocolHandlerRef>,
Query(params): Query<HashMap<String, String>>,
Extension(ctx): Extension<QueryContextRef>,
Extension(mut ctx): Extension<QueryContext>,
RawBody(body): RawBody,
) -> Result<(HttpStatusCode, Json<OpentsdbPutResponse>)> {
let summary = params.contains_key("summary");
@@ -86,6 +87,9 @@ pub async fn put(
.map(|point| point.clone().into())
.collect::<Vec<_>>();
ctx.set_channel(Channel::Opentsdb);
let ctx = Arc::new(ctx);
let response = if !summary && !details {
if let Err(e) = opentsdb_handler.exec(data_points, ctx.clone()).await {
// Not debugging purpose, failed fast.

View File

@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use axum::extract::{RawBody, State};
use axum::http::header;
use axum::response::IntoResponse;
@@ -25,7 +27,7 @@ use opentelemetry_proto::tonic::collector::trace::v1::{
ExportTraceServiceRequest, ExportTraceServiceResponse,
};
use prost::Message;
use session::context::QueryContextRef;
use session::context::{Channel, QueryContext};
use snafu::prelude::*;
use super::header::{write_cost_header_map, CONTENT_TYPE_PROTOBUF};
@@ -36,10 +38,12 @@ use crate::query_handler::OpenTelemetryProtocolHandlerRef;
#[tracing::instrument(skip_all, fields(protocol = "otlp", request_type = "metrics"))]
pub async fn metrics(
State(handler): State<OpenTelemetryProtocolHandlerRef>,
Extension(query_ctx): Extension<QueryContextRef>,
Extension(mut query_ctx): Extension<QueryContext>,
RawBody(body): RawBody,
) -> Result<OtlpMetricsResponse> {
let db = query_ctx.get_db_string();
query_ctx.set_channel(Channel::Otlp);
let query_ctx = Arc::new(query_ctx);
let _timer = crate::metrics::METRIC_HTTP_OPENTELEMETRY_METRICS_ELAPSED
.with_label_values(&[db.as_str()])
.start_timer();
@@ -83,10 +87,12 @@ impl IntoResponse for OtlpMetricsResponse {
#[tracing::instrument(skip_all, fields(protocol = "otlp", request_type = "traces"))]
pub async fn traces(
State(handler): State<OpenTelemetryProtocolHandlerRef>,
Extension(query_ctx): Extension<QueryContextRef>,
Extension(mut query_ctx): Extension<QueryContext>,
RawBody(body): RawBody,
) -> Result<OtlpTracesResponse> {
let db = query_ctx.get_db_string();
query_ctx.set_channel(Channel::Otlp);
let query_ctx = Arc::new(query_ctx);
let _timer = crate::metrics::METRIC_HTTP_OPENTELEMETRY_TRACES_ELAPSED
.with_label_values(&[db.as_str()])
.start_timer();

View File

@@ -30,7 +30,7 @@ use object_pool::Pool;
use prost::Message;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use session::context::QueryContextRef;
use session::context::{Channel, QueryContext};
use snafu::prelude::*;
use super::header::{write_cost_header_map, GREPTIME_DB_HEADER_METRICS};
@@ -74,7 +74,7 @@ impl Default for RemoteWriteQuery {
pub async fn route_write_without_metric_engine(
handler: State<PromStoreProtocolHandlerRef>,
query: Query<RemoteWriteQuery>,
extension: Extension<QueryContextRef>,
extension: Extension<QueryContext>,
content_encoding: TypedHeader<headers::ContentEncoding>,
raw_body: RawBody,
) -> Result<impl IntoResponse> {
@@ -96,7 +96,7 @@ pub async fn route_write_without_metric_engine(
pub async fn route_write_without_metric_engine_and_strict_mode(
handler: State<PromStoreProtocolHandlerRef>,
query: Query<RemoteWriteQuery>,
extension: Extension<QueryContextRef>,
extension: Extension<QueryContext>,
content_encoding: TypedHeader<headers::ContentEncoding>,
raw_body: RawBody,
) -> Result<impl IntoResponse> {
@@ -120,7 +120,7 @@ pub async fn route_write_without_metric_engine_and_strict_mode(
pub async fn remote_write(
handler: State<PromStoreProtocolHandlerRef>,
query: Query<RemoteWriteQuery>,
extension: Extension<QueryContextRef>,
extension: Extension<QueryContext>,
content_encoding: TypedHeader<headers::ContentEncoding>,
raw_body: RawBody,
) -> Result<impl IntoResponse> {
@@ -144,7 +144,7 @@ pub async fn remote_write(
pub async fn remote_write_without_strict_mode(
handler: State<PromStoreProtocolHandlerRef>,
query: Query<RemoteWriteQuery>,
extension: Extension<QueryContextRef>,
extension: Extension<QueryContext>,
content_encoding: TypedHeader<headers::ContentEncoding>,
raw_body: RawBody,
) -> Result<impl IntoResponse> {
@@ -163,7 +163,7 @@ pub async fn remote_write_without_strict_mode(
async fn remote_write_impl(
State(handler): State<PromStoreProtocolHandlerRef>,
Query(params): Query<RemoteWriteQuery>,
Extension(mut query_ctx): Extension<QueryContextRef>,
Extension(mut query_ctx): Extension<QueryContext>,
content_encoding: TypedHeader<headers::ContentEncoding>,
RawBody(body): RawBody,
is_strict_mode: bool,
@@ -175,6 +175,7 @@ async fn remote_write_impl(
}
let db = params.db.clone().unwrap_or_default();
query_ctx.set_channel(Channel::Prometheus);
let _timer = crate::metrics::METRIC_HTTP_PROM_STORE_WRITE_ELAPSED
.with_label_values(&[db.as_str()])
.start_timer();
@@ -183,10 +184,9 @@ async fn remote_write_impl(
let (request, samples) = decode_remote_write_request(is_zstd, body, is_strict_mode).await?;
if let Some(physical_table) = params.physical_table {
let mut new_query_ctx = query_ctx.as_ref().clone();
new_query_ctx.set_extension(PHYSICAL_TABLE_PARAM, physical_table);
query_ctx = Arc::new(new_query_ctx);
query_ctx.set_extension(PHYSICAL_TABLE_PARAM, physical_table);
}
let query_ctx = Arc::new(query_ctx);
let output = handler.write(request, query_ctx, is_metric_engine).await?;
crate::metrics::PROM_STORE_REMOTE_WRITE_SAMPLES.inc_by(samples as u64);
@@ -224,10 +224,12 @@ impl IntoResponse for PromStoreResponse {
pub async fn remote_read(
State(handler): State<PromStoreProtocolHandlerRef>,
Query(params): Query<RemoteWriteQuery>,
Extension(query_ctx): Extension<QueryContextRef>,
Extension(mut query_ctx): Extension<QueryContext>,
RawBody(body): RawBody,
) -> Result<PromStoreResponse> {
let db = params.db.clone().unwrap_or_default();
query_ctx.set_channel(Channel::Prometheus);
let query_ctx = Arc::new(query_ctx);
let _timer = crate::metrics::METRIC_HTTP_PROM_STORE_READ_ELAPSED
.with_label_values(&[db.as_str()])
.start_timer();

View File

@@ -42,7 +42,7 @@ use schemars::JsonSchema;
use serde::de::{self, MapAccess, Visitor};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use session::context::{QueryContext, QueryContextBuilder, QueryContextRef};
use session::context::QueryContext;
use snafu::{Location, OptionExt, ResultExt};
pub use super::prometheus_resp::PrometheusJsonResponse;
@@ -122,7 +122,7 @@ pub struct FormatQuery {
pub async fn format_query(
State(_handler): State<PrometheusHandlerRef>,
Query(params): Query<InstantQuery>,
Extension(_query_ctx): Extension<QueryContextRef>,
Extension(_query_ctx): Extension<QueryContext>,
Form(form_params): Form<InstantQuery>,
) -> PrometheusJsonResponse {
let query = params.query.or(form_params.query).unwrap_or_default();
@@ -168,7 +168,7 @@ pub struct InstantQuery {
pub async fn instant_query(
State(handler): State<PrometheusHandlerRef>,
Query(params): Query<InstantQuery>,
Extension(mut query_ctx): Extension<QueryContextRef>,
Extension(mut query_ctx): Extension<QueryContext>,
Form(form_params): Form<InstantQuery>,
) -> PrometheusJsonResponse {
// Extract time from query string, or use current server time if not specified.
@@ -190,8 +190,9 @@ pub async fn instant_query(
// update catalog and schema in query context if necessary
if let Some(db) = &params.db {
let (catalog, schema) = parse_catalog_and_schema_from_db_string(db);
query_ctx = try_update_catalog_schema(query_ctx, &catalog, &schema);
try_update_catalog_schema(&mut query_ctx, &catalog, &schema);
}
let query_ctx = Arc::new(query_ctx);
let _timer = crate::metrics::METRIC_HTTP_PROMETHEUS_PROMQL_ELAPSED
.with_label_values(&[query_ctx.get_db_string().as_str(), "instant_query"])
@@ -226,7 +227,7 @@ pub struct RangeQuery {
pub async fn range_query(
State(handler): State<PrometheusHandlerRef>,
Query(params): Query<RangeQuery>,
Extension(mut query_ctx): Extension<QueryContextRef>,
Extension(mut query_ctx): Extension<QueryContext>,
Form(form_params): Form<RangeQuery>,
) -> PrometheusJsonResponse {
let prom_query = PromQuery {
@@ -243,8 +244,9 @@ pub async fn range_query(
// update catalog and schema in query context if necessary
if let Some(db) = &params.db {
let (catalog, schema) = parse_catalog_and_schema_from_db_string(db);
query_ctx = try_update_catalog_schema(query_ctx, &catalog, &schema);
try_update_catalog_schema(&mut query_ctx, &catalog, &schema);
}
let query_ctx = Arc::new(query_ctx);
let _timer = crate::metrics::METRIC_HTTP_PROMETHEUS_PROMQL_ELAPSED
.with_label_values(&[query_ctx.get_db_string().as_str(), "range_query"])
@@ -313,11 +315,12 @@ impl<'de> Deserialize<'de> for Matches {
pub async fn labels_query(
State(handler): State<PrometheusHandlerRef>,
Query(params): Query<LabelsQuery>,
Extension(query_ctx): Extension<QueryContextRef>,
Extension(mut query_ctx): Extension<QueryContext>,
Form(form_params): Form<LabelsQuery>,
) -> PrometheusJsonResponse {
let (catalog, schema) = get_catalog_schema(&params.db, &query_ctx);
let query_ctx = try_update_catalog_schema(query_ctx, &catalog, &schema);
try_update_catalog_schema(&mut query_ctx, &catalog, &schema);
let query_ctx = Arc::new(query_ctx);
let mut queries = params.matches.0;
if queries.is_empty() {
@@ -625,20 +628,10 @@ pub(crate) fn get_catalog_schema(db: &Option<String>, ctx: &QueryContext) -> (St
}
/// Update catalog and schema in [QueryContext] if necessary.
pub(crate) fn try_update_catalog_schema(
ctx: QueryContextRef,
catalog: &str,
schema: &str,
) -> QueryContextRef {
pub(crate) fn try_update_catalog_schema(ctx: &mut QueryContext, catalog: &str, schema: &str) {
if ctx.current_catalog() != catalog || ctx.current_schema() != schema {
Arc::new(
QueryContextBuilder::from_existing(&ctx)
.current_catalog(catalog.to_string())
.current_schema(schema.to_string())
.build(),
)
} else {
ctx
ctx.set_current_catalog(catalog);
ctx.set_current_schema(schema);
}
}
@@ -693,11 +686,12 @@ pub struct LabelValueQuery {
pub async fn label_values_query(
State(handler): State<PrometheusHandlerRef>,
Path(label_name): Path<String>,
Extension(query_ctx): Extension<QueryContextRef>,
Extension(mut query_ctx): Extension<QueryContext>,
Query(params): Query<LabelValueQuery>,
) -> PrometheusJsonResponse {
let (catalog, schema) = get_catalog_schema(&params.db, &query_ctx);
let query_ctx = try_update_catalog_schema(query_ctx, &catalog, &schema);
try_update_catalog_schema(&mut query_ctx, &catalog, &schema);
let query_ctx = Arc::new(query_ctx);
let _timer = crate::metrics::METRIC_HTTP_PROMETHEUS_PROMQL_ELAPSED
.with_label_values(&[query_ctx.get_db_string().as_str(), "label_values_query"])
@@ -955,7 +949,7 @@ pub struct SeriesQuery {
pub async fn series_query(
State(handler): State<PrometheusHandlerRef>,
Query(params): Query<SeriesQuery>,
Extension(mut query_ctx): Extension<QueryContextRef>,
Extension(mut query_ctx): Extension<QueryContext>,
Form(form_params): Form<SeriesQuery>,
) -> PrometheusJsonResponse {
let mut queries: Vec<String> = params.matches.0;
@@ -981,8 +975,9 @@ pub async fn series_query(
// update catalog and schema in query context if necessary
if let Some(db) = &params.db {
let (catalog, schema) = parse_catalog_and_schema_from_db_string(db);
query_ctx = try_update_catalog_schema(query_ctx, &catalog, &schema);
try_update_catalog_schema(&mut query_ctx, &catalog, &schema);
}
let query_ctx = Arc::new(query_ctx);
let _timer = crate::metrics::METRIC_HTTP_PROMETHEUS_PROMQL_ELAPSED
.with_label_values(&[query_ctx.get_db_string().as_str(), "series_query"])

View File

@@ -20,14 +20,14 @@ use axum::http;
use http_body::Body;
use hyper::{Request, StatusCode};
use servers::http::authorize::inner_auth;
use session::context::QueryContextRef;
use session::context::QueryContext;
#[tokio::test]
async fn test_http_auth() {
// base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ="
let req = mock_http_request(Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap();
let req = inner_auth(None, req).await.unwrap();
let ctx: &QueryContextRef = req.extensions().get().unwrap();
let ctx: &QueryContext = req.extensions().get().unwrap();
let user_info = ctx.current_user();
let default = auth::userinfo_by_name(None);
assert_eq!(default.username(), user_info.username());
@@ -38,7 +38,7 @@ async fn test_http_auth() {
// base64encode("greptime:greptime") == "Z3JlcHRpbWU6Z3JlcHRpbWU="
let req = mock_http_request(Some("Basic Z3JlcHRpbWU6Z3JlcHRpbWU="), None).unwrap();
let req = inner_auth(mock_user_provider.clone(), req).await.unwrap();
let ctx: &QueryContextRef = req.extensions().get().unwrap();
let ctx: &QueryContext = req.extensions().get().unwrap();
let user_info = ctx.current_user();
let default = auth::userinfo_by_name(None);
assert_eq!(default.username(), user_info.username());
@@ -79,7 +79,7 @@ async fn test_schema_validating() {
)
.unwrap();
let req = inner_auth(mock_user_provider.clone(), req).await.unwrap();
let ctx: &QueryContextRef = req.extensions().get().unwrap();
let ctx: &QueryContext = req.extensions().get().unwrap();
let user_info = ctx.current_user();
let default = auth::userinfo_by_name(None);
assert_eq!(default.username(), user_info.username());

View File

@@ -40,7 +40,7 @@ use crate::{
#[tokio::test]
async fn test_sql_not_provided() {
let sql_handler = create_testing_sql_query_handler(MemTable::default_numbers_table());
let ctx = QueryContext::arc();
let ctx = QueryContext::with_db_name(None);
ctx.set_current_user(auth::userinfo_by_name(None));
let api_state = ApiState {
sql_handler,
@@ -74,7 +74,7 @@ async fn test_sql_output_rows() {
let sql_handler = create_testing_sql_query_handler(MemTable::default_numbers_table());
let ctx = QueryContext::arc();
let ctx = QueryContext::with_db_name(None);
ctx.set_current_user(auth::userinfo_by_name(None));
let api_state = ApiState {
sql_handler,
@@ -180,7 +180,7 @@ async fn test_sql_output_rows() {
#[tokio::test]
async fn test_dashboard_sql_limit() {
let sql_handler = create_testing_sql_query_handler(MemTable::specified_numbers_table(2000));
let ctx = QueryContext::arc();
let ctx = QueryContext::with_db_name(None);
ctx.set_current_user(auth::userinfo_by_name(None));
let api_state = ApiState {
sql_handler,
@@ -226,7 +226,7 @@ async fn test_sql_form() {
let sql_handler = create_testing_sql_query_handler(MemTable::default_numbers_table());
let ctx = QueryContext::arc();
let ctx = QueryContext::with_db_name(None);
ctx.set_current_user(auth::userinfo_by_name(None));
let api_state = ApiState {
sql_handler,

View File

@@ -20,5 +20,6 @@ common-macro.workspace = true
common-telemetry.workspace = true
common-time.workspace = true
derive_builder.workspace = true
meter-core.workspace = true
snafu.workspace = true
sql.workspace = true

View File

@@ -25,7 +25,7 @@ use common_catalog::{build_db_string, parse_catalog_and_schema_from_db_string};
use common_time::timezone::parse_timezone;
use common_time::Timezone;
use derive_builder::Builder;
use sql::dialect::{Dialect, GreptimeDbDialect, MySqlDialect, PostgreSqlDialect};
use sql::dialect::{Dialect, GenericDialect, GreptimeDbDialect, MySqlDialect, PostgreSqlDialect};
use crate::session_config::{PGByteaOutputValue, PGDateOrder, PGDateTimeStyle};
use crate::MutableInner;
@@ -47,6 +47,9 @@ pub struct QueryContext {
// The configuration parameter are used to store the parameters that are set by the user
#[builder(default)]
configuration_parameter: Arc<ConfigurationVariables>,
// Track which protocol the query comes from.
#[builder(default)]
channel: Channel,
}
impl Display for QueryContext {
@@ -94,7 +97,8 @@ impl From<&RegionRequestHeader> for QueryContext {
.current_catalog(ctx.current_catalog.clone())
.current_schema(ctx.current_schema.clone())
.timezone(parse_timezone(Some(&ctx.timezone)))
.extensions(ctx.extensions.clone());
.extensions(ctx.extensions.clone())
.channel(ctx.channel.into());
}
builder.build()
}
@@ -107,6 +111,7 @@ impl From<api::v1::QueryContext> for QueryContext {
.current_schema(ctx.current_schema)
.timezone(parse_timezone(Some(&ctx.timezone)))
.extensions(ctx.extensions)
.channel(ctx.channel.into())
.build()
}
}
@@ -117,6 +122,7 @@ impl From<QueryContext> for api::v1::QueryContext {
current_catalog,
mutable_inner,
extensions,
channel,
..
}: QueryContext,
) -> Self {
@@ -126,6 +132,7 @@ impl From<QueryContext> for api::v1::QueryContext {
current_schema: mutable_inner.schema.clone(),
timezone: mutable_inner.timezone.to_string(),
extensions,
channel: channel as u32,
}
}
}
@@ -142,6 +149,14 @@ impl QueryContext {
.build()
}
pub fn with_channel(catalog: &str, schema: &str, channel: Channel) -> QueryContext {
QueryContextBuilder::default()
.current_catalog(catalog.to_string())
.current_schema(schema.to_string())
.channel(channel)
.build()
}
pub fn with_db_name(db_name: Option<&str>) -> QueryContext {
let (catalog, schema) = db_name
.map(|db| {
@@ -172,6 +187,10 @@ impl QueryContext {
&self.current_catalog
}
pub fn set_current_catalog(&mut self, new_catalog: &str) {
self.current_catalog = new_catalog.to_string();
}
pub fn sql_dialect(&self) -> &(dyn Dialect + Send + Sync) {
&*self.sql_dialect
}
@@ -224,6 +243,14 @@ impl QueryContext {
pub fn configuration_parameter(&self) -> &ConfigurationVariables {
&self.configuration_parameter
}
pub fn channel(&self) -> Channel {
self.channel
}
pub fn set_channel(&mut self, channel: Channel) {
self.channel = channel;
}
}
impl QueryContextBuilder {
@@ -238,6 +265,7 @@ impl QueryContextBuilder {
.unwrap_or_else(|| Arc::new(GreptimeDbDialect {})),
extensions: self.extensions.unwrap_or_default(),
configuration_parameter: self.configuration_parameter.unwrap_or_default(),
channel: self.channel.unwrap_or_default(),
}
}
@@ -247,17 +275,6 @@ impl QueryContextBuilder {
.insert(key, value);
self
}
pub fn from_existing(context: &QueryContext) -> QueryContextBuilder {
QueryContextBuilder {
current_catalog: Some(context.current_catalog.clone()),
// note that this is a shallow copy
mutable_inner: Some(context.mutable_inner.clone()),
sql_dialect: Some(context.sql_dialect.clone()),
extensions: Some(context.extensions.clone()),
configuration_parameter: Some(context.configuration_parameter.clone()),
}
}
}
#[derive(Debug)]
@@ -289,10 +306,37 @@ impl ConnInfo {
}
}
#[derive(Debug, PartialEq)]
#[derive(Debug, PartialEq, Default, Clone, Copy)]
#[repr(u8)]
pub enum Channel {
Mysql,
Postgres,
#[default]
Unknown = 0,
Mysql = 1,
Postgres = 2,
Http = 3,
Prometheus = 4,
Otlp = 5,
Grpc = 6,
Influx = 7,
Opentsdb = 8,
}
impl From<u32> for Channel {
fn from(value: u32) -> Self {
match value {
1 => Self::Mysql,
2 => Self::Postgres,
3 => Self::Http,
4 => Self::Prometheus,
5 => Self::Otlp,
6 => Self::Grpc,
7 => Self::Influx,
8 => Self::Opentsdb,
_ => Self::Unknown,
}
}
}
impl Channel {
@@ -300,6 +344,7 @@ impl Channel {
match self {
Channel::Mysql => Arc::new(MySqlDialect {}),
Channel::Postgres => Arc::new(PostgreSqlDialect {}),
_ => Arc::new(GenericDialect {}),
}
}
}
@@ -309,6 +354,13 @@ impl Display for Channel {
match self {
Channel::Mysql => write!(f, "mysql"),
Channel::Postgres => write!(f, "postgres"),
Channel::Http => write!(f, "http"),
Channel::Prometheus => write!(f, "prometheus"),
Channel::Otlp => write!(f, "otlp"),
Channel::Grpc => write!(f, "grpc"),
Channel::Influx => write!(f, "influx"),
Channel::Opentsdb => write!(f, "opentsdb"),
Channel::Unknown => write!(f, "unknown"),
}
}
}

View File

@@ -79,6 +79,7 @@ impl Session {
.mutable_inner(self.mutable_inner.clone())
.sql_dialect(self.conn_info.channel.dialect())
.configuration_parameter(self.configuration_variables.clone())
.channel(self.conn_info.channel)
.build()
.into()
}

View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
pub use sqlparser::dialect::{Dialect, MySqlDialect, PostgreSqlDialect};
pub use sqlparser::dialect::{Dialect, GenericDialect, MySqlDialect, PostgreSqlDialect};
/// GreptimeDb dialect
#[derive(Debug, Clone)]