From b741a7181bc3a5cb1189e9094d317e841abf998f Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Wed, 31 Jul 2024 15:30:50 +0800 Subject: [PATCH] feat: track channels with query context and w/rcu (#4448) * feat: add source channel to meter recorders * feat: provide channel for query context * fix: testing and extension get for query context * chore: revert cargo toml structure changes * fix: querycontext modification for prometheus and pipeline * chore: switch git dependency to main branches * chore: remove TODO * refactor: rename other to unknown --------- Co-authored-by: shuiyisong <113876041+shuiyisong@users.noreply.github.com> --- Cargo.lock | 7 +- Cargo.toml | 6 +- src/common/meta/src/rpc/ddl.rs | 4 + src/operator/src/insert.rs | 7 +- src/query/src/dist_plan/merge_scan.rs | 5 +- src/servers/src/grpc/authorize.rs | 9 +-- src/servers/src/grpc/otlp.rs | 15 ++-- src/servers/src/http/authorize.rs | 4 +- src/servers/src/http/event.rs | 17 ++++- src/servers/src/http/handler.rs | 13 +++- src/servers/src/http/influxdb.rs | 11 ++- src/servers/src/http/opentsdb.rs | 8 +- src/servers/src/http/otlp.rs | 12 ++- src/servers/src/http/prom_store.rs | 22 +++--- src/servers/src/http/prometheus.rs | 45 +++++------ src/servers/tests/http/authorize.rs | 8 +- src/servers/tests/http/http_handler_test.rs | 8 +- src/session/Cargo.toml | 1 + src/session/src/context.rs | 84 +++++++++++++++++---- src/session/src/lib.rs | 1 + src/sql/src/dialect.rs | 2 +- 21 files changed, 193 insertions(+), 96 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 70b59d27ca..7b973155ac 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4234,7 +4234,7 @@ checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" [[package]] name = "greptime-proto" version = "0.1.0" -source = "git+https://github.com/GreptimeTeam/greptime-proto.git?rev=7ca323090b3ae8faf2c15036b7f41b7c5225cf5f#7ca323090b3ae8faf2c15036b7f41b7c5225cf5f" +source = "git+https://github.com/GreptimeTeam/greptime-proto.git?rev=c437b55725b7f5224fe9d46db21072b4a682ee4b#c437b55725b7f5224fe9d46db21072b4a682ee4b" dependencies = [ "prost 0.12.6", "serde", @@ -6155,7 +6155,7 @@ dependencies = [ [[package]] name = "meter-core" version = "0.1.0" -source = "git+https://github.com/GreptimeTeam/greptime-meter.git?rev=80b72716dcde47ec4161478416a5c6c21343364d#80b72716dcde47ec4161478416a5c6c21343364d" +source = "git+https://github.com/GreptimeTeam/greptime-meter.git?rev=049171eb16cb4249d8099751a0c46750d1fe88e7#049171eb16cb4249d8099751a0c46750d1fe88e7" dependencies = [ "anymap", "once_cell", @@ -6165,7 +6165,7 @@ dependencies = [ [[package]] name = "meter-macros" version = "0.1.0" -source = "git+https://github.com/GreptimeTeam/greptime-meter.git?rev=80b72716dcde47ec4161478416a5c6c21343364d#80b72716dcde47ec4161478416a5c6c21343364d" +source = "git+https://github.com/GreptimeTeam/greptime-meter.git?rev=049171eb16cb4249d8099751a0c46750d1fe88e7#049171eb16cb4249d8099751a0c46750d1fe88e7" dependencies = [ "meter-core", ] @@ -10326,6 +10326,7 @@ dependencies = [ "common-telemetry", "common-time", "derive_builder 0.12.0", + "meter-core", "snafu 0.8.3", "sql", ] diff --git a/Cargo.toml b/Cargo.toml index 1d171c95fe..4dd7e493c9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -119,12 +119,12 @@ etcd-client = { version = "0.13" } fst = "0.4.7" futures = "0.3" futures-util = "0.3" -greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "7ca323090b3ae8faf2c15036b7f41b7c5225cf5f" } +greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "c437b55725b7f5224fe9d46db21072b4a682ee4b" } humantime = "2.1" humantime-serde = "1.1" itertools = "0.10" lazy_static = "1.4" -meter-core = { git = "https://github.com/GreptimeTeam/greptime-meter.git", rev = "80b72716dcde47ec4161478416a5c6c21343364d" } +meter-core = { git = "https://github.com/GreptimeTeam/greptime-meter.git", rev = "049171eb16cb4249d8099751a0c46750d1fe88e7" } mockall = "0.11.4" moka = "0.12" notify = "6.1" @@ -238,7 +238,7 @@ table = { path = "src/table" } [workspace.dependencies.meter-macros] git = "https://github.com/GreptimeTeam/greptime-meter.git" -rev = "80b72716dcde47ec4161478416a5c6c21343364d" +rev = "049171eb16cb4249d8099751a0c46750d1fe88e7" [profile.release] debug = 1 diff --git a/src/common/meta/src/rpc/ddl.rs b/src/common/meta/src/rpc/ddl.rs index b5cf693c9a..ea14a1a5af 100644 --- a/src/common/meta/src/rpc/ddl.rs +++ b/src/common/meta/src/rpc/ddl.rs @@ -1079,6 +1079,7 @@ pub struct QueryContext { current_schema: String, timezone: String, extensions: HashMap, + channel: u8, } impl From for QueryContext { @@ -1088,6 +1089,7 @@ impl From for QueryContext { current_schema: query_context.current_schema().to_string(), timezone: query_context.timezone().to_string(), extensions: query_context.extensions(), + channel: query_context.channel() as u8, } } } @@ -1099,6 +1101,7 @@ impl From for PbQueryContext { current_schema, timezone, extensions, + channel, }: QueryContext, ) -> Self { PbQueryContext { @@ -1106,6 +1109,7 @@ impl From for PbQueryContext { current_schema, timezone, extensions, + channel: channel as u32, } } } diff --git a/src/operator/src/insert.rs b/src/operator/src/insert.rs index 4abda29c2e..556a58c49e 100644 --- a/src/operator/src/insert.rs +++ b/src/operator/src/insert.rs @@ -270,7 +270,12 @@ impl Inserter { requests: RegionInsertRequests, ctx: &QueryContextRef, ) -> Result { - let write_cost = write_meter!(ctx.current_catalog(), ctx.current_schema(), requests); + let write_cost = write_meter!( + ctx.current_catalog(), + ctx.current_schema(), + requests, + ctx.channel() as u8 + ); let request_factory = RegionRequestFactory::new(RegionRequestHeader { tracing_context: TracingContext::from_current_span().to_w3c(), dbname: ctx.get_db_string(), diff --git a/src/query/src/dist_plan/merge_scan.rs b/src/query/src/dist_plan/merge_scan.rs index 8f209a74f7..3bada2533a 100644 --- a/src/query/src/dist_plan/merge_scan.rs +++ b/src/query/src/dist_plan/merge_scan.rs @@ -194,6 +194,7 @@ impl MergeScanExec { let tracing_context = TracingContext::from_json(context.session_id().as_str()); let current_catalog = self.query_ctx.current_catalog().to_string(); let current_schema = self.query_ctx.current_schema().to_string(); + let current_channel = self.query_ctx.channel(); let timezone = self.query_ctx.timezone().to_string(); let extensions = self.query_ctx.extensions(); let target_partition = self.target_partition; @@ -221,6 +222,7 @@ impl MergeScanExec { current_schema: current_schema.clone(), timezone: timezone.clone(), extensions: extensions.clone(), + channel: current_channel as u32, }), }), region_id, @@ -271,7 +273,8 @@ impl MergeScanExec { ReadItem { cpu_time: metrics.elapsed_compute as u64, table_scan: metrics.memory_usage as u64 - } + }, + current_channel as u8 ); metric.record_greptime_exec_cost(value as usize); diff --git a/src/servers/src/grpc/authorize.rs b/src/servers/src/grpc/authorize.rs index b1abfcaf42..c9c7644e00 100644 --- a/src/servers/src/grpc/authorize.rs +++ b/src/servers/src/grpc/authorize.rs @@ -14,12 +14,11 @@ use std::pin::Pin; use std::result::Result as StdResult; -use std::sync::Arc; use std::task::{Context, Poll}; use auth::UserProviderRef; use hyper::Body; -use session::context::QueryContext; +use session::context::{Channel, QueryContext}; use tonic::body::BoxBody; use tonic::server::NamedService; use tower::{Layer, Service}; @@ -105,7 +104,7 @@ async fn do_auth( ) -> Result<(), tonic::Status> { let (catalog, schema) = extract_catalog_and_schema(req); - let query_ctx = Arc::new(QueryContext::with(&catalog, &schema)); + let query_ctx = QueryContext::with_channel(&catalog, &schema, Channel::Grpc); let Some(user_provider) = user_provider else { query_ctx.set_current_user(auth::userinfo_by_name(None)); @@ -139,7 +138,7 @@ mod tests { use base64::Engine; use headers::Header; use hyper::{Body, Request}; - use session::context::QueryContextRef; + use session::context::QueryContext; use crate::grpc::authorize::do_auth; use crate::http::header::GreptimeDbName; @@ -197,7 +196,7 @@ mod tests { expected_schema: &str, expected_user_name: &str, ) { - let ctx = req.extensions().get::().unwrap(); + let ctx = req.extensions().get::().unwrap(); assert_eq!(expected_catalog, ctx.current_catalog()); assert_eq!(expected_schema, ctx.current_schema()); diff --git a/src/servers/src/grpc/otlp.rs b/src/servers/src/grpc/otlp.rs index c96aed7af1..76992e703f 100644 --- a/src/servers/src/grpc/otlp.rs +++ b/src/servers/src/grpc/otlp.rs @@ -13,6 +13,7 @@ // limitations under the License. use std::result::Result as StdResult; +use std::sync::Arc; use opentelemetry_proto::tonic::collector::metrics::v1::metrics_service_server::MetricsService; use opentelemetry_proto::tonic::collector::metrics::v1::{ @@ -22,7 +23,7 @@ use opentelemetry_proto::tonic::collector::trace::v1::trace_service_server::Trac use opentelemetry_proto::tonic::collector::trace::v1::{ ExportTraceServiceRequest, ExportTraceServiceResponse, }; -use session::context::QueryContextRef; +use session::context::{Channel, QueryContext}; use snafu::OptionExt; use tonic::{Request, Response, Status}; @@ -47,10 +48,12 @@ impl TraceService for OtlpService { ) -> StdResult, Status> { let (_headers, extensions, req) = request.into_parts(); - let ctx = extensions - .get::() + let mut ctx = extensions + .get::() .cloned() .context(error::MissingQueryContextSnafu)?; + ctx.set_channel(Channel::Otlp); + let ctx = Arc::new(ctx); let _ = self.handler.traces(req, ctx).await?; @@ -68,10 +71,12 @@ impl MetricsService for OtlpService { ) -> StdResult, Status> { let (_headers, extensions, req) = request.into_parts(); - let ctx = extensions - .get::() + let mut ctx = extensions + .get::() .cloned() .context(error::MissingQueryContextSnafu)?; + ctx.set_channel(Channel::Otlp); + let ctx = Arc::new(ctx); let _ = self.handler.metrics(req, ctx).await?; diff --git a/src/servers/src/http/authorize.rs b/src/servers/src/http/authorize.rs index d2471c851f..7578898c54 100644 --- a/src/servers/src/http/authorize.rs +++ b/src/servers/src/http/authorize.rs @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::Arc; - use ::auth::UserProviderRef; use axum::extract::State; use axum::http::{self, Request, StatusCode}; @@ -68,7 +66,7 @@ pub async fn inner_auth( .current_schema(schema.clone()) .timezone(timezone); - let query_ctx = Arc::new(query_ctx_builder.build()); + let query_ctx = query_ctx_builder.build(); let need_auth = need_auth(&req); // 2. check if auth is needed diff --git a/src/servers/src/http/event.rs b/src/servers/src/http/event.rs index 7be645b364..9381c1b7d1 100644 --- a/src/servers/src/http/event.rs +++ b/src/servers/src/http/event.rs @@ -32,7 +32,7 @@ use pipeline::{PipelineVersion, Value as PipelineValue}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use serde_json::{Deserializer, Value}; -use session::context::QueryContextRef; +use session::context::{Channel, QueryContext, QueryContextRef}; use snafu::{ensure, OptionExt, ResultExt}; use crate::error::{ @@ -107,7 +107,7 @@ where pub async fn add_pipeline( State(state): State, Path(pipeline_name): Path, - Extension(query_ctx): Extension, + Extension(mut query_ctx): Extension, PipelineContent(payload): PipelineContent, ) -> Result { let start = Instant::now(); @@ -126,6 +126,9 @@ pub async fn add_pipeline( .build()); } + query_ctx.set_channel(Channel::Http); + let query_ctx = Arc::new(query_ctx); + let content_type = "yaml"; let result = handler .insert_pipeline(&pipeline_name, content_type, &payload, query_ctx) @@ -148,7 +151,7 @@ pub async fn add_pipeline( #[axum_macros::debug_handler] pub async fn delete_pipeline( State(state): State, - Extension(query_ctx): Extension, + Extension(mut query_ctx): Extension, Query(query_params): Query, Path(pipeline_name): Path, ) -> Result { @@ -167,6 +170,9 @@ pub async fn delete_pipeline( let version = to_pipeline_version(Some(version_str.clone())).context(PipelineSnafu)?; + query_ctx.set_channel(Channel::Http); + let query_ctx = Arc::new(query_ctx); + handler .delete_pipeline(&pipeline_name, version, query_ctx) .await @@ -231,7 +237,7 @@ fn transform_ndjson_array_factory( pub async fn log_ingester( State(log_state): State, Query(query_params): Query, - Extension(query_ctx): Extension, + Extension(mut query_ctx): Extension, TypedHeader(content_type): TypedHeader, payload: String, ) -> Result { @@ -256,6 +262,9 @@ pub async fn log_ingester( let value = extract_pipeline_value_by_content_type(content_type, payload, ignore_errors)?; + query_ctx.set_channel(Channel::Http); + let query_ctx = Arc::new(query_ctx); + ingest_logs_inner( handler, pipeline_name, diff --git a/src/servers/src/http/handler.rs b/src/servers/src/http/handler.rs index c5e599dfe9..d1690e79a9 100644 --- a/src/servers/src/http/handler.rs +++ b/src/servers/src/http/handler.rs @@ -13,6 +13,7 @@ // limitations under the License. use std::collections::HashMap; +use std::sync::Arc; use std::time::Instant; use aide::transform::TransformOperation; @@ -29,7 +30,7 @@ use query::parser::{PromQuery, DEFAULT_LOOKBACK_STRING}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use serde_json::Value; -use session::context::QueryContextRef; +use session::context::{Channel, QueryContext, QueryContextRef}; use super::header::collect_plan_metrics; use crate::http::arrow_result::ArrowResponse; @@ -70,13 +71,16 @@ pub struct SqlQuery { pub async fn sql( State(state): State, Query(query_params): Query, - Extension(query_ctx): Extension, + Extension(mut query_ctx): Extension, Form(form_params): Form, ) -> HttpResponse { let start = Instant::now(); let sql_handler = &state.sql_handler; let db = query_ctx.get_db_string(); + query_ctx.set_channel(Channel::Http); + let query_ctx = Arc::new(query_ctx); + let _timer = crate::metrics::METRIC_HTTP_SQL_ELAPSED .with_label_values(&[db.as_str()]) .start_timer(); @@ -232,12 +236,15 @@ impl From for PromQuery { pub async fn promql( State(state): State, Query(params): Query, - Extension(query_ctx): Extension, + Extension(mut query_ctx): Extension, ) -> Response { let sql_handler = &state.sql_handler; let exec_start = Instant::now(); let db = query_ctx.get_db_string(); + query_ctx.set_channel(Channel::Http); + let query_ctx = Arc::new(query_ctx); + let _timer = crate::metrics::METRIC_HTTP_PROMQL_ELAPSED .with_label_values(&[db.as_str()]) .start_timer(); diff --git a/src/servers/src/http/influxdb.rs b/src/servers/src/http/influxdb.rs index a9f806b3f2..0e27a772f6 100644 --- a/src/servers/src/http/influxdb.rs +++ b/src/servers/src/http/influxdb.rs @@ -13,6 +13,7 @@ // limitations under the License. use std::collections::HashMap; +use std::sync::Arc; use axum::extract::{Query, State}; use axum::http::StatusCode; @@ -21,7 +22,7 @@ use axum::Extension; use common_catalog::consts::DEFAULT_SCHEMA_NAME; use common_grpc::precision::Precision; use common_telemetry::tracing; -use session::context::QueryContextRef; +use session::context::{Channel, QueryContext, QueryContextRef}; use super::header::write_cost_header_map; use crate::error::{Result, TimePrecisionSnafu}; @@ -45,12 +46,14 @@ pub async fn influxdb_health() -> Result { pub async fn influxdb_write_v1( State(handler): State, Query(mut params): Query>, - Extension(query_ctx): Extension, + Extension(mut query_ctx): Extension, lines: String, ) -> Result { let db = params .remove("db") .unwrap_or_else(|| DEFAULT_SCHEMA_NAME.to_string()); + query_ctx.set_channel(Channel::Influx); + let query_ctx = Arc::new(query_ctx); let precision = params .get("precision") @@ -65,7 +68,7 @@ pub async fn influxdb_write_v1( pub async fn influxdb_write_v2( State(handler): State, Query(mut params): Query>, - Extension(query_ctx): Extension, + Extension(mut query_ctx): Extension, lines: String, ) -> Result { let db = match (params.remove("db"), params.remove("bucket")) { @@ -73,6 +76,8 @@ pub async fn influxdb_write_v2( (Some(db), None) => db.clone(), _ => DEFAULT_SCHEMA_NAME.to_string(), }; + query_ctx.set_channel(Channel::Influx); + let query_ctx = Arc::new(query_ctx); let precision = params .get("precision") diff --git a/src/servers/src/http/opentsdb.rs b/src/servers/src/http/opentsdb.rs index a963fe81ef..17faba7c7d 100644 --- a/src/servers/src/http/opentsdb.rs +++ b/src/servers/src/http/opentsdb.rs @@ -13,6 +13,7 @@ // limitations under the License. use std::collections::HashMap; +use std::sync::Arc; use axum::extract::{Query, RawBody, State}; use axum::http::StatusCode as HttpStatusCode; @@ -20,7 +21,7 @@ use axum::{Extension, Json}; use common_error::ext::ErrorExt; use hyper::Body; use serde::{Deserialize, Serialize}; -use session::context::QueryContextRef; +use session::context::{Channel, QueryContext}; use snafu::ResultExt; use crate::error::{self, Result}; @@ -74,7 +75,7 @@ pub enum OpentsdbPutResponse { pub async fn put( State(opentsdb_handler): State, Query(params): Query>, - Extension(ctx): Extension, + Extension(mut ctx): Extension, RawBody(body): RawBody, ) -> Result<(HttpStatusCode, Json)> { let summary = params.contains_key("summary"); @@ -86,6 +87,9 @@ pub async fn put( .map(|point| point.clone().into()) .collect::>(); + ctx.set_channel(Channel::Opentsdb); + let ctx = Arc::new(ctx); + let response = if !summary && !details { if let Err(e) = opentsdb_handler.exec(data_points, ctx.clone()).await { // Not debugging purpose, failed fast. diff --git a/src/servers/src/http/otlp.rs b/src/servers/src/http/otlp.rs index 17f98a3915..a04d1d42a0 100644 --- a/src/servers/src/http/otlp.rs +++ b/src/servers/src/http/otlp.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::sync::Arc; + use axum::extract::{RawBody, State}; use axum::http::header; use axum::response::IntoResponse; @@ -25,7 +27,7 @@ use opentelemetry_proto::tonic::collector::trace::v1::{ ExportTraceServiceRequest, ExportTraceServiceResponse, }; use prost::Message; -use session::context::QueryContextRef; +use session::context::{Channel, QueryContext}; use snafu::prelude::*; use super::header::{write_cost_header_map, CONTENT_TYPE_PROTOBUF}; @@ -36,10 +38,12 @@ use crate::query_handler::OpenTelemetryProtocolHandlerRef; #[tracing::instrument(skip_all, fields(protocol = "otlp", request_type = "metrics"))] pub async fn metrics( State(handler): State, - Extension(query_ctx): Extension, + Extension(mut query_ctx): Extension, RawBody(body): RawBody, ) -> Result { let db = query_ctx.get_db_string(); + query_ctx.set_channel(Channel::Otlp); + let query_ctx = Arc::new(query_ctx); let _timer = crate::metrics::METRIC_HTTP_OPENTELEMETRY_METRICS_ELAPSED .with_label_values(&[db.as_str()]) .start_timer(); @@ -83,10 +87,12 @@ impl IntoResponse for OtlpMetricsResponse { #[tracing::instrument(skip_all, fields(protocol = "otlp", request_type = "traces"))] pub async fn traces( State(handler): State, - Extension(query_ctx): Extension, + Extension(mut query_ctx): Extension, RawBody(body): RawBody, ) -> Result { let db = query_ctx.get_db_string(); + query_ctx.set_channel(Channel::Otlp); + let query_ctx = Arc::new(query_ctx); let _timer = crate::metrics::METRIC_HTTP_OPENTELEMETRY_TRACES_ELAPSED .with_label_values(&[db.as_str()]) .start_timer(); diff --git a/src/servers/src/http/prom_store.rs b/src/servers/src/http/prom_store.rs index fbc39c825a..953160de5b 100644 --- a/src/servers/src/http/prom_store.rs +++ b/src/servers/src/http/prom_store.rs @@ -30,7 +30,7 @@ use object_pool::Pool; use prost::Message; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use session::context::QueryContextRef; +use session::context::{Channel, QueryContext}; use snafu::prelude::*; use super::header::{write_cost_header_map, GREPTIME_DB_HEADER_METRICS}; @@ -74,7 +74,7 @@ impl Default for RemoteWriteQuery { pub async fn route_write_without_metric_engine( handler: State, query: Query, - extension: Extension, + extension: Extension, content_encoding: TypedHeader, raw_body: RawBody, ) -> Result { @@ -96,7 +96,7 @@ pub async fn route_write_without_metric_engine( pub async fn route_write_without_metric_engine_and_strict_mode( handler: State, query: Query, - extension: Extension, + extension: Extension, content_encoding: TypedHeader, raw_body: RawBody, ) -> Result { @@ -120,7 +120,7 @@ pub async fn route_write_without_metric_engine_and_strict_mode( pub async fn remote_write( handler: State, query: Query, - extension: Extension, + extension: Extension, content_encoding: TypedHeader, raw_body: RawBody, ) -> Result { @@ -144,7 +144,7 @@ pub async fn remote_write( pub async fn remote_write_without_strict_mode( handler: State, query: Query, - extension: Extension, + extension: Extension, content_encoding: TypedHeader, raw_body: RawBody, ) -> Result { @@ -163,7 +163,7 @@ pub async fn remote_write_without_strict_mode( async fn remote_write_impl( State(handler): State, Query(params): Query, - Extension(mut query_ctx): Extension, + Extension(mut query_ctx): Extension, content_encoding: TypedHeader, RawBody(body): RawBody, is_strict_mode: bool, @@ -175,6 +175,7 @@ async fn remote_write_impl( } let db = params.db.clone().unwrap_or_default(); + query_ctx.set_channel(Channel::Prometheus); let _timer = crate::metrics::METRIC_HTTP_PROM_STORE_WRITE_ELAPSED .with_label_values(&[db.as_str()]) .start_timer(); @@ -183,10 +184,9 @@ async fn remote_write_impl( let (request, samples) = decode_remote_write_request(is_zstd, body, is_strict_mode).await?; if let Some(physical_table) = params.physical_table { - let mut new_query_ctx = query_ctx.as_ref().clone(); - new_query_ctx.set_extension(PHYSICAL_TABLE_PARAM, physical_table); - query_ctx = Arc::new(new_query_ctx); + query_ctx.set_extension(PHYSICAL_TABLE_PARAM, physical_table); } + let query_ctx = Arc::new(query_ctx); let output = handler.write(request, query_ctx, is_metric_engine).await?; crate::metrics::PROM_STORE_REMOTE_WRITE_SAMPLES.inc_by(samples as u64); @@ -224,10 +224,12 @@ impl IntoResponse for PromStoreResponse { pub async fn remote_read( State(handler): State, Query(params): Query, - Extension(query_ctx): Extension, + Extension(mut query_ctx): Extension, RawBody(body): RawBody, ) -> Result { let db = params.db.clone().unwrap_or_default(); + query_ctx.set_channel(Channel::Prometheus); + let query_ctx = Arc::new(query_ctx); let _timer = crate::metrics::METRIC_HTTP_PROM_STORE_READ_ELAPSED .with_label_values(&[db.as_str()]) .start_timer(); diff --git a/src/servers/src/http/prometheus.rs b/src/servers/src/http/prometheus.rs index f3a907aaf9..f9d4a26b0b 100644 --- a/src/servers/src/http/prometheus.rs +++ b/src/servers/src/http/prometheus.rs @@ -42,7 +42,7 @@ use schemars::JsonSchema; use serde::de::{self, MapAccess, Visitor}; use serde::{Deserialize, Serialize}; use serde_json::Value; -use session::context::{QueryContext, QueryContextBuilder, QueryContextRef}; +use session::context::QueryContext; use snafu::{Location, OptionExt, ResultExt}; pub use super::prometheus_resp::PrometheusJsonResponse; @@ -122,7 +122,7 @@ pub struct FormatQuery { pub async fn format_query( State(_handler): State, Query(params): Query, - Extension(_query_ctx): Extension, + Extension(_query_ctx): Extension, Form(form_params): Form, ) -> PrometheusJsonResponse { let query = params.query.or(form_params.query).unwrap_or_default(); @@ -168,7 +168,7 @@ pub struct InstantQuery { pub async fn instant_query( State(handler): State, Query(params): Query, - Extension(mut query_ctx): Extension, + Extension(mut query_ctx): Extension, Form(form_params): Form, ) -> PrometheusJsonResponse { // Extract time from query string, or use current server time if not specified. @@ -190,8 +190,9 @@ pub async fn instant_query( // update catalog and schema in query context if necessary if let Some(db) = ¶ms.db { let (catalog, schema) = parse_catalog_and_schema_from_db_string(db); - query_ctx = try_update_catalog_schema(query_ctx, &catalog, &schema); + try_update_catalog_schema(&mut query_ctx, &catalog, &schema); } + let query_ctx = Arc::new(query_ctx); let _timer = crate::metrics::METRIC_HTTP_PROMETHEUS_PROMQL_ELAPSED .with_label_values(&[query_ctx.get_db_string().as_str(), "instant_query"]) @@ -226,7 +227,7 @@ pub struct RangeQuery { pub async fn range_query( State(handler): State, Query(params): Query, - Extension(mut query_ctx): Extension, + Extension(mut query_ctx): Extension, Form(form_params): Form, ) -> PrometheusJsonResponse { let prom_query = PromQuery { @@ -243,8 +244,9 @@ pub async fn range_query( // update catalog and schema in query context if necessary if let Some(db) = ¶ms.db { let (catalog, schema) = parse_catalog_and_schema_from_db_string(db); - query_ctx = try_update_catalog_schema(query_ctx, &catalog, &schema); + try_update_catalog_schema(&mut query_ctx, &catalog, &schema); } + let query_ctx = Arc::new(query_ctx); let _timer = crate::metrics::METRIC_HTTP_PROMETHEUS_PROMQL_ELAPSED .with_label_values(&[query_ctx.get_db_string().as_str(), "range_query"]) @@ -313,11 +315,12 @@ impl<'de> Deserialize<'de> for Matches { pub async fn labels_query( State(handler): State, Query(params): Query, - Extension(query_ctx): Extension, + Extension(mut query_ctx): Extension, Form(form_params): Form, ) -> PrometheusJsonResponse { let (catalog, schema) = get_catalog_schema(¶ms.db, &query_ctx); - let query_ctx = try_update_catalog_schema(query_ctx, &catalog, &schema); + try_update_catalog_schema(&mut query_ctx, &catalog, &schema); + let query_ctx = Arc::new(query_ctx); let mut queries = params.matches.0; if queries.is_empty() { @@ -625,20 +628,10 @@ pub(crate) fn get_catalog_schema(db: &Option, ctx: &QueryContext) -> (St } /// Update catalog and schema in [QueryContext] if necessary. -pub(crate) fn try_update_catalog_schema( - ctx: QueryContextRef, - catalog: &str, - schema: &str, -) -> QueryContextRef { +pub(crate) fn try_update_catalog_schema(ctx: &mut QueryContext, catalog: &str, schema: &str) { if ctx.current_catalog() != catalog || ctx.current_schema() != schema { - Arc::new( - QueryContextBuilder::from_existing(&ctx) - .current_catalog(catalog.to_string()) - .current_schema(schema.to_string()) - .build(), - ) - } else { - ctx + ctx.set_current_catalog(catalog); + ctx.set_current_schema(schema); } } @@ -693,11 +686,12 @@ pub struct LabelValueQuery { pub async fn label_values_query( State(handler): State, Path(label_name): Path, - Extension(query_ctx): Extension, + Extension(mut query_ctx): Extension, Query(params): Query, ) -> PrometheusJsonResponse { let (catalog, schema) = get_catalog_schema(¶ms.db, &query_ctx); - let query_ctx = try_update_catalog_schema(query_ctx, &catalog, &schema); + try_update_catalog_schema(&mut query_ctx, &catalog, &schema); + let query_ctx = Arc::new(query_ctx); let _timer = crate::metrics::METRIC_HTTP_PROMETHEUS_PROMQL_ELAPSED .with_label_values(&[query_ctx.get_db_string().as_str(), "label_values_query"]) @@ -955,7 +949,7 @@ pub struct SeriesQuery { pub async fn series_query( State(handler): State, Query(params): Query, - Extension(mut query_ctx): Extension, + Extension(mut query_ctx): Extension, Form(form_params): Form, ) -> PrometheusJsonResponse { let mut queries: Vec = params.matches.0; @@ -981,8 +975,9 @@ pub async fn series_query( // update catalog and schema in query context if necessary if let Some(db) = ¶ms.db { let (catalog, schema) = parse_catalog_and_schema_from_db_string(db); - query_ctx = try_update_catalog_schema(query_ctx, &catalog, &schema); + try_update_catalog_schema(&mut query_ctx, &catalog, &schema); } + let query_ctx = Arc::new(query_ctx); let _timer = crate::metrics::METRIC_HTTP_PROMETHEUS_PROMQL_ELAPSED .with_label_values(&[query_ctx.get_db_string().as_str(), "series_query"]) diff --git a/src/servers/tests/http/authorize.rs b/src/servers/tests/http/authorize.rs index ecbf6b98fb..36b721c877 100644 --- a/src/servers/tests/http/authorize.rs +++ b/src/servers/tests/http/authorize.rs @@ -20,14 +20,14 @@ use axum::http; use http_body::Body; use hyper::{Request, StatusCode}; use servers::http::authorize::inner_auth; -use session::context::QueryContextRef; +use session::context::QueryContext; #[tokio::test] async fn test_http_auth() { // base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ=" let req = mock_http_request(Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap(); let req = inner_auth(None, req).await.unwrap(); - let ctx: &QueryContextRef = req.extensions().get().unwrap(); + let ctx: &QueryContext = req.extensions().get().unwrap(); let user_info = ctx.current_user(); let default = auth::userinfo_by_name(None); assert_eq!(default.username(), user_info.username()); @@ -38,7 +38,7 @@ async fn test_http_auth() { // base64encode("greptime:greptime") == "Z3JlcHRpbWU6Z3JlcHRpbWU=" let req = mock_http_request(Some("Basic Z3JlcHRpbWU6Z3JlcHRpbWU="), None).unwrap(); let req = inner_auth(mock_user_provider.clone(), req).await.unwrap(); - let ctx: &QueryContextRef = req.extensions().get().unwrap(); + let ctx: &QueryContext = req.extensions().get().unwrap(); let user_info = ctx.current_user(); let default = auth::userinfo_by_name(None); assert_eq!(default.username(), user_info.username()); @@ -79,7 +79,7 @@ async fn test_schema_validating() { ) .unwrap(); let req = inner_auth(mock_user_provider.clone(), req).await.unwrap(); - let ctx: &QueryContextRef = req.extensions().get().unwrap(); + let ctx: &QueryContext = req.extensions().get().unwrap(); let user_info = ctx.current_user(); let default = auth::userinfo_by_name(None); assert_eq!(default.username(), user_info.username()); diff --git a/src/servers/tests/http/http_handler_test.rs b/src/servers/tests/http/http_handler_test.rs index 9ff41ec434..d31352072b 100644 --- a/src/servers/tests/http/http_handler_test.rs +++ b/src/servers/tests/http/http_handler_test.rs @@ -40,7 +40,7 @@ use crate::{ #[tokio::test] async fn test_sql_not_provided() { let sql_handler = create_testing_sql_query_handler(MemTable::default_numbers_table()); - let ctx = QueryContext::arc(); + let ctx = QueryContext::with_db_name(None); ctx.set_current_user(auth::userinfo_by_name(None)); let api_state = ApiState { sql_handler, @@ -74,7 +74,7 @@ async fn test_sql_output_rows() { let sql_handler = create_testing_sql_query_handler(MemTable::default_numbers_table()); - let ctx = QueryContext::arc(); + let ctx = QueryContext::with_db_name(None); ctx.set_current_user(auth::userinfo_by_name(None)); let api_state = ApiState { sql_handler, @@ -180,7 +180,7 @@ async fn test_sql_output_rows() { #[tokio::test] async fn test_dashboard_sql_limit() { let sql_handler = create_testing_sql_query_handler(MemTable::specified_numbers_table(2000)); - let ctx = QueryContext::arc(); + let ctx = QueryContext::with_db_name(None); ctx.set_current_user(auth::userinfo_by_name(None)); let api_state = ApiState { sql_handler, @@ -226,7 +226,7 @@ async fn test_sql_form() { let sql_handler = create_testing_sql_query_handler(MemTable::default_numbers_table()); - let ctx = QueryContext::arc(); + let ctx = QueryContext::with_db_name(None); ctx.set_current_user(auth::userinfo_by_name(None)); let api_state = ApiState { sql_handler, diff --git a/src/session/Cargo.toml b/src/session/Cargo.toml index 8e0baeaa0f..b6dbb00955 100644 --- a/src/session/Cargo.toml +++ b/src/session/Cargo.toml @@ -20,5 +20,6 @@ common-macro.workspace = true common-telemetry.workspace = true common-time.workspace = true derive_builder.workspace = true +meter-core.workspace = true snafu.workspace = true sql.workspace = true diff --git a/src/session/src/context.rs b/src/session/src/context.rs index 4a9b2bf18e..15ea0a8baf 100644 --- a/src/session/src/context.rs +++ b/src/session/src/context.rs @@ -25,7 +25,7 @@ use common_catalog::{build_db_string, parse_catalog_and_schema_from_db_string}; use common_time::timezone::parse_timezone; use common_time::Timezone; use derive_builder::Builder; -use sql::dialect::{Dialect, GreptimeDbDialect, MySqlDialect, PostgreSqlDialect}; +use sql::dialect::{Dialect, GenericDialect, GreptimeDbDialect, MySqlDialect, PostgreSqlDialect}; use crate::session_config::{PGByteaOutputValue, PGDateOrder, PGDateTimeStyle}; use crate::MutableInner; @@ -47,6 +47,9 @@ pub struct QueryContext { // The configuration parameter are used to store the parameters that are set by the user #[builder(default)] configuration_parameter: Arc, + // Track which protocol the query comes from. + #[builder(default)] + channel: Channel, } impl Display for QueryContext { @@ -94,7 +97,8 @@ impl From<&RegionRequestHeader> for QueryContext { .current_catalog(ctx.current_catalog.clone()) .current_schema(ctx.current_schema.clone()) .timezone(parse_timezone(Some(&ctx.timezone))) - .extensions(ctx.extensions.clone()); + .extensions(ctx.extensions.clone()) + .channel(ctx.channel.into()); } builder.build() } @@ -107,6 +111,7 @@ impl From for QueryContext { .current_schema(ctx.current_schema) .timezone(parse_timezone(Some(&ctx.timezone))) .extensions(ctx.extensions) + .channel(ctx.channel.into()) .build() } } @@ -117,6 +122,7 @@ impl From for api::v1::QueryContext { current_catalog, mutable_inner, extensions, + channel, .. }: QueryContext, ) -> Self { @@ -126,6 +132,7 @@ impl From for api::v1::QueryContext { current_schema: mutable_inner.schema.clone(), timezone: mutable_inner.timezone.to_string(), extensions, + channel: channel as u32, } } } @@ -142,6 +149,14 @@ impl QueryContext { .build() } + pub fn with_channel(catalog: &str, schema: &str, channel: Channel) -> QueryContext { + QueryContextBuilder::default() + .current_catalog(catalog.to_string()) + .current_schema(schema.to_string()) + .channel(channel) + .build() + } + pub fn with_db_name(db_name: Option<&str>) -> QueryContext { let (catalog, schema) = db_name .map(|db| { @@ -172,6 +187,10 @@ impl QueryContext { &self.current_catalog } + pub fn set_current_catalog(&mut self, new_catalog: &str) { + self.current_catalog = new_catalog.to_string(); + } + pub fn sql_dialect(&self) -> &(dyn Dialect + Send + Sync) { &*self.sql_dialect } @@ -224,6 +243,14 @@ impl QueryContext { pub fn configuration_parameter(&self) -> &ConfigurationVariables { &self.configuration_parameter } + + pub fn channel(&self) -> Channel { + self.channel + } + + pub fn set_channel(&mut self, channel: Channel) { + self.channel = channel; + } } impl QueryContextBuilder { @@ -238,6 +265,7 @@ impl QueryContextBuilder { .unwrap_or_else(|| Arc::new(GreptimeDbDialect {})), extensions: self.extensions.unwrap_or_default(), configuration_parameter: self.configuration_parameter.unwrap_or_default(), + channel: self.channel.unwrap_or_default(), } } @@ -247,17 +275,6 @@ impl QueryContextBuilder { .insert(key, value); self } - - pub fn from_existing(context: &QueryContext) -> QueryContextBuilder { - QueryContextBuilder { - current_catalog: Some(context.current_catalog.clone()), - // note that this is a shallow copy - mutable_inner: Some(context.mutable_inner.clone()), - sql_dialect: Some(context.sql_dialect.clone()), - extensions: Some(context.extensions.clone()), - configuration_parameter: Some(context.configuration_parameter.clone()), - } - } } #[derive(Debug)] @@ -289,10 +306,37 @@ impl ConnInfo { } } -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Default, Clone, Copy)] +#[repr(u8)] pub enum Channel { - Mysql, - Postgres, + #[default] + Unknown = 0, + + Mysql = 1, + Postgres = 2, + Http = 3, + Prometheus = 4, + Otlp = 5, + Grpc = 6, + Influx = 7, + Opentsdb = 8, +} + +impl From for Channel { + fn from(value: u32) -> Self { + match value { + 1 => Self::Mysql, + 2 => Self::Postgres, + 3 => Self::Http, + 4 => Self::Prometheus, + 5 => Self::Otlp, + 6 => Self::Grpc, + 7 => Self::Influx, + 8 => Self::Opentsdb, + + _ => Self::Unknown, + } + } } impl Channel { @@ -300,6 +344,7 @@ impl Channel { match self { Channel::Mysql => Arc::new(MySqlDialect {}), Channel::Postgres => Arc::new(PostgreSqlDialect {}), + _ => Arc::new(GenericDialect {}), } } } @@ -309,6 +354,13 @@ impl Display for Channel { match self { Channel::Mysql => write!(f, "mysql"), Channel::Postgres => write!(f, "postgres"), + Channel::Http => write!(f, "http"), + Channel::Prometheus => write!(f, "prometheus"), + Channel::Otlp => write!(f, "otlp"), + Channel::Grpc => write!(f, "grpc"), + Channel::Influx => write!(f, "influx"), + Channel::Opentsdb => write!(f, "opentsdb"), + Channel::Unknown => write!(f, "unknown"), } } } diff --git a/src/session/src/lib.rs b/src/session/src/lib.rs index d9f3fcad18..ecfc02f230 100644 --- a/src/session/src/lib.rs +++ b/src/session/src/lib.rs @@ -79,6 +79,7 @@ impl Session { .mutable_inner(self.mutable_inner.clone()) .sql_dialect(self.conn_info.channel.dialect()) .configuration_parameter(self.configuration_variables.clone()) + .channel(self.conn_info.channel) .build() .into() } diff --git a/src/sql/src/dialect.rs b/src/sql/src/dialect.rs index 5060444d0b..7b303ee286 100644 --- a/src/sql/src/dialect.rs +++ b/src/sql/src/dialect.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -pub use sqlparser::dialect::{Dialect, MySqlDialect, PostgreSqlDialect}; +pub use sqlparser::dialect::{Dialect, GenericDialect, MySqlDialect, PostgreSqlDialect}; /// GreptimeDb dialect #[derive(Debug, Clone)]