feat: support "use" stmt part 1 (#672)

* feat: a bare sketch of session; support "use" in MySQL server; modify insertion and selection related codes in Datanode
This commit is contained in:
LFC
2022-12-01 17:05:32 +08:00
committed by GitHub
parent 2e17e9c4b5
commit 6127706b5b
60 changed files with 943 additions and 494 deletions

13
Cargo.lock generated
View File

@@ -1834,6 +1834,7 @@ dependencies = [
"serde",
"serde_json",
"servers",
"session",
"snafu",
"sql",
"storage",
@@ -2231,6 +2232,7 @@ dependencies = [
"serde",
"serde_json",
"servers",
"session",
"snafu",
"sql",
"sqlparser",
@@ -4515,6 +4517,7 @@ dependencies = [
"rand 0.8.5",
"serde",
"serde_json",
"session",
"snafu",
"sql",
"statrs",
@@ -5366,6 +5369,7 @@ dependencies = [
"rustpython-parser",
"rustpython-vm",
"serde",
"session",
"snafu",
"sql",
"storage",
@@ -5535,6 +5539,7 @@ dependencies = [
"script",
"serde",
"serde_json",
"session",
"snafu",
"snap",
"table",
@@ -5550,6 +5555,14 @@ dependencies = [
"tower-http",
]
[[package]]
name = "session"
version = "0.1.0"
dependencies = [
"arc-swap",
"common-telemetry",
]
[[package]]
name = "sha-1"
version = "0.10.0"

View File

@@ -28,6 +28,7 @@ members = [
"src/query",
"src/script",
"src/servers",
"src/session",
"src/sql",
"src/storage",
"src/store-api",

View File

@@ -28,31 +28,42 @@ use crate::error::{
DeserializeCatalogEntryValueSnafu, Error, InvalidCatalogSnafu, SerializeCatalogEntryValueSnafu,
};
const ALPHANUMERICS_NAME_PATTERN: &str = "[a-zA-Z_][a-zA-Z0-9_]*";
lazy_static! {
static ref CATALOG_KEY_PATTERN: Regex =
Regex::new(&format!("^{}-([a-zA-Z_]+)$", CATALOG_KEY_PREFIX)).unwrap();
static ref CATALOG_KEY_PATTERN: Regex = Regex::new(&format!(
"^{}-({})$",
CATALOG_KEY_PREFIX, ALPHANUMERICS_NAME_PATTERN
))
.unwrap();
}
lazy_static! {
static ref SCHEMA_KEY_PATTERN: Regex = Regex::new(&format!(
"^{}-([a-zA-Z_]+)-([a-zA-Z_]+)$",
SCHEMA_KEY_PREFIX
"^{}-({})-({})$",
SCHEMA_KEY_PREFIX, ALPHANUMERICS_NAME_PATTERN, ALPHANUMERICS_NAME_PATTERN
))
.unwrap();
}
lazy_static! {
static ref TABLE_GLOBAL_KEY_PATTERN: Regex = Regex::new(&format!(
"^{}-([a-zA-Z_]+)-([a-zA-Z_]+)-([a-zA-Z0-9_]+)$",
TABLE_GLOBAL_KEY_PREFIX
"^{}-({})-({})-({})$",
TABLE_GLOBAL_KEY_PREFIX,
ALPHANUMERICS_NAME_PATTERN,
ALPHANUMERICS_NAME_PATTERN,
ALPHANUMERICS_NAME_PATTERN
))
.unwrap();
}
lazy_static! {
static ref TABLE_REGIONAL_KEY_PATTERN: Regex = Regex::new(&format!(
"^{}-([a-zA-Z_]+)-([a-zA-Z_]+)-([a-zA-Z0-9_]+)-([0-9]+)$",
TABLE_REGIONAL_KEY_PREFIX
"^{}-({})-({})-({})-([0-9]+)$",
TABLE_REGIONAL_KEY_PREFIX,
ALPHANUMERICS_NAME_PATTERN,
ALPHANUMERICS_NAME_PATTERN,
ALPHANUMERICS_NAME_PATTERN
))
.unwrap();
}

View File

@@ -37,6 +37,7 @@ meta-srv = { path = "../meta-srv", features = ["mock"] }
metrics = "0.20"
object-store = { path = "../object-store" }
query = { path = "../query" }
session = { path = "../session" }
script = { path = "../script", features = ["python"], optional = true }
serde = "1.0"
serde_json = "1.0"

View File

@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use api::result::{build_err_result, AdminResultBuilder, ObjectResultBuilder};
use api::v1::{
admin_expr, object_expr, select_expr, AdminExpr, AdminResult, Column, CreateDatabaseExpr,
@@ -26,6 +28,7 @@ use common_grpc_expr::insertion_expr_to_request;
use common_query::Output;
use query::plan::LogicalPlan;
use servers::query_handler::{GrpcAdminHandler, GrpcQueryHandler};
use session::context::QueryContext;
use snafu::prelude::*;
use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
use table::requests::CreateDatabaseRequest;
@@ -110,7 +113,9 @@ impl Instance {
async fn do_handle_select(&self, select_expr: SelectExpr) -> Result<Output> {
let expr = select_expr.expr;
match expr {
Some(select_expr::Expr::Sql(sql)) => self.execute_sql(&sql).await,
Some(select_expr::Expr::Sql(sql)) => {
self.execute_sql(&sql, Arc::new(QueryContext::new())).await
}
Some(select_expr::Expr::LogicalPlan(plan)) => self.execute_logical(plan).await,
Some(select_expr::Expr::PhysicalPlan(api::v1::PhysicalPlan { original_ql, plan })) => {
self.physical_planner

View File

@@ -13,25 +13,27 @@
// limitations under the License.
use async_trait::async_trait;
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_error::prelude::BoxedError;
use common_query::Output;
use common_recordbatch::RecordBatches;
use common_telemetry::logging::{error, info};
use common_telemetry::timer;
use servers::query_handler::SqlQueryHandler;
use session::context::QueryContextRef;
use snafu::prelude::*;
use sql::ast::ObjectName;
use sql::statements::statement::Statement;
use table::engine::TableReference;
use table::requests::CreateDatabaseRequest;
use crate::error::{
BumpTableIdSnafu, CatalogNotFoundSnafu, CatalogSnafu, ExecuteSqlSnafu, ParseSqlSnafu, Result,
SchemaNotFoundSnafu, TableIdProviderNotFoundSnafu,
};
use crate::error::{self, BumpTableIdSnafu, ExecuteSqlSnafu, Result, TableIdProviderNotFoundSnafu};
use crate::instance::Instance;
use crate::metric;
use crate::sql::SqlRequest;
impl Instance {
pub async fn execute_sql(&self, sql: &str) -> Result<Output> {
pub async fn execute_sql(&self, sql: &str, query_ctx: QueryContextRef) -> Result<Output> {
let stmt = self
.query_engine
.sql_to_statement(sql)
@@ -41,7 +43,7 @@ impl Instance {
Statement::Query(_) => {
let logical_plan = self
.query_engine
.statement_to_plan(stmt)
.statement_to_plan(stmt, query_ctx)
.context(ExecuteSqlSnafu)?;
self.query_engine
@@ -50,20 +52,15 @@ impl Instance {
.context(ExecuteSqlSnafu)
}
Statement::Insert(i) => {
let (catalog_name, schema_name, _table_name) =
i.full_table_name().context(ParseSqlSnafu)?;
let schema_provider = self
.catalog_manager
.catalog(&catalog_name)
.context(CatalogSnafu)?
.context(CatalogNotFoundSnafu { name: catalog_name })?
.schema(&schema_name)
.context(CatalogSnafu)?
.context(SchemaNotFoundSnafu { name: schema_name })?;
let request = self.sql_handler.insert_to_request(schema_provider, *i)?;
self.sql_handler.execute(request).await
let (catalog, schema, table) =
table_idents_to_full_name(i.table_name(), query_ctx.clone())?;
let table_ref = TableReference::full(&catalog, &schema, &table);
let request = self.sql_handler.insert_to_request(
self.catalog_manager.clone(),
*i,
table_ref,
)?;
self.sql_handler.execute(request, query_ctx).await
}
Statement::CreateDatabase(c) => {
@@ -74,7 +71,7 @@ impl Instance {
info!("Creating a new database: {}", request.db_name);
self.sql_handler
.execute(SqlRequest::CreateDatabase(request))
.execute(SqlRequest::CreateDatabase(request), query_ctx)
.await
}
@@ -89,58 +86,116 @@ impl Instance {
let _engine_name = c.engine.clone();
// TODO(hl): Select table engine by engine_name
let request = self.sql_handler.create_to_request(table_id, c)?;
let catalog_name = &request.catalog_name;
let schema_name = &request.schema_name;
let table_name = &request.table_name;
let name = c.name.clone();
let (catalog, schema, table) = table_idents_to_full_name(&name, query_ctx.clone())?;
let table_ref = TableReference::full(&catalog, &schema, &table);
let request = self.sql_handler.create_to_request(table_id, c, table_ref)?;
let table_id = request.id;
info!(
"Creating table, catalog: {:?}, schema: {:?}, table name: {:?}, table id: {}",
catalog_name, schema_name, table_name, table_id
catalog, schema, table, table_id
);
self.sql_handler
.execute(SqlRequest::CreateTable(request))
.execute(SqlRequest::CreateTable(request), query_ctx)
.await
}
Statement::Alter(alter_table) => {
let req = self.sql_handler.alter_to_request(alter_table)?;
self.sql_handler.execute(SqlRequest::Alter(req)).await
let name = alter_table.table_name().clone();
let (catalog, schema, table) = table_idents_to_full_name(&name, query_ctx.clone())?;
let table_ref = TableReference::full(&catalog, &schema, &table);
let req = self.sql_handler.alter_to_request(alter_table, table_ref)?;
self.sql_handler
.execute(SqlRequest::Alter(req), query_ctx)
.await
}
Statement::DropTable(drop_table) => {
let req = self.sql_handler.drop_table_to_request(drop_table);
self.sql_handler.execute(SqlRequest::DropTable(req)).await
self.sql_handler
.execute(SqlRequest::DropTable(req), query_ctx)
.await
}
Statement::ShowDatabases(stmt) => {
self.sql_handler
.execute(SqlRequest::ShowDatabases(stmt))
.execute(SqlRequest::ShowDatabases(stmt), query_ctx)
.await
}
Statement::ShowTables(stmt) => {
self.sql_handler.execute(SqlRequest::ShowTables(stmt)).await
self.sql_handler
.execute(SqlRequest::ShowTables(stmt), query_ctx)
.await
}
Statement::Explain(stmt) => {
self.sql_handler
.execute(SqlRequest::Explain(Box::new(stmt)))
.execute(SqlRequest::Explain(Box::new(stmt)), query_ctx)
.await
}
Statement::DescribeTable(stmt) => {
self.sql_handler
.execute(SqlRequest::DescribeTable(stmt))
.execute(SqlRequest::DescribeTable(stmt), query_ctx)
.await
}
Statement::ShowCreateTable(_stmt) => {
unimplemented!("SHOW CREATE TABLE is unimplemented yet");
}
Statement::Use(db) => {
ensure!(
self.catalog_manager
.schema(DEFAULT_CATALOG_NAME, &db)
.context(error::CatalogSnafu)?
.is_some(),
error::SchemaNotFoundSnafu { name: &db }
);
query_ctx.set_current_schema(&db);
Ok(Output::RecordBatches(RecordBatches::empty()))
}
}
}
}
// TODO(LFC): Refactor consideration: move this function to some helper mod,
// could be done together or after `TableReference`'s refactoring, when issue #559 is resolved.
/// Converts maybe fully-qualified table name (`<catalog>.<schema>.<table>`) to tuple.
fn table_idents_to_full_name(
obj_name: &ObjectName,
query_ctx: QueryContextRef,
) -> Result<(String, String, String)> {
match &obj_name.0[..] {
[table] => Ok((
DEFAULT_CATALOG_NAME.to_string(),
query_ctx.current_schema().unwrap_or_else(|| DEFAULT_SCHEMA_NAME.to_string()),
table.value.clone(),
)),
[schema, table] => Ok((
DEFAULT_CATALOG_NAME.to_string(),
schema.value.clone(),
table.value.clone(),
)),
[catalog, schema, table] => Ok((
catalog.value.clone(),
schema.value.clone(),
table.value.clone(),
)),
_ => error::InvalidSqlSnafu {
msg: format!(
"expect table name to be <catalog>.<schema>.<table>, <schema>.<table> or <table>, actual: {}",
obj_name
),
}.fail(),
}
}
#[async_trait]
impl SqlQueryHandler for Instance {
async fn do_query(&self, query: &str) -> servers::error::Result<Output> {
async fn do_query(
&self,
query: &str,
query_ctx: QueryContextRef,
) -> servers::error::Result<Output> {
let _timer = timer!(metric::METRIC_HANDLE_SQL_ELAPSED);
self.execute_sql(query)
self.execute_sql(query, query_ctx)
.await
.map_err(|e| {
error!(e; "Instance failed to execute sql");
@@ -149,3 +204,78 @@ impl SqlQueryHandler for Instance {
.context(servers::error::ExecuteQuerySnafu { query })
}
}
#[cfg(test)]
mod test {
use std::sync::Arc;
use session::context::QueryContext;
use super::*;
#[test]
fn test_table_idents_to_full_name() {
let my_catalog = "my_catalog";
let my_schema = "my_schema";
let my_table = "my_table";
let full = ObjectName(vec![my_catalog.into(), my_schema.into(), my_table.into()]);
let partial = ObjectName(vec![my_schema.into(), my_table.into()]);
let bare = ObjectName(vec![my_table.into()]);
let using_schema = "foo";
let query_ctx = Arc::new(QueryContext::with_current_schema(using_schema.to_string()));
let empty_ctx = Arc::new(QueryContext::new());
assert_eq!(
table_idents_to_full_name(&full, query_ctx.clone()).unwrap(),
(
my_catalog.to_string(),
my_schema.to_string(),
my_table.to_string()
)
);
assert_eq!(
table_idents_to_full_name(&full, empty_ctx.clone()).unwrap(),
(
my_catalog.to_string(),
my_schema.to_string(),
my_table.to_string()
)
);
assert_eq!(
table_idents_to_full_name(&partial, query_ctx.clone()).unwrap(),
(
DEFAULT_CATALOG_NAME.to_string(),
my_schema.to_string(),
my_table.to_string()
)
);
assert_eq!(
table_idents_to_full_name(&partial, empty_ctx.clone()).unwrap(),
(
DEFAULT_CATALOG_NAME.to_string(),
my_schema.to_string(),
my_table.to_string()
)
);
assert_eq!(
table_idents_to_full_name(&bare, query_ctx).unwrap(),
(
DEFAULT_CATALOG_NAME.to_string(),
using_schema.to_string(),
my_table.to_string()
)
);
assert_eq!(
table_idents_to_full_name(&bare, empty_ctx).unwrap(),
(
DEFAULT_CATALOG_NAME.to_string(),
DEFAULT_SCHEMA_NAME.to_string(),
my_table.to_string()
)
);
}
}

View File

@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use api::result::AdminResultBuilder;
use api::v1::{AdminResult, AlterExpr, CreateExpr, DropTableExpr};
use common_error::prelude::{ErrorExt, StatusCode};
@@ -19,6 +21,7 @@ use common_grpc_expr::{alter_expr_to_request, create_expr_to_request};
use common_query::Output;
use common_telemetry::{error, info};
use futures::TryFutureExt;
use session::context::QueryContext;
use snafu::prelude::*;
use table::requests::DropTableRequest;
@@ -72,7 +75,12 @@ impl Instance {
let request = create_expr_to_request(table_id, expr).context(CreateExprToRequestSnafu);
let result = futures::future::ready(request)
.and_then(|request| self.sql_handler().execute(SqlRequest::CreateTable(request)))
.and_then(|request| {
self.sql_handler().execute(
SqlRequest::CreateTable(request),
Arc::new(QueryContext::new()),
)
})
.await;
match result {
Ok(Output::AffectedRows(rows)) => AdminResultBuilder::default()
@@ -103,7 +111,10 @@ impl Instance {
};
let result = futures::future::ready(request)
.and_then(|request| self.sql_handler().execute(SqlRequest::Alter(request)))
.and_then(|request| {
self.sql_handler()
.execute(SqlRequest::Alter(request), Arc::new(QueryContext::new()))
})
.await;
match result {
Ok(Output::AffectedRows(rows)) => AdminResultBuilder::default()
@@ -124,7 +135,10 @@ impl Instance {
schema_name: expr.schema_name,
table_name: expr.table_name,
};
let result = self.sql_handler().execute(SqlRequest::DropTable(req)).await;
let result = self
.sql_handler()
.execute(SqlRequest::DropTable(req), Arc::new(QueryContext::new()))
.await;
match result {
Ok(Output::AffectedRows(rows)) => AdminResultBuilder::default()
.status_code(StatusCode::Success as u32)

View File

@@ -12,13 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//! sql handler
use catalog::CatalogManagerRef;
use common_query::Output;
use common_telemetry::error;
use query::query_engine::QueryEngineRef;
use query::sql::{describe_table, explain, show_databases, show_tables};
use session::context::QueryContextRef;
use snafu::{OptionExt, ResultExt};
use sql::statements::describe::DescribeTable;
use sql::statements::explain::Explain;
@@ -67,7 +66,11 @@ impl SqlHandler {
}
}
pub async fn execute(&self, request: SqlRequest) -> Result<Output> {
// TODO(LFC): Refactor consideration: a context awareness "Planner".
// Now we have some query related state (like current using database in session context), maybe
// we could create a new struct called `Planner` that stores context and handle these queries
// there, instead of executing here in a "static" fashion.
pub async fn execute(&self, request: SqlRequest, query_ctx: QueryContextRef) -> Result<Output> {
let result = match request {
SqlRequest::Insert(req) => self.insert(req).await,
SqlRequest::CreateTable(req) => self.create_table(req).await,
@@ -78,12 +81,12 @@ impl SqlHandler {
show_databases(stmt, self.catalog_manager.clone()).context(ExecuteSqlSnafu)
}
SqlRequest::ShowTables(stmt) => {
show_tables(stmt, self.catalog_manager.clone()).context(ExecuteSqlSnafu)
show_tables(stmt, self.catalog_manager.clone(), query_ctx).context(ExecuteSqlSnafu)
}
SqlRequest::DescribeTable(stmt) => {
describe_table(stmt, self.catalog_manager.clone()).context(ExecuteSqlSnafu)
}
SqlRequest::Explain(stmt) => explain(stmt, self.query_engine.clone())
SqlRequest::Explain(stmt) => explain(stmt, self.query_engine.clone(), query_ctx)
.await
.context(ExecuteSqlSnafu),
};
@@ -114,7 +117,8 @@ mod tests {
use std::any::Any;
use std::sync::Arc;
use catalog::SchemaProvider;
use catalog::{CatalogList, SchemaProvider};
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_query::logical_plan::Expr;
use common_query::physical_plan::PhysicalPlanRef;
use common_time::timestamp::Timestamp;
@@ -234,9 +238,17 @@ mod tests {
.await
.unwrap(),
);
let catalog_provider = catalog_list.catalog(DEFAULT_CATALOG_NAME).unwrap().unwrap();
catalog_provider
.register_schema(
DEFAULT_SCHEMA_NAME.to_string(),
Arc::new(MockSchemaProvider {}),
)
.unwrap();
let factory = QueryEngineFactory::new(catalog_list.clone());
let query_engine = factory.query_engine();
let sql_handler = SqlHandler::new(table_engine, catalog_list, query_engine.clone());
let sql_handler = SqlHandler::new(table_engine, catalog_list.clone(), query_engine.clone());
let stmt = match query_engine.sql_to_statement(sql).unwrap() {
Statement::Insert(i) => i,
@@ -244,9 +256,8 @@ mod tests {
unreachable!()
}
};
let schema_provider = Arc::new(MockSchemaProvider {});
let request = sql_handler
.insert_to_request(schema_provider, *stmt)
.insert_to_request(catalog_list.clone(), *stmt, TableReference::bare("demo"))
.unwrap();
match request {

View File

@@ -16,7 +16,7 @@ use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_query::Output;
use snafu::prelude::*;
use sql::statements::alter::{AlterTable, AlterTableOperation};
use sql::statements::{column_def_to_schema, table_idents_to_full_name};
use sql::statements::column_def_to_schema;
use table::engine::{EngineContext, TableReference};
use table::requests::{AddColumnRequest, AlterKind, AlterTableRequest};
@@ -53,10 +53,11 @@ impl SqlHandler {
Ok(Output::AffectedRows(0))
}
pub(crate) fn alter_to_request(&self, alter_table: AlterTable) -> Result<AlterTableRequest> {
let (catalog_name, schema_name, table_name) =
table_idents_to_full_name(alter_table.table_name()).context(error::ParseSqlSnafu)?;
pub(crate) fn alter_to_request(
&self,
alter_table: AlterTable,
table_ref: TableReference,
) -> Result<AlterTableRequest> {
let alter_kind = match alter_table.alter_operation() {
AlterTableOperation::AddConstraint(table_constraint) => {
return error::InvalidSqlSnafu {
@@ -77,9 +78,9 @@ impl SqlHandler {
},
};
Ok(AlterTableRequest {
catalog_name: Some(catalog_name),
schema_name: Some(schema_name),
table_name,
catalog_name: Some(table_ref.catalog.to_string()),
schema_name: Some(table_ref.schema.to_string()),
table_name: table_ref.table.to_string(),
alter_kind,
})
}
@@ -112,7 +113,9 @@ mod tests {
async fn test_alter_to_request_with_adding_column() {
let handler = create_mock_sql_handler().await;
let alter_table = parse_sql("ALTER TABLE my_metric_1 ADD tagk_i STRING Null;");
let req = handler.alter_to_request(alter_table).unwrap();
let req = handler
.alter_to_request(alter_table, TableReference::bare("my_metric_1"))
.unwrap();
assert_eq!(req.catalog_name, Some("greptime".to_string()));
assert_eq!(req.schema_name, Some("public".to_string()));
assert_eq!(req.table_name, "my_metric_1");

View File

@@ -23,10 +23,10 @@ use common_telemetry::tracing::log::error;
use datatypes::schema::SchemaBuilder;
use snafu::{ensure, OptionExt, ResultExt};
use sql::ast::TableConstraint;
use sql::statements::column_def_to_schema;
use sql::statements::create::CreateTable;
use sql::statements::{column_def_to_schema, table_idents_to_full_name};
use store_api::storage::consts::TIME_INDEX_NAME;
use table::engine::EngineContext;
use table::engine::{EngineContext, TableReference};
use table::metadata::TableId;
use table::requests::*;
@@ -114,13 +114,11 @@ impl SqlHandler {
&self,
table_id: TableId,
stmt: CreateTable,
table_ref: TableReference,
) -> Result<CreateTableRequest> {
let mut ts_index = usize::MAX;
let mut primary_keys = vec![];
let (catalog_name, schema_name, table_name) =
table_idents_to_full_name(&stmt.name).context(error::ParseSqlSnafu)?;
let col_map = stmt
.columns
.iter()
@@ -187,8 +185,8 @@ impl SqlHandler {
if primary_keys.is_empty() {
info!(
"Creating table: {:?}.{:?}.{} but primary key not set, use time index column: {}",
catalog_name, schema_name, table_name, ts_index
"Creating table: {} with time index column: {} upon primary keys absent",
table_ref, ts_index
);
primary_keys.push(ts_index);
}
@@ -211,9 +209,9 @@ impl SqlHandler {
let request = CreateTableRequest {
id: table_id,
catalog_name,
schema_name,
table_name,
catalog_name: table_ref.catalog.to_string(),
schema_name: table_ref.schema.to_string(),
table_name: table_ref.table.to_string(),
desc: None,
schema,
region_numbers: vec![0],
@@ -261,7 +259,9 @@ mod tests {
TIME INDEX (ts),
PRIMARY KEY(host)) engine=mito with(regions=1);"#,
);
let c = handler.create_to_request(42, parsed_stmt).unwrap();
let c = handler
.create_to_request(42, parsed_stmt, TableReference::bare("demo_table"))
.unwrap();
assert_eq!("demo_table", c.table_name);
assert_eq!(42, c.id);
assert!(!c.create_if_not_exists);
@@ -282,7 +282,9 @@ mod tests {
memory double,
PRIMARY KEY(host)) engine=mito with(regions=1);"#,
);
let error = handler.create_to_request(42, parsed_stmt).unwrap_err();
let error = handler
.create_to_request(42, parsed_stmt, TableReference::bare("demo_table"))
.unwrap_err();
assert_matches!(error, Error::MissingTimestampColumn { .. });
}
@@ -299,7 +301,9 @@ mod tests {
memory double,
TIME INDEX (ts)) engine=mito with(regions=1);"#,
);
let c = handler.create_to_request(42, parsed_stmt).unwrap();
let c = handler
.create_to_request(42, parsed_stmt, TableReference::bare("demo_table"))
.unwrap();
assert_eq!(1, c.primary_key_indices.len());
assert_eq!(
c.schema.timestamp_index().unwrap(),
@@ -318,7 +322,9 @@ mod tests {
TIME INDEX (ts)) engine=mito with(regions=1);"#,
);
let error = handler.create_to_request(42, parsed_stmt).unwrap_err();
let error = handler
.create_to_request(42, parsed_stmt, TableReference::bare("demo_table"))
.unwrap_err();
assert_matches!(error, Error::KeyColumnNotFound { .. });
}
@@ -338,7 +344,9 @@ mod tests {
let handler = create_mock_sql_handler().await;
let error = handler.create_to_request(42, create_table).unwrap_err();
let error = handler
.create_to_request(42, create_table, TableReference::full("c", "s", "demo"))
.unwrap_err();
assert_matches!(error, Error::InvalidPrimaryKey { .. });
}
@@ -358,7 +366,9 @@ mod tests {
let handler = create_mock_sql_handler().await;
let request = handler.create_to_request(42, create_table).unwrap();
let request = handler
.create_to_request(42, create_table, TableReference::full("c", "s", "demo"))
.unwrap();
assert_eq!(42, request.id);
assert_eq!("c".to_string(), request.catalog_name);

View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use catalog::SchemaProviderRef;
use catalog::CatalogManagerRef;
use common_query::Output;
use datatypes::prelude::{ConcreteDataType, VectorBuilder};
use snafu::{ensure, OptionExt, ResultExt};
@@ -23,7 +23,7 @@ use table::engine::TableReference;
use table::requests::*;
use crate::error::{
CatalogSnafu, ColumnNotFoundSnafu, ColumnValuesNumberMismatchSnafu, InsertSnafu, ParseSqlSnafu,
CatalogSnafu, ColumnNotFoundSnafu, ColumnValuesNumberMismatchSnafu, InsertSnafu,
ParseSqlValueSnafu, Result, TableNotFoundSnafu,
};
use crate::sql::{SqlHandler, SqlRequest};
@@ -49,19 +49,18 @@ impl SqlHandler {
pub(crate) fn insert_to_request(
&self,
schema_provider: SchemaProviderRef,
catalog_manager: CatalogManagerRef,
stmt: Insert,
table_ref: TableReference,
) -> Result<SqlRequest> {
let columns = stmt.columns();
let values = stmt.values().context(ParseSqlValueSnafu)?;
let (catalog_name, schema_name, table_name) =
stmt.full_table_name().context(ParseSqlSnafu)?;
let table = schema_provider
.table(&table_name)
let table = catalog_manager
.table(table_ref.catalog, table_ref.schema, table_ref.table)
.context(CatalogSnafu)?
.context(TableNotFoundSnafu {
table_name: &table_name,
table_name: table_ref.table,
})?;
let schema = table.schema();
let columns_num = if columns.is_empty() {
@@ -88,7 +87,7 @@ impl SqlHandler {
let column_schema =
schema.column_schema_by_name(column_name).with_context(|| {
ColumnNotFoundSnafu {
table_name: &table_name,
table_name: table_ref.table,
column_name: column_name.to_string(),
}
})?;
@@ -119,9 +118,9 @@ impl SqlHandler {
}
Ok(SqlRequest::Insert(InsertRequest {
catalog_name,
schema_name,
table_name,
catalog_name: table_ref.catalog.to_string(),
schema_name: table_ref.schema.to_string(),
table_name: table_ref.table.to_string(),
columns_values: columns_builders
.into_iter()
.map(|(c, _, mut b)| (c.to_owned(), b.finish()))

View File

@@ -12,6 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use common_catalog::consts::DEFAULT_SCHEMA_NAME;
use common_query::Output;
use common_recordbatch::util;
use datafusion::arrow_print;
@@ -19,6 +22,7 @@ use datafusion_common::record_batch::RecordBatch as DfRecordBatch;
use datatypes::arrow::array::{Int64Array, UInt64Array, Utf8Array};
use datatypes::arrow_array::StringArray;
use datatypes::prelude::ConcreteDataType;
use session::context::QueryContext;
use crate::instance::Instance;
use crate::tests::test_util;
@@ -32,39 +36,33 @@ async fn test_create_database_and_insert_query() {
let instance = Instance::with_mock_meta_client(&opts).await.unwrap();
instance.start().await.unwrap();
let output = instance.execute_sql("create database test").await.unwrap();
let output = execute_sql(&instance, "create database test").await;
assert!(matches!(output, Output::AffectedRows(1)));
let output = instance
.execute_sql(
r#"create table greptime.test.demo(
let output = execute_sql(
&instance,
r#"create table greptime.test.demo(
host STRING,
cpu DOUBLE,
memory DOUBLE,
ts bigint,
TIME INDEX(ts)
)"#,
)
.await
.unwrap();
)
.await;
assert!(matches!(output, Output::AffectedRows(1)));
let output = instance
.execute_sql(
r#"insert into test.demo(host, cpu, memory, ts) values
let output = execute_sql(
&instance,
r#"insert into test.demo(host, cpu, memory, ts) values
('host1', 66.6, 1024, 1655276557000),
('host2', 88.8, 333.3, 1655276558000)
"#,
)
.await
.unwrap();
)
.await;
assert!(matches!(output, Output::AffectedRows(2)));
let query_output = instance
.execute_sql("select ts from test.demo order by ts")
.await
.unwrap();
let query_output = execute_sql(&instance, "select ts from test.demo order by ts").await;
match query_output {
Output::Stream(s) => {
let batches = util::collect(s).await.unwrap();
@@ -88,54 +86,50 @@ async fn test_issue477_same_table_name_in_different_databases() {
instance.start().await.unwrap();
// Create database a and b
let output = instance.execute_sql("create database a").await.unwrap();
let output = execute_sql(&instance, "create database a").await;
assert!(matches!(output, Output::AffectedRows(1)));
let output = instance.execute_sql("create database b").await.unwrap();
let output = execute_sql(&instance, "create database b").await;
assert!(matches!(output, Output::AffectedRows(1)));
// Create table a.demo and b.demo
let output = instance
.execute_sql(
r#"create table a.demo(
let output = execute_sql(
&instance,
r#"create table a.demo(
host STRING,
ts bigint,
TIME INDEX(ts)
)"#,
)
.await
.unwrap();
)
.await;
assert!(matches!(output, Output::AffectedRows(1)));
let output = instance
.execute_sql(
r#"create table b.demo(
let output = execute_sql(
&instance,
r#"create table b.demo(
host STRING,
ts bigint,
TIME INDEX(ts)
)"#,
)
.await
.unwrap();
)
.await;
assert!(matches!(output, Output::AffectedRows(1)));
// Insert different data into a.demo and b.demo
let output = instance
.execute_sql(
r#"insert into a.demo(host, ts) values
let output = execute_sql(
&instance,
r#"insert into a.demo(host, ts) values
('host1', 1655276557000)
"#,
)
.await
.unwrap();
)
.await;
assert!(matches!(output, Output::AffectedRows(1)));
let output = instance
.execute_sql(
r#"insert into b.demo(host, ts) values
let output = execute_sql(
&instance,
r#"insert into b.demo(host, ts) values
('host2',1655276558000)
"#,
)
.await
.unwrap();
)
.await;
assert!(matches!(output, Output::AffectedRows(1)));
// Query data and assert
@@ -157,7 +151,7 @@ async fn test_issue477_same_table_name_in_different_databases() {
}
async fn assert_query_result(instance: &Instance, sql: &str, ts: i64, host: &str) {
let query_output = instance.execute_sql(sql).await.unwrap();
let query_output = execute_sql(instance, sql).await;
match query_output {
Output::Stream(s) => {
let batches = util::collect(s).await.unwrap();
@@ -200,15 +194,14 @@ async fn setup_test_instance() -> Instance {
#[tokio::test(flavor = "multi_thread")]
async fn test_execute_insert() {
let instance = setup_test_instance().await;
let output = instance
.execute_sql(
r#"insert into demo(host, cpu, memory, ts) values
let output = execute_sql(
&instance,
r#"insert into demo(host, cpu, memory, ts) values
('host1', 66.6, 1024, 1655276557000),
('host2', 88.8, 333.3, 1655276558000)
"#,
)
.await
.unwrap();
)
.await;
assert!(matches!(output, Output::AffectedRows(2)));
}
@@ -228,22 +221,17 @@ async fn test_execute_insert_query_with_i64_timestamp() {
.await
.unwrap();
let output = instance
.execute_sql(
r#"insert into demo(host, cpu, memory, ts) values
let output = execute_sql(
&instance,
r#"insert into demo(host, cpu, memory, ts) values
('host1', 66.6, 1024, 1655276557000),
('host2', 88.8, 333.3, 1655276558000)
"#,
)
.await
.unwrap();
)
.await;
assert!(matches!(output, Output::AffectedRows(2)));
let query_output = instance
.execute_sql("select ts from demo order by ts")
.await
.unwrap();
let query_output = execute_sql(&instance, "select ts from demo order by ts").await;
match query_output {
Output::Stream(s) => {
let batches = util::collect(s).await.unwrap();
@@ -257,11 +245,7 @@ async fn test_execute_insert_query_with_i64_timestamp() {
_ => unreachable!(),
}
let query_output = instance
.execute_sql("select ts as time from demo order by ts")
.await
.unwrap();
let query_output = execute_sql(&instance, "select ts as time from demo order by ts").await;
match query_output {
Output::Stream(s) => {
let batches = util::collect(s).await.unwrap();
@@ -282,10 +266,7 @@ async fn test_execute_query() {
let instance = Instance::with_mock_meta_client(&opts).await.unwrap();
instance.start().await.unwrap();
let output = instance
.execute_sql("select sum(number) from numbers limit 20")
.await
.unwrap();
let output = execute_sql(&instance, "select sum(number) from numbers limit 20").await;
match output {
Output::Stream(recordbatch) => {
let numbers = util::collect(recordbatch).await.unwrap();
@@ -309,7 +290,7 @@ async fn test_execute_show_databases_tables() {
let instance = Instance::with_mock_meta_client(&opts).await.unwrap();
instance.start().await.unwrap();
let output = instance.execute_sql("show databases").await.unwrap();
let output = execute_sql(&instance, "show databases").await;
match output {
Output::RecordBatches(databases) => {
let databases = databases.take();
@@ -325,10 +306,7 @@ async fn test_execute_show_databases_tables() {
_ => unreachable!(),
}
let output = instance
.execute_sql("show databases like '%bl%'")
.await
.unwrap();
let output = execute_sql(&instance, "show databases like '%bl%'").await;
match output {
Output::RecordBatches(databases) => {
let databases = databases.take();
@@ -344,7 +322,7 @@ async fn test_execute_show_databases_tables() {
_ => unreachable!(),
}
let output = instance.execute_sql("show tables").await.unwrap();
let output = execute_sql(&instance, "show tables").await;
match output {
Output::RecordBatches(databases) => {
let databases = databases.take();
@@ -364,7 +342,7 @@ async fn test_execute_show_databases_tables() {
.await
.unwrap();
let output = instance.execute_sql("show tables").await.unwrap();
let output = execute_sql(&instance, "show tables").await;
match output {
Output::RecordBatches(databases) => {
let databases = databases.take();
@@ -376,10 +354,7 @@ async fn test_execute_show_databases_tables() {
}
// show tables like [string]
let output = instance
.execute_sql("show tables like 'de%'")
.await
.unwrap();
let output = execute_sql(&instance, "show tables like 'de%'").await;
match output {
Output::RecordBatches(databases) => {
let databases = databases.take();
@@ -404,9 +379,9 @@ pub async fn test_execute_create() {
let instance = Instance::with_mock_meta_client(&opts).await.unwrap();
instance.start().await.unwrap();
let output = instance
.execute_sql(
r#"create table test_table(
let output = execute_sql(
&instance,
r#"create table test_table(
host string,
ts timestamp,
cpu double default 0,
@@ -414,56 +389,24 @@ pub async fn test_execute_create() {
TIME INDEX (ts),
PRIMARY KEY(host)
) engine=mito with(regions=1);"#,
)
.await
.unwrap();
)
.await;
assert!(matches!(output, Output::AffectedRows(1)));
}
#[tokio::test(flavor = "multi_thread")]
pub async fn test_create_table_illegal_timestamp_type() {
common_telemetry::init_default_ut_logging();
let (opts, _guard) =
test_util::create_tmp_dir_and_datanode_opts("create_table_illegal_timestamp_type");
let instance = Instance::with_mock_meta_client(&opts).await.unwrap();
instance.start().await.unwrap();
let output = instance
.execute_sql(
r#"create table test_table(
host string,
ts bigint,
cpu double default 0,
memory double,
TIME INDEX (ts),
PRIMARY KEY(host)
) engine=mito with(regions=1);"#,
)
.await
.unwrap();
match output {
Output::AffectedRows(rows) => {
assert_eq!(1, rows);
}
_ => unreachable!(),
}
}
async fn check_output_stream(output: Output, expected: Vec<&str>) {
match output {
Output::Stream(stream) => {
let recordbatches = util::collect(stream).await.unwrap();
let recordbatch = recordbatches
.into_iter()
.map(|r| r.df_recordbatch)
.collect::<Vec<DfRecordBatch>>();
let pretty_print = arrow_print::write(&recordbatch);
let pretty_print = pretty_print.lines().collect::<Vec<&str>>();
assert_eq!(pretty_print, expected);
}
let recordbatches = match output {
Output::Stream(stream) => util::collect(stream).await.unwrap(),
Output::RecordBatches(recordbatches) => recordbatches.take(),
_ => unreachable!(),
}
};
let recordbatches = recordbatches
.into_iter()
.map(|r| r.df_recordbatch)
.collect::<Vec<DfRecordBatch>>();
let pretty_print = arrow_print::write(&recordbatches);
let pretty_print = pretty_print.lines().collect::<Vec<&str>>();
assert_eq!(pretty_print, expected);
}
#[tokio::test]
@@ -479,35 +422,30 @@ async fn test_alter_table() {
.await
.unwrap();
// make sure table insertion is ok before altering table
instance
.execute_sql("insert into demo(host, cpu, memory, ts) values ('host1', 1.1, 100, 1000)")
.await
.unwrap();
execute_sql(
&instance,
"insert into demo(host, cpu, memory, ts) values ('host1', 1.1, 100, 1000)",
)
.await;
// Add column
let output = instance
.execute_sql("alter table demo add my_tag string null")
.await
.unwrap();
let output = execute_sql(&instance, "alter table demo add my_tag string null").await;
assert!(matches!(output, Output::AffectedRows(0)));
let output = instance
.execute_sql(
"insert into demo(host, cpu, memory, ts, my_tag) values ('host2', 2.2, 200, 2000, 'hello')",
)
.await
.unwrap();
let output = execute_sql(
&instance,
"insert into demo(host, cpu, memory, ts, my_tag) values ('host2', 2.2, 200, 2000, 'hello')",
)
.await;
assert!(matches!(output, Output::AffectedRows(1)));
let output = instance
.execute_sql("insert into demo(host, cpu, memory, ts) values ('host3', 3.3, 300, 3000)")
.await
.unwrap();
let output = execute_sql(
&instance,
"insert into demo(host, cpu, memory, ts) values ('host3', 3.3, 300, 3000)",
)
.await;
assert!(matches!(output, Output::AffectedRows(1)));
let output = instance
.execute_sql("select * from demo order by ts")
.await
.unwrap();
let output = execute_sql(&instance, "select * from demo order by ts").await;
let expected = vec![
"+-------+-----+--------+---------------------+--------+",
"| host | cpu | memory | ts | my_tag |",
@@ -520,16 +458,10 @@ async fn test_alter_table() {
check_output_stream(output, expected).await;
// Drop a column
let output = instance
.execute_sql("alter table demo drop column memory")
.await
.unwrap();
let output = execute_sql(&instance, "alter table demo drop column memory").await;
assert!(matches!(output, Output::AffectedRows(0)));
let output = instance
.execute_sql("select * from demo order by ts")
.await
.unwrap();
let output = execute_sql(&instance, "select * from demo order by ts").await;
let expected = vec![
"+-------+-----+---------------------+--------+",
"| host | cpu | ts | my_tag |",
@@ -542,16 +474,14 @@ async fn test_alter_table() {
check_output_stream(output, expected).await;
// insert a new row
let output = instance
.execute_sql("insert into demo(host, cpu, ts, my_tag) values ('host4', 400, 4000, 'world')")
.await
.unwrap();
let output = execute_sql(
&instance,
"insert into demo(host, cpu, ts, my_tag) values ('host4', 400, 4000, 'world')",
)
.await;
assert!(matches!(output, Output::AffectedRows(1)));
let output = instance
.execute_sql("select * from demo order by ts")
.await
.unwrap();
let output = execute_sql(&instance, "select * from demo order by ts").await;
let expected = vec![
"+-------+-----+---------------------+--------+",
"| host | cpu | ts | my_tag |",
@@ -580,27 +510,26 @@ async fn test_insert_with_default_value_for_type(type_name: &str) {
) engine=mito with(regions=1);"#,
type_name
);
let output = instance.execute_sql(&create_sql).await.unwrap();
let output = execute_sql(&instance, &create_sql).await;
assert!(matches!(output, Output::AffectedRows(1)));
// Insert with ts.
instance
.execute_sql("insert into test_table(host, cpu, ts) values ('host1', 1.1, 1000)")
.await
.unwrap();
let output = execute_sql(
&instance,
"insert into test_table(host, cpu, ts) values ('host1', 1.1, 1000)",
)
.await;
assert!(matches!(output, Output::AffectedRows(1)));
// Insert without ts, so it should be filled by default value.
let output = instance
.execute_sql("insert into test_table(host, cpu) values ('host2', 2.2)")
.await
.unwrap();
let output = execute_sql(
&instance,
"insert into test_table(host, cpu) values ('host2', 2.2)",
)
.await;
assert!(matches!(output, Output::AffectedRows(1)));
let output = instance
.execute_sql("select host, cpu from test_table")
.await
.unwrap();
let output = execute_sql(&instance, "select host, cpu from test_table").await;
let expected = vec![
"+-------+-----+",
"| host | cpu |",
@@ -619,3 +548,70 @@ async fn test_insert_with_default_value() {
test_insert_with_default_value_for_type("timestamp").await;
test_insert_with_default_value_for_type("bigint").await;
}
#[tokio::test(flavor = "multi_thread")]
async fn test_use_database() {
let (opts, _guard) = test_util::create_tmp_dir_and_datanode_opts("use_database");
let instance = Instance::with_mock_meta_client(&opts).await.unwrap();
instance.start().await.unwrap();
let output = execute_sql(&instance, "create database db1").await;
assert!(matches!(output, Output::AffectedRows(1)));
let output = execute_sql_in_db(
&instance,
"create table tb1(col_i32 int, ts bigint, TIME INDEX(ts))",
"db1",
)
.await;
assert!(matches!(output, Output::AffectedRows(1)));
let output = execute_sql_in_db(&instance, "show tables", "db1").await;
let expected = vec![
"+--------+",
"| Tables |",
"+--------+",
"| tb1 |",
"+--------+",
];
check_output_stream(output, expected).await;
let output = execute_sql_in_db(
&instance,
r#"insert into tb1(col_i32, ts) values (1, 1655276557000)"#,
"db1",
)
.await;
assert!(matches!(output, Output::AffectedRows(1)));
let output = execute_sql_in_db(&instance, "select col_i32 from tb1", "db1").await;
let expected = vec![
"+---------+",
"| col_i32 |",
"+---------+",
"| 1 |",
"+---------+",
];
check_output_stream(output, expected).await;
// Making a particular database the default by means of the USE statement does not preclude
// accessing tables in other databases.
let output = execute_sql(&instance, "select number from public.numbers limit 1").await;
let expected = vec![
"+--------+",
"| number |",
"+--------+",
"| 0 |",
"+--------+",
];
check_output_stream(output, expected).await;
}
async fn execute_sql(instance: &Instance, sql: &str) -> Output {
execute_sql_in_db(instance, sql, DEFAULT_SCHEMA_NAME).await
}
async fn execute_sql_in_db(instance: &Instance, sql: &str, db: &str) -> Output {
let query_ctx = Arc::new(QueryContext::with_current_schema(db.to_string()));
instance.execute_sql(sql, query_ctx).await.unwrap()
}

View File

@@ -38,6 +38,7 @@ prost = "0.11"
query = { path = "../query" }
serde = "1.0"
serde_json = "1.0"
session = { path = "../session" }
sqlparser = "0.15"
servers = { path = "../servers" }
snafu = { version = "0.7", features = ["backtraces"] }

View File

@@ -38,6 +38,7 @@ use common_error::prelude::{BoxedError, StatusCode};
use common_grpc::channel_manager::{ChannelConfig, ChannelManager};
use common_grpc::select::to_object_result;
use common_query::Output;
use common_recordbatch::RecordBatches;
use common_telemetry::{debug, error, info};
use distributed::DistInstance;
use meta_client::client::MetaClientBuilder;
@@ -47,6 +48,7 @@ use servers::query_handler::{
PrometheusProtocolHandler, ScriptHandler, ScriptHandlerRef, SqlQueryHandler,
};
use servers::{error as server_error, Mode};
use session::context::{QueryContext, QueryContextRef};
use snafu::prelude::*;
use sql::dialect::GenericDialect;
use sql::parser::ParserContext;
@@ -211,10 +213,15 @@ impl Instance {
self.script_handler = Some(handler);
}
pub async fn handle_select(&self, expr: Select, stmt: Statement) -> Result<Output> {
async fn handle_select(
&self,
expr: Select,
stmt: Statement,
query_ctx: QueryContextRef,
) -> Result<Output> {
if let Some(dist_instance) = &self.dist_instance {
let Select::Sql(sql) = expr;
dist_instance.handle_sql(&sql, stmt).await
dist_instance.handle_sql(&sql, stmt, query_ctx).await
} else {
// TODO(LFC): Refactor consideration: Datanode should directly execute statement in standalone mode to avoid parse SQL again.
// Find a better way to execute query between Frontend and Datanode in standalone mode.
@@ -298,10 +305,15 @@ impl Instance {
}
/// Handle explain expr
pub async fn handle_explain(&self, sql: &str, explain_stmt: Explain) -> Result<Output> {
pub async fn handle_explain(
&self,
sql: &str,
explain_stmt: Explain,
query_ctx: QueryContextRef,
) -> Result<Output> {
if let Some(dist_instance) = &self.dist_instance {
dist_instance
.handle_sql(sql, Statement::Explain(explain_stmt))
.handle_sql(sql, Statement::Explain(explain_stmt), query_ctx)
.await
} else {
Ok(Output::AffectedRows(0))
@@ -505,6 +517,26 @@ impl Instance {
let insert_request = insert_to_request(&schema_provider, *insert)?;
insert_request_to_insert_batch(&insert_request)
}
fn handle_use(&self, db: String, query_ctx: QueryContextRef) -> Result<Output> {
let catalog_manager = &self.catalog_manager;
if let Some(catalog_manager) = catalog_manager {
ensure!(
catalog_manager
.schema(DEFAULT_CATALOG_NAME, &db)
.context(error::CatalogSnafu)?
.is_some(),
error::SchemaNotFoundSnafu { schema_info: &db }
);
query_ctx.set_current_schema(&db);
Ok(Output::RecordBatches(RecordBatches::empty()))
} else {
// TODO(LFC): Handle "use" stmt here.
unimplemented!()
}
}
}
#[async_trait]
@@ -545,17 +577,23 @@ fn parse_stmt(sql: &str) -> Result<Statement> {
#[async_trait]
impl SqlQueryHandler for Instance {
async fn do_query(&self, query: &str) -> server_error::Result<Output> {
async fn do_query(
&self,
query: &str,
query_ctx: QueryContextRef,
) -> server_error::Result<Output> {
let stmt = parse_stmt(query)
.map_err(BoxedError::new)
.context(server_error::ExecuteQuerySnafu { query })?;
match stmt {
Statement::Query(_) => self
.handle_select(Select::Sql(query.to_string()), stmt)
.await
.map_err(BoxedError::new)
.context(server_error::ExecuteQuerySnafu { query }),
Statement::ShowDatabases(_)
| Statement::ShowTables(_)
| Statement::DescribeTable(_)
| Statement::Query(_) => {
self.handle_select(Select::Sql(query.to_string()), stmt, query_ctx)
.await
}
Statement::Insert(insert) => match self.mode {
Mode::Standalone => {
let (catalog_name, schema_name, table_name) = insert
@@ -578,10 +616,7 @@ impl SqlQueryHandler for Instance {
columns,
row_count,
};
self.handle_insert(expr)
.await
.map_err(BoxedError::new)
.context(server_error::ExecuteQuerySnafu { query })
self.handle_insert(expr).await
}
Mode::Distributed => {
let affected = self
@@ -604,55 +639,36 @@ impl SqlQueryHandler for Instance {
self.handle_create_table(create_expr, create.partitions)
.await
.map_err(BoxedError::new)
.context(server_error::ExecuteQuerySnafu { query })
}
Statement::ShowDatabases(_)
| Statement::ShowTables(_)
| Statement::DescribeTable(_) => self
.handle_select(Select::Sql(query.to_string()), stmt)
.await
.map_err(BoxedError::new)
.context(server_error::ExecuteQuerySnafu { query }),
Statement::CreateDatabase(c) => {
let expr = CreateDatabaseExpr {
database_name: c.name.to_string(),
};
self.handle_create_database(expr)
.await
.map_err(BoxedError::new)
.context(server_error::ExecuteQuerySnafu { query })
self.handle_create_database(expr).await
}
Statement::Alter(alter_stmt) => self
.handle_alter(
Statement::Alter(alter_stmt) => {
self.handle_alter(
AlterExpr::try_from(alter_stmt)
.map_err(BoxedError::new)
.context(server_error::ExecuteAlterSnafu { query })?,
)
.await
.map_err(BoxedError::new)
.context(server_error::ExecuteQuerySnafu { query }),
}
Statement::DropTable(drop_stmt) => {
let expr = DropTableExpr {
catalog_name: drop_stmt.catalog_name,
schema_name: drop_stmt.schema_name,
table_name: drop_stmt.table_name,
};
self.handle_drop_table(expr)
.await
.map_err(BoxedError::new)
.context(server_error::ExecuteQuerySnafu { query })
self.handle_drop_table(expr).await
}
Statement::Explain(explain_stmt) => {
self.handle_explain(query, explain_stmt, query_ctx).await
}
Statement::Explain(explain_stmt) => self
.handle_explain(query, explain_stmt)
.await
.map_err(BoxedError::new)
.context(server_error::ExecuteQuerySnafu { query }),
Statement::ShowCreateTable(_) => {
return server_error::NotSupportedSnafu { feat: query }.fail();
}
Statement::Use(db) => self.handle_use(db, query_ctx),
}
.map_err(BoxedError::new)
.context(server_error::ExecuteQuerySnafu { query })
@@ -716,7 +732,8 @@ impl GrpcQueryHandler for Instance {
})?;
match select {
select_expr::Expr::Sql(sql) => {
let output = SqlQueryHandler::do_query(self, sql).await;
let query_ctx = Arc::new(QueryContext::new());
let output = SqlQueryHandler::do_query(self, sql, query_ctx).await;
Ok(to_object_result(output).await)
}
_ => {
@@ -797,6 +814,8 @@ mod tests {
#[tokio::test]
async fn test_execute_sql() {
let query_ctx = Arc::new(QueryContext::new());
let instance = tests::create_frontend_instance().await;
let sql = r#"CREATE TABLE demo(
@@ -808,7 +827,9 @@ mod tests {
TIME INDEX (ts),
PRIMARY KEY(ts, host)
) engine=mito with(regions=1);"#;
let output = SqlQueryHandler::do_query(&*instance, sql).await.unwrap();
let output = SqlQueryHandler::do_query(&*instance, sql, query_ctx.clone())
.await
.unwrap();
match output {
Output::AffectedRows(rows) => assert_eq!(rows, 1),
_ => unreachable!(),
@@ -819,14 +840,18 @@ mod tests {
('frontend.host2', null, null, 2000),
('frontend.host3', 3.3, 300, 3000)
"#;
let output = SqlQueryHandler::do_query(&*instance, sql).await.unwrap();
let output = SqlQueryHandler::do_query(&*instance, sql, query_ctx.clone())
.await
.unwrap();
match output {
Output::AffectedRows(rows) => assert_eq!(rows, 3),
_ => unreachable!(),
}
let sql = "select * from demo";
let output = SqlQueryHandler::do_query(&*instance, sql).await.unwrap();
let output = SqlQueryHandler::do_query(&*instance, sql, query_ctx.clone())
.await
.unwrap();
match output {
Output::RecordBatches(recordbatches) => {
let pretty_print = recordbatches.pretty_print();
@@ -846,7 +871,9 @@ mod tests {
};
let sql = "select * from demo where ts>cast(1000000000 as timestamp)"; // use nanoseconds as where condition
let output = SqlQueryHandler::do_query(&*instance, sql).await.unwrap();
let output = SqlQueryHandler::do_query(&*instance, sql, query_ctx.clone())
.await
.unwrap();
match output {
Output::RecordBatches(recordbatches) => {
let pretty_print = recordbatches.pretty_print();

View File

@@ -33,6 +33,7 @@ use meta_client::rpc::{
};
use query::sql::{describe_table, explain, show_databases, show_tables};
use query::{QueryEngineFactory, QueryEngineRef};
use session::context::QueryContextRef;
use snafu::{ensure, OptionExt, ResultExt};
use sql::statements::create::Partitions;
use sql::statements::sql_value_to_value;
@@ -128,29 +129,31 @@ impl DistInstance {
Ok(Output::AffectedRows(region_routes.len()))
}
pub(crate) async fn handle_sql(&self, sql: &str, stmt: Statement) -> Result<Output> {
pub(crate) async fn handle_sql(
&self,
sql: &str,
stmt: Statement,
query_ctx: QueryContextRef,
) -> Result<Output> {
match stmt {
Statement::Query(_) => {
let plan = self
.query_engine
.statement_to_plan(stmt)
.statement_to_plan(stmt, query_ctx)
.context(error::ExecuteSqlSnafu { sql })?;
self.query_engine
.execute(&plan)
.await
.context(error::ExecuteSqlSnafu { sql })
self.query_engine.execute(&plan).await
}
Statement::ShowDatabases(stmt) => show_databases(stmt, self.catalog_manager.clone()),
Statement::ShowTables(stmt) => {
show_tables(stmt, self.catalog_manager.clone(), query_ctx)
}
Statement::DescribeTable(stmt) => describe_table(stmt, self.catalog_manager.clone()),
Statement::Explain(stmt) => {
explain(Box::new(stmt), self.query_engine.clone(), query_ctx).await
}
Statement::ShowDatabases(stmt) => show_databases(stmt, self.catalog_manager.clone())
.context(error::ExecuteSqlSnafu { sql }),
Statement::ShowTables(stmt) => show_tables(stmt, self.catalog_manager.clone())
.context(error::ExecuteSqlSnafu { sql }),
Statement::DescribeTable(stmt) => describe_table(stmt, self.catalog_manager.clone())
.context(error::ExecuteSqlSnafu { sql }),
Statement::Explain(stmt) => explain(Box::new(stmt), self.query_engine.clone())
.await
.context(error::ExecuteSqlSnafu { sql }),
_ => unreachable!(),
}
.context(error::ExecuteSqlSnafu { sql })
}
/// Handles distributed database creation

View File

@@ -60,9 +60,12 @@ impl Instance {
#[cfg(test)]
mod tests {
use std::sync::Arc;
use common_query::Output;
use datafusion::arrow_print;
use servers::query_handler::SqlQueryHandler;
use session::context::QueryContext;
use super::*;
use crate::tests;
@@ -121,7 +124,7 @@ mod tests {
assert!(result.is_ok());
let output = instance
.do_query("select * from my_metric_1")
.do_query("select * from my_metric_1", Arc::new(QueryContext::new()))
.await
.unwrap();
match output {

View File

@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use api::prometheus::remote::read_request::ResponseType;
use api::prometheus::remote::{Query, QueryResult, ReadRequest, ReadResponse, WriteRequest};
use async_trait::async_trait;
@@ -25,6 +27,7 @@ use servers::error::{self, Result as ServerResult};
use servers::prometheus::{self, Metrics};
use servers::query_handler::{PrometheusProtocolHandler, PrometheusResponse};
use servers::Mode;
use session::context::QueryContext;
use snafu::{OptionExt, ResultExt};
use crate::instance::{parse_stmt, Instance};
@@ -93,7 +96,10 @@ impl Instance {
let object_result = if let Some(dist_instance) = &self.dist_instance {
let output = futures::future::ready(parse_stmt(&sql))
.and_then(|stmt| dist_instance.handle_sql(&sql, stmt))
.and_then(|stmt| {
let query_ctx = Arc::new(QueryContext::with_current_schema(db.to_string()));
dist_instance.handle_sql(&sql, stmt, query_ctx)
})
.await;
to_object_result(output).await.try_into()
} else {

View File

@@ -27,6 +27,7 @@ metrics = "0.20"
once_cell = "1.10"
serde = "1.0"
serde_json = "1.0"
session = { path = "../session" }
snafu = { version = "0.7", features = ["backtraces"] }
sql = { path = "../sql" }
table = { path = "../table" }

View File

@@ -32,6 +32,7 @@ use common_recordbatch::{EmptyRecordBatchStream, SendableRecordBatchStream};
use common_telemetry::timer;
use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
use datafusion::physical_plan::ExecutionPlan;
use session::context::QueryContextRef;
use snafu::{OptionExt, ResultExt};
use sql::dialect::GenericDialect;
use sql::parser::ParserContext;
@@ -46,7 +47,7 @@ use crate::physical_optimizer::PhysicalOptimizer;
use crate::physical_planner::PhysicalPlanner;
use crate::plan::LogicalPlan;
use crate::planner::Planner;
use crate::query_engine::{QueryContext, QueryEngineState};
use crate::query_engine::{QueryEngineContext, QueryEngineState};
use crate::{metric, QueryEngine};
pub(crate) struct DatafusionQueryEngine {
@@ -61,6 +62,7 @@ impl DatafusionQueryEngine {
}
}
// TODO(LFC): Refactor consideration: extract a "Planner" that stores query context and execute queries inside.
#[async_trait::async_trait]
impl QueryEngine for DatafusionQueryEngine {
fn name(&self) -> &str {
@@ -75,21 +77,25 @@ impl QueryEngine for DatafusionQueryEngine {
Ok(statement.remove(0))
}
fn statement_to_plan(&self, stmt: Statement) -> Result<LogicalPlan> {
let context_provider = DfContextProviderAdapter::new(self.state.clone());
fn statement_to_plan(
&self,
stmt: Statement,
query_ctx: QueryContextRef,
) -> Result<LogicalPlan> {
let context_provider = DfContextProviderAdapter::new(self.state.clone(), query_ctx);
let planner = DfPlanner::new(&context_provider);
planner.statement_to_plan(stmt)
}
fn sql_to_plan(&self, sql: &str) -> Result<LogicalPlan> {
fn sql_to_plan(&self, sql: &str, query_ctx: QueryContextRef) -> Result<LogicalPlan> {
let _timer = timer!(metric::METRIC_PARSE_SQL_ELAPSED);
let stmt = self.sql_to_statement(sql)?;
self.statement_to_plan(stmt)
self.statement_to_plan(stmt, query_ctx)
}
async fn execute(&self, plan: &LogicalPlan) -> Result<Output> {
let mut ctx = QueryContext::new(self.state.clone());
let mut ctx = QueryEngineContext::new(self.state.clone());
let logical_plan = self.optimize_logical_plan(&mut ctx, plan)?;
let physical_plan = self.create_physical_plan(&mut ctx, &logical_plan).await?;
let physical_plan = self.optimize_physical_plan(&mut ctx, physical_plan)?;
@@ -100,7 +106,7 @@ impl QueryEngine for DatafusionQueryEngine {
}
async fn execute_physical(&self, plan: &Arc<dyn PhysicalPlan>) -> Result<Output> {
let ctx = QueryContext::new(self.state.clone());
let ctx = QueryEngineContext::new(self.state.clone());
Ok(Output::Stream(self.execute_stream(&ctx, plan).await?))
}
@@ -127,7 +133,7 @@ impl QueryEngine for DatafusionQueryEngine {
impl LogicalOptimizer for DatafusionQueryEngine {
fn optimize_logical_plan(
&self,
_ctx: &mut QueryContext,
_: &mut QueryEngineContext,
plan: &LogicalPlan,
) -> Result<LogicalPlan> {
let _timer = timer!(metric::METRIC_OPTIMIZE_LOGICAL_ELAPSED);
@@ -151,7 +157,7 @@ impl LogicalOptimizer for DatafusionQueryEngine {
impl PhysicalPlanner for DatafusionQueryEngine {
async fn create_physical_plan(
&self,
_ctx: &mut QueryContext,
_: &mut QueryEngineContext,
logical_plan: &LogicalPlan,
) -> Result<Arc<dyn PhysicalPlan>> {
let _timer = timer!(metric::METRIC_CREATE_PHYSICAL_ELAPSED);
@@ -183,7 +189,7 @@ impl PhysicalPlanner for DatafusionQueryEngine {
impl PhysicalOptimizer for DatafusionQueryEngine {
fn optimize_physical_plan(
&self,
_ctx: &mut QueryContext,
_: &mut QueryEngineContext,
plan: Arc<dyn PhysicalPlan>,
) -> Result<Arc<dyn PhysicalPlan>> {
let _timer = timer!(metric::METRIC_OPTIMIZE_PHYSICAL_ELAPSED);
@@ -211,7 +217,7 @@ impl PhysicalOptimizer for DatafusionQueryEngine {
impl QueryExecutor for DatafusionQueryEngine {
async fn execute_stream(
&self,
ctx: &QueryContext,
ctx: &QueryEngineContext,
plan: &Arc<dyn PhysicalPlan>,
) -> Result<SendableRecordBatchStream> {
let _timer = timer!(metric::METRIC_EXEC_PLAN_ELAPSED);
@@ -250,6 +256,7 @@ mod tests {
use common_recordbatch::util;
use datafusion::field_util::{FieldExt, SchemaExt};
use datatypes::arrow::array::UInt64Array;
use session::context::QueryContext;
use table::table::numbers::NumbersTable;
use crate::query_engine::{QueryEngineFactory, QueryEngineRef};
@@ -277,7 +284,9 @@ mod tests {
let engine = create_test_engine();
let sql = "select sum(number) from numbers limit 20";
let plan = engine.sql_to_plan(sql).unwrap();
let plan = engine
.sql_to_plan(sql, Arc::new(QueryContext::new()))
.unwrap();
assert_eq!(
format!("{:?}", plan),
@@ -293,7 +302,9 @@ mod tests {
let engine = create_test_engine();
let sql = "select sum(number) from numbers limit 20";
let plan = engine.sql_to_plan(sql).unwrap();
let plan = engine
.sql_to_plan(sql, Arc::new(QueryContext::new()))
.unwrap();
let output = engine.execute(&plan).await.unwrap();
match output {

View File

@@ -21,6 +21,7 @@ use datafusion::physical_plan::udaf::AggregateUDF;
use datafusion::physical_plan::udf::ScalarUDF;
use datafusion::sql::planner::{ContextProvider, SqlToRel};
use datatypes::arrow::datatypes::DataType;
use session::context::QueryContextRef;
use snafu::ResultExt;
use sql::statements::explain::Explain;
use sql::statements::query::Query;
@@ -85,18 +86,20 @@ where
| Statement::CreateDatabase(_)
| Statement::Alter(_)
| Statement::Insert(_)
| Statement::DropTable(_) => unreachable!(),
| Statement::DropTable(_)
| Statement::Use(_) => unreachable!(),
}
}
}
pub(crate) struct DfContextProviderAdapter {
state: QueryEngineState,
query_ctx: QueryContextRef,
}
impl DfContextProviderAdapter {
pub(crate) fn new(state: QueryEngineState) -> Self {
Self { state }
pub(crate) fn new(state: QueryEngineState, query_ctx: QueryContextRef) -> Self {
Self { state, query_ctx }
}
}
@@ -104,11 +107,18 @@ impl DfContextProviderAdapter {
/// manage UDFs, UDAFs, variables by ourself in future.
impl ContextProvider for DfContextProviderAdapter {
fn get_table_provider(&self, name: TableReference) -> Option<Arc<dyn TableProvider>> {
self.state
.df_context()
.state
.lock()
.get_table_provider(name)
let schema = self.query_ctx.current_schema();
let execution_ctx = self.state.df_context().state.lock();
match name {
TableReference::Bare { table } if schema.is_some() => {
execution_ctx.get_table_provider(TableReference::Partial {
// unwrap safety: checked in this match's arm
schema: &schema.unwrap(),
table,
})
}
_ => execution_ctx.get_table_provider(name),
}
}
fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {

View File

@@ -18,14 +18,14 @@ use common_query::physical_plan::PhysicalPlan;
use common_recordbatch::SendableRecordBatchStream;
use crate::error::Result;
use crate::query_engine::QueryContext;
use crate::query_engine::QueryEngineContext;
/// Executor to run [ExecutionPlan].
#[async_trait::async_trait]
pub trait QueryExecutor {
async fn execute_stream(
&self,
ctx: &QueryContext,
ctx: &QueryEngineContext,
plan: &Arc<dyn PhysicalPlan>,
) -> Result<SendableRecordBatchStream>;
}

View File

@@ -26,4 +26,6 @@ pub mod planner;
pub mod query_engine;
pub mod sql;
pub use crate::query_engine::{QueryContext, QueryEngine, QueryEngineFactory, QueryEngineRef};
pub use crate::query_engine::{
QueryEngine, QueryEngineContext, QueryEngineFactory, QueryEngineRef,
};

View File

@@ -14,12 +14,12 @@
use crate::error::Result;
use crate::plan::LogicalPlan;
use crate::query_engine::QueryContext;
use crate::query_engine::QueryEngineContext;
pub trait LogicalOptimizer {
fn optimize_logical_plan(
&self,
ctx: &mut QueryContext,
ctx: &mut QueryEngineContext,
plan: &LogicalPlan,
) -> Result<LogicalPlan>;
}

View File

@@ -17,12 +17,12 @@ use std::sync::Arc;
use common_query::physical_plan::PhysicalPlan;
use crate::error::Result;
use crate::query_engine::QueryContext;
use crate::query_engine::QueryEngineContext;
pub trait PhysicalOptimizer {
fn optimize_physical_plan(
&self,
ctx: &mut QueryContext,
ctx: &mut QueryEngineContext,
plan: Arc<dyn PhysicalPlan>,
) -> Result<Arc<dyn PhysicalPlan>>;
}

View File

@@ -18,7 +18,7 @@ use common_query::physical_plan::PhysicalPlan;
use crate::error::Result;
use crate::plan::LogicalPlan;
use crate::query_engine::QueryContext;
use crate::query_engine::QueryEngineContext;
/// Physical query planner that converts a `LogicalPlan` to an
/// `ExecutionPlan` suitable for execution.
@@ -27,7 +27,7 @@ pub trait PhysicalPlanner {
/// Create a physical plan from a logical plan
async fn create_physical_plan(
&self,
ctx: &mut QueryContext,
ctx: &mut QueryEngineContext,
logical_plan: &LogicalPlan,
) -> Result<Arc<dyn PhysicalPlan>>;
}

View File

@@ -23,12 +23,13 @@ use common_function::scalars::{FunctionRef, FUNCTION_REGISTRY};
use common_query::physical_plan::PhysicalPlan;
use common_query::prelude::ScalarUdf;
use common_query::Output;
use session::context::QueryContextRef;
use sql::statements::statement::Statement;
use crate::datafusion::DatafusionQueryEngine;
use crate::error::Result;
use crate::plan::LogicalPlan;
pub use crate::query_engine::context::QueryContext;
pub use crate::query_engine::context::QueryEngineContext;
pub use crate::query_engine::state::QueryEngineState;
#[async_trait::async_trait]
@@ -37,9 +38,10 @@ pub trait QueryEngine: Send + Sync {
fn sql_to_statement(&self, sql: &str) -> Result<Statement>;
fn statement_to_plan(&self, stmt: Statement) -> Result<LogicalPlan>;
fn statement_to_plan(&self, stmt: Statement, query_ctx: QueryContextRef)
-> Result<LogicalPlan>;
fn sql_to_plan(&self, sql: &str) -> Result<LogicalPlan>;
fn sql_to_plan(&self, sql: &str, query_ctx: QueryContextRef) -> Result<LogicalPlan>;
async fn execute(&self, plan: &LogicalPlan) -> Result<Output>;

View File

@@ -16,11 +16,11 @@
use crate::query_engine::state::QueryEngineState;
#[derive(Debug)]
pub struct QueryContext {
pub struct QueryEngineContext {
state: QueryEngineState,
}
impl QueryContext {
impl QueryEngineContext {
pub fn new(state: QueryEngineState) -> Self {
Self { state }
}

View File

@@ -22,6 +22,7 @@ use datatypes::prelude::*;
use datatypes::schema::{ColumnSchema, Schema};
use datatypes::vectors::{Helper, StringVector};
use once_cell::sync::Lazy;
use session::context::QueryContextRef;
use snafu::{ensure, OptionExt, ResultExt};
use sql::statements::describe::DescribeTable;
use sql::statements::explain::Explain;
@@ -109,7 +110,11 @@ pub fn show_databases(stmt: ShowDatabases, catalog_manager: CatalogManagerRef) -
Ok(Output::RecordBatches(records))
}
pub fn show_tables(stmt: ShowTables, catalog_manager: CatalogManagerRef) -> Result<Output> {
pub fn show_tables(
stmt: ShowTables,
catalog_manager: CatalogManagerRef,
query_ctx: QueryContextRef,
) -> Result<Output> {
// TODO(LFC): supports WHERE
ensure!(
matches!(stmt.kind, ShowKind::All | ShowKind::Like(_)),
@@ -118,9 +123,15 @@ pub fn show_tables(stmt: ShowTables, catalog_manager: CatalogManagerRef) -> Resu
}
);
let schema = stmt.database.as_deref().unwrap_or(DEFAULT_SCHEMA_NAME);
let schema = if let Some(database) = stmt.database {
database
} else {
query_ctx
.current_schema()
.unwrap_or_else(|| DEFAULT_SCHEMA_NAME.to_string())
};
let schema = catalog_manager
.schema(DEFAULT_CATALOG_NAME, schema)
.schema(DEFAULT_CATALOG_NAME, &schema)
.context(error::CatalogSnafu)?
.context(error::SchemaNotFoundSnafu { schema })?;
let tables = schema.table_names().context(error::CatalogSnafu)?;
@@ -141,8 +152,12 @@ pub fn show_tables(stmt: ShowTables, catalog_manager: CatalogManagerRef) -> Resu
Ok(Output::RecordBatches(records))
}
pub async fn explain(stmt: Box<Explain>, query_engine: QueryEngineRef) -> Result<Output> {
let plan = query_engine.statement_to_plan(Statement::Explain(*stmt))?;
pub async fn explain(
stmt: Box<Explain>,
query_engine: QueryEngineRef,
query_ctx: QueryContextRef,
) -> Result<Output> {
let plan = query_engine.statement_to_plan(Statement::Explain(*stmt), query_ctx)?;
query_engine.execute(&plan).await
}

View File

@@ -24,6 +24,7 @@ use datatypes::types::PrimitiveElement;
use function::{create_query_engine, get_numbers_from_table};
use query::error::Result;
use query::QueryEngine;
use session::context::QueryContext;
#[tokio::test]
async fn test_argmax_aggregator() -> Result<()> {
@@ -95,7 +96,9 @@ async fn execute_argmax<'a>(
"select ARGMAX({}) as argmax from {}",
column_name, table_name
);
let plan = engine.sql_to_plan(&sql).unwrap();
let plan = engine
.sql_to_plan(&sql, Arc::new(QueryContext::new()))
.unwrap();
let output = engine.execute(&plan).await.unwrap();
let recordbatch_stream = match output {

View File

@@ -25,6 +25,7 @@ use datatypes::types::PrimitiveElement;
use function::{create_query_engine, get_numbers_from_table};
use query::error::Result;
use query::QueryEngine;
use session::context::QueryContext;
#[tokio::test]
async fn test_argmin_aggregator() -> Result<()> {
@@ -96,7 +97,9 @@ async fn execute_argmin<'a>(
"select argmin({}) as argmin from {}",
column_name, table_name
);
let plan = engine.sql_to_plan(&sql).unwrap();
let plan = engine
.sql_to_plan(&sql, Arc::new(QueryContext::new()))
.unwrap();
let output = engine.execute(&plan).await.unwrap();
let recordbatch_stream = match output {

View File

@@ -27,6 +27,7 @@ use datatypes::vectors::PrimitiveVector;
use query::query_engine::QueryEngineFactory;
use query::QueryEngine;
use rand::Rng;
use session::context::QueryContext;
use table::test_util::MemTable;
pub fn create_query_engine() -> Arc<dyn QueryEngine> {
@@ -80,7 +81,9 @@ where
for<'a> T: Scalar<RefType<'a> = T>,
{
let sql = format!("SELECT {} FROM {}", column_name, table_name);
let plan = engine.sql_to_plan(&sql).unwrap();
let plan = engine
.sql_to_plan(&sql, Arc::new(QueryContext::new()))
.unwrap();
let output = engine.execute(&plan).await.unwrap();
let recordbatch_stream = match output {

View File

@@ -28,6 +28,7 @@ use function::{create_query_engine, get_numbers_from_table};
use num_traits::AsPrimitive;
use query::error::Result;
use query::QueryEngine;
use session::context::QueryContext;
#[tokio::test]
async fn test_mean_aggregator() -> Result<()> {
@@ -89,7 +90,9 @@ async fn execute_mean<'a>(
engine: Arc<dyn QueryEngine>,
) -> RecordResult<Vec<RecordBatch>> {
let sql = format!("select MEAN({}) as mean from {}", column_name, table_name);
let plan = engine.sql_to_plan(&sql).unwrap();
let plan = engine
.sql_to_plan(&sql, Arc::new(QueryContext::new()))
.unwrap();
let output = engine.execute(&plan).await.unwrap();
let recordbatch_stream = match output {

View File

@@ -36,6 +36,7 @@ use datatypes::with_match_primitive_type_id;
use num_traits::AsPrimitive;
use query::error::Result;
use query::QueryEngineFactory;
use session::context::QueryContext;
use table::test_util::MemTable;
#[derive(Debug, Default)]
@@ -228,7 +229,7 @@ where
"select MY_SUM({}) as my_sum from {}",
column_name, table_name
);
let plan = engine.sql_to_plan(&sql)?;
let plan = engine.sql_to_plan(&sql, Arc::new(QueryContext::new()))?;
let output = engine.execute(&plan).await?;
let recordbatch_stream = match output {

View File

@@ -30,6 +30,7 @@ use function::{create_query_engine, get_numbers_from_table};
use num_traits::AsPrimitive;
use query::error::Result;
use query::{QueryEngine, QueryEngineFactory};
use session::context::QueryContext;
use table::test_util::MemTable;
#[tokio::test]
@@ -53,7 +54,9 @@ async fn test_percentile_aggregator() -> Result<()> {
async fn test_percentile_correctness() -> Result<()> {
let engine = create_correctness_engine();
let sql = String::from("select PERCENTILE(corr_number,88.0) as percentile from corr_numbers");
let plan = engine.sql_to_plan(&sql).unwrap();
let plan = engine
.sql_to_plan(&sql, Arc::new(QueryContext::new()))
.unwrap();
let output = engine.execute(&plan).await.unwrap();
let recordbatch_stream = match output {
@@ -113,7 +116,9 @@ async fn execute_percentile<'a>(
"select PERCENTILE({},50.0) as percentile from {}",
column_name, table_name
);
let plan = engine.sql_to_plan(&sql).unwrap();
let plan = engine
.sql_to_plan(&sql, Arc::new(QueryContext::new()))
.unwrap();
let output = engine.execute(&plan).await.unwrap();
let recordbatch_stream = match output {

View File

@@ -26,6 +26,7 @@ use function::{create_query_engine, get_numbers_from_table};
use num_traits::AsPrimitive;
use query::error::Result;
use query::QueryEngine;
use session::context::QueryContext;
#[tokio::test]
async fn test_polyval_aggregator() -> Result<()> {
@@ -92,7 +93,9 @@ async fn execute_polyval<'a>(
"select POLYVAL({}, 0) as polyval from {}",
column_name, table_name
);
let plan = engine.sql_to_plan(&sql).unwrap();
let plan = engine
.sql_to_plan(&sql, Arc::new(QueryContext::new()))
.unwrap();
let output = engine.execute(&plan).await.unwrap();
let recordbatch_stream = match output {

View File

@@ -37,6 +37,7 @@ use query::plan::LogicalPlan;
use query::query_engine::QueryEngineFactory;
use query::QueryEngine;
use rand::Rng;
use session::context::QueryContext;
use table::table::adapter::DfTableProviderAdapter;
use table::table::numbers::NumbersTable;
use table::test_util::MemTable;
@@ -134,7 +135,10 @@ async fn test_udf() -> Result<()> {
engine.register_udf(udf);
let plan = engine.sql_to_plan("select pow(number, number) as p from numbers limit 10")?;
let plan = engine.sql_to_plan(
"select pow(number, number) as p from numbers limit 10",
Arc::new(QueryContext::new()),
)?;
let output = engine.execute(&plan).await?;
let recordbatch = match output {
@@ -242,7 +246,9 @@ where
for<'a> T: Scalar<RefType<'a> = T>,
{
let sql = format!("SELECT {} FROM {}", column_name, table_name);
let plan = engine.sql_to_plan(&sql).unwrap();
let plan = engine
.sql_to_plan(&sql, Arc::new(QueryContext::new()))
.unwrap();
let output = engine.execute(&plan).await.unwrap();
let recordbatch_stream = match output {
@@ -330,7 +336,9 @@ async fn execute_median<'a>(
"select MEDIAN({}) as median from {}",
column_name, table_name
);
let plan = engine.sql_to_plan(&sql).unwrap();
let plan = engine
.sql_to_plan(&sql, Arc::new(QueryContext::new()))
.unwrap();
let output = engine.execute(&plan).await.unwrap();
let recordbatch_stream = match output {

View File

@@ -26,6 +26,7 @@ use function::{create_query_engine, get_numbers_from_table};
use num_traits::AsPrimitive;
use query::error::Result;
use query::QueryEngine;
use session::context::QueryContext;
use statrs::distribution::{ContinuousCDF, Normal};
use statrs::statistics::Statistics;
@@ -94,7 +95,9 @@ async fn execute_scipy_stats_norm_cdf<'a>(
"select SCIPYSTATSNORMCDF({},2.0) as scipy_stats_norm_cdf from {}",
column_name, table_name
);
let plan = engine.sql_to_plan(&sql).unwrap();
let plan = engine
.sql_to_plan(&sql, Arc::new(QueryContext::new()))
.unwrap();
let output = engine.execute(&plan).await.unwrap();
let recordbatch_stream = match output {

View File

@@ -26,6 +26,7 @@ use function::{create_query_engine, get_numbers_from_table};
use num_traits::AsPrimitive;
use query::error::Result;
use query::QueryEngine;
use session::context::QueryContext;
use statrs::distribution::{Continuous, Normal};
use statrs::statistics::Statistics;
@@ -94,7 +95,9 @@ async fn execute_scipy_stats_norm_pdf<'a>(
"select SCIPYSTATSNORMPDF({},2.0) as scipy_stats_norm_pdf from {}",
column_name, table_name
);
let plan = engine.sql_to_plan(&sql).unwrap();
let plan = engine
.sql_to_plan(&sql, Arc::new(QueryContext::new()))
.unwrap();
let output = engine.execute(&plan).await.unwrap();
let recordbatch_stream = match output {

View File

@@ -48,6 +48,7 @@ rustpython-vm = { git = "https://github.com/RustPython/RustPython", optional = t
"default",
"freeze-stdlib",
] }
session = { path = "../session" }
snafu = { version = "0.7", features = ["backtraces"] }
sql = { path = "../sql" }
table = { path = "../table" }

View File

@@ -26,6 +26,7 @@ use common_recordbatch::{RecordBatch, RecordBatchStream, SendableRecordBatchStre
use datatypes::schema::SchemaRef;
use futures::Stream;
use query::QueryEngineRef;
use session::context::QueryContext;
use snafu::{ensure, ResultExt};
use sql::statements::statement::Statement;
@@ -93,7 +94,9 @@ impl Script for PyScript {
matches!(stmt, Statement::Query { .. }),
error::UnsupportedSqlSnafu { sql }
);
let plan = self.query_engine.statement_to_plan(stmt)?;
let plan = self
.query_engine
.statement_to_plan(stmt, Arc::new(QueryContext::new()))?;
let res = self.query_engine.execute(&plan).await?;
let copr = self.copr.clone();
match res {

View File

@@ -28,6 +28,7 @@ use datatypes::prelude::{ConcreteDataType, ScalarVector};
use datatypes::schema::{ColumnSchema, Schema, SchemaBuilder};
use datatypes::vectors::{StringVector, TimestampVector, VectorRef};
use query::QueryEngineRef;
use session::context::QueryContext;
use snafu::{ensure, OptionExt, ResultExt};
use table::requests::{CreateTableRequest, InsertRequest};
@@ -151,7 +152,7 @@ impl ScriptsTable {
let plan = self
.query_engine
.sql_to_plan(&sql)
.sql_to_plan(&sql, Arc::new(QueryContext::new()))
.context(FindScriptSnafu { name })?;
let stream = match self

View File

@@ -38,6 +38,7 @@ rand = "0.8"
schemars = "0.8"
serde = "1.0"
serde_json = "1.0"
session = { path = "../session" }
snafu = { version = "0.7", features = ["backtraces"] }
snap = "1"
table = { path = "../table" }

View File

@@ -487,6 +487,7 @@ mod test {
use datatypes::prelude::*;
use datatypes::schema::{ColumnSchema, Schema};
use datatypes::vectors::{StringVector, UInt32Vector};
use session::context::QueryContextRef;
use tokio::sync::mpsc;
use super::*;
@@ -498,7 +499,7 @@ mod test {
#[async_trait]
impl SqlQueryHandler for DummyInstance {
async fn do_query(&self, _query: &str) -> Result<Output> {
async fn do_query(&self, _: &str, _: QueryContextRef) -> Result<Output> {
unimplemented!()
}
}

View File

@@ -13,6 +13,7 @@
// limitations under the License.
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use aide::transform::TransformOperation;
@@ -21,6 +22,7 @@ use common_error::status_code::StatusCode;
use common_telemetry::metric;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use session::context::QueryContext;
use crate::http::{ApiState, JsonResponse};
@@ -39,7 +41,9 @@ pub async fn sql(
let sql_handler = &state.sql_handler;
let start = Instant::now();
let resp = if let Some(sql) = &params.sql {
JsonResponse::from_output(sql_handler.do_query(sql).await).await
// TODO(LFC): Sessions in http server.
let query_ctx = Arc::new(QueryContext::new());
JsonResponse::from_output(sql_handler.do_query(sql, query_ctx).await).await
} else {
JsonResponse::with_error(
"sql parameter is required.".to_string(),

View File

@@ -26,21 +26,25 @@ use datatypes::vectors::StringVector;
use once_cell::sync::Lazy;
use regex::bytes::RegexSet;
use regex::Regex;
use session::context::QueryContextRef;
// TODO(LFC): Include GreptimeDB's version and git commit tag etc.
const MYSQL_VERSION: &str = "8.0.26";
static SELECT_VAR_PATTERN: Lazy<Regex> = Lazy::new(|| Regex::new("(?i)^(SELECT @@(.*))").unwrap());
static MYSQL_CONN_JAVA_PATTERN: Lazy<Regex> =
Lazy::new(|| Regex::new("(?i)^(/\\* mysql-connector-java(.*))").unwrap());
Lazy::new(|| Regex::new("(?i)^(/\\* mysql-connector-j(.*))").unwrap());
static SHOW_LOWER_CASE_PATTERN: Lazy<Regex> =
Lazy::new(|| Regex::new("(?i)^(SHOW VARIABLES LIKE 'lower_case_table_names'(.*))").unwrap());
static SHOW_COLLATION_PATTERN: Lazy<Regex> =
Lazy::new(|| Regex::new("(?i)^(show collation where(.*))").unwrap());
static SHOW_VARIABLES_PATTERN: Lazy<Regex> =
Lazy::new(|| Regex::new("(?i)^(SHOW VARIABLES(.*))").unwrap());
static SELECT_VERSION_PATTERN: Lazy<Regex> =
Lazy::new(|| Regex::new(r"(?i)^(SELECT VERSION\(\s*\))").unwrap());
static SELECT_DATABASE_PATTERN: Lazy<Regex> =
Lazy::new(|| Regex::new(r"(?i)^(SELECT DATABASE\(\s*\))").unwrap());
// SELECT TIMEDIFF(NOW(), UTC_TIMESTAMP());
static SELECT_TIME_DIFF_FUNC_PATTERN: Lazy<Regex> =
@@ -248,13 +252,18 @@ fn check_show_variables(query: &str) -> Option<Output> {
}
// Check for SET or others query, this is the final check of the federated query.
fn check_others(query: &str) -> Option<Output> {
fn check_others(query: &str, query_ctx: QueryContextRef) -> Option<Output> {
if OTHER_NOT_SUPPORTED_STMT.is_match(query.as_bytes()) {
return Some(Output::RecordBatches(RecordBatches::empty()));
}
let recordbatches = if SELECT_VERSION_PATTERN.is_match(query) {
Some(select_function("version()", MYSQL_VERSION))
} else if SELECT_DATABASE_PATTERN.is_match(query) {
let schema = query_ctx
.current_schema()
.unwrap_or_else(|| "NULL".to_string());
Some(select_function("database()", &schema))
} else if SELECT_TIME_DIFF_FUNC_PATTERN.is_match(query) {
Some(select_function(
"TIMEDIFF(NOW(), UTC_TIMESTAMP())",
@@ -268,7 +277,7 @@ fn check_others(query: &str) -> Option<Output> {
// Check whether the query is a federated or driver setup command,
// and return some faked results if there are any.
pub fn check(query: &str) -> Option<Output> {
pub(crate) fn check(query: &str, query_ctx: QueryContextRef) -> Option<Output> {
// First to check the query is like "select @@variables".
let output = check_select_variable(query);
if output.is_some() {
@@ -282,25 +291,27 @@ pub fn check(query: &str) -> Option<Output> {
}
// Last check.
check_others(query)
check_others(query, query_ctx)
}
#[cfg(test)]
mod test {
use session::context::QueryContext;
use super::*;
#[test]
fn test_check() {
let query = "select 1";
let result = check(query);
let result = check(query, Arc::new(QueryContext::new()));
assert!(result.is_none());
let query = "select versiona";
let output = check(query);
let output = check(query, Arc::new(QueryContext::new()));
assert!(output.is_none());
fn test(query: &str, expected: Vec<&str>) {
let output = check(query);
let output = check(query, Arc::new(QueryContext::new()));
match output.unwrap() {
Output::RecordBatches(r) => {
assert_eq!(r.pretty_print().lines().collect::<Vec<_>>(), expected)

View File

@@ -16,11 +16,13 @@ use std::sync::Arc;
use std::time::Instant;
use async_trait::async_trait;
use common_query::Output;
use common_telemetry::{debug, error};
use opensrv_mysql::{
AsyncMysqlShim, ErrorKind, ParamParser, QueryResultWriter, StatementMetaWriter,
AsyncMysqlShim, ErrorKind, InitWriter, ParamParser, QueryResultWriter, StatementMetaWriter,
};
use rand::RngCore;
use session::Session;
use tokio::io::AsyncWrite;
use tokio::sync::RwLock;
@@ -36,7 +38,9 @@ pub struct MysqlInstanceShim {
query_handler: SqlQueryHandlerRef,
salt: [u8; 20],
client_addr: String,
// TODO(LFC): Break `Context` struct into different fields in `Session`, each with its own purpose.
ctx: Arc<RwLock<Option<Context>>>,
session: Arc<Session>,
}
impl MysqlInstanceShim {
@@ -59,8 +63,33 @@ impl MysqlInstanceShim {
salt: scramble,
client_addr,
ctx: Arc::new(RwLock::new(None)),
session: Arc::new(Session::new()),
}
}
async fn do_query(&self, query: &str) -> Result<Output> {
debug!("Start executing query: '{}'", query);
let start = Instant::now();
// TODO(LFC): Find a better way to deal with these special federated queries:
// `check` uses regex to filter out unsupported statements emitted by MySQL's federated
// components, this is quick and dirty, there must be a better way to do it.
let output =
if let Some(output) = crate::mysql::federated::check(query, self.session.context()) {
Ok(output)
} else {
self.query_handler
.do_query(query, self.session.context())
.await
};
debug!(
"Finished executing query: '{}', total time costs in microseconds: {}",
query,
start.elapsed().as_micros()
);
output
}
}
#[async_trait]
@@ -144,25 +173,20 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
query: &'a str,
writer: QueryResultWriter<'a, W>,
) -> Result<()> {
debug!("Start executing query: '{}'", query);
let start = Instant::now();
// TODO(LFC): Find a better way:
// `check` uses regex to filter out unsupported statements emitted by MySQL's federated
// components, this is quick and dirty, there must be a better way to do it.
let output = if let Some(output) = crate::mysql::federated::check(query) {
Ok(output)
} else {
self.query_handler.do_query(query).await
};
debug!(
"Finished executing query: '{}', total time costs in microseconds: {}",
query,
start.elapsed().as_micros()
);
let output = self.do_query(query).await;
let mut writer = MysqlResultWriter::new(writer);
writer.write(query, output).await
}
async fn on_init<'a>(&'a mut self, database: &'a str, w: InitWriter<'a, W>) -> Result<()> {
let query = format!("USE {}", database.trim());
let output = self.do_query(&query).await;
if let Err(e) = output {
w.error(ErrorKind::ER_UNKNOWN_ERROR, e.to_string().as_bytes())
.await
} else {
w.ok().await
}
.map_err(|e| e.into())
}
}

View File

@@ -13,6 +13,7 @@
// limitations under the License.
use std::ops::Deref;
use std::sync::Arc;
use async_trait::async_trait;
use common_query::Output;
@@ -26,6 +27,7 @@ use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
use pgwire::api::results::{text_query_response, FieldInfo, Response, Tag, TextDataRowEncoder};
use pgwire::api::{ClientInfo, Type};
use pgwire::error::{PgWireError, PgWireResult};
use session::context::QueryContext;
use crate::error::{self, Error, Result};
use crate::query_handler::SqlQueryHandlerRef;
@@ -46,9 +48,11 @@ impl SimpleQueryHandler for PostgresServerHandler {
where
C: ClientInfo + Unpin + Send + Sync,
{
// TODO(LFC): Sessions in pg server.
let query_ctx = Arc::new(QueryContext::new());
let output = self
.query_handler
.do_query(query)
.do_query(query, query_ctx)
.await
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;

View File

@@ -18,6 +18,7 @@ use api::prometheus::remote::{ReadRequest, WriteRequest};
use api::v1::{AdminExpr, AdminResult, ObjectExpr, ObjectResult};
use async_trait::async_trait;
use common_query::Output;
use session::context::QueryContextRef;
use crate::error::Result;
use crate::influxdb::InfluxdbRequest;
@@ -44,7 +45,7 @@ pub type ScriptHandlerRef = Arc<dyn ScriptHandler + Send + Sync>;
#[async_trait]
pub trait SqlQueryHandler {
async fn do_query(&self, query: &str) -> Result<Output>;
async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Result<Output>;
}
#[async_trait]

View File

@@ -23,6 +23,7 @@ use servers::error::Result;
use servers::http::{HttpOptions, HttpServer};
use servers::influxdb::InfluxdbRequest;
use servers::query_handler::{InfluxdbLineProtocolHandler, SqlQueryHandler};
use session::context::QueryContextRef;
use tokio::sync::mpsc;
struct DummyInstance {
@@ -44,7 +45,7 @@ impl InfluxdbLineProtocolHandler for DummyInstance {
#[async_trait]
impl SqlQueryHandler for DummyInstance {
async fn do_query(&self, _query: &str) -> Result<Output> {
async fn do_query(&self, _: &str, _: QueryContextRef) -> Result<Output> {
unimplemented!()
}
}

View File

@@ -22,6 +22,7 @@ use servers::error::{self, Result};
use servers::http::{HttpOptions, HttpServer};
use servers::opentsdb::codec::DataPoint;
use servers::query_handler::{OpentsdbProtocolHandler, SqlQueryHandler};
use session::context::QueryContextRef;
use tokio::sync::mpsc;
struct DummyInstance {
@@ -44,7 +45,7 @@ impl OpentsdbProtocolHandler for DummyInstance {
#[async_trait]
impl SqlQueryHandler for DummyInstance {
async fn do_query(&self, _query: &str) -> Result<Output> {
async fn do_query(&self, _: &str, _: QueryContextRef) -> Result<Output> {
unimplemented!()
}
}

View File

@@ -27,6 +27,7 @@ use servers::http::{HttpOptions, HttpServer};
use servers::prometheus;
use servers::prometheus::{snappy_compress, Metrics};
use servers::query_handler::{PrometheusProtocolHandler, PrometheusResponse, SqlQueryHandler};
use session::context::QueryContextRef;
use tokio::sync::mpsc;
struct DummyInstance {
@@ -69,7 +70,7 @@ impl PrometheusProtocolHandler for DummyInstance {
#[async_trait]
impl SqlQueryHandler for DummyInstance {
async fn do_query(&self, _query: &str) -> Result<Output> {
async fn do_query(&self, _: &str, _: QueryContextRef) -> Result<Output> {
unimplemented!()
}
}

View File

@@ -31,6 +31,8 @@ mod http;
mod mysql;
use script::engine::{CompileContext, EvalContext, Script, ScriptEngine};
use script::python::{PyEngine, PyScript};
use session::context::QueryContextRef;
mod opentsdb;
mod postgres;
@@ -52,8 +54,8 @@ impl DummyInstance {
#[async_trait]
impl SqlQueryHandler for DummyInstance {
async fn do_query(&self, query: &str) -> Result<Output> {
let plan = self.query_engine.sql_to_plan(query).unwrap();
async fn do_query(&self, query: &str, query_ctx: QueryContextRef) -> Result<Output> {
let plan = self.query_engine.sql_to_plan(query, query_ctx).unwrap();
Ok(self.query_engine.execute(&plan).await.unwrap())
}
}

9
src/session/Cargo.toml Normal file
View File

@@ -0,0 +1,9 @@
[package]
name = "session"
version = "0.1.0"
edition = "2021"
license = "Apache-2.0"
[dependencies]
arc-swap = "1.5"
common-telemetry = { path = "../common/telemetry" }

View File

@@ -0,0 +1,56 @@
// Copyright 2022 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use arc_swap::ArcSwapOption;
use common_telemetry::info;
pub type QueryContextRef = Arc<QueryContext>;
pub struct QueryContext {
current_schema: ArcSwapOption<String>,
}
impl Default for QueryContext {
fn default() -> Self {
Self::new()
}
}
impl QueryContext {
pub fn new() -> Self {
Self {
current_schema: ArcSwapOption::new(None),
}
}
pub fn with_current_schema(schema: String) -> Self {
Self {
current_schema: ArcSwapOption::new(Some(Arc::new(schema))),
}
}
pub fn current_schema(&self) -> Option<String> {
self.current_schema.load().as_deref().cloned()
}
pub fn set_current_schema(&self, schema: &str) {
let last = self.current_schema.swap(Some(Arc::new(schema.to_string())));
info!(
"set new session default schema: {:?}, swap old: {:?}",
schema, last
)
}
}

36
src/session/src/lib.rs Normal file
View File

@@ -0,0 +1,36 @@
// Copyright 2022 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pub mod context;
use std::sync::Arc;
use crate::context::{QueryContext, QueryContextRef};
#[derive(Default)]
pub struct Session {
query_ctx: QueryContextRef,
}
impl Session {
pub fn new() -> Self {
Session {
query_ctx: Arc::new(QueryContext::new()),
}
}
pub fn context(&self) -> QueryContextRef {
self.query_ctx.clone()
}
}

View File

@@ -102,6 +102,21 @@ impl<'a> ParserContext<'a> {
Keyword::DROP => self.parse_drop(),
// TODO(LFC): Use "Keyword::USE" when we can upgrade to newer version of crate sqlparser.
Keyword::NoKeyword if w.value.to_lowercase() == "use" => {
self.parser.next_token();
let database_name =
self.parser
.parse_identifier()
.context(error::UnexpectedSnafu {
sql: self.sql,
expected: "a database name",
actual: self.peek_token_as_string(),
})?;
Ok(Statement::Use(database_name.value))
}
// todo(hl) support more statements.
_ => self.unsupported(self.peek_token_as_string()),
}

View File

@@ -42,6 +42,8 @@ use crate::error::{
SerializeColumnDefaultConstraintSnafu, UnsupportedDefaultValueSnafu,
};
// TODO(LFC): Get rid of this function, use session context aware version of "table_idents_to_full_name" instead.
// Current obstacles remain in some usage in Frontend, and other SQLs like "describe", "drop" etc.
/// Converts maybe fully-qualified table name (`<catalog>.<schema>.<table>` or `<table>` when
/// catalog and schema are default) to tuple.
pub fn table_idents_to_full_name(obj_name: &ObjectName) -> Result<(String, String, String)> {

View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use sqlparser::ast::{SetExpr, Statement, UnaryOperator, Values};
use sqlparser::ast::{ObjectName, SetExpr, Statement, UnaryOperator, Values};
use sqlparser::parser::ParserError;
use crate::ast::{Expr, Value};
@@ -33,6 +33,13 @@ impl Insert {
}
}
pub fn table_name(&self) -> &ObjectName {
match &self.inner {
Statement::Insert { table_name, .. } => table_name,
_ => unreachable!(),
}
}
pub fn columns(&self) -> Vec<&String> {
match &self.inner {
Statement::Insert { columns, .. } => columns.iter().map(|ident| &ident.value).collect(),
@@ -110,15 +117,6 @@ mod tests {
use super::*;
use crate::parser::ParserContext;
#[test]
pub fn test_insert_convert() {
let sql = r"INSERT INTO tables_0 VALUES ( 'field_0', 0) ";
let mut stmts = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap();
assert_eq!(1, stmts.len());
let insert = stmts.pop().unwrap();
let _stmt: Statement = insert.try_into().unwrap();
}
#[test]
fn test_insert_value_with_unary_op() {
use crate::statements::statement::Statement;

View File

@@ -12,9 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use sqlparser::ast::Statement as SpStatement;
use sqlparser::parser::ParserError;
use crate::statements::alter::AlterTable;
use crate::statements::create::{CreateDatabase, CreateTable};
use crate::statements::describe::DescribeTable;
@@ -50,37 +47,7 @@ pub enum Statement {
DescribeTable(DescribeTable),
// EXPLAIN QUERY
Explain(Explain),
}
/// Converts Statement to sqlparser statement
impl TryFrom<Statement> for SpStatement {
type Error = sqlparser::parser::ParserError;
fn try_from(value: Statement) -> Result<Self, Self::Error> {
match value {
Statement::ShowDatabases(_) => Err(ParserError::ParserError(
"sqlparser does not support SHOW DATABASE query.".to_string(),
)),
Statement::ShowTables(_) => Err(ParserError::ParserError(
"sqlparser does not support SHOW TABLES query.".to_string(),
)),
Statement::ShowCreateTable(_) => Err(ParserError::ParserError(
"sqlparser does not support SHOW CREATE TABLE query.".to_string(),
)),
Statement::DescribeTable(_) => Err(ParserError::ParserError(
"sqlparser does not support DESCRIBE TABLE query.".to_string(),
)),
Statement::DropTable(_) => Err(ParserError::ParserError(
"sqlparser does not support DROP TABLE query.".to_string(),
)),
Statement::Query(s) => Ok(SpStatement::Query(Box::new(s.inner))),
Statement::Insert(i) => Ok(i.inner),
Statement::CreateDatabase(_) | Statement::CreateTable(_) | Statement::Alter(_) => {
unimplemented!()
}
Statement::Explain(e) => Ok(e.inner),
}
}
Use(String),
}
/// Comment hints from SQL.
@@ -92,24 +59,3 @@ pub struct Hint {
pub comment: String,
pub prefix: String,
}
#[cfg(test)]
mod tests {
use std::assert_matches::assert_matches;
use sqlparser::dialect::GenericDialect;
use super::*;
use crate::parser::ParserContext;
#[test]
pub fn test_statement_convert() {
let sql = "SELECT * FROM table_0";
let mut stmts = ParserContext::create_with_dialect(sql, &GenericDialect {}).unwrap();
assert_eq!(1, stmts.len());
let x = stmts.remove(0);
let statement = SpStatement::try_from(x).unwrap();
assert_matches!(statement, SpStatement::Query { .. });
}
}

View File

@@ -26,6 +26,27 @@ pub struct TableReference<'a> {
pub table: &'a str,
}
// TODO(LFC): Find a better place for `TableReference`,
// so that we can reuse the default catalog and schema consts.
// Could be done together with issue #559.
impl<'a> TableReference<'a> {
pub fn bare(table: &'a str) -> Self {
TableReference {
catalog: "greptime",
schema: "public",
table,
}
}
pub fn full(catalog: &'a str, schema: &'a str, table: &'a str) -> Self {
TableReference {
catalog,
schema,
table,
}
}
}
impl<'a> Display for TableReference<'a> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}.{}.{}", self.catalog, self.schema, self.table)