servers/
http.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::convert::Infallible;
17use std::fmt::Display;
18use std::net::SocketAddr;
19use std::sync::Mutex as StdMutex;
20use std::time::Duration;
21
22use async_trait::async_trait;
23use auth::UserProviderRef;
24use axum::extract::{DefaultBodyLimit, Request};
25use axum::http::StatusCode as HttpStatusCode;
26use axum::response::{IntoResponse, Response};
27use axum::routing::Route;
28use axum::serve::ListenerExt;
29use axum::{Router, middleware, routing};
30use common_base::Plugins;
31use common_base::readable_size::ReadableSize;
32use common_recordbatch::RecordBatch;
33use common_telemetry::{debug, error, info};
34use common_time::Timestamp;
35use common_time::timestamp::TimeUnit;
36use datatypes::data_type::DataType;
37use datatypes::schema::SchemaRef;
38use event::{LogState, LogValidatorRef};
39use futures::FutureExt;
40use http::{HeaderValue, Method};
41use prost::DecodeError;
42use serde::{Deserialize, Serialize};
43use serde_json::Value;
44use snafu::{ResultExt, ensure};
45use tokio::sync::Mutex;
46use tokio::sync::oneshot::{self, Sender};
47use tonic::codegen::Service;
48use tower::{Layer, ServiceBuilder};
49use tower_http::compression::CompressionLayer;
50use tower_http::cors::{AllowOrigin, Any, CorsLayer};
51use tower_http::decompression::RequestDecompressionLayer;
52use tower_http::trace::TraceLayer;
53
54use self::authorize::AuthState;
55use self::result::table_result::TableResponse;
56use crate::configurator::HttpConfiguratorRef;
57use crate::elasticsearch;
58use crate::error::{
59    AddressBindSnafu, AlreadyStartedSnafu, Error, InternalIoSnafu, InvalidHeaderValueSnafu,
60    OtherSnafu, Result,
61};
62use crate::http::influxdb::{influxdb_health, influxdb_ping, influxdb_write_v1, influxdb_write_v2};
63use crate::http::otlp::OtlpState;
64use crate::http::prom_store::PromStoreState;
65use crate::http::prometheus::{
66    build_info_query, format_query, instant_query, label_values_query, labels_query, parse_query,
67    range_query, series_query,
68};
69use crate::http::result::arrow_result::ArrowResponse;
70use crate::http::result::csv_result::CsvResponse;
71use crate::http::result::error_result::ErrorResponse;
72use crate::http::result::greptime_result_v1::GreptimedbV1Response;
73use crate::http::result::influxdb_result_v1::InfluxdbV1Response;
74use crate::http::result::json_result::JsonResponse;
75use crate::http::result::null_result::NullResponse;
76use crate::interceptor::LogIngestInterceptorRef;
77use crate::metrics::http_metrics_layer;
78use crate::metrics_handler::MetricsHandler;
79use crate::prometheus_handler::PrometheusHandlerRef;
80use crate::query_handler::sql::ServerSqlQueryHandlerRef;
81use crate::query_handler::{
82    InfluxdbLineProtocolHandlerRef, JaegerQueryHandlerRef, LogQueryHandlerRef,
83    OpenTelemetryProtocolHandlerRef, OpentsdbProtocolHandlerRef, PipelineHandlerRef,
84    PromStoreProtocolHandlerRef,
85};
86use crate::request_memory_limiter::ServerMemoryLimiter;
87use crate::server::Server;
88
89pub mod authorize;
90#[cfg(feature = "dashboard")]
91mod dashboard;
92pub mod dyn_log;
93pub mod dyn_trace;
94pub mod event;
95pub mod extractor;
96pub mod handler;
97pub mod header;
98pub mod influxdb;
99pub mod jaeger;
100pub mod logs;
101pub mod loki;
102pub mod mem_prof;
103mod memory_limit;
104pub mod opentsdb;
105pub mod otlp;
106pub mod pprof;
107pub mod prom_store;
108pub mod prometheus;
109pub mod result;
110mod timeout;
111pub mod utils;
112
113use result::HttpOutputWriter;
114pub(crate) use timeout::DynamicTimeoutLayer;
115
116mod hints;
117mod read_preference;
118#[cfg(any(test, feature = "testing"))]
119pub mod test_helpers;
120
121pub const HTTP_API_VERSION: &str = "v1";
122pub const HTTP_API_PREFIX: &str = "/v1/";
123/// Default http body limit (64M).
124const DEFAULT_BODY_LIMIT: ReadableSize = ReadableSize::mb(64);
125
126/// Authorization header
127pub const AUTHORIZATION_HEADER: &str = "x-greptime-auth";
128
129// TODO(fys): This is a temporary workaround, it will be improved later
130pub static PUBLIC_APIS: [&str; 3] = ["/v1/influxdb/ping", "/v1/influxdb/health", "/v1/health"];
131
132#[derive(Default)]
133pub struct HttpServer {
134    router: StdMutex<Router>,
135    shutdown_tx: Mutex<Option<Sender<()>>>,
136    user_provider: Option<UserProviderRef>,
137    memory_limiter: ServerMemoryLimiter,
138
139    // plugins
140    plugins: Plugins,
141
142    // server configs
143    options: HttpOptions,
144    bind_addr: Option<SocketAddr>,
145}
146
147#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
148#[serde(default)]
149pub struct HttpOptions {
150    pub addr: String,
151
152    #[serde(with = "humantime_serde")]
153    pub timeout: Duration,
154
155    #[serde(skip)]
156    pub disable_dashboard: bool,
157
158    pub body_limit: ReadableSize,
159
160    /// Validation mode while decoding Prometheus remote write requests.
161    pub prom_validation_mode: PromValidationMode,
162
163    pub cors_allowed_origins: Vec<String>,
164
165    pub enable_cors: bool,
166}
167
168#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
169#[serde(rename_all = "snake_case")]
170pub enum PromValidationMode {
171    /// Force UTF8 validation
172    Strict,
173    /// Allow lossy UTF8 strings
174    Lossy,
175    /// Do not validate UTF8 strings.
176    Unchecked,
177}
178
179impl PromValidationMode {
180    /// Decodes provided bytes to [String] with optional UTF-8 validation.
181    pub fn decode_string(&self, bytes: &[u8]) -> std::result::Result<String, DecodeError> {
182        let result = match self {
183            PromValidationMode::Strict => match String::from_utf8(bytes.to_vec()) {
184                Ok(s) => s,
185                Err(e) => {
186                    debug!("Invalid UTF-8 string value: {:?}, error: {:?}", bytes, e);
187                    return Err(DecodeError::new("invalid utf-8"));
188                }
189            },
190            PromValidationMode::Lossy => String::from_utf8_lossy(bytes).to_string(),
191            PromValidationMode::Unchecked => unsafe { String::from_utf8_unchecked(bytes.to_vec()) },
192        };
193        Ok(result)
194    }
195
196    pub(crate) fn validate_bytes(&self, bytes: &[u8]) -> std::result::Result<(), DecodeError> {
197        match self {
198            PromValidationMode::Strict => {
199                simdutf8::basic::from_utf8(bytes).map_err(|_| DecodeError::new("invalid utf-8"))?;
200                Ok(())
201            }
202            PromValidationMode::Lossy | PromValidationMode::Unchecked => Ok(()),
203        }
204    }
205}
206
207impl Default for HttpOptions {
208    fn default() -> Self {
209        Self {
210            addr: "127.0.0.1:4000".to_string(),
211            timeout: Duration::from_secs(0),
212            disable_dashboard: false,
213            body_limit: DEFAULT_BODY_LIMIT,
214            cors_allowed_origins: Vec::new(),
215            enable_cors: true,
216            prom_validation_mode: PromValidationMode::Strict,
217        }
218    }
219}
220
221#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
222pub struct ColumnSchema {
223    name: String,
224    data_type: String,
225}
226
227impl ColumnSchema {
228    pub fn new(name: String, data_type: String) -> ColumnSchema {
229        ColumnSchema { name, data_type }
230    }
231}
232
233#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
234pub struct OutputSchema {
235    column_schemas: Vec<ColumnSchema>,
236}
237
238impl OutputSchema {
239    pub fn new(columns: Vec<ColumnSchema>) -> OutputSchema {
240        OutputSchema {
241            column_schemas: columns,
242        }
243    }
244}
245
246impl From<SchemaRef> for OutputSchema {
247    fn from(schema: SchemaRef) -> OutputSchema {
248        OutputSchema {
249            column_schemas: schema
250                .column_schemas()
251                .iter()
252                .map(|cs| ColumnSchema {
253                    name: cs.name.clone(),
254                    data_type: cs.data_type.name(),
255                })
256                .collect(),
257        }
258    }
259}
260
261#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
262pub struct HttpRecordsOutput {
263    schema: OutputSchema,
264    rows: Vec<Vec<Value>>,
265    // total_rows is equal to rows.len() in most cases,
266    // the Dashboard query result may be truncated, so we need to return the total_rows.
267    #[serde(default)]
268    total_rows: usize,
269
270    // plan level execution metrics
271    #[serde(skip_serializing_if = "HashMap::is_empty")]
272    #[serde(default)]
273    metrics: HashMap<String, Value>,
274}
275
276impl HttpRecordsOutput {
277    pub fn num_rows(&self) -> usize {
278        self.rows.len()
279    }
280
281    pub fn num_cols(&self) -> usize {
282        self.schema.column_schemas.len()
283    }
284
285    pub fn schema(&self) -> &OutputSchema {
286        &self.schema
287    }
288
289    pub fn rows(&self) -> &Vec<Vec<Value>> {
290        &self.rows
291    }
292}
293
294impl HttpRecordsOutput {
295    pub fn try_new(
296        schema: SchemaRef,
297        recordbatches: Vec<RecordBatch>,
298    ) -> std::result::Result<HttpRecordsOutput, Error> {
299        if recordbatches.is_empty() {
300            Ok(HttpRecordsOutput {
301                schema: OutputSchema::from(schema),
302                rows: vec![],
303                total_rows: 0,
304                metrics: Default::default(),
305            })
306        } else {
307            let num_rows = recordbatches.iter().map(|r| r.num_rows()).sum::<usize>();
308            let mut rows = Vec::with_capacity(num_rows);
309
310            for recordbatch in recordbatches {
311                let mut writer = HttpOutputWriter::new(schema.num_columns(), None);
312                writer.write(recordbatch, &mut rows)?;
313            }
314
315            Ok(HttpRecordsOutput {
316                schema: OutputSchema::from(schema),
317                total_rows: rows.len(),
318                rows,
319                metrics: Default::default(),
320            })
321        }
322    }
323}
324
325#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
326#[serde(rename_all = "lowercase")]
327pub enum GreptimeQueryOutput {
328    AffectedRows(usize),
329    Records(HttpRecordsOutput),
330}
331
332/// It allows the results of SQL queries to be presented in different formats.
333#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
334pub enum ResponseFormat {
335    Arrow,
336    // (with_names, with_types)
337    Csv(bool, bool),
338    Table,
339    #[default]
340    GreptimedbV1,
341    InfluxdbV1,
342    Json,
343    Null,
344}
345
346impl ResponseFormat {
347    pub fn parse(s: &str) -> Option<Self> {
348        match s {
349            "arrow" => Some(ResponseFormat::Arrow),
350            "csv" => Some(ResponseFormat::Csv(false, false)),
351            "csvwithnames" => Some(ResponseFormat::Csv(true, false)),
352            "csvwithnamesandtypes" => Some(ResponseFormat::Csv(true, true)),
353            "table" => Some(ResponseFormat::Table),
354            "greptimedb_v1" => Some(ResponseFormat::GreptimedbV1),
355            "influxdb_v1" => Some(ResponseFormat::InfluxdbV1),
356            "json" => Some(ResponseFormat::Json),
357            "null" => Some(ResponseFormat::Null),
358            _ => None,
359        }
360    }
361
362    pub fn as_str(&self) -> &'static str {
363        match self {
364            ResponseFormat::Arrow => "arrow",
365            ResponseFormat::Csv(_, _) => "csv",
366            ResponseFormat::Table => "table",
367            ResponseFormat::GreptimedbV1 => "greptimedb_v1",
368            ResponseFormat::InfluxdbV1 => "influxdb_v1",
369            ResponseFormat::Json => "json",
370            ResponseFormat::Null => "null",
371        }
372    }
373}
374
375#[derive(Debug, Clone, Copy, PartialEq, Eq)]
376pub enum Epoch {
377    Nanosecond,
378    Microsecond,
379    Millisecond,
380    Second,
381}
382
383impl Epoch {
384    pub fn parse(s: &str) -> Option<Epoch> {
385        // Both u and µ indicate microseconds.
386        // epoch = [ns,u,µ,ms,s],
387        // For details, see the Influxdb documents.
388        // https://docs.influxdata.com/influxdb/v1/tools/api/#query-string-parameters-1
389        match s {
390            "ns" => Some(Epoch::Nanosecond),
391            "u" | "µ" => Some(Epoch::Microsecond),
392            "ms" => Some(Epoch::Millisecond),
393            "s" => Some(Epoch::Second),
394            _ => None, // just returns None for other cases
395        }
396    }
397
398    pub fn convert_timestamp(&self, ts: Timestamp) -> Option<Timestamp> {
399        match self {
400            Epoch::Nanosecond => ts.convert_to(TimeUnit::Nanosecond),
401            Epoch::Microsecond => ts.convert_to(TimeUnit::Microsecond),
402            Epoch::Millisecond => ts.convert_to(TimeUnit::Millisecond),
403            Epoch::Second => ts.convert_to(TimeUnit::Second),
404        }
405    }
406}
407
408impl Display for Epoch {
409    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
410        match self {
411            Epoch::Nanosecond => write!(f, "Epoch::Nanosecond"),
412            Epoch::Microsecond => write!(f, "Epoch::Microsecond"),
413            Epoch::Millisecond => write!(f, "Epoch::Millisecond"),
414            Epoch::Second => write!(f, "Epoch::Second"),
415        }
416    }
417}
418
419#[derive(Serialize, Deserialize, Debug)]
420pub enum HttpResponse {
421    Arrow(ArrowResponse),
422    Csv(CsvResponse),
423    Table(TableResponse),
424    Error(ErrorResponse),
425    GreptimedbV1(GreptimedbV1Response),
426    InfluxdbV1(InfluxdbV1Response),
427    Json(JsonResponse),
428    Null(NullResponse),
429}
430
431impl HttpResponse {
432    pub fn with_execution_time(self, execution_time: u64) -> Self {
433        match self {
434            HttpResponse::Arrow(resp) => resp.with_execution_time(execution_time).into(),
435            HttpResponse::Csv(resp) => resp.with_execution_time(execution_time).into(),
436            HttpResponse::Table(resp) => resp.with_execution_time(execution_time).into(),
437            HttpResponse::GreptimedbV1(resp) => resp.with_execution_time(execution_time).into(),
438            HttpResponse::InfluxdbV1(resp) => resp.with_execution_time(execution_time).into(),
439            HttpResponse::Json(resp) => resp.with_execution_time(execution_time).into(),
440            HttpResponse::Null(resp) => resp.with_execution_time(execution_time).into(),
441            HttpResponse::Error(resp) => resp.with_execution_time(execution_time).into(),
442        }
443    }
444
445    pub fn with_limit(self, limit: usize) -> Self {
446        match self {
447            HttpResponse::Csv(resp) => resp.with_limit(limit).into(),
448            HttpResponse::Table(resp) => resp.with_limit(limit).into(),
449            HttpResponse::GreptimedbV1(resp) => resp.with_limit(limit).into(),
450            HttpResponse::Json(resp) => resp.with_limit(limit).into(),
451            _ => self,
452        }
453    }
454}
455
456pub fn process_with_limit(
457    mut outputs: Vec<GreptimeQueryOutput>,
458    limit: usize,
459) -> Vec<GreptimeQueryOutput> {
460    outputs
461        .drain(..)
462        .map(|data| match data {
463            GreptimeQueryOutput::Records(mut records) => {
464                if records.rows.len() > limit {
465                    records.rows.truncate(limit);
466                    records.total_rows = limit;
467                }
468                GreptimeQueryOutput::Records(records)
469            }
470            _ => data,
471        })
472        .collect()
473}
474
475impl IntoResponse for HttpResponse {
476    fn into_response(self) -> Response {
477        match self {
478            HttpResponse::Arrow(resp) => resp.into_response(),
479            HttpResponse::Csv(resp) => resp.into_response(),
480            HttpResponse::Table(resp) => resp.into_response(),
481            HttpResponse::GreptimedbV1(resp) => resp.into_response(),
482            HttpResponse::InfluxdbV1(resp) => resp.into_response(),
483            HttpResponse::Json(resp) => resp.into_response(),
484            HttpResponse::Null(resp) => resp.into_response(),
485            HttpResponse::Error(resp) => resp.into_response(),
486        }
487    }
488}
489
490impl From<ArrowResponse> for HttpResponse {
491    fn from(value: ArrowResponse) -> Self {
492        HttpResponse::Arrow(value)
493    }
494}
495
496impl From<CsvResponse> for HttpResponse {
497    fn from(value: CsvResponse) -> Self {
498        HttpResponse::Csv(value)
499    }
500}
501
502impl From<TableResponse> for HttpResponse {
503    fn from(value: TableResponse) -> Self {
504        HttpResponse::Table(value)
505    }
506}
507
508impl From<ErrorResponse> for HttpResponse {
509    fn from(value: ErrorResponse) -> Self {
510        HttpResponse::Error(value)
511    }
512}
513
514impl From<GreptimedbV1Response> for HttpResponse {
515    fn from(value: GreptimedbV1Response) -> Self {
516        HttpResponse::GreptimedbV1(value)
517    }
518}
519
520impl From<InfluxdbV1Response> for HttpResponse {
521    fn from(value: InfluxdbV1Response) -> Self {
522        HttpResponse::InfluxdbV1(value)
523    }
524}
525
526impl From<JsonResponse> for HttpResponse {
527    fn from(value: JsonResponse) -> Self {
528        HttpResponse::Json(value)
529    }
530}
531
532impl From<NullResponse> for HttpResponse {
533    fn from(value: NullResponse) -> Self {
534        HttpResponse::Null(value)
535    }
536}
537
538#[derive(Clone)]
539pub struct ApiState {
540    pub sql_handler: ServerSqlQueryHandlerRef,
541}
542
543#[derive(Clone)]
544pub struct GreptimeOptionsConfigState {
545    pub greptime_config_options: String,
546}
547
548pub struct HttpServerBuilder {
549    options: HttpOptions,
550    plugins: Plugins,
551    user_provider: Option<UserProviderRef>,
552    router: Router,
553    memory_limiter: ServerMemoryLimiter,
554}
555
556impl HttpServerBuilder {
557    pub fn new(options: HttpOptions) -> Self {
558        Self {
559            options,
560            plugins: Plugins::default(),
561            user_provider: None,
562            router: Router::new(),
563            memory_limiter: ServerMemoryLimiter::default(),
564        }
565    }
566
567    /// Set a global memory limiter for all server protocols.
568    pub fn with_memory_limiter(mut self, limiter: ServerMemoryLimiter) -> Self {
569        self.memory_limiter = limiter;
570        self
571    }
572
573    pub fn with_sql_handler(self, sql_handler: ServerSqlQueryHandlerRef) -> Self {
574        let sql_router = HttpServer::route_sql(ApiState { sql_handler });
575
576        Self {
577            router: self
578                .router
579                .nest(&format!("/{HTTP_API_VERSION}"), sql_router),
580            ..self
581        }
582    }
583
584    pub fn with_logs_handler(self, logs_handler: LogQueryHandlerRef) -> Self {
585        let logs_router = HttpServer::route_logs(logs_handler);
586
587        Self {
588            router: self
589                .router
590                .nest(&format!("/{HTTP_API_VERSION}"), logs_router),
591            ..self
592        }
593    }
594
595    pub fn with_opentsdb_handler(self, handler: OpentsdbProtocolHandlerRef) -> Self {
596        Self {
597            router: self.router.nest(
598                &format!("/{HTTP_API_VERSION}/opentsdb"),
599                HttpServer::route_opentsdb(handler),
600            ),
601            ..self
602        }
603    }
604
605    pub fn with_influxdb_handler(self, handler: InfluxdbLineProtocolHandlerRef) -> Self {
606        Self {
607            router: self.router.nest(
608                &format!("/{HTTP_API_VERSION}/influxdb"),
609                HttpServer::route_influxdb(handler),
610            ),
611            ..self
612        }
613    }
614
615    pub fn with_prom_handler(
616        self,
617        handler: PromStoreProtocolHandlerRef,
618        pipeline_handler: Option<PipelineHandlerRef>,
619        prom_store_with_metric_engine: bool,
620        prom_validation_mode: PromValidationMode,
621    ) -> Self {
622        let state = PromStoreState {
623            prom_store_handler: handler,
624            pipeline_handler,
625            prom_store_with_metric_engine,
626            prom_validation_mode,
627        };
628
629        Self {
630            router: self.router.nest(
631                &format!("/{HTTP_API_VERSION}/prometheus"),
632                HttpServer::route_prom(state),
633            ),
634            ..self
635        }
636    }
637
638    pub fn with_prometheus_handler(self, handler: PrometheusHandlerRef) -> Self {
639        Self {
640            router: self.router.nest(
641                &format!("/{HTTP_API_VERSION}/prometheus/api/v1"),
642                HttpServer::route_prometheus(handler),
643            ),
644            ..self
645        }
646    }
647
648    pub fn with_otlp_handler(
649        self,
650        handler: OpenTelemetryProtocolHandlerRef,
651        with_metric_engine: bool,
652    ) -> Self {
653        Self {
654            router: self.router.nest(
655                &format!("/{HTTP_API_VERSION}/otlp"),
656                HttpServer::route_otlp(handler, with_metric_engine),
657            ),
658            ..self
659        }
660    }
661
662    pub fn with_user_provider(self, user_provider: UserProviderRef) -> Self {
663        Self {
664            user_provider: Some(user_provider),
665            ..self
666        }
667    }
668
669    pub fn with_metrics_handler(self, handler: MetricsHandler) -> Self {
670        Self {
671            router: self.router.merge(HttpServer::route_metrics(handler)),
672            ..self
673        }
674    }
675
676    pub fn with_log_ingest_handler(
677        self,
678        handler: PipelineHandlerRef,
679        validator: Option<LogValidatorRef>,
680        ingest_interceptor: Option<LogIngestInterceptorRef<Error>>,
681    ) -> Self {
682        let log_state = LogState {
683            log_handler: handler,
684            log_validator: validator,
685            ingest_interceptor,
686        };
687
688        let router = self.router.nest(
689            &format!("/{HTTP_API_VERSION}"),
690            HttpServer::route_pipelines(log_state.clone()),
691        );
692        // deprecated since v0.11.0. Use `/logs` and `/pipelines` instead.
693        let router = router.nest(
694            &format!("/{HTTP_API_VERSION}/events"),
695            #[allow(deprecated)]
696            HttpServer::route_log_deprecated(log_state.clone()),
697        );
698
699        let router = router.nest(
700            &format!("/{HTTP_API_VERSION}/loki"),
701            HttpServer::route_loki(log_state.clone()),
702        );
703
704        let router = router.nest(
705            &format!("/{HTTP_API_VERSION}/elasticsearch"),
706            HttpServer::route_elasticsearch(log_state.clone()),
707        );
708
709        let router = router.nest(
710            &format!("/{HTTP_API_VERSION}/elasticsearch/"),
711            Router::new()
712                .route("/", routing::get(elasticsearch::handle_get_version))
713                .with_state(log_state),
714        );
715
716        Self { router, ..self }
717    }
718
719    pub fn with_plugins(self, plugins: Plugins) -> Self {
720        Self { plugins, ..self }
721    }
722
723    pub fn with_greptime_config_options(self, opts: String) -> Self {
724        let config_router = HttpServer::route_config(GreptimeOptionsConfigState {
725            greptime_config_options: opts,
726        });
727
728        Self {
729            router: self.router.merge(config_router),
730            ..self
731        }
732    }
733
734    pub fn with_jaeger_handler(self, handler: JaegerQueryHandlerRef) -> Self {
735        Self {
736            router: self.router.nest(
737                &format!("/{HTTP_API_VERSION}/jaeger"),
738                HttpServer::route_jaeger(handler),
739            ),
740            ..self
741        }
742    }
743
744    pub fn with_extra_router(self, router: Router) -> Self {
745        Self {
746            router: self.router.merge(router),
747            ..self
748        }
749    }
750
751    pub fn add_layer<L>(self, layer: L) -> Self
752    where
753        L: Layer<Route> + Clone + Send + Sync + 'static,
754        L::Service: Service<Request> + Clone + Send + Sync + 'static,
755        <L::Service as Service<Request>>::Response: IntoResponse + 'static,
756        <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
757        <L::Service as Service<Request>>::Future: Send + 'static,
758    {
759        Self {
760            router: self.router.layer(layer),
761            ..self
762        }
763    }
764
765    pub fn build(self) -> HttpServer {
766        HttpServer {
767            options: self.options,
768            user_provider: self.user_provider,
769            shutdown_tx: Mutex::new(None),
770            plugins: self.plugins,
771            router: StdMutex::new(self.router),
772            bind_addr: None,
773            memory_limiter: self.memory_limiter,
774        }
775    }
776}
777
778impl HttpServer {
779    /// Gets the router and adds necessary root routes (health, status, dashboard).
780    pub fn make_app(&self) -> Router {
781        let mut router = {
782            let router = self.router.lock().unwrap();
783            router.clone()
784        };
785
786        router = router
787            .route(
788                "/health",
789                routing::get(handler::health).post(handler::health),
790            )
791            .route(
792                &format!("/{HTTP_API_VERSION}/health"),
793                routing::get(handler::health).post(handler::health),
794            )
795            .route(
796                "/ready",
797                routing::get(handler::health).post(handler::health),
798            );
799
800        router = router.route("/status", routing::get(handler::status));
801
802        #[cfg(feature = "dashboard")]
803        {
804            if !self.options.disable_dashboard {
805                info!("Enable dashboard service at '/dashboard'");
806                // redirect /dashboard to /dashboard/
807                router = router.route(
808                    "/dashboard",
809                    routing::get(|uri: axum::http::uri::Uri| async move {
810                        let path = uri.path();
811                        let query = uri.query().map(|q| format!("?{}", q)).unwrap_or_default();
812
813                        let new_uri = format!("{}/{}", path, query);
814                        axum::response::Redirect::permanent(&new_uri)
815                    }),
816                );
817
818                // "/dashboard" and "/dashboard/" are two different paths in Axum.
819                // We cannot nest "/dashboard/", because we already mapping "/dashboard/{*x}" while nesting "/dashboard".
820                // So we explicitly route "/dashboard/" here.
821                router = router
822                    .route(
823                        "/dashboard/",
824                        routing::get(dashboard::static_handler).post(dashboard::static_handler),
825                    )
826                    .route(
827                        "/dashboard/{*x}",
828                        routing::get(dashboard::static_handler).post(dashboard::static_handler),
829                    );
830            }
831        }
832
833        // Add a layer to collect HTTP metrics for axum.
834        router = router.route_layer(middleware::from_fn(http_metrics_layer));
835
836        router
837    }
838
839    /// Attaches middlewares and debug routes to the router.
840    /// Callers should call this method after [HttpServer::make_app()].
841    pub fn build(&self, router: Router) -> Result<Router> {
842        let timeout_layer = if self.options.timeout != Duration::default() {
843            Some(ServiceBuilder::new().layer(DynamicTimeoutLayer::new(self.options.timeout)))
844        } else {
845            info!("HTTP server timeout is disabled");
846            None
847        };
848        let body_limit_layer = if self.options.body_limit != ReadableSize(0) {
849            Some(
850                ServiceBuilder::new()
851                    .layer(DefaultBodyLimit::max(self.options.body_limit.0 as usize)),
852            )
853        } else {
854            info!("HTTP server body limit is disabled");
855            None
856        };
857        let cors_layer = if self.options.enable_cors {
858            Some(
859                CorsLayer::new()
860                    .allow_methods([
861                        Method::GET,
862                        Method::POST,
863                        Method::PUT,
864                        Method::DELETE,
865                        Method::HEAD,
866                    ])
867                    .allow_origin(if self.options.cors_allowed_origins.is_empty() {
868                        AllowOrigin::from(Any)
869                    } else {
870                        AllowOrigin::from(
871                            self.options
872                                .cors_allowed_origins
873                                .iter()
874                                .map(|s| {
875                                    HeaderValue::from_str(s.as_str())
876                                        .context(InvalidHeaderValueSnafu)
877                                })
878                                .collect::<Result<Vec<HeaderValue>>>()?,
879                        )
880                    })
881                    .allow_headers(Any),
882            )
883        } else {
884            info!("HTTP server cross-origin is disabled");
885            None
886        };
887
888        Ok(router
889            // middlewares
890            .layer(
891                ServiceBuilder::new()
892                    // disable on failure tracing. because printing out isn't very helpful,
893                    // and we have impl IntoResponse for Error. It will print out more detailed error messages
894                    .layer(TraceLayer::new_for_http().on_failure(()))
895                    .option_layer(cors_layer)
896                    .option_layer(timeout_layer)
897                    .option_layer(body_limit_layer)
898                    // memory limit layer - must be before body is consumed
899                    .layer(middleware::from_fn_with_state(
900                        self.memory_limiter.clone(),
901                        memory_limit::memory_limit_middleware,
902                    ))
903                    // auth layer
904                    .layer(middleware::from_fn_with_state(
905                        AuthState::new(self.user_provider.clone()),
906                        authorize::check_http_auth,
907                    ))
908                    .layer(middleware::from_fn(hints::extract_hints))
909                    .layer(middleware::from_fn(
910                        read_preference::extract_read_preference,
911                    )),
912            )
913            // Handlers for debug, we don't expect a timeout.
914            .nest(
915                "/debug",
916                Router::new()
917                    // handler for changing log level dynamically
918                    .route("/log_level", routing::post(dyn_log::dyn_log_handler))
919                    .route("/enable_trace", routing::post(dyn_trace::dyn_trace_handler))
920                    .nest(
921                        "/prof",
922                        Router::new()
923                            .route("/cpu", routing::post(pprof::pprof_handler))
924                            .route("/mem", routing::post(mem_prof::mem_prof_handler))
925                            .route("/mem/symbol", routing::post(mem_prof::symbolicate_handler))
926                            .route(
927                                "/mem/activate",
928                                routing::post(mem_prof::activate_heap_prof_handler),
929                            )
930                            .route(
931                                "/mem/deactivate",
932                                routing::post(mem_prof::deactivate_heap_prof_handler),
933                            )
934                            .route(
935                                "/mem/status",
936                                routing::get(mem_prof::heap_prof_status_handler),
937                            ) // jemalloc gdump flag status and toggle
938                            .route(
939                                "/mem/gdump",
940                                routing::get(mem_prof::gdump_status_handler)
941                                    .post(mem_prof::gdump_toggle_handler),
942                            ),
943                    ),
944            ))
945    }
946
947    fn route_metrics<S>(metrics_handler: MetricsHandler) -> Router<S> {
948        Router::new()
949            .route("/metrics", routing::get(handler::metrics))
950            .with_state(metrics_handler)
951    }
952
953    fn route_loki<S>(log_state: LogState) -> Router<S> {
954        Router::new()
955            .route("/api/v1/push", routing::post(loki::loki_ingest))
956            .layer(
957                ServiceBuilder::new()
958                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
959            )
960            .with_state(log_state)
961    }
962
963    fn route_elasticsearch<S>(log_state: LogState) -> Router<S> {
964        Router::new()
965            // Return fake responsefor HEAD '/' request.
966            .route(
967                "/",
968                routing::head((HttpStatusCode::OK, elasticsearch::elasticsearch_headers())),
969            )
970            // Return fake response for Elasticsearch version request.
971            .route("/", routing::get(elasticsearch::handle_get_version))
972            // Return fake response for Elasticsearch license request.
973            .route("/_license", routing::get(elasticsearch::handle_get_license))
974            .route("/_bulk", routing::post(elasticsearch::handle_bulk_api))
975            .route(
976                "/{index}/_bulk",
977                routing::post(elasticsearch::handle_bulk_api_with_index),
978            )
979            // Return fake response for Elasticsearch ilm request.
980            .route(
981                "/_ilm/policy/{*path}",
982                routing::any((
983                    HttpStatusCode::OK,
984                    elasticsearch::elasticsearch_headers(),
985                    axum::Json(serde_json::json!({})),
986                )),
987            )
988            // Return fake response for Elasticsearch index template request.
989            .route(
990                "/_index_template/{*path}",
991                routing::any((
992                    HttpStatusCode::OK,
993                    elasticsearch::elasticsearch_headers(),
994                    axum::Json(serde_json::json!({})),
995                )),
996            )
997            // Return fake response for Elasticsearch ingest pipeline request.
998            // See: https://www.elastic.co/guide/en/elasticsearch/reference/8.8/put-pipeline-api.html.
999            .route(
1000                "/_ingest/{*path}",
1001                routing::any((
1002                    HttpStatusCode::OK,
1003                    elasticsearch::elasticsearch_headers(),
1004                    axum::Json(serde_json::json!({})),
1005                )),
1006            )
1007            // Return fake response for Elasticsearch nodes discovery request.
1008            // See: https://www.elastic.co/guide/en/elasticsearch/reference/8.8/cluster.html.
1009            .route(
1010                "/_nodes/{*path}",
1011                routing::any((
1012                    HttpStatusCode::OK,
1013                    elasticsearch::elasticsearch_headers(),
1014                    axum::Json(serde_json::json!({})),
1015                )),
1016            )
1017            // Return fake response for Logstash APIs requests.
1018            // See: https://www.elastic.co/guide/en/elasticsearch/reference/8.8/logstash-apis.html
1019            .route(
1020                "/logstash/{*path}",
1021                routing::any((
1022                    HttpStatusCode::OK,
1023                    elasticsearch::elasticsearch_headers(),
1024                    axum::Json(serde_json::json!({})),
1025                )),
1026            )
1027            .route(
1028                "/_logstash/{*path}",
1029                routing::any((
1030                    HttpStatusCode::OK,
1031                    elasticsearch::elasticsearch_headers(),
1032                    axum::Json(serde_json::json!({})),
1033                )),
1034            )
1035            .layer(ServiceBuilder::new().layer(RequestDecompressionLayer::new()))
1036            .with_state(log_state)
1037    }
1038
1039    #[deprecated(since = "0.11.0", note = "Use `route_pipelines()` instead.")]
1040    fn route_log_deprecated<S>(log_state: LogState) -> Router<S> {
1041        Router::new()
1042            .route("/logs", routing::post(event::log_ingester))
1043            .route(
1044                "/pipelines/{pipeline_name}",
1045                routing::get(event::query_pipeline),
1046            )
1047            .route(
1048                "/pipelines/{pipeline_name}",
1049                routing::post(event::add_pipeline),
1050            )
1051            .route(
1052                "/pipelines/{pipeline_name}",
1053                routing::delete(event::delete_pipeline),
1054            )
1055            .route("/pipelines/dryrun", routing::post(event::pipeline_dryrun))
1056            .layer(
1057                ServiceBuilder::new()
1058                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1059            )
1060            .with_state(log_state)
1061    }
1062
1063    fn route_pipelines<S>(log_state: LogState) -> Router<S> {
1064        Router::new()
1065            .route("/ingest", routing::post(event::log_ingester))
1066            .route(
1067                "/pipelines/{pipeline_name}",
1068                routing::get(event::query_pipeline),
1069            )
1070            .route(
1071                "/pipelines/{pipeline_name}/ddl",
1072                routing::get(event::query_pipeline_ddl),
1073            )
1074            .route(
1075                "/pipelines/{pipeline_name}",
1076                routing::post(event::add_pipeline),
1077            )
1078            .route(
1079                "/pipelines/{pipeline_name}",
1080                routing::delete(event::delete_pipeline),
1081            )
1082            .route("/pipelines/_dryrun", routing::post(event::pipeline_dryrun))
1083            .layer(
1084                ServiceBuilder::new()
1085                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1086            )
1087            .with_state(log_state)
1088    }
1089
1090    fn route_sql<S>(api_state: ApiState) -> Router<S> {
1091        Router::new()
1092            .route("/sql", routing::get(handler::sql).post(handler::sql))
1093            .route(
1094                "/sql/parse",
1095                routing::get(handler::sql_parse).post(handler::sql_parse),
1096            )
1097            .route(
1098                "/sql/format",
1099                routing::get(handler::sql_format).post(handler::sql_format),
1100            )
1101            .route(
1102                "/promql",
1103                routing::get(handler::promql).post(handler::promql),
1104            )
1105            .with_state(api_state)
1106    }
1107
1108    fn route_logs<S>(log_handler: LogQueryHandlerRef) -> Router<S> {
1109        Router::new()
1110            .route("/logs", routing::get(logs::logs).post(logs::logs))
1111            .with_state(log_handler)
1112    }
1113
1114    /// Route Prometheus [HTTP API].
1115    ///
1116    /// [HTTP API]: https://prometheus.io/docs/prometheus/latest/querying/api/
1117    pub fn route_prometheus<S>(prometheus_handler: PrometheusHandlerRef) -> Router<S> {
1118        Router::new()
1119            .route(
1120                "/format_query",
1121                routing::post(format_query).get(format_query),
1122            )
1123            .route("/status/buildinfo", routing::get(build_info_query))
1124            .route("/query", routing::post(instant_query).get(instant_query))
1125            .route("/query_range", routing::post(range_query).get(range_query))
1126            .route("/labels", routing::post(labels_query).get(labels_query))
1127            .route("/series", routing::post(series_query).get(series_query))
1128            .route("/parse_query", routing::post(parse_query).get(parse_query))
1129            .route(
1130                "/label/{label_name}/values",
1131                routing::get(label_values_query),
1132            )
1133            .layer(ServiceBuilder::new().layer(CompressionLayer::new()))
1134            .with_state(prometheus_handler)
1135    }
1136
1137    /// Route Prometheus remote [read] and [write] API. In other places the related modules are
1138    /// called `prom_store`.
1139    ///
1140    /// [read]: https://prometheus.io/docs/prometheus/latest/querying/remote_read_api/
1141    /// [write]: https://prometheus.io/docs/concepts/remote_write_spec/
1142    fn route_prom<S>(state: PromStoreState) -> Router<S> {
1143        Router::new()
1144            .route("/read", routing::post(prom_store::remote_read))
1145            .route("/write", routing::post(prom_store::remote_write))
1146            .with_state(state)
1147    }
1148
1149    fn route_influxdb<S>(influxdb_handler: InfluxdbLineProtocolHandlerRef) -> Router<S> {
1150        Router::new()
1151            .route("/write", routing::post(influxdb_write_v1))
1152            .route("/api/v2/write", routing::post(influxdb_write_v2))
1153            .layer(
1154                ServiceBuilder::new()
1155                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1156            )
1157            .route("/ping", routing::get(influxdb_ping))
1158            .route("/health", routing::get(influxdb_health))
1159            .with_state(influxdb_handler)
1160    }
1161
1162    fn route_opentsdb<S>(opentsdb_handler: OpentsdbProtocolHandlerRef) -> Router<S> {
1163        Router::new()
1164            .route("/api/put", routing::post(opentsdb::put))
1165            .with_state(opentsdb_handler)
1166    }
1167
1168    fn route_otlp<S>(
1169        otlp_handler: OpenTelemetryProtocolHandlerRef,
1170        with_metric_engine: bool,
1171    ) -> Router<S> {
1172        Router::new()
1173            .route("/v1/metrics", routing::post(otlp::metrics))
1174            .route("/v1/traces", routing::post(otlp::traces))
1175            .route("/v1/logs", routing::post(otlp::logs))
1176            .layer(
1177                ServiceBuilder::new()
1178                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1179            )
1180            .with_state(OtlpState {
1181                with_metric_engine,
1182                handler: otlp_handler,
1183            })
1184    }
1185
1186    fn route_config<S>(state: GreptimeOptionsConfigState) -> Router<S> {
1187        Router::new()
1188            .route("/config", routing::get(handler::config))
1189            .with_state(state)
1190    }
1191
1192    fn route_jaeger<S>(handler: JaegerQueryHandlerRef) -> Router<S> {
1193        Router::new()
1194            .route("/api/services", routing::get(jaeger::handle_get_services))
1195            .route(
1196                "/api/services/{service_name}/operations",
1197                routing::get(jaeger::handle_get_operations_by_service),
1198            )
1199            .route(
1200                "/api/operations",
1201                routing::get(jaeger::handle_get_operations),
1202            )
1203            .route("/api/traces", routing::get(jaeger::handle_find_traces))
1204            .route(
1205                "/api/traces/{trace_id}",
1206                routing::get(jaeger::handle_get_trace),
1207            )
1208            .with_state(handler)
1209    }
1210}
1211
1212pub const HTTP_SERVER: &str = "HTTP_SERVER";
1213
1214#[async_trait]
1215impl Server for HttpServer {
1216    async fn shutdown(&self) -> Result<()> {
1217        let mut shutdown_tx = self.shutdown_tx.lock().await;
1218        if let Some(tx) = shutdown_tx.take()
1219            && tx.send(()).is_err()
1220        {
1221            info!("Receiver dropped, the HTTP server has already exited");
1222        }
1223        info!("Shutdown HTTP server");
1224
1225        Ok(())
1226    }
1227
1228    async fn start(&mut self, listening: SocketAddr) -> Result<()> {
1229        let (tx, rx) = oneshot::channel();
1230        let serve = {
1231            let mut shutdown_tx = self.shutdown_tx.lock().await;
1232            ensure!(
1233                shutdown_tx.is_none(),
1234                AlreadyStartedSnafu { server: "HTTP" }
1235            );
1236
1237            let mut app = self.make_app();
1238            if let Some(configurator) = self.plugins.get::<HttpConfiguratorRef<()>>() {
1239                app = configurator
1240                    .configure_http(app, ())
1241                    .await
1242                    .context(OtherSnafu)?;
1243            }
1244            let app = self.build(app)?;
1245            let listener = tokio::net::TcpListener::bind(listening)
1246                .await
1247                .context(AddressBindSnafu { addr: listening })?
1248                .tap_io(|tcp_stream| {
1249                    if let Err(e) = tcp_stream.set_nodelay(true) {
1250                        error!(e; "Failed to set TCP_NODELAY on incoming connection");
1251                    }
1252                });
1253            let serve = axum::serve(listener, app.into_make_service());
1254
1255            // FIXME(yingwen): Support keepalive.
1256            // See:
1257            // - https://github.com/tokio-rs/axum/discussions/2939
1258            // - https://stackoverflow.com/questions/73069718/how-do-i-keep-alive-tokiotcpstream-in-rust
1259            // let server = axum::Server::try_bind(&listening)
1260            //     .with_context(|_| AddressBindSnafu { addr: listening })?
1261            //     .tcp_nodelay(true)
1262            //     // Enable TCP keepalive to close the dangling established connections.
1263            //     // It's configured to let the keepalive probes first send after the connection sits
1264            //     // idle for 59 minutes, and then send every 10 seconds for 6 times.
1265            //     // So the connection will be closed after roughly 1 hour.
1266            //     .tcp_keepalive(Some(Duration::from_secs(59 * 60)))
1267            //     .tcp_keepalive_interval(Some(Duration::from_secs(10)))
1268            //     .tcp_keepalive_retries(Some(6))
1269            //     .serve(app.into_make_service());
1270
1271            *shutdown_tx = Some(tx);
1272
1273            serve
1274        };
1275        let listening = serve.local_addr().context(InternalIoSnafu)?;
1276        info!("HTTP server is bound to {}", listening);
1277
1278        common_runtime::spawn_global(async move {
1279            if let Err(e) = serve
1280                .with_graceful_shutdown(rx.map(drop))
1281                .await
1282                .context(InternalIoSnafu)
1283            {
1284                error!(e; "Failed to shutdown http server");
1285            }
1286        });
1287
1288        self.bind_addr = Some(listening);
1289        Ok(())
1290    }
1291
1292    fn name(&self) -> &str {
1293        HTTP_SERVER
1294    }
1295
1296    fn bind_addr(&self) -> Option<SocketAddr> {
1297        self.bind_addr
1298    }
1299
1300    fn as_any(&self) -> &dyn std::any::Any {
1301        self
1302    }
1303}
1304
1305#[cfg(test)]
1306mod test {
1307    use std::future::pending;
1308    use std::io::Cursor;
1309    use std::sync::Arc;
1310
1311    use arrow_ipc::reader::FileReader;
1312    use arrow_schema::DataType;
1313    use axum::handler::Handler;
1314    use axum::http::StatusCode;
1315    use axum::routing::get;
1316    use common_query::Output;
1317    use common_recordbatch::RecordBatches;
1318    use datafusion_expr::LogicalPlan;
1319    use datatypes::prelude::*;
1320    use datatypes::schema::{ColumnSchema, Schema};
1321    use datatypes::vectors::{StringVector, UInt32Vector};
1322    use header::constants::GREPTIME_DB_HEADER_TIMEOUT;
1323    use query::parser::PromQuery;
1324    use query::query_engine::DescribeResult;
1325    use session::context::QueryContextRef;
1326    use sql::statements::statement::Statement;
1327    use tokio::sync::mpsc;
1328    use tokio::time::Instant;
1329
1330    use super::*;
1331    use crate::http::test_helpers::TestClient;
1332    use crate::query_handler::sql::SqlQueryHandler;
1333
1334    struct DummyInstance {
1335        _tx: mpsc::Sender<(String, Vec<u8>)>,
1336    }
1337
1338    #[async_trait]
1339    impl SqlQueryHandler for DummyInstance {
1340        async fn do_query(&self, _: &str, _: QueryContextRef) -> Vec<Result<Output>> {
1341            unimplemented!()
1342        }
1343
1344        async fn do_promql_query(&self, _: &PromQuery, _: QueryContextRef) -> Vec<Result<Output>> {
1345            unimplemented!()
1346        }
1347
1348        async fn do_exec_plan(
1349            &self,
1350            _stmt: Option<Statement>,
1351            _plan: LogicalPlan,
1352            _query_ctx: QueryContextRef,
1353        ) -> Result<Output> {
1354            unimplemented!()
1355        }
1356
1357        async fn do_describe(
1358            &self,
1359            _stmt: sql::statements::statement::Statement,
1360            _query_ctx: QueryContextRef,
1361        ) -> Result<Option<DescribeResult>> {
1362            unimplemented!()
1363        }
1364
1365        async fn is_valid_schema(&self, _catalog: &str, _schema: &str) -> Result<bool> {
1366            Ok(true)
1367        }
1368    }
1369
1370    fn timeout() -> DynamicTimeoutLayer {
1371        DynamicTimeoutLayer::new(Duration::from_millis(10))
1372    }
1373
1374    async fn forever() {
1375        pending().await
1376    }
1377
1378    fn make_test_app(tx: mpsc::Sender<(String, Vec<u8>)>) -> Router {
1379        make_test_app_custom(tx, HttpOptions::default())
1380    }
1381
1382    fn make_test_app_custom(tx: mpsc::Sender<(String, Vec<u8>)>, options: HttpOptions) -> Router {
1383        let instance = Arc::new(DummyInstance { _tx: tx });
1384        let server = HttpServerBuilder::new(options)
1385            .with_sql_handler(instance.clone())
1386            .build();
1387        server.build(server.make_app()).unwrap().route(
1388            "/test/timeout",
1389            get(forever.layer(ServiceBuilder::new().layer(timeout()))),
1390        )
1391    }
1392
1393    #[tokio::test]
1394    pub async fn test_cors() {
1395        // cors is on by default
1396        let (tx, _rx) = mpsc::channel(100);
1397        let app = make_test_app(tx);
1398        let client = TestClient::new(app).await;
1399
1400        let res = client.get("/health").send().await;
1401
1402        assert_eq!(res.status(), StatusCode::OK);
1403        assert_eq!(
1404            res.headers()
1405                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1406                .expect("expect cors header origin"),
1407            "*"
1408        );
1409
1410        let res = client.get("/v1/health").send().await;
1411
1412        assert_eq!(res.status(), StatusCode::OK);
1413        assert_eq!(
1414            res.headers()
1415                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1416                .expect("expect cors header origin"),
1417            "*"
1418        );
1419
1420        let res = client
1421            .options("/health")
1422            .header("Access-Control-Request-Headers", "x-greptime-auth")
1423            .header("Access-Control-Request-Method", "DELETE")
1424            .header("Origin", "https://example.com")
1425            .send()
1426            .await;
1427        assert_eq!(res.status(), StatusCode::OK);
1428        assert_eq!(
1429            res.headers()
1430                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1431                .expect("expect cors header origin"),
1432            "*"
1433        );
1434        assert_eq!(
1435            res.headers()
1436                .get(http::header::ACCESS_CONTROL_ALLOW_HEADERS)
1437                .expect("expect cors header headers"),
1438            "*"
1439        );
1440        assert_eq!(
1441            res.headers()
1442                .get(http::header::ACCESS_CONTROL_ALLOW_METHODS)
1443                .expect("expect cors header methods"),
1444            "GET,POST,PUT,DELETE,HEAD"
1445        );
1446    }
1447
1448    #[tokio::test]
1449    pub async fn test_cors_custom_origins() {
1450        // cors is on by default
1451        let (tx, _rx) = mpsc::channel(100);
1452        let origin = "https://example.com";
1453
1454        let options = HttpOptions {
1455            cors_allowed_origins: vec![origin.to_string()],
1456            ..Default::default()
1457        };
1458
1459        let app = make_test_app_custom(tx, options);
1460        let client = TestClient::new(app).await;
1461
1462        let res = client.get("/health").header("Origin", origin).send().await;
1463
1464        assert_eq!(res.status(), StatusCode::OK);
1465        assert_eq!(
1466            res.headers()
1467                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1468                .expect("expect cors header origin"),
1469            origin
1470        );
1471
1472        let res = client
1473            .get("/health")
1474            .header("Origin", "https://notallowed.com")
1475            .send()
1476            .await;
1477
1478        assert_eq!(res.status(), StatusCode::OK);
1479        assert!(
1480            !res.headers()
1481                .contains_key(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1482        );
1483    }
1484
1485    #[tokio::test]
1486    pub async fn test_cors_disabled() {
1487        // cors is on by default
1488        let (tx, _rx) = mpsc::channel(100);
1489
1490        let options = HttpOptions {
1491            enable_cors: false,
1492            ..Default::default()
1493        };
1494
1495        let app = make_test_app_custom(tx, options);
1496        let client = TestClient::new(app).await;
1497
1498        let res = client.get("/health").send().await;
1499
1500        assert_eq!(res.status(), StatusCode::OK);
1501        assert!(
1502            !res.headers()
1503                .contains_key(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1504        );
1505    }
1506
1507    #[test]
1508    fn test_http_options_default() {
1509        let default = HttpOptions::default();
1510        assert_eq!("127.0.0.1:4000".to_string(), default.addr);
1511        assert_eq!(Duration::from_secs(0), default.timeout)
1512    }
1513
1514    #[tokio::test]
1515    async fn test_http_server_request_timeout() {
1516        common_telemetry::init_default_ut_logging();
1517
1518        let (tx, _rx) = mpsc::channel(100);
1519        let app = make_test_app(tx);
1520        let client = TestClient::new(app).await;
1521        let res = client.get("/test/timeout").send().await;
1522        assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
1523
1524        let now = Instant::now();
1525        let res = client
1526            .get("/test/timeout")
1527            .header(GREPTIME_DB_HEADER_TIMEOUT, "20ms")
1528            .send()
1529            .await;
1530        assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
1531        let elapsed = now.elapsed();
1532        assert!(elapsed > Duration::from_millis(15));
1533
1534        tokio::time::timeout(
1535            Duration::from_millis(15),
1536            client
1537                .get("/test/timeout")
1538                .header(GREPTIME_DB_HEADER_TIMEOUT, "0s")
1539                .send(),
1540        )
1541        .await
1542        .unwrap_err();
1543
1544        tokio::time::timeout(
1545            Duration::from_millis(15),
1546            client
1547                .get("/test/timeout")
1548                .header(
1549                    GREPTIME_DB_HEADER_TIMEOUT,
1550                    humantime::format_duration(Duration::default()).to_string(),
1551                )
1552                .send(),
1553        )
1554        .await
1555        .unwrap_err();
1556    }
1557
1558    #[tokio::test]
1559    async fn test_schema_for_empty_response() {
1560        let column_schemas = vec![
1561            ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
1562            ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
1563        ];
1564        let schema = Arc::new(Schema::new(column_schemas));
1565
1566        let recordbatches = RecordBatches::try_new(schema.clone(), vec![]).unwrap();
1567        let outputs = vec![Ok(Output::new_with_record_batches(recordbatches))];
1568
1569        let json_resp = GreptimedbV1Response::from_output(outputs).await;
1570        if let HttpResponse::GreptimedbV1(json_resp) = json_resp {
1571            let json_output = &json_resp.output[0];
1572            if let GreptimeQueryOutput::Records(r) = json_output {
1573                assert_eq!(r.num_rows(), 0);
1574                assert_eq!(r.num_cols(), 2);
1575                assert_eq!(r.schema.column_schemas[0].name, "numbers");
1576                assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1577            } else {
1578                panic!("invalid output type");
1579            }
1580        } else {
1581            panic!("invalid format")
1582        }
1583    }
1584
1585    #[tokio::test]
1586    async fn test_recordbatches_conversion() {
1587        let column_schemas = vec![
1588            ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
1589            ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
1590        ];
1591        let schema = Arc::new(Schema::new(column_schemas));
1592        let columns: Vec<VectorRef> = vec![
1593            Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
1594            Arc::new(StringVector::from(vec![
1595                None,
1596                Some("hello"),
1597                Some("greptime"),
1598                None,
1599            ])),
1600        ];
1601        let recordbatch = RecordBatch::new(schema.clone(), columns).unwrap();
1602
1603        for format in [
1604            ResponseFormat::GreptimedbV1,
1605            ResponseFormat::InfluxdbV1,
1606            ResponseFormat::Csv(true, true),
1607            ResponseFormat::Table,
1608            ResponseFormat::Arrow,
1609            ResponseFormat::Json,
1610            ResponseFormat::Null,
1611        ] {
1612            let recordbatches =
1613                RecordBatches::try_new(schema.clone(), vec![recordbatch.clone()]).unwrap();
1614            let outputs = vec![Ok(Output::new_with_record_batches(recordbatches))];
1615            let json_resp = match format {
1616                ResponseFormat::Arrow => ArrowResponse::from_output(outputs, None).await,
1617                ResponseFormat::Csv(with_names, with_types) => {
1618                    CsvResponse::from_output(outputs, with_names, with_types).await
1619                }
1620                ResponseFormat::Table => TableResponse::from_output(outputs).await,
1621                ResponseFormat::GreptimedbV1 => GreptimedbV1Response::from_output(outputs).await,
1622                ResponseFormat::InfluxdbV1 => InfluxdbV1Response::from_output(outputs, None).await,
1623                ResponseFormat::Json => JsonResponse::from_output(outputs).await,
1624                ResponseFormat::Null => NullResponse::from_output(outputs).await,
1625            };
1626
1627            match json_resp {
1628                HttpResponse::GreptimedbV1(resp) => {
1629                    let json_output = &resp.output[0];
1630                    if let GreptimeQueryOutput::Records(r) = json_output {
1631                        assert_eq!(r.num_rows(), 4);
1632                        assert_eq!(r.num_cols(), 2);
1633                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1634                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1635                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1636                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1637                    } else {
1638                        panic!("invalid output type");
1639                    }
1640                }
1641                HttpResponse::InfluxdbV1(resp) => {
1642                    let json_output = &resp.results()[0];
1643                    assert_eq!(json_output.num_rows(), 4);
1644                    assert_eq!(json_output.num_cols(), 2);
1645                    assert_eq!(json_output.series[0].columns.clone()[0], "numbers");
1646                    assert_eq!(
1647                        json_output.series[0].values[0][0],
1648                        serde_json::Value::from(1)
1649                    );
1650                    assert_eq!(json_output.series[0].values[0][1], serde_json::Value::Null);
1651                }
1652                HttpResponse::Csv(resp) => {
1653                    let output = &resp.output()[0];
1654                    if let GreptimeQueryOutput::Records(r) = output {
1655                        assert_eq!(r.num_rows(), 4);
1656                        assert_eq!(r.num_cols(), 2);
1657                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1658                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1659                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1660                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1661                    } else {
1662                        panic!("invalid output type");
1663                    }
1664                }
1665
1666                HttpResponse::Table(resp) => {
1667                    let output = &resp.output()[0];
1668                    if let GreptimeQueryOutput::Records(r) = output {
1669                        assert_eq!(r.num_rows(), 4);
1670                        assert_eq!(r.num_cols(), 2);
1671                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1672                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1673                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1674                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1675                    } else {
1676                        panic!("invalid output type");
1677                    }
1678                }
1679
1680                HttpResponse::Arrow(resp) => {
1681                    let output = resp.data;
1682                    let mut reader =
1683                        FileReader::try_new(Cursor::new(output), None).expect("Arrow reader error");
1684                    let schema = reader.schema();
1685                    assert_eq!(schema.fields[0].name(), "numbers");
1686                    assert_eq!(schema.fields[0].data_type(), &DataType::UInt32);
1687                    assert_eq!(schema.fields[1].name(), "strings");
1688                    assert_eq!(schema.fields[1].data_type(), &DataType::Utf8);
1689
1690                    let rb = reader.next().unwrap().expect("read record batch failed");
1691                    assert_eq!(rb.num_columns(), 2);
1692                    assert_eq!(rb.num_rows(), 4);
1693                }
1694
1695                HttpResponse::Json(resp) => {
1696                    let output = &resp.output()[0];
1697                    if let GreptimeQueryOutput::Records(r) = output {
1698                        assert_eq!(r.num_rows(), 4);
1699                        assert_eq!(r.num_cols(), 2);
1700                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1701                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1702                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1703                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1704                    } else {
1705                        panic!("invalid output type");
1706                    }
1707                }
1708
1709                HttpResponse::Null(resp) => {
1710                    assert_eq!(resp.rows(), 4);
1711                }
1712
1713                HttpResponse::Error(err) => unreachable!("{err:?}"),
1714            }
1715        }
1716    }
1717
1718    #[test]
1719    fn test_response_format_misc() {
1720        assert_eq!(ResponseFormat::default(), ResponseFormat::GreptimedbV1);
1721        assert_eq!(ResponseFormat::parse("arrow"), Some(ResponseFormat::Arrow));
1722        assert_eq!(
1723            ResponseFormat::parse("csv"),
1724            Some(ResponseFormat::Csv(false, false))
1725        );
1726        assert_eq!(
1727            ResponseFormat::parse("csvwithnames"),
1728            Some(ResponseFormat::Csv(true, false))
1729        );
1730        assert_eq!(
1731            ResponseFormat::parse("csvwithnamesandtypes"),
1732            Some(ResponseFormat::Csv(true, true))
1733        );
1734        assert_eq!(ResponseFormat::parse("table"), Some(ResponseFormat::Table));
1735        assert_eq!(
1736            ResponseFormat::parse("greptimedb_v1"),
1737            Some(ResponseFormat::GreptimedbV1)
1738        );
1739        assert_eq!(
1740            ResponseFormat::parse("influxdb_v1"),
1741            Some(ResponseFormat::InfluxdbV1)
1742        );
1743        assert_eq!(ResponseFormat::parse("json"), Some(ResponseFormat::Json));
1744        assert_eq!(ResponseFormat::parse("null"), Some(ResponseFormat::Null));
1745
1746        // invalid formats
1747        assert_eq!(ResponseFormat::parse("invalid"), None);
1748        assert_eq!(ResponseFormat::parse(""), None);
1749        assert_eq!(ResponseFormat::parse("CSV"), None); // Case sensitive
1750
1751        // as str
1752        assert_eq!(ResponseFormat::Arrow.as_str(), "arrow");
1753        assert_eq!(ResponseFormat::Csv(false, false).as_str(), "csv");
1754        assert_eq!(ResponseFormat::Csv(true, true).as_str(), "csv");
1755        assert_eq!(ResponseFormat::Table.as_str(), "table");
1756        assert_eq!(ResponseFormat::GreptimedbV1.as_str(), "greptimedb_v1");
1757        assert_eq!(ResponseFormat::InfluxdbV1.as_str(), "influxdb_v1");
1758        assert_eq!(ResponseFormat::Json.as_str(), "json");
1759        assert_eq!(ResponseFormat::Null.as_str(), "null");
1760        assert_eq!(ResponseFormat::default().as_str(), "greptimedb_v1");
1761    }
1762}