servers/
http.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::convert::Infallible;
17use std::fmt::Display;
18use std::net::SocketAddr;
19use std::sync::Mutex as StdMutex;
20use std::time::Duration;
21
22use async_trait::async_trait;
23use auth::UserProviderRef;
24use axum::extract::{DefaultBodyLimit, Request};
25use axum::http::StatusCode as HttpStatusCode;
26use axum::response::{IntoResponse, Response};
27use axum::routing::Route;
28use axum::serve::ListenerExt;
29use axum::{Router, middleware, routing};
30use common_base::Plugins;
31use common_base::readable_size::ReadableSize;
32use common_recordbatch::RecordBatch;
33use common_telemetry::{debug, error, info};
34use common_time::Timestamp;
35use common_time::timestamp::TimeUnit;
36use datatypes::data_type::DataType;
37use datatypes::schema::SchemaRef;
38use event::{LogState, LogValidatorRef};
39use futures::FutureExt;
40use http::{HeaderValue, Method};
41use prost::DecodeError;
42use serde::{Deserialize, Serialize};
43use serde_json::Value;
44use snafu::{ResultExt, ensure};
45use tokio::sync::Mutex;
46use tokio::sync::oneshot::{self, Sender};
47use tonic::codegen::Service;
48use tower::{Layer, ServiceBuilder};
49use tower_http::compression::CompressionLayer;
50use tower_http::cors::{AllowOrigin, Any, CorsLayer};
51use tower_http::decompression::RequestDecompressionLayer;
52use tower_http::trace::TraceLayer;
53
54use self::authorize::AuthState;
55use self::result::table_result::TableResponse;
56use crate::configurator::HttpConfiguratorRef;
57use crate::elasticsearch;
58use crate::error::{
59    AddressBindSnafu, AlreadyStartedSnafu, Error, InternalIoSnafu, InvalidHeaderValueSnafu,
60    OtherSnafu, Result,
61};
62use crate::http::influxdb::{influxdb_health, influxdb_ping, influxdb_write_v1, influxdb_write_v2};
63use crate::http::otlp::OtlpState;
64use crate::http::prom_store::PromStoreState;
65use crate::http::prometheus::{
66    build_info_query, format_query, instant_query, label_values_query, labels_query, parse_query,
67    range_query, series_query,
68};
69use crate::http::result::arrow_result::ArrowResponse;
70use crate::http::result::csv_result::CsvResponse;
71use crate::http::result::error_result::ErrorResponse;
72use crate::http::result::greptime_result_v1::GreptimedbV1Response;
73use crate::http::result::influxdb_result_v1::InfluxdbV1Response;
74use crate::http::result::json_result::JsonResponse;
75use crate::http::result::null_result::NullResponse;
76use crate::interceptor::LogIngestInterceptorRef;
77use crate::metrics::http_metrics_layer;
78use crate::metrics_handler::MetricsHandler;
79use crate::prometheus_handler::PrometheusHandlerRef;
80use crate::query_handler::sql::ServerSqlQueryHandlerRef;
81use crate::query_handler::{
82    InfluxdbLineProtocolHandlerRef, JaegerQueryHandlerRef, LogQueryHandlerRef,
83    OpenTelemetryProtocolHandlerRef, OpentsdbProtocolHandlerRef, PipelineHandlerRef,
84    PromStoreProtocolHandlerRef,
85};
86use crate::request_memory_limiter::ServerMemoryLimiter;
87use crate::server::Server;
88
89pub mod authorize;
90#[cfg(feature = "dashboard")]
91mod dashboard;
92pub mod dyn_log;
93pub mod dyn_trace;
94pub mod event;
95pub mod extractor;
96pub mod handler;
97pub mod header;
98pub mod influxdb;
99pub mod jaeger;
100pub mod logs;
101pub mod loki;
102pub mod mem_prof;
103mod memory_limit;
104pub mod opentsdb;
105pub mod otlp;
106pub mod pprof;
107pub mod prom_store;
108pub mod prometheus;
109pub mod result;
110mod timeout;
111pub mod utils;
112
113use result::HttpOutputWriter;
114pub(crate) use timeout::DynamicTimeoutLayer;
115
116mod hints;
117mod read_preference;
118#[cfg(any(test, feature = "testing"))]
119pub mod test_helpers;
120
121pub const HTTP_API_VERSION: &str = "v1";
122pub const HTTP_API_PREFIX: &str = "/v1/";
123/// Default http body limit (64M).
124const DEFAULT_BODY_LIMIT: ReadableSize = ReadableSize::mb(64);
125
126/// Authorization header
127pub const AUTHORIZATION_HEADER: &str = "x-greptime-auth";
128
129// TODO(fys): This is a temporary workaround, it will be improved later
130pub static PUBLIC_APIS: [&str; 3] = ["/v1/influxdb/ping", "/v1/influxdb/health", "/v1/health"];
131
132#[derive(Default)]
133pub struct HttpServer {
134    router: StdMutex<Router>,
135    shutdown_tx: Mutex<Option<Sender<()>>>,
136    user_provider: Option<UserProviderRef>,
137    memory_limiter: ServerMemoryLimiter,
138
139    // plugins
140    plugins: Plugins,
141
142    // server configs
143    options: HttpOptions,
144    bind_addr: Option<SocketAddr>,
145}
146
147#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
148#[serde(default)]
149pub struct HttpOptions {
150    pub addr: String,
151
152    #[serde(with = "humantime_serde")]
153    pub timeout: Duration,
154
155    #[serde(skip)]
156    pub disable_dashboard: bool,
157
158    pub body_limit: ReadableSize,
159
160    /// Validation mode while decoding Prometheus remote write requests.
161    pub prom_validation_mode: PromValidationMode,
162
163    pub cors_allowed_origins: Vec<String>,
164
165    pub enable_cors: bool,
166}
167
168#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
169#[serde(rename_all = "snake_case")]
170pub enum PromValidationMode {
171    /// Force UTF8 validation
172    Strict,
173    /// Allow lossy UTF8 strings
174    Lossy,
175    /// Do not validate UTF8 strings.
176    Unchecked,
177}
178
179/// Returns `true` if the given byte slice is a valid Prometheus label name,
180/// i.e., it matches `[a-zA-Z_][a-zA-Z0-9_]*`.
181///
182/// Since the allowed characters are pure ASCII, valid label names are
183/// also valid UTF-8 by definition.
184const IS_VALID_LABEL_REST: [bool; 256] = {
185    let mut table = [false; 256];
186    let mut i = 0;
187    while i < 256 {
188        let b = i as u8;
189        table[i] = b.is_ascii_alphanumeric() || b == b'_';
190        i += 1;
191    }
192    table
193};
194
195#[inline]
196pub fn validate_label_name(name: &[u8]) -> bool {
197    if name.is_empty() {
198        return false;
199    }
200    let first = name[0];
201    if !(first.is_ascii_alphabetic() || first == b'_') {
202        return false;
203    }
204
205    let mut rest = &name[1..];
206    while rest.len() >= 8 {
207        let res = IS_VALID_LABEL_REST[rest[0] as usize] as u8
208            & IS_VALID_LABEL_REST[rest[1] as usize] as u8
209            & IS_VALID_LABEL_REST[rest[2] as usize] as u8
210            & IS_VALID_LABEL_REST[rest[3] as usize] as u8
211            & IS_VALID_LABEL_REST[rest[4] as usize] as u8
212            & IS_VALID_LABEL_REST[rest[5] as usize] as u8
213            & IS_VALID_LABEL_REST[rest[6] as usize] as u8
214            & IS_VALID_LABEL_REST[rest[7] as usize] as u8;
215
216        if res == 0 {
217            return false;
218        }
219        rest = &rest[8..];
220    }
221
222    for &b in rest {
223        if !IS_VALID_LABEL_REST[b as usize] {
224            return false;
225        }
226    }
227
228    true
229}
230
231impl PromValidationMode {
232    /// Decodes provided bytes to [String] with optional UTF-8 validation.
233    ///
234    /// Use this for label **values** and other general string fields.
235    pub fn decode_string(&self, bytes: &[u8]) -> std::result::Result<String, DecodeError> {
236        let result = match self {
237            PromValidationMode::Strict => match String::from_utf8(bytes.to_vec()) {
238                Ok(s) => s,
239                Err(e) => {
240                    debug!("Invalid UTF-8 string value: {:?}, error: {:?}", bytes, e);
241                    return Err(DecodeError::new("invalid utf-8"));
242                }
243            },
244            PromValidationMode::Lossy => String::from_utf8_lossy(bytes).to_string(),
245            PromValidationMode::Unchecked => unsafe { String::from_utf8_unchecked(bytes.to_vec()) },
246        };
247        Ok(result)
248    }
249
250    /// Decodes provided bytes to a label name [`&str`] with Prometheus label name validation.
251    ///
252    /// The check is always performed regardless of [`PromValidationMode`], as required by
253    /// the [Prometheus data model](https://prometheus.io/docs/concepts/data_model/#metric-names-and-labels).
254    pub fn decode_label_name<'a>(
255        &self,
256        bytes: &'a [u8],
257    ) -> std::result::Result<&'a str, DecodeError> {
258        if !validate_label_name(bytes) {
259            debug!(
260                "Invalid Prometheus label name: {:?}, must match [a-zA-Z_][a-zA-Z0-9_]*",
261                bytes
262            );
263            return Err(DecodeError::new(format!(
264                "invalid prometheus label name: '{}', must match [a-zA-Z_][a-zA-Z0-9_]*",
265                String::from_utf8_lossy(bytes)
266            )));
267        }
268
269        // SAFETY: `validate_label_name` only allows ASCII bytes,
270        // and ASCII is always valid UTF-8.
271        Ok(unsafe { std::str::from_utf8_unchecked(bytes) })
272    }
273}
274
275impl Default for HttpOptions {
276    fn default() -> Self {
277        Self {
278            addr: "127.0.0.1:4000".to_string(),
279            timeout: Duration::from_secs(0),
280            disable_dashboard: false,
281            body_limit: DEFAULT_BODY_LIMIT,
282            cors_allowed_origins: Vec::new(),
283            enable_cors: true,
284            prom_validation_mode: PromValidationMode::Strict,
285        }
286    }
287}
288
289#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
290pub struct ColumnSchema {
291    name: String,
292    data_type: String,
293}
294
295impl ColumnSchema {
296    pub fn new(name: String, data_type: String) -> ColumnSchema {
297        ColumnSchema { name, data_type }
298    }
299}
300
301#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
302pub struct OutputSchema {
303    column_schemas: Vec<ColumnSchema>,
304}
305
306impl OutputSchema {
307    pub fn new(columns: Vec<ColumnSchema>) -> OutputSchema {
308        OutputSchema {
309            column_schemas: columns,
310        }
311    }
312}
313
314impl From<SchemaRef> for OutputSchema {
315    fn from(schema: SchemaRef) -> OutputSchema {
316        OutputSchema {
317            column_schemas: schema
318                .column_schemas()
319                .iter()
320                .map(|cs| ColumnSchema {
321                    name: cs.name.clone(),
322                    data_type: cs.data_type.name(),
323                })
324                .collect(),
325        }
326    }
327}
328
329#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
330pub struct HttpRecordsOutput {
331    schema: OutputSchema,
332    rows: Vec<Vec<Value>>,
333    // total_rows is equal to rows.len() in most cases,
334    // the Dashboard query result may be truncated, so we need to return the total_rows.
335    #[serde(default)]
336    total_rows: usize,
337
338    // plan level execution metrics
339    #[serde(skip_serializing_if = "HashMap::is_empty")]
340    #[serde(default)]
341    metrics: HashMap<String, Value>,
342}
343
344impl HttpRecordsOutput {
345    pub fn num_rows(&self) -> usize {
346        self.rows.len()
347    }
348
349    pub fn num_cols(&self) -> usize {
350        self.schema.column_schemas.len()
351    }
352
353    pub fn schema(&self) -> &OutputSchema {
354        &self.schema
355    }
356
357    pub fn rows(&self) -> &Vec<Vec<Value>> {
358        &self.rows
359    }
360}
361
362impl HttpRecordsOutput {
363    pub fn try_new(
364        schema: SchemaRef,
365        recordbatches: Vec<RecordBatch>,
366    ) -> std::result::Result<HttpRecordsOutput, Error> {
367        if recordbatches.is_empty() {
368            Ok(HttpRecordsOutput {
369                schema: OutputSchema::from(schema),
370                rows: vec![],
371                total_rows: 0,
372                metrics: Default::default(),
373            })
374        } else {
375            let num_rows = recordbatches.iter().map(|r| r.num_rows()).sum::<usize>();
376            let mut rows = Vec::with_capacity(num_rows);
377
378            for recordbatch in recordbatches {
379                let mut writer = HttpOutputWriter::new(schema.num_columns(), None);
380                writer.write(recordbatch, &mut rows)?;
381            }
382
383            Ok(HttpRecordsOutput {
384                schema: OutputSchema::from(schema),
385                total_rows: rows.len(),
386                rows,
387                metrics: Default::default(),
388            })
389        }
390    }
391}
392
393#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
394#[serde(rename_all = "lowercase")]
395pub enum GreptimeQueryOutput {
396    AffectedRows(usize),
397    Records(HttpRecordsOutput),
398}
399
400/// It allows the results of SQL queries to be presented in different formats.
401#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
402pub enum ResponseFormat {
403    Arrow,
404    // (with_names, with_types)
405    Csv(bool, bool),
406    Table,
407    #[default]
408    GreptimedbV1,
409    InfluxdbV1,
410    Json,
411    Null,
412}
413
414impl ResponseFormat {
415    pub fn parse(s: &str) -> Option<Self> {
416        match s {
417            "arrow" => Some(ResponseFormat::Arrow),
418            "csv" => Some(ResponseFormat::Csv(false, false)),
419            "csvwithnames" => Some(ResponseFormat::Csv(true, false)),
420            "csvwithnamesandtypes" => Some(ResponseFormat::Csv(true, true)),
421            "table" => Some(ResponseFormat::Table),
422            "greptimedb_v1" => Some(ResponseFormat::GreptimedbV1),
423            "influxdb_v1" => Some(ResponseFormat::InfluxdbV1),
424            "json" => Some(ResponseFormat::Json),
425            "null" => Some(ResponseFormat::Null),
426            _ => None,
427        }
428    }
429
430    pub fn as_str(&self) -> &'static str {
431        match self {
432            ResponseFormat::Arrow => "arrow",
433            ResponseFormat::Csv(_, _) => "csv",
434            ResponseFormat::Table => "table",
435            ResponseFormat::GreptimedbV1 => "greptimedb_v1",
436            ResponseFormat::InfluxdbV1 => "influxdb_v1",
437            ResponseFormat::Json => "json",
438            ResponseFormat::Null => "null",
439        }
440    }
441}
442
443#[derive(Debug, Clone, Copy, PartialEq, Eq)]
444pub enum Epoch {
445    Nanosecond,
446    Microsecond,
447    Millisecond,
448    Second,
449}
450
451impl Epoch {
452    pub fn parse(s: &str) -> Option<Epoch> {
453        // Both u and µ indicate microseconds.
454        // epoch = [ns,u,µ,ms,s],
455        // For details, see the Influxdb documents.
456        // https://docs.influxdata.com/influxdb/v1/tools/api/#query-string-parameters-1
457        match s {
458            "ns" => Some(Epoch::Nanosecond),
459            "u" | "µ" => Some(Epoch::Microsecond),
460            "ms" => Some(Epoch::Millisecond),
461            "s" => Some(Epoch::Second),
462            _ => None, // just returns None for other cases
463        }
464    }
465
466    pub fn convert_timestamp(&self, ts: Timestamp) -> Option<Timestamp> {
467        match self {
468            Epoch::Nanosecond => ts.convert_to(TimeUnit::Nanosecond),
469            Epoch::Microsecond => ts.convert_to(TimeUnit::Microsecond),
470            Epoch::Millisecond => ts.convert_to(TimeUnit::Millisecond),
471            Epoch::Second => ts.convert_to(TimeUnit::Second),
472        }
473    }
474}
475
476impl Display for Epoch {
477    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
478        match self {
479            Epoch::Nanosecond => write!(f, "Epoch::Nanosecond"),
480            Epoch::Microsecond => write!(f, "Epoch::Microsecond"),
481            Epoch::Millisecond => write!(f, "Epoch::Millisecond"),
482            Epoch::Second => write!(f, "Epoch::Second"),
483        }
484    }
485}
486
487#[derive(Serialize, Deserialize, Debug)]
488pub enum HttpResponse {
489    Arrow(ArrowResponse),
490    Csv(CsvResponse),
491    Table(TableResponse),
492    Error(ErrorResponse),
493    GreptimedbV1(GreptimedbV1Response),
494    InfluxdbV1(InfluxdbV1Response),
495    Json(JsonResponse),
496    Null(NullResponse),
497}
498
499impl HttpResponse {
500    pub fn with_execution_time(self, execution_time: u64) -> Self {
501        match self {
502            HttpResponse::Arrow(resp) => resp.with_execution_time(execution_time).into(),
503            HttpResponse::Csv(resp) => resp.with_execution_time(execution_time).into(),
504            HttpResponse::Table(resp) => resp.with_execution_time(execution_time).into(),
505            HttpResponse::GreptimedbV1(resp) => resp.with_execution_time(execution_time).into(),
506            HttpResponse::InfluxdbV1(resp) => resp.with_execution_time(execution_time).into(),
507            HttpResponse::Json(resp) => resp.with_execution_time(execution_time).into(),
508            HttpResponse::Null(resp) => resp.with_execution_time(execution_time).into(),
509            HttpResponse::Error(resp) => resp.with_execution_time(execution_time).into(),
510        }
511    }
512
513    pub fn with_limit(self, limit: usize) -> Self {
514        match self {
515            HttpResponse::Csv(resp) => resp.with_limit(limit).into(),
516            HttpResponse::Table(resp) => resp.with_limit(limit).into(),
517            HttpResponse::GreptimedbV1(resp) => resp.with_limit(limit).into(),
518            HttpResponse::Json(resp) => resp.with_limit(limit).into(),
519            _ => self,
520        }
521    }
522}
523
524pub fn process_with_limit(
525    mut outputs: Vec<GreptimeQueryOutput>,
526    limit: usize,
527) -> Vec<GreptimeQueryOutput> {
528    outputs
529        .drain(..)
530        .map(|data| match data {
531            GreptimeQueryOutput::Records(mut records) => {
532                if records.rows.len() > limit {
533                    records.rows.truncate(limit);
534                    records.total_rows = limit;
535                }
536                GreptimeQueryOutput::Records(records)
537            }
538            _ => data,
539        })
540        .collect()
541}
542
543impl IntoResponse for HttpResponse {
544    fn into_response(self) -> Response {
545        match self {
546            HttpResponse::Arrow(resp) => resp.into_response(),
547            HttpResponse::Csv(resp) => resp.into_response(),
548            HttpResponse::Table(resp) => resp.into_response(),
549            HttpResponse::GreptimedbV1(resp) => resp.into_response(),
550            HttpResponse::InfluxdbV1(resp) => resp.into_response(),
551            HttpResponse::Json(resp) => resp.into_response(),
552            HttpResponse::Null(resp) => resp.into_response(),
553            HttpResponse::Error(resp) => resp.into_response(),
554        }
555    }
556}
557
558impl From<ArrowResponse> for HttpResponse {
559    fn from(value: ArrowResponse) -> Self {
560        HttpResponse::Arrow(value)
561    }
562}
563
564impl From<CsvResponse> for HttpResponse {
565    fn from(value: CsvResponse) -> Self {
566        HttpResponse::Csv(value)
567    }
568}
569
570impl From<TableResponse> for HttpResponse {
571    fn from(value: TableResponse) -> Self {
572        HttpResponse::Table(value)
573    }
574}
575
576impl From<ErrorResponse> for HttpResponse {
577    fn from(value: ErrorResponse) -> Self {
578        HttpResponse::Error(value)
579    }
580}
581
582impl From<GreptimedbV1Response> for HttpResponse {
583    fn from(value: GreptimedbV1Response) -> Self {
584        HttpResponse::GreptimedbV1(value)
585    }
586}
587
588impl From<InfluxdbV1Response> for HttpResponse {
589    fn from(value: InfluxdbV1Response) -> Self {
590        HttpResponse::InfluxdbV1(value)
591    }
592}
593
594impl From<JsonResponse> for HttpResponse {
595    fn from(value: JsonResponse) -> Self {
596        HttpResponse::Json(value)
597    }
598}
599
600impl From<NullResponse> for HttpResponse {
601    fn from(value: NullResponse) -> Self {
602        HttpResponse::Null(value)
603    }
604}
605
606#[derive(Clone)]
607pub struct ApiState {
608    pub sql_handler: ServerSqlQueryHandlerRef,
609}
610
611#[derive(Clone)]
612pub struct GreptimeOptionsConfigState {
613    pub greptime_config_options: String,
614}
615
616pub struct HttpServerBuilder {
617    options: HttpOptions,
618    plugins: Plugins,
619    user_provider: Option<UserProviderRef>,
620    router: Router,
621    memory_limiter: ServerMemoryLimiter,
622}
623
624impl HttpServerBuilder {
625    pub fn new(options: HttpOptions) -> Self {
626        Self {
627            options,
628            plugins: Plugins::default(),
629            user_provider: None,
630            router: Router::new(),
631            memory_limiter: ServerMemoryLimiter::default(),
632        }
633    }
634
635    /// Set a global memory limiter for all server protocols.
636    pub fn with_memory_limiter(mut self, limiter: ServerMemoryLimiter) -> Self {
637        self.memory_limiter = limiter;
638        self
639    }
640
641    pub fn with_sql_handler(self, sql_handler: ServerSqlQueryHandlerRef) -> Self {
642        let sql_router = HttpServer::route_sql(ApiState { sql_handler });
643
644        Self {
645            router: self
646                .router
647                .nest(&format!("/{HTTP_API_VERSION}"), sql_router),
648            ..self
649        }
650    }
651
652    pub fn with_logs_handler(self, logs_handler: LogQueryHandlerRef) -> Self {
653        let logs_router = HttpServer::route_logs(logs_handler);
654
655        Self {
656            router: self
657                .router
658                .nest(&format!("/{HTTP_API_VERSION}"), logs_router),
659            ..self
660        }
661    }
662
663    pub fn with_opentsdb_handler(self, handler: OpentsdbProtocolHandlerRef) -> Self {
664        Self {
665            router: self.router.nest(
666                &format!("/{HTTP_API_VERSION}/opentsdb"),
667                HttpServer::route_opentsdb(handler),
668            ),
669            ..self
670        }
671    }
672
673    pub fn with_influxdb_handler(self, handler: InfluxdbLineProtocolHandlerRef) -> Self {
674        Self {
675            router: self.router.nest(
676                &format!("/{HTTP_API_VERSION}/influxdb"),
677                HttpServer::route_influxdb(handler),
678            ),
679            ..self
680        }
681    }
682
683    pub fn with_prom_handler(
684        self,
685        handler: PromStoreProtocolHandlerRef,
686        pipeline_handler: Option<PipelineHandlerRef>,
687        prom_store_with_metric_engine: bool,
688        prom_validation_mode: PromValidationMode,
689    ) -> Self {
690        let state = PromStoreState {
691            prom_store_handler: handler,
692            pipeline_handler,
693            prom_store_with_metric_engine,
694            prom_validation_mode,
695        };
696
697        Self {
698            router: self.router.nest(
699                &format!("/{HTTP_API_VERSION}/prometheus"),
700                HttpServer::route_prom(state),
701            ),
702            ..self
703        }
704    }
705
706    pub fn with_prometheus_handler(self, handler: PrometheusHandlerRef) -> Self {
707        Self {
708            router: self.router.nest(
709                &format!("/{HTTP_API_VERSION}/prometheus/api/v1"),
710                HttpServer::route_prometheus(handler),
711            ),
712            ..self
713        }
714    }
715
716    pub fn with_otlp_handler(
717        self,
718        handler: OpenTelemetryProtocolHandlerRef,
719        with_metric_engine: bool,
720    ) -> Self {
721        Self {
722            router: self.router.nest(
723                &format!("/{HTTP_API_VERSION}/otlp"),
724                HttpServer::route_otlp(handler, with_metric_engine),
725            ),
726            ..self
727        }
728    }
729
730    pub fn with_user_provider(self, user_provider: UserProviderRef) -> Self {
731        Self {
732            user_provider: Some(user_provider),
733            ..self
734        }
735    }
736
737    pub fn with_metrics_handler(self, handler: MetricsHandler) -> Self {
738        Self {
739            router: self.router.merge(HttpServer::route_metrics(handler)),
740            ..self
741        }
742    }
743
744    pub fn with_log_ingest_handler(
745        self,
746        handler: PipelineHandlerRef,
747        validator: Option<LogValidatorRef>,
748        ingest_interceptor: Option<LogIngestInterceptorRef<Error>>,
749    ) -> Self {
750        let log_state = LogState {
751            log_handler: handler,
752            log_validator: validator,
753            ingest_interceptor,
754        };
755
756        let router = self.router.nest(
757            &format!("/{HTTP_API_VERSION}"),
758            HttpServer::route_pipelines(log_state.clone()),
759        );
760        // deprecated since v0.11.0. Use `/logs` and `/pipelines` instead.
761        let router = router.nest(
762            &format!("/{HTTP_API_VERSION}/events"),
763            #[allow(deprecated)]
764            HttpServer::route_log_deprecated(log_state.clone()),
765        );
766
767        let router = router.nest(
768            &format!("/{HTTP_API_VERSION}/loki"),
769            HttpServer::route_loki(log_state.clone()),
770        );
771
772        let router = router.nest(
773            &format!("/{HTTP_API_VERSION}/elasticsearch"),
774            HttpServer::route_elasticsearch(log_state.clone()),
775        );
776
777        let router = router.nest(
778            &format!("/{HTTP_API_VERSION}/elasticsearch/"),
779            Router::new()
780                .route("/", routing::get(elasticsearch::handle_get_version))
781                .with_state(log_state),
782        );
783
784        Self { router, ..self }
785    }
786
787    pub fn with_plugins(self, plugins: Plugins) -> Self {
788        Self { plugins, ..self }
789    }
790
791    pub fn with_greptime_config_options(self, opts: String) -> Self {
792        let config_router = HttpServer::route_config(GreptimeOptionsConfigState {
793            greptime_config_options: opts,
794        });
795
796        Self {
797            router: self.router.merge(config_router),
798            ..self
799        }
800    }
801
802    pub fn with_jaeger_handler(self, handler: JaegerQueryHandlerRef) -> Self {
803        Self {
804            router: self.router.nest(
805                &format!("/{HTTP_API_VERSION}/jaeger"),
806                HttpServer::route_jaeger(handler),
807            ),
808            ..self
809        }
810    }
811
812    pub fn with_extra_router(self, router: Router) -> Self {
813        Self {
814            router: self.router.merge(router),
815            ..self
816        }
817    }
818
819    pub fn add_layer<L>(self, layer: L) -> Self
820    where
821        L: Layer<Route> + Clone + Send + Sync + 'static,
822        L::Service: Service<Request> + Clone + Send + Sync + 'static,
823        <L::Service as Service<Request>>::Response: IntoResponse + 'static,
824        <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
825        <L::Service as Service<Request>>::Future: Send + 'static,
826    {
827        Self {
828            router: self.router.layer(layer),
829            ..self
830        }
831    }
832
833    pub fn build(self) -> HttpServer {
834        HttpServer {
835            options: self.options,
836            user_provider: self.user_provider,
837            shutdown_tx: Mutex::new(None),
838            plugins: self.plugins,
839            router: StdMutex::new(self.router),
840            bind_addr: None,
841            memory_limiter: self.memory_limiter,
842        }
843    }
844}
845
846impl HttpServer {
847    /// Gets the router and adds necessary root routes (health, status, dashboard).
848    pub fn make_app(&self) -> Router {
849        let mut router = {
850            let router = self.router.lock().unwrap();
851            router.clone()
852        };
853
854        router = router
855            .route(
856                "/health",
857                routing::get(handler::health).post(handler::health),
858            )
859            .route(
860                &format!("/{HTTP_API_VERSION}/health"),
861                routing::get(handler::health).post(handler::health),
862            )
863            .route(
864                "/ready",
865                routing::get(handler::health).post(handler::health),
866            );
867
868        router = router.route("/status", routing::get(handler::status));
869
870        #[cfg(feature = "dashboard")]
871        {
872            if !self.options.disable_dashboard {
873                info!("Enable dashboard service at '/dashboard'");
874                // redirect /dashboard to /dashboard/
875                router = router.route(
876                    "/dashboard",
877                    routing::get(|uri: axum::http::uri::Uri| async move {
878                        let path = uri.path();
879                        let query = uri.query().map(|q| format!("?{}", q)).unwrap_or_default();
880
881                        let new_uri = format!("{}/{}", path, query);
882                        axum::response::Redirect::permanent(&new_uri)
883                    }),
884                );
885
886                // "/dashboard" and "/dashboard/" are two different paths in Axum.
887                // We cannot nest "/dashboard/", because we already mapping "/dashboard/{*x}" while nesting "/dashboard".
888                // So we explicitly route "/dashboard/" here.
889                router = router
890                    .route(
891                        "/dashboard/",
892                        routing::get(dashboard::static_handler).post(dashboard::static_handler),
893                    )
894                    .route(
895                        "/dashboard/{*x}",
896                        routing::get(dashboard::static_handler).post(dashboard::static_handler),
897                    );
898            }
899        }
900
901        // Add a layer to collect HTTP metrics for axum.
902        router = router.route_layer(middleware::from_fn(http_metrics_layer));
903
904        router
905    }
906
907    /// Attaches middlewares and debug routes to the router.
908    /// Callers should call this method after [HttpServer::make_app()].
909    pub fn build(&self, router: Router) -> Result<Router> {
910        let timeout_layer = if self.options.timeout != Duration::default() {
911            Some(ServiceBuilder::new().layer(DynamicTimeoutLayer::new(self.options.timeout)))
912        } else {
913            info!("HTTP server timeout is disabled");
914            None
915        };
916        let body_limit_layer = if self.options.body_limit != ReadableSize(0) {
917            Some(
918                ServiceBuilder::new()
919                    .layer(DefaultBodyLimit::max(self.options.body_limit.0 as usize)),
920            )
921        } else {
922            info!("HTTP server body limit is disabled");
923            None
924        };
925        let cors_layer = if self.options.enable_cors {
926            Some(
927                CorsLayer::new()
928                    .allow_methods([
929                        Method::GET,
930                        Method::POST,
931                        Method::PUT,
932                        Method::DELETE,
933                        Method::HEAD,
934                    ])
935                    .allow_origin(if self.options.cors_allowed_origins.is_empty() {
936                        AllowOrigin::from(Any)
937                    } else {
938                        AllowOrigin::from(
939                            self.options
940                                .cors_allowed_origins
941                                .iter()
942                                .map(|s| {
943                                    HeaderValue::from_str(s.as_str())
944                                        .context(InvalidHeaderValueSnafu)
945                                })
946                                .collect::<Result<Vec<HeaderValue>>>()?,
947                        )
948                    })
949                    .allow_headers(Any),
950            )
951        } else {
952            info!("HTTP server cross-origin is disabled");
953            None
954        };
955
956        Ok(router
957            // middlewares
958            .layer(
959                ServiceBuilder::new()
960                    // disable on failure tracing. because printing out isn't very helpful,
961                    // and we have impl IntoResponse for Error. It will print out more detailed error messages
962                    .layer(TraceLayer::new_for_http().on_failure(()))
963                    .option_layer(cors_layer)
964                    .option_layer(timeout_layer)
965                    .option_layer(body_limit_layer)
966                    // memory limit layer - must be before body is consumed
967                    .layer(middleware::from_fn_with_state(
968                        self.memory_limiter.clone(),
969                        memory_limit::memory_limit_middleware,
970                    ))
971                    // auth layer
972                    .layer(middleware::from_fn_with_state(
973                        AuthState::new(self.user_provider.clone()),
974                        authorize::check_http_auth,
975                    ))
976                    .layer(middleware::from_fn(hints::extract_hints))
977                    .layer(middleware::from_fn(
978                        read_preference::extract_read_preference,
979                    )),
980            )
981            // Handlers for debug, we don't expect a timeout.
982            .nest(
983                "/debug",
984                Router::new()
985                    // handler for changing log level dynamically
986                    .route("/log_level", routing::post(dyn_log::dyn_log_handler))
987                    .route("/enable_trace", routing::post(dyn_trace::dyn_trace_handler))
988                    .nest(
989                        "/prof",
990                        Router::new()
991                            .route("/cpu", routing::post(pprof::pprof_handler))
992                            .route("/mem", routing::post(mem_prof::mem_prof_handler))
993                            .route("/mem/symbol", routing::post(mem_prof::symbolicate_handler))
994                            .route(
995                                "/mem/activate",
996                                routing::post(mem_prof::activate_heap_prof_handler),
997                            )
998                            .route(
999                                "/mem/deactivate",
1000                                routing::post(mem_prof::deactivate_heap_prof_handler),
1001                            )
1002                            .route(
1003                                "/mem/status",
1004                                routing::get(mem_prof::heap_prof_status_handler),
1005                            ) // jemalloc gdump flag status and toggle
1006                            .route(
1007                                "/mem/gdump",
1008                                routing::get(mem_prof::gdump_status_handler)
1009                                    .post(mem_prof::gdump_toggle_handler),
1010                            ),
1011                    ),
1012            ))
1013    }
1014
1015    fn route_metrics<S>(metrics_handler: MetricsHandler) -> Router<S> {
1016        Router::new()
1017            .route("/metrics", routing::get(handler::metrics))
1018            .with_state(metrics_handler)
1019    }
1020
1021    fn route_loki<S>(log_state: LogState) -> Router<S> {
1022        Router::new()
1023            .route("/api/v1/push", routing::post(loki::loki_ingest))
1024            .layer(
1025                ServiceBuilder::new()
1026                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1027            )
1028            .with_state(log_state)
1029    }
1030
1031    fn route_elasticsearch<S>(log_state: LogState) -> Router<S> {
1032        Router::new()
1033            // Return fake responsefor HEAD '/' request.
1034            .route(
1035                "/",
1036                routing::head((HttpStatusCode::OK, elasticsearch::elasticsearch_headers())),
1037            )
1038            // Return fake response for Elasticsearch version request.
1039            .route("/", routing::get(elasticsearch::handle_get_version))
1040            // Return fake response for Elasticsearch license request.
1041            .route("/_license", routing::get(elasticsearch::handle_get_license))
1042            .route("/_bulk", routing::post(elasticsearch::handle_bulk_api))
1043            .route(
1044                "/{index}/_bulk",
1045                routing::post(elasticsearch::handle_bulk_api_with_index),
1046            )
1047            // Return fake response for Elasticsearch ilm request.
1048            .route(
1049                "/_ilm/policy/{*path}",
1050                routing::any((
1051                    HttpStatusCode::OK,
1052                    elasticsearch::elasticsearch_headers(),
1053                    axum::Json(serde_json::json!({})),
1054                )),
1055            )
1056            // Return fake response for Elasticsearch index template request.
1057            .route(
1058                "/_index_template/{*path}",
1059                routing::any((
1060                    HttpStatusCode::OK,
1061                    elasticsearch::elasticsearch_headers(),
1062                    axum::Json(serde_json::json!({})),
1063                )),
1064            )
1065            // Return fake response for Elasticsearch ingest pipeline request.
1066            // See: https://www.elastic.co/guide/en/elasticsearch/reference/8.8/put-pipeline-api.html.
1067            .route(
1068                "/_ingest/{*path}",
1069                routing::any((
1070                    HttpStatusCode::OK,
1071                    elasticsearch::elasticsearch_headers(),
1072                    axum::Json(serde_json::json!({})),
1073                )),
1074            )
1075            // Return fake response for Elasticsearch nodes discovery request.
1076            // See: https://www.elastic.co/guide/en/elasticsearch/reference/8.8/cluster.html.
1077            .route(
1078                "/_nodes/{*path}",
1079                routing::any((
1080                    HttpStatusCode::OK,
1081                    elasticsearch::elasticsearch_headers(),
1082                    axum::Json(serde_json::json!({})),
1083                )),
1084            )
1085            // Return fake response for Logstash APIs requests.
1086            // See: https://www.elastic.co/guide/en/elasticsearch/reference/8.8/logstash-apis.html
1087            .route(
1088                "/logstash/{*path}",
1089                routing::any((
1090                    HttpStatusCode::OK,
1091                    elasticsearch::elasticsearch_headers(),
1092                    axum::Json(serde_json::json!({})),
1093                )),
1094            )
1095            .route(
1096                "/_logstash/{*path}",
1097                routing::any((
1098                    HttpStatusCode::OK,
1099                    elasticsearch::elasticsearch_headers(),
1100                    axum::Json(serde_json::json!({})),
1101                )),
1102            )
1103            .layer(ServiceBuilder::new().layer(RequestDecompressionLayer::new()))
1104            .with_state(log_state)
1105    }
1106
1107    #[deprecated(since = "0.11.0", note = "Use `route_pipelines()` instead.")]
1108    fn route_log_deprecated<S>(log_state: LogState) -> Router<S> {
1109        Router::new()
1110            .route("/logs", routing::post(event::log_ingester))
1111            .route(
1112                "/pipelines/{pipeline_name}",
1113                routing::get(event::query_pipeline),
1114            )
1115            .route(
1116                "/pipelines/{pipeline_name}",
1117                routing::post(event::add_pipeline),
1118            )
1119            .route(
1120                "/pipelines/{pipeline_name}",
1121                routing::delete(event::delete_pipeline),
1122            )
1123            .route("/pipelines/dryrun", routing::post(event::pipeline_dryrun))
1124            .layer(
1125                ServiceBuilder::new()
1126                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1127            )
1128            .with_state(log_state)
1129    }
1130
1131    fn route_pipelines<S>(log_state: LogState) -> Router<S> {
1132        Router::new()
1133            .route("/ingest", routing::post(event::log_ingester))
1134            .route(
1135                "/pipelines/{pipeline_name}",
1136                routing::get(event::query_pipeline),
1137            )
1138            .route(
1139                "/pipelines/{pipeline_name}/ddl",
1140                routing::get(event::query_pipeline_ddl),
1141            )
1142            .route(
1143                "/pipelines/{pipeline_name}",
1144                routing::post(event::add_pipeline),
1145            )
1146            .route(
1147                "/pipelines/{pipeline_name}",
1148                routing::delete(event::delete_pipeline),
1149            )
1150            .route("/pipelines/_dryrun", routing::post(event::pipeline_dryrun))
1151            .layer(
1152                ServiceBuilder::new()
1153                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1154            )
1155            .with_state(log_state)
1156    }
1157
1158    fn route_sql<S>(api_state: ApiState) -> Router<S> {
1159        Router::new()
1160            .route("/sql", routing::get(handler::sql).post(handler::sql))
1161            .route(
1162                "/sql/parse",
1163                routing::get(handler::sql_parse).post(handler::sql_parse),
1164            )
1165            .route(
1166                "/sql/format",
1167                routing::get(handler::sql_format).post(handler::sql_format),
1168            )
1169            .route(
1170                "/promql",
1171                routing::get(handler::promql).post(handler::promql),
1172            )
1173            .with_state(api_state)
1174    }
1175
1176    fn route_logs<S>(log_handler: LogQueryHandlerRef) -> Router<S> {
1177        Router::new()
1178            .route("/logs", routing::get(logs::logs).post(logs::logs))
1179            .with_state(log_handler)
1180    }
1181
1182    /// Route Prometheus [HTTP API].
1183    ///
1184    /// [HTTP API]: https://prometheus.io/docs/prometheus/latest/querying/api/
1185    pub fn route_prometheus<S>(prometheus_handler: PrometheusHandlerRef) -> Router<S> {
1186        Router::new()
1187            .route(
1188                "/format_query",
1189                routing::post(format_query).get(format_query),
1190            )
1191            .route("/status/buildinfo", routing::get(build_info_query))
1192            .route("/query", routing::post(instant_query).get(instant_query))
1193            .route("/query_range", routing::post(range_query).get(range_query))
1194            .route("/labels", routing::post(labels_query).get(labels_query))
1195            .route("/series", routing::post(series_query).get(series_query))
1196            .route("/parse_query", routing::post(parse_query).get(parse_query))
1197            .route(
1198                "/label/{label_name}/values",
1199                routing::get(label_values_query),
1200            )
1201            .layer(ServiceBuilder::new().layer(CompressionLayer::new()))
1202            .with_state(prometheus_handler)
1203    }
1204
1205    /// Route Prometheus remote [read] and [write] API. In other places the related modules are
1206    /// called `prom_store`.
1207    ///
1208    /// [read]: https://prometheus.io/docs/prometheus/latest/querying/remote_read_api/
1209    /// [write]: https://prometheus.io/docs/concepts/remote_write_spec/
1210    fn route_prom<S>(state: PromStoreState) -> Router<S> {
1211        Router::new()
1212            .route("/read", routing::post(prom_store::remote_read))
1213            .route("/write", routing::post(prom_store::remote_write))
1214            .with_state(state)
1215    }
1216
1217    fn route_influxdb<S>(influxdb_handler: InfluxdbLineProtocolHandlerRef) -> Router<S> {
1218        Router::new()
1219            .route("/write", routing::post(influxdb_write_v1))
1220            .route("/api/v2/write", routing::post(influxdb_write_v2))
1221            .layer(
1222                ServiceBuilder::new()
1223                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1224            )
1225            .route("/ping", routing::get(influxdb_ping))
1226            .route("/health", routing::get(influxdb_health))
1227            .with_state(influxdb_handler)
1228    }
1229
1230    fn route_opentsdb<S>(opentsdb_handler: OpentsdbProtocolHandlerRef) -> Router<S> {
1231        Router::new()
1232            .route("/api/put", routing::post(opentsdb::put))
1233            .with_state(opentsdb_handler)
1234    }
1235
1236    fn route_otlp<S>(
1237        otlp_handler: OpenTelemetryProtocolHandlerRef,
1238        with_metric_engine: bool,
1239    ) -> Router<S> {
1240        Router::new()
1241            .route("/v1/metrics", routing::post(otlp::metrics))
1242            .route("/v1/traces", routing::post(otlp::traces))
1243            .route("/v1/logs", routing::post(otlp::logs))
1244            .layer(
1245                ServiceBuilder::new()
1246                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1247            )
1248            .with_state(OtlpState {
1249                with_metric_engine,
1250                handler: otlp_handler,
1251            })
1252    }
1253
1254    fn route_config<S>(state: GreptimeOptionsConfigState) -> Router<S> {
1255        Router::new()
1256            .route("/config", routing::get(handler::config))
1257            .with_state(state)
1258    }
1259
1260    fn route_jaeger<S>(handler: JaegerQueryHandlerRef) -> Router<S> {
1261        Router::new()
1262            .route("/api/services", routing::get(jaeger::handle_get_services))
1263            .route(
1264                "/api/services/{service_name}/operations",
1265                routing::get(jaeger::handle_get_operations_by_service),
1266            )
1267            .route(
1268                "/api/operations",
1269                routing::get(jaeger::handle_get_operations),
1270            )
1271            .route("/api/traces", routing::get(jaeger::handle_find_traces))
1272            .route(
1273                "/api/traces/{trace_id}",
1274                routing::get(jaeger::handle_get_trace),
1275            )
1276            .with_state(handler)
1277    }
1278}
1279
1280pub const HTTP_SERVER: &str = "HTTP_SERVER";
1281
1282#[async_trait]
1283impl Server for HttpServer {
1284    async fn shutdown(&self) -> Result<()> {
1285        let mut shutdown_tx = self.shutdown_tx.lock().await;
1286        if let Some(tx) = shutdown_tx.take()
1287            && tx.send(()).is_err()
1288        {
1289            info!("Receiver dropped, the HTTP server has already exited");
1290        }
1291        info!("Shutdown HTTP server");
1292
1293        Ok(())
1294    }
1295
1296    async fn start(&mut self, listening: SocketAddr) -> Result<()> {
1297        let (tx, rx) = oneshot::channel();
1298        let serve = {
1299            let mut shutdown_tx = self.shutdown_tx.lock().await;
1300            ensure!(
1301                shutdown_tx.is_none(),
1302                AlreadyStartedSnafu { server: "HTTP" }
1303            );
1304
1305            let mut app = self.make_app();
1306            if let Some(configurator) = self.plugins.get::<HttpConfiguratorRef<()>>() {
1307                app = configurator
1308                    .configure_http(app, ())
1309                    .await
1310                    .context(OtherSnafu)?;
1311            }
1312            let app = self.build(app)?;
1313            let listener = tokio::net::TcpListener::bind(listening)
1314                .await
1315                .context(AddressBindSnafu { addr: listening })?
1316                .tap_io(|tcp_stream| {
1317                    if let Err(e) = tcp_stream.set_nodelay(true) {
1318                        error!(e; "Failed to set TCP_NODELAY on incoming connection");
1319                    }
1320                });
1321            let serve = axum::serve(listener, app.into_make_service());
1322
1323            // FIXME(yingwen): Support keepalive.
1324            // See:
1325            // - https://github.com/tokio-rs/axum/discussions/2939
1326            // - https://stackoverflow.com/questions/73069718/how-do-i-keep-alive-tokiotcpstream-in-rust
1327            // let server = axum::Server::try_bind(&listening)
1328            //     .with_context(|_| AddressBindSnafu { addr: listening })?
1329            //     .tcp_nodelay(true)
1330            //     // Enable TCP keepalive to close the dangling established connections.
1331            //     // It's configured to let the keepalive probes first send after the connection sits
1332            //     // idle for 59 minutes, and then send every 10 seconds for 6 times.
1333            //     // So the connection will be closed after roughly 1 hour.
1334            //     .tcp_keepalive(Some(Duration::from_secs(59 * 60)))
1335            //     .tcp_keepalive_interval(Some(Duration::from_secs(10)))
1336            //     .tcp_keepalive_retries(Some(6))
1337            //     .serve(app.into_make_service());
1338
1339            *shutdown_tx = Some(tx);
1340
1341            serve
1342        };
1343        let listening = serve.local_addr().context(InternalIoSnafu)?;
1344        info!("HTTP server is bound to {}", listening);
1345
1346        common_runtime::spawn_global(async move {
1347            if let Err(e) = serve
1348                .with_graceful_shutdown(rx.map(drop))
1349                .await
1350                .context(InternalIoSnafu)
1351            {
1352                error!(e; "Failed to shutdown http server");
1353            }
1354        });
1355
1356        self.bind_addr = Some(listening);
1357        Ok(())
1358    }
1359
1360    fn name(&self) -> &str {
1361        HTTP_SERVER
1362    }
1363
1364    fn bind_addr(&self) -> Option<SocketAddr> {
1365        self.bind_addr
1366    }
1367
1368    fn as_any(&self) -> &dyn std::any::Any {
1369        self
1370    }
1371}
1372
1373#[cfg(test)]
1374mod test {
1375    use std::future::pending;
1376    use std::io::Cursor;
1377    use std::sync::Arc;
1378
1379    use arrow_ipc::reader::FileReader;
1380    use arrow_schema::DataType;
1381    use axum::handler::Handler;
1382    use axum::http::StatusCode;
1383    use axum::routing::get;
1384    use common_query::Output;
1385    use common_recordbatch::RecordBatches;
1386    use datafusion_expr::LogicalPlan;
1387    use datatypes::prelude::*;
1388    use datatypes::schema::{ColumnSchema, Schema};
1389    use datatypes::vectors::{StringVector, UInt32Vector};
1390    use header::constants::GREPTIME_DB_HEADER_TIMEOUT;
1391    use query::parser::PromQuery;
1392    use query::query_engine::DescribeResult;
1393    use session::context::QueryContextRef;
1394    use sql::statements::statement::Statement;
1395    use tokio::sync::mpsc;
1396    use tokio::time::Instant;
1397
1398    use super::*;
1399    use crate::http::test_helpers::TestClient;
1400    use crate::query_handler::sql::SqlQueryHandler;
1401
1402    struct DummyInstance {
1403        _tx: mpsc::Sender<(String, Vec<u8>)>,
1404    }
1405
1406    #[async_trait]
1407    impl SqlQueryHandler for DummyInstance {
1408        async fn do_query(&self, _: &str, _: QueryContextRef) -> Vec<Result<Output>> {
1409            unimplemented!()
1410        }
1411
1412        async fn do_promql_query(&self, _: &PromQuery, _: QueryContextRef) -> Vec<Result<Output>> {
1413            unimplemented!()
1414        }
1415
1416        async fn do_exec_plan(
1417            &self,
1418            _stmt: Option<Statement>,
1419            _plan: LogicalPlan,
1420            _query_ctx: QueryContextRef,
1421        ) -> Result<Output> {
1422            unimplemented!()
1423        }
1424
1425        async fn do_describe(
1426            &self,
1427            _stmt: sql::statements::statement::Statement,
1428            _query_ctx: QueryContextRef,
1429        ) -> Result<Option<DescribeResult>> {
1430            unimplemented!()
1431        }
1432
1433        async fn is_valid_schema(&self, _catalog: &str, _schema: &str) -> Result<bool> {
1434            Ok(true)
1435        }
1436    }
1437
1438    fn timeout() -> DynamicTimeoutLayer {
1439        DynamicTimeoutLayer::new(Duration::from_millis(10))
1440    }
1441
1442    async fn forever() {
1443        pending().await
1444    }
1445
1446    fn make_test_app(tx: mpsc::Sender<(String, Vec<u8>)>) -> Router {
1447        make_test_app_custom(tx, HttpOptions::default())
1448    }
1449
1450    fn make_test_app_custom(tx: mpsc::Sender<(String, Vec<u8>)>, options: HttpOptions) -> Router {
1451        let instance = Arc::new(DummyInstance { _tx: tx });
1452        let server = HttpServerBuilder::new(options)
1453            .with_sql_handler(instance.clone())
1454            .build();
1455        server.build(server.make_app()).unwrap().route(
1456            "/test/timeout",
1457            get(forever.layer(ServiceBuilder::new().layer(timeout()))),
1458        )
1459    }
1460
1461    #[tokio::test]
1462    pub async fn test_cors() {
1463        // cors is on by default
1464        let (tx, _rx) = mpsc::channel(100);
1465        let app = make_test_app(tx);
1466        let client = TestClient::new(app).await;
1467
1468        let res = client.get("/health").send().await;
1469
1470        assert_eq!(res.status(), StatusCode::OK);
1471        assert_eq!(
1472            res.headers()
1473                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1474                .expect("expect cors header origin"),
1475            "*"
1476        );
1477
1478        let res = client.get("/v1/health").send().await;
1479
1480        assert_eq!(res.status(), StatusCode::OK);
1481        assert_eq!(
1482            res.headers()
1483                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1484                .expect("expect cors header origin"),
1485            "*"
1486        );
1487
1488        let res = client
1489            .options("/health")
1490            .header("Access-Control-Request-Headers", "x-greptime-auth")
1491            .header("Access-Control-Request-Method", "DELETE")
1492            .header("Origin", "https://example.com")
1493            .send()
1494            .await;
1495        assert_eq!(res.status(), StatusCode::OK);
1496        assert_eq!(
1497            res.headers()
1498                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1499                .expect("expect cors header origin"),
1500            "*"
1501        );
1502        assert_eq!(
1503            res.headers()
1504                .get(http::header::ACCESS_CONTROL_ALLOW_HEADERS)
1505                .expect("expect cors header headers"),
1506            "*"
1507        );
1508        assert_eq!(
1509            res.headers()
1510                .get(http::header::ACCESS_CONTROL_ALLOW_METHODS)
1511                .expect("expect cors header methods"),
1512            "GET,POST,PUT,DELETE,HEAD"
1513        );
1514    }
1515
1516    #[tokio::test]
1517    pub async fn test_cors_custom_origins() {
1518        // cors is on by default
1519        let (tx, _rx) = mpsc::channel(100);
1520        let origin = "https://example.com";
1521
1522        let options = HttpOptions {
1523            cors_allowed_origins: vec![origin.to_string()],
1524            ..Default::default()
1525        };
1526
1527        let app = make_test_app_custom(tx, options);
1528        let client = TestClient::new(app).await;
1529
1530        let res = client.get("/health").header("Origin", origin).send().await;
1531
1532        assert_eq!(res.status(), StatusCode::OK);
1533        assert_eq!(
1534            res.headers()
1535                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1536                .expect("expect cors header origin"),
1537            origin
1538        );
1539
1540        let res = client
1541            .get("/health")
1542            .header("Origin", "https://notallowed.com")
1543            .send()
1544            .await;
1545
1546        assert_eq!(res.status(), StatusCode::OK);
1547        assert!(
1548            !res.headers()
1549                .contains_key(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1550        );
1551    }
1552
1553    #[tokio::test]
1554    pub async fn test_cors_disabled() {
1555        // cors is on by default
1556        let (tx, _rx) = mpsc::channel(100);
1557
1558        let options = HttpOptions {
1559            enable_cors: false,
1560            ..Default::default()
1561        };
1562
1563        let app = make_test_app_custom(tx, options);
1564        let client = TestClient::new(app).await;
1565
1566        let res = client.get("/health").send().await;
1567
1568        assert_eq!(res.status(), StatusCode::OK);
1569        assert!(
1570            !res.headers()
1571                .contains_key(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1572        );
1573    }
1574
1575    #[test]
1576    fn test_http_options_default() {
1577        let default = HttpOptions::default();
1578        assert_eq!("127.0.0.1:4000".to_string(), default.addr);
1579        assert_eq!(Duration::from_secs(0), default.timeout)
1580    }
1581
1582    #[tokio::test]
1583    async fn test_http_server_request_timeout() {
1584        common_telemetry::init_default_ut_logging();
1585
1586        let (tx, _rx) = mpsc::channel(100);
1587        let app = make_test_app(tx);
1588        let client = TestClient::new(app).await;
1589        let res = client.get("/test/timeout").send().await;
1590        assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
1591
1592        let now = Instant::now();
1593        let res = client
1594            .get("/test/timeout")
1595            .header(GREPTIME_DB_HEADER_TIMEOUT, "20ms")
1596            .send()
1597            .await;
1598        assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
1599        let elapsed = now.elapsed();
1600        assert!(elapsed > Duration::from_millis(15));
1601
1602        tokio::time::timeout(
1603            Duration::from_millis(15),
1604            client
1605                .get("/test/timeout")
1606                .header(GREPTIME_DB_HEADER_TIMEOUT, "0s")
1607                .send(),
1608        )
1609        .await
1610        .unwrap_err();
1611
1612        tokio::time::timeout(
1613            Duration::from_millis(15),
1614            client
1615                .get("/test/timeout")
1616                .header(
1617                    GREPTIME_DB_HEADER_TIMEOUT,
1618                    humantime::format_duration(Duration::default()).to_string(),
1619                )
1620                .send(),
1621        )
1622        .await
1623        .unwrap_err();
1624    }
1625
1626    #[tokio::test]
1627    async fn test_schema_for_empty_response() {
1628        let column_schemas = vec![
1629            ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
1630            ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
1631        ];
1632        let schema = Arc::new(Schema::new(column_schemas));
1633
1634        let recordbatches = RecordBatches::try_new(schema.clone(), vec![]).unwrap();
1635        let outputs = vec![Ok(Output::new_with_record_batches(recordbatches))];
1636
1637        let json_resp = GreptimedbV1Response::from_output(outputs).await;
1638        if let HttpResponse::GreptimedbV1(json_resp) = json_resp {
1639            let json_output = &json_resp.output[0];
1640            if let GreptimeQueryOutput::Records(r) = json_output {
1641                assert_eq!(r.num_rows(), 0);
1642                assert_eq!(r.num_cols(), 2);
1643                assert_eq!(r.schema.column_schemas[0].name, "numbers");
1644                assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1645            } else {
1646                panic!("invalid output type");
1647            }
1648        } else {
1649            panic!("invalid format")
1650        }
1651    }
1652
1653    #[tokio::test]
1654    async fn test_recordbatches_conversion() {
1655        let column_schemas = vec![
1656            ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
1657            ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
1658        ];
1659        let schema = Arc::new(Schema::new(column_schemas));
1660        let columns: Vec<VectorRef> = vec![
1661            Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
1662            Arc::new(StringVector::from(vec![
1663                None,
1664                Some("hello"),
1665                Some("greptime"),
1666                None,
1667            ])),
1668        ];
1669        let recordbatch = RecordBatch::new(schema.clone(), columns).unwrap();
1670
1671        for format in [
1672            ResponseFormat::GreptimedbV1,
1673            ResponseFormat::InfluxdbV1,
1674            ResponseFormat::Csv(true, true),
1675            ResponseFormat::Table,
1676            ResponseFormat::Arrow,
1677            ResponseFormat::Json,
1678            ResponseFormat::Null,
1679        ] {
1680            let recordbatches =
1681                RecordBatches::try_new(schema.clone(), vec![recordbatch.clone()]).unwrap();
1682            let outputs = vec![Ok(Output::new_with_record_batches(recordbatches))];
1683            let json_resp = match format {
1684                ResponseFormat::Arrow => ArrowResponse::from_output(outputs, None).await,
1685                ResponseFormat::Csv(with_names, with_types) => {
1686                    CsvResponse::from_output(outputs, with_names, with_types).await
1687                }
1688                ResponseFormat::Table => TableResponse::from_output(outputs).await,
1689                ResponseFormat::GreptimedbV1 => GreptimedbV1Response::from_output(outputs).await,
1690                ResponseFormat::InfluxdbV1 => InfluxdbV1Response::from_output(outputs, None).await,
1691                ResponseFormat::Json => JsonResponse::from_output(outputs).await,
1692                ResponseFormat::Null => NullResponse::from_output(outputs).await,
1693            };
1694
1695            match json_resp {
1696                HttpResponse::GreptimedbV1(resp) => {
1697                    let json_output = &resp.output[0];
1698                    if let GreptimeQueryOutput::Records(r) = json_output {
1699                        assert_eq!(r.num_rows(), 4);
1700                        assert_eq!(r.num_cols(), 2);
1701                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1702                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1703                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1704                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1705                    } else {
1706                        panic!("invalid output type");
1707                    }
1708                }
1709                HttpResponse::InfluxdbV1(resp) => {
1710                    let json_output = &resp.results()[0];
1711                    assert_eq!(json_output.num_rows(), 4);
1712                    assert_eq!(json_output.num_cols(), 2);
1713                    assert_eq!(json_output.series[0].columns.clone()[0], "numbers");
1714                    assert_eq!(
1715                        json_output.series[0].values[0][0],
1716                        serde_json::Value::from(1)
1717                    );
1718                    assert_eq!(json_output.series[0].values[0][1], serde_json::Value::Null);
1719                }
1720                HttpResponse::Csv(resp) => {
1721                    let output = &resp.output()[0];
1722                    if let GreptimeQueryOutput::Records(r) = output {
1723                        assert_eq!(r.num_rows(), 4);
1724                        assert_eq!(r.num_cols(), 2);
1725                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1726                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1727                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1728                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1729                    } else {
1730                        panic!("invalid output type");
1731                    }
1732                }
1733
1734                HttpResponse::Table(resp) => {
1735                    let output = &resp.output()[0];
1736                    if let GreptimeQueryOutput::Records(r) = output {
1737                        assert_eq!(r.num_rows(), 4);
1738                        assert_eq!(r.num_cols(), 2);
1739                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1740                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1741                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1742                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1743                    } else {
1744                        panic!("invalid output type");
1745                    }
1746                }
1747
1748                HttpResponse::Arrow(resp) => {
1749                    let output = resp.data;
1750                    let mut reader =
1751                        FileReader::try_new(Cursor::new(output), None).expect("Arrow reader error");
1752                    let schema = reader.schema();
1753                    assert_eq!(schema.fields[0].name(), "numbers");
1754                    assert_eq!(schema.fields[0].data_type(), &DataType::UInt32);
1755                    assert_eq!(schema.fields[1].name(), "strings");
1756                    assert_eq!(schema.fields[1].data_type(), &DataType::Utf8);
1757
1758                    let rb = reader.next().unwrap().expect("read record batch failed");
1759                    assert_eq!(rb.num_columns(), 2);
1760                    assert_eq!(rb.num_rows(), 4);
1761                }
1762
1763                HttpResponse::Json(resp) => {
1764                    let output = &resp.output()[0];
1765                    if let GreptimeQueryOutput::Records(r) = output {
1766                        assert_eq!(r.num_rows(), 4);
1767                        assert_eq!(r.num_cols(), 2);
1768                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1769                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1770                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1771                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1772                    } else {
1773                        panic!("invalid output type");
1774                    }
1775                }
1776
1777                HttpResponse::Null(resp) => {
1778                    assert_eq!(resp.rows(), 4);
1779                }
1780
1781                HttpResponse::Error(err) => unreachable!("{err:?}"),
1782            }
1783        }
1784    }
1785
1786    #[test]
1787    fn test_response_format_misc() {
1788        assert_eq!(ResponseFormat::default(), ResponseFormat::GreptimedbV1);
1789        assert_eq!(ResponseFormat::parse("arrow"), Some(ResponseFormat::Arrow));
1790        assert_eq!(
1791            ResponseFormat::parse("csv"),
1792            Some(ResponseFormat::Csv(false, false))
1793        );
1794        assert_eq!(
1795            ResponseFormat::parse("csvwithnames"),
1796            Some(ResponseFormat::Csv(true, false))
1797        );
1798        assert_eq!(
1799            ResponseFormat::parse("csvwithnamesandtypes"),
1800            Some(ResponseFormat::Csv(true, true))
1801        );
1802        assert_eq!(ResponseFormat::parse("table"), Some(ResponseFormat::Table));
1803        assert_eq!(
1804            ResponseFormat::parse("greptimedb_v1"),
1805            Some(ResponseFormat::GreptimedbV1)
1806        );
1807        assert_eq!(
1808            ResponseFormat::parse("influxdb_v1"),
1809            Some(ResponseFormat::InfluxdbV1)
1810        );
1811        assert_eq!(ResponseFormat::parse("json"), Some(ResponseFormat::Json));
1812        assert_eq!(ResponseFormat::parse("null"), Some(ResponseFormat::Null));
1813
1814        // invalid formats
1815        assert_eq!(ResponseFormat::parse("invalid"), None);
1816        assert_eq!(ResponseFormat::parse(""), None);
1817        assert_eq!(ResponseFormat::parse("CSV"), None); // Case sensitive
1818
1819        // as str
1820        assert_eq!(ResponseFormat::Arrow.as_str(), "arrow");
1821        assert_eq!(ResponseFormat::Csv(false, false).as_str(), "csv");
1822        assert_eq!(ResponseFormat::Csv(true, true).as_str(), "csv");
1823        assert_eq!(ResponseFormat::Table.as_str(), "table");
1824        assert_eq!(ResponseFormat::GreptimedbV1.as_str(), "greptimedb_v1");
1825        assert_eq!(ResponseFormat::InfluxdbV1.as_str(), "influxdb_v1");
1826        assert_eq!(ResponseFormat::Json.as_str(), "json");
1827        assert_eq!(ResponseFormat::Null.as_str(), "null");
1828        assert_eq!(ResponseFormat::default().as_str(), "greptimedb_v1");
1829    }
1830
1831    #[test]
1832    fn test_decode_label_name_strict() {
1833        let strict = PromValidationMode::Strict;
1834
1835        // Valid Prometheus label names
1836        assert!(strict.decode_label_name(b"__name__").is_ok());
1837        assert!(strict.decode_label_name(b"job").is_ok());
1838        assert!(strict.decode_label_name(b"instance").is_ok());
1839        assert!(strict.decode_label_name(b"_private").is_ok());
1840        assert!(strict.decode_label_name(b"label_with_underscores").is_ok());
1841        assert!(strict.decode_label_name(b"abc123").is_ok());
1842        assert!(strict.decode_label_name(b"A").is_ok());
1843        assert!(strict.decode_label_name(b"_").is_ok());
1844
1845        // Invalid: starts with digit
1846        assert!(strict.decode_label_name(b"0abc").is_err());
1847        assert!(strict.decode_label_name(b"123").is_err());
1848
1849        // Invalid: contains special characters
1850        assert!(strict.decode_label_name(b"label-name").is_err());
1851        assert!(strict.decode_label_name(b"label.name").is_err());
1852        assert!(strict.decode_label_name(b"label name").is_err());
1853        assert!(strict.decode_label_name(b"label/name").is_err());
1854
1855        // Invalid: empty
1856        assert!(strict.decode_label_name(b"").is_err());
1857
1858        // Invalid: non-ASCII UTF-8
1859        assert!(strict.decode_label_name("ラベル".as_bytes()).is_err());
1860
1861        // Invalid UTF-8 bytes should fail
1862        assert!(strict.decode_label_name(&[0xff, 0xfe]).is_err());
1863    }
1864
1865    #[test]
1866    fn test_decode_label_name_lossy() {
1867        let lossy = PromValidationMode::Lossy;
1868
1869        // Label name validation is always enforced.
1870        assert!(lossy.decode_label_name(b"__name__").is_ok());
1871        assert!(lossy.decode_label_name(b"label-name").is_err());
1872        assert!(lossy.decode_label_name(b"0abc").is_err());
1873
1874        // Invalid UTF-8 bytes fail the label-name byte check.
1875        assert!(lossy.decode_label_name(&[0xff, 0xfe]).is_err());
1876    }
1877
1878    #[test]
1879    fn test_decode_label_name_unchecked() {
1880        let unchecked = PromValidationMode::Unchecked;
1881
1882        // Label name validation is always enforced.
1883        assert!(unchecked.decode_label_name(b"__name__").is_ok());
1884        assert!(unchecked.decode_label_name(b"label-name").is_err());
1885        assert!(unchecked.decode_label_name(b"0abc").is_err());
1886    }
1887
1888    #[test]
1889    fn test_is_valid_prom_label_name_bytes() {
1890        use super::validate_label_name;
1891
1892        assert!(validate_label_name(b"__name__"));
1893        assert!(validate_label_name(b"job"));
1894        assert!(validate_label_name(b"_"));
1895        assert!(validate_label_name(b"A"));
1896        assert!(validate_label_name(b"abc123"));
1897        assert!(validate_label_name(b"_leading_underscore"));
1898
1899        assert!(!validate_label_name(b""));
1900        assert!(!validate_label_name(b"0starts_with_digit"));
1901        assert!(!validate_label_name(b"has-dash"));
1902        assert!(!validate_label_name(b"has.dot"));
1903        assert!(!validate_label_name(b"has space"));
1904        assert!(!validate_label_name(&[0xff, 0xfe]));
1905    }
1906}