Skip to main content

servers/
http.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::convert::Infallible;
17use std::fmt::Display;
18use std::net::SocketAddr;
19use std::sync::{Arc, Mutex as StdMutex};
20use std::time::Duration;
21
22use async_trait::async_trait;
23use auth::UserProviderRef;
24use axum::extract::{DefaultBodyLimit, Request};
25use axum::http::StatusCode as HttpStatusCode;
26use axum::response::{IntoResponse, Response};
27use axum::routing::Route;
28use axum::serve::ListenerExt;
29use axum::{Router, middleware, routing};
30use common_base::readable_size::ReadableSize;
31use common_recordbatch::RecordBatch;
32use common_telemetry::{error, info};
33use common_time::Timestamp;
34use common_time::timestamp::TimeUnit;
35use datatypes::data_type::DataType;
36use datatypes::schema::SchemaRef;
37use event::{LogState, LogValidatorRef};
38use futures::FutureExt;
39use http::{HeaderValue, Method};
40use serde::{Deserialize, Serialize};
41use serde_json::Value;
42use snafu::{ResultExt, ensure};
43use tokio::sync::Mutex;
44use tokio::sync::oneshot::{self, Sender};
45use tonic::codegen::Service;
46use tower::{Layer, ServiceBuilder};
47use tower_http::compression::CompressionLayer;
48use tower_http::cors::{AllowOrigin, Any, CorsLayer};
49use tower_http::decompression::RequestDecompressionLayer;
50use tower_http::trace::TraceLayer;
51
52use self::authorize::AuthState;
53use self::result::table_result::TableResponse;
54use crate::elasticsearch;
55use crate::error::{
56    AddressBindSnafu, AlreadyStartedSnafu, Error, InternalIoSnafu, InvalidHeaderValueSnafu, Result,
57};
58use crate::http::influxdb::{influxdb_health, influxdb_ping, influxdb_write_v1, influxdb_write_v2};
59use crate::http::otlp::OtlpState;
60use crate::http::prom_store::PromStoreState;
61use crate::http::prometheus::{
62    build_info_query, format_query, instant_query, label_values_query, labels_query, parse_query,
63    range_query, series_query,
64};
65use crate::http::result::arrow_result::ArrowResponse;
66use crate::http::result::csv_result::CsvResponse;
67use crate::http::result::error_result::ErrorResponse;
68use crate::http::result::greptime_result_v1::GreptimedbV1Response;
69use crate::http::result::influxdb_result_v1::InfluxdbV1Response;
70use crate::http::result::json_result::JsonResponse;
71use crate::http::result::null_result::NullResponse;
72use crate::interceptor::LogIngestInterceptorRef;
73use crate::metrics::http_metrics_layer;
74use crate::metrics_handler::MetricsHandler;
75use crate::pending_rows_batcher::PendingRowsBatcher;
76use crate::prometheus_handler::PrometheusHandlerRef;
77use crate::query_handler::sql::ServerSqlQueryHandlerRef;
78use crate::query_handler::{
79    DashboardHandlerRef, InfluxdbLineProtocolHandlerRef, JaegerQueryHandlerRef, LogQueryHandlerRef,
80    OpenTelemetryProtocolHandlerRef, OpentsdbProtocolHandlerRef, PipelineHandlerRef,
81    PromStoreProtocolHandlerRef,
82};
83use crate::request_memory_limiter::ServerMemoryLimiter;
84use crate::server::Server;
85
86pub mod authorize;
87#[cfg(feature = "dashboard")]
88mod dashboard;
89pub mod dyn_log;
90pub mod dyn_trace;
91pub mod event;
92pub mod extractor;
93pub mod handler;
94pub mod header;
95pub mod influxdb;
96pub mod jaeger;
97pub mod logs;
98pub mod loki;
99pub mod mem_prof;
100mod memory_limit;
101pub mod opentsdb;
102pub mod otlp;
103pub mod pprof;
104pub mod prom_store;
105pub mod prometheus;
106pub mod result;
107mod timeout;
108pub mod utils;
109
110use result::HttpOutputWriter;
111pub(crate) use timeout::DynamicTimeoutLayer;
112
113mod client_ip;
114use crate::prom_remote_write::validation::PromValidationMode;
115mod hints;
116mod read_preference;
117#[cfg(any(test, feature = "testing"))]
118pub mod test_helpers;
119
120pub const HTTP_API_VERSION: &str = "v1";
121pub const HTTP_API_PREFIX: &str = "/v1/";
122/// Default http body limit (64M).
123const DEFAULT_BODY_LIMIT: ReadableSize = ReadableSize::mb(64);
124
125/// Authorization header
126pub const AUTHORIZATION_HEADER: &str = "x-greptime-auth";
127
128// TODO(fys): This is a temporary workaround, it will be improved later
129pub static PUBLIC_API_PREFIX: [&str; 3] =
130    ["/v1/influxdb/ping", "/v1/influxdb/health", "/v1/health"];
131
132#[derive(Default)]
133pub struct HttpServer {
134    router: StdMutex<Router>,
135    shutdown_tx: Mutex<Option<Sender<()>>>,
136    user_provider: Option<UserProviderRef>,
137    memory_limiter: ServerMemoryLimiter,
138
139    // server configs
140    options: HttpOptions,
141    bind_addr: Option<SocketAddr>,
142}
143
144#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
145#[serde(default)]
146pub struct HttpOptions {
147    pub addr: String,
148
149    #[serde(with = "humantime_serde")]
150    pub timeout: Duration,
151
152    #[serde(skip)]
153    pub disable_dashboard: bool,
154
155    pub body_limit: ReadableSize,
156
157    /// Validation mode while decoding Prometheus remote write requests.
158    pub prom_validation_mode: PromValidationMode,
159
160    pub cors_allowed_origins: Vec<String>,
161
162    pub enable_cors: bool,
163}
164
165impl Default for HttpOptions {
166    fn default() -> Self {
167        Self {
168            addr: "127.0.0.1:4000".to_string(),
169            timeout: Duration::from_secs(0),
170            disable_dashboard: false,
171            body_limit: DEFAULT_BODY_LIMIT,
172            cors_allowed_origins: Vec::new(),
173            enable_cors: true,
174            prom_validation_mode: PromValidationMode::Strict,
175        }
176    }
177}
178
179#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
180pub struct ColumnSchema {
181    name: String,
182    data_type: String,
183}
184
185impl ColumnSchema {
186    pub fn new(name: String, data_type: String) -> ColumnSchema {
187        ColumnSchema { name, data_type }
188    }
189}
190
191#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
192pub struct OutputSchema {
193    column_schemas: Vec<ColumnSchema>,
194}
195
196impl OutputSchema {
197    pub fn new(columns: Vec<ColumnSchema>) -> OutputSchema {
198        OutputSchema {
199            column_schemas: columns,
200        }
201    }
202}
203
204impl From<SchemaRef> for OutputSchema {
205    fn from(schema: SchemaRef) -> OutputSchema {
206        OutputSchema {
207            column_schemas: schema
208                .column_schemas()
209                .iter()
210                .map(|cs| ColumnSchema {
211                    name: cs.name.clone(),
212                    data_type: cs.data_type.name(),
213                })
214                .collect(),
215        }
216    }
217}
218
219#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
220pub struct HttpRecordsOutput {
221    schema: OutputSchema,
222    rows: Vec<Vec<Value>>,
223    // total_rows is equal to rows.len() in most cases,
224    // the Dashboard query result may be truncated, so we need to return the total_rows.
225    #[serde(default)]
226    total_rows: usize,
227
228    // plan level execution metrics
229    #[serde(skip_serializing_if = "HashMap::is_empty")]
230    #[serde(default)]
231    metrics: HashMap<String, Value>,
232}
233
234impl HttpRecordsOutput {
235    pub fn num_rows(&self) -> usize {
236        self.rows.len()
237    }
238
239    pub fn num_cols(&self) -> usize {
240        self.schema.column_schemas.len()
241    }
242
243    pub fn schema(&self) -> &OutputSchema {
244        &self.schema
245    }
246
247    pub fn rows(&self) -> &Vec<Vec<Value>> {
248        &self.rows
249    }
250}
251
252impl HttpRecordsOutput {
253    pub fn try_new(
254        schema: SchemaRef,
255        recordbatches: Vec<RecordBatch>,
256    ) -> std::result::Result<HttpRecordsOutput, Error> {
257        if recordbatches.is_empty() {
258            Ok(HttpRecordsOutput {
259                schema: OutputSchema::from(schema),
260                rows: vec![],
261                total_rows: 0,
262                metrics: Default::default(),
263            })
264        } else {
265            let num_rows = recordbatches.iter().map(|r| r.num_rows()).sum::<usize>();
266            let mut rows = Vec::with_capacity(num_rows);
267
268            for recordbatch in recordbatches {
269                let mut writer = HttpOutputWriter::new(schema.num_columns(), None);
270                writer.write(recordbatch, &mut rows)?;
271            }
272
273            Ok(HttpRecordsOutput {
274                schema: OutputSchema::from(schema),
275                total_rows: rows.len(),
276                rows,
277                metrics: Default::default(),
278            })
279        }
280    }
281}
282
283#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
284#[serde(rename_all = "lowercase")]
285pub enum GreptimeQueryOutput {
286    AffectedRows(usize),
287    Records(HttpRecordsOutput),
288}
289
290/// It allows the results of SQL queries to be presented in different formats.
291#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
292pub enum ResponseFormat {
293    Arrow,
294    // (with_names, with_types)
295    Csv(bool, bool),
296    Table,
297    #[default]
298    GreptimedbV1,
299    InfluxdbV1,
300    Json,
301    Null,
302}
303
304impl ResponseFormat {
305    pub fn parse(s: &str) -> Option<Self> {
306        match s {
307            "arrow" => Some(ResponseFormat::Arrow),
308            "csv" => Some(ResponseFormat::Csv(false, false)),
309            "csvwithnames" => Some(ResponseFormat::Csv(true, false)),
310            "csvwithnamesandtypes" => Some(ResponseFormat::Csv(true, true)),
311            "table" => Some(ResponseFormat::Table),
312            "greptimedb_v1" => Some(ResponseFormat::GreptimedbV1),
313            "influxdb_v1" => Some(ResponseFormat::InfluxdbV1),
314            "json" => Some(ResponseFormat::Json),
315            "null" => Some(ResponseFormat::Null),
316            _ => None,
317        }
318    }
319
320    pub fn as_str(&self) -> &'static str {
321        match self {
322            ResponseFormat::Arrow => "arrow",
323            ResponseFormat::Csv(_, _) => "csv",
324            ResponseFormat::Table => "table",
325            ResponseFormat::GreptimedbV1 => "greptimedb_v1",
326            ResponseFormat::InfluxdbV1 => "influxdb_v1",
327            ResponseFormat::Json => "json",
328            ResponseFormat::Null => "null",
329        }
330    }
331}
332
333#[derive(Debug, Clone, Copy, PartialEq, Eq)]
334pub enum Epoch {
335    Nanosecond,
336    Microsecond,
337    Millisecond,
338    Second,
339}
340
341impl Epoch {
342    pub fn parse(s: &str) -> Option<Epoch> {
343        // Both u and µ indicate microseconds.
344        // epoch = [ns,u,µ,ms,s],
345        // For details, see the Influxdb documents.
346        // https://docs.influxdata.com/influxdb/v1/tools/api/#query-string-parameters-1
347        match s {
348            "ns" => Some(Epoch::Nanosecond),
349            "u" | "µ" => Some(Epoch::Microsecond),
350            "ms" => Some(Epoch::Millisecond),
351            "s" => Some(Epoch::Second),
352            _ => None, // just returns None for other cases
353        }
354    }
355
356    pub fn convert_timestamp(&self, ts: Timestamp) -> Option<Timestamp> {
357        match self {
358            Epoch::Nanosecond => ts.convert_to(TimeUnit::Nanosecond),
359            Epoch::Microsecond => ts.convert_to(TimeUnit::Microsecond),
360            Epoch::Millisecond => ts.convert_to(TimeUnit::Millisecond),
361            Epoch::Second => ts.convert_to(TimeUnit::Second),
362        }
363    }
364}
365
366impl Display for Epoch {
367    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
368        match self {
369            Epoch::Nanosecond => write!(f, "Epoch::Nanosecond"),
370            Epoch::Microsecond => write!(f, "Epoch::Microsecond"),
371            Epoch::Millisecond => write!(f, "Epoch::Millisecond"),
372            Epoch::Second => write!(f, "Epoch::Second"),
373        }
374    }
375}
376
377#[derive(Serialize, Deserialize, Debug)]
378pub enum HttpResponse {
379    Arrow(ArrowResponse),
380    Csv(CsvResponse),
381    Table(TableResponse),
382    Error(ErrorResponse),
383    GreptimedbV1(GreptimedbV1Response),
384    InfluxdbV1(InfluxdbV1Response),
385    Json(JsonResponse),
386    Null(NullResponse),
387}
388
389impl HttpResponse {
390    pub fn with_execution_time(self, execution_time: u64) -> Self {
391        match self {
392            HttpResponse::Arrow(resp) => resp.with_execution_time(execution_time).into(),
393            HttpResponse::Csv(resp) => resp.with_execution_time(execution_time).into(),
394            HttpResponse::Table(resp) => resp.with_execution_time(execution_time).into(),
395            HttpResponse::GreptimedbV1(resp) => resp.with_execution_time(execution_time).into(),
396            HttpResponse::InfluxdbV1(resp) => resp.with_execution_time(execution_time).into(),
397            HttpResponse::Json(resp) => resp.with_execution_time(execution_time).into(),
398            HttpResponse::Null(resp) => resp.with_execution_time(execution_time).into(),
399            HttpResponse::Error(resp) => resp.with_execution_time(execution_time).into(),
400        }
401    }
402
403    pub fn with_limit(self, limit: usize) -> Self {
404        match self {
405            HttpResponse::Csv(resp) => resp.with_limit(limit).into(),
406            HttpResponse::Table(resp) => resp.with_limit(limit).into(),
407            HttpResponse::GreptimedbV1(resp) => resp.with_limit(limit).into(),
408            HttpResponse::Json(resp) => resp.with_limit(limit).into(),
409            _ => self,
410        }
411    }
412}
413
414pub fn process_with_limit(
415    mut outputs: Vec<GreptimeQueryOutput>,
416    limit: usize,
417) -> Vec<GreptimeQueryOutput> {
418    outputs
419        .drain(..)
420        .map(|data| match data {
421            GreptimeQueryOutput::Records(mut records) => {
422                if records.rows.len() > limit {
423                    records.rows.truncate(limit);
424                    records.total_rows = limit;
425                }
426                GreptimeQueryOutput::Records(records)
427            }
428            _ => data,
429        })
430        .collect()
431}
432
433impl IntoResponse for HttpResponse {
434    fn into_response(self) -> Response {
435        match self {
436            HttpResponse::Arrow(resp) => resp.into_response(),
437            HttpResponse::Csv(resp) => resp.into_response(),
438            HttpResponse::Table(resp) => resp.into_response(),
439            HttpResponse::GreptimedbV1(resp) => resp.into_response(),
440            HttpResponse::InfluxdbV1(resp) => resp.into_response(),
441            HttpResponse::Json(resp) => resp.into_response(),
442            HttpResponse::Null(resp) => resp.into_response(),
443            HttpResponse::Error(resp) => resp.into_response(),
444        }
445    }
446}
447
448impl From<ArrowResponse> for HttpResponse {
449    fn from(value: ArrowResponse) -> Self {
450        HttpResponse::Arrow(value)
451    }
452}
453
454impl From<CsvResponse> for HttpResponse {
455    fn from(value: CsvResponse) -> Self {
456        HttpResponse::Csv(value)
457    }
458}
459
460impl From<TableResponse> for HttpResponse {
461    fn from(value: TableResponse) -> Self {
462        HttpResponse::Table(value)
463    }
464}
465
466impl From<ErrorResponse> for HttpResponse {
467    fn from(value: ErrorResponse) -> Self {
468        HttpResponse::Error(value)
469    }
470}
471
472impl From<GreptimedbV1Response> for HttpResponse {
473    fn from(value: GreptimedbV1Response) -> Self {
474        HttpResponse::GreptimedbV1(value)
475    }
476}
477
478impl From<InfluxdbV1Response> for HttpResponse {
479    fn from(value: InfluxdbV1Response) -> Self {
480        HttpResponse::InfluxdbV1(value)
481    }
482}
483
484impl From<JsonResponse> for HttpResponse {
485    fn from(value: JsonResponse) -> Self {
486        HttpResponse::Json(value)
487    }
488}
489
490impl From<NullResponse> for HttpResponse {
491    fn from(value: NullResponse) -> Self {
492        HttpResponse::Null(value)
493    }
494}
495
496#[derive(Clone)]
497pub struct ApiState {
498    pub sql_handler: ServerSqlQueryHandlerRef,
499}
500
501#[derive(Clone)]
502pub struct GreptimeOptionsConfigState {
503    pub greptime_config_options: String,
504}
505
506#[derive(Clone)]
507pub struct DashboardState {
508    pub handler: DashboardHandlerRef,
509}
510
511pub struct HttpServerBuilder {
512    options: HttpOptions,
513    user_provider: Option<UserProviderRef>,
514    router: Router,
515    memory_limiter: ServerMemoryLimiter,
516}
517
518impl HttpServerBuilder {
519    pub fn new(options: HttpOptions) -> Self {
520        Self {
521            options,
522            user_provider: None,
523            router: Router::new(),
524            memory_limiter: ServerMemoryLimiter::default(),
525        }
526    }
527
528    /// Set a global memory limiter for all server protocols.
529    pub fn with_memory_limiter(mut self, limiter: ServerMemoryLimiter) -> Self {
530        self.memory_limiter = limiter;
531        self
532    }
533
534    pub fn with_sql_handler(self, sql_handler: ServerSqlQueryHandlerRef) -> Self {
535        let sql_router = HttpServer::route_sql(ApiState { sql_handler });
536
537        Self {
538            router: self
539                .router
540                .nest(&format!("/{HTTP_API_VERSION}"), sql_router),
541            ..self
542        }
543    }
544
545    pub fn with_logs_handler(self, logs_handler: LogQueryHandlerRef) -> Self {
546        let logs_router = HttpServer::route_logs(logs_handler);
547
548        Self {
549            router: self
550                .router
551                .nest(&format!("/{HTTP_API_VERSION}"), logs_router),
552            ..self
553        }
554    }
555
556    pub fn with_opentsdb_handler(self, handler: OpentsdbProtocolHandlerRef) -> Self {
557        Self {
558            router: self.router.nest(
559                &format!("/{HTTP_API_VERSION}/opentsdb"),
560                HttpServer::route_opentsdb(handler),
561            ),
562            ..self
563        }
564    }
565
566    pub fn with_influxdb_handler(self, handler: InfluxdbLineProtocolHandlerRef) -> Self {
567        Self {
568            router: self.router.nest(
569                &format!("/{HTTP_API_VERSION}/influxdb"),
570                HttpServer::route_influxdb(handler),
571            ),
572            ..self
573        }
574    }
575
576    pub fn with_prom_handler(
577        self,
578        handler: PromStoreProtocolHandlerRef,
579        pipeline_handler: Option<PipelineHandlerRef>,
580        prom_store_with_metric_engine: bool,
581        prom_validation_mode: PromValidationMode,
582        pending_rows_batcher: Option<Arc<PendingRowsBatcher>>,
583    ) -> Self {
584        let state = PromStoreState {
585            prom_store_handler: handler,
586            pipeline_handler,
587            prom_store_with_metric_engine,
588            prom_validation_mode,
589            pending_rows_batcher,
590        };
591
592        Self {
593            router: self.router.nest(
594                &format!("/{HTTP_API_VERSION}/prometheus"),
595                HttpServer::route_prom(state),
596            ),
597            ..self
598        }
599    }
600
601    pub fn with_prometheus_handler(self, handler: PrometheusHandlerRef) -> Self {
602        Self {
603            router: self.router.nest(
604                &format!("/{HTTP_API_VERSION}/prometheus/api/v1"),
605                HttpServer::route_prometheus(handler),
606            ),
607            ..self
608        }
609    }
610
611    pub fn with_otlp_handler(
612        self,
613        handler: OpenTelemetryProtocolHandlerRef,
614        with_metric_engine: bool,
615    ) -> Self {
616        Self {
617            router: self.router.nest(
618                &format!("/{HTTP_API_VERSION}/otlp"),
619                HttpServer::route_otlp(handler, with_metric_engine),
620            ),
621            ..self
622        }
623    }
624
625    pub fn with_user_provider(self, user_provider: UserProviderRef) -> Self {
626        Self {
627            user_provider: Some(user_provider),
628            ..self
629        }
630    }
631
632    pub fn with_metrics_handler(self, handler: MetricsHandler) -> Self {
633        Self {
634            router: self.router.merge(HttpServer::route_metrics(handler)),
635            ..self
636        }
637    }
638
639    pub fn with_log_ingest_handler(
640        self,
641        handler: PipelineHandlerRef,
642        validator: Option<LogValidatorRef>,
643        ingest_interceptor: Option<LogIngestInterceptorRef<Error>>,
644    ) -> Self {
645        let log_state = LogState {
646            log_handler: handler,
647            log_validator: validator,
648            ingest_interceptor,
649        };
650
651        let router = self.router.nest(
652            &format!("/{HTTP_API_VERSION}"),
653            HttpServer::route_pipelines(log_state.clone()),
654        );
655        // deprecated since v0.11.0. Use `/logs` and `/pipelines` instead.
656        let router = router.nest(
657            &format!("/{HTTP_API_VERSION}/events"),
658            #[allow(deprecated)]
659            HttpServer::route_log_deprecated(log_state.clone()),
660        );
661
662        let router = router.nest(
663            &format!("/{HTTP_API_VERSION}/loki"),
664            HttpServer::route_loki(log_state.clone()),
665        );
666
667        let router = router.nest(
668            &format!("/{HTTP_API_VERSION}/elasticsearch"),
669            HttpServer::route_elasticsearch(log_state.clone()),
670        );
671
672        let router = router.nest(
673            &format!("/{HTTP_API_VERSION}/elasticsearch/"),
674            Router::new()
675                .route("/", routing::get(elasticsearch::handle_get_version))
676                .with_state(log_state),
677        );
678
679        Self { router, ..self }
680    }
681
682    pub fn with_greptime_config_options(self, opts: String) -> Self {
683        let config_router = HttpServer::route_config(GreptimeOptionsConfigState {
684            greptime_config_options: opts,
685        });
686
687        Self {
688            router: self.router.merge(config_router),
689            ..self
690        }
691    }
692
693    pub fn with_jaeger_handler(self, handler: JaegerQueryHandlerRef) -> Self {
694        Self {
695            router: self.router.nest(
696                &format!("/{HTTP_API_VERSION}/jaeger"),
697                HttpServer::route_jaeger(handler),
698            ),
699            ..self
700        }
701    }
702
703    pub fn with_dashboard_handler(self, handler: DashboardHandlerRef) -> Self {
704        Self {
705            router: self.router.nest(
706                &format!("/{HTTP_API_VERSION}/dashboards"),
707                HttpServer::route_dashboard(handler),
708            ),
709            ..self
710        }
711    }
712
713    pub fn with_extra_router(self, router: Router) -> Self {
714        Self {
715            router: self.router.merge(router),
716            ..self
717        }
718    }
719
720    pub fn add_layer<L>(self, layer: L) -> Self
721    where
722        L: Layer<Route> + Clone + Send + Sync + 'static,
723        L::Service: Service<Request> + Clone + Send + Sync + 'static,
724        <L::Service as Service<Request>>::Response: IntoResponse + 'static,
725        <L::Service as Service<Request>>::Error: Into<Infallible> + 'static,
726        <L::Service as Service<Request>>::Future: Send + 'static,
727    {
728        Self {
729            router: self.router.layer(layer),
730            ..self
731        }
732    }
733
734    pub fn build(self) -> HttpServer {
735        HttpServer {
736            options: self.options,
737            user_provider: self.user_provider,
738            shutdown_tx: Mutex::new(None),
739            router: StdMutex::new(self.router),
740            bind_addr: None,
741            memory_limiter: self.memory_limiter,
742        }
743    }
744}
745
746impl HttpServer {
747    /// Gets the router and adds necessary root routes (health, status, dashboard).
748    pub fn make_app(&self) -> Router {
749        let mut router = {
750            let router = self.router.lock().unwrap();
751            router.clone()
752        };
753
754        router = router
755            .route("/", routing::get(handler::index))
756            .route(
757                "/health",
758                routing::get(handler::health).post(handler::health),
759            )
760            .route(
761                &format!("/{HTTP_API_VERSION}/health"),
762                routing::get(handler::health).post(handler::health),
763            )
764            .route(
765                "/ready",
766                routing::get(handler::health).post(handler::health),
767            );
768
769        router = router.route("/status", routing::get(handler::status));
770
771        #[cfg(feature = "dashboard")]
772        {
773            if !self.options.disable_dashboard {
774                info!("Enable dashboard service at '/dashboard'");
775                // redirect /dashboard to /dashboard/
776                router = router.route(
777                    "/dashboard",
778                    routing::get(|uri: axum::http::uri::Uri| async move {
779                        let path = uri.path();
780                        let query = uri.query().map(|q| format!("?{}", q)).unwrap_or_default();
781
782                        let new_uri = format!("{}/{}", path, query);
783                        axum::response::Redirect::permanent(&new_uri)
784                    }),
785                );
786
787                // "/dashboard" and "/dashboard/" are two different paths in Axum.
788                // We cannot nest "/dashboard/", because we already mapping "/dashboard/{*x}" while nesting "/dashboard".
789                // So we explicitly route "/dashboard/" here.
790                router = router
791                    .route(
792                        "/dashboard/",
793                        routing::get(dashboard::static_handler).post(dashboard::static_handler),
794                    )
795                    .route(
796                        "/dashboard/{*x}",
797                        routing::get(dashboard::static_handler).post(dashboard::static_handler),
798                    );
799            }
800        }
801
802        // Add a layer to collect HTTP metrics for axum.
803        router = router.route_layer(middleware::from_fn(http_metrics_layer));
804
805        router
806    }
807
808    /// Attaches middlewares and debug routes to the router.
809    /// Callers should call this method after [HttpServer::make_app()].
810    pub fn build(&self, router: Router) -> Result<Router> {
811        let timeout_layer = if self.options.timeout != Duration::default() {
812            Some(ServiceBuilder::new().layer(DynamicTimeoutLayer::new(self.options.timeout)))
813        } else {
814            info!("HTTP server timeout is disabled");
815            None
816        };
817        let body_limit_layer = if self.options.body_limit != ReadableSize(0) {
818            Some(
819                ServiceBuilder::new()
820                    .layer(DefaultBodyLimit::max(self.options.body_limit.0 as usize)),
821            )
822        } else {
823            info!("HTTP server body limit is disabled");
824            None
825        };
826        let cors_layer = if self.options.enable_cors {
827            Some(
828                CorsLayer::new()
829                    .allow_methods([
830                        Method::GET,
831                        Method::POST,
832                        Method::PUT,
833                        Method::DELETE,
834                        Method::HEAD,
835                    ])
836                    .allow_origin(if self.options.cors_allowed_origins.is_empty() {
837                        AllowOrigin::from(Any)
838                    } else {
839                        AllowOrigin::from(
840                            self.options
841                                .cors_allowed_origins
842                                .iter()
843                                .map(|s| {
844                                    HeaderValue::from_str(s.as_str())
845                                        .context(InvalidHeaderValueSnafu)
846                                })
847                                .collect::<Result<Vec<HeaderValue>>>()?,
848                        )
849                    })
850                    .allow_headers(Any),
851            )
852        } else {
853            info!("HTTP server cross-origin is disabled");
854            None
855        };
856
857        Ok(router
858            // middlewares
859            .layer(
860                ServiceBuilder::new()
861                    // disable on failure tracing. because printing out isn't very helpful,
862                    // and we have impl IntoResponse for Error. It will print out more detailed error messages
863                    .layer(TraceLayer::new_for_http().on_failure(()))
864                    .option_layer(cors_layer)
865                    .option_layer(timeout_layer)
866                    .option_layer(body_limit_layer)
867                    // memory limit layer - must be before body is consumed
868                    .layer(middleware::from_fn_with_state(
869                        self.memory_limiter.clone(),
870                        memory_limit::memory_limit_middleware,
871                    ))
872                    // auth layer
873                    .layer(middleware::from_fn_with_state(
874                        AuthState::new(self.user_provider.clone()),
875                        authorize::check_http_auth,
876                    ))
877                    .layer(middleware::from_fn(hints::extract_hints))
878                    .layer(middleware::from_fn(client_ip::log_error_with_client_ip))
879                    .layer(middleware::from_fn(
880                        read_preference::extract_read_preference,
881                    )),
882            )
883            // Handlers for debug, we don't expect a timeout.
884            .nest(
885                "/debug",
886                Router::new()
887                    // handler for changing log level dynamically
888                    .route("/log_level", routing::post(dyn_log::dyn_log_handler))
889                    .route("/enable_trace", routing::post(dyn_trace::dyn_trace_handler))
890                    .nest(
891                        "/prof",
892                        Router::new()
893                            .route("/cpu", routing::post(pprof::pprof_handler))
894                            .route("/mem", routing::post(mem_prof::mem_prof_handler))
895                            .route("/mem/symbol", routing::post(mem_prof::symbolicate_handler))
896                            .route(
897                                "/mem/activate",
898                                routing::post(mem_prof::activate_heap_prof_handler),
899                            )
900                            .route(
901                                "/mem/deactivate",
902                                routing::post(mem_prof::deactivate_heap_prof_handler),
903                            )
904                            .route(
905                                "/mem/status",
906                                routing::get(mem_prof::heap_prof_status_handler),
907                            ) // jemalloc gdump flag status and toggle
908                            .route(
909                                "/mem/gdump",
910                                routing::get(mem_prof::gdump_status_handler)
911                                    .post(mem_prof::gdump_toggle_handler),
912                            ),
913                    ),
914            ))
915    }
916
917    fn route_metrics<S>(metrics_handler: MetricsHandler) -> Router<S> {
918        Router::new()
919            .route("/metrics", routing::get(handler::metrics))
920            .with_state(metrics_handler)
921    }
922
923    fn route_loki<S>(log_state: LogState) -> Router<S> {
924        Router::new()
925            .route("/api/v1/push", routing::post(loki::loki_ingest))
926            .layer(
927                ServiceBuilder::new()
928                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
929            )
930            .with_state(log_state)
931    }
932
933    fn route_elasticsearch<S>(log_state: LogState) -> Router<S> {
934        Router::new()
935            // Return fake responsefor HEAD '/' request.
936            .route(
937                "/",
938                routing::head((HttpStatusCode::OK, elasticsearch::elasticsearch_headers())),
939            )
940            // Return fake response for Elasticsearch version request.
941            .route("/", routing::get(elasticsearch::handle_get_version))
942            // Return fake response for Elasticsearch license request.
943            .route("/_license", routing::get(elasticsearch::handle_get_license))
944            .route("/_bulk", routing::post(elasticsearch::handle_bulk_api))
945            .route(
946                "/{index}/_bulk",
947                routing::post(elasticsearch::handle_bulk_api_with_index),
948            )
949            // Return fake response for Elasticsearch ilm request.
950            .route(
951                "/_ilm/policy/{*path}",
952                routing::any((
953                    HttpStatusCode::OK,
954                    elasticsearch::elasticsearch_headers(),
955                    axum::Json(serde_json::json!({})),
956                )),
957            )
958            // Return fake response for Elasticsearch index template request.
959            .route(
960                "/_index_template/{*path}",
961                routing::any((
962                    HttpStatusCode::OK,
963                    elasticsearch::elasticsearch_headers(),
964                    axum::Json(serde_json::json!({})),
965                )),
966            )
967            // Return fake response for Elasticsearch ingest pipeline request.
968            // See: https://www.elastic.co/guide/en/elasticsearch/reference/8.8/put-pipeline-api.html.
969            .route(
970                "/_ingest/{*path}",
971                routing::any((
972                    HttpStatusCode::OK,
973                    elasticsearch::elasticsearch_headers(),
974                    axum::Json(serde_json::json!({})),
975                )),
976            )
977            // Return fake response for Elasticsearch nodes discovery request.
978            // See: https://www.elastic.co/guide/en/elasticsearch/reference/8.8/cluster.html.
979            .route(
980                "/_nodes/{*path}",
981                routing::any((
982                    HttpStatusCode::OK,
983                    elasticsearch::elasticsearch_headers(),
984                    axum::Json(serde_json::json!({})),
985                )),
986            )
987            // Return fake response for Logstash APIs requests.
988            // See: https://www.elastic.co/guide/en/elasticsearch/reference/8.8/logstash-apis.html
989            .route(
990                "/logstash/{*path}",
991                routing::any((
992                    HttpStatusCode::OK,
993                    elasticsearch::elasticsearch_headers(),
994                    axum::Json(serde_json::json!({})),
995                )),
996            )
997            .route(
998                "/_logstash/{*path}",
999                routing::any((
1000                    HttpStatusCode::OK,
1001                    elasticsearch::elasticsearch_headers(),
1002                    axum::Json(serde_json::json!({})),
1003                )),
1004            )
1005            .layer(ServiceBuilder::new().layer(RequestDecompressionLayer::new()))
1006            .with_state(log_state)
1007    }
1008
1009    #[deprecated(since = "0.11.0", note = "Use `route_pipelines()` instead.")]
1010    fn route_log_deprecated<S>(log_state: LogState) -> Router<S> {
1011        Router::new()
1012            .route("/logs", routing::post(event::log_ingester))
1013            .route(
1014                "/pipelines/{pipeline_name}",
1015                routing::get(event::query_pipeline),
1016            )
1017            .route(
1018                "/pipelines/{pipeline_name}",
1019                routing::post(event::add_pipeline),
1020            )
1021            .route(
1022                "/pipelines/{pipeline_name}",
1023                routing::delete(event::delete_pipeline),
1024            )
1025            .route("/pipelines/dryrun", routing::post(event::pipeline_dryrun))
1026            .layer(
1027                ServiceBuilder::new()
1028                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1029            )
1030            .with_state(log_state)
1031    }
1032
1033    fn route_pipelines<S>(log_state: LogState) -> Router<S> {
1034        Router::new()
1035            .route("/ingest", routing::post(event::log_ingester))
1036            .route(
1037                "/pipelines/{pipeline_name}",
1038                routing::get(event::query_pipeline),
1039            )
1040            .route(
1041                "/pipelines/{pipeline_name}/ddl",
1042                routing::get(event::query_pipeline_ddl),
1043            )
1044            .route(
1045                "/pipelines/{pipeline_name}",
1046                routing::post(event::add_pipeline),
1047            )
1048            .route(
1049                "/pipelines/{pipeline_name}",
1050                routing::delete(event::delete_pipeline),
1051            )
1052            .route("/pipelines/_dryrun", routing::post(event::pipeline_dryrun))
1053            .layer(
1054                ServiceBuilder::new()
1055                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1056            )
1057            .with_state(log_state)
1058    }
1059
1060    fn route_sql<S>(api_state: ApiState) -> Router<S> {
1061        Router::new()
1062            .route("/sql", routing::get(handler::sql).post(handler::sql))
1063            .route(
1064                "/sql/parse",
1065                routing::get(handler::sql_parse).post(handler::sql_parse),
1066            )
1067            .route(
1068                "/sql/format",
1069                routing::get(handler::sql_format).post(handler::sql_format),
1070            )
1071            .route(
1072                "/promql",
1073                routing::get(handler::promql).post(handler::promql),
1074            )
1075            .with_state(api_state)
1076    }
1077
1078    fn route_logs<S>(log_handler: LogQueryHandlerRef) -> Router<S> {
1079        Router::new()
1080            .route("/logs", routing::get(logs::logs).post(logs::logs))
1081            .with_state(log_handler)
1082    }
1083
1084    /// Route Prometheus [HTTP API].
1085    ///
1086    /// [HTTP API]: https://prometheus.io/docs/prometheus/latest/querying/api/
1087    pub fn route_prometheus<S>(prometheus_handler: PrometheusHandlerRef) -> Router<S> {
1088        Router::new()
1089            .route(
1090                "/format_query",
1091                routing::post(format_query).get(format_query),
1092            )
1093            .route("/status/buildinfo", routing::get(build_info_query))
1094            .route("/query", routing::post(instant_query).get(instant_query))
1095            .route("/query_range", routing::post(range_query).get(range_query))
1096            .route("/labels", routing::post(labels_query).get(labels_query))
1097            .route("/series", routing::post(series_query).get(series_query))
1098            .route("/parse_query", routing::post(parse_query).get(parse_query))
1099            .route(
1100                "/label/{label_name}/values",
1101                routing::get(label_values_query),
1102            )
1103            .layer(ServiceBuilder::new().layer(CompressionLayer::new()))
1104            .with_state(prometheus_handler)
1105    }
1106
1107    /// Route Prometheus remote [read] and [write] API. In other places the related modules are
1108    /// called `prom_store`.
1109    ///
1110    /// [read]: https://prometheus.io/docs/prometheus/latest/querying/remote_read_api/
1111    /// [write]: https://prometheus.io/docs/concepts/remote_write_spec/
1112    fn route_prom<S>(state: PromStoreState) -> Router<S> {
1113        Router::new()
1114            .route("/read", routing::post(prom_store::remote_read))
1115            .route("/write", routing::post(prom_store::remote_write))
1116            .with_state(state)
1117    }
1118
1119    fn route_influxdb<S>(influxdb_handler: InfluxdbLineProtocolHandlerRef) -> Router<S> {
1120        Router::new()
1121            .route("/write", routing::post(influxdb_write_v1))
1122            .route("/api/v2/write", routing::post(influxdb_write_v2))
1123            .layer(
1124                ServiceBuilder::new()
1125                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1126            )
1127            .route("/ping", routing::get(influxdb_ping))
1128            .route("/health", routing::get(influxdb_health))
1129            .with_state(influxdb_handler)
1130    }
1131
1132    fn route_opentsdb<S>(opentsdb_handler: OpentsdbProtocolHandlerRef) -> Router<S> {
1133        Router::new()
1134            .route("/api/put", routing::post(opentsdb::put))
1135            .with_state(opentsdb_handler)
1136    }
1137
1138    fn route_otlp<S>(
1139        otlp_handler: OpenTelemetryProtocolHandlerRef,
1140        with_metric_engine: bool,
1141    ) -> Router<S> {
1142        Router::new()
1143            .route("/v1/metrics", routing::post(otlp::metrics))
1144            .route("/v1/traces", routing::post(otlp::traces))
1145            .route("/v1/logs", routing::post(otlp::logs))
1146            .layer(
1147                ServiceBuilder::new()
1148                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1149            )
1150            .with_state(OtlpState {
1151                with_metric_engine,
1152                handler: otlp_handler,
1153            })
1154    }
1155
1156    fn route_config<S>(state: GreptimeOptionsConfigState) -> Router<S> {
1157        Router::new()
1158            .route("/config", routing::get(handler::config))
1159            .with_state(state)
1160    }
1161
1162    fn route_jaeger<S>(handler: JaegerQueryHandlerRef) -> Router<S> {
1163        Router::new()
1164            .route("/api/services", routing::get(jaeger::handle_get_services))
1165            .route(
1166                "/api/services/{service_name}/operations",
1167                routing::get(jaeger::handle_get_operations_by_service),
1168            )
1169            .route(
1170                "/api/operations",
1171                routing::get(jaeger::handle_get_operations),
1172            )
1173            .route("/api/traces", routing::get(jaeger::handle_find_traces))
1174            .route(
1175                "/api/traces/{trace_id}",
1176                routing::get(jaeger::handle_get_trace),
1177            )
1178            .with_state(handler)
1179    }
1180
1181    #[cfg(feature = "dashboard")]
1182    fn route_dashboard<S>(handler: DashboardHandlerRef) -> Router<S> {
1183        use crate::http::dashboard::{add_dashboard, delete_dashboard, list_dashboards};
1184
1185        Router::new()
1186            .route("/", routing::get(list_dashboards))
1187            .route("/{dashboard_name}", routing::post(add_dashboard))
1188            .route("/{dashboard_name}", routing::delete(delete_dashboard))
1189            .layer(
1190                ServiceBuilder::new()
1191                    .layer(RequestDecompressionLayer::new().pass_through_unaccepted(true)),
1192            )
1193            .with_state(DashboardState { handler })
1194    }
1195
1196    #[cfg(not(feature = "dashboard"))]
1197    fn route_dashboard<S>(handler: DashboardHandlerRef) -> Router<S> {
1198        Router::new().with_state(DashboardState { handler })
1199    }
1200}
1201
1202pub const HTTP_SERVER: &str = "HTTP_SERVER";
1203
1204#[async_trait]
1205impl Server for HttpServer {
1206    async fn shutdown(&self) -> Result<()> {
1207        let mut shutdown_tx = self.shutdown_tx.lock().await;
1208        if let Some(tx) = shutdown_tx.take()
1209            && tx.send(()).is_err()
1210        {
1211            info!("Receiver dropped, the HTTP server has already exited");
1212        }
1213        info!("Shutdown HTTP server");
1214
1215        Ok(())
1216    }
1217
1218    async fn start(&mut self, listening: SocketAddr) -> Result<()> {
1219        let (tx, rx) = oneshot::channel();
1220        let serve = {
1221            let mut shutdown_tx = self.shutdown_tx.lock().await;
1222            ensure!(
1223                shutdown_tx.is_none(),
1224                AlreadyStartedSnafu { server: "HTTP" }
1225            );
1226
1227            let app = self.build(self.make_app())?;
1228            let listener = tokio::net::TcpListener::bind(listening)
1229                .await
1230                .context(AddressBindSnafu { addr: listening })?
1231                .tap_io(|tcp_stream| {
1232                    if let Err(e) = tcp_stream.set_nodelay(true) {
1233                        error!(e; "Failed to set TCP_NODELAY on incoming connection");
1234                    }
1235                });
1236            let serve = axum::serve(
1237                listener,
1238                app.into_make_service_with_connect_info::<SocketAddr>(),
1239            );
1240
1241            // FIXME(yingwen): Support keepalive.
1242            // See:
1243            // - https://github.com/tokio-rs/axum/discussions/2939
1244            // - https://stackoverflow.com/questions/73069718/how-do-i-keep-alive-tokiotcpstream-in-rust
1245            // let server = axum::Server::try_bind(&listening)
1246            //     .with_context(|_| AddressBindSnafu { addr: listening })?
1247            //     .tcp_nodelay(true)
1248            //     // Enable TCP keepalive to close the dangling established connections.
1249            //     // It's configured to let the keepalive probes first send after the connection sits
1250            //     // idle for 59 minutes, and then send every 10 seconds for 6 times.
1251            //     // So the connection will be closed after roughly 1 hour.
1252            //     .tcp_keepalive(Some(Duration::from_secs(59 * 60)))
1253            //     .tcp_keepalive_interval(Some(Duration::from_secs(10)))
1254            //     .tcp_keepalive_retries(Some(6))
1255            //     .serve(app.into_make_service());
1256
1257            *shutdown_tx = Some(tx);
1258
1259            serve
1260        };
1261        let listening = serve.local_addr().context(InternalIoSnafu)?;
1262        info!("HTTP server is bound to {}", listening);
1263
1264        common_runtime::spawn_global(async move {
1265            if let Err(e) = serve
1266                .with_graceful_shutdown(rx.map(drop))
1267                .await
1268                .context(InternalIoSnafu)
1269            {
1270                error!(e; "Failed to shutdown http server");
1271            }
1272        });
1273
1274        self.bind_addr = Some(listening);
1275        Ok(())
1276    }
1277
1278    fn name(&self) -> &str {
1279        HTTP_SERVER
1280    }
1281
1282    fn bind_addr(&self) -> Option<SocketAddr> {
1283        self.bind_addr
1284    }
1285
1286    fn as_any(&self) -> &dyn std::any::Any {
1287        self
1288    }
1289}
1290
1291#[cfg(test)]
1292mod test {
1293    use std::future::pending;
1294    use std::io::Cursor;
1295    use std::sync::Arc;
1296
1297    use arrow_ipc::reader::StreamReader;
1298    use arrow_schema::DataType;
1299    use axum::handler::Handler;
1300    use axum::http::StatusCode;
1301    use axum::routing::get;
1302    use common_query::Output;
1303    use common_recordbatch::RecordBatches;
1304    use datafusion_expr::LogicalPlan;
1305    use datatypes::prelude::*;
1306    use datatypes::schema::{ColumnSchema, Schema};
1307    use datatypes::vectors::{StringVector, UInt32Vector};
1308    use header::constants::GREPTIME_DB_HEADER_TIMEOUT;
1309    use query::parser::PromQuery;
1310    use query::query_engine::DescribeResult;
1311    use session::context::QueryContextRef;
1312    use sql::statements::statement::Statement;
1313    use tokio::sync::mpsc;
1314    use tokio::time::Instant;
1315
1316    use super::*;
1317    use crate::http::test_helpers::TestClient;
1318    use crate::prom_remote_write::validation::validate_label_name;
1319    use crate::query_handler::sql::SqlQueryHandler;
1320
1321    struct DummyInstance {
1322        _tx: mpsc::Sender<(String, Vec<u8>)>,
1323    }
1324
1325    #[async_trait]
1326    impl SqlQueryHandler for DummyInstance {
1327        async fn do_query(&self, _: &str, _: QueryContextRef) -> Vec<Result<Output>> {
1328            unimplemented!()
1329        }
1330
1331        async fn do_promql_query(&self, _: &PromQuery, _: QueryContextRef) -> Vec<Result<Output>> {
1332            unimplemented!()
1333        }
1334
1335        async fn do_exec_plan(
1336            &self,
1337            _plan: LogicalPlan,
1338            _stmt: Option<Statement>,
1339            _query_ctx: QueryContextRef,
1340        ) -> Result<Output> {
1341            unimplemented!()
1342        }
1343
1344        async fn do_describe(
1345            &self,
1346            _stmt: sql::statements::statement::Statement,
1347            _query_ctx: QueryContextRef,
1348        ) -> Result<Option<DescribeResult>> {
1349            unimplemented!()
1350        }
1351
1352        async fn is_valid_schema(&self, _catalog: &str, _schema: &str) -> Result<bool> {
1353            Ok(true)
1354        }
1355    }
1356
1357    fn timeout() -> DynamicTimeoutLayer {
1358        DynamicTimeoutLayer::new(Duration::from_millis(10))
1359    }
1360
1361    async fn forever() {
1362        pending().await
1363    }
1364
1365    fn make_test_app(tx: mpsc::Sender<(String, Vec<u8>)>) -> Router {
1366        make_test_app_custom(tx, HttpOptions::default())
1367    }
1368
1369    fn make_test_app_custom(tx: mpsc::Sender<(String, Vec<u8>)>, options: HttpOptions) -> Router {
1370        let instance = Arc::new(DummyInstance { _tx: tx });
1371        let server = HttpServerBuilder::new(options)
1372            .with_sql_handler(instance.clone())
1373            .build();
1374        server.build(server.make_app()).unwrap().route(
1375            "/test/timeout",
1376            get(forever.layer(ServiceBuilder::new().layer(timeout()))),
1377        )
1378    }
1379
1380    #[tokio::test]
1381    pub async fn test_cors() {
1382        // cors is on by default
1383        let (tx, _rx) = mpsc::channel(100);
1384        let app = make_test_app(tx);
1385        let client = TestClient::new(app).await;
1386
1387        let res = client.get("/health").send().await;
1388
1389        assert_eq!(res.status(), StatusCode::OK);
1390        assert_eq!(
1391            res.headers()
1392                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1393                .expect("expect cors header origin"),
1394            "*"
1395        );
1396
1397        let res = client.get("/v1/health").send().await;
1398
1399        assert_eq!(res.status(), StatusCode::OK);
1400        assert_eq!(
1401            res.headers()
1402                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1403                .expect("expect cors header origin"),
1404            "*"
1405        );
1406
1407        let res = client
1408            .options("/health")
1409            .header("Access-Control-Request-Headers", "x-greptime-auth")
1410            .header("Access-Control-Request-Method", "DELETE")
1411            .header("Origin", "https://example.com")
1412            .send()
1413            .await;
1414        assert_eq!(res.status(), StatusCode::OK);
1415        assert_eq!(
1416            res.headers()
1417                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1418                .expect("expect cors header origin"),
1419            "*"
1420        );
1421        assert_eq!(
1422            res.headers()
1423                .get(http::header::ACCESS_CONTROL_ALLOW_HEADERS)
1424                .expect("expect cors header headers"),
1425            "*"
1426        );
1427        assert_eq!(
1428            res.headers()
1429                .get(http::header::ACCESS_CONTROL_ALLOW_METHODS)
1430                .expect("expect cors header methods"),
1431            "GET,POST,PUT,DELETE,HEAD"
1432        );
1433    }
1434
1435    #[tokio::test]
1436    pub async fn test_cors_custom_origins() {
1437        // cors is on by default
1438        let (tx, _rx) = mpsc::channel(100);
1439        let origin = "https://example.com";
1440
1441        let options = HttpOptions {
1442            cors_allowed_origins: vec![origin.to_string()],
1443            ..Default::default()
1444        };
1445
1446        let app = make_test_app_custom(tx, options);
1447        let client = TestClient::new(app).await;
1448
1449        let res = client.get("/health").header("Origin", origin).send().await;
1450
1451        assert_eq!(res.status(), StatusCode::OK);
1452        assert_eq!(
1453            res.headers()
1454                .get(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1455                .expect("expect cors header origin"),
1456            origin
1457        );
1458
1459        let res = client
1460            .get("/health")
1461            .header("Origin", "https://notallowed.com")
1462            .send()
1463            .await;
1464
1465        assert_eq!(res.status(), StatusCode::OK);
1466        assert!(
1467            !res.headers()
1468                .contains_key(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1469        );
1470    }
1471
1472    #[tokio::test]
1473    pub async fn test_cors_disabled() {
1474        // cors is on by default
1475        let (tx, _rx) = mpsc::channel(100);
1476
1477        let options = HttpOptions {
1478            enable_cors: false,
1479            ..Default::default()
1480        };
1481
1482        let app = make_test_app_custom(tx, options);
1483        let client = TestClient::new(app).await;
1484
1485        let res = client.get("/health").send().await;
1486
1487        assert_eq!(res.status(), StatusCode::OK);
1488        assert!(
1489            !res.headers()
1490                .contains_key(http::header::ACCESS_CONTROL_ALLOW_ORIGIN)
1491        );
1492    }
1493
1494    #[test]
1495    fn test_http_options_default() {
1496        let default = HttpOptions::default();
1497        assert_eq!("127.0.0.1:4000".to_string(), default.addr);
1498        assert_eq!(Duration::from_secs(0), default.timeout)
1499    }
1500
1501    #[tokio::test]
1502    async fn test_http_server_request_timeout() {
1503        common_telemetry::init_default_ut_logging();
1504
1505        let (tx, _rx) = mpsc::channel(100);
1506        let app = make_test_app(tx);
1507        let client = TestClient::new(app).await;
1508        let res = client.get("/test/timeout").send().await;
1509        assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
1510
1511        let now = Instant::now();
1512        let res = client
1513            .get("/test/timeout")
1514            .header(GREPTIME_DB_HEADER_TIMEOUT, "20ms")
1515            .send()
1516            .await;
1517        assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT);
1518        let elapsed = now.elapsed();
1519        assert!(elapsed > Duration::from_millis(15));
1520
1521        tokio::time::timeout(
1522            Duration::from_millis(15),
1523            client
1524                .get("/test/timeout")
1525                .header(GREPTIME_DB_HEADER_TIMEOUT, "0s")
1526                .send(),
1527        )
1528        .await
1529        .unwrap_err();
1530
1531        tokio::time::timeout(
1532            Duration::from_millis(15),
1533            client
1534                .get("/test/timeout")
1535                .header(
1536                    GREPTIME_DB_HEADER_TIMEOUT,
1537                    humantime::format_duration(Duration::default()).to_string(),
1538                )
1539                .send(),
1540        )
1541        .await
1542        .unwrap_err();
1543    }
1544
1545    #[tokio::test]
1546    async fn test_schema_for_empty_response() {
1547        let column_schemas = vec![
1548            ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
1549            ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
1550        ];
1551        let schema = Arc::new(Schema::new(column_schemas));
1552
1553        let recordbatches = RecordBatches::try_new(schema.clone(), vec![]).unwrap();
1554        let outputs = vec![Ok(Output::new_with_record_batches(recordbatches))];
1555
1556        let json_resp = GreptimedbV1Response::from_output(outputs).await;
1557        if let HttpResponse::GreptimedbV1(json_resp) = json_resp {
1558            let json_output = &json_resp.output[0];
1559            if let GreptimeQueryOutput::Records(r) = json_output {
1560                assert_eq!(r.num_rows(), 0);
1561                assert_eq!(r.num_cols(), 2);
1562                assert_eq!(r.schema.column_schemas[0].name, "numbers");
1563                assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1564            } else {
1565                panic!("invalid output type");
1566            }
1567        } else {
1568            panic!("invalid format")
1569        }
1570    }
1571
1572    #[tokio::test]
1573    async fn test_recordbatches_conversion() {
1574        let column_schemas = vec![
1575            ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
1576            ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
1577        ];
1578        let schema = Arc::new(Schema::new(column_schemas));
1579        let columns: Vec<VectorRef> = vec![
1580            Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
1581            Arc::new(StringVector::from(vec![
1582                None,
1583                Some("hello"),
1584                Some("greptime"),
1585                None,
1586            ])),
1587        ];
1588        let recordbatch = RecordBatch::new(schema.clone(), columns).unwrap();
1589
1590        for format in [
1591            ResponseFormat::GreptimedbV1,
1592            ResponseFormat::InfluxdbV1,
1593            ResponseFormat::Csv(true, true),
1594            ResponseFormat::Table,
1595            ResponseFormat::Arrow,
1596            ResponseFormat::Json,
1597            ResponseFormat::Null,
1598        ] {
1599            let recordbatches =
1600                RecordBatches::try_new(schema.clone(), vec![recordbatch.clone()]).unwrap();
1601            let outputs = vec![Ok(Output::new_with_record_batches(recordbatches))];
1602            let json_resp = match format {
1603                ResponseFormat::Arrow => ArrowResponse::from_output(outputs, None).await,
1604                ResponseFormat::Csv(with_names, with_types) => {
1605                    CsvResponse::from_output(outputs, with_names, with_types).await
1606                }
1607                ResponseFormat::Table => TableResponse::from_output(outputs).await,
1608                ResponseFormat::GreptimedbV1 => GreptimedbV1Response::from_output(outputs).await,
1609                ResponseFormat::InfluxdbV1 => InfluxdbV1Response::from_output(outputs, None).await,
1610                ResponseFormat::Json => JsonResponse::from_output(outputs).await,
1611                ResponseFormat::Null => NullResponse::from_output(outputs).await,
1612            };
1613
1614            match json_resp {
1615                HttpResponse::GreptimedbV1(resp) => {
1616                    let json_output = &resp.output[0];
1617                    if let GreptimeQueryOutput::Records(r) = json_output {
1618                        assert_eq!(r.num_rows(), 4);
1619                        assert_eq!(r.num_cols(), 2);
1620                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1621                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1622                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1623                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1624                    } else {
1625                        panic!("invalid output type");
1626                    }
1627                }
1628                HttpResponse::InfluxdbV1(resp) => {
1629                    let json_output = &resp.results()[0];
1630                    assert_eq!(json_output.num_rows(), 4);
1631                    assert_eq!(json_output.num_cols(), 2);
1632                    assert_eq!(json_output.series[0].columns.clone()[0], "numbers");
1633                    assert_eq!(
1634                        json_output.series[0].values[0][0],
1635                        serde_json::Value::from(1)
1636                    );
1637                    assert_eq!(json_output.series[0].values[0][1], serde_json::Value::Null);
1638                }
1639                HttpResponse::Csv(resp) => {
1640                    let output = &resp.output()[0];
1641                    if let GreptimeQueryOutput::Records(r) = output {
1642                        assert_eq!(r.num_rows(), 4);
1643                        assert_eq!(r.num_cols(), 2);
1644                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1645                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1646                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1647                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1648                    } else {
1649                        panic!("invalid output type");
1650                    }
1651                }
1652
1653                HttpResponse::Table(resp) => {
1654                    let output = &resp.output()[0];
1655                    if let GreptimeQueryOutput::Records(r) = output {
1656                        assert_eq!(r.num_rows(), 4);
1657                        assert_eq!(r.num_cols(), 2);
1658                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1659                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1660                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1661                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1662                    } else {
1663                        panic!("invalid output type");
1664                    }
1665                }
1666
1667                HttpResponse::Arrow(resp) => {
1668                    let output = resp.data;
1669                    let mut reader = StreamReader::try_new(Cursor::new(output), None)
1670                        .expect("Arrow reader error");
1671                    let schema = reader.schema();
1672                    assert_eq!(schema.fields[0].name(), "numbers");
1673                    assert_eq!(schema.fields[0].data_type(), &DataType::UInt32);
1674                    assert_eq!(schema.fields[1].name(), "strings");
1675                    assert_eq!(schema.fields[1].data_type(), &DataType::Utf8);
1676
1677                    let rb = reader.next().unwrap().expect("read record batch failed");
1678                    assert_eq!(rb.num_columns(), 2);
1679                    assert_eq!(rb.num_rows(), 4);
1680                }
1681
1682                HttpResponse::Json(resp) => {
1683                    let output = &resp.output()[0];
1684                    if let GreptimeQueryOutput::Records(r) = output {
1685                        assert_eq!(r.num_rows(), 4);
1686                        assert_eq!(r.num_cols(), 2);
1687                        assert_eq!(r.schema.column_schemas[0].name, "numbers");
1688                        assert_eq!(r.schema.column_schemas[0].data_type, "UInt32");
1689                        assert_eq!(r.rows[0][0], serde_json::Value::from(1));
1690                        assert_eq!(r.rows[0][1], serde_json::Value::Null);
1691                    } else {
1692                        panic!("invalid output type");
1693                    }
1694                }
1695
1696                HttpResponse::Null(resp) => {
1697                    assert_eq!(resp.rows(), 4);
1698                }
1699
1700                HttpResponse::Error(err) => unreachable!("{err:?}"),
1701            }
1702        }
1703    }
1704
1705    #[test]
1706    fn test_response_format_misc() {
1707        assert_eq!(ResponseFormat::default(), ResponseFormat::GreptimedbV1);
1708        assert_eq!(ResponseFormat::parse("arrow"), Some(ResponseFormat::Arrow));
1709        assert_eq!(
1710            ResponseFormat::parse("csv"),
1711            Some(ResponseFormat::Csv(false, false))
1712        );
1713        assert_eq!(
1714            ResponseFormat::parse("csvwithnames"),
1715            Some(ResponseFormat::Csv(true, false))
1716        );
1717        assert_eq!(
1718            ResponseFormat::parse("csvwithnamesandtypes"),
1719            Some(ResponseFormat::Csv(true, true))
1720        );
1721        assert_eq!(ResponseFormat::parse("table"), Some(ResponseFormat::Table));
1722        assert_eq!(
1723            ResponseFormat::parse("greptimedb_v1"),
1724            Some(ResponseFormat::GreptimedbV1)
1725        );
1726        assert_eq!(
1727            ResponseFormat::parse("influxdb_v1"),
1728            Some(ResponseFormat::InfluxdbV1)
1729        );
1730        assert_eq!(ResponseFormat::parse("json"), Some(ResponseFormat::Json));
1731        assert_eq!(ResponseFormat::parse("null"), Some(ResponseFormat::Null));
1732
1733        // invalid formats
1734        assert_eq!(ResponseFormat::parse("invalid"), None);
1735        assert_eq!(ResponseFormat::parse(""), None);
1736        assert_eq!(ResponseFormat::parse("CSV"), None); // Case sensitive
1737
1738        // as str
1739        assert_eq!(ResponseFormat::Arrow.as_str(), "arrow");
1740        assert_eq!(ResponseFormat::Csv(false, false).as_str(), "csv");
1741        assert_eq!(ResponseFormat::Csv(true, true).as_str(), "csv");
1742        assert_eq!(ResponseFormat::Table.as_str(), "table");
1743        assert_eq!(ResponseFormat::GreptimedbV1.as_str(), "greptimedb_v1");
1744        assert_eq!(ResponseFormat::InfluxdbV1.as_str(), "influxdb_v1");
1745        assert_eq!(ResponseFormat::Json.as_str(), "json");
1746        assert_eq!(ResponseFormat::Null.as_str(), "null");
1747        assert_eq!(ResponseFormat::default().as_str(), "greptimedb_v1");
1748    }
1749
1750    #[test]
1751    fn test_decode_label_name_strict() {
1752        let strict = PromValidationMode::Strict;
1753
1754        // Valid Prometheus label names
1755        assert!(strict.decode_label_name(b"__name__").is_ok());
1756        assert!(strict.decode_label_name(b"job").is_ok());
1757        assert!(strict.decode_label_name(b"instance").is_ok());
1758        assert!(strict.decode_label_name(b"_private").is_ok());
1759        assert!(strict.decode_label_name(b"label_with_underscores").is_ok());
1760        assert!(strict.decode_label_name(b"abc123").is_ok());
1761        assert!(strict.decode_label_name(b"A").is_ok());
1762        assert!(strict.decode_label_name(b"_").is_ok());
1763
1764        // Invalid: starts with digit
1765        assert!(strict.decode_label_name(b"0abc").is_err());
1766        assert!(strict.decode_label_name(b"123").is_err());
1767
1768        // Invalid: contains special characters
1769        assert!(strict.decode_label_name(b"label-name").is_err());
1770        assert!(strict.decode_label_name(b"label.name").is_err());
1771        assert!(strict.decode_label_name(b"label name").is_err());
1772        assert!(strict.decode_label_name(b"label/name").is_err());
1773
1774        // Invalid: empty
1775        assert!(strict.decode_label_name(b"").is_err());
1776
1777        // Invalid: non-ASCII UTF-8
1778        assert!(strict.decode_label_name("ラベル".as_bytes()).is_err());
1779
1780        // Invalid UTF-8 bytes should fail
1781        assert!(strict.decode_label_name(&[0xff, 0xfe]).is_err());
1782    }
1783
1784    #[test]
1785    fn test_decode_label_name_lossy() {
1786        let lossy = PromValidationMode::Lossy;
1787
1788        // Label name validation is always enforced.
1789        assert!(lossy.decode_label_name(b"__name__").is_ok());
1790        assert!(lossy.decode_label_name(b"label-name").is_err());
1791        assert!(lossy.decode_label_name(b"0abc").is_err());
1792
1793        // Invalid UTF-8 bytes fail the label-name byte check.
1794        assert!(lossy.decode_label_name(&[0xff, 0xfe]).is_err());
1795    }
1796
1797    #[test]
1798    fn test_decode_label_name_unchecked() {
1799        let unchecked = PromValidationMode::Unchecked;
1800
1801        // Label name validation is always enforced.
1802        assert!(unchecked.decode_label_name(b"__name__").is_ok());
1803        assert!(unchecked.decode_label_name(b"label-name").is_err());
1804        assert!(unchecked.decode_label_name(b"0abc").is_err());
1805    }
1806
1807    #[test]
1808    fn test_is_valid_prom_label_name_bytes() {
1809        assert!(validate_label_name(b"__name__"));
1810        assert!(validate_label_name(b"job"));
1811        assert!(validate_label_name(b"_"));
1812        assert!(validate_label_name(b"A"));
1813        assert!(validate_label_name(b"abc123"));
1814        assert!(validate_label_name(b"_leading_underscore"));
1815
1816        assert!(!validate_label_name(b""));
1817        assert!(!validate_label_name(b"0starts_with_digit"));
1818        assert!(!validate_label_name(b"has-dash"));
1819        assert!(!validate_label_name(b"has.dot"));
1820        assert!(!validate_label_name(b"has space"));
1821        assert!(!validate_label_name(&[0xff, 0xfe]));
1822    }
1823}