diff --git a/benchmarks/src/bin/nyc-taxi.rs b/benchmarks/src/bin/nyc-taxi.rs index 53e44c9688..1e60db69fa 100644 --- a/benchmarks/src/bin/nyc-taxi.rs +++ b/benchmarks/src/bin/nyc-taxi.rs @@ -29,7 +29,7 @@ use client::api::v1::column::Values; use client::api::v1::{ Column, ColumnDataType, ColumnDef, CreateTableExpr, InsertRequest, InsertRequests, SemanticType, }; -use client::{Client, Database, Output, DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; +use client::{Client, Database, OutputData, DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use futures_util::TryStreamExt; use indicatif::{MultiProgress, ProgressBar, ProgressStyle}; use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder; @@ -502,9 +502,9 @@ async fn do_query(num_iter: usize, db: &Database, table_name: &str) { for i in 0..num_iter { let now = Instant::now(); let res = db.sql(&query).await.unwrap(); - match res { - Output::AffectedRows(_) | Output::RecordBatches(_) => (), - Output::Stream(stream, _) => { + match res.data { + OutputData::AffectedRows(_) | OutputData::RecordBatches(_) => (), + OutputData::Stream(stream) => { stream.try_collect::>().await.unwrap(); } } diff --git a/src/client/src/database.rs b/src/client/src/database.rs index fa665a4967..fe02032a7b 100644 --- a/src/client/src/database.rs +++ b/src/client/src/database.rs @@ -307,7 +307,7 @@ impl Database { reason: "Expect 'AffectedRows' Flight messages to be the one and the only!" } ); - Ok(Output::AffectedRows(rows)) + Ok(Output::new_with_affected_rows(rows)) } FlightMessage::Recordbatch(_) | FlightMessage::Metrics(_) => { IllegalFlightMessagesSnafu { @@ -340,7 +340,7 @@ impl Database { output_ordering: None, metrics: Default::default(), }; - Ok(Output::new_stream(Box::pin(record_batch_stream))) + Ok(Output::new_with_stream(Box::pin(record_batch_stream))) } } } diff --git a/src/client/src/lib.rs b/src/client/src/lib.rs index 7f8330f689..1a854c5daa 100644 --- a/src/client/src/lib.rs +++ b/src/client/src/lib.rs @@ -26,7 +26,7 @@ use api::v1::greptime_response::Response; use api::v1::{AffectedRows, GreptimeResponse}; pub use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_error::status_code::StatusCode; -pub use common_query::Output; +pub use common_query::{Output, OutputData, OutputMeta}; pub use common_recordbatch::{RecordBatches, SendableRecordBatchStream}; use snafu::OptionExt; diff --git a/src/cmd/src/cli/export.rs b/src/cmd/src/cli/export.rs index cf43a9dec1..1e0e261a9e 100644 --- a/src/cmd/src/cli/export.rs +++ b/src/cmd/src/cli/export.rs @@ -19,8 +19,7 @@ use async_trait::async_trait; use clap::{Parser, ValueEnum}; use client::api::v1::auth_header::AuthScheme; use client::api::v1::Basic; -use client::{Client, Database, DEFAULT_SCHEMA_NAME}; -use common_query::Output; +use client::{Client, Database, OutputData, DEFAULT_SCHEMA_NAME}; use common_recordbatch::util::collect; use common_telemetry::{debug, error, info, warn}; use datatypes::scalars::ScalarVector; @@ -142,7 +141,7 @@ impl Export { .with_context(|_| RequestDatabaseSnafu { sql: "show databases".to_string(), })?; - let Output::Stream(stream, _) = result else { + let OutputData::Stream(stream) = result.data else { NotDataFromOutputSnafu.fail()? }; let record_batch = collect(stream) @@ -183,7 +182,7 @@ impl Export { .sql(&sql) .await .with_context(|_| RequestDatabaseSnafu { sql })?; - let Output::Stream(stream, _) = result else { + let OutputData::Stream(stream) = result.data else { NotDataFromOutputSnafu.fail()? }; let Some(record_batch) = collect(stream) @@ -235,7 +234,7 @@ impl Export { .sql(&sql) .await .with_context(|_| RequestDatabaseSnafu { sql })?; - let Output::Stream(stream, _) = result else { + let OutputData::Stream(stream) = result.data else { NotDataFromOutputSnafu.fail()? }; let record_batch = collect(stream) diff --git a/src/cmd/src/cli/repl.rs b/src/cmd/src/cli/repl.rs index a6c5811224..63f04ee5ed 100644 --- a/src/cmd/src/cli/repl.rs +++ b/src/cmd/src/cli/repl.rs @@ -19,7 +19,7 @@ use std::time::Instant; use catalog::kvbackend::{ CachedMetaKvBackend, CachedMetaKvBackendBuilder, KvBackendCatalogManager, }; -use client::{Client, Database, DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; +use client::{Client, Database, OutputData, DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_base::Plugins; use common_error::ext::ErrorExt; use common_query::Output; @@ -184,15 +184,15 @@ impl Repl { } .context(RequestDatabaseSnafu { sql: &sql })?; - let either = match output { - Output::Stream(s, _) => { + let either = match output.data { + OutputData::Stream(s) => { let x = RecordBatches::try_collect(s) .await .context(CollectRecordBatchesSnafu)?; Either::Left(x) } - Output::RecordBatches(x) => Either::Left(x), - Output::AffectedRows(rows) => Either::Right(rows), + OutputData::RecordBatches(x) => Either::Left(x), + OutputData::AffectedRows(rows) => Either::Right(rows), }; let end = Instant::now(); diff --git a/src/common/query/src/lib.rs b/src/common/query/src/lib.rs index 4cb7282e53..1f00bfa22c 100644 --- a/src/common/query/src/lib.rs +++ b/src/common/query/src/lib.rs @@ -30,38 +30,87 @@ pub mod prelude; mod signature; use sqlparser_derive::{Visit, VisitMut}; -// sql output -pub enum Output { +/// new Output struct with output data(previously Output) and output meta +#[derive(Debug)] +pub struct Output { + pub data: OutputData, + pub meta: OutputMeta, +} + +/// Original Output struct +/// carrying result data to response/client/user interface +pub enum OutputData { AffectedRows(usize), RecordBatches(RecordBatches), - Stream(SendableRecordBatchStream, Option>), + Stream(SendableRecordBatchStream), +} + +/// OutputMeta stores meta information produced/generated during the execution +#[derive(Debug, Default)] +pub struct OutputMeta { + /// May exist for query output. One can retrieve execution metrics from this plan. + pub plan: Option>, + pub cost: usize, } impl Output { - // helper function to build original `Output::Stream` - pub fn new_stream(stream: SendableRecordBatchStream) -> Self { - Output::Stream(stream, None) + pub fn new_with_affected_rows(affected_rows: usize) -> Self { + Self { + data: OutputData::AffectedRows(affected_rows), + meta: Default::default(), + } + } + + pub fn new_with_record_batches(recordbatches: RecordBatches) -> Self { + Self { + data: OutputData::RecordBatches(recordbatches), + meta: Default::default(), + } + } + + pub fn new_with_stream(stream: SendableRecordBatchStream) -> Self { + Self { + data: OutputData::Stream(stream), + meta: Default::default(), + } + } + + pub fn new(data: OutputData, meta: OutputMeta) -> Self { + Self { data, meta } } } -impl Debug for Output { +impl Debug for OutputData { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { - Output::AffectedRows(rows) => write!(f, "Output::AffectedRows({rows})"), - Output::RecordBatches(recordbatches) => { - write!(f, "Output::RecordBatches({recordbatches:?})") + OutputData::AffectedRows(rows) => write!(f, "OutputData::AffectedRows({rows})"), + OutputData::RecordBatches(recordbatches) => { + write!(f, "OutputData::RecordBatches({recordbatches:?})") } - Output::Stream(_, df) => { - if df.is_some() { - write!(f, "Output::Stream(, Some)") - } else { - write!(f, "Output::Stream()") - } + OutputData::Stream(_) => { + write!(f, "OutputData::Stream()") } } } } +impl OutputMeta { + pub fn new(plan: Option>, cost: usize) -> Self { + Self { plan, cost } + } + + pub fn new_with_plan(plan: Arc) -> Self { + Self { + plan: Some(plan), + cost: 0, + } + } + + pub fn new_with_cost(cost: usize) -> Self { + Self { plan: None, cost } + } +} + pub use datafusion::physical_plan::ExecutionPlan as DfPhysicalPlan; #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Visit, VisitMut)] diff --git a/src/common/test-util/src/recordbatch.rs b/src/common/test-util/src/recordbatch.rs index 6a10df0633..64a6262a08 100644 --- a/src/common/test-util/src/recordbatch.rs +++ b/src/common/test-util/src/recordbatch.rs @@ -13,7 +13,7 @@ // limitations under the License. use client::Database; -use common_query::Output; +use common_query::OutputData; use common_recordbatch::util; pub enum ExpectedOutput<'a> { @@ -23,22 +23,24 @@ pub enum ExpectedOutput<'a> { pub async fn execute_and_check_output(db: &Database, sql: &str, expected: ExpectedOutput<'_>) { let output = db.sql(sql).await.unwrap(); + let output = output.data; + match (&output, expected) { - (Output::AffectedRows(x), ExpectedOutput::AffectedRows(y)) => { + (OutputData::AffectedRows(x), ExpectedOutput::AffectedRows(y)) => { assert_eq!(*x, y, "actual: \n{}", x) } - (Output::RecordBatches(_), ExpectedOutput::QueryResult(x)) - | (Output::Stream(_, _), ExpectedOutput::QueryResult(x)) => { + (OutputData::RecordBatches(_), ExpectedOutput::QueryResult(x)) + | (OutputData::Stream(_), ExpectedOutput::QueryResult(x)) => { check_output_stream(output, x).await } _ => panic!(), } } -pub async fn check_output_stream(output: Output, expected: &str) { +pub async fn check_output_stream(output: OutputData, expected: &str) { let recordbatches = match output { - Output::Stream(stream, _) => util::collect_batches(stream).await.unwrap(), - Output::RecordBatches(recordbatches) => recordbatches, + OutputData::Stream(stream) => util::collect_batches(stream).await.unwrap(), + OutputData::RecordBatches(recordbatches) => recordbatches, _ => unreachable!(), }; let pretty_print = recordbatches.pretty_print().unwrap(); diff --git a/src/datanode/src/region_server.rs b/src/datanode/src/region_server.rs index 773833408e..d9b74e02aa 100644 --- a/src/datanode/src/region_server.rs +++ b/src/datanode/src/region_server.rs @@ -27,7 +27,7 @@ use common_error::ext::BoxedError; use common_error::status_code::StatusCode; use common_query::logical_plan::Expr; use common_query::physical_plan::DfPhysicalPlanAdapter; -use common_query::{DfPhysicalPlan, Output}; +use common_query::{DfPhysicalPlan, OutputData}; use common_recordbatch::SendableRecordBatchStream; use common_runtime::Runtime; use common_telemetry::tracing::{self, info_span}; @@ -651,11 +651,11 @@ impl RegionServerInner { .await .context(ExecuteLogicalPlanSnafu)?; - match result { - Output::AffectedRows(_) | Output::RecordBatches(_) => { + match result.data { + OutputData::AffectedRows(_) | OutputData::RecordBatches(_) => { UnsupportedOutputSnafu { expected: "stream" }.fail() } - Output::Stream(stream, _) => Ok(stream), + OutputData::Stream(stream) => Ok(stream), } } diff --git a/src/frontend/src/instance.rs b/src/frontend/src/instance.rs index 115613ef01..fb0d5f9913 100644 --- a/src/frontend/src/instance.rs +++ b/src/frontend/src/instance.rs @@ -28,6 +28,7 @@ use api::v1::meta::Role; use async_trait::async_trait; use auth::{PermissionChecker, PermissionCheckerRef, PermissionReq}; use catalog::CatalogManagerRef; +use client::OutputData; use common_base::Plugins; use common_config::KvBackendConfig; use common_error::ext::BoxedError; @@ -401,13 +402,13 @@ impl SqlQueryHandler for Instance { /// Attaches a timer to the output and observes it once the output is exhausted. pub fn attach_timer(output: Output, timer: HistogramTimer) -> Output { - match output { - Output::AffectedRows(_) | Output::RecordBatches(_) => output, - Output::Stream(stream, plan) => { + match output.data { + OutputData::AffectedRows(_) | OutputData::RecordBatches(_) => output, + OutputData::Stream(stream) => { let stream = OnDone::new(stream, move || { timer.observe_duration(); }); - Output::Stream(Box::pin(stream), plan) + Output::new(OutputData::Stream(Box::pin(stream)), output.meta) } } } diff --git a/src/frontend/src/instance/grpc.rs b/src/frontend/src/instance/grpc.rs index 9f2a2de2a2..5dd20808f0 100644 --- a/src/frontend/src/instance/grpc.rs +++ b/src/frontend/src/instance/grpc.rs @@ -113,7 +113,7 @@ impl GrpcQueryHandler for Instance { .statement_executor .create_table_inner(&mut expr, None, &ctx) .await?; - Output::AffectedRows(0) + Output::new_with_affected_rows(0) } DdlExpr::Alter(expr) => self.statement_executor.alter_table_inner(expr).await?, DdlExpr::CreateDatabase(expr) => { diff --git a/src/frontend/src/instance/opentsdb.rs b/src/frontend/src/instance/opentsdb.rs index 4269e92339..946c3b9ff7 100644 --- a/src/frontend/src/instance/opentsdb.rs +++ b/src/frontend/src/instance/opentsdb.rs @@ -47,8 +47,8 @@ impl OpentsdbProtocolHandler for Instance { .map_err(BoxedError::new) .context(servers::error::ExecuteGrpcQuerySnafu)?; - Ok(match output { - common_query::Output::AffectedRows(rows) => rows, + Ok(match output.data { + common_query::OutputData::AffectedRows(rows) => rows, _ => unreachable!(), }) } diff --git a/src/frontend/src/instance/prom_store.rs b/src/frontend/src/instance/prom_store.rs index 22402bebff..5382cf9682 100644 --- a/src/frontend/src/instance/prom_store.rs +++ b/src/frontend/src/instance/prom_store.rs @@ -19,6 +19,7 @@ use api::prom_store::remote::{Query, QueryResult, ReadRequest, ReadResponse, Wri use api::v1::RowInsertRequests; use async_trait::async_trait; use auth::{PermissionChecker, PermissionCheckerRef, PermissionReq}; +use client::OutputData; use common_catalog::format_full_table_name; use common_error::ext::BoxedError; use common_query::prelude::GREPTIME_PHYSICAL_TABLE; @@ -77,7 +78,7 @@ fn negotiate_response_type(accepted_response_types: &[i32]) -> ServerResult ServerResult { - let Output::Stream(stream, _) = output else { + let OutputData::Stream(stream) = output.data else { unreachable!() }; let recordbatches = RecordBatches::try_collect(stream) diff --git a/src/operator/src/delete.rs b/src/operator/src/delete.rs index f4e0e20e8c..2a4737a79a 100644 --- a/src/operator/src/delete.rs +++ b/src/operator/src/delete.rs @@ -91,7 +91,8 @@ impl Deleter { .await?; let affected_rows = self.do_request(deletes, &ctx).await?; - Ok(Output::AffectedRows(affected_rows as _)) + + Ok(Output::new_with_affected_rows(affected_rows)) } pub async fn handle_table_delete( diff --git a/src/operator/src/insert.rs b/src/operator/src/insert.rs index 6102b8354e..91abfdf5c7 100644 --- a/src/operator/src/insert.rs +++ b/src/operator/src/insert.rs @@ -111,7 +111,7 @@ impl Inserter { .await?; let affected_rows = self.do_request(inserts, &ctx).await?; - Ok(Output::AffectedRows(affected_rows as _)) + Ok(Output::new_with_affected_rows(affected_rows)) } /// Handle row inserts request with metric engine. @@ -149,7 +149,7 @@ impl Inserter { .await?; let affected_rows = self.do_request(inserts, &ctx).await?; - Ok(Output::AffectedRows(affected_rows as _)) + Ok(Output::new_with_affected_rows(affected_rows)) } pub async fn handle_table_insert( @@ -185,7 +185,7 @@ impl Inserter { .await?; let affected_rows = self.do_request(inserts, ctx).await?; - Ok(Output::AffectedRows(affected_rows as _)) + Ok(Output::new_with_affected_rows(affected_rows)) } } diff --git a/src/operator/src/statement.rs b/src/operator/src/statement.rs index b1948ce1e3..5231f99a58 100644 --- a/src/operator/src/statement.rs +++ b/src/operator/src/statement.rs @@ -123,11 +123,11 @@ impl StatementExecutor { CopyDirection::Export => self .copy_table_to(req, query_ctx) .await - .map(Output::AffectedRows), + .map(Output::new_with_affected_rows), CopyDirection::Import => self .copy_table_from(req, query_ctx) .await - .map(Output::AffectedRows), + .map(Output::new_with_affected_rows), } } @@ -152,15 +152,15 @@ impl StatementExecutor { Statement::CreateTable(stmt) => { let _ = self.create_table(stmt, query_ctx).await?; - Ok(Output::AffectedRows(0)) + Ok(Output::new_with_affected_rows(0)) } Statement::CreateTableLike(stmt) => { let _ = self.create_table_like(stmt, query_ctx).await?; - Ok(Output::AffectedRows(0)) + Ok(Output::new_with_affected_rows(0)) } Statement::CreateExternalTable(stmt) => { let _ = self.create_external_table(stmt, query_ctx).await?; - Ok(Output::AffectedRows(0)) + Ok(Output::new_with_affected_rows(0)) } Statement::Alter(alter_table) => self.alter_table(alter_table, query_ctx).await, Statement::DropTable(stmt) => { @@ -231,7 +231,7 @@ impl StatementExecutor { .fail() } } - Ok(Output::AffectedRows(0)) + Ok(Output::new_with_affected_rows(0)) } Statement::ShowVariables(show_variable) => self.show_variable(show_variable, query_ctx), } diff --git a/src/operator/src/statement/copy_database.rs b/src/operator/src/statement/copy_database.rs index 63236e3eb9..ca62fd97cd 100644 --- a/src/operator/src/statement/copy_database.rs +++ b/src/operator/src/statement/copy_database.rs @@ -15,10 +15,10 @@ use std::path::Path; use std::str::FromStr; +use client::Output; use common_datasource::file_format::Format; use common_datasource::lister::{Lister, Source}; use common_datasource::object_store::build_backend; -use common_query::Output; use common_telemetry::{debug, error, info, tracing}; use object_store::Entry; use regex::Regex; @@ -96,7 +96,7 @@ impl StatementExecutor { .await?; exported_rows += exported; } - Ok(Output::AffectedRows(exported_rows)) + Ok(Output::new_with_affected_rows(exported_rows)) } /// Imports data to database from a given location and returns total rows imported. @@ -169,7 +169,7 @@ impl StatementExecutor { } } } - Ok(Output::AffectedRows(rows_inserted)) + Ok(Output::new_with_affected_rows(rows_inserted)) } } diff --git a/src/operator/src/statement/copy_table_to.rs b/src/operator/src/statement/copy_table_to.rs index 58def6af54..2a4d4a0ca7 100644 --- a/src/operator/src/statement/copy_table_to.rs +++ b/src/operator/src/statement/copy_table_to.rs @@ -14,6 +14,7 @@ use std::sync::Arc; +use client::OutputData; use common_base::readable_size::ReadableSize; use common_datasource::file_format::csv::stream_to_csv; use common_datasource::file_format::json::stream_to_json; @@ -21,7 +22,6 @@ use common_datasource::file_format::parquet::stream_to_parquet; use common_datasource::file_format::Format; use common_datasource::object_store::{build_backend, parse_url}; use common_datasource::util::find_dir_and_filename; -use common_query::Output; use common_recordbatch::adapter::DfRecordBatchStreamAdapter; use common_recordbatch::SendableRecordBatchStream; use common_telemetry::{debug, tracing}; @@ -134,9 +134,9 @@ impl StatementExecutor { .execute(LogicalPlan::DfPlan(plan), query_ctx) .await .context(ExecLogicalPlanSnafu)?; - let stream = match output { - Output::Stream(stream, _) => stream, - Output::RecordBatches(record_batches) => record_batches.as_stream(), + let stream = match output.data { + OutputData::Stream(stream) => stream, + OutputData::RecordBatches(record_batches) => record_batches.as_stream(), _ => unreachable!(), }; diff --git a/src/operator/src/statement/ddl.rs b/src/operator/src/statement/ddl.rs index 8c76b0618b..c338f306b3 100644 --- a/src/operator/src/statement/ddl.rs +++ b/src/operator/src/statement/ddl.rs @@ -338,10 +338,10 @@ impl StatementExecutor { .await .context(error::InvalidateTableCacheSnafu)?; - Ok(Output::AffectedRows(0)) + Ok(Output::new_with_affected_rows(0)) } else if drop_if_exists { // DROP TABLE IF EXISTS meets table not found - ignored - Ok(Output::AffectedRows(0)) + Ok(Output::new_with_affected_rows(0)) } else { Err(TableNotFoundSnafu { table_name: table_name.to_string(), @@ -367,7 +367,7 @@ impl StatementExecutor { let table_id = table.table_info().table_id(); self.truncate_table_procedure(&table_name, table_id).await?; - Ok(Output::AffectedRows(0)) + Ok(Output::new_with_affected_rows(0)) } fn verify_alter( @@ -471,7 +471,7 @@ impl StatementExecutor { .await .context(error::InvalidateTableCacheSnafu)?; - Ok(Output::AffectedRows(0)) + Ok(Output::new_with_affected_rows(0)) } async fn create_table_procedure( @@ -580,7 +580,7 @@ impl StatementExecutor { if exists { return if create_if_not_exists { - Ok(Output::AffectedRows(1)) + Ok(Output::new_with_affected_rows(1)) } else { error::SchemaExistsSnafu { name: database }.fail() }; @@ -592,7 +592,7 @@ impl StatementExecutor { .await .context(TableMetadataManagerSnafu)?; - Ok(Output::AffectedRows(1)) + Ok(Output::new_with_affected_rows(1)) } } diff --git a/src/query/src/datafusion.rs b/src/query/src/datafusion.rs index 914a8b4f3c..0343f0b5fd 100644 --- a/src/query/src/datafusion.rs +++ b/src/query/src/datafusion.rs @@ -28,7 +28,7 @@ use common_function::function::FunctionRef; use common_function::scalars::aggregate::AggregateFunctionMetaRef; use common_query::physical_plan::{DfPhysicalPlanAdapter, PhysicalPlan, PhysicalPlanAdapter}; use common_query::prelude::ScalarUdf; -use common_query::Output; +use common_query::{Output, OutputData, OutputMeta}; use common_recordbatch::adapter::RecordBatchStreamAdapter; use common_recordbatch::{EmptyRecordBatchStream, SendableRecordBatchStream}; use common_telemetry::tracing; @@ -90,9 +90,9 @@ impl DatafusionQueryEngine { optimized_physical_plan }; - Ok(Output::Stream( - self.execute_stream(&ctx, &physical_plan)?, - Some(physical_plan), + Ok(Output::new( + OutputData::Stream(self.execute_stream(&ctx, &physical_plan)?), + OutputMeta::new_with_plan(physical_plan), )) } @@ -121,9 +121,9 @@ impl DatafusionQueryEngine { let output = self .exec_query_plan(LogicalPlan::DfPlan((*dml.input).clone()), query_ctx.clone()) .await?; - let mut stream = match output { - Output::RecordBatches(batches) => batches.as_stream(), - Output::Stream(stream, _) => stream, + let mut stream = match output.data { + OutputData::RecordBatches(batches) => batches.as_stream(), + OutputData::Stream(stream) => stream, _ => unreachable!(), }; @@ -148,7 +148,7 @@ impl DatafusionQueryEngine { }; affected_rows += rows; } - Ok(Output::AffectedRows(affected_rows)) + Ok(Output::new_with_affected_rows(affected_rows)) } #[tracing::instrument(skip_all)] @@ -471,7 +471,6 @@ mod tests { use catalog::RegisterTableRequest; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, NUMBERS_TABLE_ID}; - use common_query::Output; use common_recordbatch::util; use datafusion::prelude::{col, lit}; use datatypes::prelude::ConcreteDataType; @@ -534,8 +533,8 @@ mod tests { let output = engine.execute(plan, QueryContext::arc()).await.unwrap(); - match output { - Output::Stream(recordbatch, _) => { + match output.data { + OutputData::Stream(recordbatch) => { let numbers = util::collect(recordbatch).await.unwrap(); assert_eq!(1, numbers.len()); assert_eq!(numbers[0].num_columns(), 1); diff --git a/src/query/src/sql.rs b/src/query/src/sql.rs index cc4462830e..da1ab58cde 100644 --- a/src/query/src/sql.rs +++ b/src/query/src/sql.rs @@ -237,10 +237,9 @@ async fn query_from_information_schema_table( .await .context(error::DataFusionSnafu)?; - Ok(Output::Stream( - Box::pin(RecordBatchStreamAdapter::try_new(stream).context(error::CreateRecordBatchSnafu)?), - None, - )) + Ok(Output::new_with_stream(Box::pin( + RecordBatchStreamAdapter::try_new(stream).context(error::CreateRecordBatchSnafu)?, + ))) } pub async fn show_tables( @@ -303,7 +302,7 @@ pub fn show_variable(stmt: ShowVariables, query_ctx: QueryContextRef) -> Result< vec![Arc::new(StringVector::from(vec![value])) as _], ) .context(error::CreateRecordBatchSnafu)?; - Ok(Output::RecordBatches(records)) + Ok(Output::new_with_record_batches(records)) } pub fn show_create_table( @@ -329,7 +328,7 @@ pub fn show_create_table( let records = RecordBatches::try_from_columns(SHOW_CREATE_TABLE_OUTPUT_SCHEMA.clone(), columns) .context(error::CreateRecordBatchSnafu)?; - Ok(Output::RecordBatches(records)) + Ok(Output::new_with_record_batches(records)) } pub fn describe_table(table: TableRef) -> Result { @@ -345,7 +344,7 @@ pub fn describe_table(table: TableRef) -> Result { ]; let records = RecordBatches::try_from_columns(DESCRIBE_TABLE_OUTPUT_SCHEMA.clone(), columns) .context(error::CreateRecordBatchSnafu)?; - Ok(Output::RecordBatches(records)) + Ok(Output::new_with_record_batches(records)) } fn describe_column_names(columns_schemas: &[ColumnSchema]) -> VectorRef { @@ -572,7 +571,7 @@ fn parse_file_table_format(options: &HashMap) -> Result { + Ok(Output { + data: OutputData::RecordBatches(record), + .. + }) => { let record = record.take().first().cloned().unwrap(); let data = record.column(0); Ok(data.get(0).to_string()) diff --git a/src/query/src/tests.rs b/src/query/src/tests.rs index c2c8ace323..e92fbba577 100644 --- a/src/query/src/tests.rs +++ b/src/query/src/tests.rs @@ -13,7 +13,7 @@ // limitations under the License. use catalog::memory::MemoryCatalogManager; -use common_query::Output; +use common_query::OutputData; use common_recordbatch::{util, RecordBatch}; use session::context::QueryContext; use table::TableRef; @@ -43,7 +43,7 @@ async fn exec_selection(engine: QueryEngineRef, sql: &str) -> Vec { .plan(stmt, query_ctx.clone()) .await .unwrap(); - let Output::Stream(stream, _) = engine.execute(plan, query_ctx).await.unwrap() else { + let OutputData::Stream(stream) = engine.execute(plan, query_ctx).await.unwrap().data else { unreachable!() }; util::collect(stream).await.unwrap() diff --git a/src/query/src/tests/query_engine_test.rs b/src/query/src/tests/query_engine_test.rs index 3fcddd5043..99551dab0e 100644 --- a/src/query/src/tests/query_engine_test.rs +++ b/src/query/src/tests/query_engine_test.rs @@ -20,7 +20,7 @@ use common_base::Plugins; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, NUMBERS_TABLE_ID}; use common_error::ext::BoxedError; use common_query::prelude::{create_udf, make_scalar_function, Volatility}; -use common_query::Output; +use common_query::OutputData; use common_recordbatch::{util, RecordBatch}; use datafusion::datasource::DefaultTableSource; use datafusion_expr::logical_plan::builder::LogicalPlanBuilder; @@ -79,8 +79,8 @@ async fn test_datafusion_query_engine() -> Result<()> { let output = engine.execute(plan, QueryContext::arc()).await?; - let recordbatch = match output { - Output::Stream(recordbatch, _) => recordbatch, + let recordbatch = match output.data { + OutputData::Stream(recordbatch) => recordbatch, _ => unreachable!(), }; diff --git a/src/script/benches/py_benchmark.rs b/src/script/benches/py_benchmark.rs index 6568b21a22..0e748bee3a 100644 --- a/src/script/benches/py_benchmark.rs +++ b/src/script/benches/py_benchmark.rs @@ -17,7 +17,7 @@ use std::sync::Arc; use catalog::memory::MemoryCatalogManager; use common_catalog::consts::NUMBERS_TABLE_ID; -use common_query::Output; +use common_query::OutputData; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use futures::Future; use once_cell::sync::{Lazy, OnceCell}; @@ -69,9 +69,9 @@ async fn run_compiled(script: &PyScript) { .execute(HashMap::default(), EvalContext::default()) .await .unwrap(); - let _res = match output { - Output::Stream(s, _) => common_recordbatch::util::collect_batches(s).await.unwrap(), - Output::RecordBatches(rbs) => rbs, + let _res = match output.data { + OutputData::Stream(s) => common_recordbatch::util::collect_batches(s).await.unwrap(), + OutputData::RecordBatches(rbs) => rbs, _ => unreachable!(), }; } diff --git a/src/script/src/manager.rs b/src/script/src/manager.rs index a80140edf5..c7301ab8d3 100644 --- a/src/script/src/manager.rs +++ b/src/script/src/manager.rs @@ -211,6 +211,8 @@ impl ScriptManager { #[cfg(test)] mod tests { + use common_query::OutputData; + use super::*; use crate::test::setup_scripts_manager; @@ -261,8 +263,8 @@ def test() -> vector[str]: .await .unwrap(); - match output { - Output::RecordBatches(batches) => { + match output.data { + OutputData::RecordBatches(batches) => { let expected = "\ +-------+ | n | diff --git a/src/script/src/python/engine.rs b/src/script/src/python/engine.rs index a0991565ac..dbf1eec6e3 100644 --- a/src/script/src/python/engine.rs +++ b/src/script/src/python/engine.rs @@ -25,7 +25,7 @@ use common_function::function::Function; use common_function::function_registry::FUNCTION_REGISTRY; use common_query::error::{PyUdfSnafu, UdfTempRecordBatchSnafu}; use common_query::prelude::Signature; -use common_query::Output; +use common_query::{Output, OutputData}; use common_recordbatch::error::{ExternalSnafu, Result as RecordBatchResult}; use common_recordbatch::{ RecordBatch, RecordBatchStream, RecordBatches, SendableRecordBatchStream, @@ -311,10 +311,10 @@ impl Script for PyScript { .await .context(DatabaseQuerySnafu)?; let copr = self.copr.clone(); - match res { - Output::Stream(stream, _) => Ok(Output::new_stream(Box::pin(CoprStream::try_new( - stream, copr, params, ctx, - )?))), + match res.data { + OutputData::Stream(stream) => Ok(Output::new_with_stream(Box::pin( + CoprStream::try_new(stream, copr, params, ctx)?, + ))), _ => unreachable!(), } } else { @@ -324,7 +324,7 @@ impl Script for PyScript { .await .context(TokioJoinSnafu)??; let batches = RecordBatches::try_new(batch.schema.clone(), vec![batch]).unwrap(); - Ok(Output::RecordBatches(batches)) + Ok(Output::new_with_record_batches(batches)) } } } @@ -410,8 +410,8 @@ def test(number) -> vector[u32]: .execute(HashMap::default(), EvalContext::default()) .await .unwrap(); - let res = common_recordbatch::util::collect_batches(match output { - Output::Stream(s, _) => s, + let res = common_recordbatch::util::collect_batches(match output.data { + OutputData::Stream(s) => s, _ => unreachable!(), }) .await @@ -441,8 +441,8 @@ def test(**params) -> vector[i64]: .execute(params, EvalContext::default()) .await .unwrap(); - let res = match _output { - Output::RecordBatches(s) => s, + let res = match _output.data { + OutputData::RecordBatches(s) => s, _ => todo!(), }; let rb = res.iter().next().expect("One and only one recordbatch"); @@ -471,8 +471,8 @@ def test(number) -> vector[u32]: .execute(HashMap::new(), EvalContext::default()) .await .unwrap(); - let res = common_recordbatch::util::collect_batches(match _output { - Output::Stream(s, _) => s, + let res = common_recordbatch::util::collect_batches(match _output.data { + OutputData::Stream(s) => s, _ => todo!(), }) .await @@ -503,8 +503,8 @@ def test(a, b, c) -> vector[f64]: .execute(HashMap::new(), EvalContext::default()) .await .unwrap(); - match output { - Output::Stream(stream, _) => { + match output.data { + OutputData::Stream(stream) => { let numbers = util::collect(stream).await.unwrap(); assert_eq!(1, numbers.len()); @@ -541,8 +541,8 @@ def test(a) -> vector[i64]: .execute(HashMap::new(), EvalContext::default()) .await .unwrap(); - match output { - Output::Stream(stream, _) => { + match output.data { + OutputData::Stream(stream) => { let numbers = util::collect(stream).await.unwrap(); assert_eq!(1, numbers.len()); diff --git a/src/script/src/python/ffi_types/copr.rs b/src/script/src/python/ffi_types/copr.rs index a93b22c5bb..1af3f416f3 100644 --- a/src/script/src/python/ffi_types/copr.rs +++ b/src/script/src/python/ffi_types/copr.rs @@ -19,6 +19,7 @@ use std::collections::HashMap; use std::result::Result as StdResult; use std::sync::{Arc, Weak}; +use common_query::OutputData; use common_recordbatch::{RecordBatch, RecordBatches}; use datatypes::arrow::compute; use datatypes::data_type::{ConcreteDataType, DataType}; @@ -399,13 +400,14 @@ impl PyQueryEngine { .await .map_err(|e| e.to_string()); match res { - Ok(common_query::Output::AffectedRows(cnt)) => { - Ok(Either::AffectedRows(cnt)) - } - Ok(common_query::Output::RecordBatches(rbs)) => Ok(Either::Rb(rbs)), - Ok(common_query::Output::Stream(s, _)) => Ok(Either::Rb( - common_recordbatch::util::collect_batches(s).await.unwrap(), - )), + Ok(o) => match o.data { + OutputData::AffectedRows(cnt) => Ok(Either::AffectedRows(cnt)), + OutputData::RecordBatches(rbs) => Ok(Either::Rb(rbs)), + OutputData::Stream(s) => Ok(Either::Rb( + common_recordbatch::util::collect_batches(s).await.unwrap(), + )), + }, + Err(e) => Err(e), } })?; diff --git a/src/script/src/python/ffi_types/pair_tests.rs b/src/script/src/python/ffi_types/pair_tests.rs index ec01355357..37e3c76994 100644 --- a/src/script/src/python/ffi_types/pair_tests.rs +++ b/src/script/src/python/ffi_types/pair_tests.rs @@ -18,7 +18,7 @@ use std::collections::HashMap; use std::sync::Arc; use arrow::compute::kernels::numeric; -use common_query::Output; +use common_query::OutputData; use common_recordbatch::RecordBatch; use datafusion::arrow::array::Float64Array; use datafusion::arrow::compute; @@ -87,9 +87,9 @@ async fn integrated_py_copr_test() { .execute(HashMap::default(), EvalContext::default()) .await .unwrap(); - let res = match output { - Output::Stream(s, _) => common_recordbatch::util::collect_batches(s).await.unwrap(), - Output::RecordBatches(rbs) => rbs, + let res = match output.data { + OutputData::Stream(s) => common_recordbatch::util::collect_batches(s).await.unwrap(), + OutputData::RecordBatches(rbs) => rbs, _ => unreachable!(), }; let rb = res.iter().next().expect("One and only one recordbatch"); diff --git a/src/script/src/table.rs b/src/script/src/table.rs index 67e561bc67..6620cd86fd 100644 --- a/src/script/src/table.rs +++ b/src/script/src/table.rs @@ -24,7 +24,7 @@ use api::v1::{ }; use catalog::error::CompileScriptInternalSnafu; use common_error::ext::{BoxedError, ErrorExt}; -use common_query::Output; +use common_query::OutputData; use common_recordbatch::{util as record_util, RecordBatch, SendableRecordBatchStream}; use common_telemetry::logging; use common_time::util; @@ -230,9 +230,9 @@ impl ScriptsTable { .execute(LogicalPlan::DfPlan(plan), query_ctx(&table_info)) .await .context(ExecuteInternalStatementSnafu)?; - let stream = match output { - Output::Stream(stream, _) => stream, - Output::RecordBatches(record_batches) => record_batches.as_stream(), + let stream = match output.data { + OutputData::Stream(stream) => stream, + OutputData::RecordBatches(record_batches) => record_batches.as_stream(), _ => unreachable!(), }; @@ -285,9 +285,9 @@ impl ScriptsTable { .execute(LogicalPlan::DfPlan(plan), query_ctx(&table_info)) .await .context(ExecuteInternalStatementSnafu)?; - let stream = match output { - Output::Stream(stream, _) => stream, - Output::RecordBatches(record_batches) => record_batches.as_stream(), + let stream = match output.data { + OutputData::Stream(stream) => stream, + OutputData::RecordBatches(record_batches) => record_batches.as_stream(), _ => unreachable!(), }; Ok(stream) diff --git a/src/script/src/test.rs b/src/script/src/test.rs index b2beb799a7..f4d3e4537d 100644 --- a/src/script/src/test.rs +++ b/src/script/src/test.rs @@ -73,6 +73,6 @@ impl GrpcQueryHandler for MockGrpcQueryHandler { type Error = Error; async fn do_query(&self, _query: Request, _ctx: QueryContextRef) -> Result { - Ok(Output::AffectedRows(1)) + Ok(Output::new_with_affected_rows(1)) } } diff --git a/src/servers/src/grpc/database.rs b/src/servers/src/grpc/database.rs index 133f3db6f9..3e242fde11 100644 --- a/src/servers/src/grpc/database.rs +++ b/src/servers/src/grpc/database.rs @@ -17,7 +17,7 @@ use api::v1::greptime_response::Response as RawResponse; use api::v1::{AffectedRows, GreptimeRequest, GreptimeResponse, ResponseHeader}; use async_trait::async_trait; use common_error::status_code::StatusCode; -use common_query::Output; +use common_query::OutputData; use futures::StreamExt; use tonic::{Request, Response, Status, Streaming}; @@ -42,8 +42,8 @@ impl GreptimeDatabase for DatabaseService { ) -> TonicResult> { let request = request.into_inner(); let output = self.handler.handle_request(request).await?; - let message = match output { - Output::AffectedRows(rows) => GreptimeResponse { + let message = match output.data { + OutputData::AffectedRows(rows) => GreptimeResponse { header: Some(ResponseHeader { status: Some(api::v1::Status { status_code: StatusCode::Success as _, @@ -52,7 +52,7 @@ impl GreptimeDatabase for DatabaseService { }), response: Some(RawResponse::AffectedRows(AffectedRows { value: rows as _ })), }, - Output::Stream(_, _) | Output::RecordBatches(_) => { + OutputData::Stream(_) | OutputData::RecordBatches(_) => { return Err(Status::unimplemented("GreptimeDatabase::Handle for query")); } }; @@ -69,9 +69,9 @@ impl GreptimeDatabase for DatabaseService { while let Some(request) = stream.next().await { let request = request?; let output = self.handler.handle_request(request).await?; - match output { - Output::AffectedRows(rows) => affected_rows += rows, - Output::Stream(_, _) | Output::RecordBatches(_) => { + match output.data { + OutputData::AffectedRows(rows) => affected_rows += rows, + OutputData::Stream(_) | OutputData::RecordBatches(_) => { return Err(Status::unimplemented( "GreptimeDatabase::HandleRequests for query", )); diff --git a/src/servers/src/grpc/flight.rs b/src/servers/src/grpc/flight.rs index 0283780b5b..9ed2ed85d3 100644 --- a/src/servers/src/grpc/flight.rs +++ b/src/servers/src/grpc/flight.rs @@ -25,7 +25,7 @@ use arrow_flight::{ }; use async_trait::async_trait; use common_grpc::flight::{FlightEncoder, FlightMessage}; -use common_query::Output; +use common_query::{Output, OutputData}; use common_telemetry::tracing::info_span; use common_telemetry::tracing_context::{FutureExt, TracingContext}; use futures::Stream; @@ -174,16 +174,16 @@ fn to_flight_data_stream( output: Output, tracing_context: TracingContext, ) -> TonicStream { - match output { - Output::Stream(stream, _) => { + match output.data { + OutputData::Stream(stream) => { let stream = FlightRecordBatchStream::new(stream, tracing_context); Box::pin(stream) as _ } - Output::RecordBatches(x) => { + OutputData::RecordBatches(x) => { let stream = FlightRecordBatchStream::new(x.as_stream(), tracing_context); Box::pin(stream) as _ } - Output::AffectedRows(rows) => { + OutputData::AffectedRows(rows) => { let stream = tokio_stream::once(Ok( FlightEncoder::default().encode(FlightMessage::AffectedRows(rows)) )); diff --git a/src/servers/src/http.rs b/src/servers/src/http.rs index fde17b72a0..6a141bfa74 100644 --- a/src/servers/src/http.rs +++ b/src/servers/src/http.rs @@ -925,7 +925,7 @@ mod test { let schema = Arc::new(Schema::new(column_schemas)); let recordbatches = RecordBatches::try_new(schema.clone(), vec![]).unwrap(); - let outputs = vec![Ok(Output::RecordBatches(recordbatches))]; + let outputs = vec![Ok(Output::new_with_record_batches(recordbatches))]; let json_resp = GreptimedbV1Response::from_output(outputs).await; if let HttpResponse::GreptimedbV1(json_resp) = json_resp { @@ -969,7 +969,7 @@ mod test { ] { let recordbatches = RecordBatches::try_new(schema.clone(), vec![recordbatch.clone()]).unwrap(); - let outputs = vec![Ok(Output::RecordBatches(recordbatches))]; + let outputs = vec![Ok(Output::new_with_record_batches(recordbatches))]; let json_resp = match format { ResponseFormat::Arrow => ArrowResponse::from_output(outputs).await, ResponseFormat::Csv => CsvResponse::from_output(outputs).await, diff --git a/src/servers/src/http/arrow_result.rs b/src/servers/src/http/arrow_result.rs index 3daad34f1d..025bd36cd8 100644 --- a/src/servers/src/http/arrow_result.rs +++ b/src/servers/src/http/arrow_result.rs @@ -20,7 +20,7 @@ use arrow_ipc::writer::FileWriter; use axum::http::{header, HeaderValue}; use axum::response::{IntoResponse, Response}; use common_error::status_code::StatusCode; -use common_query::Output; +use common_query::{Output, OutputData}; use common_recordbatch::RecordBatchStream; use futures::StreamExt; use schemars::JsonSchema; @@ -70,12 +70,12 @@ impl ArrowResponse { } match outputs.remove(0) { - Ok(output) => match output { - Output::AffectedRows(_rows) => HttpResponse::Arrow(ArrowResponse { + Ok(output) => match output.data { + OutputData::AffectedRows(_rows) => HttpResponse::Arrow(ArrowResponse { data: vec![], execution_time_ms: 0, }), - Output::RecordBatches(recordbatches) => { + OutputData::RecordBatches(recordbatches) => { let schema = recordbatches.schema(); match write_arrow_bytes(recordbatches.as_stream(), schema.arrow_schema()).await { @@ -89,7 +89,7 @@ impl ArrowResponse { } } - Output::Stream(recordbatches, _) => { + OutputData::Stream(recordbatches) => { let schema = recordbatches.schema(); match write_arrow_bytes(recordbatches, schema.arrow_schema()).await { Ok(payload) => HttpResponse::Arrow(ArrowResponse { diff --git a/src/servers/src/http/handler.rs b/src/servers/src/http/handler.rs index f8a51d02a6..7207f591e5 100644 --- a/src/servers/src/http/handler.rs +++ b/src/servers/src/http/handler.rs @@ -25,7 +25,7 @@ use common_error::ext::ErrorExt; use common_error::status_code::StatusCode; use common_plugins::GREPTIME_EXEC_PREFIX; use common_query::physical_plan::PhysicalPlan; -use common_query::Output; +use common_query::{Output, OutputData}; use common_recordbatch::util; use common_telemetry::tracing; use datafusion::physical_plan::metrics::MetricValue; @@ -140,46 +140,49 @@ pub async fn from_output( for out in outputs { match out { - Ok(Output::AffectedRows(rows)) => { - results.push(GreptimeQueryOutput::AffectedRows(rows)); - } - Ok(Output::Stream(stream, physical_plan)) => { - let schema = stream.schema().clone(); - // TODO(sunng87): streaming response - let mut http_record_output = match util::collect(stream).await { - Ok(rows) => match HttpRecordsOutput::try_new(schema, rows) { - Ok(rows) => rows, + Ok(o) => match o.data { + OutputData::AffectedRows(rows) => { + results.push(GreptimeQueryOutput::AffectedRows(rows)); + } + OutputData::Stream(stream) => { + let schema = stream.schema().clone(); + // TODO(sunng87): streaming response + let mut http_record_output = match util::collect(stream).await { + Ok(rows) => match HttpRecordsOutput::try_new(schema, rows) { + Ok(rows) => rows, + Err(err) => { + return Err(ErrorResponse::from_error(ty, err)); + } + }, Err(err) => { return Err(ErrorResponse::from_error(ty, err)); } - }, - Err(err) => { - return Err(ErrorResponse::from_error(ty, err)); - } - }; - if let Some(physical_plan) = physical_plan { - let mut result_map = HashMap::new(); + }; + if let Some(physical_plan) = o.meta.plan { + let mut result_map = HashMap::new(); - let mut tmp = vec![&mut merge_map, &mut result_map]; - collect_plan_metrics(physical_plan, &mut tmp); - let re = result_map - .into_iter() - .map(|(k, v)| (k, Value::from(v))) - .collect(); - http_record_output.metrics = re; - } - results.push(GreptimeQueryOutput::Records(http_record_output)) - } - Ok(Output::RecordBatches(rbs)) => { - match HttpRecordsOutput::try_new(rbs.schema(), rbs.take()) { - Ok(rows) => { - results.push(GreptimeQueryOutput::Records(rows)); + let mut tmp = vec![&mut merge_map, &mut result_map]; + collect_plan_metrics(physical_plan, &mut tmp); + let re = result_map + .into_iter() + .map(|(k, v)| (k, Value::from(v))) + .collect(); + http_record_output.metrics = re; } - Err(err) => { - return Err(ErrorResponse::from_error(ty, err)); + results.push(GreptimeQueryOutput::Records(http_record_output)) + } + OutputData::RecordBatches(rbs) => { + match HttpRecordsOutput::try_new(rbs.schema(), rbs.take()) { + Ok(rows) => { + results.push(GreptimeQueryOutput::Records(rows)); + } + Err(err) => { + return Err(ErrorResponse::from_error(ty, err)); + } } } - } + }, + Err(err) => { return Err(ErrorResponse::from_error(ty, err)); } diff --git a/src/servers/src/http/influxdb_result_v1.rs b/src/servers/src/http/influxdb_result_v1.rs index 05525ea128..1cc2ade276 100644 --- a/src/servers/src/http/influxdb_result_v1.rs +++ b/src/servers/src/http/influxdb_result_v1.rs @@ -16,7 +16,7 @@ use axum::http::HeaderValue; use axum::response::{IntoResponse, Response}; use axum::Json; use common_error::ext::ErrorExt; -use common_query::Output; +use common_query::{Output, OutputData}; use common_recordbatch::{util, RecordBatch}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -153,41 +153,45 @@ impl InfluxdbV1Response { for (statement_id, out) in outputs.into_iter().enumerate() { let statement_id = statement_id as u32; match out { - Ok(Output::AffectedRows(_)) => { - results.push(InfluxdbOutput { - statement_id, - series: vec![], - }); - } - Ok(Output::Stream(stream, _)) => { - // TODO(sunng87): streaming response - match util::collect(stream).await { - Ok(rows) => match InfluxdbRecordsOutput::try_from((epoch, rows)) { - Ok(rows) => { - results.push(InfluxdbOutput { - statement_id, - series: vec![rows], - }); - } - Err(err) => { - return make_error_response(err); - } - }, - Err(err) => { - return make_error_response(err); - } - } - } - Ok(Output::RecordBatches(rbs)) => { - match InfluxdbRecordsOutput::try_from((epoch, rbs.take())) { - Ok(rows) => { + Ok(o) => { + match o.data { + OutputData::AffectedRows(_) => { results.push(InfluxdbOutput { statement_id, - series: vec![rows], + series: vec![], }); } - Err(err) => { - return make_error_response(err); + OutputData::Stream(stream) => { + // TODO(sunng87): streaming response + match util::collect(stream).await { + Ok(rows) => match InfluxdbRecordsOutput::try_from((epoch, rows)) { + Ok(rows) => { + results.push(InfluxdbOutput { + statement_id, + series: vec![rows], + }); + } + Err(err) => { + return make_error_response(err); + } + }, + Err(err) => { + return make_error_response(err); + } + } + } + OutputData::RecordBatches(rbs) => { + match InfluxdbRecordsOutput::try_from((epoch, rbs.take())) { + Ok(rows) => { + results.push(InfluxdbOutput { + statement_id, + series: vec![rows], + }); + } + Err(err) => { + return make_error_response(err); + } + } } } } diff --git a/src/servers/src/http/prometheus.rs b/src/servers/src/http/prometheus.rs index 7cfa6d5715..bfece7a907 100644 --- a/src/servers/src/http/prometheus.rs +++ b/src/servers/src/http/prometheus.rs @@ -23,7 +23,7 @@ use common_catalog::parse_catalog_and_schema_from_db_string; use common_error::ext::ErrorExt; use common_error::status_code::StatusCode; use common_query::prelude::{GREPTIME_TIMESTAMP, GREPTIME_VALUE}; -use common_query::Output; +use common_query::{Output, OutputData}; use common_recordbatch::RecordBatches; use common_telemetry::tracing; use common_time::util::{current_time_rfc3339, yesterday_rfc3339}; @@ -336,19 +336,19 @@ async fn retrieve_series_from_query_result( series: &mut Vec>, table_name: &str, ) -> Result<()> { - match result? { - Output::RecordBatches(batches) => { + match result?.data { + OutputData::RecordBatches(batches) => { record_batches_to_series(batches, series, table_name)?; Ok(()) } - Output::Stream(stream, _) => { + OutputData::Stream(stream) => { let batches = RecordBatches::try_collect(stream) .await .context(CollectRecordbatchSnafu)?; record_batches_to_series(batches, series, table_name)?; Ok(()) } - Output::AffectedRows(_) => Err(Error::UnexpectedResult { + OutputData::AffectedRows(_) => Err(Error::UnexpectedResult { reason: "expected data result, but got affected rows".to_string(), location: Location::default(), }), @@ -360,19 +360,19 @@ async fn retrieve_labels_name_from_query_result( result: Result, labels: &mut HashSet, ) -> Result<()> { - match result? { - Output::RecordBatches(batches) => { + match result?.data { + OutputData::RecordBatches(batches) => { record_batches_to_labels_name(batches, labels)?; Ok(()) } - Output::Stream(stream, _) => { + OutputData::Stream(stream) => { let batches = RecordBatches::try_collect(stream) .await .context(CollectRecordbatchSnafu)?; record_batches_to_labels_name(batches, labels)?; Ok(()) } - Output::AffectedRows(_) => UnexpectedResultSnafu { + OutputData::AffectedRows(_) => UnexpectedResultSnafu { reason: "expected data result, but got affected rows".to_string(), } .fail(), @@ -569,17 +569,17 @@ async fn retrieve_label_values( label_name: &str, labels_values: &mut HashSet, ) -> Result<()> { - match result? { - Output::RecordBatches(batches) => { + match result?.data { + OutputData::RecordBatches(batches) => { retrieve_label_values_from_record_batch(batches, label_name, labels_values).await } - Output::Stream(stream, _) => { + OutputData::Stream(stream) => { let batches = RecordBatches::try_collect(stream) .await .context(CollectRecordbatchSnafu)?; retrieve_label_values_from_record_batch(batches, label_name, labels_values).await } - Output::AffectedRows(_) => UnexpectedResultSnafu { + OutputData::AffectedRows(_) => UnexpectedResultSnafu { reason: "expected data result, but got affected rows".to_string(), } .fail(), diff --git a/src/servers/src/http/prometheus_resp.rs b/src/servers/src/http/prometheus_resp.rs index e7a310faf5..775d7e3c11 100644 --- a/src/servers/src/http/prometheus_resp.rs +++ b/src/servers/src/http/prometheus_resp.rs @@ -20,7 +20,7 @@ use axum::response::{IntoResponse, Response}; use axum::Json; use common_error::ext::ErrorExt; use common_error::status_code::StatusCode; -use common_query::Output; +use common_query::{Output, OutputData}; use common_recordbatch::RecordBatches; use datatypes::prelude::ConcreteDataType; use datatypes::scalars::ScalarVector; @@ -107,40 +107,38 @@ impl PrometheusJsonResponse { result_type: ValueType, ) -> Self { let response: Result = try { - let resp = match result? { - Output::RecordBatches(batches) => Self::success(Self::record_batches_to_data( - batches, - metric_name, - result_type, - )?), - Output::Stream(stream, physical_plan) => { - let record_batches = RecordBatches::try_collect(stream) - .await - .context(CollectRecordbatchSnafu)?; - let mut resp = Self::success(Self::record_batches_to_data( - record_batches, - metric_name, - result_type, - )?); - - if let Some(physical_plan) = physical_plan { - let mut result_map = HashMap::new(); - let mut tmp = vec![&mut result_map]; - collect_plan_metrics(physical_plan, &mut tmp); - - let re = result_map - .into_iter() - .map(|(k, v)| (k, Value::from(v))) - .collect(); - resp.resp_metrics = re; + let result = result?; + let mut resp = + match result.data { + OutputData::RecordBatches(batches) => Self::success( + Self::record_batches_to_data(batches, metric_name, result_type)?, + ), + OutputData::Stream(stream) => { + let record_batches = RecordBatches::try_collect(stream) + .await + .context(CollectRecordbatchSnafu)?; + Self::success(Self::record_batches_to_data( + record_batches, + metric_name, + result_type, + )?) } + OutputData::AffectedRows(_) => { + Self::error("Unexpected", "expected data result, but got affected rows") + } + }; - resp - } - Output::AffectedRows(_) => { - Self::error("Unexpected", "expected data result, but got affected rows") - } - }; + if let Some(physical_plan) = result.meta.plan { + let mut result_map = HashMap::new(); + let mut tmp = vec![&mut result_map]; + collect_plan_metrics(physical_plan, &mut tmp); + + let re = result_map + .into_iter() + .map(|(k, v)| (k, Value::from(v))) + .collect(); + resp.resp_metrics = re; + } resp }; diff --git a/src/servers/src/mysql/federated.rs b/src/servers/src/mysql/federated.rs index ea3020cd76..a9c9d630b8 100644 --- a/src/servers/src/mysql/federated.rs +++ b/src/servers/src/mysql/federated.rs @@ -229,7 +229,7 @@ fn select_variable(query: &str, query_context: QueryContextRef) -> Option Option { @@ -254,13 +254,13 @@ fn check_show_variables(query: &str) -> Option { } else { None }; - recordbatches.map(Output::RecordBatches) + recordbatches.map(Output::new_with_record_batches) } // Check for SET or others query, this is the final check of the federated query. fn check_others(query: &str, query_ctx: QueryContextRef) -> Option { if OTHER_NOT_SUPPORTED_STMT.is_match(query.as_bytes()) { - return Some(Output::RecordBatches(RecordBatches::empty())); + return Some(Output::new_with_record_batches(RecordBatches::empty())); } let recordbatches = if SELECT_DATABASE_PATTERN.is_match(query) { @@ -274,7 +274,7 @@ fn check_others(query: &str, query_ctx: QueryContextRef) -> Option { } else { None }; - recordbatches.map(Output::RecordBatches) + recordbatches.map(Output::new_with_record_batches) } // Check whether the query is a federated or driver setup command, @@ -301,6 +301,7 @@ pub(crate) fn check( #[cfg(test)] mod test { + use common_query::OutputData; use common_time::timezone::set_default_timezone; use session::context::{Channel, QueryContext}; use session::Session; @@ -321,8 +322,8 @@ mod test { fn test(query: &str, expected: &str) { let session = Arc::new(Session::new(None, Channel::Mysql)); let output = check(query, QueryContext::arc(), session.clone()); - match output.unwrap() { - Output::RecordBatches(r) => { + match output.unwrap().data { + OutputData::RecordBatches(r) => { assert_eq!(&r.pretty_print().unwrap(), expected) } _ => unreachable!(), diff --git a/src/servers/src/mysql/writer.rs b/src/servers/src/mysql/writer.rs index b01adf374c..3e245fd67c 100644 --- a/src/servers/src/mysql/writer.rs +++ b/src/servers/src/mysql/writer.rs @@ -15,7 +15,7 @@ use std::ops::Deref; use common_error::ext::ErrorExt; -use common_query::Output; +use common_query::{Output, OutputData}; use common_recordbatch::{RecordBatch, SendableRecordBatchStream}; use common_telemetry::{debug, error}; use datatypes::prelude::{ConcreteDataType, Value}; @@ -80,22 +80,22 @@ impl<'a, W: AsyncWrite + Unpin> MysqlResultWriter<'a, W> { // We don't support sending multiple query result because the RowWriter's lifetime is bound to // a local variable. match output { - Ok(output) => match output { - Output::Stream(stream, _) => { + Ok(output) => match output.data { + OutputData::Stream(stream) => { let query_result = QueryResult { schema: stream.schema(), stream, }; Self::write_query_result(query_result, self.writer, self.query_context).await?; } - Output::RecordBatches(recordbatches) => { + OutputData::RecordBatches(recordbatches) => { let query_result = QueryResult { schema: recordbatches.schema(), stream: recordbatches.as_stream(), }; Self::write_query_result(query_result, self.writer, self.query_context).await?; } - Output::AffectedRows(rows) => { + OutputData::AffectedRows(rows) => { let next_writer = Self::write_affected_rows(self.writer, rows).await?; return Ok(Some(MysqlResultWriter::new( next_writer, diff --git a/src/servers/src/postgres/handler.rs b/src/servers/src/postgres/handler.rs index 050ba1488c..352d292758 100644 --- a/src/servers/src/postgres/handler.rs +++ b/src/servers/src/postgres/handler.rs @@ -16,7 +16,7 @@ use std::sync::Arc; use async_trait::async_trait; use common_error::ext::ErrorExt; -use common_query::Output; +use common_query::{Output, OutputData}; use common_recordbatch::error::Result as RecordBatchResult; use common_recordbatch::RecordBatch; use common_telemetry::tracing; @@ -74,15 +74,19 @@ fn output_to_query_response<'a>( field_format: &Format, ) -> PgWireResult> { match output { - Ok(Output::AffectedRows(rows)) => Ok(Response::Execution(Tag::new("OK").with_rows(rows))), - Ok(Output::Stream(record_stream, _)) => { - let schema = record_stream.schema(); - recordbatches_to_query_response(record_stream, schema, field_format) - } - Ok(Output::RecordBatches(recordbatches)) => { - let schema = recordbatches.schema(); - recordbatches_to_query_response(recordbatches.as_stream(), schema, field_format) - } + Ok(o) => match o.data { + OutputData::AffectedRows(rows) => { + Ok(Response::Execution(Tag::new("OK").with_rows(rows))) + } + OutputData::Stream(record_stream) => { + let schema = record_stream.schema(); + recordbatches_to_query_response(record_stream, schema, field_format) + } + OutputData::RecordBatches(recordbatches) => { + let schema = recordbatches.schema(); + recordbatches_to_query_response(recordbatches.as_stream(), schema, field_format) + } + }, Err(e) => Ok(Response::Error(Box::new(ErrorInfo::new( "ERROR".to_string(), "XX000".to_string(), diff --git a/src/servers/tests/interceptor.rs b/src/servers/tests/interceptor.rs index 1482a92078..ca5837a993 100644 --- a/src/servers/tests/interceptor.rs +++ b/src/servers/tests/interceptor.rs @@ -16,6 +16,7 @@ use std::borrow::Cow; use api::v1::greptime_request::Request; use api::v1::{InsertRequest, InsertRequests}; +use client::OutputData; use common_query::Output; use query::parser::PromQuery; use servers::error::{self, InternalSnafu, NotSupportedSnafu, Result}; @@ -101,8 +102,8 @@ impl PromQueryInterceptor for NoopInterceptor { output: Output, _query_ctx: QueryContextRef, ) -> std::result::Result { - match output { - Output::AffectedRows(1) => Ok(Output::AffectedRows(2)), + match output.data { + OutputData::AffectedRows(1) => Ok(Output::new_with_affected_rows(2)), _ => Ok(output), } } @@ -121,8 +122,14 @@ fn test_prom_interceptor() { let fail = PromQueryInterceptor::pre_execute(&di, &query, ctx.clone()); assert!(fail.is_err()); - let output = Output::AffectedRows(1); + let output = Output::new_with_affected_rows(1); let two = PromQueryInterceptor::post_execute(&di, output, ctx); assert!(two.is_ok()); - matches!(two.unwrap(), Output::AffectedRows(2)); + matches!( + two.unwrap(), + Output { + data: OutputData::AffectedRows(2), + .. + } + ); } diff --git a/src/servers/tests/py_script/mod.rs b/src/servers/tests/py_script/mod.rs index 5b63a91d1b..03232329b8 100644 --- a/src/servers/tests/py_script/mod.rs +++ b/src/servers/tests/py_script/mod.rs @@ -15,7 +15,7 @@ use std::collections::HashMap; use std::sync::Arc; -use common_query::Output; +use common_query::OutputData; use common_recordbatch::RecordBatch; use datatypes::prelude::ConcreteDataType; use datatypes::schema::{ColumnSchema, Schema}; @@ -70,8 +70,8 @@ def hello() -> vector[str]: .execute_script(query_ctx.clone(), name, HashMap::new()) .await?; - match output { - Output::RecordBatches(batches) => { + match output.data { + OutputData::RecordBatches(batches) => { let expected = "\ +-------+ | n | @@ -88,12 +88,12 @@ def hello() -> vector[str]: .await .remove(0) .unwrap(); - match res { - common_query::Output::AffectedRows(_) => (), - common_query::Output::RecordBatches(_) => { + match res.data { + OutputData::AffectedRows(_) => (), + OutputData::RecordBatches(_) => { unreachable!() } - common_query::Output::Stream(s, _) => { + OutputData::Stream(s) => { let batches = common_recordbatch::util::collect_batches(s).await.unwrap(); let expected = "\ +---------+ diff --git a/tests-integration/src/grpc.rs b/tests-integration/src/grpc.rs index db2de37f16..2d2a6294a9 100644 --- a/tests-integration/src/grpc.rs +++ b/tests-integration/src/grpc.rs @@ -26,6 +26,7 @@ mod test { CreateDatabaseExpr, CreateTableExpr, DdlRequest, DeleteRequest, DeleteRequests, DropTableExpr, InsertRequest, InsertRequests, QueryRequest, SemanticType, }; + use client::OutputData; use common_catalog::consts::MITO_ENGINE; use common_meta::rpc::router::region_distribution; use common_query::Output; @@ -78,7 +79,7 @@ mod test { })), }); let output = query(instance, request).await; - assert!(matches!(output, Output::AffectedRows(1))); + assert!(matches!(output.data, OutputData::AffectedRows(1))); let request = Request::Ddl(DdlRequest { expr: Some(DdlExpr::CreateTable(CreateTableExpr { @@ -109,7 +110,7 @@ mod test { })), }); let output = query(instance, request).await; - assert!(matches!(output, Output::AffectedRows(0))); + assert!(matches!(output.data, OutputData::AffectedRows(0))); let request = Request::Ddl(DdlRequest { expr: Some(DdlExpr::Alter(AlterExpr { @@ -132,13 +133,13 @@ mod test { })), }); let output = query(instance, request).await; - assert!(matches!(output, Output::AffectedRows(0))); + assert!(matches!(output.data, OutputData::AffectedRows(0))); let request = Request::Query(QueryRequest { query: Some(Query::Sql("INSERT INTO database_created_through_grpc.table_created_through_grpc (a, b, ts) VALUES ('s', 1, 1672816466000)".to_string())) }); let output = query(instance, request).await; - assert!(matches!(output, Output::AffectedRows(1))); + assert!(matches!(output.data, OutputData::AffectedRows(1))); let request = Request::Query(QueryRequest { query: Some(Query::Sql( @@ -147,7 +148,7 @@ mod test { )), }); let output = query(instance, request).await; - let Output::Stream(stream, _) = output else { + let OutputData::Stream(stream) = output.data else { unreachable!() }; let recordbatches = RecordBatches::try_collect(stream).await.unwrap(); @@ -168,7 +169,7 @@ mod test { })), }); let output = query(instance, request).await; - assert!(matches!(output, Output::AffectedRows(0))); + assert!(matches!(output.data, OutputData::AffectedRows(0))); } async fn verify_table_is_dropped(instance: &MockDistributedInstance) { @@ -307,7 +308,7 @@ CREATE TABLE {table_name} ( query: Some(Query::Sql(sql)), }); let output = query(frontend, request).await; - assert!(matches!(output, Output::AffectedRows(0))); + assert!(matches!(output.data, OutputData::AffectedRows(0))); } async fn test_insert_delete_and_query_on_existing_table(instance: &Instance, table_name: &str) { @@ -376,7 +377,7 @@ CREATE TABLE {table_name} ( }), ) .await; - assert!(matches!(output, Output::AffectedRows(16))); + assert!(matches!(output.data, OutputData::AffectedRows(16))); let request = Request::Query(QueryRequest { query: Some(Query::Sql(format!( @@ -384,7 +385,7 @@ CREATE TABLE {table_name} ( ))), }); let output = query(instance, request.clone()).await; - let Output::Stream(stream, _) = output else { + let OutputData::Stream(stream) = output.data else { unreachable!() }; let recordbatches = RecordBatches::try_collect(stream).await.unwrap(); @@ -474,10 +475,10 @@ CREATE TABLE {table_name} ( }), ) .await; - assert!(matches!(output, Output::AffectedRows(6))); + assert!(matches!(output.data, OutputData::AffectedRows(6))); let output = query(instance, request).await; - let Output::Stream(stream, _) = output else { + let OutputData::Stream(stream) = output.data else { unreachable!() }; let recordbatches = RecordBatches::try_collect(stream).await.unwrap(); @@ -604,7 +605,7 @@ CREATE TABLE {table_name} ( inserts: vec![insert], }); let output = query(instance, request).await; - assert!(matches!(output, Output::AffectedRows(3))); + assert!(matches!(output.data, OutputData::AffectedRows(3))); let insert = InsertRequest { table_name: "auto_created_table".to_string(), @@ -643,7 +644,7 @@ CREATE TABLE {table_name} ( inserts: vec![insert], }); let output = query(instance, request).await; - assert!(matches!(output, Output::AffectedRows(3))); + assert!(matches!(output.data, OutputData::AffectedRows(3))); let request = Request::Query(QueryRequest { query: Some(Query::Sql( @@ -651,7 +652,7 @@ CREATE TABLE {table_name} ( )), }); let output = query(instance, request.clone()).await; - let Output::Stream(stream, _) = output else { + let OutputData::Stream(stream) = output.data else { unreachable!() }; let recordbatches = RecordBatches::try_collect(stream).await.unwrap(); @@ -690,10 +691,10 @@ CREATE TABLE {table_name} ( }), ) .await; - assert!(matches!(output, Output::AffectedRows(2))); + assert!(matches!(output.data, OutputData::AffectedRows(2))); let output = query(instance, request).await; - let Output::Stream(stream, _) = output else { + let OutputData::Stream(stream) = output.data else { unreachable!() }; let recordbatches = RecordBatches::try_collect(stream).await.unwrap(); @@ -780,7 +781,7 @@ CREATE TABLE {table_name} ( inserts: vec![insert], }); let output = query(instance, request).await; - assert!(matches!(output, Output::AffectedRows(8))); + assert!(matches!(output.data, OutputData::AffectedRows(8))); let request = Request::Query(QueryRequest { query: Some(Query::PromRangeQuery(api::v1::PromRangeQuery { @@ -791,7 +792,7 @@ CREATE TABLE {table_name} ( })), }); let output = query(instance, request).await; - let Output::Stream(stream, _) = output else { + let OutputData::Stream(stream) = output.data else { unreachable!() }; let recordbatches = RecordBatches::try_collect(stream).await.unwrap(); diff --git a/tests-integration/src/influxdb.rs b/tests-integration/src/influxdb.rs index 305bf55ab8..be73e415ac 100644 --- a/tests-integration/src/influxdb.rs +++ b/tests-integration/src/influxdb.rs @@ -16,7 +16,7 @@ mod test { use std::sync::Arc; - use common_query::Output; + use client::OutputData; use common_recordbatch::RecordBatches; use frontend::instance::Instance; use servers::influxdb::InfluxdbRequest; @@ -80,7 +80,7 @@ monitor1,host=host2 memory=1027"; ) .await; let output = output.remove(0).unwrap(); - let Output::Stream(stream, _) = output else { + let OutputData::Stream(stream) = output.data else { unreachable!() }; @@ -109,7 +109,7 @@ monitor1,host=host2 memory=1027 1663840496400340001"; ) .await; let output = output.remove(0).unwrap(); - let Output::Stream(stream, _) = output else { + let OutputData::Stream(stream) = output.data else { unreachable!() }; let recordbatches = RecordBatches::try_collect(stream).await.unwrap(); diff --git a/tests-integration/src/instance.rs b/tests-integration/src/instance.rs index 965d571d5a..90c66e15e9 100644 --- a/tests-integration/src/instance.rs +++ b/tests-integration/src/instance.rs @@ -20,6 +20,7 @@ mod tests { use std::sync::Arc; use api::v1::region::QueryRequest; + use client::OutputData; use common_base::Plugins; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_meta::key::table_name::TableNameKey; @@ -152,7 +153,7 @@ mod tests { async fn create_table(instance: &Instance, sql: &str) { let output = query(instance, sql).await; - let Output::AffectedRows(x) = output else { + let OutputData::AffectedRows(x) = output.data else { unreachable!() }; assert_eq!(x, 0); @@ -166,14 +167,14 @@ mod tests { ('MOSS', 100000000, 10000000000, 2335190400000) "#; let output = query(instance, sql).await; - let Output::AffectedRows(x) = output else { + let OutputData::AffectedRows(x) = output.data else { unreachable!() }; assert_eq!(x, 4); let sql = "SELECT * FROM demo WHERE ts > cast(1000000000 as timestamp) ORDER BY host"; // use nanoseconds as where condition let output = query(instance, sql).await; - let Output::Stream(s, _) = output else { + let OutputData::Stream(s) = output.data else { unreachable!() }; let batches = common_recordbatch::util::collect_batches(s).await.unwrap(); @@ -264,7 +265,7 @@ mod tests { async fn drop_table(instance: &Instance) { let sql = "DROP TABLE demo"; let output = query(instance, sql).await; - let Output::AffectedRows(x) = output else { + let OutputData::AffectedRows(x) = output.data else { unreachable!() }; assert_eq!(x, 0); @@ -326,8 +327,8 @@ mod tests { _query_ctx: QueryContextRef, ) -> Result { let _ = self.c.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - match &mut output { - Output::AffectedRows(rows) => { + match &mut output.data { + OutputData::AffectedRows(rows) => { assert_eq!(*rows, 0); // update output result *rows = 10; @@ -364,8 +365,8 @@ mod tests { // assert that the hook is called 3 times assert_eq!(4, counter_hook.c.load(std::sync::atomic::Ordering::Relaxed)); - match output { - Output::AffectedRows(rows) => assert_eq!(rows, 10), + match output.data { + OutputData::AffectedRows(rows) => assert_eq!(rows, 10), _ => unreachable!(), } } @@ -424,8 +425,8 @@ mod tests { .remove(0) .unwrap(); - match output { - Output::AffectedRows(rows) => assert_eq!(rows, 0), + match output.data { + OutputData::AffectedRows(rows) => assert_eq!(rows, 0), _ => unreachable!(), } diff --git a/tests-integration/src/opentsdb.rs b/tests-integration/src/opentsdb.rs index 497c214017..21f0896b2f 100644 --- a/tests-integration/src/opentsdb.rs +++ b/tests-integration/src/opentsdb.rs @@ -16,7 +16,7 @@ mod tests { use std::sync::Arc; - use common_query::Output; + use client::OutputData; use common_recordbatch::RecordBatches; use frontend::instance::Instance; use itertools::Itertools; @@ -83,8 +83,8 @@ mod tests { .await .remove(0) .unwrap(); - match output { - Output::Stream(stream, _) => { + match output.data { + OutputData::Stream(stream) => { let recordbatches = RecordBatches::try_collect(stream).await.unwrap(); let pretty_print = recordbatches.pretty_print().unwrap(); let expected = vec![ diff --git a/tests-integration/src/otlp.rs b/tests-integration/src/otlp.rs index 8a3b0cee25..b961bc2cf0 100644 --- a/tests-integration/src/otlp.rs +++ b/tests-integration/src/otlp.rs @@ -16,8 +16,7 @@ mod test { use std::sync::Arc; - use client::DEFAULT_CATALOG_NAME; - use common_query::Output; + use client::{OutputData, DEFAULT_CATALOG_NAME}; use common_recordbatch::RecordBatches; use frontend::instance::Instance; use opentelemetry_proto::tonic::collector::metrics::v1::ExportMetricsServiceRequest; @@ -75,7 +74,7 @@ mod test { ) .await; let output = output.remove(0).unwrap(); - let Output::Stream(stream, _) = output else { + let OutputData::Stream(stream) = output.data else { unreachable!() }; let recordbatches = RecordBatches::try_collect(stream).await.unwrap(); @@ -97,7 +96,7 @@ mod test { ) .await; let output = output.remove(0).unwrap(); - let Output::Stream(stream, _) = output else { + let OutputData::Stream(stream) = output.data else { unreachable!() }; let recordbatches = RecordBatches::try_collect(stream).await.unwrap(); @@ -117,7 +116,7 @@ mod test { .do_query("SELECT * FROM my_test_histo_sum", ctx.clone()) .await; let output = output.remove(0).unwrap(); - let Output::Stream(stream, _) = output else { + let OutputData::Stream(stream) = output.data else { unreachable!() }; let recordbatches = RecordBatches::try_collect(stream).await.unwrap(); @@ -135,7 +134,7 @@ mod test { .do_query("SELECT * FROM my_test_histo_count", ctx.clone()) .await; let output = output.remove(0).unwrap(); - let Output::Stream(stream, _) = output else { + let OutputData::Stream(stream) = output.data else { unreachable!() }; let recordbatches = RecordBatches::try_collect(stream).await.unwrap(); diff --git a/tests-integration/src/tests/instance_kafka_wal_test.rs b/tests-integration/src/tests/instance_kafka_wal_test.rs index f6f05b53d0..2135c49823 100644 --- a/tests-integration/src/tests/instance_kafka_wal_test.rs +++ b/tests-integration/src/tests/instance_kafka_wal_test.rs @@ -14,6 +14,7 @@ use std::sync::Arc; +use client::OutputData; use common_query::Output; use common_recordbatch::util; use datatypes::vectors::{TimestampMillisecondVector, VectorRef}; @@ -33,7 +34,7 @@ async fn test_create_database_and_insert_query(instance: Option { + match query_output.data { + OutputData::Stream(s) => { let batches = util::collect(s).await.unwrap(); assert_eq!(1, batches[0].num_columns()); assert_eq!( diff --git a/tests-integration/src/tests/instance_test.rs b/tests-integration/src/tests/instance_test.rs index 78653c1d15..7bede82e30 100644 --- a/tests-integration/src/tests/instance_test.rs +++ b/tests-integration/src/tests/instance_test.rs @@ -15,6 +15,7 @@ use std::env; use std::sync::Arc; +use client::OutputData; use common_catalog::consts::DEFAULT_CATALOG_NAME; use common_query::Output; use common_recordbatch::util; @@ -40,8 +41,8 @@ use crate::tests::test_util::{ async fn test_create_database_and_insert_query(instance: Arc) { let instance = instance.frontend(); - let output = execute_sql(&instance, "create database test").await; - assert!(matches!(output, Output::AffectedRows(1))); + let output = execute_sql(&instance, "create database test").await.data; + assert!(matches!(output, OutputData::AffectedRows(1))); let output = execute_sql( &instance, @@ -53,8 +54,9 @@ async fn test_create_database_and_insert_query(instance: Arc) TIME INDEX(ts) )"#, ) - .await; - assert!(matches!(output, Output::AffectedRows(0))); + .await + .data; + assert!(matches!(output, OutputData::AffectedRows(0))); let output = execute_sql( &instance, @@ -63,12 +65,15 @@ async fn test_create_database_and_insert_query(instance: Arc) ('host2', 88.8, 333.3, 1655276558000) "#, ) - .await; - assert!(matches!(output, Output::AffectedRows(2))); + .await + .data; + assert!(matches!(output, OutputData::AffectedRows(2))); - let query_output = execute_sql(&instance, "select ts from test.demo order by ts limit 1").await; + let query_output = execute_sql(&instance, "select ts from test.demo order by ts limit 1") + .await + .data; match query_output { - Output::Stream(s, _) => { + OutputData::Stream(s) => { let batches = util::collect(s).await.unwrap(); assert_eq!(1, batches[0].num_columns()); assert_eq!( @@ -106,10 +111,10 @@ PARTITION ON COLUMNS (n) ( TIME INDEX(ts) )"# }; - let output = execute_sql(&frontend, sql).await; - assert!(matches!(output, Output::AffectedRows(0))); + let output = execute_sql(&frontend, sql).await.data; + assert!(matches!(output, OutputData::AffectedRows(0))); - let output = execute_sql(&frontend, "show create table demo").await; + let output = execute_sql(&frontend, "show create table demo").await.data; let expected = if instance.is_distributed_mode() { r#"+-------+-------------------------------------+ @@ -192,12 +197,14 @@ async fn test_show_create_external_table(instance: Arc) { r#"create external table {table_name} with (location='{location}', format='{format}');"#, ), ) - .await; - assert!(matches!(output, Output::AffectedRows(0))); + .await.data; + assert!(matches!(output, OutputData::AffectedRows(0))); - let output = execute_sql(&fe_instance, &format!("show create table {table_name};")).await; + let output = execute_sql(&fe_instance, &format!("show create table {table_name};")) + .await + .data; - let Output::RecordBatches(record_batches) = output else { + let OutputData::RecordBatches(record_batches) = output else { unreachable!() }; @@ -232,10 +239,10 @@ async fn test_issue477_same_table_name_in_different_databases(instance: Arc, sql: &str, ts: i64, host: &str) { - let query_output = execute_sql(instance, sql).await; + let query_output = execute_sql(instance, sql).await.data; match query_output { - Output::Stream(s, _) => { + OutputData::Stream(s) => { let batches = util::collect(s).await.unwrap(); // let columns = batches[0].df_recordbatch.columns(); assert_eq!(2, batches[0].num_columns()); @@ -326,8 +337,9 @@ async fn test_execute_insert(instance: Arc) { &instance, "create table demo(host string, cpu double, memory double, ts timestamp time index);", ) - .await, - Output::AffectedRows(0) + .await + .data, + OutputData::AffectedRows(0) )); let output = execute_sql( @@ -337,8 +349,9 @@ async fn test_execute_insert(instance: Arc) { ('host2', 88.8, 333.3, 1655276558000) "#, ) - .await; - assert!(matches!(output, Output::AffectedRows(2))); + .await + .data; + assert!(matches!(output, OutputData::AffectedRows(2))); } #[apply(both_instances_cases)] @@ -351,16 +364,18 @@ async fn test_execute_insert_by_select(instance: Arc) { &instance, "create table demo1(host string, cpu double, memory double, ts timestamp time index);", ) - .await, - Output::AffectedRows(0) + .await + .data, + OutputData::AffectedRows(0) )); assert!(matches!( execute_sql( &instance, "create table demo2(host string, cpu double, memory double, ts timestamp time index);", ) - .await, - Output::AffectedRows(0) + .await + .data, + OutputData::AffectedRows(0) )); let output = execute_sql( @@ -370,8 +385,9 @@ async fn test_execute_insert_by_select(instance: Arc) { ('host2', 88.8, 333.3, 1655276558000) "#, ) - .await; - assert!(matches!(output, Output::AffectedRows(2))); + .await + .data; + assert!(matches!(output, OutputData::AffectedRows(2))); assert!(matches!( try_execute_sql(&instance, "insert into demo2(host) select * from demo1") @@ -401,10 +417,14 @@ async fn test_execute_insert_by_select(instance: Arc) { } )); - let output = execute_sql(&instance, "insert into demo2 select * from demo1").await; - assert!(matches!(output, Output::AffectedRows(2))); + let output = execute_sql(&instance, "insert into demo2 select * from demo1") + .await + .data; + assert!(matches!(output, OutputData::AffectedRows(2))); - let output = execute_sql(&instance, "select * from demo2 order by ts").await; + let output = execute_sql(&instance, "select * from demo2 order by ts") + .await + .data; let expected = "\ +-------+------+--------+---------------------+ | host | cpu | memory | ts | @@ -419,9 +439,11 @@ async fn test_execute_insert_by_select(instance: Arc) { async fn test_execute_query(instance: Arc) { let instance = instance.frontend(); - let output = execute_sql(&instance, "select sum(number) from numbers limit 20").await; + let output = execute_sql(&instance, "select sum(number) from numbers limit 20") + .await + .data; match output { - Output::Stream(recordbatch, _) => { + OutputData::Stream(recordbatch) => { let numbers = util::collect(recordbatch).await.unwrap(); assert_eq!(1, numbers[0].num_columns()); assert_eq!(numbers[0].column(0).len(), 1); @@ -474,7 +496,7 @@ async fn test_execute_show_databases_tables(instance: Arc) { assert!(matches!(execute_sql( &instance, "create table demo(host string, cpu double, memory double, ts timestamp time index, primary key (host));", - ).await, Output::AffectedRows(0))); + ).await.data, OutputData::AffectedRows(0))); let output = execute_sql(&instance, "show tables").await; let expected = "\ @@ -539,8 +561,9 @@ async fn test_execute_create(instance: Arc) { PRIMARY KEY(host) ) engine=mito with(regions=1);"#, ) - .await; - assert!(matches!(output, Output::AffectedRows(0))); + .await + .data; + assert!(matches!(output, OutputData::AffectedRows(0))); } #[apply(both_instances_cases)] @@ -561,8 +584,9 @@ async fn test_execute_external_create(instance: Arc) { ) with (location='{location}', format='csv');"# ), ) - .await; - assert!(matches!(output, Output::AffectedRows(0))); + .await + .data; + assert!(matches!(output, OutputData::AffectedRows(0))); let output = execute_sql( &instance, @@ -576,8 +600,9 @@ async fn test_execute_external_create(instance: Arc) { ) with (location='{location}', format='csv');"# ), ) - .await; - assert!(matches!(output, Output::AffectedRows(0))); + .await + .data; + assert!(matches!(output, OutputData::AffectedRows(0))); } #[apply(both_instances_cases)] @@ -591,8 +616,9 @@ async fn test_execute_external_create_infer_format(instance: Arc) { "create table demo(host string, cpu double, memory double, ts timestamp, time index(ts))", query_ctx.clone(), ) - .await; - assert!(matches!(output, Output::AffectedRows(0))); + .await + .data; + assert!(matches!(output, OutputData::AffectedRows(0))); // make sure table insertion is ok before altering table name let output = execute_sql_with( @@ -1161,8 +1225,8 @@ async fn test_rename_table(instance: Arc) { "insert into demo(host, cpu, memory, ts) values ('host1', 1.1, 100, 1000), ('host2', 2.2, 200, 2000)", query_ctx.clone(), ) - .await; - assert!(matches!(output, Output::AffectedRows(2))); + .await.data; + assert!(matches!(output, OutputData::AffectedRows(2))); // rename table let output = execute_sql_with( @@ -1170,10 +1234,13 @@ async fn test_rename_table(instance: Arc) { "alter table demo rename test_table", query_ctx.clone(), ) - .await; - assert!(matches!(output, Output::AffectedRows(0))); + .await + .data; + assert!(matches!(output, OutputData::AffectedRows(0))); - let output = execute_sql_with(&instance, "show tables", query_ctx.clone()).await; + let output = execute_sql_with(&instance, "show tables", query_ctx.clone()) + .await + .data; let expect = "\ +------------+ | Tables | @@ -1187,7 +1254,8 @@ async fn test_rename_table(instance: Arc) { "select * from test_table order by ts", query_ctx.clone(), ) - .await; + .await + .data; let expected = "\ +-------+-----+--------+---------------------+ | host | cpu | memory | ts | @@ -1209,8 +1277,8 @@ async fn test_rename_table(instance: Arc) { async fn test_create_table_after_rename_table(instance: Arc) { let instance = instance.frontend(); - let output = execute_sql(&instance, "create database db").await; - assert!(matches!(output, Output::AffectedRows(1))); + let output = execute_sql(&instance, "create database db").await.data; + assert!(matches!(output, OutputData::AffectedRows(1))); // create test table let table_name = "demo"; @@ -1220,8 +1288,8 @@ async fn test_create_table_after_rename_table(instance: Arc) { &format!("create table {table_name}(host string, cpu double, memory double, ts timestamp, time index(ts))"), query_ctx.clone(), ) - .await; - assert!(matches!(output, Output::AffectedRows(0))); + .await.data; + assert!(matches!(output, OutputData::AffectedRows(0))); // rename table let new_table_name = "test_table"; @@ -1230,8 +1298,9 @@ async fn test_create_table_after_rename_table(instance: Arc) { &format!("alter table {table_name} rename {new_table_name}"), query_ctx.clone(), ) - .await; - assert!(matches!(output, Output::AffectedRows(0))); + .await + .data; + assert!(matches!(output, OutputData::AffectedRows(0))); // create table with same name // create test table @@ -1240,8 +1309,8 @@ async fn test_create_table_after_rename_table(instance: Arc) { &format!("create table {table_name}(host string, cpu double, memory double, ts timestamp, time index(ts))"), query_ctx.clone(), ) - .await; - assert!(matches!(output, Output::AffectedRows(0))); + .await.data; + assert!(matches!(output, OutputData::AffectedRows(0))); let expect = "\ +------------+ @@ -1250,7 +1319,9 @@ async fn test_create_table_after_rename_table(instance: Arc) { | demo | | test_table | +------------+"; - let output = execute_sql_with(&instance, "show tables", query_ctx).await; + let output = execute_sql_with(&instance, "show tables", query_ctx) + .await + .data; check_output_stream(output, expect).await; } @@ -1264,8 +1335,9 @@ async fn test_alter_table(instance: Arc) { &instance, "create table demo(host string, cpu double, memory double, ts timestamp time index);", ) - .await, - Output::AffectedRows(0) + .await + .data, + OutputData::AffectedRows(0) )); // make sure table insertion is ok before altering table @@ -1274,28 +1346,35 @@ async fn test_alter_table(instance: Arc) { &instance, "insert into demo(host, cpu, memory, ts) values ('host1', 1.1, 100, 1000)", ) - .await, - Output::AffectedRows(1) + .await + .data, + OutputData::AffectedRows(1) )); // Add column - let output = execute_sql(&instance, "alter table demo add my_tag string null").await; - assert!(matches!(output, Output::AffectedRows(0))); + let output = execute_sql(&instance, "alter table demo add my_tag string null") + .await + .data; + assert!(matches!(output, OutputData::AffectedRows(0))); let output = execute_sql( &instance, "insert into demo(host, cpu, memory, ts, my_tag) values ('host2', 2.2, 200, 2000, 'hello')", ) - .await; - assert!(matches!(output, Output::AffectedRows(1))); + .await + .data; + assert!(matches!(output, OutputData::AffectedRows(1))); let output = execute_sql( &instance, "insert into demo(host, cpu, memory, ts) values ('host3', 3.3, 300, 3000)", ) - .await; - assert!(matches!(output, Output::AffectedRows(1))); + .await + .data; + assert!(matches!(output, OutputData::AffectedRows(1))); - let output = execute_sql(&instance, "select * from demo order by ts").await; + let output = execute_sql(&instance, "select * from demo order by ts") + .await + .data; let expected = "\ +-------+-----+--------+---------------------+--------+ | host | cpu | memory | ts | my_tag | @@ -1307,10 +1386,14 @@ async fn test_alter_table(instance: Arc) { check_output_stream(output, expected).await; // Drop a column - let output = execute_sql(&instance, "alter table demo drop column memory").await; - assert!(matches!(output, Output::AffectedRows(0))); + let output = execute_sql(&instance, "alter table demo drop column memory") + .await + .data; + assert!(matches!(output, OutputData::AffectedRows(0))); - let output = execute_sql(&instance, "select * from demo order by ts").await; + let output = execute_sql(&instance, "select * from demo order by ts") + .await + .data; let expected = "\ +-------+-----+---------------------+--------+ | host | cpu | ts | my_tag | @@ -1326,10 +1409,13 @@ async fn test_alter_table(instance: Arc) { &instance, "insert into demo(host, cpu, ts, my_tag) values ('host4', 400, 4000, 'world')", ) - .await; - assert!(matches!(output, Output::AffectedRows(1))); + .await + .data; + assert!(matches!(output, OutputData::AffectedRows(1))); - let output = execute_sql(&instance, "select * from demo order by ts").await; + let output = execute_sql(&instance, "select * from demo order by ts") + .await + .data; let expected = "\ +-------+-------+---------------------+--------+ | host | cpu | ts | my_tag | @@ -1353,26 +1439,30 @@ async fn test_insert_with_default_value_for_type(instance: Arc, type_n PRIMARY KEY(host) ) engine=mito with(regions=1);"#, ); - let output = execute_sql(&instance, &create_sql).await; - assert!(matches!(output, Output::AffectedRows(0))); + let output = execute_sql(&instance, &create_sql).await.data; + assert!(matches!(output, OutputData::AffectedRows(0))); // Insert with ts. let output = execute_sql( &instance, &format!("insert into {table_name}(host, cpu, ts) values ('host1', 1.1, 1000)"), ) - .await; - assert!(matches!(output, Output::AffectedRows(1))); + .await + .data; + assert!(matches!(output, OutputData::AffectedRows(1))); // Insert without ts, so it should be filled by default value. let output = execute_sql( &instance, &format!("insert into {table_name}(host, cpu) values ('host2', 2.2)"), ) - .await; - assert!(matches!(output, Output::AffectedRows(1))); + .await + .data; + assert!(matches!(output, OutputData::AffectedRows(1))); - let output = execute_sql(&instance, &format!("select host, cpu from {table_name}")).await; + let output = execute_sql(&instance, &format!("select host, cpu from {table_name}")) + .await + .data; let expected = "\ +-------+-----+ | host | cpu | @@ -1393,8 +1483,8 @@ async fn test_insert_with_default_value(instance: Arc) { async fn test_use_database(instance: Arc) { let instance = instance.frontend(); - let output = execute_sql(&instance, "create database db1").await; - assert!(matches!(output, Output::AffectedRows(1))); + let output = execute_sql(&instance, "create database db1").await.data; + assert!(matches!(output, OutputData::AffectedRows(1))); let query_ctx = QueryContext::with(DEFAULT_CATALOG_NAME, "db1"); let output = execute_sql_with( @@ -1402,10 +1492,13 @@ async fn test_use_database(instance: Arc) { "create table tb1(col_i32 int, ts timestamp, TIME INDEX(ts))", query_ctx.clone(), ) - .await; - assert!(matches!(output, Output::AffectedRows(0))); + .await + .data; + assert!(matches!(output, OutputData::AffectedRows(0))); - let output = execute_sql_with(&instance, "show tables", query_ctx.clone()).await; + let output = execute_sql_with(&instance, "show tables", query_ctx.clone()) + .await + .data; let expected = "\ +--------+ | Tables | @@ -1419,10 +1512,13 @@ async fn test_use_database(instance: Arc) { r#"insert into tb1(col_i32, ts) values (1, 1655276557000)"#, query_ctx.clone(), ) - .await; - assert!(matches!(output, Output::AffectedRows(1))); + .await + .data; + assert!(matches!(output, OutputData::AffectedRows(1))); - let output = execute_sql_with(&instance, "select col_i32 from tb1", query_ctx.clone()).await; + let output = execute_sql_with(&instance, "select col_i32 from tb1", query_ctx.clone()) + .await + .data; let expected = "\ +---------+ | col_i32 | @@ -1433,7 +1529,9 @@ async fn test_use_database(instance: Arc) { // Making a particular database the default by means of the USE statement does not preclude // accessing tables in other databases. - let output = execute_sql(&instance, "select number from public.numbers limit 1").await; + let output = execute_sql(&instance, "select number from public.numbers limit 1") + .await + .data; let expected = "\ +--------+ | number | @@ -1458,8 +1556,9 @@ async fn test_delete(instance: Arc) { PRIMARY KEY(host) ) engine=mito with(regions=1);"#, ) - .await; - assert!(matches!(output, Output::AffectedRows(0))); + .await + .data; + assert!(matches!(output, OutputData::AffectedRows(0))); let output = execute_sql( &instance, @@ -1469,17 +1568,21 @@ async fn test_delete(instance: Arc) { ('host3', 88.8, 3072, 1655276559000) "#, ) - .await; - assert!(matches!(output, Output::AffectedRows(3))); + .await + .data; + assert!(matches!(output, OutputData::AffectedRows(3))); let output = execute_sql( &instance, "delete from test_table where host = 'host1' and ts = 1655276557000 ", ) - .await; - assert!(matches!(output, Output::AffectedRows(1))); + .await + .data; + assert!(matches!(output, OutputData::AffectedRows(1))); - let output = execute_sql(&instance, "select * from test_table").await; + let output = execute_sql(&instance, "select * from test_table") + .await + .data; let expect = "\ +-------+---------------------+------+--------+ | host | ts | cpu | memory | @@ -1501,7 +1604,7 @@ async fn test_execute_copy_to_s3(instance: Arc) { &instance, "create table demo(host string, cpu double, memory double, ts timestamp time index);", ) - .await, Output::AffectedRows(0))); + .await.data, OutputData::AffectedRows(0))); let output = execute_sql( &instance, @@ -1510,8 +1613,9 @@ async fn test_execute_copy_to_s3(instance: Arc) { ('host2', 88.8, 333.3, 1655276558000) "#, ) - .await; - assert!(matches!(output, Output::AffectedRows(2))); + .await + .data; + assert!(matches!(output, OutputData::AffectedRows(2))); let key_id = env::var("GT_S3_ACCESS_KEY_ID").unwrap(); let key = env::var("GT_S3_ACCESS_KEY").unwrap(); let region = env::var("GT_S3_REGION").unwrap(); @@ -1521,8 +1625,8 @@ async fn test_execute_copy_to_s3(instance: Arc) { // exports let copy_to_stmt = format!("Copy demo TO 's3://{}/{}/export/demo.parquet' CONNECTION (ACCESS_KEY_ID='{}',SECRET_ACCESS_KEY='{}',REGION='{}')", bucket, root, key_id, key, region); - let output = execute_sql(&instance, ©_to_stmt).await; - assert!(matches!(output, Output::AffectedRows(2))); + let output = execute_sql(&instance, ©_to_stmt).await.data; + assert!(matches!(output, OutputData::AffectedRows(2))); } } } @@ -1539,7 +1643,7 @@ async fn test_execute_copy_from_s3(instance: Arc) { &instance, "create table demo(host string, cpu double, memory double, ts timestamp time index);", ) - .await, Output::AffectedRows(0))); + .await.data, OutputData::AffectedRows(0))); let output = execute_sql( &instance, @@ -1548,8 +1652,9 @@ async fn test_execute_copy_from_s3(instance: Arc) { ('host2', 88.8, 333.3, 1655276558000) "#, ) - .await; - assert!(matches!(output, Output::AffectedRows(2))); + .await + .data; + assert!(matches!(output, OutputData::AffectedRows(2))); // export let root = uuid::Uuid::new_v4().to_string(); @@ -1559,8 +1664,8 @@ async fn test_execute_copy_from_s3(instance: Arc) { let copy_to_stmt = format!("Copy demo TO 's3://{}/{}/export/demo.parquet' CONNECTION (ACCESS_KEY_ID='{}',SECRET_ACCESS_KEY='{}',REGION='{}')", bucket, root, key_id, key, region); - let output = execute_sql(&instance, ©_to_stmt).await; - assert!(matches!(output, Output::AffectedRows(2))); + let output = execute_sql(&instance, ©_to_stmt).await.data; + assert!(matches!(output, OutputData::AffectedRows(2))); struct Test<'a> { sql: &'a str, @@ -1597,8 +1702,9 @@ async fn test_execute_copy_from_s3(instance: Arc) { test.table_name ), ) - .await, - Output::AffectedRows(0) + .await + .data, + OutputData::AffectedRows(0) )); let sql = format!( "{} CONNECTION (ACCESS_KEY_ID='{}',SECRET_ACCESS_KEY='{}',REGION='{}')", @@ -1606,14 +1712,15 @@ async fn test_execute_copy_from_s3(instance: Arc) { ); logging::info!("Running sql: {}", sql); - let output = execute_sql(&instance, &sql).await; - assert!(matches!(output, Output::AffectedRows(2))); + let output = execute_sql(&instance, &sql).await.data; + assert!(matches!(output, OutputData::AffectedRows(2))); let output = execute_sql( &instance, &format!("select * from {} order by ts", test.table_name), ) - .await; + .await + .data; let expected = "\ +-------+------+--------+---------------------+ | host | cpu | memory | ts | @@ -1637,7 +1744,7 @@ async fn test_execute_copy_from_orc_with_cast(instance: Arc) { &instance, "create table demo(bigint_direct timestamp(9), bigint_neg_direct timestamp(6), bigint_other timestamp(3), timestamp_simple timestamp(9), time index (bigint_other));", ) - .await, Output::AffectedRows(0))); + .await.data, OutputData::AffectedRows(0))); let filepath = find_testing_resource("/src/common/datasource/tests/orc/test.orc"); @@ -1645,11 +1752,12 @@ async fn test_execute_copy_from_orc_with_cast(instance: Arc) { &instance, &format!("copy demo from '{}' WITH(FORMAT='orc');", &filepath), ) - .await; + .await + .data; - assert!(matches!(output, Output::AffectedRows(5))); + assert!(matches!(output, OutputData::AffectedRows(5))); - let output = execute_sql(&instance, "select * from demo;").await; + let output = execute_sql(&instance, "select * from demo;").await.data; let expected = r#"+-------------------------------+----------------------------+-------------------------+----------------------------+ | bigint_direct | bigint_neg_direct | bigint_other | timestamp_simple | +-------------------------------+----------------------------+-------------------------+----------------------------+ @@ -1670,7 +1778,7 @@ async fn test_execute_copy_from_orc(instance: Arc) { &instance, "create table demo(double_a double, a float, b boolean, str_direct string, d string, e string, f string, int_short_repeated int, int_neg_short_repeated int, int_delta int, int_neg_delta int, int_direct int, int_neg_direct int, bigint_direct bigint, bigint_neg_direct bigint, bigint_other bigint, utf8_increase string, utf8_decrease string, timestamp_simple timestamp(9) time index, date_simple date);", ) - .await, Output::AffectedRows(0))); + .await.data, OutputData::AffectedRows(0))); let filepath = find_testing_resource("/src/common/datasource/tests/orc/test.orc"); @@ -1678,11 +1786,14 @@ async fn test_execute_copy_from_orc(instance: Arc) { &instance, &format!("copy demo from '{}' WITH(FORMAT='orc');", &filepath), ) - .await; + .await + .data; - assert!(matches!(output, Output::AffectedRows(5))); + assert!(matches!(output, OutputData::AffectedRows(5))); - let output = execute_sql(&instance, "select * from demo order by double_a;").await; + let output = execute_sql(&instance, "select * from demo order by double_a;") + .await + .data; let expected = r#"+----------+-----+-------+------------+-----+-----+-------+--------------------+------------------------+-----------+---------------+------------+----------------+---------------+-------------------+--------------+---------------+---------------+----------------------------+-------------+ | double_a | a | b | str_direct | d | e | f | int_short_repeated | int_neg_short_repeated | int_delta | int_neg_delta | int_direct | int_neg_direct | bigint_direct | bigint_neg_direct | bigint_other | utf8_increase | utf8_decrease | timestamp_simple | date_simple | +----------+-----+-------+------------+-----+-----+-------+--------------------+------------------------+-----------+---------------+------------+----------------+---------------+-------------------+--------------+---------------+---------------+----------------------------+-------------+ @@ -1704,7 +1815,7 @@ async fn test_cast_type_issue_1594(instance: Arc) { &instance, "create table tsbs_cpu(hostname STRING, environment STRING, usage_user DOUBLE, usage_system DOUBLE, usage_idle DOUBLE, usage_nice DOUBLE, usage_iowait DOUBLE, usage_irq DOUBLE, usage_softirq DOUBLE, usage_steal DOUBLE, usage_guest DOUBLE, usage_guest_nice DOUBLE, ts TIMESTAMP TIME INDEX, PRIMARY KEY(hostname));", ) - .await, Output::AffectedRows(0))); + .await.data, OutputData::AffectedRows(0))); let filepath = find_testing_resource("/src/common/datasource/tests/csv/type_cast.csv"); @@ -1714,9 +1825,11 @@ async fn test_cast_type_issue_1594(instance: Arc) { ) .await; - assert!(matches!(output, Output::AffectedRows(5))); + assert!(matches!(output.data, OutputData::AffectedRows(5))); - let output = execute_sql(&instance, "select * from tsbs_cpu order by hostname;").await; + let output = execute_sql(&instance, "select * from tsbs_cpu order by hostname;") + .await + .data; let expected = "\ +----------+-------------+------------+--------------+------------+------------+--------------+-----------+---------------+-------------+-------------+------------------+---------------------+ | hostname | environment | usage_user | usage_system | usage_idle | usage_nice | usage_iowait | usage_irq | usage_softirq | usage_steal | usage_guest | usage_guest_nice | ts | @@ -1736,14 +1849,16 @@ async fn test_information_schema_dot_tables(instance: Arc) { let sql = "create table another_table(i timestamp time index)"; let query_ctx = QueryContext::with("another_catalog", "another_schema"); - let output = execute_sql_with(&instance, sql, query_ctx.clone()).await; - assert!(matches!(output, Output::AffectedRows(0))); + let output = execute_sql_with(&instance, sql, query_ctx.clone()) + .await + .data; + assert!(matches!(output, OutputData::AffectedRows(0))); // User can only see information schema under current catalog. // A necessary requirement to GreptimeCloud. let sql = "select table_catalog, table_schema, table_name, table_type, table_id, engine from information_schema.tables where table_type != 'SYSTEM VIEW' and table_name in ('columns', 'numbers', 'tables', 'another_table') order by table_name"; - let output = execute_sql(&instance, sql).await; + let output = execute_sql(&instance, sql).await.data; let expected = "\ +---------------+--------------------+------------+-----------------+----------+-------------+ | table_catalog | table_schema | table_name | table_type | table_id | engine | @@ -1755,7 +1870,7 @@ async fn test_information_schema_dot_tables(instance: Arc) { check_output_stream(output, expected).await; - let output = execute_sql_with(&instance, sql, query_ctx).await; + let output = execute_sql_with(&instance, sql, query_ctx).await.data; let expected = "\ +-----------------+--------------------+---------------+-----------------+----------+--------+ | table_catalog | table_schema | table_name | table_type | table_id | engine | @@ -1774,14 +1889,16 @@ async fn test_information_schema_dot_columns(instance: Arc) { let sql = "create table another_table(i timestamp time index)"; let query_ctx = QueryContext::with("another_catalog", "another_schema"); - let output = execute_sql_with(&instance, sql, query_ctx.clone()).await; - assert!(matches!(output, Output::AffectedRows(0))); + let output = execute_sql_with(&instance, sql, query_ctx.clone()) + .await + .data; + assert!(matches!(output, OutputData::AffectedRows(0))); // User can only see information schema under current catalog. // A necessary requirement to GreptimeCloud. let sql = "select table_catalog, table_schema, table_name, column_name, data_type, semantic_type from information_schema.columns where table_name in ('columns', 'numbers', 'tables', 'another_table') order by table_name"; - let output = execute_sql(&instance, sql).await; + let output = execute_sql(&instance, sql).await.data; let expected = "\ +---------------+--------------------+------------+----------------+-----------+---------------+ | table_catalog | table_schema | table_name | column_name | data_type | semantic_type | @@ -1807,7 +1924,7 @@ async fn test_information_schema_dot_columns(instance: Arc) { check_output_stream(output, expected).await; - let output = execute_sql_with(&instance, sql, query_ctx).await; + let output = execute_sql_with(&instance, sql, query_ctx).await.data; let expected = "\ +-----------------+--------------------+---------------+----------------+----------------------+---------------+ | table_catalog | table_schema | table_name | column_name | data_type | semantic_type | @@ -1890,8 +2007,8 @@ async fn test_custom_storage(instance: Arc) { ) }; - let output = execute_sql(&instance.frontend(), &sql).await; - assert!(matches!(output, Output::AffectedRows(0))); + let output = execute_sql(&instance.frontend(), &sql).await.data; + assert!(matches!(output, OutputData::AffectedRows(0))); let output = execute_sql( &frontend, r#"insert into test_table(a, ts) values @@ -1899,10 +2016,13 @@ async fn test_custom_storage(instance: Arc) { (1000, 1655276558000) "#, ) - .await; - assert!(matches!(output, Output::AffectedRows(2))); + .await + .data; + assert!(matches!(output, OutputData::AffectedRows(2))); - let output = execute_sql(&frontend, "select * from test_table").await; + let output = execute_sql(&frontend, "select * from test_table") + .await + .data; let expected = "\ +------+---------------------+ | a | ts | @@ -1912,8 +2032,10 @@ async fn test_custom_storage(instance: Arc) { +------+---------------------+"; check_output_stream(output, expected).await; - let output = execute_sql(&frontend, "show create table test_table").await; - let Output::RecordBatches(record_batches) = output else { + let output = execute_sql(&frontend, "show create table test_table") + .await + .data; + let OutputData::RecordBatches(record_batches) = output else { unreachable!() }; @@ -1955,16 +2077,18 @@ WITH( ) }; assert_eq!(actual.to_string(), expect); - let output = execute_sql(&frontend, "truncate test_table").await; - assert!(matches!(output, Output::AffectedRows(0))); - let output = execute_sql(&frontend, "select * from test_table").await; + let output = execute_sql(&frontend, "truncate test_table").await.data; + assert!(matches!(output, OutputData::AffectedRows(0))); + let output = execute_sql(&frontend, "select * from test_table") + .await + .data; let expected = "\ ++ ++"; check_output_stream(output, expected).await; - let output = execute_sql(&frontend, "drop table test_table").await; - assert!(matches!(output, Output::AffectedRows(0))); + let output = execute_sql(&frontend, "drop table test_table").await.data; + assert!(matches!(output, OutputData::AffectedRows(0))); } } } diff --git a/tests-integration/src/tests/test_util.rs b/tests-integration/src/tests/test_util.rs index b6d47a3249..9f3bad224b 100644 --- a/tests-integration/src/tests/test_util.rs +++ b/tests-integration/src/tests/test_util.rs @@ -15,6 +15,7 @@ use std::env; use std::sync::Arc; +use client::OutputData; use common_query::Output; use common_recordbatch::util; use common_telemetry::warn; @@ -329,9 +330,9 @@ pub(crate) async fn check_unordered_output_stream(output: Output, expected: &str .unwrap() }; - let recordbatches = match output { - Output::Stream(stream, _) => util::collect_batches(stream).await.unwrap(), - Output::RecordBatches(recordbatches) => recordbatches, + let recordbatches = match output.data { + OutputData::Stream(stream) => util::collect_batches(stream).await.unwrap(), + OutputData::RecordBatches(recordbatches) => recordbatches, _ => unreachable!(), }; let pretty_print = sort_table(&recordbatches.pretty_print().unwrap()); diff --git a/tests-integration/tests/grpc.rs b/tests-integration/tests/grpc.rs index e54450589f..0d11456e59 100644 --- a/tests-integration/tests/grpc.rs +++ b/tests-integration/tests/grpc.rs @@ -20,7 +20,7 @@ use api::v1::{ PromqlRequest, RequestHeader, SemanticType, }; use auth::user_provider_from_option; -use client::{Client, Database, DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; +use client::{Client, Database, OutputData, DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_catalog::consts::MITO_ENGINE; use common_query::Output; use common_recordbatch::RecordBatches; @@ -311,7 +311,7 @@ pub async fn test_insert_and_select(store_type: StorageType) { // create let expr = testing_create_expr(); let result = db.create(expr).await.unwrap(); - assert!(matches!(result, Output::AffectedRows(0))); + assert!(matches!(result.data, OutputData::AffectedRows(0))); //alter let add_column = ColumnDef { @@ -335,7 +335,7 @@ pub async fn test_insert_and_select(store_type: StorageType) { kind: Some(kind), }; let result = db.alter(expr).await.unwrap(); - assert!(matches!(result, Output::AffectedRows(0))); + assert!(matches!(result.data, OutputData::AffectedRows(0))); // insert insert_and_assert(&db).await; @@ -373,7 +373,7 @@ async fn insert_and_assert(db: &Database) { ) .await .unwrap(); - assert!(matches!(result, Output::AffectedRows(2))); + assert!(matches!(result.data, OutputData::AffectedRows(2))); // select let output = db @@ -381,10 +381,10 @@ async fn insert_and_assert(db: &Database) { .await .unwrap(); - let record_batches = match output { - Output::RecordBatches(record_batches) => record_batches, - Output::Stream(stream, _) => RecordBatches::try_collect(stream).await.unwrap(), - Output::AffectedRows(_) => unreachable!(), + let record_batches = match output.data { + OutputData::RecordBatches(record_batches) => record_batches, + OutputData::Stream(stream) => RecordBatches::try_collect(stream).await.unwrap(), + OutputData::AffectedRows(_) => unreachable!(), }; let pretty = record_batches.pretty_print().unwrap(); @@ -479,14 +479,16 @@ pub async fn test_prom_gateway_query(store_type: StorageType) { assert!(matches!( db.sql("CREATE TABLE test(i DOUBLE, j TIMESTAMP TIME INDEX, k STRING PRIMARY KEY);") .await - .unwrap(), - Output::AffectedRows(0) + .unwrap() + .data, + OutputData::AffectedRows(0) )); assert!(matches!( db.sql(r#"INSERT INTO test VALUES (1, 1, "a"), (1, 1, "b"), (2, 2, "a");"#) .await - .unwrap(), - Output::AffectedRows(3) + .unwrap() + .data, + OutputData::AffectedRows(3) )); // Instant query using prometheus gateway service @@ -684,10 +686,10 @@ pub async fn test_grpc_timezone(store_type: StorageType) { } async fn to_batch(output: Output) -> String { - match output { - Output::RecordBatches(batch) => batch, - Output::Stream(stream, _) => RecordBatches::try_collect(stream).await.unwrap(), - Output::AffectedRows(_) => unreachable!(), + match output.data { + OutputData::RecordBatches(batch) => batch, + OutputData::Stream(stream) => RecordBatches::try_collect(stream).await.unwrap(), + OutputData::AffectedRows(_) => unreachable!(), } .pretty_print() .unwrap() diff --git a/tests-integration/tests/region_failover.rs b/tests-integration/tests/region_failover.rs index c6874ccb40..ae32bcbc82 100644 --- a/tests-integration/tests/region_failover.rs +++ b/tests-integration/tests/region_failover.rs @@ -16,6 +16,7 @@ use std::sync::Arc; use std::time::Duration; use catalog::kvbackend::{CachedMetaKvBackend, KvBackendCatalogManager}; +use client::OutputData; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_meta::key::table_route::TableRouteKey; use common_meta::key::{RegionDistribution, TableMetaKey}; @@ -104,7 +105,7 @@ pub async fn test_region_failover(store_type: StorageType) { let results = insert_values(&frontend, logical_timer).await; logical_timer += 1000; for result in results { - assert!(matches!(result.unwrap(), Output::AffectedRows(1))); + assert!(matches!(result.unwrap().data, OutputData::AffectedRows(1))); } assert!(has_route_cache(&frontend, table_id).await); @@ -144,7 +145,7 @@ pub async fn test_region_failover(store_type: StorageType) { let frontend = cluster.frontend.clone(); let results = insert_values(&frontend, logical_timer).await; for result in results { - assert!(matches!(result.unwrap(), Output::AffectedRows(1))); + assert!(matches!(result.unwrap().data, OutputData::AffectedRows(1))); } assert_values(&frontend).await; @@ -226,7 +227,7 @@ async fn assert_values(instance: &Arc) { | 55 | 2023-05-31T04:51:55 | | 55 | 2023-05-31T04:51:56 | +----+---------------------+"; - check_output_stream(result.unwrap(), expected).await; + check_output_stream(result.unwrap().data, expected).await; } async fn prepare_testing_table(cluster: &GreptimeDbCluster) -> TableId { diff --git a/tests-integration/tests/region_migration.rs b/tests-integration/tests/region_migration.rs index 18a5e12241..89175d9093 100644 --- a/tests-integration/tests/region_migration.rs +++ b/tests-integration/tests/region_migration.rs @@ -15,7 +15,7 @@ use std::sync::Arc; use std::time::Duration; -use client::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; +use client::{OutputData, DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_meta::key::{RegionDistribution, TableMetadataManagerRef}; use common_meta::peer::Peer; use common_query::Output; @@ -136,7 +136,7 @@ pub async fn test_region_migration(store_type: StorageType, endpoints: Vec) { | 55 | 2023-05-31T04:51:55 | | 55 | 2023-05-31T04:51:56 | +----+---------------------+"; - check_output_stream(result.unwrap(), expected).await; + check_output_stream(result.unwrap().data, expected).await; } async fn prepare_testing_table(cluster: &GreptimeDbCluster) -> TableId { @@ -824,7 +824,7 @@ async fn prepare_testing_table(cluster: &GreptimeDbCluster) -> TableId { ); let mut result = cluster.frontend.do_query(&sql, QueryContext::arc()).await; let output = result.remove(0).unwrap(); - assert!(matches!(output, Output::AffectedRows(0))); + assert!(matches!(output.data, OutputData::AffectedRows(0))); let table = cluster .frontend @@ -852,7 +852,7 @@ async fn find_region_distribution( async fn find_region_distribution_by_sql(cluster: &GreptimeDbCluster) -> RegionDistribution { let query_ctx = QueryContext::arc(); - let Output::Stream(stream, _) = run_sql( + let OutputData::Stream(stream) = run_sql( &cluster.frontend, &format!(r#"select b.peer_id as datanode_id, a.greptime_partition_id as region_id @@ -862,7 +862,7 @@ async fn find_region_distribution_by_sql(cluster: &GreptimeDbCluster) -> RegionD ), query_ctx.clone(), ) - .await.unwrap() else { + .await.unwrap().data else { unreachable!(); }; @@ -901,13 +901,15 @@ async fn trigger_migration_by_sql( from_peer_id: u64, to_peer_id: u64, ) -> String { - let Output::Stream(stream, _) = run_sql( + let OutputData::Stream(stream) = run_sql( &cluster.frontend, &format!("select migrate_region({region_id}, {from_peer_id}, {to_peer_id})"), QueryContext::arc(), ) .await - .unwrap() else { + .unwrap() + .data + else { unreachable!(); }; @@ -924,13 +926,15 @@ async fn trigger_migration_by_sql( /// Query procedure state by SQL. async fn query_procedure_by_sql(instance: &Arc, pid: &str) -> String { - let Output::Stream(stream, _) = run_sql( + let OutputData::Stream(stream) = run_sql( instance, &format!("select procedure_state('{pid}')"), QueryContext::arc(), ) .await - .unwrap() else { + .unwrap() + .data + else { unreachable!(); }; diff --git a/tests/runner/src/env.rs b/tests/runner/src/env.rs index bd2848d4e7..0edf7e471e 100644 --- a/tests/runner/src/env.rs +++ b/tests/runner/src/env.rs @@ -28,7 +28,7 @@ use client::{ Client, Database as DB, Error as ClientError, DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, }; use common_error::ext::ErrorExt; -use common_query::Output; +use common_query::{Output, OutputData}; use common_recordbatch::RecordBatches; use serde::Serialize; use sqlness::{Database, EnvController, QueryContext}; @@ -443,7 +443,7 @@ impl Database for GreptimeDB { .trim_end_matches(';'); client.set_schema(database); Box::new(ResultDisplayer { - result: Ok(Output::AffectedRows(0)), + result: Ok(Output::new_with_affected_rows(0)), }) as _ } else if query.trim().to_lowercase().starts_with("set time_zone") { // set time_zone='xxx' @@ -460,13 +460,19 @@ impl Database for GreptimeDB { client.set_timezone(timezone); Box::new(ResultDisplayer { - result: Ok(Output::AffectedRows(0)), + result: Ok(Output::new_with_affected_rows(0)), }) as _ } else { let mut result = client.sql(&query).await; - if let Ok(Output::Stream(stream, _)) = result { + if let Ok(Output { + data: OutputData::Stream(stream), + .. + }) = result + { match RecordBatches::try_collect(stream).await { - Ok(recordbatches) => result = Ok(Output::RecordBatches(recordbatches)), + Ok(recordbatches) => { + result = Ok(Output::new_with_record_batches(recordbatches)); + } Err(e) => { let status_code = e.status_code(); let msg = e.output_msg(); @@ -567,11 +573,11 @@ struct ResultDisplayer { impl Display for ResultDisplayer { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match &self.result { - Ok(result) => match result { - Output::AffectedRows(rows) => { + Ok(result) => match &result.data { + OutputData::AffectedRows(rows) => { write!(f, "Affected Rows: {rows}") } - Output::RecordBatches(recordbatches) => { + OutputData::RecordBatches(recordbatches) => { let pretty = recordbatches.pretty_print().map_err(|e| e.to_string()); match pretty { Ok(s) => write!(f, "{s}"), @@ -580,7 +586,7 @@ impl Display for ResultDisplayer { } } } - Output::Stream(_, _) => unreachable!(), + OutputData::Stream(_) => unreachable!(), }, Err(e) => { let status_code = e.status_code();