servers/
http.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::convert::Infallible;
17use std::fmt::Display;
18use std::net::SocketAddr;
19use std::sync::Mutex as StdMutex;
20use std::time::Duration;
21
22use async_trait::async_trait;
23use auth::UserProviderRef;
24use axum::extract::{DefaultBodyLimit, Request};
25use axum::http::StatusCode as HttpStatusCode;
26use axum::response::{IntoResponse, Response};
27use axum::routing::Route;
28use axum::serve::ListenerExt;
29use axum::{Router, middleware, routing};
30use common_base::Plugins;
31use common_base::readable_size::ReadableSize;
32use common_recordbatch::RecordBatch;
33use common_telemetry::{error, info};
34use common_time::Timestamp;
35use common_time::timestamp::TimeUnit;
36use datatypes::data_type::DataType;
37use datatypes::schema::SchemaRef;
38use event::{LogState, LogValidatorRef};
39use futures::FutureExt;
40use http::{HeaderValue, Method};
41use serde::{Deserialize, Serialize};
42use serde_json::Value;
43use snafu::{ResultExt, ensure};
44use tokio::sync::Mutex;
45use tokio::sync::oneshot::{self, Sender};
46use tonic::codegen::Service;
47use tower::{Layer, ServiceBuilder};
48use tower_http::compression::CompressionLayer;
49use tower_http::cors::{AllowOrigin, Any, CorsLayer};
50use tower_http::decompression::RequestDecompressionLayer;
51use tower_http::trace::TraceLayer;
52
53use self::authorize::AuthState;
54use self::result::table_result::TableResponse;
55use crate::configurator::HttpConfiguratorRef;
56use crate::elasticsearch;
57use crate::error::{
58    AddressBindSnafu, AlreadyStartedSnafu, Error, InternalIoSnafu, InvalidHeaderValueSnafu,
59    OtherSnafu, Result,
60};
61use crate::http::influxdb::{influxdb_health, influxdb_ping, influxdb_write_v1, influxdb_write_v2};
62use crate::http::otlp::OtlpState;
63use crate::http::prom_store::PromStoreState;
64use crate::http::prometheus::{
65    build_info_query, format_query, instant_query, label_values_query, labels_query, parse_query,
66    range_query, series_query,
67};
68use crate::http::result::arrow_result::ArrowResponse;
69use crate::http::result::csv_result::CsvResponse;
70use crate::http::result::error_result::ErrorResponse;
71use crate::http::result::greptime_result_v1::GreptimedbV1Response;
72use crate::http::result::influxdb_result_v1::InfluxdbV1Response;
73use crate::http::result::json_result::JsonResponse;
74use crate::http::result::null_result::NullResponse;
75use crate::interceptor::LogIngestInterceptorRef;
76use crate::metrics::http_metrics_layer;
77use crate::metrics_handler::MetricsHandler;
78use crate::prometheus_handler::PrometheusHandlerRef;
79use crate::query_handler::sql::ServerSqlQueryHandlerRef;
80use crate::query_handler::{
81    DashboardHandlerRef, InfluxdbLineProtocolHandlerRef, JaegerQueryHandlerRef, LogQueryHandlerRef,
82    OpenTelemetryProtocolHandlerRef, OpentsdbProtocolHandlerRef, PipelineHandlerRef,
83    PromStoreProtocolHandlerRef,
84};
85use crate::request_memory_limiter::ServerMemoryLimiter;
86use crate::server::Server;
87
88pub mod authorize;
89#[cfg(feature = "dashboard")]
90mod dashboard;
91pub mod dyn_log;
92pub mod dyn_trace;
93pub mod event;
94pub mod extractor;
95pub mod handler;
96pub mod header;
97pub mod influxdb;
98pub mod jaeger;
99pub mod logs;
100pub mod loki;
101pub mod mem_prof;
102mod memory_limit;
103pub mod opentsdb;
104pub mod otlp;
105pub mod pprof;
106pub mod prom_store;
107pub mod prometheus;
108pub mod result;
109mod timeout;
110pub mod utils;
111
112use result::HttpOutputWriter;
113pub(crate) use timeout::DynamicTimeoutLayer;
114
115mod client_ip;
116use crate::prom_remote_write::validation::PromValidationMode;
117mod hints;
118mod read_preference;
119#[cfg(any(test, feature = "testing"))]
120pub mod test_helpers;
121
122pub const HTTP_API_VERSION: &str = "v1";
123pub const HTTP_API_PREFIX: &str = "/v1/";
124/// Default http body limit (64M).
125const DEFAULT_BODY_LIMIT: ReadableSize = ReadableSize::mb(64);
126
127/// Authorization header
128pub const AUTHORIZATION_HEADER: &str = "x-greptime-auth";
129
130// TODO(fys): This is a temporary workaround, it will be improved later
131pub static PUBLIC_APIS: [&str; 3] = ["/v1/influxdb/ping", "/v1/influxdb/health", "/v1/health"];
132
133#[derive(Default)]
134pub struct HttpServer {
135    router: StdMutex<Router>,
136    shutdown_tx: Mutex<Option<Sender<()>>>,
137    user_provider: Option<UserProviderRef>,
138    memory_limiter: ServerMemoryLimiter,
139
140    // plugins
141    plugins: Plugins,
142
143    // server configs
144    options: HttpOptions,
145    bind_addr: Option<SocketAddr>,
146}
147
148#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
149#[serde(default)]
150pub struct HttpOptions {
151    pub addr: String,
152
153    #[serde(with = "humantime_serde")]
154    pub timeout: Duration,
155
156    #[serde(skip)]
157    pub disable_dashboard: bool,
158
159    pub body_limit: ReadableSize,
160
161    /// Validation mode while decoding Prometheus remote write requests.
162    pub prom_validation_mode: PromValidationMode,
163
164    pub cors_allowed_origins: Vec<String>,
165
166    pub enable_cors: bool,
167}
168
169impl Default for HttpOptions {
170    fn default() -> Self {
171        Self {
172            addr: "127.0.0.1:4000".to_string(),
173            timeout: Duration::from_secs(0),
174            disable_dashboard: false,
175            body_limit: DEFAULT_BODY_LIMIT,
176            cors_allowed_origins: Vec::new(),
177            enable_cors: true,
178            prom_validation_mode: PromValidationMode::Strict,
179        }
180    }
181}
182
183#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
184pub struct ColumnSchema {
185    name: String,
186    data_type: String,
187}
188
189impl ColumnSchema {
190    pub fn new(name: String, data_type: String) -> ColumnSchema {
191        ColumnSchema { name, data_type }
192    }
193}
194
195#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
196pub struct OutputSchema {
197    column_schemas: Vec<ColumnSchema>,
198}
199
200impl OutputSchema {
201    pub fn new(columns: Vec<ColumnSchema>) -> OutputSchema {
202        OutputSchema {
203            column_schemas: columns,
204        }
205    }
206}
207
208impl From<SchemaRef> for OutputSchema {
209    fn from(schema: SchemaRef) -> OutputSchema {
210        OutputSchema {
211            column_schemas: schema
212                .column_schemas()
213                .iter()
214                .map(|cs| ColumnSchema {
215                    name: cs.name.clone(),
216                    data_type: cs.data_type.name(),
217                })
218                .collect(),
219        }
220    }
221}
222
223#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
224pub struct HttpRecordsOutput {
225    schema: OutputSchema,
226    rows: Vec<Vec<Value>>,
227    // total_rows is equal to rows.len() in most cases,
228    // the Dashboard query result may be truncated, so we need to return the total_rows.
229    #[serde(default)]
230    total_rows: usize,
231
232    // plan level execution metrics
233    #[serde(skip_serializing_if = "HashMap::is_empty")]
234    #[serde(default)]
235    metrics: HashMap<String, Value>,
236}
237
238impl HttpRecordsOutput {
239    pub fn num_rows(&self) -> usize {
240        self.rows.len()
241    }
242
243    pub fn num_cols(&self) -> usize {
244        self.schema.column_schemas.len()
245    }
246
247    pub fn schema(&self) -> &OutputSchema {
248        &self.schema
249    }
250
251    pub fn rows(&self) -> &Vec<Vec<Value>> {
252        &self.rows
253    }
254}
255
256impl HttpRecordsOutput {
257    pub fn try_new(
258        schema: SchemaRef,
259        recordbatches: Vec<RecordBatch>,
260    ) -> std::result::Result<HttpRecordsOutput, Error> {
261        if recordbatches.is_empty() {
262            Ok(HttpRecordsOutput {
263                schema: OutputSchema::from(schema),
264                rows: vec![],
265                total_rows: 0,
266                metrics: Default::default(),
267            })
268        } else {
269            let num_rows = recordbatches.iter().map(|r| r.num_rows()).sum::<usize>();
270            let mut rows = Vec::with_capacity(num_rows);
271
272            for recordbatch in recordbatches {
273                let mut writer = HttpOutputWriter::new(schema.num_columns(), None);
274                writer.write(recordbatch, &mut rows)?;
275            }
276
277            Ok(HttpRecordsOutput {
278                schema: OutputSchema::from(schema),
279                total_rows: rows.len(),
280                rows,
281                metrics: Default::default(),
282            })
283        }
284    }
285}
286
287#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
288#[serde(rename_all = "lowercase")]
289pub enum GreptimeQueryOutput {
290    AffectedRows(usize),
291    Records(HttpRecordsOutput),
292}
293
294/// It allows the results of SQL queries to be presented in different formats.
295#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
296pub enum ResponseFormat {
297    Arrow,
298    // (with_names, with_types)
299    Csv(bool, bool),
300    Table,
301    #[default]
302    GreptimedbV1,
303    InfluxdbV1,
304    Json,
305    Null,
306}
307
308impl ResponseFormat {
309    pub fn parse(s: &str) -> Option<Self> {
310        match s {
311            "arrow" => Some(ResponseFormat::Arrow),
312            "csv" => Some(ResponseFormat::Csv(false, false)),
313            "csvwithnames" => Some(ResponseFormat::Csv(true, false)),
314            "csvwithnamesandtypes" => Some(ResponseFormat::Csv(true, true)),
315            "table" => Some(ResponseFormat::Table),
316            "greptimedb_v1" => Some(ResponseFormat::GreptimedbV1),
317            "influxdb_v1" => Some(ResponseFormat::InfluxdbV1),
318            "json" => Some(ResponseFormat::Json),
319            "null" => Some(ResponseFormat::Null),
320            _ => None,
321        }
322    }
323
324    pub fn as_str(&self) -> &'static str {
325        match self {
326            ResponseFormat::Arrow => "arrow",
327            ResponseFormat::Csv(_, _) => "csv",
328            ResponseFormat::Table => "table",
329            ResponseFormat::GreptimedbV1 => "greptimedb_v1",
330            ResponseFormat::InfluxdbV1 => "influxdb_v1",
331            ResponseFormat::Json => "json",
332            ResponseFormat::Null => "null",
333        }
334    }
335}
336
337#[derive(Debug, Clone, Copy, PartialEq, Eq)]
338pub enum Epoch {
339    Nanosecond,
340    Microsecond,
341    Millisecond,
342    Second,
343}
344
345impl Epoch {
346    pub fn parse(s: &str) -> Option<Epoch> {
347        // Both u and µ indicate microseconds.
348        // epoch = [ns,u,µ,ms,s],
349        // For details, see the Influxdb documents.
350        // https://docs.influxdata.com/influxdb/v1/tools/api/#query-string-parameters-1
351        match s {
352            "ns" => Some(Epoch::Nanosecond),
353            "u" | "µ" => Some(Epoch::Microsecond),
354            "ms" => Some(Epoch::Millisecond),
355            "s" => Some(Epoch::Second),
356            _ => None, // just returns None for other cases
357        }
358    }
359
360    pub fn convert_timestamp(&self, ts: Timestamp) -> Option<Timestamp> {
361        match self {
362            Epoch::Nanosecond => ts.convert_to(TimeUnit::Nanosecond),
363            Epoch::Microsecond => ts.convert_to(TimeUnit::Microsecond),
364            Epoch::Millisecond => ts.convert_to(TimeUnit::Millisecond),
365            Epoch::Second => ts.convert_to(TimeUnit::Second),
366        }
367    }
368}
369
370impl Display for Epoch {
371    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
372        match self {
373            Epoch::Nanosecond => write!(f, "Epoch::Nanosecond"),
374            Epoch::Microsecond => write!(f, "Epoch::Microsecond"),
375            Epoch::Millisecond => write!(f, "Epoch::Millisecond"),
376            Epoch::Second => write!(f, "Epoch::Second"),
377        }
378    }
379}
380
381#[derive(Serialize, Deserialize, Debug)]
382pub enum HttpResponse {
383    Arrow(ArrowResponse),
384    Csv(CsvResponse),
385    Table(TableResponse),
386    Error(ErrorResponse),
387    GreptimedbV1(GreptimedbV1Response),
388    InfluxdbV1(InfluxdbV1Response),
389    Json(JsonResponse),
390    Null(NullResponse),
391}
392
393impl HttpResponse {
394    pub fn with_execution_time(self, execution_time: u64) -> Self {
395        match self {
396            HttpResponse::Arrow(resp) => resp.with_execution_time(execution_time).into(),
397            HttpResponse::Csv(resp) => resp.with_execution_time(execution_time).into(),
398            HttpResponse::Table(resp) => resp.with_execution_time(execution_time).into(),
399            HttpResponse::GreptimedbV1(resp) => resp.with_execution_time(execution_time).into(),
400            HttpResponse::InfluxdbV1(resp) => resp.with_execution_time(execution_time).into(),
401            HttpResponse::Json(resp) => resp.with_execution_time(execution_time).into(),
402            HttpResponse::Null(resp) => resp.with_execution_time(execution_time).into(),
403            HttpResponse::Error(resp) => resp.with_execution_time(execution_time).into(),
404        }
405    }
406
407    pub fn with_limit(self, limit: usize) -> Self {
408        match self {
409            HttpResponse::Csv(resp) => resp.with_limit(limit).into(),
410            HttpResponse::Table(resp) => resp.with_limit(limit).into(),
411            HttpResponse::GreptimedbV1(resp) => resp.with_limit(limit).into(),
412            HttpResponse::Json(resp) => resp.with_limit(limit).into(),
413            _ => self,
414        }
415    }
416}
417
418pub fn process_with_limit(
419    mut outputs: Vec<GreptimeQueryOutput>,
420    limit: usize,
421) -> Vec<GreptimeQueryOutput> {
422    outputs
423        .drain(..)
424        .map(|data| match data {
425            GreptimeQueryOutput::Records(mut records) => {
426                if records.rows.len() > limit {
427                    records.rows.truncate(limit);
428                    records.total_rows = limit;
429                }
430                GreptimeQueryOutput::Records(records)
431            }
432            _ => data,
433        })
434        .collect()
435}
436
437impl IntoResponse for HttpResponse {
438    fn into_response(self) -> Response {
439        match self {
440            HttpResponse::Arrow(resp) => resp.into_response(),
441            HttpResponse::Csv(resp) => resp.into_response(),
442            HttpResponse::Table(resp) => resp.into_response(),
443            HttpResponse::GreptimedbV1(resp) => resp.into_response(),
444            HttpResponse::InfluxdbV1(resp) => resp.into_response(),
445            HttpResponse::Json(resp) => resp.into_response(),
446            HttpResponse::Null(resp) => resp.into_response(),
447            HttpResponse::Error(resp) => resp.into_response(),
448        }
449    }
450}
451
452impl From<ArrowResponse> for HttpResponse {
453    fn from(value: ArrowResponse) -> Self {
454        HttpResponse::Arrow(value)
455    }
456}
457
458impl From<CsvResponse> for HttpResponse {
459    fn from(value: CsvResponse) -> Self {
460        HttpResponse::Csv(value)
461    }
462}
463
464impl From<TableResponse> for HttpResponse {
465    fn from(value: TableResponse) -> Self {
466        HttpResponse::Table(value)
467    }
468}
469
470impl From<ErrorResponse> for HttpResponse {
471    fn from(value: ErrorResponse) -> Self {
472        HttpResponse::Error(value)
473    }
474}
475
476impl From<GreptimedbV1Response> for HttpResponse {
477    fn from(value: GreptimedbV1Response) -> Self {
478        HttpResponse::GreptimedbV1(value)
479    }
480}
481
482impl From<InfluxdbV1Response> for HttpResponse {
483    fn from(value: InfluxdbV1Response) -> Self {
484        HttpResponse::InfluxdbV1(value)
485    }
486}
487
488impl From<JsonResponse> for HttpResponse {
489    fn from(value: JsonResponse) -> Self {
490        HttpResponse::Json(value)
491    }
492}
493
494impl From<NullResponse> for HttpResponse {
495    fn from(value: NullResponse) -> Self {
496        HttpResponse::Null(value)
497    }
498}
499
500#[derive(Clone)]
501pub struct ApiState {
502    pub sql_handler: ServerSqlQueryHandlerRef,
503}
504
505#[derive(Clone)]
506pub struct GreptimeOptionsConfigState {
507    pub greptime_config_options: String,
508}
509
510#[derive(Clone)]
511pub struct DashboardState {
512    pub handler: DashboardHandlerRef,
513}
514
515pub struct HttpServerBuilder {
516    options: HttpOptions,
517    plugins: Plugins,
518    user_provider: Option<UserProviderRef>,
519    router: Router,
520    memory_limiter: ServerMemoryLimiter,
521}
522
523impl HttpServerBuilder {
524    pub fn new(options: HttpOptions) -> Self {
525        Self {
526            options,
527            plugins: Plugins::default(),
528            user_provider: None,
529            router: Router::new(),
530            memory_limiter: ServerMemoryLimiter::default(),
531        }
532    }
533
534    /// Set a global memory limiter for all server protocols.
535    pub fn with_memory_limiter(mut self, limiter: ServerMemoryLimiter) -> Self {
536        self.memory_limiter = limiter;
537        self
538    }
539
540    pub fn with_sql_handler(self, sql_handler: ServerSqlQueryHandlerRef) -> Self {
541        let sql_router = HttpServer::route_sql(ApiState { sql_handler });
542
543        Self {
544            router: self
545                .router
546                .nest(&format!("/{HTTP_API_VERSION}"), sql_router),
547            ..self
548        }
549    }
550
551    pub fn with_logs_handler(self, logs_handler: LogQueryHandlerRef) -> Self {
552        let logs_router = HttpServer::route_logs(logs_handler);
553
554        Self {
555            router: self
556                .router
557                .nest(&format!("/{HTTP_API_VERSION}"), logs_router),
558            ..self
559        }
560    }
561
562    pub fn with_opentsdb_handler(self, handler: OpentsdbProtocolHandlerRef) -> Self {
563        Self {
564            router: self.router.nest(
565                &format!("/{HTTP_API_VERSION}/opentsdb"),
566                HttpServer::route_opentsdb(handler),
567            ),
568            ..self
569        }
570    }
571
572    pub fn with_influxdb_handler(self, handler: InfluxdbLineProtocolHandlerRef) -> Self {
573        Self {
574            router: self.router.nest(
575                &format!("/{HTTP_API_VERSION}/influxdb"),
576                HttpServer::route_influxdb(handler),
577            ),
578            ..self
579        }
580    }
581
582    pub fn with_prom_handler(
583        self,
584        handler: PromStoreProtocolHandlerRef,
585        pipeline_handler: Option<PipelineHandlerRef>,
586        prom_store_with_metric_engine: bool,
587        prom_validation_mode: PromValidationMode,
588    ) -> Self {
589        let state = PromStoreState {
590            prom_store_handler: handler,
591            pipeline_handler,
592            prom_store_with_metric_engine,
593            prom_validation_mode,
594        };
595
596        Self {
597            router: self.router.nest(
598                &format!("/{HTTP_API_VERSION}/prometheus"),
599                HttpServer::route_prom(state),
600            ),
601            ..self
602        }
603    }
604
605    pub fn with_prometheus_handler(self, handler: PrometheusHandlerRef) -> Self {
606        Self {
607            router: self.router.nest(
608                &format!("/{HTTP_API_VERSION}/prometheus/api/v1"),
609                HttpServer::route_prometheus(handler),
610            ),
611            ..self
612        }
613    }
614
615    pub fn with_otlp_handler(
616        self,
617        handler: OpenTelemetryProtocolHandlerRef,
618        with_metric_engine: bool,
619    ) -> Self {
620        Self {
621            router: self.router.nest(
622                &format!("/{HTTP_API_VERSION}/otlp"),
623                HttpServer::route_otlp(handler, with_metric_engine),
624            ),
625            ..self
626        }
627    }
628
629    pub fn with_user_provider(self, user_provider: UserProviderRef) -> Self {
630        Self {
631            user_provider: Some(user_provider),
632            ..self
633        }
634    }
635
636    pub fn with_metrics_handler(self, handler: MetricsHandler) -> Self {
637        Self {
638            router: self.router.merge(HttpServer::route_metrics(handler)),
639            ..self
640        }
641    }
642
643    pub fn with_log_ingest_handler(
644        self,
645        handler: PipelineHandlerRef,
646        validator: Option<LogValidatorRef>,
647        ingest_interceptor: Option<LogIngestInterceptorRef<Error>>,
648    ) -> Self {
649        let log_state = LogState {
650            log_handler: handler,
651            log_validator: validator,
652            ingest_interceptor,
653        };
654
655        let router = self.router.nest(
656            &format!("/{HTTP_API_VERSION}"),
657            HttpServer::route_pipelines(log_state.clone()),
658        );
659        // deprecated since v0.11.0. Use `/logs` and `/pipelines` instead.
660        let router = router.nest(
661            &format!("/{HTTP_API_VERSION}/events"),
662            #[allow(deprecated)]
663            HttpServer::route_log_deprecated(log_state.clone()),
664        );
665
666        let router = router.nest(
667            &format!("/{HTTP_API_VERSION}/loki"),
668            HttpServer::route_loki(log_state.clone()),
669        );
670
671        let router = router.nest(
672            &format!("/{HTTP_API_VERSION}/elasticsearch"),
673            HttpServer::route_elasticsearch(log_state.clone()),
674        );
675
676        let router = router.nest(
677            &format!("/{HTTP_API_VERSION}/elasticsearch/"),
678            Router::new()
679                .route("/", routing::get(elasticsearch::handle_get_version))
680                .with_state(log_state),
681        );
682
683        Self { router, ..self }
684    }
685
686    pub fn with_plugins(self, plugins: Plugins) -> Self {
687        Self { plugins, ..self }
688    }
689
690    pub fn with_greptime_config_options(self, opts: String) -> Self {
691        let config_router = HttpServer::route_config(GreptimeOptionsConfigState {
692            greptime_config_options: opts,
693        });
694
695        Self {
696            router: self.router.merge(config_router),
697            ..self
698        }
699    }
700
701    pub fn with_jaeger_handler(self, handler: JaegerQueryHandlerRef) -> Self {
702        Self {
703            router: self.router.nest(
704                &format!("/{HTTP_API_VERSION}/jaeger"),
705                HttpServer::route_jaeger(handler),
706            ),
707            ..self
708        }
709    }
710
711    pub fn with_dashboard_handler(self, handler: DashboardHandlerRef) -> Self {
712        Self {
713            router: self.router.nest(
714                &format!("/{HTTP_API_VERSION}/dashboards"),
715                HttpServer::route_dashboard(handler),
716            ),
717            ..self
718        }
719    }
720
721    pub fn with_extra_router(self, router: Router) -> Self {
722        Self {
723            router: self.router.merge(router),
724            ..self
725        }
726    }
727
728    pub fn add_layer<L>(self, layer: L) -> Self
729    where
730        L: Layer<Route> + Clone + Send + Sync + 'static,
731        L::Service: Service<Request> + Clone + Send + Sync + 'static,
732        <L::Service as Service<Request>>::Response: IntoResponse + 'static,
733        <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
734        <L::Service as Service<Request>>::Future: Send + 'static,
735    {
736        Self {
737            router: self.router.layer(layer),
738            ..self
739        }
740    }
741
742    pub fn build(self) -> HttpServer {
743        HttpServer {
744            options: self.options,
745            user_provider: self.user_provider,
746            shutdown_tx: Mutex::new(None),
747            plugins: self.plugins,
748            router: StdMutex::new(self.router),
749            bind_addr: None,
750            memory_limiter: self.memory_limiter,
751        }
752    }
753}
754
755impl HttpServer {
756    /// Gets the router and adds necessary root routes (health, status, dashboard).
757    pub fn make_app(&self) -> Router {
758        let mut router = {
759            let router = self.router.lock().unwrap();
760            router.clone()
761        };
762
763        router = router
764            .route(
765                "/health",
766                routing::get(handler::health).post(handler::health),
767            )
768            .route(
769                &format!("/{HTTP_API_VERSION}/health"),
770                routing::get(handler::health).post(handler::health),
771            )
772            .route(
773                "/ready",
774                routing::get(handler::health).post(handler::health),
775            );
776
777        router = router.route("/status", routing::get(handler::status));
778
779        #[cfg(feature = "dashboard")]
780        {
781            if !self.options.disable_dashboard {
782                info!("Enable dashboard service at '/dashboard'");
783                // redirect /dashboard to /dashboard/
784                router = router.route(
785                    "/dashboard",
786                    routing::get(|uri: axum::http::uri::Uri| async move {
787                        let path = uri.path();
788                        let query = uri.query().map(|q| format!("?{}", q)).unwrap_or_default();
789
790                        let new_uri = format!("{}/{}", path, query);
791                        axum::response::Redirect::permanent(&new_uri)
792                    }),
793                );
794
795                // "/dashboard" and "/dashboard/" are two different paths in Axum.
796                // We cannot nest "/dashboard/", because we already mapping "/dashboard/{*x}" while nesting "/dashboard".
797                // So we explicitly route "/dashboard/" here.
798                router = router
799                    .route(
800                        "/dashboard/",
801                        routing::get(dashboard::static_handler).post(dashboard::static_handler),
802                    )
803                    .route(
804                        "/dashboard/{*x}",
805                        routing::get(dashboard::static_handler).post(dashboard::static_handler),
806                    );
807            }
808        }
809
810        // Add a layer to collect HTTP metrics for axum.
811        router = router.route_layer(middleware::from_fn(http_metrics_layer));
812
813        router
814    }
815
816    /// Attaches middlewares and debug routes to the router.
817    /// Callers should call this method after [HttpServer::make_app()].
818    pub fn build(&self, router: Router) -> Result<Router> {
819        let timeout_layer = if self.options.timeout != Duration::default() {
820            Some(ServiceBuilder::new().layer(DynamicTimeoutLayer::new(self.options.timeout)))
821        } else {
822            info!("HTTP server timeout is disabled");
823            None
824        };
825        let body_limit_layer = if self.options.body_limit != ReadableSize(0) {
826            Some(
827                ServiceBuilder::new()
828                    .layer(DefaultBodyLimit::max(self.options.body_limit.0 as usize)),
829            )
830        } else {
831            info!("HTTP server body limit is disabled");
832            None
833        };
834        let cors_layer = if self.options.enable_cors {
835            Some(
836                CorsLayer::new()
837                    .allow_methods([
838                        Method::GET,
839                        Method::POST,
840                        Method::PUT,
841                        Method::DELETE,
842                        Method::HEAD,
843                    ])
844                    .allow_origin(if self.options.cors_allowed_origins.is_empty() {
845                        AllowOrigin::from(Any)
846                    } else {
847                        AllowOrigin::from(
848                            self.options
849                                .cors_allowed_origins
850                                .iter()
851                                .map(|s| {
852                                    HeaderValue::from_str(s.as_str())
853                                        .context(InvalidHeaderValueSnafu)
854                                })
855                                .collect::<Result<Vec<HeaderValue>>>()?,
856                        )
857                    })
858                    .allow_headers(Any),
859            )
860        } else {
861            info!("HTTP server cross-origin is disabled");
862            None
863        };
864
865        Ok(router
866            // middlewares
867            .layer(
868                ServiceBuilder::new()
869                    // disable on failure tracing. because printing out isn't very helpful,
870                    // and we have impl IntoResponse for Error. It will print out more detailed error messages
871                    .layer(TraceLayer::new_for_http().on_failure(()))
872                    .option_layer(cors_layer)
873                    .option_layer(timeout_layer)
874                    .option_layer(body_limit_layer)
875                    // memory limit layer - must be before body is consumed
876                    .layer(middleware::from_fn_with_state(
877                        self.memory_limiter.clone(),
878                        memory_limit::memory_limit_middleware,
879                    ))
880                    // auth layer
881                    .layer(middleware::from_fn_with_state(
882                        AuthState::new(self.user_provider.clone()),
883                        authorize::check_http_auth,
884                    ))
885                    .layer(middleware::from_fn(hints::extract_hints))
886                    .layer(middleware::from_fn(client_ip::log_error_with_client_ip))
887                    .layer(middleware::from_fn(
888                        read_preference::extract_read_preference,
889                    )),
890            )
891            // Handlers for debug, we don't expect a timeout.
892            .nest(
893                "/debug",
894                Router::new()
895                    // handler for changing log level dynamically
896                    .route("/log_level", routing::post(dyn_log::dyn_log_handler))
897                    .route("/enable_trace", routing::post(dyn_trace::dyn_trace_handler))
898                    .nest(
899                        "/prof",
900                        Router::new()
901                            .route("/cpu", routing::post(pprof::pprof_handler))
902                            .route("/mem", routing::post(mem_prof::mem_prof_handler))
903                            .route("/mem/symbol", routing::post(mem_prof::symbolicate_handler))
904                            .route(
905                                "/mem/activate",
906                                routing::post(mem_prof::activate_heap_prof_handler),
907                            )
908                            .route(
909                                "/mem/deactivate",
910                                routing::post(mem_prof::deactivate_heap_prof_handler),
911                            )
912                            .route(
913                                "/mem/status",
914                                routing::get(mem_prof::heap_prof_status_handler),
915                            ) // jemalloc gdump flag status and toggle
916                            .route(
917                                "/mem/gdump",
918                                routing::get(mem_prof::gdump_status_handler)
919                                    .post(mem_prof::gdump_toggle_handler),
920                            ),
921                    ),
922            ))
923    }
924
925    fn route_metrics<S>(metrics_handler: MetricsHandler) -> Router<S> {
926        Router::new()
927            .route("/metrics", routing::get(handler::metrics))
928            .with_state(metrics_handler)
929    }
930
931    fn route_loki<S>(log_state: LogState) -> Router<S> {
932        Router::new()
933            .route("/api/v1/push", routing::post(loki::loki_ingest))
934            .layer(
935                ServiceBuilder::new()
936                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
937            )
938            .with_state(log_state)
939    }
940
941    fn route_elasticsearch<S>(log_state: LogState) -> Router<S> {
942        Router::new()
943            // Return fake responsefor HEAD '/' request.
944            .route(
945                "/",
946                routing::head((HttpStatusCode::OK, elasticsearch::elasticsearch_headers())),
947            )
948            // Return fake response for Elasticsearch version request.
949            .route("/", routing::get(elasticsearch::handle_get_version))
950            // Return fake response for Elasticsearch license request.
951            .route("/_license", routing::get(elasticsearch::handle_get_license))
952            .route("/_bulk", routing::post(elasticsearch::handle_bulk_api))
953            .route(
954                "/{index}/_bulk",
955                routing::post(elasticsearch::handle_bulk_api_with_index),
956            )
957            // Return fake response for Elasticsearch ilm request.
958            .route(
959                "/_ilm/policy/{*path}",
960                routing::any((
961                    HttpStatusCode::OK,
962                    elasticsearch::elasticsearch_headers(),
963                    axum::Json(serde_json::json!({})),
964                )),
965            )
966            // Return fake response for Elasticsearch index template request.
967            .route(
968                "/_index_template/{*path}",
969                routing::any((
970                    HttpStatusCode::OK,
971                    elasticsearch::elasticsearch_headers(),
972                    axum::Json(serde_json::json!({})),
973                )),
974            )
975            // Return fake response for Elasticsearch ingest pipeline request.
976            // See: https://www.elastic.co/guide/en/elasticsearch/reference/8.8/put-pipeline-api.html.
977            .route(
978                "/_ingest/{*path}",
979                routing::any((
980                    HttpStatusCode::OK,
981                    elasticsearch::elasticsearch_headers(),
982                    axum::Json(serde_json::json!({})),
983                )),
984            )
985            // Return fake response for Elasticsearch nodes discovery request.
986            // See: https://www.elastic.co/guide/en/elasticsearch/reference/8.8/cluster.html.
987            .route(
988                "/_nodes/{*path}",
989                routing::any((
990                    HttpStatusCode::OK,
991                    elasticsearch::elasticsearch_headers(),
992                    axum::Json(serde_json::json!({})),
993                )),
994            )
995            // Return fake response for Logstash APIs requests.
996            // See: https://www.elastic.co/guide/en/elasticsearch/reference/8.8/logstash-apis.html
997            .route(
998                "/logstash/{*path}",
999                routing::any((
1000                    HttpStatusCode::OK,
1001                    elasticsearch::elasticsearch_headers(),
1002                    axum::Json(serde_json::json!({})),
1003                )),
1004            )
1005            .route(
1006                "/_logstash/{*path}",
1007                routing::any((
1008                    HttpStatusCode::OK,
1009                    elasticsearch::elasticsearch_headers(),
1010                    axum::Json(serde_json::json!({})),
1011                )),
1012            )
1013            .layer(ServiceBuilder::new().layer(RequestDecompressionLayer::new()))
1014            .with_state(log_state)
1015    }
1016
1017    #[deprecated(since = "0.11.0", note = "Use `route_pipelines()` instead.")]
1018    fn route_log_deprecated<S>(log_state: LogState) -> Router<S> {
1019        Router::new()
1020            .route("/logs", routing::post(event::log_ingester))
1021            .route(
1022                "/pipelines/{pipeline_name}",
1023                routing::get(event::query_pipeline),
1024            )
1025            .route(
1026                "/pipelines/{pipeline_name}",
1027                routing::post(event::add_pipeline),
1028            )
1029            .route(
1030                "/pipelines/{pipeline_name}",
1031                routing::delete(event::delete_pipeline),
1032            )
1033            .route("/pipelines/dryrun", routing::post(event::pipeline_dryrun))
1034            .layer(
1035                ServiceBuilder::new()
1036                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1037            )
1038            .with_state(log_state)
1039    }
1040
1041    fn route_pipelines<S>(log_state: LogState) -> Router<S> {
1042        Router::new()
1043            .route("/ingest", routing::post(event::log_ingester))
1044            .route(
1045                "/pipelines/{pipeline_name}",
1046                routing::get(event::query_pipeline),
1047            )
1048            .route(
1049                "/pipelines/{pipeline_name}/ddl",
1050                routing::get(event::query_pipeline_ddl),
1051            )
1052            .route(
1053                "/pipelines/{pipeline_name}",
1054                routing::post(event::add_pipeline),
1055            )
1056            .route(
1057                "/pipelines/{pipeline_name}",
1058                routing::delete(event::delete_pipeline),
1059            )
1060            .route("/pipelines/_dryrun", routing::post(event::pipeline_dryrun))
1061            .layer(
1062                ServiceBuilder::new()
1063                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1064            )
1065            .with_state(log_state)
1066    }
1067
1068    fn route_sql<S>(api_state: ApiState) -> Router<S> {
1069        Router::new()
1070            .route("/sql", routing::get(handler::sql).post(handler::sql))
1071            .route(
1072                "/sql/parse",
1073                routing::get(handler::sql_parse).post(handler::sql_parse),
1074            )
1075            .route(
1076                "/sql/format",
1077                routing::get(handler::sql_format).post(handler::sql_format),
1078            )
1079            .route(
1080                "/promql",
1081                routing::get(handler::promql).post(handler::promql),
1082            )
1083            .with_state(api_state)
1084    }
1085
1086    fn route_logs<S>(log_handler: LogQueryHandlerRef) -> Router<S> {
1087        Router::new()
1088            .route("/logs", routing::get(logs::logs).post(logs::logs))
1089            .with_state(log_handler)
1090    }
1091
1092    /// Route Prometheus [HTTP API].
1093    ///
1094    /// [HTTP API]: https://prometheus.io/docs/prometheus/latest/querying/api/
1095    pub fn route_prometheus<S>(prometheus_handler: PrometheusHandlerRef) -> Router<S> {
1096        Router::new()
1097            .route(
1098                "/format_query",
1099                routing::post(format_query).get(format_query),
1100            )
1101            .route("/status/buildinfo", routing::get(build_info_query))
1102            .route("/query", routing::post(instant_query).get(instant_query))
1103            .route("/query_range", routing::post(range_query).get(range_query))
1104            .route("/labels", routing::post(labels_query).get(labels_query))
1105            .route("/series", routing::post(series_query).get(series_query))
1106            .route("/parse_query", routing::post(parse_query).get(parse_query))
1107            .route(
1108                "/label/{label_name}/values",
1109                routing::get(label_values_query),
1110            )
1111            .layer(ServiceBuilder::new().layer(CompressionLayer::new()))
1112            .with_state(prometheus_handler)
1113    }
1114
1115    /// Route Prometheus remote [read] and [write] API. In other places the related modules are
1116    /// called `prom_store`.
1117    ///
1118    /// [read]: https://prometheus.io/docs/prometheus/latest/querying/remote_read_api/
1119    /// [write]: https://prometheus.io/docs/concepts/remote_write_spec/
1120    fn route_prom<S>(state: PromStoreState) -> Router<S> {
1121        Router::new()
1122            .route("/read", routing::post(prom_store::remote_read))
1123            .route("/write", routing::post(prom_store::remote_write))
1124            .with_state(state)
1125    }
1126
1127    fn route_influxdb<S>(influxdb_handler: InfluxdbLineProtocolHandlerRef) -> Router<S> {
1128        Router::new()
1129            .route("/write", routing::post(influxdb_write_v1))
1130            .route("/api/v2/write", routing::post(influxdb_write_v2))
1131            .layer(
1132                ServiceBuilder::new()
1133                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1134            )
1135            .route("/ping", routing::get(influxdb_ping))
1136            .route("/health", routing::get(influxdb_health))
1137            .with_state(influxdb_handler)
1138    }
1139
1140    fn route_opentsdb<S>(opentsdb_handler: OpentsdbProtocolHandlerRef) -> Router<S> {
1141        Router::new()
1142            .route("/api/put", routing::post(opentsdb::put))
1143            .with_state(opentsdb_handler)
1144    }
1145
1146    fn route_otlp<S>(
1147        otlp_handler: OpenTelemetryProtocolHandlerRef,
1148        with_metric_engine: bool,
1149    ) -> Router<S> {
1150        Router::new()
1151            .route("/v1/metrics", routing::post(otlp::metrics))
1152            .route("/v1/traces", routing::post(otlp::traces))
1153            .route("/v1/logs", routing::post(otlp::logs))
1154            .layer(
1155                ServiceBuilder::new()
1156                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1157            )
1158            .with_state(OtlpState {
1159                with_metric_engine,
1160                handler: otlp_handler,
1161            })
1162    }
1163
1164    fn route_config<S>(state: GreptimeOptionsConfigState) -> Router<S> {
1165        Router::new()
1166            .route("/config", routing::get(handler::config))
1167            .with_state(state)
1168    }
1169
1170    fn route_jaeger<S>(handler: JaegerQueryHandlerRef) -> Router<S> {
1171        Router::new()
1172            .route("/api/services", routing::get(jaeger::handle_get_services))
1173            .route(
1174                "/api/services/{service_name}/operations",
1175                routing::get(jaeger::handle_get_operations_by_service),
1176            )
1177            .route(
1178                "/api/operations",
1179                routing::get(jaeger::handle_get_operations),
1180            )
1181            .route("/api/traces", routing::get(jaeger::handle_find_traces))
1182            .route(
1183                "/api/traces/{trace_id}",
1184                routing::get(jaeger::handle_get_trace),
1185            )
1186            .with_state(handler)
1187    }
1188
1189    #[cfg(feature = "dashboard")]
1190    fn route_dashboard<S>(handler: DashboardHandlerRef) -> Router<S> {
1191        use crate::http::dashboard::{add_dashboard, delete_dashboard, list_dashboards};
1192
1193        Router::new()
1194            .route("/", routing::get(list_dashboards))
1195            .route("/{dashboard_name}", routing::post(add_dashboard))
1196            .route("/{dashboard_name}", routing::delete(delete_dashboard))
1197            .layer(
1198                ServiceBuilder::new()
1199                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1200            )
1201            .with_state(DashboardState { handler })
1202    }
1203
1204    #[cfg(not(feature = "dashboard"))]
1205    fn route_dashboard<S>(handler: DashboardHandlerRef) -> Router<S> {
1206        Router::new().with_state(DashboardState { handler })
1207    }
1208}
1209
1210pub const HTTP_SERVER: &str = "HTTP_SERVER";
1211
1212#[async_trait]
1213impl Server for HttpServer {
1214    async fn shutdown(&self) -> Result<()> {
1215        let mut shutdown_tx = self.shutdown_tx.lock().await;
1216        if let Some(tx) = shutdown_tx.take()
1217            && tx.send(()).is_err()
1218        {
1219            info!("Receiver dropped, the HTTP server has already exited");
1220        }
1221        info!("Shutdown HTTP server");
1222
1223        Ok(())
1224    }
1225
1226    async fn start(&mut self, listening: SocketAddr) -> Result<()> {
1227        let (tx, rx) = oneshot::channel();
1228        let serve = {
1229            let mut shutdown_tx = self.shutdown_tx.lock().await;
1230            ensure!(
1231                shutdown_tx.is_none(),
1232                AlreadyStartedSnafu { server: "HTTP" }
1233            );
1234
1235            let mut app = self.make_app();
1236            if let Some(configurator) = self.plugins.get::<HttpConfiguratorRef<()>>() {
1237                app = configurator
1238                    .configure_http(app, ())
1239                    .await
1240                    .context(OtherSnafu)?;
1241            }
1242            let app = self.build(app)?;
1243            let listener = tokio::net::TcpListener::bind(listening)
1244                .await
1245                .context(AddressBindSnafu { addr: listening })?
1246                .tap_io(|tcp_stream| {
1247                    if let Err(e) = tcp_stream.set_nodelay(true) {
1248                        error!(e; "Failed to set TCP_NODELAY on incoming connection");
1249                    }
1250                });
1251            let serve = axum::serve(
1252                listener,
1253                app.into_make_service_with_connect_info::<SocketAddr>(),
1254            );
1255
1256            // FIXME(yingwen): Support keepalive.
1257            // See:
1258            // - https://github.com/tokio-rs/axum/discussions/2939
1259            // - https://stackoverflow.com/questions/73069718/how-do-i-keep-alive-tokiotcpstream-in-rust
1260            // let server = axum::Server::try_bind(&listening)
1261            //     .with_context(|_| AddressBindSnafu { addr: listening })?
1262            //     .tcp_nodelay(true)
1263            //     // Enable TCP keepalive to close the dangling established connections.
1264            //     // It's configured to let the keepalive probes first send after the connection sits
1265            //     // idle for 59 minutes, and then send every 10 seconds for 6 times.
1266            //     // So the connection will be closed after roughly 1 hour.
1267            //     .tcp_keepalive(Some(Duration::from_secs(59 * 60)))
1268            //     .tcp_keepalive_interval(Some(Duration::from_secs(10)))
1269            //     .tcp_keepalive_retries(Some(6))
1270            //     .serve(app.into_make_service());
1271
1272            *shutdown_tx = Some(tx);
1273
1274            serve
1275        };
1276        let listening = serve.local_addr().context(InternalIoSnafu)?;
1277        info!("HTTP server is bound to {}", listening);
1278
1279        common_runtime::spawn_global(async move {
1280            if let Err(e) = serve
1281                .with_graceful_shutdown(rx.map(drop))
1282                .await
1283                .context(InternalIoSnafu)
1284            {
1285                error!(e; "Failed to shutdown http server");
1286            }
1287        });
1288
1289        self.bind_addr = Some(listening);
1290        Ok(())
1291    }
1292
1293    fn name(&self) -> &str {
1294        HTTP_SERVER
1295    }
1296
1297    fn bind_addr(&self) -> Option<SocketAddr> {
1298        self.bind_addr
1299    }
1300
1301    fn as_any(&self) -> &dyn std::any::Any {
1302        self
1303    }
1304}
1305
1306#[cfg(test)]
1307mod test {
1308    use std::future::pending;
1309    use std::io::Cursor;
1310    use std::sync::Arc;
1311
1312    use arrow_ipc::reader::FileReader;
1313    use arrow_schema::DataType;
1314    use axum::handler::Handler;
1315    use axum::http::StatusCode;
1316    use axum::routing::get;
1317    use common_query::Output;
1318    use common_recordbatch::RecordBatches;
1319    use datafusion_expr::LogicalPlan;
1320    use datatypes::prelude::*;
1321    use datatypes::schema::{ColumnSchema, Schema};
1322    use datatypes::vectors::{StringVector, UInt32Vector};
1323    use header::constants::GREPTIME_DB_HEADER_TIMEOUT;
1324    use query::parser::PromQuery;
1325    use query::query_engine::DescribeResult;
1326    use session::context::QueryContextRef;
1327    use sql::statements::statement::Statement;
1328    use tokio::sync::mpsc;
1329    use tokio::time::Instant;
1330
1331    use super::*;
1332    use crate::http::test_helpers::TestClient;
1333    use crate::prom_remote_write::validation::validate_label_name;
1334    use crate::query_handler::sql::SqlQueryHandler;
1335
1336    struct DummyInstance {
1337        _tx: mpsc::Sender<(String, Vec<u8>)>,
1338    }
1339
1340    #[async_trait]
1341    impl SqlQueryHandler for DummyInstance {
1342        async fn do_query(&self, _: &str, _: QueryContextRef) -> Vec<Result<Output>> {
1343            unimplemented!()
1344        }
1345
1346        async fn do_promql_query(&self, _: &PromQuery, _: QueryContextRef) -> Vec<Result<Output>> {
1347            unimplemented!()
1348        }
1349
1350        async fn do_exec_plan(
1351            &self,
1352            _stmt: Option<Statement>,
1353            _plan: LogicalPlan,
1354            _query_ctx: QueryContextRef,
1355        ) -> Result<Output> {
1356            unimplemented!()
1357        }
1358
1359        async fn do_describe(
1360            &self,
1361            _stmt: sql::statements::statement::Statement,
1362            _query_ctx: QueryContextRef,
1363        ) -> Result<Option<DescribeResult>> {
1364            unimplemented!()
1365        }
1366
1367        async fn is_valid_schema(&self, _catalog: &str, _schema: &str) -> Result<bool> {
1368            Ok(true)
1369        }
1370    }
1371
1372    fn timeout() -> DynamicTimeoutLayer {
1373        DynamicTimeoutLayer::new(Duration::from_millis(10))
1374    }
1375
1376    async fn forever() {
1377        pending().await
1378    }
1379
1380    fn make_test_app(tx: mpsc::Sender<(String, Vec<u8>)>) -> Router {
1381        make_test_app_custom(tx, HttpOptions::default())
1382    }
1383
1384    fn make_test_app_custom(tx: mpsc::Sender<(String, Vec<u8>)>, options: HttpOptions) -> Router {
1385        let instance = Arc::new(DummyInstance { _tx: tx });
1386        let server = HttpServerBuilder::new(options)
1387            .with_sql_handler(instance.clone())
1388            .build();
1389        server.build(server.make_app()).unwrap().route(
1390            "/test/timeout",
1391            get(forever.layer(ServiceBuilder::new().layer(timeout()))),
1392        )
1393    }
1394
1395    #[tokio::test]
1396    pub async fn test_cors() {
1397        // cors is on by default
1398        let (tx, _rx) = mpsc::channel(100);
1399        let app = make_test_app(tx);
1400        let client = TestClient::new(app).await;
1401
1402        let res = client.get("/health").send().await;
1403
1404        assert_eq!(res.status(), StatusCode::OK);
1405        assert_eq!(
1406            res.headers()
1407                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1408                .expect("expect cors header origin"),
1409            "*"
1410        );
1411
1412        let res = client.get("/v1/health").send().await;
1413
1414        assert_eq!(res.status(), StatusCode::OK);
1415        assert_eq!(
1416            res.headers()
1417                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1418                .expect("expect cors header origin"),
1419            "*"
1420        );
1421
1422        let res = client
1423            .options("/health")
1424            .header("Access-Control-Request-Headers", "x-greptime-auth")
1425            .header("Access-Control-Request-Method", "DELETE")
1426            .header("Origin", "https://example.com")
1427            .send()
1428            .await;
1429        assert_eq!(res.status(), StatusCode::OK);
1430        assert_eq!(
1431            res.headers()
1432                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1433                .expect("expect cors header origin"),
1434            "*"
1435        );
1436        assert_eq!(
1437            res.headers()
1438                .get(http::header::ACCESS_CONTROL_ALLOW_HEADERS)
1439                .expect("expect cors header headers"),
1440            "*"
1441        );
1442        assert_eq!(
1443            res.headers()
1444                .get(http::header::ACCESS_CONTROL_ALLOW_METHODS)
1445                .expect("expect cors header methods"),
1446            "GET,POST,PUT,DELETE,HEAD"
1447        );
1448    }
1449
1450    #[tokio::test]
1451    pub async fn test_cors_custom_origins() {
1452        // cors is on by default
1453        let (tx, _rx) = mpsc::channel(100);
1454        let origin = "https://example.com";
1455
1456        let options = HttpOptions {
1457            cors_allowed_origins: vec![origin.to_string()],
1458            ..Default::default()
1459        };
1460
1461        let app = make_test_app_custom(tx, options);
1462        let client = TestClient::new(app).await;
1463
1464        let res = client.get("/health").header("Origin", origin).send().await;
1465
1466        assert_eq!(res.status(), StatusCode::OK);
1467        assert_eq!(
1468            res.headers()
1469                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1470                .expect("expect cors header origin"),
1471            origin
1472        );
1473
1474        let res = client
1475            .get("/health")
1476            .header("Origin", "https://notallowed.com")
1477            .send()
1478            .await;
1479
1480        assert_eq!(res.status(), StatusCode::OK);
1481        assert!(
1482            !res.headers()
1483                .contains_key(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1484        );
1485    }
1486
1487    #[tokio::test]
1488    pub async fn test_cors_disabled() {
1489        // cors is on by default
1490        let (tx, _rx) = mpsc::channel(100);
1491
1492        let options = HttpOptions {
1493            enable_cors: false,
1494            ..Default::default()
1495        };
1496
1497        let app = make_test_app_custom(tx, options);
1498        let client = TestClient::new(app).await;
1499
1500        let res = client.get("/health").send().await;
1501
1502        assert_eq!(res.status(), StatusCode::OK);
1503        assert!(
1504            !res.headers()
1505                .contains_key(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1506        );
1507    }
1508
1509    #[test]
1510    fn test_http_options_default() {
1511        let default = HttpOptions::default();
1512        assert_eq!("127.0.0.1:4000".to_string(), default.addr);
1513        assert_eq!(Duration::from_secs(0), default.timeout)
1514    }
1515
1516    #[tokio::test]
1517    async fn test_http_server_request_timeout() {
1518        common_telemetry::init_default_ut_logging();
1519
1520        let (tx, _rx) = mpsc::channel(100);
1521        let app = make_test_app(tx);
1522        let client = TestClient::new(app).await;
1523        let res = client.get("/test/timeout").send().await;
1524        assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
1525
1526        let now = Instant::now();
1527        let res = client
1528            .get("/test/timeout")
1529            .header(GREPTIME_DB_HEADER_TIMEOUT, "20ms")
1530            .send()
1531            .await;
1532        assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
1533        let elapsed = now.elapsed();
1534        assert!(elapsed > Duration::from_millis(15));
1535
1536        tokio::time::timeout(
1537            Duration::from_millis(15),
1538            client
1539                .get("/test/timeout")
1540                .header(GREPTIME_DB_HEADER_TIMEOUT, "0s")
1541                .send(),
1542        )
1543        .await
1544        .unwrap_err();
1545
1546        tokio::time::timeout(
1547            Duration::from_millis(15),
1548            client
1549                .get("/test/timeout")
1550                .header(
1551                    GREPTIME_DB_HEADER_TIMEOUT,
1552                    humantime::format_duration(Duration::default()).to_string(),
1553                )
1554                .send(),
1555        )
1556        .await
1557        .unwrap_err();
1558    }
1559
1560    #[tokio::test]
1561    async fn test_schema_for_empty_response() {
1562        let column_schemas = vec![
1563            ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
1564            ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
1565        ];
1566        let schema = Arc::new(Schema::new(column_schemas));
1567
1568        let recordbatches = RecordBatches::try_new(schema.clone(), vec![]).unwrap();
1569        let outputs = vec![Ok(Output::new_with_record_batches(recordbatches))];
1570
1571        let json_resp = GreptimedbV1Response::from_output(outputs).await;
1572        if let HttpResponse::GreptimedbV1(json_resp) = json_resp {
1573            let json_output = &json_resp.output[0];
1574            if let GreptimeQueryOutput::Records(r) = json_output {
1575                assert_eq!(r.num_rows(), 0);
1576                assert_eq!(r.num_cols(), 2);
1577                assert_eq!(r.schema.column_schemas[0].name, "numbers");
1578                assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1579            } else {
1580                panic!("invalid output type");
1581            }
1582        } else {
1583            panic!("invalid format")
1584        }
1585    }
1586
1587    #[tokio::test]
1588    async fn test_recordbatches_conversion() {
1589        let column_schemas = vec![
1590            ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
1591            ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
1592        ];
1593        let schema = Arc::new(Schema::new(column_schemas));
1594        let columns: Vec<VectorRef> = vec![
1595            Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
1596            Arc::new(StringVector::from(vec![
1597                None,
1598                Some("hello"),
1599                Some("greptime"),
1600                None,
1601            ])),
1602        ];
1603        let recordbatch = RecordBatch::new(schema.clone(), columns).unwrap();
1604
1605        for format in [
1606            ResponseFormat::GreptimedbV1,
1607            ResponseFormat::InfluxdbV1,
1608            ResponseFormat::Csv(true, true),
1609            ResponseFormat::Table,
1610            ResponseFormat::Arrow,
1611            ResponseFormat::Json,
1612            ResponseFormat::Null,
1613        ] {
1614            let recordbatches =
1615                RecordBatches::try_new(schema.clone(), vec![recordbatch.clone()]).unwrap();
1616            let outputs = vec![Ok(Output::new_with_record_batches(recordbatches))];
1617            let json_resp = match format {
1618                ResponseFormat::Arrow => ArrowResponse::from_output(outputs, None).await,
1619                ResponseFormat::Csv(with_names, with_types) => {
1620                    CsvResponse::from_output(outputs, with_names, with_types).await
1621                }
1622                ResponseFormat::Table => TableResponse::from_output(outputs).await,
1623                ResponseFormat::GreptimedbV1 => GreptimedbV1Response::from_output(outputs).await,
1624                ResponseFormat::InfluxdbV1 => InfluxdbV1Response::from_output(outputs, None).await,
1625                ResponseFormat::Json => JsonResponse::from_output(outputs).await,
1626                ResponseFormat::Null => NullResponse::from_output(outputs).await,
1627            };
1628
1629            match json_resp {
1630                HttpResponse::GreptimedbV1(resp) => {
1631                    let json_output = &resp.output[0];
1632                    if let GreptimeQueryOutput::Records(r) = json_output {
1633                        assert_eq!(r.num_rows(), 4);
1634                        assert_eq!(r.num_cols(), 2);
1635                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1636                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1637                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1638                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1639                    } else {
1640                        panic!("invalid output type");
1641                    }
1642                }
1643                HttpResponse::InfluxdbV1(resp) => {
1644                    let json_output = &resp.results()[0];
1645                    assert_eq!(json_output.num_rows(), 4);
1646                    assert_eq!(json_output.num_cols(), 2);
1647                    assert_eq!(json_output.series[0].columns.clone()[0], "numbers");
1648                    assert_eq!(
1649                        json_output.series[0].values[0][0],
1650                        serde_json::Value::from(1)
1651                    );
1652                    assert_eq!(json_output.series[0].values[0][1], serde_json::Value::Null);
1653                }
1654                HttpResponse::Csv(resp) => {
1655                    let output = &resp.output()[0];
1656                    if let GreptimeQueryOutput::Records(r) = output {
1657                        assert_eq!(r.num_rows(), 4);
1658                        assert_eq!(r.num_cols(), 2);
1659                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1660                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1661                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1662                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1663                    } else {
1664                        panic!("invalid output type");
1665                    }
1666                }
1667
1668                HttpResponse::Table(resp) => {
1669                    let output = &resp.output()[0];
1670                    if let GreptimeQueryOutput::Records(r) = output {
1671                        assert_eq!(r.num_rows(), 4);
1672                        assert_eq!(r.num_cols(), 2);
1673                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1674                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1675                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1676                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1677                    } else {
1678                        panic!("invalid output type");
1679                    }
1680                }
1681
1682                HttpResponse::Arrow(resp) => {
1683                    let output = resp.data;
1684                    let mut reader =
1685                        FileReader::try_new(Cursor::new(output), None).expect("Arrow reader error");
1686                    let schema = reader.schema();
1687                    assert_eq!(schema.fields[0].name(), "numbers");
1688                    assert_eq!(schema.fields[0].data_type(), &DataType::UInt32);
1689                    assert_eq!(schema.fields[1].name(), "strings");
1690                    assert_eq!(schema.fields[1].data_type(), &DataType::Utf8);
1691
1692                    let rb = reader.next().unwrap().expect("read record batch failed");
1693                    assert_eq!(rb.num_columns(), 2);
1694                    assert_eq!(rb.num_rows(), 4);
1695                }
1696
1697                HttpResponse::Json(resp) => {
1698                    let output = &resp.output()[0];
1699                    if let GreptimeQueryOutput::Records(r) = output {
1700                        assert_eq!(r.num_rows(), 4);
1701                        assert_eq!(r.num_cols(), 2);
1702                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1703                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1704                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1705                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1706                    } else {
1707                        panic!("invalid output type");
1708                    }
1709                }
1710
1711                HttpResponse::Null(resp) => {
1712                    assert_eq!(resp.rows(), 4);
1713                }
1714
1715                HttpResponse::Error(err) => unreachable!("{err:?}"),
1716            }
1717        }
1718    }
1719
1720    #[test]
1721    fn test_response_format_misc() {
1722        assert_eq!(ResponseFormat::default(), ResponseFormat::GreptimedbV1);
1723        assert_eq!(ResponseFormat::parse("arrow"), Some(ResponseFormat::Arrow));
1724        assert_eq!(
1725            ResponseFormat::parse("csv"),
1726            Some(ResponseFormat::Csv(false, false))
1727        );
1728        assert_eq!(
1729            ResponseFormat::parse("csvwithnames"),
1730            Some(ResponseFormat::Csv(true, false))
1731        );
1732        assert_eq!(
1733            ResponseFormat::parse("csvwithnamesandtypes"),
1734            Some(ResponseFormat::Csv(true, true))
1735        );
1736        assert_eq!(ResponseFormat::parse("table"), Some(ResponseFormat::Table));
1737        assert_eq!(
1738            ResponseFormat::parse("greptimedb_v1"),
1739            Some(ResponseFormat::GreptimedbV1)
1740        );
1741        assert_eq!(
1742            ResponseFormat::parse("influxdb_v1"),
1743            Some(ResponseFormat::InfluxdbV1)
1744        );
1745        assert_eq!(ResponseFormat::parse("json"), Some(ResponseFormat::Json));
1746        assert_eq!(ResponseFormat::parse("null"), Some(ResponseFormat::Null));
1747
1748        // invalid formats
1749        assert_eq!(ResponseFormat::parse("invalid"), None);
1750        assert_eq!(ResponseFormat::parse(""), None);
1751        assert_eq!(ResponseFormat::parse("CSV"), None); // Case sensitive
1752
1753        // as str
1754        assert_eq!(ResponseFormat::Arrow.as_str(), "arrow");
1755        assert_eq!(ResponseFormat::Csv(false, false).as_str(), "csv");
1756        assert_eq!(ResponseFormat::Csv(true, true).as_str(), "csv");
1757        assert_eq!(ResponseFormat::Table.as_str(), "table");
1758        assert_eq!(ResponseFormat::GreptimedbV1.as_str(), "greptimedb_v1");
1759        assert_eq!(ResponseFormat::InfluxdbV1.as_str(), "influxdb_v1");
1760        assert_eq!(ResponseFormat::Json.as_str(), "json");
1761        assert_eq!(ResponseFormat::Null.as_str(), "null");
1762        assert_eq!(ResponseFormat::default().as_str(), "greptimedb_v1");
1763    }
1764
1765    #[test]
1766    fn test_decode_label_name_strict() {
1767        let strict = PromValidationMode::Strict;
1768
1769        // Valid Prometheus label names
1770        assert!(strict.decode_label_name(b"__name__").is_ok());
1771        assert!(strict.decode_label_name(b"job").is_ok());
1772        assert!(strict.decode_label_name(b"instance").is_ok());
1773        assert!(strict.decode_label_name(b"_private").is_ok());
1774        assert!(strict.decode_label_name(b"label_with_underscores").is_ok());
1775        assert!(strict.decode_label_name(b"abc123").is_ok());
1776        assert!(strict.decode_label_name(b"A").is_ok());
1777        assert!(strict.decode_label_name(b"_").is_ok());
1778
1779        // Invalid: starts with digit
1780        assert!(strict.decode_label_name(b"0abc").is_err());
1781        assert!(strict.decode_label_name(b"123").is_err());
1782
1783        // Invalid: contains special characters
1784        assert!(strict.decode_label_name(b"label-name").is_err());
1785        assert!(strict.decode_label_name(b"label.name").is_err());
1786        assert!(strict.decode_label_name(b"label name").is_err());
1787        assert!(strict.decode_label_name(b"label/name").is_err());
1788
1789        // Invalid: empty
1790        assert!(strict.decode_label_name(b"").is_err());
1791
1792        // Invalid: non-ASCII UTF-8
1793        assert!(strict.decode_label_name("ラベル".as_bytes()).is_err());
1794
1795        // Invalid UTF-8 bytes should fail
1796        assert!(strict.decode_label_name(&[0xff, 0xfe]).is_err());
1797    }
1798
1799    #[test]
1800    fn test_decode_label_name_lossy() {
1801        let lossy = PromValidationMode::Lossy;
1802
1803        // Label name validation is always enforced.
1804        assert!(lossy.decode_label_name(b"__name__").is_ok());
1805        assert!(lossy.decode_label_name(b"label-name").is_err());
1806        assert!(lossy.decode_label_name(b"0abc").is_err());
1807
1808        // Invalid UTF-8 bytes fail the label-name byte check.
1809        assert!(lossy.decode_label_name(&[0xff, 0xfe]).is_err());
1810    }
1811
1812    #[test]
1813    fn test_decode_label_name_unchecked() {
1814        let unchecked = PromValidationMode::Unchecked;
1815
1816        // Label name validation is always enforced.
1817        assert!(unchecked.decode_label_name(b"__name__").is_ok());
1818        assert!(unchecked.decode_label_name(b"label-name").is_err());
1819        assert!(unchecked.decode_label_name(b"0abc").is_err());
1820    }
1821
1822    #[test]
1823    fn test_is_valid_prom_label_name_bytes() {
1824        assert!(validate_label_name(b"__name__"));
1825        assert!(validate_label_name(b"job"));
1826        assert!(validate_label_name(b"_"));
1827        assert!(validate_label_name(b"A"));
1828        assert!(validate_label_name(b"abc123"));
1829        assert!(validate_label_name(b"_leading_underscore"));
1830
1831        assert!(!validate_label_name(b""));
1832        assert!(!validate_label_name(b"0starts_with_digit"));
1833        assert!(!validate_label_name(b"has-dash"));
1834        assert!(!validate_label_name(b"has.dot"));
1835        assert!(!validate_label_name(b"has space"));
1836        assert!(!validate_label_name(&[0xff, 0xfe]));
1837    }
1838}