1use std::str::FromStr;
18use std::time::Instant;
19
20use api::helper::request_type;
21use api::v1::{GreptimeRequest, RequestHeader};
22use auth::UserProviderRef;
23use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
24use common_catalog::parse_catalog_and_schema_from_db_string;
25use common_error::ext::ErrorExt;
26use common_error::status_code::StatusCode;
27use common_grpc::flight::do_put::DoPutResponse;
28use common_query::Output;
29use common_runtime::Runtime;
30use common_runtime::runtime::RuntimeTrait;
31use common_session::ReadPreference;
32use common_telemetry::tracing_context::{FutureExt, TracingContext};
33use common_telemetry::{debug, error, tracing, warn};
34use common_time::timezone::parse_timezone;
35use futures_util::StreamExt;
36use session::context::{Channel, QueryContextBuilder, QueryContextRef};
37use session::hints::{READ_PREFERENCE_HINT, is_reserved_extension_key};
38use snafu::{OptionExt, ResultExt};
39use tokio::sync::mpsc;
40use tokio::sync::mpsc::error::TrySendError;
41use tonic::Status;
42
43use crate::error::{InvalidQuerySnafu, JoinTaskSnafu, Result, UnknownHintSnafu};
44use crate::grpc::flight::PutRecordBatchRequestStream;
45use crate::grpc::{FlightCompression, TonicResult, context_auth};
46use crate::metrics::{self, METRIC_SERVER_GRPC_DB_REQUEST_TIMER};
47use crate::query_handler::grpc::ServerGrpcQueryHandlerRef;
48
49#[derive(Clone)]
50pub struct GreptimeRequestHandler {
51 handler: ServerGrpcQueryHandlerRef,
52 pub(crate) user_provider: Option<UserProviderRef>,
53 runtime: Option<Runtime>,
54 pub(crate) flight_compression: FlightCompression,
55}
56
57impl GreptimeRequestHandler {
58 pub fn new(
59 handler: ServerGrpcQueryHandlerRef,
60 user_provider: Option<UserProviderRef>,
61 runtime: Option<Runtime>,
62 flight_compression: FlightCompression,
63 ) -> Self {
64 Self {
65 handler,
66 user_provider,
67 runtime,
68 flight_compression,
69 }
70 }
71
72 #[tracing::instrument(skip_all, fields(protocol = "grpc", request_type = get_request_type(&request)))]
73 pub(crate) async fn handle_request(
74 &self,
75 request: GreptimeRequest,
76 hints: Vec<(String, String)>,
77 ) -> Result<Output> {
78 let query = request.request.context(InvalidQuerySnafu {
79 reason: "Expecting non-empty GreptimeRequest.",
80 })?;
81
82 let header = request.header.as_ref();
83 let query_ctx = create_query_context(Channel::Grpc, header, hints)?;
84 let user_info = context_auth::auth(self.user_provider.clone(), header, &query_ctx).await?;
85 query_ctx.set_current_user(user_info);
86
87 let handler = self.handler.clone();
88 let request_type = request_type(&query).to_string();
89 let db = query_ctx.get_db_string();
90 let timer = RequestTimer::new(db.clone(), request_type);
91 let tracing_context = TracingContext::from_current_span();
92
93 let result_future = async move {
94 handler
95 .do_query(query, query_ctx)
96 .trace(tracing_context.attach(tracing::info_span!(
97 "GreptimeRequestHandler::handle_request_runtime"
98 )))
99 .await
100 .map_err(|e| {
101 if e.status_code().should_log_error() {
102 let root_error = e.root_cause().unwrap_or(&e);
103 error!(e; "Failed to handle request, error: {}", root_error.to_string());
104 } else {
105 debug!("Failed to handle request, err: {:?}", e);
107 }
108 e
109 })
110 };
111
112 match &self.runtime {
113 Some(runtime) => {
114 runtime
122 .spawn(result_future)
123 .await
124 .context(JoinTaskSnafu)
125 .inspect_err(|e| {
126 timer.record(e.status_code());
127 })?
128 }
129 None => result_future.await,
130 }
131 }
132
133 pub(crate) async fn put_record_batches(
134 &self,
135 stream: PutRecordBatchRequestStream,
136 result_sender: mpsc::Sender<TonicResult<DoPutResponse>>,
137 query_ctx: QueryContextRef,
138 ) {
139 let handler = self.handler.clone();
140 let runtime = self
141 .runtime
142 .clone()
143 .unwrap_or_else(common_runtime::global_runtime);
144 runtime.spawn(async move {
145 let mut result_stream = handler.handle_put_record_batch_stream(stream, query_ctx);
146
147 while let Some(result) = result_stream.next().await {
148 match &result {
149 Ok(response) => {
150 metrics::GRPC_BULK_INSERT_ELAPSED.observe(response.elapsed_secs());
152 }
153 Err(e) => {
154 error!(e; "Failed to handle flight record batches");
155 }
156 }
157
158 if let Err(e) =
159 result_sender.try_send(result.map_err(|e| Status::from_error(Box::new(e))))
160 && let TrySendError::Closed(_) = e
161 {
162 warn!(r#""DoPut" client maybe unreachable, abort handling its message"#);
163 break;
164 }
165 }
166 });
167 }
168}
169
170pub fn get_request_type(request: &GreptimeRequest) -> &'static str {
171 request
172 .request
173 .as_ref()
174 .map(request_type)
175 .unwrap_or_default()
176}
177
178pub(crate) fn create_query_context(
181 channel: Channel,
182 header: Option<&RequestHeader>,
183 mut extensions: Vec<(String, String)>,
184) -> Result<QueryContextRef> {
185 let (catalog, schema) = header
186 .map(|header| {
187 if !header.dbname.is_empty() {
190 parse_catalog_and_schema_from_db_string(&header.dbname)
191 } else {
192 (
193 if !header.catalog.is_empty() {
194 header.catalog.to_lowercase()
195 } else {
196 DEFAULT_CATALOG_NAME.to_string()
197 },
198 if !header.schema.is_empty() {
199 header.schema.to_lowercase()
200 } else {
201 DEFAULT_SCHEMA_NAME.to_string()
202 },
203 )
204 }
205 })
206 .unwrap_or_else(|| {
207 (
208 DEFAULT_CATALOG_NAME.to_string(),
209 DEFAULT_SCHEMA_NAME.to_string(),
210 )
211 });
212 let timezone = parse_timezone(header.map(|h| h.timezone.as_str()));
213 let mut ctx_builder = QueryContextBuilder::default()
214 .current_catalog(catalog)
215 .current_schema(schema)
216 .timezone(timezone)
217 .channel(channel);
218
219 if let Some(x) = extensions
220 .iter()
221 .position(|(k, _)| k == READ_PREFERENCE_HINT)
222 {
223 let (k, v) = extensions.swap_remove(x);
224 let Ok(read_preference) = ReadPreference::from_str(&v) else {
225 return UnknownHintSnafu {
226 hint: format!("{k}={v}"),
227 }
228 .fail();
229 };
230 ctx_builder = ctx_builder.read_preference(read_preference);
231 }
232
233 for (key, value) in extensions {
234 if is_reserved_extension_key(&key) {
235 debug!(
236 key = key.as_str(),
237 "Ignoring reserved external query context extension key"
238 );
239 continue;
240 }
241 ctx_builder = ctx_builder.set_extension(key, value);
242 }
243 Ok(ctx_builder.build().into())
244}
245
246pub(crate) struct RequestTimer {
250 start: Instant,
251 db: String,
252 request_type: String,
253 status_code: StatusCode,
254}
255
256impl RequestTimer {
257 pub fn new(db: String, request_type: String) -> RequestTimer {
259 RequestTimer {
260 start: Instant::now(),
261 db,
262 request_type,
263 status_code: StatusCode::Success,
264 }
265 }
266
267 pub fn record(mut self, status_code: StatusCode) {
269 self.status_code = status_code;
270 }
271}
272
273impl Drop for RequestTimer {
274 fn drop(&mut self) {
275 METRIC_SERVER_GRPC_DB_REQUEST_TIMER
276 .with_label_values(&[
277 self.db.as_str(),
278 self.request_type.as_str(),
279 self.status_code.as_ref(),
280 ])
281 .observe(self.start.elapsed().as_secs_f64());
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use chrono::FixedOffset;
288 use common_time::Timezone;
289 use session::hints::REMOTE_QUERY_ID_EXTENSION_KEY;
290
291 use super::*;
292
293 #[test]
294 fn test_create_query_context() {
295 let header = RequestHeader {
296 catalog: "cat-a-log".to_string(),
297 timezone: "+01:00".to_string(),
298 ..Default::default()
299 };
300 let query_context = create_query_context(
301 Channel::Unknown,
302 Some(&header),
303 vec![
304 ("auto_create_table".to_string(), "true".to_string()),
305 ("read_preference".to_string(), "leader".to_string()),
306 ],
307 )
308 .unwrap();
309 assert_eq!(query_context.current_catalog(), "cat-a-log");
310 assert_eq!(query_context.current_schema(), DEFAULT_SCHEMA_NAME);
311 assert_eq!(
312 query_context.timezone(),
313 Timezone::Offset(FixedOffset::east_opt(3600).unwrap())
314 );
315 assert!(matches!(
316 query_context.read_preference(),
317 ReadPreference::Leader
318 ));
319 let mut extensions = query_context.extensions().into_iter().collect::<Vec<_>>();
320 extensions.sort_unstable_by(|a, b| a.0.cmp(&b.0));
321 assert_eq!(
322 extensions[0],
323 ("auto_create_table".to_string(), "true".to_string())
324 );
325 assert_eq!(extensions[1].0, REMOTE_QUERY_ID_EXTENSION_KEY.to_string());
326 assert_eq!(
327 query_context.remote_query_id(),
328 Some(extensions[1].1.as_str())
329 );
330 }
331
332 #[test]
333 fn test_create_query_context_ignores_remote_query_id_extension() {
334 let query_context = create_query_context(
335 Channel::Grpc,
336 None,
337 vec![(
338 REMOTE_QUERY_ID_EXTENSION_KEY.to_string(),
339 "spoofed-query-id".to_string(),
340 )],
341 )
342 .unwrap();
343
344 assert_ne!(query_context.remote_query_id(), Some("spoofed-query-id"));
345 assert_eq!(
346 query_context.extension(REMOTE_QUERY_ID_EXTENSION_KEY),
347 query_context.remote_query_id()
348 );
349 }
350}