diff --git a/config/config.md b/config/config.md index f8ce99cae1..7eba8a472e 100644 --- a/config/config.md +++ b/config/config.md @@ -32,6 +32,7 @@ | `http.enable_cors` | Bool | `true` | HTTP CORS support, it's turned on by default
This allows browser to access http APIs without CORS restrictions | | `http.cors_allowed_origins` | Array | Unset | Customize allowed origins for HTTP CORS. | | `http.prom_validation_mode` | String | `strict` | Whether to enable validation for Prometheus remote write requests.
Available options:
- strict: deny invalid UTF-8 strings (default).
- lossy: allow invalid UTF-8 strings, replace invalid characters with REPLACEMENT_CHARACTER(U+FFFD).
- unchecked: do not valid strings. | +| `http.experimental_enable_explain_analyze_stream` | Bool | `false` | Experimental: enable POST /v1/sql/analyze/stream for streaming EXPLAIN ANALYZE VERBOSE metrics. | | `grpc` | -- | -- | The gRPC server options. | | `grpc.bind_addr` | String | `127.0.0.1:4001` | The address to bind the gRPC server. | | `grpc.runtime_size` | Integer | `8` | The number of server worker threads. | @@ -247,6 +248,7 @@ | `http.enable_cors` | Bool | `true` | HTTP CORS support, it's turned on by default
This allows browser to access http APIs without CORS restrictions | | `http.cors_allowed_origins` | Array | Unset | Customize allowed origins for HTTP CORS. | | `http.prom_validation_mode` | String | `strict` | Whether to enable validation for Prometheus remote write requests.
Available options:
- strict: deny invalid UTF-8 strings (default).
- lossy: allow invalid UTF-8 strings, replace invalid characters with REPLACEMENT_CHARACTER(U+FFFD).
- unchecked: do not valid strings. | +| `http.experimental_enable_explain_analyze_stream` | Bool | `false` | Experimental: enable POST /v1/sql/analyze/stream for streaming EXPLAIN ANALYZE VERBOSE metrics. | | `grpc` | -- | -- | The gRPC server options. | | `grpc.bind_addr` | String | `127.0.0.1:4001` | The address to bind the gRPC server. | | `grpc.server_addr` | String | `127.0.0.1:4001` | The address advertised to the metasrv, and used for connections from outside the host.
If left empty or unset, the server will automatically use the IP address of the first network interface
on the host, with the same port number as the one specified in `grpc.bind_addr`. | diff --git a/config/frontend.example.toml b/config/frontend.example.toml index 2331cdf028..06b4efd36e 100644 --- a/config/frontend.example.toml +++ b/config/frontend.example.toml @@ -57,6 +57,8 @@ cors_allowed_origins = ["https://example.com"] ## - lossy: allow invalid UTF-8 strings, replace invalid characters with REPLACEMENT_CHARACTER(U+FFFD). ## - unchecked: do not valid strings. prom_validation_mode = "strict" +## Experimental: enable POST /v1/sql/analyze/stream for streaming EXPLAIN ANALYZE VERBOSE metrics. +experimental_enable_explain_analyze_stream = false ## The gRPC server options. [grpc] diff --git a/config/standalone.example.toml b/config/standalone.example.toml index 79ed2814f3..407280278d 100644 --- a/config/standalone.example.toml +++ b/config/standalone.example.toml @@ -71,6 +71,8 @@ cors_allowed_origins = ["https://example.com"] ## - lossy: allow invalid UTF-8 strings, replace invalid characters with REPLACEMENT_CHARACTER(U+FFFD). ## - unchecked: do not valid strings. prom_validation_mode = "strict" +## Experimental: enable POST /v1/sql/analyze/stream for streaming EXPLAIN ANALYZE VERBOSE metrics. +experimental_enable_explain_analyze_stream = false ## The gRPC server options. [grpc] diff --git a/src/frontend/src/instance.rs b/src/frontend/src/instance.rs index 1ee45ef438..8a6bc062f2 100644 --- a/src/frontend/src/instance.rs +++ b/src/frontend/src/instance.rs @@ -89,7 +89,7 @@ use sql::statements::comment::CommentObject; use sql::statements::copy::{CopyDatabase, CopyTable}; use sql::statements::statement::Statement; use sql::statements::tql::Tql; -use sqlparser::ast::ObjectName; +use sqlparser::ast::{AnalyzeFormat, ObjectName}; pub use standalone::StandaloneDatanodeManager; use table::requests::{OTLP_METRIC_COMPAT_KEY, OTLP_METRIC_COMPAT_PROM}; use tracing::Span; @@ -195,35 +195,66 @@ fn parse_stmt(sql: &str, dialect: &(dyn Dialect + Send + Sync)) -> Result Result<()> { + let Statement::Explain(explain) = stmt else { + return InvalidSqlSnafu { + err_msg: "only EXPLAIN ANALYZE VERBOSE statement is supported", + } + .fail(); + }; + ensure!( + explain.analyze && explain.verbose, + InvalidSqlSnafu { + err_msg: "statement must be EXPLAIN ANALYZE VERBOSE" + } + ); + match explain.format { + None | Some(AnalyzeFormat::JSON) => { + // Keep explicit FORMAT JSON accepted, but pass JSON through + // QueryContext.explain_format instead of the statement to avoid the + // planner's current `EXPLAIN VERBOSE with FORMAT` limitation. + explain.format = None; + Ok(()) + } + Some(_) => InvalidSqlSnafu { + err_msg: "only FORMAT JSON is supported for analyze stream", + } + .fail(), + } +} + impl Instance { + fn statement_slow_query_timer( + &self, + stmt: &Statement, + schema_name: String, + ) -> Option { + if !stmt.is_readonly() || !self.slow_query_options.enable { + return None; + } + + self.event_recorder.clone().map(|event_recorder| { + SlowQueryTimer::new( + CatalogQueryStatement::Sql(stmt.clone()), + schema_name, + self.slow_query_options.threshold, + self.slow_query_options.sample_ratio, + self.slow_query_options.record_type, + event_recorder, + ) + }) + } + async fn query_statement(&self, stmt: Statement, query_ctx: QueryContextRef) -> Result { check_permission(self.plugins.clone(), &stmt, &query_ctx)?; let query_interceptor = self.plugins.get::>(); let query_interceptor = query_interceptor.as_ref(); - let is_readonly_stmt = stmt.is_readonly(); if should_track_statement_process(&stmt) { let catalog_name = query_ctx.current_catalog().to_string(); let schema_name = query_ctx.current_schema(); - let slow_query_timer = if is_readonly_stmt { - self.slow_query_options - .enable - .then(|| self.event_recorder.clone()) - .flatten() - .map(|event_recorder| { - SlowQueryTimer::new( - CatalogQueryStatement::Sql(stmt.clone()), - schema_name.clone(), - self.slow_query_options.threshold, - self.slow_query_options.sample_ratio, - self.slow_query_options.record_type, - event_recorder, - ) - }) - } else { - None - }; + let slow_query_timer = self.statement_slow_query_timer(&stmt, schema_name.clone()); let ticket = self.process_manager.register_query( catalog_name, @@ -552,6 +583,62 @@ fn attach_timeout(output: Output, mut timeout: Duration) -> Result { } impl Instance { + #[tracing::instrument(skip_all, name = "SqlQueryHandler::do_analyze_stream_query")] + async fn do_analyze_stream_query_inner( + &self, + query: &str, + query_ctx: QueryContextRef, + ) -> Result { + ensure!(!self.is_suspended(), error::SuspendedSnafu); + + let query_interceptor_opt = self.plugins.get::>(); + let query_interceptor = query_interceptor_opt.as_ref(); + let query = query_interceptor.pre_parsing(query, query_ctx.clone())?; + let mut stmts = parse_stmt(query.as_ref(), query_ctx.sql_dialect()) + .and_then(|stmts| query_interceptor.post_parsing(stmts, query_ctx.clone()))?; + + ensure!( + stmts.len() == 1, + InvalidSqlSnafu { + err_msg: "only single EXPLAIN ANALYZE VERBOSE statement is supported" + } + ); + let mut stmt = stmts.remove(0); + validate_analyze_stream_statement(&mut stmt)?; + query_ctx.set_explain_format(AnalyzeFormat::JSON.to_string()); + + let checker_ref = self.plugins.get::(); + checker_ref + .as_ref() + .check_permission(query_ctx.current_user(), PermissionReq::SqlStatement(&stmt)) + .context(PermissionSnafu)?; + check_permission(self.plugins.clone(), &stmt, &query_ctx)?; + let catalog_name = query_ctx.current_catalog().to_string(); + let schema_name = query_ctx.current_schema(); + let slow_query_timer = self.statement_slow_query_timer(&stmt, schema_name.clone()); + let ticket = self.process_manager.register_query( + catalog_name, + vec![schema_name], + stmt.to_string(), + query_ctx.conn_info().to_string(), + Some(query_ctx.process_id()), + slow_query_timer, + ); + let query_fut = + self.exec_statement_with_timeout(stmt, query_ctx.clone(), query_interceptor); + let output = CancellableFuture::new(query_fut, ticket.cancellation_handle.clone()) + .await + .map_err(|_| error::CancelledSnafu.build())??; + let Output { meta, data } = output; + let data = match data { + OutputData::Stream(stream) => OutputData::Stream(Box::pin( + CancellableStreamWrapper::new_cancel_on_drop(stream, ticket), + )), + other => other, + }; + query_interceptor.post_execute(Output { data, meta }, query_ctx) + } + #[tracing::instrument(skip_all, name = "SqlQueryHandler::do_query")] async fn do_query_inner(&self, query: &str, query_ctx: QueryContextRef) -> Vec> { if self.is_suspended() { @@ -801,6 +888,17 @@ impl SqlQueryHandler for Instance { .collect() } + async fn do_analyze_stream_query( + &self, + query: &str, + query_ctx: QueryContextRef, + ) -> server_error::Result { + self.do_analyze_stream_query_inner(query, query_ctx) + .await + .map_err(BoxedError::new) + .context(ExecuteQuerySnafu) + } + async fn do_exec_plan( &self, plan: LogicalPlan, @@ -1262,6 +1360,46 @@ mod tests { use crate::frontend::FrontendOptions; use crate::instance::builder::FrontendBuilder; + fn parse_test_sql(sql: &str) -> Vec { + parse_stmt(sql, &GreptimeDbDialect {}).unwrap() + } + + #[test] + fn test_validate_analyze_stream_statement_strictness() { + for sql in [ + "select 1", + "explain analyze select 1", + "explain analyze verbose format text select 1", + "explain analyze verbose format graphviz select 1", + ] { + let mut stmts = parse_test_sql(sql); + assert!( + validate_analyze_stream_statement(&mut stmts[0]).is_err(), + "{sql}" + ); + } + + for sql in [ + "explain analyze verbose select 1", + "explain analyze verbose format json select 1", + ] { + let mut stmts = parse_test_sql(sql); + assert!( + validate_analyze_stream_statement(&mut stmts[0]).is_ok(), + "{sql}" + ); + let Statement::Explain(explain) = &stmts[0] else { + unreachable!(); + }; + assert!(explain.format.is_none()); + } + + assert_eq!( + parse_test_sql("explain analyze verbose select 1; select 2").len(), + 2 + ); + } + #[derive(Debug, Snafu)] enum TestError { #[snafu(display("Failed to build test cache registry"))] diff --git a/src/frontend/src/stream_wrapper.rs b/src/frontend/src/stream_wrapper.rs index 7ac27a1be2..60e4d5e930 100644 --- a/src/frontend/src/stream_wrapper.rs +++ b/src/frontend/src/stream_wrapper.rs @@ -24,6 +24,8 @@ use futures::Stream; pub struct CancellableStreamWrapper { inner: SendableRecordBatchStream, ticket: Ticket, + cancel_on_drop: bool, + finished: bool, } impl Unpin for CancellableStreamWrapper {} @@ -33,6 +35,17 @@ impl CancellableStreamWrapper { Self { inner: stream, ticket, + cancel_on_drop: false, + finished: false, + } + } + + pub fn new_cancel_on_drop(stream: SendableRecordBatchStream, ticket: Ticket) -> Self { + Self { + inner: stream, + ticket, + cancel_on_drop: true, + finished: false, } } } @@ -47,6 +60,9 @@ impl Stream for CancellableStreamWrapper { } if let Poll::Ready(res) = Pin::new(&mut this.inner).poll_next(cx) { + if res.is_none() { + this.finished = true; + } return Poll::Ready(res); } @@ -60,6 +76,14 @@ impl Stream for CancellableStreamWrapper { } } +impl Drop for CancellableStreamWrapper { + fn drop(&mut self) { + if self.cancel_on_drop && !self.finished { + self.ticket.cancellation_handle.cancel(); + } + } +} + impl RecordBatchStream for CancellableStreamWrapper { fn schema(&self) -> SchemaRef { self.inner.schema() @@ -371,4 +395,50 @@ mod tests { assert!(stream_result.is_some()); assert!(stream_result.unwrap().is_err()); } + + #[tokio::test] + async fn test_cancel_on_drop_cancels_unfinished_stream() { + let batch = create_test_batch(); + let mock_stream = MockRecordBatchStream::new(vec![Ok(batch)]); + let process_manager = Arc::new(ProcessManager::new("".to_string(), None)); + let ticket = process_manager.register_query( + "catalog".to_string(), + vec![], + "query".to_string(), + "client".to_string(), + None, + None, + ); + let cancellation_handle = ticket.cancellation_handle.clone(); + + let cancellable_stream = + CancellableStreamWrapper::new_cancel_on_drop(Box::pin(mock_stream), ticket); + drop(cancellable_stream); + + assert!(cancellation_handle.is_cancelled()); + } + + #[tokio::test] + async fn test_cancel_on_drop_does_not_cancel_finished_stream() { + let batch = create_test_batch(); + let mock_stream = MockRecordBatchStream::new(vec![Ok(batch)]); + let process_manager = Arc::new(ProcessManager::new("".to_string(), None)); + let ticket = process_manager.register_query( + "catalog".to_string(), + vec![], + "query".to_string(), + "client".to_string(), + None, + None, + ); + let cancellation_handle = ticket.cancellation_handle.clone(); + + let mut cancellable_stream = + CancellableStreamWrapper::new_cancel_on_drop(Box::pin(mock_stream), ticket); + assert!(cancellable_stream.next().await.unwrap().is_ok()); + assert!(cancellable_stream.next().await.is_none()); + drop(cancellable_stream); + + assert!(!cancellation_handle.is_cancelled()); + } } diff --git a/src/query/src/analyze.rs b/src/query/src/analyze.rs index daa9d9485d..6f6864ef2b 100644 --- a/src/query/src/analyze.rs +++ b/src/query/src/analyze.rs @@ -37,6 +37,7 @@ use datafusion_common::{DataFusionError, internal_err}; use datafusion_physical_expr::{Distribution, EquivalenceProperties, Partitioning}; use futures::StreamExt; use serde::Serialize; +use serde_json::{Value, json}; use sqlparser::ast::AnalyzeFormat; use crate::dist_plan::MergeScanExec; @@ -84,6 +85,52 @@ impl DistAnalyzeExec { properties.boundedness, ) } + + pub fn input(&self) -> &Arc { + &self.input + } +} + +/// Returns verbose analyze metrics as JSON values using the same `JsonMetrics` shape +/// as `EXPLAIN ANALYZE VERBOSE FORMAT JSON`. +/// +/// This reads metrics directly from a running physical plan for the experimental +/// HTTP analyze stream. It is a best-effort diagnostic live snapshot, not a +/// transactionally consistent snapshot; metric values may change while this +/// function traverses the plan. +pub fn analyze_plan_metrics_to_json_value( + plan: &Arc, + verbose: bool, +) -> serde_json::Result { + let input = plan + .as_any() + .downcast_ref::() + .map(|exec| exec.input().clone()) + .unwrap_or_else(|| plan.clone()); + + let mut stages = Vec::new(); + let mut collector = MetricCollector::new(verbose); + accept(input.as_ref(), &mut collector).unwrap(); + stages.push(json!({ + "stage": 0, + "node": 0, + "plan": JsonMetrics::from_record_batch_metrics(collector.record_batch_metrics), + })); + + let _ = input.apply(|plan| { + if let Some(merge_scan) = plan.as_any().downcast_ref::() { + for (node, metric) in merge_scan.sub_stage_metrics().into_iter().enumerate() { + stages.push(json!({ + "stage": 1, + "node": node, + "plan": JsonMetrics::from_record_batch_metrics(metric), + })); + } + } + Ok(TreeNodeRecursion::Continue) + }); + + Ok(Value::Array(stages)) } impl DisplayAs for DistAnalyzeExec { diff --git a/src/query/src/lib.rs b/src/query/src/lib.rs index 2b159a91b1..68d8ff3a9e 100644 --- a/src/query/src/lib.rs +++ b/src/query/src/lib.rs @@ -45,6 +45,7 @@ pub(crate) mod test_util; #[cfg(test)] mod tests; +pub use crate::analyze::analyze_plan_metrics_to_json_value; pub use crate::datafusion::DfContextProviderAdapter; pub use crate::query_engine::{ QueryEngine, QueryEngineContext, QueryEngineFactory, QueryEngineRef, diff --git a/src/servers/src/http.rs b/src/servers/src/http.rs index 5f5462f264..aa0633ec45 100644 --- a/src/servers/src/http.rs +++ b/src/servers/src/http.rs @@ -166,6 +166,8 @@ pub struct HttpOptions { pub cors_allowed_origins: Vec, pub enable_cors: bool, + + pub experimental_enable_explain_analyze_stream: bool, } impl Default for HttpOptions { @@ -178,6 +180,7 @@ impl Default for HttpOptions { cors_allowed_origins: Vec::new(), enable_cors: true, prom_validation_mode: PromValidationMode::Strict, + experimental_enable_explain_analyze_stream: false, } } } @@ -502,6 +505,7 @@ impl From for HttpResponse { #[derive(Clone)] pub struct ApiState { pub sql_handler: ServerSqlQueryHandlerRef, + pub experimental_enable_explain_analyze_stream: bool, } #[derive(Clone)] @@ -538,7 +542,12 @@ impl HttpServerBuilder { } pub fn with_sql_handler(self, sql_handler: ServerSqlQueryHandlerRef) -> Self { - let sql_router = HttpServer::route_sql(ApiState { sql_handler }); + let sql_router = HttpServer::route_sql(ApiState { + sql_handler, + experimental_enable_explain_analyze_stream: self + .options + .experimental_enable_explain_analyze_stream, + }); Self { router: self @@ -1097,7 +1106,7 @@ impl HttpServer { } fn route_sql(api_state: ApiState) -> Router { - Router::new() + let mut router = Router::new() .route("/sql", routing::get(handler::sql).post(handler::sql)) .route( "/sql/parse", @@ -1110,8 +1119,16 @@ impl HttpServer { .route( "/promql", routing::get(handler::promql).post(handler::promql), - ) - .with_state(api_state) + ); + + if api_state.experimental_enable_explain_analyze_stream { + router = router.route( + "/sql/analyze/stream", + routing::post(handler::sql_analyze_stream), + ); + } + + router.with_state(api_state) } fn route_logs(log_handler: LogQueryHandlerRef) -> Router { @@ -1338,7 +1355,7 @@ mod test { use axum::handler::Handler; use axum::http::StatusCode; use axum::routing::get; - use common_query::Output; + use common_query::{Output, OutputData}; use common_recordbatch::RecordBatches; use datafusion_expr::LogicalPlan; use datatypes::prelude::*; @@ -1367,6 +1384,11 @@ mod test { unimplemented!() } + async fn do_analyze_stream_query(&self, _: &str, _: QueryContextRef) -> Result { + let stream = common_recordbatch::RecordBatches::empty().as_stream(); + Ok(Output::new(OutputData::Stream(stream), Default::default())) + } + async fn do_promql_query(&self, _: &PromQuery, _: QueryContextRef) -> Vec> { unimplemented!() } @@ -1416,6 +1438,31 @@ mod test { ) } + #[tokio::test] + pub async fn test_analyze_stream_route_config_gate() { + let (tx, _rx) = mpsc::channel(100); + let app = make_test_app_custom(tx, HttpOptions::default()); + let client = TestClient::new(app).await; + let res = client + .post("/v1/sql/analyze/stream?sql=EXPLAIN%20ANALYZE%20VERBOSE%20SELECT%201") + .send() + .await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); + + let (tx, _rx) = mpsc::channel(100); + let options = HttpOptions { + experimental_enable_explain_analyze_stream: true, + ..Default::default() + }; + let app = make_test_app_custom(tx, options); + let client = TestClient::new(app).await; + let res = client + .post("/v1/sql/analyze/stream?sql=EXPLAIN%20ANALYZE%20VERBOSE%20SELECT%201") + .send() + .await; + assert_ne!(res.status(), StatusCode::NOT_FOUND); + } + #[tokio::test] pub async fn test_cors() { // cors is on by default diff --git a/src/servers/src/http/handler.rs b/src/servers/src/http/handler.rs index 4cfbb57691..a6421d4e55 100644 --- a/src/servers/src/http/handler.rs +++ b/src/servers/src/http/handler.rs @@ -14,9 +14,11 @@ use std::collections::HashMap; use std::sync::Arc; -use std::time::Instant; +use std::time::{Duration, Instant}; +use axum::extract::rejection::FormRejection; use axum::extract::{Json, Query, State}; +use axum::response::sse::{Event, KeepAlive, Sse}; use axum::response::{IntoResponse, Response}; use axum::{Extension, Form}; use common_catalog::parse_catalog_and_schema_from_db_string; @@ -24,8 +26,11 @@ use common_error::ext::ErrorExt; use common_error::status_code::StatusCode; use common_plugins::GREPTIME_EXEC_WRITE_COST; use common_query::{Output, OutputData}; -use common_recordbatch::util; +use common_recordbatch::{RecordBatch, SendableRecordBatchStream, util}; use common_telemetry::tracing; +use datafusion::physical_plan::ExecutionPlan; +use datatypes::schema::SchemaRef; +use futures::StreamExt; use query::parser::{DEFAULT_LOOKBACK_STRING, PromQuery}; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -72,6 +77,39 @@ pub struct SqlQuery { pub limit: Option, // For arrow output pub compression: Option, + pub snapshot_interval_ms: Option, +} + +const DEFAULT_ANALYZE_SNAPSHOT_INTERVAL_MS: u64 = 5000; +const MIN_ANALYZE_SNAPSHOT_INTERVAL_MS: u64 = 1000; +const MAX_ANALYZE_SNAPSHOT_INTERVAL_MS: u64 = 60000; + +#[derive(Serialize)] +struct AnalyzeStreamPayload { + seq: u64, + state: &'static str, + partial: bool, + elapsed_ms: u64, + #[serde(skip_serializing_if = "Option::is_none")] + metrics: Option, + #[serde(skip_serializing_if = "Option::is_none")] + output: Option, + #[serde(skip_serializing_if = "Option::is_none")] + reason: Option, + #[serde(skip_serializing_if = "Option::is_none")] + code: Option, +} + +struct AnalyzeStreamState { + stream: SendableRecordBatchStream, + schema: SchemaRef, + plan: Option>, + batches: Vec, + seq: u64, + start: Instant, + requested_interval_ms: u64, + current_interval_ms: u64, + done: bool, } /// Handler to execute sql @@ -155,6 +193,274 @@ pub async fn sql( resp.with_execution_time(start.elapsed().as_millis() as u64) } +/// Handler to stream partial `EXPLAIN ANALYZE VERBOSE` metrics as SSE. +/// +/// This experimental endpoint is POST-only SSE, so browser `EventSource` does +/// not apply. Each `metrics` event carries a complete snapshot (not a delta); +/// large snapshots are throttled but never truncated. `final`, `canceled`, and +/// `error` are terminal events. If the client disconnects it won't receive a +/// `canceled` event, but the production frontend stream is dropped and +/// best-effort cancels the underlying query. +#[axum_macros::debug_handler] +#[tracing::instrument( + skip_all, + fields(protocol = "http", request_type = "sql_analyze_stream") +)] +pub async fn sql_analyze_stream( + State(state): State, + Query(query_params): Query, + Extension(mut query_ctx): Extension, + form_params: std::result::Result, FormRejection>, +) -> Response { + let start = Instant::now(); + let form_params = match form_params { + Ok(Form(params)) => params, + Err(err) => { + if err.status() != axum::http::StatusCode::UNSUPPORTED_MEDIA_TYPE { + return ErrorResponse::from_error_message( + StatusCode::InvalidArguments, + err.body_text(), + ) + .with_execution_time(start.elapsed().as_millis() as u64) + .into_response(); + } + SqlQuery::default() + } + }; + let sql_handler = &state.sql_handler; + if let Some(db) = &query_params.db.or(form_params.db) { + let (catalog, schema) = parse_catalog_and_schema_from_db_string(db); + query_ctx.set_current_catalog(&catalog); + query_ctx.set_current_schema(&schema); + } + query_ctx.set_channel(Channel::HttpSql); + let query_ctx = Arc::new(query_ctx); + + let Some(sql) = query_params.sql.or(form_params.sql) else { + return ErrorResponse::from_error_message( + StatusCode::InvalidArguments, + "sql parameter is required.".to_string(), + ) + .with_execution_time(start.elapsed().as_millis() as u64) + .into_response(); + }; + if let Some((status, msg)) = validate_schema(sql_handler.clone(), query_ctx.clone()).await { + return ErrorResponse::from_error_message(status, msg) + .with_execution_time(start.elapsed().as_millis() as u64) + .into_response(); + } + + let interval_ms = query_params + .snapshot_interval_ms + .or(form_params.snapshot_interval_ms) + .unwrap_or(DEFAULT_ANALYZE_SNAPSHOT_INTERVAL_MS) + .clamp( + MIN_ANALYZE_SNAPSHOT_INTERVAL_MS, + MAX_ANALYZE_SNAPSHOT_INTERVAL_MS, + ); + + let output = match state + .sql_handler + .do_analyze_stream_query(&sql, query_ctx.clone()) + .await + { + Ok(output) => output, + Err(err) => { + return ErrorResponse::from_error(err) + .with_execution_time(start.elapsed().as_millis() as u64) + .into_response(); + } + }; + + let plan = output.meta.plan.clone(); + let OutputData::Stream(stream) = output.data else { + return ErrorResponse::from_error_message( + StatusCode::InvalidArguments, + "analyze stream query must return a stream".to_string(), + ) + .with_execution_time(start.elapsed().as_millis() as u64) + .into_response(); + }; + let schema = stream.schema(); + + let sse_stream = futures::stream::unfold( + AnalyzeStreamState { + stream, + schema, + plan, + batches: Vec::new(), + seq: 0, + start, + requested_interval_ms: interval_ms, + current_interval_ms: interval_ms, + done: false, + }, + |mut state| async move { + if state.done { + return None; + } + let tick = tokio::time::sleep(Duration::from_millis(state.current_interval_ms)); + tokio::pin!(tick); + loop { + tokio::select! { + item = state.stream.next() => { + match item { + Some(Ok(batch)) => state.batches.push(batch), + Some(Err(err)) => { + let status = err.status_code(); + let event_name = if status == StatusCode::Cancelled { "canceled" } else { "error" }; + let (payload, _) = make_analyze_payload(AnalyzePayloadArgs { + seq: state.seq, + state: event_name, + partial: false, + start: state.start, + plan: state.plan.as_ref(), + output: None, + reason: Some(err.output_msg()), + code: Some(status as u32), + }); + state.seq += 1; + state.done = true; + return Some((Ok::(Event::default().event(event_name).data(payload)), state)); + } + None => { + let batches = std::mem::take(&mut state.batches); + let output = HttpRecordsOutput::try_new(state.schema.clone(), batches) + .map(GreptimeQueryOutput::Records); + let (event_name, payload) = make_final_analyze_event( + output.map_err(|err| (err.output_msg(), err.status_code() as u32)), + state.seq, + state.start, + state.plan.as_ref(), + ); + state.seq += 1; + state.done = true; + return Some((Ok::(Event::default().event(event_name).data(payload)), state)); + } + } + } + _ = &mut tick => { + if state.plan.is_some() { + let (payload, payload_bytes) = make_analyze_payload(AnalyzePayloadArgs { + seq: state.seq, + state: "metrics", + partial: true, + start: state.start, + plan: state.plan.as_ref(), + output: None, + reason: None, + code: None, + }); + state.current_interval_ms = adaptive_interval_ms(payload_bytes, state.requested_interval_ms); + state.seq += 1; + return Some((Ok::(Event::default().event("metrics").data(payload)), state)); + } + tick.as_mut().reset(tokio::time::Instant::now() + Duration::from_millis(state.current_interval_ms)); + } + } + } + }, + ); + + Sse::new(sse_stream) + .keep_alive(KeepAlive::new().interval(Duration::from_secs(15))) + .into_response() +} + +fn adaptive_interval_ms(payload_bytes: usize, requested_ms: u64) -> u64 { + if payload_bytes >= 10 * 1024 * 1024 { + requested_ms.max(30_000) + } else if payload_bytes >= 1024 * 1024 { + requested_ms.max(10_000) + } else { + requested_ms + } +} + +fn make_final_analyze_event( + output: std::result::Result, + seq: u64, + start: Instant, + plan: Option<&Arc>, +) -> (&'static str, String) { + match output { + Ok(output) => ( + "final", + make_analyze_payload(AnalyzePayloadArgs { + seq, + state: "final", + partial: false, + start, + plan, + output: Some(output), + reason: None, + code: None, + }) + .0, + ), + Err((reason, code)) => ( + "error", + make_analyze_payload(AnalyzePayloadArgs { + seq, + state: "error", + partial: false, + start, + plan, + output: None, + reason: Some(reason), + code: Some(code), + }) + .0, + ), + } +} + +struct AnalyzePayloadArgs<'a> { + seq: u64, + state: &'static str, + partial: bool, + start: Instant, + plan: Option<&'a Arc>, + output: Option, + reason: Option, + code: Option, +} + +fn make_analyze_payload(args: AnalyzePayloadArgs<'_>) -> (String, usize) { + let AnalyzePayloadArgs { + seq, + state, + partial, + start, + plan, + output, + reason, + code, + } = args; + let metrics = plan.and_then(|plan| query::analyze_plan_metrics_to_json_value(plan, true).ok()); + let payload = AnalyzeStreamPayload { + seq, + state, + partial, + elapsed_ms: start.elapsed().as_millis() as u64, + metrics, + output, + reason, + code, + }; + let payload_string = serde_json::to_string(&payload).unwrap_or_else(|e| { + serde_json::json!({ + "seq": seq, + "state": "error", + "partial": false, + "reason": format!("Failed to serialize SSE payload: {e}"), + }) + .to_string() + }); + let payload_bytes = payload_string.len(); + (payload_string, payload_bytes) +} + /// Handler to parse sql #[axum_macros::debug_handler] #[tracing::instrument(skip_all, fields(protocol = "http", request_type = "sql"))] @@ -500,3 +806,26 @@ pub async fn index() -> axum::response::Html { "#, )) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_final_analyze_event_uses_error_event_for_conversion_error() { + let (event_name, payload) = make_final_analyze_event( + Err(( + "conversion failed".to_string(), + StatusCode::InvalidArguments as u32, + )), + 7, + Instant::now(), + None, + ); + + assert_eq!(event_name, "error"); + let value: Value = serde_json::from_str(&payload).unwrap(); + assert_eq!(value["state"], "error"); + assert_eq!(value["reason"], "conversion failed"); + } +} diff --git a/src/servers/src/mysql/handler.rs b/src/servers/src/mysql/handler.rs index 88a539ea21..b0eb33f8fd 100644 --- a/src/servers/src/mysql/handler.rs +++ b/src/servers/src/mysql/handler.rs @@ -948,6 +948,10 @@ mod tests { unimplemented!() } + async fn do_analyze_stream_query(&self, _: &str, _: QueryContextRef) -> Result { + unimplemented!() + } + async fn do_promql_query(&self, _: &PromQuery, _: QueryContextRef) -> Vec> { unimplemented!() } diff --git a/src/servers/src/query_handler/sql.rs b/src/servers/src/query_handler/sql.rs index eb72190309..5ebb66d83b 100644 --- a/src/servers/src/query_handler/sql.rs +++ b/src/servers/src/query_handler/sql.rs @@ -30,6 +30,21 @@ pub type ServerSqlQueryHandlerRef = Arc; pub trait SqlQueryHandler { async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Vec>; + /// Executes the experimental HTTP analyze-stream query path. + /// + /// Implementations must validate that `query` is exactly one explicit + /// `EXPLAIN ANALYZE VERBOSE` statement and must return a streaming output. + /// `OutputMeta.plan` is used by the HTTP layer to emit metrics snapshots; + /// when it is absent, partial metrics may not be available. The returned + /// stream should support cancel-on-drop semantics (as the production + /// frontend implementation does) so client disconnect can best-effort cancel + /// the underlying query. + async fn do_analyze_stream_query( + &self, + query: &str, + query_ctx: QueryContextRef, + ) -> Result; + async fn do_exec_plan( &self, plan: LogicalPlan, diff --git a/src/servers/tests/http/http_handler_test.rs b/src/servers/tests/http/http_handler_test.rs index 12465ffbc6..68cfc3a295 100644 --- a/src/servers/tests/http/http_handler_test.rs +++ b/src/servers/tests/http/http_handler_test.rs @@ -13,31 +13,157 @@ // limitations under the License. use std::collections::HashMap; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::time::Duration; +use async_trait::async_trait; use axum::Form; use axum::extract::{Json, Query, State}; -use axum::http::header; +use axum::http::{StatusCode, header}; use axum::response::{IntoResponse, Response}; use bytes::Bytes; +use common_query::{Output, OutputData}; +use common_recordbatch::adapter::RecordBatchMetrics; +use common_recordbatch::{OrderOption, RecordBatch, RecordBatchStream, SendableRecordBatchStream}; +use datafusion_expr::LogicalPlan; +use datatypes::schema::SchemaRef; +use futures::Stream; use headers::HeaderValue; use mime_guess::mime; +use query::parser::PromQuery; +use query::query_engine::DescribeResult; +use serde_json::Value; +use servers::error::Result; use servers::http::GreptimeQueryOutput::Records; +use servers::http::test_helpers::TestClient; use servers::http::{ - ApiState, GreptimeOptionsConfigState, GreptimeQueryOutput, HttpResponse, - handler as http_handler, + ApiState, GreptimeOptionsConfigState, GreptimeQueryOutput, HttpOptions, HttpResponse, + HttpServerBuilder, handler as http_handler, }; use servers::metrics_handler::MetricsHandler; -use session::context::QueryContext; +use servers::query_handler::sql::{ServerSqlQueryHandlerRef, SqlQueryHandler}; +use session::context::{QueryContext, QueryContextRef}; +use sql::statements::statement::Statement; use table::test_util::MemTable; use crate::create_testing_sql_query_handler; +struct DelayedRecordBatchStream { + inner: SendableRecordBatchStream, + schema: SchemaRef, + delay: Pin>, + delayed: bool, +} + +impl DelayedRecordBatchStream { + fn new(inner: SendableRecordBatchStream, delay: Duration) -> Self { + let schema = inner.schema(); + Self { + inner, + schema, + delay: Box::pin(tokio::time::sleep(delay)), + delayed: false, + } + } +} + +impl Stream for DelayedRecordBatchStream { + type Item = common_recordbatch::error::Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if !self.delayed { + match self.delay.as_mut().poll(cx) { + Poll::Ready(()) => self.delayed = true, + Poll::Pending => return Poll::Pending, + } + } + + Pin::new(&mut self.inner).poll_next(cx) + } +} + +impl RecordBatchStream for DelayedRecordBatchStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn output_ordering(&self) -> Option<&[OrderOption]> { + self.inner.output_ordering() + } + + fn metrics(&self) -> Option { + self.inner.metrics() + } +} + +struct SlowAnalyzeStreamHandler { + inner: ServerSqlQueryHandlerRef, + delay: Duration, +} + +#[async_trait] +impl SqlQueryHandler for SlowAnalyzeStreamHandler { + async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Vec> { + self.inner.do_query(query, query_ctx).await + } + + async fn do_analyze_stream_query( + &self, + query: &str, + query_ctx: QueryContextRef, + ) -> Result { + let output = self.inner.do_analyze_stream_query(query, query_ctx).await?; + let Output { data, meta } = output; + let data = match data { + OutputData::Stream(stream) => { + OutputData::Stream(Box::pin(DelayedRecordBatchStream::new(stream, self.delay))) + } + data => data, + }; + Ok(Output { data, meta }) + } + + async fn do_exec_plan( + &self, + plan: LogicalPlan, + stmt: Option, + query_ctx: QueryContextRef, + ) -> Result { + self.inner.do_exec_plan(plan, stmt, query_ctx).await + } + + async fn do_promql_query( + &self, + query: &PromQuery, + query_ctx: QueryContextRef, + ) -> Vec> { + self.inner.do_promql_query(query, query_ctx).await + } + + async fn do_describe( + &self, + stmt: Statement, + query_ctx: QueryContextRef, + ) -> Result> { + self.inner.do_describe(stmt, query_ctx).await + } + + async fn is_valid_schema(&self, catalog: &str, schema: &str) -> Result { + self.inner.is_valid_schema(catalog, schema).await + } +} + #[tokio::test] async fn test_sql_not_provided() { let sql_handler = create_testing_sql_query_handler(MemTable::default_numbers_table()); let ctx = QueryContext::with_db_name(None); ctx.set_current_user(auth::userinfo_by_name(None)); - let api_state = ApiState { sql_handler }; + let api_state = ApiState { + sql_handler, + experimental_enable_explain_analyze_stream: false, + }; for format in ["greptimedb_v1", "influxdb_v1", "csv", "table"] { let query = http_handler::SqlQuery { @@ -68,7 +194,10 @@ async fn test_sql_output_rows() { let ctx = QueryContext::with_db_name(None); ctx.set_current_user(auth::userinfo_by_name(None)); - let api_state = ApiState { sql_handler }; + let api_state = ApiState { + sql_handler, + experimental_enable_explain_analyze_stream: false, + }; let query_sql = "select sum(uint32s) from numbers limit 20"; for format in ["greptimedb_v1", "influxdb_v1", "csv", "table"] { @@ -173,7 +302,10 @@ async fn test_dashboard_sql_limit() { let sql_handler = create_testing_sql_query_handler(MemTable::specified_numbers_table(2000)); let ctx = QueryContext::with_db_name(None); ctx.set_current_user(auth::userinfo_by_name(None)); - let api_state = ApiState { sql_handler }; + let api_state = ApiState { + sql_handler, + experimental_enable_explain_analyze_stream: false, + }; for format in ["greptimedb_v1", "csv", "table"] { let query = create_query(format, "select * from numbers", Some(1000)); let sql_response = http_handler::sql( @@ -216,7 +348,10 @@ async fn test_sql_form() { let ctx = QueryContext::with_db_name(None); ctx.set_current_user(auth::userinfo_by_name(None)); - let api_state = ApiState { sql_handler }; + let api_state = ApiState { + sql_handler, + experimental_enable_explain_analyze_stream: false, + }; for format in ["greptimedb_v1", "influxdb_v1", "csv", "table", "null"] { let form = create_form(format); @@ -332,6 +467,183 @@ async fn test_sql_form() { } } +#[tokio::test] +async fn test_analyze_stream_sse_e2e() { + common_telemetry::init_default_ut_logging(); + + let client = analyze_stream_test_client(create_testing_sql_query_handler( + MemTable::default_numbers_table(), + )) + .await; + + let response = client + .post("/v1/sql/analyze/stream?snapshot_interval_ms=1000") + .header(header::ACCEPT, "text/event-stream") + .form(&http_handler::SqlQuery { + sql: Some("EXPLAIN ANALYZE VERBOSE SELECT sum(uint32s) FROM numbers".to_string()), + ..Default::default() + }) + .send() + .await; + + assert_eq!(response.status(), StatusCode::OK); + assert!( + response + .headers() + .get(header::CONTENT_TYPE) + .and_then(|value| value.to_str().ok()) + .is_some_and(|value| value.starts_with("text/event-stream")) + ); + + let body = response.text().await; + let final_payload = sse_event_payload(&body, "final").expect(&body); + let final_payload: Value = serde_json::from_str(&final_payload).unwrap(); + assert_eq!(final_payload["state"], "final"); + assert_eq!(final_payload["partial"], false); + assert!( + final_payload["metrics"] + .as_array() + .is_some_and(|v| !v.is_empty()) + ); + assert!( + final_payload["output"]["records"]["rows"] + .as_array() + .is_some_and(|v| !v.is_empty()) + ); +} + +#[tokio::test] +async fn test_analyze_stream_route_rejects_invalid_sql() { + common_telemetry::init_default_ut_logging(); + + let client = analyze_stream_test_client(create_testing_sql_query_handler( + MemTable::default_numbers_table(), + )) + .await; + + for sql in [ + "SELECT 1", + "EXPLAIN ANALYZE SELECT 1", + "EXPLAIN ANALYZE VERBOSE FORMAT TEXT SELECT 1", + "EXPLAIN ANALYZE VERBOSE SELECT 1; SELECT 2", + ] { + let response = client + .post("/v1/sql/analyze/stream") + .form(&http_handler::SqlQuery { + sql: Some(sql.to_string()), + ..Default::default() + }) + .send() + .await; + + assert_ne!(response.status(), StatusCode::OK, "{sql}"); + let body: Value = response.json().await; + assert!(body.get("error").is_some(), "{sql}: {body}"); + } +} + +#[tokio::test] +async fn test_analyze_stream_route_accepts_explicit_format_json() { + common_telemetry::init_default_ut_logging(); + + let client = analyze_stream_test_client(create_testing_sql_query_handler( + MemTable::default_numbers_table(), + )) + .await; + + let response = client + .post("/v1/sql/analyze/stream") + .header(header::ACCEPT, "text/event-stream") + .form(&http_handler::SqlQuery { + sql: Some( + "EXPLAIN ANALYZE VERBOSE FORMAT JSON SELECT sum(uint32s) FROM numbers".to_string(), + ), + ..Default::default() + }) + .send() + .await; + + assert_eq!(response.status(), StatusCode::OK); + let body = response.text().await; + assert!(sse_event_payload(&body, "final").is_some(), "{body}"); +} + +#[tokio::test] +async fn test_analyze_stream_emits_metrics_before_final_when_stream_is_pending() { + common_telemetry::init_default_ut_logging(); + + let inner = create_testing_sql_query_handler(MemTable::default_numbers_table()); + let sql_handler = Arc::new(SlowAnalyzeStreamHandler { + inner, + delay: Duration::from_millis(1500), + }); + let options = HttpOptions { + experimental_enable_explain_analyze_stream: true, + ..Default::default() + }; + let server = HttpServerBuilder::new(options) + .with_sql_handler(sql_handler) + .build(); + let app = server.build(server.make_app()).unwrap(); + let client = TestClient::new(app).await; + + let response = client + .post("/v1/sql/analyze/stream?snapshot_interval_ms=1000") + .header(header::ACCEPT, "text/event-stream") + .form(&http_handler::SqlQuery { + sql: Some("EXPLAIN ANALYZE VERBOSE SELECT sum(uint32s) FROM numbers".to_string()), + ..Default::default() + }) + .send() + .await; + + assert_eq!(response.status(), StatusCode::OK); + let body = response.text().await; + let metrics_pos = body.find("event: metrics").expect(&body); + let final_pos = body.find("event: final").expect(&body); + assert!( + metrics_pos < final_pos, + "metrics event should be emitted before final event: {body}" + ); + + let metrics_payload = sse_event_payload(&body, "metrics").expect(&body); + let metrics_payload: Value = serde_json::from_str(&metrics_payload).unwrap(); + assert_eq!(metrics_payload["state"], "metrics"); + assert_eq!(metrics_payload["partial"], true); + assert!( + metrics_payload["metrics"] + .as_array() + .is_some_and(|v| !v.is_empty()) + ); +} + +fn sse_event_payload(body: &str, event_name: &str) -> Option { + body.split("\n\n").find_map(|event| { + let mut found = false; + let mut data = Vec::new(); + for line in event.lines() { + if line.strip_prefix("event: ") == Some(event_name) { + found = true; + } else if let Some(value) = line.strip_prefix("data: ") { + data.push(value); + } + } + found.then(|| data.join("\n")) + }) +} + +async fn analyze_stream_test_client(sql_handler: ServerSqlQueryHandlerRef) -> TestClient { + let options = HttpOptions { + experimental_enable_explain_analyze_stream: true, + ..Default::default() + }; + let server = HttpServerBuilder::new(options) + .with_sql_handler(sql_handler) + .build(); + let app = server.build(server.make_app()).unwrap(); + TestClient::new(app).await +} + lazy_static::lazy_static! { static ref TEST_METRIC: prometheus::Counter = prometheus::register_counter!("test_metrics", "test metrics").unwrap(); diff --git a/src/servers/tests/http/influxdb_test.rs b/src/servers/tests/http/influxdb_test.rs index 59ec9c138c..ec5cb9e7c9 100644 --- a/src/servers/tests/http/influxdb_test.rs +++ b/src/servers/tests/http/influxdb_test.rs @@ -64,6 +64,10 @@ impl SqlQueryHandler for DummyInstance { unimplemented!() } + async fn do_analyze_stream_query(&self, _: &str, _: QueryContextRef) -> Result { + unimplemented!() + } + async fn do_exec_plan( &self, _plan: LogicalPlan, diff --git a/src/servers/tests/http/opentsdb_test.rs b/src/servers/tests/http/opentsdb_test.rs index 49acd8625b..5520b73b22 100644 --- a/src/servers/tests/http/opentsdb_test.rs +++ b/src/servers/tests/http/opentsdb_test.rs @@ -56,6 +56,10 @@ impl SqlQueryHandler for DummyInstance { unimplemented!() } + async fn do_analyze_stream_query(&self, _: &str, _: QueryContextRef) -> Result { + unimplemented!() + } + async fn do_exec_plan( &self, _plan: LogicalPlan, diff --git a/src/servers/tests/http/prom_store_test.rs b/src/servers/tests/http/prom_store_test.rs index cd0d3e7822..0de22f34e6 100644 --- a/src/servers/tests/http/prom_store_test.rs +++ b/src/servers/tests/http/prom_store_test.rs @@ -124,6 +124,10 @@ impl SqlQueryHandler for DummyInstance { unimplemented!() } + async fn do_analyze_stream_query(&self, _: &str, _: QueryContextRef) -> Result { + unimplemented!() + } + async fn do_exec_plan( &self, _plan: LogicalPlan, diff --git a/src/servers/tests/mod.rs b/src/servers/tests/mod.rs index 767a5aa26c..080703c687 100644 --- a/src/servers/tests/mod.rs +++ b/src/servers/tests/mod.rs @@ -80,6 +80,61 @@ impl SqlQueryHandler for DummyInstance { results } + async fn do_analyze_stream_query( + &self, + query: &str, + query_ctx: QueryContextRef, + ) -> Result { + let mut statements = ParserContext::create_with_dialect( + query, + query_ctx.sql_dialect(), + ParseOptions::default(), + ) + .map_err(BoxedError::new) + .context(ExecuteQuerySnafu)?; + ensure!( + statements.len() == 1, + NotSupportedSnafu { + feat: "execute multiple statements in analyze stream query" + } + ); + + let mut statement = statements.remove(0); + let Statement::Explain(explain) = &mut statement else { + return NotSupportedSnafu { + feat: "non EXPLAIN ANALYZE VERBOSE analyze stream query", + } + .fail(); + }; + ensure!( + explain.analyze && explain.verbose, + NotSupportedSnafu { + feat: "non EXPLAIN ANALYZE VERBOSE analyze stream query" + } + ); + match &explain.format { + None => {} + Some(format) if format.to_string().eq_ignore_ascii_case("json") => { + explain.format = None; + } + Some(_) => { + return NotSupportedSnafu { + feat: "non-JSON analyze stream format", + } + .fail(); + } + } + query_ctx.set_explain_format("JSON".to_string()); + + self.query_engine + .planner() + .plan(&QueryStatement::Sql(statement), query_ctx.clone()) + .and_then(|plan| self.query_engine.execute(plan, query_ctx.clone())) + .await + .map_err(BoxedError::new) + .context(ExecuteQuerySnafu) + } + async fn do_exec_plan( &self, plan: LogicalPlan, diff --git a/tests-integration/tests/http.rs b/tests-integration/tests/http.rs index b96c717f9d..e5776292dd 100644 --- a/tests-integration/tests/http.rs +++ b/tests-integration/tests/http.rs @@ -1807,6 +1807,7 @@ body_limit = "64MiB" prom_validation_mode = "strict" cors_allowed_origins = [] enable_cors = true +experimental_enable_explain_analyze_stream = false [grpc] bind_addr = "127.0.0.1:4001"