diff --git a/Cargo.lock b/Cargo.lock index 7404b2a25f..b76ae023ee 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6312,6 +6312,7 @@ dependencies = [ "sha1", "snafu", "snap", + "sql", "strum", "table", "tempdir", diff --git a/src/datanode/src/instance/sql.rs b/src/datanode/src/instance/sql.rs index 80149dda5c..4b0fc0ed79 100644 --- a/src/datanode/src/instance/sql.rs +++ b/src/datanode/src/instance/sql.rs @@ -33,12 +33,11 @@ use crate::metric; use crate::sql::SqlRequest; impl Instance { - pub async fn execute_sql(&self, sql: &str, query_ctx: QueryContextRef) -> Result { - let stmt = self - .query_engine - .sql_to_statement(sql) - .context(ExecuteSqlSnafu)?; - + pub async fn execute_stmt( + &self, + stmt: Statement, + query_ctx: QueryContextRef, + ) -> Result { match stmt { Statement::Query(_) => { let logical_plan = self @@ -153,6 +152,14 @@ impl Instance { } } } + + pub async fn execute_sql(&self, sql: &str, query_ctx: QueryContextRef) -> Result { + let stmt = self + .query_engine + .sql_to_statement(sql) + .context(ExecuteSqlSnafu)?; + self.execute_stmt(stmt, query_ctx).await + } } // TODO(LFC): Refactor consideration: move this function to some helper mod, @@ -193,15 +200,33 @@ impl SqlQueryHandler for Instance { &self, query: &str, query_ctx: QueryContextRef, - ) -> servers::error::Result { + ) -> Vec> { let _timer = timer!(metric::METRIC_HANDLE_SQL_ELAPSED); - self.execute_sql(query, query_ctx) + // we assume sql string has only 1 statement in datanode + let result = self + .execute_sql(query, query_ctx) .await .map_err(|e| { error!(e; "Instance failed to execute sql"); BoxedError::new(e) }) - .context(servers::error::ExecuteQuerySnafu { query }) + .context(servers::error::ExecuteQuerySnafu { query }); + vec![result] + } + + async fn do_statement_query( + &self, + stmt: Statement, + query_ctx: QueryContextRef, + ) -> servers::error::Result { + let _timer = timer!(metric::METRIC_HANDLE_SQL_ELAPSED); + self.execute_stmt(stmt, query_ctx) + .await + .map_err(|e| { + error!(e; "Instance failed to execute sql"); + BoxedError::new(e) + }) + .context(servers::error::ExecuteStatementSnafu) } } diff --git a/src/frontend/src/error.rs b/src/frontend/src/error.rs index 2f40eec4b2..bf7542cd6d 100644 --- a/src/frontend/src/error.rs +++ b/src/frontend/src/error.rs @@ -387,6 +387,12 @@ pub enum Error { source: query::error::Error, }, + #[snafu(display("Failed to execute statement, source: {}", source))] + ExecuteStatement { + #[snafu(backtrace)] + source: query::error::Error, + }, + #[snafu(display("Failed to do vector computation, source: {}", source))] VectorComputation { #[snafu(backtrace)] @@ -536,6 +542,7 @@ impl ErrorExt for Error { Error::DeserializeInsertBatch { source, .. } => source.status_code(), Error::PrimaryKeyNotFound { .. } => StatusCode::InvalidArguments, Error::ExecuteSql { source, .. } => source.status_code(), + Error::ExecuteStatement { source, .. } => source.status_code(), Error::InsertBatchToRequest { source, .. } => source.status_code(), Error::CollectRecordbatchStream { source } | Error::CreateRecordbatches { source } => { source.status_code() diff --git a/src/frontend/src/instance.rs b/src/frontend/src/instance.rs index 849b43c984..721fc5008a 100644 --- a/src/frontend/src/instance.rs +++ b/src/frontend/src/instance.rs @@ -461,31 +461,18 @@ impl FrontendInstance for Instance { } } -fn parse_stmt(sql: &str) -> Result { - let mut stmt = ParserContext::create_with_dialect(sql, &GenericDialect {}) - .context(error::ParseSqlSnafu)?; - // TODO(LFC): Support executing multiple SQL queries, - // which seems to be a major change to our whole server framework? - ensure!( - stmt.len() == 1, - error::InvalidSqlSnafu { - err_msg: "Currently executing multiple SQL queries are not supported." - } - ); - Ok(stmt.remove(0)) +fn parse_stmt(sql: &str) -> Result> { + ParserContext::create_with_dialect(sql, &GenericDialect {}).context(error::ParseSqlSnafu) } -#[async_trait] -impl SqlQueryHandler for Instance { - async fn do_query( +impl Instance { + async fn query_statement( &self, - query: &str, + stmt: Statement, query_ctx: QueryContextRef, ) -> server_error::Result { - let stmt = parse_stmt(query) - .map_err(BoxedError::new) - .context(server_error::ExecuteQuerySnafu { query })?; - + // TODO(sunng87): provide a better form to log or track statement + let query = &format!("{:?}", &stmt); match stmt { Statement::CreateDatabase(_) | Statement::ShowDatabases(_) @@ -494,7 +481,7 @@ impl SqlQueryHandler for Instance { | Statement::DescribeTable(_) | Statement::Explain(_) | Statement::Query(_) => { - return self.sql_handler.do_query(query, query_ctx).await; + return self.sql_handler.do_statement_query(stmt, query_ctx).await; } Statement::Insert(insert) => match self.mode { Mode::Standalone => { @@ -569,6 +556,45 @@ impl SqlQueryHandler for Instance { } } +#[async_trait] +impl SqlQueryHandler for Instance { + async fn do_query( + &self, + query: &str, + query_ctx: QueryContextRef, + ) -> Vec> { + match parse_stmt(query) + .map_err(BoxedError::new) + .context(server_error::ExecuteQuerySnafu { query }) + { + Ok(stmts) => { + let mut results = Vec::with_capacity(stmts.len()); + for stmt in stmts { + match self.query_statement(stmt, query_ctx.clone()).await { + Ok(output) => results.push(Ok(output)), + Err(e) => { + results.push(Err(e)); + break; + } + } + } + results + } + Err(e) => { + vec![Err(e)] + } + } + } + + async fn do_statement_query( + &self, + stmt: Statement, + query_ctx: QueryContextRef, + ) -> server_error::Result { + self.query_statement(stmt, query_ctx).await + } +} + #[async_trait] impl ScriptHandler for Instance { async fn insert_script(&self, name: &str, script: &str) -> server_error::Result<()> { @@ -671,6 +697,7 @@ mod tests { ) engine=mito with(regions=1);"#; let output = SqlQueryHandler::do_query(&*instance, sql, query_ctx.clone()) .await + .remove(0) .unwrap(); match output { Output::AffectedRows(rows) => assert_eq!(rows, 1), @@ -684,6 +711,7 @@ mod tests { "#; let output = SqlQueryHandler::do_query(&*instance, sql, query_ctx.clone()) .await + .remove(0) .unwrap(); match output { Output::AffectedRows(rows) => assert_eq!(rows, 3), @@ -693,6 +721,7 @@ mod tests { let sql = "select * from demo"; let output = SqlQueryHandler::do_query(&*instance, sql, query_ctx.clone()) .await + .remove(0) .unwrap(); match output { Output::RecordBatches(_) => { @@ -720,6 +749,7 @@ mod tests { let sql = "select * from demo where ts>cast(1000000000 as timestamp)"; // use nanoseconds as where condition let output = SqlQueryHandler::do_query(&*instance, sql, query_ctx.clone()) .await + .remove(0) .unwrap(); match output { Output::RecordBatches(_) => { diff --git a/src/frontend/src/instance/distributed.rs b/src/frontend/src/instance/distributed.rs index 5d46823f01..de09c1b4d3 100644 --- a/src/frontend/src/instance/distributed.rs +++ b/src/frontend/src/instance/distributed.rs @@ -142,14 +142,17 @@ impl DistInstance { Ok(Output::AffectedRows(0)) } - async fn handle_sql(&self, sql: &str, query_ctx: QueryContextRef) -> Result { - let stmt = parse_stmt(sql)?; + async fn handle_statement( + &self, + stmt: Statement, + query_ctx: QueryContextRef, + ) -> Result { match stmt { Statement::Query(_) => { let plan = self .query_engine .statement_to_plan(stmt, query_ctx) - .context(error::ExecuteSqlSnafu { sql })?; + .context(error::ExecuteStatementSnafu {})?; self.query_engine.execute(&plan).await } Statement::CreateDatabase(stmt) => { @@ -173,7 +176,30 @@ impl DistInstance { } _ => unreachable!(), } - .context(error::ExecuteSqlSnafu { sql }) + .context(error::ExecuteStatementSnafu) + } + + async fn handle_sql(&self, sql: &str, query_ctx: QueryContextRef) -> Vec> { + let stmts = parse_stmt(sql); + match stmts { + Ok(stmts) => { + let mut results = Vec::with_capacity(stmts.len()); + + for stmt in stmts { + let result = self.handle_statement(stmt, query_ctx.clone()).await; + let is_err = result.is_err(); + + results.push(result); + + if is_err { + break; + } + } + + results + } + Err(e) => vec![Err(e)], + } } /// Handles distributed database creation @@ -310,11 +336,26 @@ impl SqlQueryHandler for DistInstance { &self, query: &str, query_ctx: QueryContextRef, - ) -> server_error::Result { + ) -> Vec> { self.handle_sql(query, query_ctx) + .await + .into_iter() + .map(|r| { + r.map_err(BoxedError::new) + .context(server_error::ExecuteQuerySnafu { query }) + }) + .collect() + } + + async fn do_statement_query( + &self, + stmt: Statement, + query_ctx: QueryContextRef, + ) -> server_error::Result { + self.handle_statement(stmt, query_ctx) .await .map_err(BoxedError::new) - .context(server_error::ExecuteQuerySnafu { query }) + .context(server_error::ExecuteStatementSnafu) } } @@ -555,7 +596,7 @@ mod test { let cases = [ ( r" -CREATE TABLE rcx ( a INT, b STRING, c TIMESTAMP, TIME INDEX (c) ) +CREATE TABLE rcx ( a INT, b STRING, c TIMESTAMP, TIME INDEX (c) ) PARTITION BY RANGE COLUMNS (b) ( PARTITION r0 VALUES LESS THAN ('hz'), PARTITION r1 VALUES LESS THAN ('sh'), @@ -601,6 +642,7 @@ ENGINE=mito", let output = dist_instance .handle_sql(sql, QueryContext::arc()) .await + .remove(0) .unwrap(); match output { Output::AffectedRows(rows) => assert_eq!(rows, 1), @@ -611,6 +653,7 @@ ENGINE=mito", let output = dist_instance .handle_sql(sql, QueryContext::arc()) .await + .remove(0) .unwrap(); match output { Output::RecordBatches(r) => { @@ -649,6 +692,7 @@ ENGINE=mito", dist_instance .handle_sql(sql, QueryContext::arc()) .await + .remove(0) .unwrap(); let sql = " @@ -667,11 +711,16 @@ ENGINE=mito", dist_instance .handle_sql(sql, QueryContext::arc()) .await + .remove(0) .unwrap(); async fn assert_show_tables(instance: SqlQueryHandlerRef) { let sql = "show tables in test_show_tables"; - let output = instance.do_query(sql, QueryContext::arc()).await.unwrap(); + let output = instance + .do_query(sql, QueryContext::arc()) + .await + .remove(0) + .unwrap(); match output { Output::RecordBatches(r) => { let expected = r#"+--------------+ diff --git a/src/frontend/src/instance/opentsdb.rs b/src/frontend/src/instance/opentsdb.rs index 9bcec20bb7..89cb869fc5 100644 --- a/src/frontend/src/instance/opentsdb.rs +++ b/src/frontend/src/instance/opentsdb.rs @@ -130,6 +130,7 @@ mod tests { Arc::new(QueryContext::new()), ) .await + .remove(0) .unwrap(); match output { Output::Stream(stream) => { diff --git a/src/frontend/src/instance/prometheus.rs b/src/frontend/src/instance/prometheus.rs index 1257d186c8..d1a848e158 100644 --- a/src/frontend/src/instance/prometheus.rs +++ b/src/frontend/src/instance/prometheus.rs @@ -27,9 +27,9 @@ use servers::prometheus::{self, Metrics}; use servers::query_handler::{PrometheusProtocolHandler, PrometheusResponse}; use servers::Mode; use session::context::QueryContext; -use snafu::{OptionExt, ResultExt}; +use snafu::{ensure, OptionExt, ResultExt}; -use crate::instance::Instance; +use crate::instance::{parse_stmt, Instance}; const SAMPLES_RESPONSE_TYPE: i32 = ResponseType::Samples as i32; @@ -94,13 +94,26 @@ impl Instance { ); let query_ctx = Arc::new(QueryContext::with_current_schema(db.to_string())); - let output = self.sql_handler.do_query(&sql, query_ctx).await; + + let mut stmts = parse_stmt(&sql) + .map_err(BoxedError::new) + .context(error::ExecuteQuerySnafu { query: &sql })?; + + ensure!( + stmts.len() == 1, + error::InvalidQuerySnafu { + reason: "The sql has multiple statements".to_string() + } + ); + let stmt = stmts.remove(0); + + let output = self.sql_handler.do_statement_query(stmt, query_ctx).await; let object_result = to_object_result(output) .await .try_into() .map_err(BoxedError::new) - .context(error::ExecuteQuerySnafu { query: sql })?; + .context(error::ExecuteQuerySnafu { query: &sql })?; results.push((table_name, object_result)); } diff --git a/src/query/src/datafusion.rs b/src/query/src/datafusion.rs index 0968d99357..79c3bc7938 100644 --- a/src/query/src/datafusion.rs +++ b/src/query/src/datafusion.rs @@ -33,7 +33,7 @@ use common_telemetry::timer; use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion::physical_plan::ExecutionPlan; use session::context::QueryContextRef; -use snafu::{OptionExt, ResultExt}; +use snafu::{ensure, OptionExt, ResultExt}; use sql::dialect::GenericDialect; use sql::parser::ParserContext; use sql::statements::statement::Statement; @@ -72,8 +72,7 @@ impl QueryEngine for DatafusionQueryEngine { fn sql_to_statement(&self, sql: &str) -> Result { let mut statement = ParserContext::create_with_dialect(sql, &GenericDialect {}) .context(error::ParseSqlSnafu)?; - // TODO(dennis): supports multi statement in one sql? - assert!(1 == statement.len()); + ensure!(1 == statement.len(), error::MultipleStatementsSnafu { sql }); Ok(statement.remove(0)) } @@ -280,6 +279,7 @@ mod tests { .sql_to_plan(sql, Arc::new(QueryContext::new())) .unwrap(); + // TODO(sunng87): do not rely on to_string for compare assert_eq!( format!("{:?}", plan), r#"DfPlan(Limit: skip=0, fetch=20 @@ -297,6 +297,7 @@ mod tests { let plan = engine .sql_to_plan(sql, Arc::new(QueryContext::new())) .unwrap(); + let output = engine.execute(&plan).await.unwrap(); match output { diff --git a/src/query/src/datafusion/error.rs b/src/query/src/datafusion/error.rs index b7bb5d9919..95ffc8d843 100644 --- a/src/query/src/datafusion/error.rs +++ b/src/query/src/datafusion/error.rs @@ -40,6 +40,9 @@ pub enum InnerError { source: sql::error::Error, }, + #[snafu(display("The SQL string has multiple statements, sql: {}", sql))] + MultipleStatements { sql: String, backtrace: Backtrace }, + #[snafu(display("Cannot plan SQL: {}, source: {}", sql, source))] PlanSql { sql: String, @@ -90,6 +93,7 @@ impl ErrorExt for InnerError { PlanSql { .. } => StatusCode::PlanQuery, ConvertDfRecordBatchStream { source } => source.status_code(), ExecutePhysicalPlan { source } => source.status_code(), + MultipleStatements { .. } => StatusCode::InvalidArguments, } } diff --git a/src/servers/Cargo.toml b/src/servers/Cargo.toml index 3abb18b1c2..d708cc5551 100644 --- a/src/servers/Cargo.toml +++ b/src/servers/Cargo.toml @@ -47,6 +47,7 @@ session = { path = "../session" } sha1 = "0.10" snafu = { version = "0.7", features = ["backtraces"] } snap = "1" +sql = { path = "../sql" } strum = { version = "0.24", features = ["derive"] } table = { path = "../table" } tokio = { version = "1.20", features = ["full"] } diff --git a/src/servers/src/error.rs b/src/servers/src/error.rs index a24cee44b6..5c738cfc12 100644 --- a/src/servers/src/error.rs +++ b/src/servers/src/error.rs @@ -78,6 +78,12 @@ pub enum Error { source: BoxedError, }, + #[snafu(display("Failed to execute sql statement, source: {}", source))] + ExecuteStatement { + #[snafu(backtrace)] + source: BoxedError, + }, + #[snafu(display("Failed to execute insert: {}, source: {}", msg, source))] ExecuteInsert { msg: String, @@ -257,6 +263,7 @@ impl ErrorExt for Error { InsertScript { source, .. } | ExecuteScript { source, .. } | ExecuteQuery { source, .. } + | ExecuteStatement { source, .. } | ExecuteInsert { source, .. } | ExecuteAlter { source, .. } | PutOpentsdbDataPoint { source, .. } => source.status_code(), diff --git a/src/servers/src/http.rs b/src/servers/src/http.rs index 3885543e98..955c67c9cb 100644 --- a/src/servers/src/http.rs +++ b/src/servers/src/http.rs @@ -232,28 +232,52 @@ impl JsonResponse { } /// Create a json response from query result - async fn from_output(output: Result) -> Self { - match output { - Ok(Output::AffectedRows(rows)) => { - Self::with_output(Some(vec![JsonOutput::AffectedRows(rows)])) - } - Ok(Output::Stream(stream)) => match util::collect(stream).await { - Ok(rows) => match HttpRecordsOutput::try_from(rows) { - Ok(rows) => Self::with_output(Some(vec![JsonOutput::Records(rows)])), - Err(err) => Self::with_error(err, StatusCode::Internal), + async fn from_output(outputs: Vec>) -> Self { + // TODO(sunng87): this api response structure cannot represent error + // well. It hides successful execution results from error response + let mut results = Vec::with_capacity(outputs.len()); + for out in outputs { + match out { + Ok(Output::AffectedRows(rows)) => { + results.push(JsonOutput::AffectedRows(rows)); + } + Ok(Output::Stream(stream)) => { + // TODO(sunng87): streaming response + match util::collect(stream).await { + Ok(rows) => match HttpRecordsOutput::try_from(rows) { + Ok(rows) => { + results.push(JsonOutput::Records(rows)); + } + Err(err) => { + return Self::with_error(err, StatusCode::Internal); + } + }, + + Err(e) => { + return Self::with_error( + format!("Recordbatch error: {}", e), + e.status_code(), + ); + } + } + } + Ok(Output::RecordBatches(rbs)) => match HttpRecordsOutput::try_from(rbs.take()) { + Ok(rows) => { + results.push(JsonOutput::Records(rows)); + } + Err(err) => { + return Self::with_error(err, StatusCode::Internal); + } }, - Err(e) => Self::with_error(format!("Recordbatch error: {}", e), e.status_code()), - }, - Ok(Output::RecordBatches(recordbatches)) => { - match HttpRecordsOutput::try_from(recordbatches.take()) { - Ok(rows) => Self::with_output(Some(vec![JsonOutput::Records(rows)])), - Err(err) => Self::with_error(err, StatusCode::Internal), + Err(e) => { + return Self::with_error( + format!("Query engine output error: {}", e), + e.status_code(), + ); } } - Err(e) => { - Self::with_error(format!("Query engine output error: {}", e), e.status_code()) - } } + Self::with_output(Some(results)) } pub fn code(&self) -> u32 { @@ -519,7 +543,15 @@ mod test { #[async_trait] impl SqlQueryHandler for DummyInstance { - async fn do_query(&self, _: &str, _: QueryContextRef) -> Result { + async fn do_query(&self, _: &str, _: QueryContextRef) -> Vec> { + unimplemented!() + } + + async fn do_statement_query( + &self, + _stmt: sql::statements::statement::Statement, + _query_ctx: QueryContextRef, + ) -> Result { unimplemented!() } } @@ -582,7 +614,8 @@ mod test { let recordbatch = RecordBatch::new(schema.clone(), columns).unwrap(); let recordbatches = RecordBatches::try_new(schema.clone(), vec![recordbatch]).unwrap(); - let json_resp = JsonResponse::from_output(Ok(Output::RecordBatches(recordbatches))).await; + let json_resp = + JsonResponse::from_output(vec![Ok(Output::RecordBatches(recordbatches))]).await; let json_output = &json_resp.output.unwrap()[0]; if let JsonOutput::Records(r) = json_output { diff --git a/src/servers/src/http/handler.rs b/src/servers/src/http/handler.rs index 37a36ca5b8..d148851be4 100644 --- a/src/servers/src/http/handler.rs +++ b/src/servers/src/http/handler.rs @@ -45,8 +45,11 @@ pub async fn sql( let sql_handler = &state.sql_handler; let start = Instant::now(); let resp = if let Some(sql) = ¶ms.sql { - // TODO(LFC): Sessions in http server. let query_ctx = Arc::new(QueryContext::new()); + if let Some(db) = params.database { + query_ctx.set_current_schema(db.as_ref()); + } + JsonResponse::from_output(sql_handler.do_query(sql, query_ctx).await).await } else { JsonResponse::with_error( @@ -78,8 +81,8 @@ pub struct HealthQuery {} #[derive(Debug, Serialize, Deserialize, JsonSchema, PartialEq, Eq)] pub struct HealthResponse {} -/// Handler to export healthy check -/// +/// Handler to export healthy check +/// /// Currently simply return status "200 OK" (default) with an empty json payload "{}" #[axum_macros::debug_handler] pub async fn health(Query(_params): Query) -> Json { diff --git a/src/servers/src/http/script.rs b/src/servers/src/http/script.rs index 5b9e17b600..04d0571b8d 100644 --- a/src/servers/src/http/script.rs +++ b/src/servers/src/http/script.rs @@ -91,7 +91,7 @@ pub async fn run_script( } let output = script_handler.execute_script(name.unwrap()).await; - let resp = JsonResponse::from_output(output).await; + let resp = JsonResponse::from_output(vec![output]).await; Json(resp.with_execution_time(start.elapsed().as_millis())) } else { diff --git a/src/servers/src/mysql/handler.rs b/src/servers/src/mysql/handler.rs index 55100dff9d..cfb24514a3 100644 --- a/src/servers/src/mysql/handler.rs +++ b/src/servers/src/mysql/handler.rs @@ -73,7 +73,7 @@ impl MysqlInstanceShim { } } - async fn do_query(&self, query: &str) -> Result { + async fn do_query(&self, query: &str) -> Vec> { debug!("Start executing query: '{}'", query); let start = Instant::now(); @@ -82,7 +82,7 @@ impl MysqlInstanceShim { // components, this is quick and dirty, there must be a better way to do it. let output = if let Some(output) = crate::mysql::federated::check(query, self.session.context()) { - Ok(output) + vec![Ok(output)] } else { self.query_handler .do_query(query, self.session.context()) @@ -193,14 +193,17 @@ impl AsyncMysqlShim for MysqlInstanceShi query: &'a str, writer: QueryResultWriter<'a, W>, ) -> Result<()> { - let output = self.do_query(query).await; + let outputs = self.do_query(query).await; let mut writer = MysqlResultWriter::new(writer); - writer.write(query, output).await + for output in outputs { + writer.write(query, output).await?; + } + Ok(()) } async fn on_init<'a>(&'a mut self, database: &'a str, w: InitWriter<'a, W>) -> Result<()> { let query = format!("USE {}", database.trim()); - let output = self.do_query(&query).await; + let output = self.do_query(&query).await.remove(0); if let Err(e) = output { w.error(ErrorKind::ER_UNKNOWN_ERROR, e.to_string().as_bytes()) .await diff --git a/src/servers/src/postgres/handler.rs b/src/servers/src/postgres/handler.rs index 3d9b11c077..6fa4da6a16 100644 --- a/src/servers/src/postgres/handler.rs +++ b/src/servers/src/postgres/handler.rs @@ -26,7 +26,7 @@ use pgwire::api::portal::Portal; use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler}; use pgwire::api::results::{text_query_response, FieldInfo, Response, Tag, TextDataRowEncoder}; use pgwire::api::{ClientInfo, Type}; -use pgwire::error::{PgWireError, PgWireResult}; +use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; use session::context::QueryContext; use crate::error::{self, Error, Result}; @@ -61,36 +61,43 @@ impl SimpleQueryHandler for PostgresServerHandler { C: ClientInfo + Unpin + Send + Sync, { let query_ctx = query_context_from_client_info(client); - let output = self - .query_handler - .do_query(query, query_ctx) - .await - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let outputs = self.query_handler.do_query(query, query_ctx).await; - match output { - Output::AffectedRows(rows) => Ok(vec![Response::Execution(Tag::new_for_execution( - "OK", - Some(rows), - ))]), - Output::Stream(record_stream) => { - let schema = record_stream.schema(); - recordbatches_to_query_response(record_stream, schema) - } - Output::RecordBatches(recordbatches) => { - let schema = recordbatches.schema(); - recordbatches_to_query_response( - stream::iter(recordbatches.take().into_iter().map(Ok)), - schema, - ) - } + let mut results = Vec::with_capacity(outputs.len()); + + for output in outputs { + let resp = match output { + Ok(Output::AffectedRows(rows)) => { + Response::Execution(Tag::new_for_execution("OK", Some(rows))) + } + Ok(Output::Stream(record_stream)) => { + let schema = record_stream.schema(); + recordbatches_to_query_response(record_stream, schema)? + } + Ok(Output::RecordBatches(recordbatches)) => { + let schema = recordbatches.schema(); + recordbatches_to_query_response( + stream::iter(recordbatches.take().into_iter().map(Ok)), + schema, + )? + } + Err(e) => Response::Error(Box::new(ErrorInfo::new( + "ERROR".to_string(), + "XX000".to_string(), + e.to_string(), + ))), + }; + results.push(resp); } + + Ok(results) } } fn recordbatches_to_query_response( recordbatches_stream: S, schema: SchemaRef, -) -> PgWireResult> +) -> PgWireResult where S: Stream> + Send + Unpin + 'static, { @@ -121,10 +128,10 @@ where }) }); - Ok(vec![Response::Query(text_query_response( + Ok(Response::Query(text_query_response( pg_schema, data_row_stream, - ))]) + ))) } fn schema_to_pg(origin: SchemaRef) -> Result> { diff --git a/src/servers/src/query_handler.rs b/src/servers/src/query_handler.rs index d9a48ba30f..3abb848737 100644 --- a/src/servers/src/query_handler.rs +++ b/src/servers/src/query_handler.rs @@ -19,6 +19,7 @@ use api::v1::{AdminExpr, AdminResult, ObjectExpr, ObjectResult}; use async_trait::async_trait; use common_query::Output; use session::context::QueryContextRef; +use sql::statements::statement::Statement; use crate::error::Result; use crate::influxdb::InfluxdbRequest; @@ -45,7 +46,13 @@ pub type ScriptHandlerRef = Arc; #[async_trait] pub trait SqlQueryHandler { - async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Result; + async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Vec>; + + async fn do_statement_query( + &self, + stmt: Statement, + query_ctx: QueryContextRef, + ) -> Result; } #[async_trait] diff --git a/src/servers/tests/http/influxdb_test.rs b/src/servers/tests/http/influxdb_test.rs index 13d5cab06f..547bb90900 100644 --- a/src/servers/tests/http/influxdb_test.rs +++ b/src/servers/tests/http/influxdb_test.rs @@ -46,7 +46,15 @@ impl InfluxdbLineProtocolHandler for DummyInstance { #[async_trait] impl SqlQueryHandler for DummyInstance { - async fn do_query(&self, _: &str, _: QueryContextRef) -> Result { + async fn do_query(&self, _: &str, _: QueryContextRef) -> Vec> { + unimplemented!() + } + + async fn do_statement_query( + &self, + _stmt: sql::statements::statement::Statement, + _query_ctx: QueryContextRef, + ) -> Result { unimplemented!() } } diff --git a/src/servers/tests/http/opentsdb_test.rs b/src/servers/tests/http/opentsdb_test.rs index 3b51f66965..15ade25fab 100644 --- a/src/servers/tests/http/opentsdb_test.rs +++ b/src/servers/tests/http/opentsdb_test.rs @@ -45,7 +45,15 @@ impl OpentsdbProtocolHandler for DummyInstance { #[async_trait] impl SqlQueryHandler for DummyInstance { - async fn do_query(&self, _: &str, _: QueryContextRef) -> Result { + async fn do_query(&self, _: &str, _: QueryContextRef) -> Vec> { + unimplemented!() + } + + async fn do_statement_query( + &self, + _stmt: sql::statements::statement::Statement, + _query_ctx: QueryContextRef, + ) -> Result { unimplemented!() } } diff --git a/src/servers/tests/http/prometheus_test.rs b/src/servers/tests/http/prometheus_test.rs index b7df350505..9a895e3bc4 100644 --- a/src/servers/tests/http/prometheus_test.rs +++ b/src/servers/tests/http/prometheus_test.rs @@ -70,7 +70,15 @@ impl PrometheusProtocolHandler for DummyInstance { #[async_trait] impl SqlQueryHandler for DummyInstance { - async fn do_query(&self, _: &str, _: QueryContextRef) -> Result { + async fn do_query(&self, _: &str, _: QueryContextRef) -> Vec> { + unimplemented!() + } + + async fn do_statement_query( + &self, + _stmt: sql::statements::statement::Statement, + _query_ctx: QueryContextRef, + ) -> Result { unimplemented!() } } diff --git a/src/servers/tests/mod.rs b/src/servers/tests/mod.rs index 63c8e2ebe2..5f2692bcab 100644 --- a/src/servers/tests/mod.rs +++ b/src/servers/tests/mod.rs @@ -54,9 +54,18 @@ impl DummyInstance { #[async_trait] impl SqlQueryHandler for DummyInstance { - async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Result { + async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Vec> { let plan = self.query_engine.sql_to_plan(query, query_ctx).unwrap(); - Ok(self.query_engine.execute(&plan).await.unwrap()) + let output = self.query_engine.execute(&plan).await.unwrap(); + vec![Ok(output)] + } + + async fn do_statement_query( + &self, + _stmt: sql::statements::statement::Statement, + _query_ctx: QueryContextRef, + ) -> Result { + unimplemented!() } } diff --git a/tests-integration/tests/http.rs b/tests-integration/tests/http.rs index 267c49e824..ecdfac3b62 100644 --- a/tests-integration/tests/http.rs +++ b/tests-integration/tests/http.rs @@ -59,13 +59,12 @@ macro_rules! http_tests { pub async fn test_sql_api(store_type: StorageType) { common_telemetry::init_default_ut_logging(); - let (app, mut guard) = setup_test_app(store_type, "sql_api").await; + let (app, mut guard) = setup_test_app_with_frontend(store_type, "sql_api").await; let client = TestClient::new(app); let res = client.get("/v1/sql").send().await; assert_eq!(res.status(), StatusCode::OK); let body = serde_json::from_str::(&res.text().await).unwrap(); - // body json: r#"{"code":1004,"error":"sql parameter is required."}"# assert_eq!(body.code(), 1004); assert_eq!(body.error().unwrap(), "sql parameter is required."); assert!(body.execution_time_ms().is_some()); @@ -77,9 +76,6 @@ pub async fn test_sql_api(store_type: StorageType) { assert_eq!(res.status(), StatusCode::OK); let body = serde_json::from_str::(&res.text().await).unwrap(); - // body json: - // r#"{"code":0,"output":[{"records":{"schema":{"column_schemas":[{"name":"number","data_type":"UInt32"}]},"rows":[[0],[1],[2],[3],[4],[5],[6],[7],[8],[9]]}}]}"# - assert!(body.success()); assert!(body.execution_time_ms().is_some()); @@ -107,7 +103,6 @@ pub async fn test_sql_api(store_type: StorageType) { assert_eq!(res.status(), StatusCode::OK); let body = serde_json::from_str::(&res.text().await).unwrap(); - // body json: r#"{"code":0,"output":[{"records":{"schema":{"column_schemas":[{"name":"host","data_type":"String"},{"name":"cpu","data_type":"Float64"},{"name":"memory","data_type":"Float64"},{"name":"ts","data_type":"Timestamp"}]},"rows":[["host",66.6,1024.0,0]]}}]}"# assert!(body.success()); assert!(body.execution_time_ms().is_some()); let output = body.output().unwrap(); @@ -128,8 +123,6 @@ pub async fn test_sql_api(store_type: StorageType) { assert_eq!(res.status(), StatusCode::OK); let body = serde_json::from_str::(&res.text().await).unwrap(); - // body json: - // r#"{"code":0,"output":[{"records":{"schema":{"column_schemas":[{"name":"cpu","data_type":"Float64"},{"name":"ts","data_type":"Timestamp"}]},"rows":[[66.6,0]]}}]}"# assert!(body.success()); assert!(body.execution_time_ms().is_some()); let output = body.output().unwrap(); @@ -150,8 +143,6 @@ pub async fn test_sql_api(store_type: StorageType) { assert_eq!(res.status(), StatusCode::OK); let body = serde_json::from_str::(&res.text().await).unwrap(); - // body json: - // r#"{"code":0,"output":[{"records":{"schema":{"column_schemas":[{"name":"c","data_type":"Float64"},{"name":"time","data_type":"Timestamp"}]},"rows":[[66.6,0]]}}]}"# assert!(body.success()); assert!(body.execution_time_ms().is_some()); let output = body.output().unwrap(); @@ -163,6 +154,44 @@ pub async fn test_sql_api(store_type: StorageType) { })).unwrap() ); + // test multi-statement + let res = client + .get("/v1/sql?sql=select cpu, ts from demo limit 1;select cpu, ts from demo where ts > 0;") + .send() + .await; + assert_eq!(res.status(), StatusCode::OK); + + let body = serde_json::from_str::(&res.text().await).unwrap(); + assert!(body.success()); + assert!(body.execution_time_ms().is_some()); + let outputs = body.output().unwrap(); + assert_eq!(outputs.len(), 2); + assert_eq!( + outputs[0], + serde_json::from_value::(json!({ + "records":{"schema":{"column_schemas":[{"name":"cpu","data_type":"Float64"},{"name":"ts","data_type":"TimestampMillisecond"}]},"rows":[[66.6,0]]} + })).unwrap() + ); + assert_eq!( + outputs[1], + serde_json::from_value::(json!({ + "records":{"rows":[]} + })) + .unwrap() + ); + + // test multi-statement with error + let res = client + .get("/v1/sql?sql=select cpu, ts from demo limit 1;select cpu, ts from demo2 where ts > 0;") + .send() + .await; + assert_eq!(res.status(), StatusCode::OK); + + let body = serde_json::from_str::(&res.text().await).unwrap(); + assert!(!body.success()); + assert!(body.execution_time_ms().is_some()); + assert!(body.error().unwrap().contains("not found")); + guard.remove_all().await; } @@ -206,7 +235,6 @@ def test(n): assert_eq!(res.status(), StatusCode::OK); let body = serde_json::from_str::(&res.text().await).unwrap(); - // body json: r#"{"code":0}"# assert_eq!(body.code(), 0); assert!(body.output().is_none()); @@ -215,8 +243,6 @@ def test(n): assert_eq!(res.status(), StatusCode::OK); let body = serde_json::from_str::(&res.text().await).unwrap(); - // body json: - // r#"{"code":0,"output":[{"records":{"schema":{"column_schemas":[{"name":"n","data_type":"Float64"}]},"rows":[[1.0],[2.0],[3.0],[4.0],[5.0],[6.0],[7.0],[8.0],[9.0],[10.0]]}}]}"# assert_eq!(body.code(), 0); assert!(body.execution_time_ms().is_some()); let output = body.output().unwrap();