Skip to main content

servers/grpc/
greptime_handler.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Handler for Greptime Database service. It's implemented by frontend.
16
17use std::str::FromStr;
18use std::time::Instant;
19
20use api::helper::request_type;
21use api::v1::{GreptimeRequest, RequestHeader};
22use auth::UserProviderRef;
23use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
24use common_catalog::parse_catalog_and_schema_from_db_string;
25use common_error::ext::ErrorExt;
26use common_error::status_code::StatusCode;
27use common_grpc::flight::do_put::DoPutResponse;
28use common_query::Output;
29use common_runtime::Runtime;
30use common_runtime::runtime::RuntimeTrait;
31use common_session::ReadPreference;
32use common_telemetry::tracing_context::{FutureExt, TracingContext};
33use common_telemetry::{debug, error, tracing, warn};
34use common_time::timezone::parse_timezone;
35use futures_util::StreamExt;
36use session::context::{Channel, QueryContextBuilder, QueryContextRef};
37use session::hints::{READ_PREFERENCE_HINT, is_reserved_extension_key};
38use snafu::{OptionExt, ResultExt};
39use tokio::sync::mpsc;
40use tokio::sync::mpsc::error::TrySendError;
41use tonic::Status;
42
43use crate::error::{InvalidQuerySnafu, JoinTaskSnafu, Result, UnknownHintSnafu};
44use crate::grpc::flight::PutRecordBatchRequestStream;
45use crate::grpc::{FlightCompression, TonicResult, context_auth};
46use crate::metrics::{self, METRIC_SERVER_GRPC_DB_REQUEST_TIMER};
47use crate::query_handler::grpc::ServerGrpcQueryHandlerRef;
48
49#[derive(Clone)]
50pub struct GreptimeRequestHandler {
51    handler: ServerGrpcQueryHandlerRef,
52    pub(crate) user_provider: Option<UserProviderRef>,
53    runtime: Option<Runtime>,
54    pub(crate) flight_compression: FlightCompression,
55}
56
57impl GreptimeRequestHandler {
58    pub fn new(
59        handler: ServerGrpcQueryHandlerRef,
60        user_provider: Option<UserProviderRef>,
61        runtime: Option<Runtime>,
62        flight_compression: FlightCompression,
63    ) -> Self {
64        Self {
65            handler,
66            user_provider,
67            runtime,
68            flight_compression,
69        }
70    }
71
72    #[tracing::instrument(skip_all, fields(protocol = "grpc", request_type = get_request_type(&request)))]
73    pub(crate) async fn handle_request(
74        &self,
75        request: GreptimeRequest,
76        hints: Vec<(String, String)>,
77    ) -> Result<Output> {
78        let query = request.request.context(InvalidQuerySnafu {
79            reason: "Expecting non-empty GreptimeRequest.",
80        })?;
81
82        let header = request.header.as_ref();
83        let query_ctx = create_query_context(Channel::Grpc, header, hints)?;
84        let user_info = context_auth::auth(self.user_provider.clone(), header, &query_ctx).await?;
85        query_ctx.set_current_user(user_info);
86
87        let handler = self.handler.clone();
88        let request_type = request_type(&query).to_string();
89        let db = query_ctx.get_db_string();
90        let timer = RequestTimer::new(db.clone(), request_type);
91        let tracing_context = TracingContext::from_current_span();
92
93        let result_future = async move {
94            handler
95                .do_query(query, query_ctx)
96                .trace(tracing_context.attach(tracing::info_span!(
97                    "GreptimeRequestHandler::handle_request_runtime"
98                )))
99                .await
100                .map_err(|e| {
101                    if e.status_code().should_log_error() {
102                        let root_error = e.root_cause().unwrap_or(&e);
103                        error!(e; "Failed to handle request, error: {}", root_error.to_string());
104                    } else {
105                        // Currently, we still print a debug log.
106                        debug!("Failed to handle request, err: {:?}", e);
107                    }
108                    e
109                })
110        };
111
112        match &self.runtime {
113            Some(runtime) => {
114                // Executes requests in another runtime to
115                // 1. prevent the execution from being cancelled unexpected by Tonic runtime;
116                //   - Refer to our blog for the rational behind it:
117                //     https://www.greptime.com/blogs/2023-01-12-hidden-control-flow.html
118                //   - Obtaining a `JoinHandle` to get the panic message (if there's any).
119                //     From its docs, `JoinHandle` is cancel safe. The task keeps running even it's handle been dropped.
120                // 2. avoid the handler blocks the gRPC runtime incidentally.
121                runtime
122                    .spawn(result_future)
123                    .await
124                    .context(JoinTaskSnafu)
125                    .inspect_err(|e| {
126                        timer.record(e.status_code());
127                    })?
128            }
129            None => result_future.await,
130        }
131    }
132
133    pub(crate) async fn put_record_batches(
134        &self,
135        stream: PutRecordBatchRequestStream,
136        result_sender: mpsc::Sender<TonicResult<DoPutResponse>>,
137        query_ctx: QueryContextRef,
138    ) {
139        let handler = self.handler.clone();
140        let runtime = self
141            .runtime
142            .clone()
143            .unwrap_or_else(common_runtime::global_runtime);
144        runtime.spawn(async move {
145            let mut result_stream = handler.handle_put_record_batch_stream(stream, query_ctx);
146
147            while let Some(result) = result_stream.next().await {
148                match &result {
149                    Ok(response) => {
150                        // Record the elapsed time metric from the response
151                        metrics::GRPC_BULK_INSERT_ELAPSED.observe(response.elapsed_secs());
152                    }
153                    Err(e) => {
154                        error!(e; "Failed to handle flight record batches");
155                    }
156                }
157
158                if let Err(e) =
159                    result_sender.try_send(result.map_err(|e| Status::from_error(Box::new(e))))
160                    && let TrySendError::Closed(_) = e
161                {
162                    warn!(r#""DoPut" client maybe unreachable, abort handling its message"#);
163                    break;
164                }
165            }
166        });
167    }
168}
169
170pub fn get_request_type(request: &GreptimeRequest) -> &'static str {
171    request
172        .request
173        .as_ref()
174        .map(request_type)
175        .unwrap_or_default()
176}
177
178/// Creates a new `QueryContext` from the provided request header and extensions.
179/// Strongly recommend setting an appropriate channel, as this is very helpful for statistics.
180pub(crate) fn create_query_context(
181    channel: Channel,
182    header: Option<&RequestHeader>,
183    mut extensions: Vec<(String, String)>,
184) -> Result<QueryContextRef> {
185    let (catalog, schema) = header
186        .map(|header| {
187            // We provide dbname field in newer versions of protos/sdks
188            // parse dbname from header in priority
189            if !header.dbname.is_empty() {
190                parse_catalog_and_schema_from_db_string(&header.dbname)
191            } else {
192                (
193                    if !header.catalog.is_empty() {
194                        header.catalog.to_lowercase()
195                    } else {
196                        DEFAULT_CATALOG_NAME.to_string()
197                    },
198                    if !header.schema.is_empty() {
199                        header.schema.to_lowercase()
200                    } else {
201                        DEFAULT_SCHEMA_NAME.to_string()
202                    },
203                )
204            }
205        })
206        .unwrap_or_else(|| {
207            (
208                DEFAULT_CATALOG_NAME.to_string(),
209                DEFAULT_SCHEMA_NAME.to_string(),
210            )
211        });
212    let timezone = parse_timezone(header.map(|h| h.timezone.as_str()));
213    let mut ctx_builder = QueryContextBuilder::default()
214        .current_catalog(catalog)
215        .current_schema(schema)
216        .timezone(timezone)
217        .channel(channel);
218
219    if let Some(x) = extensions
220        .iter()
221        .position(|(k, _)| k == READ_PREFERENCE_HINT)
222    {
223        let (k, v) = extensions.swap_remove(x);
224        let Ok(read_preference) = ReadPreference::from_str(&v) else {
225            return UnknownHintSnafu {
226                hint: format!("{k}={v}"),
227            }
228            .fail();
229        };
230        ctx_builder = ctx_builder.read_preference(read_preference);
231    }
232
233    for (key, value) in extensions {
234        if is_reserved_extension_key(&key) {
235            debug!(
236                key = key.as_str(),
237                "Ignoring reserved external query context extension key"
238            );
239            continue;
240        }
241        ctx_builder = ctx_builder.set_extension(key, value);
242    }
243    Ok(ctx_builder.build().into())
244}
245
246/// Histogram timer for handling gRPC request.
247///
248/// The timer records the elapsed time with [StatusCode::Success] on drop.
249pub(crate) struct RequestTimer {
250    start: Instant,
251    db: String,
252    request_type: String,
253    status_code: StatusCode,
254}
255
256impl RequestTimer {
257    /// Returns a new timer.
258    pub fn new(db: String, request_type: String) -> RequestTimer {
259        RequestTimer {
260            start: Instant::now(),
261            db,
262            request_type,
263            status_code: StatusCode::Success,
264        }
265    }
266
267    /// Consumes the timer and record the elapsed time with specific `status_code`.
268    pub fn record(mut self, status_code: StatusCode) {
269        self.status_code = status_code;
270    }
271}
272
273impl Drop for RequestTimer {
274    fn drop(&mut self) {
275        METRIC_SERVER_GRPC_DB_REQUEST_TIMER
276            .with_label_values(&[
277                self.db.as_str(),
278                self.request_type.as_str(),
279                self.status_code.as_ref(),
280            ])
281            .observe(self.start.elapsed().as_secs_f64());
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use chrono::FixedOffset;
288    use common_time::Timezone;
289    use session::hints::REMOTE_QUERY_ID_EXTENSION_KEY;
290
291    use super::*;
292
293    #[test]
294    fn test_create_query_context() {
295        let header = RequestHeader {
296            catalog: "cat-a-log".to_string(),
297            timezone: "+01:00".to_string(),
298            ..Default::default()
299        };
300        let query_context = create_query_context(
301            Channel::Unknown,
302            Some(&header),
303            vec![
304                ("auto_create_table".to_string(), "true".to_string()),
305                ("read_preference".to_string(), "leader".to_string()),
306            ],
307        )
308        .unwrap();
309        assert_eq!(query_context.current_catalog(), "cat-a-log");
310        assert_eq!(query_context.current_schema(), DEFAULT_SCHEMA_NAME);
311        assert_eq!(
312            query_context.timezone(),
313            Timezone::Offset(FixedOffset::east_opt(3600).unwrap())
314        );
315        assert!(matches!(
316            query_context.read_preference(),
317            ReadPreference::Leader
318        ));
319        let mut extensions = query_context.extensions().into_iter().collect::<Vec<_>>();
320        extensions.sort_unstable_by(|a, b| a.0.cmp(&b.0));
321        assert_eq!(
322            extensions[0],
323            ("auto_create_table".to_string(), "true".to_string())
324        );
325        assert_eq!(extensions[1].0, REMOTE_QUERY_ID_EXTENSION_KEY.to_string());
326        assert_eq!(
327            query_context.remote_query_id(),
328            Some(extensions[1].1.as_str())
329        );
330    }
331
332    #[test]
333    fn test_create_query_context_ignores_remote_query_id_extension() {
334        let query_context = create_query_context(
335            Channel::Grpc,
336            None,
337            vec![(
338                REMOTE_QUERY_ID_EXTENSION_KEY.to_string(),
339                "spoofed-query-id".to_string(),
340            )],
341        )
342        .unwrap();
343
344        assert_ne!(query_context.remote_query_id(), Some("spoofed-query-id"));
345        assert_eq!(
346            query_context.extension(REMOTE_QUERY_ID_EXTENSION_KEY),
347            query_context.remote_query_id()
348        );
349    }
350}